├── LICENSE ├── README.md └── ONNX2.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 mahdieslaminet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST_ONNX_Classification 2 | 3 | This repository contains a Python project for training a simple neural network on the MNIST dataset using PyTorch, exporting the model to ONNX format, and performing image classification using ONNX Runtime. The project includes steps to upload and utilize the ONNX model stored on Google Drive. 4 | 5 | ## Requirements 6 | 7 | Make sure to install the following libraries before running the project: 8 | 9 | ```bash 10 | pip install torch torchvision onnx onnxruntime requests matplotlib numpy 11 | ``` 12 | 13 | ## Project Structure 14 | 15 | - `train_and_save_model.py`: Script for training a simple neural network on the MNIST dataset and saving the model in ONNX format to Google Drive. 16 | - `classify_images.py`: Script for loading the ONNX model from Google Drive and classifying images from the MNIST dataset. 17 | 18 | ## Usage 19 | 20 | ### 1. Training and Saving the Model 21 | 22 | This script trains a simple neural network on the MNIST dataset and saves the trained model in ONNX format to your Google Drive. 23 | 24 | ```python 25 | from google.colab import drive 26 | drive.mount('/content/drive', force_remount=True) 27 | 28 | import torch 29 | import torch.nn as nn 30 | import torch.optim as optim 31 | from torchvision import datasets, transforms 32 | import onnx 33 | 34 | # Define a simple neural network 35 | class SimpleModel(nn.Module): 36 | def __init__(self): 37 | super(SimpleModel, self).__init__() 38 | self.fc1 = nn.Linear(28 * 28, 128) 39 | self.fc2 = nn.Linear(128, 10) 40 | 41 | def forward(self, x): 42 | x = x.view(-1, 28 * 28) 43 | x = torch.relu(self.fc1(x)) 44 | x = self.fc2(x) 45 | return x 46 | 47 | # Prepare the MNIST dataset 48 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 49 | train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) 50 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) 51 | 52 | # Initialize the model, loss function, and optimizer 53 | model = SimpleModel() 54 | criterion = nn.CrossEntropyLoss() 55 | optimizer = optim.SGD(model.parameters(), lr=0.01) 56 | 57 | # Train the model 58 | for epoch in range(5): # Train for 5 epochs 59 | for images, labels in train_loader: 60 | outputs = model(images) 61 | loss = criterion(outputs, labels) 62 | 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | 67 | print(f'Epoch {epoch+1}, Loss: {loss.item()}') 68 | 69 | # Convert the trained model to ONNX format and save to Google Drive 70 | dummy_input = torch.randn(1, 1, 28, 28) # Example input for the model 71 | onnx_model_path = "/content/drive/My Drive/simple_model.onnx" 72 | torch.onnx.export(model, dummy_input, onnx_model_path, input_names=['input'], output_names=['output']) 73 | print(f"Model saved to {onnx_model_path}") 74 | ``` 75 | 76 | ### 2. Classifying Images 77 | 78 | This script loads the ONNX model from Google Drive and performs classification on images from the MNIST dataset. 79 | 80 | ```python 81 | import requests 82 | import onnxruntime as ort 83 | import numpy as np 84 | import matplotlib.pyplot as plt 85 | from torchvision import datasets, transforms 86 | import torch 87 | from google.colab import drive 88 | 89 | # Mount Google Drive 90 | drive.mount('/content/drive', force_remount=True) 91 | 92 | # Path to the ONNX model on Google Drive 93 | onnx_model_path = "/content/drive/My Drive/simple_model.onnx" 94 | 95 | # Load the ONNX model 96 | ort_session = ort.InferenceSession(onnx_model_path) 97 | 98 | # Prepare the MNIST dataset 99 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 100 | test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) 101 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True) 102 | 103 | # Select a few images from the dataset 104 | images, labels = [], [] 105 | for i, (img, lbl) in enumerate(test_loader): 106 | if i >= 5: 107 | break 108 | images.append(img) 109 | labels.append(lbl) 110 | 111 | images = torch.cat(images) 112 | labels = torch.cat(labels) 113 | 114 | # Display images and perform classification 115 | plt.figure(figsize=(10, 5)) 116 | 117 | for i in range(5): 118 | image = images[i].numpy().squeeze() 119 | label = labels[i].item() 120 | 121 | # Prepare the input for the ONNX model 122 | ort_inputs = {ort_session.get_inputs()[0].name: images[i].numpy().reshape(1, 1, 28, 28).astype(np.float32)} 123 | ort_outs = ort_session.run(None, ort_inputs) 124 | pred_label = np.argmax(ort_outs[0]) 125 | 126 | # Display the image 127 | plt.subplot(1, 5, i+1) 128 | plt.imshow(image, cmap='gray') 129 | plt.title(f'Label: {label}\nPred: {pred_label}') 130 | plt.axis('off') 131 | 132 | # Show result of classification 133 | if label == pred_label: 134 | plt.xlabel('Correct', color='green') 135 | else: 136 | plt.xlabel('Wrong', color='red') 137 | 138 | plt.show() 139 | ``` 140 | 141 | ## Functions 142 | 143 | ### `train_and_save_model.py` 144 | 145 | - `SimpleModel`: Defines a simple fully connected neural network. 146 | - `train_model`: Trains the model on the MNIST dataset. 147 | - `save_model_to_onnx`: Converts the trained model to ONNX format and saves it to Google Drive. 148 | 149 | ### `classify_images.py` 150 | 151 | - `load_model`: Loads the ONNX model from Google Drive. 152 | - `classify_images`: Classifies images from the MNIST dataset using the ONNX model. 153 | - `display_results`: Displays images along with classification results. 154 | 155 | ## File Storage 156 | 157 | - The trained ONNX model is saved in your Google Drive at the path: `/content/drive/My Drive/simple_model.onnx`. 158 | 159 | ## References 160 | 161 | - [PyTorch Documentation](https://pytorch.org/docs/stable/index.html) 162 | - [ONNX Documentation](https://onnx.ai/) 163 | - [ONNX Runtime Documentation](https://onnxruntime.ai/) 164 | 165 | ## Similar Projects 166 | 167 | - [PyTorch to ONNX Conversion](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html) 168 | - [ONNX Runtime GitHub](https://github.com/microsoft/onnxruntime) 169 | 170 | Feel free to clone this repository and use it for your own projects. Contributions and suggestions are welcome! 171 | -------------------------------------------------------------------------------- /ONNX2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "mount_file_id": "1ZWr1G0NtcOiy9tu1y9IdjCTcnbx-aFdo", 8 | "authorship_tag": "ABX9TyMiJ6Hx7oQJsCXtDb0QFuSE", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "source": [ 33 | "pip install torch onnx onnxruntime\n" 34 | ], 35 | "metadata": { 36 | "colab": { 37 | "base_uri": "https://localhost:8080/" 38 | }, 39 | "id": "78lWrV_1fWQa", 40 | "outputId": "49e36e95-b5f4-44fc-8d42-b1e0eb66ed03" 41 | }, 42 | "execution_count": null, 43 | "outputs": [ 44 | { 45 | "output_type": "stream", 46 | "name": "stdout", 47 | "text": [ 48 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n", 49 | "Collecting onnx\n", 50 | " Downloading onnx-1.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)\n", 51 | "Collecting onnxruntime\n", 52 | " Downloading onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.3 kB)\n", 53 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n", 54 | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", 55 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", 56 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", 57 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", 58 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", 59 | "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n", 60 | " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", 61 | "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n", 62 | " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", 63 | "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n", 64 | " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", 65 | "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n", 66 | " Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", 67 | "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n", 68 | " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", 69 | "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n", 70 | " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", 71 | "Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n", 72 | " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", 73 | "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n", 74 | " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", 75 | "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n", 76 | " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", 77 | "Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n", 78 | " Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n", 79 | "Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n", 80 | " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n", 81 | "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n", 82 | "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n", 83 | " Using cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 84 | "Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from onnx) (1.26.4)\n", 85 | "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n", 86 | "Collecting coloredlogs (from onnxruntime)\n", 87 | " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)\n", 88 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.3.25)\n", 89 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.1)\n", 90 | "Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n", 91 | " Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)\n", 92 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", 93 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", 94 | "Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", 95 | "Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", 96 | "Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", 97 | "Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", 98 | "Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", 99 | "Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", 100 | "Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", 101 | "Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", 102 | "Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", 103 | "Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n", 104 | "Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", 105 | "Downloading onnx-1.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)\n", 106 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.9/15.9 MB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 107 | "\u001b[?25hDownloading onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)\n", 108 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.8/6.8 MB\u001b[0m \u001b[31m81.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 109 | "\u001b[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n", 110 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 111 | "\u001b[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n", 112 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 113 | "\u001b[?25hUsing cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl (19.7 MB)\n", 114 | "Installing collected packages: onnx, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, humanfriendly, nvidia-cusparse-cu12, nvidia-cudnn-cu12, coloredlogs, onnxruntime, nvidia-cusolver-cu12\n", 115 | "Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.20 nvidia-nvtx-cu12-12.1.105 onnx-1.16.2 onnxruntime-1.18.1\n" 116 | ] 117 | } 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "source": [ 123 | "pip install torch torchvision onnx onnxruntime\n" 124 | ], 125 | "metadata": { 126 | "colab": { 127 | "base_uri": "https://localhost:8080/" 128 | }, 129 | "id": "vSulfPATfnHn", 130 | "outputId": "92a0411b-1fb2-469f-f4eb-9c82419008e8" 131 | }, 132 | "execution_count": null, 133 | "outputs": [ 134 | { 135 | "output_type": "stream", 136 | "name": "stdout", 137 | "text": [ 138 | "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n", 139 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.18.1+cu121)\n", 140 | "Requirement already satisfied: onnx in /usr/local/lib/python3.10/dist-packages (1.16.2)\n", 141 | "Requirement already satisfied: onnxruntime in /usr/local/lib/python3.10/dist-packages (1.18.1)\n", 142 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n", 143 | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", 144 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", 145 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", 146 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", 147 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", 148 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", 149 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", 150 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", 151 | "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", 152 | "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", 153 | "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", 154 | "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", 155 | "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", 156 | "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", 157 | "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.20.5)\n", 158 | "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", 159 | "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n", 160 | "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.20)\n", 161 | "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", 162 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n", 163 | "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n", 164 | "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (15.0.1)\n", 165 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.3.25)\n", 166 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.1)\n", 167 | "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->onnxruntime) (10.0)\n", 168 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", 169 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" 170 | ] 171 | } 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "colab": { 179 | "base_uri": "https://localhost:8080/" 180 | }, 181 | "id": "2Vx1rkfuZp0-", 182 | "outputId": "957db4f4-352f-49bb-8274-b13e6d0374bb" 183 | }, 184 | "outputs": [ 185 | { 186 | "output_type": "stream", 187 | "name": "stdout", 188 | "text": [ 189 | "Mounted at /content/drive\n", 190 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 191 | "Failed to download (trying next):\n", 192 | "HTTP Error 403: Forbidden\n", 193 | "\n", 194 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", 195 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" 196 | ] 197 | }, 198 | { 199 | "output_type": "stream", 200 | "name": "stderr", 201 | "text": [ 202 | "100%|██████████| 9912422/9912422 [00:10<00:00, 968941.41it/s] \n" 203 | ] 204 | }, 205 | { 206 | "output_type": "stream", 207 | "name": "stdout", 208 | "text": [ 209 | "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", 210 | "\n", 211 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 212 | "Failed to download (trying next):\n", 213 | "HTTP Error 403: Forbidden\n", 214 | "\n", 215 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", 216 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" 217 | ] 218 | }, 219 | { 220 | "output_type": "stream", 221 | "name": "stderr", 222 | "text": [ 223 | "100%|██████████| 28881/28881 [00:00<00:00, 152313.89it/s]\n" 224 | ] 225 | }, 226 | { 227 | "output_type": "stream", 228 | "name": "stdout", 229 | "text": [ 230 | "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", 231 | "\n", 232 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", 233 | "Failed to download (trying next):\n", 234 | "HTTP Error 403: Forbidden\n", 235 | "\n", 236 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", 237 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" 238 | ] 239 | }, 240 | { 241 | "output_type": "stream", 242 | "name": "stderr", 243 | "text": [ 244 | "100%|██████████| 1648877/1648877 [00:01<00:00, 1445960.74it/s]\n" 245 | ] 246 | }, 247 | { 248 | "output_type": "stream", 249 | "name": "stdout", 250 | "text": [ 251 | "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", 252 | "\n", 253 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", 254 | "Failed to download (trying next):\n", 255 | "HTTP Error 403: Forbidden\n", 256 | "\n", 257 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", 258 | "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" 259 | ] 260 | }, 261 | { 262 | "output_type": "stream", 263 | "name": "stderr", 264 | "text": [ 265 | "100%|██████████| 4542/4542 [00:00<00:00, 4477210.05it/s]\n" 266 | ] 267 | }, 268 | { 269 | "output_type": "stream", 270 | "name": "stdout", 271 | "text": [ 272 | "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", 273 | "\n", 274 | "Epoch 1, Loss: 0.3986712098121643\n", 275 | "Epoch 2, Loss: 0.30953291058540344\n", 276 | "Epoch 3, Loss: 0.4542768895626068\n", 277 | "Epoch 4, Loss: 0.19958028197288513\n", 278 | "Epoch 5, Loss: 0.16435885429382324\n", 279 | "Model saved to /content/drive/My Drive/simple_model.onnx\n" 280 | ] 281 | } 282 | ], 283 | "source": [ 284 | "\n", 285 | "from google.colab import drive\n", 286 | "drive.mount('/content/drive', force_remount=True)\n", 287 | "\n", 288 | "\n", 289 | "# آموزش و تبدیل مدل به ONNX\n", 290 | "import torch\n", 291 | "import torch.nn as nn\n", 292 | "import torch.optim as optim\n", 293 | "from torchvision import datasets, transforms\n", 294 | "import onnx\n", 295 | "import onnxruntime as ort\n", 296 | "import numpy as np\n", 297 | "\n", 298 | "# تعریف مدل ساده\n", 299 | "class SimpleModel(nn.Module):\n", 300 | " def __init__(self):\n", 301 | " super(SimpleModel, self).__init__()\n", 302 | " self.fc1 = nn.Linear(28 * 28, 128)\n", 303 | " self.fc2 = nn.Linear(128, 10)\n", 304 | "\n", 305 | " def forward(self, x):\n", 306 | " x = x.view(-1, 28 * 28)\n", 307 | " x = torch.relu(self.fc1(x))\n", 308 | " x = self.fc2(x)\n", 309 | " return x\n", 310 | "\n", 311 | "# آماده‌سازی داده‌ها\n", 312 | "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n", 313 | "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", 314 | "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)\n", 315 | "\n", 316 | "# ساخت مدل، تعریف تابع هزینه و بهینه‌ساز\n", 317 | "model = SimpleModel()\n", 318 | "criterion = nn.CrossEntropyLoss()\n", 319 | "optimizer = optim.SGD(model.parameters(), lr=0.01)\n", 320 | "\n", 321 | "# آموزش مدل\n", 322 | "for epoch in range(5): # آموزش برای 5 دور\n", 323 | " for images, labels in train_loader:\n", 324 | " outputs = model(images)\n", 325 | " loss = criterion(outputs, labels)\n", 326 | "\n", 327 | " optimizer.zero_grad()\n", 328 | " loss.backward()\n", 329 | " optimizer.step()\n", 330 | "\n", 331 | " print(f'Epoch {epoch+1}, Loss: {loss.item()}')\n", 332 | "\n", 333 | "# تبدیل مدل به فرمت ONNX\n", 334 | "dummy_input = torch.randn(1, 1, 28, 28) # ورودی نمونه برای مدل\n", 335 | "onnx_model_path = \"/content/drive/My Drive/simple_model.onnx\"\n", 336 | "torch.onnx.export(model, dummy_input, onnx_model_path, input_names=['input'], output_names=['output'])\n", 337 | "print(f\"Model saved to {onnx_model_path}\")\n" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "source": [ 343 | "import requests\n", 344 | "import onnxruntime as ort\n", 345 | "import numpy as np\n", 346 | "import matplotlib.pyplot as plt\n", 347 | "from torchvision import datasets, transforms\n", 348 | "import torch\n", 349 | "from google.colab import drive\n", 350 | "\n", 351 | "# اتصال به گوگل درایو\n", 352 | "drive.mount('/content/drive', force_remount=True)\n", 353 | "\n", 354 | "# مسیر مدل ONNX در گوگل درایو\n", 355 | "onnx_model_path = \"/content/drive/My Drive/simple_model.onnx\"\n", 356 | "\n", 357 | "# بارگذاری مدل ONNX\n", 358 | "ort_session = ort.InferenceSession(onnx_model_path)\n", 359 | "\n", 360 | "# آماده‌سازی داده‌ها (دیتاست MNIST)\n", 361 | "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n", 362 | "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n", 363 | "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)\n", 364 | "\n", 365 | "# انتخاب چند تصویر از دیتاست\n", 366 | "images, labels = [], []\n", 367 | "for i, (img, lbl) in enumerate(test_loader):\n", 368 | " if i >= 5:\n", 369 | " break\n", 370 | " images.append(img)\n", 371 | " labels.append(lbl)\n", 372 | "\n", 373 | "images = torch.cat(images)\n", 374 | "labels = torch.cat(labels)\n", 375 | "\n", 376 | "# نمایش تصاویر و انجام کلاس‌بندی\n", 377 | "plt.figure(figsize=(10, 5))\n", 378 | "\n", 379 | "for i in range(5):\n", 380 | " image = images[i].numpy().squeeze()\n", 381 | " label = labels[i].item()\n", 382 | "\n", 383 | " # آماده‌سازی ورودی برای مدل ONNX\n", 384 | " ort_inputs = {ort_session.get_inputs()[0].name: images[i].numpy().reshape(1, 1, 28, 28).astype(np.float32)}\n", 385 | " ort_outs = ort_session.run(None, ort_inputs)\n", 386 | " pred_label = np.argmax(ort_outs[0])\n", 387 | "\n", 388 | " # نمایش تصویر\n", 389 | " plt.subplot(1, 5, i+1)\n", 390 | " plt.imshow(image, cmap='gray')\n", 391 | " plt.title(f'Label: {label}\\nPred: {pred_label}')\n", 392 | " plt.axis('off')\n", 393 | "\n", 394 | " # نمایش نتیجه درستی یا نادرستی\n", 395 | " if label == pred_label:\n", 396 | " plt.xlabel('Correct', color='green')\n", 397 | " else:\n", 398 | " plt.xlabel('Wrong', color='red')\n", 399 | "\n", 400 | "plt.show()\n" 401 | ], 402 | "metadata": { 403 | "colab": { 404 | "base_uri": "https://localhost:8080/", 405 | "height": 189 406 | }, 407 | "id": "EfFAKGIwZuQp", 408 | "outputId": "893b0a05-fe42-4a70-cd3d-4113d6a8a181" 409 | }, 410 | "execution_count": null, 411 | "outputs": [ 412 | { 413 | "output_type": "stream", 414 | "name": "stdout", 415 | "text": [ 416 | "Mounted at /content/drive\n" 417 | ] 418 | }, 419 | { 420 | "output_type": "display_data", 421 | "data": { 422 | "text/plain": [ 423 | "
" 424 | ], 425 | "image/png": "\n" 426 | }, 427 | "metadata": {} 428 | } 429 | ] 430 | } 431 | ] 432 | } --------------------------------------------------------------------------------