├── .gitignore ├── README.md ├── examples ├── alexnet.png ├── googlenet.png ├── inception_v3.png ├── mobilenet_v2.png ├── resnet18.png ├── shufflenet_v2_x1_0.png ├── squeezenet1_0.png └── vgg16.png ├── summary_example.ipynb ├── test.py ├── transform_example.ipynb ├── transformers ├── __init__.py ├── quantize.py ├── torchTransformer.py └── utils.py └── visualize_example.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .ipynb_checkpoints 4 | *.py[cod] 5 | *$py.class 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTranformer 2 | 3 | 4 | 5 | ## summary 6 | This repository implement the summary function similar to keras summary() 7 | 8 | ``` 9 | model = nn.Sequential( 10 | nn.Conv2d(3,20,5), 11 | nn.ReLU(), 12 | nn.Conv2d(20,64,5), 13 | nn.ReLU() 14 | ) 15 | 16 | model.eval() 17 | 18 | transofrmer = TorchTransformer() 19 | input_tensor = torch.randn([1, 3, 224, 224]) 20 | net = transofrmer.summary(model, input_tensor) 21 | 22 | ########################################################################################## 23 | Index| Layer (type) | Bottoms Output Shape Param # 24 | --------------------------------------------------------------------------- 25 | 1| Data | [(1, 3, 224, 224)] 0 26 | --------------------------------------------------------------------------- 27 | 2| Conv2d_1 | Data [(1, 20, 220, 220)] 1500 28 | --------------------------------------------------------------------------- 29 | 3| ReLU_2 | Conv2d_1 [(1, 20, 220, 220)] 0 30 | --------------------------------------------------------------------------- 31 | 4| Conv2d_3 | ReLU_2 [(1, 64, 216, 216)] 32000 32 | --------------------------------------------------------------------------- 33 | 5| ReLU_4 | Conv2d_3 [(1, 64, 216, 216)] 0 34 | --------------------------------------------------------------------------- 35 | ================================================================================== 36 | Total Trainable params: 33500 37 | Total Non-Trainable params: 0 38 | Total params: 33500 39 | ``` 40 | 41 | other example is in [example.ipynb](summary_example.ipynb) 42 | 43 | ## visualize 44 | visualize using [graphviz](https://graphviz.readthedocs.io/en/stable/) and [pydot](https://pypi.org/project/pydot/) 45 | it will show the architecture. 46 | Such as alexnet in torchvision: 47 | ``` 48 | model = models.__dict__["alexnet"]() 49 | model.eval() 50 | transofrmer = TorchTransformer() 51 | transofrmer.visualize(model, save_name= "example", graph_size = 80) 52 | # graph_size can modify to change the size of the output graph 53 | # graphviz does not auto fit the model's layers, which mean if the model is too deep. 54 | # And it will become too small to see. 55 | # So change the graph size to enlarge the image for higher resolution. 56 | ``` 57 | 58 | 59 | example is in [example](visualize_example.ipynb) 60 | other example image is in [examples](/examples) 61 | 62 | ## transform layers 63 | you can register layer type to transform 64 | First you need to register to transformer and the transformer will transform layers you registered. 65 | 66 | example in in [transform_example](transform_example.ipynb) 67 | 68 | 69 | 70 | 71 | ## Note 72 | Suggest that the layers input should not be too many because the graphviz may generate image slow.(eg: densenet161 in torchvision 0.4.0 may stuck when generating png) 73 | 74 | ## TODO 75 | - [x] support registration(replace) for custom layertype 76 | - [ ] support replacement of specified layer in model for specified layer 77 | - [x] activation size calculation for supported layers 78 | - [x] network summary output as in keras 79 | - [x] model graph visualization 80 | - [ ] replace multiple modules to 1 module 81 | - [ ] conditional module replacement 82 | - [ ] add additional module to forward graph 83 | -------------------------------------------------------------------------------- /examples/alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/alexnet.png -------------------------------------------------------------------------------- /examples/googlenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/googlenet.png -------------------------------------------------------------------------------- /examples/inception_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/inception_v3.png -------------------------------------------------------------------------------- /examples/mobilenet_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/mobilenet_v2.png -------------------------------------------------------------------------------- /examples/resnet18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/resnet18.png -------------------------------------------------------------------------------- /examples/shufflenet_v2_x1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/shufflenet_v2_x1_0.png -------------------------------------------------------------------------------- /examples/squeezenet1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/squeezenet1_0.png -------------------------------------------------------------------------------- /examples/vgg16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/vgg16.png -------------------------------------------------------------------------------- /summary_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torchvision\n", 12 | "import torchvision.models as models\n", 13 | "\n", 14 | "from transformers.torchTransformer import TorchTransformer" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "scrolled": false 22 | }, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/plain": [ 27 | "ResNet(\n", 28 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 29 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 30 | " (relu): ReLU(inplace=True)\n", 31 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 32 | " (layer1): Sequential(\n", 33 | " (0): BasicBlock(\n", 34 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 35 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 36 | " (relu): ReLU(inplace=True)\n", 37 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 38 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 39 | " )\n", 40 | " (1): BasicBlock(\n", 41 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 42 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 43 | " (relu): ReLU(inplace=True)\n", 44 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 45 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 46 | " )\n", 47 | " )\n", 48 | " (layer2): Sequential(\n", 49 | " (0): BasicBlock(\n", 50 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 51 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 52 | " (relu): ReLU(inplace=True)\n", 53 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 54 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 55 | " (downsample): Sequential(\n", 56 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 57 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 58 | " )\n", 59 | " )\n", 60 | " (1): BasicBlock(\n", 61 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 62 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 63 | " (relu): ReLU(inplace=True)\n", 64 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 65 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 66 | " )\n", 67 | " )\n", 68 | " (layer3): Sequential(\n", 69 | " (0): BasicBlock(\n", 70 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 71 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 72 | " (relu): ReLU(inplace=True)\n", 73 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 74 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 75 | " (downsample): Sequential(\n", 76 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 77 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 78 | " )\n", 79 | " )\n", 80 | " (1): BasicBlock(\n", 81 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 82 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (relu): ReLU(inplace=True)\n", 84 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 85 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 86 | " )\n", 87 | " )\n", 88 | " (layer4): Sequential(\n", 89 | " (0): BasicBlock(\n", 90 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 91 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 92 | " (relu): ReLU(inplace=True)\n", 93 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 94 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 95 | " (downsample): Sequential(\n", 96 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 97 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 98 | " )\n", 99 | " )\n", 100 | " (1): BasicBlock(\n", 101 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 102 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 103 | " (relu): ReLU(inplace=True)\n", 104 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 105 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 106 | " )\n", 107 | " )\n", 108 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", 109 | " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", 110 | ")" 111 | ] 112 | }, 113 | "execution_count": 2, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "model = models.__dict__[\"resnet18\"]()\n", 120 | "model.eval()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 3, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "transformer = TorchTransformer()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "## summary(cpu)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 4, 142 | "metadata": { 143 | "scrolled": false 144 | }, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "##########################################################################################\n", 151 | "Index| Layer (type) | Bottoms Output Shape Param # \n", 152 | "---------------------------------------------------------------------------\n", 153 | " 0| Data | [(1, 3, 224, 224)] 0 \n", 154 | "---------------------------------------------------------------------------\n", 155 | " 1| Conv2d_1 | Data [(1, 64, 112, 112)] 9408 \n", 156 | "---------------------------------------------------------------------------\n", 157 | " 2| BatchNorm2d_2 | Conv2d_1 [(1, 64, 112, 112)] 64 \n", 158 | "---------------------------------------------------------------------------\n", 159 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n", 160 | "---------------------------------------------------------------------------\n", 161 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n", 162 | "---------------------------------------------------------------------------\n", 163 | " 5| Conv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n", 164 | "---------------------------------------------------------------------------\n", 165 | " 6| BatchNorm2d_6 | Conv2d_5 [(1, 64, 56, 56)] 64 \n", 166 | "---------------------------------------------------------------------------\n", 167 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n", 168 | "---------------------------------------------------------------------------\n", 169 | " 8| Conv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n", 170 | "---------------------------------------------------------------------------\n", 171 | " 9| BatchNorm2d_9 | Conv2d_8 [(1, 64, 56, 56)] 64 \n", 172 | "---------------------------------------------------------------------------\n", 173 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n", 174 | " | | MaxPool2d_4 \n", 175 | "---------------------------------------------------------------------------\n", 176 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n", 177 | "---------------------------------------------------------------------------\n", 178 | " 12| Conv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n", 179 | "---------------------------------------------------------------------------\n", 180 | " 13| BatchNorm2d_13 | Conv2d_12 [(1, 64, 56, 56)] 64 \n", 181 | "---------------------------------------------------------------------------\n", 182 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n", 183 | "---------------------------------------------------------------------------\n", 184 | " 15| Conv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n", 185 | "---------------------------------------------------------------------------\n", 186 | " 16| BatchNorm2d_16 | Conv2d_15 [(1, 64, 56, 56)] 64 \n", 187 | "---------------------------------------------------------------------------\n", 188 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n", 189 | " | | ReLU_11 \n", 190 | "---------------------------------------------------------------------------\n", 191 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n", 192 | "---------------------------------------------------------------------------\n", 193 | " 19| Conv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n", 194 | "---------------------------------------------------------------------------\n", 195 | " 20| BatchNorm2d_20 | Conv2d_19 [(1, 128, 28, 28)] 128 \n", 196 | "---------------------------------------------------------------------------\n", 197 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n", 198 | "---------------------------------------------------------------------------\n", 199 | " 22| Conv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n", 200 | "---------------------------------------------------------------------------\n", 201 | " 23| BatchNorm2d_23 | Conv2d_22 [(1, 128, 28, 28)] 128 \n", 202 | "---------------------------------------------------------------------------\n", 203 | " 24| Conv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n", 204 | "---------------------------------------------------------------------------\n", 205 | " 25| BatchNorm2d_25 | Conv2d_24 [(1, 128, 28, 28)] 128 \n", 206 | "---------------------------------------------------------------------------\n", 207 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n", 208 | " | | BatchNorm2d_25 \n", 209 | "---------------------------------------------------------------------------\n", 210 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n", 211 | "---------------------------------------------------------------------------\n", 212 | " 28| Conv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n", 213 | "---------------------------------------------------------------------------\n", 214 | " 29| BatchNorm2d_29 | Conv2d_28 [(1, 128, 28, 28)] 128 \n", 215 | "---------------------------------------------------------------------------\n", 216 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n", 217 | "---------------------------------------------------------------------------\n", 218 | " 31| Conv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n", 219 | "---------------------------------------------------------------------------\n", 220 | " 32| BatchNorm2d_32 | Conv2d_31 [(1, 128, 28, 28)] 128 \n", 221 | "---------------------------------------------------------------------------\n", 222 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n", 223 | " | | ReLU_27 \n", 224 | "---------------------------------------------------------------------------\n", 225 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n", 226 | "---------------------------------------------------------------------------\n", 227 | " 35| Conv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n", 228 | "---------------------------------------------------------------------------\n", 229 | " 36| BatchNorm2d_36 | Conv2d_35 [(1, 256, 14, 14)] 256 \n", 230 | "---------------------------------------------------------------------------\n", 231 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n", 232 | "---------------------------------------------------------------------------\n", 233 | " 38| Conv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n", 234 | "---------------------------------------------------------------------------\n", 235 | " 39| BatchNorm2d_39 | Conv2d_38 [(1, 256, 14, 14)] 256 \n", 236 | "---------------------------------------------------------------------------\n", 237 | " 40| Conv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n", 238 | "---------------------------------------------------------------------------\n", 239 | " 41| BatchNorm2d_41 | Conv2d_40 [(1, 256, 14, 14)] 256 \n", 240 | "---------------------------------------------------------------------------\n", 241 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n", 242 | " | | BatchNorm2d_41 \n", 243 | "---------------------------------------------------------------------------\n", 244 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n", 245 | "---------------------------------------------------------------------------\n", 246 | " 44| Conv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n", 247 | "---------------------------------------------------------------------------\n", 248 | " 45| BatchNorm2d_45 | Conv2d_44 [(1, 256, 14, 14)] 256 \n", 249 | "---------------------------------------------------------------------------\n", 250 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n", 251 | "---------------------------------------------------------------------------\n", 252 | " 47| Conv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n", 253 | "---------------------------------------------------------------------------\n", 254 | " 48| BatchNorm2d_48 | Conv2d_47 [(1, 256, 14, 14)] 256 \n", 255 | "---------------------------------------------------------------------------\n", 256 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n", 257 | " | | ReLU_43 \n", 258 | "---------------------------------------------------------------------------\n", 259 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n", 260 | "---------------------------------------------------------------------------\n", 261 | " 51| Conv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n", 262 | "---------------------------------------------------------------------------\n", 263 | " 52| BatchNorm2d_52 | Conv2d_51 [(1, 512, 7, 7)] 512 \n", 264 | "---------------------------------------------------------------------------\n", 265 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n", 266 | "---------------------------------------------------------------------------\n", 267 | " 54| Conv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n", 268 | "---------------------------------------------------------------------------\n", 269 | " 55| BatchNorm2d_55 | Conv2d_54 [(1, 512, 7, 7)] 512 \n", 270 | "---------------------------------------------------------------------------\n", 271 | " 56| Conv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n", 272 | "---------------------------------------------------------------------------\n", 273 | " 57| BatchNorm2d_57 | Conv2d_56 [(1, 512, 7, 7)] 512 \n", 274 | "---------------------------------------------------------------------------\n", 275 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n", 276 | " | | BatchNorm2d_57 \n", 277 | "---------------------------------------------------------------------------\n", 278 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n", 279 | "---------------------------------------------------------------------------\n", 280 | " 60| Conv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n", 281 | "---------------------------------------------------------------------------\n", 282 | " 61| BatchNorm2d_61 | Conv2d_60 [(1, 512, 7, 7)] 512 \n", 283 | "---------------------------------------------------------------------------\n", 284 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n", 285 | "---------------------------------------------------------------------------\n", 286 | " 63| Conv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n", 287 | "---------------------------------------------------------------------------\n", 288 | " 64| BatchNorm2d_64 | Conv2d_63 [(1, 512, 7, 7)] 512 \n", 289 | "---------------------------------------------------------------------------\n", 290 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n", 291 | " | | ReLU_59 \n", 292 | "---------------------------------------------------------------------------\n", 293 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n", 294 | "---------------------------------------------------------------------------\n", 295 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n", 296 | "---------------------------------------------------------------------------\n", 297 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n", 298 | "---------------------------------------------------------------------------\n", 299 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n", 300 | "---------------------------------------------------------------------------\n", 301 | "==================================================================================\n", 302 | "Total Trainable params: 11683712 \n", 303 | "Total Non-Trainable params: 0 \n", 304 | "Total params: 11683712 \n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "input_tensor = torch.randn([1, 3, 224, 224])\n", 310 | "transformer.summary(model, input_tensor = input_tensor)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "## summary(gpu)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 5, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "model.cuda()\n", 327 | "input_tensor = input_tensor.cuda()" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 6, 333 | "metadata": { 334 | "scrolled": false 335 | }, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "##########################################################################################\n", 342 | "Index| Layer (type) | Bottoms Output Shape Param # \n", 343 | "---------------------------------------------------------------------------\n", 344 | " 0| Data | [(1, 3, 224, 224)] 0 \n", 345 | "---------------------------------------------------------------------------\n", 346 | " 1| Conv2d_1 | Data [(1, 64, 112, 112)] 9408 \n", 347 | "---------------------------------------------------------------------------\n", 348 | " 2| BatchNorm2d_2 | Conv2d_1 [(1, 64, 112, 112)] 64 \n", 349 | "---------------------------------------------------------------------------\n", 350 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n", 351 | "---------------------------------------------------------------------------\n", 352 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n", 353 | "---------------------------------------------------------------------------\n", 354 | " 5| Conv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n", 355 | "---------------------------------------------------------------------------\n", 356 | " 6| BatchNorm2d_6 | Conv2d_5 [(1, 64, 56, 56)] 64 \n", 357 | "---------------------------------------------------------------------------\n", 358 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n", 359 | "---------------------------------------------------------------------------\n", 360 | " 8| Conv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n", 361 | "---------------------------------------------------------------------------\n", 362 | " 9| BatchNorm2d_9 | Conv2d_8 [(1, 64, 56, 56)] 64 \n", 363 | "---------------------------------------------------------------------------\n", 364 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n", 365 | " | | MaxPool2d_4 \n", 366 | "---------------------------------------------------------------------------\n", 367 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n", 368 | "---------------------------------------------------------------------------\n", 369 | " 12| Conv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n", 370 | "---------------------------------------------------------------------------\n", 371 | " 13| BatchNorm2d_13 | Conv2d_12 [(1, 64, 56, 56)] 64 \n", 372 | "---------------------------------------------------------------------------\n", 373 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n", 374 | "---------------------------------------------------------------------------\n", 375 | " 15| Conv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n", 376 | "---------------------------------------------------------------------------\n", 377 | " 16| BatchNorm2d_16 | Conv2d_15 [(1, 64, 56, 56)] 64 \n", 378 | "---------------------------------------------------------------------------\n", 379 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n", 380 | " | | ReLU_11 \n", 381 | "---------------------------------------------------------------------------\n", 382 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n", 383 | "---------------------------------------------------------------------------\n", 384 | " 19| Conv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n", 385 | "---------------------------------------------------------------------------\n", 386 | " 20| BatchNorm2d_20 | Conv2d_19 [(1, 128, 28, 28)] 128 \n", 387 | "---------------------------------------------------------------------------\n", 388 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n", 389 | "---------------------------------------------------------------------------\n", 390 | " 22| Conv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n", 391 | "---------------------------------------------------------------------------\n", 392 | " 23| BatchNorm2d_23 | Conv2d_22 [(1, 128, 28, 28)] 128 \n", 393 | "---------------------------------------------------------------------------\n", 394 | " 24| Conv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n", 395 | "---------------------------------------------------------------------------\n", 396 | " 25| BatchNorm2d_25 | Conv2d_24 [(1, 128, 28, 28)] 128 \n", 397 | "---------------------------------------------------------------------------\n", 398 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n", 399 | " | | BatchNorm2d_25 \n", 400 | "---------------------------------------------------------------------------\n", 401 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n", 402 | "---------------------------------------------------------------------------\n", 403 | " 28| Conv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n", 404 | "---------------------------------------------------------------------------\n", 405 | " 29| BatchNorm2d_29 | Conv2d_28 [(1, 128, 28, 28)] 128 \n", 406 | "---------------------------------------------------------------------------\n", 407 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n", 408 | "---------------------------------------------------------------------------\n", 409 | " 31| Conv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n", 410 | "---------------------------------------------------------------------------\n", 411 | " 32| BatchNorm2d_32 | Conv2d_31 [(1, 128, 28, 28)] 128 \n", 412 | "---------------------------------------------------------------------------\n", 413 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n", 414 | " | | ReLU_27 \n", 415 | "---------------------------------------------------------------------------\n", 416 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n", 417 | "---------------------------------------------------------------------------\n", 418 | " 35| Conv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n", 419 | "---------------------------------------------------------------------------\n", 420 | " 36| BatchNorm2d_36 | Conv2d_35 [(1, 256, 14, 14)] 256 \n", 421 | "---------------------------------------------------------------------------\n", 422 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n", 423 | "---------------------------------------------------------------------------\n", 424 | " 38| Conv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n", 425 | "---------------------------------------------------------------------------\n", 426 | " 39| BatchNorm2d_39 | Conv2d_38 [(1, 256, 14, 14)] 256 \n", 427 | "---------------------------------------------------------------------------\n", 428 | " 40| Conv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n", 429 | "---------------------------------------------------------------------------\n", 430 | " 41| BatchNorm2d_41 | Conv2d_40 [(1, 256, 14, 14)] 256 \n", 431 | "---------------------------------------------------------------------------\n", 432 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n", 433 | " | | BatchNorm2d_41 \n", 434 | "---------------------------------------------------------------------------\n", 435 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n", 436 | "---------------------------------------------------------------------------\n", 437 | " 44| Conv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n", 438 | "---------------------------------------------------------------------------\n", 439 | " 45| BatchNorm2d_45 | Conv2d_44 [(1, 256, 14, 14)] 256 \n", 440 | "---------------------------------------------------------------------------\n", 441 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n", 442 | "---------------------------------------------------------------------------\n", 443 | " 47| Conv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n", 444 | "---------------------------------------------------------------------------\n", 445 | " 48| BatchNorm2d_48 | Conv2d_47 [(1, 256, 14, 14)] 256 \n", 446 | "---------------------------------------------------------------------------\n", 447 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n", 448 | " | | ReLU_43 \n", 449 | "---------------------------------------------------------------------------\n", 450 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n", 451 | "---------------------------------------------------------------------------\n", 452 | " 51| Conv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n", 453 | "---------------------------------------------------------------------------\n", 454 | " 52| BatchNorm2d_52 | Conv2d_51 [(1, 512, 7, 7)] 512 \n", 455 | "---------------------------------------------------------------------------\n", 456 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n", 457 | "---------------------------------------------------------------------------\n", 458 | " 54| Conv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n", 459 | "---------------------------------------------------------------------------\n", 460 | " 55| BatchNorm2d_55 | Conv2d_54 [(1, 512, 7, 7)] 512 \n", 461 | "---------------------------------------------------------------------------\n", 462 | " 56| Conv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n", 463 | "---------------------------------------------------------------------------\n", 464 | " 57| BatchNorm2d_57 | Conv2d_56 [(1, 512, 7, 7)] 512 \n", 465 | "---------------------------------------------------------------------------\n", 466 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n", 467 | " | | BatchNorm2d_57 \n", 468 | "---------------------------------------------------------------------------\n", 469 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n", 470 | "---------------------------------------------------------------------------\n", 471 | " 60| Conv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n", 472 | "---------------------------------------------------------------------------\n", 473 | " 61| BatchNorm2d_61 | Conv2d_60 [(1, 512, 7, 7)] 512 \n", 474 | "---------------------------------------------------------------------------\n", 475 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n", 476 | "---------------------------------------------------------------------------\n", 477 | " 63| Conv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n", 478 | "---------------------------------------------------------------------------\n", 479 | " 64| BatchNorm2d_64 | Conv2d_63 [(1, 512, 7, 7)] 512 \n", 480 | "---------------------------------------------------------------------------\n", 481 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n", 482 | " | | ReLU_59 \n", 483 | "---------------------------------------------------------------------------\n", 484 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n", 485 | "---------------------------------------------------------------------------\n", 486 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n", 487 | "---------------------------------------------------------------------------\n", 488 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n", 489 | "---------------------------------------------------------------------------\n", 490 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n", 491 | "---------------------------------------------------------------------------\n", 492 | "==================================================================================\n", 493 | "Total Trainable params: 11683712 \n", 494 | "Total Non-Trainable params: 0 \n", 495 | "Total params: 11683712 \n" 496 | ] 497 | } 498 | ], 499 | "source": [ 500 | "transformer.summary(model, input_tensor = input_tensor)" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [] 509 | } 510 | ], 511 | "metadata": { 512 | "kernelspec": { 513 | "display_name": "Python 3", 514 | "language": "python", 515 | "name": "python3" 516 | }, 517 | "language_info": { 518 | "codemirror_mode": { 519 | "name": "ipython", 520 | "version": 3 521 | }, 522 | "file_extension": ".py", 523 | "mimetype": "text/x-python", 524 | "name": "python", 525 | "nbconvert_exporter": "python", 526 | "pygments_lexer": "ipython3", 527 | "version": "3.6.9" 528 | } 529 | }, 530 | "nbformat": 4, 531 | "nbformat_minor": 2 532 | } 533 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torchvision.models as models 5 | import copy 6 | from transformers.torchTransformer import TorchTransformer 7 | from transformers.quantize import QConv2d 8 | model = models.__dict__["resnet18"]() 9 | model.cuda() 10 | model = model.eval() 11 | 12 | transofrmer = TorchTransformer() 13 | transofrmer.register(nn.Conv2d, QConv2d) 14 | model = transofrmer.trans_layers(model) 15 | print(model) 16 | sys.exit() 17 | 18 | 19 | input_tensor = torch.randn([1, 3, 224, 224]) 20 | input_tensor = input_tensor.cuda() 21 | net = transofrmer.summary(model, input_tensor=input_tensor) 22 | # transofrmer.visualize(model, input_tensor = input_tensor, save_name= "example", graph_size = 80) -------------------------------------------------------------------------------- /transform_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torchvision\n", 12 | "import torchvision.models as models\n", 13 | "\n", 14 | "from transformers.torchTransformer import TorchTransformer\n", 15 | "from transformers.quantize import QConv2d" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "model = models.__dict__[\"resnet18\"]()\n", 25 | "model.cuda()\n", 26 | "model = model.eval()\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Register layer to be transform" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "register \n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "transformer = TorchTransformer()\n", 51 | "transformer.register(nn.Conv2d, QConv2d)\n", 52 | "model = transformer.trans_layers(model)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "ResNet(\n", 66 | " (conv1): QConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 67 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 68 | " (relu): ReLU(inplace=True)\n", 69 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 70 | " (layer1): Sequential(\n", 71 | " (0): BasicBlock(\n", 72 | " (conv1): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " (relu): ReLU(inplace=True)\n", 75 | " (conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 76 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 77 | " )\n", 78 | " (1): BasicBlock(\n", 79 | " (conv1): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " (relu): ReLU(inplace=True)\n", 82 | " (conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 83 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 84 | " )\n", 85 | " )\n", 86 | " (layer2): Sequential(\n", 87 | " (0): BasicBlock(\n", 88 | " (conv1): QConv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 89 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (relu): ReLU(inplace=True)\n", 91 | " (conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 92 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 93 | " (downsample): Sequential(\n", 94 | " (0): QConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 95 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 96 | " )\n", 97 | " )\n", 98 | " (1): BasicBlock(\n", 99 | " (conv1): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 100 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 101 | " (relu): ReLU(inplace=True)\n", 102 | " (conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 103 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 104 | " )\n", 105 | " )\n", 106 | " (layer3): Sequential(\n", 107 | " (0): BasicBlock(\n", 108 | " (conv1): QConv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 109 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 110 | " (relu): ReLU(inplace=True)\n", 111 | " (conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 112 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 113 | " (downsample): Sequential(\n", 114 | " (0): QConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 115 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 116 | " )\n", 117 | " )\n", 118 | " (1): BasicBlock(\n", 119 | " (conv1): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 120 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 121 | " (relu): ReLU(inplace=True)\n", 122 | " (conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 123 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 124 | " )\n", 125 | " )\n", 126 | " (layer4): Sequential(\n", 127 | " (0): BasicBlock(\n", 128 | " (conv1): QConv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 129 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 130 | " (relu): ReLU(inplace=True)\n", 131 | " (conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 132 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 133 | " (downsample): Sequential(\n", 134 | " (0): QConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 135 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 136 | " )\n", 137 | " )\n", 138 | " (1): BasicBlock(\n", 139 | " (conv1): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 140 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 141 | " (relu): ReLU(inplace=True)\n", 142 | " (conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 143 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 144 | " )\n", 145 | " )\n", 146 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", 147 | " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", 148 | ")" 149 | ] 150 | }, 151 | "execution_count": 4, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "model" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stdout", 167 | "output_type": "stream", 168 | "text": [ 169 | "##########################################################################################\n", 170 | "Index| Layer (type) | Bottoms Output Shape Param # \n", 171 | "---------------------------------------------------------------------------\n", 172 | " 0| Data | [(1, 3, 224, 224)] 0 \n", 173 | "---------------------------------------------------------------------------\n", 174 | " 1| QConv2d_1 | Data [(1, 64, 112, 112)] 9408 \n", 175 | "---------------------------------------------------------------------------\n", 176 | " 2| BatchNorm2d_2 | QConv2d_1 [(1, 64, 112, 112)] 64 \n", 177 | "---------------------------------------------------------------------------\n", 178 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n", 179 | "---------------------------------------------------------------------------\n", 180 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n", 181 | "---------------------------------------------------------------------------\n", 182 | " 5| QConv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n", 183 | "---------------------------------------------------------------------------\n", 184 | " 6| BatchNorm2d_6 | QConv2d_5 [(1, 64, 56, 56)] 64 \n", 185 | "---------------------------------------------------------------------------\n", 186 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n", 187 | "---------------------------------------------------------------------------\n", 188 | " 8| QConv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n", 189 | "---------------------------------------------------------------------------\n", 190 | " 9| BatchNorm2d_9 | QConv2d_8 [(1, 64, 56, 56)] 64 \n", 191 | "---------------------------------------------------------------------------\n", 192 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n", 193 | " | | MaxPool2d_4 \n", 194 | "---------------------------------------------------------------------------\n", 195 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n", 196 | "---------------------------------------------------------------------------\n", 197 | " 12| QConv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n", 198 | "---------------------------------------------------------------------------\n", 199 | " 13| BatchNorm2d_13 | QConv2d_12 [(1, 64, 56, 56)] 64 \n", 200 | "---------------------------------------------------------------------------\n", 201 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n", 202 | "---------------------------------------------------------------------------\n", 203 | " 15| QConv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n", 204 | "---------------------------------------------------------------------------\n", 205 | " 16| BatchNorm2d_16 | QConv2d_15 [(1, 64, 56, 56)] 64 \n", 206 | "---------------------------------------------------------------------------\n", 207 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n", 208 | " | | ReLU_11 \n", 209 | "---------------------------------------------------------------------------\n", 210 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n", 211 | "---------------------------------------------------------------------------\n", 212 | " 19| QConv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n", 213 | "---------------------------------------------------------------------------\n", 214 | " 20| BatchNorm2d_20 | QConv2d_19 [(1, 128, 28, 28)] 128 \n", 215 | "---------------------------------------------------------------------------\n", 216 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n", 217 | "---------------------------------------------------------------------------\n", 218 | " 22| QConv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n", 219 | "---------------------------------------------------------------------------\n", 220 | " 23| BatchNorm2d_23 | QConv2d_22 [(1, 128, 28, 28)] 128 \n", 221 | "---------------------------------------------------------------------------\n", 222 | " 24| QConv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n", 223 | "---------------------------------------------------------------------------\n", 224 | " 25| BatchNorm2d_25 | QConv2d_24 [(1, 128, 28, 28)] 128 \n", 225 | "---------------------------------------------------------------------------\n", 226 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n", 227 | " | | BatchNorm2d_25 \n", 228 | "---------------------------------------------------------------------------\n", 229 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n", 230 | "---------------------------------------------------------------------------\n", 231 | " 28| QConv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n", 232 | "---------------------------------------------------------------------------\n", 233 | " 29| BatchNorm2d_29 | QConv2d_28 [(1, 128, 28, 28)] 128 \n", 234 | "---------------------------------------------------------------------------\n", 235 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n", 236 | "---------------------------------------------------------------------------\n", 237 | " 31| QConv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n", 238 | "---------------------------------------------------------------------------\n", 239 | " 32| BatchNorm2d_32 | QConv2d_31 [(1, 128, 28, 28)] 128 \n", 240 | "---------------------------------------------------------------------------\n", 241 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n", 242 | " | | ReLU_27 \n", 243 | "---------------------------------------------------------------------------\n", 244 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n", 245 | "---------------------------------------------------------------------------\n", 246 | " 35| QConv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n", 247 | "---------------------------------------------------------------------------\n", 248 | " 36| BatchNorm2d_36 | QConv2d_35 [(1, 256, 14, 14)] 256 \n", 249 | "---------------------------------------------------------------------------\n", 250 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n", 251 | "---------------------------------------------------------------------------\n", 252 | " 38| QConv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n", 253 | "---------------------------------------------------------------------------\n", 254 | " 39| BatchNorm2d_39 | QConv2d_38 [(1, 256, 14, 14)] 256 \n", 255 | "---------------------------------------------------------------------------\n", 256 | " 40| QConv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n", 257 | "---------------------------------------------------------------------------\n", 258 | " 41| BatchNorm2d_41 | QConv2d_40 [(1, 256, 14, 14)] 256 \n", 259 | "---------------------------------------------------------------------------\n", 260 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n", 261 | " | | BatchNorm2d_41 \n", 262 | "---------------------------------------------------------------------------\n", 263 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n", 264 | "---------------------------------------------------------------------------\n", 265 | " 44| QConv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n", 266 | "---------------------------------------------------------------------------\n", 267 | " 45| BatchNorm2d_45 | QConv2d_44 [(1, 256, 14, 14)] 256 \n", 268 | "---------------------------------------------------------------------------\n", 269 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n", 270 | "---------------------------------------------------------------------------\n", 271 | " 47| QConv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n", 272 | "---------------------------------------------------------------------------\n", 273 | " 48| BatchNorm2d_48 | QConv2d_47 [(1, 256, 14, 14)] 256 \n", 274 | "---------------------------------------------------------------------------\n", 275 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n", 276 | " | | ReLU_43 \n", 277 | "---------------------------------------------------------------------------\n", 278 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n", 279 | "---------------------------------------------------------------------------\n", 280 | " 51| QConv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n", 281 | "---------------------------------------------------------------------------\n", 282 | " 52| BatchNorm2d_52 | QConv2d_51 [(1, 512, 7, 7)] 512 \n", 283 | "---------------------------------------------------------------------------\n", 284 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n", 285 | "---------------------------------------------------------------------------\n", 286 | " 54| QConv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n", 287 | "---------------------------------------------------------------------------\n", 288 | " 55| BatchNorm2d_55 | QConv2d_54 [(1, 512, 7, 7)] 512 \n", 289 | "---------------------------------------------------------------------------\n", 290 | " 56| QConv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n", 291 | "---------------------------------------------------------------------------\n", 292 | " 57| BatchNorm2d_57 | QConv2d_56 [(1, 512, 7, 7)] 512 \n", 293 | "---------------------------------------------------------------------------\n", 294 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n", 295 | " | | BatchNorm2d_57 \n", 296 | "---------------------------------------------------------------------------\n", 297 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n", 298 | "---------------------------------------------------------------------------\n", 299 | " 60| QConv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n", 300 | "---------------------------------------------------------------------------\n", 301 | " 61| BatchNorm2d_61 | QConv2d_60 [(1, 512, 7, 7)] 512 \n", 302 | "---------------------------------------------------------------------------\n", 303 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n", 304 | "---------------------------------------------------------------------------\n", 305 | " 63| QConv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n", 306 | "---------------------------------------------------------------------------\n", 307 | " 64| BatchNorm2d_64 | QConv2d_63 [(1, 512, 7, 7)] 512 \n", 308 | "---------------------------------------------------------------------------\n", 309 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n", 310 | " | | ReLU_59 \n", 311 | "---------------------------------------------------------------------------\n", 312 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n", 313 | "---------------------------------------------------------------------------\n", 314 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n", 315 | "---------------------------------------------------------------------------\n", 316 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n", 317 | "---------------------------------------------------------------------------\n", 318 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n", 319 | "---------------------------------------------------------------------------\n", 320 | "==================================================================================\n", 321 | "Total Trainable params: 11683712 \n", 322 | "Total Non-Trainable params: 0 \n", 323 | "Total params: 11683712 \n" 324 | ] 325 | } 326 | ], 327 | "source": [ 328 | "input_tensor = torch.randn([1, 3, 224, 224]).cuda()\n", 329 | "model = model.cuda()\n", 330 | "transformer.summary(model, input_tensor = input_tensor)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [] 339 | } 340 | ], 341 | "metadata": { 342 | "kernelspec": { 343 | "display_name": "Python 3", 344 | "language": "python", 345 | "name": "python3" 346 | }, 347 | "language_info": { 348 | "codemirror_mode": { 349 | "name": "ipython", 350 | "version": 3 351 | }, 352 | "file_extension": ".py", 353 | "mimetype": "text/x-python", 354 | "name": "python", 355 | "nbconvert_exporter": "python", 356 | "pygments_lexer": "ipython3", 357 | "version": "3.6.9" 358 | } 359 | }, 360 | "nbformat": 4, 361 | "nbformat_minor": 2 362 | } 363 | -------------------------------------------------------------------------------- /transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/transformers/__init__.py -------------------------------------------------------------------------------- /transformers/quantize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module Source: 3 | https://github.com/eladhoffer/quantized.pytorch 4 | """ 5 | 6 | import torch 7 | from torch.autograd.function import InplaceFunction, Function 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import math 11 | 12 | 13 | class UniformQuantize(InplaceFunction): 14 | @staticmethod 15 | def forward(ctx, input, num_bits=8, min_value=None, max_value=None, inplace=False, symmetric=False, num_chunks=None): 16 | num_chunks = num_chunks = input.shape[0] if num_chunks is None else num_chunks 17 | if min_value is None or max_value is None: 18 | B = input.shape[0] 19 | y = input.view(B // num_chunks, -1) 20 | 21 | if min_value is None: 22 | min_value = y.min(-1)[0].mean(-1) # C 23 | #min_value = float(input.view(input.size(0), -1).min(-1)[0].mean()) 24 | 25 | if max_value is None: 26 | #max_value = float(input.view(input.size(0), -1).max(-1)[0].mean()) 27 | max_value = y.max(-1)[0].mean(-1) # C 28 | 29 | ctx.inplace = inplace 30 | ctx.num_bits = num_bits 31 | ctx.min_value = min_value 32 | ctx.max_value = max_value 33 | 34 | if ctx.inplace: 35 | ctx.mark_dirty(input) 36 | output = input 37 | 38 | else: 39 | output = input.clone() 40 | 41 | if symmetric: 42 | qmin = -2. ** (num_bits - 1) 43 | qmax = 2 ** (num_bits - 1) - 1 44 | max_value = torch.max(torch.abs(max_value), torch.abs(min_value)) 45 | min_value = 0. 46 | 47 | else: 48 | qmin = 0. 49 | qmax = 2. ** num_bits - 1. 50 | 51 | scale = (max_value - min_value) / (qmax - qmin) 52 | scale = max(scale, 1e-8) 53 | 54 | output.add_(-min_value).div_(scale) 55 | 56 | output.clamp_(qmin, qmax).round_() # quantize 57 | 58 | output.mul_(scale).add_(min_value) # dequantize 59 | 60 | return output 61 | 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | # straight-through estimator 66 | grad_input = grad_output 67 | return grad_input, None, None, None, None, None, None 68 | 69 | 70 | def quantize(x, num_bits=8, min_value=None, max_value=None, inplace=False, symmetric=False, num_chunks=None): 71 | return UniformQuantize().apply(x, num_bits, min_value, max_value, inplace, symmetric, num_chunks) 72 | 73 | 74 | class QuantMeasure(nn.Module): 75 | """docstring for QuantMeasure.""" 76 | 77 | def __init__(self, num_bits=8, momentum=0.1): 78 | super(QuantMeasure, self).__init__() 79 | self.register_buffer('running_min', torch.zeros(1)) 80 | self.register_buffer('running_max', torch.zeros(1)) 81 | self.momentum = momentum 82 | self.num_bits = num_bits 83 | 84 | 85 | def forward(self, input): 86 | if self.training: 87 | min_value = input.detach().view(input.size(0), -1).min(-1)[0].mean() 88 | max_value = input.detach().view(input.size(0), -1).max(-1)[0].mean() 89 | self.running_min.mul_(1 - self.momentum).add_(min_value * (self.momentum)) 90 | self.running_max.mul_(1 - self.momentum).add_(max_value * (self.momentum)) 91 | 92 | else: 93 | min_value = self.running_min 94 | max_value = self.running_max 95 | 96 | return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16) 97 | 98 | 99 | class QConv2d(nn.Conv2d): 100 | """docstring for QConv2d.""" 101 | 102 | def __init__(self, in_channels, out_channels, kernel_size, 103 | stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None): 104 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 105 | stride, padding, dilation, groups, bias) 106 | self.num_bits = num_bits 107 | self.num_bits_weight = num_bits_weight or num_bits 108 | 109 | 110 | def forward(self, input): 111 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 112 | min_value=float(self.weight.min()), 113 | max_value=float(self.weight.max())) 114 | if self.bias is not None: 115 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 116 | else: 117 | qbias = None 118 | 119 | output = F.conv2d(input, qweight, qbias, self.stride, 120 | self.padding, self.dilation, self.groups) 121 | 122 | return output 123 | 124 | 125 | class QLinear(nn.Linear): 126 | """docstring for QConv2d.""" 127 | 128 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False): 129 | super(QLinear, self).__init__(in_features, out_features, bias) 130 | self.num_bits = num_bits 131 | self.num_bits_weight = num_bits_weight or num_bits 132 | 133 | 134 | def forward(self, input): 135 | qweight = quantize(self.weight, num_bits=self.num_bits_weight, 136 | min_value=float(self.weight.min()), 137 | max_value=float(self.weight.max())) 138 | if self.bias is not None: 139 | qbias = quantize(self.bias, num_bits=self.num_bits_weight) 140 | else: 141 | qbias = None 142 | 143 | output = F.linear(input, qweight, qbias) 144 | 145 | return output 146 | -------------------------------------------------------------------------------- /transformers/torchTransformer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import types 4 | import inspect 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import pydot 12 | from graphviz import Digraph 13 | 14 | from .utils import _ReplaceFunc, Log, UnitLayer 15 | 16 | 17 | 18 | class TorchTransformer(nn.Module): 19 | """! 20 | This class handle layer swap, summary, visualization of the input model 21 | """ 22 | def __init__(self): 23 | super(TorchTransformer, self).__init__() 24 | 25 | self._register_dict = OrderedDict() 26 | self.log = Log() 27 | self._raw_TrochFuncs = OrderedDict() 28 | self._raw_TrochFunctionals = OrderedDict() 29 | 30 | # register class to trans 31 | def register(self, origin_class, target_class): 32 | """! 33 | This function register which class should transform to target class. 34 | """ 35 | print("register", origin_class, target_class) 36 | 37 | self._register_dict[origin_class] = target_class 38 | 39 | pass 40 | 41 | def trans_layers(self, model, update = True): 42 | """! 43 | This function transform layer by layers in register dictionarys 44 | 45 | @param model: input model to transfer 46 | 47 | @param update: default is True, wether to update the paramter from the orign layer or not. 48 | Note that it will update matched parameters only. 49 | 50 | @return transfered model 51 | """ 52 | # print("trans layer") 53 | if len(self._register_dict) == 0: 54 | print("No layer to swap") 55 | print("Please use register( {origin_layer}, {target_layer} ) to register layer") 56 | return model 57 | else: 58 | for module_name in model._modules: 59 | # has children 60 | if len(model._modules[module_name]._modules) > 0: 61 | self.trans_layers(model._modules[module_name]) 62 | else: 63 | if type(getattr(model, module_name)) in self._register_dict: 64 | # use inspect.signature to know args and kwargs of __init__ 65 | _sig = inspect.signature(type(getattr(model, module_name))) 66 | _kwargs = {} 67 | for key in _sig.parameters: 68 | if _sig.parameters[key].default == inspect.Parameter.empty: #args 69 | # assign args 70 | # default values should be handled more properly, unknown data type might be an issue 71 | if 'kernel' in key: 72 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=3) 73 | value = 3 74 | elif 'channel' in key: 75 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=32) 76 | value = 32 77 | else: 78 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=None) 79 | value = None 80 | 81 | _kwargs[key] = value 82 | 83 | _attr_dict = getattr(model, module_name).__dict__ 84 | _layer_new = self._register_dict[type(getattr(model, module_name))](**_kwargs) # only give positional args 85 | _layer_new.__dict__.update(_attr_dict) 86 | 87 | setattr(model, module_name, _layer_new) 88 | return model 89 | 90 | 91 | 92 | 93 | def summary(self, model = None, input_tensor = None): 94 | """! 95 | This function act like keras summary function 96 | 97 | @param model: input model to summary 98 | 99 | @param input_tensor: input data of the model to forward 100 | 101 | """ 102 | # input_tensor = torch.randn([1, 3, 224, 224]) 103 | # input_tensor = input_tensor.cuda() 104 | 105 | 106 | self._build_graph(model, input_tensor) 107 | 108 | # get dicts and variables 109 | model_graph = self.log.getGraph() 110 | bottoms_graph = self.log.getBottoms() 111 | output_shape_graph = self.log.getOutShapes() 112 | # store top names for bottoms 113 | topNames = OrderedDict() 114 | totoal_trainable_params = 0 115 | total_params = 0 116 | # loop graph 117 | print("##########################################################################################") 118 | line_title = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format("Index","Layer (type)", "Bottoms","Output Shape", "Param #") 119 | print(line_title) 120 | print("---------------------------------------------------------------------------") 121 | 122 | 123 | for layer_index, key in enumerate(model_graph): 124 | 125 | # data layer 126 | if bottoms_graph[key] is None: 127 | # Layer information 128 | layer = model_graph[key] 129 | layer_type = layer.__class__.__name__ 130 | if layer_type == "str": 131 | layer_type = key 132 | else: 133 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 134 | 135 | topNames[key] = layer_type 136 | 137 | # Layer Output shape 138 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 139 | 140 | # Layer Params 141 | param_weight_num = 0 142 | if hasattr(layer, "weight") and hasattr(layer.weight, "size"): 143 | param_weight_num += torch.prod(torch.LongTensor(list(layer.weight.size()))) 144 | if layer.weight.requires_grad: 145 | totoal_trainable_params += param_weight_num 146 | if hasattr(layer, "bias") and hasattr(layer.weight, "bias"): 147 | param_weight_num += torch.prod(torch.LongTensor(list(layer.bias.size()))) 148 | if layer.bias.requires_grad: 149 | totoal_trainable_params += param_weight_num 150 | 151 | total_params += param_weight_num 152 | 153 | new_layer = "{:5}| {:<15} | {:<15} {:<25} {:<15}".format(layer_index, layer_type, "", output_shape, param_weight_num) 154 | print(new_layer) 155 | 156 | else: 157 | # Layer Information 158 | layer = model_graph[key] 159 | layer_type = layer.__class__.__name__ 160 | 161 | # add, sub, mul...,etc. (custom string) 162 | if layer_type == "str": 163 | # the key should be XXX_{idx_prevent_duplicate} 164 | tmp_key = key.split("_") 165 | tmp_key[-1] = "_{}".format(layer_index) 166 | tmp_key = "".join(tmp_key) 167 | layer_type = tmp_key 168 | else: 169 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 170 | 171 | topNames[key] = layer_type 172 | 173 | # Layer Bottoms 174 | bottoms = [] 175 | for b_key in bottoms_graph[key]: 176 | bottom = topNames[b_key] 177 | bottoms.append(bottom) 178 | 179 | # Layer Output Shape 180 | if key in output_shape_graph: 181 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 182 | else: 183 | output_shape = "None" 184 | 185 | # Layer Params 186 | param_weight_num = 0 187 | if hasattr(layer, "weight") and hasattr(layer.weight, "size"): 188 | param_weight_num += torch.prod(torch.LongTensor(list(layer.weight.size()))) 189 | if layer.weight.requires_grad: 190 | totoal_trainable_params += param_weight_num 191 | if hasattr(layer, "bias") and hasattr(layer.weight, "bias"): 192 | param_weight_num += torch.prod(torch.LongTensor(list(layer.bias.size()))) 193 | if layer.bias.requires_grad: 194 | totoal_trainable_params += param_weight_num 195 | total_params += param_weight_num 196 | 197 | # Print (one bottom a line) 198 | for idx, b in enumerate(bottoms): 199 | # if more than one bottom, only print bottom 200 | if idx == 0: 201 | new_layer = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format(layer_index, layer_type, b, output_shape, param_weight_num) 202 | else: 203 | new_layer = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format("", "", b, "", "") 204 | print(new_layer) 205 | print("---------------------------------------------------------------------------") 206 | 207 | 208 | # total information 209 | print("==================================================================================") 210 | print("Total Trainable params: {} ".format(totoal_trainable_params)) 211 | print("Total Non-Trainable params: {} ".format(total_params - totoal_trainable_params)) 212 | print("Total params: {} ".format(total_params)) 213 | 214 | # del model_graph, bottoms_graph, output_shape_graph, topNames 215 | # return model 216 | 217 | def visualize(self, model = None, input_tensor = None, save_name = None, graph_size = 30): 218 | """! 219 | This functin visualize the model architecture 220 | 221 | @param model: input model to summary 222 | 223 | @param input_tensor: input data of the model to forward 224 | 225 | @param save_name: if save_name is not None, it will save as '{save_name}.png' 226 | 227 | @param graph_size: graph_size for graphviz, to help increase the resolution of the output graph 228 | 229 | @return dot, graphviz's Digraph element 230 | """ 231 | # input_tensor = torch.randn([1, 3, 224, 224]) 232 | # model_graph = self.log.getGraph() 233 | 234 | # if graph empty 235 | if model is None: 236 | # check if use self modules 237 | if len(self._modules) > 0: 238 | self._build_graph(self, input_tensor) 239 | else: 240 | raise ValueError("Please input model to visualize") 241 | else: 242 | self._build_graph(model, input_tensor) 243 | 244 | # graph 245 | node_attr = dict(style='filled', 246 | shape='box', 247 | align='left', 248 | fontsize='30', 249 | ranksep='0.1', 250 | height='0.2') 251 | 252 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="{},{}".format(graph_size, graph_size))) 253 | 254 | # get dicts and variables 255 | model_graph = self.log.getGraph() 256 | bottoms_graph = self.log.getBottoms() 257 | output_shape_graph = self.log.getOutShapes() 258 | topNames = OrderedDict() 259 | 260 | for layer_index, key in enumerate(model_graph): 261 | # Input Data layer 262 | if bottoms_graph[key] is None: 263 | layer = model_graph[key] 264 | layer_type = layer.__class__.__name__ 265 | # add, sub, mul...,etc. (custom string) 266 | if layer_type == "str": 267 | layer_type = key 268 | else: 269 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 270 | 271 | output_shape = "{}".format(tuple(output_shape_graph[key])) 272 | topNames[key] = layer_type 273 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 274 | layer_type = layer_type + "\nShape: " + output_shape 275 | 276 | dot.node(str(key), layer_type, fillcolor='orange') 277 | else: 278 | # Layer Information 279 | layer = model_graph[key] 280 | layer_type = layer.__class__.__name__ 281 | # add, sub, mul...,etc. (custom string) 282 | if layer_type == "str": 283 | # the key should be XXX_{idx_prevent_duplicate} 284 | tmp_key = key.split("_") 285 | tmp_key[-1] = "_{}".format(layer_index) 286 | tmp_key = "".join(tmp_key) 287 | layer_type = tmp_key 288 | else: 289 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index) 290 | 291 | topNames[key] = layer_type 292 | # layer_type = layer_type 293 | # print("Layer: {}".format(layer_type)) 294 | # print("Key: {}".format(key)) 295 | # add bottoms 296 | 297 | layer_type = layer_type + "\nBottoms: " 298 | for b_key in bottoms_graph[key]: 299 | layer_type = layer_type + topNames[b_key] + "\n" 300 | 301 | output_shape = "[{}]".format(tuple(output_shape_graph[key])) 302 | layer_type = layer_type + "Shape: " + output_shape 303 | 304 | dot.node(str(key), layer_type, fillcolor='orange') 305 | # link bottoms 306 | # print("Bottoms: ") 307 | for bot_key in bottoms_graph[key]: 308 | # print(bot_key) 309 | dot.edge(str(bot_key), str(key)) 310 | 311 | # return graph 312 | if save_name is not None: 313 | (graph,) = pydot.graph_from_dot_data(dot.source) 314 | graph.write_png(save_name + ".png" ) 315 | return dot 316 | 317 | def _build_graph(self, model, input_tensor = None): 318 | 319 | if input_tensor is None: 320 | raise ValueError("Please set input tensor") 321 | 322 | # reset log 323 | self.log = Log() 324 | # add Data input 325 | self.log.setTensor(input_tensor) 326 | 327 | 328 | tmp_model = self._trans_unit(copy.deepcopy(model)) 329 | 330 | for f in dir(torch): 331 | 332 | # if private function, pass 333 | if f.startswith("_") or "tensor" == f: 334 | continue 335 | if isinstance(getattr(torch, f) ,types.BuiltinMethodType) or isinstance(getattr(torch, f) ,types.BuiltinFunctionType): 336 | self._raw_TrochFuncs[f] = getattr(torch, f) 337 | setattr(torch, f, _ReplaceFunc(getattr(torch,f), self._torchFunctions)) 338 | 339 | for f in dir(F): 340 | # if private function, pass 341 | if f.startswith("_"): 342 | continue 343 | 344 | if isinstance(getattr(F, f) ,types.BuiltinMethodType) or isinstance(getattr(F, f) ,types.BuiltinFunctionType) or isinstance(getattr(F, f) ,types.FunctionType): 345 | self._raw_TrochFunctionals[f] = getattr(F, f) 346 | setattr(F, f, _ReplaceFunc(getattr(F,f), self._torchFunctionals)) 347 | 348 | 349 | self.log = tmp_model.forward(self.log) 350 | 351 | # reset back 352 | for f in self._raw_TrochFuncs: 353 | setattr(torch, f, self._raw_TrochFuncs[f]) 354 | 355 | for f in self._raw_TrochFunctionals: 356 | setattr(F, f, self._raw_TrochFunctionals[f]) 357 | 358 | del tmp_model 359 | 360 | def _trans_unit(self, model): 361 | # print("TRNS_UNIT") 362 | for module_name in model._modules: 363 | # has children 364 | if len(model._modules[module_name]._modules) > 0: 365 | self._trans_unit(model._modules[module_name]) 366 | else: 367 | unitlayer = UnitLayer(getattr(model, module_name)) 368 | setattr(model, module_name, unitlayer) 369 | 370 | return model 371 | 372 | def _torchFunctions(self, raw_func, *args, **kwargs): 373 | """! 374 | The replaced torch function (eg: torch.{function}) will go here 375 | """ 376 | # print("Torch function") 377 | function_name = raw_func.__name__ 378 | 379 | # torch function may has no input 380 | # so check first 381 | 382 | if len(args) > 0: 383 | logs = args[0] 384 | cur_args = args[1:] 385 | elif len(kwargs) > 0: 386 | 387 | return raw_func(**kwargs) 388 | else: 389 | return raw_func() 390 | 391 | # check is user used or in torch function call 392 | is_tensor_in = False 393 | # tensor input 394 | # multi tensor input 395 | if isinstance(logs, tuple) and (type(logs[0]) == torch.Tensor): 396 | cur_inputs = logs 397 | is_tensor_in = True 398 | return raw_func(*args, **kwargs) 399 | # single tensor input 400 | elif (type(logs) == torch.Tensor): 401 | 402 | cur_inputs = logs 403 | is_tensor_in = True 404 | # print(*args) 405 | # print(**kwargs) 406 | return raw_func(*args, **kwargs) 407 | elif (type(logs) == nn.Parameter): 408 | cur_inputs = logs 409 | is_tensor_in = True 410 | return raw_func(*args, **kwargs) 411 | # log input 412 | else: 413 | # multi inputs 414 | bottoms = [] 415 | cur_inputs = [] 416 | 417 | if isinstance(logs, tuple) or isinstance(logs, list): 418 | # may use origin input log as others' input 419 | # eg: densenet in torchvision 0.4.0 420 | cur_log = copy.deepcopy(logs[0]) 421 | for log in logs: 422 | cur_inputs.append(log.cur_tensor) 423 | # print(log.cur_tensor.size()) 424 | bottoms.append(log.cur_id) 425 | # update informations 426 | cur_log.graph.update(log.graph) 427 | cur_log.bottoms.update(log.bottoms) 428 | cur_log.output_shape.update(log.output_shape) 429 | cur_inputs = tuple(cur_inputs) 430 | # one input 431 | else: 432 | # print(args) 433 | # print(kwargs) 434 | cur_log = logs 435 | cur_inputs = cur_log.cur_tensor 436 | bottoms.append(cur_log.cur_id) 437 | 438 | # replace logs to tensor as function inputs to get output tensor 439 | args = list(args) 440 | args[0] = cur_inputs 441 | args = tuple(args) 442 | # send into origin functions 443 | #out_tensor = raw_func(*args, **kwargs).clone().detach() 444 | out_tensor = raw_func(*args, **kwargs).clone() 445 | 446 | # if function call, just return out tensor 447 | if is_tensor_in: 448 | return out_tensor 449 | 450 | # most multi input change to one output 451 | # most multi output has one input 452 | # if shape change 453 | # store theese types of opreation as a layer 454 | if isinstance(logs, tuple) or isinstance(logs, list) or isinstance(out_tensor, tuple) or (logs.cur_tensor.size() != out_tensor.size()): 455 | layer_name = "torch.{}_{}".format(function_name, len(cur_log.graph)) 456 | cur_log.graph[layer_name] = layer_name 457 | cur_log.bottoms[layer_name] = bottoms 458 | cur_log.cur_id = layer_name 459 | 460 | # multi output 461 | if not isinstance(out_tensor , torch.Tensor): 462 | # print("multi output") 463 | out_logs = [] 464 | for t in out_tensor: 465 | out_log = copy.deepcopy(cur_log) 466 | out_log.setTensor(t) 467 | out_logs.append(out_log) 468 | 469 | # sometimes will has (out, ) and this lens is >1 470 | if len(out_logs) == 1: 471 | out_logs = out_logs[0] 472 | return out_logs 473 | 474 | else: 475 | cur_log.setTensor(out_tensor) 476 | return cur_log 477 | 478 | # torch.functionals 479 | def _torchFunctionals(self, raw_func, *args, **kwargs): 480 | """! 481 | The replaced torch.functional function (eg: F.{function}) will go here 482 | """ 483 | # print("Functional") 484 | function_name = raw_func.__name__ 485 | # print(raw_func.__name__) 486 | 487 | # functional has input expect affine_grid 488 | if function_name == "affine_grid": 489 | pass 490 | else: 491 | logs = args[0] 492 | cur_args = args[1:] 493 | 494 | # check is user used or in torch function call 495 | is_tensor_in = False 496 | # tensor input 497 | if (len(logs) > 1) and (type(logs[0]) == torch.Tensor): 498 | # print(logs[0].size(), logs[1].size()) 499 | cur_inputs = logs 500 | is_tensor_in = True 501 | out = raw_func(*args, **kwargs) 502 | # print("Functional return : {}".format(out.size())) 503 | return raw_func(*args, **kwargs) 504 | 505 | elif (len(logs) ==1) and (type(logs) == torch.Tensor): 506 | cur_inputs = logs 507 | is_tensor_in = True 508 | out = raw_func(*args, **kwargs) 509 | # print("Functional return : {}".format(out.size())) 510 | return raw_func(*args, **kwargs) 511 | 512 | # log input 513 | else: 514 | # multi inputs 515 | bottoms = [] 516 | cur_inputs = [] 517 | if len(logs) > 1: 518 | cur_log = logs[0] 519 | for log in logs: 520 | cur_inputs.append(log.cur_tensor) 521 | bottoms.append(log.cur_id) 522 | # update informations 523 | cur_log.graph.update(log.graph) 524 | cur_log.bottoms.update(log.bottoms) 525 | cur_log.output_shape.update(log.output_shape) 526 | cur_inputs = tuple(cur_inputs) 527 | # one input 528 | else: 529 | cur_log = logs 530 | cur_inputs = cur_log.cur_tensor 531 | bottoms.append(cur_log.cur_id) 532 | 533 | 534 | 535 | # replace logs to tensor as function inputs to get output tensor 536 | args = list(args) 537 | args[0] = cur_inputs 538 | args = tuple(args) 539 | # send into origin functions 540 | #out_tensor = raw_func(*args, **kwargs).clone().detach() 541 | out_tensor = raw_func(*args, **kwargs).clone() 542 | 543 | # if function call, just return out tensor 544 | if is_tensor_in: 545 | return out_tensor 546 | 547 | # if log input and is function type, store as an layer 548 | if isinstance(raw_func, types.FunctionType): 549 | # use multiple address as name to prevent duplicate address 550 | layer_name = "F.{}_{}{}{}".format(function_name, id(out_tensor), id(args), id(kwargs)) 551 | # replace with new address if still duplicate 552 | while layer_name in cur_log.graph: 553 | #if layer_name in cur_log.graph: 554 | # tmp_list = [] 555 | # tmp_list.append(out_tensor) 556 | # tmp_tensor = copy.deepcopy(tmp_list[-1]) 557 | # tmp_tensor = tmp_list[-1].clone() 558 | tmp_tensor = torch.tensor([0]) 559 | 560 | # should not duplicate again? 561 | # layer_name = layer_name.split('.')[0] + "F" + ".{}_{}{}{}".format(function_name, id(tmp_tensor), id(args), id(kwargs)) 562 | layer_name = "F.{}_{}{}{}{}".format(function_name, id(tmp_tensor), id(args), id(kwargs), int((time.time()*100000)%1000000)) 563 | 564 | cur_log.graph[layer_name] = layer_name 565 | cur_log.bottoms[layer_name] = bottoms 566 | cur_log.cur_id = layer_name 567 | 568 | # if multi-output 569 | # if len(out_tensor) > 1: 570 | if not isinstance(out_tensor, torch.Tensor): 571 | out_logs = [] 572 | for t in out_tensor: 573 | out_log = copy.deepcopy(cur_log) 574 | out_log.setTensor(t) 575 | out_logs.append(out_log) 576 | 577 | return out_logs 578 | else: 579 | cur_log.setTensor(out_tensor) 580 | return cur_log 581 | 582 | 583 | -------------------------------------------------------------------------------- /transformers/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | 7 | class _ReplaceFunc(object): 8 | """! 9 | This Function replace torch functions with self-define Function. 10 | Inorder to get the imformation of torch model layer infomration. 11 | """ 12 | def __init__(self, ori_func, replace_func, **kwargs): 13 | self.torch_func = ori_func 14 | self.replace_func = replace_func 15 | 16 | def __call__(self, *args, **kwargs): 17 | out = self.replace_func(self.torch_func, *args, **kwargs) 18 | return out 19 | 20 | 21 | class Log(object): 22 | """! 23 | This class use as an log to replace input tensor and store all the information 24 | """ 25 | def __init__(self): 26 | self.graph = OrderedDict() 27 | self.bottoms = OrderedDict() 28 | self.output_shape = OrderedDict() 29 | self.cur_tensor = None 30 | self.cur_id = None 31 | self.tmp_list = None 32 | self.log_init() 33 | 34 | def __len__(self): 35 | """! 36 | Log should be one 37 | """ 38 | return 1 39 | 40 | def __copy__(self): 41 | """! 42 | copy, create new one and assign clone tensor in log 43 | """ 44 | copy_paster = Log() 45 | copy_paster.__dict__.update(self.__dict__) 46 | copy_paster.cur_tensor = self.cur_tensor.clone() 47 | return copy_paster 48 | 49 | def __deepcopy__(self, memo): 50 | """! 51 | deepcopy, create new one and assign clone tensor in log 52 | """ 53 | copy_paster = Log() 54 | copy_paster.__dict__.update(self.__dict__) 55 | copy_paster.cur_tensor = self.cur_tensor.clone() 56 | return copy_paster 57 | 58 | def reset(self): 59 | """ 60 | This function reset all attribute in log. 61 | """ 62 | self.graph = OrderedDict() 63 | self.bottoms = OrderedDict() 64 | self.output_shape = OrderedDict() 65 | self.cur_tensor = None 66 | self.cur_id = None 67 | self.tmp_list = [] 68 | self.log_init() 69 | 70 | 71 | # add data input layer to log 72 | def log_init(self): 73 | """ 74 | Init log attribute, set Data Layer as the first layer 75 | """ 76 | layer_id = "Data" 77 | self.graph[layer_id] = layer_id 78 | self.bottoms[layer_id] = None 79 | self.output_shape[layer_id] = "" 80 | self.cur_id = layer_id 81 | self.tmp_list = [] 82 | 83 | 84 | # for general layer (should has only one input?) 85 | def putLayer(self, layer): 86 | """! 87 | Put genreal layer's information into log 88 | """ 89 | # force use different address id ( prevent use same defined layer more than once, eg: bottleneck in torchvision) 90 | # tmp_layer = copy.deepcopy(layer) 91 | layer_id = id(layer) 92 | self.tmp_list.append(layer) 93 | layer_id = id(self.tmp_list[-1]) 94 | if layer_id in self.graph: 95 | tmp_layer = copy.deepcopy(layer) 96 | self.tmp_list.append(tmp_layer) 97 | # layer_id = id(self.tmp_list[-1]) 98 | layer_id = id(tmp_layer) 99 | 100 | self.graph[layer_id] = layer 101 | self.bottoms[layer_id] = [self.cur_id] 102 | self.cur_id = layer_id 103 | # del layer, tmp_layer, layer_id 104 | 105 | def getGraph(self): 106 | """! 107 | This function get the layers graph from log 108 | """ 109 | return self.graph 110 | 111 | def getBottoms(self): 112 | """! 113 | This function get the layers bottoms from log 114 | """ 115 | return self.bottoms 116 | 117 | def getOutShapes(self): 118 | """! 119 | This function get the layers output shape from log 120 | """ 121 | return self.output_shape 122 | 123 | def getTensor(self): 124 | """! 125 | This function get the layers current tensor (output tensor) 126 | """ 127 | return self.cur_tensor 128 | 129 | def setTensor(self, tensor): 130 | """! 131 | This function set the layer's current tensor 132 | and also change output shape by the input tensor 133 | """ 134 | self.cur_tensor = tensor 135 | if tensor is not None: 136 | self.output_shape[self.cur_id] = self.cur_tensor.size() 137 | else: 138 | self.output_shape[self.cur_id] = None 139 | 140 | 141 | # handle tensor operation(eg: tensor.view) 142 | def __getattr__(self, name): 143 | """! 144 | This function handle all the tensor operation 145 | """ 146 | if name == "__deepcopy__" or name == "__setstate__": 147 | return object.__getattribute__(self, name) 148 | # if get data => get cur_tensor.data 149 | elif name == "data": 150 | return self.cur_tensor.data 151 | 152 | elif hasattr(self.cur_tensor, name): 153 | def wrapper(*args, **kwargs): 154 | func = self.cur_tensor.__getattribute__(name) 155 | out_tensor = func(*args, **kwargs) 156 | 157 | if not isinstance(out_tensor, torch.Tensor): 158 | out_logs = [] 159 | for t in out_tensor: 160 | out_log = copy.deepcopy(self) 161 | out_log.setTensor(t) 162 | out_logs.append(out_log) 163 | 164 | return out_logs 165 | else: 166 | self.cur_tensor = out_tensor 167 | self.output_shape[self.cur_id] = out_tensor.size() 168 | 169 | return self 170 | # print(wrapper) 171 | return wrapper 172 | 173 | # return self 174 | 175 | 176 | else: 177 | return object.__getattribute__(self, name) 178 | 179 | 180 | def __add__(self, other): 181 | """! 182 | Log addition 183 | """ 184 | #print("add") 185 | # merge other branch 186 | self.graph.update(other.graph) 187 | self.bottoms.update(other.bottoms) 188 | self.output_shape.update(other.output_shape) 189 | layer_name = "add_{}".format(len(self.graph)) 190 | self.graph[layer_name] = layer_name 191 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 192 | self.output_shape[layer_name] = self.cur_tensor.size() 193 | self.cur_id = layer_name 194 | # save memory 195 | del other 196 | 197 | return self 198 | 199 | 200 | def __iadd__(self, other): 201 | """! 202 | Log identity addition 203 | """ 204 | #print("iadd") 205 | # merge other branch 206 | self.graph.update(other.graph) 207 | self.bottoms.update(other.bottoms) 208 | self.output_shape.update(other.output_shape) 209 | layer_name = "iadd_{}".format(len(self.graph)) 210 | self.graph[layer_name] = layer_name 211 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 212 | self.output_shape[layer_name] = self.cur_tensor.size() 213 | self.cur_id = layer_name 214 | # save memory 215 | del other 216 | return self 217 | 218 | 219 | def __sub__(self, other): 220 | """! 221 | Log substraction 222 | """ 223 | #print("sub") 224 | # merge other branch 225 | self.graph.update(other.graph) 226 | self.bottoms.update(other.bottoms) 227 | self.output_shape.update(other.output_shape) 228 | layer_name = "sub_{}".format(len(self.graph)) 229 | self.graph[layer_name] = layer_name 230 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 231 | self.output_shape[layer_name] = self.cur_tensor.size() 232 | self.cur_id = layer_name 233 | # save memory 234 | del other 235 | return self 236 | 237 | 238 | def __isub__(self, other): 239 | """! 240 | Log identity substraction 241 | """ 242 | #print("isub") 243 | # merge other branch 244 | self.graph.update(other.graph) 245 | self.bottoms.update(other.bottoms) 246 | self.output_shape.update(other.output_shape) 247 | layer_name = "sub_{}".format(len(self.graph)) 248 | self.graph[layer_name] = layer_name 249 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 250 | self.output_shape[layer_name] = self.cur_tensor.size() 251 | self.cur_id = layer_name 252 | # save memory 253 | del other 254 | return self 255 | 256 | 257 | def __mul__(self, other): 258 | """! 259 | Log multiplication 260 | """ 261 | #print("mul") 262 | # merge other branch 263 | self.graph.update(other.graph) 264 | self.bottoms.update(other.bottoms) 265 | self.output_shape.update(other.output_shape) 266 | layer_name = "mul_{}".format(len(self.graph)) 267 | self.graph[layer_name] = layer_name 268 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 269 | self.output_shape[layer_name] = self.cur_tensor.size() 270 | self.cur_id = layer_name 271 | # save memory 272 | del other 273 | return self 274 | 275 | 276 | def __imul__(self, other): 277 | """! 278 | Log identity multiplication 279 | """ 280 | #print("imul") 281 | # merge other branch 282 | self.graph.update(other.graph) 283 | self.bottoms.update(other.bottoms) 284 | self.output_shape.update(other.output_shape) 285 | layer_name = "mul_{}".format(len(self.graph)) 286 | self.graph[layer_name] = layer_name 287 | self.bottoms[layer_name] = [self.cur_id, other.cur_id] 288 | self.output_shape[layer_name] = self.cur_tensor.size() 289 | self.cur_id = layer_name 290 | # save memory 291 | del other 292 | return self 293 | 294 | 295 | def size(self, dim=None): 296 | """! 297 | This function return the size of the tensor by given dim 298 | 299 | @param dim: defult None, return as tensor.size(dim) 300 | 301 | @return tensor size by dim 302 | """ 303 | return self.cur_tensor.size(dim) if dim is not None else self.cur_tensor.size() 304 | 305 | 306 | 307 | class UnitLayer(nn.Module): 308 | """! 309 | This class is an Unit-layer act like an identity layer 310 | """ 311 | def __init__(self, ori_layer): 312 | super(UnitLayer, self).__init__() 313 | self.origin_layer = ori_layer 314 | 315 | 316 | def setOrigin(self, ori_layer): 317 | self.origin_layer = ori_layer 318 | 319 | 320 | # general layer should has only one input? 321 | def forward(self, log, *args): 322 | # prevent overwrite log for other forward flow 323 | cur_log = copy.deepcopy(log) 324 | # print(cur_log) 325 | cur_log.putLayer(self.origin_layer) 326 | 327 | # print(log.cur_tensor) 328 | log_tensor = log.getTensor() 329 | # out_tensor = self.origin_layer(log_tensor).clone().detach() 330 | out_tensor = self.origin_layer(log_tensor).clone() 331 | cur_log.setTensor(out_tensor) 332 | 333 | return cur_log -------------------------------------------------------------------------------- /visualize_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torchvision\n", 12 | "import torchvision.models as models\n", 13 | "\n", 14 | "from transformers.torchTransformer import TorchTransformer" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "scrolled": true 22 | }, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/plain": [ 27 | "AlexNet(\n", 28 | " (features): Sequential(\n", 29 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", 30 | " (1): ReLU(inplace=True)\n", 31 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 32 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 33 | " (4): ReLU(inplace=True)\n", 34 | " (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 35 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 36 | " (7): ReLU(inplace=True)\n", 37 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 38 | " (9): ReLU(inplace=True)\n", 39 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 40 | " (11): ReLU(inplace=True)\n", 41 | " (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 42 | " )\n", 43 | " (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n", 44 | " (classifier): Sequential(\n", 45 | " (0): Dropout(p=0.5, inplace=False)\n", 46 | " (1): Linear(in_features=9216, out_features=4096, bias=True)\n", 47 | " (2): ReLU(inplace=True)\n", 48 | " (3): Dropout(p=0.5, inplace=False)\n", 49 | " (4): Linear(in_features=4096, out_features=4096, bias=True)\n", 50 | " (5): ReLU(inplace=True)\n", 51 | " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", 52 | " )\n", 53 | ")" 54 | ] 55 | }, 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "model_name = \"alexnet\"\n", 63 | "model = models.__dict__[model_name]()\n", 64 | "model.eval()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "input_tensor = torch.randn([1, 3, 224, 224])" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## visualization" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "### without saving image" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "transformer = TorchTransformer()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "dot = transformer.visualize(model, input_tensor = input_tensor)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": { 112 | "scrolled": true 113 | }, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "image/svg+xml": [ 118 | "\n", 119 | "\n", 121 | "\n", 123 | "\n", 124 | "\n", 126 | "\n", 127 | "%3\n", 128 | "\n", 129 | "\n", 130 | "\n", 131 | "Data\n", 132 | "\n", 133 | "Data\n", 134 | "Shape: [(1, 3, 224, 224)]\n", 135 | "\n", 136 | "\n", 137 | "\n", 138 | "140437396772176\n", 139 | "\n", 140 | "Conv2d_1\n", 141 | "Bottoms: Data\n", 142 | "Shape: [(1, 64, 55, 55)]\n", 143 | "\n", 144 | "\n", 145 | "\n", 146 | "Data->140437396772176\n", 147 | "\n", 148 | "\n", 149 | "\n", 150 | "\n", 151 | "\n", 152 | "140437396773744\n", 153 | "\n", 154 | "ReLU_2\n", 155 | "Bottoms: Conv2d_1\n", 156 | "Shape: [(1, 64, 55, 55)]\n", 157 | "\n", 158 | "\n", 159 | "\n", 160 | "140437396772176->140437396773744\n", 161 | "\n", 162 | "\n", 163 | "\n", 164 | "\n", 165 | "\n", 166 | "140437394718280\n", 167 | "\n", 168 | "MaxPool2d_3\n", 169 | "Bottoms: ReLU_2\n", 170 | "Shape: [(1, 64, 27, 27)]\n", 171 | "\n", 172 | "\n", 173 | "\n", 174 | "140437396773744->140437394718280\n", 175 | "\n", 176 | "\n", 177 | "\n", 178 | "\n", 179 | "\n", 180 | "140437394761432\n", 181 | "\n", 182 | "Conv2d_4\n", 183 | "Bottoms: MaxPool2d_3\n", 184 | "Shape: [(1, 192, 27, 27)]\n", 185 | "\n", 186 | "\n", 187 | "\n", 188 | "140437394718280->140437394761432\n", 189 | "\n", 190 | "\n", 191 | "\n", 192 | "\n", 193 | "\n", 194 | "140437394760760\n", 195 | "\n", 196 | "ReLU_5\n", 197 | "Bottoms: Conv2d_4\n", 198 | "Shape: [(1, 192, 27, 27)]\n", 199 | "\n", 200 | "\n", 201 | "\n", 202 | "140437394761432->140437394760760\n", 203 | "\n", 204 | "\n", 205 | "\n", 206 | "\n", 207 | "\n", 208 | "140437394813616\n", 209 | "\n", 210 | "MaxPool2d_6\n", 211 | "Bottoms: ReLU_5\n", 212 | "Shape: [(1, 192, 13, 13)]\n", 213 | "\n", 214 | "\n", 215 | "\n", 216 | "140437394760760->140437394813616\n", 217 | "\n", 218 | "\n", 219 | "\n", 220 | "\n", 221 | "\n", 222 | "140437394834040\n", 223 | "\n", 224 | "Conv2d_7\n", 225 | "Bottoms: MaxPool2d_6\n", 226 | "Shape: [(1, 384, 13, 13)]\n", 227 | "\n", 228 | "\n", 229 | "\n", 230 | "140437394813616->140437394834040\n", 231 | "\n", 232 | "\n", 233 | "\n", 234 | "\n", 235 | "\n", 236 | "140437394857768\n", 237 | "\n", 238 | "ReLU_8\n", 239 | "Bottoms: Conv2d_7\n", 240 | "Shape: [(1, 384, 13, 13)]\n", 241 | "\n", 242 | "\n", 243 | "\n", 244 | "140437394834040->140437394857768\n", 245 | "\n", 246 | "\n", 247 | "\n", 248 | "\n", 249 | "\n", 250 | "140437394897496\n", 251 | "\n", 252 | "Conv2d_9\n", 253 | "Bottoms: ReLU_8\n", 254 | "Shape: [(1, 256, 13, 13)]\n", 255 | "\n", 256 | "\n", 257 | "\n", 258 | "140437394857768->140437394897496\n", 259 | "\n", 260 | "\n", 261 | "\n", 262 | "\n", 263 | "\n", 264 | "140437394945640\n", 265 | "\n", 266 | "ReLU_10\n", 267 | "Bottoms: Conv2d_9\n", 268 | "Shape: [(1, 256, 13, 13)]\n", 269 | "\n", 270 | "\n", 271 | "\n", 272 | "140437394897496->140437394945640\n", 273 | "\n", 274 | "\n", 275 | "\n", 276 | "\n", 277 | "\n", 278 | "140437394476448\n", 279 | "\n", 280 | "Conv2d_11\n", 281 | "Bottoms: ReLU_10\n", 282 | "Shape: [(1, 256, 13, 13)]\n", 283 | "\n", 284 | "\n", 285 | "\n", 286 | "140437394945640->140437394476448\n", 287 | "\n", 288 | "\n", 289 | "\n", 290 | "\n", 291 | "\n", 292 | "140437394497376\n", 293 | "\n", 294 | "ReLU_12\n", 295 | "Bottoms: Conv2d_11\n", 296 | "Shape: [(1, 256, 13, 13)]\n", 297 | "\n", 298 | "\n", 299 | "\n", 300 | "140437394476448->140437394497376\n", 301 | "\n", 302 | "\n", 303 | "\n", 304 | "\n", 305 | "\n", 306 | "140437394495920\n", 307 | "\n", 308 | "MaxPool2d_13\n", 309 | "Bottoms: ReLU_12\n", 310 | "Shape: [(1, 256, 6, 6)]\n", 311 | "\n", 312 | "\n", 313 | "\n", 314 | "140437394497376->140437394495920\n", 315 | "\n", 316 | "\n", 317 | "\n", 318 | "\n", 319 | "\n", 320 | "140437394496424\n", 321 | "\n", 322 | "AdaptiveAvgPool2d_14\n", 323 | "Bottoms: MaxPool2d_13\n", 324 | "Shape: [(1, 256, 6, 6)]\n", 325 | "\n", 326 | "\n", 327 | "\n", 328 | "140437394495920->140437394496424\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | "\n", 334 | "torch.flatten_15\n", 335 | "\n", 336 | "torch.flatten_15\n", 337 | "Bottoms: AdaptiveAvgPool2d_14\n", 338 | "Shape: [(1, 9216)]\n", 339 | "\n", 340 | "\n", 341 | "\n", 342 | "140437394496424->torch.flatten_15\n", 343 | "\n", 344 | "\n", 345 | "\n", 346 | "\n", 347 | "\n", 348 | "140437394496928\n", 349 | "\n", 350 | "Dropout_16\n", 351 | "Bottoms: torch.flatten_15\n", 352 | "Shape: [(1, 9216)]\n", 353 | "\n", 354 | "\n", 355 | "\n", 356 | "torch.flatten_15->140437394496928\n", 357 | "\n", 358 | "\n", 359 | "\n", 360 | "\n", 361 | "\n", 362 | "140437396656704\n", 363 | "\n", 364 | "Linear_17\n", 365 | "Bottoms: Dropout_16\n", 366 | "Shape: [(1, 4096)]\n", 367 | "\n", 368 | "\n", 369 | "\n", 370 | "140437394496928->140437396656704\n", 371 | "\n", 372 | "\n", 373 | "\n", 374 | "\n", 375 | "\n", 376 | "140437396693792\n", 377 | "\n", 378 | "ReLU_18\n", 379 | "Bottoms: Linear_17\n", 380 | "Shape: [(1, 4096)]\n", 381 | "\n", 382 | "\n", 383 | "\n", 384 | "140437396656704->140437396693792\n", 385 | "\n", 386 | "\n", 387 | "\n", 388 | "\n", 389 | "\n", 390 | "140437396625336\n", 391 | "\n", 392 | "Dropout_19\n", 393 | "Bottoms: ReLU_18\n", 394 | "Shape: [(1, 4096)]\n", 395 | "\n", 396 | "\n", 397 | "\n", 398 | "140437396693792->140437396625336\n", 399 | "\n", 400 | "\n", 401 | "\n", 402 | "\n", 403 | "\n", 404 | "140439190873704\n", 405 | "\n", 406 | "Linear_20\n", 407 | "Bottoms: Dropout_19\n", 408 | "Shape: [(1, 4096)]\n", 409 | "\n", 410 | "\n", 411 | "\n", 412 | "140437396625336->140439190873704\n", 413 | "\n", 414 | "\n", 415 | "\n", 416 | "\n", 417 | "\n", 418 | "140437394655440\n", 419 | "\n", 420 | "ReLU_21\n", 421 | "Bottoms: Linear_20\n", 422 | "Shape: [(1, 4096)]\n", 423 | "\n", 424 | "\n", 425 | "\n", 426 | "140439190873704->140437394655440\n", 427 | "\n", 428 | "\n", 429 | "\n", 430 | "\n", 431 | "\n", 432 | "140439262563408\n", 433 | "\n", 434 | "Linear_22\n", 435 | "Bottoms: ReLU_21\n", 436 | "Shape: [(1, 1000)]\n", 437 | "\n", 438 | "\n", 439 | "\n", 440 | "140437394655440->140439262563408\n", 441 | "\n", 442 | "\n", 443 | "\n", 444 | "\n", 445 | "\n" 446 | ], 447 | "text/plain": [ 448 | "" 449 | ] 450 | }, 451 | "execution_count": 6, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "dot" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "### save image" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 7, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "model.cuda()\n", 474 | "input_tensor = input_tensor.cuda()" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 8, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "dot = transformer.visualize(model, input_tensor = input_tensor, save_name = \"examples/{}\".format(model_name), graph_size = 80)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [] 492 | } 493 | ], 494 | "metadata": { 495 | "kernelspec": { 496 | "display_name": "Python 3", 497 | "language": "python", 498 | "name": "python3" 499 | }, 500 | "language_info": { 501 | "codemirror_mode": { 502 | "name": "ipython", 503 | "version": 3 504 | }, 505 | "file_extension": ".py", 506 | "mimetype": "text/x-python", 507 | "name": "python", 508 | "nbconvert_exporter": "python", 509 | "pygments_lexer": "ipython3", 510 | "version": "3.6.9" 511 | } 512 | }, 513 | "nbformat": 4, 514 | "nbformat_minor": 2 515 | } 516 | --------------------------------------------------------------------------------