├── .gitignore ├── LICENSE ├── README.md ├── convert_to_onnx.py ├── debug_demo ├── debug.html └── onnx_model.onnx ├── full_demo ├── index.html ├── onnx_model.onnx ├── script.js └── style.css ├── inference_mnist_model.py ├── inputs_batch_preview.png ├── onnx_model.onnx ├── preview_dataset.py ├── pytorch_model.pt └── train_mnist_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Elliot Waite 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Run PyTorch models in the browser using ONNX.js 2 | 3 | Run PyTorch models in the browser with JavaScript by first converting your PyTorch model into the ONNX format and then loading that ONNX model in your website or app using ONNX.js. In the video tutorial below, I take you through this process using the demo example of a handwritten digit recognition model trained on the MNIST dataset. 4 | 5 | ### Tutorial 6 | https://www.youtube.com/watch?v=Vs730jsRgO8 7 | 8 | [](https://www.youtube.com/watch?v=Vs730jsRgO8) 9 | 10 | ### Live Demo and Code Sandbox 11 | 12 | * [Live demo](https://vgzep.csb.app/) 13 | 14 | * [Code sandbox](https://codesandbox.io/s/pytorch-to-javascript-with-onnx-vgzep) 15 | 16 | Note: The model used in this demo is not very accurate, it will often 17 | [misclassify 18 | digits](https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js/issues/1). 19 | It's only meant to be used as a proof of concept. It's the same model that was 20 | used in [PyTorch's MNIST 21 | example](https://github.com/pytorch/examples/blob/main/mnist/main.py). 22 | You can find more accurate image classification models here: [Papers With Code - 23 | Image Classification](https://paperswithcode.com/task/image-classification) 24 | 25 | ### The files in this repo (and a description of what they do) 26 | ``` 27 | ├── degug_demo 28 | │ ├── debug.html (A debug test to make sure the generated ONNX model works. 29 | │ │ Uses ONNX.js to load and run the generated ONNX model.) 30 | │ │ 31 | │ └── onnx_model.onnx (A copy of the generated ONNX model that will be loaded 32 | │ for debugging.) 33 | │ 34 | ├── full_demo 35 | │ ├── index.html (The full demo's HTML code.) 36 | │ │ 37 | │ ├── onnx_model.onnx (A copy of the generated ONNX model. Used by script.js.) 38 | │ │ 39 | │ ├── script.js (The full demos's JS code. Loads the onnx_model.onnx and 40 | │ │ predicts the drawn numbers.) 41 | │ │ 42 | │ └── style.css (The full demo's CSS.) 43 | │ 44 | ├── convert_to_onnx.py (Converts a trained PyTorch model into an ONNX model.) 45 | │ 46 | ├── inference_mnist_model.py (The PyTorch model description. Used by 47 | │ convert_to_onnx.py to generate the ONNX model.) 48 | │ 49 | ├── inputs_batch_preview.png (A preview of a batch of augmented input data. 50 | │ Generated by preview_mnist_dataset.py.) 51 | │ 52 | ├── onnx_model.py (The ONNX model generated by convert_to_onnx.py.) 53 | │ 54 | ├── preview_dataset.py (For testing out different types of data augmentation.) 55 | │ 56 | ├── pytorch_model.pt (The trained PyTorch model parameters. Generated by 57 | │ train_mnist.model.py and used by convert_to_onnx.py to 58 | │ generate the ONNX model.) 59 | │ 60 | └── train_mnist_model.pt (Trains the PyTorch model and saves the trained 61 | parameters as pytorch_model.pt.) 62 | ``` 63 | 64 | ### The benefits of running a model in the browser: 65 | * Faster inference times with smaller models. 66 | * Easy to host and scale (only static files). 67 | * Offline support. 68 | * User privacy (can keep the data on the device). 69 | 70 | ### The benefits of using a backend server: 71 | * Faster load times (don't have to download the model). 72 | * Faster and consistent inference times with larger models (can take advantage of GPUs or other accelerators). 73 | * Model privacy (don't have to share your model if you want to keep it private). 74 | 75 | ## License 76 | 77 | [MIT](LICENSE) 78 | -------------------------------------------------------------------------------- /convert_to_onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from inference_mnist_model import Net 4 | 5 | 6 | def main(): 7 | pytorch_model = Net() 8 | pytorch_model.load_state_dict(torch.load('pytorch_model.pt')) 9 | pytorch_model.eval() 10 | dummy_input = torch.zeros(280 * 280 * 4) 11 | torch.onnx.export(pytorch_model, dummy_input, 'onnx_model.onnx', verbose=True) 12 | 13 | 14 | if __name__ == '__main__': 15 | main() 16 | -------------------------------------------------------------------------------- /debug_demo/debug.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 15 |

