├── .gitignore
├── README.md
├── examples
├── alexnet.png
├── googlenet.png
├── inception_v3.png
├── mobilenet_v2.png
├── resnet18.png
├── shufflenet_v2_x1_0.png
├── squeezenet1_0.png
└── vgg16.png
├── summary_example.ipynb
├── test.py
├── transform_example.ipynb
├── transformers
├── __init__.py
├── quantize.py
├── torchTransformer.py
└── utils.py
└── visualize_example.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | .ipynb_checkpoints
4 | *.py[cod]
5 | *$py.class
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTranformer
2 |
3 |
4 |
5 | ## summary
6 | This repository implement the summary function similar to keras summary()
7 |
8 | ```
9 | model = nn.Sequential(
10 | nn.Conv2d(3,20,5),
11 | nn.ReLU(),
12 | nn.Conv2d(20,64,5),
13 | nn.ReLU()
14 | )
15 |
16 | model.eval()
17 |
18 | transofrmer = TorchTransformer()
19 | input_tensor = torch.randn([1, 3, 224, 224])
20 | net = transofrmer.summary(model, input_tensor)
21 |
22 | ##########################################################################################
23 | Index| Layer (type) | Bottoms Output Shape Param #
24 | ---------------------------------------------------------------------------
25 | 1| Data | [(1, 3, 224, 224)] 0
26 | ---------------------------------------------------------------------------
27 | 2| Conv2d_1 | Data [(1, 20, 220, 220)] 1500
28 | ---------------------------------------------------------------------------
29 | 3| ReLU_2 | Conv2d_1 [(1, 20, 220, 220)] 0
30 | ---------------------------------------------------------------------------
31 | 4| Conv2d_3 | ReLU_2 [(1, 64, 216, 216)] 32000
32 | ---------------------------------------------------------------------------
33 | 5| ReLU_4 | Conv2d_3 [(1, 64, 216, 216)] 0
34 | ---------------------------------------------------------------------------
35 | ==================================================================================
36 | Total Trainable params: 33500
37 | Total Non-Trainable params: 0
38 | Total params: 33500
39 | ```
40 |
41 | other example is in [example.ipynb](summary_example.ipynb)
42 |
43 | ## visualize
44 | visualize using [graphviz](https://graphviz.readthedocs.io/en/stable/) and [pydot](https://pypi.org/project/pydot/)
45 | it will show the architecture.
46 | Such as alexnet in torchvision:
47 | ```
48 | model = models.__dict__["alexnet"]()
49 | model.eval()
50 | transofrmer = TorchTransformer()
51 | transofrmer.visualize(model, save_name= "example", graph_size = 80)
52 | # graph_size can modify to change the size of the output graph
53 | # graphviz does not auto fit the model's layers, which mean if the model is too deep.
54 | # And it will become too small to see.
55 | # So change the graph size to enlarge the image for higher resolution.
56 | ```
57 |
58 |
59 | example is in [example](visualize_example.ipynb)
60 | other example image is in [examples](/examples)
61 |
62 | ## transform layers
63 | you can register layer type to transform
64 | First you need to register to transformer and the transformer will transform layers you registered.
65 |
66 | example in in [transform_example](transform_example.ipynb)
67 |
68 |
69 |
70 |
71 | ## Note
72 | Suggest that the layers input should not be too many because the graphviz may generate image slow.(eg: densenet161 in torchvision 0.4.0 may stuck when generating png)
73 |
74 | ## TODO
75 | - [x] support registration(replace) for custom layertype
76 | - [ ] support replacement of specified layer in model for specified layer
77 | - [x] activation size calculation for supported layers
78 | - [x] network summary output as in keras
79 | - [x] model graph visualization
80 | - [ ] replace multiple modules to 1 module
81 | - [ ] conditional module replacement
82 | - [ ] add additional module to forward graph
83 |
--------------------------------------------------------------------------------
/examples/alexnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/alexnet.png
--------------------------------------------------------------------------------
/examples/googlenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/googlenet.png
--------------------------------------------------------------------------------
/examples/inception_v3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/inception_v3.png
--------------------------------------------------------------------------------
/examples/mobilenet_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/mobilenet_v2.png
--------------------------------------------------------------------------------
/examples/resnet18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/resnet18.png
--------------------------------------------------------------------------------
/examples/shufflenet_v2_x1_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/shufflenet_v2_x1_0.png
--------------------------------------------------------------------------------
/examples/squeezenet1_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/squeezenet1_0.png
--------------------------------------------------------------------------------
/examples/vgg16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/examples/vgg16.png
--------------------------------------------------------------------------------
/summary_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "import torchvision\n",
12 | "import torchvision.models as models\n",
13 | "\n",
14 | "from transformers.torchTransformer import TorchTransformer"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {
21 | "scrolled": false
22 | },
23 | "outputs": [
24 | {
25 | "data": {
26 | "text/plain": [
27 | "ResNet(\n",
28 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
29 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
30 | " (relu): ReLU(inplace=True)\n",
31 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
32 | " (layer1): Sequential(\n",
33 | " (0): BasicBlock(\n",
34 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
35 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
36 | " (relu): ReLU(inplace=True)\n",
37 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
38 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
39 | " )\n",
40 | " (1): BasicBlock(\n",
41 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
42 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
43 | " (relu): ReLU(inplace=True)\n",
44 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
45 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
46 | " )\n",
47 | " )\n",
48 | " (layer2): Sequential(\n",
49 | " (0): BasicBlock(\n",
50 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
51 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
52 | " (relu): ReLU(inplace=True)\n",
53 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
54 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
55 | " (downsample): Sequential(\n",
56 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
57 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
58 | " )\n",
59 | " )\n",
60 | " (1): BasicBlock(\n",
61 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
62 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
63 | " (relu): ReLU(inplace=True)\n",
64 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
65 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
66 | " )\n",
67 | " )\n",
68 | " (layer3): Sequential(\n",
69 | " (0): BasicBlock(\n",
70 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
71 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
72 | " (relu): ReLU(inplace=True)\n",
73 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
74 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
75 | " (downsample): Sequential(\n",
76 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
77 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
78 | " )\n",
79 | " )\n",
80 | " (1): BasicBlock(\n",
81 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
82 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
83 | " (relu): ReLU(inplace=True)\n",
84 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
85 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
86 | " )\n",
87 | " )\n",
88 | " (layer4): Sequential(\n",
89 | " (0): BasicBlock(\n",
90 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
91 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
92 | " (relu): ReLU(inplace=True)\n",
93 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
94 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
95 | " (downsample): Sequential(\n",
96 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
97 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
98 | " )\n",
99 | " )\n",
100 | " (1): BasicBlock(\n",
101 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
102 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
103 | " (relu): ReLU(inplace=True)\n",
104 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
105 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
106 | " )\n",
107 | " )\n",
108 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
109 | " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
110 | ")"
111 | ]
112 | },
113 | "execution_count": 2,
114 | "metadata": {},
115 | "output_type": "execute_result"
116 | }
117 | ],
118 | "source": [
119 | "model = models.__dict__[\"resnet18\"]()\n",
120 | "model.eval()"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 3,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "transformer = TorchTransformer()"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "## summary(cpu)"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 4,
142 | "metadata": {
143 | "scrolled": false
144 | },
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "##########################################################################################\n",
151 | "Index| Layer (type) | Bottoms Output Shape Param # \n",
152 | "---------------------------------------------------------------------------\n",
153 | " 0| Data | [(1, 3, 224, 224)] 0 \n",
154 | "---------------------------------------------------------------------------\n",
155 | " 1| Conv2d_1 | Data [(1, 64, 112, 112)] 9408 \n",
156 | "---------------------------------------------------------------------------\n",
157 | " 2| BatchNorm2d_2 | Conv2d_1 [(1, 64, 112, 112)] 64 \n",
158 | "---------------------------------------------------------------------------\n",
159 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n",
160 | "---------------------------------------------------------------------------\n",
161 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n",
162 | "---------------------------------------------------------------------------\n",
163 | " 5| Conv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n",
164 | "---------------------------------------------------------------------------\n",
165 | " 6| BatchNorm2d_6 | Conv2d_5 [(1, 64, 56, 56)] 64 \n",
166 | "---------------------------------------------------------------------------\n",
167 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n",
168 | "---------------------------------------------------------------------------\n",
169 | " 8| Conv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n",
170 | "---------------------------------------------------------------------------\n",
171 | " 9| BatchNorm2d_9 | Conv2d_8 [(1, 64, 56, 56)] 64 \n",
172 | "---------------------------------------------------------------------------\n",
173 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n",
174 | " | | MaxPool2d_4 \n",
175 | "---------------------------------------------------------------------------\n",
176 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n",
177 | "---------------------------------------------------------------------------\n",
178 | " 12| Conv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n",
179 | "---------------------------------------------------------------------------\n",
180 | " 13| BatchNorm2d_13 | Conv2d_12 [(1, 64, 56, 56)] 64 \n",
181 | "---------------------------------------------------------------------------\n",
182 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n",
183 | "---------------------------------------------------------------------------\n",
184 | " 15| Conv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n",
185 | "---------------------------------------------------------------------------\n",
186 | " 16| BatchNorm2d_16 | Conv2d_15 [(1, 64, 56, 56)] 64 \n",
187 | "---------------------------------------------------------------------------\n",
188 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n",
189 | " | | ReLU_11 \n",
190 | "---------------------------------------------------------------------------\n",
191 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n",
192 | "---------------------------------------------------------------------------\n",
193 | " 19| Conv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n",
194 | "---------------------------------------------------------------------------\n",
195 | " 20| BatchNorm2d_20 | Conv2d_19 [(1, 128, 28, 28)] 128 \n",
196 | "---------------------------------------------------------------------------\n",
197 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n",
198 | "---------------------------------------------------------------------------\n",
199 | " 22| Conv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n",
200 | "---------------------------------------------------------------------------\n",
201 | " 23| BatchNorm2d_23 | Conv2d_22 [(1, 128, 28, 28)] 128 \n",
202 | "---------------------------------------------------------------------------\n",
203 | " 24| Conv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n",
204 | "---------------------------------------------------------------------------\n",
205 | " 25| BatchNorm2d_25 | Conv2d_24 [(1, 128, 28, 28)] 128 \n",
206 | "---------------------------------------------------------------------------\n",
207 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n",
208 | " | | BatchNorm2d_25 \n",
209 | "---------------------------------------------------------------------------\n",
210 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n",
211 | "---------------------------------------------------------------------------\n",
212 | " 28| Conv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n",
213 | "---------------------------------------------------------------------------\n",
214 | " 29| BatchNorm2d_29 | Conv2d_28 [(1, 128, 28, 28)] 128 \n",
215 | "---------------------------------------------------------------------------\n",
216 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n",
217 | "---------------------------------------------------------------------------\n",
218 | " 31| Conv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n",
219 | "---------------------------------------------------------------------------\n",
220 | " 32| BatchNorm2d_32 | Conv2d_31 [(1, 128, 28, 28)] 128 \n",
221 | "---------------------------------------------------------------------------\n",
222 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n",
223 | " | | ReLU_27 \n",
224 | "---------------------------------------------------------------------------\n",
225 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n",
226 | "---------------------------------------------------------------------------\n",
227 | " 35| Conv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n",
228 | "---------------------------------------------------------------------------\n",
229 | " 36| BatchNorm2d_36 | Conv2d_35 [(1, 256, 14, 14)] 256 \n",
230 | "---------------------------------------------------------------------------\n",
231 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n",
232 | "---------------------------------------------------------------------------\n",
233 | " 38| Conv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n",
234 | "---------------------------------------------------------------------------\n",
235 | " 39| BatchNorm2d_39 | Conv2d_38 [(1, 256, 14, 14)] 256 \n",
236 | "---------------------------------------------------------------------------\n",
237 | " 40| Conv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n",
238 | "---------------------------------------------------------------------------\n",
239 | " 41| BatchNorm2d_41 | Conv2d_40 [(1, 256, 14, 14)] 256 \n",
240 | "---------------------------------------------------------------------------\n",
241 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n",
242 | " | | BatchNorm2d_41 \n",
243 | "---------------------------------------------------------------------------\n",
244 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n",
245 | "---------------------------------------------------------------------------\n",
246 | " 44| Conv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n",
247 | "---------------------------------------------------------------------------\n",
248 | " 45| BatchNorm2d_45 | Conv2d_44 [(1, 256, 14, 14)] 256 \n",
249 | "---------------------------------------------------------------------------\n",
250 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n",
251 | "---------------------------------------------------------------------------\n",
252 | " 47| Conv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n",
253 | "---------------------------------------------------------------------------\n",
254 | " 48| BatchNorm2d_48 | Conv2d_47 [(1, 256, 14, 14)] 256 \n",
255 | "---------------------------------------------------------------------------\n",
256 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n",
257 | " | | ReLU_43 \n",
258 | "---------------------------------------------------------------------------\n",
259 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n",
260 | "---------------------------------------------------------------------------\n",
261 | " 51| Conv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n",
262 | "---------------------------------------------------------------------------\n",
263 | " 52| BatchNorm2d_52 | Conv2d_51 [(1, 512, 7, 7)] 512 \n",
264 | "---------------------------------------------------------------------------\n",
265 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n",
266 | "---------------------------------------------------------------------------\n",
267 | " 54| Conv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n",
268 | "---------------------------------------------------------------------------\n",
269 | " 55| BatchNorm2d_55 | Conv2d_54 [(1, 512, 7, 7)] 512 \n",
270 | "---------------------------------------------------------------------------\n",
271 | " 56| Conv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n",
272 | "---------------------------------------------------------------------------\n",
273 | " 57| BatchNorm2d_57 | Conv2d_56 [(1, 512, 7, 7)] 512 \n",
274 | "---------------------------------------------------------------------------\n",
275 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n",
276 | " | | BatchNorm2d_57 \n",
277 | "---------------------------------------------------------------------------\n",
278 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n",
279 | "---------------------------------------------------------------------------\n",
280 | " 60| Conv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n",
281 | "---------------------------------------------------------------------------\n",
282 | " 61| BatchNorm2d_61 | Conv2d_60 [(1, 512, 7, 7)] 512 \n",
283 | "---------------------------------------------------------------------------\n",
284 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n",
285 | "---------------------------------------------------------------------------\n",
286 | " 63| Conv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n",
287 | "---------------------------------------------------------------------------\n",
288 | " 64| BatchNorm2d_64 | Conv2d_63 [(1, 512, 7, 7)] 512 \n",
289 | "---------------------------------------------------------------------------\n",
290 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n",
291 | " | | ReLU_59 \n",
292 | "---------------------------------------------------------------------------\n",
293 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n",
294 | "---------------------------------------------------------------------------\n",
295 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n",
296 | "---------------------------------------------------------------------------\n",
297 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n",
298 | "---------------------------------------------------------------------------\n",
299 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n",
300 | "---------------------------------------------------------------------------\n",
301 | "==================================================================================\n",
302 | "Total Trainable params: 11683712 \n",
303 | "Total Non-Trainable params: 0 \n",
304 | "Total params: 11683712 \n"
305 | ]
306 | }
307 | ],
308 | "source": [
309 | "input_tensor = torch.randn([1, 3, 224, 224])\n",
310 | "transformer.summary(model, input_tensor = input_tensor)"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "metadata": {},
316 | "source": [
317 | "## summary(gpu)"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 5,
323 | "metadata": {},
324 | "outputs": [],
325 | "source": [
326 | "model.cuda()\n",
327 | "input_tensor = input_tensor.cuda()"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": 6,
333 | "metadata": {
334 | "scrolled": false
335 | },
336 | "outputs": [
337 | {
338 | "name": "stdout",
339 | "output_type": "stream",
340 | "text": [
341 | "##########################################################################################\n",
342 | "Index| Layer (type) | Bottoms Output Shape Param # \n",
343 | "---------------------------------------------------------------------------\n",
344 | " 0| Data | [(1, 3, 224, 224)] 0 \n",
345 | "---------------------------------------------------------------------------\n",
346 | " 1| Conv2d_1 | Data [(1, 64, 112, 112)] 9408 \n",
347 | "---------------------------------------------------------------------------\n",
348 | " 2| BatchNorm2d_2 | Conv2d_1 [(1, 64, 112, 112)] 64 \n",
349 | "---------------------------------------------------------------------------\n",
350 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n",
351 | "---------------------------------------------------------------------------\n",
352 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n",
353 | "---------------------------------------------------------------------------\n",
354 | " 5| Conv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n",
355 | "---------------------------------------------------------------------------\n",
356 | " 6| BatchNorm2d_6 | Conv2d_5 [(1, 64, 56, 56)] 64 \n",
357 | "---------------------------------------------------------------------------\n",
358 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n",
359 | "---------------------------------------------------------------------------\n",
360 | " 8| Conv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n",
361 | "---------------------------------------------------------------------------\n",
362 | " 9| BatchNorm2d_9 | Conv2d_8 [(1, 64, 56, 56)] 64 \n",
363 | "---------------------------------------------------------------------------\n",
364 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n",
365 | " | | MaxPool2d_4 \n",
366 | "---------------------------------------------------------------------------\n",
367 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n",
368 | "---------------------------------------------------------------------------\n",
369 | " 12| Conv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n",
370 | "---------------------------------------------------------------------------\n",
371 | " 13| BatchNorm2d_13 | Conv2d_12 [(1, 64, 56, 56)] 64 \n",
372 | "---------------------------------------------------------------------------\n",
373 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n",
374 | "---------------------------------------------------------------------------\n",
375 | " 15| Conv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n",
376 | "---------------------------------------------------------------------------\n",
377 | " 16| BatchNorm2d_16 | Conv2d_15 [(1, 64, 56, 56)] 64 \n",
378 | "---------------------------------------------------------------------------\n",
379 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n",
380 | " | | ReLU_11 \n",
381 | "---------------------------------------------------------------------------\n",
382 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n",
383 | "---------------------------------------------------------------------------\n",
384 | " 19| Conv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n",
385 | "---------------------------------------------------------------------------\n",
386 | " 20| BatchNorm2d_20 | Conv2d_19 [(1, 128, 28, 28)] 128 \n",
387 | "---------------------------------------------------------------------------\n",
388 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n",
389 | "---------------------------------------------------------------------------\n",
390 | " 22| Conv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n",
391 | "---------------------------------------------------------------------------\n",
392 | " 23| BatchNorm2d_23 | Conv2d_22 [(1, 128, 28, 28)] 128 \n",
393 | "---------------------------------------------------------------------------\n",
394 | " 24| Conv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n",
395 | "---------------------------------------------------------------------------\n",
396 | " 25| BatchNorm2d_25 | Conv2d_24 [(1, 128, 28, 28)] 128 \n",
397 | "---------------------------------------------------------------------------\n",
398 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n",
399 | " | | BatchNorm2d_25 \n",
400 | "---------------------------------------------------------------------------\n",
401 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n",
402 | "---------------------------------------------------------------------------\n",
403 | " 28| Conv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n",
404 | "---------------------------------------------------------------------------\n",
405 | " 29| BatchNorm2d_29 | Conv2d_28 [(1, 128, 28, 28)] 128 \n",
406 | "---------------------------------------------------------------------------\n",
407 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n",
408 | "---------------------------------------------------------------------------\n",
409 | " 31| Conv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n",
410 | "---------------------------------------------------------------------------\n",
411 | " 32| BatchNorm2d_32 | Conv2d_31 [(1, 128, 28, 28)] 128 \n",
412 | "---------------------------------------------------------------------------\n",
413 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n",
414 | " | | ReLU_27 \n",
415 | "---------------------------------------------------------------------------\n",
416 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n",
417 | "---------------------------------------------------------------------------\n",
418 | " 35| Conv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n",
419 | "---------------------------------------------------------------------------\n",
420 | " 36| BatchNorm2d_36 | Conv2d_35 [(1, 256, 14, 14)] 256 \n",
421 | "---------------------------------------------------------------------------\n",
422 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n",
423 | "---------------------------------------------------------------------------\n",
424 | " 38| Conv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n",
425 | "---------------------------------------------------------------------------\n",
426 | " 39| BatchNorm2d_39 | Conv2d_38 [(1, 256, 14, 14)] 256 \n",
427 | "---------------------------------------------------------------------------\n",
428 | " 40| Conv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n",
429 | "---------------------------------------------------------------------------\n",
430 | " 41| BatchNorm2d_41 | Conv2d_40 [(1, 256, 14, 14)] 256 \n",
431 | "---------------------------------------------------------------------------\n",
432 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n",
433 | " | | BatchNorm2d_41 \n",
434 | "---------------------------------------------------------------------------\n",
435 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n",
436 | "---------------------------------------------------------------------------\n",
437 | " 44| Conv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n",
438 | "---------------------------------------------------------------------------\n",
439 | " 45| BatchNorm2d_45 | Conv2d_44 [(1, 256, 14, 14)] 256 \n",
440 | "---------------------------------------------------------------------------\n",
441 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n",
442 | "---------------------------------------------------------------------------\n",
443 | " 47| Conv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n",
444 | "---------------------------------------------------------------------------\n",
445 | " 48| BatchNorm2d_48 | Conv2d_47 [(1, 256, 14, 14)] 256 \n",
446 | "---------------------------------------------------------------------------\n",
447 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n",
448 | " | | ReLU_43 \n",
449 | "---------------------------------------------------------------------------\n",
450 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n",
451 | "---------------------------------------------------------------------------\n",
452 | " 51| Conv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n",
453 | "---------------------------------------------------------------------------\n",
454 | " 52| BatchNorm2d_52 | Conv2d_51 [(1, 512, 7, 7)] 512 \n",
455 | "---------------------------------------------------------------------------\n",
456 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n",
457 | "---------------------------------------------------------------------------\n",
458 | " 54| Conv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n",
459 | "---------------------------------------------------------------------------\n",
460 | " 55| BatchNorm2d_55 | Conv2d_54 [(1, 512, 7, 7)] 512 \n",
461 | "---------------------------------------------------------------------------\n",
462 | " 56| Conv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n",
463 | "---------------------------------------------------------------------------\n",
464 | " 57| BatchNorm2d_57 | Conv2d_56 [(1, 512, 7, 7)] 512 \n",
465 | "---------------------------------------------------------------------------\n",
466 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n",
467 | " | | BatchNorm2d_57 \n",
468 | "---------------------------------------------------------------------------\n",
469 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n",
470 | "---------------------------------------------------------------------------\n",
471 | " 60| Conv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n",
472 | "---------------------------------------------------------------------------\n",
473 | " 61| BatchNorm2d_61 | Conv2d_60 [(1, 512, 7, 7)] 512 \n",
474 | "---------------------------------------------------------------------------\n",
475 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n",
476 | "---------------------------------------------------------------------------\n",
477 | " 63| Conv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n",
478 | "---------------------------------------------------------------------------\n",
479 | " 64| BatchNorm2d_64 | Conv2d_63 [(1, 512, 7, 7)] 512 \n",
480 | "---------------------------------------------------------------------------\n",
481 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n",
482 | " | | ReLU_59 \n",
483 | "---------------------------------------------------------------------------\n",
484 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n",
485 | "---------------------------------------------------------------------------\n",
486 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n",
487 | "---------------------------------------------------------------------------\n",
488 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n",
489 | "---------------------------------------------------------------------------\n",
490 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n",
491 | "---------------------------------------------------------------------------\n",
492 | "==================================================================================\n",
493 | "Total Trainable params: 11683712 \n",
494 | "Total Non-Trainable params: 0 \n",
495 | "Total params: 11683712 \n"
496 | ]
497 | }
498 | ],
499 | "source": [
500 | "transformer.summary(model, input_tensor = input_tensor)"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": null,
506 | "metadata": {},
507 | "outputs": [],
508 | "source": []
509 | }
510 | ],
511 | "metadata": {
512 | "kernelspec": {
513 | "display_name": "Python 3",
514 | "language": "python",
515 | "name": "python3"
516 | },
517 | "language_info": {
518 | "codemirror_mode": {
519 | "name": "ipython",
520 | "version": 3
521 | },
522 | "file_extension": ".py",
523 | "mimetype": "text/x-python",
524 | "name": "python",
525 | "nbconvert_exporter": "python",
526 | "pygments_lexer": "ipython3",
527 | "version": "3.6.9"
528 | }
529 | },
530 | "nbformat": 4,
531 | "nbformat_minor": 2
532 | }
533 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.models as models
5 | import copy
6 | from transformers.torchTransformer import TorchTransformer
7 | from transformers.quantize import QConv2d
8 | model = models.__dict__["resnet18"]()
9 | model.cuda()
10 | model = model.eval()
11 |
12 | transofrmer = TorchTransformer()
13 | transofrmer.register(nn.Conv2d, QConv2d)
14 | model = transofrmer.trans_layers(model)
15 | print(model)
16 | sys.exit()
17 |
18 |
19 | input_tensor = torch.randn([1, 3, 224, 224])
20 | input_tensor = input_tensor.cuda()
21 | net = transofrmer.summary(model, input_tensor=input_tensor)
22 | # transofrmer.visualize(model, input_tensor = input_tensor, save_name= "example", graph_size = 80)
--------------------------------------------------------------------------------
/transform_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "import torchvision\n",
12 | "import torchvision.models as models\n",
13 | "\n",
14 | "from transformers.torchTransformer import TorchTransformer\n",
15 | "from transformers.quantize import QConv2d"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 2,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "model = models.__dict__[\"resnet18\"]()\n",
25 | "model.cuda()\n",
26 | "model = model.eval()\n"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "## Register layer to be transform"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 3,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "register \n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "transformer = TorchTransformer()\n",
51 | "transformer.register(nn.Conv2d, QConv2d)\n",
52 | "model = transformer.trans_layers(model)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 4,
58 | "metadata": {
59 | "scrolled": true
60 | },
61 | "outputs": [
62 | {
63 | "data": {
64 | "text/plain": [
65 | "ResNet(\n",
66 | " (conv1): QConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
67 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
68 | " (relu): ReLU(inplace=True)\n",
69 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
70 | " (layer1): Sequential(\n",
71 | " (0): BasicBlock(\n",
72 | " (conv1): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
73 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
74 | " (relu): ReLU(inplace=True)\n",
75 | " (conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
76 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
77 | " )\n",
78 | " (1): BasicBlock(\n",
79 | " (conv1): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
81 | " (relu): ReLU(inplace=True)\n",
82 | " (conv2): QConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
83 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
84 | " )\n",
85 | " )\n",
86 | " (layer2): Sequential(\n",
87 | " (0): BasicBlock(\n",
88 | " (conv1): QConv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
89 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
90 | " (relu): ReLU(inplace=True)\n",
91 | " (conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
92 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
93 | " (downsample): Sequential(\n",
94 | " (0): QConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
95 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
96 | " )\n",
97 | " )\n",
98 | " (1): BasicBlock(\n",
99 | " (conv1): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
100 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
101 | " (relu): ReLU(inplace=True)\n",
102 | " (conv2): QConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
103 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
104 | " )\n",
105 | " )\n",
106 | " (layer3): Sequential(\n",
107 | " (0): BasicBlock(\n",
108 | " (conv1): QConv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
109 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
110 | " (relu): ReLU(inplace=True)\n",
111 | " (conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
112 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
113 | " (downsample): Sequential(\n",
114 | " (0): QConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
115 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
116 | " )\n",
117 | " )\n",
118 | " (1): BasicBlock(\n",
119 | " (conv1): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
120 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
121 | " (relu): ReLU(inplace=True)\n",
122 | " (conv2): QConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
123 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
124 | " )\n",
125 | " )\n",
126 | " (layer4): Sequential(\n",
127 | " (0): BasicBlock(\n",
128 | " (conv1): QConv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
129 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
130 | " (relu): ReLU(inplace=True)\n",
131 | " (conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
132 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
133 | " (downsample): Sequential(\n",
134 | " (0): QConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
135 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
136 | " )\n",
137 | " )\n",
138 | " (1): BasicBlock(\n",
139 | " (conv1): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
140 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
141 | " (relu): ReLU(inplace=True)\n",
142 | " (conv2): QConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
143 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
144 | " )\n",
145 | " )\n",
146 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
147 | " (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
148 | ")"
149 | ]
150 | },
151 | "execution_count": 4,
152 | "metadata": {},
153 | "output_type": "execute_result"
154 | }
155 | ],
156 | "source": [
157 | "model"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 5,
163 | "metadata": {},
164 | "outputs": [
165 | {
166 | "name": "stdout",
167 | "output_type": "stream",
168 | "text": [
169 | "##########################################################################################\n",
170 | "Index| Layer (type) | Bottoms Output Shape Param # \n",
171 | "---------------------------------------------------------------------------\n",
172 | " 0| Data | [(1, 3, 224, 224)] 0 \n",
173 | "---------------------------------------------------------------------------\n",
174 | " 1| QConv2d_1 | Data [(1, 64, 112, 112)] 9408 \n",
175 | "---------------------------------------------------------------------------\n",
176 | " 2| BatchNorm2d_2 | QConv2d_1 [(1, 64, 112, 112)] 64 \n",
177 | "---------------------------------------------------------------------------\n",
178 | " 3| ReLU_3 | BatchNorm2d_2 [(1, 64, 112, 112)] 0 \n",
179 | "---------------------------------------------------------------------------\n",
180 | " 4| MaxPool2d_4 | ReLU_3 [(1, 64, 56, 56)] 0 \n",
181 | "---------------------------------------------------------------------------\n",
182 | " 5| QConv2d_5 | MaxPool2d_4 [(1, 64, 56, 56)] 36864 \n",
183 | "---------------------------------------------------------------------------\n",
184 | " 6| BatchNorm2d_6 | QConv2d_5 [(1, 64, 56, 56)] 64 \n",
185 | "---------------------------------------------------------------------------\n",
186 | " 7| ReLU_7 | BatchNorm2d_6 [(1, 64, 56, 56)] 0 \n",
187 | "---------------------------------------------------------------------------\n",
188 | " 8| QConv2d_8 | ReLU_7 [(1, 64, 56, 56)] 36864 \n",
189 | "---------------------------------------------------------------------------\n",
190 | " 9| BatchNorm2d_9 | QConv2d_8 [(1, 64, 56, 56)] 64 \n",
191 | "---------------------------------------------------------------------------\n",
192 | " 10| iadd_10 | BatchNorm2d_9 [(1, 64, 56, 56)] 0 \n",
193 | " | | MaxPool2d_4 \n",
194 | "---------------------------------------------------------------------------\n",
195 | " 11| ReLU_11 | iadd_10 [(1, 64, 56, 56)] 0 \n",
196 | "---------------------------------------------------------------------------\n",
197 | " 12| QConv2d_12 | ReLU_11 [(1, 64, 56, 56)] 36864 \n",
198 | "---------------------------------------------------------------------------\n",
199 | " 13| BatchNorm2d_13 | QConv2d_12 [(1, 64, 56, 56)] 64 \n",
200 | "---------------------------------------------------------------------------\n",
201 | " 14| ReLU_14 | BatchNorm2d_13 [(1, 64, 56, 56)] 0 \n",
202 | "---------------------------------------------------------------------------\n",
203 | " 15| QConv2d_15 | ReLU_14 [(1, 64, 56, 56)] 36864 \n",
204 | "---------------------------------------------------------------------------\n",
205 | " 16| BatchNorm2d_16 | QConv2d_15 [(1, 64, 56, 56)] 64 \n",
206 | "---------------------------------------------------------------------------\n",
207 | " 17| iadd_17 | BatchNorm2d_16 [(1, 64, 56, 56)] 0 \n",
208 | " | | ReLU_11 \n",
209 | "---------------------------------------------------------------------------\n",
210 | " 18| ReLU_18 | iadd_17 [(1, 64, 56, 56)] 0 \n",
211 | "---------------------------------------------------------------------------\n",
212 | " 19| QConv2d_19 | ReLU_18 [(1, 128, 28, 28)] 73728 \n",
213 | "---------------------------------------------------------------------------\n",
214 | " 20| BatchNorm2d_20 | QConv2d_19 [(1, 128, 28, 28)] 128 \n",
215 | "---------------------------------------------------------------------------\n",
216 | " 21| ReLU_21 | BatchNorm2d_20 [(1, 128, 28, 28)] 0 \n",
217 | "---------------------------------------------------------------------------\n",
218 | " 22| QConv2d_22 | ReLU_21 [(1, 128, 28, 28)] 147456 \n",
219 | "---------------------------------------------------------------------------\n",
220 | " 23| BatchNorm2d_23 | QConv2d_22 [(1, 128, 28, 28)] 128 \n",
221 | "---------------------------------------------------------------------------\n",
222 | " 24| QConv2d_24 | ReLU_18 [(1, 128, 28, 28)] 8192 \n",
223 | "---------------------------------------------------------------------------\n",
224 | " 25| BatchNorm2d_25 | QConv2d_24 [(1, 128, 28, 28)] 128 \n",
225 | "---------------------------------------------------------------------------\n",
226 | " 26| iadd_26 | BatchNorm2d_23 [(1, 128, 28, 28)] 0 \n",
227 | " | | BatchNorm2d_25 \n",
228 | "---------------------------------------------------------------------------\n",
229 | " 27| ReLU_27 | iadd_26 [(1, 128, 28, 28)] 0 \n",
230 | "---------------------------------------------------------------------------\n",
231 | " 28| QConv2d_28 | ReLU_27 [(1, 128, 28, 28)] 147456 \n",
232 | "---------------------------------------------------------------------------\n",
233 | " 29| BatchNorm2d_29 | QConv2d_28 [(1, 128, 28, 28)] 128 \n",
234 | "---------------------------------------------------------------------------\n",
235 | " 30| ReLU_30 | BatchNorm2d_29 [(1, 128, 28, 28)] 0 \n",
236 | "---------------------------------------------------------------------------\n",
237 | " 31| QConv2d_31 | ReLU_30 [(1, 128, 28, 28)] 147456 \n",
238 | "---------------------------------------------------------------------------\n",
239 | " 32| BatchNorm2d_32 | QConv2d_31 [(1, 128, 28, 28)] 128 \n",
240 | "---------------------------------------------------------------------------\n",
241 | " 33| iadd_33 | BatchNorm2d_32 [(1, 128, 28, 28)] 0 \n",
242 | " | | ReLU_27 \n",
243 | "---------------------------------------------------------------------------\n",
244 | " 34| ReLU_34 | iadd_33 [(1, 128, 28, 28)] 0 \n",
245 | "---------------------------------------------------------------------------\n",
246 | " 35| QConv2d_35 | ReLU_34 [(1, 256, 14, 14)] 294912 \n",
247 | "---------------------------------------------------------------------------\n",
248 | " 36| BatchNorm2d_36 | QConv2d_35 [(1, 256, 14, 14)] 256 \n",
249 | "---------------------------------------------------------------------------\n",
250 | " 37| ReLU_37 | BatchNorm2d_36 [(1, 256, 14, 14)] 0 \n",
251 | "---------------------------------------------------------------------------\n",
252 | " 38| QConv2d_38 | ReLU_37 [(1, 256, 14, 14)] 589824 \n",
253 | "---------------------------------------------------------------------------\n",
254 | " 39| BatchNorm2d_39 | QConv2d_38 [(1, 256, 14, 14)] 256 \n",
255 | "---------------------------------------------------------------------------\n",
256 | " 40| QConv2d_40 | ReLU_34 [(1, 256, 14, 14)] 32768 \n",
257 | "---------------------------------------------------------------------------\n",
258 | " 41| BatchNorm2d_41 | QConv2d_40 [(1, 256, 14, 14)] 256 \n",
259 | "---------------------------------------------------------------------------\n",
260 | " 42| iadd_42 | BatchNorm2d_39 [(1, 256, 14, 14)] 0 \n",
261 | " | | BatchNorm2d_41 \n",
262 | "---------------------------------------------------------------------------\n",
263 | " 43| ReLU_43 | iadd_42 [(1, 256, 14, 14)] 0 \n",
264 | "---------------------------------------------------------------------------\n",
265 | " 44| QConv2d_44 | ReLU_43 [(1, 256, 14, 14)] 589824 \n",
266 | "---------------------------------------------------------------------------\n",
267 | " 45| BatchNorm2d_45 | QConv2d_44 [(1, 256, 14, 14)] 256 \n",
268 | "---------------------------------------------------------------------------\n",
269 | " 46| ReLU_46 | BatchNorm2d_45 [(1, 256, 14, 14)] 0 \n",
270 | "---------------------------------------------------------------------------\n",
271 | " 47| QConv2d_47 | ReLU_46 [(1, 256, 14, 14)] 589824 \n",
272 | "---------------------------------------------------------------------------\n",
273 | " 48| BatchNorm2d_48 | QConv2d_47 [(1, 256, 14, 14)] 256 \n",
274 | "---------------------------------------------------------------------------\n",
275 | " 49| iadd_49 | BatchNorm2d_48 [(1, 256, 14, 14)] 0 \n",
276 | " | | ReLU_43 \n",
277 | "---------------------------------------------------------------------------\n",
278 | " 50| ReLU_50 | iadd_49 [(1, 256, 14, 14)] 0 \n",
279 | "---------------------------------------------------------------------------\n",
280 | " 51| QConv2d_51 | ReLU_50 [(1, 512, 7, 7)] 1179648 \n",
281 | "---------------------------------------------------------------------------\n",
282 | " 52| BatchNorm2d_52 | QConv2d_51 [(1, 512, 7, 7)] 512 \n",
283 | "---------------------------------------------------------------------------\n",
284 | " 53| ReLU_53 | BatchNorm2d_52 [(1, 512, 7, 7)] 0 \n",
285 | "---------------------------------------------------------------------------\n",
286 | " 54| QConv2d_54 | ReLU_53 [(1, 512, 7, 7)] 2359296 \n",
287 | "---------------------------------------------------------------------------\n",
288 | " 55| BatchNorm2d_55 | QConv2d_54 [(1, 512, 7, 7)] 512 \n",
289 | "---------------------------------------------------------------------------\n",
290 | " 56| QConv2d_56 | ReLU_50 [(1, 512, 7, 7)] 131072 \n",
291 | "---------------------------------------------------------------------------\n",
292 | " 57| BatchNorm2d_57 | QConv2d_56 [(1, 512, 7, 7)] 512 \n",
293 | "---------------------------------------------------------------------------\n",
294 | " 58| iadd_58 | BatchNorm2d_55 [(1, 512, 7, 7)] 0 \n",
295 | " | | BatchNorm2d_57 \n",
296 | "---------------------------------------------------------------------------\n",
297 | " 59| ReLU_59 | iadd_58 [(1, 512, 7, 7)] 0 \n",
298 | "---------------------------------------------------------------------------\n",
299 | " 60| QConv2d_60 | ReLU_59 [(1, 512, 7, 7)] 2359296 \n",
300 | "---------------------------------------------------------------------------\n",
301 | " 61| BatchNorm2d_61 | QConv2d_60 [(1, 512, 7, 7)] 512 \n",
302 | "---------------------------------------------------------------------------\n",
303 | " 62| ReLU_62 | BatchNorm2d_61 [(1, 512, 7, 7)] 0 \n",
304 | "---------------------------------------------------------------------------\n",
305 | " 63| QConv2d_63 | ReLU_62 [(1, 512, 7, 7)] 2359296 \n",
306 | "---------------------------------------------------------------------------\n",
307 | " 64| BatchNorm2d_64 | QConv2d_63 [(1, 512, 7, 7)] 512 \n",
308 | "---------------------------------------------------------------------------\n",
309 | " 65| iadd_65 | BatchNorm2d_64 [(1, 512, 7, 7)] 0 \n",
310 | " | | ReLU_59 \n",
311 | "---------------------------------------------------------------------------\n",
312 | " 66| ReLU_66 | iadd_65 [(1, 512, 7, 7)] 0 \n",
313 | "---------------------------------------------------------------------------\n",
314 | " 67| AdaptiveAvgPool2d_67 | ReLU_66 [(1, 512, 1, 1)] 0 \n",
315 | "---------------------------------------------------------------------------\n",
316 | " 68| torch.flatten_68 | AdaptiveAvgPool2d_67 [(1, 512)] 0 \n",
317 | "---------------------------------------------------------------------------\n",
318 | " 69| Linear_69 | torch.flatten_68 [(1, 1000)] 512000 \n",
319 | "---------------------------------------------------------------------------\n",
320 | "==================================================================================\n",
321 | "Total Trainable params: 11683712 \n",
322 | "Total Non-Trainable params: 0 \n",
323 | "Total params: 11683712 \n"
324 | ]
325 | }
326 | ],
327 | "source": [
328 | "input_tensor = torch.randn([1, 3, 224, 224]).cuda()\n",
329 | "model = model.cuda()\n",
330 | "transformer.summary(model, input_tensor = input_tensor)"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "metadata": {},
337 | "outputs": [],
338 | "source": []
339 | }
340 | ],
341 | "metadata": {
342 | "kernelspec": {
343 | "display_name": "Python 3",
344 | "language": "python",
345 | "name": "python3"
346 | },
347 | "language_info": {
348 | "codemirror_mode": {
349 | "name": "ipython",
350 | "version": 3
351 | },
352 | "file_extension": ".py",
353 | "mimetype": "text/x-python",
354 | "name": "python",
355 | "nbconvert_exporter": "python",
356 | "pygments_lexer": "ipython3",
357 | "version": "3.6.9"
358 | }
359 | },
360 | "nbformat": 4,
361 | "nbformat_minor": 2
362 | }
363 |
--------------------------------------------------------------------------------
/transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ricky40403/PyTransformer/22a0a824be0ef7d4dd65312c4b3e190e4cde4fee/transformers/__init__.py
--------------------------------------------------------------------------------
/transformers/quantize.py:
--------------------------------------------------------------------------------
1 | """
2 | Module Source:
3 | https://github.com/eladhoffer/quantized.pytorch
4 | """
5 |
6 | import torch
7 | from torch.autograd.function import InplaceFunction, Function
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import math
11 |
12 |
13 | class UniformQuantize(InplaceFunction):
14 | @staticmethod
15 | def forward(ctx, input, num_bits=8, min_value=None, max_value=None, inplace=False, symmetric=False, num_chunks=None):
16 | num_chunks = num_chunks = input.shape[0] if num_chunks is None else num_chunks
17 | if min_value is None or max_value is None:
18 | B = input.shape[0]
19 | y = input.view(B // num_chunks, -1)
20 |
21 | if min_value is None:
22 | min_value = y.min(-1)[0].mean(-1) # C
23 | #min_value = float(input.view(input.size(0), -1).min(-1)[0].mean())
24 |
25 | if max_value is None:
26 | #max_value = float(input.view(input.size(0), -1).max(-1)[0].mean())
27 | max_value = y.max(-1)[0].mean(-1) # C
28 |
29 | ctx.inplace = inplace
30 | ctx.num_bits = num_bits
31 | ctx.min_value = min_value
32 | ctx.max_value = max_value
33 |
34 | if ctx.inplace:
35 | ctx.mark_dirty(input)
36 | output = input
37 |
38 | else:
39 | output = input.clone()
40 |
41 | if symmetric:
42 | qmin = -2. ** (num_bits - 1)
43 | qmax = 2 ** (num_bits - 1) - 1
44 | max_value = torch.max(torch.abs(max_value), torch.abs(min_value))
45 | min_value = 0.
46 |
47 | else:
48 | qmin = 0.
49 | qmax = 2. ** num_bits - 1.
50 |
51 | scale = (max_value - min_value) / (qmax - qmin)
52 | scale = max(scale, 1e-8)
53 |
54 | output.add_(-min_value).div_(scale)
55 |
56 | output.clamp_(qmin, qmax).round_() # quantize
57 |
58 | output.mul_(scale).add_(min_value) # dequantize
59 |
60 | return output
61 |
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | # straight-through estimator
66 | grad_input = grad_output
67 | return grad_input, None, None, None, None, None, None
68 |
69 |
70 | def quantize(x, num_bits=8, min_value=None, max_value=None, inplace=False, symmetric=False, num_chunks=None):
71 | return UniformQuantize().apply(x, num_bits, min_value, max_value, inplace, symmetric, num_chunks)
72 |
73 |
74 | class QuantMeasure(nn.Module):
75 | """docstring for QuantMeasure."""
76 |
77 | def __init__(self, num_bits=8, momentum=0.1):
78 | super(QuantMeasure, self).__init__()
79 | self.register_buffer('running_min', torch.zeros(1))
80 | self.register_buffer('running_max', torch.zeros(1))
81 | self.momentum = momentum
82 | self.num_bits = num_bits
83 |
84 |
85 | def forward(self, input):
86 | if self.training:
87 | min_value = input.detach().view(input.size(0), -1).min(-1)[0].mean()
88 | max_value = input.detach().view(input.size(0), -1).max(-1)[0].mean()
89 | self.running_min.mul_(1 - self.momentum).add_(min_value * (self.momentum))
90 | self.running_max.mul_(1 - self.momentum).add_(max_value * (self.momentum))
91 |
92 | else:
93 | min_value = self.running_min
94 | max_value = self.running_max
95 |
96 | return quantize(input, self.num_bits, min_value=float(min_value), max_value=float(max_value), num_chunks=16)
97 |
98 |
99 | class QConv2d(nn.Conv2d):
100 | """docstring for QConv2d."""
101 |
102 | def __init__(self, in_channels, out_channels, kernel_size,
103 | stride=1, padding=0, dilation=1, groups=1, bias=True, num_bits=8, num_bits_weight=None):
104 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
105 | stride, padding, dilation, groups, bias)
106 | self.num_bits = num_bits
107 | self.num_bits_weight = num_bits_weight or num_bits
108 |
109 |
110 | def forward(self, input):
111 | qweight = quantize(self.weight, num_bits=self.num_bits_weight,
112 | min_value=float(self.weight.min()),
113 | max_value=float(self.weight.max()))
114 | if self.bias is not None:
115 | qbias = quantize(self.bias, num_bits=self.num_bits_weight)
116 | else:
117 | qbias = None
118 |
119 | output = F.conv2d(input, qweight, qbias, self.stride,
120 | self.padding, self.dilation, self.groups)
121 |
122 | return output
123 |
124 |
125 | class QLinear(nn.Linear):
126 | """docstring for QConv2d."""
127 |
128 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=None, num_bits_grad=None, biprecision=False):
129 | super(QLinear, self).__init__(in_features, out_features, bias)
130 | self.num_bits = num_bits
131 | self.num_bits_weight = num_bits_weight or num_bits
132 |
133 |
134 | def forward(self, input):
135 | qweight = quantize(self.weight, num_bits=self.num_bits_weight,
136 | min_value=float(self.weight.min()),
137 | max_value=float(self.weight.max()))
138 | if self.bias is not None:
139 | qbias = quantize(self.bias, num_bits=self.num_bits_weight)
140 | else:
141 | qbias = None
142 |
143 | output = F.linear(input, qweight, qbias)
144 |
145 | return output
146 |
--------------------------------------------------------------------------------
/transformers/torchTransformer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import copy
3 | import types
4 | import inspect
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | import pydot
12 | from graphviz import Digraph
13 |
14 | from .utils import _ReplaceFunc, Log, UnitLayer
15 |
16 |
17 |
18 | class TorchTransformer(nn.Module):
19 | """!
20 | This class handle layer swap, summary, visualization of the input model
21 | """
22 | def __init__(self):
23 | super(TorchTransformer, self).__init__()
24 |
25 | self._register_dict = OrderedDict()
26 | self.log = Log()
27 | self._raw_TrochFuncs = OrderedDict()
28 | self._raw_TrochFunctionals = OrderedDict()
29 |
30 | # register class to trans
31 | def register(self, origin_class, target_class):
32 | """!
33 | This function register which class should transform to target class.
34 | """
35 | print("register", origin_class, target_class)
36 |
37 | self._register_dict[origin_class] = target_class
38 |
39 | pass
40 |
41 | def trans_layers(self, model, update = True):
42 | """!
43 | This function transform layer by layers in register dictionarys
44 |
45 | @param model: input model to transfer
46 |
47 | @param update: default is True, wether to update the paramter from the orign layer or not.
48 | Note that it will update matched parameters only.
49 |
50 | @return transfered model
51 | """
52 | # print("trans layer")
53 | if len(self._register_dict) == 0:
54 | print("No layer to swap")
55 | print("Please use register( {origin_layer}, {target_layer} ) to register layer")
56 | return model
57 | else:
58 | for module_name in model._modules:
59 | # has children
60 | if len(model._modules[module_name]._modules) > 0:
61 | self.trans_layers(model._modules[module_name])
62 | else:
63 | if type(getattr(model, module_name)) in self._register_dict:
64 | # use inspect.signature to know args and kwargs of __init__
65 | _sig = inspect.signature(type(getattr(model, module_name)))
66 | _kwargs = {}
67 | for key in _sig.parameters:
68 | if _sig.parameters[key].default == inspect.Parameter.empty: #args
69 | # assign args
70 | # default values should be handled more properly, unknown data type might be an issue
71 | if 'kernel' in key:
72 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=3)
73 | value = 3
74 | elif 'channel' in key:
75 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=32)
76 | value = 32
77 | else:
78 | # _sig.parameters[key].replace(default=inspect.Parameter.empty, annotation=None)
79 | value = None
80 |
81 | _kwargs[key] = value
82 |
83 | _attr_dict = getattr(model, module_name).__dict__
84 | _layer_new = self._register_dict[type(getattr(model, module_name))](**_kwargs) # only give positional args
85 | _layer_new.__dict__.update(_attr_dict)
86 |
87 | setattr(model, module_name, _layer_new)
88 | return model
89 |
90 |
91 |
92 |
93 | def summary(self, model = None, input_tensor = None):
94 | """!
95 | This function act like keras summary function
96 |
97 | @param model: input model to summary
98 |
99 | @param input_tensor: input data of the model to forward
100 |
101 | """
102 | # input_tensor = torch.randn([1, 3, 224, 224])
103 | # input_tensor = input_tensor.cuda()
104 |
105 |
106 | self._build_graph(model, input_tensor)
107 |
108 | # get dicts and variables
109 | model_graph = self.log.getGraph()
110 | bottoms_graph = self.log.getBottoms()
111 | output_shape_graph = self.log.getOutShapes()
112 | # store top names for bottoms
113 | topNames = OrderedDict()
114 | totoal_trainable_params = 0
115 | total_params = 0
116 | # loop graph
117 | print("##########################################################################################")
118 | line_title = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format("Index","Layer (type)", "Bottoms","Output Shape", "Param #")
119 | print(line_title)
120 | print("---------------------------------------------------------------------------")
121 |
122 |
123 | for layer_index, key in enumerate(model_graph):
124 |
125 | # data layer
126 | if bottoms_graph[key] is None:
127 | # Layer information
128 | layer = model_graph[key]
129 | layer_type = layer.__class__.__name__
130 | if layer_type == "str":
131 | layer_type = key
132 | else:
133 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index)
134 |
135 | topNames[key] = layer_type
136 |
137 | # Layer Output shape
138 | output_shape = "[{}]".format(tuple(output_shape_graph[key]))
139 |
140 | # Layer Params
141 | param_weight_num = 0
142 | if hasattr(layer, "weight") and hasattr(layer.weight, "size"):
143 | param_weight_num += torch.prod(torch.LongTensor(list(layer.weight.size())))
144 | if layer.weight.requires_grad:
145 | totoal_trainable_params += param_weight_num
146 | if hasattr(layer, "bias") and hasattr(layer.weight, "bias"):
147 | param_weight_num += torch.prod(torch.LongTensor(list(layer.bias.size())))
148 | if layer.bias.requires_grad:
149 | totoal_trainable_params += param_weight_num
150 |
151 | total_params += param_weight_num
152 |
153 | new_layer = "{:5}| {:<15} | {:<15} {:<25} {:<15}".format(layer_index, layer_type, "", output_shape, param_weight_num)
154 | print(new_layer)
155 |
156 | else:
157 | # Layer Information
158 | layer = model_graph[key]
159 | layer_type = layer.__class__.__name__
160 |
161 | # add, sub, mul...,etc. (custom string)
162 | if layer_type == "str":
163 | # the key should be XXX_{idx_prevent_duplicate}
164 | tmp_key = key.split("_")
165 | tmp_key[-1] = "_{}".format(layer_index)
166 | tmp_key = "".join(tmp_key)
167 | layer_type = tmp_key
168 | else:
169 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index)
170 |
171 | topNames[key] = layer_type
172 |
173 | # Layer Bottoms
174 | bottoms = []
175 | for b_key in bottoms_graph[key]:
176 | bottom = topNames[b_key]
177 | bottoms.append(bottom)
178 |
179 | # Layer Output Shape
180 | if key in output_shape_graph:
181 | output_shape = "[{}]".format(tuple(output_shape_graph[key]))
182 | else:
183 | output_shape = "None"
184 |
185 | # Layer Params
186 | param_weight_num = 0
187 | if hasattr(layer, "weight") and hasattr(layer.weight, "size"):
188 | param_weight_num += torch.prod(torch.LongTensor(list(layer.weight.size())))
189 | if layer.weight.requires_grad:
190 | totoal_trainable_params += param_weight_num
191 | if hasattr(layer, "bias") and hasattr(layer.weight, "bias"):
192 | param_weight_num += torch.prod(torch.LongTensor(list(layer.bias.size())))
193 | if layer.bias.requires_grad:
194 | totoal_trainable_params += param_weight_num
195 | total_params += param_weight_num
196 |
197 | # Print (one bottom a line)
198 | for idx, b in enumerate(bottoms):
199 | # if more than one bottom, only print bottom
200 | if idx == 0:
201 | new_layer = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format(layer_index, layer_type, b, output_shape, param_weight_num)
202 | else:
203 | new_layer = "{:>5}| {:<15} | {:<15} {:<25} {:<15}".format("", "", b, "", "")
204 | print(new_layer)
205 | print("---------------------------------------------------------------------------")
206 |
207 |
208 | # total information
209 | print("==================================================================================")
210 | print("Total Trainable params: {} ".format(totoal_trainable_params))
211 | print("Total Non-Trainable params: {} ".format(total_params - totoal_trainable_params))
212 | print("Total params: {} ".format(total_params))
213 |
214 | # del model_graph, bottoms_graph, output_shape_graph, topNames
215 | # return model
216 |
217 | def visualize(self, model = None, input_tensor = None, save_name = None, graph_size = 30):
218 | """!
219 | This functin visualize the model architecture
220 |
221 | @param model: input model to summary
222 |
223 | @param input_tensor: input data of the model to forward
224 |
225 | @param save_name: if save_name is not None, it will save as '{save_name}.png'
226 |
227 | @param graph_size: graph_size for graphviz, to help increase the resolution of the output graph
228 |
229 | @return dot, graphviz's Digraph element
230 | """
231 | # input_tensor = torch.randn([1, 3, 224, 224])
232 | # model_graph = self.log.getGraph()
233 |
234 | # if graph empty
235 | if model is None:
236 | # check if use self modules
237 | if len(self._modules) > 0:
238 | self._build_graph(self, input_tensor)
239 | else:
240 | raise ValueError("Please input model to visualize")
241 | else:
242 | self._build_graph(model, input_tensor)
243 |
244 | # graph
245 | node_attr = dict(style='filled',
246 | shape='box',
247 | align='left',
248 | fontsize='30',
249 | ranksep='0.1',
250 | height='0.2')
251 |
252 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="{},{}".format(graph_size, graph_size)))
253 |
254 | # get dicts and variables
255 | model_graph = self.log.getGraph()
256 | bottoms_graph = self.log.getBottoms()
257 | output_shape_graph = self.log.getOutShapes()
258 | topNames = OrderedDict()
259 |
260 | for layer_index, key in enumerate(model_graph):
261 | # Input Data layer
262 | if bottoms_graph[key] is None:
263 | layer = model_graph[key]
264 | layer_type = layer.__class__.__name__
265 | # add, sub, mul...,etc. (custom string)
266 | if layer_type == "str":
267 | layer_type = key
268 | else:
269 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index)
270 |
271 | output_shape = "{}".format(tuple(output_shape_graph[key]))
272 | topNames[key] = layer_type
273 | output_shape = "[{}]".format(tuple(output_shape_graph[key]))
274 | layer_type = layer_type + "\nShape: " + output_shape
275 |
276 | dot.node(str(key), layer_type, fillcolor='orange')
277 | else:
278 | # Layer Information
279 | layer = model_graph[key]
280 | layer_type = layer.__class__.__name__
281 | # add, sub, mul...,etc. (custom string)
282 | if layer_type == "str":
283 | # the key should be XXX_{idx_prevent_duplicate}
284 | tmp_key = key.split("_")
285 | tmp_key[-1] = "_{}".format(layer_index)
286 | tmp_key = "".join(tmp_key)
287 | layer_type = tmp_key
288 | else:
289 | layer_type = layer.__class__.__name__ + "_{}".format(layer_index)
290 |
291 | topNames[key] = layer_type
292 | # layer_type = layer_type
293 | # print("Layer: {}".format(layer_type))
294 | # print("Key: {}".format(key))
295 | # add bottoms
296 |
297 | layer_type = layer_type + "\nBottoms: "
298 | for b_key in bottoms_graph[key]:
299 | layer_type = layer_type + topNames[b_key] + "\n"
300 |
301 | output_shape = "[{}]".format(tuple(output_shape_graph[key]))
302 | layer_type = layer_type + "Shape: " + output_shape
303 |
304 | dot.node(str(key), layer_type, fillcolor='orange')
305 | # link bottoms
306 | # print("Bottoms: ")
307 | for bot_key in bottoms_graph[key]:
308 | # print(bot_key)
309 | dot.edge(str(bot_key), str(key))
310 |
311 | # return graph
312 | if save_name is not None:
313 | (graph,) = pydot.graph_from_dot_data(dot.source)
314 | graph.write_png(save_name + ".png" )
315 | return dot
316 |
317 | def _build_graph(self, model, input_tensor = None):
318 |
319 | if input_tensor is None:
320 | raise ValueError("Please set input tensor")
321 |
322 | # reset log
323 | self.log = Log()
324 | # add Data input
325 | self.log.setTensor(input_tensor)
326 |
327 |
328 | tmp_model = self._trans_unit(copy.deepcopy(model))
329 |
330 | for f in dir(torch):
331 |
332 | # if private function, pass
333 | if f.startswith("_") or "tensor" == f:
334 | continue
335 | if isinstance(getattr(torch, f) ,types.BuiltinMethodType) or isinstance(getattr(torch, f) ,types.BuiltinFunctionType):
336 | self._raw_TrochFuncs[f] = getattr(torch, f)
337 | setattr(torch, f, _ReplaceFunc(getattr(torch,f), self._torchFunctions))
338 |
339 | for f in dir(F):
340 | # if private function, pass
341 | if f.startswith("_"):
342 | continue
343 |
344 | if isinstance(getattr(F, f) ,types.BuiltinMethodType) or isinstance(getattr(F, f) ,types.BuiltinFunctionType) or isinstance(getattr(F, f) ,types.FunctionType):
345 | self._raw_TrochFunctionals[f] = getattr(F, f)
346 | setattr(F, f, _ReplaceFunc(getattr(F,f), self._torchFunctionals))
347 |
348 |
349 | self.log = tmp_model.forward(self.log)
350 |
351 | # reset back
352 | for f in self._raw_TrochFuncs:
353 | setattr(torch, f, self._raw_TrochFuncs[f])
354 |
355 | for f in self._raw_TrochFunctionals:
356 | setattr(F, f, self._raw_TrochFunctionals[f])
357 |
358 | del tmp_model
359 |
360 | def _trans_unit(self, model):
361 | # print("TRNS_UNIT")
362 | for module_name in model._modules:
363 | # has children
364 | if len(model._modules[module_name]._modules) > 0:
365 | self._trans_unit(model._modules[module_name])
366 | else:
367 | unitlayer = UnitLayer(getattr(model, module_name))
368 | setattr(model, module_name, unitlayer)
369 |
370 | return model
371 |
372 | def _torchFunctions(self, raw_func, *args, **kwargs):
373 | """!
374 | The replaced torch function (eg: torch.{function}) will go here
375 | """
376 | # print("Torch function")
377 | function_name = raw_func.__name__
378 |
379 | # torch function may has no input
380 | # so check first
381 |
382 | if len(args) > 0:
383 | logs = args[0]
384 | cur_args = args[1:]
385 | elif len(kwargs) > 0:
386 |
387 | return raw_func(**kwargs)
388 | else:
389 | return raw_func()
390 |
391 | # check is user used or in torch function call
392 | is_tensor_in = False
393 | # tensor input
394 | # multi tensor input
395 | if isinstance(logs, tuple) and (type(logs[0]) == torch.Tensor):
396 | cur_inputs = logs
397 | is_tensor_in = True
398 | return raw_func(*args, **kwargs)
399 | # single tensor input
400 | elif (type(logs) == torch.Tensor):
401 |
402 | cur_inputs = logs
403 | is_tensor_in = True
404 | # print(*args)
405 | # print(**kwargs)
406 | return raw_func(*args, **kwargs)
407 | elif (type(logs) == nn.Parameter):
408 | cur_inputs = logs
409 | is_tensor_in = True
410 | return raw_func(*args, **kwargs)
411 | # log input
412 | else:
413 | # multi inputs
414 | bottoms = []
415 | cur_inputs = []
416 |
417 | if isinstance(logs, tuple) or isinstance(logs, list):
418 | # may use origin input log as others' input
419 | # eg: densenet in torchvision 0.4.0
420 | cur_log = copy.deepcopy(logs[0])
421 | for log in logs:
422 | cur_inputs.append(log.cur_tensor)
423 | # print(log.cur_tensor.size())
424 | bottoms.append(log.cur_id)
425 | # update informations
426 | cur_log.graph.update(log.graph)
427 | cur_log.bottoms.update(log.bottoms)
428 | cur_log.output_shape.update(log.output_shape)
429 | cur_inputs = tuple(cur_inputs)
430 | # one input
431 | else:
432 | # print(args)
433 | # print(kwargs)
434 | cur_log = logs
435 | cur_inputs = cur_log.cur_tensor
436 | bottoms.append(cur_log.cur_id)
437 |
438 | # replace logs to tensor as function inputs to get output tensor
439 | args = list(args)
440 | args[0] = cur_inputs
441 | args = tuple(args)
442 | # send into origin functions
443 | #out_tensor = raw_func(*args, **kwargs).clone().detach()
444 | out_tensor = raw_func(*args, **kwargs).clone()
445 |
446 | # if function call, just return out tensor
447 | if is_tensor_in:
448 | return out_tensor
449 |
450 | # most multi input change to one output
451 | # most multi output has one input
452 | # if shape change
453 | # store theese types of opreation as a layer
454 | if isinstance(logs, tuple) or isinstance(logs, list) or isinstance(out_tensor, tuple) or (logs.cur_tensor.size() != out_tensor.size()):
455 | layer_name = "torch.{}_{}".format(function_name, len(cur_log.graph))
456 | cur_log.graph[layer_name] = layer_name
457 | cur_log.bottoms[layer_name] = bottoms
458 | cur_log.cur_id = layer_name
459 |
460 | # multi output
461 | if not isinstance(out_tensor , torch.Tensor):
462 | # print("multi output")
463 | out_logs = []
464 | for t in out_tensor:
465 | out_log = copy.deepcopy(cur_log)
466 | out_log.setTensor(t)
467 | out_logs.append(out_log)
468 |
469 | # sometimes will has (out, ) and this lens is >1
470 | if len(out_logs) == 1:
471 | out_logs = out_logs[0]
472 | return out_logs
473 |
474 | else:
475 | cur_log.setTensor(out_tensor)
476 | return cur_log
477 |
478 | # torch.functionals
479 | def _torchFunctionals(self, raw_func, *args, **kwargs):
480 | """!
481 | The replaced torch.functional function (eg: F.{function}) will go here
482 | """
483 | # print("Functional")
484 | function_name = raw_func.__name__
485 | # print(raw_func.__name__)
486 |
487 | # functional has input expect affine_grid
488 | if function_name == "affine_grid":
489 | pass
490 | else:
491 | logs = args[0]
492 | cur_args = args[1:]
493 |
494 | # check is user used or in torch function call
495 | is_tensor_in = False
496 | # tensor input
497 | if (len(logs) > 1) and (type(logs[0]) == torch.Tensor):
498 | # print(logs[0].size(), logs[1].size())
499 | cur_inputs = logs
500 | is_tensor_in = True
501 | out = raw_func(*args, **kwargs)
502 | # print("Functional return : {}".format(out.size()))
503 | return raw_func(*args, **kwargs)
504 |
505 | elif (len(logs) ==1) and (type(logs) == torch.Tensor):
506 | cur_inputs = logs
507 | is_tensor_in = True
508 | out = raw_func(*args, **kwargs)
509 | # print("Functional return : {}".format(out.size()))
510 | return raw_func(*args, **kwargs)
511 |
512 | # log input
513 | else:
514 | # multi inputs
515 | bottoms = []
516 | cur_inputs = []
517 | if len(logs) > 1:
518 | cur_log = logs[0]
519 | for log in logs:
520 | cur_inputs.append(log.cur_tensor)
521 | bottoms.append(log.cur_id)
522 | # update informations
523 | cur_log.graph.update(log.graph)
524 | cur_log.bottoms.update(log.bottoms)
525 | cur_log.output_shape.update(log.output_shape)
526 | cur_inputs = tuple(cur_inputs)
527 | # one input
528 | else:
529 | cur_log = logs
530 | cur_inputs = cur_log.cur_tensor
531 | bottoms.append(cur_log.cur_id)
532 |
533 |
534 |
535 | # replace logs to tensor as function inputs to get output tensor
536 | args = list(args)
537 | args[0] = cur_inputs
538 | args = tuple(args)
539 | # send into origin functions
540 | #out_tensor = raw_func(*args, **kwargs).clone().detach()
541 | out_tensor = raw_func(*args, **kwargs).clone()
542 |
543 | # if function call, just return out tensor
544 | if is_tensor_in:
545 | return out_tensor
546 |
547 | # if log input and is function type, store as an layer
548 | if isinstance(raw_func, types.FunctionType):
549 | # use multiple address as name to prevent duplicate address
550 | layer_name = "F.{}_{}{}{}".format(function_name, id(out_tensor), id(args), id(kwargs))
551 | # replace with new address if still duplicate
552 | while layer_name in cur_log.graph:
553 | #if layer_name in cur_log.graph:
554 | # tmp_list = []
555 | # tmp_list.append(out_tensor)
556 | # tmp_tensor = copy.deepcopy(tmp_list[-1])
557 | # tmp_tensor = tmp_list[-1].clone()
558 | tmp_tensor = torch.tensor([0])
559 |
560 | # should not duplicate again?
561 | # layer_name = layer_name.split('.')[0] + "F" + ".{}_{}{}{}".format(function_name, id(tmp_tensor), id(args), id(kwargs))
562 | layer_name = "F.{}_{}{}{}{}".format(function_name, id(tmp_tensor), id(args), id(kwargs), int((time.time()*100000)%1000000))
563 |
564 | cur_log.graph[layer_name] = layer_name
565 | cur_log.bottoms[layer_name] = bottoms
566 | cur_log.cur_id = layer_name
567 |
568 | # if multi-output
569 | # if len(out_tensor) > 1:
570 | if not isinstance(out_tensor, torch.Tensor):
571 | out_logs = []
572 | for t in out_tensor:
573 | out_log = copy.deepcopy(cur_log)
574 | out_log.setTensor(t)
575 | out_logs.append(out_log)
576 |
577 | return out_logs
578 | else:
579 | cur_log.setTensor(out_tensor)
580 | return cur_log
581 |
582 |
583 |
--------------------------------------------------------------------------------
/transformers/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from collections import OrderedDict
6 |
7 | class _ReplaceFunc(object):
8 | """!
9 | This Function replace torch functions with self-define Function.
10 | Inorder to get the imformation of torch model layer infomration.
11 | """
12 | def __init__(self, ori_func, replace_func, **kwargs):
13 | self.torch_func = ori_func
14 | self.replace_func = replace_func
15 |
16 | def __call__(self, *args, **kwargs):
17 | out = self.replace_func(self.torch_func, *args, **kwargs)
18 | return out
19 |
20 |
21 | class Log(object):
22 | """!
23 | This class use as an log to replace input tensor and store all the information
24 | """
25 | def __init__(self):
26 | self.graph = OrderedDict()
27 | self.bottoms = OrderedDict()
28 | self.output_shape = OrderedDict()
29 | self.cur_tensor = None
30 | self.cur_id = None
31 | self.tmp_list = None
32 | self.log_init()
33 |
34 | def __len__(self):
35 | """!
36 | Log should be one
37 | """
38 | return 1
39 |
40 | def __copy__(self):
41 | """!
42 | copy, create new one and assign clone tensor in log
43 | """
44 | copy_paster = Log()
45 | copy_paster.__dict__.update(self.__dict__)
46 | copy_paster.cur_tensor = self.cur_tensor.clone()
47 | return copy_paster
48 |
49 | def __deepcopy__(self, memo):
50 | """!
51 | deepcopy, create new one and assign clone tensor in log
52 | """
53 | copy_paster = Log()
54 | copy_paster.__dict__.update(self.__dict__)
55 | copy_paster.cur_tensor = self.cur_tensor.clone()
56 | return copy_paster
57 |
58 | def reset(self):
59 | """
60 | This function reset all attribute in log.
61 | """
62 | self.graph = OrderedDict()
63 | self.bottoms = OrderedDict()
64 | self.output_shape = OrderedDict()
65 | self.cur_tensor = None
66 | self.cur_id = None
67 | self.tmp_list = []
68 | self.log_init()
69 |
70 |
71 | # add data input layer to log
72 | def log_init(self):
73 | """
74 | Init log attribute, set Data Layer as the first layer
75 | """
76 | layer_id = "Data"
77 | self.graph[layer_id] = layer_id
78 | self.bottoms[layer_id] = None
79 | self.output_shape[layer_id] = ""
80 | self.cur_id = layer_id
81 | self.tmp_list = []
82 |
83 |
84 | # for general layer (should has only one input?)
85 | def putLayer(self, layer):
86 | """!
87 | Put genreal layer's information into log
88 | """
89 | # force use different address id ( prevent use same defined layer more than once, eg: bottleneck in torchvision)
90 | # tmp_layer = copy.deepcopy(layer)
91 | layer_id = id(layer)
92 | self.tmp_list.append(layer)
93 | layer_id = id(self.tmp_list[-1])
94 | if layer_id in self.graph:
95 | tmp_layer = copy.deepcopy(layer)
96 | self.tmp_list.append(tmp_layer)
97 | # layer_id = id(self.tmp_list[-1])
98 | layer_id = id(tmp_layer)
99 |
100 | self.graph[layer_id] = layer
101 | self.bottoms[layer_id] = [self.cur_id]
102 | self.cur_id = layer_id
103 | # del layer, tmp_layer, layer_id
104 |
105 | def getGraph(self):
106 | """!
107 | This function get the layers graph from log
108 | """
109 | return self.graph
110 |
111 | def getBottoms(self):
112 | """!
113 | This function get the layers bottoms from log
114 | """
115 | return self.bottoms
116 |
117 | def getOutShapes(self):
118 | """!
119 | This function get the layers output shape from log
120 | """
121 | return self.output_shape
122 |
123 | def getTensor(self):
124 | """!
125 | This function get the layers current tensor (output tensor)
126 | """
127 | return self.cur_tensor
128 |
129 | def setTensor(self, tensor):
130 | """!
131 | This function set the layer's current tensor
132 | and also change output shape by the input tensor
133 | """
134 | self.cur_tensor = tensor
135 | if tensor is not None:
136 | self.output_shape[self.cur_id] = self.cur_tensor.size()
137 | else:
138 | self.output_shape[self.cur_id] = None
139 |
140 |
141 | # handle tensor operation(eg: tensor.view)
142 | def __getattr__(self, name):
143 | """!
144 | This function handle all the tensor operation
145 | """
146 | if name == "__deepcopy__" or name == "__setstate__":
147 | return object.__getattribute__(self, name)
148 | # if get data => get cur_tensor.data
149 | elif name == "data":
150 | return self.cur_tensor.data
151 |
152 | elif hasattr(self.cur_tensor, name):
153 | def wrapper(*args, **kwargs):
154 | func = self.cur_tensor.__getattribute__(name)
155 | out_tensor = func(*args, **kwargs)
156 |
157 | if not isinstance(out_tensor, torch.Tensor):
158 | out_logs = []
159 | for t in out_tensor:
160 | out_log = copy.deepcopy(self)
161 | out_log.setTensor(t)
162 | out_logs.append(out_log)
163 |
164 | return out_logs
165 | else:
166 | self.cur_tensor = out_tensor
167 | self.output_shape[self.cur_id] = out_tensor.size()
168 |
169 | return self
170 | # print(wrapper)
171 | return wrapper
172 |
173 | # return self
174 |
175 |
176 | else:
177 | return object.__getattribute__(self, name)
178 |
179 |
180 | def __add__(self, other):
181 | """!
182 | Log addition
183 | """
184 | #print("add")
185 | # merge other branch
186 | self.graph.update(other.graph)
187 | self.bottoms.update(other.bottoms)
188 | self.output_shape.update(other.output_shape)
189 | layer_name = "add_{}".format(len(self.graph))
190 | self.graph[layer_name] = layer_name
191 | self.bottoms[layer_name] = [self.cur_id, other.cur_id]
192 | self.output_shape[layer_name] = self.cur_tensor.size()
193 | self.cur_id = layer_name
194 | # save memory
195 | del other
196 |
197 | return self
198 |
199 |
200 | def __iadd__(self, other):
201 | """!
202 | Log identity addition
203 | """
204 | #print("iadd")
205 | # merge other branch
206 | self.graph.update(other.graph)
207 | self.bottoms.update(other.bottoms)
208 | self.output_shape.update(other.output_shape)
209 | layer_name = "iadd_{}".format(len(self.graph))
210 | self.graph[layer_name] = layer_name
211 | self.bottoms[layer_name] = [self.cur_id, other.cur_id]
212 | self.output_shape[layer_name] = self.cur_tensor.size()
213 | self.cur_id = layer_name
214 | # save memory
215 | del other
216 | return self
217 |
218 |
219 | def __sub__(self, other):
220 | """!
221 | Log substraction
222 | """
223 | #print("sub")
224 | # merge other branch
225 | self.graph.update(other.graph)
226 | self.bottoms.update(other.bottoms)
227 | self.output_shape.update(other.output_shape)
228 | layer_name = "sub_{}".format(len(self.graph))
229 | self.graph[layer_name] = layer_name
230 | self.bottoms[layer_name] = [self.cur_id, other.cur_id]
231 | self.output_shape[layer_name] = self.cur_tensor.size()
232 | self.cur_id = layer_name
233 | # save memory
234 | del other
235 | return self
236 |
237 |
238 | def __isub__(self, other):
239 | """!
240 | Log identity substraction
241 | """
242 | #print("isub")
243 | # merge other branch
244 | self.graph.update(other.graph)
245 | self.bottoms.update(other.bottoms)
246 | self.output_shape.update(other.output_shape)
247 | layer_name = "sub_{}".format(len(self.graph))
248 | self.graph[layer_name] = layer_name
249 | self.bottoms[layer_name] = [self.cur_id, other.cur_id]
250 | self.output_shape[layer_name] = self.cur_tensor.size()
251 | self.cur_id = layer_name
252 | # save memory
253 | del other
254 | return self
255 |
256 |
257 | def __mul__(self, other):
258 | """!
259 | Log multiplication
260 | """
261 | #print("mul")
262 | # merge other branch
263 | self.graph.update(other.graph)
264 | self.bottoms.update(other.bottoms)
265 | self.output_shape.update(other.output_shape)
266 | layer_name = "mul_{}".format(len(self.graph))
267 | self.graph[layer_name] = layer_name
268 | self.bottoms[layer_name] = [self.cur_id, other.cur_id]
269 | self.output_shape[layer_name] = self.cur_tensor.size()
270 | self.cur_id = layer_name
271 | # save memory
272 | del other
273 | return self
274 |
275 |
276 | def __imul__(self, other):
277 | """!
278 | Log identity multiplication
279 | """
280 | #print("imul")
281 | # merge other branch
282 | self.graph.update(other.graph)
283 | self.bottoms.update(other.bottoms)
284 | self.output_shape.update(other.output_shape)
285 | layer_name = "mul_{}".format(len(self.graph))
286 | self.graph[layer_name] = layer_name
287 | self.bottoms[layer_name] = [self.cur_id, other.cur_id]
288 | self.output_shape[layer_name] = self.cur_tensor.size()
289 | self.cur_id = layer_name
290 | # save memory
291 | del other
292 | return self
293 |
294 |
295 | def size(self, dim=None):
296 | """!
297 | This function return the size of the tensor by given dim
298 |
299 | @param dim: defult None, return as tensor.size(dim)
300 |
301 | @return tensor size by dim
302 | """
303 | return self.cur_tensor.size(dim) if dim is not None else self.cur_tensor.size()
304 |
305 |
306 |
307 | class UnitLayer(nn.Module):
308 | """!
309 | This class is an Unit-layer act like an identity layer
310 | """
311 | def __init__(self, ori_layer):
312 | super(UnitLayer, self).__init__()
313 | self.origin_layer = ori_layer
314 |
315 |
316 | def setOrigin(self, ori_layer):
317 | self.origin_layer = ori_layer
318 |
319 |
320 | # general layer should has only one input?
321 | def forward(self, log, *args):
322 | # prevent overwrite log for other forward flow
323 | cur_log = copy.deepcopy(log)
324 | # print(cur_log)
325 | cur_log.putLayer(self.origin_layer)
326 |
327 | # print(log.cur_tensor)
328 | log_tensor = log.getTensor()
329 | # out_tensor = self.origin_layer(log_tensor).clone().detach()
330 | out_tensor = self.origin_layer(log_tensor).clone()
331 | cur_log.setTensor(out_tensor)
332 |
333 | return cur_log
--------------------------------------------------------------------------------
/visualize_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "import torchvision\n",
12 | "import torchvision.models as models\n",
13 | "\n",
14 | "from transformers.torchTransformer import TorchTransformer"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {
21 | "scrolled": true
22 | },
23 | "outputs": [
24 | {
25 | "data": {
26 | "text/plain": [
27 | "AlexNet(\n",
28 | " (features): Sequential(\n",
29 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n",
30 | " (1): ReLU(inplace=True)\n",
31 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
32 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
33 | " (4): ReLU(inplace=True)\n",
34 | " (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
35 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
36 | " (7): ReLU(inplace=True)\n",
37 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
38 | " (9): ReLU(inplace=True)\n",
39 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
40 | " (11): ReLU(inplace=True)\n",
41 | " (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
42 | " )\n",
43 | " (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n",
44 | " (classifier): Sequential(\n",
45 | " (0): Dropout(p=0.5, inplace=False)\n",
46 | " (1): Linear(in_features=9216, out_features=4096, bias=True)\n",
47 | " (2): ReLU(inplace=True)\n",
48 | " (3): Dropout(p=0.5, inplace=False)\n",
49 | " (4): Linear(in_features=4096, out_features=4096, bias=True)\n",
50 | " (5): ReLU(inplace=True)\n",
51 | " (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
52 | " )\n",
53 | ")"
54 | ]
55 | },
56 | "execution_count": 2,
57 | "metadata": {},
58 | "output_type": "execute_result"
59 | }
60 | ],
61 | "source": [
62 | "model_name = \"alexnet\"\n",
63 | "model = models.__dict__[model_name]()\n",
64 | "model.eval()"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "input_tensor = torch.randn([1, 3, 224, 224])"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "## visualization"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "### without saving image"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 4,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "transformer = TorchTransformer()"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 5,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "dot = transformer.visualize(model, input_tensor = input_tensor)"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 6,
111 | "metadata": {
112 | "scrolled": true
113 | },
114 | "outputs": [
115 | {
116 | "data": {
117 | "image/svg+xml": [
118 | "\n",
119 | "\n",
121 | "\n",
123 | "\n",
124 | "\n"
446 | ],
447 | "text/plain": [
448 | ""
449 | ]
450 | },
451 | "execution_count": 6,
452 | "metadata": {},
453 | "output_type": "execute_result"
454 | }
455 | ],
456 | "source": [
457 | "dot"
458 | ]
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {},
463 | "source": [
464 | "### save image"
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": 7,
470 | "metadata": {},
471 | "outputs": [],
472 | "source": [
473 | "model.cuda()\n",
474 | "input_tensor = input_tensor.cuda()"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 8,
480 | "metadata": {},
481 | "outputs": [],
482 | "source": [
483 | "dot = transformer.visualize(model, input_tensor = input_tensor, save_name = \"examples/{}\".format(model_name), graph_size = 80)"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": null,
489 | "metadata": {},
490 | "outputs": [],
491 | "source": []
492 | }
493 | ],
494 | "metadata": {
495 | "kernelspec": {
496 | "display_name": "Python 3",
497 | "language": "python",
498 | "name": "python3"
499 | },
500 | "language_info": {
501 | "codemirror_mode": {
502 | "name": "ipython",
503 | "version": 3
504 | },
505 | "file_extension": ".py",
506 | "mimetype": "text/x-python",
507 | "name": "python",
508 | "nbconvert_exporter": "python",
509 | "pygments_lexer": "ipython3",
510 | "version": "3.6.9"
511 | }
512 | },
513 | "nbformat": 4,
514 | "nbformat_minor": 2
515 | }
516 |
--------------------------------------------------------------------------------