├── .gitignore ├── LICENSE ├── README.md ├── cuda2cpu └── README.md ├── linear2conv └── spatial_model.ipynb ├── model_graph ├── README.md ├── model_def.py └── visualize.py ├── profiler ├── README.md └── profile.py ├── pytorch2thnets └── thexport.py └── torch2pytorch └── th2pyth.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .DS_Store 3 | *.t7 4 | *.sw* 5 | *.net 6 | *.*~ 7 | *.pth 8 | *.pth.tar 9 | *.gv 10 | *.pdf 11 | *.svg 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 e-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Toolbox 2 | 3 | This repository contains several tools useful for pytorch users. 4 | 5 | + [Cuda to cpu](cuda2cpu): get rid of dataparallel while loading cuda model into cpu 6 | + [Pytorch to thnets](pytorch2thnets): convert pytorch model into thnets representation, tested to work with pytorch version 0.2.0+08b4770 7 | + [Spatial model](linear2conv): convert `Linear` layer into `Conv2D` 8 | + [Torch to pytorch](torch2pytorch): convert torch model into pytorch model y 9 | + [Visualize model](model_graph): generate model graph 10 | -------------------------------------------------------------------------------- /cuda2cpu/README.md: -------------------------------------------------------------------------------- 1 | # Load and convert to GPU model to CPU: 2 | 3 | ``` 4 | # load model and convert model to cpu: 5 | import torch 6 | from torchvision import models 7 | import torch.nn.parallel 8 | model = models.AlexNet() 9 | checkpoint = torch.load('model_best.pth.tar') 10 | state_dict = checkpoint['state_dict'] 11 | print('loaded state dict:', state_dict.keys()) 12 | 13 | print('\nIn state dict keys there is an extra word inserted by model parallel: "module.". We remove it here:') 14 | from collections import OrderedDict 15 | new_state_dict = OrderedDict() 16 | 17 | for k, v in state_dict.items(): 18 | name = k[0:9] + k[16:] # remove `module.` 19 | if k[0] == 'f': 20 | new_state_dict[name] = v 21 | else: 22 | new_state_dict[k] = v 23 | 24 | model.load_state_dict(new_state_dict) 25 | model.cpu() 26 | 27 | print('Now see converted state dict:') 28 | print(new_state_dict.keys()) 29 | 30 | # saving model: 31 | model_dict={} 32 | model_dict['model_def']=model 33 | model_dict['weights']=model.state_dict() 34 | torch.save(model_dict, 'model_cpu.pth') 35 | ``` 36 | 37 | See: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3 38 | 39 | And: https://discuss.pytorch.org/t/loading-weights-for-cpu-model-while-trained-on-gpu/1032 40 | 41 | 42 | # example load 43 | ``` 44 | model_dict = torch.load('model_cpu.pth') 45 | model = model_dict['model_def'] 46 | model.load_state_dict( model_dict['weights'] ) 47 | ``` 48 | -------------------------------------------------------------------------------- /linear2conv/spatial_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Take in a model with linear layer which say works for image of resolution 3x224x224\n", 12 | "# Convert it into a spatial model which can now work on any image size\n", 13 | "\n", 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import torchvision.models as models\n", 17 | "from torch.autograd import Variable" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "('features', Sequential (\n", 30 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", 31 | " (1): ReLU (inplace)\n", 32 | " (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))\n", 33 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 34 | " (4): ReLU (inplace)\n", 35 | " (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))\n", 36 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 37 | " (7): ReLU (inplace)\n", 38 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 39 | " (9): ReLU (inplace)\n", 40 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 41 | " (11): ReLU (inplace)\n", 42 | " (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))\n", 43 | "))\n", 44 | "('classifier', Sequential (\n", 45 | " (0): Dropout (p = 0.5)\n", 46 | " (1): Linear (9216 -> 4096)\n", 47 | " (2): ReLU (inplace)\n", 48 | " (3): Dropout (p = 0.5)\n", 49 | " (4): Linear (4096 -> 4096)\n", 50 | " (5): ReLU (inplace)\n", 51 | " (6): Linear (4096 -> 1000)\n", 52 | "))\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "m1 = models.alexnet(pretrained=True) # pretrained alexnet model\n", 58 | "m1.eval()\n", 59 | "for (name, layer) in m1._modules.items():\n", 60 | " #iteration over outer layers\n", 61 | " print((name, layer))" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": { 68 | "collapsed": true 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "class ModelDef(nn.Module):\n", 73 | "\n", 74 | " def __init__(self, num_classes=1000):\n", 75 | " super(ModelDef, self).__init__()\n", 76 | " self.features = nn.Sequential(\n", 77 | " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n", 78 | " nn.ReLU(inplace=True), \n", 79 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 80 | " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n", 81 | " nn.ReLU(inplace=True),\n", 82 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 83 | " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", 84 | " nn.ReLU(inplace=True),\n", 85 | " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", 86 | " nn.ReLU(inplace=True),\n", 87 | " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", 88 | " nn.ReLU(inplace=True),\n", 89 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 90 | " )\n", 91 | " self.classifier = nn.Sequential(\n", 92 | " nn.Dropout(),\n", 93 | " nn.Conv2d(256, 4096, kernel_size=6),\n", 94 | " nn.ReLU(inplace=True),\n", 95 | " nn.Dropout(),\n", 96 | " nn.Conv2d(4096, 4096, kernel_size=1),\n", 97 | " nn.ReLU(inplace=True),\n", 98 | " nn.Conv2d(4096, num_classes, kernel_size=1)\n", 99 | " ) \n", 100 | " \n", 101 | " def forward(self, x):\n", 102 | " x = self.features(x)\n", 103 | " x = self.classifier(x)\n", 104 | " return x" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "('features', Sequential (\n", 117 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", 118 | " (1): ReLU (inplace)\n", 119 | " (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))\n", 120 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 121 | " (4): ReLU (inplace)\n", 122 | " (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))\n", 123 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 124 | " (7): ReLU (inplace)\n", 125 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 126 | " (9): ReLU (inplace)\n", 127 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 128 | " (11): ReLU (inplace)\n", 129 | " (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))\n", 130 | "))\n", 131 | "('classifier', Sequential (\n", 132 | " (0): Dropout (p = 0.5)\n", 133 | " (1): Conv2d(256, 4096, kernel_size=(6, 6), stride=(1, 1))\n", 134 | " (2): ReLU (inplace)\n", 135 | " (3): Dropout (p = 0.5)\n", 136 | " (4): Conv2d(4096, 4096, kernel_size=(1, 1), stride=(1, 1))\n", 137 | " (5): ReLU (inplace)\n", 138 | " (6): Conv2d(4096, 1000, kernel_size=(1, 1), stride=(1, 1))\n", 139 | "))\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "m2 = ModelDef()\n", 145 | "m2.eval()\n", 146 | "for (name, layer) in m2._modules.items():\n", 147 | " #iteration over outer layers\n", 148 | " print((name, layer))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "Variable containing:\n", 161 | "1.00000e-02 *\n", 162 | " -0.7129 -0.3940 0.5958 -0.7427 -0.9551 -1.0343\n", 163 | " 1.0163 1.0155 -0.0219 -0.8092 0.4064 -0.3752\n", 164 | " 0.9636 0.7135 -0.9375 0.8450 0.1269 0.6673\n", 165 | " 0.8593 1.0271 0.6331 1.0400 0.4612 -0.1184\n", 166 | " -0.5954 -0.3229 -0.4763 -0.1278 0.4813 -0.6584\n", 167 | " 0.1439 -1.0080 -0.8007 0.0180 -0.1923 -0.1401\n", 168 | "[torch.FloatTensor of size 6x6]\n", 169 | "\n", 170 | "Variable containing:\n", 171 | "1.00000e-02 *\n", 172 | " -0.9219\n", 173 | " -0.0723\n", 174 | " -1.0670\n", 175 | " -2.1309\n", 176 | " -0.8962\n", 177 | " -2.0367\n", 178 | " -1.0963\n", 179 | " -1.1400\n", 180 | " -1.0789\n", 181 | " -0.9955\n", 182 | " -0.5506\n", 183 | " -0.5041\n", 184 | " -2.1386\n", 185 | " -0.7569\n", 186 | " -2.2493\n", 187 | " -0.7488\n", 188 | " -1.8975\n", 189 | " -1.0458\n", 190 | " -1.7070\n", 191 | " -0.2994\n", 192 | " -0.1663\n", 193 | " -0.0909\n", 194 | " -0.8320\n", 195 | " -0.3444\n", 196 | " -1.1933\n", 197 | " 0.3273\n", 198 | " 0.2165\n", 199 | " 0.2178\n", 200 | " -0.1997\n", 201 | " -0.5213\n", 202 | " -0.4225\n", 203 | " -0.1865\n", 204 | " -0.8689\n", 205 | " -0.6569\n", 206 | " -0.3979\n", 207 | " -0.3722\n", 208 | "[torch.FloatTensor of size 36]\n", 209 | "\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "print(m2._modules['classifier'][1].weight[1][1])\n", 215 | "print(m1._modules['classifier'][1].weight[1][36:72]) # Weights of the models are different at this point" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 6, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "Variable containing:\n", 228 | "-8.7606e+28\n", 229 | "-2.1060e+27\n", 230 | "-1.2613e+28\n", 231 | "-1.2498e+29\n", 232 | "-4.9750e+28\n", 233 | "[torch.FloatTensor of size 5]\n", 234 | "\n", 235 | "Variable containing:\n", 236 | "-2.6316e+26\n", 237 | "-2.3292e+26\n", 238 | "-1.7740e+26\n", 239 | " 3.5361e+26\n", 240 | "-2.8206e+25\n", 241 | "[torch.FloatTensor of size 5]\n", 242 | "\n" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "x = Variable(torch.FloatTensor(1, 3, 224, 224))\n", 248 | "y1 = m1(x)\n", 249 | "y2 = m2(x)\n", 250 | "print(y1[0, :5])\n", 251 | "print(y2[0, :5, 0, 0]) # Different output values; as expected!!!" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 7, 257 | "metadata": { 258 | "collapsed": true 259 | }, 260 | "outputs": [], 261 | "source": [ 262 | "for i, j in zip(m1.modules(), m2.modules()):\n", 263 | " if not list(i.children()):\n", 264 | " if isinstance(i, nn.Linear): # copy weights of linear layer into conv2d\n", 265 | " j.weight.data = i.weight.data.view(j.weight.size())\n", 266 | " j.bias.data = i.bias.data\n", 267 | " else:\n", 268 | " if len(i.state_dict()) > 0: # relu and dropout do not have anything in their state_dict\n", 269 | " j.weight.data = i.weight.data\n", 270 | " j.bias.data = i.bias.data" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 8, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "Variable containing:\n", 283 | "1.00000e-02 *\n", 284 | " -0.3969 -0.8168 -0.3132 0.0382 -1.2532 -0.7787\n", 285 | " -0.5779 -0.6933 0.1444 0.3889 -0.5300 0.1078\n", 286 | " -0.0160 -0.5410 0.3189 -0.1015 -0.3006 -0.1682\n", 287 | " -0.8105 0.4949 -0.0498 0.6025 -0.7505 -0.4757\n", 288 | " -0.8852 -0.7535 -0.7075 0.5752 0.2680 -1.7264\n", 289 | " 0.3389 -0.7997 0.4491 1.4019 0.5940 0.2137\n", 290 | "[torch.FloatTensor of size 6x6]\n", 291 | "\n", 292 | "Variable containing:\n", 293 | "1.00000e-02 *\n", 294 | " -0.3969\n", 295 | " -0.8168\n", 296 | " -0.3132\n", 297 | " 0.0382\n", 298 | " -1.2532\n", 299 | " -0.7787\n", 300 | " -0.5779\n", 301 | " -0.6933\n", 302 | " 0.1444\n", 303 | " 0.3889\n", 304 | " -0.5300\n", 305 | " 0.1078\n", 306 | " -0.0160\n", 307 | " -0.5410\n", 308 | " 0.3189\n", 309 | " -0.1015\n", 310 | " -0.3006\n", 311 | " -0.1682\n", 312 | " -0.8105\n", 313 | " 0.4949\n", 314 | " -0.0498\n", 315 | " 0.6025\n", 316 | " -0.7505\n", 317 | " -0.4757\n", 318 | " -0.8852\n", 319 | " -0.7535\n", 320 | " -0.7075\n", 321 | " 0.5752\n", 322 | " 0.2680\n", 323 | " -1.7264\n", 324 | " 0.3389\n", 325 | " -0.7997\n", 326 | " 0.4491\n", 327 | " 1.4019\n", 328 | " 0.5940\n", 329 | " 0.2137\n", 330 | "[torch.FloatTensor of size 36]\n", 331 | "\n" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "print(m2._modules['classifier'][1].weight[0][1])\n", 337 | "print(m1._modules['classifier'][1].weight[0][36:72]) # Weights of both the models are now exactly the same" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 9, 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "Variable containing:\n", 350 | "-8.7606e+28\n", 351 | "-2.1060e+27\n", 352 | "-1.2613e+28\n", 353 | "-1.2498e+29\n", 354 | "-4.9750e+28\n", 355 | "[torch.FloatTensor of size 5]\n", 356 | "\n", 357 | "Variable containing:\n", 358 | "-8.7606e+28\n", 359 | "-2.1060e+27\n", 360 | "-1.2613e+28\n", 361 | "-1.2498e+29\n", 362 | "-4.9750e+28\n", 363 | "[torch.FloatTensor of size 5]\n", 364 | "\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "y1 = m1(x)\n", 370 | "y2 = m2(x)\n", 371 | "print(y1[0][:5])\n", 372 | "print(y2[0, :5, 0, 0]) # Same output values as expected!!!" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 10, 378 | "metadata": {}, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "torch.Size([1, 1000, 3, 3])\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "x1 = Variable(torch.FloatTensor(1, 3, 300, 300))\n", 390 | "y2 = m2(x1)\n", 391 | "print(y2.size()) # Now the network is capable of giving spatial output" 392 | ] 393 | } 394 | ], 395 | "metadata": { 396 | "kernelspec": { 397 | "display_name": "Python 3", 398 | "language": "python", 399 | "name": "python3" 400 | }, 401 | "language_info": { 402 | "codemirror_mode": { 403 | "name": "ipython", 404 | "version": 3 405 | }, 406 | "file_extension": ".py", 407 | "mimetype": "text/x-python", 408 | "name": "python", 409 | "nbconvert_exporter": "python", 410 | "pygments_lexer": "ipython3", 411 | "version": "3.5.2" 412 | } 413 | }, 414 | "nbformat": 4, 415 | "nbformat_minor": 2 416 | } 417 | -------------------------------------------------------------------------------- /model_graph/README.md: -------------------------------------------------------------------------------- 1 | # Generate Model Graph 2 | 3 | This script uses the backward pass of a model to generate the model graph and displays it using graphviz. 4 | You can load model definition from torchvision or your own model definition. 5 | In order to load your own model definition, modify `model_def.py` file. 6 | 7 | + From torchvision: 8 | 9 | ``` 10 | python3 visualize.py --from_zoo --model resent18 --detailed 11 | ``` 12 | 13 | + Custom model definition: 14 | 15 | ``` 16 | python3 visualize.py --detailed 17 | ``` 18 | 19 | Adapted from: https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py 20 | 21 | Also see: https://discuss.pytorch.org/t/print-autograd-graph/692 22 | -------------------------------------------------------------------------------- /model_graph/model_def.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as f 3 | 4 | 5 | class ModelDef(nn.Module): 6 | 7 | def __init__(self, num_classes=1000): 8 | super(ModelDef, self).__init__() 9 | self.conv1=nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) 10 | self.bn=nn.BatchNorm2d(3) 11 | self.fc=nn.Linear(3*224*224,1000) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | 15 | def forward(self, x): 16 | x = self.conv1(x) 17 | x = self.conv1(x) 18 | residual = x 19 | x = self.bn(x) 20 | x=self.relu(x) 21 | x+=residual 22 | x = x.view(-1, 3*224*224) 23 | x = self.fc(x) 24 | 25 | return x 26 | -------------------------------------------------------------------------------- /model_graph/visualize.py: -------------------------------------------------------------------------------- 1 | # e-Lab Model Visualization Script 2 | # 3 | # Abhishek Chaurasia 4 | 5 | import torch 6 | import torchvision.models as models 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from graphviz import Digraph 10 | from argparse import ArgumentParser 11 | 12 | 13 | parser = ArgumentParser(description='e-Lab Model Visualization Script') 14 | _ = parser.add_argument 15 | _('--model', type=str, default='alexnet', help='model definition') 16 | _('--from_zoo', action='store_true', help='load from vision or your own model') 17 | _('--detailed', action='store_true', help='detailed blocks or not') 18 | 19 | args = parser.parse_args() 20 | 21 | 22 | def make_dot(var): 23 | node_attr = dict(style='filled', 24 | shape='box', 25 | align='left', 26 | fontsize='12', 27 | ranksep='0.1', 28 | height='0.2') 29 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"), format='svg') 30 | seen = set() 31 | 32 | module_att = ('kernel_size', 'stride', 'padding', 'dilation') 33 | detailed = args.detailed # Select False to see concise version 34 | 35 | 36 | def size_to_str(size): 37 | return '('+(', ').join(['%d'% v for v in size])+')' 38 | 39 | 40 | def get_details(module): 41 | fill_color = None 42 | label = str(type(module).__name__)[:-8] 43 | 44 | if detailed: # Show all kernel, stride, etc. values 45 | if(str(type(module).__name__)=="ConvNdBackward"): 46 | kernel = module.next_functions[1][0].variable 47 | label = label + \ 48 | '\\nkernel_size=' + size_to_str(kernel.size()) 49 | fill_color = 'orange' 50 | 51 | for attribute in module_att: 52 | if(hasattr(module, attribute)): 53 | label = label + '\\n' + attribute + \ 54 | '=' + str(getattr(module, attribute)) 55 | 56 | if fill_color == None: 57 | fill_color = 'lightblue' 58 | 59 | if(str(type(module).__name__)=="AddmmBackward"): # Linear layer 60 | label = label + ' ' + size_to_str(module.saved_tensors[1].size()) 61 | fill_color = 'orange' 62 | 63 | return label, fill_color 64 | 65 | 66 | def graph_gen(module): 67 | if module not in seen: 68 | seen.add(module) 69 | label, fill_color = get_details(module) 70 | dot.node(str(id(module)), label, fillcolor=fill_color) 71 | if hasattr(module, 'next_functions'): # only the main branch of graph has next_function 72 | for child in module.next_functions: 73 | if child[0]: # ignore variables 74 | if not hasattr(child[0], 'variable'): # eliminate accumulated grad 75 | if not (str(type(child[0]).__name__)=="TransposeBackward"): 76 | dot.edge(str(id(child[0])), str(id(module))) 77 | graph_gen(child[0]) 78 | 79 | 80 | graph_gen(var.grad_fn) 81 | return dot 82 | 83 | 84 | if args.from_zoo: 85 | model = getattr(models, args.model)() 86 | else: 87 | from model_def import ModelDef 88 | model = ModelDef() 89 | 90 | x = torch.randn(1,3,224, 224) 91 | y = model(Variable(x)) 92 | 93 | g = make_dot(y) 94 | g.view() 95 | -------------------------------------------------------------------------------- /profiler/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-profiler 2 | Computes #parameters and #ops for a given network. 3 | 4 | > Will ignore layers that have not been implemeneted yet. 5 | 6 | ``` 7 | usage: profile.py [-h] model input_size [input_size ...] 8 | 9 | pytorch model profiler 10 | 11 | positional arguments: 12 | model model to profile 13 | input_size input size to network 14 | 15 | optional arguments: 16 | -h, --help show this help message and exit 17 | ``` 18 | 19 | If you want to profile model with custom layers, you can implement a ```count_``` function like in ```profile.py``` and write a tiny script like following: 20 | 21 | ``` 22 | from profile import profile 23 | 24 | model = 25 | 26 | custom_ops = { '', ... } 27 | 28 | num_ops, num_params = profile(model, input_size, custom_ops) 29 | ``` 30 | -------------------------------------------------------------------------------- /profiler/profile.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def count_conv2d(m, x, y): 7 | x = x[0] 8 | 9 | cin = m.in_channels // m.groups 10 | cout = m.out_channels // m.groups 11 | kh, kw = m.kernel_size 12 | batch_size = x.size()[0] 13 | 14 | # ops per output element 15 | kernel_mul = kh * kw * cin 16 | kernel_add = kh * kw * cin - 1 17 | bias_ops = 1 if m.bias is not None else 0 18 | ops = kernel_mul + kernel_add + bias_ops 19 | 20 | # total ops 21 | num_out_elements = y.numel() 22 | total_ops = num_out_elements * ops 23 | 24 | # incase same conv is used multiple times 25 | m.total_ops += torch.Tensor([int(total_ops)]) 26 | 27 | def count_bn2d(m, x, y): 28 | x = x[0] 29 | 30 | nelements = x.numel() 31 | total_sub = nelements 32 | total_div = nelements 33 | total_ops = total_sub + total_div 34 | 35 | m.total_ops += torch.Tensor([int(total_ops)]) 36 | 37 | def count_relu(m, x, y): 38 | x = x[0] 39 | 40 | nelements = x.numel() 41 | total_ops = nelements 42 | 43 | m.total_ops += torch.Tensor([int(total_ops)]) 44 | 45 | def count_softmax(m, x, y): 46 | x = x[0] 47 | 48 | batch_size, nfeatures = x.size() 49 | 50 | total_exp = nfeatures 51 | total_add = nfeatures - 1 52 | total_div = nfeatures 53 | total_ops = batch_size * (total_exp + total_add + total_div) 54 | 55 | m.total_ops += torch.Tensor([int(total_ops)]) 56 | 57 | def count_maxpool(m, x, y): 58 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) - 1 59 | num_elements = y.numel() 60 | total_ops = kernel_ops * num_elements 61 | 62 | m.total_ops += torch.Tensor([int(total_ops)]) 63 | 64 | def count_avgpool(m, x, y): 65 | total_add = torch.prod(torch.Tensor([m.kernel_size])) - 1 66 | total_div = 1 67 | kernel_ops = total_add + total_div 68 | num_elements = y.numel() 69 | total_ops = kernel_ops * num_elements 70 | 71 | m.total_ops += torch.Tensor([int(total_ops)]) 72 | 73 | def count_linear(m, x, y): 74 | # per output element 75 | total_mul = m.in_features 76 | total_add = m.in_features - 1 77 | num_elements = y.numel() 78 | total_ops = (total_mul + total_add) * num_elements 79 | 80 | m.total_ops += torch.Tensor([int(total_ops)]) 81 | 82 | def profile(model, input_size, custom_ops = {}): 83 | 84 | model.eval() 85 | 86 | def add_hooks(m): 87 | if len(list(m.children())) > 0: return 88 | m.register_buffer('total_ops', torch.zeros(1)) 89 | m.register_buffer('total_params', torch.zeros(1)) 90 | 91 | for p in m.parameters(): 92 | m.total_params += torch.Tensor([p.numel()]) 93 | 94 | if isinstance(m, nn.Conv2d): 95 | m.register_forward_hook(count_conv2d) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | m.register_forward_hook(count_bn2d) 98 | elif isinstance(m, nn.ReLU): 99 | m.register_forward_hook(count_relu) 100 | elif isinstance(m, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)): 101 | m.register_forward_hook(count_maxpool) 102 | elif isinstance(m, (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)): 103 | m.register_forward_hook(count_avgpool) 104 | elif isinstance(m, nn.Linear): 105 | m.register_forward_hook(count_linear) 106 | elif isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)): 107 | pass 108 | else: 109 | print("Not implemented for ", m) 110 | 111 | model.apply(add_hooks) 112 | 113 | x = torch.zeros(input_size) 114 | model(x) 115 | 116 | total_ops = 0 117 | total_params = 0 118 | for m in model.modules(): 119 | if len(list(m.children())) > 0: continue 120 | total_ops += m.total_ops 121 | total_params += m.total_params 122 | total_ops = total_ops 123 | total_params = total_params 124 | 125 | return total_ops, total_params 126 | 127 | def main(args): 128 | model = torch.load(args.model) 129 | total_ops, total_params = profile(model, args.input_size) 130 | print("#Ops: %f GOps"%(total_ops/1e9)) 131 | print("#Parameters: %f M"%(total_params/1e6)) 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser(description="pytorch model profiler") 135 | parser.add_argument("model", help="model to profile") 136 | parser.add_argument("input_size", nargs='+', type=int, 137 | help="input size to the network") 138 | args = parser.parse_args() 139 | main(args) 140 | -------------------------------------------------------------------------------- /pytorch2thnets/thexport.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from ctypes import * 4 | from torchvision import models 5 | from torch.autograd import Variable as V 6 | 7 | def writetensor(f, name, t): 8 | f.write(bytearray(c_int8(3))) #Tensor id 9 | f.write(str.encode(name)) #Name 10 | f.write(bytearray(c_int8(0))) #Name terminator 11 | f.write(bytearray(c_int32(t.dim()))) #Number of dimensions 12 | for i in range(t.dim()): 13 | f.write(bytearray(c_int32(t.size(i)))) #Individual dimensions 14 | f.flush() 15 | t.contiguous().storage()._write_file(f) 16 | f.flush() 17 | 18 | def writeint(f, name, v): 19 | f.write(bytearray(c_int8(4))) #int32param id 20 | f.write(str.encode(name)) #Param name 21 | f.write(bytearray(c_int8(0))) #Param name terminator 22 | f.write(bytearray(c_int32(v))) #Data 23 | 24 | def writefloat(f, name, v): 25 | f.write(bytearray(c_int8(5))) #floatparam id 26 | f.write(str.encode(name)) #Param name 27 | f.write(bytearray(c_int8(0))) #Param name terminator 28 | f.write(bytearray(c_float(v))) #Data 29 | 30 | def writeintvect(f, name, v): 31 | f.write(bytearray(c_int8(6))) #32tupleparam id 32 | f.write(str.encode(name)) #Param name 33 | f.write(bytearray(c_int8(0))) #Param name terminator 34 | f.write(bytearray(c_int32(len(v)))) #Tuple elements 35 | for i in range(len(v)): 36 | f.write(bytearray(c_int32(v[i]))) 37 | 38 | def writefunctionid(f, id): 39 | f.write(bytearray(c_int8(7))) #Function id 40 | f.write(bytearray(c_int32(id))) #Data 41 | 42 | def check_layer_class(obj): 43 | if (str(obj.__class__)==""): 44 | return (True,'torch.nn._functions.linear.Linear') 45 | elif (str(obj.__class__)==""): 46 | return (True,'torch.nn._functions.thnn.auto.Threshold') 47 | elif (str(obj.__class__)==""): 48 | return (True,'torch.nn._functions.dropout.Dropout') 49 | elif (str(obj.__class__)==""): 50 | return (True,'torch.autograd._functions.tensor.View') 51 | elif (str(obj.__class__)==""): 52 | return (True,'torch.nn._functions.thnn.pooling.MaxPool2d') 53 | elif (str(obj.__class__)==""): 54 | return (True,'torch.nn._functions.conv.ConvNd') 55 | elif (str(obj.__class__)==""): 56 | return (True,'torch.nn._functions.thnn.pooling.AvgPool2d') 57 | elif (str(obj.__class__)==""): 58 | obj.inplace=True 59 | return (True,'torch.autograd._functions.basic_ops.Add') 60 | elif (str(obj.__class__)==""): 61 | return (True,'torch.nn._functions.batchnorm.BatchNorm') 62 | elif (str(obj.__class__)==""): 63 | return (True,'torch.autograd._functions.tensor.Concat') 64 | return (False,'') 65 | 66 | def check_parameter_class(obj): 67 | if (str(obj.__class__)==""): 68 | return (True,obj.variable) 69 | return (False,'') 70 | 71 | def check_if_linear_weight(obj): 72 | if (str(obj.__class__)==""): 73 | return (True,obj.next_functions[0][0].variable) 74 | return (False,'') 75 | 76 | class Exporter: 77 | def __init__(self, f): 78 | self.f = f 79 | self.output_id = 0 80 | self.objects = {} 81 | #Write 24 bytes header 82 | f.write(str.encode('PyTorch Graph Dump 1.00')) 83 | f.write(bytearray(c_int8(0))) #String terminator 84 | def end(self): 85 | self.f.write(bytearray(c_int8(0))) #End of function id 86 | def input(self): 87 | self.f.write(bytearray(c_int8(1))) #Input id 88 | def function(self, name, obj): 89 | self.f.write(bytearray(c_int8(2))) #Function id 90 | self.objects[obj] = self.output_id 91 | self.f.write(bytearray(c_int32(self.output_id))) #Unique ID of the output of this function 92 | self.output_id = self.output_id + 1 93 | self.f.write(str.encode(name)) #Function name 94 | self.f.write(bytearray(c_int8(0))) #Function name terminator 95 | if hasattr(obj, 'inplace'): 96 | writeint(self.f, 'inplace', obj.inplace) 97 | if hasattr(obj, 'ceil_mode'): 98 | writeint(self.f, 'ceil_mode', obj.ceil_mode) 99 | if hasattr(obj, 'kernel_size'): 100 | writeintvect(self.f, 'kernel_size', obj.kernel_size) 101 | if hasattr(obj, 'new_sizes'): 102 | writeintvect(self.f, 'sizes', obj.new_sizes) 103 | if hasattr(obj, 'stride'): 104 | writeintvect(self.f, 'stride', obj.stride) 105 | if hasattr(obj, 'padding'): 106 | writeintvect(self.f, 'padding', obj.padding) 107 | if hasattr(obj, 'eps'): 108 | writefloat(self.f, 'eps', obj.eps) 109 | #if hasattr(obj, 'threshold'): 110 | # writefloat(self.f, 'threshold', obj.threshold) 111 | #if hasattr(obj, 'value'): 112 | # writefloat(self.f, 'value', obj.value) 113 | if hasattr(obj, 'running_mean'): 114 | writetensor(self.f, 'running_mean', obj.running_mean) 115 | if hasattr(obj, 'running_var'): 116 | writetensor(self.f, 'running_var', obj.running_var) 117 | if hasattr(obj, 'dim'): 118 | writeint(self.f, 'dim', obj.dim) 119 | def tensor(self, t): 120 | writetensor(self.f, '', t.data) 121 | def write(self, obj): 122 | self.function(check_layer_class(obj)[1], obj) 123 | for o in obj.next_functions: 124 | if check_layer_class(o[0])[0]: 125 | if o[0] in self.objects: 126 | writefunctionid(self.f, self.objects[o[0]]) 127 | else: 128 | self.write(o[0]) 129 | if obj.next_functions[0][0] is None: 130 | self.input() 131 | for o in obj.next_functions: 132 | (check,param)=check_if_linear_weight(o[0]) 133 | if check: 134 | self.tensor(param) 135 | for o in obj.next_functions: 136 | (check,param)=check_parameter_class(o[0]) 137 | if check: 138 | self.tensor(param) 139 | self.end() 140 | 141 | def save(path, output): 142 | with open(path, mode='wb') as f: 143 | e = Exporter(f) 144 | e.write(output.grad_fn) 145 | 146 | #model=models.densenet201(pretrained=True).eval() 147 | #out=model(V(torch.FloatTensor(1,3,227,227))) 148 | #save('model.net',out) 149 | -------------------------------------------------------------------------------- /torch2pytorch/th2pyth.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Convert torch model into pytorch model\n", 12 | "\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.legacy.nn as nn1\n", 16 | "from torch.utils.serialization import load_lua\n", 17 | "from torch.autograd import Variable" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": { 24 | "collapsed": true 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "# load torch model\n", 29 | "\n", 30 | "nn1.SpatialConvolutionMM = nn1.SpatialConvolution #load_lua does not recognize SpatialConvolutionMM\n", 31 | "\n", 32 | "m1 = load_lua('/Workspace/model.net')\n", 33 | "m1.evaluate()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "def patch(m):\n", 45 | " s = str(type(m))\n", 46 | " s = s[str.rfind(s, '.')+1:-2]\n", 47 | " if s == 'Padding' and hasattr(m, 'nInputDim') and m.nInputDim == 3:\n", 48 | " m.dim = m.dim + 1\n", 49 | " if s == 'View' and len(m.size) == 1:\n", 50 | " m.size = torch.Size([1,m.size[0]])\n", 51 | " if hasattr(m, 'modules'):\n", 52 | " for m in m.modules:\n", 53 | " patch(m)\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "nn.Sequential {\n", 66 | " [input -> (0) -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> output]\n", 67 | " (0): nn.SpatialConvolution(3 -> 64, 11x11, 4, 4, 2, 2)\n", 68 | " (1): nn.ReLU\n", 69 | " (2): nn.SpatialMaxPooling(3x3, 2, 2)\n", 70 | " (3): nn.SpatialConvolution(64 -> 192, 5x5, 1, 1, 2, 2)\n", 71 | " (4): nn.ReLU\n", 72 | " (5): nn.SpatialMaxPooling(3x3, 2, 2)\n", 73 | " (6): nn.SpatialConvolution(192 -> 384, 3x3, 1, 1, 1, 1)\n", 74 | " (7): nn.ReLU\n", 75 | " (8): nn.SpatialConvolution(384 -> 256, 3x3, 1, 1, 1, 1)\n", 76 | " (9): nn.ReLU\n", 77 | " (10): nn.SpatialConvolution(256 -> 256, 3x3, 1, 1, 1, 1)\n", 78 | " (11): nn.ReLU\n", 79 | " (12): nn.SpatialMaxPooling(3x3, 2, 2)\n", 80 | " (13): nn.View(1, 9216)\n", 81 | " (14): nn.Linear(9216 -> 4096)\n", 82 | " (15): nn.ReLU\n", 83 | " (16): nn.Linear(4096 -> 4096)\n", 84 | " (17): nn.ReLU\n", 85 | " (18): nn.Linear(4096 -> 46)\n", 86 | " (19): nn.SoftMax\n", 87 | "}\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "patch(m1)\n", 93 | "print(m1)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "class ModelDef(nn.Module):\n", 105 | "\n", 106 | " def __init__(self, num_classes=46):\n", 107 | " super(ModelDef, self).__init__()\n", 108 | " self.features = nn.Sequential(\n", 109 | " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n", 110 | " nn.ReLU(inplace=True), \n", 111 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 112 | " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n", 113 | " nn.ReLU(inplace=True),\n", 114 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 115 | " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n", 116 | " nn.ReLU(inplace=True),\n", 117 | " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", 118 | " nn.ReLU(inplace=True),\n", 119 | " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", 120 | " nn.ReLU(inplace=True),\n", 121 | " nn.MaxPool2d(kernel_size=3, stride=2),\n", 122 | " )\n", 123 | " self.classifier = nn.Sequential(\n", 124 | " nn.Linear(9216, 4096),\n", 125 | " nn.ReLU(inplace=True),\n", 126 | " nn.Linear(4096, 4096),\n", 127 | " nn.ReLU(inplace=True),\n", 128 | " nn.Linear(4096, num_classes)\n", 129 | " ) \n", 130 | " \n", 131 | " def forward(self, x):\n", 132 | " x = self.features(x)\n", 133 | " x = x.view(x.size(0), 256 * 6 * 6)\n", 134 | " x = self.classifier(x)\n", 135 | " return x\n" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "metadata": { 142 | "collapsed": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "m2 = ModelDef()\n", 147 | "m2.eval()\n", 148 | "m = nn.Softmax()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 7, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "1.3432990044748294e-06\t0.022018155083060265\n", 161 | "0.0008719050674699247\t0.022039148956537247\n", 162 | "0.007889222353696823\t0.021509507670998573\n", 163 | "0.0004464764497242868\t0.021576479077339172\n", 164 | "0.0003359577094670385\t0.021639961749315262\n", 165 | "1.7650879672800879e-12\t0.02152116224169731\n", 166 | "0.00433177687227726\t0.021732622757554054\n", 167 | "3.5346817139902953e-10\t0.02132371813058853\n", 168 | "5.946105777020674e-17\t0.022074732929468155\n", 169 | "7.42116170772157e-18\t0.021620875224471092\n", 170 | "2.153080686184694e-06\t0.021924156695604324\n", 171 | "0.010166157968342304\t0.02187509462237358\n", 172 | "0.14177590608596802\t0.021905595436692238\n", 173 | "4.358494152256753e-06\t0.022067096084356308\n", 174 | "2.4869538736118643e-18\t0.02138614095747471\n", 175 | "0.008647882379591465\t0.021801117807626724\n", 176 | "8.150720376409737e-11\t0.02167879231274128\n", 177 | "8.782055260780908e-07\t0.021408220753073692\n", 178 | "0.012028225697577\t0.02191108837723732\n", 179 | "0.010035747662186623\t0.02156299166381359\n", 180 | "0.11148055642843246\t0.02182559110224247\n", 181 | "7.24380515748635e-05\t0.02188880741596222\n", 182 | "0.000648920948151499\t0.021966079249978065\n", 183 | "0.00041553800110705197\t0.02156689204275608\n", 184 | "0.5863264799118042\t0.021324940025806427\n", 185 | "1.4560652061845758e-06\t0.021738141775131226\n", 186 | "0.0003181264619342983\t0.021697165444493294\n", 187 | "9.683155076345429e-05\t0.02181481383740902\n", 188 | "1.7417216113813083e-08\t0.02183741331100464\n", 189 | "0.00019580854859668761\t0.02213718183338642\n", 190 | "0.0004190478066448122\t0.02181345596909523\n", 191 | "0.02167375758290291\t0.02184818871319294\n", 192 | "5.7447582548775245e-06\t0.022094057872891426\n", 193 | "0.0026216977275907993\t0.021630844101309776\n", 194 | "1.707065530354157e-05\t0.021400123834609985\n", 195 | "1.1819829559556183e-09\t0.021656261757016182\n", 196 | "0.00038011331344023347\t0.02194521389901638\n", 197 | "0.05634373053908348\t0.021586017683148384\n", 198 | "0.010638452135026455\t0.021794524043798447\n", 199 | "5.027295173931634e-06\t0.02164366841316223\n", 200 | "0.0005301411729305983\t0.021417489275336266\n", 201 | "0.002515652682632208\t0.021755939349532127\n", 202 | "0.00010807059879880399\t0.021748697385191917\n", 203 | "0.008504015393555164\t0.02184322662651539\n", 204 | "3.978097538492875e-06\t0.021836012601852417\n", 205 | "0.00013932630827184767\t0.02161259762942791\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "x1 = torch.randn(1, 3, 224, 224)\n", 211 | "x1_var = Variable(x1)\n", 212 | "y1 = m1.forward(x1)\n", 213 | "y2 = m(m2(x1_var))\n", 214 | "# Output of both network will be different; which is obvious!!!\n", 215 | "for i in range(len(y1[0])):\n", 216 | " print(str(y1[0][i]) + '\\t' + str(y2.data[0][i]))" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 8, 222 | "metadata": { 223 | "collapsed": true 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "# copy weights from torch model into pytorch model\n", 228 | "j = 0\n", 229 | "for i in m2.modules():\n", 230 | " if not list(i.children()):\n", 231 | " if len(i.state_dict()) > 0:\n", 232 | " i.weight.data = m1.modules[j].weight\n", 233 | " i.bias.data = m1.modules[j].bias\n", 234 | " \n", 235 | " j += 1\n", 236 | " if j == 13: # Ignore nn.View\n", 237 | " j += 1\n" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 9, 243 | "metadata": {}, 244 | "outputs": [ 245 | { 246 | "data": { 247 | "text/plain": [ 248 | "Parameter containing:\n", 249 | "-0.0000 -0.0000 -0.0000 ... -0.0001 0.0000 0.0000\n", 250 | " 0.0000 0.0000 0.0000 ... 0.0000 -0.0000 -0.0000\n", 251 | " 0.0002 -0.0001 0.0001 ... 0.0000 -0.0002 0.0169\n", 252 | " ... ⋱ ... \n", 253 | " 0.0000 -0.0000 -0.0000 ... 0.0000 -0.0000 0.0000\n", 254 | "-0.0000 0.0000 0.0000 ... 0.0001 -0.0000 -0.0000\n", 255 | " 0.0000 -0.0000 0.0000 ... -0.0000 -0.0000 -0.0000\n", 256 | "[torch.FloatTensor of size 4096x9216]" 257 | ] 258 | }, 259 | "execution_count": 9, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | } 263 | ], 264 | "source": [ 265 | "m2._modules['classifier'][0].weight" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 10, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "\n", 277 | "-0.0000 -0.0000 -0.0000 ... -0.0001 0.0000 0.0000\n", 278 | " 0.0000 0.0000 0.0000 ... 0.0000 -0.0000 -0.0000\n", 279 | " 0.0002 -0.0001 0.0001 ... 0.0000 -0.0002 0.0169\n", 280 | " ... ⋱ ... \n", 281 | " 0.0000 -0.0000 -0.0000 ... 0.0000 -0.0000 0.0000\n", 282 | "-0.0000 0.0000 0.0000 ... 0.0001 -0.0000 -0.0000\n", 283 | " 0.0000 -0.0000 0.0000 ... -0.0000 -0.0000 -0.0000\n", 284 | "[torch.FloatTensor of size 4096x9216]" 285 | ] 286 | }, 287 | "execution_count": 10, 288 | "metadata": {}, 289 | "output_type": "execute_result" 290 | } 291 | ], 292 | "source": [ 293 | "m1.modules[14].weight # both weights should match" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 11, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "1.3432990044748294e-06\t1.3432990044748294e-06\n", 306 | "0.0008719050674699247\t0.0008719050674699247\n", 307 | "0.007889222353696823\t0.007889222353696823\n", 308 | "0.0004464764497242868\t0.0004464764497242868\n", 309 | "0.0003359577094670385\t0.0003359577094670385\n", 310 | "1.7650879672800879e-12\t1.7650879672800879e-12\n", 311 | "0.00433177687227726\t0.00433177687227726\n", 312 | "3.5346817139902953e-10\t3.5346817139902953e-10\n", 313 | "5.946105777020674e-17\t5.946105777020674e-17\n", 314 | "7.42116170772157e-18\t7.42116170772157e-18\n", 315 | "2.153080686184694e-06\t2.153080686184694e-06\n", 316 | "0.010166157968342304\t0.010166157968342304\n", 317 | "0.14177590608596802\t0.14177590608596802\n", 318 | "4.358494152256753e-06\t4.358494152256753e-06\n", 319 | "2.4869538736118643e-18\t2.4869538736118643e-18\n", 320 | "0.008647882379591465\t0.008647882379591465\n", 321 | "8.150720376409737e-11\t8.150720376409737e-11\n", 322 | "8.782055260780908e-07\t8.782055260780908e-07\n", 323 | "0.012028225697577\t0.012028225697577\n", 324 | "0.010035747662186623\t0.010035747662186623\n", 325 | "0.11148055642843246\t0.11148055642843246\n", 326 | "7.24380515748635e-05\t7.24380515748635e-05\n", 327 | "0.000648920948151499\t0.000648920948151499\n", 328 | "0.00041553800110705197\t0.00041553800110705197\n", 329 | "0.5863264799118042\t0.5863264799118042\n", 330 | "1.4560652061845758e-06\t1.4560652061845758e-06\n", 331 | "0.0003181264619342983\t0.0003181264619342983\n", 332 | "9.683155076345429e-05\t9.683155076345429e-05\n", 333 | "1.7417216113813083e-08\t1.7417216113813083e-08\n", 334 | "0.00019580854859668761\t0.00019580854859668761\n", 335 | "0.0004190478066448122\t0.0004190478066448122\n", 336 | "0.02167375758290291\t0.02167375758290291\n", 337 | "5.7447582548775245e-06\t5.7447582548775245e-06\n", 338 | "0.0026216977275907993\t0.0026216977275907993\n", 339 | "1.707065530354157e-05\t1.707065530354157e-05\n", 340 | "1.1819829559556183e-09\t1.1819829559556183e-09\n", 341 | "0.00038011331344023347\t0.00038011331344023347\n", 342 | "0.05634373053908348\t0.05634373053908348\n", 343 | "0.010638452135026455\t0.010638452135026455\n", 344 | "5.027295173931634e-06\t5.027295173931634e-06\n", 345 | "0.0005301411729305983\t0.0005301411729305983\n", 346 | "0.002515652682632208\t0.002515652682632208\n", 347 | "0.00010807059879880399\t0.00010807059879880399\n", 348 | "0.008504015393555164\t0.008504015393555164\n", 349 | "3.978097538492875e-06\t3.978097538492875e-06\n", 350 | "0.00013932630827184767\t0.00013932630827184767\n" 351 | ] 352 | } 353 | ], 354 | "source": [ 355 | "y1 = m1.forward(x1)\n", 356 | "y2 = m(m2(x1_var))\n", 357 | "# Output of both networks are same because they now have the same weights\n", 358 | "for i in range(len(y1[0])):\n", 359 | " print(str(y1[0][i]) + '\\t' + str(y2.data[0][i]))" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 12, 365 | "metadata": { 366 | "collapsed": true 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "# Conversion of torch to pytorch model complete\n", 371 | "# Time to save the new model\n", 372 | "torch.save(m2.state_dict(), '/Workspace/pytorch_model.pth.tar')" 373 | ] 374 | } 375 | ], 376 | "metadata": { 377 | "kernelspec": { 378 | "display_name": "Python 3", 379 | "language": "python", 380 | "name": "python3" 381 | }, 382 | "language_info": { 383 | "codemirror_mode": { 384 | "name": "ipython", 385 | "version": 3 386 | }, 387 | "file_extension": ".py", 388 | "mimetype": "text/x-python", 389 | "name": "python", 390 | "nbconvert_exporter": "python", 391 | "pygments_lexer": "ipython3", 392 | "version": "3.5.2" 393 | } 394 | }, 395 | "nbformat": 4, 396 | "nbformat_minor": 2 397 | } 398 | --------------------------------------------------------------------------------