├── Pytorch_Image_Classification ├── CIFAR-10-NET │ ├── Pytorch CIFAR-10分类(AlexNet).ipynb │ ├── Pytorch CIFAR-10分类(DenseNet)GPU内存不足未通过.ipynb │ ├── Pytorch CIFAR-10分类(GoogLeNet)GPU内存不足未通过.ipynb │ ├── Pytorch CIFAR-10分类(LeNet5).ipynb │ ├── Pytorch CIFAR-10分类(MobileNet(v1、v2)).ipynb │ ├── Pytorch CIFAR-10分类(ResNet18).ipynb │ ├── Pytorch CIFAR-10分类(VGGNet16).ipynb │ └── Pytorch CIFAR-10分类(自定义网络).ipynb ├── Pytorch 两种自定义数据集.ipynb └── Pytorch 自定义数据集+迁移学习一.ipynb └── README.md /Pytorch_Image_Classification/CIFAR-10-NET/Pytorch CIFAR-10分类(DenseNet)GPU内存不足未通过.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "@[toc](目录)\n", 8 | "\n", 9 | "**Pytorch一般有以下几个流程**\n", 10 | "\n", 11 | "\n", 12 | "### 1.数据读取\n", 13 | "\n", 14 | "### 2.数据处理\n", 15 | "\n", 16 | "### 3.搭建网络\n", 17 | "\n", 18 | "### 4.模型训练\n", 19 | "\n", 20 | "### 5.模型上线\n", 21 | "\n", 22 | "\n", 23 | "主要包括以上几个阶段,每个阶段又可以细分,后面再说\n" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### 1.数据读取\n", 31 | "CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( arplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 \n", 32 | "\n", 33 | "与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:\n", 34 | "\n", 35 | " • CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。\n", 36 | " • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。\n", 37 | " • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "首先使用`torchvision`加载和归一化我们的训练数据和测试数据。\n", 45 | "\n", 46 | " a、`torchvision`这个东西,实现了常用的一些深度学习的相关的图像数据的加载功能,比如cifar10、Imagenet、Mnist等等的,保存在`torchvision.datasets`模块中。\n", 47 | " \n", 48 | " b、同时,也封装了一些处理数据的方法。保存在`torchvision.transforms`模块中\n", 49 | " \n", 50 | " c、还封装了一些模型和工具封装在相应模型中,比如`torchvision.models`当中就包含了AlexNet,VGG,ResNet,SqueezeNet等模型。" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 1, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "#首先导入torch,torchvision模块\n", 60 | "\n", 61 | "import torch\n", 62 | "import torchvision\n", 63 | "import torchvision.datasets as datasets\n", 64 | "import torchvision.transforms as transforms" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "**由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor**\n", 72 | "\n", 73 | "首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换\n", 74 | "\n", 75 | "`transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))`前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)是三个通道的标准差,注意通道顺序是 R G B ,用过opencv的同学应该知道openCV读出来的图像是 BRG顺序。这两个tuple数据是用来对RGB 图像做归一化的,如其名称 Normalize 所示这里都取0.5只是一个近似的操作,实际上其均值和方差并不是这么多,但是就这个示例而言 影响可不计。精确值是通过分别计算R,G,B三个通道的数据算出来的。\n" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) " 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Files already downloaded and verified\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "# datasets.CIFAR10( )也是封装好了的,就在我前面提到的torchvision.datasets块中\n", 102 | "trainset = datasets.CIFAR10(root='D:/CIFAR-10', train=True,download=True, transform=transform)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | " `trainloader`其实是一个比较重要的东西,我们后面就是通过`trainloader`把数据传入网络,当然这里的`trainloader`其实是个变量名,可以随便取,重点是他是由后面的`torch.utils.data.DataLoader()`定义的,这个东西来源于`torch.utils.data`模块" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 4, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": { 125 | "scrolled": true 126 | }, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "Files already downloaded and verified\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "# 对于测试集的操作和训练集一样,我就不赘述了\n", 138 | "testset = torchvision.datasets.CIFAR10(root='D:/CIFAR-10', train=False,download=True, transform=transform)\n", 139 | "testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)\n", 140 | " # 类别信息也是需要我们给定的\n", 141 | "classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/plain": [ 152 | "['airplane',\n", 153 | " 'automobile',\n", 154 | " 'bird',\n", 155 | " 'cat',\n", 156 | " 'deer',\n", 157 | " 'dog',\n", 158 | " 'frog',\n", 159 | " 'horse',\n", 160 | " 'ship',\n", 161 | " 'truck']" 162 | ] 163 | }, 164 | "execution_count": 6, 165 | "metadata": {}, 166 | "output_type": "execute_result" 167 | } 168 | ], 169 | "source": [ 170 | "trainset.classes" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 7, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "{'airplane': 0,\n", 182 | " 'automobile': 1,\n", 183 | " 'bird': 2,\n", 184 | " 'cat': 3,\n", 185 | " 'deer': 4,\n", 186 | " 'dog': 5,\n", 187 | " 'frog': 6,\n", 188 | " 'horse': 7,\n", 189 | " 'ship': 8,\n", 190 | " 'truck': 9}" 191 | ] 192 | }, 193 | "execution_count": 7, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "trainset.class_to_idx" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "### 2. 查看数据(格式,大小,形状)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 8, 212 | "metadata": { 213 | "scrolled": true 214 | }, 215 | "outputs": [ 216 | { 217 | "data": { 218 | "text/plain": [ 219 | "array([[[[ 59, 62, 63],\n", 220 | " [ 43, 46, 45],\n", 221 | " [ 50, 48, 43],\n", 222 | " ...,\n", 223 | " [158, 132, 108],\n", 224 | " [152, 125, 102],\n", 225 | " [148, 124, 103]],\n", 226 | "\n", 227 | " [[ 16, 20, 20],\n", 228 | " [ 0, 0, 0],\n", 229 | " [ 18, 8, 0],\n", 230 | " ...,\n", 231 | " [123, 88, 55],\n", 232 | " [119, 83, 50],\n", 233 | " [122, 87, 57]],\n", 234 | "\n", 235 | " [[ 25, 24, 21],\n", 236 | " [ 16, 7, 0],\n", 237 | " [ 49, 27, 8],\n", 238 | " ...,\n", 239 | " [118, 84, 50],\n", 240 | " [120, 84, 50],\n", 241 | " [109, 73, 42]],\n", 242 | "\n", 243 | " ...,\n", 244 | "\n", 245 | " [[208, 170, 96],\n", 246 | " [201, 153, 34],\n", 247 | " [198, 161, 26],\n", 248 | " ...,\n", 249 | " [160, 133, 70],\n", 250 | " [ 56, 31, 7],\n", 251 | " [ 53, 34, 20]],\n", 252 | "\n", 253 | " [[180, 139, 96],\n", 254 | " [173, 123, 42],\n", 255 | " [186, 144, 30],\n", 256 | " ...,\n", 257 | " [184, 148, 94],\n", 258 | " [ 97, 62, 34],\n", 259 | " [ 83, 53, 34]],\n", 260 | "\n", 261 | " [[177, 144, 116],\n", 262 | " [168, 129, 94],\n", 263 | " [179, 142, 87],\n", 264 | " ...,\n", 265 | " [216, 184, 140],\n", 266 | " [151, 118, 84],\n", 267 | " [123, 92, 72]]],\n", 268 | "\n", 269 | "\n", 270 | " [[[154, 177, 187],\n", 271 | " [126, 137, 136],\n", 272 | " [105, 104, 95],\n", 273 | " ...,\n", 274 | " [ 91, 95, 71],\n", 275 | " [ 87, 90, 71],\n", 276 | " [ 79, 81, 70]],\n", 277 | "\n", 278 | " [[140, 160, 169],\n", 279 | " [145, 153, 154],\n", 280 | " [125, 125, 118],\n", 281 | " ...,\n", 282 | " [ 96, 99, 78],\n", 283 | " [ 77, 80, 62],\n", 284 | " [ 71, 73, 61]],\n", 285 | "\n", 286 | " [[140, 155, 164],\n", 287 | " [139, 146, 149],\n", 288 | " [115, 115, 112],\n", 289 | " ...,\n", 290 | " [ 79, 82, 64],\n", 291 | " [ 68, 70, 55],\n", 292 | " [ 67, 69, 55]],\n", 293 | "\n", 294 | " ...,\n", 295 | "\n", 296 | " [[175, 167, 166],\n", 297 | " [156, 154, 160],\n", 298 | " [154, 160, 170],\n", 299 | " ...,\n", 300 | " [ 42, 34, 36],\n", 301 | " [ 61, 53, 57],\n", 302 | " [ 93, 83, 91]],\n", 303 | "\n", 304 | " [[165, 154, 128],\n", 305 | " [156, 152, 130],\n", 306 | " [159, 161, 142],\n", 307 | " ...,\n", 308 | " [103, 93, 96],\n", 309 | " [123, 114, 120],\n", 310 | " [131, 121, 131]],\n", 311 | "\n", 312 | " [[163, 148, 120],\n", 313 | " [158, 148, 122],\n", 314 | " [163, 156, 133],\n", 315 | " ...,\n", 316 | " [143, 133, 139],\n", 317 | " [143, 134, 142],\n", 318 | " [143, 133, 144]]],\n", 319 | "\n", 320 | "\n", 321 | " [[[255, 255, 255],\n", 322 | " [253, 253, 253],\n", 323 | " [253, 253, 253],\n", 324 | " ...,\n", 325 | " [253, 253, 253],\n", 326 | " [253, 253, 253],\n", 327 | " [253, 253, 253]],\n", 328 | "\n", 329 | " [[255, 255, 255],\n", 330 | " [255, 255, 255],\n", 331 | " [255, 255, 255],\n", 332 | " ...,\n", 333 | " [255, 255, 255],\n", 334 | " [255, 255, 255],\n", 335 | " [255, 255, 255]],\n", 336 | "\n", 337 | " [[255, 255, 255],\n", 338 | " [254, 254, 254],\n", 339 | " [254, 254, 254],\n", 340 | " ...,\n", 341 | " [254, 254, 254],\n", 342 | " [254, 254, 254],\n", 343 | " [254, 254, 254]],\n", 344 | "\n", 345 | " ...,\n", 346 | "\n", 347 | " [[113, 120, 112],\n", 348 | " [111, 118, 111],\n", 349 | " [105, 112, 106],\n", 350 | " ...,\n", 351 | " [ 72, 81, 80],\n", 352 | " [ 72, 80, 79],\n", 353 | " [ 72, 80, 79]],\n", 354 | "\n", 355 | " [[111, 118, 110],\n", 356 | " [104, 111, 104],\n", 357 | " [ 99, 106, 98],\n", 358 | " ...,\n", 359 | " [ 68, 75, 73],\n", 360 | " [ 70, 76, 75],\n", 361 | " [ 78, 84, 82]],\n", 362 | "\n", 363 | " [[106, 113, 105],\n", 364 | " [ 99, 106, 98],\n", 365 | " [ 95, 102, 94],\n", 366 | " ...,\n", 367 | " [ 78, 85, 83],\n", 368 | " [ 79, 85, 83],\n", 369 | " [ 80, 86, 84]]],\n", 370 | "\n", 371 | "\n", 372 | " ...,\n", 373 | "\n", 374 | "\n", 375 | " [[[ 35, 178, 235],\n", 376 | " [ 40, 176, 239],\n", 377 | " [ 42, 176, 241],\n", 378 | " ...,\n", 379 | " [ 99, 177, 219],\n", 380 | " [ 79, 147, 197],\n", 381 | " [ 89, 148, 189]],\n", 382 | "\n", 383 | " [[ 57, 182, 234],\n", 384 | " [ 44, 184, 250],\n", 385 | " [ 50, 183, 240],\n", 386 | " ...,\n", 387 | " [156, 182, 200],\n", 388 | " [141, 177, 206],\n", 389 | " [116, 149, 175]],\n", 390 | "\n", 391 | " [[ 98, 197, 237],\n", 392 | " [ 64, 189, 252],\n", 393 | " [ 69, 192, 245],\n", 394 | " ...,\n", 395 | " [188, 195, 206],\n", 396 | " [119, 135, 147],\n", 397 | " [ 61, 79, 90]],\n", 398 | "\n", 399 | " ...,\n", 400 | "\n", 401 | " [[ 73, 79, 77],\n", 402 | " [ 53, 63, 68],\n", 403 | " [ 54, 68, 80],\n", 404 | " ...,\n", 405 | " [ 17, 40, 64],\n", 406 | " [ 21, 36, 51],\n", 407 | " [ 33, 48, 49]],\n", 408 | "\n", 409 | " [[ 61, 68, 75],\n", 410 | " [ 55, 70, 86],\n", 411 | " [ 57, 79, 103],\n", 412 | " ...,\n", 413 | " [ 24, 48, 72],\n", 414 | " [ 17, 35, 53],\n", 415 | " [ 7, 23, 32]],\n", 416 | "\n", 417 | " [[ 44, 56, 73],\n", 418 | " [ 46, 66, 88],\n", 419 | " [ 49, 77, 105],\n", 420 | " ...,\n", 421 | " [ 27, 52, 77],\n", 422 | " [ 21, 43, 66],\n", 423 | " [ 12, 31, 50]]],\n", 424 | "\n", 425 | "\n", 426 | " [[[189, 211, 240],\n", 427 | " [186, 208, 236],\n", 428 | " [185, 207, 235],\n", 429 | " ...,\n", 430 | " [175, 195, 224],\n", 431 | " [172, 194, 222],\n", 432 | " [169, 194, 220]],\n", 433 | "\n", 434 | " [[194, 210, 239],\n", 435 | " [191, 207, 236],\n", 436 | " [190, 206, 235],\n", 437 | " ...,\n", 438 | " [173, 192, 220],\n", 439 | " [171, 191, 218],\n", 440 | " [167, 190, 216]],\n", 441 | "\n", 442 | " [[208, 219, 244],\n", 443 | " [205, 216, 240],\n", 444 | " [204, 215, 239],\n", 445 | " ...,\n", 446 | " [175, 191, 217],\n", 447 | " [172, 190, 216],\n", 448 | " [169, 191, 215]],\n", 449 | "\n", 450 | " ...,\n", 451 | "\n", 452 | " [[207, 199, 181],\n", 453 | " [203, 195, 175],\n", 454 | " [203, 196, 173],\n", 455 | " ...,\n", 456 | " [135, 132, 127],\n", 457 | " [162, 158, 150],\n", 458 | " [168, 163, 151]],\n", 459 | "\n", 460 | " [[198, 190, 170],\n", 461 | " [189, 181, 159],\n", 462 | " [180, 172, 147],\n", 463 | " ...,\n", 464 | " [178, 171, 160],\n", 465 | " [175, 169, 156],\n", 466 | " [175, 169, 154]],\n", 467 | "\n", 468 | " [[198, 189, 173],\n", 469 | " [189, 181, 162],\n", 470 | " [178, 170, 149],\n", 471 | " ...,\n", 472 | " [195, 184, 169],\n", 473 | " [196, 189, 171],\n", 474 | " [195, 190, 171]]],\n", 475 | "\n", 476 | "\n", 477 | " [[[229, 229, 239],\n", 478 | " [236, 237, 247],\n", 479 | " [234, 236, 247],\n", 480 | " ...,\n", 481 | " [217, 219, 233],\n", 482 | " [221, 223, 234],\n", 483 | " [222, 223, 233]],\n", 484 | "\n", 485 | " [[222, 221, 229],\n", 486 | " [239, 239, 249],\n", 487 | " [233, 234, 246],\n", 488 | " ...,\n", 489 | " [223, 223, 236],\n", 490 | " [227, 228, 238],\n", 491 | " [210, 211, 220]],\n", 492 | "\n", 493 | " [[213, 206, 211],\n", 494 | " [234, 232, 239],\n", 495 | " [231, 233, 244],\n", 496 | " ...,\n", 497 | " [220, 220, 232],\n", 498 | " [220, 219, 232],\n", 499 | " [202, 203, 215]],\n", 500 | "\n", 501 | " ...,\n", 502 | "\n", 503 | " [[150, 143, 135],\n", 504 | " [140, 135, 127],\n", 505 | " [132, 127, 120],\n", 506 | " ...,\n", 507 | " [224, 222, 218],\n", 508 | " [230, 228, 225],\n", 509 | " [241, 241, 238]],\n", 510 | "\n", 511 | " [[137, 132, 126],\n", 512 | " [130, 127, 120],\n", 513 | " [125, 121, 115],\n", 514 | " ...,\n", 515 | " [181, 180, 178],\n", 516 | " [202, 201, 198],\n", 517 | " [212, 211, 207]],\n", 518 | "\n", 519 | " [[122, 119, 114],\n", 520 | " [118, 116, 110],\n", 521 | " [120, 116, 111],\n", 522 | " ...,\n", 523 | " [179, 177, 173],\n", 524 | " [164, 164, 162],\n", 525 | " [163, 163, 161]]]], dtype=uint8)" 526 | ] 527 | }, 528 | "execution_count": 8, 529 | "metadata": {}, 530 | "output_type": "execute_result" 531 | } 532 | ], 533 | "source": [ 534 | "trainset.data" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 9, 540 | "metadata": {}, 541 | "outputs": [ 542 | { 543 | "data": { 544 | "text/plain": [ 545 | "(50000, 32, 32, 3)" 546 | ] 547 | }, 548 | "execution_count": 9, 549 | "metadata": {}, 550 | "output_type": "execute_result" 551 | } 552 | ], 553 | "source": [ 554 | "trainset.data.shape#50000是图片数量,32x32是图片大小,3是通道数量RGB" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 10, 560 | "metadata": { 561 | "scrolled": false 562 | }, 563 | "outputs": [ 564 | { 565 | "data": { 566 | "text/plain": [ 567 | "168" 568 | ] 569 | }, 570 | "execution_count": 10, 571 | "metadata": {}, 572 | "output_type": "execute_result" 573 | } 574 | ], 575 | "source": [ 576 | "trainset.data[10000][31][31][2]" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 11, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "name": "stdout", 586 | "output_type": "stream", 587 | "text": [ 588 | "\n", 589 | "\n" 590 | ] 591 | } 592 | ], 593 | "source": [ 594 | "#查看数据类型\n", 595 | "print(type(trainset.data))\n", 596 | "print(type(trainset))" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 12, 602 | "metadata": {}, 603 | "outputs": [ 604 | { 605 | "data": { 606 | "text/plain": [ 607 | "tensor(0.3176)" 608 | ] 609 | }, 610 | "execution_count": 12, 611 | "metadata": {}, 612 | "output_type": "execute_result" 613 | } 614 | ], 615 | "source": [ 616 | "trainset[10000][0][2][31][31]#10000为图片数量,0表示取前面的数据,2表示通道数RGB,32*32表示图片大小" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": {}, 622 | "source": [ 623 | "### 总结:\n", 624 | " \n", 625 | " `trainset.data.shape`是标准的numpy.ndarray类型,其中50000是图片数量,32x32是图片大小,3是通道数量RGB;\n", 626 | " `trainset`是标准的??类型,其中50000为图片数量,0表示取前面的数据,2表示3通道数RGB,32*32表示图片大小\n", 627 | " \n", 628 | " ### 3. 查看图片" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 13, 634 | "metadata": { 635 | "scrolled": true 636 | }, 637 | "outputs": [ 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "" 642 | ] 643 | }, 644 | "execution_count": 13, 645 | "metadata": {}, 646 | "output_type": "execute_result" 647 | }, 648 | { 649 | "data": { 650 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAegklEQVR4nO2da4yc53Xf/+ed++4s98rlLi8SKYlUZMux7NCKEheGk7SBagSQDSSGjcLVByMMihiogfSD4AK1C/SDU9Q2/KFwQVdClML1pbFdC4VQ2xXiOA4CxbSrOyWLoijelrvLy95ndudy+mFGKKU8/2eX3N1Z2s//BxDcfc4+73vmmffMO/P855xj7g4hxK8+2U47IIToDQp2IRJBwS5EIijYhUgEBbsQiaBgFyIR8puZbGYPAvgygByA/+run4/9fZZlnsvngrZCnrtiaAfHy4XwsQCgUilSW39fhdpqqw1quzK3EBzPR3wvFbgtM6O2mCAaXStyTG/zI5bLfD2aEWl2tbF2w34UCwU6p5DntmKxTG1ra01qW1haojZGqVSitlaTn2ultkJt+Rx/zorF8HXc11elc0rF/uD4hYvncO3a1eDi33Swm1kOwH8G8M8AnAfwUzN7wt1fYnNy+RxGx0eCtonxUXquQms5OH5kf/hYAPDr9+yjtvt/413U9sKpKWp7/H/+IDg+Pr6bzjk0MUZtlSJ/QWq1wi9wnfONUxt7IWjXeGDefQ9fj6uRQDp14Ty15YrhwL197ySds2c4so6H7qG20+dmqe2pv/374Hg78iJ2+M5D1HbtyhVqe/bZZ6ltbJRf3wf2DQbHf+O9v0XnHDr4QHD8Dz/6IJ2zmbfx9wM45e6n3X0NwDcAPLSJ4wkhtpHNBPs+AOeu+/18d0wIcQuymc/soc8F/+i9kZkdA3AMALKc9gOF2Ck2E33nARy47vf9AC6+/Y/c/bi7H3X3o1mmYBdip9hM9P0UwGEzO2RmRQAfA/DE1rglhNhqbvptvLs3zexTAL6PjvT2mLu/GJuTGVDJh19f8mhxJ4lssVznO9aLK9xWq/NzZRE/DuweCo5PjoXHAWA4ssTtGt/pvjjHd5gLfVxynLg9vG2yUOKP6xerV6nt6lyN2mpEEgWAvUNh2WhwMCInRWTKUpHbYp8Oc1lYAmyucYm11eJr1YzYYhmkMSl1ZCR8/Yzv3UPnrBXCD9q5mrs5nd3dnwTw5GaOIYToDfoQLUQiKNiFSAQFuxCJoGAXIhEU7EIkwqZ242+UzAyVYlgbiGVeWRZ2c7bOX6tOnOaS0eszP6M2b9WpbX4xLDXtKnEJLVfgx1tbW6W2wWo4OQIAqjmeQHP59XPBcSdrCACFjCcUZbPc/9wKt1X7w2s1kOOZbf39fdRWr/Hnc1c1nAEGAP2VcLZcfZX7HiUir7FMv/Vsy2Qd62sRma+fZDcGv9jaQXd2IRJBwS5EIijYhUgEBbsQiaBgFyIRerob7+5oNUk9uT5evml4z53B8UguA7IS36F94TVePujqpbPU1liZD45PXwzXpgOAsQHux9Awr6t22zBPGMl4/gmytbCxmY/UklvkO7j1Ol/ktSK/V8yuhXfP+65M0zkDg7wsVW2F13ezSJoJS2zKIhkjMZu3b65uYOy2Wq+FVZlGjSsGo5PhOnkZSfxZxwUhxK8SCnYhEkHBLkQiKNiFSAQFuxCJoGAXIhF6Kr2ZZcjlwpJBtbqLT8zCNddW61zyqi1co7a2czkpV+KJGs1GeLmWoq2VeL24XaROGwA08zwpZCk3zOcN7Q+OV/q5BIhqRPJqc8kLTf64G41wss7aGr+/zC/zrjWxGnQ1IvMBQK0etmXG/bCYhha1RQrARRJUimUmo/GTzV0Ody5qNfm1rTu7EImgYBciERTsQiSCgl2IRFCwC5EICnYhEmFT0puZnQGwCKAFoOnuR2N/7w40Sbm2K7M8G6o2HW6FlEUktFJEBRlo85pxE3v3UtvSUljyqpGsJQCYmAjLKgAweWiA2oYjGWCzBV6vb7UZbiVU5C6iXOStptoDXA4r1CJyWH943mKb318uz3Mny2WeIbiyyq+D5dXwc53PuMQaV95i1shFZ3ytFmphXy5HshEn+6M5dkG2Qmf/HXe/vAXHEUJsI3obL0QibDbYHcAPzOxnZnZsKxwSQmwPm30b/353v2hm4wB+aGYvu/uPr/+D7ovAMQDI5/hXR4UQ28um7uzufrH7/wyA7wK4P/A3x939qLsfzSnYhdgxbjrYzazfzAbe/BnA7wN4YascE0JsLZt5G78HwHe7bW3yAP67u//v2IRcLsPArnA2VyvSdmlpMZx55atLdM5IpNCjRR52LVxTEgCwWmeSHT9ecReX3kpN/pj3LPDHtnI3z4j7u5k3guNZg8uNvxnJONw/xf3IXaEm+ERYRls03taqL5IRtxZpXxXpQoWmh49ZyGL3uYiEFmn/FBfD+LvapXY4+/G1aZ7N1zcYXt9Wm1cjvelgd/fTAN59s/OFEL1F0psQiaBgFyIRFOxCJIKCXYhEULALkQg9LThZqlRw5J33Bm3zl3kuzcylS8Hx/Xt5f7hqlctT52a4vra8tMxtC2FbdYDLa/U1nqE2vRQplFiYobbzy1zGuZALHzM3yn2czXhm264zF7jtasT/PeHsOx8ZpXNyAzwTrVLmxTnrda69tdn9LNITLSaitZ1LWx6R5WJy3iJZ/rMzvKDq6GDYttbg/unOLkQiKNiFSAQFuxCJoGAXIhEU7EIkQk9341vuWGiGdwsXFvku+PhoeGd33759dE6uwHd2Ly/zpJA2eB20/btvC88p8R3QqcUWtS2Oj1DbqUG+ez5XD7dWAoCRXaSGXomrE6drPNll6dfC7aQAYLzAa78tNsOKwfgy3znPLb5EbZPD/FztSF27xlr4ubESnxOrM+dt/nzGaEd26seGwtfBvtvuonOGdoUTivI5HtK6swuRCAp2IRJBwS5EIijYhUgEBbsQiaBgFyIReiq9uTvqzfC3/ku7eB209mpYhjp9gSfP1BtcqinkuAwy3s+XZITIYfMlnphyuXGV2qZeP0VtgwM88WOwf47achPhJJ++g3fTOcXhsKQIAEtDk9T2PK5R28RqWPIaKnAJ8MwLr1Hb/DKXRPcwuRFAjshy+Ug7pmgiTDsiy0XktVjaTe1aOOnpXJ0nKE0Mvof4oEQYIZJHwS5EIijYhUgEBbsQiaBgFyIRFOxCJMK60puZPQbgDwDMuPu93bERAN8EcBDAGQAfdXeuw3RpNNYwPRWuaVaKZOsM9odbOUUUL2R5bszn+WtcKaKRrNTD2WFXl3itsEN7eZbUgTu5vLb7tj3UNjq4m9qWr4TlyItXfkLnzLUmqG2gwm17ylxqGiStrX5xgWc3Ttx5O7WNFoeprTXPJSrWS9Qi5eJi8lqrxTMmo0Suq7nL4ZiYvcKzAI8cORAcb0b828id/S8APPi2sUcAPOXuhwE81f1dCHELs26wd/utv/2bIQ8BeLz78+MAPrzFfgkhtpib/cy+x92nAKD7//jWuSSE2A62/euyZnYMwDEAKBR59RghxPZys3f2aTObBIDu/7Sjgbsfd/ej7n40V+jpV/GFENdxs8H+BICHuz8/DOB7W+OOEGK72Ij09nUAHwQwZmbnAXwWwOcBfMvMPgngLIA/2sjJmo0mrkyHWznlIu14msPhjLjJcb5V0FfiHxlW6lwOW6jzbDnzsKxRHODHKx3g8tr8Pt4aan6Ir8dUibddOnxvOIPtfQWeVbg0d4XasHKemqoFXjDzyZ+eC46fnObZawff9wC13XXoHmq78vwvqG36bPj5pG2hALTbPHOsSQqmAkBEsYtmveUsfF01Gyt0DlOWLXKidYPd3T9OTL+33lwhxK2DvkEnRCIo2IVIBAW7EImgYBciERTsQiRCT7/lks/nMToczqLyiNxRrgwExxtepnP6y7yw4cgAf9gF436sLIV1jdrwFPfjNq6F5Ed4P7dF59lhi22e5XWa9DabzbikuHuCZ5SN93N5bWmRF/wcrYefmyPjXG6ca/HHfD4ilw5GjlmdJNfISuQ+F8uI4ya0YoeMFIIskW+W3nGQFwKdvXg2ON5c49Km7uxCJIKCXYhEULALkQgKdiESQcEuRCIo2IVIhJ5Kb6Mju/Ev/8WxG56Xy5HXpFiWURapRhkxFfP8oLVwDUW8vPg3dM6uvbPU1urnclKxEJYbAaAQyeirlorh44HLWi3ncs1anmftXWnxbLm9d4f9n8/xTK4LL56gtsZF3jNvZGQ/tZVHw8VKByJr721etNEjsqdFhLnFpXCxUgDIk2u1VObS8muvvhgcX10lFyl0ZxciGRTsQiSCgl2IRFCwC5EICnYhEqGnu/GtVhML86QQbWRnPSO78RYpuGXGX8eakZPl8nyXdm45vNO5UuI7xUN5noDS5KfCyAivr1chiUEAUPLw4x6u8Bp0fWVe064JXoMul/Ed/qX5oeB4u8UvudoKVy58kK9HocqPOb8c3iG/vMKTeA5EeoA1SR1CIJ7M1WjyecVS2H+LHI/v/PNrW3d2IRJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJMJG2j89BuAPAMy4+73dsc8B+GMAb2oln3H3J9c9mzvarbAU5ZHeOU7kJDifY8Zt7UgzHmtx2/zcXHB8pcolqHJkhbOMa2+tZS6Hwbj0VsyHE2EQqeHWZusLIJ/jiTC3jfI6f/O50eD4ygpPCJkOlyfszIsUfxuvcnlz+Wo4yWSqxev/tZauUVvW4klD7ajsFalFmBHpLSJHtxrk+YzExEbu7H8B4MHA+Jfc/b7uv/UDXQixo6wb7O7+YwD8WyNCiF8KNvOZ/VNm9pyZPWZmvBaxEOKW4GaD/SsA7gRwH4ApAF9gf2hmx8zshJmdWFnhn22FENvLTQW7u0+7e8s7le+/CuD+yN8ed/ej7n60ry9cNUQIsf3cVLCb2eR1v34EwAtb444QYrvYiPT2dQAfBDBmZucBfBbAB83sPnRSbM4A+JONnKzVamJuLly3zKMyGst6i5ws48crRrKJkOOSzKWr08HxlUjdr772bmrbO8LnLUXW49oyyRwE0CCLkh/YQ+fkWY0/AKVskNqG+rlW1lcMv4t7Y+kcnbO0FpY2AeDiRZ6ltn98L7VVyuHWUI3IczYdyVAb7OMFDIsVfu14JMORXcjlMj8ey27LsoiMGjla55DuHw8MP7rePCHErYW+QSdEIijYhUgEBbsQiaBgFyIRFOxCJEJPC05muQz9VfLFmqj0RjS2SFFJi0gQZXAdZM14a6XFi2FJ5tIMz+Q6/waXaoZ28eWv8s4/6N/Nv52ckTWJZV15ROFpFbn/NecSZrMdnlfN8Uy5UoM/Z5ev8m9fnppZoLY794az9hpZxPe+cMYeAIzt4VLqrgovmDl3+nVqQ0aebCMZjOBSdSRRTnd2IVJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJEJPpTeDoUAKIkaUN368mPQGLhlZxrOaapGMuPlm2Laa8aKMr51bobbJSZ7fPzQcK8C5Rm25Yljaaue5lreS8SKKmXHJa355kdoqCPd6u73Ki2WWV/nzshqRDi8scB9z+XC2XCHjfsxlPEPQ7zlKbX37zlDbwhTvmWfkuoplgt4MurMLkQgKdiESQcEuRCIo2IVIBAW7EInQ0914AGCb3bH2T4xYi6csYmtFkiCWmnz3HP3hjJFKke/enp26wG1n+Wtto8Z3yCuRtkvtgfAxWwW+071cn6e2Zpv3Bxmv8vp0xUZ49zxb5klItsYfc9/QCLUtRVp2nbsYrteXcSEBWYWfa/7lV6htrMJrAw4N8+cMi6T9U6TIIrNFyzJGbEKIXyEU7EIkgoJdiERQsAuRCAp2IRJBwS5EImyk/dMBAH8JYAJAG8Bxd/+ymY0A+CaAg+i0gPqou1+LHcsBOKlbdlNf+o/Ia+1IQksukiSzVKtRW60dbhlUrO7ix5viksupVxrUZk2e7NK/yuWwoXD+CUr9vLZeIc9bIeUjbYvaS1wqqzXD95E3XuUJITNTXA+zPi7zVZp8HUeKYT/OTJ+lc6pl/rgKOb72MyuXqC2/HJEVSdLQaqQmX6EQfj43W4OuCeDP3P0eAA8A+FMzeweARwA85e6HATzV/V0IcYuybrC7+5S7/7z78yKAkwD2AXgIwOPdP3scwIe3y0khxOa5oc/sZnYQwHsAPA1gj7tPAZ0XBADjW+2cEGLr2HCwm1kVwLcBfNrdeaHufzzvmJmdMLMTtZXIV1GFENvKhoLdzAroBPrX3P073eFpM5vs2icBBL8Y7O7H3f2oux+t9EW+HyyE2FbWDXbrfOP+UQAn3f2L15meAPBw9+eHAXxv690TQmwVG8l6ez+ATwB43sye6Y59BsDnAXzLzD4J4CyAP1rvQJllKObDtb+i0hvr/hTNCuJZXrkCf9hrkQyqfCVcx83zXNZqRV5PT786R22FFpcHD98zQW1NC0uOhUX+EaqU5zJlPtIOq93gctLccrgl1uUzXLqyNpcwByIZdpOjFWobJ9mU52qv0TlLrXDdOgB45RJfx0qdX3OHivyxVUvhbMpWLDxJey33yPXLj/bmZP8JeObc7603Xwhxa6Bv0AmRCAp2IRJBwS5EIijYhUgEBbsQidDTgpNtd6y1wxlKsSw1JgVYxmWGXKTy3prxVK7lJrcNjIS/EVwY5K2EfI7La4MZb/909qUpanv95eeo7cjdYR9vH+fSz+5d3I9qX1gWAoD2aqRg5kpY8ipG5LViMdIaKpJZWCGSKABcfe1ccHz1algaBIB2OdyiDADqsey1gTFqm9h/J7WNlMPtw6JqNJGd8xEZWHd2IRJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJEJPpTd3x+pauLhhVHpjfa0iWW9Zxl/H1po8O2lxjdsKfbuD441Iply1MExtv3nvHdyPMZ559f2/+T61/d1PXg+OvzTI5amhAV5nYKCPz8vHpKFW+PlcbHFps3mA98wbq/B1LNZ5ocrzZ8O99hZnuSSai8iNA3u4j0fuOkJt+ydv4+erha85i2hvuVw4dHM5ni2pO7sQiaBgFyIRFOxCJIKCXYhEULALkQg93Y2HGSwXTjLIRxrXFLLwDmO7zXd2LeMPbZ4oAgAwX+OJDvWZ8A55g7S0AoDxSC25+QVez2z2Kt+Nz5d4zbVCM3y+VoHPOXdtmdowy6uGW0RBgYefm0aFJ2ocPMITYfqKXBVor3IfFxvhNlq53bymXSyLqrSLr+PAGPe/nfEWVbtIIk8l8jyzXfe8duOFEAp2IRJBwS5EIijYhUgEBbsQiaBgFyIR1pXezOwAgL8EMAGgDeC4u3/ZzD4H4I8BzHb/9DPu/mTsWO1WG0ukLVCjxRNQagjLOEvNsKwCAI02l08uzHGpZnatRm1OJLtiRKqZbnIp79wrL/FzLfDkDnO+Vp4Lr1V7jctklYhcU3d+ibQKsVZZYYl1eGKSzmkWeALK1GUuRY4P8HlD+w4Gx1vlK3ROPpJENTbK68y9cuo0te2+453cNjAUHM8KvBaeETkakeSwjejsTQB/5u4/N7MBAD8zsx92bV9y9/+0gWMIIXaYjfR6mwIw1f150cxOAti33Y4JIbaWG/rMbmYHAbwHwNPdoU+Z2XNm9piZ8YRjIcSOs+FgN7MqgG8D+LS7LwD4CoA7AdyHzp3/C2TeMTM7YWYn6jX+9VAhxPayoWA3swI6gf41d/8OALj7tLu33L0N4KsA7g/Ndffj7n7U3Y+WK7wiihBie1k32K1T++lRACfd/YvXjV+/rfoRAC9svXtCiK1iI7vx7wfwCQDPm9kz3bHPAPi4md0HwAGcAfAn6x2oXK7gnb/2rqDtSoNnXp2cOR8cn5mbDY4DwFqTZ9Fda/Ost1yeSxcD5XAWUqvO5bXyIN/KmNjL65LtjkheHpEpFy0sR3pErrt8jUuRc5H7Qb3EJa+RyXCttnccOUznTF3mLZnOnA9fAwCwy3hrqH0j4b3k1ctcYu3P8Qy7qoVbNQHA5Xl+PXqJz9tz4EBwvGh87Vmbp0KRy3Ub2Y3/CcLt1qKauhDi1kLfoBMiERTsQiSCgl2IRFCwC5EICnYhEqGnBSer1X789m8/ELStksKAAPBbq2GpbCGSoVZr8AJ/a5FsuflF3haoTtpGlSr9dE61ygsbVo0XX7RlLg/Wa9zm5fAxFxr824unL12itvnI/eDSIs/MGxwKP+6xPr4eZ66cobYDu3gxx3dN7OW2g+GWTIX3fYDOqUQyx4p9vAhkO5L9ONbP5cHJgbCtXOTXR7kS9qOPFK8EdGcXIhkU7EIkgoJdiERQsAuRCAp2IRJBwS5EIvRUejM4ChYuiFgp82ydPcNhacJy/LXKncsghYhE0m5zya7RJP3LIufKgjlEHSzjNt7FDrBINlRGXr9bkb54KzGZMnI/aKxyCRNrYVs7xwtffvDgXdRmkX56B4nMBwDj/eGsw0qJX2+FjGdMNiJrhRKXyvKxtnjkuZm9xoti/u3/+fvg+MLcVTpHd3YhEkHBLkQiKNiFSAQFuxCJoGAXIhEU7EIkQk+lt3q9jldeDheh7YtkExUKYUnDIj25CvlI4T1yPADIRY7Jivzl83wZc5E+aohIaIVIxhPzAwCMiHaZczmpatxWLHD/c308E42piq0WlwCbkUKa89e4pFRs8IKfeQvbcuCP+eRLJ6ntRz/6EbUduuMQtR2+ixfaXG2EH/f0LC9geY2sR2x9dWcXIhEU7EIkgoJdiERQsAuRCAp2IRJh3d14MysD+DGAUvfv/8rdP2tmhwB8A8AIgJ8D+IS7RzIjAMDhJKFhaYm3/on4Rm1ZxneRY/OixyQ76/nIjntkwz1OZPc8pkIw/9fW+E63RV7z+6u8bVF/P6+9x/xgSR8A0CTJMwBw9Qrfme6P1F2rLYfr5LXaPDPl0uxlahsa201tFlGApq9cozaQhK5SZH3vffd9wfFKpHnqRi7FVQC/6+7vRqc984Nm9gCAPwfwJXc/DOAagE9u4FhCiB1i3WD3Dm/edgvdfw7gdwH8VXf8cQAf3hYPhRBbwkb7s+e6HVxnAPwQwGsA5vz/twY9DyDcLlMIcUuwoWB395a73wdgP4D7AdwT+rPQXDM7ZmYnzOzEYqTOuBBie7mh7SN3nwPwIwAPABgyszc3+PYDuEjmHHf3o+5+dGCAf71SCLG9rBvsZrbbzIa6P1cA/FMAJwH8NYA/7P7ZwwC+t11OCiE2z0YSYSYBPG5mOXReHL7l7v/LzF4C8A0z+w8A/i+AR9c7UJZlqJC2Ne2IFMLgIllcQmPyX3ciNfGScfx43oqdK+ZGxBiR5ZipEKnXF5MpW5Ekk8V5LpWx5zP2nMG5LFeO1HezyPovLcwHx9t8CTE4xFs1jY6OUFvsumrHnjOSlBObk8vIuSJJTesGu7s/B+A9gfHT6Hx+F0L8EqBv0AmRCAp2IRJBwS5EIijYhUgEBbsQiWAe2d7f8pOZzQJ4o/vrGACeXtQ75MdbkR9v5ZfNj9vdPZia19Ngf8uJzU64+9EdObn8kB8J+qG38UIkgoJdiETYyWA/voPnvh758Vbkx1v5lfFjxz6zCyF6i97GC5EIOxLsZvagmb1iZqfM7JGd8KHrxxkze97MnjGzEz0872NmNmNmL1w3NmJmPzSzV7v/D++QH58zswvdNXnGzD7UAz8OmNlfm9lJM3vRzP51d7ynaxLxo6drYmZlM/sHM3u268e/744fMrOnu+vxTTPjFS5DuHtP/wHIoVPW6g4ARQDPAnhHr/3o+nIGwNgOnPcDAN4L4IXrxv4jgEe6Pz8C4M93yI/PAfg3PV6PSQDv7f48AOAXAN7R6zWJ+NHTNUEnE7ja/bkA4Gl0CsZ8C8DHuuP/BcC/upHj7sSd/X4Ap9z9tHdKT38DwEM74MeO4e4/BvD2znwPoVO4E+hRAU/iR89x9yl3/3n350V0iqPsQ4/XJOJHT/EOW17kdSeCfR+Ac9f9vpPFKh3AD8zsZ2Z2bId8eJM97j4FdC46AOM76MunzOy57tv8bf84cT1mdhCd+glPYwfX5G1+AD1ek+0o8roTwR4qVrJTksD73f29AP45gD81sw/skB+3El8BcCc6PQKmAHyhVyc2syqAbwP4tLsv9Oq8G/Cj52vimyjyytiJYD8P4MB1v9NilduNu1/s/j8D4LvY2co702Y2CQDd/2d2wgl3n+5eaG0AX0WP1sTMCugE2Nfc/Tvd4Z6vSciPnVqT7rlvuMgrYyeC/acADnd3FosAPgbgiV47YWb9Zjbw5s8Afh/AC/FZ28oT6BTuBHawgOebwdXlI+jBmlin4N6jAE66+xevM/V0TZgfvV6TbSvy2qsdxrftNn4InZ3O1wD82x3y4Q50lIBnAbzYSz8AfB2dt4MNdN7pfBLAKICnALza/X9kh/z4bwCeB/AcOsE22QM//gk6b0mfA/BM99+Her0mET96uiYAfh2dIq7PofPC8u+uu2b/AcApAP8DQOlGjqtv0AmRCPoGnRCJoGAXIhEU7EIkgoJdiERQsAuRCAp2IRJBwS5EIijYhUiE/wdumdOUbugNxwAAAABJRU5ErkJggg==\n", 651 | "text/plain": [ 652 | "
" 653 | ] 654 | }, 655 | "metadata": { 656 | "needs_background": "light" 657 | }, 658 | "output_type": "display_data" 659 | } 660 | ], 661 | "source": [ 662 | "import numpy as np\n", 663 | "import matplotlib.pyplot as plt\n", 664 | "plt.imshow(trainset.data[19])" 665 | ] 666 | }, 667 | { 668 | "cell_type": "markdown", 669 | "metadata": {}, 670 | "source": [ 671 | "### np.ndarray转为torch.Tensor\n", 672 | "\n", 673 | "在深度学习中,原始图像需要转换为深度学习框架自定义的数据格式,在pytorch中,需要转为`torch.Tensor`。\n", 674 | "pytorch提供了`torch.Tensor` 与`numpy.ndarray`转换为接口:\n", 675 | "\n", 676 | "方法名|作用\n", 677 | "--|--\n", 678 | "`torch.from_numpy(xxx)`|`numpy.ndarray`转为torch.Tensor\n", 679 | "`tensor1.numpy()`|获取tensor1对象的numpy格式数据\n", 680 | "\n", 681 | "`torch.Tensor` 高维矩阵的表示: N x C x H x W\n", 682 | "\n", 683 | "`numpy.ndarray` 高维矩阵的表示:N x H x W x C\n", 684 | "\n", 685 | "因此在两者转换的时候需要使用`numpy.transpose( )` 方法 。\n" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 14, 691 | "metadata": { 692 | "scrolled": true 693 | }, 694 | "outputs": [ 695 | { 696 | "data": { 697 | "text/plain": [ 698 | "torch.Size([3, 32, 32])" 699 | ] 700 | }, 701 | "execution_count": 14, 702 | "metadata": {}, 703 | "output_type": "execute_result" 704 | } 705 | ], 706 | "source": [ 707 | "# numpy image: H x W x C\n", 708 | "# torch image: C x H x W\n", 709 | "# np.transpose( xxx, (2, 0, 1)) # 将 H x W x C 转化为 C x H x W\n", 710 | "tensor_skimage = torch.from_numpy(np.transpose(trainset.data[19], (2, 0, 1)))\n", 711 | "tensor_skimage.size()" 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": {}, 717 | "source": [ 718 | "### torch.Tensor转numpy.ndarray" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": 15, 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "# np.transpose( xxx, (2, 0, 1)) # 将 C x H x W 转化为 H x W x C\n", 728 | "img_skimage_2 = np.transpose(trainset[19][0].numpy(), (1,2,0))" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": 16, 734 | "metadata": {}, 735 | "outputs": [ 736 | { 737 | "name": "stderr", 738 | "output_type": "stream", 739 | "text": [ 740 | "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" 741 | ] 742 | }, 743 | { 744 | "data": { 745 | "text/plain": [ 746 | "" 747 | ] 748 | }, 749 | "execution_count": 16, 750 | "metadata": {}, 751 | "output_type": "execute_result" 752 | }, 753 | { 754 | "data": { 755 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAWzUlEQVR4nO3df2xcVXYH8O8ZTyaO4xjHGMfOL5xAYKE0hNSNWEC7aWhXKUUFVrsUpLJUpZvdapGKtPyRUqmwalXtVgXEHyuq0KSbXbH8KD8jiigphQVUlSWwIQlJCCEkxtiJE2xjjJNMxnP6x7xIDvvO8Xh+vLG5348Uxb7H983185x5nnd87xVVBRF9+aVqPQAiSgaTnSgQTHaiQDDZiQLBZCcKBJOdKBDpcjqLyFoADwCoA/BvqvrjCb4+sTpfnRNrO8uODQzbsZNTpErpvUJbMW/ozbPts3UynzdjI8edo0p88+wGe/SzZtabsaY5c83Y8OfHzdixYwNmzDJzlj2OkydO2h1LLGOLca7OaW0x+8yZMy+2vf9oL4aHh2KPWHKyi0gdgJ8C+CMAPQDeFJEtqrq71GNWUpMT+6uv2bFHX7Bj758qeTgVNduJNRjtOafPn65oNGP7R06YsdfecZ74xjNr+fJZZpflSy82Y2vWfNOM/ff/2k+5hzb+woxZFp9/nhk7sG+fGRs7WdoTJGOcqz/71p+YfdasviO2/c6//XOzTzm/xq8CsF9VD6hqFsCjAK4r43hEVEXlJPsCAB+N+7wnaiOiKaic9+xx7wt+602LiKwDsK6MxyGiCign2XsALBr3+UIAvV/8IlXdAGADkOwNOiI6Uzm/xr8JYJmILBGRDICbAGypzLCIqNJKvrKrak5EbgfwXyhUujap6rsVG1mZBp3YwJAdK6sWmZDPnFhTR3z7Eed7fm74UzN29GBRQ/otZ3fGty9e3G72aWpwSm+Ndsy6m12qbDZrxsaylS/JnH9+fH3lkq7lZp+RhvhvOp8y6ngo87mtqs8DeL6cYxBRMvgXdESBYLITBYLJThQIJjtRIJjsRIGYDpWmivvpa7UeQXm8GX0f903+eA0HnaBX53O098e3z89YU3WAtrZWMzY0YM9eW9jeVvS4ylaFPwvrP/Z5bPvQiF0CzLfFX6fHnMfhlZ0oEEx2okAw2YkCwWQnCgSTnSgQQd6Nn+68H5p3N9YyXOIdd8/ukfj21n07zT7zF9vLUg0cO2bGUrDXyStFOp/sNXDIKDSMOjO2LlgZv/DajDq7VsMrO1EgmOxEgWCyEwWCyU4UCCY7USCY7ESBYOltGnL2YbHZS5MB7fFbCQEA+o6U8mhQowY4MmKXhrr7jXod/DXoBkYmv8WTJ1XZSt6EGpvj29NpeyAH974d237yxKjZh1d2okAw2YkCwWQnCgSTnSgQTHaiQDDZiQJRVulNRA6isErZGICcqnZVYlA0gZlOzKrLOWunDY445bVZzmMdd2JnxTf35uyn3N7uYTPW3GzUpwAcG7bLTaVIuPKGwYEZse17e+1r8co2Y5TOz7kSdfY/UFV7/iERTQn8NZ4oEOUmuwJ4UUTeEpF1lRgQEVVHub/GX6mqvSLSBmCriOxV1VfHf0H0IsAXAqIaK+vKrqq90f/9AJ4GsCrmazaoahdv3hHVVsnJLiKzRWTO6Y8BfAPArkoNjIgqq5xf4+cBeFpETh/nl6r6QkVGRVjixD78ihO0Xm69lSjbndj7dsjbhipvHLMvZc/Zax2xn44jaXvW2zF7XcYSJXzfWuNP1os77Nl8rYvjy5TZMfsHXXKyq+oBAJeW2p+IksXSG1EgmOxEgWCyEwWCyU4UCCY7USC44OQU9aEX7HdimUm2AyU/C9x95d4z2s+1u2TmN5ixlma7Pjg0VOnaW9Lz3uId3dVjxrYtjo99fjxr9uGVnSgQTHaiQDDZiQLBZCcKBJOdKBC8Gz8deTefFxrtTU4fb/ckb7aLffMcOGG0O5WEt3r/04wtWWqvM5dz1rUrRT7pu/Fzz49tXnLVWrNL58LFse17ZtiLBvLKThQIJjtRIJjsRIFgshMFgslOFAgmO1EgWHqbjrxtl6w1437f6bPUmZ3S6W2tdNQOWTs5NTi1vNfsqTUfvvey3W/2IjtWkoRLb4PxCwd++NyI2WXF4r+Mbc/nT5l9eGUnCgSTnSgQTHaiQDDZiQLBZCcKBJOdKBATlt5EZBOAawH0q+olUVsLgMcAdAI4COBGVR2s3jAD5JTKzo2f8AQA6N8X337caAcAZA/ZsRanX7MTs2a9veWsXPc7zvEaz7Jj3XaJCp87xzTkckmvQWeUMNUubXZ3XxHbns1aJ764K/vPAHxxrt16AC+p6jIAL0WfE9EUNmGyR/utf3HG83UANkcfbwZwfYXHRUQVVup79nmq2gcA0f9tlRsSEVVD1f9cVkTWAVhX7cchIl+pV/YjItIBANH/5mJDqrpBVbtUtavExyKiCig12bcAuDX6+FYAz1ZmOERULcWU3h4BsBpAq4j0ALgbwI8BPC4itwHoBvDtag7yS+tSJ9Zphw45i0fOvSm+fZWzOOThg844jtmhdueYv3rIOablKnv23dfXfNOM7fvlc2as79nJV4Rzudyk+ySt3tjOS8TuM2Gyq+rNRujqIsZERFME/4KOKBBMdqJAMNmJAsFkJwoEk50oEFxwstoucmLxW3wVeBOvnMrQoDEB7FfOT3ruCjt2ifOH0Id77Zi5H90Op0/W3gju/4ayZmzxJc7UvFeM9k+dGlXSk95KsHvb67HtJ0btGYC8shMFgslOFAgmO1EgmOxEgWCyEwWCyU4UCFHVxB4snWnQ5rYLJt0vkzHqRk6JJJU2pgUBgBNqrLcPOmCUk472vmEf8EY75K7v48wogzPrzY0ZZjtluQs67dj2vXYsbYzj1AvOQJ5xYkud2PnOXm89PfHtQ/bzfl7uHDN25ENnf7spQlVj64q8shMFgslOFAgmO1EgmOxEgWCyEwUi0YkwY6eO45OPd02+oxjDTDmvVSnvW3Nu49fbEy7w2an49hnOQ9U7MeehvEkydc68jybjW1vq9Gl1tnGyNxMCrCIJAJzsNgLe9+xt1eQtC9funOR+4667s7bekaapf8e9FLyyEwWCyU4UCCY7USCY7ESBYLITBYLJThSIYrZ/2gTgWgD9qnpJ1HYPgO8COF2juEtVny/uIccmP0o1+niHKuFhAABGda3kPl7pzTv79nJs7kt0o/V41ppwAHJeJdKZNHSVM6ep2+j3vlPywoVOzCm9zWgfNWOn9tfFB7LOE+SwM45prJgr+88ArI1pv19VV0T/ikx0IqqVCZNdVV8FMJDAWIioisp5z367iOwQkU0iMrdiIyKiqig12R8EcB6AFQD6ANxrfaGIrBORbSKyrcTHIqIKKCnZVfWIqo6pah7AQwBWOV+7QVW7VLWr1EESUflKSnYR6Rj36Q0ASpjdQkRJKqb09giA1QBaRaQHwN0AVovICgAK4CCA71VxjNPaLKdk1OnMbDvslMMGnbLcR8bLd/18u49XXmtyniGdbWebsdbG+AX23j+8xz6gvXMR8LEdOrXd2Yeq2Zju9/knzoN9OU2Y7Kp6c0zzxiqMhYiqiH9BRxQIJjtRIJjsRIFgshMFgslOFIhEF5wM0fFXneBCO9TuLALZdrEdSxsv3+7ym86WUdlGOzaQt+uKJ3JGPS9jzEIDgNESpyrucrYw6zLGEbtBUsSbqTi/w4590Od0rD1e2YkCwWQnCgSTnSgQTHaiQDDZiQLBZCcKhKg6ZYtKP5hIcg82DdTdYsc6l9qxfKsdazRKZXmnnOSV19JOLOfMUrO2lssMzzL7vPad4/YBvaqcfUh7z7y9Tp/6RXbs+9+3Yz2v2LFHtjoPWFmqGltY5JWdKBBMdqJAMNmJAsFkJwoEk50oEJwIU0Njr9uxD5xtOWY72y7ljLXmsg12n35nayh11tBb1m7HGo0dmdL9Jd5x92Sd2M4SjnfKOfnPbLFjLVN73VVe2YkCwWQnCgSTnSgQTHaiQDDZiQLBZCcKxIQTYURkEYCfA2gHkAewQVUfEJEWAI8B6ERhC6gbVXVwgmNxIkyxnKXOzMkdQOGnEWNO/G5MAPxJMnmnrNXubBvVfCK+/TfP233g7AxVshlG+6kqPNYUUc5EmByAH6rqRQAuB/ADEbkYwHoAL6nqMgAvRZ8T0RQ1YbKrap+qvh19/BkKr78LAFwHYHP0ZZsBXF+tQRJR+Sb1nl1EOgFcBuANAPNUtQ8ovCAAcH5RJKJaK/rPZUWkEcCTAO5Q1WERb+HtM/qtA7CutOERUaUUdWUXkRkoJPrDqvpU1HxERDqieAeA2F3DVXWDqnapalclBkxEpZkw2aVwCd8IYI+q3jcutAXArdHHtwJ4tvLDI6JKKab0dhWA11CYP5SPmu9C4X374wAWA+gG8G1VdaYLASJ1ai8YljfaPd5rlVH7ASAz7VqTnvy8hHEk7Fwn1hnfPMt5w9bklN7qnVOcN2a2AcBI7O95wGAps9AmssSJWbP2PqrCOKYIq/Q24Xt2VX0d9s5YV5czKCJKDv+CjigQTHaiQDDZiQLBZCcKBJOdKBA12P7JurGf5IQ4ayoUUNJ0KG/7IWd9xSnjbCfkbDWVG7Zjo8fi20+VOtvMO8cXOz/Pt6bD9LZzSuhjXac/geopbv9EFDImO1EgmOxEgWCyEwWCyU4UCCY7USBqsNfbVFhzssLlGG+vMc9Mp2R0MsGS0SdOyIklqmWeHRs6ktw4SlR30VfNWGYgfoZmKm/PBM1k4qcqDvePmH14ZScKBJOdKBBMdqJAMNmJAsFkJwpEDe7Gl6KUyTN1TmysjLFU8HBJ3nGf7hqb7djwFLkbf+5ZZmgsbS/Y19YS/721NLWYfTKZ+L23dn26z+zDKztRIJjsRIFgshMFgslOFAgmO1EgmOxEgZiw9CYiiwD8HEA7Cns0bVDVB0TkHgDfBXA0+tK7VPX56gyzlMkzFS6vUW29954d65giiwAe+tSOLbAX7GuY3xnbnm5oNPuk0vGlN4hdci6mzp4D8ENVfVtE5gB4S0S2RrH7VfVfijgGEdVYMXu99QHoiz7+TET2AFhQ7YERUWVN6j27iHQCuAyFHVwB4HYR2SEim0RkboXHRkQVVHSyi0gjgCcB3KGqwwAeBHAegBUoXPnvNfqtE5FtIrKtAuMlohIVlewiMgOFRH9YVZ8CAFU9oqpjqpoH8BCAVXF9VXWDqnapalelBk1EkzdhsouIANgIYI+q3jeuvWPcl90AYFflh0dElVLM3fgrAdwCYKeIbI/a7gJws4isQKEudhDA9yY6UKa+ER1Lfi82dmi03+54aE8Rw6Rp4au/a8f2HrZjg0ft2GGnvDYVdhsDgKZ2M7T8iiti2xtTdnrW1zfEtu/fuTW2HSjubvzriD9lVaqpE1E18C/oiALBZCcKBJOdKBBMdqJAMNmJApHogpPz5s3DnXfeERsbHrW3rekdHopt7xkZMPsMjNoL/I2csB+ru/egGRs6Eb9NT1NLm9mnvX2xHUvFl08AINUf/z0DwNCAHcs3xx+zZ/SY2efo9u1mzH2K9PbasU7j+261zwcGdzrjcKywy3mXrb42tr0ha1/nWlJ2rLHVXgQyl7H7faVtoRlbOT8+1txoPz+aW+LH8T9bHjT78MpOFAgmO1EgmOxEgWCyEwWCyU4UCCY7USASLb2lRNGQysbGWprtxfWWL40vTaQy9vDzeft1rMEpkeRydslu9ET82Eedx0o7r6eptB2Lf6SonzMbKm38SLM5+4jHvDKl8xQZHbZLmBiJj+UyObNL9+q1ZiyVt/uttsp8AC5pWxrb3tJkP98a0nkzNuqcKzTZpbJ6e/jIGz+b3Qfsfdv+af19se09Bw+YfXhlJwoEk50oEEx2okAw2YkCwWQnCgSTnSgQiZbehgYHsOWZR2Njrc5sooaG+JJGKm0Pv6HeLq3UG8cDgIxzTGuRv/r6evt4GWNPLgBwSmgNzownaxwAkDKKdum8XU5qT9mxxgZ7/JnW+WbMqipms3YJ8EQ2flYhAHQf2G/GGkftfdTqU/GxDOzv+aknnjJj9/zoH8zYmqu/bsauWXuNGRsejf++d+zebfY5YJyPbPak2YdXdqJAMNmJAsFkJwoEk50oEEx2okBMeDdeROoBvApgZvT1T6jq3SKyBMCjAFoAvA3gFlX15m8AUOSNCQ2HDztb/xhSzlph6bR9F9nr5x7TuLNe79xxd264+5y7514Vwhr/yIh9pzvlPA3a2u1ti9ra7LX3rHFYkz4A4IQxeQYA9u+z70y3tTSbsYH++HXysjl7Zsr23XvNWOeFF5qxlFMB2rHPnqACY0JXk3N+b/rOX8S2b9z872afYq7sJwGsUdVLUdieea2IXA7gJwDuV9VlAAYB3FbEsYioRiZMdi04/ZI7I/qnANYAeCJq3wzg+qqMkIgqotj92euiHVz7AWwF8AGAIVU9/btQD4AF1RkiEVVCUcmuqmOqugLAQgCrAFwU92VxfUVknYhsE5Ft3l9PEVF1TepuvKoOAXgFwOUAmkXk9J2FhQBi74So6gZV7VLVLvdPR4moqiZMdhE5R0Sao49nAfhDAHsAvAzgW9GX3Qrg2WoNkojKV0xhqAPAZhGpQ+HF4XFVfU5EdgN4VET+EcBvAGyc8MHSabQY29bknFKIxXul8kpoVvkv6miG7CXj7OPls95jecNwgk5Zzgo1OOv1eWXKrDPJpLfbLpVZP0/36pK33+Y1O+u7pZzzf7inO7Y9Z59CLO60t2q64ILzzZj3vMp5PzNjUo7XJ5O2zm/su2kARSS7qu4AcFlM+wEU3r8T0TTAv6AjCgSTnSgQTHaiQDDZiQLBZCcKhKjat+or/mAiRwEcij5tBXAssQe3cRxn4jjONN3Gca6qnhMXSDTZz3hgkW2q2lWTB+c4OI4Ax8Ff44kCwWQnCkQtk31DDR97PI7jTBzHmb4046jZe3YiShZ/jScKRE2SXUTWish7IrJfRNbXYgzROA6KyE4R2S4i2xJ83E0i0i8iu8a1tYjIVhF5P/p/bo3GcY+IfBydk+0iYu9bVLlxLBKRl0Vkj4i8KyJ/E7Unek6ccSR6TkSkXkR+LSLvROP4UdS+RETeiM7HYyIyuQUiVDXRfwDqUFjWaimADIB3AFyc9DiisRwE0FqDx/0agJUAdo1r+2cA66OP1wP4SY3GcQ+AOxM+Hx0AVkYfzwGwD8DFSZ8TZxyJnhMAAqAx+ngGgDdQWDDmcQA3Re3/CuCvJ3PcWlzZVwHYr6oHtLD09KMArqvBOGpGVV8FMPCF5utQWLgTSGgBT2MciVPVPlV9O/r4MxQWR1mAhM+JM45EaUHFF3mtRbIvAPDRuM9ruVilAnhRRN4SkXU1GsNp81S1Dyg86QDYi4ZX3+0isiP6Nb/qbyfGE5FOFNZPeAM1PCdfGAeQ8DmpxiKvtUh2iWmrVUngSlVdCeCPAfxARL5Wo3FMJQ8COA+FPQL6ANyb1AOLSCOAJwHcoar2EjnJjyPxc6JlLPJqqUWy9wBYNO5zc7HKalPV3uj/fgBPo7Yr7xwRkQ4AiP7vr8UgVPVI9ETLA3gICZ0TEZmBQoI9rKqnN0hP/JzEjaNW5yR67Ekv8mqpRbK/CWBZdGcxA+AmAFuSHoSIzBaROac/BvANALv8XlW1BYWFO4EaLuB5OrkiNyCBcyIigsIahntU9b5xoUTPiTWOpM9J1RZ5TeoO4xfuNl6Dwp3ODwD8XY3GsBSFSsA7AN5NchwAHkHh18FTKPymcxuAswG8BOD96P+WGo3jFwB2AtiBQrJ1JDCOq1D4lXQHgO3Rv2uSPifOOBI9JwCWo7CI6w4UXlj+ftxz9tcA9gP4DwAzJ3Nc/gUdUSD4F3REgWCyEwWCyU4UCCY7USCY7ESBYLITBYLJThQIJjtRIP4ft5LQwVArFHAAAAAASUVORK5CYII=\n", 756 | "text/plain": [ 757 | "
" 758 | ] 759 | }, 760 | "metadata": { 761 | "needs_background": "light" 762 | }, 763 | "output_type": "display_data" 764 | } 765 | ], 766 | "source": [ 767 | "plt.imshow(img_skimage_2)" 768 | ] 769 | }, 770 | { 771 | "cell_type": "markdown", 772 | "metadata": {}, 773 | "source": [ 774 | "### 4.定义网络(DenseNet)\n", 775 | "\n", 776 | "前面ResNet通过前层与后层的“短路连接”(Shortcuts),加强了前后层之间的信息流通,在一定程度上缓解了梯度消失现象,从而**可以将神经网络搭建得很深**。更进一步,DenseNet最大化了这种前后层信息交流,通过建立**前面所有层与后面层的密集连接**,实现了特征在通道维度上的复用,使其可以在参数与计算量更少的情况下实现比ResNet更优的性能。\n", 777 | "\n", 778 | "DenseNet的网络架构如下图所示,网络由多个Dense Block与中间的卷积池化组成,核心就在Dense Block中。Dense Block中的黑点代表一个卷积层,其中的多条黑线代表数据的流动,每一层的输入由前面的所有卷积层的输出组成。注意这里使用了通道拼接(Concatnate)操作,而非ResNet的逐元素相加操作。\n", 779 | "\n", 780 | "\"结构\"\n", 781 | "\n", 782 | "\n", 783 | "\n", 784 | " \n", 785 | " 具体的Block实现细节如下图所示,每一个Block由若干个Bottleneck的卷积层组成,对应上面图中的黑点。Bottleneck由BN、ReLU、1×1卷积、BN、ReLU、3×3卷积的顺序构成,也被称为DenseNet-B结构。其中1x1 Conv得到 4k 个特征图它起到的作用是降低特征数量,从而提升计算效率。\n", 786 | " \n", 787 | " \"DenseNet的Block结构\"\n", 788 | "\n", 789 | "\n", 790 | "关于Block,有以下4个细节需要注意:\n", 791 | "\n", 792 | " 1.每一个Bottleneck输出的特征通道数是相同的,例如这里的32。同时可以看到,经过Concatnate操作后的通道数是按32的增长量增加的,因此这个32也被称为GrowthRate。\n", 793 | " \n", 794 | " 2.这里1×1卷积的作用是固定输出通道数,达到降维的作用。当几十个Bottleneck相连接时,Concatnate后的通道数会增加到上千,如果不增加1×1的卷积来降维,后续3×3卷积所需的参数量会急剧增加。1×1卷积的通道数通常是GrowthRate的4倍。\n", 795 | " \n", 796 | " 3.上图中的特征传递方式是直接将前面所有层的特征Concatnate后传到下一层,这种方式与具体代码实现的方式是一致的。\n", 797 | " \n", 798 | " 4. Block采用了激活函数在前、卷积层在后的顺序,这与一般的网络上是不同的。\n", 799 | " \n", 800 | " \n", 801 | "DenseNet的结构有如下两个特性:\n", 802 | "\n", 803 | " 1.神经网络一般需要使用池化等操作缩小特征图尺寸来提取语义特征,而Dense Block需要保持每一个Block内的特征图尺寸一致来直接进行Concatnate操作,因此DenseNet被分成了多个Block。Block的数量一般为4。\n", 804 | " \n", 805 | " 2.两个相邻的Dense Block之间的部分被称为Transition层,对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。具体包括BN、ReLU、1×1卷积(Conv)、2×2平均池化操作。1×1卷积的作用是降维,起到压缩模型的作用,而平均池化则是降低特征图的尺寸。\n", 806 | " \n", 807 | " 3.Transition层可以起到压缩模型的作用。假定Transition的上接DenseBlock得到的特征图channels数为m ,Transition层可以产生θm个特征(通过卷积层),其中0 <θ≤1是压缩系数(compression rate)。当 θ=1时,特征个数经过Transition层没有变化,即无压缩,而当压缩系数小于1时,这种结构称为DenseNet-C,原论文中使用θ≤0.5 。对于使用bottleneck层的DenseBlock结构和压缩系数小于1的Transition组合结构称为DenseNet-BC。\n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | "**数据集的不同**\n", 812 | "\n", 813 | "DenseNet共在三个图像分类数据集(CIFAR,SVHN和ImageNet)上进行测试。对于CIFAR,SVHN两个数据集,其输入图片大小为 32x32,所使用的DenseNet在进入第一个DenseBlock之前,首先进行进行一次3x3卷积(stride=1),每个DenseBlock里面的层数相同。最后的DenseBlock之后是一个global AvgPooling层,然后送入一个softmax分类器。注意,在DenseNet中,所有的3x3卷积均采用padding=1的方式以保证特征图大小维持不变。\n", 814 | " \n", 815 | " \"CIFAR,SVHN结构\"\n", 816 | " \n", 817 | "

