└── image_classification.py /image_classification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat May 9 14:37:08 2020 4 | 采用CNN对mnist数据集分类 5 | @author: 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torchvision import datasets, transforms 13 | 14 | class Net(nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=100, kernel_size=3, stride=1, padding=0, bias=False) 18 | self.conv2 = nn.Conv2d(in_channels=100, out_channels=50, kernel_size=3, stride=1, padding=0, bias=False) 19 | self.fc1 = nn.Linear(in_features=5*5*50, out_features=500) 20 | self.fc2 = nn.Linear(in_features=500, out_features=10) 21 | def forward(self, x): 22 | #x [1, 28, 28] 23 | x = F.relu(self.conv1(x)) #[100, 26, 26] 24 | x = F.max_pool2d(x, 2, 2) #[100, 13, 13] 25 | x = F.relu(self.conv2(x)) #[50, 11, 11] 26 | x = F.max_pool2d(x, 2, 2) #[50, 5, 5] 27 | x = x.view(-1, 5*5*50) 28 | x = F.relu(self.fc1(x)) 29 | x = self.fc2(x) 30 | return F.log_softmax(x, dim=1) 31 | 32 | def train(model, train_loader, optimizer, epoch): 33 | model.train() 34 | for data, label in train_loader: 35 | optimizer.zero_grad() 36 | pred = model(data) 37 | loss = F.nll_loss(pred, label) 38 | loss.backward() 39 | optimizer.step() 40 | print("loss:", loss.item()) 41 | 42 | def test(model, test_loader): 43 | model.eval() 44 | correct = 0 45 | with torch.no_grad(): #验证不需要梯度 46 | for data, target in test_loader: 47 | output = model(data) 48 | pred = output.argmax(dim=1, keepdim=True) 49 | correct += pred.eq(target.view_as(pred)).sum().item() 50 | print("acc:", 100 * correct / len(test_loader.dataset)) 51 | 52 | 53 | #预处理数据 54 | #mnist图片大小为[1,28,28] 55 | batch_size = 64 56 | torch.manual_seed(100) 57 | train_loader = torch.utils.data.DataLoader( 58 | datasets.MNIST('./mnist_data', train=True, download=True, 59 | transform=transforms.Compose([ 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.1307,), (0.3081,)) 62 | ])), 63 | batch_size=batch_size, shuffle=True) 64 | test_loader = torch.utils.data.DataLoader( 65 | datasets.MNIST('./mnist_data', train=False, transform=transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.1307,), (0.3081,)) #对每一个通道进行norm 68 | ])), 69 | batch_size=batch_size, shuffle=True) 70 | 71 | 72 | lr = 1e-2 73 | momentum = 0.5 74 | model = Net() 75 | 76 | 77 | epochs = 1 78 | isTrain = False 79 | if(isTrain): 80 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 81 | for epoch in range(epochs): 82 | train(model, train_loader, optimizer, epoch) 83 | test(model, test_loader) 84 | 85 | torch.save(model.state_dict(),"mnist_cnn0.pt") 86 | #torch.save(model, "minist_cnn1.pt") 87 | 88 | isTest = True 89 | if(isTest): 90 | model.load_state_dict(torch.load("mnist_cnn0.pt")) 91 | test(model, test_loader) 92 | #themodel = torch.load("minist_cnn1.pt") 93 | #test(themodel, test_loader) 94 | --------------------------------------------------------------------------------