├── README.md ├── demo.py ├── demo_stamp.py ├── model.py ├── test_img ├── 1.png ├── 2.png └── 3.png └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # MNIST-pytorch 2 | 3 | A simple handwriting network for handwritten digit string recognition with Pytorch. 4 | 5 | 对于邮编图片的数字识别,选用深度学习分类的方法对手写邮编数字进行识别。首先使用Opencv对邮政编码中的数字图片进行提取,形成多个手写数字的图片。卷积网络选用CNN卷积神经网络和ResNet卷积神经网络,基于Pytorch框架在MNIST数据集上进行手写数字识别的训练,损失函数loss选用交叉熵损失,优化器选用Adam,训练完成使用模型对手写数字进行识别并进行可视化实现。最终识别的效果良好。 6 | 7 | 1.train.py对网络进行训练,可选取CNN和ResNet 8 | 9 | 2.model.py包含了手写的CNN和ResNet网络 10 | 11 | 3.demo.py为MNIST数据集测试可视化 12 | 13 | 4.demo_stamp.py为针对邮票字符串的手写数字识别demo 14 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.datasets as datasets 4 | from model import ConvNet, ResNetMNIST 5 | import torchvision.transforms as transforms 6 | import matplotlib.pyplot as plt 7 | 8 | def predict(): 9 | 10 | # Device configuration 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # trained on RTX3080ti 12 | 13 | # Choose the network 14 | # Network = 'CNN' 15 | Network = 'ResNet' 16 | 17 | num_classes = 10 18 | 19 | test_dataset = datasets.MNIST(root='./', train=False, download=True, transform=transforms.ToTensor()) 20 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) 21 | 22 | # Choose the network 23 | if Network == 'CNN': 24 | model = ConvNet(num_classes).to(device) 25 | model.load_state_dict(torch.load('./CNN_20.ckpt')) 26 | elif Network == 'ResNet': 27 | model = ResNetMNIST(num_classes).to(device) 28 | model.load_state_dict(torch.load('./ResNet_20.ckpt')) 29 | else: 30 | print('Choose wrong network!') 31 | 32 | model.eval() 33 | with torch.no_grad(): 34 | correct = 0 35 | total = 0 36 | for images, labels in test_loader: 37 | images = images.to(device) 38 | labels = labels.to(device) 39 | outputs = model(images) 40 | _, predicted = torch.max(outputs.data, 1) 41 | total += labels.size(0) 42 | correct += (predicted == labels).sum().item() 43 | plt.ion() 44 | plt.imshow(images.cpu().numpy().squeeze(),cmap='gray') 45 | plt.title("Prediction: {} GT: {}".format(predicted.cpu().numpy()[0], labels.cpu().numpy()[0])) 46 | plt.pause(0.5) 47 | plt.close() 48 | print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 49 | 50 | if __name__ == '__main__': 51 | predict() -------------------------------------------------------------------------------- /demo_stamp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import ConvNet, ResNetMNIST 3 | import glob 4 | import cv2 5 | import numpy as np 6 | 7 | img_size = 28 8 | kernel_connect = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]], np.uint8) 9 | ans = [] # 保存图片数组 10 | 11 | def split_digits(s, prefix_name): 12 | s = np.rot90(s) # 使图片逆时针旋转90° 13 | # show(s) 14 | s_copy = cv2.dilate(s, kernel_connect, iterations=1) 15 | s_copy2 = s_copy.copy() 16 | contours, hierarchy = cv2.findContours(s_copy2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # 该函数可以检测出图片中物品的轮廓 17 | # contours:list结构,列表中每个元素代表一个边沿信息。每个元素是(x, 1, 2)的三维向量,x表示该条边沿里共有多少个像素点,第三维的那个“2”表示每个点的横、纵坐标; 18 | # hierarchy:返回类型是(x, 4)的二维ndarray。x和contours里的x是一样的意思。如果输入选择cv2.RETR_TREE,则以树形结构组织输出,hierarchy的四列分别对应下一个轮廓编号、上一个轮廓编号、父轮廓编号、子轮廓编号,该值为负数表示没有对应项。 19 | 20 | # for it in contours: 21 | # print(it) 22 | # print("##########################") 23 | 24 | idx = 0 25 | for contour in contours: 26 | idx = idx + 1 27 | [x, y, w, h] = cv2.boundingRect(contour) # 当得到对象轮廓后,可用boundingRect()得到包覆此轮廓的最小正矩形, 28 | # show(cv2.boundingRect(contour)) 29 | digit = s_copy[y:y + h, x:x + w] 30 | # show(digit) 31 | pad_len = (h - w) // 2 32 | # print(pad_len) 33 | if pad_len > 0: 34 | digit = cv2.copyMakeBorder(digit, 0, 0, pad_len, pad_len, cv2.BORDER_CONSTANT,value=0) 35 | elif pad_len < 0: 36 | digit = cv2.copyMakeBorder(digit, -pad_len, -pad_len, 0, 0, cv2.BORDER_CONSTANT, value=0) 37 | 38 | pad = digit.shape[0] // 4 # 避免数字与边框直接相连,留出4个像素左右。 39 | digit = cv2.copyMakeBorder(digit, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=0) 40 | digit = cv2.resize(digit, (img_size, img_size), interpolation=cv2.INTER_AREA) # 把图片缩放至28*28 41 | digit = np.rot90(digit, 3) # 逆时针旋转270°将原本图片旋转为原来的水平方向 42 | # show(digit) 43 | cv2.imwrite(prefix_name + str(idx) + '.jpg', digit) 44 | ans.append(digit) 45 | 46 | def predict(): 47 | 48 | # Device configuration 49 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # trained on RTX3080ti 50 | 51 | # Choose the network 52 | # Network = 'CNN' 53 | Network = 'ResNet' 54 | 55 | num_classes = 10 56 | 57 | # Choose the network 58 | if Network == 'CNN': 59 | model = ConvNet(num_classes).to(device) 60 | model.load_state_dict(torch.load('./CNN_20.ckpt')) 61 | elif Network == 'ResNet': 62 | model = ResNetMNIST(num_classes).to(device) 63 | model.load_state_dict(torch.load('./ResNet_20.ckpt')) 64 | else: 65 | print('Choose wrong network!') 66 | 67 | img_list = glob.glob('./test_img/*.png') 68 | model.eval() 69 | with torch.no_grad(): 70 | for i, image in enumerate(img_list): 71 | img0 = cv2.imread(image) 72 | img = cv2.cvtColor(img0, cv2.COLOR_BGR2GRAY) 73 | ret, thresh_img = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV) 74 | # cv2.imshow('fig', thresh_img) 75 | # cv2.waitKey(0) 76 | split_digits(thresh_img, str(i+1)+"/split_") 77 | num_list = [] 78 | for inp in glob.glob('./'+str(i+1)+'/*.jpg'): 79 | input = cv2.imread(inp) 80 | input = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY)/255. 81 | input = torch.Tensor(input).to(device).unsqueeze(0).unsqueeze(0) 82 | output = model(input) 83 | # print(output) 84 | # a=a 85 | _, predicted = torch.max(output.data, 1) 86 | predicted = predicted.cpu().numpy()[0] 87 | num_list.append(predicted) 88 | cv2.imshow(str(num_list), img0) 89 | cv2.waitKey(0) 90 | # plt.ion() 91 | # plt.imshow(images.cpu().numpy().squeeze(),cmap='gray') 92 | # plt.title("Prediction: {}".format(predicted.cpu().numpy()[0], labels.cpu().numpy()[0])) 93 | # print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 94 | 95 | if __name__ == '__main__': 96 | predict() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import resnet18 3 | 4 | # Convolutional neural network (two convolutional layers) 5 | class ConvNet(nn.Module): 6 | def __init__(self, num_classes=10): 7 | super(ConvNet, self).__init__() 8 | self.layer1 = nn.Sequential( 9 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 10 | nn.BatchNorm2d(16), 11 | nn.ReLU(), 12 | nn.MaxPool2d(kernel_size=2, stride=2)) 13 | self.layer2 = nn.Sequential( 14 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 15 | nn.BatchNorm2d(32), 16 | nn.ReLU(), 17 | nn.MaxPool2d(kernel_size=2, stride=2)) 18 | self.fc = nn.Linear(7*7*32, num_classes) 19 | def forward(self, x): 20 | out = self.layer1(x) 21 | out = self.layer2(out) 22 | out = out.reshape(out.size(0), -1) 23 | out = self.fc(out) 24 | return out 25 | 26 | #ResNet for MNIST 27 | class ResNetMNIST(nn.Module): 28 | def __init__(self, num_classes=10): 29 | super().__init__() 30 | self.model = resnet18(num_classes=num_classes) 31 | self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 32 | self.loss = nn.CrossEntropyLoss() 33 | def forward(self, x): 34 | return self.model(x) 35 | -------------------------------------------------------------------------------- /test_img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeorgeWuzy/MNIST-pytorch/0a51169762e7d6c3faf45ee3b0aa4782c4313481/test_img/1.png -------------------------------------------------------------------------------- /test_img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeorgeWuzy/MNIST-pytorch/0a51169762e7d6c3faf45ee3b0aa4782c4313481/test_img/2.png -------------------------------------------------------------------------------- /test_img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeorgeWuzy/MNIST-pytorch/0a51169762e7d6c3faf45ee3b0aa4782c4313481/test_img/3.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | from model import ConvNet, ResNetMNIST 6 | 7 | def train(): 8 | # Device configuration 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # trained on RTX3080ti 10 | 11 | # Choose the network 12 | # Network = 'CNN' 13 | Network = 'ResNet' 14 | 15 | # Hyper parameters 16 | num_epochs = 1 # CNN —> 5:99.13%; 10:99.16%; 20:99.16% 17 | # ResNet -> 5:98.85%; 10:99.17%; 20:99.17% 18 | num_classes = 10 19 | batch_size = 100 20 | learning_rate = 0.001 21 | 22 | # MNIST dataset 23 | train_dataset = datasets.MNIST(root='./', train=True, download=True, transform=transforms.ToTensor()) 24 | test_dataset = datasets.MNIST(root='./', train=False, download=True, transform=transforms.ToTensor()) 25 | 26 | # Data loader 27 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 28 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) 29 | 30 | # Choose the network 31 | if Network == 'CNN': 32 | model = ConvNet(num_classes).to(device) 33 | elif Network == 'ResNet': 34 | model = ResNetMNIST(num_classes).to(device) 35 | else: 36 | print('Choose wrong network!') 37 | 38 | # Loss and optimizer 39 | criterion = nn.CrossEntropyLoss() 40 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 41 | 42 | # Train the model 43 | total_step = len(train_loader) 44 | for epoch in range(num_epochs): 45 | for i, (images, labels) in enumerate(train_loader): 46 | images = images.to(device) 47 | labels = labels.to(device) 48 | # Forward pass 49 | outputs = model(images) 50 | loss = criterion(outputs, labels) 51 | # Backward and optimize 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | if (i+1) % 100 == 0: 56 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 57 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 58 | 59 | # Test the model 60 | model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) 61 | with torch.no_grad(): 62 | correct = 0 63 | total = 0 64 | for images, labels in test_loader: 65 | images = images.to(device) 66 | labels = labels.to(device) 67 | outputs = model(images) 68 | _, predicted = torch.max(outputs.data, 1) 69 | total += labels.size(0) 70 | correct += (predicted == labels).sum().item() 71 | print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 72 | 73 | # Save the model checkpoint 74 | if Network == 'CNN': 75 | torch.save(model.state_dict(), 'CNN_'+str(epoch+1)+'.ckpt') 76 | elif Network == 'ResNet': 77 | torch.save(model.state_dict(), 'ResNet_'+str(epoch+1)+'.ckpt') 78 | 79 | if __name__ == '__main__': 80 | train() --------------------------------------------------------------------------------