├── LICENSE ├── README.md ├── examples ├── CNN.py └── Transformer.py ├── modelsummary ├── __init__.py ├── hierarchicalsummary.py └── modelsummary.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tae Hwan Jung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## modelsummary (Pytorch Model summary) 2 | 3 | > Keras style model.summary() in PyTorch, [torchsummary](https://github.com/sksq96/pytorch-summary) 4 | 5 | This is Pytorch library for visualization Improved tool of [torchsummary](https://github.com/sksq96/pytorch-summary) and [torchsummaryX](https://github.com/nmhkahn/torchsummaryX). I was inspired by [torchsummary](https://github.com/sksq96/pytorch-summary) and I written down code which i referred to. **It is not care with number of Input parameter!** 6 | 7 | ```python 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from modelsummary import summary 13 | 14 | class Net(nn.Module): 15 | def __init__(self): 16 | super(Net, self).__init__() 17 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 18 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 19 | self.conv2_drop = nn.Dropout2d() 20 | self.fc1 = nn.Linear(320, 50) 21 | self.fc2 = nn.Linear(50, 10) 22 | 23 | def forward(self, x): 24 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 25 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 26 | x = x.view(-1, 320) 27 | x = F.relu(self.fc1(x)) 28 | x = F.dropout(x, training=self.training) 29 | x = self.fc2(x) 30 | return F.log_softmax(x, dim=1) 31 | 32 | # show input shape 33 | summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True) 34 | 35 | # show output shape 36 | summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False) 37 | ``` 38 | 39 | ``` 40 | ----------------------------------------------------------------------- 41 | Layer (type) Input Shape Param # 42 | ======================================================================= 43 | Conv2d-1 [-1, 1, 28, 28] 260 44 | Conv2d-2 [-1, 10, 12, 12] 5,020 45 | Dropout2d-3 [-1, 20, 8, 8] 0 46 | Linear-4 [-1, 320] 16,050 47 | Linear-5 [-1, 50] 510 48 | ======================================================================= 49 | Total params: 21,840 50 | Trainable params: 21,840 51 | Non-trainable params: 0 52 | ----------------------------------------------------------------------- 53 | 54 | ----------------------------------------------------------------------- 55 | Layer (type) Output Shape Param # 56 | ======================================================================= 57 | Conv2d-1 [-1, 10, 24, 24] 260 58 | Conv2d-2 [-1, 20, 8, 8] 5,020 59 | Dropout2d-3 [-1, 20, 8, 8] 0 60 | Linear-4 [-1, 50] 16,050 61 | Linear-5 [-1, 10] 510 62 | ======================================================================= 63 | Total params: 21,840 64 | Trainable params: 21,840 65 | Non-trainable params: 0 66 | ----------------------------------------------------------------------- 67 | ``` 68 | 69 | 70 | 71 | ## Quick Start 72 | 73 | Just download with pip `modelsummary` 74 | 75 | `pip install modelsummary` and `from modelsummary import summary` 76 | 77 | You can use this library like this. If you see more detail, Please see example code. 78 | 79 | ``` 80 | from modelsummary import summary 81 | 82 | model = your_model_name() 83 | 84 | # show input shape 85 | summary(model, (input tensor you want), show_input=True) 86 | 87 | # show output shape 88 | summary(model, (input tensor you want), show_input=False) 89 | 90 | # show hierarchical struct 91 | summary(model, (input tensor you want), show_hierarchical=True) 92 | ``` 93 | 94 | 95 | 96 | summary function has this parameter options`def summary(model, *inputs, batch_size=-1, show_input=True, show_hierarchical=False)` 97 | 98 | #### Options 99 | 100 | - model : your model class 101 | - *input : your input tensor **datas** (Asterisk) 102 | - batch_size : `-1` is same with tensor `None` 103 | - show_input : show input shape data, **if this parameter is False, it will show output shape** **default : True** 104 | - show_hierarchical : show hierarchical data structure, **default : False** 105 | 106 | 107 | 108 | ## Result 109 | 110 | Run example using Transformer Model in [Attention is all you need paper(2017)](https://arxiv.org/abs/1706.03762) 111 | 112 | 1) showing input shape 113 | 114 | ``` 115 | # show input shape 116 | summary(model, enc_inputs, dec_inputs, show_input=True) 117 | 118 | ----------------------------------------------------------------------- 119 | Layer (type) Input Shape Param # 120 | ======================================================================= 121 | Encoder-1 [-1, 5] 0 122 | Embedding-2 [-1, 5] 3,072 123 | Embedding-3 [-1, 5] 3,072 124 | EncoderLayer-4 [-1, 5, 512] 0 125 | MultiHeadAttention-5 [-1, 5, 512] 0 126 | Linear-6 [-1, 5, 512] 262,656 127 | Linear-7 [-1, 5, 512] 262,656 128 | Linear-8 [-1, 5, 512] 262,656 129 | PoswiseFeedForwardNet-9 [-1, 5, 512] 0 130 | Conv1d-10 [-1, 512, 5] 1,050,624 131 | Conv1d-11 [-1, 2048, 5] 1,049,088 132 | EncoderLayer-12 [-1, 5, 512] 0 133 | MultiHeadAttention-13 [-1, 5, 512] 0 134 | Linear-14 [-1, 5, 512] 262,656 135 | Linear-15 [-1, 5, 512] 262,656 136 | Linear-16 [-1, 5, 512] 262,656 137 | PoswiseFeedForwardNet-17 [-1, 5, 512] 0 138 | Conv1d-18 [-1, 512, 5] 1,050,624 139 | Conv1d-19 [-1, 2048, 5] 1,049,088 140 | EncoderLayer-20 [-1, 5, 512] 0 141 | MultiHeadAttention-21 [-1, 5, 512] 0 142 | Linear-22 [-1, 5, 512] 262,656 143 | Linear-23 [-1, 5, 512] 262,656 144 | Linear-24 [-1, 5, 512] 262,656 145 | PoswiseFeedForwardNet-25 [-1, 5, 512] 0 146 | Conv1d-26 [-1, 512, 5] 1,050,624 147 | Conv1d-27 [-1, 2048, 5] 1,049,088 148 | EncoderLayer-28 [-1, 5, 512] 0 149 | MultiHeadAttention-29 [-1, 5, 512] 0 150 | Linear-30 [-1, 5, 512] 262,656 151 | Linear-31 [-1, 5, 512] 262,656 152 | Linear-32 [-1, 5, 512] 262,656 153 | PoswiseFeedForwardNet-33 [-1, 5, 512] 0 154 | Conv1d-34 [-1, 512, 5] 1,050,624 155 | Conv1d-35 [-1, 2048, 5] 1,049,088 156 | EncoderLayer-36 [-1, 5, 512] 0 157 | MultiHeadAttention-37 [-1, 5, 512] 0 158 | Linear-38 [-1, 5, 512] 262,656 159 | Linear-39 [-1, 5, 512] 262,656 160 | Linear-40 [-1, 5, 512] 262,656 161 | PoswiseFeedForwardNet-41 [-1, 5, 512] 0 162 | Conv1d-42 [-1, 512, 5] 1,050,624 163 | Conv1d-43 [-1, 2048, 5] 1,049,088 164 | EncoderLayer-44 [-1, 5, 512] 0 165 | MultiHeadAttention-45 [-1, 5, 512] 0 166 | Linear-46 [-1, 5, 512] 262,656 167 | Linear-47 [-1, 5, 512] 262,656 168 | Linear-48 [-1, 5, 512] 262,656 169 | PoswiseFeedForwardNet-49 [-1, 5, 512] 0 170 | Conv1d-50 [-1, 512, 5] 1,050,624 171 | Conv1d-51 [-1, 2048, 5] 1,049,088 172 | Decoder-52 [-1, 5] 0 173 | Embedding-53 [-1, 5] 3,584 174 | Embedding-54 [-1, 5] 3,072 175 | DecoderLayer-55 [-1, 5, 512] 0 176 | MultiHeadAttention-56 [-1, 5, 512] 0 177 | Linear-57 [-1, 5, 512] 262,656 178 | Linear-58 [-1, 5, 512] 262,656 179 | Linear-59 [-1, 5, 512] 262,656 180 | MultiHeadAttention-60 [-1, 5, 512] 0 181 | Linear-61 [-1, 5, 512] 262,656 182 | Linear-62 [-1, 5, 512] 262,656 183 | Linear-63 [-1, 5, 512] 262,656 184 | PoswiseFeedForwardNet-64 [-1, 5, 512] 0 185 | Conv1d-65 [-1, 512, 5] 1,050,624 186 | Conv1d-66 [-1, 2048, 5] 1,049,088 187 | DecoderLayer-67 [-1, 5, 512] 0 188 | MultiHeadAttention-68 [-1, 5, 512] 0 189 | Linear-69 [-1, 5, 512] 262,656 190 | Linear-70 [-1, 5, 512] 262,656 191 | Linear-71 [-1, 5, 512] 262,656 192 | MultiHeadAttention-72 [-1, 5, 512] 0 193 | Linear-73 [-1, 5, 512] 262,656 194 | Linear-74 [-1, 5, 512] 262,656 195 | Linear-75 [-1, 5, 512] 262,656 196 | PoswiseFeedForwardNet-76 [-1, 5, 512] 0 197 | Conv1d-77 [-1, 512, 5] 1,050,624 198 | Conv1d-78 [-1, 2048, 5] 1,049,088 199 | DecoderLayer-79 [-1, 5, 512] 0 200 | MultiHeadAttention-80 [-1, 5, 512] 0 201 | Linear-81 [-1, 5, 512] 262,656 202 | Linear-82 [-1, 5, 512] 262,656 203 | Linear-83 [-1, 5, 512] 262,656 204 | MultiHeadAttention-84 [-1, 5, 512] 0 205 | Linear-85 [-1, 5, 512] 262,656 206 | Linear-86 [-1, 5, 512] 262,656 207 | Linear-87 [-1, 5, 512] 262,656 208 | PoswiseFeedForwardNet-88 [-1, 5, 512] 0 209 | Conv1d-89 [-1, 512, 5] 1,050,624 210 | Conv1d-90 [-1, 2048, 5] 1,049,088 211 | DecoderLayer-91 [-1, 5, 512] 0 212 | MultiHeadAttention-92 [-1, 5, 512] 0 213 | Linear-93 [-1, 5, 512] 262,656 214 | Linear-94 [-1, 5, 512] 262,656 215 | Linear-95 [-1, 5, 512] 262,656 216 | MultiHeadAttention-96 [-1, 5, 512] 0 217 | Linear-97 [-1, 5, 512] 262,656 218 | Linear-98 [-1, 5, 512] 262,656 219 | Linear-99 [-1, 5, 512] 262,656 220 | PoswiseFeedForwardNet-100 [-1, 5, 512] 0 221 | Conv1d-101 [-1, 512, 5] 1,050,624 222 | Conv1d-102 [-1, 2048, 5] 1,049,088 223 | DecoderLayer-103 [-1, 5, 512] 0 224 | MultiHeadAttention-104 [-1, 5, 512] 0 225 | Linear-105 [-1, 5, 512] 262,656 226 | Linear-106 [-1, 5, 512] 262,656 227 | Linear-107 [-1, 5, 512] 262,656 228 | MultiHeadAttention-108 [-1, 5, 512] 0 229 | Linear-109 [-1, 5, 512] 262,656 230 | Linear-110 [-1, 5, 512] 262,656 231 | Linear-111 [-1, 5, 512] 262,656 232 | PoswiseFeedForwardNet-112 [-1, 5, 512] 0 233 | Conv1d-113 [-1, 512, 5] 1,050,624 234 | Conv1d-114 [-1, 2048, 5] 1,049,088 235 | DecoderLayer-115 [-1, 5, 512] 0 236 | MultiHeadAttention-116 [-1, 5, 512] 0 237 | Linear-117 [-1, 5, 512] 262,656 238 | Linear-118 [-1, 5, 512] 262,656 239 | Linear-119 [-1, 5, 512] 262,656 240 | MultiHeadAttention-120 [-1, 5, 512] 0 241 | Linear-121 [-1, 5, 512] 262,656 242 | Linear-122 [-1, 5, 512] 262,656 243 | Linear-123 [-1, 5, 512] 262,656 244 | PoswiseFeedForwardNet-124 [-1, 5, 512] 0 245 | Conv1d-125 [-1, 512, 5] 1,050,624 246 | Conv1d-126 [-1, 2048, 5] 1,049,088 247 | Linear-127 [-1, 5, 512] 3,584 248 | ======================================================================= 249 | Total params: 39,396,352 250 | Trainable params: 39,390,208 251 | Non-trainable params: 6,144 252 | ``` 253 | 254 | 2) showing output shape 255 | 256 | ``` 257 | # show output shape 258 | summary(model, enc_inputs, dec_inputs, show_input=False) 259 | 260 | ----------------------------------------------------------------------- 261 | Layer (type) Output Shape Param # 262 | ======================================================================= 263 | Embedding-1 [-1, 5, 512] 3,072 264 | Embedding-2 [-1, 5, 512] 3,072 265 | Linear-3 [-1, 5, 512] 262,656 266 | Linear-4 [-1, 5, 512] 262,656 267 | Linear-5 [-1, 5, 512] 262,656 268 | MultiHeadAttention-6 [-1, 8, 5, 5] 0 269 | Conv1d-7 [-1, 2048, 5] 1,050,624 270 | Conv1d-8 [-1, 512, 5] 1,049,088 271 | PoswiseFeedForwardNet-9 [-1, 5, 512] 0 272 | EncoderLayer-10 [-1, 8, 5, 5] 0 273 | Linear-11 [-1, 5, 512] 262,656 274 | Linear-12 [-1, 5, 512] 262,656 275 | Linear-13 [-1, 5, 512] 262,656 276 | MultiHeadAttention-14 [-1, 8, 5, 5] 0 277 | Conv1d-15 [-1, 2048, 5] 1,050,624 278 | Conv1d-16 [-1, 512, 5] 1,049,088 279 | PoswiseFeedForwardNet-17 [-1, 5, 512] 0 280 | EncoderLayer-18 [-1, 8, 5, 5] 0 281 | Linear-19 [-1, 5, 512] 262,656 282 | Linear-20 [-1, 5, 512] 262,656 283 | Linear-21 [-1, 5, 512] 262,656 284 | MultiHeadAttention-22 [-1, 8, 5, 5] 0 285 | Conv1d-23 [-1, 2048, 5] 1,050,624 286 | Conv1d-24 [-1, 512, 5] 1,049,088 287 | PoswiseFeedForwardNet-25 [-1, 5, 512] 0 288 | EncoderLayer-26 [-1, 8, 5, 5] 0 289 | Linear-27 [-1, 5, 512] 262,656 290 | Linear-28 [-1, 5, 512] 262,656 291 | Linear-29 [-1, 5, 512] 262,656 292 | MultiHeadAttention-30 [-1, 8, 5, 5] 0 293 | Conv1d-31 [-1, 2048, 5] 1,050,624 294 | Conv1d-32 [-1, 512, 5] 1,049,088 295 | PoswiseFeedForwardNet-33 [-1, 5, 512] 0 296 | EncoderLayer-34 [-1, 8, 5, 5] 0 297 | Linear-35 [-1, 5, 512] 262,656 298 | Linear-36 [-1, 5, 512] 262,656 299 | Linear-37 [-1, 5, 512] 262,656 300 | MultiHeadAttention-38 [-1, 8, 5, 5] 0 301 | Conv1d-39 [-1, 2048, 5] 1,050,624 302 | Conv1d-40 [-1, 512, 5] 1,049,088 303 | PoswiseFeedForwardNet-41 [-1, 5, 512] 0 304 | EncoderLayer-42 [-1, 8, 5, 5] 0 305 | Linear-43 [-1, 5, 512] 262,656 306 | Linear-44 [-1, 5, 512] 262,656 307 | Linear-45 [-1, 5, 512] 262,656 308 | MultiHeadAttention-46 [-1, 8, 5, 5] 0 309 | Conv1d-47 [-1, 2048, 5] 1,050,624 310 | Conv1d-48 [-1, 512, 5] 1,049,088 311 | PoswiseFeedForwardNet-49 [-1, 5, 512] 0 312 | EncoderLayer-50 [-1, 8, 5, 5] 0 313 | Encoder-51 [-1, 8, 5, 5] 0 314 | Embedding-52 [-1, 5, 512] 3,584 315 | Embedding-53 [-1, 5, 512] 3,072 316 | Linear-54 [-1, 5, 512] 262,656 317 | Linear-55 [-1, 5, 512] 262,656 318 | Linear-56 [-1, 5, 512] 262,656 319 | MultiHeadAttention-57 [-1, 8, 5, 5] 0 320 | Linear-58 [-1, 5, 512] 262,656 321 | Linear-59 [-1, 5, 512] 262,656 322 | Linear-60 [-1, 5, 512] 262,656 323 | MultiHeadAttention-61 [-1, 8, 5, 5] 0 324 | Conv1d-62 [-1, 2048, 5] 1,050,624 325 | Conv1d-63 [-1, 512, 5] 1,049,088 326 | PoswiseFeedForwardNet-64 [-1, 5, 512] 0 327 | DecoderLayer-65 [-1, 8, 5, 5] 0 328 | Linear-66 [-1, 5, 512] 262,656 329 | Linear-67 [-1, 5, 512] 262,656 330 | Linear-68 [-1, 5, 512] 262,656 331 | MultiHeadAttention-69 [-1, 8, 5, 5] 0 332 | Linear-70 [-1, 5, 512] 262,656 333 | Linear-71 [-1, 5, 512] 262,656 334 | Linear-72 [-1, 5, 512] 262,656 335 | MultiHeadAttention-73 [-1, 8, 5, 5] 0 336 | Conv1d-74 [-1, 2048, 5] 1,050,624 337 | Conv1d-75 [-1, 512, 5] 1,049,088 338 | PoswiseFeedForwardNet-76 [-1, 5, 512] 0 339 | DecoderLayer-77 [-1, 8, 5, 5] 0 340 | Linear-78 [-1, 5, 512] 262,656 341 | Linear-79 [-1, 5, 512] 262,656 342 | Linear-80 [-1, 5, 512] 262,656 343 | MultiHeadAttention-81 [-1, 8, 5, 5] 0 344 | Linear-82 [-1, 5, 512] 262,656 345 | Linear-83 [-1, 5, 512] 262,656 346 | Linear-84 [-1, 5, 512] 262,656 347 | MultiHeadAttention-85 [-1, 8, 5, 5] 0 348 | Conv1d-86 [-1, 2048, 5] 1,050,624 349 | Conv1d-87 [-1, 512, 5] 1,049,088 350 | PoswiseFeedForwardNet-88 [-1, 5, 512] 0 351 | DecoderLayer-89 [-1, 8, 5, 5] 0 352 | Linear-90 [-1, 5, 512] 262,656 353 | Linear-91 [-1, 5, 512] 262,656 354 | Linear-92 [-1, 5, 512] 262,656 355 | MultiHeadAttention-93 [-1, 8, 5, 5] 0 356 | Linear-94 [-1, 5, 512] 262,656 357 | Linear-95 [-1, 5, 512] 262,656 358 | Linear-96 [-1, 5, 512] 262,656 359 | MultiHeadAttention-97 [-1, 8, 5, 5] 0 360 | Conv1d-98 [-1, 2048, 5] 1,050,624 361 | Conv1d-99 [-1, 512, 5] 1,049,088 362 | PoswiseFeedForwardNet-100 [-1, 5, 512] 0 363 | DecoderLayer-101 [-1, 8, 5, 5] 0 364 | Linear-102 [-1, 5, 512] 262,656 365 | Linear-103 [-1, 5, 512] 262,656 366 | Linear-104 [-1, 5, 512] 262,656 367 | MultiHeadAttention-105 [-1, 8, 5, 5] 0 368 | Linear-106 [-1, 5, 512] 262,656 369 | Linear-107 [-1, 5, 512] 262,656 370 | Linear-108 [-1, 5, 512] 262,656 371 | MultiHeadAttention-109 [-1, 8, 5, 5] 0 372 | Conv1d-110 [-1, 2048, 5] 1,050,624 373 | Conv1d-111 [-1, 512, 5] 1,049,088 374 | PoswiseFeedForwardNet-112 [-1, 5, 512] 0 375 | DecoderLayer-113 [-1, 8, 5, 5] 0 376 | Linear-114 [-1, 5, 512] 262,656 377 | Linear-115 [-1, 5, 512] 262,656 378 | Linear-116 [-1, 5, 512] 262,656 379 | MultiHeadAttention-117 [-1, 8, 5, 5] 0 380 | Linear-118 [-1, 5, 512] 262,656 381 | Linear-119 [-1, 5, 512] 262,656 382 | Linear-120 [-1, 5, 512] 262,656 383 | MultiHeadAttention-121 [-1, 8, 5, 5] 0 384 | Conv1d-122 [-1, 2048, 5] 1,050,624 385 | Conv1d-123 [-1, 512, 5] 1,049,088 386 | PoswiseFeedForwardNet-124 [-1, 5, 512] 0 387 | DecoderLayer-125 [-1, 8, 5, 5] 0 388 | Decoder-126 [-1, 8, 5, 5] 0 389 | Linear-127 [-1, 5, 7] 3,584 390 | ======================================================================= 391 | Total params: 39,396,352 392 | Trainable params: 39,390,208 393 | Non-trainable params: 6,144 394 | ----------------------------------------------------------------------- 395 | ``` 396 | 397 | 3) showing hierarchical summary 398 | 399 | ``` 400 | Transformer( 401 | (encoder): Encoder( 402 | (src_emb): Embedding(6, 512), 3,072 params 403 | (pos_emb): Embedding(6, 512), 3,072 params 404 | (layers): ModuleList( 405 | (0): EncoderLayer( 406 | (enc_self_attn): MultiHeadAttention( 407 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 408 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 409 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 410 | ), 787,968 params 411 | (pos_ffn): PoswiseFeedForwardNet( 412 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 413 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 414 | ), 2,099,712 params 415 | ), 2,887,680 params 416 | (1): EncoderLayer( 417 | (enc_self_attn): MultiHeadAttention( 418 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 419 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 420 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 421 | ), 787,968 params 422 | (pos_ffn): PoswiseFeedForwardNet( 423 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 424 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 425 | ), 2,099,712 params 426 | ), 2,887,680 params 427 | (2): EncoderLayer( 428 | (enc_self_attn): MultiHeadAttention( 429 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 430 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 431 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 432 | ), 787,968 params 433 | (pos_ffn): PoswiseFeedForwardNet( 434 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 435 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 436 | ), 2,099,712 params 437 | ), 2,887,680 params 438 | (3): EncoderLayer( 439 | (enc_self_attn): MultiHeadAttention( 440 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 441 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 442 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 443 | ), 787,968 params 444 | (pos_ffn): PoswiseFeedForwardNet( 445 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 446 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 447 | ), 2,099,712 params 448 | ), 2,887,680 params 449 | (4): EncoderLayer( 450 | (enc_self_attn): MultiHeadAttention( 451 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 452 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 453 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 454 | ), 787,968 params 455 | (pos_ffn): PoswiseFeedForwardNet( 456 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 457 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 458 | ), 2,099,712 params 459 | ), 2,887,680 params 460 | (5): EncoderLayer( 461 | (enc_self_attn): MultiHeadAttention( 462 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 463 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 464 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 465 | ), 787,968 params 466 | (pos_ffn): PoswiseFeedForwardNet( 467 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 468 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 469 | ), 2,099,712 params 470 | ), 2,887,680 params 471 | ), 17,326,080 params 472 | ), 17,332,224 params 473 | (decoder): Decoder( 474 | (tgt_emb): Embedding(7, 512), 3,584 params 475 | (pos_emb): Embedding(6, 512), 3,072 params 476 | (layers): ModuleList( 477 | (0): DecoderLayer( 478 | (dec_self_attn): MultiHeadAttention( 479 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 480 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 481 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 482 | ), 787,968 params 483 | (dec_enc_attn): MultiHeadAttention( 484 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 485 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 486 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 487 | ), 787,968 params 488 | (pos_ffn): PoswiseFeedForwardNet( 489 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 490 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 491 | ), 2,099,712 params 492 | ), 3,675,648 params 493 | (1): DecoderLayer( 494 | (dec_self_attn): MultiHeadAttention( 495 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 496 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 497 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 498 | ), 787,968 params 499 | (dec_enc_attn): MultiHeadAttention( 500 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 501 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 502 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 503 | ), 787,968 params 504 | (pos_ffn): PoswiseFeedForwardNet( 505 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 506 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 507 | ), 2,099,712 params 508 | ), 3,675,648 params 509 | (2): DecoderLayer( 510 | (dec_self_attn): MultiHeadAttention( 511 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 512 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 513 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 514 | ), 787,968 params 515 | (dec_enc_attn): MultiHeadAttention( 516 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 517 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 518 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 519 | ), 787,968 params 520 | (pos_ffn): PoswiseFeedForwardNet( 521 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 522 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 523 | ), 2,099,712 params 524 | ), 3,675,648 params 525 | (3): DecoderLayer( 526 | (dec_self_attn): MultiHeadAttention( 527 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 528 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 529 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 530 | ), 787,968 params 531 | (dec_enc_attn): MultiHeadAttention( 532 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 533 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 534 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 535 | ), 787,968 params 536 | (pos_ffn): PoswiseFeedForwardNet( 537 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 538 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 539 | ), 2,099,712 params 540 | ), 3,675,648 params 541 | (4): DecoderLayer( 542 | (dec_self_attn): MultiHeadAttention( 543 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 544 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 545 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 546 | ), 787,968 params 547 | (dec_enc_attn): MultiHeadAttention( 548 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 549 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 550 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 551 | ), 787,968 params 552 | (pos_ffn): PoswiseFeedForwardNet( 553 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 554 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 555 | ), 2,099,712 params 556 | ), 3,675,648 params 557 | (5): DecoderLayer( 558 | (dec_self_attn): MultiHeadAttention( 559 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 560 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 561 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 562 | ), 787,968 params 563 | (dec_enc_attn): MultiHeadAttention( 564 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 565 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 566 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 567 | ), 787,968 params 568 | (pos_ffn): PoswiseFeedForwardNet( 569 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 570 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 571 | ), 2,099,712 params 572 | ), 3,675,648 params 573 | ), 22,053,888 params 574 | ), 22,060,544 params 575 | (projection): Linear(in_features=512, out_features=7, bias=False), 3,584 params 576 | ), 39,396,352 params 577 | 578 | ``` 579 | 580 | 581 | 582 | ## Reference 583 | 584 | ```python 585 | code_reference = { 'https://github.com/pytorch/pytorch/issues/2001', 586 | 'https://gist.github.com/HTLife/b6640af9d6e7d765411f8aa9aa94b837', 587 | 'https://github.com/sksq96/pytorch-summary', 588 | 'Inspired by https://github.com/sksq96/pytorch-summary'} 589 | ``` -------------------------------------------------------------------------------- /examples/CNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modelsummary import summary 6 | 7 | class Net(nn.Module): 8 | def __init__(self): 9 | super(Net, self).__init__() 10 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 11 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 12 | self.conv2_drop = nn.Dropout2d() 13 | self.fc1 = nn.Linear(320, 50) 14 | self.fc2 = nn.Linear(50, 10) 15 | 16 | def forward(self, x): 17 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 18 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 19 | x = x.view(-1, 320) 20 | x = F.relu(self.fc1(x)) 21 | x = F.dropout(x, training=self.training) 22 | x = self.fc2(x) 23 | return F.log_softmax(x, dim=1) 24 | 25 | # show input shape 26 | summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True) 27 | 28 | # show output shape 29 | summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False) -------------------------------------------------------------------------------- /examples/Transformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Tae Hwan Jung(Jeff Jung) @graykode 3 | Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch 4 | https://github.com/JayParks/transformer 5 | ''' 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | 11 | from modelsummary import summary 12 | 13 | # S: Symbol that shows starting of decoding input 14 | # E: Symbol that shows starting of decoding output 15 | # P: Symbol that will fill in blank sequence if current batch data size is short than time steps 16 | sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E'] 17 | 18 | # Transformer Parameters 19 | src_vocab = {'PAD' : 0} 20 | for i, w in enumerate(sentences[0].split()): 21 | src_vocab[w] = i+1 22 | src_vocab_size = len(src_vocab) 23 | 24 | tgt_vocab = {'PAD' : 0} 25 | number_dict = {0 : 'PAD'} 26 | for i, w in enumerate(set((sentences[1]+' '+sentences[2]).split())): 27 | tgt_vocab[w] = i+1 28 | number_dict[i+1] = w 29 | tgt_vocab_size = len(tgt_vocab) 30 | 31 | src_len = tgt_len= 5 32 | 33 | d_model = 512 # Embedding Size 34 | d_ff = 2048 # FeedForward dimension 35 | d_k = d_v = 64 # dimension of K(=Q), V 36 | n_layers = 6 # number of Encoder of Decoder Layer 37 | n_heads = 8 # number of heads in Multi-Head Attention 38 | 39 | def make_batch(sentences): 40 | input_batch = [[src_vocab[n] for n in sentences[0].split()]] 41 | output_batch = [[tgt_vocab[n] for n in sentences[1].split()]] 42 | target_batch = [[tgt_vocab[n] for n in sentences[2].split()]] 43 | return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch)) 44 | 45 | def get_sinusoid_encoding_table(n_position, d_model): 46 | def cal_angle(position, hid_idx): 47 | return position / np.power(10000, 2 * (hid_idx // 2) / d_model) 48 | def get_posi_angle_vec(position): 49 | return [cal_angle(position, hid_j) for hid_j in range(d_model)] 50 | 51 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 52 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 53 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 54 | return torch.FloatTensor(sinusoid_table) 55 | 56 | def get_attn_pad_mask(seq_q, seq_k): 57 | batch_size, len_q = seq_q.size() 58 | batch_size, len_k = seq_k.size() 59 | pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) 60 | return pad_attn_mask.expand(batch_size, len_q, len_k) 61 | 62 | def get_attn_subsequent_mask(seq): 63 | attn_shape = [seq.size(0), seq.size(1), seq.size(1)] 64 | subsequent_mask = np.triu(np.ones(attn_shape), k=1) 65 | subsequent_mask = torch.from_numpy(subsequent_mask).byte() 66 | return subsequent_mask 67 | 68 | class ScaledDotProductAttention(nn.Module): 69 | def __init__(self): 70 | super(ScaledDotProductAttention, self).__init__() 71 | 72 | def forward(self, Q, K, V, attn_mask): 73 | scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) 74 | scores.masked_fill_(attn_mask, -1e9) 75 | attn = nn.Softmax(dim=-1)(scores) 76 | context = torch.matmul(attn, V) 77 | return context, attn 78 | 79 | class MultiHeadAttention(nn.Module): 80 | def __init__(self): 81 | super(MultiHeadAttention, self).__init__() 82 | self.W_Q = nn.Linear(d_model, d_k * n_heads) 83 | self.W_K = nn.Linear(d_model, d_k * n_heads) 84 | self.W_V = nn.Linear(d_model, d_v * n_heads) 85 | def forward(self, Q, K, V, attn_mask): 86 | residual, batch_size = Q, Q.size(0) 87 | q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) 88 | k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) 89 | v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) 90 | 91 | attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) 92 | 93 | context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask) 94 | context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) 95 | output = nn.Linear(n_heads * d_v, d_model)(context) 96 | return nn.LayerNorm(d_model)(output + residual), attn 97 | 98 | class PoswiseFeedForwardNet(nn.Module): 99 | def __init__(self): 100 | super(PoswiseFeedForwardNet, self).__init__() 101 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 102 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 103 | 104 | def forward(self, inputs): 105 | residual = inputs 106 | output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) 107 | output = self.conv2(output).transpose(1, 2) 108 | return nn.LayerNorm(d_model)(output + residual) 109 | 110 | class EncoderLayer(nn.Module): 111 | def __init__(self): 112 | super(EncoderLayer, self).__init__() 113 | self.enc_self_attn = MultiHeadAttention() 114 | self.pos_ffn = PoswiseFeedForwardNet() 115 | 116 | def forward(self, enc_inputs, enc_self_attn_mask): 117 | enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) 118 | enc_outputs = self.pos_ffn(enc_outputs) 119 | return enc_outputs, attn 120 | 121 | class DecoderLayer(nn.Module): 122 | def __init__(self): 123 | super(DecoderLayer, self).__init__() 124 | self.dec_self_attn = MultiHeadAttention() 125 | self.dec_enc_attn = MultiHeadAttention() 126 | self.pos_ffn = PoswiseFeedForwardNet() 127 | 128 | def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): 129 | dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) 130 | dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) 131 | dec_outputs = self.pos_ffn(dec_outputs) 132 | return dec_outputs, dec_self_attn, dec_enc_attn 133 | 134 | class Encoder(nn.Module): 135 | def __init__(self): 136 | super(Encoder, self).__init__() 137 | self.src_emb = nn.Embedding(src_vocab_size, d_model) 138 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_len+1 , d_model),freeze=True) 139 | self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) 140 | 141 | def forward(self, enc_inputs): 142 | enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]])) 143 | enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) 144 | enc_self_attns = [] 145 | for layer in self.layers: 146 | enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) 147 | enc_self_attns.append(enc_self_attn) 148 | return enc_outputs, enc_self_attns 149 | 150 | class Decoder(nn.Module): 151 | def __init__(self): 152 | super(Decoder, self).__init__() 153 | self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) 154 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_len+1 , d_model),freeze=True) 155 | self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) 156 | 157 | def forward(self, dec_inputs, enc_inputs, enc_outputs): 158 | dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]])) 159 | dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) 160 | dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) 161 | dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) 162 | 163 | dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) 164 | 165 | dec_self_attns, dec_enc_attns = [], [] 166 | for layer in self.layers: 167 | dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask) 168 | dec_self_attns.append(dec_self_attn) 169 | dec_enc_attns.append(dec_enc_attn) 170 | return dec_outputs, dec_self_attns, dec_enc_attns 171 | 172 | class Transformer(nn.Module): 173 | def __init__(self): 174 | super(Transformer, self).__init__() 175 | self.encoder = Encoder() 176 | self.decoder = Decoder() 177 | self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False) 178 | 179 | def forward(self, enc_inputs, dec_inputs): 180 | enc_outputs, enc_self_attns = self.encoder(enc_inputs) 181 | dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) 182 | dec_logits = self.projection(dec_outputs) 183 | return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns 184 | 185 | model = Transformer() 186 | enc_inputs, dec_inputs, target_batch = make_batch(sentences) 187 | 188 | # show input shape 189 | summary(model, enc_inputs, dec_inputs, show_input=True) 190 | 191 | # show output shape 192 | summary(model, enc_inputs, dec_inputs, show_input=False) 193 | 194 | # show hierarchical struct 195 | summary(model, enc_inputs, dec_inputs, show_hierarchical=True) -------------------------------------------------------------------------------- /modelsummary/__init__.py: -------------------------------------------------------------------------------- 1 | from modelsummary.modelsummary import * -------------------------------------------------------------------------------- /modelsummary/hierarchicalsummary.py: -------------------------------------------------------------------------------- 1 | """ 2 | code reference : https://github.com/pytorch/pytorch/issues/2001, 3 | """ 4 | from functools import reduce 5 | 6 | from torch.nn.modules.module import _addindent 7 | 8 | def hierarchicalsummary(model): 9 | def repr(model): 10 | # We treat the extra repr like the sub-module, one item per line 11 | extra_lines = [] 12 | extra_repr = model.extra_repr() 13 | # empty string will be split into list [''] 14 | if extra_repr: 15 | extra_lines = extra_repr.split('\n') 16 | child_lines = [] 17 | total_params = 0 18 | for key, module in model._modules.items(): 19 | mod_str, num_params = repr(module) 20 | mod_str = _addindent(mod_str, 2) 21 | child_lines.append('(' + key + '): ' + mod_str) 22 | total_params += num_params 23 | lines = extra_lines + child_lines 24 | 25 | for name, p in model._parameters.items(): 26 | if p is not None: 27 | total_params += reduce(lambda x, y: x * y, p.shape) 28 | 29 | main_str = model._get_name() + '(' 30 | if lines: 31 | # simple one-liner info, which most builtin Modules will use 32 | if len(extra_lines) == 1 and not child_lines: 33 | main_str += extra_lines[0] 34 | else: 35 | main_str += '\n ' + '\n '.join(lines) + '\n' 36 | 37 | main_str += ')' 38 | main_str += ', {:,} params'.format(total_params) 39 | return main_str, total_params 40 | 41 | string, count = repr(model) 42 | print(string) 43 | return count -------------------------------------------------------------------------------- /modelsummary/modelsummary.py: -------------------------------------------------------------------------------- 1 | """ 2 | code by Tae Hwan Jung(Jeff Jung) @graykode 3 | code reference : https://github.com/pytorch/pytorch/issues/2001, 4 | https://gist.github.com/HTLife/b6640af9d6e7d765411f8aa9aa94b837, 5 | https://github.com/sksq96/pytorch-summary 6 | Inspired by https://github.com/sksq96/pytorch-summary 7 | But 'torchsummary' module only works in Vision Network model So I fixed it! 8 | """ 9 | 10 | import torch 11 | import numpy as np 12 | import torch.nn as nn 13 | from collections import OrderedDict 14 | from modelsummary.hierarchicalsummary import hierarchicalsummary 15 | 16 | def summary(model, *inputs, batch_size=-1, show_input=True, show_hierarchical=False): 17 | if show_hierarchical is True: 18 | hierarchicalsummary(model) 19 | 20 | def register_hook(module): 21 | 22 | def hook(module, input, output=None): 23 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 24 | module_idx = len(summary) 25 | 26 | m_key = "%s-%i" % (class_name, module_idx + 1) 27 | summary[m_key] = OrderedDict() 28 | 29 | if len(input) != 0 : 30 | summary[m_key]["input_shape"] = list(input[0].size()) 31 | summary[m_key]["input_shape"][0] = batch_size 32 | else: 33 | summary[m_key]["input_shape"] = input 34 | 35 | if show_input is False and output is not None: 36 | if isinstance(output, (list, tuple)): 37 | for out in output: 38 | if isinstance(out, torch.Tensor): 39 | summary[m_key]["output_shape"] = [ 40 | [-1] + list(out.size())[1:] 41 | ][0] 42 | else: 43 | summary[m_key]["output_shape"] = [ 44 | [-1] + list(out[0].size())[1:] 45 | ][0] 46 | else: 47 | summary[m_key]["output_shape"] = list(output.size()) 48 | summary[m_key]["output_shape"][0] = batch_size 49 | 50 | params = 0 51 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 52 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 53 | summary[m_key]["trainable"] = module.weight.requires_grad 54 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 55 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 56 | summary[m_key]["nb_params"] = params 57 | 58 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model)): 59 | if show_input is True: 60 | hooks.append(module.register_forward_pre_hook(hook)) 61 | else: 62 | hooks.append(module.register_forward_hook(hook)) 63 | 64 | # create properties 65 | summary = OrderedDict() 66 | hooks = [] 67 | 68 | # register hook 69 | model.apply(register_hook) 70 | model(*inputs) 71 | 72 | # remove these hooks 73 | for h in hooks: 74 | h.remove() 75 | 76 | if show_hierarchical is False: 77 | print("-----------------------------------------------------------------------") 78 | if show_input is True: 79 | line_new = "{:>25} {:>25} {:>15}".format("Layer (type)", "Input Shape", "Param #") 80 | else: 81 | line_new = "{:>25} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 82 | if show_hierarchical is False: 83 | print(line_new) 84 | print("=======================================================================") 85 | 86 | total_params = 0 87 | total_output = 0 88 | trainable_params = 0 89 | for layer in summary: 90 | # input_shape, output_shape, trainable, nb_params 91 | if show_input is True: 92 | line_new = "{:>25} {:>25} {:>15}".format( 93 | layer, 94 | str(summary[layer]["input_shape"]), 95 | "{0:,}".format(summary[layer]["nb_params"]), 96 | ) 97 | else: 98 | line_new = "{:>25} {:>25} {:>15}".format( 99 | layer, 100 | str(summary[layer]["output_shape"]), 101 | "{0:,}".format(summary[layer]["nb_params"]), 102 | ) 103 | 104 | total_params += summary[layer]["nb_params"] 105 | if show_input is True: 106 | total_output += np.prod(summary[layer]["input_shape"]) 107 | else: 108 | total_output += np.prod(summary[layer]["output_shape"]) 109 | if "trainable" in summary[layer]: 110 | if summary[layer]["trainable"] == True: 111 | trainable_params += summary[layer]["nb_params"] 112 | 113 | if show_hierarchical is False: 114 | print(line_new) 115 | 116 | print("=======================================================================") 117 | print("Total params: {0:,}".format(total_params)) 118 | print("Trainable params: {0:,}".format(trainable_params)) 119 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 120 | print("-----------------------------------------------------------------------") -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('README.md') as f: 4 | long_description = f.read() 5 | 6 | setup_info = dict( 7 | name='modelsummary', 8 | version='1.1.6', 9 | author='Tae Hwan Jung(@graykode)', 10 | author_email='nlkey2022@gmail.com', 11 | url='https://github.com/graykode/modelsummary', 12 | description='All Model summary in PyTorch similar to `model.summary()` in Keras', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', # This is important! 15 | license='MIT', 16 | install_requires=[ 'tqdm', 'torch', 'numpy'], 17 | keywords='pytorch model summary model.summary()', 18 | packages=["modelsummary"], 19 | ) 20 | 21 | setup(**setup_info) --------------------------------------------------------------------------------