├── .gitignore ├── 1_training_loop_in_chainer.ipynb ├── 2_how_to_use_trainer.ipynb ├── 3_write_convnet_in_chainer.ipynb ├── 4_RNN-language-model.ipynb ├── 5_word2vec.ipynb ├── 6_dqn_cartpole.ipynb ├── 7_multiple_gpus.ipynb ├── 8_chainer-for-theano-users.ipynb ├── 9_vanilla-LSTM-with-cupy.ipynb ├── LICENSE ├── README.md ├── cbow.png ├── center_context_word.png ├── gentxt.py ├── input.txt ├── rnnlm.png ├── rnnlm_example.png ├── shakespear.txt ├── skipgram.png ├── skipgram_detail.png ├── train_ptb.py └── trainer.png /.gitignore: -------------------------------------------------------------------------------- 1 | agent/ 2 | 3 | .DS_Store 4 | 5 | *.model 6 | 7 | .ipynb_checkpoints/ 8 | 9 | mnist_result 10 | result 11 | ptb_result 12 | word2vec_result 13 | multi_gpu_result 14 | multi_gpu_result_2 15 | 16 | 5.png 17 | 7.png 18 | 19 | __pycache__/ 20 | -------------------------------------------------------------------------------- /1_training_loop_in_chainer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# How to write a training loop in Chainer\n", 8 | "\n", 9 | "In this tutorial section, we will learn how to train a deep neural network to classify images of hand-written digits in the popular MNIST dataset. This dataset contains 50,000 training examples and 10,000 test examples. Each example is a set of a 28 x 28 greyscale image and a corresponding class label. Since the digits from 0 to 9 are used, there are 10 classes for the labels.\n", 10 | "\n", 11 | "Chainer provides a feature called [Trainer](https://docs.chainer.org/en/latest/reference/core/generated/chainer.training.Trainer.html#chainer.training.Trainer) that can simplify the training procedure of your model. However, it is also good to know how the training works in Chainer before starting to use the useful [Trainer](https://docs.chainer.org/en/latest/reference/core/generated/chainer.training.Trainer.html#chainer.training.Trainer) class that hides the actual processes. Writing your own training loop can be useful for learning how [Trainer](https://docs.chainer.org/en/latest/reference/core/generated/chainer.training.Trainer.html#chainer.training.Trainer) works or for implementing features not included in the standard trainer.\n", 12 | "\n", 13 | "The complete training procedure consists of the following steps:\n", 14 | "\n", 15 | "1. Prepare a dataset\n", 16 | "2. Create a dataset iterator\n", 17 | "3. Define a network\n", 18 | "4. Select an optimization algorithm\n", 19 | "5. Write a training loop\n", 20 | " 1. Retrieve a set of examples (mini-batch) from the training dataset.\n", 21 | " 2. Feed the mini-batch to your network.\n", 22 | " 3. Run a forward pass of the network and compute the loss.\n", 23 | " 4. Just call the [backward()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable.backward) method from the loss [Variable](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable) to compute the gradients for all trainable parameters.\n", 24 | " 5. Run the optimizer to update those parameters.\n", 25 | "6. Save the trained model\n", 26 | "7. Perform classification by the saved model and check the network performance on validation/test sets." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "First, let's import the necessary packages for using Chainer." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "import numpy as np\n", 45 | "import chainer\n", 46 | "from chainer import cuda, Function, gradient_check, report, training, utils, Variable\n", 47 | "from chainer import datasets, iterators, optimizers, serializers\n", 48 | "from chainer import Link, Chain, ChainList\n", 49 | "import chainer.functions as F\n", 50 | "import chainer.links as L\n", 51 | "from chainer.training import extensions\n", 52 | "import matplotlib.pyplot as plt\n", 53 | "from chainer.datasets import mnist" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# 1. Prepare a dataset\n", 61 | "\n", 62 | "Chainer contains some built-in functions to use some popular datasets like MNIST, CIFAR10/100, etc. Those can automatically download the data from servers and provide dataset objects which are easy to use.\n", 63 | "\n", 64 | "The code below shows how to retrieve the MNIST dataset from the server and save an image from its training split to make sure the images are correctly obtained." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...\n", 77 | "Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...\n", 78 | "Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...\n", 79 | "Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...\n", 80 | "label: 5\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "# Download the MNIST data if you haven't downloaded it yet\n", 86 | "train, test = mnist.get_mnist(withlabel=True, ndim=1)\n", 87 | "\n", 88 | "# Display an example from the MNIST dataset.\n", 89 | "# `x` contains the inpu t image array and `t` contains that target class\n", 90 | "# label as an integer.\n", 91 | "x, t = train[0]\n", 92 | "plt.imshow(x.reshape(28, 28), cmap='gray')\n", 93 | "plt.savefig('5.png')\n", 94 | "print('label:', t)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "The saved image `5.png` will look like:" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 3, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAHf1JREFUeJzt3X+M14V9+PHXyZWTIvcxBLnjClawLaZaMOnw4mq7GikH\nyZy4blHTJbCZNTNHN0OqGV1RTH+w6lwMLVH/WGBsK7ZZIt3MxmZRIKZAI9YtjQmBjkaYHK5s9/lw\np4Jy7+8fxMv3WvzxvnKfN3evxyP5JNzn83nxeeXNu+ez7/vcXUtRFEUAAJDGRVUvAABAcwlAAIBk\nBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQA\nSEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhA\nAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCM\nAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAA\nyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJJprXoBxqehoaF45ZVXYtq0adHS0lL1\nOgCUVBRFnDx5Mrq6uuKii1wPykYAMiqvvPJKzJkzp+o1APg1HTlyJGbPnl31GjSZ5GdUpk2bVvUK\nAJwHPp/nJAAT27hxY1xxxRVx8cUXR3d3d/z4xz9+37O+7AswMfh8npMATOp73/terF69Ou6///54\n4YUXYuHChdHT0xOvvvpq1asBAGOspSiKouolaL7u7u5YtGhRfOc734mIs9/UMWfOnPjSl74Uf/7n\nf/6e841GI2q12livCcAYq9fr0d7eXvUaNJkrgAmdPn069u/fH4sXLx6+76KLLorFixfHnj17zjlz\n6tSpaDQaI24AwPgkABP6xS9+EWfOnImOjo4R93d0dERfX985Z9avXx+1Wm345juAAWD8EoC8L2vW\nrIl6vT58O3LkSNUrAQCj5OcAJjRjxoyYNGlSHD9+fMT9x48fj87OznPOtLW1RVtbWzPWAwDGmCuA\nCU2ePDk++clPxo4dO4bvGxoaih07dsT1119f4WYAQDO4ApjU6tWrY8WKFfEbv/Ebcd1118UjjzwS\ng4OD8Yd/+IdVrwYAjDEBmNRtt90W//M//xP33Xdf9PX1xbXXXhvbt2//lW8MAQAmHj8HkFHxcwAB\nJgY/BzAn7wEEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMA\nAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAy\nAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIA\nJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQg\nAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMm0Vr0AQFmTJk0qPVOr1cZgk/Nj1apVo5r7\n4Ac/WHpm/vz5pWd6e3tLz/zVX/1V6Zk77rij9ExExBtvvFF65i//8i9LzzzwwAOlZ+BC5QogAEAy\nAhAAIBkBmNS6deuipaVlxO2qq66qei0AoAm8BzCxq6++On74wx8Of9za6nQAgAz8Fz+x1tbW6Ozs\nrHoNAKDJfAk4sYMHD0ZXV1fMmzcvvvCFL8TLL7/8js89depUNBqNETcAYHwSgEl1d3fH5s2bY/v2\n7fHoo4/G4cOH49Of/nScPHnynM9fv3591Gq14ducOXOavDEAcL4IwKSWLVsWv//7vx8LFiyInp6e\n+Jd/+Zfo7++P73//++d8/po1a6Jerw/fjhw50uSNAYDzxXsAiYiISy+9ND72sY/FoUOHzvl4W1tb\ntLW1NXkrAGAsuAJIREQMDAzEoUOHYtasWVWvAgCMMQGY1Je//OXYtWtX/PznP48f/ehHceutt0Zr\na+uofxUTADB++BJwUkePHo077rgjTpw4EZdddlnccMMNsXfv3rjsssuqXg0AGGMCMKknnnii6hVo\nkssvv7z0zOTJk0vP/OZv/mbpmRtuuKH0TMTZ96yW9fnPf35UrzXRHD16tPTMhg0bSs/ceuutpWfe\n6acQvJf/+I//KD2za9euUb0WTBS+BAwAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQ\nACAZAQgAkIwABABIRgACACQjAAEAkmkpiqKoegnGn0ajEbVareo1Urn22mtHNffMM8+UnvFvOz4M\nDQ2VnvmjP/qj0jMDAwOlZ0bj2LFjo5r7v//7v9IzBw4cGNVrTUT1ej3a29urXoMmcwUQACAZAQgA\nkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkmmt\negHg/Xn55ZdHNXfixInSM7VabVSvNdHs27ev9Ex/f3/pmRtvvLH0TETE6dOnS8/83d/93aheC5hY\nXAEEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAA\nAMm0Vr0A8P787//+76jm7rnnntIzv/3bv1165ic/+UnpmQ0bNpSeGa0XX3yx9MznPve50jODg4Ol\nZ66++urSMxERf/ZnfzaqOQBXAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQj\nAAEAkhGAAADJCEAAgGQEIABAMi1FURRVL8H402g0olarVb0GY6S9vb30zMmTJ0vPPP7446VnIiLu\nvPPO0jN/8Ad/UHpm69atpWdgvKnX66P63zzjmyuAAADJCEAAgGQE4AS0e/fuuPnmm6OrqytaWlpi\n27ZtIx4viiLuu+++mDVrVkyZMiUWL14cBw8erGhbAKDZBOAENDg4GAsXLoyNGzee8/EHH3wwNmzY\nEI899ljs27cvpk6dGj09PfHGG280eVMAoAqtVS/A+bds2bJYtmzZOR8riiIeeeSR+OpXvxq33HJL\nRERs2bIlOjo6Ytu2bXH77bc3c1UAoAKuACZz+PDh6Ovri8WLFw/fV6vVoru7O/bs2fOOc6dOnYpG\nozHiBgCMTwIwmb6+voiI6OjoGHF/R0fH8GPnsn79+qjVasO3OXPmjOmeAMDYEYC8L2vWrIl6vT58\nO3LkSNUrAQCjJACT6ezsjIiI48ePj7j/+PHjw4+dS1tbW7S3t4+4AQDjkwBMZu7cudHZ2Rk7duwY\nvq/RaMS+ffvi+uuvr3AzAKBZfBfwBDQwMBCHDh0a/vjw4cPx4osvxvTp0+Pyyy+Pu+++O77+9a/H\nRz/60Zg7d26sXbs2urq6Yvny5RVuDQA0iwCcgJ5//vm48cYbhz9evXp1RESsWLEiNm/eHPfee28M\nDg7GF7/4xejv748bbrghtm/fHhdffHFVKwMATdRSFEVR9RKMP41GI2q1WtVrMM499NBDo5p7+//U\nlLFr167SM///j0t6v4aGhkrPQJXq9br3dSfkPYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhG\nAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQTEtRFEXVSzD+NBqNqNVqVa/BODd16tRR\nzf3zP/9z6Znf+q3fKj2zbNmy0jP//u//XnoGqlSv16O9vb3qNWgyVwABAJIRgAAAyQhAAIBkBCAA\nQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDItRVEUVS/B+NNoNKJWq1W9\nBkldeeWVpWdeeOGF0jP9/f2lZ5599tnSM88//3zpmYiIjRs3lp7xKZ9fVq/Xo729veo1aDJXAAEA\nkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMi2F\n3wzOKDQajajValWvAe/brbfeWnpm06ZNpWemTZtWema0vvKVr5Se2bJlS+mZY8eOlZ5h/KjX69He\n3l71GjSZK4AAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQg\nAEAyAhAAIJmWoiiKqpdg/Gk0GlGr1apeA8bUNddcU3rmr//6r0vP3HTTTaVnRuvxxx8vPfONb3yj\n9Mx///d/l56hGvV6Pdrb26tegyZzBRAAIBkBCACQjACcgHbv3h0333xzdHV1RUtLS2zbtm3E4ytX\nroyWlpYRt6VLl1a0LQDQbAJwAhocHIyFCxfGxo0b3/E5S5cujWPHjg3ftm7d2sQNAYAqtVa9AOff\nsmXLYtmyZe/6nLa2tujs7GzSRgDAhcQVwKR27twZM2fOjPnz58ddd90VJ06ceNfnnzp1KhqNxogb\nADA+CcCEli5dGlu2bIkdO3bEt771rdi1a1csW7Yszpw5844z69evj1qtNnybM2dOEzcGAM4nXwJO\n6Pbbbx/+8yc+8YlYsGBBXHnllbFz5853/Hlka9asidWrVw9/3Gg0RCAAjFOuABLz5s2LGTNmxKFD\nh97xOW1tbdHe3j7iBgCMTwKQOHr0aJw4cSJmzZpV9SoAQBP4EvAENDAwMOJq3uHDh+PFF1+M6dOn\nx/Tp0+OBBx6Iz3/+89HZ2Rk/+9nP4t57742PfOQj0dPTU+HWAECzCMAJ6Pnnn48bb7xx+OO337u3\nYsWKePTRR+M///M/42//9m+jv78/urq6YsmSJfG1r30t2traqloZAGiilqIoiqqXYPxpNBpRq9Wq\nXgMuOJdeemnpmZtvvnlUr7Vp06bSMy0tLaVnnnnmmdIzn/vc50rPUI16ve593Ql5DyAAQDICEAAg\nGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAk01IU\nRVH1Eow/jUYjarVa1WtAaqdOnSo909raWnrmrbfeKj3T09NTembnzp2lZ/j11ev1aG9vr3oNmswV\nQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQ\nTPnfCg6QxIIFC0rP/N7v/V7pmUWLFpWeiYhobW3Op/CXXnqp9Mzu3bvHYBPgfHEFEAAgGQEIAJCM\nAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAk05zfJA5wHs2f\nP7/0zKpVq0rP/O7v/m7pmc7OztIzzXTmzJnSM8eOHSs9MzQ0VHoGaB5XAAEAkhGAAADJCEAAgGQE\nIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMq1VLwBMDJ2dnaVn7rjj\njlG91qpVq0rPXHHFFaN6rQvZ888/X3rmG9/4RumZf/qnfyo9A1zYXAEEAEhGAAIAJCMAJ6D169fH\nokWLYtq0aTFz5sxYvnx5HDhwYMRziqKI++67L2bNmhVTpkyJxYsXx8GDByvaGABoJgE4Ae3atSt6\ne3tj79698fTTT8ebb74ZS5YsicHBweHnPPjgg7Fhw4Z47LHHYt++fTF16tTo6emJN954o8LNAYBm\n8E0gE9D27dtHfLx58+aYOXNm7N+/Pz7zmc9EURTxyCOPxFe/+tW45ZZbIiJiy5Yt0dHREdu2bYvb\nb7+9irUBgCZxBTCBer0eERHTp0+PiIjDhw9HX19fLF68ePg5tVoturu7Y8+ePef8O06dOhWNRmPE\nDQAYnwTgBDc0NBR33313fOpTn4prrrkmIiL6+voiIqKjo2PEczs6OoYf+2Xr16+PWq02fJszZ87Y\nLg4AjBkBOMH19vbGT3/603jiiSd+rb9nzZo1Ua/Xh29Hjhw5TxsCAM3mPYAT2KpVq+Kpp56K3bt3\nx+zZs4fvf/sH9h4/fjxmzZo1fP/x48fj2muvPeff1dbWFm1tbWO7MADQFK4ATkBFUcSqVaviySef\njGeeeSbmzp074vG5c+dGZ2dn7NixY/i+RqMR+/bti+uvv77Z6wIATeYK4ATU29sb3/3ud+MHP/hB\nTJs2bfh9fbVaLaZMmRItLS1x9913x9e//vX46Ec/GnPnzo21a9dGV1dXLF++vOLtAYCxJgAnoEcf\nfTQiIj772c+OuH/Tpk2xcuXKiIi49957Y3BwML74xS9Gf39/3HDDDbF9+/a4+OKLm7wtANBsLUVR\nFFUvwfjTaDSiVqtVvQbvwy9/t/f78fGPf7z0zHe+853SM1dddVXpmQvdvn37Ss889NBDo3qtH/zg\nB6VnhoaGRvVaTFz1ej3a29urXoMm8x5AAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIR\ngAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASKa16gUgo+nTp5eeefzxx0f1Wtdee23pmXnz5o3q\ntS5kP/rRj0rPPPzww6Vn/u3f/q30zOuvv156BuDX4QogAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMA\nAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEimteoF4ELS3d1deuaee+4pPXPdddeVnvnQ\nhz5UeuZC99prr41qbsOGDaVnvvnNb5aeGRwcLD0DMB64AggAkIwABABIRgACACQjAAEAkhGAAADJ\nCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkmmtegG4kNx6661NmWmml156qfTMU089\nVXrmrbfeKj3z8MMPl56JiOjv7x/VHABnuQIIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAA\nQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJJpKYqiqHoJxp9GoxG1Wq3qNQD4NdXr9Whvb696DZrM\nFUAAgGQEIABAMgJwAlq/fn0sWrQopk2bFjNnzozly5fHgQMHRjxn5cqV0dLSMuK2dOnSijYGAJpJ\nAE5Au3btit7e3ti7d288/fTT8eabb8aSJUticHBwxPOWLl0ax44dG75t3bq1oo0BgGZqrXoBzr/t\n27eP+Hjz5s0xc+bM2L9/f3zmM58Zvr+trS06OzubvR4AUDFXABOo1+sRETF9+vQR9+/cuTNmzpwZ\n8+fPj7vuuitOnDjxjn/HqVOnotFojLgBAOOTHwMzwQ0NDcXv/M7vRH9/fzz33HPD9z/xxBPxwQ9+\nMObOnRs/+9nP4itf+UpccsklsWfPnpg0adKv/D3r1q2LBx54oJmrA9AEfgxMTgJwgrvrrrviX//1\nX+O5556L2bNnv+Pz/uu//iuuvPLK+OEPfxg33XTTrzx+6tSpOHXq1PDHjUYj5syZMyY7A9A8AjAn\nXwKewFatWhVPPfVUPPvss+8afxER8+bNixkzZsShQ4fO+XhbW1u0t7ePuAEA45NvApmAiqKIL33p\nS/Hkk0/Gzp07Y+7cue85c/To0Thx4kTMmjWrCRsCAFVyBXAC6u3tjb//+7+P7373uzFt2rTo6+uL\nvr6+eP311yMiYmBgIO65557Yu3dv/PznP48dO3bELbfcEh/5yEeip6en4u0BgLHmPYATUEtLyznv\n37RpU6xcuTJef/31WL58efzkJz+J/v7+6OrqiiVLlsTXvva16OjoeF+v4XcBA0wM3gOYkwBkVAQg\nwMQgAHPyJWAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQ\nACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQj\nAAEAkhGAAADJCEAAgGQEIKNSFEXVKwBwHvh8npMAZFROnjxZ9QoAnAc+n+fUUkh/RmFoaCheeeWV\nmDZtWrS0tIx4rNFoxJw5c+LIkSPR3t5e0YbVcxzOchzOchzOchzOuhCOQ1EUcfLkyejq6oqLLnI9\nKJvWqhdgfLroooti9uzZ7/qc9vb21J/g3+Y4nOU4nOU4nOU4nFX1cajVapW9NtWS/AAAyQhAAIBk\nJq1bt25d1Usw8UyaNCk++9nPRmtr7ncZOA5nOQ5nOQ5nOQ5nOQ5UyTeBAAAk40vAAADJCEAAgGQE\nIABAMgIQACAZAch5tXHjxrjiiivi4osvju7u7vjxj39c9UpNtW7dumhpaRlxu+qqq6pea8zt3r07\nbr755ujq6oqWlpbYtm3biMeLooj77rsvZs2aFVOmTInFixfHwYMHK9p27LzXcVi5cuWvnB9Lly6t\naNuxs379+li0aFFMmzYtZs6cGcuXL48DBw6MeE6Gc+L9HIcs5wQXHgHIefO9730vVq9eHffff3+8\n8MILsXDhwujp6YlXX3216tWa6uqrr45jx44N35577rmqVxpzg4ODsXDhwti4ceM5H3/wwQdjw4YN\n8dhjj8W+ffti6tSp0dPTE2+88UaTNx1b73UcIiKWLl064vzYunVrEzdsjl27dkVvb2/s3bs3nn76\n6XjzzTdjyZIlMTg4OPycDOfE+zkOETnOCS5ABZwn1113XdHb2zv88ZkzZ4qurq5i/fr1FW7VXPff\nf3+xcOHCqteoVEQUTz755PDHQ0NDRWdnZ/HQQw8N39ff31+0tbUVW7durWLFpvjl41AURbFixYri\nlltuqWij6rz66qtFRBS7du0qiiLvOfHLx6Eo8p4TVM8VQM6L06dPx/79+2Px4sXD91100UWxePHi\n2LNnT4WbNd/Bgwejq6sr5s2bF1/4whfi5ZdfrnqlSh0+fDj6+vpGnBu1Wi26u7vTnRsRETt37oyZ\nM2fG/Pnz46677ooTJ05UvdKYq9frERExffr0iMh7TvzycXhbxnOC6glAzotf/OIXcebMmejo6Bhx\nf0dHR/T19VW0VfN1d3fH5s2bY/v27fHoo4/G4cOH49Of/nScPHmy6tUq8/a/f/ZzI+Lsl/q2bNkS\nO3bsiG9961uxa9euWLZsWZw5c6bq1cbM0NBQ3H333fGpT30qrrnmmojIeU6c6zhE5DwnuDD4/TNw\nHi1btmz4zwsWLIju7u748Ic/HN///vfjzjvvrHAzLgS333778J8/8YlPxIIFC+LKK6+MnTt3xk03\n3VThZmOnt7c3fvrTn6Z4L+y7eafjkPGc4MLgCiDnxYwZM2LSpElx/PjxEfcfP348Ojs7K9qqepde\neml87GMfi0OHDlW9SmXe/vd3bvyqefPmxYwZMybs+bFq1ap46qmn4tlnn43Zs2cP35/tnHin43Au\nE/2c4MIhADkvJk+eHJ/85Cdjx44dw/cNDQ3Fjh074vrrr69ws2oNDAzEoUOHYtasWVWvUpm5c+dG\nZ2fniHOj0WjEvn37Up8bERFHjx6NEydOTLjzoyiKWLVqVTz55JPxzDPPxNy5c0c8nuWceK/jcC4T\n9ZzgwjNp3bp166pegomhvb091q5dG3PmzIm2trZYu3ZtvPjii/E3f/M3cckll1S9XlN8+ctfjra2\ntoiIeOmll+JP/uRP4tVXX43HHnsspk6dWvF2Y2dgYCBeeuml6Ovri8cffzy6u7tjypQpcfr06bj0\n0kvjzJkz8c1vfjM+/vGPx+nTp+NP//RP47XXXotvf/vb0do6cd6J8m7HYdKkSfEXf/EX0d7eHm+9\n9Vbs378/7rzzzrjkkkvi4YcfnlDHobe3N/7hH/4h/vEf/zG6urpiYGAgBgYGYtKkSfGBD3wgWlpa\nUpwT73UcBgYG0pwTXICq/jZkJpZvf/vbxeWXX15Mnjy5uO6664q9e/dWvVJT3XbbbcWsWbOKyZMn\nFx/60IeK2267rTh06FDVa425Z599toiIX7mtWLGiKIqzP/Zj7dq1RUdHR9HW1lbcdNNNxYEDB6pd\negy823F47bXXiiVLlhSXXXZZ8YEPfKD48Ic/XPzxH/9x0dfXV/Xa5925jkFEFJs2bRp+ToZz4r2O\nQ6ZzggtPS1EURTODEwCAankPIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQE\nIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABI\nRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGAAADJ/D8r\n22guhwhz8AAAAABJRU5ErkJggg==\n", 112 | "text/plain": [ 113 | "" 114 | ] 115 | }, 116 | "execution_count": 3, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "from IPython.display import Image\n", 123 | "Image('5.png')" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "# 2. Create a dataset iterator\n", 131 | "\n", 132 | "Although this is an optional step, we’d like to introduce the [Iterator](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator) class that retrieves a set of data and labels from the given dataset to easily make a mini-batch. There are some subclasses that can perform the same thing in different ways, e.g., using multi-processing to parallelize the data loading part, etc.\n", 133 | "\n", 134 | "Here, we use [SerialIterator](https://docs.chainer.org/en/latest/reference/generated/chainer.iterators.SerialIterator.html#chainer.iterators.SerialIterator), which is also a subclass of [Iterator](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator) in the example code below. The [SerialIterator](https://docs.chainer.org/en/latest/reference/generated/chainer.iterators.SerialIterator.html#chainer.iterators.SerialIterator) can provide mini-batches with or without shuffling the order of data in the given dataset.\n", 135 | "\n", 136 | "All [Iterators](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator) produce a new mini-batch by calling its [next()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator.next) method. All [Iterators](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator) also have properties to know how many times we have taken all the data from the given dataset (epoch) and whether the next mini-batch will be the start of a new epoch (`is_new_epoch`), and so on.\n", 137 | "\n", 138 | "The code below shows how to create a [SerialIterator](https://docs.chainer.org/en/latest/reference/generated/chainer.iterators.SerialIterator.html#chainer.iterators.SerialIterator) object from a dataset object." 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 4, 144 | "metadata": { 145 | "collapsed": true 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "from chainer import iterators\n", 150 | "\n", 151 | "# Choose the minibatch size.\n", 152 | "batchsize = 128\n", 153 | "\n", 154 | "train_iter = iterators.SerialIterator(train, batchsize)\n", 155 | "test_iter = iterators.SerialIterator(\n", 156 | " test, batchsize, repeat=False, shuffle=False)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "**Note**\n", 164 | "\n", 165 | "`iterator`s can take a built-in Python list as a given dataset. It means that the example code below is able to work,\n", 166 | "\n", 167 | "```\n", 168 | "train = [(x1, t1), (x2, t2), ...] # A list of tuples\n", 169 | "train_iter = iterators.SerialIterator(train, batchsize)\n", 170 | "```\n", 171 | "\n", 172 | "where `x1, x2, ...` denote the input data and `t1, t2, ...` denote the corresponding labels." 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "## Details of [SerialIterator](https://docs.chainer.org/en/latest/reference/generated/chainer.iterators.SerialIterator.html#chainer.iterators.SerialIterator)\n", 180 | "\n", 181 | "- [SerialIterator](https://docs.chainer.org/en/latest/reference/generated/chainer.iterators.SerialIterator.html#chainer.iterators.SerialIterator) is a built-in subclass of [Iterator](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator) that can retrieve a mini-batch from a given dataset in either sequential or shuffled order.\n", 182 | "- The [Iterator](https://docs.chainer.org/en/latest/reference/core/generated/chainer.dataset.Iterator.html#chainer.dataset.Iterator)‘s constructor takes two arguments: a dataset object and a mini-batch size.\n", 183 | "- If you want to use the same dataset repeatedly during the training process, set the `repeat` argument to `True` (default). Otherwise, the dataset will be used only one time. The latter case is actually for the evaluation.\n", 184 | "- If you want to shuffle the training dataset every epoch, set the `shuffle` argument to `True`. Otherwise, the order of each data retrieved from the dataset will be always the same at each epoch.\n", 185 | "\n", 186 | "In the example code shown above, we set `batchsize = 128` in both `train_iter` and `test_iter`. So, these iterators will provide 128 images and corresponding labels at a time." 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "# 3. Define a network\n", 194 | "\n", 195 | "Now let’s define a neural network that we will train to classify the MNIST images. For simplicity, we use a three-layer perceptron here. We set each hidden layer to have 100 units and set the output layer to have 10 units, which is corresponding to the number of class labels of the MNIST." 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## Create your network as a subclass of Chain\n", 203 | "\n", 204 | "You can create your network by writing a new subclass of [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain). The main steps are twofold:\n", 205 | "\n", 206 | "1. Register the network components which have trainable parameters to the subclass. Each of them must be instantiated and assigned to a property in the scope specified by [init_scope()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain.init_scope)\n", 207 | "2. Define a `__call__()` method that represents the actual **forward computation** of your network. This method takes one or more [Variable](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable), `numpy.array`, or `cupy.array` as its inputs and calculates the forward pass using them." 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 5, 213 | "metadata": { 214 | "collapsed": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "class MyNetwork(Chain):\n", 219 | "\n", 220 | " def __init__(self, n_mid_units=100, n_out=10):\n", 221 | " super(MyNetwork, self).__init__()\n", 222 | " with self.init_scope():\n", 223 | " self.l1 = L.Linear(None, n_mid_units)\n", 224 | " self.l2 = L.Linear(n_mid_units, n_mid_units)\n", 225 | " self.l3 = L.Linear(n_mid_units, n_out)\n", 226 | "\n", 227 | " def __call__(self, x):\n", 228 | " h = F.relu(self.l1(x))\n", 229 | " h = F.relu(self.l2(h))\n", 230 | " return self.l3(h)\n", 231 | "\n", 232 | "model = MyNetwork()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "[Link](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Link.html#chainer.Link), [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain), [ChainList](https://docs.chainer.org/en/latest/reference/core/generated/chainer.ChainList.html#chainer.ChainList), and those subclass objects which contain trainable parameters should be registered to the model by assigning it as a property inside the [init_scope()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain.init_scope). For example, a [Function](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Function.html#chainer.Function) does not contain any trainable parameters, so there is no need to keep the object as a property of your network. When you want to use [relu()](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.relu.html#chainer.functions.relu) in your network, using it as a function in `__call__()` works correctly.\n", 240 | "\n", 241 | "In Chainer, the Python code that implements the forward computation itself represents the network. In other words, we can conceptually think of the computation graph for our network being constructed dynamically as this forward computation code executes. This allows Chainer to describe networks in which different computations can be performed in each iteration, such as branched networks, intuitively and with a high degree of flexibility. This is the key feature of Chainer that we call **Define-by-Run**." 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "# 4. Select an optimization algorithm\n", 249 | "\n", 250 | "Chainer provides a wide variety of optimization algorithms that can be used to optimize the network parameters during training. They are located in `optimizers` module.\n", 251 | "\n", 252 | "Here, we are going to use the stochastic gradient descent (SGD) method with momentum, which is implemented by [MomentumSGD](https://docs.chainer.org/en/latest/reference/generated/chainer.optimizers.MomentumSGD.html#chainer.optimizers.MomentumSGD). To use the optimizer, we give the network object (typically it’s a [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) or [ChainList](https://docs.chainer.org/en/latest/reference/core/generated/chainer.ChainList.html#chainer.ChainList)) to the [setup()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Optimizer.html#chainer.Optimizer.setup) method of the optimizer object to register it. In this way, the [Optimizer](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Optimizer.html#chainer.Optimizer) can automatically find the model parameters and update them during training.\n", 253 | "\n", 254 | "You can easily try out other optimizers as well. Please test and observe the results of various optimizers. For example, you could try to change [MomentumSGD](https://docs.chainer.org/en/latest/reference/generated/chainer.optimizers.MomentumSGD.html#chainer.optimizers.MomentumSGD) to [Adam](https://docs.chainer.org/en/latest/reference/generated/chainer.optimizers.Adam.html#chainer.optimizers.Adam), [RMSprop](https://docs.chainer.org/en/latest/reference/generated/chainer.optimizers.RMSprop.html#chainer.optimizers.RMSprop), etc." 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 6, 260 | "metadata": { 261 | "collapsed": true 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "# Choose an optimizer algorithm\n", 266 | "optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)\n", 267 | "\n", 268 | "# Give the optimizer a reference to the model so that it\n", 269 | "# can locate the model's parameters.\n", 270 | "optimizer.setup(model)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "**Note**\n", 278 | "\n", 279 | "In the above example, we set `lr` to 0.01 in the constructor. This value is known as the “learning rate”, one of the most important hyperparameters that need to be adjusted in order to obtain the best performance. The various optimizers may each have different hyperparameters and so be sure to check the documentation for the details." 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": {}, 285 | "source": [ 286 | "# 5. Write a training loop\n", 287 | "\n", 288 | "We now show how to write the training loop. Since we are working on a digit classification problem, we will use [softmax_cross_entropy()](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.softmax_cross_entropy.html#chainer.functions.softmax_cross_entropy) as the loss function for the optimizer to minimize. For other types of problems, such as regression models, other loss functions might be more appropriate. See the [Chainer documentation for detailed information on the various loss functions](http://docs.chainer.org/en/stable/reference/functions.html#loss-functions) for more details.\n", 289 | "\n", 290 | "Our training loop will be structured as follows.\n", 291 | "\n", 292 | "1. We will first get a mini-batch of examples from the training dataset.\n", 293 | "2. We will then feed the batch into our network by calling it (a [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) object) like a function. This will execute the forward-pass code that are written in the `__call__()` method.\n", 294 | "3. This will return the network output that represents class label predictions. We supply it to the loss function along with the true (that is, target) values. The loss function will output the loss as a [Variable](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable) object.\n", 295 | "4. We then clear any previous gradients in the network and perform the backward pass by calling the [backward()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable.backward) method on the loss variable which computes the parameter gradients. We need to clear the gradients first because the [backward()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable.backward) method accumulates gradients instead of overwriting the previous values.\n", 296 | "5. Since the optimizer already has a reference to the network, it has access to the parameters and the computed gradients so that we can now call the [update()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Optimizer.html#chainer.Optimizer.update) method of the optimizer which will update the model parameters.\n", 297 | "\n", 298 | "In addition to the above steps, you might want to check the performance of the network with a validation dataset. This allows you to observe how well it is generalized to new data so far, namely, you can check whether it is overfitting to the training data. The code below checks the performance on the test set at the end of each epoch. The code has the same structure as the training code except that no backpropagation is performed and we also compute the accuracy on the test data using the [accuracy()](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.accuracy.html#chainer.functions.accuracy) function.\n", 299 | "\n", 300 | "The training loop code is as follows:" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 7, 306 | "metadata": {}, 307 | "outputs": [ 308 | { 309 | "name": "stdout", 310 | "output_type": "stream", 311 | "text": [ 312 | "epoch:01 train_loss:0.3986 val_loss:0.2676 val_accuracy:0.9232\n", 313 | "epoch:02 train_loss:0.2622 val_loss:0.1954 val_accuracy:0.9414\n", 314 | "epoch:03 train_loss:0.1323 val_loss:0.1561 val_accuracy:0.9540\n", 315 | "epoch:04 train_loss:0.1114 val_loss:0.1277 val_accuracy:0.9621\n", 316 | "epoch:05 train_loss:0.0467 val_loss:0.1106 val_accuracy:0.9672\n", 317 | "epoch:06 train_loss:0.0747 val_loss:0.0971 val_accuracy:0.9710\n", 318 | "epoch:07 train_loss:0.0632 val_loss:0.0938 val_accuracy:0.9723\n", 319 | "epoch:08 train_loss:0.0468 val_loss:0.0982 val_accuracy:0.9690\n", 320 | "epoch:09 train_loss:0.0551 val_loss:0.0807 val_accuracy:0.9755\n", 321 | "epoch:10 train_loss:0.0933 val_loss:0.0767 val_accuracy:0.9773\n" 322 | ] 323 | } 324 | ], 325 | "source": [ 326 | "import numpy as np\n", 327 | "from chainer.dataset import concat_examples\n", 328 | "from chainer.cuda import to_cpu\n", 329 | "\n", 330 | "max_epoch = 10\n", 331 | "gpu_id = 0 # If you want to use GPU, set 0 (GPU ID you want to use)\n", 332 | "\n", 333 | "if gpu_id >= 0:\n", 334 | " model.to_gpu(gpu_id)\n", 335 | "\n", 336 | "while train_iter.epoch < max_epoch:\n", 337 | "\n", 338 | " # ---------- One iteration of the training loop ----------\n", 339 | " train_batch = train_iter.next()\n", 340 | " image_train, target_train = concat_examples(train_batch, gpu_id)\n", 341 | "\n", 342 | " # Calculate the prediction of the network\n", 343 | " prediction_train = model(image_train)\n", 344 | "\n", 345 | " # Calculate the loss with softmax_cross_entropy\n", 346 | " loss = F.softmax_cross_entropy(prediction_train, target_train)\n", 347 | "\n", 348 | " # Calculate the gradients in the network\n", 349 | " model.cleargrads()\n", 350 | " loss.backward()\n", 351 | "\n", 352 | " # Update all the trainable paremters\n", 353 | " optimizer.update()\n", 354 | " # --------------------- until here ---------------------\n", 355 | "\n", 356 | " # Check the validation accuracy of prediction after every epoch\n", 357 | " if train_iter.is_new_epoch: # If this iteration is the final iteration of the current epoch\n", 358 | "\n", 359 | " # Display the training loss\n", 360 | " print('epoch:{:02d} train_loss:{:.04f} '.format(\n", 361 | " train_iter.epoch, float(to_cpu(loss.data))), end='')\n", 362 | "\n", 363 | " test_losses = []\n", 364 | " test_accuracies = []\n", 365 | " while True:\n", 366 | " test_batch = test_iter.next()\n", 367 | " image_test, target_test = concat_examples(test_batch, gpu_id)\n", 368 | "\n", 369 | " # Forward the test data\n", 370 | " prediction_test = model(image_test)\n", 371 | "\n", 372 | " # Calculate the loss\n", 373 | " loss_test = F.softmax_cross_entropy(prediction_test, target_test)\n", 374 | " test_losses.append(to_cpu(loss_test.data))\n", 375 | "\n", 376 | " # Calculate the accuracy\n", 377 | " accuracy = F.accuracy(prediction_test, target_test)\n", 378 | " accuracy.to_cpu()\n", 379 | " test_accuracies.append(accuracy.data)\n", 380 | "\n", 381 | " if test_iter.is_new_epoch:\n", 382 | " test_iter.reset()\n", 383 | " break\n", 384 | "\n", 385 | " print('val_loss:{:.04f} val_accuracy:{:.04f}'.format(\n", 386 | " np.mean(test_losses), np.mean(test_accuracies)))" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "# 6. Save the trained model\n", 394 | "\n", 395 | "Chainer provides two types of [serializers](https://docs.chainer.org/en/latest/reference/serializers.html#module-chainer.serializers) that can be used to save and restore model state. One supports the HDF5 format and the other supports the NumPy NPZ format. For this example, we are going to use the NPZ format to save our model since it is easy to use with NumPy and doesn’t need to install any additional dependencies or libraries." 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 8, 401 | "metadata": { 402 | "collapsed": true 403 | }, 404 | "outputs": [], 405 | "source": [ 406 | "serializers.save_npz('my_mnist.model', model)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "# 7. Perform classification by the saved model\n", 414 | "\n", 415 | "Let’s use the saved model to classify a new image. In order to load the trained model parameters, we need to perform the following two steps:\n", 416 | "\n", 417 | "1. Instantiate the same network as what you trained.\n", 418 | "2. Overwrite all parameters in the model instance with the saved weights using the [load_npz()](https://docs.chainer.org/en/latest/reference/generated/chainer.serializers.load_npz.html#chainer.serializers.load_npz) function.\n", 419 | "\n", 420 | "Once the model is restored, it can be used to predict image labels on new input data." 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 9, 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "name": "stdout", 430 | "output_type": "stream", 431 | "text": [ 432 | "label: 7\n" 433 | ] 434 | } 435 | ], 436 | "source": [ 437 | "from chainer import serializers\n", 438 | "\n", 439 | "# Create an instance of the network you trained\n", 440 | "model = MyNetwork()\n", 441 | "\n", 442 | "# Load the saved paremeters into the instance\n", 443 | "serializers.load_npz('my_mnist.model', model)\n", 444 | "\n", 445 | "# Get a test image and label\n", 446 | "x, t = test[0]\n", 447 | "plt.imshow(x.reshape(28, 28), cmap='gray')\n", 448 | "plt.savefig('7.png')\n", 449 | "print('label:', t)" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": {}, 455 | "source": [ 456 | "The saved test image looks like:" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 10, 462 | "metadata": {}, 463 | "outputs": [ 464 | { 465 | "data": { 466 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAHItJREFUeJzt3W9sXYV5+PHHseESQXxpFvxvMZDQQjcgYWPBQ0BFh5U/\nnRiZOolUfREqBBpyqKKooEZrIOuqWTBtQnRZ2IupWdSWUlUi1dDkjbnEDJGkIhWr0nZZzFKRCBza\n0PjG7gg0Ob8XUa2fS4Be174n8fP5SEeK771P/PT0KP32+N6kqSiKIgAASGNO2QsAANBYAhAAIBkB\nCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCS\nEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAA\nIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMA\nAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAy\nAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZFrKXoBz06lTp+K1116LefPmRVNTU9nr\nAFCnoiji+PHj0dXVFXPmuB+UjQBkSl577bXo7u4uew0AfkOHDh2KhQsXlr0GDSb5mZJ58+aVvQIA\n08Cf5zkJwMS2bNkSl19+eVxwwQXR09MT3/3ud3/tWT/2BZgd/HmekwBM6qmnnooNGzbEww8/HN/7\n3vdi6dKlsWLFinjjjTfKXg0AmGFNRVEUZS9B4/X09MSyZcvi7//+7yPi9Ic6uru74/7774/Pf/7z\nHzhfq9WiWq3O9JoAzLDR0dFobW0tew0azB3AhN5+++3Yu3dv9Pb2Tjw2Z86c6O3tjV27dp1x5sSJ\nE1Gr1SYdAMC5SQAm9NOf/jROnjwZ7e3tkx5vb2+PkZGRM8709/dHtVqdOHwCGADOXQKQX8vGjRtj\ndHR04jh06FDZKwEAU+TvAUxowYIF0dzcHEeOHJn0+JEjR6Kjo+OMM5VKJSqVSiPWAwBmmDuACZ1/\n/vlx/fXXx+Dg4MRjp06disHBwbjxxhtL3AwAaAR3AJPasGFDrF27Nv7gD/4gbrjhhnjsscdifHw8\nPvOZz5S9GgAwwwRgUnfeeWf85Cc/iYceeihGRkbiuuuui4GBgXd9MAQAmH38PYBMib8HEGB28PcA\n5uQ9gAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDIC\nEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAk\nIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAA\nQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYA\nAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwCT2rx5czQ1NU06PvrR\nj5a9FgDQAC1lL0B5rr766viP//iPia9bWlwOAJCB/8VPrKWlJTo6OspeAwBoMD8CTuzAgQPR1dUV\nixcvjk9/+tPx6quvvudrT5w4EbVabdIBAJybBGBSPT09sW3bthgYGIitW7fGwYMH45Zbbonjx4+f\n8fX9/f1RrVYnju7u7gZvDABMl6aiKIqyl6B8x44di8suuyz+7u/+Lu6+++53PX/ixIk4ceLExNe1\nWk0EAswCo6Oj0draWvYaNJj3ABIRERdffHFceeWVMTw8fMbnK5VKVCqVBm8FAMwEPwImIiLGxsZi\neHg4Ojs7y14FAJhhAjCpz33uczE0NBQ//vGP48UXX4w//dM/jZaWlvjUpz5V9moAwAzzI+CkDh8+\nHJ/61Kfi6NGjcckll8TNN98cu3fvjksuuaTs1QCAGeZDIExJrVaLarVa9hoA/IZ8CCQnPwIGAEhG\nAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMm0lL0A\nZPRnf/Zndc/cc889U/per732Wt0zb731Vt0zX/va1+qeGRkZqXsmImJ4eHhKcwCc5g4gAEAyAhAA\nIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJNNU\nFEVR9hKce2q1WlSr1bLXOGf97//+b90zl19++fQvUrLjx49Pae4HP/jBNG/CdDt8+HDdM48++uiU\nvtdLL700pTlOGx0djdbW1rLXoMHcAQQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEI\nAJCMAAQASEYAAgAkIwABAJIRgAAAybSUvQBkdM8999Q9s2TJkil9rx/96Ed1z/zO7/xO3TO///u/\nX/fMrbfeWvdMRMQf/uEf1j1z6NChume6u7vrnmmkX/ziF3XP/OQnP6l7prOzs+6ZqXj11VenNPfS\nSy9N8yYw+7kDCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhG\nAAIAJCMAAQCSaSl7AchocHCwITNTNTAw0JDv86EPfWhKc9ddd13dM3v37q17ZtmyZXXPNNJbb71V\n98z//M//1D3zox/9qO6Z+fPn1z3zyiuv1D0DTI07gAAAyQhAAIBkBOAs9Pzzz8ftt98eXV1d0dTU\nFDt27Jj0fFEU8dBDD0VnZ2fMnTs3ent748CBAyVtCwA0mgCchcbHx2Pp0qWxZcuWMz7/6KOPxuOP\nPx5PPPFE7NmzJy688MJYsWLFlN5PBACce3wIZBZatWpVrFq16ozPFUURjz32WHzhC1+IO+64IyIi\ntm/fHu3t7bFjx45Ys2ZNI1cFAErgDmAyBw8ejJGRkejt7Z14rFqtRk9PT+zates9506cOBG1Wm3S\nAQCcmwRgMiMjIxER0d7ePunx9vb2iefOpL+/P6rV6sTR3d09o3sCADNHAPJr2bhxY4yOjk4chw4d\nKnslAGCKBGAyHR0dERFx5MiRSY8fOXJk4rkzqVQq0draOukAAM5NAjCZRYsWRUdHx6R/VaJWq8We\nPXvixhtvLHEzAKBRfAp4FhobG4vh4eGJrw8ePBgvv/xyzJ8/Py699NJYv359fOlLX4qPfOQjsWjR\noti0aVN0dXXF6tWrS9waAGgUATgLvfTSS/Hxj3984usNGzZERMTatWtj27Zt8eCDD8b4+Hjce++9\ncezYsbj55ptjYGAgLrjggrJWBgAaqKkoiqLsJTj31Gq1qFarZa8B1OmTn/xk3TPf/OY3657Zt29f\n3TP///9xrcebb745pTlOGx0d9b7uhLwHEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBk\nBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJJpKXsBAKamra2t7pl/+Id/qHtmzpz67xV88Ytf\nrHvmzTffrHsGmBp3AAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGA\nAADJCEAAgGQEIABAMi1lLwDA1PT19dU9c8kll9Q987Of/azumf3799c9AzSOO4AAAMkIQACAZAQg\nAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIJmWshcAyO6mm26a\n0tznP//5ad7kzFavXl33zL59+2ZgE2C6uAMIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAA\nQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJJpKXsBgOw+8YlPTGnuvPPOq3tmcHCw7pldu3bVPQOc\n3dwBBABIRgACACQjAGeh559/Pm6//fbo6uqKpqam2LFjx6Tn77rrrmhqapp0rFy5sqRtAYBGE4Cz\n0Pj4eCxdujS2bNnynq9ZuXJlvP766xPHk08+2cANAYAy+RDILLRq1apYtWrV+76mUqlER0dHgzYC\nAM4m7gAmtXPnzmhra4urrroq7rvvvjh69Oj7vv7EiRNRq9UmHQDAuUkAJrRy5crYvn17DA4OxiOP\nPBJDQ0OxatWqOHny5HvO9Pf3R7VanTi6u7sbuDEAMJ38CDihNWvWTPz62muvjSVLlsQVV1wRO3fu\njNtuu+2MMxs3bowNGzZMfF2r1UQgAJyj3AEkFi9eHAsWLIjh4eH3fE2lUonW1tZJBwBwbhKAxOHD\nh+Po0aPR2dlZ9ioAQAP4EfAsNDY2Nulu3sGDB+Pll1+O+fPnx/z58+Mv//Iv45Of/GR0dHTEK6+8\nEg8++GB8+MMfjhUrVpS4NQDQKAJwFnrppZfi4x//+MTXv3zv3tq1a2Pr1q3x/e9/P/75n/85jh07\nFl1dXbF8+fL4q7/6q6hUKmWtDAA0UFNRFEXZS3DuqdVqUa1Wy14Dzjpz586te+aFF16Y0ve6+uqr\n6575oz/6o7pnXnzxxbpnOHeMjo56X3dC3gMIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAA\nQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAybSUvQDAbPLAAw/UPfN7v/d7U/peAwMDdc+8\n+OKLU/pewOziDiAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAg\nGQEIAJCMAAQASKal7AUAzlZ//Md/XPfMpk2b6p6p1Wp1z0REfPGLX5zSHIA7gAAAyQhAAIBkBCAA\nQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgmZayFwBohN/6rd+q\ne+bxxx+ve6a5ubnumX/913+teyYiYvfu3VOaA3AHEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAA\nyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAk01L2AgD1am5urntmYGCg7plFixbVPfPKK6/U\nPbNp06a6ZwB+E+4AAgAkIwABAJIRgLNQf39/LFu2LObNmxdtbW2xevXq2L9//6TXFEURDz30UHR2\ndsbcuXOjt7c3Dhw4UNLGAEAjCcBZaGhoKPr6+mL37t3x7LPPxjvvvBPLly+P8fHxidc8+uij8fjj\nj8cTTzwRe/bsiQsvvDBWrFgRb731VombAwCN4EMgs9Cvvtl927Zt0dbWFnv37o2PfexjURRFPPbY\nY/GFL3wh7rjjjoiI2L59e7S3t8eOHTtizZo1ZawNADSIO4AJjI6ORkTE/PnzIyLi4MGDMTIyEr29\nvROvqVar0dPTE7t27Trj73HixImo1WqTDgDg3CQAZ7lTp07F+vXr46abboprrrkmIiJGRkYiIqK9\nvX3Sa9vb2yee+1X9/f1RrVYnju7u7pldHACYMQJwluvr64t9+/bFN77xjd/o99m4cWOMjo5OHIcO\nHZqmDQGARvMewFls3bp18cwzz8Tzzz8fCxcunHi8o6MjIiKOHDkSnZ2dE48fOXIkrrvuujP+XpVK\nJSqVyswuDAA0hDuAs1BRFLFu3bp4+umn4zvf+c67/jWDRYsWRUdHRwwODk48VqvVYs+ePXHjjTc2\nel0AoMHcAZyF+vr64utf/3p8+9vfjnnz5k28r69arcbcuXOjqakp1q9fH1/60pfiIx/5SCxatCg2\nbdoUXV1dsXr16pK3BwBmmgCchbZu3RoREbfeeuukx7/yla/EXXfdFRERDz74YIyPj8e9994bx44d\ni5tvvjkGBgbiggsuaPC2AECjNRVFUZS9BOeeWq0W1Wq17DVI6sorr6x75r//+79nYJN3++XfrVmP\nf/mXf5mBTeDXMzo6Gq2trWWvQYN5DyAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgAAAyQhA\nAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAk01L2AkBel1122ZTm/v3f/32aNzmzBx54oO6ZZ555\nZgY2AZhe7gACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQjAAEAkhGA\nAADJCEAAgGRayl4AyOvee++d0tyll146zZuc2dDQUN0zRVHMwCYA08sdQACAZAQgAEAyAhAAIBkB\nCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQTEvZCwCzw80331z3zP33\n3z8DmwDwQdwBBABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQACAZAQgAkIwABABIRgACACQj\nAAEAkhGAAADJtJS9ADA73HLLLXXPXHTRRTOwyZm98sordc+MjY3NwCYA5XMHEAAgGQEIAJCMAJyF\n+vv7Y9myZTFv3rxoa2uL1atXx/79+ye95q677oqmpqZJx8qVK0vaGABoJAE4Cw0NDUVfX1/s3r07\nnn322XjnnXdi+fLlMT4+Pul1K1eujNdff33iePLJJ0vaGABoJB8CmYUGBgYmfb1t27Zoa2uLvXv3\nxsc+9rGJxyuVSnR0dDR6PQCgZO4AJjA6OhoREfPnz5/0+M6dO6OtrS2uuuqquO++++Lo0aPv+Xuc\nOHEiarXapAMAODcJwFnu1KlTsX79+rjpppvimmuumXh85cqVsX379hgcHIxHHnkkhoaGYtWqVXHy\n5Mkz/j79/f1RrVYnju7u7kb9RwAAppkfAc9yfX19sW/fvnjhhRcmPb5mzZqJX1977bWxZMmSuOKK\nK2Lnzp1x2223vev32bhxY2zYsGHi61qtJgIB4BzlDuAstm7dunjmmWfiueeei4ULF77vaxcvXhwL\nFiyI4eHhMz5fqVSitbV10gEAnJvcAZyFiqKI+++/P55++unYuXNnLFq06ANnDh8+HEePHo3Ozs4G\nbAgAlMkdwFmor68vvvrVr8bXv/71mDdvXoyMjMTIyEj83//9X0Sc/uetHnjggdi9e3f8+Mc/jsHB\nwbjjjjviwx/+cKxYsaLk7QGAmSYAZ6GtW7fG6Oho3HrrrdHZ2TlxPPXUUxER0dzcHN///vfjT/7k\nT+LKK6+Mu+++O66//vr4z//8z6hUKiVvDwDMND8CnoWKonjf5+fOnRv/9m//1qBtAICzjQAEzjn/\n9V//VffMmT7d/kHefPPNumcAzgV+BAwAkIwABABIRgACACQjAAEAkhGAAADJCEAAgGQEIABAMgIQ\nACAZAQgAkIwABABIRgACACQjAAEAkmkqiqIoewnOPbVaLarVatlrAPAbGh0djdbW1rLXoMHcAQQA\nSEYAAgAkIwABAJIRgAAAyQhAAIBkBCAAQDICEAAgGQEIAJCMAAQASEYAAgAkIwABAJIRgEyJf0Ia\nYHbw53lOApApOX78eNkrADAN/HmeU1Mh/ZmCU6dOxWuvvRbz5s2LpqamSc/VarXo7u6OQ4cORWtr\na0kbls95OM15OM15OM15OO1sOA9FUcTx48ejq6sr5sxxPyiblrIX4Nw0Z86cWLhw4fu+prW1NfUf\n8L/kPJzmPJzmPJzmPJxW9nmoVqulfW/KJfkBAJIRgAAAyTRv3rx5c9lLMPs0NzfHrbfeGi0tud9l\n4Dyc5jyc5jyc5jyc5jxQJh8CAQBIxo+AAQCSEYAAAMkIQACAZAQgAEAyApBptWXLlrj88svjggsu\niJ6envjud79b9koNtXnz5mhqapp0fPSjHy17rRn3/PPPx+233x5dXV3R1NQUO3bsmPR8URTx0EMP\nRWdnZ8ydOzd6e3vjwIEDJW07cz7oPNx1113vuj5WrlxZ0rYzp7+/P5YtWxbz5s2Ltra2WL16dezf\nv3/SazJcE7/OechyTXD2EYBMm6eeeio2bNgQDz/8cHzve9+LpUuXxooVK+KNN94oe7WGuvrqq+P1\n11+fOF544YWyV5px4+PjsXTp0tiyZcsZn3/00Ufj8ccfjyeeeCL27NkTF154YaxYsSLeeuutBm86\nsz7oPERErFy5ctL18eSTTzZww8YYGhqKvr6+2L17dzz77LPxzjvvxPLly2N8fHziNRmuiV/nPETk\nuCY4CxUwTW644Yair69v4uuTJ08WXV1dRX9/f4lbNdbDDz9cLF26tOw1ShURxdNPPz3x9alTp4qO\njo7ib/7mbyYeO3bsWFGpVIonn3yyjBUb4lfPQ1EUxdq1a4s77rijpI3K88YbbxQRUQwNDRVFkfea\n+NXzUBR5rwnK5w4g0+Ltt9+OvXv3Rm9v78Rjc+bMid7e3ti1a1eJmzXegQMHoqurKxYvXhyf/vSn\n49VXXy17pVIdPHgwRkZGJl0b1Wo1enp60l0bERE7d+6Mtra2uOqqq+K+++6Lo0ePlr3SjBsdHY2I\niPnz50dE3mviV8/DL2W8JiifAGRa/PSnP42TJ09Ge3v7pMfb29tjZGSkpK0ar6enJ7Zt2xYDAwOx\ndevWOHjwYNxyyy1x/PjxslcrzS//+89+bUSc/lHf9u3bY3BwMB555JEYGhqKVatWxcmTJ8tebcac\nOnUq1q9fHzfddFNcc801EZHzmjjTeYjIeU1wdvDvz8A0WrVq1cSvlyxZEj09PXHZZZfFN7/5zbj7\n7rtL3IyzwZo1ayZ+fe2118aSJUviiiuuiJ07d8Ztt91W4mYzp6+vL/bt25fivbDv573OQ8ZrgrOD\nO4BMiwULFkRzc3McOXJk0uNHjhyJjo6OkrYq38UXXxxXXnllDA8Pl71KaX75379r490WL14cCxYs\nmLXXx7p16+KZZ56J5557LhYuXDjxeLZr4r3Ow5nM9muCs4cAZFqcf/75cf3118fg4ODEY6dOnYrB\nwcG48cYbS9ysXGNjYzE8PBydnZ1lr1KaRYsWRUdHx6Rro1arxZ49e1JfGxERhw8fjqNHj86666Mo\nili3bl08/fTT8Z3vfCcWLVo06fks18QHnYczma3XBGef5s2bN28uewlmh9bW1ti0aVN0d3dHpVKJ\nTZs2xcsvvxz/9E//FBdddFHZ6zXE5z73uahUKhER8cMf/jD+/M//PN5444144okn4sILLyx5u5kz\nNjYWP/zhD2NkZCT+8R//MXp6emLu3Lnx9ttvx8UXXxwnT56Mv/7rv47f/d3fjbfffjs++9nPxs9/\n/vP48pe/HC0ts+edKO93Hpqbm+Mv/uIvorW1NX7xi1/E3r174+67746LLroo/vZv/3ZWnYe+vr74\n2te+Ft/61reiq6srxsbGYmxsLJqbm+O8886LpqamFNfEB52HsbGxNNcEZ6GyP4bM7PLlL3+5uPTS\nS4vzzz+/uOGGG4rdu3eXvVJD3XnnnUVnZ2dx/vnnF7/9279d3HnnncXw8HDZa8245557roiIdx1r\n164tiuL0X/uxadOmor29vahUKsVtt91W7N+/v9ylZ8D7nYef//znxfLly4tLLrmkOO+884rLLrus\nuOeee4qRkZGy1552ZzoHEVF85StfmXhNhmvig85DpmuCs09TURRFI4MTAIByeQ8gAEAyAhAAIBkB\nCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCS\nEYAAAMkIQACAZAQgAEAyAhAAIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMkIQACAZAQgAEAyAhAA\nIBkBCACQjAAEAEhGAAIAJCMAAQCSEYAAAMn8P6Oq2cbCfoIzAAAAAElFTkSuQmCC\n", 467 | "text/plain": [ 468 | "" 469 | ] 470 | }, 471 | "execution_count": 10, 472 | "metadata": {}, 473 | "output_type": "execute_result" 474 | } 475 | ], 476 | "source": [ 477 | "Image('7.png')" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 11, 483 | "metadata": {}, 484 | "outputs": [ 485 | { 486 | "name": "stdout", 487 | "output_type": "stream", 488 | "text": [ 489 | "(784,) -> (1, 784)\n", 490 | "predicted label: 7\n" 491 | ] 492 | } 493 | ], 494 | "source": [ 495 | "# Change the shape of the minibatch.\n", 496 | "# In this example, the size of minibatch is 1.\n", 497 | "# Inference using any mini-batch size can be performed.\n", 498 | "\n", 499 | "print(x.shape, end=' -> ')\n", 500 | "x = x[None, ...]\n", 501 | "print(x.shape)\n", 502 | "\n", 503 | "# forward calculation of the model by sending X\n", 504 | "y = model(x)\n", 505 | "\n", 506 | "# The result is given as Variable, then we can take a look at the contents by the attribute, .data.\n", 507 | "y = y.data\n", 508 | "\n", 509 | "# Look up the most probable digit number using argmax\n", 510 | "pred_label = y.argmax(axis=1)\n", 511 | "\n", 512 | "print('predicted label:', pred_label[0])" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "metadata": {}, 518 | "source": [ 519 | "**The prediction result looks correct. Yay!**" 520 | ] 521 | } 522 | ], 523 | "metadata": { 524 | "kernelspec": { 525 | "display_name": "Python 3", 526 | "language": "python", 527 | "name": "python3" 528 | }, 529 | "language_info": { 530 | "codemirror_mode": { 531 | "name": "ipython", 532 | "version": 3 533 | }, 534 | "file_extension": ".py", 535 | "mimetype": "text/x-python", 536 | "name": "python", 537 | "nbconvert_exporter": "python", 538 | "pygments_lexer": "ipython3", 539 | "version": "3.6.1" 540 | } 541 | }, 542 | "nbformat": 4, 543 | "nbformat_minor": 2 544 | } 545 | -------------------------------------------------------------------------------- /3_write_convnet_in_chainer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# How to write ConvNet models in Chainer\n", 8 | "\n", 9 | "In this notebook, you will learn how to write\n", 10 | "\n", 11 | "- A small convolutional network with a model class that is inherited from Chain,\n", 12 | "- A large convolutional network that has several building block networks with ChainList.\n", 13 | "\n", 14 | "A convolutional network (ConvNet) is mainly comprised of convolutional layers. This type of network is commonly used for various visual recognition tasks, e.g., classifying hand-written digits or natural images into given object classes, detecting objects from an image, and labeling all pixels of an image with the object classes (semantic segmentation), and so on.\n", 15 | "\n", 16 | "In such tasks, a typical ConvNet takes a set of images whose shape is $(N,C,H,W)$, where\n", 17 | "\n", 18 | "- $N$ denotes the number of images in a mini-batch,\n", 19 | "- $C$ denotes the number of channels of those images,\n", 20 | "- $H$ and $W$ denote the height and width of those images,\n", 21 | "\n", 22 | "respectively. Then, it typically outputs a fixed-sized vector as membership probabilities over the target object classes. It also can output a set of feature maps that have the corresponding size to the input image for a pixel labeling task, etc." 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "First, let's import the necessary packages for using Chainer." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": { 36 | "collapsed": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "import numpy as np\n", 41 | "import chainer\n", 42 | "from chainer import cuda, Function, gradient_check, report, training, utils, Variable\n", 43 | "from chainer import datasets, iterators, optimizers, serializers\n", 44 | "from chainer import Link, Chain, ChainList\n", 45 | "import chainer.functions as F\n", 46 | "import chainer.links as L\n", 47 | "from chainer.training import extensions" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "# LeNet5\n", 55 | "\n", 56 | "Here, let’s start by defining LeNet5 [[LeCun98]](#LeCun98) in Chainer. This is a ConvNet model that has 5 layers comprised of 3 convolutional layers and 2 fully-connected layers. This was proposed to classify hand-written digit images in 1998. In Chainer, the model can be written as follows:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": { 63 | "collapsed": true 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "class LeNet5(Chain):\n", 68 | " def __init__(self):\n", 69 | " super(LeNet5, self).__init__()\n", 70 | " with self.init_scope():\n", 71 | " self.conv1 = L.Convolution2D(\n", 72 | " in_channels=1, out_channels=6, ksize=5, stride=1)\n", 73 | " self.conv2 = L.Convolution2D(\n", 74 | " in_channels=6, out_channels=16, ksize=5, stride=1)\n", 75 | " self.conv3 = L.Convolution2D(\n", 76 | " in_channels=16, out_channels=120, ksize=4, stride=1)\n", 77 | " self.fc4 = L.Linear(None, 84)\n", 78 | " self.fc5 = L.Linear(84, 10)\n", 79 | "\n", 80 | " def __call__(self, x):\n", 81 | " h = F.sigmoid(self.conv1(x))\n", 82 | " h = F.max_pooling_2d(h, 2, 2)\n", 83 | " h = F.sigmoid(self.conv2(h))\n", 84 | " h = F.max_pooling_2d(h, 2, 2)\n", 85 | " h = F.sigmoid(self.conv3(h))\n", 86 | " h = F.sigmoid(self.fc4(h))\n", 87 | " if chainer.config.train:\n", 88 | " return self.fc5(h)\n", 89 | " return F.softmax(self.fc5(h))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "A typical way to write your network is creating a new class inherited from [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) class. When defining your model in this way, typically, all the layers which have trainable parameters are registered to the model by assigning the objects of [Link](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Link.html#chainer.Link) as an attribute.\n", 97 | "\n", 98 | "The model class is instantiated before the forward and backward computations. To give input images and label vectors simply by calling the model object like a function, `__call__()` is usually defined in the model class. This method performs the forward computation of the model. Chainer uses the powerful autograd system for any computational graphs written with [Function](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Function.html#chainer.Function)s and [Link](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Link.html#chainer.Link)s (actually a [Link](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Link.html#chainer.Link) calls a corresponding [Function](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Function.html#chainer.Function) inside of it), so that you don’t need to explicitly write the code for backward computations in the model. Just prepare the data, then give it to the model. The way this works is the resulting output [Variable](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable) from the forward computation has a [backward()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable.backward) method to perform autograd. In the above model, `__call__()` has a `if` statement at the end to switch its behavior by the Chainer’s running mode, i.e., training mode or not. Chainer presents the running mode as a global variable `chainer.config.train`. When it’s in training mode, `__call__()` returns the output value of the last layer as is to compute the loss later on, otherwise it returns a prediction result by calculating [softmax()](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.softmax.html#chainer.functions.softmax)." 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "### Ways to calculate loss\n", 106 | "\n", 107 | "When you train the model with label vector `t`, the loss should be calculated using the output from the model. There also are several ways to calculate the loss:" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": { 114 | "collapsed": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "model = LeNet5()\n", 119 | "\n", 120 | "# Input data and label\n", 121 | "x = np.random.rand(32, 1, 28, 28).astype(np.float32)\n", 122 | "t = np.random.randint(0, 10, size=(32,)).astype(np.int32)\n", 123 | "\n", 124 | "# Forward computation\n", 125 | "y = model(x)\n", 126 | "\n", 127 | "# Loss calculation\n", 128 | "loss = F.softmax_cross_entropy(y, t)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "This is a primitive way to calculate a loss value from the output of the model. On the other hand, the loss computation can be included in the model itself by wrapping the model object ([Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) or [ChainList](https://docs.chainer.org/en/latest/reference/core/generated/chainer.ChainList.html#chainer.ChainList) object) with a class inherited from Chain. The outer [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) should take the model defined above and register it with [init_scope()](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain.init_scope). [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) is actually inherited from Link, so that [Chain](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Chain.html#chainer.Chain) itself can also be registered as a trainable [Link](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Link.html#chainer.Link) to another Chain. Actually, [Classifier](https://docs.chainer.org/en/latest/reference/generated/chainer.links.Classifier.html#chainer.links.Classifier) class to wrap the model and add the loss computation to the model already exists. Actually, there is already a [Classifier](https://docs.chainer.org/en/latest/reference/generated/chainer.links.Classifier.html#chainer.links.Classifier) class that can be used to wrap the model and include the loss computation as well. It can be used like this:" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 4, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "model = L.Classifier(LeNet5())\n", 147 | "\n", 148 | "# Foward & Loss calculation\n", 149 | "loss = model(x, t)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "This class takes a model object as an input argument and registers it to a `predictor` property as a trained parameter. As shown above, the returned object can then be called like a function in which we pass `x` and `t` as the input arguments and the resulting loss value (which we recall is a [Variable](https://docs.chainer.org/en/latest/reference/core/generated/chainer.Variable.html#chainer.Variable)) is returned.\n", 157 | "\n", 158 | "See the detailed implementation of [Classifier](https://docs.chainer.org/en/latest/reference/generated/chainer.links.Classifier.html#chainer.links.Classifier) from here: [chainer.links.Classifier](https://docs.chainer.org/en/latest/reference/generated/chainer.links.Classifier.html#chainer.links.Classifier) and check the implementation by looking at the source.\n", 159 | "\n", 160 | "From the above examples, we can see that Chainer provides the flexibility to write our original network in many different ways. Such flexibility intends to make it intuitive for users to design new and complex models." 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "# VGG16\n", 168 | "\n", 169 | "Next, let’s write some larger models in Chainer. When you write a large network consisting of several building block networks, [ChainList](https://docs.chainer.org/en/latest/reference/core/generated/chainer.ChainList.html#chainer.ChainList) is useful. First, let’s see how to write a VGG16 [[Simonyan14]](#Simonyan14) model." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 5, 175 | "metadata": { 176 | "collapsed": true 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "class VGG16(chainer.ChainList):\n", 181 | "\n", 182 | " def __init__(self):\n", 183 | " w = chainer.initializers.HeNormal()\n", 184 | " super(VGG16, self).__init__(\n", 185 | " VGGBlock(64),\n", 186 | " VGGBlock(128),\n", 187 | " VGGBlock(256, 3),\n", 188 | " VGGBlock(512, 3),\n", 189 | " VGGBlock(512, 3, True))\n", 190 | "\n", 191 | " def __call__(self, x):\n", 192 | " for f in self.children():\n", 193 | " x = f(x)\n", 194 | " if chainer.config.train:\n", 195 | " return x\n", 196 | " return F.softmax(x)\n", 197 | "\n", 198 | "\n", 199 | "class VGGBlock(chainer.Chain):\n", 200 | " \n", 201 | " def __init__(self, n_channels, n_convs=2, fc=False):\n", 202 | " w = chainer.initializers.HeNormal()\n", 203 | " super(VGGBlock, self).__init__()\n", 204 | " with self.init_scope():\n", 205 | " self.conv1 = L.Convolution2D(None, n_channels, 3, 1, 1, initialW=w)\n", 206 | " self.conv2 = L.Convolution2D(\n", 207 | " n_channels, n_channels, 3, 1, 1, initialW=w)\n", 208 | " if n_convs == 3:\n", 209 | " self.conv3 = L.Convolution2D(\n", 210 | " n_channels, n_channels, 3, 1, 1, initialW=w)\n", 211 | " if fc:\n", 212 | " self.fc4 = L.Linear(None, 4096, initialW=w)\n", 213 | " self.fc5 = L.Linear(4096, 4096, initialW=w)\n", 214 | " self.fc6 = L.Linear(4096, 1000, initialW=w)\n", 215 | "\n", 216 | " self.n_convs = n_convs\n", 217 | " self.fc = fc\n", 218 | "\n", 219 | " def __call__(self, x):\n", 220 | " h = F.relu(self.conv1(x))\n", 221 | " h = F.relu(self.conv2(h))\n", 222 | " if self.n_convs == 3:\n", 223 | " h = F.relu(self.conv3(h))\n", 224 | " h = F.max_pooling_2d(h, 2, 2)\n", 225 | " if self.fc:\n", 226 | " h = F.dropout(F.relu(self.fc4(h)))\n", 227 | " h = F.dropout(F.relu(self.fc5(h)))\n", 228 | " h = self.fc6(h)\n", 229 | " return h" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "That’s it. VGG16 is a model which won the 1st place in [classification + localization task at ILSVRC 2014](http://www.image-net.org/challenges/LSVRC/2014/results#clsloc), and since then, has become one of the standard models for many different tasks as a pre-trained model. This has 16-layers, so it’s called “VGG-16”, but we can write this model without writing all layers independently. Since this model consists of several building blocks that have the same architecture, we can build the whole network by re-using the building block definition. Each part of the network is consisted of 2 or 3 convolutional layers and activation function ([relu()](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.relu.html#chainer.functions.relu)) following them, and [max_pooling_2d()](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.max_pooling_2d.html#chainer.functions.max_pooling_2d) operations. This block is written as VGGBlock in the above example code. And the whole network just calls this block one by one in sequential manner." 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "# ResNet152\n", 244 | "\n", 245 | "How about ResNet? ResNet [[He16]](#He16) came in the following year’s ILSVRC. It is a much deeper model than VGG16, having up to 152 layers. This sounds super laborious to build, but it can be implemented in almost same manner as VGG16. In the other words, it’s easy. One possible way to write ResNet-152 is:" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "collapsed": true 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "class ResNet152(chainer.Chain):\n", 257 | " \n", 258 | " def __init__(self, n_blocks=[3, 8, 36, 3]):\n", 259 | " w = chainer.initializers.HeNormal()\n", 260 | " super(ResNet152, self).__init__(\n", 261 | " conv1=L.Convolution2D(\n", 262 | " None, 64, 7, 2, 3, initialW=w, nobias=True),\n", 263 | " bn1=L.BatchNormalization(64),\n", 264 | " res2=ResBlock(n_blocks[0], 64, 64, 256, 1),\n", 265 | " res3=ResBlock(n_blocks[1], 256, 128, 512),\n", 266 | " res4=ResBlock(n_blocks[2], 512, 256, 1024),\n", 267 | " res5=ResBlock(n_blocks[3], 1024, 512, 2048),\n", 268 | " fc6=L.Linear(2048, 1000))\n", 269 | "\n", 270 | " def __call__(self, x):\n", 271 | " h = self.bn1(self.conv1(x))\n", 272 | " h = F.max_pooling_2d(F.relu(h), 2, 2)\n", 273 | " h = self.res2(h)\n", 274 | " h = self.res3(h)\n", 275 | " h = self.res4(h)\n", 276 | " h = self.res5(h)\n", 277 | " h = F.average_pooling_2d(h, h.shape[2:], stride=1)\n", 278 | " h = self.fc6(h)\n", 279 | " if chainer.config.train:\n", 280 | " return h\n", 281 | " return F.softmax(h)\n", 282 | "\n", 283 | "\n", 284 | "class ResBlock(chainer.ChainList):\n", 285 | " \n", 286 | " def __init__(self, n_layers, n_in, n_mid, n_out, stride=2):\n", 287 | " w = chainer.initializers.HeNormal()\n", 288 | " super(ResBlock, self).__init__()\n", 289 | " self.add_link(BottleNeck(n_in, n_mid, n_out, stride, True))\n", 290 | " for _ in range(n_layers - 1):\n", 291 | " self.add_link(BottleNeck(n_out, n_mid, n_out))\n", 292 | "\n", 293 | " def __call__(self, x):\n", 294 | " for f in self.children():\n", 295 | " x = f(x)\n", 296 | " return x\n", 297 | "\n", 298 | "\n", 299 | "class BottleNeck(chainer.Chain):\n", 300 | " \n", 301 | " def __init__(self, n_in, n_mid, n_out, stride=1, proj=False):\n", 302 | " w = chainer.initializers.HeNormal()\n", 303 | " super(BottleNeck, self).__init__()\n", 304 | " with self.init_scope():\n", 305 | " self.conv1x1a = L.Convolution2D(\n", 306 | " n_in, n_mid, 1, stride, 0, initialW=w, nobias=True)\n", 307 | " self.conv3x3b = L.Convolution2D(\n", 308 | " n_mid, n_mid, 3, 1, 1, initialW=w, nobias=True)\n", 309 | " self.conv1x1c = L.Convolution2D(\n", 310 | " n_mid, n_out, 1, 1, 0, initialW=w, nobias=True)\n", 311 | " self.bn_a = L.BatchNormalization(n_mid)\n", 312 | " self.bn_b = L.BatchNormalization(n_mid)\n", 313 | " self.bn_c = L.BatchNormalization(n_out)\n", 314 | " if proj:\n", 315 | " self.conv1x1r = L.Convolution2D(\n", 316 | " n_in, n_out, 1, stride, 0, initialW=w, nobias=True)\n", 317 | " self.bn_r = L.BatchNormalization(n_out)\n", 318 | " self.proj = proj\n", 319 | "\n", 320 | " def __call__(self, x):\n", 321 | " h = F.relu(self.bn_a(self.conv1x1a(x)))\n", 322 | " h = F.relu(self.bn_b(self.conv3x3b(h)))\n", 323 | " h = self.bn_c(self.conv1x1c(h))\n", 324 | " if self.proj:\n", 325 | " x = self.bn_r(self.conv1x1r(x))\n", 326 | " return F.relu(h + x)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "In the BottleNeck class, depending on the value of the proj argument supplied to the initializer, it will conditionally compute a convolutional layer `conv1x1r` which will extend the number of channels of the input `x` to be equal to the number of channels of the output of `conv1x1c`, and followed by a batch normalization layer before the final ReLU layer. Writing the building block in this way improves the re-usability of a class. It switches not only the behavior in `__class__()` by flags but also the parameter registration. In this case, when `proj` is `False`, the `BottleNeck` doesn’t have `conv1x1r` and `bn_r` layers, so the memory usage would be efficient compared to the case when it registers both anyway and just ignore them if `proj` is `False`.\n", 334 | "\n", 335 | "Using nested `Chain`s and `ChainList` for sequential part enables us to write complex and very deep models easily." 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "# Use Pre-trained Models\n", 343 | "\n", 344 | "Various ways to write your models were described above. It turns out that VGG16 and ResNet are very useful as general feature extractors for many kinds of tasks, including but not limited to image classification. So, Chainer provides you with the pre-trained VGG16 and ResNet-50/101/152 models with a simple API. You can use these models as follows:" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "Downloading from http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel...\n", 357 | "Now loading caffemodel (usually it may take few minutes)\n" 358 | ] 359 | } 360 | ], 361 | "source": [ 362 | "from chainer.links import VGG16Layers\n", 363 | "\n", 364 | "model = VGG16Layers()" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "When [VGG16Layers](https://docs.chainer.org/en/latest/reference/generated/chainer.links.VGG16Layers.html#chainer.links.VGG16Layers) is instantiated, the pre-trained parameters are automatically downloaded from the author’s server. So you can immediately start to use VGG16 with pre-trained weight as a good image feature extractor. See the details of this model here: [chainer.links.VGG16Layers](https://docs.chainer.org/en/latest/reference/generated/chainer.links.VGG16Layers.html#chainer.links.VGG16Layers)." 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "In the case of ResNet models, there are three variations differing in the number of layers. We have `chainer.links.ResNet50`, `chainer.links.ResNet101`, and `chainer.links.ResNet152` models with easy parameter loading feature. ResNet’s pre-trained parameters are not available for direct downloading, so you need to download the weight from the author’s web page first, and then place it into the dir `$CHAINER_DATSET_ROOT/pfnet/chainer/models` or your favorite place. Once the preparation is finished, the usage is the same as VGG16:\n", 379 | "\n", 380 | "```\n", 381 | "from chainer.links import ResNet152Layers\n", 382 | "\n", 383 | "model = ResNet152layers()\n", 384 | "```\n", 385 | "\n", 386 | "Please see the details of usage and how to prepare the pre-trained weights for ResNet here: [chainer.links.ResNet50](https://docs.chainer.org/en/latest/reference/generated/chainer.links.ResNet50Layers.html#chainer.links.ResNet50Layers)" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "# References\n", 394 | "\n", 395 | "\n", 396 | "[LeCun98]\tYann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), 2278–2324, 1998.\n", 397 | "\n", 398 | "\n", 399 | "[Simonyan14]\tSimonyan, K. and Zisserman, A., Very Deep Convolutional Networks for Large-Scale Image Recognition. arXiv preprint arXiv:1409.1556, 2014.\n", 400 | "\n", 401 | "\n", 402 | "[He16]\tKaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning for Image Recognition. The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770-778, 2016.\n", 403 | "Next Previous\n" 404 | ] 405 | } 406 | ], 407 | "metadata": { 408 | "kernelspec": { 409 | "display_name": "Python 3", 410 | "language": "python", 411 | "name": "python3" 412 | }, 413 | "language_info": { 414 | "codemirror_mode": { 415 | "name": "ipython", 416 | "version": 3 417 | }, 418 | "file_extension": ".py", 419 | "mimetype": "text/x-python", 420 | "name": "python", 421 | "nbconvert_exporter": "python", 422 | "pygments_lexer": "ipython3", 423 | "version": "3.6.1" 424 | } 425 | }, 426 | "nbformat": 4, 427 | "nbformat_minor": 2 428 | } 429 | -------------------------------------------------------------------------------- /4_RNN-language-model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Write an RNN Language Model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "0. Introduction\n", 15 | "================\n", 16 | "\n", 17 | "The **language model** is modeling the probability of generating natural language sentences or documents. You can use the language model to estimate how natural a sentence or a document is. Also, with the language model, you can generate new sentences or documents.\n", 18 | "\n", 19 | "Let's start with modeling the probability of generating sentences. We represent a sentence as ${\\bf X} = ({\\bf x}_0, {\\bf x}_1, \\dots, {\\bf x}_T)$, in which ${\\bf x}_t$ is a one-hot vector. Generally, ${\\bf x}_0$ is the one-hot vector of **BOS** (beginning of sentence), and ${\\bf x}_T$ is that of **EOS** (end of sentence).\n", 20 | "\n", 21 | "A language model models the probability of a word occurance under the condition of its previous words in a sentence. Let ${\\bf X}_{[i, j]}$ be $({\\bf x}_i, {\\bf x}_{i+1}, \\dots, {\\bf x}_j)$ , the occurrence probability of sentence ${\\bf X}$ can be represented as follows:\n", 22 | "\n", 23 | "$$P({\\bf X}) = P({\\bf x}_0) \\prod_{t=1}^T P({\\bf x}_t \\mid {\\bf X}_{[0, t-1]})$$\n", 24 | "\n", 25 | "So, the language model $P({\\bf X})$ can be decomposed into word probabilities conditioned with its previous words.\n", 26 | "\n", 27 | "In this notebook, we model $P({\\bf x}_t \\mid {\\bf X}_{[0, t-1]})$ with a recurrent neural network to obtain a language model $P({\\bf X})$." 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "1. Basic Idea of Recurrent Neural Net Language Model\n", 35 | "=====================================================\n", 36 | "\n", 37 | "1.1 Recurrent Neural Net Language Model\n", 38 | "---------------------------------------\n", 39 | "\n", 40 | "**Recurrent Neurral Net Language Model** (RNNLM) is a type of neural net language models which contains RNNs in the network. Since an RNN can deal with the variable length inputs, it is suitable for modeling sequential data such as sentences in natural language.\n", 41 | "\n", 42 | "We show one layer of a RNNLM with these parameters.\n", 43 | "\n", 44 | "| Symbol | Definition |\n", 45 | "|-------:|:-----------|\n", 46 | "| ${\\bf x}_t$ | the one-hot vector of $t$-th word |\n", 47 | "| ${\\bf y}_t$ | the $t$-th output |\n", 48 | "| ${\\bf h}_t^{(i)}$ | the $t$-th hidden layer of $i$-th layer |\n", 49 | "| ${\\bf p}_t$ | the next word's probability of $t$-th word |\n", 50 | "| ${\\bf E}$ | Embedding matrix |\n", 51 | "| ${\\bf W}_h$ | Hidden layer matrix |\n", 52 | "| ${\\bf W}_o$ | Output layer matrix |\n", 53 | "\n", 54 | "\n", 55 | "![rnnlm](rnnlm.png)\n", 56 | "\n", 57 | "**The process to get a next word prediction from $t$-th input word ${\\bf x}_t$**\n", 58 | "\n", 59 | "1. Get the embedding vector: ${\\bf h}_t^{(0)} = {\\bf E} {\\bf x}_t$\n", 60 | "2. Calculate the hidden layer: ${\\bf h}_t^{(1)} = {\\rm tanh}\\left({\\bf W}_h \\left[ \\begin{array}{cc} {\\bf h}_t^{(0)} \\\\ {\\bf h}_{t-1}^{(1)} \\end{array} \\right]\\right)$\n", 61 | "3. Calculate the output layer: ${\\bf y}_t = {\\bf W}_o {\\bf h}_t^{(1)}$\n", 62 | "4. Transform to probability: ${\\bf p}_t = {\\rm softmax}({\\bf y}_t)$\n", 63 | "\n", 64 | "- Note that ${\\rm tanh}$ in the above equation is applied to the input vector in element-wise manner.\n", 65 | "- Note that $\\left[ \\begin{array}{cc} {\\bf a} \\\\ {\\bf b} \\end{array} \\right]$ denotes a concatenated vector of ${\\bf a}$ and ${\\bf b}$.\n", 66 | "- Note that ${\\rm softmax}$ in the above equation converts an arbitrary real vector to a probability vector which the summation over all elements is $1$." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "1.2 Perplexity (Evaluation of the language model)\n", 74 | "-----------------------------------------------\n", 75 | "\n", 76 | "**Perplexity** is the common evaluation metric for a language model. Generally, it measures how well the proposed probability model $P_{\\rm model}({\\bf X})$ represents the target data $P^*({\\bf X})$.\n", 77 | "\n", 78 | "Let a validation dataset be $D = \\{{\\bf X}^{(n)}\\}_{n=1}^{|D|}$, which is a set of sentences, where the $n$-th sentence length is $T^{(n)}$, and the vocabulary size of this dataset is $|\\mathcal{V}|$, the perplexity is represented as follows:\n", 79 | "\n", 80 | "$$\n", 81 | "\\begin{eqnarray}\n", 82 | "&& b^z \\\\\n", 83 | "&& s.t.~ ~ ~ z = - \\frac{1}{|\\mathcal{V}|}\n", 84 | "\\sum_{n=1}^{|D|} \\sum_{t=1}^{T^{(n)}} \\log_b P_{\\rm model}({\\bf x}_t^{(n)}, {\\bf X}_{[a, t-1]}^{(n)})\n", 85 | "\\end{eqnarray}\n", 86 | "$$\n", 87 | "\n", 88 | "We usually use $b = 2$ or $b = e$. The perplexity shows how much varied the predicted distribution for the next word is. When a language model well represents the dataset, it should show a high probability only for the correct next word, so that the entropy should be high. In the above equation, the sign is reversed, so that smaller perplexity means better model." 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "During training, we minimize the below cross entropy:\n", 96 | "\n", 97 | "$$\n", 98 | "\\mathcal{H}(\\hat{P}, P_{\\rm model}) = - \\hat{P}({\\bf X}) \\log P_{\\rm model}({\\bf X})\n", 99 | "$$\n", 100 | "\n", 101 | "where $\\hat P$ is the empirical distribution of a sequence in the training dataset." 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "2. Implementation of Recurrent Neural Net Language Model\n", 109 | "=========================================================\n", 110 | "\n", 111 | "**There is an example of RNN language model in the official repository, so we will explain how to implement a RNNLM in Chainer based on that: [chainer/examples/ptb](https://github.com/chainer/chainer/tree/master/examples/ptb)**\n", 112 | "\n", 113 | "2.1 Model overview\n", 114 | "-----------------\n", 115 | "\n", 116 | "![rnnlm_example](rnnlm_example.png)\n", 117 | "\n", 118 | "The RNNLM used in this notebook is depicted in the above figure. The symbols appeared in the figure are defined as follows:\n", 119 | "\n", 120 | "| Symbol | Definition |\n", 121 | "|-------:|:-----------|\n", 122 | "| ${\\bf x}_t$ | the one-hot vector of $t$-th input |\n", 123 | "| ${\\bf y}_t$ | the $t$-th output |\n", 124 | "| ${\\bf h}_t^{(i)}$ | the $t$-th hidden vector of $i$-th layer |\n", 125 | "| ${\\bf E}$ | Embedding matrix |\n", 126 | "| ${\\bf W}_o$ | Output layer matrix |\n", 127 | "\n", 128 | "**LSTMs** (long short-term memory) are used for the connection of hidden layers. A LSTM is one of major recurrent neural net modules. It is desined for remembering the long-term memory, so that it should be able to consider relationships of distant words, such that a word at beginning of sentence and it at the end. We also use **Dropout** before both LSTMs and linear transformations. Dropout is one of regularization techniques for preventing overfitting on training dataset." 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "2.2 Step-by-step implementation\n", 136 | "-----------------------------" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "### 2.2.1 Import Package\n", 144 | "\n", 145 | "First, let's import necessary packages." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 1, 151 | "metadata": { 152 | "collapsed": true 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "import chainer\n", 157 | "import chainer.functions as F\n", 158 | "import chainer.links as L\n", 159 | "from chainer import training\n", 160 | "from chainer.training import extensions\n", 161 | "import numpy as np" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "### 2.2.2 Define training settings\n", 169 | "\n", 170 | "Define all training settings here for ease of reference." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 2, 176 | "metadata": { 177 | "collapsed": true 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "batchsize = 20\n", 182 | "bproplen = 35\n", 183 | "epoch = 39\n", 184 | "gpu = 0 # negative value to run in CPU\n", 185 | "gradclip = 5\n", 186 | "unit = 650\n", 187 | "update_interval = 10" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "### 2.2.3 Define Network Structure\n", 195 | "\n", 196 | "An RNNLM written in Chainer is shown below. It implements the model depicted in the above figure." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 3, 202 | "metadata": { 203 | "collapsed": true 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "class RNNLM(chainer.Chain):\n", 208 | "\n", 209 | " def __init__(self, n_vocab, n_units):\n", 210 | " super(RNNLM, self).__init__()\n", 211 | " with self.init_scope():\n", 212 | " self.embed = L.EmbedID(n_vocab, n_units)\n", 213 | " self.l1 = L.LSTM(n_units, n_units)\n", 214 | " self.l2 = L.LSTM(n_units, n_units)\n", 215 | " self.l3 = L.Linear(n_units, n_vocab)\n", 216 | "\n", 217 | " for param in self.params():\n", 218 | " param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape)\n", 219 | "\n", 220 | " def reset_state(self):\n", 221 | " self.l1.reset_state()\n", 222 | " self.l2.reset_state()\n", 223 | "\n", 224 | " def __call__(self, x):\n", 225 | " h0 = self.embed(x)\n", 226 | " h1 = self.l1(F.dropout(h0))\n", 227 | " h2 = self.l2(F.dropout(h1))\n", 228 | " y = self.l3(F.dropout(h2))\n", 229 | " return y" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "- When we insatantiate this class for making a model, we give the vocavulary size to `n_vocab` and the size of hidden vectors to `n_units`.\n", 237 | "- This network uses `chainer.links.LSTM`, `chainer.links.Linear`, and `chainer.functions.dropout` as its building blocks.\n", 238 | "- All the layers are registered and initialized in the context with `self.init_scope()`.\n", 239 | "- You can access all the parameters in those layers by calling `self.params()`.\n", 240 | "- In the constructor, it initializes all parameters with values sampled from a uniform distribution $U(-1, 1)$.\n", 241 | "- The `__call__` method takes an word ID `x`, and calculates the word probability vector for the next word by forwarding it through the nerwork, and returns the output.\n", 242 | "- Note that the word ID `x` is automatically converted to a $|\\mathcal{V}|$-dimensional one-hot vector and then multiplied with the input embedding matrix in `self.embed(x)` to obtain an embed vector `h0` at the first line of `__call__`." 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "### 2.2.4 Load the Penn Tree Bank long word sequence dataset\n", 250 | "\n", 251 | "In this notebook, we use [Penn Tree Bank](https://www.cis.upenn.edu/~treebank/) dataset that contains number of sentences. Chainer provides an utility function to obtain this dataset from server and convert it to a long single sequence of word IDs. `chainer.datasets.get_ptb_words()` actually returns three separated datasets which are for train, validation, and test.\n", 252 | "\n", 253 | "Let's download and make dataset objects using it:" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 4, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Downloading from https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt...\n", 266 | "Downloading from https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt...\n", 267 | "Downloading from https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt...\n" 268 | ] 269 | } 270 | ], 271 | "source": [ 272 | "train, val, test = chainer.datasets.get_ptb_words()\n", 273 | "n_vocab = max(train) + 1 " 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "### 2.2.5 Define Iterator for making a mini-batch from the dataset\n", 281 | "\n", 282 | "Dataset iterator creates a mini-batch of couple of words at different positions, namely, pairs of current word and its next word. Each example is a part of sentences starting from different offsets equally spaced within the whole sequence." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 5, 288 | "metadata": { 289 | "collapsed": true 290 | }, 291 | "outputs": [], 292 | "source": [ 293 | "class ParallelSequentialIterator(chainer.dataset.Iterator):\n", 294 | "\n", 295 | " def __init__(self, dataset, batch_size, repeat=True):\n", 296 | " self.dataset = dataset\n", 297 | " \n", 298 | " # batch size\n", 299 | " self.batch_size = batch_size\n", 300 | " \n", 301 | " # Number of completed sweeps over the dataset. In this case, it is\n", 302 | " # incremented if every word is visited at least once after the last\n", 303 | " # increment.\n", 304 | " self.epoch = 0\n", 305 | " \n", 306 | " # True if the epoch is incremented at the last iteration.\n", 307 | " self.is_new_epoch = False\n", 308 | " self.repeat = repeat\n", 309 | " length = len(dataset)\n", 310 | " \n", 311 | " # Offsets maintain the position of each sequence in the mini-batch.\n", 312 | " self.offsets = [i * length // batch_size for i in range(batch_size)]\n", 313 | " \n", 314 | " # NOTE: this is not a count of parameter updates. It is just a count of\n", 315 | " # calls of `__next__`.\n", 316 | " self.iteration = 0\n", 317 | " \n", 318 | " # use -1 instead of None internally\n", 319 | " self._previous_epoch_detail = -1.\n", 320 | "\n", 321 | " def __next__(self):\n", 322 | " # This iterator returns a list representing a mini-batch. Each item\n", 323 | " # indicates a different position in the original sequence. Each item is\n", 324 | " # represented by a pair of two word IDs. The first word is at the\n", 325 | " # \"current\" position, while the second word is at the next position.\n", 326 | " # At each iteration, the iteration count is incremented, which pushes\n", 327 | " # forward the \"current\" position.\n", 328 | " length = len(self.dataset)\n", 329 | " if not self.repeat and self.iteration * self.batch_size >= length:\n", 330 | " # If not self.repeat, this iterator stops at the end of the first\n", 331 | " # epoch (i.e., when all words are visited once).\n", 332 | " raise StopIteration\n", 333 | " cur_words = self.get_words()\n", 334 | " self._previous_epoch_detail = self.epoch_detail\n", 335 | " self.iteration += 1\n", 336 | " next_words = self.get_words()\n", 337 | "\n", 338 | " epoch = self.iteration * self.batch_size // length\n", 339 | " self.is_new_epoch = self.epoch < epoch\n", 340 | " if self.is_new_epoch:\n", 341 | " self.epoch = epoch\n", 342 | "\n", 343 | " return list(zip(cur_words, next_words))\n", 344 | "\n", 345 | " @property\n", 346 | " def epoch_detail(self):\n", 347 | " # Floating point version of epoch.\n", 348 | " return self.iteration * self.batch_size / len(self.dataset)\n", 349 | "\n", 350 | " @property\n", 351 | " def previous_epoch_detail(self):\n", 352 | " if self._previous_epoch_detail < 0:\n", 353 | " return None\n", 354 | " return self._previous_epoch_detail\n", 355 | "\n", 356 | " def get_words(self):\n", 357 | " # It returns a list of current words.\n", 358 | " return [self.dataset[(offset + self.iteration) % len(self.dataset)]\n", 359 | " for offset in self.offsets]\n", 360 | "\n", 361 | " def serialize(self, serializer):\n", 362 | " # It is important to serialize the state to be recovered on resume.\n", 363 | " self.iteration = serializer('iteration', self.iteration)\n", 364 | " self.epoch = serializer('epoch', self.epoch)\n", 365 | " try:\n", 366 | " self._previous_epoch_detail = serializer(\n", 367 | " 'previous_epoch_detail', self._previous_epoch_detail)\n", 368 | " except KeyError:\n", 369 | " # guess previous_epoch_detail for older version\n", 370 | " self._previous_epoch_detail = self.epoch + \\\n", 371 | " (self.current_position - self.batch_size) / len(self.dataset)\n", 372 | " if self.epoch_detail > 0:\n", 373 | " self._previous_epoch_detail = max(\n", 374 | " self._previous_epoch_detail, 0.)\n", 375 | " else:\n", 376 | " self._previous_epoch_detail = -1." 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "### 2.2.6 Define Updater\n", 384 | "\n", 385 | "We use Backpropagation through time (BPTT) for optimize the RNNLM. BPTT can be implemented by overriding `update_core()` method of `StandardUpdater`. First, in the constructor of the `BPTTUpdater`, it takes `bprop_len` as an argument in addiotion to other arguments `StandardUpdater` needs. `bprop_len` defines the length of sequence $T$ to calculate the loss:\n", 386 | "\n", 387 | "$$\n", 388 | "\\mathcal{L} = - \\sum_{t=0}^T \\sum_{n=1}^{|\\mathcal{V}|}\n", 389 | "\\hat{P}({\\bf x}_{t+1}^{(n)})\n", 390 | "\\log\n", 391 | "P_{\\rm model}({\\bf x}_{t+1}^{(n)} \\mid {\\bf x}_t^{(n)})\n", 392 | "$$\n", 393 | "\n", 394 | "where $\\hat{P}({\\bf x}_t^n)$ is a probability for $n$-th word in the vocabulary at the position $t$ in the training data sequence." 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 6, 400 | "metadata": { 401 | "collapsed": true 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "class BPTTUpdater(training.StandardUpdater):\n", 406 | "\n", 407 | " def __init__(self, train_iter, optimizer, bprop_len, device):\n", 408 | " super(BPTTUpdater, self).__init__(\n", 409 | " train_iter, optimizer, device=device)\n", 410 | " self.bprop_len = bprop_len\n", 411 | "\n", 412 | " # The core part of the update routine can be customized by overriding.\n", 413 | " def update_core(self):\n", 414 | " loss = 0\n", 415 | " # When we pass one iterator and optimizer to StandardUpdater.__init__,\n", 416 | " # they are automatically named 'main'.\n", 417 | " train_iter = self.get_iterator('main')\n", 418 | " optimizer = self.get_optimizer('main')\n", 419 | "\n", 420 | " # Progress the dataset iterator for bprop_len words at each iteration.\n", 421 | " for i in range(self.bprop_len):\n", 422 | " # Get the next batch (a list of tuples of two word IDs)\n", 423 | " batch = train_iter.__next__()\n", 424 | "\n", 425 | " # Concatenate the word IDs to matrices and send them to the device\n", 426 | " # self.converter does this job\n", 427 | " # (it is chainer.dataset.concat_examples by default)\n", 428 | " x, t = self.converter(batch, self.device)\n", 429 | "\n", 430 | " # Compute the loss at this time step and accumulate it\n", 431 | " loss += optimizer.target(chainer.Variable(x), chainer.Variable(t))\n", 432 | "\n", 433 | " optimizer.target.cleargrads() # Clear the parameter gradients\n", 434 | " loss.backward() # Backprop\n", 435 | " loss.unchain_backward() # Truncate the graph\n", 436 | " optimizer.update() # Update the parameters" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "### 2.2.7 Define Evaluation Function (Perplexity)\n", 444 | "\n", 445 | "Define a function to calculate the perplexity from the loss value. If we take $e$ as $b$ in the above definition of perplexity, calculating the perplexity is just to give the loss value to the power of $e$:" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 7, 451 | "metadata": { 452 | "collapsed": true 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "def compute_perplexity(result):\n", 457 | " result['perplexity'] = np.exp(result['main/loss'])\n", 458 | " if 'validation/main/loss' in result:\n", 459 | " result['val_perplexity'] = np.exp(result['validation/main/loss'])" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "### 2.2.8 Create iterators\n", 467 | "\n", 468 | "Here, the code below just create iterator objects from dataset splits (train/val/test)." 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 8, 474 | "metadata": { 475 | "collapsed": true 476 | }, 477 | "outputs": [], 478 | "source": [ 479 | "train_iter = ParallelSequentialIterator(train, batchsize)\n", 480 | "val_iter = ParallelSequentialIterator(val, 1, repeat=False)\n", 481 | "test_iter = ParallelSequentialIterator(test, 1, repeat=False)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "### 2.2.9 Create RNN and classification model\n", 489 | "\n", 490 | "Instantiate RNNLM model and wrap it with `L.Classifier` because it calculates softmax cross entropy as the loss." 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 9, 496 | "metadata": { 497 | "collapsed": true 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "rnn = RNNLM(n_vocab, unit)\n", 502 | "model = L.Classifier(rnn)\n", 503 | "model.compute_accuracy = False # we only want the perplexity" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": {}, 509 | "source": [ 510 | "Note that `chainer.links.Classifier` computes not only the loss but also accuracy based on a given input/label pair. To learn the RNN language model, we only need the loss (cross entropy) in the `Classifier` because we calculate the perplexity instead of classification accuracy to check the performance of the model. So, we turn off computing the accuracy by giving `False` to `model.compute_accuracy` attribute." 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "### 2.2.10 Setup optimizer\n", 518 | "\n", 519 | "Prepare an optimizer. Here, we use `GradientClipping` to prevent gradient explosion. It automatically clip the gradient to be used to update the parameters in the model with given constant `gradclip`." 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 10, 525 | "metadata": { 526 | "collapsed": true 527 | }, 528 | "outputs": [], 529 | "source": [ 530 | "optimizer = chainer.optimizers.SGD(lr=1.0)\n", 531 | "optimizer.setup(model)\n", 532 | "optimizer.add_hook(chainer.optimizer.GradientClipping(gradclip))" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "### 2.2.11 Setup and run trainer\n", 540 | "\n", 541 | "Let's make an `trainer` object and start the training! Note that we add an `eval_hook` to the `Evaluator` extension to reset the internal states before starting evaluation process. It can prevent to use training data during evaluating the model." 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 11, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "epoch iteration perplexity val_perplexity\n", 554 | "\u001b[J1 1328 363.995 200.325 \n", 555 | "\u001b[J2 2656 179.979 148.02 \n", 556 | "\u001b[J3 3984 140.361 127.294 \n", 557 | "\u001b[J4 5312 119.012 114.39 \n", 558 | "\u001b[J5 6640 106.444 106.898 \n", 559 | "\u001b[J6 7968 97.9365 101.867 \n", 560 | "\u001b[J7 9296 90.0965 97.8341 \n", 561 | "\u001b[J8 10624 85.1498 95.9942 \n", 562 | "\u001b[J9 11952 81.3766 93.6086 \n", 563 | "\u001b[J10 13280 76.9708 92.1757 \n", 564 | "\u001b[J11 14608 70.8522 91.2738 \n", 565 | "\u001b[J12 15936 69.6306 90.0212 \n", 566 | "\u001b[J13 17264 69.8435 89.401 \n", 567 | "\u001b[J14 18592 66.0797 88.7216 \n", 568 | "\u001b[J15 19920 62.5617 88.9686 \n", 569 | "\u001b[J16 21248 60.855 88.4791 \n", 570 | "\u001b[J17 22576 59.2166 87.8589 \n", 571 | "\u001b[J18 23904 59.2627 88.0755 \n", 572 | "\u001b[J19 25232 60.0981 88.7103 \n", 573 | "\u001b[J20 26560 58.5949 87.7052 \n", 574 | "\u001b[J21 27888 56.1726 88.2218 \n", 575 | "\u001b[J22 29216 54.784 87.9091 \n", 576 | "\u001b[J23 30544 53.9637 87.8705 \n", 577 | "\u001b[J24 31872 53.4845 87.7428 \n", 578 | "\u001b[J25 33200 52.7459 89.1088 \n", 579 | "\u001b[J26 34528 53.4468 88.2932 \n", 580 | "\u001b[J27 35856 52.4166 88.5804 \n", 581 | "\u001b[J28 37184 50.5044 89.2111 \n", 582 | "\u001b[J29 38512 50.4532 89.0317 \n", 583 | "\u001b[J30 39840 50.5737 88.8173 \n", 584 | "\u001b[J31 41168 50.7383 88.2683 \n", 585 | "\u001b[J32 42496 49.8793 89.4104 \n", 586 | "\u001b[J33 43824 48.5711 89.0279 \n", 587 | "\u001b[J34 45152 49.241 89.2004 \n", 588 | "\u001b[J35 46480 47.9883 89.4998 \n", 589 | "\u001b[J36 47808 47.1426 90.068 \n", 590 | "\u001b[J37 49136 46.4937 90.2983 \n", 591 | "\u001b[J38 50464 44.9296 90.5085 \n", 592 | "\u001b[J39 51792 47.4697 89.5714 \n" 593 | ] 594 | } 595 | ], 596 | "source": [ 597 | "updater = BPTTUpdater(train_iter, optimizer, bproplen, gpu)\n", 598 | "trainer = training.Trainer(updater, (epoch, 'epoch'), out='ptb_result')\n", 599 | "\n", 600 | "eval_model = model.copy() # Model with shared params and distinct states \n", 601 | "eval_rnn = eval_model.predictor\n", 602 | "trainer.extend(extensions.Evaluator(\n", 603 | " val_iter, eval_model, device=gpu,\n", 604 | " # Reset the RNN state at the beginning of each evaluation \n", 605 | " eval_hook=lambda _: eval_rnn.reset_state()))\n", 606 | "\n", 607 | "trainer.extend(extensions.LogReport(postprocess=compute_perplexity, trigger=(1, 'epoch')))\n", 608 | "trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'perplexity', 'val_perplexity']), trigger=(1, 'epoch'))\n", 609 | "trainer.extend(extensions.snapshot())\n", 610 | "trainer.extend(extensions.snapshot_object(model, 'model_epoch_{.updater.epoch}'))\n", 611 | "\n", 612 | "trainer.run()" 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": {}, 618 | "source": [ 619 | "### 2.2.12 Evaluate the trained model on test dataset\n", 620 | "\n", 621 | "Let's see the perplexity on the test split. Trainer's extension can be used as just a normal function outside of Trainer." 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 12, 627 | "metadata": {}, 628 | "outputs": [ 629 | { 630 | "name": "stdout", 631 | "output_type": "stream", 632 | "text": [ 633 | "test perplexity: 87.1754855238\n" 634 | ] 635 | } 636 | ], 637 | "source": [ 638 | "eval_rnn.reset_state()\n", 639 | "evaluator = extensions.Evaluator(test_iter, eval_model, device=gpu)\n", 640 | "result = evaluator()\n", 641 | "print('test perplexity:', np.exp(float(result['main/loss'])))" 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": {}, 647 | "source": [ 648 | "2.3 Generating sentences\n", 649 | "-----------------------" 650 | ] 651 | }, 652 | { 653 | "cell_type": "markdown", 654 | "metadata": {}, 655 | "source": [ 656 | "You can generate the sentence which starts with a word in the vocabulary. In this example, we generate a sentence which starts with the word **apple**. We use the script in the PTB example of the official repository.\n", 657 | "\n", 658 | "https://github.com/chainer/chainer/tree/master/examples/ptb" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 14, 664 | "metadata": {}, 665 | "outputs": [ 666 | { 667 | "name": "stdout", 668 | "output_type": "stream", 669 | "text": [ 670 | "apple is the major public business in N years .this is a regime of the earth as \n" 671 | ] 672 | } 673 | ], 674 | "source": [ 675 | "%%bash\n", 676 | "python gentxt.py -m ptb_result/model_epoch_39 -p apple" 677 | ] 678 | } 679 | ], 680 | "metadata": { 681 | "kernelspec": { 682 | "display_name": "Python 3", 683 | "language": "python", 684 | "name": "python3" 685 | }, 686 | "language_info": { 687 | "codemirror_mode": { 688 | "name": "ipython", 689 | "version": 3 690 | }, 691 | "file_extension": ".py", 692 | "mimetype": "text/x-python", 693 | "name": "python", 694 | "nbconvert_exporter": "python", 695 | "pygments_lexer": "ipython3", 696 | "version": "3.6.1" 697 | } 698 | }, 699 | "nbformat": 4, 700 | "nbformat_minor": 2 701 | } 702 | -------------------------------------------------------------------------------- /5_word2vec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Word2Vec: Obtain word embeddings\n", 8 | "\n", 9 | "## 0. Introduction\n", 10 | "\n", 11 | "**Word2vec** is the tool for generating the distributed representation of words, which is proposed by Mikolov et al[[1]](#1). When the tool assigns a real-valued vector to each word, the closer the meanings of the words, the greater similarity the vectors will indicate.\n", 12 | "\n", 13 | "**Distributed representation** means assigning a real-valued vector for each word and representing the word by the vector. When representing a word by distributed representation, we call the vector **word embeddings**. In this notebook, we aim at explaining how to get the word embeddings from Penn Tree Bank dataset.\n", 14 | "\n", 15 | "Let's think about what the meaning of word is. Since we are human, so we can understand that the words \"animal\" and \"dog\" are deeply related each other. But what information will Word2vec use to learn the vectors for words? The words \"animal\" and \"dog\" should have similar vectors, but the words \"food\" and \"dog\" should be far from each other. How to know the features of those words automatically?" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## 1. Basic Idea\n", 23 | "\n", 24 | "Word2vec learns the similarity of word meanings from simple information. It learns the representation of words from sentences. The core idea is based on the assumption that the meaning of a word is affected by the words around it. This idea follows **distributional hypothesis**[[2]](#2).\n", 25 | "\n", 26 | "The word we focus on to learn its representation is called **\"center word\"**, and the words around it are called **\"context words\"**. Depending on the window size `C` determines the number of context words which is considered.\n", 27 | "\n", 28 | "Here, let's see the algorithm by using an example sentence: \"**The cute cat jumps over the lazy dog.**\"\n", 29 | "\n", 30 | "- All of the following figures consider \"cat\" as the center word.\n", 31 | "- According to the window size `C`, you can see that the number of context words is changed.\n", 32 | "\n", 33 | "![](center_context_word.png)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## 2. Main Algorithm\n", 41 | "\n", 42 | "Word2vec, the tool for creating the word embeddings, is actually built with two models, which are called **Skip-gram** and **CBoW**.\n", 43 | "\n", 44 | "To explain the models with the figures below, we will use the following symbols.\n", 45 | "\n", 46 | "| Symbol | Definition |\n", 47 | "| --------: | :------------------------------------------------------- |\n", 48 | "| $|\\mathcal{V}|$ | The size of vocabulary |\n", 49 | "| $D$ | The size of embedding vector |\n", 50 | "| ${\\bf v}_t$ | A one-hot center word vector |\n", 51 | "| $V_{\\pm C}$ | A set of $C$ context vectors around ${\\bf v}_t$, namely, $\\{{\\bf v}_{t+c}\\}_{c=-C}^C \\backslash {\\bf v}_t$ |\n", 52 | "| ${\\bf l}_H$ | An embedding vector of an input word vector |\n", 53 | "| ${\\bf l}_O$ | An output vector of the network |\n", 54 | "| ${\\bf W}_H$ | The embedding matrix for inputs |\n", 55 | "| ${\\bf W}_O$ | The embedding matrix for outputs |\n", 56 | "\n", 57 | "**Note**\n", 58 | "\n", 59 | "Using **negative sampling** or **hierarchical softmax** for the loss function is very common, however, in this notebook, we will use the **softmax over all words** and skip the other variants for the sake of simplicity." 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "### 2.1 Skip-gram\n", 67 | "\n", 68 | "This model learns to predict context words $V_{t \\pm C}$ when a center word ${\\bf v}_t$ is given. In the model, each row of the embedding matrix for input $W_H$ becomes a word embedding of each word.\n", 69 | "\n", 70 | "When you input a center word ${\\bf v}_t$ into the network, you can predict one of context words $\\hat{\\bf v}_{t+i} \\in V_{t \\pm C}$ as follows:\n", 71 | "\n", 72 | "1. Calculate an embedding vector of the input center word vector: ${\\bf l}_H = {\\bf W}_H {\\bf v}_t$\n", 73 | "2. Calculate an output vector of the embedding vector: ${\\bf l}_O = {\\bf W}_O {\\bf l}_H$\n", 74 | "3. Calculate a probability vector of a context word: $\\hat{\\bf v}_{t+i} = \\text{softmax}({\\bf l}_O)$\n", 75 | "\n", 76 | "Each element of the $|\\mathcal{V}|$-dimensional vector $\\hat{\\bf v}_{t+i}$ is a probability that a word in the vocabulary turns out to be a context word at position $i$. So, the probability $p({\\bf v}_{t+i} \\mid {\\bf v}_t)$ can be estimated by a dot product of the one-hot vector ${\\bf v}_{t+i}$ which represents the actual word at the position $i$ and the output vector $\\hat{\\bf v}_{t+i}$.\n", 77 | "\n", 78 | "$p({\\bf v}_{t+i} \\mid {\\bf v}_t) = {\\bf v}_{t+i}^T \\hat{\\bf v}_{t+i}$\n", 79 | "\n", 80 | "The loss function for all the context words $V_{t \\pm C}$ given a center word ${\\bf v}_t$ is defined as following:\n", 81 | "\n", 82 | "$\n", 83 | "\\begin{eqnarray}\n", 84 | "L(V_{t \\pm C} | {\\bf v}_t; {\\bf W}_H, {\\bf W}_O)\n", 85 | "&=& \\sum_{V_{t \\pm C}} -\\log\\left(p({\\bf v}_{t+i} \\mid {\\bf v}_t)\\right) \\\\\n", 86 | "&=& \\sum_{V_{t \\pm C}} -\\log({\\bf v}_{t+i}^T \\hat{\\bf v}_{t+i})\n", 87 | "\\end{eqnarray}\n", 88 | "$" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "### 2.2 Continuous Bag of Words (CBoW)\n", 96 | "\n", 97 | "This model learns to predict the center word ${\\bf v}_t$ when context words $V_{t \\pm C}$ is given.\n", 98 | "\n", 99 | "When you give a set of context words $V_{t \\pm C}$ to the network, you can estimate the probability of the center word $\\hat{v}_t$ as follows:\n", 100 | "\n", 101 | "1. Calculate a mean embedding vector over all context words: ${\\bf l}_H = \\frac{1}{2C} \\sum_{V_{t \\pm C}} {\\bf W}_H {\\bf v}_{t+i}$\n", 102 | "2. Calculate an output vector: ${\\bf l}_O = {\\bf W}_O {\\bf l}_H$\n", 103 | "3. Calculate an probability vector: $\\hat{\\bf v}_t = \\text{softmax}({\\bf l}_O)$\n", 104 | "\n", 105 | "Each element of $\\hat{\\bf v}_t$ is a probability that a word in the vocabulary is considered as the center word. So, the prediction $p({\\bf v}_t \\mid V_{t \\pm C})$ can be calculated by ${\\bf v}_t^T \\hat{\\bf v}_t$, where ${\\bf v}_t$ denots the one-hot vector of the actual center word vector in the sentence from the dataset.\n", 106 | "\n", 107 | "The loss function for the center word prediction is defined as follows:\n", 108 | "\n", 109 | "$\n", 110 | "\\begin{eqnarray}\n", 111 | "L({\\bf v}_t|V_{t \\pm C}; W_H, W_O)\n", 112 | "&=& -\\log(p({\\bf v}_t|V_{t \\pm C})) \\\\\n", 113 | "&=& -\\log({\\bf v}_t^T \\hat{\\bf v}_t)\n", 114 | "\\end{eqnarray}\n", 115 | "$" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## 3. Details of skip-gram\n", 123 | "\n", 124 | "In this notebook, we mainly explain skip-gram model because\n", 125 | "\n", 126 | "1. It is easier to understand the algorithm than CBoW.\n", 127 | "2. Even if the number of words increases, the accuracy is largely maintained. So, it is more scalable." 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "So, let's think about a concrete example of calculating skip-gram under this setup:\n", 135 | "\n", 136 | "* The size of vocabulary $|\\mathcal{V}|$ is 10.\n", 137 | "* The size of embedding vector $D$ is 2.\n", 138 | "* Center word is \"dog\".\n", 139 | "* Context word is \"animal\".\n", 140 | "\n", 141 | "Since there should be more than one context words, repeat the following process for each context word.\n", 142 | "\n", 143 | "1. The one-hot vector of \"dog\" is `[0 0 1 0 0 0 0 0 0 0]` and you input it as the center word.\n", 144 | "2. The third row of embedding matrix ${\\bf W}_H$ is used for the word embedding of \"dog\" ${\\bf l}_H$.\n", 145 | "3. Then multiply ${\\bf W}_O$ with ${\\bf l}_H$ to obtain the output vector ${\\bf l}_O$\n", 146 | "4. Give ${\\bf l}_O$ to the softmax function to make it a predicted probability vector $\\hat{\\bf v}_{t+c}$ for a context word at the position $c$.\n", 147 | "5. Calculate the error between $\\hat{\\bf v}_{t+c}$ and the one-hot vector of \"animal\"; `[1 0 0 0 0 0 0 0 0 0 0]`.\n", 148 | "6. Propagate the error back to the network to update the parameters.\n", 149 | "\n", 150 | "![](skipgram_detail.png)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## 4. Implementation of skip-gram in Chainer\n", 158 | "\n", 159 | "There is an example of Word2vec in the official repository of Chainer, so we will explain how to implement skip-gram based on this: [chainer/examples/word2vec](https://github.com/chainer/chainer/tree/master/examples/word2vec)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "### 4.1 Preparation\n", 167 | "\n", 168 | "First, let's import necessary packages:" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 1, 174 | "metadata": { 175 | "collapsed": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "import argparse\n", 180 | "import collections\n", 181 | "\n", 182 | "import numpy as np\n", 183 | "import six\n", 184 | "\n", 185 | "import chainer\n", 186 | "from chainer import cuda\n", 187 | "import chainer.functions as F\n", 188 | "import chainer.initializers as I\n", 189 | "import chainer.links as L\n", 190 | "import chainer.optimizers as O\n", 191 | "from chainer import reporter\n", 192 | "from chainer import training\n", 193 | "from chainer.training import extensions" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "### 4.2 Define a skip-gram model\n", 201 | "\n", 202 | "Next, let's define a network for skip-gram." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 2, 208 | "metadata": { 209 | "collapsed": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "class SkipGram(chainer.Chain):\n", 214 | "\n", 215 | " def __init__(self, n_vocab, n_units):\n", 216 | " super().__init__()\n", 217 | " with self.init_scope():\n", 218 | " self.embed = L.EmbedID(\n", 219 | " n_vocab, n_units, initialW=I.Uniform(1. / n_units))\n", 220 | " self.out = L.Linear(n_units, n_vocab, initialW=0)\n", 221 | "\n", 222 | " def __call__(self, x, context):\n", 223 | " e = self.embed(context)\n", 224 | " shape = e.shape\n", 225 | " x = F.broadcast_to(x[:, None], (shape[0], shape[1]))\n", 226 | " e = F.reshape(e, (shape[0] * shape[1], shape[2]))\n", 227 | " x = F.reshape(x, (shape[0] * shape[1],))\n", 228 | " center_predictions = self.out(e)\n", 229 | " loss = F.softmax_cross_entropy(center_predictions, x)\n", 230 | " reporter.report({'loss': loss}, self)\n", 231 | " return loss" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "**Note**\n", 239 | "\n", 240 | "- The weight matrix `self.embed.W` is the embbeding matrix for input vector `x`.\n", 241 | "- `__call__` takes the word ID of a center word `x` and word IDs of context words `contexts` as inputs, and outputs the error calculated by the loss function `softmax_cross_entropy`.\n", 242 | "- Note that the initial shape of `x` and `contexts` are `(batch_size,)` and `(batch_size, n_context)`, respectively.\n", 243 | "- The `batch_size` means the size of mini-batch, and `n_context` means the number of context words.\n", 244 | "\n", 245 | "First, we obtain the embedding vectors of `contexts` by `e = self.embed(contexts)`. \n", 246 | "\n", 247 | "Then `F.broadcast_to(x[:, None], (shape[0], shape[1]))` performs broadcasting of `x` (`(batch_size,)`) to `(batch_size, n_context)` by copying the same value `n_context` time to fill the second axis, and then the broadcasted `x` is reshaped into 1-D vector `(batchsize * n_context,)` while `e` is reshaped to `(batch_size * n_context, n_units)`.\n", 248 | "\n", 249 | "In skip-gram model, predicting a context word from the center word is the same as predicting the center word from a context word because the center word is always a context word when considering the context word as a center word. So, we create `batch_size * n_context` center word predictions by applying `self.out` linear layer to the embedding vectors of context words. Then, calculate softmax cross entropy between the broadcasted center word ID `x` and the predictions." 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "### 4.3 Prepare dataset and iterator\n", 257 | "\n", 258 | "Let's retrieve the Penn Tree Bank (PTB) dataset by using Chainer's dataset utility `get_ptb_words()` method." 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 3, 264 | "metadata": { 265 | "collapsed": true 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "train, val, _ = chainer.datasets.get_ptb_words()\n", 270 | "n_vocab = max(train) + 1 # The minimum word ID is 0" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "Then define an iterator to make mini-batches that contain a set of center words with their context words." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 4, 283 | "metadata": { 284 | "collapsed": true 285 | }, 286 | "outputs": [], 287 | "source": [ 288 | "class WindowIterator(chainer.dataset.Iterator):\n", 289 | "\n", 290 | " def __init__(self, dataset, window, batch_size, repeat=True):\n", 291 | " self.dataset = np.array(dataset, np.int32)\n", 292 | " self.window = window\n", 293 | " self.batch_size = batch_size\n", 294 | " self._repeat = repeat\n", 295 | "\n", 296 | " self.order = np.random.permutation(\n", 297 | " len(dataset) - window * 2).astype(np.int32)\n", 298 | " self.order += window\n", 299 | " self.current_position = 0\n", 300 | " self.epoch = 0\n", 301 | " self.is_new_epoch = False\n", 302 | "\n", 303 | " def __next__(self):\n", 304 | " if not self._repeat and self.epoch > 0:\n", 305 | " raise StopIteration\n", 306 | "\n", 307 | " i = self.current_position\n", 308 | " i_end = i + self.batch_size\n", 309 | " position = self.order[i: i_end]\n", 310 | " w = np.random.randint(self.window - 1) + 1\n", 311 | " offset = np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])\n", 312 | " pos = position[:, None] + offset[None, :]\n", 313 | " context = self.dataset.take(pos)\n", 314 | " center = self.dataset.take(position)\n", 315 | "\n", 316 | " if i_end >= len(self.order):\n", 317 | " np.random.shuffle(self.order)\n", 318 | " self.epoch += 1\n", 319 | " self.is_new_epoch = True\n", 320 | " self.current_position = 0\n", 321 | " else:\n", 322 | " self.is_new_epoch = False\n", 323 | " self.current_position = i_end\n", 324 | "\n", 325 | " return center, context\n", 326 | "\n", 327 | " @property\n", 328 | " def epoch_detail(self):\n", 329 | " return self.epoch + float(self.current_position) / len(self.order)\n", 330 | "\n", 331 | " def serialize(self, serializer):\n", 332 | " self.current_position = serializer('current_position',\n", 333 | " self.current_position)\n", 334 | " self.epoch = serializer('epoch', self.epoch)\n", 335 | " self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)\n", 336 | " if self._order is not None:\n", 337 | " serializer('_order', self._order)\n", 338 | "\n", 339 | "def convert(batch, device):\n", 340 | " center, context = batch\n", 341 | " if device >= 0:\n", 342 | " center = cuda.to_gpu(center)\n", 343 | " context = cuda.to_gpu(context)\n", 344 | " return center, context" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "- In the constructor, we create an array `self.order` which denotes shuffled indices of `[window, window + 1, ..., len(dataset) - window - 1]` in order to choose a center word randomly from `dataset` in a mini-batch.\n", 352 | "- The iterator definition `__next__` returns `batch_size` sets of center word and context words.\n", 353 | "- The code `self.order[i:i_end]` returns the indices for a set of center words from the random-ordered array `self.order`. The center word IDs `center` at the random indices are retrieved by `self.dataset.take`.\n", 354 | "- `np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])` creates a set of offsets to retrieve context words from the dataset.\n", 355 | "- The code `position[:, None] + offset[None, :]` generates the indices of context words for each center word index in `position`. The context word IDs `context` are retrieved by `self.dataset.take`." 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "### 4.4 Prepare model, optimizer, and updater" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 5, 368 | "metadata": { 369 | "collapsed": true 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "unit = 100 # number of hidden units\n", 374 | "window = 5\n", 375 | "batchsize = 1000\n", 376 | "gpu = 0\n", 377 | "\n", 378 | "# Instantiate model\n", 379 | "model = SkipGram(n_vocab, unit)\n", 380 | "\n", 381 | "if gpu >= 0:\n", 382 | " model.to_gpu(gpu)\n", 383 | "\n", 384 | "# Create optimizer\n", 385 | "optimizer = O.Adam()\n", 386 | "optimizer.setup(model)\n", 387 | "\n", 388 | "# Create iterators for both train and val datasets\n", 389 | "train_iter = WindowIterator(train, window, batchsize)\n", 390 | "val_iter = WindowIterator(val, window, batchsize, repeat=False)\n", 391 | "\n", 392 | "# Create updater\n", 393 | "updater = training.StandardUpdater(\n", 394 | " train_iter, optimizer, converter=convert, device=gpu)" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "### 4.5 Start training" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 6, 407 | "metadata": {}, 408 | "outputs": [ 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "epoch main/loss validation/main/loss\n", 414 | "\u001b[J1 6.87469 6.49239 \n", 415 | "\u001b[J2 6.43766 6.42476 \n", 416 | "\u001b[J3 6.34942 6.36353 \n", 417 | "\u001b[J4 6.28435 6.31737 \n", 418 | "\u001b[J5 6.23287 6.283 \n", 419 | "\u001b[J6 6.20336 6.24662 \n", 420 | "\u001b[J7 6.16982 6.26347 \n", 421 | "\u001b[J8 6.14064 6.21212 \n", 422 | "\u001b[J9 6.11639 6.17406 \n", 423 | "\u001b[J10 6.0834 6.23088 \n", 424 | "\u001b[J11 6.04547 6.22154 \n", 425 | "\u001b[J12 6.04376 6.22592 \n", 426 | "\u001b[J13 6.03126 6.17224 \n", 427 | "\u001b[J14 6.00218 6.21196 \n", 428 | "\u001b[J15 6.01313 6.15059 \n", 429 | "\u001b[J16 6.00215 6.21771 \n", 430 | "\u001b[J17 5.97584 6.18996 \n", 431 | "\u001b[J18 5.96465 6.21517 \n", 432 | "\u001b[J19 5.95188 6.19383 \n", 433 | "\u001b[J20 5.9473 6.13596 \n", 434 | "\u001b[J21 5.92897 6.17627 \n", 435 | "\u001b[J22 5.92385 6.17938 \n", 436 | "\u001b[J23 5.9195 6.18834 \n", 437 | "\u001b[J24 5.90871 6.16518 \n", 438 | "\u001b[J25 5.91193 6.22971 \n", 439 | "\u001b[J26 5.91069 6.16714 \n", 440 | "\u001b[J27 5.8846 6.19366 \n", 441 | "\u001b[J28 5.90398 6.21329 \n", 442 | "\u001b[J29 5.90064 6.21574 \n", 443 | "\u001b[J30 5.88258 6.18508 \n", 444 | "\u001b[J31 5.8702 6.19865 \n", 445 | "\u001b[J32 5.86399 6.20083 \n", 446 | "\u001b[J33 5.869 6.16945 \n", 447 | "\u001b[J34 5.85692 6.21721 \n", 448 | "\u001b[J35 5.85175 6.24535 \n", 449 | "\u001b[J36 5.8673 6.22052 \n", 450 | "\u001b[J37 5.85207 6.22343 \n", 451 | "\u001b[J38 5.83775 6.19369 \n", 452 | "\u001b[J39 5.85815 6.28352 \n", 453 | "\u001b[J40 5.85035 6.2091 \n", 454 | "\u001b[J41 5.8423 6.28353 \n", 455 | "\u001b[J42 5.83698 6.22302 \n", 456 | "\u001b[J43 5.84559 6.28724 \n", 457 | "\u001b[J44 5.82942 6.24757 \n", 458 | "\u001b[J45 5.81868 6.2275 \n", 459 | "\u001b[J46 5.84483 6.31015 \n", 460 | "\u001b[J47 5.82578 6.2163 \n", 461 | "\u001b[J48 5.81876 6.27374 \n", 462 | "\u001b[J49 5.80941 6.27729 \n", 463 | "\u001b[J50 5.82054 6.30504 \n", 464 | "\u001b[J51 5.80205 6.25924 \n", 465 | "\u001b[J52 5.82125 6.24315 \n", 466 | "\u001b[J53 5.8075 6.27823 \n", 467 | "\u001b[J54 5.81458 6.29426 \n", 468 | "\u001b[J55 5.80863 6.2662 \n", 469 | "\u001b[J56 5.81012 6.27777 \n", 470 | "\u001b[J57 5.80712 6.21873 \n", 471 | "\u001b[J58 5.82886 6.26017 \n", 472 | "\u001b[J59 5.80509 6.31339 \n", 473 | "\u001b[J60 5.80734 6.2874 \n", 474 | "\u001b[J61 5.80876 6.28746 \n", 475 | "\u001b[J62 5.82057 6.26831 \n", 476 | "\u001b[J63 5.80071 6.27469 \n", 477 | "\u001b[J64 5.82219 6.33319 \n", 478 | "\u001b[J65 5.8104 6.27122 \n", 479 | "\u001b[J66 5.80218 6.34508 \n", 480 | "\u001b[J67 5.81691 6.30597 \n", 481 | "\u001b[J68 5.80138 6.29948 \n", 482 | "\u001b[J69 5.80251 6.31008 \n", 483 | "\u001b[J70 5.79986 6.31869 \n", 484 | "\u001b[J71 5.8007 6.33197 \n", 485 | "\u001b[J72 5.7917 6.34002 \n", 486 | "\u001b[J73 5.80931 6.34416 \n", 487 | "\u001b[J74 5.80188 6.31156 \n", 488 | "\u001b[J75 5.80004 6.3059 \n", 489 | "\u001b[J76 5.78182 6.34326 \n", 490 | "\u001b[J77 5.80898 6.3267 \n", 491 | "\u001b[J78 5.80455 6.30006 \n", 492 | "\u001b[J79 5.79808 6.35709 \n", 493 | "\u001b[J80 5.80054 6.36302 \n", 494 | "\u001b[J81 5.8025 6.31012 \n", 495 | "\u001b[J82 5.78773 6.3225 \n", 496 | "\u001b[J83 5.7871 6.33635 \n", 497 | "\u001b[J84 5.80807 6.32977 \n", 498 | "\u001b[J85 5.78313 6.3488 \n", 499 | "\u001b[J86 5.79799 6.32949 \n", 500 | "\u001b[J87 5.79194 6.36984 \n", 501 | "\u001b[J88 5.78191 6.28538 \n", 502 | "\u001b[J89 5.80466 6.33613 \n", 503 | "\u001b[J90 5.79181 6.30647 \n", 504 | "\u001b[J91 5.81321 6.34016 \n", 505 | "\u001b[J92 5.80324 6.4427 \n", 506 | "\u001b[J93 5.7999 6.3277 \n", 507 | "\u001b[J94 5.78875 6.31837 \n", 508 | "\u001b[J95 5.7871 6.36152 \n", 509 | "\u001b[J96 5.78073 6.37049 \n", 510 | "\u001b[J97 5.80366 6.35217 \n", 511 | "\u001b[J98 5.78997 6.42968 \n", 512 | "\u001b[J99 5.78576 6.32253 \n", 513 | "\u001b[J100 5.78377 6.3595 \n" 514 | ] 515 | } 516 | ], 517 | "source": [ 518 | "epoch = 100\n", 519 | "\n", 520 | "trainer = training.Trainer(updater, (epoch, 'epoch'), out='word2vec_result')\n", 521 | "trainer.extend(extensions.Evaluator(val_iter, model, converter=convert, device=gpu))\n", 522 | "trainer.extend(extensions.LogReport())\n", 523 | "trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss']))\n", 524 | "trainer.run()" 525 | ] 526 | } 527 | ], 528 | "metadata": { 529 | "kernelspec": { 530 | "display_name": "Python 3", 531 | "language": "python", 532 | "name": "python3" 533 | }, 534 | "language_info": { 535 | "codemirror_mode": { 536 | "name": "ipython", 537 | "version": 3 538 | }, 539 | "file_extension": ".py", 540 | "mimetype": "text/x-python", 541 | "name": "python", 542 | "nbconvert_exporter": "python", 543 | "pygments_lexer": "ipython3", 544 | "version": "3.6.1" 545 | } 546 | }, 547 | "nbformat": 4, 548 | "nbformat_minor": 2 549 | } 550 | -------------------------------------------------------------------------------- /6_dqn_cartpole.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ChainerRL Quickstart Guide\n", 8 | "\n", 9 | "This is a quickstart guide for users who just want to try ChainerRL for the first time.\n", 10 | "\n", 11 | "If you have not yet installed ChainerRL, run the command below to install it:" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "Collecting chainerrl\n", 24 | " Downloading chainerrl-0.2.0.tar.gz (56kB)\n", 25 | "Collecting cached-property (from chainerrl)\n", 26 | " Downloading cached_property-1.3.1-py2.py3-none-any.whl\n", 27 | "Requirement already satisfied: chainer>=2.0.0 in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainerrl)\n", 28 | "Collecting future (from chainerrl)\n", 29 | " Downloading future-0.16.0.tar.gz (824kB)\n", 30 | "Collecting gym>=0.7.3 (from chainerrl)\n", 31 | " Downloading gym-0.9.3.tar.gz (157kB)\n", 32 | "Requirement already satisfied: numpy>=1.10.4 in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainerrl)\n", 33 | "Requirement already satisfied: pillow in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainerrl)\n", 34 | "Requirement already satisfied: scipy in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainerrl)\n", 35 | "Requirement already satisfied: six>=1.9.0 in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainer>=2.0.0->chainerrl)\n", 36 | "Requirement already satisfied: protobuf>=2.6.0 in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainer>=2.0.0->chainerrl)\n", 37 | "Requirement already satisfied: mock in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainer>=2.0.0->chainerrl)\n", 38 | "Requirement already satisfied: nose in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainer>=2.0.0->chainerrl)\n", 39 | "Requirement already satisfied: filelock in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from chainer>=2.0.0->chainerrl)\n", 40 | "Requirement already satisfied: requests>=2.0 in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from gym>=0.7.3->chainerrl)\n", 41 | "Collecting pyglet>=1.2.0 (from gym>=0.7.3->chainerrl)\n", 42 | " Downloading pyglet-1.2.4-py3-none-any.whl (964kB)\n", 43 | "Requirement already satisfied: olefile in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from pillow->chainerrl)\n", 44 | "Requirement already satisfied: setuptools in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/setuptools-27.2.0-py3.6.egg (from protobuf>=2.6.0->chainer>=2.0.0->chainerrl)\n", 45 | "Requirement already satisfied: pbr>=0.11 in /home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages (from mock->chainer>=2.0.0->chainerrl)\n", 46 | "Building wheels for collected packages: chainerrl, future, gym\n", 47 | " Running setup.py bdist_wheel for chainerrl: started\n", 48 | " Running setup.py bdist_wheel for chainerrl: finished with status 'done'\n", 49 | " Stored in directory: /home/shunta/.cache/pip/wheels/50/e1/16/d6879538da7fe0053f5b61c3d1f4e1b009464d3564b99c792c\n", 50 | " Running setup.py bdist_wheel for future: started\n", 51 | " Running setup.py bdist_wheel for future: finished with status 'done'\n", 52 | " Stored in directory: /home/shunta/.cache/pip/wheels/c2/50/7c/0d83b4baac4f63ff7a765bd16390d2ab43c93587fac9d6017a\n", 53 | " Running setup.py bdist_wheel for gym: started\n", 54 | " Running setup.py bdist_wheel for gym: finished with status 'done'\n", 55 | " Stored in directory: /home/shunta/.cache/pip/wheels/2b/16/05/14202d3528fb14912254fe7062bfc8b061ade8de9409f1abd0\n", 56 | "Successfully built chainerrl future gym\n", 57 | "Installing collected packages: cached-property, future, pyglet, gym, chainerrl\n", 58 | "Successfully installed cached-property-1.3.1 chainerrl-0.2.0 future-0.16.0 gym-0.9.3 pyglet-1.2.4\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "%%bash\n", 64 | "pip install chainerrl" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "If you have already installed ChainerRL, let's begin!\n", 72 | "\n", 73 | "First, you need to import necessary modules. The module name of ChainerRL is `chainerrl`. Let's import `gym` and `numpy` as well since they are used later." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 2, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "import chainer\n", 83 | "import chainer.functions as F\n", 84 | "import chainer.links as L\n", 85 | "import chainerrl\n", 86 | "import gym\n", 87 | "import numpy as np" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "ChainerRL can be used for any problems if they are modeled as \"environments\". [OpenAI Gym](https://github.com/openai/gym) provides various kinds of benchmark environments and defines the common interface among them. ChainerRL uses a subset of the interface. Specifically, an environment must define its observation space and action space and have at least two methods: `reset` and `step`.\n", 95 | "\n", 96 | "- `env.reset` will reset the environment to the initial state and return the initial observation.\n", 97 | "- `env.step` will execute a given action, move to the next state and return four values:\n", 98 | " - a next observation\n", 99 | " - a scalar reward\n", 100 | " - a boolean value indicating whether the current state is terminal or not\n", 101 | " - additional information\n", 102 | "- `env.render` will render the current state.\n", 103 | "\n", 104 | "Let's try 'CartPole-v0', which is a classic control problem. You can see below that its observation space consists of four real numbers while its action space consists of two discrete actions." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 3, 110 | "metadata": { 111 | "scrolled": false 112 | }, 113 | "outputs": [ 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "[2017-09-23 05:17:39,776] Making new env: CartPole-v0\n" 119 | ] 120 | }, 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "observation space: Box(4,)\n", 126 | "action space: Discrete(2)\n", 127 | "initial observation: [ 0.0042395 -0.01029501 0.01804505 0.04476617]\n", 128 | "next observation: [ 0.0040336 -0.205671 0.01894037 0.3430874 ]\n", 129 | "reward: 1.0\n", 130 | "done: False\n", 131 | "info: {}\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "env = gym.make('CartPole-v0')\n", 137 | "print('observation space:', env.observation_space)\n", 138 | "print('action space:', env.action_space)\n", 139 | "\n", 140 | "obs = env.reset()\n", 141 | "env.render(close=True)\n", 142 | "print('initial observation:', obs)\n", 143 | "\n", 144 | "action = env.action_space.sample()\n", 145 | "obs, r, done, info = env.step(action)\n", 146 | "print('next observation:', obs)\n", 147 | "print('reward:', r)\n", 148 | "print('done:', done)\n", 149 | "print('info:', info)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "Now you have defined your environment. Next, you need to define an agent, which will learn through interactions with the environment.\n", 157 | "\n", 158 | "ChainerRL provides various agents, each of which implements a deep reinforcement learning algorithm.\n", 159 | "\n", 160 | "To use [DQN (Deep Q-Network)](http://dx.doi.org/10.1038/nature14236), you need to define a Q-function that receives an observation and returns an expected future return for each action the agent can take. In ChainerRL, you can define your Q-function as `chainer.Link` as below. Note that the outputs are wrapped by `chainerrl.action_value.DiscreteActionValue`, which implements `chainerrl.action_value.ActionValue`. By wrapping the outputs of Q-functions, ChainerRL can treat discrete-action Q-functions like this and [NAFs (Normalized Advantage Functions)](https://arxiv.org/abs/1603.00748) in the same way." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 4, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "class QFunction(chainer.Chain):\n", 170 | "\n", 171 | " def __init__(self, obs_size, n_actions, n_hidden_channels=50):\n", 172 | " super().__init__(\n", 173 | " l0=L.Linear(obs_size, n_hidden_channels),\n", 174 | " l1=L.Linear(n_hidden_channels, n_hidden_channels),\n", 175 | " l2=L.Linear(n_hidden_channels, n_actions))\n", 176 | "\n", 177 | " def __call__(self, x, test=False):\n", 178 | " \"\"\"\n", 179 | " Args:\n", 180 | " x (ndarray or chainer.Variable): An observation\n", 181 | " test (bool): a flag indicating whether it is in test mode\n", 182 | " \"\"\"\n", 183 | " h = F.tanh(self.l0(x))\n", 184 | " h = F.tanh(self.l1(h))\n", 185 | " return chainerrl.action_value.DiscreteActionValue(self.l2(h))\n", 186 | "\n", 187 | "obs_size = env.observation_space.shape[0]\n", 188 | "n_actions = env.action_space.n\n", 189 | "q_func = QFunction(obs_size, n_actions)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "If you want to use CUDA for computation, as usual as in Chainer, call `to_gpu`." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 5, 202 | "metadata": { 203 | "collapsed": true 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "# Uncomment to use CUDA\n", 208 | "# q_func.to_gpu(0)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "You can also use ChainerRL's predefined Q-functions." 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 6, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "_q_func = chainerrl.q_functions.FCStateQFunctionWithDiscreteAction(\n", 225 | " obs_size, n_actions,\n", 226 | " n_hidden_layers=2, n_hidden_channels=50)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "As in Chainer, `chainer.Optimizer` is used to update models." 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 7, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "# Use Adam to optimize q_func. eps=1e-2 is for stability.\n", 243 | "optimizer = chainer.optimizers.Adam(eps=1e-2)\n", 244 | "optimizer.setup(q_func)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "A Q-function and its optimizer are used by a DQN agent. To create a DQN agent, you need to specify a bit more parameters and configurations." 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 8, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "# Set the discount factor that discounts future rewards.\n", 261 | "gamma = 0.95\n", 262 | "\n", 263 | "# Use epsilon-greedy for exploration\n", 264 | "explorer = chainerrl.explorers.ConstantEpsilonGreedy(\n", 265 | " epsilon=0.3, random_action_func=env.action_space.sample)\n", 266 | "\n", 267 | "# DQN uses Experience Replay.\n", 268 | "# Specify a replay buffer and its capacity.\n", 269 | "replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)\n", 270 | "\n", 271 | "# Since observations from CartPole-v0 is numpy.float64 while\n", 272 | "# Chainer only accepts numpy.float32 by default, specify\n", 273 | "# a converter as a feature extractor function phi.\n", 274 | "phi = lambda x: x.astype(np.float32, copy=False)\n", 275 | "\n", 276 | "# Now create an agent that will interact with the environment.\n", 277 | "agent = chainerrl.agents.DoubleDQN(\n", 278 | " q_func, optimizer, replay_buffer, gamma, explorer,\n", 279 | " replay_start_size=500, update_interval=1,\n", 280 | " target_update_interval=100, phi=phi)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "Now you have an agent and an environment. It's time to start reinforcement learning!\n", 288 | "\n", 289 | "In training, use `agent.act_and_train` to select exploratory actions. `agent.stop_episode_and_train` must be called after finishing an episode. You can get training statistics of the agent via `agent.get_statistics`." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 9, 295 | "metadata": { 296 | "scrolled": true 297 | }, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "episode: 10 R: 12.0 statistics: [('average_q', 0.0077787917633448615), ('average_loss', 0)]\n", 304 | "episode: 20 R: 43.0 statistics: [('average_q', 0.013923729594215806), ('average_loss', 0)]\n", 305 | "episode: 30 R: 10.0 statistics: [('average_q', 0.04999595856865319), ('average_loss', 0.15626195506060395)]\n", 306 | "episode: 40 R: 10.0 statistics: [('average_q', 0.18431173820404814), ('average_loss', 0.19973429628136666)]\n", 307 | "episode: 50 R: 16.0 statistics: [('average_q', 0.4329778858284125), ('average_loss', 0.12129529302886367)]\n", 308 | "episode: 60 R: 40.0 statistics: [('average_q', 1.5867962687319506), ('average_loss', 0.1231642400453139)]\n", 309 | "episode: 70 R: 36.0 statistics: [('average_q', 4.5508317081422485), ('average_loss', 0.14574642336842872)]\n", 310 | "episode: 80 R: 70.0 statistics: [('average_q', 7.293821113338115), ('average_loss', 0.222018443450522)]\n", 311 | "episode: 90 R: 42.0 statistics: [('average_q', 9.706054559843952), ('average_loss', 0.22261116615911836)]\n", 312 | "episode: 100 R: 148.0 statistics: [('average_q', 13.271654782141711), ('average_loss', 0.2537233644580171)]\n", 313 | "episode: 110 R: 185.0 statistics: [('average_q', 17.379473389886567), ('average_loss', 0.23995480935576677)]\n", 314 | "episode: 120 R: 179.0 statistics: [('average_q', 19.205810990096783), ('average_loss', 0.20982516267359438)]\n", 315 | "episode: 130 R: 200.0 statistics: [('average_q', 19.86128616157245), ('average_loss', 0.17017104907517325)]\n", 316 | "episode: 140 R: 160.0 statistics: [('average_q', 20.14523553965665), ('average_loss', 0.17918074812334736)]\n", 317 | "episode: 150 R: 200.0 statistics: [('average_q', 20.386843352118866), ('average_loss', 0.1511973771788008)]\n", 318 | "episode: 160 R: 200.0 statistics: [('average_q', 20.524274776492966), ('average_loss', 0.181143022239863)]\n", 319 | "episode: 170 R: 200.0 statistics: [('average_q', 20.501493065164738), ('average_loss', 0.1426581032476842)]\n", 320 | "episode: 180 R: 146.0 statistics: [('average_q', 20.37513869566722), ('average_loss', 0.12322326194384814)]\n", 321 | "episode: 190 R: 55.0 statistics: [('average_q', 20.404746612680285), ('average_loss', 0.13629612704703933)]\n", 322 | "episode: 200 R: 200.0 statistics: [('average_q', 20.572537269328773), ('average_loss', 0.1488116341248042)]\n", 323 | "Finished.\n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "n_episodes = 200\n", 329 | "max_episode_len = 200\n", 330 | "for i in range(1, n_episodes + 1):\n", 331 | " obs = env.reset()\n", 332 | " reward = 0\n", 333 | " done = False\n", 334 | " R = 0 # return (sum of rewards)\n", 335 | " t = 0 # time step\n", 336 | " while not done and t < max_episode_len:\n", 337 | " # Uncomment to watch the behaviour\n", 338 | " # env.render()\n", 339 | " action = agent.act_and_train(obs, reward)\n", 340 | " obs, reward, done, _ = env.step(action)\n", 341 | " R += reward\n", 342 | " t += 1\n", 343 | " if i % 10 == 0:\n", 344 | " print('episode:', i,\n", 345 | " 'R:', R,\n", 346 | " 'statistics:', agent.get_statistics())\n", 347 | " agent.stop_episode_and_train(obs, reward, done)\n", 348 | "print('Finished.')" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "Now you finished training the agent. How good is the agent now? You can test it by using `agent.act` and `agent.stop_episode` instead. Exploration such as epsilon-greedy is not used anymore." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 10, 361 | "metadata": {}, 362 | "outputs": [ 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "test episode: 0 R: 200.0\n", 368 | "test episode: 1 R: 200.0\n", 369 | "test episode: 2 R: 200.0\n", 370 | "test episode: 3 R: 200.0\n", 371 | "test episode: 4 R: 200.0\n", 372 | "test episode: 5 R: 200.0\n", 373 | "test episode: 6 R: 200.0\n", 374 | "test episode: 7 R: 200.0\n", 375 | "test episode: 8 R: 200.0\n", 376 | "test episode: 9 R: 200.0\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "for i in range(10):\n", 382 | " obs = env.reset()\n", 383 | " done = False\n", 384 | " R = 0\n", 385 | " t = 0\n", 386 | " while not done and t < 200:\n", 387 | " env.render(close=True)\n", 388 | " action = agent.act(obs)\n", 389 | " obs, r, done, _ = env.step(action)\n", 390 | " R += r\n", 391 | " t += 1\n", 392 | " print('test episode:', i, 'R:', R)\n", 393 | " agent.stop_episode()" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "If test scores are good enough, the only remaining task is to save the agent so that you can reuse it. What you need to do is to simply call `agent.save` to save the agent, then `agent.load` to load the saved agent." 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 11, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "# Save an agent to the 'agent' directory\n", 410 | "agent.save('agent')\n", 411 | "\n", 412 | "# Uncomment to load an agent from the 'agent' directory\n", 413 | "# agent.load('agent')" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "RL completed!\n", 421 | "\n", 422 | "But writing code like this every time you use RL might be boring. So, ChainerRL has utility functions that do these things." 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 12, 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "name": "stdout", 432 | "output_type": "stream", 433 | "text": [ 434 | "outdir:result step:86 episode:0 R:86.0\n", 435 | "statistics:[('average_q', 20.728489019516747), ('average_loss', 0.13604925025581077)]\n", 436 | "outdir:result step:286 episode:1 R:200.0\n", 437 | "statistics:[('average_q', 20.671014208079793), ('average_loss', 0.14984728771766473)]\n", 438 | "outdir:result step:396 episode:2 R:110.0\n", 439 | "statistics:[('average_q', 20.658295082215886), ('average_loss', 0.16141102891913808)]\n", 440 | "outdir:result step:596 episode:3 R:200.0\n", 441 | "statistics:[('average_q', 20.65092498811014), ('average_loss', 0.11670109444167831)]\n", 442 | "outdir:result step:796 episode:4 R:200.0\n", 443 | "statistics:[('average_q', 20.624282196582172), ('average_loss', 0.15006617026267832)]\n", 444 | "outdir:result step:996 episode:5 R:200.0\n", 445 | "statistics:[('average_q', 20.590381701508214), ('average_loss', 0.17453604165516437)]\n", 446 | "outdir:result step:1196 episode:6 R:200.0\n", 447 | "statistics:[('average_q', 20.571275081196642), ('average_loss', 0.16252849495287455)]\n", 448 | "test episode: 0 R: 200.0\n", 449 | "test episode: 1 R: 200.0\n", 450 | "test episode: 2 R: 200.0\n", 451 | "test episode: 3 R: 200.0\n", 452 | "test episode: 4 R: 200.0\n", 453 | "test episode: 5 R: 200.0\n", 454 | "test episode: 6 R: 200.0\n", 455 | "test episode: 7 R: 200.0\n", 456 | "test episode: 8 R: 200.0\n", 457 | "test episode: 9 R: 200.0\n", 458 | "The best score is updated -3.40282e+38 -> 200.0\n", 459 | "Saved the agent to result/1196\n", 460 | "outdir:result step:1244 episode:7 R:48.0\n", 461 | "statistics:[('average_q', 20.44840300754298), ('average_loss', 0.1455696393507992)]\n", 462 | "outdir:result step:1444 episode:8 R:200.0\n", 463 | "statistics:[('average_q', 20.443317168193577), ('average_loss', 0.1385756250812212)]\n", 464 | "outdir:result step:1644 episode:9 R:200.0\n", 465 | "statistics:[('average_q', 20.388818403317572), ('average_loss', 0.11136568147911419)]\n", 466 | "outdir:result step:1844 episode:10 R:200.0\n", 467 | "statistics:[('average_q', 20.393853468915438), ('average_loss', 0.1388451133452519)]\n", 468 | "outdir:result step:1951 episode:11 R:107.0\n", 469 | "statistics:[('average_q', 20.403746200029968), ('average_loss', 0.1201870912602859)]\n", 470 | "outdir:result step:2000 episode:12 R:49.0\n", 471 | "statistics:[('average_q', 20.413271961263554), ('average_loss', 0.13582760984249495)]\n", 472 | "test episode: 0 R: 200.0\n", 473 | "test episode: 1 R: 200.0\n", 474 | "test episode: 2 R: 200.0\n", 475 | "test episode: 3 R: 200.0\n", 476 | "test episode: 4 R: 200.0\n", 477 | "test episode: 5 R: 200.0\n", 478 | "test episode: 6 R: 200.0\n", 479 | "test episode: 7 R: 200.0\n", 480 | "test episode: 8 R: 200.0\n", 481 | "test episode: 9 R: 200.0\n", 482 | "Saved the agent to result/2000_finish\n" 483 | ] 484 | } 485 | ], 486 | "source": [ 487 | "# Set up the logger to print info messages for understandability.\n", 488 | "import logging\n", 489 | "import sys\n", 490 | "gym.undo_logger_setup() # Turn off gym's default logger settings\n", 491 | "logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')\n", 492 | "\n", 493 | "chainerrl.experiments.train_agent_with_evaluation(\n", 494 | " agent, env,\n", 495 | " steps=2000, # Train the agent for 2000 steps\n", 496 | " eval_n_runs=10, # 10 episodes are sampled for each evaluation\n", 497 | " max_episode_len=200, # Maximum length of each episodes\n", 498 | " eval_interval=1000, # Evaluate the agent after every 1000 steps\n", 499 | " outdir='result') # Save everything to 'result' directory" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": {}, 505 | "source": [ 506 | "That's all of the ChainerRL quickstart guide. To know more about ChainerRL, please look into the `examples` directory and read and run the examples. Thank you!" 507 | ] 508 | } 509 | ], 510 | "metadata": { 511 | "kernelspec": { 512 | "display_name": "Python 3", 513 | "language": "python", 514 | "name": "python3" 515 | }, 516 | "language_info": { 517 | "codemirror_mode": { 518 | "name": "ipython", 519 | "version": 3 520 | }, 521 | "file_extension": ".py", 522 | "mimetype": "text/x-python", 523 | "name": "python", 524 | "nbconvert_exporter": "python", 525 | "pygments_lexer": "ipython3", 526 | "version": "3.6.1" 527 | } 528 | }, 529 | "nbformat": 4, 530 | "nbformat_minor": 1 531 | } 532 | -------------------------------------------------------------------------------- /8_chainer-for-theano-users.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Chainer for Theano Users\n", 8 | "\n", 9 | "As we mentioned [here](https://chainer.org/general/2017/09/29/thank-you-theano.html), Theano stops the development in a few weeks. Many spects of Chainer were inspired by Theano's clean interface design, so that we would like to introduce Chainer here by comparing the difference from Theano. We believe that this article assists the Theano users to move to Chainer quickly." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "In this post, we asume that the modules below have been imported." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": { 34 | "collapsed": true 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import theano\n", 39 | "import theano.tensor as T" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": { 46 | "collapsed": true 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import chainer\n", 51 | "import chainer.functions as F\n", 52 | "import chainer.links as L" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "First, let's summarize the key similarities and differences between Theano and Chainer.\n", 60 | "\n", 61 | "### Key similarities:\n", 62 | "\n", 63 | "- Python-based library\n", 64 | "- Functions can accept NumPy arrays\n", 65 | "- CPU/GPU support\n", 66 | "- Easy to write various operation as a differentiable function (custom layer)\n", 67 | "\n", 68 | "### Key differences:\n", 69 | "\n", 70 | "- Theano compiles the computational graph before run\n", 71 | "- Chainer builds the comptuational graph in runtime\n", 72 | "- Chainer provides many high-level APIs for neural networks\n", 73 | "- Chainer supports distributed learning with ChainerMN" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## Define a parametric function\n", 81 | "\n", 82 | "A neural network basically has many parametric functions and activation functions which are called \"layers\" commonly. Let's see the difference between how to create a new parametric function in Theano and Chainer. In this example, to show the way to do the same thing with the two different libraries, we show how to define the 2D convolution function. But Chainer has `chainer.links.Convolution2D`, so that you don't need to write the code below to use 2D convolution as a building block of a network actually." 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Theano:" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "class TheanoConvolutionLayer(object):\n", 101 | " \n", 102 | " def __init__(self, input, filter_shape, image_shape):\n", 103 | " # Prepare initial values of the parameter W\n", 104 | " spatial_dim = np.prod(filter_shape[2:])\n", 105 | " fan_in = filter_shape[1] * spatial_dim\n", 106 | " fan_out = filter_shape[0] * spatial_dim\n", 107 | " scale = np.sqrt(3. / fan_in)\n", 108 | " \n", 109 | " # Create the parameter W\n", 110 | " W_init = np.random.uniform(-scale, scale, filter_shape)\n", 111 | " self.W = theano.shared(W_init.astype(np.float32), borrow=True)\n", 112 | "\n", 113 | " # Create the paramter b\n", 114 | " b_init = np.zeros((filter_shape[0],))\n", 115 | " self.b = theano.shared(b_init.astype(np.float32), borrow=True)\n", 116 | "\n", 117 | " # Describe the convolution operation\n", 118 | " conv_out = T.nnet.conv2d(\n", 119 | " input=input,\n", 120 | " filters=self.W,\n", 121 | " filter_shape=filter_shape,\n", 122 | " input_shape=image_shape)\n", 123 | " \n", 124 | " # Add a bias\n", 125 | " self.output = conv_out + self.b.dimshuffle('x', 0, 'x', 'x')\n", 126 | " \n", 127 | " # Store paramters\n", 128 | " self.params = [self.W, self.b]" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "How can we use this class? In Theano, it defines the computation as code using symbols, but doesn't perform actual computation at that time. Namely, it defines the computational graph before run. To use the defined computational graph, we need to define another operator using `theano.function` which takes input variables and output variable." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 5, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "batchsize = 32\n", 147 | "input_shape = (batchsize, 1, 28, 28)\n", 148 | "filter_shape = (6, 1, 5, 5)\n", 149 | "\n", 150 | "# Create a tensor that represents a minibatch\n", 151 | "x = T.fmatrix('x')\n", 152 | "input = x.reshape(input_shape)\n", 153 | "\n", 154 | "conv = TheanoConvolutionLayer(input, filter_shape, input_shape)\n", 155 | "f = theano.function([input], conv.output)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "`conv` is the definition of how to compute the output from the first argument `input`, and `f` is the actual operator. You can pass values to `f` to compute the result of convolution like this:" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 6, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "(32, 6, 24, 24) \n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "x_data = np.random.rand(32, 1, 28, 28).astype(np.float32)\n", 180 | "\n", 181 | "y = f(x_data)\n", 182 | "\n", 183 | "print(y.shape, type(y))" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "### Chainer:\n", 191 | "\n", 192 | "What about the case in Chainer? Theano is a more general framework for scientific calculation, while Chainer focuses on neural networks. So, Chainer has many high-level APIs that enable users to write the building blocks of neural networks easier. Well, how to write the same convolution operator in Chainer?" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "metadata": { 199 | "collapsed": true 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "class ChainerConvolutionLayer(chainer.Link):\n", 204 | " \n", 205 | " def __init__(self, filter_shape):\n", 206 | " super().__init__()\n", 207 | " with self.init_scope():\n", 208 | " # Specify the way of initialize\n", 209 | " W_init = chainer.initializers.LeCunUniform()\n", 210 | " b_init = chainer.initializers.Zero()\n", 211 | " \n", 212 | " # Create a parameter object\n", 213 | " self.W = chainer.Parameter(W_init, filter_shape) \n", 214 | " self.b = chainer.Parameter(b_init, filter_shape[0])\n", 215 | " \n", 216 | " def __call__(self, x):\n", 217 | " return F.convolution_2d(x, self.W, self.b)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "Actually, as we said at the top of this article, Chainer has pre-implemented `chainer.links.Convolution2D` class for convolution. So, you don't need to implement the code above by yourself, but it shows how to do the same thing written in Theano above.\n", 225 | "\n", 226 | "You can create your own parametric function by defining a class inherited from `chainer.Link` as shown in the above. What computation will be applied to the input is described in `__call__` method.\n", 227 | "\n", 228 | "Then, how to use this class?" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 8, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "(32, 6, 24, 24) \n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "chainer_conv = ChainerConvolutionLayer(filter_shape)\n", 246 | "\n", 247 | "y = chainer_conv(x_data)\n", 248 | "\n", 249 | "print(y.shape, type(y), type(y.array))" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "Chainer provides many functions in `chainer.functions` and it takes NumPy array or `chainer.Variable` object as inputs. You can write arbitrary layer using those functions to make it differentiable. Note that a `chainer.Variable` object contains its actual data in `array` property." 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "**NOTE:**\n", 264 | "You can write the same thing using `L.Convolution2D` like this:" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 9, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "(32, 6, 24, 24) \n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "conv_link = L.Convolution2D(in_channels=1, out_channels=6, ksize=(5, 5))\n", 282 | "\n", 283 | "y = conv_link(x_data)\n", 284 | "\n", 285 | "print(y.shape, type(y), type(y.array))" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "## Use Theano function as a layer in Chainer\n", 293 | "\n", 294 | "How to port parametric functions written in Theano to `Link`s in Chainer is shown in the above chapter. But there's an easier way to port **non-parametric functions** from Theano to Chainer.\n", 295 | "\n", 296 | "Chainer provides [`TheanoFunction`](https://docs.chainer.org/en/latest/reference/generated/chainer.links.TheanoFunction.html?highlight=Theano) to wrap a Theano function as a `chainer.Link`. What you need to prepare is just the inputs and outputs of the Theano function you want to port to Chainer's `Link`. For example, a convolution function of Theano can be converted to a Chainer's `Link` as followings:" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 10, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "name": "stderr", 306 | "output_type": "stream", 307 | "text": [ 308 | "/home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/chainer/utils/experimental.py:104: FutureWarning: chainer.links.TheanoFunction is experimental. The interface can change in the future.\n", 309 | " FutureWarning)\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "x = T.fmatrix().reshape((32, 1, 28, 28))\n", 315 | "W = T.fmatrix().reshape((6, 1, 5, 5))\n", 316 | "b = T.fvector().reshape((6,))\n", 317 | "conv_out = T.nnet.conv2d(x, W) + b.dimshuffle('x', 0, 'x', 'x')\n", 318 | "\n", 319 | "f = L.TheanoFunction(inputs=[x, W, b], outputs=[conv_out])" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "It converts the Theano computational graph into Chainer's computational graph! So it's differentiable with the Chainer APIs, and easy to use as a building block of a network written in Chainer. But it takes `W` and `b` as input arguments, so it should be noted that it doesn't keep those parameters inside.\n", 327 | "\n", 328 | "Anyway, how to use this ported Theano function in a network in Chainer?" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 11, 334 | "metadata": { 335 | "collapsed": true 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "class MyNetworkWithTheanoConvolution(chainer.Chain):\n", 340 | " \n", 341 | " def __init__(self, theano_conv):\n", 342 | " super().__init__()\n", 343 | " self.theano_conv = theano_conv\n", 344 | " W_init = chainer.initializers.LeCunUniform()\n", 345 | " b_init = chainer.initializers.Zero()\n", 346 | " with self.init_scope():\n", 347 | " self.W = chainer.Parameter(W_init, (6, 1, 5, 5))\n", 348 | " self.b = chainer.Parameter(b_init, (6,))\n", 349 | " self.l1 = L.Linear(None, 100)\n", 350 | " self.l2 = L.Linear(100, 10)\n", 351 | " \n", 352 | " def __call__(self, x):\n", 353 | " h = self.theano_conv(x, self.W, self.b)\n", 354 | " h = F.relu(h)\n", 355 | " h = self.l1(h)\n", 356 | " h = F.relu(h)\n", 357 | " return self.l2(h)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "This class is a Chainer's model class which is inherited from `chainer.Chain`. This is a standard way to define a class in Chainer, but, look! it uses a Theano function as a layer inside `__call__` method. The first layer of this network is a convolution layer, and that layer is Theano function which runs computation with Theano.\n", 365 | "\n", 366 | "The usage of this network is completely same as the normal Chainer's models:" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 12, 372 | "metadata": {}, 373 | "outputs": [ 374 | { 375 | "name": "stderr", 376 | "output_type": "stream", 377 | "text": [ 378 | "/home/shunta/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/chainer/utils/experimental.py:104: FutureWarning: chainer.functions.TheanoFunction is experimental. The interface can change in the future.\n", 379 | " FutureWarning)\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "# Instantiate a model object\n", 385 | "model = MyNetworkWithTheanoConvolution(f)\n", 386 | "\n", 387 | "# And give an array/Variable to get the network output\n", 388 | "y = model(x_data)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "metadata": {}, 394 | "source": [ 395 | "This network takes a mini-batch of images whose shape is `(32, 1, 28, 28)` and outputs 10-dimensional vectors for each input image, so the shape of the output variable will be `(32, 10)`:" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 13, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "(32, 10)\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "print(y.shape)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": {}, 418 | "source": [ 419 | "This network is differentiable and the parameters of the Theano's convolution function which are defined in the constructer as `self.W` and `self.b` can be optimized through Chainer's optimizers normaly." 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 14, 425 | "metadata": { 426 | "collapsed": true 427 | }, 428 | "outputs": [], 429 | "source": [ 430 | "t = np.random.randint(0, 10, size=(32,)).astype(np.int32)\n", 431 | "loss = F.softmax_cross_entropy(y, t)\n", 432 | "\n", 433 | "model.cleargrads()\n", 434 | "loss.backward()" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "You can check the gradients calculated for the parameters `W` and `b` used in the Theano function `theano_conv`:" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 15, 447 | "metadata": { 448 | "collapsed": true 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "W_gradient = model.W.grad_var.array\n", 453 | "b_gradient = model.b.grad_var.array" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 16, 459 | "metadata": {}, 460 | "outputs": [ 461 | { 462 | "name": "stdout", 463 | "output_type": "stream", 464 | "text": [ 465 | "(6, 1, 5, 5) \n", 466 | "(6,) \n" 467 | ] 468 | } 469 | ], 470 | "source": [ 471 | "print(W_gradient.shape, type(W_gradient))\n", 472 | "print(b_gradient.shape, type(b_gradient))" 473 | ] 474 | } 475 | ], 476 | "metadata": { 477 | "kernelspec": { 478 | "display_name": "Python 3", 479 | "language": "python", 480 | "name": "python3" 481 | }, 482 | "language_info": { 483 | "codemirror_mode": { 484 | "name": "ipython", 485 | "version": 3 486 | }, 487 | "file_extension": ".py", 488 | "mimetype": "text/x-python", 489 | "name": "python", 490 | "nbconvert_exporter": "python", 491 | "pygments_lexer": "ipython3", 492 | "version": "3.6.1" 493 | } 494 | }, 495 | "nbformat": 4, 496 | "nbformat_minor": 2 497 | } 498 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Shunta Saito 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chainer-Notebooks 2 | 3 | ## Requirements 4 | 5 | - Chainer>=3.0.0rc1 6 | - CuPy>=2.0.0rc1 7 | 8 | ## Contents 9 | 10 | 1. [How to write a training loop in Chainer](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/1_training_loop_in_chainer.ipynb) 11 | 2. [Let’s try using the Trainer feature](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/2_how_to_use_trainer.ipynb) 12 | 3. [How to write ConvNet models in Chainer](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/3_write_convnet_in_chainer.ipynb) 13 | 4. [Write an RNN Language Model](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/4_RNN-language-model.ipynb) 14 | 5. [Word2Vec: Obtain word embeddings](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/5_word2vec.ipynb) 15 | 6. [ChainerRL Quickstart Guide](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/6_dqn_cartpole.ipynb) 16 | 7. [Train a ConvNet using multiple GPUs](https://github.com/mitmul/DSVM-Chainer-Notebooks/blob/master/7_multiple_gpus.ipynb) 17 | 8. [Chainer for Theano Users](https://github.com/mitmul/chainer-notebooks/blob/master/8_chainer-for-theano-users.ipynb) 18 | 9. [Vanilla LSTM with CuPy](https://github.com/mitmul/chainer-notebooks/blob/master/9_vanilla-LSTM-with-cupy.ipynb) 19 | 20 | -------------------------------------------------------------------------------- /center_context_word.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mitmul/chainer-notebooks/639db107034346055b51af98fafdffc9f4bd52d2/center_context_word.png -------------------------------------------------------------------------------- /gentxt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Example to generate text from a recurrent neural network language model. 3 | 4 | This code is ported from following implementation. 5 | https://github.com/longjie/chainer-char-rnn/blob/master/sample.py 6 | 7 | """ 8 | import argparse 9 | import sys 10 | 11 | import numpy as np 12 | import six 13 | 14 | import chainer 15 | from chainer import cuda 16 | import chainer.functions as F 17 | import chainer.links as L 18 | from chainer import serializers 19 | 20 | import train_ptb 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--model', '-m', type=str, required=True, 26 | help='model data, saved by train_ptb.py') 27 | parser.add_argument('--primetext', '-p', type=str, required=True, 28 | default='', 29 | help='base text data, used for text generation') 30 | parser.add_argument('--seed', '-s', type=int, default=123, 31 | help='random seeds for text generation') 32 | parser.add_argument('--unit', '-u', type=int, default=650, 33 | help='number of units') 34 | parser.add_argument('--sample', type=int, default=1, 35 | help='negative value indicates NOT use random choice') 36 | parser.add_argument('--length', type=int, default=20, 37 | help='length of the generated text') 38 | parser.add_argument('--gpu', type=int, default=-1, 39 | help='GPU ID (negative value indicates CPU)') 40 | args = parser.parse_args() 41 | 42 | np.random.seed(args.seed) 43 | chainer.config.train = False 44 | 45 | xp = cuda.cupy if args.gpu >= 0 else np 46 | 47 | # load vocabulary 48 | vocab = chainer.datasets.get_ptb_words_vocabulary() 49 | ivocab = {} 50 | for c, i in vocab.items(): 51 | ivocab[i] = c 52 | 53 | # should be same as n_units , described in train_ptb.py 54 | n_units = args.unit 55 | 56 | lm = train_ptb.RNNForLM(len(vocab), n_units) 57 | model = L.Classifier(lm) 58 | 59 | serializers.load_npz(args.model, model) 60 | 61 | if args.gpu >= 0: 62 | cuda.get_device_from_id(args.gpu).use() 63 | model.to_gpu() 64 | 65 | model.predictor.reset_state() 66 | 67 | primetext = args.primetext 68 | if isinstance(primetext, six.binary_type): 69 | primetext = primetext.decode('utf-8') 70 | 71 | if primetext in vocab: 72 | prev_word = chainer.Variable(xp.array([vocab[primetext]], xp.int32)) 73 | else: 74 | print('ERROR: Unfortunately ' + primetext + ' is unknown.') 75 | exit() 76 | 77 | prob = F.softmax(model.predictor(prev_word)) 78 | sys.stdout.write(primetext + ' ') 79 | 80 | for i in six.moves.range(args.length): 81 | prob = F.softmax(model.predictor(prev_word)) 82 | if args.sample > 0: 83 | probability = cuda.to_cpu(prob.data)[0].astype(np.float64) 84 | probability /= np.sum(probability) 85 | index = np.random.choice(range(len(probability)), p=probability) 86 | else: 87 | index = np.argmax(cuda.to_cpu(prob.data)) 88 | 89 | if ivocab[index] == '': 90 | sys.stdout.write('.') 91 | else: 92 | sys.stdout.write(ivocab[index] + ' ') 93 | 94 | prev_word = chainer.Variable(xp.array([index], dtype=xp.int32)) 95 | 96 | sys.stdout.write('\n') 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /rnnlm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mitmul/chainer-notebooks/639db107034346055b51af98fafdffc9f4bd52d2/rnnlm.png -------------------------------------------------------------------------------- /rnnlm_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mitmul/chainer-notebooks/639db107034346055b51af98fafdffc9f4bd52d2/rnnlm_example.png -------------------------------------------------------------------------------- /skipgram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mitmul/chainer-notebooks/639db107034346055b51af98fafdffc9f4bd52d2/skipgram.png -------------------------------------------------------------------------------- /skipgram_detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mitmul/chainer-notebooks/639db107034346055b51af98fafdffc9f4bd52d2/skipgram_detail.png -------------------------------------------------------------------------------- /train_ptb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Sample script of recurrent neural network language model. 3 | 4 | This code is ported from the following implementation written in Torch. 5 | https://github.com/tomsercu/lstm 6 | 7 | """ 8 | from __future__ import division 9 | from __future__ import print_function 10 | import argparse 11 | 12 | import numpy as np 13 | 14 | import chainer 15 | import chainer.functions as F 16 | import chainer.links as L 17 | from chainer import training 18 | from chainer.training import extensions 19 | 20 | 21 | # Definition of a recurrent net for language modeling 22 | class RNNForLM(chainer.Chain): 23 | 24 | def __init__(self, n_vocab, n_units): 25 | super(RNNForLM, self).__init__() 26 | with self.init_scope(): 27 | self.embed = L.EmbedID(n_vocab, n_units) 28 | self.l1 = L.LSTM(n_units, n_units) 29 | self.l2 = L.LSTM(n_units, n_units) 30 | self.l3 = L.Linear(n_units, n_vocab) 31 | 32 | for param in self.params(): 33 | param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape) 34 | 35 | def reset_state(self): 36 | self.l1.reset_state() 37 | self.l2.reset_state() 38 | 39 | def __call__(self, x): 40 | h0 = self.embed(x) 41 | h1 = self.l1(F.dropout(h0)) 42 | h2 = self.l2(F.dropout(h1)) 43 | y = self.l3(F.dropout(h2)) 44 | return y 45 | 46 | 47 | # Dataset iterator to create a batch of sequences at different positions. 48 | # This iterator returns a pair of current words and the next words. Each 49 | # example is a part of sequences starting from the different offsets 50 | # equally spaced within the whole sequence. 51 | class ParallelSequentialIterator(chainer.dataset.Iterator): 52 | 53 | def __init__(self, dataset, batch_size, repeat=True): 54 | self.dataset = dataset 55 | self.batch_size = batch_size # batch size 56 | # Number of completed sweeps over the dataset. In this case, it is 57 | # incremented if every word is visited at least once after the last 58 | # increment. 59 | self.epoch = 0 60 | # True if the epoch is incremented at the last iteration. 61 | self.is_new_epoch = False 62 | self.repeat = repeat 63 | length = len(dataset) 64 | # Offsets maintain the position of each sequence in the mini-batch. 65 | self.offsets = [i * length // batch_size for i in range(batch_size)] 66 | # NOTE: this is not a count of parameter updates. It is just a count of 67 | # calls of ``__next__``. 68 | self.iteration = 0 69 | # use -1 instead of None internally 70 | self._previous_epoch_detail = -1. 71 | 72 | def __next__(self): 73 | # This iterator returns a list representing a mini-batch. Each item 74 | # indicates a different position in the original sequence. Each item is 75 | # represented by a pair of two word IDs. The first word is at the 76 | # "current" position, while the second word at the next position. 77 | # At each iteration, the iteration count is incremented, which pushes 78 | # forward the "current" position. 79 | length = len(self.dataset) 80 | if not self.repeat and self.iteration * self.batch_size >= length: 81 | # If not self.repeat, this iterator stops at the end of the first 82 | # epoch (i.e., when all words are visited once). 83 | raise StopIteration 84 | cur_words = self.get_words() 85 | self._previous_epoch_detail = self.epoch_detail 86 | self.iteration += 1 87 | next_words = self.get_words() 88 | 89 | epoch = self.iteration * self.batch_size // length 90 | self.is_new_epoch = self.epoch < epoch 91 | if self.is_new_epoch: 92 | self.epoch = epoch 93 | 94 | return list(zip(cur_words, next_words)) 95 | 96 | @property 97 | def epoch_detail(self): 98 | # Floating point version of epoch. 99 | return self.iteration * self.batch_size / len(self.dataset) 100 | 101 | @property 102 | def previous_epoch_detail(self): 103 | if self._previous_epoch_detail < 0: 104 | return None 105 | return self._previous_epoch_detail 106 | 107 | def get_words(self): 108 | # It returns a list of current words. 109 | return [self.dataset[(offset + self.iteration) % len(self.dataset)] 110 | for offset in self.offsets] 111 | 112 | def serialize(self, serializer): 113 | # It is important to serialize the state to be recovered on resume. 114 | self.iteration = serializer('iteration', self.iteration) 115 | self.epoch = serializer('epoch', self.epoch) 116 | try: 117 | self._previous_epoch_detail = serializer( 118 | 'previous_epoch_detail', self._previous_epoch_detail) 119 | except KeyError: 120 | # guess previous_epoch_detail for older version 121 | self._previous_epoch_detail = self.epoch + \ 122 | (self.current_position - self.batch_size) / len(self.dataset) 123 | if self.epoch_detail > 0: 124 | self._previous_epoch_detail = max( 125 | self._previous_epoch_detail, 0.) 126 | else: 127 | self._previous_epoch_detail = -1. 128 | 129 | 130 | # Custom updater for truncated BackProp Through Time (BPTT) 131 | class BPTTUpdater(training.StandardUpdater): 132 | 133 | def __init__(self, train_iter, optimizer, bprop_len, device): 134 | super(BPTTUpdater, self).__init__( 135 | train_iter, optimizer, device=device) 136 | self.bprop_len = bprop_len 137 | 138 | # The core part of the update routine can be customized by overriding. 139 | def update_core(self): 140 | loss = 0 141 | # When we pass one iterator and optimizer to StandardUpdater.__init__, 142 | # they are automatically named 'main'. 143 | train_iter = self.get_iterator('main') 144 | optimizer = self.get_optimizer('main') 145 | 146 | # Progress the dataset iterator for bprop_len words at each iteration. 147 | for i in range(self.bprop_len): 148 | # Get the next batch (a list of tuples of two word IDs) 149 | batch = train_iter.__next__() 150 | 151 | # Concatenate the word IDs to matrices and send them to the device 152 | # self.converter does this job 153 | # (it is chainer.dataset.concat_examples by default) 154 | x, t = self.converter(batch, self.device) 155 | 156 | # Compute the loss at this time step and accumulate it 157 | loss += optimizer.target(chainer.Variable(x), chainer.Variable(t)) 158 | 159 | optimizer.target.cleargrads() # Clear the parameter gradients 160 | loss.backward() # Backprop 161 | loss.unchain_backward() # Truncate the graph 162 | optimizer.update() # Update the parameters 163 | 164 | 165 | # Routine to rewrite the result dictionary of LogReport to add perplexity 166 | # values 167 | def compute_perplexity(result): 168 | result['perplexity'] = np.exp(result['main/loss']) 169 | if 'validation/main/loss' in result: 170 | result['val_perplexity'] = np.exp(result['validation/main/loss']) 171 | 172 | 173 | def main(): 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('--batchsize', '-b', type=int, default=20, 176 | help='Number of examples in each mini-batch') 177 | parser.add_argument('--bproplen', '-l', type=int, default=35, 178 | help='Number of words in each mini-batch ' 179 | '(= length of truncated BPTT)') 180 | parser.add_argument('--epoch', '-e', type=int, default=39, 181 | help='Number of sweeps over the dataset to train') 182 | parser.add_argument('--gpu', '-g', type=int, default=-1, 183 | help='GPU ID (negative value indicates CPU)') 184 | parser.add_argument('--gradclip', '-c', type=float, default=5, 185 | help='Gradient norm threshold to clip') 186 | parser.add_argument('--out', '-o', default='result', 187 | help='Directory to output the result') 188 | parser.add_argument('--resume', '-r', default='', 189 | help='Resume the training from snapshot') 190 | parser.add_argument('--test', action='store_true', 191 | help='Use tiny datasets for quick tests') 192 | parser.set_defaults(test=False) 193 | parser.add_argument('--unit', '-u', type=int, default=650, 194 | help='Number of LSTM units in each layer') 195 | parser.add_argument('--model', '-m', default='model.npz', 196 | help='Model file name to serialize') 197 | args = parser.parse_args() 198 | 199 | # Load the Penn Tree Bank long word sequence dataset 200 | train, val, test = chainer.datasets.get_ptb_words() 201 | n_vocab = max(train) + 1 # train is just an array of integers 202 | print('#vocab =', n_vocab) 203 | 204 | if args.test: 205 | train = train[:100] 206 | val = val[:100] 207 | test = test[:100] 208 | 209 | train_iter = ParallelSequentialIterator(train, args.batchsize) 210 | val_iter = ParallelSequentialIterator(val, 1, repeat=False) 211 | test_iter = ParallelSequentialIterator(test, 1, repeat=False) 212 | 213 | # Prepare an RNNLM model 214 | rnn = RNNForLM(n_vocab, args.unit) 215 | model = L.Classifier(rnn) 216 | model.compute_accuracy = False # we only want the perplexity 217 | if args.gpu >= 0: 218 | # Make a specified GPU current 219 | chainer.cuda.get_device_from_id(args.gpu).use() 220 | model.to_gpu() 221 | 222 | # Set up an optimizer 223 | optimizer = chainer.optimizers.SGD(lr=1.0) 224 | optimizer.setup(model) 225 | optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip)) 226 | 227 | # Set up a trainer 228 | updater = BPTTUpdater(train_iter, optimizer, args.bproplen, args.gpu) 229 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 230 | 231 | eval_model = model.copy() # Model with shared params and distinct states 232 | eval_rnn = eval_model.predictor 233 | trainer.extend(extensions.Evaluator( 234 | val_iter, eval_model, device=args.gpu, 235 | # Reset the RNN state at the beginning of each evaluation 236 | eval_hook=lambda _: eval_rnn.reset_state())) 237 | 238 | interval = 10 if args.test else 500 239 | trainer.extend(extensions.LogReport(postprocess=compute_perplexity, 240 | trigger=(interval, 'iteration'))) 241 | trainer.extend(extensions.PrintReport( 242 | ['epoch', 'iteration', 'perplexity', 'val_perplexity'] 243 | ), trigger=(interval, 'iteration')) 244 | trainer.extend(extensions.ProgressBar( 245 | update_interval=1 if args.test else 10)) 246 | trainer.extend(extensions.snapshot()) 247 | trainer.extend(extensions.snapshot_object( 248 | model, 'model_iter_{.updater.iteration}')) 249 | if args.resume: 250 | chainer.serializers.load_npz(args.resume, trainer) 251 | 252 | trainer.run() 253 | 254 | # Evaluate the final model 255 | print('test') 256 | eval_rnn.reset_state() 257 | evaluator = extensions.Evaluator(test_iter, eval_model, device=args.gpu) 258 | result = evaluator() 259 | print('test perplexity:', np.exp(float(result['main/loss']))) 260 | 261 | # Serialize the final model 262 | chainer.serializers.save_npz(args.model, model) 263 | 264 | 265 | if __name__ == '__main__': 266 | main() 267 | -------------------------------------------------------------------------------- /trainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mitmul/chainer-notebooks/639db107034346055b51af98fafdffc9f4bd52d2/trainer.png --------------------------------------------------------------------------------