├── 10-自动求导.ipynb ├── 100-几种Pytorch常用的并行训练加速方法.ipynb ├── 101-人脸识别模型(快速上手)-face_recognition.ipynb ├── 11-图像分类数据集.ipynb ├── 12-线性模型.ipynb ├── 13-线性回归从零实现.ipynb ├── 14-线性回归简洁版实现.ipynb ├── 15-Softmax回归从零开始实现.ipynb ├── 16-Softmax回归简洁版实现.ipynb ├── 17-多层感知机.ipynb ├── 18-多层感知机的从零开始实现.ipynb ├── 19-多层感知机的简洁实现.ipynb ├── 20-模型选择、欠拟合和过拟合.ipynb ├── 21-权重衰减.ipynb ├── 22-Dropout.ipynb ├── 23-正向传播、反向传播和计算图.ipynb ├── 24-数值稳定性和模型初始化.ipynb ├── 25-环境和分布偏移.ipynb ├── 26-深度学习动手实践:预测房价.ipynb ├── 27-深度学习结构-层和块的概念.ipynb ├── 28-深度学习计算-参数管理.ipynb ├── 29-深度学习计算-如何自定义层&文件读写.ipynb ├── 30-循环神经网络RNN-序列模型-自然语言处理-pytorch.ipynb ├── 31-循环神经网络RNN-文本预处理-自然语言处理-pytorch.ipynb ├── 32-循环神经网络RNN-语言模型和数据集-自然语言处理-pytorch.ipynb ├── 33-循环神经网络RNN-隐藏状态-自然语言处理-pytorch.ipynb ├── 34-循环神经网络RNN-从零实现-自然语言处理-pytorch.ipynb ├── 35-循环神经网络RNN-简洁版实现-自然语言处理-pytorch.ipynb ├── 37-循环神经网络RNN-反向传播数学细节-自然语言处理-pytorch.ipynb ├── 38-GRU原理和代码实现-自然语言处理-pytorch.ipynb ├── 39-LSTM原理和代码实现-自然语言处理-pytorch.ipynb ├── 40-深度循环神经网络原理和实现-自然语言处理-pytorch.ipynb ├── 41-双向循环神经网络原理和实现-自然语言处理-pytorch.ipynb ├── 42-机器翻译和数据处理-自然语言处理-pytorch.ipynb ├── 43-seq2seq-序列到序列-自然语言处理-pytorch.ipynb ├── 44-注意力提示-注意力机制-自然语言处理-pytorch.ipynb ├── 45-注意力机制-Watson核回归-自然语言处理-pytorch.ipynb ├── 46-注意力机制-注意力评分函数-自然语言处理-pytorch.ipynb ├── 47-注意力机制-Bahdanau注意力-自然语言处理-pytorch.ipynb ├── 48-注意力机制-多头注意力-自然语言处理-pytorch.ipynb ├── 49-注意力机制-自注意力(self-attention)和位置编码-自然语言处理-pytorch.ipynb ├── 50-注意力机制-Transformer-自然语言处理-pytorch.ipynb ├── 51-优化算法与深度学习-pytorch.ipynb ├── 52-优化算法-凸函数(convexity)与深度学习-pytorch.ipynb ├── 53-优化算法-梯度下降-深度学习-pytorch.ipynb ├── 54-优化算法-随机梯度下降-深度学习-pytorch.ipynb ├── 55-优化算法-小批量随机梯度下降-深度学习-pytorch.ipynb ├── 56-优化算法-动量法-深度学习-pytorch.ipynb ├── 57-优化算法-AdaGrad算法-深度学习-pytorch.ipynb ├── 58-优化算法-RMSProp算法-深度学习-pytorch.ipynb ├── 59-优化算法-Adadelta算法-深度学习-pytorch.ipynb ├── 60-优化算法-Adam算法-深度学习-pytorch.ipynb ├── 61-优化算法-学习率调度器-深度学习-pytorch.ipynb ├── 62-计算性能-命令编程和符号编程-深度学习-pytorch.ipynb ├── 63-计算性能-异步与并行-深度学习-pytorch.ipynb ├── 64-计算性能-多GPU训练-深度学习-pytorch.ipynb ├── 65-计算性能-几行代码实现多GPU训练-深度学习-pytorch.ipynb ├── 66-NLP-Word2vec-预训练-深度学习-pytorch.ipynb ├── 67-NLP-SkipGram改进(负采样和霍夫曼树)-预训练-深度学习-pytorch.ipynb ├── 68-NLP-word2vector训练之数据集整理-预训练-深度学习-pytorch.ipynb ├── 69-NLP-word2vector训练环节代码操作-预训练-深度学习-pytorch.ipynb ├── 7-线性代数.ipynb ├── 70-NLP-GloVe(Global Vectors)全局向量的词嵌入-预训练-深度学习-pytorch.ipynb ├── 71-NLP-子词嵌入(fastText)原理和代码实现-预训练-深度学习-pytorch.ipynb ├── 72-NLP-模型应用判断词的相似性和类比(代码实现)-预训练-深度学习-pytorch.ipynb ├── 73-NLP-BERT-预训练-深度学习-pytorch.ipynb ├── 74-NLP-BERT训练之数据集处理-预训练-深度学习-pytorch.ipynb ├── 75-NLP-BERT训练环节(代码实现)-预训练-深度学习-pytorch.ipynb ├── 76-NLP-情感分析数据预处理-项目实操-深度学习-pytorch.ipynb ├── 77-NLP-情感分析模型训练-项目实操-深度学习-pytorch.ipynb ├── 78-NLP-自然语言推断数据预处理-项目实操-深度学习-pytorch.ipynb ├── 79-NLP-自然语言推断模型训练-项目实操-深度学习-pytorch.ipynb ├── 8-概率.ipynb ├── 80-NLP-自然语言推断模型训练(BERT)-项目实操-深度学习-pytorch.ipynb ├── 81-NLP-GPT2(原理详解)-预训练-深度学习-pytorch.ipynb ├── 82-NLP-生成任务GPT2数据预处理-项目实操-深度学习-pytorch.ipynb ├── 83-卷积计算和互相关计算-卷积神经网络-pytorch.ipynb ├── 84-Padding和Stride-卷积神经网络.ipynb ├── 85-Channel多输入和输出通道-卷积神经网络.ipynb ├── 86-池化层-卷积神经网络.ipynb ├── 87-LeNet-卷积神经网络.ipynb ├── 88-AlexNet-卷积神经网络.ipynb ├── 89-VGG-卷积神经网络.ipynb ├── 9-微积分.ipynb ├── 90-NiN-卷积神经网络.ipynb ├── 91-GoogLeNet实现-卷积神经网络.ipynb ├── 92-批量归一化实现-卷积神经网络.ipynb ├── 93-ResNet(残差网络)-卷积神经网络.ipynb ├── 94-DenseNet(稠密连接网络)-卷积神经网络.ipynb ├── 95-GAN-生成对抗网络.ipynb ├── 96-Pytorch常用代码汇总-持续更新.ipynb ├── 97-NLP-语义相似度模型(SBERT)-深度学习-pytorch.ipynb ├── 98-数据处理(比pandas快几倍)-pytorch.ipynb ├── 99-模拟数据神器-Faker.ipynb ├── README.md └── dltools.py /10-自动求导.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "d060ea48", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "id": "e19b0422", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "x = torch.arange(4.0)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "297f935f", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "text/plain": [ 32 | "tensor([0., 1., 2., 3.])" 33 | ] 34 | }, 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "x" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "id": "08011121", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "tensor([0., 1., 2., 3.], requires_grad=True)" 54 | ] 55 | }, 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "x.requires_grad_(True)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "id": "716a9921", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "x.grad" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 12, 78 | "id": "11c70ea3", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "y =2 * torch.dot(x,x)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 13, 88 | "id": "5d7cdd33", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "tensor([ 0., 4., 8., 12.])" 95 | ] 96 | }, 97 | "execution_count": 13, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "y.backward()\n", 104 | "x.grad" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 16, 110 | "id": "ebeccd6a", 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "tensor([0., 0., 0., 0.])" 117 | ] 118 | }, 119 | "execution_count": 16, 120 | "metadata": {}, 121 | "output_type": "execute_result" 122 | } 123 | ], 124 | "source": [ 125 | "x.grad.zero_()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 15, 131 | "id": "b40e1158", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "tensor([True, True, True, True])" 138 | ] 139 | }, 140 | "execution_count": 15, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "x.grad == 4 * x" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 17, 152 | "id": "a94d91a5", 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "tensor([0., 0., 0., 0.])" 159 | ] 160 | }, 161 | "execution_count": 17, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "x.grad" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 18, 173 | "id": "c1e802b3", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "y = x.sum()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 19, 183 | "id": "d8dac36f", 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "tensor([0., 1., 2., 3.], requires_grad=True)" 190 | ] 191 | }, 192 | "execution_count": 19, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "x" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 21, 204 | "id": "02468790", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "y.backward()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 22, 214 | "id": "04bd9b71", 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "tensor([1., 1., 1., 1.])" 221 | ] 222 | }, 223 | "execution_count": 22, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "x.grad" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 23, 235 | "id": "cf20120a", 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "tensor([0., 0., 0., 0.])" 242 | ] 243 | }, 244 | "execution_count": 23, 245 | "metadata": {}, 246 | "output_type": "execute_result" 247 | } 248 | ], 249 | "source": [ 250 | "x.grad.zero_()" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 7, 256 | "id": "fbc74b6d", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "y = x * x" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 25, 266 | "id": "5d8759b6", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "y.backward(torch.ones(len(x)))" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 26, 276 | "id": "dd705249", 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "text/plain": [ 282 | "tensor([0., 2., 4., 6.])" 283 | ] 284 | }, 285 | "execution_count": 26, 286 | "metadata": {}, 287 | "output_type": "execute_result" 288 | } 289 | ], 290 | "source": [ 291 | "x.grad" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 27, 297 | "id": "426f76e4", 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "data": { 302 | "text/plain": [ 303 | "tensor([0., 1., 2., 3.], requires_grad=True)" 304 | ] 305 | }, 306 | "execution_count": 27, 307 | "metadata": {}, 308 | "output_type": "execute_result" 309 | } 310 | ], 311 | "source": [ 312 | "x" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 28, 318 | "id": "1bc33ea0", 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "text/plain": [ 324 | "tensor([0., 0., 0., 0.])" 325 | ] 326 | }, 327 | "execution_count": 28, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | } 331 | ], 332 | "source": [ 333 | "x.grad.zero_()" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 29, 339 | "id": "26c87216", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "y = x * x" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 30, 349 | "id": "12e2212a", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "u = y.detach()" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 31, 359 | "id": "b66bbad9", 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "z = u * x" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 32, 369 | "id": "132da71d", 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "z.sum().backward()" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 33, 379 | "id": "ac9c95ed", 380 | "metadata": {}, 381 | "outputs": [ 382 | { 383 | "data": { 384 | "text/plain": [ 385 | "tensor([True, True, True, True])" 386 | ] 387 | }, 388 | "execution_count": 33, 389 | "metadata": {}, 390 | "output_type": "execute_result" 391 | } 392 | ], 393 | "source": [ 394 | "x.grad == u" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 34, 400 | "id": "b2099868", 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "data": { 405 | "text/plain": [ 406 | "tensor([0., 1., 4., 9.])" 407 | ] 408 | }, 409 | "execution_count": 34, 410 | "metadata": {}, 411 | "output_type": "execute_result" 412 | } 413 | ], 414 | "source": [ 415 | "x.grad" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 35, 421 | "id": "18364630", 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "tensor([0., 1., 2., 3.], requires_grad=True)" 428 | ] 429 | }, 430 | "execution_count": 35, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "x" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 36, 442 | "id": "761ff5e3", 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "data": { 447 | "text/plain": [ 448 | "tensor([0., 0., 0., 0.])" 449 | ] 450 | }, 451 | "execution_count": 36, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "x.grad.zero_()" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 37, 463 | "id": "46ee7ffe", 464 | "metadata": {}, 465 | "outputs": [ 466 | { 467 | "data": { 468 | "text/plain": [ 469 | "tensor([True, True, True, True])" 470 | ] 471 | }, 472 | "execution_count": 37, 473 | "metadata": {}, 474 | "output_type": "execute_result" 475 | } 476 | ], 477 | "source": [ 478 | "y.sum().backward()\n", 479 | "x.grad == 2 * x" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 14, 485 | "id": "67d40e52", 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "a = torch.Tensor([[1,2], [3,4]])\n", 490 | "b = torch.Tensor([[1,2,3,4], [1,2,3,4]])" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 17, 496 | "id": "7ba71066", 497 | "metadata": {}, 498 | "outputs": [ 499 | { 500 | "ename": "RuntimeError", 501 | "evalue": "The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1", 502 | "output_type": "error", 503 | "traceback": [ 504 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 505 | "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", 506 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 507 | "\u001b[1;31mRuntimeError\u001b[0m: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 1" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "torch.mul(a,b)" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "id": "3052bf8e", 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [] 522 | } 523 | ], 524 | "metadata": { 525 | "kernelspec": { 526 | "display_name": "Python 3", 527 | "language": "python", 528 | "name": "python3" 529 | }, 530 | "language_info": { 531 | "codemirror_mode": { 532 | "name": "ipython", 533 | "version": 3 534 | }, 535 | "file_extension": ".py", 536 | "mimetype": "text/x-python", 537 | "name": "python", 538 | "nbconvert_exporter": "python", 539 | "pygments_lexer": "ipython3", 540 | "version": "3.6.5" 541 | } 542 | }, 543 | "nbformat": 4, 544 | "nbformat_minor": 5 545 | } 546 | -------------------------------------------------------------------------------- /11-图像分类数据集.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "fe3e6df5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%config Completer.use_jedi = False\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "9d75727f", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import torch\n", 22 | "import torchvision\n", 23 | "from torch.utils import data\n", 24 | "from torchvision import transforms\n", 25 | "from d2l import torch as d2l\n", 26 | "\n", 27 | "d2l.use_svg_display()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "id": "17c24781", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# 读取数据集\n", 38 | "# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式\n", 39 | "# 并除以255使得所有像素的数值均在0到1之间\n", 40 | "trans = transforms.ToTensor()\n", 41 | "mnist_train = torchvision.datasets.FashionMNIST(root=\"../data\", train=True,transform=trans,download=True)\n", 42 | "mnist_test = torchvision.datasets.FashionMNIST(root=\"../data\", train=False,transform=trans, download=True)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 6, 48 | "id": "47f473f3", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "(60000, 10000)" 55 | ] 56 | }, 57 | "execution_count": 6, 58 | "metadata": {}, 59 | "output_type": "execute_result" 60 | } 61 | ], 62 | "source": [ 63 | "len(mnist_train), len(mnist_test)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 9, 69 | "id": "42a92d81", 70 | "metadata": { 71 | "scrolled": false 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "\"\"\"\n", 76 | "Fashion-MNIST中包含的10个类别分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)\n", 77 | "coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。\n", 78 | "以下函数用于在数字标签索引及其文本名称之间进行转换。\n", 79 | "\"\"\"\n", 80 | "def get_fashion_mnist_labels(labels): \n", 81 | " \"\"\"返回Fashion-MNIST数据集的文本标签。\"\"\"\n", 82 | " text_labels = [\n", 83 | " 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt',\n", 84 | " 'sneaker', 'bag', 'ankle boot']\n", 85 | " return [text_labels[int(i)] for i in labels]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 8, 91 | "id": "a1d550de", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# 创建一个函数来可视化这些样本。\n", 96 | "def show_images(imgs, num_rows, num_cols, titles=None, scale=1.1):\n", 97 | " \"\"\"Plot a list of images.\"\"\"\n", 98 | " figsize = (num_cols * scale, num_rows * scale)\n", 99 | " _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n", 100 | " axes = axes.flatten()\n", 101 | " for i, (ax, img) in enumerate(zip(axes, imgs)):\n", 102 | " if torch.is_tensor(img):\n", 103 | " # 图片张量\n", 104 | " ax.imshow(img.numpy())\n", 105 | " else:\n", 106 | " # PIL图片\n", 107 | " ax.imshow(img)\n", 108 | " ax.axes.get_xaxis().set_visible(False)\n", 109 | " ax.axes.get_yaxis().set_visible(False)\n", 110 | " if titles:\n", 111 | " ax.set_title(titles[i])\n", 112 | " return axes" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 18, 118 | "id": "04711e83", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "tensor([5, 8, 4, 2, 0, 9])" 125 | ] 126 | }, 127 | "execution_count": 18, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "# 以下是训练数据集中前几个样本的图像及其相应的标签(文本形式)\n", 134 | "X, y = next(iter(data.DataLoader(mnist_train,shuffle=True, batch_size=6)))\n", 135 | "#show_images(X.reshape(6, 28, 28), 2, 3, titles=get_fashion_mnist_labels(y));\n", 136 | "y" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 19, 142 | "id": "9159f1fe", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# 读取小批量\n", 147 | "batch_size = 256\n", 148 | "\n", 149 | "def get_dataloader_workers(): #@save\n", 150 | " \"\"\"使用4个进程来读取数据。\"\"\"\n", 151 | " return 4\n", 152 | "\n", 153 | "train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=4)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 20, 159 | "id": "43eae658", 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "'6.74 sec'" 166 | ] 167 | }, 168 | "execution_count": 20, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "# 看一下读取数据的时间\n", 175 | "timer = d2l.Timer()\n", 176 | "for X, y in train_iter:\n", 177 | " continue\n", 178 | "f'{timer.stop():.2f} sec'" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 28, 184 | "id": "c2e20065", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | " # 完整的读取例子\n", 189 | "def load_data_fashion_mnist(batch_size, resize=None): \n", 190 | " \"\"\"下载Fashion-MNIST数据集,然后将其加载到内存中。\"\"\"\n", 191 | " trans = [transforms.ToTensor()]\n", 192 | " if resize:\n", 193 | " trans.insert(0, transforms.Resize(resize))\n", 194 | " trans = transforms.Compose(trans)\n", 195 | " # 训练集\n", 196 | " train_dataset = torchvision.datasets.FashionMNIST(root=\"../data\",train=True,transform=trans,download=True)\n", 197 | " train_dataloader = data.DataLoader(train_dataset, batch_size, shuffle=True,num_workers=get_dataloader_workers())\n", 198 | " # 测试集\n", 199 | " test_dataset = torchvision.datasets.FashionMNIST(root=\"../data\",train=False,transform=trans,download=True)\n", 200 | " test_dataloader = data.DataLoader(test_dataset, batch_size, shuffle=False,num_workers=get_dataloader_workers())\n", 201 | " return (train_dataloader,test_dataloader)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 30, 207 | "id": "2abd49f8", 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "torch.Size([6, 1, 28, 28]) torch.float32 torch.Size([6]) torch.int64\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "train_iter, test_iter = load_data_fashion_mnist(6, resize=None)\n", 220 | "for X, y in train_iter:\n", 221 | " print(X.shape, X.dtype, y.shape, y.dtype)\n", 222 | " break" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 31, 228 | "id": "7c16b409", 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "tensor([1, 1, 4, 5, 3, 5])" 235 | ] 236 | }, 237 | "execution_count": 31, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "y" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "525a4aaa", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [] 253 | } 254 | ], 255 | "metadata": { 256 | "kernelspec": { 257 | "display_name": "Python 3", 258 | "language": "python", 259 | "name": "python3" 260 | }, 261 | "language_info": { 262 | "codemirror_mode": { 263 | "name": "ipython", 264 | "version": 3 265 | }, 266 | "file_extension": ".py", 267 | "mimetype": "text/x-python", 268 | "name": "python", 269 | "nbconvert_exporter": "python", 270 | "pygments_lexer": "ipython3", 271 | "version": "3.8.5" 272 | } 273 | }, 274 | "nbformat": 4, 275 | "nbformat_minor": 5 276 | } 277 | -------------------------------------------------------------------------------- /14-线性回归简洁版实现.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "0e0e1ebc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%config Completer.use_jedi = False\n", 11 | "%matplotlib inline\n", 12 | "import random\n", 13 | "import torch\n", 14 | "from torch.utils import data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 5, 20 | "id": "24f1e21d", 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "tensor([-0.8812, 0.4406]) tensor([0.9518])\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "## 生成数据集与手动版\n", 33 | "# 使用线性模型参数 w=[2,−3.4]⊤ 、 b=4.2 和噪声项 ϵ 生成数据集及其标签:y=Xw+b+ϵ.\n", 34 | "# 可以将 ϵ 视为捕获特征和标签时的潜在观测误差。在这里我们认为标准假设成立,即 ϵ 服从均值为0的正态分布。\n", 35 | "# 为了简化问题,我们将标准差设为0.01。下面的代码生成合成数据集。\n", 36 | "\n", 37 | "def synthetic_data(w, b, num_examples): #@save\n", 38 | " \"\"\"生成 y = Xw + b + 噪声。\"\"\"\n", 39 | " X = torch.normal(0, 1, (num_examples, len(w))) \n", 40 | " y = torch.matmul(X, w) + b # 1000 X 1\n", 41 | " y += torch.normal(0, 0.01, y.shape)\n", 42 | " return X, y.reshape((-1, 1))\n", 43 | "\n", 44 | "true_w = torch.tensor([2, -3.4])\n", 45 | "true_b = 4.2\n", 46 | "features, labels = synthetic_data(true_w, true_b, 1000)\n", 47 | "print(features[1],labels[1])" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 6, 53 | "id": "f2298245", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# 读取数据集\n", 58 | "def load_array(data_arrays, batch_size, is_train=True): #@save\n", 59 | " \"\"\"构造一个PyTorch数据迭代器。\"\"\"\n", 60 | " dataset = data.TensorDataset(*data_arrays)\n", 61 | " return data.DataLoader(dataset, batch_size, shuffle=is_train)\n", 62 | "\n", 63 | "batch_size = 10\n", 64 | "data_iter = load_array((features, labels), batch_size)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 7, 70 | "id": "d58cb7a4", 71 | "metadata": { 72 | "scrolled": true 73 | }, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "[tensor([[-1.4016, -1.6029],\n", 79 | " [-1.0687, -1.6633],\n", 80 | " [ 0.6922, -0.0588],\n", 81 | " [-0.9198, -1.3350],\n", 82 | " [ 2.0441, 0.7863],\n", 83 | " [-0.2757, 0.2868],\n", 84 | " [ 0.3032, 1.5384],\n", 85 | " [-0.9962, -2.2553],\n", 86 | " [-0.0897, 0.0033],\n", 87 | " [ 2.2038, -1.5245]]),\n", 88 | " tensor([[ 6.8442],\n", 89 | " [ 7.7281],\n", 90 | " [ 5.7971],\n", 91 | " [ 6.9003],\n", 92 | " [ 5.6190],\n", 93 | " [ 2.6939],\n", 94 | " [-0.4086],\n", 95 | " [ 9.8762],\n", 96 | " [ 4.0031],\n", 97 | " [13.7947]])]" 98 | ] 99 | }, 100 | "execution_count": 7, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "# 查看迭代取数是否正常\n", 107 | "next(iter(data_iter))" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 8, 113 | "id": "bcff7e85", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# 定义模型\n", 118 | "# `nn` 是神经网络的缩写\n", 119 | "from torch import nn\n", 120 | "\n", 121 | "net = nn.Sequential(nn.Linear(2, 1))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 13, 127 | "id": "b43ed0e9", 128 | "metadata": { 129 | "scrolled": true 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "Parameter containing:\n", 136 | "tensor([0.], requires_grad=True)" 137 | ] 138 | }, 139 | "execution_count": 13, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "# 初始化模型参数\n", 146 | "net[0].weight.data.normal_(0, 0.01)\n", 147 | "net[0].bias.data.fill_(0)\n", 148 | "net[0].bias" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 14, 154 | "id": "b5d2b487", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "# 定义损失函数\n", 159 | "# 计算均方误差使用的是MSELoss类,也称为平方 L2 范数。默认情况下,它返回所有样本损失的平均值。\n", 160 | "loss = nn.MSELoss()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 15, 166 | "id": "ff18ca9e", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# 定义优化算法\n", 171 | "# 我们要指定优化的参数(可通过 net.parameters() 从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置 lr值。\n", 172 | "trainer = torch.optim.SGD(net.parameters(), lr=0.03)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 16, 178 | "id": "b38eebd9", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "epoch 1, loss 0.000189\n", 186 | "epoch 2, loss 0.000100\n", 187 | "epoch 3, loss 0.000100\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "# 开始训练\n", 193 | "# 通过调用 net(X) 生成预测并计算损失 l(正向传播)。\n", 194 | "# 通过进行反向传播来计算梯度。\n", 195 | "# 通过调用优化器来更新模型参数。\n", 196 | "num_epochs = 3\n", 197 | "for epoch in range(num_epochs):\n", 198 | " for X, y in data_iter:\n", 199 | " l = loss(net(X), y)\n", 200 | " trainer.zero_grad() # 梯度清零\n", 201 | " l.backward() # 反向传播\n", 202 | " trainer.step() # 更新阐述\n", 203 | " l = loss(net(features), labels)\n", 204 | " print(f'epoch {epoch + 1}, loss {l:f}')" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 17, 210 | "id": "cba7323a", 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "name": "stdout", 215 | "output_type": "stream", 216 | "text": [ 217 | "w的估计误差: tensor([ 0.0006, -0.0002])\n", 218 | "b的估计误差: tensor([0.0007])\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "# 和真实的参数比较误差\n", 224 | "w = net[0].weight.data\n", 225 | "print('w的估计误差:', true_w - w.reshape(true_w.shape))\n", 226 | "b = net[0].bias.data\n", 227 | "print('b的估计误差:', true_b - b)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "1d455f67", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [] 237 | } 238 | ], 239 | "metadata": { 240 | "kernelspec": { 241 | "display_name": "Python 3", 242 | "language": "python", 243 | "name": "python3" 244 | }, 245 | "language_info": { 246 | "codemirror_mode": { 247 | "name": "ipython", 248 | "version": 3 249 | }, 250 | "file_extension": ".py", 251 | "mimetype": "text/x-python", 252 | "name": "python", 253 | "nbconvert_exporter": "python", 254 | "pygments_lexer": "ipython3", 255 | "version": "3.8.5" 256 | } 257 | }, 258 | "nbformat": 4, 259 | "nbformat_minor": 5 260 | } 261 | -------------------------------------------------------------------------------- /28-深度学习计算-参数管理.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d85fc0f9", 6 | "metadata": {}, 7 | "source": [ 8 | "1. 参数访问\n", 9 | " * 目标参数:每个参数都表示为参数(parameter)类的一个实例。要对参数执行任何操作,首先我们需要访问底层的数值。\n", 10 | " * 一次性访问所有参数\n", 11 | " * 从嵌套块收集参数\n", 12 | "2. 参数初始化\n", 13 | " * 内置初始化\n", 14 | " * 自定义初始化\n", 15 | "3. 参数绑定\n", 16 | "4. 小结" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "17141bf0", 23 | "metadata": { 24 | "scrolled": true 25 | }, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "text/plain": [ 30 | "tensor([[0.4296],\n", 31 | " [0.4028]], grad_fn=)" 32 | ] 33 | }, 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "output_type": "execute_result" 37 | } 38 | ], 39 | "source": [ 40 | "# 单隐藏层的多层感知机。\n", 41 | "import torch\n", 42 | "from torch import nn\n", 43 | "\n", 44 | "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n", 45 | "X = torch.rand(size=(2, 4))\n", 46 | "net(X)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "id": "f9609809", 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "OrderedDict([('weight', tensor([[ 0.1570, 0.1388, 0.0965, -0.0047, -0.3511, 0.3259, 0.1751, 0.2166]])), ('bias', tensor([0.3295]))])\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "# 参数访问\n", 65 | "print(net[2].state_dict())" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "id": "c4317390", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "\n", 79 | "Parameter containing:\n", 80 | "tensor([0.3295], requires_grad=True)\n", 81 | "tensor([0.3295])\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "# 下面的代码从第二个神经网络层提取偏置,提取后返回的是一个参数类实例,并进一步访问该参数的值。\n", 87 | "# 参数是复合的对象,包含值、梯度和额外信息\n", 88 | "print(type(net[2].bias))\n", 89 | "print(net[2].bias)\n", 90 | "print(net[2].bias.data)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "id": "8ee89127", 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "True" 103 | ] 104 | }, 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "output_type": "execute_result" 108 | } 109 | ], 110 | "source": [ 111 | "# 除了值之外,我们还可以访问每个参数的梯度,由于没有调用反向传播,所以参数的梯度处于初始状态。\n", 112 | "net[2].weight.grad == None" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 9, 118 | "id": "c06da1ee", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n", 126 | "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n" 127 | ] 128 | }, 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "tensor([0.3295])" 133 | ] 134 | }, 135 | "execution_count": 9, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "# 一次性访问所有参数\n", 142 | "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n", 143 | "print(*[(name, param.shape) for name, param in net.named_parameters()])\n", 144 | "net.state_dict()['2.bias'].data" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 10, 150 | "id": "4c47114e", 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "tensor([[0.3266],\n", 157 | " [0.3266]], grad_fn=)" 158 | ] 159 | }, 160 | "execution_count": 10, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "# 从嵌套块收集参数\n", 167 | "\n", 168 | "def block1():\n", 169 | " return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n", 170 | " nn.Linear(8, 4), nn.ReLU())\n", 171 | "\n", 172 | "def block2():\n", 173 | " net = nn.Sequential()\n", 174 | " for i in range(4):\n", 175 | " # 在这里嵌套\n", 176 | " net.add_module(f'block {i}', block1())\n", 177 | " return net\n", 178 | "\n", 179 | "rgnet = nn.Sequential(block2(), nn.Linear(4, 1))\n", 180 | "rgnet(X)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 11, 186 | "id": "e7aeec8a", 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "name": "stdout", 191 | "output_type": "stream", 192 | "text": [ 193 | "Sequential(\n", 194 | " (0): Sequential(\n", 195 | " (block 0): Sequential(\n", 196 | " (0): Linear(in_features=4, out_features=8, bias=True)\n", 197 | " (1): ReLU()\n", 198 | " (2): Linear(in_features=8, out_features=4, bias=True)\n", 199 | " (3): ReLU()\n", 200 | " )\n", 201 | " (block 1): Sequential(\n", 202 | " (0): Linear(in_features=4, out_features=8, bias=True)\n", 203 | " (1): ReLU()\n", 204 | " (2): Linear(in_features=8, out_features=4, bias=True)\n", 205 | " (3): ReLU()\n", 206 | " )\n", 207 | " (block 2): Sequential(\n", 208 | " (0): Linear(in_features=4, out_features=8, bias=True)\n", 209 | " (1): ReLU()\n", 210 | " (2): Linear(in_features=8, out_features=4, bias=True)\n", 211 | " (3): ReLU()\n", 212 | " )\n", 213 | " (block 3): Sequential(\n", 214 | " (0): Linear(in_features=4, out_features=8, bias=True)\n", 215 | " (1): ReLU()\n", 216 | " (2): Linear(in_features=8, out_features=4, bias=True)\n", 217 | " (3): ReLU()\n", 218 | " )\n", 219 | " )\n", 220 | " (1): Linear(in_features=4, out_features=1, bias=True)\n", 221 | ")\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "print(rgnet)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 12, 232 | "id": "6d7fad2e", 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "tensor([-0.1088, 0.0520, -0.2078, -0.0831, -0.1747, -0.2374, -0.2178, -0.4653])" 239 | ] 240 | }, 241 | "execution_count": 12, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "rgnet[0][1][0].bias.data" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 13, 253 | "id": "74071331", 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "data": { 258 | "text/plain": [ 259 | "(tensor([ 0.0147, 0.0044, -0.0047, -0.0027]), tensor(0.))" 260 | ] 261 | }, 262 | "execution_count": 13, 263 | "metadata": {}, 264 | "output_type": "execute_result" 265 | } 266 | ], 267 | "source": [ 268 | "# 参数初始化\n", 269 | "# 高斯分布\n", 270 | "def init_normal(m):\n", 271 | " if type(m) == nn.Linear:\n", 272 | " nn.init.normal_(m.weight, mean=0, std=0.01)\n", 273 | " nn.init.zeros_(m.bias)\n", 274 | "net.apply(init_normal)\n", 275 | "net[0].weight.data[0], net[0].bias.data[0]" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 14, 281 | "id": "6b716fd7", 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "(tensor([1., 1., 1., 1.]), tensor(0.))" 288 | ] 289 | }, 290 | "execution_count": 14, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "# 常数\n", 297 | "def init_constant(m):\n", 298 | " if type(m) == nn.Linear:\n", 299 | " nn.init.constant_(m.weight, 1)\n", 300 | " nn.init.zeros_(m.bias)\n", 301 | "net.apply(init_constant)\n", 302 | "net[0].weight.data[0], net[0].bias.data[0]" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 15, 308 | "id": "51f4b1ff", 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "tensor([-0.5934, 0.0626, -0.3637, 0.2423])\n", 316 | "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n" 317 | ] 318 | } 319 | ], 320 | "source": [ 321 | "# 不同块采用不同的初始化\n", 322 | "def xavier(m):\n", 323 | " if type(m) == nn.Linear:\n", 324 | " nn.init.xavier_uniform_(m.weight)\n", 325 | "def init_42(m):\n", 326 | " if type(m) == nn.Linear:\n", 327 | " nn.init.constant_(m.weight, 42)\n", 328 | "\n", 329 | "net[0].apply(xavier)\n", 330 | "net[2].apply(init_42)\n", 331 | "print(net[0].weight.data[0])\n", 332 | "print(net[2].weight.data)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 16, 338 | "id": "1e776439", 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "Init weight torch.Size([8, 4])\n", 346 | "Init weight torch.Size([1, 8])\n" 347 | ] 348 | }, 349 | { 350 | "data": { 351 | "text/plain": [ 352 | "Parameter containing:\n", 353 | "tensor([[-5.4169, 6.9345, -5.3274, 6.7716],\n", 354 | " [ 8.4636, 0.2145, 6.9587, 1.5237],\n", 355 | " [-5.5604, 8.0154, -4.5076, -9.5305],\n", 356 | " [ 9.0585, -6.1241, 8.2844, -8.7465],\n", 357 | " [ 6.6681, -3.4718, -7.3977, 9.0384],\n", 358 | " [ 1.4992, 2.2234, 6.6491, 4.0744],\n", 359 | " [-3.2752, 8.6998, -5.5834, 3.9857],\n", 360 | " [ 4.8742, 8.1861, -3.1011, 4.0511]], requires_grad=True)" 361 | ] 362 | }, 363 | "execution_count": 16, 364 | "metadata": {}, 365 | "output_type": "execute_result" 366 | } 367 | ], 368 | "source": [ 369 | "# 自定义参数初始化\n", 370 | "def my_init(m):\n", 371 | " if type(m) == nn.Linear:\n", 372 | " print(\"Init\", *[(name, param.shape)\n", 373 | " for name, param in m.named_parameters()][0])\n", 374 | " nn.init.uniform_(m.weight, -10, 10)\n", 375 | " m.weight.data *= m.weight.data.abs() >= 5\n", 376 | "\n", 377 | "net.apply(my_init)\n", 378 | "net[0].weight" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 17, 384 | "id": "ea07392c", 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "text/plain": [ 390 | "tensor([42.0000, 7.9345, -4.3274, 7.7716])" 391 | ] 392 | }, 393 | "execution_count": 17, 394 | "metadata": {}, 395 | "output_type": "execute_result" 396 | } 397 | ], 398 | "source": [ 399 | "# 任意赋值\n", 400 | "net[0].weight.data[:] += 1\n", 401 | "net[0].weight.data[0, 0] = 42\n", 402 | "net[0].weight.data[0]" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 18, 408 | "id": "e2668016", 409 | "metadata": {}, 410 | "outputs": [ 411 | { 412 | "name": "stdout", 413 | "output_type": "stream", 414 | "text": [ 415 | "tensor([True, True, True, True, True, True, True, True])\n", 416 | "tensor([True, True, True, True, True, True, True, True])\n" 417 | ] 418 | } 419 | ], 420 | "source": [ 421 | "# 我们需要给共享层一个名称,以便可以引用它的参数。\n", 422 | "shared = nn.Linear(8, 8)\n", 423 | "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n", 424 | " shared, nn.ReLU(),\n", 425 | " shared, nn.ReLU(),\n", 426 | " nn.Linear(8, 1))\n", 427 | "net(X)\n", 428 | "# 检查参数是否相同\n", 429 | "print(net[2].weight.data[0] == net[4].weight.data[0])\n", 430 | "net[2].weight.data[0, 0] = 100\n", 431 | "# 确保它们实际上是同一个对象,而不只是有相同的值。\n", 432 | "print(net[2].weight.data[0] == net[4].weight.data[0])" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 19, 438 | "id": "4a50cf4c", 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "tensor([ 1.0000e+02, 8.9864e-03, 2.7858e-01, 1.6640e-01, -7.8594e-02,\n", 445 | " 4.4907e-02, -4.4838e-02, 2.4383e-01])" 446 | ] 447 | }, 448 | "execution_count": 19, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "net[2].weight.data[0]" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "id": "d762103b", 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [] 464 | } 465 | ], 466 | "metadata": { 467 | "kernelspec": { 468 | "display_name": "Python 3", 469 | "language": "python", 470 | "name": "python3" 471 | }, 472 | "language_info": { 473 | "codemirror_mode": { 474 | "name": "ipython", 475 | "version": 3 476 | }, 477 | "file_extension": ".py", 478 | "mimetype": "text/x-python", 479 | "name": "python", 480 | "nbconvert_exporter": "python", 481 | "pygments_lexer": "ipython3", 482 | "version": "3.6.5" 483 | } 484 | }, 485 | "nbformat": 4, 486 | "nbformat_minor": 5 487 | } 488 | -------------------------------------------------------------------------------- /29-深度学习计算-如何自定义层&文件读写.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "12eeb556", 6 | "metadata": {}, 7 | "source": [ 8 | "1.自定义层:\n", 9 | " * 带参数的层\n", 10 | " * 不带参数的层\n", 11 | "\n", 12 | "***\n", 13 | "2.文件读写\n", 14 | " * 加载和保存张量\n", 15 | " * 加载和保存模型参数" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "id": "25cb322f", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# 不带参数的层\n", 26 | "import torch\n", 27 | "import torch.nn.functional as F\n", 28 | "from torch import nn\n", 29 | "\n", 30 | "\n", 31 | "class CenteredLayer(nn.Module):\n", 32 | " def __init__(self):\n", 33 | " super().__init__()\n", 34 | "\n", 35 | " def forward(self, X):\n", 36 | " return X - X.mean()" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "242c20f4", 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "tensor([-2., -1., 0., 1., 2.])" 49 | ] 50 | }, 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "output_type": "execute_result" 54 | } 55 | ], 56 | "source": [ 57 | "# 验证一下\n", 58 | "layer = CenteredLayer()\n", 59 | "layer(torch.FloatTensor([1, 2, 3, 4, 5]))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "73d25719", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# 现在,我们可以将层作为组件合并到构建更复杂的模型中。\n", 70 | "net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 6, 76 | "id": "f4ed1c96", 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "torch.Size([4, 128])" 83 | ] 84 | }, 85 | "execution_count": 6, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "Y = net(torch.rand(4, 8))\n", 92 | "Y.mean()\n", 93 | "Y.shape" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 8, 99 | "id": "c72434e4", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# 带参数的层\n", 104 | "\n", 105 | "class MyLinear(nn.Module):\n", 106 | " def __init__(self, in_units, units):\n", 107 | " super().__init__()\n", 108 | " self.weight = nn.Parameter(torch.randn(in_units, units))\n", 109 | " self.bias = nn.Parameter(torch.randn(units,))\n", 110 | " def forward(self, X):\n", 111 | " linear = torch.matmul(X, self.weight.data) + self.bias.data\n", 112 | " return F.relu(linear)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 9, 118 | "id": "82b4a3b9", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "Parameter containing:\n", 125 | "tensor([[-0.6468, -0.3709, 0.8210],\n", 126 | " [ 0.0276, -0.4635, -0.7936],\n", 127 | " [-0.2383, -0.1799, -0.1329],\n", 128 | " [ 0.0879, 0.2967, 3.2110],\n", 129 | " [-2.8279, 1.3314, -0.0485]], requires_grad=True)" 130 | ] 131 | }, 132 | "execution_count": 9, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "# 实例化\n", 139 | "linear = MyLinear(5, 3)\n", 140 | "linear.weight" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 10, 146 | "id": "6b42b380", 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/plain": [ 152 | "torch.Size([2, 1])" 153 | ] 154 | }, 155 | "execution_count": 10, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "# 我们还可以使用自定义层构建模型\n", 162 | "net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))\n", 163 | "net(torch.rand(2, 64))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 12, 169 | "id": "3a1c2472", 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "data": { 174 | "text/plain": [ 175 | "tensor([0, 1, 2, 3])" 176 | ] 177 | }, 178 | "execution_count": 12, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "# 文件读写部分:\n", 185 | "# 对于单个张量,我们可以直接调用load和save函数分别读写\n", 186 | "import torch\n", 187 | "from torch import nn\n", 188 | "from torch.nn import functional as F\n", 189 | "\n", 190 | "x = torch.arange(4)\n", 191 | "torch.save(x, 'x-file')" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 13, 197 | "id": "045c9895", 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "tensor([0, 1, 2, 3])" 204 | ] 205 | }, 206 | "execution_count": 13, 207 | "metadata": {}, 208 | "output_type": "execute_result" 209 | } 210 | ], 211 | "source": [ 212 | "x2 = torch.load('x-file')\n", 213 | "x2" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 14, 219 | "id": "591d9c69", 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))" 226 | ] 227 | }, 228 | "execution_count": 14, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "# 我们可以存储一个张量列表,然后把它们读回内存。\n", 235 | "y = torch.zeros(4)\n", 236 | "torch.save([x, y],'x-files')\n", 237 | "x2, y2 = torch.load('x-files')\n", 238 | "(x2, y2)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 15, 244 | "id": "29f2998c", 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "data": { 249 | "text/plain": [ 250 | "{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}" 251 | ] 252 | }, 253 | "execution_count": 15, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "# 写入或读取从字符串映射到张量的字典\n", 260 | "mydict = {'x': x, 'y': y}\n", 261 | "torch.save(mydict, 'mydict')\n", 262 | "mydict2 = torch.load('mydict')\n", 263 | "mydict2" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 17, 269 | "id": "69b8d45c", 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "OrderedDict([('hidden.weight',\n", 276 | " tensor([[ 0.1414, 0.1394, 0.1123, ..., 0.1445, 0.1443, 0.2183],\n", 277 | " [ 0.0397, 0.2163, -0.1119, ..., 0.0144, 0.0573, -0.1374],\n", 278 | " [ 0.1329, -0.0665, -0.2027, ..., 0.0693, 0.1316, -0.1639],\n", 279 | " ...,\n", 280 | " [ 0.0669, -0.1595, 0.1582, ..., 0.0316, -0.1544, 0.1534],\n", 281 | " [-0.1199, 0.1117, 0.1042, ..., -0.0875, -0.0185, -0.0969],\n", 282 | " [-0.1175, -0.0018, 0.0047, ..., -0.2157, 0.1471, 0.0033]])),\n", 283 | " ('hidden.bias',\n", 284 | " tensor([ 0.0686, 0.1704, 0.0100, 0.1763, 0.1053, -0.0825, 0.0187, 0.0017,\n", 285 | " -0.2133, 0.1802, -0.1089, 0.1716, 0.0216, -0.1982, -0.0934, -0.0977,\n", 286 | " 0.1516, 0.0691, -0.2066, -0.0339, -0.0680, 0.2016, -0.0111, 0.1381,\n", 287 | " 0.0830, 0.0088, 0.1984, -0.0877, -0.0884, -0.2093, -0.1989, -0.2088,\n", 288 | " -0.0362, 0.1941, 0.0824, 0.0912, 0.1913, 0.1653, 0.1852, -0.0561,\n", 289 | " 0.0992, -0.0857, 0.0842, -0.1262, -0.1677, -0.0236, 0.1475, 0.1886,\n", 290 | " -0.0716, 0.0504, -0.1921, -0.0882, -0.0303, -0.1296, 0.2102, 0.1779,\n", 291 | " -0.2156, -0.1268, -0.2170, 0.0702, -0.0656, 0.0396, 0.0533, 0.2062,\n", 292 | " 0.1820, -0.0791, -0.1638, -0.2180, -0.1191, -0.1587, -0.0725, -0.0901,\n", 293 | " 0.0290, 0.1881, -0.1311, 0.0324, 0.1025, -0.2228, 0.0470, 0.1864,\n", 294 | " -0.0686, -0.2147, -0.0773, 0.1345, 0.0316, -0.0673, 0.0061, -0.0039,\n", 295 | " 0.1939, 0.1900, -0.0081, -0.0110, -0.0998, -0.2198, -0.1159, 0.1391,\n", 296 | " 0.1631, -0.0938, 0.0127, 0.1769, -0.2053, 0.0540, -0.1181, -0.0141,\n", 297 | " 0.1192, -0.0749, 0.1035, 0.0070, -0.0835, -0.2140, 0.2180, -0.1413,\n", 298 | " -0.1306, 0.1772, -0.1210, -0.0738, 0.1260, 0.0345, -0.0911, 0.0177,\n", 299 | " -0.1896, -0.1422, -0.0106, -0.0936, -0.0953, -0.2027, 0.1347, 0.2020,\n", 300 | " -0.1160, 0.2137, -0.1686, 0.0163, 0.0870, 0.1017, 0.0910, -0.1271,\n", 301 | " -0.0722, 0.1092, 0.0824, -0.0955, 0.1210, 0.1356, 0.0086, -0.2097,\n", 302 | " 0.2085, 0.0815, 0.1185, -0.0035, -0.1677, 0.0928, -0.0599, -0.2013,\n", 303 | " -0.0907, -0.1654, -0.1917, 0.0461, 0.1752, 0.1340, 0.0025, 0.1013,\n", 304 | " -0.0381, -0.0580, 0.0540, 0.1223, 0.1068, -0.1808, 0.1712, -0.2162,\n", 305 | " -0.1422, 0.2177, 0.1342, 0.0055, -0.0034, 0.1852, 0.0241, -0.1438,\n", 306 | " -0.1633, -0.1194, -0.2102, 0.1450, -0.2133, 0.1080, -0.2008, 0.2186,\n", 307 | " 0.0241, 0.0665, -0.0456, 0.0371, 0.1790, 0.0127, 0.1510, 0.1503,\n", 308 | " -0.1534, -0.0615, 0.1624, 0.0918, -0.1880, 0.0707, -0.0007, -0.0897,\n", 309 | " -0.0829, 0.0962, 0.1218, 0.1907, 0.0055, 0.1026, -0.1315, -0.1823,\n", 310 | " 0.0135, -0.1919, -0.0811, 0.0832, 0.1782, 0.0487, 0.0697, -0.1536,\n", 311 | " 0.1893, -0.0524, -0.1449, 0.0261, -0.1176, -0.0968, 0.1735, 0.1918,\n", 312 | " 0.1784, 0.2165, 0.0873, -0.1330, -0.1380, -0.1949, -0.0187, -0.0207,\n", 313 | " 0.1511, -0.0124, 0.0396, 0.0375, -0.0960, -0.0563, -0.0737, -0.2202,\n", 314 | " -0.0640, -0.2038, -0.0145, 0.1390, 0.1226, 0.1472, 0.0604, 0.0227,\n", 315 | " 0.2071, -0.2038, 0.0259, 0.0368, 0.0904, 0.0943, 0.0976, -0.2188])),\n", 316 | " ('output.weight',\n", 317 | " tensor([[-0.0408, 0.0305, 0.0432, ..., 0.0460, -0.0030, 0.0523],\n", 318 | " [ 0.0409, -0.0272, -0.0037, ..., 0.0370, 0.0170, 0.0018],\n", 319 | " [ 0.0430, -0.0135, 0.0283, ..., -0.0423, 0.0270, -0.0506],\n", 320 | " ...,\n", 321 | " [-0.0311, -0.0108, -0.0436, ..., 0.0370, -0.0097, 0.0209],\n", 322 | " [-0.0129, -0.0448, 0.0184, ..., 0.0207, 0.0233, -0.0390],\n", 323 | " [-0.0362, 0.0457, -0.0259, ..., 0.0594, 0.0383, 0.0596]])),\n", 324 | " ('output.bias',\n", 325 | " tensor([-0.0553, 0.0492, 0.0273, -0.0185, 0.0195, 0.0541, 0.0337, -0.0548,\n", 326 | " -0.0348, 0.0356]))])" 327 | ] 328 | }, 329 | "execution_count": 17, 330 | "metadata": {}, 331 | "output_type": "execute_result" 332 | } 333 | ], 334 | "source": [ 335 | "# 加载和保存模型参数\n", 336 | "class MLP(nn.Module):\n", 337 | " def __init__(self):\n", 338 | " super().__init__()\n", 339 | " self.hidden = nn.Linear(20, 256)\n", 340 | " self.output = nn.Linear(256, 10)\n", 341 | "\n", 342 | " def forward(self, x):\n", 343 | " return self.output(F.relu(self.hidden(x)))\n", 344 | "\n", 345 | "net = MLP()\n", 346 | "X = torch.randn(size=(2, 20))\n", 347 | "Y = net(X)\n", 348 | "net.state_dict()" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 18, 354 | "id": "f364c035", 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "data": { 359 | "text/plain": [ 360 | "MLP(\n", 361 | " (hidden): Linear(in_features=20, out_features=256, bias=True)\n", 362 | " (output): Linear(in_features=256, out_features=10, bias=True)\n", 363 | ")" 364 | ] 365 | }, 366 | "execution_count": 18, 367 | "metadata": {}, 368 | "output_type": "execute_result" 369 | } 370 | ], 371 | "source": [ 372 | "# 我们将模型的参数存储为一个叫做“mlp.params”的文件。\n", 373 | "torch.save(net.state_dict(), 'mlp.params')\n", 374 | "clone = MLP()\n", 375 | "clone.load_state_dict(torch.load('mlp.params'))\n", 376 | "clone.eval()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 19, 382 | "id": "c947a5e9", 383 | "metadata": {}, 384 | "outputs": [ 385 | { 386 | "data": { 387 | "text/plain": [ 388 | "tensor([[True, True, True, True, True, True, True, True, True, True],\n", 389 | " [True, True, True, True, True, True, True, True, True, True]])" 390 | ] 391 | }, 392 | "execution_count": 19, 393 | "metadata": {}, 394 | "output_type": "execute_result" 395 | } 396 | ], 397 | "source": [ 398 | "Y_clone = clone(X)\n", 399 | "Y_clone == Y" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "id": "6ae2ac4e", 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [] 409 | } 410 | ], 411 | "metadata": { 412 | "kernelspec": { 413 | "display_name": "Python 3", 414 | "language": "python", 415 | "name": "python3" 416 | }, 417 | "language_info": { 418 | "codemirror_mode": { 419 | "name": "ipython", 420 | "version": 3 421 | }, 422 | "file_extension": ".py", 423 | "mimetype": "text/x-python", 424 | "name": "python", 425 | "nbconvert_exporter": "python", 426 | "pygments_lexer": "ipython3", 427 | "version": "3.6.5" 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 5 432 | } 433 | -------------------------------------------------------------------------------- /31-循环神经网络RNN-文本预处理-自然语言处理-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5e73ae99", 6 | "metadata": {}, 7 | "source": [ 8 | "# 文本预处理\n", 9 | "\n", 10 | "* 将文本作为字符串加载到内存中。\n", 11 | "\n", 12 | "* 将字符串拆分为词元(如单词和字符)。\n", 13 | "\n", 14 | "* 建立一个词汇表,将拆分的词元映射到数字索引。\n", 15 | "\n", 16 | "* 将文本转换为数字索引序列,方便模型操作。\n", 17 | "\n", 18 | "1. 读取数据集\n", 19 | "2. 词元化\n", 20 | "3. 词汇表:语料、、pad、bos、eos\n", 21 | "4. 整合所有功能\n", 22 | "5. 小结\n", 23 | "6. 练习\n", 24 | "\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "id": "7385f991", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import collections\n", 35 | "import re" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "c7388789", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "# text lines: 3221\n", 49 | "the time machine by h g wells\n", 50 | "twinkled and his usually pale face was flushed and animated the\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "# 读取数据集\n", 56 | "def read_time_machine(): \n", 57 | " \"\"\"Load the time machine dataset into a list of text lines.\"\"\"\n", 58 | " with open(\"article.txt\", 'r') as f:\n", 59 | " lines = f.readlines()\n", 60 | " return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]\n", 61 | "\n", 62 | "lines = read_time_machine()\n", 63 | "print(f'# text lines: {len(lines)}')\n", 64 | "print(lines[0])\n", 65 | "print(lines[10])" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "id": "c0a62512", 72 | "metadata": { 73 | "scrolled": false 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "['t', 'h', 'e', ' ', 't', 'i', 'm', 'e', ' ', 'm', 'a', 'c', 'h', 'i', 'n', 'e', ' ', 'b', 'y', ' ', 'h', ' ', 'g', ' ', 'w', 'e', 'l', 'l', 's']\n", 81 | "[]\n", 82 | "[]\n", 83 | "[]\n", 84 | "[]\n", 85 | "['i']\n", 86 | "[]\n", 87 | "[]\n", 88 | "['t', 'h', 'e', ' ', 't', 'i', 'm', 'e', ' ', 't', 'r', 'a', 'v', 'e', 'l', 'l', 'e', 'r', ' ', 'f', 'o', 'r', ' ', 's', 'o', ' ', 'i', 't', ' ', 'w', 'i', 'l', 'l', ' ', 'b', 'e', ' ', 'c', 'o', 'n', 'v', 'e', 'n', 'i', 'e', 'n', 't', ' ', 't', 'o', ' ', 's', 'p', 'e', 'a', 'k', ' ', 'o', 'f', ' ', 'h', 'i', 'm']\n", 89 | "['w', 'a', 's', ' ', 'e', 'x', 'p', 'o', 'u', 'n', 'd', 'i', 'n', 'g', ' ', 'a', ' ', 'r', 'e', 'c', 'o', 'n', 'd', 'i', 't', 'e', ' ', 'm', 'a', 't', 't', 'e', 'r', ' ', 't', 'o', ' ', 'u', 's', ' ', 'h', 'i', 's', ' ', 'g', 'r', 'e', 'y', ' ', 'e', 'y', 'e', 's', ' ', 's', 'h', 'o', 'n', 'e', ' ', 'a', 'n', 'd']\n", 90 | "['t', 'w', 'i', 'n', 'k', 'l', 'e', 'd', ' ', 'a', 'n', 'd', ' ', 'h', 'i', 's', ' ', 'u', 's', 'u', 'a', 'l', 'l', 'y', ' ', 'p', 'a', 'l', 'e', ' ', 'f', 'a', 'c', 'e', ' ', 'w', 'a', 's', ' ', 'f', 'l', 'u', 's', 'h', 'e', 'd', ' ', 'a', 'n', 'd', ' ', 'a', 'n', 'i', 'm', 'a', 't', 'e', 'd', ' ', 't', 'h', 'e']\n" 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "# 词元化\n", 96 | "def tokenize(lines, token='word'):\n", 97 | " \"\"\"将文本行拆分为单词或字符词元。\"\"\"\n", 98 | " if token == 'word':\n", 99 | " return [line.split() for line in lines]\n", 100 | " elif token == 'char':\n", 101 | " return [list(line) for line in lines]\n", 102 | " else:\n", 103 | " print('错误:未知词元类型:' + token)\n", 104 | "\n", 105 | "tokens = tokenize(lines,token='char')\n", 106 | "for i in range(11):\n", 107 | " print(tokens[i])\n" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 7, 113 | "id": "af07b892", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# 构建词汇表\n", 118 | "class Vocab: \n", 119 | " \"\"\"文本词汇表\"\"\"\n", 120 | " def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):\n", 121 | " if tokens is None:\n", 122 | " tokens = []\n", 123 | " if reserved_tokens is None:\n", 124 | " reserved_tokens = []\n", 125 | " # 按出现频率排序\n", 126 | " counter = count_corpus(tokens)\n", 127 | " self.token_freqs = sorted(counter.items(), key=lambda x: x[1],\n", 128 | " reverse=True)\n", 129 | " # 未知词元的索引为0\n", 130 | " self.unk, uniq_tokens = 0, [''] + reserved_tokens\n", 131 | " uniq_tokens += [token for token, freq in self.token_freqs\n", 132 | " if freq >= min_freq and token not in uniq_tokens]\n", 133 | " self.idx_to_token, self.token_to_idx = [], dict()\n", 134 | " for token in uniq_tokens:\n", 135 | " self.idx_to_token.append(token)\n", 136 | " self.token_to_idx[token] = len(self.idx_to_token) - 1\n", 137 | "\n", 138 | " def __len__(self):\n", 139 | " return len(self.idx_to_token)\n", 140 | "\n", 141 | " def __getitem__(self, tokens):\n", 142 | " if not isinstance(tokens, (list, tuple)):\n", 143 | " return self.token_to_idx.get(tokens, self.unk)\n", 144 | " return [self.__getitem__(token) for token in tokens]\n", 145 | "\n", 146 | " def to_tokens(self, indices):\n", 147 | " if not isinstance(indices, (list, tuple)):\n", 148 | " return self.idx_to_token[indices]\n", 149 | " return [self.idx_to_token[index] for index in indices]\n", 150 | "\n", 151 | "def count_corpus(tokens): \n", 152 | " \"\"\"统计词元的频率。\"\"\"\n", 153 | " # 这里的 `tokens` 是 1D 列表或 2D 列表\n", 154 | " if len(tokens) == 0 or isinstance(tokens[0], list):\n", 155 | " # 将词元列表展平成使用词元填充的一个列表\n", 156 | " tokens = [token for line in tokens for token in line]\n", 157 | " return collections.Counter(tokens)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 8, 163 | "id": "8e2f1c2a", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "[('', 0), (' ', 1), ('e', 2), ('t', 3), ('a', 4), ('i', 5), ('n', 6), ('o', 7), ('s', 8), ('h', 9)]\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "vocab = Vocab(tokens)\n", 176 | "print(list(vocab.token_to_idx.items())[:10])" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 9, 182 | "id": "c7cb4bd7", 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "words: ['t', 'h', 'e', ' ', 't', 'i', 'm', 'e', ' ', 'm', 'a', 'c', 'h', 'i', 'n', 'e', ' ', 'b', 'y', ' ', 'h', ' ', 'g', ' ', 'w', 'e', 'l', 'l', 's']\n", 190 | "indices: [3, 9, 2, 1, 3, 5, 13, 2, 1, 13, 4, 15, 9, 5, 6, 2, 1, 21, 19, 1, 9, 1, 18, 1, 17, 2, 12, 12, 8]\n", 191 | "words: ['t', 'w', 'i', 'n', 'k', 'l', 'e', 'd', ' ', 'a', 'n', 'd', ' ', 'h', 'i', 's', ' ', 'u', 's', 'u', 'a', 'l', 'l', 'y', ' ', 'p', 'a', 'l', 'e', ' ', 'f', 'a', 'c', 'e', ' ', 'w', 'a', 's', ' ', 'f', 'l', 'u', 's', 'h', 'e', 'd', ' ', 'a', 'n', 'd', ' ', 'a', 'n', 'i', 'm', 'a', 't', 'e', 'd', ' ', 't', 'h', 'e']\n", 192 | "indices: [3, 17, 5, 6, 23, 12, 2, 11, 1, 4, 6, 11, 1, 9, 5, 8, 1, 14, 8, 14, 4, 12, 12, 19, 1, 20, 4, 12, 2, 1, 16, 4, 15, 2, 1, 17, 4, 8, 1, 16, 12, 14, 8, 9, 2, 11, 1, 4, 6, 11, 1, 4, 6, 5, 13, 4, 3, 2, 11, 1, 3, 9, 2]\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "for i in [0, 10]:\n", 198 | " print('words:', tokens[i])\n", 199 | " print('indices:', vocab[tokens[i]])\n", 200 | " " 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 10, 206 | "id": "6fcd23d5", 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "data": { 211 | "text/plain": [ 212 | "(170580, 28)" 213 | ] 214 | }, 215 | "execution_count": 10, 216 | "metadata": {}, 217 | "output_type": "execute_result" 218 | } 219 | ], 220 | "source": [ 221 | "# 整合所有功能\n", 222 | "def load_corpus_time_machine(max_tokens=-1):\n", 223 | " \"\"\"返回时光机器数据集的词元索引列表和词汇表。\"\"\"\n", 224 | " lines = read_time_machine()\n", 225 | " tokens = tokenize(lines, 'char')\n", 226 | " vocab = Vocab(tokens)\n", 227 | " # 因为时光机器数据集中的每个文本行不一定是一个句子或一个段落,\n", 228 | " # 所以将所有文本行展平到一个列表中\n", 229 | " corpus = [vocab[token] for line in tokens for token in line]\n", 230 | " if max_tokens > 0:\n", 231 | " corpus = corpus[:max_tokens]\n", 232 | " return corpus, vocab\n", 233 | "\n", 234 | "corpus, vocab = load_corpus_time_machine()\n", 235 | "len(corpus), len(vocab)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "8386ae85", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.6.5" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 5 268 | } 269 | -------------------------------------------------------------------------------- /62-计算性能-命令编程和符号编程-深度学习-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "fa5f87f7", 6 | "metadata": {}, 7 | "source": [ 8 | "1. 符号式编程\n", 9 | "2. 混合式编程:torchscript\n", 10 | "3. Sequential的混合式编程\n", 11 | " * 通过混合式编程加速\n", 12 | " * 序列化\n", 13 | "4. 小结" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 3, 19 | "id": "61128b5a", 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "10\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "# 命令式编程\n", 32 | "def add(a, b):\n", 33 | " return a + b\n", 34 | "\n", 35 | "def fancy_func(a, b, c, d):\n", 36 | " e = add(a, b)\n", 37 | " f = add(c, d)\n", 38 | " g = add(e, f)\n", 39 | " return g\n", 40 | "\n", 41 | "print(fancy_func(1, 2, 3, 4))" 42 | ] 43 | }, 44 | { 45 | "attachments": { 46 | "image.png": { 47 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWkAAACtCAYAAABydY5YAAAgAElEQVR4Ae2dbcweVZnHgVZCX59HVop9fdpiX6HP0/1Aodv2eYioQKUvrASDbItml8QgFBTRxpftflhWjFbYIGiQAgkf2Bi7qxs+GKVFFNSEbPiiiRtJGlFJUJTEDyaSTWbzm3o9nA4z9z0vZ2bOzP3/cGfmnjlzzsx1rut3Xec65577rDfeeCPSRzKQDkgHpANh6sBZ6pgwO0b9on6RDkgH0AFBWiMJjaSkA9KBgHVAkA64cxRJKZKSDkgHBGlBWlGUdEA6ELAOCNIBd46iKEVR0gHpgCAtSCuKkg5IBwLWAUE64M5RFKUoSjogHRCkBWlFUdIB6UDAOiBIB9w5iqIURUkHpAOCtCCtKEo6IB0IWAcE6YA7R1GUoijpgHRAkBakFUVJB6QDAeuAIB1w5yiKUhQlHZAOCNKCtKIo6YB0IGAdEKQD7hxFUYqipAPSAUFakFYUJR2QDgSsA4J0wJ2jKEpRlHRAOiBIC9KKoqQD0oGAdUCQDrhzFEUpipIOSAcEaUFaUZR0QDoQsA4I0gF0zgsvvBA9/fTTsx++j2oE5cqB/V/+8peSxV91Y1Rl8bvf/W7WNkw/ODYqNiJItwDp48ePR4cOHYpWTqyOzjrrrMwP5ylH+T4qJIbGsx08eDAaGxvPlAMy2jI5Fcuirw4MWRw7dizas3dvtDiHLI4ePRr1VRY4I55vz569A3UCvbh2z764bJ8dmCDdIKQxQoPR2k1T0Y23HYnueeKZ+PPf/xtFfJ584fXZY3sO3hFRDmVcNbE6NuI+wBogff7zn5+VxZZtM9E/fea+2edOyuIzD/5XhCwuXD4xK4vvf//7vXBcyAJHbM76siv3zcri/m+/GOsE8njk5KlYPsji3dfdPCuL6emZqC+yALQ4bGSxcNFYhCzu+MJj8XOnyYJzyGLBorH4mr179/Zy5CVINwDpb33rWzFkUT5gg8EZiPJsKY8yGqy7HEEhCxwVRoiTwinlkYGVwVgxXmQBoLosCxwVUTOQAThFZYGDx8EhCwAF8LvqxM1R4YiRhfV33i3Oa8lfnTig77Iskn0oSNcMaYZtGBFgKQrnpIJy/eXvOQ2oRx55pHMGaVESjqookJKyAFCMMoAc4E8qdsjfAYjJooyjSpMFcMP5dc1pIYvJqanYaZeBc1IW1EEAQHqsLykQQbpGSJshMpRPKlOV7xg24Cf6CBlGdm8YIlEvxuPDEF3Z2QijK04LWQAQomd3CO8+U5l9nN4l22Zip9UVWeBQVq5aHTtbn7IgmFm78bQD75rTMptxt4J0TZAG0ECJYVgZoxt2DbAD1LfffnvwoAbQvqHkygcniCxCh5MBmhFA1VGV+/zuvjmt0PPURLlE/ms2TlUeVbnPb/s4LUavjLS6DmpBugZIAwugURegTREN1CHDyZyVz0jJnt/d2ugiZINk5QbOqmqqx33utH1AHTKcXGdVtyxwAqySok03Ou3SviDtGdJEMADad4ojzRg5Rn6X9kKEkzkr3ymOLFkAJ6KzEHORjHgYWdXtrJAN4GO4T1olRDgxyYks6hpNuPqBLMjXz8zMCNJd8kx13itL5ZjccxWl7n1m+HdNh6WEwAFgEuHW/fxWvxnkgQMHgjJInEaTjht5IAtAyAqSOvW9aN0WxDDxa/1W9xbHiPxDHnEOkqMiaY+RtEWOTUQIrmKj8ChhSHlI4AAkgIV7r3XvWwoopGgap0E0V/ezJ+vHQZL2CCmaJphggjN5r3V/Z5RFADUIhqGeE6Q9QholQBnqVri0+pkkCUUJgQJwaDKKdmXCxNyePXuCMEjSUDjQuucn3Oe3fRwkOfBQomkLYppI+ZgMbEvg1NVoWpD2BGkzxqajaFNCG9KFEEGaMTYdRZssLJoOIYIkF91GFG2yYM6CibMQokRGFNsbTgWaHNgSQF0biPMu0h+CtCdIt22MKCFR05e//OXWDZIolsjeNZAi+0y6VnF2OAeiphB+5MLkHaAs8vw+y5rzDmFimdGVrwl19KNoXSE5b0HaE3iLCLJtY8SwiRR27ppuHdJVjRHAVgUVqxvankC0CcM2Uh2u/EJw3jZhWMX5us8EcIvmtkNy3kXYokjaE9B95B1RoipKbGuFiyiA77IGprKz91zH2lYMskruElnsaNlhGZhcuJTZr6oXwKxth0Ve3IfzNftgdFJmzmPysplgcvR5bU+Q9gjpsmDCCLdsuyKGEwa1ZNlEKVjbKo+8nV9HOQMTz1QGSAxh127aGkdJwBq5lKkHA+Ynx3U8Y946LTdf5v65Bhle/p79sV5UkQVAa9thAenJCqs6iJyxC+zDPmXsjWtDmUjNq0eCtAdIG5jKKA3GCJjc/Br53DJRgkG6zQkzk0VZMJGycfPZwKmMXC3/mNcQ6igHDKpMGgJXH7JAl/5uZ7tpMOZsAGQZvSB6BtA2srK+LRMIIM8QUoJF9E2Q9gjpMkpjEdOd9z4e3Xj7v8SfBYvHSy3ZMkgDyiJK4LMsk3VVhrWuMSIbDNt1YHmN3GTh89mK1hVHj5eVAxPPOd/Tr/KA9OLFY63pBHKbnp4uDWn6313aSt+iJ3l1wS2HLARpD9Aragxtl7fld2UiPqIEoEzUxPU2G2+5N1fBhu2HAKYqkTROLgn4spE0hk1dbeoGkF64eLwUTOj/pCyG9X/W+RAiaWRRNpJOjiyJpN0RRtZzpx3nHrq2DE+RtCengkGVgTTXuMoLXMpGCawiaBtMBukqTsZGJMimrCwA08VbJluFdNVRBX1pskAvXD1JA1DWsRBW/QDpizZvLeWwCGD48HzIY8ny1aVGV1yPDLmXNp130bYFaY+QxsNnGUrWcWDGsHbvzXdGV/79h2MlKmuMIQxryYeXdVjcP0Dh+ZFHMvWRJcO04xh128Nac1gG2rT7HHSMZ2DitKoseLdL26s7WL+PXgx63qxz5qyRA7qBXpQJiKj/XZu3CtJFvURfygMEN2+WpXBpxzFilI5Uh+2nlRt2jCFgCD+HXrFyotTEJw6L52eLPMrCDTlduGJ16+/aNodVxnlbXyOHqrIAjqw0adPWLCVok3/2fHm3rk5UqQNZtDlnU6YPFEl7iqSJFMrmH/Mq6rByIRgjSkjUxjK6Yfdb13kMOhRjrOK8fcjHUmAhvC6grPP2IQfqIGW0qOUJVEHaE3DLCLLqjziqKqIZY5vL70xulosFllWfq8z1GGPbqxlMFm077zh91HJu3mTBMjzSDWX61Mc1k5dd0Xrax2RRZKtI2iPYd+2ajrZcVu7HF1WVEAVkmVORzq+z7MpVE6XTP1VkQYqEEQ1AqPP58taN8yZ6I99e5bnKXGsjirZTHSYry9G38TN5UkahjK5MHnm3grRHSJsSohBljKrsNba4P4QhrSme/dqu6WgaGALFEEYUJgtbilclx15GN4iicZZ2HyFsSYUxX1DmeapcQ/DU9kRyWfkL0h4hTScQzTKka8ogaeedK1YHOYwDENvfs78xg8QhLFo8HtzsPQ6D9IstI6sCm7zXzq63b3nCMAkmSwuSksr7LFXLWRDTtQlDk50g7RnSKCEGyXK6qsqV53qWaNFeSJGjKZfN6DcBJ5zVRZu2Rlu2TAYpCxtZVFnpkUcfKIOzIuUTwkof0wV3y8iC1EMTI06cFY47lPSXK4e8+4K0Z0gjeINT3XlIHAFDe9rL2+FNl2sKTryIKFRnZTI3OJVdQpYH0jir+AVVgTorkwVpDxxJ3bKgDeaKrN0ubgXpGiCNIhicWICfx7iKlgHQoSy5G6b4RDHcax1RJFAC0KE7K5ORwamOyTMDNM4qpPkJe3Z3y8hvy+RkbaAG/qQBLwncWbkyydoXpGuCNAI3UAMRDKgoiNPKUw8pDqAUwj+PZClW8rhFkT7TQAzrSXF0BdAmE0CN07rls/d70Qn0BCgRNQKlkEdWJgO2gJr3aJCO8OnAcYDIItTUlyuDPPuCdI2QpgMwGCIbvDpvuksDb95jN9/1hWjOnLnR0mXLOmOIrhLitObNnx+dN29BpXwkjoo3BmLcGGLoUaMrA9s3B87SySq5WWRx6RXXxtDfNT0dZD7enjlrayOt5WvWl3qPutkPTttGmDjCrPa6dlyQrhnSKAQQseiJl8wA67yRNeUoD+SJvs4///zo/e9/fycV8C9/+Uu0Zs2aaMmSJfGzAKjPPvTt3I4LIyT6RBZEz0TnIU6Y5oUAqw1YFka/ApeissBRETEuWLAwriOE/7fM++xuuZMnT8b3//bz/2ZWFkVy1ZQlrYgcWVGEA3Tr7/q+IN0ApE1JgDUz7igTH9Zu2juk/+2JH0R8MFQ7xgSQlQXyXP/UU0/Fxx5//PHOKeLdd98dLViwIDp16lT8/gQiP3s+UkL23ElZYICkNSgLnJFFl+Fs+mBb0lbkZ3k+oGuywCGZLOx948jCHDYjNHNUhw8fjs4777zopZde6pxebN++Pbrqqqvi+wawgNZkgfNCL9JkwTmTBdd01UmZHmRtBekGIW2dAGAwTGCzY+d0DB6DlYGIvzviPOWSQLrllluiFStWRK+//npnDPK5556LDe/+++8/4555NgwT58W/h7hycGXBkLhLOXjr6yJbnDCgQRa8ZjUpC959QeSNLNLyzhs3bow++MEPniHfIu23UfaLX/xi/JzJ5+E7smBlRposOIYsKJO8to3nqLNNQboFSFft0FdffTW64IILoo997GOdMcgrrrgi4lP12XX9G5kyPH78eAy8J598MrNMSPL7xS9+Ec2dOzf63Oc+14n7bUt2gnQHIY2yPPzww7FBfu973wtewe+77/S/pDz//PPB32tbhuirXUZfa9eujf785z8HL+vrr78+uuSSS4K/T199U7YeQbqjkKbDd+/eHW3bti1oJSf/TB6afHRZJdV12dFzUjYvv/xytHjx4ugTn/hE0PJ+4okn4iDjO9/5TtD3mZRvG98F6Q5D+sUXX4wV/Z577glW0T/0oQ9F69evD/b+2jC6utt84IEHYr149tlng5T7n/70p2jVqlXRRz7ykSDvr+7+KVq/IN1hSNPZR44cic4+++zo5z//eXAK/81vfjOGBduiiqny+aPnNFldeeWV0a5du4KU+6FDh+KlpK+88kqQ95cmzzaPCdIdhzTKs3Xr1mj//v1BKTxrotetWxfddNNNQd1Xm8bWZNs//elPYwcZ2rK0EydOxPf10EMPSS9yskeQzimoJg2saFshrp1210QXfR6VrxZFm/xCXDvtrom2+9R2cH8L0j2ANErO2unly5cHsXY6a020jHGwMdYhn02bNgWzdvree++No+i+r2v23Y+CdE8gbWunb7311taHkVoT3TyMs8AQytpprYkurxOCdE8gjZGGsHZaa6LLG2MWaKseP3jwYOtrpz/wgQ9oTXRJ1gjSJQVX1XDqup6105deemkr0bTWRIcHaPSs7bXTWhNdTS8E6Z5Bus2101oTXc0Y63Lc1NvW2mmtia6uE4J0zyCNQbJ2mpfz/OxnP4sj6jonangpEG1qTXR1Y6wT0tSdXDttfee7Xeq1l4JpTXR1vRCkewhpjI6107x3GiMB2L4Nkfp4H/L8+fPj12VqTXR1Y6yjj9w6be00PxnfsWNHdPnll9eiF7w+denSpdFXvvIVrYn2wBdB2oMQXUMIZZ83iwHnc845J97W8Xf2/FCCt5jRBp+uvH0tlD5q+j6IbvkVInrBr1Trct44AOrmwx881BWxNy2/ttoTpHsI6WnnZfoYCiCt45dnvHFtzpw5Zxg9UVRbyqx2s6N5QEl0a/C0bR2pMF6oZfXjvPne93eB16l7gnQPIU3UjGG87W1vi40FkAJU34rEi5PMGHEE/HBCUVM2KH3Lv2h9OGrrL9v6/qsp+t/qti26Zznqoves8m9EgnQPIY1iYxQMOw3UGzZs8App6jcjZCtDDBfOLuiInBcuXBide+65cYqKf3lxz1fdJ2JGH4ig582b17v/G6wqnzLXC9I9hbQpg/0TM4Zjx9wtv0gjRbF3795oZmbmjA8/guBc2pCYaJ06mTj0HY2596d9//DHwe7cuTPuv7TJQ84fO3Ys7vukTvCdyeijR4+mjprQF/SC1Eqa3qg/i/enIN1zSGMURDekP2zykC1Qxpj4TE5ORuSxATpGxod9jvHeX8qMj49HQNvSGZQhOpchFje6UEBlQLX7AcxTU1OzekH/83+LphNsGTFxfGxsLC43MTERnwfs1GPX2HerW9vyeiJIjwCkMRDg+qUvfSmOlIEuxkQEbNAdZESUIZ9pwAbW/IGoDLG84Q2Sd5PncOC8NhTYohdAOO8kH84eZw6wceJAXKMq/zohSI8IpBmeGpwtoi4DA4wQWGOUeY25TDu6xr+xJ2WKk7URFXDO47CTdfCdegA0sCYSL1tPWt06ponD1Dxt3xSDyBdA+1yGh1FTJ/Dvm7xG4XkAKUAFrFWctisr6iR1hgNXGsyfk1Uk3fNI2gyxjqjXlnThBFxj1b4/A61DlgAUkAJU31EvUTV5bBx4HTpXhzxCr1OQ7jGkgSeRUp1RDekPDFK5yLDBbCACojhuAM2+Hfe9ZaSliNqPTgjSPYW0RblNRDPkIwG1r2Gzb2CovjdhwRI6HHedgDZ54wiYkGyiLWuzj1tBuoeQZggLNH3moIcpP5GTDPJNGA6TVxvnzXHXObJynws4M8mMbrjHtV9MTwTpHkKaaIkopkljwCCJ0Iiqm2xXbeUzePqH9EPT/cNITqOsfH2UpcuCdM8gTcqhLaMAAIBAw9tqRpllrFWO0zdNpTmS98mafAKH5HF9z6cngnTPIM3MOkbRhgFYNN1kmqWN5+xim21E0SYnCxyaSrNYu33ZCtI9gzTG2CYkyT+25ST6YpS+n8Mg6Xu5XZH7JIpvUy+L3GtoZQXpIZAGOGmrFkzx3Q6lbJWhPmkKtz7bx7jygM/uqU1jtCV5VeRgzx3q1uScdn+kFdy+oi8Y3aSVLXPMdAT55q2Xn24zgVemPfcaq8d9Pvf8oH0573ypjTQZCtKeIE2UUHVSxgwwraMwkGFrkWm/6QnD5L0Cj7Zy4sl7qet7EUgDNJ/DfFdH6O9hOoEMuAf0p6o8aLus87UJxKr3MIrXjyykUW57Paf702aUEOVnooO3gqHgGCXKgbHxAxGuA8quwRCpuAqMUrr1u+eyFI367DpeB+leQ0Q2DMAYYpkox+7HfXZXJnY+75bn6OLQlufnuel7+o6+sGdG/tb36Ifb99ZnnHf7IK3P0Dvqp2weeHNP1i7tuO1Sf54IuWo+2p69Sj3m2FydNtlqOzjKHklIAxCAh+LwYd+gwhCSoRnH2WIU7KNclu+1a8xgMDYXoBhiVv2DFJL6rG2MnV+GueUxSAzGPebuA2iud4/l3ef5aA858DzUUzb64tkBWd62Qynn9r0BEXkP6nt0gXwr5fmwb44SGbhyYN/0Ah3Js9qCPqdPaIf7M50zmVHfMNhzDfdm1xTd8vzoA+1zH0Wvpzxy5D7KXl+mzb5cM5KQRqkNdiggkQoGxD6KxNY6GCNBsTAqMz7OmRGzz7UosF1D3cn688Azq22rl/YHKTnnXSjYdXm2gMC9R2SQJ0pLq7vKfaTV19QxZGt9T//htKzvgaHdB/pDX/GdfjcHz3c3kkYO6I1dZ7pk39327Ji75bzbLvdk7Vq5ZBt23N1yDXW5x4ruoxvucxa9nvI+7qNMu12/ZmQhzXCWyBglxxCAG4qcZgQc53wSgFY2eQ5DT9ZPO8OUxeqzclzjGhffk/dgZdkOO++WTe5zLXJg636S5fJ8r3IfeeqvqwyOFzAT4QJfgyoydx0Y7Vtf8axuH1EHxyiTPGfX5L1/ty67JllHHllzjXuPVleRrcmiyDXJsj7uI1nnKHwfSUgnFZvoB0O0SMWiKRQAcKHgyUjajaa41gyTazBwN1WQPJ+lWCix23bSMJJGn6yH8+59JM8P+p6sG0AMG0Zn1Wf/1pF1PsTjyB35u89s8h/U98lI2u1rZOpG0sDfrZ9zNuJKkwl656a87B7dssk23HO2z3O592HH827T2s17rZUze6nqLKy+UdqOPKRRGndCBKMj/YHxYHDm/VFUjIyJJc4x+cM5lAUFdIel1MHEH+esfoMn37MMhvpsQom23Tqpi/YHGbULiKJKzLVE/9TP/QGoQW0Nqr8qFAbVXdc5AxH9wz79nNX3yMn6nvIGc9u3vkamSWeN3iBXnKDJeJBOUIZ7QcfQDWvX5MB5F/x23N1ShntxjxXZ5/7smYpc55aljuS9u+e1nz15OJKQRqkBKYqH8mIwZkwYKENbO8dxMwL3uuRQFIByLcrG1uqnLhTUhssAMEvhOW71Ut7qo04MG2MbpMzkDHE4g8oMOocsuAfuvSyguc7gNqitEM+Z7JEB/YQ8zaEO6nu7DrnR16ZLXJN0tK6MTa9og2vTZOLqo7Vj5fLoBGV5HtM/u7bI1mRR5JpkWWSJjSSP63s2nE02Iwlpe3ifW4yPz7A6MUwz4mFl3fNcY8Bwj7v71A0gzfjdc03tyxjPNDoAOaw/gG0ZiKJvw3SCfq/qvH3oDk4oyxH5qL/PdQjSQ37MkrfziXgwSDf6TbsWgxxWJnkd5fMqONF2GSeQbLPs96pRW9l2Q70OQA/rO/q3jE4g6zzP3fbohmcjeMjjUPI8z6iVEaQ9QRrFKQPgPApXpF4AzcRdnnp9lzEYyBjPjKaHRdJl+gHwIe+81+K8y0TreesfVA59ANJFHdGgOkfpnCDtEdIhKA6GS+4vT+rF9/0CgWF5c99tqr4zHUKWPAyURcCeVVeR44CZeZI2R3dF7jfEsoJ0zyCNkgFoDKPJyMVm75ncClHRdU+nJxBZXdKkLNBFd1K9ybb70pYg3UNIA2cMg6ViTSjqqVOnovXr12euWmniHtTG8IjaHGlToyzSPKQ5mmqvrzogSPcQ0iirGUgTw8zVq1fHxvjAAw804hT6aox1P9cf//jHeEkg4Pzud79ba1+RVmE0N2zStO5n7kP9gnRPIY1yWh6yzok8fmBB1H7jjTfGoP7Upz5Vq/H3wejaeIaf/OQn0ebNm+M5g/e+970xQOuY0OTZGMnxS0nWiDeZcmtDrk20KUj3GNIokIEamPpUKIyP/CaAtjz0gw8+GIN637590e9//3uv7fm891Gr67HHHovmzJkTAedf/epXcb8Q4RLp+nbggJ96AXTTk5R97VdBuueQRnGBKDC1nyRXVWZym0RKrORIRmNPP/10RPpj48aN0XPPPSdQt6xfhw8fjh3nbbfd9pa+IBVG6iP57vKy+sHP1wF0nt8LlG1jFK8TpFs2oqaUDpgCVYySqLrMMJTICNBTxyBD/PWvfx1dddVV0dlnnx3/cUJTz6h23pw8/MMf/hBdd911cV8NmisgksaBA1cgW0aGOG3W5qMXTcyBlLnHLl8jSI8IpE1JMUqDNas/MMxBw1LgThkiZ4yQYSxGafUN2hKhcY3y1G/Cc5C8fJ378Y9/PJt/ztNXOGxWYNBXwBonfvz48YGOnPP0r8GZNfKD9MjXs41iPYL0iEEaJccogTV5SaIojNM+RMoWLdsxoI4RWu65iKE89NBDcd04BOWp64e15Z/f9773RS+//HIuZ2r9CWR5zwejJOt7toAYnTBHbedw2ETOyZSX1aetn/4WpEcQ0knjIdriQzRlH4yVYz4M8MSJE9GaNWuiDRs2RD/60Y8KgSN5r/qebfiWf/aRcsCR0/84ZtMJtjh3jitqzu4H3zoqSAvSjUDzN7/5TXT11VfHERqG7luRR7k+N//81a9+VbLtmU0L0j3r0NBhdccdd8SgvvvuuwUTD7pn+WdSEkS4ofe/7q94BC5IezAUKV4xxXPz1GVWmUjep+X96KOPRuecc05UJv8sGRbT2TblJUgL0q1EX8pTV4PEpz/96XhE4iP/3CaA1PZwPRCkBelWII1xKk893ECTEHvttdei/fv3x4BW/rm4/JLy7MJ3QVqQbg3SZiDKU+eDzfPPPx9t2rQpXhKn/HM+mZmOdXkrSAvSrUMaA/ra174WR4es3Vae+q0AcvPP/KKzy9DRvb+1fwfJRJAWpIMx+JMnT0Zr166N3039wx/+MJj7GmRATZyz/DO/8GuiPbVRDKJ1y0uQFqSDMvzf/va30TXXXBNH1d/4xjeCure6jTFZv5t/5g2DyfP6HhZM6+oPQVqQDtL477zzzhjUn/zkJ4O8v7oM0uq1/DNvFOTNgnZc29EAs9vPgrQgHSwA3Dz1q6++Gux9ugblY//YsWPx+mfeJKj88+hBOalDgrQgHTT8Ri1PzRsDeYGR8s+Cs8FakBakg4Y0ijoKeWreEMg/2gBo5Z8FaAM0W0FakA4e0qawfc1T8w82/JON8s+Cs+m6uxWkBenOQBrF/frXvx5Hm9dee23Uhzw1+Wf+wUb5ZwHaBbO7L0gL0p2CNMr7zDPPRBdddFG0bt266Nlnn+3c/ZsBKv8sMJsuDNoK0oJ0JyH3yiuvRLt3746j6ocffrhTz6D8s+A8CMrJc4K0IN0pwCUV+OMf/3gM6rvuumv2Oar8tJyfpLMumQ+pCNuv8g81/MuN/fWYm3/mTYDJ59F3ATypA4K0IN15ULh56htuuCGGNn/1lFT2rO/2p6rLV53+x2v7D7+07e49w/+8120H6C9cuDBasGBBdOTIkTj/zD/U8AZAt5z2BecsHRCkBelewII89ZIlS2JAG1yH/Q8f0e3Kv4J5ybsmo8nrD0X77j8R3fDI/0S3/uD/Zj//+NRr8fF3Hz4WrdmxN5q38PSf9/Kv2sPa4H3Pc+bMiX+cwn199KMf7YW8s4Ci4/6djSAtSPcCGqQjDM5s586dG/EP5WnQoOzO6Zm4/IarD0YH/uOlWSC7cB60D7DHl05Ei8bG4z9qzWrHvSf+RSXrntKu1zH/wOuiTAVpQToVZF1T5unp6Ri6wNkFY/K9y0TPgJXImah5EKMIXGoAAAbsSURBVIjznLv0w/8ct7dn7963vGJ1+/btsxE00TT3NX/+/Nn8dNdkrPttx2kI0oJ0LyANQEg9MEnHD0MM1MuWLZt9Pv6lnONEz3kAnLcM6ZHzFo5FF2+ZmgW1tWVgPnDggOAsW5vVxSIOT5CW4pRSnCJK1kZZA/bmzZvjdIRBc8dtR70C2kBOyoToHFDT9qpVqyKBuZ3Isw19q7NNQVqQ7iWkXaPhJU1EtEwMGlTr2AJqIuqrr7mm9zJ15av9ep2RIC1I9xooLIFbPDYer8qoA8zJOkl94BCKLAEU5OqFXNflK0gL0r2GNCmHty+diFhGlwRqXd9JqQDqYcvzug4P3X8zzkWQFqR7C2kgCSxZLlcXkLPqHXvnRJyTFsiaAVmf5SxIC9K9hTRR9Pg7JxoHNODGMSiaFqB9OA9BWpDuLaRZD13Xao6sCNo9ziQiSwJ9GKrqGF3gC9KCdG6IZE2GkVZgiZsLEn4OzaSde6zIPr8KpD3qLVMXP2Ihki3za0IXtFX2WY998eRUaRkUkZfK9hfigrQgnRsiQC8NBgCRX/zZOX7VB1jte5kt64ypA1hTf9H6KH/huslWUh0G9mv+9XjsKKo4qzKy0zX9ArYgLUjHMAWEvJYzaeBEyXY8CWkAahB1IT05OfmWlQ1WNll/1nfacuGWVmfWtRzfsWu68i8LWU5XJRLnWp4D2Q66V53rF1R996cgPeKQJuolauUdzIBwfHx8FrDkU+2cvRvDFJDvlLfrDNLAmONWDtBOTU3FkTZl+LjwtXLuljLAjS31cY7UR1a6xb3W9i/bMR3xXg2LaotuSVW846LJeH01++cuGCu1jE+QFoBNJ8tuBekRhzRpATfSA4zkgQGpG81Shu8oGucpZ0pHHfYdkLqpCVZYuHDlnPvd6khurS07TvvWhh0btGXSsCykuY5XkhrY+aXiogvLrRLhOTR5KFAP0tVh5wTpEYc0CkI0DTh5PzKRNPtpUDRwJkHrlk2eI6qmXuq0NvLA1toyBaYN7s2+D9tyfVlIA2T3DXlJaBu882zPW5T9KtNhz6Dzgjs6IEiPOKSJdIEm0R6pBfZ9Qpp0CXUDWftYCmMQhNIgTV2DrnHPbbh4sjSkadsFMOmOskv5qCu58sW9T+0LxMN0QJAecUgDEYMmKY6JiYkY0pbusJ82A1gDJ9CZmZmZBSZQt+iY/UHpDoBNmaGKmVhJQvvWxrBrOX/5zvI5aSJpVmYAansXhxtZuwAfto/MuPc896wyAnaaDgjSIw5pImnAfOjQoXjCzyJplAUYj42Nxf8mwnGDNOe4jglB/mmElIYBFOC7E4dAnu9WlmjYnAKwzgK22xbtDSqbpti7pqfPyCsPg6l7HiAzUbh860xcB/fins+7b4AXpAXfNB3Ne0yQHnFIoyhAMwskQDbrHNcZcF2FA8QWgdvxtDaI1rMgbdfZNq1OO5e2JWKft2i8MFx5ERNwNRATUbPKw74X2ZIiWbh4TFG0bKySDgjSUqBKCpQGSCYiibTTzrnHiNTTIO+WYR8n4aZQkufTvuMkiICLpikANFH0Nff8Z/x5x7umZlMfRQBN2eVbp6Ob/mG4HNLuX8cUfZsOCNKC9FCYmrIU2QJVIuUi12SVLVvXipWrSr3oH7CzosP+PbwonClPRI6T0KShYJul13mPC9KCtBeQ5lW4JsuRSpm/aLzSrwbLAJprgPyixWPeHFWTclNbYTkWQVqQ7i2kgc3ylasq/zy8KKjt5+CKosOCXVedjyAtSPca0oCStIM7GVgUukXLr925LyLV0lUo6L7Dci6CtCDde5jwfhF++dfEX2iR5sAp5JkQFQzDgmGo/SFIC9K9hzQTmJsumYwuWDdVK6jt31iU5hB8fQJfkBakew9pDIYleUzk1QVqA3TRpYI+jVl19dM5CNKC9EhA2kBNRE3qw2eOeuPVN8cpjrw/zBFM+wnTuvpVkBakRwbSGBGpD3LU5I2Ba5U8Neupx5eujn9VqBSHwCtIC6YjBdO6FN7qBarLVq6Kfzo+df0duSNroH7l4UejlX87E4OeX1YmfwJvbWgrcPvQAUXSgv/Iwt/eHcJaaiLrty9dHUfXrNDgs//fT8ZAtu8GZsoSjWe908SHYaoOAd50QJAWpEcW0mYEbFkyx0uZdu6ajrbtmI5TGMCYD++m3r5zOn4fCe8l8fVzd7d97QvKWTogSAvSgrR0QDoQsA4I0gF3TpZn1XFFXdKB0dEBQVqQVhQlHZAOBKwDgnTAnaNoaXSiJfW1+jpLBwRpQVpRlHRAOhCwDgjSAXdOlmfVcUVd0oHR0QFBWpBWFCUdkA4ErAOCdMCdo2hpdKIl9bX6OksHBGlBWlGUdEA6ELAO/D/0zlIyYzL/lAAAAABJRU5ErkJggg==" 48 | } 49 | }, 50 | "cell_type": "markdown", 51 | "id": "f2059548", 52 | "metadata": {}, 53 | "source": [ 54 | "计算流程:\n", 55 | "![image.png](attachment:image.png)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "id": "45dd8d2d", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "\n", 69 | "def add(a, b):\n", 70 | " return a + b\n", 71 | "\n", 72 | "def fancy_func(a, b, c, d):\n", 73 | " e = add(a, b)\n", 74 | " f = add(c, d)\n", 75 | " g = add(e, f)\n", 76 | " return g\n", 77 | "print(fancy_func(1, 2, 3, 4))\n", 78 | "10\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "# 符号式编程\n", 84 | "def add_():\n", 85 | " return '''\n", 86 | "def add(a, b):\n", 87 | " return a + b\n", 88 | "'''\n", 89 | "\n", 90 | "def fancy_func_():\n", 91 | " return '''\n", 92 | "def fancy_func(a, b, c, d):\n", 93 | " e = add(a, b)\n", 94 | " f = add(c, d)\n", 95 | " g = add(e, f)\n", 96 | " return g\n", 97 | "'''\n", 98 | "\n", 99 | "def evoke_():\n", 100 | " return add_() + fancy_func_() + 'print(fancy_func(1, 2, 3, 4))'\n", 101 | "\n", 102 | "prog = evoke_()\n", 103 | "print(prog)\n", 104 | "y = compile(prog, '', 'exec')\n", 105 | "exec(y)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 14, 111 | "id": "d28c32e4", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "tensor([[ 0.0461, -0.2136]], grad_fn=)" 118 | ] 119 | }, 120 | "execution_count": 14, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "# 多层感知机\n", 127 | "import torch\n", 128 | "from torch import nn\n", 129 | "import dltools\n", 130 | "\n", 131 | "\n", 132 | "# 生产网络的工厂模式\n", 133 | "def get_net():\n", 134 | " net = nn.Sequential(nn.Linear(512, 256),\n", 135 | " nn.ReLU(),\n", 136 | " nn.Linear(256, 128),\n", 137 | " nn.ReLU(),\n", 138 | " nn.Linear(128, 2))\n", 139 | " return net\n", 140 | "\n", 141 | "x = torch.randn(size=(1, 512))\n", 142 | "net = get_net()\n", 143 | "net(x)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 15, 149 | "id": "b8d06a6a", 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "tensor([[ 0.0461, -0.2136]], grad_fn=)" 156 | ] 157 | }, 158 | "execution_count": 15, 159 | "metadata": {}, 160 | "output_type": "execute_result" 161 | } 162 | ], 163 | "source": [ 164 | "net = torch.jit.script(net)\n", 165 | "net(x)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 16, 171 | "id": "c6a26b49", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# 通过混合式编程加速\n", 176 | "# 时间测试方法\n", 177 | "class Benchmark:\n", 178 | " \"\"\"用于测量运行时间\"\"\"\n", 179 | " def __init__(self, description='Done'):\n", 180 | " self.description = description\n", 181 | "\n", 182 | " def __enter__(self):\n", 183 | " self.timer = dltools.Timer()\n", 184 | " return self\n", 185 | "\n", 186 | " def __exit__(self, *args):\n", 187 | " print(f'{self.description}: {self.timer.stop():.4f} sec')" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 17, 193 | "id": "73d3ef65", 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "无torchscript: 0.2520 sec\n", 201 | "有torchscript: 0.1780 sec\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "net = get_net()\n", 207 | "with Benchmark('无torchscript'):\n", 208 | " for i in range(1000): net(x)\n", 209 | "\n", 210 | "net = torch.jit.script(net)\n", 211 | "with Benchmark('有torchscript'):\n", 212 | " for i in range(1000): net(x)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 18, 218 | "id": "135270af", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "# 序列化模型\n", 223 | "net.save('my_mlp')" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "6dbe0141", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 3", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.6.5" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 5 256 | } 257 | -------------------------------------------------------------------------------- /63-计算性能-异步与并行-深度学习-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c2ffe63a-8e1d-457d-aa3b-09a3e76aa67d", 6 | "metadata": {}, 7 | "source": [ 8 | "1. 异步计算\n", 9 | "2. 自动并行" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 4, 15 | "id": "c135000a-3a0a-4129-8901-f7d780de63fa", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "import subprocess\n", 21 | "import numpy\n", 22 | "import torch\n", 23 | "from torch import nn\n", 24 | "import dltools" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 5, 30 | "id": "b1634ccd-3dbc-4111-acab-60249ffc8fe2", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "numpy: 2.8189 sec\n", 38 | "torch: 0.0508 sec\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "# GPU计算\n", 44 | "device = dltools.try_gpu()\n", 45 | "a = torch.randn(size=(1000, 1000), device=device)\n", 46 | "b = torch.mm(a, a)\n", 47 | "\n", 48 | "with dltools.Benchmark('numpy'):\n", 49 | " for _ in range(10):\n", 50 | " a = numpy.random.normal(size=(1000, 1000))\n", 51 | " b = numpy.dot(a, a)\n", 52 | "\n", 53 | "with dltools.Benchmark('torch'):\n", 54 | " for _ in range(10):\n", 55 | " a = torch.randn(size=(1000, 1000), device=device)\n", 56 | " b = torch.mm(a, a)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 9, 62 | "id": "c3650baf-62f6-435d-a72b-6618577da206", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Done: 0.1034 sec\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "with dltools.Benchmark():\n", 75 | " for _ in range(10):\n", 76 | " a = torch.randn(size=(1000, 1000), device=device)\n", 77 | " b = torch.mm(a, a)\n", 78 | " torch.cuda.synchronize(device)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 8, 84 | "id": "ee361df3-46c6-4c4a-932c-e6761c0fbf4b", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "tensor([[3., 3.]], device='cuda:0')" 91 | ] 92 | }, 93 | "execution_count": 8, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "x = torch.ones((1, 2), device=device)\n", 100 | "y = torch.ones((1, 2), device=device)\n", 101 | "z = x * y + 2\n", 102 | "z" 103 | ] 104 | }, 105 | { 106 | "attachments": { 107 | "12d8339f-8994-4e71-af46-ed03573b7a2c.png": { 108 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAO4AAADRCAYAAADLwo5aAAAbAUlEQVR4Ae2dbagd1XrHIyaam5yXXWtzk5vknPiSNsZkHxXqrSbZ53IVvFZPrFJEocYLtQglyVXwg1yxLXhb+8FY/dCCaNpCqC1oqEg/KKbtLUIQpBa/pVgpCCJItfhNQmHKb5LnsLLce595WTN71tr/D8PaM7PWmplnPb/nedbLzF53/vz5TJtkIB2ISwfWqcHiajC1l9oLHRC4ijgUcUWoAwI3wkaT15XXFbgCVx43Qh0QuBE2mjyuPK7AFbjyuBHqgMCNsNHkceVxBa7AlceNUAcEboSNJo8rjytwBa48boQ6IHAjbDR5XHlcgStw5XEj1AGBG2GjyePK4wpcgSuPG6EOCNwIG00eVx5X4ApcedwIdUDgRtho8rjyuAJX4MrjRqgDAjfCRpPHlccVuAJXHjdCHRC4ETaaPK48rsAVuPK4EeqAwG2w0d58883sk08+uQQM9r/88stLjo3zoOQ/c+bMJWXee++97MMPPyxcx7j6Yz332muvfef5kUuZ5yE/snXLDKvXPd+V3wK3IXABbmVl5RKlALZer5cVVbBjx45lCwsLeT3z8/OZKRXg+3V3RaHauA8M4rPPPnuJbJHNunXrLjk26l6Q39LSUtbv93M5ImMzhNTj1z2qnkkeF7gNgfvII49kKJg1LgqBggBgEXBRJPJbefYpa/tAbSDbsWlJAc6iFlJAQzZFwSU/7WPyYt81hMjdj5Qsb1fSqQMXhcfr0dDHjx9fbTwDYXFxMT/vngM0LDRlDh8+vNqowGTHKWcgoUwuZDQ2ikH+wWBQCFzqIL8pCvuuYnIOBbbzk065H2TDPSILM1ocR7Ymd865hgtoOAYsyJzn5FmQpbXT8vLyqiw47kJH/dTB9Vz5jJMHZVwwqZN2sTLUx2b7XUynClwawywrCsJvayAazhSCRkUJSNlQKhqbBnQbmfKuglrdbh6/0YuC65fj3uz+7Bz3aIpuxyaVYkReeOGFXEbIDMMFoGzcJzLh3sgDqPw2Odkz8HzWHu6zUcaOk8eu4z6rXcc9VuQ313YNDWWoC4NcpPyk8kwVuD40NJBZWv8cish5lASLzyCGbaZUeBEa/cSJE6tg05AomSma37D+dfzzw/ZRVjc8tDxV6rKyIVOUH5m4dSIbZOAD5e5z/3hZkyty5Bj18LzI/eTJk5d4x1HP7Nbr3se439w3gPoGcdjzjKtnEuemClw8Jw1sguY3x9j3FcL2UT5+G4yW4lUoh9eg4fEwZqUtj13HTa1e99io3yjQkSNH8uvz289Xpi6/bMj9YdCYDPxz7j73j+wsr6XcG8+L0SSKwShwjuOjntmtt8izEUHRXlavX8Y3RP75Se9PFbgogYVsCJ7fFt76CmH75DEgKYNCWf8XxbIQmnMWHqIMvhW3hrZ6bX9cSt5R9VDON0Tj6mr6HIpuxoxrce/IxwfK3Uf2LjjIEnkjY/e423+lXuueuM/k1useH/ab+6T/7OqCm4+6/DEK93wXfk8VuDQIyo5SAJ+r+D5Q7j5hGwMvhHKEbwYT4SBQc5z6LJxFAfk9rIHdejmP8lCPn5fjwMD13M3yodxdUi5kah6MKMFk4QPl7ltfGNkhQ2AyKF2ZI3uTEddxoTZ5uPXaMcq4htWO035cy5Wr1U8eZG9tbGW6lk4VuAgfZcET0Piuh6CBgcEayN+nMSlDanlIUTSOU6dbHqPg1m9l/HrJY17f8pByHGX0N8vTReXi2UxGJgtSnsHue9i+tYefz467MkcuyNbqs9Svl+OUc+u0vNynL1eO2Xk/MrPjXUqnDty2hI/SoMRrXQ/wzcusldc9j0caZhjcPKn+xhsOA9J/XvKVlRH5R0VLfv2T3Be4DS3AoFGx3GUVp4gymJcvkjfFPHjXYVFKiGcFdtf7hqiziToEboPgNtFgqlNvBqEDAlfgrhnOy1h0z1gIXIErcCPUAYEbYaPJA3bPA7bdJgJX4MrjRqgDAjfCRmvbuut63fPwAlfgyuNGqAMCN8JGkwfsngdsu00ErsCVx41QBwRuhI3WtnXX9brn4QWuwJXHjVAHBG6EjSYP2D0P2HabCNyA4PJCgX2GpYmXC9pWji5dj4X/JlteMujSvU3iXgRuRXBRHr6HxEvj8/O9/KV3Xnz3N86Rh7xSuGKeEqPHi/UrKxe+GunL1PYXFnflHzA4ffr01IEscEuCC3x8scGUZ9+ty9ljP/+L7E9P/Wu+vfYv/52x2T7nyGP5gVgADwcYrzpYviCrzbPz2Y/vfzT72fN/syrLv//wf7OX3vqP1f2VIz/Lrr1hKZctBhLYJ+H9JnFNgVsCXF6Mn5vvZSgVCoUivf2fWaGNvJT5/vbFvI4iL9lPQiEmcU0MGQYN47b/1uXs53/5j4VkarLHUD589I/z8njhKh8mmMRz17mmwC0ALorFd49m5uZzBSkDrCmXpZRFyWZm57PBYHnqvS9edufirmzr9sXcsJmcqqQAjJfGAKRuGAXuGuDS39rfX8q9LGFaFYUaVoa68Nz7+ktRfHGhjncYVRbPSARDV6KOMfTlS2QDvCl3SwTuGHDxtHgD+lEhFcsUjTqv3cNfm/SmzvPyzSjgwkOaPEKmGEaiGiKlUYYj5uMCdwy4/aWlxqA1JTV48erTMmhFeIynbQpaky3wYhzcT6/GDKt77wJ3BLg0NhY7ZHhsCuWn5h0I7dzGSfE3xgkjRaTRRBTjy9bC5tQGrATuEHAtjCs7uukrTZl9po/wDkU+Oxoz0AwaYRDbgNbkz7RRat0RgTsE3EOD5Xxawhq+rZRBGq4dM5jj7h1vS4jMqHpbMuU6GAmMRUojzQLXA9e8LR6wTeXiWuZ1UwvrDGa6H0z7tOltrQ0xFhiNVMYRBK4HLh/EZhTZGrztlAUI3IMpe0op4SorydqWKdfDWNAVcf/OJGbZClwPXKzypJQLBePa3EPMSjXs3hlJBpw2BvtGGYYf3nFf9nuJGEWB64DbBeVi9U+Kg1T0L1nuOQqqNo6nZBQFrgMu/w4HNG0o0bhrcA8pDaTggQ8NBhkeb9xzN33OxhBSeOVS4DrgAkud/i0LCph6QAEJCTdVnAemn5sauAcODiqPJhOFIEsX7C0/WCwddls/N4UpN4HrgcuUjKsgZX6bgpFunuvlo8Rlylte7uHokD+7HtZ3jOXYjf2lyuAiF6IQG41mhLjqqqtUuiEC1wGXv26sAy4KhkIBLSt2DMSy6YX53EFSA1QAU2f+FpkQ6gIv3hbjWFau5Oc+6BLFYvBG3afAdcAlPA0BrusdqigX95Da+tob9/drgYtBxBjSFaljAORxHYUfZRViOw641+29qZIlB1AUCuhQsjrKlWQf91D1Pq7JFrnS17WQuaxRxEsL3ATBZXKehi2rEOTHGxDCoVTW162qYNxDaoNTjCrXiWZsRLhOF8Tq0KhyYvDWmcdFKdx+F6PKVcClHOByL7FFLOPul37lzFyvklHEMPLCxzV76q1oy+dx5+aTkKv6uJ7x2bFzsVaYW8Vbu2VQrtlElMsFuapRxPj92alfZlu276o8Sm/yvfaGm5JZTipwPXBZJ1ynn2tKUjW9LiHlcsHl986FxdV57qLyIQIhxK4TInMtDACRTCovcAhcD9wuvB2UwgIBH1r26bcTLlfpQhQFfVQ+Bgvn5ub1dtCwhknlGPO5eL5RStDU8f4Pf5QdOpTW/K2rE/n7uBe/lNmUDIfVa6PJKczfmjzlcT2Pi2AYdSSsor85TBGaOMa1UhyUMkWzFHhm53qllyvWkflv3fk7eZhu95BCKnCHgEvDsgACkOhj1VGaImW5Bsqc2qKLUYDs7/fzcYQ2QmYziKn0bU2mAncEuAiIkJk+mTvNUwTEMnlQXq6Rcohsymaphcz7b/1Ro0aRKSSMb0ohsslQ4I4BFwXbt7+fbd2xqxHPi6dlBJtrpPJJFVOstVKmh5j2IoxtwvMyCk0Uk+rXRATuGHBRPoAitMMrhvzq49FfvJpdceXGqYTWoDZ4e7+6JXv+7/4tmPd9+Nif5J42VWiRn8BdA1xTMpSAsOuOB35ay0PgXaiDulIN40xmRdKnnnoql8OmTTPZHzzzUi14Wb3GyDxyTeXbUqNkKHALgosAUYYdC4u598Wql+n7kpcyeG5WZ1EX85ooWWrLG0cpm3/83XffzZ8fOZgs6JY88ed/W8o40uUwY3jw0CD5b1MjR4FbAlwERuiMkjGZD3Tf37ErO/zoEzmULM1zNzwI58hDXspQ1u3PDgaDbHFx8ZJjvoKnuP/1119nu3fvzu6///7VtcNMw1lkg7xYooixA2RXrs/81Vv5cWDFEJKXVVmpjRyPa3eBWxJcV5iscGIK58ChQa48KJC/cY48o1ZDAfHCwkKyf07lysv9/fDDD2c7duzIPv/881Vw7TwyAUIgZuDOlyn7GEG8KyPGKbztY89eNBW4NcAdJmSUzvWow/L4x2wBforTFv6zsv/iiy/mML711lvfgXZYfjs2jYDas/upwA0Mri/govtAiydJvb/7/vvv58/5zDPPlIK2qBynJZ/A7Qi4KBwLPnq9dP4mw4fo22+/zZaWlrK77rpL0NbUO4FbU4C+ctbZt/7u8nKaf/z12GOPZVddddVU9knr6MWwsgK3Q+DSQNbfZfR5WIPFeuyVV17JQ+TXX389qeeaVHsI3I6BiyJYf3fUSPSklKXqdT/66KNs/fr12ZNPPiloA+mbwA0kyKpKPapcSv3d2267LTt48KCgDahrAjegMEdBWOV4Kv3d48ePZ1deeWX28ccfC9yAuiZwAwqzCqDjysTe3z116lTer0193fC4NmzqnMDtMLg0eqz93XPnzmXz8/PZ448/Lk/bgI4J3AaEGtrKxtjfvfPOO7NbbrlF0DakXwK3IcGGhDe2/u7TTz+dh8hnz54VuA3pl8BtSLAhwaWuWPq7p0+fzqF9+eWXBW2DuiVwGxRuaHi73t/97LPPsm3btiX7uZjQ7VmnPoEbEbg0NP3drr6/e/jw4WzPnj3ZN998I2/bsF4J3IYFXMeqDitr/V0gGXZ+Useee+65PEQ+c+ZMp+5rUvJo+roCNzJwUQiWQvIKIKFz0wpSpP533nknv5/nn3++E/dT5J5jzyNwIwQXpbNvNE36/d2vvvoqu+6667IHHnhA0LaoSwK3RWGHtvJ8r4r3WwmfqdvS0NcZV99DDz2U7dy5M/viiy8Ebou6JHBbFPY4AKqcA1RWJx05ciT/RtPs7Gwr77r2+/18eurEiRN5iPz2228L2pb1SOC2LPAqgI4rw0fV3I+pNb0u2K63cePG/LqpvTc8TtZdOidwIwaX/u3evXuzyy+/PIeIlOmiJhWML1ZeccUVq8YCbz+JEL3JZ4yhboEbMbh4102bNmUbNmxYBWlmZqZRcFlg4Xp4fsvrnm9U5sMMicCNGFwalE+W8qK6C9O4kWa8I3OtbCdPnsyhY5miHRumJHaMa9l18LqbN2+eqo+Qmxy6kArcyME1JWJOF+8LWP78LsAxkMSiDQNvVMpXJgl/gdoPgW3JJWUJyf3zdi9Km/fAAjcRcIEFQG+//faMaSL2AYsvUAAao8/8MwDhNfmGwcXCDuAESsoAMcBb3nvuuUdetiP6InA70hAGR4gU+AAO8Ph7kyojzUBvizxYGw3UDEzJyzbvTYvogMBNEFxCXTwm4NUFDe9sHriKASiihMpT3hgI3ITABVJWUhEWh/60q3lfjIJAKw9aaJkJ3ITA5R8QCI3HjSrXUSA8rnnyOvWobH3wBW4i4OIJ8bRNQWuwGbwKm+vDZzKtkgrcBMC1ZYikVZSgbBkGqRj4GjU6XbY+5S9vBARu5ODSrwWitlcv8aJBqn9OFoMhEbiRgwuwhMh1R4/LKqu9zB96EKzsfUxrfoEbMbiT8rYGCws95HXLh7kmvzqpwI0YXOvbtu1tTeEmfX27j2lMBW5JcBlNbWsQaC2FZAmjLW9cK28T5zEYTA9phLl9rytwS4DLaCrTLqz/ZWsChjJ1Mijlv1BQpnyIvKyqwoCEqEt1FDcAAvfiYnz3s6J4EnffFIrpD0ZT2YZNhfgDNeQZls/qq5vi7fxr1q2zbHkGxybp9cvebyr5Be5FcFlxZCEwXgTv6jcykKConBsGjFsHZamnqTCShRZdARfP78tK+8W9ZxVZCdyLoTIgMq1C6IlHrTLgA9QWNuJpqa9KoxQpw/0CbpX7LFJ/0Tw2QFU0v/KFAVrgOn1cwAOGqssGXVgxAAZxE8pq4DZRd5k6u3IfZe45hbwC1wGXvpp53aqNa+ExXruqAShybQOmyWsUuQ8MFMauSF7lCeNtkaPAvQiuDbKY16wKBH1ae0unSUW1qRgAbvI6a9Vtclsrn86HgxZZCtyL/z2Lp7URYLyI+w8BZZWurtcuej08nQ2oFS0TOh8DdRpVDgtlkTYSuE6oXERgRfIwumxGoEj+qnkIx5vsRxe5Lz5rg9ctkld5wgEucAOCS5jMlxTbgonIoOxUDP+sN2r74IMPSgFId6LOYJ5Arg6ywA0ILv1NvE9bUzQGTpl+7quvvpqN2sqCy7M2OeUlsEeDLXADgjsJRaN/OYk3dDBOk3gPeBIy7uI1BW7k4Nq0UBmvG0IRzdu2FV2EuOeU6hC4kYOLMjJ3XGcUvKxCW4hOH7tsWeUfHf6WkY3ATQBcvB59TQbGyjR+lbxci5FkTQGFAbBKG1BG4CYALg2JFwTeJr97DLR49qpruasqqcp910gI3ETARbmZjmJ6pgl4DVqMA0ZCMH0XpjZlInATAhfFYSUVcDHSDGwhlAlQGUFuev11iHudljoEbmLgoriAxuqtutM1gI/3xovTpw1lCKYFriafU+AmCC4KA2Q2ZcNgEv/eV3QZJl//4NM8gI8BaOpjAE0qdup1C9xEwTXFBVaWYBI+4zmBmNFnoH7iiSdWf7NPeA2s5ANYjsnLTrYva+3opwI3cXDdBieEtrd5CH2vvvrqHNI9e/bkoTDHmJst6pnduvW7XcAF7hSB68P1xhtv5OCeO3cuyCCWX7/2m4NZ4ArcTOA2B1hTxkvgClyBG6EOCNwIGy2UFVeoHJ+ntbYXuAJXHjdCHRC4ETaaWd26qTyuPK5GJSM0AAJX4ApcgSsdaFEHFCq3KOy6oW3o8vK48riythEaAIErcAWuwJUOtKgDCpVbFHboULduffK48riythEaAIErcAWuwJUOtKgDCpVbFHbd0DZ0eXlceVxZ24gMgP1VyoMPPpi/1nf06NH8pXm9hxsPyPK4EQEXyuPyQj1fuVi/fn122WWX5enmzZtlfCPSBYEbUWOFApd6+OoF8LJt2LAh/zJGyPpVV7PeW+BOKbh8ogaPa/DihQVbs7CFlK/AnVJw6c8atHjfkEqlupo3AAJ3SsEFrnvvvTeHV3/e1TxooY2ZwJ1icO0vS/QJVoGrkCsyQ8A3l0N7A9XXvCGQx40MtLpQMAjFvxrwsXO2u+++e/X3yZMn9U3lSPRB4EbSUHWA5Y/A+A+g2fkL/1LAoNSW6/vZD5YGq9v81sXVwartC4v5X5BoQUbznrNquwrchMFlhdSBQ8s5kDtvHmQHjp7IHvmH/8r+8Jf/N3J78LV/z/q/ezzburuflwN4Adw9gAVuguAy2GT/srf9pkF230v/PBLUcRD/+OmTmXliwuuq3kHlwoMvcBMDF2hv7C9lG2fmKwPrw4ynJrzGGAjC8BBWkanATQhcBp7m5nt5/5WQ1wewzv7dvzidG4P9/SX9g18HdEbgdqARqlhcvwye9ob9S9mvXd/Pfv+f/icotAY8xgBP/og878QjD4GbCLgHB8t5f7QpaF14CZuZSvKNh/bbC6MFbgLgAtH3Zuaz0OGxweqnDFoBL6PWgrU9WF1ZC9zIwWWqZnaul0/1+IA1uc9o9e2HlgXuhPRH4E5I8K71rPObJYu9bYuN9GnHgc98MF6X9c517l9lq3lsgRsxuPZqHqHrOMiaOvcbPzmS7esvCdwJ6JDAnYDQQ3kZe7unKTDXqpcpIryuVlZV85p19EDgNgguYC0vL+cbC/jrNNSwsvesrGTXHDg8EW9rUDM9pPd5BW5w5R6m8CGOsdhh3HusLOzv9/sZ+dj4Hbo/yEsDkwqTDVwMx70rK9G0W4i270Id8rgVPe5gMBg7HcIUjQsqXokyIRudMLXqOmQGl37zp390yValLuq49UDY5wopo1TrErgNgesrzMrKStBFC8yh1gEXSA1cvOYVm6vNA1MHnt9/Xu03Gz4L3JLgAou/reVJ8byEyuNC67KKbuBayFo1ZdHGlTPVQ24MAPIoe//KXw9sgVsSXFO4tUJly3fs2LHg0FJ3CHBZHjm7dVetfrLArQeg6UnZVOA2CC6LI0J7WmtgA7fqMkegvfr6pfyl+aremnKEyvK47cMrcCuCS/g7bv6Swaler5edPn06O3PmTL4xumzghUgBpsqAEsDx2RoWUNSB1sDV4JTADarYIeCoWgeDUYTT7kbYXLW+YeUAt8p0EF7a/d4Uv/lcTRWIgf/gIY0qD2ufJo/J41b0uE02StG6f/veDizAmO1pAcYEdEjgnj+ff66UbyoN2z799NOgXrIolEXyMTf8vdleJU9Zxbv6ZWxgalyXochzKE/5UFvgnj+fnT17duTWZXC78JLBjp0LnTVsKRsEgTuBMCekQl14rW9X615Xr/WV95Ih213gRg7upLwug1Lb5W0nFm0I3MjBxYozWk1fd62Pnft91Kr79jofL1KE9CKqq7gXF7gJgJt/S3l/P9uye6mxLzwa5Ewl5UZCfxY2UaMlcBMAF08FvDNz89k1B+9rrL/Laqvetl3ZAc3bThRa2lvgJgIujcnKLOBlKWPoz7TiaTfO9rK9+8O+LKHwuHh47MpK4CYErsF7w75+9ivbdgX7XKt9jpUvboR8w8lVRP0uB7DATQxcAAAullqyJHLPTx6tPGjFAoudN1/4t7/QyzUFajlQfXkJ3ATBtUbmRQimbAzgO57+6zVDaEamDx59cRVY+rOhX46w+1NaHV6BmzC4BgYAs64ZgNm2/vpStv2m5Xwgi9fy+L3j5uU8vOY8/WQWdvDqoNWhtDpkTchO4E4BuKY4hNDMvfLKIRvelI2+qx0TrN0C1NrOTwXuFIHrN77244B0WDsJXIGrcDhCHRC4ETbaMAusY/F6zyptJ3AFrjxuhDogcCNstCoWWmXS8sgCV+DK40aoAwI3wkaT90zLe1ZpT4ErcOVxI9SB/wfVMEEjVq41uwAAAABJRU5ErkJggg==" 109 | } 110 | }, 111 | "cell_type": "markdown", 112 | "id": "f08387c9-0437-41e0-bca6-971dd834d1b4", 113 | "metadata": {}, 114 | "source": [ 115 | "![image.png](attachment:12d8339f-8994-4e71-af46-ed03573b7a2c.png)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 10, 121 | "id": "23dc5513-d225-4f78-aa56-78c0bff66220", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "# 基于多GPU的并行计算\n", 126 | "devices = dltools.try_all_gpus()\n", 127 | "def run(x):\n", 128 | " return [x.mm(x) for _ in range(50)]\n", 129 | "\n", 130 | "x_gpu1 = torch.rand(size=(4000, 4000), device=devices[0])\n", 131 | "x_gpu2 = torch.rand(size=(4000, 4000), device=devices[1])\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "f01b2b76-6567-4ee4-b4d2-858b5dc91e2f", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "run(x_gpu1)\n", 142 | "run(x_gpu2) \n", 143 | "torch.cuda.synchronize(devices[0])\n", 144 | "torch.cuda.synchronize(devices[1])\n", 145 | "\n", 146 | "with d2l.Benchmark('GPU1 time'):\n", 147 | " run(x_gpu1)\n", 148 | " torch.cuda.synchronize(devices[0])\n", 149 | "\n", 150 | "with d2l.Benchmark('GPU2 time'):\n", 151 | " run(x_gpu2)\n", 152 | " torch.cuda.synchronize(devices[1])" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "72f57e0c", 158 | "metadata": {}, 159 | "source": [ 160 | "GPU1 time: 0.4872 sec\n", 161 | "GPU2 time: 0.5076 sec" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "251fc1a8-31e5-48dc-bf4e-999de56d1015", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "with dltools.Benchmark('GPU1 & GPU2'):\n", 172 | " run(x_gpu1)\n", 173 | " run(x_gpu2)\n", 174 | " torch.cuda.synchronize()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "2d1de176-3ebb-4522-bc98-54f4807de993", 180 | "metadata": {}, 181 | "source": [ 182 | "GPU1 & GPU2: 0.4918 sec" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "5142520e-558d-443f-a1ff-f3a2d00a9ab3", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "# 并行计算与通信\n", 193 | "def copy_to_cpu(x, non_blocking=False):\n", 194 | " return [y.to('cpu', non_blocking=non_blocking) for y in x]\n", 195 | "\n", 196 | "with dltools.Benchmark('在GPU1上运行'):\n", 197 | " y = run(x_gpu1)\n", 198 | " torch.cuda.synchronize()\n", 199 | "\n", 200 | "with dltools.Benchmark('复制到CPU'):\n", 201 | " y_cpu = copy_to_cpu(y)\n", 202 | " torch.cuda.synchronize()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "b7e66bae-1840-43b6-85dd-e881c2f8655f", 208 | "metadata": {}, 209 | "source": [ 210 | "在GPU1上运行: 0.4904 sec\n", 211 | "复制到CPU: 2.3745 sec" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "f3531457-0753-449b-aadf-f1450896ac63", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "with dltools.Benchmark('在GPU1上运行并复制到CPU'):\n", 222 | " y = run(x_gpu1)\n", 223 | " y_cpu = copy_to_cpu(y, True)\n", 224 | " torch.cuda.synchronize()" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "6723889d-5e36-4c16-8dc7-d370f562b4d3", 230 | "metadata": {}, 231 | "source": [ 232 | "在GPU1上运行并复制到CPU: 1.7353 sec" 233 | ] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "Python 3", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.6.5" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 5 257 | } 258 | -------------------------------------------------------------------------------- /7-线性代数.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "64ed6990", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 5, 16 | "id": "e479cbd8", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "\n", 22 | "os.makedirs(os.path.join('..', 'data'), exist_ok=True)\n", 23 | "data_file = os.path.join('..', 'data', 'house_tiny.csv')\n", 24 | "with open(data_file, 'w') as f:\n", 25 | " f.write('NumRooms,Alley,Price\\n') # 列名\n", 26 | " f.write('NA,Pave,127500\\n') # 每行表示一个数据样本\n", 27 | " f.write('2,NA,106000\\n')\n", 28 | " f.write('4,NA,178100\\n')\n", 29 | " f.write('NA,NA,140000\\n')" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 6, 35 | "id": "fd1f153c", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | " NumRooms Alley Price\n", 43 | "0 NaN Pave 127500\n", 44 | "1 2.0 NaN 106000\n", 45 | "2 4.0 NaN 178100\n", 46 | "3 NaN NaN 140000\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "import pandas as pd\n", 52 | "data = pd.read_csv(data_file)\n", 53 | "print(data)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 7, 59 | "id": "d529f492", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "inputs,outputs = data.iloc[:, 0:2], data.iloc[:, 2]\n", 64 | "inputs = inputs.fillna(inputs.mean())" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 8, 70 | "id": "8315bc8e", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | " NumRooms Alley\n", 78 | "0 3.0 Pave\n", 79 | "1 2.0 NaN\n", 80 | "2 4.0 NaN\n", 81 | "3 3.0 NaN\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "print(inputs)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 9, 92 | "id": "54d2c89f", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "inputs = pd.get_dummies(inputs, dummy_na=True)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 11, 102 | "id": "27088b37", 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/html": [ 108 | "
\n", 109 | "\n", 122 | "\n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | "
NumRoomsAlley_PaveAlley_nan
03.010
12.001
24.001
33.001
\n", 158 | "
" 159 | ], 160 | "text/plain": [ 161 | " NumRooms Alley_Pave Alley_nan\n", 162 | "0 3.0 1 0\n", 163 | "1 2.0 0 1\n", 164 | "2 4.0 0 1\n", 165 | "3 3.0 0 1" 166 | ] 167 | }, 168 | "execution_count": 11, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "inputs" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 14, 180 | "id": "78adde0b", 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "(tensor([[3., 1., 0.],\n", 187 | " [2., 0., 1.],\n", 188 | " [4., 0., 1.],\n", 189 | " [3., 0., 1.]], dtype=torch.float64),\n", 190 | " tensor([127500, 106000, 178100, 140000]))" 191 | ] 192 | }, 193 | "execution_count": 14, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "import torch\n", 200 | "X,y = torch.tensor(inputs.values),torch.tensor(outputs.values)\n", 201 | "X,y" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 15, 207 | "id": "a59d0fb5", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "x = torch.tensor([3.0])" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 16, 217 | "id": "15fac9ee", 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "data": { 222 | "text/plain": [ 223 | "(tensor([5.]), tensor([6.]), tensor([1.5000]), tensor([9.]))" 224 | ] 225 | }, 226 | "execution_count": 16, 227 | "metadata": {}, 228 | "output_type": "execute_result" 229 | } 230 | ], 231 | "source": [ 232 | "y = torch.tensor([2.0])\n", 233 | "x+y,x*y,x/y,x**y" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 17, 239 | "id": "230cc79c", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "tensor([0, 1, 2, 3])" 246 | ] 247 | }, 248 | "execution_count": 17, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "x = torch.arange(4)\n", 255 | "x" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 18, 261 | "id": "ec66d584", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "4" 268 | ] 269 | }, 270 | "execution_count": 18, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "len(x)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 19, 282 | "id": "ceceb2d6", 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "data": { 287 | "text/plain": [ 288 | "torch.Size([4])" 289 | ] 290 | }, 291 | "execution_count": 19, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "x.shape" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 20, 303 | "id": "963e1ce3", 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "data": { 308 | "text/plain": [ 309 | "tensor([[ 0, 1, 2, 3],\n", 310 | " [ 4, 5, 6, 7],\n", 311 | " [ 8, 9, 10, 11],\n", 312 | " [12, 13, 14, 15],\n", 313 | " [16, 17, 18, 19]])" 314 | ] 315 | }, 316 | "execution_count": 20, 317 | "metadata": {}, 318 | "output_type": "execute_result" 319 | } 320 | ], 321 | "source": [ 322 | "A = torch.arange(20).reshape(5,4)\n", 323 | "A" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 21, 329 | "id": "adc91a5b", 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "text/plain": [ 335 | "tensor([[ 0, 4, 8, 12, 16],\n", 336 | " [ 1, 5, 9, 13, 17],\n", 337 | " [ 2, 6, 10, 14, 18],\n", 338 | " [ 3, 7, 11, 15, 19]])" 339 | ] 340 | }, 341 | "execution_count": 21, 342 | "metadata": {}, 343 | "output_type": "execute_result" 344 | } 345 | ], 346 | "source": [ 347 | "A.T" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 22, 353 | "id": "9a0b5294", 354 | "metadata": {}, 355 | "outputs": [ 356 | { 357 | "data": { 358 | "text/plain": [ 359 | "tensor([[[ 0, 1, 2, 3],\n", 360 | " [ 4, 5, 6, 7],\n", 361 | " [ 8, 9, 10, 11]],\n", 362 | "\n", 363 | " [[12, 13, 14, 15],\n", 364 | " [16, 17, 18, 19],\n", 365 | " [20, 21, 22, 23]]])" 366 | ] 367 | }, 368 | "execution_count": 22, 369 | "metadata": {}, 370 | "output_type": "execute_result" 371 | } 372 | ], 373 | "source": [ 374 | "X = torch.arange(24).reshape(2,3,4)\n", 375 | "X" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 24, 381 | "id": "b1156b1b", 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "X = torch.arange(24).reshape(4,6)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 25, 391 | "id": "4f2877c9", 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "data": { 396 | "text/plain": [ 397 | "tensor([[ 0, 1, 2, 3, 4, 5],\n", 398 | " [ 6, 7, 8, 9, 10, 11],\n", 399 | " [12, 13, 14, 15, 16, 17],\n", 400 | " [18, 19, 20, 21, 22, 23]])" 401 | ] 402 | }, 403 | "execution_count": 25, 404 | "metadata": {}, 405 | "output_type": "execute_result" 406 | } 407 | ], 408 | "source": [ 409 | "X" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 26, 415 | "id": "9bfc9af9", 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "data": { 420 | "text/plain": [ 421 | "tensor([[ 0, 2, 4, 6, 8, 10],\n", 422 | " [12, 14, 16, 18, 20, 22],\n", 423 | " [24, 26, 28, 30, 32, 34],\n", 424 | " [36, 38, 40, 42, 44, 46]])" 425 | ] 426 | }, 427 | "execution_count": 26, 428 | "metadata": {}, 429 | "output_type": "execute_result" 430 | } 431 | ], 432 | "source": [ 433 | "X * 2" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 27, 439 | "id": "7966c3ae", 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "data": { 444 | "text/plain": [ 445 | "tensor(276)" 446 | ] 447 | }, 448 | "execution_count": 27, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "X.sum()" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 28, 460 | "id": "315e0a41", 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "(tensor([36, 40, 44, 48, 52, 56]), torch.Size([6]))" 467 | ] 468 | }, 469 | "execution_count": 28, 470 | "metadata": {}, 471 | "output_type": "execute_result" 472 | } 473 | ], 474 | "source": [ 475 | "X_sum_axis0 = X.sum(axis=0)\n", 476 | "X_sum_axis0,X_sum_axis0.shape" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 30, 482 | "id": "26666dba", 483 | "metadata": {}, 484 | "outputs": [ 485 | { 486 | "data": { 487 | "text/plain": [ 488 | "(tensor([ 15, 51, 87, 123]), torch.Size([4]))" 489 | ] 490 | }, 491 | "execution_count": 30, 492 | "metadata": {}, 493 | "output_type": "execute_result" 494 | } 495 | ], 496 | "source": [ 497 | "X_sum_axis1 = X.sum(axis=1)\n", 498 | "X_sum_axis1,X_sum_axis0.shape" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 31, 504 | "id": "cc946cf4", 505 | "metadata": {}, 506 | "outputs": [ 507 | { 508 | "data": { 509 | "text/plain": [ 510 | "tensor([[ 15],\n", 511 | " [ 51],\n", 512 | " [ 87],\n", 513 | " [123]])" 514 | ] 515 | }, 516 | "execution_count": 31, 517 | "metadata": {}, 518 | "output_type": "execute_result" 519 | } 520 | ], 521 | "source": [ 522 | "sum_X = X.sum(axis=1,keepdims=True)\n", 523 | "sum_X" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 32, 529 | "id": "d89e9f66", 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "tensor([[ 0, 1, 2, 3, 4, 5],\n", 536 | " [ 6, 8, 10, 12, 14, 16],\n", 537 | " [18, 21, 24, 27, 30, 33],\n", 538 | " [36, 40, 44, 48, 52, 56]])" 539 | ] 540 | }, 541 | "execution_count": 32, 542 | "metadata": {}, 543 | "output_type": "execute_result" 544 | } 545 | ], 546 | "source": [ 547 | "X.cumsum(axis=0)" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 33, 553 | "id": "e71ebed6", 554 | "metadata": {}, 555 | "outputs": [ 556 | { 557 | "data": { 558 | "text/plain": [ 559 | "tensor([[ 0, 1, 2, 3, 4, 5],\n", 560 | " [ 6, 7, 8, 9, 10, 11],\n", 561 | " [12, 13, 14, 15, 16, 17],\n", 562 | " [18, 19, 20, 21, 22, 23]])" 563 | ] 564 | }, 565 | "execution_count": 33, 566 | "metadata": {}, 567 | "output_type": "execute_result" 568 | } 569 | ], 570 | "source": [ 571 | "X" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 34, 577 | "id": "726f6dc1", 578 | "metadata": {}, 579 | "outputs": [ 580 | { 581 | "data": { 582 | "text/plain": [ 583 | "tensor([[ 0, 1, 2, 3, 4, 5],\n", 584 | " [ 6, 8, 10, 12, 14, 16],\n", 585 | " [18, 21, 24, 27, 30, 33],\n", 586 | " [36, 40, 44, 48, 52, 56]])" 587 | ] 588 | }, 589 | "execution_count": 34, 590 | "metadata": {}, 591 | "output_type": "execute_result" 592 | } 593 | ], 594 | "source": [ 595 | "X.cumsum(axis=0)" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": 35, 601 | "id": "668c1124", 602 | "metadata": {}, 603 | "outputs": [ 604 | { 605 | "data": { 606 | "text/plain": [ 607 | "torch.Size([4, 6])" 608 | ] 609 | }, 610 | "execution_count": 35, 611 | "metadata": {}, 612 | "output_type": "execute_result" 613 | } 614 | ], 615 | "source": [ 616 | "X.shape" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 43, 622 | "id": "19e878ef", 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "Y = torch.ones(6,2, dtype = torch.float32)" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": 44, 632 | "id": "b5599949", 633 | "metadata": {}, 634 | "outputs": [ 635 | { 636 | "data": { 637 | "text/plain": [ 638 | "tensor([[1., 1.],\n", 639 | " [1., 1.],\n", 640 | " [1., 1.],\n", 641 | " [1., 1.],\n", 642 | " [1., 1.],\n", 643 | " [1., 1.]])" 644 | ] 645 | }, 646 | "execution_count": 44, 647 | "metadata": {}, 648 | "output_type": "execute_result" 649 | } 650 | ], 651 | "source": [ 652 | "Y" 653 | ] 654 | }, 655 | { 656 | "cell_type": "code", 657 | "execution_count": 45, 658 | "id": "2314ea5d", 659 | "metadata": {}, 660 | "outputs": [ 661 | { 662 | "ename": "RuntimeError", 663 | "evalue": "1D tensors expected, but got 2D and 2D tensors", 664 | "output_type": "error", 665 | "traceback": [ 666 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 667 | "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", 668 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mY\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 669 | "\u001b[1;31mRuntimeError\u001b[0m: 1D tensors expected, but got 2D and 2D tensors" 670 | ] 671 | } 672 | ], 673 | "source": [ 674 | "torch.dot(X,Y)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 46, 680 | "id": "358b284f", 681 | "metadata": {}, 682 | "outputs": [ 683 | { 684 | "data": { 685 | "text/plain": [ 686 | "torch.Size([6, 2])" 687 | ] 688 | }, 689 | "execution_count": 46, 690 | "metadata": {}, 691 | "output_type": "execute_result" 692 | } 693 | ], 694 | "source": [ 695 | "Y.shape\n", 696 | "\n" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 66, 702 | "id": "f800ed37", 703 | "metadata": {}, 704 | "outputs": [ 705 | { 706 | "data": { 707 | "text/plain": [ 708 | "(tensor([1., 1., 1., 1.]), tensor([1., 2., 3., 4.]))" 709 | ] 710 | }, 711 | "execution_count": 66, 712 | "metadata": {}, 713 | "output_type": "execute_result" 714 | } 715 | ], 716 | "source": [ 717 | "x,y = torch.ones(4,dtype = torch.float32),torch.tensor([1,2,3,4],dtype = torch.float32)\n", 718 | "x,y" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 62, 724 | "id": "95230594", 725 | "metadata": {}, 726 | "outputs": [ 727 | { 728 | "data": { 729 | "text/plain": [ 730 | "tensor([2., 3., 4., 5.])" 731 | ] 732 | }, 733 | "execution_count": 62, 734 | "metadata": {}, 735 | "output_type": "execute_result" 736 | } 737 | ], 738 | "source": [ 739 | "x + y" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": 63, 745 | "id": "2fcafef1", 746 | "metadata": {}, 747 | "outputs": [ 748 | { 749 | "data": { 750 | "text/plain": [ 751 | "tensor([1., 2., 3., 4.])" 752 | ] 753 | }, 754 | "execution_count": 63, 755 | "metadata": {}, 756 | "output_type": "execute_result" 757 | } 758 | ], 759 | "source": [ 760 | "x * y" 761 | ] 762 | }, 763 | { 764 | "cell_type": "raw", 765 | "id": "b362a39e", 766 | "metadata": {}, 767 | "source": [ 768 | "torch.sum(x*y)" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": 64, 774 | "id": "47c6fd1d", 775 | "metadata": {}, 776 | "outputs": [ 777 | { 778 | "data": { 779 | "text/plain": [ 780 | "tensor(10.)" 781 | ] 782 | }, 783 | "execution_count": 64, 784 | "metadata": {}, 785 | "output_type": "execute_result" 786 | } 787 | ], 788 | "source": [ 789 | "torch.sum(x*y)" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 67, 795 | "id": "f0c3c775", 796 | "metadata": {}, 797 | "outputs": [ 798 | { 799 | "data": { 800 | "text/plain": [ 801 | "tensor(10.)" 802 | ] 803 | }, 804 | "execution_count": 67, 805 | "metadata": {}, 806 | "output_type": "execute_result" 807 | } 808 | ], 809 | "source": [ 810 | "torch.dot(x,y)" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": 68, 816 | "id": "e6e030e9", 817 | "metadata": {}, 818 | "outputs": [ 819 | { 820 | "data": { 821 | "text/plain": [ 822 | "tensor([[ 0, 1, 2, 3, 4, 5],\n", 823 | " [ 6, 7, 8, 9, 10, 11],\n", 824 | " [12, 13, 14, 15, 16, 17],\n", 825 | " [18, 19, 20, 21, 22, 23]])" 826 | ] 827 | }, 828 | "execution_count": 68, 829 | "metadata": {}, 830 | "output_type": "execute_result" 831 | } 832 | ], 833 | "source": [ 834 | "X" 835 | ] 836 | }, 837 | { 838 | "cell_type": "code", 839 | "execution_count": 74, 840 | "id": "804e31cf", 841 | "metadata": {}, 842 | "outputs": [], 843 | "source": [ 844 | "B = torch.ones(6,2,dtype = torch.long)" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": 70, 850 | "id": "7a01fead", 851 | "metadata": {}, 852 | "outputs": [ 853 | { 854 | "data": { 855 | "text/plain": [ 856 | "tensor([[1., 1.],\n", 857 | " [1., 1.],\n", 858 | " [1., 1.],\n", 859 | " [1., 1.],\n", 860 | " [1., 1.],\n", 861 | " [1., 1.]])" 862 | ] 863 | }, 864 | "execution_count": 70, 865 | "metadata": {}, 866 | "output_type": "execute_result" 867 | } 868 | ], 869 | "source": [ 870 | "B" 871 | ] 872 | }, 873 | { 874 | "cell_type": "code", 875 | "execution_count": 75, 876 | "id": "510dc9c5", 877 | "metadata": {}, 878 | "outputs": [ 879 | { 880 | "data": { 881 | "text/plain": [ 882 | "tensor([[ 15, 15],\n", 883 | " [ 51, 51],\n", 884 | " [ 87, 87],\n", 885 | " [123, 123]])" 886 | ] 887 | }, 888 | "execution_count": 75, 889 | "metadata": {}, 890 | "output_type": "execute_result" 891 | } 892 | ], 893 | "source": [ 894 | "torch.mm(X,B)" 895 | ] 896 | }, 897 | { 898 | "cell_type": "code", 899 | "execution_count": 76, 900 | "id": "f7c0900a", 901 | "metadata": {}, 902 | "outputs": [ 903 | { 904 | "data": { 905 | "text/plain": [ 906 | "tensor(5.)" 907 | ] 908 | }, 909 | "execution_count": 76, 910 | "metadata": {}, 911 | "output_type": "execute_result" 912 | } 913 | ], 914 | "source": [ 915 | "u = torch.tensor([3.0,-4.0])\n", 916 | "torch.norm(u)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": 77, 922 | "id": "e75113cf", 923 | "metadata": {}, 924 | "outputs": [ 925 | { 926 | "data": { 927 | "text/plain": [ 928 | "tensor(7.)" 929 | ] 930 | }, 931 | "execution_count": 77, 932 | "metadata": {}, 933 | "output_type": "execute_result" 934 | } 935 | ], 936 | "source": [ 937 | "torch.abs(u).sum()" 938 | ] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": 1, 943 | "id": "2aaba756", 944 | "metadata": {}, 945 | "outputs": [], 946 | "source": [ 947 | "import pandas as pd" 948 | ] 949 | }, 950 | { 951 | "cell_type": "code", 952 | "execution_count": 2, 953 | "id": "bb84481d", 954 | "metadata": {}, 955 | "outputs": [], 956 | "source": [ 957 | "%config Completer.use_jedi = False" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": null, 963 | "id": "2a3f5ce3", 964 | "metadata": {}, 965 | "outputs": [], 966 | "source": [ 967 | "pd." 968 | ] 969 | } 970 | ], 971 | "metadata": { 972 | "kernelspec": { 973 | "display_name": "Python 3", 974 | "language": "python", 975 | "name": "python3" 976 | }, 977 | "language_info": { 978 | "codemirror_mode": { 979 | "name": "ipython", 980 | "version": 3 981 | }, 982 | "file_extension": ".py", 983 | "mimetype": "text/x-python", 984 | "name": "python", 985 | "nbconvert_exporter": "python", 986 | "pygments_lexer": "ipython3", 987 | "version": "3.8.5" 988 | } 989 | }, 990 | "nbformat": 4, 991 | "nbformat_minor": 5 992 | } 993 | -------------------------------------------------------------------------------- /74-NLP-BERT训练之数据集处理-预训练-深度学习-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "id": "a02706cb", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import random\n", 12 | "import torch\n", 13 | "import dltools" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 9, 19 | "id": "a16faf51", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "def _read_wiki(data_dir):\n", 24 | " file_name = os.path.join(data_dir, 'wiki.train.tokens')\n", 25 | " with open(file_name, 'r',encoding=\"utf-8\") as f:\n", 26 | " lines = f.readlines()\n", 27 | " # 大写字母转换为小写字母\n", 28 | " paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) >= 2]\n", 29 | " random.shuffle(paragraphs)\n", 30 | " return paragraphs" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 34, 36 | "id": "f13a408d", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# 生成下一句预测任务的数据\n", 41 | "def _get_next_sentence(sentence, next_sentence, paragraphs):\n", 42 | " if random.random() < 0.5:\n", 43 | " is_next = True\n", 44 | " else:\n", 45 | " # paragraphs是三重列表的嵌套\n", 46 | " next_sentence = random.choice(random.choice(paragraphs))\n", 47 | " is_next = False\n", 48 | " return sentence, next_sentence, is_next" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 35, 54 | "id": "4db70d6d", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#@save\n", 59 | "def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):\n", 60 | " nsp_data_from_paragraph = []\n", 61 | " for i in range(len(paragraph) - 1):\n", 62 | " tokens_a, tokens_b, is_next = _get_next_sentence(\n", 63 | " paragraph[i], paragraph[i + 1], paragraphs)\n", 64 | " # 考虑1个''词元和2个''词元\n", 65 | " if len(tokens_a) + len(tokens_b) + 3 > max_len:\n", 66 | " continue\n", 67 | " tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)\n", 68 | " nsp_data_from_paragraph.append((tokens, segments, is_next))\n", 69 | " return nsp_data_from_paragraph" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 36, 75 | "id": "040bc80d", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# 生成遮蔽语言模型任务的数据\n", 80 | "def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,\n", 81 | " vocab):\n", 82 | " # 为遮蔽语言模型的输入创建新的词元副本,其中输入可能包含替换的“”或随机词元\n", 83 | " mlm_input_tokens = [token for token in tokens]\n", 84 | " pred_positions_and_labels = []\n", 85 | " # 打乱后用于在遮蔽语言模型任务中获取15%的随机词元进行预测\n", 86 | " random.shuffle(candidate_pred_positions)\n", 87 | " for mlm_pred_position in candidate_pred_positions:\n", 88 | " if len(pred_positions_and_labels) >= num_mlm_preds:\n", 89 | " break\n", 90 | " masked_token = None\n", 91 | " # 80%的时间:将词替换为“”词元\n", 92 | " if random.random() < 0.8:\n", 93 | " masked_token = ''\n", 94 | " else:\n", 95 | " # 10%的时间:保持词不变\n", 96 | " if random.random() < 0.5:\n", 97 | " masked_token = tokens[mlm_pred_position]\n", 98 | " # 10%的时间:用随机词替换该词\n", 99 | " else:\n", 100 | " masked_token = random.choice(vocab.idx_to_token)\n", 101 | " mlm_input_tokens[mlm_pred_position] = masked_token\n", 102 | " pred_positions_and_labels.append(\n", 103 | " (mlm_pred_position, tokens[mlm_pred_position]))\n", 104 | " return mlm_input_tokens, pred_positions_and_labels" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 37, 110 | "id": "de98eb80", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# \n", 115 | "def _get_mlm_data_from_tokens(tokens, vocab):\n", 116 | " candidate_pred_positions = []\n", 117 | " # tokens是一个字符串列表\n", 118 | " for i, token in enumerate(tokens):\n", 119 | " # 在遮蔽语言模型任务中不会预测特殊词元\n", 120 | " if token in ['', '']:\n", 121 | " continue\n", 122 | " candidate_pred_positions.append(i)\n", 123 | " # 遮蔽语言模型任务中预测15%的随机词元\n", 124 | " num_mlm_preds = max(1, round(len(tokens) * 0.15))\n", 125 | " mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(\n", 126 | " tokens, candidate_pred_positions, num_mlm_preds, vocab)\n", 127 | " pred_positions_and_labels = sorted(pred_positions_and_labels,\n", 128 | " key=lambda x: x[0])\n", 129 | " pred_positions = [v[0] for v in pred_positions_and_labels]\n", 130 | " mlm_pred_labels = [v[1] for v in pred_positions_and_labels]\n", 131 | " return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 38, 137 | "id": "6cb40b9d", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# 将文本转换为预训练数据集\n", 142 | "def _pad_bert_inputs(examples, max_len, vocab):\n", 143 | " max_num_mlm_preds = round(max_len * 0.15)\n", 144 | " all_token_ids, all_segments, valid_lens, = [], [], []\n", 145 | " all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []\n", 146 | " nsp_labels = []\n", 147 | " for (token_ids, pred_positions, mlm_pred_label_ids, segments,\n", 148 | " is_next) in examples:\n", 149 | " all_token_ids.append(torch.tensor(token_ids + [vocab['']] * (\n", 150 | " max_len - len(token_ids)), dtype=torch.long))\n", 151 | " all_segments.append(torch.tensor(segments + [0] * (\n", 152 | " max_len - len(segments)), dtype=torch.long))\n", 153 | " # valid_lens不包括''的计数\n", 154 | " valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))\n", 155 | " all_pred_positions.append(torch.tensor(pred_positions + [0] * (\n", 156 | " max_num_mlm_preds - len(pred_positions)), dtype=torch.long))\n", 157 | " # 填充词元的预测将通过乘以0权重在损失中过滤掉\n", 158 | " all_mlm_weights.append(\n", 159 | " torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (\n", 160 | " max_num_mlm_preds - len(pred_positions)),\n", 161 | " dtype=torch.float32))\n", 162 | " all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (\n", 163 | " max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))\n", 164 | " nsp_labels.append(torch.tensor(is_next, dtype=torch.long))\n", 165 | " return (all_token_ids, all_segments, valid_lens, all_pred_positions,\n", 166 | " all_mlm_weights, all_mlm_labels, nsp_labels)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 39, 172 | "id": "3f4738e0", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "#@save\n", 177 | "class _WikiTextDataset(torch.utils.data.Dataset):\n", 178 | " def __init__(self, paragraphs, max_len):\n", 179 | " # 输入paragraphs[i]是代表段落的句子字符串列表;\n", 180 | " # 而输出paragraphs[i]是代表段落的句子列表,其中每个句子都是词元列表\n", 181 | " paragraphs = [dltools.tokenize(\n", 182 | " paragraph, token='word') for paragraph in paragraphs]\n", 183 | " sentences = [sentence for paragraph in paragraphs\n", 184 | " for sentence in paragraph]\n", 185 | " self.vocab = dltools.Vocab(sentences, min_freq=5, reserved_tokens=[\n", 186 | " '', '', '', ''])\n", 187 | " # 获取下一句子预测任务的数据\n", 188 | " examples = []\n", 189 | " for paragraph in paragraphs:\n", 190 | " examples.extend(_get_nsp_data_from_paragraph(\n", 191 | " paragraph, paragraphs, self.vocab, max_len))\n", 192 | " # 获取遮蔽语言模型任务的数据\n", 193 | " examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)\n", 194 | " + (segments, is_next))\n", 195 | " for tokens, segments, is_next in examples]\n", 196 | " # 填充输入\n", 197 | " (self.all_token_ids, self.all_segments, self.valid_lens,\n", 198 | " self.all_pred_positions, self.all_mlm_weights,\n", 199 | " self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(\n", 200 | " examples, max_len, self.vocab)\n", 201 | "\n", 202 | " def __getitem__(self, idx):\n", 203 | " return (self.all_token_ids[idx], self.all_segments[idx],\n", 204 | " self.valid_lens[idx], self.all_pred_positions[idx],\n", 205 | " self.all_mlm_weights[idx], self.all_mlm_labels[idx],\n", 206 | " self.nsp_labels[idx])\n", 207 | "\n", 208 | " def __len__(self):\n", 209 | " return len(self.all_token_ids)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 40, 215 | "id": "1cf21293", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "#@save\n", 220 | "def load_data_wiki(batch_size, max_len):\n", 221 | " \"\"\"加载WikiText-2数据集\"\"\"\n", 222 | " num_workers = dltools.get_dataloader_workers()\n", 223 | " data_dir = \"E:/data/wikitext-2/\"\n", 224 | " paragraphs = _read_wiki(data_dir)\n", 225 | " train_set = _WikiTextDataset(paragraphs, max_len)\n", 226 | " train_iter = torch.utils.data.DataLoader(train_set, batch_size,\n", 227 | " shuffle=True, num_workers=num_workers)\n", 228 | " return train_iter, train_set.vocab" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 41, 234 | "id": "e041e570", 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "batch_size, max_len = 512, 64\n", 247 | "train_iter, vocab = load_data_wiki(batch_size, max_len)\n", 248 | "\n", 249 | "for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,\n", 250 | " mlm_Y, nsp_y) in train_iter:\n", 251 | " print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,\n", 252 | " pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,\n", 253 | " nsp_y.shape)\n", 254 | " break" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "344a00ae", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "len(vocab)" 265 | ] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3", 271 | "language": "python", 272 | "name": "python3" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.6.5" 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 5 289 | } 290 | -------------------------------------------------------------------------------- /78-NLP-自然语言推断数据预处理-项目实操-深度学习-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bb36f9c5-1dc0-422d-ae2c-bcf4dcd372eb", 6 | "metadata": {}, 7 | "source": [ 8 | "### 自然语言推断与数据集\n", 9 | "* 文本序列之间的逻辑关系:\n", 10 | " * 蕴涵(entailment):假设可以从前提中推断出来。\n", 11 | " * 矛盾(contradiction):假设的否定可以从前提中推断出来。\n", 12 | " * 中性(neutral):所有其他情况。\n", 13 | "\n", 14 | "* 例子\n", 15 | " * entailment:前提:Tom正在看Kavin的深度学习分享 假设:Tom正在B站学习\n", 16 | " * contradiction: 前提:Tom正在看kavin的深度学习分享 假设:Tom正在操场打篮球\n", 17 | " * neutral:前提:两个人在机场拥抱 假设:正在送行,两个人即将分离\n", 18 | "\n", 19 | "* 斯坦福自然语言推断语料库(Stanford Natural Language Inference,SNLI)数据下载位置:https://nlp.stanford.edu/projects/snli/snli_1.0.zip 数据越90M左右" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 14, 25 | "id": "d85f4177-dc47-4ac5-8945-56be966a18d1", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import os\n", 30 | "import re\n", 31 | "import torch\n", 32 | "from torch import nn\n", 33 | "import dltools" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 4, 39 | "id": "2dee8698-b3b7-43e2-a37e-e2fed1152375", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# 读取数据集\n", 44 | "data_dir = \"E:/data/snli_1.0/\"\n", 45 | "def read_snli(data_dir, is_train):\n", 46 | " \"\"\"将SNLI数据集解析为前提、假设和标签\"\"\"\n", 47 | " def extract_text(s):\n", 48 | " # 删除我们不会使用的信息\n", 49 | " s = re.sub('\\\\(', '', s)\n", 50 | " s = re.sub('\\\\)', '', s)\n", 51 | " # 用一个空格替换两个或多个连续的空格\n", 52 | " s = re.sub('\\\\s{2,}', ' ', s)\n", 53 | " return s.strip()\n", 54 | " label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}\n", 55 | " file_name = os.path.join(data_dir, 'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')\n", 56 | " with open(file_name, 'r') as f:\n", 57 | " rows = [row.split('\\t') for row in f.readlines()[1:]]\n", 58 | " premises = [extract_text(row[1]) for row in rows if row[0] in label_set]\n", 59 | " hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]\n", 60 | " labels = [label_set[row[0]] for row in rows if row[0] in label_set]\n", 61 | " return premises, hypotheses, labels" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 5, 67 | "id": "bbcc6288-4be2-4a80-980b-3bc6f7e96e34", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "前提: A person on a horse jumps over a broken down airplane .\n", 75 | "假设: A person is training his horse for a competition .\n", 76 | "标签: 2\n", 77 | "前提: A person on a horse jumps over a broken down airplane .\n", 78 | "假设: A person is at a diner , ordering an omelette .\n", 79 | "标签: 1\n", 80 | "前提: A person on a horse jumps over a broken down airplane .\n", 81 | "假设: A person is outdoors , on a horse .\n", 82 | "标签: 0\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "train_data = read_snli(data_dir, is_train=True)\n", 88 | "for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):\n", 89 | " print('前提:', x0)\n", 90 | " print('假设:', x1)\n", 91 | " print('标签:', y)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 6, 97 | "id": "5939f65d-6164-4fd5-a630-ef4b7d721795", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "[183416, 183187, 182764]\n", 105 | "[3368, 3237, 3219]\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "# 训练集约55万对,测试集约1万对\n", 111 | "test_data = read_snli(data_dir, is_train=False)\n", 112 | "for data in [train_data, test_data]:\n", 113 | " print([[row for row in data[2]].count(i) for i in range(3)])" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "id": "7837ec4b-b305-426c-a93c-de0d90aaee04", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# 加载数据集Dataset\n", 124 | "class SNLIDataset(torch.utils.data.Dataset):\n", 125 | " \"\"\"用于加载SNLI数据集的自定义数据集\"\"\"\n", 126 | " def __init__(self, dataset, num_steps, vocab=None):\n", 127 | " self.num_steps = num_steps\n", 128 | " all_premise_tokens = dltools.tokenize(dataset[0])\n", 129 | " all_hypothesis_tokens = dltools.tokenize(dataset[1])\n", 130 | " if vocab is None:\n", 131 | " self.vocab = dltools.Vocab(all_premise_tokens + all_hypothesis_tokens, min_freq=5, reserved_tokens=[''])\n", 132 | " else:\n", 133 | " self.vocab = vocab\n", 134 | " self.premises = self._pad(all_premise_tokens)\n", 135 | " self.hypotheses = self._pad(all_hypothesis_tokens)\n", 136 | " self.labels = torch.tensor(dataset[2])\n", 137 | " print('read ' + str(len(self.premises)) + ' examples')\n", 138 | "\n", 139 | " def _pad(self, lines):\n", 140 | " return torch.tensor([dltools.truncate_pad(self.vocab[line], self.num_steps, self.vocab['']) for line in lines])\n", 141 | "\n", 142 | " def __getitem__(self, idx):\n", 143 | " return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]\n", 144 | "\n", 145 | " def __len__(self):\n", 146 | " return len(self.premises)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 10, 152 | "id": "087ac68a-d30b-4e9a-bd54-3234649da779", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# 整合代码\n", 157 | "def load_data_snli(batch_size, num_steps=50):\n", 158 | " \"\"\"下载SNLI数据集并返回数据迭代器和词表\"\"\"\n", 159 | " num_workers = dltools.get_dataloader_workers()\n", 160 | " data_dir = \"E:/data/snli_1.0/\"\n", 161 | " train_data = read_snli(data_dir, True)\n", 162 | " test_data = read_snli(data_dir, False)\n", 163 | " train_set = SNLIDataset(train_data, num_steps)\n", 164 | " test_set = SNLIDataset(test_data, num_steps, train_set.vocab)\n", 165 | " train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True,num_workers=num_workers)\n", 166 | " test_iter = torch.utils.data.DataLoader(test_set, batch_size,shuffle=False,num_workers=num_workers)\n", 167 | " return train_iter, test_iter, train_set.vocab" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 11, 173 | "id": "2b127319-f20a-4026-a614-ca77455d6d52", 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "read 549367 examples\n", 181 | "read 9824 examples\n" 182 | ] 183 | }, 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "18678" 188 | ] 189 | }, 190 | "execution_count": 11, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "train_iter, test_iter, vocab = load_data_snli(128, 50)\n", 197 | "len(vocab)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 12, 203 | "id": "863f060a-6e62-470b-b035-dcbba5e7a0e1", 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "torch.Size([128, 50])\n", 211 | "torch.Size([128, 50])\n", 212 | "torch.Size([128])\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "for X, Y in train_iter:\n", 218 | " print(X[0].shape)\n", 219 | " print(X[1].shape)\n", 220 | " print(Y.shape)\n", 221 | " break" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "22ddff0f-b26b-4f72-8a08-98abc2975c80", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.6.5" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /82-NLP-生成任务GPT2数据预处理-项目实操-深度学习-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1c897996-60af-4524-b843-a8707894f455", 6 | "metadata": {}, 7 | "source": [ 8 | "### 基于GPT2的文本生成项目-数据预处理\n", 9 | "项目地址:https://github.com/yangjianxin1/GPT2-chitchat" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "id": "a1414ef7-6aa2-475b-94ee-86d113e08123", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from transformers import BertTokenizerFast\n", 20 | "import argparse\n", 21 | "import pickle\n", 22 | "from tqdm import tqdm\n", 23 | "import logging\n", 24 | "import numpy as np" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "id": "2ce94007-a849-42cf-b5b3-d9599e57c2e2", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def create_logger(log_path):\n", 35 | " \"\"\"\n", 36 | " 将日志输出到日志文件和控制台\n", 37 | " \"\"\"\n", 38 | " logger = logging.getLogger(__name__)\n", 39 | " logger.setLevel(logging.INFO)\n", 40 | " formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')\n", 41 | " # 创建一个handler,用于写入日志文件\n", 42 | " file_handler = logging.FileHandler(filename=log_path)\n", 43 | " file_handler.setFormatter(formatter)\n", 44 | " file_handler.setLevel(logging.INFO)\n", 45 | " logger.addHandler(file_handler)\n", 46 | "\n", 47 | " # 创建一个handler,用于将日志输出到控制台\n", 48 | " console = logging.StreamHandler()\n", 49 | " console.setLevel(logging.DEBUG)\n", 50 | " console.setFormatter(formatter)\n", 51 | " logger.addHandler(console)\n", 52 | "\n", 53 | " return logger" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 10, 59 | "id": "b6457910-16d9-4f28-8b98-db72e50a00ef", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "def preprocess():\n", 64 | " \"\"\"\n", 65 | " 对原始语料进行tokenize,将每段对话处理成如下形式:\"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]\"\n", 66 | " \"\"\"\n", 67 | " # 设置参数\n", 68 | " parser = argparse.ArgumentParser()\n", 69 | " parser.add_argument('--vocab_path', default='E:/Fangwork/Dltools/GPT2-chitchat-master/vocab/vocab.txt', type=str, required=False,help='词表路径')\n", 70 | " parser.add_argument('--log_path', default='E:/Fangwork/Dltools/GPT2-chitchat-master/data/preprocess.log', type=str, required=False, help='训练日志存放位置')\n", 71 | " parser.add_argument('--train_path', default='E:/Fangwork/Dltools/GPT2-chitchat-master/data/train.txt', type=str, required=False, help='训练日志存放位置')\n", 72 | " parser.add_argument('--save_path', default='E:/Fangwork/Dltools/GPT2-chitchat-master/data/train_51w.pkl', type=str, required=False, help='tokenize的训练数据集')\n", 73 | " args =parser.parse_known_args()[0]\n", 74 | "\n", 75 | " # 初始化日志对象\n", 76 | " logger = create_logger(args.log_path)\n", 77 | "\n", 78 | " # 初始化tokenizer\n", 79 | " tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token=\"[SEP]\", pad_token=\"[PAD]\", cls_token=\"[CLS]\")\n", 80 | " sep_id = tokenizer.sep_token_id\n", 81 | " cls_id = tokenizer.cls_token_id\n", 82 | " logger.info(\"preprocessing data,data path:{}, save path:{}\".format(args.train_path, args.save_path))\n", 83 | "\n", 84 | " # 读取训练数据集\n", 85 | " with open(args.train_path, 'rb') as f:\n", 86 | " data = f.read().decode(\"utf-8\")\n", 87 | "\n", 88 | " # 需要区分linux和windows环境下的换行符\n", 89 | " if \"\\r\\n\" in data:\n", 90 | " train_data = data.split(\"\\r\\n\\r\\n\")\n", 91 | " else:\n", 92 | " train_data = data.split(\"\\n\\n\")\n", 93 | " logger.info(\"there are {} dialogue in dataset\".format(len(train_data)))\n", 94 | "\n", 95 | " # 开始进行tokenize\n", 96 | " # 保存所有的对话数据,每条数据的格式为:\"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]\"\n", 97 | " dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值\n", 98 | " dialogue_list = []\n", 99 | " with open(args.save_path, \"w\", encoding=\"utf-8\") as f:\n", 100 | " for index, dialogue in enumerate(tqdm(train_data)):\n", 101 | " if \"\\r\\n\" in data:\n", 102 | " utterances = dialogue.split(\"\\r\\n\")\n", 103 | " else:\n", 104 | " utterances = dialogue.split(\"\\n\")\n", 105 | "\n", 106 | " input_ids = [cls_id] # 每个dialogue以[CLS]开头\n", 107 | " for utterance in utterances:\n", 108 | " input_ids += tokenizer.encode(utterance, add_special_tokens=False)\n", 109 | " input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束\n", 110 | " dialogue_len.append(len(input_ids))\n", 111 | " dialogue_list.append(input_ids)\n", 112 | " len_mean = np.mean(dialogue_len)\n", 113 | " len_median = np.median(dialogue_len)\n", 114 | " len_max = np.max(dialogue_len)\n", 115 | " with open(args.save_path, \"wb\") as f:\n", 116 | " pickle.dump(dialogue_list, f)\n", 117 | " logger.info(\"finish preprocessing data,the result is stored in {}\".format(args.save_path))\n", 118 | " logger.info(\"mean of dialogue len:{},median of dialogue len:{},max len:{}\".format(len_mean, len_median, len_max))\n" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 11, 124 | "id": "5bfc9f74-55ad-4751-8f07-da8a6221dae0", 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stderr", 129 | "output_type": "stream", 130 | "text": [ 131 | "2022-05-12 16:09:55,454 - INFO - preprocessing data,data path:E:/Fangwork/Dltools/GPT2-chitchat-master/data/train.txt, save path:E:/Fangwork/Dltools/GPT2-chitchat-master/data/train_51w.pkl\n", 132 | "2022-05-12 16:09:55,471 - INFO - there are 9831 dialogue in dataset\n", 133 | "100%|████████████████████████████████████████████████████████████████████████████| 9831/9831 [00:07<00:00, 1354.13it/s]\n", 134 | "2022-05-12 16:10:02,794 - INFO - finish preprocessing data,the result is stored in E:/Fangwork/Dltools/GPT2-chitchat-master/data/train_51w.pkl\n", 135 | "2022-05-12 16:10:02,795 - INFO - mean of dialogue len:45.780998881090426,median of dialogue len:36.0,max len:1089\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "preprocess()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "0030e263-91c8-4e4b-bb9c-4f84bcce2b21", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.6.5" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 5 173 | } 174 | -------------------------------------------------------------------------------- /83-卷积计算和互相关计算-卷积神经网络-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": { 5 | "bf3f324e-fb62-4915-a83e-baf3dd96c468.png": { 6 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAARUAAABvCAYAAADCIFTxAAAR/UlEQVR4Ae2dvY/fSBnHR7m8XLLZ241AaSIlGyiOAzZZCcRRJErE0SHIQsU1JMWVoKROkzS8dElDRZETSNcmoqJB2eIQ0lEEKoSElBP/QCRKdMkPfby/Z3d24rFnvB7ntzvfkUZjj+3H9ndmPn484xfnFKSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIgQOswKZz7k5w/KvOuefOuVtBvj/7sGe5v66myylw3Tl3PzBP+f09yNPstAo8adndx865Ry35YRblGZZpuM7Czm/MT/K2c+6ac47KSNxyzrGMikmlbQtWcVlP4c0p4EOFi8C9eeSi8GA+3XVxeHNHfjj3vOacu+mcezpvU9Y+AAqgoCyY9gNlSJuzEELFX2brLGzKCX4+P1lOBK8FkOC5AJkb8+WxSolgB+qEF7Ykhh0Y5cKVjwoMTCgPKihXScqQaaJV7GF70VY5ChggQojQrmhTRMqH5VyYCbbNfLaBj++pHKg2Zl4JJwVhOVEqoJGSlHluddoC28zmorQtV15ZBaysKDfKkECeVULyuFAoTKsA0EB7IuCnHVlq08xb2bCelRlHau3PjtpfZnkLm3JiXNE4CU4MSJj7zDIqKJHptsC2AKft/rFtfeWNrwBAeTG/6lF+3PZYGXJ1tKvh+HuWxTYF/As17YpA+7EyYTnl5F+oaXuUIR4nkTIk2jzLDkygQj6bnwApkLCrHnQkcsImTnhiVFoC6yGUwrQKUEEpAyqolQHlxRWQikhK9CvwtEdY397QHyBYu7KyIZ+ysvZEecVC6KnE1lvofPNUTBAIyUkTTYTwBMzDIZ915K2ECpWft8pHSnngVQJ3KjRXR6BDnxmpwrQK+O2GdkSbAvSW70MF8FB2FqxcbZ5lB+rCQIWjs4/og6ILKpxkWFmBit0jmhhKyypg9+5UQgJeJuVC+VCuwMWvrPPVlEygAG0JkJAaQEjboGL5dlghVNjGbNg6C5tSKTlY3DJOHigwz0mREk0EOwkqK+5dWFm5QlqFtnWVllWAsvDLx8qE1MrV8soeiaybAn7boX1ZWZBP+7DyonwsWDuz+QMNFf8kOFk7YfLtpP08BAIosY5blptwZltpWQX88gHsdAhykaAsgA7lyLTCtAr45WJtyVKOJJzmTsE6c62D1uZZ5q8/7ZkM3BtkRAQTghMAHERcaLutwb3uq6AsZz2FaRSwMgMgAD0EPuXBFVNhWgWsXNgr7clugwz0fpnYcmuDYWrbTnsG2psUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACFSvANzEeex+9tW8rjJX+1Tn378L2/1nQvn1Nq+IqolOXAnkK8C2F2Te/c61IvPi1y/xSo4htjnlpeaWxf+3atVmJeOHChcZ+nqRaWwrUrUADlT/+azYrEX/5+6dNoyxhG5uABWiVCvfu3RNU6m4fOvsBCggqHUQSVAbUKG1SvQKCiqBSfSOQAOMqIKgIKuPWKFmrXgFBRVCpvhFIgHEVEFQElXFrlKxVr0AyVD7524vZR3cfzO7+9nETU0Z0ckd/fvfn580+Umznjv48e/Zstrm5Obt+/frs8ePHHSjZXaSO2urbhwQYoEAyVBi+BRI05u/9+GYSWHKgYkBhP2ND5cWLF7Pbt2/PAMujR4+aYeLnz5/v0iMyJagMqFHapHoFkqBCg+dBM2vseCvvf3BjZ97ywzQHKmzL+iWgAkwAiwUelHv69KnNRlNBpfr2IQEGKJAElbCxM7/0zuqBgUpIjcuXL++BTLjc5g8JVCjjawXjTwva5rh5laSGwN8ES5bTD5xzxJL74ByaX4/2wgHPxPcgUj2Q1PXMwwnhZfmxlGPKfaIWrwVYpIRDABV+YWlPBRdJjx47XsSuHfex48f/UgNR5r+LLaqlaVow5ZfGaVAJGzvz/u1QrNEvGlS4BXrw4EEKT5p1DgFUKOTmtpKyGDvSt/bOmS+/HNuu2cP+yaXlzyqByhbvytm5j51ie2lpqbnt59Z/7Li+vv7SOZcOFUZ+qJwGj9u/ftR01tp8LEUYf7vYepbP+r5HZPmxNMdTCYHi97HEKHNYoBLTb7/5H/783mz1S2df7tdObHvs1wSVnLof0yyWj+2VlZVYVd93/tWrV7/IggoHSgFzG8Q0nbQPnzzbgUzsRHKhgv2z59Z67dr+cqBy48aN2f3793diyi2QoNL9oqmgMqoPtVUdVGjIQAIvhdEga9hdaQ5U8IZYn5gCLPabChW8ktDdk6fSDYyucrVlgoqgYi7OIE/FKlJOmgOVHLu2bipU7MRzU3kq3eARVAQVa1OCiinRkwoqgsqo2Og2Vuftj3kIqak8le5aNMHSZvQntby4rf3VH7aSbm2xOcRTSb11NvvqqN0L9jb9UroM8Oq7OmrpDtja2mqeOg+vtXQdsIwY6zaQpxKqFpmvyVOhr4x3u6i0dMTTx9UHoxyoYPeHP7udNbqHfUFlGyr0MwIGNLFyoYysrPpGZLugAihu3rzZvMoCeHilxQJAsU+1MtgRC4JKTJkgvxaoUDn9546ouADGKm8szYEKNriiUrlj9sJ8QWWvl4IeRNOJMvLneRYlNsjRBRXehzMPhHfieAzEArBJCYJKikqzWfPkLQJPcJtSahdJtz9c5fzGntr4qdA5z6mk2rVGg315KrtgQQ8fIjwc6M9ThrGLQRdUwuZgt0mAZl7/Z7du3doBT7g+84JKmyotebV4KlTMECopzwuxnaAy2vWgt6MWvX2IAJD196/veC6Uob/c4EyaChVud6j3fsB74X05PhsSCz5UeBmsOTAObuz4lfc2GtItLa/OSsQjR95q7HOyJeLa2pqRerSaM7GhJE+FihhCBVfar5Rt02wnqIxWotlQ4bYVb4WIt3n23IWdB1TD8kqBCp5JCBSDiHktNh+mr0GF++kS8e1TS02j5IRKxJNLy4193LUS8cSJE1VAJbwtSf20haAyGlAwlA0VHxyUIVDx8/zpFKjwXpz1rYTQYJ4O21jwoZL06QP/4HKmOVE8oZxtctZFKOyXCrXc/qA5FdJGfBilsVcyuspDUFkMqFButIVYJy1l2AcVRnyePHmyM3TM1xG57bGPmQGbrk5bQSWRQjVBhQqJC02MdfaFgMmBCkPK2GUIlItNaKttHvvqqN3uqLUheYBv8AAmwJ88lrdpaHldUGH0x4aNLQUmeC78pZPl/ghRW/MRVNpUacmrCSpW+XLSHKjk2LV1BZXdkR/TZGjaBZWWqp+dJagkSiaodFdqQWXa25+hQGE7QSXx384IpT6VzoqdNPoztLIKKp3a5y7s7agdWk6CSiJQTChBpbPuCiqd8izUwvqgQkfQd7+/2TzLkjIyQKMfMvpD5xMdTylUzvVU6L2251lS/v2j2x/d/kyInfqgQs89DR64pA4TD4EKowPAYmyo0JsNJOxjTSndKoKKoCKopLSUAY/pAxO/kXc9ZOOvlwsVhjPxgkpABS/lzp07DVTSZKrn3R+/zHKm1acyKnLq81SssvFIcInbH8begRCxBFQYc+cBH5685eWolDCBp8IvNDZGrZp7jTV9KnwjpUT84Ce3+Jr+qxK2sYn9k0un/7H3lIrM8W+h2P9wpvrv0Ke81lJKS2yfPn165+E2+z7KWOmlS5deOed+Q+lkPVFLw6e/g9ufvodtgFCOp2IPXJWCikGE//4AFtK+UBAqVNQt59ytOVSeFmkqzn1CWZWMbx09VtT+0WPH/1tIG98s5UB5tEWWTRH+U7KcJrKNfnlQMW8FsOD62nwsTYUKQPnRzTuzD39xv7k68YYs0zG7lp/bUWsQwWOhb6UvFIQK2vM3t4+dc3grpULR0R/0/+o3vqVfdIxTeltd79b01dW+5djmYloq7PvhN4CScguUChW8HtYlAhj7qZLBI5buByoL4qlszsH++Tj18jUrgsprkrRm4I3gLbbFqTyVuqBCRy2eBI0/Z8g3FSo+NNhm7D4VALK5udm8HMW0/7m8LnIX9lSo3TR6bkFLBUGllLLj260LKjR6RmbwUOxlJh8EsekhUAFaqfvI8VR4IYrO2hQPxUAzAVTGr5p7LQoqe/VY5Ln6oBIDR1f+EKh02QuX5UDFQJGTCirdz6mgv/pURuOUoBI28LZ5QWW0CjfUkDyVocpNv52g0gaRME9Qmb5mBnvMggrlxXMSYTnG5nM9lVz73G7X9D2VvtEfRizD23fy/GdN7KNKoUeeOvrDx5hCG237De3ve/QnVsnCfEElaOLTzyZDxR4TABQ8KBWWZdt8DlQY1QMSvIZBbLMX5gkq202Xhs7gAn2CwMEGGnh/jXmLXc9fpUKF9bjtJ7Bfpol8+NryQ6AwL6i0qdKSh4jzh4amx8E4e0yCCo8H+A8z8hoGF4SwkYfzqVCh8z20H9pqmxdUtislngMNnIDXYD/1YtoPACEWUqDCYAbAMnj4XpG/37Z9CCptqrTk1QKVsEEDi5QRuFSo+Pbtk5V+XmxaUHm9UsZGMPseleiDCtsDDuq8QcXfe2y/to6gYkr0pDVCBa8CWMQaup+fCxU8Ijw/3unx7cSmBZXdCoqnQsNGv7ZGj4fR9TmPPqhgmxBCxd+vrbN7VLtTPlSa//58/dtXX5WIa+9eakQoYRubby8tz44cOTI7c+bMqxLx1KlTVdz++I2avg/A4ufFpnOhgh1sp95eCSq7jdam8CjaHrenz6MrdEEFWNjtVQgVsxnbry1/DSrssETkRCFrCdvYXF7e/u+PCTF2yj5q6FMxaHBrknLbY+sPgQrbAouUPhtBxZrs3pR6aRBgCX0uXb/PYB22aYORLWM5ka/nE60z2N9zuF9/mQ+V5i1lf+GY09yj0ShLBU6ypH0gVQtUAMpHdx/uvHrPvMEjluZAxe+oZfQnxRsSVLZbDgDxIRICBE+j69YHK7SVGFT89mkXZvL8/TId7tffTlDx1eiYrgUqeCcAwo8pL4ymQgWgcMsDJICVD5gYsMyj0XMqswYYXf/fafMqwmqdChX7xw/bAyrbr3+LFNpmXlBpU6UlrxaodDXsrmWpUOmy0bVMnkpLpRyYlQqVgeYFlVThBBW9+zPO40JJVg7XY/qpjSx3PfWpJFWmkislPfzW5S10LZOnMmrRCSopgBFURq10Q4wJKkNUezPb1AsV68hh3Lov5ECFnmX7Jw8ptx59IWf0h+Pl2Ol86nqIx9+nbn90+zMhX+qDCkNKGxsbvUNXfqNMhQrr0dBJifRmA4C+kAMVerEtYD8FioKKoCKoWKvpTgeN/tAQU4au/F2nQsUfg2d7XpgK83y7Np0DFf8NTrwiQaUbGF19KbZMfSqjIqc+T4WHwIAKV+/URp8KFYMEKTCxtzD9/LbpHKjg+QAWjr/vQSHb12HxVEr9S+biexuzcxfffVnK/vZ/f5Y/G7XpLq6xLe4E/G+jjDmN7cL//Xk5/+by9i86rBHFUq7qNEjzHgwssfUtfwhUrM/GbHSlOVDh2PFQOI8F+plY6SrOl+DtqeADmR49euzT0iItiH1+13Igy8g7bgYG0qACHNbW1nbaN/Orq6s787GJIVBJ9YLYZw5UWBewEHknKaXP5hB4KgvSXnQYNSmQ/O6P76kACxppX8iFCg0+9dYnFyocvwWAAjD6gqBSU1PQuY6lQDJUaGA2FMt0Sr9ELlRo7LaPvgbP8hxPBVhZ5yx9QxxbXxBUxqpmslOTAslQoQHSKGn04YdxY40zFyqAKtU2+8yBCuvbMyqp+xBUamoKOtexFMiCSgwesfxcqMTsxPJzoRKzE8sXVMaqZrJTkwKCSowo80/rzXu2a6oTOlcpsC8FBBVBZV8VSBtLgVABQUVQCeuE5qXAvhQQVASVfVUgbSwFQgUEFUElrBOalwL7UkBQEVT2VYG0sRQIFWigcuXKlf+ViOvr618welLCNjZXVlZelbR//vx5XpLifQwFKSAFEhVYm79ZyItAJeJD59yfCtnmeLH/pKB99sFLeQpSQApIASkgBaTA1Ar8H22dQODuS8W4AAAAAElFTkSuQmCC" 7 | } 8 | }, 9 | "cell_type": "markdown", 10 | "id": "e6adbfb5-7a54-4445-be83-081ac236e435", 11 | "metadata": {}, 12 | "source": [ 13 | "## 从全连接层到卷积\n", 14 | "问题:MLP应对高维的数据计算开销巨大\n", 15 | "* 1. 二维互相关运算
\n", 16 | "![image.png](attachment:bf3f324e-fb62-4915-a83e-baf3dd96c468.png)\n", 17 | "
\n", 18 | "\\begin{split}0\\times0+1\\times1+3\\times2+4\\times3=19,\\\\\n", 19 | "1\\times0+2\\times1+4\\times2+5\\times3=25,\\\\\n", 20 | "3\\times0+4\\times1+6\\times2+7\\times3=37,\\\\\n", 21 | "4\\times0+5\\times1+7\\times2+8\\times3=43.\\\\\\end{split}\n", 22 | "\n", 23 | "* 2. 二维卷积层\n", 24 | "* 3. 图像中物体边缘检测\n", 25 | "* 4. 通过数据学习核数组\n", 26 | "* 5. feature map 和 receptive field\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "id": "5c235dc3-7a71-4b3c-b321-434008a23e5d", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# 二维互相关运算\n", 37 | "import torch\n", 38 | "from torch import nn\n", 39 | "import dltools\n", 40 | "\n", 41 | "def corr2d(X, K): #@save\n", 42 | " \"\"\"计算二维互相关运算\"\"\"\n", 43 | " h, w = K.shape\n", 44 | " Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))\n", 45 | " for i in range(Y.shape[0]):\n", 46 | " for j in range(Y.shape[1]):\n", 47 | " Y[i, j] = (X[i:i + h, j:j + w] * K).sum()\n", 48 | " return Y" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "id": "0a055df9-9ea4-4f3f-b327-2f0181102a54", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "tensor([[19., 25.],\n", 61 | " [37., 43.]])" 62 | ] 63 | }, 64 | "execution_count": 2, 65 | "metadata": {}, 66 | "output_type": "execute_result" 67 | } 68 | ], 69 | "source": [ 70 | "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", 71 | "K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])\n", 72 | "corr2d(X, K)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 4, 78 | "id": "ac6bbbc7-bbad-444b-9054-36a44dc931b6", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# 二维卷积层\n", 83 | "class Conv2D(nn.Module):\n", 84 | " def __init__(self, kernel_size):\n", 85 | " super().__init__()\n", 86 | " self.weight = nn.Parameter(torch.rand(kernel_size))\n", 87 | " self.bias = nn.Parameter(torch.zeros(1))\n", 88 | "\n", 89 | " def forward(self, x):\n", 90 | " return corr2d(x, self.weight) + self.bias" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 11, 96 | "id": "117377db-70e7-4d7b-9195-44b6f1253c84", 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n", 103 | " [1., 1., 0., 0., 0., 0., 1., 1.],\n", 104 | " [1., 1., 0., 0., 0., 0., 1., 1.],\n", 105 | " [1., 1., 0., 0., 0., 0., 1., 1.],\n", 106 | " [1., 1., 0., 0., 0., 0., 1., 1.],\n", 107 | " [1., 1., 0., 0., 0., 0., 1., 1.]])" 108 | ] 109 | }, 110 | "execution_count": 11, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "# 边缘检测\n", 117 | "X = torch.ones((6, 8))\n", 118 | "X[:, 2:6] = 0\n", 119 | "X" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 12, 125 | "id": "e3cc8afb-9ff2-4446-ae7c-258323260af0", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "K = torch.tensor([[1.0, -1.0]])" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 13, 135 | "id": "e25aaddc-312b-4fbe-b3a7-66394fdd7bb7", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "tensor([[ 0., 1., 0., 0., 0., -1., 0.],\n", 142 | " [ 0., 1., 0., 0., 0., -1., 0.],\n", 143 | " [ 0., 1., 0., 0., 0., -1., 0.],\n", 144 | " [ 0., 1., 0., 0., 0., -1., 0.],\n", 145 | " [ 0., 1., 0., 0., 0., -1., 0.],\n", 146 | " [ 0., 1., 0., 0., 0., -1., 0.]])" 147 | ] 148 | }, 149 | "execution_count": 13, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "Y = corr2d(X, K)\n", 156 | "Y" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 14, 162 | "id": "6467058c-8b8f-47fc-9d83-7d81a2d3c31a", 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "tensor([[0., 0., 0., 0., 0.],\n", 169 | " [0., 0., 0., 0., 0.],\n", 170 | " [0., 0., 0., 0., 0.],\n", 171 | " [0., 0., 0., 0., 0.],\n", 172 | " [0., 0., 0., 0., 0.],\n", 173 | " [0., 0., 0., 0., 0.],\n", 174 | " [0., 0., 0., 0., 0.],\n", 175 | " [0., 0., 0., 0., 0.]])" 176 | ] 177 | }, 178 | "execution_count": 14, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "corr2d(X.t(), K)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "id": "548cd656-004c-4bb3-82e4-57393d07ef36", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "epoch 2, loss 6.492\n", 198 | "epoch 4, loss 1.090\n", 199 | "epoch 6, loss 0.183\n", 200 | "epoch 8, loss 0.031\n", 201 | "epoch 10, loss 0.005\n" 202 | ] 203 | } 204 | ], 205 | "source": [ 206 | "# 卷积核\n", 207 | "# 构造一个二维卷积层,它具有1个输出通道和形状为(1,2)的卷积核\n", 208 | "conv2d = nn.Conv2d(1,1, kernel_size=(1, 2), bias=False)\n", 209 | "\n", 210 | "# 这个二维卷积层使用四维输入和输出格式(批量大小、通道、高度、宽度),\n", 211 | "# 其中批量大小和通道数都为1\n", 212 | "X = X.reshape((1, 1, 6, 8))\n", 213 | "Y = Y.reshape((1, 1, 6, 7))\n", 214 | "lr = 3e-2 # 学习率\n", 215 | "\n", 216 | "for i in range(10):\n", 217 | " Y_hat = conv2d(X)\n", 218 | " l = (Y_hat - Y) ** 2\n", 219 | " conv2d.zero_grad()\n", 220 | " l.sum().backward()\n", 221 | " # 迭代卷积核\n", 222 | " conv2d.weight.data[:] -= lr * conv2d.weight.grad\n", 223 | " if (i + 1) % 2 == 0:\n", 224 | " print(f'epoch {i+1}, loss {l.sum():.3f}')" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 10, 230 | "id": "a13bf7de-7af6-42eb-b5d8-54a23b0a8ab9", 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "tensor([[ 0.9880, -0.9855]])" 237 | ] 238 | }, 239 | "execution_count": 10, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "conv2d.weight.data.reshape((1, 2))" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "b0446144-e9cb-4111-87f5-7e1ce39b2024", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.6.5" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 5 278 | } 279 | -------------------------------------------------------------------------------- /85-Channel多输入和输出通道-卷积神经网络.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": { 5 | "c4539663-dfc2-40eb-bacc-f1ce9fc046c8.png": { 6 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWgAAACHCAYAAAA7gMM2AAASjUlEQVR4Ae2dwY8cRxWH27Pr9UqxsMU6SEjGCQdQ4rXkWAjw7mU5IFkKkWxuIA6JRODilZIrlpD3AlfvFXEgggPHRJEs5YJ2Dyi3iBy4cDL/QATJLQFl0K89z3TPVvdUT/Xr6e75Wmr1dFX1m5rfq/fN25ru2ixjQwEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQAEUQIGFCvwgy7KjuVaXsyz7d5Zlb8yVF0+PF9QX2/K6/wqcBrqIjwOirEnRK1mWyf/FTVz4OFBebGOvNZ7EFrYEBeSEP2RZ9jDLsoMsy17MskxOkLiqe6cGwuYsnJDggB5daoCW309m+9+yLHtSOFcd2/gVkJ/FBLGhyAWNhyouzCd684Cejl+29j+hMuR/zjJoZdFvz74hdZRj7s7q36p4awFd17MNVwEFloJRINbx3izz0diwIFMb7fpSZhu/AsYFZdDigs4FZyVs4oJ2jY1HBSkAdEGMtl4Ws2DBVg7RN6ScYruEV3koOHXNoqmQtvqKHR8F5EP5WH7UUefaFICWMb83K+OwPgpoKsPGg1ig18aE4rmNF6s3hezL3c7JoE2JBkd9MypbNsEltrIo7arTuYJUr0ObXStnsA1XAflYAaRxoE2ZkjImjYN3Z2NkVsVhDRQQbC171mttGhtFLihhUxvb1K44Hfav2Riy6TIAbUo1OEp0m2fUUcBVcEpsQdfAbYE7b1rTG3KUXTNfz/kwFFCg2VSXpjgEbI0B+VXZs/2lFforahifkF42UWCeC0rQLKMucqEuMVOdwb3Je9M2oICBuPgtWHRECNBymn2DGtADpinquQL6K0kQlr8FYPlaR80vWplArS9vtvVSwLigT60pMGXDRS7otW3F1yrTeRHQ4kXVX+Jmg2NAAQWofq3VblMWJrA5aB7QyrAsezaTgjVBbGoM52g+tgBTEFk2rU9hf1UN5xPR0zYUKHJBY8LGh47zY0bvNz+FMQ9oXTPPkTb6OWobyowkpHYFov68NQfYUd+CRWGVXWk6pPjtKJFUrj+DBG+24Sggv2mTv7Xp3HbLojVO2NZHgXkuaDzYl3aRCzZmpAyAdhwf9o1YhLGJXywTfPVDQNWfKpZZC/Rsw1LA/K1ey7/ys/1FpHMgPSx/ttFb44Js2fiwY7FMrwVou8FAR40fe8ZC55oeKSZ6bfRvbWyYIwzGcoJl1MqWDciaxliUIevblmAe3tCxwJOPtetP3OIWKivW83p8ChgX9Mk0PjQGVCYG2E0F9qkFaPGjahcTALSpxREFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFUAAFziig5f7s3597HF8duP35pTLPCEhBrQLe4wv/1MpP5ZAV0HqrWnt1sPtkMnHt+2Qy+XTIDl5x393H1zn8s2IX8/aeCuQB9Js/nkw99h/97H4Oz5OTk6nH/ujRo07sezpg5LZdx9ebD576f+Qa8vHWWIE8gN7/x3Tqsf/8V08DaOq0CfrK/r02s7/G4yP1o7uOLyUVs7/+UvvJ9SjQSwVcAwhA99LnXXbKdXwB6C5dyXutQgHXAALQq3Bpr97TdXwB6F75ms44KOAaQADawWPDMuk6vgD0sAYDvW2ugGsAAejmDhnZFa7jC0CPbLTwcc4o4BpAAPqM3utW4Dq+ALTLcHp7yLfdqu/et95ubGy876J8wKhrAAHogOLrVeQ6vgC0y2A62t3d/cLjtljZvHnz5vT69etfetl//fXXp5cuXXK1v729/aGL8gGjrgEEoAOKr1eR6/gC0C6D6Wh/f/9zr1tXDw4Opnt7e1962X/48OF0Z2fH1T6AjvSe3acc2bxxM7PvEgbrYRRAD8/PALqGFPoCSAV0k7UP3tKczW//dBq9//p3j6Pbvjp7kvD09HQauz9+/Di6rT1JGGtb7Z48eVIjf7kKQAfp4jq+fv+XJ9EPTZFBB/2TWgigyxgonaUCOs9YhjzJ7z3Bf/ny5ZLgdScA+kysu4+v575yGUCfkb3TAgBdA4VWAK3MwmMfy9oaNfqXqgD0GTDkgPYYW7Jpa2vELjuga2bJyJmOUrC0AgC6RIHySSuAjh3gTduN5Ue/suTVZwD6TJB3MqccOy4B9Bn/tFEAoKuRMAXQHSx+VKN/qQpAn4l3AH1GktEVAOgSBconABpA9zniAXSfvdNO3wB0mcmlMwANoNsJMx8rANpH1z5ZBdAlJJdPADSA7lOwzvcFQM8rMr5zAF1mcukMQAPoPoc8gO6zd8J9Y22NBf/S7+7duyUI150AaAAdDrN+lALofvihSS+Orn1r9wuvWyO/+dLN6Y0bN1z+pZ1+pNfaGnp2Qa89dtnX4+exG4AG0E2Cr+u2ALprxdPf7+ilW/ufx9662LTdje8dTPf392P51ridgHjlypXG18VeIPsAOlItu60tsnnjZk3tW/v0GBmNBQA9PFcC6BpStAFo17UP1nFtjePj42nMfnh4uA5PqrmOr2XW1ohdG+bNB8fr4J/UrwQA7QjoPGNhbY3MAvHMccm1Nc7YqdJ4tjZIapD09Xr38bXk2hrR/jk3mXzWV3F70i8A7Q1orwn+dV1bI/bHBlstryeB5tGNHNBe4ythbQ31K3Z/0UOYEdkE0N6AbjpxH9t+XdfWqPFXqWoN5qBzQMeOl6btbK2M2Ous/Yjg2IePAqBLUV0+SZ2Ddg0gAF121vwZgJ5GL/0ZgrABN1QXKrP2faDaiPoAoOcDu3AOoHt421zBP7UvATSAHgGoAXRNlANoAN3nGHf9C80y4lC2HCqz9n0WbIB9A9AAOqyAZaDh2vTSpvaXbT/AoIztMoCOVWq47QB0DWrIoMmg+xzaALrP3mmnb+6A3tvbq0FgWpUAurOzk2ak5mrZ39rair6tU7fjbmxsfGKucQ0gfiSs8dx0mj/7P/J/qeQ6vmzKIjSdESqz9jb4ObaigDugh/6o961bt6LX+dDaHRcuXPjIPOMaQAAaQOsLKATLNsoMuLG2rL0Nfo6tKACga8KcKQ6mOFqJMicjrgmAARdAO3kvziyABtBhBZr+KBe2Ul1q9mPW1VAbW1uj2mK5xuzHxcEgW3UCaNbWWOnYANDlsC6dkUF3kEFXraMRKt/Y2Cg5qO4EQLdzH3TID1VlrK3ROswBdE2QA+gOAC2Qxuy2tkaNv0pVALo1QCtTj91ZW6NdRgPoUlSXTwB0B4AuS159ZsCtblGusfbtxkuvrHUyxdGrT7x+nQHQ5bAunQFoAN1nJADoPnunnb4B6BKSyycAGkC3E2Y+VgC0j659sgqgy0wunQFoAN2nYJ3vC4CeV2R85wC6hOTyCYAG0H0OeQDdZ++00zcAXWZy6QxAA+h2wszHCoD20bVPVj+ouqVxKOW6Ndazr88//3wJ2nUnAvr29vaH5mDXAOJR7zpXsBZH7BOAVe3sSUIbzBxXosDx7u7uf2NuQ12mzcsvvzzd3d2Nus11Gfta++LSpUuu9m/fvl0PgkItgC6Iseil3Qa3qJ3VL9t+JWHVzZu6JgAAuhsnLniXo/39/c8tBto+HhwcTIe+WJI+Q+wGoGOVKqw2F3vJAAB9L8sy7V1tADpdaWn4sGJ/I918sgUAXQOIVuagf/HgeOqx3/7hj/O5ndi1LJq2s7Uvml4X297s1+hfqjJAn56eTmN29WM295UcJQsMXM6y7J0ZnAXod7MsU5n3lgM6dq2Mpu3efPBMv4Msy7z2VT95qPeXjqH9FW8HRtgH0CUKlE9SAf225+R4F7Ynk4lBzuW45Noa0X2Z9T8iDpKbKNCPsiyTz7uCTg7oLsaB23ucm3yarHyaAfmq6ssHQJd52PhMAL1y5Urj62IvSAV0HkDK/Dz2+/fv56DysC2btvaFt/1YZ6gfAsXe/UdR+/V7T/VJi9+oq5UtvzfLoPVn8ccdZdDqnCASyv7aKMsTjNcenUw9dvlxBv4okZ0aSSd9sYZ2pjhig7Oi3SAAXdH35GIDaLKhCgMGxIrq5OKm9q39L0+m05hdUOkYAAroPgR1WyzLE4wYrZdpswL/tKVLl3aY4qghTSsZdI39pCoAXQ9qAJDMEQCdLGGyAQBdQ0kA3cMHVWKzNQCdDAcAnSxhsgEADaDDCtiUQrg2vbSpfWsPoJODPtYAgI5Vyq8dgK5BDRk0GbRf6PXfMoBevY8ANIAOK2AZa7g2vbSpfWtPBt0ZNQB0Z1JXvpE7oPf29tKDucKCMtydnZ2K2vRi2d/a2rKbAaKOm5ubn5ja+QBP70bYAj8S8iOhDTSnI4B2EraBWXdAD/1R751r34667Va3dV797p3puc2tj0x/AB3+bslLLSOuaVKqsvZk0Da83I8A2l3ihW8AoEsUKJ8og/7GrYOo227Fje+8/nC6cX5uNbuyyfbOyKDJoBeGd1oDAJ2mXxtXA+gaZALoDn4kjFlXQ21sbQ0y6KS415OHVY82z5e/pQd9Xnt0Gr3/9M9PorMZboOM8iOABtBhBWxKIVybXmr2Z0/7RU3wTzY2AEBUXAcb5RlxE72btt26eBn/BKVfuhBA16CGDLqDDLrp2hpk0EsHew5oZa4eu62tgX+W9k/oQgANoMMKWIYbrk0vNftNA7pp+9CoX9OyTuaU8U+rowtA16CGDLqDDLppQDdt32q4DNsYgB6e/wA0gA4rYBluuDa91Ow3BW7T9sOLSbceA2g3ad0MA+ga1JBBk0G7Rd4KDAPoFYie+JYAGkCHFbAMN1ybXmr2m2bETdsnBsiYLgfQw/MmgK5BDRk0GfTwQrq6xwC6Wpu+1nzQ9FbHvrXf3NyMuoV22X4/99WvRd/ayZOENd9281Vk0J0zAUB3LnnyGx6/8MIL/9FTwx77tWvXplevXnWxrf7euXNnevHiRVf7X7/+fQA9D9c2zgF0cvA2NQCgmyq2+vZMcdTAhikOpjhWH6Lt9QBAt6dlV5YAtDegY9eaaNru8PAwn9tpel1se1v7IrZ903Zmv+mPfrFrQ+wdHtvcV1fBtIr36d3aGvin1WEAoB0Bnf/b+mUnv/tw3WQyMci5HJdcWyO6L+cmk89aDZd+GcszYs9xsuTaGtH+yc6N2j9tjBYA7QjoPIA8Jvdl8969e3kgeNm/f/9+bj92rYym7a7fe2q/aQadZZl0jd1fbCNKemojH18e62rIZsLaGrG+Ubsx+6eNYQOgvQFdYz+pSmBW9uS1Nf0RLxa01k4QUP/tfNHR2rcx6kdiIwf0It2WrTe9Y6+39iPRti8fA0DXAI4fCRsANDaQrZ0FtJ0vOlr7vkROD/oBoHvgBOcuAGgAHVaADNo59NLNA+h0DftuAUCH8ZSXkkGTQfc5gAF0n73TTt8ANIAOK0AG3U6EOVoB0I7i9sQ0gA7jiQwaQPckRKu7AaCrtRlLjTug9/b2ahCYVqUpiJ2dnTQjNVfL/tbWVn6zQeztpuc2z//VBkceQDX2k6q4i8NkXtsjgB6/67sA9JdJIKq5eAZoV/sXLlz4e4PbbhUzz27tBNAn08rb6OyujEV3b1i9tR9/TEZ/QgAdLdVgGwLoBV8A29vbHy7rXQANoJcdOzHXAegYlYbdBkAD6LACXc1Bs3bD0gTpBND4Z2n/tHEhgA7jKS/VFAoZdE0WbNMPyxwLUxbRk/wjX1ujaUB3AujYH1/ydqyt0dSHi9oDaAAdVqCrDHrZCf5FI3sN6rsCtN4ndn/2A8wa6N/FRwTQYTyRQXcI6C4G+hjfoytAj1G7oXwmAA2gwwoA6N7HMIDuvYuSOwigw3gigwbQycHlbQBAeyu8evsAGkCHFQDQq4/OBT0A0AsEGkE1gA7jiQwaQPc+vAF0712U3EEADaDDCgDo5ODyNgCgvRVevf2jRrc5Zln0Latd2d3c3HTt0/nz/19bo6m78gAK4y+9dCxrcTQVlfbPFADQz6QY7Qvdtig/e+0/ybJM+1Dtq99L39qpi9NJXGEBQI82KGM/WD6+lnlIKOaawoNEsf2hHQoMSgEAXfMUIgBIHssAOllCDKyzAnkAnZ6eTj32w8PDfG7Hw7ZsHh8f5/Zj12Jo2m7v8Kn9dR4giZ89H19NdY9tj38SvcPlvVcgD6CuJuM93mcymbhO8LO2RtIY9h9frK2R5CAu7r8CrzhOwCtAXx24/aUn+Pvv+k566D2+8E8nbuRNulbgf9AzuafIqPpKAAAAAElFTkSuQmCC" 7 | }, 8 | "d6ff1699-f041-4865-9c55-ca4a7195765c.png": { 9 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcQAAADDCAYAAAAcLWETAAAgAElEQVR4Ae2dXYge2ZnfS90zI3W3Wt2jkVpjaW1pZmfs+bTajL5GzHiExzAhsS3FGDxDwBKswYFkkC693gv1TbK+kyBkY8iFRALO3knkKjdhFDaQsGvQQm6yWYOWhQ27V5Ml2ZA1blX4lfrpPV1dH8+prvOq6n3/BcWpj1OnTv3rnOdXz6lTp7JMkxSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACoxQgfUsyy5V5PtGlmXsq5uuteyvO07b964A2pcn7pem9AqovqTXuOoMlPlTpR1XamxXKZpWsyyrsg/ox9w2oX2VzWk7bpT7Eap8sXe3BPyzikJoF4lhYP+qbVA4MQUebJ0JA/HB1nwvy7LrwfrEMjNjJ1J9eTo3nDLPgzv2xsr8zSzLKPe2LltUfW/QxWyGxcDmsw1bX7b/xLltEbMs29iabRP7mpwlize6kAv7X1vCIA7CIBAzExf9xw0Xz9PF/a24CtIrwP2g8j/cCtGfwso94D5ZwSXU1L8Cqi/9a9qWogGQMs8DH1C0co7tsmXCsgfZlvYs7Mem47gwY+MpwwZDe4DAdpShGALUNDa92FfVqmj7Rx0iBgUJcTCsXDwXbDMCsmzilS82bwBmOa7W96YAhZb7QeEODQD3x+6hpwlkb7mY7aNVXyZ7/ynXlHXKPA/qZojZZs2AKvPN9wSbgW5oSUgZRsvQxrMeQpF9NqE1s03ss/tg26Yi5KIQi4s1waz5wVxiE7Hqgs1DNI+yKo629asAhfZR8IBiBZn7dadUqPs9s1JTfXk6ZYCHcR68zWCb3cE2XVUrVetNMYBh59HSWpqw9TaZtrbOAwhN0syfbc22jv2ZSiACQ4woF2zGFPEQzkIuPHw6MMEIMcYUSntqDvdpOY0C5iHizVO47T5xD1hmG3E09a+A6kv/mnpSxFjTPIqRxt5wH2xmO9vM2HvSm6U42G+am83Oo6XZDPQ0+862cEJPm9lnx9g2NJ+6iYtCEMBG8wNPXgjHk4MJYIKVL57tJiIiywiXFep/nSc7dEb3sHMH+gNEYAgoWdfUvwKqL/1r6kmR8s5MubbWKOoCBp3QPEZPWrMWB5uArcZeoB16oaX1RTD7zjab0DkEnkHQ9rNvqm0MF2xARDSDHdtNMBPDQgww4tpUXrftCvtTANhRGK3wYgjw7rlnFPzQYPR3VqVUVkD1paxIunXKNLaJMo8twubw0G72hnXASPnXVK8A2vEwbbadEA0JKc9mU0jBtLbU2M9sk6Vh61MTcpEYVNrnzbNgnbZiRGG/CRZeNAa47BFSMBFYBTNUqt9l7gWTFV4MhW1jWR2ctgRKFKi+JBK2IVl76DYjTXk374Tlz0vGuiGpmdxlD8289+PBAe3QEhtv7wIp12ZTEMm0NsHYz2wTuofxbfvoQytsXCwXyRx6iLbNxAB2PJ0Rp2pCbEGxSpl+t4WFEY+RhxgeUOxp2e5rv2dVaqar6svky0JopLFDGHTzEgkNkpPP2bDPiH1gxpabHTct0c32hTaF5fCbZmy+vUqjqZV9YfxhK9Ahd2EFZ5nZwIgQJmSVZ1g+nUGxvF3r/SlghZGQmcJuE0abgq4pnQKqL+m0rUvZyjktIZRvQpsAJNvCemD7FD5RoAxEPEc0QzvKc2gz0Bpbz/aqmX3EmdoJscInCYwqQjGzj5DJwq3V2sAbrzYB7WhUwJ6GuWdVU932qrjaFq+A6ku8Zns9gjJvNqoqLdmcKlX+bhvaUW6Zylpaed7aXbQ0NenJPmstsWMUSgEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkwOgUOLVv39xfZ1mWp5rn5uYfp0qbdOfm5pKmPz8//+eju6vKsBSQAlJACkQrwP/c8k/+6c0k8zf+4dUCtKnS/9p7HxXp37x5M08xX736JP/RquoAKSAFpIAUGJ0CBRD//Z/keYr5n/2bzwpgpUibNH/rt28V6eeJps8+e5L/0d1VZVgKSAEpIAWiFRAQG2AqIEaXJx0gBaSAFBitAgKigDjawquMSwEpIAX6VEBAFBD7LE9KSwpIASkwWgUERAFxtIVXGZcCUkAK9KmAgCgg9lmelJYUkAJSYLQKRAGRXqNvn7/k7pEa08uUuGsnThW9Rj/87jXXOWJ6mT58+DC/fPlyfvr06fzGjRsNGPy7XepUM9pyrYxLASkgBaIVcAPx53/0eQEpvlv0fkYRA8Rv/+B6kS7nWTtxMr/+u3dazxMDxDt37myT7uTJk/mjR4+21+sWBMTo8qQDpIAUkAKjVcANRINgCiD+6//4KDfgch77kN/OWRfGADGE3gcffBCu1i4LiKMt18q4FJACUiBagUEAsQy88x9ezoFkeXt5PRaIeIVXrlwZcpPpjSzLPks4/2mWZQ8Tpk/eKVOapIAUkAKjU2BwQLx9/2H+w5/caoUhcIwFIu8R7927V7xHDJtQ61zESXuIzx1Y+C9Hj3/p8Ztn3t9MMS8sLeeLyyt5irRJc2FpeTPLso3R1QJlWApIASmw9TTvgo95ZymaTC1tmk1pLrX1tjAWiAY+oMjYp23T0wBizPW36VPe/5XTFx6/de4Dt77l49vWX1u/+P8ERNkVKSAFxqrAYDxEYBh6hniKbQa4KxDxDofqIQqIY61KyrcUkAJjVyAKiNZrlLANVuy3+G1xgeHLr68Xn3TwWQczf8poOy4GiHSkuXv3bn7//v381q1bbc5hsV8eYtyg7/IQx24OlH8pMNsKRAGxDVDl/V4glo/zrscAkfeHAO7zzz93wZBIAqKAONvmQVcvBWZLgZkBopuCQUQBUUCcLXOgq5UCs62AgBgAsLwoIAqIs20edPVSYLYUEBDLFAzWhw5EmqR/8i/vtb5rtSbomF6mjBTEu1yG0QsHTbC0qkK9Q5wt46GrlQLTpoCAGACwvDhkIDJwAQMYxPRK9QIRyAJEgEvnJs5TBcDyNgFx2syDrkcKzJYCAmKZgsH6kIEIjIBhCiACQoMd4PV+eyogzpbx0NVKgWlTQEAMAFhenFUgGgwJ+R5UHuK0VXtdjxSQAlUKCIhlCgbrAmJeDJbgGVcWeMpDrKpi2iYFpMBYFPiY5jDNrRpM5H4ylmlME2iqJlPzEGM77QiIEykmOokUkAKJFPgWMGR8yxTzS6+dLmBrhrvv8O0L3yjSZxSaFDM/E956WEgk/85khwREmkrDHqyenqYC4s77qTUpIAXGpYCaTIMm0vLi0JtM+alyjEfp7WUKDPnk4pNPN4r5O1dvFD1OzXOsCwXEcVV+5VYKSIGdCgiIZQoG60MGIh4bTZphj9A6UNn2GCBa2hZaGk2hgLizcmlNCkiBcSkgIAYALC8OGYhNYKrb5wVi3fFt2wXEcVV+5VYKSIGdCgiIZQoG6wKihm7bWV20JgWkwDQrICAGACwvCogC4jRXfl2bFJACOxUQEMsUDNYFRAFxZ3XRmhSQAtOsQDQQ6czhHbmEDhl8ttD27sn202PSfhDMWJq2vS6M+R8inONfiBsbG/mdO3dy/o/YNu0BiA+6FJrYzy7qdKnbrneIXe6KjpECUmBWFIgGItDim8U6oxtujwEio6F4IBimHwNEYLi+vl789LcNhLa/AxCvZVl2N8syQlt2lyUB0S2VIkoBKSAFelcgCogACw8xBRD5qwLeId+8eT4CB4wxQLx+/XrOHDN1ACI3aD3LsvtbQIy6YQJilFyKLAWkgBToVQE3EMORS1IA0dIn7bUTp1xQjAHiyspKfvPmzQKKJ0+edHmKHYCIV3hjawaMeIvu6bn9B/6Ma7dm477DxYOH8qXl1aTpx16zWxxFlAJSQAokVsANxB/+5NZ2M2kKIIZNoaQfni/cFy57gfjo0ZNfGNFsynTv3r381KlTrc5iByDa7YoCoR00Nz//6Nnn9ueLyytJ5mzfvnz+mWeTpE2e5+bmeWfc6dpNA4VSQApIgaelgAuIwOnl19e3PYulQ0+8jBBOVcsx7xDD42ma9QxJ5gUi5KNzTziV18N9trwHIHa6n6mbTJdXX3js0TW8FzHL+jC/022f2EFXsiy7mXD+d1mW/asRp482NOtoml0FXEAsG8XUHiJGG5iWz1tejwEiA3XjKTLhKdJs2jYJiPrsYmpMw8Li8h8eXjvxq6+cPv83Keb9C4uPF5aWN1OkTZpLBw8+Pn78+Ob58+f/JsW8uLj4qyzLNqbmhutCuigwCCDSw3RpeSW//tO7+e/83n2XdwgcY4BIM+nVq1eLzy3oXMN62yQgCohdKtUgj1k8uPKLlM0Dr7z5zqb3Sbn8ZOtZf+HI2iadAFJN586d+98C4iCL7iQz1QmInvJLnJgmU6BIfO/PaEk/BojUIzxDIGfvEtvqloAoIE6yMiY9l4DYXN0FxKTFbyyJDwaIXsiG8WKB2Fwjdu8VEAXEsVTk1nwKiLsreLhFQGwtQrMQQUAMK0VpWUAUEKfGCAiIpdpdWhUQp6ao7+VCBMRSvQhXhwxE/mbPd4offvea65tNPOsuvUzp8cs3oqFnXresXqZ7qYqJjxUQw6q9e1lATFwAx5G8gLi7amxvGSoQAZT1wv32D67nzHWQCrfHApH3ufSTsHOFaVUtC4gDrvQC4na9rlwQEAdceCeXNQGxsnY82ThUIIZD2+HBeQYxAGCxQCRd7ycwpC8gTq7iRp8pFog8BXmfhLj5Mb1Mad745NON/Ic/ue16kiP9mF6mVFxG8b99+3ZD9d65S0CMLlLTeICAuLNa7FgbKhCxD8x4cBe+eSVJkykwBLwC4pRU+xggAsK1Eyfd3z/FABEYUrgIX3rttLt5wwtEfmNjv7O5fPlyfuvWrR2Vum5FQJySgr63yxAQ6ypInhefaDCizd4k9h8dM1KNwQqbwrtEg2RT6PUQwyZZAdF//wYdMwaIFCLaymO+W/R6iIDQCing9X676AVi+E0V3y0KiIMulkPL3MdbBr8Y2myMy3Nzc0nzPj9fjM85kfsWA0SzKYCRQQ1svSn0ArGwhZ9uFK1a1nHH03qmJtOJFJNuJxkKEMMCSqHy/lzVC0R7wAWMMb+3GamHmNT4TQgIeGVDmb7FNX/wwQdJZoZLS5n+Sy+9VKTPg2CKmZFttsrERO5XFyBiX7wP2V4gYqds5rdctHB5BkwQECdSTLZPEjU26TPP7f+L8EmHd3hN88rho/lvvvlOY5zw+MNHv/C4aGZtSdeOef/vf8x7x/zvffyPXedYWFx6jKHi3WDb/OMf/7gwaAcPHiwMg0GyKRwrEKmcVln7DjF+/+Af/ZOk6WdZNiQgFk2mTeVkL/vsHdxe0mg6ltYQ7lmqyfK/bYESL3iBSLlnmDtgSAsUHWvCB++6ZS8Qw+PVZJr4pndNPnZs0lffPrv52tcuFk9PgLFt3r+wlPO/sLZ4tv/LXz2ff/n0eXf8gweX8xMnfiN/7+sfuOazZ8/lFy9edD+58783PERCzzRWIGIMwgrb5zLGldFP+kwzTGvL2xAQPQXUEWdWgcg7PkDl9dysDHYBIp5h2KvV0qoK5SF2pVuH42KbQKtuWNO2V946U8CtKc5e9h1Ze9HtvTlswa4oeJM0G3lG8edgAXH3KBwC4q5itacN5mHtKZGGg2cViF3tUBcgxpxrYEA8tfUAas3eScP5+WeSpl+6lhvZ6IF49FhyIPLOgx6nnklAFBC3mm89xaVTHAEx7snf22QaA6kw7owBsXgdkPIVi72y4T3r6uEXNinvqec33njjb4ufMqQG4qtvn03rISYG4nvvvVf83sZruQREAVFAbK4tBvQ4rHWPLSB2167iyAKIKV+x2MMGzdeHj6xtNpemfvZeuHDh/04GiF89lxSIR49NpsnUK7uAKCAKiM21RUDcXUcMAlXhwJpMBcSqm+Td9qqAWPGQNbhNUaMJce+9HQKIG/sOMfbpc+s9wCg71dAcf+nSpe3Z84mPAaUZO0/2MsAEadLkTxh+X1t3fMw7RNLnHTvpE3rSt/xPqhbIQ+xVaQHRC7+qeAJir4UxVWJuIAIrPouJgZYXiPTyo7cxo4QsHVp1/w1grEB89OjRjj/M87d5YNQ2GVDa4rGfTmE2AS1P+jFApLOZQdC+WbTz1YWW/1SFuZyugFhWZE/rAqKBjq7CMd/ncFwsEDG0McY2pskUgxM7Nqn1Mq2r3OXts9BkCrRi7pEXiOGIRvxpgJfqVvaawrECsVx28BaBZNtkQGmLx/5w2EE8RI5tm7xAxDtEe5tIOwSwbS+Hlv89meWIgwXECLHaowqIGCNgaKPB8CTvNVYxQMTIxo596gWiPX0TMtKHp2mKiiwg7n7fkQqIYTMsD17eMjYtQPTAhDJpQCmDpmodwPIN7ZUrV3Z4o1VxbZsXiKSN9uYhki/P97qW/3bb208MgLh24lQxNilDpvU97z+wmKdMf/Hgoc2iw0c/cuw1FQERIGKc6GprT+o0a3mGHYoBImljbENPwc5XF8YA0So8FdJrfEYKxPUsyz7bNz//n1zz3Fz+8utfcxsKBlpYPfKiO/6+ubn8CydfdcfHYDE4g9cLnQYg2gObldGm0IDSFMf2ASse/ijH6+vr2/Cy/VWhF4gci1cLbB88eFAse+qV5X+vljni+BtZlj1IOP8yy7I/Tpg+eQdEQ5gERANVaKAAl2doo6EAMaz4VEiakjzTSIG4sbC0vMk9SjHjxceMPhSbB9L/jZdf2374qnsYsu3TAERvcyll1oDiKb+Uc2uG5RzMbVMMEEmLplNm0ubYtsnyPwTrrjxEK+ACIs6SeeLWsmj1lX04PbQCha1Ctt9Cz2cXlLewU5rZdR4Er127VuzjNVnbFP3ZRbkjBUbO48nFfofoTddEe6HDd4i8/KcCe6Z33313W9S2sU/Zf/z48V9nWfaz6GLW7wEbX/7quf9jGvUdct9XjxxzAyv2/KRPOfAeNw1A9HhWVl4NKLbeFIYjLHkfBGOByPlJm1cR1nzalCfLf79FXqlNSAEXEKnDOFDMvGKzuowTRd1uAqHFJY2m7xCt9YPyxEynMXvgo1WEB0HsPHWg7UEtGohcRNlDJMOW+bowdui2eCCuRY1Ug0AI553OnDmTLy4uPj569OimZyZulmW/P6HCWXeaqQAiFcnTLD92INJcykOadzKgeOLb5xAYB87BsW1TDBAxSvxAGwPkgSHntvzXFV5tH7QCrUAEdrxT5cftYf1lmd7j4bY6brDdA8SwLFNuqUtMocPj6f0cDcSqd4gh+esu7MuR3yHGAtH7DtFEMsFY91TgsTaZpvYQV15I6yEeWFgqOljVlatw+9iBCKw8ZbGo6QFQbL0tJH0g5D1HDBBj0rV8CoiDBl5b5lqBCPCAGbY8/HyKbazTF+UJT243OlRtQLTyZGFdKwsPa6Hdt/hhGA1E4GdtwSzTqSY0SnXLQwEiTwy0NVuz540bN1xPywLi7l6mFNTnj7zouv915aJpu1WmpjjhvrEDMayYnmUDiidulzgxQOySvuW/zfJq/yAVaAViWDeBH94i2wAkn1Ox/MSLPLmjOTU8juUYIPLQZ82lYZnkIRAgtk3RQCSDgJA24LaXoeGFxQCR9BGBc3jdaq+HaE1H5j57m6gERAGxg1kqjEZbJey634DS9fi24wTEDnd8dg6JAiIs4IGVEIcK+258KK/bdgtjgBg2l4blm+2elpFOQLSMxoQxQIxJ1+J6gRiKFLMsIAqIHWydgNhQyQzoHXTVIU9fgWggWmsi3qJ5iNhvgIhzZba8HMYAsaq5NIQhUGwCo4DYUGHDXdMORJouKKjm+ZcLZdU6BTWmyZTOWOGTYVWa4Tbi0rwSbmtaVpNpWGL3viwP8elTZ8A5aAUi9feTTzeK+ottoeWP+outAY6EzG11nHSaeplaSae51D63sG00k/JazF6RlfdbPAsFRFOiJZx2IFoBpcBSQK3wNgGIguoFIk3fVIq2wh+eT0BsLpTmYTXH6r5XQBwwjp5+1lqBaJ1qeMgGfGHdZh174Hnt5gUi/UOoE+HEejiHvU7DeLY8NUA8spb+B8E27FWTy23CDmQsU9dnFxTKEFQUVM9waTFApDLgIYbnCStI1bKAaKWpOhQQnz4VZjgHrUCsqtNdtnmBWF1L4rZODIixI9XECpcaiGfPns1/9KMfudUdExDL4PGCS0BsNYeF0XAXmsiIAmKr/oqQTgEBMRZSYfzYkWrCYz3LR9bS/iD4/fffj/poemxAtE9p0Bog2gvwJu0FxFZrUxiNrXebRS87LWdVOrQKqQiDU0BAbDKObftS9zJ94WjcSDWRD+M5Q7dduHBhx8gHTWmMCYjld3sxHuLK4bUd7waayoE3XUsD4PJhfiREGMB5KNNqlmXkZyPR/NMsy36eKG3ynDp9znFtKDdL+YhSYLqBuHDw0C8ZhJVeQSnmw2vH8wOLy/lb5y4lmZeXD+344N56FfUVfvGLX8y/973vFb+1oTdT2zQmIPLymzFqQxABI1uvC4mzGjFSTRcgMrj3b/32Ldc8wF6mURZGkaXAiBSYbiA+u//AXy0trz5eOXx00znnKy+sFYM7M8Bz3VzEeeFYvn9hKX9u/4EcT/Erp8+3zq++fa74qbAnLnEOHDiQr66u5kePHs2PHTvWOhNvbW2tNZ7FWVpayr///e8Xv8/xjIE6JiACPDrRACzrBl3uFVYFxdgmUzrvvPz6eito7VykH9MJR0AckTlVVseuwHQDcfHgyi88XoEZq9gw9TvESX2Yz+cXbePh4T2ODYjcT4Do6QZt9z4GiHihpG+zpdEUCohjt5kzlX+af4HErEwCYpPxatuX+h3ipIDo/bVND0A81UMFc3120Xbv6vbHALEujabtAuKs2NapuE4B8U92j2bVVL+9+7ADng/z215jefZP7LOL1EBM3cv04sWLRZOp5/1hDx4iHTEAIuGVreUuVkNA9KvW9Cd1fyqKOasKDBmIdFyqK99dOzUVHiK/dvrn//ZB0vnD717LVw8fefzgwYM89fzGG2/8Ksuy21nqJtPUQJyUh+h5yugBiBRSQEglA4r0VuwyCYhdVNMxUsCnAPX0s635UZZlD4P1rqDxnfnpx3p/39xcbA/wzvHnn3mm87GRvdQ5z38QEFtIN+Gh2wAhQASGt2fZQ/zK+rvuTjh77FSD5jdr5qdvfpSDoSswZA8Rb66ubHd971l4iAwOkXrij0Rra2ubqc9D+lPTZDplHuJ6lmXMVDKaTgfrIcZ8h+h9Z2DxeHdwcOXwpIBIBa+bh26Mlb+nr8CQgWj9EarKN/u6TAKiGSp6CvKtIu3Htq0t7NJkynnondiWNvtjgMh7QPs+ccD/QwSCXQurFfDkTaYrEd8heu5jGCe2084ePUTTTKEU6KLAkIHY5XrajhEQMVb2c2BChvtilJPQiNUtxwLRPhYHinVphttjgBj+HNjzCQWu9ISbTNsKo3f/BIDoH6kmvF+eZQHRe5sVbwAK7KUlZwDZj86CgIgRCz/YxmClAqINJ9Y3EPlTxaVLl4reSjFt0iMF4s8WlpbzN8983T+f/Xr+pnNeO34yL9J3xi/y4Y179uvF6DkHV553PRBRNgfoIZ7at2/urzu81LdraQ3n5uYfp01/Lmn68/Pzfx5tinXAEBQQEMOnesAY/vU43Fe1HOMhAkMbMaVvIPI/LOC2srKSr6+vN/49OQTmWIE4N/9MvrS8kmSem5vP9+2bS5I2eX72uf0FcKvKU9W2AQKxMBo8OKaYGV2Ia06RNml+7b2PivTDFpU+l69efZL/IVh35SFaAQHRjBCwosJgtAhte1PoHamGpliDIMN22XJT2uyL/f0TniIVsu0PygbFc+fO5adOnSq8SzzMtnl5efnXWZb9LLqY9XtA0iZT7g8POm33put+ytbiwUM5Y+x6flg8VCB2vf6246gbXHNbvK77GUOW9FNN9vuqfou8UpuQAgJiuWLxng8olrdXrb/y1hnXuJQ89dqHnox7SccdT8eaLn+7AIp4ip6J/yEeP3788dmzZzc98+LiIk1Nvz+hwll3muRA5D+XVfe7j20A8fkjL7rTFxD7HSlEQKyrVtq+1SN71x/qQ1ta9SP1qm12DPuq9tMq0eWzi6q07Fx14Z4/uwj/ktBkBL1NphhBm0nbBp1uSpt9MZ1qQjFoCvVMI20yTQ/EtwXEBvNYPEW3ld2u++UhNiivXakVqPUQ8fyr7OX169eLcaABXNUPEmitq/qusQmIdIqk5S6c7PwnT57c8VoMQNIqyH7CKmBGA5HBn635igp5/XfvuJ7gvUAMjUNMk6kXiNwIE/3WrVuugboRu+oGhzehvNzDWKZ9FOjkQEzdZDryv10IiOWKEaxTD7e8+j7KutKYrAK1QOQWAz9AZhN2FwjZVB4TGrABMLPNFo+wCYjsLzfrY9dtArKWD5btqwLOU/W6rBMQ+dwCGDKHAGta7gJEYOtpLuW8XiAiiAlDBxvvJCDubo6bxDtEAXG37lbP5CFOlgA62w4FGoEIhAxE2Fhsbggq7Kl5iXwXDqDY1gcQQ88vzEcZnOV18hkNRKuMsWEXIMacwwtELwDL8QTE3YZZQNxhIKpW5CGWK1KwjvGTh1hVbEaxLQqIZdjhLRowLSzHsaLC/qZ3iFVgs2M5D8DF+SnHY70MYAHRlGsJpx2IeBv05mSmWVB/B8oAABB2SURBVNzzMBIDRJrZL3zzSpG+t2WBd8mz4iFaz20+OfJ+2xvrIdo99txb4sR0qsGw8C4HI3Pt2rWW2vRkt4A4CvDVZXJPQLQmVbxG8+j6BiIQNC9UQHRVSX+kaQYizdL27SchRs3TVO0FIsbeIAsYvR2xZgmIaGnv5nlVwLW3gSsGiNwD0uPetqVr+2OAiIFjwrjxLsgMUVMNExDrWDOK7VFApMk0LBPYU2BIWbFP2FZXV4tOMGE8yk8XD5FyaJ6nlUHKfjiV19k3MQ/R+x2iVcbY8Ojai7sECC9+r8vvvfde/tFHH7mTGVOnGjPEpvlLr512vR+OAaKlDWh5B23rTeGsABFNQlCx7vmcKQaIpnN4HttWF3qBSJOUPeVTQTBEZWNUVXEExFGAry6TUUAEcmGnGkAYlhnKR18eIumG7yvxDtlmfUc4F2WP9fI0PUA8lhaI3CxuqD0Jl4Usr48JiGWDCBDL26rWvUC0Y/FS+NlnGcC2vxzOChABWxmAHnANBYjlso+hAZJt01MA4v3gf4X2H8M+w/+ZZdkfJDzHvT38+aYObF231wIRADEK2JUrV3aUAx6S6NQIrMpeIGWlCxBJj7py9+7dbcDSg5VtNpMuE1DEhvOT4boyOjEgTkunGkT2VPaxAhEj6/2UJhaIpM13pRh/a8IrQzBcnyUgrp04teMhhHIWalG1PEQgYgzDp/MmKE4YiIUB52GPctv3TLrcsxe/9Ju9p01eLf2tD+K7QqzP42qB2HTPy15hU1zbB0ibOtVYvD7CqQFi7NBtseLxlMGN4emDitw2jRGINNV53l2ZcaaiMgKRrXtD72AL5OXAwtL2k5498bWEVNShTIXRaNMFjzkEIA8L4Xrd8UMDIsbO01RqdedpABHN6vTcy3a7FzQz7yWdumMt/bED0e59TPhUgLhw8NAv6WHIfw5TzIfXjufMKdImzcWlpaixRu1FrjfkhS+95winEYgY4bB3o8eD2wLi47pKXLcd0HmaTYm3sLT83xt+3AtwyvNQYEg+XEBEJzoa8UDCMh2QPO9ZzUjW6Vy13QNaO877DhHjBgxDzxBPsW0SEHd/ymTal0O71wJiW6na2/5tD/HAgQN/dfDQ6uPDR45uOuec8UPxzDwzT/rPPPOsKy7pHT6yVsyetMnH0tISY43mjDl6/vz51pl4DNjtiUs80n/zzTfdY5+OyUMEfuHDEO/5PJ7iFhA3y5W3vF54np9uFM2kgND7lxTysLh06I+GRLjIvLiBiC54zmgFDA2OZS3DdTOS4bamZYtP2BTP9nmBCAx5ZxQ+XIYdKOpMlIAoINaVjXD7U/EQV1ZWfhHT3BFm2LMMgOzlpid+bJwXX0zfqQZ4ejUaExAxvhjJcI7wEFuBaN6n1zM0gzxLQLRr9sKK+MSN8fjsHN7QC8TYumrxBUQB0cpCUziVQDxz5syogXjx4sXi/WHTjQv3jQmIXgNZjuf1EMvHeddnEYhebYgnIEb564W3HvPA0eVe6B1iaAX7WZ5KIKb2EI8dO+b23rrcJr5D9HqHpC8g+p9+6wyPgNisoYAoIEYp0G/k4gHD05+ii70Nj5lKIPIebuxNpgLiTgMtD7HVwrjfIdY9FDRtFxBb9Q8jRHmIfHrEHHY089wLr4fIawQ6AzalGe6ze61ONSEq+1/e7lST+h2igBjWzYktJ//90ytvnWl9hxhW7JhleYg7H0DK2pmRLG/va31W3yFS7uxbXELW2zS1e+EBIjAsf2rjTV9A7B+CYYoCYqhGw7J9h9gQZccuNZk2G/M2A8B+AbFZQzPCHi27xJlVIDJwhPXyBV7lUYSqtLR74QGiHR/TIcrSHxoQy72Lw57GfS0zaPyhQ4d29GLuK+1yOocOHdrMsmwji/UQ6WrNcDkMncNy2+T1EBkFZmNjY8fsGRkmppcp7d7M3rxzbQLibuOsJtNWT19Npg2GYYi9TAFgGVSsGyANZuXQgDVDQKTw382y7MEE5l9kWfbfJnAeu5ZLUUDko1tGbPF8fGv1wQtE+y0I7+uYGQTWM3mBSJ5tPFJAWzXAa9X5BEQBsRV/uyMIiFWVaWvbEIFoYAuBBxDZHm4rL9txMwbE3SV+WrZ4PUS8QVzYGBhS/r1ADL1NgOX5wJf0vUBkNA0DIscBds80ViAuLC3/+s0zX99MMdOUtLC0nKdImzTXjp98vDDuD/M/xphqbtVgEmbU3amG+xUCj3V5iJO4RQM6hxeIAAWvjWbNGzduuMHoBWIIJ85Fs6Zn8gIR4K6srOS3b98uPFAv2EcKxPUsyxjh35oC+g4n0ZRxY0DVJDYr38KY0rScYrYBn3nXmmJ++8I3CphT9lPMwV8JYnXtEt8NRHS1oQUBoef/nfIQu9ySAR/jBSJNjNbMCEyAi+cdX5fvEKmE3inmO0QgS755IRx6pE3nevfdd4uXuuX3m3XrJ06c+Nvi5eyA77myllwBNZk2VKohNpniGTKWrA0vyIOG9TgNvcbyciwQ7V1lm+dp57H0B9SpJnnleaon8AIRSFGQbWK96t9Wtt/Cd955J+o7xJjmUs6xtrbm+nA+TJfmWKDomd55553N55577i+Wl5f/0DMvLi7+1yzLrjzVm6qTP20FBMSGyjVUIAIhPERAaJ6igakuNGB53yES32bgWJeubbf0BcQJVWkvEIFI+IE6QPQ0awKUGI+P5lIPaK2+8b+sMF+2vRxaZx3bjqfo8RIH8hnFhEqDTtOTAgKiVbSKcMhANBB5QwOWF4jedC2epS8g9lQz25LxAhEPK+yIwrIHKLFA9KZr9cwLROAddtTxQlpAbCtB2l+hgIBoFbQiFBB399w2AJZDAbGidqXc5AUi5ZqCbB1ePO8POSYWiDHeIel7gUhcoGgeqDf/AmLK0je1aQuIVLiaSUAUEAdb82OAWFO+GzfHArExsYqdMUCsOLx1k4A42KI75IwJiA01S0AUEAdbeQXEhpo7nL9XDLb8KGOVCriBSOcNfgxMl//vXL3R2smCJjVrRis3rzWt04GD8zTFsX28D+OzEe/E+3kbCsvTwrMHIPL5UOxU3As0s+vrM7R7oXeIsbdloPEFxOZqLw9xoAV32NlyAzHs2s93b57u+GaEYww7nxHwTaTnmBgg8urBA8GwlnUA4rWt4cIIbdlbAgREr1KKl0UN3RYWau+ymkxVymZQATcQQ0B5gRULRKCLh+hNPwaIdFTDO2SwDk8nO+xGByBShGywCYAYMwmIMWrNelx5iM1ol4c46zWk0/VHARGv8MI3ryRpMqVJlg/OAW8KIDJIB53V6LXN0I4eKHYAIhBk5CJmwMjg0t6puBcvv76ev33+Uu8z6dK8/IUvvdJ72uTX0s+yjNGPNKVWQEAUEFOXsRlMPwqIBi3eI4ZNqKH3GC7HeIjhj25TADGsPUCRXtxtUwcgWhGKAaEdU9yLA4tLxe+cGIe3z/m5/QcKIO5fWOw1Xcsj+Qa4AqLdzsShgNhcfeUhJi6A05l8FBANdnhyvOuz9brQC0RgGHpGS4dWCy+mLl3bHtNkGtYe3iV6BsnYAxC7lJbiXqCZXV+foWmVKn271/owv8ut73AMQOQdQN3YnHvdfvz48cc0pew1nbrjn3/++ccp86+xSTsUKh3SCYh4h317iKHxT+0hAkNg1zYJiPrsYrAm4tlnn/0XnjE6u8ZZXFz8y/379/+Prse3HXfgwIG/PHjw4J+2xeu6X2OTDrboDjljbiACqes/vZv/zu/dz8PmzRBk5WXzGsrb29b7BiI9TBkCkR+G379/3+UdAksBUUAccuVV3qSAFOhXATcQeX8I4DwDPRvwugLRjm8LrRmwzdNjP1AEcN6RnwREPwy5T3av1WTabwVValJACkxOATcQ2+BUtd+MZNW+PrbFANEDzXIceYh+KNq9FhAnV3l1JikgBfpVQEAsUzBYFxAFxH6rm1KTAlJgyAoIiAEAy4tDBSLN1h9+91rRE/eTTzdcvVLNm8aTa/POu6QvD3HI1Vx5kwJSwKOAgFimYLA+VCB++wfXi6HzeK/LMHqeTk4xQLQRgwAj6XOeNogKiJ7qpjhSQAoMWQEBMQBgeXGoQAwBxfegnm9CY4AYdpxioPWYcWv1DnHI1V15kwJSoEkBAbFMwWB9qEAMvTW8RRvyLtxeXo4Boh2L5+mBLfHlITZVM+2TAlJgDAoIiAEAy4tDByKeHEA0gDWFsUDEKwSGDM0WMwiDPMQxVHvlUQpIgSoFPt4af9LGoRxdODc3lzTP8/PzpD+JqXg48XR6MfDhwYXNm7a9KowFoqWB9+n5P6U8xEkUEZ1DCkiBlArwZ4LiDxAMeN33fPr06QJWfadr6b300ktF+gzFlmLml1FbDwwp74GlHQXEEIZAsQ2MXYEI6DzNpgKi3UaFUkAKjFWBwgiXmwr7Wrcmx77SK6fDHysAVqrJ8j+hm+sGIs2k37l6I+eTC2aPBxcDRNI3TxUYtsEWb1JAnFAp0WmkgBRIpoCA2EDToQIR+IRz2OvUmjrLYQwQ7W8mns46dh4BMVkdVcJSQApMSAEBcYRANAjFhDFAjEnX4gqIE6qxOo0UkALJFBAQBURXL1UDX10oICaro0pYCkiBCSkgIAqIAuKEKptOIwWkwLAVEBAFRAFx2HVUuZMCUmBCCkQB8fPPP883NjbyO3fu5A8fPmxAyZNd1imlNWKe53zicOnSpe35+vXrrYfF9DIlv6RJ3gm5lrbJ8j/Je0HTY12z5F626x3ihO6iTiMFpMBoFXADEYCsr68XP9ltA4ntN6DYel3IT3vv3bu3vZtlYNc2xQCRbxdtAoqe9C3/E7q7xb0QECektk4jBaSAFCgp4AYiXpXHazPoEBpQwm2eZbxFz5/tY4B4+fLlbQhyHeStbbL8lzRLtSogplJW6UoBKSAFHAq4gbiyslKMBgNMTp48GQWUNvCU94feXHlfuB4DRADLNVy5cmWHNxqmV15+GkB8+fX14j+Hb5+/1Gv4hS+9UgxikCp90t0a1YcypUkKSAEpMDoFXEAEJhg7e+9Gk+apU6fK/Ni1bkDZtaNhg7e5lCRigEjegTmwpenXrqUhK9se7gTv6t0syx4knH+ZZdl/Tpg++V+doF46lRSQAlKgNwVcQAQa5SHSyutVYOkCRG9zaSwQaTK1ZljOwdw2Wf57U1sJSQEpIAWkwGAVcAORgboNKHhXNJu2TQaUtnjhfm9zKcfEeIhhfskXgGybLP+DvXvKmBSQAlJACvSmgBuINGXiVdnnC6y3TQaUtni2nzT5a4V3igEiPUtJG6gTkre2yfLfm9pKSApIASkgBQargBuIwAPPEEh43r8R34DSBh7bD6y8aXNMDBCJT/pd8j/Yu6eMSQEpIAWkQG8KRAHRwOUNY4HoTdfixQLRjvOGlv/e1FZCUkAKSAEpMFgFBMQGOgqIgy23ypgUkAJSoHcFBEQBsfdCpQSlgBSQAmNUQEAUEMdYbpVnKSAFpEDvCgiIAmLvhUoJSgEpIAXGqICAKCCOsdwqz1JACkiB3hUQEAXE3guVEpQCUkAKjFGBAohbgzLb4MwKs6yswRjvrfIsBaSAFJACEQowEPONLMs2Es0/zbLs54nSJs+p0+cc1yL0VFQpMBEF/j+GdQLU1BRMXwAAAABJRU5ErkJggg==" 10 | } 11 | }, 12 | "cell_type": "markdown", 13 | "id": "c6653299-6f89-4a2b-be05-1b434e3e7895", 14 | "metadata": {}, 15 | "source": [ 16 | "### 多输入通道\n", 17 | "![image.png](attachment:d6ff1699-f041-4865-9c55-ca4a7195765c.png)\n", 18 | "$$(1\\times1+2\\times2+4\\times3+5\\times4)+(0\\times0+1\\times1+3\\times2+4\\times3)=56$$\n", 19 | "\n", 20 | "### 多输出通道\n", 21 | "\n", 22 | "\n", 23 | "### 卷积层\n", 24 | "![image.png](attachment:c4539663-dfc2-40eb-bacc-f1ce9fc046c8.png)\n", 25 | "\n", 26 | "\n", 27 | "\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 28, 33 | "id": "2f29fabd-8095-4d90-8c20-2768c5ab8b2b", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import torch\n", 38 | "import dltools\n", 39 | "\n", 40 | "# 多输入通道\n", 41 | "def corr2d_multi_in(X, K):\n", 42 | " # 先遍历“X”和“K”的第0个维度(通道维度),再把它们加在一起\n", 43 | " return sum(dltools.corr2d(x, k) for x, k in zip(X, K))" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 29, 49 | "id": "6b293461-dd02-4ed1-8a32-7fb3b271728b", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "tensor([[ 56., 72.],\n", 56 | " [104., 120.]])" 57 | ] 58 | }, 59 | "execution_count": 29, 60 | "metadata": {}, 61 | "output_type": "execute_result" 62 | } 63 | ], 64 | "source": [ 65 | "X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],\n", 66 | " [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])\n", 67 | "K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])\n", 68 | "\n", 69 | "corr2d_multi_in(X, K)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 30, 75 | "id": "e68b7f4a-fbbb-4d15-bfbd-4ae8b27c37ee", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# 多输出通道\n", 80 | "def corr2d_multi_in_out(X, K):\n", 81 | " # 迭代“K”的第0个维度,每次都对输入“X”执行互相关运算。\n", 82 | " # 最后将所有结果都叠加在一起\n", 83 | " return torch.stack([corr2d_multi_in(X, k) for k in K], 0)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 31, 89 | "id": "722942ce-f37b-4773-880c-e41a9d6ce8c3", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/plain": [ 95 | "torch.Size([3, 2, 2, 2])" 96 | ] 97 | }, 98 | "execution_count": 31, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "K = torch.stack((K, K + 1, K + 2), 0)\n", 105 | "K.shape" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 32, 111 | "id": "c2321447-846e-4b01-92bc-b5922ea4e110", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "tensor([[[ 56., 72.],\n", 118 | " [104., 120.]],\n", 119 | "\n", 120 | " [[ 76., 100.],\n", 121 | " [148., 172.]],\n", 122 | "\n", 123 | " [[ 96., 128.],\n", 124 | " [192., 224.]]])" 125 | ] 126 | }, 127 | "execution_count": 32, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "corr2d_multi_in_out(X, K)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 36, 139 | "id": "26b8b628-f7e9-4820-a3ac-f9e7ee1cc778", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# 卷积层\n", 144 | "def corr2d_multi_in_out_1x1(X, K):\n", 145 | " c_i, h, w = X.shape\n", 146 | " c_o = K.shape[0]\n", 147 | " X = X.reshape((c_i, h * w))\n", 148 | " K = K.reshape((c_o, c_i))\n", 149 | " # 全连接层中的矩阵乘法\n", 150 | " Y = torch.matmul(K, X)\n", 151 | " return Y.reshape((c_o, h, w))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 39, 157 | "id": "807f3bc7-79a5-4150-978b-b7f64ffdfff9", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "X = torch.normal(0, 1, (3, 3, 3))\n", 162 | "K = torch.normal(0, 1, (2, 3, 1, 1))\n", 163 | "\n", 164 | "Y1 = corr2d_multi_in_out_1x1(X, K)\n", 165 | "Y2 = corr2d_multi_in_out(X, K)\n", 166 | "assert float(torch.abs(Y1 - Y2).sum()) < 1e-6" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "87b8b116-b57c-4eef-9870-0802224d8c38", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.6.5" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 5 199 | } 200 | -------------------------------------------------------------------------------- /86-池化层-卷积神经网络.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": { 5 | "bbbefdf2-760d-45e0-948c-323410b6ad27.png": { 6 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAO0AAABsCAYAAACGoTXyAAARSUlEQVR4Ae2dPYxdRxmGJ9m73jXGa8dehwhbspWIPynErtbubEiRBmRIRRpsKQYa0LqOkOwGQhVvQ0VhC0shnSN6ZEsQmiA5BQ2VSUERGiOUFESyD3ru3W/328n5mbn3zv3bd6TxnDNnzpwz7/2e+ebMmbMOQUEKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAL7QoH3QwhnopZeDSH8M4RwLsq33aMhhI9ajls5pVKgtALYYRzuhBAehBCw0zh4m74ZQiBa4FjdOXZ8JlJu+FEI4cZ25KYAlkZ3gUkDEWzmGzkTSusmxq0Ajga7NfvFbgnYLttN9gnMFmJoOXbJDs5qSgMBz25+a7uHuhhCIG5uH29qyPWop5rVduq+Fk8BoAUyRorYL7Z4P4SADZv9mj37keTcQ0tD6ZVoHB6TlAYCqUXyrReLf3rK/0feNpZF+xNSAGB/sG2f2Kt5WLNdvC2RfQtzD60Nb60hNJreimgNphfzjbbGk3K+ge/ztS0FSiuAM8FusU1slIATMfvluHnj7cP9pAohtMUmW/d1THWbRhPxlqQGqKU+L75RGueffePj2pcCJRXAqWCfOA4mTQGWfQKpwWx524f2JNi5Ab/nwDzsWMNoACBao0kN4LgdCGXPCpTB4ypIgUkrYHByXbNXbNPyzbbtOPkWsFuiBY758pY/Uym9FQ/tzMAxrDBALTURfENpAGD7xjIUoccziGeqkbqZhVUAu2Sy9GGip8WevS1jw96OZx5aG+9z0/Zsa9vWU9VBy3NvXW9Eg62ehbUSNWxmFGACiokoHAiRgF0SedzDHrFnb6tszzW02+3caSwbNNIaa42zPI7b0BnPWhcAGnDx4ApSYBIKmL1yLQPUUp9n23hlm6xihGmjTPI45s+dxP0PfQ27UQAFPFslxXOqHQNUjncFgKUOBSkwCQViaNnHfgk4GSaoLLCNbVKmLnLMl7fzlEoBKSAFpIAUkAJSQApIASkgBaSAFJACUkAKSAEpIAWkgBSQAlJACkgBKSAFpIAUkAJSQApIASkgBaSAFJACUmCWFWCdr/2xtRLp2yGEewWvQf1/KFg/mvCVh4IUmBkFWMxcff2VjU9LxFMvfvN/Jes/eOjwE+rf2Nj4tEQ8efLk50tLS/+dmV9LNyIFtr9AqP74j6pI/NXv7/f/Rk6p+l/euNivvyoUbty4YX/jR8YiBWZGgb6nLQWVoJ2Z31k3skAKCNoWLy1Pu0CWvkBNEbSCdoHMeX80RdAK2v1h6QvUSkEraBfInPdHU7Kg3Xz7dvXWb+9VW+8/TJptzp2IevfDx/36UyfGcmaPHz9+XN28ebMfHz161ILq7iE90+4PCOatlcnQfv/HmxUQAtR3f3glCdwcaH/3p0cV1wDEcUMLsFeuXKnu379fAeKRI0eqFHAF7byZ8/643yRo8YCHDh/ZgQlvC7hdcOVAS12ULwHtw4cPK8C1cPbs2T7Att+UCtr9AcG8tTIJ2hgm9g+tHZ0baGMoL168KE87b5aq+91RIAlanmW9B0z1oKnlzGPHnYPlN6U5z7QGLsPizc1N221N5Wl37GQRNkqvs79VeJ39zjr4JGhjmFJhTC1nUMbXsfymdBhoU4GFZkG7CKzutIE/iF994+z5z0rEY89/tb8O/vz585+ViH4dfBK0dc+051+9PHfD41u3bu082/pn3CZ3K2h3DH4RNvhfKzpttslBdOW/8fPBOvUmWxo139tiErTccDx7zGRUV0OG8bQvfutcZ7123RxPy+zx9evXd177AHBX8EKNwWpf3/7vTjAexXQNfhlCeH4M+u8/aAHF3tMCo4HTluZAizenfnqs1PpToWX2GAB9JK8rjBPalZWVf7/00ksVk2CK6RrgHUMIPxG0u1+cJXvaNjibjuVA21RHW34qtF1wNh0fN7TvvPNO06WU36DA4cOHPw8h/EzQCtoGE9mbLWj36jGNPUE7UN3bojxtiyV6oUbt6Rkey9O2iN1wSNAK2gbTqM8WtPW6TDJ3mtDaPEvbI5ody509vn379s6kaO78ijxtiwUK2hZxJnRomtACIvMmBmZbmgMtC3x4m4F9EXNfPwraFuMTtC3idByK13t3FG88PC1omUS99tatItAC7J07d5JgNWG8LQpaU6Um9ULtx2davoaqez3FBxdo0xY4j6+qRg3TgNaGxYBbwtMCLRryOothckrwtvg9TnzjFzeLxFdfv9q/sZc3LlUl4sFDa/367TvZcaeXLl3q1z8qsJw/7ETUvXv3+s8+W1tbKb9tZ5mc+k6fPr3z3OW1vXr1ahK0KZ8/dt3wNKBlqMtQuBS01mY6NfhL0clDe623fKBae+7EkxLxy0eOVc8uLVVfe2XjSYl4YPVgtbKyUp04ceJJiYinWVpaQtiRwzDQAhg9MT8uPXPOumkzDJ9SF5F6U+qjDNeOI6vJvKfF6B48eLAnnjt3bs++Hff3k7I9aWhZ4LP5mzvVr+8+qK69tVWxQo/ttudZjuU80/p243XRtyt4aDU8blHLCzUqtcNA639MthlyxuHjjz+u3nzzzT3Z7733XnX37t09eez4+gC3rj5/Eu1vir4uOoJ4GI0HAfo439efsj1paPGuAEjkm/HnT57ub5eClo5Ys8cplpBYZtrQ+tsEEnrlugC0Bi7A4uW6Qlt9QMijQUqsu4558qb7rTunKW/S0Ho4SwyPAdQedewvqjS13ed7W5Sn9cpE216oaXhafzv0yG3PPkALZCnAUm9XfeZhKQvgeEwMzoa5pE2vKoDVPHlTGd+2tu1pQsufQGK47EFu2k4dHqMH2vKIgUapwduioG1RzQs1TWiBpusHxsO+8MILOx63pVn9uqizLdB2QCUFQtv2aV0ddCxMYBG451G97TShbQK0Lj8V2jbN2455WxS0LUp5oaYFLd7NA1vnufyQ2A+V65qWUh/n0XY8N7PGzBYDK4HnVO7B0vgalPMwX758Ofm1RlwX+4J2oIq3RUFbZynbeV6oaUALYAYO8PA9MHk+MBEVD4kBt24iKqU+q5u2EwlAyDWAlQ4ET1r3PTJeNfasBjjPycMEQTtQzdtiMrS8cGYKnHe6pf7uMcMOrsNQo24IEuflfpqH4dj7xtj46wzKCzUNaLlfg4e0DpS6+27Ks1c1VmdTfVz3zJkzfe+K58SrAirw0ongPb1HRUugbnolBbiUp9NhOycI2oFa3haToeWBnAdzZtT8n1ONQfL7lGXq3+d1bQNs6iqUHGh51sILmMGmGI8XahrQ5hh3ybJACZDAbENf9EQfgs9vuw/K5QZBO1DM22IytHhAA64UtECes94zB1qAZb1n2wxsbFBeqP0MbazLJPcF7QjQGrB43JS/D0X5HE87zHrPHGgBkGEenj+1xxe0k8Sz/lqCdkRogZC/woinZahsIDelOdDacyznlBgem0kwvBtmvac8rSk42XSM0G7xu5daZ//t84N16jZnMu6UuQTuHztMHh57MFnexTDW59Vtp0I77HrPHE/rTY2hsj2b+fx4W542VmTy+2OE9g7r7I8dX39aIq4e/FLV6y1X6+vrT0vEtbXBxzFDQ4tXTFkpkgot5aiTmLPec1hogXEeZo8nj8jsXXGM0Pb/hGqpFjI7jicsFbwDSfK09rrHJqP4G8i2XedhLS8VWitPyjnjHh4DKOs9mTEmNr2aiAX3Qml4HKszmX1BO9DZ22IytHhBhsR42BRgDUB6Hw9l13bOes9UTwuoNJr3kqmTUEjlhRK0k4E0voqgHRLaLtCajg/jaZvqqstPhTY2hNT9eYG26TVW14iC53rrxCyt04b6/YcCddsp773r6u7KE7QDhbwtJnnaOmBS8gTtrn8e5nvaLoPmOEDxKsugIzWIec5qC35VU1zWQ4jBcJy0Lvp62q43zDFBO1BN0CZajxdqF7/htkpBS1MADHAstZlx8nieZ/mgQW1NZw0xxy34bWbXfXkD1crGKcftmvGxUfcF7UBBb4vytC1W5YUaDtXds0pBC5R+Jhz4eHYHVP5cDuuEGSbTFguUP3r06I5HJp/zgJ4veoDWB85l7TFlmqK/B3/uqNvTgBb9/HtW/5VVU3vQJXX2GJ3pFImpnZ23RUHb9CvMyUQUsAAUKcYGjBgQBkFKIN+gxWBY3A/QPlAH+ZSNA+fa+fGx0vuThpZHC+vkaHPqp4U50FKW34FAB5nS4XEvIy2uSHmepYyeact7Wn54fnR+VHptnm+t9zZoOUY0I/FDY4wU72rg9wtF/3AunQDXoc66mGJ4UbVJu9OA1t8YUBlgPj/eRpNUT0s5q5MOIkU7fgNBG6tes++F2sVvuK1Sw2O7bT/55KE1KOMhngFN6r2y1edTdOD8ulljn+fPGdf2pKGN79t0ivPj/RxoAZVRDR+w1I1s4rrZ97b4I+gttR7T/u5xqfqPf+VUv/dhXWaJyDel1rsNh+ruWSWhpdfGUxIwHg8t23WGERtjvO8Nx4aIlCEytMaj276l/pxxbU8TWjqqOu3q2oYG2EpqsPJ0mCnBQ3ttqdernjt2/GmJuHbkaPXss0vF6u/1en2hLly48LREPHVq0Cnsojf8VkloMSx+VALGYNDSo1tgmGfe2MrZsa59ALVgQ2TbT30ms/K56TShjTVru3eDsK2MHTPN0JIO0H4vO16Xemj7E1F1hcaRx83k9D6518wRKrduynuhhsd1cGZJaPGyBiResc4IPHi0De18iPfNc2NY1EnAo7Ntz2Nck+WhDPUsz9c5ju1pQhtr0taeHFv0v4XvcNvq97YoaFuU8kLNKrQMr+i5LTCko/f2jwtAZeBZudgg/TmUtzrx1tYJUAfnWeQYGpFaeat/XOm0oM0ZGtPWXGitk0M/rtUVvC0K2ha1vFCzCi23b162pSlfOJQyY8lJZlxfqCDKSC0Xnda5Oy1o6ahydM2BFq3wsFzDOsQuIbwtCtoWtbxQswxtSxPm/tC0oM0VLgfa3Lop721R0LYo6IUStC1CFTwkaAfielsUtC0G54UStC1CFTwkaMcALWNwWx3T9VtRNnX2mDF+yfWePJ9w38SUB3/aJmi7fuHyxwXtQGNvi1melhlC4EoNqdAClM1EcnPMUgJXV8h5jvCvKmhHykSMF2ocnva1117rdwTUq5imwcrKypMQwk9H1T+EsFh/bqYLDo4Da/zaoOu8HGh9XUCVMhuZAy0e3+qkg5g0tKurq++ur69/pJinwdra2ochhO8I2mjtsQemaZv3eCwsB1yGsQZAU3nyU6GN6wDGlJADLZ6N948sBkgdLYzT047B6FTFaArsL0+LV4o9VcrL9GGgzXmpnQMtnQAdDu1IGXpTXtCORsmMnd2H1s+bjHPb1qmPs05f1xf+7nGXVwM+vJQF9rnJrjAMtHQGqS+1c6C151g6oGHWe86YAep28hU4s7y8/JfV1dW/lojLy8t/7/V6n5So2+rcfi4f/LHyLvgYCvtZYGD06yebzh8G2tShMdfMgdbW0nIew2O8aFeQp80nQ2eUVyB59pjJG3sWJE0ZYuZCmzM0zoWWTsY8ODCmvPYRtOUNUFfIVyAZWiABVEBMMXjK50JLeQOrywvmQstowe6f66QEQZtvUDqjvAJZ0KYYui+TC60/N2U7Z3icUl9cRtCWN0BdIV8BQRuT6vYFbb5B6YzyCghaB2m8KWjLG6CukK+AoI1JdfuCNt+gdEZ5BQStgzTeFLTlDVBXyFdA0Makun1Bm29QOqO8AoLWQRpvCtryBqgr5CvQh9avcRznNh8YsJJqnHX6uia53jNfWp0hBcoocKbX6/35wIEDHxSKf3vmmWc+KVT3B8vLyx8tLS39q1T91BtC2CojvWqVAlJACkiBhVfg/wdYCZZ/jArNAAAAAElFTkSuQmCC" 7 | } 8 | }, 9 | "cell_type": "markdown", 10 | "id": "df5cec76-1a44-4d1b-94a5-093c3ebee1d5", 11 | "metadata": {}, 12 | "source": [ 13 | "\n", 14 | "* 1. 二维最大池化层和平均池化层
\n", 15 | "![image.png](attachment:bbbefdf2-760d-45e0-948c-323410b6ad27.png)
\n", 16 | "$\\begin{split}\\max(0,1,3,4)=4,\\\\\n", 17 | "\\max(1,2,4,5)=5,\\\\\n", 18 | "\\max(3,4,6,7)=7,\\\\\n", 19 | "\\max(4,5,7,8)=8.\\\\\\end{split}$\n", 20 | "* 2. 填充和步幅\n", 21 | "* 3. 多通道\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "id": "5c7f75e0-4f53-4138-bca6-c726e465421b", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import torch\n", 32 | "from torch import nn\n", 33 | "import dltools\n", 34 | "\n", 35 | "def pool2d(X, pool_size, mode='max'):\n", 36 | " p_h, p_w = pool_size\n", 37 | " Y = torch.zeros((X.shape[0] - p_h + 1, X.shape[1] - p_w + 1))\n", 38 | " for i in range(Y.shape[0]):\n", 39 | " for j in range(Y.shape[1]):\n", 40 | " if mode == 'max':\n", 41 | " Y[i, j] = X[i: i + p_h, j: j + p_w].max()\n", 42 | " elif mode == 'avg':\n", 43 | " Y[i, j] = X[i: i + p_h, j: j + p_w].mean()\n", 44 | " return Y" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "7934ab0f-7428-466d-81aa-1fd99807cb77", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "tensor([[4., 5.],\n", 57 | " [7., 8.]])" 58 | ] 59 | }, 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "X = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])\n", 67 | "pool2d(X, (2, 2))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "ee416947-c82c-4116-8e3c-d3384d480ad1", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "tensor([[2., 3.],\n", 80 | " [5., 6.]])" 81 | ] 82 | }, 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "pool2d(X, (2, 2), 'avg')" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "id": "fe01ae21-7791-44de-bc7d-13481b239b4f", 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "tensor([[[[ 0., 1., 2., 3.],\n", 102 | " [ 4., 5., 6., 7.],\n", 103 | " [ 8., 9., 10., 11.],\n", 104 | " [12., 13., 14., 15.]]]])" 105 | ] 106 | }, 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "# padding and stride\n", 114 | "X = torch.arange(16, dtype=torch.float32).reshape((1, 1, 4, 4))\n", 115 | "X" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "id": "5f68e80c-bc5a-4074-9503-d082ca412dd0", 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "tensor([[[[10.]]]])" 128 | ] 129 | }, 130 | "execution_count": 5, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "pool2d = nn.MaxPool2d(3)\n", 137 | "pool2d(X)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "id": "36923e6a-8de8-4d3b-b74c-9b2656352227", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "tensor([[[[ 5., 7.],\n", 150 | " [13., 15.]]]])" 151 | ] 152 | }, 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "output_type": "execute_result" 156 | } 157 | ], 158 | "source": [ 159 | "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", 160 | "pool2d(X)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "id": "60dcddb2-e408-47fb-a3bb-bb0cde94224c", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "text/plain": [ 172 | "tensor([[[[ 5., 7.],\n", 173 | " [13., 15.]]]])" 174 | ] 175 | }, 176 | "execution_count": 7, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "# stride=(2, 3):窗口将每次向右滑动三个元素位置,或者向下滑动两个元素位置\n", 183 | "# padding=(0, 1):表示在上下两个方向填充零行0,在左右两个方向各填充一列0\n", 184 | "pool2d = nn.MaxPool2d((2, 3), stride=(2, 3), padding=(0, 1)) \n", 185 | "pool2d(X)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 8, 191 | "id": "76010bee-9b4a-40e9-901d-68814d8e6d82", 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "tensor([[[[ 0., 1., 2., 3.],\n", 198 | " [ 4., 5., 6., 7.],\n", 199 | " [ 8., 9., 10., 11.],\n", 200 | " [12., 13., 14., 15.]],\n", 201 | "\n", 202 | " [[ 1., 2., 3., 4.],\n", 203 | " [ 5., 6., 7., 8.],\n", 204 | " [ 9., 10., 11., 12.],\n", 205 | " [13., 14., 15., 16.]]]])" 206 | ] 207 | }, 208 | "execution_count": 8, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "# 多个通道\n", 215 | "X = torch.cat((X, X + 1), 1)\n", 216 | "X\n" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 9, 222 | "id": "7f27d44e-7792-43ed-bc46-c32f2c35088e", 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "data": { 227 | "text/plain": [ 228 | "tensor([[[[ 5., 7.],\n", 229 | " [13., 15.]],\n", 230 | "\n", 231 | " [[ 6., 8.],\n", 232 | " [14., 16.]]]])" 233 | ] 234 | }, 235 | "execution_count": 9, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n", 242 | "pool2d(X)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "id": "87245102-02cb-43bb-ae49-663d3791526c", 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python 3", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.6.5" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 5 275 | } 276 | -------------------------------------------------------------------------------- /98-数据处理(比pandas快几倍)-pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "54f883cc-4aec-4f69-b602-7956a59dffba", 6 | "metadata": {}, 7 | "source": [ 8 | "# Polars(据说比Pandas快不少)\n", 9 | "* 安装方法:pip3 install -U polars[pyarrow]\n", 10 | "* 用户指南: https://pola-rs.github.io/polars-book/\n", 11 | "* h2oai's db基准测试结果: https://h2oai.github.io/db-benchmark/\n", 12 | "* Python文档: https://pola-rs.github.io/polars/py-polars/html/reference/index.html\n", 13 | "* 用户指南: https://pola-rs.github.io/polars-book/user-guide/index.html" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "id": "91bdf7a2-1aeb-468f-808b-c9f07bda3322", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "import polars as pl\n", 25 | "import timeit\n", 26 | "df = pl.DataFrame(\n", 27 | " {\n", 28 | " \"A\": [1, 2, 3, 4, 5],\n", 29 | " \"fruits\": [\"banana\", \"banana\", \"apple\", \"apple\", \"banana\"],\n", 30 | " \"B\": [5, 4, 3, 2, 1],\n", 31 | " \"cars\": [\"beetle\", \"audi\", \"beetle\", \"beetle\", \"beetle\"],\n", 32 | " }\n", 33 | " )" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "fc239aaa-b7be-4a12-864e-a88950a91b44", 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/html": [ 45 | "
\n", 46 | "\n", 59 | "\n", 60 | "\n", 61 | "\n", 62 | "\n", 65 | "\n", 68 | "\n", 71 | "\n", 74 | "\n", 75 | "\n", 76 | "\n", 79 | "\n", 82 | "\n", 85 | "\n", 88 | "\n", 89 | "\n", 90 | "\n", 91 | "\n", 92 | "\n", 95 | "\n", 98 | "\n", 101 | "\n", 104 | "\n", 105 | "\n", 106 | "\n", 109 | "\n", 112 | "\n", 115 | "\n", 118 | "\n", 119 | "\n", 120 | "\n", 123 | "\n", 126 | "\n", 129 | "\n", 132 | "\n", 133 | "\n", 134 | "\n", 137 | "\n", 140 | "\n", 143 | "\n", 146 | "\n", 147 | "\n", 148 | "\n", 151 | "\n", 154 | "\n", 157 | "\n", 160 | "\n", 161 | "\n", 162 | "
\n", 63 | "A\n", 64 | "\n", 66 | "fruits\n", 67 | "\n", 69 | "B\n", 70 | "\n", 72 | "cars\n", 73 | "
\n", 77 | "i64\n", 78 | "\n", 80 | "str\n", 81 | "\n", 83 | "i64\n", 84 | "\n", 86 | "str\n", 87 | "
\n", 93 | "1\n", 94 | "\n", 96 | "\"banana\"\n", 97 | "\n", 99 | "5\n", 100 | "\n", 102 | "\"beetle\"\n", 103 | "
\n", 107 | "2\n", 108 | "\n", 110 | "\"banana\"\n", 111 | "\n", 113 | "4\n", 114 | "\n", 116 | "\"audi\"\n", 117 | "
\n", 121 | "3\n", 122 | "\n", 124 | "\"apple\"\n", 125 | "\n", 127 | "3\n", 128 | "\n", 130 | "\"beetle\"\n", 131 | "
\n", 135 | "4\n", 136 | "\n", 138 | "\"apple\"\n", 139 | "\n", 141 | "2\n", 142 | "\n", 144 | "\"beetle\"\n", 145 | "
\n", 149 | "5\n", 150 | "\n", 152 | "\"banana\"\n", 153 | "\n", 155 | "1\n", 156 | "\n", 158 | "\"beetle\"\n", 159 | "
\n", 163 | "
" 164 | ], 165 | "text/plain": [ 166 | "shape: (5, 4)\n", 167 | "┌─────┬────────┬─────┬────────┐\n", 168 | "│ A ┆ fruits ┆ B ┆ cars │\n", 169 | "│ --- ┆ --- ┆ --- ┆ --- │\n", 170 | "│ i64 ┆ str ┆ i64 ┆ str │\n", 171 | "╞═════╪════════╪═════╪════════╡\n", 172 | "│ 1 ┆ banana ┆ 5 ┆ beetle │\n", 173 | "├╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤\n", 174 | "│ 2 ┆ banana ┆ 4 ┆ audi │\n", 175 | "├╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤\n", 176 | "│ 3 ┆ apple ┆ 3 ┆ beetle │\n", 177 | "├╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤\n", 178 | "│ 4 ┆ apple ┆ 2 ┆ beetle │\n", 179 | "├╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤\n", 180 | "│ 5 ┆ banana ┆ 1 ┆ beetle │\n", 181 | "└─────┴────────┴─────┴────────┘" 182 | ] 183 | }, 184 | "execution_count": 2, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "df" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 3, 196 | "id": "a397971b-cad3-4c81-9388-36f8d6c020bf", 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/html": [ 202 | "
\n", 203 | "\n", 216 | "\n", 217 | "\n", 218 | "\n", 219 | "\n", 222 | "\n", 225 | "\n", 228 | "\n", 231 | "\n", 234 | "\n", 237 | "\n", 240 | "\n", 243 | "\n", 244 | "\n", 245 | "\n", 248 | "\n", 251 | "\n", 254 | "\n", 257 | "\n", 260 | "\n", 263 | "\n", 266 | "\n", 269 | "\n", 270 | "\n", 271 | "\n", 272 | "\n", 273 | "\n", 276 | "\n", 279 | "\n", 282 | "\n", 285 | "\n", 288 | "\n", 291 | "\n", 294 | "\n", 297 | "\n", 298 | "\n", 299 | "\n", 302 | "\n", 305 | "\n", 308 | "\n", 311 | "\n", 314 | "\n", 317 | "\n", 320 | "\n", 323 | "\n", 324 | "\n", 325 | "\n", 328 | "\n", 331 | "\n", 334 | "\n", 337 | "\n", 340 | "\n", 343 | "\n", 346 | "\n", 349 | "\n", 350 | "\n", 351 | "\n", 354 | "\n", 357 | "\n", 360 | "\n", 363 | "\n", 366 | "\n", 369 | "\n", 372 | "\n", 375 | "\n", 376 | "\n", 377 | "\n", 380 | "\n", 383 | "\n", 386 | "\n", 389 | "\n", 392 | "\n", 395 | "\n", 398 | "\n", 401 | "\n", 402 | "\n", 403 | "
\n", 220 | "fruits\n", 221 | "\n", 223 | "cars\n", 224 | "\n", 226 | "literal_string_fruits\n", 227 | "\n", 229 | "B\n", 230 | "\n", 232 | "sum_A_by_cars\n", 233 | "\n", 235 | "sum_A_by_fruits\n", 236 | "\n", 238 | "rev_A_by_fruits\n", 239 | "\n", 241 | "sort_A_by_B_by_fruits\n", 242 | "
\n", 246 | "str\n", 247 | "\n", 249 | "str\n", 250 | "\n", 252 | "str\n", 253 | "\n", 255 | "i64\n", 256 | "\n", 258 | "i64\n", 259 | "\n", 261 | "i64\n", 262 | "\n", 264 | "list\n", 265 | "\n", 267 | "list\n", 268 | "
\n", 274 | "\"apple\"\n", 275 | "\n", 277 | "\"beetle\"\n", 278 | "\n", 280 | "\"fruits\"\n", 281 | "\n", 283 | "11\n", 284 | "\n", 286 | "4\n", 287 | "\n", 289 | "7\n", 290 | "\n", 292 | "[4, 3]\n", 293 | "\n", 295 | "[4, 3]\n", 296 | "
\n", 300 | "\"apple\"\n", 301 | "\n", 303 | "\"beetle\"\n", 304 | "\n", 306 | "\"fruits\"\n", 307 | "\n", 309 | "11\n", 310 | "\n", 312 | "4\n", 313 | "\n", 315 | "7\n", 316 | "\n", 318 | "[4, 3]\n", 319 | "\n", 321 | "[4, 3]\n", 322 | "
\n", 326 | "\"banana\"\n", 327 | "\n", 329 | "\"beetle\"\n", 330 | "\n", 332 | "\"fruits\"\n", 333 | "\n", 335 | "11\n", 336 | "\n", 338 | "4\n", 339 | "\n", 341 | "8\n", 342 | "\n", 344 | "[5, 2, 1]\n", 345 | "\n", 347 | "[5, 2, 1]\n", 348 | "
\n", 352 | "\"banana\"\n", 353 | "\n", 355 | "\"audi\"\n", 356 | "\n", 358 | "\"fruits\"\n", 359 | "\n", 361 | "11\n", 362 | "\n", 364 | "2\n", 365 | "\n", 367 | "8\n", 368 | "\n", 370 | "[5, 2, 1]\n", 371 | "\n", 373 | "[5, 2, 1]\n", 374 | "
\n", 378 | "\"banana\"\n", 379 | "\n", 381 | "\"beetle\"\n", 382 | "\n", 384 | "\"fruits\"\n", 385 | "\n", 387 | "11\n", 388 | "\n", 390 | "4\n", 391 | "\n", 393 | "8\n", 394 | "\n", 396 | "[5, 2, 1]\n", 397 | "\n", 399 | "[5, 2, 1]\n", 400 | "
\n", 404 | "
" 405 | ], 406 | "text/plain": [ 407 | "shape: (5, 8)\n", 408 | "┌────────┬────────┬────────┬─────┬───┬───┬───────────┬───────────┐\n", 409 | "│ fruits ┆ cars ┆ litera ┆ B ┆ s ┆ s ┆ rev_A_by_ ┆ sort_A_by │\n", 410 | "│ --- ┆ --- ┆ l_stri ┆ --- ┆ u ┆ u ┆ fruits ┆ _B_by_fru │\n", 411 | "│ str ┆ str ┆ ng_fru ┆ i64 ┆ m ┆ m ┆ --- ┆ its │\n", 412 | "│ ┆ ┆ its ┆ ┆ _ ┆ _ ┆ list ┆ --- │\n", 413 | "│ ┆ ┆ --- ┆ ┆ A ┆ A ┆ [i64] ┆ list │\n", 414 | "│ ┆ ┆ str ┆ ┆ _ ┆ _ ┆ ┆ [i64] │\n", 415 | "│ ┆ ┆ ┆ ┆ b ┆ b ┆ ┆ │\n", 416 | "│ ┆ ┆ ┆ ┆ y ┆ y ┆ ┆ │\n", 417 | "│ ┆ ┆ ┆ ┆ _ ┆ _ ┆ ┆ │\n", 418 | "│ ┆ ┆ ┆ ┆ c ┆ f ┆ ┆ │\n", 419 | "│ ┆ ┆ ┆ ┆ a ┆ r ┆ ┆ │\n", 420 | "│ ┆ ┆ ┆ ┆ r ┆ u ┆ ┆ │\n", 421 | "│ ┆ ┆ ┆ ┆ s ┆ i ┆ ┆ │\n", 422 | "│ ┆ ┆ ┆ ┆ - ┆ t ┆ ┆ │\n", 423 | "│ ┆ ┆ ┆ ┆ - ┆ s ┆ ┆ │\n", 424 | "│ ┆ ┆ ┆ ┆ - ┆ - ┆ ┆ │\n", 425 | "│ ┆ ┆ ┆ ┆ i ┆ - ┆ ┆ │\n", 426 | "│ ┆ ┆ ┆ ┆ 6 ┆ - ┆ ┆ │\n", 427 | "│ ┆ ┆ ┆ ┆ 4 ┆ i ┆ ┆ │\n", 428 | "│ ┆ ┆ ┆ ┆ ┆ 6 ┆ ┆ │\n", 429 | "│ ┆ ┆ ┆ ┆ ┆ 4 ┆ ┆ │\n", 430 | "╞════════╪════════╪════════╪═════╪═══╪═══╪═══════════╪═══════════╡\n", 431 | "│ apple ┆ beetle ┆ fruits ┆ 11 ┆ 4 ┆ 7 ┆ [4, 3] ┆ [4, 3] │\n", 432 | "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌┼╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤\n", 433 | "│ apple ┆ beetle ┆ fruits ┆ 11 ┆ 4 ┆ 7 ┆ [4, 3] ┆ [4, 3] │\n", 434 | "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌┼╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤\n", 435 | "│ banana ┆ beetle ┆ fruits ┆ 11 ┆ 4 ┆ 8 ┆ [5, 2, 1] ┆ [5, 2, 1] │\n", 436 | "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌┼╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤\n", 437 | "│ banana ┆ audi ┆ fruits ┆ 11 ┆ 2 ┆ 8 ┆ [5, 2, 1] ┆ [5, 2, 1] │\n", 438 | "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌┼╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┤\n", 439 | "│ banana ┆ beetle ┆ fruits ┆ 11 ┆ 4 ┆ 8 ┆ [5, 2, 1] ┆ [5, 2, 1] │\n", 440 | "└────────┴────────┴────────┴─────┴───┴───┴───────────┴───────────┘" 441 | ] 442 | }, 443 | "execution_count": 3, 444 | "metadata": {}, 445 | "output_type": "execute_result" 446 | } 447 | ], 448 | "source": [ 449 | "df.sort(\"fruits\").select(\n", 450 | " [\n", 451 | " \"fruits\",\n", 452 | " \"cars\",\n", 453 | " pl.lit(\"fruits\").alias(\"literal_string_fruits\"),\n", 454 | " pl.col(\"B\").filter(pl.col(\"cars\") == \"beetle\").sum(),\n", 455 | " pl.col(\"A\").filter(pl.col(\"B\") > 2).sum().over(\"cars\").alias(\"sum_A_by_cars\"), # groups by \"cars\"\n", 456 | " pl.col(\"A\").sum().over(\"fruits\").alias(\"sum_A_by_fruits\"), # groups by \"fruits\"\n", 457 | " pl.col(\"A\").reverse().over(\"fruits\").alias(\"rev_A_by_fruits\"), # groups by \"fruits\n", 458 | " pl.col(\"A\").sort_by(\"B\").over(\"fruits\").alias(\"sort_A_by_B_by_fruits\"), # groups by \"fruits\"\n", 459 | " ]\n", 460 | ")" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 7, 466 | "id": "178b2936-3345-48a8-b0f7-396380eaee63", 467 | "metadata": {}, 468 | "outputs": [ 469 | { 470 | "name": "stdout", 471 | "output_type": "stream", 472 | "text": [ 473 | "Pandas_time: 0.2836948000000348\n", 474 | "Polars_time: 0.0953109999999242\n", 475 | "Pandas_time1: 0.343936900000017\n", 476 | "Polars_time1: 0.042977000000064436\n" 477 | ] 478 | } 479 | ], 480 | "source": [ 481 | "# 读取时间对比\n", 482 | "start_df = timeit.default_timer()\n", 483 | "df = pd.read_csv(\"E:/data/a.csv\")\n", 484 | "df = df.sort_values(\"current\", ascending=False).head()\n", 485 | "stop_df = timeit.default_timer()\n", 486 | "print('Pandas_time: ', stop_df - start_df)\n", 487 | "\n", 488 | "start_pl = timeit.default_timer()\n", 489 | "data = pl.read_csv(\"E:/data/a.csv\")\n", 490 | "data.sort(by=\"current\", reverse=True).head()\n", 491 | "stop_pl = timeit.default_timer()\n", 492 | "print('Polars_time: ', stop_pl - start_pl)\n", 493 | "\n", 494 | "# 纵向拼接时间对比\n", 495 | "start_df1 = timeit.default_timer()\n", 496 | "df_1 = pd.read_csv('E:/data/a.csv')\n", 497 | "df_2 = pd.read_csv('E:/data/b.csv')\n", 498 | "df_1.append(df_2, ignore_index=True)\n", 499 | "stop_df1 = timeit.default_timer()\n", 500 | "print('Pandas_time1: ', stop_df1 - start_df1)\n", 501 | "\n", 502 | "start_pl1 = timeit.default_timer()\n", 503 | "pl_1 = pl.read_csv('E:/data/a.csv')\n", 504 | "pl_2 = pl.read_csv('E:/data/b.csv')\n", 505 | "pl_1.vstack(pl_2)\n", 506 | "stop_pl1 = timeit.default_timer()\n", 507 | "print('Polars_time1: ', stop_pl1 - start_pl1)\n" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "id": "c3a84ee9-5f03-4b03-8de1-f57924b4849c", 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [] 517 | } 518 | ], 519 | "metadata": { 520 | "kernelspec": { 521 | "display_name": "Python 3", 522 | "language": "python", 523 | "name": "python3" 524 | }, 525 | "language_info": { 526 | "codemirror_mode": { 527 | "name": "ipython", 528 | "version": 3 529 | }, 530 | "file_extension": ".py", 531 | "mimetype": "text/x-python", 532 | "name": "python", 533 | "nbconvert_exporter": "python", 534 | "pygments_lexer": "ipython3", 535 | "version": "3.6.5" 536 | } 537 | }, 538 | "nbformat": 4, 539 | "nbformat_minor": 5 540 | } 541 | -------------------------------------------------------------------------------- /99-模拟数据神器-Faker.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c244bb06-30db-4df8-990b-17db9968e3d6", 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "source": [ 10 | "# 生成模拟数据Faker\n", 11 | "\n", 12 | "* 安装方法:pip install Faker\n", 13 | "* 作者:https://blog.csdn.net/mall_lucy/article/details/108655317\n", 14 | "* 常见的语言选项:\n", 15 | " * zh_CN - Chinese (China Mainland)\n", 16 | " * zh_TW - Chinese (China Taiwan)\n", 17 | " * en_US - English (United States)\n", 18 | "* 地理信息:\n", 19 | " * city_suffix():市,县\n", 20 | " * country():国家\n", 21 | " * country_code():国家编码\n", 22 | " * district():区\n", 23 | " * geo_coordinate():地理坐标\n", 24 | " * latitude():地理坐标(纬度)\n", 25 | " * longitude():地理坐标(经度)\n", 26 | " * postcode():邮编\n", 27 | " * province():省份\n", 28 | " * address():详细地址\n", 29 | " * street_address():街道地址\n", 30 | " * street_name():街道名\n", 31 | " * street_suffix():街、路\n", 32 | "* 基础信息\n", 33 | " * ssn():生成身份证号\n", 34 | " * bs():随机公司服务名\n", 35 | " * company():随机公司名(长)\n", 36 | " * company_prefix():随机公司名(短)\n", 37 | " * company_suffix():公司性质,如'信息有限公司'\n", 38 | " * credit_card_expire():随机信用卡到期日,如'03/30'\n", 39 | " * credit_card_full():生成完整信用卡信息\n", 40 | " * credit_card_number():信用卡号\n", 41 | " * credit_card_provider():信用卡类型\n", 42 | " * credit_card_security_code():信用卡安全码\n", 43 | " * job():随机职位\n", 44 | " * first_name_female():女性名\n", 45 | " * first_name_male():男性名\n", 46 | " * name():随机生成全名\n", 47 | " * name_female():男性全名\n", 48 | " * name_male():女性全名\n", 49 | " * phone_number():随机生成手机号\n", 50 | " * phonenumber_prefix():随机生成手机号段,如139\n", 51 | "* 邮箱信息\n", 52 | " * ascii_company_email():随机ASCII公司邮箱名\n", 53 | " * ascii_email():随机ASCII邮箱:\n", 54 | " * company_email():公司邮箱\n", 55 | " * email():普通邮箱\n", 56 | " * safe_email():安全邮箱\n", 57 | "\n", 58 | "* 网络类基础信息\n", 59 | " * domain_name():生成域名\n", 60 | " * domain_word():域词(即,不包含后缀)\n", 61 | " * ipv4():随机IP4地址\n", 62 | " * ipv6():随机IP6地址\n", 63 | " * mac_address():随机MAC地址\n", 64 | " * tld():网址域名后缀(.com,.net.cn,等等,不包括.)\n", 65 | " * uri():随机URI地址\n", 66 | " * uri_extension():网址文件后缀\n", 67 | " * uri_page():网址文件(不包含后缀)\n", 68 | " * uri_path():网址文件路径(不包含文件名)\n", 69 | " * url():随机URL地址\n", 70 | " * user_name():随机用户名\n", 71 | " * image_url():随机URL地址\n", 72 | "* 浏览器信息\n", 73 | " * chrome():随机生成Chrome的浏览器user_agent信息\n", 74 | " * firefox():随机生成FireFox的浏览器user_agent信息\n", 75 | " * internet_explorer():随机生成IE的浏览器user_agent信息\n", 76 | " * opera():随机生成Opera的浏览器user_agent信息\n", 77 | " * safari():随机生成Safari的浏览器user_agent信息\n", 78 | " * linux_platform_token():随机Linux信息\n", 79 | " * user_agent():随机user_agent信息\n", 80 | " \n", 81 | "* 数字信息\n", 82 | " * numerify():三位随机数字\n", 83 | " * random_digit():0~9随机数\n", 84 | " * random_digit_not_null():1~9的随机数\n", 85 | " * random_int():随机数字,默认0~9999,可以通过设置min,max来设置\n", 86 | " * random_number():随机数字,参数digits设置生成的数字位数\n", 87 | " * pyfloat():随机Float数字\n", 88 | " * pyint():随机Int数字(参考random_int()参数)\n", 89 | " * pydecimal():随机Decimal数字(参考pyfloat参数)\n", 90 | "\n", 91 | "* 文本类信息\n", 92 | " * pystr():随机字符串\n", 93 | " * random_element():随机字母\n", 94 | " * random_letter():随机字母\n", 95 | " * paragraph():随机生成一个段落\n", 96 | " * paragraphs():随机生成多个段落\n", 97 | " * sentence():随机生成一句话\n", 98 | " * sentences():随机生成多句话,与段落类似\n", 99 | " * text():随机生成一篇文章\n", 100 | " * word():随机生成词语\n", 101 | " * words():随机生成多个词语,用法与段落,句子,类似\n", 102 | " * binary():随机生成二进制编码\n", 103 | " * boolean():True/False\n", 104 | " * language_code():随机生成两位语言编码\n", 105 | " * locale():随机生成语言/国际 信息\n", 106 | " * md5():随机生成MD5\n", 107 | " * null_boolean():NULL/True/False\n", 108 | " * password():随机生成密码,可选参数:length:密码长度;special_chars:是否能使用特殊字符;digits:是否包含数字;upper_case:是否包含大写字母; * * lower_case:是否包含小写字母\n", 109 | " * sha1():随机SHA1\n", 110 | " * sha256():随机SHA256\n", 111 | " * uuid4():随机UUID\n", 112 | "\n", 113 | "* 时间类信息\n", 114 | " * date():随机日期\n", 115 | " * date_between():随机生成指定范围内日期,参数:start_date,end_date\n", 116 | " * date_between_dates():随机生成指定范围内日期,用法同上\n", 117 | " * date_object():随机生产从1970-1-1到指定日期的随机日期。\n", 118 | " * date_time():随机生成指定时间(1970年1月1日至今)\n", 119 | " * date_time_ad():生成公元1年到现在的随机时间\n", 120 | " * date_time_between():用法同dates\n", 121 | " * future_date():未来日期\n", 122 | " * future_datetime():未来时间\n", 123 | " * month():随机月份\n", 124 | " * month_name():随机月份(英文)\n", 125 | " * past_date():随机生成已经过去的日期\n", 126 | " * past_datetime():随机生成已经过去的时间\n", 127 | " * time():随机24小时时间\n", 128 | " * timedelta():随机获取时间差\n", 129 | " * time_object():随机24小时时间,time对象\n", 130 | " * time_series():随机TimeSeries对象\n", 131 | " * timezone():随机时区\n", 132 | " * unix_time():随机Unix时间\n", 133 | " * year():随机年份\n", 134 | "\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 18, 140 | "id": "a6e7c540-81c0-4379-8249-51984de32596", 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stdout", 145 | "output_type": "stream", 146 | "text": [ 147 | "左桂香\n", 148 | "湖南省杰市清浦兴安盟街b座 886606\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "from faker import Faker\n", 154 | "fk = Faker(locale='zh_CN')\n", 155 | "#简体中文:zh_CN\n", 156 | "#繁体中文:zh_TW\n", 157 | "#美国英文:en_US\n", 158 | "#英国英文:en_GB\n", 159 | "#德文:de_DE\n", 160 | "#日文:ja_JP\n", 161 | "#韩文:ko_KR\n", 162 | "#法文:fr_FR\n", 163 | "\n", 164 | "print(fk.name())\n", 165 | "print(fk.address())\n" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 19, 171 | "id": "bcd5dbfe-f8cc-4715-8c2a-239b94a555d0", 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "青海省张家港县平山米路p座 248302 ktan@example.com 华泰通安传媒有限公司 na01@jiefeng.cn 寻呼员/话务员 608\n", 179 | "#######################\n", 180 | "{'username': 'yinming', 'name': '李颖', 'sex': 'F', 'address': '贵州省六安市高明银川路E座 919545', 'mail': 'pingyuan@hotmail.com', 'birthdate': datetime.date(1950, 8, 19)}\n", 181 | "#######################\n", 182 | "{'job': '个人业务部门经理/主管', 'company': '佳禾科技有限公司', 'ssn': '623021197604282742', 'residence': '福建省太原县清浦陈路P座 620691', 'current_location': (Decimal('-75.988008'), Decimal('-162.637549')), 'blood_group': 'AB-', 'website': ['http://www.pingxiulan.cn/', 'https://hf.cn/', 'http://www.wei.net/'], 'username': 'wangyan', 'name': '王楠', 'sex': 'F', 'address': '内蒙古自治区岩县滨城太原路S座 308025', 'mail': 'pingwei@yahoo.com', 'birthdate': datetime.date(1960, 4, 13)}\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "# 基础信息数据例子\n", 188 | "#1、地址\n", 189 | "addr = fk.address()\n", 190 | "#2、邮箱\n", 191 | "email = fk.email()\n", 192 | "#3、公司名称\n", 193 | "company = fk.company()\n", 194 | "#4、公司邮箱\n", 195 | "company_email = fk.company_email()\n", 196 | "#5、工作岗位\n", 197 | "job = fk.job()\n", 198 | "#6、简单个人信息\n", 199 | "per_s = fk.simple_profile()\n", 200 | "#7、更多个人信息\n", 201 | "per = fk.profile()\n", 202 | " \n", 203 | "#随机数 :\n", 204 | "res = fk.random_int(min=100,max=999)\n", 205 | "\n", 206 | "print(addr,email,company,company_email,job,res)\n", 207 | "print(\"#######################\")\n", 208 | "print(per_s)\n", 209 | "print(\"#######################\")\n", 210 | "print(per)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "ad00dc89-14ac-44c2-9501-5c629eb8aae1", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.6.5" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 5 243 | } 244 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deeplearning-and-Coding 2 | 各位朋友大家好,我是kavin:
3 | 学习深度学习不如边写代码边学习,实际操作一遍才能理解数据的变换过程,参数的训练过程,这里整合了B站的jupter代码,可以结合着B站的视频边看边练,希望能对大家有帮助。 4 | 另外我这里参考了沐神的教材,他也有相关视频大家可以去看,我觉得自己不太一样的地方就是侧重于coding部分,边讲边执行代码,甚至在线debug数据的变换过程。
5 | 这里是我的B站:里边有配套的视频,欢迎观看。https://space.bilibili.com/99089023
6 | 项目用到的数据集中放在了这里:链接:https://pan.baidu.com/s/185V7-wK1ZG2Nsrx8mwZE2Q?pwd=bg55 提取码:bg55
7 | 项目持续更新中,感谢大家的关注,谢谢。 8 | --------------------------------------------------------------------------------