├── .gitignore ├── pytorch_model_summary ├── __init__.py ├── hierarchical_summary.py └── model_summary.py ├── setup.py ├── LICENSE ├── examples ├── CNN.py └── Transformer.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | build/ 3 | dist/ 4 | *.egg-info/ 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /pytorch_model_summary/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_model_summary.model_summary import * 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open('README.md', encoding='us-ascii') as f: 4 | long_description = f.read() 5 | 6 | setup_info = dict( 7 | name='pytorch_model_summary', 8 | version='0.1.2', 9 | author='Alison Marczewski', 10 | author_email='alison.marczewski@gmail.com', 11 | url='https://github.com/amarczew/pytorch_model_summary', 12 | description='It is a Keras style model.summary() implementation for PyTorch', 13 | long_description_content_type='text/markdown', # This is important! 14 | long_description=long_description, 15 | license='MIT', 16 | install_requires=['tqdm', 'torch', 'numpy'], 17 | keywords='pytorch model summary model.summary() keras', 18 | packages=['pytorch_model_summary'], 19 | python_requires='>=3.6' 20 | ) 21 | 22 | setup(**setup_info) 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alison Marczewski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/CNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from pytorch_model_summary import summary 6 | 7 | 8 | class Net(nn.Module): 9 | def __init__(self): 10 | super(Net, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 12 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 13 | self.conv2_drop = nn.Dropout2d() 14 | self.fc1 = nn.Linear(320, 50) 15 | self.fc2 = nn.Linear(50, 10) 16 | 17 | def forward(self, x): 18 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 19 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 20 | x = x.view(-1, 320) 21 | x = F.relu(self.fc1(x)) 22 | x = F.dropout(x, training=self.training) 23 | x = self.fc2(x) 24 | return F.log_softmax(x, dim=1) 25 | 26 | 27 | # show input shape 28 | print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True)) 29 | 30 | # show output shape 31 | print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False)) 32 | 33 | # show output shape and hierarchical view of net 34 | print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True)) 35 | -------------------------------------------------------------------------------- /pytorch_model_summary/hierarchical_summary.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a fork from: https://github.com/graykode/modelsummary 3 | """ 4 | 5 | from functools import reduce 6 | 7 | from torch.nn.modules.module import _addindent 8 | 9 | 10 | def hierarchical_summary(model, print_summary=False): 11 | 12 | def repr(model): 13 | # We treat the extra repr like the sub-module, one item per line 14 | extra_lines = [] 15 | extra_repr = model.extra_repr() 16 | # empty string will be split into list [''] 17 | if extra_repr: 18 | extra_lines = extra_repr.split('\n') 19 | child_lines = [] 20 | total_params = 0 21 | for key, module in model._modules.items(): 22 | if module is None: 23 | continue 24 | mod_str, num_params = repr(module) 25 | mod_str = _addindent(mod_str, 2) 26 | child_lines.append('(' + key + '): ' + mod_str) 27 | total_params += num_params 28 | lines = extra_lines + child_lines 29 | 30 | for name, p in model._parameters.items(): 31 | if p is not None: 32 | total_params += reduce(lambda x, y: x * y, p.shape) 33 | 34 | main_str = model._get_name() + '(' 35 | if lines: 36 | # simple one-liner info, which most builtin Modules will use 37 | if len(extra_lines) == 1 and not child_lines: 38 | main_str += extra_lines[0] 39 | else: 40 | main_str += '\n ' + '\n '.join(lines) + '\n' 41 | 42 | main_str += ')' 43 | main_str += ', {:,} params'.format(total_params) 44 | return main_str, total_params 45 | 46 | string, count = repr(model) 47 | 48 | # Building hierarchical output 49 | _pad = int(max(max(len(_) for _ in string.split('\n')) - 20, 0) / 2) 50 | lines = list() 51 | lines.append('=' * _pad + ' Hierarchical Summary ' + '=' * _pad + '\n') 52 | lines.append(string) 53 | lines.append('\n\n' + '=' * (_pad * 2 + 22) + '\n') 54 | 55 | str_summary = '\n'.join(lines) 56 | if print_summary: 57 | print(str_summary) 58 | 59 | return str_summary, count 60 | -------------------------------------------------------------------------------- /pytorch_model_summary/model_summary.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a fork from: https://github.com/graykode/modelsummary 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from collections import OrderedDict 8 | from pytorch_model_summary.hierarchical_summary import hierarchical_summary 9 | 10 | 11 | def summary(model, *inputs, batch_size=-1, show_input=False, show_hierarchical=False, 12 | print_summary=False, max_depth=1, show_parent_layers=False): 13 | 14 | max_depth = max_depth if max_depth is not None else 9999 15 | 16 | def build_module_tree(module): 17 | 18 | def _in(module, id_parent, depth): 19 | for c in module.children(): 20 | # ModuleList and Sequential do not count as valid layers 21 | if isinstance(c, (nn.ModuleList, nn.Sequential)): 22 | _in(c, id_parent, depth) 23 | else: 24 | _module_name = str(c.__class__).split(".")[-1].split("'")[0] 25 | _parent_layers = f'{module_summary[id_parent].get("parent_layers")}' \ 26 | f'{"/" if module_summary[id_parent].get("parent_layers") != "" else ""}' \ 27 | f'{module_summary[id_parent]["module_name"]}' 28 | 29 | module_summary[id(c)] = {'module_name': _module_name, 'parent_layers': _parent_layers, 30 | 'id_parent': id_parent, 'depth': depth, 'n_children': 0} 31 | 32 | module_summary[id_parent]['n_children'] += 1 33 | 34 | _in(c, id(c), depth+1) 35 | 36 | # Defining summary for the main module 37 | module_summary[id(module)] = {'module_name': str(module.__class__).split(".")[-1].split("'")[0], 38 | 'parent_layers': '', 'id_parent': None, 'depth': 0, 'n_children': 0} 39 | 40 | _in(module, id_parent=id(module), depth=1) 41 | 42 | # Defining layers that will be printed 43 | for k, v in module_summary.items(): 44 | module_summary[k]['show'] = v['depth'] == max_depth or (v['depth'] < max_depth and v['n_children'] == 0) 45 | 46 | def register_hook(module): 47 | 48 | def shapes(x): 49 | _lst = list() 50 | 51 | def _shapes(_): 52 | if isinstance(_, torch.Tensor): 53 | _lst.append(list(_.size())) 54 | elif isinstance(_, (tuple, list)): 55 | for _x in _: 56 | _shapes(_x) 57 | else: 58 | # TODO: decide what to do when there is an input which is not a tensor 59 | raise Exception('Object not supported') 60 | 61 | _shapes(x) 62 | 63 | return _lst 64 | 65 | def hook(module, input, output=None): 66 | if id(module) in module_mapped: 67 | return 68 | 69 | module_mapped.add(id(module)) 70 | module_name = module_summary.get(id(module)).get('module_name') 71 | module_idx = len(summary) 72 | 73 | m_key = "%s-%i" % (module_name, module_idx + 1) 74 | summary[m_key] = OrderedDict() 75 | summary[m_key]['parent_layers'] = module_summary.get(id(module)).get('parent_layers') 76 | 77 | summary[m_key]["input_shape"] = shapes(input) if len(input) != 0 else input 78 | 79 | if show_input is False and output is not None: 80 | summary[m_key]["output_shape"] = shapes(output) 81 | 82 | params = 0 83 | params_trainable = 0 84 | trainable = False 85 | for m in module.parameters(): 86 | _params = torch.prod(torch.LongTensor(list(m.size()))) 87 | params += _params 88 | params_trainable += _params if m.requires_grad else 0 89 | # if any parameter is trainable, then module is trainable 90 | trainable = trainable or m.requires_grad 91 | 92 | summary[m_key]["nb_params"] = params 93 | summary[m_key]["nb_params_trainable"] = params_trainable 94 | summary[m_key]["trainable"] = trainable 95 | 96 | _map_module = module_summary.get(id(module), None) 97 | if _map_module is not None and _map_module.get('show'): 98 | if show_input is True: 99 | hooks.append(module.register_forward_pre_hook(hook)) 100 | else: 101 | hooks.append(module.register_forward_hook(hook)) 102 | 103 | # create properties 104 | summary = OrderedDict() 105 | module_summary = dict() 106 | module_mapped = set() 107 | hooks = [] 108 | 109 | # register id of parent modules 110 | build_module_tree(model) 111 | 112 | # register hook 113 | model.apply(register_hook) 114 | 115 | model_training = model.training 116 | 117 | model.eval() 118 | model(*inputs) 119 | 120 | if model_training: 121 | model.train() 122 | 123 | # remove these hooks 124 | for h in hooks: 125 | h.remove() 126 | 127 | # params to format output - dynamic width 128 | _key_shape = 'input_shape' if show_input else 'output_shape' 129 | _len_str_parent = max([len(v['parent_layers']) for v in summary.values()] + [13]) + 3 130 | _len_str_layer = max([len(layer) for layer in summary.keys()] + [15]) + 3 131 | _len_str_shapes = max([len(', '.join([str(_) for _ in summary[layer][_key_shape]])) for layer in summary] + [15]) + 3 132 | _len_line = 35 + _len_str_parent * int(show_parent_layers) + _len_str_layer + _len_str_shapes 133 | fmt = ("{:>%d} " % _len_str_parent if show_parent_layers else "") + "{:>%d} {:>%d} {:>15} {:>15}" % (_len_str_layer, _len_str_shapes) 134 | 135 | """ starting to build output text """ 136 | 137 | # Table header 138 | lines = list() 139 | lines.append('-' * _len_line) 140 | _fmt_args = ("Parent Layers",) if show_parent_layers else () 141 | _fmt_args += ("Layer (type)", f'{"Input" if show_input else "Output"} Shape', "Param #", "Tr. Param #") 142 | lines.append(fmt.format(*_fmt_args)) 143 | lines.append('=' * _len_line) 144 | 145 | total_params = 0 146 | trainable_params = 0 147 | for layer in summary: 148 | # Table content (for each layer) 149 | _fmt_args = (summary[layer]["parent_layers"], ) if show_parent_layers else () 150 | _fmt_args += (layer, 151 | ", ".join([str(_) for _ in summary[layer][_key_shape]]), 152 | "{0:,}".format(summary[layer]["nb_params"]), 153 | "{0:,}".format(summary[layer]["nb_params_trainable"])) 154 | line_new = fmt.format(*_fmt_args) 155 | lines.append(line_new) 156 | 157 | total_params += summary[layer]["nb_params"] 158 | trainable_params += summary[layer]["nb_params_trainable"] 159 | 160 | # Table footer 161 | lines.append('=' * _len_line) 162 | lines.append("Total params: {0:,}".format(total_params)) 163 | lines.append("Trainable params: {0:,}".format(trainable_params)) 164 | lines.append("Non-trainable params: {0:,}".format(total_params - trainable_params)) 165 | if batch_size != -1: 166 | lines.append("Batch size: {0:,}".format(batch_size)) 167 | lines.append('-' * _len_line) 168 | 169 | if show_hierarchical: 170 | h_summary, _ = hierarchical_summary(model, print_summary=False) 171 | lines.append('\n') 172 | lines.append(h_summary) 173 | 174 | str_summary = '\n'.join(lines) 175 | if print_summary: 176 | print(str_summary) 177 | 178 | return str_summary 179 | -------------------------------------------------------------------------------- /examples/Transformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Tae Hwan Jung(Jeff Jung) @graykode 3 | Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch 4 | https://github.com/JayParks/transformer 5 | ''' 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | 11 | from pytorch_model_summary import summary 12 | 13 | # S: Symbol that shows starting of decoding input 14 | # E: Symbol that shows starting of decoding output 15 | # P: Symbol that will fill in blank sequence if current batch data size is short than time steps 16 | sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E'] 17 | 18 | # Transformer Parameters 19 | src_vocab = {'PAD' : 0} 20 | for i, w in enumerate(sentences[0].split()): 21 | src_vocab[w] = i+1 22 | src_vocab_size = len(src_vocab) 23 | 24 | tgt_vocab = {'PAD' : 0} 25 | number_dict = {0 : 'PAD'} 26 | for i, w in enumerate(set((sentences[1]+' '+sentences[2]).split())): 27 | tgt_vocab[w] = i+1 28 | number_dict[i+1] = w 29 | tgt_vocab_size = len(tgt_vocab) 30 | 31 | src_len = tgt_len= 5 32 | 33 | d_model = 512 # Embedding Size 34 | d_ff = 2048 # FeedForward dimension 35 | d_k = d_v = 64 # dimension of K(=Q), V 36 | n_layers = 6 # number of Encoder of Decoder Layer 37 | n_heads = 8 # number of heads in Multi-Head Attention 38 | 39 | def make_batch(sentences): 40 | input_batch = [[src_vocab[n] for n in sentences[0].split()]] 41 | output_batch = [[tgt_vocab[n] for n in sentences[1].split()]] 42 | target_batch = [[tgt_vocab[n] for n in sentences[2].split()]] 43 | return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch)) 44 | 45 | def get_sinusoid_encoding_table(n_position, d_model): 46 | def cal_angle(position, hid_idx): 47 | return position / np.power(10000, 2 * (hid_idx // 2) / d_model) 48 | def get_posi_angle_vec(position): 49 | return [cal_angle(position, hid_j) for hid_j in range(d_model)] 50 | 51 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 52 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 53 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 54 | return torch.FloatTensor(sinusoid_table) 55 | 56 | def get_attn_pad_mask(seq_q, seq_k): 57 | batch_size, len_q = seq_q.size() 58 | batch_size, len_k = seq_k.size() 59 | pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) 60 | return pad_attn_mask.expand(batch_size, len_q, len_k) 61 | 62 | def get_attn_subsequent_mask(seq): 63 | attn_shape = [seq.size(0), seq.size(1), seq.size(1)] 64 | subsequent_mask = np.triu(np.ones(attn_shape), k=1) 65 | subsequent_mask = torch.from_numpy(subsequent_mask).byte() 66 | return subsequent_mask 67 | 68 | class ScaledDotProductAttention(nn.Module): 69 | def __init__(self): 70 | super(ScaledDotProductAttention, self).__init__() 71 | 72 | def forward(self, Q, K, V, attn_mask): 73 | scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) 74 | scores.masked_fill_(attn_mask, -1e9) 75 | attn = nn.Softmax(dim=-1)(scores) 76 | context = torch.matmul(attn, V) 77 | return context, attn 78 | 79 | class MultiHeadAttention(nn.Module): 80 | def __init__(self): 81 | super(MultiHeadAttention, self).__init__() 82 | self.W_Q = nn.Linear(d_model, d_k * n_heads) 83 | self.W_K = nn.Linear(d_model, d_k * n_heads) 84 | self.W_V = nn.Linear(d_model, d_v * n_heads) 85 | def forward(self, Q, K, V, attn_mask): 86 | residual, batch_size = Q, Q.size(0) 87 | q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) 88 | k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) 89 | v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) 90 | 91 | attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) 92 | 93 | context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask) 94 | context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) 95 | output = nn.Linear(n_heads * d_v, d_model)(context) 96 | return nn.LayerNorm(d_model)(output + residual), attn 97 | 98 | class PoswiseFeedForwardNet(nn.Module): 99 | def __init__(self): 100 | super(PoswiseFeedForwardNet, self).__init__() 101 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 102 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 103 | 104 | def forward(self, inputs): 105 | residual = inputs 106 | output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) 107 | output = self.conv2(output).transpose(1, 2) 108 | return nn.LayerNorm(d_model)(output + residual) 109 | 110 | class EncoderLayer(nn.Module): 111 | def __init__(self): 112 | super(EncoderLayer, self).__init__() 113 | self.enc_self_attn = MultiHeadAttention() 114 | self.pos_ffn = PoswiseFeedForwardNet() 115 | 116 | def forward(self, enc_inputs, enc_self_attn_mask): 117 | enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) 118 | enc_outputs = self.pos_ffn(enc_outputs) 119 | return enc_outputs, attn 120 | 121 | class DecoderLayer(nn.Module): 122 | def __init__(self): 123 | super(DecoderLayer, self).__init__() 124 | self.dec_self_attn = MultiHeadAttention() 125 | self.dec_enc_attn = MultiHeadAttention() 126 | self.pos_ffn = PoswiseFeedForwardNet() 127 | 128 | def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): 129 | dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) 130 | dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) 131 | dec_outputs = self.pos_ffn(dec_outputs) 132 | return dec_outputs, dec_self_attn, dec_enc_attn 133 | 134 | class Encoder(nn.Module): 135 | def __init__(self): 136 | super(Encoder, self).__init__() 137 | self.src_emb = nn.Embedding(src_vocab_size, d_model) 138 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_len+1 , d_model),freeze=True) 139 | self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) 140 | 141 | def forward(self, enc_inputs): 142 | enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]])) 143 | enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) 144 | enc_self_attns = [] 145 | for layer in self.layers: 146 | enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) 147 | enc_self_attns.append(enc_self_attn) 148 | return enc_outputs, enc_self_attns 149 | 150 | class Decoder(nn.Module): 151 | def __init__(self): 152 | super(Decoder, self).__init__() 153 | self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) 154 | self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_len+1 , d_model),freeze=True) 155 | self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) 156 | 157 | def forward(self, dec_inputs, enc_inputs, enc_outputs): 158 | dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]])) 159 | dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) 160 | dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) 161 | dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) 162 | 163 | dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) 164 | 165 | dec_self_attns, dec_enc_attns = [], [] 166 | for layer in self.layers: 167 | dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask) 168 | dec_self_attns.append(dec_self_attn) 169 | dec_enc_attns.append(dec_enc_attn) 170 | return dec_outputs, dec_self_attns, dec_enc_attns 171 | 172 | class Transformer(nn.Module): 173 | def __init__(self): 174 | super(Transformer, self).__init__() 175 | self.encoder = Encoder() 176 | self.decoder = Decoder() 177 | self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False) 178 | 179 | def forward(self, enc_inputs, dec_inputs): 180 | enc_outputs, enc_self_attns = self.encoder(enc_inputs) 181 | dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) 182 | dec_logits = self.projection(dec_outputs) 183 | return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns 184 | 185 | 186 | model = Transformer() 187 | enc_inputs, dec_inputs, target_batch = make_batch(sentences) 188 | 189 | # show input shape 190 | summary(model, enc_inputs, dec_inputs, show_input=True, print_summary=True) 191 | 192 | # show output shape and batch_size in table. In addition, also hierarchical summary version 193 | summary(model, enc_inputs, dec_inputs, batch_size=1, show_hierarchical=True, print_summary=True) 194 | 195 | # show layers until depth 2 196 | summary(model, enc_inputs, dec_inputs, max_depth=2, print_summary=True) 197 | 198 | # show deepest layers 199 | summary(model, enc_inputs, dec_inputs, max_depth=None, print_summary=True) 200 | 201 | # show layers until depth 3 and add column with parent layers 202 | summary(model, enc_inputs, dec_inputs, max_depth=3, show_parent_layers=True, print_summary=True) 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch Model Summary -- Keras style `model.summary()` for PyTorch 2 | [![PyPI version fury.io](https://badge.fury.io/py/pytorch-model-summary.svg)](https://pypi.python.org/pypi/pytorch-model-summary/) 3 | [![Downloads](https://pepy.tech/badge/pytorch-model-summary)](https://pepy.tech/project/pytorch-model-summary) 4 | 5 | 6 | It is a Keras style model.summary() implementation for PyTorch 7 | 8 | This is an Improved PyTorch library of [modelsummary](https://github.com/graykode/modelsummary). Like in `modelsummary`, **It does not care with number of Input parameter!** 9 | 10 | ### Improvements: 11 | - For user defined pytorch layers, now `summary` can show layers inside it 12 | - some assumptions: when is an user defined layer, if any weight/params/bias is trainable, then it is assumed that this layer is trainable (but only trainable params are counted in Tr. Params #) 13 | - Adding column counting only trainable parameters (it makes sense when there are user defined layers) 14 | - Showing all input/output shapes, instead of showing only the first one 15 | - example: LSTM layer return a Tensor and a tuple (Tensor, Tensor), then output_shape has three set of values 16 | - Printing: table width defined dynamically 17 | - Adding option to add hierarchical summary in output 18 | - Adding batch_size value (when provided) in table footer 19 | - fix bugs 20 | 21 | ### Parameters 22 | **Default values have keras behavior** 23 | ```python 24 | summary(model, *inputs, batch_size=-1, show_input=False, show_hierarchical=False, 25 | print_summary=False, max_depth=1, show_parent_layers=False): 26 | ``` 27 | 28 | - `model`: pytorch model object 29 | - `*inputs`: ... 30 | - `batch_size`: if provided, it is printed in summary table 31 | - `show_input`: show input shape. Otherwise, output shape for each layer. **(Default: False)** 32 | - `show_hierarchical`: in addition of summary table, return hierarchical view of the model **(Default: False)** 33 | - `print_summary`: when true, is not required to use print function outside `summary` method **(Default: False)** 34 | - `max_depth`: it specifies how many times it can go inside user defined layers to show them **(Default: 1)** 35 | - `show_parent_layer`: it adds a column to show parent layers path until reaching current layer in depth. **(Default: False)** 36 | 37 | 38 | 39 | ```python 40 | import torch 41 | import torch.nn as nn 42 | import torch.nn.functional as F 43 | 44 | from pytorch_model_summary import summary 45 | 46 | 47 | class Net(nn.Module): 48 | def __init__(self): 49 | super(Net, self).__init__() 50 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 51 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 52 | self.conv2_drop = nn.Dropout2d() 53 | self.fc1 = nn.Linear(320, 50) 54 | self.fc2 = nn.Linear(50, 10) 55 | 56 | def forward(self, x): 57 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 58 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 59 | x = x.view(-1, 320) 60 | x = F.relu(self.fc1(x)) 61 | x = F.dropout(x, training=self.training) 62 | x = self.fc2(x) 63 | return F.log_softmax(x, dim=1) 64 | 65 | 66 | # show input shape 67 | print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True)) 68 | 69 | # show output shape 70 | print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False)) 71 | 72 | # show output shape and hierarchical view of net 73 | print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True)) 74 | 75 | ``` 76 | 77 | ``` 78 | ----------------------------------------------------------------------- 79 | Layer (type) Input Shape Param # Tr. Param # 80 | ======================================================================= 81 | Conv2d-1 [1, 1, 28, 28] 260 260 82 | Conv2d-2 [1, 10, 12, 12] 5,020 5,020 83 | Dropout2d-3 [1, 20, 8, 8] 0 0 84 | Linear-4 [1, 320] 16,050 16,050 85 | Linear-5 [1, 50] 510 510 86 | ======================================================================= 87 | Total params: 21,840 88 | Trainable params: 21,840 89 | Non-trainable params: 0 90 | ----------------------------------------------------------------------- 91 | ``` 92 | ``` 93 | ----------------------------------------------------------------------- 94 | Layer (type) Output Shape Param # Tr. Param # 95 | ======================================================================= 96 | Conv2d-1 [1, 10, 24, 24] 260 260 97 | Conv2d-2 [1, 20, 8, 8] 5,020 5,020 98 | Dropout2d-3 [1, 20, 8, 8] 0 0 99 | Linear-4 [1, 50] 16,050 16,050 100 | Linear-5 [1, 10] 510 510 101 | ======================================================================= 102 | Total params: 21,840 103 | Trainable params: 21,840 104 | Non-trainable params: 0 105 | ----------------------------------------------------------------------- 106 | ``` 107 | ``` 108 | ----------------------------------------------------------------------- 109 | Layer (type) Output Shape Param # Tr. Param # 110 | ======================================================================= 111 | Conv2d-1 [1, 10, 24, 24] 260 260 112 | Conv2d-2 [1, 20, 8, 8] 5,020 5,020 113 | Dropout2d-3 [1, 20, 8, 8] 0 0 114 | Linear-4 [1, 50] 16,050 16,050 115 | Linear-5 [1, 10] 510 510 116 | ======================================================================= 117 | Total params: 21,840 118 | Trainable params: 21,840 119 | Non-trainable params: 0 120 | ----------------------------------------------------------------------- 121 | =========================== Hierarchical Summary =========================== 122 | Net( 123 | (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1)), 260 params 124 | (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1)), 5,020 params 125 | (conv2_drop): Dropout2d(p=0.5), 0 params 126 | (fc1): Linear(in_features=320, out_features=50, bias=True), 16,050 params 127 | (fc2): Linear(in_features=50, out_features=10, bias=True), 510 params 128 | ), 21,840 params 129 | ============================================================================ 130 | ``` 131 | 132 | 133 | ## Quick Start 134 | 135 | Just download with **pip** 136 | 137 | `pip install pytorch-model-summary` and 138 | ```python 139 | from pytorch_model_summary import summary 140 | ``` 141 | or 142 | ```python 143 | import pytorch_model_summary as pms 144 | pms.summary([params]) 145 | ``` 146 | to avoid reference conflicts with other methods in your code 147 | 148 | You can use this library like this. If you want to see more detail, Please see examples below. 149 | 150 | ## Examples using different set of parameters 151 | 152 | Outputs from [Transformer Model Example](https://github.com/amarczew/pytorch_model_summary/blob/master/examples/Transformer.py) based on [Attention is all you need paper (2017)](https://arxiv.org/abs/1706.03762) 153 | 154 | 1) showing **input shape** 155 | ```python 156 | # show input shape 157 | pms.summary(model, enc_inputs, dec_inputs, show_input=True, print_summary=True) 158 | ``` 159 | ``` 160 | ----------------------------------------------------------------------------------- 161 | Layer (type) Input Shape Param # Tr. Param # 162 | =================================================================================== 163 | Encoder-1 [1, 5] 17,332,224 17,329,152 164 | Decoder-2 [1, 5], [1, 5], [1, 5, 512] 22,060,544 22,057,472 165 | Linear-3 [1, 5, 512] 3,584 3,584 166 | =================================================================================== 167 | Total params: 39,396,352 168 | Trainable params: 39,390,208 169 | Non-trainable params: 6,144 170 | ----------------------------------------------------------------------------------- 171 | ``` 172 | 173 | 2) showing **output shape** and **batch_size** in table. In addition, also **hierarchical summary** version 174 | ```python 175 | # show output shape and batch_size in table. In addition, also hierarchical summary version 176 | pms.summary(model, enc_inputs, dec_inputs, batch_size=1, show_hierarchical=True, print_summary=True) 177 | ``` 178 | ``` 179 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 180 | Layer (type) Output Shape Param # Tr. Param # 181 | =========================================================================================================================================================================================================================================== 182 | Encoder-1 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 17,332,224 17,329,152 183 | Decoder-2 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 22,060,544 22,057,472 184 | Linear-3 [1, 5, 7] 3,584 3,584 185 | =========================================================================================================================================================================================================================================== 186 | Total params: 39,396,352 187 | Trainable params: 39,390,208 188 | Non-trainable params: 6,144 189 | Batch size: 1 190 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 191 | 192 | 193 | ================================ Hierarchical Summary ================================ 194 | 195 | Transformer( 196 | (encoder): Encoder( 197 | (src_emb): Embedding(6, 512), 3,072 params 198 | (pos_emb): Embedding(6, 512), 3,072 params 199 | (layers): ModuleList( 200 | (0): EncoderLayer( 201 | (enc_self_attn): MultiHeadAttention( 202 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 203 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 204 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 205 | ), 787,968 params 206 | (pos_ffn): PoswiseFeedForwardNet( 207 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 208 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 209 | ), 2,099,712 params 210 | ), 2,887,680 params 211 | (1): EncoderLayer( 212 | (enc_self_attn): MultiHeadAttention( 213 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 214 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 215 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 216 | ), 787,968 params 217 | (pos_ffn): PoswiseFeedForwardNet( 218 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 219 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 220 | ), 2,099,712 params 221 | ), 2,887,680 params 222 | (2): EncoderLayer( 223 | (enc_self_attn): MultiHeadAttention( 224 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 225 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 226 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 227 | ), 787,968 params 228 | (pos_ffn): PoswiseFeedForwardNet( 229 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 230 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 231 | ), 2,099,712 params 232 | ), 2,887,680 params 233 | (3): EncoderLayer( 234 | (enc_self_attn): MultiHeadAttention( 235 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 236 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 237 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 238 | ), 787,968 params 239 | (pos_ffn): PoswiseFeedForwardNet( 240 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 241 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 242 | ), 2,099,712 params 243 | ), 2,887,680 params 244 | (4): EncoderLayer( 245 | (enc_self_attn): MultiHeadAttention( 246 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 247 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 248 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 249 | ), 787,968 params 250 | (pos_ffn): PoswiseFeedForwardNet( 251 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 252 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 253 | ), 2,099,712 params 254 | ), 2,887,680 params 255 | (5): EncoderLayer( 256 | (enc_self_attn): MultiHeadAttention( 257 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 258 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 259 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 260 | ), 787,968 params 261 | (pos_ffn): PoswiseFeedForwardNet( 262 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 263 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 264 | ), 2,099,712 params 265 | ), 2,887,680 params 266 | ), 17,326,080 params 267 | ), 17,332,224 params 268 | (decoder): Decoder( 269 | (tgt_emb): Embedding(7, 512), 3,584 params 270 | (pos_emb): Embedding(6, 512), 3,072 params 271 | (layers): ModuleList( 272 | (0): DecoderLayer( 273 | (dec_self_attn): MultiHeadAttention( 274 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 275 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 276 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 277 | ), 787,968 params 278 | (dec_enc_attn): MultiHeadAttention( 279 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 280 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 281 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 282 | ), 787,968 params 283 | (pos_ffn): PoswiseFeedForwardNet( 284 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 285 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 286 | ), 2,099,712 params 287 | ), 3,675,648 params 288 | (1): DecoderLayer( 289 | (dec_self_attn): MultiHeadAttention( 290 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 291 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 292 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 293 | ), 787,968 params 294 | (dec_enc_attn): MultiHeadAttention( 295 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 296 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 297 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 298 | ), 787,968 params 299 | (pos_ffn): PoswiseFeedForwardNet( 300 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 301 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 302 | ), 2,099,712 params 303 | ), 3,675,648 params 304 | (2): DecoderLayer( 305 | (dec_self_attn): MultiHeadAttention( 306 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 307 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 308 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 309 | ), 787,968 params 310 | (dec_enc_attn): MultiHeadAttention( 311 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 312 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 313 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 314 | ), 787,968 params 315 | (pos_ffn): PoswiseFeedForwardNet( 316 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 317 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 318 | ), 2,099,712 params 319 | ), 3,675,648 params 320 | (3): DecoderLayer( 321 | (dec_self_attn): MultiHeadAttention( 322 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 323 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 324 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 325 | ), 787,968 params 326 | (dec_enc_attn): MultiHeadAttention( 327 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 328 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 329 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 330 | ), 787,968 params 331 | (pos_ffn): PoswiseFeedForwardNet( 332 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 333 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 334 | ), 2,099,712 params 335 | ), 3,675,648 params 336 | (4): DecoderLayer( 337 | (dec_self_attn): MultiHeadAttention( 338 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 339 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 340 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 341 | ), 787,968 params 342 | (dec_enc_attn): MultiHeadAttention( 343 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 344 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 345 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 346 | ), 787,968 params 347 | (pos_ffn): PoswiseFeedForwardNet( 348 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 349 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 350 | ), 2,099,712 params 351 | ), 3,675,648 params 352 | (5): DecoderLayer( 353 | (dec_self_attn): MultiHeadAttention( 354 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 355 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 356 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 357 | ), 787,968 params 358 | (dec_enc_attn): MultiHeadAttention( 359 | (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params 360 | (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params 361 | (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params 362 | ), 787,968 params 363 | (pos_ffn): PoswiseFeedForwardNet( 364 | (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params 365 | (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params 366 | ), 2,099,712 params 367 | ), 3,675,648 params 368 | ), 22,053,888 params 369 | ), 22,060,544 params 370 | (projection): Linear(in_features=512, out_features=7, bias=False), 3,584 params 371 | ), 39,396,352 params 372 | 373 | 374 | ====================================================================================== 375 | ``` 376 | 377 | 3) showing **layers until depth 2** 378 | ```python 379 | # show layers until depth 2 380 | pms.summary(model, enc_inputs, dec_inputs, max_depth=2, print_summary=True) 381 | ``` 382 | ``` 383 | ----------------------------------------------------------------------------------------------- 384 | Layer (type) Output Shape Param # Tr. Param # 385 | =============================================================================================== 386 | Embedding-1 [1, 5, 512] 3,072 3,072 387 | Embedding-2 [1, 5, 512] 3,072 0 388 | EncoderLayer-3 [1, 5, 512], [1, 8, 5, 5] 2,887,680 2,887,680 389 | EncoderLayer-4 [1, 5, 512], [1, 8, 5, 5] 2,887,680 2,887,680 390 | EncoderLayer-5 [1, 5, 512], [1, 8, 5, 5] 2,887,680 2,887,680 391 | EncoderLayer-6 [1, 5, 512], [1, 8, 5, 5] 2,887,680 2,887,680 392 | EncoderLayer-7 [1, 5, 512], [1, 8, 5, 5] 2,887,680 2,887,680 393 | EncoderLayer-8 [1, 5, 512], [1, 8, 5, 5] 2,887,680 2,887,680 394 | Embedding-9 [1, 5, 512] 3,584 3,584 395 | Embedding-10 [1, 5, 512] 3,072 0 396 | DecoderLayer-11 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648 3,675,648 397 | DecoderLayer-12 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648 3,675,648 398 | DecoderLayer-13 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648 3,675,648 399 | DecoderLayer-14 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648 3,675,648 400 | DecoderLayer-15 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648 3,675,648 401 | DecoderLayer-16 [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648 3,675,648 402 | Linear-17 [1, 5, 7] 3,584 3,584 403 | =============================================================================================== 404 | Total params: 39,396,352 405 | Trainable params: 39,390,208 406 | Non-trainable params: 6,144 407 | ----------------------------------------------------------------------------------------------- 408 | ``` 409 | 410 | 4) showing **deepest layers** 411 | ```python 412 | # show deepest layers 413 | pms.summary(model, enc_inputs, dec_inputs, max_depth=None, print_summary=True) 414 | ``` 415 | ``` 416 | ----------------------------------------------------------------------- 417 | Layer (type) Output Shape Param # Tr. Param # 418 | ======================================================================= 419 | Embedding-1 [1, 5, 512] 3,072 3,072 420 | Embedding-2 [1, 5, 512] 3,072 0 421 | Linear-3 [1, 5, 512] 262,656 262,656 422 | Linear-4 [1, 5, 512] 262,656 262,656 423 | Linear-5 [1, 5, 512] 262,656 262,656 424 | Conv1d-6 [1, 2048, 5] 1,050,624 1,050,624 425 | Conv1d-7 [1, 512, 5] 1,049,088 1,049,088 426 | Linear-8 [1, 5, 512] 262,656 262,656 427 | Linear-9 [1, 5, 512] 262,656 262,656 428 | Linear-10 [1, 5, 512] 262,656 262,656 429 | Conv1d-11 [1, 2048, 5] 1,050,624 1,050,624 430 | Conv1d-12 [1, 512, 5] 1,049,088 1,049,088 431 | Linear-13 [1, 5, 512] 262,656 262,656 432 | Linear-14 [1, 5, 512] 262,656 262,656 433 | Linear-15 [1, 5, 512] 262,656 262,656 434 | Conv1d-16 [1, 2048, 5] 1,050,624 1,050,624 435 | Conv1d-17 [1, 512, 5] 1,049,088 1,049,088 436 | Linear-18 [1, 5, 512] 262,656 262,656 437 | Linear-19 [1, 5, 512] 262,656 262,656 438 | Linear-20 [1, 5, 512] 262,656 262,656 439 | Conv1d-21 [1, 2048, 5] 1,050,624 1,050,624 440 | Conv1d-22 [1, 512, 5] 1,049,088 1,049,088 441 | Linear-23 [1, 5, 512] 262,656 262,656 442 | Linear-24 [1, 5, 512] 262,656 262,656 443 | Linear-25 [1, 5, 512] 262,656 262,656 444 | Conv1d-26 [1, 2048, 5] 1,050,624 1,050,624 445 | Conv1d-27 [1, 512, 5] 1,049,088 1,049,088 446 | Linear-28 [1, 5, 512] 262,656 262,656 447 | Linear-29 [1, 5, 512] 262,656 262,656 448 | Linear-30 [1, 5, 512] 262,656 262,656 449 | Conv1d-31 [1, 2048, 5] 1,050,624 1,050,624 450 | Conv1d-32 [1, 512, 5] 1,049,088 1,049,088 451 | Embedding-33 [1, 5, 512] 3,584 3,584 452 | Embedding-34 [1, 5, 512] 3,072 0 453 | Linear-35 [1, 5, 512] 262,656 262,656 454 | Linear-36 [1, 5, 512] 262,656 262,656 455 | Linear-37 [1, 5, 512] 262,656 262,656 456 | Linear-38 [1, 5, 512] 262,656 262,656 457 | Linear-39 [1, 5, 512] 262,656 262,656 458 | Linear-40 [1, 5, 512] 262,656 262,656 459 | Conv1d-41 [1, 2048, 5] 1,050,624 1,050,624 460 | Conv1d-42 [1, 512, 5] 1,049,088 1,049,088 461 | Linear-43 [1, 5, 512] 262,656 262,656 462 | Linear-44 [1, 5, 512] 262,656 262,656 463 | Linear-45 [1, 5, 512] 262,656 262,656 464 | Linear-46 [1, 5, 512] 262,656 262,656 465 | Linear-47 [1, 5, 512] 262,656 262,656 466 | Linear-48 [1, 5, 512] 262,656 262,656 467 | Conv1d-49 [1, 2048, 5] 1,050,624 1,050,624 468 | Conv1d-50 [1, 512, 5] 1,049,088 1,049,088 469 | Linear-51 [1, 5, 512] 262,656 262,656 470 | Linear-52 [1, 5, 512] 262,656 262,656 471 | Linear-53 [1, 5, 512] 262,656 262,656 472 | Linear-54 [1, 5, 512] 262,656 262,656 473 | Linear-55 [1, 5, 512] 262,656 262,656 474 | Linear-56 [1, 5, 512] 262,656 262,656 475 | Conv1d-57 [1, 2048, 5] 1,050,624 1,050,624 476 | Conv1d-58 [1, 512, 5] 1,049,088 1,049,088 477 | Linear-59 [1, 5, 512] 262,656 262,656 478 | Linear-60 [1, 5, 512] 262,656 262,656 479 | Linear-61 [1, 5, 512] 262,656 262,656 480 | Linear-62 [1, 5, 512] 262,656 262,656 481 | Linear-63 [1, 5, 512] 262,656 262,656 482 | Linear-64 [1, 5, 512] 262,656 262,656 483 | Conv1d-65 [1, 2048, 5] 1,050,624 1,050,624 484 | Conv1d-66 [1, 512, 5] 1,049,088 1,049,088 485 | Linear-67 [1, 5, 512] 262,656 262,656 486 | Linear-68 [1, 5, 512] 262,656 262,656 487 | Linear-69 [1, 5, 512] 262,656 262,656 488 | Linear-70 [1, 5, 512] 262,656 262,656 489 | Linear-71 [1, 5, 512] 262,656 262,656 490 | Linear-72 [1, 5, 512] 262,656 262,656 491 | Conv1d-73 [1, 2048, 5] 1,050,624 1,050,624 492 | Conv1d-74 [1, 512, 5] 1,049,088 1,049,088 493 | Linear-75 [1, 5, 512] 262,656 262,656 494 | Linear-76 [1, 5, 512] 262,656 262,656 495 | Linear-77 [1, 5, 512] 262,656 262,656 496 | Linear-78 [1, 5, 512] 262,656 262,656 497 | Linear-79 [1, 5, 512] 262,656 262,656 498 | Linear-80 [1, 5, 512] 262,656 262,656 499 | Conv1d-81 [1, 2048, 5] 1,050,624 1,050,624 500 | Conv1d-82 [1, 512, 5] 1,049,088 1,049,088 501 | Linear-83 [1, 5, 7] 3,584 3,584 502 | ======================================================================= 503 | Total params: 39,396,352 504 | Trainable params: 39,390,208 505 | Non-trainable params: 6,144 506 | ----------------------------------------------------------------------- 507 | ``` 508 | 509 | 5) showing **layers until depth 3** and adding column with **parent layers** 510 | ```python 511 | # show layers until depth 3 and add column with parent layers 512 | pms.summary(model, enc_inputs, dec_inputs, max_depth=3, show_parent_layers=True, print_summary=True) 513 | ``` 514 | ``` 515 | ----------------------------------------------------------------------------------------------------------------------------- 516 | Parent Layers Layer (type) Output Shape Param # Tr. Param # 517 | ============================================================================================================================= 518 | Transformer/Encoder Embedding-1 [1, 5, 512] 3,072 3,072 519 | Transformer/Encoder Embedding-2 [1, 5, 512] 3,072 0 520 | Transformer/Encoder/EncoderLayer MultiHeadAttention-3 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 521 | Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-4 [1, 5, 512] 2,099,712 2,099,712 522 | Transformer/Encoder/EncoderLayer MultiHeadAttention-5 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 523 | Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-6 [1, 5, 512] 2,099,712 2,099,712 524 | Transformer/Encoder/EncoderLayer MultiHeadAttention-7 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 525 | Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-8 [1, 5, 512] 2,099,712 2,099,712 526 | Transformer/Encoder/EncoderLayer MultiHeadAttention-9 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 527 | Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-10 [1, 5, 512] 2,099,712 2,099,712 528 | Transformer/Encoder/EncoderLayer MultiHeadAttention-11 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 529 | Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-12 [1, 5, 512] 2,099,712 2,099,712 530 | Transformer/Encoder/EncoderLayer MultiHeadAttention-13 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 531 | Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-14 [1, 5, 512] 2,099,712 2,099,712 532 | Transformer/Decoder Embedding-15 [1, 5, 512] 3,584 3,584 533 | Transformer/Decoder Embedding-16 [1, 5, 512] 3,072 0 534 | Transformer/Decoder/DecoderLayer MultiHeadAttention-17 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 535 | Transformer/Decoder/DecoderLayer MultiHeadAttention-18 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 536 | Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-19 [1, 5, 512] 2,099,712 2,099,712 537 | Transformer/Decoder/DecoderLayer MultiHeadAttention-20 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 538 | Transformer/Decoder/DecoderLayer MultiHeadAttention-21 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 539 | Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-22 [1, 5, 512] 2,099,712 2,099,712 540 | Transformer/Decoder/DecoderLayer MultiHeadAttention-23 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 541 | Transformer/Decoder/DecoderLayer MultiHeadAttention-24 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 542 | Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-25 [1, 5, 512] 2,099,712 2,099,712 543 | Transformer/Decoder/DecoderLayer MultiHeadAttention-26 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 544 | Transformer/Decoder/DecoderLayer MultiHeadAttention-27 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 545 | Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-28 [1, 5, 512] 2,099,712 2,099,712 546 | Transformer/Decoder/DecoderLayer MultiHeadAttention-29 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 547 | Transformer/Decoder/DecoderLayer MultiHeadAttention-30 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 548 | Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-31 [1, 5, 512] 2,099,712 2,099,712 549 | Transformer/Decoder/DecoderLayer MultiHeadAttention-32 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 550 | Transformer/Decoder/DecoderLayer MultiHeadAttention-33 [1, 5, 512], [1, 8, 5, 5] 787,968 787,968 551 | Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-34 [1, 5, 512] 2,099,712 2,099,712 552 | Transformer Linear-35 [1, 5, 7] 3,584 3,584 553 | ============================================================================================================================= 554 | Total params: 39,396,352 555 | Trainable params: 39,390,208 556 | Non-trainable params: 6,144 557 | ----------------------------------------------------------------------------------------------------------------------------- 558 | ``` 559 | 560 | 561 | ## Reference 562 | 563 | ```python 564 | code_reference = { 'https://github.com/graykode/modelsummary', 565 | 'https://github.com/pytorch/pytorch/issues/2001', 566 | 'https://gist.github.com/HTLife/b6640af9d6e7d765411f8aa9aa94b837', 567 | 'https://github.com/sksq96/pytorch-summary', 568 | 'Inspired by https://github.com/sksq96/pytorch-summary'} 569 | ``` 570 | --------------------------------------------------------------------------------