├── model.pth ├── utils.py ├── LICENSE ├── model.py ├── README.md ├── test.py ├── train.py └── demo.py /model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KevinLikesCodingMC/PyTorch-MNIST-Tutorial/HEAD/model.pth -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | MNIST Dataset Loader 3 | Load CSV files 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | class MNISTDataset(Dataset): 11 | def __init__(self, csv_file): 12 | data = np.loadtxt(csv_file, delimiter=",", dtype=np.float32) 13 | self.labels = data[:, 0].astype(np.long) 14 | self.images = data[:, 1:] / 255.0 15 | def __len__(self): 16 | return len(self.labels) 17 | def __getitem__(self, idx): 18 | image = self.images[idx] 19 | label = self.labels[idx] 20 | # np arr -> torch tenser 21 | image_tensor = torch.tensor(image, dtype=torch.float32) 22 | label_tensor = torch.tensor(label, dtype=torch.long) 23 | return image_tensor, label_tensor 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Keroshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | NeuralNetwork Model 3 | Convolutional Neural Network: 4 | 28 x 28 5 | Conv -> 32 * 28 x 28 6 | MaxPool -> 32 * 14 x 14 7 | Conv -> 64 * 14 x 14 8 | MaxPool -> 64 * 7 x 7 9 | Flatten -> 3136 10 | Fully-connected -> 256 11 | Fully-connected -> 10 12 | """ 13 | 14 | from torch import nn 15 | 16 | class NeuralNetwork(nn.Module): 17 | def __init__(self): 18 | super(NeuralNetwork, self).__init__() 19 | # Convolutional Layer 20 | self.conv = nn.Sequential( 21 | # Conv 28 x 28 -> 32 * 28 x 28 22 | nn.Conv2d(1, 32, 3, 1, 1), 23 | nn.ReLU(), 24 | # MaxPool 28 x 28 -> 14 x 14 25 | nn.MaxPool2d(2, 2), 26 | # Conv 32 * 14 x 14 -> 64 * 7 x 7 27 | nn.Conv2d(32, 64, 3, 1, 1), 28 | nn.ReLU(), 29 | # MaxPool 14 x 14 -> 7 x 7 30 | nn.MaxPool2d(2, 2), 31 | ) 32 | # Fully-connected Layer 33 | self.fc = nn.Sequential( 34 | # Flatten 64 * 7 x 7 -> 3136 35 | nn.Flatten(), 36 | # Fully-connected 3136 -> 256 37 | nn.Linear(64 * 7 * 7, 256), 38 | nn.ReLU(), 39 | nn.Dropout(0.5), 40 | # Fully-connected 256 -> 10 41 | nn.Linear(256, 10) 42 | ) 43 | def forward(self, x): 44 | # 768 -> 28 x 28 45 | x = x.view(- 1, 1, 28, 28) 46 | x = self.conv(x) 47 | x = self.fc(x) 48 | return x 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-MNIST-Tutorial 2 | 3 | ## PyTorch MNIST 手写数字识别教程 4 | 5 | [![Python 3.8+](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/) 6 | [![PyTorch 1.9+](https://img.shields.io/badge/PyTorch-1.9+-red.svg)](https://pytorch.org/) 7 | [![License MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 8 | [![GitHub stars](https://img.shields.io/github/stars/KevinLikesCodingMC/PyTorch-MNIST-Tutorial.svg?style=social)](https://github.com/KevinLikesCodingMC/PyTorch-MNIST-Tutorial) 9 | 10 | # Quick Start 11 | 12 | ## 1. Clone the repository 13 | 14 | ```bash 15 | git clone https://github.com/KevinLikesCodingMC/PyTorch-MNIST-Tutorial.git 16 | ``` 17 | ## 2. Install requirements 18 | 19 | First download the [MNIST dataset in CSV format](https://github.com/phoebetronic/mnist) and place `mnist_train.csv` and `mnist_test.csv` in the root directory: 20 | 21 | ``` 22 | PyTorch-MNIST-Tutorial/ 23 | ├── LICENSE 24 | ├── model.pth 25 | ├── model.py 26 | ├── README.md 27 | ├── test.py 28 | ├── train.py 29 | ├── utils.py 30 | ├── demo.py 31 | ├── mnist_train.csv 32 | └── mnist_test.csv 33 | ``` 34 | 35 | Then install these packages: 36 | 37 | ```bash 38 | pip install numpy matplotlib pygame 39 | ``` 40 | 41 | ## 3. Install PyTorch 42 | 43 | Visit [PyTorch official website](https://pytorch.org/get-started/locally/). 44 | 45 | ## 4. Training and Testing 46 | 47 | Run `train.py` to train and generate model `model.pth`. 48 | 49 | Run `test.py` to test the model `model.pth`. 50 | 51 | ## 5. Visual Demo 52 | 53 | Run `demo.py` to watch the visual demo. 54 | 55 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the model 3 | """ 4 | 5 | def main(): 6 | # print logs 7 | import datetime 8 | def log(message): 9 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 10 | print('[%s] %s' % (timestamp, message)) 11 | 12 | # Load Pytorch 13 | log("Initializing PyTorch...") 14 | import torch 15 | from torch.utils.data import DataLoader 16 | log("PyTorch initialization complete") 17 | log(f"PyTorch version: {torch.__version__}") 18 | 19 | # Load libs 20 | log("Loading libraries...") 21 | from utils import MNISTDataset 22 | from model import NeuralNetwork 23 | log("Libraries loaded") 24 | 25 | # Check devices 26 | log("Checking available compute devices...") 27 | if torch.cuda.is_available(): 28 | device = torch.device("cuda") 29 | log(f"GPU detected: {torch.cuda.get_device_name(0)}") 30 | log(f"Available GPU count: {torch.cuda.device_count()}") 31 | log(f"Selected device: GPU (cuda:0) - {torch.cuda.get_device_name(0)}") 32 | log(f"CUDA version: {torch.version.cuda}") 33 | log(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") 34 | else: 35 | device = torch.device("cpu") 36 | log("No GPU available, using CPU for computation") 37 | log(f"Selected device: CPU") 38 | 39 | # Load MNIST test data 40 | log("Loading MNIST...") 41 | test_dataset = MNISTDataset('mnist_test.csv') 42 | test_loader = DataLoader( 43 | test_dataset, 44 | batch_size=128, 45 | shuffle=True, 46 | ) 47 | log("MNIST loaded") 48 | 49 | # Load model 50 | log(f"Loading model...") 51 | model = NeuralNetwork() 52 | checkpoint = torch.load('model.pth', map_location=device) 53 | model.load_state_dict(checkpoint['model_state_dict']) 54 | model = model.to(device) 55 | model.eval() 56 | log(f"model loaded") 57 | 58 | # testing 59 | correct = 0 60 | total = 0 61 | with torch.no_grad(): 62 | for images, labels in test_loader: 63 | images = images.to(device) 64 | labels = labels.to(device) 65 | outputs = model(images) 66 | _, predicted = torch.max(outputs, 1) 67 | total += labels.size(0) 68 | correct += (predicted == labels).sum().item() 69 | 70 | # show result 71 | accuracy = 100.0 * correct / total 72 | log(f"Accuracy: {accuracy:.2f}% ({correct}/{total})") 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train the models 3 | """ 4 | 5 | def main(): 6 | # print logs 7 | import datetime 8 | def log(message): 9 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 10 | print('[%s] %s' % (timestamp, message)) 11 | 12 | # Load Pytorch 13 | log("Initializing PyTorch...") 14 | import torch 15 | from torch import nn 16 | from torch import optim 17 | from torch.utils.data import DataLoader 18 | log("PyTorch initialization complete") 19 | log(f"PyTorch version: {torch.__version__}") 20 | 21 | # Load libs 22 | log("Loading libraries...") 23 | import matplotlib.pyplot as plt 24 | from utils import MNISTDataset 25 | from model import NeuralNetwork 26 | log("Libraries loaded") 27 | 28 | # Check devices 29 | log("Checking available compute devices...") 30 | if torch.cuda.is_available(): 31 | device = torch.device("cuda") 32 | log(f"GPU detected: {torch.cuda.get_device_name(0)}") 33 | log(f"Available GPU count: {torch.cuda.device_count()}") 34 | log(f"Selected device: GPU (cuda:0) - {torch.cuda.get_device_name(0)}") 35 | log(f"CUDA version: {torch.version.cuda}") 36 | log(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") 37 | else: 38 | device = torch.device("cpu") 39 | log("No GPU available, using CPU for computation") 40 | log(f"Selected device: CPU") 41 | 42 | # Load MNIST train data 43 | log("Loading MNIST...") 44 | train_dataset = MNISTDataset('mnist_train.csv') 45 | train_loader = DataLoader( 46 | train_dataset, 47 | batch_size=128, 48 | shuffle=True, 49 | ) 50 | log("MNIST loaded") 51 | 52 | # move model 53 | log(f"Moving model to {device}...") 54 | model = NeuralNetwork().to(device) 55 | log(f"Model successfully moved to {device}") 56 | 57 | # setup model 58 | criterion = nn.CrossEntropyLoss() 59 | optimizer = optim.Adam(model.parameters(), lr=0.001) 60 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 61 | log("Model setup complete") 62 | 63 | losses = [] 64 | accuracies = [] 65 | 66 | # training 67 | log("Starting training...") 68 | 69 | # 10 epoches 70 | for epoch in range(10): 71 | model.train() 72 | running_loss = 0.0 73 | total = 0 74 | correct = 0 75 | for batch_idx, (images, labels) in enumerate(train_loader): 76 | # load data 77 | images = images.to(device, non_blocking=True) 78 | labels = labels.to(device, non_blocking=True) 79 | # calc loss 80 | outputs = model(images) 81 | loss = criterion(outputs, labels) 82 | # backward 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | # collect 87 | running_loss += loss.item() 88 | total += labels.size(0) 89 | _, predicted = torch.max(outputs.data, 1) 90 | correct += (predicted == labels).sum().item() 91 | scheduler.step() 92 | # collect 93 | epoch_loss = running_loss / len(train_loader) 94 | epoch_acc = 100 * correct / total 95 | losses.append(epoch_loss) 96 | accuracies.append(epoch_acc) 97 | log('Loss: {:.4f}, Accuracy: {:.4f}'.format(epoch_loss, epoch_acc)) 98 | 99 | log("Training complete") 100 | 101 | # save model 102 | log("Saving model...") 103 | torch.save({ 104 | 'model_state_dict': model.state_dict(), 105 | 'losses': losses, 106 | 'accuracies': accuracies 107 | }, 'model.pth') 108 | log("Model saved") 109 | 110 | # show result 111 | plt.figure(figsize=(15, 6)) 112 | 113 | plt.subplot(1, 2, 1) 114 | plt.plot(losses, 'b-', linewidth=2) 115 | plt.title('Training Loss') 116 | plt.xlabel('Epoch') 117 | plt.ylabel('Loss') 118 | plt.grid(True) 119 | 120 | plt.subplot(1, 2, 2) 121 | plt.plot(accuracies, 'r-', linewidth=2) 122 | plt.title('Training Accuracy') 123 | plt.xlabel('Epoch') 124 | plt.ylabel('Accuracy (%)') 125 | plt.grid(True) 126 | 127 | plt.show() 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visual Demo of Handwritten Digit Recognition using CNN 3 | Press Left Button to draw 4 | Press key R to reset canvas 5 | """ 6 | 7 | import datetime 8 | 9 | # print logs 10 | def log(message): 11 | timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 12 | print('[%s] %s' % (timestamp, message)) 13 | 14 | # Load Pytorch 15 | log("Initializing PyTorch...") 16 | import torch 17 | import torch.nn.functional as F 18 | log("PyTorch initialization complete") 19 | log(f"PyTorch version: {torch.__version__}") 20 | 21 | # Load libs 22 | log("Loading libraries...") 23 | import pygame.freetype 24 | import numpy as np 25 | import sys 26 | from model import NeuralNetwork 27 | log("Libraries loaded") 28 | 29 | # Check devices 30 | log("Checking available compute devices...") 31 | if torch.cuda.is_available(): 32 | device = torch.device("cuda") 33 | log(f"GPU detected: {torch.cuda.get_device_name(0)}") 34 | log(f"Available GPU count: {torch.cuda.device_count()}") 35 | log(f"Selected device: GPU (cuda:0) - {torch.cuda.get_device_name(0)}") 36 | log(f"CUDA version: {torch.version.cuda}") 37 | log(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") 38 | else: 39 | device = torch.device("cpu") 40 | log("No GPU available, using CPU for computation") 41 | log(f"Selected device: CPU") 42 | 43 | # move model 44 | log(f"Loading model...") 45 | model = NeuralNetwork() 46 | checkpoint = torch.load('model.pth', map_location=device) 47 | model.load_state_dict(checkpoint['model_state_dict']) 48 | model = model.to(device) 49 | model.eval() 50 | log(f"model loaded") 51 | 52 | # init pygame 53 | log(f"Initializing...") 54 | pygame.init() 55 | screen = pygame.display.set_mode((1500, 800)) 56 | pygame.display.set_caption('Visual Demo') 57 | 58 | # pygame colors 59 | BLACK = (0, 0, 0) 60 | WHITE = (255, 255, 255) 61 | LIGHT_GREY = (200, 200, 200) 62 | DARK_GREY = (50, 50, 50) 63 | 64 | # 600 x 600 canvas 65 | canvas = np.zeros((600, 600)) 66 | image = np.zeros((28, 28)) 67 | scl = 600 / 28 68 | 69 | # Shrink 70 | def shrink(canvas): 71 | canvas_tensor = torch.from_numpy(canvas).float() 72 | canvas_tensor = canvas_tensor.unsqueeze(0).unsqueeze(0) 73 | # blur 74 | kernel = [ 75 | [1, 2, 1], 76 | [2, 4, 2], 77 | [1, 2, 1], 78 | ] 79 | kernel = torch.tensor(kernel, dtype=torch.float32) / 16.0 80 | kernel = kernel.unsqueeze(0).unsqueeze(0) 81 | blurred = F.conv2d(canvas_tensor, kernel, padding=1) 82 | # 600 x 600 -> 28 x 28 83 | scaled = F.interpolate(blurred, size=(28, 28), mode='bilinear', align_corners=False) 84 | result = scaled.squeeze().numpy() 85 | return result 86 | 87 | # recognition 88 | def recognize(image): 89 | image_torch = torch.from_numpy(image).float().unsqueeze(0).unsqueeze(0) 90 | image_torch = image_torch.to(device) 91 | with torch.no_grad(): 92 | output = model(image_torch) 93 | output = torch.softmax(output, dim=1) 94 | output = output.cpu().numpy() 95 | return output 96 | 97 | 98 | # drawing state 99 | drawing = False 100 | 101 | last_call_time = 0 102 | interval = 50 103 | 104 | # use the default font 105 | font = pygame.freetype.Font(None, 40) 106 | 107 | # CNN output 108 | output = recognize(image) 109 | 110 | log("Initialization complete") 111 | 112 | while True: 113 | # events 114 | for event in pygame.event.get(): 115 | if event.type == pygame.QUIT: 116 | pygame.quit() 117 | sys.exit() 118 | # switch drawing state 119 | elif event.type == pygame.MOUSEBUTTONDOWN: 120 | if event.button == 1: 121 | drawing = True 122 | elif event.type == pygame.MOUSEBUTTONUP: 123 | if event.button == 1: 124 | drawing = False 125 | # reset canvas 126 | elif event.type == pygame.KEYDOWN: 127 | if event.key == pygame.K_r: 128 | canvas = np.zeros((600, 600)) 129 | image = np.zeros((28, 28)) 130 | output = recognize(image) 131 | 132 | # clear screen 133 | screen.fill(WHITE) 134 | 135 | if drawing: 136 | mouse_x, mouse_y = pygame.mouse.get_pos() 137 | mouse_x -= 100 138 | mouse_y -= 100 139 | r = 22 140 | # drawing a circle 141 | for y in range(max(0, mouse_y - r), min(mouse_y + r, 600)): 142 | for x in range(max(0, mouse_x - r), min(mouse_x + r, 600)): 143 | if (mouse_x - x) ** 2 + (mouse_y - y) ** 2 <= r ** 2: 144 | canvas[y, x] = 1 145 | # 600 x 600 -> 28 x 28 146 | image = shrink(canvas) 147 | # 28 x 28 -> 10 148 | output = recognize(image) 149 | 150 | # draw the 28 x 28 image 151 | for y in range(28): 152 | for x in range(28): 153 | val = int(image[y][x] * 255) 154 | val = max(0, min(255, val)) 155 | clr = (val, val, val) 156 | rect_x = int(x * scl + 100) 157 | rect_y = int(y * scl + 100) 158 | rect_size = int(scl) + 1 159 | pygame.draw.rect(screen, clr, (rect_x, rect_y, rect_size, rect_size)) 160 | 161 | # draw inner border 162 | pygame.draw.rect(screen, DARK_GREY, (250, 200, 300, 400), 3) 163 | # drawing outer border 164 | pygame.draw.rect(screen, LIGHT_GREY, (100, 100, 600, 600), 5) 165 | 166 | # draw number bars 167 | for number in range(10): 168 | font.render_to(screen, (800 + number * 60, 650), str(number), BLACK) 169 | rect_x = 795 + number * 60 170 | rect_height = 500 * float(output[0][number]) 171 | rect_y = 620 - rect_height 172 | pygame.draw.rect(screen, BLACK, (rect_x, rect_y, 30, rect_height)) 173 | 174 | # update screen 175 | pygame.display.update() 176 | --------------------------------------------------------------------------------