├── .gitignore ├── 01_the_foundation_of_deep_learning.ipynb ├── 02_pytorch_tutorial.ipynb ├── 03_tensorflow_tutorial.ipynb ├── 04_neural_network_foundation.ipynb ├── 05_neural_network.ipynb ├── 06_deep_neural_network.ipynb ├── 08_cnn_pytorch.ipynb ├── 08_cnn_tensorflow.ipynb ├── 09_rnn_pytorch.ipynb ├── 09_rnn_tensorflow.ipynb ├── LICENSE ├── README.md ├── datasets ├── ch01 │ └── cat.jpg ├── ch03 │ └── 3-3.png ├── ch06 │ ├── test.rar │ └── train.rar ├── ch08 │ ├── pytorch │ │ └── MNIST │ │ │ ├── processed │ │ │ ├── test.pt │ │ │ └── training.pt │ │ │ └── raw │ │ │ ├── t10k-images-idx3-ubyte │ │ │ ├── t10k-labels-idx1-ubyte │ │ │ ├── train-images-idx3-ubyte │ │ │ └── train-labels-idx1-ubyte │ └── tensorflow │ │ └── MNIST │ │ ├── t10k-images-idx3-ubyte.gz │ │ ├── t10k-labels-idx1-ubyte.gz │ │ ├── train-images-idx3-ubyte.gz │ │ └── train-labels-idx1-ubyte.gz └── readme.md ├── tf_logs ├── ch03 │ └── run-20190704075350 │ │ └── events.out.tfevents.1562226830.AB-201810292038 └── readme.md └── 微信公众号:AI有道.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /02_pytorch_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 第 2 章:PyTorch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "data": { 17 | "text/plain": [ 18 | "'1.0.1'" 19 | ] 20 | }, 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "output_type": "execute_result" 24 | } 25 | ], 26 | "source": [ 27 | "# 导入PyTorch库\n", 28 | "import torch\n", 29 | "import torchvision\n", 30 | "\n", 31 | "# 查看安装的PyTorch版本\n", 32 | "torch.__version__" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## 张量 Tensor" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### 创建 Tensor" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "创建一个随机初始化的 Tensor" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "tensor([[-0.3249, 0.4365],\n", 66 | " [ 0.3976, -0.4804]])\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "x = torch.randn(2,2)\n", 72 | "print(x)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "直接把 Python 列表构建成 Tensor" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "tensor([[1, 2],\n", 92 | " [3, 4]])\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "x = torch.tensor([[1, 2], [3, 4]])\n", 98 | "print(x)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "创建一个全零 Tensor" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "tensor([[0., 0.],\n", 118 | " [0., 0.]])\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "x = torch.zeros(2,2)\n", 124 | "print(x)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "基于现有的 Tensor 创建新的 Tensor" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "tensor([[0., 0.],\n", 144 | " [0., 0.]])\n", 145 | "tensor([[1., 1.],\n", 146 | " [1., 1.]])\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "x = torch.zeros(2,2)\n", 152 | "y = torch.ones_like(x)\n", 153 | "print(x)\n", 154 | "print(y)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "指定 Tensor 数据类型" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "tensor([[1, 1],\n", 174 | " [1, 1]])\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "x = torch.ones(2, 2, dtype=torch.long)\n", 180 | "print(x)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Tensor 的数学运算" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "两个 Tensor 相加" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "tensor([[2., 2.],\n", 207 | " [2., 2.]])\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "x = torch.ones(2,2)\n", 213 | "y = torch.ones(2,2)\n", 214 | "z = x + y\n", 215 | "print(z)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "也可以使用torch.add()实现Tensor相加:" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 8, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "tensor([[2., 2.],\n", 235 | " [2., 2.]])\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "x = torch.ones(2,2)\n", 241 | "y = torch.ones(2,2)\n", 242 | "z = torch.add(x, y)\n", 243 | "print(z)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "还可以使用._add()实现替换:" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 9, 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "tensor([[2., 2.],\n", 263 | " [2., 2.]])\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "x = torch.ones(2,2)\n", 269 | "y = torch.ones(2,2)\n", 270 | "y.add_(x)\n", 271 | "print(y)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "Tenosr乘法有两种形式,第一种是对应元素相乘:" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 10, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "tensor([[ 1, 4],\n", 290 | " [ 9, 16]])" 291 | ] 292 | }, 293 | "execution_count": 10, 294 | "metadata": {}, 295 | "output_type": "execute_result" 296 | } 297 | ], 298 | "source": [ 299 | "x = torch.tensor([[1, 2], [3, 4]])\n", 300 | "y = torch.tensor([[1, 2], [3, 4]])\n", 301 | "x.mul(y)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "第二种更常用的是矩阵相乘:" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 11, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "tensor([[ 7, 10],\n", 320 | " [15, 22]])" 321 | ] 322 | }, 323 | "execution_count": 11, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "x = torch.tensor([[1, 2], [3, 4]])\n", 330 | "y = torch.tensor([[1, 2], [3, 4]])\n", 331 | "x.mm(y)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "### Tensor 与 NumPy" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 12, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "# 导入numpy\n", 348 | "import numpy as np" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "#### Tensor to NumPy" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 13, 361 | "metadata": {}, 362 | "outputs": [ 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "\n", 368 | "\n" 369 | ] 370 | } 371 | ], 372 | "source": [ 373 | "a = torch.ones(2,2)\n", 374 | "b = a.numpy()\n", 375 | "print(type(a))\n", 376 | "print(type(b))" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "此时,如果Tensor发生改变,对应的NumPy数组也有相同的变化。" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 14, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "tensor([[2., 2.],\n", 396 | " [2., 2.]])\n", 397 | "[[2. 2.]\n", 398 | " [2. 2.]]\n" 399 | ] 400 | } 401 | ], 402 | "source": [ 403 | "a.add_(1)\n", 404 | "print(a)\n", 405 | "print(b)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "#### NumPy to Tensor" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 15, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stdout", 422 | "output_type": "stream", 423 | "text": [ 424 | "\n", 425 | "\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "a = np.array([[1, 1], [1, 1]])\n", 431 | "b = torch.from_numpy(a)\n", 432 | "print(type(a))\n", 433 | "print(type(b))" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "如果NumPy数组发生改变,对应的Tensor也有相同的变化。" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 16, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "[[2 2]\n", 453 | " [2 2]]\n", 454 | "tensor([[2, 2],\n", 455 | " [2, 2]], dtype=torch.int32)\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "np.add(a, 1, out=a)\n", 461 | "print(a)\n", 462 | "print(b)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "### CUDA Tensor" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 17, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "a = torch.ones(2,2)\n", 479 | "# 检查是否可以使用GPU\n", 480 | "if torch.cuda.is_available():\n", 481 | " a_cuda = a.cuda()\n", 482 | " print(a_cuda)" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "metadata": {}, 488 | "source": [ 489 | "因为我们安装的是CPU版本的PyTorch,所以这里不会执行if语句。如果安装了GPU,a_cuda的打印结果如下:\n", 490 | "\n", 491 | "```\n", 492 | "tensor([[1., 1.],\n", 493 | " [1., 1.]], device='cuda:0')\n", 494 | "```" 495 | ] 496 | }, 497 | { 498 | "cell_type": "markdown", 499 | "metadata": {}, 500 | "source": [ 501 | "## 自动求导 autograd" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "定义 Tensor x,设置参数`tensor.requries_grad=True`" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 18, 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "True\n", 521 | "True\n" 522 | ] 523 | } 524 | ], 525 | "source": [ 526 | "x = torch.ones(2, 2, requires_grad=True)\n", 527 | "y = torch.ones(2, 2, requires_grad=True)\n", 528 | "print(x.requires_grad)\n", 529 | "print(y.requires_grad)" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "metadata": {}, 535 | "source": [ 536 | "### 当输出是标量时" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": {}, 542 | "source": [ 543 | "定义输出$z=\\frac14\\sum_ix_i+y_i$" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": 19, 549 | "metadata": {}, 550 | "outputs": [ 551 | { 552 | "name": "stdout", 553 | "output_type": "stream", 554 | "text": [ 555 | "tensor(2., grad_fn=)\n" 556 | ] 557 | } 558 | ], 559 | "source": [ 560 | "z = x + y\n", 561 | "z = z.mean()\n", 562 | "print(z)" 563 | ] 564 | }, 565 | { 566 | "cell_type": "markdown", 567 | "metadata": {}, 568 | "source": [ 569 | "反向传播" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": 20, 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "z.backward()" 579 | ] 580 | }, 581 | { 582 | "cell_type": "markdown", 583 | "metadata": {}, 584 | "source": [ 585 | "计算$\\frac{\\partial z}{\\partial x}$和$\\frac{\\partial z}{\\partial y}$" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 21, 591 | "metadata": {}, 592 | "outputs": [ 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "tensor([[0.2500, 0.2500],\n", 598 | " [0.2500, 0.2500]])\n" 599 | ] 600 | } 601 | ], 602 | "source": [ 603 | "print(x.grad)" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 22, 609 | "metadata": {}, 610 | "outputs": [ 611 | { 612 | "name": "stdout", 613 | "output_type": "stream", 614 | "text": [ 615 | "tensor([[0.2500, 0.2500],\n", 616 | " [0.2500, 0.2500]])\n" 617 | ] 618 | } 619 | ], 620 | "source": [ 621 | "print(y.grad)" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": {}, 627 | "source": [ 628 | "### 当输出是多维张量时" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "metadata": {}, 634 | "source": [ 635 | "定义输出 $z=2x+3y$" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": 23, 641 | "metadata": {}, 642 | "outputs": [ 643 | { 644 | "name": "stdout", 645 | "output_type": "stream", 646 | "text": [ 647 | "tensor([[5., 5.],\n", 648 | " [5., 5.]], grad_fn=)\n" 649 | ] 650 | } 651 | ], 652 | "source": [ 653 | "x = torch.ones(2, 2, requires_grad=True)\n", 654 | "y = torch.ones(2, 2, requires_grad=True)\n", 655 | "z = 2 * x + 3 * y\n", 656 | "print(z)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "markdown", 661 | "metadata": {}, 662 | "source": [ 663 | "反向传播" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": 24, 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [ 672 | "z.backward(torch.ones_like(z))" 673 | ] 674 | }, 675 | { 676 | "cell_type": "markdown", 677 | "metadata": {}, 678 | "source": [ 679 | "计算$\\frac{\\partial z}{\\partial x}$和$\\frac{\\partial z}{\\partial y}$" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 25, 685 | "metadata": {}, 686 | "outputs": [ 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "tensor([[2., 2.],\n", 692 | " [2., 2.]])\n" 693 | ] 694 | } 695 | ], 696 | "source": [ 697 | "print(x.grad)" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": 26, 703 | "metadata": {}, 704 | "outputs": [ 705 | { 706 | "name": "stdout", 707 | "output_type": "stream", 708 | "text": [ 709 | "tensor([[3., 3.],\n", 710 | " [3., 3.]])\n" 711 | ] 712 | } 713 | ], 714 | "source": [ 715 | "print(y.grad)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": {}, 721 | "source": [ 722 | "### 禁止自动求导" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 27, 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "name": "stdout", 732 | "output_type": "stream", 733 | "text": [ 734 | "True\n", 735 | "True\n", 736 | "False\n" 737 | ] 738 | } 739 | ], 740 | "source": [ 741 | "print(x.requires_grad)\n", 742 | "print((2 * x).requires_grad)\n", 743 | "\n", 744 | "with torch.no_grad():\n", 745 | " print((2 * x).requires_grad)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": {}, 751 | "source": [ 752 | "## 神经网络包 nn 和优化器 optim" 753 | ] 754 | }, 755 | { 756 | "cell_type": "markdown", 757 | "metadata": {}, 758 | "source": [ 759 | "### torch.nn" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 28, 765 | "metadata": {}, 766 | "outputs": [], 767 | "source": [ 768 | "import torch.nn as nn" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": 29, 774 | "metadata": {}, 775 | "outputs": [], 776 | "source": [ 777 | "class net_name(nn.Module):\n", 778 | " def __init__(self):\n", 779 | " super(net_name, self).__init__()\n", 780 | " self.fc = nn.Linear(1, 1)\n", 781 | " # 其它层\n", 782 | " \n", 783 | " def forward(self, x):\n", 784 | " out = self.fc(x)\n", 785 | " return out" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "新建一个该模型的对象" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 30, 798 | "metadata": {}, 799 | "outputs": [], 800 | "source": [ 801 | "net = net_name()" 802 | ] 803 | }, 804 | { 805 | "cell_type": "markdown", 806 | "metadata": {}, 807 | "source": [ 808 | "### torch.optim" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": 31, 814 | "metadata": {}, 815 | "outputs": [], 816 | "source": [ 817 | "import torch.optim as optim" 818 | ] 819 | }, 820 | { 821 | "cell_type": "markdown", 822 | "metadata": {}, 823 | "source": [ 824 | "计算预测值与真实值的均方误差" 825 | ] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "metadata": {}, 830 | "source": [ 831 | "```\n", 832 | "criterion = nn.MSELoss()\n", 833 | "loss = criterion(output, target)\n", 834 | "```" 835 | ] 836 | }, 837 | { 838 | "cell_type": "markdown", 839 | "metadata": {}, 840 | "source": [ 841 | "使用随机梯度下降(SGD)优化" 842 | ] 843 | }, 844 | { 845 | "cell_type": "code", 846 | "execution_count": null, 847 | "metadata": {}, 848 | "outputs": [], 849 | "source": [] 850 | }, 851 | { 852 | "cell_type": "markdown", 853 | "metadata": {}, 854 | "source": [ 855 | "```\n", 856 | "optimizer = optim.SGD(net.parameters(), lr=0.01)\n", 857 | "```" 858 | ] 859 | }, 860 | { 861 | "cell_type": "markdown", 862 | "metadata": {}, 863 | "source": [ 864 | "单次迭代对应的代码为" 865 | ] 866 | }, 867 | { 868 | "cell_type": "markdown", 869 | "metadata": {}, 870 | "source": [ 871 | "```\n", 872 | "optimizer.zero_grad() # 梯度清零\n", 873 | "output = net(input)\n", 874 | "loss = criterion(output, target)\n", 875 | "loss.backward()\n", 876 | "optimizer.step() # 完成更新\n", 877 | "\n", 878 | "```" 879 | ] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": {}, 884 | "source": [ 885 | "## PyTorch 线性回归" 886 | ] 887 | }, 888 | { 889 | "cell_type": "markdown", 890 | "metadata": {}, 891 | "source": [ 892 | "### 创建数据集" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": 32, 898 | "metadata": {}, 899 | "outputs": [], 900 | "source": [ 901 | "# y=3x+10,后面加上torch.randn()函数制造噪音\n", 902 | "x = torch.unsqueeze(torch.linspace(-1, 1, 50), dim=1)\n", 903 | "y = 3 * x + 10 + 0.5 * torch.randn(x.size())" 904 | ] 905 | }, 906 | { 907 | "cell_type": "markdown", 908 | "metadata": {}, 909 | "source": [ 910 | "显示数据分布" 911 | ] 912 | }, 913 | { 914 | "cell_type": "code", 915 | "execution_count": 33, 916 | "metadata": {}, 917 | "outputs": [ 918 | { 919 | "data": { 920 | "text/plain": [ 921 | "
" 922 | ] 923 | }, 924 | "metadata": {}, 925 | "output_type": "display_data" 926 | } 927 | ], 928 | "source": [ 929 | "import matplotlib.pyplot as plt\n", 930 | "plt.rcParams['font.sans-serif']=['SimHei']\n", 931 | "plt.rcParams['axes.unicode_minus']=False\n", 932 | "plt.rcParams['figure.figsize'] = (10.0, 6.0) # set default size of plots\n", 933 | "\n", 934 | "plt.scatter(x.numpy(), y.numpy())\n", 935 | "plt.show()" 936 | ] 937 | }, 938 | { 939 | "cell_type": "markdown", 940 | "metadata": {}, 941 | "source": [ 942 | "### 定义模型" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 34, 948 | "metadata": {}, 949 | "outputs": [], 950 | "source": [ 951 | "class LinearRegression(nn.Module):\n", 952 | " def __init__(self):\n", 953 | " super(LinearRegression, self).__init__()\n", 954 | " self.fc = nn.Linear(1, 1)\n", 955 | " \n", 956 | " def forward(self, x):\n", 957 | " out = self.fc(x)\n", 958 | " return out" 959 | ] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": 35, 964 | "metadata": {}, 965 | "outputs": [], 966 | "source": [ 967 | "model = LinearRegression()" 968 | ] 969 | }, 970 | { 971 | "cell_type": "markdown", 972 | "metadata": {}, 973 | "source": [ 974 | "### 定义 loss 和优化函数" 975 | ] 976 | }, 977 | { 978 | "cell_type": "code", 979 | "execution_count": 36, 980 | "metadata": {}, 981 | "outputs": [], 982 | "source": [ 983 | "# 定义loss和优化函数\n", 984 | "criterion = nn.MSELoss()\n", 985 | "optimizer = optim.SGD(model.parameters(), lr=5e-3)" 986 | ] 987 | }, 988 | { 989 | "cell_type": "markdown", 990 | "metadata": {}, 991 | "source": [ 992 | "### 训练" 993 | ] 994 | }, 995 | { 996 | "cell_type": "code", 997 | "execution_count": 37, 998 | "metadata": {}, 999 | "outputs": [ 1000 | { 1001 | "name": "stdout", 1002 | "output_type": "stream", 1003 | "text": [ 1004 | "Epoch[20/1000], loss: 58.123623\n", 1005 | "Epoch[40/1000], loss: 39.301388\n", 1006 | "Epoch[60/1000], loss: 26.664478\n", 1007 | "Epoch[80/1000], loss: 18.171267\n", 1008 | "Epoch[100/1000], loss: 12.455207\n", 1009 | "Epoch[120/1000], loss: 8.601423\n", 1010 | "Epoch[140/1000], loss: 5.997334\n", 1011 | "Epoch[160/1000], loss: 4.232640\n", 1012 | "Epoch[180/1000], loss: 3.032410\n", 1013 | "Epoch[200/1000], loss: 2.212354\n", 1014 | "Epoch[220/1000], loss: 1.648835\n", 1015 | "Epoch[240/1000], loss: 1.258876\n", 1016 | "Epoch[260/1000], loss: 0.986709\n", 1017 | "Epoch[280/1000], loss: 0.794804\n", 1018 | "Epoch[300/1000], loss: 0.657871\n", 1019 | "Epoch[320/1000], loss: 0.558824\n", 1020 | "Epoch[340/1000], loss: 0.486085\n", 1021 | "Epoch[360/1000], loss: 0.431788\n", 1022 | "Epoch[380/1000], loss: 0.390558\n", 1023 | "Epoch[400/1000], loss: 0.358707\n", 1024 | "Epoch[420/1000], loss: 0.333685\n", 1025 | "Epoch[440/1000], loss: 0.313713\n", 1026 | "Epoch[460/1000], loss: 0.297540\n", 1027 | "Epoch[480/1000], loss: 0.284271\n", 1028 | "Epoch[500/1000], loss: 0.273265\n", 1029 | "Epoch[520/1000], loss: 0.264049\n", 1030 | "Epoch[540/1000], loss: 0.256270\n", 1031 | "Epoch[560/1000], loss: 0.249662\n", 1032 | "Epoch[580/1000], loss: 0.244020\n", 1033 | "Epoch[600/1000], loss: 0.239182\n", 1034 | "Epoch[620/1000], loss: 0.235021\n", 1035 | "Epoch[640/1000], loss: 0.231432\n", 1036 | "Epoch[660/1000], loss: 0.228331\n", 1037 | "Epoch[680/1000], loss: 0.225646\n", 1038 | "Epoch[700/1000], loss: 0.223320\n", 1039 | "Epoch[720/1000], loss: 0.221302\n", 1040 | "Epoch[740/1000], loss: 0.219550\n", 1041 | "Epoch[760/1000], loss: 0.218029\n", 1042 | "Epoch[780/1000], loss: 0.216707\n", 1043 | "Epoch[800/1000], loss: 0.215558\n", 1044 | "Epoch[820/1000], loss: 0.214558\n", 1045 | "Epoch[840/1000], loss: 0.213690\n", 1046 | "Epoch[860/1000], loss: 0.212934\n", 1047 | "Epoch[880/1000], loss: 0.212276\n", 1048 | "Epoch[900/1000], loss: 0.211704\n", 1049 | "Epoch[920/1000], loss: 0.211207\n", 1050 | "Epoch[940/1000], loss: 0.210774\n", 1051 | "Epoch[960/1000], loss: 0.210397\n", 1052 | "Epoch[980/1000], loss: 0.210070\n", 1053 | "Epoch[1000/1000], loss: 0.209784\n" 1054 | ] 1055 | } 1056 | ], 1057 | "source": [ 1058 | "num_epochs = 1000 # 遍历整个训练集的次数\n", 1059 | "for epoch in range(num_epochs):\n", 1060 | " # forward\n", 1061 | " out = model(x) #前向传播\n", 1062 | " loss = criterion(out, y) #计算loss\n", 1063 | " # backward\n", 1064 | " optimizer.zero_grad() #梯度归零\n", 1065 | " loss.backward() #反向传播\n", 1066 | " optimizer.step() #更新参数\n", 1067 | " \n", 1068 | " if (epoch+1) % 20 == 0:\n", 1069 | " print('Epoch[{}/{}], loss: {:.6f}'.format(epoch+1, num_epochs, loss.detach().numpy()))" 1070 | ] 1071 | }, 1072 | { 1073 | "cell_type": "markdown", 1074 | "metadata": {}, 1075 | "source": [ 1076 | "### 模型测试" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": 38, 1082 | "metadata": {}, 1083 | "outputs": [ 1084 | { 1085 | "data": { 1086 | "image/png": "\n", 1087 | "text/plain": [ 1088 | "
" 1089 | ] 1090 | }, 1091 | "metadata": { 1092 | "needs_background": "light" 1093 | }, 1094 | "output_type": "display_data" 1095 | } 1096 | ], 1097 | "source": [ 1098 | "model.eval()\n", 1099 | "y_hat = model(x)\n", 1100 | "plt.scatter(x.numpy(), y.numpy(), label='原始数据')\n", 1101 | "plt.plot(x.numpy(), y_hat.detach().numpy(), c='r', label='拟合直线')\n", 1102 | "# 显示图例\n", 1103 | "plt.legend() \n", 1104 | "plt.show()" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "markdown", 1109 | "metadata": {}, 1110 | "source": [ 1111 | "查看参数" 1112 | ] 1113 | }, 1114 | { 1115 | "cell_type": "code", 1116 | "execution_count": 39, 1117 | "metadata": {}, 1118 | "outputs": [ 1119 | { 1120 | "data": { 1121 | "text/plain": [ 1122 | "[('fc.weight', Parameter containing:\n", 1123 | " tensor([[2.7996]], requires_grad=True)), ('fc.bias', Parameter containing:\n", 1124 | " tensor([9.9626], requires_grad=True))]" 1125 | ] 1126 | }, 1127 | "execution_count": 39, 1128 | "metadata": {}, 1129 | "output_type": "execute_result" 1130 | } 1131 | ], 1132 | "source": [ 1133 | "list(model.named_parameters())" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": null, 1139 | "metadata": {}, 1140 | "outputs": [], 1141 | "source": [] 1142 | } 1143 | ], 1144 | "metadata": { 1145 | "kernelspec": { 1146 | "display_name": "Python 3", 1147 | "language": "python", 1148 | "name": "python3" 1149 | }, 1150 | "language_info": { 1151 | "codemirror_mode": { 1152 | "name": "ipython", 1153 | "version": 3 1154 | }, 1155 | "file_extension": ".py", 1156 | "mimetype": "text/x-python", 1157 | "name": "python", 1158 | "nbconvert_exporter": "python", 1159 | "pygments_lexer": "ipython3", 1160 | "version": "3.5.6" 1161 | } 1162 | }, 1163 | "nbformat": 4, 1164 | "nbformat_minor": 2 1165 | } 1166 | -------------------------------------------------------------------------------- /03_tensorflow_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 第 3 章:TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "1.10.0\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "# 导入 tensorflow\n", 25 | "import tensorflow as tf\n", 26 | "\n", 27 | "# 查看 tensorflow 版本\n", 28 | "print(tf.__version__)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## 张量 Tensor" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "### 创建 Tensor" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# 0 阶 Tensor\n", 52 | "c0 = tf.constant(2, name='c0')\n", 53 | "# 1 阶 Tensor\n", 54 | "c1 = tf.constant([1, 2, 3], name='c1')\n", 55 | "# 2 阶 Tensor\n", 56 | "c2 = tf.constant([[1, 2], [3, 4]], name='c2')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "Tensor(\"c0:0\", shape=(), dtype=int32)\n", 69 | "Tensor(\"c1:0\", shape=(3,), dtype=int32)\n", 70 | "Tensor(\"c2:0\", shape=(2, 2), dtype=int32)\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "print(c0)\n", 76 | "print(c1)\n", 77 | "print(c2)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### Tensor数学运算" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Tensor(\"add:0\", shape=(2, 2), dtype=int32)\n", 97 | "Tensor(\"sub:0\", shape=(2, 2), dtype=int32)\n", 98 | "Tensor(\"multiply:0\", shape=(2, 2), dtype=int32)\n", 99 | "Tensor(\"matmul:0\", shape=(2, 2), dtype=int32)\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "a = tf.constant([[1, 2], [3, 4]], name='a')\n", 105 | "b = tf.constant([[5, 6], [7, 8]], name='b')\n", 106 | "# 加法\n", 107 | "out = tf.add(a, b, name='add')\n", 108 | "print(out)\n", 109 | "# 减法\n", 110 | "out = tf.subtract(a, b, name='sub')\n", 111 | "print(out)\n", 112 | "# 对应元素相乘\n", 113 | "out = tf.multiply(a, b, name='multiply')\n", 114 | "print(out)\n", 115 | "# 矩阵相乘\n", 116 | "out = tf.matmul(a, b, name='matmul')\n", 117 | "print(out)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## 数据流图(Data Flow Graph)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "![](./datasets/ch03/3-3.png)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "c = tf.constant(2, name='const')\n", 141 | "\n", 142 | "a = tf.Variable(3, name='a')\n", 143 | "b = tf.Variable(4, name='b')\n", 144 | "\n", 145 | "f = a*a + a*b +c" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "## 会话(Session)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "### 会话机制 1 " 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 6, 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "name": "stdout", 169 | "output_type": "stream", 170 | "text": [ 171 | "23\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "sess = tf.Session()\n", 177 | "sess.run(a.initializer)\n", 178 | "sess.run(b.initializer)\n", 179 | "result = sess.run(f)\n", 180 | "print(result)\n", 181 | "\n", 182 | "sess.close()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "### 会话机制 2" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 7, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "23\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "with tf.Session() as sess:\n", 207 | " a.initializer.run()\n", 208 | " b.initializer.run()\n", 209 | " result = f.eval()\n", 210 | "print(result)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "### 会话机制 3" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "init = tf.global_variables_initializer() # 定义全局初始化节点\n", 227 | "\n", 228 | "with tf.Session() as sess:\n", 229 | " init.run() # 初始化所有变量\n", 230 | " result = f.eval()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "### 占位符 tf.placeholder()" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 9, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "name": "stdout", 247 | "output_type": "stream", 248 | "text": [ 249 | "Situation 1:\n", 250 | " [12.]\n", 251 | "Situation 2:\n", 252 | " [[2. 4.]\n", 253 | " [6. 8.]]\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "# 在 Tensorflow 中需要定义 placeholder 的 type,一般为 float32\n", 259 | "a = tf.placeholder(tf.float32, name='a')\n", 260 | "b = tf.placeholder(tf.float32, name='b')\n", 261 | "f = tf.multiply(a, b)\n", 262 | "\n", 263 | "with tf.Session() as sess:\n", 264 | " print('Situation 1:\\n', sess.run(f, feed_dict={a: [3.], b: [4.]}))\n", 265 | " print('Situation 2:\\n',sess.run(f, feed_dict={a:[[1.,2.],[3.,4.]], b: [2.]}))" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "## TensorFlow 线性回归" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 10, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "import numpy as np\n", 282 | "import matplotlib.pyplot as plt" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "### 创建数据集" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 11, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "x_train = np.linspace(-1, 1, 50)\n", 299 | "y_train = 3*x_train + 10 + 0.5 * np.random.randn(x_train.shape[0])" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "显示数据分布" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 12, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "data": { 316 | "image/png": "\n", 317 | "text/plain": [ 318 | "
" 319 | ] 320 | }, 321 | "metadata": { 322 | "needs_background": "light" 323 | }, 324 | "output_type": "display_data" 325 | } 326 | ], 327 | "source": [ 328 | "import matplotlib.pyplot as plt\n", 329 | "plt.rcParams['font.sans-serif']=['SimHei']\n", 330 | "plt.rcParams['axes.unicode_minus']=False\n", 331 | "plt.rcParams['figure.figsize'] = (10.0, 6.0) # set default size of plots\n", 332 | "\n", 333 | "plt.scatter(x_train, y_train)\n", 334 | "plt.show()" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": {}, 340 | "source": [ 341 | "### 创建输入节点 x 和 y,用于输入数据" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 13, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "x = tf.placeholder(tf.float32, name='x')\n", 351 | "y = tf.placeholder(tf.float32, name='y')" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "### 创建变量节点 w1 和 w0,并初始化变量" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 14, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "w1 = tf.Variable(tf.random_normal([1]), name='w1')\n", 368 | "w0 = tf.Variable(tf.zeros([1]), name='w0')" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "### 创建线性模型" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 15, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "y_hat = w0 + w1 * x" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "### 创建损失模型" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 16, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "loss = tf.reduce_mean(tf.square(y_hat - y))" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "### 创建一个梯度下降优化器" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 17, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "optimizer = tf.train.GradientDescentOptimizer(0.01) # 学习率设为 0.01\n", 417 | "train = optimizer.minimize(loss)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": {}, 423 | "source": [ 424 | "### 创建会话 Session 用来计算模型" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 18, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "sess = tf.Session()" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "### 初始化变量" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 19, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "w1 = [-0.4311598] w0 = [0.]\n" 453 | ] 454 | } 455 | ], 456 | "source": [ 457 | "init = tf.global_variables_initializer()\n", 458 | "sess.run(init)\n", 459 | "print (\"w1 =\", sess.run(w1), \"w0 =\", sess.run(w0)) # 打印初始化的w1和w0" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "### 训练" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 20, 472 | "metadata": {}, 473 | "outputs": [ 474 | { 475 | "name": "stdout", 476 | "output_type": "stream", 477 | "text": [ 478 | "Iteration[20/500], loss: 47.436157\n", 479 | "Iteration[40/500], loss: 22.252466\n", 480 | "Iteration[60/500], loss: 10.788024\n", 481 | "Iteration[80/500], loss: 5.496624\n", 482 | "Iteration[100/500], loss: 3.000713\n", 483 | "Iteration[120/500], loss: 1.784191\n", 484 | "Iteration[140/500], loss: 1.163201\n", 485 | "Iteration[160/500], loss: 0.826786\n", 486 | "Iteration[180/500], loss: 0.631707\n", 487 | "Iteration[200/500], loss: 0.510594\n", 488 | "Iteration[220/500], loss: 0.430753\n", 489 | "Iteration[240/500], loss: 0.375594\n", 490 | "Iteration[260/500], loss: 0.336194\n", 491 | "Iteration[280/500], loss: 0.307420\n", 492 | "Iteration[300/500], loss: 0.286107\n", 493 | "Iteration[320/500], loss: 0.270183\n", 494 | "Iteration[340/500], loss: 0.258223\n", 495 | "Iteration[360/500], loss: 0.249212\n", 496 | "Iteration[380/500], loss: 0.242410\n", 497 | "Iteration[400/500], loss: 0.237269\n", 498 | "Iteration[420/500], loss: 0.233382\n", 499 | "Iteration[440/500], loss: 0.230442\n", 500 | "Iteration[460/500], loss: 0.228216\n", 501 | "Iteration[480/500], loss: 0.226533\n", 502 | "Iteration[500/500], loss: 0.225258\n" 503 | ] 504 | } 505 | ], 506 | "source": [ 507 | "num_iter = 500\n", 508 | "for i in range(num_iter):\n", 509 | " sess.run(train, {x: x_train, y: y_train})\n", 510 | " if (i+1) % 20 == 0:\n", 511 | " print('Iteration[{}/{}], loss: {:.6f}'.format(i+1,num_iter,sess.run(loss,{x:x_train,y:y_train})))" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "打印 w1" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 21, 524 | "metadata": {}, 525 | "outputs": [ 526 | { 527 | "data": { 528 | "text/plain": [ 529 | "array([2.938043], dtype=float32)" 530 | ] 531 | }, 532 | "execution_count": 21, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | } 536 | ], 537 | "source": [ 538 | "sess.run(w1)" 539 | ] 540 | }, 541 | { 542 | "cell_type": "markdown", 543 | "metadata": {}, 544 | "source": [ 545 | "打印 w0" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 22, 551 | "metadata": {}, 552 | "outputs": [ 553 | { 554 | "data": { 555 | "text/plain": [ 556 | "array([9.940149], dtype=float32)" 557 | ] 558 | }, 559 | "execution_count": 22, 560 | "metadata": {}, 561 | "output_type": "execute_result" 562 | } 563 | ], 564 | "source": [ 565 | "sess.run(w0)" 566 | ] 567 | }, 568 | { 569 | "cell_type": "markdown", 570 | "metadata": {}, 571 | "source": [ 572 | "### 模型测试" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 23, 578 | "metadata": {}, 579 | "outputs": [ 580 | { 581 | "data": { 582 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlAAAAFnCAYAAABpdXfNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xt8zvX/x/HH2wwLNamUSeaYSlLLt76iJaW+nXyXKJ1LR0oqoVDOqyl9lUgncqhfoXVOSiJKjYVKsxwzOYQ5Ncz2/v3x2ea02XVt13V9rsPzfru5fbfL57o+77fru13P3ofX21hrERERERHPVXC7ASIiIiKhRgFKRERExEsKUCIiIiJeUoASERER8ZIClIiIiIiXFKBEREREvKQAJSIiIuIlBSgRERERLylAiYiIiHhJAUpERETESxX9fYMTTjjB1qtXz9+3ERERESm3hQsX/m2tPbG06/weoOrVq0daWpq/byMiIiJSbsaYNZ5cpyk8ERERES8pQImIiIh4SQFKRERExEt+XwNVnNzcXNatW8eePXvcuH1IqVKlCnXq1CE6OtrtpoiIiEgBVwLUunXrqF69OvXq1cMY40YTQoK1li1btrBu3Tri4+Pdbo6IiIgUcGUKb8+ePdSsWVPhqRTGGGrWrKmROhERkSDj2hqoUAtP27dvx1pb6nV79+4lPz/fZ/cNtX8nERGRSBCxi8jXrl1LUlJS0fejR49m/fr1LFq0iEmTJh1x/Y033siIESOOeDw3N5evvvqK2bNnM3v2bO68806eeOKJou+/+uqromvXr19/yD1vu+02FixYwLvvvuvj3omIiIg/RWSA2rdvHxUrViQmJobc3Fwee+wx6tSpw7Bhw9i2bRt//PHHIdenpKTQpk0bMjIy+OKLLw75u/z8fNatW8eGDRuK/tSqVavo63Xr1hVdO3r0aHr27MmqVasAiI6OpkaNGqxevZqsrCz/d1xERER8wpVF5G5LTU3llVdeITMzkxEjRvDLL7+QkpJCcnIyV199NRUqHMiVzz//PCtWrGDs2LHs27ePm2++mSVLltCzZ0+io6OpXLkyd9xxBzfeeCN///03v/76K7m5uVSuXJlq1aqRmpoKwLJly9i5c2fR9f379yczM5O+ffsSGxvL4MGDGTt2rFv/JCIiIuIF9wPUI4/Azz/79jXPOQdefLHEv+7UqRNt27bl8ccfp3fv3sybN48KFSrw3Xff8c033xAVFcXKlSvp1q0bVatWpVmzZgwZMgSAM888kxUrVnDOOecwduxYWrduDcDq1atJTk4+5D69e/cu+nry5MnMnz+f999/n9mzZ/P6669z8skn8/DDD9OqVSsWL17M2rVrqVu3rm//LURERIJManoWKTMyWJ+dQ+3YGHq1b0KHFnFuN8sr7gcolyxZsoRvv/2WV155peix/Px8/vnnHypXrkzVqlV58sknOeuss1i1alXRqFR+fj6NGjVi9+7dVK9evei5eXl5bNiw4ZB7HLyYfMiQISQlJTFz5kwaNmxIvXr1+PHHH3nqqaeoXbs2p556KnfddZefey0iIuKu1PQs+k5fSk5uHgBZ2Tn0nb4UIKRClPsB6igjRf4yceJEvv/+e9q0aUP37t354osv+Pzzz1mwYAHHHXcc8fHx1KpVixNPPJGzzjqL2rVrH/L8nJwc5s2bd8hjO3fuPGIKbuvWrUVfZ2Rk0KlTJ2655RbeffddmjRpQr9+/fjss89o2bIlc+bM4fTTT/dfp0VERIJAyoyMovBUKCc3j5QZGQpQwe7WW2/lyiuv5PHHHwfgzz//5NVXX2XKlCkkJibStWtXACpUqMDJJ598yE46gAsuuOCI16xatSodO3Y85LE333yz6OsaNWowYsQI6tevT/Xq1YmPjycvL49u3brx+++/M23aNF93U0REJOisz87x+PFgnuqLyAAFFNV0ysvLo3379vTu3Ztu3bpxww03MGnSJC6++GKaNGnCkiVLaNeu3SHP3bZtW9HX+/fvxxhDpUqVOOeccw65rkqVKuTn55Ofn8+iRYuYPHkyeXl5HH/88SQnJ9OvXz/q1atH48aNyczMJDMz84gQJiIiEk5qx8aQVUxYqh0bc8j3wT7VF5EBKj8/n86dO9OpUyeioqK45557uPjii3nggQfo1q0biYmJzJ8/n0aNGtGsWbMjRqDOO++8oq+nTJnC+PHjqVq1Ks8888wh1x1zzDFcdtlldOrUiaSkJJo2bcppp53Gn3/+Sdu2bRk2bBhXX301AwcO5NJLL9UolIiIhL1e7ZscEowAYqKj6NW+ySHXBftUn/GkunZ5JCQk2LS0tEMeW7ZsGU2bNvXrfUtjrS2q8m2tZcOGDZxyyimu3B+c422qVKlS7LXB8O8lIiLiK55MzcX3+ZTiEooBViVf5be2GWMWWmsTSrsuIkeg4NAjUowxAQ1Ph98fKDE8iYiIhJsOLeJKHUXydKrPLRFZiVxERESCW6/2TYiJjjrkseKm+twSsSNQIiIi4lu+3DVX+DztwhMREZGw5Y9dc55M9blFU3glyMvLw1rL3r172bVrV9Hj1lr279/vYstERESCz9F2zYWjiB6BeuONNxgxYgRxcU66PXgnXF5eHtOmTWPFihWMHz+e0aNHA/D5558zbdo03njjjUNea8CAAVxyySV89dVXVK9enW7dutGxY0c+++wzoqIOzOHu37+fTp06MX36dLp06cL69euL/q5atWp88skn/u62iIiIz3lTIDMclClAGWNOAc4EFlhrd/q2SYETFRVFjx49uP/++wH417/+xdy5c4t2yO3YsYOff/6ZRo0asXHjRmrVqsWECRN49tln+eijj2jdujU1atRg165dHHvssXz//fds2rSJDRs2sGbNGqpWrUpUVFTRmXjGGCpWrMi+ffsAyM3NZfbs2UXtURFNEREJVcG+a87XPJrCM8bUMsbMLfi6MfB/QCvgW2NMJT+2D3DmVVslzyK+z6e0Sp5FanqWT1//8ccf56KLLiI7O5tLLrmECy+8kBdeeIHZs2fzv//9j/nz59OmTRvmz5/PRRddRL169ahSpQoPPvggANu3b2fLli0MHz6c33//nRNOOIGXX36ZP/74gzZt2lCnTh3S0tKYPn067du356effqJjx46sXbuWxMREWrduTZs2bVi0aJFP+yUiIhIowb5rztdKHYEyxtQAJgBVCx46G7jTWrvCGNMMiAf8NsEZiFLu2dnZ3HbbbZx00kkApKens2PHDq699lqmTJnChAkT+O9//0vPnj3JzMzklVde4fjjj2fXrl1MnTqViy66iNWrV/PYY4+RkZHBhg0bWLJkCUOHDqVhw4a8+uqrtGzZkpYtW3L99ddz9dVX8+CDDzJu3Djeffdd3nnnHQBuuukmn/RHREQk0IJ915yveTKFlwd0Bj4EsNZONcZUNMZcBdQA/vBj+/xayn3//v1ER0djreX444/nhBNOAODYY49l505nZnLXrl1kZ2dz6qmnMmnSJGrUqFE0xbd3714qVKjAxo0befrpp5k4cSKPPvoo0dHRPPnkkyxcuJBKlSpRv379ont+++23zJ07lw4dOvD777/Trl07Nm3aBMCrr75Kly5duPfee8vVLxERETcE8645Xys1QFlrd8ARlbOrAZ2ANXBkpXVjzL3AvQB169YtVwP9uShtw4YNnHnmmTRq1IixY8eSlZVFhQoVOOWUU7jpppvIyspi8+bNvPfee7Rr146PP/6YIUOGUKtWLQBWrlzJ+vXrWbNmDf369SMzM5PFixfz888/s2rVKpKSkgC48sorAXj77beZOnUqrVu3pmvXrnz77bdMnDiRqVOnAloDJSIiEirKtIjcWpsN3G6MmQicDyw47O/HAePAOQuvPA3056K0xYsXc8MNN3D22WfTokULVq1aRZUqVTDG0LlzZ6pUqcKYMWNo06YNI0eOpGrVqvTu3ZuuXbsCzqJzgFatWtGpUyd++OEHrrzySs4880yio6M599xzSU1NZcCAAQDccMMN3HrrrVxzzTWsXbuWE044gQYNGnDqqacC0KNHD7KyfLu+S0RERHzP6zpQxpgxxpg2Bd/GAtm+bdKh/LUobdeuXWRkZFC1alXuuOMOateuXbRbrnr16iQlJZGTk8Mbb7zB5MmT+eeff9ixYwcpKSkkJiaSmJjIxo0bi17v/vvv59JLL2XIkCH88ssvrFq1il9//ZXKlSsXLQ6PiYnBGIMxhm+++Ya2bdty/vnnM3v2bGbPnk3r1q3L1ScREREJjLKMQD0HTDTGWOBLa61fK2T5a1HanDlz6NKlC0uXLmX06NHs3LmTESNGMGnSJP79738TGxvLrFmzqFSpEtdddx3gTMH16tWraATq3HPPBWDbtm088MAD1K9fnx9++IHffvuNO++8kxEjRlCrVi06duzIlClTaNCgAb179+byyy9nzpw5jB07lh49epCYmAhARkZ4FhsTEREJN8bacs2wlSohIcGmpaUd8tiyZcto2rSpX+/rD4UVyCtWPDJ37t+/v+hxay35+flFBTSttYevISuSm5tLdHT0Ue8bqv9eIiIiPpGTAzGBqSdljFlorU0o7Tod5eKFihUrFhueCv+ukDHmkOrjJYUnoNTwJCIiErE2boRHHoHTToMtW9xuzSFcO8rlaKMycoC/RwhFRESCzpYtkJICL70Ee/fCHXdAkJ1D60qAqlKlClu2bKFmzZoKUUdhrWXLli1F5/OJiIiEte3b4YUXYORI2LULunSBp5+GRo3cbtkRXAlQderUYd26dWzevNmN24eUKlWqUKdOHbebISIi4j+7dsGoUTBiBGzbBh07wjPPwJlnut2yErkSoKKjo4mPj3fj1iIiIhIscnJgzBhITobNm+Hqq2HQIGjRwu2Wlcq1NVAiIiLiudT0rPA5Z27vXnj9dRg6FP76Cy67DAYPhoIC1aFAAUpERCTIpaZn0Xf60qKzYbOyc+g7fSlAaIWo3Fx4+21nlGntWmjdGt55By6+2O2WeU1lDERERIJcyoyMovBUKCc3j5QZIVKAOS8PJk2Cpk2ha1c4+WT48kv49tuQDE+gACUiIhL01hdzJuzRHg8a+fnw/vvQrBnceitUqwYffQQ//OBM24XwTnxN4YmIiAS52rExZBUTlmrH+r86d5nWXlkLH38MAwbA4sXOyNP770NSElQIj7Gb8OiFiIhIGOvVvgkx0VGHPBYTHUWv9k38et/CtVdZ2TlYDqy9Sk3PKv4J1jpTcxdcANdd55QnmDQJli51ShOESXgCBSgREZGg16FFHMOTmhEXG4MB4mJjGJ7UzO8LyL1ae1W4nql9e9iwwdllt2wZ3HwzREUdeX2I0xSeiIhICOjQIi7gO+48Wnv1ww/Qvz989RWccgqMHg133w2VKweole7QCJSIiIgUq6Q1VrVjY2DRIqfw5YUXOuucXngBVqyABx8M+/AEClAiIiJSguLWXp297U/enzkCzjsP5s+HYcNg5Uro2RNi/L+oPVhoCk9ERESKVThlmDIjg8or/6DPj+9x2ZJZmGrVnEN+e/aE445zuZXuUIASERFxSSgcz9KhRi4dlk9yKohXrgy9e8Pjj0PNmm43zVUKUCIiIi4I+uNZ1q1zzqp74w2n/MBDD0GfPlCrltstCwpaAyUiIuKCoD2eZeNGZ2quYUMnPHXt6iwOHzlS4ekgGoESERFxQdAdz7JlC6SkwEsvwd69cPvtTnmCevXcaU+QU4ASERFxgZvHsxxi+3anBMHIkU7l8JtuchaIN24c2HaEGE3hiYiIuMCt41mK7NoFw4dDfDwMGgSXX+4cuTJ5ssKTBzQCJSIi4oKDSwQEdBdeTg6MGQPJybB5M1x1lROgzj3Xv/cNMwpQIiIiLgno8Sx79zrn0w0dCn/9Be3aweDBzsG/4jUFKBERkTByeG2pJ9rW57qlXzujTGvXwkUXwTvvOAf/SpkpQImIiISJg2tLVcjP4/x5n3HOc+/AtvVw/vnw2mtw2WVgjNtNDXkKUCIiImEiZUYGe/blcmXGfB79bjKNtvzJbyfF88Rtg3lu/FMKTj6kACUiIhIOrKVp2mxemzuJMzatIrPmqTx4XR8+b/JvMBV4TuHJpxSgREREQpm1MHMm9O/P6z/+yOrYU3jk6sf4qGkb8is4ZRLiAl1bKgIoQImIiISqOXOgXz+YOxfq1iV9QAq35Z7OzvwDo00BrS0VQVRIU0REJNT88IOzGPzii+GPP+Dll2H5cloMfJzBN7QgLjYGgzPyNDypWXAcThxmNAIlIiISKtLTYcAA+OQTOOEEeP55eOABiDkwRRfQ2lIRTAFKREQk2P36q3M+3bRpEBvrFMN8+GGoVs3tlkUsBSgREZFglZkJAwfClClOWOrfHx591AlR4ioFKBERkWCzerVzzMqECVCpEjzxBPTqBTVr+uwWh1csD8g5fGFEAUpERCRYZGXBsGFOxfAKFaB7d+jTB04+2ae3ObhiOUBWdg59py8FUIjykEe78IwxtYwxcwu+rmuMmW2MmWWMGWeMKnOJiIiUy6ZNztRcgwYwbhzcfbezu+7FF30ensCpWF4Yngrl5OaRMiPD5/cKV6WOQBljagATgKoFD90HPGCtXWaM+RxoBizxXxNFRETC1NatkJICo0bBnj1w223OLrv4eL/edn12jlePy5E8mcLLAzoDHwJYa5866O9qAn/7oV0iIiLha/t2Z3TphRdg50648UZnl12TwBS8rB0bQ1YxYam2KpZ7rNQpPGvtDmvt9sMfN8Z0Bn611q73S8tERERCVGp6Fq2SZxHf51NaJc8iNT3L+YvduyE52RlheuYZaNcOlixxdtkFKDwB9GrfhJjoqEMeU8Vy7xhrrWcXGjPbWptY8HV94P+AdiWEq3uBewHq1q173po1a3zWYBERkWB2+AJtgFj2MzE3nWaTxjjrna66CgYNgnPPdbWd2oV3JGPMQmttQqnXeRugCtZEfQF0tdYuLe15CQkJNi0tzaN7iIiIhLpWybOKpsei83LpvGQm3ee/y8m7tjojToMGwYUXutxKKYmnAaosZQz6AHWBlwo24D1trf22DK8jIiISdtZn5xCVn0fSL1/TY9671NmxiR/rnMEj1zzOu1P6ut088RGPA1Th9J21tjfQ218NEhERCVl5edyxah63zpxA/W3r+fmURvS9ojtz67UgrsYxbrdOfEiFNEVEpIjWxZRRfj588AEMGMDTv/3Gslr1ufv6/nzdoCUYowXaYcijQpoiIhL+Chc/Z2XnYDlQnbpoB5kcyVr45BM47zzo2NEJUu+9R8ans/k9IRFjDHGxMQxPaqYgGmY0AiUiIsDRq1Prw/8w1sJXXzmH+y5Y4FQQnzgRbroJoqLoAHQ471S3Wyl+pBEoEREBVJ3aY3PnQmIiXH45/PUXvP46LFsGt9wCUVGlPl3CgwKUiIgAJVehVnXqAgsWOKGpTRvIzISXX4bly51z66Kj3W6dBJgClIiIAKpOXaL0dLjmGrjgAufrESOcg367dYPKld1unbhEa6BERASgaJ2TduEV+O0353y6qVMhNhaGDoWHH4Zq1dxumQQBBSgRESnSoUVc5AamQpmZMHCgcz5dtWowYAD07OmEKJECClAiIiIAa9bA4MEwfjxUqgRPPAG9ekHNmm63TIKQApSIiES29eud6bnXXgNjoHt36NMHTj7Z7ZZJEFOAEhGRyLRpEyQnw5gxsH+/s5uuXz+oU6fEp6hSuxRSgBIRkciydauzk27UKMjJYc1VHXns9GtZWCGW2pOW06u9KTYUFVZqLyw2WlipHVCIikAqYyAiIpFhxw5ncXh8vDPydO21fDX1G65ofidpFWJLPb7maJXaJfIoQImISHjbvRuefdYJTs88A+3awZIlMGUKT2fs9zgUqVK7HEwBSkREwtJHP6zgf1d34++T6kCfPmw44xxIS4Np0+CsswDvQpEqtcvBFKBERCSkpKZn0Sp5FvF9PqVV8qwjp9v27ePnJ4fzr8v/RY9PXyHjxLok3ZLCJYm9SK1w6M46b0KRKrXLwRSgREQkZBQu5M7KzjlyzdL+/fDmm9C4MecMf5K1x9biphuHcfONw1gU17TYqTlvQlGHFnEMT2pGXGwMBoiLjWF4UjMtII9Q2oUnIiIho7iF3Hv37mNx8mg6pE91qognJHB7y7v4Nv5cp67TQQ6fmvP2+BpVapdCClAiIhIyDg5AxubTfvn3PDp3Mo23rIWzz4YPP4RrruGPZ78BD9cxKRRJWWgKT0REQkbt2BiwlrZ//Mgn4x9hbOpwKth8+t/UH9LT4dprwRitVxK/0wiUiIh4zZWK3Nby3HEbqfbyQJpnZbAm9mR6XvUoXzZvy9CO50CFA2MC3k7NiXjLWGv9eoOEhASblpbm13uIiEjgHF6RG5zRHb8uqJ4zB/r3hzlz+Ofk2oy68EbeaNCGk2pWVzASnzLGLLTWJpR2nUagRETEK0eryO3zILNggROcZs50Dvd96SWOuece+lSuTB/f3knEK1oDJSIiXglIRe6ff3bWM11wgbO2acQIWLECuneHypV9dx+RMlKAEhERr/i1Ivdvv8ENN0CLFjB3LgwZAitXwmOPwTHHlP/1RXxEAUpERLzilx1umZlwyy3OEStffOFM261aBU89BdWrl7PFIr6nNVAiIuIVn+5wW7MGBg+G8eOhUiVnpKl3bzjhBN82WsTHFKBERMRr5S4+uX49DB0Kr73mVAvv1g369nUWiouEAAUoEREJnE2bIDkZxoxxzq67+25nmu7UU91umYhXFKBERMT/tm51dtKNGgU5OXDrrTBgANSv73bLRMpEAUpERPxnxw4YORJeeMH5+sYb4emn4fTT3W6ZSLkoQImIiO/t3g0vvwzPPeeMPnXoAAMHOgf+ioQBBSgREfGdPXtg7FgYPtxZ73TllTBoECSUejKGSEhRHSgRESm/ffucheENG0LPnk49p3nz4LPPFJ4kLClAiYhI2e3fD2+9BU2awIMPwmmnwddfO3/+/W+3WyfiNwpQIiLivbw8mDIFzjgD7roLataEzz+H776Dtm3dbp2I33kUoIwxtYwxcw/6vqkx5kP/NUtERIKStTB9OjRvDjff7Bzs+8EH8NNPcMUVTlFMkQhQaoAyxtQAJgBVC75vAKQAx/m3aSIiEjSshU8/hfPOg+uvd6bu3n0XFi92dtgpOEmE8WQEKg/oDOwo+H4ncL3fWiQiIsHDWvjqK2c909VXQ3Y2TJgAv/wCnTtDBa0EkchU6v/zrbU7rLXbD/p+k7V2r3+bJSIirps7Fy65BC67DNatg3HjICMDbrsNKqoKjkQ2v/yngzHmXmNMmjEmbfPmzf64hYiI+EvheqY2bZzANGoUZGbCPfdAdLTbrRMJCn4JUNbacdbaBGttwoknnuiPW4iIiK8tXgzXXQctW8LChZCSAitWwEMPQZUqbrdOJKhoDFZEJNItW+acT/f++xAbC4MHQ48eUL16wJqQmp5FyowM1mfnUDs2hl7tm9ChRVzA7i/iLY8DlLU28Wjfi4hIiFmxwjmfbvJkOOYY6N8fHn3UCVEBlJqeRd/pS8nJzQMgKzuHvtOXAihESdDS9gkRkUizdq2znqlJE5g6FR57DFatcs6sC3B4AkiZkVEUngrl5OaRMiMj4G0R8ZSm8EREIsX69TBsGLz2mvN9t27Qty+cfLK7zcrO8epxkWCgACUiEu42b4Znn4XRo50CmHfdBf36wamnut0yAGrHxpBVTFiqHRvjQmtEPKMpPBGRcLVtGzz1FMTHw8iR0KmTU5bg1VeDJjwB9GrfhJjoqEMei4mOolf7Ji61SKR0GoESEQmAgO4y27EDXnwRXngBtm+HG290dtmdfrp/7ldOhf8O2oUnoUQBSkTEzwK2y2z3bmea7tlnYetW+O9/nV12zZr57h5+0qFFnAKThBQFKBERPzvaLjOfhIY9e5xpueHDYeNGuPJKZ0ddQgKgGksi/qAAJSLiZ37bZbZvH7z1llP4MisL2raF6dOdg38LqMaSiH9oEbmIiJ+VtJuszLvM9u+H8eOdOk733w+nnQazZsHXXx8SnkA1lkT8RQFKRMTPfLbLLD8f3nkHzjwT7rwTataEzz+H776DSy4p9imqsSTiHwpQIiJ+1qFFHMOTmhEXG4MB4mJjGJ7UzPMpNGvhgw+geXPo0gUqV4bUVPjpJ7jiCjCmxKf6fPRLRACtgRIRCYgy7TKz1hlh6t8fFi1ypuzefRduuAEqePbfv73aNzlkDRSoxpKIL2gESkQkGM2aBa1awVVXOQUxJ0yAX36Bzp09Dk/gg9EvESmWRqBERILJvHnOMSuzZ0OdOjBuHNxxB0RHl/kl3ayxpBIKEq4UoEREgkFamjNV98UXUKsWjBoF99wDVaqU+JRgDycqoSDhTAFKRAIu2D/4A2rJEhgwAD780NlVl5ICDz4Ixxxz1KeFQjjxewFRERdpDZSIBFThB39Wdg6WAx/8qelZbjctsH7/3TmjrnlzZ7pu8GBYtQoef7zU8AShUd9JJRQknGkESkQCKuJHJVascI5ZmTTJCUr9+sGjj0KNGkWXeDJCFwrhpHZsDFnFtEclFCQcaARKRAIqFD74/WLtWrj3Xjj9dHj/fXjsMWfEafDgI8KTJyN0oVDfyWcFREWCkAKUiARUKHzw+9Rff8FDD0GjRk4pgvvvd0ahnnsOTjjhiMs9nZoLhXCiEgoSzjSFJyIBFTGFHTdvdkLSyy87Z9fddRc89RTUrXvUp3k6QlcYQoJ9Mb6bJRRE/EkBSkQCKlQ++Mts2zZ4/nl48UXIyYFbbnF22TVo4NHTvVk3pHAi4h4FKBEJuLD84N+xA/73Pyc8bd/uVAx/5hlnzZMXImaETiTEKUCJiJTH7t0werQzXbdlC1x3nbPL7uyzy/RyYT9CJxImFKBERMpizx549VUYPhw2boQrrnCC0/nnl/ulw3KETiTMKECJiHhj3z546y2n/EBWFiQmwrRpzsG/IhIxVMZARMQT+/fD+PHOmqb773d20339NXzzjcKTSARSgBIROZr8fHjnHTjzTLjzTqfo5Wefwbx50Lat260TEZcoQImIFMda+OAPA62pAAAd30lEQVQD56y6Ll2gUiXn+7Q0uPJKMMbtFoqIixSgREQOZq0zwnT++ZCU5Kx5eucdWLwYOnRQcBIRQAFKROSAWbOc9UxXXQVbtzqLxX/9FW68ESro16WIHKDfCCIi330Hl1wCl14Kf/4JY8fC77/DHXdARW1WFpEjKUCJSOT66SenflPr1rBsmVNJPDMT7rvPWfMkIlICBSgRiTxLljjrmVq2dBaFP/ccrFgBDz8MVaq43ToRCQEamxaRyLFsmXM+3XvvwbHHOpXDe/RwvhYR8YIClIiEvxUrYOBAmDwZYmLgySfh8cedmk4iImWgACUiR5WanhW6B9uuXQtDhsCbb0J0NPTsCb17w4knut0yEQlxHgUoY0wtYKq1trUxJhqYDhwPvGGtfdOfDRQR96SmZ9F3+lJycvMAyMrOoe/0pQDBHaL++guGDYNx45y6Tg884Iw6nXKKxy8R0sFRRPyu1EXkxpgawASgasFDDwELrbWtgI7GmOp+bJ+IuChlRkZReCqUk5tHyowMl1pUis2boVcvqF8fxoyB22+HP/6Al17yOjz1nb6UrOwcLAeCY2p6lv/aLiIhxZNdeHlAZ2BHwfeJwHsFX88BEnzfLBEJBuuzc7x63DXbtkG/fhAfDy+8ADfc4NRxGjfOOfTXSyEXHEUk4EqdwrPW7gAwB44vqAoU/mfYVqDW4c8xxtwL3AtQtwy/vEQkONSOjSGrmLBUOzbGhdYUY8cOp3bT88/D9u1OcBo4EJo2LdfLhkxwFBHXlKUO1C6g8LdnteJew1o7zlqbYK1NOFGLNUXKLTU9i1bJs4jv8ymtkmcFbCqpV/smxERHHfJYTHQUvdo3Ccj9S/TPP5CS4kzVDRgAF18MP//slCcoZ3iCkgNi0ARHEXFdWQLUQuCigq+bA6t91hoROYKb63E6tIhjeFIz4mJjMEBcbAzDk5q5t5h6zx4YNcoJTk88AQkJsGABfPghNG/us9t4ExzdCrci4q6ylDGYAHxmjGkNnAEs8G2TRORgR1uPE4gg06FFnPu7z3JznYN9Bw+GdeucEaepU+Gii0p/bhkU9re0XXghu0tRRMrN4wBlrU0s+N81xpjLcEahBlhr8476RBEpl4hej7N/v1P8cuBAWLUKLrwQxo+Htm3hwLpMv/AkOLodbkXEPWUqpGmtXc+BnXgi4kdBv5C7gE/rJuXnO+uZnnkGMjLg3HPh5Zfhyiv9Hpy8EdHhViTC6TBhkSAXtAu5D+KzdVrWQmoqnHMO3HSTUz18+nTnwN///CeowhNosblIJFOAEglyQbeQuxjlrptkLXz+OZx/Pvz3v7B3L7zzDixe7HwfZMGpUCiEWxHxD52FJxICgmIh91GUayrrm2+cIpjz50O9es4ap5tvhorB/+vJ08XmIhJ+gv83lIgEvTKt05o/3wlO33wDcXEwdizceSdUquTHlvpesIdbEfEPTeGJSLl5NZWVluYsBm/VCn77zakk/scfcN99IReeRCRyaQRKRMrNo6mspUudquGpqXD88ZCcDN27Q9WqJbyqiEjwUoASEZ8ocSrr99+dcgTvvQfVq8OgQdCjBxx7bMDbKCLiKwpQIuIfK1c6YWniRIiJgSefhMcegxo13G6ZiEi5KUCJiG/9+ScMGQJvvunspOvZE3r3Bh0sLiJhRAFKRHzjr79g+HB49VWnrtP990PfvlC7ttstExHxOQUoESmfv/+G555zjlrZt88pRdCvH5x2mtstExHxGwUoESmb7Gx4/nl48UXYvRtuucXZZdewodstExHxOwUoEfHOzp1O7abnn3dCVKdOzi67pk3dbpmISMAoQIlEqNT0LO+OIPnnH3jlFad+05YtcO21zi675s2Dp40uvaaIRB4FKJEIlJqeRd/pS4sOAM7KzqHv9KUAR4aJvXth3DgYNgw2bID27Z3g1LJl8LTRxdcUkciko1xEXJSankWr5FnE9/mUVsmzSE3PCsh9U2ZkFIWIQjm5eaTMyDjwQG4uvPYaNGoEDz8MTZrAnDnwxRd+D08etzEIXlNEIpNGoERc4uZoyPpiDv4tejwvDyZPhoEDnWKYF1wAb70FbduCMX5tl8dtDKLXFJHIpBEoEZe4ORpSOzbmiMeMzefWtT/AWWfB7bfDccfBJ5/A/Plw6aUBDU8ltfFoj7v1miISmRSgRFzi5mhIr/ZNiImOcr6xlssyf+Dz8T0Y9M4QiIqCadNg4UK46qqAB6di21ggJjqKXu2bBNVrikhk0hSeiEtqx8aQVUxYCsRoSIcWcWAtc16axO1fvEnzDZnsqhsPU6Y4ZQmiokp/kUC0EXy6Y84frykikclYa/16g4SEBJuWlubXe4iEosPXQIEzGjI8qZn/P9Bnz3aqhc+bB/XqOQUwb73VObsuBKk0gYj4ijFmobU2obTrQvO3pUgYcGU0ZP586N8fZs2CuDgYMwbuugsqVfLfPf1MpQlExA0KUCIu6tAiLjAf8gsXOsHp88/hpJOc41fuuw+qVPH/vf3saIvxFaBExF+0iFwknC1dCklJkJAACxY4VcRXroQePcIiPIFKE4iIOxSgRMJRRgbcdJNzzMrXXzs1nVatgt69oWpVt1vnUypNICJuUIASCScrV8Idd8AZZ8DHH0Pfvk5wGjAAjj3W7db5hUoTiIgbtAZKxENBvdPrzz9hyBB4801nJ12PHtCnj7PeKcypNIGIuEEBSsQDQbvTa8MGGD4cxo4Fa52F4U8+CbVru9cmFwRsMb6ISAFN4Yl4IOgOof37b3jiCahfH0aPdmo4ZWbCyy9HXHgSEXGDRqBEPBA0O72ys+H5550yBLt3w803w9NPQ8OGgW2HiEiEU4AS8YCbx64AsHMnjBoFI0Y4IeqGG+CZZ5zF4iIiEnCawhPxgGs7vf75xwlN9es7R69cdBGkp8N77yk8iYi4SCNQIh4I+E6vvXth3DgYNsxZKH755TBoEPzrX/65n4iIeEUBSsRDAdnplZsL48fD4MFOaYI2bZzRptat/XtfERHxiqbwRIJBXh68/Tacfjrce6+zk27mTJg9W+FJRCQIeR2gjDHxxphPjTFzjTHP+6NRIhEjP98ZYTrrLLj9dqda+Mcfw/ffQ7t2YIzbLRQRkWKUZQTqWWCwtbY1UMcYk+jbJolEAGvhww+hRQvo3BkqVICpU2HhQrj6agUnEZEgV5YA1RhYVPD1JuA43zVHJMxZC198AS1bQocOzi67SZNgyRK4/nonSImISNAry2/rqcDTxphrgCuArw+/wBhzrzEmzRiTtnnz5vK2USQ8FK5nuvJK2LwZ3ngDli1zimFGRZX6dBERCR5eByhr7RDgc6ArMMFau6uYa8ZZaxOstQknnniiD5opEsK+/x4uvRQuuQRWrYJXXoHly+Guu5yDf0VEJOSU9bf3z0Bd4CYftkUkvCxaBP37w2efwUknwciRzmG/MQGqXi4iIn5T1gDVC3jBWvuPLxsjEhaWLnXOp/vgA6hRA5KToXt3qFrV7ZaJiIiPlGnFqrX2aWvtRF83RiSkLV8OXbpA8+bkzvyK1y+9nbNvH0srez6py7Pdbp2IiPiQFmCIlNeqVc4xK2+/DTExZNzRjduPb8OGiscAsCM7h77TlwL4v5K5iIgEhPZMi5TVunVw//3QuDG8+y488gisXMldjf9bFJ4K5eTmkTIjw6WGioiIr2kESsRbGzY465rGjnUqid93Hzz5pHP8CrA+O6fYp5X0uIiIhB4FKBFPbdkCzz0HL78Me/fCnXdCv35w2mmHXFY7NoasYsJS7VjtvhMRCReawpOQkpqeRavkWcT3+ZRWybNITc/y/02zs2HAAIiPh5QUSEqC33+H1147IjwB9GrfhJjoQwtjxkRH0at9E/+3VUREAkIjUBIyUtOz6Dt9KTm5eQBk+Xtx9q5dMGqUE5qys+GGG+CZZ+CMM476tMK2pMzIYH12DrVjY+jVvokWkIuIhBEFKAkZKTMyisJTocLF2T4NJzk5TrXw5GT4+2+45hpnl90553j8Eh1axCkwiYiEMU3hScjw++LsvXth9Gho0AAefxzOPRcWLICPPvIqPImISPhTgJKQUdIi7HIvzs7Nhddfd8oRdO8OjRrBnDkwYwa0bFm+1xYRkbCkACUhw+eLs/PyYOJEaNoU7rkHTjkFZs6E2bOhdevyN1hERMKW1kBJyPDZ4uz8fJg2zTmvbtkyaNECPvkE/vMfMMYPLRcRkXCjACUhpVyLs62Fjz92ShIsXuzspps2jdTTzidlZibr534W8jvmUtOztPtPRCQAFKAiQMR/qFoLX34J/fvDTz9Bw4YweTJ07kzqkg2BLY3gRwEv8yAiEsG0BirMFX6oZmXnYDnwoRqQApTB4NtvoU0buOIK2LQJ3nzTmbbr0gWioo5aGiHUhFNfRESCnQJUmIvYD9UffoDLLoPERFi50qnrtHy5c/xKxQMDr+F0bl049UVEJNgpQIW5iPtQXbQIrr4aLrwQliyBF16AP/6ABx6ASpWOuNxvpRFcEE59EREJdgpQYS5iPlR/+QWuvx7OOw/mz4fhw2HFCujZE2JK7ms4nVsXTn0REQl2ClBhLuw/VJcvh5tvhrPPhq++cs6qW7UK+vSBatVKfXqHFnEMT2pGXGwMBoiLjWF4UrOQXHQdTn0REQl2xlrr1xskJCTYtLQ0v95Dji4sd+GtXu2cT/f221C5MvTo4Ry/cvzxbrdMRERCmDFmobU2obTrVMYgAoTVwbZZWTB0qHP0SoUK8PDDzmjTSSe53TIREYkgClASGjZuhORkGDPGqSTetSs89RTEhUkwFBGRkKIAJV4L6JTgli2QkgIvvQR798Idd0C/flCvnn/uJyIi4gEFKPFKwKpdb9/ulCAYORJ27XIKXz79NDRq5Lt7iIiIlJF24YlX/F6Yc9cuGDYM4uOdReLt28PSpTBpksKTiIgEDY1AiVf8VpgzJ8dZ35ScDJs3O8UwBw2CFi3K97oiIiJ+oBEo8YrPC3Pu3QujR0ODBvDYY9C8uXMMy8cfKzyJiEjQUoASr/isMGdurlOKoHFj6N4dGjaE2bNh5kz4179812ARERE/0BSeeKVwoXiZd+Hl5cGUKTBwoHPUSsuWTpBq1w6M8WPLRUREfEcBSrxWpsKc+fkwdapz1MqyZc5U3ccfw1VXKTiJiEjI0RSe+Je18NFHznqmzp2dx95/HxYtchaKKzyJiEgI0giU+Ie18OWX0L8//PSTs8Zp0iS48UaIijri8rA8r09ERMKWRqDE9779Ftq0gSuucI5gef11+O03uPnmEsNT3+lLycrOwXKgOGdqelbg2y4iIuIBBSjxne+/dxaDJyY6C8RHj4bly+HuuyE6usSn+b04p4iIiI8pQEn5Fa5n+ve/YckS5wiWFSvgwQehcuVSn+634pwiIiJ+ogAlZffLL3D99XDeeTB/vnMEy8qV0LMnxHheWNPnxTlFRET8TAFKvLd8uXO479lnO4Uvn34aVq2Cvn2hWjWvX85nxTlFREQCxOtdeMaYGsBk4CRgobX2Pp+3SoLT6tXO+XQTJkCVKvDEE9CrF9SsWa6XLXdxThERkQArSxmDW4HJ1trJxpgpxpgEa22arxsmQWTdOhg6FN54AypUgIcfhj59oFYtn92iTMU5RUREXFKWALUFOMsYEwucCvzp2yZJ0Ni4EYYPh7FjnUriXbvCk09CnTput0xERMRVZQlQ3wFXAQ8Dy4Cth19gjLkXuBegbt265WmfuGHLFkhJgZdegj174PbbnYKY8fFut0xERCQoGGutd08w5k3gEWvtDmPMo8Aua+24kq5PSEiwaWma4fOUqxW5t293ShCMHAm7dsFNNzkLxBs3Dsz9RUREXGaMWWitTSjturLswqsBNDPGRAH/ArxLYFIi1ypy79oFw4axr+5pMGgQn5/SjFsefo3Ux59TeBIRESlGWabwhgNvAacB3wPv+LRFEexoFbn9MgqVkwNjxkByMmzezPyGLUn5bxd+PbkhAAunLwXQ4m4REZHDeB2grLU/Amf6oS0RL2AVuffudc6nGzoU/voL2rXj3obX8uVx9Q+5zK/hTUREJISpkGYQ8XtF7txcpxRB48bQvTs0aADffAMzZzLzsPBUSMepiIiIHEkBKoj4rSJ3Xh5MmgRnnOGUIqhVC2bMgDlznIN/0XEqIiIi3lCACiIdWsQxPKkZcbExGCAuNobhSc3KPoWWnw9Tp0KzZnDrrXDMMfDhh7BgAVx+ORhTdKmOUxEREfFcWRaRix/5pCK3tfDJJ07tpsWL4fTT4b33nIN/KxSfmcPxOBVXS0KIiEhYU4AKJ9Y6h/v27w8//uiscZo40annFBVV6tPD6TiVwpIQhbsaC0tCgHYViohI+WkKL1zMmQMXXwzt28OGDc4uu2XL4JZbPApP4eZoJSFERETKSwEq1BWuZ7r4YvjjD3j5ZVi+HO6+G6Kj3W6dawJWEkJERCKSAlSoSk+Ha66BCy6An3+G55+HFSugWzeoXNnt1rlOuwpFRMSfFKBCza+/QseOcO65MG+eUwxz5Up49FGIUTgopF2FIiLiT1pEHioyM2HgQJgyBapVcw757dkTjjvO7ZYFpXDcVSgiIsFDASrYrV4NgwfDhAnO1NwTT0CvXlCzptstC3rhtKtQRESCiwJUsMrKcqbnXn/dqd300EPQp49TRVxERERcpQAVbDZtguRkeOUV5wiWe+6BJ5+EOnXcbpmIiIgUUIAqhisVrLduhZQUGDUK9uyB226DAQMgPt6/9xURERGvKUAdJuAVrLdvh5EjnT87dzpVw59+Gho39v29RERExCfCIkD5csToaBWsfRqgdu+Gl16C556DbdsgKcnZZXfWWb67h4iIiPhFyAcoX48Y+b2CdU4OjB3rrHPatAmuugoGDXLqOomIiEhICPlCmr4+88xvFaz37YMxY6BhQ6fo5dlnw/z58MknCk8iIiIhJuQDlK9HjHxewXr/fnjzTWdN04MPOovCv/kGZs6ECy8s22uKiIiIq0I+QPl6xKhDiziGJzUjLjYGA8TFxjA8qZn304F5eTB5MpxxhnOw70knwRdfwNy5kJhYpraJiIhIcAj5NVC92jc5ZA0UlP/Ms3JVsM7Phw8+cEoQ/PYbNG8OH35Iap1zSflyOeu/+UzHioiIiIS4kB+B8tmIUXlZ66xnOu8857Df/Hx47z1YtIjUU8+j7we/kJWdg+XAQvfU9KzAtlFERER8IuRHoMDlM8+sha+/hn79YMECaNAA3n4bunSBKGctVcBKI4iIiEhAhPwIlKvmzHHWM112GaxfD6+9BsuWwa23FoUnCEBpBBEREQkoBaiyWLAALr8cLr4Yli93CmJmZkLXrhAdfcTlfiuNICIiIq5QgPJGejpccw1ccIHz9YgRsGIFdO8OlSuX+DSfl0YQERERV4XFGii/++0353y6qVMhNhaGDoWHHoLq1T16euE6p4AfUCwiIiJ+oQB1NJmZzvl0U6ZAtWpOaYKePZ0Q5SV/LHT35RmAIiIi4jkFqOKsXg2DB8OECVCpEjzxBPTqBTVrut2yIr4+A1BEREQ8pzVQB8vKgm7dnGNXJk921jatXOkc/BtE4Ql8fwagiIiIeE4jUACbNjkh6ZVXnCNYunaFp56COnXcblmJVBpBRETEPZEdoLZuhZQUGDUK9uyB225z1jnFx7vdslLVjo0hq5iwpNIIIiIi/heZU3jbtzuLw+Pj4dln4brrnJ12b70VEuEJVBpBRETETZE1ArV7t1P0MiXFGX1KSnKC1Flnud0yr6k0goiIiHsiI0Dl5MDYsc46p02b4D//gUGDnIN/Q5irZwCKiIhEsPCewtu3D8aMgYYN4dFHoVkzmD8fPv005MOTiIiIuKdMI1DGmAeAzgXfxgILrLX3+axV5bV/P7z9tjPKtGYNtGoFkybBJZe43TIREREJA2UagbLWjrHWJlprE4G5wGs+bVVZ5eU59ZvOOAPuvhtOPBG++ALmzlV4EhEREZ8p1xSeMSYOqGWtTfNRe8omPx+mTYOzz4ZbboGYGEhNhR9/hPbtwRhXmyciIiLhpbyLyLsBYw5/0BhzL3AvQN26dct5Cw/s2gX33AMnnQT/93/QsSNUCO/lXSIiIuIeY60t2xONqQDMA/5tj/IiCQkJNi0tAANUv/4KTZpAxcBuLNSBviIiIuHDGLPQWptQ2nXlSRutcRaPly2B+dqZZwb8ljrQV0REJDKVZ56rPTDHVw0JRTrQV0REJDKVeQTKWvukLxsSinSgr4iISGTSSutyKOngXh3oKyIiEt4iKkClpmfRKnkW8X0+pVXyLFLTs8r1ejrQV0REJDJFxll4+GfBtw70FRERiUwRE6COtuC7PIFHB/qKiIhEnoiZwtOCbxEREfGViAlQWvAtIiIivhIxAUoLvkVERMRXImYNlBZ8i4iIiK9ETIACLfgWERER34ioACWBp8OWRUQkHClAid/osGUREQlXEbOIXAJPhy2LiEi4UoASv1HtLRERCVcKUOI3qr0lIiLhSgFK/Ea1t0REJFxpEbn4jWpviYhIuFKAEr9S7S0REQlHmsITERER8ZIClIiIiIiXFKBEREREvKQAJSIiIuIlBSgRERERLylAiYiIiHhJAUpERETESwpQIiIiIl5SgBIRERHxkgKUiIiIiJeMtda/NzBmM7DGrzc54ATg7wDdK9io75FJfY9M6ntkUt8D4zRr7YmlXeT3ABVIxpg0a22C2+1wg/quvkca9V19jzTqe3D1XVN4IiIiIl5SgBIRERHxUrgFqHFuN8BF6ntkUt8jk/oemdT3IBJWa6BEREREAiHcRqBERERE/C7kApQxppYxZm4p10QbYz42xswzxtxV0mOhxhjzhjHme2NMv6Nc84AxZnbBn5+NMa8aYyoaY9Ye9HizQLbbFzzse7H9NMYMNMb8ZIwZHbgW+46HfT/OGPO5MeZLY8wHxphKofy+e9jnI67x5HnBrrQ+hNt7fTgP+h+WP+fgUd/D8vc7lP7ZHoyf6yEVoIwxNYAJQNVSLn0IWGitbQV0NMZUL+GxkGGMSQKirLUXAvWNMY2Ku85aO8Zam2itTQTmAq8BZwPvFD5urV0asIb7gKd9p5h+GmPOAy4CWgKbjDHtAtRsn/Ci7zcDL1hrLwc2AFcQou+7J30u7hov/q2Clod9CJv3+nAe9j/sfs7Bs76H4+938PizPeg+10MqQAF5QGdgRynXJQLvFXw9B0go4bFQksiB9n+J88uiRMaYOKCWtTYNuAC42hjzY8F/4VT0a0t9LxHP+l5cPy8Gpllnsd8MoLW/G+tjiXjQd2vtK9bamQXfnghsInTf90RK73Nx13jyvGCXSCl9CLP3+nCJlP4ehuPPOXjx/98w+/0Onn22JxJkn+tBHaAKhicLhyVnA49Ya7d78NSqQFbB11uBWiU8FrSK6ftDeNf+bsCYgq9/AtpZa1sC0cB//NBknylH34vrZ0S978aYC4Ea1tofCLH3/SCevGch/zNeAo/7ECbv9eE86X/I/5yXwJs+hOzv9+JYa3d48NkedD/zQZ1UrbX3lfGpu4AYYDtQreD74h4LWof33RjzP5z2g9P+EsOvMaYCcAnwVMFDS6y1ewu+TgOCemqjHH0vrp+F73tpzw0K5XzfjwdeAq4veCik3veDePKeFXdNSL3XJfCoD2H0Xh/Ok/6H/M95CTx970P693s5BN3neij+n8wTCzkw/NkcWF3CY6HEm/a3BhbYAzUqJhpjmhtjooAOwGK/tdI/PO17cf2MiPfdGFMJeB/oa60tPHsyVN93T/ocjj/j4EEfwuy9Ppwn72E4/pyD530It9/vngq+n3lrbcj9AWYf9HVboPthf38a8CvwP5zhzajiHnO7H172+VicH4wXgGXAccAZwJBirh0GJB30/VnAEmApMNTtvvir78X1E+c/EuYVvO8ZQLzb/fFT3x8AtgGzC/50DtX3vZg+Ny+mv8X9uxzxmNt98VPfw+a9LmP/w+7n3NO+F1wXVr/fD+vb7IL/DYnP9bAtpGmMqY2TTGfYgrnV4h4LJQU7FS4D5lhrN7jdnkAqT9+NMTHAVcAia+1Kf7TPnyLxffekz8VdEw7/VuHQh/Ioa/9D/ecc9N6XJtg+18M2QImIiIj4S7iugRIRERHxGwUoERERES8pQImIiIh4SQFKRERExEsKUCIiIiJeUoASERER8dL/AyjKJVG5GDVIAAAAAElFTkSuQmCC\n", 583 | "text/plain": [ 584 | "
" 585 | ] 586 | }, 587 | "metadata": { 588 | "needs_background": "light" 589 | }, 590 | "output_type": "display_data" 591 | } 592 | ], 593 | "source": [ 594 | "y_hat = sess.run(y_hat, {x: x_train, y: y_train}) # 模型的预测输出\n", 595 | "plt.scatter(x_train, y_train, label='原始数据')\n", 596 | "plt.plot(x_train, y_hat, c='r', label='拟合直线')\n", 597 | "# 显示图例\n", 598 | "plt.legend() \n", 599 | "plt.show()" 600 | ] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "### 关闭会话& 图复位" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 24, 612 | "metadata": {}, 613 | "outputs": [], 614 | "source": [ 615 | "sess.close() # 关闭会话 Session\n", 616 | "tf.reset_default_graph() # 图复位" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "## TensorBoard" 624 | ] 625 | }, 626 | { 627 | "cell_type": "markdown", 628 | "metadata": {}, 629 | "source": [ 630 | "### 创建日志存放目录" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 25, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "from datetime import datetime\n", 640 | "\n", 641 | "now = datetime.utcnow().strftime('%Y%m%d%H%M%S')\n", 642 | "root_logdir = 'tf_logs'\n", 643 | "logdir = '{}/ch03/run-{}'.format(root_logdir, now)" 644 | ] 645 | }, 646 | { 647 | "cell_type": "markdown", 648 | "metadata": {}, 649 | "source": [ 650 | "### 构建计算图" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": 26, 656 | "metadata": {}, 657 | "outputs": [], 658 | "source": [ 659 | "x = tf.placeholder(tf.float32, name='x')\n", 660 | "y = tf.placeholder(tf.float32, name='y')\n", 661 | "w1 = tf.Variable(tf.random_normal([1]), name='w1')\n", 662 | "w0 = tf.Variable(tf.zeros([1]), name='w0')\n", 663 | "y_hat = w0 + w1 * x\n", 664 | "loss = tf.reduce_mean(tf.square(y_hat - y))\n", 665 | "optimizer = tf.train.GradientDescentOptimizer(0.01) # 学习率设为 0.01\n", 666 | "train = optimizer.minimize(loss)" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": {}, 672 | "source": [ 673 | "### 设置 TensorBoard" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 27, 679 | "metadata": {}, 680 | "outputs": [], 681 | "source": [ 682 | "# 给损失模型的输出添加scalar,用来观察loss的收敛曲线\n", 683 | "loss_summary = tf.summary.scalar('loss', loss)\n", 684 | "# 模型运行产生的所有数据保存到文件夹供 TensorBoard 使用\n", 685 | "file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())" 686 | ] 687 | }, 688 | { 689 | "cell_type": "markdown", 690 | "metadata": {}, 691 | "source": [ 692 | "### 会话 Session" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": 28, 698 | "metadata": {}, 699 | "outputs": [ 700 | { 701 | "name": "stdout", 702 | "output_type": "stream", 703 | "text": [ 704 | "Iteration[20/500], loss: 49.139008\n", 705 | "Iteration[40/500], loss: 23.541363\n", 706 | "Iteration[60/500], loss: 11.763593\n", 707 | "Iteration[80/500], loss: 6.235038\n", 708 | "Iteration[100/500], loss: 3.559620\n", 709 | "Iteration[120/500], loss: 2.207230\n", 710 | "Iteration[140/500], loss: 1.483401\n", 711 | "Iteration[160/500], loss: 1.069147\n", 712 | "Iteration[180/500], loss: 0.815151\n", 713 | "Iteration[200/500], loss: 0.649443\n", 714 | "Iteration[220/500], loss: 0.535848\n", 715 | "Iteration[240/500], loss: 0.455141\n", 716 | "Iteration[260/500], loss: 0.396404\n", 717 | "Iteration[280/500], loss: 0.352993\n", 718 | "Iteration[300/500], loss: 0.320601\n", 719 | "Iteration[320/500], loss: 0.296292\n", 720 | "Iteration[340/500], loss: 0.277985\n", 721 | "Iteration[360/500], loss: 0.264170\n", 722 | "Iteration[380/500], loss: 0.253731\n", 723 | "Iteration[400/500], loss: 0.245839\n", 724 | "Iteration[420/500], loss: 0.239868\n", 725 | "Iteration[440/500], loss: 0.235351\n", 726 | "Iteration[460/500], loss: 0.231933\n", 727 | "Iteration[480/500], loss: 0.229345\n", 728 | "Iteration[500/500], loss: 0.227387\n" 729 | ] 730 | } 731 | ], 732 | "source": [ 733 | "init = tf.global_variables_initializer()\n", 734 | "sess = tf.Session()\n", 735 | "sess.run(init)\n", 736 | "\n", 737 | "num_iter = 500\n", 738 | "for i in range(num_iter):\n", 739 | " # 训练时传入loss_summary\n", 740 | " summary, _ = sess.run([loss_summary, train], {x: x_train, y: y_train})\n", 741 | " # 收集每次训练产生的数据\n", 742 | " file_writer.add_summary(summary, i)\n", 743 | " if (i+1) % 20 == 0:\n", 744 | " print('Iteration[{}/{}], loss: {:.6f}'.format(i+1,num_iter,sess.run(loss,{x:x_train,y:y_train})))" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": null, 750 | "metadata": {}, 751 | "outputs": [], 752 | "source": [] 753 | } 754 | ], 755 | "metadata": { 756 | "kernelspec": { 757 | "display_name": "Python 3", 758 | "language": "python", 759 | "name": "python3" 760 | }, 761 | "language_info": { 762 | "codemirror_mode": { 763 | "name": "ipython", 764 | "version": 3 765 | }, 766 | "file_extension": ".py", 767 | "mimetype": "text/x-python", 768 | "name": "python", 769 | "nbconvert_exporter": "python", 770 | "pygments_lexer": "ipython3", 771 | "version": "3.5.6" 772 | } 773 | }, 774 | "nbformat": 4, 775 | "nbformat_minor": 2 776 | } 777 | -------------------------------------------------------------------------------- /08_cnn_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 第 8 章:卷积神经网络 — PyTorch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## 导入 MNIST 数据集" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import torch\n", 24 | "import torchvision\n", 25 | "import torchvision.transforms as transforms\n", 26 | "import torch.nn as nn\n", 27 | "import torch.nn.functional as F\n", 28 | "import torch.optim as optim\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import numpy as np\n", 31 | "\n", 32 | "transform = transforms.Compose(\n", 33 | " [transforms.ToTensor()])\n", 34 | "\n", 35 | "# 训练集\n", 36 | "trainset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录\n", 37 | " train=True,\n", 38 | " download=True, # 不从网络上download图片\n", 39 | " transform=transform)\n", 40 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n", 41 | " shuffle=True, num_workers=2)\n", 42 | "# 测试集\n", 43 | "testset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录\n", 44 | " train=False,\n", 45 | " download=True, # 不从网络上download图片\n", 46 | " transform=transform)\n", 47 | "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n", 48 | " shuffle=False, num_workers=2)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Dataset MNIST\n", 61 | " Number of datapoints: 60000\n", 62 | " Split: train\n", 63 | " Root Location: ./datasets/ch08/pytorch\n", 64 | " Transforms (if any): Compose(\n", 65 | " ToTensor()\n", 66 | " )\n", 67 | " Target Transforms (if any): None\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "print(trainset)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Dataset MNIST\n", 85 | " Number of datapoints: 10000\n", 86 | " Split: test\n", 87 | " Root Location: ./datasets/ch08/pytorch\n", 88 | " Transforms (if any): Compose(\n", 89 | " ToTensor()\n", 90 | " )\n", 91 | " Target Transforms (if any): None\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "print(testset)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "接下来展示一些训练样本图像" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAB6CAYAAACr63iqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAElRJREFUeJzt3XmQVNXZx/HvIyiJWwBxYYugEgV3VIpXLU3hhiuJimI0olJFueDrVlHEJS5vVPSNr0tpFEFFYwQkLqh5X0FwX1AQNSogGBGQQeKOS6ng8/7R9545I91MT09Pz/Sd36eK6qdPd9977tz2ePrcc59j7o6IiGTHOs1dARERKS817CIiGaOGXUQkY9Swi4hkjBp2EZGMUcMuIpIxathFRDKmUQ27mQ00s/lmttDMRparUiIiUjor9QYlM2sDvAscACwFXgWOc/d3ylc9ERFpqLaN+Gw/YKG7/wvAzCYAg4CCDbuZ6TZXEZGG+9jdNy32zY0ZiukKLImeL03K6jCz4WY2y8xmNWJfIiKt2QcNeXNjeuyWp2yNHrm7jwHGgHrsIiKV0Jge+1Kge/S8G7CscdUREZHGakzD/irQy8x6mtl6wBBgSnmqJSIipSp5KMbdV5nZCOAJoA1wp7u/XbaaiYhISUqe7ljSzjTGLiJSitnuvnuxb9adpyIiGaOGXUQkY9Swi4hkjBp2EZGMUcMuIpIxathFRDJGDbuISMY0JleMSKNsueWWIb788ssBmDhxYigzq01H1K9fvxA/99xzACxatCiUvffee01VTWkhRowYEeKbb74ZgEL34UyePDnEn376KQBff/11KJs3b17ez02dOhWADz5oUM6tFkc9dhGRjNGdp9Lk1l9//RCPHz8+xAMGDAhxhw4dGrzdtHcFcNJJJ4V4+fLlDd6WtHwvv/xyiPfYY48m2ccZZ5wBwG233dYk228E3XkqItKaqWEXEckYXTyVslpnndq+whFHHAHAn/70p1DWu3fvEK9atSrE7777LlD3otdjjz0W4h49eoT4oosuAuDAAw8MZVdffXWITz755JLrL63bddddB8DSpUtDWfw9rBbqsYuIZIwadhGRjNFQjJRV166165k/+OCDa7wezx8+55xzQvzEE0+sdbvxjIhtttkGgCuuuKLkerZEgwcPDvHYsWNDvPHGG4f4hRdeAGD48OGh7J133qlA7VqHdAbXsGHDQtlTTz0V4ngufEumHruISMaoYRcRyZhWfYPSQQcdBMCoUaNC2U477RTiZ599NsQjR44EYO7cuRWqXXVq06ZNiI855pg1Xn/44YdD/O2335a0j3QoZtasWaHszTffDPE+++xT0nYrKf47pbOG/vCHP4SyOJ3CkiVLQty9e3cApkypXTf+/PPPD3FNTU2IV65cWcYaN7/DDz88xKNHjwbgV7/6VSj74YcfQvz++++HeKONNgKgS5cuJe23Z8+eIV68eHFJ2yiD8t6gZGZ3mtkKM3srKutoZtPMbEHy2PDbBkVEpEnU22M3s32Ar4B73H2HpOxa4FN3v8bMRgId3P2CenfWAnrsl112WYgPOeQQAB5//PFQNmPGjBD/7ne/C3Ha+7z00ktD2S233NJU1ZS1yNdjnzBhQohPPfXUitepoeK59uPGjQNg9erVoeziiy8OcXx7+9/+9jcADj744Lzbveuuu0IcXwDMqrPPPjvEX3zxRYjjv8O2224LwGGHHRbK0l/rAPvtt99a95HJHru7Pwt8+pPiQUCa9GM88JuiqyciIk2q1OmOm7t7DYC715jZZoXeaGbDgeGFXhcRkfJq8nns7j4GGAOVHYqJ51PfcMMNIe7UqVOI06GWhQsX5t1GmvcbarMSPvTQQ6HsjjvuCPH333/fyBpLsdLMfvH87tmzZzdXdYr2i1/8IsTxUEsqHlYoNMw3Z84coPBQzKabbtqYKlad+L/tQubPn1/nEaB9+/Yhrm8ophqVOt3xIzPrDJA8rihflUREpDFKbdinAEOTeCjwSHmqIyIijVXvUIyZ3Q/8GuhkZkuBPwLXAJPMbBiwGBhceAvN46abbgpxPL81ngv71VdfFb29dHmtzTffPJQNGjQoxA888EBJ9ZSGS+dtf/bZZ6Es/pndUsWzK+L4888/B+qmESgkvb09vvdCmlY8RHbuuec2Y02KV2/D7u7HFXgpewNTIiIZoJQCIiIZk7nsjrvssgtQu8gD1L3tuCHDL7E08f5HH30UyuJb1zUU07QuuKD2/rcdd9wRqHsjUpz+oaXq06dP3vIbb7wRgO+++67ebcQ32eRTzDakYdJFYKqJeuwiIhmTuR77BhtsANRdqT5OCJRP27a1f4Z4ubUjjzwyxL169QJgiy22CGUvvfRS4yoraxXP+z7vvPNCnC6/98wzz1S8To2x22675S2PUyPUp0OHNdMyxUsM/vnPf254xVqJNNc61P73XIx//OMfTVGdJqUeu4hIxqhhFxHJmMwNxSxbtgyoe5EzvvU/Xkasc+fOAAwYMCCUvfHGGyG++eabQ/z6668DdefHd+vWrVzVlkSaOxtg0qRJIY5vlU+HZRYsWFC5ipVZfJEzvTBfqvjCcbyEoOSkWVzj+ehZTCMQU49dRCRj1LCLiGRM5oZi0hkwJ5xwQiiLf4IdddRRIX711VeBunOD49k0+cQzMeLFEq699toSayyx+++/P8QHHHBAiJ9//vkQx0Nk1eStt8IiZLRr1y7E6UIQ8TBgPFNr6NChIf7tb3+7xnarcZ51U9t///1DPHnyZKDu37w+8fftyy+/LF/FKkQ9dhGRjFHDLiKSMZkbiknNmzcvxOVcAzNe+X3PPfcs23ZbsjijZToUEM8kKjQ7aOrUqUDtjCKAKVOmhHi99dYLcbpGZXyDWLyG5emnnx7iOFtnNZk+fXqI02yhUDvTKh5miW/OGjhwYAVqV/3izK333ntviBsyBJMumHPnnXeGsjT7ZjVRj11EJGMy22OXhtt6661DfNVVV4V45513DnGcUK0+/fv3X6PsyiuvDHG/fv1CnK4av2jRolDWt2/fEFdjr+mn4hXuTzvttBBfeumlABx77LGhLJ7nPmPGjBBvv/32QN3UFq3ZfffdF+L4gml8P0RDpN+zRx6pXTsoTWEB8OOPP5a03UpTj11EJGPUsIuIZIyGYlqp+KfqK6+8AkCXLl1CWTxf//bbb1/j8/GQSfzeNLsmwDXXXAPAccfVLsJ1ySWX5K1Pur04x30Whl8KifP3P/744wB07NgxlK1evTrENTU1Ib777rsBOPHEE5u4hi3PZpttFuL0+AcPrl2Vs02bNmXbxyeffBLK4mHJ+IJ+aty4cSGOl2tsTvX22M2su5k9ZWZzzextMzsrKe9oZtPMbEHyuGY+URERqbhihmJWAee5e2+gP3CGmfUBRgLT3b0XMD15LiIizayYxaxrgJokXmlmc4GuwCDg18nbxgNPAxfk2URmffjhh81dhZKNHj06xOkt7ffcc08oO+mkk4re1rrrrhvieAZNPMumPunfstSlC6vZN998U+dR8hs7dmyIDz300Irtd9SoUWt9PZ7hFN9jEc/YiWeDVUKDxtjNrAewKzAT2Dxp9HH3GjPbrMBnhgPDG1dNEREpVtENu5ltCPwdONvdvzSzoj7n7mOAMck2vJRKiohI8Ypq2M1sXXKN+n3u/mBS/JGZdU56652BFU1VyZbqtddea+4qlCzOnJiaNm1a0Z/v3r17iOMZGvl+cqZZNAE23HDDEPfu3TvEPXv2BGCvvfYKZelsEamVLqoR/8333Xff5qpORcWzYlqSHj165C2PZ4Ols5mWLFlSgRoVNyvGgHHAXHe/PnppCpDmEx0KPPLTz4qISOUV02PfC/g98E8zS7M5jQKuASaZ2TBgMTC4wOelSsRJlOKV2eNkZ2nu+rjHv9VWW4V4zpw5IX700UeBujnW0yRLAH/9619DnN42/+STT5Z+AK1AmtM9/jtut912IR40aFCI49vipfLSSQkA7du3ByrXYy9mVszzQKEB9WwvHCgiUoWUUkBEJGOUUqCVilezT+ebH3300aFs7733DnGcaiD18ccfhzhNHQBw/fXX531PPvGFUilOmv7h6aefDmVxDvsjjzwyxFkbikmXvQTYY489mnRfy5YtC3GnTp1CHK8h0JKpxy4ikjFq2EVEMkZDMa3UpEmTQrxq1aqiPzd37lyg7tJh9Q25SPnNmjUrxPFQTJ8+fUKcLgkXL9pRzUaMGBHidPbUrbfeGsrats3fnJ111lkAfPvtt0XvK17OsVevXiGOs5fmc8opp4Q430IzlaIeu4hIxqhhFxHJGHOvXPqWLOSKiX/2Hn/88SEeOnRovreLNIlNNtkkxDNnzgxxfLNYesOZUjNkwmx3373YN6vHLiKSMbp42gi//OUvm7sK0krFS7eNGTMmxPE9BUOGDAFg+fLloWz27NkVqJ00N/XYRUQyRg27iEjGaCimgVauXBnivn37hjjOM94al3eT5jNx4sQQx1k304v7L730UijTUEzroB67iEjGqGEXEckYzWNvhPhn7ZlnnhniF198sTmqIyLZpXnsIiKtmRp2EZGMqXcoxsx+BjwLtCM3i2ayu//RzHoCE4COwGvA7939+8Jbyt5QjIhIhZR9KOY7YIC77wzsAgw0s/7AaOB/3L0X8BkwrJTaiohIedXbsHtOOjF73eSfAwOAyUn5eOA3TVJDERFpkKLG2M2sjZm9DqwApgHvAZ+7e7pCw1Kga9NUUUREGqKoht3dV7v7LkA3oB/QO9/b8n3WzIab2Swzm5XvdRERKa8GzYpx98+Bp4H+QHszS1MSdAOWFfjMGHffvSED/yIiUrp6G3Yz29TM2ifxz4H9gbnAU8DRyduGAo80VSVFRKR4xSQB6wyMN7M25P5HMMndHzOzd4AJZvZfwBxgXBPWU0REilTplAL/Br4GsrqsfSd0bNVIx1adWtOxbenumxb74Yo27ABmNiur4+06tuqkY6tOOrbClFJARCRj1LCLiGRMczTsY+p/S9XSsVUnHVt10rEVUPExdhERaVoaihERyRg17CIiGVPRht3MBprZfDNbaGYjK7nvcjOz7mb2lJnNNbO3zeyspLyjmU0zswXJY4fmrmspksRvc8zsseR5TzObmRzXRDNbr7nrWAoza29mk81sXnLu/iND5+yc5Lv4lpndb2Y/q9bzZmZ3mtkKM3srKst7niznpqRdedPM+jZfzetX4NiuS76Tb5rZQ+nd/slrFybHNt/MDipmHxVr2JM7V28BDgb6AMeZWZ9K7b8JrALOc/fe5HLnnJEcz0hgepKnfnryvBqdRS51RCor+fdvBP7P3bcDdiZ3jFV/zsysK/CfwO7uvgPQBhhC9Z63u4GBPykrdJ4OBnol/4YDf6lQHUt1N2se2zRgB3ffCXgXuBAgaVOGANsnn7k1aUvXqpI99n7AQnf/V7LS0gRgUAX3X1buXuPuryXxSnINRFdyxzQ+eVtV5qk3s27AocDY5LmRgfz7ZrYxsA9J+gt3/z5JbFf15yzRFvh5kpxvfaCGKj1v7v4s8OlPigudp0HAPcnaES+TS1DYuTI1bbh8x+buU6M06C+TS6wIuWOb4O7fufv7wEJybelaVbJh7wosiZ5nJoe7mfUAdgVmApu7ew3kGn9gs+arWcluAM4Hfkyeb0I28u9vBfwbuCsZZhprZhuQgXPm7h8C/w0sJtegfwHMJhvnLVXoPGWtbTkF+N8kLunYKtmwW56yqp9raWYbAn8Hznb3L5u7Po1lZocBK9x9dlyc563VeO7aAn2Bv7j7ruTyFlXdsEs+yXjzIKAn0AXYgNwQxU9V43mrT1a+n5jZReSGee9Li/K8rd5jq2TDvhToHj0vmMO9WpjZuuQa9fvc/cGk+KP0Z2DyuKK56leivYAjzGwRueGyAeR68EXl32/hlgJL3X1m8nwyuYa+2s8Z5NJpv+/u/3b3H4AHgT3JxnlLFTpPmWhbzGwocBhwvNfeYFTSsVWyYX8V6JVcpV+P3AWBKRXcf1kl487jgLnufn300hRy+emhCvPUu/uF7t7N3XuQO0cz3P14MpB/392XA0vMbNukaD/gHar8nCUWA/3NbP3ku5keW9Wft0ih8zQFODGZHdMf+CIdsqkWZjYQuAA4wt2/iV6aAgwxs3Zm1pPcBeJX6t2gu1fsH3AIuSu+7wEXVXLfTXAse5P7SfQm8Hry7xBy49HTgQXJY8fmrmsjjvHXwGNJvFXyhVoIPAC0a+76lXhMuwCzkvP2MNAhK+cMuByYB7wF3Au0q9bzBtxP7lrBD+R6rcMKnSdywxW3JO3KP8nNDGr2Y2jgsS0kN5aetiW3Re+/KDm2+cDBxexDKQVERDJGd56KiGSMGnYRkYxRwy4ikjFq2EVEMkYNu4hIxqhhFxHJGDXsIiIZ8/8aIlzCxxtQlAAAAABJRU5ErkJggg==\n", 114 | "text/plain": [ 115 | "
" 116 | ] 117 | }, 118 | "metadata": { 119 | "needs_background": "light" 120 | }, 121 | "output_type": "display_data" 122 | }, 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | " 9 2 9 2\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "def imshow(img):\n", 133 | " npimg = img.numpy()\n", 134 | " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", 135 | "\n", 136 | "# 选择一个 batch 的图片\n", 137 | "dataiter = iter(trainloader)\n", 138 | "images, labels = dataiter.next()\n", 139 | "\n", 140 | "# 显示图片\n", 141 | "imshow(torchvision.utils.make_grid(images))\n", 142 | "plt.show()\n", 143 | "# 打印 labels\n", 144 | "print(' '.join('%11s' % labels[j].numpy() for j in range(4)))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## 定义卷积神经网络" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 5, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "class Net(nn.Module):\n", 161 | " def __init__(self):\n", 162 | " super(Net, self).__init__()\n", 163 | " self.conv1 = nn.Conv2d(1, 6, 5) # 1个输入图片通道,6个输出通道,5x5 卷积核\n", 164 | " self.pool1 = nn.MaxPool2d(2, 2) # max pooling,2x2\n", 165 | " self.conv2 = nn.Conv2d(6, 16, 5) # 6个输入图片通道,16个输出通道,5x5 卷积核\n", 166 | " self.pool2 = nn.MaxPool2d(2,2) # max pooling,2x2\n", 167 | " self.fc1 = nn.Linear(16 * 4 * 4, 120) # 拉伸成一维向量,全连接层\n", 168 | " self.fc2 = nn.Linear(120, 84) # 全连接层 \n", 169 | " self.fc3 = nn.Linear(84, 10) # 全连接层,输出层 softmax,10个数字\n", 170 | " \n", 171 | " def forward(self, x):\n", 172 | " x = F.relu(self.conv1(x))\n", 173 | " x = self.pool1(x)\n", 174 | " x = F.relu(self.conv2(x))\n", 175 | " x = self.pool2(x)\n", 176 | " x = x.view(-1, 16 * 4 * 4) # 拉伸成一维向量\n", 177 | " x = F.relu(self.fc1(x))\n", 178 | " x = F.relu(self.fc2(x))\n", 179 | " x = self.fc3(x)\n", 180 | " return x" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 6, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "Net(\n", 193 | " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n", 194 | " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 195 | " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", 196 | " (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 197 | " (fc1): Linear(in_features=256, out_features=120, bias=True)\n", 198 | " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", 199 | " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", 200 | ")\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "net = Net()\n", 206 | "print(net) " 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## 定义损失函数与优化算法" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 7, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "criterion = nn.CrossEntropyLoss() # 交叉熵损失\n", 223 | "optimizer = optim.Adam(net.parameters(), lr=0.0001) # Adam 优化算法" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## 训练网络" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 8, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "name": "stdout", 240 | "output_type": "stream", 241 | "text": [ 242 | "[epoch: 1, mini-batch: 2000] loss: 1.085\n", 243 | "[epoch: 1, mini-batch: 4000] loss: 0.397\n", 244 | "[epoch: 1, mini-batch: 6000] loss: 0.305\n", 245 | "[epoch: 1, mini-batch: 8000] loss: 0.231\n", 246 | "[epoch: 1, mini-batch: 10000] loss: 0.191\n", 247 | "[epoch: 1, mini-batch: 12000] loss: 0.174\n", 248 | "[epoch: 1, mini-batch: 14000] loss: 0.152\n", 249 | "[epoch: 2, mini-batch: 2000] loss: 0.133\n", 250 | "[epoch: 2, mini-batch: 4000] loss: 0.119\n", 251 | "[epoch: 2, mini-batch: 6000] loss: 0.125\n", 252 | "[epoch: 2, mini-batch: 8000] loss: 0.102\n", 253 | "[epoch: 2, mini-batch: 10000] loss: 0.110\n", 254 | "[epoch: 2, mini-batch: 12000] loss: 0.094\n", 255 | "[epoch: 2, mini-batch: 14000] loss: 0.097\n", 256 | "[epoch: 3, mini-batch: 2000] loss: 0.084\n", 257 | "[epoch: 3, mini-batch: 4000] loss: 0.087\n", 258 | "[epoch: 3, mini-batch: 6000] loss: 0.078\n", 259 | "[epoch: 3, mini-batch: 8000] loss: 0.070\n", 260 | "[epoch: 3, mini-batch: 10000] loss: 0.086\n", 261 | "[epoch: 3, mini-batch: 12000] loss: 0.073\n", 262 | "[epoch: 3, mini-batch: 14000] loss: 0.079\n", 263 | "[epoch: 4, mini-batch: 2000] loss: 0.059\n", 264 | "[epoch: 4, mini-batch: 4000] loss: 0.070\n", 265 | "[epoch: 4, mini-batch: 6000] loss: 0.070\n", 266 | "[epoch: 4, mini-batch: 8000] loss: 0.055\n", 267 | "[epoch: 4, mini-batch: 10000] loss: 0.057\n", 268 | "[epoch: 4, mini-batch: 12000] loss: 0.066\n", 269 | "[epoch: 4, mini-batch: 14000] loss: 0.061\n", 270 | "[epoch: 5, mini-batch: 2000] loss: 0.052\n", 271 | "[epoch: 5, mini-batch: 4000] loss: 0.053\n", 272 | "[epoch: 5, mini-batch: 6000] loss: 0.052\n", 273 | "[epoch: 5, mini-batch: 8000] loss: 0.059\n", 274 | "[epoch: 5, mini-batch: 10000] loss: 0.047\n", 275 | "[epoch: 5, mini-batch: 12000] loss: 0.049\n", 276 | "[epoch: 5, mini-batch: 14000] loss: 0.047\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "num_epoches = 5 # 设置 epoch 数目\n", 282 | "cost = [] # 损失函数累加\n", 283 | "\n", 284 | "for epoch in range(num_epoches): \n", 285 | " \n", 286 | " running_loss = 0.0\n", 287 | " for i, data in enumerate(trainloader):\n", 288 | " # 输入样本和标签\n", 289 | " inputs, labels = data\n", 290 | " \n", 291 | " # 每次训练梯度清零\n", 292 | " optimizer.zero_grad()\n", 293 | " \n", 294 | " # 正向传播、反向传播和优化过程\n", 295 | " outputs = net(inputs)\n", 296 | " loss = criterion(outputs, labels)\n", 297 | " loss.backward()\n", 298 | " optimizer.step()\n", 299 | " \n", 300 | " # 打印训练情况\n", 301 | " running_loss += loss.item()\n", 302 | " if (i+1) % 2000 == 0: # 每隔2000 mini-batches,打印一次\n", 303 | " print('[epoch: %d, mini-batch: %5d] loss: %.3f' % \n", 304 | " (epoch + 1, i + 1, running_loss / 2000))\n", 305 | " cost.append(running_loss / 2000)\n", 306 | " running_loss = 0.0" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 9, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "data": { 316 | "image/png": "\n", 317 | "text/plain": [ 318 | "
" 319 | ] 320 | }, 321 | "metadata": { 322 | "needs_background": "light" 323 | }, 324 | "output_type": "display_data" 325 | } 326 | ], 327 | "source": [ 328 | "plt.plot(cost)\n", 329 | "plt.xlabel('mini-batches(per 2000)')\n", 330 | "plt.ylabel('running_loss')\n", 331 | "plt.show()" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "## 测试数据" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 10, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "Accuracy on the 60000 train images: 98.883 %\n" 351 | ] 352 | } 353 | ], 354 | "source": [ 355 | "correct = 0\n", 356 | "total = 0\n", 357 | "with torch.no_grad():\n", 358 | " for data in trainloader:\n", 359 | " images, labels = data\n", 360 | " outputs = net(images)\n", 361 | " _, predicted = torch.max(outputs.data, 1)\n", 362 | " total += labels.size(0)\n", 363 | " correct += (predicted == labels).sum().item()\n", 364 | "\n", 365 | "print('Accuracy on the 60000 train images: %.3f %%' % \n", 366 | " (100 * correct / total))" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 11, 372 | "metadata": {}, 373 | "outputs": [ 374 | { 375 | "name": "stdout", 376 | "output_type": "stream", 377 | "text": [ 378 | "Accuracy on the 10000 test images: 98.580 %\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "correct = 0\n", 384 | "total = 0\n", 385 | "with torch.no_grad():\n", 386 | " for data in testloader:\n", 387 | " images, labels = data\n", 388 | " outputs = net(images)\n", 389 | " _, predicted = torch.max(outputs.data, 1)\n", 390 | " total += labels.size(0)\n", 391 | " correct += (predicted == labels).sum().item()\n", 392 | "\n", 393 | "print('Accuracy on the 10000 test images: %.3f %%' % \n", 394 | " (100 * correct / total))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [] 403 | } 404 | ], 405 | "metadata": { 406 | "kernelspec": { 407 | "display_name": "Python 3", 408 | "language": "python", 409 | "name": "python3" 410 | }, 411 | "language_info": { 412 | "codemirror_mode": { 413 | "name": "ipython", 414 | "version": 3 415 | }, 416 | "file_extension": ".py", 417 | "mimetype": "text/x-python", 418 | "name": "python", 419 | "nbconvert_exporter": "python", 420 | "pygments_lexer": "ipython3", 421 | "version": "3.5.6" 422 | } 423 | }, 424 | "nbformat": 4, 425 | "nbformat_minor": 2 426 | } 427 | -------------------------------------------------------------------------------- /08_cnn_tensorflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 第 8 章:卷积神经网络 — TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from tensorflow.examples.tutorials.mnist import input_data\n", 17 | "import tensorflow as tf\n", 18 | "import matplotlib.pyplot as plt" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## 载入数据集" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "WARNING:tensorflow:From :1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 38 | "Instructions for updating:\n", 39 | "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", 40 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", 41 | "Instructions for updating:\n", 42 | "Please write your own downloading logic.\n", 43 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 44 | "Instructions for updating:\n", 45 | "Please use tf.data to implement this functionality.\n", 46 | "Extracting ./datasets/ch08/tensorflow/MNIST\\train-images-idx3-ubyte.gz\n", 47 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 48 | "Instructions for updating:\n", 49 | "Please use tf.data to implement this functionality.\n", 50 | "Extracting ./datasets/ch08/tensorflow/MNIST\\train-labels-idx1-ubyte.gz\n", 51 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 52 | "Instructions for updating:\n", 53 | "Please use tf.one_hot on tensors.\n", 54 | "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-images-idx3-ubyte.gz\n", 55 | "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-labels-idx1-ubyte.gz\n", 56 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 57 | "Instructions for updating:\n", 58 | "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "mnist = input_data.read_data_sets('./datasets/ch08/tensorflow/MNIST',one_hot=True)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "(55000, 784)\n", 76 | "(55000, 10)\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "print(mnist.train.images.shape)\n", 82 | "print(mnist.train.labels.shape)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "(5000, 784)\n", 95 | "(5000, 10)\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "print(mnist.validation.images.shape)\n", 101 | "print(mnist.validation.labels.shape)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 5, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "(10000, 784)\n", 114 | "(10000, 10)\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "print(mnist.test.images.shape)\n", 120 | "print(mnist.test.labels.shape)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## 定义创建模型需要的函数" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# input 代表输入,filter 代表卷积核\n", 137 | "# 卷积层\n", 138 | "def conv2d(x, filter):\n", 139 | " return tf.nn.conv2d(x, \n", 140 | " filter, \n", 141 | " strides=[1,1,1,1], \n", 142 | " padding='SAME')\n", 143 | "\n", 144 | "# 池化层\n", 145 | "def max_pool(x):\n", 146 | " return tf.nn.max_pool(x, \n", 147 | " ksize=[1,2,2,1], \n", 148 | " strides=[1,2,2,1], \n", 149 | " padding='SAME')\n", 150 | "\n", 151 | "# 初始化卷积核或者是权重数组的值\n", 152 | "def weight_variable(shape):\n", 153 | " initial = tf.truncated_normal(shape, stddev=0.1)\n", 154 | " return tf.Variable(initial)\n", 155 | "\n", 156 | "# 初始化bias的值\n", 157 | "def bias_variable(shape):\n", 158 | " initial = tf.constant(0.1, shape=shape)\n", 159 | " return tf.Variable(initial)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## 定义占位符 placeholder" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 7, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# None 代表图片数量未知\n", 176 | "x = tf.placeholder(tf.float32, [None,784])/255\n", 177 | "# 将input 重新调整结构,适用于CNN的特征提取\n", 178 | "x_image = tf.reshape(x, [-1,28,28,1])\n", 179 | "\n", 180 | "# y是最终预测的结果\n", 181 | "y = tf.placeholder(tf.float32, [None,10])" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "## 定义 CNN 模型" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "### CONV1" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "W_conv1 = weight_variable([5, 5, 1, 32]) # 卷积核尺寸:5x5, 输入通道:1, 输出通道:32\n", 205 | "b_conv1 = bias_variable([32])\n", 206 | "h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) # 输出尺寸:28x28x32\n", 207 | "h_pool1 = max_pool(h_conv1) # 输出尺寸:14x14x32" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "### CONV2" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 9, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "W_conv2 = weight_variable([5, 5, 32, 64]) # 卷积核尺寸:5x5, 输入通道:32, 输出通道:64\n", 224 | "b_conv2 = bias_variable([64])\n", 225 | "h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) # 输出尺寸:14x14x64\n", 226 | "h_pool2 = max_pool(h_conv2) # 输出尺寸:7x7x64" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "### FC1" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 10, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "W_fc1 = weight_variable([7*7*64, 1024])\n", 243 | "b_fc1 = bias_variable([1024])\n", 244 | "h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) # 展开为一维向量\n", 245 | "h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "### FC2" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 11, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "W_fc2 = weight_variable([1024, 10])\n", 262 | "b_fc2 = bias_variable([10])\n", 263 | "prediction = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "## 定义损失函数与优化算法" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 12, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(prediction),reduction_indices=[1])) \n", 280 | "train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 13, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "#判断预测标签和实际标签是否匹配\n", 290 | "correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))\n", 291 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction,\"float\"))" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "## 训练并验证准确率" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 14, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stdout", 308 | "output_type": "stream", 309 | "text": [ 310 | "train accuracy 0.810\n", 311 | "train accuracy 0.890\n", 312 | "train accuracy 0.880\n", 313 | "train accuracy 0.960\n", 314 | "train accuracy 0.950\n", 315 | "train accuracy 0.940\n", 316 | "train accuracy 0.950\n", 317 | "train accuracy 0.940\n", 318 | "train accuracy 0.940\n", 319 | "train accuracy 0.960\n", 320 | "train accuracy 0.980\n", 321 | "train accuracy 0.950\n", 322 | "train accuracy 0.940\n", 323 | "train accuracy 0.980\n", 324 | "train accuracy 0.980\n", 325 | "train accuracy 0.980\n", 326 | "train accuracy 1.000\n", 327 | "train accuracy 0.990\n", 328 | "train accuracy 0.990\n", 329 | "train accuracy 0.990\n", 330 | "test accuracy 0.977\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "sess = tf.Session()\n", 336 | "init = tf.global_variables_initializer()\n", 337 | "sess.run(init)\n", 338 | "\n", 339 | "cost = []\n", 340 | "\n", 341 | "for i in range(1000):\n", 342 | " batch_x, batch_y = mnist.train.next_batch(100)\n", 343 | " sess.run(train_step, feed_dict={x: batch_x, y: batch_y})\n", 344 | " if (i+1) % 50 == 0:\n", 345 | " print(\"train accuracy %.3f\" % accuracy.eval(session = sess,\n", 346 | " feed_dict = {x:batch_x, y:batch_y}))\n", 347 | "print(\"test accuracy %.3f\" % accuracy.eval(session = sess,\n", 348 | " feed_dict = {x:mnist.test.images, y:mnist.test.labels}))" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [] 357 | } 358 | ], 359 | "metadata": { 360 | "kernelspec": { 361 | "display_name": "Python 3", 362 | "language": "python", 363 | "name": "python3" 364 | }, 365 | "language_info": { 366 | "codemirror_mode": { 367 | "name": "ipython", 368 | "version": 3 369 | }, 370 | "file_extension": ".py", 371 | "mimetype": "text/x-python", 372 | "name": "python", 373 | "nbconvert_exporter": "python", 374 | "pygments_lexer": "ipython3", 375 | "version": "3.5.6" 376 | } 377 | }, 378 | "nbformat": 4, 379 | "nbformat_minor": 2 380 | } 381 | -------------------------------------------------------------------------------- /09_rnn_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 第 9 章:循环神经网络 — PyTorch" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## 导入 MNIST 数据集" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import torch\n", 24 | "import torchvision\n", 25 | "import torchvision.transforms as transforms\n", 26 | "import torch.nn as nn\n", 27 | "import torch.nn.functional as F\n", 28 | "import torch.optim as optim\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import numpy as np\n", 31 | "\n", 32 | "transform = transforms.Compose(\n", 33 | " [transforms.ToTensor()])\n", 34 | "\n", 35 | "# 训练集\n", 36 | "trainset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录\n", 37 | " train=True,\n", 38 | " download=True, # 不从网络上download图片\n", 39 | " transform=transform)\n", 40 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n", 41 | " shuffle=True, num_workers=2)\n", 42 | "# 测试集\n", 43 | "testset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch', # 选择数据的根目录\n", 44 | " train=False,\n", 45 | " download=True, # 不从网络上download图片\n", 46 | " transform=transform)\n", 47 | "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n", 48 | " shuffle=False, num_workers=2)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Dataset MNIST\n", 61 | " Number of datapoints: 60000\n", 62 | " Split: train\n", 63 | " Root Location: ./datasets/ch08/pytorch\n", 64 | " Transforms (if any): Compose(\n", 65 | " ToTensor()\n", 66 | " )\n", 67 | " Target Transforms (if any): None\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "print(trainset)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "Dataset MNIST\n", 85 | " Number of datapoints: 10000\n", 86 | " Split: test\n", 87 | " Root Location: ./datasets/ch08/pytorch\n", 88 | " Transforms (if any): Compose(\n", 89 | " ToTensor()\n", 90 | " )\n", 91 | " Target Transforms (if any): None\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "print(testset)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "接下来展示一些训练样本图像" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAB6CAYAAACr63iqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEvJJREFUeJzt3XnQVNWZx/HvIyoiBgUEZIsrcY1L3HDGmLijg2JcEgxREqlgEo06RaJoKjFWWYkiruVWVFSIoqi4EYwwxNE4RGVEJAZF5EUdZESJcQPHoOgzf/S9h/NCN728/d6m7/v7VL31Pn26+95zuc15T5977nPM3RERkfzYpNEVEBGR+lLDLiKSM2rYRURyRg27iEjOqGEXEckZNewiIjmjhl1EJGfa1LCb2RAzW2RmLWY2tl6VEhGR2lmtNyiZWSfgVeBoYBnwHHC6u79cv+qJiEi1Nm3Dew8CWtz9NQAzmwIMA0o27Gam21xFRKr3rrv3qvTFbRmK6Q+8GT1elpS1YmajzWyumc1tw75ERDqy/6nmxW3psVuRsvV65O4+AZgA6rGLiGShLT32ZcDA6PEA4K22VUdERNqqLQ37c8AgM9vRzDYHhgPT6lMtERGpVc1DMe6+xszOBWYCnYDb3f2lutVMRERqUvN0x5p2pjF2EZFaPO/uB1T6Yt15KiKSM2rYRURyRg27iEjOqGEXEckZNewiIjmjhl1EJGfUsIuI5IwadhGRnFHDLiKSM2rYRURyRg27iEjOqGEXEcmZtiy0ISKy0TvssMMAePHFF0PZBx980KjqZEI9dhGRnFHDLiKSMxqKATbffPMQ77777iHeZ599QvyTn/xkg9uYMGFCiGfMmAHAW29ppUBpu1691i5O/4Mf/ACAk08+OZQdfPDBIb777rtDfOONNwLwzDPPtHcVNwpbb711iMeNGxfiUaNGAXD11VeHsosuuii7ijWAeuwiIjmjhl1EJGc69NJ46de1+KvumWee2ebtvvDCCwAcf/zxoWzFihVt3m49DRgwIMTnn38+ACeccEIo+8pXvlL0ff/4xz8AmDx5ctHnP/rooxBPnDix4vqkw1b//Oc/K35Pnh1yyCEhvuGGG0K8//77V7yN9957D4BvfOMboeyll/K1LHGfPn1CfMstt4R42LBhIb7vvvsAOPfcc0NZ+jluIvVdGs/MbjezFWa2ICrrYWazzGxx8rt7rbUVEZH6KttjN7PDgFXA7919r6RsHPCeu19hZmOB7u5e9mpEo3rsxx57bIh79+4d4vQvfJcuXdplv3Hv6KSTTgrxa6+91i77q0Y8pzftnf/1r38NZQccUHHnoJVNNlnbV/jiiy8qft/tt98OwBVXXBHKlixZUlMdmtW2224b4kcffTTEBx54YIhbWloA+PGPf1x0G0ceeWSIx44dC8Dw4cNDWdp7bXZpT/2Pf/xjKPvqV78a4tmzZ4f4uOOOA2D16tUZ1a5d1LfH7u5PAe+tUzwMmJTEk4CTEBGRjUKt0x37uPtyAHdfbma9S73QzEYDo2vcj4iIVKmii6dmtgMwPRqK+cDdt4mef9/dy46zZzEUk35FS79+AVx77bUh7tatW3tXoah4WOa0004L8RtvvAFk/zVxl112CfGmmxb+vi9dujSU7bTTTiHebLPNQlxqCKCYPffcE4A1a9aEskMPPXSD75k2bVqIv/Wtb1W8r2aWDsHEF/9OOeWUEKf3RcTln3zySdFt7bXXXiFOL+IPGTIklMUXt1955ZUQr1y5sqa6Zyn+PKRDdl27dg1l8dz1+IJzTtR3KKaEd8ysL0Dye+Oa8iEi0oHV2rBPA0Ym8UjgkfpUR0RE2qqSWTH3AN8EtgXeAS4FHgbuA74MLAVOc/d1L7AW21a7D8WkX8HOOeec9t5VXVx++eUAXHrppQ2uSf2lQzjxZ2ybbcIIHpdddlmIf/SjHwHw2WefhbJ42Gbu3LntVs9G6NSpU4inTp0KtJ57/dRTT4V46NChIV61atUGt9u5c+cQb7/99gCceOKJoew3v/lNiK+77roQX3jhhRXXvb317ds3xOPHjw/xqaeeGuI//OEPAPzyl78MZQsXLsygduuLZ9r16NEjxPG9K+k9BW1Q1VBM2Yun7n56iaeOLFEuIiINpJQCIiI507TZHeNZGw888ECId9ttt6q39dBDD4X4rLPOKvqan/70p0DrmyDKiWfmbLXVVkVfc8wxxwD5HIqJh1VS7777bojjYz766KMB2HnnnUNZfGNZ3oZi7rjjjhCnQzAzZ84MZRdccEGIyw2/xOLZVeeddx5QOjPp+++/X/F2szRv3rwQ9+zZM8Q//OEPQ3zvvfcC2aegiGe0pTcdfv3rXw9l/fv3D/Hrr78e4jFjxgDwyCPZXI5Uj11EJGeaNglY/Nf71ltvrfh9cS8lzW399NNPh7J6Jgf6y1/+EuLBgwcXfc2HH34IwIgRI0LZY489Vrc6NIs0Z/hBBx0UyhYsCOmJWuXGb1bxt8z4G0h6QTlO8JXOQa9WvC5A+u0zTvNwzTXXhDi+YFpN+of2Ftfl9NPXXuJLe+ntKf1mHSfBixMDxhdwly9fDsBzzz0XyuJUHf369QvxJZdcAqy9Z6QGmcxjFxGRjZQadhGRnGnai6fxnN9y4gt28fueffbZutZpXRdffHGIn3jiiaKvSZfzii9wdcShmI7gqKOOCnE8n//hhx8GWmfXLCeer37llVeG+Pvf/36I0yGYeKgyzfgIG9fwSylZ5I+PU2b8+te/BuDss88OZZMmTQrx4YcfHuLFixdvcLvpthpBPXYRkZxRwy4ikjNNOxQTLztXbmZPOvsF2n/4JRavil5OfDyST+kMqHWl89dLDY3Ec6MPPvhgAE4++eRQ9t3vfrfo+2bNmgWsTVsBrTNtNoPtttsuxPEsqbaKs5vG2V+/9KUvAa2Xjix13oqJM1CmM2GgdSqHLKjHLiKSM2rYRURypmmHYswsxMWGYuJZKHPmzGn3+sQ3gXzve98DYNdddy37vvS2+zjTnuRTuvDIutJb0+PP0N577x3iM844I8RbbrnlBveR3jQDa2+seeedd6qvbAPFqSgmTpwY4nh4pBbxIiRPPvlkiF999dUQp2lAah1+ufPOO0Mcry/729/+tqq6tpV67CIiOdO0PfZyF0yXLVsW4nqmCSglvvU5TvBUTnrreDy/uKPYb7/9Qrzvvvuu9/z8+fOzrE67iz8X8XoBabKzOOlZreI0Ac3WU0+l32AAJk+eHOLhw4eHeMqUKVVvN76HJb61P04ZUK6nHiepS5eJjM9lnFBw5MiRIf7888+rrm9bqMcuIpIzathFRHKmaYdiyl08zUK6ujy0zjZZjSxumd6YxEvCxRed0lvk4/QPV111VXYVy0Ccnzue1/zzn/8cgLfffjuUPfrooyGOM0GmF+fii6gPPvhgiO+///461rgx4pQa99xzT4hvuummEO+xxx5A60kH5Zaf+/Of/xziU045JcTxZy4dHhw0aFAoO+KII0L8ne98J8RLliwBWudo/9Of/hTirIdfYmV77GY20MyeMLOFZvaSmZ2flPcws1lmtjj53b39qysiIuVUMhSzBhjj7rsDg4FzzGwPYCzwuLsPAh5PHouISINVvdCGmT0C3Jj8fNPdl5tZX+BJd9/gxO16LrQR17vYrdhxUv44c+Inn3wS4i222AJovbxWly5diu6vW7duQOvhgT59+oQ4Xh6rXH3TbH4Ao0aNAqqbN9vM4n+zt956a73n4yGKeAX6PEs/W59++mkoiz+T8Tz2dF73ypUrQ1nv3r1DHG8jb+K0G+n/748//jiUxQtexNLhq3QZSmidBTOeQTdw4MD13h9vN16oJJ0Lv2LFikqq31ZVLbRR1Ri7me0A7AfMAfq4+3KApHHvXeI9o4HR1exHRERqV3HDbmZbAQ8AF7j7R/HFyw1x9wnAhGQbjbnKKSLSgVQ0FGNmmwHTgZnufk1StogGDsXMmDEjxOkK95WYPn16iNOFD+KyeE3Deoq/Wnft2rVd9tEMNBRTmZ49e4Y4XtAhHbb51a9+Fcqyzhy4MejevTBXI76pK854WeyGt7gzOnv27BAvWrRovdfGs4vioZrVq1fXWOM2q++ap1b417gNWJg26olpQHpr1UjgkWpqKSIi7aNsj93MDgX+C/gbkF6lvITCOPt9wJeBpcBp7r7BiaT17LEPHTo0xHfddVeI03zKjZT+hR8xYkQoi+cot7S0ZF6njUW8XFixHnl88apYj76j+NnPfhbicePGhThNQbH//vtnXidpqPpePHX32UCpAfUjK92RiIhkQykFRERypmlTCsQXPOPMimmmtTSvcr2kQ1bxxZOXX345xNdff/165fPmzatrHfLghBNOKFqe3hZf7rbwjiK9SLqu559/PuOaSDNSj11EJGfUsIuI5EzTDsXE4mxwaXz11VeHsm9/+9sh7tevX8Xbve2220KcLlrQkedW1ype4CBe8i02fvx4oPV8/44sznwZW7BgQcY1kWakHruISM6oYRcRyZlcDMUUM2bMmBDffPPNIS63ynts4cKFIV6zZk19KtYB9erVK8SbbLK2LxHfgNSRb0YSqTf12EVEcia3PfZYuoSVNEa8hGBs/vz5IX7zzTezqs5Ga8CAASGOc6yLVEs9dhGRnFHDLiKSMx1iKEYaK12CcF2TJ0/OuCYbt3h5xFWrVoU4vvgc37MhUop67CIiOaOGXUQkZzQUI+1u5syZIT7wwANDHGfoFFi5cmWId9555wbWRJqdeuwiIjmjhl1EJGcqWfN0C+ApoDOFoZup7n6pme0ITAF6APOAM9z90zLbqtuapyIiHUhVa55W0mNfDRzh7vsA+wJDzGwwcCVwrbsPAt4HRtVSWxERqa+yDbsXpJNqN0t+HDgCmJqUTwJOapcaiohIVSoaYzezTmY2H1gBzAKWAB+4e5rycBnQv32qKCIi1aioYXf3z919X2AAcBCwe7GXFXuvmY02s7lmNrf2aoqISKWqmhXj7h8ATwKDgW3MLJ0HPwAomlDb3Se4+wHVDPyLiEjtyjbsZtbLzLZJ4i7AUcBC4Ang1ORlI4FH2quSIiJSuUruPO0LTDKzThT+ENzn7tPN7GVgipldDrwA3LahjYiISDbKzmOv687M/g58DLyb2U6ztS06tmakY2tOHenYtnf3XqVevK5MG3YAM5ub1/F2HVtz0rE1Jx1baUopICKSM2rYRURyphEN+4QG7DMrOrbmpGNrTjq2EjIfYxcRkfaloRgRkZxRwy4ikjOZNuxmNsTMFplZi5mNzXLf9WZmA83sCTNbaGYvmdn5SXkPM5tlZouT390bXddaJInfXjCz6cnjHc1sTnJc95rZ5o2uYy3MbBszm2pmryTn7pAcnbN/Tz6LC8zsHjPbolnPm5ndbmYrzGxBVFb0PFnBDUm78qKZfa1xNS+vxLFdlXwmXzSzh9K7/ZPnLk6ObZGZHVvJPjJr2JM7V28CjgP2AE43sz2y2n87WAOMcffdKeTOOSc5nrHA40me+seTx83ofAqpI1J5yb9/PTDD3XcD9qFwjE1/zsysP3AecIC77wV0AobTvOdtIjBknbJS5+k4YFDyMxq4JaM61moi6x/bLGAvd98beBW4GCBpU4YDeybvuTlpSzcoyx77QUCLu7+WrLQ0BRiW4f7ryt2Xu/u8JF5JoYHoT+GYJiUva8o89WY2APg34HfJYyMH+ffNrBtwGEn6C3f/NEls1/TnLLEp0CVJzrclsJwmPW/u/hTw3jrFpc7TMOD3ydoRz1JIUNg3m5pWr9ixuft/RGnQn6WQWBEKxzbF3Ve7++tAC4W2dIOybNj7A29Gj3OTw93MdgD2A+YAfdx9ORQaf6B342pWs+uAC4Evksc9yUf+/Z2AvwN3JMNMvzOzruTgnLn7/wLjgaUUGvQPgefJx3lLlTpPeWtbzgIeS+Kaji3Lht2KlDX9XEsz2wp4ALjA3T9qdH3aysyGAivc/fm4uMhLm/HcbQp8DbjF3fejkLeo6YZdiknGm4cBOwL9gK4UhijW1YznrZy8fD4xs19QGOadnBYVeVnZY8uyYV8GDIwel8zh3izMbDMKjfpkd38wKX4n/RqY/F7RqPrV6F+BE83sDQrDZUdQ6MFXlH9/I7cMWObuc5LHUyk09M1+zqCQTvt1d/+7u38GPAj8C/k4b6lS5ykXbYuZjQSGAiN87Q1GNR1blg37c8Cg5Cr95hQuCEzLcP91lYw73wYsdPdroqemUchPD02Yp97dL3b3Ae6+A4Vz9J/uPoIc5N9397eBN81s16ToSOBlmvycJZYCg81sy+SzmR5b05+3SKnzNA04M5kdMxj4MB2yaRZmNgS4CDjR3f8vemoaMNzMOpvZjhQuEP932Q26e2Y/wPEUrvguAX6R5b7b4VgOpfCV6EVgfvJzPIXx6MeBxcnvHo2uaxuO8ZvA9CTeKflAtQD3A50bXb8aj2lfYG5y3h4GuuflnAGXAa8AC4A7gc7Net6AeyhcK/iMQq91VKnzRGG44qakXfkbhZlBDT+GKo+thcJYetqW3Bq9/hfJsS0CjqtkH0opICKSM7rzVEQkZ9Swi4jkjBp2EZGcUcMuIpIzathFRHJGDbuISM6oYRcRyZn/BwyyzRo47o6iAAAAAElFTkSuQmCC\n", 114 | "text/plain": [ 115 | "
" 116 | ] 117 | }, 118 | "metadata": { 119 | "needs_background": "light" 120 | }, 121 | "output_type": "display_data" 122 | }, 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | " 0 7 9 6\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "def imshow(img):\n", 133 | " npimg = img.numpy()\n", 134 | " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", 135 | "\n", 136 | "# 选择一个 batch 的图片\n", 137 | "dataiter = iter(trainloader)\n", 138 | "images, labels = dataiter.next()\n", 139 | "\n", 140 | "# 显示图片\n", 141 | "imshow(torchvision.utils.make_grid(images))\n", 142 | "plt.show()\n", 143 | "# 打印 labels\n", 144 | "print(' '.join('%11s' % labels[j].numpy() for j in range(4)))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## 定义循环神经网络" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 5, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "class Net(nn.Module):\n", 161 | " def __init__(self):\n", 162 | " super(Net, self).__init__()\n", 163 | " self.rnn = nn.LSTM( # 使用 LSTM 结构\n", 164 | " input_size = 28, # 输入每个元素的维度,即图片每行包含 28 个像素点\n", 165 | " hidden_size = 84, # 隐藏层神经元设置为 84 个\n", 166 | " num_layers=2, # 隐藏层数目,两层\n", 167 | " batch_first=True, # 是否将 batch 放在维度的第一位,(batch, time_step, input_size)\n", 168 | " )\n", 169 | " self.out = nn.Linear(84, 10) # 输出层,包含 10 个神经元,对应 0~9 数字\n", 170 | "\n", 171 | " def forward(self, x):\n", 172 | " r_out, (h_n, h_c) = self.rnn(x, None) \n", 173 | " # 选择图片的最后一行作为 RNN 输出\n", 174 | " out = self.out(r_out[:, -1, :])\n", 175 | " return out" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 6, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "Net(\n", 188 | " (rnn): LSTM(28, 84, num_layers=2, batch_first=True)\n", 189 | " (out): Linear(in_features=84, out_features=10, bias=True)\n", 190 | ")\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "net = Net()\n", 196 | "print(net) " 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## 定义损失函数与优化算法" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 7, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "criterion = nn.CrossEntropyLoss()\n", 213 | "optimizer = optim.Adam(net.parameters(), lr=0.0001)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## 训练网络" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "[epoch: 1, mini-batch: 2000] loss: 1.592\n", 233 | "[epoch: 1, mini-batch: 4000] loss: 0.789\n", 234 | "[epoch: 1, mini-batch: 6000] loss: 0.564\n", 235 | "[epoch: 1, mini-batch: 8000] loss: 0.429\n", 236 | "[epoch: 1, mini-batch: 10000] loss: 0.357\n", 237 | "[epoch: 1, mini-batch: 12000] loss: 0.294\n", 238 | "[epoch: 1, mini-batch: 14000] loss: 0.290\n", 239 | "[epoch: 2, mini-batch: 2000] loss: 0.236\n", 240 | "[epoch: 2, mini-batch: 4000] loss: 0.213\n", 241 | "[epoch: 2, mini-batch: 6000] loss: 0.196\n", 242 | "[epoch: 2, mini-batch: 8000] loss: 0.195\n", 243 | "[epoch: 2, mini-batch: 10000] loss: 0.173\n", 244 | "[epoch: 2, mini-batch: 12000] loss: 0.161\n", 245 | "[epoch: 2, mini-batch: 14000] loss: 0.161\n", 246 | "[epoch: 3, mini-batch: 2000] loss: 0.140\n", 247 | "[epoch: 3, mini-batch: 4000] loss: 0.124\n", 248 | "[epoch: 3, mini-batch: 6000] loss: 0.141\n", 249 | "[epoch: 3, mini-batch: 8000] loss: 0.133\n", 250 | "[epoch: 3, mini-batch: 10000] loss: 0.121\n", 251 | "[epoch: 3, mini-batch: 12000] loss: 0.129\n", 252 | "[epoch: 3, mini-batch: 14000] loss: 0.107\n", 253 | "[epoch: 4, mini-batch: 2000] loss: 0.110\n", 254 | "[epoch: 4, mini-batch: 4000] loss: 0.099\n", 255 | "[epoch: 4, mini-batch: 6000] loss: 0.101\n", 256 | "[epoch: 4, mini-batch: 8000] loss: 0.101\n", 257 | "[epoch: 4, mini-batch: 10000] loss: 0.094\n", 258 | "[epoch: 4, mini-batch: 12000] loss: 0.087\n", 259 | "[epoch: 4, mini-batch: 14000] loss: 0.098\n", 260 | "[epoch: 5, mini-batch: 2000] loss: 0.082\n", 261 | "[epoch: 5, mini-batch: 4000] loss: 0.094\n", 262 | "[epoch: 5, mini-batch: 6000] loss: 0.082\n", 263 | "[epoch: 5, mini-batch: 8000] loss: 0.082\n", 264 | "[epoch: 5, mini-batch: 10000] loss: 0.081\n", 265 | "[epoch: 5, mini-batch: 12000] loss: 0.078\n", 266 | "[epoch: 5, mini-batch: 14000] loss: 0.078\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "num_epoches = 5 # 设置 epoch 数目\n", 272 | "cost = [] # 损失函数累加\n", 273 | "\n", 274 | "for epoch in range(num_epoches): \n", 275 | "\n", 276 | " running_loss = 0.0\n", 277 | " for i, data in enumerate(trainloader, 0):\n", 278 | " # 输入样本和标签\n", 279 | " inputs, labels = data\n", 280 | " inputs = inputs.view(-1, 28, 28) # 设置 RNN 输入维度为 (batch, time_step, input_size)\n", 281 | "\n", 282 | " # 每次训练梯度清零\n", 283 | " optimizer.zero_grad()\n", 284 | "\n", 285 | " # 正向传播、反向传播和优化过程\n", 286 | " outputs = net(inputs)\n", 287 | " loss = criterion(outputs, labels)\n", 288 | " loss.backward()\n", 289 | " optimizer.step()\n", 290 | "\n", 291 | " # 打印训练情况\n", 292 | " running_loss += loss.item()\n", 293 | " if i % 2000 == 1999: # 每隔2000 mini-batches,打印一次\n", 294 | " print('[epoch: %d, mini-batch: %5d] loss: %.3f' % \n", 295 | " (epoch + 1, i + 1, running_loss / 2000))\n", 296 | " cost.append(running_loss / 2000)\n", 297 | " running_loss = 0.0" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 9, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "image/png": "\n", 308 | "text/plain": [ 309 | "
" 310 | ] 311 | }, 312 | "metadata": { 313 | "needs_background": "light" 314 | }, 315 | "output_type": "display_data" 316 | } 317 | ], 318 | "source": [ 319 | "plt.plot(cost)\n", 320 | "plt.xlabel('mini-batches(per 2000)')\n", 321 | "plt.ylabel('running_loss')\n", 322 | "plt.show()" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "## 测试数据" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 10, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "Accuracy of the network on the 60000 test images: 96.355 %\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "correct = 0\n", 347 | "total = 0\n", 348 | "with torch.no_grad():\n", 349 | " for data in trainloader:\n", 350 | " images, labels = data\n", 351 | " images = images.view(-1, 28, 28)\n", 352 | " outputs = net(images)\n", 353 | " _, predicted = torch.max(outputs.data, 1)\n", 354 | " total += labels.size(0)\n", 355 | " correct += (predicted == labels).sum().item()\n", 356 | "\n", 357 | "print('Accuracy of the network on the 60000 test images: %.3f %%' % \n", 358 | " (100 * correct / total))" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 11, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "name": "stdout", 368 | "output_type": "stream", 369 | "text": [ 370 | "Accuracy of the network on the 10000 test images: 95.790 %\n" 371 | ] 372 | } 373 | ], 374 | "source": [ 375 | "correct = 0\n", 376 | "total = 0\n", 377 | "with torch.no_grad():\n", 378 | " for data in testloader:\n", 379 | " images, labels = data\n", 380 | " images = images.view(-1, 28, 28)\n", 381 | " outputs = net(images)\n", 382 | " _, predicted = torch.max(outputs.data, 1)\n", 383 | " total += labels.size(0)\n", 384 | " correct += (predicted == labels).sum().item()\n", 385 | "\n", 386 | "print('Accuracy of the network on the 10000 test images: %.3f %%' % \n", 387 | " (100 * correct / total))" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [] 396 | } 397 | ], 398 | "metadata": { 399 | "kernelspec": { 400 | "display_name": "Python 3", 401 | "language": "python", 402 | "name": "python3" 403 | }, 404 | "language_info": { 405 | "codemirror_mode": { 406 | "name": "ipython", 407 | "version": 3 408 | }, 409 | "file_extension": ".py", 410 | "mimetype": "text/x-python", 411 | "name": "python", 412 | "nbconvert_exporter": "python", 413 | "pygments_lexer": "ipython3", 414 | "version": "3.7.2" 415 | } 416 | }, 417 | "nbformat": 4, 418 | "nbformat_minor": 2 419 | } 420 | -------------------------------------------------------------------------------- /09_rnn_tensorflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 循环神经网络—TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from tensorflow.examples.tutorials.mnist import input_data\n", 17 | "import tensorflow as tf\n", 18 | "import matplotlib.pyplot as plt" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## 导入 MNIST 数据集" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "WARNING:tensorflow:From :1: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 38 | "Instructions for updating:\n", 39 | "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", 40 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", 41 | "Instructions for updating:\n", 42 | "Please write your own downloading logic.\n", 43 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 44 | "Instructions for updating:\n", 45 | "Please use tf.data to implement this functionality.\n", 46 | "Extracting ./datasets/ch08/tensorflow/MNIST\\train-images-idx3-ubyte.gz\n", 47 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 48 | "Instructions for updating:\n", 49 | "Please use tf.data to implement this functionality.\n", 50 | "Extracting ./datasets/ch08/tensorflow/MNIST\\train-labels-idx1-ubyte.gz\n", 51 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 52 | "Instructions for updating:\n", 53 | "Please use tf.one_hot on tensors.\n", 54 | "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-images-idx3-ubyte.gz\n", 55 | "Extracting ./datasets/ch08/tensorflow/MNIST\\t10k-labels-idx1-ubyte.gz\n", 56 | "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", 57 | "Instructions for updating:\n", 58 | "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "mnist = input_data.read_data_sets('./datasets/ch08/tensorflow/MNIST',one_hot=True)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "(55000, 784)\n", 76 | "(55000, 10)\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "print(mnist.train.images.shape)\n", 82 | "print(mnist.train.labels.shape)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "(5000, 784)\n", 95 | "(5000, 10)\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "print(mnist.validation.images.shape)\n", 101 | "print(mnist.validation.labels.shape)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 5, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "(10000, 784)\n", 114 | "(10000, 10)\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "print(mnist.test.images.shape)\n", 120 | "print(mnist.test.labels.shape)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## 定义占位符 placeholder" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# 参数设置\n", 137 | "batch_size = 100 # BATCH 的大小,相当于一次处理100个image\n", 138 | "time_step = 28 # 一个LSTM中,输入序列的长度Tx,image有28行\n", 139 | "input_size = 28 # 单个x向量长度,image有28列\n", 140 | "lr = 0.001 # 学习率\n", 141 | "num_units = 100 # 隐藏层多少个LTSM单元\n", 142 | "iterations =1000 # 迭代训练次数\n", 143 | "classes =10 # 输出大小,0-9十个数字的概率\n", 144 | "\n", 145 | "# 定义 placeholders\n", 146 | "# 维度是[batch_size,time_step * input_size]\n", 147 | "x = tf.placeholder(tf.float32, [None, time_step * input_size]) \n", 148 | "# 输入的是二维数据,将其还原为三维,维度是[batch_size, time_step, input_size]\n", 149 | "x_image = tf.reshape(x, [-1, time_step, input_size]) \n", 150 | "y = tf.placeholder(tf.int32, [None, classes]) " 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## 定义RNN(LSTM)结构" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 7, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=num_units) \n", 167 | "outputs,final_state = tf.nn.dynamic_rnn(\n", 168 | " cell=rnn_cell, # 选择传入的cell\n", 169 | " inputs=x_image, # 传入的数据\n", 170 | " initial_state=None, # 初始状态\n", 171 | " dtype=tf.float32, # 数据类型\n", 172 | " time_major=False, # False: (batch, time_step, input); True: (time_step, batch, input),这里根据x_image结构选择False\n", 173 | ")\n", 174 | "output = tf.layers.dense(inputs=outputs[:, -1, :], units=classes) " 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "## 定义损失函数与优化算法" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 8, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=output) # 计算loss\n", 191 | "train_step = tf.train.AdamOptimizer(lr).minimize(cross_entropy) #选择优化方法" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 9, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "#判断预测标签和实际标签是否匹配\n", 201 | "correct_prediction = tf.equal(tf.argmax(y, axis=1),tf.argmax(output, axis=1))\n", 202 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float')) #计算正确率" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "## 训练并验证准确率" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 10, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "train accuracy 0.560\n", 222 | "train accuracy 0.740\n", 223 | "train accuracy 0.900\n", 224 | "train accuracy 0.810\n", 225 | "train accuracy 0.940\n", 226 | "train accuracy 0.890\n", 227 | "train accuracy 0.880\n", 228 | "train accuracy 0.910\n", 229 | "train accuracy 0.910\n", 230 | "train accuracy 0.930\n", 231 | "train accuracy 0.870\n", 232 | "train accuracy 0.960\n", 233 | "train accuracy 0.960\n", 234 | "train accuracy 0.940\n", 235 | "train accuracy 0.930\n", 236 | "train accuracy 0.960\n", 237 | "train accuracy 0.940\n", 238 | "train accuracy 0.980\n", 239 | "train accuracy 0.920\n", 240 | "train accuracy 0.980\n", 241 | "test accuracy 0.958\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "sess = tf.Session()\n", 247 | "init = tf.global_variables_initializer()\n", 248 | "sess.run(init)\n", 249 | "\n", 250 | "for i in range(iterations):\n", 251 | " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", 252 | " sess.run(train_step, feed_dict={x: batch_x, y: batch_y})\n", 253 | " if (i+1) % 50 == 0:\n", 254 | " print(\"train accuracy %.3f\" % accuracy.eval(session = sess,\n", 255 | " feed_dict = {x:batch_x, y:batch_y}))\n", 256 | "print(\"test accuracy %.3f\" % accuracy.eval(session = sess,\n", 257 | " feed_dict = {x:mnist.test.images, y:mnist.test.labels}))" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3", 271 | "language": "python", 272 | "name": "python3" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.5.6" 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 2 289 | } 290 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 《深度学习入门:基于PyTorch和TensorFlow的理论与实战》 2 | 3 | **作者:红色石头** 4 | 5 | **出版社:清华大学出版社** 6 | 7 | 19世纪70 年代,电力的发明和应用掀起了第二次工业化高潮,从此改变了人们的生活方式,大大提高了人类的科技水平。现如今,深度学习(Deep Learning)技术也正在发挥同样的作用。纵观这近几年,深度学习发展非常迅速,发展势头一直高歌猛进。毫无疑问,深度学习技术正在影响着我们的日常生活和行为方式。 8 | 9 | 10 | ## 深度学习怎么学 11 | 12 | 深度学习怎么学?事实上,很多初学者常常会误入两大误区:一是找不到一本真正适合自己的教材或书本来学习,陷入到海量资源中手足无措;二是受制于数学理论知识,自认为数学基础不好而影响学习的主观性和积极性。 13 | 14 | 这两大误区很容易让初学者陷入到一种迷茫的状态。所以,第一步就是要放弃海量资源,选择一份真正适合自己的资料,好好研读下去!第二步就是重视实践。深度学习涉及的理论知识很多,有些人可能基础不是特别扎实,就想着从最底层的知识开始学起,概率论、线性代数、凸优化理论等。但是这样做比较耗时间,而且容易打消学习的积极性。啃书本和推导公式相对来说是比较枯燥的,远不如自己搭建一个简单的神经网络模型更能激发学习积极性。当然,基础理论知识非常重要!只是说,在入门的时候,最好先从顶层框架上有个系统的认识,然后再从实践到理论,有的放矢的查缺补漏深度学习知识点。从宏观到微观,从整体到细节,更有利于快速入门! 15 | 16 | ## 为什么写这本书 17 | 18 | 在学习深度学习的几年时间里,我学过一些国内外优秀的深度学习公开课程,口碑都很好;我也看过不少大牛老师写的高质量书籍,收获颇丰;我也在学习的过程中走过一些弯路,遇到一些坑,这些也是宝贵的经验。 19 | 20 | 我个人觉得,任何前沿技术,如深度学习,扎实的基础知识非常重要。而最好的基础知识的获取方式还是教材和书本。但是,反观现在一些深度学习方面的书籍,或多或少存在一些问题: 21 | 22 | 1、数学理论太多,公式多,起点高,对初学者不友善,容易削弱入门学习的积极性。 23 | 24 | 2、只讲深度学习框架,教你如何调包、调用库函数,不讲深度学习理论知识。容易造成一知半解,沦为“调包侠”。 25 | 26 | 3、理论与实战的脱节,过于侧重理论或者过于侧重实战,二者之间没有很好的融合。 27 | 28 | 基于以上这些问题,我认为写一本真正适合深度学习初学者的入门书籍非常必要。这样的书籍不仅要兼顾理论和实战,还应该将重难点知识通俗化讲解、全面细致。难度有阶梯性,照顾不同水平的读者。这样的书籍才能最大程度地让读者受益。 29 | 30 | 基于这样的考量,《深度学习入门:基于PyTorch与TensorFlow的理论与实战》跟大家见面了。 31 | 32 | ## 本书特色 33 | 34 | 我刚刚也说了,对于初学者而言,一本好的深度学习书籍,宗旨就是让读者能够轻轻松松地掌握知识、触类旁通。本书作为一本深度学习的入门书籍,对初学者是非常友好的。本书的内容来自于我多年的知识积累和技术沉淀,也是我的一份深度学习经验总结。 35 | 36 | 首先,这本书包含了Python的基本介绍。Python作为人工智能的首选语言,其重要性不言而喻。Python入门非常简单,本书将会对深度学习中所需的基本Python语法知识进行简明扼要的提炼和概括。如果有的读者之前没有接触过Python,那么本书将轻松带你入门。 37 | 38 | 其次,这本书介绍了如今主流的深度学习框架PyTorch和TensorFlow。通过本书,读者可以对这两个框架的基本语法和基础知识有一个系统的学习,夯实基础。如果你之前对PyTorch和TensorFlow不了解也没有关系,这本书也可以是这两个框架的知识学习手册。 39 | 40 | 然而,最重要的,这是一本关于深度学习的入门教程。我在编写该书的时候,从小白的视角出发,结合我多年的知识和经验,尽量将深度学习、神经网络的理论知识用通俗易懂的语言描绘出来。这本书能让读者真正了解、熟悉神经网络的结构和优化方法,也能帮助读者梳理一些容易被忽视的技术细节。例如最简单的梯度下降算法,它的公式来源和理论支持是什么?本书都会有详细的解释。 41 | 42 | 值得注意的是,我一贯的坚持是将复杂的理论简单化,这本书会将理论以平易化的语言描述清楚,但不会深陷于数学公式之中。这本书面向的深度学习的入门者和初学者,不会涉及太多太复杂的理论知识。因为入门深度学习,前期整体上的感性认识尤为重要。轻松入门,往往是比较正确的学习路线。我在编写该书的时候,也一直在把握这个分寸和尺度。如果想要学习更深层次、更高级的深度学习知识,读者可以查阅更多的书籍、会议论文、前沿技术等。 43 | 除此之外,深度学习更重要的是代码实践,这也是本书一直秉承的一个重点。这本书的另一个优势就是不仅讲理论知识,也配备了完整的实战项目和代码。从简单的逻辑回归,到浅层神经网络、深层神经网络,再到CNN、RNN,都会通过一个实际项目从零搭建神经网络,或者使用PyTorch、TensorFlow来构建更复杂的例如CNN、RNN模型解决问题。 44 | 45 | 本书的所有代码,我都开源放在了GitHub上,地址如下: 46 | 47 | [https://github.com/RedstoneWill/dl-from-scratch](https://github.com/RedstoneWill/dl-from-scratch) 48 | 49 | ## 面向的读者 50 | 51 | 这是一本深度学习的入门书籍,也是一本关于Python、PyTorch、TensorFlow的工具手册;这是一本深度学习的理论书籍,也是一本教你如何编写代码构建神经网络的实战手册。我希望这本书能够帮助更多想要入门深度学习的爱好者,能够帮助读者扫清学习过程中的障碍,再上新台阶。 52 | 53 | 本书面向的读者包括:深度学习初学者; 对深度学习感兴趣的在校大学生; 有意向转行AI领域的IT从业人员。当然,这本书也是不错的深度学习工具手册,里面不仅有理论知识,也有示例代码。 54 | 55 | 值得一提的是,如果你已经有很高的深度学习水平了,那么可能本书不太适合你,你应该更关注深度学习的前沿理论和论文。 56 | 57 | ## 关于作者 58 | 59 | 红色石头,北京大学硕士。专注AI领域多年,爱好写作,累计写过的AI领域的原创文章超过150篇,累计读者达20W,文章阅读量超100W。写作风格是擅长以通俗化的语言来解释机器学习、深度学习的算法理论和技术细节。 60 | 61 | 个人网站:[www.redstonewill.com](www.redstonewill.com) 62 | 63 | 创办了AI技术领域的微信公众号:AI 有道(ID:redstonewill),欢迎读者关注,方便第一时间获取机器学习、深度学习等有价值的干货分享和信息资源。 64 | 65 | ![](./微信公众号:AI有道.jpg) 66 | 67 | ## 书籍目录 68 | 69 | **第1章 深度学习基础** 1 70 | 71 | - 1.1 深度学习概述 1 72 | 73 | - 1.1.1 什么是深度学习 1 74 | 75 | - 1.1.2 深度学习的应用场景 3 76 | 77 | - 1.1.3 深度学习的发展动力 4 78 | 79 | - 1.1.4 深度学习的未来 5 80 | 81 | - 1.2 Python入门 6 82 | 83 | - 1.2.1 Python简介 6 84 | 85 | - 1.2.2 Python的安装 7 86 | 87 | - 1.2.3 Python基础知识 8 88 | 89 | - 1.2.4 NumPy矩阵运算 15 90 | 91 | - 1.2.5 Matplitlib绘图 20 92 | 93 | - 1.3 Anaconda与Jupyter Notebook 24 94 | 95 | - 1.3.1 Anaconda 25 96 | 97 | - 1.3.2 Jupyter Notebook 27 98 | 99 | **第2章 PyTorch** 34 100 | 101 | - 2.1 PyTorch简介 34 102 | 103 | - 2.1.1 什么是PyTorch 34 104 | 105 | - 2.1.2 为什么使用PyTorch 35 106 | 107 | - 2.2 PyTorch安装 36 108 | 109 | - 2.3 张量Tensor 39 110 | 111 | - 2.3.1 创建Tensor 39 112 | 113 | - 2.3.2 Tensor的数学运算 40 114 | 115 | - 2.3.3 Tensor与NumPy 41 116 | 117 | - 2.3.4 CUDA Tensor 42 118 | 119 | - 2.4 自动求导 autograd 43 120 | 121 | - 2.4.1 返回值是标量 43 122 | 123 | - 2.4.2 返回值是张量 44 124 | 125 | - 2.4.3 禁止自动求导 45 126 | 127 | - 2.5 神经网络包nn和优化器optim 45 128 | 129 | - 2.5.1 torch.nn 45 130 | 131 | - 2.5.2 torch.optim 46 132 | 133 | - 2.6 PyTorch线性回归 47 134 | 135 | - 2.6.1 线性回归基本原理 48 136 | 137 | - 2.6.2 PyTorch实现 49 138 | 139 | **第3章 TensorFlow** 53 140 | 141 | - 3.1 TensorFlow简介 53 142 | 143 | - 3.1.1 什么是TensorFlow 53 144 | 145 | - 3.1.2 为什么使用TensorFLow 54 146 | 147 | - 3.2 TensorFlow安装 54 148 | 149 | - 3.3 张量Tensor 56 150 | 151 | - 3.3.1 创建Tensor 56 152 | 153 | - 3.3.2 Tensor的数学运算 57 154 | 155 | - 3.4 数据流图 58 156 | 157 | - 3.5 会话Session 60 158 | 159 | - 3.6 TensorFlow线性回归 62 160 | 161 | - 3.7 TensorBoard 66 162 | 163 | - 3.7.1 TensorBoard代码 66 164 | 165 | - 3.7.2 TensorBoard显示 67 166 | 167 | **第4章 神经网络基础** 71 168 | 169 | - 4.1 感知机 71 170 | 171 | - 4.1.1 感知机模型 71 172 | 173 | - 4.1.2 感知机与逻辑电路 72 174 | 175 | - 4.2 多层感知机 77 176 | 177 | - 4.2.1 感知机的局限性 78 178 | 179 | - 4.2.2 多层感知机实现异或 79 180 | 181 | - 4.3 逻辑回归 81 182 | 183 | - 4.3.1 基本原理 82 184 | 185 | - 4.3.2 损失函数 84 186 | 187 | - 4.3.3 梯度下降 87 188 | 189 | - 4.3.4逻辑回归的Python实现 92 190 | 191 | **第5章 神经网络** 98 192 | 193 | - 5.1 基本结构 98 194 | 195 | - 5.2 前向传播 100 196 | 197 | - 5.3 激活函数 101 198 | 199 | - 5.4 反向传播 106 200 | 201 | - 5.5 更新参数 108 202 | 203 | - 5.6 初始化 108 204 | 205 | - 5.7 神经网络的Python实现 109 206 | 207 | - 5.7.1 准备数据 109 208 | 209 | - 5.7.2 参数初始化 110 210 | 211 | - 5.7.3 前向传播 111 212 | 213 | - 5.7.4 交叉熵损失 112 214 | 215 | - 5.7.5 反向传播 113 216 | 217 | - 5.7.6 更新参数 114 218 | 219 | - 5.7.7 构建整个神经网络模型 115 220 | 221 | - 5.7.8 训练 116 222 | 223 | - 5.7.9 预测 116 224 | 225 | **第6章 深层神经网络** 119 226 | 227 | - 6.1 神经网络为什么要深 119 228 | 229 | - 6.2 符号标记 121 230 | 231 | - 6.3 前向传播与反向传播 122 232 | 233 | - 6.4 多分类Softmax 125 234 | 235 | - 6.4.1 Softmax基本原理 126 236 | 237 | - 6.4.2 Softmax损失函数 127 238 | 239 | - 6.4.3 Softmax求导 128 240 | 241 | - 6.5 深层神经网络的Python实现 130 242 | 243 | - 6.5.1 准备数据 130 244 | 245 | - 6.5.2 参数初始化 133 246 | 247 | - 6.5.3 前向传播 134 248 | 249 | - 6.5.4 交叉熵损失 137 250 | 251 | - 6.5.5 反向传播 137 252 | 253 | - 6.5.6 更新参数 140 254 | 255 | - 6.5.7 构建整个神经网络 141 256 | 257 | - 6.5.8 训练与预测 143 258 | 259 | **第7章 优化神经网络** 146 260 | 261 | - 7.1 正则化 146 262 | 263 | - 7.1.1 什么是过拟合 146 264 | 265 | - 7.1.2 L1和L2正则化 149 266 | 267 | - 7.1.3 Dropout正则化 153 268 | 269 | - 7.1.4 其它正则化技巧 157 270 | 271 | - 7.2 梯度优化 159 272 | 273 | - 7.2.1 BGD、SGD、MBGD 159 274 | 275 | - 7.2.2 Momentum GD 163 276 | 277 | - 7.2.3 Nesterov Momentum 165 278 | 279 | - 7.2.4 AdaGrad 166 280 | 281 | - 7.2.5 RMSprop 167 282 | 283 | - 7.2.6 Adam 168 284 | 285 | - 7.2.7 Learning Rate Decay 169 286 | 287 | - 7.3 网络初始化与超参数调试 170 288 | 289 | - 7.3.1 输入标准化 171 290 | 291 | - 7.3.2 权重参数初始化 173 292 | 293 | - 7.3.3 批归一化 176 294 | 295 | - 7.3.4 超参数调试 179 296 | 297 | - 7.4 模型评估与调试 183 298 | 299 | - 7.4.1 模型评估 183 300 | 301 | - 7.4.2 训练/验证/测试集 184 302 | 303 | - 7.4.3 偏差与方差 187 304 | 305 | - 7.4.4 错误分析 187 306 | 307 | **第8章 卷积神经网络** 192 308 | 309 | - 8.1 为什么选择CNN 192 310 | 311 | - 8.2 CNN基本结构 193 312 | 313 | - 8.3 卷积层 194 314 | 315 | - 8.3.1 卷积 194 316 | 317 | - 8.3.2 边缘检测 196 318 | 319 | - 8.3.3 填充Padding 198 320 | 321 | - 8.3.4 步幅Stride 199 322 | 323 | - 8.3.5 CNN卷积 200 324 | 325 | - 8.3.6 卷积层的作用 205 326 | 327 | - 8.4 池化层 205 328 | 329 | - 8.5 全连接层 208 330 | 331 | - 8.6 CNN模型 210 332 | 333 | - 8.7 典型的CNN模型 213 334 | 335 | - 8.7.1 LeNet-5 213 336 | 337 | - 8.7.2 AlexNet 214 338 | 339 | - 8.8 CNN的PyTorch实现 215 340 | 341 | - 8.8.1 准备数据 215 342 | 343 | - 8.8.2 定义CNN模型 219 344 | 345 | - 8.8.3 损失函数与梯度优化 221 346 | 347 | - 8.8.4 训练模型 221 348 | 349 | - 8.8.5 测试模型 223 350 | 351 | - 8.9 CNN的TensorFlow实现 224 352 | 353 | - 8.9.1 准备数据 224 354 | 355 | - 8.9.2 定义CNN模型 225 356 | 357 | - 8.9.3 损失函数与优化算法 227 358 | 359 | - 8.9.4 训练并测试 228 360 | 361 | **第9章 循环神经网络** 229 362 | 363 | - 9.1 为什么选择RNN 229 364 | 365 | - 9.2 RNN基本结构 230 366 | 367 | - 9.3 模型参数 232 368 | 369 | - 9.4 梯度消失 234 370 | 371 | - 9.5 GRU 234 372 | 373 | - 9.6 LSTM 236 374 | 375 | - 9.7 多种RNN模型 237 376 | 377 | - 9.8 RNN的PyTorch实现 241 378 | 379 | - 9.8.1 准备数据 242 380 | 381 | - 9.8.2 定义RNN模型 244 382 | 383 | - 9.8.3 损失函数与梯度优化 246 384 | 385 | - 9.8.4 训练模型 246 386 | 387 | - 9.8.5 测试模型 247 388 | 389 | - 9.9 RNN的TensorFlow实现 248 390 | 391 | - 9.9.1 准备数据 249 392 | 393 | - 9.9.2 定义RNN模型 249 394 | 395 | - 9.9.3 损失函数与优化算法 250 396 | 397 | - 9.9.4 训练并测试 251 398 | 399 | **后记** 252 400 | 401 | 402 | 403 | ## 源代码目录 404 | 405 | - ## [01_the_foundation_of_deep_learning.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/01_the_foundation_of_deep_learning.ipynb) 406 | 407 | - ## [02_pytorch_tutorial.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/02_pytorch_tutorial.ipynb) 408 | 409 | - ## [03_tensorflow_tutorial.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/03_tensorflow_tutorial.ipynb) 410 | 411 | - ## [04_neural_network_foundation.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/04_neural_network_foundation.ipynb) 412 | 413 | - ## [05_neural_network.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/05_neural_network.ipynb) 414 | 415 | - ## [06_deep_neural_network.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/06_deep_neural_network.ipynb) 416 | 417 | - ## [08_cnn_pytorch.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/08_cnn_pytorch.ipynb) 418 | 419 | - ## [08_cnn_tensorflow.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/08_cnn_tensorflow.ipynb) 420 | 421 | - ## [09_rnn_pytorch.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/09_rnn_pytorch.ipynb) 422 | 423 | - ## [09_rnn_tensorflow.ipynb](https://github.com/RedstoneWill/dl-from-scratch/blob/master/09_rnn_tensorflow.ipynb) -------------------------------------------------------------------------------- /datasets/ch01/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch01/cat.jpg -------------------------------------------------------------------------------- /datasets/ch03/3-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch03/3-3.png -------------------------------------------------------------------------------- /datasets/ch06/test.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch06/test.rar -------------------------------------------------------------------------------- /datasets/ch06/train.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch06/train.rar -------------------------------------------------------------------------------- /datasets/ch08/pytorch/MNIST/processed/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/pytorch/MNIST/processed/test.pt -------------------------------------------------------------------------------- /datasets/ch08/pytorch/MNIST/processed/training.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/pytorch/MNIST/processed/training.pt -------------------------------------------------------------------------------- /datasets/ch08/pytorch/MNIST/raw/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/pytorch/MNIST/raw/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /datasets/ch08/pytorch/MNIST/raw/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 |datasets/ch08/pytorch/MNIST/raw/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/pytorch/MNIST/raw/train-images-idx3-ubyte -------------------------------------------------------------------------------- /datasets/ch08/pytorch/MNIST/raw/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/pytorch/MNIST/raw/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /datasets/ch08/tensorflow/MNIST/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/tensorflow/MNIST/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /datasets/ch08/tensorflow/MNIST/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/tensorflow/MNIST/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /datasets/ch08/tensorflow/MNIST/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/tensorflow/MNIST/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /datasets/ch08/tensorflow/MNIST/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/datasets/ch08/tensorflow/MNIST/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /datasets/readme.md: -------------------------------------------------------------------------------- 1 | # Datasets -------------------------------------------------------------------------------- /tf_logs/ch03/run-20190704075350/events.out.tfevents.1562226830.AB-201810292038: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/tf_logs/ch03/run-20190704075350/events.out.tfevents.1562226830.AB-201810292038 -------------------------------------------------------------------------------- /tf_logs/readme.md: -------------------------------------------------------------------------------- 1 | # TensorBoard -------------------------------------------------------------------------------- /微信公众号:AI有道.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedstoneWill/dl-from-scratch/34c87663e5ec30384129c528031c3e9085877418/微信公众号:AI有道.jpg --------------------------------------------------------------------------------