├── README.md
├── 《繁凡的深度学习笔记》前言、目录大纲
└── 《繁凡的深度学习笔记》前言、目录大纲.pdf
├── 深度学习经典书籍PDF版免费下载.md
├── 第 01 章 深度学习综述
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 02 章 回归问题与神经元模型
├── PDF文件(已更新)
│ └── 第 2 章 回归问题与神经元模型.pdf
└── 全部源码(已更新)
│ ├── Jupyter Notebook
│ └── 《繁凡的深度学习笔记》第 2 章 回归问题与神经元模型 2.2.3 神经元线性模型实战(1).ipynb
│ └── PyCharm
│ └── 《繁凡的深度学习笔记》第 2 章 回归问题与神经元模型 2.2.3 神经元线性模型实战.py
├── 第 03 章 分类问题与信息论基础
├── PDF文件(已更新)
│ └── 《繁凡的深度学习笔记》第 3 章 分类问题与信息论基础(DL笔记整理系列).pdf
└── 全部源码(已更新)
│ ├── PyTorch实现
│ ├── Jupyter Notebook
│ │ ├── (PyTorch)手写数字识别实战.ipynb
│ │ └── utils.py
│ └── PyCharm
│ │ ├── utils.py
│ │ └── 手写数字识别-pytorch.py
│ └── TensorFlow2.0实现
│ ├── Jupyter Notebook
│ ├── MNIST数据集的前向传播训练误差曲线.png
│ └── 第 03 章 TensorFlow2.0 手写数字图片实战 - 神经网络.ipynb
│ └── PyCharm
│ ├── MNIST数据集的前向传播训练误差曲线.png
│ └── 第 03 章 TensorFlow2.0 手写数字图片实战 - 神经网络.py
├── 第 04 章 TensorFlow2.0从入门到升天
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 05 章 PyTorch 从入门到升天
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 06 章 神经网络与反向传播算法
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 07 章 过拟合、优化算法与参数优化
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 08 章 卷积神经网络 (CNN) 从入门到升天
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 09 章 循环神经网络 (RNN) 从入门到升天
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 10 章 注意力机制与Transformer
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 11 章 图神经网络(万字综述)
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 12 章 自编码器(万字综述)
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 13 章 生成对抗网络(万字综述)
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 14 章 强化学习(万字综述)
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 15 章 元学习(万字综述)
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
├── 第 16 章 对抗攻击与防御(万字综述)
├── PDF文件(待更)
│ └── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
│ └── 正文很快更新哟^q^.pdf
└── 第 17 章 迁移学习(万字综述)
├── PDF文件(待更)
└── 正文很快更新哟^q^.pdf
└── 全部源码(待更)
└── 正文很快更新哟^q^.pdf
/README.md:
--------------------------------------------------------------------------------
1 |
这里是《繁凡的深度学习笔记》官方代码、PDF文件仓库如果觉得还不错的话,欢迎 ⭐ Starred !谢谢^q^
2 |
3 | 《繁凡的深度学习笔记》前言、目录大纲 (DL笔记整理系列)
4 |
5 | 一文弄懂深度学习所有基础 !
6 |
7 | 3043331995@qq.com
8 |
9 | https://fanfansann.blog.csdn.net/
10 |
11 | https://github.com/fanfansann/fanfan-deep-learning-note
12 |
13 | 作者:繁凡
14 |
15 | version 1.0 2022-1-20
16 |
17 |
18 |
19 | **声明:**
20 |
21 | 1)《繁凡的深度学习笔记》是我自学完成深度学习相关的教材、课程、论文、项目实战等内容之后,自我总结整理创作的学习笔记。写文章就图一乐,大家能看得开心,能学到些许知识,对我而言就已经足够了 \^q\^ 。这是我写的第二本书! (第一本: [《算法竞赛中的初等数论》](https://blog.csdn.net/weixin_45697774/article/details/113765056))现在竞赛退役的我更希望先沉淀,再输出,所以更新速度相较以前会稍慢一些,但相信质量也会更高,请大家见谅。
22 |
23 | 2)因个人时间、能力和水平有限,本文并非由我个人完全原创,文章部分内容整理自互联网上的各种资源,引用内容标注在每章末的参考资料之中。
24 |
25 | 3)本文仅供学术交流,非商用。所以每一部分具体的参考资料并没有详细对应。如果某部分不小心侵犯了大家的利益,还望海涵,并联系博主删除,非常感谢各位为知识传播做出的贡献!
26 |
27 | 4)本人才疏学浅,整理总结的时候难免出错,还望各位前辈不吝指正,谢谢。
28 |
29 | 5)本文由我个人( CSDN 博主 「繁凡さん」(博客) , 知乎答主 「繁凡」(专栏), Github 「fanfansann」(全部源码) , 微信公众号 「繁凡的小岛来信」(文章 P D F 下载))整理创作而成,且仅发布于这四个平台,仅做交流学习使用,无任何商业用途。
30 |
31 | 6)「我希望能够创作出一本清晰易懂、可爱有趣、内容详实的深度学习笔记,而不仅仅只是知识的简单堆砌。」
32 |
33 | 7)本文《繁凡的深度学习笔记》全汇总链接:[《繁凡的深度学习笔记》前言、目录大纲](https://fanfansann.blog.csdn.net/article/details/121702108) [https://fanfansann.blog.csdn.net/article/details/121702108](https://fanfansann.blog.csdn.net/article/details/121702108)
34 |
35 | 8)本文的Github 地址:[https://github.com/fanfansann/fanfan-deep-learning-note/](https://github.com/fanfansann/fanfan-deep-learning-note/) 孩子的第一个 [『Github』](https://github.com/fanfansann/fanfan-deep-learning-note/)!给我个 ⭐ Starred 嘛!谢谢!!o(〃^▽^〃)o
36 |
37 | 9)此属 version 1.0 ,若有错误,还需继续修正与增删,还望大家多多指点。本文会随着我的深入学习不断地进行完善更新,[Github](https://github.com/fanfansann/fanfan-deep-learning-note/) 中的 P D F 版也会尽量每月进行一次更新,所以建议点赞收藏分享加关注,以便经常过来回看!
38 |
39 |
40 | 10)本文同步于 [CSDN「繁凡さん」](https://blog.csdn.net/weixin_45697774/article/details/121702108)、[知乎「繁凡」](https://zhuanlan.zhihu.com/p/448041547) 与 [微信公众号 「繁凡的小岛来信」](https://mp.weixin.qq.com/s/Emv2OSEZLyawjqfao7WIOA)。
41 |
42 |
43 |
44 | **[更好的阅读体验!](https://mp.weixin.qq.com/s/Emv2OSEZLyawjqfao7WIOA)(我光排版就排了一个多小时!)**
45 |
46 | # 前言
47 |
48 | 这是一本面向深度学习初学者的 **深度学习笔记** 。相信大家在入门深度学习的时候都有一种感觉,各种好评如潮的深度学习课程,各种深入浅出的深度学习书籍,虽然都能很容易理解,但是学完之后会有一种 “我到底学了什么” 的空虚感。特别是在你完成了大量资料的学习,想要复习的时候,需要翻阅大量博客与听课笔记、读书笔记等,很是麻烦。如果大家阅读过花书《深度学习》这类人类圣经,往往会有一种晦涩难懂的感觉。这类书籍尽管比较全面,但是并不适合完全零基础的初学者阅读。此外,深度学习课程与书籍往往会因为篇幅限制而省略掉很多知识讲解,大多内容只是一概而论,并没有深入探讨。而本书以网络博客为载体,就不会有这方面问题的限制。因此本文旨在使用 **简明清晰、通俗易懂** 的语言帮助大家构建起 **全面** 的深度学习 **完整** 知识框架,轻松学懂学会深度学习。本书中对于深度学习的各个研究方向进行了详细讲解,同时各种知识拓展组成了每章的万字长文综述,也使得本文可以作为一本资料书进行使用。书中代码均使用 TensorFlow2.0 以及 Pytorch 双料实现,实用性强。
49 |
50 | 本书暂时共有 18 章,分为四个部分,后期将会慢慢继续拓展,敬请期待。
51 |
52 | **第一部分 第 1 章 ~ 第 3 章 为深度学习基础认知**
53 |
54 | **第 1 章 深度学习综述** 学习一个新的领域最好的入门方法就是阅读一篇综述。本章通过对深度学习几十年来的发展进行简要综述,帮助大家快速对深度学习这一领域建立起一个基础的认知,搭建起大致的知识体系框架。本章综述中讲到的大多数深度学习的研究方向内容在本书中相应章节均有详细讲解,帮助大家扎实深度学习基础。
55 |
56 | **第 2 章 回归问题与神经元模型** 回归问题是机器学习中较早就开始应用的学习模型,多用来预测一个具体的数值。在深度学习中,可以使用大量方法予以解决。作为深度学习入门要解决的第一个问题,本章引入对于深度学习非常重要的神经元模型,并利用神经元模型解决回归问题。最后探讨了非线性模型,并直观地展示了激活函数的作用。
57 |
58 | **第 3 章 分类问题** 回归问题是对真实值的一种逼近预测,而分类问题则是为事物打上标签,得到一个离散的值。回归模型与分类模型在本质上是相同的:分类模型可将回归模型的输出离散化,回归模型也可将分类模型的输出连续化。本章通过实战引入了分类问题,详细讲解了逻辑回归、 softmax 回归以及信息论基础的相关内容,并探讨了逻辑回归与 softmax 回归的关系。
59 |
60 | **第二部分 第 4 章 ~ 第 5 章 为深度学习框架讲解**
61 |
62 | **第 4 章 TensorFlow2.0从入门到升天** TensorFlow 是谷歌开源的一款深度学习框架,首发于 2015 年,采用静态图的TensorFlow1.x 尽管在性能方面较为强劲,但是由于实现以及调试方面的困难一直令人诟病。2019 年谷歌发布了TensorFlow2.0,采用动态图优先模式运行,避免 TensorFlow 1.x 版本的诸多缺陷,获得了业界的广泛认可,在工业界的部署应用最为广泛。本章从零开始讲解TensorFlow的使用、API、部署等全方位的知识,从零基础开始入门 TensorFlow2.0 直至升天。
63 |
64 | **第 5 章 PyTorch 从入门到升天** PyTorch 是 Facebook (Meta) 于 2017 年发布的一款深度学习框架,凭借其简洁优雅的设计、统一易用的接口、追风逐电的速度和变化无方的灵活性受到业界的一致好评。经过数年的发展,在学术界中逐渐占据主导地位。本章将详细讲解 PyTorch 框架的基础知识和使用方法,从零基础开始入门 Pytorch 直至升天。
65 |
66 | **第三部分 第 6 章 ~ 第 7 章 为深度学习基础**
67 |
68 | **第 6 章 神经网络与反向传播算法** 神经网络(Neural Network,NN),在机器学习和认知科学领域,是一种模仿生物神经网络的结构和功能的数学模型或计算模型,用于对函数进行估计或近似。近年来人们应用深层神经网络技术在计算机视觉、自然语言处理、机器人等领域取得了重大突破,部分任务上甚至超越了人类智能水平,引领了以深层神经网络为代表的第三次人工智能复兴。而深层神经网络也有另一个名字:深度学习。反向传播算法(Backpropagation,BP),是一种与最优化方法结合使用的,用来训练人工神经网络的常见方法。本章从感知机模型出发,引入神经网络模型,并介绍了十数种常用激活函数,通过数次实战扎实基础,最后详细地全流程推导了反向传播算法。
69 |
70 | **第 7 章 过拟合、优化算法与参数优化** 在深度学习实战训练中,我们往往会遇到各种问题。本章对这些问题进行总结,并给出一些训练常用技巧,引入过拟合概念并详细介绍了如何避免。通过前面的学习,我们知道深度学习训练中最常用的反向传播算法是一种与最优化方法相结合的算法,因此本章介绍了一些常用的优化算法,最后对参数优化进行了一些探讨。
71 |
72 | **第四部分 第 8 章 ~ 第 17 章 为深度学习研究方向综述**
73 |
74 | **第 8 章 卷积神经网络 (CNN) 从入门到升天** 卷积神经网络(Convolutional Neural Network, CNN),一种前馈神经网络,由若干卷积层和池化层组成,尤其在图像处理方面表现十分出色。卷积神经网络作为一种非常重要的主流深度学习模型,需要大家理解掌握。本章详细讲解了卷积神经网络的原理、实现、变种以及各种应用,从零基础开始入门卷积神经网络直至升天。
75 |
76 | **第 9 章 循环神经网络 (RNN) 从入门到升天** 循环神经网络(Recurrent Neural Network, RNN),一类以序列数据为输入,在序列的演进方向进行递归且所有节点按链式连接的递归神经网络。其特有的循环概念及其最重要的结构 “长短时记忆网络” 使得它在处理和预测序列数据的问题上有着良好的表现。循环神经网络作为一种非常重要的主流深度学习模型,需要大家理解掌握。本章详细讲解了循环神经网络的原理、实现、变种以及各种应用,从零基础开始入门循环神经网络直至升天。
77 |
78 | **第 10 章 注意力机制与Transformer** 注意力机制(Attention Mechanism),人们在机器学习模型中嵌入的一种特殊结构,用来自动学习和计算输入数据对输出数据的贡献大小。Transformer 是一种采用 self-attention 的深度学习模型,对输入数据的每个部分的重要性进行差分加权。在自然语言处理和计算机视觉领域有着非常广泛的应用。Attention Is All You Need!本章对注意力机制与 Transformer 的原理、实现、常用变种与应用进行了清晰易懂的综述详解,适合零基础入门研究学习。
79 |
80 | **第 11 章 图神经网络(万字综述)** 图神经网络(Graph Neural Networks,GNN),一种基于图结构的深度学习方法,从其定义中可以看出图神经网络主要由两部分组成,即图论中的图数据结构与深度学习中的神经网络(正巧我大学在 ACM 竞赛中就专门研究图论 x )。GNN 在处理非结构化数据时的出色能力使其在网络数据分析、推荐系统、物理建模、自然语言处理和图上的组合优化问题方面都取得了新的突破,成为各大深度学习顶会的研究热点。本章从图这一数据结构开始讲解,对图神经网络的原理、实现、常用变种与应用进行了清晰易懂的综述详解,适合零基础入门研究学习。
81 |
82 |
83 | **第 12 章 自编码器(万字综述)** 自编码器(autoencoder, AE),一类在半监督学习和非监督学习中使用的人工神经网络,其功能是通过将输入信息作为学习目标,对输入信息进行表征学习。自编码器具有一般意义上表征学习算法的功能,被应用于降维和异常值检测。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪 、神经风格迁移等。本章对自编码器的原理、实现、常用变种与应用进行了清晰易懂的综述详解,适合零基础入门研究学习。
84 |
85 |
86 | **第 13 章 生成对抗网络(万字综述)** 生成对抗网络(Generative Adversarial Network,GAN),一种非监督式学习的方法,通过让两个神经网络相互博弈的方式进行学习,是近年来复杂分布上无监督学习最具前景的方法之一。GAN 在图像生成,如超分辨率任务,语义分割等方面有着非常出色的表现。本章从动漫头像生成实战入手,对生成对抗网络的训练公式以及纳什均衡进行了详细的推导证明,分析了GAN的训练难题,并给出了相应的解决办法,然后对 WGAN 的原理和实现进行了详细讲解,最后探讨了 GAN 的应用,适合零基础入门研究学习。
87 |
88 | **第 14 章 强化学习(万字综述)** 强化学习(Reinforcement learning,RL),机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益。强化学习是除了监督学习和非监督学习之外的第三种基本的机器学习方法。深度强化学习(Deep Reinforcement Learning,DRL)是深度学习与强化学习相结合的产物,它集成了深度学习在视觉等感知问题上强大的理解能力,以及强化学习的决策能力,实现了端到端学习。深度强化学习的出现使得强化学习技术真正走向实用,得以解决现实场景中的复杂问题。本章对强化学习的原理、实现、常用变种与应用进行了清晰易懂的综述,适合零基础入门研究学习。
89 |
90 | **第 15 章 元学习(万字综述)** 元学习(Meta Learing),一种全新的机器学习方法,尝试学习如何学习。元学习的诞生可以追溯到上世纪八十年代,随着深度学习逐渐火热,元学习也回到了大众的视野当中。元学习与强化学习的结合更是借着深度学习的大潮,在各个领域扩展到了极致(例如人脸识别领域等,均可用元学习来加以强化 cross domain 的性能)。本章综述首先给出元学习相关的术语,并尝试对元学习进行一个定义,给出了一个简单示例去帮助理解元学习这一概念。按照实现方法将元学习分为了三种,并对每种方法的经典算法进行了详细的剖析。最后探讨了元学习的一些应用以及未来展望。
91 |
92 | **第 16 章 对抗攻击与防御(万字综述)** 对抗攻击(Adversarial Attack),对目标机器学习模型的原输入施加轻微扰动以生成对抗样本(Adversarial Example)来欺骗目标模型的过程。在深度学习算法驱动的数据计算时代,确保算法的安全性和健壮性至关重要。研究者发现深度学习算法无法有效地处理对抗样本。这些伪造的样本对人类的判断没有太大影响,但会使深度学习模型输出意想不到的结果。自 2013 年发现这一现象以来,它引起了人工智能多个子领域研究人员的极大关注,也是我目前研究的方向。本章综述首先讲解对抗攻击的概念以及如何简单地实现对抗攻击。然后给出对抗攻击相关的术语及其定义,对第一代最简单的对抗攻击的原理及实现进行详细介绍,然后分类探讨不同的对抗攻击,接着对对抗攻击的防御进行讲解。最后尝试对对抗攻击的理论原理以及机器学习的可解释性进行简单讨论。
93 |
94 | **第 17 章 迁移学习** 迁移学习(Transfer Learning),一种机器学习方法,把一个领域(即源领域)的知识,迁移到另外一个领域(即目标领域),使得目标领域能够取得更好的学习效果。在许多机器学习和数据挖掘算法中,一个重要的假设就是目前的训练数据和将来的训练数据,一定要在相同的特征空间并且具有相同的分布。然而,在许多现实的应用案例中,这个假设可能不会成立。比如,我们有时候在某个感兴趣的领域有分类任务,但是我们只有另一个感兴趣领域的足够训练数据,并且后者的数据可能处于与之前领域不同的特征空间或者遵循不同的数据分布。这类情况下,如果知识的迁移做的成功,我们将会通过避免花费大量昂贵的标记样本数据的代价,使得学习性能取得显著的提升。近年来,为了解决这类问题,迁移学习作为一个新的学习框架出现在人们面前。本章对迁移学习的原理、实现、常用变种与应用进行了清晰易懂的综述详解,适合零基础入门研究学习。
95 |
96 | 深度学习作为国家人工智能战略发展的重要一环,是一个非常前沿且广袤的研究领域。而自古以来,我国知识分子素有 “为天地立心,为生民立命,为往圣继绝学,为万世开太平” 的志向和传统。作为新一代青年,自当立时代之潮头。笔者希望通过这份学习笔记,分享和传播更多的知识。如能帮助各位读者搭建深度学习知识体系,在自己的研究领域发光发热,那便是笔者最大的荣幸。
97 |
98 | 繁凡
99 | 2021 年 12 月 20 日
100 | fanfansann.blog.csdn.net
101 | zhihu.com/people/fanfansann
102 |
103 |
104 | ---
105 |
106 |
107 | **请注意:这里仅是《繁凡的深度学习笔记》的前言、目录大纲,文章的正文还未正式发布。如果对本文感兴趣想让我更新的话,可以给在点赞后移步评论区文明催更,我将视情况 ~~(看心情 )~~ 进行更新 o(〃^▽^〃)o 同时也请给本文的 [Github项目](https://github.com/fanfansann/fanfan-deep-learning-note/) 点一个 ⭐ Starred ,这样我才有更新下去的动力!更新将以每章一篇文章为单位进行发布,届时将会把文章链接更新至本文目录中,点击目录对应章节标题即可进行跳转。所以也请给我一个关注,防止错过更新 \^q\^**
108 |
109 |
110 |
113 |
114 | 谢谢!!!
118 |
119 |
120 |
121 | **请注意:这里仅是《繁凡的深度学习笔记》的目录大纲,文章的正文还未正式发布。如果对本文感兴趣想让我更新的话,可以给在点赞后移步评论区文明催更,我将视情况 ~~(看心情 )~~ 进行更新 o(〃^▽^〃)o 同时也请给本文的 [Github项目](https://github.com/fanfansann/fanfan-deep-learning-note/) 点一个 ⭐ Starred ,这样我才有更新下去的动力!更新将以每章一篇文章为单位发布,届时将在将文章链接更新至本文目录中,点击目录对应章节标题即可进行跳转。所以也请给我一个关注,防止错过更新 \^q\^**
122 |
123 |
124 |
125 |
126 |
127 |
128 |
131 |
132 | 谢谢!
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/《繁凡的深度学习笔记》前言、目录大纲/《繁凡的深度学习笔记》前言、目录大纲.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/《繁凡的深度学习笔记》前言、目录大纲/《繁凡的深度学习笔记》前言、目录大纲.pdf
--------------------------------------------------------------------------------
/深度学习经典书籍PDF版免费下载.md:
--------------------------------------------------------------------------------
1 | **目录**
2 |
3 | [《统计学习方法》李航](#jump1)
4 |
5 | [《机器学习》周志华](#jump2)
6 |
7 | [《深度学习》花书](#jump3)
8 |
9 | [《信息论基础》](#jump4)
10 |
11 | [《繁凡的深度学习笔记》](#jump5)
12 |
13 | ---
14 |
15 |
16 |
17 | ## 1. 《统计学习方法》李航 第二版
18 |
19 | 《统计学习方法》第二版 PDF + 课件
20 |
21 | 在我的公众号:**繁凡的小岛来信** 中
22 |
23 | 回复关键字:**统计学习方法**
24 |
25 | 即可免费获得 **《统计学习方法》第二版 PDF + 课件文件百度网盘下载链接**
26 |
27 | ## 2. 《机器学习》周志华
28 |
29 | 《机器学习》周志华
30 |
31 | 在我的公众号:**繁凡的小岛来信** 中
32 |
33 | 回复关键字:**西瓜书**
34 |
35 | 即可免费获得 **《机器学习》周志华PDF文件百度网盘下载链接**
36 |
37 | ## 3. 《深度学习》花书
38 |
39 | 《深度学习》花书 中文版
40 |
41 | 在我的公众号:**繁凡的小岛来信** 中
42 |
43 | 回复关键字:**花书**
44 |
45 | 即可免费获得 **《深度学习》花书中文版 PDF 版本文件百度网盘下载链接**
46 |
47 | ## 4.《信息论基础》
48 |
49 | 《信息论基础》
50 |
51 | 在我的公众号:**繁凡的小岛来信** 中
52 |
53 | 回复关键字:**信息论**
54 |
55 | 即可免费获得 **信息论经典之作《Elements of Information Theory》(Thomas M. Cover)及其中译本《信息论基础》的 PDF 版本文件百度网盘下载链接**
56 |
57 |
58 |
59 | ## 5. 《繁凡的深度学习笔记》
60 |
61 | 《繁凡的深度学习笔记》
62 |
63 | 在我的公众号:**繁凡的小岛来信** 中
64 |
65 | 回复关键字:**深度学习笔记+第 x 章**,例如:深度学习笔记第二章
66 |
67 | 即可免费获得 **已更新的《繁凡的深度学习笔记》的 PDF 版本文件百度网盘下载链接**
68 |
69 |
70 |
71 |
72 | 谢谢!
76 |
77 |
78 |
79 |
80 |
81 |
84 |
85 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/第 01 章 深度学习综述/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 01 章 深度学习综述/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 01 章 深度学习综述/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 01 章 深度学习综述/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 02 章 回归问题与神经元模型/PDF文件(已更新)/第 2 章 回归问题与神经元模型.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 02 章 回归问题与神经元模型/PDF文件(已更新)/第 2 章 回归问题与神经元模型.pdf
--------------------------------------------------------------------------------
/第 02 章 回归问题与神经元模型/全部源码(已更新)/Jupyter Notebook/《繁凡的深度学习笔记》第 2 章 回归问题与神经元模型 2.2.3 神经元线性模型实战(1).ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "4a665c09",
6 | "metadata": {},
7 | "source": [
8 | " 在理解了神经元线性模型的原理以及各种优化算法以后,我们来实战训练单输入神经元线性模型。\n",
9 | "\n",
10 | " 首先我们引入需要的包。"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "498ca646",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "%matplotlib inline\n",
21 | "import numpy as np\n",
22 | "import math\n",
23 | "from matplotlib import pyplot as plt\n",
24 | "# cal y = 1.477x + 0.089 + epsilon,epsilon ~ N(0, 0.01^2)"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "2e390add",
30 | "metadata": {},
31 | "source": [
32 | "**1. 生成数据集**\n",
33 | "\n",
34 | " 我们需要采样自真实模型的多组数据,对于已知真实模型的 **玩具样例** (Toy Example),我们直接从指定的 $w = 1.477 , b = 0.089$ 的真实模型中直接采样:\n",
35 | "$$\n",
36 | "y=1.477 \\times x+0.089\n",
37 | "$$\n",
38 | "\n",
39 | "\n",
40 | "\n",
41 | " 为了能够很好地模拟真实样本的观测误差,我们给模型添加误差自变量 $\\epsilon$ ,它采样自均值为 $0$ ,方差为 $0.01$ 的高斯分布:\n",
42 | "$$\n",
43 | "y=1.477 x+0.089+\\epsilon, \\epsilon \\sim \\mathcal{N}(0,0.01)\n",
44 | "$$\n",
45 | "\n",
46 | " 我们通过随机采样 $n = 100$ 次,我们获得 $n$ 个样本的训练数据集 $\\mathbb D_{\\mathrm{train}}$ ,然后循环进行 $100$ 次采样,每次从均匀分布 $U ( -10,10)$ 中随机采样一个数据 $x$ 同时从均值为 $0$ ,方差为 $0.1^{2}$ 的高斯分布 $\\mathcal{N}\\left(0,0.1^{2}\\right)$ 中随机采样噪声 $\\epsilon$,根据真实模型生成 $y$ 的数据,并保存为 $\\text{Numpy}$ 数组。"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "id": "4bbabdb8",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "def get_data():\n",
57 | " # 计算均方误差\n",
58 | " #保存样本集的列表\n",
59 | " data = [] \n",
60 | " for i in range(100):\n",
61 | " x = np.random.uniform(-10., 10.) # 随机采样 x\n",
62 | " # 高斯噪声\n",
63 | " eps = np.random.normal(0., 0.01) # 均值和方差\n",
64 | " # 得到模型的输出\n",
65 | " y = 1.477 * x + 0.089 + eps\n",
66 | " # 保存样本点\n",
67 | " data.append([x, y]) \n",
68 | " # 转换为2D Numpy数组\n",
69 | " data = np.array(data) \n",
70 | " return data"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "id": "85e8343d",
76 | "metadata": {},
77 | "source": [
78 | "**2. 计算误差**\n",
79 | "\n",
80 | " 循环计算在每个点 $\\left(x^{(i)}, y^{(i)}\\right)$ 处的预测值与真实值之间差的平方并累加,从而获得训练集上的均方差损失值。\n",
81 | "\n",
82 | " 最后的误差和除以数据样本总数,从而得到每个样本上的平均误差。"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 3,
88 | "id": "9d85924c",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "def mse(b, w, points) :\n",
93 | " totalError = 0\n",
94 | " # 根据当前的w,b参数计算均方差损失\n",
95 | " for i in range(0, len(points)) : # 循环迭代所有点\n",
96 | " # 获得 i 号点的输入 x\n",
97 | " x = points[i, 0]\n",
98 | " # 获得 i 号点的输出 y\n",
99 | " y = points[i, 1]\n",
100 | " # 计算差的平方,并累加\n",
101 | " totalError += (y - (w * x + b)) ** 2\n",
102 | " # 将累加的误差求平均,得到均方误差\n",
103 | " return totalError / float(len(points))"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "id": "8953d873",
109 | "metadata": {},
110 | "source": [
111 | "**3. 计算梯度**\n",
112 | "\n",
113 | " 这里我们使用更加简单好用的梯度下降算法。我们需要计算出函数在每一个点上的梯度信息: $\\left(\\dfrac{\\partial \\mathcal{L}}{\\partial w}, \\dfrac{\\partial \\mathcal{L}}{\\partial b}\\right)$。我们来推导一下梯度的表达式,首先考虑 $\\dfrac{\\partial \\mathcal{L}}{\\partial w}$ ,将均方差函数展开: \n",
114 | "\n",
115 | "$$\n",
116 | "\\begin{aligned}\\frac{\\displaystyle \\partial \\mathcal{L}}{\\partial w}&=\\frac{\\displaystyle \\partial \\frac{1}{n} \\sum_{i=1}^{n}\\left(w x^{(i)}+b-y^{(i)}\\right)^{2}}{\\partial w}&\\\\&=\\frac{1}{n} \\sum_{i=1}^{n} \\frac{\\partial\\left(w x^{(i)}+b-y^{(i)}\\right)^{2}}{\\partial w}\\end{aligned}\n",
117 | "$$\n",
118 | "\n",
119 | "由于:\n",
120 | "\n",
121 | "$$\n",
122 | "\\frac{\\partial g^{2}}{\\partial w}=2 \\cdot g \\cdot \\frac{\\partial g}{\\partial w}\n",
123 | "$$\n",
124 | "\n",
125 | "则有:\n",
126 | "\n",
127 | "$$\n",
128 | "\\begin{aligned}\\frac{\\partial \\mathcal{L}}{\\partial w}&=\\frac{1}{n} \\sum_{i=1}^{n} 2\\left(w x^{(i)}+b-y^{(i)}\\right) \\cdot \\frac{\\partial\\left(w x^{(i)}+b-y^{(i)}\\right)}{\\partial w} &\\\\&=\\frac{1}{n} \\sum_{i=1}^{n} 2\\left(w x^{(i)}+b-y^{(i)}\\right) \\cdot x^{(i)} &\\\\&=\\frac{2}{n} \\sum_{i=1}^{n}\\left(w x^{(i)}+b-y^{(i)}\\right) \\cdot x^{(i)}\\end{aligned}\n",
129 | "$$\n",
130 | "\n",
131 | "\n",
132 | "$$\n",
133 | "\\begin{aligned}\\dfrac{\\partial \\mathcal{L}}{\\partial b}&=\\dfrac{\\displaystyle \\partial \\dfrac{1}{n} \\sum_{i=1}^{n}\\left(w x^{(i)}+b-y^{(i)}\\right)^{2}}{\\partial b}\\\\&=\\frac{1}{n} \\sum_{i=1}^{n} \\frac{\\partial\\left(w x^{(i)}+b-y^{(i)}\\right)^{2}}{\\partial b} &\\\\&=\\frac{1}{n} \\sum_{i=1}^{n} 2\\left(w x^{(i)}+b-y^{(i)}\\right) \\cdot \\frac{\\partial\\left(w x^{(i)}+b-y^{(i)}\\right)}{\\partial b} &\\\\&=\\frac{1}{n} \\sum_{i=1}^{n} 2\\left(w x^{(i)}+b-y^{(i)}\\right) \\cdot 1 &\\\\&=\\frac{2}{n} \\sum_{i=1}^{n}\\left(w x^{(i)}+b-y^{(i)}\\right)\\end{aligned}\n",
134 | "$$\n",
135 | "\n",
136 | " 根据上面偏导数的表达式,我们只需要计算在每一个点上面的 $\\left(w x^{(i)}+b-y^{(i)}\\right)$ 和 $\\left(w x^{(i)}+b-y^{(i)}\\right)$ 值,平均后即可得到偏导数 $\\dfrac{\\partial \\mathcal{L}}{\\partial w}$ 和 $\\dfrac{\\partial \\mathcal{L}}{\\partial b}$ 。 "
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 4,
142 | "id": "1e63ea91",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "# 计算偏导数\n",
147 | "def step_gradient(b_current, w_current, points, lr) :\n",
148 | " # 计算误差函数在所有点上的异数,并更新w,b\n",
149 | " b_gradient = 0\n",
150 | " w_gradient = 0\n",
151 | " # 总体样本\n",
152 | " M = float(len(points))\n",
153 | " for i in range(0, len(points)) :\n",
154 | " x = points[i, 0]\n",
155 | " y = points[i, 1]\n",
156 | " # 偏b\n",
157 | " b_gradient += (2 / M) * ((w_current * x + b_current) - y)\n",
158 | " # 偏w\n",
159 | " w_gradient += (2 / M) * x * ((w_current * x + b_current) - y)\n",
160 | " # 根据梯度下降算法更新的 w',b',其中lr为学习率\n",
161 | " new_b = b_current - (lr * b_gradient)\n",
162 | " new_w = w_current - (lr * w_gradient)\n",
163 | " return [new_b, new_w]\n",
164 | " \n",
165 | "plt.rcParams['font.size'] = 16\n",
166 | "plt.rcParams['font.family'] = ['STKaiti']\n",
167 | "plt.rcParams['axes.unicode_minus'] = False\n",
168 | "\n",
169 | "# 梯度更新\n",
170 | "def gradient_descent(points, starting_b, starting_w, lr, num_iterations) :\n",
171 | " b = starting_b\n",
172 | " w = starting_w\n",
173 | " MSE = []\n",
174 | " Epoch = []\n",
175 | " for step in range(num_iterations) :\n",
176 | " b, w = step_gradient(b, w, np.array(points), lr)\n",
177 | " # 计算当前的均方误差,用于监控训练进度\n",
178 | " loss = mse(b, w, points)\n",
179 | " MSE.append(loss)\n",
180 | " Epoch.append(step)\n",
181 | " if step % 50 == 0 :\n",
182 | " print(f\"iteration:{step}, loss:{loss}, w:{w}, b:{b}\")\n",
183 | " plt.plot(Epoch, MSE, color='C1', label='均方差')\n",
184 | " plt.xlabel('epoch')\n",
185 | " plt.ylabel('MSE')\n",
186 | " plt.title('MSE function')\n",
187 | " plt.legend(loc = 1)\n",
188 | " plt.show()\n",
189 | " return [b, w] "
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "id": "f7a02b2d",
195 | "metadata": {},
196 | "source": [
197 | "**4. 主函数**\n",
198 | "\n"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 5,
204 | "id": "3ee69c2a",
205 | "metadata": {
206 | "scrolled": true
207 | },
208 | "outputs": [
209 | {
210 | "name": "stdout",
211 | "output_type": "stream",
212 | "text": [
213 | "iteration:0, loss:8.52075121569461, w:0.9683621336270813, b:0.018598967590321615\n",
214 | "iteration:50, loss:0.0005939300597845278, w:1.477514542941938, b:0.06613823978315139\n",
215 | "iteration:100, loss:0.00016616611251547874, w:1.4772610937560182, b:0.08026637756911292\n",
216 | "iteration:150, loss:0.00010824080152649426, w:1.4771678278317757, b:0.08546534407456151\n",
217 | "iteration:200, loss:0.00010039689211198855, w:1.4771335072140424, b:0.08737849449355252\n",
218 | "iteration:250, loss:9.933471542527609e-05, w:1.4771208776838989, b:0.08808250836281861\n",
219 | "iteration:300, loss:9.919088162325623e-05, w:1.477116230185043, b:0.08834157608859644\n",
220 | "iteration:350, loss:9.917140448460003e-05, w:1.4771145199673728, b:0.08843690956066241\n",
221 | "iteration:400, loss:9.916876700352793e-05, w:1.477113890630052, b:0.08847199100845321\n",
222 | "iteration:450, loss:9.916840985114966e-05, w:1.4771136590423013, b:0.08848490051392309\n",
223 | "iteration:500, loss:9.916836148764827e-05, w:1.4771135738210939, b:0.08848965103996947\n",
224 | "iteration:550, loss:9.916835493854371e-05, w:1.4771135424608248, b:0.08849139917027324\n",
225 | "iteration:600, loss:9.916835405170177e-05, w:1.4771135309206636, b:0.08849204245893828\n",
226 | "iteration:650, loss:9.916835393161082e-05, w:1.4771135266740378, b:0.08849227918059785\n",
227 | "iteration:700, loss:9.916835391534817e-05, w:1.4771135251113363, b:0.08849236629101521\n",
228 | "iteration:750, loss:9.916835391314785e-05, w:1.477113524536283, b:0.08849239834648838\n",
229 | "iteration:800, loss:9.916835391284828e-05, w:1.477113524324671, b:0.08849241014247554\n",
230 | "iteration:850, loss:9.916835391280702e-05, w:1.4771135242468005, b:0.08849241448324166\n",
231 | "iteration:900, loss:9.916835391280325e-05, w:1.4771135242181452, b:0.08849241608058574\n",
232 | "iteration:950, loss:9.916835391280336e-05, w:1.4771135242076006, b:0.08849241666838711\n"
233 | ]
234 | },
235 | {
236 | "data": {
237 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEjCAYAAAA1ymrVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhmklEQVR4nO3deZhcVb3u8e+bpBMSMAlKJ8RACOFoQAMHsFU8xIgMKnjzPGoExBBA0RDUqMHhhEEPOZKTeC6DghckiASuuUo4wPUgg4IMJipg8BLmQRkiYQgJGSATGX73j72rqd7dXemu7qrq3vV+nqee7trTWquq+61Va0+KCMzMrH70qXUFzMysuhz8ZmZ1xsFvZlZnHPxmZnXGwW9mVmcc/GZmdcbBb1UjqZ+kPunv/SWpg+v1kTSosrUrWf7RkhZJ2iTpeUmjalUXs+7Qr9YVsLqyF3CopJnAfsAU4BdtLSipH/AcsDtwBXA7cEM67zPAvwAbgZHAMODfI+J+SbsBE4HJwBHp9v+Y2fzewHeBgRGxqVSFJX0C2CUiPizpAOAQYEXnmt39JC0Ffh8RZ9S6Ltb7yCdwWbVJ+gFwIrAmIg5qZ5nPAlcBD0fEvxRN/zIwOiLOziw7KSJOKJp2Srr+RyPi7ja2fydwROzgH0DSvRFxSCeaVxWSvgU8FhG31rou1vt4qMdqYRvwI+BASYe3s8zhwAPAm5nppwMLiydExH8Bf+pkHX7cgdDfC3hnJ7fb7SSNkjS0eFpEXODQt3I5+K1WrgTWAN/KzpC0H/BEO+ttAs6SNKB4YkRc0pnCI+LXHVjs3cD2zmy3Qi4Chta6EpYfDn6riYh4A7gcOFrSvpnZJwPXtLPqRcBxwEOSJnV0B3ExSed0YJkJJN9KRkq6LX0cJ2mqpLWSzk2X21vSGZKWSZpftH4/SbdK2lXSf0h6TNJqSXMy5fSRdJqkX6WP1ZJukLSLpL0k3QB8Bpgn6f9KOi9db5KkszLbOkTSTyR9XdLpki6XND6zzDslXS1prKSrJL0k6QVJkzr7OlovFhF++FHVB3Bu+nMkyVDOvKJ5uwAXpr/fDdzdxvqfA1YBQTIcdFgby5ySzv8J8M2ix0XAXzpQx51JPmCijXl3F9pQNG0+MD/9fXfg0bT8HwHvTKefmk4bV7TepcASkh3NAEcCfwD2SJ+PTtcZXbTONGBrcR3S9Z4BhhZN2xX4R+H1ASYAq4FXSHZuDyI5wOMq4FXSfX5+5P/hHr/VTEQsB34JTJHUmE4+MZ1War1fAWNJQnUccJeki9pZ/L8i4kdFjxnAYx2o23rKPHonIl4Gvpw+PSMiXkynXwm8BrwXQNKuwFTghxGxMV3mjoiYEBEvlNj+T4GbMpMvBq6MiDVFy60mGVK7MH3+B+B/Ak9GxH9GxIaI2JrO343k6CirAw5+q7ULgZ2Ar6bPPxARf9nRShGxMg3xfUl64N+U9LEOlnlbORXtpELYZ/cRvA4MTH8fC/Ql+XbQWWsLv6SHsO4HPN7Gcg8BBxXtHH6R1vstXk9/DsTqgoPfaioilgJ3AF+RdBStj7lvQdKhmfWfBT4OPAUc2MFir0u3tZOkWvZyC+E9tJu2lz0CCt4K+YZuKsNywMFvPcH5QCPwv9jBMA9wdnZCRLxJchTQ8o4Ulg5vQLITeUTHq9lsIzBgh0vt2N9Ixtwnllhm2442EhErSdqf3UkOcADwRES8WlYNLZcc/FYL2UMxfws8AtwaERuKZvWh6OzydLjiaEnnFB/NI+kgYDhwfUcrkG7rTEoPs/RNl31bZvr9wMckDZM0VNIXgH8G3pk9zLStotMHEbEFOBeYIenE9AifnSR9VdKH0uUL3wpGpHVpbzjmDGBacV0lDSfZh/DVdtYprlPxT8s5X7LBqkbS/wA+CpwiaQNJj/fGSC6bcD5JoCLpfcCngCagr6RLgcXAtcBJwBeA8ZIeJul9bwc+ERGb0uvoTAA+kRY7XVJTpipvIzlEcmBR7z9b1w8A306f/krSb4CfR8Rm4ALgfcCzaZ3PAO4BRqVtu5fkRDMkzQN+BrxAcjTSO4GTJT0XEXdFxMXph8Vc4MckPfe5EfFngIhYJ+laksM57wH+IellkhPcXpL0UETcEBG3SjoVuEDSoyQfmu8BPhMRhdf1GJLLZBws6d+BeSQ7dL+UtvNMSRdExJNtvoGWG75kg/UIkhQ96I9RUl+SQzk7fQJX8brpN5M+EbHDIZsObrsfyTeRN0nCfaf0CKSOrrstIqIr7bPez8FvZlZnPMZvZlZnHPxmZnXGwW9mVmd6xVE9u+22W4wePbrW1TAz61UeeOCBlRHRmJ3eK4J/9OjRLFmypNbVMDPrVSQ939Z0D/WYmdUZB7+ZWZ1x8JuZ1RkHv5lZnXHwm5nVGQe/mVmd6RWHc5pZ9a1bt44VK1awZcuWWlfFMhoaGhg2bBiDBw8ua/18B//9V8CrT8AnL6h1Tcx6lXXr1vHKK68wcuRIBg4cSNHtD6zGIoKNGzeyfHly36Fywj/fQz3LH4Cnf1frWpj1OitWrGDkyJEMGjTIod/DSGLQoEGMHDmSFStWlLWNfAc/gK86bdZpW7ZsYeBA33u9Jxs4cGDZw3A5D373VMzK5Z5+z9aV9yfnwQ/u8puZtZTv4HePxcyslXwHv5mZtZL/4Pc9hc2syLXXXssnP/lJ1qxZ02L6V77yFWbNmtWpbS1YsIDBgwdz2mmnsX17x+9b/8orrzBu3Dhuvvnm5mnf//73mTRpEtu2betUHcqR8+D3UI+ZtTRx4kRuu+02nn++5aXq77//fv7pn/6pU9uaPHkyV111FQsWLOCxxx7r8HrDhw/nnHPO4Y9//COQfIC8973vZcGCBfTt27dTdShHVU/gUrIbegawCdgZGAj8IKKS3XL3+M3sLYMGDWLo0KG8613vap527733sv/++zN58uROb2/SpEkcddRRLU6k2rBhAwMGDGg3xO+//362bNnCTjvtxLRp01i/fj2bNm3ioYce4txzz6WhoaHzDeuEap+5+yXgpYj4JYCk7wEnA/MrUpo7/Gbd69aZ8PLDta3D7vvD0XO7tIldd92VQYMGAcmZsLNmzWLixInMnDmTxYsX8+yzzzJ37lymTJlScjsRwTPPPMODDz7I0qVLWbp0KQ899BDLly9nyJAh3HLLLbz//e9vtd64cePYc889efrpp5k9eza//e1vu9Sezqp28E8EZhY9/0s6bX7FSvQYv5kV2bhxIxHBjBkzGD9+POvWrWPOnDk88sgjTJ48mblz5zJ9+nSuueaaFsH/8ssv8/DDD/Pkk0/y6KOPsnr1avr06cPixYv5/Oc/z/HHH8/06dN5xzveQZ8+fYiIdo+1X7ZsGffddx833XQTL730Ep///OcZNmwY5513HrvsskvFX4NqB/+LwGeAwmDYwcCfK1ecu/xm3aqLPe1aigi+973vcd999/Hmm29yzDHHMGrUKJYtW8bgwYM5/vjjm4dY9tprL97znve0WH/t2rUsXryYAw44gOOPP57GxkYigr333pv99tuPgw46qMXypU6w2mOPPdiwYQMXXnghV199NQcccABbt27l5ZdfpqGhgQEDBnT/C1Ck2sE/F/izpHcBtwCrIuIXbS0oaSowFWDUqFFdKNI9fjNLgvi8884DYMKECRx11FGsWbOGsWPHcvHFF3P55Zdz55130r9/f1avXs3uu+/eYv2xY8e2Oupn4cKFvP3tb+fvf/87kyZN4kMf+hAzZswouYN23bp1LFu2jMcff5yXXnqJWbNm8dxzz7Fq1Sr69+/Phz/8YebPn9/t7S9W1eCPiOcknUwy3HM5cKOkBRHxRhvLzgPmATQ1NZWX3j6By8wytm3bxvbt25k7dy5HHnkkTU1NfP3rX+f555/nxz/+Meeeey6rVq1q3gfQnrVr1/Ld736XG2+8kYMPPpglS5bwgQ98gPHjx3PIIYe0uc4f/vAH5syZQ2NjI4sWLeLSSy9lwoQJDBs2rBJNbVe1j+o5Fvh4RBwu6e3Ar0jG9z9bzXqYWX2KCG666SYaGxs5/fTTGThwIHfddRcXXHABJ510EscddxwATz/9NCeccELJbU2dOpWTTjqJgw8+GIDNmzfzwQ9+sN3Qh+SbxoQJE1i8eDERwaRJk1i2bBm/+c1veOSRR2hsbOTUU0/tvga3o9pDPd8EzgCIiNcknQC8JGlwRKyrSIneuWtmwM9+9jPmzJnDlClTuOGGG5DE3XffzWmnncbChQs58MADAdi6dStLlizh3e9+d7vbmj17Nps2bWox9HP77bczYcKEHdZjxYoVzJw5k2uuuYazzjqLO+64g/79+zNt2jQOPfTQLrezI6od/LsAzQEfEaskrQIqdIsfD/WYWWLKlCl89KMfZZ999mmedthhh/HUU08BsHr1aoYOHcott9zCgAEDGDFiRJvbmT17Ng8++CALFy4kIprH5q+44gquvvrqknVYs2YNU6ZM4ZJLLmHMmDHMmTOHfffdl7vuumuHh452p2qfuXsbcFjhiaT9gTsjYmPlinSP38xgwIABLUI/a/Xq1eyzzz587nOf49hjj201f/ny5Zx++uk0NjaycOHC5hO0TjzxREaMGMHuu+/OEUcc0e72X3zxRb74xS8yd+7c5iOA1q9fzwsvvMCjjz7KxIkTGTduHDvvvDN77LEHN954Y9cb3Y5q9/jPBWZJ+jbwBjACmFax0rxz18w6aMyYMXz729/m+uuv54c//GGr+du2beOSSy6hX7+WsXnOOecwefJkrrnmmnYP4Xzssce47rrruPLKK1m0aBFnnnkmI0eOZPTo0eyzzz5ceOGFDB06tPlwUkmMHDmy+xuZUkWvltBNmpqaYsmSJZ1f8aZvwBO3wHee7v5KmeXY448/zn777VfratgO7Oh9kvRARDRlp/sibWZmdSbnwQ8e4zczaynfwe8xfjOzVvId/ODj+M3K1Bv2/9Wzrrw/OQ9+9/jNytHQ0MDGjRU8ytq6bOPGjWVftz/nwW9m5Rg2bBjLly9nw4YN7vn3MBHBhg0bWL58ednX+Kn2cfw14D9as84q3E3qxRdfZMuWCp1Yb2VraGhg+PDhLe761Rn5Dn7v3DUr2+DBg8sOFuvZ8j/U46+pZmYt5Dz43eM3M8vKefCDx/jNzFrKd/B7jN/MrJV8Bz94jN/MLCPnwe8ev5lZVs6D38zMsuog+D3UY2ZWLN/B7527Zmat5Dv4wR1+M7OMnAe/e/xmZlk5D35wl9/MrKV8B7/H+M3MWsl38INP4DIzy8h58LvHb2aWlfPgNzOzrDoIfg/1mJkVy3fwe+eumVkr+Q5+8M5dM7OM/Ae/mZm1UAfB7x6/mVmxfAe/x/jNzFrJd/CDx/jNzDJyHvzu8ZuZZeU8+M3MLKvmwS+pr6T+lSvBQz1mZsX61apgSQ3AV4CBwGXAmxUopNs3aWbW29Wkxy9pCHAz8FREzI2ItRUrzDt3zcxaqHrwS+oD3ABcGxG3Vri0ym7ezKwXqkWP/8vAoIi4sjrFucdvZlasFsE/A7hE0v6SzpR0TMVK8hi/mVkrVd25K2lPYCzwu4hYKekp4H5Jb0TEHzLLTgWmAowaNar8Qj3Gb2bWQrV7/HsA6yJiJUBEbAauA76UXTAi5kVEU0Q0NTY2llmce/xmZlnVDv61bZS5guQDwczMqqDawf8s0FfS7kXThgDLK1ekh3rMzIpVNfgjYiNwBXBS0eTDgMoc4eOdu2ZmrdTizN0zgQslfQNoAK6LiLsrVpp37pqZtVD14I+IDcC06pTmHr+ZWVbNL9JWee7xm5kVy3fwe4zfzKyVfAc/eIzfzCwj58HvHr+ZWVbOg9/MzLLqIPg91GNmVizfwe+du2ZmreQ7+M3MrJWcB797/GZmWTkP/pQP6TQza5bv4PcYv5lZK/kO/gL3+M3MmuU8+N3jNzPLynnwm5lZVp0Ev4d6zMwK8h383rlrZtZKvoO/wDt3zcya5Tz43eM3M8vKefAXuMdvZlaQ7+B3h9/MrJWyg1/SoO6sSEV5jN/MrFm7wS/t8JCYCZJulPTnbq5TN3KX38wsq1SP/3VJ8yR9pK2ZEXEb8FlgZUVq1q3c4zczKygV/N+KiKkRcQ+ApFMlPSHpbEn7AkTENuC2alTUzMy6R78S814vfhIRV0rqExFXZJbb0v3V6iY+gcvMrJVSPf62UnNzB5frWbxz18ysWang72ha9uBU7fmfSWZm1VZqqGeipOz88W0c7PNJYF631qrb9eDPJjOzKisV/Menj6wvZZ733FT1GL+ZWSulhnrOiIg+O3oAX6hWZcvmMX4zs2algv/2Dm7j0e6oSGW4x29mltXuUE9EtBnokt4BHAqsAe6NiAcqU7Xu5B6/mVlBqUs2HCbpUknji6aNA54AfgZ8HbhJ0h6Vr2aZPMZvZtZKqZ27ZwAzIuLvAJL6Ar8EtgIfiIh/SBoCfAv4fsVramZm3aLUGP8DhdBPTQfeC3wzIv4BEBFrgZfLLVzSnpIuLnf9DvPOXTOzZqWCf13hl3Rc/3vAnRFxbWa5vbpQ/k+BwV1Yfwc81GNmllUq+PtIGp+G/i9IUvS04gXSMf9DyylY0onAs+Ws23nu8ZuZFZQK/p8CRwJ3pM+PKBrv303SLOBeyrg6p6QRwAhgSWfX7WRBFd28mVlv1GbwS2oAjouIcyPioIg4OiL+X2F+RKyMiH8DhgBryyj3G0Dlx/YLPMZvZtaszaN6ImKLpOmSAtheYn0BJwOXdLRASZOA2yJic6mbfEmaCkwFGDVqVEc330b1zMysWKnDOfcDfgS8RvsJKpIhmw6RNBQ4JCK+s6NlI2Ie6cXfmpqauthld4/fzKygVPAPByYB7wT+Dvw6IjZmF5I0sxPlnQAcJene9HkjMDh9fmZE3NWJbe2Yx/jNzFopdcmGdcBVAJLGAKdLGgj8MSLuLlr0Fx0tLCIuAy4rPJd0CnBYRJzSqVqbmVnZSvX4m0XEM8CFAOkhnucAbwL/HRFPVLB+3cM7d83MmpU6nLM9f0vXmw48KGmH4/W146EeM7OsDvX4ASQdCUwDJgLLgJ8AV0XEinILj4j5wPxy1+9ESZUvwsyslygZ/JLeTnKjlanAaOC/gU9GxB1Fy+wWESsrWcmyeeeumVkr7Z3A1UfSNcALwOnAz4E9I+LY4tBPfazCdew6j/GbmTVr7wSu7ZKOBW4guRNXAJ9o44SrvsCJwP+pZCXL5x6/mVlWqaGes0iCvxSRXM/HzMx6iTaDP73pyi0R8fyONiCph/b28Ri/mVkb2hvq2QY82ZENRMTN3VojMzOrqHKO4+99vHPXzKxZzoPfQz1mZlk5D/4C9/jNzAryHfzeuWtm1kq+g7/AY/xmZs3qI/jNzKxZnQS/e/xmZgX5Dn6P8ZuZtZLv4Dczs1bqI/i9c9fMrFnOg99DPWZmWTkP/gL3+M3MCvId/N65a2bWSr6Dv8Bj/GZmzXIe/O7xm5ll5Tz4C9zjNzMryHfwe4zfzKyVfAe/mZm1Uh/B7527ZmbNch78HuoxM8vKefAXuMdvZlaQ7+D3zl0zs1byHfwFHuM3M2uW8+B3j9/MLCvnwV/gHr+ZWUG+g99j/GZmreQ7+M3MrJX6CH7v3DUza5bz4PdQj5lZVtWDX9JOkn4i6UlJz0g6tfKlusdvZlZQix7/mcC1ETEW+DRwgaQjK1KSd+6ambVS1eCXNAB4JSIWAUTEUmAByQdA5XiM38ysWb8ql7cV+HlmWgAbK1Oce/xmZllV7fFHxLaI2JSZ/BHgl9llJU2VtETSkldffbWrJXdxfTOz/KjpUT2SPgfcHBF/zc6LiHkR0RQRTY2NjeUW0MUampnlT7WHeppJ2hs4HJhW8cI8xm9m1qwmPX5JuwLfAL4WEdtrUQczs3pVk+P4gX8Dzo6IN9NpgytUWmU2a2bWi9Wix38+cFFErAeQJGB2DephZlaXqjrGL2kSMBU4UW/teG0AXqhQgRXZrJlZb1bV4I+I64H+1SwzLbjqRZqZ9VS+SJuZWZ3JefAXuMdvZlaQ7+D3GL+ZWSv5Dv4Cj/GbmTWrj+A3M7NmDn4zszpTJ8HvoR4zs4J8B7937pqZtZLv4C/wzl0zs2Y5D373+M3MsnIe/AXu8ZuZFeQ7+D3Gb2bWSr6Dv8Bj/GZmzXIe/O7xm5ll5Tz4zcwsK9/Br7R5sa229TAz60HyHfx9+iY/tzv4zcwKch786Q3G3OM3M2uW8+B3j9/MLCvnwZ/2+LdvrW09zMx6kHwHvwo9fge/mVlBvoPfPX4zs1bqJPi317YeZmY9SM6D30M9ZmZZDn4zszqT8+D3GL+ZWVZ9BL9P4DIza1Yfwe8TuMzMmuU8+D3Gb2aWle/g9wlcZmat5Dv4PdRjZtZKnQS/e/xmZgU5D35fndPMLKtOgt89fjOzgn7VLlCSgGlAAAOARRHx14oU5qEeM7NWqh78wDeAxRGxBEDSZZJmRsTabi/JwW9m1kpVh3ok9QVOKIR+6j7glIoU2Lc/NOwM61dWZPNmZr1RtXv8E4ANmWlPAhcBP+720iQYuieseBQ2rknH/PXWvOzvUmHFNn4vWrZ5nplZ71Pt4B8DvJaZ9lo6vTJ23x8evg5+uFeFCvCHgFmPkreO2Sm3wF4f6tZNVjv4hwHZAfctwG6S+kRE8x1TJE0FpgKMGjWq/BKPOR9Gfxg2vw7Nmw+IyPyePm/xe3vLZn83s54hh/+TQ0Z2+yarHfwrgP6ZaQ3AyuLQB4iIecA8gKampvLfzYFD4X0nl726mVneVPs4/meA3TLTdgWerXI9zMzqVrWDfxEwRFJxue8GflnlepiZ1a2qBn9EbAWuBD5WNPmDwPxq1sPMrJ7V4gSui4HpkkaRjO9fFRFralAPM7O6VPXgj4ggCX8zM6uBfF+kzczMWnHwm5nVGQe/mVmdUfSCs08lvQo8X+bquwH1dpU2t7k+uM31oStt3isiGrMTe0Xwd4WkJRHRVOt6VJPbXB/c5vpQiTZ7qMfMrM44+M3M6kw9BP+8WlegBtzm+uA214dub3Pux/jNzKyleujxm5lZEQe/WQ8nqa+k7H0szMpWi4u0VYUkAdNIbskzAFgUEX+tba26RtJOwPnAUSQXuJsdEVem8w4ADgc2AX2Bywo3t8nLayFpT+A7EfH19Hmu2yypAfgKMBC4DHgzr21O6z6DpF07k7T5BxER6ft+PPA6sAtwaURsLFp3Sjq9D/BkRNxR7fp3VHpJ+iuBeyJiftH0st/XUuu2KyJy+QC+CTQVPb8MGFLrenWxTbOAD6e//zOwBjiS5I/+p0XLHQj8a95eC+BmYH76e67bDAwBfgccXTQtt20GvgycUPT8e8ApJDe1ns9b+yOHAxcVLfcp4LNFz2cDY2rdnnbauBNwEfAgcEp3vK87Wre9Ry6HeiT1JfkjWlI0+T6SP6ReSdIA4JWIWAQQEUuBBcCngROBBwrLRsSDwKcl9c/LayHpRFreqS23bU57hTcA10bErUWzcttmYCKwtOj5X4D3Ax8HlkeaahHxCvAeSYWzUb8N/KZovduB6ZWvblk+DvwHSfAX68r72u66pSqSy+AHJgAbMtOeBE6oQV26y1bg55lpQdLOE4CnMvPWknz96/WvhaQRwAig+I8/z23+MjAo0mG8Inlu84vAZ4qeHwz8mbbb/BwwSdLewMiI2FQ0r8e2OSJ+HRGvtjGrK+9rqXXbldfgHwO8lpn2Wjq9V4qIbZk/cICPkNy2slR78/BafIPW93DIc5tnAJdI2l/SmZKOSafnuc1zga9KulrS8cCqiPgF5bV5uKRBla5wN+rK+1rW+57XnbvDSHrIxbYAu0nqEzva8dELSPoccHNE/FVSe+0dBmxuZ16veC0kTQJui4jNyT6uZrlsc7ojcyzwu4hYKekp4H5Jb5DTNgNExHOSTgZmApcDN0paQPttHlFiHum85ypW4e5V9vu6g3Xbldce/wogO8bVAKzs6f8AHZF+xT0cOCud1F57V5SY1+NfC0lDgUMi4u42ZueyzcAewLqIWAkQEZuB64Avkd82I+lY4LiIOJyktzqSZKduOW0mnddbdOV9LbVuu/La43+G5FKmxXal5c7BXknSriRDH18r+ocu1d5NJeb1dCcAR0m6N33eCAxOn+e1zWtp3SFbQfJBn9c2Q3LkyhkAEfGapBOAl0iO5GqrXUtJ2tbWvBURkR0X78m68r6WlXV57fEvAoakX4UK3k0yHt5rpcfx/xtwdkS8mU4bTNKucZnFhwK/pxe/FhFxWUQcGBGHRMQhwA9IhrcOIadtJvmH7Stp96JpQ4Dl5LfNkByWuK7wJCJWAatIrlOTbfNo4PqIeAZ4Ie0MFfSmNhd05X0ttW67chn8EbGV5CSJjxVN/iDJV8fe7HySY5jXQ/OJHbOBXwDjCwtJGgf8OiLezPFrkcs2R3Ji0hXASUWTDyNpTy7bnLqNpJ0ASNofuDOdvnfh8ERJu5GcpFUYyvghycldBUcCl1Sjwt2oK+9ru+uWKjC3F2lLQ3E6yVelBuC+zLGwvUq6k/OXtDy0qwF4ISLGpv8ox5D0kvoDl0fEtnTdXLwWkk4BDouIU9LnuWxzekTKhcDjvDWeOz+dl9c2DyQ5QXEF8AbJztvzI+L1dIf3ScDLwNuAecVDOZK+QHJGawB/jx565m76/kwBzgOeJvnmfm86r+z3tdS67dYlr8FvZmZty+VQj5mZtc/Bb2ZWZxz8ZmZ1xsFvZlZnHPxmZnXGwW9mVmcc/GZmdcbBb1Ylkk6UtEbSYbWui9U3B79ZlaTXl1+6wwXNKszBb1ZdPlXeas7Bb5Yh6SO1roNZJeX1evxmOyRpJPAtkgtm7QVcD5wMHC3pbOCzQBPJ/WC/GBFPpOsNJ7lo1lLgHSSXxf1+RLxWNH8WyUWzPk1ySeVPR8QbadEDJJ1HcmGtTcDE9DLEZlXhi7RZ3ZJ0O/CFiHghvSLmn0gu8Xsf8ImIuFdSX2ABsC9wEMnVEZcAn4mIv6XbOQL414j4WPr8HmB+RFwlaQjJDUb+MyLWS7ob+AfJjXTWSvrfwBMRMbuKTbc65x6/1SVJ7wOGA+PTe/n2BZ4gCfY1hcvlRsQ2STNJbpAyBjgQ2FAI/XSZ30u6SlIT8ApwKHBUOm8tyc1zil2dTge4K13erGoc/Fav9gEej4hfFU1bIGl0dsH0RuDrSe5sNIbkloBZL5LcJL0PyQdDqRthFN8cezvJh45Z1XjnrtWrl4CDMre0a5OkBpL/lado+x6nkNw45DmSbwZvk7RX91XVrHs5+K1e/Tn9+R+S+gFI+jiwB8k9Tou/DU8BLomI14HfAEOLvxlIOpLkA+FPEfEqcB1wsaQB6fyPSNq30g0y6ygP9VhdioitkiYCVwErJP0VuAB4gaRD9C1Jz5P05IcAZ6frbZZ0NPA1SQ+Q/A/tD3wq3jpSYhrwc+AZSU8D10XEPZKOIxliOl3SOuBVkqN+xkj6SETcU53WW73zUT1mRdKe/N0RMboC2+6XfuAI6LOj+6KaVYp7/GZVEhFb058BOPStZjzGb2ZWZxz8Zqn0JK7PAsMlfdlH5lheeYzfLJWepTsIWE/SKeobEZtrWyuz7ufgNzOrMx7qMTOrMw5+M7M64+A3M6szDn4zszrz/wF1r11LhQ14cQAAAABJRU5ErkJggg==\n",
238 | "text/plain": [
239 | ""
240 | ]
241 | },
242 | "metadata": {
243 | "needs_background": "light"
244 | },
245 | "output_type": "display_data"
246 | },
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "Final loss:9.916835391280157e-05, w1.4771135242037658, b0.08849241688214672\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "def solve(data) :\n",
257 | " # 学习率\n",
258 | " lr = 0.01\n",
259 | " initial_b = 0\n",
260 | " initial_w = 0\n",
261 | " num_iterations = 1000\n",
262 | " [b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)\n",
263 | " loss = mse(b, w, data)\n",
264 | " print(f'Final loss:{loss}, w{w}, b{b}')\n",
265 | "\n",
266 | "if __name__ == \"__main__\": \n",
267 | " data = get_data() \n",
268 | " solve(data)"
269 | ]
270 | }
271 | ],
272 | "metadata": {
273 | "kernelspec": {
274 | "display_name": "Python 3 (ipykernel)",
275 | "language": "python",
276 | "name": "python3"
277 | },
278 | "language_info": {
279 | "codemirror_mode": {
280 | "name": "ipython",
281 | "version": 3
282 | },
283 | "file_extension": ".py",
284 | "mimetype": "text/x-python",
285 | "name": "python",
286 | "nbconvert_exporter": "python",
287 | "pygments_lexer": "ipython3",
288 | "version": "3.7.11"
289 | }
290 | },
291 | "nbformat": 4,
292 | "nbformat_minor": 5
293 | }
294 |
--------------------------------------------------------------------------------
/第 02 章 回归问题与神经元模型/全部源码(已更新)/PyCharm/《繁凡的深度学习笔记》第 2 章 回归问题与神经元模型 2.2.3 神经元线性模型实战.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | from matplotlib import pyplot as plt
4 | # cal y = 1.477x + 0.089 + epsilon,epsilon ~ N(0, 0.01^2)
5 |
6 | # plt参数设置
7 | plt.rcParams['font.size'] = 16
8 | plt.rcParams['font.family'] = ['STKaiti']
9 | plt.rcParams['axes.unicode_minus'] = False
10 |
11 | # 生成数据
12 | def get_data():
13 | # 计算均方误差
14 | #保存样本集的列表
15 | data = []
16 | for i in range(100):
17 | x = np.random.uniform(-10., 10.) # 随机采样 x
18 | # 高斯噪声
19 | eps = np.random.normal(0., 0.01) # 均值和方差
20 | # 得到模型的输出
21 | y = 1.477 * x + 0.089 + eps
22 | # 保存样本点
23 | data.append([x, y])
24 | # 转换为2D Numpy数组
25 | data = np.array(data)
26 | return data
27 |
28 | # mse 损失函数
29 | def mse(b, w, points) :
30 | totalError = 0
31 | # 根据当前的w,b参数计算均方差损失
32 | for i in range(0, len(points)) : # 循环迭代所有点
33 | # 获得 i 号点的输入 x
34 | x = points[i, 0]
35 | # 获得 i 号点的输出 y
36 | y = points[i, 1]
37 | # 计算差的平方,并累加
38 | totalError += (y - (w * x + b)) ** 2
39 | # 将累加的误差求平均,得到均方误差
40 | return totalError / float(len(points))
41 |
42 |
43 | # 计算偏导数
44 | def step_gradient(b_current, w_current, points, lr) :
45 | # 计算误差函数在所有点上的异数,并更新w,b
46 | b_gradient = 0
47 | w_gradient = 0
48 | # 总体样本
49 | M = float(len(points))
50 | for i in range(0, len(points)) :
51 | x = points[i, 0]
52 | y = points[i, 1]
53 | # 偏b
54 | b_gradient += (2 / M) * ((w_current * x + b_current) - y)
55 | # 偏w
56 | w_gradient += (2 / M) * x * ((w_current * x + b_current) - y)
57 | # 根据梯度下降算法更新的 w',b',其中lr为学习率
58 | new_b = b_current - (lr * b_gradient)
59 | new_w = w_current - (lr * w_gradient)
60 | return [new_b, new_w]
61 |
62 |
63 | # 梯度更新
64 | def gradient_descent(points, starting_b, starting_w, lr, num_iterations) :
65 | b = starting_b
66 | w = starting_w
67 | MSE = []
68 | Epoch = []
69 | for step in range(num_iterations) :
70 | b, w = step_gradient(b, w, np.array(points), lr)
71 | # 计算当前的均方误差,用于监控训练进度
72 | loss = mse(b, w, points)
73 | MSE.append(loss)
74 | Epoch.append(step)
75 | if step % 50 == 0 :
76 | print(f"iteration:{step}, loss:{loss}, w:{w}, b:{b}")
77 | plt.plot(Epoch, MSE, color='C1', label='均方差')
78 | plt.xlabel('epoch')
79 | plt.ylabel('MSE')
80 | plt.title('MSE function')
81 | plt.legend(loc = 1)
82 | plt.show()
83 | return [b, w]
84 |
85 | # 主函数
86 | def solve(data) :
87 | # 学习率
88 | lr = 0.01
89 | initial_b = 0
90 | initial_w = 0
91 | num_iterations = 1000
92 | [b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
93 | loss = mse(b, w, data)
94 | print(f'Final loss:{loss}, w{w}, b{b}')
95 |
96 |
97 | if __name__ == "__main__":
98 | data = get_data()
99 | solve(data)
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/PDF文件(已更新)/《繁凡的深度学习笔记》第 3 章 分类问题与信息论基础(DL笔记整理系列).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 03 章 分类问题与信息论基础/PDF文件(已更新)/《繁凡的深度学习笔记》第 3 章 分类问题与信息论基础(DL笔记整理系列).pdf
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/PyTorch实现/Jupyter Notebook/(PyTorch)手写数字识别实战.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "sys.version_info(major=3, minor=8, micro=5, releaselevel='final', serial=0)\n",
13 | "torch 1.10.2\n",
14 | "torchvision 0.11.3\n",
15 | "numpy 1.19.3\n",
16 | "pandas 1.3.3\n",
17 | "sklearn 0.22.2.post1\n"
18 | ]
19 | }
20 | ],
21 | "source": [
22 | "%matplotlib inline\n",
23 | "import os\n",
24 | "import sys\n",
25 | "import time\n",
26 | "import torch # 导入 pytorch \n",
27 | "import sklearn\n",
28 | "import torchvision # 导入视觉库\n",
29 | "import numpy as np\n",
30 | "import pandas as pd\n",
31 | "from torch import nn # 导入网络层子库\n",
32 | "from torch import optim # 导入优化器\n",
33 | "from torchsummary import summary # 从 torchsummary 工具包中导入 summary 函数\n",
34 | "from torch.nn import functional as F # 导入网络层函数子库\n",
35 | "from matplotlib import pyplot as plt\n",
36 | "from utils import one_hot, plot_curve, plot_image\n",
37 | "\n",
38 | "print(sys.version_info)\n",
39 | "for module in torch, torchvision, np, pd, sklearn:\n",
40 | " print(module.__name__, module.__version__)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 2,
46 | "metadata": {},
47 | "outputs": [
48 | {
49 | "data": {
50 | "text/plain": [
51 | "' Hyperparameters '"
52 | ]
53 | },
54 | "execution_count": 2,
55 | "metadata": {},
56 | "output_type": "execute_result"
57 | }
58 | ],
59 | "source": [
60 | "\n",
61 | "'''' 超参数 '''\n",
62 | "\n",
63 | "batch_size = 512 # 批大小\n",
64 | "n_epochs = 3\n",
65 | "# 学习率\n",
66 | "learning_rate = 0.01\n",
67 | "# 动量\n",
68 | "momentum = 0.9\n",
69 | "\n",
70 | "''' Hyperparameters '''"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 3,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "''' Step 1 下载训练集和测试集,对数据进行预处理 '''\n",
80 | "\n",
81 | "# 训练数据集,从网络下载 MNIST数据集 保存至 mnist_data 文件夹中\n",
82 | "# 创建 DataLoader 对象 (iterable, 类似 list,可用 iter() 进行访问),方便批量训练,将数据标准化在 0 附近并随机打散\n",
83 | "train_loader = torch.utils.data.DataLoader(\n",
84 | " torchvision.datasets.MNIST('mnist_data', train = True, download = True,\n",
85 | " # 图片预处理\n",
86 | " transform = torchvision.transforms.Compose([\n",
87 | " # 转换为张量\n",
88 | " torchvision.transforms.ToTensor(),\n",
89 | " # 标准化\n",
90 | " torchvision.transforms.Normalize(\n",
91 | " (0.1307,), (0.3081,))\n",
92 | " ])),\n",
93 | " batch_size = batch_size, shuffle = True)\n",
94 | "\n",
95 | "test_loader = torch.utils.data.DataLoader(\n",
96 | " torchvision.datasets.MNIST('mnist_data/', train = False, download = True,\n",
97 | " transform = torchvision.transforms.Compose([\n",
98 | " torchvision.transforms.ToTensor(),\n",
99 | " torchvision.transforms.Normalize(\n",
100 | " (0.1307,), (0.3081,)) # 使用训练集的均值和方差\n",
101 | " ])),\n",
102 | " batch_size = batch_size, shuffle = False)"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 4,
108 | "metadata": {},
109 | "outputs": [
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(2.8215)\n"
115 | ]
116 | },
117 | {
118 | "data": {
119 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAgfklEQVR4nO3de7xUVf3/8fdHBFIJTEBMAQmxQEwrJdFEUUFEwXuZl/ShUMrjkZJoxtcbIil+LeWSmaJlAqXhJfXrQ5CkQP3lFbwkoogEaoCCAuINAvbvjz1nu9b2zJyZOWtu57yejwcP1oc1s/eaM4vzmbXWnrUtiiIBANBY21S6AQCApoGEAgAIgoQCAAiChAIACIKEAgAIgoQCAAii4IRiZgvNrH/4pjQtZtbNzCIz27bSbakW9J380He+iL6Tn0r3nYITShRFvaMomluCtiAPZjanVn/Z0Hcqw8wuNLNVZvahmf3BzFpXuk2Fou9URqF9hymvGmJmp0tqWel2oHaY2SBJoyUdIWl3Sd0lja1oo1ATiuo7URQV9EfSMkkDMuWrJN0jabqkDZL+Jenrkv5H0nuS3pZ0pPPcsyUtyjx2qaRzU8e+RNJKSSskDZcUSeqRqWst6deS3pL0rqRbJG2XpY09JM2TtF7SGkl/ceomZdr1oaT5kvo5dYW+nrmSxkt6NnO8ByXtlKnrlmn/tpm4naTfZ17ffyT9UlKLAn7u7SQtltTXPW4t/aHvlL/vSPqzpGud+AhJqyrdF+g7TbPvhBihDJU0TdJXJL0g6VHFI5/dJF0t6Vbnse9JGiKpreI3eYKZfUeSzOwoSaMkDVD8xvRPnec6xT/kb2Xqd5N0ZZY2jZM0O9OmzpJ+49Q9lznGTop/YPeY2ZeKfD2SdKakcyR9VdJmSZOztOmPmfoekr4t6UjFnVdm1tXM1plZ1yzPlaRrJf1O0qocj6k19J3S953ekl5y4pckdTKz9lkeXyvoO9XYdwJ8UvibUzdU0kfKZEBJX1acLXfMcqwHJI3MlP8gaXwq20eZv03Sx5L2cOoPlPTvLMedKmmKpM55vJ61kvYt5vUo/qRwnfP4vSRtktRCzicFSZ0kbZTzyUbSqZL+kefPfH9JL2aOlRy3Ep8UG/OHvlORvvOmpKOcuGXm2N0q3R/oO02v74QYobzrlD+VtCaKoi1OLEltJMnMBpvZ02b2gZmtk3S0pA6Zx+yqeGhXxy13lLS9pPmZjLpO0qzMv9fnEsWd4dnM1SHn1FWY2cVmtsjM1meO085pQ0Gvp552Llf8Q3ePJ8Xzjy0lrXTaf6uknbO0P2Fm20i6WfF/gM0NPb7G0Hc+F7zvZHyk+JN5nbryhjyfX63oO5+rmr5TtiuFMlcH3Kd4qPZgFEX/NbMHFL8BUjzH19l5ShenvEbxD7V3FEX/aehcURStkvTjzHkPlvSYmT2ueHh4ieK5wIVRFG01s7VOG4rhtrOrpP9m2uv++9uKPyl0KCIptFU8QvmLmUnxpxBJesfMvh9F0RNFtbqG0HeK7juStFDSvpJmZOJ9Jb0bRdH7RRyr5tB3ytt3ynmVVyvFC1yrJW02s8GK5/PqzJB0tpn1MrPtJV1RVxFF0VZJtyme+9xZksxst8xVCF9gZt83s7pOslbxMG2r4qHj5kwbtjWzK+Vn4GKcYWZ7Zdp8taR7nU8Wde1fqXhu9QYza2tm25jZHmZ2aB7HX6/4U9S3Mn+Ozvz7fpKeaWTbawV9p7i+I8XTMMMy59lR0uWK59WbC/pOGftO2RJKFEUbJF2g+A1cK+k0SQ859TMVLyz9Q9ISSU9nqjZm/v5F3b+b2YeSHpP0jSyn6yPpGTP7KHOOkVEULVW80DVL8dVSyyV9Jn/oWIxpin/IqyR9KfMa63Om4s79quLXf6/iTy51i2Mf1bc4FsVW1f1R3Cml+JPCpka2vSbQd4rrO5IURdEsSdcr/tm8lWn7mEa2u2bQd8rbdyyz2FJ1zKyXpFckta7WtQMzmytpehRFt1e6LfgcfQfFou80TlV9sdHMTjCz1mb2FUn/K+n/qvVNRXWh76BY9J1wqiqhSDpX8TXjb0raImlEZZuDGkLfQbHoO4FU7ZQXAKC2VNsIBQBQo0goAIAgCvpio5kxP1aFoihqzBekSo5+U7XWRFGU7VvfVYG+U7Xq7TuMUIDma3mlG4CaVW/fIaEAAIIgoQAAgiChAACCIKEAAIIgoQAAgiChAACCIKEAAIIgoQAAgiChAACCIKEAAIIgoQAAgiChAACCIKEAAIIgoQAAgiChAACCKOgGW9VmxIgRXtyhQwcvXrJkSVK+6667ytImAKjPfvvt58WHHHJIUl65cqVXd/fdd5elTaExQgEABEFCAQAEQUIBAARR9Wso++yzjxfPnj07KafXTMzMi2fNmpWUWUMBUGp9+/ZNyl26dPHqbr31Vi9u27ZtUt60aZNXt9dee3nxv/71r6R8zz33NLqdpcIIBQAQBAkFABBE1U959enTx4s7duxY9jacddZZXjx58uSk/OGHH3p1xxxzTFJ++eWXS9swVJX27dt7cfoy0Xnz5iXljRs3lqVNCCv9Hg8cONCL3d8NO+20k1eXnpKPoigpt2rVyqu79NJLs7bhwgsv9OIXX3wxKd98881e3SuvvJL1OKXACAUAEAQJBQAQBAkFABBEza2hlEPv3r29+JZbbvFid76zTZs2Xt1BBx2UlFlDqU1DhgxJygMGDPDq+vfv78XuPPiXv/xlr65bt25evGjRoqS8ZcsWr+7Pf/5zUr7++usLai9Ky31fp06d6tUNGjSo3M3RAQcc4MXf/e53k/JJJ53k1R133HFe/PTTT5euYWKEAgAIhIQCAAiChAIACKLq1lB69uzpxT/4wQ/K3obRo0d7cfoacdSeMWPGJOVzzz0352Pd7w9su63/XyTXdwka0qtXr6zHufrqq5PyCSec4NUde+yxXrx69eq8z4nGe+ihh5Jyv379gh3X/f7I9OnT837eDTfckLUu/T2Zzp07F9yuxmCEAgAIgoQCAAii6qa8Ro4c6cXt2rUr+lgvvfRSUc9LX/6ZyyeffOLFb7zxRlHnROFat26dlN3LtSV/mkLyL+/eunVr0efcZhv/M1ixx0ofp2XLlknZvQxUkm666SYvHj58eFLesGFDUefH58444wwvvvPOO73Yfa/S7/d1113nxT169EjKJ598slc3ceJEL3anOdevX593e9PHmTZtWlI+7bTTvLr0DsdLly5NygsWLMj7nPlihAIACIKEAgAIgoQCAAii6tZQ0vOOhVixYoUX33777Y1tToOefPJJL54zZ07Jz9lcpbeDHz9+fFI+/PDDcz7Xnfsu5FLfV1991YvTW6Y8/PDDSfn111/P+7iHHHKIF5944olJeccdd/Tq0ttpPP/880n5V7/6Vd7nxOfOPvvspOxuOS99sX+89tprSTm99Ur6El53PSN9nIsuuqi4xjbgt7/9bVLu1KmTV5f+f/HTn/40KZ9zzjnB28IIBQAQBAkFABAECQUAEERVrKGMGzcuKRfyvZP03Gf6+uzly5c3ql35mDFjRsnP0Zy525Ck3++vfvWrQc7x/vvve/HPfvazpHzfffd5dZs2bQpyzvRWG+53ak4//fScz01/9wkNO+KII7x4woQJSXm77bbz6pYtW+bFRx11VFJu6HeK20fTa1+l4m5Jf+WVV3p16TWUM888MymzhgIAqFokFABAEBWZ8jrrrLO82N3dN70lRS4LFy70Ync31/pi1+677+7Fw4YNS8ru9gkNueKKK7x4xIgRWR+bvuTU3fLhzTffzPuczYm7PUWoKa61a9d68SmnnOLFc+fODXKeUnH7nHvJKLLr3r27F7tb8aT/7x1zzDFeXMjU+b///e+kXIltcdJtTW8/te+++5b0/IxQAABBkFAAAEGQUAAAQVRkDaVv375eXMi6iSu9NXMlpNdi0nEu7jYt3bp1C9WkJmXNmjXBj3n//fd7cbWvmaR16NCh0k2oOZdffnnWussuu8yLlyxZUvR53O2fjj/+eK/uuOOO8+IHH3yw6PNkk/79U+o1kzRGKACAIEgoAIAgSCgAgCDKtobizmH++Mc/Ltdpq1qXLl0q3YSq9+677yblffbZx6vr3LlzUp45c2bO47z33ntJedKkSYFal1t6Sw93+4/Bgwd7denb0KLx3NsdpG/r/dRTTyXle++9tyTnnzdvXkmOm0v//v292MzKen5GKACAIEgoAIAgyjbl5e4o7N49r9aldybdYYcdknLHjh3L3Jqm54c//GHWOnfKq6G7MLp31kvfhbEx3GmtgQMHenUXX3yxFx900EFJOT0VUchdJO+8885CmthsuXcnbNu2rVf34osvlrk1pTNkyJCknN5tON2v3N/DpcAIBQAQBAkFABAECQUAEETZ1lBWrVqVlHfeeee8n/fBBx948TvvvFPU+W+//XYvTs9hn3/++Um5oe3r3bv2XXTRRV7d0qVLkzLbqdS+9GWY6Tv/uXfEO+CAA0rShpUrV3pxus+hfu7dCdNrCek7ZtaS9NZV7lcyWrVq5dU9++yzXjxlypTSNUyMUAAAgZBQAABBkFAAAEFYIde/m1n+D07p2bNnUt577729OndO8OWXX/bq0t8ZeP7554ttQk7uVhgPP/xwzse63z3ZY489StKeQkRRVN79FQrUmH6Ti7udRnqueM899/RidwsXdxuWhqRvI92yZUsvLuT/jyvX91DSaybpbc8XLFhQ1DnrMT+Kov1DHawUGtN33Ftup9+n733ve0n5mWeeKfYUJdO+fXsvdr+PNXnyZK/OfW0bN2706k4++WQvbmiLogLU23cYoQAAgiChAACCKNtlw6+99lq9Zal0u30W4oILLsj7se6WKoceeqhXV4kdRpurDRs2JOUnn3zSq0tPee2yyy5JuVOnTqVtWB5+/etfe7E7FZHeFmTdunVlaFHzcuyxxyblapjyGjlypBefd955Xpzuz9mMHTvWiwNOceWFEQoAIAgSCgAgCBIKACCIsq2hNCXuFvXdu3f36lhDqYz0XUDT26B885vfTMqNuX3CNtv4n8Hc9Y1p06Z5den563LPZzd37rY4Dz30kFc3evTopLzrrrt6demvLriXaacv6W7Tpo0Xp9dUXTfeeKMXF9sPX3jhBS92LyOeOnVqUccMhREKACAIEgoAIAimvNAk7bPPPl7s7rK6//7Ffzn88ccf9+KJEycm5fTdO1FZ7vRzeir66KOPTso/+tGPch7n448/Tsrpb6K3aNHCi9u1a5f1OOkprly7LDz11FNe7O6ynr6keP369VmPU26MUAAAQZBQAABBkFAAAEGwhpLxyCOPJOUjjzzSq0vPfbrzmdWwbQMa9pOf/KTSTUAFTZgwwYu7du2alNO7n6e5lwa7Xxko1IoVK7zYXXP7zW9+49XNmTPHi99///2iz1tOjFAAAEGQUAAAQZBQAABBlO2OjdXO3d78jjvu8OrS2+v//ve/L0ub8tVc79iIRmvSd2zMxb0j4mGHHebVDR061Iv79++flNPrqcuXL/fim266Kes533jjDS9+6aWX8mprleKOjQCA0iGhAACCYMqrCWDKC0VqtlNeaDSmvAAApUNCAQAEQUIBAARBQgEABEFCAQAEQUIBAARBQgEABEFCAQAEQUIBAARBQgEABEFCAQAEQUIBAARBQgEABEFCAQAEsW2Bj18jaXmDj0I57V7pBuSBflOd6DsoVr19p6D7oQAAkA1TXgCAIEgoAIAgSCgAgCBIKACAIEgoAIAgSCgAgCBIKACAIEgoAIAgSCgAgCBIKACAIEgoAIAgSCgAgCBIKACAIApOKGa20Mz6h29K02Jm3cwsMrNCbxHQZNF38kPf+SL6Tn4q3XcKTihRFPWOomhuCdqCHMysu5k9bGYbzGyNmV1f6TYVir5TfmZ2i5l95PzZaGYbKt2uQtF3KsPMLjSzVWb2oZn9wcxa53o8U141wMxaSfqbpL9L2kVSZ0nTK9oo1IQois6LoqhN3R9Jd0m6p9LtQvUzs0GSRks6QvENtbpLGpvzSVEUFfRH0jJJAzLlqxR3zumSNkj6l6SvS/ofSe9JelvSkc5zz5a0KPPYpZLOTR37EkkrJa2QNFxSJKlHpq61pF9LekvSu5JukbRdljb2kDRP0nrFd3z7i1M3KdOuDyXNl9TPqSv09cyVNF7Ss5njPShpp0xdt0z7t83E7ST9PvP6/iPpl5Ja5Pkz/4mkJwp9r6rtD32n/H0n9dp2yLTt0Er3BfpO9fcdSX+WdK0THyFpVa7nhBihDJU0TdJXJL0g6VHFI5/dJF0t6Vbnse9JGiKpreI3eYKZfUeSzOwoSaMkDVD8xvRPnec6xT/kb2Xqd5N0ZZY2jZM0O9OmzpJ+49Q9lznGTop/YPeY2ZeKfD2SdKakcyR9VdJmSZOztOmPmfoekr4t6UjFnVdm1tXM1plZ1yzP7StpmZnNzEx3zTWzb2Z5bC2h75S+77hOkrRa0uN5PLba0XdK33d6S3rJiV+S1MnM2md5fJARyt+cuqGSPlImA0r6suJsuWOWYz0gaWSm/AdJ41PZPsr8bZI+lrSHU3+gpH9nOe5USVMkdc7j9ayVtG8xr0fxJ4XrnMfvJWmTpBZyPilI6iRpo5xPNpJOlfSPPH/msyX9V9JgSa0k/VzxJ61Wlfi0WOwf+k75+06qvXMkXVXpfkDfqY2+I+lNSUc5ccvMsbtle06IEcq7TvlTSWuiKNrixJLURpLMbLCZPW1mH5jZOklHS+qQecyuiod2ddxyR0nbS5qfyajrJM3K/Ht9LlHcGZ7NXB1yTl2FmV1sZovMbH3mOO2cNhT0eupp53LFP3T3eFI8/9hS0kqn/bdK2jlL+9M+lfRkFEUzoyjapHgI3l5SrzyfX63oO58rVd+pa3tXxZ++pxbyvCpG3/lcqfrOR4pHdXXqylkv6ijbpWWZqwPuUzxUezCKov+a2QOK3wApnuPr7Dyli1Neo/iH2juKov80dK4oilZJ+nHmvAdLeszMHlc8PLxE8VzgwiiKtprZWqcNxXDb2VXxSGJN6t/fVvxJoUMURZuLOMfLkr5XdAtrHH2nUX2nzo8k/b8oipY24hg1h77TqL6zUNK+kmZk4n0lvRtF0fvZnlDOq7xaKV7gWi1ps5kNVjyfV2eGpLPNrJeZbS/pirqKKIq2SrpN8dznzpJkZrtlrkL4AjP7vpnVdZK1iodpWxUPHTdn2rCtmV0pPwMX4wwz2yvT5qsl3et8sqhr/0rF01Y3mFlbM9vGzPYws0PzPMd0SX3NbICZtZD0M8WdZ1Ej214r6DvF9506ZyqeT29u6DvF952pkoZlzrOjpMvVQB8qW0KJomiDpAsUv4FrJZ0m6SGnfqbihaV/SFoi6elM1cbM37+o+3cz+1DSY5K+keV0fSQ9Y2YfZc4xMvPJ7FHFQ9bFioeJn8kfOhZjmuIf8ipJX8q8xvqcqbhzv6r49d+r+JNL3eLYR9kWx6Ioel3SGYqvMFkr6ThJx2amv5o8+k7xfSfzmAMVfwpvdpcL03ca9XtnlqTrFf9s3sq0fUyuRllmsaXqmFkvSa9Iat3IoX7JmNlcSdOjKLq90m3B5+g7KBZ9p3Gq6ouNZnaCmbU2s69I+l9J/1etbyqqC30HxaLvhFNVCUXSuYqvGX9T0hZJIyrbHNQQ+g6KRd8JpGqnvAAAtaXaRigAgBpFQgEABFHQFxvNjPmxKhRFUWO+IFVy9JuqtSaKomzf+q4K9J2qVW/fYYQCNF/LK90A1Kx6+w4JBQAQBAkFABAECQUAEAQJBQAQBAkFABAECQUAEAQJBQAQBAkFABBE2W4B3JTsvffeSXno0KFZHzds2DAv3rTJvx/WXnvtFbZhAFBBjFAAAEGQUAAAQZBQAABBsIaSh3PPPdeLx44dm5Q3bNjg1c2cOTMpP/LII17de++9V4LWAWhqpk2b5sWDBg1KyqNGjfLqpk+fXpY25YMRCgAgCBIKACCIgu4p35RudtOyZUsvPvXUU5Py+PHjvbpOnTp58dKlS5PykCFDvLrFixeHamLeuMEWijQ/iqL9K92IXJpL3+nZs6cXP/fcc168/fbbJ+UnnnjCq+vfv3/J2pVDvX2HEQoAIAgSCgAgCBIKACCIZnPZ8P77+9N9e+65pxffcccdSdnMX5J48803vdhdN6nEmgmApuX+++/3YnfNRPJ/J/Xr18+ru+yyy5LyNddcU4LW5Y8RCgAgCBIKACCIJn3Z8A477JCU00PKAQMGZH3e1KlTvfiqq67y4uXLlze+cQE1l8uGZ8yY4cUnnXRSUcdZvXq1F0+aNKnoNpXD7NmzvXj+/PmhDs1lwxV0wgknJOV7773Xq0v/XnanvNJ1L7zwQlLu06dPyCbmwmXDAIDSIaEAAIIgoQAAgmhSlw23adPGiydOnJiU02smW7Zs8WJ3XvqXv/ylV1dtaybN1VNPPeXFJ554YlHH6dChgxePGzcu7+emLykvZA2yWCNHjvTiXXbZpeTnROO5ayTSFy/p/cY3vpGU0/0qLVe9u1acvtz4k08+abCdITFCAQAEQUIBAARBQgEABFHTayju3KHkr5lI0tlnn52UN2/e7NVdf/31XnzFFVeEbRyCmzJlihcffvjhSXnTpk1e3fr1673YXTc5+uijS9C6cNLz3uedd16FWoJCuesm6e+zpdc33PW3v/71r17d8ccfn/Uc6XU7dy0mvQ3+ggULcjc4MEYoAIAgSCgAgCBqesrrhhtu8GJ3iivtpptu8mKmuGrPxx9/7MVnnHFGUv7000+9uvQU2HbbbZeUu3btmvM8rVu3TspXXnmlV/f444/n11hJjz76qBe7xzrllFOyPm/MmDFe/MADD+R9TpSXu9Ov5F+Cnp6aevvtt73Y7b9PPvlkzvPst99+SfnZZ5/16hq65LicGKEAAIIgoQAAgiChAACCqOk1lG9/+9s56996662knL7kFLUvfWlwLu4ay+uvv573804++eSC2uTq37+/F+fqr8uWLUvK06ZNK/qcKK2OHTt68fDhw73YXTd59dVXvbr0elxD6yauXr161XuOasMIBQAQBAkFABAECQUAEERNr6E05MILL0zKhcybAyGMGjXKi7/+9a9nfeySJUuS8po1a0rWJjTOvHnzvDj9nSZ3qxP3949U2JpJ2sEHH5yU0987cc9Z7q1W0hihAACCIKEAAIKouSkvdzfPPn36eHUrV6704sWLF5elTYAk9ejRw4vdSz3T0lvFvPjii6VoEgJwf+e4O/tKX7yE97bbbkvKjZniytWG9DkXLVoU7DyNxQgFABAECQUAEAQJBQAQRM2tofTt2zcpp+cS3a1WJH/rg4EDB3p16blQ17XXXuvF//znP734pJNOSsrpLdXRfP3973/34l133TXrY9O3XrjqqqtK0SQUYffdd/fiW265JSmnL9lN32mx2C2e0nefTd/t0d3yZfXq1V5d+vdVJTFCAQAEQUIBAARBQgEABFFzayhDhw7NWtehQwcvnj17dlI+8MADvbrtt98+63HS86Tp9ZcZM2Yk5dNOO82rK2RLddQe9/bAknTNNdck5Xbt2uV87ssvv5yUX3nllbANQzA33nijF7dv3z4pp9dtQ61fjB492ouPO+44L3bPmz7na6+9FqQNITBCAQAEQUIBAARRc1NeuXTv3j1nHMpRRx2VlHfZZRevjimvpm3mzJlefMghh2R97MKFC734mGOOScrpbYJQOen30N3mRPKnm9J3XWzM7r7ueS677LKs50yfd9KkSUWfs9QYoQAAgiChAACCIKEAAIJoUmsouWzcuNGL58yZ48Vjx45NyuPGjfPqjjzyyNI1DDXlsMMO8+KtW7dmfax7F0aJdZNqdfzxx3txev3Cjd3LxAuVXptxt1dJnzN9nsact5wYoQAAgiChAACCIKEAAIKo+jWU9PpFly5d8n7uZ599lpR/8YtfeHW/+93vvPjSSy9NyoMGDfLq0vOba9asqfccaBoOPfTQpDxq1CivLr1m4vaNP/7xj16duy6H6uJuUX/66ad7demtl5544omizpFeM0mvg7jbP6W/zzJ58uSizllpjFAAAEGQUAAAQVT9lFd6V+Btt82/yStWrEjKjz32mFeX3t3TvWNeesj7wQcfePEpp5ySlJcvX553e1CdBg8e7MV33313Uk7fSS/trrvuSsoXX3yxV7du3brGNw4lMXz48KTs7iYsNW5H4WnTpiXl9OXI6d9l7nnSfdCdVq8ljFAAAEGQUAAAQZBQAABBVP0aygMPPODFV199dVIeM2aMV9eyZUsvdrevT28lnsuGDRu8+Oabb/biuXPn5n0sVD93TUzKvW4yb948Lx42bFhS3rRpU9iGoWR23nnnpJxeM01fwuve+XX//ff36u68804v7tWrV1JOr8Wkj+uum9TqmkkaIxQAQBAkFABAEFU/5ZU2fvz4pDxw4ECv7uCDD/biFi1aZD1Oejh63333JeUJEyZ4dU8//XTB7UR1OfbYY5OyuyuCJH3ta1/L+zjLli3z4i1btjSqXagM9/9/+ndBz549vfi5555Lyt/5zneyHicd33///V7diBEjvLipTHO5GKEAAIIgoQAAgiChAACCqLk1FNfhhx/uxT//+c+9+LrrrkvK99xzj1f3zDPPeHF63QS1rXfv3l7srpukL/3M5Z133vFidz5dYg2lVq1evToppy8bbtOmjRe76ybpx6bXQf70pz8l5fSWLU1xzSSNEQoAIAgSCgAgCBIKACAIS19HnfPBZvk/GGUTRZE1/KjKKUe/GTp0qBffcccdXrzjjjtmfe7mzZu92N2m59RTT/XqFi9eXGQLq9L8KIryX1CqgFL1HXcr+alTp3p1/fr182L3d+Rtt93m1aXjt956K1QTq129fYcRCgAgCBIKACAIpryagOY65eVezjlr1iyvbqeddsr7OBMnTvTi9J0Xm7BmO+WFRmPKCwBQOiQUAEAQJBQAQBA1vfUKkC93i5QpU6Z4delLPwEUhxEKACAIEgoAIAgSCgAgCNZQULMWLFiQlKdPn+7VXXDBBV48f/78pHz++eeXtmFAM8UIBQAQBAkFABAEW680Ac116xU0GluvoFhsvQIAKB0SCgAgCBIKACCIQi8bXiNpeSkagqLtXukG5IF+U53oOyhWvX2noEV5AACyYcoLABAECQUAEAQJBQAQBAkFABAECQUAEAQJBQAQBAkFABAECQUAEAQJBQAQxP8HBKH1tZvvgs0AAAAASUVORK5CYII=\n",
120 | "text/plain": [
121 | ""
122 | ]
123 | },
124 | "metadata": {},
125 | "output_type": "display_data"
126 | }
127 | ],
128 | "source": [
129 | "''' Step 2。 展示样本数据 '''\n",
130 | "\n",
131 | "\n",
132 | "def show_sample_image():\n",
133 | " # 使用 iter() 从 DataLoader 中取出 迭代器, next() 选取下一个迭代器\n",
134 | " x, y = next(iter(train_loader))\n",
135 | " # 输出数据的 shape,以及输入图片的最小最大强度值\n",
136 | " print(x.shape, y.shape, x.min(), x.max())\n",
137 | " # 使用自己封装的 polt_image() 函数对图片进行展示\n",
138 | " plot_image(x, y, 'image sample')\n",
139 | " \n",
140 | "show_sample_image() "
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 5,
146 | "metadata": {},
147 | "outputs": [
148 | {
149 | "name": "stdout",
150 | "output_type": "stream",
151 | "text": [
152 | "True\n",
153 | "Net(\n",
154 | " (fc1): Linear(in_features=784, out_features=256, bias=True)\n",
155 | " (fc2): Linear(in_features=256, out_features=64, bias=True)\n",
156 | " (fc3): Linear(in_features=64, out_features=10, bias=True)\n",
157 | ")\n",
158 | "----------------------------------------------------------------\n",
159 | " Layer (type) Output Shape Param #\n",
160 | "================================================================\n",
161 | " Linear-1 [-1, 1, 256] 200,960\n",
162 | " Linear-2 [-1, 1, 64] 16,448\n",
163 | " Linear-3 [-1, 1, 10] 650\n",
164 | "================================================================\n",
165 | "Total params: 218,058\n",
166 | "Trainable params: 218,058\n",
167 | "Non-trainable params: 0\n",
168 | "----------------------------------------------------------------\n",
169 | "Input size (MB): 0.00\n",
170 | "Forward/backward pass size (MB): 0.00\n",
171 | "Params size (MB): 0.83\n",
172 | "Estimated Total Size (MB): 0.84\n",
173 | "----------------------------------------------------------------\n"
174 | ]
175 | }
176 | ],
177 | "source": [
178 | "''' Step 3。 搭建网络模型 '''\n",
179 | "\n",
180 | "class Net(nn.Module):\n",
181 | " # 网络初始化\n",
182 | " def __init__(self):\n",
183 | " super(Net, self).__init__()\n",
184 | " # y = wx + b\n",
185 | " # 三层全连接层神经网络\n",
186 | " self.fc1 = nn.Linear(28 * 28, 256)\n",
187 | " self.fc2 = nn.Linear(256, 64)\n",
188 | " self.fc3 = nn.Linear(64, 10)\n",
189 | "\n",
190 | " # 定义神经网络前向传播逻辑\n",
191 | " def forward(self, x):\n",
192 | " # x : [b, 1, 28, 28]\n",
193 | " # h1 = relu(w1x + b1)\n",
194 | " x = F.relu(self.fc1(x))\n",
195 | " # h2 = relu(w2x + b2)\n",
196 | " x = F.relu(self.fc2(x))\n",
197 | " # h3 = w3h2 + b3\n",
198 | " x = self.fc3(x)\n",
199 | " # 直接返回向量 [b, 10], 通过 argmax 即可得到分类预测值\n",
200 | " return x \n",
201 | " '''\n",
202 | " 也可直接将向量经过 softmax 函数得到分类预测值\n",
203 | " return F.log_softmax(x, dim = 1)\n",
204 | " '''\n",
205 | " \n",
206 | "# 使用 summary 函数之前,需要使用 device 来指定网络在 GPU 还是 CPU 运行\n",
207 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"gpu\") \n",
208 | "print(torch.cuda.is_available())\n",
209 | "net = Net().to(device)\n",
210 | "# 所有的张量都需要进行 `.to(device)` \n",
211 | "print(net)\n",
212 | "summary(net, (1, 28 * 28))\n",
213 | "# summary(your_model, input_size=(channels, H, W))\n",
214 | "# input_size 要求符合模型的输入要求, 用来进行前向传播"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 6,
220 | "metadata": {},
221 | "outputs": [
222 | {
223 | "name": "stdout",
224 | "output_type": "stream",
225 | "text": [
226 | "epoch:0, iteration:0, loss:0.12122594565153122\n",
227 | "epoch:0, iteration:10, loss:0.09926386922597885\n",
228 | "epoch:0, iteration:20, loss:0.08644484728574753\n",
229 | "epoch:0, iteration:30, loss:0.07972756773233414\n",
230 | "epoch:0, iteration:40, loss:0.07488133013248444\n",
231 | "epoch:0, iteration:50, loss:0.07044463604688644\n",
232 | "epoch:0, iteration:60, loss:0.06637724488973618\n",
233 | "epoch:0, iteration:70, loss:0.06356173753738403\n",
234 | "epoch:0, iteration:80, loss:0.05986356735229492\n",
235 | "epoch:0, iteration:90, loss:0.05715459585189819\n",
236 | "epoch:0, iteration:100, loss:0.05593841150403023\n",
237 | "epoch:0, iteration:110, loss:0.05120903253555298\n",
238 | "epoch:1, iteration:0, loss:0.05355915054678917\n",
239 | "epoch:1, iteration:10, loss:0.050536610186100006\n",
240 | "epoch:1, iteration:20, loss:0.04767081141471863\n",
241 | "epoch:1, iteration:30, loss:0.048206694424152374\n",
242 | "epoch:1, iteration:40, loss:0.04414692148566246\n",
243 | "epoch:1, iteration:50, loss:0.04389065504074097\n",
244 | "epoch:1, iteration:60, loss:0.042195286601781845\n",
245 | "epoch:1, iteration:70, loss:0.044043783098459244\n",
246 | "epoch:1, iteration:80, loss:0.04229837656021118\n",
247 | "epoch:1, iteration:90, loss:0.0413779579102993\n",
248 | "epoch:1, iteration:100, loss:0.04181772097945213\n",
249 | "epoch:1, iteration:110, loss:0.03836136683821678\n",
250 | "epoch:2, iteration:0, loss:0.03855973109602928\n",
251 | "epoch:2, iteration:10, loss:0.038034625351428986\n",
252 | "epoch:2, iteration:20, loss:0.03956493362784386\n",
253 | "epoch:2, iteration:30, loss:0.03763895854353905\n",
254 | "epoch:2, iteration:40, loss:0.03816075250506401\n",
255 | "epoch:2, iteration:50, loss:0.03620471805334091\n",
256 | "epoch:2, iteration:60, loss:0.03597486764192581\n",
257 | "epoch:2, iteration:70, loss:0.03391267731785774\n",
258 | "epoch:2, iteration:80, loss:0.0356239415705204\n",
259 | "epoch:2, iteration:90, loss:0.035082779824733734\n",
260 | "epoch:2, iteration:100, loss:0.031901001930236816\n",
261 | "epoch:2, iteration:110, loss:0.03308051824569702\n"
262 | ]
263 | },
264 | {
265 | "data": {
266 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtxUlEQVR4nO3dd5iU9dX/8feBpYhgoYhLURBRqnU12AsqYAGNEjQ+RqOJjz3YEow+iqZYYixJ/NkVEwsYjAGNBBtojChNFAERFAhNKSICSt3z++PMZGaX2WWBnZ3Z3c/ruua6+8yZ+1rm8L2/zdwdERGR0urkOgAREclPShAiIpKREoSIiGSkBCEiIhkpQYiISEYFuQ6gsjRv3tzbtWuX6zBERKqVSZMmLXP3FpmO1ZgE0a5dOyZOnJjrMEREqhUzm1fWMT1iEhGRjJQgREQkIyUIERHJqMbUQYiIVMSGDRtYsGABa9euzXUoVaphw4a0adOGevXqVfgaJQgRqVUWLFhAkyZNaNeuHWaW63CqhLuzfPlyFixYQPv27St8nR4xiUitsnbtWpo1a1ZrkgOAmdGsWbOtLjUpQYhIrVObkkPStnxnJQhg40Z4+GFYsybXkYiI5A8lCODll+GSS+CCC3IdiYhISY0bN87ZZytBAO+9F8vhw+HLL3Mbi4hIvlCCAN55J7U+e3bu4hCRmm/QoEE88MAD/90ePHgwv/71r+nZsycHHXQQ3bt3Z8SIEZtdN3bsWE499dT/bl9xxRUMGTIEgEmTJnHMMcdw8MEH06tXLxYvXlwpsWa1mauZ9QbuB+oCj7n7HaWOHw3cB+wHnO3uwxP7DwAeBHYCNgG/cfdh2Yhx3TqYOBF69YLRo2HRomx8iojko4EDYcqUyn3PAw6A++4r+/iAAQMYOHAgl19+OQDPP/88o0eP5qqrrmKnnXZi2bJl9OjRg759+1aoYnnDhg1ceeWVjBgxghYtWjBs2DBuvPFGnnjiie3+LllLEGZWF3gAOBFYAEwws5HuPj3ttP8AFwDXlbr8W+BH7j7LzFoBk8xstLt/XdlxfvVVJIcBAyJBLFxY2Z8gIpJy4IEHsmTJEhYtWsTSpUvZdddd2X333bn66qt5++23qVOnDgsXLuTLL79k99133+L7zZw5k48//pgTTzwRgE2bNlFYWFgpsWazBHEoMNvdPwcws6FAP+C/CcLd5yaOFadf6O6fpq0vMrMlQAvg68oOsrAQRowAd7jwQiUIkdqkvP/pZ1P//v0ZPnw4X3zxBQMGDOCZZ55h6dKlTJo0iXr16tGuXbvN+iwUFBRQXJz6qUwed3e6du3KuHHjKj3ObNZBtAbmp20vSOzbKmZ2KFAf+KyS4irjc6B1az1iEpHsGzBgAEOHDmX48OH079+flStXsttuu1GvXj3GjBnDvHmbj8C95557Mn36dNatW8fXX3/NG2+8AcC+++7L0qVL/5sgNmzYwLRp0yolzrweasPMCoG/AOe7e3GG4xcDFwPsscce2/15rVurBCEi2de1a1dWrVpF69atKSws5Nxzz+W0006je/fuFBUV0alTp82uadu2LT/4wQ/o1q0b7du358ADDwSgfv36DB8+nKuuuoqVK1eyceNGBg4cSNeuXbc7TnP37X6TjG9sdhgw2N17JbZvAHD32zOcOwR4OVlJndi3EzAW+G36/rIUFRX59k4YdM45UWE9a9Z2vY2I5LEZM2bQuXPnXIeRE5m+u5lNcveiTOdn8xHTBKCjmbU3s/rA2cDIilyYOP9F4M8VSQ6VpVWreMSUpZwpIlKtZC1BuPtG4ApgNDADeN7dp5nZbWbWF8DMDjGzBUB/4GEzSz44+wFwNHCBmU1JvA7IVqxJrVvDt9/CypXZ/iQRkfyX1ToId38FeKXUvpvT1icAbTJc9zTwdDZjy6R1ogp94ULYZZeq/nQRqSruXusG7NuW6gT1pE7TqlUsVVEtUnM1bNiQ5cuXb9MPZnWVnA+iYcOGW3VdXrdiqmrJEoSauorUXG3atGHBggUsXbo016FUqeSMcltDCSKNShAiNV+9evW2ala12kyPmNI0bAhNmypBiIiAEsRm1JtaRCQoQZSi3tQiIkEJopQ994TPPlNnORERJYhS9tsPVqyABQtyHYmISG4pQZSy//6x/PDD3MYhIpJrShCl7LdfLCt7likRkepGCaKUJk2gQwclCBERJYgM9t4bMszXISJSqyhBZJAc9ltEpDZTgsigsBC+/BI2bcp1JCIiuaMEkUGrVpEcatlYXiIiJShBZFBYGMvFi3Mbh4hILilBZKAEISKiBJFRcthvVVSLSG2mBJHB7rvHUiUIEanNlCAyaNAAWrSAzz/PdSQiIrmjBFGGo4+G117TqK4iUnspQZTh5JNjXoiPPsp1JCIiuaEEUYY+fWL5z3/mNg4RkVxRgihDYSF06QJjx+Y6EhGR3FCCKMexx8I778DGjbmORESk6ilBlOPYY2H1apg0KdeRiIhUPSWIchQVxXLatNzGISKSC0oQ5WjbFurWhTlzch2JiEjVU4IoR0EBtGmjBCEitZMSxBa0b68EISK1kxLEFrRrB3Pn5joKEZGqpwSxBe3bx6iua9fmOhIRkaqlBLEF7drFct68nIYhIlLllCC2YJ99Yvnpp7mNQ0SkqmU1QZhZbzObaWazzWxQhuNHm9lkM9toZmeVOna+mc1KvM7PZpzl2XffWH7ySa4iEBHJjawlCDOrCzwA9AG6AOeYWZdSp/0HuAB4ttS1TYFbgO8BhwK3mNmu2Yq1PLvuCi1bKkGISO2TzRLEocBsd//c3dcDQ4F+6Se4+1x3/wgoLnVtL+A1d//K3VcArwG9sxhruTp1UoIQkdonmwmiNTA/bXtBYl+lXWtmF5vZRDObuHTp0m0OdEs6dYIZMzR5kIjULtW6ktrdH3H3IncvatGiRdY+p1s3WLEC5s/f8rkiIjVFNhPEQqBt2nabxL5sX1vpjjgilu+8k6sIRESqXjYTxASgo5m1N7P6wNnAyApeOxo4ycx2TVROn5TYlxP77QdNmihBiEjtkrUE4e4bgSuIH/YZwPPuPs3MbjOzvgBmdoiZLQD6Aw+b2bTEtV8BvyKSzATgtsS+nKhbFw4/HP71r1xFICJS9Qqy+ebu/grwSql9N6etTyAeH2W69gngiWzGtzX23x/GjIHiYqhTrWtuREQqRj91FdSuHaxfD198ketIRESqhhJEBe25Zyw1squI1BZKEBWUTBAatE9EagsliApSghCR2kYJooIaN4ZmzfSISURqDyWIrbDnnvD557mOQkSkaihBbIUjj4S33oIlS3IdiYhI9ilBbIVLLommrr/9bfSHEBGpyZQgtkLnzvDjH8P998dLRKQmU4LYSo8/Dj16wJAhuY5ERCS7lCC2khmcfTZ89JEmERKRmk0JYhv07RvLsWNzGoaISFYpQWyDPfeEHXaAmTNzHYmISPYoQWyDOnVgn32UIESkZlOC2Eb77qsEISI1mxLENtp33xh2Y926XEciIpIdShDbaN99o7OcWjKJSE2lBLGNjjsOCgrgz3/OdSQiItmhBLGNWrWCM8+MjnOrVuU6GhGRyqcEsR2uvRZWroQ//SnXkYiIVD4liO1wyCFw8snw+9+rslpEah4liO30s5/B8uVwxRXQvj08+WSuIxIRqRxKENvphBOiZ/Vjj0Wz19Gjcx2RiEjlKMh1ANVdnTrRkmnOHHjuOTV7FZGaQyWISnD00XD++dC1a/Sunjcv1xGJiGw/JYhK1KkTrF0L7dpppFcRqf6UICpRp06p9VdfzV0cIiKVQQmiEh1wABx4YKy/9VZOQxER2W5KEJWoSROYPBkGDYLx42HJklxHJCKy7ZQgsuC886J1009+kutIRES2nRJEFnTpArfeCi+9FCUKEZHqSAkiSy65JKYlffDBXEciIrJtlCCyZJdd4H/+B/7yF1i4MNfRiIhsvawmCDPrbWYzzWy2mQ3KcLyBmQ1LHH/fzNol9tczs6fMbKqZzTCzG7IZZ7bccANs2gR33AGTJmlAPxGpXraYIMyspZk9bmajEttdzOyiClxXF3gA6AN0Ac4xsy6lTrsIWOHuewP3Ancm9vcHGrh7d+Bg4H+TyaM6ad8ezj0XHnkEiorgN7/JdUQiIhVXkRLEEGA00Cqx/SkwsALXHQrMdvfP3X09MBToV+qcfsBTifXhQE8zM8CBHc2sANgBWA98U4HPzDtXXAHr18f6o4/Chg25jUdEpKIqkiCau/vzQDGAu28ENlXgutbA/LTtBYl9Gc9JvO9KoBmRLNYAi4H/AHe7+1elP8DMLjaziWY2cenSpRUIqeoVFcWjpoED4YsvYMSIXEckIlIxFUkQa8ysGfG/esysB/FDnk2HEkmoFdAeuNbM9ip9krs/4u5F7l7UokWLLIe07X77W7j77hgWvH9/+P734dtvcx2ViEj5KpIgrgFGAh3M7N/An4ErK3DdQqBt2nabxL6M5yQeJ+0MLAd+CPzT3Te4+xLg30BRBT4zb9WtCz/9aay/+CLcfHNu4xER2ZItJgh3nwwcAxwO/C/Q1d0/qsB7TwA6mll7M6sPnE0kmnQjgfMT62cBb7q7E4+Vjgcwsx2BHkC1n2nh+uvhlVfg4ovhnnvg3XdzHZGISNksfo/LOcHsR5n2u/uft/jmZicD9wF1gSfc/Tdmdhsw0d1HmllD4C/AgcBXwNnu/rmZNQaeJFo/GfCku/+uvM8qKiryiRMnbimkvLBqFXTrBi1bxphNIiK5YmaT3D3jE5qKJIg/pm02BHoCk939rMoLcftVpwQB8MAD0cLpvPPg1FPhrLNi/CYRkaq0XQkiw5vtAgx1996VEFulqW4J4ptvoHVrWL06tv/wBzjzTGjVqvzrREQqU3kJYlv+z7qGaFkk22GnnWDCBPj002gKe9VVkTCWLct1ZCIioSI9qV8ys5GJ18vATODF7IdW83XqBB07xvwRSW+8kbt4RETSFVTgnLvT1jcC89x9QZbiqZXOPBOWL4cOHeC112DAgNj/4YdRmV23bm7jE5HaaYsJwt01eWYVaNoUevaEYcPghBNg//1jCtNnn4Vzzsl1dCJSG5X5iMnMVpnZNxleq8ysWo6LlO/uuisG+LvhBnjnndinCYdEJFfKTBDu3sTdd8rwauLuO1VlkLXFXnvBBRfA3LmpMZs+/jiXEYlIbVaROggAzGw3oh8EAO7+n6xEVMsddlgs//GPWE6blrtYRKR2q0grpr5mNguYA7wFzAVGZTmuWuugg1LrhYUwfz68/npsL10avbBFRKpCRfpB/IoYC+lTd29P9KR+L6tR1WINGsQMdLffHvNHAJx4Ijz2GOy2GxxzTG7jE5HaoyKPmDa4+3Izq2Nmddx9jJndl+3AarNf/CKW7jGg3+DBqZFgP/ggZ2GJSC1TkRLE14nB8/4FPGNm9xO9qSXLzKJO4k9/Krn/669zEo6I1DIVSRBjiHkafgb8E/gMOC2bQUlJHTvCRRdBo0ax/a9/wdq1uY1JRGq+iiSIAuBVYCzQBBjm7suzGZRs7rHH4LPPYr1vX2jeXPNJiEh2VWTCoFvdvStwOVAIvGVmr2c9MtnM7rvHZENXXx09ry+7DDZVZHZwEZFtUOF+EMAS4AtiStDdshOObMnDD8dyv/3gxz+GiRPhe9/LbUwiUjNVpB/EZWY2FngDaAb81N33y3ZgUr4+fWL5lkbKEpEsqUgdRFtgoLt3dffB7j4920HJlrVsCZ07w9NPq1WTiGRHReogbnD3KVUQi2yl446DqVPhiCOiz4SISGXSLMjV2C23xBDh06dHBzolCRGpTEoQ1dhuu8UjJoCDD4Zzz4VLL4WFC3Mbl4jUDEoQ1dzuu8NZZ8X6c8/BQw/B9dfH9n/+E01hR2loRRHZBlvTzFXy1F//Ct99BwMHxvDgzz0Hhx4Kv/xl7J89O9XqSUSkolSCqCF22CH6SNx/f2xffz20agXf/z689150qHv33ShRFBfnNlYRqR6UIGqYrl2hbl3YuDGGCT/99JhD4s03o7XTgw9qljoRqRgliBqmYUPo1CnWi4pSM9T165c6JznftYhIeZQgaqD9949lURF06BD1Ed99F01iCwvh3//ObXwiUj0oQdRAp5wCXbrEywwefzzqI26+GY48EkaPhjlz4Lzz4LXXch2tiOQr8xrSu6qoqMgnTpyY6zDy3pQpcPzxsPPOMHdu7KshfwIisg3MbJK7F2U6phJELXPAAdGZLpkcAL74IlfRiEg+U4KohY47ruT23XfH8r77YvgOERFQR7la6fDDoV496NEjRoS9996owL766jg+eHDUXYhI7aYEUQs1agS33hrNYY8/Hl5+GU4+OXV8/nzYY4/cxSci+SGrj5jMrLeZzTSz2WY2KMPxBmY2LHH8fTNrl3ZsPzMbZ2bTzGyqmTXMZqy1zQ03wBlnRGX1kCHRX+LCC+PYlCm5jExE8kXWEoSZ1QUeAPoAXYBzzKxLqdMuAla4+97AvcCdiWsLgKeBSxLzYR8LbMhWrLXdiSfGMBz33x+PlraUIKZOVcsnkdogmyWIQ4HZ7v65u68HhgL9Sp3TD3gqsT4c6GlmBpwEfOTuHwK4+3J335TFWAVo3DgeO735JowdC2vWbH7OpEkxH/brr1d5eCJSxbKZIFoD89O2FyT2ZTzH3TcCK4l5r/cB3MxGm9lkM/t5pg8ws4vNbKKZTVy6dGmlf4Ha6Mc/jnmujzsOTjsN1q8veXzSpFhqPCeRmi9fm7kWAEcC5yaWZ5hZz9Inufsj7l7k7kUtWrSo6hhrpJ/+FHbaCdq1gzFjYnC/Rx6Bt9+O41OnxvKzz3IWoohUkWy2YloItE3bbpPYl+mcBYl6h52B5URp4213XwZgZq8ABwFvZDFeAXbZBebNiyTRoUPMMQExnPj06amSw+ef5ypCEakq2SxBTAA6mll7M6sPnA2MLHXOSOD8xPpZwJseY3+MBrqbWaNE4jgGmJ7FWCXNLrtAnToxlwTAL34Ry+9/P+omoGQJYsMGaNIkxnwSkZojawkiUadwBfFjPwN43t2nmdltZtY3cdrjQDMzmw1cAwxKXLsCuIdIMlOAye7+j2zFKpndeisMHw633w5Dh0bJAqBlyxjsb1Oi2cCiRbB6dQzhISI1hwbrkwpbty7GcBozJpLBwIFwzz0wblxMRgRq/ipS3ZQ3WJ96UkuFNWgA++4LTZvCP/4RYzfNmgWLF+c6MhHJBiUI2WotWkTv6+bNI1Gk++abqOCGeARVp47GdRKprvK1mavkuWbNYhKi0j79NLVev36qoltEqh8lCNlm77wD111Xct/kyVEPsXYtFBfD3/8e+0t3uBOR/KcEIdusfXu45JLUdps28OqrMWx4kyap/ePHx/aMGVUfo4hsO9VByHZp3z613qtX5r4QL70UJYh//St6aG/cWDKBiEh+UglCtkudOjByZPSwPu202NeyZclzXn01lh98EEnkhBMikdxxB6xYUfZ7FxdnJ2YRqRiVIGS7JRNDly5RSb333pE4ksaPj+WQIVE3kb6vQYPUTHbpRo2KSYw++SSa1opI1VMJQiqNGXTsGMv33ksN0QExrtPatbDPPiWvmT49c0nhlVdi+be/ZS9eESmfEoRkxfe+B2edFet16sCNN8bsdSNHRvPYHj3g6KPhscdi7Kf5iYHhk62fmjeP7UmTQCO5i+SGEoRkTWFhLA8+OOaZWLYsHhe9/DI8/TR07hzHV62C//f/4Nxzo/L6yivhq6/i2AsvxARFmzRdlEiVUx2EZE3LltCoUVRKAxQk/toOPDCW6Y+b7rgD6taNWe0efLDksS++iPkojjsutW/t2ui1vdtu2f0OIrWZShCSNQUFMGEC3HRT5uN77RXLU06Bvn1jXux586Kj3cyZ8Rhq9eqYi2L48JLXXnJJJKBly7L7HURqMyUIyaouXaIUkUm/fjG39UsvwYgRcOihUU/ROjExbfPmsOOOkUBeeCEeM61YEQnkqcRM5n36wGuvVc13EaltlCAkZ8ygZ8/NB/Pr0CGWyYrq/v3hyy+j70TTplEBnjxv4kT41a9ie80auPxyuPjisj/zz3/evDQiIpkpQUjeST56SiaIU06Jx0z/+7+xPWFCLP/+d7j2Wnj//aiTePTRqOx+9FFYWHpy24Sbb4Zbbon1WbM0f4VIeZQgJO8ke2LvuGNqmWwy27o1vPUW/OEP0LVrNJVdvz6SxiefpN6jTRu47LKY7e6EE2K5alXUcUyfHlOn7rNPqr+FiGxOCULyTnI+idWrU/uS05l27BhJ4cor49HUkUfG/n//O0oEBx2UuuYvf4Gf/xzeeAP++tdIDEn33x/LMWOy9z1EqjslCMk7mZrB9ugBDz0ETz5Z8tymTWMCozlzIkF06RLNZSESzDPPxHq9ejFeVFJyGPLf/x7OO0/jPolkogQheadPnxir6ac/Te0zizqIdu02P79Nm0gO8+dHCWPq1Bg5Nn1Co0cfhZ/8JNbT+1NAdNrLVGfxzDPwgx8oeUjtpQQheemQQyo+VWmbNtGRDiJBdO4cj57Gjo0KboApU2L5s5/BD3+4+XvMmgW//GU8inKPQQd/8pPYHjlye7+NSPWkBCHVXps2qaE40kd+7dgxhvXYf//YPuwwuO++1DSoTz8Nc+fG+rRp8LvfwZ/+FElh331TI8/+4Q/w7LPxGKu0NWvikdjYsRWLtWdP+M1vtvILiuSIhtqQaq9Nm9R6166bH999d/jww1T/iqZNU81bi4uhYcPorLdxY7SGeuihktdPmBCV2c2bbz5w4LRpUTp5/XU49tjy43SHceNSlfAi+U4lCKn2kgmiSZOYX6K0XXeNZTJBpKtTJ+avSPbG/u67ki2bjj8+1Zpq2bLN6yM+/TSWn3225ThXr473X7Jky+eK5AMlCKn2kkNzlJ7JLin5qCjZAa+0jh0333fAAbFM9r9ImjIlShpJM2fGsiIJ4osvYpkshcyaBf/3f+qsJ/lLCUKqveSw4r17Zz6+Zk0smzXLfPyii2K5005RB/HJJ1Gn8I9/lHxsVLduDPvRrFn0u4CtK0F8+WUskyWIU06BX/86VQ8ikm9UByHVXpcu8YN++OGZj99yS3SSK+v4KafAm2/G46n0c04+Gdati8dQXbrEeXfeGYmiX7/opzFuXJz71VcxkGDycVYmyQSxcmU0yZ01K7bnzoX27bfmG4tUDZUgpEY45pjoDJfJEUfAggXl/3gfd1zmBNKgQTSbPfjgGMfpkUdigEBIjQmVnPjonXdi6I677so8wVEyQQDssUdqfc6ceGzVuXNqnCiRfKAShMgWvPZaDFneqFGq896778LXX6c64x1+OJxxBtSvHxXRM2fGmFH/939w/vlxTnqCSDdnTtRtfPIJ3HYbnHZaNJ1N9gjPpLg4SjaZXHttDDly7rnb8m1FUlSCENmCwsKYpyLdPvvE/BVt2sTro4+iDqRRo3gU9cQTUS9xwQXx6AlSldRJPXvGDHpz5pRsOXXVVTHZ0gUXZI7n2mujye2mTbB4cXQKTA4jsm5d9Nt4/PHK+OZS25nXkCYURUVFPjFZ9hfJAff4gZ4wIQYUhPhf/hlnwHPPQa9e8N57UcJInn/ssVHSAOjUKeonFi9OvefSpalhzzduhA8+iMQEMHRo1K3cdlvM+f3EEzBpEhQVxeO05csr3htdai8zm+TuRZmOqQQhUknMotPdEUfEI6e77oof7xdegNNPj1JCstlssr4k+eNfWBitqdIHKIQomST/Dzd0aCo5QHToe+yx1LGVK2Hy5NhesSIqwjOZODGGSBfZEiUIkUpWp040g73+ejjzzNj3yiuRNJ56KoYgf//92H/HHanBAq+7LpUgevWK5aRJ0SfjuutKjkYLUfJYtChGo/3uu6jHmDQpdfyFF2Dw4FQ/EIgBCA85BB5+uOR7/fznMHr01n3Phx+uWPNeqcbcPWsvoDcwE5gNDMpwvAEwLHH8faBdqeN7AKuB67b0WQcffLCL5JsNG9zr13cH9yuv3PL5v/tdnHvnne4tW7o3bhzb4N6wYWr9pJNS62PGxLJPH/eCAvfjj3dv1Ch1fMSIVCy77Rb7Lroo9q1e7b52rbuZe8+eqThGjXK///6y4/z663ifyy/f5lsjeQKY6GX8rmatBGFmdYEHgD5AF+AcM+tS6rSLgBXuvjdwL3BnqeP3AKOyFaNIthUURN0CwH77bfn85GCD3brFIIOrV8fjpyZNSpYEBgyIZWEhHHVUPLIaNSqGQ3/hhRju/MYb45z334+xov74x1QnvWnTYuDCxo2jCa97jIi7alUcv//+uL6sKspk577x4yt4I6RayuYjpkOB2e7+ubuvB4YC/Uqd0w94KrE+HOhpFtVqZnY6MAeYlsUYRbKuW7dYViRB9OkTlc0nnRSjy15/fUySlKz0TurXL/poHHZYNIdNdrQ7/njYZZdo5vrrX0dz2d/+Fk48Ea65Js754Q+jbuPVV2N7WuJf2IYN8OCDcPbZURm+enX0H8lk3rxYTpkSFfNSM2UzQbQG0qvJFiT2ZTzH3TcCK4FmZtYY+AVwa3kfYGYXm9lEM5u4tPQwmyJ5okePKAFkGmm2tIKCaJFUUBAJ5a67oj5i773j+I9+FAmjWbOoT7jttpLXp1diQ2ogw/TPPuEE+PbbKHGkM4s5MYYNS/XZSJ+mNV0yQWzYECPlbo2lSyPxPfBA+eddc03U2Uju5Gsl9WDgXndfXd5J7v6Iuxe5e1GLFi2qJjKRrXTppTB7Nuy447a/x//8TywvuSTVP+LMM1M//MkRZ4tKNVa8/PKo5B4zBn71q2jtdNRRqeN9+qTWTzxx8x7gyQRxzz3QvXsqccybl+qol6xwr6hnnoG774YrroBvvsl8zrp1MS5W6cp0qVrZ7Em9EGibtt0msS/TOQvMrADYGVgOfA84y8zuAnYBis1srbv/KYvximRFQQHsttv2vUdRUfSeLqtfwzPPRL1B6VJKr16pFlE33bT5dT/4QZQkmjaNgQiTj52SZsyIIUSuvTa2n3giWmPNnRuj4H7zTcl6iFmzorS0++7RsuqddyLxuMeyf/+SLa3mzYvEU9rUqVE6mTw5muTWr5/5e7tHf49kc2GpZGXVXm/vi0g+nwPtgfrAh0DXUudcDjyUWD8beD7D+wxGrZhEKtWLL7q3b+++YoX7Dju4779/rP/4x+4tWkQLpaOOcm/a1P3MM6N1VLJV1A47uHfuHC2p+vVz32efeM/Vq1PnLFzofuGFsf7JJ+7vvhvrXbrEtcnWVC+9lDm+hx5Kvdf48e4rV5Y8d9Mm9zvucL/55jjn00+zertqNHLRismjTuEKYDQwI/HjP83MbjOzvonTHifqHGYD1wCDshWPiKScfjp8/nlUaB9wQAwUuMsuUUKYOjWGOn/00SgFvPBCDB+SbI313XdRsthvP/je92LI8xUrUp32IOoy3ngj1n/0o9RAiNOnx7XJ/iFlDXU+aVJq8qf33ovK89NOg//8J/Z98AEMGpSqg0kfqkQqUVmZo7q9VIIQ2TZLl0a/hkzGjXPv3dv97bfdTzgh9b96cF+wwH3s2Fh/7jn3PfZwP/JI98suc2/QoOS54H7qqan1V1+Nc667ruTnrV/v/vHH7m3auJ9yinuTJu5XXeV+9tlx3R13REwPPljyvc87L3P8330X38/dff5899NOc581q/LuXU1AOSWInP+wV9ZLCUIku+bPd3/00UgGQ4bEvg0bUo+kwP3ll92nT988OdxwQzwWWr7cfepU9+LieDR11lmxf/x49zVr3C++OHXN3/4Wj6TOOCOW6e930UUlt1u3jvcprW9f9513jpg6d45z7767Sm/bf82ZE4/tvvwyN59fFiUIEcma5I/1UUfFD797lCLSf8CHDt38usMOi2OlSybgvu++kXxOOsm9e3f3unVLHu/WbfNrGjRwHzw49f4TJsR+M/eddkqd17lz9Dpfu3b7vve4cZG4lizZuvtUXg/1XCgvQeRrM1cRqSZuuSWa0I4alWpl9cc/Rm/uZF3D/vtvfl3//rF8/fXUvr32in4SH30Urb/ato06kdLNb9PHpbriivjcdeti7Kni4tg/bFi0fnrppej3cdhh0U9kxoyI9Z//jA6JF164eY/xVatSc4+7x1DtS5ZE3Ujy3MceizqVF16IsbS25OuvY7lhw5bPTbdkCbz88tZdU1mUIERku7RtG01o0/t51KkT81Qceii0bBlNYku7+mr4+99jvVGj+CH/8MNospps1to2raF8stNf0hFHxLJ9+2hqmxwpNzmi7dSpqaliJ0+OH9kuaYP93HVXTAb15JNR4e0ezXRPPjnmJ7/00jjvuediBsDevaO58SWXRPIYMSKOX3ppxFY6yaRbuzYaBUDM/7E1rr02Kui3tr9JZVCCEJGsueaaaOVU1ux4xx6bSiannhpjQ6VLTxBTp8aouEnJhNC2bfxwJ3tmP/dcLD/+ONXHonv36OvRsGHq+nffjeUxx0SyuPDCKFEk5xkfOzbm5hg1Kv7X/8EHsf+pp6LUsGxZyVg/+ST6oTzzTGwnSz3LlkWSTF6fTBQVlRw/q3Sv+SpR1rOn6vZSHYRI9XTffTEibSajR8dz+z32iO0NG6JPRmFh9LU4/PBoTZV0/vlx/o03xvKuuzK/3xtvpOokiovdBw2K9R13dH//fffbbtu8jgNK7t9tN/dnn031EenVK1XH8cQTUTk+YUK01Epe07Jlqt9IWYqLY+TfG290X7XKvWvXuLagwP3bb+OcZcvcf/lL93XrtuGGl4IqqUWkOkq2iLrhhtS+bt2i814m69a5DxiQ+kF+5ZXNz9mwIZbvvBPNd5MWLIgOee7RKS89MbRoET/4a9ZEBTe433RTnLt+feq8evVKXtehQySS/v3j3J//PIZ/37gxc/y//GXJ648+OpadOsVy1Kj4vO9/P7W9vZQgRKTa+uCDkk1Y584tv+VQcbH7vfe6H3KI+1dfbdtnzp+f+pE+5hj32bPdFy1Kvf/48SX/9967d5w7fHiUbpItsZLvkWy5NGxYbPfpE623rr8+Pmv16ji+ww6pax59NLV+++2p9csvT5337LPb9v3SKUGIiGyF4uLoqPf00xU7f82a1I/8Z5+5X3tt/O8++aM+dmzqfX//ey9RSigoiCa6U6e6N2sWnQSTJZtksnn77ZLXJF+lH6Fti/IShCqpRURKMYtWT+eeW7HzGzVKteLaa68YrfbII1Mj3iab+ZpF661k81+IFlHr1sWIucuXw8CBqRF3x42Dq66K1mAjR0Ylerqy5uuoLNkczVVEpNZq3DjGr1qzJsa5SjKDv/41WicVFkLfvtHi6ckn43h6k+A994xReiGaunbpEnODdOgQY1UpQYiIVFM33RSDG5ZmFk1fX3wxShfjxqU6w+2zT9nv16EDDBkCPXvCRRdlP0FYPIKq/oqKinzixIm5DkNEZKsVF8MOO8TcF2vXpkayLc9FF0Vv8Ir04i6PmU1y96JMx1SCEBHJsTp1YNGiGAakIskBovf24sXRia9evSzFlZ23FRGRrdGsWVRsV1T37tGW6bHHUvNkVDYlCBGRaujEE2N52WUxrWxykMLKpAQhIlIN7bwztGgR67ffnmpSW5lUByEiUk0NHx6DDp5+enbeXwlCRKSaOvroeGWLHjGJiEhGShAiIpKREoSIiGSkBCEiIhkpQYiISEZKECIikpEShIiIZKQEISIiGdWY4b7NbCkwbzveojmwrJLCyTbFmh2KNTsUa3ZUVqx7unuLTAdqTILYXmY2sawx0fONYs0OxZodijU7qiJWPWISEZGMlCBERCQjJYiUR3IdwFZQrNmhWLNDsWZH1mNVHYSIiGSkEoSIiGSkBCEiIhnV+gRhZr3NbKaZzTazQbmOpzQzm2tmU81siplNTOxramavmdmsxHLXHMb3hJktMbOP0/ZljM/CHxL3+iMzOygPYh1sZgsT93eKmZ2cduyGRKwzzaxXFcbZ1szGmNl0M5tmZj9L7M+7+1pOrHl3XxOf3dDMxpvZh4l4b03sb29m7yfiGmZm9RP7GyS2ZyeOt8uDWIeY2Zy0e3tAYn/l/x24e619AXWBz4C9gPrAh0CXXMdVKsa5QPNS++4CBiXWBwF35jC+o4GDgI+3FB9wMjAKMKAH8H4exDoYuC7DuV0Sfw8NgPaJv5O6VRRnIXBQYr0J8Gkinry7r+XEmnf3NfH5BjROrNcD3k/cs+eBsxP7HwIuTaxfBjyUWD8bGJYHsQ4BzspwfqX/HdT2EsShwGx3/9zd1wNDgX45jqki+gFPJdafAk7PVSDu/jbwVandZcXXD/izh/eAXcyssEoCpcxYy9IPGOru69x9DjCb+HvJOndf7O6TE+urgBlAa/LwvpYTa1lydl8BEvdodWKzXuLlwPHA8MT+0vc2ec+HAz3NzHIca1kq/e+gtieI1sD8tO0FlP/HnQsOvGpmk8zs4sS+lu6+OLH+BdAyN6GVqaz48vV+X5Eokj+R9rguL2JNPNI4kPjfY17f11KxQp7eVzOra2ZTgCXAa0Qp5mt335ghpv/Gmzi+EmiWq1jdPXlvf5O4t/eaWYPSsSZs972t7QmiOjjS3Q8C+gCXm1mJKco9ypZ521Y53+MDHgQ6AAcAi4Hf5zSaNGbWGHgBGOju36Qfy7f7miHWvL2v7r7J3Q8A2hCll065jahspWM1s27ADUTMhwBNgV9k6/Nre4JYCLRN226T2Jc33H1hYrkEeJH4g/4yWXRMLJfkLsKMyoov7+63u3+Z+EdYDDxK6nFHTmM1s3rED+4z7v63xO68vK+ZYs3X+5rO3b8GxgCHEY9jCjLE9N94E8d3BpZXbaQlYu2deKzn7r4OeJIs3tvaniAmAB0TLRjqE5VQI3Mc03+Z2Y5m1iS5DpwEfEzEeH7itPOBEbmJsExlxTcS+FGitUUPYGXaI5OcKPWM9gzi/kLEenaiFUt7oCMwvopiMuBxYIa735N2KO/ua1mx5uN9TcTVwsx2SazvAJxI1JuMAc5KnFb63ibv+VnAm4nSW65i/STtPwlG1JWk39vK/TvIdk18vr+Imv9PieeQN+Y6nlKx7UW0+PgQmJaMj3gG+gYwC3gdaJrDGJ8jHiFsIJ55XlRWfETrigcS93oqUJQHsf4lEctHiX9ghWnn35iIdSbQpwrjPJJ4fPQRMCXxOjkf72s5sebdfU189n7AB4m4PgZuTuzfi0hUs4G/Ag0S+xsmtmcnju+VB7G+mbi3HwNPk2rpVOl/BxpqQ0REMqrtj5hERKQMShAiIpKREoSIiGSkBCEiIhkpQYiISEZKECKVzMwGmlmjXMchsr3UzFWkkpnZXKIN+rJcxyKyPVSCENkOid7u/0iM2f+xmd0CtALGmNmYxDknmdk4M5tsZn9NjFuUnOvjLov5Psab2d65/C4ipSlBiGyf3sAid9/f3bsB9wGLgOPc/Tgzaw7cBJzgMejiROCatOtXunt34E+Ja0XyhhKEyPaZCpxoZnea2VHuvrLU8R7EJDn/TgzbfD6wZ9rx59KWh2U7WJGtUbDlU0SkLO7+aWJqx5OBX5vZG6VOMWIc/3PKeosy1kVyTiUIke1gZq2Ab939aeB3xJSmq4jpNwHeA45I1i8k6iz2SXuLAWnLcVUTtUjFqAQhsn26A78zs2JilNhLiUdF/zSzRYl6iAuA59Jm/rqJGEEYYFcz+whYB5RVyhDJCTVzFckRNYeVfKdHTCIikpFKECIikpFKECIikpEShIiIZKQEISIiGSlBiIhIRkoQIiKS0f8HyZqskw1tn6EAAAAASUVORK5CYII=\n",
267 | "text/plain": [
268 | ""
269 | ]
270 | },
271 | "metadata": {
272 | "needs_background": "light"
273 | },
274 | "output_type": "display_data"
275 | }
276 | ],
277 | "source": [
278 | "''' Step 4. 在训练集上进行训练 '''\n",
279 | "\n",
280 | "\n",
281 | "def MNIST_trains(net):\n",
282 | " # 选择 SGD 随机梯度下降算法作为优化方法,导入网络参数、学习率以及动量\n",
283 | " optimizer = optim.SGD(net.parameters(), lr = learning_rate, momentum = momentum)\n",
284 | "\n",
285 | " train_loss = []\n",
286 | "\n",
287 | " for epoch in range(n_epochs):\n",
288 | " for batch_idx, (x, y) in enumerate(train_loader):\n",
289 | " # 将数据 x 打平\n",
290 | " # x: [b, 1, 28, 28] -> [b, 784]\n",
291 | " x = x.view(x.size(0), 28 * 28).to(device)\n",
292 | " # 经过神经网络 [b, 784] -> [b, 10]\n",
293 | " out = net(x).to(device)\n",
294 | " # 将数据的真实标签 y 转换为 one hot 向量\n",
295 | " y_one_hot = one_hot(y).to(device)\n",
296 | " # 计算 网络预测值 out 与 真实标签 y 的 mse 均方差\n",
297 | " # loss = mse(out, y_one_hot)\n",
298 | " loss = F.mse_loss(out, y_one_hot)\n",
299 | " # zero grad 清空历史梯度数据\n",
300 | " optimizer.zero_grad()\n",
301 | " # 进行反向传播,计算当前梯度\n",
302 | " loss.backward()\n",
303 | " # 根据当前梯度更新网络参数\n",
304 | " # w' = w - lr * grad\n",
305 | " optimizer.step()\n",
306 | " # 保存当前的损失函数值\n",
307 | " train_loss.append(loss.item())\n",
308 | " # 每 10 步 输出一次数据查看训练情况\n",
309 | " if batch_idx % 10 == 0:\n",
310 | " print(f\"epoch:{epoch}, iteration:{batch_idx}, loss:{loss.item()}\")\n",
311 | " # 绘制损失函数图像\n",
312 | " # [w1, b1, w2, b2, w3, b3]\n",
313 | " plot_curve(train_loss)\n",
314 | " \n",
315 | "MNIST_trains(net)"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 7,
321 | "metadata": {},
322 | "outputs": [
323 | {
324 | "name": "stdout",
325 | "output_type": "stream",
326 | "text": [
327 | "test_acc: 0.8927\n"
328 | ]
329 | }
330 | ],
331 | "source": [
332 | "''' Step 5. 在测试集中进行测试 '''\n",
333 | "\n",
334 | "\n",
335 | "def MNIST_tests(net):\n",
336 | " # 在测试集中预测正确的总数\n",
337 | " total_correct = 0\n",
338 | " # 迭代所有测试数据\n",
339 | " for x, y in test_loader:\n",
340 | " # 将图片 x 打平\n",
341 | " x = x.view(x.size(0), 28 * 28).to(device)\n",
342 | " # 经过已经训练好的神经网络 net\n",
343 | " out = net(x).to(device)\n",
344 | " # 预测值 pred: argmax 返回指定维度最大值的索引\n",
345 | " # out [b, 10] -> pred [b]\n",
346 | " pred = out.argmax(dim = 1).to(device)\n",
347 | " # 计算预测值等于真实标签的样本数量\n",
348 | " correct = pred.eq(y.to(device)).sum().float().item()\n",
349 | " # 计算预测正确样本的总数\n",
350 | " total_correct += correct\n",
351 | " # 总样本数即为测试集的长度\n",
352 | " total_num = len(test_loader.dataset)\n",
353 | " # 计算正确率\n",
354 | " acc = total_correct / total_num\n",
355 | " # 输出测试正确率 acc\n",
356 | " print(\"test_acc:\", acc)\n",
357 | " \n",
358 | "MNIST_tests(net)"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 8,
364 | "metadata": {},
365 | "outputs": [
366 | {
367 | "data": {
368 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAELCAYAAAD+9XA2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAYc0lEQVR4nO3dfZAU1bnH8d+BICDeKBJUlAgBb1AkvAlKCMaXICiCooBQkFzRimhKolUoJIq5vpcJ3jIxiWCSKgrFRFOCWiiIUClepFBLuIqKEC8YXiyXhBVBlpcIy7l/zNj2aXeGndkz/bL7/VR1eZ49vT3PDsd59nT3njbWWgEA0FDNkk4AANA4UFAAAF5QUAAAXlBQAABeUFAAAF5QUAAAXlBQAABeZKqgGGO2GGMGezjORGPMqhK/Z70xpia0HTbGvNjQXBCPhMfO/xhj/s8Ys9cYs9EY818NzQPxSXjsXGOMWW2M2W+MWd7QHCrta0knkBXW2rO/aBtjjKQPJT2bXEbIkH2SRkj6QFJ/SYuNMZustauTTQsZsEvSbySdKeniZFOpB2ttJjZJcyUdkXRAUo2kafmvD5C0WtJuSeskXRj6nonKffDvlfQPSRMknSXpoKTa/HF2l5HLBfljtkn6fWHL1tjJH3uBpNuSfl/YsjN2JP1Y0vKk34+j5pl0AiW+qVskDQ7Fp0n6RNIw5U7fXZKP20tqI+kzSd3y+3aQdHboH3xV5NjjJb1TzzxmS5qT9PvBlsmx01pSlaRLk35P2LIzdrJSUDJ1DaUOP5S0yFq7yFp7xFq7VNIa5f6hpdxvFj2MMa2ttVXW2vWFDmSt/Yu1tufRXtAYc6yk0ZLmNDx9JCj2sZP3uHK/0b7SkOSRqKTGTuplvaB0kjTGGLP7i03SIEkdrLX7JI2VdJOkKmPMQmPMmR5e82rlzmuu8HAsJCf2sWOMeVhSD0nX2PyvncikJD53MiFrBSX6P+F2SXOttSeEtjbW2l9KkrX2FWvtJcpNOzdK+lOB45TiWklP8oGQOYmOHWPMvZIukzTEWvtZeT8CEpKGz51MyFpB+aekLqH4KUkjjDFDjTHNjTGtjDEXGmM6GmNONsZcaYxpI+nfyl0IOxI6TkdjzDGlvLgxpqOkiyQ90fAfBTFLbOwYY+5Q7lz5YGvtJ35+HMQoybHT3BjTSrk7cpvlX6uFnx+rApK+iFPKJulKSduUu7Pi9vzXzlPu9NMuSTslLZR0unK/HayQtCe//3JJ3fPfc0x+v12SqvNfmyBp/VFe/w5Jryb9PrBla+wo95vpFx8uX2x3Jv2esGVi7EzMj5/wNifp96TQZvJJAwDQIFk75QUASCkKCgDACwoKAMALCgoAwAsKCgDAi5JWGzbGcEtYCllrTdI5FMO4Sa1qa237pJMohrGTWnWOHWYoQNO1NekEkFl1jh0KCgDACwoKAMALCgoAwAsKCgDACwoKAMALCgoAwAsKCgDACwoKAMCLkv5SHmgMbr/9didu3bq1E/fs2TNojx49uuixZs2aFbRfe+01p2/u3LnlpghkEjMUAIAXFBQAgBcUFACAFyU9U56VP9OJ1YaP7q9//WvQPtp1kXJt3rzZiQcPHuzE27Ztq8jrNsBaa22/pJMoJg1jJw7f/va3nXjjxo1OfOuttwbt3/3ud7HkdBR1jh1mKAAALygoAAAvuG0YjVL4FJdU2mmu8OmGV155xenr0qWLE48YMSJod+3a1embMGGCEz/00EP1zgFNS58+fZz4yJEjTvzRRx/FmU7ZmKEAALygoAAAvKCgAAC84BoKGoV+/dw7GK+66qqC+65fv96Jr7jiCieurq4O2jU1NU7fMccc48Svv/560O7Vq5fT165duyIZA1/q3bu3E+/bt8+Jn3/++RizKR8zFACAFxQUAIAXqTjlFb6l84YbbnD6Pv74Yyc+ePBg0P7zn//s9O3YscOJN23a5CtFpFyHDh2c2Bh38YDwaa6hQ4c6fVVVVfV+ndtuu82Ju3fvXnDfhQsX1vu4aHp69OgRtCdPnuz0ZXWlamYoAAAvKCgAAC8oKAAAL1JxDWXGjBlBu3PnzvX+vhtvvNGJ9+7d68TR20PjEF4iIfxzSdKaNWviTqfJePHFF534jDPOcOLw2Ni1a1fZrzNu3DgnbtGiRdnHQtN25plnBu02bdo4fdGlg7KCGQoAwAsKCgDACwoKAMCLVFxDCf/tSc+ePZ2+DRs2OPFZZ50VtPv27ev0XXjhhU48YMCAoL19+3an75vf/Ga98zt8+LAT79y5M2hH//4hLPqEPq6hxGfr1q1ejjN16lQnjj5ZL+yNN94oGgNh06ZNC9rR8ZrVzwpmKAAALygoAAAvUnHK629/+1ud7bosXry4YF/btm2dOLyC59q1a52+/v371zu/8HIvkvTBBx8E7egpuRNPPDFob968ud6vgfQYPnx40L7vvvucvuhqw//617+C9h133OH07d+/vwLZIauifxIRXiE7/JkifXW14axghgIA8IKCAgDwgoICAPAiFddQfPn000+deNmyZQX3Pdq1mmJGjRoVtKPXbd59992gndXlE5q68Lnt6DWTqPC/8YoVKyqWE7LvggsuKNgX/lOELGOGAgDwgoICAPCCggIA8KJRXUOplJNOOsmJZ86cGbSbNXNrcvjvFhqyTDri88ILLzjxkCFDCu775JNPOvFdd91ViZTQCH3nO98p2Bd91EVWMUMBAHhBQQEAeMEpr3q4+eabnbh9+/ZBO3qr8t///vdYckL5oitEDxw40IlbtmwZtKurq52+Bx54wIlramo8Z4fGIrzauSRdd911TvzWW28F7aVLl8aSU6UxQwEAeEFBAQB4QUEBAHjBNZQ6fO9733Pin//85wX3HTlypBO/9957lUgJHs2fP9+J27VrV3Dfp556yol5JAHqa/DgwU4cfrSF5D6KI/qIjKxihgIA8IKCAgDwgoICAPCCayh1GDZsmBO3aNHCicNL37/22mux5ISGueKKK4J23759i+67fPnyoH333XdXKiU0cr169XJia60Tz5s3L850YsEMBQDgBQUFAOAFp7zyWrduHbQvvfRSp+/zzz934vBpkEOHDlU2MZQleivwnXfeGbSjpzCj3n777aDN0iooxSmnnBK0zz//fKcvuizT888/H0tOcWKGAgDwgoICAPCCggIA8IJrKHlTp04N2n369HH6wkskSNLq1atjyQnlu+2225y4f//+BfeNPrGRW4VRrokTJwbt6JNeX3755ZiziR8zFACAFxQUAIAXFBQAgBdN9hrK5Zdf7sS/+MUvgvZnn33m9N13332x5AR/pkyZUu99J0+e7MT87QnK1alTp4J90ceFN0bMUAAAXlBQAABeNJlTXtGlOH772986cfPmzYP2okWLnL7XX3+9cokhcdEn6ZW7nM6ePXuKHie85Mvxxx9f8DgnnHCCE5dy+q62ttaJf/aznwXt/fv31/s4KM/w4cML9r344osxZpIMZigAAC8oKAAALygoAAAvGvU1lPB1kejyKd/61recePPmzUE7fAsxGr933nnHy3GeffZZJ66qqnLik08+OWiPHTvWy2sezY4dO4L2gw8+GMtrNiWDBg1y4vDy9U0RMxQAgBcUFACAF436lFfXrl2D9jnnnFN03/CtmeHTX8im6K3fV155ZcVfc8yYMWV/7+HDh4P2kSNHiu67YMGCoL1mzZqi+7766qtl54Sju+qqq5w4fJr9rbfecvpWrlwZS05JYoYCAPCCggIA8IKCAgDwolFdQ4mu9LlkyZKC+4af0ChJL730UkVyQjKuvvpqJ542bVrQDi+BcjRnn322E5dyu+/s2bOdeMuWLQX3nT9/ftDeuHFjvV8D8Tr22GOdeNiwYQX3nTdvnhNHl8VpjJihAAC8oKAAALygoAAAvGhU11AmTZrkxKeffnrBfVesWOHE1tqK5IR0mDFjhpfjjB8/3stxkE3RRxJEn8IY/huhRx99NJac0oQZCgDACwoKAMCLTJ/yiq70+dOf/jShTAA0BdFTXgMHDkwok3RihgIA8IKCAgDwgoICAPAi09dQzj//fCc+7rjjCu4bXZK+pqamIjkBQFPFDAUA4AUFBQDgBQUFAOBFpq+hHM26deuC9g9+8AOnb9euXXGnAwCNGjMUAIAXFBQAgBemlFV2jTEsyZtC1lqTdA7FMG5Sa621tl/SSRTD2EmtOscOMxQAgBcUFACAFxQUAIAXpd42XC1payUSQdk6JZ1APTBu0omxg3LVOXZKuigPAEAhnPICAHhBQQEAeEFBAQB4QUEBAHhBQQEAeEFBAQB4QUEBAHhBQQEAeEFBAQB4QUEBAHhBQQEAeEFBAQB4QUEBAHhBQQEAeJGpgmKM2WKMGezhOBONMatK/J6WxpjZxpjPjDE7jDFTGpoH4pPk2Al974nGmJ3lfj+SkfDnzjXGmNXGmP3GmOUNzaHSSn3AVlN2j6T/VO7BMqdIWmaMed9auzjRrJAlv5K0QRn7RQ6J2iXpN5LOlHRxsqnUg7U2E5ukuZKOSDogqUbStPzXB0haLWm3pHWSLgx9z0RJH0raK+kfkiZIOkvSQUm1+ePsrufrfyxpSCi+X9IzSb8vbOkfO/njDZT0mqTrJK1K+j1hy87YyR/zx5KWJ/1+HDXPpBMo8U3dImlwKD5N0ieShin3W98l+bi9pDaSPpPULb9vB0lnh/7BV0WOPV7SOwVet60kK+nk0NdGS3o36feELd1jJ9/fXNL/Sjqnru9nS/eW5NgJ7ZeJgpL1qfcPJS2y1i6y1h6x1i6VtEa5f2gp95tFD2NMa2ttlbV2faEDWWv/Yq3tWaD7uPx/94S+tkfSfzQwfyQnrrEjSbdIesNau9Zb9khSnGMnU7JeUDpJGmOM2f3FJmmQpA7W2n2Sxkq6SVKVMWahMebMMl+nJv/fr4e+9nXlprTIpljGjjHmVOUKynRPeSN5cX3uZE7WCoqNxNslzbXWnhDa2lhrfylJ1tpXrLWXKDft3CjpTwWOU/xFrf1UUpWkXqEv95JU8DcPpE4iY0fSufljvG+M2SHpUUnn5u8UbF72T4M4JTV2MidrBeWfkrqE4qckjTDGDDXGNDfGtDLGXGiM6WiMOdkYc6Uxpo2kfys3yzgSOk5HY8wxJbz2k5LuMsa0zf/GcYOkOQ3+iRCXpMbOy5I6S+qd3/5b0luSeltraxv6QyEWiX3ufHF85e7IbZZ/rRZ+fqwKSPoiTimbpCslbVPuzorb8187T9IK5W6v2ylpoaTTlfvtYIVy1zp2S1ouqXv+e47J77dLUnX+axMkrS/y2i0lzVbugts/JU1J+v1gy8bYieQxUVyUz9SW8OfOROVmNuFtTtLvSaHN5JMGAKBBsnbKCwCQUhQUAIAXFBQAgBcUFACAFxQUAIAXJa02bIzhlrAUstaapHMohnGTWtXW2vZJJ1EMYye16hw7zFCApmtr0gkgs+ocOxQUAIAXFBQAgBcUFACAFxQUAIAXFBQAgBcUFACAFxQUAIAXFBQAgBcUFACAFxQUAIAXFBQAgBcUFACAFyWtNpw1bdq0CdoPP/yw03fjjTc68dq1a4P2mDFjnL6tW1lDDwCOhhkKAMALCgoAwItGfcqrQ4cOQfuGG25w+o4cOeLE55xzTtAePny40/fYY49VIDskpW/fvk783HPPOXHnzp0rnsOQIUOceMOGDUF7+/btFX99pMuIESOceMGCBU48efLkoP344487fbW1tZVLrETMUAAAXlBQAABeUFAAAF40qmso7du3d+InnngioUyQZkOHDnXili1bxp5D9Jz59ddfH7THjRsXdzpIQLt27YL2zJkzi+77+9//PmjPnj3b6Ttw4IDfxBqAGQoAwAsKCgDAi0yf8rrllluceOTIkU587rnnlnXc73//+07crJlbd9etWxe0V65cWdZrIF5f+9qXQ33YsGEJZpITXplBkqZMmRK0wys8SNK+fftiyQnxCn/OdOzYsei+Tz/9dNA+ePBgxXJqKGYoAAAvKCgAAC8oKAAALzJ9DeXXv/61E0eXUynX1VdfXTQOrz48duxYpy96bhzpcNFFFwXt7373u07fjBkz4k5Hbdu2deLu3bsH7WOPPdbp4xpK4xC9PX369On1/t65c+cGbWutt5x8Y4YCAPCCggIA8IKCAgDwwpRyPs4Yk/jJu0WLFgXtyy67zOlryDWUTz75JGjX1NQ4fZ06dar3cZo3b152DuWy1prYX7QESYybHj16OPHy5cuDdvjfWnIfXSB99d+/EsL5SNKgQYOCdvixC5K0c+fOSqWx1lrbr1IH9yENnzm+9OvnvtVvvvlmwX0PHz7sxC1atKhITg1Q59hhhgIA8IKCAgDwIvW3DV9wwQVO3K1bt6AdPcVVyimv6FPPlixZErT37Nnj9F188cVOXOx2v5/85CdBe9asWfXOB37dddddThxezuTSSy91+uI4xSVJJ554YtCOjmtft7wjvUaNGlXvfcOfR1nCDAUA4AUFBQDgBQUFAOBF6q6hdO7c2YmfeeYZJ/7GN75R72OFl0iZP3++03fvvfc68f79++t1HEmaNGlS0I4+JTK8jEerVq2cvvBT1yTp0KFDBV8TpRk9erQTR5eo37RpU9Bes2ZNLDlFha+9Ra+ZhG8j3r17d0wZIU7Rx2KEff75505cyrIsacIMBQDgBQUFAOAFBQUA4EXqrqGEH9UqlXbNZMWKFU48bty4oF1dXV12TtFrKA899FDQfuSRR5y+8NLj0WXRFyxY4MSbN28uOye4xowZ48TRJeBnzpwZZzqSvno9cMKECUG7trbW6XvggQeCNtfWGoeBAwcWjcOijyh4++23K5FSxTFDAQB4QUEBAHiRulNepYje/nn99dc7cUNOcxUTPnUVPo0hSf3796/Ia+Krjj/++KA9YMCAovsmsQxO+PZyyT19u2HDBqdv2bJlseSE+JTyWdBYlmlihgIA8IKCAgDwgoICAPAi9ddQmjUrXPPOO++8GDP5kjFfPiAxml+xfO+55x4n/tGPfuQ1r6amZcuWQfu0005z+p5++um40/mKrl27Fux77733YswESYg+oTEqvMQO11AAAAihoAAAvKCgAAC8SN01lJtuusmJ0/ho1BEjRgTtPn36OH3hfKO5R6+hoGH27t0btKNLVfTs2dOJw4/f3bVrV0XyOemkk5w4uqR+2KpVqyqSA5I1aNCgoD1+/Pii+4YfNf7RRx9VLKc4MUMBAHhBQQEAeJG6U17h00lJiT6FsXv37k5855131us4O3fudGJWkfXrwIEDQTu6cvOoUaOceOHChUE7ukJ0KXr06OHEXbp0CdrR1YWttQWPk8ZTuWi4du3aBe1if0IgSUuXLq10OrFjhgIA8IKCAgDwgoICAPAidddQ0mD69OlOfPPNN9f7e7ds2RK0r732Wqdv27ZtDcoLhd19991OHF4eR5Iuv/zyoN2QZVmij0QIXycp5emic+bMKTsHpFexW8XDS61I0h/+8IcKZxM/ZigAAC8oKAAALygoAAAvuIaSt2jRoqDdrVu3so/z/vvvB22W14jPxo0bnfiaa65x4t69ewftM844o+zXmTdvXsG+J554womjj4cOC/8NDbKrY8eOTlxsuZXo8irRR5g3BsxQAABeUFAAAF6k7pRX9HbPYssXXHbZZUWP9cc//jFon3rqqUX3Db9OQ5bFSMPSMfiq8GrE0ZWJffnwww/rvW90CRee4JhNAwcOdOJin1cvvPBChbNJHjMUAIAXFBQAgBcUFACAF6m7hjJr1iwnnjFjRsF9X3rpJScudu2jlOsipez7+OOP13tfNG7R63/ROIxrJo1DeLn6qOgyPY8++mil00kcMxQAgBcUFACAF6k75fXcc8858dSpU504+jTFSog+aXHDhg1OPGnSpKBdVVVV8XyQDdEnNBZ7YiMah6FDhxbsi64uvmfPnkqnkzhmKAAALygoAAAvKCgAAC9Sdw1l69atTjxu3DgnHjlyZNC+9dZbK5LDgw8+6MSPPfZYRV4HjUurVq2K9rPCcPa1aNHCibt27Vpw34MHDzrxoUOHKpJTmjBDAQB4QUEBAHhBQQEAeJG6ayhRK1euLBgvWbLE6Qv/fYjkLiW/YMECpy+8tL3kLpMRfuoiUF/XXXedE+/evduJ77///hizQSVEl2WKPnUx/FiCTZs2xZJTmjBDAQB4QUEBAHiR+lNexSxevLhoDMTpzTffdOJHHnnEiZctWxZnOqiA2tpaJ54+fboTh5fbWbt2bSw5pQkzFACAFxQUAIAXFBQAgBemlCW2jTGsx51C1trCjwZMAcZNaq211vZLOoliGDupVefYYYYCAPCCggIA8IKCAgDwgoICAPCCggIA8IKCAgDwgoICAPCCggIA8IKCAgDwgoICAPCi1OXrqyVtrUQiKFunpBOoB8ZNOjF2UK46x05Ja3kBAFAIp7wAAF5QUAAAXlBQAABeUFAAAF5QUAAAXlBQAABeUFAAAF5QUAAAXlBQAABe/D/tHPBeDqaJOQAAAABJRU5ErkJggg==\n",
369 | "text/plain": [
370 | ""
371 | ]
372 | },
373 | "metadata": {},
374 | "output_type": "display_data"
375 | }
376 | ],
377 | "source": [
378 | "''' Step 6。 展示样本数据 '''\n",
379 | "\n",
380 | "\n",
381 | "def show_test_sample_image(net):\n",
382 | " x, y = next(iter(test_loader))\n",
383 | " out = net(x.view(x.size(0), 28 * 28).to(device)).to(device)\n",
384 | " pred = out.argmax(dim = 1).to(device)\n",
385 | " plot_image(x, pred, 'test')\n",
386 | " \n",
387 | "show_test_sample_image(net)"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 9,
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "data": {
397 | "text/plain": [
398 | "'\\ndef solve():\\n show_sample_image()\\n\\n net = Net()\\n\\n MNIST_trains(net)\\n MNIST_tests(net)\\n\\n show_test_sample_image(net)\\n\\n\\nif __name__ == \"__main__\":\\n solve()\\n'"
399 | ]
400 | },
401 | "execution_count": 9,
402 | "metadata": {},
403 | "output_type": "execute_result"
404 | }
405 | ],
406 | "source": [
407 | "'''\n",
408 | "def solve():\n",
409 | " show_sample_image()\n",
410 | "\n",
411 | " net = Net()\n",
412 | "\n",
413 | " MNIST_trains(net)\n",
414 | " MNIST_tests(net)\n",
415 | "\n",
416 | " show_test_sample_image(net)\n",
417 | "\n",
418 | "\n",
419 | "if __name__ == \"__main__\":\n",
420 | " solve()\n",
421 | "'''"
422 | ]
423 | }
424 | ],
425 | "metadata": {
426 | "kernelspec": {
427 | "display_name": "Python [conda env:.conda-pytorch] *",
428 | "language": "python",
429 | "name": "conda-env-.conda-pytorch-py"
430 | },
431 | "language_info": {
432 | "codemirror_mode": {
433 | "name": "ipython",
434 | "version": 3
435 | },
436 | "file_extension": ".py",
437 | "mimetype": "text/x-python",
438 | "name": "python",
439 | "nbconvert_exporter": "python",
440 | "pygments_lexer": "ipython3",
441 | "version": "3.8.5"
442 | }
443 | },
444 | "nbformat": 4,
445 | "nbformat_minor": 4
446 | }
447 |
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/PyTorch实现/Jupyter Notebook/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import pyplot as plt
3 |
4 | device = torch.device("cuda" if torch.cuda.is_available() else "gpu")
5 |
6 | def plot_curve(data):
7 | fig = plt.figure()
8 | plt.plot(range(len(data)), data, color = 'blue')
9 | plt.legend(["value"], loc = 'upper right')
10 | plt.xlabel('step')
11 | plt.ylabel('value')
12 | plt.show()
13 |
14 |
15 | def plot_image(img, label, name):
16 | fig = plt.figure()
17 | for i in range(6):
18 | plt.subplot(2, 3, i + 1)
19 | plt.tight_layout()
20 | plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap = 'gray', interpolation = 'none')
21 | plt.title("{}: {}".format(name, label[i].item()))
22 | plt.xticks([])
23 | plt.yticks([])
24 | plt.show()
25 |
26 |
27 | def one_hot(label, depth = 10):
28 | out = torch.zeros(label.size(0), depth)
29 | idx = torch.LongTensor(label).view(-1, 1)
30 | out.scatter_(dim = 1, index = idx, value = 1)
31 | return out
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/PyTorch实现/PyCharm/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import pyplot as plt
3 |
4 | device = torch.device( "cuda" if torch.cuda.is_available() else "gpu" )
5 |
6 |
7 | def plot_curve(data):
8 | fig = plt.figure()
9 | plt.plot( range( len( data ) ), data, color='blue' )
10 | plt.legend( ["value"], loc='upper right' )
11 | plt.xlabel( 'step' )
12 | plt.ylabel( 'value' )
13 | plt.show()
14 |
15 |
16 | def plot_image(img, label, name):
17 | fig = plt.figure()
18 | for i in range( 6 ):
19 | plt.subplot( 2, 3, i + 1 )
20 | plt.tight_layout()
21 | plt.imshow( img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none' )
22 | plt.title( "{}: {}".format( name, label[i].item() ) )
23 | plt.xticks( [] )
24 | plt.yticks( [] )
25 | plt.show()
26 |
27 |
28 | def one_hot(label, depth=10):
29 | out = torch.zeros( label.size( 0 ), depth )
30 | idx = torch.LongTensor( label ).view( -1, 1 )
31 | out.scatter_( dim=1, index=idx, value=1 )
32 | return out
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/PyTorch实现/PyCharm/手写数字识别-pytorch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import torch # 导入 pytorch
5 | import sklearn
6 | import torchvision # 导入视觉库
7 | import numpy as np
8 | import pandas as pd
9 | from torch import nn # 导入网络层子库
10 | from torch import optim # 导入优化器
11 | from torchsummary import summary # 从 torchsummary 工具包中导入 summary 函数
12 | from torch.nn import functional as F # 导入网络层函数子库
13 | from matplotlib import pyplot as plt
14 | from utils import one_hot, plot_curve, plot_image
15 |
16 | print(sys.version_info)
17 | for module in torch, torchvision, np, pd, sklearn:
18 | print(module.__name__, module.__version__)
19 |
20 |
21 | '''' 超参数与初始化 '''
22 | batch_size = 512 # 批大小
23 | n_epochs = 3
24 | # 学习率
25 | learning_rate = 0.01
26 | # 动量
27 | momentum = 0.9
28 | # 使用 summary 函数之前,需要使用 device 来指定网络在 GPU 还是 CPU 运行
29 | device = torch.device("cuda" if torch.cuda.is_available() else "gpu")
30 | ''' Hyperparameters ans initialize '''
31 |
32 |
33 | ''' Step 1 下载训练集和测试集,对数据进行预处理 '''
34 |
35 | # 训练数据集,从网络下载 MNIST数据集 保存至 mnist_data 文件夹中
36 | # 创建 DataLoader 对象 (iterable, 类似 list,可用 iter() 进行访问),方便批量训练,将数据标准化在 0 附近并随机打散
37 | train_loader = torch.utils.data.DataLoader(
38 | torchvision.datasets.MNIST('mnist_data', train = True, download = True,
39 | # 图片预处理
40 | transform = torchvision.transforms.Compose([
41 | # 转换为张量
42 | torchvision.transforms.ToTensor(),
43 | # 标准化
44 | torchvision.transforms.Normalize(
45 | (0.1307,), (0.3081,))
46 | ])),
47 | batch_size = batch_size, shuffle = True)
48 |
49 |
50 | test_loader = torch.utils.data.DataLoader(
51 | torchvision.datasets.MNIST('mnist_data/', train = False, download = True,
52 | transform = torchvision.transforms.Compose([
53 | torchvision.transforms.ToTensor(),
54 | torchvision.transforms.Normalize(
55 | (0.1307,), (0.3081,)) # 使用训练集的均值和方差
56 | ])),
57 | batch_size = batch_size, shuffle = False)
58 |
59 |
60 | ''' Step 2. 展示样本数据 '''
61 |
62 |
63 | def show_sample_image():
64 | # 使用 iter() 从 DataLoader 中取出 迭代器, next() 选取下一个迭代器
65 | x, y = next(iter(train_loader))
66 | # 输出数据的 shape,以及输入图片的最小最大强度值
67 | print(x.shape, y.shape, x.min(), x.max())
68 | # 使用自己封装的 polt_image() 函数对图片进行展示
69 | plot_image(x, y, 'image sample')
70 |
71 |
72 | ''' Step 3。 搭建网络模型 '''
73 |
74 |
75 | class Net(nn.Module):
76 | # 网络初始化
77 | def __init__(self):
78 | super(Net, self).__init__()
79 | # y = wx + b
80 | # 三层全连接层神经网络
81 | self.fc1 = nn.Linear(28 * 28, 256)
82 | self.fc2 = nn.Linear(256, 64)
83 | self.fc3 = nn.Linear(64, 10)
84 |
85 | # 定义神经网络前向传播逻辑
86 | def forward(self, x):
87 | # x : [b, 1, 28, 28]
88 | # h1 = relu(w1x + b1)
89 | x = F.relu(self.fc1(x))
90 | # h2 = relu(w2x + b2)
91 | x = F.relu(self.fc2(x))
92 | # h3 = w3h2 + b3
93 | x = self.fc3(x)
94 | # 直接返回向量 [b, 10], 通过 argmax 即可得到分类预测值
95 | return x
96 | '''
97 | 也可直接将向量经过 softmax 函数得到分类预测值
98 | return F.log_softmax(x, dim = 1)
99 | '''
100 |
101 |
102 | ''' Step 4. 在训练集上进行训练 '''
103 |
104 |
105 | def MNIST_trains(net):
106 | # 选择 SGD 随机梯度下降算法作为优化方法,导入网络参数、学习率以及动量
107 | optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
108 |
109 | train_loss = []
110 |
111 | for epoch in range(n_epochs):
112 | for batch_idx, (x, y) in enumerate(train_loader):
113 | # 将数据 x 打平
114 | # x: [b, 1, 28, 28] -> [b, 784]
115 | x = x.view(x.size(0), 28 * 28).to(device)
116 | # 经过神经网络 [b, 784] -> [b, 10]
117 | out = net(x).to(device)
118 | # 将数据的真实标签 y 转换为 one hot 向量
119 | y_one_hot = one_hot(y).to(device)
120 | # 计算 网络预测值 out 与 真实标签 y 的 mse 均方差
121 | # loss = mse(out, y_one_hot)
122 | loss = F.mse_loss(out, y_one_hot)
123 | # zero grad 清空历史梯度数据
124 | optimizer.zero_grad()
125 | # 进行反向传播,计算当前梯度
126 | loss.backward()
127 | # 根据当前梯度更新网络参数
128 | # w' = w - lr * grad
129 | optimizer.step()
130 | # 保存当前的损失函数值
131 | train_loss.append(loss.item())
132 | # 每 10 步 输出一次数据查看训练情况
133 | if batch_idx % 10 == 0:
134 | print(f"epoch:{epoch}, iteration:{batch_idx}, loss:{loss.item()}")
135 | # 绘制损失函数图像
136 | # [w1, b1, w2, b2, w3, b3]
137 | plot_curve(train_loss)
138 |
139 |
140 | ''' Step 5. 在测试集中进行测试 '''
141 |
142 |
143 | def MNIST_tests(net):
144 | # 在测试集中预测正确的总数
145 | total_correct = 0
146 | # 迭代所有测试数据
147 | for x, y in test_loader:
148 | # 将图片 x 打平
149 | x = x.view(x.size(0), 28 * 28).to(device)
150 | # 经过已经训练好的神经网络 net
151 | out = net(x).to(device)
152 | # 预测值 pred: argmax 返回指定维度最大值的索引
153 | # out [b, 10] -> pred [b]
154 | pred = out.argmax(dim=1).to(device)
155 | # 计算预测值等于真实标签的样本数量
156 | correct = pred.eq(y.to(device)).sum().float().item()
157 | # 计算预测正确样本的总数
158 | total_correct += correct
159 | # 总样本数即为测试集的长度
160 | total_num = len(test_loader.dataset)
161 | # 计算正确率
162 | acc = total_correct / total_num
163 | # 输出测试正确率 acc
164 | print("test_acc:", acc)
165 |
166 |
167 | ''' Step 6。 展示样本数据 '''
168 |
169 |
170 | def show_test_sample_image(net):
171 | x, y = next(iter(test_loader))
172 | out = net(x.view(x.size(0), 28 * 28).to(device)).to(device)
173 | pred = out.argmax(dim=1).to(device)
174 | plot_image(x, pred, 'test')
175 |
176 |
177 | def solve():
178 | show_sample_image()
179 |
180 | print(torch.cuda.is_available())
181 | net = Net().to(device)
182 | # 所有的张量都需要进行 `.to(device)`
183 | print(net)
184 | summary(net, (1, 28 * 28))
185 | # summary(your_model, input_size=(channels, H, W))
186 | # input_size 要求符合模型的输入要求, 用来进行前向传播
187 |
188 | MNIST_trains(net)
189 | MNIST_tests(net)
190 |
191 | show_test_sample_image(net)
192 |
193 |
194 | if __name__ == "__main__":
195 | solve()
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/TensorFlow2.0实现/Jupyter Notebook/MNIST数据集的前向传播训练误差曲线.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 03 章 分类问题与信息论基础/全部源码(已更新)/TensorFlow2.0实现/Jupyter Notebook/MNIST数据集的前向传播训练误差曲线.png
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/TensorFlow2.0实现/Jupyter Notebook/第 03 章 TensorFlow2.0 手写数字图片实战 - 神经网络.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "67527f3f",
6 | "metadata": {},
7 | "source": [
8 | "$❑\\, \\, $ **Step 0 引入文件并设置图像参数**"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "e9d66ec1",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "%matplotlib inline\n",
19 | "# “在 jupyter notebook 在线使用 matplotlib”,为 IPython 的内置 magic 函数,在 Pycharm 中不被支持\n",
20 | "import os\n",
21 | "import matplotlib.pyplot as plt\n",
22 | "import tensorflow as tf\n",
23 | "import tensorflow.keras.datasets as datasets\n",
24 | "\n",
25 | "plt.rcParams['font.size'] = 16\n",
26 | "plt.rcParams['font.family'] = ['STKaiti']\n",
27 | "plt.rcParams['axes.unicode_minus'] = False"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "44cd7f0c",
33 | "metadata": {},
34 | "source": [
35 | "$❑\\, \\, $ **Step 1 导入数据并对数据进行处理**\n",
36 | "\n",
37 | " 利用 TensorFlow 自动在线下载 MNIST 数据集,并转换为 Numpy 数组格式。"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "id": "afd312bd",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "def load_data() :\n",
48 | " # 加载 MNIST 数据集 元组tuple: (x, y), (x_val, y_val)\n",
49 | " (x, y), (x_val, y_val) = datasets.mnist.load_data()\n",
50 | " # 将 x 转换为浮点张量,并从 0 ~ 255 缩放到 [0, 1.] - 1 -> [-1, 1] 即缩放到 -1 ~ 1\n",
51 | " x = tf.convert_to_tensor(x, dtype = tf.float32) / 255. - 1\n",
52 | " # 转换为整形张量\n",
53 | " y = tf.convert_to_tensor(y, dtype = tf.int32)\n",
54 | " # one-hot 编码\n",
55 | " y = tf.one_hot(y, depth = 10)\n",
56 | " # 改变视图, [b, 28, 28] => [b, 28*28]\n",
57 | " x = tf.reshape(x, (-1, 28 * 28))\n",
58 | " # 构建数据集对象\n",
59 | " train_dataset = tf.data.Dataset.from_tensor_slices((x, y))\n",
60 | " # 批量训练\n",
61 | " train_dataset = train_dataset.batch(200)\n",
62 | " return train_dataset"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "22707823",
68 | "metadata": {},
69 | "source": [
70 | "\n",
71 | " TensorFlow 中的 `load_data()` 函数返回两个**元组** (tuple) 对象,第一个是训练集,第二个是测试集,每个 tuple 的第一个元素是多个训练图片数据 $X$ ,第二个元素是训练图片对应的类别数字 $Y$。其中训练集 $X$ 的大小为 $(60000,28,28)$ ,代表了 $60000$ 个样本,每个样本由 $28$ 行、$28$ 列构成,由于是灰度图片,故没有 RGB 通道;训练集 $Y$ 的大小为 $(60000)$,代表了这 $60000$ 个样本的标签数字,每个样本标签用一个 $0\\sim 9$ 的数字表示,测试集同理。\n",
72 | "\n",
73 | " 从 TensorFlow 中加载的 MNIST 数据图片,数值的范围在 $[0,255]$ 之间。在机器学习中间,一般希望数据的范围在 $0$ 周围小范围内分布。我们可以通过预处理步骤,我们把 $[0,255]$ 像素范围**归一化**(Normalize)到 $[0,1.]$ 区间,再缩放到 $[−1,1]$ 区间,从而有利于模型的训练。\n",
74 | "\n",
75 | " 每一张图片的计算流程是通用的,我们在计算的过程中可以一次进行多张图片的计算,充分利用 CPU 或 GPU 的并行计算能力。一张图片我们用 shape 为 $[h, w]$ 的矩阵来表示,对于多张图片来说,我们在前面添加一个**数量维度** (Dimension),使用 shape 为 $[b, h, w]$ 的张量来表示,其中的 $b$ 代表了 batch size(**批量**。多张彩色图片可以使用 shape 为 $[b, h, w, c]$ 的张量来表示,其中的 $c$ 表示通道数量(Channel),彩色图片$c = 3$(R、G、B)。通过 TensorFlow 的Dataset 对象可以方便完成模型的批量训练,只需要调用 `batch()` 函数即可构建带 `batch` 功能的数据集对象。\n"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "id": "4478bfdb",
81 | "metadata": {},
82 | "source": [
83 | "$❑\\, \\, $ **Step 2 网络搭建**\n",
84 | "\n",
85 | "\n",
86 | " 对于第一层模型来说,他接受的输入 $𝒙 ∈ \\mathbb R^{784}$ ,输出𝒉𝟏 ∈ $\\mathbb R^{256}$ 设计为长度为 $256$ 的向量,我们不需要显式地编写 $\\boldsymbol{h}_{1}=\\operatorname{ReLU}\\left(\\boldsymbol{W}_{1} \\boldsymbol{x}+\\boldsymbol{b}_{1}\\right)$ 的计算逻辑,在 TensorFlow 中通过一行代码即可实现:\n",
87 | "\n",
88 | "\n",
89 | "```python\n",
90 | "layers.Dense(256, activation = 'relu') \n",
91 | "```\n",
92 | "\n",
93 | " 使用 TensorFlow 的 Sequential 容器可以非常方便地搭建多层的网络。对于 3 层网络,我们可以通过\n",
94 | "\n",
95 | "```python\n",
96 | "keras.sequential([\n",
97 | " layers.Dense(256, activation = 'relu'),\n",
98 | " layers.Dense(128, activation = 'relu'),\n",
99 | " layers.Dense(10)])\n",
100 | "```\n",
101 | "\n",
102 | "\n",
103 | " 快速完成 $3$ 层网络的搭建,第 $1$ 层的输出节点数设计为 $256$,第 $2$ 层设计为 $128$,输出层节点数设计为 $10$。直接调用这个模型对象 `model(x)` 就可以返回模型最后一层的输出 。\n",
104 | "\n",
105 | " 为了能让大家理解更多的细节,我们这里不使用上面的框架,手动实现经过 3 层神经网络。\n",
106 | "\n",
107 | "对神经网络参数初始化:"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 3,
113 | "id": "96f0eb34",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "def init_paramaters() :\n",
118 | " # 每层的张量需要被优化,使用 Variable 类型,并使用截断的正太分布初始化权值张量\n",
119 | " # 偏置向量初始化为 0 即可\n",
120 | " # 第一层参数\n",
121 | " w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev = 0.1))\n",
122 | " b1 = tf.Variable(tf.zeros([256]))\n",
123 | " # 第二层参数\n",
124 | " w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev = 0.1))\n",
125 | " b2 = tf.Variable(tf.zeros([128]))\n",
126 | " # 第三层参数\n",
127 | " w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev = 0.1))\n",
128 | " b3 = tf.Variable(tf.zeros([10]))\n",
129 | " return w1, b1, w2, b2, w3, b3"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "99a242c4",
135 | "metadata": {},
136 | "source": [
137 | "$❑\\, \\, $ **Step 3 模型训练**\n",
138 | "\n",
139 | " 得到模型输出 $\\boldsymbol{o}$ 后,通过 MSE 损失函数计算当前的误差 $\\mathcal L$:\n",
140 | "\n",
141 | "```python\n",
142 | "with tf.GradientTape() as tape:#构建梯度记录环境\n",
143 | " #打平,[b,28,28] =>[b,784]\n",
144 | " x=tf.reshape(x,(-1,28*28))\n",
145 | " #step1. 得到模型输出 output\n",
146 | " # [b,784] =>[b,10]\n",
147 | " out=model(x)\n",
148 | "```\n",
149 | "\n",
150 | "**手动实现代码:**\n",
151 | "\n",
152 | "\n",
153 | "```python\n",
154 | " with tf.GradientTape() as tape :#构建梯度记录环境\n",
155 | " # 第一层计算, [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b,256] + [b, 256]\n",
156 | " h1 = x @ w1 + tf.broadcast_to(b1, (x.shape[0], 256))\n",
157 | " # 通过激活函数 relu\n",
158 | " h1 = tf.nn.relu(h1)\n",
159 | " # 第二层计算, [b, 256] => [b, 128]\n",
160 | " h2 = h1 @ w2 + b2\n",
161 | " h2 = tf.nn.relu(h2)\n",
162 | " # 输出层计算, [b, 128] => [b, 10]\n",
163 | " out = h2 @ w3 + b3\n",
164 | "```\n"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "id": "00432bee",
170 | "metadata": {},
171 | "source": [
172 | "$❑\\, \\, $ **Step 4 梯度优化**\n",
173 | "\n",
174 | " 利用 TensorFlow 提供的自动求导函数 `tape.gradient(loss, model.trainable_variables)` 求出模型中所有的梯度信息 $\\dfrac{\\partial L}{\\partial \\theta}$ , $\\theta ∈ \\{\\boldsymbol W_1, 𝒃_𝟏,\\boldsymbol W_2, 𝒃_𝟐,\\boldsymbol W_3, 𝒃_𝟑\\}$:\n",
175 | "\n",
176 | "```python \n",
177 | " grads = tape.gradient(loss, model.trainable_variables) \n",
178 | "```\n",
179 | "\n",
180 | " 计算获得的梯度结果使用 grads 变量保存。再使用 optimizers 对象自动按着梯度更新法则\n",
181 | "$$\n",
182 | "\\theta^{\\prime}=\\theta-\\eta \\times \\frac{\\partial \\mathcal{L}}{\\partial \\theta}\n",
183 | "$$\n",
184 | "\n",
185 | "\n",
186 | "去更新模型的参数 $\\theta$。\n",
187 | "\n",
188 | "```python\n",
189 | "grads = tape.gradient(loss, model.trainable_variables)\n",
190 | " # w' = w - lr * grad,更新网络参数\n",
191 | "optimizer.apply_gradients(zip(grads, model.trainable_variables)) \n",
192 | "```\n",
193 | "\n",
194 | " 循环迭代多次后,就可以利用学好的模型 $𝑓_{\\theta}$ 去预测未知的图片的类别概率分布。 \n",
195 | "\n",
196 | "\n",
197 | "**手动实现梯度更新代码如下:**"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 4,
203 | "id": "36967ffb",
204 | "metadata": {},
205 | "outputs": [],
206 | "source": [
207 | "def train_epoch(epoch, train_dataset, w1, b1, w2, b2, w3, b3, lr = 0.001) :\n",
208 | " for step, (x, y) in enumerate(train_dataset) :\n",
209 | " with tf.GradientTape() as tape :\n",
210 | " # 第一层计算, [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b,256] + [b, 256]\n",
211 | " h1 = x @ w1 + tf.broadcast_to(b1, (x.shape[0], 256))\n",
212 | " # 通过激活函数 relu\n",
213 | " h1 = tf.nn.relu(h1)\n",
214 | " # 第二层计算, [b, 256] => [b, 128]\n",
215 | " h2 = h1 @ w2 + b2\n",
216 | " h2 = tf.nn.relu(h2)\n",
217 | " # 输出层计算, [b, 128] => [b, 10]\n",
218 | " out = h2 @ w3 + b3\n",
219 | "\n",
220 | " # 计算网络输出与标签之间的均方差, mse = mean(sum(y - out) ^ 2)\n",
221 | " # [b, 10]\n",
222 | " loss = tf.square(y - out)\n",
223 | " # 误差标量, mean: scalar\n",
224 | " loss = tf.reduce_mean(loss) \n",
225 | " # 自动梯度,需要求梯度的张量有[w1, b1, w2, b2, w3, b3]\n",
226 | " grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])\n",
227 | "\n",
228 | " # 梯度更新, assign_sub 将当前值减去参数值,原地更新\n",
229 | " w1.assign_sub(lr * grads[0])\n",
230 | " b1.assign_sub(lr * grads[1])\n",
231 | " w2.assign_sub(lr * grads[2])\n",
232 | " b2.assign_sub(lr * grads[3])\n",
233 | " w3.assign_sub(lr * grads[4])\n",
234 | " b3.assign_sub(lr * grads[5])\n",
235 | "\n",
236 | " if step % 100 == 0 :\n",
237 | " print(f\"epoch:{epoch}, iteration:{step}, loss:{loss.numpy()}\") \n",
238 | " \n",
239 | " return loss.numpy()"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "id": "95f47773",
245 | "metadata": {},
246 | "source": [
247 | "整体的代码思路如下:\n",
248 | "\n",
249 | " 我们首先创建每个非线性层的 $𝑾$ 和 $𝒃$ 张量参数,然后把 $\\boldsymbol x$ 的 `shape=[b, 28, 28]` 转成向量 `[b, 784]` ,然后计算三层神经网络,每层使用 ReLU 激活函数,然后与 `one_hot` 编码的 $\\boldsymbol y$ 一起计算均方差,利用 `tape.gradient()` 函数自动求梯度\n",
250 | "\n",
251 | "\n",
252 | "\n",
253 | "\n",
254 | "$$\n",
255 | "\\theta^{\\prime}=\\theta-\\eta \\cdot \\frac{\\partial \\mathcal{L}}{\\partial \\theta}\n",
256 | "$$\n",
257 | "\n",
258 | "\n",
259 | " 使用 `assign_sub()` 函数按照上述梯度下降算法更新网络参数(assign_sub()将自身减去给定的参数值,实现参数的原地 (In-place) 更新操作),最后使用 `matplotlib` 绘制图像输出即可。 \n",
260 | "\n"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 5,
266 | "id": "d6703b58",
267 | "metadata": {},
268 | "outputs": [
269 | {
270 | "name": "stdout",
271 | "output_type": "stream",
272 | "text": [
273 | "epoch:0, iteration:0, loss:2.8238346576690674\n",
274 | "epoch:0, iteration:100, loss:0.12282778322696686\n",
275 | "epoch:0, iteration:200, loss:0.09311732649803162\n",
276 | "epoch:1, iteration:0, loss:0.08501865714788437\n",
277 | "epoch:1, iteration:100, loss:0.08542951196432114\n",
278 | "epoch:1, iteration:200, loss:0.0726892277598381\n",
279 | "epoch:2, iteration:0, loss:0.0702851265668869\n",
280 | "epoch:2, iteration:100, loss:0.07322590798139572\n",
281 | "epoch:2, iteration:200, loss:0.06422527879476547\n",
282 | "epoch:3, iteration:0, loss:0.06337807327508926\n",
283 | "epoch:3, iteration:100, loss:0.06658075749874115\n",
284 | "epoch:3, iteration:200, loss:0.05931859463453293\n",
285 | "epoch:4, iteration:0, loss:0.05899645760655403\n",
286 | "epoch:4, iteration:100, loss:0.062211062759160995\n",
287 | "epoch:4, iteration:200, loss:0.05583002790808678\n",
288 | "epoch:5, iteration:0, loss:0.05573083460330963\n",
289 | "epoch:5, iteration:100, loss:0.058985237032175064\n",
290 | "epoch:5, iteration:200, loss:0.053309373557567596\n",
291 | "epoch:6, iteration:0, loss:0.053154319524765015\n",
292 | "epoch:6, iteration:100, loss:0.0565737783908844\n",
293 | "epoch:6, iteration:200, loss:0.051321178674697876\n",
294 | "epoch:7, iteration:0, loss:0.05108541622757912\n",
295 | "epoch:7, iteration:100, loss:0.054654691368341446\n",
296 | "epoch:7, iteration:200, loss:0.04969591274857521\n",
297 | "epoch:8, iteration:0, loss:0.0493333600461483\n",
298 | "epoch:8, iteration:100, loss:0.05307555943727493\n",
299 | "epoch:8, iteration:200, loss:0.04830818995833397\n",
300 | "epoch:9, iteration:0, loss:0.047856513410806656\n",
301 | "epoch:9, iteration:100, loss:0.051712967455387115\n",
302 | "epoch:9, iteration:200, loss:0.047130417078733444\n",
303 | "epoch:10, iteration:0, loss:0.046589870005846024\n",
304 | "epoch:10, iteration:100, loss:0.05050584673881531\n",
305 | "epoch:10, iteration:200, loss:0.046077027916908264\n",
306 | "epoch:11, iteration:0, loss:0.04547039046883583\n",
307 | "epoch:11, iteration:100, loss:0.0494525171816349\n",
308 | "epoch:11, iteration:200, loss:0.04513201490044594\n",
309 | "epoch:12, iteration:0, loss:0.04447377845644951\n",
310 | "epoch:12, iteration:100, loss:0.04850050434470177\n",
311 | "epoch:12, iteration:200, loss:0.04429565370082855\n",
312 | "epoch:13, iteration:0, loss:0.043561432510614395\n",
313 | "epoch:13, iteration:100, loss:0.04763760045170784\n",
314 | "epoch:13, iteration:200, loss:0.04352030158042908\n",
315 | "epoch:14, iteration:0, loss:0.04273266717791557\n",
316 | "epoch:14, iteration:100, loss:0.04685059189796448\n",
317 | "epoch:14, iteration:200, loss:0.04281007871031761\n",
318 | "epoch:15, iteration:0, loss:0.04196896404027939\n",
319 | "epoch:15, iteration:100, loss:0.04612402245402336\n",
320 | "epoch:15, iteration:200, loss:0.042143866419792175\n",
321 | "epoch:16, iteration:0, loss:0.04124370589852333\n",
322 | "epoch:16, iteration:100, loss:0.045436661690473557\n",
323 | "epoch:16, iteration:200, loss:0.04152608662843704\n",
324 | "epoch:17, iteration:0, loss:0.04054683819413185\n",
325 | "epoch:17, iteration:100, loss:0.04479134455323219\n",
326 | "epoch:17, iteration:200, loss:0.04094816371798515\n",
327 | "epoch:18, iteration:0, loss:0.039887987077236176\n",
328 | "epoch:18, iteration:100, loss:0.044195547699928284\n",
329 | "epoch:18, iteration:200, loss:0.040402255952358246\n",
330 | "epoch:19, iteration:0, loss:0.039282456040382385\n",
331 | "epoch:19, iteration:100, loss:0.04364131763577461\n",
332 | "epoch:19, iteration:200, loss:0.03987015783786774\n",
333 | "epoch:20, iteration:0, loss:0.03870239108800888\n",
334 | "epoch:20, iteration:100, loss:0.04311355948448181\n",
335 | "epoch:20, iteration:200, loss:0.0393809899687767\n",
336 | "epoch:21, iteration:0, loss:0.03816535696387291\n",
337 | "epoch:21, iteration:100, loss:0.04260074719786644\n",
338 | "epoch:21, iteration:200, loss:0.038916587829589844\n",
339 | "epoch:22, iteration:0, loss:0.03764738887548447\n",
340 | "epoch:22, iteration:100, loss:0.0421011820435524\n",
341 | "epoch:22, iteration:200, loss:0.03846241906285286\n",
342 | "epoch:23, iteration:0, loss:0.03715534508228302\n",
343 | "epoch:23, iteration:100, loss:0.04162450134754181\n",
344 | "epoch:23, iteration:200, loss:0.038025762885808945\n",
345 | "epoch:24, iteration:0, loss:0.03669610247015953\n",
346 | "epoch:24, iteration:100, loss:0.04115061089396477\n",
347 | "epoch:24, iteration:200, loss:0.037617530673742294\n",
348 | "epoch:25, iteration:0, loss:0.036250464618206024\n",
349 | "epoch:25, iteration:100, loss:0.04068863391876221\n",
350 | "epoch:25, iteration:200, loss:0.03720514848828316\n",
351 | "epoch:26, iteration:0, loss:0.0358189232647419\n",
352 | "epoch:26, iteration:100, loss:0.04023948684334755\n",
353 | "epoch:26, iteration:200, loss:0.03680950030684471\n",
354 | "epoch:27, iteration:0, loss:0.03541094437241554\n",
355 | "epoch:27, iteration:100, loss:0.039796702563762665\n",
356 | "epoch:27, iteration:200, loss:0.036432769149541855\n",
357 | "epoch:28, iteration:0, loss:0.035001132637262344\n",
358 | "epoch:28, iteration:100, loss:0.0393792949616909\n",
359 | "epoch:28, iteration:200, loss:0.0360533632338047\n",
360 | "epoch:29, iteration:0, loss:0.03459516167640686\n",
361 | "epoch:29, iteration:100, loss:0.038972556591033936\n",
362 | "epoch:29, iteration:200, loss:0.03568875044584274\n",
363 | "epoch:30, iteration:0, loss:0.034208156168460846\n",
364 | "epoch:30, iteration:100, loss:0.03857466205954552\n",
365 | "epoch:30, iteration:200, loss:0.03533806651830673\n",
366 | "epoch:31, iteration:0, loss:0.03383728861808777\n",
367 | "epoch:31, iteration:100, loss:0.03818388655781746\n",
368 | "epoch:31, iteration:200, loss:0.03499047830700874\n",
369 | "epoch:32, iteration:0, loss:0.03347790986299515\n",
370 | "epoch:32, iteration:100, loss:0.03780962899327278\n",
371 | "epoch:32, iteration:200, loss:0.03465589880943298\n",
372 | "epoch:33, iteration:0, loss:0.0331435389816761\n",
373 | "epoch:33, iteration:100, loss:0.03744875639677048\n",
374 | "epoch:33, iteration:200, loss:0.03432944416999817\n",
375 | "epoch:34, iteration:0, loss:0.03281198441982269\n",
376 | "epoch:34, iteration:100, loss:0.03710091859102249\n",
377 | "epoch:34, iteration:200, loss:0.03401472419500351\n",
378 | "epoch:35, iteration:0, loss:0.03248513862490654\n",
379 | "epoch:35, iteration:100, loss:0.03677191212773323\n",
380 | "epoch:35, iteration:200, loss:0.0337025411427021\n",
381 | "epoch:36, iteration:0, loss:0.03216780349612236\n",
382 | "epoch:36, iteration:100, loss:0.03644610941410065\n",
383 | "epoch:36, iteration:200, loss:0.03340121731162071\n",
384 | "epoch:37, iteration:0, loss:0.03186160326004028\n",
385 | "epoch:37, iteration:100, loss:0.03613165020942688\n",
386 | "epoch:37, iteration:200, loss:0.03310254588723183\n",
387 | "epoch:38, iteration:0, loss:0.03155995160341263\n",
388 | "epoch:38, iteration:100, loss:0.03582020476460457\n",
389 | "epoch:38, iteration:200, loss:0.03279929608106613\n",
390 | "epoch:39, iteration:0, loss:0.031268708407878876\n",
391 | "epoch:39, iteration:100, loss:0.03551224246621132\n",
392 | "epoch:39, iteration:200, loss:0.03250761330127716\n",
393 | "epoch:40, iteration:0, loss:0.03098423406481743\n",
394 | "epoch:40, iteration:100, loss:0.035210851579904556\n",
395 | "epoch:40, iteration:200, loss:0.032223351299762726\n",
396 | "epoch:41, iteration:0, loss:0.030709076672792435\n",
397 | "epoch:41, iteration:100, loss:0.03492467850446701\n",
398 | "epoch:41, iteration:200, loss:0.03194170445203781\n",
399 | "epoch:42, iteration:0, loss:0.03045233152806759\n",
400 | "epoch:42, iteration:100, loss:0.03464159741997719\n",
401 | "epoch:42, iteration:200, loss:0.031668927520513535\n",
402 | "epoch:43, iteration:0, loss:0.030199026688933372\n",
403 | "epoch:43, iteration:100, loss:0.03436439484357834\n",
404 | "epoch:43, iteration:200, loss:0.03139331936836243\n",
405 | "epoch:44, iteration:0, loss:0.02995467558503151\n",
406 | "epoch:44, iteration:100, loss:0.03409324958920479\n",
407 | "epoch:44, iteration:200, loss:0.031120551750063896\n",
408 | "epoch:45, iteration:0, loss:0.029712168499827385\n",
409 | "epoch:45, iteration:100, loss:0.03382962942123413\n",
410 | "epoch:45, iteration:200, loss:0.03084247186779976\n",
411 | "epoch:46, iteration:0, loss:0.029475683346390724\n",
412 | "epoch:46, iteration:100, loss:0.03357085585594177\n",
413 | "epoch:46, iteration:200, loss:0.030582088977098465\n",
414 | "epoch:47, iteration:0, loss:0.029252037405967712\n",
415 | "epoch:47, iteration:100, loss:0.033312439918518066\n",
416 | "epoch:47, iteration:200, loss:0.030324475839734077\n",
417 | "epoch:48, iteration:0, loss:0.02903318777680397\n",
418 | "epoch:48, iteration:100, loss:0.033065315335989\n",
419 | "epoch:48, iteration:200, loss:0.030078813433647156\n",
420 | "epoch:49, iteration:0, loss:0.028811004012823105\n",
421 | "epoch:49, iteration:100, loss:0.03282175958156586\n",
422 | "epoch:49, iteration:200, loss:0.02983343042433262\n"
423 | ]
424 | },
425 | {
426 | "data": {
427 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEQCAYAAABxzUkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnpElEQVR4nO3deZicVZ328e+vlq7elySdPSwJJGHTAYKGLTIKDAqjLIJE5CVBQZQJGJRNFFnHeZmwGAYSAwwgDpuDGtBB2VTgRQIhgIMgsiQQQkg6JJ1Oel9+7x9V3amuru50V1d1dVfdn+vKZdV5TnedE0nffc55znnM3REREUlFINsNEBGRkUshIiIiKVOIiIhIyhQiIiKSMoWIiIikLJTtBgylMWPG+G677ZbtZoiIjCgvvfTSJnevTnYtr0Jkt912Y+XKldluhojIiGJm7/V2TdNZIiKSMoWIiIikTCEiIiIpU4iIiEjKFCIiIpIyhYiIiKQsr27xFZGRp6mpiZqaGpqammhra8t2c3JKOBxm7NixlJeXp/w9FCL98J07tlDfHD0yv7zIuH5+VZZbJJIftm7dyoYNG6iurmb8+PGEQiHMLNvNygnuTmNjI+vWrQNIOUg0ndUPnQECUNeo56+IDJVNmzYxefJkqqqqCIfDCpA0MjOKi4uZNGkSGzduTPn7KEREZNhqaWmhqKgo283IaUVFRbS2tqb89QqRfigvsqSvRSTzNPrIrMH+/SpE+uH6+VVUlhiHzCzQeoiISByFSD9VlQbYsr0j280QERlWFCL9VFWiEBERSaQQ6aeq0gC19QoREUmf7du393n9zDPPZNmyZUmvPfnkkxx++OHU1dV1Kz/77LP58Y9/nLY27oxCpJ+qSgI0tUJji27xFZHBa29v58wzz2TmzJnU1dVRU1PTo8706dN57bXXkn79Sy+9xP77709TU1O38jVr1nDCCSdkpM3JKET6qao0+lelKS0RSYdgMMiDDz7ID3/4Q5qbm3nwwQdZuHAh9fX1PPjggwAUFhay33779fja1tZWVq9ezU033cS5557LvffeS0dHBy+88ALr16/n8ssvZ5dddmHmzJm0tLRktB/asd5P8SEycVQwy60RyW/3P1vP2k3tKX/9Ox+10R77fTAYgGnjB/6jcMqYIKceVpJyGwCam5s57bTTADjwwANZuXIlxcXFLFy4kJNOOolAIJB0n8wdd9zBRRddxAsvvEB9fT3HHHMM9957L4sXL2bOnDkceOCBXHDBBXzyk5+koKBgUG3cGY1E+qmyJBYiWhcRGfHaO5K/Hirvv/8+t9xyC2PHjuX0008HoKysrOtYl0996lMEg0HMrMc+jqVLl3LBBRdw4okncvHFF/PQQw8xatQo5s6dS2lpKbNnz6ampobZs2cPyUZNjUT6qStENJ0lknWDHQGcdevmbu8vPD71AwgHavny5dxyyy3MmzePz33ucyxYsACIbvoLhaI/kjt/+CcLkZNPPpmvfvWrPProoxx11FEUFRVRW1vLeeedx5o1a5g/fz7z588fsv5oJNJP4aBRVmQaiYjkgGyeQvGlL32Jxx57jK9+9au0tbWx7777AtDQ0EBZWRkAgUD0R7O7d73uNGrUKC6++GKWLl3Keeedx9KlS4HoibwXXXQRp5xyCqtWreKQQw6hurqaJUuWZLQ/Qz4SsWisngM4EAGecfdVSepFgG8DDUAZ8JC7r467fhhwOLAVmAHc4O7vZbLt2nAokhuGy8kTra2tRCIRHnjgAQoLC5k8eTJA15H3HR09f96YGSUlJVx33XUcdNBBAFx11VVMmzaN/fbbj912242//vWvfPe73x2SPmRjOut84Fl3XwlgZkvM7BJ335pQ72rgJnf/MBY8d5vZfHdvN7Nq4F/c/dTY99gVuBP4bCYbXlUSYFOdQkREBq+jo4MPPviAU089lauuuorbb7+dk08+mYaGBtavXw/Q6/NTFi1axNq1a1m8eDHNzc388Ic/5OWXX+aAAw7ghRdeoLm5mZdffpni4mJmzJiR0X4M6XSWmQWBuZ0BErMCmJdQrwrY390/BHB3B94FjotVORx4v7N+bASyW8YaHlNVGtB0loikxW233cabb77Jueeey+jRo3n11Ve7FsM7RxiNjY3dbtHduHEjl112GQcffDBf+cpX2Hvvvbnwwgt5+OGH+dnPfgZEp8WampqYOnUqxx57LLfccktG+zHUayJziE5PxXsTmJtQdjzwQR/1PgT+ycyKAcxsMvBWWluaRFVJgIZmp7lVGw5FZHBKS0tZtGgRc+bM4aKLLuLWW28FotNVixYtAmDdunXddrWPHj2a2tpaFi5cyHPPPceRRx4JwDXXXMNpp51GR0cHGzZsYPv27VRUVHD//ffz05/+NKP9GOrprKnA5oSyzbHyftdz9+fN7A3gRTP7V2BX4LT0N7e7zr0itfUdjKvUXhERSV3n/pAXX3yRa6+9lkmTJvWos2bNGvbZZ5+u98FgMOnIYtmyZey///5ceeWV3HDDDSxcuBCAWbNm8fjjj2eoB1FDPRIZCyRO8rUCY8ws0I96Y+PenwOsBi4hGiC7J/tAMzvbzFaa2cpkxwoMRNeGQ01piUiaHHTQQUkDBODII4/s15rG/vvvD8D3v/99Zs6c2RVQAOPGjUtPQ3sx1CORjUDi9skwsMndOxLqlSaptxHAzEYDvyE67bUROBd4wsxmuvv6+C9y92XAMoBZs2YNah5Ke0VEZCgN9A6rcDjMihUrMtSa5IZ6JPIuMCahrIroiGIg9U4EXnH3DR71H8BTwElpbm/3BihERES6GeoQeQaoSJi6mg7cl1BvOTAtoSy+XilQl3D9baA5Te1MKhI2iiPacCgi0mlIQ8Td24A7gKPjij8N3GVmx5rZzbF6m4HnzWyvuHrTgEdir58ADu28ENtH8om46xmjh1OJiOyQjc2Gi4EFZrYL0XWOO9291symA4fF1bsUON/MDiU68rjS3dsB3P1/zewGM1tEdAQyEbjM3T/KdOO1V0RkaLl7j/OjJH2i2/BSN+QhEts4uDhJ+Y3AjXHvm4Hr+vg+vwR+mYk29qWqNMB7Ncl3kYpIehUUFNDY2EhxcXG2m5KzGhsbCYfDKX+9DmAcoKqSANsandZ2bTgUybQxY8bwwQcfsHnzZlpbWwf9W7Ps4O40NDSwbt06xo4du/Mv6IWOgh+gzr0iW+s7GFOuDYcimVRRUUEkEqGmpoaPP/6417OkJDXhcJhx48ZRXp76UfgKkQGqins4lUJEJPMKCwuZMmVKtpshvdB01gBVlkYX+HSHloiIQmTAtOFQRGQHhcgAFRUYkbDOzxIRAYXIgJmZNhyKiMQoRFKgDYciIlEKkRToWesiIlEKkRRUlQTY2uC0d2jjk4jkN4VICqpKA7jD1gaFiIjkN4VICjpv863VuoiI5DmFSAoqS7VXREQEFCIp0YZDEZEohUgKSguNUFAbDkVEFCIp0IZDEZEohUiKtOFQREQhkjKNREREFCIpqyoNUFvfQYeetCYieUwhkqKq0gDtHbC9USEiIvlLIZKi+CcciojkK4VIiiq1V0RERCGSqqpSjURERBQiKSovMoIBjUREJL8pRFIUCBgVxbrNV0Tym0JkELThUETynUJkELThUETynUJkEDo3HLo2HIpInlKIDEJVaYCWNmhoVoiISH5SiAyCnisiIvlOITIIlSUGaK+IiOSvULYbMJJ1bjhc/NvtQHTvyPXzq7LZJBGRIaWRyCB0Hn3SqU6HMYpInlGIDEIwYNlugohIVilEBikSNyFYXqRQEZH8ohAZpG9/vgyA848r1XqIiOQdhcggTR0fImDw1vq2bDdFRGTIKUQGqTBsTBkT5G2FiIjkIYVIGuw5IcTqDW20tuvuLBHJLwqRNNhzQpjWdni/RqMREckvCpE02GNC9BYtrYuISL5RiKRBeXGAcZUBrYuISN5RiKTJHuNDvL2+jQ4dCy8ieUQhkiZ7TghT3+ys39Ke7aaIiAwZhUia7Dkxui6iKS0RyScKkTSpLg9QXmQKERHJKwqRNDEz9pwY0h1aIpJXFCJptMf4MB9v62DzNq2LiEh+UIik0Z6d+0U+0mhERPKDQiSNJo8JEglrcV1E8odCJI2CAWPauJBCRETyhkIkzfacGGbdx+00NHdkuykiIhmnEEmzPSaEcOAdrYuISB5QiKTZ7mNDBAM6jFFE8oNCJM0iYWOX6qBCRETyQs6EiJkVZbsNnfYcH2bNhjZa23QYo4jkttBQf6CZGXAO4EAEeMbdVyWpFwG+DTQAZcBD7r46Sb1DgeOBXwAvZK7l/bfnxBCPvQpratrYc0I4280REcmYIQ8R4HzgWXdfCWBmS8zsEnffmlDvauAmd/8wFjx3m9l8d+/aDm5m3wOmA99292EzfzRtfPSv9bpfbQOgvMi4fn5VNpskIpIRQzqdZWZBYG5ngMSsAOYl1KsC9nf3DwHc3YF3gePi6syLvT9nOAUIQFlR97/WukZNa4lIbhrqNZE5RKen4r0JzE0oOx74oLd6ZlYN3Eh0BKINGSIiWTLUITIV2JxQtjlWPpB6XwdeBD40s3lmdomZFST7QDM728xWmtnKmpqawbV+AEoLret1eZH1UVNEZORKOUTMrDiFLxsLJE49tQJjzCzQj3pjY6+PAh5z91p3vwsIArcm+0B3X+bus9x9VnV1dQpNTs2NZ1Yxc1KI6vIAi+ZVDtnniogMpV5DJLaY3Zc5ZvYrM/vzAD5vI5A4YggDmxKmpXqrtzH2ejLwTty1u4EzehuNZMshMyPU1HVoz4iI5Ky+RiLbzGyZmX0m2UV3/x3wZWDTAD7vXWBMQlkVkHjr7s7qbaV72zcSvdNs/ADaknEHTC0gEobn/tac7aaIiGREXyHyXXc/293/BGBmXzezv5nZZWY2EyB2u+3vBvB5zwAVCVNX04H7EuotB6YllMXXex3YM+5aBdAOfDSAtmRcJGzMmlbAS++00NyqO7REJPf0ORKJf+PudwDXu/u17v63uEut/f2w2K24dwBHxxV/GrjLzI41s5tj9TYDz5vZXnH1pgGPxF4vBr4aF0ZHAD9395b+tmWoHDwjQlMrvPzusGuaiMig9bXZMNmaSLJ5mYHeerQYWGBmuxBd57jT3WvNbDpwWFy9S4HzYzvSS4ErOzcauvsqM7seWGRmrwEzgAUDbMeQ2HNiiDHlAZ57s5nZMyLZbo6ISFr1FSL9nX8Z0DxNbOPg4iTlNxLd+9H5vhm4ro/vc/dAPjdbAmYcPCPCb15s5ONt7YwuC2a7SSIiadNXiPyzmSVePyzJTVvHAsvS2qocc/CMAh55sZHn32zh2FnD5pxIEZFB6ytEvhL7k+gbCe+1YrwT1eVBpk8M8dybzXzhwEJ2fve0iMjI0NfC+gXuHtjZH2D+UDV2JDtkRoSNWzv0xEMRySl9hcjj/fwef01HQ3LdgdMKKAjBn9/UXVoikjt6DRF3TxoOZjbazL5oZnPMrMDdX8pc83JHYYFx4LQCXny7hRY9rEpEckRfx54cYWa3mtlhcWX7An8DbgfOAx4xs8mZb2ZuOHhGhMYW55XVGo2ISG7oa2H9AmChu78DXc8CuY/owYifcve1ZlYBfBe4POMtzQEzJoUw4LbH67nt8Xo9rEpERry+QuSlzgCJWQDsQ/ShUmsB3H2rmQ2ro0aGs4BZt1vZ9LAqERnp+lpYr+t8YWajgR8CT7n7Awn1ds1Ew0REZPjrK0QCZnZYLEB+TvR4k2/GV4itkRyawfblnPgHVBUNq4PrRUQGrq8QWQocCTwRe/+5uPWRMWZ2JfA8AzvFN+9dP7+KpedUMaEqQHlxgLZ2TWmJyMiVNETMLAyc4u5XuPv+7v55d3+587q7b3L3HxE9gn3rELU1ZwQDxkkHF7OhtoNn39CzRkRk5Eq6sO7urWa2wMwc6EhWJ8aAM4CbM9G4XPaJXcNMnxji4RcbmT09QmGBjkIRkZGnr7uz9gJuAjbT+3HvBkxIc5vygpnx5YOL+deH6vj9K4186VOpPLJeRCS7+gqRccBJwESizzNf7u6NiZXM7JIMtS3n7T4uxKxpBTz2ShOf2aeQypK+lqhERIafvo49qXP3O939WuAF4FuxR+MekVD15xlsX847YXYR7R3wyIs98llEZNjrayTSxd3fBW4AiN32+wOgBXg44VG5MkBjK4IcsU+Ep15r5shPFDJhlB5aJSIjRyrzJ2/Hvm4B8IqZXZjeJuWfY2cV4Q6X37+Vs27dzHfv3JLtJomI9Eu/RiIAZnYkcA7wz8D7wH8QfT76xgy1LW+UFXXPch2HIiIjRZ8hYmajiD506mxgN+Bh4Fh3fyKuzhh335TJRoqIyPDU22bDgJn9DPgA+Bbwn8AUdz85PkBijs5wG/NC/HEowQDayS4iI0Jvmw07zOxk4JdEn3DowDFJng0eBL4G3JvJRuaDziPhX3y7mWWP1fPwi42cOFt7R0RkeOtrOuv7REOkL0b0fC1Jk4P2iPDG2jZ+t6qJvSaH2WtyONtNEhHpVdIQiT2A6n/c/b2dfQMz0ygkzb5yWDFvf9TKHU9s5/JTKigv1iZEERmekv50cvd2d3+zP9/A3X+b3iZJJGycfXQp9c3OnU/V0+FaHxGR4Um/4g5Tk0eHOOXQYl57v5UnXm3KdnNERJLq9z4RGXpH7BPhjbWt/OK5Rn7xXPRYFD2XXUSGE41EhjEz44x/LOlWpo2IIjKcKESGuZJC/V8kIsOXfkKNAPEbEQ34aEt79hojIhLHPI/u/Jk1a5avXLky281I2Udb2rnu13UEA3DRCeVUl+vEXxHJPDN7yd1nJbumkcgIMr4qyAX/XEZLG9zw8Da2bO/rycUiIpmnu7NGmMljQnznuDJueLiOS+6ppSM2kNRdWyKSDRqJjEC7jwux4NiyrgAB3bUlItmhEBmhpk/UmVoikn0KkREs/q4tgGffaM5SS0QkX2lNZATrXANpaO7gp7/fzt1/qGdjbTvHzy4i0PPYfhGRtNNIJAcURwIsOLaMOXtHePTlJpY9tp2WNq2RiEjmaSSSI0JB42ufKWZcZYBfPNfIS+9sAXTXlohklkYiOcTMOPofirqV1TU6+bShVESGlkIkDyz+7Xa2Nmhjooikn0IkB8XftVUYhjfXtXLF/Vt5ZXVLFlslIrlIayI5KHEN5MPN7dz+xHZueXR7V5nWSkQkHTQSyQMTRwW59KTybmVaKxGRdFCI5IlwsOe+kUXLt7Fuc1sWWiMiuUIhkkcS10rWfdzO1Q/W8d/PNdDUqlGJiAyc1kTySOIayLbGDh76cwO/f6WJ37/SBGitREQGRiORPFZWFGDeZ0u7ldU1Oi+81UyH1ktEpB8UItLDbY/Xc/WDdbyyukWL7yLSJ01nCeVF1vU8kvIi45RDi3n4xUZueXQ7AUMPvhKRXilEJGkwHDitgD+/2czP/tjQVVbX6LS2OeGQTggWkShNZ0lSoaBx+N6FPcovvqeW36xsZHuTjlEREY1EZCfip7qKI8au1SGWv9DI8hcau9XRNJdIflKISJ+ShcO6j9u44oG6rvd1jc6Kt5o5cGoBoSSbGkUkdylEZMAmje75n83tj9fzQFEDh+8VYc4+EUaXBbPQMhEZakMeImZmwDmAAxHgGXdflaReBPg20ACUAQ+5++pevuev3P2EzLVaEiXe0TX/cyX88bVm/mdVE/+zKrpxsaggOpJJduSKiOQGG+p9AGb2HeBZd18Ze78EuMTdtybUuw64yd0/jAXP3cB8d29PqHcy8KC77/Qn1axZs3zlypVp6okkc9atm7u9L4kYn55ewGF7RZgyRgNfkZHIzF5y91nJrg3pv2ozCwJz3f2muOIVwDzgJ3H1qoD93f1DAHd3M3sXOA5YHlevGpie+ZZLqvaeEubp15t56n+bu8pKC40bz9RCvEguGOpbfOcQnZ6K9yYwN6HseOCDftQ7H7gxXY2TwYs/5LG8yDj76FIWnVHZrc72JmfRr+t4+vUm6nWrsMiINtTzC1OBzQllm2PlA6pnZl8EnnT3huhslwwHye7mKins+btKbUMH9/yxgXviNjOWFRk36FZhkRFlqEciY4HEB1i0AmPMLNCPemMBzKwc+JS7/2FnH2hmZ5vZSjNbWVNTk3rLZVASRyhXz63gByd3f1DWtkZnye+28eJbzTqaXmSEGOqRyEagIKEsDGxy946EeqVJ6m2Mvb4UuKE/H+juy4BlEF1YH2iDJT2SjVB2re75n987H7Wx6t1WCkL1dDi0xW6j0IZGkeFpqEci7wJjEsqqgMRbd3utF1tMPxF4xMyeN7PnAWKvF2WgzZJBiSOU686o5MLjyzh0r0hXgEB0Q+Nzf2vWGorIMDPUI5FngAozC8SNPKYD9yXUWw6clVA2HbjP3WuAGfEXzMzdfXYmGiyZlWx0MX1imOkTw/wh7o4ugDufqicY+7WnPfZfj0YoItk1pCMRd28D7gCOjiv+NHCXmR1rZjfH6m0GnjezveLqTQMeGbLGStYljlK+f1I5R36ysCtAIDpCeXRVIx9taU/yHUQk07Kx2dCABUAT0XWOFe6+0swWAv/H3feP1YsQvYV3M9H1kYfd/d1evqdrs2H+SNzQ2EnPPhHJjGGz2RCiGweBxUnKbyRuz4e7NwPX9fN76h7fPJJ45MplXy7n1TWt3PtM92ef3P74dj6xW5h9poST3mYsIoOncyhkxEk2wvjH/YLdQgTgr2tbWfFWS7eykkLjxvmVaG+RSHro1zPJGYlrKNfPq+SSE8u61alvci65Zyv3/LGeV1a3aD+KyCBpJCI5I9kIZdr4cI+yXaqDrHirmadf7373V0nEuPFMjVJEBkIhIjkvcQ3l3M+X0druvL2+jRse3tZVr77ZufDuWvaeEuaV1S00tuz4ei3SiySnEJGclywAwkFjr8k9RynTJ4b5y5rWrgCB6CL962tb2WNCiIKQRiki8RQiktcSRylnH11KR4fzzaVbutW78ZFthILgro2OIvEUIpLXkoVAIGDdwqWsyDjzcyW8vraNx19t6qpX1+j85DfbmDEpxIyJYXapDhIMaKQi+UUhIpJEsnDZd5eCbiEC8PG2dh76cyvQ2FUWCcN3jitjt7EhQno0sOQ4hYjIACROf101t5KtDR18767arjrNrfB/f7WNglB06qtz+kvPS5FcpBARGYBkI5SK4p7brb51TClvfdjKE3/ZcRvxtkbnml9sZc8JIfaYEGaPCaGkXysykihERNIgcYRywNQCDpha0C1EAArDxtOvNycph4tOKGfSqCABravICKIQEUmD3u7SSgyX7x1fTlu78/6mdn78UF1XvaZWuOrBOiLh6IO4uqbACo0bztQUmAxfChGRDEoWLqGgMXVcz396Xz+yhHc+auOPr8VNgTU5l9xTy9RxIf7yXgvNrdFy3V4sw4UmZEWyJPGsr9nTI5w2p6RHvV2rg7y9vq0rQCB6e/H9z9Sz4u/N1GxtZ6gf6SDSSSMRkSzp7xTYt46JHiKZ+ByVZ95o5smEpz9GQnDOMaXsPjak4+9lSChERIaZ/obLdWdUsm5zO1c/uGNtpbkNfvKb7UD3h3SVRIx/P6OSsI5tkTRTiIiMEMnCZZcxPf8JX/DFMt7d0MavV+zYAFnf7Cy4fQuTRwfZfWyI5//eTJPWVyQNNN4VGeES11b2mhzm2AOLetQ7+pOFFBVYtwCB6PrKA8/W8/zfm/motp0Ora/IAGgkIjLC9Xf668SDiwHocOebS7ofMPn06820JOxdKQjB//nHEnatDjG2IkBAz1mRJBQiIjmqt3AJmCVdX1m/pZ0rH9ixvtLSBrc/Xt/j64sK4NKTKhhXEdDGSFGIiOSjZAEzeXTPHweXn1LOezXt3P2HHWHS2AKX37eVSAja4s4GK4kYi+ZV6tDJPKMQEZEuiSOUKWNCTBkT6hYiAPM/W8J7NW08FXeLcX2z8y+3bWFCVZANte20tkfLdfBkblOIiEiX/q6vHDIzwiEzI91CBKKL92s/bueDj9u7yrY1RnfdTxkTZMroIE/8pUmPHs4hChER2amBLt4nboycOi7E2k1tvLq6lfh7v+oanXv+WM/k0UEmjQ4yaVRQmyRHGIWIiKSsv+Fy9tGlADS3Rqe84q18p4WnX+9+W3EwAEd+spDJo6LhMr4qSFhrLcOSQkRE0q63cImEe94ZtmheJbX1zkU/q+2q194BT77aRFtH968PBeHzBxQxaVSQiaOCjK0I6JHEWaYQEZEhlSxgqkp7BsHNZ1WxcWsHP7p/a1dZWzv85sVGErdDdobLxKogE0YFGFcR1F1iQ0QhIiLDQuIIJRQ0Jo4K9qh381lVfFTbzjW/2LGnpbdwgeimydOPKGFiVZBxlUEiYYVLOilERGRY6O/6SiRs7Frd80fXzWdVsaG2nfVb2rn9iR23JLe0wR1P9Nw0WRCC0+aUMHFUdM2lUOGSEoWIiAxrAwmXXapD7FId6hYiAFecWs5HWzpY+vvtXWUtbXDnU8nD5ZRDixlfGWRCVZCyIsN05EuvFCIiMiL1tb8kMWAmjQoxaVTPelfPreDDLe0s+V33cPn5nxp61A0H4dhZRYyvDDK+MsDYiqCO1kchIiI5qL+jl/FV0amsRP92egUfbengpt9s6yprbafb8fqdggGYs3eEcZXRu8XGVQYZXZY/d40pREQkb/Q3XEaXBRldlmRR/xvRdZdr/nvHon57Bzz/9xYaW3ou64cCcMR+EcZVBLtCpqo0t05EVoiISN7rb7gUFhi7ju35Y/MnX69kW6OzYWs71/1qx+ilrQOe/mszLW3d6wcM9ts1zNiKIM+83jSiHxCmEBER6UV/w8XMKC82yot7Htly81lV1NY7F8dtpuxw2FTXwetrW7sOqoToMTCX/ryW6vIAb69v67pWEjF+fHolRQXDbwSjEBERGaCBLOoHzBiVZDPlFadWJH1A2NRxIWq2tncLl/pm57zbt1BaaDS0OB2xnfyFBbDgC2VUlwepKLGsTJMpRERE0qi/oxdI/oCws46KnjOWeIjlibOLqKnr4JnXd5yc3NQC//7r6PRZOBhdn+mILc1EwvCNI0sZUxZgTHmQwgyNYhQiIiJDoLdw6W/ofP6AIoBuIQJw/nGlbKrroGZrB4+92tRV3twKtzy6nUTpXndRiIiIDEP9DZd9dynouhYfIgDfP6mcTdvaWfbYjk2VnV+bLgoREZERZCDrMbuPC7H7uFC3EEk3hYiISI4YyHpMuihERERyXCb3nug5lCIikjKFiIiIpEwhIiIiKVOIiIhIyhQiIiKSMoWIiIikzNzTu3txODOzGuC9FL98DLApjc0ZSfK17+p3flG/e7eru1cnu5BXITIYZrbS3Wdlux3ZkK99V7/zi/qdGk1niYhIyhQiIiKSMoVI/y3LdgOyKF/7rn7nF/U7BVoTERGRlGkkIiIiKVOIiIhIynQU/E6YmQHnAA5EgGfcfVV2W5UZZhYA7gD+5O53xZV/Avgs0AQEgSXu3pGVRmaAmRUCi4CjgDBwrbvfEbuW632/AvgKUAVc4+7/ESvP6X53MrMpwIXufl7sfU7328wS1y8Od/dnB9Vvd9efPv4A3wFmxb1fAlRku10Z6GchcCPwCjAvrrwUWBr3/h+Ai7Pd3jT3/crYPyaATwK1wJG53nfgVOAzsdeHAm1AZa73O+Hv4LfAXbHXOd9v4Cyig4cQEEpHvzWd1QczCwJz3X1lXPEKYF52WpRR/wT8K9EQifc14KXON+7+CnCCmRWQA8wsAmxw92cA3P1V4L+AE8jxvgPPu/ufYq+fA14G6sj9fgNgZl8DVscV5UO/W929rfNPrGxQ/VaI9G0O0JBQ9iYwNwttySh3X+7uNUkuzQX+nlC2lejQNxe0Af+ZUOZE/3/P6b67+5q4tycDZ3h0CiOn+w1gZhOACUD8L4g53+9eDKrfCpG+TQU2J5RtjpXni5z+O3D3dndvSij+DHAfOd53iK75mdmVwL8DZ8V++8z5fgPnA4sTyvKh35ea2XYze8vMzoqVDarfWljv21iiv6nGawXGmFnAc2jBrQ+9/R2MzUJbMs7MTgV+6+6rzCwf+j4ZWArcCvw/oj88crrfZnYS8Dt3b47eN9Mlp/sdcw3wFPBpYImZfcwg+60Q6dtGIHFeMAxsypMAgd7/DjZmoS0ZZWa7Ex3CnxMryvm+u/vaztdmthg4nRzut5lVArPd/cIkl3O2353c/Z7Yy1/G7jz9JoPst0Kkb+8SPSY5XhXdF+NyXV78HZhZFdEpjn+J+wUhL/oeZy1QRG73ey5wlJk9H3tfDZTH3udyv5N5HLiWQfZbayJ9ewaoiO2f6DSd6Hx5vrgP2DehrBJ4cuibkhmxfSI/Ai5z95ZYWTl50PcEewBPk8P9dvcl7v4P7j7b3WcDVxOdvpxNDve7F6OAdxhkvxUifYjdAncHcHRc8aeBu7LSoOz4OXBY5xsz2xdY3vnDNkcsAm5093ro2mB6LTncdzOrNrND4t6XA18mx/u9EzndbzM73sziRxzfBf6NQfZbBzDuROwHygKiOznDwIqEfSM5IdbP04kuvL1F9Lfy52PX9gO+AHxMdO70p+7enq22plNskfU+ut/KHQY+cPcZudp3M5tGtN+vAs8DU4C73X117HpO9juRmc0DjnD3ebH3OdtvM/sBcAqwHNhG9GSKFbFrKfdbISIiIinTdJaIiKRMISIiIilTiIiISMoUIiIikjKFiIiIpEwhIiIiKVOIiIhIyhQiIikws33NbKGZuZk9YGZfi/tzjpnVxx5qlqnPD5rZAjNrNbPdMvU5IjujzYYigxB7ZvV8j3smfaz8dnf/xhB8/ntEH3G7JtOfJZKMRiIimfFfQ/Q5+i1QskohIpIB7v6Hztdm9plstkUkkxQiImlmZt+I/W+Fmd0PXG1m3zKz5Wb2gZk9GnvGd2f9PczsajM7xczOM7PrY8fTx1//TzP7sZmtjq3BxD+Sr9LMfmJmr5vZb80sMnS9lXynNRGRQYitidwBPBsrKgC+5+7TY9f/iejjZ7/g7m+YWTHwKLDV3b9oZqOInqL7KXevjX3N14FPu/vZsUD4X+Cb7v4HM5sEnAVc6e5uZmuAXwCXAe1EnwFxm7sP1XSa5Dk92VBk8J6NX1g3s3Fx15qB99z9DQB3bzCzHwFPmlkYmA+s6gyQmHuIPv/6B8AMoLBzeszd1wFXJHz+T+MepvU0MDWNfRPpk6azRNLvoZ1cf43ov71yoj/w18dfjAXCZmBPYAJQu5Pv1xb3ugPI2K3FIokUIiJp5u6vA5jZ54HSJFVKgQ3u/jFJnm8dG6GMBtbErk+NTYOJDDsKEZEMiG00/CawnWggxDsTuDL2+m7gIDMri7t+BnBvbOrqJeCvwCIzC8S+94lmVp3J9ov0l9ZERFJgZrOAfWJvT4q/m4rov6tjgM4f9FVmdj6wEdgV2OzuSwDcfZOZnQhcamargBKiU1hnx667mX0FuAt418zeBZYAtbFHu44GLjKzG2OfexTQZmYPu/uqzPReZAfdnSWSQWZ2BHCFux+Rge8ddPf22O2+5u4d6f4MkZ3RSERkhHL39tj/Otq5LlmiNREREUmZQkQkQ2IbCY8DZprZ6WY2NtttEkk3rYmIZEjsVt0CoJHo3g3r3BQokisUIiIikjJNZ4mISMoUIiIikjKFiIiIpEwhIiIiKfv/r0K1FPW5V2UAAAAASUVORK5CYII=\n",
428 | "text/plain": [
429 | ""
430 | ]
431 | },
432 | "metadata": {
433 | "needs_background": "light"
434 | },
435 | "output_type": "display_data"
436 | }
437 | ],
438 | "source": [
439 | "def train(epochs) :\n",
440 | " losses = []\n",
441 | " train_dataset = load_data()\n",
442 | " w1, b1, w2, b2, w3, b3 = init_paramaters()\n",
443 | " for epoch in range(epochs) :\n",
444 | " loss = train_epoch(epoch, train_dataset, w1, b1, w2, b2, w3, b3, lr = 0.01)\n",
445 | " losses.append(loss)\n",
446 | " x = [i for i in range(0, epochs)]\n",
447 | " # 绘制曲线\n",
448 | " plt.plot(x, losses, color = 'cornflowerblue', marker = 's', label = '训练', markersize='3')\n",
449 | " plt.xlabel('Epoch')\n",
450 | " plt.ylabel('MSE')\n",
451 | " plt.legend()\n",
452 | " plt.savefig('MNIST数据集的前向传播训练误差曲线.png')\n",
453 | " plt.show()\n",
454 | " plt.close()\n",
455 | "\n",
456 | "if __name__ == '__main__' :\n",
457 | "\t# x 轴 0 ~ 50\n",
458 | " train(epochs = 50)"
459 | ]
460 | }
461 | ],
462 | "metadata": {
463 | "kernelspec": {
464 | "display_name": "Python 3 (ipykernel)",
465 | "language": "python",
466 | "name": "python3"
467 | },
468 | "language_info": {
469 | "codemirror_mode": {
470 | "name": "ipython",
471 | "version": 3
472 | },
473 | "file_extension": ".py",
474 | "mimetype": "text/x-python",
475 | "name": "python",
476 | "nbconvert_exporter": "python",
477 | "pygments_lexer": "ipython3",
478 | "version": "3.7.11"
479 | }
480 | },
481 | "nbformat": 4,
482 | "nbformat_minor": 5
483 | }
484 |
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/TensorFlow2.0实现/PyCharm/MNIST数据集的前向传播训练误差曲线.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 03 章 分类问题与信息论基础/全部源码(已更新)/TensorFlow2.0实现/PyCharm/MNIST数据集的前向传播训练误差曲线.png
--------------------------------------------------------------------------------
/第 03 章 分类问题与信息论基础/全部源码(已更新)/TensorFlow2.0实现/PyCharm/第 03 章 TensorFlow2.0 手写数字图片实战 - 神经网络.py:
--------------------------------------------------------------------------------
1 | # 1. 导入模块,并设置图像参数
2 | import os
3 | import matplotlib.pyplot as plt
4 | import tensorflow as tf
5 | import tensorflow.keras.datasets as datasets
6 |
7 | plt.rcParams['font.size'] = 4
8 | plt.rcParams['font.family'] = ['STKaiti']
9 | plt.rcParams['axes.unicode_minus'] = False
10 |
11 |
12 | # 2. 导入数据,并对数据进行简单处理
13 |
14 | def load_data() :
15 | # 加载 MNIST 数据集 元组tuple: (x, y), (x_val, y_val)
16 | (x, y), (x_val, y_val) = datasets.mnist.load_data()
17 | # 将 x 转换为浮点张量,并从 0 ~ 255 缩放到 [0, 1.] - 1 -> [-1, 1] 即缩放到 -1 ~ 1
18 | x = tf.convert_to_tensor(x, dtype = tf.float32) / 255. - 1
19 | # 转换为整形张量
20 | y = tf.convert_to_tensor(y, dtype = tf.int32)
21 | # one-hot 编码
22 | y = tf.one_hot(y, depth = 10)
23 | # 改变视图, [b, 28, 28] => [b, 28*28]
24 | x = tf.reshape(x, (-1, 28 * 28))
25 | # 构建数据集对象
26 | train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
27 | # 批量训练
28 | train_dataset = train_dataset.batch(200)
29 | return train_dataset
30 |
31 |
32 | # 3. 对神经网络参数进行初始化
33 |
34 | def init_paramaters() :
35 | # 每层的张量需要被优化,使用 Variable 类型,并使用截断的正太分布初始化权值张量
36 | # 偏置向量初始化为 0 即可
37 | # 第一层参数
38 | w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev = 0.1))
39 | b1 = tf.Variable(tf.zeros([256]))
40 | # 第二层参数
41 | w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev = 0.1))
42 | b2 = tf.Variable(tf.zeros([128]))
43 | # 第三层参数
44 | w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev = 0.1))
45 | b3 = tf.Variable(tf.zeros([10]))
46 | return w1, b1, w2, b2, w3, b3
47 |
48 |
49 | # 4. 模型训练
50 |
51 | def train_epoch(epoch, train_dataset, w1, b1, w2, b2, w3, b3, lr=0.001):
52 | for step, (x, y) in enumerate( train_dataset ):
53 | with tf.GradientTape() as tape:
54 | # 第一层计算, [b, 784]@[784, 256] + [256] => [b, 256] + [256] => [b,256] + [b, 256]
55 | h1 = x @ w1 + tf.broadcast_to( b1, (x.shape[0], 256) )
56 | # 通过激活函数 relu
57 | h1 = tf.nn.relu( h1 )
58 | # 第二层计算, [b, 256] => [b, 128]
59 | h2 = h1 @ w2 + b2
60 | h2 = tf.nn.relu( h2 )
61 | # 输出层计算, [b, 128] => [b, 10]
62 | out = h2 @ w3 + b3
63 |
64 | # 计算网络输出与标签之间的均方差, mse = mean(sum(y - out) ^ 2)
65 | # [b, 10]
66 | loss = tf.square( y - out )
67 | # 误差标量, mean: scalar
68 | loss = tf.reduce_mean( loss )
69 | # 自动梯度,需要求梯度的张量有[w1, b1, w2, b2, w3, b3]
70 | grads = tape.gradient( loss, [w1, b1, w2, b2, w3, b3] )
71 |
72 | # 梯度更新, assign_sub 将当前值减去参数值,原地更新
73 | w1.assign_sub( lr * grads[0] )
74 | b1.assign_sub( lr * grads[1] )
75 | w2.assign_sub( lr * grads[2] )
76 | b2.assign_sub( lr * grads[3] )
77 | w3.assign_sub( lr * grads[4] )
78 | b3.assign_sub( lr * grads[5] )
79 |
80 | if step % 100 == 0:
81 | print( f"epoch:{epoch}, iteration:{step}, loss:{loss.numpy()}" )
82 |
83 | return loss.numpy()
84 |
85 |
86 | def train(epochs) :
87 | losses = []
88 | train_dataset = load_data()
89 | w1, b1, w2, b2, w3, b3 = init_paramaters()
90 | for epoch in range(epochs) :
91 | loss = train_epoch(epoch, train_dataset, w1, b1, w2, b2, w3, b3, lr = 0.01)
92 | losses.append(loss)
93 | x = [i for i in range(0, epochs)]
94 | # 绘制曲线
95 | plt.plot(x, losses, color = 'cornflowerblue', marker = 's', label = '训练', markersize='5')
96 | plt.xlabel('Epoch')
97 | plt.ylabel('MSE')
98 | plt.legend()
99 | plt.savefig('MNIST数据集的前向传播训练误差曲线.png')
100 | plt.show()
101 | plt.close()
102 |
103 | if __name__ == '__main__' :
104 | # x 轴 0 ~ 50
105 | train(epochs = 50)
--------------------------------------------------------------------------------
/第 04 章 TensorFlow2.0从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 04 章 TensorFlow2.0从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 04 章 TensorFlow2.0从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 04 章 TensorFlow2.0从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 05 章 PyTorch 从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 05 章 PyTorch 从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 05 章 PyTorch 从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 05 章 PyTorch 从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 06 章 神经网络与反向传播算法/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 06 章 神经网络与反向传播算法/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 06 章 神经网络与反向传播算法/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 06 章 神经网络与反向传播算法/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 07 章 过拟合、优化算法与参数优化/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 07 章 过拟合、优化算法与参数优化/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 07 章 过拟合、优化算法与参数优化/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 07 章 过拟合、优化算法与参数优化/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 08 章 卷积神经网络 (CNN) 从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 08 章 卷积神经网络 (CNN) 从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 08 章 卷积神经网络 (CNN) 从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 08 章 卷积神经网络 (CNN) 从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 09 章 循环神经网络 (RNN) 从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 09 章 循环神经网络 (RNN) 从入门到升天/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 09 章 循环神经网络 (RNN) 从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 09 章 循环神经网络 (RNN) 从入门到升天/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 10 章 注意力机制与Transformer/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 10 章 注意力机制与Transformer/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 10 章 注意力机制与Transformer/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 10 章 注意力机制与Transformer/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 11 章 图神经网络(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 11 章 图神经网络(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 11 章 图神经网络(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 11 章 图神经网络(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 12 章 自编码器(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 12 章 自编码器(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 12 章 自编码器(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 12 章 自编码器(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 13 章 生成对抗网络(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 13 章 生成对抗网络(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 13 章 生成对抗网络(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 13 章 生成对抗网络(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 14 章 强化学习(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 14 章 强化学习(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 14 章 强化学习(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 14 章 强化学习(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 15 章 元学习(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 15 章 元学习(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 15 章 元学习(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 15 章 元学习(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 16 章 对抗攻击与防御(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 16 章 对抗攻击与防御(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 16 章 对抗攻击与防御(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 16 章 对抗攻击与防御(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 17 章 迁移学习(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 17 章 迁移学习(万字综述)/PDF文件(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------
/第 17 章 迁移学习(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanfansann/fanfan-deep-learning-note/64f27dc1e146505171b4f59e3610be7cd6b91d23/第 17 章 迁移学习(万字综述)/全部源码(待更)/正文很快更新哟^q^.pdf
--------------------------------------------------------------------------------