├── Convert_PyTorch_model_to_TensorFlow.ipynb ├── LICENSE └── README.md /Convert_PyTorch_model_to_TensorFlow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Convert PyTorch model to TensorFlow.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "oYYLVKHguJcp" 22 | }, 23 | "source": [ 24 | "# Open Nueral Network Exchange [ONNX]" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "0W503l0z3yhl" 31 | }, 32 | "source": [ 33 | "###Installing ONNX and other required libraries" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "metadata": { 39 | "colab": { 40 | "base_uri": "https://localhost:8080/" 41 | }, 42 | "id": "vHcTO5Jpk5Bv", 43 | "outputId": "2d14cd14-efc7-40e8-8981-ab6aab868031" 44 | }, 45 | "source": [ 46 | "!pip install onnx" 47 | ], 48 | "execution_count": null, 49 | "outputs": [ 50 | { 51 | "output_type": "stream", 52 | "text": [ 53 | "Requirement already satisfied: onnx in /usr/local/lib/python3.7/dist-packages (1.8.1)\n", 54 | "Requirement already satisfied: protobuf in /usr/local/lib/python3.7/dist-packages (from onnx) (3.12.4)\n", 55 | "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.7/dist-packages (from onnx) (3.7.4.3)\n", 56 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from onnx) (1.15.0)\n", 57 | "Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.7/dist-packages (from onnx) (1.19.5)\n", 58 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from protobuf->onnx) (54.0.0)\n" 59 | ], 60 | "name": "stdout" 61 | } 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "colab": { 68 | "base_uri": "https://localhost:8080/" 69 | }, 70 | "id": "XltrK_M8xZAg", 71 | "outputId": "5052d02a-6d9a-458e-f3a3-2001b2e83c47" 72 | }, 73 | "source": [ 74 | "!pip install tensorflow-addons\n", 75 | "!git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow && pip install -e ." 76 | ], 77 | "execution_count": null, 78 | "outputs": [ 79 | { 80 | "output_type": "stream", 81 | "text": [ 82 | "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.7/dist-packages (0.12.1)\n", 83 | "Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow-addons) (2.7.1)\n", 84 | "fatal: destination path 'onnx-tensorflow' already exists and is not an empty directory.\n" 85 | ], 86 | "name": "stdout" 87 | } 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "id": "7V6MzGtIkxb5" 94 | }, 95 | "source": [ 96 | "### Restart Runtime before continuing" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "metadata": { 102 | "colab": { 103 | "base_uri": "https://localhost:8080/" 104 | }, 105 | "id": "3BYiz5J2lGEv", 106 | "outputId": "683abf76-c077-46fd-b9dc-fb2e550219b3" 107 | }, 108 | "source": [ 109 | "!pip install torchvision" 110 | ], 111 | "execution_count": null, 112 | "outputs": [ 113 | { 114 | "output_type": "stream", 115 | "text": [ 116 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (0.9.0+cu101)\n", 117 | "Requirement already satisfied: torch==1.8.0 in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.8.0+cu101)\n", 118 | "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from torchvision) (7.0.0)\n", 119 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision) (1.19.5)\n", 120 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch==1.8.0->torchvision) (3.7.4.3)\n" 121 | ], 122 | "name": "stdout" 123 | } 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "xLEW1az535Z7" 130 | }, 131 | "source": [ 132 | "###Import required libraries and classes" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "metadata": { 138 | "id": "_EH_Sd64nE9p" 139 | }, 140 | "source": [ 141 | "import torch\n", 142 | "import torch.nn as nn\n", 143 | "import torch.nn.functional as F\n", 144 | "import torch.optim as optim\n", 145 | "from torchvision import datasets, transforms\n", 146 | "from torch.autograd import Variable\n", 147 | "\n", 148 | "import onnx\n", 149 | "from onnx_tf.backend import prepare" 150 | ], 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "dFXB_WEC357Z" 158 | }, 159 | "source": [ 160 | "### Define the model" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "eU20lt1fwBdj" 167 | }, 168 | "source": [ 169 | "class Net(nn.Module):\n", 170 | " def __init__(self):\n", 171 | " super(Net, self).__init__()\n", 172 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 173 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 174 | " self.conv2_drop = nn.Dropout2d()\n", 175 | " self.fc1 = nn.Linear(320, 50)\n", 176 | " self.fc2 = nn.Linear(50, 10)\n", 177 | "\n", 178 | " def forward(self, x):\n", 179 | " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 180 | " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", 181 | " x = x.view(-1, 320)\n", 182 | " x = F.relu(self.fc1(x))\n", 183 | " x = F.dropout(x, training=self.training)\n", 184 | " x = self.fc2(x)\n", 185 | " return F.log_softmax(x, dim=1)" 186 | ], 187 | "execution_count": null, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": { 193 | "id": "QnLPogE636pr" 194 | }, 195 | "source": [ 196 | "### Create the train and test methods" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "metadata": { 202 | "id": "IC3YcJ2DxTBu" 203 | }, 204 | "source": [ 205 | "def train(model, device, train_loader, optimizer, epoch):\n", 206 | " model.train()\n", 207 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 208 | " data, target = data.to(device), target.to(device)\n", 209 | " optimizer.zero_grad()\n", 210 | " output = model(data)\n", 211 | " loss = F.nll_loss(output, target)\n", 212 | " loss.backward()\n", 213 | " optimizer.step()\n", 214 | " if batch_idx % 1000 == 0:\n", 215 | " print('Train Epoch: {} \\tLoss: {:.6f}'.format(\n", 216 | " epoch, loss.item()))\n", 217 | "\n", 218 | "def test(model, device, test_loader):\n", 219 | " model.eval()\n", 220 | " test_loss = 0\n", 221 | " correct = 0\n", 222 | " with torch.no_grad():\n", 223 | " for data, target in test_loader:\n", 224 | " data, target = data.to(device), target.to(device)\n", 225 | " output = model(data)\n", 226 | " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", 227 | " pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 228 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 229 | "\n", 230 | " test_loss /= len(test_loader.dataset)\n", 231 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 232 | " test_loss, correct, len(test_loader.dataset),\n", 233 | " 100. * correct / len(test_loader.dataset)))\n", 234 | " \n" 235 | ], 236 | "execution_count": null, 237 | "outputs": [] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": { 242 | "id": "Pn7NuNjg37UQ" 243 | }, 244 | "source": [ 245 | "### Download the datasets, normalize them and train the model" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "metadata": { 251 | "colab": { 252 | "base_uri": "https://localhost:8080/" 253 | }, 254 | "id": "DSu-horR2kz8", 255 | "outputId": "3d7a393e-1488-4289-a7d0-e4019d1518a6" 256 | }, 257 | "source": [ 258 | "train_loader = torch.utils.data.DataLoader(\n", 259 | " datasets.MNIST('../data', train=True, download=True,\n", 260 | " transform=transforms.Compose([\n", 261 | " transforms.ToTensor(),\n", 262 | " transforms.Normalize((0.1307,), (0.3081,))\n", 263 | " ])),\n", 264 | " batch_size=64, shuffle=True)\n", 265 | "\n", 266 | "test_loader = torch.utils.data.DataLoader(\n", 267 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 268 | " transforms.ToTensor(),\n", 269 | " transforms.Normalize((0.1307,), (0.3081,))\n", 270 | " ])),\n", 271 | " batch_size=1000, shuffle=True)\n", 272 | "\n", 273 | "\n", 274 | "torch.manual_seed(1)\n", 275 | "device = torch.device(\"cuda\")\n", 276 | "\n", 277 | "model = Net().to(device)\n", 278 | "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n", 279 | " \n", 280 | "for epoch in range(21):\n", 281 | " train(model, device, train_loader, optimizer, epoch)\n", 282 | " test(model, device, test_loader)\n" 283 | ], 284 | "execution_count": null, 285 | "outputs": [ 286 | { 287 | "output_type": "stream", 288 | "text": [ 289 | "\n", 290 | "Test set: Average loss: 0.0763, Accuracy: 9753/10000 (98%)\n", 291 | "\n", 292 | "Train Epoch: 5 \tLoss: 0.140425\n", 293 | "\n", 294 | "Test set: Average loss: 0.0645, Accuracy: 9788/10000 (98%)\n", 295 | "\n", 296 | "Train Epoch: 6 \tLoss: 0.137073\n", 297 | "\n", 298 | "Test set: Average loss: 0.0611, Accuracy: 9812/10000 (98%)\n", 299 | "\n", 300 | "Train Epoch: 7 \tLoss: 0.343908\n", 301 | "\n", 302 | "Test set: Average loss: 0.0549, Accuracy: 9820/10000 (98%)\n", 303 | "\n", 304 | "Train Epoch: 8 \tLoss: 0.117892\n", 305 | "\n", 306 | "Test set: Average loss: 0.0526, Accuracy: 9828/10000 (98%)\n", 307 | "\n", 308 | "Train Epoch: 9 \tLoss: 0.211196\n", 309 | "\n", 310 | "Test set: Average loss: 0.0542, Accuracy: 9834/10000 (98%)\n", 311 | "\n", 312 | "Train Epoch: 10 \tLoss: 0.159183\n", 313 | "\n", 314 | "Test set: Average loss: 0.0492, Accuracy: 9852/10000 (99%)\n", 315 | "\n", 316 | "Train Epoch: 11 \tLoss: 0.137455\n", 317 | "\n", 318 | "Test set: Average loss: 0.0465, Accuracy: 9857/10000 (99%)\n", 319 | "\n", 320 | "Train Epoch: 12 \tLoss: 0.183048\n", 321 | "\n", 322 | "Test set: Average loss: 0.0455, Accuracy: 9868/10000 (99%)\n", 323 | "\n", 324 | "Train Epoch: 13 \tLoss: 0.193880\n", 325 | "\n", 326 | "Test set: Average loss: 0.0448, Accuracy: 9865/10000 (99%)\n", 327 | "\n", 328 | "Train Epoch: 14 \tLoss: 0.028933\n", 329 | "\n", 330 | "Test set: Average loss: 0.0445, Accuracy: 9877/10000 (99%)\n", 331 | "\n", 332 | "Train Epoch: 15 \tLoss: 0.113090\n", 333 | "\n", 334 | "Test set: Average loss: 0.0422, Accuracy: 9871/10000 (99%)\n", 335 | "\n", 336 | "Train Epoch: 16 \tLoss: 0.091770\n", 337 | "\n", 338 | "Test set: Average loss: 0.0434, Accuracy: 9871/10000 (99%)\n", 339 | "\n", 340 | "Train Epoch: 17 \tLoss: 0.086220\n", 341 | "\n", 342 | "Test set: Average loss: 0.0401, Accuracy: 9883/10000 (99%)\n", 343 | "\n", 344 | "Train Epoch: 18 \tLoss: 0.076834\n", 345 | "\n", 346 | "Test set: Average loss: 0.0381, Accuracy: 9884/10000 (99%)\n", 347 | "\n", 348 | "Train Epoch: 19 \tLoss: 0.196543\n", 349 | "\n", 350 | "Test set: Average loss: 0.0380, Accuracy: 9877/10000 (99%)\n", 351 | "\n", 352 | "Train Epoch: 20 \tLoss: 0.082292\n", 353 | "\n", 354 | "Test set: Average loss: 0.0405, Accuracy: 9881/10000 (99%)\n", 355 | "\n", 356 | "Train Epoch: 0 \tLoss: 2.377307\n", 357 | "\n", 358 | "Test set: Average loss: 0.1973, Accuracy: 9403/10000 (94%)\n", 359 | "\n", 360 | "Train Epoch: 1 \tLoss: 0.450073\n", 361 | "\n", 362 | "Test set: Average loss: 0.1318, Accuracy: 9581/10000 (96%)\n", 363 | "\n", 364 | "Train Epoch: 2 \tLoss: 0.514039\n", 365 | "\n", 366 | "Test set: Average loss: 0.0985, Accuracy: 9685/10000 (97%)\n", 367 | "\n", 368 | "Train Epoch: 3 \tLoss: 0.171799\n", 369 | "\n", 370 | "Test set: Average loss: 0.0833, Accuracy: 9730/10000 (97%)\n", 371 | "\n", 372 | "Train Epoch: 4 \tLoss: 0.157794\n", 373 | "\n", 374 | "Test set: Average loss: 0.0766, Accuracy: 9758/10000 (98%)\n", 375 | "\n", 376 | "Train Epoch: 5 \tLoss: 0.156471\n", 377 | "\n", 378 | "Test set: Average loss: 0.0638, Accuracy: 9793/10000 (98%)\n", 379 | "\n", 380 | "Train Epoch: 6 \tLoss: 0.143752\n", 381 | "\n", 382 | "Test set: Average loss: 0.0617, Accuracy: 9811/10000 (98%)\n", 383 | "\n", 384 | "Train Epoch: 7 \tLoss: 0.295539\n", 385 | "\n", 386 | "Test set: Average loss: 0.0545, Accuracy: 9831/10000 (98%)\n", 387 | "\n", 388 | "Train Epoch: 8 \tLoss: 0.105239\n", 389 | "\n", 390 | "Test set: Average loss: 0.0528, Accuracy: 9837/10000 (98%)\n", 391 | "\n", 392 | "Train Epoch: 9 \tLoss: 0.181777\n", 393 | "\n", 394 | "Test set: Average loss: 0.0539, Accuracy: 9832/10000 (98%)\n", 395 | "\n", 396 | "Train Epoch: 10 \tLoss: 0.133470\n", 397 | "\n", 398 | "Test set: Average loss: 0.0491, Accuracy: 9862/10000 (99%)\n", 399 | "\n", 400 | "Train Epoch: 11 \tLoss: 0.131326\n", 401 | "\n", 402 | "Test set: Average loss: 0.0466, Accuracy: 9857/10000 (99%)\n", 403 | "\n", 404 | "Train Epoch: 12 \tLoss: 0.186313\n", 405 | "\n", 406 | "Test set: Average loss: 0.0462, Accuracy: 9866/10000 (99%)\n", 407 | "\n", 408 | "Train Epoch: 13 \tLoss: 0.206141\n", 409 | "\n", 410 | "Test set: Average loss: 0.0445, Accuracy: 9863/10000 (99%)\n", 411 | "\n", 412 | "Train Epoch: 14 \tLoss: 0.034316\n", 413 | "\n", 414 | "Test set: Average loss: 0.0451, Accuracy: 9881/10000 (99%)\n", 415 | "\n", 416 | "Train Epoch: 15 \tLoss: 0.116052\n", 417 | "\n", 418 | "Test set: Average loss: 0.0427, Accuracy: 9876/10000 (99%)\n", 419 | "\n", 420 | "Train Epoch: 16 \tLoss: 0.091477\n", 421 | "\n", 422 | "Test set: Average loss: 0.0432, Accuracy: 9872/10000 (99%)\n", 423 | "\n", 424 | "Train Epoch: 17 \tLoss: 0.073883\n", 425 | "\n", 426 | "Test set: Average loss: 0.0400, Accuracy: 9885/10000 (99%)\n", 427 | "\n", 428 | "Train Epoch: 18 \tLoss: 0.090850\n", 429 | "\n", 430 | "Test set: Average loss: 0.0379, Accuracy: 9882/10000 (99%)\n", 431 | "\n", 432 | "Train Epoch: 19 \tLoss: 0.201131\n", 433 | "\n", 434 | "Test set: Average loss: 0.0372, Accuracy: 9892/10000 (99%)\n", 435 | "\n", 436 | "Train Epoch: 20 \tLoss: 0.093583\n", 437 | "\n", 438 | "Test set: Average loss: 0.0403, Accuracy: 9885/10000 (99%)\n", 439 | "\n" 440 | ], 441 | "name": "stdout" 442 | } 443 | ] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "metadata": { 448 | "id": "vxpveW_938Og" 449 | }, 450 | "source": [ 451 | "### Save the Pytorch model" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "metadata": { 457 | "id": "tgNoaZ5zwFa2" 458 | }, 459 | "source": [ 460 | "torch.save(model.state_dict(), 'mnist.pth')" 461 | ], 462 | "execution_count": null, 463 | "outputs": [] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": { 468 | "id": "IDfDo3HZ38pK" 469 | }, 470 | "source": [ 471 | "### Load the saved Pytorch model and export it as an ONNX file" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "x_gYxh35zWjr" 478 | }, 479 | "source": [ 480 | "trained_model = Net()\n", 481 | "trained_model.load_state_dict(torch.load('mnist.pth'))\n", 482 | "\n", 483 | "dummy_input = Variable(torch.randn(1, 1, 28, 28)) \n", 484 | "torch.onnx.export(trained_model, dummy_input, \"mnist.onnx\")" 485 | ], 486 | "execution_count": null, 487 | "outputs": [] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": { 492 | "id": "675zGIJ_5x3O" 493 | }, 494 | "source": [ 495 | "### Load the ONNX file and import it into Tensorflow" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "metadata": { 501 | "id": "KjdEU-496qUd" 502 | }, 503 | "source": [ 504 | "# Load the ONNX file\n", 505 | "model = onnx.load('mnist.onnx')\n", 506 | "\n", 507 | "# Import the ONNX model to Tensorflow\n", 508 | "tf_rep = prepare(model)" 509 | ], 510 | "execution_count": null, 511 | "outputs": [] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": { 516 | "id": "Q2Pfaw1b5yeE" 517 | }, 518 | "source": [ 519 | "### Run and test the Tensorflow model" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "metadata": { 525 | "colab": { 526 | "base_uri": "https://localhost:8080/", 527 | "height": 160 528 | }, 529 | "id": "YHlLIajoxcSn", 530 | "outputId": "cb2de210-234a-4b2e-9e41-66fff53522bb" 531 | }, 532 | "source": [ 533 | "import numpy as np\n", 534 | "from IPython.display import display\n", 535 | "from PIL import Image\n", 536 | "print('Image 1:')\n", 537 | "img = Image.open('/content/img1.png').resize((28, 28)).convert('L')\n", 538 | "display(img)\n", 539 | "output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :])\n", 540 | "print('The digit is classified as ', np.argmax(output))\n", 541 | "print('------------------------------------------------------------------------------')\n", 542 | "print('Image 2:')\n", 543 | "img = Image.open('/content/img2.png').resize((28, 28)).convert('L')\n", 544 | "display(img)\n", 545 | "output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :])\n", 546 | "print('The digit is classified as ', np.argmax(output))" 547 | ], 548 | "execution_count": null, 549 | "outputs": [ 550 | { 551 | "output_type": "stream", 552 | "text": [ 553 | "Image 1:\n" 554 | ], 555 | "name": "stdout" 556 | }, 557 | { 558 | "output_type": "display_data", 559 | "data": { 560 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAMFGlDQ1BJQ0MgUHJvZmlsZQAAeJyVlwdUk8kWx+crKYSEFoiAlNCbIL1K7x3pYCMkAUKJIRBU7OiigmsXERQVXQFRdC2ArBULFhYBe31RRGVlXSxgQeVNEkCf+/a88+ac+fLLnXvv/Gcy82UGAEV7lkCQjSoBkMPPF0YH+jATk5KZJDFAAA4osJqy2HkC76ioMPCPZegW9IbluqUk1z/7/deizOHmsQFAoiCncvLYOZCPAoBrsgXCfAAIndBuMCdfIOF3kFWFUCAARLKE02WsJeFUGVtLfWKjfSH7AUCmsljCdAAUJPmZBex0mEdBANmaz+HxIe+E7MHOYHEgiyFPysmZDVmRCtk09bs86f+RM3U8J4uVPs6ysUgL2Y+XJ8hmzfs/p+N/l5xs0Vgf+rBSM4RB0ZIxw3mrzZodKmGoHTnBT42IhKwC+RKPI/WX8L0MUVDcqH8/O88XzhlgAIACDssvFDKcS5QhyorzHmVbllAaC/3RCF5+cOwopwpnR4/mRwu4ef4xY5zBDQ4bzbmSnx0xxlVpvIBgyHCloUcLM2ITZDrR8wW8+AjICpA787JiQkf9HxVm+EaM+QhF0RLNhpDfpQkDomU+mHpO3ti4MCs2S6pBHbJXfkZskCwWS+TmJYaNaeNw/fxlGjAOlx83qhmDq8snejS2WJAdNeqPVXGzA6Nl84wdyiuIGYvtzocLTDYP2JNMVkiUTD82JMiPipVpw3EQBnyBH2ACEaypYDbIBLyO/qZ++E3WEgBYQAjSARdYjlrGIhKkLXz4jAGF4E9IXJA3HucjbeWCAmj/Mm6VPS1BmrS1QBqRBZ5BzsE1cQ/cDQ+DTy9YbXFn3GUsjqk41ivRn+hHDCIGEM3GdbCh6mxYhYD3d9u3SMIzQhfhCeEmQUy4C0JhKxeOWaKQPz6yePBUmmX0+yxekfAH5UwQDsQwLmB0dKkwum/MBzeGqh1wH9wd6ofacQauCSxxezgSb9wTjs0BWr9XKBpX8W0uf+xPou/7MY7aFcwVHEZVpI7r9x33+jGL73dzxIGfoT96YiuxI1gbdha7jJ3AmgATO401Y+3YSQmPr4Sn0pUw1lu0VFsWzMMb87Gut+6z/vy33lmjCoTS3xvkc+fmSzaE72zBPCEvPSOf6Q3fyFxmMJ9tNYlpa23jCIDk/S57fbxlSN/bCOPKN1vuGQBcSqAx/ZuNZQDA8WcA0Ie+2QzewO21DoCTnWyRsEBmwyUPAvzXUIQ7QwPoAANgCsdkCxyBG/AC/iAERIJYkARmwlnPADlQ9RywACwFxaAUrAObQQXYAXaDWnAAHAZN4AQ4Cy6Cq6AT3AT34droBS/BABgCwwiCkBAaQkc0EF3ECLFAbBFnxAPxR8KQaCQJSUHSET4iQhYgy5BSZANSgexC6pBfkePIWeQy0oXcRR4jfcgb5BOKoVRUFdVGjdHJqDPqjYaisegMNB3NRQvR5egatBytRvejjehZ9Cp6ExWjL9FBDGDyGAPTwywxZ8wXi8SSsTRMiC3CSrAyrBprwFrgb30dE2P92EeciNNxJm4J12cQHoez8Vx8Eb4ar8Br8Ub8PH4df4wP4F8JNIIWwYLgSggmJBLSCXMIxYQywl7CMcIFuKN6CUNEIpFBNCE6wb2ZRMwkzieuJm4nHiSeIXYRe4iDJBJJg2RBcidFklikfFIxaStpP+k0qZvUS/pAlifrkm3JAeRkMp9cRC4j7yOfIneTn5OH5ZTkjORc5SLlOHLz5NbK7ZFrkbsm1ys3TFGmmFDcKbGUTMpSSjmlgXKB8oDyVl5eXl/eRX6qPE9+iXy5/CH5S/KP5T9SVajmVF/qdKqIuoZaQz1DvUt9S6PRjGletGRaPm0NrY52jvaI9kGBrmClEKzAUVisUKnQqNCt8EpRTtFI0VtxpmKhYpniEcVriv1KckrGSr5KLKVFSpVKx5VuKw0q05VtlCOVc5RXK+9Tvqz8QoWkYqzir8JRWa6yW+WcSg8doxvQfels+jL6HvoFeq8qUdVENVg1U7VU9YBqh+qAmoqavVq82ly1SrWTamIGxjBmBDOyGWsZhxm3GJ8maE/wnsCdsGpCw4TuCe/VJ6p7qXPVS9QPqt9U/6TB1PDXyNJYr9Gk8VAT1zTXnKo5R7NK84Jm/0TViW4T2RNLJh6eeE8L1TLXitaar7Vbq11rUFtHO1BboL1V+5x2vw5Dx0snU2eTzimdPl26rocuT3eT7mndP5hqTG9mNrOceZ45oKelF6Qn0tul16E3rG+iH6dfpH9Q/6EBxcDZIM1gk0GrwYChrmG44QLDesN7RnJGzkYZRluM2ozeG5sYJxivMG4yfmGibhJsUmhSb/LAlGbqaZprWm16w4xo5myWZbbdrNMcNXcwzzCvNL9mgVo4WvAstlt0TSJMcpnEn1Q96bYl1dLbssCy3vKxFcMqzKrIqsnq1WTDycmT109um/zV2sE623qP9X0bFZsQmyKbFps3tua2bNtK2xt2NLsAu8V2zXav7S3sufZV9ncc6A7hDiscWh2+ODo5Ch0bHPucDJ1SnLY53XZWdY5yXu18yYXg4uOy2OWEy0dXR9d818Ouf7lZumW57XN7McVkCnfKnik97vruLPdd7mIPpkeKx04PsaeeJ8uz2vOJl4EXx2uv13NvM+9M7/3er3ysfYQ+x3ze+7r6LvQ944f5BfqV+HX4q/jH+Vf4PwrQD0gPqA8YCHQInB94JogQFBq0Puh2sHYwO7gueCDEKWRhyPlQamhMaEXokzDzMGFYSzgaHhK+MfxBhFEEP6IpEkQGR26MfBhlEpUb9dtU4tSoqZVTn0XbRC+Ibouhx8yK2RczFOsTuzb2fpxpnCiuNV4xfnp8Xfz7BL+EDQnixMmJCxOvJmkm8ZKak0nJ8cl7kwen+U/bPK13usP04um3ZpjMmDvj8kzNmdkzT85SnMWadSSFkJKQsi/lMyuSVc0aTA1O3ZY6wPZlb2G/5HhxNnH6uO7cDdznae5pG9JepLunb0zvy/DMKMvo5/nyKnivM4Myd2S+z4rMqskayU7IPphDzknJOc5X4Wfxz8/WmT13dpfAQlAsEOe65m7OHRCGCvfmIXkz8przVeFRp11kKvpJ9LjAo6Cy4MOc+DlH5irP5c9tn2c+b9W854UBhb/Mx+ez57cu0FuwdMHjhd4Ldy1CFqUual1ssHj54t4lgUtql1KWZi39vci6aEPRu2UJy1qWay9fsrznp8Cf6osVioXFt1e4rdixEl/JW9mxym7V1lVfSzglV0qtS8tKP69mr77ys83P5T+PrElb07HWcW3VOuI6/rpb6z3X125Q3lC4oWdj+MbGTcxNJZvebZ61+XKZfdmOLZQtoi3i8rDy5q2GW9dt/VyRUXGz0qfy4Datbau2vd/O2d5d5VXVsEN7R+mOTzt5O+/sCtzVWG1cXbabuLtg97M98XvafnH+pW6v5t7SvV9q+DXi2uja83VOdXX7tPatrUfrRfV9+6fv7zzgd6C5wbJh10HGwdJD4JDo0B+/pvx663Do4dYjzkcajhod3XaMfqykEWmc1zjQlNEkbk5q7joecry1xa3l2G9Wv9Wc0DtReVLt5NpTlFPLT42cLjw9eEZwpv9s+tme1lmt988lnrtxfur5jguhFy5dDLh4rs277fQl90snLrtePn7F+UrTVcerje0O7cd+d/j9WIdjR+M1p2vNnS6dLV1Tuk51e3afve53/eKN4BtXb0bc7LoVd+vO7em3xXc4d17czb77+l7BveH7Sx4QHpQ8VHpY9kjrUfW/zP51UOwoPvnY73H7k5gn93vYPS+f5j393Lv8Ge1Z2XPd53UvbF+c6Avo6/xj2h+9LwUvh/uL/1T+c9sr01dH//L6q30gcaD3tfD1yJvVbzXe1ryzf9c6GDX4aChnaPh9yQeND7UfnT+2fUr49Hx4zmfS5/IvZl9avoZ+fTCSMzIiYAlZ0qMABiualgbAmxoAaEnw7ADvcRQF2f1LWhDZnVFK4J9YdkeTFnhyqfECIG4JAGHwjFIFqxFkKvyUHL9jvQBqZzdeR0temp2tLBcV3mIIH0ZG3moDQGoB4ItwZGR4+8jIlz1Q7F0AzuTK7n2SQoRn/J1mEupop4Afy78B56FrHGvPucAAAAFgSURBVHicYzzHw8iAHfz/wsKnjEOOgeEuE04pBgYG8iVZULn/vr05x+bGhk3y993VZ+58EHwfA/PAnf9Q8O/7uWItXg4FJUbv5xCROwjJ1zUaXAyMjieXMupdhUrCjf23Z+5zBlbzLoN3DP//o9v5++gXAQWTTKOv9//D3QCXZMtS0lSR5Gb4+RyLTkZNFVaGPz8Z3l5hYGFBl2RgYP37eM0X5g83+J1VMCR/3z879xjL739/LFMwdH7c3vRZ0F5y01uG57vluFAD4V6PoOjEK982KrFIcMvMuPf9529EIDzOZJfI//Bjlxa77YQEadGw+UuOwSVf5gkYLnp4Y5KWXOLlv6+3ZSjxSHfBJP918Qs0bG4xZ9Ve8ub/////PywvmfkBJvnHkYFZUlhII2DD7/9IABK2jHZv37Kq2lhYiKLGLuMdZQYGhp8XbvFZC6Mnw7uIKMMEd2iV+j7dxp0dAAPb5ugh/EDeAAAAAElFTkSuQmCC\n", 561 | "text/plain": [ 562 | "" 563 | ] 564 | }, 565 | "metadata": { 566 | "tags": [] 567 | } 568 | }, 569 | { 570 | "output_type": "stream", 571 | "text": [ 572 | "The digit is classified as 2\n", 573 | "------------------------------------------------------------------------------\n", 574 | "Image 2:\n" 575 | ], 576 | "name": "stdout" 577 | }, 578 | { 579 | "output_type": "display_data", 580 | "data": { 581 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAB3klEQVR4nH3TvWsUQRQA8Pd2ZvfW291bch/JLUlMI4YrjB8oikUO0wiCWAkRtJcIgpDC1to/QUhhYyFqIYiK6U6IhUUinIrmQDxByUWW5D73dt6zyO7tHqivmcf8YJh57w0S/Du0/9hfETFKZLwBAMAHq2IJnMZ+OxBZA1n1w05z/0jZ0kaInQcPB97iuUx7+0WL93runWsWR4hqY+2z+vLOQBh2Fch81tMYAICIiHevZIWGgLpd8KZnqls7PUVEJAEA1YfNAUhnIme7J6uTjlWE1IWGnzog5u4uGaYQpowkxsH7HmbKFdugQ0ZCEVJzKML6Dfeoc+F0wRWJEhH7l3VNCISMWfTO1HYpDiAiDtdnAYQ96c0WpF46u51G4vb9SmFu+VXje+28pRlXf3EKiVobj+qtgDn8eL0opx8POI08DJiZiOnbTcte9ceQiOPkbVmcis496CePXofoCPWTx/qpNEQGBOg1V3YgJ9Lo141S3sFu4K89+Rrk77kJotpc+e1cuiWev2m0/FBWl3QeVYi4cQzExInqQsmQmji+NV6E/tOKgTJj6qiZCy8DTiOR/3rxsJvRzan5i+v9UW0xGmq19+NZbX/q9nw2l3QlRgTuDljLiXQ/MfkO8egm8QfuAiaY1DMQLQAAAABJRU5ErkJggg==\n", 582 | "text/plain": [ 583 | "" 584 | ] 585 | }, 586 | "metadata": { 587 | "tags": [] 588 | } 589 | }, 590 | { 591 | "output_type": "stream", 592 | "text": [ 593 | "The digit is classified as 5\n" 594 | ], 595 | "name": "stdout" 596 | } 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "metadata": { 602 | "colab": { 603 | "base_uri": "https://localhost:8080/" 604 | }, 605 | "id": "5aFZScgO-MI2", 606 | "outputId": "7107da8c-197a-4662-e8fa-5793220b135a" 607 | }, 608 | "source": [ 609 | "tf_rep.export_graph('mnist.pb')" 610 | ], 611 | "execution_count": null, 612 | "outputs": [ 613 | { 614 | "output_type": "stream", 615 | "text": [ 616 | "WARNING:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.\n", 617 | "WARNING:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.\n" 618 | ], 619 | "name": "stderr" 620 | }, 621 | { 622 | "output_type": "stream", 623 | "text": [ 624 | "INFO:tensorflow:Assets written to: mnist.pb/assets\n" 625 | ], 626 | "name": "stdout" 627 | }, 628 | { 629 | "output_type": "stream", 630 | "text": [ 631 | "INFO:tensorflow:Assets written to: mnist.pb/assets\n" 632 | ], 633 | "name": "stderr" 634 | } 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "metadata": { 640 | "id": "NOw8u6fNlFdo" 641 | }, 642 | "source": [ 643 | "" 644 | ], 645 | "execution_count": null, 646 | "outputs": [] 647 | } 648 | ] 649 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Data Magic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert PyTorch model to Tensorflow 2 | 3 | I have used ONNX[Open Neural Network Exchange] to convert the PyTorch model to Tensorflow. 4 | 5 | ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers. 6 | 7 | You can find the more details about onnx here: https://onnx.ai/ 8 | --------------------------------------------------------------------------------