├── requirements.txt ├── model.pt ├── Train.jpg ├── images ├── Architecture.jpg └── PaperErrorRate.png ├── .gitignore ├── viewModel.py ├── README.md ├── imports └── ParametersManager.py ├── DownloadUnzipData.py ├── LeNet-5_GPU.py └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | requests 3 | numpy 4 | matplotlib -------------------------------------------------------------------------------- /model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnyHaze/LeNet5-MNIST-Pytorch/HEAD/model.pt -------------------------------------------------------------------------------- /Train.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnyHaze/LeNet5-MNIST-Pytorch/HEAD/Train.jpg -------------------------------------------------------------------------------- /images/Architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnyHaze/LeNet5-MNIST-Pytorch/HEAD/images/Architecture.jpg -------------------------------------------------------------------------------- /images/PaperErrorRate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SunnyHaze/LeNet5-MNIST-Pytorch/HEAD/images/PaperErrorRate.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | CNN.py 2 | __pycache__ 3 | test.py 4 | *-idx3-ubyte* 5 | *-idx1-ubyte* 6 | *.idx3-ubyte* 7 | *.idx1-ubyte* 8 | *.npy 9 | ReadData.py -------------------------------------------------------------------------------- /viewModel.py: -------------------------------------------------------------------------------- 1 | from imports.ParametersManager import * 2 | from matplotlib import pyplot as plt 3 | parManager = ParametersManager('cuda') 4 | parManager.loadFromFile('model.pt') 5 | # 在终端输出最终的两个数据集准确率 6 | parManager.show() 7 | 8 | # 绘制迄今为止的训练准确率图 9 | plt.figure(figsize=(10,7)) 10 | plt.plot(range(parManager.EpochDone),parManager.TrainACC,marker='*' ,color='r',label='Train') 11 | plt.plot(range(parManager.EpochDone),parManager.TestACC,marker='*' ,color='b',label='Test') 12 | 13 | plt.xlabel('Epochs') 14 | plt.ylabel('ACC') 15 | plt.legend() 16 | plt.title("LeNet-5 on MNIST") 17 | 18 | plt.show() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于Pytorch复现LeNet-5在MNIST数据集的实现 2 | ![Powered by](https://img.shields.io/badge/Based_on-Pytorch-blue?logo=pytorch) 3 | ![GitHub repo size](https://img.shields.io/github/repo-size/SunnyHaze/LeNet5-Pytorch?logo=hack%20the%20box) 4 | ![GitHub](https://img.shields.io/github/license/Sunnyhaze/LeNet5-Pytorch?logo=license) 5 | 6 | 本文使用Pytorch构建了经典的LeNet-5网络,数据集为[MNIST数据集](http://yann.lecun.com/exdb/mnist/),并提供了一个预训练模型与结果。 7 | 8 | >MNIST数据集是一个非常经典的手写体数字识别数据集。 9 | 10 | 同时本文也提供了一个下载、解压、重构原始数据集的自动化脚本,便于自行体验模型训练过程。 11 | ## LeNet-5简介 12 | LeNet-5是Yann LeCun巨佬在1998年就提出的卷积神经网络模型,非常的经典。是用于手写体字符识别的非常高效的卷积神经网络。 13 | 14 | 论文链接:[Gradient-based learning applied to document recognition](https://ieeexplore.ieee.org/abstract/document/726791) 15 | > 引用量4万多,感受大佬的恐怖 16 | 17 | LeNet-5网络很小,但是包含了图像识别方向深度学习的基本模块,卷积层,池化层(此时还仅称为下采样层(subsampling),AlexNet模型诞生后才称作池化层)全连接层。是其他深度学习模型的基础。 18 | 19 | + 网络结构 20 | 21 | ![](/images/Architecture.jpg) 22 | 23 | 具体的网络结构解释已经有很多大佬写过文章,本文也在代码中做了浅显的解释,请自行查阅。 24 | ## 本仓库简介 25 | ### 目录结构 26 | - [LeNet-5_GPU.py](LeNet-5_GPU.py) 是主要的模型与训练逻辑文件脚本。 27 | - [/imports/ParametersManager.py](imports/ParametersManager.py) 是主要用于保存训练过程中的各种参数到文件的“控制器”。主模型文件中引用了这个包。 28 | - [DownloadUnzipData.py](DownloadUnzipData.py) 是用于自动化下载原始数据集并解压到文件的脚本 29 | - [model.pt](model.pt) 是已经预训练到准确率98%以上的模型文件,目前已经训练了31个Epoch 30 | - [viewModel.py](viewModel.py) 是直接查看当前模型准确率的脚本,需要在目录中存在`model.pt`才能正常运行 31 | 32 | ### 使用方法 33 | - 查看当前模型的效率与训练过程准确率:请运行`viewModel.py`即可 34 | - 也可在`Train.jpg`中查看已经输出的训练结果 35 | - 为了减小本仓库的大小,并没有上传数据集,所以需要先运行`DownloadUnzipData.py`来下载并解包原始数据集。 36 | - 下载和解包的速度可能较慢,但这是值得的。可以大幅提高训练模型时候的效率。 37 | - 模型默认训练30个Epoch,BatchSize为10,可以自行尝试调整 38 | > 原始的数据是以“一个字节存储一个数值”的形式存储在数据集中的,所以是高度压缩的,而计算机中用来运算的浮点数,则需要达到32位(4字节),这也是显卡大多最支持的数据类型。 39 | >如果不提前将数据转储位4字节格式,则会在读取数据时不断的由CPU进行运算,转换1字节的数据为4字节, 40 | >这会重复浪费大量的运算能力,带来的结果就是CPU满载100%,但GPU几乎长期只有0%。 41 | >而提前将数据解压好后,经过系统优化可以直接将整块的数据存入显存,大幅提高运算速度 42 | >虽然数据集的大小变大了不少,但是运算时间大幅降低!也就是下面定义的类的主要功能: 43 | >【读取高度压缩的字节码文件,并转化为GPU喜闻乐见的浮点数形式保存】 44 | (具体的字节码如何组织的,可以参考http://yann.lecun.com/exdb/mnist/ 网页最下面的说明) 45 | - 如果想从头自行训练模型,请先将根目录下的`model.pt`重命名为其他名称,或者拷贝到其他地方,这样运行`LeNet-5_GPU.py`脚本就会重新训练一个新的模型。 46 | 47 | 48 | ### 效果评估 49 | 整体运行的准确率与巨佬论文中提供的`Error Rate`基本匹配,但是也偶见一些情况训练开始后准确率始终不下降。推测可能是因为随机初始化运气不好,无法收敛,此时重新运行脚本重新初始化权重即可。 50 | + 我的模型的`准确率` 51 | 52 | ![](Train.jpg) 53 | 54 | + 大佬原文的`Error Rate`和一些小问题的解释原因 55 | 56 | ![](images/PaperErrorRate.png) 57 | -------------------------------------------------------------------------------- /imports/ParametersManager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # 定义一个管理模型训练时参数的类 4 | class ParametersManager(): 5 | def __init__(self,device) -> None: 6 | self.device = device 7 | # 具体数据 8 | self.EpochDone = 0 # 已经完成的Epoch个数 9 | self.LearningRate = [] # 各个Epoch的学习率 10 | self.TrainACC = [] # 训练集准确率 11 | self.TestACC = [] # 测试集准确率 12 | self.loss = [] # loss 13 | self.state_dict = 0 # 模型的具体权重 14 | self.datas = {} 15 | # 打包 16 | def pack(self): 17 | self.datas = { 18 | 'EpochDone' : self.EpochDone, # 已经完成的Epoch个数 19 | 'LearningRate' : self.LearningRate, # 各个Epoch的学习率 20 | 'TrainACC' : self.TrainACC, # 训练集准确率 21 | 'TestACC' : self.TestACC, # 测试集准确率 22 | 'loss' : self.loss, # loss 23 | 'state_dict' : self.state_dict, # 模型的具体权重 24 | } 25 | # 解包 26 | def unpack(self): 27 | self.EpochDone = self.datas['EpochDone'] 28 | self.LearningRate = self.datas['LearningRate'] 29 | self.TestACC = self.datas['TestACC'] 30 | self.TrainACC = self.datas['TrainACC'] 31 | self.loss = self.datas['loss'] 32 | self.state_dict = self.datas['state_dict'] 33 | # 从脚本中获取模型的参数 34 | def loadModelParameters(self, model:nn.Module): 35 | self.state_dict = model.state_dict() 36 | 37 | # 从脚本中将参数输出给模型 38 | def setModelParameters(self, model:nn.Module): 39 | model.load_state_dict(self.state_dict) 40 | 41 | # 从脚本中获取一个Epoch的 42 | def oneEpochDone(self, LastLearningRate, LastTrainACC, lastTestACC, lastLoss): 43 | self.EpochDone += 1 44 | self.LearningRate.append(LastLearningRate) 45 | self.TrainACC.append(LastTrainACC) 46 | self.TestACC.append(lastTestACC) 47 | self.loss.append(lastLoss) 48 | 49 | # 保存数据到文件 50 | def saveToFile(self, path): 51 | self.pack() 52 | torch.save(self.datas, path) 53 | print('===succesfully saved model!===') 54 | 55 | # 从文件中读取数据 56 | def loadFromFile(self, path): 57 | self.datas = torch.load(path,map_location=torch.device(self.device)) 58 | self.unpack() 59 | print('===Load model succesfully!===') 60 | # 展示当前存储的模型的数据 61 | def show(self): 62 | print('===' * 10 + 63 | '''\n此模型已经训练了{}个Epoch \n 64 | 目前的训练集准确率为 {:.3f}% \n 65 | 目前的测试集准确率为 {:.3f}% \n'''.format(self.EpochDone, self.TrainACC[-1] * 100, self.TestACC[-1] * 100),'===' * 10) -------------------------------------------------------------------------------- /DownloadUnzipData.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import requests , gzip 5 | from tqdm import tqdm 6 | # 需要下载的文件名 7 | fileNames = [ 8 | 'train-images-idx3-ubyte.gz', 9 | 'train-labels-idx1-ubyte.gz', 10 | 't10k-images-idx3-ubyte.gz', 11 | 't10k-labels-idx1-ubyte.gz' 12 | ] 13 | # 用于简便直观理解文件名的字典 14 | rawDataName = { 15 | "trainX" : "train-images-idx3-ubyte", 16 | "trainY" : "train-labels-idx1-ubyte", 17 | "testX" : "t10k-images-idx3-ubyte", 18 | "testY" : "t10k-labels-idx1-ubyte" 19 | } 20 | # 用于下载文件的函数 21 | def downLoadAFile(filename:str): 22 | baseUrl = 'http://yann.lecun.com/exdb/mnist/' 23 | tmpUrl = baseUrl + filename 24 | print('下载 {} 文件中...'.format(filename)) 25 | file = requests.get(url=tmpUrl) 26 | with open(filename, 'wb+') as f: 27 | f.write(file.content) 28 | print('下载完成!') 29 | cutFilename = filename.replace('.gz','') 30 | print('解压gzip文件 {} 中...'.format(filename)) 31 | gFile = gzip.GzipFile(filename) 32 | with open(cutFilename, 'wb+') as f: 33 | f.write(gFile.read()) 34 | print('解压完成') 35 | print('===数据集下载完成===') 36 | 37 | # =================以下是构建实际应用的数据集部分================ 38 | # 原始的数据是以“一个字节存储一个数值”的形式存储在数据集中的,所以是高度压缩的 39 | # 而计算机中用来运算的浮点数,则需要达到32位(4字节),这也是显卡大多最支持的数据类型 40 | # 如果不提前将数据转储位4字节格式,则会在读取数据时不断的由CPU进行运算,转换1字节的数据为4字节 41 | # 这会重复浪费大量的运算能力,带来的结果就是CPU满载100%,但GPU几乎长期只有0% 42 | # 而提前将数据解压好后,经过系统优化可以直接将整块的数据存入显存,大幅提高运算速度 43 | 44 | # 虽然数据集的大小变大了不少,但是运算时间大幅降低!也就是下面定义的类的主要功能: 45 | # 【读取高度压缩的字节码文件,并转化为GPU喜闻乐见的形式保存】 46 | # (具体的字节码如何组织的,可以参考http://yann.lecun.com/exdb/mnist/ 网页页最下面的说明) 47 | class imgReader: 48 | def __init__(self,PATH) -> None: 49 | self.path = PATH 50 | with open(self.path, 'rb') as f: 51 | self.buff = f.read() 52 | # 按字节拆出“元素”个数 具体字节码组织形式参考MNIST官网 53 | self.size = struct.unpack(">i",self.buff[4:8])[0] # 按照4字节无符号整数拆分字节,教程:https://www.liaoxuefeng.com/wiki/1016959663602400/1017685387246080 54 | # 象征性的拆一下横竖 55 | self.numberOfRows = struct.unpack(">i",self.buff[8:12]) 56 | self.numberOfCols = struct.unpack(">i",self.buff[12:16]) 57 | # 最好别用,贼慢,事实上下面的程序也没有使用此函数 58 | def returnWholeArray(self) -> list: 59 | data = {} 60 | data["imgs"] = [] 61 | data["size"] = self.size 62 | for i in range(data["size"]): 63 | tmpMatrix = struct.unpack_from('>784B',self.buff[16:],i * 784) # "magic number" 单纯的28x28的图像,784字节 64 | pic = np.array(tmpMatrix) 65 | pic = pic.reshape(28,28) 66 | data["imgs"].append(pic) 67 | return data 68 | # 重载[]运算符 69 | def __getitem__(self,index): 70 | if type(index) == int: 71 | assert index >= 0 and index < self.size ,"Index out of range! Should less than {}.".format(self.size) 72 | offset = index * 784 73 | tmpMatrix = struct.unpack_from('>784B',self.buff[16:],offset) 74 | tmpMatrix = np.array(tmpMatrix).reshape(28,28) 75 | return np.array(tmpMatrix) 76 | if type(index) == slice: 77 | stop = self.size-1 if index.stop == None else index.stop 78 | start = 0 if index.start == None else index.start 79 | length = stop - start 80 | data = [] 81 | for i in range(start,stop): 82 | data.append(self.__getitem__(i)) 83 | return data 84 | # 重载len()运算符 85 | def __len__(self): 86 | return self.size 87 | 88 | # 用于读取标签的函数 89 | class labelReader: 90 | def __init__(self,PATH) -> None: 91 | self.path = PATH 92 | with open(self.path, 'rb') as f: 93 | self.buff = f.read() 94 | self.size = struct.unpack(">i",self.buff[4:8])[0] 95 | # 最好别用,贼慢 96 | def returnWholeArray(self) -> list: 97 | data = {} 98 | data["imgs"] = [] 99 | sizeByte = self.buff[4:8] # 按字节拆出“元素”个数 具体字节码组织形式参考MNIST官网 100 | data["size"] = struct.unpack(">i",sizeByte)[0] # 按照4字节无符号整数拆分字节,教程:https://www.liaoxuefeng.com/wiki/1016959663602400/1017685387246080 101 | # 象征性的拆一下横竖 102 | for i in range(data["size"]): 103 | tmpMatrix = struct.unpack_from('>B',self.buff[8:],i) # "magic number" 单纯的28x28的图像,784字节 104 | pic = np.array(tmpMatrix) 105 | pic = pic.reshape(28,28) 106 | data["imgs"].append(pic) 107 | return data 108 | # 重载[]运算符 109 | def __getitem__(self,index): 110 | if type(index) == int: 111 | assert index >= 0 and index < self.size ,"Index {} out of range! Should less than {}.".format(index,self.size) 112 | offset = index 113 | label= struct.unpack_from('>B',self.buff,offset + 8) 114 | return label[0] 115 | if type(index) == slice: 116 | stop = self.size-1 if index.stop == None else index.stop 117 | start = 0 if index.start == None else index.start 118 | length = stop - start 119 | label = struct.unpack_from(">" + str(length) + "B", self.buff,8 + start) 120 | return label 121 | # 重载len()运算符 122 | def __len__(self): 123 | return self.size 124 | # 将一个Reader中的所有数据,1字节解码为4字节后,转储到fileName对应的文件中 125 | def saveAsNpy(reader,fileName): 126 | with open(fileName,'wb') as f: 127 | data = [] 128 | for i in tqdm(range(len(reader))): 129 | data.append(reader[i]) 130 | data = np.array(data) 131 | print('Done!') 132 | np.save(f,data) 133 | # 主函数 134 | if __name__ == "__main__": 135 | # 从网站http://yann.lecun.com/exdb/mnist/自动下载MNIST数据集并解压 136 | for i in fileNames: 137 | downLoadAFile(i) 138 | # 组织好Reader,用于读取格式化的字节码 139 | trainXreader = imgReader(rawDataName["trainX"]) 140 | testXreader = imgReader(rawDataName["testX"]) 141 | testYreader = labelReader(rawDataName["testY"]) 142 | trainYreader = labelReader(rawDataName["trainY"]) 143 | 144 | print('===正在转换训练集图片,总共{}个...==='.format(len(trainXreader))) 145 | saveAsNpy(trainXreader,'TrainImg.npy') 146 | print('===正在转换测试集图片,总共{}个...'.format(len(testXreader))) 147 | saveAsNpy(testXreader,'TestImg.npy') 148 | print('===正在转换测试集标签,总共{}个...'.format(len(testYreader))) 149 | saveAsNpy(testYreader,"TestLabel.npy") 150 | print('===正在转换训练集标签,总共{}个...'.format(len(trainYreader))) 151 | saveAsNpy(trainYreader,"TrainLabel.npy") 152 | 153 | print('=====恭喜您!全部完成,可以开始体验训练模型了!======') 154 | 155 | -------------------------------------------------------------------------------- /LeNet-5_GPU.py: -------------------------------------------------------------------------------- 1 | import os 2 | from matplotlib.pyplot import imshow 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader, TensorDataset 5 | import numpy as np 6 | from imports.ParametersManager import * 7 | 8 | # 超参数 9 | BatchSize = 10 10 | LEARNINGRATE = 0.005 11 | epochNums = 30 12 | SaveModelEveryNEpoch = 2 # 每执行多少次保存一个模型 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | # 初始化数据转换器,通过索引访问 15 | # trainXreader = imgReader(rawDataName["trainX"]) 16 | # testXreader = imgReader(rawDataName["testX"]) 17 | # trainYreader = labelReader(rawDataName["trainY"]) 18 | # testYreader = labelReader(rawDataName["testY"]) 19 | 20 | # 可以将数据线包装为Dataset,然后传入DataLoader中取样 21 | class MyDataset(Dataset): 22 | def __init__(self,SetType) -> None: 23 | with open(SetType + 'img.npy','rb') as f: 24 | self.images =torch.tensor(np.load(f, allow_pickle=True), dtype=torch.float32) 25 | with open(SetType + 'Label.npy','rb') as f: 26 | tmp = np.load(f, allow_pickle=True) 27 | print(tmp) 28 | self.labels=[] 29 | for num in tmp: 30 | self.labels.append([1 if x == num else 0 for x in range(10)]) 31 | self.labels = torch.tensor(self.labels, dtype=torch.float32) 32 | def __getitem__(self, index): 33 | return self.images.unsqueeze(1)[index], self.labels[index] 34 | def __len__(self): 35 | return len(self.labels) 36 | 37 | # 定义网络结构 38 | class LeNet_5(nn.Module): 39 | def __init__(self): 40 | super().__init__() 41 | self.layer1 = nn.Sequential( 42 | nn.Conv2d(1,6,kernel_size=5,padding=2),# 原题为三通道,此处转为单通道实现 # C1 43 | nn.ReLU(), 44 | nn.MaxPool2d(2,2), # S2 45 | nn.Conv2d(6,16,5), # C3 原始论文中C3与S2并不是全连接而是部分连接,这样能减少部分计算量。而现代CNN模型中,比如AlexNet,ResNet等,都采取全连接的方式了。我们的实现在这里做了一些简化。 46 | nn.ReLU(), 47 | nn.MaxPool2d(2,2) # S4 48 | ) 49 | # 然后需要经过变形后,继续进行全连接 50 | self.layer2 = nn.Sequential( 51 | nn.Linear(16 * 5 * 5, 120), # C5 52 | nn.ReLU(), 53 | nn.Linear(120, 84), # F6 54 | nn.ReLU(), 55 | nn.Linear(84,10), # Output 文章中使用高斯连接,现在方便起见仍然使用全连接 56 | ) 57 | def forward(self,x): 58 | x = self.layer1(x) # 执行卷积神经网络部分 59 | x = x.view(-1,16 * 5 * 5) # 重新构建向量形状,准备全连接 60 | x = self.layer2(x) # 执行全连接部分 61 | return x 62 | 63 | # 定义准确率函数 64 | def accuracy(output , label): 65 | rightNum = torch.sum(torch.max(label,1)[1].eq(torch.max(output,1)[1])) 66 | return rightNum / len(label) 67 | 68 | if __name__ == "__main__": 69 | # 模型实例化 70 | model = LeNet_5() 71 | # # 如果有“半成品”则导入参数 72 | parManager = ParametersManager(device) 73 | if os.path.exists("./model.pt"): 74 | parManager.loadFromFile('./model.pt') 75 | parManager.setModelParameters(model) 76 | else: 77 | print('===No pre-trained model found!===') 78 | 79 | model.cuda() 80 | criterion = nn.MSELoss(reduction='mean') 81 | optimizer = torch.optim.SGD(model.parameters(), lr=LEARNINGRATE, momentum=0.9) 82 | 83 | # 构建训练集 84 | TrainDataset = MyDataset('Train') 85 | # 构建测试集 86 | TestDataset = MyDataset('Test') 87 | # 构建训练集读取器 88 | TrainLoader = DataLoader(TrainDataset,num_workers=8, pin_memory=True, batch_size=BatchSize, sampler= torch.utils.data.sampler.SubsetRandomSampler(range(len(TrainDataset)))) 89 | # 构建测试集读取器: 90 | TestLoader = DataLoader(TestDataset,num_workers=8, pin_memory=True, batch_size=BatchSize, sampler= torch.utils.data.sampler.SubsetRandomSampler(range(len(TestDataset)))) 91 | # 92 | print('len(TrainLoader):{}'.format(len(TrainLoader))) 93 | 94 | # # 检查分割是否正确的函数,分为两行,以行为顺序排列和输出结果一一对应 95 | # def testLoader(): 96 | # inputs, classes = next(iter(TrainLoader)) 97 | # print(inputs.shape) 98 | # print(classes.shape) 99 | # print(classes) # 查看标签 100 | # for i in range(len(inputs)): 101 | # plt.subplot(2,5,i+1) 102 | # plt.imshow(inputs[i][0],cmap="gray") 103 | # plt.show() 104 | 105 | # testLoader() 106 | 107 | TrainACC = [] 108 | TestACC = [] 109 | GlobalLoss = [] 110 | for epoch in range(epochNums): 111 | print("===开始本轮的Epoch {} == 总计是Epoch {}===".format(epoch, parManager.EpochDone)) 112 | 113 | # 收集训练参数 114 | epochAccuracy = [] 115 | epochLoss = [] 116 | #=============实际训练流程================= 117 | for batch_id, (inputs,label) in enumerate(TrainLoader): 118 | # 先初始化梯度0 119 | optimizer.zero_grad() 120 | output = model(inputs.cuda()) 121 | loss = criterion(output,label.cuda()) 122 | loss.backward() 123 | optimizer.step() 124 | epochAccuracy.append(accuracy(output,label.cuda()).cpu()) 125 | epochLoss.append(loss.item()) # 需要获取数值来转换 126 | if batch_id % (len(TrainLoader) / 20) == 0: 127 | print(" 当前运行到[{}/{}], 目前Epoch准确率为:{:.2f}%,Loss:{:.6f}".format(batch_id,len(TrainLoader), np.mean(epochAccuracy) * 100, loss)) 128 | #==============本轮训练结束============== 129 | # 收集训练集准确率 130 | TrainACC.append(np.mean(epochAccuracy)) 131 | GlobalLoss.append(np.mean(epochLoss)) 132 | # ==========进行一次验证集测试============ 133 | localTestACC = [] 134 | model.eval() # 进入评估模式,节约开销 135 | for inputs, label in TestLoader: 136 | torch.no_grad() # 上下文管理器,此部分内不会追踪梯度 137 | output = model(inputs.cuda()) 138 | localTestACC.append(accuracy(output,label.cuda()).cpu()) 139 | # ==========验证集测试结束================ 140 | # 收集验证集准确率 141 | TestACC.append(np.mean(localTestACC)) 142 | print("当前Epoch结束,训练集准确率为:{:3f}%,测试集准确率为:{:3f}%".format(TrainACC[-1] * 100, TestACC[-1] * 100)) 143 | # 暂存结果到参数管理器 144 | parManager.oneEpochDone(LEARNINGRATE,TrainACC[-1],TestACC[-1],GlobalLoss[-1]) 145 | # 周期性保存结果到文件 146 | if epoch == epochNums - 1 or epoch % SaveModelEveryNEpoch == 0: 147 | parManager.loadModelParameters(model) 148 | parManager.saveToFile('./model.pt') 149 | 150 | # 查看此次训练之后结果 151 | parManager.show() 152 | # 绘图 153 | plt.figure(figsize=(10,7)) 154 | plt.plot(range(parManager.EpochDone),parManager.TrainACC,marker='*' ,color='r',label='Train') 155 | plt.plot(range(parManager.EpochDone),parManager.TestACC,marker='*' ,color='b',label='Test') 156 | 157 | plt.xlabel('Epochs') 158 | plt.ylabel('ACC') 159 | plt.legend() 160 | plt.title("LeNet-5 on MNIST") 161 | 162 | plt.savefig('Train.jpg') 163 | plt.show() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------