注意:这是我根绝自己的理解画的图(原文中没有),如有错误还望指正

\n", 818 | " \n", 819 | " \n", 820 | " 对于ImageNet数据集,图片输入大小为224x224,网络结构采用包含4个DenseBlock的DenseNet-BC,其首先是一个stride=2的7x7卷积层,然后是一个stride=2的3x3 MaxPooling层,后面才进入DenseBlock。\n", 821 | " \n", 822 | " \"ImageNet数据集\"\n", 823 | "\n", 824 | "

这个放心,这是原文中的图

" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": 17, 830 | "metadata": {}, 831 | "outputs": [], 832 | "source": [ 833 | "import torch\n", 834 | "import torch.nn as nn\n", 835 | "\n", 836 | "#实现一个Bottleneck的类,初始化需要输入通道数与GrowthRate这两个参数\n", 837 | "class Bottleneck(nn.Module):\n", 838 | " def __init__(self, in_channels, growth_rate):\n", 839 | " super(Bottleneck, self).__init__()\n", 840 | " #通常1×1卷积的通道数为GrowthRate的4倍\n", 841 | " inner_channel = 4 * growth_rate\n", 842 | "\n", 843 | " #Bottleneck由BN、ReLU、1×1卷积、BN、ReLU、3×3卷积的顺序构成\n", 844 | " self.bottle_neck = nn.Sequential(\n", 845 | " nn.BatchNorm2d(in_channels),\n", 846 | " nn.ReLU(inplace=True),\n", 847 | " nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False),\n", 848 | " nn.BatchNorm2d(inner_channel),\n", 849 | " nn.ReLU(inplace=True),\n", 850 | " nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False)\n", 851 | " )\n", 852 | " \n", 853 | " def forward(self, x):\n", 854 | " # 将输入x同计算的结果out进行通道拼接\n", 855 | " return torch.cat([x, self.bottle_neck(x)], 1)\n", 856 | " \n", 857 | "#Transition层,具体包括BN、ReLU、1×1卷积(Conv)、2×2平均池化操作。\n", 858 | "#1×1卷积的作用是降维,起到压缩模型的作用,而平均池化则是降低特征图的尺寸。\n", 859 | "class Transition(nn.Module):\n", 860 | " def __init__(self, in_channels, out_channels):\n", 861 | " super(Transition, self).__init__()\n", 862 | " self.down_sample = nn.Sequential(\n", 863 | " nn.BatchNorm2d(in_channels),\n", 864 | " nn.ReLU(inplace=True),\n", 865 | " nn.Conv2d(in_channels, out_channels, 1, bias=False),\n", 866 | " nn.AvgPool2d(2, stride=2)\n", 867 | " )\n", 868 | "\n", 869 | " def forward(self, x):\n", 870 | " return self.down_sample(x)\n", 871 | " \n", 872 | " \n", 873 | "\n", 874 | "#DesneNet-BC\n", 875 | "#B 代表 bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3))\n", 876 | "#C代表压缩系数(0<=theta<=1)\n", 877 | "class DenseNet(nn.Module):\n", 878 | " def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_class=100):\n", 879 | " super(DenseNet, self).__init__()\n", 880 | " self.growth_rate = growth_rate\n", 881 | " \n", 882 | " inner_channels = 2 * growth_rate\n", 883 | "\n", 884 | " self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False) \n", 885 | "\n", 886 | " self.features = nn.Sequential()\n", 887 | "\n", 888 | " for index in range(len(nblocks) - 1):\n", 889 | " self.features.add_module(\"dense_block_layer_{}\".format(index), self._make_dense_layers(block, inner_channels, nblocks[index]))\n", 890 | " inner_channels += growth_rate * nblocks[index]\n", 891 | " out_channels = int(reduction * inner_channels) # int() will automatic floor the value\n", 892 | " self.features.add_module(\"transition_layer_{}\".format(index), Transition(inner_channels, out_channels))\n", 893 | " inner_channels = out_channels\n", 894 | "\n", 895 | " self.features.add_module(\"dense_block{}\".format(len(nblocks) - 1), self._make_dense_layers(block, inner_channels, nblocks[len(nblocks)-1]))\n", 896 | " inner_channels += growth_rate * nblocks[len(nblocks) - 1]\n", 897 | " self.features.add_module('bn', nn.BatchNorm2d(inner_channels))\n", 898 | " self.features.add_module('relu', nn.ReLU(inplace=True))\n", 899 | "\n", 900 | " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", 901 | "\n", 902 | " self.linear = nn.Linear(inner_channels, num_class)\n", 903 | "\n", 904 | " def forward(self, x):\n", 905 | " output = self.conv1(x)\n", 906 | " output = self.features(output)\n", 907 | " output = self.avgpool(output)\n", 908 | " output = output.view(output.size()[0], -1)\n", 909 | " output = self.linear(output)\n", 910 | " return output\n", 911 | "\n", 912 | " def _make_dense_layers(self, block, in_channels, nblocks):\n", 913 | " dense_block = nn.Sequential()\n", 914 | " for index in range(nblocks):\n", 915 | " dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate))\n", 916 | " in_channels += self.growth_rate\n", 917 | " return dense_block\n", 918 | "\n", 919 | "\n", 920 | "\n", 921 | "def densenet121():\n", 922 | " return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, reduction=0.5, num_class=10)\n", 923 | "\n", 924 | "def densenet169():\n", 925 | " return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, reduction=0.5, num_class=10)\n", 926 | "\n", 927 | "def densenet201():\n", 928 | " return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, reduction=0.5, num_class=10)\n", 929 | "\n", 930 | "def densenet161():\n", 931 | " return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, reduction=0.5, num_class=10)" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": 18, 937 | "metadata": { 938 | "scrolled": true 939 | }, 940 | "outputs": [ 941 | { 942 | "name": "stdout", 943 | "output_type": "stream", 944 | "text": [ 945 | "DenseNet(\n", 946 | " (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 947 | " (features): Sequential(\n", 948 | " (dense_block_layer_0): Sequential(\n", 949 | " (bottle_neck_layer_0): Bottleneck(\n", 950 | " (bottle_neck): Sequential(\n", 951 | " (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 952 | " (1): ReLU(inplace=True)\n", 953 | " (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 954 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 955 | " (4): ReLU(inplace=True)\n", 956 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 957 | " )\n", 958 | " )\n", 959 | " (bottle_neck_layer_1): Bottleneck(\n", 960 | " (bottle_neck): Sequential(\n", 961 | " (0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 962 | " (1): ReLU(inplace=True)\n", 963 | " (2): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 964 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 965 | " (4): ReLU(inplace=True)\n", 966 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 967 | " )\n", 968 | " )\n", 969 | " (bottle_neck_layer_2): Bottleneck(\n", 970 | " (bottle_neck): Sequential(\n", 971 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 972 | " (1): ReLU(inplace=True)\n", 973 | " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 974 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 975 | " (4): ReLU(inplace=True)\n", 976 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 977 | " )\n", 978 | " )\n", 979 | " (bottle_neck_layer_3): Bottleneck(\n", 980 | " (bottle_neck): Sequential(\n", 981 | " (0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 982 | " (1): ReLU(inplace=True)\n", 983 | " (2): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 984 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 985 | " (4): ReLU(inplace=True)\n", 986 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 987 | " )\n", 988 | " )\n", 989 | " (bottle_neck_layer_4): Bottleneck(\n", 990 | " (bottle_neck): Sequential(\n", 991 | " (0): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 992 | " (1): ReLU(inplace=True)\n", 993 | " (2): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 994 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 995 | " (4): ReLU(inplace=True)\n", 996 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 997 | " )\n", 998 | " )\n", 999 | " (bottle_neck_layer_5): Bottleneck(\n", 1000 | " (bottle_neck): Sequential(\n", 1001 | " (0): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1002 | " (1): ReLU(inplace=True)\n", 1003 | " (2): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1004 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1005 | " (4): ReLU(inplace=True)\n", 1006 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1007 | " )\n", 1008 | " )\n", 1009 | " )\n", 1010 | " (transition_layer_0): Transition(\n", 1011 | " (down_sample): Sequential(\n", 1012 | " (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1013 | " (1): ReLU(inplace=True)\n", 1014 | " (2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1015 | " (3): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", 1016 | " )\n", 1017 | " )\n", 1018 | " (dense_block_layer_1): Sequential(\n", 1019 | " (bottle_neck_layer_0): Bottleneck(\n", 1020 | " (bottle_neck): Sequential(\n", 1021 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1022 | " (1): ReLU(inplace=True)\n", 1023 | " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1024 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1025 | " (4): ReLU(inplace=True)\n", 1026 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1027 | " )\n", 1028 | " )\n", 1029 | " (bottle_neck_layer_1): Bottleneck(\n", 1030 | " (bottle_neck): Sequential(\n", 1031 | " (0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1032 | " (1): ReLU(inplace=True)\n", 1033 | " (2): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1034 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1035 | " (4): ReLU(inplace=True)\n", 1036 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1037 | " )\n", 1038 | " )\n", 1039 | " (bottle_neck_layer_2): Bottleneck(\n", 1040 | " (bottle_neck): Sequential(\n", 1041 | " (0): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1042 | " (1): ReLU(inplace=True)\n", 1043 | " (2): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1044 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1045 | " (4): ReLU(inplace=True)\n", 1046 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1047 | " )\n", 1048 | " )\n", 1049 | " (bottle_neck_layer_3): Bottleneck(\n", 1050 | " (bottle_neck): Sequential(\n", 1051 | " (0): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1052 | " (1): ReLU(inplace=True)\n", 1053 | " (2): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1054 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1055 | " (4): ReLU(inplace=True)\n", 1056 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1057 | " )\n", 1058 | " )\n", 1059 | " (bottle_neck_layer_4): Bottleneck(\n", 1060 | " (bottle_neck): Sequential(\n", 1061 | " (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1062 | " (1): ReLU(inplace=True)\n", 1063 | " (2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1064 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1065 | " (4): ReLU(inplace=True)\n", 1066 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1067 | " )\n", 1068 | " )\n", 1069 | " (bottle_neck_layer_5): Bottleneck(\n", 1070 | " (bottle_neck): Sequential(\n", 1071 | " (0): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1072 | " (1): ReLU(inplace=True)\n", 1073 | " (2): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1074 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1075 | " (4): ReLU(inplace=True)\n", 1076 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1077 | " )\n", 1078 | " )\n", 1079 | " (bottle_neck_layer_6): Bottleneck(\n", 1080 | " (bottle_neck): Sequential(\n", 1081 | " (0): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1082 | " (1): ReLU(inplace=True)\n", 1083 | " (2): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1084 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1085 | " (4): ReLU(inplace=True)\n", 1086 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1087 | " )\n", 1088 | " )\n", 1089 | " (bottle_neck_layer_7): Bottleneck(\n", 1090 | " (bottle_neck): Sequential(\n", 1091 | " (0): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1092 | " (1): ReLU(inplace=True)\n", 1093 | " (2): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1094 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1095 | " (4): ReLU(inplace=True)\n", 1096 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1097 | " )\n", 1098 | " )\n", 1099 | " (bottle_neck_layer_8): Bottleneck(\n", 1100 | " (bottle_neck): Sequential(\n", 1101 | " (0): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1102 | " (1): ReLU(inplace=True)\n", 1103 | " (2): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1104 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1105 | " (4): ReLU(inplace=True)\n", 1106 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1107 | " )\n", 1108 | " )\n", 1109 | " (bottle_neck_layer_9): Bottleneck(\n", 1110 | " (bottle_neck): Sequential(\n", 1111 | " (0): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1112 | " (1): ReLU(inplace=True)\n", 1113 | " (2): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1114 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1115 | " (4): ReLU(inplace=True)\n", 1116 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1117 | " )\n", 1118 | " )\n", 1119 | " (bottle_neck_layer_10): Bottleneck(\n", 1120 | " (bottle_neck): Sequential(\n", 1121 | " (0): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1122 | " (1): ReLU(inplace=True)\n", 1123 | " (2): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1124 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1125 | " (4): ReLU(inplace=True)\n", 1126 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1127 | " )\n", 1128 | " )\n", 1129 | " (bottle_neck_layer_11): Bottleneck(\n", 1130 | " (bottle_neck): Sequential(\n", 1131 | " (0): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1132 | " (1): ReLU(inplace=True)\n", 1133 | " (2): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1134 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1135 | " (4): ReLU(inplace=True)\n", 1136 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1137 | " )\n", 1138 | " )\n", 1139 | " )\n", 1140 | " (transition_layer_1): Transition(\n", 1141 | " (down_sample): Sequential(\n", 1142 | " (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1143 | " (1): ReLU(inplace=True)\n", 1144 | " (2): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1145 | " (3): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", 1146 | " )\n", 1147 | " )\n", 1148 | " (dense_block_layer_2): Sequential(\n", 1149 | " (bottle_neck_layer_0): Bottleneck(\n", 1150 | " (bottle_neck): Sequential(\n", 1151 | " (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1152 | " (1): ReLU(inplace=True)\n", 1153 | " (2): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1154 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1155 | " (4): ReLU(inplace=True)\n", 1156 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1157 | " )\n", 1158 | " )\n", 1159 | " (bottle_neck_layer_1): Bottleneck(\n", 1160 | " (bottle_neck): Sequential(\n", 1161 | " (0): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1162 | " (1): ReLU(inplace=True)\n", 1163 | " (2): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1164 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1165 | " (4): ReLU(inplace=True)\n", 1166 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1167 | " )\n", 1168 | " )\n", 1169 | " (bottle_neck_layer_2): Bottleneck(\n", 1170 | " (bottle_neck): Sequential(\n", 1171 | " (0): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1172 | " (1): ReLU(inplace=True)\n", 1173 | " (2): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1174 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1175 | " (4): ReLU(inplace=True)\n", 1176 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1177 | " )\n", 1178 | " )\n", 1179 | " (bottle_neck_layer_3): Bottleneck(\n", 1180 | " (bottle_neck): Sequential(\n", 1181 | " (0): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1182 | " (1): ReLU(inplace=True)\n", 1183 | " (2): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1184 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1185 | " (4): ReLU(inplace=True)\n", 1186 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1187 | " )\n", 1188 | " )\n", 1189 | " (bottle_neck_layer_4): Bottleneck(\n", 1190 | " (bottle_neck): Sequential(\n", 1191 | " (0): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1192 | " (1): ReLU(inplace=True)\n", 1193 | " (2): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1194 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1195 | " (4): ReLU(inplace=True)\n", 1196 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1197 | " )\n", 1198 | " )\n", 1199 | " (bottle_neck_layer_5): Bottleneck(\n", 1200 | " (bottle_neck): Sequential(\n", 1201 | " (0): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1202 | " (1): ReLU(inplace=True)\n", 1203 | " (2): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1204 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1205 | " (4): ReLU(inplace=True)\n", 1206 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1207 | " )\n", 1208 | " )\n", 1209 | " (bottle_neck_layer_6): Bottleneck(\n", 1210 | " (bottle_neck): Sequential(\n", 1211 | " (0): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1212 | " (1): ReLU(inplace=True)\n", 1213 | " (2): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1214 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1215 | " (4): ReLU(inplace=True)\n", 1216 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1217 | " )\n", 1218 | " )\n", 1219 | " (bottle_neck_layer_7): Bottleneck(\n", 1220 | " (bottle_neck): Sequential(\n", 1221 | " (0): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1222 | " (1): ReLU(inplace=True)\n", 1223 | " (2): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1224 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1225 | " (4): ReLU(inplace=True)\n", 1226 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1227 | " )\n", 1228 | " )\n", 1229 | " (bottle_neck_layer_8): Bottleneck(\n", 1230 | " (bottle_neck): Sequential(\n", 1231 | " (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1232 | " (1): ReLU(inplace=True)\n", 1233 | " (2): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1234 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1235 | " (4): ReLU(inplace=True)\n", 1236 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1237 | " )\n", 1238 | " )\n", 1239 | " (bottle_neck_layer_9): Bottleneck(\n", 1240 | " (bottle_neck): Sequential(\n", 1241 | " (0): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1242 | " (1): ReLU(inplace=True)\n", 1243 | " (2): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1244 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1245 | " (4): ReLU(inplace=True)\n", 1246 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1247 | " )\n", 1248 | " )\n", 1249 | " (bottle_neck_layer_10): Bottleneck(\n", 1250 | " (bottle_neck): Sequential(\n", 1251 | " (0): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1252 | " (1): ReLU(inplace=True)\n", 1253 | " (2): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1254 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1255 | " (4): ReLU(inplace=True)\n", 1256 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1257 | " )\n", 1258 | " )\n", 1259 | " (bottle_neck_layer_11): Bottleneck(\n", 1260 | " (bottle_neck): Sequential(\n", 1261 | " (0): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1262 | " (1): ReLU(inplace=True)\n", 1263 | " (2): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1264 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1265 | " (4): ReLU(inplace=True)\n", 1266 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1267 | " )\n", 1268 | " )\n", 1269 | " (bottle_neck_layer_12): Bottleneck(\n", 1270 | " (bottle_neck): Sequential(\n", 1271 | " (0): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1272 | " (1): ReLU(inplace=True)\n", 1273 | " (2): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1274 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1275 | " (4): ReLU(inplace=True)\n", 1276 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1277 | " )\n", 1278 | " )\n", 1279 | " (bottle_neck_layer_13): Bottleneck(\n", 1280 | " (bottle_neck): Sequential(\n", 1281 | " (0): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1282 | " (1): ReLU(inplace=True)\n", 1283 | " (2): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1284 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1285 | " (4): ReLU(inplace=True)\n", 1286 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1287 | " )\n", 1288 | " )\n", 1289 | " (bottle_neck_layer_14): Bottleneck(\n", 1290 | " (bottle_neck): Sequential(\n", 1291 | " (0): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1292 | " (1): ReLU(inplace=True)\n", 1293 | " (2): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1294 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1295 | " (4): ReLU(inplace=True)\n", 1296 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1297 | " )\n", 1298 | " )\n", 1299 | " (bottle_neck_layer_15): Bottleneck(\n", 1300 | " (bottle_neck): Sequential(\n", 1301 | " (0): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1302 | " (1): ReLU(inplace=True)\n", 1303 | " (2): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1304 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1305 | " (4): ReLU(inplace=True)\n", 1306 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1307 | " )\n", 1308 | " )\n", 1309 | " (bottle_neck_layer_16): Bottleneck(\n", 1310 | " (bottle_neck): Sequential(\n", 1311 | " (0): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1312 | " (1): ReLU(inplace=True)\n", 1313 | " (2): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1314 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1315 | " (4): ReLU(inplace=True)\n", 1316 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1317 | " )\n", 1318 | " )\n", 1319 | " (bottle_neck_layer_17): Bottleneck(\n", 1320 | " (bottle_neck): Sequential(\n", 1321 | " (0): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1322 | " (1): ReLU(inplace=True)\n", 1323 | " (2): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1324 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1325 | " (4): ReLU(inplace=True)\n", 1326 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1327 | " )\n", 1328 | " )\n", 1329 | " (bottle_neck_layer_18): Bottleneck(\n", 1330 | " (bottle_neck): Sequential(\n", 1331 | " (0): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1332 | " (1): ReLU(inplace=True)\n", 1333 | " (2): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1334 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1335 | " (4): ReLU(inplace=True)\n", 1336 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1337 | " )\n", 1338 | " )\n", 1339 | " (bottle_neck_layer_19): Bottleneck(\n", 1340 | " (bottle_neck): Sequential(\n", 1341 | " (0): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1342 | " (1): ReLU(inplace=True)\n", 1343 | " (2): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1344 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1345 | " (4): ReLU(inplace=True)\n", 1346 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1347 | " )\n", 1348 | " )\n", 1349 | " (bottle_neck_layer_20): Bottleneck(\n", 1350 | " (bottle_neck): Sequential(\n", 1351 | " (0): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1352 | " (1): ReLU(inplace=True)\n", 1353 | " (2): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1354 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1355 | " (4): ReLU(inplace=True)\n", 1356 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1357 | " )\n", 1358 | " )\n", 1359 | " (bottle_neck_layer_21): Bottleneck(\n", 1360 | " (bottle_neck): Sequential(\n", 1361 | " (0): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1362 | " (1): ReLU(inplace=True)\n", 1363 | " (2): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1364 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1365 | " (4): ReLU(inplace=True)\n", 1366 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1367 | " )\n", 1368 | " )\n", 1369 | " (bottle_neck_layer_22): Bottleneck(\n", 1370 | " (bottle_neck): Sequential(\n", 1371 | " (0): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1372 | " (1): ReLU(inplace=True)\n", 1373 | " (2): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1374 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1375 | " (4): ReLU(inplace=True)\n", 1376 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1377 | " )\n", 1378 | " )\n", 1379 | " (bottle_neck_layer_23): Bottleneck(\n", 1380 | " (bottle_neck): Sequential(\n", 1381 | " (0): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1382 | " (1): ReLU(inplace=True)\n", 1383 | " (2): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1384 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1385 | " (4): ReLU(inplace=True)\n", 1386 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1387 | " )\n", 1388 | " )\n", 1389 | " )\n", 1390 | " (transition_layer_2): Transition(\n", 1391 | " (down_sample): Sequential(\n", 1392 | " (0): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1393 | " (1): ReLU(inplace=True)\n", 1394 | " (2): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1395 | " (3): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", 1396 | " )\n", 1397 | " )\n", 1398 | " (dense_block3): Sequential(\n", 1399 | " (bottle_neck_layer_0): Bottleneck(\n", 1400 | " (bottle_neck): Sequential(\n", 1401 | " (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1402 | " (1): ReLU(inplace=True)\n", 1403 | " (2): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1404 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1405 | " (4): ReLU(inplace=True)\n", 1406 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1407 | " )\n", 1408 | " )\n", 1409 | " (bottle_neck_layer_1): Bottleneck(\n", 1410 | " (bottle_neck): Sequential(\n", 1411 | " (0): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1412 | " (1): ReLU(inplace=True)\n", 1413 | " (2): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1414 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1415 | " (4): ReLU(inplace=True)\n", 1416 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1417 | " )\n", 1418 | " )\n", 1419 | " (bottle_neck_layer_2): Bottleneck(\n", 1420 | " (bottle_neck): Sequential(\n", 1421 | " (0): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1422 | " (1): ReLU(inplace=True)\n", 1423 | " (2): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1424 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1425 | " (4): ReLU(inplace=True)\n", 1426 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1427 | " )\n", 1428 | " )\n", 1429 | " (bottle_neck_layer_3): Bottleneck(\n", 1430 | " (bottle_neck): Sequential(\n", 1431 | " (0): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1432 | " (1): ReLU(inplace=True)\n", 1433 | " (2): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1434 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1435 | " (4): ReLU(inplace=True)\n", 1436 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1437 | " )\n", 1438 | " )\n", 1439 | " (bottle_neck_layer_4): Bottleneck(\n", 1440 | " (bottle_neck): Sequential(\n", 1441 | " (0): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1442 | " (1): ReLU(inplace=True)\n", 1443 | " (2): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1444 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1445 | " (4): ReLU(inplace=True)\n", 1446 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1447 | " )\n", 1448 | " )\n", 1449 | " (bottle_neck_layer_5): Bottleneck(\n", 1450 | " (bottle_neck): Sequential(\n", 1451 | " (0): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1452 | " (1): ReLU(inplace=True)\n", 1453 | " (2): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1454 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1455 | " (4): ReLU(inplace=True)\n", 1456 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1457 | " )\n", 1458 | " )\n", 1459 | " (bottle_neck_layer_6): Bottleneck(\n", 1460 | " (bottle_neck): Sequential(\n", 1461 | " (0): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1462 | " (1): ReLU(inplace=True)\n", 1463 | " (2): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1464 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1465 | " (4): ReLU(inplace=True)\n", 1466 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1467 | " )\n", 1468 | " )\n", 1469 | " (bottle_neck_layer_7): Bottleneck(\n", 1470 | " (bottle_neck): Sequential(\n", 1471 | " (0): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1472 | " (1): ReLU(inplace=True)\n", 1473 | " (2): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1474 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1475 | " (4): ReLU(inplace=True)\n", 1476 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1477 | " )\n", 1478 | " )\n", 1479 | " (bottle_neck_layer_8): Bottleneck(\n", 1480 | " (bottle_neck): Sequential(\n", 1481 | " (0): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1482 | " (1): ReLU(inplace=True)\n", 1483 | " (2): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1484 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1485 | " (4): ReLU(inplace=True)\n", 1486 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1487 | " )\n", 1488 | " )\n", 1489 | " (bottle_neck_layer_9): Bottleneck(\n", 1490 | " (bottle_neck): Sequential(\n", 1491 | " (0): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1492 | " (1): ReLU(inplace=True)\n", 1493 | " (2): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1494 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1495 | " (4): ReLU(inplace=True)\n", 1496 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1497 | " )\n", 1498 | " )\n", 1499 | " (bottle_neck_layer_10): Bottleneck(\n", 1500 | " (bottle_neck): Sequential(\n", 1501 | " (0): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1502 | " (1): ReLU(inplace=True)\n", 1503 | " (2): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1504 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1505 | " (4): ReLU(inplace=True)\n", 1506 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1507 | " )\n", 1508 | " )\n", 1509 | " (bottle_neck_layer_11): Bottleneck(\n", 1510 | " (bottle_neck): Sequential(\n", 1511 | " (0): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1512 | " (1): ReLU(inplace=True)\n", 1513 | " (2): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1514 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1515 | " (4): ReLU(inplace=True)\n", 1516 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1517 | " )\n", 1518 | " )\n", 1519 | " (bottle_neck_layer_12): Bottleneck(\n", 1520 | " (bottle_neck): Sequential(\n", 1521 | " (0): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1522 | " (1): ReLU(inplace=True)\n", 1523 | " (2): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1524 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1525 | " (4): ReLU(inplace=True)\n", 1526 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1527 | " )\n", 1528 | " )\n", 1529 | " (bottle_neck_layer_13): Bottleneck(\n", 1530 | " (bottle_neck): Sequential(\n", 1531 | " (0): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1532 | " (1): ReLU(inplace=True)\n", 1533 | " (2): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1534 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1535 | " (4): ReLU(inplace=True)\n", 1536 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1537 | " )\n", 1538 | " )\n", 1539 | " (bottle_neck_layer_14): Bottleneck(\n", 1540 | " (bottle_neck): Sequential(\n", 1541 | " (0): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1542 | " (1): ReLU(inplace=True)\n", 1543 | " (2): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1544 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1545 | " (4): ReLU(inplace=True)\n", 1546 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1547 | " )\n", 1548 | " )\n", 1549 | " (bottle_neck_layer_15): Bottleneck(\n", 1550 | " (bottle_neck): Sequential(\n", 1551 | " (0): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1552 | " (1): ReLU(inplace=True)\n", 1553 | " (2): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 1554 | " (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1555 | " (4): ReLU(inplace=True)\n", 1556 | " (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 1557 | " )\n", 1558 | " )\n", 1559 | " )\n", 1560 | " (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 1561 | " (relu): ReLU(inplace=True)\n", 1562 | " )\n", 1563 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", 1564 | " (linear): Linear(in_features=1024, out_features=10, bias=True)\n", 1565 | ")\n" 1566 | ] 1567 | } 1568 | ], 1569 | "source": [ 1570 | "#这里我只敢选择resnet18网络,其他会死机。\n", 1571 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 1572 | "net = densenet121().to(device)\n", 1573 | "print(net)" 1574 | ] 1575 | }, 1576 | { 1577 | "cell_type": "markdown", 1578 | "metadata": {}, 1579 | "source": [ 1580 | "### 5. 定义损失函数和优化器 \n", 1581 | " pytorch将深度学习中常用的优化方法全部封装在torch.optim之中,所有的优化方法都是继承基类optim.Optimizier\n", 1582 | " \n", 1583 | " 损失函数是封装在神经网络工具箱nn中的,包含很多损失函数\n" 1584 | ] 1585 | }, 1586 | { 1587 | "cell_type": "code", 1588 | "execution_count": 19, 1589 | "metadata": {}, 1590 | "outputs": [], 1591 | "source": [ 1592 | "import torch.optim as optim\n", 1593 | "#用到了神经网络工具箱 nn 中的交叉熵损失函数\n", 1594 | "loss_function = nn.CrossEntropyLoss() \n", 1595 | "# 使用SGD(随机梯度下降)优化,学习率为0.001,动量为0.9\n", 1596 | "optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) " 1597 | ] 1598 | }, 1599 | { 1600 | "cell_type": "markdown", 1601 | "metadata": {}, 1602 | "source": [ 1603 | "### 6. 训练" 1604 | ] 1605 | }, 1606 | { 1607 | "cell_type": "code", 1608 | "execution_count": 20, 1609 | "metadata": { 1610 | "scrolled": true 1611 | }, 1612 | "outputs": [ 1613 | { 1614 | "ename": "RuntimeError", 1615 | "evalue": "CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 2.94 GiB total capacity; 2.52 GiB already allocated; 12.50 MiB free; 2.53 GiB reserved in total by PyTorch)", 1616 | "output_type": "error", 1617 | "traceback": [ 1618 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 1619 | "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", 1620 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0mimages\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mlogits\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mloss_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1621 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 533\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1622 | "\u001b[1;32m\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 73\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 74\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 75\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mavgpool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0moutput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1623 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 533\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1624 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1625 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 533\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1626 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1627 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 533\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1628 | "\u001b[1;32m\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[1;31m# 将输入x同计算的结果out进行通道拼接\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbottle_neck\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[1;31m#Transition层,具体包括BN、ReLU、1×1卷积(Conv)、2×2平均池化操作。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1629 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 533\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1630 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1631 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 532\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 533\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1632 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 105\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrunning_mean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrunning_var\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 106\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrack_running_stats\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 107\u001b[1;33m exponential_average_factor, self.eps)\n\u001b[0m\u001b[0;32m 108\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1633 | "\u001b[1;32mD:\\Anaconda3\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[1;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[0;32m 1668\u001b[0m return torch.batch_norm(\n\u001b[0;32m 1669\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1670\u001b[1;33m \u001b[0mtraining\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meps\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1671\u001b[0m )\n\u001b[0;32m 1672\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1634 | "\u001b[1;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 2.94 GiB total capacity; 2.52 GiB already allocated; 12.50 MiB free; 2.53 GiB reserved in total by PyTorch)" 1635 | ] 1636 | } 1637 | ], 1638 | "source": [ 1639 | "best_acc = 0.0\n", 1640 | "save_path = 'D:/CIFAR-10/model/DenseNet121.pth'\n", 1641 | "for epoch in range(50):\n", 1642 | " # train\n", 1643 | " net.train()\n", 1644 | " running_loss = 0.0\n", 1645 | " for step, data in enumerate(trainloader, start=0):\n", 1646 | " images, labels = data\n", 1647 | " optimizer.zero_grad()\n", 1648 | " logits = net(images.to(device))\n", 1649 | " loss = loss_function(logits, labels.to(device))\n", 1650 | " loss.backward()\n", 1651 | " optimizer.step()\n", 1652 | "\n", 1653 | " # print statistics\n", 1654 | " running_loss += loss.item()\n", 1655 | " # print train process\n", 1656 | " rate = (step+1)/len(trainloader)\n", 1657 | " a = \"*\" * int(rate * 50)\n", 1658 | " b = \".\" * int((1 - rate) * 50)\n", 1659 | " print(\"\\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}\".format(int(rate*100), a, b, loss), end=\"\")\n", 1660 | " print()\n", 1661 | "\n", 1662 | " # validate\n", 1663 | " net.eval()\n", 1664 | " acc = 0.0 # accumulate accurate number / epoch\n", 1665 | " with torch.no_grad():\n", 1666 | " for val_data in testloader:\n", 1667 | " val_images, val_labels = val_data\n", 1668 | " outputs = net(val_images.to(device)) # eval model only have last output layer\n", 1669 | " # loss = loss_function(outputs, test_labels)\n", 1670 | " predict_y = torch.max(outputs, dim=1)[1]\n", 1671 | " acc += (predict_y == val_labels.to(device)).sum().item()\n", 1672 | " val_accurate = acc / len(testset)\n", 1673 | " if val_accurate > best_acc:\n", 1674 | " best_acc = val_accurate\n", 1675 | " torch.save(net.state_dict(), save_path)\n", 1676 | " print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %\n", 1677 | " (epoch + 1, running_loss / step, val_accurate))\n", 1678 | "\n", 1679 | "print('Finished Training')" 1680 | ] 1681 | }, 1682 | { 1683 | "cell_type": "code", 1684 | "execution_count": null, 1685 | "metadata": {}, 1686 | "outputs": [], 1687 | "source": [ 1688 | "from torch.autograd import Variable\n", 1689 | "# 定义2个存储每类中测试正确的个数的 列表,初始化为0\n", 1690 | "class_correct = list(0. for i in range(10))\n", 1691 | "class_total = list(0. for i in range(10))\n", 1692 | "with torch.no_grad():\n", 1693 | " \n", 1694 | " for data in testloader:\n", 1695 | " net.eval()\n", 1696 | " images, labels = data\n", 1697 | " images=Variable(images).cuda()\n", 1698 | " labels=Variable(labels).cuda()\n", 1699 | " outputs = net(images)\n", 1700 | "\n", 1701 | " _, predicted = torch.max(outputs.data, 1)\n", 1702 | " #4组(batch_size)数据中,输出于label相同的,标记为1,否则为0\n", 1703 | " c = (predicted == labels).squeeze()\n", 1704 | " for i in range(16): # 因为每个batch都有4张图片,所以还需要一个4的小循环\n", 1705 | " label = labels[i] # 对各个类的进行各自累加\n", 1706 | " class_correct[label] += c[i]\n", 1707 | " class_total[label] += 1\n", 1708 | " \n", 1709 | " \n", 1710 | "for i in range(10):\n", 1711 | " print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))" 1712 | ] 1713 | }, 1714 | { 1715 | "cell_type": "markdown", 1716 | "metadata": {}, 1717 | "source": [ 1718 | "### 8. 保存模型" 1719 | ] 1720 | }, 1721 | { 1722 | "cell_type": "code", 1723 | "execution_count": null, 1724 | "metadata": {}, 1725 | "outputs": [], 1726 | "source": [ 1727 | "# torch.save(net, 'D:/CIFAR-10/model/DenseNet.pth')" 1728 | ] 1729 | }, 1730 | { 1731 | "cell_type": "markdown", 1732 | "metadata": {}, 1733 | "source": [ 1734 | "### 9. 预测" 1735 | ] 1736 | }, 1737 | { 1738 | "cell_type": "code", 1739 | "execution_count": null, 1740 | "metadata": {}, 1741 | "outputs": [], 1742 | "source": [ 1743 | "import torch\n", 1744 | "from PIL import Image\n", 1745 | "from torch.autograd import Variable\n", 1746 | "import torch.nn.functional as F\n", 1747 | "from torchvision import datasets, transforms\n", 1748 | "import numpy as np\n", 1749 | " \n", 1750 | "classes = ('plane', 'car', 'bird', 'cat',\n", 1751 | " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", 1752 | "\n", 1753 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 1754 | "\n", 1755 | "\n", 1756 | "model = MobileNetV2(num_classes=10)\n", 1757 | "\n", 1758 | "model.load_state_dict(torch.load('D:/CIFAR-10/model/DenseNet121.pth'))# 加载模型\n", 1759 | "model = model.to(device)\n", 1760 | "model.eval() # 把模型转为test模式\n", 1761 | "# 读取要预测的图片\n", 1762 | "img = Image.open(\"D:/CIFAR-10/airplan.jpg\").convert('RGB') # 读取图像" 1763 | ] 1764 | }, 1765 | { 1766 | "cell_type": "code", 1767 | "execution_count": null, 1768 | "metadata": {}, 1769 | "outputs": [], 1770 | "source": [ 1771 | "img" 1772 | ] 1773 | }, 1774 | { 1775 | "cell_type": "code", 1776 | "execution_count": null, 1777 | "metadata": {}, 1778 | "outputs": [], 1779 | "source": [ 1780 | "trans = transforms.Compose([transforms.Scale((32,32)),\n", 1781 | " transforms.ToTensor(),\n", 1782 | " transforms.Normalize(mean=(0.5, 0.5, 0.5), \n", 1783 | " std=(0.5, 0.5, 0.5)),\n", 1784 | " ])\n", 1785 | " \n", 1786 | "img = trans(img)\n", 1787 | "img = img.to(device)\n", 1788 | "# 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]\n", 1789 | "img = img.unsqueeze(0) \n", 1790 | " # 扩展后,为[1,1,28,28]\n", 1791 | "output = model(img)\n", 1792 | "prob = F.softmax(output,dim=1) #prob是10个分类的概率\n", 1793 | "print(\"概率\",prob)\n", 1794 | "value, predicted = torch.max(output.data, 1)\n", 1795 | "print(\"类别\",predicted.item())\n", 1796 | "print(value)\n", 1797 | "pred_class = classes[predicted.item()]\n", 1798 | "print(\"分类\",pred_class)\n", 1799 | " \n", 1800 | " \n", 1801 | " \n", 1802 | " # prob = F.softmax(output, dim=1)\n", 1803 | " # prob = Variable(prob)\n", 1804 | " # prob = prob.cpu().numpy() # 用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式\n", 1805 | " # print(prob) # prob是10个分类的概率\n", 1806 | " # pred = np.argmax(prob) # 选出概率最大的一个\n", 1807 | " # # print(pred)\n", 1808 | " # # print(pred.item())\n", 1809 | " # pred_class = classes[pred]\n", 1810 | " # print(pred_class)" 1811 | ] 1812 | }, 1813 | { 1814 | "cell_type": "code", 1815 | "execution_count": null, 1816 | "metadata": {}, 1817 | "outputs": [], 1818 | "source": [] 1819 | } 1820 | ], 1821 | "metadata": { 1822 | "kernelspec": { 1823 | "display_name": "Python 3", 1824 | "language": "python", 1825 | "name": "python3" 1826 | }, 1827 | "language_info": { 1828 | "codemirror_mode": { 1829 | "name": "ipython", 1830 | "version": 3 1831 | }, 1832 | "file_extension": ".py", 1833 | "mimetype": "text/x-python", 1834 | "name": "python", 1835 | "nbconvert_exporter": "python", 1836 | "pygments_lexer": "ipython3", 1837 | "version": "3.7.6" 1838 | } 1839 | }, 1840 | "nbformat": 4, 1841 | "nbformat_minor": 4 1842 | } 1843 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_Image_Classification 2 | LeNet5→AlexNet→VGGNet→GoogleNet→ResNet→MobileNet→DenseNet 3 | --------------------------------------------------------------------------------