├── Chapter_3 ├── digitRecognizer.ipynb ├── linearRegression.ipynb ├── logistRegression │ ├── data.txt │ └── logisticRegression.ipynb └── polynomialRegression.ipynb ├── Chapter_4 ├── classRNNs.ipynb └── simpleCNN.ipynb ├── Chapter_5 ├── N_Gram.ipynb ├── rnnModule.ipynb └── sequencePrediction │ ├── FullConnection.ipynb │ ├── data.csv │ ├── gru.ipynb │ ├── lstm.ipynb │ ├── rnn.ipynb │ ├── seqInit.ipynb │ └── seqInit.py ├── Chapter_6 └── autoEncoder.ipynb └── README.md /Chapter_3/digitRecognizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch import nn, optim\n", 11 | "from torch.autograd import Variable" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# 最简单的层次网络\n", 21 | "class simpleNet(nn.Module) :\n", 22 | " def __init__(self, in_dim, hidden_dim, out_dim) :\n", 23 | " super().__init__()\n", 24 | " i, self.layer = 1, nn.Sequential()\n", 25 | " for h_dim in hidden_dim :\n", 26 | " self.layer.add_module('layer_{}'.format(i), nn.Linear(in_dim, h_dim))\n", 27 | " i, in_dim = i + 1, h_dim\n", 28 | " self.layer.add_module('layer_{}'.format(i), nn.Linear(in_dim, out_dim))\n", 29 | " self.layerNum = i\n", 30 | " \n", 31 | " def forward(self, x) :\n", 32 | " x = self.layer(x)\n", 33 | " return x" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# 添加激活函数,增加网络的非线性\n", 43 | "\n", 44 | "class Activation_Net(nn.Module) :\n", 45 | " def __init__(self, in_dim, hidden_dim, out_dim) :\n", 46 | " super().__init__()\n", 47 | " i, self.layer = 1, nn.Sequential()\n", 48 | " for h_dim in hidden_dim :\n", 49 | " self.layer.add_module('layer_{}'.format(i), nn.Sequential(nn.Linear(in_dim, h_dim), nn.ReLU(True)))\n", 50 | " i, in_dim = i + 1, h_dim\n", 51 | " self.layer.add_module('layer_{}'.format(i), nn.Sequential(nn.Linear(in_dim, out_dim)))\n", 52 | " self.layerNum = i\n", 53 | " def forward(self, x) :\n", 54 | " x = self.layer(x)\n", 55 | " return x" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# 添加批标准化的网络\n", 65 | "\n", 66 | "class Batch_net(nn.Module) :\n", 67 | " def __init__(self, in_dim, hidden_dim, out_dim) :\n", 68 | " super().__init__()\n", 69 | " i, self.layer = 1, nn.Sequential()\n", 70 | " for h_dim in hidden_dim :\n", 71 | " self.layer.add_module('layer_{}'.format(i), \n", 72 | " nn.Sequential(nn.Linear(in_dim, h_dim), nn.BatchNorm1d(h_dim), nn.ReLU(True)))\n", 73 | " i, in_dim = i + 1, h_dim\n", 74 | " self.layer.add_module('layer_{}'.format(i), nn.Sequential(nn.Linear(in_dim, out_dim)))\n", 75 | " self.layerNum = i\n", 76 | " def forward(self, x) :\n", 77 | " x = self.layer(x)\n", 78 | " return x" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# 确定超参数\n", 88 | "\n", 89 | "epoch_size = 700\n", 90 | "learning_rate = 1e-2\n", 91 | "num_epoches = 60" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "---\n", 99 | "这里本身自带手写数字的数据集,是从kaggle上下载的。给出[链接](https://www.kaggle.com/c/digit-recognizer/data)\n", 100 | "\n", 101 | "保存为名为*/data* 的文件夹,文件结构如下 :\n", 102 | "```\n", 103 | "data\n", 104 | "├── sample_submission.csv\n", 105 | "├── test.csv\n", 106 | "└── train.csv\n", 107 | "\n", 108 | "0 directories, 3 files\n", 109 | "```\n", 110 | "---" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 6, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "42000 784\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "# 获得训练数据 - train.csv\n", 128 | "\n", 129 | "import csv\n", 130 | "with open('./data/train.csv') as f :\n", 131 | " lines = csv.reader(f)\n", 132 | " label, attr = [], []\n", 133 | " for line in lines :\n", 134 | " if lines.line_num == 1 :\n", 135 | " continue\n", 136 | " label.append(int(line[0]))\n", 137 | " attr.append([float(j) for j in line[1:]])\n", 138 | "print(len(label), len(attr[1]))\n", 139 | "\n", 140 | "# 将数据分为 60(epoches) * 700(rows) 的数据集\n", 141 | "epoches = []\n", 142 | "for i in range(0, len(label), epoch_size) :\n", 143 | " torch_attr = torch.FloatTensor(attr[i : i + epoch_size])\n", 144 | " torch_label = torch.LongTensor(label[i : i + epoch_size])\n", 145 | " epoches.append((torch_attr, torch_label))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# 模型实例化,Activation_Net, \n", 155 | "if torch.cuda.is_available() :\n", 156 | " net = Activation_Net(28 * 28, [300, 100], 10).cuda()\n", 157 | "else :\n", 158 | " net = Activation_Net(28 * 28, [300, 100], 10)\n" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# 损失函数\n", 168 | "criterion = nn.CrossEntropyLoss()\n", 169 | "# 优化函数\n", 170 | "optimizer = optim.SGD(net.parameters(), lr = learning_rate)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 9, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "--- train time 1 ---\n", 183 | "average loss = 11.5832\n", 184 | "average correct number = 0.4225\n", 185 | "--- train time 2 ---\n", 186 | "average loss = 0.5315\n", 187 | "average correct number = 0.8423\n", 188 | "--- train time 3 ---\n", 189 | "average loss = 0.3385\n", 190 | "average correct number = 0.8999\n", 191 | "--- train time 4 ---\n", 192 | "average loss = 0.2668\n", 193 | "average correct number = 0.9208\n", 194 | "--- train time 5 ---\n", 195 | "average loss = 0.2259\n", 196 | "average correct number = 0.9326\n", 197 | "--- train time 6 ---\n", 198 | "average loss = 0.2012\n", 199 | "average correct number = 0.9394\n", 200 | "--- train time 7 ---\n", 201 | "average loss = 0.1813\n", 202 | "average correct number = 0.9449\n", 203 | "--- train time 8 ---\n", 204 | "average loss = 0.1658\n", 205 | "average correct number = 0.9500\n", 206 | "--- train time 9 ---\n", 207 | "average loss = 0.1531\n", 208 | "average correct number = 0.9541\n", 209 | "--- train time 10 ---\n", 210 | "average loss = 0.1429\n", 211 | "average correct number = 0.9567\n", 212 | "--- train time 11 ---\n", 213 | "average loss = 0.1339\n", 214 | "average correct number = 0.9595\n", 215 | "--- train time 12 ---\n", 216 | "average loss = 0.1257\n", 217 | "average correct number = 0.9615\n", 218 | "--- train time 13 ---\n", 219 | "average loss = 0.1180\n", 220 | "average correct number = 0.9641\n", 221 | "--- train time 14 ---\n", 222 | "average loss = 0.1121\n", 223 | "average correct number = 0.9661\n", 224 | "--- train time 15 ---\n", 225 | "average loss = 0.1066\n", 226 | "average correct number = 0.9678\n", 227 | "--- train time 16 ---\n", 228 | "average loss = 0.1019\n", 229 | "average correct number = 0.9692\n", 230 | "--- train time 17 ---\n", 231 | "average loss = 0.0966\n", 232 | "average correct number = 0.9708\n", 233 | "--- train time 18 ---\n", 234 | "average loss = 0.0925\n", 235 | "average correct number = 0.9723\n", 236 | "--- train time 19 ---\n", 237 | "average loss = 0.0883\n", 238 | "average correct number = 0.9735\n", 239 | "--- train time 20 ---\n", 240 | "average loss = 0.0843\n", 241 | "average correct number = 0.9747\n", 242 | "--- train time 40 ---\n", 243 | "average loss = 0.0390\n", 244 | "average correct number = 0.9904\n", 245 | "--- train time 60 ---\n", 246 | "average loss = 0.0206\n", 247 | "average correct number = 0.9961\n", 248 | "--- train time 80 ---\n", 249 | "average loss = 0.0112\n", 250 | "average correct number = 0.9985\n", 251 | "--- train time 100 ---\n", 252 | "average loss = 0.0065\n", 253 | "average correct number = 0.9995\n", 254 | "--- train time 120 ---\n", 255 | "average loss = 0.0039\n", 256 | "average correct number = 0.9998\n", 257 | "--- train time 140 ---\n", 258 | "average loss = 0.0026\n", 259 | "average correct number = 0.9999\n", 260 | "--- train time 160 ---\n", 261 | "average loss = 0.0019\n", 262 | "average correct number = 1.0000\n", 263 | "--- train time 180 ---\n", 264 | "average loss = 0.0014\n", 265 | "average correct number = 1.0000\n", 266 | "--- train time 200 ---\n", 267 | "average loss = 0.0011\n", 268 | "average correct number = 1.0000\n", 269 | "--- train time 220 ---\n", 270 | "average loss = 0.0009\n", 271 | "average correct number = 1.0000\n", 272 | "--- train time 240 ---\n", 273 | "average loss = 0.0008\n", 274 | "average correct number = 1.0000\n", 275 | "--- train time 260 ---\n", 276 | "average loss = 0.0007\n", 277 | "average correct number = 1.0000\n", 278 | "--- train time 280 ---\n", 279 | "average loss = 0.0006\n", 280 | "average correct number = 1.0000\n", 281 | "--- train time 300 ---\n", 282 | "average loss = 0.0005\n", 283 | "average correct number = 1.0000\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "# 训练过程\n", 289 | "def train() :\n", 290 | " epoch_num, loss_sum, cort_num_sum = 0, 0.0, 0\n", 291 | " for epoch in epoches :\n", 292 | " epoch_num += 1\n", 293 | " if torch.cuda.is_available() :\n", 294 | " inputs = Variable(epoch[0]).cuda()\n", 295 | " target = Variable(epoch[1]).cuda()\n", 296 | " else :\n", 297 | " inputs = Variable(epoch[0])\n", 298 | " target = Variable(epoch[1])\n", 299 | " output = net(inputs)\n", 300 | " loss = criterion(output, target)\n", 301 | " # reset gradients\n", 302 | " optimizer.zero_grad()\n", 303 | " # backward pass\n", 304 | " loss.backward()\n", 305 | " # update parameters\n", 306 | " optimizer.step()\n", 307 | " \n", 308 | " # get training infomation\n", 309 | " loss_sum += loss.data[0]\n", 310 | " _, pred = torch.max(output.data, 1)\n", 311 | " \n", 312 | " #print(pred.shape)\n", 313 | " #print(epoch[1].shape)\n", 314 | " \n", 315 | " num_correct = torch.eq(pred, epoch[1]).sum()\n", 316 | " cort_num_sum += num_correct\n", 317 | " \n", 318 | " loss_avg = loss_sum / epoch_num\n", 319 | " cort_num_avg = cort_num_sum / epoch_num / epoch_size\n", 320 | " return loss_avg, cort_num_avg\n", 321 | "\n", 322 | "# 对所有数据跑300遍模型\n", 323 | "loss, correct = [], []\n", 324 | "training_time = 300\n", 325 | "for i in range(1, training_time + 1) :\n", 326 | " loss_avg, correct_num_avg = train()\n", 327 | " loss.append(loss_avg)\n", 328 | " if i< 20 or i % 20 == 0 :\n", 329 | " print('--- train time {} ---'.format(i))\n", 330 | " print('average loss = {:.4f}'.format(loss_avg))\n", 331 | " print('average correct number = {:.4f}'.format(correct_num_avg))\n", 332 | " correct.append(correct_num_avg)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 10, 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "image/png": "\n", 343 | "text/plain": [ 344 | "" 345 | ] 346 | }, 347 | "metadata": {}, 348 | "output_type": "display_data" 349 | }, 350 | { 351 | "data": { 352 | "image/png": "\n", 353 | "text/plain": [ 354 | "" 355 | ] 356 | }, 357 | "metadata": {}, 358 | "output_type": "display_data" 359 | } 360 | ], 361 | "source": [ 362 | "# 画图输出训练过程情况\n", 363 | "\n", 364 | "import numpy as np\n", 365 | "import matplotlib.pyplot as plt\n", 366 | "% matplotlib inline\n", 367 | "# 画训练过程中的损失值图像\n", 368 | "lx = np.array(range(len(loss)))\n", 369 | "ly = np.array(loss)\n", 370 | "plt.title('loss of training')\n", 371 | "plt.plot(lx, ly)\n", 372 | "plt.show()\n", 373 | "\n", 374 | "\n", 375 | "# 画训练过程中正确率变化\n", 376 | "cx = np.array(range(len(correct)))\n", 377 | "cy = np.array(correct)\n", 378 | "plt.title('correct rate of training')\n", 379 | "plt.plot(cx, cy)\n", 380 | "plt.show()" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 11, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "write done.\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "# 引入测试数据\n", 398 | "\n", 399 | "with open('./data/test.csv') as f :\n", 400 | " lines = csv.reader(f)\n", 401 | " test = []\n", 402 | " for line in lines :\n", 403 | " if lines.line_num == 1 :\n", 404 | " continue\n", 405 | " test.append([float(i) for i in line])\n", 406 | "test = torch.FloatTensor(test)\n", 407 | "net.eval()\n", 408 | "# volatile = True 表示前向传播不保留缓存\n", 409 | "predict = net(Variable(test, volatile=True))\n", 410 | "_, predict = torch.max(predict, 1)\n", 411 | "predict = predict.data.numpy()\n", 412 | "\n", 413 | "with open('./data/predict.csv', 'w') as f :\n", 414 | " writer = csv.writer(f)\n", 415 | " writer.writerow(['ImageId', 'Label'])\n", 416 | " for i in range(predict.shape[0]) :\n", 417 | " result = [i + 1, predict[i]]\n", 418 | " writer.writerow(result)\n", 419 | " print('write done.')" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [] 428 | } 429 | ], 430 | "metadata": { 431 | "kernelspec": { 432 | "display_name": "Python 3", 433 | "language": "python", 434 | "name": "python3" 435 | }, 436 | "language_info": { 437 | "codemirror_mode": { 438 | "name": "ipython", 439 | "version": 3 440 | }, 441 | "file_extension": ".py", 442 | "mimetype": "text/x-python", 443 | "name": "python", 444 | "nbconvert_exporter": "python", 445 | "pygments_lexer": "ipython3", 446 | "version": "3.6.4" 447 | } 448 | }, 449 | "nbformat": 4, 450 | "nbformat_minor": 2 451 | } 452 | -------------------------------------------------------------------------------- /Chapter_3/linearRegression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.nn import Linear\n", 11 | "from torch import nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "[[-7.21]\n", 24 | " [-6.6 ]\n", 25 | " [-6.07]\n", 26 | " [-5.78]\n", 27 | " [-4.06]\n", 28 | " [-0.83]\n", 29 | " [ 6.69]\n", 30 | " [ 7.44]\n", 31 | " [ 9.65]\n", 32 | " [ 9.84]] \n", 33 | " [[-9.88]\n", 34 | " [-8.79]\n", 35 | " [-5.52]\n", 36 | " [-4.67]\n", 37 | " [-3.83]\n", 38 | " [ 2.03]\n", 39 | " [ 5.18]\n", 40 | " [ 7.18]\n", 41 | " [ 8.15]\n", 42 | " [ 9.94]]\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "# 获得训练数据\n", 48 | "import numpy as np\n", 49 | "import random\n", 50 | "x = sorted([random.randint(-1000, 1000) * 0.01 for i in range(10)])\n", 51 | "y = sorted([random.randint(-1000, 1000) * 0.01 for i in range(10)])\n", 52 | "x_train = [[i] for i in x]\n", 53 | "y_train = [[i] for i in y]\n", 54 | "x_train = np.array(x_train)\n", 55 | "y_train = np.array(y_train)\n", 56 | "\n", 57 | "print(x_train, '\\n', y_train)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEACAYAAACwB81wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAEAhJREFUeJzt3V+MXOV9xvHnQd6VJqmMKAxGIslMGposqWoRK6aRUlWzihZMpQqCogR60X+rCGQi9aKVAHHhbdUL6IVvUu1NGCEqweL0gjZUarCjeCJZvegKShcr2Lhqz/InwTkKBCXqSt7gXy92bA9mdu35c+ac3ff7kUaec87MeX86Ozz78u573nFECACw811TdgEAgMkg8AEgEQQ+ACSCwAeARBD4AJAIAh8AEjGWwLfdtn3W9krPvkO237L9cvdxYBxtAQCGM64e/lOS7uyz/3BE7Os+vj+mtgAAQxhL4EfECUnv9TnkcZwfADC6osfwH7L9iu0nbV9bcFsAgC0UGfiLkj4TEbdJekfS4QLbAgBcwa6iThwRec/mdyS90O91tlnMBwCGEBEDDZuPs4dv9YzZ276p59i9kk5u9saI4BGhQ4cOlV5DVR5cC64F12LrxzDG0sO3/ayklqTrbb8h6ZCkWdu3STovKZP0wDjaAgAMZyyBHxF/3Gf3U+M4NwBgPLjTtkJarVbZJVQG1+ISrsUlXIvReNixoLEVYEfZNQDAdmNbUeIfbQEAFUbgA0AiCHwASASBDwCJIPABIBEEPgAkgsAHgEQQ+ACQCAIfABJB4ANAIgh8AEgEgQ8AiSDwASARBD4AJILAB4BEEPgAkAgCHwASQeADQCIIfABIBIEPAIkg8AEgEQQ+ACSCwAeARBD4AJAIAh8AEkHgA0AixhL4ttu2z9pe6dl3ne2jtk/bftH2teNoCwAwnHH18J+SdOdl+x6R9IOI+JykH0p6dExtAQCGMJbAj4gTkt67bPfdkp7uPn9a0j3jaAtAWvI81/LysvI8L7uUba/IMfwbI+KsJEXEO5JuLLAtADvQ0tIRNRozmpt7UI3GjJaWjpRd0rbmiBjPieyGpBciYm93+92I+M2e4z+PiOv7vC/GVQOAnSPPczUaM1pbOy5pr6QV1WqzWl09pXq9XnZ5pbOtiPAg79lVVDGSztreExFnbd8k6WebvXBhYeHi81arpVarVWBZALaDLMs0Pd3U2tre7p69mppqKMuyJAO/0+mo0+mMdI5x9vCb2ujh/253+wlJ70bEE7YflnRdRDzS53308AF8BD38rQ3Twx/XtMxnJf27pM/afsP2n0t6XNKc7dOSvtLdBoCrUq/X1W4vqlab1e7d+1SrzardXiTsRzC2Hv7QBdDDB7CFPM+VZZmazSZh32OYHj6BDwDbUGlDOgCA6iPwASARBD4AJILAB4BEEPgAkAgCHwASQeADqBxWyCwGgQ+gUlghszjceAWgMlg/5+px4xWAbe3CCpkbYS/1rpCJ0RH4ACqj2Wzq3LlM0oWvx17R+vqqms1meUXtIAQ+gMpghcxiMYYPoHJYIfPKWC0TwNgRvtXEH20BjBVTJHcWevgA+mKKZLXRwwcwNkyR3HkIfAB9MUVy5yHwAfTFFMmdhzF8AFtilk41MS0TABLBH20BAJsi8AEgEQQ+ACSCwAeARBD4AJAIAh8AEkHgA0AidhXdgO1M0vuSzktaj4jbi24TwJVxQ1V6JtHDPy+pFRFfIOyBamDZ4zQVfqet7f+V9MWI+Pkmx7nTFpgglj3eGap6p21IetH2su1vTqA9AFtg2eN0FT6GL+nLEfFT23VJx2y/FhEnel+wsLBw8Xmr1VKr1ZpAWUCaPrzs8UYPn2WPq6/T6ajT6Yx0jokunmb7kKRfRsThnn0M6QATtrR0RPPzBzU11dD6+qra7UXdf/83yi4LA6jcapm2Pybpmoj4le2PSzoq6W8i4mjPawh8oATM0tneqhj4n5b0vDbG8XdJeiYiHr/sNQQ+AAyocoF/VQUQ+AAwsKrO0gEAVACBDwCJIPABIBEEPgAkgsAHgEQQ+ACQCAIfABJB4ANAIgh8AEgEgQ8AiSDwASARBD4AJILAB4BEEPgAkAgCHxOR57mWl5eV53nZpQDJIvBRuKWlI2o0ZjQ396AajRktLR0puyQgSXwBCgqV57kajRmtrR3XhS/MrtVmtbp6iq/VA0bAF6CgcrIs0/R0UxthL0l7NTXVUJZl5RUFJIrAR6GazabOncskrXT3rGh9fVXNZrO8ooBEEfgoVL1eV7u9qFptVrt371OtNqt2e5HhHKAEjOFjIvI8V5ZlajabhD0wBsOM4RP4ALAN8UdbAMCmCHwASASBDwCJIPABIBEEPgAkgsAHgEQUHvi2D9g+Zft12w8X3R4AoL9C5+HbvkbS65K+IuknkpYl3RcRp3pewzx8ABhQFefh3y7pTESsRsS6pOck3V1wmwCAPooO/Jslvdmz/VZ3HwBgwnaVXYAkLSwsXHzearXUarVKqwUAqqjT6ajT6Yx0jqLH8L8kaSEiDnS3H5EUEfFEz2sYwweAAVVxDH9Z0i22G7anJd0n6XsFtwkA6KPQIZ2I+MD2tyQd1cYvl3ZEvFZkmwCA/lgeGQC2oSoO6QAAKoLAB4BEEPgAkAgCHwASQeADQCIIfABIBIEPAIkg8AEgEQQ+ACSCwAeARBD4AJAIAh8AEkHgA0AiCHwASASBDwCJIPABIBEEPgAkgsAHgEQQ+ACQCAIfABJB4ANAIgj8bSzPcy0vLyvP87JLAbANEPjb1NLSETUaM5qbe1CNxoyWlo6UXRKAinNElFuAHWXXsN3kea5GY0Zra8cl7ZW0olptVqurp1Sv18suD8AE2FZEeJD30MPfhrIs0/R0UxthL0l7NTXVUJZl5RUFoPII/G2o2Wzq3LlM0kp3z4rW11fVbDbLKwpA5RH421C9Xle7vahabVa7d+9TrTardnuR4RwAW2IMv2B5nivLMjWbzbEHcpHnBlBtlRrDt33I9lu2X+4+DhTVVlUVPZOmXq9r//79hD2Aq1JYD9/2IUm/jIjDV3jdjuzhM5MGQJEq1cPvGqiYnYSZNACqpujAf8j2K7aftH1twW1VCjNpAFTNrlHebPuYpD29uySFpMckLUr624gI238n6bCk+X7nWVhYuPi81Wqp1WqNUlYlXJhJMz8/q6mphtbXV5lJA2BonU5HnU5npHNMZJaO7YakFyJib59jO2oM//KZM8ykAVCESo3h276pZ/NeSSeLaqsq+s3KYSYNgKoocpbOP0q6TdJ5SZmkByLibJ/X7YgePrNyAEzSMD38kcbwtxIRf1LUuavowqyctbWPzsoh8AFUAUsrjAmzcgBUHYE/JqxvA6DqWEtnzJiVA2AShhnDJ/ABYBuq1LRMAEC1EPgAkAgCHwASQeADQCIIfABIBIEPAIkg8AEgEQQ+ACSCwAeARBD4AJAIAh8AEkHgA0AiCHwASASBDwCJIPABIBEEPgAkgsAHgEQQ+EPI81zLy8vK87zsUgDgqhH4A1paOqJGY0Zzcw+q0ZjR0tKRsksCgKvCd9oOIM9zNRozWls7LmmvpBXVarNaXT3FF5YDmCi+07ZgWZZperqpjbCXpL2ammooy7LyigKAq0TgD6DZbOrcuUzSSnfPitbXV9VsNssrCgCuEoE/gHq9rnZ7UbXarHbv3qdabVbt9iLDOQC2Bcbwh5DnubIsU7PZJOwBlGLiY/i2v2b7pO0PbO+77Nijts/Yfs32HaO0UzX1el379+8n7AFsK6MO6bwq6auSftS70/atkr4u6VZJd0latD3Qb6IyMc8ewE40UuBHxOmIOCPp8jC/W9JzEfHriMgknZF0+yhtTQrz7AHsVEX90fZmSW/2bL/d3VdpeZ5rfv6g1taO6/33X9La2nHNzx+kpw9gR9h1pRfYPiZpT+8uSSHpsYh4oajCynBhnv3a2kfn2TNeD2C7u2LgR8TcEOd9W9Ine7Y/0d3X18LCwsXnrVZLrVZriCZH9+F59ht30jLPHkAVdDoddTqdkc4xlmmZto9L+uuIeKm7/XlJz0j6PW0M5RyT9Nv95l9WbVrm0tIRzc8f1NRUQ+vrq2q3F3X//d8ouywA+JBhpmWOFPi275H0bUk3SPqFpFci4q7usUclzUtal/SXEXF0k3NUKvAl5tkDqL6JB/44VDHwAaDqWDwNALApAh8AEkHgA0AiCHwASASBDwCJIPABIBEEPgAkgsAHgEQQ+ACQCAIfABJB4ANAIgh8AEgEgQ8AiSDwASARBD4AJILAB4BEEPgAkAgCHwASQeADQCIIfABIBIEPAIkg8AEgEQQ+ACSCwAeARBD4AJAIAh8AEkHgA0AiCHwASMRIgW/7a7ZP2v7A9r6e/Q3b/2f75e5jcfRSAQCjGLWH/6qkr0r6UZ9j/x0R+7qPgyO2k4ROp1N2CZXBtbiEa3EJ12I0IwV+RJyOiDOS3Odwv33YAh/mS7gWl3AtLuFajKbIMfym7ZdsH7f9+wW2AwC4Cruu9ALbxyTt6d0lKSQ9FhEvbPK2n0j6VES81x3b/2fbn4+IX41cMQBgKI6I0U9iH5f0VxHx8qDHbY9eAAAkKCIGGjq/Yg9/ABcbtn2DpHcj4rzt35J0i6T/6femQQsGAAxn1GmZ99h+U9KXJP2r7X/rHvoDSSu2X5b0XUkPRMQvRisVADCKsQzpAACqr7Q7bTe7aat77FHbZ2y/ZvuOsmosg+1Dtt/quWntQNk1TZrtA7ZP2X7d9sNl11Mm25nt/7L9n7b/o+x6Jsl22/ZZ2ys9+66zfdT2adsv2r62zBonZZNrMXBWlLm0Qt+btmzfKunrkm6VdJekRdupjfMf7rlp7ftlFzNJtq+R9A+S7pT0O5Lutz1TblWlOi+pFRFfiIjbyy5mwp7Sxueg1yOSfhARn5P0Q0mPTryqcvS7FtKAWVFa4G9x09bdkp6LiF9HRCbpjKTUPuip/YLrdbukMxGxGhHrkp7TxmciVVaia15FxAlJ7122+25JT3efPy3pnokWVZJNroU0YFZU8YN0s6Q3e7bf7u5LyUO2X7H9ZCr/y9rj8p//W0rv598rJL1oe9n2N8supgJujIizkhQR70i6seR6yjZQVhQa+LaP2V7pebza/fePimy36q5wXRYlfSYibpP0jqTD5VaLkn05Ir4o6Q+18R83d61/WMqzTgbOinHOw/+IiJgb4m1vS/pkz/Ynuvt2jAGuy3ckbXY38071tqRP9WzvuJ//ICLip91/c9vPa2PI60S5VZXqrO09EXHW9k2SflZ2QWWJiLxn86qyoipDOr3jUN+TdJ/taduf1sZNW8nMTuh+iC+4V9LJsmopybKkW7pLbE9Luk8bn4nk2P6Y7d/oPv+4pDuU3ufB+mg+/Fn3+Z9K+pdJF1SiD12LYbKi0B7+VmzfI+nbkm7Qxk1br0TEXRHxY9vflfRjSeuSDkZaNwv8ve3btDE7I5P0QLnlTFZEfGD7W5KOaqND0o6I10ouqyx7JD3fXX5kl6RnIuJoyTVNjO1nJbUkXW/7DUmHJD0u6Z9s/4WkVW3M6NvxNrkWs4NmBTdeAUAiqjKkAwAoGIEPAIkg8AEgEQQ+ACSCwAeARBD4AJAIAh8AEkHgA0Ai/h8ouuEJoRhXmwAAAABJRU5ErkJggg==\n", 68 | "text/plain": [ 69 | "" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "# 画图\n", 78 | "import matplotlib.pyplot as plt\n", 79 | "plt.scatter(x_train, y_train)\n", 80 | "plt.show()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 5, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "\n", 93 | "-7.2100\n", 94 | "-6.6000\n", 95 | "-6.0700\n", 96 | "-5.7800\n", 97 | "-4.0600\n", 98 | "-0.8300\n", 99 | " 6.6900\n", 100 | " 7.4400\n", 101 | " 9.6500\n", 102 | " 9.8400\n", 103 | "[torch.FloatTensor of size 10x1]\n", 104 | " \n", 105 | " \n", 106 | "-9.8800\n", 107 | "-8.7900\n", 108 | "-5.5200\n", 109 | "-4.6700\n", 110 | "-3.8300\n", 111 | " 2.0300\n", 112 | " 5.1800\n", 113 | " 7.1800\n", 114 | " 8.1500\n", 115 | " 9.9400\n", 116 | "[torch.FloatTensor of size 10x1]\n", 117 | "\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "# 将numpy变量转化为tensor\n", 123 | "x_train = torch.from_numpy(x_train)\n", 124 | "y_train = torch.from_numpy(y_train)\n", 125 | "x_train = x_train.float()\n", 126 | "y_train = y_train.float()\n", 127 | "print(x_train, '\\n', y_train)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# 定义模型, 应当包含 __init__()函数和forward()函数\n", 137 | "class LinearRegression(nn.Module):\n", 138 | " def __init__(self):\n", 139 | " super(LinearRegression, self).__init__()\n", 140 | " self.linear = Linear(1, 1)\n", 141 | " \n", 142 | " def forward(self, x) :\n", 143 | " out = self.linear(x)\n", 144 | " return out\n", 145 | "if torch.cuda.is_available():\n", 146 | " model = LinearRegression().cuda()\n", 147 | "else :\n", 148 | " model = LinearRegression()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# 定义损失函数和优化函数\n", 158 | "criterison = nn.MSELoss()\n", 159 | "optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 8, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "Epoch[20/1000], loss : 3.845105\n", 172 | "Epoch[40/1000], loss : 2.739876\n", 173 | "Epoch[60/1000], loss : 2.704095\n", 174 | "Epoch[80/1000], loss : 2.689542\n", 175 | "Epoch[100/1000], loss : 2.676452\n", 176 | "Epoch[120/1000], loss : 2.664375\n", 177 | "Epoch[140/1000], loss : 2.653225\n", 178 | "Epoch[160/1000], loss : 2.642932\n", 179 | "Epoch[180/1000], loss : 2.633429\n", 180 | "Epoch[200/1000], loss : 2.624657\n", 181 | "Epoch[220/1000], loss : 2.616558\n", 182 | "Epoch[240/1000], loss : 2.609081\n", 183 | "Epoch[260/1000], loss : 2.602178\n", 184 | "Epoch[280/1000], loss : 2.595806\n", 185 | "Epoch[300/1000], loss : 2.589923\n", 186 | "Epoch[320/1000], loss : 2.584491\n", 187 | "Epoch[340/1000], loss : 2.579478\n", 188 | "Epoch[360/1000], loss : 2.574849\n", 189 | "Epoch[380/1000], loss : 2.570575\n", 190 | "Epoch[400/1000], loss : 2.566631\n", 191 | "Epoch[420/1000], loss : 2.562988\n", 192 | "Epoch[440/1000], loss : 2.559626\n", 193 | "Epoch[460/1000], loss : 2.556522\n", 194 | "Epoch[480/1000], loss : 2.553656\n", 195 | "Epoch[500/1000], loss : 2.551010\n", 196 | "Epoch[520/1000], loss : 2.548568\n", 197 | "Epoch[540/1000], loss : 2.546313\n", 198 | "Epoch[560/1000], loss : 2.544232\n", 199 | "Epoch[580/1000], loss : 2.542310\n", 200 | "Epoch[600/1000], loss : 2.540536\n", 201 | "Epoch[620/1000], loss : 2.538898\n", 202 | "Epoch[640/1000], loss : 2.537386\n", 203 | "Epoch[660/1000], loss : 2.535990\n", 204 | "Epoch[680/1000], loss : 2.534701\n", 205 | "Epoch[700/1000], loss : 2.533512\n", 206 | "Epoch[720/1000], loss : 2.532413\n", 207 | "Epoch[740/1000], loss : 2.531399\n", 208 | "Epoch[760/1000], loss : 2.530463\n", 209 | "Epoch[780/1000], loss : 2.529599\n", 210 | "Epoch[800/1000], loss : 2.528801\n", 211 | "Epoch[820/1000], loss : 2.528064\n", 212 | "Epoch[840/1000], loss : 2.527384\n", 213 | "Epoch[860/1000], loss : 2.526757\n", 214 | "Epoch[880/1000], loss : 2.526177\n", 215 | "Epoch[900/1000], loss : 2.525642\n", 216 | "Epoch[920/1000], loss : 2.525148\n", 217 | "Epoch[940/1000], loss : 2.524692\n", 218 | "Epoch[960/1000], loss : 2.524271\n", 219 | "Epoch[980/1000], loss : 2.523883\n", 220 | "Epoch[1000/1000], loss : 2.523524\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "# 开始训练模型\n", 226 | "from torch.autograd import Variable\n", 227 | "# 定义要跑的批(epoch)数\n", 228 | "num_epochs = 1000\n", 229 | "for epoch in range(num_epochs) :\n", 230 | " # 获得每一批数据的输出\n", 231 | " if torch.cuda.is_available() :\n", 232 | " inputs = Variable(x_train).cuda()\n", 233 | " target = Variable(y_train).cuda()\n", 234 | " else :\n", 235 | " inputs = Variable(x_train)\n", 236 | " target = Variable(y_train)\n", 237 | " # forward前向计算\n", 238 | " out = model(inputs)\n", 239 | " loss = criterison(out, target)\n", 240 | " # backward 计算误差项及更新参数\n", 241 | " optimizer.zero_grad()\n", 242 | " loss.backward()\n", 243 | " optimizer.step()\n", 244 | " \n", 245 | " # 一轮训练结束, 输出信息\n", 246 | " if(epoch + 1) % 20 == 0 :\n", 247 | " print('Epoch[{}/{}], loss : {:.6f}'.format(epoch + 1, num_epochs, loss.data[0]))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 9, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "text/plain": [ 258 | "LinearRegression(\n", 259 | " (linear): Linear(in_features=1, out_features=1)\n", 260 | ")" 261 | ] 262 | }, 263 | "execution_count": 9, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | } 267 | ], 268 | "source": [ 269 | "# 将模型变成测试模式\n", 270 | "model.eval()" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 10, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "[]" 282 | ] 283 | }, 284 | "execution_count": 10, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | }, 288 | { 289 | "data": { 290 | "image/png": "\n", 291 | "text/plain": [ 292 | "" 293 | ] 294 | }, 295 | "metadata": {}, 296 | "output_type": "display_data" 297 | } 298 | ], 299 | "source": [ 300 | "# 测试结果与真实结果作比较\n", 301 | "predict = model(Variable(x_train))\n", 302 | "predict = predict.data.numpy()\n", 303 | "plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')\n", 304 | "plt.plot(x_train.numpy(), predict, label='Fitting Line')" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "Python 3", 318 | "language": "python", 319 | "name": "python3" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.5.2" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 2 336 | } 337 | -------------------------------------------------------------------------------- /Chapter_3/logistRegression/data.txt: -------------------------------------------------------------------------------- 1 | 34.62365962451697,78.0246928153624,0 2 | 30.28671076822607,43.89499752400101,0 3 | 35.84740876993872,72.90219802708364,0 4 | 60.18259938620976,86.30855209546826,1 5 | 79.0327360507101,75.3443764369103,1 6 | 45.08327747668339,56.3163717815305,0 7 | 61.10666453684766,96.51142588489624,1 8 | 75.02474556738889,46.55401354116538,1 9 | 76.09878670226257,87.42056971926803,1 10 | 84.43281996120035,43.53339331072109,1 11 | 95.86155507093572,38.22527805795094,0 12 | 75.01365838958247,30.60326323428011,0 13 | 82.30705337399482,76.48196330235604,1 14 | 69.36458875970939,97.71869196188608,1 15 | 39.53833914367223,76.03681085115882,0 16 | 53.9710521485623,89.20735013750205,1 17 | 69.07014406283025,52.74046973016765,1 18 | 67.94685547711617,46.67857410673128,0 19 | 70.66150955499435,92.92713789364831,1 20 | 76.97878372747498,47.57596364975532,1 21 | 67.37202754570876,42.83843832029179,0 22 | 89.67677575072079,65.79936592745237,1 23 | 50.534788289883,48.85581152764205,0 24 | 34.21206097786789,44.20952859866288,0 25 | 77.9240914545704,68.9723599933059,1 26 | 62.27101367004632,69.95445795447587,1 27 | 80.1901807509566,44.82162893218353,1 28 | 93.114388797442,38.80067033713209,0 29 | 61.83020602312595,50.25610789244621,0 30 | 38.78580379679423,64.99568095539578,0 31 | 61.379289447425,72.80788731317097,1 32 | 85.40451939411645,57.05198397627122,1 33 | 52.10797973193984,63.12762376881715,0 34 | 52.04540476831827,69.43286012045222,1 35 | 40.23689373545111,71.16774802184875,0 36 | 54.63510555424817,52.21388588061123,0 37 | 33.91550010906887,98.86943574220611,0 38 | 64.17698887494485,80.90806058670817,1 39 | 74.78925295941542,41.57341522824434,0 40 | 34.1836400264419,75.2377203360134,0 41 | 83.90239366249155,56.30804621605327,1 42 | 51.54772026906181,46.85629026349976,0 43 | 94.44336776917852,65.56892160559052,1 44 | 82.36875375713919,40.61825515970618,0 45 | 51.04775177128865,45.82270145776001,0 46 | 62.22267576120188,52.06099194836679,0 47 | 77.19303492601364,70.45820000180959,1 48 | 97.77159928000232,86.7278223300282,1 49 | 62.07306379667647,96.76882412413983,1 50 | 91.56497449807442,88.69629254546599,1 51 | 79.94481794066932,74.16311935043758,1 52 | 99.2725269292572,60.99903099844988,1 53 | 90.54671411399852,43.39060180650027,1 54 | 34.52451385320009,60.39634245837173,0 55 | 50.2864961189907,49.80453881323059,0 56 | 49.58667721632031,59.80895099453265,0 57 | 97.64563396007767,68.86157272420604,1 58 | 32.57720016809309,95.59854761387875,0 59 | 74.24869136721598,69.82457122657193,1 60 | 71.79646205863379,78.45356224515052,1 61 | 75.3956114656803,85.75993667331619,1 62 | 35.28611281526193,47.02051394723416,0 63 | 56.25381749711624,39.26147251058019,0 64 | 30.05882244669796,49.59297386723685,0 65 | 44.66826172480893,66.45008614558913,0 66 | 66.56089447242954,41.09209807936973,0 67 | 40.45755098375164,97.53518548909936,1 68 | 49.07256321908844,51.88321182073966,0 69 | 80.27957401466998,92.11606081344084,1 70 | 66.74671856944039,60.99139402740988,1 71 | 32.72283304060323,43.30717306430063,0 72 | 64.0393204150601,78.03168802018232,1 73 | 72.34649422579923,96.22759296761404,1 74 | 60.45788573918959,73.09499809758037,1 75 | 58.84095621726802,75.85844831279042,1 76 | 99.82785779692128,72.36925193383885,1 77 | 47.26426910848174,88.47586499559782,1 78 | 50.45815980285988,75.80985952982456,1 79 | 60.45555629271532,42.50840943572217,0 80 | 82.22666157785568,42.71987853716458,0 81 | 88.9138964166533,69.80378889835472,1 82 | 94.83450672430196,45.69430680250754,1 83 | 67.31925746917527,66.58935317747915,1 84 | 57.23870631569862,59.51428198012956,1 85 | 80.36675600171273,90.96014789746954,1 86 | 68.46852178591112,85.59430710452014,1 87 | 42.0754545384731,78.84478600148043,0 88 | 75.47770200533905,90.42453899753964,1 89 | 78.63542434898018,96.64742716885644,1 90 | 52.34800398794107,60.76950525602592,0 91 | 94.09433112516793,77.15910509073893,1 92 | 90.44855097096364,87.50879176484702,1 93 | 55.48216114069585,35.57070347228866,0 94 | 74.49269241843041,84.84513684930135,1 95 | 89.84580670720979,45.35828361091658,1 96 | 83.48916274498238,48.38028579728175,1 97 | 42.2617008099817,87.10385094025457,1 98 | 99.31500880510394,68.77540947206617,1 99 | 55.34001756003703,64.9319380069486,1 100 | 74.77589300092767,89.52981289513276,1 101 | -------------------------------------------------------------------------------- /Chapter_3/logistRegression/logisticRegression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Logistic 回归\n", 8 | "Logistic回归中,输出(Y = 1)的对数几率是输入 x 的线性函数,思路 : \n", 9 | "\n", 10 | "1. 先拟合决策边界(不局限于线性,还可以是多项式)\n", 11 | "2. 再建立这个边界和分类概率的关系\n", 12 | "3. 从而得到二分类情况下的概率" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import torch\n", 24 | "from torch import nn\n", 25 | "from torch import optim\n", 26 | "from torch.autograd import Variable" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/plain": [ 37 | "" 38 | ] 39 | }, 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "output_type": "execute_result" 43 | }, 44 | { 45 | "data": { 46 | "image/png": "\n", 47 | "text/plain": [ 48 | "" 49 | ] 50 | }, 51 | "metadata": {}, 52 | "output_type": "display_data" 53 | } 54 | ], 55 | "source": [ 56 | "# 获取并查看数据\n", 57 | "\n", 58 | "with open('data.txt') as f :\n", 59 | " data = f.read().split('\\n')\n", 60 | " data = [row.split(',') for row in data][:-1]\n", 61 | " label0 = np.array([(float(row[0]), float(row[1])) for row in data if row[2] == '0'])\n", 62 | " label1 = np.array([(float(row[0]), float(row[1])) for row in data if row[2] == '1'])\n", 63 | "x0, y0 = label0[:, 0], label0[:, 1]\n", 64 | "x1, y1 = label1[:, 0], label1[:, 1]\n", 65 | "plt.plot(x0, y0, 'ro', label = 'label_0')\n", 66 | "plt.plot(x1, y1, 'bo', label = 'label_1')\n", 67 | "plt.legend(loc = 'best')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "import torch\n", 77 | "\n", 78 | "# 获得训练数据\n", 79 | "x = np.concatenate((label0, label1), axis = 0)\n", 80 | "x_data = torch.from_numpy(x).float()\n", 81 | "\n", 82 | "y = [[0] for i in range(label0.shape[0])]\n", 83 | "y += [[1] for i in range(label1.shape[0])]\n", 84 | "y_data = torch.FloatTensor(y)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "# 定义模型\n", 94 | "\n", 95 | "class logisticRegression(nn.Module) :\n", 96 | " def __init__(self) :\n", 97 | " super().__init__()\n", 98 | " self.line = nn.Linear(2, 1)\n", 99 | " self.smd = nn.Sigmoid()\n", 100 | " def forward(self, x) :\n", 101 | " x = self.line(x)\n", 102 | " return self.smd(x)\n", 103 | "\n", 104 | "logistic = logisticRegression()\n", 105 | "\n", 106 | "if torch.cuda.is_available() :\n", 107 | " logistic.cuda()" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 5, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# 定义损失函数为二分类交叉熵(Binary Cross Entropy)\n", 117 | "criterion = nn.BCELoss()\n", 118 | "# 定义优化函数为随机梯度下降(Sochastic Gradient Descent)\n", 119 | "optimizer = optim.SGD(logistic.parameters(), lr = 1e-3, momentum = 0.9)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "解释一下两个部分\n", 127 | "\n", 128 | "#### BCELoss()\n", 129 | "它的定义是 $loss(o, t) = \\frac 1 n \\sum_i(t[i] * log(o[i]) + (1 - t[i]) * log(1 - o[i]))$\n", 130 | "\n", 131 | "以上来源于[torch.nn.BCELoss](http://pytorch.org/docs/master/nn.html?highlight=nn%20bceloss#torch.nn.BCELoss)\n", 132 | "#### SGD函数,它带有一个参数 `momentum = 0.9`,这表示动量。\n", 133 | "\n", 134 | "这种优化方法来源于[On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf)\n", 135 | "\n", 136 | "它的优化方式为 : $$v = \\rho * v + g \\\\ p = p - lr * v$$\n", 137 | "其中,p, g, v, $\\rho$ 分别表示待更新的参数parameters,求得的梯度gradient,速度velocity,和动量momentum\n", 138 | "\n", 139 | "以上来源于[官网](http://pytorch.org/docs/master/optim.html)关于torch.optim.SGD的note" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "**********\n", 152 | "epoch 10000\n", 153 | "loss is 0.3826\n", 154 | "correct rate is 0.9100\n", 155 | "**********\n", 156 | "epoch 20000\n", 157 | "loss is 0.3146\n", 158 | "correct rate is 0.9200\n", 159 | "**********\n", 160 | "epoch 30000\n", 161 | "loss is 0.2826\n", 162 | "correct rate is 0.9100\n", 163 | "**********\n", 164 | "epoch 40000\n", 165 | "loss is 0.2640\n", 166 | "correct rate is 0.9100\n", 167 | "**********\n", 168 | "epoch 50000\n", 169 | "loss is 0.2517\n", 170 | "correct rate is 0.9100\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "# 开始训练\n", 176 | "epoches = 50000\n", 177 | "for epoch in range(epoches) :\n", 178 | " if torch.cuda.is_available() :\n", 179 | " x = Variable(x_data).cuda()\n", 180 | " y = Variable(y_data).cuda()\n", 181 | " else :\n", 182 | " x = Variable(x_data)\n", 183 | " y = Variable(y_data)\n", 184 | " \n", 185 | " # forward 前向计算\n", 186 | " out = logistic(x)\n", 187 | " loss = criterion(out, y)\n", 188 | " \n", 189 | " # 计算准确率\n", 190 | " print_loss = loss.data[0]\n", 191 | " mask = out.ge(0.5).float()\n", 192 | " # print('size : {} - {}'.format(mask.size(), y.size()))\n", 193 | " correct = (mask == y).sum()\n", 194 | " acc = correct.data[0] / x.size(0)\n", 195 | " \n", 196 | " # BP回朔\n", 197 | " optimizer.zero_grad()\n", 198 | " loss.backward()\n", 199 | " optimizer.step()\n", 200 | " if (epoch + 1) % 10000 == 0 :\n", 201 | " print('*' * 10)\n", 202 | " print('epoch {}'.format(epoch + 1))\n", 203 | " print('loss is {:.4f}'.format(print_loss))\n", 204 | " print('correct rate is {:.4f}'.format(acc))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 7, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "image/png": "\n", 215 | "text/plain": [ 216 | "" 217 | ] 218 | }, 219 | "metadata": {}, 220 | "output_type": "display_data" 221 | } 222 | ], 223 | "source": [ 224 | "weight = logistic.line.weight.data[0]\n", 225 | "w0, w1 = weight[0], weight[1]\n", 226 | "b = logistic.line.bias.data[0]\n", 227 | "\n", 228 | "plt.plot(x0, y0, 'ro', label = 'label_0')\n", 229 | "plt.plot(x1, y1, 'bo', label = 'label_1')\n", 230 | "plt.legend(loc = 'best')\n", 231 | "plot_x = np.arange(30, 100, 0.1)\n", 232 | "plot_y = (-w0 * plot_x - b) / w1\n", 233 | "plt.plot(plot_x, plot_y)\n", 234 | "plt.show()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "Python 3", 248 | "language": "python", 249 | "name": "python3" 250 | }, 251 | "language_info": { 252 | "codemirror_mode": { 253 | "name": "ipython", 254 | "version": 3 255 | }, 256 | "file_extension": ".py", 257 | "mimetype": "text/x-python", 258 | "name": "python", 259 | "nbconvert_exporter": "python", 260 | "pygments_lexer": "ipython3", 261 | "version": "3.6.3" 262 | } 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 2 266 | } 267 | -------------------------------------------------------------------------------- /Chapter_3/polynomialRegression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.autograd import Variable\n", 11 | "from torch import nn\n", 12 | "from torch import optim" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# 定义最高次项的指数\n", 22 | "\n", 23 | "n = 4" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# 为线性相乘做准备,将x中的每一个元素排列成 [x^1, x^2, ..., x^n]\n", 33 | "# 因为有多个x,故返回的是一个4x4的二维张量\n", 34 | "\n", 35 | "def make_features(x) :\n", 36 | " x = x.unsqueeze(1)\n", 37 | " return torch.cat([x ** i for i in range(1, n + 1)], 1)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# 不加入x^0的原因是,nn.Linear默认自带bias属性(可取消), 以下来源于http://pytorch.org/docs/master/nn.html\n", 47 | "# 在页面内搜索nn.Linear即可找到\n", 48 | "\n", 49 | "## class torch.nn.Linear(in_features, out_features, bias=True)[source]\n", 50 | "## Applies a linear transformation to the incoming data: y=Ax+b\n", 51 | "##\n", 52 | "## > Parameters:\n", 53 | "## - in_features – size of each input sample\n", 54 | "## - out_features – size of each output sample\n", 55 | "## - bias – If set to False, the layer will not learn an additive bias. Default: True" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 5, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "import random\n", 65 | "#定义目标函数,包括权重和偏执\n", 66 | "\n", 67 | "W_target = torch.FloatTensor([random.randint(-1000, 1000) * 0.01 for i in range(n)]).unsqueeze(1)\n", 68 | "b_target = torch.FloatTensor([random.randint(-100, 1000) * 0.01])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 6, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# 定义实际函数\n", 78 | "\n", 79 | "def f(x) :\n", 80 | " return x.mm(W_target) + b_target[0]" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 7, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# 生成训练集。随机取数,然后生成x,然后生成y\n", 90 | "\n", 91 | "def get_batch(batch_size = 32, random = None) :\n", 92 | " if random is None :\n", 93 | " random = torch.randn(batch_size)\n", 94 | " batch_size = random.size()[0]\n", 95 | " x = make_features(random)\n", 96 | " y = f(x)\n", 97 | " if torch.cuda.is_available() :\n", 98 | " return Variable(x).cuda(), Variable(y).cuda()\n", 99 | " else :\n", 100 | " return Variable(x), Variable(y)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 8, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# 构造训练网络模型\n", 110 | "\n", 111 | "class poly_model(nn.Module) :\n", 112 | " def __init__(self, n) :\n", 113 | " super().__init__()\n", 114 | " self.poly = nn.Linear(n, 1)\n", 115 | " def forward(self, x) :\n", 116 | " return self.poly(x)\n", 117 | "\n", 118 | "# 实例化网络模型\n", 119 | "if torch.cuda.is_available() :\n", 120 | " poly = poly_model(n).cuda()\n", 121 | "else :\n", 122 | " poly = poly_model(n)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 9, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# 定义损失函数为均方误差, 定义优化函数为随机梯度下降,学习率为0.03\n", 132 | "\n", 133 | "criterion = nn.MSELoss()\n", 134 | "optimizer = optim.SGD(poly.parameters(), lr = 1e-3)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 10, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "the number of epoches : 7788\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "# 开始训练\n", 152 | "\n", 153 | "epoch = 0\n", 154 | "while True :\n", 155 | " # 获得数据\n", 156 | " batch_x, batch_y = get_batch()\n", 157 | " # 前向计算\n", 158 | " output = poly(batch_x)\n", 159 | " # 计算损失函数\n", 160 | " loss = criterion(output, batch_y)\n", 161 | " print_loss = loss.data[0]\n", 162 | " # 参数更新\n", 163 | " optimizer.zero_grad()\n", 164 | " loss.backward()\n", 165 | " optimizer.step()\n", 166 | " \n", 167 | " epoch += 1\n", 168 | " if print_loss < 1e-3 :\n", 169 | " break\n", 170 | "\n", 171 | "print(\"the number of epoches :\", epoch)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 11, 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "data": { 181 | "text/plain": [ 182 | "" 183 | ] 184 | }, 185 | "metadata": {}, 186 | "output_type": "display_data" 187 | } 188 | ], 189 | "source": [ 190 | "import matplotlib.pyplot as plt\n", 191 | "import numpy as np\n", 192 | "x = [random.randint(-200, 200) * 0.01 for i in range(20)]\n", 193 | "x = np.array(sorted(x))\n", 194 | "feature_x, y = get_batch(random = torch.from_numpy(x).float())\n", 195 | "y = y.data.numpy()\n", 196 | "plt.plot(x, y, 'ro', label='Original data')\n", 197 | "\n", 198 | "poly.eval()\n", 199 | "x_sample = np.arange(-2, 2, 0.01)\n", 200 | "x, y = get_batch(random = torch.from_numpy(x_sample).float())\n", 201 | "y = poly(x)\n", 202 | "y_sample = y.data.numpy()\n", 203 | "plt.plot(x_sample, y_sample, label = 'Fitting Line')\n", 204 | "plt.show()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 15, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "predicted function : y = 8.05 * x^4 + 7.44 * x^3 + -7.49 * x^2 + 9.16 * x^1 + 4.89\n", 217 | "real function : y = 8.06 * x^4 + 7.44 * x^3 + -7.54 * x^2 + 9.18 * x^1 + 4.94\n" 218 | ] 219 | } 220 | ], 221 | "source": [ 222 | "# 定义函数输出形式\n", 223 | "def func_format(weight, bias, n) :\n", 224 | " func = ''\n", 225 | " for i in range(n, 0, -1) :\n", 226 | " func += ' {:.2f} * x^{} +'.format(weight[i - 1], i)\n", 227 | " return 'y =' + func + ' {:.2f}'.format(bias[0])\n", 228 | " \n", 229 | "predict_weight = poly.poly.weight.data.numpy().flatten()\n", 230 | "predict_bias = poly.poly.bias.data.numpy().flatten()\n", 231 | "print('predicted function :', func_format(predict_weight, predict_bias, n))\n", 232 | "real_W = W_target.numpy().flatten()\n", 233 | "real_b = b_target.numpy().flatten()\n", 234 | "print('real function :', func_format(real_W, real_b, n))" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "Python 3", 248 | "language": "python", 249 | "name": "python3" 250 | }, 251 | "language_info": { 252 | "codemirror_mode": { 253 | "name": "ipython", 254 | "version": 3 255 | }, 256 | "file_extension": ".py", 257 | "mimetype": "text/x-python", 258 | "name": "python", 259 | "nbconvert_exporter": "python", 260 | "pygments_lexer": "ipython3", 261 | "version": "3.6.3" 262 | } 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 2 266 | } 267 | -------------------------------------------------------------------------------- /Chapter_4/classRNNs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 卷积层经典网络" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "from torch import nn, optim\n", 18 | "from torch.autograd import Variable\n", 19 | "from torch.nn import init" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "---\n", 27 | "## LeNet\n", 28 | "7层,其中2层卷积和2层池化层交替出现,最后输出3层全连接得到整体的效果。" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "class Lenet(nn.Module) :\n", 38 | " def __init__(self) :\n", 39 | " super().__init__()\n", 40 | " \n", 41 | " layer1 = nn.Sequential()\n", 42 | " layer1.add_module('conv1', nn.Conv2d(1, 6, 3, padding=1))\n", 43 | " layer1.add_module('pool1', nn.MaxPool2d(2, 2))\n", 44 | " self.layer1 = layer1\n", 45 | " \n", 46 | " layer2 = nn.Sequential()\n", 47 | " layer2.add_module('conv2', nn.Conv2d(6, 16, 5))\n", 48 | " layer2.add_module('pool2', nn.MaxPool2d(2, 2))\n", 49 | " self.layer2 = layer2\n", 50 | " \n", 51 | " layer3 = nn.Sequential()\n", 52 | " layer3.add_module('fc1', nn.Linear(400, 120))\n", 53 | " layer3.add_module('fc2', nn.Linear(120, 84))\n", 54 | " layer3.add_module('fc3', nn.Linear(84, 10))\n", 55 | " self.layer3 = layer3\n", 56 | " \n", 57 | " def forward(self, x) :\n", 58 | " x = self.layer1(x)\n", 59 | " x = self.layer2(x)\n", 60 | " x.view(x.size(0), -1)\n", 61 | " x = self.layer3(x)\n", 62 | " return x" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "---\n", 70 | "## AlexNet\n", 71 | "相比于LeNet层数更深,同事第一册引入激活层ReLU,在全连接层引入了Dropout层防止过拟合" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "class AlexNet(nn.Module) :\n", 81 | " def __init__(self, num_classes) :\n", 82 | " super().__init__()\n", 83 | " self.features = nn.Sequential(\n", 84 | " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n", 85 | " nn.ReLU(inplace=True),\n", 86 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 87 | " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n", 88 | " nn.ReLU(inplace=True),\n", 89 | " nn.MaxPlool2d(kernel_size=3, stride=2),\n", 90 | " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", 91 | " nn.ReLU(True),\n", 92 | " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", 93 | " nn.ReLU(True),\n", 94 | " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", 95 | " nn.ReLU(True),\n", 96 | " nn.MaxPool2d(kernel_size=3, stride=2)\n", 97 | " )\n", 98 | " \n", 99 | " self.classifier = nn.Sequential(\n", 100 | " nn.Dropout(),\n", 101 | " nn.Linear(256 * 6 * 6, 4096),\n", 102 | " nn.ReLU(True),\n", 103 | " nn.Dropout(),\n", 104 | " nn.Linear(4096, 4096),\n", 105 | " nn.ReLU(True),\n", 106 | " nn.Linear(4096, num_classes)\n", 107 | " )\n", 108 | " \n", 109 | " def forward(self, x) :\n", 110 | " x = self.features(x)\n", 111 | " x = x.view(x.size(0), 256 * 256 * 256)\n", 112 | " x = self.classifier(x)\n", 113 | " return x" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "---\n", 121 | "## VGGNet\n", 122 | "相比于AlexNet, 它使用更多小的滤波器,层叠很多小的滤波器的感受野盒一个大的滤波器是相同的, 还能减小参数,同时有更深的网络结构" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 4, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "class VGG(nn.Module) :\n", 132 | " def __init__(self, num_classes) :\n", 133 | " super().__init__()\n", 134 | " \n", 135 | " self.features = nn.Sequential(\n", 136 | " nn.Conv2d(3, 64, kernel_size=3, pading=1),\n", 137 | " nn.ReLU(True),\n", 138 | " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", 139 | " nn.ReLU(True),\n", 140 | " nn.MaxPool2d(kernel_sze=2, stride=2),\n", 141 | " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n", 142 | " nn.ReLU(True),\n", 143 | " nn.Conv2d(128, 128, kernel_size=3, padding=1),\n", 144 | " nn.ReLU(True),\n", 145 | " nn.MaxPool2d(kernel_size=2,stride=2),\n", 146 | " nn.Conv2d(128, 256, kernel_size=3, padding=1),\n", 147 | " nn.ReLU(True),\n", 148 | " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", 149 | " nn.ReLU(True),\n", 150 | " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", 151 | " nn.ReLU(True),\n", 152 | " nn.MaxPool2d(kernel_size=2, stride=2),\n", 153 | " nn.Conv2d(256, 512, kernel_size=3, padding=1),\n", 154 | " nn.ReLU(True),\n", 155 | " nn.Conv2d(512, 512, kernel_size=3, padding=1),\n", 156 | " nn.ReLU(True),\n", 157 | " nn.Conv2d(512, 512, kernel_size=3, padding=1),\n", 158 | " nn.ReLU(True),\n", 159 | " nn.Conv2d(512, 512, kernel_size=3, padding=1),\n", 160 | " nn.ReLU(True),\n", 161 | " nn.MaxPool2d(kernel_size=2, stride=2)\n", 162 | " )\n", 163 | " \n", 164 | " self.classifier = nn.Sequential(\n", 165 | " nn.Linear(512 * 7 * 7, 4096),\n", 166 | " nn.ReLU(True),\n", 167 | " nn.Dropout(),\n", 168 | " nn.Linear(4096, 4096),\n", 169 | " nn.ReLU(True),\n", 170 | " nn.Dropout(),\n", 171 | " nn.Linear(4096, num_classes)\n", 172 | " )\n", 173 | " \n", 174 | " self._initialize_weights()\n", 175 | " def forward(self, x) :\n", 176 | " x = self.features(x)\n", 177 | " x = x.view(x.size(0), -1)\n", 178 | " x = self.classifier(x)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "VGG只是对网络层进行不断叠加,并没有太多创新\n", 186 | "\n", 187 | "---\n", 188 | "## GoogleNet的Inception模块\n", 189 | "Inception模块设计了一个局部的网络拓扑结构,然后将这些模块堆叠在一起形成一个抽象网络结构。具体来说就是运用几个并行的滤波器对水乳进行卷积盒池化,这些滤波器有不同的感受野,做后将输出的结果按深度拼接在一起形成输出层。\n", 190 | "\n", 191 | "新的Inception模块增加了一些 1 × 1的卷积层来降低输入层的难度,使网络参数减少,从而减少了网络的复杂性。" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 5, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "import torch.nn.functional as F\n", 201 | "\n", 202 | "class BasicConv2d(nn.Module) :\n", 203 | " def __init__(self, in_channels, out_channels, **kwargs) :\n", 204 | " super().__init__()\n", 205 | " self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)\n", 206 | " self.bn = nn.BatchNorm2d(out_channels, eps=0.001)\n", 207 | " \n", 208 | " def forward(self, x) :\n", 209 | " x = self.conv(x)\n", 210 | " x = self.bn(x)\n", 211 | " return x\n", 212 | "\n", 213 | "class Inception(nn.Module) :\n", 214 | " def __init__(self, in_channels, out_channels) :\n", 215 | " super().__init__()\n", 216 | " self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)\n", 217 | " \n", 218 | " self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)\n", 219 | " self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) \n", 220 | " \n", 221 | " self.branch3x3db1_1 = BasicConv2d(in_channels, 64, kernel_size=1)\n", 222 | " slef.branch3x3db1_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)\n", 223 | " self.branch3x3db1_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)\n", 224 | " \n", 225 | " self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)\n", 226 | " \n", 227 | " def forward(self, x) :\n", 228 | " branch1x1 = self.branch1x1(x)\n", 229 | " \n", 230 | " branch5x5 = self.branch5x5_1(x)\n", 231 | " branch5x5 = self.branch5x5_2(branch5x5)\n", 232 | " \n", 233 | " branch3x3db1 = self.branch3x3db1_1(x)\n", 234 | " branch3x3db1 = self.branch3x3db1_2(branch3x3db1)\n", 235 | " branch3x3db1 = self.branch3x3db1_3(branch3x3db1)\n", 236 | " \n", 237 | " branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)\n", 238 | " branch_pool = self.branch_pool(branch_pool)\n", 239 | " \n", 240 | " out = [branch1x1, branch5x5, branch3x3db1, branch_pool]\n", 241 | " return torch.cat(out, 1)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "Python 3", 255 | "language": "python", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.6.4" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 2 273 | } 274 | -------------------------------------------------------------------------------- /Chapter_4/simpleCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.autograd import Variable\n", 11 | "from torch import nn, optim" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "class SimpleCNN(nn.Module) :\n", 21 | " def __init__(self) :\n", 22 | " # b, 3, 32, 32\n", 23 | " super().__init__()\n", 24 | " layer1 = nn.Sequential()\n", 25 | " layer1.add_module('conv_1', nn.Conv2d(3, 32, 3, 1, padding=1))\n", 26 | " #b, 32, 32, 32\n", 27 | " layer1.add_module('relu_1', nn.ReLU(True))\n", 28 | " layer1.add_module('pool_1', nn.MaxPool2d(2, 2)) # b, 32, 16, 16\n", 29 | " self.layer1 = layer1\n", 30 | " \n", 31 | " layer2 = nn.Sequential()\n", 32 | " layer2.add_module('conv_2', nn.Conv2d(32, 64, 3, 1, padding=1))\n", 33 | " # b, 64, 16, 16\n", 34 | " layer2.add_module('relu_2', nn.ReLU(True))\n", 35 | " layer2.add_module('pool_2', nn.MaxPool2d(2, 2)) # b, 64, 8, 8\n", 36 | " self.layer2 = layer2\n", 37 | " \n", 38 | " layer3 = nn.Sequential()\n", 39 | " layer3.add_module('conv_3', nn.Conv2d(64, 128, 3, 1, padding=1))\n", 40 | " # b, 128, 8, 8\n", 41 | " layer3.add_module('relu_3', nn.ReLU(True))\n", 42 | " layer3.add_module('pool_3', nn.MaxPool2d(2, 2)) # b, 128, 4, 4\n", 43 | " self.layer3 = layer3\n", 44 | " \n", 45 | " layer4 = nn.Sequential()\n", 46 | " layer4.add_module('fc_1', nn.Linear(2048, 512))\n", 47 | " layer4.add_module('fc_relu1', nn.ReLU(True))\n", 48 | " layer4.add_module('fc_2', nn.Linear(512, 64))\n", 49 | " layer4.add_module('fc_relu2', nn.ReLU(True))\n", 50 | " layer4.add_module('fc_3', nn.Linear(64, 10))\n", 51 | " self.layer4 = layer4\n", 52 | " \n", 53 | " def forward(self, x) :\n", 54 | " conv1 = self.layer1(x)\n", 55 | " conv2 = self.layer2(conv1)\n", 56 | " conv3 = self.layer3(conv2)\n", 57 | " fc_input = conv3.view(conv3.size(0), -1)\n", 58 | " fc_out = self.layer4(fc_input)\n", 59 | " return fc_out" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "SimpleCNN(\n", 72 | " (layer1): Sequential(\n", 73 | " (conv_1): Conv2d (3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 74 | " (relu_1): ReLU(inplace)\n", 75 | " (pool_1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n", 76 | " )\n", 77 | " (layer2): Sequential(\n", 78 | " (conv_2): Conv2d (32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 79 | " (relu_2): ReLU(inplace)\n", 80 | " (pool_2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n", 81 | " )\n", 82 | " (layer3): Sequential(\n", 83 | " (conv_3): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 84 | " (relu_3): ReLU(inplace)\n", 85 | " (pool_3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n", 86 | " )\n", 87 | " (layer4): Sequential(\n", 88 | " (fc_1): Linear(in_features=2048, out_features=512)\n", 89 | " (fc_relu1): ReLU(inplace)\n", 90 | " (fc_2): Linear(in_features=512, out_features=64)\n", 91 | " (fc_relu2): ReLU(inplace)\n", 92 | " (fc_3): Linear(in_features=64, out_features=10)\n", 93 | " )\n", 94 | ")\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "# 建立模型\n", 100 | "\n", 101 | "model = SimpleCNN()\n", 102 | "print(model)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Sequential(\n", 115 | " (0): Sequential(\n", 116 | " (conv_1): Conv2d (3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 117 | " (relu_1): ReLU(inplace)\n", 118 | " (pool_1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n", 119 | " )\n", 120 | " (1): Sequential(\n", 121 | " (conv_2): Conv2d (32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 122 | " (relu_2): ReLU(inplace)\n", 123 | " (pool_2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n", 124 | " )\n", 125 | ")\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "# 提取前两层\n", 131 | "\n", 132 | "new_model = nn.Sequential(*list(model.children())[:2])\n", 133 | "print(new_model)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 5, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "Sequential(\n", 146 | " (layer1.conv_1): Conv2d (3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 147 | " (layer2.conv_2): Conv2d (32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 148 | " (layer3.conv_3): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 149 | ")\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "# 提取所有的卷积层\n", 155 | "\n", 156 | "conv_model = nn.Sequential()\n", 157 | "for name, module in model.named_modules() :\n", 158 | " if isinstance(module, nn.Conv2d) :\n", 159 | " conv_model.add_module(name, module)\n", 160 | "\n", 161 | "print(conv_model)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "layer1.conv_1.weight : torch.Size([32, 3, 3, 3])\n", 174 | "layer1.conv_1.bias : torch.Size([32])\n", 175 | "layer2.conv_2.weight : torch.Size([64, 32, 3, 3])\n", 176 | "layer2.conv_2.bias : torch.Size([64])\n", 177 | "layer3.conv_3.weight : torch.Size([128, 64, 3, 3])\n", 178 | "layer3.conv_3.bias : torch.Size([128])\n", 179 | "layer4.fc_1.weight : torch.Size([512, 2048])\n", 180 | "layer4.fc_1.bias : torch.Size([512])\n", 181 | "layer4.fc_2.weight : torch.Size([64, 512])\n", 182 | "layer4.fc_2.bias : torch.Size([64])\n", 183 | "layer4.fc_3.weight : torch.Size([10, 64])\n", 184 | "layer4.fc_3.bias : torch.Size([10])\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "# 提取模型中的参数\n", 190 | "\n", 191 | "for name, param in model.named_parameters() :\n", 192 | " print('{} : {}'.format(name, param.shape))" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# 权重初始化\n", 202 | "from torch.nn import init\n", 203 | "\n", 204 | "for m in model.modules() :\n", 205 | " if isinstance(m, nn.Conv2d) :\n", 206 | " init.normal(m.weight.data)\n", 207 | " init.xavier_normal(m.weight.data)\n", 208 | " init.kaiming_normal(m.weight.data)\n", 209 | " m.bias.data.fill_(0)\n", 210 | " elif isinstance(m, nn.Linear) :\n", 211 | " m.weight.data.normal_()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "关于参数初始化可参考[深度学习的weight initialization](https://zhuanlan.zhihu.com/p/25110150)\n", 219 | "\n", 220 | "可以从torch的[文档](http://pytorch.org/docs/master/nn.html?highlight=init%20xavier_normal#torch.nn.init.xavier_normal)中得到\n", 221 | "\n", 222 | "- `init.xvaier_uniform()`一般用于tanh的初始化,结果采样于均匀分布 $$U(-a, a) \\sim [-\\frac {\\sqrt{6}} {\\sqrt{fan\\_in + fan\\_out}}, \\frac {\\sqrt{6}} {\\sqrt{fan\\_in + fan\\_out}}]$$\n", 223 | "- `init.xvarier_normal()`,结果采样于正态分布 $$N(0, \\sqrt{\\frac 2 {fan\\_in + fan\\_out}})$$\n", 224 | "- `init.kaiming_uniform()` 结果采样于均匀分布 $$U(-a, a) \\sim [-\\frac {\\sqrt{6}} {\\sqrt{(1+a^2) \\times fan\\_out}}, \\frac {\\sqrt{6}} {\\sqrt{(1+a^2) \\times fan\\_out}}]$$\n", 225 | "- `init.kaiming_normal()`一般用于ReLU的初始化,初始化方法为正态分布 $$N(0, \\sqrt{\\frac 2 {(1 + a^2) \\times fan\\_in}})$$" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "Python 3", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.6.4" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 2 257 | } 258 | -------------------------------------------------------------------------------- /Chapter_5/N_Gram.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "---\n", 8 | "## 词嵌入\n", 9 | "\n", 10 | "词嵌入在PyTorch中通过函数 `nn.Embdding(m, n)` 来实现的。`m`表示所有的单词数目,n表示词嵌入的维度" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "from torch.autograd import Variable\n", 22 | "import torch.nn.functional as F\n", 23 | "from torch import optim" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "Variable containing:\n", 36 | " 0.8044 0.0947 -0.2797 0.4324 -0.5038\n", 37 | "[torch.FloatTensor of size 1x5]\n", 38 | "\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "word_to_ix = {'hello' : 0, 'world' : 1}\n", 44 | "embeds = nn.Embedding(2, 5)\n", 45 | "hello_idx = torch.LongTensor([word_to_ix['hello']])\n", 46 | "hello_idx = Variable(hello_idx)\n", 47 | "hello_embed = embeds(hello_idx)\n", 48 | "print(hello_embed)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "---\n", 56 | "## N Gram 模型" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "CONTEXT_SIZE = 2\n", 66 | "EMBEDDING_DIM = 10\n", 67 | "\n", 68 | "test_sentence = \"\"\"When forty winters shall besiege thy brow,\n", 69 | "And dig deep trenches in thy beauty's field,\n", 70 | "Thy youth's proud livery so gazed on now,\n", 71 | "Will be a totter'd weed of small worth held:\n", 72 | "Then being asked, where all thy beauty lies,\n", 73 | "Where all the treasure of thy lusty days;\n", 74 | "To say, within thine own deep sunken eyes,\n", 75 | "Were an all-eating shame, and thriftless praise.\n", 76 | "How much more praise deserv'd thy beauty's use,\n", 77 | "If thou couldst answer 'This fair child of mine\n", 78 | "Shall sum my count, and make my old excuse,'\n", 79 | "Proving his beauty by succession thine!\n", 80 | "This were to be new made when thou art old,\n", 81 | "And see thy blood warm when thou feel'st it cold.\"\"\".split()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "trigram = [((test_sentence[i], test_sentence[i+1]), test_sentence[i+2])\n", 91 | " for i in range(len(test_sentence) - 2)]" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "vocb = set(test_sentence)\n", 101 | "word_to_ix = {word : i for i, word in enumerate(vocb)}\n", 102 | "idx_to_word = {word_to_ix[word] : word for word in word_to_ix}" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "class NgramModel(nn.Module) :\n", 112 | " def __init__(self, vocb_size, context_size, n_dim) :\n", 113 | " super().__init__()\n", 114 | " self.n_word = vocb_size\n", 115 | " self.embedding = nn.Embedding(self.n_word, n_dim)\n", 116 | " self.linear1 = nn.Linear(context_size * n_dim, 128)\n", 117 | " self.linear2 = nn.Linear(128, self.n_word)\n", 118 | " \n", 119 | " def forward(self, x) :\n", 120 | " emb = self.embedding(x)\n", 121 | " emb = emb.view(1, -1)\n", 122 | " out = self.linear1(emb)\n", 123 | " out = F.relu(out)\n", 124 | " out = self.linear2(out)\n", 125 | " log_prob = F.log_softmax(out, 1)\n", 126 | " return log_prob" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "net = NgramModel(len(vocb), CONTEXT_SIZE, EMBEDDING_DIM)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "criterion = nn.CrossEntropyLoss()\n", 145 | "optimizer = optim.SGD(net.parameters(), lr=1e-2, weight_decay=1e-5)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 9, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "epoch: 20, Loss : 0.809687\n", 158 | "epoch: 40, Loss : 0.152495\n", 159 | "epoch: 60, Loss : 0.095547\n", 160 | "epoch: 80, Loss : 0.076622\n", 161 | "epoch: 100, Loss : 0.066594\n", 162 | "epoch: 120, Loss : 0.060094\n", 163 | "epoch: 140, Loss : 0.055386\n", 164 | "epoch: 160, Loss : 0.051820\n", 165 | "epoch: 180, Loss : 0.049030\n", 166 | "epoch: 200, Loss : 0.046760\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "epoches = 200\n", 172 | "for epoch in range(epoches) :\n", 173 | " train_loss = 0\n", 174 | " for word, label in trigram : \n", 175 | " word = Variable(torch.LongTensor([word_to_ix[i] for i in word]))\n", 176 | " label = Variable(torch.LongTensor([word_to_ix[label]]))\n", 177 | " out = net(word)\n", 178 | " loss = criterion(out, label)\n", 179 | " train_loss += loss.data[0]\n", 180 | " optimizer.zero_grad()\n", 181 | " loss.backward()\n", 182 | " optimizer.step()\n", 183 | " if (epoch + 1) % 20 == 0 :\n", 184 | " print('epoch: {}, Loss : {:.6f}'.format(epoch + 1, train_loss / len(trigram)))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 10, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "net = net.eval()" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "input: ('so', 'gazed')\n", 206 | "input: on\n", 207 | "\n", 208 | "67\n", 209 | "real word is \"on\", predicted word is \"on\"\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "word, label = trigram[19]\n", 215 | "print('input: {}'.format(word))\n", 216 | "print('input: {}'.format(label), end =\"\\n\\n\")\n", 217 | "\n", 218 | "word = Variable(torch.LongTensor([word_to_ix[i] for i in word]))\n", 219 | "out = net(word)\n", 220 | "\n", 221 | "pred_label_idx = out.max(1)[1].data[0]\n", 222 | "print(pred_label_idx)\n", 223 | "\n", 224 | "predict_word = idx_to_word[pred_label_idx]\n", 225 | "print('real word is \"{}\", predicted word is \"{}\"'.format(label, predict_word))" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "Python 3", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.5.2" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 2 257 | } 258 | -------------------------------------------------------------------------------- /Chapter_5/rnnModule.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from torch.autograd import Variable\n", 11 | "from torch import nn, optim\n", 12 | "from torch.nn import init" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "---\n", 20 | "\n", 21 | "## 1. 标准RNN" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# 建立简单的循环神经网络\n", 31 | "\n", 32 | "basic_rnn = nn.RNN(input_size=3, hidden_size=5, num_layers=2)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "Parameter containing:\n", 44 | " 0.0763 -0.0092 -0.1030\n", 45 | "-0.1201 -0.2317 0.3021\n", 46 | "-0.1568 -0.0355 0.3281\n", 47 | "-0.3040 -0.1278 0.2782\n", 48 | "-0.2930 -0.1496 0.2063\n", 49 | "[torch.FloatTensor of size 5x3]" 50 | ] 51 | }, 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "output_type": "execute_result" 55 | } 56 | ], 57 | "source": [ 58 | "# 访问第一层网络的 $w_{ih}$\n", 59 | "\n", 60 | "basic_rnn.weight_ih_l0" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "data": { 70 | "text/plain": [ 71 | "Parameter containing:\n", 72 | " 0.2065 0.3279 -0.0407 -0.1670 0.2792\n", 73 | "-0.2776 -0.3277 0.2683 0.0603 0.1309\n", 74 | "-0.3711 0.3196 0.4190 0.0326 -0.3182\n", 75 | "-0.0235 -0.0512 0.3191 -0.1126 0.0228\n", 76 | " 0.2597 -0.2444 0.3148 -0.0795 -0.1418\n", 77 | "[torch.FloatTensor of size 5x5]" 78 | ] 79 | }, 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "# 访问第二层的 $w_{ih}$\n", 87 | "\n", 88 | "basic_rnn.weight_ih_l1" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "text/plain": [ 99 | "Parameter containing:\n", 100 | " 0.2150 0.0983 0.0374 0.2345 -0.4368\n", 101 | " 0.4327 -0.0015 0.4188 0.0039 -0.2717\n", 102 | " 0.4181 -0.2944 -0.0375 -0.3815 -0.3615\n", 103 | " 0.1330 -0.4197 -0.3870 -0.2852 0.3714\n", 104 | " 0.0946 -0.0085 -0.4026 -0.1688 0.2727\n", 105 | "[torch.FloatTensor of size 5x5]" 106 | ] 107 | }, 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "# 访问第一层的 $w_{hh}$\n", 115 | "\n", 116 | "basic_rnn.weight_hh_l1" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "Parameter containing:\n", 128 | "-0.0978\n", 129 | " 0.2153\n", 130 | " 0.2239\n", 131 | " 0.2170\n", 132 | " 0.0324\n", 133 | "[torch.FloatTensor of size 5]" 134 | ] 135 | }, 136 | "execution_count": 6, 137 | "metadata": {}, 138 | "output_type": "execute_result" 139 | } 140 | ], 141 | "source": [ 142 | "# 访问第一层的 $b_{ih}$\n", 143 | "\n", 144 | "basic_rnn.bias_ih_l0" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# 输入的序列长度为5,批量是10,维度是3\n", 154 | "toy_input = Variable(torch.randn(5, 10, 3))\n", 155 | "\n", 156 | "# 初始隐藏状态 layer * direction = 2, 批量是10,维度是5\n", 157 | "h_0 = Variable(torch.randn(2, 10, 5))" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "torch.Size([5, 10, 5])\n", 170 | "torch.Size([2, 10, 5])\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "toy_output, h_n = basic_rnn(toy_input, h_0)\n", 176 | "\n", 177 | "# toy_output的形状应该为 (5, 10, 5)\n", 178 | "# h_n 的形状应该为 (2, 10, 5)\n", 179 | "print(toy_output.shape)\n", 180 | "print(h_n.shape)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "---\n", 188 | "\n", 189 | "## 2. LSTM" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 9, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "lstm = nn.LSTM(input_size=3, hidden_size=5, num_layers=2)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "LSTM中间比标准RNN多了三个线性变换,多的三个线性变换的权重拼在一起,所以一共是4倍,同理偏置也是4倍。\n", 206 | "换句话说,LSTM里面做了4个类似标准RNN所做的运算,所以参数个数是标准RNN的4倍。" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 10, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "Parameter containing:\n", 218 | "-0.0370 0.3933 0.4261\n", 219 | " 0.0104 0.2447 0.2633\n", 220 | "-0.3971 0.1586 0.2650\n", 221 | " 0.1254 0.0325 -0.0926\n", 222 | "-0.1548 -0.2009 -0.1171\n", 223 | "-0.1834 0.3234 0.0946\n", 224 | "-0.3815 0.0475 0.3884\n", 225 | "-0.2534 -0.2354 0.2124\n", 226 | "-0.4156 -0.1013 -0.3804\n", 227 | "-0.1959 0.2669 0.0208\n", 228 | "-0.1414 -0.0864 -0.0238\n", 229 | " 0.4439 -0.2000 -0.0276\n", 230 | "-0.3399 0.2919 0.0363\n", 231 | "-0.0171 0.1917 0.4374\n", 232 | " 0.3896 0.0809 -0.4040\n", 233 | "-0.3879 -0.3903 0.1277\n", 234 | " 0.1634 0.3729 0.1317\n", 235 | "-0.2193 0.1497 0.0905\n", 236 | " 0.1066 -0.2967 0.0568\n", 237 | "-0.2763 0.0103 0.3772\n", 238 | "[torch.FloatTensor of size 20x3]" 239 | ] 240 | }, 241 | "execution_count": 10, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "# 参数的大小将变成 (4 * hidden_size, input_size) = (4 × 5, 3) = (20, 3)\n", 248 | "\n", 249 | "lstm.weight_ih_l0" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 11, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "# toy_input的size为 (5, 10, 3)\n", 259 | "\n", 260 | "lstm_out, (h_n, c_n) = lstm(toy_input)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 12, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "name": "stdout", 270 | "output_type": "stream", 271 | "text": [ 272 | "lstm_out_size = torch.Size([5, 10, 5])\n", 273 | "h_n_size = torch.Size([2, 10, 5])\n", 274 | "c_n_size = torch.Size([2, 10, 5])\n" 275 | ] 276 | } 277 | ], 278 | "source": [ 279 | "# lstm的输出矩阵应该是 (5, 10, 5)\n", 280 | "print('lstm_out_size = ', lstm_out.size())\n", 281 | "\n", 282 | "# h_n 和 c_n 的size应为 (2, 10, 5)\n", 283 | "print('h_n_size = ', h_n.size())\n", 284 | "print('c_n_size = ', c_n.size())" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "---\n", 292 | "\n", 293 | "## 3. GRU\n", 294 | "1. GRU的隐藏状态数量为标准RNN的3倍;\n", 295 | "2. 网络的隐藏状态不是 $h_0 \\text{和} c_0$,而是只有 $h_0$;\n", 296 | "3. 其余部分和LSTM相同;" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "---\n", 304 | "## 4. 单步版本\n", 305 | "Pytorch提供 `RNNCell`, `LSTMCell`, `GRUCell`分别作为这三个模型的单步版本。\n", 306 | "\n", 307 | "它们的输入不再是一个序列,而是一个序列中的一步,也就是循环神经网络的一个循环。\n", 308 | "\n", 309 | "单步版本在序列的应用上更加灵活,能在基础上添加更多的自定义操作。" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [] 318 | } 319 | ], 320 | "metadata": { 321 | "kernelspec": { 322 | "display_name": "Python 3", 323 | "language": "python", 324 | "name": "python3" 325 | }, 326 | "language_info": { 327 | "codemirror_mode": { 328 | "name": "ipython", 329 | "version": 3 330 | }, 331 | "file_extension": ".py", 332 | "mimetype": "text/x-python", 333 | "name": "python", 334 | "nbconvert_exporter": "python", 335 | "pygments_lexer": "ipython3", 336 | "version": "3.6.4" 337 | } 338 | }, 339 | "nbformat": 4, 340 | "nbformat_minor": 2 341 | } 342 | -------------------------------------------------------------------------------- /Chapter_5/sequencePrediction/data.csv: -------------------------------------------------------------------------------- 1 | "Month","International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60" 2 | "1949-01",112 3 | "1949-02",118 4 | "1949-03",132 5 | "1949-04",129 6 | "1949-05",121 7 | "1949-06",135 8 | "1949-07",148 9 | "1949-08",148 10 | "1949-09",136 11 | "1949-10",119 12 | "1949-11",104 13 | "1949-12",118 14 | "1950-01",115 15 | "1950-02",126 16 | "1950-03",141 17 | "1950-04",135 18 | "1950-05",125 19 | "1950-06",149 20 | "1950-07",170 21 | "1950-08",170 22 | "1950-09",158 23 | "1950-10",133 24 | "1950-11",114 25 | "1950-12",140 26 | "1951-01",145 27 | "1951-02",150 28 | "1951-03",178 29 | "1951-04",163 30 | "1951-05",172 31 | "1951-06",178 32 | "1951-07",199 33 | "1951-08",199 34 | "1951-09",184 35 | "1951-10",162 36 | "1951-11",146 37 | "1951-12",166 38 | "1952-01",171 39 | "1952-02",180 40 | "1952-03",193 41 | "1952-04",181 42 | "1952-05",183 43 | "1952-06",218 44 | "1952-07",230 45 | "1952-08",242 46 | "1952-09",209 47 | "1952-10",191 48 | "1952-11",172 49 | "1952-12",194 50 | "1953-01",196 51 | "1953-02",196 52 | "1953-03",236 53 | "1953-04",235 54 | "1953-05",229 55 | "1953-06",243 56 | "1953-07",264 57 | "1953-08",272 58 | "1953-09",237 59 | "1953-10",211 60 | "1953-11",180 61 | "1953-12",201 62 | "1954-01",204 63 | "1954-02",188 64 | "1954-03",235 65 | "1954-04",227 66 | "1954-05",234 67 | "1954-06",264 68 | "1954-07",302 69 | "1954-08",293 70 | "1954-09",259 71 | "1954-10",229 72 | "1954-11",203 73 | "1954-12",229 74 | "1955-01",242 75 | "1955-02",233 76 | "1955-03",267 77 | "1955-04",269 78 | "1955-05",270 79 | "1955-06",315 80 | "1955-07",364 81 | "1955-08",347 82 | "1955-09",312 83 | "1955-10",274 84 | "1955-11",237 85 | "1955-12",278 86 | "1956-01",284 87 | "1956-02",277 88 | "1956-03",317 89 | "1956-04",313 90 | "1956-05",318 91 | "1956-06",374 92 | "1956-07",413 93 | "1956-08",405 94 | "1956-09",355 95 | "1956-10",306 96 | "1956-11",271 97 | "1956-12",306 98 | "1957-01",315 99 | "1957-02",301 100 | "1957-03",356 101 | "1957-04",348 102 | "1957-05",355 103 | "1957-06",422 104 | "1957-07",465 105 | "1957-08",467 106 | "1957-09",404 107 | "1957-10",347 108 | "1957-11",305 109 | "1957-12",336 110 | "1958-01",340 111 | "1958-02",318 112 | "1958-03",362 113 | "1958-04",348 114 | "1958-05",363 115 | "1958-06",435 116 | "1958-07",491 117 | "1958-08",505 118 | "1958-09",404 119 | "1958-10",359 120 | "1958-11",310 121 | "1958-12",337 122 | "1959-01",360 123 | "1959-02",342 124 | "1959-03",406 125 | "1959-04",396 126 | "1959-05",420 127 | "1959-06",472 128 | "1959-07",548 129 | "1959-08",559 130 | "1959-09",463 131 | "1959-10",407 132 | "1959-11",362 133 | "1959-12",405 134 | "1960-01",417 135 | "1960-02",391 136 | "1960-03",419 137 | "1960-04",461 138 | "1960-05",472 139 | "1960-06",535 140 | "1960-07",622 141 | "1960-08",606 142 | "1960-09",508 143 | "1960-10",461 144 | "1960-11",390 145 | "1960-12",432 146 | -------------------------------------------------------------------------------- /Chapter_5/sequencePrediction/seqInit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 序列预测\n", 8 | "\n", 9 | "---" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn, optim\n", 20 | "from torch.autograd import Variable\n", 21 | "from torch.nn import init" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## 已知\n", 29 | "1949年到1960年每一个月的飞机客流量" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAIABJREFUeJzsnXl8HFeV77+n1VJr3xfLkmx5i+3E\njh3HceIsZCcLDGEnIQMBMgQYmGGY4UHgPbZ5DAzvsWVmGCAMMwlbIATyEjLZd7I6duIk3i3bshZr\n39WtVkvd9/1RVa3W2i11yVp8vp+PPqq6VXXrdsn+1elzzz1HjDEoiqIoixfPXA9AURRFmV1U6BVF\nURY5KvSKoiiLHBV6RVGURY4KvaIoyiJHhV5RFGWRo0KvnBRE5CERuekk3MeIyOoZXHeRiByMc06t\niFxhb39ZRP5jpuNUlJOJCv0CIVZkEjj3aRH5q9ke0xT3/7qI/Cq2zRhzjTHmzrkaUzyMMX82xqyd\nxvnfMsbM2TNWlOmgQq+MQ0RS5noM8wkR8c71GBY7+oxnFxX6BYiIfEREnhOR74pIl4gcE5Fr7GP/\nBFwE/JuI9IvIv9nt60TkMRHpFJGDIvL+mP7uEJEfi8iDIuIHLrXbfiQi/y0ifSLysoisirnmNhGp\nF5FeEdklIhfZ7VcDXwY+YN//dbs9+i1DRDwi8r9E5LiItIrIL0Qkzz5WbbtfbhKROhFpF5H/GXPf\nbSLyooh0i0iTiPybiKQl+Nw+KiL77c9zVEQ+EXPsEhFpiNmvFZEvisgbgH+sEMV+a0lgzB4RuVVE\njohIh4jcLSKFk4zxEhFpsF1D7fY4bow5/jYRec1+7vUi8vWYY+ki8iv7Ht0i8oqIlNnHPmJ/5j77\n30tsnx+zn0uXiDwiIstjjhkR+aSIHLb7/JGIiH0sRUS+Z4/zmIh8xj7fax/PE5Gf23+nRhH5pthG\nhD2e50XkByLSAXxdRFaLyDMi0mP3+btE/q5KAhhj9GcB/AC1wBX29keAIeDjQArwKeAEIPbxp4G/\nirk2C6gHPgp4gbOAduB0+/gdQA9wAdbLP91u6wC22df8GvhtTJ9/CRTZx/4BaAbS7WNfB341ZvzR\nMQEfA2qAlUA28Efgl/axasAAPwMygE3AILDePn42cJ5932pgP/B3MfcxwOpJnuHbgFWAABcDAWCL\nfewSoGHM894NVAEZE/wNop8xgTF/FngJqAR8wE+BuyYZ4yXAMPB9+9yLAT+wNub4RvvvdCbQArzT\nPvYJ4E9AJta/i7OBXPvv3xvTRzlwhr19nf23WG8/0/8FvDDmeT4A5APLgDbgavvYJ4F99ucqAB63\nz/fax++1P2sWUArsAD4R8294GPgb+74ZwF3A/2Tk3+CFc/3/brH8zPkA9CfBP9R4oa+JOZZp/wdb\nYu8/zWih/wDw5zH9/RT4mr19B/CLMcfvAP4jZv9a4MAU4+sCNtnbURGMOR4dE/AE8Ncxx9Zivbgc\n8TZAZczxHcD1k9z374B7Y/YnFfoJrv1/wGft7UsYL/Qfm+JvEP2M8caM9TK6POZYufN5JxjTJbYA\nZsW03Q18ZZLP8EPgB/b2x4AXgDPHnJMFdAPvwX5pxRx7CLg5Zt+D9QJcHvM8Lxwzllvt7Sexhdve\nv8I+3wuUYb3sMmKO3wA8FfNvuG7MWH4B3B77HPXHnR913Sxcmp0NY0zA3sye5NzlwLn2V+9uEekG\nbgSWxJxTP9U9sP7zR/sXkc/bX/d77P7ygOIEx74UOB6zf5wRcZjy3iJymog8ICLNItILfCvR+4rI\nNSLykljuq26sl9dU1070TKZisue1HLg35tnvB8KM/ryxdBlj/DH7x7GeGSJyrog8JSJtItKDZVU7\nn+GXwCPAb0XkhIj8HxFJtfv6gH1uk1juuHUxY7stZmydWN94KhL4XEsZ/Yxit5cDqfb9nL5/imXZ\nT3Q+wBfse+8Qkb0i8rFJno8yTVToFydjU5LWA88YY/JjfrKNMZ+a4ppJEcsf/wXg/UCBMSYfy/Uj\nCfZ1AksIHJZhWbEtCdz+x8ABYI0xJhdrPkCmvgRExAf8AfguUGaP+cE417qV2rUeuGbM8083xjRO\ncn6BiGTF7C/DemYAvwHuB6qMMXnAT7A/gzFmyBjzDWPM6cD5wNuBD9vHHjHGXIn1beIAlpvJGdsn\nxowtwxjzQgKfqwnLbeNQNeYzDwLFMf3mGmPOiDln1PM1xjQbYz5ujFmK5Yb6d5lBqKwyHhX6xUkL\nlv/b4QHgNBH5kIik2j/niMj6GfafgyXMbYBXRL6K5QuOvX+1iEz27+su4HMiskJEsrGs8t8ZY4YT\nvHcv0G9bpZ+Kc75DGpbPuw0YFmvy+q0JXpssPwH+yZnkFJESEbkuzjXfEJE0+6X6duD3dnsO0GmM\nCYrINuCDzgUicqmIbLQnPHux3EMRESkTkevsl8cg0A9EYsb2JRE5w+4jT0Tel+Dnuhv4rIhUiEg+\n8EXngDGmCXgU+J6I5Io1Ib1KRC6erDMReZ+IOC+OLqwXQWSy85XEUaFfnNwGvNeOovgXY0wflqhd\nj2UZNgPfwRK+mfAI8DBwCMutEGT013BHlDpE5NUJrv9PLDfDs8Ax+/q/SfDen8cStz4sqzShyAz7\nGfwtljh12X3cn+A9k+U2+16Pikgf1sTsuVOc34w1xhNYk+CfNMYcsI/9NfCPdj9fxfo8DkuAe7BE\nfj/wDNZz9gB/b/fXiTXB+ykAY8y9WP8Wfmu7wvYA1yT4uX6GJeZvAK9hfUMaxnJLgfVtIg1rwrbL\nHlv5FP2dA7wsIv1Yz+uzxpijCY5FmQInSkNRlHmAiFyCNclbGe/c+Yb9LeknxpjlcU9WTipq0SuK\nMiNEJENErhURr4hUAF/DCqlU5hkq9IqizBQBvoHllnkNy1301TkdkTIh6rpRFEVZ5KhFryiKssiZ\nF4mEiouLTXV19VwPQ1EUZUGxa9eudmNMSbzz5oXQV1dXs3PnzrkehqIoyoJCRI7HP0tdN4qiKIse\nFXpFUZRFjgq9oijKIkeFXlEUZZGjQq8oirLIUaFXFEVZ5KjQK4qiLHJU6BVFUVzk9fpudh3vnOth\njEKFXlEUxUW+/dB+Pv3r14hE5k8eMRV6RVEUF+kODNHcG2RH7fyx6lXoFUVRXKRnYAiA+18/EefM\nk4cKvaIoiov02kL/4JtNhIbnR8lbFXpFURSXGA5H8IfCbKrMozswxHM1bXM9JCBBoReRfBG5R0QO\niMh+EdkuIoUi8piIHLZ/F9jnioj8i4jUiMgbIrJldj+CoijK/KA3OAzA284sJy8jlft3zw/3TaIW\n/W3Aw8aYdcAmrJJhtwJPGGPWAE/Y+2BVkF9j/9wC/NjVESuKosxTHLdNcbaPS9eW8OLRjjkekUVc\noReRPOAtwM8BjDEhY0w3cB1wp33ancA77e3rgF8Yi5eAfBEpd33kiqIo8wxnIjY3PZWSHB+9A8Nz\nPCKLRCz6FUAb8F8i8pqI/IeIZAFlxpgm+5xmoMzergDqY65vsNsURVEWNb1BW+gzUsn2pTIwFGYo\nPPcTsokIvRfYAvzYGHMW4GfETQOAsSqMT2t1gIjcIiI7RWRnW9v8mLBQFEVJBseCz8tIJSfdKuDn\nH5x7qz4RoW8AGowxL9v792AJf4vjkrF/t9rHG4GqmOsr7bZRGGNuN8ZsNcZsLSmJW/JQURRl3hN1\n3WR4o0LfF1wAQm+MaQbqRWSt3XQ5sA+4H7jJbrsJuM/evh/4sB19cx7QE+PiURRFWbQ4rptYi95p\nm0sSLQ7+N8CvRSQNOAp8FOslcbeI3AwcB95vn/sgcC1QAwTscxVFURY9PQNDeD1CRmoKOempAPTP\nA4s+IaE3xuwGtk5w6PIJzjXAp5Mcl6IoyoKjd2CIvIxURIRs3wJy3SiKoiiJ0TMwRG6GZck7rpv+\nBTIZqyiKoiRAb3CYXFvgs6OTsXPvo1ehVxRFcYneGIs+1/bR96lFryiKsniIFXqf10NqiqiPXlEU\nZTHRG7QmY4HohKy6bhRFURYJxhhrMtZ22QDkpKfOi/BKFXpFURQXCA5FGAqbqEUP2Ba9Cr2iKMqi\nIDb9gUNOulcnYxVFURYLsekPHHLSU9WiVxRFOdlEIoY/vX6CQMhdAe6NyUXvkJOuk7GKoignnT+9\ncYK/ues1Ht3b4mq/I66b0UKvK2MVRVFOIsPhCLc9fhiA7kDI1b4nct04k7FWCrC5Q4VeUZRThvtf\nP8HRdj/gfg6anoDjuomdjE0lHDEEh+a2ypQKvaIopwTD4Qi3PXGY08tz8Xk9rk+S9tr9xbpu5ku+\nGxV6RVFOCV4+1snxjgCfuWw1OempUWF2i96BITLTUkhNGZFVx7qf6xBLFXpFUU4J2voGAVi7JIfc\nWYiG6RkYGuWfB+ZNOUEVekVRTgm67MnXgsw0O+zRbdfN6PQHANk+O4Olum4URVFmn67AECJOPdfU\nWbHoY1fFQkzxEbXoFUVRZp+eQIjc9FRSPDIrOWh6B4bHuW7mSzlBFXpFUU4JugJD5GeOlPlzPbwy\nJhe9w3wpPqJCryjKKUFXIER+Zhrgfg4aYwzt/YMUZ/tGtWf5UgD10SuKopwUugNDFIyx6MMRd1as\n9g8OMzgcoSgrbVS7N8VDZlqKum4URVFOBl2BEAVRi96eJHXJpdLRb0X0jLXonXvpZKyiKMpJoCfG\nRx/1nbvkUmnvt2L0i3PGC322z0vfoLpuFEVRZpWhcIS+wWHyM0Zb9G65VByhH+u6se419znpVegV\nRVn0dNsJxwqyHB+9Y9G7JfSW66ZkAot+NhZnTRcVekVRFj1OSmIn6iY76qN313VTOKFFP/c56VXo\nFUVZ9HQ5Fn1M1A2467rJz0wdldDMwVqcpT56RVGUWSVq0Y/x0buVwbKjPzRhxI11L/XRK4qizDqO\nj342o26Ks8e7bcB6qQRCYYbDc1d8JCGhF5FaEXlTRHaLyE67rVBEHhORw/bvArtdRORfRKRGRN4Q\nkS2z+QEURVk8/OzZo3z3kYOu9xvNXGn70H1eD6kp4upkbNEkFn1pTjoALXaa5LlgOhb9pcaYzcaY\nrfb+rcATxpg1wBP2PsA1wBr75xbgx24NVlGUxYt/cJgfPn6IB/c0ud53V2CI1BQhK81KSSAirmaw\nbO8fpGQSoa8oyADgRPeAK/eaCcm4bq4D7rS37wTeGdP+C2PxEpAvIuVJ3EdRlFOAP71+An8oPCur\nSLvtPDciEm1zK+wxOBSmLzg8qeumIt+y6BeC0BvgURHZJSK32G1lxhjn1dsMlNnbFUB9zLUNdtso\nROQWEdkpIjvb2tpmMHRFURYTd+2oA9wv2g2Wjz5/gupPbrxUOv2WW2gy183SfMuib+iaO6H3xj8F\ngAuNMY0iUgo8JiIHYg8aY4yITCs7kDHmduB2gK1bt7qTWUhRlAXJnsYeXm/ooTTHR2vfIOGIIcUj\n8S9MkNg8Nw5u5aSPpj+YROgz07wUZKbOf4veGNNo/24F7gW2AS2OS8b+3Wqf3ghUxVxeabcpiqJM\nyG9fqcPn9XD9OZZ0uG3Vd8fkuXGwCoQn76MfEfqJXTdgWfWN81noRSRLRHKcbeCtwB7gfuAm+7Sb\ngPvs7fuBD9vRN+cBPTEuHkVRlHG8UNPBW04robIgE3A/f/tEFr1bPvr2KTJXOizNz5hTiz4R100Z\ncK89ieEFfmOMeVhEXgHuFpGbgePA++3zHwSuBWqAAPBR10etKMqioqU3yCVrS2NSE7hbFGQiiz7X\npaibeK4bgIr8DF6oaccYM2pC+GQRV+iNMUeBTRO0dwCXT9BugE+7MjpFURY9/YPD+ENhSnN90Rqr\nbkbeDAyFCYUj0Tw3Dk4OmmTFt70vRFZaChl26OZEVORn4A+FrbqyY144JwNdGasoypzS2hsEoCzX\nN5KDxkWLfmyeG4ecdC8RA/5QOKn+O/yDk0bcODix9HPlp1ehVxRlTmm1V4yW5qSPVH5y0aLv8o/O\nXOngpCpO9l5TpT9wcEIsVegVRTklaYmx6LN9tvi6aNF3T2HRQ/ITv+19kyc0c1g6x4umVOgVRZlT\n2myLviQnfWQy1k2LPjCxRe/MBySbwTIR101xlo80r2fOhD7RBVOKoiizQktvkPRUD7npXowBEXfD\nK0cSmo2Po4fk7tXQFaDDH6I8L33K8zweYWleOg1q0SuKMl/pDoR47rAVHug2rX2DlOakIyJ4PEJ2\nmtfVydiW3iApHqEoa7TVnetC8ZEfPVVDqsfDe8+ujHvuXMbSq0WvKMqktPYG+cp9e3jyQCtDYcNv\nPn4u568qdvUeLb1BynJHRDjbpRw0Dk09QcpyfONSKiRbN7a+M8DvdzZw47nLopOtU1GRn8Ezh+Ym\nr5da9IqiTMoj+1p4ZG8L1260EtA2dQddv4dj0Ttk+9ytsdrSG6RsAtdKspOx//rkYTwe4a8vXZ3Q\n+UvzM2jtG2RwOLlwzpmgQq8oyqTUdwZI83r45js3ACOrQN2ktXeQ0hiL3u1i2k09wQl96JlpKfi8\nnmj2yenQ0T/IH15t5MZzl1GWO7V/3qHCtvpbek5+ARIVekVRJqW+M0BlQQbZPi/pqR7Xhd4/OEz/\n4PBoi97FGqvGGJp7ghOKsYhQmuuLhndOhyNtfsIRw6VrSxO+pjjHivrp8KvQK4oyj6jvClBVkImI\nUJzto6N/+tbvVDiLpWJ99Dkuum76BocJhMKTRsWU5aRHxzAd6joDACwrzEz4mkJ7MtiJAjqZqNAr\nijIpdR0Bqgotl0NRto82ly16J/3BWB+9W+GVLT3OYqyJhX6mFn1dZwCPkNAkrEOhHcfv9ssyEVTo\nFUWZkJ6BIXqDw1GrtSQ7LZqS1y1aJrDo3Yy6abKFvjxvYkEunaFFX98ZoDwvgzRv4hLqxPGrRa8o\nyryh3nZPVNk54i3Xzcmx6P2hMOFI8jH7zXb/S6aw6PuCwwxMM7FZXWdgWm4bsD5XaorQ6Xc3134i\nqNArijIhDV220NuCVpSdRoc/RMQFAXZo7RvE5/WQmzGypMcJe/SHkrfqHddNbFRPLM4LprVveu6b\nmQi9iFCQmRZNsnYyUaFXFGVC6jutVZyxFn04YugecM8ibe0NUprrG5UP3s2c9E29QQqz0khPnThX\nvOMyaulN/JvKQChMW99gdO5iOhRmpdGprhtFUeYLdZ0BctO90UIZToZGN903Lb2DlOWMdqtE0we7\nEHnT0hOc1G0DM7Po68d805kOatErijKvqO8KjBKzIjvnupuRN619wXFulWyX0geDNRm7ZIqEYzOx\n6Os6ph9a6aAWvaIo84r6zkDUbQNQYlv0bkbejE1/ACOuGzcWTbX0Ti30eRmppHk907LoZxJD71CY\nNWLRG2MIDp2cdAgq9IqijCMSMTR0DYzyQzuum/YZhCNORGg4Ql9wmKKs8bVcIXnXzeBwmA5/aErX\njYhQmuOjdToWfWeArLQUCrOmrio1EQVZaXQPDBGOGHoGhlj3lYf55Yu10+5numj2SkVRxtHWP8jg\ncGSU1ZqXkUqKR1xbwt894OSJn7ggSLKTsY54T2XRA5bQT8dH32m5tGZSULwwMxVjrDUKTlRTaYK5\ncpJBLXpFUcbhxNBXxgi9xyMUZaXR3ueO66bLjicfaxlnu2TRO4ulprLowVo1Oy0f/QxCKx2cl1qn\nP5SUC2i6qNArijKOaGRJwWgRKs72uZbYrDNatHt05aesNHd89M5iqXjVn0pzEk+DYIxJSuidl1pX\nYEToZxK9M11U6BVFGUdjlxVDX1kwOla8KDuNdpfCA51UAGMt+hSPuJKTPprnJp7Q56YnvDq2rc92\naRXN0KLPHLHo6zsDFGWlRV1Vs4kKvaIo42jvD5GT7h230Kgk2+faZKxj0Rdmjp/UdCOx2dF2Pznp\nXnLiCGlpjjXJPJGfPjgU5li7P7p/PEkrvHCM6+ZkWPOgQq8oygS09w9Go2xiKc6xXDdu1I7tirpu\nJhB6F4qP7DreyZZlBXEnTZ3MlhMlN/vHB/Zx7W1/joZB7m3sAWDdkpwZjWms0J8M/zyo0CvKguX7\njx7kv99ompW+O/2hCcMHi7LSGByO4J9mErCJ6AoMkePzTpgB0rLoZy70PYEhDrX0s3V5QdxzS6OL\npkZb9K19Qe7Z2cDAUJg9tsC/2dhLcXZa3AneyUhPTSEzLYW2vkFOdAdPmtBreKWiLEDCEcNPnjmK\nxwPrynNYVZLtav+d/tCEboXYWPpkfctdgdC40EqHZMsJ7qrrBGBrdWHcc50UDGNj6f/r+VqGIhEA\ndtd3s7W6kD2NPWyoyJtRaKVDQWYae0/0EI4YtegVRZmcE90DhMIRgkMRPve73QyFI672394fojh7\nvAgX5zirY5P303f6QxSMibhxyPZNPye9MSaa2viV2i68HmFzVX7c6/IzU0lL8dAS46PvCw7xq5eO\nc+2GciryM3itvpuBUJjDrX1srMib1rjGUpiVxpv2N4R556MXkRQReU1EHrD3V4jIyyJSIyK/E5E0\nu91n79fYx6tnZ+iKcupy3M638pHzq3mjoYefPH3Etb4jEUNXYHLXDbiTBmEqi34mUTef+OUuPvmr\nXQDsqu3ijIo8MtImzloZi1M7trlnROh/u6OevuAwn7x4FZur8tld182+pl4iBjYkKfQFWWkEh6wX\n80yjd6bLdCz6zwL7Y/a/A/zAGLMa6AJutttvBrrs9h/Y5ymK4iLHOqxIkE9evIpt1YU8caDVtb57\ng9YSfafGaSwltkXvRmKzTn9owogbgNyMVLoDQ9Oa9N3f3Mtj+1p4bF8Luxu6E/LPO1TkZ0RDSgFe\nPNrB2rIcNlbmsbkqn8buAZ4+aD3jpC16+1tMaorM2Nc/XRISehGpBN4G/Ie9L8BlwD32KXcC77S3\nr7P3sY9fLsk4tBRFGUdtu5/0VA+lOT6WF2XS1DMQ/6IEcaz1sTlonDYRK548Wbr8k1v0S/MzGBgK\n0xVIPMTSqcX6+d+/Tmg4wjnViQt9ZUEmjd0jz7C+M8By29revMxy//z2lXqKstLiLsCKh/OZKwsy\nSfGcHGlM1KL/IfAFwHEEFgHdxhjnu1UDUGFvVwD1APbxHvv8UYjILSKyU0R2trW1zXD4inJqcrzD\nT3VRFh6PUJ6fQWvfoGt+eie+vWgCH703xUNRli9aAnCmDA6H8YfCkyYGcxZqxVrZUxEIDRMIhdlc\nlU+PXRjl7OXxJ2Jj79fcGyQ0HMEYMypF84aleaR4hLa+waQnYmFk3cDJ8s9DAkIvIm8HWo0xu9y8\nsTHmdmPMVmPM1pKSEje7VpRFz7F2f9TiXJqXjjHjwwNnSqedtGwyEbaSgCVn0XfblnrBJK4bR+id\nxF/xcKz5D25bxrYVhZxWlh11MyVCRUEGxkBTzwBt/YMEhyJU2WPISEuJxs0n67aBEYt+2QwqVM2U\nROKjLgDeISLXAulALnAbkC8iXttqrwQa7fMbgSqgQUS8QB7Q4frIFeUUJRwx1HcOcMXpZQCU51uC\n0dQTpLIgeSuxw7HoJ/DRgxV3Pt0aq2OJrorNmjjqpjLf+hwNCVr0ThRQcU4a//mRc6ad5z32G4TP\nXg0cO1G6uSqfvSd6k56IhRGX2MkKrYQELHpjzJeMMZXGmGrgeuBJY8yNwFPAe+3TbgLus7fvt/ex\njz9p3FhGpygKMBJaWV2UBYwk7TrR7Y6f3rGOp7Top5HtcSKmWhULkJthpS5oTPAzOWMuzvaR7fNO\nuKp3KpzkbQ1dAyNF0WNemhetKcHn9bBlWfxwzXgU2WM7mUKfzIqHLwK/FZFvAq8BP7fbfw78UkRq\ngE6sl4OiKC5Ra0fcjBX62PDAZOj0W3luJlqxClad1fb+QcIRM+PJxM5JEpo5iAgVBRkJu24ci75o\nmgLvsCQvHY9YriLnc8d+O7rqjDJ2feVKVxKQnb28gG++cwOXrStLuq9EmdaojTFPA0/b20eBbROc\nEwTe58LYFEWZgFo7hn5FsSX0Oemp5Pi80fzrydLhD00YceNQmusjYqDDP74MYKI4Fv1kPnqwhDZh\nH71/8kihREhN8bAkN52GrgFSUzwUZ/tGxeCLiGtZJlM8wl+et9yVvhJFV8YqygIjNrTSoTw/3UXX\nzeCUZfKi2R6TcN84YZNjc9HHUlmQQUPXQEKx9O39g+T4xmfbnA6VBZk0dA/YycZO3kTpyUCFXlEW\nGLXtI6GVDuV5Ga5Z9J3+0JQukBLbik8mlr7THyI33UtqyuQSVFmQQf/gML0D8VfIdvSHJgwHnQ6V\nBdaiqdjQysWCCr2iLDBqO0ZCKx2W5qe7tmgqrutmivztiTJZioVYnEiY+gTcN+39gzP2zztUFGTQ\n1DNAU09wXGWthY4KvaIsIJzQSmci1mFJbgbt/SEGh5NLHxyJGLomSVHsUOKC66bTH5o04sahsiDx\nEMuOSZKwTYfKggwixnrGVeq6URRlrujoHyQUjowr8Vee707kTW9wiOGImdI6Tk9NIS8jNalFU4lY\n9BX2+oBEQiw7/Mlb9LFRNmrRK4oyZzjiWjIm2mVpniWKJ7qTE/pEo1es1bFJuG78Q1NG3IA1UZuV\nlhI38iYcMXT6QxTPMOLGwXmxwMlNT3AyUKFXlFlgttYIOmkOynJHW69Ri743OT/9yIrVOEKfm1wa\nBMuinzziBqyQRivEcurP1BUIETEjufJnSnl+OiJW+GOyicvmGyr0iuIyD73ZxLnfeiKaXMtNHHEt\nzZ0li75/6jw3DqU56Qn76CMRw8Hmvuh+cChMIBSeNHNlLE6I5VR09E+dsiFRfN4UynLSWZqfjneK\naKCFyOL6NIoyD9hd301r3yBPHmhxvW/Hoi8Z44/OSEshPzM16cgbx3UTL4VAaY6Ptr7EioQ/vr+F\nq374LDuOWeX9Xjpqpb5anUD5w4qCDBoncd3cs6uB1t5gzKrY5Fw3YJVlXL8kN+l+5hsq9IriMifs\nCdEH32x2ve/WPmsx00TpCcrzMmhK0qLvtK3jgjhuldLcdELhSDQL5VQcbu0H4DcvHwcsgc7PTOXi\ntfGz1lYWZNAbHKY3OPo+te1+Pv/71/nRUzUjCc2SnIwF+NcbzuL7H9icdD/zDRV6RXGZJjtK5JlD\nbUkVuJ6I1t7gqBWxsZTnpUeqN2d8AAAgAElEQVRfMony0f/awdfv3xvdb7NXmPq8U68wHYmlj+++\ncVwvD+5ppq4jwKP7Wrhu09K49wCrAAkw7gXmfDt4bF9LtFBKsuGVYKWTcCvVwXxChV5RXKapJ8jy\nokxCwxGedLHEH1jCOtY/71CeN71FU8YYXj7WyR0v1PLI3mYOtfRxz66GaEWlqZjOoqmGrgCFWWmE\nhiP89W92ERqO8J6zKxMaY3l07mH053rZFvoTPUGePdSG1yPkpk/9LeRURoVeUVwkHDE09wa5dmM5\npTk+HnqzydX+W3qDlE1i0S/Nz6A7MMRAKLFFU/2DVlUmEbj1D29wyy92kuXz8t33bYp7rfOySWRC\ntqFrgO2rithclc+exl7WlGYnXMBjqR1NdGLMC2xHbQfnVBfgEXj2cBtF2WmjUkIoo1GhVxQXae0L\nEo4YKvIzuHrDEp462Eog5I77JhwxtPeHKM2d3HUDJGzVt9gi/ZlLVxMIhWnoGuDHN26hLIGC1Ym6\nbiIRQ2PXAJUFGXxw2zIA3nN2ZcLl+Epz0knxyCjXzYnuAeo7B7h6QzlblxdiTPIRN4udxeeMUpQ5\nxAlvrMjPoKIgg1+8eJzX63vYvmpc2eRp0+G3csBPJsSOm6OpJ8jKBCJanLqv568q5twVRUSMYWt1\nYnVWs3xeOzXy1C+V1j5rJW9VQSbXnbWUnoEhbjh3WUL3ACumvSzHN8qif6XWctucu6KQSMSwo7bT\nlYibxYwKvaK4iCN85fnp5Ng+4yNt/a4IveMmmWwyNurmSDBdcUvfyOKrRF4MY6kqzKS+c+pVq05C\nssqCDHzeFD7+lpXTvk95/uhoopePdZLt87K+PJdsn5d/enD/uHBTZTTqulEUF3EEqTwvg6V56WSm\npVBjhxcmizPxOdlkrGPpJ5qu2HHdTNZfPJYVZlI3RugDoWG+dt8ePve73cBIce9katmOnWTecayT\nrdUFpHiE6uIsbthWFa2fq0yMWvSK4iInegbISkshN92LiLCqJJsjbS4JfRyLPj01haKstGn46INk\n+7wzDidcVpTJkwdbiUQMHo9wqKWPT/1qF0farFKHt16zjvpOayxjk7BNh4r8DB7d14IxVk6bmtZ+\n3r2lInr82+8+c8Z9nyqoRa8oLtLUHaQ8PyM62bi6NNs1i96xwEumyOliVZpKzKJv7R2cdGI3EaoK\nrRBSZ0L26/fvpSswxJevXQdYK2AbugKU5PiSqvxUnpdOaDhChz/E7vpuAM5eVjDj/k5FVOgVxUWa\negZGJcRaXZpNU0/QlYVTrX1BCrPSplxoZFWaStyiL5thzVewXDcAdZ0BjDHsPdHL1RuWcPOFK8lN\n9/LikQ7qOweoSsKaB8tHD9ZL9PWGHjwCGysTC89ULFToFcVFGruDo9LdriqxCoQcdcF909I7OKnb\nxmFpXnriPvq+YFIWfazQN/cG6RkYYv2SHFI8wrYVRbx4tIOG7kBS/nmISdjWM8AbDd2sKc0hM029\nztNBhV5RXGJwOEx7/2A0zBEsix5wxX3T1heMO3Fanp9BX3A47jcIYwwtvYMJxcxPRkV+BiKW0O9v\n6gVgXbmVEGz7qiKOdwRo6BpIulqTk4K5qXuANxp6OFOt+WmjQq8oLtHSY/mqHWECWF6Uhdcjrgh9\nIhZ9dNFUnBDLnoEhQsORuP1NRZrXw9K8DOo7A+xvstIQr12SA8D2lVY4qTHJRdyAVQQlzevhleNd\ndPpDKvQzQIVeUVzCWdSzNMaiT03xsLwoM2mhj0QMbf2D4wqOjCWaGyaO+8aZ2E3GogeoKsygrjPA\ngeY+KvIzovlm1i3JIT/T2k4m4gasAiRL89J5ys4bdGZl/Fw8ymhU6BXFJWIXS8Uy0xDL4XAkut3e\nb62KLY0zeZqoRT9SqSo5oV9mL5o60NTL+vKcaLvHI5y7wlpl60b91fK8DAKhMKkpwrqY+yiJoUKv\nKC7hhDXGWvRg+emPdwQYihHueDy8p5n1X32Ynz5zhO5AiM/c9RoAGyqmLoqxJM8qhzeRRW+M4aWj\nHYSGI5OWJJwuywozae0b5Gi7n3VjCnZct7mCNaXZ0VTDyeC8PNeX5yaU3lgZjU5dK4pLnOgeID8z\nlYy00UK0ujSb4YjheIef1aWJWaOv1nUxFDZ8+6ED3PbEYYbCEW67fjNnL586F01qioeSbN+EFv1D\ne5r561+/yhevXkfErgwV7xtCPJwi2uGIGWdpX7uxnGs3lifVv4Pz8lT//MxQi15RXKKxe2BCN8VM\nIm9q2/2sLs3mO+/ZSFVBJnd8dBvXba6IfyF2bpgxFn1fcIhv/MkqMPLbV+po7gmSm+4d91KaLk6I\nJTDOoncTx6JX//zMUIteUVyioWtgwjqoK4qtWPpj7VMnAIvleEeA6qJMPnDOMj5wTuLZHsGKpT/U\n0jeq7XuPHqK1b5CPXbCC/3z+GA8NNiftn4cRofd5PVQXJe+Ln4xNlfnkpnuj0TzK9Ihr0YtIuojs\nEJHXRWSviHzDbl8hIi+LSI2I/E5E0ux2n71fYx+vnt2PoChzjzGGhq7AhBEmOemplOT4ONaemEVv\njOF4p5/lRVkzGou1OjYYLdx9tK2fX7xYy1+eu5wvXL2WgsxU2vuTi6F3KMxKIysthdPKcvCmzJ6D\nYENFHm98/aqoq0iZHon8ZQaBy4wxm4DNwNUich7wHeAHxpjVQBdws33+zUCX3f4D+zxFmRe8UttJ\nlz/ker8d/hDBocikoYQrirM41u5PqK/WvkGCQ5EZW8jLizIJhMLRHDS7jncRMfCRC6pJT03hPVus\nMn7JrIp1EBHeesYSrt6wJOm+lNkjrtAbC8cUSbV/DHAZcI/dfifwTnv7Onsf+/jlkmg5GUWZRQaH\nw9z4s5f59G9ejVq7buEUwJ5scdDKaQh9rX3eshla9GPnBGra+klL8bDctoavtys9xebkSYYffGAz\nn750tSt9KbNDQt+1RCRFRHYDrcBjwBGg2xjjrLNuAJyZogqgHsA+3gOMc6yJyC0islNEdra1tSX3\nKRQlAeo7A4TCEV440sH/293oat+NttBXTGHRt/eH6BkYitvXcTvH+0wt+rFCf6S1n+rizKhrZXVp\nNj/90Nl86LzqGfWvLDwSEnpjTNgYsxmoBLYB65K9sTHmdmPMVmPM1pKSkmS7U5S4OJOhxdk+vvnA\nfnoC8UU3UZwCG1MJvTWG+Fb98Q4/Xo+MSo42HUpzfOT4vNFFWjWt/VHxd7jqjCUsccmiV+Y/05o9\nMcZ0A08B24F8EXGidioBx0RqBKoA7ON5QIcro1WUJHAmQ//1hrPoHhjiR0/XuNZ3Q9cAeRmp0RQA\nY1lZ4gh9/AnZ2o4AFQUZM57cFBFW2Xnwg0Nh6joDE0YDKacOiUTdlIhIvr2dAVwJ7McS/Pfap90E\n3Gdv32/vYx9/0rjtEFWUGXCs3U9hVhrbVxVx9vICdh3vcq3vhq7AlBZ4VWEmHoFjbfEt+rqOwIwj\nbhxWlVhCf7wjQMTAqlIV+lOZREyGcuApEXkDeAV4zBjzAPBF4O9FpAbLB/9z+/yfA0V2+98Dt7o/\nbEWZPsfa/VEXymll2Rxu6XNtUraxe2DK5F0+bwqVBZkcjeO6McZQ2+FPOiZ9dWk2rX2DvFZnvcxW\nqUV/ShN3wZQx5g3grAnaj2L568e2B4H3uTI6RXGRY+1+LlpjzQetKc2hNzhMW9/gjItjO1gx9ANc\nuHrquaapQixfreuiOxBic1UBfcHhUStOZ4Ljk39kbzMiKvSnOroyVjkl8A8O09I7GLXo19hCeLi1\nP2mh7woMEQiF46bjXVGcxSu1nRhjiI043nGskw/9/GWGwhE+d8VpAFQn6bpxhP75mg4q8jOSTnWg\nLGw0141ySlDbYVnSjtCvLrOFfkyqgJkQL+LGYVVJ1qiFTAB7Gnu4+Y5XqCzIYE1pDt977BAA1cXJ\nWfRVBRmkpXgIhSPjIm6UUw8VeuWUwHGZOJZySbaPvIxUDrlQ+akxulgqnkVvCe7RmAnZr9y3hyyf\nl1/efC4//dDZ5KZ7EUm+KpM3xTPyUlO3zSmPCr1ySuCsNnUsZRFhTWk2NS3JC328VbEOK0pGx9Ib\nYzjc0s9VZ5SxND+D6uIsfv6Rc7j16nWkpybvallVmmX/VqE/1VGhV+YVD+9p4uofPos/TnHr6XK0\n3c+S3HQy00ampdaUZXOoNbnIGydKJifdS17GxDH0DuW56fi8nmgsfac/RP/g8KhUB+dUF/KJi1fN\neDyxOJa8um4UnYxV5g0nugf4wj1v0Bsc5kBzb9wiG9MhNrTSYXVpDt2Bejr8IYqzE0/wFRwK89i+\nFu5//QS77ILVGyviF8TweGRU5E2dneog2QibybhkXSlPHWzj9PLZyxOvLAxU6JV5QSRi+PzvX2dg\nKAxYy/bdFPradj/XjKl2FI28aelPWOiNMbzr319gf1MvS3LTuWJ9KevLc7lkbWlC168ozuKgPQHs\nCP3yWcrjvmVZAX/6mwtnpW9lYaFCr8wL7t5ZzwtHOvind23gG3/aN61qTPHoDoToCgyxYkzI4poy\nJ/lXH9tXJVbQoq1vkP1NvXzm0tV87srTSPFMLzHriuIsHtvXwnA4wvGO2bXoFcVBffTKvODBPc2s\nKsnig9uWsbI4iyMJpApIlAPNlgXtCLvDktx0sn1eDk/jpeKcu31V0bRFHiyhH45YC6yOdwQoy/W5\nMvGqKFOhQq/MOYPDYXYc6+CiNSWjEnK5xb4TvQCcvnS0r1pEWF2aPa7s3lQ4cfdrZjjBuTIm8qau\n08/ywuQWRilKIqjQK3POq8e7CQ5FuHB1MWBFi9R3BQja/vpk2d/US3F2GqU541fArluSw4HmxCNv\nDrf2k5vupSRnZtWZorH07X7qOgMsm8U6q4rioEKvzDnP17ST4hHOXWlNvq4qzcaY0QuLkmFfUy/r\nJ4k8OX1pLt2BIZp7gwn1dbi1nzVlOcy0aFpBZip5Gansb+qlpXdQ/fPKSUGFXplznqtpZ3NVPjl2\nLncn/rumLXn3zVA4wuGW/klDDJ0XgOPeiUdNa/+M3TZguYtWlmTx7CGrqtpsRdwoSiwq9Mqc0jMw\nxBsN3Vxgu23A8mOLWCXwkuVIWz+hcGScf95h3ZIcIDGh7+gfpNMfSnoB0orirGi+G7XolZOBCr0y\np7x0tIOIIeqfB0hPTaGqINMVi35/kyXgk7luctJTWV6Uyf7m+ELvRNysKctJakwrYxZuJVtgRFES\nQYVeSZjH97VQZ8d+u8Wzh9rITEthc1X+qPbVpdmuWPT7TvSS5vWMEtexrF+SO6lFPxSO8JuX6wiE\nhkeEPmmL3ro+x+elIHPqtAmK4gYq9EpC1Lb7ueWXO/m3pw671mdwKMwDbzRx6bpS0ryj/ymuLs3m\naLufcCS5ClD7m/pYW5YzZf3V05fmcrwzQP8E+XXufbWRL9/7Jv/n4YPUtPSRlZZCeZJFtZ1UDMuK\nMmc8qaso00GFXkmI2/98lIiBgzPM9tjcE6Sjf3BU28N7mukZGOLGbcvGnb+qJIvQcIT6zpl/gzDG\nsK+pN26ul/XluRgDB8e4b4wx/NcLtYjAnS/W8vj+VlYnEXHj4GTQ1IlY5WShQq/EpbUvyD27Gkjx\nCIdb+ohM08oeCIW59LtPc/Y3H2fbPz3Ofz53DIDf7KijuiiT81aOTz+wdoklzo6PPRGer2nnU7/a\nFY2/b+m1Jk/Xl0/tU3cmase6b3Yc62R/Uy9fvmY9Jdk+GrsHknbbAGSmeXnbmeVcsb4s6b4UJRFU\n6JW43PF8LUPhCDdfuIJAKExj98C0rj/e6WdgKMy7zqpgdWk2//jAPr714H52HOvk+m3L8EyQSmB9\neQ5pKR5213cnfJ/H9rXw0J5m/vmhAwDc9oRVrWlr9dTJ0ZbmpZOXkcq+ptErZO94oZb8zFT+8rzl\nfPUvTgdgbZITsQ4/+uAW3r2l0pW+FCUemtRMmZK+4BC/fOk412xYwltPL+P2Z49yuLWPqmmEBda2\nW+6Xmy9cwdolOXzil7u4/dmjpKYI7z17YrHzeVNYvzSX16Yh9M4L6I4XagmEhrl7ZwOfuXQ1G+Kk\nEBYR1pfnsC/m20Nj9wCP7G3mlresIiMthbdtLCftQx7OSzD5maLMJ9SiV6bkrh119AWH+eTFq6Jh\nhYem6ac/btdrXVaUSWqKh3+/cQtXnl7Gh7dXT5ke+KyqfN5s6GE4HEnoPo1dA5y/qojVpdncvbOB\nS9aW8LkrT0vo2jOW5nGgqZch+16P72shYuD6c6oA62Xw1jOWkJuuUTLKwkOFXpmUweEwP3/uGOev\nKuLMynzyMlIpy/VxqHl6BbVrOwIUZqVFRTI9NYWffXgrX3n76VNet7kqn4GhcMIvlhM9A6wqyebf\nb9zCDduWcdsHzko4w+TmqnwGhyMctD/b7vpuSnJ8OmGqLApU6JVJue+1E7T0DvLJmNJ2p5XlcKh1\nekJf1+mfkWA6sfWvN8R33/gHh+kODLE0P4PTynL49rs3kjeNGHXnXo6raHd9N5ur8jX8UVkUqNAr\nExKJGH7y7BHOWJrLRWtGVq2eVpZDTWv/tOLba9sDVM9gBejyokwKMlPZXRdf6E/Y/vml+TOLca8s\nyKAoK43ddd10B0Ica/ePW8SlKAsVFXplQt5o7OFom5+bL1wxyqo9rSyb4FDi8e2Dw2FO9AzMKKeL\niLCpKj+hyBtnIrayIGPa93Hutbkqn931XdH7naVCrywSVOiVCXGKcZy1rGBU+8iEbGLum4auAYwZ\nWSQ0XTZX5XOotW/CVauxNEYt+pkJvXOvI21+/ny4HRHYWBm/4LeiLARU6JUJOdLaT1qKh6oxFnK0\noHaCeWiciJuZJu/aXJWPMfBGHD/9ie4BvB6ZsLhIwvdaZlnwv99Zz5rS7GjaZEVZ6KjQKxNS09rP\niuKscTlictJTqcjPSHjFqhNDv3yG6XjPrLTEN14a4RPdQZbkpc+ojuvYe/UGh9U/rywq4gq9iFSJ\nyFMisk9E9orIZ+32QhF5TEQO278L7HYRkX8RkRoReUNEtsz2h1Dc50hb/6R517dWF/B8TXtC8e3H\nO/zk+LwUZqXNaByFWWmU5fpGLWaaiMaugaTcNgB5Gamssmu6bq4qiHO2oiwcErHoh4F/MMacDpwH\nfFpETgduBZ4wxqwBnrD3Aa4B1tg/twA/dn3UyqwSHApT1xmIit5Y3nr6EroCQ+w63hW3r+OdAZYX\nJ5elcd2SXA7EpCcwxvDkgRbe/q9/5ov3vAFYPvrKJIUeRgR+U5X655XFQ1yhN8Y0GWNetbf7gP1A\nBXAdcKd92p3AO+3t64BfGIuXgHwRKXd95ArGGNr7B2nvH2Qg5E4hbYDaDj8RY9VunYiL15aQluLh\nsX0tcfs63hFgeWFyxTXWlVshnc6q1c/9bjcfu2Mnh5r7+cOrDbT3D9LcG0zaoge4bvNSLj6txLWc\nNooyH5iWj15EqoGzgJeBMmNMk32oGXBS8VUA9TGXNdhtY/u6RUR2isjOtra2aQ5bAfjnhw6w9ZuP\ns/Wbj7P9n5+gLzjkSr819kTrZK6bbJ+X81cX8dj+FoyZPJ5+OGyFYSa7unT9klxC4QjH2v10+kPc\n9/oJbthWxd2f3M5wxHDnC7WEI8YVoX/LaSXc+bFtU+avV5SFRsL/mkUkG/gD8HfGmFEOU2P9b59W\n7lpjzO3GmK3GmK0lJSXTuVSxeXRfC2dW5vF3V6yhOzDEQ282u9JvTWs/IrCyePKUvFeeXsbxjsCo\n6BtjDE8daOVv73qNs/7xUTb/42MMR8yMFkvFss5OM7y/qZcXj3RgDLxvaxWbKvNYWZLFnS/UAlAx\nwxh6RVnsJCT0IpKKJfK/Nsb80W5ucVwy9u9Wu70RqIq5vNJuU1ykoSvAsXY/79xcwWcvX8PK4izu\n2dUw7X4efLOJ/3nvm6Ms8yNtfiryM8hIS5n0OieX+qN7R14uv9/VwEfveIVnD7dx+foy3r+1ik9d\nsoq3npFc3vWVxdmkpggHmvt4rqadHJ+XMyvyEBHesWkpvUErxr5ihqtiFWWxk0jUjQA/B/YbY74f\nc+h+4CZ7+ybgvpj2D9vRN+cBPTEuHsUlnq9pB+DCNcWICO85u5IdtZ3Tqukajhi+9eB+fv1yHS8e\n7Yi217ROHnHjUJabzuaq/FF++sf3tVBZkMGOL1/Bd9+3ia/+xel88ep15GfOLOLGIc3rYVVJNgea\nenm+pp3zVhVFXSvv2LQ0ep4brhtFWYwkYtFfAHwIuExEdts/1wL/DFwpIoeBK+x9gAeBo0AN8DPg\nr90ftvJcTQclOb7oAqZ3nVWBCPzh1cSt+if2t9DQNUCKR/jJM0cBS/yPtvWzuiR+JaUr1pfyekMP\nbX2DRCKGl491cv6qonH1X91gfXkuLx/rpK4zwIWrR3LvrCzJZmNFHvmZqWSmaXkFRZmIuP8zjDHP\nAZPFxl0+wfkG+HSS41KmIBIxvFDTzltOK4mGLS7Nz+CCVcX88bUGPnv5mgmrNo3ljhdqWZqXzgfO\nWcYPHj/E3hM95PhSGRyOxLXoAS5dV8p3Hz3E0wdbWV+eS8/AENtnqTDHuiU53Pua5QG8IEboAb72\nF6fT0DW9qleKciqhoQULkAPNfXT4Q+ME711nVVDfOcCeEz1x+zjY3McLRzr40PZqPnJ+NVlpKXzj\nT/v40r1WXHoiQn96eS5LctN58kArL9mun+0ri+NcNTPW2QW+l+Smj4vv31pdyDvPGhfYpSiKjQr9\nPGayItyOf/6C1aOtZyed8ItHOsZdM5afP3cUn9fD9edUkWfXRd1xrJMjrX4+e/kazl4ef2WoiHDp\nuhL+fLidZw+3s6I4iyV5szMhun6JFXlzwepizRGvKNNEhX6e0tE/yKZvPMrDe0aHTPYFh7hnVwOr\nSrIozxs9+VhqW7uxE6sTsbu+m9/vauDGc5dTYKcm+Pu3nsb9n7mA52+9jM9deVrCYnrp2lL6B4d5\n9lAb562cvXqqJTk+Pv/W0/j4W1bM2j0UZbGiQj9Pea2um77BYX7xYm20LTgU5uO/2MmRtn7+19sm\nLsO3fVURrxzrjK4iHctwOMKX//gmpTk+Pnflmmi7z5vCmZX5004KdsHqYtLsCJjZ8s+D9e3hM5et\nYd2S3Fm7h6IsVlToZ4lwxDAUjiRc2HosbzZafvYXj3bQ2D2AMYa/v3s3Lx3t5Lvv28Sl60onvG77\nymL8oXD0+rH81/O17Gvq5et/cYYraXizfF7OXVkIwHn2b0VR5hcajzYLtPYGufz7z9BnL+T539ed\nwYe2V0+rjz2NPRRnp9HeH+LeVxtYWZLNg2828z+uWjvlxKMjti8e6WDLmKIhd+2o49sP7eeK9aVc\nvWHJ9D7UFHz60tWctawgqVzwiqLMHir0s8DTh9roCw7zibes5M+H2/m3p2p4/zlV+LyTrzQdy5uN\nPVy0poQT3QPcvbOBweEwp5fn8om3rJzyuqJsH2vLcnjpaAefvnR1tP1HT9Xwfx85yCVrS7jt+rNc\nndA8b2XRrPrnFUVJDnXdzALP17RTnO3j1mvWces162jpHeS+104kfH1rb5DWvkE2VOTxnrMrqesM\n0No3yLfevTGhZFvbVxWxs7aL0LDlNuryh/jeowe5ZsMSfvbhrWT59P2uKKcSKvQuY4zh+Zp2Llxd\nhIhw0Zpizliay0+ePTJpuORYHP/6xoo8rt1YTl5GKjdtr0646tF5K4sYGArzWp2VL/65mnYiBj7+\nlpWkalZGRTnl0P/1LnOwpY/2/pHFTCLCJy5exdE2P4/tj5+/HSyhF4EzluaS7fPy7Bcu5atvnzjK\nZiIuXFOMz+vhwTetFEPPHGojLyOVTZVaHk9RTkVU6F3mucPOYqaRFaLXbljCktx0/t9riSXx3NPY\nw8rirKiLJS8jNaGUBg7ZPi+Xry/lv99sYjgc4dlDbVy4ujipeqqKoixcVOhd5oUjHawsyRqVSdGb\n4uHclYW8Wtc1ZaEOhzcbe6KFqmfKOzYtpb0/xB0v1NLaN8jFp2nOf0U5VVGhd5GhcISXjnaMyq7o\nsGVZAS29g5zoCU7ZR2tfkJZeayI2GS5ZW0qOz8v3Hj0EwEWnzU4OGkVR5j+nXPjFkwda+O83rLQC\na5dkc8tbVrnW9+76bgKhMOevmljoAV493kXFJHnT+weH+Ye7XwdgW3Vyi4/SU1N46xlL+MOrDawt\nyxmXLkFRlFOHU8qiHwpHuPUPb/Lo3maeOdTKtx48wJ5JVpDOhB3HOgE4d8V4kV5XnkN6qodX7UiY\nsbT2Bbn+9hd54UgH//e9Z7KxMjmLHuAdm62iHG9Ra15RTmlOKaF/aE8zrX2D/MsNZ/Hk5y8hx+fl\nJ88cca3/Xce7WF2aHU0UFktqioczK/N5ta573LFj7X7e8+MXONLq5z8+vJX3ba0ad85MuHB1MX97\n+Ro+PM1VuYqiLC5OKaG/4/ljVBdlcvFpJeSmp/LB85bx4JtNHO/wT6uff33iMDf95w76B4ejbZGI\nYWdtJ1unSO+7ZVkB+070EBwKR9sONvfxnh+/gH8wzF23nDdpDpuZkOIR/v7K06gqzHStT0VRFh6n\njNC/Xt/Nq3Xd3HR+dTRU8eYLVuD1ePjZn48m3E9wKMztzx7lmUNtfPzOnVHRrmnrpzc4zNYpfOtb\nluUzFDajEo7d+WItg0Nh/vCp8xNeEKUoijIdThmhv/OFWrLSUnjv2ZXRttLcdN69pYLf72ygJzCU\nUD+P7G2mb3CYG89dxotHO/i73+7GGMMrtZZ/fkqLfvnIhKzD7rpuzlpWwIrirMkuUxRFSYpFJfRD\n4Qi/fLF2nCumyx/igTebePeWynGpea/ftozB4QiPJ7hq9Q+vNlKRn8H/vm4DX7x6HQ/vbeaxfS3s\nqu2iONvH8qLJ3STF2T6WFWayyxb6gVCYgy19askrijKrLCqhv3/3Cb5y314u/94zfO2+PfQGLSv9\nj681EhqOcMO2ZeOu2V9fC88AAAk8SURBVFSZx9K8dB7a0xS3/5beIM8dbuPdWyrweISPX7SCVSVZ\nfPuhA7x8zPLPx8sKef6qIl480sFQOMKbjT2EI0aFXlGUWWVRCf1dO+qoLsrk/edU8auX6/gfv38d\nYwx37ahjc1U+py8dX51IRLhmYznPHmqnLzi1++be1xqJGHj3Fsv9403x8OVr13Os3U9j9wBbq+PX\nWb10XSl9g8O8UtvJ7nrLst+8TIVeUZTZY9EI/aGWPnYe7+LGc5fzrXdt5AtXreWRvS18+d491LT2\n88EJrHmHazcuIRSO8OSB1knPGQ5HuGtHHWcvH+1Pv2xdKdvtXOxTTcQ6XGiX3nvqQCu767upLMig\nONs3jU+qKIoyPRaN0N+1o460FA/vsSdb/+qilWxbUchdO+rI9nl5+6bySa89q6qAslxfNNvjRPzx\ntUaOdwT45MWjV9KKCN9+90Y+dckqNiaQtsApvffkgVZ213Wr20ZRlFlnUQh9cCjMH19t5KoNSyi0\nFyuleITvv38T+ZmpXH9OFZlpk2d78HiEazaU8/TBNvwxsfFd/hDD4QhD4Qj/+uRhNlbkccX68XHu\n1cVZfPHqdQlnh7xsXSlH2vyc6Amq0CuKMussCqH//mOH6BkY4oZto1eUVhZk8twXL+PL166P28fb\nzixncDjCI3utPDitfUEu+M6TXPXDZ/na/Xup7xzgc1eucaUE32Uxi6JU6BVFmW0WvND/+Okj3P7s\nUf7yvGVRX3ks2T5vQrncty4vYFlhJn94tQGA3+9sIBAKEzHwm5fr2FSVz6Vr3Vm1urwoi5UlWXg9\nknSWSkVRlHgs6OyVv91Rx3cePsA7Ni3lH9+xISlrW0R495YKbnviMA1dAX77Sh3bVxbxy5u38cje\nFjZU5LpaUPuWi1ZyoLmP9NTEC4YriqLMhAUt9OvLc3n3WRV8571nTqsC02S8Z0slP3z8MF+45w3q\nOwf4H1etw5vi4W1nTj6RO1OunyIKSFEUxU3ium5E5D9FpFVE9sS0FYrIYyJy2P5dYLeLiPyLiNSI\nyBsismU2B7+pKp/vf2CzawWvqwoz2baikBeOdFCQmcpVZ5S50q+iKMpckohC3gFcPabtVuAJY8wa\n4Al7H+AaYI39cwvwY3eGefJ4r70Y6r1nV+LzqltFUZSFT1yhN8Y8C3SOab4OuNPevhN4Z0z7L4zF\nS0C+iLjv95hF/mLTUm6+cAV/ddHKuR6KoiiKK8zUR19mjHFWFzUDjo+jAqiPOa/Bbhu3EklEbsGy\n+lm2bP74qzPSUvjK20+f62EoiqK4RtLObWOMAcwMrrvdGLPVGLO1pKQk2WEoiqIokzBToW9xXDL2\nbydJTCMQu2qp0m5TFEVR5oiZCv39wE329k3AfTHtH7ajb84DemJcPIqiKMocENdHLyJ3AZcAxSLS\nAHwN+GfgbhG5GTgOvN8+/UHgWqAGCAAfnYUxK4qiKNMgrtAbY26Y5NDlE5xrgE8nOyhFURTFPRZ8\nrhtFURRlalToFUVRFjkq9IqiKIscsdzqczwIkTasSd2ZUAy0uzic2WShjHWhjBN0rLPBQhknLJyx\nztY4lxtj4i5EmhdCnwwistMYs3Wux5EIC2WsC2WcoGOdDRbKOGHhjHWux6muG0VRlEWOCr2iKMoi\nZzEI/e1zPYBpsFDGulDGCTrW2WChjBMWzljndJwL3kevKIqiTM1isOgVRVGUKVChVxRFWeQsaKEX\nkatF5KBdo/bW+FecHESkSkSeEpF9IrJXRD5rt09Ya3c+ICIpIvKaiDxg768QkZftZ/s7EUmbB2PM\nF5F7ROSAiOwXke3z9ZmKyOfsv/0eEblLRNLnyzOdz3WgExjn/7X//m+IyL0ikh9z7Ev2OA+KyFUn\na5yTjTXm2D+IiBGRYnv/pD/TBSv0IpIC/AirTu3pwA0iMl9KQw0D/2CMOR04D/i0PbbJau3OBz4L\n7I/Z/w7wA2PMaqALuHlORjWa24CHjTHrgE1Y4513z1REKoC/BbYaYzYAKcD1zJ9negcLow70HYwf\n52PABmPMmcAh4EsA9v+v64Ez7Gv+3daIk8UdjB8rIlIFvBWoi2k++c/UGLMgf4DtwCMx+18CvjTX\n45pkrPcBVwIHgXK7rRw4ONdjs8dSifWf+zLgAUCwVvF5J3rWczTGPOAYdgBBTPu8e6aMlNQsxMoQ\n+wBw1Xx6pkA1sCfecwR+Ctww0XlzMc4xx94F/NreHvX/H3gE2D6Xz9RuuwfLKKkFiufqmS5Yi57J\n69POK0SkGjgLeJnJa+3ONT8EvgBE7P0ioNsYM2zvz4dnuwJoA/7LdjH9h4hkMQ+fqTGmEfgulhXX\nBPQAu5h/zzSW6daBng98DHjI3p534xSR64BGY8zrYw6d9LEuZKGf94hINvAH4O+MMb2xx4z1Kp/z\n2FYReTvQaozZNddjiYMX2AL82BhzFuBnjJtmHj3TAuD/t3f2rlFEURT/3UIXYqMWIpIiKmIrVoIW\nghYaJDYWQhrBv0IXBP8BwUKwsZKgoAQJln7UURE14gdGDLiFYG+T4ljct7gsidiY93Y4P1jYnZni\ncJh3HnPvsPc8uTntA3awwWN9q7Ti49+IiD5ZIl2orWUjImIKuApcq60FJjvom55PGxHbyJBfkLRY\nDm82a7cmx4G5iFgD7pPlm5vAzogYDqZpwdsBMJC0XH4/JIO/RU9PA98k/ZS0DiySPrfm6SgTMwc6\nIi4B54D5silBezoPkhv927K2poHXEbGXClonOehfAofKmwzbyUbMUmVNQHbVgTvAR0k3Rk5tNmu3\nGpKuSJqWNEN6+EzSPPAcuFAuq65V0g/ge0QcLodOAR9o0FOyZHMsIqbKvTDU2pSnY0zEHOiIOEOW\nGeck/Ro5tQRcjIheROwnG50vamgEkLQiaY+kmbK2BsDRch9vvadb2az4D82PWbLz/hXo19YzousE\n+ej7DnhTPrNk7fsp8AV4AuyurXVM90ngcfl+gFwoq8ADoNeAviPAq+LrI2BXq54C14FPwHvgLtBr\nxVPgHtk7WCcD6PJmPpKN+Vtlja2QbxLV1LlK1reH6+r2yPX9ovMzcLa2p2Pn1/jTjN1yT/0XCMYY\n03EmuXRjjDHmH3DQG2NMx3HQG2NMx3HQG2NMx3HQG2NMx3HQG2NMx3HQG2NMx/kN4MlmEVmSNk4A\nAAAASUVORK5CYII=\n", 40 | "text/plain": [ 41 | "" 42 | ] 43 | }, 44 | "metadata": {}, 45 | "output_type": "display_data" 46 | } 47 | ], 48 | "source": [ 49 | "# 获取输入数据,并可视化\n", 50 | "\n", 51 | "import pandas as pd\n", 52 | "import numpy as np\n", 53 | "import matplotlib.pyplot as plt\n", 54 | "% matplotlib inline\n", 55 | "\n", 56 | "data = pd.read_csv('data.csv', usecols=[1])\n", 57 | "plt.title('International airline passengers')\n", 58 | "plt.plot(data)\n", 59 | "plt.show()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "---\n", 67 | "## 目标\n", 68 | "给定前三个月的客流量预测当前月的飞机客流量。\n", 69 | "\n", 70 | "以1949年到1958年的数据作为训练集, 1959和1960年的数据用作测试集。\n", 71 | "## 方法\n", 72 | "1. 模型使用某三个月的数据作为输入,获得第四个月的客流量\n", 73 | "2. 以此类推进行测试\n", 74 | "\n", 75 | "## 比较\n", 76 | "创建三个模型进行结果比较 : `FC`, `RNN`, `LSTM`, `GRU1`, `GRU2`\n", 77 | "\n", 78 | "`GRU2`比`GRU1`多一般的状态数\n", 79 | "\n", 80 | "分别从预测准确性,训练时间两方面分析。" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "---\n", 88 | "**以下部分为数据预处理的代码,会在其他几个模型中动态生成\\*.py文件并引用**" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# 定义numpy转化为Tensor的函数\n", 98 | "toTs = lambda x : torch.from_numpy(x)\n", 99 | "\n", 100 | "# 定义检测cuda的函数\n", 101 | "cudAvl = lambda x : x.cuda() if torch.cuda.is_available() else x" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "dataSet shape :\t (144, 1)\n", 114 | "train data shape : (120, 1)\n", 115 | "real data shape : (144, 1)\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "# 数据预处理\n", 121 | "data = data.dropna()\n", 122 | "dataSet = data.values\n", 123 | "dataSet = dataSet.astype('float32')\n", 124 | "print('dataSet shape :\\t', dataSet.shape)\n", 125 | "\n", 126 | "# 数据归一化\n", 127 | "def MinMaxScaler(X) :\n", 128 | " mx, mi = np.max(X), np.min(X)\n", 129 | " X_std = (X - mi) / (mx - mi)\n", 130 | " return X_std\n", 131 | "\n", 132 | "dataSet = MinMaxScaler(dataSet)\n", 133 | "\n", 134 | "# 将数据分为训练集和测试集\n", 135 | "train = dataSet[:12*10]\n", 136 | "real = dataSet\n", 137 | "print('train data shape :', train.shape)\n", 138 | "print('real data shape :', real.shape)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 5, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "input_size = 3" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python 3", 154 | "language": "python", 155 | "name": "python3" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.5.2" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 2 172 | } 173 | -------------------------------------------------------------------------------- /Chapter_5/sequencePrediction/seqInit.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # # 序列预测 5 | # 6 | # --- 7 | 8 | # In[1]: 9 | 10 | 11 | import torch 12 | from torch import nn, optim 13 | from torch.autograd import Variable 14 | from torch.nn import init 15 | 16 | 17 | # ## 已知 18 | # 1949年到1960年每一个月的飞机客流量 19 | 20 | # In[2]: 21 | 22 | 23 | # 获取输入数据,并可视化 24 | 25 | import pandas as pd 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | get_ipython().run_line_magic('matplotlib', 'inline') 29 | 30 | data = pd.read_csv('data.csv', usecols=[1]) 31 | plt.title('International airline passengers') 32 | plt.plot(data) 33 | plt.show() 34 | 35 | 36 | # --- 37 | # ## 目标 38 | # 给定前三个月的客流量预测当前月的飞机客流量。 39 | # 40 | # 以1949年到1958年的数据作为训练集, 1959和1960年的数据用作测试集。 41 | # ## 方法 42 | # 1. 模型使用某三个月的数据作为输入,获得第四个月的客流量 43 | # 2. 以此类推进行测试 44 | # 45 | # ## 比较 46 | # 创建三个模型进行结果比较 : `FC`, `RNN`, `LSTM`, `GRU1`, `GRU2` 47 | # 48 | # `GRU2`比`GRU1`多一般的状态数 49 | # 50 | # 分别从预测准确性,训练时间两方面分析。 51 | 52 | # --- 53 | # **以下部分为数据预处理的代码,会在其他几个模型中动态生成\*.py文件并引用** 54 | 55 | # In[3]: 56 | 57 | 58 | # 定义numpy转化为Tensor的函数 59 | toTs = lambda x : torch.from_numpy(x) 60 | 61 | # 定义检测cuda的函数 62 | cudAvl = lambda x : x.cuda() if torch.cuda.is_available() else x 63 | 64 | 65 | # In[4]: 66 | 67 | 68 | # 数据预处理 69 | data = data.dropna() 70 | dataSet = data.values 71 | dataSet = dataSet.astype('float32') 72 | print('dataSet shape :\t', dataSet.shape) 73 | 74 | # 数据归一化 75 | def MinMaxScaler(X) : 76 | mx, mi = np.max(X), np.min(X) 77 | X_std = (X - mi) / (mx - mi) 78 | return X_std 79 | 80 | dataSet = MinMaxScaler(dataSet) 81 | 82 | # 将数据分为训练集和测试集 83 | train = dataSet[:12*10] 84 | real = dataSet 85 | print('train data shape :', train.shape) 86 | print('real data shape :', real.shape) 87 | 88 | 89 | # In[5]: 90 | 91 | 92 | input_size = 3 93 | 94 | -------------------------------------------------------------------------------- /Chapter_6/autoEncoder.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# AutoEncoder\n", 8 | "---\n", 9 | "自动编码器的主要应用有两个方面,第一是**数据去噪**,第二是进行**可视化降维**。自动编码器还有一个功能,即**生成数据**。\n", 10 | "\n", 11 | "编码器和解码器可以是任意模型,通常使用神经网络模型作为编码器和解码器。\n", 12 | "\n", 13 | "输入的数据经过神经网络降维到一个编码(code),接着通过另外一个网络剧解码得到一个和输入原数据一模一样的生成数据,然后通过比较这两个数据,最小化它们之间的差异来训练这个网络中编码器和解码器的参数。\n", 14 | "\n", 15 | "当这个过程训练完后,拿出这个解码器,随机传入一个编码(code),通过解码器能够生成一个和原数据差不多的数据," 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "collapsed": true 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import torch\n", 27 | "import torch.nn as nn\n", 28 | "from torch.autograd import Variable\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "%matplotlib inline" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": { 37 | "collapsed": true 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "cudAvl = lambda x : x.cuda() if torch.cuda.is_available() else x" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": { 48 | "collapsed": true 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "class autoencoder(nn.Module):\n", 53 | " def __init__(self):\n", 54 | " super().__init__()\n", 55 | " self.encoder = nn.Sequential(\n", 56 | " nn.Linear(28*28, 128),\n", 57 | " nn.ReLU(True),\n", 58 | " nn.Linear(128, 64),\n", 59 | " nn.ReLU(True),\n", 60 | " nn.Linear(64, 12),\n", 61 | " nn.ReLU(True),\n", 62 | " nn.Linear(12, 3)\n", 63 | " )\n", 64 | " \n", 65 | " self.decoder = nn.Sequential(\n", 66 | " nn.Linear(3, 12),\n", 67 | " nn.ReLU(True),\n", 68 | " nn.Linear(12, 64),\n", 69 | " nn.ReLU(True),\n", 70 | " nn.Linear(64, 128),\n", 71 | " nn.ReLU(True),\n", 72 | " nn.Linear(128, 28*28),\n", 73 | " nn.Tanh()\n", 74 | " )\n", 75 | " \n", 76 | " def forward(self, x):\n", 77 | " code = self.encoder(x)\n", 78 | " imgTensor = self.decoder(code)\n", 79 | " return code, imgTensor\n", 80 | " \n", 81 | " def encode(self, x):\n", 82 | " return self.encoder(x)\n", 83 | " \n", 84 | " def decode(self, x):\n", 85 | " return self.decoder(x)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": { 92 | "collapsed": true 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "net = autoencoder()\n", 97 | "net = cudAvl(net)\n", 98 | "criterion = nn.MSELoss()\n", 99 | "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 5, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "42000 784\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "import csv\n", 117 | "\n", 118 | "with open('train.csv') as f :\n", 119 | " lines = csv.reader(f)\n", 120 | " label, attr = [], []\n", 121 | " for line in lines :\n", 122 | " if lines.line_num == 1 :\n", 123 | " continue\n", 124 | " label.append(int(line[0]))\n", 125 | " attr.append([float(j) for j in line[1:]])\n", 126 | "print(len(label), len(attr[1]))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 6, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "torch.Size([42000, 784])" 138 | ] 139 | }, 140 | "execution_count": 6, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "digTensor = torch.FloatTensor(attr)\n", 147 | "digTensor = digTensor / 255.0\n", 148 | "digTensor.shape" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "epoch [100/1500]: 0.058348119258880615\n", 161 | "epoch [200/1500]: 0.047589220106601715\n", 162 | "epoch [300/1500]: 0.042967937886714935\n", 163 | "epoch [400/1500]: 0.040192779153585434\n", 164 | "epoch [500/1500]: 0.038630422204732895\n", 165 | "epoch [600/1500]: 0.03767261654138565\n", 166 | "epoch [700/1500]: 0.037026435136795044\n", 167 | "epoch [800/1500]: 0.036425597965717316\n", 168 | "epoch [900/1500]: 0.03603523224592209\n", 169 | "epoch [1000/1500]: 0.0356963612139225\n", 170 | "epoch [1100/1500]: 0.035425879061222076\n", 171 | "epoch [1200/1500]: 0.035187773406505585\n", 172 | "epoch [1300/1500]: 0.034891996532678604\n", 173 | "epoch [1400/1500]: 0.03473861888051033\n", 174 | "epoch [1500/1500]: 0.034577593207359314\n" 175 | ] 176 | }, 177 | { 178 | "data": { 179 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xt8XHWd//HXZ66ZXJombUrvtIVyKYqIBUEF0SqgP1f0\nt6hFH4ouKz/1x89dxVVYd72gu17WVdxdWGUFb4uI4mULooCKV1hsYKFAoTcobXpNmzZt7pnM5/fH\nOQnTdJJM2iQzPfN+Ph55ZOac75z5zEnmPWe+53vOMXdHREQqQ6zUBYiIyNRR6IuIVBCFvohIBVHo\ni4hUEIW+iEgFUeiLiFQQhX6FMrPNZvaaKXqujJndaWbtZvbDqXjOvOd+0swumMrnFClnCn2ZCpcC\nxwEz3P0tk/UkZvYtM/ts/jR3P83dfzNZz3kkJvoDdyo/wOXYp9CXqXA8sN7ds6UuRI6cmSVKXYMc\nPYW+YGZpM7vezLaHP9ebWTqcN9PM7jKz/WbWZma/N7NYOO9jZrbNzA6a2TozW1Fg2Z8GPgG8zcw6\nzOwKM/uUmf1nXptFZuaDoWJmvzGzz5jZH8Nl32tmM/Pav8LMHghr2mpm7zazK4F3AB8Nn+fOsO3Q\nVvAYr/MCM2sxs6vNbLeZ7TCz94yyzuaa2apwnWw0s/fmzTvkG8fgssPb3wUWAneGdX407/VfGda1\nw8yuPtLljVDvJWb2qJkdMLNNZnbx8PUT3h/62+TVdYWZbQF+bWa/MLOrhi37MTP73+HtU8zsvnC9\nrDOzt460DqU0FPoC8HHgHOAM4EXA2cDfhfOuBlqAJoIumr8F3MxOBq4CznL3OuAiYPPwBbv7J4F/\nBG5391p3v7nImt4OvAeYBaSAjwCY2ULg58C/hjWdATzq7jcBtwJfDJ/nz8b5OgFmA/XAPOAK4AYz\naxihvtsI1stcgu6rfyz0oTecu78T2AL8WVjnF/NmvwpYClwIXFNMl80YywPAzM4GvgP8DTAdOJ8C\nf6tRvBI4leBv/D3gsrxlLyP4JvczM6sB7gvbzArb3Whmp43juWSSKfQFgi3k69x9t7u3Ap8G3hnO\n6wfmAMe7e7+7/96DEzYNAGlgmZkl3X2zu2+awJq+6e7r3b0b+AFBUA/W+kt3vy2sZ6+7P1rkMkd7\nnRC81uvC5d4NdAAnD1+ImS0AXgF8zN17wuf/xrBlHYlPu3unuz8OfJO8cD1KVwC3uPt97p5z923u\n/vQ4Hv+psK5u4CfAGWZ2fDjvHcCP3b0XeAOw2d2/6e5Zd38E+BHBh6KUCYW+QLC1+lze/efCaQD/\nBGwE7jWzZ8zsGgB33wj8NfApYLeZfd/M5jJxdubd7gJqw9sLgCP9cBntdQLsHbbfIf95hy+nzd0P\nDlvWvCOsa9DWUWo7GkezziCvrvA1/wxYGU5aSfANC4It/peG3W77zWw/wYfC7KN4bplgCn0B2E7w\nhh20MJyGux9096vdfQnwZ8CHB7sx3P177v6K8LEOfKHI5+sEqvPujycUtgInjDBvrFPGjvg6x2k7\n0GhmdcOWtS28PdbrG6nOBSPUdqTLGzTaOivmbzF8+bcBl5nZuUAGuD/veX7r7tPzfmrd/f1j1CdT\nSKEvELyJ/87MmsIdpp8ABnfmvcHMTjQzAw4QdOsMmNnJZvbqcEdoD9AdzivGo8D5ZrbQzOqBa8dR\n663Aa8zsrWaWMLMZZjbY9bMLWHIkr3M83H0r8ADwOTOrMrPTCbpQBrd4HwVeb2aNZjab4BtRvpHq\n/Hszqw77wN8D3H6Uyxt0M/AeM1thZjEzm2dmp+Qte6WZJc1sOcV1xdxN8OF5HcG+mlw4/S7gJDN7\nZ7i8pJmdZWanFrFMmSIKfQH4LNAMrAEeBx4Jp0GwY/GXBP3bDwI3huPe08DngT0EXTGzCHbyjsnd\n7yMItDXAwwRhURR33wK8nmAHcxtBaL0onH0zwT6G/Wb203G+zvG6DFhEsDX+E+CT4esC+C7wGMHO\n0nt5PrwHfY7gw2e/mX0kb/pvCbrSfgV8yd3vPcrlAeDufyL4EPkK0B4+z+A3nr8n+Bawj2Afx/fG\neuFh//2Pgdfktw+7fi4k6PLZTvB/8QWC/xUpE6aLqIiUlpktAp4FkjqWQSabtvRFRCqIQl9EpIKo\ne0dEpIJoS19EpIIUdQKl8DwdXwXiwDfc/fPD5p8PXA+cDqx09zvC6WcA/w5MIxjO9w/uPnzkwSFm\nzpzpixYtGufLEBGpbA8//PAed28aq92YoW9mceAG4LUE5xpZbWar3H1tXrMtwLsJz4+Spwt4l7tv\nCI/WfNjM7nH3/SM936JFi2hubh6rLBERyWNmz43dqrgt/bOBje7+TLjg7wOXAEOh7+6bw3m5/Ae6\n+/q829vNbDfBSbJGDH0REZk8xfTpz+PQc4K0cATnGAnP9JeiwDlAwlPKNptZc2tr63gXLSIiRSom\n9K3AtHEN+TGzOQRHFb4n75Dt5xfmfpO7L3f35U1NY3ZJiYjIESom9Fs49ERQ8xnHSarMbBrBWfn+\nzt3/e3zliYjIRCom9FcDS81ssZmlCM6rsaqYhYftfwJ8x92n9ILYIiJyuDFDPzwXyFXAPcBTwA/c\n/Ukzu87M3ggQnkmvBXgL8HUzezJ8+FsJrtLz7vBSbY/mnRFRRESmWNkdkbt8+XLXkE0RkfExs4fd\nfflY7SJzRG5Hb5Yv37ee/9myr9SliIiUrciEfl82x7/8agOPbtUhACIiI4lM6GeScQB6+g8bESoi\nIqHIhH46EbyUnv5ir9gnIlJ5IhP6sZiRTsQU+iIio4hM6ANUJeMKfRGRUUQq9DPJON0KfRGREUUq\n9KuSMe3IFREZRcRCX1v6IiKjiVzoq09fRGRkEQt9jd4RERlNpEI/k4yrT19EZBSRCn1174iIjC5S\noa8hmyIio4tU6KfVvSMiMqpIhX5G3TsiIqOKVOhr9I6IyOgiFfqZZJxszukfUBePiEghkQr9qqFz\n6mtrX0SkkGiFfioIfY3gEREpLFqhH15IpVcjeERECopW6Kt7R0RkVJEK/cHr5Kp7R0SksEiF/uCW\nfnefQl9EpJBIhX4mFV4cPas+fRGRQiIV+trSFxEZXSRDvzer0BcRKSRSoZ/Rlr6IyKgiGfoasiki\nUlikQn+oT18HZ4mIFBSp0E+HR+RqnL6ISGGRCv1YzKhKxuhV6IuIFBSp0Iegi0db+iIihUUu9DPJ\nuEbviIiMIHKhX5WM64hcEZERRDL0taUvIlJY5EI/o+vkioiMKHKhX5WMK/RFREYQudDPaPSOiMiI\nigp9M7vYzNaZ2UYzu6bA/PPN7BEzy5rZpcPmXW5mG8Kfyyeq8JFUpbSlLyIykjFD38ziwA3A64Bl\nwGVmtmxYsy3Au4HvDXtsI/BJ4KXA2cAnzazh6MseWVUiTo9OwyAiUlAxW/pnAxvd/Rl37wO+D1yS\n38DdN7v7GmB42l4E3Ofube6+D7gPuHgC6h5RJhVT946IyAiKCf15wNa8+y3htGIU9Vgzu9LMms2s\nubW1tchFFxZs6Sv0RUQKKSb0rcA0L3L5RT3W3W9y9+XuvrypqanIRReWSQU7ct2LLVFEpHIUE/ot\nwIK8+/OB7UUu/2gee0SqknHcoVdH5YqIHKaY0F8NLDWzxWaWAlYCq4pc/j3AhWbWEO7AvTCcNmmG\nLpmonbkiIocZM/TdPQtcRRDWTwE/cPcnzew6M3sjgJmdZWYtwFuAr5vZk+Fj24DPEHxwrAauC6dN\nmqFLJqpfX0TkMIliGrn73cDdw6Z9Iu/2aoKum0KPvQW45ShqHJdMShdSEREZSeSOyK1K6Dq5IiIj\niV7op9S9IyIyksiF/mCfvrb0RUQOF7nQr1Loi4iMKHKhPzR6p09DNkVEhotc6Fclg5ekLX0RkcNF\nLvQ1Tl9EZGSRC/3B0Tva0hcROVz0Ql/j9EVERhS50E/GjXjM1L0jIlJA5ELfzILr5Gr0jojIYSIX\n+hCM4OnJaktfRGS4iIZ+nJ4+hb6IyHCRDX316YuIHC6SoZ9J6jq5IiKFRDL0q5IxenTlLBGRw0Q0\n9NW9IyJSSGRDX907IiKHU+iLiFSQSIZ+Rn36IiIFRTL0q5JxHZwlIlJAJEM/OA2DQl9EZLhIhn46\nGac3myOX81KXIiJSViIZ+oNXz+rNql9fRCRfJEM/o4uji4gUFMnQrxoMfe3MFRE5RCRDf+g6udqZ\nKyJyiEiG/mCfvsbqi4gcKpKhnx7c0lefvojIISIZ+oPdO70KfRGRQ0Qy9LUjV0SksEiG/vM7ctWn\nLyKSL5Kh//yOXG3pi4jki2joa0euiEghkQ59bemLiBwqoqGvc++IiBQSydBPxWPETEfkiogMF8nQ\nNzNdMlFEpIBIhj4E/frakSsicqiiQt/MLjazdWa20cyuKTA/bWa3h/MfMrNF4fSkmX3bzB43s6fM\n7NqJLX9kmWRc594RERlmzNA3szhwA/A6YBlwmZktG9bsCmCfu58IfAX4Qjj9LUDa3V8IvAT4P4Mf\nCJMtnYzpiFwRkWGK2dI/G9jo7s+4ex/wfeCSYW0uAb4d3r4DWGFmBjhQY2YJIAP0AQcmpPIxZJJx\nerQjV0TkEMWE/jxga979lnBawTbungXagRkEHwCdwA5gC/Ald28b/gRmdqWZNZtZc2tr67hfRCFV\nybi29EVEhikm9K3AtOFXHB+pzdnAADAXWAxcbWZLDmvofpO7L3f35U1NTUWUNLZMMq4hmyIiwxQT\n+i3Agrz784HtI7UJu3LqgTbg7cAv3L3f3XcDfwSWH23RxahKxrQjV0RkmGJCfzWw1MwWm1kKWAms\nGtZmFXB5ePtS4Nfu7gRdOq+2QA1wDvD0xJQ+urTG6YuIHGbM0A/76K8C7gGeAn7g7k+a2XVm9saw\n2c3ADDPbCHwYGBzWeQNQCzxB8OHxTXdfM8GvoaCMQl9E5DCJYhq5+93A3cOmfSLvdg/B8Mzhj+so\nNH0qVCVj9OjcOyIih4jsEbnakSsicrjIhv7gkM1g14KIiEDEQ99dp1cWEckX6dAHXUhFRCRfZEO/\nJhWEfpf69UVEhkQ39NPBwKTO3myJKxERKR+RDf3aMPQPKvRFRIZENvS1pS8icrgIh37Qp6/QFxF5\nXmRDf7B7p6NXO3JFRAZFNvTVvSMicrjIhv7zW/oKfRGRQZEN/apknNp0gj0dvaUuRUSkbEQ29AGa\n6tLsPqjQFxEZFPnQbz2g0BcRGRT50N91sKfUZYiIlI1Ih/6SmTW07OvWSddEREKRDv3T5k5jIOc8\nvfNgqUsRESkLkQ79ZXPqAXhye3uJKxERKQ+RDv0FjRlqUnE27OoodSkiImUh0qFvZixpqmVTq0Jf\nRAQiHvoAS5pqeKa1s9RliIiUhciH/glNtWzb3023rqAlIhL90H/hvGBn7urNbSWuRESk9CIf+ucs\nmUFtOsFXf7VB4/VFpOJFPvQzqTif//MX8siWfXzu7qdKXY6ISElFPvQB3nD6XC47eyG3PrSFrW1d\npS5HRKRkKiL0Af7fq08kZsaNv9lY6lJEREqmYkJ/Tn2GlWcv4IfNLWzcrdMyiEhlqpjQB/irFUvJ\npOJcd9dTuHupyxERmXIVFfozatN86DUn8bv1rdy7dlepyxERmXIVFfoA7zr3eE6ZXcd1d66lq0/X\nzxWRylJxoZ+Ix/jMm17A9vZu/uFnGsIpIpWl4kIf4KxFjbz3vCXc+tAWfv20unlEpHJUZOgDXH3h\nSZwyu46P3vE4ezt0HV0RqQwVG/rpRJzrV57Bge5+rvnx4xrNIyIVoWJDH+CU2dP46MUnc9/aXdy+\nemupyxERmXQVHfoAf/HyxbzshBlcd9dant2j8+6LSLRVfOjHYsY/v/VFJOMxPnjb/9CXzZW6JBGR\nSVNU6JvZxWa2zsw2mtk1Beanzez2cP5DZrYob97pZvagmT1pZo+bWdXElT8x5tRn+MKfn87j29r5\n5/vWlbocEZFJM2bom1kcuAF4HbAMuMzMlg1rdgWwz91PBL4CfCF8bAL4T+B97n4acAHQP2HVT6CL\nXzCby85ewH/87hke27q/1OWIiEyKYrb0zwY2uvsz7t4HfB+4ZFibS4Bvh7fvAFaYmQEXAmvc/TEA\nd9/r7mV7JZNrX38qTXVpPvajNfQPqJtHRKKnmNCfB+QPbWkJpxVs4+5ZoB2YAZwEuJndY2aPmNlH\nCz2BmV1pZs1m1tza2jre1zBhplUlue6SF/D0zoP8x++fKVkdIiKTpZjQtwLThg9qH6lNAngF8I7w\n95vNbMVhDd1vcvfl7r68qampiJImz0Wnzebi02bz1V9uYLNG84hIxBQT+i3Agrz784HtI7UJ+/Hr\ngbZw+m/dfY+7dwF3A2cebdGT7dOXnEYqEeNvf/I4uZwO2hKR6Cgm9FcDS81ssZmlgJXAqmFtVgGX\nh7cvBX7twSGu9wCnm1l1+GHwSmDtxJQ+eY6bVsW1rzuVBzbt5Yb7daUtEYmOxFgN3D1rZlcRBHgc\nuMXdnzSz64Bmd18F3Ax818w2Emzhrwwfu8/MvkzwweHA3e7+s0l6LRPqsrMX8Kdn9/LlX65n2dxp\nrDj1uFKXJCJy1KzczjmzfPlyb25uLnUZAPT0D3Dp1x7guT1d/PSql3NCU22pSxIRKcjMHnb35WO1\nq/gjckdTlYzz9XcuJ5WIccW3VtPW2VfqkkREjopCfwzzpme46V0vYXt7D1d+p5me/rI9zEBEZEwK\n/SK85PhGvvLWM2h+bh9/c8cajegRkWOWQr9I/+v0OXzs4lO487HtOj+PiByzxhy9I8973yuXsKWt\nixvu38S86dW8/aULS12SiMi4aEt/HMyMz1xyGhec3MTf/9cTPKoTs4nIMUahP06JeIyvrnwxx9Wl\n+dDtj9LVly11SSIiRVPoH4H6TJIvv+0MNu/t5Ku/2lDqckREiqbQP0LnLJnBm86Yx7f+uJmd7T2l\nLkdEpCgK/aPw4deeRM6dr/5qfalLEREpikL/KCxorOYdLz2eHzS36KLqInJMUOgfpQ+86gSSceP6\nX2prX0TKn0L/KM2qq+LdL1vMqse2s27nwVKXIyIyKoX+BHjfK5dQm0rwz/fqSF0RKW8K/QkwvTrF\nX563hHvX7qJ5c1upyxERGZFCf4Jccd5i5k3P8JEfPqYDtkSkbCn0J0htOsGX3vIinmvr4tOr1lJu\nF6cREQGF/oQ694QZfOCCE7i9eSu3/HFzqcsRETmMzrI5wa5+7cls2t3JZ3+2lrp0greetaDUJYmI\nDNGW/gSLxYzrV57B+Uub+OiP1vDtBzaXuiQRkSEK/UlQlYxz07tewoXLjuOTq57ki794WlfbEpGy\noNCfJOlEnBvecSZvf+lCbvzNJq667RFdX1dESk6hP4mS8Rj/8KYX8PHXn8rPn9jJm298QOfoEZGS\nUuhPMjPjvecv4RvvWs6uAz1c8m9/4OeP7yh1WSJSoRT6U2TFqcfx4/e/jMUza3j/rY/wodsfZW9H\nb6nLEpEKo9CfQotm1nDH+1/GB1cs5c7HtrPiy7/lW398lv6BXKlLE5EKodCfYsl4jA+/9iR+9sHz\nWDZnGp+6cy0XfeV33Ld2l47iFZFJp9AvkZNn13HrX76Umy9fjhm89zvNvP0/HuKJbe2lLk1EIkyh\nX0JmxopTj+MXf30+111yGut2HeQN//oHPnDrw2zYpXPzi8jEs3LrUli+fLk3NzeXuoySaO/u5+bf\nP8PNf3iWrv4BXnXyLK6+8CROm1tf6tJEpMyZ2cPuvnzMdgr98tPW2cctf3iW2/60hX1dfbzpxfP4\n4KuXsmhmTalLE5EypdCPgPaufv711xv47n8/RzbnXHrmfK569YksaKwudWkiUmYU+hGy+0APN/5m\nE9/70xZyOeetZy3gqledyNzpmVKXJiJlQqEfQTvau7nx/k18f/UWDOOCk5v4i1cs5pwlM0pdmoiU\nmEI/wrbt7+am327izjU7aOvs49WnzOL/vuoEXnJ8Y6lLE5ESUehXgPbufm763SZufWgL+7v6edH8\net557iLecPocqpLxUpcnIlNIoV9Buvqy3PFwC99+YDObWjupSyc4bd40rjx/CectbSIZ1+EYIlGn\n0K9A7s6Dm/Zy55rt3PvkLvZ29tFYk2LFKbNYcepxvPzEGdRVJUtdpohMAoV+hevpH+B361u5a80O\n7n96Nwd7syRixpnHN/CKE2dyyRlzWdhYjZmVulQRmQATGvpmdjHwVSAOfMPdPz9sfhr4DvASYC/w\nNnffnDd/IbAW+JS7f2m051LoT7z+gRwPP7eP365v5Q8b9vDE9nbcYW59Fa88uYnzljbx8hNnUp/R\ntwCRY9WEhb6ZxYH1wGuBFmA1cJm7r81r8wHgdHd/n5mtBN7s7m/Lm/8jIAc8pNAvvS17u/jthlb+\nsKGVBzbu5WBvlphBY02axpok7zzneFacepyOAxA5hkxk6J9LsIV+UXj/WgB3/1xem3vCNg+aWQLY\nCTS5u5vZm4CXA51Ah0K/vPQP5Hh0635+v76VDbs7+M26VrrDa/kunlnDeUtncubCBl4wbxqLZ9YS\nj6k7SKQcFRv6iSKWNQ/Ymne/BXjpSG3cPWtm7cAMM+sGPkbwLeEjoxR7JXAlwMKFC4soSSZKMh7j\nrEWNnLUoGOOfyzmPtuznwU17ad7cxg+bW/jOg88Nta9OxVm+qJHXnDqLF82fzsmz6zQ8VOQYUkzo\nF9q0G/71YKQ2nwa+4u4do+0wdPebgJsg2NIvoiaZJLGYcebCBs5c2AAE3wQ2tXbwxLYDbNnbyaY9\nndz/9G5+t74VgHjMWDyzhlNm1zFveoYZtSlecnwDJx1Xp5FCImWomNBvARbk3Z8PbB+hTUvYvVMP\ntBF8I7jUzL4ITAdyZtbj7v921JXLlEjGY5wyexqnzJ42NG0g52ze28m6nQd5escBntp5kMda9nPf\n2l30Zp+/9OPc+ioWNFYzd3qGhY3VLJpZjTvMqE1z+rx6GmpSpXhJIhWtmNBfDSw1s8XANmAl8PZh\nbVYBlwMPApcCv/ZgZ8F5gw3M7FMEffoK/GNcPGac0FTLCU21vP6Fcw6Zt7ejl9Wb97F5bydP7TjA\n9v3d/OnZNn766Dbydx+ZwaIZNcxvyDC/oZoFjRlq0wnqM0lOaKolGY8xszbFjNr0FL86kWgbM/TD\nPvqrgHsIhmze4u5Pmtl1QLO7rwJuBr5rZhsJtvBXTmbRUr5m1Ka5+AWzD5vemx1ga1sXfVnniW3t\nbGrtoGVfN1v3dfHEth3s6+o/7DHxmNFYk+L4xmrqM0mq0wmqk3EyqTjTq5OctaiRE5pqmVWXJqYd\nzCJF0cFZUhY6e7N09Q2ws72Hbfu76M3m2Li7g53tPTzX1jU0f+h3X3bom0PMYF5Dhhk1aRprUkyv\nTlKfSbKmpZ1UPMaFpx3HibNqWTSjhlnT0qQTwY7n/oEcO/b3sHCGrk8gx76JHL0jMulq0glq0gma\n6tK8cP7Yl4ds6+xjTct+NrV20tbZy9a2bto6+9h1oId1Ow/S1tmH4/T053jwmb2HPLauKsH06iQH\ne7Ls7+pn3vQMZx7fwJz6Kppq01Sn49SkEmRSwe/qdJzqVJyqRJyqZJzjpqV1JLMcsxT6ckxqrElx\nwcmzuODk0dt19w2w60AP2/Z3s7Wti90He2nr7KO9u5+YGR29/exs7+HhzW3s6eijbyA3+gIJhq02\nVKeYWZtiWiZJImbUZ5I01aVJxmNUJePUphPk3Mmk4tRVJZlVl6ahOkV1Kk4iblSnEjoCWkpCoS+R\nlknFWTSzpqjrC7s7B3uzdPcNHNKV1NUX/O7pD6Y9u6eL/V19tHb00tGbJTvgrN/VQVtnH9lcjv6B\n4rpM66oSxMyYXp2kJpUgmYhRmw4+UFLxGMl4jOnVSWIxozvs+prfkGF2fRXTMsEHyYyaNKlEjHjM\n6OkfYE59FY01KX0TkREp9EVCZsa0qiTTjvL4gv6BHJ29WWIxo6t3gIM9/ew60Et7dz+dfVlyOaet\nq4/dB3rJubOvq5/uvix9A87Bnn527D9Afy5HXzbHvs5gB3dVMkZtOsE9a3dSzG64dCJGXVWSZNyI\nmVGVjJFJxUkn4tSkE+zv6mNvRx8vmDeNuvA1z6hNUZtOkEnGiceMeMxIJ2IMuLN5TycLGquZ31BN\ndSro7mqoCT6c2rv7aT3Yy7I507RD/Rig0BeZYMEWenAMwrSqJLPrq1h6XN2ELHtvRy9mRkdPltaO\nHto6++kfyNE/kCOdiNGyr5sDPVl6+oMPm+yAM+BOb3+O7v4BuvsG2N/VR0N1ivkNGR7f1k5/1jnQ\n009X38BR1VaXToBBfSaJO8yaFnR3peIxqpIxDnRnae/uZ35Dhq6+4FtJQ02KTDKO42RzTmN1MEx3\nIJdjX1c/jTVBne5QlYxTk46zs72HAz1ZTmiqYUZNsA+mszfLtKokNenDI83d6c3mdOR4SKEvcgwZ\nPG6hsSY14aOOevoH6Ai7t9wh505X3wCOs3hmDZv3dLH7YA/dfQN09g3Q1tlL/4BTm06QTsR4cvsB\nzBjaX9J6sJdsLkdXX5a2zhy16QQLGqvZ2tZFJhXndxv20NWXpbt/gLgZsZjRlx17n8poGqqTpBKx\noQ+bZDzG3s4+9nb28sJ59UPf4nLumEF1KoF70NVWlYwNve6e/hzzGoJjRzp6s0PfnFKJGNMzyaDb\nzSw4FYFBeAszmF6dJJOM80xrJ/XVSZbMrKEqGScVjx3yTSiXc3LuJMKLHHX1Zckk45PeNafQFxEg\n2JIebWt42dxpLGPaiPOPlLsPBd3Bnn72dfYTiwXfGPZ09LGjvZu4Wbh/ZWBoB/ozrZ20dfXR3Zel\nNp1kb0cvOw/0kB3w4NtPzunP5jht3jSaatM8vq2drr4sADEzcu7s7ejDzDjY009vNkcsDPBkwvjZ\n4zsYyDnJuBW9n2Ys8ZgRD19rNpcj5zCnvorO3iwHerKcuXA6P/7AyyfkuUai0BeRksrfsq2rSh5y\nzqa6qiSLR9gJf+qcif8AytcTnm22KhknO5Cjs3eA3oEB9nf1s7+rH3fHAXfw8HRk7rCvq4+u3gEW\nNFazv6sRLnABAAAFyUlEQVSPbfu76RsI9tFkB4JuLDOGvt1s2dvJ9OoUDdUpmuom/wh0hb6ISAH5\n33oS8Rj11TEgyay6qtIVNQF0xWwRkQqi0BcRqSAKfRGRCqLQFxGpIAp9EZEKotAXEakgCn0RkQqi\n0BcRqSBld+UsM2sFnjuKRcwE9kxQOZOh3OuD8q+x3OsD1TgRyr0+KK8aj3f3prEalV3oHy0zay7m\nkmGlUu71QfnXWO71gWqcCOVeHxwbNQ6n7h0RkQqi0BcRqSBRDP2bSl3AGMq9Pij/Gsu9PlCNE6Hc\n64Njo8ZDRK5PX0RERhbFLX0RERmBQl9EpIJEJvTN7GIzW2dmG83smhLWscDM7jezp8zsSTP7q3B6\no5ndZ2Ybwt8N4XQzs38J615jZmdOUZ1xM/sfM7srvL/YzB4K67vdzFLh9HR4f2M4f9EU1TfdzO4w\ns6fDdXluOa1DM/tQ+Pd9wsxuM7OqUq9DM7vFzHab2RN508a9zszs8rD9BjO7fApq/Kfw77zGzH5i\nZtPz5l0b1rjOzC7Kmz4p7/dC9eXN+4iZuZnNDO+XZB0eNXc/5n+AOLAJWAKkgMeAZSWqZQ5wZni7\nDlgPLAO+CFwTTr8G+EJ4+/XAzwEDzgEemqI6Pwx8D7grvP8DYGV4+2vA+8PbHwC+Ft5eCdw+RfV9\nG/jL8HYKmF4u6xCYBzwLZPLW3btLvQ6B84EzgSfypo1rnQGNwDPh74bwdsMk13ghkAhvfyGvxmXh\nezkNLA7f4/HJfL8Xqi+cvgC4h+DA0ZmlXIdH/RpLXcAE/aHOBe7Ju38tcG2p6wpr+S/gtcA6YE44\nbQ6wLrz9deCyvPZD7SaxpvnAr4BXA3eF/7R78t54Q+sz/Ec/N7ydCNvZJNc3LQxVGza9LNYhQehv\nDd/UiXAdXlQO6xBYNCxQx7XOgMuAr+dNP6TdZNQ4bN6bgVvD24e8jwfX42S/3wvVB9wBvAjYzPOh\nX7J1eDQ/UeneGXwTDmoJp5VU+DX+xcBDwHHuvgMg/D0rbFaK2q8HPgrkwvszgP3uni1Qw1B94fz2\nsP1kWgK0At8Mu6C+YWY1lMk6dPdtwJeALcAOgnXyMOW1DgeNd52V+r30FwRbz4xSy5TWaGZvBLa5\n+2PDZpVFfeMVldC3AtNKOhbVzGqBHwF/7e4HRmtaYNqk1W5mbwB2u/vDRdZQinWbIPiK/e/u/mKg\nk6BrYiRTvQ4bgEsIuhzmAjXA60apoez+Pxm5ppLVamYfB7LArYOTRqhlymo0s2rg48AnCs0eoY5y\n/HsPiUrotxD0uQ2aD2wvUS2YWZIg8G919x+Hk3eZ2Zxw/hxgdzh9qmt/OfBGM9sMfJ+gi+d6YLqZ\nJQrUMFRfOL8eaJvE+gafs8XdHwrv30HwIVAu6/A1wLPu3uru/cCPgZdRXutw0HjXWUneS+HOzjcA\n7/CwT6RMajyB4MP9sfA9Mx94xMxml0l94xaV0F8NLA1HT6QIdpatKkUhZmbAzcBT7v7lvFmrgMG9\n+JcT9PUPTn9XOBLgHKB98Ov4ZHD3a919vrsvIlhPv3b3dwD3A5eOUN9g3ZeG7Sd1q8XddwJbzezk\ncNIKYC1lsg4JunXOMbPq8O89WF/ZrMM8411n9wAXmllD+I3mwnDapDGzi4GPAW90965hta8MRz8t\nBpYCf2IK3+/u/ri7z3L3ReF7poVgoMZOymgdjkupdypM1A/BnvT1BHv1P17COl5B8FVuDfBo+PN6\ngj7cXwEbwt+NYXsDbgjrfhxYPoW1XsDzo3eWELyhNgI/BNLh9Krw/sZw/pIpqu0MoDlcjz8lGAVR\nNusQ+DTwNPAE8F2CESYlXYfAbQT7GPoJwumKI1lnBP3qG8Of90xBjRsJ+sAH3y9fy2v/8bDGdcDr\n8qZPyvu9UH3D5m/m+R25JVmHR/uj0zCIiFSQqHTviIhIERT6IiIVRKEvIlJBFPoiIhVEoS8iUkEU\n+iIiFUShLyJSQf4/CcSLb4WyQggAAAAASUVORK5CYII=\n", 180 | "text/plain": [ 181 | "" 182 | ] 183 | }, 184 | "metadata": {}, 185 | "output_type": "display_data" 186 | } 187 | ], 188 | "source": [ 189 | "epoch = 1500\n", 190 | "pltX, pltY = [], []\n", 191 | "for e in range(epoch):\n", 192 | " pltX.append(e)\n", 193 | " Input = cudAvl(Variable(digTensor))\n", 194 | " Target = cudAvl(Variable(digTensor))\n", 195 | " _, Output = net(Input)\n", 196 | " loss = criterion(Output,Target)\n", 197 | " print_loss = loss.data[0]\n", 198 | " pltY.append(print_loss)\n", 199 | " optimizer.zero_grad()\n", 200 | " loss.backward()\n", 201 | " optimizer.step()\n", 202 | " if (e + 1) % 100 == 0:\n", 203 | " print('epoch [%s/%s]: %s' %(e + 1, epoch, print_loss))\n", 204 | "\n", 205 | "plt.title('loss function output curve')\n", 206 | "plt.plot(pltX, pltY)\n", 207 | "plt.show()" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 8, 213 | "metadata": { 214 | "collapsed": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "from PIL import Image\n", 219 | "import numpy as np\n", 220 | "\n", 221 | "def getImage(matrix):\n", 222 | " dig = np.array(matrix.numpy()).reshape((28, 28))\n", 223 | " digImg = Image.fromarray(dig * 255)\n", 224 | " digImg = digImg.convert('L')\n", 225 | " return digImg\n", 226 | "\n", 227 | "def compare(matrix, testNet):\n", 228 | " before = getImage(matrix)\n", 229 | " plt.subplot(121)\n", 230 | " plt.title('before encoding')\n", 231 | " plt.imshow(before)\n", 232 | " \n", 233 | " matrix = matrix.unsqueeze(0)\n", 234 | " In = cudAvl(Variable(matrix))\n", 235 | " code, Out = testNet(In)\n", 236 | " if torch.cuda.is_available():\n", 237 | " Out = Out.cpu()\n", 238 | " after = getImage(Out.data)\n", 239 | " \n", 240 | " plt.subplot(122)\n", 241 | " plt.title('after encoding')\n", 242 | " plt.imshow(after)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 16, 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "name": "stdout", 252 | "output_type": "stream", 253 | "text": [ 254 | "image index = 36717 , number = 3 \n", 255 | "\n" 256 | ] 257 | }, 258 | { 259 | "data": { 260 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGttJREFUeJzt3XmQXVWdB/Dvt/furCQhkISQhCWZMKhhptkKSuIgCokU\nUCUqWg6MQJBB0Rl0RMoZGJEqpgp0HLUYkwEBQZESEGoAhYkwDIiYAJEtCQlLNkIWsnQn6U56+c0f\n70YffX43/V6//fT3U9XV75137r3nvvd7p2/fs9HMICIita+u0gUQEZHiUIUuIhIJVegiIpFQhS4i\nEglV6CIikVCFLiISiagrdJJvk/zoELedRfJFkp0kryx22aoNyekkjWRD8vxRkhdWulxSPCS/Q3Ir\nyXcrXZZSGe5x3FDpAlSxfwLwpJkdV+mCVIKZnVXpMkjxkJwK4CoA08xsM8npAN4C0GhmvZUsWykN\ntziO+gq9QNMAvDqUDfdfHYhUkWkA3jOzzcXYmWK8Og2HCv14kq+R3E7yJyRb9r9A8hMkl5HcQfJ3\nJD+YpP8WwEcA/JDkLpIzSY4heSfJLSTXkPwWybok/0UknyH5PZLbAFyXpH+B5PLk2L8hOS2tkCRP\nSsqwg+QfSc7Neu1Jktcnx+gk+RjJCVmvn5q17TqSFyXpBypzPcmbkn/B3wQwf0B5niR5Sdb5PZ3k\n307yLZJnZeWdQfKppGz/Q/JHJO8a4uclQ0TyapJvJJ/DayTPS9I/CuBxAJOTeL4dwFPJZjuStJOT\nvKkxm9zKuILkKgCrUsqgOK4kM4v2B8DbAF4BMBXAOADPAPhO8tpfAdgM4EQA9QAuTPI3J68/CeCS\nrH3dCeBBAKMATAfwOoCLk9cuAtAL4MvI3MZqBXAugNUAZidp3wLwu5RyTgHwHoB5yPyRPSN5fnBW\nWd4AMDPZ95MAbkxeOxxAJ4ALADQCGA9gTg5l/iKAFVnvzRMADEDDwPNPzq8HwKXJe3U5gHcAMHn9\nWQA3AWgCcCqADgB3VfrzH24/AM4HMDmJoU8D2A1gUvLaXADrs/JOz/68k7QDxmyS//EkXloVx9X3\nU/EClDjA3wbwxazn8wC8kTy+BcD1A/KvBHCaEwj1APYCOCYr72XI3GPfHyhrB+zr0f1BlzyvA7AH\nmXuYA8v5DQA/HZD2GwAXZpXlW1mv/T2AXyePvwngAWefg5X5twPem48N8kVYnZW3Lcl7aPJF7AXQ\nlvX6XbX2RYjxB8AyAOckj+di8Ar9gDGb5P+bAxxPcVzhn+Fwy2Vd1uM1yFzBAJl7ilcl/97tILkD\nmb/ykwfuAMAEZP5qrxmwrykpx9m//+9n7XsbAA7YJjvv+QPKciqASVl5snsm7AEwMnk8FZmrnnzL\nPBnhe3Mgfzq+me1JHo5M9rMtKw0I3wspA5J/yz/fQtwB4Fhk4iBXucTsgT5bxXGFDYeGjalZjw9H\n5l8sIPNh3WBmN+Swj63I/Ks2DcBrWfvakJVn4LSV+/d/dw77X4fMlc2lOeT1tj3BSR+szBsRvjdD\nsRHAOJJtWV+GqQfaQIovude9CMDpAJ41sz6Sy5CpkD3eNKu5xOyBpmdVHFfYcLhCv4LkYSTHAbgG\nwC+S9EUAvkjyRGaMIDmf5KiBOzCzPgD3AriB5Kjky/OPyPxLluY/AXyT5F8Cf2rYOT8l710Azib5\n8aSRp4XkXJKH5XB+dwP4KMlPkWwgOZ7knBzKfC+AK5P35iAAV+dwrICZrQGwFMB1JJuSxrWzh7Iv\nKcgIZCrbLQBA8u+QuUJPswVAP4AjstLyiVmP4rjChkOF/jMAjwF4M/n5DgCY2VJkGkd+CGA7Mo1B\nFx1gP19GppHpTQBPJ/u9LS2zmT0A4N8A3EOyA5nGWbdPrJmtA3AOMn9wtiBztfJ15PD5mNlaZNoG\nrkLmX+RlAD6UQ5kXIXN/848AXgBw/2DHOoDPATgZmQaw7yDzR3NvAfuTPJnZawBuRqZhbxOADyDT\nCSAt/x4ANwB4Jrk9clI+MZuyT8Vxhe1v3RUpGpK/ALDCzK6tdFlEhqoW43g4XKFLiZE8nuSRJOtI\nnonMVdqvKl0ukXzEEMfDoVFUSu9QZP7VHQ9gPYDLzezFyhZJJG81H8e65SIiEgndchERiURBFTrJ\nM0muJLma5JC6C4lUI8W21KIh33IhWY/MnApnIHO/aQmAC5LuU64mNlsLRgzpeCKD6cZu7LO9aQNp\nclazse2d+XC7o5rPMKpqkGN5c43tQhpFT0BmXoQ3AYDkPci0CqcGfQtG4ESeXsAhRdI9Z4uLtavy\nxzZTvqt5XHCxIfw6W28Rpjr3ylalbW/eewAU6X0ogVzLm2tsF3LLZQreP9fBejjzlJBcQHIpyaU9\ntdVHX4YvxbbUpEIq9Jz+wTOzhWbWbmbtjWgu4HAiZaPYlppUyC2X9Xj/5DWH4c8TX4nUsvLHdj63\nMFJuz5TqtgLr68t2rEKllqvQ20Zpt8TcQuS+32K/j4VcoS8BcHSyykcTgM8AeKg4xRKpKMW21KQh\nX6GbWS/JLyEzMU49gNvMbEhrcIpUE8W21KqChv6b2SMAHilSWUSqhmJbapFGioqIREIVuohIJDTb\nokgxMBwkUrIeF2WWV0+MHM+trqXF377OucZ0etmk6e/s9F8o9P0t1edThAFl2XSFLiISCVXoIiKR\nUIUuIhIJVegiIpFQo6hIMZjTeFjkBq8hKXcZnP2ysSlMG+lPNczGxtwP1dsXbt/VlZK3OqcqKPbn\noCt0EZFIqEIXEYmEKnQRkUioQhcRiYQqdBGRSKiXi0iplLEnSTHypq1vWTd2TLjbyQe7eXsPag33\n2xuWYfch/gpPPa1hr5y6sDMLAKDt3X1BWvPqsEcNAPTv7AjSLK1HTF/KAR3VtviHrtBFRCKhCl1E\nJBKq0EVEIqEKXUQkEgU1ipJ8G0AngD4AvWbWXoxCxaJ+1lFu+rbvhWknTFwTpP3H5CU5H+vhPf78\n0te//okgbdw/+PvoW7k65+PFriZj2xnmzwZ/KD3rw2s5HjbJzdt9xPggbesH/MbHXdPDBsUxh+8M\n0uYc8oa7/YSmXUHa8o5D3byvrpgapI19eZqbd8ybPUFa6wZ/7vT6LTuCNG+aAQAwZ/71Ss6DX4xe\nLh8xs61F2I9ItVFsS03RLRcRkUgUWqEbgMdIPk9yQTEKJFIlFNtScwq95XKKmb1DciKAx0muMLOn\nsjMkX4YFANCCtgIPJ1I2im2pOQVdoZvZO8nvzQAeAHCCk2ehmbWbWXsj/NFhItVGsS21aMhX6CRH\nAKgzs87k8ccAfLtoJasxV65eEaTNb1uW8/YnLftkmLbZb7H3eL1kAOD3c34ZpM342qVu3pl+8rBT\n9bGdsmiF16OFLf4fmrrxBwVpXdPHuXl3HBnu1+vNAgAzZm8M0r4w9Zkg7eimd93tt/WNDNIOb97m\nl6s7nGZgY/dEN29DV1jVNXX4PcPqtjvVYlovl77+IC1tCgWv90tq3oHTD+TYGaaQWy6HAHiAmeBq\nAPAzM/t1AfsTqRaKbalJQ67QzexNAB8qYllEqoJiW2qVui2KiERCFbqISCQ0H3qe0obz/2ZH2HB0\n/espw5DnhUPsx6CwYfd/eMQvF5zpA2bftN3Nmvss0FIr2OwP0be2sEGwryWc2xsA+hucKQV6/YbZ\ndVvCxtYbdpwVpHVvGuFujxHOsPm6lBZB5zvXusm/Rm19L9xv3Z5wOgAAQE+Y3r8jnL4gb05jdrHn\nTtcVuohIJFShi4hEQhW6iEgkVKGLiERCFbqISCTUyyVPaYtArHSWPyi050qanU6PFm+IPwDMeDgc\nzz9zZe4LZ0gJFLrQQUpe6w17Z9jefW7euu0dQVrbWv/6jjYqSKvf61cd3ZvD3itN4aHQts8/h54R\n4VQF+8a6WdHkdDxp2ervt2Vr+D7Ub/V7rnjD+etG+r1yrKsrTOsNty8XXaGLiERCFbqISCRUoYuI\nREIVuohIJNQoWiW8KQWWfy0cRg0Ab81ZFKR586kDwMxL1QBa05wGVNb7Q/S9YeT9zqr0gN+YV7d7\nj5u3lVOCtO6xY9y8fc1heRu6woZKpswzsW9SmNY9MW14fFh9jXjHbxSt37U3SLNdu/3dDpyLHOlD\n9L30YC7zhDdfvfX4jdZDpSt0EZFIqEIXEYmEKnQRkUioQhcRicSgFTrJ20huJvlKVto4ko+TXJX8\n9lvvRKqYYltik0svl9sB/BDAnVlpVwNYbGY3krw6ef6N4hcvTnvPOj5I+/oPfhqkzW/rdrf/i/+6\nPEib9i/PFl6w4ed2lDC281n9PXUfXo8WplyH5TGlgFeGvg5njD6Aho7wb1pvmz8ev99ZT6MHYbl6\nwtkEAADdh4XTF7SN93vf9OwYHebd5Pca4YbNYWJKbxT38+lPGc7vfBb0OyG5UzOkGvhZ5jgzxKBX\n6Gb2FIBtA5LPAXBH8vgOAOfmdjiR6qHYltgM9R76IWa2EQCS3xOLVySRilJsS80q+cAikgsALACA\nFrSV+nAiZaPYlmoz1Cv0TSQnAUDy27lBlWFmC82s3czaGxFOjSlSZRTbUrOGeoX+EIALAdyY/H6w\naCWKyOuLwsZPAHhrfjh0/+E94Srscy8O5zIHgGmPqgG0hIoW23mt6O41aKbu2G+gK9XQ8t5DwmH+\nPSP98u46whk23xym1bf6DZKzJ28K97nP/2O5a13YKNr4/Co3b58zBUL96HB7wG/MTv0s8xj6X4w5\n7weTS7fFnwN4FsAskutJXoxMsJ9BchWAM5LnIjVFsS2xGfQK3cwuSHnp9CKXRaSsFNsSG40UFRGJ\nhCp0EZFIqEIXEYmEFrgokp2PhAtUeAtRAMCV74S9X1a2h8OCm6HFKYaNtCH6To+Juma/14e3Wn0x\n9IwOx/PvOtw/1gePfTtImzU67LlyZIvfG3Rvf9hT50evftjNO+OZ7UFa2oIeHq/nCwDUj3LmJchj\nUZFK0hW6iEgkVKGLiERCFbqISCRUoYuIREKNokXy+zm/DNK84fwAsPpz07zUIpdIKi5tOH8ew7rd\n+dDr/OswNpZo6H9reLwRM3a6eU8ZH8bxaSNWBGmH1u91t//17plB2t6trW5e9oXnxkZnQnakvA8p\nn4PXWOp+DlVIV+giIpFQhS4iEglV6CIikVCFLiISCTWKFom3cPOKS25x885/ImxAzcdJyz6Zc96e\nBw8O0ib8WPOpl8UQ57R+H2cRYtuXshByq994mPux/EZcqw/TO9/1V3l+YsysIO2PHVODtOlt77nb\nv74rXPGveUKXm/eNz44L0g5/zF85qv6Zl4O01FGezueWltedOz1lPnSvYTW1DKVaJFpERGqDKnQR\nkUioQhcRiYQqdBGRSOSypuhtJDeTfCUr7TqSG0guS37mlbaYIsWn2JbY0AZpiSf5YQC7ANxpZscm\nadcB2GVmN+VzsNEcZydSyzVuvezkIG33lDDf2PYtOe/Tm3ogTVovmTHzanv6gedsMTpsW8p4+1Ct\nxnZdiz+lBJvCYe/9Xd1u3nymBOg/7bggbeuxfo+aHqfzS58zGr/70JTeHc3hPOtTD/N7xIxuDs9t\nxZLpbt6jf+rMnf5SOCVBURRhyoeBco3tQa/QzewpANuGXBKRKqXYltgUcg/9SyRfSv5tPahoJRKp\nPMW21KShVui3ADgSwBwAGwHcnJaR5AKSS0ku7YE/w5pIFVFsS80aUoVuZpvMrM/M+gEsAnDCAfIu\nNLN2M2tvhL8Woki1UGxLLRvS0H+Sk8xsY/L0PACvHCi/vJ839H5Cgfv8OOa46bOWhnNkpzWgzpsV\nNpb2razthtJ81XJse8PI82n8TJtLvGFH2Pg4an0YVwDQ0BU2avY1hdeNPSP9a8ndk8Iqafs4vwH2\n7CkvhYnh+usAgG3Ph2sQjHI2z5vTAMoG/72x3nAh+KJMD5Fl0Aqd5M8BzAUwgeR6ANcCmEtyDjIz\nDLwN4LKilkqkDBTbEptBK3Qzu8BJvrUEZREpK8W2xEYjRUVEIqEKXUQkEqrQRUQioQUuIvfYI+1h\n4iVL3Lyb5jqLYQyzXi61wPrCniQAkPOcB0XQsNtfxKFxV9iTo6Uz7KPvLZoBAE07w7kDtp3m9wQ5\npmVDmJjSXexn42cEaf4SHfnxerS4vVkAt0eLt0BGZh8pUyMMQlfoIiKRUIUuIhIJVegiIpFQhS4i\nEgk1ikaiftZRbvrNn/1JzvvwpiSQ6sPGlK9tv9NYms/c3OY3tnp563r9hsr+hvAasWHnrnCXOzvc\n7VvqDw/Sxrb6c7p/oGlrkPZK11R/v9tTzq1ArPeuiVOG/jvTMAy18TONrtBFRCKhCl1EJBKq0EVE\nIqEKXUQkEqrQRUQioV4ukbji4f920+e3hT0E/vpfL3fzToB6uRQkrUfJQAUuasBmf3Uk6+oq6Fhp\nPS7qO/YEaXUj/DLU9ToLXGzaEh4rZeGNjtnhgPwzD33RzTuK4fXo/2492s07ZnlnWAY3Z368aRhS\nh/6Xga7QRUQioQpdRCQSqtBFRCKhCl1EJBK5LBI9FcCdAA4F0A9goZl9n+Q4AL8AMB2ZxXQ/ZWbb\nS1fU4SdtOP9Rd68J0rzGTwA4adkngzQN8c8oemzn2gCZx0rxqcP8HXUHjQ237/Eb6Pre2xZun9bY\n2twUpPU317t5e0eE5W2ZEQ7H7z3Yn438vbPDht1rD37NzfvqvvB9XLVkmpv3iBcLjPmUBu9Ch/4X\nWy5X6L0ArjKz2QBOAnAFyWMAXA1gsZkdDWBx8lyklii2JSqDVuhmttHMXkgedwJYDmAKgHMA3JFk\nuwPAuaUqpEgpKLYlNnndQyc5HcBxAJ4DcIiZbQQyXwwAE1O2WUByKcmlPQiXohKpBoptiUHOFTrJ\nkQDuA/BVM/PnvnSY2UIzazez9kb49+hEKkmxLbHIqUIn2YhMwN9tZvcnyZtITkpenwRgc2mKKFI6\nim2JSS69XAjgVgDLzey7WS89BOBCADcmvx8sSQmHidcXHR+kvTV/Uc7bz3j4Ujd95qVLhlym2FUs\ntgsd+j/G7yHSN2F0kNbf7H/F9409IkxMWQNiz8RwH50z/F4fPSPCc7OGg4O0sUeFvWwA4Ok54YIs\nfdbq5r1g2ReCtKmPFd6TpK6tLUys93v1wOlFVI7eLGly6RN1CoDPA3iZ5LIk7Rpkgv1ekhcDWAvg\n/NIUUaRkFNsSlUErdDN7GkDarEOnF7c4IuWj2JbYaKSoiEgkVKGLiERC86EXiTdMf/nXDnLz+o2d\ny4IUb9g+ALT8INzvzEfV+FnL0hrS3KH/afOuO+m7J/vdKTumhY18u47y50OfPSucauLzE1918x7R\nvClI6+wLGzUnN/ozKbzbF5br0yvOc/O23hdOddDy/Ao3b39D+D6yKZzSAEiZb76vz9+vM4c8nWMB\n6fPN+4UY8Fnm2I6uK3QRkUioQhcRiYQqdBGRSKhCFxGJhCp0EZFIqJdLntZ8+2Q3fcUlt+S8jyvf\nCYf5v/TPc4K0Meq5MuxZT9gzwnb484fVtYa9Mxq6/GHzDeE6Eqjr9q/vtneH+1jV5U5AiT5nnNaG\nvWGvrHs6T3C3f3nt5CBt3BMtbt6JT6wN0vr3+rNeej1a2Oa/Nx5L2W/JDHF6CF2hi4hEQhW6iEgk\nVKGLiERCFbqISCTUKHoAXgNoWuOnN0y/58FwHmgAmPDjcAXyZqgBVELW68y3ndbwt/7dIG1k5x43\nb/N748O8G/3Gx94/hHH8B/qx/cLesDGPzqj5viZ/+oIj14attY3rwsZPAOhdv8FNz1Vdf8oE8Hnw\nhvOnDf0vB12hi4hEQhW6iEgkVKGLiERCFbqISCQGrdBJTiX5BMnlJF8l+ZUk/TqSG0guS37mlb64\nIsWj2JbY0AYZYkpyEoBJZvYCyVEAngdwLoBPAdhlZjflerDRHGcnUks1Smk8Z4vRYdvS1ggNVCy2\nvQUqhjjUezB1LX7PFVdjo5vc39lZWCGc860fP87Puy/s1dPf1e1mTVsUpOLSFiAp4DPONbZzWSR6\nI4CNyeNOkssBTBlyyUSqhGJbYpPXPXSS0wEcB+C5JOlLJF8ieRtJd701kgtILiW5tAdlnuBGJEeK\nbYlBzhU6yZEA7gPwVTPrAHALgCMBzEHmKudmbzszW2hm7WbW3gh/fUORSlJsSyxyqtBJNiIT8Heb\n2f0AYGabzKzPzPoBLALgz4cpUsUU2xKTQe+hkySAWwEsN7PvZqVPSu5BAsB5AF4pTRFFSqNisV2i\nBlBPf7ffoOg23KXlLZRzvrZrdx7bFz5E31WCxssDbl+GxvBcJh04BcDnAbxMclmSdg2AC0jOAWAA\n3gZwWVFLJlJ6im2JSi69XJ4GnGVIgEeKXxyR8lFsS2w0UlREJBKq0EVEIqEKXUQkElrgQqTWpPXO\n8KT1oihjTxuvvKm9bwrcb6pCz7dUPWKKTFfoIiKRUIUuIhIJVegiIpFQhS4iEolB50Mv6sHILQDW\nJE8nANhatoOXj86rcqaZmb8cfYllxXYtvE9DFeu51cJ55RTbZa3Q33dgcqmZtVfk4CWk8xreYn6f\nYj23mM5Lt1xERCKhCl1EJBKVrNAXVvDYpaTzGt5ifp9iPbdozqti99BFRKS4dMtFRCQSZa/QSZ5J\nciXJ1SSvLvfxiylZQHgzyVey0saRfJzkquS3u8BwNSM5leQTJJeTfJXkV5L0mj+3UoolthXXtXdu\n+5W1QidZD+BHAM4CcAwyK8McU84yFNntAM4ckHY1gMVmdjSAxcnzWtML4Cozmw3gJABXJJ9TDOdW\nEpHF9u1QXNekcl+hnwBgtZm9aWb7ANwD4Jwyl6FozOwpANsGJJ8D4I7k8R0Azi1roYrAzDaa2QvJ\n404AywFMQQTnVkLRxLbiuvbObb9yV+hTAKzLer4+SYvJIfsXGE5+T6xweQpCcjqA4wA8h8jOrchi\nj+2oPvtY47rcFbo3qbC62VQpkiMB3Afgq2bWUenyVDnFdo2IOa7LXaGvBzA16/lhAN4pcxlKbRPJ\nSQCQ/N5c4fIMCclGZIL+bjO7P0mO4txKJPbYjuKzjz2uy12hLwFwNMkZJJsAfAbAQ2UuQ6k9BODC\n5PGFAB6sYFmGhCQB3ApguZl9N+ulmj+3Eoo9tmv+sx8OcV32gUUk5wH4dwD1AG4zsxvKWoAiIvlz\nAHORma1tE4BrAfwKwL0ADgewFsD5ZjawgamqkTwVwP8BeBlAf5J8DTL3G2v63EoplthWXNfeue2n\nkaIiIpHQSFERkUioQhcRiYQqdBGRSKhCFxGJhCp0EZFIqEIXEYmEKnQRkUioQhcRicT/AysNDIFE\nb9eiAAAAAElFTkSuQmCC\n", 261 | "text/plain": [ 262 | "" 263 | ] 264 | }, 265 | "metadata": {}, 266 | "output_type": "display_data" 267 | } 268 | ], 269 | "source": [ 270 | "import random\n", 271 | "index = random.randint(0, digTensor.shape[0])\n", 272 | "print('image index =', index, ', number = ', label[index], '\\n')\n", 273 | "\n", 274 | "net.eval()\n", 275 | "compare(digTensor[index], net)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "效果还不错~\n", 283 | "\n", 284 | "---\n", 285 | "\n", 286 | "接下来试一下以**卷积神经网络**为基础的自动编码器" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 10, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "class CNNautoencoder(nn.Module):\n", 296 | " def __init__(self):\n", 297 | " super().__init__()\n", 298 | " self.encoder = nn.Sequential(\n", 299 | " nn.Conv2d(1, 16, 3, stride=3, padding=1),\n", 300 | " nn.ReLU(True),\n", 301 | " nn.MaxPool2d(2, stride=2),\n", 302 | " nn.Conv2d(16, 8, 3, stride=2, padding=1),\n", 303 | " nn.ReLU(True),\n", 304 | " nn.MaxPool2d(2, stride=1)\n", 305 | " )\n", 306 | " \n", 307 | " self.decoder = nn.Sequential(\n", 308 | " nn.ConvTranspose2d(8, 16, 3, stride=2),\n", 309 | " nn.ReLU(True),\n", 310 | " nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),\n", 311 | " nn.ReLU(True),\n", 312 | " nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),\n", 313 | " nn.Tanh()\n", 314 | " )\n", 315 | " \n", 316 | " def forward(self, x):\n", 317 | " code = self.encoder(x)\n", 318 | " transImg = self.decoder(code)\n", 319 | " return code, transImg" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 11, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "CNNnet = CNNautoencoder()\n", 329 | "CNNnet = cudAvl(CNNnet)\n", 330 | "CNNcriterion = nn.MSELoss()\n", 331 | "CNNoptimizer = torch.optim.Adam(CNNnet.parameters(), lr=1e-3, weight_decay=1e-5)\n", 332 | "\n", 333 | "CNNdigTensor = digTensor.view(digTensor.shape[0], 1, 28, 28)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 12, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "epoch [100/1000]: 0.06539251655340195\n", 346 | "epoch [200/1000]: 0.05366114154458046\n", 347 | "epoch [300/1000]: 0.04686489701271057\n", 348 | "epoch [400/1000]: 0.042550913989543915\n", 349 | "epoch [500/1000]: 0.0394083596765995\n", 350 | "epoch [600/1000]: 0.03722544014453888\n", 351 | "epoch [700/1000]: 0.03553758189082146\n", 352 | "epoch [800/1000]: 0.034192875027656555\n", 353 | "epoch [900/1000]: 0.03310518339276314\n", 354 | "epoch [1000/1000]: 0.03222407400608063\n" 355 | ] 356 | }, 357 | { 358 | "data": { 359 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8HPV9//HXR7d13/IlWTKWDcIYA7INhMMJR0xKcJJC\nsUOpQ2lpmtKmOdqQX1tCSC/StCRNSBo3kECaYii5HEIChMMQTstgjI0v+ZZt2bIlS7JlSZb0+f2x\nY2cRkrWyJa20+34+Hvvw7sxXO5/RyO+Z/c7sd8zdERGR+JAQ7QJERGTkKPRFROKIQl9EJI4o9EVE\n4ohCX0Qkjij0RUTiiEI/TpnZdjO7coSWNc7MfmFmzWb2fyOxzLBlrzOz+SO5TJHRTKEvI+F6oAQo\ncPcbhmshZvYDM/vH8Gnufra7Pz9cyzwVQ73DHckduIx9Cn0ZCVOATe7eFe1C5NSZWVK0a5DTp9AX\nzCzVzL5uZnuCx9fNLDWYV2hmj5vZITNrNLMXzSwhmPcFM9ttZq1mttHMrujjvb8M3AncaGaHzexW\nM7vLzP4nrE25mfnxUDGz583sK2b2UvDeT5lZYVj7S8zs5aCmXWb2CTO7DbgJ+NtgOb8I2p44Ch5g\nPeebWZ2Zfc7M9pvZXjO75SS/s4lmtjz4ndSa2Z+GzXvXJ47j7x08/yFQBvwiqPNvw9b/tqCuvWb2\nuVN9v37qXWhmq82sxcy2mNmC3r+f4PWJbRNW161mthN41sx+bWa393rvt8zsY8HzM83s6eD3stHM\n/qC/36FEh0JfAP4OuBCYDZwLzAX+Ppj3OaAOKCLURfP/ADezGcDtwBx3zwI+CGzv/cbu/iXgn4FH\n3D3T3e+PsKaPA7cAxUAK8HkAMysDfgV8M6hpNrDa3ZcCPwK+Giznw4NcT4DxQA4wCbgVuM/M8vqp\n72FCv5eJhLqv/rmvnV5v7n4zsBP4cFDnV8Nmvx+oBK4G7oiky2aA9wPAzOYCDwF/A+QCl9HHtjqJ\ny4GzCG3j/wUWh713FaFPcr80swzg6aBNcdDu22Z29iCWJcNMoS8QOkK+2933u3sD8GXg5mDeMWAC\nMMXdj7n7ix4asKkbSAWqzCzZ3be7+5YhrOn77r7J3Y8CjxIK6uO1/sbdHw7qOejuqyN8z5OtJ4TW\n9e7gfZ8ADgMzer+JmZUClwBfcPf2YPnf6/Vep+LL7n7E3d8Gvk9YuJ6mW4EH3P1pd+9x993uvmEQ\nP39XUNdR4KfAbDObEsy7CfiJu3cA1wLb3f377t7l7m8APya0U5RRQqEvEDpa3RH2ekcwDeDfgFrg\nKTPbamZ3ALh7LfDXwF3AfjNbZmYTGTr1Yc/bgMzgeSlwqjuXk60nwMFe5x3Cl9v7fRrdvbXXe006\nxbqO23WS2k7H6fzOIKyuYJ1/CSwKJi0i9AkLQkf884Jut0NmdojQTmH8aSxbhphCXwD2EPoPe1xZ\nMA13b3X3z7n7VODDwGePd2O4+/+6+yXBzzpwT4TLOwKkh70eTCjsAs7oZ95AQ8b2u56DtAfIN7Os\nXu+1O3g+0Pr1V2dpP7Wd6vsdd7LfWSTbovf7PwwsNrOLgHHAc2HLWeHuuWGPTHf/8wHqkxGk0BcI\n/Sf+ezMrCk6Y3gkcP5l3rZlNMzMDWgh163Sb2Qwz+0BwIrQdOBrMi8Rq4DIzKzOzHOCLg6j1R8CV\nZvYHZpZkZgVmdrzrZx8w9VTWczDcfRfwMvAvZpZmZrMIdaEcP+JdDXzIzPLNbDyhT0Th+qvzH8ws\nPegDvwV45DTf77j7gVvM7AozSzCzSWZ2Zth7LzKzZDOrJrKumCcI7TzvJnSupieY/jgw3cxuDt4v\n2czmmNlZEbynjBCFvgD8I1ADrAHeBt4IpkHoxOJvCPVvvwJ8O7juPRX4V+AAoa6YYkIneQfk7k8T\nCrQ1wCpCYRERd98JfIjQCeZGQqF1bjD7fkLnGA6Z2c8GuZ6DtRgoJ3Q0/lPgS8F6AfwQeIvQydKn\n+F14H/cvhHY+h8zs82HTVxDqSnsG+Jq7P3Wa7weAu79OaCdyL9AcLOf4J55/IPQpoInQOY7/HWjF\ng/77nwBXhrcPun6uJtTls4fQ38U9hP5WZJQw3URFJLrMrBzYBiTruwwy3HSkLyISRxT6IiJxRN07\nIiJxREf6IiJxZNQNoFRYWOjl5eXRLkNEZExZtWrVAXcvGqjdqAv98vJyampqol2GiMiYYmY7Bm6l\n7h0Rkbii0BcRiSMKfRGROKLQFxGJIwp9EZE4otAXEYkjCn0RkTgSM6G/+9BR/vVXG9jX0h7tUkRE\nRq2YCf22ji7+a8UWnn5nX7RLEREZtWIm9KcVZzKlIF2hLyJyEhGFvpktMLONZlZ7/MbYveZfZmZv\nmFmXmV3fa94SM9scPJYMVeF91MBVZ5XwypaDHO7QfShERPoyYOibWSJwH3ANUEXohshVvZrtBD5B\nr1utmVk+8CVgHjAX+JKZ5Z1+2X27qqqEzu4eVmxsGK5FiIiMaZEc6c8Fat19q7t3AsuAheEN3H27\nu68Benr97AeBp9290d2bgKeBBUNQd58umJJHdloSKzbtH65FiIiMaZGE/iRgV9jrumBaJCL6WTO7\nzcxqzKymoeHUj9KTEhOYN7WAV7c2nvJ7iIjEskhC3/qYFunttiL6WXdf6u7V7l5dVDTgcNAnddHU\nAnY2trH70NHTeh8RkVgUSejXAaVhrycDeyJ8/9P52VNy0RkFALyy5eBwLkZEZEyKJPRXApVmVmFm\nKcAiYHmE7/8kcLWZ5QUncK8Opg2bGSVZZKUmsXpX03AuRkRkTBow9N29C7idUFivBx5193VmdreZ\nXQdgZnPMrA64Afiuma0LfrYR+AqhHcdK4O5g2rBJSDDOmZzDW7uah3MxIiJjUkS3S3T3J4Anek27\nM+z5SkJdN3397APAA6dR46CdW5rLf7+wlfZj3aQlJ47kokVERrWY+UZuuHMn59DV47yztyXapYiI\njCoxGfrnTM4FYN1udfGIiISLydCfmJNGVloSG/e1RrsUEZFRJSZD38yYUZLFxnqFvohIuJgMfYDp\n40Oh7x7p98hERGJfzIb+meOzaGnvol43VREROSFmQ39GSRYAG9TFIyJyQuyG/vhQ6G9S6IuInBCz\noZ+bnkJJdqpO5oqIhInZ0AeYMT5b3TsiImFiOvTPGp9F7f7DHOvufW8XEZH4FNOhXzUxm87uHrY0\nHI52KSIio0JMh/5ZE7IBeGePxuAREYEYD/2phRmkJCWwXgOviYgAMR76SYkJnDk+i3U60hcRAWI8\n9AGqp+SzakcT7ce6o12KiEjUxXzoXza9kI6uHl7eciDapYiIRF3Mh/6FUwsoykrlO89v0eBrIhL3\nYj7005IT+cyV01m5vYnHVtVFuxwRkaiK+dAHWDSnlJmTsvnei9t0tC8icS2i0DezBWa20cxqzeyO\nPuanmtkjwfzXzKw8mJ5iZt83s7fN7C0zmz+k1UcoIcFYNKeMjftadd9cEYlrA4a+mSUC9wHXAFXA\nYjOr6tXsVqDJ3acB9wL3BNP/FMDdzwGuAv7dzKLy6eKDZ48H4LkN+6OxeBGRUSGSAJ4L1Lr7Vnfv\nBJYBC3u1WQg8GDx/DLjCzIzQTuIZAHffDxwCqoei8MEqykpl5qRsVmxqiMbiRURGhUhCfxKwK+x1\nXTCtzzbu3gU0AwXAW8BCM0syswrgAqC09wLM7DYzqzGzmoaG4Qvl+dOLeWPnIZqPHhu2ZYiIjGaR\nhL71Ma332dD+2jxAaCdRA3wdeBnoek9D96XuXu3u1UVFRRGUdGrmzyiiu8f57WZdsy8i8SmS0K/j\n3Ufnk4E9/bUxsyQgB2h09y53/4y7z3b3hUAusPn0yz41s0tzyU5L4vmN6tcXkfgUSeivBCrNrMLM\nUoBFwPJebZYDS4Ln1wPPurubWbqZZQCY2VVAl7u/M0S1D1pSYgKXVhaxYlODLt0Ukbg0YOgHffS3\nA08C64FH3X2dmd1tZtcFze4HCsysFvgscPyyzmLgDTNbD3wBuHmoV2CwLp9RxP7WDtbv1R21RCT+\nJEXSyN2fAJ7oNe3OsOftwA19/Nx2YMbplTi05k8PnTN4ftN+qiZmR7kaEZGRFRffyA1XnJ3G2ROz\neWa9+vVFJP7EXehD6Itab+xsYn9Le7RLEREZUXEb+u7w1Dv7ol2KiMiIisvQn16SSUVhBk+uq492\nKSIiIyouQ9/MuPKsYl7b2khb53u+KyYiErPiMvQBLpteRGd3D69tbYx2KSIiIyZuQ39OeT5pyQka\ngE1E4krchn5aciLzKgp4YbNCX0TiR9yGPsCllYVsbThCXVNbtEsRERkRcR36lwffzn1Ro26KSJyI\n69CfVpzJhJw0XlC/vojEibgOfTPj0spCXqo9QFd3T7TLEREZdnEd+hC6dLOlvYu36pqjXYqIyLCL\n+9C/ZFohZqiLR0TiQtyHfm56CrMm5/KiLt0UkTgQ96EPcHllIat3HaK5TTdMF5HYptAHLp1eRI/D\nS1t06aaIxDaFPnBeaS6ZqUm8VKvQF5HYptAndMP086fksXK7Bl8TkdgWUeib2QIz22hmtWZ2Rx/z\nU83skWD+a2ZWHkxPNrMHzextM1tvZl8c2vKHztzyPDbtO0zTkc5olyIiMmwGDH0zSwTuA64BqoDF\nZlbVq9mtQJO7TwPuBe4Jpt8ApLr7OcAFwJ8d3yGMNnPK8wGo2dEU5UpERIZPJEf6c4Fad9/q7p3A\nMmBhrzYLgQeD548BV5iZAQ5kmFkSMA7oBFqGpPIhdm5pLimJCdSoi0dEYlgkoT8J2BX2ui6Y1mcb\nd+8CmoECQjuAI8BeYCfwNXd/T6qa2W1mVmNmNQ0N0blePi05kVmTc3hdoS8iMSyS0Lc+pnmEbeYC\n3cBEoAL4nJlNfU9D96XuXu3u1UVFRRGUNDzmVOTzdl0zRzu7o1aDiMhwiiT064DSsNeTgT39tQm6\ncnKARuDjwK/d/Zi77wdeAqpPt+jhMqc8j64e581d6tcXkdgUSeivBCrNrMLMUoBFwPJebZYDS4Ln\n1wPPursT6tL5gIVkABcCG4am9KF3wZR8zGDlNoW+iMSmAUM/6KO/HXgSWA886u7rzOxuM7suaHY/\nUGBmtcBngeOXdd4HZAJrCe08vu/ua4Z4HYZMzrhkZpRk6Xp9EYlZSZE0cvcngCd6Tbsz7Hk7ocsz\ne//c4b6mj2ZzK/J5bFUdXd09JCXqu2siEluUar3MKc+nrbObdXtG5ZWlIiKnRaHfy9yK0Je01MUj\nIrFIod9LSXYaZfnpvL5NoS8isUeh34c55fnU7GgidAGSiEjsUOj3YW5FHo1HOtnScDjapYiIDCmF\nfh+OD762cruu1xeR2KLQ70NFYQaFmSk6mSsiMUeh3wczo3pKvkJfRGKOQr8f1eV57Go8Sn1ze7RL\nEREZMgr9fuh6fRGJRQr9flRNyCY9JVE3VRGRmKLQ70dSYgLnl+XpCh4RiSkK/ZOoLs9jfX0LLe3H\nol2KiMiQUOifxNzyfNxhlY72RSRGKPRP4vwpeaQkJfBS7YFolyIiMiQU+ieRlpzI3PJ8Xtys0BeR\n2KDQH8AllYVs3NfK/hZdry8iY59CfwCXTCsE0NG+iMQEhf4AqiZkU5CRwm/Vry8iMSCi0DezBWa2\n0cxqzeyOPuanmtkjwfzXzKw8mH6Tma0Oe/SY2eyhXYXhlZBgXFpZyIpNDXR190S7HBGR0zJg6JtZ\nInAfcA1QBSw2s6pezW4Fmtx9GnAvcA+Au//I3We7+2zgZmC7u68eyhUYCVefPZ7GI528rm/nisgY\nF8mR/lyg1t23unsnsAxY2KvNQuDB4PljwBVmZr3aLAYePp1io2X+jCLSkhP49dr6aJciInJaIgn9\nScCusNd1wbQ+27h7F9AMFPRqcyNjNPTTU5KYP72YX6+tp6dHt1AUkbErktDvfcQO0Dv5TtrGzOYB\nbe6+ts8FmN1mZjVmVtPQ0BBBSSPvmnPGs7+1gzd26tu5IjJ2RRL6dUBp2OvJwJ7+2phZEpADhHeA\nL+IkR/nuvtTdq929uqioKJK6R9wHzixmXHIiP36jLtqliIicskhCfyVQaWYVZpZCKMCX92qzHFgS\nPL8eeNbdHcDMEoAbCJ0LGLOy0pK5dtYElq/ew5GOrmiXIyJySgYM/aCP/nbgSWA98Ki7rzOzu83s\nuqDZ/UCBmdUCnwXCL+u8DKhz961DW/rIWzS3jCOd3fzird4fdERExgYLDshHjerqaq+pqYl2GX1y\nd66+9wXGpSTy8794H++9QElEJDrMbJW7Vw/UTt/IHQQzY8nF5aypa+aVLQejXY6IyKAp9Afp+gsm\nU5yVyreeq412KSIig6bQH6S05ET+9NKpvLzlIKt26Bu6IjK2KPRPwcfnlVGUlcpXHl+vL2uJyJii\n0D8FGalJ3LHgTFbvOqTr9kVkTFHon6KPnjeJ88ty+ZdfbWB/q26wIiJjg0L/FCUkGPf8/izaOrv4\n7CNv0a1uHhEZAxT6p6GyJIu7Pnw2v609wFd/vSHa5YiIDCgp2gWMdYvmlvHO3ha++8JWphZlcOOc\nsmiXJCLSLx3pD4E7r63i0spC/u6na3lZt1UUkVFMoT8EkhITuO+m85lalMGf/c8qNu9rjXZJIiJ9\nUugPkey0ZB74xBzSkhO55QcraWjtiHZJIiLvodAfQpPz0rl/STUHD3fyJw+upP1Yd7RLEhF5F4X+\nEJs1OZdvLJrNmt3NfOHHaxhto5iKSHxT6A+Dq88ez+evnsHPV+/hOyu2RLscEZETdMnmMPnU/DPY\nWN/Kvz25kenFWVxZVRLtkkREdKQ/XMyMr14/i5kTc/j0sjfZpCt6RGQUUOgPo7TkRP77j6pJT03i\nTx6soelIZ7RLEpE4p9AfZuNz0lh68wXUt7TzqR+9wbHunmiXJCJxTKE/As4ry+NfP3YOr2w9yJ0/\nX6crekQkaiIKfTNbYGYbzazWzO7oY36qmT0SzH/NzMrD5s0ys1fMbJ2ZvW1maUNX/tjxsfMn86n5\nZ/Dw6zv5yuPrFfwiEhUDXr1jZonAfcBVQB2w0syWu/s7Yc1uBZrcfZqZLQLuAW40syTgf4Cb3f0t\nMysAjg35WowRf/PBGRw91s0DL20jMQH+34fOwsyiXZaIxJFILtmcC9S6+1YAM1sGLATCQ38hcFfw\n/DHgWxZKs6uBNe7+FoC7HxyiusckM+POa6vo6XH++8VtJCQYdyw4U8EvIiMmktCfBOwKe10HzOuv\njbt3mVkzUABMB9zMngSKgGXu/tXeCzCz24DbAMrKYntoYjPjruvOptud767YSlKC8fmrZyj4RWRE\nRBL6faVR7w7p/tokAZcAc4A24BkzW+Xuz7yroftSYClAdXV1zHd2mxl3XzeT7h6477ktpCUl8pdX\nVEa7LBGJA5GEfh1QGvZ6MrCnnzZ1QT9+DtAYTF/h7gcAzOwJ4HzgGeJcQoLxTx+ZSfuxbv796U0U\nZKby8Xmx/SlHRKIvkqt3VgKVZlZhZinAImB5rzbLgSXB8+uBZz10ecqTwCwzSw92Bpfz7nMBcS0h\nIfSt3fkzivj7n73Nr9fujXZJIhLjBgx9d+8CbicU4OuBR919nZndbWbXBc3uBwrMrBb4LHBH8LNN\nwH8Q2nGsBt5w918O/WqMXcmJCXz7pvM5tzSXv1q2mle3xvW5bhEZZjbarhevrq72mpqaaJcx4pqO\ndHLDd19hX3M7y/7sQs6emBPtkkRkDAnOl1YP1E7fyB0l8jJSeOiP55KVlsSSB1ay4+CRaJckIjFI\noT+KTMwdx0O3zqWrp4c/euB13XJRRIacQn+UmVacxfc/MYf9LR0seeB1Wtrj9gvMIjIMFPqj0Hll\neXznD89n075WbnuoRvfaFZEho9AfpebPKOZrN5zLq1sb+etlq+nuGV0n3EVkbFLoj2IfOW8S/3Bt\nFb9eV88//HytRuYUkdOme+SOcrdeUsHBwx18+/ktpCQm8KUPV2mcHhE5ZQr9MeBvPjiDzq4evvfb\nbXT19HD3dTNJSFDwi8jgKfTHADPj737vLJISE/ivFVvo6nb++aPnKPhFZNAU+mOEmfGFBTNISjC+\n9VwtXT3OPb8/i0QFv4gMgkJ/DDEzPnf1dBITjG88s5nuHuer188iOVHn40UkMgr9McbM+MxV00lO\nNL721Caajx7jvo+fz7iUxGiXJiJjgA4Rx6jbP1DJP35kJs9t3M/Hv/cqTUc6o12SiIwBCv0x7A8v\nnMJ3bjqfdXtauP6/Xmb3oaPRLklERjmF/hi3YOYEfvjHc9nf2sHCb73Eqh2N0S5JREYxhX4MmDe1\ngJ/8+cVkpiayeOlrPLpy18A/JCJxSaEfIypLsvjZX7yPuRX5/O2P13DX8nV0dfdEuywRGWUU+jEk\nNz2FH9wyhz9+XwU/eHk7i5a+yh7184tIGIV+jElKTODOD1fxjUWzWb+3hd/7zxd5bsP+aJclIqNE\nRKFvZgvMbKOZ1ZrZHX3MTzWzR4L5r5lZeTC93MyOmtnq4PFfQ1u+9Gfh7En84i8vYXzOOG75wUr+\n+Yn1dHRpXH6ReDdg6JtZInAfcA1QBSw2s6pezW4Fmtx9GnAvcE/YvC3uPjt4fHKI6pYITC3K5Kef\nupib5pWx9IWtfPibv+XNnU3RLktEoiiSI/25QK27b3X3TmAZsLBXm4XAg8Hzx4ArTOP/jgppyYn8\n00fP4fufmEPL0S4+9p2XuWv5Olp1G0aRuBRJ6E8Cwq8BrAum9dnG3buAZqAgmFdhZm+a2Qozu7Sv\nBZjZbWZWY2Y1DQ0Ng1oBicz7zyzm6c9exh9dOIUHX9nOFf++gp+v3q0bs4jEmUhCv68j9t5J0V+b\nvUCZu58HfBb4XzPLfk9D96XuXu3u1UVFRRGUJKciKy2ZLy+cyU8/9T5KstP49LLVLP7vV1m7uzna\npYnICIkk9OuA0rDXk4E9/bUxsyQgB2h09w53Pwjg7quALcD00y1aTs/s0lx+9hfv4ysfmcnG+lau\n/eZv+cwjqzWMg0gciCT0VwKVZlZhZinAImB5rzbLgSXB8+uBZ93dzawoOBGMmU0FKoGtQ1O6nI7E\nBOPmC6fw/N+8n09efga/fHsv7//a8/zLr9bT3Kb+fpFYNWDoB330twNPAuuBR919nZndbWbXBc3u\nBwrMrJZQN87xyzovA9aY2VuETvB+0t01OMwokjMumTuuOZPnPj+fD8+ayNIXtnLJV5/lm89s5nBH\nV7TLE5EhZqPtRF51dbXX1NREu4y4taG+hf94ahNPvbOP/IwU/vzyM7j5oimkJWu8fpHRzMxWuXv1\ngO0U+tKXt3Yd4t+f3sQLmxoozEzh1kum8ocXlpGVlhzt0kSkDwp9GRIrtzfyzWdreWFTA9lpSXzi\nfRXccnE5eRkp0S5NRMIo9GVIrak7xLeereWpd/aRnpLIjXNKueqsEuZNLdDN2UVGAYW+DIuN9a18\n+/lafrlmL109zpnjs/irKyq5uqqEJN2gXSRqFPoyrFrbj/Gb9fv4+m82s+NgGxNy0rhpXhmL55ZR\nkJka7fJE4o5CX0ZEd4/z7Ib9PPTKdl7cfICUxASuPXcCn7i4nFmTc6NdnkjciDT0k0aiGIldiQnG\nVVUlXFVVQu3+wzz0ynZ+vKqOn7yxm9mluXzi4nKuOWc8qUm65FNkNNCRvgy51vZjPLaqjode2cG2\nA0fIS0/m98+fzOJ5ZZxRlBnt8kRikrp3JOp6epyXthzg4dd38tS6fXT1OPMq8vn4vDIWzNTRv8hQ\nUujLqNLQ2sH/rdrFstd3sbOxTUf/IkNMoS+jUn9H/4vnho7+NdyDyKlR6Muo19DawWOr6nj49Z3s\nbGwjOy2Jj5w3iT+oLmXmpJxolycypij0Zczo6XFe3XaQR1bu4ldr6+ns6mHmpGxurC7lutmTyBmn\n8X5EBqLQlzGpue0YP1u9m2Urd7F+bwupSQl86JwJ3DinlHkV+ejWyyJ9U+jLmOburN3dwrKVO1m+\neg+tHV2UF6TzB3NKuf78yRRnp0W7RJFRRaEvMeNoZze/WruXZSt38fq2RhITjEsrC/nI7ElcVVVC\nRqq+Yyii0JeYtLXhMP+3qo7lq/ew+9BRxiUnclVVCR85byKXVhaRrEHfJE4p9CWm9fQ4q3Y28bM3\nd/PLt/dyqO0YWWlJXDS1gBvnlHJJZaG+/CVxRaEvcaOzq4cXNjXwm/X7+M36/Rw43MG45EQuOqOA\ny6cXcdn0IsoL0nUSWGKaBlyTuJGSlMCVVSVcWVXC0c5uXt5ygBWbGlixqYFnN+wHoCw/ncumF3L5\n9GIuOqOATJ0HkDgV0ZG+mS0AvgEkAt9z93/tNT8VeAi4ADgI3Oju28PmlwHvAHe5+9dOtiwd6ctQ\n2n7gCC9sbuCFTQ28vOUgbZ3dJCcaF0zJ4/LpxVw+vYizJmTpU4CMeUPWvWNmicAm4CqgDlgJLHb3\nd8LafAqY5e6fNLNFwEfd/caw+T8GeoDXFPoSLR1d3aza3sSKzQ2s2NjAhvpWAIqyUrmssojLZxRx\n6bRC3f9XxqSh7N6ZC9S6+9bgjZcBCwkduR+3ELgreP4Y8C0zM3d3M/sIsBU4Moj6RYZcalIiF08r\n5OJphXzxmrPY19LOC0E30DMb9vHjN+owg1mTc7m8spC5FQWcV5arS0IlpkTy1zwJ2BX2ug6Y118b\nd+8ys2agwMyOAl8g9Cnh8/0twMxuA24DKCsri7h4kdNRkp3GDdWl3FBdSnePs6buEC9sOsCKTfv5\n1nO19DxbS2KCMXNiNnPK85lTkc+c8nzy9UlAxrBIQr+vzs7efUL9tfkycK+7Hz5Zn6m7LwWWQqh7\nJ4KaRIZUYoJxXlke55Xl8ekrK2ltP8aqHU2s3N7Iym1NPPTqDr73220ATCvOZE55PnMr8phTns+k\n3HE6JyBjRiShXweUhr2eDOzpp02dmSUBOUAjoU8E15vZV4FcoMfM2t39W6dducgwykpLZv6MYubP\nKAZC5wPermvm9e2NrNzWyONr9vDw6zsBKMlOZXZpLueV5TG7NJdZk3NIT1GXkIxOkfxlrgQqzawC\n2A0sAj7j7AJDAAALVUlEQVTeq81yYAnwCnA98KyHzhBferyBmd0FHFbgy1iUmpRIdXk+1eX5MD90\nQ/iN9a2s3N7ImzubWL3rEE+u2weEPjVML8li1qQcirNTuWBKHnMr8rUjkFFhwL/CoI/+duBJQpds\nPuDu68zsbqDG3ZcD9wM/NLNaQkf4i4azaJFoS0wwqiZmUzUxmyUXlwPQeKSTt3Yd4s2dTby56xC/\nWb+Pg0c6ATCD6il5VE3I5uyJOZwzOYdpxZkaNkJGnL6RKzKMOrq6+dXb9ayvb6FmexPv7Gnh6LFu\nIPSlshklWcyclM1ZE0KPM8dnkZWm+wfI4GkYBpFRqKfH2dHYxpq6Q6zb08K6Pc2s3d1C89FjJ9pM\nzhvHmeNDO4AzJ2Rx5vgsygsySNKnAjkJDcMgMgolJBgVhRlUFGawcPYkIHTvgL3N7Wyob2H93lbe\n2dvCxvpWntu4n+6e0EFZSlIClcWZzBgf2gkc3ykUZaXqyiEZFIW+SJSZGRNzxzExdxwfOLPkxPT2\nY93U7j/MxvpWNu5rZUN9K7/dfICfvLH7RJu89ORgR5BNaX46GSmJVBRmMFd3GZN+KPRFRqm05ERm\nTsp5z03im450sqG+lQ31oU8EG+pbebRmF22d3SfapCQmMK04k7MmZDO1KIOzJmQxtTCTyXnj1E0U\n59SnLxIDenqc5qPHONzRxcrtjazb08La3c3sONhGfUv7iXZJCUZZQTpTgy6msoIMyvLTKctPZ1Lu\nOFKStEMYq9SnLxJHEhKMvIwU8jJSKM1P52Pn/25e89FjbN7XyrYDR971eHHzATq6ek60M4MJ2WmU\nBjuBsvx0SoNHWX46hZkp6jKKAQp9kRiXMy75d18sC9PT4zQc7mBnYxs7D7axs7GNXY1t7Gpq44XN\nDexr6XhX+3HJicGOYBwTcsZxXlku00uymJw3jtx0jUc0Vij0ReJUQoJRkp1GSXYac3rtECB0Irmu\nqe3ETmFX09ETO4bXtjbyw1d3nGhbkJHC5Px0JueNY3LeOErzjj8PdRuNS9GtK0cLhb6I9CktOZFp\nxVlMK856z7yeHmdDfSs7G9vYcfAI2w8eYVfjUd7Z08LT6/bR2d3zrvbZaUmMzwntYMZnp73n+fic\nNPLTU0hIUPfRcFPoi8igJYQNQ9FbT4+zv7WDuqY26pqOsqf5KPua26lvaae+uZ1N+1ppaO2gp9c1\nJGnJCUwvyaKyOPT9gykFoU8LE3PHUZKdpltcDhH9FkVkSCUk2Imj9+ryvtt0dfdw4HDniR1BffNR\ndjYeZf3eFl6qPcDBIx0c6373XiEjJZGS7DSKs1MpzkqjJDs1eJ1GcVboeVFWKhkpiTrhfBIKfREZ\ncUmJCSd2DO8auD3Q3ePsOXQ09Gg+yv6WDva1dLCvtZ2Glg7eqjtEfXP7u64+OvHeCUZxVipVE7NP\nnLOYkPO7LqXCzFRyxyXHbVeSQl9ERp3EBDtxuWh/3J2W9i72t7SHdggt7TQc7qDl6DHqmo6yob6F\nN3YeojEY6bT3++dnpFCQkUJRVioFGSkUZqZSkJlKQWYKRcG/hZmpFGelxtQX2hT6IjImmRk545LJ\nGZdMZcl7TzYf19HVHXQhtbOvtYODhzs4cLiDg4c7OXC4kwOHO9h+8AgHD3e+61vNxyUlhJZTkJnC\nxNxxwc4hhcKM0L8FmankpSefqCUrLZnEUfwpQqEvIjEtNSmRKQUZTCnIGLBtW2dXsDPoOLFDqGtq\no6ntGA2toU8TG+tbOXi48z1XKIXLSks6sRM4/jCDnY1tXDNzAkWZqVSWZFKUlUp+RsqI3mBHoS8i\nEkhPSSI9P+mk3UoQ6lo63PG7HURT2zGaj/7u0XL03a837z9MV3cPmWlJ/NuTG/tYbiL5GSksOHs8\nf39t1XCtHqDQFxEZNDMjKy3UlVNeOPAniOPcnW0HjnC4o4u9ze3sb+3gcHsXDa0dNLV1MiF33DBW\nHaLQFxEZIWbG1KJMAGZNjk4NsXNKWkREBhRR6JvZAjPbaGa1ZnZHH/NTzeyRYP5rZlYeTJ9rZquD\nx1tm9tGhLV9ERAZjwNA3s0TgPuAaoApYbGa9zzTcCjS5+zTgXuCeYPpaoNrdZwMLgO+ambqURESi\nJJIj/blArbtvdfdOYBmwsFebhcCDwfPHgCvMzNy9zd27gulpwOi6Y4uISJyJJPQnAbvCXtcF0/ps\nE4R8M1AAYGbzzGwd8DbwybCdwAlmdpuZ1ZhZTUNDw+DXQkREIhJJ6Pf11bLeR+z9tnH319z9bGAO\n8EUzS3tPQ/el7l7t7tVFRUURlCQiIqciktCv491DIk0G9vTXJuizzwEawxu4+3rgCDDzVIsVEZHT\nE0norwQqzazCzFKARcDyXm2WA0uC59cDz7q7Bz+TBGBmU4AZwPYhqVxERAZtwCtp3L3LzG4HngQS\ngQfcfZ2Z3Q3UuPty4H7gh2ZWS+gIf1Hw45cAd5jZMaAH+JS7HzjZ8latWnXAzHacrM0ACoGTLiPG\nxNv6gtY5XmidB2dKJI3MPbYuqDGzGnevjnYdIyXe1he0zvFC6zw89I1cEZE4otAXEYkjsRj6S6Nd\nwAiLt/UFrXO80DoPg5jr0xcRkf7F4pG+iIj0Q6EvIhJHYib0Bxr+eawys1Ize87M1pvZOjP7dDA9\n38yeNrPNwb95wXQzs/8Mfg9rzOz86K7BqTGzRDN708weD15XBMN2bw6G8U4Jpvc5rPdYZGa5ZvaY\nmW0ItvdFcbCdPxP8Xa81s4fNLC3WtrWZPWBm+81sbdi0QW9XM1sStN9sZkv6WlYkYiL0Ixz+eazq\nAj7n7mcBFwJ/EazbHcAz7l4JPBO8htDvoDJ43AZ8Z+RLHhKfBtaHvb4HuDdY3yZCw3lD/8N6j0Xf\nAH7t7mcC5xJa/5jdzmY2CfgrQsOvzyT05c9FxN62/gGhoeXDDWq7mlk+8CVgHqGRj790fEcxaO4+\n5h/ARcCTYa+/CHwx2nUN07r+HLgK2AhMCKZNADYGz78LLA5rf6LdWHkQGt/pGeADwOOEBvQ7ACT1\n3t6Evil+UfA8KWhn0V6HU1jnbGBb79pjfDsfH503P9h2jwMfjMVtDZQDa091uwKLge+GTX9Xu8E8\nYuJIn8iGfx7zgo+z5wGvASXuvhcg+Lc4aBYLv4uvA39LaOgOCA3Tfch/Nyx3+Dr1O6z3GDMVaAC+\nH3Rrfc/MMojh7ezuu4GvATuBvYS23Spif1vD4LfrkG3vWAn9SIZ/HtPMLBP4MfDX7t5ysqZ9TBsz\nvwszuxbY7+6rwif30dQjmDeWJAHnA99x9/MIjUh7snNTY369g+6JhUAFMBHIINS90VusbeuT6W8d\nh2zdYyX0Ixn+ecwys2RCgf8jd/9JMHmfmU0I5k8A9gfTx/rv4n3AdWa2ndBd2j5A6Mg/1353q83w\ndRpwWO8xog6oc/fXgtePEdoJxOp2BrgS2ObuDe5+DPgJcDGxv61h8Nt1yLZ3rIR+JMM/j0lmZoRG\nMV3v7v8RNit8OOslhPr6j0//o+AqgAuB5uMfI8cCd/+iu09293JC2/FZd78JeI7QsN3w3vV9z7De\nI1jykHD3emCXmc0IJl0BvEOMbufATuBCM0sP/s6Pr3NMb+vAYLfrk8DVZpYXfEK6Opg2eNE+wTGE\nJ0o+BGwCtgB/F+16hnC9LiH0MW4NsDp4fIhQX+YzwObg3/ygvRG6kmkLoVtUVkd7HU5j3ecDjwfP\npwKvA7XA/wGpwfS04HVtMH9qtOs+jfWdDdQE2/pnQF6sb2fgy8AGYC3wQyA11rY18DChcxbHCB2x\n33oq2xX442Dda4FbTrUeDcMgIhJHYqV7R0REIqDQFxGJIwp9EZE4otAXEYkjCn0RkTii0BcRiSMK\nfRGROPL/AeLR50AV3uSMAAAAAElFTkSuQmCC\n", 360 | "text/plain": [ 361 | "" 362 | ] 363 | }, 364 | "metadata": {}, 365 | "output_type": "display_data" 366 | } 367 | ], 368 | "source": [ 369 | "epoch = 1000\n", 370 | "pltX, pltY = [], []\n", 371 | "for e in range(epoch):\n", 372 | " pltX.append(e)\n", 373 | " Input = cudAvl(Variable(CNNdigTensor))\n", 374 | " Target = cudAvl(Variable(CNNdigTensor))\n", 375 | " _, Output = CNNnet(Input)\n", 376 | " loss = CNNcriterion(Output,Target)\n", 377 | " print_loss = loss.data[0]\n", 378 | " pltY.append(print_loss)\n", 379 | " CNNoptimizer.zero_grad()\n", 380 | " loss.backward()\n", 381 | " CNNoptimizer.step()\n", 382 | " if (e + 1) % 100 == 0:\n", 383 | " print('epoch [%s/%s]: %s' %(e + 1, epoch, print_loss))\n", 384 | "\n", 385 | "plt.title('loss function output curve')\n", 386 | "plt.plot(pltX, pltY)\n", 387 | "plt.show()" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 17, 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "data": { 397 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGhZJREFUeJzt3X+U3XV95/Hna2byg4QkEMKPEEKCQBBaa9iGX0eOpRVb\nAVtwV6y0p4tdAXVRtIvdUo+7slVO3XPQbrf2sEKlQKWKp2rhFKwilYMoIhGjAuE3CQkJ+UESEvJr\nJjPv/eN+U4f5fG7mzsy9d+Z+5vU4Z87ced/P/d7Pvfc9n/nO9/NLEYGZmXW+rvGugJmZNYcbdDOz\nQrhBNzMrhBt0M7NCuEE3MyuEG3Qzs0IU3aBLWiXp3FE+9iRJP5W0Q9JVza7bRCNpsaSQ1FP9/C1J\nl453vax5JH1G0mZJL493XVplsudxz3hXYAL778D9EXHqeFdkPETEeeNdB2seSQuBq4FFEbFR0mLg\nBWBKROwbz7q10mTL46LP0MdoEfD4aB64/+zAbAJZBLwSERubcTDn+MQ0GRr00yQ9IWmrpL+XNH3/\nHZLeKWmFpG2Sfijp16r4vwG/CXxB0muSlkiaI+k2SZskrZb0SUldVfn3SfqBpL+StAW4tor/F0kr\nq+f+tqRF9Sop6cyqDtsk/UzSOYPuu1/Sp6vn2CHpO5LmDbr/7EGPXSPpfVX8QHXulnR99S/488AF\nQ+pzv6TLBr2+B6vyWyW9IOm8QWWPk/RAVbfvSvpbSV8e5edloyTpGknPVZ/DE5LeVcXPBe4Fjq7y\n+Rbggeph26rYWVXZujlbXcq4UtIzwDN16uA8Hk8RUewXsAp4DFgIzAV+AHymuu8/ABuBM4Bu4NKq\n/LTq/vuBywYd6zbgTmAWsBh4Gnh/dd/7gH3AR6hdxjoIuAh4Fji5in0S+GGdei4AXgHOp/ZH9u3V\nz4cPqstzwJLq2PcDn63uOxbYAVwCTAEOA5Y2UOcPAk8Oem++BwTQM/T1V6+vD7i8eq8+BKwDVN3/\nEHA9MBU4G9gOfHm8P//J9gVcDBxd5dDvAzuB+dV95wBrB5VdPPjzrmIHzNmq/L1VvhzkPJ54X+Ne\ngRYn+Crgg4N+Ph94rrp9A/DpIeWfAn4jkwjdwF7glEFlP0DtGvv+RHlxyLG+tT/pqp+7gF3UrmEO\nreefAf8wJPZt4NJBdfnkoPv+K/Cv1e0/B76ZOeZwdf63Ie/Nbw/zi/DsoLIzqrJHVb+I+4AZg+7/\ncqf9IpT4BawALqxun8PwDfoBc7Yq/1sHeD7n8Th/TYZLLmsG3V5N7QwGatcUr67+vdsmaRu1v/JH\nDz0AMI/aX+3VQ461oM7z7D/+Xw869hZAQx4zuOzFQ+pyNjB/UJnBIxN2AQdXtxdSO+sZaZ2PJn1v\nDuTfnz8idlU3D66Os2VQDNL3wtpA0n/WLy8hbgN+lVoeNKqRnD3QZ+s8HmeToWNj4aDbx1L7Fwtq\nH9Z1EXFdA8fYTO1ftUXAE4OO9dKgMkOXrdx//NsbOP4aamc2lzdQNvfY0zPx4eq8nvS9GY31wFxJ\nMwb9Miw80AOs+apr3TcBbwMeioh+SSuoNcg5uWVWG8nZAy3P6jweZ5PhDP1KScdImgt8Arijit8E\nfFDSGaqZKekCSbOGHiAi+oGvAddJmlX98vw3av+S1fP/gD+X9Cvw7x07F9cp+2XgdyX9TtXJM13S\nOZKOaeD13Q6cK+k9knokHSZpaQN1/hpwVfXeHApc08BzJSJiNbAcuFbS1Kpz7XdHcywbk5nUGttN\nAJL+mNoZej2bgAHgDYNiI8nZHOfxOJsMDfo/At8Bnq++PgMQEcupdY58AdhKrTPofQc4zkeodTI9\nDzxYHffmeoUj4pvA/wa+Kmk7tc7Z7JjYiFgDXEjtD84mamcrf0oDn09EvEitb+Bqav8irwDe3ECd\nb6J2ffNnwKPAN4Z7rgP4Q+Asah1gn6H2R3PvGI5nIxQRTwCfo9axtwF4E7VBAPXK7wKuA35QXR45\ncyQ5W+eYzuNxtr9316xpJN0BPBkRnxrvupiNVifm8WQ4Q7cWk3SapOMldUl6B7WztH8e73qZjUQJ\neTwZOkWt9Y6i9q/uYcBa4EMR8dPxrZLZiHV8HvuSi5lZIXzJxcysEGNq0CW9Q9JTkp6VNKrhQmYT\nkXPbOtGoL7lI6qa2psLbqV1vegS4pBo+lTVV02I6M0f1fGbD2cNOemNvvYk0DXNuv5560q622Ffs\nirsTUqO5PZZO0dOprYvwPICkr1LrFa6b9NOZyRl62xie0qy+h+O+Zh1qcua28u1F99x09YD+TZta\nXRsbpNHcHssllwW8fq2DtWTWKZF0haTlkpb3ddYYfZu8nNvWkcbSoOf+nCfXbyLixohYFhHLpjBt\nDE9n1jbObetIY7nkspbXL15zDL9c+Mqsk7U2t3OXNkbSl1Xn0siIjpHRPffQOscdSEJds5IljwAY\n2LEjiWnK1PSQfb0jq9wYZfsB+vvzhcc6lLtFn08jxnKG/ghwYrXLx1TgvcBdzamW2bhybltHGvUZ\nekTsk/RhagvjdAM3R8So9uA0m0ic29apxjT1PyLuAe5pUl3MJgzntnUizxQ1MyuEG3Qzs0J4tUWz\nNlN3dxIb0cxL1TkPizqjNho0sP21xqswvfFhmnVHk7RRDLRxEcImjGZJRuU0mB4+QzczK4QbdDOz\nQrhBNzMrhBt0M7NCuFPUrM3GvPTsQOOdjPWm6Ks7PZfr3/Zqw8cd0dT9EdS3rSbCbm11lglIcqTB\nqvoM3cysEG7QzcwK4QbdzKwQbtDNzArhBt3MrBAe5WJWsNyGEyXrPmRONq45s9Pg3vxInX0vb2hm\nlQ6sySNtfIZuZlYIN+hmZoVwg25mVgg36GZmhRhTp6ikVcAOoB/YFxHLmlGpUnSfdEI2vuWv0tjp\nR6xOYv/36Ecafq67d03Pxj/99DuT2Nw/yR+j/6lnG36+0jm3D6xn0cJsfO8bDk9iUx9bk8Q2Xpj/\n3SAzE37vnPz0+N45aYfinOfyhz3s0a1JbODnab0mjKFLAjTYd9qMUS6/GRGbm3Acs4nGuW0dxZdc\nzMwKMdYGPYDvSPqJpCuaUSGzCcK5bR1nrJdc3hIR6yQdAdwr6cmIeGBwgeqX4QqA6cwY49OZtY1z\n2zrOmM7QI2Jd9X0j8E3g9EyZGyNiWUQsm0LjG8uajSfntnWiUZ+hS5oJdEXEjur2bwN/0bSadZir\nnn0yiV0wY0XDjz9zxbvT2MZFDT8+N0oG4EdL/ymJHffxy7Nll+TDk45ze3g7f+WobPzVxVOSWO9Z\nJyaxrtO2ZR/f0zWQxGZPy0/RX7dqXhLb9Vq+STusK3Pu2tWdLTshNuTQ6M61x3LJ5Ujgm6oNr+kB\n/jEi/nUMxzObKJzb1pFG3aBHxPPAm5tYF7MJwbltncrDFs3MCuEG3cysEF4PfYTqTef/9ra0M+jT\nT+c7Neecn06xn8PYpt3/+J46U6kzywecfH06DRpqc9xtDJLp2o2vda2e/K9isvs70DVrVrbsSNY+\n75qRDrMc2LUrX/bNJyex1f8p/9pOWJxOp//IsfclsTdN3Zh9/Lr+tF7ddea9/0m8J4lt3n5Etqx2\nZzpWW9X52YzO1lHWzWfoZmaFcINuZlYIN+hmZoVwg25mVgg36GZmhfAolxGqtwnEU5ntD8Y6cqWe\nVzMjWnJT/AGOuzudz7/kqcY3zrAGCdT9+tENuREq9R9f59xq6MgZgP6xj87Q1HRUVveidIo+wOal\nhySxQ+blR0otPnhLEjuiu/HRN4/uPi6J7ejPb96ya+/UJNazM78ZBt1tPHcdx6UDfIZuZlYIN+hm\nZoVwg25mVgg36GZmhXCn6ASRW1Jg5ccPzZZ9YelNSSy3njrAksvdAdoWATGGzsroy6/5nVNviv6I\nnm/P3iS25vcOz5btXboziX3ypO9myx7Sndbt+7uWJLGX9uZz+67Hfy2JdXXnp/7rpbSzdPb6fNn+\nJ57Oxtsq08E9tCN9vxF1qA/iM3Qzs0K4QTczK4QbdDOzQrhBNzMrxLANuqSbJW2U9Nig2FxJ90p6\npvqe7+Ewm8Cc21aaRka53AJ8AbhtUOwa4L6I+Kyka6qf/6z51SvT3vNOS2J/+jf/kMQumLEn+/g3\n/t2Hktii//nQ2Cs2+dzCJM3trqPSjSB65+RHiEydlo64WNt7WLbsWtL4A5vTEVxrtqbLCQBMeXFa\nEuval5/OP31zGjv0mfzvTG5TkJFsCNIUmQ1P6o2M6j5kzut+1vY6m2YMMewZekQ8AAxdoOFC4Nbq\n9q3ARQ09m9kE4ty20oz2GvqREbEeoPqe3/fJrPM4t61jtXxikaQrgCsAppPuF2jWqZzbNtGM9gx9\ng6T5ANX3/I6vQETcGBHLImLZFNLrY2YTjHPbOtZoz9DvAi4FPlt9v7NpNSrI0zelnZ8AL1yQTt2/\ne1c6jfmc96drmQMs+pY7QFto9Lmd6fSaqAZeTv9OHbTxmGzZPcenHXK3/MtvZctqII31HZsuMxD7\n8ueSC1akB9h6Yr5DcMbGtOy0Va9ky+48+41JbPq9P82WHe20+1GpkzP9214dUqyxZSUaGbb4FeAh\n4CRJayW9n1qyv13SM8Dbq5/NOopz20oz7Bl6RFxS5663NbkuZm3l3LbSeKaomVkh3KCbmRXCDbqZ\nWSG8wUWTvHpPOr05txEFwFXr0tEvTy3rS2LT8OYU1hpdQ6aWA/Sls+MB6OlJR1jEkvy0+b7etEmZ\nNSMd5bJ7z5Ts4/ccMjOJDdQZEbrthNz56Pxs2RnrcyNt2jiapU18hm5mVgg36GZmhXCDbmZWCDfo\nZmaFcKdok/xo6T8lsdx0foBn/3BRLtrkGtmk05WZIh+ZufjAwI7Xktj8H6QdhwCvrpudxHadl+8U\nPXTOziT22o8OT2Kamp/y/mo6toCp27NF2XVs2qnZPy3fpM1+YncSy78zY5dbex3as/66z9DNzArh\nBt3MrBBu0M3MCuEG3cysEO4UbZLcxs1PXnZDtuwF30s7UEfizBXvbrhs351ph9S8L3o99cmia0Z+\nJyVl4n2z8+uOD0xNY9OnpjOb69l7WNr9GF35TlHN7U1iu7dlKgAsekO6pvvaWYdmy3btbF+n6Ig6\nP5XfAHu0a+v7DN3MrBBu0M3MCuEG3cysEG7QzcwK0cieojdL2ijpsUGxayW9JGlF9XV+a6tp1nzO\nbSuNYpjeVElvBV4DbouIX61i1wKvRcT1I3my2ZobZ8jbNW7+wFlJbOeCtNwhyzY1fMzc0gP11Bsl\nM+f8zl5+4OG4j+2xpc6wgdRkzm1NSUeOdB+TX0s8NxJjYHZ+9Ez/jHSd855n1qWH7M6fS8bcdJ32\nV5Ydli171kfT/QIO7s4vX/DoH5ycxPqfeDpbdiJqNLeHPUOPiAeALU2pldkE4ty20ozlGvqHJf28\n+rc1P/jTrDM5t60jjbZBvwE4HlgKrAc+V6+gpCskLZe0vI/8v0NmE4hz2zrWqBr0iNgQEf0RMQDc\nBJx+gLI3RsSyiFg2hTqbA5pNEM5t62SjmvovaX5ErK9+fBfw2IHK2+vlpt7PG+Mxf4el2fhJy9NO\nqnodqOeflHaW9j/V2R2lIzVZcjv2pVP3BzZubvjxA8+n654D5Hrt0i2mD3Dck45OYr2z832BFx3y\naBLbE/nNpx+e++tJrGcc1y0fTtJp3ddYX/+wDbqkrwDnAPMkrQU+BZwjaSkQwCrgAyOprNlE4Ny2\n0gzboEfEJZnwl1pQF7O2cm5baTxT1MysEG7QzcwK4QbdzKwQ3uCicN+5Z1kavCydMg2w4ZzMZhiT\nbJTLZKHudDOLgZ35kSvt1LM13Yhiz9yDsmUX9WxPYn3ZcTaweWm6VMERvYvzlfjxL+pXsE2ib8hG\nHw1ueOEzdDOzQrhBNzMrhBt0M7NCuEE3MyuEO0UL0X3SCdn45/7g7xs+Rm5JAuts6sn/imtquh56\n7NvX8HG7D5ubP+6c2Uls3/OrkljPMZkNAICXz0oXt+xdknaUAgzkHt8/M1t25vp0AYKeFzdmyzb+\nLkw8PkM3MyuEG3Qzs0K4QTczK4QbdDOzQrhBNzMrhEe5FOLKu/8lG79gxp4k9uv/60PZsvPwKJdO\n0LMg3QQCYOCVdL/rgT3p5w/QdXBmNMiu/PPlRsr0Z54LoOegdJr+7gvTTZ96D65zLvnOV5LQWw9f\nnykIVz73+0nsmXVHZMue9ERa330vb8jXYQS6Zqbv48Cu/BuZW25hJCOLGqpPU49mZmbjxg26mVkh\n3KCbmRXCDbqZWSEa2SR6IXAbcBS12bY3RsRfS5oL3AEspraZ7nsiYmvrqjr51JvOf8Ltq5NYrvMT\n4MwV705inuJf06m5PbA9vyt9vQ7QrKG7ygNdM9I1wwGU6fjrOyk/dX/N6ekx9mWWM983K7++94Lp\ne5NYj3KT/GHDjllJLLakrwsgpqQdks0QvX2Nl+1Plx9otkbO0PcBV0fEycCZwJWSTgGuAe6LiBOB\n+6qfzTqJc9uKMmyDHhHrI+LR6vYOYCWwALgQuLUqditwUasqadYKzm0rzYiuoUtaDJwKPAwcGRHr\nofaLAWQHgEq6QtJyScv7SP+dMpsInNtWgoYbdEkHA18HPhYR6WZ+dUTEjRGxLCKWTWHaaOpo1lLO\nbStFQw26pCnUEv72iPhGFd4gaX51/3wgv7iw2QTm3LaSNDLKRcCXgJUR8flBd90FXAp8tvp+Z0tq\nOEk8fdNpSeyFC25q+PHH3X15Nr7k8kdGXafSjVtuK7MzfYO7ugMM7MiPchmJga3poJ1609C7p05J\nYruPyv9HsvPYdETKdefdkcRW987LPv67G9+YxO5/7sRs2Tn3pcNn3vjgpmxZXmnNIKXo6224rDIj\ni+o+fmiONJgejazl8hbgj4BfSFpRxT5BLdm/Jun9wIvAxY09pdmE4dy2ogzboEfEg0DmlAKAtzW3\nOmbt49y20nimqJlZIdygm5kVwuuhN0lumv7Kj6c7mEO9zs4VSSQ3bR9g+t+kx13yLXd+dowRdICO\nWVd+yvvec09NYgf9ZFW27OrL0k7Jvje/li37x6f8OInN6t6dVqvOdP6131+YxOY9nS97yOPbklj/\nU89my3YvOT6N1ekE7t/2ajY+ViPpQB1tjvgM3cysEG7QzcwK4QbdzKwQbtDNzArhBt3MrBAe5TJC\nq//irGz8yctuaPgYV61Lp/n//H8sTWJzPHKlsyTTtds4mqWOOOtN2fieuenol1XXLc6W/cvf+EoS\ne++s/FT6rf3pjveXPv8fk9hjP1uUffyJ1/4wG8/Jj32pU/aFNUks9jW+OUWn8Bm6mVkh3KCbmRXC\nDbqZWSHcoJuZFcKdogeQ6wCt1/mZm6bfd+fh2bLzvvhQEpuGO0A73gToBB1K/fmuw+7etK6HPZJv\nDr543FuT2D0H5ztFH9s0P4lNvSNdquLE23+UfbyNjc/QzcwK4QbdzKwQbtDNzArhBt3MrBDDNuiS\nFkr6nqSVkh6X9NEqfq2klyStqL7Ob311zZrHuW2lUQzTMy9pPjA/Ih6VNAv4CXAR8B7gtYi4vtEn\nm625cYa8VaO1xsNxH9tjS709QhNNz+2uc18fbNGol67p07Pxgd50Kru68m9H9+Hzktirb8lPx995\nZLpMwN65+bpl9rLgmC88msQG9uzJH2AkMpt3dE2fli2qGQclsf5XtuSP26LPTT3pKKKos8nGUI3m\ndiObRK8H1le3d0haCSxoqBZmE5hz20ozomvokhYDpwIPV6EPS/q5pJslZfdbk3SFpOWSlvexd0yV\nNWsV57aVoOEGXdLBwNeBj0XEduAG4HhgKbWznM/lHhcRN0bEsohYNoX8v0Nm48m5baVoqEGXNIVa\nwt8eEd8AiIgNEdEfEQPATcDpraumWWs4t60kw15DlyTgS8DKiPj8oPj86hokwLuAx1pTRbPWaHpu\nt2nq/0g6FDV9Rv6OaVOT0JxH1mWLzpyfXnHafVS+Y3ba1rRjtikdoDkD/Uko+up0Mu7O1KHNSzVE\nf1rfZmtkLZe3AH8E/ELSiir2CeASSUuBAFYBH2hJDc1ax7ltRWlklMuDQG64zD3Nr45Z+zi3rTSe\nKWpmVgg36GZmhXCDbmZWCG9wYVawgV278vFVL6bBzFR6ANako18OyowwgfyyBPktNloj+nobLlt3\nCYVWjcppw6gan6GbmRXCDbqZWSHcoJuZFcINuplZIYZdD72pTyZtAlZXP84DNrftydvHr2v8LIqI\nw8fjiQfldie8T6NV6mvrhNfVUG63tUF/3RNLyyNi2bg8eQv5dU1uJb9Ppb62kl6XL7mYmRXCDbqZ\nWSHGs0G/cRyfu5X8uia3kt+nUl9bMa9r3K6hm5lZc/mSi5lZIdreoEt6h6SnJD0r6Zp2P38zVRsI\nb5T02KDYXEn3Snqm+p7dYHgik7RQ0vckrZT0uKSPVvGOf22tVEpuO68777Xt19YGXVI38LfAecAp\n1HaGOaWddWiyW4B3DIldA9wXEScC91U/d5p9wNURcTJwJnBl9TmV8NpaorDcvgXndUdq9xn66cCz\nEfF8RPQCXwUubHMdmiYiHgC2DAlfCNxa3b4VuKitlWqCiFgfEY9Wt3cAK4EFFPDaWqiY3HZed95r\n26/dDfoCYM2gn9dWsZIcuX+D4er7EeNcnzGRtBg4FXiYwl5bk5We20V99qXmdbsb9Nz+jR5mM0FJ\nOhj4OvCxiNg+3vWZ4JzbHaLkvG53g74WWDjo52OAdPX8zrZB0nyA6vvGca7PqEiaQi3pb4+Ib1Th\nIl5bi5Se20V89qXndbsb9EeAEyUdJ2kq8F7grjbXodXuAi6tbl8K3DmOdRkVSQK+BKyMiM8Puqvj\nX1sLlZ7bHf/ZT4a8bvvEIknnA/8H6AZujojr2lqBJpL0FeAcaqu1bQA+Bfwz8DXgWOBF4OKIGNrB\nNKFJOhv4PvALfrmD2CeoXW/s6NfWSqXktvO6817bfp4pamZWCM8UNTMrhBt0M7NCuEE3MyuEG3Qz\ns0K4QTczK4QbdDOzQrhBNzMrhBt0M7NC/H9b2nahEsvkewAAAABJRU5ErkJggg==\n", 398 | "text/plain": [ 399 | "" 400 | ] 401 | }, 402 | "metadata": {}, 403 | "output_type": "display_data" 404 | } 405 | ], 406 | "source": [ 407 | "CNNnet.eval()\n", 408 | "compare(CNNdigTensor[index], CNNnet)" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": { 415 | "collapsed": true 416 | }, 417 | "outputs": [], 418 | "source": [] 419 | } 420 | ], 421 | "metadata": { 422 | "kernelspec": { 423 | "display_name": "Python 3", 424 | "language": "python", 425 | "name": "python3" 426 | }, 427 | "language_info": { 428 | "codemirror_mode": { 429 | "name": "ipython", 430 | "version": 3 431 | }, 432 | "file_extension": ".py", 433 | "mimetype": "text/x-python", 434 | "name": "python", 435 | "nbconvert_exporter": "python", 436 | "pygments_lexer": "ipython3", 437 | "version": "3.6.2" 438 | } 439 | }, 440 | "nbformat": 4, 441 | "nbformat_minor": 2 442 | } 443 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # learning pyTorch with SherlockLiao 2 | 3 | pyTorch初学者,发现了廖星宇编著的书籍《深度学习入门之PyTorch》。以及最近发现了新的写python的神器jupyter。于是用jupyter手撸一遍其中的代码以学习。以本书入门pyTorch, 同时学会git. 4 | 5 | 6 | 7 | --------------------------------------------------------------------------------