├── .gitignore ├── README.md ├── lenet.py ├── requirements.txt └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | .*.swp 4 | *.onnx 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LeNet-5 2 | 3 | This implements a slightly modified LeNet-5 [LeCun et al., 1998a] and achieves an accuracy of ~99% on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). 4 | 5 | 6 | ![Epoch Train Loss visualization](https://i.imgur.com/h4h7CrF.gif) 7 | 8 | ## Setup 9 | 10 | Install all dependencies using the following command 11 | 12 | ``` 13 | $ pip install -r requirements.txt 14 | ``` 15 | 16 | ## Usage 17 | 18 | Start the `visdom` server for visualization 19 | 20 | ``` 21 | $ python -m visdom.server 22 | ``` 23 | 24 | Start the training procedure 25 | 26 | ``` 27 | $ python run.py 28 | ``` 29 | 30 | See epoch train loss live graph at [`http://localhost:8097`](http://localhost:8097). 31 | 32 | The trained model will be exported as ONNX to `lenet.onnx`. The `lenet.onnx` file can be viewed with [Neutron](https://www.electronjs.org/apps/netron) 33 | 34 | ## References 35 | 36 | [[1](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf)] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE, 86(11):2278-2324, November 1998. 37 | -------------------------------------------------------------------------------- /lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | 4 | 5 | class C1(nn.Module): 6 | def __init__(self): 7 | super(C1, self).__init__() 8 | 9 | self.c1 = nn.Sequential(OrderedDict([ 10 | ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))), 11 | ('relu1', nn.ReLU()), 12 | ('s1', nn.MaxPool2d(kernel_size=(2, 2), stride=2)) 13 | ])) 14 | 15 | def forward(self, img): 16 | output = self.c1(img) 17 | return output 18 | 19 | 20 | class C2(nn.Module): 21 | def __init__(self): 22 | super(C2, self).__init__() 23 | 24 | self.c2 = nn.Sequential(OrderedDict([ 25 | ('c2', nn.Conv2d(6, 16, kernel_size=(5, 5))), 26 | ('relu2', nn.ReLU()), 27 | ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)) 28 | ])) 29 | 30 | def forward(self, img): 31 | output = self.c2(img) 32 | return output 33 | 34 | 35 | class C3(nn.Module): 36 | def __init__(self): 37 | super(C3, self).__init__() 38 | 39 | self.c3 = nn.Sequential(OrderedDict([ 40 | ('c3', nn.Conv2d(16, 120, kernel_size=(5, 5))), 41 | ('relu3', nn.ReLU()) 42 | ])) 43 | 44 | def forward(self, img): 45 | output = self.c3(img) 46 | return output 47 | 48 | 49 | class F4(nn.Module): 50 | def __init__(self): 51 | super(F4, self).__init__() 52 | 53 | self.f4 = nn.Sequential(OrderedDict([ 54 | ('f4', nn.Linear(120, 84)), 55 | ('relu4', nn.ReLU()) 56 | ])) 57 | 58 | def forward(self, img): 59 | output = self.f4(img) 60 | return output 61 | 62 | 63 | class F5(nn.Module): 64 | def __init__(self): 65 | super(F5, self).__init__() 66 | 67 | self.f5 = nn.Sequential(OrderedDict([ 68 | ('f5', nn.Linear(84, 10)), 69 | ('sig5', nn.LogSoftmax(dim=-1)) 70 | ])) 71 | 72 | def forward(self, img): 73 | output = self.f5(img) 74 | return output 75 | 76 | 77 | class LeNet5(nn.Module): 78 | """ 79 | Input - 1x32x32 80 | Output - 10 81 | """ 82 | def __init__(self): 83 | super(LeNet5, self).__init__() 84 | 85 | self.c1 = C1() 86 | self.c2_1 = C2() 87 | self.c2_2 = C2() 88 | self.c3 = C3() 89 | self.f4 = F4() 90 | self.f5 = F5() 91 | 92 | def forward(self, img): 93 | output = self.c1(img) 94 | 95 | x = self.c2_1(output) 96 | output = self.c2_2(output) 97 | 98 | output += x 99 | 100 | output = self.c3(output) 101 | output = output.view(img.size(0), -1) 102 | output = self.f4(output) 103 | output = self.f5(output) 104 | return output 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.17.0 2 | torch>=1.4.0 3 | torchvision>=0.4.0 4 | visdom>=0.1.6 5 | Pillow==6.2.0 6 | onnx==1.6.0 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from lenet import LeNet5 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torchvision.datasets.mnist import MNIST 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import DataLoader 8 | import visdom 9 | import onnx 10 | 11 | viz = visdom.Visdom() 12 | 13 | data_train = MNIST('./data/mnist', 14 | download=True, 15 | transform=transforms.Compose([ 16 | transforms.Resize((32, 32)), 17 | transforms.ToTensor()])) 18 | data_test = MNIST('./data/mnist', 19 | train=False, 20 | download=True, 21 | transform=transforms.Compose([ 22 | transforms.Resize((32, 32)), 23 | transforms.ToTensor()])) 24 | data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8) 25 | data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) 26 | 27 | net = LeNet5() 28 | criterion = nn.CrossEntropyLoss() 29 | optimizer = optim.Adam(net.parameters(), lr=2e-3) 30 | 31 | cur_batch_win = None 32 | cur_batch_win_opts = { 33 | 'title': 'Epoch Loss Trace', 34 | 'xlabel': 'Batch Number', 35 | 'ylabel': 'Loss', 36 | 'width': 1200, 37 | 'height': 600, 38 | } 39 | 40 | 41 | def train(epoch): 42 | global cur_batch_win 43 | net.train() 44 | loss_list, batch_list = [], [] 45 | for i, (images, labels) in enumerate(data_train_loader): 46 | optimizer.zero_grad() 47 | 48 | output = net(images) 49 | 50 | loss = criterion(output, labels) 51 | 52 | loss_list.append(loss.detach().cpu().item()) 53 | batch_list.append(i+1) 54 | 55 | if i % 10 == 0: 56 | print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) 57 | 58 | # Update Visualization 59 | if viz.check_connection(): 60 | cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list), 61 | win=cur_batch_win, name='current_batch_loss', 62 | update=(None if cur_batch_win is None else 'replace'), 63 | opts=cur_batch_win_opts) 64 | 65 | loss.backward() 66 | optimizer.step() 67 | 68 | 69 | def test(): 70 | net.eval() 71 | total_correct = 0 72 | avg_loss = 0.0 73 | for i, (images, labels) in enumerate(data_test_loader): 74 | output = net(images) 75 | avg_loss += criterion(output, labels).sum() 76 | pred = output.detach().max(1)[1] 77 | total_correct += pred.eq(labels.view_as(pred)).sum() 78 | 79 | avg_loss /= len(data_test) 80 | print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) 81 | 82 | 83 | def train_and_test(epoch): 84 | train(epoch) 85 | test() 86 | 87 | dummy_input = torch.randn(1, 1, 32, 32, requires_grad=True) 88 | torch.onnx.export(net, dummy_input, "lenet.onnx") 89 | 90 | onnx_model = onnx.load("lenet.onnx") 91 | onnx.checker.check_model(onnx_model) 92 | 93 | 94 | def main(): 95 | for e in range(1, 16): 96 | train_and_test(e) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | --------------------------------------------------------------------------------