├── .gitignore ├── torchsummary ├── __init__.py ├── tests │ ├── test_models │ │ └── test_model.py │ └── unit_tests │ │ └── torchsummary_test.py └── torchsummary.py ├── setup.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .vscode/ -------------------------------------------------------------------------------- /torchsummary/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchsummary import summary, summary_string 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="torchsummary", 5 | version="1.5.1", 6 | description="Model summary in PyTorch similar to `model.summary()` in Keras", 7 | url="https://github.com/sksq96/pytorch-summary", 8 | author="Shubham Chandel @sksq96", 9 | author_email="shubham.zeez@gmail.com", 10 | packages=["torchsummary"], 11 | ) 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shubham Chandel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchsummary/tests/test_models/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SingleInputNet(nn.Module): 6 | def __init__(self): 7 | super(SingleInputNet, self).__init__() 8 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = nn.Dropout2d(0.3) 11 | self.fc1 = nn.Linear(320, 50) 12 | self.fc2 = nn.Linear(50, 10) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = self.fc2(x) 20 | return F.log_softmax(x, dim=1) 21 | 22 | class MultipleInputNet(nn.Module): 23 | def __init__(self): 24 | super(MultipleInputNet, self).__init__() 25 | self.fc1a = nn.Linear(300, 50) 26 | self.fc1b = nn.Linear(50, 10) 27 | 28 | self.fc2a = nn.Linear(300, 50) 29 | self.fc2b = nn.Linear(50, 10) 30 | 31 | def forward(self, x1, x2): 32 | x1 = F.relu(self.fc1a(x1)) 33 | x1 = self.fc1b(x1) 34 | x2 = F.relu(self.fc2a(x2)) 35 | x2 = self.fc2b(x2) 36 | x = torch.cat((x1, x2), 0) 37 | return F.log_softmax(x, dim=1) 38 | 39 | class MultipleInputNetDifferentDtypes(nn.Module): 40 | def __init__(self): 41 | super(MultipleInputNetDifferentDtypes, self).__init__() 42 | self.fc1a = nn.Linear(300, 50) 43 | self.fc1b = nn.Linear(50, 10) 44 | 45 | self.fc2a = nn.Linear(300, 50) 46 | self.fc2b = nn.Linear(50, 10) 47 | 48 | def forward(self, x1, x2): 49 | x1 = F.relu(self.fc1a(x1)) 50 | x1 = self.fc1b(x1) 51 | x2 = x2.type(torch.FloatTensor) 52 | x2 = F.relu(self.fc2a(x2)) 53 | x2 = self.fc2b(x2) 54 | # set x2 to FloatTensor 55 | x = torch.cat((x1, x2), 0) 56 | return F.log_softmax(x, dim=1) 57 | -------------------------------------------------------------------------------- /torchsummary/tests/unit_tests/torchsummary_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from torchsummary import summary, summary_string 3 | from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, MultipleInputNetDifferentDtypes 4 | import torch 5 | 6 | gpu_if_available = "cuda:0" if torch.cuda.is_available() else "cpu" 7 | 8 | class torchsummaryTests(unittest.TestCase): 9 | def test_single_input(self): 10 | model = SingleInputNet() 11 | input = (1, 28, 28) 12 | total_params, trainable_params = summary(model, input, device="cpu") 13 | self.assertEqual(total_params, 21840) 14 | self.assertEqual(trainable_params, 21840) 15 | 16 | def test_multiple_input(self): 17 | model = MultipleInputNet() 18 | input1 = (1, 300) 19 | input2 = (1, 300) 20 | total_params, trainable_params = summary( 21 | model, [input1, input2], device="cpu") 22 | self.assertEqual(total_params, 31120) 23 | self.assertEqual(trainable_params, 31120) 24 | 25 | def test_single_layer_network(self): 26 | model = torch.nn.Linear(2, 5) 27 | input = (1, 2) 28 | total_params, trainable_params = summary(model, input, device="cpu") 29 | self.assertEqual(total_params, 15) 30 | self.assertEqual(trainable_params, 15) 31 | 32 | def test_single_layer_network_on_gpu(self): 33 | model = torch.nn.Linear(2, 5) 34 | if torch.cuda.is_available(): 35 | model.cuda() 36 | input = (1, 2) 37 | total_params, trainable_params = summary(model, input, device=gpu_if_available) 38 | self.assertEqual(total_params, 15) 39 | self.assertEqual(trainable_params, 15) 40 | 41 | def test_multiple_input_types(self): 42 | model = MultipleInputNetDifferentDtypes() 43 | input1 = (1, 300) 44 | input2 = (1, 300) 45 | dtypes = [torch.FloatTensor, torch.LongTensor] 46 | total_params, trainable_params = summary( 47 | model, [input1, input2], device="cpu", dtypes=dtypes) 48 | self.assertEqual(total_params, 31120) 49 | self.assertEqual(trainable_params, 31120) 50 | 51 | 52 | class torchsummarystringTests(unittest.TestCase): 53 | def test_single_input(self): 54 | model = SingleInputNet() 55 | input = (1, 28, 28) 56 | result, (total_params, trainable_params) = summary_string( 57 | model, input, device="cpu") 58 | self.assertEqual(type(result), str) 59 | self.assertEqual(total_params, 21840) 60 | self.assertEqual(trainable_params, 21840) 61 | 62 | 63 | if __name__ == '__main__': 64 | unittest.main(buffer=True) 65 | -------------------------------------------------------------------------------- /torchsummary/torchsummary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | 9 | def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): 10 | result, params_info = summary_string( 11 | model, input_size, batch_size, device, dtypes) 12 | print(result) 13 | 14 | return params_info 15 | 16 | 17 | def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): 18 | if dtypes == None: 19 | dtypes = [torch.FloatTensor]*len(input_size) 20 | 21 | summary_str = '' 22 | 23 | def register_hook(module): 24 | def hook(module, input, output): 25 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 26 | module_idx = len(summary) 27 | 28 | m_key = "%s-%i" % (class_name, module_idx + 1) 29 | summary[m_key] = OrderedDict() 30 | summary[m_key]["input_shape"] = list(input[0].size()) 31 | summary[m_key]["input_shape"][0] = batch_size 32 | if isinstance(output, (list, tuple)): 33 | summary[m_key]["output_shape"] = [ 34 | [-1] + list(o.size())[1:] for o in output 35 | ] 36 | else: 37 | summary[m_key]["output_shape"] = list(output.size()) 38 | summary[m_key]["output_shape"][0] = batch_size 39 | 40 | params = 0 41 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 42 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 43 | summary[m_key]["trainable"] = module.weight.requires_grad 44 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 45 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 46 | summary[m_key]["nb_params"] = params 47 | 48 | if ( 49 | not isinstance(module, nn.Sequential) 50 | and not isinstance(module, nn.ModuleList) 51 | ): 52 | hooks.append(module.register_forward_hook(hook)) 53 | 54 | # multiple inputs to the network 55 | if isinstance(input_size, tuple): 56 | input_size = [input_size] 57 | 58 | # batch_size of 2 for batchnorm 59 | x = [torch.rand(2, *in_size).type(dtype).to(device=device) 60 | for in_size, dtype in zip(input_size, dtypes)] 61 | 62 | # create properties 63 | summary = OrderedDict() 64 | hooks = [] 65 | 66 | # register hook 67 | model.apply(register_hook) 68 | 69 | # make a forward pass 70 | # print(x.shape) 71 | model(*x) 72 | 73 | # remove these hooks 74 | for h in hooks: 75 | h.remove() 76 | 77 | summary_str += "----------------------------------------------------------------" + "\n" 78 | line_new = "{:>20} {:>25} {:>15}".format( 79 | "Layer (type)", "Output Shape", "Param #") 80 | summary_str += line_new + "\n" 81 | summary_str += "================================================================" + "\n" 82 | total_params = 0 83 | total_output = 0 84 | trainable_params = 0 85 | for layer in summary: 86 | # input_shape, output_shape, trainable, nb_params 87 | line_new = "{:>20} {:>25} {:>15}".format( 88 | layer, 89 | str(summary[layer]["output_shape"]), 90 | "{0:,}".format(summary[layer]["nb_params"]), 91 | ) 92 | total_params += summary[layer]["nb_params"] 93 | 94 | total_output += np.prod(summary[layer]["output_shape"]) 95 | if "trainable" in summary[layer]: 96 | if summary[layer]["trainable"] == True: 97 | trainable_params += summary[layer]["nb_params"] 98 | summary_str += line_new + "\n" 99 | 100 | # assume 4 bytes/number (float on cuda). 101 | total_input_size = abs(np.prod(sum(input_size, ())) 102 | * batch_size * 4. / (1024 ** 2.)) 103 | total_output_size = abs(2. * total_output * 4. / 104 | (1024 ** 2.)) # x2 for gradients 105 | total_params_size = abs(total_params * 4. / (1024 ** 2.)) 106 | total_size = total_params_size + total_output_size + total_input_size 107 | 108 | summary_str += "================================================================" + "\n" 109 | summary_str += "Total params: {0:,}".format(total_params) + "\n" 110 | summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" 111 | summary_str += "Non-trainable params: {0:,}".format(total_params - 112 | trainable_params) + "\n" 113 | summary_str += "----------------------------------------------------------------" + "\n" 114 | summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" 115 | summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" 116 | summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" 117 | summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" 118 | summary_str += "----------------------------------------------------------------" + "\n" 119 | # return summary 120 | return summary_str, (total_params, trainable_params) 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Use the new and updated [torchinfo](https://github.com/TylerYep/torchinfo). 2 | 3 | ## Keras style `model.summary()` in PyTorch 4 | [![PyPI version](https://badge.fury.io/py/torchsummary.svg)](https://badge.fury.io/py/torchsummary) 5 | 6 | Keras has a neat API to view the visualization of the model which is very helpful while debugging your network. Here is a barebone code to try and mimic the same in PyTorch. The aim is to provide information complementary to, what is not provided by `print(your_model)` in PyTorch. 7 | 8 | ### Usage 9 | 10 | - `pip install torchsummary` or 11 | - `git clone https://github.com/sksq96/pytorch-summary` 12 | 13 | ```python 14 | from torchsummary import summary 15 | summary(your_model, input_size=(channels, H, W)) 16 | ``` 17 | 18 | - Note that the `input_size` is required to make a forward pass through the network. 19 | 20 | ### Examples 21 | 22 | #### CNN for MNIST 23 | 24 | 25 | ```python 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | from torchsummary import summary 30 | 31 | class Net(nn.Module): 32 | def __init__(self): 33 | super(Net, self).__init__() 34 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 35 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 36 | self.conv2_drop = nn.Dropout2d() 37 | self.fc1 = nn.Linear(320, 50) 38 | self.fc2 = nn.Linear(50, 10) 39 | 40 | def forward(self, x): 41 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 42 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 43 | x = x.view(-1, 320) 44 | x = F.relu(self.fc1(x)) 45 | x = F.dropout(x, training=self.training) 46 | x = self.fc2(x) 47 | return F.log_softmax(x, dim=1) 48 | 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0 50 | model = Net().to(device) 51 | 52 | summary(model, (1, 28, 28)) 53 | ``` 54 | 55 | 56 | ``` 57 | ---------------------------------------------------------------- 58 | Layer (type) Output Shape Param # 59 | ================================================================ 60 | Conv2d-1 [-1, 10, 24, 24] 260 61 | Conv2d-2 [-1, 20, 8, 8] 5,020 62 | Dropout2d-3 [-1, 20, 8, 8] 0 63 | Linear-4 [-1, 50] 16,050 64 | Linear-5 [-1, 10] 510 65 | ================================================================ 66 | Total params: 21,840 67 | Trainable params: 21,840 68 | Non-trainable params: 0 69 | ---------------------------------------------------------------- 70 | Input size (MB): 0.00 71 | Forward/backward pass size (MB): 0.06 72 | Params size (MB): 0.08 73 | Estimated Total Size (MB): 0.15 74 | ---------------------------------------------------------------- 75 | ``` 76 | 77 | 78 | #### VGG16 79 | 80 | 81 | ```python 82 | import torch 83 | from torchvision import models 84 | from torchsummary import summary 85 | 86 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 87 | vgg = models.vgg16().to(device) 88 | 89 | summary(vgg, (3, 224, 224)) 90 | ``` 91 | 92 | 93 | ``` 94 | ---------------------------------------------------------------- 95 | Layer (type) Output Shape Param # 96 | ================================================================ 97 | Conv2d-1 [-1, 64, 224, 224] 1,792 98 | ReLU-2 [-1, 64, 224, 224] 0 99 | Conv2d-3 [-1, 64, 224, 224] 36,928 100 | ReLU-4 [-1, 64, 224, 224] 0 101 | MaxPool2d-5 [-1, 64, 112, 112] 0 102 | Conv2d-6 [-1, 128, 112, 112] 73,856 103 | ReLU-7 [-1, 128, 112, 112] 0 104 | Conv2d-8 [-1, 128, 112, 112] 147,584 105 | ReLU-9 [-1, 128, 112, 112] 0 106 | MaxPool2d-10 [-1, 128, 56, 56] 0 107 | Conv2d-11 [-1, 256, 56, 56] 295,168 108 | ReLU-12 [-1, 256, 56, 56] 0 109 | Conv2d-13 [-1, 256, 56, 56] 590,080 110 | ReLU-14 [-1, 256, 56, 56] 0 111 | Conv2d-15 [-1, 256, 56, 56] 590,080 112 | ReLU-16 [-1, 256, 56, 56] 0 113 | MaxPool2d-17 [-1, 256, 28, 28] 0 114 | Conv2d-18 [-1, 512, 28, 28] 1,180,160 115 | ReLU-19 [-1, 512, 28, 28] 0 116 | Conv2d-20 [-1, 512, 28, 28] 2,359,808 117 | ReLU-21 [-1, 512, 28, 28] 0 118 | Conv2d-22 [-1, 512, 28, 28] 2,359,808 119 | ReLU-23 [-1, 512, 28, 28] 0 120 | MaxPool2d-24 [-1, 512, 14, 14] 0 121 | Conv2d-25 [-1, 512, 14, 14] 2,359,808 122 | ReLU-26 [-1, 512, 14, 14] 0 123 | Conv2d-27 [-1, 512, 14, 14] 2,359,808 124 | ReLU-28 [-1, 512, 14, 14] 0 125 | Conv2d-29 [-1, 512, 14, 14] 2,359,808 126 | ReLU-30 [-1, 512, 14, 14] 0 127 | MaxPool2d-31 [-1, 512, 7, 7] 0 128 | Linear-32 [-1, 4096] 102,764,544 129 | ReLU-33 [-1, 4096] 0 130 | Dropout-34 [-1, 4096] 0 131 | Linear-35 [-1, 4096] 16,781,312 132 | ReLU-36 [-1, 4096] 0 133 | Dropout-37 [-1, 4096] 0 134 | Linear-38 [-1, 1000] 4,097,000 135 | ================================================================ 136 | Total params: 138,357,544 137 | Trainable params: 138,357,544 138 | Non-trainable params: 0 139 | ---------------------------------------------------------------- 140 | Input size (MB): 0.57 141 | Forward/backward pass size (MB): 218.59 142 | Params size (MB): 527.79 143 | Estimated Total Size (MB): 746.96 144 | ---------------------------------------------------------------- 145 | ``` 146 | 147 | 148 | #### Multiple Inputs 149 | 150 | 151 | ```python 152 | import torch 153 | import torch.nn as nn 154 | from torchsummary import summary 155 | 156 | class SimpleConv(nn.Module): 157 | def __init__(self): 158 | super(SimpleConv, self).__init__() 159 | self.features = nn.Sequential( 160 | nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1), 161 | nn.ReLU(), 162 | ) 163 | 164 | def forward(self, x, y): 165 | x1 = self.features(x) 166 | x2 = self.features(y) 167 | return x1, x2 168 | 169 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 170 | model = SimpleConv().to(device) 171 | 172 | summary(model, [(1, 16, 16), (1, 28, 28)]) 173 | ``` 174 | 175 | 176 | ``` 177 | ---------------------------------------------------------------- 178 | Layer (type) Output Shape Param # 179 | ================================================================ 180 | Conv2d-1 [-1, 1, 16, 16] 10 181 | ReLU-2 [-1, 1, 16, 16] 0 182 | Conv2d-3 [-1, 1, 28, 28] 10 183 | ReLU-4 [-1, 1, 28, 28] 0 184 | ================================================================ 185 | Total params: 20 186 | Trainable params: 20 187 | Non-trainable params: 0 188 | ---------------------------------------------------------------- 189 | Input size (MB): 0.77 190 | Forward/backward pass size (MB): 0.02 191 | Params size (MB): 0.00 192 | Estimated Total Size (MB): 0.78 193 | ---------------------------------------------------------------- 194 | ``` 195 | 196 | 197 | 198 | ### References 199 | 200 | - The idea for this package sparked from [this PyTorch issue](https://github.com/pytorch/pytorch/issues/2001). 201 | - Thanks to @ncullen93 and @HTLife. 202 | - For Model Size Estimation @jacobkimmel ([details here](https://github.com/sksq96/pytorch-summary/pull/21)) 203 | 204 | ### License 205 | 206 | `pytorch-summary` is MIT-licensed. 207 | 208 | --------------------------------------------------------------------------------