├── README.md ├── pytorch_tools.py └── tensorflow_tools.py /README.md: -------------------------------------------------------------------------------- 1 | # model-tools 2 | Tools for computing model parameters and FLOPs. 3 | 4 | - [Caffe](https://simochen.github.io/netscope) 5 | - [PyTorch](https://github.com/simochen/model-tools/blob/master/pytorch_tools.py) 6 | - [TensorFlow](https://github.com/simochen/model-tools/blob/master/tensorflow_tools.py) 7 | -------------------------------------------------------------------------------- /pytorch_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf8 -*- 2 | import torch 3 | import torchvision 4 | 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | import numpy as np 9 | 10 | 11 | def print_model_param_nums(model=None): 12 | if model == None: 13 | model = torchvision.models.alexnet() 14 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 15 | print(' + Number of params: %.2fM' % (total / 1e6)) 16 | 17 | 18 | 19 | def print_model_param_flops(model=None, input_res=[224, 224], multiply_adds=True): 20 | 21 | prods = {} 22 | def save_hook(name): 23 | def hook_per(self, input, output): 24 | prods[name] = np.prod(input[0].shape) 25 | return hook_per 26 | 27 | list_1=[] 28 | def simple_hook(self, input, output): 29 | list_1.append(np.prod(input[0].shape)) 30 | list_2={} 31 | def simple_hook2(self, input, output): 32 | list_2['names'] = np.prod(input[0].shape) 33 | 34 | 35 | list_conv=[] 36 | def conv_hook(self, input, output): 37 | batch_size, input_channels, input_height, input_width = input[0].size() 38 | output_channels, output_height, output_width = output[0].size() 39 | 40 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 41 | bias_ops = 1 if self.bias is not None else 0 42 | 43 | params = output_channels * (kernel_ops + bias_ops) 44 | flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 45 | 46 | list_conv.append(flops) 47 | 48 | 49 | list_linear=[] 50 | def linear_hook(self, input, output): 51 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 52 | 53 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 54 | bias_ops = self.bias.nelement() 55 | 56 | flops = batch_size * (weight_ops + bias_ops) 57 | list_linear.append(flops) 58 | 59 | list_bn=[] 60 | def bn_hook(self, input, output): 61 | list_bn.append(input[0].nelement() * 2) 62 | 63 | list_relu=[] 64 | def relu_hook(self, input, output): 65 | list_relu.append(input[0].nelement()) 66 | 67 | list_pooling=[] 68 | def pooling_hook(self, input, output): 69 | batch_size, input_channels, input_height, input_width = input[0].size() 70 | output_channels, output_height, output_width = output[0].size() 71 | 72 | kernel_ops = self.kernel_size * self.kernel_size 73 | bias_ops = 0 74 | params = 0 75 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 76 | 77 | list_pooling.append(flops) 78 | 79 | list_upsample=[] 80 | # For bilinear upsample 81 | def upsample_hook(self, input, output): 82 | batch_size, input_channels, input_height, input_width = input[0].size() 83 | output_channels, output_height, output_width = output[0].size() 84 | 85 | flops = output_height * output_width * output_channels * batch_size * 12 86 | list_upsample.append(flops) 87 | 88 | def foo(net): 89 | childrens = list(net.children()) 90 | if not childrens: 91 | if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): 92 | net.register_forward_hook(conv_hook) 93 | if isinstance(net, torch.nn.Linear): 94 | net.register_forward_hook(linear_hook) 95 | if isinstance(net, torch.nn.BatchNorm2d): 96 | net.register_forward_hook(bn_hook) 97 | if isinstance(net, torch.nn.ReLU): 98 | net.register_forward_hook(relu_hook) 99 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 100 | net.register_forward_hook(pooling_hook) 101 | if isinstance(net, torch.nn.Upsample): 102 | net.register_forward_hook(upsample_hook) 103 | return 104 | for c in childrens: 105 | foo(c) 106 | 107 | if model == None: 108 | model = torchvision.models.alexnet() 109 | foo(model) 110 | input = Variable(torch.rand(3,input_res[1],input_res[0]).unsqueeze(0), requires_grad = True) 111 | out = model(input) 112 | 113 | 114 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 115 | 116 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 117 | 118 | 119 | 120 | def print_forward(model=None): 121 | if model == None: 122 | model = torchvision.models.resnet18() 123 | select_layer = model.layer1[0].conv1 124 | 125 | grads={} 126 | def save_grad(name): 127 | def hook(self, input, output): 128 | grads[name] = input 129 | return hook 130 | 131 | select_layer.register_forward_hook(save_grad('select_layer')) 132 | 133 | input = Variable(torch.rand(3,224,224).unsqueeze(0), requires_grad = True) 134 | out = model(input) 135 | # print(grads['select_layer']) 136 | print(grads) 137 | 138 | 139 | def print_value(): 140 | grads = {} 141 | def save_grad(name): 142 | def hook(grad): 143 | grads[name] = grad 144 | return hook 145 | 146 | x = Variable(torch.randn(1,1), requires_grad=True) 147 | y = 3*x 148 | z = y**2 149 | 150 | # In here, save_grad('y') returns a hook (a function) that keeps 'y' as name 151 | y.register_hook(save_grad('y')) 152 | z.register_hook(save_grad('z')) 153 | z.backward() 154 | print('HW') 155 | print("grads['y']: {}".format(grads['y'])) 156 | print(grads['z']) 157 | 158 | def print_layers_num(model=None): 159 | if model == None: 160 | model = torchvision.models.resnet18() 161 | def foo(net): 162 | childrens = list(net.children()) 163 | if not childrens: 164 | if isinstance(net, torch.nn.Conv2d): 165 | print(' ') 166 | #可以用来统计不同层的个数 167 | # net.register_backward_hook(print) 168 | return 1 169 | count = 0 170 | for c in childrens: 171 | count += foo(c) 172 | return count 173 | print(foo(model)) 174 | 175 | 176 | def check_summary(model=None): 177 | def torch_summarize(model, show_weights=True, show_parameters=True): 178 | """Summarizes torch model by showing trainable parameters and weights.""" 179 | from torch.nn.modules.module import _addindent 180 | 181 | tmpstr = model.__class__.__name__ + ' (\n' 182 | for key, module in model._modules.items(): 183 | # if it contains layers let call it recursively to get params and weights 184 | if type(module) in [ 185 | torch.nn.modules.container.Container, 186 | torch.nn.modules.container.Sequential 187 | ]: 188 | modstr = torch_summarize(module) 189 | else: 190 | modstr = module.__repr__() 191 | modstr = _addindent(modstr, 2) 192 | 193 | params = sum([np.prod(p.size()) for p in module.parameters()]) 194 | weights = tuple([tuple(p.size()) for p in module.parameters()]) 195 | 196 | tmpstr += ' (' + key + '): ' + modstr 197 | if show_weights: 198 | tmpstr += ', weights={}'.format(weights) 199 | if show_parameters: 200 | tmpstr += ', parameters={}'.format(params) 201 | tmpstr += '\n' 202 | 203 | tmpstr = tmpstr + ')' 204 | return tmpstr 205 | 206 | # Test 207 | if model == None: 208 | model = torchvision.models.alexnet() 209 | print(torch_summarize(model)) 210 | 211 | #https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7 212 | #summarize a torch model like in keras, showing parameters and output shape 213 | def show_summary(model=None): 214 | from collections import OrderedDict 215 | import pandas as pd 216 | import numpy as np 217 | 218 | import torch 219 | from torch.autograd import Variable 220 | import torch.nn.functional as F 221 | from torch import nn 222 | 223 | 224 | def get_names_dict(model): 225 | """ 226 | Recursive walk to get names including path 227 | """ 228 | names = {} 229 | def _get_names(module, parent_name=''): 230 | for key, module in module.named_children(): 231 | name = parent_name + '.' + key if parent_name else key 232 | names[name]=module 233 | if isinstance(module, torch.nn.Module): 234 | _get_names(module, parent_name=name) 235 | _get_names(model) 236 | return names 237 | 238 | 239 | def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False): 240 | """ 241 | Summarizes torch model by showing trainable parameters and weights. 242 | 243 | author: wassname 244 | url: https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7 245 | license: MIT 246 | 247 | Modified from: 248 | - https://github.com/pytorch/pytorch/issues/2001#issuecomment-313735757 249 | - https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7/ 250 | 251 | Usage: 252 | import torchvision.models as models 253 | model = models.alexnet() 254 | df = torch_summarize_df(input_size=(3, 224,224), model=model) 255 | print(df) 256 | 257 | # name class_name input_shape output_shape num_params 258 | # 1 features=>0 Conv2d (-1, 3, 224, 224) (-1, 64, 55, 55) 23296#(3*11*11+1)*64 259 | # 2 features=>1 ReLU (-1, 64, 55, 55) (-1, 64, 55, 55) 0 260 | # ... 261 | """ 262 | 263 | def register_hook(module): 264 | def hook(module, input, output): 265 | name = '' 266 | for key, item in names.items(): 267 | if item == module: 268 | name = key 269 | # 270 | class_name = str(module.__class__).split('.')[-1].split("'")[0] 271 | module_idx = len(summary) 272 | 273 | m_key = module_idx + 1 274 | 275 | summary[m_key] = OrderedDict() 276 | summary[m_key]['name'] = name 277 | summary[m_key]['class_name'] = class_name 278 | if input_shape: 279 | summary[m_key][ 280 | 'input_shape'] = (-1, ) + tuple(input[0].size())[1:] 281 | summary[m_key]['output_shape'] = (-1, ) + tuple(output.size())[1:] 282 | if weights: 283 | summary[m_key]['weights'] = list( 284 | [tuple(p.size()) for p in module.parameters()]) 285 | 286 | # summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()]) 287 | if nb_trainable: 288 | params_trainable = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad]) 289 | summary[m_key]['nb_trainable'] = params_trainable 290 | params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()]) 291 | summary[m_key]['nb_params'] = params 292 | 293 | 294 | if not isinstance(module, nn.Sequential) and \ 295 | not isinstance(module, nn.ModuleList) and \ 296 | not (module == model): 297 | hooks.append(module.register_forward_hook(hook)) 298 | 299 | # Names are stored in parent and path+name is unique not the name 300 | names = get_names_dict(model) 301 | 302 | # check if there are multiple inputs to the network 303 | if isinstance(input_size[0], (list, tuple)): 304 | x = [Variable(torch.rand(1, *in_size)) for in_size in input_size] 305 | else: 306 | x = Variable(torch.rand(1, *input_size)) 307 | 308 | if next(model.parameters()).is_cuda: 309 | x = x.cuda() 310 | 311 | # create properties 312 | summary = OrderedDict() 313 | hooks = [] 314 | 315 | # register hook 316 | model.apply(register_hook) 317 | 318 | # make a forward pass 319 | model(x) 320 | 321 | # remove these hooks 322 | for h in hooks: 323 | h.remove() 324 | 325 | # make dataframe 326 | df_summary = pd.DataFrame.from_dict(summary, orient='index') 327 | 328 | return df_summary 329 | 330 | 331 | # Test on alexnet 332 | if model == None: 333 | model = torchvision.models.alexnet() 334 | df = torch_summarize_df(input_size=(3, 224, 224), model=model) 335 | print(df) 336 | 337 | # # Output 338 | # name class_name input_shape output_shape num_params 339 | # 1 features=>0 Conv2d (-1, 3, 224, 224) (-1, 64, 55, 55) 23296#nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 340 | # 2 features=>1 ReLU (-1, 64, 55, 55) (-1, 64, 55, 55) 0 341 | # 3 features=>2 MaxPool2d (-1, 64, 55, 55) (-1, 64, 27, 27) 0 342 | # 4 features=>3 Conv2d (-1, 64, 27, 27) (-1, 192, 27, 27) 307392 343 | # 5 features=>4 ReLU (-1, 192, 27, 27) (-1, 192, 27, 27) 0 344 | # 6 features=>5 MaxPool2d (-1, 192, 27, 27) (-1, 192, 13, 13) 0 345 | # 7 features=>6 Conv2d (-1, 192, 13, 13) (-1, 384, 13, 13) 663936 346 | # 8 features=>7 ReLU (-1, 384, 13, 13) (-1, 384, 13, 13) 0 347 | # 9 features=>8 Conv2d (-1, 384, 13, 13) (-1, 256, 13, 13) 884992 348 | # 10 features=>9 ReLU (-1, 256, 13, 13) (-1, 256, 13, 13) 0 349 | # 11 features=>10 Conv2d (-1, 256, 13, 13) (-1, 256, 13, 13) 590080 350 | # 12 features=>11 ReLU (-1, 256, 13, 13) (-1, 256, 13, 13) 0 351 | # 13 features=>12 MaxPool2d (-1, 256, 13, 13) (-1, 256, 6, 6) 0 352 | # 14 classifier=>0 Dropout (-1, 9216) (-1, 9216) 0 353 | # 15 classifier=>1 Linear (-1, 9216) (-1, 4096) 37752832 354 | # 16 classifier=>2 ReLU (-1, 4096) (-1, 4096) 0 355 | # 17 classifier=>3 Dropout (-1, 4096) (-1, 4096) 0 356 | # 18 classifier=>4 Linear (-1, 4096) (-1, 4096) 16781312 357 | # 19 classifier=>5 ReLU (-1, 4096) (-1, 4096) 0 358 | # 20 classifier=>6 Linear (-1, 4096) (-1, 1000) 4097000 359 | 360 | 361 | def show_save_tensor(model=None): 362 | import torch 363 | import torchvision 364 | import matplotlib.pyplot as plt 365 | 366 | def vis_tensor(tensor, ch = 0, all_kernels=False, nrow=8, padding = 2): 367 | ''' 368 | ch: channel for visualization 369 | allkernels: all kernels for visualization 370 | ''' 371 | n,c,h,w = tensor.shape 372 | if all_kernels: 373 | tensor = tensor.view(n*c ,-1, w, h) 374 | elif c != 3: 375 | tensor = tensor[:, ch,:,:].unsqueeze(dim=1) 376 | 377 | rows = np.min((tensor.shape[0]//nrow + 1, 64 )) 378 | grid = torchvision.utils.make_grid(tensor, nrow=nrow, normalize=True, padding=padding) 379 | # plt.figure(figsize=(nrow,rows)) 380 | plt.imshow(grid.numpy().transpose((1, 2, 0)))#CHW HWC 381 | 382 | 383 | def save_tensor(tensor, filename, ch=0, all_kernels=False, nrow=8, padding=2): 384 | n,c,h,w = tensor.shape 385 | if all_kernels: 386 | tensor = tensor.view(n*c ,-1, w, h) 387 | elif c != 3: 388 | tensor = tensor[:, ch,:,:].unsqueeze(dim=1) 389 | torchvision.utils.save_image(tensor, filename, nrow = nrow,normalize=True, padding=padding) 390 | 391 | 392 | if model == None: 393 | model = torchvision.models.resnet18(pretrained=True) 394 | mm = model.double() 395 | filters = mm.modules 396 | body_model = [i for i in mm.children()][0] 397 | # layer1 = body_model[0] 398 | layer1 = body_model 399 | tensor = layer1.weight.data.clone() 400 | vis_tensor(tensor) 401 | save_tensor(tensor,'test.png') 402 | 403 | plt.axis('off') 404 | plt.ioff() 405 | plt.show() 406 | 407 | def print_autograd_graph(model=None): 408 | from graphviz import Digraph 409 | 410 | 411 | def make_dot(var, params=None): 412 | """ Produces Graphviz representation of PyTorch autograd graph 413 | 414 | Blue nodes are the Variables that require grad, orange are Tensors 415 | saved for backward in torch.autograd.Function 416 | 417 | Args: 418 | var: output Variable 419 | params: dict of (name, Variable) to add names to node that 420 | require grad (TODO: make optional) 421 | """ 422 | if params is not None: 423 | #assert all(isinstance(p, Variable) for p in params.values()) 424 | param_map = {id(v): k for k, v in params.items()} 425 | 426 | 427 | node_attr = dict(style='filled', 428 | shape='box', 429 | align='left', 430 | fontsize='12', 431 | ranksep='0.1', 432 | height='0.2') 433 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 434 | seen = set() 435 | 436 | def size_to_str(size): 437 | return '('+(', ').join(['%d' % v for v in size])+')' 438 | 439 | def add_nodes(var): 440 | if var not in seen: 441 | if torch.is_tensor(var): 442 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') 443 | elif hasattr(var, 'variable'): 444 | u = var.variable 445 | #name = param_map[id(u)] if params is not None else '' 446 | #node_name = '%s\n %s' % (name, size_to_str(u.size())) 447 | node_name = '%s\n %s' % (param_map.get(id(u.data)), size_to_str(u.size())) 448 | dot.node(str(id(var)), node_name, fillcolor='lightblue') 449 | 450 | else: 451 | dot.node(str(id(var)), str(type(var).__name__)) 452 | seen.add(var) 453 | if hasattr(var, 'next_functions'): 454 | for u in var.next_functions: 455 | if u[0] is not None: 456 | dot.edge(str(id(u[0])), str(id(var))) 457 | add_nodes(u[0]) 458 | if hasattr(var, 'saved_tensors'): 459 | for t in var.saved_tensors: 460 | dot.edge(str(id(t)), str(id(var))) 461 | add_nodes(t) 462 | add_nodes(var.grad_fn) 463 | return dot 464 | 465 | 466 | 467 | torch.manual_seed(1) 468 | inputs = torch.randn(1,3,224,224) 469 | if model == None: 470 | model = torchvision.models.resnet18(pretrained=False) 471 | y = model(Variable(inputs)) 472 | #print(y) 473 | 474 | 475 | g = make_dot(y, params=model.state_dict()) 476 | g.view() 477 | #g 478 | 479 | if __name__=='__main__': 480 | import fire 481 | fire. Fire() 482 | -------------------------------------------------------------------------------- /tensorflow_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf8 -*- 2 | import tensorflow as tf 3 | 4 | run_metadata = tf.RunMetadata() 5 | 6 | # # 计算整个网络的参数量 7 | # # 利用 tf.trainable_variables 8 | # def count_model_params(): 9 | # total_parameters = 0 10 | # for variable in tf.trainable_variables(): 11 | # # shape is an array of tf.Dimension 12 | # shape = variable.get_shape() 13 | # variable_parameters = 1 14 | # for dim in shape: 15 | # variable_parameters *= dim.value 16 | # total_parameters += variable_parameters 17 | # print(' + Number of params: %.2fM' % (total_parameters / 1e6)) 18 | 19 | # def print_model_params(): 20 | # import numpy as np 21 | # num_params = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 22 | # print(' + Number of params: %.2fM' % (num_params / 1e6)) 23 | 24 | with tf.Session(graph=tf.Graph()) as sess: 25 | 26 | opt = tf.profiler.ProfileOptionBuilder.float_operation() 27 | flops = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt) 28 | 29 | opt = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter() 30 | param_count = tf.profiler.profile(sess.graph, run_meta=run_metadata, cmd='op', options=opt) 31 | 32 | print(' + Number of FLOPs: %.4fG' % (flops.total_float_ops / 1e9)) 33 | print(' + Number of params: %.4fM' % (param_count.total_parameters / 1e6)) 34 | --------------------------------------------------------------------------------