16 | The output of this debug demo is logged to the JavaScript 17 | console. To view the output, open your browser's developer 18 | tools window, and look under the "Console" tab. 19 |

20 | 21 | -------------------------------------------------------------------------------- /debug_demo/onnx_model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliotwaite/pytorch-to-javascript-with-onnx-js/9c5f85145faea4e9be7e8e0b4c2effbbe0a1177e/debug_demo/onnx_model.onnx -------------------------------------------------------------------------------- /full_demo/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 14 | 15 |
CLEAR
16 | 17 |
18 |
19 |
20 |
21 |
22 |
0
23 |
24 | 25 |
26 |
27 |
28 |
29 |
1
30 |
31 | 32 |
33 |
34 |
35 |
36 |
2
37 |
38 | 39 |
40 |
41 |
42 |
43 |
3
44 |
45 | 46 |
47 |
48 |
49 |
50 |
4
51 |
52 | 53 |
54 |
55 |
56 |
57 |
5
58 |
59 | 60 |
61 |
62 |
63 |
64 |
6
65 |
66 | 67 |
68 |
69 |
70 |
71 |
7
72 |
73 | 74 |
75 |
76 |
77 |
78 |
8
79 |
80 | 81 |
82 |
83 |
84 |
85 |
9
86 |
87 |
88 |
89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /full_demo/onnx_model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliotwaite/pytorch-to-javascript-with-onnx-js/9c5f85145faea4e9be7e8e0b4c2effbbe0a1177e/full_demo/onnx_model.onnx -------------------------------------------------------------------------------- /full_demo/script.js: -------------------------------------------------------------------------------- 1 | const CANVAS_SIZE = 280; 2 | const CANVAS_SCALE = 0.5; 3 | 4 | const canvas = document.getElementById("canvas"); 5 | const ctx = canvas.getContext("2d"); 6 | const clearButton = document.getElementById("clear-button"); 7 | 8 | let isMouseDown = false; 9 | let hasIntroText = true; 10 | let lastX = 0; 11 | let lastY = 0; 12 | 13 | // Load our model. 14 | const sess = new onnx.InferenceSession(); 15 | const loadingModelPromise = sess.loadModel("./onnx_model.onnx"); 16 | 17 | // Add 'Draw a number here!' to the canvas. 18 | ctx.lineWidth = 28; 19 | ctx.lineJoin = "round"; 20 | ctx.font = "28px sans-serif"; 21 | ctx.textAlign = "center"; 22 | ctx.textBaseline = "middle"; 23 | ctx.fillStyle = "#212121"; 24 | ctx.fillText("Loading...", CANVAS_SIZE / 2, CANVAS_SIZE / 2); 25 | 26 | // Set the line color for the canvas. 27 | ctx.strokeStyle = "#212121"; 28 | 29 | function clearCanvas() { 30 | ctx.clearRect(0, 0, CANVAS_SIZE, CANVAS_SIZE); 31 | for (let i = 0; i < 10; i++) { 32 | const element = document.getElementById(`prediction-${i}`); 33 | element.className = "prediction-col"; 34 | element.children[0].children[0].style.height = "0"; 35 | } 36 | } 37 | 38 | function drawLine(fromX, fromY, toX, toY) { 39 | // Draws a line from (fromX, fromY) to (toX, toY). 40 | ctx.beginPath(); 41 | ctx.moveTo(fromX, fromY); 42 | ctx.lineTo(toX, toY); 43 | ctx.closePath(); 44 | ctx.stroke(); 45 | updatePredictions(); 46 | } 47 | 48 | async function updatePredictions() { 49 | // Get the predictions for the canvas data. 50 | const imgData = ctx.getImageData(0, 0, CANVAS_SIZE, CANVAS_SIZE); 51 | const input = new onnx.Tensor(new Float32Array(imgData.data), "float32"); 52 | 53 | const outputMap = await sess.run([input]); 54 | const outputTensor = outputMap.values().next().value; 55 | const predictions = outputTensor.data; 56 | const maxPrediction = Math.max(...predictions); 57 | 58 | for (let i = 0; i < predictions.length; i++) { 59 | const element = document.getElementById(`prediction-${i}`); 60 | element.children[0].children[0].style.height = `${predictions[i] * 100}%`; 61 | element.className = 62 | predictions[i] === maxPrediction 63 | ? "prediction-col top-prediction" 64 | : "prediction-col"; 65 | } 66 | } 67 | 68 | function canvasMouseDown(event) { 69 | isMouseDown = true; 70 | if (hasIntroText) { 71 | clearCanvas(); 72 | hasIntroText = false; 73 | } 74 | const x = event.offsetX / CANVAS_SCALE; 75 | const y = event.offsetY / CANVAS_SCALE; 76 | 77 | // To draw a dot on the mouse down event, we set laxtX and lastY to be 78 | // slightly offset from x and y, and then we call `canvasMouseMove(event)`, 79 | // which draws a line from (laxtX, lastY) to (x, y) that shows up as a 80 | // dot because the difference between those points is so small. However, 81 | // if the points were the same, nothing would be drawn, which is why the 82 | // 0.001 offset is added. 83 | lastX = x + 0.001; 84 | lastY = y + 0.001; 85 | canvasMouseMove(event); 86 | } 87 | 88 | function canvasMouseMove(event) { 89 | const x = event.offsetX / CANVAS_SCALE; 90 | const y = event.offsetY / CANVAS_SCALE; 91 | if (isMouseDown) { 92 | drawLine(lastX, lastY, x, y); 93 | } 94 | lastX = x; 95 | lastY = y; 96 | } 97 | 98 | function bodyMouseUp() { 99 | isMouseDown = false; 100 | } 101 | 102 | function bodyMouseOut(event) { 103 | // We won't be able to detect a MouseUp event if the mouse has moved 104 | // ouside the window, so when the mouse leaves the window, we set 105 | // `isMouseDown` to false automatically. This prevents lines from 106 | // continuing to be drawn when the mouse returns to the canvas after 107 | // having been released outside the window. 108 | if (!event.relatedTarget || event.relatedTarget.nodeName === "HTML") { 109 | isMouseDown = false; 110 | } 111 | } 112 | 113 | loadingModelPromise.then(() => { 114 | canvas.addEventListener("mousedown", canvasMouseDown); 115 | canvas.addEventListener("mousemove", canvasMouseMove); 116 | document.body.addEventListener("mouseup", bodyMouseUp); 117 | document.body.addEventListener("mouseout", bodyMouseOut); 118 | clearButton.addEventListener("mousedown", clearCanvas); 119 | 120 | ctx.clearRect(0, 0, CANVAS_SIZE, CANVAS_SIZE); 121 | ctx.fillText("Draw a number here!", CANVAS_SIZE / 2, CANVAS_SIZE / 2); 122 | }) 123 | -------------------------------------------------------------------------------- /full_demo/style.css: -------------------------------------------------------------------------------- 1 | *, 2 | *:before, 3 | *:after { 4 | box-sizing: inherit; 5 | } 6 | 7 | html { 8 | -webkit-font-smoothing: antialiased; 9 | -moz-osx-font-smoothing: grayscale; 10 | box-sizing: border-box; 11 | } 12 | 13 | body { 14 | align-items: center; 15 | background: #fafafa; 16 | color: #212121; 17 | display: flex; 18 | font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, 19 | Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol"; 20 | justify-content: center; 21 | margin: 0; 22 | } 23 | 24 | .elevation { 25 | box-shadow: 0 3px 1px -2px rgba(0, 0, 0, 0.2), 0 2px 2px 0 rgba(0, 0, 0, 0.14), 26 | 0 1px 5px 0 rgba(0, 0, 0, 0.12); 27 | } 28 | 29 | .card { 30 | background: #fff; 31 | border-radius: 4px; 32 | padding: 16px; 33 | } 34 | 35 | .canvas { 36 | border-radius: 4px; 37 | height: 140px; 38 | width: 140px; 39 | } 40 | 41 | .button { 42 | background-color: #fff; 43 | border-radius: 4px; 44 | box-shadow: 0 3px 1px -2px rgba(0, 0, 0, 0.2), 0 2px 2px 0 rgba(0, 0, 0, 0.14), 45 | 0 1px 5px 0 rgba(0, 0, 0, 0.12), inset 0 0 0 rgba(0, 0, 0, 0.3); 46 | cursor: pointer; 47 | font-size: 14px; 48 | font-weight: 500; 49 | letter-spacing: 1.25px; 50 | line-height: 36px; 51 | margin: 16px 0; 52 | text-align: center; 53 | transition: box-shadow 0.2s cubic-bezier(0.4, 0, 0.2, 1); 54 | user-select: none; 55 | width: 140px; 56 | } 57 | 58 | .button:hover { 59 | background: #f5f5f5; 60 | } 61 | 62 | .button:active { 63 | box-shadow: 0 0 rgba(0, 0, 0, 0.2), 0 0 rgba(0, 0, 0, 0.14), 64 | 0 0 rgba(0, 0, 0, 0.12), inset 0 0 2px rgba(0, 0, 0, 0.3); 65 | transition: box-shadow 0.05s cubic-bezier(0.4, 0, 0.2, 1); 66 | } 67 | 68 | .predictions { 69 | display: flex; 70 | } 71 | 72 | .prediction-col { 73 | padding: 0 2px; 74 | } 75 | 76 | .prediction-bar-container { 77 | background: #f5f5f5; 78 | height: 140px; 79 | width: 10px; 80 | position: relative; 81 | } 82 | 83 | .prediction-bar { 84 | background: #e0e0e0; 85 | bottom: 0; 86 | position: absolute; 87 | width: 100%; 88 | } 89 | 90 | .prediction-number { 91 | color: #bdbdbd; 92 | font-size: 14px; 93 | text-align: center; 94 | } 95 | 96 | .top-prediction .prediction-bar { 97 | background: #00f0ff; 98 | } 99 | 100 | .top-prediction .prediction-number { 101 | color: #00f0ff; 102 | font-weight: bold; 103 | } 104 | -------------------------------------------------------------------------------- /inference_mnist_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | MEAN = 0.1307 6 | STANDARD_DEVIATION = 0.3081 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 13 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 14 | self.dropout1 = nn.Dropout2d(0.25) 15 | self.dropout2 = nn.Dropout2d(0.5) 16 | self.fc1 = nn.Linear(9216, 128) 17 | self.fc2 = nn.Linear(128, 10) 18 | 19 | def forward(self, x): 20 | x = x.reshape(280, 280, 4) 21 | x = torch.narrow(x, dim=2, start=3, length=1) 22 | x = x.reshape(1, 1, 280, 280) 23 | x = F.avg_pool2d(x, 10, stride=10) 24 | x = x / 255 25 | x = (x - MEAN) / STANDARD_DEVIATION 26 | 27 | x = self.conv1(x) 28 | x = F.relu(x) 29 | x = self.conv2(x) 30 | x = F.max_pool2d(x, 2) 31 | x = self.dropout1(x) 32 | x = torch.flatten(x, 1) 33 | x = self.fc1(x) 34 | x = F.relu(x) 35 | x = self.dropout2(x) 36 | x = self.fc2(x) 37 | output = F.softmax(x, dim=1) 38 | return output 39 | -------------------------------------------------------------------------------- /inputs_batch_preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliotwaite/pytorch-to-javascript-with-onnx-js/9c5f85145faea4e9be7e8e0b4c2effbbe0a1177e/inputs_batch_preview.png -------------------------------------------------------------------------------- /onnx_model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliotwaite/pytorch-to-javascript-with-onnx-js/9c5f85145faea4e9be7e8e0b4c2effbbe0a1177e/onnx_model.onnx -------------------------------------------------------------------------------- /preview_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | def main(): 6 | train_loader = torch.utils.data.DataLoader( 7 | torchvision.datasets.MNIST( 8 | 'data', train=True, download=True, 9 | transform=torchvision.transforms.Compose([ 10 | 11 | # torchvision.transforms.RandomAffine( 12 | # degrees=30), 13 | 14 | # torchvision.transforms.RandomAffine( 15 | # degrees=0, translate=(0.0, 0.5)), 16 | 17 | # torchvision.transforms.RandomAffine( 18 | # degrees=0, translate=(0.5, 0.5)), 19 | 20 | # torchvision.transforms.RandomAffine( 21 | # degrees=0, scale=(0.25, 1)), 22 | 23 | # torchvision.transforms.RandomAffine( 24 | # degrees=0, shear=(-30, 30, -30, 30)), 25 | 26 | torchvision.transforms.RandomAffine( 27 | degrees=30, translate=(0.5, 0.5), scale=(0.25, 1), 28 | shear=(-30, 30, -30, 30)), 29 | 30 | torchvision.transforms.ToTensor(), 31 | ])), 32 | batch_size=800) 33 | inputs_batch, labels_batch = next(iter(train_loader)) 34 | grid = torchvision.utils.make_grid(inputs_batch, nrow=40, pad_value=1) 35 | torchvision.utils.save_image(grid, 'inputs_batch_preview.png') 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /pytorch_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elliotwaite/pytorch-to-javascript-with-onnx-js/9c5f85145faea4e9be7e8e0b4c2effbbe0a1177e/pytorch_model.pt -------------------------------------------------------------------------------- /train_mnist_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from PyTorch's MNIST example (with only a few changes): 3 | https://github.com/pytorch/examples/blob/master/mnist/main.py 4 | """ 5 | from __future__ import print_function 6 | import argparse 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torchvision import datasets, transforms 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | 15 | class Net(nn.Module): 16 | def __init__(self): 17 | super(Net, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 19 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 20 | self.dropout1 = nn.Dropout2d(0.25) 21 | self.dropout2 = nn.Dropout2d(0.5) 22 | self.fc1 = nn.Linear(9216, 128) 23 | self.fc2 = nn.Linear(128, 10) 24 | 25 | def forward(self, x): 26 | x = self.conv1(x) 27 | x = F.relu(x) 28 | x = self.conv2(x) 29 | x = F.max_pool2d(x, 2) 30 | x = self.dropout1(x) 31 | x = torch.flatten(x, 1) 32 | x = self.fc1(x) 33 | x = F.relu(x) 34 | x = self.dropout2(x) 35 | x = self.fc2(x) 36 | output = F.log_softmax(x, dim=1) 37 | return output 38 | 39 | 40 | def train(args, model, device, train_loader, optimizer, epoch): 41 | model.train() 42 | for batch_idx, (data, target) in enumerate(train_loader): 43 | data, target = data.to(device), target.to(device) 44 | optimizer.zero_grad() 45 | output = model(data) 46 | loss = F.nll_loss(output, target) 47 | loss.backward() 48 | optimizer.step() 49 | if batch_idx % args.log_interval == 0: 50 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 51 | epoch, batch_idx * len(data), len(train_loader.dataset), 52 | 100. * batch_idx / len(train_loader), loss.item())) 53 | 54 | 55 | def test(args, model, device, test_loader): 56 | model.eval() 57 | test_loss = 0 58 | correct = 0 59 | with torch.no_grad(): 60 | for data, target in test_loader: 61 | data, target = data.to(device), target.to(device) 62 | output = model(data) 63 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 64 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 65 | correct += pred.eq(target.view_as(pred)).sum().item() 66 | 67 | test_loss /= len(test_loader.dataset) 68 | 69 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 70 | test_loss, correct, len(test_loader.dataset), 71 | 100. * correct / len(test_loader.dataset))) 72 | 73 | 74 | def main(): 75 | # Training settings 76 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 77 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 78 | help='input batch size for training (default: 64)') 79 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 80 | help='input batch size for testing (default: 1000)') 81 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 82 | help='number of epochs to train (default: 14)') 83 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 84 | help='learning rate (default: 1.0)') 85 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 86 | help='Learning rate step gamma (default: 0.7)') 87 | parser.add_argument('--no-cuda', action='store_true', default=False, 88 | help='disables CUDA training') 89 | parser.add_argument('--seed', type=int, default=1, metavar='S', 90 | help='random seed (default: 1)') 91 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 92 | help='how many batches to wait before logging training status') 93 | 94 | parser.add_argument('--save-model', action='store_true', default=False, 95 | help='For Saving the current Model') 96 | args = parser.parse_args() 97 | use_cuda = not args.no_cuda and torch.cuda.is_available() 98 | 99 | torch.manual_seed(args.seed) 100 | 101 | device = torch.device("cuda" if use_cuda else "cpu") 102 | 103 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 104 | train_loader = torch.utils.data.DataLoader( 105 | datasets.MNIST('data', train=True, download=True, 106 | transform=transforms.Compose([ 107 | # Add random transformations to the image. 108 | transforms.RandomAffine( 109 | degrees=30, translate=(0.5, 0.5), scale=(0.25, 1), 110 | shear=(-30, 30, -30, 30)), 111 | 112 | transforms.ToTensor(), 113 | transforms.Normalize((0.1307,), (0.3081,)) 114 | ])), 115 | batch_size=args.batch_size, shuffle=True, **kwargs) 116 | test_loader = torch.utils.data.DataLoader( 117 | datasets.MNIST('data', train=False, transform=transforms.Compose([ 118 | transforms.ToTensor(), 119 | transforms.Normalize((0.1307,), (0.3081,)) 120 | ])), 121 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 122 | 123 | model = Net().to(device) 124 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 125 | 126 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 127 | for epoch in range(1, args.epochs + 1): 128 | train(args, model, device, train_loader, optimizer, epoch) 129 | test(args, model, device, test_loader) 130 | scheduler.step() 131 | 132 | torch.save(model.state_dict(), "pytorch_model.pt") 133 | 134 | 135 | if __name__ == '__main__': 136 | main() --------------------------------------------------------------------------------