├── .DS_Store ├── .ipynb_checkpoints └── SimpleConvNet-checkpoint.ipynb ├── 2.jpeg ├── 5.jpeg ├── 7.jpeg └── SimpleConvNet.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanhuang/SimpleConvNet/43c1d44c83d9f82f45913e4fd42829159971912b/.DS_Store -------------------------------------------------------------------------------- /.ipynb_checkpoints/SimpleConvNet-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 简单的CNN实现。用mnist数据集训练" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Package" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 859, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import sys ,os \n", 24 | "import numpy as np\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import tensorflow as tf #只是用来加载mnist数据集\n", 27 | "from PIL import Image\n", 28 | "import pandas as pd \n", 29 | "import math" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## 加载MNIST数据集" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 860, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def one_hot_label(y):\n", 46 | " one_hot_label = np.zeros((y.shape[0],10))\n", 47 | " y = y.reshape(y.shape[0])\n", 48 | " one_hot_label[range(y.shape[0]),y] = 1\n", 49 | " return one_hot_label\n", 50 | " \n", 51 | " " 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 863, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "shape of x_train is :(60000, 1, 28, 28)\n", 64 | "shape of t_train is :(60000, 10)\n", 65 | "shape of x_test is :(10000, 1, 28, 28)\n", 66 | "shape of t_test is :(10000, 10)\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "# #(训练图像,训练标签),(测试图像,测试标签)\n", 72 | "# # mnist的图像均为28*28尺寸的数据,通道为1\n", 73 | "(x_train_origin,t_train_origin),(x_test_origin,t_test_origin) = tf.keras.datasets.mnist.load_data()\n", 74 | "X_train = x_train_origin/255.0\n", 75 | "X_test = x_test_origin/255.0\n", 76 | "m,h,w = x_train_origin.shape\n", 77 | "X_train = X_train.reshape((m,1,h,w))\n", 78 | "y_train = one_hot_label(t_train_origin)\n", 79 | "\n", 80 | "m,h,w = x_test_origin.shape\n", 81 | "X_test = X_test.reshape((m,1,h,w))\n", 82 | "y_test = one_hot_label(t_test_origin)\n", 83 | "print(\"shape of x_train is :\"+repr(X_train.shape))\n", 84 | "print(\"shape of t_train is :\"+repr(y_train.shape))\n", 85 | "print(\"shape of x_test is :\"+repr(X_test.shape))\n", 86 | "print(\"shape of t_test is :\"+repr(y_test.shape))\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## 显示图像" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 864, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "y is:5\n" 106 | ] 107 | }, 108 | { 109 | "data": { 110 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADgdJREFUeJzt3X9sXfV5x/HPs9D8QRoIXjUTpWFpIhQUIuZOJkwoGkXM5YeCggGhWkLKRBT3j1ii0hQNZX8MNAVFg2RqBKrsqqHJ1KWZBCghqpp0CZBOTBEmhF9mKQylqi2TFAWTH/zIHD/74x53Lvh+r3Pvufdc+3m/JMv3nuecex4d5ZPz8/pr7i4A8fxJ0Q0AKAbhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q1GWNXJmZ8TghUGfublOZr6Y9v5ndYWbHzex9M3ukls8C0FhW7bP9ZjZL0m8kdUgalPSqpC53H0gsw54fqLNG7PlXSHrf3T9w9wuSfi5pdQ2fB6CBagn/Akm/m/B+MJv2R8ys28z6zay/hnUByFndL/i5e5+kPonDfqCZ1LLnH5K0cML7b2bTAEwDtYT/VUnXmtm3zGy2pO9J2ptPWwDqrerDfncfNbMeSfslzZK03d3fya0zAHVV9a2+qlbGOT9Qdw15yAfA9EX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFUP0S1JZnZC0llJFyWNunt7Hk0hP7NmzUrWr7zyyrquv6enp2zt8ssvTy67dOnSZH39+vXJ+pNPPlm21tXVlVz2888/T9Y3b96crD/22GPJejOoKfyZW939oxw+B0ADcdgPBFVr+F3SATN7zcy682gIQGPUeti/0t2HzOzPJP3KzP7b3Q9PnCH7T4H/GIAmU9Oe392Hst+nJD0vacUk8/S5ezsXA4HmUnX4zWyOmc0dfy3pu5LezqsxAPVVy2F/q6TnzWz8c/7N3X+ZS1cA6q7q8Lv7B5L+IsdeZqxrrrkmWZ89e3ayfvPNNyfrK1euLFubN29ectn77rsvWS/S4OBgsr5t27ZkvbOzs2zt7NmzyWXfeOONZP3ll19O1qcDbvUBQRF+ICjCDwRF+IGgCD8QFOEHgjJ3b9zKzBq3sgZqa2tL1g8dOpSs1/trtc1qbGwsWX/ooYeS9XPnzlW97uHh4WT9448/TtaPHz9e9brrzd1tKvOx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoLjPn4OWlpZk/ciRI8n64sWL82wnV5V6HxkZSdZvvfXWsrULFy4kl436/EOtuM8PIInwA0ERfiAowg8ERfiBoAg/EBThB4LKY5Te8E6fPp2sb9iwIVlftWpVsv76668n65X+hHXKsWPHkvWOjo5k/fz588n69ddfX7b28MMPJ5dFfbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKn6f38y2S1ol6ZS7L8+mtUjaLWmRpBOSHnD39B8618z9Pn+trrjiimS90nDSvb29ZWtr165NLvvggw8m67t27UrW0Xzy/D7/TyXd8aVpj0g66O7XSjqYvQcwjVQMv7sflvTlR9hWS9qRvd4h6Z6c+wJQZ9We87e6+/h4Rx9Kas2pHwANUvOz/e7uqXN5M+uW1F3regDkq9o9/0kzmy9J2e9T5WZ09z53b3f39irXBaAOqg3/XklrstdrJO3Jpx0AjVIx/Ga2S9J/SVpqZoNmtlbSZkkdZvaepL/J3gOYRiqe87t7V5nSbTn3EtaZM2dqWv6TTz6petl169Yl67t3707Wx8bGql43isUTfkBQhB8IivADQRF+ICjCDwRF+IGgGKJ7BpgzZ07Z2gsvvJBc9pZbbknW77zzzmT9wIEDyToajyG6ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQ3Oef4ZYsWZKsHz16NFkfGRlJ1l988cVkvb+/v2zt6aefTi7byH+bMwn3+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUNznD66zszNZf+aZZ5L1uXPnVr3ujRs3Jus7d+5M1oeHh5P1qLjPDyCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCqnif38y2S1ol6ZS7L8+mPSppnaTfZ7NtdPdfVFwZ9/mnneXLlyfrW7duTdZvu636kdx7e3uT9U2bNiXrQ0NDVa97OsvzPv9PJd0xyfR/cfe27Kdi8AE0l4rhd/fDkk43oBcADVTLOX+Pmb1pZtvN7KrcOgLQENWG/0eSlkhqkzQsaUu5Gc2s28z6zaz8H3MD0HBVhd/dT7r7RXcfk/RjSSsS8/a5e7u7t1fbJID8VRV+M5s/4W2npLfzaQdAo1xWaQYz2yXpO5K+YWaDkv5R0nfMrE2SSzoh6ft17BFAHfB9ftRk3rx5yfrdd99dtlbpbwWYpW9XHzp0KFnv6OhI1mcqvs8PIInwA0ERfiAowg8ERfiBoAg/EBS3+lCYL774Ilm/7LL0Yyijo6PJ+u2331629tJLLyWXnc641QcgifADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX7EdsMNNyTr999/f7J+4403lq1Vuo9fycDAQLJ++PDhmj5/pmPPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZ9/hlu6dGmy3tPTk6zfe++9yfrVV199yT1N1cWLF5P14eHhZH1sbCzPdmYc9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFTF+/xmtlDSTkmtklxSn7v/0MxaJO2WtEjSCUkPuPvH9Ws1rkr30ru6usrWKt3HX7RoUTUt5aK/vz9Z37RpU7K+d+/ePNsJZyp7/lFJf+fuyyT9laT1ZrZM0iOSDrr7tZIOZu8BTBMVw+/uw+5+NHt9VtK7khZIWi1pRzbbDkn31KtJAPm7pHN+M1sk6duSjkhqdffx5ys/VOm0AMA0MeVn+83s65KelfQDdz9j9v/Dgbm7lxuHz8y6JXXX2iiAfE1pz29mX1Mp+D9z9+eyySfNbH5Wny/p1GTLunufu7e7e3seDQPIR8XwW2kX/xNJ77r71gmlvZLWZK/XSNqTf3sA6qXiEN1mtlLSryW9JWn8O5IbVTrv/3dJ10j6rUq3+k5X+KyQQ3S3tqYvhyxbtixZf+qpp5L166677pJ7ysuRI0eS9SeeeKJsbc+e9P6Cr+RWZ6pDdFc853f3/5RU7sNuu5SmADQPnvADgiL8QFCEHwiK8ANBEX4gKMIPBMWf7p6ilpaWsrXe3t7ksm1tbcn64sWLq+opD6+88kqyvmXLlmR9//79yfpnn312yT2hMdjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQYe7z33TTTcn6hg0bkvUVK1aUrS1YsKCqnvLy6aeflq1t27Ytuezjjz+erJ8/f76qntD82PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBh7vN3dnbWVK/FwMBAsr5v375kfXR0NFlPfed+ZGQkuSziYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0GZu6dnMFsoaaekVkkuqc/df2hmj0paJ+n32awb3f0XFT4rvTIANXN3m8p8Uwn/fEnz3f2omc2V9JqkeyQ9IOmcuz851aYIP1B/Uw1/xSf83H1Y0nD2+qyZvSup2D9dA6Bml3TOb2aLJH1b0pFsUo+ZvWlm283sqjLLdJtZv5n119QpgFxVPOz/w4xmX5f0sqRN7v6cmbVK+kil6wD/pNKpwUMVPoPDfqDOcjvnlyQz+5qkfZL2u/vWSeqLJO1z9+UVPofwA3U21fBXPOw3M5P0E0nvTgx+diFwXKekty+1SQDFmcrV/pWSfi3pLUlj2eSNkroktal02H9C0vezi4Opz2LPD9RZrof9eSH8QP3ldtgPYGYi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBNXoIbo/kvTbCe+/kU1rRs3aW7P2JdFbtfLs7c+nOmNDv8//lZWb9bt7e2ENJDRrb83al0Rv1SqqNw77gaAIPxBU0eHvK3j9Kc3aW7P2JdFbtQrprdBzfgDFKXrPD6AghYTfzO4ws+Nm9r6ZPVJED+WY2Qkze8vMjhU9xFg2DNopM3t7wrQWM/uVmb2X/Z50mLSCenvUzIaybXfMzO4qqLeFZvaimQ2Y2Ttm9nA2vdBtl+irkO3W8MN+M5sl6TeSOiQNSnpVUpe7DzS0kTLM7ISkdncv/J6wmf21pHOSdo6PhmRm/yzptLtvzv7jvMrd/75JentUlzhyc516Kzey9N+qwG2X54jXeShiz79C0vvu/oG7X5D0c0mrC+ij6bn7YUmnvzR5taQd2esdKv3jabgyvTUFdx9296PZ67OSxkeWLnTbJfoqRBHhXyDpdxPeD6q5hvx2SQfM7DUz6y66mUm0ThgZ6UNJrUU2M4mKIzc30pdGlm6abVfNiNd544LfV61097+UdKek9dnhbVPy0jlbM92u+ZGkJSoN4zYsaUuRzWQjSz8r6QfufmZirchtN0lfhWy3IsI/JGnhhPffzKY1BXcfyn6fkvS8SqcpzeTk+CCp2e9TBffzB+5+0t0vuvuYpB+rwG2XjSz9rKSfuftz2eTCt91kfRW13YoI/6uSrjWzb5nZbEnfk7S3gD6+wszmZBdiZGZzJH1XzTf68F5Ja7LXayTtKbCXP9IsIzeXG1laBW+7phvx2t0b/iPpLpWu+P+PpH8ooocyfS2W9Eb2807RvUnapdJh4P+qdG1kraQ/lXRQ0nuS/kNSSxP19q8qjeb8pkpBm19QbytVOqR/U9Kx7Oeuorddoq9CthtP+AFBccEPCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ/weCC5r/92q6mAAAAABJRU5ErkJggg==\n", 111 | "text/plain": [ 112 | "
" 113 | ] 114 | }, 115 | "metadata": { 116 | "needs_background": "light" 117 | }, 118 | "output_type": "display_data" 119 | } 120 | ], 121 | "source": [ 122 | "index = 0\n", 123 | "plt.imshow(X_train[index].reshape((28,28)),cmap = plt.cm.gray)\n", 124 | "print(\"y is:\"+str(np.argmax(y_train[index])))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## 辅助函数" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "将根据滤波器的大小,步幅,填充,输入数据展开为2维数组。" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 865, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def im2col2(input_data,fh,fw,stride=1,pad=0):\n", 148 | " '''\n", 149 | " Arguments:\n", 150 | " \n", 151 | " input_data--输入数据,shape为(Number of example,Channel,Height,Width)\n", 152 | " fh -- 滤波器的height\n", 153 | " fw --滤波器的width\n", 154 | " stride -- 步幅\n", 155 | " pad -- 填充\n", 156 | " \n", 157 | " Returns :\n", 158 | " col -- 输入数据根据滤波器、步幅等展开的二维数组,每一行代表一条卷积数据\n", 159 | " '''\n", 160 | " N,C,H,W = input_data.shape\n", 161 | " \n", 162 | " out_h = (H + 2*pad - fh)//stride+1\n", 163 | " out_w = (W+2*pad-fw)//stride+1\n", 164 | " \n", 165 | " img = np.pad(input_data,[(0,0),(0,0),(pad,pad),(pad,pad)],\"constant\")\n", 166 | " \n", 167 | " col = np.zeros((N,out_h,out_w,fh*fw*C))\n", 168 | " \n", 169 | " #将所有维度上需要卷积的值展开成一列\n", 170 | " for y in range(out_h):\n", 171 | " y_start = y * stride\n", 172 | " y_end = y_start + fh\n", 173 | " for x in range(out_w):\n", 174 | " x_start = x*stride\n", 175 | " x_end = x_start+fw\n", 176 | " col[:,y,x] = img[:,:,y_start:y_end,x_start:x_end].reshape(N,-1)\n", 177 | " col = col.reshape(N*out_h*out_w,-1)\n", 178 | " return col\n", 179 | "\n", 180 | " " 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 866, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "def col2im2(col,out_shape,fh,fw,stride=1,pad=0):\n", 190 | " '''\n", 191 | " Arguments:\n", 192 | " col: 二维数组 \n", 193 | " out_shape-- 输出的shape,shape为(Number of example,Channel,Height,Width)\n", 194 | " fh -- 滤波器的height\n", 195 | " fw --滤波器的width\n", 196 | " stride -- 步幅\n", 197 | " pad -- 填充\n", 198 | " \n", 199 | " Returns :\n", 200 | " img -- 将col转换成的img ,shape为out_shape\n", 201 | " '''\n", 202 | " N,C,H,W = out_shape\n", 203 | " \n", 204 | " col_m,col_n = col.shape\n", 205 | " \n", 206 | " out_h = (H + 2*pad - fh)//stride+1\n", 207 | " out_w = (W+2*pad-fw)//stride+1\n", 208 | "\n", 209 | " \n", 210 | "\n", 211 | " img = np.zeros((N, C, H , W))\n", 212 | " # img = np.pad(img,[(0,0),(0,0),(pad,pad),(pad,pad)],\"constant\")\n", 213 | "\n", 214 | " #将col转换成一个filter\n", 215 | " for c in range(C):\n", 216 | " for y in range(out_h):\n", 217 | " for x in range(out_w):\n", 218 | " col_index = (c*out_h*out_w)+y*out_w+x\n", 219 | " ih = y*stride\n", 220 | " iw = x*stride\n", 221 | " img[:,c,ih:ih+fh,iw:iw+fw] = col[col_index].reshape((fh,fw))\n", 222 | " return img\n" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 868, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "[[[[0 2 1]\n", 235 | " [2 0 2]\n", 236 | " [1 3 0]]]]\n", 237 | "[[0. 2. 2. 0.]\n", 238 | " [2. 1. 0. 2.]\n", 239 | " [2. 0. 1. 3.]\n", 240 | " [0. 2. 3. 0.]]\n", 241 | "===\n" 242 | ] 243 | }, 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "array([[[[0., 2., 1.],\n", 248 | " [2., 0., 2.],\n", 249 | " [1., 3., 0.]]]])" 250 | ] 251 | }, 252 | "execution_count": 868, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "\n", 259 | "a = np.random.randint(0,5,size=(2,3,4,4))\n", 260 | "c = im2col2(a,2,2,2,1)\n", 261 | "\n", 262 | "a = np.random.randint(0,5,size=(1,1,3,3))\n", 263 | "print(a)\n", 264 | "c = im2col2(a,2,2,1,0)\n", 265 | "print(c)\n", 266 | "\n", 267 | "print('===')\n", 268 | "col2im2(c,out_shape=(1,1,3,3),fh=2,fw=2,stride=1)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "### 激活函数" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 869, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "def relu(input_X):\n", 285 | " \"\"\"\n", 286 | " Arguments:\n", 287 | " input_X -- a numpy array\n", 288 | " Return :\n", 289 | " A: a numpy array. let each elements in array all greater or equal 0\n", 290 | " \"\"\"\n", 291 | " \n", 292 | " A = np.where(input_X < 0 ,0,input_X)\n", 293 | " return A\n", 294 | " " 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 870, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "def softmax(input_X):\n", 304 | " \"\"\"\n", 305 | " Arguments:\n", 306 | " input_X -- a numpy array\n", 307 | " Return :\n", 308 | " A: a numpy array same shape with input_X\n", 309 | " \"\"\"\n", 310 | " exp_a = np.exp(input_X)\n", 311 | " sum_exp_a = np.sum(exp_a,axis=1)\n", 312 | " sum_exp_a = sum_exp_a.reshape(input_X.shape[0],-1)\n", 313 | " ret = exp_a/sum_exp_a\n", 314 | " # print(ret)\n", 315 | " return ret" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "### 损失函数" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 871, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "def cross_entropy_error(labels,logits):\n", 332 | " return -np.sum(labels*np.log(logits))\n" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 872, 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "text/plain": [ 343 | "-1.6094379124341003" 344 | ] 345 | }, 346 | "execution_count": 872, 347 | "metadata": {}, 348 | "output_type": "execute_result" 349 | } 350 | ], 351 | "source": [ 352 | "a = np.array([[1,0,0],[0,1,0]])\n", 353 | "a = a.T\n", 354 | "c = np.random.randint(1, 7,size = (3,2))\n", 355 | "cross_entropy_error(a,c)\n", 356 | "\n" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "## 卷积层" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 911, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "class Convolution:\n", 373 | " def __init__(self,W,fb,stride = 1,pad = 0):\n", 374 | " \"\"\"\n", 375 | " W-- 滤波器权重,shape为(FN,NC,FH,FW),FN 为滤波器的个数\n", 376 | " fb -- 滤波器的偏置,shape 为(1,FN) \n", 377 | " stride -- 步长\n", 378 | " pad -- 填充个数\n", 379 | " \"\"\"\n", 380 | " self.W = W\n", 381 | " self.fb = fb \n", 382 | " self.stride = stride\n", 383 | " self.pad = pad\n", 384 | " \n", 385 | " \n", 386 | " self.col_X = None\n", 387 | " self.X = None\n", 388 | " self.col_W = None\n", 389 | " \n", 390 | " self.dW = None\n", 391 | " self.db = None\n", 392 | " self.out_shape = None\n", 393 | " # self.out = None\n", 394 | " \n", 395 | " def forward (self ,input_X):\n", 396 | " \"\"\"\n", 397 | " input_X-- shape为(m,nc,height,width)\n", 398 | " \"\"\" \n", 399 | " self.X = input_X\n", 400 | " FN,NC,FH,FW = self.W.shape\n", 401 | " \n", 402 | " m,input_nc, input_h,input_w = self.X.shape\n", 403 | " \n", 404 | " #先计算输出的height和widt\n", 405 | " out_h = int((input_h+2*self.pad-FH)/self.stride + 1)\n", 406 | " out_w = int((input_w+2*self.pad-FW)/self.stride + 1)\n", 407 | " \n", 408 | " #将输入数据展开成二维数组,shape为(m*out_h*out_w,FH*FW*C)\n", 409 | " self.col_X = col_X = im2col2(self.X,FH,FW,self.stride,self.pad)\n", 410 | " \n", 411 | " #将滤波器一个个按列展开(FH*FW*C,FN)\n", 412 | " self.col_W = col_W = self.W.reshape(FN,-1).T\n", 413 | " out = np.dot(col_X,col_W)+self.fb\n", 414 | " out = out.T\n", 415 | " out = out.reshape(m,FN,out_h,out_w)\n", 416 | " self.out_shape = out.shape\n", 417 | " return out\n", 418 | " \n", 419 | " def backward(self, dz,learning_rate):\n", 420 | " #print(\"==== Conv backbward ==== \")\n", 421 | " assert(dz.shape == self.out_shape)\n", 422 | " \n", 423 | " FN,NC,FH,FW = self.W.shape\n", 424 | " o_FN,o_NC,o_FH,o_FW = self.out_shape\n", 425 | " \n", 426 | " col_dz = dz.reshape(o_NC,-1)\n", 427 | " col_dz = col_dz.T\n", 428 | " \n", 429 | " self.dW = np.dot(self.col_X.T,col_dz) #shape is (FH*FW*C,FN)\n", 430 | " self.db = np.sum(col_dz,axis=0,keepdims=True)\n", 431 | "\n", 432 | " \n", 433 | " self.dW = self.dW.T.reshape(self.W.shape)\n", 434 | " self.db = self.db.reshape(self.fb.shape)\n", 435 | " \n", 436 | " \n", 437 | " d_col_x = np.dot(col_dz,self.col_W.T) #shape is (m*out_h*out_w,FH,FW*C)\n", 438 | " dx = col2im2(d_col_x,self.X.shape,FH,FW,stride=1)\n", 439 | " \n", 440 | " assert(dx.shape == self.X.shape)\n", 441 | " \n", 442 | " #更新W和b\n", 443 | " self.W = self.W - learning_rate*self.dW\n", 444 | " self.fb = self.fb -learning_rate*self.db\n", 445 | " \n", 446 | " return dx\n", 447 | " \n", 448 | " " 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 912, 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "name": "stdout", 458 | "output_type": "stream", 459 | "text": [ 460 | "(2, 1, 2, 2)\n", 461 | "(1, 2)\n", 462 | "(1, 1, 3, 3)\n", 463 | "[[[[ 6. 8.]\n", 464 | " [12. 14.]]\n", 465 | "\n", 466 | " [[ 1. 2.]\n", 467 | " [ 4. 5.]]]]\n", 468 | "(1, 2, 2, 2)\n" 469 | ] 470 | } 471 | ], 472 | "source": [ 473 | "#W = np.array([[[[1,0],[0,1]]]])\n", 474 | "#W = np.array([[[[1,0],[0,1]]],[[[1,0],[0,0]]],[[[0,0],[0,0]]]])\n", 475 | "W = np.array([[[[1,0],[0,1]]],[[[1,0],[0,0]]]])\n", 476 | "print(W.shape)\n", 477 | "b = np.zeros((1,2))\n", 478 | "print(b.shape)\n", 479 | "conv = Convolution(W,b,stride=1,pad=0)\n", 480 | "\n", 481 | "a = np.array([[[[1,2,3],[4,5,6],[7,8,9]]]])\n", 482 | "print(a.shape)\n", 483 | "out = conv.forward(a)\n", 484 | "print(out)\n", 485 | "print(out.shape)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "metadata": {}, 491 | "source": [ 492 | "## 池化层" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 875, 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [ 501 | "class Pooling:\n", 502 | " def __init__(self,pool_h,pool_w,stride = 1,pad = 0):\n", 503 | " self.pool_h = pool_h\n", 504 | " self.pool_w = pool_w\n", 505 | " self.stride = stride\n", 506 | " self.pad = pad \n", 507 | " self.X = None\n", 508 | " self.arg_max = None\n", 509 | " \n", 510 | " def forward ( self,input_X) :\n", 511 | " \"\"\"\n", 512 | " 前向传播\n", 513 | " input_X-- shape为(m,nc,height,width)\n", 514 | " \"\"\" \n", 515 | " self.X = input_X\n", 516 | " N , C, H, W = input_X.shape\n", 517 | " out_h = int(1+(H-self.pool_h)/self.stride)\n", 518 | " out_w = int(1+(W-self.pool_w)/self.stride)\n", 519 | " \n", 520 | " #展开\n", 521 | " col = im2col2(input_X,self.pool_h,self.pool_w,self.stride,self.pad)\n", 522 | " col = col.reshape(-1,self.pool_h*self.pool_w)\n", 523 | " arg_max = np.argmax(col,axis=1)\n", 524 | " #最大值\n", 525 | " out = np.max(col,axis=1)\n", 526 | " out =out.T.reshape(N,C,out_h,out_w)\n", 527 | " self.arg_max = arg_max\n", 528 | " return out\n", 529 | " \n", 530 | " def backward(self ,dz):\n", 531 | " \"\"\"\n", 532 | " 反向传播\n", 533 | " Arguments:\n", 534 | " dz-- out的导数,shape与out 一致\n", 535 | " \n", 536 | " Return:\n", 537 | " 返回前向传播是的input_X的导数\n", 538 | " \"\"\" \n", 539 | " pool_size = self.pool_h*self.pool_w\n", 540 | " dmax = np.zeros((dz.size,pool_size))\n", 541 | " dmax[np.arange(self.arg_max.size),self.arg_max.flatten()] = dz.flatten()\n", 542 | " \n", 543 | " dx = col2im2(dmax,out_shape=self.X.shape,fh=self.pool_h,fw=self.pool_w,stride=self.stride)\n", 544 | " return dx\n", 545 | " \n", 546 | " " 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 876, 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "name": "stdout", 556 | "output_type": "stream", 557 | "text": [ 558 | "[[[[5. 6.]\n", 559 | " [8. 9.]]]]\n" 560 | ] 561 | } 562 | ], 563 | "source": [ 564 | "pool = Pooling(2,2,1,0)\n", 565 | "a = np.array([[[[1,2,3],[4,5,6],[7,8,9]]]])\n", 566 | "out = pool.forward(a)\n", 567 | "print(out)\n", 568 | "# print(a.size)\n", 569 | "# a = np.array([[[[1,2,3],[4,5,6],[7,8,9]]]])\n", 570 | "# print(a.size)\n", 571 | "# a.flatten()\n" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": {}, 577 | "source": [ 578 | "## Relu层" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 877, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "class Relu:\n", 588 | " def __init__(self):\n", 589 | " self.mask = None\n", 590 | " \n", 591 | " def forward(self ,X):\n", 592 | " self.mask = X <= 0\n", 593 | " out = X\n", 594 | " out[self.mask] = 0\n", 595 | " return out\n", 596 | " \n", 597 | " def backward(self,dz):\n", 598 | " dz[self.mask] = 0\n", 599 | " dx = dz \n", 600 | " return dx" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 878, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "name": "stdout", 610 | "output_type": "stream", 611 | "text": [ 612 | "[[False False]\n", 613 | " [ True False]]\n" 614 | ] 615 | }, 616 | { 617 | "data": { 618 | "text/plain": [ 619 | "array([[1, 1],\n", 620 | " [2, 1]])" 621 | ] 622 | }, 623 | "execution_count": 878, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "arr = np.array([[1,2],[3,-0.4]])\n", 630 | "mask = arr == np.max(arr)\n", 631 | "print(mask)\n", 632 | "\n", 633 | "barr = np.array([[1,1],[1,1]])\n", 634 | "barr+mask\n" 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "metadata": {}, 640 | "source": [ 641 | "## SoftMax层" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 879, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [ 650 | "class SoftMax:\n", 651 | " def __init__ (self):\n", 652 | " self.y_hat = None\n", 653 | " \n", 654 | " def forward(self,X):\n", 655 | " \n", 656 | " self.y_hat = softmax(X)\n", 657 | " return self.y_hat\n", 658 | " \n", 659 | " def backward(self,labels):\n", 660 | " m = labels.shape[0]\n", 661 | " dx = (self.y_hat - labels)\n", 662 | " \n", 663 | " return dx\n", 664 | " " 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": 880, 670 | "metadata": {}, 671 | "outputs": [], 672 | "source": [ 673 | "def compute_cost(logits,label):\n", 674 | " return cross_entropy_error(label,logits)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "markdown", 679 | "metadata": {}, 680 | "source": [ 681 | "## Affine FC层" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": 913, 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "class Affine:\n", 691 | " def __init__(self,W,b):\n", 692 | " self.W = W # shape is (n_x,n_unit)\n", 693 | " self.b = b # shape is(1,n_unit)\n", 694 | " self.X = None\n", 695 | " self.origin_x_shape = None\n", 696 | " \n", 697 | " self.dW = None\n", 698 | " self.db = None\n", 699 | " \n", 700 | " self.out_shape =None\n", 701 | " \n", 702 | " def forward(self,X):\n", 703 | " self.origin_x_shape = X.shape \n", 704 | " self.X = X.reshape(X.shape[0],-1)#(m,n)\n", 705 | " out = np.dot(self.X, self.W)+self.b\n", 706 | " self.out_shape = out.shape\n", 707 | " return out\n", 708 | " \n", 709 | " def backward(self,dz,learning_rate):\n", 710 | " \"\"\"\n", 711 | " dz-- 前面的导数\n", 712 | " \"\"\" \n", 713 | "# print(\"Affine backward\")\n", 714 | "# print(self.X.shape)\n", 715 | "# print(dz.shape)\n", 716 | "# print(self.W.shape)\n", 717 | " \n", 718 | " assert(dz.shape == self.out_shape)\n", 719 | " \n", 720 | " m = self.X.shape[0]\n", 721 | " \n", 722 | " self.dW = np.dot(self.X.T,dz)/m\n", 723 | " self.db = np.sum(dz,axis=0,keepdims=True)/m\n", 724 | " \n", 725 | " assert(self.dW.shape == self.W.shape)\n", 726 | " assert(self.db.shape == self.b.shape)\n", 727 | " \n", 728 | " dx = np.dot(dz,self.W.T)\n", 729 | " assert(dx.shape == self.X.shape)\n", 730 | " \n", 731 | " dx = dx.reshape(self.origin_x_shape) # 保持与之前的x一样的shape\n", 732 | " \n", 733 | " #更新W和b\n", 734 | " self.W = self.W-learning_rate*self.dW\n", 735 | " self.b = self.b - learning_rate*self.db\n", 736 | " \n", 737 | " return dx\n", 738 | " " 739 | ] 740 | }, 741 | { 742 | "cell_type": "markdown", 743 | "metadata": {}, 744 | "source": [ 745 | "## 模型\n" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": 964, 751 | "metadata": { 752 | "code_folding": [] 753 | }, 754 | "outputs": [], 755 | "source": [ 756 | "class SimpleConvNet:\n", 757 | "\n", 758 | " def __init__(self):\n", 759 | " self.X = None\n", 760 | " self.Y= None\n", 761 | " self.layers = []\n", 762 | "\n", 763 | " def add_conv_layer(self,n_filter,n_c , f, stride=1, pad=0):\n", 764 | " \"\"\"\n", 765 | " 添加一层卷积层\n", 766 | " Arguments:\n", 767 | " n_c -- 输入数据通道数,也即卷积层的通道数\n", 768 | " n_filter -- 滤波器的个数\n", 769 | " f --滤波器的长/宽\n", 770 | "\n", 771 | " Return :\n", 772 | " Conv -- 卷积层\n", 773 | " \"\"\"\n", 774 | "\n", 775 | " # 初始化W,b\n", 776 | " W = np.random.randn(n_filter, n_c, f, f)*0.01\n", 777 | " fb = np.zeros((1, n_filter))\n", 778 | " # 卷积层\n", 779 | " Conv = Convolution(W, fb, stride=stride, pad=pad)\n", 780 | " return Conv\n", 781 | "\n", 782 | " def add_maxpool_layer(self, pool_shape, stride=1, pad=0):\n", 783 | " \"\"\"\n", 784 | " 添加一层池化层\n", 785 | " Arguments:\n", 786 | " pool_shape -- 滤波器的shape\n", 787 | " f -- 滤波器大小\n", 788 | " Return :\n", 789 | " Pool -- 初始化的Pool类\n", 790 | " \"\"\"\n", 791 | " pool_h, pool_w = pool_shape\n", 792 | " pool = Pooling(pool_h, pool_w, stride=stride, pad=pad)\n", 793 | " \n", 794 | " return pool\n", 795 | " \n", 796 | " def add_affine(self,n_x, n_units):\n", 797 | " \"\"\"\n", 798 | " 添加一层全连接层\n", 799 | " Arguments:\n", 800 | " n_x -- 输入个数\n", 801 | " n_units -- 神经元个数\n", 802 | " Return :\n", 803 | " fc_layer -- Affine层对象\n", 804 | " \"\"\"\n", 805 | " \n", 806 | " W= np.random.randn(n_x, n_units)*0.01\n", 807 | " \n", 808 | " b = np.zeros((1, n_units))\n", 809 | " \n", 810 | " fc_layer = Affine(W,b)\n", 811 | " \n", 812 | " return fc_layer\n", 813 | " \n", 814 | " def add_relu(self):\n", 815 | " relu_layer = Relu()\n", 816 | " return relu_layer\n", 817 | " \n", 818 | " \n", 819 | " def add_softmax(self):\n", 820 | " softmax_layer = SoftMax()\n", 821 | " return softmax_layer\n", 822 | " \n", 823 | " #计算卷积或池化后的H和W\n", 824 | " def cacl_out_hw(self,HW,f,stride = 1,pad = 0):\n", 825 | " return (HW+2*pad - f)/stride+1\n", 826 | " \n", 827 | "\n", 828 | " \n", 829 | " \n", 830 | " def init_model(self,train_X,n_classes):\n", 831 | " \"\"\"\n", 832 | " 初始化一个卷积层网络\n", 833 | " \"\"\"\n", 834 | " N,C,H,W = train_X.shape\n", 835 | " #卷积层\n", 836 | " n_filter = 4\n", 837 | " f = 7\n", 838 | " \n", 839 | " conv_layer = self.add_conv_layer(n_filter= n_filter,n_c=C,f=f,stride=1)\n", 840 | " \n", 841 | " out_h = self.cacl_out_hw(H,f)\n", 842 | " out_w = self.cacl_out_hw(W,f)\n", 843 | " out_ch = n_filter\n", 844 | " \n", 845 | " self.layers.append(conv_layer)\n", 846 | " \n", 847 | " #Relu\n", 848 | " relu_layer = self.add_relu()\n", 849 | " self.layers.append(relu_layer)\n", 850 | " \n", 851 | " #池化\n", 852 | " f = 2\n", 853 | " pool_layer = self.add_maxpool_layer(pool_shape=(f,f),stride=2)\n", 854 | " out_h = self.cacl_out_hw(out_h,f,stride=2)\n", 855 | " out_w = self.cacl_out_hw(out_w,f,stride=2)\n", 856 | " #out_ch 不改变\n", 857 | " self.layers.append(pool_layer)\n", 858 | " \n", 859 | "\n", 860 | " \n", 861 | " \n", 862 | " #Affine层\n", 863 | " n_x = int(out_h*out_w*out_ch)\n", 864 | " n_units = 32\n", 865 | " fc_layer = self.add_affine(n_x=n_x,n_units=n_units)\n", 866 | " self.layers.append(fc_layer)\n", 867 | " \n", 868 | " #Relu\n", 869 | " relu_layer = self.add_relu()\n", 870 | " self.layers.append(relu_layer)\n", 871 | " \n", 872 | " #Affine\n", 873 | " fc_layer = self.add_affine(n_x=n_units,n_units=n_classes)\n", 874 | " self.layers.append(fc_layer)\n", 875 | " \n", 876 | " #SoftMax\n", 877 | " softmax_layer = self.add_softmax()\n", 878 | " self.layers.append(softmax_layer)\n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " def forward_progation(self,train_X, print_out = False):\n", 883 | " \"\"\"\n", 884 | " 前向传播\n", 885 | " Arguments:\n", 886 | " train_X -- 训练数据\n", 887 | " f -- 滤波器大小\n", 888 | "\n", 889 | " Return :\n", 890 | " Z-- 前向传播的结果\n", 891 | " loss -- 损失值\n", 892 | " \"\"\"\n", 893 | " \n", 894 | " \n", 895 | " N,C,H,W = train_X.shape\n", 896 | " index = 0\n", 897 | " # 卷积层\n", 898 | " conv_layer = self.layers[index]\n", 899 | " X = conv_layer.forward(train_X)\n", 900 | " index =index+1\n", 901 | " if print_out:\n", 902 | " print(\"卷积之后:\"+str(X.shape))\n", 903 | " # Relu\n", 904 | " relu_layer = self.layers[index]\n", 905 | " index =index+1\n", 906 | " X = relu_layer.forward(X)\n", 907 | " if print_out:\n", 908 | " print(\"Relu:\"+str(X.shape))\n", 909 | " \n", 910 | " \n", 911 | " # 池化层\n", 912 | " pool_layer = self.layers[index]\n", 913 | " index =index+1\n", 914 | " X = pool_layer.forward(X)\n", 915 | " if print_out:\n", 916 | " print(\"池化:\"+str(X.shape))\n", 917 | "\n", 918 | "\n", 919 | " #Affine层\n", 920 | " fc_layer = self.layers[index]\n", 921 | " index =index+1\n", 922 | " X = fc_layer.forward(X)\n", 923 | " if print_out:\n", 924 | " print(\"Affline 层的X:\"+str(X.shape))\n", 925 | "\n", 926 | " #Relu\n", 927 | " relu_layer = self.layers[index]\n", 928 | " index =index+1\n", 929 | " X = relu_layer.forward(X)\n", 930 | " if print_out:\n", 931 | " print(\"Relu 层的X:\"+str(X.shape))\n", 932 | " \n", 933 | " #Affine层\n", 934 | " fc_layer = self.layers[index]\n", 935 | " index =index+1\n", 936 | " X = fc_layer.forward(X)\n", 937 | " if print_out:\n", 938 | " print(\"Affline 层的X:\"+str(X.shape))\n", 939 | "\n", 940 | " #SoftMax层\n", 941 | " sofmax_layer = self.layers[index]\n", 942 | " index =index+1\n", 943 | " A = sofmax_layer.forward(X)\n", 944 | " if print_out:\n", 945 | " print(\"Softmax 层的X:\"+str(A.shape))\n", 946 | " \n", 947 | " return A\n", 948 | " \n", 949 | " def back_progation(self,train_y,learning_rate):\n", 950 | " \"\"\"\n", 951 | " 反向传播\n", 952 | " Arguments:\n", 953 | " \n", 954 | " \"\"\"\n", 955 | " index = len(self.layers)-1\n", 956 | " sofmax_layer = self.layers[index]\n", 957 | " index -= 1\n", 958 | " dz = sofmax_layer.backward(train_y)\n", 959 | " \n", 960 | " fc_layer = self.layers[index]\n", 961 | " dz = fc_layer.backward(dz,learning_rate=learning_rate)\n", 962 | " index -= 1\n", 963 | " \n", 964 | " relu_layer = self.layers[index]\n", 965 | " dz = relu_layer.backward(dz)\n", 966 | " index -= 1\n", 967 | " \n", 968 | " fc_layer = self.layers[index]\n", 969 | " dz = fc_layer.backward(dz,learning_rate=learning_rate)\n", 970 | " index -= 1\n", 971 | " \n", 972 | " pool_layer = self.layers[index]\n", 973 | " dz = pool_layer.backward(dz)\n", 974 | " index -= 1\n", 975 | " \n", 976 | " relu_layer = self.layers[index]\n", 977 | " dz = relu_layer.backward(dz)\n", 978 | " index -= 1\n", 979 | " \n", 980 | " conv_layer = self.layers[index]\n", 981 | " conv_layer.backward(dz,learning_rate=learning_rate)\n", 982 | " index -= 1\n", 983 | " \n", 984 | " \n", 985 | " def get_minibatch(self,batch_data,minibatch_size,num):\n", 986 | " m_examples = batch_data.shape[0]\n", 987 | " minibatches = math.ceil( m_examples / minibatch_size)\n", 988 | " \n", 989 | " if(num < minibatches):\n", 990 | " return batch_data[num*minibatch_size:(num+1)*minibatch_size]\n", 991 | " else:\n", 992 | " return batch_data[num*minibatch_size:m_examples]\n", 993 | " \n", 994 | " \n", 995 | " def optimize(self,train_X, train_y,minibatch_size,learning_rate=0.05,num_iters=500):\n", 996 | " \"\"\"\n", 997 | " 优化方法\n", 998 | " Arguments:\n", 999 | " train_X -- 训练数据 \n", 1000 | " train_y -- 训练数据的标签\n", 1001 | " learning_rate -- 学习率\n", 1002 | " num_iters -- 迭代次数\n", 1003 | " minibatch_size \n", 1004 | " \"\"\"\n", 1005 | " m = train_X.shape[0]\n", 1006 | " num_batches = math.ceil(m / minibatch_size)\n", 1007 | " \n", 1008 | " costs = []\n", 1009 | " for iteration in range(num_iters):\n", 1010 | " iter_cost = 0\n", 1011 | " for batch_num in range(num_batches):\n", 1012 | " minibatch_X = self.get_minibatch(train_X,minibatch_size,batch_num)\n", 1013 | " minibatch_y = self.get_minibatch(train_y,minibatch_size,batch_num)\n", 1014 | " \n", 1015 | " # 前向传播\n", 1016 | " A = self.forward_progation(minibatch_X,print_out=False)\n", 1017 | " #损失:\n", 1018 | " cost = compute_cost (A,minibatch_y)\n", 1019 | " #反向传播\n", 1020 | " self.back_progation(minibatch_y,learning_rate)\n", 1021 | " if(iteration%20 == 0):\n", 1022 | " iter_cost += cost/num_batches\n", 1023 | " \n", 1024 | " if(iteration%20 == 0):\n", 1025 | " print(\"After %d iters ,cost is :%g\" %(iteration,iter_cost))\n", 1026 | " costs.append(iter_cost)\n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " #画出损失函数图\n", 1032 | " plt.plot(costs)\n", 1033 | " plt.xlabel(\"iterations/hundreds\")\n", 1034 | " plt.ylabel(\"costs\")\n", 1035 | " plt.show()\n", 1036 | " \n", 1037 | " \n", 1038 | " def predicate(self, train_X):\n", 1039 | " \"\"\"\n", 1040 | " 预测\n", 1041 | " \"\"\"\n", 1042 | " logits = self.forward_progation(train_X)\n", 1043 | " one_hot = np.zeros_like(logits)\n", 1044 | " one_hot[range(train_X.shape[0]),np.argmax(logits,axis=1)] = 1\n", 1045 | " return one_hot \n", 1046 | "\n", 1047 | " def fit(self,train_X, train_y):\n", 1048 | " \"\"\"\n", 1049 | " 训练\n", 1050 | " \"\"\"\n", 1051 | " self.X = train_X\n", 1052 | " self.Y = train_y\n", 1053 | " n_y = train_y.shape[1]\n", 1054 | " m = train_X.shape[0]\n", 1055 | " \n", 1056 | " #初始化模型\n", 1057 | " self.init_model(train_X,n_classes=n_y)\n", 1058 | "\n", 1059 | " self.optimize(train_X, train_y,minibatch_size=1024,learning_rate=0.08,num_iters=500)\n", 1060 | " \n", 1061 | " logits = self.predicate(train_X)\n", 1062 | " \n", 1063 | " accuracy = np.sum(np.argmax(logits,axis=1) == np.argmax(train_y,axis=1))/m\n", 1064 | " print(\"训练集的准确率为:%g\" %(accuracy))\n", 1065 | "\n", 1066 | "\n", 1067 | " \n", 1068 | " \n", 1069 | " " 1070 | ] 1071 | }, 1072 | { 1073 | "cell_type": "code", 1074 | "execution_count": 965, 1075 | "metadata": {}, 1076 | "outputs": [ 1077 | { 1078 | "name": "stdout", 1079 | "output_type": "stream", 1080 | "text": [ 1081 | "After 0 iters ,cost is :230.258\n", 1082 | "After 20 iters ,cost is :229.224\n", 1083 | "After 40 iters ,cost is :228.492\n", 1084 | "After 60 iters ,cost is :227.945\n", 1085 | "After 80 iters ,cost is :227.538\n", 1086 | "After 100 iters ,cost is :227.236\n", 1087 | "After 120 iters ,cost is :227.006\n", 1088 | "After 140 iters ,cost is :226.829\n", 1089 | "After 160 iters ,cost is :226.689\n", 1090 | "After 180 iters ,cost is :226.577\n", 1091 | "After 200 iters ,cost is :226.476\n", 1092 | "After 220 iters ,cost is :226.387\n", 1093 | "After 240 iters ,cost is :226.313\n", 1094 | "After 260 iters ,cost is :226.245\n", 1095 | "After 280 iters ,cost is :226.202\n", 1096 | "After 300 iters ,cost is :226.151\n", 1097 | "After 320 iters ,cost is :226.092\n", 1098 | "After 340 iters ,cost is :226.022\n", 1099 | "After 360 iters ,cost is :225.954\n", 1100 | "After 380 iters ,cost is :225.859\n", 1101 | "After 400 iters ,cost is :225.758\n", 1102 | "After 420 iters ,cost is :225.627\n", 1103 | "After 440 iters ,cost is :225.546\n", 1104 | "After 460 iters ,cost is :225.349\n", 1105 | "After 480 iters ,cost is :225.117\n" 1106 | ] 1107 | }, 1108 | { 1109 | "data": { 1110 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEKCAYAAADn+anLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xt8VeWd7/HPLzcSQhIScoUEkoggAQEBQQWUKiq2Wp3WXq3jpdXa01qdaU9nOnYcp2fmnLZO7WXaaavTWp1S7UWtFi+Ad1ABA3IPIPeLIQnEQCBASPI7f+wFbpFAkOysJPv7fr32Kytrr7X2b7nbfHmeZ61nmbsjIiJyPAlhFyAiIt2XQkJERNqlkBARkXYpJEREpF0KCRERaZdCQkRE2qWQEBGRdikkRESkXQoJERFpV1LYBZyO3NxcLy0tDbsMEZEeZfHixbvcPa8j2/bokCgtLaWysjLsMkREehQz29LRbdXdJCIi7VJIiIhIuxQSIiLSLoWEiIi0SyEhIiLtUkiIiEi7FBIiItKuuAyJdxoO8L1n11Cz92DYpYiIdGtxGRL7D7Xwy1c2MGfVzrBLERHp1uIyJIbm96M8L53Zq2rCLkVEpFuLy5AwMy4fWcgbG3fT0NQcdjkiIt1WXIYEwIyRhbS2Oc9X1YZdiohItxW3ITG6OIuirFRma1xCRKRdcRsSR7qcXl1Xx/5DLWGXIyLSLcVtSABcPrKQQy1tvLKuLuxSRES6pZiFhJmVmNlLZrbazFaZ2R3B+v9jZsvNbKmZzTGzgcF6M7Ofmtn64P1xsartiHNLs8lJT1GXk4hIO2LZkmgBvuHuFcB5wFfNrAK4191Hu/tYYBZwd7D9FcCZwetW4BcxrA2ApMQEpo/I58WqWppb2mL9cSIiPU7MQsLdq919SbDcCFQBg9x9b9Rm6YAHy1cDD3vEAqC/mRXFqr4jZowqpPFQC69v2BXrjxIR6XG6ZEzCzEqBc4CFwe//bmbbgOt4ryUxCNgWtdv2YN2xx7rVzCrNrLKu7vTHEi44I5d+fZLU5SQichwxDwkz6wc8Btx5pBXh7ne5ewkwE/jaqRzP3e939wnuPiEvr0PP8T6h1OREpg3PY86qGlrb/OQ7iIjEkZiGhJklEwmIme7++HE2mQl8MljeAZREvVccrIu5GaMK2b2/mcVb3u2KjxMR6TFieXWTAb8Gqtz9vqj1Z0ZtdjWwJlh+Cvjb4Cqn84A97l4dq/qiTRueT0pSAs+tVJeTiEi0WLYkJgPXAxcHl7suNbOPAt8zs5Vmthy4DLgj2P4ZYCOwHngA+F8xrO19+vVJYurQXGav2om7upxERI5IitWB3X0+YMd565l2tnfgq7Gq52QuH1XIC2tqWfXOXkYNygqrDBGRbiWu77iONn1EAYkJpi4nEZEoColATnoKE0tzeE6XwoqIHKWQiDJjVCHra/exvnZf2KWIiHQLCokol40sANCNdSIiAYVElKKsNMaU9FdIiIgEFBLHmDGykOXb97Cj4UDYpYiIhE4hcYzLgy6nOWpNiIgoJI5VnteP4QUZuhRWRASFxHFdPrKANzfXs3vfobBLEREJlULiOC4fVUibw/NVNWGXIiISKoXEcVQUZVKSk6YuJxGJewqJ4zAzLq8o5LX1u2k8eDjsckREQqOQaMeMUYU0t7bx0trTf/qdiEhPpZBox7jB2eRl9GG2upxEJI4pJNqRkGBcWlHAS2trOXi4NexyRERCoZA4gRkjC2lqbmX+27vCLkVEJBQKiRM4r3wAmalJmj5cROKWQuIEUpISuGREAc9X1dDS2hZ2OSIiXU4hcRKXjyykoekwizbVh12KiEiXU0icxEXD8khNTlCXk4jEJYXESaSlJHLRsDzmrKqhrc3DLkdEpEspJDpgxqhCdu49yLLtDWGXIiLSpRQSHXDxWQUkJZi6nEQk7igkOiArLZnzzxjA7JU7cVeXk4jED4VEB10zdhCbdzcxTzfWiUgcUUh00FVjBpKf0YcH5m0MuxQRkS6jkOiglKQEbpxcyry3d1FVvTfsckREuoRC4hRcN3EIfVMS+e95m8IuRUSkSygkTkFW32Q+PaGEp5btoGbvwbDLERGJuZiFhJmVmNlLZrbazFaZ2R3B+nvNbI2ZLTezJ8ysf7A+xcweNLMVZrbMzKbFqrbTcfPkMlrbnN++vjnsUkREYi6WLYkW4BvuXgGcB3zVzCqAucAodx8NrAO+HWx/C4C7nw1cCvzQzLpdS2fwgL7MGFXIzAVb2H+oJexyRERiKmZ/hN292t2XBMuNQBUwyN3nuPuRv64LgOJguQJ4Mdi+FmgAJsSqvtPxpanl7D3Ywh8rt4VdiohITHXJv9TNrBQ4B1h4zFs3A88Gy8uAj5tZkpmVAeOBkq6o71SNG5zNhCHZ/Oa1TZpCXER6tZiHhJn1Ax4D7nT3vVHr7yLSJTUzWPUbYDtQCfwYeB34wHNDzexWM6s0s8q6urpYl9+uL00tZ1v9AWavqgmtBhGRWItpSJhZMpGAmOnuj0etvxG4ErjOg3ku3L3F3f/O3ce6+9VAfyJjFu/j7ve7+wR3n5CXlxfL8k/o0ooChgzoywPzNmqqDhHptWJ5dZMBvwaq3P2+qPUzgG8BH3f3pqj1fc0sPVi+FGhx99Wxqu90JSYYX5pSxtJtDSze8m7Y5YiIxEQsWxKTgeuBi81safD6KPAzIAOYG6z7ZbB9PrDEzKqAfwj27dauHV9C/77J3P+qpuoQkd4pKVYHdvf5gB3nrWfa2X4zMDxW9cRCWkoiX5g0hJ+/vJ5Nu/ZTlpsedkkiIp2q292H0NP87QVDSE5I4DfzNVWHiPQ+ConTlJ+RyjXnDORPi7fx7v7msMsREelUColO8KWp5Rw83MbvFmwJuxQRkU6lkOgEwwoymDY8j4fe2MzBwx+4tUNEpMdSSHSSW6aWs2tfM08u3RF2KSIinUYh0UkuOGMAI4oyeWDeJtradHOdiPQOColOYmbcMrWM9bX7eGVdeNOFiIh0JoVEJ7py9EAKM1P1HGwR6TUUEp3oyHOwX9+wm5U79oRdjojIaVNIdLLPTRxMekoi/63WhIj0AgqJTpaVlsxnzh3MrOXVvNNwIOxyREROi0IiBm6aXEqb6znYItLzKSRioCSnL1ecXcQjC7fSePBw2OWIiHxoCokYuXVqOY2HWnh0kZ6DLSI9l0IiRsaU9OeCMwbw85fXU6+J/0Skh1JIxNC/XDWSxoMt3Dt7TdiliIh8KAqJGBpemMFNF5Ty6JvbeGurHnEqIj2PQiLG7rx0GPkZffjnJ1fSqjmdRKSHUUjEWL8+Sdz1sQpW7tjL7xdtDbscEZFTopDoAleNLuL88gHc+9wadu87FHY5IiIdppDoAmbGd68eSVNzK99/ToPYItJzKCS6yJkFGXxxShl/rNzO4i0axBaRnkEh0YVuv+RMCjNT+ee/aBBbRHoGhUQX6tcnie9cOYLV1XuZuXBL2OWIiJyUQqKLfezsIqYMzeXe2WvZpUFsEenmFBJdzMy45+MjOXi4lf/3jAaxRaR7U0iEYGh+P740tZzHlmznzc31YZcjItIuhURIbr94KAOzIoPYLa1tYZcjInJcComQ9E1J4p+vrGDNzkb+Z4EGsUWke4pZSJhZiZm9ZGarzWyVmd0RrL/XzNaY2XIze8LM+gfrk83sITNbYWZVZvbtWNXWXcwYVcjUM3O5b846ahsPhl2OiMgHxLIl0QJ8w90rgPOAr5pZBTAXGOXuo4F1wJEw+BTQx93PBsYDXzaz0hjWFzoz418/PpJDLW18T4PYItINxSwk3L3a3ZcEy41AFTDI3ee4e0uw2QKg+MguQLqZJQFpQDOwN1b1dRflef249cJyHn9rBws37g67HBGR9+mSMYmgRXAOsPCYt24Gng2W/wzsB6qBrcB/uHtcXPrz1Y8MZVD/NO5+chWHNYgtIt1IzEPCzPoBjwF3uvveqPV3EemSmhmsmgi0AgOBMuAbZlZ+nOPdamaVZlZZV1cX6/K7RFpKIndfVcHamkYefkOD2CLSfXQoJMws3cwSguVhZvZxM0vuwH7JRAJiprs/HrX+RuBK4Dp3PzKJ0eeB59z9sLvXAq8BE449prvf7+4T3H1CXl5eR8rvES6rKGDa8Dx+NHcdO/doEFtEuoeOtiReBVLNbBAwB7ge+O2JdjAzA34NVLn7fVHrZwDfAj7u7k1Ru2wFLg62SScy2B03o7lmxj1XjaS1zflfMxdzqKU17JJERDocEhb8Qf8E8F/u/ilg5En2mUwkTC42s6XB66PAz4AMYG6w7pfB9j8H+pnZKuBN4EF3X36qJ9STleamc9+nx7BkawN3/2UV7zWyRETCkdTB7czMzgeuA74YrEs80Q7uPh+w47z1TDvb7yNyGWxcu+LsIr5+8VB++uJ6KgZmcsMFpWGXJCJxrKMtiTuI3M/whLuvCgaUX4pdWfHtzunDmD6igO/OWs3rG3aFXY6IxLGOhkSBu3/c3b8P4O4bgXmxKyu+JSQYP/rMGMpz0/nqzCVsq286+U4iIjHQ0ZA43hQZvX7ajDBlpCbzwN9OoLXNueXhSvYfajn5TiIineyEIWFmV5jZfwKDzOynUa/fErnHQWKoNDed//z8ONbVNPLNPy3TQLaIdLmTtSTeASqBg8DiqNdTwOWxLU0ALhqWx7evGMGzK3fysxfXh12OiMSZE17d5O7LgGVm9nt3PwxgZtlAibu/2xUFCnxpahmrq/fyw7nrGF6YwWUjC8MuSUTiREfHJOaaWaaZ5QBLgAfM7EcxrEuimBn/7xNnM7o4i7/7w1LW1TSGXZKIxImOhkRWMO/SJ4CH3X0ScEnsypJjpSYn8qvrx5OWksQtD1fS0NQcdkkiEgc6GhJJZlYEfBqYFcN65ASKstL41fXjeKfhALc/8pYeeyoiMdfRkPguMBvY4O5vBjfTvR27sqQ944fk8G/XjGLe27v4/nNxM7WViISkQ9NyuPufgD9F/b4R+GSsipIT+8y5g1n9zl4emLeJEUWZfGJc8cl3EhH5EDo6VXhx8Dzq2uD1mJnpL1OIvnNlBeeV5/CPj69g2baGsMsRkV6qo91NDxK5N2Jg8PprsE5CkpyYwH9dN568fn245eFK1tfqiicR6XwdDYk8d3/Q3VuC12+B3vPEnx4qJz2FB286lzaHT/3yDZaqRSEinayjIbHbzL5gZonB6wvA7lgWJh0zrCCDx75yPhmpyXz+gQW8uq53PNJVRLqHjobEzUQuf90JVAPXAjfGqCY5RUMGpPPnr5zPkAHpfPGhN/nrsnfCLklEeolTuQT2BnfPc/d8IqHxr7ErS05VfkYqj956HucMzubrj77Fw29sDrskEekFOhoSo6PnanL3euCc2JQkH1ZWWjIP3zyRS84q4O4nV/Gjues0c6yInJaOhkRCMLEfAMEcTh199Kl0odTkRH75hXF8anwxP3nhbe5+chWtbQoKEflwOvqH/ofAG2Z25Ia6TwH/HpuS5HQlJSbwg2tHk9MvhV+9spH6pmbu+/QY+iSd8LHkIiIf0NE7rh82s0rg4mDVJ9x9dezKktNlZnz7ihEMSE/h/z6zhj1Nh/nV9eNJ76MGoIh0XIf/YgShoGDoYW698Ayy+6bwj4+v4PMPLODBmyaSk54Sdlki0kN0dExCerBPTSjhV18Yz5qdjVz7y9fZ0XAg7JJEpIdQSMSJ6RUF/M8XJ1HXeIhrf/G6HlwkIh2ikIgjE8ty+OOXz6elzbn6Z6/xuwVbdImsiJyQQiLOjCjKZNbtU5hQms13/rKSLz5USW3jwbDLEpFuSiERhwoyU3noponcc1UFr63fxYwfz2P2qp1hlyUi3ZBCIk4lJBg3Ti7j6a9PYWD/VL78P4v51p+Xse9QS9iliUg3opCIc0PzM3j8K5P52keG8ufF27niJ69Subk+7LJEpJtQSAgpSQl88/Lh/PHL5wPw6V+9wb2z19Dc0hZyZSIStpiFhJmVmNlLZrbazFaZ2R3B+nvNbI2ZLQ8eido/WH+dmS2NerWZ2dhY1ScfNKE0h2fvuJBrxxfz85c28IlfvKYn3onEOYvVJZBmVgQUufsSM8sAFgPXAMXAi+7eYmbfB3D3fzhm37OBv7j7GSf6jAkTJnhlZWVM6o93s1ft5NuPr2D/oRa+fcVZ/O35pSQkWNhliUgnMLPF7j6hI9vGrCXh7tXuviRYbgSqgEHuPsfdj4yOLiASGsf6HPBorGqTk7t8ZCHP3TmVC84YwD1/Xc0NDy5i6+6msMsSkS7WJWMSZlZK5PkTC49562bg2ePs8hngkdhWJSeTn5HKb248l3//m1FUbn6XS+57mXueWsWufYfCLk1EukjMQ8LM+gGPAXe6+96o9XcBLcDMY7afBDS5+8p2jnermVWaWWVdnZ7nHGtmxnWThvDy/57GteNL+J8FW7joBy/xo7nrdLmsSByI2ZgEgJklA7OA2e5+X9T6G4EvA5e4e9Mx+/wIqHP3/3uy42tMouttqNvHD+es5ZkVOxmQnsLtFw/l85OGkJKkC+VEeopTGZOI5cC1AQ8B9e5+Z9T6GcB9wEXuXnfMPgnANmCqu2882WcoJMKzdFsD3392DW9s3E1JThrfvGw4V40eqMFtkR6gWwxcA5OB64GLoy5r/SjwMyADmBus+2XUPhcC2zoSEBKusSX9+f0tk3jo5olk9EnmjkeXcuV/zufltbWaNFCkF4lpd1OsqSXRPbS1OX9d/g7/MWct2+oPcF55Dv94xQjGlvQPuzQROY5u0d3UFRQS3UtzSxuPLNrKT194m937m7l8ZAG3TC1n/JBsIr2PItIdKCQkVPsOtfDf8zby6/mbaDzYwsiBmdx4QSlXjRlIanJi2OWJxD2FhHQLTc0tPPHWDh56fTPravaRk57CZ88t4QvnDWFg/7SwyxOJWwoJ6VbcnTc27Oa3r2/m+aoazIzLKgq44YJSJpXlqCtKpIudSkgkxboYETPjgqG5XDA0l231Tfxu4Rb+8OY2nl25k7MKM7jhglKuGTuItBR1RYl0N2pJSCgONLfy1LIdPPjaZtbsbCQrLZnPnlvC5ycNZsiA9LDLE+nV1N0kPYa7s2hTPQ+9sZnZq2pobXPOHpTFlaOL+NjoIoqz+4Zdokivo5CQHql6zwH+uuwdZi2vZvn2PQCcM7g/Hzs7EhhFWRrsFukMCgnp8bbubmLWineYtaya1dWReSHPLc3mytEDueLsQvIzUkOuUKTnUkhIr7Kxbh+zllfz9PJq1tY0YgaTynIigTGqkAH9+oRdokiPopCQXmtdTSOzllcza/k7bKzbT2KCcW5pNtNHFHBpRYEGvUU6QCEhvZ67U1XdyNMr3uH51bWsrYk8i3tYQT+mjyhgekUBY4v7a1ZakeNQSEjc2bq7ieerapi7uoZFm+tpbXNy+/Vh+oh8po8oYMqZuZoSRCSgkJC4tqfpMC+vq2Xu6hpeWVtH46EWUpMTmHpmHpeOKOAjZ+WTl6FxDIlfuuNa4lpW32SuHjuIq8cOormljYWbdvP86hqer4oEB0BFUSZTz8xlypm5nFuao1aGSDvUkpC4cWQc48U1Ncxfv4vFW97lcKuTkpTAuaXZTB6ay9SheYwcmKmxDOnV1N0k0gFNzS0s2lTP/Ld3MX/9LtbsjAx+9++bzOQzIq2MKUNzKcnRXd/Su6i7SaQD+qYkMW14PtOG5wNQ23iQ19fvZt7bu5i/vo6nV1QDMGRAX84vH8D4IdmMH5JNWW66Zq6VuKGWhMhxuDsb6vYdbWUs2lTP3oMtAGT3TWb8kGzGDclm3OBsxhT31wy20qOoJSFymsyMofkZDM3P4MbJZbS1RUJj8ZZ3I6+t7/J8VS0ASQlGxcBMxg3OPtra0EOVpLdQS0LkQ6rf38xbW989GhzLtjdw8HAbAEVZqYwt6c/o4v6MLs7i7OIsMlOTQ65YJEItCZEukJOewiUjCrhkRAEAh1vbqKreezQ0lm/fw7Mrdx7dvjw3ndHFWYwu7s+YkiwqirLUTSXdnloSIjH07v5mlu/Yw/JtDZGf2xuo2XsIgMQE48z8fowp7s/okixGDcyiNDedrDS1OCS2dAmsSDdWs/cgy7Y1sGLHHpZtjwRHQ9Pho+9npiYxeEBfSrL7MjinL8U5kZ8l2WkMyk6jT5JaH3J61N0k0o0VZKZy2chCLhtZCESupNpWf4DV1XvYWt/EtvoDbK1vYm1NIy9U1dLc2nZ0XzMozEylJCcSIkMG9KU8L53y3H6U5aar+0o6nUJCJGRmxuABfRk84IM37bW1ObWNh4LwaIr8fDeyPH99HY8tOfS+7Qf1TwtCI53yvH6R5bx+FGWm6i5y+VAUEiLdWEKCUZiVSmFWKhPLcj7wflNzC5t27WdjXfDatY+Ndfv58+Lt7G9uPbpdWnIiZbnplOelM6wggxFFmZxVmEFxdppuDJQTUkiI9GB9U5IYOTCLkQOz3rfePdIC2VC3LypE9rF8+x6eXlHNkaHIjD5JnFWUwVmFme/9LMwgvY/+NEiE/pcg0guZGQWZqRRkpnLBGbnve2//oRbW1jSyprqRNTv3UlW9l7+8tYPGBS1HtxkyoC9nFR5pcWQypiSLwsxUtTriUMxCwsxKgIeBAsCB+939J2Z2L3AV0AxsAG5y94Zgn9HAr4BMoA04190PxqpGkXiU3ieJcYMjU4oc4e7saDhAVXUja6r3UrVzL2uqG5mzuuZoqyM/ow+ji/sztiSLMSX9GT2oP1l9dblubxezS2DNrAgocvclZpYBLAauAYqBF929xcy+D+Du/2BmScAS4Hp3X2ZmA4AGd29t7zN0CaxIbB1obqVq597IfR7b97B0ewMb6/Yffb8suEFwTHF/xpT0Z+TATD2bowfoFpfAuns1UB0sN5pZFTDI3edEbbYAuDZYvgxY7u7Lgn12x6o2EemYtJTED7Q69hw4zMode1i6rYFl2xpYuLGeJ5e+A0TmsRpemHG0xTG6uD9n5vcjKTEhrFOQ09QlYxJmVgqcAyw85q2bgT8Ey8MAN7PZQB7wqLv/oCvqE5GOy0pLZvLQXCYPfW+sY+eegyzb3sDy7Q0s3dbArOXv8MiirUDkyqqRAzOPTkcyprg/Qwb01fhGDxHzkDCzfsBjwJ3uvjdq/V1ACzAzqpYpwLlAE/BC0CR64Zjj3QrcCjB48OBYly8iHRC5TLeQy4MbBNvanM2797N8+54gPPYwc+EWfvNa5MbArLTkYB6r97qqCjJTwzwFaUdMp+Uws2RgFjDb3e+LWn8j8GXgEndvCtZ9FrjC3W8Ifv9n4KC739ve8TUmIdJzHG5tY11NI8uDqUiWbdvD2ppGWtsif4PKctO5aFgeHzkrn0lleu54LHWLuZss0pZ8CKh39zuj1s8A7gMucve6qPXZwAtEWhPNwHPAj9z96fY+QyEh0rMdaG5ldfUelm7bw/y363h9w24OtbSRlpzIBWcMYNpZ+UwblqdHyHay7hISU4B5wAoil7MC/BPwU6APcGRgeoG73xbs8wXg20QumX3G3b91os9QSIj0LgcPt/LGxt28vKaWl9bWsbW+CYCh+f34yPA8PjI8nwmlOaQkaSD8dHSLkOgKCgmR3svd2bRrPy+trePltbUs3FhPc2sb6SmJTB6ay0fOyufis/I1lvEhdItLYEVEToeZBZMU9uOLU8rYf6iFNzbs5qW1tby8to45q2sAGFOcxfQRBUyvKOCswgxdNdXJ1JIQkR7H3VlXs4/nq2p4vqqGt7Y2AJFZcC+tKGD6iAImlqlbqj3qbhKRuFLbeJCX1tQyd3Ut89fXcfBwGxl9krhoeB6XVhQwbVi+phCJopAQkbh1oLmV19bvCloZtezad4jEBGNiaQ4zRhVyzdhBcR8YCgkRESI39S3b3sDzVTXMXV3Dupp99ElK4GNnF/HZiYM5tzQ7LscwFBIiIsex6p09PLpoW2Rq9EMtnJGXzucmDuaT44rJTk8Ju7wuo5AQETmBpuYWZi2v5pFFW3lrawMpiQlccXYhn5s4mEllOb2+daGQEBHpoKrqvTy6aCuPv7WDxoMtlOcGrYvxxeT00taFQkJE5BQdaG7l6RWR1sXiLe+SkpjAZSMLuHJ0ERPLBvSqwFBIiIichnU1jTyyaCuPL9nBngOHATgzvx+TynOYWDaA88pyyO/Bd3orJEREOkFzSxsrdjSwYGM9izbVU7m5nv3NkYdlluWmM7E0h0nlOUwqH8Cg/mkhV9txCgkRkRhoaW1jdfVeFm6sZ+Gm3SzaVM/egy1A5G7vSeU5nFc+gKtGDyQtpftOda6QEBHpAm1tzpqdjUcDY9Gmenbvb6Y4O427r6zg0oqCbnmllEJCRCQE7s7rG3Zzz1OreLt2H9OG53HPVSMpzU0Pu7T3OZWQ0OxXIiKdxMyYPDSXZ+6Yync+NoLKze9y2Y9e5Ydz1nIgGMvoaRQSIiKdLDkxgS9NLefFb1zER88u5D9fXM/0+17huZU76Wm9NwoJEZEYyc9M5cefPYc/3HoeGalJ3Pa7xdzw4JtsrNsXdmkdppAQEYmxSeUDmHX7FO6+soK3trzLjB/P4wfPraGpuSXs0k5KISEi0gWSEhO4eUoZL3zzIq4cU8R/vbyB6T98hWdXVHfrLiiFhIhIF8rPSOW+T4/lT7edT2ZaMl+ZuYQbHnyT6j0Hwi7tuBQSIiIhOLc0h1m3T+FfrqrgzU31zPjxPJ5eXh12WR+gkBARCUlSYgI3TS7jmTumUpqbzld/v4Rv/HEZjQcPh13aUQoJEZGQleWm8+fbzufrl5zJE29t54qfzOPNzfVhlwUoJEREuoXkxAT+/tJh/Om2C0gw4zO/eoP/mL2Ww61todalkBAR6UbGD8nmmTum8slxxfzspfV88hevh3pfhUJCRKSb6dcniXs/NYZfXDeOrfVNfOyn85m5cEsol8oqJEREuqkrzi5i9p0XMqE0m7ueWMktD1eya9+hLq1BISEi0o0VZKby0E0TufvKCl59exczfvwqL66p6bLPV0iIiHRzCQnGzVPK+OvXppDbrw83/7aSf5u1ums+u0s+RURETtvwwgye/Npkbr2wnCFd9IyKpFgd2MxKgIdGt92PAAAIeUlEQVSBAsCB+939J2Z2L3AV0AxsAG5y9wYzKwWqgLXBIRa4+22xqk9EpCfqk5TIP310RJd9XixbEi3AN9y9AjgP+KqZVQBzgVHuPhpYB3w7ap8N7j42eCkgRERCFrOQcPdqd18SLDcSaSUMcvc57n5kftwFQHGsahARkdPTJWMSQVfSOcDCY966GXg26vcyM3vLzF4xs6ntHOtWM6s0s8q6urqY1CsiIhExDwkz6wc8Btzp7nuj1t9FpEtqZrCqGhjs7ucAfw/83swyjz2eu9/v7hPcfUJeXl6syxcRiWsxDQkzSyYSEDPd/fGo9TcCVwLXeXALobsfcvfdwfJiIoPaw2JZn4iInFjMQsLMDPg1UOXu90WtnwF8C/i4uzdFrc8zs8RguRw4E9gYq/pEROTkYnYJLDAZuB5YYWZLg3X/BPwU6APMjeTI0UtdLwS+a2aHgTbgNnfvHnPliojEqZiFhLvPB+w4bz3TzvaPEemaEhGRbsK68wO4T8bM6oAtp3GIXGBXJ5XT0+jc41c8n388nzu8d/5D3L1DV/706JA4XWZW6e4Twq4jDDr3+Dx3iO/zj+dzhw93/pq7SURE2qWQEBGRdsV7SNwfdgEh0rnHr3g+/3g+d/gQ5x/XYxIiInJi8d6SEBGRE4jLkDCzGWa21szWm9k/hl1PVzOzzWa2wsyWmlll2PXEkpn9xsxqzWxl1LocM5trZm8HP7PDrDGW2jn/e8xsR/D9LzWzj4ZZY6yYWYmZvWRmq81slZndEazv9d//Cc79lL/7uOtuCqb+WAdcCmwH3gQ+5+5d8yzAbsDMNgMT3L3XXy9uZhcC+4CH3X1UsO4HQL27fy/4R0K2u/9DmHXGSjvnfw+wz93/I8zaYs3MioAid19iZhnAYuAa4EZ6+fd/gnP/NKf43cdjS2IisN7dN7p7M/AocHXINUmMuPurwLHTu1wNPBQsP0Tk/zy9UjvnHxfae6YNcfD9n+DcT1k8hsQgYFvU79v5kP/xejAH5pjZYjO7NexiQlDg7tXB8k4ij9iNN18zs+VBd1Sv62451jHPtImr7/84z/M5pe8+HkNCYIq7jwOuIPJY2QvDLigswVT18dXnCr8AzgDGEnmOyw/DLSe22numDfT+7/84537K3308hsQOoCTq9+JgXdxw9x3Bz1rgCSJdcPGkJuizPdJ3WxtyPV3K3WvcvdXd24AH6MXffzvPtImL7/945/5hvvt4DIk3gTPNrMzMUoDPAk+FXFOXMbP0YCALM0sHLgNWnnivXucp4IZg+QbgyRBr6XJH/kAG/oZe+v2390wb4uD7P8HzfE75u4+7q5sAgsu+fgwkAr9x938PuaQuEzzQ6Yng1yTg9735/M3sEWAakdkva4B/Af4C/BEYTGQW4U/31meXtHP+04h0NziwGfhyVB99r2FmU4B5wAoiz6iByDNtFtLLv/8TnPvnOMXvPi5DQkREOiYeu5tERKSDFBIiItIuhYSIiLRLISEiIu1SSIiISLsUEtJtmdnrwc9SM/t8Jx/7n473WZ38GUVmNsfMppnZrM4+fvAZ95jZN0/zGJvNLLezapLeRSEh3Za7XxAslgKnFBJmlnSSTd4XElGf1ZlmALNjcNyT6sD5i3SIQkK6LTPbFyx+D5gazH//d2aWaGb3mtmbwURlXw62n2Zm88zsKWB1sO4vwUSGq45MZmhm3wPSguPNjP4si7jXzFYGz9z4TNSxXzazP5vZGjObGdzVipl9L5i3f7mZRU/BPAN4Nlju186+R/8Vb2YTzOzlYPmeYAK2l81so5l9Peq/y11mts7M5gPDo9a/bGY/tsgzQu4wszwzeyz47/SmmU0OthsQtHBWmdl/A0dqSTezp81sWXD+n+mEr1F6OnfXS69u+SIy7z1E7hCeFbX+VuA7wXIfoBIoC7bbD5RFbZsT/EwjMgXBgOhjH+ezPgnMJXI3fgGwFSgKjr2HyFxfCcAbwBRgALCW925M7R/8TASWRtX/gX2D9zYDucHyBODlYPke4PXg/HKB3UAyMJ7IXbR9gUxgPfDNYJ+Xgf+KOqffR33OYCJTNAD8FLg7WP4Ykbtvc4NzfyBq/6yw/zegV/gvNUmlJ7oMGG1m1wa/ZwFnAs3AInffFLXt183sb4LlkmC73Sc49hTgEXdvJTIR3CvAucDe4NjbAcxsKZFusAXAQeDXwbjDkbGHSbw3NTPt7Dv/JOf5tLsfAg6ZWS2R0JoKPOHuTcGxjp137A9Ry9OBiqDRApAZzAp6IfAJAHd/2szeDd5fAfzQzL5PJJTnnaQ+iQMKCemJDLjd3d/X329m04i0JKJ/nw6c7+5NQVdO6ml87qGo5VYgyd1bzGwicAlwLfA14GIi07A/d6J9g+UW3uv2Pba29vY5kf1RywnAee5+MHqDqNB4H3dfZ2bjgI8C/2ZmL7j7dzvwmdKLaUxCeoJGICPq99nAV4KpkDGzYcGMtsfKAt4NAuIs4Lyo9w4f2f8Y84DPBOMeeUT+1b2ovcKCf5lnufszwN8BY4K3LgGe78C5bSbShQSR7p6TeRW4xszSLDKb71Un2HYOcHtUrWOjjvH5YN0VQHawPBBocvffAfcC4zpQj/RyaklIT7AcaDWzZcBvgZ8Q6a5ZEgwA13H8R1A+B9xmZlVExg0WRL13P7DczJa4+3VR658AzgeWEemr/5a77wxC5ngygCfNLJVIC+fvg3A56JHHRp7MvxLpqvo/RMYUTsgjzyz+Q1BfLZGp79vzdeDnZracyP/XXwVuCz7zETNbRWTcY2uw/dnAvWbWBhwGvtKB+qWX0yywIp3MzL4AFLv798KuReR0KSRERKRdGpMQEZF2KSRERKRdCgkREWmXQkJERNqlkBARkXYpJEREpF0KCRERadf/B95B21obnCufAAAAAElFTkSuQmCC\n", 1111 | "text/plain": [ 1112 | "
" 1113 | ] 1114 | }, 1115 | "metadata": { 1116 | "needs_background": "light" 1117 | }, 1118 | "output_type": "display_data" 1119 | }, 1120 | { 1121 | "name": "stdout", 1122 | "output_type": "stream", 1123 | "text": [ 1124 | "训练集的准确率为:0.14\n" 1125 | ] 1126 | } 1127 | ], 1128 | "source": [ 1129 | "convNet = SimpleConvNet()\n", 1130 | "#拿10张先做实验\n", 1131 | "train_X = X_train[0:100]\n", 1132 | "train_y = y_train[0:100]\n", 1133 | "# train_X = X_train\n", 1134 | "# train_y = y_train\n", 1135 | "\n", 1136 | "convNet.fit(train_X,train_y)" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "markdown", 1141 | "metadata": {}, 1142 | "source": [ 1143 | "## 预测" 1144 | ] 1145 | }, 1146 | { 1147 | "cell_type": "code", 1148 | "execution_count": 920, 1149 | "metadata": {}, 1150 | "outputs": [ 1151 | { 1152 | "name": "stdout", 1153 | "output_type": "stream", 1154 | "text": [ 1155 | "测试的准确率为:0.0985\n" 1156 | ] 1157 | } 1158 | ], 1159 | "source": [ 1160 | "logits = convNet.predicate(X_test)\n", 1161 | "m = X_test.shape[0]\n", 1162 | "accuracy = np.sum(np.argmax(logits,axis=1) == np.argmax(y_test,axis=1))/m\n", 1163 | "print(\"测试的准确率为:%g\" %(accuracy))" 1164 | ] 1165 | }, 1166 | { 1167 | "cell_type": "code", 1168 | "execution_count": 921, 1169 | "metadata": {}, 1170 | "outputs": [ 1171 | { 1172 | "name": "stdout", 1173 | "output_type": "stream", 1174 | "text": [ 1175 | "y is:7\n", 1176 | "your predicate result is :1\n" 1177 | ] 1178 | }, 1179 | { 1180 | "data": { 1181 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADO5JREFUeJzt3V2IXfW5x/Hf76QpiOlFYjUMNpqeogerSKKjCMYS9VhyYiEWg9SLkkLJ9CJKCyVU7EVzWaQv1JvAlIbGkmMrpNUoYmNjMQ1qcSJqEmNiElIzMW9lhCaCtNGnF7Nsp3H2f+/st7XH5/uBYfZez3p52Mxv1lp77bX/jggByOe/6m4AQD0IP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpD7Vz43Z5uOEQI9FhFuZr6M9v+1ltvfZPmD7gU7WBaC/3O5n+23PkrRf0h2SxiW9LOneiHijsAx7fqDH+rHnv1HSgYg4FBF/l/RrSSs6WB+APuok/JdKOjLl+Xg17T/YHrE9Znusg20B6LKev+EXEaOSRiUO+4FB0sme/6ikBVOef66aBmAG6CT8L0u6wvbnbX9a0tckbelOWwB6re3D/og4a/s+Sb+XNEvShojY07XOAPRU25f62toY5/xAz/XlQz4AZi7CDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp7iG5Jsn1Y0mlJH0g6GxHD3WgKQO91FP7KrRHx1y6sB0AfcdgPJNVp+EPSVts7bY90oyEA/dHpYf+SiDhq+xJJz9p+MyK2T52h+qfAPwZgwDgiurMie52kMxHxo8I83dkYgIYiwq3M1/Zhv+0LbX/mo8eSvixpd7vrA9BfnRz2z5f0O9sfref/I+KZrnQFoOe6dtjf0sY47Ad6rueH/QBmNsIPJEX4gaQIP5AU4QeSIvxAUt24qy+FlStXNqytXr26uOw777xTrL///vvF+qZNm4r148ePN6wdOHCguCzyYs8PJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxS2+LDh061LC2cOHC/jUyjdOnTzes7dmzp4+dDJbx8fGGtYceeqi47NjYWLfb6Rtu6QVQRPiBpAg/kBThB5Ii/EBShB9IivADSXE/f4tK9+xfe+21xWX37t1brF911VXF+nXXXVesL126tGHtpptuKi575MiRYn3BggXFeifOnj1brJ86dapYHxoaanvbb7/9drE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR+ftsbJH1F0smIuKaaNk/SbyQtlHRY0j0R8W7Tjc3g+/kH2dy5cxvWFi1aVFx2586dxfoNN9zQVk+taDZewf79+4v1Zp+fmDdvXsPamjVrisuuX7++WB9k3byf/5eSlp0z7QFJ2yLiCknbqucAZpCm4Y+I7ZImzpm8QtLG6vFGSXd1uS8APdbuOf/8iDhWPT4uaX6X+gHQJx1/tj8ionQub3tE0kin2wHQXe3u+U/YHpKk6vfJRjNGxGhEDEfEcJvbAtAD7YZ/i6RV1eNVkp7oTjsA+qVp+G0/KulFSf9je9z2NyX9UNIdtt+S9L/VcwAzCN/bj4F19913F+uPPfZYsb579+6GtVtvvbW47MTEuRe4Zg6+tx9AEeEHkiL8QFKEH0iK8ANJEX4gKS71oTaXXHJJsb5r166Oll+5cmXD2ubNm4vLzmRc6gNQRPiBpAg/kBThB5Ii/EBShB9IivADSTFEN2rT7OuzL7744mL93XfL3xa/b9++8+4pE/b8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU9/Ojp26++eaGteeee6647OzZs4v1pUuXFuvbt28v1j+puJ8fQBHhB5Ii/EBShB9IivADSRF+ICnCDyTV9H5+2xskfUXSyYi4ppq2TtJqSaeq2R6MiKd71SRmruXLlzesNbuOv23btmL9xRdfbKsnTGplz/9LScummf7TiFhU/RB8YIZpGv6I2C5pog+9AOijTs7577P9uu0Ntud2rSMAfdFu+NdL+oKkRZKOSfpxoxltj9gesz3W5rYA9EBb4Y+IExHxQUR8KOnnkm4szDsaEcMRMdxukwC6r63w2x6a8vSrknZ3px0A/dLKpb5HJS2V9Fnb45J+IGmp7UWSQtJhSd/qYY8AeoD7+dGRCy64oFjfsWNHw9rVV19dXPa2224r1l944YViPSvu5wdQRPiBpAg/kBThB5Ii/EBShB9IiiG60ZG1a9cW64sXL25Ye+aZZ4rLcimvt9jzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBS3NKLojvvvLNYf/zxx4v19957r2Ft2bLpvhT631566aViHdPjll4ARYQfSIrwA0kRfiApwg8kRfiBpAg/kBT38yd30UUXFesPP/xwsT5r1qxi/emnGw/gzHX8erHnB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkmt7Pb3uBpEckzZcUkkYj4me250n6jaSFkg5Luici3m2yLu7n77Nm1+GbXWu//vrri/WDBw8W66V79psti/Z0837+s5K+GxFflHSTpDW2vyjpAUnbIuIKSduq5wBmiKbhj4hjEfFK9fi0pL2SLpW0QtLGaraNku7qVZMAuu+8zvltL5S0WNKfJc2PiGNV6bgmTwsAzBAtf7bf9hxJmyV9JyL+Zv/7tCIiotH5vO0RSSOdNgqgu1ra89uercngb4qI31aTT9gequpDkk5Ot2xEjEbEcEQMd6NhAN3RNPye3MX/QtLeiPjJlNIWSauqx6skPdH99gD0SiuX+pZI+pOkXZI+rCY/qMnz/sckXSbpL5q81DfRZF1c6uuzK6+8slh/8803O1r/ihUrivUnn3yyo/Xj/LV6qa/pOX9E7JDUaGW3n09TAAYHn/ADkiL8QFKEH0iK8ANJEX4gKcIPJMVXd38CXH755Q1rW7du7Wjda9euLdafeuqpjtaP+rDnB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuM7/CTAy0vhb0i677LKO1v38888X682+DwKDiz0/kBThB5Ii/EBShB9IivADSRF+ICnCDyTFdf4ZYMmSJcX6/fff36dO8EnCnh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp6nd/2AkmPSJovKSSNRsTPbK+TtFrSqWrWByPi6V41mtktt9xSrM+ZM6ftdR88eLBYP3PmTNvrxmBr5UM+ZyV9NyJesf0ZSTttP1vVfhoRP+pdewB6pWn4I+KYpGPV49O290q6tNeNAeit8zrnt71Q0mJJf64m3Wf7ddsbbM9tsMyI7THbYx11CqCrWg6/7TmSNkv6TkT8TdJ6SV+QtEiTRwY/nm65iBiNiOGIGO5CvwC6pKXw256tyeBviojfSlJEnIiIDyLiQ0k/l3Rj79oE0G1Nw2/bkn4haW9E/GTK9KEps31V0u7utwegV1p5t/9mSV+XtMv2q9W0ByXda3uRJi//HZb0rZ50iI689tprxfrtt99erE9MTHSzHQyQVt7t3yHJ05S4pg/MYHzCD0iK8ANJEX4gKcIPJEX4gaQIP5CU+znEsm3GcwZ6LCKmuzT/Mez5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpfg/R/VdJf5ny/LPVtEE0qL0Nal8SvbWrm71d3uqMff2Qz8c2bo8N6nf7DWpvg9qXRG/tqqs3DvuBpAg/kFTd4R+tefslg9rboPYl0Vu7aumt1nN+APWpe88PoCa1hN/2Mtv7bB+w/UAdPTRi+7DtXbZfrXuIsWoYtJO2d0+ZNs/2s7bfqn5PO0xaTb2ts320eu1etb28pt4W2P6j7Tds77H97Wp6ra9doa9aXre+H/bbniVpv6Q7JI1LelnSvRHxRl8bacD2YUnDEVH7NWHbX5J0RtIjEXFNNe0hSRMR8cPqH+fciPjegPS2TtKZukdurgaUGZo6srSkuyR9QzW+doW+7lENr1sde/4bJR2IiEMR8XdJv5a0ooY+Bl5EbJd07qgZKyRtrB5v1OQfT9816G0gRMSxiHilenxa0kcjS9f62hX6qkUd4b9U0pEpz8c1WEN+h6SttnfaHqm7mWnMr4ZNl6TjkubX2cw0mo7c3E/njCw9MK9dOyNedxtv+H3ckoi4TtL/SVpTHd4OpJg8ZxukyzUtjdzcL9OMLP0vdb527Y543W11hP+opAVTnn+umjYQIuJo9fukpN9p8EYfPvHRIKnV75M19/MvgzRy83QjS2sAXrtBGvG6jvC/LOkK25+3/WlJX5O0pYY+Psb2hdUbMbJ9oaQva/BGH94iaVX1eJWkJ2rs5T8MysjNjUaWVs2v3cCNeB0Rff+RtFyT7/gflPT9Onpo0Nd/S3qt+tlTd2+SHtXkYeA/NPneyDclXSRpm6S3JP1B0rwB6u1XknZJel2TQRuqqbclmjykf13Sq9XP8rpfu0JftbxufMIPSIo3/ICkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPVP82g/p9/JjhUAAAAASUVORK5CYII=\n", 1182 | "text/plain": [ 1183 | "
" 1184 | ] 1185 | }, 1186 | "metadata": { 1187 | "needs_background": "light" 1188 | }, 1189 | "output_type": "display_data" 1190 | } 1191 | ], 1192 | "source": [ 1193 | "index = 0\n", 1194 | "pred_y = convNet.predicate(X_test[index:index+1])\n", 1195 | "plt.imshow(X_test[index].reshape((28,28)),cmap = plt.cm.gray)\n", 1196 | "print(\"y is:\"+str(np.argmax(y_test[index])))\n", 1197 | "print(\"your predicate result is :\"+str(np.argmax(pred_y)))" 1198 | ] 1199 | }, 1200 | { 1201 | "cell_type": "code", 1202 | "execution_count": null, 1203 | "metadata": {}, 1204 | "outputs": [], 1205 | "source": [] 1206 | }, 1207 | { 1208 | "cell_type": "code", 1209 | "execution_count": null, 1210 | "metadata": {}, 1211 | "outputs": [], 1212 | "source": [] 1213 | } 1214 | ], 1215 | "metadata": { 1216 | "kernelspec": { 1217 | "display_name": "Python 3", 1218 | "language": "python", 1219 | "name": "python3" 1220 | }, 1221 | "language_info": { 1222 | "codemirror_mode": { 1223 | "name": "ipython", 1224 | "version": 3 1225 | }, 1226 | "file_extension": ".py", 1227 | "mimetype": "text/x-python", 1228 | "name": "python", 1229 | "nbconvert_exporter": "python", 1230 | "pygments_lexer": "ipython3", 1231 | "version": "3.7.3" 1232 | }, 1233 | "toc": { 1234 | "base_numbering": 1, 1235 | "nav_menu": {}, 1236 | "number_sections": true, 1237 | "sideBar": false, 1238 | "skip_h1_title": false, 1239 | "title_cell": "Table of Contents", 1240 | "title_sidebar": "Contents", 1241 | "toc_cell": false, 1242 | "toc_position": { 1243 | "height": "681.3333129882812px", 1244 | "left": "22px", 1245 | "top": "112.33333587646484px", 1246 | "width": "252.3333282470703px" 1247 | }, 1248 | "toc_section_display": false, 1249 | "toc_window_display": true 1250 | }, 1251 | "varInspector": { 1252 | "cols": { 1253 | "lenName": 16, 1254 | "lenType": 16, 1255 | "lenVar": 40 1256 | }, 1257 | "kernels_config": { 1258 | "python": { 1259 | "delete_cmd_postfix": "", 1260 | "delete_cmd_prefix": "del ", 1261 | "library": "var_list.py", 1262 | "varRefreshCmd": "print(var_dic_list())" 1263 | }, 1264 | "r": { 1265 | "delete_cmd_postfix": ") ", 1266 | "delete_cmd_prefix": "rm(", 1267 | "library": "var_list.r", 1268 | "varRefreshCmd": "cat(var_dic_list()) " 1269 | } 1270 | }, 1271 | "types_to_exclude": [ 1272 | "module", 1273 | "function", 1274 | "builtin_function_or_method", 1275 | "instance", 1276 | "_Feature" 1277 | ], 1278 | "window_display": false 1279 | } 1280 | }, 1281 | "nbformat": 4, 1282 | "nbformat_minor": 2 1283 | } 1284 | -------------------------------------------------------------------------------- /2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanhuang/SimpleConvNet/43c1d44c83d9f82f45913e4fd42829159971912b/2.jpeg -------------------------------------------------------------------------------- /5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanhuang/SimpleConvNet/43c1d44c83d9f82f45913e4fd42829159971912b/5.jpeg -------------------------------------------------------------------------------- /7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanhuang/SimpleConvNet/43c1d44c83d9f82f45913e4fd42829159971912b/7.jpeg -------------------------------------------------------------------------------- /SimpleConvNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# CNN实现手写数字识别" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Package" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 3, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import sys ,os \n", 24 | "import numpy as np\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import tensorflow as tf #只是用来加载mnist数据集\n", 27 | "from PIL import Image\n", 28 | "import pandas as pd \n", 29 | "import math" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## 加载MNIST数据集" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def one_hot_label(y):\n", 46 | " one_hot_label = np.zeros((y.shape[0],10))\n", 47 | " y = y.reshape(y.shape[0])\n", 48 | " one_hot_label[range(y.shape[0]),y] = 1\n", 49 | " return one_hot_label\n", 50 | " \n", 51 | " " 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 6, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "shape of x_train is :(60000, 1, 28, 28)\n", 64 | "shape of t_train is :(60000, 10)\n", 65 | "shape of x_test is :(10000, 1, 28, 28)\n", 66 | "shape of t_test is :(10000, 10)\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "# #(训练图像,训练标签),(测试图像,测试标签)\n", 72 | "# # mnist的图像均为28*28尺寸的数据,通道为1\n", 73 | "(x_train_origin,t_train_origin),(x_test_origin,t_test_origin) = tf.keras.datasets.mnist.load_data()\n", 74 | "X_train = x_train_origin/255.0\n", 75 | "X_test = x_test_origin/255.0\n", 76 | "m,h,w = x_train_origin.shape\n", 77 | "X_train = X_train.reshape((m,1,h,w))\n", 78 | "y_train = one_hot_label(t_train_origin)\n", 79 | "\n", 80 | "m,h,w = x_test_origin.shape\n", 81 | "X_test = X_test.reshape((m,1,h,w))\n", 82 | "y_test = one_hot_label(t_test_origin)\n", 83 | "print(\"shape of x_train is :\"+repr(X_train.shape))\n", 84 | "print(\"shape of t_train is :\"+repr(y_train.shape))\n", 85 | "print(\"shape of x_test is :\"+repr(X_test.shape))\n", 86 | "print(\"shape of t_test is :\"+repr(y_test.shape))\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## 显示图像" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "y is:5\n" 106 | ] 107 | }, 108 | { 109 | "data": { 110 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADgdJREFUeJzt3X9sXfV5x/HPs9D8QRoIXjUTpWFpIhQUIuZOJkwoGkXM5YeCggGhWkLKRBT3j1ii0hQNZX8MNAVFg2RqBKrsqqHJ1KWZBCghqpp0CZBOTBEmhF9mKQylqi2TFAWTH/zIHD/74x53Lvh+r3Pvufdc+3m/JMv3nuecex4d5ZPz8/pr7i4A8fxJ0Q0AKAbhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q1GWNXJmZ8TghUGfublOZr6Y9v5ndYWbHzex9M3ukls8C0FhW7bP9ZjZL0m8kdUgalPSqpC53H0gsw54fqLNG7PlXSHrf3T9w9wuSfi5pdQ2fB6CBagn/Akm/m/B+MJv2R8ys28z6zay/hnUByFndL/i5e5+kPonDfqCZ1LLnH5K0cML7b2bTAEwDtYT/VUnXmtm3zGy2pO9J2ptPWwDqrerDfncfNbMeSfslzZK03d3fya0zAHVV9a2+qlbGOT9Qdw15yAfA9EX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFUP0S1JZnZC0llJFyWNunt7Hk0hP7NmzUrWr7zyyrquv6enp2zt8ssvTy67dOnSZH39+vXJ+pNPPlm21tXVlVz2888/T9Y3b96crD/22GPJejOoKfyZW939oxw+B0ADcdgPBFVr+F3SATN7zcy682gIQGPUeti/0t2HzOzPJP3KzP7b3Q9PnCH7T4H/GIAmU9Oe392Hst+nJD0vacUk8/S5ezsXA4HmUnX4zWyOmc0dfy3pu5LezqsxAPVVy2F/q6TnzWz8c/7N3X+ZS1cA6q7q8Lv7B5L+IsdeZqxrrrkmWZ89e3ayfvPNNyfrK1euLFubN29ectn77rsvWS/S4OBgsr5t27ZkvbOzs2zt7NmzyWXfeOONZP3ll19O1qcDbvUBQRF+ICjCDwRF+IGgCD8QFOEHgjJ3b9zKzBq3sgZqa2tL1g8dOpSs1/trtc1qbGwsWX/ooYeS9XPnzlW97uHh4WT9448/TtaPHz9e9brrzd1tKvOx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoLjPn4OWlpZk/ciRI8n64sWL82wnV5V6HxkZSdZvvfXWsrULFy4kl436/EOtuM8PIInwA0ERfiAowg8ERfiBoAg/EBThB4LKY5Te8E6fPp2sb9iwIVlftWpVsv76668n65X+hHXKsWPHkvWOjo5k/fz588n69ddfX7b28MMPJ5dFfbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKn6f38y2S1ol6ZS7L8+mtUjaLWmRpBOSHnD39B8618z9Pn+trrjiimS90nDSvb29ZWtr165NLvvggw8m67t27UrW0Xzy/D7/TyXd8aVpj0g66O7XSjqYvQcwjVQMv7sflvTlR9hWS9qRvd4h6Z6c+wJQZ9We87e6+/h4Rx9Kas2pHwANUvOz/e7uqXN5M+uW1F3regDkq9o9/0kzmy9J2e9T5WZ09z53b3f39irXBaAOqg3/XklrstdrJO3Jpx0AjVIx/Ga2S9J/SVpqZoNmtlbSZkkdZvaepL/J3gOYRiqe87t7V5nSbTn3EtaZM2dqWv6TTz6petl169Yl67t3707Wx8bGql43isUTfkBQhB8IivADQRF+ICjCDwRF+IGgGKJ7BpgzZ07Z2gsvvJBc9pZbbknW77zzzmT9wIEDyToajyG6ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQ3Oef4ZYsWZKsHz16NFkfGRlJ1l988cVkvb+/v2zt6aefTi7byH+bMwn3+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUNznD66zszNZf+aZZ5L1uXPnVr3ujRs3Jus7d+5M1oeHh5P1qLjPDyCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCqnif38y2S1ol6ZS7L8+mPSppnaTfZ7NtdPdfVFwZ9/mnneXLlyfrW7duTdZvu636kdx7e3uT9U2bNiXrQ0NDVa97OsvzPv9PJd0xyfR/cfe27Kdi8AE0l4rhd/fDkk43oBcADVTLOX+Pmb1pZtvN7KrcOgLQENWG/0eSlkhqkzQsaUu5Gc2s28z6zaz8H3MD0HBVhd/dT7r7RXcfk/RjSSsS8/a5e7u7t1fbJID8VRV+M5s/4W2npLfzaQdAo1xWaQYz2yXpO5K+YWaDkv5R0nfMrE2SSzoh6ft17BFAHfB9ftRk3rx5yfrdd99dtlbpbwWYpW9XHzp0KFnv6OhI1mcqvs8PIInwA0ERfiAowg8ERfiBoAg/EBS3+lCYL774Ilm/7LL0Yyijo6PJ+u2331629tJLLyWXnc641QcgifADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX7EdsMNNyTr999/f7J+4403lq1Vuo9fycDAQLJ++PDhmj5/pmPPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZ9/hlu6dGmy3tPTk6zfe++9yfrVV199yT1N1cWLF5P14eHhZH1sbCzPdmYc9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFTF+/xmtlDSTkmtklxSn7v/0MxaJO2WtEjSCUkPuPvH9Ws1rkr30ru6usrWKt3HX7RoUTUt5aK/vz9Z37RpU7K+d+/ePNsJZyp7/lFJf+fuyyT9laT1ZrZM0iOSDrr7tZIOZu8BTBMVw+/uw+5+NHt9VtK7khZIWi1pRzbbDkn31KtJAPm7pHN+M1sk6duSjkhqdffx5ys/VOm0AMA0MeVn+83s65KelfQDdz9j9v/Dgbm7lxuHz8y6JXXX2iiAfE1pz29mX1Mp+D9z9+eyySfNbH5Wny/p1GTLunufu7e7e3seDQPIR8XwW2kX/xNJ77r71gmlvZLWZK/XSNqTf3sA6qXiEN1mtlLSryW9JWn8O5IbVTrv/3dJ10j6rUq3+k5X+KyQQ3S3tqYvhyxbtixZf+qpp5L166677pJ7ysuRI0eS9SeeeKJsbc+e9P6Cr+RWZ6pDdFc853f3/5RU7sNuu5SmADQPnvADgiL8QFCEHwiK8ANBEX4gKMIPBMWf7p6ilpaWsrXe3t7ksm1tbcn64sWLq+opD6+88kqyvmXLlmR9//79yfpnn312yT2hMdjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQYe7z33TTTcn6hg0bkvUVK1aUrS1YsKCqnvLy6aeflq1t27Ytuezjjz+erJ8/f76qntD82PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBh7vN3dnbWVK/FwMBAsr5v375kfXR0NFlPfed+ZGQkuSziYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0GZu6dnMFsoaaekVkkuqc/df2hmj0paJ+n32awb3f0XFT4rvTIANXN3m8p8Uwn/fEnz3f2omc2V9JqkeyQ9IOmcuz851aYIP1B/Uw1/xSf83H1Y0nD2+qyZvSup2D9dA6Bml3TOb2aLJH1b0pFsUo+ZvWlm283sqjLLdJtZv5n119QpgFxVPOz/w4xmX5f0sqRN7v6cmbVK+kil6wD/pNKpwUMVPoPDfqDOcjvnlyQz+5qkfZL2u/vWSeqLJO1z9+UVPofwA3U21fBXPOw3M5P0E0nvTgx+diFwXKekty+1SQDFmcrV/pWSfi3pLUlj2eSNkroktal02H9C0vezi4Opz2LPD9RZrof9eSH8QP3ldtgPYGYi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBNXoIbo/kvTbCe+/kU1rRs3aW7P2JdFbtfLs7c+nOmNDv8//lZWb9bt7e2ENJDRrb83al0Rv1SqqNw77gaAIPxBU0eHvK3j9Kc3aW7P2JdFbtQrprdBzfgDFKXrPD6AghYTfzO4ws+Nm9r6ZPVJED+WY2Qkze8vMjhU9xFg2DNopM3t7wrQWM/uVmb2X/Z50mLSCenvUzIaybXfMzO4qqLeFZvaimQ2Y2Ttm9nA2vdBtl+irkO3W8MN+M5sl6TeSOiQNSnpVUpe7DzS0kTLM7ISkdncv/J6wmf21pHOSdo6PhmRm/yzptLtvzv7jvMrd/75JentUlzhyc516Kzey9N+qwG2X54jXeShiz79C0vvu/oG7X5D0c0mrC+ij6bn7YUmnvzR5taQd2esdKv3jabgyvTUFdx9296PZ67OSxkeWLnTbJfoqRBHhXyDpdxPeD6q5hvx2SQfM7DUz6y66mUm0ThgZ6UNJrUU2M4mKIzc30pdGlm6abVfNiNd544LfV61097+UdKek9dnhbVPy0jlbM92u+ZGkJSoN4zYsaUuRzWQjSz8r6QfufmZirchtN0lfhWy3IsI/JGnhhPffzKY1BXcfyn6fkvS8SqcpzeTk+CCp2e9TBffzB+5+0t0vuvuYpB+rwG2XjSz9rKSfuftz2eTCt91kfRW13YoI/6uSrjWzb5nZbEnfk7S3gD6+wszmZBdiZGZzJH1XzTf68F5Ja7LXayTtKbCXP9IsIzeXG1laBW+7phvx2t0b/iPpLpWu+P+PpH8ooocyfS2W9Eb2807RvUnapdJh4P+qdG1kraQ/lXRQ0nuS/kNSSxP19q8qjeb8pkpBm19QbytVOqR/U9Kx7Oeuorddoq9CthtP+AFBccEPCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ/weCC5r/92q6mAAAAABJRU5ErkJggg==\n", 111 | "text/plain": [ 112 | "
" 113 | ] 114 | }, 115 | "metadata": { 116 | "needs_background": "light" 117 | }, 118 | "output_type": "display_data" 119 | } 120 | ], 121 | "source": [ 122 | "index = 0\n", 123 | "plt.imshow(X_train[index].reshape((28,28)),cmap = plt.cm.gray)\n", 124 | "print(\"y is:\"+str(np.argmax(y_train[index])))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## 辅助函数" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "将根据滤波器的大小,步幅,填充,输入数据展开为2维数组。" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 8, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def im2col2(input_data,fh,fw,stride=1,pad=0):\n", 148 | " '''\n", 149 | " Arguments:\n", 150 | " \n", 151 | " input_data--输入数据,shape为(Number of example,Channel,Height,Width)\n", 152 | " fh -- 滤波器的height\n", 153 | " fw --滤波器的width\n", 154 | " stride -- 步幅\n", 155 | " pad -- 填充\n", 156 | " \n", 157 | " Returns :\n", 158 | " col -- 输入数据根据滤波器、步幅等展开的二维数组,每一行代表一条卷积数据\n", 159 | " '''\n", 160 | " N,C,H,W = input_data.shape\n", 161 | " \n", 162 | " out_h = (H + 2*pad - fh)//stride+1\n", 163 | " out_w = (W+2*pad-fw)//stride+1\n", 164 | " \n", 165 | " img = np.pad(input_data,[(0,0),(0,0),(pad,pad),(pad,pad)],\"constant\")\n", 166 | " \n", 167 | " col = np.zeros((N,out_h,out_w,fh*fw*C))\n", 168 | " \n", 169 | " #将所有维度上需要卷积的值展开成一列\n", 170 | " for y in range(out_h):\n", 171 | " y_start = y * stride\n", 172 | " y_end = y_start + fh\n", 173 | " for x in range(out_w):\n", 174 | " x_start = x*stride\n", 175 | " x_end = x_start+fw\n", 176 | " col[:,y,x] = img[:,:,y_start:y_end,x_start:x_end].reshape(N,-1)\n", 177 | " col = col.reshape(N*out_h*out_w,-1)\n", 178 | " return col\n", 179 | "\n", 180 | " " 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "将二维数据转成image" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 11, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "def col2im2(col,out_shape,fh,fw,stride=1,pad=0):\n", 197 | " '''\n", 198 | " Arguments:\n", 199 | " col: 二维数组 \n", 200 | " out_shape-- 输出的shape,shape为(Number of example,Channel,Height,Width)\n", 201 | " fh -- 滤波器的height\n", 202 | " fw --滤波器的width\n", 203 | " stride -- 步幅\n", 204 | " pad -- 填充\n", 205 | " \n", 206 | " Returns :\n", 207 | " img -- 将col转换成的img ,shape为out_shape\n", 208 | " '''\n", 209 | " N,C,H,W = out_shape\n", 210 | " \n", 211 | " col_m,col_n = col.shape\n", 212 | " \n", 213 | " out_h = (H + 2*pad - fh)//stride+1\n", 214 | " out_w = (W+2*pad-fw)//stride+1\n", 215 | "\n", 216 | " \n", 217 | "\n", 218 | " img = np.zeros((N, C, H , W))\n", 219 | " # img = np.pad(img,[(0,0),(0,0),(pad,pad),(pad,pad)],\"constant\")\n", 220 | "\n", 221 | " #将col转换成一个filter\n", 222 | " for c in range(C):\n", 223 | " for y in range(out_h):\n", 224 | " for x in range(out_w):\n", 225 | " col_index = (c*out_h*out_w)+y*out_w+x\n", 226 | " ih = y*stride\n", 227 | " iw = x*stride\n", 228 | " img[:,c,ih:ih+fh,iw:iw+fw] = col[col_index].reshape((fh,fw))\n", 229 | " return img\n" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 12, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "def im2col2test():\n", 239 | " a = np.random.randint(0,5,size=(2,3,4,4))\n", 240 | " c = im2col2(a,2,2,2,1)\n", 241 | "\n", 242 | " a = np.random.randint(0,5,size=(1,1,3,3))\n", 243 | " print(a)\n", 244 | " c = im2col2(a,2,2,1,0)\n", 245 | " print(c)\n", 246 | " col2im2(c,out_shape=(1,1,3,3),fh=2,fw=2,stride=1)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "### 激活函数" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 13, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "def relu(input_X):\n", 263 | " \"\"\"\n", 264 | " Arguments:\n", 265 | " input_X -- a numpy array\n", 266 | " Return :\n", 267 | " A: a numpy array. let each elements in array all greater or equal 0\n", 268 | " \"\"\"\n", 269 | " \n", 270 | " A = np.where(input_X < 0 ,0,input_X)\n", 271 | " return A\n", 272 | " " 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 14, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "def softmax(input_X):\n", 282 | " \"\"\"\n", 283 | " Arguments:\n", 284 | " input_X -- a numpy array\n", 285 | " Return :\n", 286 | " A: a numpy array same shape with input_X\n", 287 | " \"\"\"\n", 288 | " exp_a = np.exp(input_X)\n", 289 | " sum_exp_a = np.sum(exp_a,axis=1)\n", 290 | " sum_exp_a = sum_exp_a.reshape(input_X.shape[0],-1)\n", 291 | " ret = exp_a/sum_exp_a\n", 292 | " # print(ret)\n", 293 | " return ret" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "### 损失函数" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 15, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "def cross_entropy_error(labels,logits):\n", 310 | " return -np.sum(labels*np.log(logits))\n" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "## 卷积层" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 17, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "class Convolution:\n", 327 | " def __init__(self,W,fb,stride = 1,pad = 0):\n", 328 | " \"\"\"\n", 329 | " W-- 滤波器权重,shape为(FN,NC,FH,FW),FN 为滤波器的个数\n", 330 | " fb -- 滤波器的偏置,shape 为(1,FN) \n", 331 | " stride -- 步长\n", 332 | " pad -- 填充个数\n", 333 | " \"\"\"\n", 334 | " self.W = W\n", 335 | " self.fb = fb \n", 336 | " self.stride = stride\n", 337 | " self.pad = pad\n", 338 | " \n", 339 | " \n", 340 | " self.col_X = None\n", 341 | " self.X = None\n", 342 | " self.col_W = None\n", 343 | " \n", 344 | " self.dW = None\n", 345 | " self.db = None\n", 346 | " self.out_shape = None\n", 347 | " # self.out = None\n", 348 | " \n", 349 | " def forward (self ,input_X):\n", 350 | " \"\"\"\n", 351 | " input_X-- shape为(m,nc,height,width)\n", 352 | " \"\"\" \n", 353 | " self.X = input_X\n", 354 | " FN,NC,FH,FW = self.W.shape\n", 355 | " \n", 356 | " m,input_nc, input_h,input_w = self.X.shape\n", 357 | " \n", 358 | " #先计算输出的height和widt\n", 359 | " out_h = int((input_h+2*self.pad-FH)/self.stride + 1)\n", 360 | " out_w = int((input_w+2*self.pad-FW)/self.stride + 1)\n", 361 | " \n", 362 | " #将输入数据展开成二维数组,shape为(m*out_h*out_w,FH*FW*C)\n", 363 | " self.col_X = col_X = im2col2(self.X,FH,FW,self.stride,self.pad)\n", 364 | " \n", 365 | " #将滤波器一个个按列展开(FH*FW*C,FN)\n", 366 | " self.col_W = col_W = self.W.reshape(FN,-1).T\n", 367 | " out = np.dot(col_X,col_W)+self.fb\n", 368 | " out = out.T\n", 369 | " out = out.reshape(m,FN,out_h,out_w)\n", 370 | " self.out_shape = out.shape\n", 371 | " return out\n", 372 | " \n", 373 | " def backward(self, dz,learning_rate):\n", 374 | " #print(\"==== Conv backbward ==== \")\n", 375 | " assert(dz.shape == self.out_shape)\n", 376 | " \n", 377 | " FN,NC,FH,FW = self.W.shape\n", 378 | " o_FN,o_NC,o_FH,o_FW = self.out_shape\n", 379 | " \n", 380 | " col_dz = dz.reshape(o_NC,-1)\n", 381 | " col_dz = col_dz.T\n", 382 | " \n", 383 | " self.dW = np.dot(self.col_X.T,col_dz) #shape is (FH*FW*C,FN)\n", 384 | " self.db = np.sum(col_dz,axis=0,keepdims=True)\n", 385 | "\n", 386 | " \n", 387 | " self.dW = self.dW.T.reshape(self.W.shape)\n", 388 | " self.db = self.db.reshape(self.fb.shape)\n", 389 | " \n", 390 | " \n", 391 | " d_col_x = np.dot(col_dz,self.col_W.T) #shape is (m*out_h*out_w,FH,FW*C)\n", 392 | " dx = col2im2(d_col_x,self.X.shape,FH,FW,stride=1)\n", 393 | " \n", 394 | " assert(dx.shape == self.X.shape)\n", 395 | " \n", 396 | " #更新W和b\n", 397 | " self.W = self.W - learning_rate*self.dW\n", 398 | " self.fb = self.fb -learning_rate*self.db\n", 399 | " \n", 400 | " return dx\n", 401 | " \n", 402 | " " 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "## 池化层" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 18, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "class Pooling:\n", 419 | " def __init__(self,pool_h,pool_w,stride = 1,pad = 0):\n", 420 | " self.pool_h = pool_h\n", 421 | " self.pool_w = pool_w\n", 422 | " self.stride = stride\n", 423 | " self.pad = pad \n", 424 | " self.X = None\n", 425 | " self.arg_max = None\n", 426 | " \n", 427 | " def forward ( self,input_X) :\n", 428 | " \"\"\"\n", 429 | " 前向传播\n", 430 | " input_X-- shape为(m,nc,height,width)\n", 431 | " \"\"\" \n", 432 | " self.X = input_X\n", 433 | " N , C, H, W = input_X.shape\n", 434 | " out_h = int(1+(H-self.pool_h)/self.stride)\n", 435 | " out_w = int(1+(W-self.pool_w)/self.stride)\n", 436 | " \n", 437 | " #展开\n", 438 | " col = im2col2(input_X,self.pool_h,self.pool_w,self.stride,self.pad)\n", 439 | " col = col.reshape(-1,self.pool_h*self.pool_w)\n", 440 | " arg_max = np.argmax(col,axis=1)\n", 441 | " #最大值\n", 442 | " out = np.max(col,axis=1)\n", 443 | " out =out.T.reshape(N,C,out_h,out_w)\n", 444 | " self.arg_max = arg_max\n", 445 | " return out\n", 446 | " \n", 447 | " def backward(self ,dz):\n", 448 | " \"\"\"\n", 449 | " 反向传播\n", 450 | " Arguments:\n", 451 | " dz-- out的导数,shape与out 一致\n", 452 | " \n", 453 | " Return:\n", 454 | " 返回前向传播是的input_X的导数\n", 455 | " \"\"\" \n", 456 | " pool_size = self.pool_h*self.pool_w\n", 457 | " dmax = np.zeros((dz.size,pool_size))\n", 458 | " dmax[np.arange(self.arg_max.size),self.arg_max.flatten()] = dz.flatten()\n", 459 | " \n", 460 | " dx = col2im2(dmax,out_shape=self.X.shape,fh=self.pool_h,fw=self.pool_w,stride=self.stride)\n", 461 | " return dx\n", 462 | " \n", 463 | " " 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "## Relu层" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 19, 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "class Relu:\n", 480 | " def __init__(self):\n", 481 | " self.mask = None\n", 482 | " \n", 483 | " def forward(self ,X):\n", 484 | " self.mask = X <= 0\n", 485 | " out = X\n", 486 | " out[self.mask] = 0\n", 487 | " return out\n", 488 | " \n", 489 | " def backward(self,dz):\n", 490 | " dz[self.mask] = 0\n", 491 | " dx = dz \n", 492 | " return dx" 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": {}, 498 | "source": [ 499 | "## SoftMax层" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 20, 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "class SoftMax:\n", 509 | " def __init__ (self):\n", 510 | " self.y_hat = None\n", 511 | " \n", 512 | " def forward(self,X):\n", 513 | " \n", 514 | " self.y_hat = softmax(X)\n", 515 | " return self.y_hat\n", 516 | " \n", 517 | " def backward(self,labels):\n", 518 | " m = labels.shape[0]\n", 519 | " dx = (self.y_hat - labels)\n", 520 | " \n", 521 | " return dx\n", 522 | " " 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 21, 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "def compute_cost(logits,label):\n", 532 | " return cross_entropy_error(label,logits)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "## Affine FC层" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 22, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "class Affine:\n", 549 | " def __init__(self,W,b):\n", 550 | " self.W = W # shape is (n_x,n_unit)\n", 551 | " self.b = b # shape is(1,n_unit)\n", 552 | " self.X = None\n", 553 | " self.origin_x_shape = None\n", 554 | " \n", 555 | " self.dW = None\n", 556 | " self.db = None\n", 557 | " \n", 558 | " self.out_shape =None\n", 559 | " \n", 560 | " def forward(self,X):\n", 561 | " self.origin_x_shape = X.shape \n", 562 | " self.X = X.reshape(X.shape[0],-1)#(m,n)\n", 563 | " out = np.dot(self.X, self.W)+self.b\n", 564 | " self.out_shape = out.shape\n", 565 | " return out\n", 566 | " \n", 567 | " def backward(self,dz,learning_rate):\n", 568 | " \"\"\"\n", 569 | " dz-- 前面的导数\n", 570 | " \"\"\" \n", 571 | "# print(\"Affine backward\")\n", 572 | "# print(self.X.shape)\n", 573 | "# print(dz.shape)\n", 574 | "# print(self.W.shape)\n", 575 | " \n", 576 | " assert(dz.shape == self.out_shape)\n", 577 | " \n", 578 | " m = self.X.shape[0]\n", 579 | " \n", 580 | " self.dW = np.dot(self.X.T,dz)/m\n", 581 | " self.db = np.sum(dz,axis=0,keepdims=True)/m\n", 582 | " \n", 583 | " assert(self.dW.shape == self.W.shape)\n", 584 | " assert(self.db.shape == self.b.shape)\n", 585 | " \n", 586 | " dx = np.dot(dz,self.W.T)\n", 587 | " assert(dx.shape == self.X.shape)\n", 588 | " \n", 589 | " dx = dx.reshape(self.origin_x_shape) # 保持与之前的x一样的shape\n", 590 | " \n", 591 | " #更新W和b\n", 592 | " self.W = self.W-learning_rate*self.dW\n", 593 | " self.b = self.b - learning_rate*self.db\n", 594 | " \n", 595 | " return dx\n", 596 | " " 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": {}, 602 | "source": [ 603 | "## 模型\n" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 97, 609 | "metadata": { 610 | "code_folding": [] 611 | }, 612 | "outputs": [], 613 | "source": [ 614 | "class SimpleConvNet:\n", 615 | "\n", 616 | " def __init__(self):\n", 617 | " self.X = None\n", 618 | " self.Y= None\n", 619 | " self.layers = []\n", 620 | "\n", 621 | " def add_conv_layer(self,n_filter,n_c , f, stride=1, pad=0):\n", 622 | " \"\"\"\n", 623 | " 添加一层卷积层\n", 624 | " Arguments:\n", 625 | " n_c -- 输入数据通道数,也即卷积层的通道数\n", 626 | " n_filter -- 滤波器的个数\n", 627 | " f --滤波器的长/宽\n", 628 | "\n", 629 | " Return :\n", 630 | " Conv -- 卷积层\n", 631 | " \"\"\"\n", 632 | "\n", 633 | " # 初始化W,b\n", 634 | " W = np.random.randn(n_filter, n_c, f, f)*0.01\n", 635 | " fb = np.zeros((1, n_filter))\n", 636 | " # 卷积层\n", 637 | " Conv = Convolution(W, fb, stride=stride, pad=pad)\n", 638 | " return Conv\n", 639 | "\n", 640 | " def add_maxpool_layer(self, pool_shape, stride=1, pad=0):\n", 641 | " \"\"\"\n", 642 | " 添加一层池化层\n", 643 | " Arguments:\n", 644 | " pool_shape -- 滤波器的shape\n", 645 | " f -- 滤波器大小\n", 646 | " Return :\n", 647 | " Pool -- 初始化的Pool类\n", 648 | " \"\"\"\n", 649 | " pool_h, pool_w = pool_shape\n", 650 | " pool = Pooling(pool_h, pool_w, stride=stride, pad=pad)\n", 651 | " \n", 652 | " return pool\n", 653 | " \n", 654 | " def add_affine(self,n_x, n_units):\n", 655 | " \"\"\"\n", 656 | " 添加一层全连接层\n", 657 | " Arguments:\n", 658 | " n_x -- 输入个数\n", 659 | " n_units -- 神经元个数\n", 660 | " Return :\n", 661 | " fc_layer -- Affine层对象\n", 662 | " \"\"\"\n", 663 | " \n", 664 | " W= np.random.randn(n_x, n_units)*0.01\n", 665 | " \n", 666 | " b = np.zeros((1, n_units))\n", 667 | " \n", 668 | " fc_layer = Affine(W,b)\n", 669 | " \n", 670 | " return fc_layer\n", 671 | " \n", 672 | " def add_relu(self):\n", 673 | " relu_layer = Relu()\n", 674 | " return relu_layer\n", 675 | " \n", 676 | " \n", 677 | " def add_softmax(self):\n", 678 | " softmax_layer = SoftMax()\n", 679 | " return softmax_layer\n", 680 | " \n", 681 | " #计算卷积或池化后的H和W\n", 682 | " def cacl_out_hw(self,HW,f,stride = 1,pad = 0):\n", 683 | " return (HW+2*pad - f)/stride+1\n", 684 | " \n", 685 | "\n", 686 | " \n", 687 | " \n", 688 | " def init_model(self,train_X,n_classes):\n", 689 | " \"\"\"\n", 690 | " 初始化一个卷积层网络\n", 691 | " \"\"\"\n", 692 | " N,C,H,W = train_X.shape\n", 693 | " #卷积层\n", 694 | " n_filter = 4\n", 695 | " f = 7\n", 696 | " \n", 697 | " conv_layer = self.add_conv_layer(n_filter= n_filter,n_c=C,f=f,stride=1)\n", 698 | " \n", 699 | " out_h = self.cacl_out_hw(H,f)\n", 700 | " out_w = self.cacl_out_hw(W,f)\n", 701 | " out_ch = n_filter\n", 702 | " \n", 703 | " self.layers.append(conv_layer)\n", 704 | " \n", 705 | " #Relu\n", 706 | " relu_layer = self.add_relu()\n", 707 | " self.layers.append(relu_layer)\n", 708 | " \n", 709 | " #池化\n", 710 | " f = 2\n", 711 | " pool_layer = self.add_maxpool_layer(pool_shape=(f,f),stride=2)\n", 712 | " out_h = self.cacl_out_hw(out_h,f,stride=2)\n", 713 | " out_w = self.cacl_out_hw(out_w,f,stride=2)\n", 714 | " #out_ch 不改变\n", 715 | " self.layers.append(pool_layer)\n", 716 | " \n", 717 | "\n", 718 | " \n", 719 | " \n", 720 | " #Affine层\n", 721 | " n_x = int(out_h*out_w*out_ch)\n", 722 | " n_units = 32\n", 723 | " fc_layer = self.add_affine(n_x=n_x,n_units=n_units)\n", 724 | " self.layers.append(fc_layer)\n", 725 | " \n", 726 | " #Relu\n", 727 | " relu_layer = self.add_relu()\n", 728 | " self.layers.append(relu_layer)\n", 729 | " \n", 730 | " #Affine\n", 731 | " fc_layer = self.add_affine(n_x=n_units,n_units=n_classes)\n", 732 | " self.layers.append(fc_layer)\n", 733 | " \n", 734 | " #SoftMax\n", 735 | " softmax_layer = self.add_softmax()\n", 736 | " self.layers.append(softmax_layer)\n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " def forward_progation(self,train_X, print_out = False):\n", 741 | " \"\"\"\n", 742 | " 前向传播\n", 743 | " Arguments:\n", 744 | " train_X -- 训练数据\n", 745 | " f -- 滤波器大小\n", 746 | "\n", 747 | " Return :\n", 748 | " Z-- 前向传播的结果\n", 749 | " loss -- 损失值\n", 750 | " \"\"\"\n", 751 | " \n", 752 | " \n", 753 | " N,C,H,W = train_X.shape\n", 754 | " index = 0\n", 755 | " # 卷积层\n", 756 | " conv_layer = self.layers[index]\n", 757 | " X = conv_layer.forward(train_X)\n", 758 | " index =index+1\n", 759 | " if print_out:\n", 760 | " print(\"卷积之后:\"+str(X.shape))\n", 761 | " # Relu\n", 762 | " relu_layer = self.layers[index]\n", 763 | " index =index+1\n", 764 | " X = relu_layer.forward(X)\n", 765 | " if print_out:\n", 766 | " print(\"Relu:\"+str(X.shape))\n", 767 | " \n", 768 | " \n", 769 | " # 池化层\n", 770 | " pool_layer = self.layers[index]\n", 771 | " index =index+1\n", 772 | " X = pool_layer.forward(X)\n", 773 | " if print_out:\n", 774 | " print(\"池化:\"+str(X.shape))\n", 775 | "\n", 776 | "\n", 777 | " #Affine层\n", 778 | " fc_layer = self.layers[index]\n", 779 | " index =index+1\n", 780 | " X = fc_layer.forward(X)\n", 781 | " if print_out:\n", 782 | " print(\"Affline 层的X:\"+str(X.shape))\n", 783 | "\n", 784 | " #Relu\n", 785 | " relu_layer = self.layers[index]\n", 786 | " index =index+1\n", 787 | " X = relu_layer.forward(X)\n", 788 | " if print_out:\n", 789 | " print(\"Relu 层的X:\"+str(X.shape))\n", 790 | " \n", 791 | " #Affine层\n", 792 | " fc_layer = self.layers[index]\n", 793 | " index =index+1\n", 794 | " X = fc_layer.forward(X)\n", 795 | " if print_out:\n", 796 | " print(\"Affline 层的X:\"+str(X.shape))\n", 797 | "\n", 798 | " #SoftMax层\n", 799 | " sofmax_layer = self.layers[index]\n", 800 | " index =index+1\n", 801 | " A = sofmax_layer.forward(X)\n", 802 | " if print_out:\n", 803 | " print(\"Softmax 层的X:\"+str(A.shape))\n", 804 | " \n", 805 | " return A\n", 806 | " \n", 807 | " def back_progation(self,train_y,learning_rate):\n", 808 | " \"\"\"\n", 809 | " 反向传播\n", 810 | " Arguments:\n", 811 | " \n", 812 | " \"\"\"\n", 813 | " index = len(self.layers)-1\n", 814 | " sofmax_layer = self.layers[index]\n", 815 | " index -= 1\n", 816 | " dz = sofmax_layer.backward(train_y)\n", 817 | " \n", 818 | " fc_layer = self.layers[index]\n", 819 | " dz = fc_layer.backward(dz,learning_rate=learning_rate)\n", 820 | " index -= 1\n", 821 | " \n", 822 | " relu_layer = self.layers[index]\n", 823 | " dz = relu_layer.backward(dz)\n", 824 | " index -= 1\n", 825 | " \n", 826 | " fc_layer = self.layers[index]\n", 827 | " dz = fc_layer.backward(dz,learning_rate=learning_rate)\n", 828 | " index -= 1\n", 829 | " \n", 830 | " pool_layer = self.layers[index]\n", 831 | " dz = pool_layer.backward(dz)\n", 832 | " index -= 1\n", 833 | " \n", 834 | " relu_layer = self.layers[index]\n", 835 | " dz = relu_layer.backward(dz)\n", 836 | " index -= 1\n", 837 | " \n", 838 | " conv_layer = self.layers[index]\n", 839 | " conv_layer.backward(dz,learning_rate=learning_rate)\n", 840 | " index -= 1\n", 841 | " \n", 842 | " \n", 843 | " def get_minibatch(self,batch_data,minibatch_size,num):\n", 844 | " m_examples = batch_data.shape[0]\n", 845 | " minibatches = math.ceil( m_examples / minibatch_size)\n", 846 | " \n", 847 | " if(num < minibatches):\n", 848 | " return batch_data[num*minibatch_size:(num+1)*minibatch_size]\n", 849 | " else:\n", 850 | " return batch_data[num*minibatch_size:m_examples]\n", 851 | " \n", 852 | " \n", 853 | " def optimize(self,train_X, train_y,minibatch_size,learning_rate=0.05,num_iters=500):\n", 854 | " \"\"\"\n", 855 | " 优化方法\n", 856 | " Arguments:\n", 857 | " train_X -- 训练数据 \n", 858 | " train_y -- 训练数据的标签\n", 859 | " learning_rate -- 学习率\n", 860 | " num_iters -- 迭代次数\n", 861 | " minibatch_size \n", 862 | " \"\"\"\n", 863 | " m = train_X.shape[0]\n", 864 | " num_batches = math.ceil(m / minibatch_size)\n", 865 | " \n", 866 | " costs = []\n", 867 | " for iteration in range(num_iters):\n", 868 | " iter_cost = 0\n", 869 | " for batch_num in range(num_batches):\n", 870 | " minibatch_X = self.get_minibatch(train_X,minibatch_size,batch_num)\n", 871 | " minibatch_y = self.get_minibatch(train_y,minibatch_size,batch_num)\n", 872 | " \n", 873 | " # 前向传播\n", 874 | " A = self.forward_progation(minibatch_X,print_out=False)\n", 875 | " #损失:\n", 876 | " cost = compute_cost (A,minibatch_y)\n", 877 | " #反向传播\n", 878 | " self.back_progation(minibatch_y,learning_rate)\n", 879 | " if(iteration%100 == 0):\n", 880 | " iter_cost += cost/num_batches\n", 881 | " \n", 882 | " if(iteration%100 == 0):\n", 883 | " print(\"After %d iters ,cost is :%g\" %(iteration,iter_cost))\n", 884 | " costs.append(iter_cost)\n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " #画出损失函数图\n", 890 | " plt.plot(costs)\n", 891 | " plt.xlabel(\"iterations/hundreds\")\n", 892 | " plt.ylabel(\"costs\")\n", 893 | " plt.show()\n", 894 | " \n", 895 | " \n", 896 | " def predicate(self, train_X):\n", 897 | " \"\"\"\n", 898 | " 预测\n", 899 | " \"\"\"\n", 900 | " logits = self.forward_progation(train_X)\n", 901 | " one_hot = np.zeros_like(logits)\n", 902 | " one_hot[range(train_X.shape[0]),np.argmax(logits,axis=1)] = 1\n", 903 | " return one_hot \n", 904 | "\n", 905 | " def fit(self,train_X, train_y):\n", 906 | " \"\"\"\n", 907 | " 训练\n", 908 | " \"\"\"\n", 909 | " self.X = train_X\n", 910 | " self.Y = train_y\n", 911 | " n_y = train_y.shape[1]\n", 912 | " m = train_X.shape[0]\n", 913 | " \n", 914 | " #初始化模型\n", 915 | " self.init_model(train_X,n_classes=n_y)\n", 916 | "\n", 917 | " self.optimize(train_X, train_y,minibatch_size=10,learning_rate=0.05,num_iters=800)\n", 918 | " \n", 919 | " logits = self.predicate(train_X)\n", 920 | " \n", 921 | " accuracy = np.sum(np.argmax(logits,axis=1) == np.argmax(train_y,axis=1))/m\n", 922 | " print(\"训练集的准确率为:%g\" %(accuracy))" 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": 98, 928 | "metadata": {}, 929 | "outputs": [ 930 | { 931 | "name": "stdout", 932 | "output_type": "stream", 933 | "text": [ 934 | "After 0 iters ,cost is :23.0254\n", 935 | "After 100 iters ,cost is :14.5255\n", 936 | "After 200 iters ,cost is :6.01782\n", 937 | "After 300 iters ,cost is :5.71148\n", 938 | "After 400 iters ,cost is :5.63212\n", 939 | "After 500 iters ,cost is :5.45006\n", 940 | "After 600 iters ,cost is :5.05849\n", 941 | "After 700 iters ,cost is :4.29723\n" 942 | ] 943 | }, 944 | { 945 | "data": { 946 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEKCAYAAAAB0GKPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmUVPWd9/H3t6t6oTfWBtkbaCRiVNQOLggCMUaNUZNxMprEMcY8GBUnmeXkSSbPeZJJnjMnM5nJzCS4hKjRTNRkEiVxojEyAqLi1iAgi7IvzdbN2k03vX+fP+oCBXR1F9DVt6r78zqnTt26S9WnO4ZP/+69da+5OyIiIp3JCjuAiIhkBhWGiIgkRYUhIiJJUWGIiEhSVBgiIpIUFYaIiCRFhSEiIklRYYiISFJUGCIikpRo2AG60qBBg7y0tDTsGCIiGWPp0qV73b0kmXV7VGGUlpZSUVERdgwRkYxhZluTXVe7pEREJCkqDBERSYoKQ0REkqLCEBGRpKgwREQkKSoMERFJigpDRESS0usLo6G5lZ8t3sTbm/aFHUVEJK31+sIwg8de38yP5q8LO4qISFrr9YWRG41wz9VjeXvzft7ZvD/sOCIiaStlhWFmI81soZmtMbPVZva1YP4PzewDM1tpZvPMrF+C7beY2ftmttzMUnq9j9s+NopBhTnMWbghlR8jIpLRUjnCaAH+1t0nApcD95vZRGA+8FF3vxBYB3yrg/eY4e6T3L08hTnpkxPhK1PHsnhdNSu2H0zlR4mIZKyUFYa773L3ZcF0LbAWGO7uL7t7S7DaW8CIVGU4HV+8fDR9+2RrlCEikkC3HMMws1LgYuDtkxZ9Gfhjgs0ceNnMlprZrNSliynMjfLlKWOYv2YPa3fVpPrjREQyTsoLw8wKgWeBr7t7Tdz8bxPbbfVUgk2vcvdLgOuJ7c6aluD9Z5lZhZlVVFdXn1XWL11ZSmFulAc1yhAROUVKC8PMsomVxVPu/lzc/C8BNwJfcHdvb1t33xE8VwHzgMkJ1pvr7uXuXl5SktQ9QBLqm5/NX14xmhfe38XG6sNn9V4iIj1NKs+SMuAxYK27/yhu/nXAN4Cb3L0+wbYFZlZ0dBq4FliVqqzx7r5qDLnRLB5auLE7Pk5EJGOkcoQxBbgDmBmcGrvczG4A5gBFwPxg3iMAZjbMzF4Mth0CvG5mK4B3gBfc/aUUZj1mYGEuX7hsNL9bvoPt+9vtMxGRXskS7BHKSOXl5d4Vt2jdU9PA1H9ayK3lI/jHz1zQBclERNKTmS1N9qsLvf6b3u0ZUpzH5z42gt9WVLLr0JGw44iIpAUVRgL3TBtHmztzF28KO4qISFpQYSQwckA+n7l4OE+/vY3q2saw44iIhE6F0YH7ZpTR3NrGo69rlCEiosLowJhBBXz6omH88s2tHKhrCjuOiEioVBiduH9GGXVNrfx8yZawo4iIhEqF0YlzhxRx3fnn8MQbm6lpaA47johIaFQYSZg9s4yahhb+882tYUcREQmNCiMJHx3elxkTSnjs9c3UN7V0voGISA+kwkjS7Jnj2V/XxNNvbws7iohIKFQYSbp0dH+uHDeQuYs30dDcGnYcEZFup8I4DbNnllFV28hvllaGHUVEpNupME7DFWMHcuno/jyyaCPNrW1hxxER6VYqjNNgZsyeWcaOg0eY996OsOOIiHQrFcZpmn5uCR8dXsxDCzfQ2tZzLg0vItIZFcZpMjNmzxjPln31/GHlzrDjiIh0GxXGGbh24hDOHVLIgws30KZRhoj0EiqMM5CVZdw/o4x1ew7z8prdYccREekWKSsMMxtpZgvNbI2ZrTazrwXzB5jZfDNbHzz3T7D9ncE6683szlTlPFM3XjiMMYMK+MmCDfSk29yKiCSSyhFGC/C37j4RuBy438wmAt8EXnH38cArwesTmNkA4DvAZcBk4DuJiiUskSzj3unjWL2zhkUfVocdR0Qk5VJWGO6+y92XBdO1wFpgOHAz8GSw2pPALe1s/klgvrvvd/cDwHzgulRlPVOfuXg4w/v14ccL1muUISI9XrccwzCzUuBi4G1giLvvChbtBoa0s8lwYHvc68pgXnvvPcvMKsysorq6e//Sz45kce/0cby37SBvbtzXrZ8tItLdUl4YZlYIPAt83d1r4pd57M/ys/rT3N3nunu5u5eXlJSczVudkVsvHcGQ4lx+smBDt3+2iEh3SmlhmFk2sbJ4yt2fC2bvMbOhwfKhQFU7m+4ARsa9HhHMSzt52RFmTRvHm5v2UbFlf9hxRERSJpVnSRnwGLDW3X8Ut+h54OhZT3cCv29n8z8B15pZ/+Bg97XBvLR0++SRDCzIYc5CjTJEpOdK5QhjCnAHMNPMlgePG4AfAJ8ws/XANcFrzKzczB4FcPf9wPeBd4PH94J5aSk/J8rdU8ew6MNq3q88FHYcEZGUsJ50dk95eblXVFSE8tm1Dc1M+cECrhg3kJ/eUR5KBhGR02VmS909qX+09E3vLlKUl81dU8bwp9V7+HB3bdhxRES6nAqjC901pZSCnAgP6liGiPRAKowu1C8/hzuuKOUPK3eyqfpw2HFERLqUCqOLfWXqGHKiWTy0aGPYUUREupQKo4sNKszl9smjmPfeDrbvrw87johIl1FhpMCsaWOJmPHIqxpliEjPocJIgaF9+3Br+Qh+U1HJ7kMNYccREekSKowUuffqcbS6M3fxprCjiIh0CRVGiowckM8tk4bz9Dtb2Xu4Mew4IiJnTYWRQvfNGEdjSxuPvb457CgiImdNhZFC40oKufHCYfxiyRYO1jeFHUdE5KyoMFLs/hnjqGtq5YklW8KOIiJyVlQYKfaRc4q5duIQfv7GFmobmsOOIyJyxlQY3WD2zDIOHWnml29tCzuKiMgZU2F0gwtH9OPqc0t49LVNHGlqDTuOiMgZUWF0kwdmlrGvroln3tEoQ0Qykwqjm5SXDuDysQP46eKNNDRrlCEimSeV9/R+3MyqzGxV3Lxfx92udYuZLU+w7RYzez9YL5xb6KXAAzPHs6emkd8urQw7iojIaUvlCOMJ4Lr4Ge7+F+4+yd0nAc8Cz3Ww/Yxg3R5zv9Mrxw3k4lH9eHjRRppb28KOIyJyWlJWGO6+GNjf3jIzM+BzwDOp+vx0ZGY8MLOMHQeP8Lv3doQdR0TktIR1DGMqsMfd1ydY7sDLZrbUzGZ1Y66UmzFhMBOHFvPQoo20tnnYcUREkhZWYdxOx6OLq9z9EuB64H4zm5ZoRTObZWYVZlZRXV3d1Tm73NFRxua9dbzw/q6w44iIJK3bC8PMosBngV8nWsfddwTPVcA8YHIH685193J3Ly8pKenquCnxyfPPoWxwIQ8u2ECbRhkikiHCGGFcA3zg7u2eKmRmBWZWdHQauBZY1d66mSory5g9o4wP99Qyf+2esOOIiCQllafVPgO8CUwws0ozuztYdBsn7Y4ys2Fm9mLwcgjwupmtAN4BXnD3l1KVMyw3XjiU0QPzmbNgA+4aZYhI+oum6o3d/fYE87/UzrydwA3B9CbgolTlShfRSBb3TR/H/372fV5dV830CYPDjiQi0iF90ztEn7l4BMP79eEnGmWISAZQYYQoJ5rFV68ey9KtB3hrU7tfWRERSRsqjJD9eflISopymbMw0VdSRETSgwojZHnZEe6ZNpY3Nuxj6dYDYccREUlIhZEGPn/ZKPrnZzNngUYZIpK+VBhpID8nylemjmXhh9Ws2nEo7DgiIu1SYaSJO64YTVFelDkLNoQdRUSkXSqMNFGcl81dV5by0urdrNtTG3YcEZFTqDDSyF1TxpCfE+HBhRpliEj6UWGkkf4FOdxx+Wj+e8VONu+tCzuOiMgJVBhp5u6pY8iOZPHwIo0yRCS9qDDSzOCiPG6fPIrnlu2g8kB92HFERI5RYaShWdPGYgY/fXVT2FFERI5RYaShYf36cOulI/h1xXb21DSEHUdEBFBhpK17ry6jtc352WKNMkQkPagw0tSogfncfNEwnnp7G/sON4YdR0REhZHO7ptRRkNLK4+/sTnsKCIiKb1F6+NmVmVmq+LmfdfMdpjZ8uBxQ4JtrzOzD81sg5l9M1UZ013Z4EJuuGAoTy7ZyqH65rDjiEgvl8oRxhPAde3M/zd3nxQ8Xjx5oZlFgAeB64GJwO1mNjGFOdPa7BllHG5s4YklW8KOIiK9XMoKw90XA2dyG7nJwAZ33+TuTcCvgJu7NFwGOW9oMdecN4TH39jM4caWsOOISC8WxjGM2Wa2Mthl1b+d5cOB7XGvK4N5vdbsmWUcOtLML9/aGnYUEenFurswHgbGAZOAXcC/nu0bmtksM6sws4rq6uqzfbu0NGlkP6aOH8Sjr23iSFNr2HFEpJfq1sJw9z3u3urubcDPiO1+OtkOYGTc6xHBvETvOdfdy929vKSkpGsDp5EHZo5n7+EmfvXutrCjiEgv1a2FYWZD415+BljVzmrvAuPNbIyZ5QC3Ac93R750NnnMACaPGcBPX91EY4tGGSLS/ZIqDDMrMLOsYPpcM7vJzLI72eYZ4E1ggplVmtndwD+b2ftmthKYAfx1sO4wM3sRwN1bgNnAn4C1wH+5++oz/Pl6lAdmlrG7poFnlyYccImIpIy5e+crmS0FpgL9gTeIjQKa3P0LqY13esrLy72ioiLsGCnj7tzy0BL2HW5k4d9NJzui712KyNkxs6XuXp7Musn+i2PuXg98FnjI3f8cOP9MA8qZMTMemFFG5YEjPL98Z9hxRKSXSbowzOwK4AvAC8G8SGoiSUc+ft5gzhtazIOLNtDa1vnoUESkqyRbGF8DvgXMc/fVZjYWWJi6WJKImTF7Rhmbquv446pdYccRkV4k2cIY4u43ufs/Abj7JuC11MWSjlz30XMYV1LAnAUbaNMoQ0S6SbKF8a0k50k3iGQZ988o44PdtbzyQVXYcUSkl4h2tNDMrgduAIab2Y/jFhUDurBRiG66aBj/9j/rmLNgPdecNxgzCzuSiPRwnY0wdgIVQAOwNO7xPPDJ1EaTjkQjWdw3vYwVlYd4bf3esOOISC/QYWG4+wp3fxIoc/cng+nniV1N9kC3JJSEPnvJcIb2zWPOgg1hRxGRXiDZYxjzzazYzAYAy4Cfmdm/pTCXJCE3GuGeaWN5Z8t+3t60L+w4ItLDJVsYfd29htgX937h7pcBH09dLEnWbZNHMagwhzkLNcoQkdRKtjCiwYUDPwf8IYV55DTlZUf4X1PH8tr6vby3TXsJRSR1ki2M7xG7GOBGd383+OLe+tTFktPxhctH0y8/mwc1yhCRFEqqMNz9N+5+obvfG7ze5O5/ltpokqzC3ChfnjKG/1lbxeqdh8KOIyI9VLKXNx9hZvPMrCp4PGtmI1IdTpJ355WlFOVGeWjhxrCjiEgPlewuqZ8TO512WPD472CepIm+fbL5yytH8+KqXWyoOhx2HBHpgZItjBJ3/7m7twSPJ4Ceez/UDPXlKWPIi0Z4aJGOZYhI10u2MPaZ2RfNLBI8vgjoxP80M7Awl89fNorfL9/Jtn31YccRkR4m2cL4MrFTancDu4BbgS+lKJOchVnTxhIx4+FXdSxDRLrW6ZxWe6e7l7j7YGIF8g8dbWBmjwcHyFfFzfuhmX1gZiuDg+j9Emy7Jbj393Iz67n3XE2BIcV5fO5jI/jt0u3sOnQk7Dgi0oMkWxgXxl87yt33Axd3ss0TwHUnzZsPfNTdLwTW0fEl0me4+6Rk7zUrx90zbRzu8NNXN4UdRUR6kGQLI8vM+h99EVxTqsNLo7v7YmD/SfNedvejl0V/C9CpuSkwckA+n7l4OM+8s43q2saw44hID5FsYfwr8KaZfd/Mvg8sAf75LD/7y8AfEyxz4GUzW2pmszp6EzObZWYVZlZRXV19lpF6jnunj6O5tY1HX9coQ0S6RrLf9P4FsQsP7gken3X3/zzTDzWzbxO7AdNTCVa5yt0vAa4H7jezaR1km+vu5e5eXlKiM32PGltSyI0XDuOXb27lQF1T2HFEpAdIdoSBu69x9znBY82ZfqCZfQm4EfiCu7d7Q2p33xE8VwHzgMln+nm92f0zyqhrauXnS7aEHUVEeoCkC6MrmNl1wDeAm9y93S8KmFmBmRUdnQauBVa1t650bMI5RXzy/CE88cZmahqaw44jIhkuZYVhZs8AbwITzKzSzO4G5gBFxG7ItNzMHgnWHWZmLwabDgFeN7MVwDvAC+7+Uqpy9nSzZ4ynpqGF/3xza9hRRCTDdXim09lw99vbmf1YgnV3AjcE05uAi1KVq7e5YERfpk8o4bHXN3PXlFLyc1L2P7mI9HDduktKwvHAzDL21zXx9Nvbwo4iIhlMhdELXDp6AFeMHcjcxZtoaG4NO46IZCgVRi/xwMwyqmob+c3SyrCjiEiGUmH0EleMG8glo/rxyKKNNLe2hR1HRDKQCqOXMDMemDmeHQePMO+9HWHHEZEMpMLoRaZPKOH8YcU8tHADrW3tfmdSRCQhFUYvEhtllLFlXz1/WLkz7DgikmFUGL3MtRPPYfzgQh5cuIE2jTJE5DSoMHqZrCxj9swy1u05zMtr9oQdR0QyiAqjF/rUBUMpHZjPnIXrSXD9RxGRU6gweqFoJIv7ppexakcNi9bpHiIikhwVRi91y8XDGdY3j5+8olGGiCRHhdFL5USz+Or0cSzbdpA3N+0LO46IZAAVRi/2ufKRlBTlMmfBhrCjiEgGUGH0YnnZEWZNHcuSjftYuvVA2HFEJM2pMHq5z182iv752Ty4UKMMEemYCqOXK8iNcvdVY1jwQRWrdhwKO46IpLGUFoaZPW5mVWa2Km7eADObb2brg+f+Cba9M1hnvZndmcqcvd1fXllKUV5UowwR6VCqRxhPANedNO+bwCvuPh54JXh9AjMbAHwHuAyYDHwnUbHI2SvOy+ZLV5byx1W7WbenNuw4IpKmUloY7r4Y2H/S7JuBJ4PpJ4Fb2tn0k8B8d9/v7geA+ZxaPNKF7poyhvycCA9plCEiCYRxDGOIu+8KpncDQ9pZZziwPe51ZTBPUmRAQQ5fvHw0z6/YyZa9dWHHEZE0FOpBb499xfisvmZsZrPMrMLMKqqrdZmLs/GVqWOIRrJ4eNHGsKOISBoKozD2mNlQgOC5qp11dgAj416PCOadwt3nunu5u5eXlJR0edjeZHBRHrd/bCTPLqtkx8EjYccRkTQTRmE8Dxw96+lO4PftrPMn4Foz6x8c7L42mCcpNuvqcZjBT1/VKENETpTq02qfAd4EJphZpZndDfwA+ISZrQeuCV5jZuVm9iiAu+8Hvg+8Gzy+F8yTFBverw9/dskIfvXudqpqGsKOIyJpxHrSlUrLy8u9oqIi7BgZb+u+Omb8yyLuvmoM3/7UxLDjiEgKmdlSdy9PZl1901tOMXpgATdPGs4v39rG/rqmsOOISJpQYUi77ps+joaWVh5/fXPYUUQkTagwpF3jhxRx/UfP4cklWzh0pDnsOCKSBlQYktD9M8qobWzhF0u2hB1FRNKACkMSOn9YXz7+kcE89sZm6hpbwo4jIiFTYUiH7p9ZxsH6Zp56e2vYUUQkZCoM6dAlo/pzVdkg5i7eTENza9hxRCREKgzp1OyZZew93Miv393e+coi0mOpMKRTl40ZwMdK+/PIqxtpamkLO46IhESFIZ0yM+6fUcauQw08t6wy7DgiEhIVhiTl6nNLuGB4Xx5atJGWVo0yRHojFYYkxcyYPbOMbfvr+e+VO8OOIyIhUGFI0j5x3hAmDClizoINtLX1nItWikhyVBiStKws4/6ZZWysruOl1bvDjiMi3UyFIaflUxcMZcygAuYs2EBPujS+iHROhSGnJZJl3Dd9HGt21bDww/burisiPZUKQ07bLRcPZ3i/Pvz4FY0yRHqTbi8MM5tgZsvjHjVm9vWT1pluZofi1vm/3Z1TEsuOZHHv9HEs336QJRv3hR1HRLpJtxeGu3/o7pPcfRJwKVAPzGtn1deOrufu3+velNKZWy8dwZDiXH6yYH3YUUSkm4S9S+rjwEZ316VQM0xedoRZ08bx1qb9vLtlf9hxRKQbhF0YtwHPJFh2hZmtMLM/mtn53RlKknP75JEMLMhhzoINYUcRkW4QWmGYWQ5wE/CbdhYvA0a7+0XAT4DfdfA+s8yswswqqqurUxNW2pWfE+XuqWN4dV01KysPhh1HRFIszBHG9cAyd99z8gJ3r3H3w8H0i0C2mQ1q703cfa67l7t7eUlJSWoTyynuuHw0xXlRjTJEeoEwC+N2EuyOMrNzzMyC6cnEcup0nDRUlJfNXVPG8PKaPXywuybsOCKSQqEUhpkVAJ8Anoub91Uz+2rw8lZglZmtAH4M3OY64T9t3TWllIKcCA8u3Bh2FBFJoVAKw93r3H2gux+Km/eIuz8STM9x9/Pd/SJ3v9zdl4SRU5LTLz+HO64o5Q8rd7Kx+nDYcUQkRcI+S0p6iK9MHUNuNIuHF2mUIdJTqTCkSwwqzOX2yaOY994Otu+vDzuOiKSACkO6zKxpY4mY8cirGmWI9EQqDOkyQ/v24dbyEfymopLdhxrCjiMiXUyFIV3q3qvH0erO3MWbwo4iIl1MhSFdauSAfG6ZNJyn39nK3sONYccRkS6kwpAud9+McTS2tPHY65vDjiIiXUiFIV1uXEkhN1wwlF8s2cLB+qaw44hIF1FhSErMnlFGXVMrTyzZEnYUEekiKgxJifOGFnPNeUP4+RtbqG1oDjuOiHSBaNgBpOeaPbOMWx58g8v+8RUGFOTQPz+HfvnZ9M/PoX9+Nv2C5/4FOceng3UKc6ME158UkTShwpCUmTSyHz+5/WKWbTvAwfpmDtQ3caC+mW376zlQ10RNQ0vCbbMjRt8+xwvleJmcWCxHl/XLz6Ffn2yiEQ2aRVJFhSEp9emLhvHpi4a1u6yltY1DR44XyYG6phOK5WB907HpzXvrWFZ/kIP1TTS3Jr5wcVFe9JQRTOw5h/4F2acWTn4O+TkRjWZEkqDCkNBEI1kMLMxlYGFu0tu4O3VNrSeVy/Hp+MI5UN/Epr2HOVjXTG1j4tFMTiTrWHkU94kSzcoiGjGyI1lEs4LniBHNyiI7YidNx9Y5vs2py07dPovsrGB5xMiO2zaSdeI2R5cdnc7KUrFJeFQYklHMjMLcKIW5UUYOSH675tY2Dh4btTSzv67p2HT8SOZwQwstbW0caXZa2tpoaXWaW9toafMTpptbY8ta2to6HPF0tSzjhMLpkx0hPydCn5wIBTlR+uTEXufnRIPn49MdLcvPiZCfG6VPdoSISkkSUGFIr5AdyaKkKJeSouRHM8lyd1rb/KQiObVwTi6ZllanOVinpbWN5rbY89H5rW0erHfi9s1x79vQ3Ep909FHCwfrm9hxsJUjwev6plYaW9pO6+fJjWZREJRHe6UTX0wnrJcbJT++wOKWFeRGyY1maddfhlNhiJwls2A3VQTysiNhxzlFS2sbR5pjJVIXFMnR6SNBqdTHFUxsWfx0bL3dNc0nLKtvaqW1LfnRVXbEKCnMPVbcJUV5x6YHH50XLE/H36OoMER6vGgki6JIFkV52V36vu5OU2tbMJrprHRaqDnSQnVtI9WHG9lxsIHl2w+yr66J9m6+XJwXDYqknVKJm9+vT7aO63Sj0ArDzLYAtUAr0OLu5SctN+A/gBuAeuBL7r6su3OKSPvMjNxohNxohH75Z/YeLa1t7K9roqq2MVYmtY1U1TYcK5aqmkZWVB6kqqaRI82tp2yfHTEGFR4fnQwuzo0bxeSdUDQatZy9sEcYM9x9b4Jl1wPjg8dlwMPBs4j0ENFIFoOL8xhcnNfpuocbW04tldrGY2Wz61ADKyoPsa+usd1RS9GxUUtQJicVzNHp/vk5GrUkEHZhdORm4Bfu7sBbZtbPzIa6+66wg4lI9zt6dtyYQQUdrtfS2sb++iaqamKjlOraEx9VtQ28X3mQqtpG6ptOHbVEs2KjlnGDC5g4tJjzgkfZ4EKye/kXQ8MsDAdeNjMHfuruc09aPhzYHve6Mph3QmGY2SxgFsCoUaNSl1ZEMkI0ksXgojwGF3U+aqkLRi3Hd4k1UFXbyO6aBtbvOcyTb26lKTjLLCeSRdngQs4bWszEYcWcN7SIiUOL6Zefk+ofKW2EWRhXufsOMxsMzDezD9x98em+SVA0cwHKy8u774R4Ecl4BblRCnKjlCYYtbS0trF5bx1rdtWwZlcNa3fVsnh9Nc8uqzy2ztC+eSeMRCYOK2b0gPweuVsrtMJw9x3Bc5WZzQMmA/GFsQMYGfd6RDBPRKRbRCNZjB9SxPghRdw8afix+dW1jazdVcPaY0VSw6J11cdOM87PiTDhnKITiuQj5xRRkJvORwE6F0p6MysAsty9Npi+FvjeSas9D8w2s18RO9h9SMcvRCQdxM7CKmHauSXH5jU0t7J+z+FjJbJmVw3Pr9jJU29vA8AMSgcWHNuVdbRIhvbNy5gvNIZVd0OAecEvKQo87e4vmdlXAdz9EeBFYqfUbiB2Wu1dIWUVEelUXnaEC0b05YIRfY/Nc3d2HDzCmp2x3Vlrd9WwemcNL76/+9g6/fKzOe+c4hOOjYwfXERONP0OsJu3d/5ZhiovL/eKioqwY4iIdKi2oZkPd9fGjUZq+XB3DQ3NsQPs0SyjbHDhsZFIrEiKGVDQ9QfYzWzpyd+DSySzd6iJiGSgorxsyksHUF56/AqarW3O5r11x46NrN1Vwxsb9/Lce8cP3Q4pzj3lAHvpwIJuu2CkCkNEJA1EglFF2eDCE+4hs7+uKTYS2Xn8IPtr6/fSEhxgz8vO4oLhffmve65I+bEQFYaISBobUJDDlLJBTCkbdGxeY0srG6oOHzsuUtfY0i0HzlUYIiIZJjca4fxhfTl/WN/OV+5C6XcYXkRE0pIKQ0REkqLCEBGRpKgwREQkKSoMERFJigpDRESSosIQEZGkqDBERCQpPerig2ZWDWw9w80HAYnuL55uMikrZFbeTMoKmZU3k7JCZuU9m6yj3b2k89V6WGGcDTOrSPaKjWHLpKyQWXkzKStkVt5MygqZlbe7smqXlIiIJEWFISIiSVFhHDc37ACnIZOyQmblzaSskFl6XFaQAAAHXElEQVR5MykrZFbebsmqYxgiIpIUjTBERCQpvb4wzOw6M/vQzDaY2TfDztMRM3vczKrMbFXYWTpjZiPNbKGZrTGz1Wb2tbAzdcTM8szsHTNbEeT9h7AzdcbMImb2npn9IewsnTGzLWb2vpktN7OKsPN0xMz6mdlvzewDM1trZleEnSkRM5sQ/E6PPmrM7Osp+7zevEvKzCLAOuATQCXwLnC7u68JNVgCZjYNOAz8wt0/GnaejpjZUGCouy8zsyJgKXBLGv9uDShw98Nmlg28DnzN3d8KOVpCZvY3QDlQ7O43hp2nI2a2BSh397T/XoOZPQm85u6PmlkOkO/uB8PO1Zng37MdwGXufqbfR+tQbx9hTAY2uPsmd28CfgXcHHKmhNx9MbA/7BzJcPdd7r4smK4F1gLDw02VmMccDl5mB4+0/WvKzEYAnwIeDTtLT2JmfYFpwGMA7t6UCWUR+DiwMVVlASqM4cD2uNeVpPE/apnKzEqBi4G3w03SsWAXz3KgCpjv7umc99+BbwBtYQdJkgMvm9lSM5sVdpgOjAGqgZ8Hu/seNbOCsEMl6TbgmVR+QG8vDEkxMysEngW+7u41YefpiLu3uvskYAQw2czScrefmd0IVLn70rCznIar3P0S4Hrg/mD3ajqKApcAD7v7xUAdkNbHNgGCXWc3Ab9J5ef09sLYAYyMez0imCddIDgW8CzwlLs/F3aeZAW7IBYC14WdJYEpwE3BcYFfATPN7JfhRuqYu+8InquAecR2B6ejSqAybnT5W2IFku6uB5a5+55UfkhvL4x3gfFmNiZo6NuA50PO1CMEB5EfA9a6+4/CztMZMysxs37BdB9iJ0J8EG6q9rn7t9x9hLuXEvtvdoG7fzHkWAmZWUFw4gPB7p1rgbQ808/ddwPbzWxCMOvjQFqeqHGS20nx7iiIDb96LXdvMbPZwJ+ACPC4u68OOVZCZvYMMB0YZGaVwHfc/bFwUyU0BbgDeD84LgDw9+7+YoiZOjIUeDI40yQL+C93T/vTVTPEEGBe7G8IosDT7v5SuJE69ADwVPBH5CbgrpDzdCgo4U8A96T8s3rzabUiIpK83r5LSkREkqTCEBGRpKgwREQkKSoMERFJigpDRESSosKQjGBmS4LnUjP7fBe/99+391ld/BlDzexlM5ueqqvLmtl3zezvzvI9tpjZoK7KJD2LCkMygrtfGUyWAqdVGGbW2feNTiiMuM/qStcR+75Pt0vi5xdJigpDMoKZHb2S7A+AqcG1//86uGDgD83sXTNbaWb3BOtPN7PXzOx5gm/qmtnvgovfrT56ATwz+wHQJ3i/p+I/y2J+aGargns5/EXcey+Ku2fCU8E32zGzHwT3AFlpZv8S9yNcB/wxmC5MsO2xv+7NrNzMFgXT37XYvVAWmdkmM/uruN/Lt81snZm9DkyIm7/IzP7dYvee+FrwTfZng9/Tu2Y2JVhvYDDyWW1mjwJHsxSY2QsWuz/IqqM/u/Ry7q6HHmn/AA4Hz9OBP8TNnwX8n2A6F6ggdsXR6cQuHDcmbt0BwXMfYpemGBj/3u181p8B84ldBWAIsI3YN8KnA4eIXXssC3gTuAoYCHzI8S/E9gueI8DyuPynbBss2wIMCqbLgUXB9HeBJcHPNwjYR+zy65cC7wP5QDGwAfi7YJtFwENxP9PTcZ8zitglWwB+DPzfYPpTxK4qOyj42X8Wt33fsP8b0CP8h4aqkumuBS40s1uD132B8UAT8I67b45b96/M7DPB9MhgvX0dvPdVwDPu3grsMbNXgY8BNcF7VwIElz4pBd4CGoDHguMUR49VXMaJl3Zvb9vXO/k5X3D3RqDRzKqIFdhUYJ671wfvdfJ10H4dN30NMDEYzAAUB1cSngZ8FsDdXzCzA8Hy94F/NbN/IlbQr3WST3oBFYZkOgMecPcTjg+Y2XRiI4z419cAV7h7fbC7J+8sPrcxbroViHrs2mSTiV2w7lZgNjCT2JVEX+po22C6heO7iU/OlmibjtTFTWcBl7t7Q/wKcQVyAndfZ2aXADcA/8/MXnH37yXxmdKD6RiGZJpaoCju9Z+Aey12KXXM7Fxr/4Y3fYEDQVl8BLg8blnz0e1P8hrwF8FxkhJif42/kyhY8Bd7X49dYPGvgYuCRR8H/ieJn20Lsd1MENsl1JnFwC1m1ie4GuynO1j3ZWIX1TuadVLce3w+mHc90D+YHgbUu/svgR+SGZf4lhTTCEMyzUqg1cxWAE8A/0Fsl86y4OBxNXBLO9u9BHzVzNYSO84Qf6/uucBKM1vm7l+Imz8PuAJYQWzf/jfcfXdQOO0pAn5vZnnERj5/ExRNg8duU9uZfyC2O+v7xI5BdMhj90v/dZCvitjl+hP5K+BBM1tJ7P/3i4GvBp/5jJmtJnacZFuw/gXAD82sDWgG7k0iv/RwulqtSAqZ2ReBEe7+g7CziJwtFYaIiCRFxzBERCQpKgwREUmKCkNERJKiwhARkaSoMEREJCkqDBERSYoKQ0REkvL/AQq4JIbEQNd4AAAAAElFTkSuQmCC\n", 947 | "text/plain": [ 948 | "
" 949 | ] 950 | }, 951 | "metadata": { 952 | "needs_background": "light" 953 | }, 954 | "output_type": "display_data" 955 | }, 956 | { 957 | "name": "stdout", 958 | "output_type": "stream", 959 | "text": [ 960 | "训练集的准确率为:0.9\n" 961 | ] 962 | } 963 | ], 964 | "source": [ 965 | "convNet = SimpleConvNet()\n", 966 | "#拿20张先做实验\n", 967 | "train_X = X_train[0:10]\n", 968 | "train_y = y_train[0:10]\n", 969 | "convNet.fit(train_X,train_y)" 970 | ] 971 | }, 972 | { 973 | "cell_type": "markdown", 974 | "metadata": {}, 975 | "source": [ 976 | "## 预测" 977 | ] 978 | }, 979 | { 980 | "cell_type": "code", 981 | "execution_count": 120, 982 | "metadata": {}, 983 | "outputs": [ 984 | { 985 | "name": "stdout", 986 | "output_type": "stream", 987 | "text": [ 988 | "训练的准确率为:0.9\n" 989 | ] 990 | } 991 | ], 992 | "source": [ 993 | "logits = convNet.predicate(X_train[0:10])\n", 994 | "m = 10\n", 995 | "accuracy = np.sum(np.argmax(logits,axis=1) == np.argmax(y_train[0:10],axis=1))/m\n", 996 | "print(\"训练的准确率为:%g\" %(accuracy))" 997 | ] 998 | }, 999 | { 1000 | "cell_type": "code", 1001 | "execution_count": 121, 1002 | "metadata": {}, 1003 | "outputs": [ 1004 | { 1005 | "name": "stdout", 1006 | "output_type": "stream", 1007 | "text": [ 1008 | "y is:5\n", 1009 | "your predicate result is :5\n" 1010 | ] 1011 | }, 1012 | { 1013 | "data": { 1014 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADgdJREFUeJzt3X9sXfV5x/HPs9D8QRoIXjUTpWFpIhQUIuZOJkwoGkXM5YeCggGhWkLKRBT3j1ii0hQNZX8MNAVFg2RqBKrsqqHJ1KWZBCghqpp0CZBOTBEmhF9mKQylqi2TFAWTH/zIHD/74x53Lvh+r3Pvufdc+3m/JMv3nuecex4d5ZPz8/pr7i4A8fxJ0Q0AKAbhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q1GWNXJmZ8TghUGfublOZr6Y9v5ndYWbHzex9M3ukls8C0FhW7bP9ZjZL0m8kdUgalPSqpC53H0gsw54fqLNG7PlXSHrf3T9w9wuSfi5pdQ2fB6CBagn/Akm/m/B+MJv2R8ys28z6zay/hnUByFndL/i5e5+kPonDfqCZ1LLnH5K0cML7b2bTAEwDtYT/VUnXmtm3zGy2pO9J2ptPWwDqrerDfncfNbMeSfslzZK03d3fya0zAHVV9a2+qlbGOT9Qdw15yAfA9EX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFUP0S1JZnZC0llJFyWNunt7Hk0hP7NmzUrWr7zyyrquv6enp2zt8ssvTy67dOnSZH39+vXJ+pNPPlm21tXVlVz2888/T9Y3b96crD/22GPJejOoKfyZW939oxw+B0ADcdgPBFVr+F3SATN7zcy682gIQGPUeti/0t2HzOzPJP3KzP7b3Q9PnCH7T4H/GIAmU9Oe392Hst+nJD0vacUk8/S5ezsXA4HmUnX4zWyOmc0dfy3pu5LezqsxAPVVy2F/q6TnzWz8c/7N3X+ZS1cA6q7q8Lv7B5L+IsdeZqxrrrkmWZ89e3ayfvPNNyfrK1euLFubN29ectn77rsvWS/S4OBgsr5t27ZkvbOzs2zt7NmzyWXfeOONZP3ll19O1qcDbvUBQRF+ICjCDwRF+IGgCD8QFOEHgjJ3b9zKzBq3sgZqa2tL1g8dOpSs1/trtc1qbGwsWX/ooYeS9XPnzlW97uHh4WT9448/TtaPHz9e9brrzd1tKvOx5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoLjPn4OWlpZk/ciRI8n64sWL82wnV5V6HxkZSdZvvfXWsrULFy4kl436/EOtuM8PIInwA0ERfiAowg8ERfiBoAg/EBThB4LKY5Te8E6fPp2sb9iwIVlftWpVsv76668n65X+hHXKsWPHkvWOjo5k/fz588n69ddfX7b28MMPJ5dFfbHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKn6f38y2S1ol6ZS7L8+mtUjaLWmRpBOSHnD39B8618z9Pn+trrjiimS90nDSvb29ZWtr165NLvvggw8m67t27UrW0Xzy/D7/TyXd8aVpj0g66O7XSjqYvQcwjVQMv7sflvTlR9hWS9qRvd4h6Z6c+wJQZ9We87e6+/h4Rx9Kas2pHwANUvOz/e7uqXN5M+uW1F3regDkq9o9/0kzmy9J2e9T5WZ09z53b3f39irXBaAOqg3/XklrstdrJO3Jpx0AjVIx/Ga2S9J/SVpqZoNmtlbSZkkdZvaepL/J3gOYRiqe87t7V5nSbTn3EtaZM2dqWv6TTz6petl169Yl67t3707Wx8bGql43isUTfkBQhB8IivADQRF+ICjCDwRF+IGgGKJ7BpgzZ07Z2gsvvJBc9pZbbknW77zzzmT9wIEDyToajyG6ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQ3Oef4ZYsWZKsHz16NFkfGRlJ1l988cVkvb+/v2zt6aefTi7byH+bMwn3+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUNznD66zszNZf+aZZ5L1uXPnVr3ujRs3Jus7d+5M1oeHh5P1qLjPDyCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCqnif38y2S1ol6ZS7L8+mPSppnaTfZ7NtdPdfVFwZ9/mnneXLlyfrW7duTdZvu636kdx7e3uT9U2bNiXrQ0NDVa97OsvzPv9PJd0xyfR/cfe27Kdi8AE0l4rhd/fDkk43oBcADVTLOX+Pmb1pZtvN7KrcOgLQENWG/0eSlkhqkzQsaUu5Gc2s28z6zaz8H3MD0HBVhd/dT7r7RXcfk/RjSSsS8/a5e7u7t1fbJID8VRV+M5s/4W2npLfzaQdAo1xWaQYz2yXpO5K+YWaDkv5R0nfMrE2SSzoh6ft17BFAHfB9ftRk3rx5yfrdd99dtlbpbwWYpW9XHzp0KFnv6OhI1mcqvs8PIInwA0ERfiAowg8ERfiBoAg/EBS3+lCYL774Ilm/7LL0Yyijo6PJ+u2331629tJLLyWXnc641QcgifADQRF+ICjCDwRF+IGgCD8QFOEHgqr4fX7EdsMNNyTr999/f7J+4403lq1Vuo9fycDAQLJ++PDhmj5/pmPPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZ9/hlu6dGmy3tPTk6zfe++9yfrVV199yT1N1cWLF5P14eHhZH1sbCzPdmYc9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFTF+/xmtlDSTkmtklxSn7v/0MxaJO2WtEjSCUkPuPvH9Ws1rkr30ru6usrWKt3HX7RoUTUt5aK/vz9Z37RpU7K+d+/ePNsJZyp7/lFJf+fuyyT9laT1ZrZM0iOSDrr7tZIOZu8BTBMVw+/uw+5+NHt9VtK7khZIWi1pRzbbDkn31KtJAPm7pHN+M1sk6duSjkhqdffx5ys/VOm0AMA0MeVn+83s65KelfQDdz9j9v/Dgbm7lxuHz8y6JXXX2iiAfE1pz29mX1Mp+D9z9+eyySfNbH5Wny/p1GTLunufu7e7e3seDQPIR8XwW2kX/xNJ77r71gmlvZLWZK/XSNqTf3sA6qXiEN1mtlLSryW9JWn8O5IbVTrv/3dJ10j6rUq3+k5X+KyQQ3S3tqYvhyxbtixZf+qpp5L166677pJ7ysuRI0eS9SeeeKJsbc+e9P6Cr+RWZ6pDdFc853f3/5RU7sNuu5SmADQPnvADgiL8QFCEHwiK8ANBEX4gKMIPBMWf7p6ilpaWsrXe3t7ksm1tbcn64sWLq+opD6+88kqyvmXLlmR9//79yfpnn312yT2hMdjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQYe7z33TTTcn6hg0bkvUVK1aUrS1YsKCqnvLy6aeflq1t27Ytuezjjz+erJ8/f76qntD82PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBh7vN3dnbWVK/FwMBAsr5v375kfXR0NFlPfed+ZGQkuSziYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0GZu6dnMFsoaaekVkkuqc/df2hmj0paJ+n32awb3f0XFT4rvTIANXN3m8p8Uwn/fEnz3f2omc2V9JqkeyQ9IOmcuz851aYIP1B/Uw1/xSf83H1Y0nD2+qyZvSup2D9dA6Bml3TOb2aLJH1b0pFsUo+ZvWlm283sqjLLdJtZv5n119QpgFxVPOz/w4xmX5f0sqRN7v6cmbVK+kil6wD/pNKpwUMVPoPDfqDOcjvnlyQz+5qkfZL2u/vWSeqLJO1z9+UVPofwA3U21fBXPOw3M5P0E0nvTgx+diFwXKekty+1SQDFmcrV/pWSfi3pLUlj2eSNkroktal02H9C0vezi4Opz2LPD9RZrof9eSH8QP3ldtgPYGYi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBNXoIbo/kvTbCe+/kU1rRs3aW7P2JdFbtfLs7c+nOmNDv8//lZWb9bt7e2ENJDRrb83al0Rv1SqqNw77gaAIPxBU0eHvK3j9Kc3aW7P2JdFbtQrprdBzfgDFKXrPD6AghYTfzO4ws+Nm9r6ZPVJED+WY2Qkze8vMjhU9xFg2DNopM3t7wrQWM/uVmb2X/Z50mLSCenvUzIaybXfMzO4qqLeFZvaimQ2Y2Ttm9nA2vdBtl+irkO3W8MN+M5sl6TeSOiQNSnpVUpe7DzS0kTLM7ISkdncv/J6wmf21pHOSdo6PhmRm/yzptLtvzv7jvMrd/75JentUlzhyc516Kzey9N+qwG2X54jXeShiz79C0vvu/oG7X5D0c0mrC+ij6bn7YUmnvzR5taQd2esdKv3jabgyvTUFdx9296PZ67OSxkeWLnTbJfoqRBHhXyDpdxPeD6q5hvx2SQfM7DUz6y66mUm0ThgZ6UNJrUU2M4mKIzc30pdGlm6abVfNiNd544LfV61097+UdKek9dnhbVPy0jlbM92u+ZGkJSoN4zYsaUuRzWQjSz8r6QfufmZirchtN0lfhWy3IsI/JGnhhPffzKY1BXcfyn6fkvS8SqcpzeTk+CCp2e9TBffzB+5+0t0vuvuYpB+rwG2XjSz9rKSfuftz2eTCt91kfRW13YoI/6uSrjWzb5nZbEnfk7S3gD6+wszmZBdiZGZzJH1XzTf68F5Ja7LXayTtKbCXP9IsIzeXG1laBW+7phvx2t0b/iPpLpWu+P+PpH8ooocyfS2W9Eb2807RvUnapdJh4P+qdG1kraQ/lXRQ0nuS/kNSSxP19q8qjeb8pkpBm19QbytVOqR/U9Kx7Oeuorddoq9CthtP+AFBccEPCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ/weCC5r/92q6mAAAAABJRU5ErkJggg==\n", 1015 | "text/plain": [ 1016 | "
" 1017 | ] 1018 | }, 1019 | "metadata": { 1020 | "needs_background": "light" 1021 | }, 1022 | "output_type": "display_data" 1023 | } 1024 | ], 1025 | "source": [ 1026 | "index = 0\n", 1027 | "plt.imshow(X_train[index].reshape((28,28)),cmap = plt.cm.gray)\n", 1028 | "print(\"y is:\"+str(np.argmax(y_train[index])))\n", 1029 | "print(\"your predicate result is :\"+str(np.argmax(logits[index])))" 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "code", 1034 | "execution_count": 99, 1035 | "metadata": {}, 1036 | "outputs": [ 1037 | { 1038 | "name": "stdout", 1039 | "output_type": "stream", 1040 | "text": [ 1041 | "测试的准确率为:0.1031\n" 1042 | ] 1043 | } 1044 | ], 1045 | "source": [ 1046 | "logits = convNet.predicate(X_test)\n", 1047 | "m = X_test.shape[0]\n", 1048 | "accuracy = np.sum(np.argmax(logits,axis=1) == np.argmax(y_test,axis=1))/m\n", 1049 | "print(\"测试的准确率为:%g\" %(accuracy))" 1050 | ] 1051 | }, 1052 | { 1053 | "cell_type": "code", 1054 | "execution_count": null, 1055 | "metadata": {}, 1056 | "outputs": [], 1057 | "source": [] 1058 | }, 1059 | { 1060 | "cell_type": "code", 1061 | "execution_count": null, 1062 | "metadata": {}, 1063 | "outputs": [], 1064 | "source": [] 1065 | } 1066 | ], 1067 | "metadata": { 1068 | "kernelspec": { 1069 | "display_name": "Python 3", 1070 | "language": "python", 1071 | "name": "python3" 1072 | }, 1073 | "language_info": { 1074 | "codemirror_mode": { 1075 | "name": "ipython", 1076 | "version": 3 1077 | }, 1078 | "file_extension": ".py", 1079 | "mimetype": "text/x-python", 1080 | "name": "python", 1081 | "nbconvert_exporter": "python", 1082 | "pygments_lexer": "ipython3", 1083 | "version": "3.7.3" 1084 | }, 1085 | "toc": { 1086 | "base_numbering": 1, 1087 | "nav_menu": {}, 1088 | "number_sections": true, 1089 | "sideBar": false, 1090 | "skip_h1_title": false, 1091 | "title_cell": "Table of Contents", 1092 | "title_sidebar": "Contents", 1093 | "toc_cell": false, 1094 | "toc_position": { 1095 | "height": "681.328125px", 1096 | "left": "22px", 1097 | "top": "112.328125px", 1098 | "width": "165px" 1099 | }, 1100 | "toc_section_display": true, 1101 | "toc_window_display": true 1102 | }, 1103 | "varInspector": { 1104 | "cols": { 1105 | "lenName": 16, 1106 | "lenType": 16, 1107 | "lenVar": 40 1108 | }, 1109 | "kernels_config": { 1110 | "python": { 1111 | "delete_cmd_postfix": "", 1112 | "delete_cmd_prefix": "del ", 1113 | "library": "var_list.py", 1114 | "varRefreshCmd": "print(var_dic_list())" 1115 | }, 1116 | "r": { 1117 | "delete_cmd_postfix": ") ", 1118 | "delete_cmd_prefix": "rm(", 1119 | "library": "var_list.r", 1120 | "varRefreshCmd": "cat(var_dic_list()) " 1121 | } 1122 | }, 1123 | "types_to_exclude": [ 1124 | "module", 1125 | "function", 1126 | "builtin_function_or_method", 1127 | "instance", 1128 | "_Feature" 1129 | ], 1130 | "window_display": false 1131 | } 1132 | }, 1133 | "nbformat": 4, 1134 | "nbformat_minor": 2 1135 | } 1136 | --------------------------------------------------------------------------------