├── .gitignore
├── LICENSE
├── README.md
├── convert_to_onnx.py
├── debug_demo
├── debug.html
└── onnx_model.onnx
├── full_demo
├── index.html
├── onnx_model.onnx
├── script.js
└── style.css
├── inference_mnist_model.py
├── inputs_batch_preview.png
├── onnx_model.onnx
├── preview_dataset.py
├── pytorch_model.pt
└── train_mnist_model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Elliot Waite
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Run PyTorch models in the browser using ONNX.js
2 |
3 | Run PyTorch models in the browser with JavaScript by first converting your PyTorch model into the ONNX format and then loading that ONNX model in your website or app using ONNX.js. In the video tutorial below, I take you through this process using the demo example of a handwritten digit recognition model trained on the MNIST dataset.
4 |
5 | ### Tutorial
6 | https://www.youtube.com/watch?v=Vs730jsRgO8
7 |
8 | [](https://www.youtube.com/watch?v=Vs730jsRgO8)
9 |
10 | ### Live Demo and Code Sandbox
11 |
12 | * [Live demo](https://vgzep.csb.app/)
13 |
14 | * [Code sandbox](https://codesandbox.io/s/pytorch-to-javascript-with-onnx-vgzep)
15 |
16 | Note: The model used in this demo is not very accurate, it will often
17 | [misclassify
18 | digits](https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js/issues/1).
19 | It's only meant to be used as a proof of concept. It's the same model that was
20 | used in [PyTorch's MNIST
21 | example](https://github.com/pytorch/examples/blob/main/mnist/main.py).
22 | You can find more accurate image classification models here: [Papers With Code -
23 | Image Classification](https://paperswithcode.com/task/image-classification)
24 |
25 | ### The files in this repo (and a description of what they do)
26 | ```
27 | ├── degug_demo
28 | │ ├── debug.html (A debug test to make sure the generated ONNX model works.
29 | │ │ Uses ONNX.js to load and run the generated ONNX model.)
30 | │ │
31 | │ └── onnx_model.onnx (A copy of the generated ONNX model that will be loaded
32 | │ for debugging.)
33 | │
34 | ├── full_demo
35 | │ ├── index.html (The full demo's HTML code.)
36 | │ │
37 | │ ├── onnx_model.onnx (A copy of the generated ONNX model. Used by script.js.)
38 | │ │
39 | │ ├── script.js (The full demos's JS code. Loads the onnx_model.onnx and
40 | │ │ predicts the drawn numbers.)
41 | │ │
42 | │ └── style.css (The full demo's CSS.)
43 | │
44 | ├── convert_to_onnx.py (Converts a trained PyTorch model into an ONNX model.)
45 | │
46 | ├── inference_mnist_model.py (The PyTorch model description. Used by
47 | │ convert_to_onnx.py to generate the ONNX model.)
48 | │
49 | ├── inputs_batch_preview.png (A preview of a batch of augmented input data.
50 | │ Generated by preview_mnist_dataset.py.)
51 | │
52 | ├── onnx_model.py (The ONNX model generated by convert_to_onnx.py.)
53 | │
54 | ├── preview_dataset.py (For testing out different types of data augmentation.)
55 | │
56 | ├── pytorch_model.pt (The trained PyTorch model parameters. Generated by
57 | │ train_mnist.model.py and used by convert_to_onnx.py to
58 | │ generate the ONNX model.)
59 | │
60 | └── train_mnist_model.pt (Trains the PyTorch model and saves the trained
61 | parameters as pytorch_model.pt.)
62 | ```
63 |
64 | ### The benefits of running a model in the browser:
65 | * Faster inference times with smaller models.
66 | * Easy to host and scale (only static files).
67 | * Offline support.
68 | * User privacy (can keep the data on the device).
69 |
70 | ### The benefits of using a backend server:
71 | * Faster load times (don't have to download the model).
72 | * Faster and consistent inference times with larger models (can take advantage of GPUs or other accelerators).
73 | * Model privacy (don't have to share your model if you want to keep it private).
74 |
75 | ## License
76 |
77 | [MIT](LICENSE)
78 |
--------------------------------------------------------------------------------
/convert_to_onnx.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from inference_mnist_model import Net
4 |
5 |
6 | def main():
7 | pytorch_model = Net()
8 | pytorch_model.load_state_dict(torch.load('pytorch_model.pt'))
9 | pytorch_model.eval()
10 | dummy_input = torch.zeros(280 * 280 * 4)
11 | torch.onnx.export(pytorch_model, dummy_input, 'onnx_model.onnx', verbose=True)
12 |
13 |
14 | if __name__ == '__main__':
15 | main()
16 |
--------------------------------------------------------------------------------
/debug_demo/debug.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
15 |
16 | The output of this debug demo is logged to the JavaScript
17 | console. To view the output, open your browser's developer
18 | tools window, and look under the "Console" tab.
19 |