├── 01_linear_regression.ipynb
├── 02_MNIST_classification.ipynb
├── 03_cnn_image_classification.ipynb
├── 04_LSTM_sentiment_analysis.ipynb
├── 06_GAN_image_generation.ipynb
├── 08_few_shot_learning.ipynb
├── 101_mosaic_video.ipynb
├── 102_Wav2Lip_Inference.ipynb
├── 201_native_rag.ipynb
├── README.en.md
├── README.md
└── asserts
├── images
├── 01_01.png
├── 01_02.png
├── 06_01.png
├── 101_01.png
├── 101_02.png
├── 101_03.png
├── 101_04.png
├── 101_05.png
└── 101_06.png
└── mp4
└── kunkun.mp4
/01_linear_regression.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "Pytorch入门实战(1) - 实现线性回归"
7 | ],
8 | "metadata": {
9 | "collapsed": false,
10 | "pycharm": {
11 | "name": "#%% md\n"
12 | }
13 | }
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "collapsed": true,
19 | "pycharm": {
20 | "name": "#%% md\n"
21 | }
22 | },
23 | "source": [
24 | "# 涉及知识点\n",
25 | "[Pytorch nn.Module的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122797244)\n",
26 | "\n",
27 | "[Pytorch nn.Linear的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122797190)"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "source": [
33 | "# 将线性回归神经网络化\n",
34 | "\n",
35 | "线性回归也可以看作一个简单的神经网络。以一个特征的一元线性回归为例:\n",
36 | "\n",
37 | "$$\n",
38 | "y = w \\cdot x + b\n",
39 | "$$\n",
40 | "\n",
41 | "可以改造下图神经网络:\n",
42 | "\n",
43 | " \n",
44 | "\n",
45 | "若将x泛化为向量,即 $x=(x_1, x_2, ... , x_n)$,则对应神经网络为:\n",
46 | "\n",
47 | " "
48 | ],
49 | "metadata": {
50 | "collapsed": false,
51 | "pycharm": {
52 | "name": "#%% md\n"
53 | }
54 | }
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "source": [
59 | "# Pytorch 代码实现\n",
60 | "\n",
61 | "## 一元线性回归Pytorch方式实现"
62 | ],
63 | "metadata": {
64 | "collapsed": false,
65 | "pycharm": {
66 | "name": "#%% md\n"
67 | }
68 | }
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 2,
73 | "outputs": [],
74 | "source": [
75 | "import torch\n",
76 | "import matplotlib.pyplot as plt"
77 | ],
78 | "metadata": {
79 | "collapsed": false,
80 | "pycharm": {
81 | "name": "#%%\n"
82 | }
83 | }
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "source": [
88 | "首先生成测试数据:"
89 | ],
90 | "metadata": {
91 | "collapsed": false,
92 | "pycharm": {
93 | "name": "#%% md\n"
94 | }
95 | }
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 3,
100 | "outputs": [
101 | {
102 | "data": {
103 | "text/plain": "tensor([[7.2543],\n [0.8824],\n [8.1629]])"
104 | },
105 | "execution_count": 3,
106 | "metadata": {},
107 | "output_type": "execute_result"
108 | }
109 | ],
110 | "source": [
111 | "X = torch.rand(100, 1) * 10 # 生成一个100行一列的数据;该数据服从[0,10]的uniform分布\n",
112 | "X[:3]"
113 | ],
114 | "metadata": {
115 | "collapsed": false,
116 | "pycharm": {
117 | "name": "#%%\n"
118 | }
119 | }
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 5,
124 | "outputs": [
125 | {
126 | "data": {
127 | "text/plain": "tensor([[33.6710],\n [13.3356],\n [36.9041]])"
128 | },
129 | "execution_count": 5,
130 | "metadata": {},
131 | "output_type": "execute_result"
132 | }
133 | ],
134 | "source": [
135 | "y = 3 * X + 10 + torch.randn(100, 1) * 3 # 计算其对应的y值;y也是100行1列的\n",
136 | "y[:3]"
137 | ],
138 | "metadata": {
139 | "collapsed": false,
140 | "pycharm": {
141 | "name": "#%%\n"
142 | }
143 | }
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "source": [
148 | "将生成的数据绘制成散点图,看下效果:"
149 | ],
150 | "metadata": {
151 | "collapsed": false,
152 | "pycharm": {
153 | "name": "#%% md\n"
154 | }
155 | }
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 6,
160 | "outputs": [
161 | {
162 | "data": {
163 | "text/plain": "",
164 | "image/png": "\n"
165 | },
166 | "metadata": {
167 | "needs_background": "light"
168 | },
169 | "output_type": "display_data"
170 | }
171 | ],
172 | "source": [
173 | "plt.scatter(X.numpy(), y.numpy())\n",
174 | "plt.show()"
175 | ],
176 | "metadata": {
177 | "collapsed": false,
178 | "pycharm": {
179 | "name": "#%%\n"
180 | }
181 | }
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "source": [
186 | "接下来定义线性回归预训练模型:"
187 | ],
188 | "metadata": {
189 | "collapsed": false,
190 | "pycharm": {
191 | "name": "#%% md\n"
192 | }
193 | }
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 9,
198 | "outputs": [],
199 | "source": [
200 | "class LinearRegression(torch.nn.Module):\n",
201 | " \"\"\"\n",
202 | " 模型需要继承 `torch.nn.Module`,在Pytorch中,模型都需要继承该类\n",
203 | " \"\"\"\n",
204 | "\n",
205 | " def __init__(self):\n",
206 | " super().__init__() # 初始化Module类\n",
207 | "\n",
208 | " \"\"\"\n",
209 | " 定义我们神经网络的第一层(线性层)。其接受的重要三个参数:\n",
210 | " in_features: 输入神经元的个数\n",
211 | " out_features:输出神经元的个数\n",
212 | " bias:是否包含偏置\n",
213 | "\n",
214 | "\t\t更多,关于torch.nn.Linear,可以参考:https://pytorch.org/docs/stable/nn.html#linear-layers\n",
215 | " \"\"\"\n",
216 | " self.linear = torch.nn.Linear(in_features=1, out_features=1, bias=True)\n",
217 | "\n",
218 | " def forward(self, x):\n",
219 | " \"\"\"\n",
220 | " 前向传播计算神经网络的输出\n",
221 | " \"\"\"\n",
222 | " predict = self.linear(x)\n",
223 | " return predict"
224 | ],
225 | "metadata": {
226 | "collapsed": false,
227 | "pycharm": {
228 | "name": "#%%\n"
229 | }
230 | }
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "source": [
235 | "到这里预训练模型已经构建完毕。初始化预训练模型:"
236 | ],
237 | "metadata": {
238 | "collapsed": false,
239 | "pycharm": {
240 | "name": "#%% md\n"
241 | }
242 | }
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 10,
247 | "outputs": [],
248 | "source": [
249 | "model = LinearRegression() # 初始化模型"
250 | ],
251 | "metadata": {
252 | "collapsed": false,
253 | "pycharm": {
254 | "name": "#%%\n"
255 | }
256 | }
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "source": [
261 | "定义梯度下降器,这里选择随机梯度下降法:"
262 | ],
263 | "metadata": {
264 | "collapsed": false,
265 | "pycharm": {
266 | "name": "#%% md\n"
267 | }
268 | }
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 11,
273 | "outputs": [
274 | {
275 | "name": "stdout",
276 | "output_type": "stream",
277 | "text": [
278 | "Parameter containing:\n",
279 | "tensor([[-0.7801]], requires_grad=True)\n",
280 | "Parameter containing:\n",
281 | "tensor([0.3026], requires_grad=True)\n"
282 | ]
283 | }
284 | ],
285 | "source": [
286 | "\"\"\"\n",
287 | "torch.optim.SGD 接受几个重要的参数:\n",
288 | "- params: 模型参数\n",
289 | "- lr: 学习率\n",
290 | "\"\"\"\n",
291 | "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n",
292 | "\n",
293 | "# 这里可以看下模型参数\n",
294 | "for param in model.parameters(): # 因为模型有多个参数,所以model.parameters会返回一个可迭代的对象\n",
295 | " print(param)"
296 | ],
297 | "metadata": {
298 | "collapsed": false,
299 | "pycharm": {
300 | "name": "#%%\n"
301 | }
302 | }
303 | },
304 | {
305 | "cell_type": "markdown",
306 | "source": [
307 | "定义损失函数,这里使用MSE:"
308 | ],
309 | "metadata": {
310 | "collapsed": false,
311 | "pycharm": {
312 | "name": "#%% md\n"
313 | }
314 | }
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 12,
319 | "outputs": [],
320 | "source": [
321 | "loss_function = torch.nn.MSELoss()"
322 | ],
323 | "metadata": {
324 | "collapsed": false,
325 | "pycharm": {
326 | "name": "#%%\n"
327 | }
328 | }
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "source": [
333 | "此时就可以训练模型了:"
334 | ],
335 | "metadata": {
336 | "collapsed": false,
337 | "pycharm": {
338 | "name": "#%% md\n"
339 | }
340 | }
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 14,
345 | "outputs": [],
346 | "source": [
347 | "for epoch in range(10000): # 训练10000次\n",
348 | " \"\"\"\n",
349 | " 1. 将X带入模型,其会自动调用前向传递,计算出每个x对应的y值\n",
350 | " X.shape 和 predict_y.shape 都为(100,1),\n",
351 | " \"\"\"\n",
352 | " predict_y = model(X)\n",
353 | "\n",
354 | " \"\"\"\n",
355 | " 2. 通过损失函数计算损失\n",
356 | " \"\"\"\n",
357 | " loss = loss_function(predict_y, y)\n",
358 | "\n",
359 | " \"\"\"\n",
360 | " 3. 进行反向传播\n",
361 | " \"\"\"\n",
362 | " loss.backward()\n",
363 | "\n",
364 | " \"\"\"\n",
365 | " 4. 更新权重\n",
366 | " \"\"\"\n",
367 | " optimizer.step()\n",
368 | "\n",
369 | " \"\"\"\n",
370 | " 5.清空optimizer的梯度,否则会影响下次迭代\n",
371 | " \"\"\"\n",
372 | " optimizer.zero_grad()"
373 | ],
374 | "metadata": {
375 | "collapsed": false,
376 | "pycharm": {
377 | "name": "#%%\n"
378 | }
379 | }
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "source": [
384 | "看下最后的参数,结果符合预期:"
385 | ],
386 | "metadata": {
387 | "collapsed": false,
388 | "pycharm": {
389 | "name": "#%% md\n"
390 | }
391 | }
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 16,
396 | "outputs": [
397 | {
398 | "name": "stdout",
399 | "output_type": "stream",
400 | "text": [
401 | "Parameter containing:\n",
402 | "tensor([[3.0832]], requires_grad=True)\n",
403 | "Parameter containing:\n",
404 | "tensor([9.7287], requires_grad=True)\n"
405 | ]
406 | }
407 | ],
408 | "source": [
409 | "for param in model.parameters(): # 因为模型有多个参数,所以model.parameters会返回一个可迭代的对象\n",
410 | " print(param)"
411 | ],
412 | "metadata": {
413 | "collapsed": false,
414 | "pycharm": {
415 | "name": "#%%\n"
416 | }
417 | }
418 | },
419 | {
420 | "cell_type": "markdown",
421 | "source": [
422 | "再重新绘制一下图,看下最终效果:"
423 | ],
424 | "metadata": {
425 | "collapsed": false,
426 | "pycharm": {
427 | "name": "#%% md\n"
428 | }
429 | }
430 | },
431 | {
432 | "cell_type": "code",
433 | "execution_count": 17,
434 | "outputs": [
435 | {
436 | "data": {
437 | "text/plain": "",
438 | "image/png": "\n"
439 | },
440 | "metadata": {
441 | "needs_background": "light"
442 | },
443 | "output_type": "display_data"
444 | }
445 | ],
446 | "source": [
447 | "plt.scatter(X, y)\n",
448 | "plt.plot(X, model(X).detach().numpy(), color='red')\n",
449 | "plt.show()"
450 | ],
451 | "metadata": {
452 | "collapsed": false,
453 | "pycharm": {
454 | "name": "#%%\n"
455 | }
456 | }
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": null,
461 | "outputs": [],
462 | "source": [],
463 | "metadata": {
464 | "collapsed": false,
465 | "pycharm": {
466 | "name": "#%%\n"
467 | }
468 | }
469 | }
470 | ],
471 | "metadata": {
472 | "kernelspec": {
473 | "display_name": "Python 3",
474 | "language": "python",
475 | "name": "python3"
476 | },
477 | "language_info": {
478 | "codemirror_mode": {
479 | "name": "ipython",
480 | "version": 2
481 | },
482 | "file_extension": ".py",
483 | "mimetype": "text/x-python",
484 | "name": "python",
485 | "nbconvert_exporter": "python",
486 | "pygments_lexer": "ipython2",
487 | "version": "2.7.6"
488 | }
489 | },
490 | "nbformat": 4,
491 | "nbformat_minor": 0
492 | }
--------------------------------------------------------------------------------
/02_MNIST_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "# Pytorch入门实战(2)-使用BP神经网络实现MNIST手写数字识别"
7 | ],
8 | "metadata": {
9 | "collapsed": false,
10 | "pycharm": {
11 | "name": "#%% md\n"
12 | }
13 | }
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "collapsed": true,
19 | "pycharm": {
20 | "name": "#%% md\n"
21 | }
22 | },
23 | "source": [
24 | "# 涉及知识点\n",
25 | "\n",
26 | "[Pytorch nn.Module的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122797244)\n",
27 | "\n",
28 | "[Pytorch nn.Linear的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122797190)\n",
29 | "\n",
30 | "[PytorchVision Transforms的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122799782)\n",
31 | "\n",
32 | "[Pytorch中DataLoader的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122742656)\n",
33 | "\n",
34 | "[Pytorch详解NLLLoss和CrossEntropyLoss](https://blog.csdn.net/qq_22210253/article/details/85229988)\n",
35 | "\n",
36 | "[如何确定神经网络的层数和隐藏层神经元数量](https://zhuanlan.zhihu.com/p/100419971)"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "source": [
42 | "# 本文内容\n",
43 | "\n",
44 | "本文将会使用BP神经网络(就是最普通的神经网络)实现一个MNIST手写数据集的实现。话不多说,直接开始。\n",
45 | "\n",
46 | "本文所使用到的环境如下:\n",
47 | "\n",
48 | "```\n",
49 | "python==3.8.5\n",
50 | "torch==1.10.2\n",
51 | "torchvision==0.11.3\n",
52 | "matplotlib==3.2.2\n",
53 | "```\n",
54 | "\n",
55 | "首先先导入需要的包:"
56 | ],
57 | "metadata": {
58 | "collapsed": false,
59 | "pycharm": {
60 | "name": "#%% md\n"
61 | }
62 | }
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "outputs": [],
68 | "source": [
69 | "import os\n",
70 | "import torch\n",
71 | "import matplotlib.pyplot as plt\n",
72 | "from time import time\n",
73 | "from torchvision import datasets, transforms\n",
74 | "from torch import nn, optim"
75 | ],
76 | "metadata": {
77 | "collapsed": false,
78 | "pycharm": {
79 | "name": "#%%\n"
80 | }
81 | }
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "source": [
86 | "定义transform对象,其定义了数据集中的图片应该做怎样的处理:"
87 | ],
88 | "metadata": {
89 | "collapsed": false,
90 | "pycharm": {
91 | "name": "#%% md\n"
92 | }
93 | }
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 4,
98 | "outputs": [],
99 | "source": [
100 | "transform = transforms.Compose([transforms.ToTensor(),\n",
101 | " transforms.Normalize((0.5,), (0.5,)),])"
102 | ],
103 | "metadata": {
104 | "collapsed": false,
105 | "pycharm": {
106 | "name": "#%%\n"
107 | }
108 | }
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "source": [
113 | "加载和下载训练数据集,这里使用pytorch提供的API进行下载。如果你下载不下来,可以使用[百度网盘链接](https://pan.baidu.com/s/1NmxIlPhaeKSz_kFwCOn6rA?pwd=6hfa)进行下载,然后解压即可。"
114 | ],
115 | "metadata": {
116 | "collapsed": false,
117 | "pycharm": {
118 | "name": "#%% md\n"
119 | }
120 | }
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 6,
125 | "outputs": [
126 | {
127 | "data": {
128 | "text/plain": "Dataset MNIST\n Number of datapoints: 60000\n Root location: train_set\n Split: Train\n StandardTransform\nTransform: Compose(\n ToTensor()\n Normalize(mean=(0.5,), std=(0.5,))\n )"
129 | },
130 | "execution_count": 6,
131 | "metadata": {},
132 | "output_type": "execute_result"
133 | }
134 | ],
135 | "source": [
136 | "train_set = datasets.MNIST('train_set', # 下载到该文件夹下\n",
137 | " download=not os.path.exists('train_set'), # 是否下载,如果下载过,则不重复下载\n",
138 | " train=True, # 是否为训练集\n",
139 | " transform=transform # 要对图片做的transform\n",
140 | " )\n",
141 | "train_set"
142 | ],
143 | "metadata": {
144 | "collapsed": false,
145 | "pycharm": {
146 | "name": "#%%\n"
147 | }
148 | }
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "source": [
153 | "等待一段时间下载成功后,可以看到训练集中一共有6w个数据,接下来下载测试数据集:"
154 | ],
155 | "metadata": {
156 | "collapsed": false,
157 | "pycharm": {
158 | "name": "#%% md\n"
159 | }
160 | }
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 7,
165 | "outputs": [
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
171 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to test_set\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
172 | ]
173 | },
174 | {
175 | "data": {
176 | "text/plain": " 0%| | 0/9912422 [00:00, ?it/s]",
177 | "application/vnd.jupyter.widget-view+json": {
178 | "version_major": 2,
179 | "version_minor": 0,
180 | "model_id": "b1a7e1c62c8b45479ef7003bd82b6c67"
181 | }
182 | },
183 | "metadata": {},
184 | "output_type": "display_data"
185 | },
186 | {
187 | "name": "stdout",
188 | "output_type": "stream",
189 | "text": [
190 | "Extracting test_set\\MNIST\\raw\\train-images-idx3-ubyte.gz to test_set\\MNIST\\raw\n",
191 | "\n",
192 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
193 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
194 | ]
195 | },
196 | {
197 | "data": {
198 | "text/plain": " 0%| | 0/28881 [00:00, ?it/s]",
199 | "application/vnd.jupyter.widget-view+json": {
200 | "version_major": 2,
201 | "version_minor": 0,
202 | "model_id": "1f24ed72115e43a0bf2401eb904345e7"
203 | }
204 | },
205 | "metadata": {},
206 | "output_type": "display_data"
207 | },
208 | {
209 | "name": "stdout",
210 | "output_type": "stream",
211 | "text": [
212 | "Extracting test_set\\MNIST\\raw\\train-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\n",
213 | "\n",
214 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
215 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to test_set\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
216 | ]
217 | },
218 | {
219 | "data": {
220 | "text/plain": " 0%| | 0/1648877 [00:00, ?it/s]",
221 | "application/vnd.jupyter.widget-view+json": {
222 | "version_major": 2,
223 | "version_minor": 0,
224 | "model_id": "60b31908556446d2b9352c3cee8c556e"
225 | }
226 | },
227 | "metadata": {},
228 | "output_type": "display_data"
229 | },
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "Extracting test_set\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to test_set\\MNIST\\raw\n",
235 | "\n",
236 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
237 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
238 | ]
239 | },
240 | {
241 | "data": {
242 | "text/plain": " 0%| | 0/4542 [00:00, ?it/s]",
243 | "application/vnd.jupyter.widget-view+json": {
244 | "version_major": 2,
245 | "version_minor": 0,
246 | "model_id": "f96fb86f6acd476e9f2dd6916882da5a"
247 | }
248 | },
249 | "metadata": {},
250 | "output_type": "display_data"
251 | },
252 | {
253 | "name": "stdout",
254 | "output_type": "stream",
255 | "text": [
256 | "Extracting test_set\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\n",
257 | "\n"
258 | ]
259 | },
260 | {
261 | "data": {
262 | "text/plain": "Dataset MNIST\n Number of datapoints: 10000\n Root location: test_set\n Split: Test\n StandardTransform\nTransform: Compose(\n ToTensor()\n Normalize(mean=(0.5,), std=(0.5,))\n )"
263 | },
264 | "execution_count": 7,
265 | "metadata": {},
266 | "output_type": "execute_result"
267 | }
268 | ],
269 | "source": [
270 | "test_set = datasets.MNIST('test_set',\n",
271 | " download=not os.path.exists('test_set'),\n",
272 | " train=False,\n",
273 | " transform=transform\n",
274 | " )\n",
275 | "test_set"
276 | ],
277 | "metadata": {
278 | "collapsed": false,
279 | "pycharm": {
280 | "name": "#%%\n"
281 | }
282 | }
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "source": [
287 | "测试数据集包含1w条数据\n",
288 | "\n",
289 | "接下来构建训练数据集和测试数据集的DataLoader对象:"
290 | ],
291 | "metadata": {
292 | "collapsed": false,
293 | "pycharm": {
294 | "name": "#%% md\n"
295 | }
296 | }
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 8,
301 | "outputs": [
302 | {
303 | "name": "stdout",
304 | "output_type": "stream",
305 | "text": [
306 | "torch.Size([64, 1, 28, 28])\n",
307 | "torch.Size([64])\n"
308 | ]
309 | }
310 | ],
311 | "source": [
312 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
313 | "test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)\n",
314 | "\n",
315 | "dataiter = iter(train_loader)\n",
316 | "images, labels = dataiter.next()\n",
317 | "\n",
318 | "print(images.shape)\n",
319 | "print(labels.shape)"
320 | ],
321 | "metadata": {
322 | "collapsed": false,
323 | "pycharm": {
324 | "name": "#%%\n"
325 | }
326 | }
327 | },
328 | {
329 | "cell_type": "markdown",
330 | "source": [
331 | "在上面,我们将其分成64个一组的图片,每个图片只有一个通道(灰度图),大小为28x28。抽一张绘制一下:"
332 | ],
333 | "metadata": {
334 | "collapsed": false,
335 | "pycharm": {
336 | "name": "#%% md\n"
337 | }
338 | }
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 9,
343 | "outputs": [
344 | {
345 | "data": {
346 | "text/plain": "",
347 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOJ0lEQVR4nO3df6xU9Z3G8edZrP4BDQHvFQkF6TYaNRuXkhFNNA2bSv2FwZp0U/6oGM1SE1RqatC4RjRqghtbc42GeF1/4EqpNZSIBt0SrD/4p3E0LOISlUVoQbxcYgw0QbvgZ/+4h80V73znMnPmB3zfr+RmZs4z554PAw9n7py5cxwRAnDi+7tODwCgPSg7kAnKDmSCsgOZoOxAJk5q58Z6enpi+vTp7dwkkJUdO3Zo3759Hilrquy2L5PUJ2mMpH+PiGWp+0+fPl3VarWZTQJIqFQqNbOGn8bbHiPpMUmXSzpX0nzb5zb6/QC0VjM/s8+StC0itkfE3yT9VtK8csYCULZmyj5F0l+G3d5VLPsa2wttV21XBwcHm9gcgGY0U/aRXgT4xntvI6I/IioRUent7W1icwCa0UzZd0maOuz2dyR90tw4AFqlmbK/LelM29+1fbKkn0paW85YAMrW8KG3iDhk+yZJ/6mhQ29PRcT7pU0GoFRNHWePiHWS1pU0C4AW4u2yQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5CJpk7ZbHuHpAOSDks6FBGVMoYCUL6myl74p4jYV8L3AdBCPI0HMtFs2UPSH2y/Y3vhSHewvdB21XZ1cHCwyc0BaFSzZb8oImZKulzSIts/OPoOEdEfEZWIqPT29ja5OQCNaqrsEfFJcblX0hpJs8oYCkD5Gi677bG2v33kuqQfSdpS1mAAytXMq/GTJK2xfeT7/CYiXi1lKgCla7jsEbFd0j+WOAuAFuLQG5AJyg5kgrIDmaDsQCYoO5CJMn4RBuhKW7bUftvHpk2bkuvefPPNyfzw4cPJfP/+/cm8E9izA5mg7EAmKDuQCcoOZIKyA5mg7EAmKDuQCY6zn+Bee+21ZF6tVpP5kiVLyhznaz799NNk/sEHHyTz559/PpmvXLmyZtbscfDzzz8/me/evTuZT5kypantN4I9O5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmeA4+wngmWeeqZk99NBDyXXrHe9dtGhRMn/llVeS+VtvvVUze/bZZ5Prfv7558m8lcaPH5/MV69encw7cRy9HvbsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kguPsx4Gnn346mS9evLhmduDAgeS627dvT+aVSiWZf/zxx8n8yy+/TOYpp59+ejIfGBhI5hFRM+vp6Umuu27dumQ+derUZN6N6u7ZbT9le6/tLcOWTbS93vZHxeWE1o4JoFmjeRr/jKTLjlp2h6QNEXGmpA3FbQBdrG7ZI+JNSZ8dtXiepBXF9RWSri53LABla/QFukkRsUeSisvTat3R9kLbVdvVwcHBBjcHoFktfzU+IvojohIRld7e3lZvDkANjZZ9wPZkSSou95Y3EoBWaLTsayUtKK4vkPRiOeMAaJW6x9ltr5I0W1KP7V2SlkpaJul3tm+Q9GdJP2nlkMe7Q4cOJfPrrrsumb/00kvJvN6x9JSDBw8m823btiXzU089NZlfeeWVNbMbb7wxue769euT+V133ZXMJ06cWDNbs2ZNct16nwt/PKpb9oiYXyP6YcmzAGgh3i4LZIKyA5mg7EAmKDuQCcoOZIJfcW2Da6+9NpmvWrWqZds+44wzkvntt9+ezGfOnJnML7jggmOe6Yjly5cn876+vmQ+adKkZP7EE0/UzOr9uU5E7NmBTFB2IBOUHcgEZQcyQdmBTFB2IBOUHcgEx9nb4Kyzzmpq/dmzZyfza665pmY2b9685LrTpk1rZKRRS30k87333ptcd+/e9GeiPPjgg8l8zpw5NbNTTjklue6JiD07kAnKDmSCsgOZoOxAJig7kAnKDmSCsgOZcOq0tmWrVCpRrVbbtr1uUe8xPnz4cDIfM2ZMMrd9zDOVpd7f59y5c2tm9U65vGzZsmS+ZMmSZN7Jx6VTKpWKqtXqiH9w9uxAJig7kAnKDmSCsgOZoOxAJig7kAnKDmSC32dvg3rHe086qXv/Gnbv3p3Mb7vttmSeOpZ+4YUXJte99dZbk3mOx9GbUXfPbvsp23ttbxm27B7bu21vKr6uaO2YAJo1mqfxz0i6bITlD0fEjOKr9seRAOgKdcseEW9K+qwNswBooWZeoLvJ9ubiaf6EWneyvdB21XZ1cHCwic0BaEajZV8u6XuSZkjaI+lXte4YEf0RUYmISm9vb4ObA9CshsoeEQMRcTgivpL0hKRZ5Y4FoGwNld325GE3fyxpS637AugOdQ/w2l4labakHtu7JC2VNNv2DEkhaYekn7duRLTSzp07k/n999+fzN94441knvrM+xdeeCG57sknn5zMcWzqlj0i5o+w+MkWzAKghXi7LJAJyg5kgrIDmaDsQCYoO5CJ7v3dSpTi4MGDyfyWW25J5mvXrk3m48ePT+Z9fX01s56enuS6KBd7diATlB3IBGUHMkHZgUxQdiATlB3IBGUHMsFx9hPco48+mszrHUev9zHXS5cuTebnnXdeMkf7sGcHMkHZgUxQdiATlB3IBGUHMkHZgUxQdiATHGc/DkREMn/99ddrZkuWLGlq23Pnzk3m9U6rjO7Bnh3IBGUHMkHZgUxQdiATlB3IBGUHMkHZgUxwnP04sH79+mR+6aWXNvy9zz777GT+3HPPNfy90V3q7tltT7X9R9tbbb9ve3GxfKLt9bY/Ki4ntH5cAI0azdP4Q5J+GRHnSLpQ0iLb50q6Q9KGiDhT0obiNoAuVbfsEbEnIt4trh+QtFXSFEnzJK0o7rZC0tUtmhFACY7pBTrb0yV9X9KfJE2KiD3S0H8Ikk6rsc5C21Xb1cHBwSbHBdCoUZfd9jhJqyX9IiL2j3a9iOiPiEpEVHp7exuZEUAJRlV229/SUNFXRsTvi8UDticX+WRJe1szIoAy1D30ZtuSnpS0NSJ+PSxaK2mBpGXF5YstmTADH374YTK/7777Gv7el1xySTJ/+OGHk/nYsWMb3ja6y2iOs18k6WeS3rO9qVh2p4ZK/jvbN0j6s6SftGRCAKWoW/aI2CjJNeIfljsOgFbh7bJAJig7kAnKDmSCsgOZoOxAJvgV1zao9zbhBx54IJlv3Lgxmc+YMaNm9sgjjyTXPeecc5I5Thzs2YFMUHYgE5QdyARlBzJB2YFMUHYgE5QdyATH2Uvw6quvJvPHHnssmb/88svJfObMmcn88ccfr5lxHB1HsGcHMkHZgUxQdiATlB3IBGUHMkHZgUxQdiATHGcfpf7+/prZ3XffnVx3YGAgmV988cXJfOXKlcl82rRpyRyQ2LMD2aDsQCYoO5AJyg5kgrIDmaDsQCYoO5CJ0ZyffaqkZyWdLukrSf0R0Wf7Hkn/IunIh6LfGRHrWjVoq+3cuTOZL168uGb2xRdfJNedM2dOMl++fHky5zg6yjCaN9UckvTLiHjX9rclvWN7fZE9HBEPtW48AGUZzfnZ90jaU1w/YHurpCmtHgxAuY7pZ3bb0yV9X9KfikU32d5s+ynbE2qss9B21Xa13mmQALTOqMtue5yk1ZJ+ERH7JS2X9D1JMzS05//VSOtFRH9EVCKi0tvb2/zEABoyqrLb/paGir4yIn4vSRExEBGHI+IrSU9ImtW6MQE0q27ZbVvSk5K2RsSvhy2fPOxuP5a0pfzxAJRlNK/GXyTpZ5Les72pWHanpPm2Z0gKSTsk/bwF87XN5s2bk3nq8Nr111+fXLevry+Zjxs3LpkDZRjNq/EbJXmE6Lg9pg7kiHfQAZmg7EAmKDuQCcoOZIKyA5mg7EAm+CjpwlVXXZXMI6JNkwCtwZ4dyARlBzJB2YFMUHYgE5QdyARlBzJB2YFMuJ3Hj20PShr+mc09kva1bYBj062zdetcErM1qszZzoiIET//ra1l/8bG7WpEVDo2QEK3ztatc0nM1qh2zcbTeCATlB3IRKfL3t/h7ad062zdOpfEbI1qy2wd/ZkdQPt0es8OoE0oO5CJjpTd9mW2P7C9zfYdnZihFts7bL9ne5Ptaodnecr2Xttbhi2baHu97Y+KyxHPsdeh2e6xvbt47DbZvqJDs021/UfbW22/b3txsbyjj11irrY8bm3/md32GEkfSpojaZektyXNj4j/busgNdjeIakSER1/A4btH0j6q6RnI+IfimX/JumziFhW/Ec5ISJu75LZ7pH0106fxrs4W9Hk4acZl3S1pOvUwccuMdc/qw2PWyf27LMkbYuI7RHxN0m/lTSvA3N0vYh4U9JnRy2eJ2lFcX2Fhv6xtF2N2bpCROyJiHeL6wckHTnNeEcfu8RcbdGJsk+R9Jdht3epu873HpL+YPsd2ws7PcwIJkXEHmnoH4+k0zo8z9Hqnsa7nY46zXjXPHaNnP68WZ0o+0inkuqm438XRcRMSZdLWlQ8XcXojOo03u0ywmnGu0Kjpz9vVifKvkvS1GG3vyPpkw7MMaKI+KS43CtpjbrvVNQDR86gW1zu7fA8/6+bTuM90mnG1QWPXSdPf96Jsr8t6Uzb37V9sqSfSlrbgTm+wfbY4oUT2R4r6UfqvlNRr5W0oLi+QNKLHZzla7rlNN61TjOuDj92HT/9eUS0/UvSFRp6Rf5/JP1rJ2aoMdffS/qv4uv9Ts8maZWGntb9r4aeEd0g6VRJGyR9VFxO7KLZ/kPSe5I2a6hYkzs028Ua+tFws6RNxdcVnX7sEnO15XHj7bJAJngHHZAJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmfg/SP9Cn800eXAAAAAASUVORK5CYII=\n"
348 | },
349 | "metadata": {
350 | "needs_background": "light"
351 | },
352 | "output_type": "display_data"
353 | }
354 | ],
355 | "source": [
356 | "plt.imshow(images[0].numpy().squeeze(), cmap='gray_r');"
357 | ],
358 | "metadata": {
359 | "collapsed": false,
360 | "pycharm": {
361 | "name": "#%%\n"
362 | }
363 | }
364 | },
365 | {
366 | "cell_type": "markdown",
367 | "source": [
368 | "到这里,前期准备工作就结束了。\n",
369 | "\n",
370 | "---\n",
371 | "\n",
372 | "开始定义神经网络"
373 | ],
374 | "metadata": {
375 | "collapsed": false,
376 | "pycharm": {
377 | "name": "#%% md\n"
378 | }
379 | }
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": 10,
384 | "outputs": [],
385 | "source": [
386 | "class NerualNetwork(nn.Module):\n",
387 | "\n",
388 | " def __init__(self):\n",
389 | " super().__init__()\n",
390 | "\n",
391 | " \"\"\"\n",
392 | " 定义第一个线性层,\n",
393 | " 输入为图片(28x28),\n",
394 | " 输出为第一个隐层的输入,大小为128。\n",
395 | " \"\"\"\n",
396 | " self.linear1 = nn.Linear(28 * 28, 128)\n",
397 | " # 在第一个隐层使用ReLU激活函数\n",
398 | " self.relu1 = nn.ReLU()\n",
399 | " \"\"\"\n",
400 | " 定义第二个线性层,\n",
401 | " 输入是第一个隐层的输出,\n",
402 | " 输出为第二个隐层的输入,大小为64。\n",
403 | " \"\"\"\n",
404 | " self.linear2 = nn.Linear(128, 64)\n",
405 | " # 在第二个隐层使用ReLU激活函数\n",
406 | " self.relu2 = nn.ReLU()\n",
407 | " \"\"\"\n",
408 | " 定义第三个线性层,\n",
409 | " 输入是第二个隐层的输出,\n",
410 | " 输出为输出层,大小为10\n",
411 | " \"\"\"\n",
412 | " self.linear3 = nn.Linear(64, 10)\n",
413 | " # 最终的输出经过softmax进行归一化\n",
414 | " self.softmax = nn.LogSoftmax(dim=1)\n",
415 | "\n",
416 | " # 上述动作可以直接使用nn.Sequential写成如下形式:\n",
417 | " self.model = nn.Sequential(nn.Linear(28 * 28, 128),\n",
418 | " nn.ReLU(),\n",
419 | " nn.Linear(128, 64),\n",
420 | " nn.ReLU(),\n",
421 | " nn.Linear(64, 10),\n",
422 | " nn.LogSoftmax(dim=1)\n",
423 | " )\n",
424 | "\n",
425 | " def forward(self, x):\n",
426 | " \"\"\"\n",
427 | " 定义神经网络的前向传播\n",
428 | " x: 图片数据, shape为(64, 1, 28, 28)\n",
429 | " \"\"\"\n",
430 | " # 首先将x的shape转为(64, 784)\n",
431 | " x = x.view(x.shape[0], -1)\n",
432 | "\n",
433 | " # 接下来进行前向传播\n",
434 | " x = self.linear1(x)\n",
435 | " x = self.relu1(x)\n",
436 | " x = self.linear2(x)\n",
437 | " x = self.relu2(x)\n",
438 | " x = self.linear3(x)\n",
439 | " x = self.softmax(x)\n",
440 | "\n",
441 | " # 上述一串,可以直接使用 x = self.model(x) 代替。\n",
442 | "\n",
443 | " return x"
444 | ],
445 | "metadata": {
446 | "collapsed": false,
447 | "pycharm": {
448 | "name": "#%%\n"
449 | }
450 | }
451 | },
452 | {
453 | "cell_type": "code",
454 | "execution_count": 11,
455 | "outputs": [],
456 | "source": [
457 | "model = NerualNetwork()"
458 | ],
459 | "metadata": {
460 | "collapsed": false,
461 | "pycharm": {
462 | "name": "#%%\n"
463 | }
464 | }
465 | },
466 | {
467 | "cell_type": "markdown",
468 | "source": [
469 | "神经网络定义完后,开始定义损失函数,这里选用**负对数似然**损失函数(`NLLLoss`, [negative log likelihood loss](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html)),其常用于分类任务。[详情可参考链接](https://blog.csdn.net/qq_22210253/article/details/85229988)"
470 | ],
471 | "metadata": {
472 | "collapsed": false,
473 | "pycharm": {
474 | "name": "#%% md\n"
475 | }
476 | }
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": 12,
481 | "outputs": [],
482 | "source": [
483 | "criterion = nn.NLLLoss()"
484 | ],
485 | "metadata": {
486 | "collapsed": false,
487 | "pycharm": {
488 | "name": "#%%\n"
489 | }
490 | }
491 | },
492 | {
493 | "cell_type": "markdown",
494 | "source": [
495 | "接下来定义优化器,这里使用随机梯度下降法,学习率设置为0.003,momentum取默认的0.9(用于防止过拟合)"
496 | ],
497 | "metadata": {
498 | "collapsed": false,
499 | "pycharm": {
500 | "name": "#%% md\n"
501 | }
502 | }
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": 13,
507 | "outputs": [],
508 | "source": [
509 | "optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)"
510 | ],
511 | "metadata": {
512 | "collapsed": false,
513 | "pycharm": {
514 | "name": "#%%\n"
515 | }
516 | }
517 | },
518 | {
519 | "cell_type": "markdown",
520 | "source": [
521 | "准备工作完毕,开始训练数据集:"
522 | ],
523 | "metadata": {
524 | "collapsed": false,
525 | "pycharm": {
526 | "name": "#%% md\n"
527 | }
528 | }
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": 14,
533 | "outputs": [
534 | {
535 | "name": "stdout",
536 | "output_type": "stream",
537 | "text": [
538 | "Epoch 0 - Training loss: 0.6294137474252726\n",
539 | "Epoch 1 - Training loss: 0.27885234054884933\n",
540 | "Epoch 2 - Training loss: 0.2180362274207032\n",
541 | "Epoch 3 - Training loss: 0.17646600610650043\n",
542 | "Epoch 4 - Training loss: 0.14901786734228895\n",
543 | "Epoch 5 - Training loss: 0.12897429347081957\n",
544 | "Epoch 6 - Training loss: 0.11274210547309504\n",
545 | "Epoch 7 - Training loss: 0.10064082649717135\n",
546 | "Epoch 8 - Training loss: 0.09091206552532277\n",
547 | "Epoch 9 - Training loss: 0.08191311532861865\n",
548 | "Epoch 10 - Training loss: 0.07508732156971021\n",
549 | "Epoch 11 - Training loss: 0.07009464516681728\n",
550 | "Epoch 12 - Training loss: 0.0649078527074764\n",
551 | "Epoch 13 - Training loss: 0.06004000982112769\n",
552 | "Epoch 14 - Training loss: 0.054164604703361575\n",
553 | "\n",
554 | "Training Time (in minutes) = 0.9925608317057292\n"
555 | ]
556 | }
557 | ],
558 | "source": [
559 | "time0 = time() # 记录下当前时间\n",
560 | "epochs = 15 # 一共训练15轮\n",
561 | "for e in range(epochs):\n",
562 | " running_loss = 0 # 本轮的损失值\n",
563 | " for images, labels in train_loader:\n",
564 | " # 前向传播获取预测值\n",
565 | " output = model(images)\n",
566 | "\n",
567 | " # 计算损失\n",
568 | " loss = criterion(output, labels)\n",
569 | "\n",
570 | " # 进行反向传播\n",
571 | " loss.backward()\n",
572 | "\n",
573 | " # 更新权重\n",
574 | " optimizer.step()\n",
575 | "\n",
576 | " # 清空梯度\n",
577 | " optimizer.zero_grad()\n",
578 | "\n",
579 | " # 累加损失\n",
580 | " running_loss += loss.item()\n",
581 | " else:\n",
582 | " # 一轮循环结束后打印本轮的损失函数\n",
583 | " print(\"Epoch {} - Training loss: {}\".format(e, running_loss/len(train_loader)))\n",
584 | "\n",
585 | "# 打印总的训练时间\n",
586 | "print(\"\\nTraining Time (in minutes) =\",(time()-time0)/60)"
587 | ],
588 | "metadata": {
589 | "collapsed": false,
590 | "pycharm": {
591 | "name": "#%%\n"
592 | }
593 | }
594 | },
595 | {
596 | "cell_type": "markdown",
597 | "source": [
598 | "最终在我这台机器上,花费了2分多钟完成了训练。可以看到,损失是越来越小的。"
599 | ],
600 | "metadata": {
601 | "collapsed": false,
602 | "pycharm": {
603 | "name": "#%% md\n"
604 | }
605 | }
606 | },
607 | {
608 | "cell_type": "markdown",
609 | "source": [
610 | "接下来进行模型的评估"
611 | ],
612 | "metadata": {
613 | "collapsed": false,
614 | "pycharm": {
615 | "name": "#%% md\n"
616 | }
617 | }
618 | },
619 | {
620 | "cell_type": "code",
621 | "execution_count": 15,
622 | "outputs": [
623 | {
624 | "name": "stdout",
625 | "output_type": "stream",
626 | "text": [
627 | "Number Of Images Tested = 10000\n",
628 | "\n",
629 | "Model Accuracy = 0.97\n"
630 | ]
631 | }
632 | ],
633 | "source": [
634 | "correct_count, all_count = 0, 0\n",
635 | "model.eval() # 将模型设置为评估模式\n",
636 | "\n",
637 | "# 从test_loader中一批一批加载图片\n",
638 | "for images,labels in test_loader:\n",
639 | " # 循环检测这一批图片\n",
640 | " for i in range(len(labels)):\n",
641 | " logps = model(images[i]) # 进行前向传播,获取预测值\n",
642 | " probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表(因为一次只预测一张照片)\n",
643 | " pred_label = probab.index(max(probab)) # 取最大的index作为预测结果\n",
644 | " true_label = labels.numpy()[i]\n",
645 | " if(true_label == pred_label): # 判断是否预测正确\n",
646 | " correct_count += 1\n",
647 | " all_count += 1\n",
648 | "\n",
649 | "print(\"Number Of Images Tested =\", all_count)\n",
650 | "print(\"\\nModel Accuracy =\", (correct_count/all_count))"
651 | ],
652 | "metadata": {
653 | "collapsed": false,
654 | "pycharm": {
655 | "name": "#%%\n"
656 | }
657 | }
658 | },
659 | {
660 | "cell_type": "markdown",
661 | "source": [
662 | "最终,本次训练在测试数据集上的精准率为97.41%"
663 | ],
664 | "metadata": {
665 | "collapsed": false,
666 | "pycharm": {
667 | "name": "#%% md\n",
668 | "is_executing": true
669 | }
670 | }
671 | },
672 | {
673 | "cell_type": "markdown",
674 | "source": [
675 | "# 参考资料\n",
676 | "\n",
677 | "[Handwritten Digit Recognition Using PyTorch — Intro To Neural Networks](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627): https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627"
678 | ],
679 | "metadata": {
680 | "collapsed": false,
681 | "pycharm": {
682 | "name": "#%% md\n"
683 | }
684 | }
685 | },
686 | {
687 | "cell_type": "code",
688 | "execution_count": null,
689 | "outputs": [],
690 | "source": [],
691 | "metadata": {
692 | "collapsed": false,
693 | "pycharm": {
694 | "name": "#%%\n"
695 | }
696 | }
697 | }
698 | ],
699 | "metadata": {
700 | "kernelspec": {
701 | "display_name": "Python 3",
702 | "language": "python",
703 | "name": "python3"
704 | },
705 | "language_info": {
706 | "codemirror_mode": {
707 | "name": "ipython",
708 | "version": 2
709 | },
710 | "file_extension": ".py",
711 | "mimetype": "text/x-python",
712 | "name": "python",
713 | "nbconvert_exporter": "python",
714 | "pygments_lexer": "ipython2",
715 | "version": "2.7.6"
716 | }
717 | },
718 | "nbformat": 4,
719 | "nbformat_minor": 0
720 | }
--------------------------------------------------------------------------------
/03_cnn_image_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# Pytorch入门实战(3):使用简单CNN实现物体分类"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "source": [
18 | "# 本文涉及知识点\n",
19 | "[PytorchVision Transforms的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122799782)\n",
20 | "\n",
21 | "\n",
22 | "[Pytorch nn.Linear的基本用法\n",
23 | "](https://blog.csdn.net/zhaohongfei_358/article/details/122797190)\n",
24 | "\n",
25 | "[nn.Conv2d官方文档](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)"
26 | ],
27 | "metadata": {
28 | "collapsed": false,
29 | "pycharm": {
30 | "name": "#%% md\n"
31 | }
32 | }
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "source": [
37 | "# 本文内容\n",
38 | "\n",
39 | "使用自己随便构造的一个简单的CNN网络,使用CIFAR10数据集(包含10种类别,图片大小为32x32)训练一个分类网络。\n",
40 | "\n",
41 | "本文所使用到的环境如下:\n",
42 | "\n",
43 | "```\n",
44 | "torch==1.10.2\n",
45 | "```"
46 | ],
47 | "metadata": {
48 | "collapsed": false,
49 | "pycharm": {
50 | "name": "#%% md\n"
51 | }
52 | }
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "source": [
57 | "# 数据预处理"
58 | ],
59 | "metadata": {
60 | "collapsed": false,
61 | "pycharm": {
62 | "name": "#%% md\n"
63 | }
64 | }
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "source": [
69 | "首先导入需要的包:"
70 | ],
71 | "metadata": {
72 | "collapsed": false,
73 | "pycharm": {
74 | "name": "#%% md\n"
75 | }
76 | }
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 1,
81 | "outputs": [],
82 | "source": [
83 | "import torch\n",
84 | "import torchvision\n",
85 | "import torchvision.transforms as transforms\n",
86 | "import torch.nn as nn\n",
87 | "import torch.optim as optim\n",
88 | "\n",
89 | "import matplotlib.pyplot as plt\n",
90 | "import numpy as np"
91 | ],
92 | "metadata": {
93 | "collapsed": false,
94 | "pycharm": {
95 | "name": "#%%\n"
96 | }
97 | }
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "source": [
102 | "定义transform,表明图片的处理方式和batch_size:"
103 | ],
104 | "metadata": {
105 | "collapsed": false,
106 | "pycharm": {
107 | "name": "#%% md\n"
108 | }
109 | }
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 2,
114 | "outputs": [],
115 | "source": [
116 | "# 如果你的内存不够的话,可以减小batch_size\n",
117 | "# 一般batch_size越大,模型越稳定,训练速度越快。(但也不是越大越好)\n",
118 | "batch_size = 16\n",
119 | "\n",
120 | "transform = transforms.Compose(\n",
121 | " [transforms.ToTensor(), # 将图片转为Tensor类型\n",
122 | " # 对图片进行正则化。第一个参数为mean(均值),第二个为std(方差)。每个参数之所以有三个0.5,是因为有RGB三个通道。\n",
123 | " # 综上,这句就是把图片的RGB三个通道都正则化到均值为0.5,方差为0.5的分布上。\n",
124 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])"
125 | ],
126 | "metadata": {
127 | "collapsed": false,
128 | "pycharm": {
129 | "name": "#%%\n"
130 | }
131 | }
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "source": [
136 | "之后开始准备数据集,这里直接使用官方提供的数据集。如果你的网够快,可以把`download=False`改为`True`,或者通过[百度网盘](https://pan.baidu.com/s/120cx2BA2hZ2nB_qaBgekdA?pwd=6aof)下载,然后解压到`data`目录下"
137 | ],
138 | "metadata": {
139 | "collapsed": false,
140 | "pycharm": {
141 | "name": "#%% md\n"
142 | }
143 | }
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 4,
148 | "outputs": [
149 | {
150 | "name": "stdout",
151 | "output_type": "stream",
152 | "text": [
153 | "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\\cifar-10-python.tar.gz\n"
154 | ]
155 | },
156 | {
157 | "data": {
158 | "text/plain": " 0%| | 0/170498071 [00:00, ?it/s]",
159 | "application/vnd.jupyter.widget-view+json": {
160 | "version_major": 2,
161 | "version_minor": 0,
162 | "model_id": "411ce37db9e4485cb517c952f8c9abf5"
163 | }
164 | },
165 | "metadata": {},
166 | "output_type": "display_data"
167 | },
168 | {
169 | "name": "stdout",
170 | "output_type": "stream",
171 | "text": [
172 | "Extracting ./data\\cifar-10-python.tar.gz to ./data\n",
173 | "Files already downloaded and verified\n"
174 | ]
175 | }
176 | ],
177 | "source": [
178 | "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
179 | " download=True, transform=transform)\n",
180 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
181 | " shuffle=True)\n",
182 | "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
183 | " download=True, transform=transform)\n",
184 | "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
185 | " shuffle=False)\n",
186 | "\n",
187 | "# CIFAR10总共10个类别\n",
188 | "classes = ('plane', 'car', 'bird', 'cat',\n",
189 | " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
190 | ],
191 | "metadata": {
192 | "collapsed": false,
193 | "pycharm": {
194 | "name": "#%%\n"
195 | }
196 | }
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "source": [
201 | "接下来随便打印几个看看:"
202 | ],
203 | "metadata": {
204 | "collapsed": false,
205 | "pycharm": {
206 | "name": "#%% md\n"
207 | }
208 | }
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 5,
213 | "outputs": [],
214 | "source": [
215 | "def imshow(img):\n",
216 | "\t# 因为之前正则化了,所以显示图片前要恢复\n",
217 | " img = img / 2 + 0.5 # unnormalize\n",
218 | " # 转成Numpy.ndarray格式,plt不认tensor\n",
219 | " npimg = img.numpy()\n",
220 | " # plt.imshow接受的图片shape为(h, w, c)\n",
221 | " # 即通道在最后一维,而传过来的img是(c, h, w)\n",
222 | " # 所以要用transpose改一下\n",
223 | " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
224 | " plt.show()"
225 | ],
226 | "metadata": {
227 | "collapsed": false,
228 | "pycharm": {
229 | "name": "#%%\n"
230 | }
231 | }
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 6,
236 | "outputs": [
237 | {
238 | "data": {
239 | "text/plain": "",
240 | "image/png": "\n"
241 | },
242 | "metadata": {
243 | "needs_background": "light"
244 | },
245 | "output_type": "display_data"
246 | },
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "car car plane truck horse horse dog horse car plane dog bird frog deer bird horse\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "# 获取一些图片及其对应的标签\n",
257 | "dataiter = iter(trainloader)\n",
258 | "\"\"\"\n",
259 | "这里images是一个tensor,shape为(16, 3, 32, 32)\n",
260 | "\t其中16为batch_size,即16张图片,3为RGB通道,图片大小为32x32\n",
261 | "labels也是tensor,shape为(16),为每个图片的对应的标签。\n",
262 | "\"\"\"\n",
263 | "images, labels = dataiter.next()\n",
264 | "\n",
265 | "\"\"\"\n",
266 | "make_grid:制作表格。即把多张图片拼到一张中去。\n",
267 | "\t\t nrow=8,表示生成表格的列数。\n",
268 | "从下面的输出可以看到,images的16张图片被make_grid\n",
269 | "按照2x8的表格拼成了一张大图片。该方法方便人们进行\n",
270 | "图像展示。\n",
271 | "\"\"\"\n",
272 | "imshow(torchvision.utils.make_grid(images, nrow=8))\n",
273 | "# print labels\n",
274 | "print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))"
275 | ],
276 | "metadata": {
277 | "collapsed": false,
278 | "pycharm": {
279 | "name": "#%%\n"
280 | }
281 | }
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "source": [
286 | "# 定义CNN分类模型"
287 | ],
288 | "metadata": {
289 | "collapsed": false,
290 | "pycharm": {
291 | "name": "#%% md\n"
292 | }
293 | }
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": 8,
298 | "outputs": [],
299 | "source": [
300 | "class Net(nn.Module):\n",
301 | " def __init__(self):\n",
302 | " super().__init__()\n",
303 | " \"\"\"\n",
304 | " \t定义卷积层, nn.Conv2d官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html\n",
305 | " \tnn.Conv2d包含三个重要参数:\n",
306 | " \t\tin_channels: 输入的通道数\n",
307 | " \t\tout_channels: 输出的通道数\n",
308 | " \t\tkernel_size: 卷积核的大小\n",
309 | " \t\tstride: 步长,默认为1\n",
310 | " \t\tpadding: 填充,默认为0,即不进行填充\n",
311 | "\n",
312 | " \t补充:这里卷积层2d的意思是数据是“2维的”,\n",
313 | " \t\t 例如图片就是2维数据(长×宽)。同理也有Conv1d,\n",
314 | " \t\t 是针对于文本、信号等1维数据,也有Conv3d\n",
315 | " \t\t 是针对视频等这种3维数据。\n",
316 | " \t\"\"\"\n",
317 | " self.classifier = nn.Sequential(\n",
318 | " nn.Conv2d(3, 6, 5),\n",
319 | " # 激活函数\n",
320 | " nn.ReLU(),\n",
321 | " # 使用MaxPool进行下采样。\n",
322 | " nn.MaxPool2d(2, 2),\n",
323 | " nn.Conv2d(6, 16, 5),\n",
324 | " nn.ReLU(),\n",
325 | " nn.MaxPool2d(2, 2),\n",
326 | " # 当完成卷积后,使用flatten将数据展开\n",
327 | " # 即将tensor的shape从(batch_size, c, h, w)变成(batch_size, c*h*w),这样才能送给全连接层\n",
328 | " nn.Flatten(),\n",
329 | " # 最后接全连接层。\n",
330 | " # 计算方式可以参考:https://blog.csdn.net/zhaohongfei_358/article/details/123269313\n",
331 | " nn.Linear(16 * 5 * 5, 120),\n",
332 | " nn.ReLU(),\n",
333 | " nn.Linear(120, 84),\n",
334 | " nn.ReLU(),\n",
335 | " nn.Linear(84, 10)\n",
336 | " # 注意这里并没有调用Softmax,也不能调Softmax\n",
337 | " # 这是因为Softmax被包含在了CrossEntropyLoss损失函数中\n",
338 | " # 如果这里调用的话,就会调用两遍,最后网络啥都学不着\n",
339 | " )\n",
340 | "\n",
341 | " def forward(self, x):\n",
342 | " return self.classifier(x)"
343 | ],
344 | "metadata": {
345 | "collapsed": false,
346 | "pycharm": {
347 | "name": "#%%\n"
348 | }
349 | }
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": 9,
354 | "outputs": [],
355 | "source": [
356 | "net = Net()\n",
357 | "# 使用简单的CrossEntorpyLoss作为损失函数,一般多分类问题都用这个\n",
358 | "criterion = nn.CrossEntropyLoss()\n",
359 | "# 使用简单的SGD作为优化器\n",
360 | "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
361 | ],
362 | "metadata": {
363 | "collapsed": false,
364 | "pycharm": {
365 | "name": "#%%\n"
366 | }
367 | }
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "source": [
372 | "# 训练网络"
373 | ],
374 | "metadata": {
375 | "collapsed": false,
376 | "pycharm": {
377 | "name": "#%% md\n"
378 | }
379 | }
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "source": [
384 | "开始训练网络,由于网络较小,这里直接用cpu进行训练"
385 | ],
386 | "metadata": {
387 | "collapsed": false,
388 | "pycharm": {
389 | "name": "#%% md\n"
390 | }
391 | }
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": 10,
396 | "outputs": [
397 | {
398 | "name": "stdout",
399 | "output_type": "stream",
400 | "text": [
401 | "[1, 2000] loss: 2.113\n",
402 | "[2, 2000] loss: 1.567\n",
403 | "[3, 2000] loss: 1.406\n",
404 | "[4, 2000] loss: 1.280\n",
405 | "[5, 2000] loss: 1.194\n",
406 | "[6, 2000] loss: 1.121\n",
407 | "[7, 2000] loss: 1.063\n",
408 | "[8, 2000] loss: 1.017\n",
409 | "[9, 2000] loss: 0.966\n",
410 | "[10, 2000] loss: 0.924\n",
411 | "[11, 2000] loss: 0.885\n",
412 | "[12, 2000] loss: 0.848\n",
413 | "[13, 2000] loss: 0.825\n",
414 | "[14, 2000] loss: 0.787\n",
415 | "[15, 2000] loss: 0.758\n",
416 | "[16, 2000] loss: 0.735\n",
417 | "[17, 2000] loss: 0.698\n",
418 | "[18, 2000] loss: 0.679\n",
419 | "[19, 2000] loss: 0.655\n",
420 | "[20, 2000] loss: 0.630\n",
421 | "Finished Training\n"
422 | ]
423 | }
424 | ],
425 | "source": [
426 | "# 把所有训练样本看过一遍称为1个epoch\n",
427 | "# 简单起见,这里只训练20个epochs\n",
428 | "epochs = 20\n",
429 | "\n",
430 | "for epoch in range(epochs):\n",
431 | "\t# 记录一下损失\n",
432 | " running_loss = 0.0\n",
433 | " for i, data in enumerate(trainloader):\n",
434 | "\t\t# trainloader返回的是tuple,第一个是图片数,第二个对应的labels\n",
435 | " inputs, labels = data\n",
436 | "\n",
437 | " # 清除之前的梯度\n",
438 | " optimizer.zero_grad()\n",
439 | "\n",
440 | " # 进行前向传播\n",
441 | " outputs = net(inputs)\n",
442 | " # 计算损失\n",
443 | " loss = criterion(outputs, labels)\n",
444 | " # 反向传播\n",
445 | " loss.backward()\n",
446 | " # 更新参数\n",
447 | " optimizer.step()\n",
448 | "\n",
449 | " # 记录损失,每2000次打印一次损失\n",
450 | " running_loss += loss.item()\n",
451 | " if i % 2000 == 1999:\n",
452 | " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
453 | " running_loss = 0.0\n",
454 | "\n",
455 | "print('Finished Training')"
456 | ],
457 | "metadata": {
458 | "collapsed": false,
459 | "pycharm": {
460 | "name": "#%%\n"
461 | }
462 | }
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "source": [
467 | "在训练了20个epoch后,损失降到了0.648。可以看到损失还在下降,大家可以尝试多训练一段时间。"
468 | ],
469 | "metadata": {
470 | "collapsed": false,
471 | "pycharm": {
472 | "name": "#%% md\n"
473 | }
474 | }
475 | },
476 | {
477 | "cell_type": "markdown",
478 | "source": [
479 | "# 测试模型"
480 | ],
481 | "metadata": {
482 | "collapsed": false,
483 | "pycharm": {
484 | "name": "#%% md\n"
485 | }
486 | }
487 | },
488 | {
489 | "cell_type": "markdown",
490 | "source": [
491 | "测试下模型总的精准率:"
492 | ],
493 | "metadata": {
494 | "collapsed": false,
495 | "pycharm": {
496 | "name": "#%% md\n"
497 | }
498 | },
499 | "outputs": [
500 | {
501 | "ename": "SyntaxError",
502 | "evalue": "invalid character in identifier (, line 1)",
503 | "output_type": "error",
504 | "traceback": [
505 | "\u001B[1;36m File \u001B[1;32m\"\"\u001B[1;36m, line \u001B[1;32m1\u001B[0m\n\u001B[1;33m 测试下模型总的精准率:\u001B[0m\n\u001B[1;37m ^\u001B[0m\n\u001B[1;31mSyntaxError\u001B[0m\u001B[1;31m:\u001B[0m invalid character in identifier\n"
506 | ]
507 | }
508 | ],
509 | "execution_count": 11
510 | },
511 | {
512 | "cell_type": "code",
513 | "execution_count": 12,
514 | "outputs": [
515 | {
516 | "name": "stdout",
517 | "output_type": "stream",
518 | "text": [
519 | "Accuracy of the network on the 10000 test images: 64 %\n"
520 | ]
521 | }
522 | ],
523 | "source": [
524 | "correct = 0 # 记录正确的数量\n",
525 | "total = 0 # 记录总数\n",
526 | "\n",
527 | "# torch.no_grad表示不需要计算梯度。\n",
528 | "with torch.no_grad():\n",
529 | " for data in testloader:\n",
530 | " images, labels = data\n",
531 | "\t\t# 前向传播\n",
532 | " outputs = net(images)\n",
533 | " \"\"\"\n",
534 | "\t\toutputs.shape为(16, 10),batch_size为16, 10为类别\n",
535 | "\t\toutput这16张图片的各个类别的可能性(未经Softmax处理)\n",
536 | "\t\t所以通过torch.max找到最大的那个。\n",
537 | "\t\ttorch.max接受两个参数,第一个是tensor,第二个是dim(维度)\n",
538 | "\t\t\t\t 这里传1,意思是在类别这个维度上取最大的\n",
539 | "\t\ttorch.max有两个输出,values和indexes,\n",
540 | "\t\t\t\t values就是最大的数是什么,\n",
541 | "\t\t\t\t indexes是这些最大的数的index是什么\n",
542 | "\t\t这里我们只需要index即可,所以忽略第一个参数\n",
543 | "\t\t\"\"\"\n",
544 | " _, predicted = torch.max(outputs, 1)\n",
545 | " # 记录总数量\n",
546 | " total += labels.size(0)\n",
547 | " # 计算正确数量\n",
548 | " correct += (predicted == labels).sum().item()\n",
549 | "\n",
550 | "print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')"
551 | ],
552 | "metadata": {
553 | "collapsed": false,
554 | "pycharm": {
555 | "name": "#%%\n"
556 | }
557 | }
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "source": [
562 | "接下来计算下每个类别精准率:"
563 | ],
564 | "metadata": {
565 | "collapsed": false,
566 | "pycharm": {
567 | "name": "#%% md\n"
568 | }
569 | }
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": 13,
574 | "outputs": [
575 | {
576 | "name": "stdout",
577 | "output_type": "stream",
578 | "text": [
579 | "Accuracy for class: plane is 80.0 %\n",
580 | "Accuracy for class: car is 70.5 %\n",
581 | "Accuracy for class: bird is 53.7 %\n",
582 | "Accuracy for class: cat is 46.3 %\n",
583 | "Accuracy for class: deer is 55.6 %\n",
584 | "Accuracy for class: dog is 50.8 %\n",
585 | "Accuracy for class: frog is 77.8 %\n",
586 | "Accuracy for class: horse is 67.4 %\n",
587 | "Accuracy for class: ship is 73.6 %\n",
588 | "Accuracy for class: truck is 70.9 %\n"
589 | ]
590 | }
591 | ],
592 | "source": [
593 | "# 统计每个类别的正确数量和总数量\n",
594 | "correct_pred = {classname: 0 for classname in classes}\n",
595 | "total_pred = {classname: 0 for classname in classes}\n",
596 | "\n",
597 | "# again no gradients needed\n",
598 | "with torch.no_grad():\n",
599 | " for data in testloader:\n",
600 | " images, labels = data\n",
601 | " outputs = net(images)\n",
602 | " _, predictions = torch.max(outputs, 1)\n",
603 | " # collect the correct predictions for each class\n",
604 | " for label, prediction in zip(labels, predictions):\n",
605 | " if label == prediction:\n",
606 | " correct_pred[classes[label]] += 1\n",
607 | " total_pred[classes[label]] += 1\n",
608 | "\n",
609 | "\n",
610 | "# print accuracy for each class\n",
611 | "for classname, correct_count in correct_pred.items():\n",
612 | " accuracy = 100 * float(correct_count) / total_pred[classname]\n",
613 | " print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')"
614 | ],
615 | "metadata": {
616 | "collapsed": false,
617 | "pycharm": {
618 | "name": "#%%\n"
619 | }
620 | }
621 | },
622 | {
623 | "cell_type": "markdown",
624 | "source": [
625 | "# 参考文献\n",
626 | "\n",
627 | "[pytorch官方CNN分类样例](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html): https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html"
628 | ],
629 | "metadata": {
630 | "collapsed": false,
631 | "pycharm": {
632 | "name": "#%% md\n"
633 | }
634 | }
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "outputs": [],
640 | "source": [],
641 | "metadata": {
642 | "collapsed": false,
643 | "pycharm": {
644 | "name": "#%%\n"
645 | }
646 | }
647 | }
648 | ],
649 | "metadata": {
650 | "kernelspec": {
651 | "display_name": "Python 3",
652 | "language": "python",
653 | "name": "python3"
654 | },
655 | "language_info": {
656 | "codemirror_mode": {
657 | "name": "ipython",
658 | "version": 2
659 | },
660 | "file_extension": ".py",
661 | "mimetype": "text/x-python",
662 | "name": "python",
663 | "nbconvert_exporter": "python",
664 | "pygments_lexer": "ipython2",
665 | "version": "2.7.6"
666 | }
667 | },
668 | "nbformat": 4,
669 | "nbformat_minor": 0
670 | }
--------------------------------------------------------------------------------
/04_LSTM_sentiment_analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | },
10 | "id": "mMulrrXe6_0i"
11 | },
12 | "source": [
13 | "# Pytorch入门实战(4):基于LSTM实现文本的情感分析"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "source": [
19 | "# 本文涉及知识点\n",
20 | "\n",
21 | "[Pytorch nn.Module的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122797244)\n",
22 | "\n",
23 | "[Pytorch nn.Linear的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122797190)\n",
24 | "\n",
25 | "[Pytorch中DataLoader的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122742656)\n",
26 | "\n",
27 | "[Pytorch nn.Embedding的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122809709)\n",
28 | "\n",
29 | "[详解torch.nn.utils.clip_grad_norm_ 的使用与原理](https://blog.csdn.net/zhaohongfei_358/article/details/122820992)"
30 | ],
31 | "metadata": {
32 | "collapsed": false,
33 | "pycharm": {
34 | "name": "#%% md\n"
35 | },
36 | "id": "TyKcUptm6_0k"
37 | }
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "source": [
42 | "# 本文内容\n",
43 | "\n",
44 | "本文基于文章[Long Short-Term Memory: From Zero to Hero with PyTorch](https://blog.floydhub.com/long-short-term-memory-from-zero-to-hero-with-pytorch/)的代码,对该文章代码进行了一些修改和注释添加。该文章详细的介绍了LSTM,如果对LSTM不熟悉的朋友,可以先看下改文章。\n",
45 | "\n",
46 | "本文使用的亚马逊评论数据集,训练一个可以判别文本情感的分类器。\n",
47 | "\n",
48 | "数据集如下:\n",
49 | "\n",
50 | "```\n",
51 | "链接:https://pan.baidu.com/s/1cK-scxLIliTsOPF-6byucQ\n",
52 | "提取码:yqbq\n",
53 | "```"
54 | ],
55 | "metadata": {
56 | "collapsed": false,
57 | "pycharm": {
58 | "name": "#%% md\n"
59 | },
60 | "id": "YTCs253y6_0l"
61 | }
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "source": [
66 | "# 数据预处理\n",
67 | "\n",
68 | "首先导入要使用的包:"
69 | ],
70 | "metadata": {
71 | "collapsed": false,
72 | "pycharm": {
73 | "name": "#%% md\n"
74 | },
75 | "id": "rv0IV3qr6_0l"
76 | }
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 1,
81 | "outputs": [],
82 | "source": [
83 | "import bz2 # 用于读取bz2压缩文件\n",
84 | "from collections import Counter # 用于统计词频\n",
85 | "import re # 正则表达式\n",
86 | "import nltk # 文本预处理\n",
87 | "import numpy as np"
88 | ],
89 | "metadata": {
90 | "pycharm": {
91 | "name": "#%%\n"
92 | },
93 | "id": "dcJoLvsE6_0m"
94 | }
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "source": [
99 | "将数据样本解压到当前目录的data目录下,其中包含两个文件:train.ft.txt.bz2”和“test.ft.txt.bz2”\n",
100 | "\n",
101 | "解压后,读取训练数据和测试数据:"
102 | ],
103 | "metadata": {
104 | "collapsed": false,
105 | "pycharm": {
106 | "name": "#%% md\n"
107 | },
108 | "id": "RhSZFZBQ6_0m"
109 | }
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 2,
114 | "outputs": [
115 | {
116 | "output_type": "stream",
117 | "name": "stdout",
118 | "text": [
119 | "/usr/local/lib/python3.7/dist-packages/gdown/cli.py:131: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
120 | " category=FutureWarning,\n",
121 | "Downloading...\n",
122 | "From: https://drive.google.com/uc?id=1BKma83La6Cx3m1rWmnQScHjTWj-z65NP\n",
123 | "To: /content/data.zip\n",
124 | "100% 517M/517M [00:05<00:00, 90.6MB/s]\n"
125 | ]
126 | }
127 | ],
128 | "source": [
129 | "!gdown --id '1BKma83La6Cx3m1rWmnQScHjTWj-z65NP' --output data.zip"
130 | ],
131 | "metadata": {
132 | "pycharm": {
133 | "name": "#%%\n"
134 | },
135 | "id": "87YmVQWf6_0n",
136 | "outputId": "d07d229e-4760-48cb-a97f-b126d80eaf96",
137 | "colab": {
138 | "base_uri": "https://localhost:8080/"
139 | }
140 | }
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 3,
145 | "outputs": [
146 | {
147 | "output_type": "stream",
148 | "name": "stdout",
149 | "text": [
150 | "Archive: data.zip\n",
151 | " inflating: test.ft.txt.bz2 \n",
152 | " inflating: train.ft.txt.bz2 \n"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "!unzip data.zip"
158 | ],
159 | "metadata": {
160 | "pycharm": {
161 | "name": "#%%\n"
162 | },
163 | "id": "l6pid1tY6_0o",
164 | "outputId": "4e23b4b6-1e5f-44d8-b09e-f9a10cf4af6e",
165 | "colab": {
166 | "base_uri": "https://localhost:8080/"
167 | }
168 | }
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 5,
173 | "outputs": [
174 | {
175 | "output_type": "stream",
176 | "name": "stdout",
177 | "text": [
178 | "b'__label__2 Stuning even for the non-gamer: This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music! I have played the game Chrono Cross but out of all of the games I have ever played it has the best music! It backs away from crude keyboarding and takes a fresher step with grate guitars and soulful orchestras. It would impress anyone who cares to listen! ^_^\\n'\n"
179 | ]
180 | }
181 | ],
182 | "source": [
183 | "train_file = bz2.BZ2File('train.ft.txt.bz2')\n",
184 | "test_file = bz2.BZ2File('test.ft.txt.bz2')\n",
185 | "train_file = train_file.readlines()\n",
186 | "test_file = test_file.readlines()\n",
187 | "print(train_file[0])"
188 | ],
189 | "metadata": {
190 | "pycharm": {
191 | "name": "#%%\n"
192 | },
193 | "id": "SAsgBBHj6_0o",
194 | "outputId": "359aba04-75cc-4cb9-dec2-cf8d9cd059d9",
195 | "colab": {
196 | "base_uri": "https://localhost:8080/"
197 | }
198 | }
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "source": [
203 | "从上面打印的数据可以看到,每条数据由两部分组成,*Label*和*Data*。其中:\n",
204 | "\n",
205 | "- `__label__1` 代表差评,之后将其编码为0\n",
206 | "- `__label__2` 代表好评,之后将其编码为1\n",
207 | "\n",
208 | "由于数据量太大,所以这里只取100w条记录进行训练,训练集和测试集按照*8:2*进行拆分:"
209 | ],
210 | "metadata": {
211 | "collapsed": false,
212 | "pycharm": {
213 | "name": "#%% md\n"
214 | },
215 | "id": "_BxmKkli6_0o"
216 | }
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 6,
221 | "outputs": [],
222 | "source": [
223 | "num_train = 800000\n",
224 | "num_test = 200000\n",
225 | "\n",
226 | "train_file = [x.decode('utf-8') for x in train_file[:num_train]]\n",
227 | "test_file = [x.decode('utf-8') for x in test_file[:num_test]]"
228 | ],
229 | "metadata": {
230 | "pycharm": {
231 | "name": "#%%\n"
232 | },
233 | "id": "uL3MhKtc6_0p"
234 | }
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "source": [
239 | "> 这里使用decode('utf-8')是因为源文件是以二进制类型存储的,从上面的`b''`可以看出\n",
240 | "\n",
241 | "源文件中,数据和标签是在一起的,所以要将其拆分开:"
242 | ],
243 | "metadata": {
244 | "collapsed": false,
245 | "pycharm": {
246 | "name": "#%% md\n"
247 | },
248 | "id": "4utoGSoz6_0p"
249 | }
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 7,
254 | "outputs": [],
255 | "source": [
256 | "# 将__label__1编码为0(差评),__label__2编码为1(好评)\n",
257 | "train_labels = [0 if x.split(' ')[0] == '__label__1' else 1 for x in train_file]\n",
258 | "test_labels = [0 if x.split(' ')[0] == '__label__1' else 1 for x in test_file]\n",
259 | "\n",
260 | "\"\"\"\n",
261 | "`split(' ', 1)[1]`:将label和data分开后,获取data部分\n",
262 | "`[:-1]`:去掉最后一个字符(\\n)\n",
263 | "`lower()`: 将其转换为小写,因为区分大小写对情感识别帮助不大,且会增加编码难度\n",
264 | "\"\"\"\n",
265 | "train_sentences = [x.split(' ', 1)[1][:-1].lower() for x in train_file]\n",
266 | "test_sentences = [x.split(' ', 1)[1][:-1].lower() for x in test_file]"
267 | ],
268 | "metadata": {
269 | "pycharm": {
270 | "name": "#%%\n"
271 | },
272 | "id": "2JlI_ZXE6_0p"
273 | }
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "source": [
278 | "在对数据拆分后,对数据进行简单的数据清理:\n",
279 | "\n",
280 | "由于数字对情感分类帮助不大,所以这里将所有的数字都转换为0:"
281 | ],
282 | "metadata": {
283 | "collapsed": false,
284 | "pycharm": {
285 | "name": "#%% md\n"
286 | },
287 | "id": "Nk04Vc-y6_0p"
288 | }
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 8,
293 | "outputs": [],
294 | "source": [
295 | "for i in range(len(train_sentences)):\n",
296 | " train_sentences[i] = re.sub('\\d','0',train_sentences[i])\n",
297 | "\n",
298 | "for i in range(len(test_sentences)):\n",
299 | " test_sentences[i] = re.sub('\\d','0',test_sentences[i])"
300 | ],
301 | "metadata": {
302 | "pycharm": {
303 | "name": "#%%\n"
304 | },
305 | "id": "07AVnZC16_0p"
306 | }
307 | },
308 | {
309 | "cell_type": "markdown",
310 | "source": [
311 | "数据集中还存在包含网站的样本,例如:`Welcome to our website: www.pohabo.com`。对于这种带有网站的样本,网站地址会干扰数据处理,所以一律处理成:`Welcome to our website: `:"
312 | ],
313 | "metadata": {
314 | "collapsed": false,
315 | "pycharm": {
316 | "name": "#%% md\n"
317 | },
318 | "id": "abM6rblh6_0q"
319 | }
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 9,
324 | "outputs": [],
325 | "source": [
326 | "for i in range(len(train_sentences)):\n",
327 | " if 'www.' in train_sentences[i] or 'http:' in train_sentences[i] or 'https:' in train_sentences[i] or '.com' in train_sentences[i]:\n",
328 | " train_sentences[i] = re.sub(r\"([^ ]+(?<=\\.[a-z]{3}))\", \"\", train_sentences[i])\n",
329 | "\n",
330 | "for i in range(len(test_sentences)):\n",
331 | " if 'www.' in test_sentences[i] or 'http:' in test_sentences[i] or 'https:' in test_sentences[i] or '.com' in test_sentences[i]:\n",
332 | " test_sentences[i] = re.sub(r\"([^ ]+(?<=\\.[a-z]{3}))\", \"\", test_sentences[i])"
333 | ],
334 | "metadata": {
335 | "pycharm": {
336 | "name": "#%%\n"
337 | },
338 | "id": "oiHNzDE56_0q"
339 | }
340 | },
341 | {
342 | "cell_type": "markdown",
343 | "source": [
344 | "数据清理结束后,我们需要将**文本进行分词**,并**将仅出现一次的单词丢掉**,因为它们参考价值不大:"
345 | ],
346 | "metadata": {
347 | "collapsed": false,
348 | "pycharm": {
349 | "name": "#%% md\n"
350 | },
351 | "id": "pVoDvGR16_0q"
352 | }
353 | },
354 | {
355 | "cell_type": "code",
356 | "source": [
357 | "nltk.download('punkt') # 使用nltk.work_tokenize前,需要下载`punkt`"
358 | ],
359 | "metadata": {
360 | "id": "0SerO5AE7-pw",
361 | "outputId": "6337ca5e-754a-410a-bdee-8bbda9f3e0b5",
362 | "colab": {
363 | "base_uri": "https://localhost:8080/"
364 | }
365 | },
366 | "execution_count": 12,
367 | "outputs": [
368 | {
369 | "output_type": "stream",
370 | "name": "stderr",
371 | "text": [
372 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
373 | "[nltk_data] Unzipping tokenizers/punkt.zip.\n"
374 | ]
375 | },
376 | {
377 | "output_type": "execute_result",
378 | "data": {
379 | "text/plain": [
380 | "True"
381 | ]
382 | },
383 | "metadata": {},
384 | "execution_count": 12
385 | }
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": 13,
391 | "outputs": [
392 | {
393 | "output_type": "stream",
394 | "name": "stdout",
395 | "text": [
396 | "0.0% done\n",
397 | "25.0% done\n",
398 | "50.0% done\n",
399 | "75.0% done\n",
400 | "100% done\n"
401 | ]
402 | }
403 | ],
404 | "source": [
405 | "words = Counter() # 用于统计每个单词出现的次数\n",
406 | "for i, sentence in enumerate(train_sentences):\n",
407 | " words_list = nltk.word_tokenize(sentence) # 将句子进行分词\n",
408 | " words.update(words_list) # 更新词频列表\n",
409 | " train_sentences[i] = words_list # 分词后的单词列表存在该列表中\n",
410 | "\n",
411 | " if i % 20000 == 0: # 每2w打印一次进度\n",
412 | " print(str((i*100)/num_train) + \"% done\")\n",
413 | "print(\"100% done\")"
414 | ],
415 | "metadata": {
416 | "pycharm": {
417 | "name": "#%%\n"
418 | },
419 | "id": "uUIDMkDt6_0q",
420 | "outputId": "b3db9b70-b558-4f14-9278-e8ef4f6a3872",
421 | "colab": {
422 | "base_uri": "https://localhost:8080/"
423 | }
424 | }
425 | },
426 | {
427 | "cell_type": "markdown",
428 | "source": [
429 | "移除仅出现一次的单词:"
430 | ],
431 | "metadata": {
432 | "collapsed": false,
433 | "pycharm": {
434 | "name": "#%% md\n"
435 | },
436 | "id": "iWS8bm1g6_0q"
437 | }
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 14,
442 | "outputs": [],
443 | "source": [
444 | "words = {k:v for k,v in words.items() if v>1}"
445 | ],
446 | "metadata": {
447 | "pycharm": {
448 | "name": "#%%\n"
449 | },
450 | "id": "AeW4L89X6_0r"
451 | }
452 | },
453 | {
454 | "cell_type": "markdown",
455 | "source": [
456 | "将words按照出现次数由大到小排序,并转换为list,**作为我们的词典**,之后**对于单词的编码会基于该词典**:"
457 | ],
458 | "metadata": {
459 | "collapsed": false,
460 | "pycharm": {
461 | "name": "#%% md\n"
462 | },
463 | "id": "F_twWbLr6_0r"
464 | }
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": 15,
469 | "outputs": [
470 | {
471 | "output_type": "stream",
472 | "name": "stdout",
473 | "text": [
474 | "['.', 'the', ',', 'i', 'and', 'a', 'to', 'it', 'of', 'this']\n"
475 | ]
476 | }
477 | ],
478 | "source": [
479 | "words = sorted(words, key=words.get,reverse=True)\n",
480 | "print(words[:10]) # 打印一下出现次数最多的10个单词"
481 | ],
482 | "metadata": {
483 | "pycharm": {
484 | "name": "#%%\n"
485 | },
486 | "id": "7xGQFepO6_0r",
487 | "outputId": "2e1e1eb0-9f5d-41c4-a2ef-784e9d849b30",
488 | "colab": {
489 | "base_uri": "https://localhost:8080/"
490 | }
491 | }
492 | },
493 | {
494 | "cell_type": "markdown",
495 | "source": [
496 | "向词典中增加一个单词:\n",
497 | "\n",
498 | "- `_PAD`:表示填充,因为后续会固定所有句子长度。过长的句子进行阶段,过短的句子使用该单词进行填充"
499 | ],
500 | "metadata": {
501 | "collapsed": false,
502 | "pycharm": {
503 | "name": "#%% md\n"
504 | },
505 | "id": "R-kILnqA6_0r"
506 | }
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 16,
511 | "outputs": [],
512 | "source": [
513 | "words = ['_PAD'] + words"
514 | ],
515 | "metadata": {
516 | "pycharm": {
517 | "name": "#%%\n"
518 | },
519 | "id": "mkBOi5Np6_0r"
520 | }
521 | },
522 | {
523 | "cell_type": "markdown",
524 | "source": [
525 | "整理好词典后,对**单词进行编码**,即**将单词映射成数字**,这里直接使用单词所在的数字下表作为单词的编码值:"
526 | ],
527 | "metadata": {
528 | "collapsed": false,
529 | "pycharm": {
530 | "name": "#%% md\n"
531 | },
532 | "id": "n35FzP_M6_0r"
533 | }
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 17,
538 | "outputs": [],
539 | "source": [
540 | "word2idx = {o:i for i,o in enumerate(words)}\n",
541 | "idx2word = {i:o for i,o in enumerate(words)}"
542 | ],
543 | "metadata": {
544 | "pycharm": {
545 | "name": "#%%\n"
546 | },
547 | "id": "NBWVkmM_6_0r"
548 | }
549 | },
550 | {
551 | "cell_type": "markdown",
552 | "source": [
553 | "映射字典准备完毕后,就可以将`train_sentences`中存储的单词转化为数字了:"
554 | ],
555 | "metadata": {
556 | "collapsed": false,
557 | "pycharm": {
558 | "name": "#%% md\n"
559 | },
560 | "id": "FieIwn2u6_0r"
561 | }
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 18,
566 | "outputs": [],
567 | "source": [
568 | "for i, sentence in enumerate(train_sentences):\n",
569 | " train_sentences[i] = [word2idx[word] if word in word2idx else 0 for word in sentence]\n",
570 | "\n",
571 | "for i, sentence in enumerate(test_sentences):\n",
572 | " test_sentences[i] = [word2idx[word.lower()] if word.lower() in word2idx else 0 for word in nltk.word_tokenize(sentence)]"
573 | ],
574 | "metadata": {
575 | "pycharm": {
576 | "name": "#%%\n"
577 | },
578 | "id": "mOCna9Q56_0r"
579 | }
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "source": [
584 | "> 上面的`else 0`表示:如果单词没有在字典中出现过,则使用编码0,对应上面的`_PAD`\n",
585 | "\n",
586 | "为了方便构建模型,需要固定所有句子的长度,这里选择200作为句子的固定长度,对于长度不够的句子,在前面填充`0`(`_PAD`),超出长度的句子进行从后面截断:"
587 | ],
588 | "metadata": {
589 | "collapsed": false,
590 | "pycharm": {
591 | "name": "#%% md\n"
592 | },
593 | "id": "G95_HFbD6_0s"
594 | }
595 | },
596 | {
597 | "cell_type": "code",
598 | "execution_count": 19,
599 | "outputs": [],
600 | "source": [
601 | "def pad_input(sentences, seq_len):\n",
602 | " \"\"\"\n",
603 | " 将句子长度固定为`seq_len`,超出长度的从后面阶段,长度不足的在前面补0\n",
604 | " \"\"\"\n",
605 | " features = np.zeros((len(sentences), seq_len),dtype=int)\n",
606 | " for ii, review in enumerate(sentences):\n",
607 | " if len(review) != 0:\n",
608 | " features[ii, -len(review):] = np.array(review)[:seq_len]\n",
609 | " return features\n",
610 | "\n",
611 | "# 固定测试数据集和训练数据集的句子长度\n",
612 | "train_sentences = pad_input(train_sentences, 200)\n",
613 | "test_sentences = pad_input(test_sentences, 200)"
614 | ],
615 | "metadata": {
616 | "pycharm": {
617 | "name": "#%%\n"
618 | },
619 | "id": "chNfmNcv6_0s"
620 | }
621 | },
622 | {
623 | "cell_type": "markdown",
624 | "source": [
625 | "上述方法除了固定长度外,还顺便将数字转化为了numpy数组。Label数据集也需要转换一下:"
626 | ],
627 | "metadata": {
628 | "collapsed": false,
629 | "pycharm": {
630 | "name": "#%% md\n"
631 | },
632 | "id": "nqRUKb6s6_0s"
633 | }
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": 20,
638 | "outputs": [],
639 | "source": [
640 | "train_labels = np.array(train_labels)\n",
641 | "test_labels = np.array(test_labels)"
642 | ],
643 | "metadata": {
644 | "pycharm": {
645 | "name": "#%%\n"
646 | },
647 | "id": "01OAgkzX6_0s"
648 | }
649 | },
650 | {
651 | "cell_type": "markdown",
652 | "source": [
653 | "到这里,数据预处理的工作基本完成,接下来该PyTorch登场了"
654 | ],
655 | "metadata": {
656 | "collapsed": false,
657 | "pycharm": {
658 | "name": "#%% md\n"
659 | },
660 | "id": "-SKNTwas6_0s"
661 | }
662 | },
663 | {
664 | "cell_type": "markdown",
665 | "source": [
666 | "# 模型构建\n",
667 | "\n",
668 | "首先导出Pytorch需要用到的包"
669 | ],
670 | "metadata": {
671 | "collapsed": false,
672 | "pycharm": {
673 | "name": "#%% md\n"
674 | },
675 | "id": "bJWix0sq6_0s"
676 | }
677 | },
678 | {
679 | "cell_type": "code",
680 | "execution_count": 21,
681 | "outputs": [],
682 | "source": [
683 | "import torch\n",
684 | "from torch.utils.data import TensorDataset, DataLoader\n",
685 | "import torch.nn as nn"
686 | ],
687 | "metadata": {
688 | "pycharm": {
689 | "name": "#%%\n"
690 | },
691 | "id": "CSdCaxjs6_0s"
692 | }
693 | },
694 | {
695 | "cell_type": "markdown",
696 | "source": [
697 | "构建训练数据集和测试数据集的DataLoader,同时**定义BatchSize为200**:"
698 | ],
699 | "metadata": {
700 | "collapsed": false,
701 | "pycharm": {
702 | "name": "#%% md\n"
703 | },
704 | "id": "ngzQUMlL6_0s"
705 | }
706 | },
707 | {
708 | "cell_type": "code",
709 | "execution_count": 22,
710 | "outputs": [],
711 | "source": [
712 | "batch_size = 200\n",
713 | "\n",
714 | "train_data = TensorDataset(torch.from_numpy(train_sentences), torch.from_numpy(train_labels))\n",
715 | "test_data = TensorDataset(torch.from_numpy(test_sentences), torch.from_numpy(test_labels))\n",
716 | "\n",
717 | "train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)\n",
718 | "test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)"
719 | ],
720 | "metadata": {
721 | "pycharm": {
722 | "name": "#%%\n"
723 | },
724 | "id": "rB8exua36_0s"
725 | }
726 | },
727 | {
728 | "cell_type": "markdown",
729 | "source": [
730 | "如果有条件,建议使用显卡来加速计算:"
731 | ],
732 | "metadata": {
733 | "collapsed": false,
734 | "pycharm": {
735 | "name": "#%% md\n"
736 | },
737 | "id": "vR-y2-by6_0s"
738 | }
739 | },
740 | {
741 | "cell_type": "code",
742 | "execution_count": 23,
743 | "outputs": [],
744 | "source": [
745 | "device = torch.device('cuda') if torch.cuda.is_available() else torch.device(\"cpu\")"
746 | ],
747 | "metadata": {
748 | "pycharm": {
749 | "name": "#%%\n"
750 | },
751 | "id": "-ltzuAyQ6_0s"
752 | }
753 | },
754 | {
755 | "cell_type": "code",
756 | "execution_count": 24,
757 | "outputs": [],
758 | "source": [
759 | "class SentimentNet(nn.Module):\n",
760 | " def __init__(self, vocab_size):\n",
761 | " super(SentimentNet, self).__init__()\n",
762 | " self.n_layers = n_layers = 2 # LSTM的层数\n",
763 | " self.hidden_dim = hidden_dim = 512 # 隐状态的维度,即LSTM输出的隐状态的维度为512\n",
764 | " embedding_dim = 400 # 将单词编码成400维的向量\n",
765 | " drop_prob=0.5 # dropout\n",
766 | "\n",
767 | " # 定义embedding,负责将数字编码成向量,详情可参考:https://blog.csdn.net/zhaohongfei_358/article/details/122809709\n",
768 | " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
769 | "\n",
770 | " self.lstm = nn.LSTM(embedding_dim, # 输入的维度\n",
771 | " hidden_dim, # LSTM输出的hidden_state的维度\n",
772 | " n_layers, # LSTM的层数\n",
773 | " dropout=drop_prob,\n",
774 | " batch_first=True # 第一个维度是否是batch_size\n",
775 | " )\n",
776 | "\n",
777 | "\n",
778 | "\n",
779 | " # LSTM结束后的全连接线性层\n",
780 | " self.fc = nn.Linear(in_features=hidden_dim, # 将LSTM的输出作为线性层的输入\n",
781 | " out_features=1 # 由于情感分析只需要输出0或1,所以输出的维度是1\n",
782 | " )\n",
783 | " self.sigmoid = nn.Sigmoid() # 线性层输出后,还需要过一下sigmoid\n",
784 | "\n",
785 | " # 给最后的全连接层加一个Dropout\n",
786 | " self.dropout = nn.Dropout(drop_prob)\n",
787 | "\n",
788 | " def forward(self, x, hidden):\n",
789 | " \"\"\"\n",
790 | " x: 本次的输入,其size为(batch_size, 200),200为句子长度\n",
791 | " hidden: 上一时刻的Hidden State和Cell State。类型为tuple: (h, c),\n",
792 | " 其中h和c的size都为(n_layers, batch_size, hidden_dim), 即(2, 200, 512)\n",
793 | " \"\"\"\n",
794 | " # 因为一次输入一组数据,所以第一个维度是batch的大小\n",
795 | " batch_size = x.size(0)\n",
796 | "\n",
797 | " # 由于embedding只接受LongTensor类型,所以将x转换为LongTensor类型\n",
798 | " x = x.long()\n",
799 | "\n",
800 | " # 对x进行编码,这里会将x的size由(batch_size, 200)转化为(batch_size, 200, embedding_dim)\n",
801 | " embeds = self.embedding(x)\n",
802 | "\n",
803 | " # 将编码后的向量和上一时刻的hidden_state传给LSTM,并获取本次的输出和隐状态(hidden_state, cell_state)\n",
804 | " # lstm_out的size为 (batch_size, 200, 512),200是单词的数量,由于是一个单词一个单词送给LSTM的,所以会产生与单词数量相同的输出\n",
805 | " # hidden为tuple(hidden_state, cell_state),它们俩的size都为(2, batch_size, 512), 2是由于lstm有两层。由于是所有单词都是共享隐状态的,所以并不会出现上面的那个200\n",
806 | " lstm_out, hidden = self.lstm(embeds, hidden)\n",
807 | "\n",
808 | " # 接下来要过全连接层,所以size变为(batch_size * 200, hidden_dim),\n",
809 | " # 之所以是batch_size * 200=40000,是因为每个单词的输出都要经过全连接层。\n",
810 | " # 换句话说,全连接层的batch_size为40000\n",
811 | " lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)\n",
812 | "\n",
813 | " # 给全连接层加个Dropout\n",
814 | " out = self.dropout(lstm_out)\n",
815 | "\n",
816 | " # 将dropout后的数据送给全连接层\n",
817 | " # 全连接层输出的size为(40000, 1)\n",
818 | " out = self.fc(out)\n",
819 | "\n",
820 | " # 过一下sigmoid\n",
821 | " out = self.sigmoid(out)\n",
822 | "\n",
823 | " # 将最终的输出数据维度变为 (batch_size, 200),即每个单词都对应一个输出\n",
824 | " out = out.view(batch_size, -1)\n",
825 | "\n",
826 | " # 只去最后一个单词的输出\n",
827 | " # 所以out的size会变为(200, 1)\n",
828 | " out = out[:,-1]\n",
829 | "\n",
830 | " # 将输出和本次的(h, c)返回\n",
831 | " return out, hidden\n",
832 | "\n",
833 | " def init_hidden(self, batch_size):\n",
834 | " \"\"\"\n",
835 | " 初始化隐状态:第一次送给LSTM时,没有隐状态,所以要初始化一个\n",
836 | " 这里的初始化策略是全部赋0。\n",
837 | " 这里之所以是tuple,是因为LSTM需要接受两个隐状态hidden state和cell state\n",
838 | " \"\"\"\n",
839 | " hidden = (torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device),\n",
840 | " torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)\n",
841 | " )\n",
842 | " return hidden"
843 | ],
844 | "metadata": {
845 | "pycharm": {
846 | "name": "#%%\n"
847 | },
848 | "id": "MvOKKm2a6_0t"
849 | }
850 | },
851 | {
852 | "cell_type": "markdown",
853 | "source": [
854 | "模型定义完毕,构建模型对象:"
855 | ],
856 | "metadata": {
857 | "collapsed": false,
858 | "pycharm": {
859 | "name": "#%% md\n"
860 | },
861 | "id": "XMh1JBZL6_0t"
862 | }
863 | },
864 | {
865 | "cell_type": "code",
866 | "execution_count": 25,
867 | "outputs": [
868 | {
869 | "output_type": "execute_result",
870 | "data": {
871 | "text/plain": [
872 | "SentimentNet(\n",
873 | " (embedding): Embedding(221604, 400)\n",
874 | " (lstm): LSTM(400, 512, num_layers=2, batch_first=True, dropout=0.5)\n",
875 | " (fc): Linear(in_features=512, out_features=1, bias=True)\n",
876 | " (sigmoid): Sigmoid()\n",
877 | " (dropout): Dropout(p=0.5, inplace=False)\n",
878 | ")"
879 | ]
880 | },
881 | "metadata": {},
882 | "execution_count": 25
883 | }
884 | ],
885 | "source": [
886 | "model = SentimentNet(len(words))\n",
887 | "model.to(device)"
888 | ],
889 | "metadata": {
890 | "pycharm": {
891 | "name": "#%%\n"
892 | },
893 | "id": "pqZmoWar6_0t",
894 | "outputId": "12839ec8-c4b3-4d86-c4b7-aff71dc7c066",
895 | "colab": {
896 | "base_uri": "https://localhost:8080/"
897 | }
898 | }
899 | },
900 | {
901 | "cell_type": "markdown",
902 | "source": [
903 | "接下来定义损失函数,由于是二分类问题,所以使用**交叉熵(Binary Cross Entropy,BCE)**:"
904 | ],
905 | "metadata": {
906 | "collapsed": false,
907 | "pycharm": {
908 | "name": "#%% md\n"
909 | },
910 | "id": "sb3eb6bZ6_0t"
911 | }
912 | },
913 | {
914 | "cell_type": "code",
915 | "execution_count": 26,
916 | "outputs": [],
917 | "source": [
918 | "criterion = nn.BCELoss()"
919 | ],
920 | "metadata": {
921 | "pycharm": {
922 | "name": "#%%\n"
923 | },
924 | "id": "Wg2L1kOQ6_0t"
925 | }
926 | },
927 | {
928 | "cell_type": "markdown",
929 | "source": [
930 | "优化器选用Adam优化器:"
931 | ],
932 | "metadata": {
933 | "collapsed": false,
934 | "pycharm": {
935 | "name": "#%% md\n"
936 | },
937 | "id": "bvvd1SYa6_0t"
938 | }
939 | },
940 | {
941 | "cell_type": "code",
942 | "execution_count": 27,
943 | "outputs": [],
944 | "source": [
945 | "lr = 0.005\n",
946 | "optimizer = torch.optim.Adam(model.parameters(), lr=lr)"
947 | ],
948 | "metadata": {
949 | "pycharm": {
950 | "name": "#%%\n"
951 | },
952 | "id": "_6P6YNYM6_0t"
953 | }
954 | },
955 | {
956 | "cell_type": "markdown",
957 | "source": [
958 | "接下来定义训练代码:"
959 | ],
960 | "metadata": {
961 | "collapsed": false,
962 | "pycharm": {
963 | "name": "#%% md\n"
964 | },
965 | "id": "sY7N__vE6_0t"
966 | }
967 | },
968 | {
969 | "cell_type": "code",
970 | "execution_count": 28,
971 | "outputs": [
972 | {
973 | "output_type": "stream",
974 | "name": "stdout",
975 | "text": [
976 | "Epoch: 1/2... Step: 1000... Loss: 0.268714...\n",
977 | "Epoch: 1/2... Step: 2000... Loss: 0.187919...\n",
978 | "Epoch: 1/2... Step: 3000... Loss: 0.215379...\n",
979 | "Epoch: 1/2... Step: 4000... Loss: 0.195820...\n",
980 | "Epoch: 2/2... Step: 5000... Loss: 0.130096...\n",
981 | "Epoch: 2/2... Step: 6000... Loss: 0.110538...\n",
982 | "Epoch: 2/2... Step: 7000... Loss: 0.198314...\n",
983 | "Epoch: 2/2... Step: 8000... Loss: 0.233867...\n"
984 | ]
985 | }
986 | ],
987 | "source": [
988 | "epochs = 2 # 一共训练两轮\n",
989 | "counter = 0 # 用于记录训练次数\n",
990 | "print_every = 1000 # 每1000次打印一下当前状态\n",
991 | "\n",
992 | "for i in range(epochs):\n",
993 | " h = model.init_hidden(batch_size) # 初始化第一个Hidden_state\n",
994 | "\n",
995 | " for inputs, labels in train_loader: # 从train_loader中获取一组inputs和labels\n",
996 | " counter += 1 # 训练次数+1\n",
997 | "\n",
998 | " # 将上次输出的hidden_state转为tuple格式\n",
999 | " # 因为有两次,所以len(h)==2\n",
1000 | " h = tuple([e.data for e in h])\n",
1001 | "\n",
1002 | " # 将数据迁移到GPU\n",
1003 | " inputs, labels = inputs.to(device), labels.to(device)\n",
1004 | "\n",
1005 | " # 清空模型梯度\n",
1006 | " model.zero_grad()\n",
1007 | "\n",
1008 | " # 将本轮的输入和hidden_state送给模型,进行前向传播,\n",
1009 | " # 然后获取本次的输出和新的hidden_state\n",
1010 | " output, h = model(inputs, h)\n",
1011 | "\n",
1012 | " # 将预测值和真实值送给损失函数计算损失\n",
1013 | " loss = criterion(output, labels.float())\n",
1014 | "\n",
1015 | " # 进行反向传播\n",
1016 | " loss.backward()\n",
1017 | "\n",
1018 | " # 对模型进行裁剪,防止模型梯度爆炸\n",
1019 | " # 详情请参考:https://blog.csdn.net/zhaohongfei_358/article/details/122820992\n",
1020 | " nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)\n",
1021 | "\n",
1022 | " # 更新权重\n",
1023 | " optimizer.step()\n",
1024 | "\n",
1025 | " # 隔一定次数打印一下当前状态\n",
1026 | " if counter%print_every == 0:\n",
1027 | " print(\"Epoch: {}/{}...\".format(i+1, epochs),\n",
1028 | " \"Step: {}...\".format(counter),\n",
1029 | " \"Loss: {:.6f}...\".format(loss.item()))"
1030 | ],
1031 | "metadata": {
1032 | "pycharm": {
1033 | "name": "#%%\n"
1034 | },
1035 | "id": "Gzc82i5j6_0t",
1036 | "outputId": "3c74bd52-f2f9-4e76-84bc-f667605b86dd",
1037 | "colab": {
1038 | "base_uri": "https://localhost:8080/"
1039 | }
1040 | }
1041 | },
1042 | {
1043 | "cell_type": "markdown",
1044 | "source": [
1045 | "> 如果这里抛出了`RuntimeError: CUDA out of memory. Tried to allocate ...`异常,可以将batch_size调小,或者清空gpu中的缓存(`torch.cuda.empty_cache()`)"
1046 | ],
1047 | "metadata": {
1048 | "collapsed": false,
1049 | "pycharm": {
1050 | "name": "#%% md\n"
1051 | },
1052 | "id": "bDVxnKpc6_0t"
1053 | }
1054 | },
1055 | {
1056 | "cell_type": "markdown",
1057 | "source": [
1058 | "经过一段时间的训练,现在来评估一下模型的性能:"
1059 | ],
1060 | "metadata": {
1061 | "collapsed": false,
1062 | "pycharm": {
1063 | "name": "#%% md\n"
1064 | },
1065 | "id": "AZSIMcDN6_0t"
1066 | }
1067 | },
1068 | {
1069 | "cell_type": "code",
1070 | "execution_count": 29,
1071 | "outputs": [
1072 | {
1073 | "output_type": "stream",
1074 | "name": "stdout",
1075 | "text": [
1076 | "Test loss: 0.202\n",
1077 | "Test accuracy: 92.487%\n"
1078 | ]
1079 | }
1080 | ],
1081 | "source": [
1082 | "test_losses = [] # 记录测试数据集的损失\n",
1083 | "num_correct = 0 # 记录正确预测的数量\n",
1084 | "h = model.init_hidden(batch_size) # 初始化hidden_state和cell_state\n",
1085 | "model.eval() # 将模型调整为评估模式\n",
1086 | "\n",
1087 | "# 开始评估模型\n",
1088 | "for inputs, labels in test_loader:\n",
1089 | " h = tuple([each.data for each in h])\n",
1090 | " inputs, labels = inputs.to(device), labels.to(device)\n",
1091 | " output, h = model(inputs, h)\n",
1092 | " test_loss = criterion(output.squeeze(), labels.float())\n",
1093 | " test_losses.append(test_loss.item())\n",
1094 | " pred = torch.round(output.squeeze()) # 将模型四舍五入为0和1\n",
1095 | " correct_tensor = pred.eq(labels.float().view_as(pred)) # 计算预测正确的数据\n",
1096 | " correct = np.squeeze(correct_tensor.cpu().numpy())\n",
1097 | " num_correct += np.sum(correct)\n",
1098 | "\n",
1099 | "print(\"Test loss: {:.3f}\".format(np.mean(test_losses)))\n",
1100 | "test_acc = num_correct/len(test_loader.dataset)\n",
1101 | "print(\"Test accuracy: {:.3f}%\".format(test_acc*100))"
1102 | ],
1103 | "metadata": {
1104 | "pycharm": {
1105 | "name": "#%%\n"
1106 | },
1107 | "id": "fa8y_0t86_0t",
1108 | "outputId": "4d613e01-71d7-4347-fcfc-d926143678ae",
1109 | "colab": {
1110 | "base_uri": "https://localhost:8080/"
1111 | }
1112 | }
1113 | },
1114 | {
1115 | "cell_type": "markdown",
1116 | "source": [
1117 | "最终,经过训练后,可以得到90%以上的准确率。\n",
1118 | "\n",
1119 | "我们来实际尝试一下,定义一个`predict(sentence)`函数,输入一个句子,输出其预测结果:"
1120 | ],
1121 | "metadata": {
1122 | "collapsed": false,
1123 | "pycharm": {
1124 | "name": "#%% md\n"
1125 | },
1126 | "id": "lT3opySB6_0u"
1127 | }
1128 | },
1129 | {
1130 | "cell_type": "code",
1131 | "execution_count": 30,
1132 | "outputs": [],
1133 | "source": [
1134 | "def predict(sentence):\n",
1135 | " # 将句子分词后,转换为数字\n",
1136 | " sentences = [[word2idx[word.lower()] if word.lower() in word2idx else 0 for word in nltk.word_tokenize(sentence)]]\n",
1137 | "\n",
1138 | " # 将句子变为固定长度200\n",
1139 | " sentences = pad_input(sentences, 200)\n",
1140 | "\n",
1141 | " # 将数据移到GPU中\n",
1142 | " sentences = torch.Tensor(sentences).long().to(device)\n",
1143 | "\n",
1144 | " # 初始化隐状态\n",
1145 | " h = (torch.Tensor(2, 1, 512).zero_().to(device),\n",
1146 | " torch.Tensor(2, 1, 512).zero_().to(device))\n",
1147 | " h = tuple([each.data for each in h])\n",
1148 | "\n",
1149 | " # 预测\n",
1150 | " if model(sentences, h)[0] >= 0.5:\n",
1151 | " print(\"positive\")\n",
1152 | " else:\n",
1153 | " print(\"negative\")"
1154 | ],
1155 | "metadata": {
1156 | "pycharm": {
1157 | "name": "#%%\n"
1158 | },
1159 | "id": "q56TaJZO6_0u"
1160 | }
1161 | },
1162 | {
1163 | "cell_type": "code",
1164 | "execution_count": 31,
1165 | "outputs": [
1166 | {
1167 | "output_type": "stream",
1168 | "name": "stdout",
1169 | "text": [
1170 | "negative\n",
1171 | "negative\n"
1172 | ]
1173 | }
1174 | ],
1175 | "source": [
1176 | "predict(\"The film is so boring\")\n",
1177 | "predict(\"The actor is too ugly.\")"
1178 | ],
1179 | "metadata": {
1180 | "pycharm": {
1181 | "name": "#%%\n"
1182 | },
1183 | "id": "qil4MtPi6_0u",
1184 | "outputId": "4df7aab6-0d64-493f-ee03-374f30856f44",
1185 | "colab": {
1186 | "base_uri": "https://localhost:8080/"
1187 | }
1188 | }
1189 | },
1190 | {
1191 | "cell_type": "markdown",
1192 | "source": [
1193 | "# 参考资料\n",
1194 | "\n",
1195 | "[Long Short-Term Memory: From Zero to Hero with PyTorch](https://blog.floydhub.com/long-short-term-memory-from-zero-to-hero-with-pytorch/): https://blog.floydhub.com/long-short-term-memory-from-zero-to-hero-with-pytorch/"
1196 | ],
1197 | "metadata": {
1198 | "collapsed": false,
1199 | "pycharm": {
1200 | "name": "#%% md\n"
1201 | },
1202 | "id": "1fQ0sKb26_0u"
1203 | }
1204 | },
1205 | {
1206 | "cell_type": "code",
1207 | "execution_count": null,
1208 | "outputs": [],
1209 | "source": [
1210 | ""
1211 | ],
1212 | "metadata": {
1213 | "pycharm": {
1214 | "name": "#%%\n"
1215 | },
1216 | "id": "gL4V6ekz6_0u"
1217 | }
1218 | }
1219 | ],
1220 | "metadata": {
1221 | "kernelspec": {
1222 | "display_name": "Python 3",
1223 | "language": "python",
1224 | "name": "python3"
1225 | },
1226 | "language_info": {
1227 | "codemirror_mode": {
1228 | "name": "ipython",
1229 | "version": 2
1230 | },
1231 | "file_extension": ".py",
1232 | "mimetype": "text/x-python",
1233 | "name": "python",
1234 | "nbconvert_exporter": "python",
1235 | "pygments_lexer": "ipython2",
1236 | "version": "2.7.6"
1237 | },
1238 | "colab": {
1239 | "provenance": [],
1240 | "machine_shape": "hm"
1241 | },
1242 | "accelerator": "GPU",
1243 | "gpuClass": "standard"
1244 | },
1245 | "nbformat": 4,
1246 | "nbformat_minor": 0
1247 | }
--------------------------------------------------------------------------------
/101_mosaic_video.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "code",
19 | "source": [
20 | "# Install libraries\n",
21 | "!pip install ultralytics\n",
22 | "!pip install ffmpeg-python"
23 | ],
24 | "metadata": {
25 | "id": "WJTaMsEd09Vt"
26 | },
27 | "execution_count": null,
28 | "outputs": []
29 | },
30 | {
31 | "cell_type": "code",
32 | "source": [
33 | "# Download model and video.\n",
34 | "!gdown 1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb\n",
35 | "!wget https://github.com/iioSnail/pytorch_deep_learning_examples/raw/refs/heads/main/asserts/mp4/kunkun.mp4"
36 | ],
37 | "metadata": {
38 | "id": "eXGr2Fl1Wpdc"
39 | },
40 | "execution_count": null,
41 | "outputs": []
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {
47 | "id": "x3QIKG890rxR"
48 | },
49 | "outputs": [],
50 | "source": [
51 | "import ffmpeg\n",
52 | "import cv2\n",
53 | "from numpy import ndarray\n",
54 | "from ultralytics import YOLO\n",
55 | "from tqdm import tqdm"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "source": [
61 | "# Apply mosaic to an image\n",
62 | "def mosaic_image(model, image:ndarray, mosaic_scale = 10) -> ndarray:\n",
63 | " results = model(image, verbose=False)\n",
64 | " results[0].boxes\n",
65 | "\n",
66 | " boxes = results[0].boxes.xyxy\n",
67 | " for i in range(len(boxes)):\n",
68 | " x1, y1, x2, y2 = boxes[i].int()\n",
69 | " roi = image[y1:y2, x1:x2]\n",
70 | "\n",
71 | " h, w = roi.shape[:2]\n",
72 | " small_roi = cv2.resize(roi, (w // mosaic_scale, h // mosaic_scale), interpolation=cv2.INTER_LINEAR)\n",
73 | " mosaic_roi = cv2.resize(small_roi, (w, h), interpolation=cv2.INTER_NEAREST)\n",
74 | " image[y1:y2, x1:x2] = mosaic_roi\n",
75 | "\n",
76 | " return image"
77 | ],
78 | "metadata": {
79 | "id": "ys1gOBSBWY6C"
80 | },
81 | "execution_count": 4,
82 | "outputs": []
83 | },
84 | {
85 | "cell_type": "code",
86 | "source": [
87 | "# Define filepaths.\n",
88 | "input_video = \"kunkun.mp4\"\n",
89 | "tmp_audio = \"tmp.wav\"\n",
90 | "tmp_video = \"tmp_kunkun.mp4\"\n",
91 | "output_video = \"mosaic_kunkun.mp4\"\n",
92 | "\n",
93 | "model = YOLO(\"yolov8n-face.pt\")"
94 | ],
95 | "metadata": {
96 | "id": "js-8koFq04mb"
97 | },
98 | "execution_count": 5,
99 | "outputs": []
100 | },
101 | {
102 | "cell_type": "code",
103 | "source": [
104 | "# Extract audio from the video.\n",
105 | "ffmpeg.input(input_video).output(tmp_audio, format='wav').run(overwrite_output=True)"
106 | ],
107 | "metadata": {
108 | "colab": {
109 | "base_uri": "https://localhost:8080/"
110 | },
111 | "id": "Q270pQcA1JiA",
112 | "outputId": "9669c593-23b2-437b-9701-47845588cca9"
113 | },
114 | "execution_count": 6,
115 | "outputs": [
116 | {
117 | "output_type": "execute_result",
118 | "data": {
119 | "text/plain": [
120 | "(None, None)"
121 | ]
122 | },
123 | "metadata": {},
124 | "execution_count": 6
125 | }
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "source": [
131 | "# Play mosaic frame by frame and generate the output video with mosaic.\n",
132 | "cap = cv2.VideoCapture(input_video)\n",
133 | "if not cap.isOpened():\n",
134 | " print(\"Error: Could not open video file.\")\n",
135 | " exit(0)\n",
136 | "\n",
137 | "width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
138 | "height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
139 | "fps = cap.get(cv2.CAP_PROP_FPS)\n",
140 | "n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
141 | "\n",
142 | "fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n",
143 | "out = cv2.VideoWriter(tmp_video, fourcc, fps, (width, height))\n",
144 | "\n",
145 | "pro_bar = tqdm(total=n_frames)\n",
146 | "while True:\n",
147 | " ret, frame = cap.read()\n",
148 | "\n",
149 | " if not ret:\n",
150 | " break\n",
151 | "\n",
152 | " frame = mosaic_image(model, frame)\n",
153 | " out.write(frame)\n",
154 | "\n",
155 | " pro_bar.update(1)\n",
156 | "\n",
157 | "cap.release()\n",
158 | "out.release()\n",
159 | "pro_bar.close()"
160 | ],
161 | "metadata": {
162 | "id": "WkKbG2uw1mnJ"
163 | },
164 | "execution_count": null,
165 | "outputs": []
166 | },
167 | {
168 | "cell_type": "code",
169 | "source": [
170 | "# Merge the video and audio.\n",
171 | "video_stream = ffmpeg.input(tmp_video)\n",
172 | "audio_stream = ffmpeg.input(tmp_audio)\n",
173 | "ffmpeg.output(video_stream, audio_stream, output_video, vcodec=\"copy\", acodec='aac').run(overwrite_output=True)"
174 | ],
175 | "metadata": {
176 | "id": "GiYXjgvQ1p8F"
177 | },
178 | "execution_count": null,
179 | "outputs": []
180 | },
181 | {
182 | "cell_type": "code",
183 | "source": [
184 | "# Show the result video.\n",
185 | "from IPython.display import HTML\n",
186 | "from base64 import b64encode\n",
187 | "import os\n",
188 | "\n",
189 | "# Compressed video path\n",
190 | "compressed_path = \"./compressed.mp4\"\n",
191 | "os.system(f\"ffmpeg -i {output_video} -vcodec libx264 {compressed_path}\")\n",
192 | "\n",
193 | "# Show video\n",
194 | "mp4 = open(compressed_path,'rb').read()\n",
195 | "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
196 | "HTML(\"\"\"\n",
197 | "\n",
198 | " \n",
199 | " \n",
200 | "\"\"\" % data_url)"
201 | ],
202 | "metadata": {
203 | "id": "AabemWblYA50"
204 | },
205 | "execution_count": null,
206 | "outputs": []
207 | }
208 | ]
209 | }
--------------------------------------------------------------------------------
/README.en.md:
--------------------------------------------------------------------------------
1 | 中文 | English
2 |
3 | # Python Deep Learning Examples
4 |
5 | During my journey of learning deep learning, I encountered numerous challenges. For me, the two biggest obstacles were:
6 |
7 | 1. Most tutorials focus heavily on theory, but without hands-on practice, it’s hard to truly grasp the concepts.
8 | 2. Practical projects on GitHub are often too comprehensive, including many aspects unrelated to understanding theory, such as parallel computing and model optimization. Furthermore, many projects are provided as `.py` files instead of Jupyter notebooks, which makes it even harder for beginners like me to read and understand.
9 |
10 | To address these two issues, I built this project with a focus on using the simplest possible code to demonstrate deep learning concepts in practice. The goal is to help others understand the theory through hands-on experience.
11 |
12 | # Project Structure
13 |
14 | | Series Name | Project Title | Blog Links | Google Colab |
15 | |--|--|--|--|
16 | | PyTorch Beginner's Tutorial | 01. Implementing Linear Regression | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-01.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/01_linear_regression.ipynb) |
17 | || 02. Using a BP Neural Network to Recognize MNIST Handwritten Digits | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-02.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/02_MNIST_classification.ipynb) |
18 | || 03. Object Classification with a Simple CNN | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-03.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/03_cnn_image_classification.ipynb) |
19 | || 04. Sentiment Analysis of Text Using LSTM | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-04.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/04_LSTM_sentiment_analysis.ipynb) |
20 | || 05. Machine Translation Using nn.Transformer (English to Chinese) | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-05.html) | [Open In Colab](https://github.com/iioSnail/chaotic-transformer-tutorials/blob/master/en_to_zh_demo.ipynb) |
21 | || 06. Using GAN to Generate Simple Anime Character Avatar | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-06.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/06_GAN_image_generation.ipynb) |
22 | || 07. Textual Metaphor Binary Classification with BERT | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-07.html) | [Open In Colab](https://github.com/iioSnail/chaotic-transformer-tutorials/blob/master/bert_classification_demo.ipynb) |
23 | || 08. Few-shot Learning for Image Classification | [Blogger](https://iiosnail.blogspot.com/2024/11/pytorch-en-08.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/08_few_shot_learning.ipynb) |
24 | | PyTorch Applications | 01. Video Face Mosaic Processing with YOLO | [Blogger](https://iiosnail.blogspot.com/2024/12/mosaic-en.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/101_mosaic_video.ipynb) |
25 | || 02. Generate Lip-syncing Video with Wav2Lip | TODO | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/102_Wav2Lip_Inference.ipynb) |
26 | | LLM Applications | 01. A Native RAG Example with LangChain | [Blogger](https://iiosnail.blogspot.com/2025/04/native-rag.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/201_native_rag.ipynb) |
27 |
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 中文 | English
2 |
3 | # Python项目实战
4 |
5 | 在学习深度学习的过程中有许多坎坷,对我来说,最大的困难有两点:
6 |
7 | 1. 大多的教学都是偏向理论,但如果不结合实践,很难吃透理论。
8 | 2. Github上的实战项目都是大而全的,里面包含了许多与理解理论无关的东西,例如并行计算、模型优化等等,而且很多都是py文件,并非jupyter,这样对于我这种菜鸟来说,增加了阅读难度。
9 |
10 | 针对这两个痛点,我构建了这个项目,致力于使用最简单的代码,来进行深度学习实战,帮助大家理解理论。
11 |
12 | # 项目目录
13 |
14 | | 系列名称 | 项目名称 | 博客地址 | Google Colab |
15 | |--|--|--|--|
16 | | Pytorch入门实战 | 01. 实现线性回归 | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/121418622), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch1.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/01_linear_regression.ipynb) |
17 | || 02. 使用BP神经网络实现MNIST手写数字识别 | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/122800647), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch2.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/02_MNIST_classification.ipynb) |
18 | || 03. 使用简单CNN实现物体分类 | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/125020186), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch3.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/03_cnn_image_classification.ipynb) |
19 | || 04. 基于LSTM实现文本的情感分析 | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/122838743), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch4.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/04_LSTM_sentiment_analysis.ipynb) |
20 | || 05. 基于nn.Transformer实现机器翻译(英译汉) | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/126175328), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch5.html) | [Open In Colab](https://github.com/iioSnail/chaotic-transformer-tutorials/blob/master/en_to_zh_demo.ipynb) |
21 | || 06. 基于GAN生成简单的动漫人物头像 | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/125675557), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch6.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/06_GAN_image_generation.ipynb) |
22 | || 07. 基于BERT实现文本隐喻二分类(Kaggle入门题目) | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/126426855), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch7.html) | [Open In Colab](https://github.com/iioSnail/chaotic-transformer-tutorials/blob/master/bert_classification_demo.ipynb) |
23 | || 08.小样本学习(Few-shot Learning)实现图片分类 | [CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/126453857), [Blogger](https://iiosnail.blogspot.com/2024/10/pytorch8.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/08_few_shot_learning.ipynb) |
24 | | Pytorch应用实战 | 01. 基于YOLO的视频人脸马赛克处理 | [CSDN](https://iio-snail.blog.csdn.net/article/details/144245634), [Blogger](https://iiosnail.blogspot.com/2024/12/mosaic.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/101_mosaic_video.ipynb) |
25 | || 02. 基于Wav2Lip的对嘴型(Lip-syncing)视频生成 | [CSDN](https://iio-snail.blog.csdn.net/article/details/146425716) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/102_Wav2Lip_Inference.ipynb) |
26 | | LLM应用实战 | 01. 基于LangChain的Native RAG简单样例 | [CSDN](https://iio-snail.blog.csdn.net/article/details/147148936), [Blogger](https://iiosnail.blogspot.com/2025/04/native-rag.html) | [Open In Colab](https://colab.research.google.com/github/iioSnail/pytorch_deep_learning_examples/blob/main/201_native_rag.ipynb) |
27 |
--------------------------------------------------------------------------------
/asserts/images/01_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/01_01.png
--------------------------------------------------------------------------------
/asserts/images/01_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/01_02.png
--------------------------------------------------------------------------------
/asserts/images/06_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/06_01.png
--------------------------------------------------------------------------------
/asserts/images/101_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/101_01.png
--------------------------------------------------------------------------------
/asserts/images/101_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/101_02.png
--------------------------------------------------------------------------------
/asserts/images/101_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/101_03.png
--------------------------------------------------------------------------------
/asserts/images/101_04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/101_04.png
--------------------------------------------------------------------------------
/asserts/images/101_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/101_05.png
--------------------------------------------------------------------------------
/asserts/images/101_06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/images/101_06.png
--------------------------------------------------------------------------------
/asserts/mp4/kunkun.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iioSnail/pytorch_deep_learning_examples/d6ed972dcf00e7a38296e46ef2d295002f61751b/asserts/mp4/kunkun.mp4
--------------------------------------------------------------------------------