├── LICENSE
├── README.md
└── ONNX2.ipynb
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 mahdieslaminet
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MNIST_ONNX_Classification
2 |
3 | This repository contains a Python project for training a simple neural network on the MNIST dataset using PyTorch, exporting the model to ONNX format, and performing image classification using ONNX Runtime. The project includes steps to upload and utilize the ONNX model stored on Google Drive.
4 |
5 | ## Requirements
6 |
7 | Make sure to install the following libraries before running the project:
8 |
9 | ```bash
10 | pip install torch torchvision onnx onnxruntime requests matplotlib numpy
11 | ```
12 |
13 | ## Project Structure
14 |
15 | - `train_and_save_model.py`: Script for training a simple neural network on the MNIST dataset and saving the model in ONNX format to Google Drive.
16 | - `classify_images.py`: Script for loading the ONNX model from Google Drive and classifying images from the MNIST dataset.
17 |
18 | ## Usage
19 |
20 | ### 1. Training and Saving the Model
21 |
22 | This script trains a simple neural network on the MNIST dataset and saves the trained model in ONNX format to your Google Drive.
23 |
24 | ```python
25 | from google.colab import drive
26 | drive.mount('/content/drive', force_remount=True)
27 |
28 | import torch
29 | import torch.nn as nn
30 | import torch.optim as optim
31 | from torchvision import datasets, transforms
32 | import onnx
33 |
34 | # Define a simple neural network
35 | class SimpleModel(nn.Module):
36 | def __init__(self):
37 | super(SimpleModel, self).__init__()
38 | self.fc1 = nn.Linear(28 * 28, 128)
39 | self.fc2 = nn.Linear(128, 10)
40 |
41 | def forward(self, x):
42 | x = x.view(-1, 28 * 28)
43 | x = torch.relu(self.fc1(x))
44 | x = self.fc2(x)
45 | return x
46 |
47 | # Prepare the MNIST dataset
48 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
49 | train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
50 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
51 |
52 | # Initialize the model, loss function, and optimizer
53 | model = SimpleModel()
54 | criterion = nn.CrossEntropyLoss()
55 | optimizer = optim.SGD(model.parameters(), lr=0.01)
56 |
57 | # Train the model
58 | for epoch in range(5): # Train for 5 epochs
59 | for images, labels in train_loader:
60 | outputs = model(images)
61 | loss = criterion(outputs, labels)
62 |
63 | optimizer.zero_grad()
64 | loss.backward()
65 | optimizer.step()
66 |
67 | print(f'Epoch {epoch+1}, Loss: {loss.item()}')
68 |
69 | # Convert the trained model to ONNX format and save to Google Drive
70 | dummy_input = torch.randn(1, 1, 28, 28) # Example input for the model
71 | onnx_model_path = "/content/drive/My Drive/simple_model.onnx"
72 | torch.onnx.export(model, dummy_input, onnx_model_path, input_names=['input'], output_names=['output'])
73 | print(f"Model saved to {onnx_model_path}")
74 | ```
75 |
76 | ### 2. Classifying Images
77 |
78 | This script loads the ONNX model from Google Drive and performs classification on images from the MNIST dataset.
79 |
80 | ```python
81 | import requests
82 | import onnxruntime as ort
83 | import numpy as np
84 | import matplotlib.pyplot as plt
85 | from torchvision import datasets, transforms
86 | import torch
87 | from google.colab import drive
88 |
89 | # Mount Google Drive
90 | drive.mount('/content/drive', force_remount=True)
91 |
92 | # Path to the ONNX model on Google Drive
93 | onnx_model_path = "/content/drive/My Drive/simple_model.onnx"
94 |
95 | # Load the ONNX model
96 | ort_session = ort.InferenceSession(onnx_model_path)
97 |
98 | # Prepare the MNIST dataset
99 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
100 | test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
101 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
102 |
103 | # Select a few images from the dataset
104 | images, labels = [], []
105 | for i, (img, lbl) in enumerate(test_loader):
106 | if i >= 5:
107 | break
108 | images.append(img)
109 | labels.append(lbl)
110 |
111 | images = torch.cat(images)
112 | labels = torch.cat(labels)
113 |
114 | # Display images and perform classification
115 | plt.figure(figsize=(10, 5))
116 |
117 | for i in range(5):
118 | image = images[i].numpy().squeeze()
119 | label = labels[i].item()
120 |
121 | # Prepare the input for the ONNX model
122 | ort_inputs = {ort_session.get_inputs()[0].name: images[i].numpy().reshape(1, 1, 28, 28).astype(np.float32)}
123 | ort_outs = ort_session.run(None, ort_inputs)
124 | pred_label = np.argmax(ort_outs[0])
125 |
126 | # Display the image
127 | plt.subplot(1, 5, i+1)
128 | plt.imshow(image, cmap='gray')
129 | plt.title(f'Label: {label}\nPred: {pred_label}')
130 | plt.axis('off')
131 |
132 | # Show result of classification
133 | if label == pred_label:
134 | plt.xlabel('Correct', color='green')
135 | else:
136 | plt.xlabel('Wrong', color='red')
137 |
138 | plt.show()
139 | ```
140 |
141 | ## Functions
142 |
143 | ### `train_and_save_model.py`
144 |
145 | - `SimpleModel`: Defines a simple fully connected neural network.
146 | - `train_model`: Trains the model on the MNIST dataset.
147 | - `save_model_to_onnx`: Converts the trained model to ONNX format and saves it to Google Drive.
148 |
149 | ### `classify_images.py`
150 |
151 | - `load_model`: Loads the ONNX model from Google Drive.
152 | - `classify_images`: Classifies images from the MNIST dataset using the ONNX model.
153 | - `display_results`: Displays images along with classification results.
154 |
155 | ## File Storage
156 |
157 | - The trained ONNX model is saved in your Google Drive at the path: `/content/drive/My Drive/simple_model.onnx`.
158 |
159 | ## References
160 |
161 | - [PyTorch Documentation](https://pytorch.org/docs/stable/index.html)
162 | - [ONNX Documentation](https://onnx.ai/)
163 | - [ONNX Runtime Documentation](https://onnxruntime.ai/)
164 |
165 | ## Similar Projects
166 |
167 | - [PyTorch to ONNX Conversion](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html)
168 | - [ONNX Runtime GitHub](https://github.com/microsoft/onnxruntime)
169 |
170 | Feel free to clone this repository and use it for your own projects. Contributions and suggestions are welcome!
171 |
--------------------------------------------------------------------------------
/ONNX2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "mount_file_id": "1ZWr1G0NtcOiy9tu1y9IdjCTcnbx-aFdo",
8 | "authorship_tag": "ABX9TyMiJ6Hx7oQJsCXtDb0QFuSE",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "view-in-github",
24 | "colab_type": "text"
25 | },
26 | "source": [
27 | "
"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "source": [
33 | "pip install torch onnx onnxruntime\n"
34 | ],
35 | "metadata": {
36 | "colab": {
37 | "base_uri": "https://localhost:8080/"
38 | },
39 | "id": "78lWrV_1fWQa",
40 | "outputId": "49e36e95-b5f4-44fc-8d42-b1e0eb66ed03"
41 | },
42 | "execution_count": null,
43 | "outputs": [
44 | {
45 | "output_type": "stream",
46 | "name": "stdout",
47 | "text": [
48 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n",
49 | "Collecting onnx\n",
50 | " Downloading onnx-1.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n",
51 | "Collecting onnxruntime\n",
52 | " Downloading onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.3 kB)\n",
53 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n",
54 | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
55 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
56 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
57 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
58 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n",
59 | "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
60 | " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
61 | "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
62 | " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
63 | "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
64 | " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
65 | "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n",
66 | " Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
67 | "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
68 | " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
69 | "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
70 | " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
71 | "Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
72 | " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
73 | "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
74 | " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
75 | "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
76 | " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
77 | "Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n",
78 | " Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n",
79 | "Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
80 | " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n",
81 | "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n",
82 | "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
83 | " Using cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
84 | "Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from onnx) (1.26.4)\n",
85 | "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n",
86 | "Collecting coloredlogs (from onnxruntime)\n",
87 | " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n",
88 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.3.25)\n",
89 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.1)\n",
90 | "Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n",
91 | " Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n",
92 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
93 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
94 | "Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
95 | "Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
96 | "Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
97 | "Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
98 | "Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
99 | "Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
100 | "Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
101 | "Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
102 | "Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
103 | "Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
104 | "Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
105 | "Downloading onnx-1.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)\n",
106 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.9/15.9 MB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
107 | "\u001b[?25hDownloading onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)\n",
108 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.8/6.8 MB\u001b[0m \u001b[31m81.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
109 | "\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
110 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
111 | "\u001b[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
112 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
113 | "\u001b[?25hUsing cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl (19.7 MB)\n",
114 | "Installing collected packages: onnx, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, humanfriendly, nvidia-cusparse-cu12, nvidia-cudnn-cu12, coloredlogs, onnxruntime, nvidia-cusolver-cu12\n",
115 | "Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.20 nvidia-nvtx-cu12-12.1.105 onnx-1.16.2 onnxruntime-1.18.1\n"
116 | ]
117 | }
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "source": [
123 | "pip install torch torchvision onnx onnxruntime\n"
124 | ],
125 | "metadata": {
126 | "colab": {
127 | "base_uri": "https://localhost:8080/"
128 | },
129 | "id": "vSulfPATfnHn",
130 | "outputId": "92a0411b-1fb2-469f-f4eb-9c82419008e8"
131 | },
132 | "execution_count": null,
133 | "outputs": [
134 | {
135 | "output_type": "stream",
136 | "name": "stdout",
137 | "text": [
138 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n",
139 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.18.1+cu121)\n",
140 | "Requirement already satisfied: onnx in /usr/local/lib/python3.10/dist-packages (1.16.2)\n",
141 | "Requirement already satisfied: onnxruntime in /usr/local/lib/python3.10/dist-packages (1.18.1)\n",
142 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n",
143 | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
144 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
145 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
146 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
147 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n",
148 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
149 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
150 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
151 | "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n",
152 | "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n",
153 | "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n",
154 | "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n",
155 | "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n",
156 | "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n",
157 | "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.20.5)\n",
158 | "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
159 | "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n",
160 | "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.20)\n",
161 | "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n",
162 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n",
163 | "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n",
164 | "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (15.0.1)\n",
165 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.3.25)\n",
166 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.1)\n",
167 | "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->onnxruntime) (10.0)\n",
168 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
169 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n"
170 | ]
171 | }
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {
178 | "colab": {
179 | "base_uri": "https://localhost:8080/"
180 | },
181 | "id": "2Vx1rkfuZp0-",
182 | "outputId": "957db4f4-352f-49bb-8274-b13e6d0374bb"
183 | },
184 | "outputs": [
185 | {
186 | "output_type": "stream",
187 | "name": "stdout",
188 | "text": [
189 | "Mounted at /content/drive\n",
190 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
191 | "Failed to download (trying next):\n",
192 | "HTTP Error 403: Forbidden\n",
193 | "\n",
194 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
195 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n"
196 | ]
197 | },
198 | {
199 | "output_type": "stream",
200 | "name": "stderr",
201 | "text": [
202 | "100%|██████████| 9912422/9912422 [00:10<00:00, 968941.41it/s] \n"
203 | ]
204 | },
205 | {
206 | "output_type": "stream",
207 | "name": "stdout",
208 | "text": [
209 | "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
210 | "\n",
211 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
212 | "Failed to download (trying next):\n",
213 | "HTTP Error 403: Forbidden\n",
214 | "\n",
215 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
216 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
217 | ]
218 | },
219 | {
220 | "output_type": "stream",
221 | "name": "stderr",
222 | "text": [
223 | "100%|██████████| 28881/28881 [00:00<00:00, 152313.89it/s]\n"
224 | ]
225 | },
226 | {
227 | "output_type": "stream",
228 | "name": "stdout",
229 | "text": [
230 | "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
231 | "\n",
232 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
233 | "Failed to download (trying next):\n",
234 | "HTTP Error 403: Forbidden\n",
235 | "\n",
236 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
237 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
238 | ]
239 | },
240 | {
241 | "output_type": "stream",
242 | "name": "stderr",
243 | "text": [
244 | "100%|██████████| 1648877/1648877 [00:01<00:00, 1445960.74it/s]\n"
245 | ]
246 | },
247 | {
248 | "output_type": "stream",
249 | "name": "stdout",
250 | "text": [
251 | "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n",
252 | "\n",
253 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
254 | "Failed to download (trying next):\n",
255 | "HTTP Error 403: Forbidden\n",
256 | "\n",
257 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
258 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
259 | ]
260 | },
261 | {
262 | "output_type": "stream",
263 | "name": "stderr",
264 | "text": [
265 | "100%|██████████| 4542/4542 [00:00<00:00, 4477210.05it/s]\n"
266 | ]
267 | },
268 | {
269 | "output_type": "stream",
270 | "name": "stdout",
271 | "text": [
272 | "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n",
273 | "\n",
274 | "Epoch 1, Loss: 0.3986712098121643\n",
275 | "Epoch 2, Loss: 0.30953291058540344\n",
276 | "Epoch 3, Loss: 0.4542768895626068\n",
277 | "Epoch 4, Loss: 0.19958028197288513\n",
278 | "Epoch 5, Loss: 0.16435885429382324\n",
279 | "Model saved to /content/drive/My Drive/simple_model.onnx\n"
280 | ]
281 | }
282 | ],
283 | "source": [
284 | "\n",
285 | "from google.colab import drive\n",
286 | "drive.mount('/content/drive', force_remount=True)\n",
287 | "\n",
288 | "\n",
289 | "# آموزش و تبدیل مدل به ONNX\n",
290 | "import torch\n",
291 | "import torch.nn as nn\n",
292 | "import torch.optim as optim\n",
293 | "from torchvision import datasets, transforms\n",
294 | "import onnx\n",
295 | "import onnxruntime as ort\n",
296 | "import numpy as np\n",
297 | "\n",
298 | "# تعریف مدل ساده\n",
299 | "class SimpleModel(nn.Module):\n",
300 | " def __init__(self):\n",
301 | " super(SimpleModel, self).__init__()\n",
302 | " self.fc1 = nn.Linear(28 * 28, 128)\n",
303 | " self.fc2 = nn.Linear(128, 10)\n",
304 | "\n",
305 | " def forward(self, x):\n",
306 | " x = x.view(-1, 28 * 28)\n",
307 | " x = torch.relu(self.fc1(x))\n",
308 | " x = self.fc2(x)\n",
309 | " return x\n",
310 | "\n",
311 | "# آمادهسازی دادهها\n",
312 | "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
313 | "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
314 | "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)\n",
315 | "\n",
316 | "# ساخت مدل، تعریف تابع هزینه و بهینهساز\n",
317 | "model = SimpleModel()\n",
318 | "criterion = nn.CrossEntropyLoss()\n",
319 | "optimizer = optim.SGD(model.parameters(), lr=0.01)\n",
320 | "\n",
321 | "# آموزش مدل\n",
322 | "for epoch in range(5): # آموزش برای 5 دور\n",
323 | " for images, labels in train_loader:\n",
324 | " outputs = model(images)\n",
325 | " loss = criterion(outputs, labels)\n",
326 | "\n",
327 | " optimizer.zero_grad()\n",
328 | " loss.backward()\n",
329 | " optimizer.step()\n",
330 | "\n",
331 | " print(f'Epoch {epoch+1}, Loss: {loss.item()}')\n",
332 | "\n",
333 | "# تبدیل مدل به فرمت ONNX\n",
334 | "dummy_input = torch.randn(1, 1, 28, 28) # ورودی نمونه برای مدل\n",
335 | "onnx_model_path = \"/content/drive/My Drive/simple_model.onnx\"\n",
336 | "torch.onnx.export(model, dummy_input, onnx_model_path, input_names=['input'], output_names=['output'])\n",
337 | "print(f\"Model saved to {onnx_model_path}\")\n"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "source": [
343 | "import requests\n",
344 | "import onnxruntime as ort\n",
345 | "import numpy as np\n",
346 | "import matplotlib.pyplot as plt\n",
347 | "from torchvision import datasets, transforms\n",
348 | "import torch\n",
349 | "from google.colab import drive\n",
350 | "\n",
351 | "# اتصال به گوگل درایو\n",
352 | "drive.mount('/content/drive', force_remount=True)\n",
353 | "\n",
354 | "# مسیر مدل ONNX در گوگل درایو\n",
355 | "onnx_model_path = \"/content/drive/My Drive/simple_model.onnx\"\n",
356 | "\n",
357 | "# بارگذاری مدل ONNX\n",
358 | "ort_session = ort.InferenceSession(onnx_model_path)\n",
359 | "\n",
360 | "# آمادهسازی دادهها (دیتاست MNIST)\n",
361 | "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
362 | "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
363 | "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)\n",
364 | "\n",
365 | "# انتخاب چند تصویر از دیتاست\n",
366 | "images, labels = [], []\n",
367 | "for i, (img, lbl) in enumerate(test_loader):\n",
368 | " if i >= 5:\n",
369 | " break\n",
370 | " images.append(img)\n",
371 | " labels.append(lbl)\n",
372 | "\n",
373 | "images = torch.cat(images)\n",
374 | "labels = torch.cat(labels)\n",
375 | "\n",
376 | "# نمایش تصاویر و انجام کلاسبندی\n",
377 | "plt.figure(figsize=(10, 5))\n",
378 | "\n",
379 | "for i in range(5):\n",
380 | " image = images[i].numpy().squeeze()\n",
381 | " label = labels[i].item()\n",
382 | "\n",
383 | " # آمادهسازی ورودی برای مدل ONNX\n",
384 | " ort_inputs = {ort_session.get_inputs()[0].name: images[i].numpy().reshape(1, 1, 28, 28).astype(np.float32)}\n",
385 | " ort_outs = ort_session.run(None, ort_inputs)\n",
386 | " pred_label = np.argmax(ort_outs[0])\n",
387 | "\n",
388 | " # نمایش تصویر\n",
389 | " plt.subplot(1, 5, i+1)\n",
390 | " plt.imshow(image, cmap='gray')\n",
391 | " plt.title(f'Label: {label}\\nPred: {pred_label}')\n",
392 | " plt.axis('off')\n",
393 | "\n",
394 | " # نمایش نتیجه درستی یا نادرستی\n",
395 | " if label == pred_label:\n",
396 | " plt.xlabel('Correct', color='green')\n",
397 | " else:\n",
398 | " plt.xlabel('Wrong', color='red')\n",
399 | "\n",
400 | "plt.show()\n"
401 | ],
402 | "metadata": {
403 | "colab": {
404 | "base_uri": "https://localhost:8080/",
405 | "height": 189
406 | },
407 | "id": "EfFAKGIwZuQp",
408 | "outputId": "893b0a05-fe42-4a70-cd3d-4113d6a8a181"
409 | },
410 | "execution_count": null,
411 | "outputs": [
412 | {
413 | "output_type": "stream",
414 | "name": "stdout",
415 | "text": [
416 | "Mounted at /content/drive\n"
417 | ]
418 | },
419 | {
420 | "output_type": "display_data",
421 | "data": {
422 | "text/plain": [
423 | ""
424 | ],
425 | "image/png": "\n"
426 | },
427 | "metadata": {}
428 | }
429 | ]
430 | }
431 | ]
432 | }
--------------------------------------------------------------------------------