├── LICENSE ├── convert.py ├── distributed_train.sh ├── figs ├── accuracy_to_latency.png ├── accuracy_to_latency_iphone12cpu.png └── repghost_bottleneck.png ├── infotool ├── __init__.py ├── fx_profile.c ├── fx_profile.py ├── helper.py ├── profile.py ├── rnn_hooks.py └── vision │ ├── __init__.py │ ├── basic_hooks.py │ ├── counter.py │ ├── efficientnet.py │ └── onnx_counter.py ├── model ├── __init__.py └── repghost.py ├── readme.md ├── requirements.txt ├── tools.py ├── train.py ├── train.sh ├── validate.py └── work_dirs └── train ├── readme.md ├── repghostnet_0_58x_60M_68.94 ├── args.yaml ├── eval.log └── train.log ├── repghostnet_0_5x_43M_66.95 ├── args.yaml ├── eval.log └── train.log ├── repghostnet_0_8x_96M_72.24 ├── args.yaml ├── eval.log └── train.log ├── repghostnet_1_0x_142M_74.22 ├── args.yaml ├── eval.log └── train.log ├── repghostnet_1_11x_170M_75.07 ├── args.yaml ├── eval.log └── train.log ├── repghostnet_1_3x_231M_76.37 ├── args.yaml ├── eval.log └── train.log ├── repghostnet_1_5x_301M_77.45 ├── args.yaml ├── eval.log └── train.log └── repghostnet_2_0x_516M_78.81 ├── args.yaml ├── eval.log └── train.log /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ChengpengChen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # @Author : chengpeng.chen 2 | # @Email : chencp@live.com 3 | """ 4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong. 5 | https://arxiv.org/abs/2211.06088 6 | """ 7 | import argparse 8 | import os 9 | import importlib 10 | import torch 11 | import torch.nn.parallel 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | from model.repghost import repghost_model_convert 16 | 17 | parser = argparse.ArgumentParser(description='RepGhost Conversion for Inference') 18 | parser.add_argument('load', metavar='LOAD', help='path to the weights file') 19 | parser.add_argument('save', metavar='SAVE', help='path to the weights file') 20 | parser.add_argument('-m', '--model', metavar='model', default='repghot.repghostnet_0_5x') 21 | parser.add_argument('--ema-model', '--ema', action='store_true', help='to load the ema model') 22 | parser.add_argument('--sanity_check', '-c', action='store_true', help='to check the outputs of the models') 23 | 24 | 25 | def convert(): 26 | args = parser.parse_args() 27 | 28 | m = importlib.import_module(f"model.{args.model.split('.')[0]}") 29 | train_model = getattr(m, args.model.split('.')[1])() 30 | train_model.eval() 31 | 32 | if os.path.isfile(args.load): 33 | print("=> loading checkpoint '{}'".format(args.load)) 34 | checkpoint = torch.load(args.load, map_location='cpu') 35 | if args.ema_model and 'state_dict_ema' in checkpoint: 36 | checkpoint = checkpoint['state_dict_ema'] 37 | else: 38 | checkpoint = checkpoint['state_dict'] 39 | 40 | try: 41 | train_model.load_state_dict(checkpoint) 42 | except Exception as e: 43 | ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names 44 | # print(ckpt.keys()) 45 | train_model.load_state_dict(ckpt) 46 | else: 47 | print("=> no checkpoint found at '{}'".format(args.load)) 48 | 49 | infer_model = repghost_model_convert(train_model, save_path=args.save) 50 | print("=> saved checkpoint to '{}'".format(args.save)) 51 | 52 | if args.sanity_check: 53 | data = torch.randn(5, 3, 224, 224) 54 | out = train_model(data) 55 | out2 = infer_model(data) 56 | print('=> The diff is', ((out - out2) ** 2).sum()) 57 | 58 | 59 | if __name__ == '__main__': 60 | convert() 61 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC --master_port=2345 train.py "$@" -------------------------------------------------------------------------------- /figs/accuracy_to_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/figs/accuracy_to_latency.png -------------------------------------------------------------------------------- /figs/accuracy_to_latency_iphone12cpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/figs/accuracy_to_latency_iphone12cpu.png -------------------------------------------------------------------------------- /figs/repghost_bottleneck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/figs/repghost_bottleneck.png -------------------------------------------------------------------------------- /infotool/__init__.py: -------------------------------------------------------------------------------- 1 | from .helper import clever_format 2 | from .profile import profile, profile_origin 3 | import torch 4 | 5 | default_dtype = torch.float64 6 | -------------------------------------------------------------------------------- /infotool/fx_profile.c: -------------------------------------------------------------------------------- 1 | #error Do not use this file, it is the result of a failed Cython compilation. 2 | -------------------------------------------------------------------------------- /infotool/fx_profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as th 3 | import torch.nn as nn 4 | from distutils.version import LooseVersion 5 | 6 | if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): 7 | logging.warning( 8 | f"torch.fx requires version higher than 1.8.0. " 9 | f"But You are using an old version PyTorch {torch.__version__}. " 10 | ) 11 | 12 | 13 | def count_clamp(input_shapes, output_shapes): 14 | return 0 15 | 16 | 17 | def count_mul(input_shapes, output_shapes): 18 | # element-wise 19 | return output_shapes[0].numel() 20 | 21 | 22 | def count_matmul(input_shapes, output_shapes): 23 | in_shape = input_shapes[0] 24 | out_shape = output_shapes[0] 25 | in_features = in_shape[-1] 26 | num_elements = out_shape.numel() 27 | return in_features * num_elements 28 | 29 | 30 | def count_fn_linear(input_shapes, output_shapes, *args, **kwargs): 31 | mul_flops = count_matmul(input_shapes, output_shapes) 32 | if "bias" in kwargs: 33 | add_flops = output_shapes[0].numel() 34 | return mul_flops 35 | 36 | 37 | from .vision.counter import counter_conv 38 | 39 | 40 | def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs): 41 | inputs, weight, bias, stride, padding, dilation, groups = args 42 | if len(input_shapes) == 2: 43 | x_shape, k_shape = input_shapes 44 | elif len(input_shapes) == 3: 45 | x_shape, k_shape, b_shape = input_shapes 46 | out_shape = output_shapes[0] 47 | 48 | kernel_parameters = k_shape[2:].numel() 49 | bias_op = 0 # check it later 50 | in_channel = x_shape[1] 51 | 52 | total_ops = counter_conv( 53 | bias_op, kernel_parameters, out_shape.numel(), in_channel, groups 54 | ).item() 55 | return int(total_ops) 56 | 57 | 58 | def count_nn_linear(module: nn.Module, input_shapes, output_shapes): 59 | return count_matmul(input_shapes, output_shapes) 60 | 61 | 62 | def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs): 63 | return 0 64 | 65 | 66 | def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes): 67 | bias_op = 1 if module.bias is not None else 0 68 | out_shape = output_shapes[0] 69 | 70 | in_channel = module.in_channels 71 | groups = module.groups 72 | kernel_ops = module.weight.shape[2:].numel() 73 | total_ops = counter_conv( 74 | bias_op, kernel_ops, out_shape.numel(), in_channel, groups 75 | ).item() 76 | return int(total_ops) 77 | 78 | 79 | def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes): 80 | assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output" 81 | y = output_shapes[0] 82 | # y = (x - mean) / \sqrt{var + e} * weight + bias 83 | total_ops = 2 * y.numel() 84 | return total_ops 85 | 86 | 87 | zero_ops = ( 88 | nn.ReLU, 89 | nn.ReLU6, 90 | nn.Dropout, 91 | nn.MaxPool2d, 92 | nn.AvgPool2d, 93 | nn.AdaptiveAvgPool2d, 94 | ) 95 | 96 | count_map = { 97 | nn.Linear: count_nn_linear, 98 | nn.Conv2d: count_nn_conv2d, 99 | nn.BatchNorm2d: count_nn_bn2d, 100 | "function linear": count_fn_linear, 101 | "clamp": count_clamp, 102 | "built-in function add": count_zero_ops, 103 | "built-in method fl": count_zero_ops, 104 | "built-in method conv2d of type object": count_fn_conv2d, 105 | "built-in function mul": count_mul, 106 | "built-in function truediv": count_mul, 107 | } 108 | 109 | for k in zero_ops: 110 | count_map[k] = count_zero_ops 111 | 112 | missing_maps = {} 113 | 114 | from torch.fx import symbolic_trace 115 | from torch.fx.passes.shape_prop import ShapeProp 116 | from .helper import prGreen, prRed, prYellow 117 | 118 | 119 | def null_print(*args, **kwargs): 120 | return 121 | 122 | 123 | def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): 124 | gm: torch.fx.GraphModule = symbolic_trace(mod) 125 | g = gm.graph 126 | ShapeProp(gm).propagate(input) 127 | 128 | fprint = null_print 129 | if verbose: 130 | fprint = print 131 | 132 | v_maps = {} 133 | total_flops = 0 134 | 135 | for node in gm.graph.nodes: 136 | # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}") 137 | fprint( 138 | f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}" 139 | ) 140 | # node_op_type = str(node.target).split(".")[-1] 141 | node_flops = None 142 | 143 | input_shapes = [] 144 | output_shapes = [] 145 | fprint("input_shape:", end="\t") 146 | for arg in node.args: 147 | if str(arg) not in v_maps: 148 | continue 149 | fprint(f"{v_maps[str(arg)]}", end="\t") 150 | input_shapes.append(v_maps[str(arg)]) 151 | fprint() 152 | fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}") 153 | output_shapes.append(node.meta["tensor_meta"].shape) 154 | 155 | if node.op in ["output", "placeholder"]: 156 | node_flops = 0 157 | elif node.op == "call_function": 158 | # torch internal functions 159 | key = ( 160 | str(node.target) 161 | .split("at")[0] 162 | .replace("<", "") 163 | .replace(">", "") 164 | .strip() 165 | ) 166 | if key in count_map: 167 | node_flops = count_map[key]( 168 | input_shapes, output_shapes, *node.args, **node.kwargs 169 | ) 170 | else: 171 | missing_maps[key] = (node.op, key) 172 | prRed(f"|{key}| is missing") 173 | elif node.op == "call_method": 174 | # torch internal functions 175 | # fprint(str(node.target) in count_map, str(node.target), count_map.keys()) 176 | key = str(node.target) 177 | if key in count_map: 178 | node_flops = count_map[key](input_shapes, output_shapes) 179 | else: 180 | missing_maps[key] = (node.op, key) 181 | prRed(f"{key} is missing") 182 | elif node.op == "call_module": 183 | # torch.nn modules 184 | # m = getattr(mod, node.target, None) 185 | m = mod.get_submodule(node.target) 186 | key = type(m) 187 | fprint(type(m), type(m) in count_map) 188 | if type(m) in count_map: 189 | node_flops = count_map[type(m)](m, input_shapes, output_shapes) 190 | else: 191 | missing_maps[key] = (node.op,) 192 | prRed(f"{key} is missing") 193 | print("module type:", type(m)) 194 | if isinstance(m, zero_ops): 195 | print(f"weight_shape: None") 196 | else: 197 | print(type(m)) 198 | print( 199 | f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}" 200 | ) 201 | 202 | v_maps[str(node.name)] = node.meta["tensor_meta"].shape 203 | if node_flops is not None: 204 | total_flops += node_flops 205 | prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}") 206 | fprint("==" * 40) 207 | 208 | if len(missing_maps.keys()) > 0: 209 | from pprint import pprint 210 | print("Missing operators: ") 211 | pprint(missing_maps) 212 | return total_flops 213 | 214 | 215 | if __name__ == "__main__": 216 | 217 | class MyOP(nn.Module): 218 | def forward(self, input): 219 | return input / 1 220 | 221 | class MyModule(torch.nn.Module): 222 | def __init__(self): 223 | super().__init__() 224 | self.linear1 = torch.nn.Linear(5, 3) 225 | self.linear2 = torch.nn.Linear(5, 3) 226 | self.myop = MyOP() 227 | 228 | def forward(self, x): 229 | out1 = self.linear1(x) 230 | out2 = self.linear2(x).clamp(min=0.0, max=1.0) 231 | return self.myop(out1 + out2) 232 | 233 | net = MyModule() 234 | data = th.randn(20, 5) 235 | flops = fx_profile(net, data, verbose=False) 236 | print(flops) 237 | -------------------------------------------------------------------------------- /infotool/helper.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | COLOR_RED = "91m" 4 | COLOR_GREEN = "92m" 5 | COLOR_YELLOW = "93m" 6 | 7 | def colorful_print(fn_print, color=COLOR_RED): 8 | def actual_call(*args, **kwargs): 9 | print(f"\033[{color}", end="") 10 | fn_print(*args, **kwargs) 11 | print("\033[00m", end="") 12 | return actual_call 13 | 14 | prRed = colorful_print(print, color=COLOR_RED) 15 | prGreen = colorful_print(print, color=COLOR_GREEN) 16 | prYellow = colorful_print(print, color=COLOR_YELLOW) 17 | 18 | # def prRed(skk): 19 | # print("\033[91m{}\033[00m".format(skk)) 20 | 21 | # def prGreen(skk): 22 | # print("\033[92m{}\033[00m".format(skk)) 23 | 24 | # def prYellow(skk): 25 | # print("\033[93m{}\033[00m".format(skk)) 26 | 27 | 28 | def clever_format(nums, format="%.2f"): 29 | if not isinstance(nums, Iterable): 30 | nums = [nums] 31 | clever_nums = [] 32 | 33 | for num in nums: 34 | if num > 1e12: 35 | clever_nums.append(format % (num / 1e12) + "T") 36 | elif num > 1e9: 37 | clever_nums.append(format % (num / 1e9) + "G") 38 | elif num > 1e6: 39 | clever_nums.append(format % (num / 1e6) + "M") 40 | elif num > 1e3: 41 | clever_nums.append(format % (num / 1e3) + "K") 42 | else: 43 | clever_nums.append(format % num + "B") 44 | 45 | clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,) 46 | 47 | return clever_nums 48 | 49 | 50 | if __name__ == "__main__": 51 | prRed("hello", "world") 52 | prGreen("hello", "world") 53 | prYellow("hello", "world") -------------------------------------------------------------------------------- /infotool/profile.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | 3 | from .vision.basic_hooks import * 4 | from .rnn_hooks import * 5 | 6 | 7 | # logger = logging.getLogger(__name__) 8 | # logger.setLevel(logging.INFO) 9 | 10 | from .helper import prGreen, prRed, prYellow 11 | 12 | if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): 13 | logging.warning( 14 | "You are using an old version PyTorch {version}, which THOP does NOT support.".format( 15 | version=torch.__version__ 16 | ) 17 | ) 18 | 19 | default_dtype = torch.float64 20 | 21 | register_hooks = { 22 | # nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication. 23 | nn.Conv1d: count_convNd, 24 | nn.Conv2d: count_convNd, 25 | nn.Conv3d: count_convNd, 26 | nn.ConvTranspose1d: count_convNd, 27 | nn.ConvTranspose2d: count_convNd, 28 | nn.ConvTranspose3d: count_convNd, 29 | # nn.BatchNorm1d: count_bn, 30 | # nn.BatchNorm2d: count_bn, 31 | # nn.BatchNorm3d: count_bn, 32 | # nn.LayerNorm: count_ln, 33 | # nn.InstanceNorm1d: count_in, 34 | # nn.InstanceNorm2d: count_in, 35 | # nn.InstanceNorm3d: count_in, 36 | # nn.PReLU: count_prelu, 37 | # nn.Softmax: count_softmax, 38 | # nn.ReLU: zero_ops, 39 | # nn.ReLU6: zero_ops, 40 | # nn.LeakyReLU: count_relu, 41 | # nn.MaxPool1d: zero_ops, 42 | # nn.MaxPool2d: zero_ops, 43 | # nn.MaxPool3d: zero_ops, 44 | # nn.AdaptiveMaxPool1d: zero_ops, 45 | # nn.AdaptiveMaxPool2d: zero_ops, 46 | # nn.AdaptiveMaxPool3d: zero_ops, 47 | # nn.AvgPool1d: count_avgpool, 48 | # nn.AvgPool2d: count_avgpool, 49 | # nn.AvgPool3d: count_avgpool, 50 | # nn.AdaptiveAvgPool1d: count_adap_avgpool, 51 | # nn.AdaptiveAvgPool2d: count_adap_avgpool, 52 | # nn.AdaptiveAvgPool3d: count_adap_avgpool, 53 | nn.Linear: count_linear, 54 | # nn.Dropout: zero_ops, 55 | # nn.Upsample: count_upsample, 56 | # nn.UpsamplingBilinear2d: count_upsample, 57 | # nn.UpsamplingNearest2d: count_upsample, 58 | # nn.RNNCell: count_rnn_cell, 59 | # nn.GRUCell: count_gru_cell, 60 | # nn.LSTMCell: count_lstm_cell, 61 | # nn.RNN: count_rnn, 62 | # nn.GRU: count_gru, 63 | # nn.LSTM: count_lstm, 64 | # nn.Sequential: zero_ops, 65 | } 66 | 67 | if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): 68 | register_hooks.update({nn.SyncBatchNorm: count_bn}) 69 | 70 | 71 | def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False): 72 | handler_collection = [] 73 | types_collection = set() 74 | if custom_ops is None: 75 | custom_ops = {} 76 | if report_missing: 77 | verbose = True 78 | 79 | def add_hooks(m): 80 | if len(list(m.children())) > 0: 81 | return 82 | 83 | if hasattr(m, "total_ops") or hasattr(m, "total_params"): 84 | logging.warning( 85 | "Either .total_ops or .total_params is already defined in %s. " 86 | "Be careful, it might change your code's behavior." % str(m) 87 | ) 88 | 89 | m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype)) 90 | m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype)) 91 | 92 | for p in m.parameters(): 93 | m.total_params += torch.DoubleTensor([p.numel()]) 94 | 95 | m_type = type(m) 96 | 97 | fn = None 98 | if ( 99 | m_type in custom_ops 100 | ): # if defined both op maps, use custom_ops to overwrite. 101 | fn = custom_ops[m_type] 102 | if m_type not in types_collection and verbose: 103 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) 104 | elif m_type in register_hooks: 105 | fn = register_hooks[m_type] 106 | if m_type not in types_collection and verbose: 107 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) 108 | else: 109 | if m_type not in types_collection and report_missing: 110 | prRed( 111 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." 112 | % m_type 113 | ) 114 | 115 | if fn is not None: 116 | handler = m.register_forward_hook(fn) 117 | handler_collection.append(handler) 118 | types_collection.add(m_type) 119 | 120 | training = model.training 121 | 122 | model.eval() 123 | model.apply(add_hooks) 124 | 125 | with torch.no_grad(): 126 | model(*inputs) 127 | 128 | total_ops = 0 129 | total_params = 0 130 | for m in model.modules(): 131 | if len(list(m.children())) > 0: # skip for non-leaf module 132 | continue 133 | total_ops += m.total_ops 134 | total_params += m.total_params 135 | 136 | total_ops = total_ops.item() 137 | total_params = total_params.item() 138 | 139 | # reset model to original status 140 | model.train(training) 141 | for handler in handler_collection: 142 | handler.remove() 143 | 144 | # remove temporal buffers 145 | for n, m in model.named_modules(): 146 | if len(list(m.children())) > 0: 147 | continue 148 | if "total_ops" in m._buffers: 149 | m._buffers.pop("total_ops") 150 | if "total_params" in m._buffers: 151 | m._buffers.pop("total_params") 152 | 153 | return total_ops, total_params 154 | 155 | 156 | def profile( 157 | model: nn.Module, 158 | inputs, 159 | custom_ops=None, 160 | verbose=True, 161 | ret_layer_info=False, 162 | report_missing=False, 163 | ): 164 | handler_collection = {} 165 | types_collection = set() 166 | if custom_ops is None: 167 | custom_ops = {} 168 | if report_missing: 169 | # overwrite `verbose` option when enable report_missing 170 | verbose = True 171 | 172 | def add_hooks(m: nn.Module): 173 | m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64)) 174 | m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64)) 175 | 176 | # for p in m.parameters(): 177 | # m.total_params += torch.DoubleTensor([p.numel()]) 178 | 179 | m_type = type(m) 180 | 181 | fn = None 182 | if m_type in custom_ops: 183 | # if defined both op maps, use custom_ops to overwrite. 184 | fn = custom_ops[m_type] 185 | if m_type not in types_collection and verbose: 186 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) 187 | elif m_type in register_hooks: 188 | fn = register_hooks[m_type] 189 | if m_type not in types_collection and verbose: 190 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) 191 | else: 192 | if m_type not in types_collection and report_missing: 193 | prRed( 194 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." 195 | % m_type 196 | ) 197 | 198 | if fn is not None: 199 | handler_collection[m] = ( 200 | m.register_forward_hook(fn), 201 | m.register_forward_hook(count_parameters), 202 | ) 203 | types_collection.add(m_type) 204 | 205 | prev_training_status = model.training 206 | 207 | model.eval() 208 | model.apply(add_hooks) 209 | 210 | with torch.no_grad(): 211 | model(*inputs) 212 | 213 | def dfs_count(module: nn.Module, prefix="\t") -> (int, int): 214 | total_ops, total_params = module.total_ops.item(), 0 215 | ret_dict = {} 216 | for n, m in module.named_children(): 217 | # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0: 218 | # m_ops, m_params = dfs_count(m, prefix=prefix + "\t") 219 | # else: 220 | # m_ops, m_params = m.total_ops, m.total_params 221 | next_dict = {} 222 | if m in handler_collection and not isinstance( 223 | m, (nn.Sequential, nn.ModuleList) 224 | ): 225 | m_ops, m_params = m.total_ops.item(), m.total_params.item() 226 | else: 227 | m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t") 228 | ret_dict[n] = (m_ops, m_params, next_dict) 229 | total_ops += m_ops 230 | total_params += m_params 231 | # print(prefix, module._get_name(), (total_ops, total_params)) 232 | return total_ops, total_params, ret_dict 233 | 234 | total_ops, total_params, ret_dict = dfs_count(model) 235 | 236 | # reset model to original status 237 | model.train(prev_training_status) 238 | for m, (op_handler, params_handler) in handler_collection.items(): 239 | op_handler.remove() 240 | params_handler.remove() 241 | m._buffers.pop("total_ops") 242 | m._buffers.pop("total_params") 243 | 244 | if ret_layer_info: 245 | return total_ops, total_params, ret_dict 246 | return total_ops, total_params 247 | -------------------------------------------------------------------------------- /infotool/rnn_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import PackedSequence 4 | 5 | 6 | def _count_rnn_cell(input_size, hidden_size, bias=True): 7 | # h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) 8 | total_ops = hidden_size * (input_size + hidden_size) + hidden_size 9 | if bias: 10 | total_ops += hidden_size * 2 11 | 12 | return total_ops 13 | 14 | 15 | def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor): 16 | total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias) 17 | 18 | batch_size = x[0].size(0) 19 | total_ops *= batch_size 20 | 21 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 22 | 23 | 24 | def _count_gru_cell(input_size, hidden_size, bias=True): 25 | total_ops = 0 26 | # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ 27 | # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ 28 | state_ops = (hidden_size + input_size) * hidden_size + hidden_size 29 | if bias: 30 | state_ops += hidden_size * 2 31 | total_ops += state_ops * 2 32 | 33 | # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ 34 | total_ops += (hidden_size + input_size) * hidden_size + hidden_size 35 | if bias: 36 | total_ops += hidden_size * 2 37 | # r hadamard : r * (~) 38 | total_ops += hidden_size 39 | 40 | # h' = (1 - z) * n + z * h 41 | # hadamard hadamard add 42 | total_ops += hidden_size * 3 43 | 44 | return total_ops 45 | 46 | 47 | def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor): 48 | total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias) 49 | 50 | batch_size = x[0].size(0) 51 | total_ops *= batch_size 52 | 53 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 54 | 55 | 56 | def _count_lstm_cell(input_size, hidden_size, bias=True): 57 | total_ops = 0 58 | 59 | # i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ 60 | # f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ 61 | # o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ 62 | # g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ 63 | state_ops = (input_size + hidden_size) * hidden_size + hidden_size 64 | if bias: 65 | state_ops += hidden_size * 2 66 | total_ops += state_ops * 4 67 | 68 | # c' = f * c + i * g \\ 69 | # hadamard hadamard add 70 | total_ops += hidden_size * 3 71 | 72 | # h' = o * \tanh(c') \\ 73 | total_ops += hidden_size 74 | 75 | return total_ops 76 | 77 | 78 | def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor): 79 | total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias) 80 | 81 | batch_size = x[0].size(0) 82 | total_ops *= batch_size 83 | 84 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 85 | 86 | 87 | def count_rnn(m: nn.RNN, x, y): 88 | bias = m.bias 89 | input_size = m.input_size 90 | hidden_size = m.hidden_size 91 | num_layers = m.num_layers 92 | 93 | if isinstance(x[0], PackedSequence): 94 | batch_size = torch.max(x[0].batch_sizes) 95 | num_steps = x[0].batch_sizes.size(0) 96 | else: 97 | if m.batch_first: 98 | batch_size = x[0].size(0) 99 | num_steps = x[0].size(1) 100 | else: 101 | batch_size = x[0].size(1) 102 | num_steps = x[0].size(0) 103 | 104 | total_ops = 0 105 | if m.bidirectional: 106 | total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2 107 | else: 108 | total_ops += _count_rnn_cell(input_size, hidden_size, bias) 109 | 110 | for i in range(num_layers - 1): 111 | if m.bidirectional: 112 | total_ops += _count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2 113 | else: 114 | total_ops += _count_rnn_cell(hidden_size, hidden_size, bias) 115 | 116 | # time unroll 117 | total_ops *= num_steps 118 | # batch_size 119 | total_ops *= batch_size 120 | 121 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 122 | 123 | 124 | def count_gru(m: nn.GRU, x, y): 125 | bias = m.bias 126 | input_size = m.input_size 127 | hidden_size = m.hidden_size 128 | num_layers = m.num_layers 129 | 130 | if isinstance(x[0], PackedSequence): 131 | batch_size = torch.max(x[0].batch_sizes) 132 | num_steps = x[0].batch_sizes.size(0) 133 | else: 134 | if m.batch_first: 135 | batch_size = x[0].size(0) 136 | num_steps = x[0].size(1) 137 | else: 138 | batch_size = x[0].size(1) 139 | num_steps = x[0].size(0) 140 | 141 | total_ops = 0 142 | if m.bidirectional: 143 | total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2 144 | else: 145 | total_ops += _count_gru_cell(input_size, hidden_size, bias) 146 | 147 | for i in range(num_layers - 1): 148 | if m.bidirectional: 149 | total_ops += _count_gru_cell(hidden_size * 2, hidden_size, bias) * 2 150 | else: 151 | total_ops += _count_gru_cell(hidden_size, hidden_size, bias) 152 | 153 | # time unroll 154 | total_ops *= num_steps 155 | # batch_size 156 | total_ops *= batch_size 157 | 158 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 159 | 160 | 161 | def count_lstm(m: nn.LSTM, x, y): 162 | bias = m.bias 163 | input_size = m.input_size 164 | hidden_size = m.hidden_size 165 | num_layers = m.num_layers 166 | 167 | if isinstance(x[0], PackedSequence): 168 | batch_size = torch.max(x[0].batch_sizes) 169 | num_steps = x[0].batch_sizes.size(0) 170 | else: 171 | if m.batch_first: 172 | batch_size = x[0].size(0) 173 | num_steps = x[0].size(1) 174 | else: 175 | batch_size = x[0].size(1) 176 | num_steps = x[0].size(0) 177 | 178 | total_ops = 0 179 | if m.bidirectional: 180 | total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2 181 | else: 182 | total_ops += _count_lstm_cell(input_size, hidden_size, bias) 183 | 184 | for i in range(num_layers - 1): 185 | if m.bidirectional: 186 | total_ops += _count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2 187 | else: 188 | total_ops += _count_lstm_cell(hidden_size, hidden_size, bias) 189 | 190 | # time unroll 191 | total_ops *= num_steps 192 | # batch_size 193 | total_ops *= batch_size 194 | 195 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 196 | -------------------------------------------------------------------------------- /infotool/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/infotool/vision/__init__.py -------------------------------------------------------------------------------- /infotool/vision/basic_hooks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from .counter import ( 4 | counter_parameters, 5 | counter_conv, 6 | counter_norm, 7 | counter_relu, 8 | counter_softmax, 9 | counter_avgpool, 10 | counter_adap_avg, 11 | counter_zero_ops, 12 | counter_upsample, 13 | counter_linear, 14 | ) 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn.modules.conv import _ConvNd 18 | 19 | multiply_adds = 1 20 | 21 | 22 | def count_parameters(m, x, y): 23 | total_params = 0 24 | for p in m.parameters(): 25 | total_params += torch.DoubleTensor([p.numel()]) 26 | m.total_params[0] = counter_parameters(m.parameters()) 27 | 28 | 29 | def zero_ops(m, x, y): 30 | m.total_ops += counter_zero_ops() 31 | 32 | 33 | def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): 34 | x = x[0] 35 | 36 | kernel_ops = torch.zeros(m.weight.size()[2:]).numel() # Kw x Kh 37 | bias_ops = 1 if m.bias is not None else 0 38 | 39 | # N x Cout x H x W x (Cin x Kw x Kh + bias) 40 | m.total_ops += counter_conv( 41 | bias_ops, 42 | torch.zeros(m.weight.size()[2:]).numel(), 43 | y.nelement(), 44 | m.in_channels, 45 | m.groups, 46 | ) 47 | 48 | 49 | def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): 50 | x = x[0] 51 | 52 | # N x H x W (exclude Cout) 53 | output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() 54 | # # Cout x Cin x Kw x Kh 55 | # kernel_ops = m.weight.nelement() 56 | # if m.bias is not None: 57 | # # Cout x 1 58 | # kernel_ops += + m.bias.nelement() 59 | # # x N x H x W x Cout x (Cin x Kw x Kh + bias) 60 | # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) 61 | m.total_ops += counter_conv(m.bias.nelement(), m.weight.nelement(), output_size) 62 | 63 | 64 | def count_bn(m, x, y): 65 | x = x[0] 66 | if not m.training: 67 | m.total_ops += counter_norm(x.numel()) 68 | 69 | 70 | def count_ln(m, x, y): 71 | x = x[0] 72 | if not m.training: 73 | m.total_ops += counter_norm(x.numel()) 74 | 75 | 76 | def count_in(m, x, y): 77 | x = x[0] 78 | if not m.training: 79 | m.total_ops += counter_norm(x.numel()) 80 | 81 | 82 | def count_prelu(m, x, y): 83 | x = x[0] 84 | 85 | nelements = x.numel() 86 | if not m.training: 87 | m.total_ops += counter_relu(nelements) 88 | 89 | 90 | def count_relu(m, x, y): 91 | x = x[0] 92 | 93 | nelements = x.numel() 94 | 95 | m.total_ops += counter_relu(nelements) 96 | 97 | 98 | def count_softmax(m, x, y): 99 | x = x[0] 100 | nfeatures = x.size()[m.dim] 101 | batch_size = x.numel() // nfeatures 102 | 103 | m.total_ops += counter_softmax(batch_size, nfeatures) 104 | 105 | 106 | def count_avgpool(m, x, y): 107 | # total_add = torch.prod(torch.Tensor([m.kernel_size])) 108 | # total_div = 1 109 | # kernel_ops = total_add + total_div 110 | num_elements = y.numel() 111 | m.total_ops += counter_avgpool(num_elements) 112 | 113 | 114 | def count_adap_avgpool(m, x, y): 115 | kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor( 116 | [*(y.shape[2:])] 117 | ) 118 | total_add = torch.prod(kernel) 119 | num_elements = y.numel() 120 | m.total_ops += counter_adap_avg(total_add, num_elements) 121 | 122 | 123 | # TODO: verify the accuracy 124 | def count_upsample(m, x, y): 125 | if m.mode not in ( 126 | "nearest", 127 | "linear", 128 | "bilinear", 129 | "bicubic", 130 | ): # "trilinear" 131 | logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) 132 | return counter_zero_ops() 133 | 134 | if m.mode == "nearest": 135 | return counter_zero_ops() 136 | 137 | x = x[0] 138 | m.total_ops += counter_upsample(m.mode, y.nelement()) 139 | 140 | 141 | # nn.Linear 142 | def count_linear(m, x, y): 143 | # per output element 144 | total_mul = m.in_features 145 | # total_add = m.in_features - 1 146 | # total_add += 1 if m.bias is not None else 0 147 | num_elements = y.numel() 148 | 149 | m.total_ops += counter_linear(total_mul, num_elements) 150 | -------------------------------------------------------------------------------- /infotool/vision/counter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def counter_parameters(para_list): 6 | total_params = 0 7 | for p in para_list: 8 | total_params += torch.DoubleTensor([p.nelement()]) 9 | return total_params 10 | 11 | 12 | def counter_zero_ops(): 13 | return torch.DoubleTensor([int(0)]) 14 | 15 | 16 | def counter_conv(bias, kernel_size, output_size, in_channel, group): 17 | """inputs are all numbers!""" 18 | return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)]) 19 | 20 | 21 | def counter_norm(input_size): 22 | """input is a number not a array or tensor""" 23 | return torch.DoubleTensor([2 * input_size]) 24 | 25 | 26 | def counter_relu(input_size: torch.Tensor): 27 | return torch.DoubleTensor([int(input_size)]) 28 | 29 | 30 | def counter_softmax(batch_size, nfeatures): 31 | total_exp = nfeatures 32 | total_add = nfeatures - 1 33 | total_div = nfeatures 34 | total_ops = batch_size * (total_exp + total_add + total_div) 35 | return torch.DoubleTensor([int(total_ops)]) 36 | 37 | 38 | def counter_avgpool(input_size): 39 | return torch.DoubleTensor([int(input_size)]) 40 | 41 | 42 | def counter_adap_avg(kernel_size, output_size): 43 | total_div = 1 44 | kernel_op = kernel_size + total_div 45 | return torch.DoubleTensor([int(kernel_op * output_size)]) 46 | 47 | 48 | def counter_upsample(mode: str, output_size): 49 | total_ops = output_size 50 | if mode == "linear": 51 | total_ops *= 5 52 | elif mode == "bilinear": 53 | total_ops *= 11 54 | elif mode == "bicubic": 55 | ops_solve_A = 224 # 128 muls + 96 adds 56 | ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds 57 | total_ops *= ops_solve_A + ops_solve_p 58 | elif mode == "trilinear": 59 | total_ops *= 13 * 2 + 5 60 | return torch.DoubleTensor([int(total_ops)]) 61 | 62 | 63 | def counter_linear(in_feature, num_elements): 64 | return torch.DoubleTensor([int(in_feature * num_elements)]) 65 | 66 | 67 | def counter_matmul(input_size, output_size): 68 | input_size = np.array(input_size) 69 | output_size = np.array(output_size) 70 | return np.prod(input_size) * output_size[-1] 71 | 72 | 73 | def counter_mul(input_size): 74 | return input_size 75 | 76 | 77 | def counter_pow(input_size): 78 | return input_size 79 | 80 | 81 | def counter_sqrt(input_size): 82 | return input_size 83 | 84 | 85 | def counter_div(input_size): 86 | return input_size 87 | -------------------------------------------------------------------------------- /infotool/vision/efficientnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules.conv import _ConvNd 7 | 8 | from efficientnet_pytorch.utils import Conv2dDynamicSamePadding, Conv2dStaticSamePadding 9 | 10 | register_hooks = {} 11 | -------------------------------------------------------------------------------- /infotool/vision/onnx_counter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from onnx import numpy_helper 4 | from thop.vision.basic_hooks import zero_ops 5 | from .counter import ( 6 | counter_matmul, 7 | counter_zero_ops, 8 | counter_conv, 9 | counter_mul, 10 | counter_norm, 11 | counter_pow, 12 | counter_sqrt, 13 | counter_div, 14 | counter_softmax, 15 | counter_avgpool, 16 | ) 17 | 18 | 19 | def onnx_counter_matmul(diction, node): 20 | input1 = node.input[0] 21 | input2 = node.input[1] 22 | input1_dim = diction[input1] 23 | input2_dim = diction[input2] 24 | out_size = np.append(input1_dim[0:-1], input2_dim[-1]) 25 | output_name = node.output[0] 26 | macs = counter_matmul(input1_dim, out_size[-2:]) 27 | return macs, out_size, output_name 28 | 29 | 30 | def onnx_counter_add(diction, node): 31 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 32 | out_size = diction[node.input[1]] 33 | else: 34 | out_size = diction[node.input[0]] 35 | output_name = node.output[0] 36 | macs = counter_zero_ops() 37 | # if '140' in diction: 38 | # print(diction['140'],output_name) 39 | return macs, out_size, output_name 40 | 41 | 42 | def onnx_counter_conv(diction, node): 43 | # print(node) 44 | # bias,kernelsize,outputsize 45 | dim_bias = 0 46 | input_count = 0 47 | for i in node.input: 48 | input_count += 1 49 | if input_count == 3: 50 | dim_bias = 1 51 | dim_weight = diction[node.input[1]] 52 | else: 53 | dim_weight = diction[node.input[1]] 54 | for attr in node.attribute: 55 | # print(attr) 56 | if attr.name == "kernel_shape": 57 | dim_kernel = attr.ints # kw,kh 58 | if attr.name == "strides": 59 | dim_stride = attr.ints 60 | if attr.name == "pads": 61 | dim_pad = attr.ints 62 | if attr.name == "dilations": 63 | dim_dil = attr.ints 64 | if attr.name == "group": 65 | group = attr.i 66 | # print(dim_dil) 67 | dim_input = diction[node.input[0]] 68 | output_size = np.append( 69 | dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0] 70 | ) 71 | hw = np.array(dim_input[-np.array(dim_kernel).size :]) 72 | for i in range(hw.size): 73 | hw[i] = int( 74 | (hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1) 75 | / dim_stride[i] 76 | + 1 77 | ) 78 | output_size = np.append(output_size, hw) 79 | macs = counter_conv( 80 | dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group 81 | ) 82 | output_name = node.output[0] 83 | 84 | # if '140' in diction: 85 | # print("conv",diction['140'],output_name) 86 | return macs, output_size, output_name 87 | 88 | 89 | def onnx_counter_constant(diction, node): 90 | # print("constant",node) 91 | macs = counter_zero_ops() 92 | output_name = node.output[0] 93 | output_size = [1] 94 | # print(macs, output_size, output_name) 95 | return macs, output_size, output_name 96 | 97 | 98 | def onnx_counter_mul(diction, node): 99 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 100 | input_size = diction[node.input[1]] 101 | else: 102 | input_size = diction[node.input[0]] 103 | macs = counter_mul(np.prod(input_size)) 104 | output_size = diction[node.input[0]] 105 | output_name = node.output[0] 106 | return macs, output_size, output_name 107 | 108 | 109 | def onnx_counter_bn(diction, node): 110 | input_size = diction[node.input[0]] 111 | macs = counter_norm(np.prod(input_size)) 112 | output_name = node.output[0] 113 | output_size = input_size 114 | return macs, output_size, output_name 115 | 116 | 117 | def onnx_counter_relu(diction, node): 118 | input_size = diction[node.input[0]] 119 | macs = counter_zero_ops() 120 | output_name = node.output[0] 121 | output_size = input_size 122 | # print(macs, output_size, output_name) 123 | # if '140' in diction: 124 | # print("relu",diction['140'],output_name) 125 | return macs, output_size, output_name 126 | 127 | 128 | def onnx_counter_reducemean(diction, node): 129 | keep_dim = 0 130 | for attr in node.attribute: 131 | if "axes" in attr.name: 132 | dim_axis = np.array(attr.ints) 133 | elif "keepdims" in attr.name: 134 | keep_dim = attr.i 135 | 136 | input_size = diction[node.input[0]] 137 | macs = counter_zero_ops() 138 | output_name = node.output[0] 139 | if keep_dim == 1: 140 | output_size = input_size 141 | else: 142 | output_size = np.delete(input_size, dim_axis) 143 | # output_size = input_size 144 | return macs, output_size, output_name 145 | 146 | 147 | def onnx_counter_sub(diction, node): 148 | input_size = diction[node.input[0]] 149 | macs = counter_zero_ops() 150 | output_name = node.output[0] 151 | output_size = input_size 152 | return macs, output_size, output_name 153 | 154 | 155 | def onnx_counter_pow(diction, node): 156 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 157 | input_size = diction[node.input[1]] 158 | else: 159 | input_size = diction[node.input[0]] 160 | macs = counter_pow(np.prod(input_size)) 161 | output_name = node.output[0] 162 | output_size = input_size 163 | return macs, output_size, output_name 164 | 165 | 166 | def onnx_counter_sqrt(diction, node): 167 | input_size = diction[node.input[0]] 168 | macs = counter_sqrt(np.prod(input_size)) 169 | output_name = node.output[0] 170 | output_size = input_size 171 | return macs, output_size, output_name 172 | 173 | 174 | def onnx_counter_div(diction, node): 175 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 176 | input_size = diction[node.input[1]] 177 | else: 178 | input_size = diction[node.input[0]] 179 | macs = counter_div(np.prod(input_size)) 180 | output_name = node.output[0] 181 | output_size = input_size 182 | return macs, output_size, output_name 183 | 184 | 185 | def onnx_counter_instance(diction, node): 186 | input_size = diction[node.input[0]] 187 | macs = counter_norm(np.prod(input_size)) 188 | output_name = node.output[0] 189 | output_size = input_size 190 | return macs, output_size, output_name 191 | 192 | 193 | def onnx_counter_softmax(diction, node): 194 | input_size = diction[node.input[0]] 195 | dim = node.attribute[0].i 196 | nfeatures = input_size[dim] 197 | batch_size = np.prod(input_size) / nfeatures 198 | macs = counter_softmax(nfeatures, batch_size) 199 | output_name = node.output[0] 200 | output_size = input_size 201 | return macs, output_size, output_name 202 | 203 | 204 | def onnx_counter_pad(diction, node): 205 | # # TODO add constant name and output real vector 206 | # if 207 | # if (np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size): 208 | # input_size = diction[node.input[1]] 209 | # else: 210 | # input_size = diction[node.input[0]] 211 | input_size = diction[node.input[0]] 212 | macs = counter_zero_ops() 213 | output_name = node.output[0] 214 | output_size = input_size 215 | return macs, output_size, output_name 216 | 217 | 218 | def onnx_counter_averagepool(diction, node): 219 | # TODO add support of ceil_mode and floor 220 | macs = counter_avgpool(np.prod(diction[node.input[0]])) 221 | output_name = node.output[0] 222 | dim_pad = None 223 | for attr in node.attribute: 224 | # print(attr) 225 | if attr.name == "kernel_shape": 226 | dim_kernel = attr.ints # kw,kh 227 | elif attr.name == "strides": 228 | dim_stride = attr.ints 229 | elif attr.name == "pads": 230 | dim_pad = attr.ints 231 | elif attr.name == "dilations": 232 | dim_dil = attr.ints 233 | # print(dim_dil) 234 | dim_input = diction[node.input[0]] 235 | hw = dim_input[-np.array(dim_kernel).size :] 236 | if dim_pad is not None: 237 | for i in range(hw.size): 238 | hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_kernel[i]) / dim_stride[i] + 1) 239 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 240 | else: 241 | for i in range(hw.size): 242 | hw[i] = int((hw[i] - dim_kernel[i]) / dim_stride[i] + 1) 243 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 244 | # print(macs, output_size, output_name) 245 | return macs, output_size, output_name 246 | 247 | 248 | def onnx_counter_flatten(diction, node): 249 | # print(node) 250 | macs = counter_zero_ops() 251 | output_name = node.output[0] 252 | axis = node.attribute[0].i 253 | input_size = diction[node.input[0]] 254 | output_size = np.append(input_size[axis - 1], np.prod(input_size[axis:])) 255 | # print("flatten",output_size) 256 | return macs, output_size, output_name 257 | 258 | 259 | def onnx_counter_gemm(diction, node): 260 | # print(node) 261 | # Compute Y = alpha * A' * B' + beta * C 262 | input_size = diction[node.input[0]] 263 | dim_weight = diction[node.input[1]] 264 | # print(input_size,dim_weight) 265 | macs = np.prod(input_size) * dim_weight[1] + dim_weight[0] 266 | output_size = np.append(input_size[0:-1], dim_weight[0]) 267 | output_name = node.output[0] 268 | return macs, output_size, output_name 269 | pass 270 | 271 | 272 | def onnx_counter_maxpool(diction, node): 273 | # TODO add support of ceil_mode and floor 274 | # print(node) 275 | macs = counter_zero_ops() 276 | output_name = node.output[0] 277 | dim_pad = None 278 | for attr in node.attribute: 279 | # print(attr) 280 | if attr.name == "kernel_shape": 281 | dim_kernel = attr.ints # kw,kh 282 | elif attr.name == "strides": 283 | dim_stride = attr.ints 284 | elif attr.name == "pads": 285 | dim_pad = attr.ints 286 | elif attr.name == "dilations": 287 | dim_dil = attr.ints 288 | # print(dim_dil) 289 | dim_input = diction[node.input[0]] 290 | hw = dim_input[-np.array(dim_kernel).size :] 291 | if dim_pad is not None: 292 | for i in range(hw.size): 293 | hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_kernel[i]) / dim_stride[i] + 1) 294 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 295 | else: 296 | for i in range(hw.size): 297 | hw[i] = int((hw[i] - dim_kernel[i]) / dim_stride[i] + 1) 298 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 299 | # print(macs, output_size, output_name) 300 | return macs, output_size, output_name 301 | 302 | 303 | def onnx_counter_globalaveragepool(diction, node): 304 | macs = counter_zero_ops() 305 | output_name = node.output[0] 306 | input_size = diction[node.input[0]] 307 | output_size = input_size 308 | return macs, output_size, output_name 309 | 310 | 311 | def onnx_counter_concat(diction, node): 312 | # print(node) 313 | # print(diction[node.input[0]]) 314 | axis = node.attribute[0].i 315 | input_size = diction[node.input[0]] 316 | for i in node.input: 317 | dim_concat = diction[i][axis] 318 | output_size = input_size 319 | output_size[axis] = dim_concat 320 | output_name = node.output[0] 321 | macs = counter_zero_ops() 322 | return macs, output_size, output_name 323 | 324 | 325 | def onnx_counter_clip(diction, node): 326 | macs = counter_zero_ops() 327 | output_name = node.output[0] 328 | input_size = diction[node.input[0]] 329 | output_size = input_size 330 | return macs, output_size, output_name 331 | 332 | 333 | onnx_operators = { 334 | "MatMul": onnx_counter_matmul, 335 | "Add": onnx_counter_add, 336 | "Conv": onnx_counter_conv, 337 | "Mul": onnx_counter_mul, 338 | "Constant": onnx_counter_constant, 339 | "BatchNormalization": onnx_counter_bn, 340 | "Relu": onnx_counter_relu, 341 | "ReduceMean": onnx_counter_reducemean, 342 | "Sub": onnx_counter_sub, 343 | "Pow": onnx_counter_pow, 344 | "Sqrt": onnx_counter_sqrt, 345 | "Div": onnx_counter_div, 346 | "InstanceNormalization": onnx_counter_instance, 347 | "Softmax": onnx_counter_softmax, 348 | "Pad": onnx_counter_pad, 349 | "AveragePool": onnx_counter_averagepool, 350 | "MaxPool": onnx_counter_maxpool, 351 | "Flatten": onnx_counter_flatten, 352 | "Gemm": onnx_counter_gemm, 353 | "GlobalAveragePool": onnx_counter_globalaveragepool, 354 | "Concat": onnx_counter_concat, 355 | "Clip": onnx_counter_clip, 356 | None: None, 357 | } 358 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .repghost import * 2 | -------------------------------------------------------------------------------- /model/repghost.py: -------------------------------------------------------------------------------- 1 | # @Author : chengpeng.chen 2 | # @Email : chencp@live.com 3 | """ 4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong. 5 | https://arxiv.org/abs/2211.06088 6 | """ 7 | import copy 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | __all__ = [ 15 | 'repghostnet_0_5x', 16 | 'repghostnet_repid_0_5x', 17 | 'repghostnet_norep_0_5x', 18 | 'repghostnet_wo_0_5x', 19 | 'repghostnet_0_58x', 20 | 'repghostnet_0_8x', 21 | 'repghostnet_1_0x', 22 | 'repghostnet_1_11x', 23 | 'repghostnet_1_3x', 24 | 'repghostnet_1_5x', 25 | 'repghostnet_2_0x', 26 | 'repghostnet', 27 | ] 28 | 29 | 30 | def _make_divisible(v, divisor, min_value=None): 31 | """ 32 | This function is taken from the original tf repo. 33 | It ensures that all layers have a channel number that is divisible by 8 34 | It can be seen here: 35 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 36 | """ 37 | if min_value is None: 38 | min_value = divisor 39 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 40 | # Make sure that round down does not go down by more than 10%. 41 | if new_v < 0.9 * v: 42 | new_v += divisor 43 | return new_v 44 | 45 | 46 | def hard_sigmoid(x, inplace: bool = False): 47 | if inplace: 48 | return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0) 49 | else: 50 | return F.relu6(x + 3.0) / 6.0 51 | 52 | 53 | class SqueezeExcite(nn.Module): 54 | def __init__( 55 | self, 56 | in_chs, 57 | se_ratio=0.25, 58 | reduced_base_chs=None, 59 | act_layer=nn.ReLU, 60 | gate_fn=hard_sigmoid, 61 | divisor=4, 62 | **_, 63 | ): 64 | super(SqueezeExcite, self).__init__() 65 | self.gate_fn = gate_fn 66 | reduced_chs = _make_divisible( 67 | (reduced_base_chs or in_chs) * se_ratio, divisor, 68 | ) 69 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 70 | self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) 71 | self.act1 = act_layer(inplace=True) 72 | self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) 73 | 74 | def forward(self, x): 75 | x_se = self.avg_pool(x) 76 | x_se = self.conv_reduce(x_se) 77 | x_se = self.act1(x_se) 78 | x_se = self.conv_expand(x_se) 79 | x = x * self.gate_fn(x_se) 80 | return x 81 | 82 | 83 | class ConvBnAct(nn.Module): 84 | def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_layer=nn.ReLU): 85 | super(ConvBnAct, self).__init__() 86 | self.conv = nn.Conv2d( 87 | in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False, 88 | ) 89 | self.bn1 = nn.BatchNorm2d(out_chs) 90 | self.act1 = act_layer(inplace=True) 91 | 92 | def forward(self, x): 93 | x = self.conv(x) 94 | x = self.bn1(x) 95 | x = self.act1(x) 96 | return x 97 | 98 | 99 | class RepGhostModule(nn.Module): 100 | def __init__( 101 | self, inp, oup, kernel_size=1, dw_size=3, stride=1, relu=True, deploy=False, reparam_bn=True, reparam_identity=False 102 | ): 103 | super(RepGhostModule, self).__init__() 104 | init_channels = oup 105 | new_channels = oup 106 | self.deploy = deploy 107 | 108 | self.primary_conv = nn.Sequential( 109 | nn.Conv2d( 110 | inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False, 111 | ), 112 | nn.BatchNorm2d(init_channels), 113 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 114 | ) 115 | fusion_conv = [] 116 | fusion_bn = [] 117 | if not deploy and reparam_bn: 118 | fusion_conv.append(nn.Identity()) 119 | fusion_bn.append(nn.BatchNorm2d(init_channels)) 120 | if not deploy and reparam_identity: 121 | fusion_conv.append(nn.Identity()) 122 | fusion_bn.append(nn.Identity()) 123 | 124 | self.fusion_conv = nn.Sequential(*fusion_conv) 125 | self.fusion_bn = nn.Sequential(*fusion_bn) 126 | 127 | self.cheap_operation = nn.Sequential( 128 | nn.Conv2d( 129 | init_channels, 130 | new_channels, 131 | dw_size, 132 | 1, 133 | dw_size // 2, 134 | groups=init_channels, 135 | bias=deploy, 136 | ), 137 | nn.BatchNorm2d(new_channels) if not deploy else nn.Sequential(), 138 | # nn.ReLU(inplace=True) if relu else nn.Sequential(), 139 | ) 140 | if deploy: 141 | self.cheap_operation = self.cheap_operation[0] 142 | if relu: 143 | self.relu = nn.ReLU(inplace=False) 144 | else: 145 | self.relu = nn.Sequential() 146 | 147 | def forward(self, x): 148 | x1 = self.primary_conv(x) 149 | x2 = self.cheap_operation(x1) 150 | for conv, bn in zip(self.fusion_conv, self.fusion_bn): 151 | x2 = x2 + bn(conv(x1)) 152 | return self.relu(x2) 153 | 154 | def get_equivalent_kernel_bias(self): 155 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1]) 156 | for conv, bn in zip(self.fusion_conv, self.fusion_bn): 157 | kernel, bias = self._fuse_bn_tensor(conv, bn, kernel3x3.shape[0], kernel3x3.device) 158 | kernel3x3 += self._pad_1x1_to_3x3_tensor(kernel) 159 | bias3x3 += bias 160 | return kernel3x3, bias3x3 161 | 162 | @staticmethod 163 | def _pad_1x1_to_3x3_tensor(kernel1x1): 164 | if kernel1x1 is None: 165 | return 0 166 | else: 167 | return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) 168 | 169 | @staticmethod 170 | def _fuse_bn_tensor(conv, bn, in_channels=None, device=None): 171 | in_channels = in_channels if in_channels else bn.running_mean.shape[0] 172 | device = device if device else bn.weight.device 173 | if isinstance(conv, nn.Conv2d): 174 | kernel = conv.weight 175 | assert conv.bias is None 176 | else: 177 | assert isinstance(conv, nn.Identity) 178 | kernel_value = np.zeros((in_channels, 1, 1, 1), dtype=np.float32) 179 | for i in range(in_channels): 180 | kernel_value[i, 0, 0, 0] = 1 181 | kernel = torch.from_numpy(kernel_value).to(device) 182 | 183 | if isinstance(bn, nn.BatchNorm2d): 184 | running_mean = bn.running_mean 185 | running_var = bn.running_var 186 | gamma = bn.weight 187 | beta = bn.bias 188 | eps = bn.eps 189 | std = (running_var + eps).sqrt() 190 | t = (gamma / std).reshape(-1, 1, 1, 1) 191 | return kernel * t, beta - running_mean * gamma / std 192 | assert isinstance(bn, nn.Identity) 193 | return kernel, torch.zeros(in_channels).to(kernel.device) 194 | 195 | def switch_to_deploy(self): 196 | if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0: 197 | return 198 | kernel, bias = self.get_equivalent_kernel_bias() 199 | self.cheap_operation = nn.Conv2d(in_channels=self.cheap_operation[0].in_channels, 200 | out_channels=self.cheap_operation[0].out_channels, 201 | kernel_size=self.cheap_operation[0].kernel_size, 202 | padding=self.cheap_operation[0].padding, 203 | dilation=self.cheap_operation[0].dilation, 204 | groups=self.cheap_operation[0].groups, 205 | bias=True) 206 | self.cheap_operation.weight.data = kernel 207 | self.cheap_operation.bias.data = bias 208 | self.__delattr__('fusion_conv') 209 | self.__delattr__('fusion_bn') 210 | self.fusion_conv = [] 211 | self.fusion_bn = [] 212 | self.deploy = True 213 | 214 | 215 | class RepGhostBottleneck(nn.Module): 216 | """RepGhost bottleneck w/ optional SE""" 217 | 218 | def __init__( 219 | self, 220 | in_chs, 221 | mid_chs, 222 | out_chs, 223 | dw_kernel_size=3, 224 | stride=1, 225 | se_ratio=0.0, 226 | shortcut=True, 227 | reparam=True, 228 | reparam_bn=True, 229 | reparam_identity=False, 230 | deploy=False, 231 | ): 232 | super(RepGhostBottleneck, self).__init__() 233 | has_se = se_ratio is not None and se_ratio > 0.0 234 | self.stride = stride 235 | self.enable_shortcut = shortcut 236 | self.in_chs = in_chs 237 | self.out_chs = out_chs 238 | 239 | # Point-wise expansion 240 | self.ghost1 = RepGhostModule( 241 | in_chs, 242 | mid_chs, 243 | relu=True, 244 | reparam_bn=reparam and reparam_bn, 245 | reparam_identity=reparam and reparam_identity, 246 | deploy=deploy, 247 | ) 248 | 249 | # Depth-wise convolution 250 | if self.stride > 1: 251 | self.conv_dw = nn.Conv2d( 252 | mid_chs, 253 | mid_chs, 254 | dw_kernel_size, 255 | stride=stride, 256 | padding=(dw_kernel_size - 1) // 2, 257 | groups=mid_chs, 258 | bias=False, 259 | ) 260 | self.bn_dw = nn.BatchNorm2d(mid_chs) 261 | 262 | # Squeeze-and-excitation 263 | if has_se: 264 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) 265 | else: 266 | self.se = None 267 | 268 | # Point-wise linear projection 269 | self.ghost2 = RepGhostModule( 270 | mid_chs, 271 | out_chs, 272 | relu=False, 273 | reparam_bn=reparam and reparam_bn, 274 | reparam_identity=reparam and reparam_identity, 275 | deploy=deploy, 276 | ) 277 | 278 | # shortcut 279 | if in_chs == out_chs and self.stride == 1: 280 | self.shortcut = nn.Sequential() 281 | else: 282 | self.shortcut = nn.Sequential( 283 | nn.Conv2d( 284 | in_chs, 285 | in_chs, 286 | dw_kernel_size, 287 | stride=stride, 288 | padding=(dw_kernel_size - 1) // 2, 289 | groups=in_chs, 290 | bias=False, 291 | ), 292 | nn.BatchNorm2d(in_chs), 293 | nn.Conv2d( 294 | in_chs, out_chs, 1, stride=1, 295 | padding=0, bias=False, 296 | ), 297 | nn.BatchNorm2d(out_chs), 298 | ) 299 | 300 | def forward(self, x): 301 | residual = x 302 | 303 | # 1st repghost bottleneck 304 | x1 = self.ghost1(x) 305 | 306 | # Depth-wise convolution 307 | if self.stride > 1: 308 | x = self.conv_dw(x1) 309 | x = self.bn_dw(x) 310 | else: 311 | x = x1 312 | 313 | # Squeeze-and-excitation 314 | if self.se is not None: 315 | x = self.se(x) 316 | 317 | # 2nd repghost bottleneck 318 | x = self.ghost2(x) 319 | if not self.enable_shortcut and self.in_chs == self.out_chs and self.stride == 1: 320 | return x 321 | return x + self.shortcut(residual) 322 | 323 | 324 | class RepGhostNet(nn.Module): 325 | def __init__( 326 | self, 327 | cfgs, 328 | num_classes=1000, 329 | width=1.0, 330 | dropout=0.2, 331 | shortcut=True, 332 | reparam=True, 333 | reparam_bn=True, 334 | reparam_identity=False, 335 | deploy=False, 336 | ): 337 | super(RepGhostNet, self).__init__() 338 | # setting of inverted residual blocks 339 | self.cfgs = cfgs 340 | self.dropout = dropout 341 | self.num_classes = num_classes 342 | 343 | # building first layer 344 | output_channel = _make_divisible(16 * width, 4) 345 | self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) 346 | self.bn1 = nn.BatchNorm2d(output_channel) 347 | self.act1 = nn.ReLU(inplace=True) 348 | input_channel = output_channel 349 | 350 | # building inverted residual blocks 351 | stages = [] 352 | block = RepGhostBottleneck 353 | for cfg in self.cfgs: 354 | layers = [] 355 | for k, exp_size, c, se_ratio, s in cfg: 356 | output_channel = _make_divisible(c * width, 4) 357 | hidden_channel = _make_divisible(exp_size * width, 4) 358 | layers.append( 359 | block( 360 | input_channel, 361 | hidden_channel, 362 | output_channel, 363 | k, 364 | s, 365 | se_ratio=se_ratio, 366 | shortcut=shortcut, 367 | reparam=reparam, 368 | reparam_bn=reparam_bn, 369 | reparam_identity=reparam_identity, 370 | deploy=deploy 371 | ), 372 | ) 373 | input_channel = output_channel 374 | stages.append(nn.Sequential(*layers)) 375 | 376 | output_channel = _make_divisible(exp_size * width * 2, 4) 377 | stages.append( 378 | nn.Sequential( 379 | ConvBnAct(input_channel, output_channel, 1), 380 | ), 381 | ) 382 | input_channel = output_channel 383 | 384 | self.blocks = nn.Sequential(*stages) 385 | 386 | # building last several layers 387 | output_channel = 1280 388 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 389 | self.conv_head = nn.Conv2d( 390 | input_channel, output_channel, 1, 1, 0, bias=True, 391 | ) 392 | self.act2 = nn.ReLU(inplace=True) 393 | self.classifier = nn.Linear(output_channel, num_classes) 394 | 395 | def forward(self, x): 396 | x = self.conv_stem(x) 397 | x = self.bn1(x) 398 | x = self.act1(x) 399 | x = self.blocks(x) 400 | x = self.global_pool(x) 401 | x = self.conv_head(x) 402 | x = self.act2(x) 403 | x = x.view(x.size(0), -1) 404 | if self.dropout > 0.0: 405 | x = F.dropout(x, p=self.dropout, training=self.training) 406 | x = self.classifier(x) 407 | return x 408 | 409 | def convert_to_deploy(self): 410 | repghost_model_convert(self, do_copy=False) 411 | 412 | 413 | def repghost_model_convert(model:torch.nn.Module, save_path=None, do_copy=True): 414 | """ 415 | taken from from https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py 416 | """ 417 | if do_copy: 418 | model = copy.deepcopy(model) 419 | for module in model.modules(): 420 | if hasattr(module, 'switch_to_deploy'): 421 | module.switch_to_deploy() 422 | if save_path is not None: 423 | torch.save(model.state_dict(), save_path) 424 | return model 425 | 426 | 427 | def repghostnet(enable_se=True, **kwargs): 428 | """ 429 | Constructs a RepGhostNet model 430 | """ 431 | cfgs = [ 432 | # k, t, c, SE, s 433 | # stage1 434 | [[3, 8, 16, 0, 1]], 435 | # stage2 436 | [[3, 24, 24, 0, 2]], 437 | [[3, 36, 24, 0, 1]], 438 | # stage3 439 | [[5, 36, 40, 0.25 if enable_se else 0, 2]], 440 | [[5, 60, 40, 0.25 if enable_se else 0, 1]], 441 | # stage4 442 | [[3, 120, 80, 0, 2]], 443 | [ 444 | [3, 100, 80, 0, 1], 445 | [3, 120, 80, 0, 1], 446 | [3, 120, 80, 0, 1], 447 | [3, 240, 112, 0.25 if enable_se else 0, 1], 448 | [3, 336, 112, 0.25 if enable_se else 0, 1], 449 | ], 450 | # stage5 451 | [[5, 336, 160, 0.25 if enable_se else 0, 2]], 452 | [ 453 | [5, 480, 160, 0, 1], 454 | [5, 480, 160, 0.25 if enable_se else 0, 1], 455 | [5, 480, 160, 0, 1], 456 | [5, 480, 160, 0.25 if enable_se else 0, 1], 457 | ], 458 | ] 459 | 460 | return RepGhostNet(cfgs, **kwargs) 461 | 462 | 463 | def repghostnet_0_5x(**kwargs): 464 | return repghostnet(width=0.5, **kwargs) 465 | 466 | 467 | def repghostnet_repid_0_5x(**kwargs): 468 | return repghostnet(width=0.5, reparam_bn=False, reparam_identity=True, **kwargs) 469 | 470 | 471 | def repghostnet_norep_0_5x(**kwargs): 472 | return repghostnet(width=0.5, reparam=False, **kwargs) 473 | 474 | 475 | def repghostnet_wo_0_5x(**kwargs): 476 | return repghostnet(width=0.5, shortcut=False, **kwargs) 477 | 478 | 479 | def repghostnet_0_58x(**kwargs): 480 | return repghostnet(width=0.58, **kwargs) 481 | 482 | 483 | def repghostnet_0_8x(**kwargs): 484 | return repghostnet(width=0.8, **kwargs) 485 | 486 | 487 | def repghostnet_1_0x(**kwargs): 488 | return repghostnet(width=1.0, **kwargs) 489 | 490 | 491 | def repghostnet_1_11x(**kwargs): 492 | return repghostnet(width=1.11, **kwargs) 493 | 494 | 495 | def repghostnet_1_3x(**kwargs): 496 | return repghostnet(width=1.3, **kwargs) 497 | 498 | 499 | def repghostnet_1_5x(**kwargs): 500 | return repghostnet(width=1.5, **kwargs) 501 | 502 | 503 | def repghostnet_2_0x(**kwargs): 504 | return repghostnet(width=2.0, **kwargs) 505 | 506 | 507 | if __name__ == "__main__": 508 | model = repghostnet_norep_0_5x().eval() 509 | print(model) 510 | input = torch.randn(1, 3, 224, 224) 511 | # y = model(input) 512 | # print(y.size()) 513 | 514 | import sys 515 | 516 | sys.path.append("../") 517 | from tools import cal_flops_params 518 | 519 | flops, params = cal_flops_params(model, input_size=input.shape) 520 | 521 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization 2 | 3 | The official pytorch implementation of the paper **[RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization](https://arxiv.org/abs/2211.06088)** 4 | 5 | #### Chengpeng Chen\*, Zichao Guo\*, Haien Zeng, Pengfei Xiong, Jian Dong 6 | 7 | >Feature reuse has been a key technique in light-weight convolutional neural networks (CNNs) design. Current methods usually utilize a concatenation operator to keep large channel numbers cheaply (thus large network capacity) by reusing feature maps from other layers. Although concatenation is parameters- and FLOPs-free, its computational cost on hardware devices is non-negligible. To address this, this paper provides a new perspective to realize feature reuse via structural re-parameterization technique. A novel hardware-efficient RepGhost module is proposed for implicit feature reuse via re-parameterization, instead of using concatenation operator. Based on the RepGhost module, we develop our efficient RepGhost bottleneck and RepGhostNet. Experiments on ImageNet and COCO benchmarks demonstrate that the proposed RepGhostNet is much more effective and efficient than GhostNet and MobileNetV3 on mobile devices. Specially, our RepGhostNet surpasses GhostNet 0.5x by 2.5% Top-1 accuracy on ImageNet dataset with less parameters and comparable latency on an ARM-based mobile phone. 8 | 9 |

10 | 11 |

12 | 13 | ```python 14 | python 3.9.12 15 | pytorch 1.11.0 16 | cuda 11.3 17 | timm 0.6.7 18 | ``` 19 | 20 | ```bash 21 | git clone https://github.com/ChengpengChen/RepGhost 22 | cd RepGhost 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Training 27 | ```bash 28 | bash distributed_train.sh 8 --model repghost.repghostnet_0_5x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ --data_dir {path_to_imagenet_dir} 29 | ``` 30 | 31 | ### Validation 32 | ```bash 33 | python3 -m torch.distributed.launch --nproc_per_node=8 --master_port=2340 validate.py -b 32 --model-ema --model {model} --resume {checkpoint_path} --data_dir {path_to_imagenet_dir} 34 | ``` 35 | 36 | ### Convert a training-time RepGhost into a fast-inference one 37 | To check the conversion example at ```convert.py```. You can also convert RepGhostNet model for fast inference via: 38 | 39 | ```python 40 | model.convert_to_deploy() 41 | ``` 42 | 43 | ### Results and Pre-trained Models 44 | 45 |

46 | 47 |

48 | 49 | | RepGhostNet | Params(M) | FLOPs(M) | Latency(ms) | Top-1 Acc.(%) | Top-5 Acc.(%) | checkpoints | logs | 50 | |:------------|:----------|:---------|:------------|:--------------|:--------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------| 51 | | 0.5x | 2.3 | 43 | 25.1 | 66.9 | 86.9 | [gdrive](https://drive.google.com/file/d/16AGg-kSscFXDpXPZ3cJpYwqeZbUlUoyr/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1s-tuS8JoHVoCVHWuUiHUFw?pwd=qttp) | [log](./work_dirs/train/repghostnet_0_5x_43M_66.95/train.log) | 52 | | 0.58x | 2.5 | 60 | 31.9 | 68.9 | 88.4 | [gdrive](https://drive.google.com/file/d/1L6ccPjfnCMt5YK-pNFDfqGYvJyTRyZPR/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1bnVk2ILONqPEbmTQahZ0Og?pwd=tiyw) | [log](./work_dirs/train/repghostnet_0_58x_60M_68.94/train.log) | 53 | | 0.8x | 3.3 | 96 | 44.5 | 72.2 | 90.5 | [gdrive](https://drive.google.com/file/d/13gmUpwiJF_O05f3-3UeEyKD57veL5cG-/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1L_EJ0CnQeGpd0QBoOiY7oQ?pwd=rkd8) | [log](./work_dirs/train/repghostnet_0_8x_96M_72.24/train.log) | 54 | | 1.0x | 4.1 | 142 | 62.2 | 74.2 | 91.5 | [gdrive](https://drive.google.com/file/d/1gzfGln60urfY38elpPHVTyv9b94ukn5o/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1CEwuBLV05z7zrVbBrku59w?pwd=z4s7) | [log](./work_dirs/train/repghostnet_1_0x_142M_74.22/train.log) | 55 | | 1.11x | 4.5 | 170 | 71.5 | 75.1 | 92.2 | [gdrive](https://drive.google.com/file/d/14Lk4pKWIUFk1Mb53ooy_GsZbhMmz3iVE/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1Lb54Jiqyt0Jc6X4F_tUYnw?pwd=dwcb) | [log](./work_dirs/train/repghostnet_1_11x_170M_75.07/train.log) | 56 | | 1.3x | 5.5 | 231 | 92.9 | 76.4 | 92.9 | [gdrive](https://drive.google.com/file/d/1dNHpX2JyiuTcDmmyvr8gnAI9t8RM-Nui/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/19x_OUgxRDvwh2g4E9gN12Q?pwd=uux6) | [log](./work_dirs/train/repghostnet_1_3x_231M_76.37/train.log) | 57 | | 1.5x | 6.6 | 301 | 116.9 | 77.5 | 93.5 | [gdrive](https://drive.google.com/file/d/1TWAY654Dz8zcwhDBDN6QDWhV7as30P8e/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/15UWOMRQN5vw99QbgiFWMRw?pwd=3uqq) | [log](./work_dirs/train/repghostnet_1_5x_301M_77.45/train.log) | 58 | | 2.0x | 9.8 | 516 | 190.0 | 78.8 | 94.3 | [gdrive](https://drive.google.com/file/d/12k00eWCXhKxx_fq3ewDhCNX08ftJ-iyP/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1YbtYvIBt3tTqCzvbcjJuBw?pwd=nq1r) | [log](./work_dirs/train/repghostnet_2_0x_516M_78.81/train.log) | 59 | 60 | #### Parameters and FLOPs 61 | We calculate parameters and FLOPs using a modified [thop](https://github.com/Lyken17/pytorch-OpCounter) in ```tools.py```. It only counts infos of convolutional and full-connected layers, without BN. To use it in your code: 62 | 63 | ```python 64 | from tools import cal_flops_params 65 | flops, params = cal_flops_params(model, input_size=(1, 3, 224, 224)) 66 | ``` 67 | 68 | #### Latency 69 | We first export our pytorch model to a ONNX one, and then use [MNN](https://github.com/alibaba/MNN) to convert it to MNN format, at last evaluate its latency on an ARM-based mobile phone. 70 | 71 | #### Comparisons to MobileOne on iPhone12 CPU 72 | We compare RepGhostNet to [MobileOne](https://arxiv.org/abs/2206.04040) on iPhone12 CPU based on [ModelBench](https://github.com/apple/ml-mobileone/tree/main/ModelBench). The result is shown below. 73 | The MobileOne series models are $\mu0$, $\mu1$, $\mu2$, S0, S1, S2, S3 and S4. 74 | 75 |

76 | 77 |

78 | 79 | 80 | ### Citations 81 | If RepGhostNet helps your research or work, please consider citing: 82 | 83 | ``` 84 | @article{chen2022repghost, 85 | title={RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization}, 86 | author={Chen, Chengpeng, and Guo, Zichao, and Zeng, Haien, and Xiong, Pengfei and Dong, Jian}, 87 | journal={arXiv preprint arXiv:2211.06088}, 88 | year={2022} 89 | } 90 | ``` 91 | 92 | ### Contact 93 | 94 | If you have any questions, please contact chencp@live.com. 95 | 96 | --- 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.6.7 2 | pyyaml 3 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | # @Author : chengpeng.chen 2 | # @Email : chencp@live.com 3 | """ 4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong. 5 | https://arxiv.org/abs/2211.06088 6 | """ 7 | import torch 8 | 9 | from infotool.profile import profile_origin 10 | from infotool.helper import clever_format 11 | 12 | import copy 13 | 14 | def convert_syncbn_to_bn(module): 15 | module_output = module 16 | if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): 17 | module_output = torch.nn.BatchNorm2d( 18 | module.num_features, 19 | module.eps, 20 | module.momentum, 21 | module.affine, 22 | module.track_running_stats, 23 | ) 24 | if module.affine: 25 | with torch.no_grad(): 26 | module_output.weight = module.weight 27 | module_output.bias = module.bias 28 | module_output.running_mean = module.running_mean 29 | module_output.running_var = module.running_var 30 | module_output.num_batches_tracked = module.num_batches_tracked 31 | if hasattr(module, "qconfig"): 32 | module_output.qconfig = module.qconfig 33 | 34 | for name, child in module.named_children(): 35 | module_output.add_module( 36 | name, convert_syncbn_to_bn(child) 37 | ) 38 | del module 39 | return module_output 40 | 41 | 42 | def cal_flops_params(original_model, input_size): 43 | model = copy.deepcopy(original_model) 44 | model = convert_syncbn_to_bn(model) 45 | input_size = list(input_size) 46 | assert len(input_size) in [3, 4] 47 | if len(input_size) == 4: 48 | if input_size[0] != 1: 49 | print('modify batchsize of input_size from {} to 1'.format(input_size[0])) 50 | input_size[0] = 1 51 | 52 | if len(input_size) == 3: 53 | input_size.insert(0, 1) 54 | 55 | flops, params = profile_origin(model, inputs=(torch.zeros(input_size), )) 56 | 57 | print('flops = {}, params = {}'.format(flops, params)) 58 | print('flops = {}, params = {}'.format(clever_format(flops), clever_format(params))) 59 | 60 | return flops, params 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # @Author : chengpeng.chen 2 | # @Email : chencp@live.com 3 | """ 4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong. 5 | https://arxiv.org/abs/2211.06088 6 | """ 7 | #!/usr/bin/env python3 8 | import argparse 9 | import time 10 | import yaml 11 | import os 12 | import logging 13 | from collections import OrderedDict 14 | from contextlib import suppress 15 | from datetime import datetime 16 | import importlib 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torchvision.utils 21 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 22 | 23 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 24 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ 25 | convert_splitbn_model, model_parameters 26 | from timm.utils import * 27 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy 28 | from timm.optim import create_optimizer_v2, optimizer_kwargs 29 | from timm.scheduler import create_scheduler 30 | from timm.utils import ApexScaler, NativeScaler 31 | 32 | try: 33 | from apex import amp 34 | from apex.parallel import DistributedDataParallel as ApexDDP 35 | from apex.parallel import convert_syncbn_model 36 | 37 | has_apex = True 38 | except ImportError: 39 | has_apex = False 40 | 41 | has_native_amp = False 42 | try: 43 | if getattr(torch.cuda.amp, 'autocast') is not None: 44 | has_native_amp = True 45 | except AttributeError: 46 | pass 47 | 48 | try: 49 | import wandb 50 | 51 | has_wandb = True 52 | except ImportError: 53 | has_wandb = False 54 | 55 | torch.backends.cudnn.benchmark = True 56 | _logger = logging.getLogger('train') 57 | 58 | # The first arg parser parses out only the --config argument, this argument is used to 59 | # load a yaml file containing key-values that override the defaults for the main parser below 60 | config_parser = parser = argparse.ArgumentParser( 61 | description='Training Config', add_help=False) 62 | parser.add_argument( 63 | '-c', 64 | '--config', 65 | default='', 66 | type=str, 67 | metavar='FILE', 68 | help='YAML config file specifying default arguments') 69 | 70 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 71 | 72 | # Dataset / Model parameters 73 | parser.add_argument( 74 | '--data_dir', 75 | metavar='DIR', 76 | default='/disk2/datasets/imagenet', 77 | help='path to dataset') 78 | parser.add_argument( 79 | '--dataset', 80 | '-d', 81 | metavar='NAME', 82 | default='', 83 | help='dataset type (default: ImageFolder/ImageTar if empty)') 84 | parser.add_argument( 85 | '--train-split', 86 | metavar='NAME', 87 | default='train', 88 | help='dataset train split (default: train)') 89 | parser.add_argument( 90 | '--val-split', 91 | metavar='NAME', 92 | default='val', 93 | help='dataset validation split (default: val)') 94 | parser.add_argument( 95 | '--model', 96 | default='', 97 | type=str, 98 | metavar='MODEL', 99 | help='Name of model to train (default: "countception"') 100 | parser.add_argument( 101 | '--pretrained', 102 | action='store_true', 103 | default=False, 104 | help='Start with pretrained version of specified network (if avail)') 105 | parser.add_argument( 106 | '--initial-checkpoint', 107 | default='', 108 | type=str, 109 | metavar='PATH', 110 | help='Initialize model from this checkpoint (default: none)') 111 | parser.add_argument( 112 | '--resume', 113 | default='', 114 | type=str, 115 | metavar='PATH', 116 | help='Resume full model and optimizer state from checkpoint (default: none)' 117 | ) 118 | parser.add_argument( 119 | '--no-resume-opt', 120 | action='store_true', 121 | default=False, 122 | help='prevent resume of optimizer state when resuming model') 123 | parser.add_argument( 124 | '--num-classes', 125 | type=int, 126 | default=1000, 127 | metavar='N', 128 | help='number of label classes (Model default if None)') 129 | parser.add_argument( 130 | '--gp', 131 | default=None, 132 | type=str, 133 | metavar='POOL', 134 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.' 135 | ) 136 | parser.add_argument( 137 | '--img-size', 138 | type=int, 139 | default=None, 140 | metavar='N', 141 | help='Image patch size (default: None => model default)') 142 | parser.add_argument( 143 | '--input-size', 144 | default=None, 145 | nargs=3, 146 | type=int, 147 | metavar='N N N', 148 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty' 149 | ) 150 | parser.add_argument( 151 | '--crop-pct', 152 | default=None, 153 | type=float, 154 | metavar='N', 155 | help='Input image center crop percent (for validation only)') 156 | parser.add_argument( 157 | '--mean', 158 | type=float, 159 | nargs='+', 160 | default=None, 161 | metavar='MEAN', 162 | help='Override mean pixel value of dataset') 163 | parser.add_argument( 164 | '--std', 165 | type=float, 166 | nargs='+', 167 | default=None, 168 | metavar='STD', 169 | help='Override std deviation of of dataset') 170 | parser.add_argument( 171 | '--interpolation', 172 | default='', 173 | type=str, 174 | metavar='NAME', 175 | help='Image resize interpolation type (overrides model)') 176 | parser.add_argument( 177 | '-b', 178 | '--batch-size', 179 | type=int, 180 | default=32, 181 | metavar='N', 182 | help='input batch size for training (default: 32)') 183 | parser.add_argument( 184 | '-vb', 185 | '--validation-batch-size-multiplier', 186 | type=int, 187 | default=1, 188 | metavar='N', 189 | help='ratio of validation batch size to training batch size (default: 1)') 190 | 191 | # Optimizer parameters 192 | parser.add_argument( 193 | '--opt', 194 | default='sgd', 195 | type=str, 196 | metavar='OPTIMIZER', 197 | help='Optimizer (default: "sgd"') 198 | parser.add_argument( 199 | '--opt-eps', 200 | default=None, 201 | type=float, 202 | metavar='EPSILON', 203 | help='Optimizer Epsilon (default: None, use opt default)') 204 | parser.add_argument( 205 | '--opt-betas', 206 | default=None, 207 | type=float, 208 | nargs='+', 209 | metavar='BETA', 210 | help='Optimizer Betas (default: None, use opt default)') 211 | parser.add_argument( 212 | '--momentum', 213 | type=float, 214 | default=0.9, 215 | metavar='M', 216 | help='Optimizer momentum (default: 0.9)') 217 | parser.add_argument( 218 | '--weight-decay', 219 | type=float, 220 | default=0.0001, 221 | help='weight decay (default: 0.0001)') 222 | parser.add_argument( 223 | '--clip-grad', 224 | type=float, 225 | default=None, 226 | metavar='NORM', 227 | help='Clip gradient norm (default: None, no clipping)') 228 | parser.add_argument( 229 | '--clip-mode', 230 | type=str, 231 | default='norm', 232 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 233 | 234 | # Learning rate schedule parameters 235 | parser.add_argument( 236 | '--sched', 237 | default='step', 238 | type=str, 239 | metavar='SCHEDULER', 240 | help='LR scheduler (default: "step"') 241 | parser.add_argument( 242 | '--lr', 243 | type=float, 244 | default=0.01, 245 | metavar='LR', 246 | help='learning rate (default: 0.01)') 247 | parser.add_argument( 248 | '--lr-noise', 249 | type=float, 250 | nargs='+', 251 | default=None, 252 | metavar='pct, pct', 253 | help='learning rate noise on/off epoch percentages') 254 | parser.add_argument( 255 | '--lr-noise-pct', 256 | type=float, 257 | default=0.67, 258 | metavar='PERCENT', 259 | help='learning rate noise limit percent (default: 0.67)') 260 | parser.add_argument( 261 | '--lr-noise-std', 262 | type=float, 263 | default=1.0, 264 | metavar='STDDEV', 265 | help='learning rate noise std-dev (default: 1.0)') 266 | parser.add_argument( 267 | '--lr-cycle-mul', 268 | type=float, 269 | default=1.0, 270 | metavar='MULT', 271 | help='learning rate cycle len multiplier (default: 1.0)') 272 | parser.add_argument( 273 | '--lr-cycle-limit', 274 | type=int, 275 | default=1, 276 | metavar='N', 277 | help='learning rate cycle limit') 278 | parser.add_argument( 279 | '--warmup-lr', 280 | type=float, 281 | default=0.0001, 282 | metavar='LR', 283 | help='warmup learning rate (default: 0.0001)') 284 | parser.add_argument( 285 | '--min-lr', 286 | type=float, 287 | default=1e-5, 288 | metavar='LR', 289 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 290 | parser.add_argument( 291 | '--epochs', 292 | type=int, 293 | default=200, 294 | metavar='N', 295 | help='number of epochs to train (default: 2)') 296 | parser.add_argument( 297 | '--epoch-repeats', 298 | type=float, 299 | default=0., 300 | metavar='N', 301 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).' 302 | ) 303 | parser.add_argument( 304 | '--start-epoch', 305 | default=None, 306 | type=int, 307 | metavar='N', 308 | help='manual epoch number (useful on restarts)') 309 | parser.add_argument( 310 | '--decay-epochs', 311 | type=float, 312 | default=30, 313 | metavar='N', 314 | help='epoch interval to decay LR') 315 | parser.add_argument( 316 | '--warmup-epochs', 317 | type=int, 318 | default=3, 319 | metavar='N', 320 | help='epochs to warmup LR, if scheduler supports') 321 | parser.add_argument( 322 | '--cooldown-epochs', 323 | type=int, 324 | default=0, 325 | metavar='N', 326 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 327 | parser.add_argument( 328 | '--patience-epochs', 329 | type=int, 330 | default=10, 331 | metavar='N', 332 | help='patience epochs for Plateau LR scheduler (default: 10') 333 | parser.add_argument( 334 | '--decay-rate', 335 | '--dr', 336 | type=float, 337 | default=0.1, 338 | metavar='RATE', 339 | help='LR decay rate (default: 0.1)') 340 | 341 | # Augmentation & regularization parameters 342 | parser.add_argument( 343 | '--no-aug', 344 | action='store_true', 345 | default=False, 346 | help='Disable all training augmentation, override other train aug args') 347 | parser.add_argument( 348 | '--scale', 349 | type=float, 350 | nargs='+', 351 | default=[0.08, 1.0], 352 | metavar='PCT', 353 | help='Random resize scale (default: 0.08 1.0)') 354 | parser.add_argument( 355 | '--ratio', 356 | type=float, 357 | nargs='+', 358 | default=[3. / 4., 4. / 3.], 359 | metavar='RATIO', 360 | help='Random resize aspect ratio (default: 0.75 1.33)') 361 | parser.add_argument( 362 | '--hflip', 363 | type=float, 364 | default=0.5, 365 | help='Horizontal flip training aug probability') 366 | parser.add_argument( 367 | '--vflip', 368 | type=float, 369 | default=0., 370 | help='Vertical flip training aug probability') 371 | parser.add_argument( 372 | '--color-jitter', 373 | type=float, 374 | default=0.4, 375 | metavar='PCT', 376 | help='Color jitter factor (default: 0.4)') 377 | parser.add_argument( 378 | '--aa', 379 | type=str, 380 | default=None, 381 | metavar='NAME', 382 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 383 | parser.add_argument( 384 | '--aug-splits', 385 | type=int, 386 | default=0, 387 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 388 | parser.add_argument( 389 | '--jsd', 390 | action='store_true', 391 | default=False, 392 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 393 | parser.add_argument( 394 | '--reprob', 395 | type=float, 396 | default=0., 397 | metavar='PCT', 398 | help='Random erase prob (default: 0.)') 399 | parser.add_argument( 400 | '--remode', 401 | type=str, 402 | default='const', 403 | help='Random erase mode (default: "const")') 404 | parser.add_argument( 405 | '--recount', type=int, default=1, help='Random erase count (default: 1)') 406 | parser.add_argument( 407 | '--resplit', 408 | action='store_true', 409 | default=False, 410 | help='Do not random erase first (clean) augmentation split') 411 | parser.add_argument( 412 | '--mixup', 413 | type=float, 414 | default=0.0, 415 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 416 | parser.add_argument( 417 | '--cutmix', 418 | type=float, 419 | default=0.0, 420 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 421 | parser.add_argument( 422 | '--cutmix-minmax', 423 | type=float, 424 | nargs='+', 425 | default=None, 426 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)' 427 | ) 428 | parser.add_argument( 429 | '--mixup-prob', 430 | type=float, 431 | default=1.0, 432 | help='Probability of performing mixup or cutmix when either/both is enabled' 433 | ) 434 | parser.add_argument( 435 | '--mixup-switch-prob', 436 | type=float, 437 | default=0.5, 438 | help='Probability of switching to cutmix when both mixup and cutmix enabled' 439 | ) 440 | parser.add_argument( 441 | '--mixup-mode', 442 | type=str, 443 | default='batch', 444 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 445 | parser.add_argument( 446 | '--mixup-off-epoch', 447 | default=0, 448 | type=int, 449 | metavar='N', 450 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 451 | parser.add_argument( 452 | '--smoothing', 453 | type=float, 454 | default=0.1, 455 | help='Label smoothing (default: 0.1)') 456 | parser.add_argument( 457 | '--train-interpolation', 458 | type=str, 459 | default='random', 460 | help='Training interpolation (random, bilinear, bicubic default: "random")') 461 | parser.add_argument( 462 | '--drop', 463 | type=float, 464 | default=0.2, 465 | metavar='PCT', 466 | help='Dropout rate (default: 0.)') 467 | parser.add_argument( 468 | '--drop-connect', 469 | type=float, 470 | default=None, 471 | metavar='PCT', 472 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 473 | parser.add_argument( 474 | '--drop-path', 475 | type=float, 476 | default=None, 477 | metavar='PCT', 478 | help='Drop path rate (default: None)') 479 | parser.add_argument( 480 | '--drop-block', 481 | type=float, 482 | default=None, 483 | metavar='PCT', 484 | help='Drop block rate (default: None)') 485 | 486 | # Batch norm parameters (only works with gen_efficientnet based models currently) 487 | parser.add_argument( 488 | '--bn-tf', 489 | action='store_true', 490 | default=False, 491 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)' 492 | ) 493 | parser.add_argument( 494 | '--bn-momentum', 495 | type=float, 496 | default=None, 497 | help='BatchNorm momentum override (if not None)') 498 | parser.add_argument( 499 | '--bn-eps', 500 | type=float, 501 | default=None, 502 | help='BatchNorm epsilon override (if not None)') 503 | parser.add_argument( 504 | '--sync-bn', 505 | action='store_true', 506 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 507 | parser.add_argument( 508 | '--dist-bn', 509 | type=str, 510 | default='', 511 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")' 512 | ) 513 | parser.add_argument( 514 | '--split-bn', 515 | action='store_true', 516 | help='Enable separate BN layers per augmentation split.') 517 | 518 | # Model Exponential Moving Average 519 | parser.add_argument( 520 | '--model-ema', 521 | action='store_true', 522 | default=False, 523 | help='Enable tracking moving average of model weights') 524 | parser.add_argument( 525 | '--model-ema-force-cpu', 526 | action='store_true', 527 | default=False, 528 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.' 529 | ) 530 | parser.add_argument( 531 | '--model-ema-decay', 532 | type=float, 533 | default=0.9998, 534 | help='decay factor for model weights moving average (default: 0.9998)') 535 | 536 | # Misc 537 | parser.add_argument( 538 | '--seed', 539 | type=int, 540 | default=42, 541 | metavar='S', 542 | help='random seed (default: 42)') 543 | parser.add_argument( 544 | '--log-interval', 545 | type=int, 546 | default=50, 547 | metavar='N', 548 | help='how many batches to wait before logging training status') 549 | parser.add_argument( 550 | '--recovery-interval', 551 | type=int, 552 | default=0, 553 | metavar='N', 554 | help='how many batches to wait before writing recovery checkpoint') 555 | parser.add_argument( 556 | '--checkpoint-hist', 557 | type=int, 558 | default=5, 559 | metavar='N', 560 | help='number of checkpoints to keep (default: 10)') 561 | parser.add_argument( 562 | '-j', 563 | '--workers', 564 | type=int, 565 | default=4, 566 | metavar='N', 567 | help='how many training processes to use (default: 1)') 568 | parser.add_argument( 569 | '--save-images', 570 | action='store_true', 571 | default=False, 572 | help='save images of input bathes every log interval for debugging') 573 | parser.add_argument( 574 | '--amp', 575 | action='store_true', 576 | default=False, 577 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 578 | parser.add_argument( 579 | '--apex-amp', 580 | action='store_true', 581 | default=False, 582 | help='Use NVIDIA Apex AMP mixed precision') 583 | parser.add_argument( 584 | '--native-amp', 585 | action='store_true', 586 | default=False, 587 | help='Use Native Torch AMP mixed precision') 588 | parser.add_argument( 589 | '--channels-last', 590 | action='store_true', 591 | default=False, 592 | help='Use channels_last memory layout') 593 | parser.add_argument( 594 | '--pin-mem', 595 | action='store_true', 596 | default=False, 597 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.' 598 | ) 599 | parser.add_argument( 600 | '--no-prefetcher', 601 | action='store_true', 602 | default=False, 603 | help='disable fast prefetcher') 604 | parser.add_argument( 605 | '--output', 606 | default='', 607 | type=str, 608 | metavar='PATH', 609 | help='path to output folder (default: none, current dir)') 610 | parser.add_argument( 611 | '--experiment', 612 | default='', 613 | type=str, 614 | metavar='NAME', 615 | help='name of train experiment, name of sub-folder for output') 616 | parser.add_argument( 617 | '--eval-metric', 618 | default='top1', 619 | type=str, 620 | metavar='EVAL_METRIC', 621 | help='Best metric (default: "top1"') 622 | parser.add_argument( 623 | '--tta', 624 | type=int, 625 | default=0, 626 | metavar='N', 627 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)' 628 | ) 629 | parser.add_argument("--local_rank", default=0, type=int) 630 | parser.add_argument( 631 | '--use-multi-epochs-loader', 632 | action='store_true', 633 | default=False, 634 | help='use the multi-epochs-loader to save time at the beginning of every epoch' 635 | ) 636 | parser.add_argument( 637 | '--torchscript', 638 | dest='torchscript', 639 | action='store_true', 640 | help='convert model torchscript for inference') 641 | parser.add_argument( 642 | '--log-wandb', 643 | action='store_true', 644 | default=False, 645 | help='log training and validation metrics to wandb') 646 | 647 | 648 | def _parse_args(): 649 | # Do we have a config file to parse? 650 | args_config, remaining = config_parser.parse_known_args() 651 | if args_config.config: 652 | with open(args_config.config, 'r') as f: 653 | cfg = yaml.safe_load(f) 654 | parser.set_defaults(**cfg) 655 | 656 | # The main arg parser parses the rest of the args, the usual 657 | # defaults will have been overridden if config file specified. 658 | args = parser.parse_args(remaining) 659 | 660 | # Cache the args as a text string to save them in the output dir later 661 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 662 | return args, args_text 663 | 664 | 665 | def main(): 666 | setup_default_logging() 667 | args, args_text = _parse_args() 668 | 669 | if args.log_wandb: 670 | if has_wandb: 671 | wandb.init(project=args.experiment, config=args) 672 | else: 673 | _logger.warning( 674 | "You've requested to log metrics to wandb but package not found. " 675 | "Metrics not being logged to wandb, try `pip install wandb`") 676 | 677 | args.prefetcher = not args.no_prefetcher 678 | args.distributed = False 679 | if 'WORLD_SIZE' in os.environ: 680 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 681 | args.device = 'cuda:0' 682 | args.world_size = 1 683 | args.rank = 0 # global rank 684 | if args.distributed: 685 | args.device = 'cuda:%d' % args.local_rank 686 | torch.cuda.set_device(args.local_rank) 687 | torch.distributed.init_process_group( 688 | backend='nccl', init_method='env://') 689 | args.world_size = torch.distributed.get_world_size() 690 | args.rank = torch.distributed.get_rank() 691 | _logger.info( 692 | 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 693 | % (args.rank, args.world_size)) 694 | else: 695 | _logger.info('Training with a single process on 1 GPUs.') 696 | assert args.rank >= 0 697 | 698 | # resolve AMP arguments based on PyTorch / Apex availability 699 | use_amp = None 700 | if args.amp: 701 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 702 | if has_native_amp: 703 | args.native_amp = True 704 | elif has_apex: 705 | args.apex_amp = True 706 | if args.apex_amp and has_apex: 707 | use_amp = 'apex' 708 | elif args.native_amp and has_native_amp: 709 | use_amp = 'native' 710 | elif args.apex_amp or args.native_amp: 711 | _logger.warning( 712 | "Neither APEX or native Torch AMP is available, using float32. " 713 | "Install NVIDA apex or upgrade to PyTorch 1.6") 714 | 715 | random_seed(args.seed, args.rank) 716 | 717 | # model = create_model( 718 | # args.model, 719 | # pretrained=args.pretrained, 720 | # num_classes=args.num_classes, 721 | # drop_rate=args.drop, 722 | # drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 723 | # drop_path_rate=args.drop_path, 724 | # drop_block_rate=args.drop_block, 725 | # global_pool=args.gp, 726 | # bn_tf=args.bn_tf, 727 | # bn_momentum=args.bn_momentum, 728 | # bn_eps=args.bn_eps, 729 | # scriptable=args.torchscript, 730 | # checkpoint_path=args.initial_checkpoint) 731 | 732 | m = importlib.import_module(f"model.{args.model.split('.')[0]}") 733 | model = getattr(m, args.model.split('.')[1])(dropout=args.drop) 734 | 735 | if args.num_classes is None: 736 | assert hasattr( 737 | model, 'num_classes' 738 | ), 'Model must have `num_classes` attr if not set on cmd line/config.' 739 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 740 | 741 | if args.local_rank == 0: 742 | _logger.info( 743 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' 744 | ) 745 | 746 | data_config = resolve_data_config( 747 | vars(args), model=model, verbose=args.local_rank == 0) 748 | 749 | # setup augmentation batch splits for contrastive loss or split bn 750 | num_aug_splits = 0 751 | if args.aug_splits > 0: 752 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 753 | num_aug_splits = args.aug_splits 754 | 755 | # enable split bn (separate bn stats per batch-portion) 756 | if args.split_bn: 757 | assert num_aug_splits > 1 or args.resplit 758 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 759 | 760 | # move model to GPU, enable channels last layout if set 761 | model.cuda() 762 | if args.channels_last: 763 | model = model.to(memory_format=torch.channels_last) 764 | 765 | # setup synchronized BatchNorm for distributed training 766 | if args.distributed and args.sync_bn: 767 | assert not args.split_bn 768 | if has_apex and use_amp != 'native': 769 | # Apex SyncBN preferred unless native amp is activated 770 | model = convert_syncbn_model(model) 771 | else: 772 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 773 | if args.local_rank == 0: 774 | _logger.info( 775 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 776 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' 777 | ) 778 | 779 | if args.torchscript: 780 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 781 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 782 | model = torch.jit.script(model) 783 | 784 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) 785 | 786 | # setup automatic mixed-precision (AMP) loss scaling and op casting 787 | amp_autocast = suppress # do nothing 788 | loss_scaler = None 789 | if use_amp == 'apex': 790 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 791 | loss_scaler = ApexScaler() 792 | if args.local_rank == 0: 793 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 794 | elif use_amp == 'native': 795 | amp_autocast = torch.cuda.amp.autocast 796 | loss_scaler = NativeScaler() 797 | if args.local_rank == 0: 798 | _logger.info( 799 | 'Using native Torch AMP. Training in mixed precision.') 800 | else: 801 | if args.local_rank == 0: 802 | _logger.info('AMP not enabled. Training in float32.') 803 | 804 | # optionally resume from a checkpoint 805 | resume_epoch = None 806 | if args.resume: 807 | resume_epoch = resume_checkpoint( 808 | model, 809 | args.resume, 810 | optimizer=None if args.no_resume_opt else optimizer, 811 | loss_scaler=None if args.no_resume_opt else loss_scaler, 812 | log_info=args.local_rank == 0) 813 | 814 | # setup exponential moving average of model weights, SWA could be used here too 815 | model_ema = None 816 | if args.model_ema: 817 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 818 | model_ema = ModelEmaV2( 819 | model, 820 | decay=args.model_ema_decay, 821 | device='cpu' if args.model_ema_force_cpu else None) 822 | if args.resume: 823 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 824 | 825 | # setup distributed training 826 | if args.distributed: 827 | if has_apex and use_amp != 'native': 828 | # Apex DDP preferred unless native amp is activated 829 | if args.local_rank == 0: 830 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 831 | model = ApexDDP(model, delay_allreduce=True) 832 | else: 833 | if args.local_rank == 0: 834 | _logger.info("Using native Torch DistributedDataParallel.") 835 | model = NativeDDP( 836 | model, device_ids=[args.local_rank 837 | ]) # can use device str in Torch >= 1.1 838 | # NOTE: EMA model does not need to be wrapped by DDP 839 | 840 | # setup learning rate schedule and starting epoch 841 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 842 | start_epoch = 0 843 | if args.start_epoch is not None: 844 | # a specified start_epoch will always override the resume epoch 845 | start_epoch = args.start_epoch 846 | elif resume_epoch is not None: 847 | start_epoch = resume_epoch 848 | if lr_scheduler is not None and start_epoch > 0: 849 | lr_scheduler.step(start_epoch) 850 | 851 | if args.local_rank == 0: 852 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 853 | 854 | # create the train and eval datasets 855 | dataset_train = create_dataset( 856 | args.dataset, 857 | root=args.data_dir, 858 | split=args.train_split, 859 | is_training=True, 860 | batch_size=args.batch_size, 861 | repeats=args.epoch_repeats) 862 | dataset_eval = create_dataset( 863 | args.dataset, 864 | root=args.data_dir, 865 | split=args.val_split, 866 | is_training=False, 867 | batch_size=args.batch_size) 868 | 869 | # setup mixup / cutmix 870 | collate_fn = None 871 | mixup_fn = None 872 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 873 | if mixup_active: 874 | mixup_args = dict( 875 | mixup_alpha=args.mixup, 876 | cutmix_alpha=args.cutmix, 877 | cutmix_minmax=args.cutmix_minmax, 878 | prob=args.mixup_prob, 879 | switch_prob=args.mixup_switch_prob, 880 | mode=args.mixup_mode, 881 | label_smoothing=args.smoothing, 882 | num_classes=args.num_classes) 883 | if args.prefetcher: 884 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 885 | collate_fn = FastCollateMixup(**mixup_args) 886 | else: 887 | mixup_fn = Mixup(**mixup_args) 888 | 889 | # wrap dataset in AugMix helper 890 | if num_aug_splits > 1: 891 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 892 | 893 | # create data loaders w/ augmentation pipeiine 894 | train_interpolation = args.train_interpolation 895 | if args.no_aug or not train_interpolation: 896 | train_interpolation = data_config['interpolation'] 897 | loader_train = create_loader( 898 | dataset_train, 899 | input_size=data_config['input_size'], 900 | batch_size=args.batch_size, 901 | is_training=True, 902 | use_prefetcher=args.prefetcher, 903 | no_aug=args.no_aug, 904 | re_prob=args.reprob, 905 | re_mode=args.remode, 906 | re_count=args.recount, 907 | re_split=args.resplit, 908 | scale=args.scale, 909 | ratio=args.ratio, 910 | hflip=args.hflip, 911 | vflip=args.vflip, 912 | color_jitter=args.color_jitter, 913 | auto_augment=args.aa, 914 | num_aug_splits=num_aug_splits, 915 | interpolation=train_interpolation, 916 | mean=data_config['mean'], 917 | std=data_config['std'], 918 | num_workers=args.workers, 919 | distributed=args.distributed, 920 | collate_fn=collate_fn, 921 | pin_memory=args.pin_mem, 922 | use_multi_epochs_loader=args.use_multi_epochs_loader) 923 | 924 | loader_eval = create_loader( 925 | dataset_eval, 926 | input_size=data_config['input_size'], 927 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 928 | is_training=False, 929 | use_prefetcher=args.prefetcher, 930 | interpolation=data_config['interpolation'], 931 | mean=data_config['mean'], 932 | std=data_config['std'], 933 | num_workers=args.workers, 934 | distributed=args.distributed, 935 | crop_pct=data_config['crop_pct'], 936 | pin_memory=args.pin_mem, ) 937 | 938 | # setup loss function 939 | if args.jsd: 940 | assert num_aug_splits > 1 # JSD only valid with aug splits set 941 | train_loss_fn = JsdCrossEntropy( 942 | num_splits=num_aug_splits, smoothing=args.smoothing).cuda() 943 | elif mixup_active: 944 | # smoothing is handled with mixup target transform 945 | train_loss_fn = SoftTargetCrossEntropy().cuda() 946 | elif args.smoothing: 947 | train_loss_fn = LabelSmoothingCrossEntropy( 948 | smoothing=args.smoothing).cuda() 949 | else: 950 | train_loss_fn = nn.CrossEntropyLoss().cuda() 951 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 952 | 953 | # setup checkpoint saver and eval metric tracking 954 | eval_metric = args.eval_metric 955 | best_metric = None 956 | best_epoch = None 957 | saver = None 958 | output_dir = None 959 | if args.rank == 0: 960 | if args.experiment: 961 | exp_name = args.experiment 962 | else: 963 | exp_name = '-'.join([ 964 | datetime.now().strftime("%Y%m%d-%H%M%S"), 965 | safe_model_name(args.model), str(data_config['input_size'][-1]) 966 | ]) 967 | output_dir = get_outdir(args.output 968 | if args.output else './output/train', exp_name) 969 | handler = logging.FileHandler(os.path.join(output_dir, 'train.log')) 970 | _logger.addHandler(handler) 971 | decreasing = True if eval_metric == 'loss' else False 972 | saver = CheckpointSaver( 973 | model=model, 974 | optimizer=optimizer, 975 | args=args, 976 | model_ema=model_ema, 977 | amp_scaler=loss_scaler, 978 | checkpoint_dir=output_dir, 979 | recovery_dir=output_dir, 980 | decreasing=decreasing, 981 | max_history=args.checkpoint_hist) 982 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 983 | f.write(args_text) 984 | 985 | try: 986 | for epoch in range(start_epoch, num_epochs): 987 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 988 | loader_train.sampler.set_epoch(epoch) 989 | 990 | train_metrics = train_one_epoch( 991 | epoch, 992 | model, 993 | loader_train, 994 | optimizer, 995 | train_loss_fn, 996 | args, 997 | lr_scheduler=lr_scheduler, 998 | saver=saver, 999 | output_dir=output_dir, 1000 | amp_autocast=amp_autocast, 1001 | loss_scaler=loss_scaler, 1002 | model_ema=model_ema, 1003 | mixup_fn=mixup_fn) 1004 | 1005 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 1006 | if args.local_rank == 0: 1007 | _logger.info( 1008 | "Distributing BatchNorm running means and vars") 1009 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 1010 | 1011 | eval_metrics = validate( 1012 | model, 1013 | loader_eval, 1014 | validate_loss_fn, 1015 | args, 1016 | amp_autocast=amp_autocast) 1017 | 1018 | if model_ema is not None and not args.model_ema_force_cpu: 1019 | if args.distributed and args.dist_bn in ('broadcast', 'reduce' 1020 | ): 1021 | distribute_bn(model_ema, args.world_size, 1022 | args.dist_bn == 'reduce') 1023 | ema_eval_metrics = validate( 1024 | model_ema.module, 1025 | loader_eval, 1026 | validate_loss_fn, 1027 | args, 1028 | amp_autocast=amp_autocast, 1029 | log_suffix=' (EMA)') 1030 | eval_metrics = ema_eval_metrics 1031 | 1032 | if lr_scheduler is not None: 1033 | # step LR for next epoch 1034 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 1035 | 1036 | if output_dir is not None: 1037 | update_summary( 1038 | epoch, 1039 | train_metrics, 1040 | eval_metrics, 1041 | os.path.join(output_dir, 'summary.csv'), 1042 | write_header=best_metric is None, 1043 | log_wandb=args.log_wandb and has_wandb) 1044 | 1045 | if saver is not None: 1046 | # save proper checkpoint with eval metric 1047 | save_metric = eval_metrics[eval_metric] 1048 | best_metric, best_epoch = saver.save_checkpoint( 1049 | epoch, metric=save_metric) 1050 | 1051 | except KeyboardInterrupt: 1052 | pass 1053 | if best_metric is not None: 1054 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, 1055 | best_epoch)) 1056 | 1057 | 1058 | def train_one_epoch(epoch, 1059 | model, 1060 | loader, 1061 | optimizer, 1062 | loss_fn, 1063 | args, 1064 | lr_scheduler=None, 1065 | saver=None, 1066 | output_dir=None, 1067 | amp_autocast=suppress, 1068 | loss_scaler=None, 1069 | model_ema=None, 1070 | mixup_fn=None): 1071 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 1072 | if args.prefetcher and loader.mixup_enabled: 1073 | loader.mixup_enabled = False 1074 | elif mixup_fn is not None: 1075 | mixup_fn.mixup_enabled = False 1076 | 1077 | second_order = hasattr(optimizer, 1078 | 'is_second_order') and optimizer.is_second_order 1079 | batch_time_m = AverageMeter() 1080 | data_time_m = AverageMeter() 1081 | losses_m = AverageMeter() 1082 | top1_m = AverageMeter() 1083 | top5_m = AverageMeter() 1084 | 1085 | model.train() 1086 | 1087 | end = time.time() 1088 | last_idx = len(loader) - 1 1089 | num_updates = epoch * len(loader) 1090 | 1091 | for batch_idx, (input, target) in enumerate(loader): 1092 | last_batch = batch_idx == last_idx 1093 | data_time_m.update(time.time() - end) 1094 | if not args.prefetcher: 1095 | input, target = input.cuda(), target.cuda() 1096 | if mixup_fn is not None: 1097 | input, target = mixup_fn(input, target) 1098 | 1099 | if args.channels_last: 1100 | input = input.contiguous(memory_format=torch.channels_last) 1101 | 1102 | with amp_autocast(): 1103 | output = model(input) 1104 | loss = loss_fn(output, target) 1105 | if len(target.shape) > 1 and target.shape[1] > 1: 1106 | label = torch.argmax(target, dim=1) 1107 | acc1, acc5 = accuracy(output, label, topk=(1, 5)) 1108 | else: 1109 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 1110 | 1111 | if not args.distributed: 1112 | losses_m.update(loss.item(), input.size(0)) 1113 | top1_m.update(acc1.item(), output.size(0)) 1114 | top5_m.update(acc5.item(), output.size(0)) 1115 | 1116 | optimizer.zero_grad() 1117 | if loss_scaler is not None: 1118 | loss_scaler( 1119 | loss, 1120 | optimizer, 1121 | clip_grad=args.clip_grad, 1122 | clip_mode=args.clip_mode, 1123 | parameters=model_parameters( 1124 | model, exclude_head='agc' in args.clip_mode), 1125 | create_graph=second_order) 1126 | else: 1127 | loss.backward(create_graph=second_order) 1128 | if args.clip_grad is not None: 1129 | dispatch_clip_grad( 1130 | model_parameters( 1131 | model, exclude_head='agc' in args.clip_mode), 1132 | value=args.clip_grad, 1133 | mode=args.clip_mode) 1134 | optimizer.step() 1135 | 1136 | if model_ema is not None: 1137 | model_ema.update(model) 1138 | 1139 | torch.cuda.synchronize() 1140 | num_updates += 1 1141 | batch_time_m.update(time.time() - end) 1142 | if last_batch or batch_idx % args.log_interval == 0: 1143 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 1144 | lr = sum(lrl) / len(lrl) 1145 | 1146 | if args.distributed: 1147 | reduced_loss = reduce_tensor(loss.data, args.world_size) 1148 | losses_m.update(reduced_loss.item(), input.size(0)) 1149 | acc1 = reduce_tensor(acc1, args.world_size) 1150 | acc5 = reduce_tensor(acc5, args.world_size) 1151 | top1_m.update(acc1.item(), output.size(0)) 1152 | top5_m.update(acc5.item(), output.size(0)) 1153 | 1154 | if args.local_rank == 0: 1155 | now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 1156 | _logger.info( 1157 | '{} ' 1158 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 1159 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 1160 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 1161 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) ' 1162 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 1163 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 1164 | 'LR: {lr:.3e} ' 1165 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 1166 | now, 1167 | epoch, 1168 | batch_idx, 1169 | len(loader), 1170 | 100. * batch_idx / last_idx, 1171 | loss=losses_m, 1172 | top1=top1_m, 1173 | top5=top5_m, 1174 | batch_time=batch_time_m, 1175 | rate=input.size( 1176 | 0) * args.world_size / batch_time_m.val, 1177 | rate_avg=input.size(0) * args.world_size / 1178 | batch_time_m.avg, 1179 | lr=lr, 1180 | data_time=data_time_m)) 1181 | 1182 | if args.save_images and output_dir: 1183 | torchvision.utils.save_image( 1184 | input, 1185 | os.path.join(output_dir, 1186 | 'train-batch-%d.jpg' % batch_idx), 1187 | padding=0, 1188 | normalize=True) 1189 | 1190 | if saver is not None and args.recovery_interval and ( 1191 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 1192 | saver.save_recovery(epoch, batch_idx=batch_idx) 1193 | 1194 | if lr_scheduler is not None: 1195 | lr_scheduler.step_update( 1196 | num_updates=num_updates, metric=losses_m.avg) 1197 | 1198 | end = time.time() 1199 | # end for 1200 | 1201 | if hasattr(optimizer, 'sync_lookahead'): 1202 | optimizer.sync_lookahead() 1203 | 1204 | return OrderedDict([('loss', losses_m.avg)]) 1205 | 1206 | 1207 | def validate(model, 1208 | loader, 1209 | loss_fn, 1210 | args, 1211 | amp_autocast=suppress, 1212 | log_suffix=''): 1213 | batch_time_m = AverageMeter() 1214 | losses_m = AverageMeter() 1215 | top1_m = AverageMeter() 1216 | top5_m = AverageMeter() 1217 | 1218 | model.eval() 1219 | 1220 | end = time.time() 1221 | last_idx = len(loader) - 1 1222 | with torch.no_grad(): 1223 | for batch_idx, (input, target) in enumerate(loader): 1224 | last_batch = batch_idx == last_idx 1225 | if not args.prefetcher: 1226 | input = input.cuda() 1227 | target = target.cuda() 1228 | if args.channels_last: 1229 | input = input.contiguous(memory_format=torch.channels_last) 1230 | 1231 | with amp_autocast(): 1232 | output = model(input) 1233 | if isinstance(output, (tuple, list)): 1234 | output = output[0] 1235 | 1236 | # augmentation reduction 1237 | reduce_factor = args.tta 1238 | if reduce_factor > 1: 1239 | output = output.unfold(0, reduce_factor, reduce_factor).mean( 1240 | dim=2) 1241 | target = target[0:target.size(0):reduce_factor] 1242 | 1243 | loss = loss_fn(output, target) 1244 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 1245 | 1246 | if args.distributed: 1247 | reduced_loss = reduce_tensor(loss.data, args.world_size) 1248 | acc1 = reduce_tensor(acc1, args.world_size) 1249 | acc5 = reduce_tensor(acc5, args.world_size) 1250 | else: 1251 | reduced_loss = loss.data 1252 | 1253 | torch.cuda.synchronize() 1254 | 1255 | losses_m.update(reduced_loss.item(), input.size(0)) 1256 | top1_m.update(acc1.item(), output.size(0)) 1257 | top5_m.update(acc5.item(), output.size(0)) 1258 | 1259 | batch_time_m.update(time.time() - end) 1260 | end = time.time() 1261 | if args.local_rank == 0 and (last_batch or 1262 | batch_idx % args.log_interval == 0): 1263 | log_name = 'Test' + log_suffix 1264 | _logger.info( 1265 | '{0}: [{1:>4d}/{2}] ' 1266 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 1267 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 1268 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 1269 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 1270 | log_name, 1271 | batch_idx, 1272 | last_idx, 1273 | batch_time=batch_time_m, 1274 | loss=losses_m, 1275 | top1=top1_m, 1276 | top5=top5_m)) 1277 | 1278 | metrics = OrderedDict( 1279 | [('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 1280 | 1281 | return metrics 1282 | 1283 | 1284 | if __name__ == '__main__': 1285 | main() 1286 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Default 2 | # For small models 3 | # bash distributed_train.sh 8 --model repghost.repghostnet_0_5x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ 4 | # bash distributed_train.sh 8 --model repghost.repghostnet_0_58x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ 5 | # bash distributed_train.sh 8 --model repghost.repghostnet_0_8x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ 6 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_0x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ 7 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_11x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ 8 | 9 | 10 | # For large models 11 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_3x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --output work_dirs/train/ 12 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_5x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --output work_dirs/train/ 13 | # bash distributed_train.sh 8 --model repghost.repghostnet_2_0x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --output work_dirs/train/ 14 | 15 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | # @Author : chengpeng.chen 2 | # @Email : chencp@live.com 3 | """ 4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong. 5 | https://arxiv.org/abs/2211.06088 6 | """ 7 | #!/usr/bin/env python3 8 | import argparse 9 | import time 10 | import yaml 11 | import os 12 | import logging 13 | from collections import OrderedDict 14 | from contextlib import suppress 15 | import importlib 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 20 | 21 | from timm.data import create_dataset, create_loader, resolve_data_config 22 | from timm.models import safe_model_name, load_checkpoint, \ 23 | convert_splitbn_model 24 | from timm.utils import * 25 | 26 | 27 | try: 28 | from apex import amp 29 | from apex.parallel import DistributedDataParallel as ApexDDP 30 | from apex.parallel import convert_syncbn_model 31 | 32 | has_apex = True 33 | except ImportError: 34 | has_apex = False 35 | 36 | has_native_amp = False 37 | try: 38 | if getattr(torch.cuda.amp, 'autocast') is not None: 39 | has_native_amp = True 40 | except AttributeError: 41 | pass 42 | 43 | try: 44 | import wandb 45 | 46 | has_wandb = True 47 | except ImportError: 48 | has_wandb = False 49 | 50 | torch.backends.cudnn.benchmark = True 51 | _logger = logging.getLogger('validate') 52 | 53 | # The first arg parser parses out only the --config argument, this argument is used to 54 | # load a yaml file containing key-values that override the defaults for the main parser below 55 | config_parser = parser = argparse.ArgumentParser( 56 | description='Training Config', add_help=False) 57 | parser.add_argument( 58 | '-c', 59 | '--config', 60 | default='', 61 | type=str, 62 | metavar='FILE', 63 | help='YAML config file specifying default arguments') 64 | 65 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Test') 66 | 67 | # Dataset / Model parameters 68 | parser.add_argument( 69 | '--data_dir', 70 | metavar='DIR', 71 | default='/disk2/datasets/imagenet', 72 | help='path to dataset') 73 | parser.add_argument( 74 | '--dataset', 75 | '-d', 76 | metavar='NAME', 77 | default='', 78 | help='dataset type (default: ImageFolder/ImageTar if empty)') 79 | parser.add_argument( 80 | '--train-split', 81 | metavar='NAME', 82 | default='train', 83 | help='dataset train split (default: train)') 84 | parser.add_argument( 85 | '--val-split', 86 | metavar='NAME', 87 | default='val', 88 | help='dataset validation split (default: val)') 89 | parser.add_argument( 90 | '--model', 91 | default='', 92 | type=str, 93 | metavar='MODEL', 94 | help='Name of model to train (default: "countception"') 95 | parser.add_argument( 96 | '--resume', 97 | default='', 98 | type=str, 99 | metavar='PATH', 100 | help='Resume full model and optimizer state from checkpoint (default: none)' 101 | ) 102 | parser.add_argument( 103 | '--num-classes', 104 | type=int, 105 | default=1000, 106 | metavar='N', 107 | help='number of label classes (Model default if None)') 108 | parser.add_argument( 109 | '--gp', 110 | default=None, 111 | type=str, 112 | metavar='POOL', 113 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.' 114 | ) 115 | parser.add_argument( 116 | '--img-size', 117 | type=int, 118 | default=None, 119 | metavar='N', 120 | help='Image patch size (default: None => model default)') 121 | parser.add_argument( 122 | '--input-size', 123 | default=None, 124 | nargs=3, 125 | type=int, 126 | metavar='N N N', 127 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty' 128 | ) 129 | parser.add_argument( 130 | '--crop-pct', 131 | default=None, 132 | type=float, 133 | metavar='N', 134 | help='Input image center crop percent (for validation only)') 135 | parser.add_argument( 136 | '--mean', 137 | type=float, 138 | nargs='+', 139 | default=None, 140 | metavar='MEAN', 141 | help='Override mean pixel value of dataset') 142 | parser.add_argument( 143 | '--std', 144 | type=float, 145 | nargs='+', 146 | default=None, 147 | metavar='STD', 148 | help='Override std deviation of of dataset') 149 | parser.add_argument( 150 | '--interpolation', 151 | default='', 152 | type=str, 153 | metavar='NAME', 154 | help='Image resize interpolation type (overrides model)') 155 | parser.add_argument( 156 | '-b', 157 | '--batch-size', 158 | type=int, 159 | default=32, 160 | metavar='N', 161 | help='input batch size for training (default: 32)') 162 | parser.add_argument( 163 | '-vb', 164 | '--validation-batch-size-multiplier', 165 | type=int, 166 | default=1, 167 | metavar='N', 168 | help='ratio of validation batch size to training batch size (default: 1)') 169 | parser.add_argument( 170 | '--aug-splits', 171 | type=int, 172 | default=0, 173 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 174 | parser.add_argument( 175 | '--drop', 176 | type=float, 177 | default=0.2, 178 | metavar='PCT', 179 | help='Dropout rate (default: 0.)') 180 | parser.add_argument( 181 | '--drop-connect', 182 | type=float, 183 | default=None, 184 | metavar='PCT', 185 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 186 | parser.add_argument( 187 | '--drop-path', 188 | type=float, 189 | default=None, 190 | metavar='PCT', 191 | help='Drop path rate (default: None)') 192 | parser.add_argument( 193 | '--drop-block', 194 | type=float, 195 | default=None, 196 | metavar='PCT', 197 | help='Drop block rate (default: None)') 198 | 199 | # Batch norm parameters (only works with gen_efficientnet based models currently) 200 | parser.add_argument( 201 | '--bn-tf', 202 | action='store_true', 203 | default=False, 204 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)' 205 | ) 206 | parser.add_argument( 207 | '--bn-momentum', 208 | type=float, 209 | default=None, 210 | help='BatchNorm momentum override (if not None)') 211 | parser.add_argument( 212 | '--bn-eps', 213 | type=float, 214 | default=None, 215 | help='BatchNorm epsilon override (if not None)') 216 | parser.add_argument( 217 | '--sync-bn', 218 | action='store_true', 219 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 220 | parser.add_argument( 221 | '--dist-bn', 222 | type=str, 223 | default='', 224 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")' 225 | ) 226 | parser.add_argument( 227 | '--split-bn', 228 | action='store_true', 229 | help='Enable separate BN layers per augmentation split.') 230 | 231 | # Model Exponential Moving Average 232 | parser.add_argument( 233 | '--model-ema', 234 | action='store_true', 235 | default=False, 236 | help='Enable tracking moving average of model weights') 237 | parser.add_argument( 238 | '--model-ema-force-cpu', 239 | action='store_true', 240 | default=False, 241 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.' 242 | ) 243 | parser.add_argument( 244 | '--model-ema-decay', 245 | type=float, 246 | default=0.9998, 247 | help='decay factor for model weights moving average (default: 0.9998)') 248 | 249 | # Misc 250 | parser.add_argument( 251 | '--seed', 252 | type=int, 253 | default=42, 254 | metavar='S', 255 | help='random seed (default: 42)') 256 | parser.add_argument( 257 | '--log-interval', 258 | type=int, 259 | default=50, 260 | metavar='N', 261 | help='how many batches to wait before logging training status') 262 | parser.add_argument( 263 | '--recovery-interval', 264 | type=int, 265 | default=0, 266 | metavar='N', 267 | help='how many batches to wait before writing recovery checkpoint') 268 | parser.add_argument( 269 | '--checkpoint-hist', 270 | type=int, 271 | default=10, 272 | metavar='N', 273 | help='number of checkpoints to keep (default: 10)') 274 | parser.add_argument( 275 | '-j', 276 | '--workers', 277 | type=int, 278 | default=4, 279 | metavar='N', 280 | help='how many training processes to use (default: 1)') 281 | parser.add_argument( 282 | '--save-images', 283 | action='store_true', 284 | default=False, 285 | help='save images of input bathes every log interval for debugging') 286 | parser.add_argument( 287 | '--amp', 288 | action='store_true', 289 | default=False, 290 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 291 | parser.add_argument( 292 | '--apex-amp', 293 | action='store_true', 294 | default=False, 295 | help='Use NVIDIA Apex AMP mixed precision') 296 | parser.add_argument( 297 | '--native-amp', 298 | action='store_true', 299 | default=False, 300 | help='Use Native Torch AMP mixed precision') 301 | parser.add_argument( 302 | '--channels-last', 303 | action='store_true', 304 | default=False, 305 | help='Use channels_last memory layout') 306 | parser.add_argument( 307 | '--pin-mem', 308 | action='store_true', 309 | default=False, 310 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.' 311 | ) 312 | parser.add_argument( 313 | '--no-prefetcher', 314 | action='store_true', 315 | default=False, 316 | help='disable fast prefetcher') 317 | parser.add_argument( 318 | '--output', 319 | default='./', 320 | type=str, 321 | metavar='PATH', 322 | help='path to output folder (default: none, current dir)') 323 | parser.add_argument( 324 | '--experiment', 325 | default='', 326 | type=str, 327 | metavar='NAME', 328 | help='name of train experiment, name of sub-folder for output') 329 | parser.add_argument( 330 | '--eval-metric', 331 | default='top1', 332 | type=str, 333 | metavar='EVAL_METRIC', 334 | help='Best metric (default: "top1"') 335 | parser.add_argument( 336 | '--tta', 337 | type=int, 338 | default=0, 339 | metavar='N', 340 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)' 341 | ) 342 | parser.add_argument("--local_rank", default=0, type=int) 343 | parser.add_argument( 344 | '--use-multi-epochs-loader', 345 | action='store_true', 346 | default=False, 347 | help='use the multi-epochs-loader to save time at the beginning of every epoch' 348 | ) 349 | parser.add_argument( 350 | '--torchscript', 351 | dest='torchscript', 352 | action='store_true', 353 | help='convert model torchscript for inference') 354 | parser.add_argument( 355 | '--log-wandb', 356 | action='store_true', 357 | default=False, 358 | help='log training and validation metrics to wandb') 359 | 360 | 361 | def _parse_args(): 362 | # Do we have a config file to parse? 363 | args_config, remaining = config_parser.parse_known_args() 364 | if args_config.config: 365 | with open(args_config.config, 'r') as f: 366 | cfg = yaml.safe_load(f) 367 | parser.set_defaults(**cfg) 368 | 369 | # The main arg parser parses the rest of the args, the usual 370 | # defaults will have been overridden if config file specified. 371 | args = parser.parse_args(remaining) 372 | 373 | # Cache the args as a text string to save them in the output dir later 374 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 375 | return args, args_text 376 | 377 | 378 | def main(): 379 | setup_default_logging() 380 | args, args_text = _parse_args() 381 | 382 | if args.log_wandb: 383 | if has_wandb: 384 | wandb.init(project=args.experiment, config=args) 385 | else: 386 | _logger.warning( 387 | "You've requested to log metrics to wandb but package not found. " 388 | "Metrics not being logged to wandb, try `pip install wandb`") 389 | 390 | args.prefetcher = not args.no_prefetcher 391 | args.distributed = False 392 | if 'WORLD_SIZE' in os.environ: 393 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 394 | args.device = 'cuda:0' 395 | args.world_size = 1 396 | args.rank = 0 # global rank 397 | if args.distributed: 398 | args.device = 'cuda:%d' % args.local_rank 399 | torch.cuda.set_device(args.local_rank) 400 | torch.distributed.init_process_group( 401 | backend='nccl', init_method='env://') 402 | args.world_size = torch.distributed.get_world_size() 403 | args.rank = torch.distributed.get_rank() 404 | _logger.info( 405 | 'Testing in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 406 | % (args.rank, args.world_size)) 407 | else: 408 | _logger.info('Testing with a single process on 1 GPUs.') 409 | assert args.rank >= 0 410 | 411 | # resolve AMP arguments based on PyTorch / Apex availability 412 | use_amp = None 413 | if args.amp: 414 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 415 | if has_native_amp: 416 | args.native_amp = True 417 | elif has_apex: 418 | args.apex_amp = True 419 | if args.apex_amp and has_apex: 420 | use_amp = 'apex' 421 | elif args.native_amp and has_native_amp: 422 | use_amp = 'native' 423 | elif args.apex_amp or args.native_amp: 424 | _logger.warning( 425 | "Neither APEX or native Torch AMP is available, using float32. " 426 | "Install NVIDA apex or upgrade to PyTorch 1.6") 427 | 428 | random_seed(args.seed, args.rank) 429 | 430 | m = importlib.import_module(f"model.{args.model.split('.')[0]}") 431 | model = getattr(m, args.model.split('.')[1])(dropout=args.drop) 432 | 433 | if args.num_classes is None: 434 | assert hasattr( 435 | model, 'num_classes' 436 | ), 'Model must have `num_classes` attr if not set on cmd line/config.' 437 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 438 | 439 | if args.local_rank == 0: 440 | _logger.info( 441 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}' 442 | ) 443 | 444 | data_config = resolve_data_config( 445 | vars(args), model=model, verbose=args.local_rank == 0) 446 | 447 | # setup augmentation batch splits for contrastive loss or split bn 448 | num_aug_splits = 0 449 | if args.aug_splits > 0: 450 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 451 | num_aug_splits = args.aug_splits 452 | 453 | # enable split bn (separate bn stats per batch-portion) 454 | if args.split_bn: 455 | assert num_aug_splits > 1 or args.resplit 456 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 457 | 458 | # move model to GPU, enable channels last layout if set 459 | model.cuda() 460 | if args.channels_last: 461 | model = model.to(memory_format=torch.channels_last) 462 | 463 | # setup synchronized BatchNorm for distributed training 464 | if args.distributed and args.sync_bn: 465 | assert not args.split_bn 466 | if has_apex and use_amp != 'native': 467 | # Apex SyncBN preferred unless native amp is activated 468 | model = convert_syncbn_model(model) 469 | else: 470 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 471 | if args.local_rank == 0: 472 | _logger.info( 473 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 474 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' 475 | ) 476 | 477 | if args.torchscript: 478 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 479 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 480 | model = torch.jit.script(model) 481 | 482 | amp_autocast = suppress # do nothing 483 | 484 | load_checkpoint(model, args.resume, use_ema=False) 485 | 486 | # setup exponential moving average of model weights, SWA could be used here too 487 | model_ema = None 488 | if args.model_ema: 489 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 490 | model_ema = ModelEmaV2( 491 | model, 492 | decay=args.model_ema_decay, 493 | device='cpu' if args.model_ema_force_cpu else None) 494 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 495 | 496 | # setup distributed training 497 | if args.distributed: 498 | if has_apex and use_amp != 'native': 499 | # Apex DDP preferred unless native amp is activated 500 | if args.local_rank == 0: 501 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 502 | model = ApexDDP(model, delay_allreduce=True) 503 | else: 504 | if args.local_rank == 0: 505 | _logger.info("Using native Torch DistributedDataParallel.") 506 | model = NativeDDP( 507 | model, device_ids=[args.local_rank 508 | ]) # can use device str in Torch >= 1.1 509 | # NOTE: EMA model does not need to be wrapped by DDP 510 | 511 | dataset_eval = create_dataset( 512 | args.dataset, 513 | root=args.data_dir, 514 | split=args.val_split, 515 | is_training=False, 516 | batch_size=args.batch_size) 517 | 518 | loader_eval = create_loader( 519 | dataset_eval, 520 | input_size=data_config['input_size'], 521 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 522 | is_training=False, 523 | use_prefetcher=args.prefetcher, 524 | interpolation=data_config['interpolation'], 525 | mean=data_config['mean'], 526 | std=data_config['std'], 527 | num_workers=args.workers, 528 | distributed=args.distributed, 529 | crop_pct=data_config['crop_pct'], 530 | pin_memory=args.pin_mem, ) 531 | 532 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 533 | 534 | eval_metrics = validate( 535 | model, 536 | loader_eval, 537 | validate_loss_fn, 538 | args, 539 | amp_autocast=amp_autocast) 540 | if model_ema is not None and not args.model_ema_force_cpu: 541 | if args.distributed and args.dist_bn in ('broadcast', 'reduce' 542 | ): 543 | distribute_bn(model_ema, args.world_size, 544 | args.dist_bn == 'reduce') 545 | ema_eval_metrics = validate( 546 | model_ema.module, 547 | loader_eval, 548 | validate_loss_fn, 549 | args, 550 | amp_autocast=amp_autocast, 551 | log_suffix=' (EMA)') 552 | eval_metrics = ema_eval_metrics 553 | 554 | if args.rank == 0: 555 | print('eval_metrics = {}'.format(eval_metrics)) 556 | time.sleep(3) 557 | 558 | 559 | def validate(model, 560 | loader, 561 | loss_fn, 562 | args, 563 | amp_autocast=suppress, 564 | log_suffix=''): 565 | batch_time_m = AverageMeter() 566 | losses_m = AverageMeter() 567 | top1_m = AverageMeter() 568 | top5_m = AverageMeter() 569 | 570 | model.eval() 571 | 572 | end = time.time() 573 | last_idx = len(loader) - 1 574 | with torch.no_grad(): 575 | for batch_idx, (input, target) in enumerate(loader): 576 | last_batch = batch_idx == last_idx 577 | if not args.prefetcher: 578 | input = input.cuda() 579 | target = target.cuda() 580 | if args.channels_last: 581 | input = input.contiguous(memory_format=torch.channels_last) 582 | 583 | with amp_autocast(): 584 | output = model(input) 585 | if isinstance(output, (tuple, list)): 586 | output = output[0] 587 | 588 | # augmentation reduction 589 | reduce_factor = args.tta 590 | if reduce_factor > 1: 591 | output = output.unfold(0, reduce_factor, reduce_factor).mean( 592 | dim=2) 593 | target = target[0:target.size(0):reduce_factor] 594 | 595 | loss = loss_fn(output, target) 596 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 597 | 598 | if args.distributed: 599 | reduced_loss = reduce_tensor(loss.data, args.world_size) 600 | acc1 = reduce_tensor(acc1, args.world_size) 601 | acc5 = reduce_tensor(acc5, args.world_size) 602 | else: 603 | reduced_loss = loss.data 604 | 605 | torch.cuda.synchronize() 606 | 607 | losses_m.update(reduced_loss.item(), input.size(0)) 608 | top1_m.update(acc1.item(), output.size(0)) 609 | top5_m.update(acc5.item(), output.size(0)) 610 | 611 | batch_time_m.update(time.time() - end) 612 | end = time.time() 613 | if args.local_rank == 0 and (last_batch or 614 | batch_idx % args.log_interval == 0): 615 | log_name = 'Test' + log_suffix 616 | _logger.info( 617 | '{0}: [{1:>4d}/{2}] ' 618 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 619 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 620 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 621 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 622 | log_name, 623 | batch_idx, 624 | last_idx, 625 | batch_time=batch_time_m, 626 | loss=losses_m, 627 | top1=top1_m, 628 | top5=top5_m)) 629 | 630 | metrics = OrderedDict( 631 | [('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 632 | 633 | return metrics 634 | 635 | 636 | if __name__ == '__main__': 637 | main() 638 | -------------------------------------------------------------------------------- /work_dirs/train/readme.md: -------------------------------------------------------------------------------- 1 | We report the results based on the released models in [GoogleDrive](https://drive.google.com/drive/folders/1aL5UkhXgevyoQDo_cLmmd-DUfZcAFRXu?usp=share_link) and [百度网盘](https://pan.baidu.com/s/1yz7IdlagAL8LMf_NvjrbZw?pwd=qy7c), as in ```eval.log```. 2 | 3 | Note that the results differ to that in ```train.log``` slightly, this is because of the BN statistics. 4 | During training, the models are evaluated using 8 GPUs, in which each GPU has its own BN statistics, 5 | and we only save one model in GPU-0 as checkpoint, i.e., our released model. 6 | 7 | 8 | | RepGhostNet | Params(M) | FLOPs(M) | Latency(ms) | Top-1 Acc.(%) | Top-5 Acc.(%) | checkpoints | logs | 9 | |:------------|:----------|:---------|:------------|:--------------|:--------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------| 10 | | 0.5x | 2.3 | 43 | 25.1 | 66.9 | 86.9 | [gdrive](https://drive.google.com/file/d/16AGg-kSscFXDpXPZ3cJpYwqeZbUlUoyr/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1s-tuS8JoHVoCVHWuUiHUFw?pwd=qttp) | [log](./work_dirs/train/repghostnet_0_5x_43M_66.95/train.log) | 11 | | 0.58x | 2.5 | 60 | 31.9 | 68.9 | 88.4 | [gdrive](https://drive.google.com/file/d/1L6ccPjfnCMt5YK-pNFDfqGYvJyTRyZPR/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1bnVk2ILONqPEbmTQahZ0Og?pwd=tiyw) | [log](./work_dirs/train/repghostnet_0_58x_60M_68.94/train.log) | 12 | | 0.8x | 3.3 | 96 | 44.5 | 72.2 | 90.5 | [gdrive](https://drive.google.com/file/d/13gmUpwiJF_O05f3-3UeEyKD57veL5cG-/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1L_EJ0CnQeGpd0QBoOiY7oQ?pwd=rkd8) | [log](./work_dirs/train/repghostnet_0_8x_96M_72.24/train.log) | 13 | | 1.0x | 4.1 | 142 | 62.2 | 74.2 | 91.5 | [gdrive](https://drive.google.com/file/d/1gzfGln60urfY38elpPHVTyv9b94ukn5o/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1CEwuBLV05z7zrVbBrku59w?pwd=z4s7) | [log](./work_dirs/train/repghostnet_1_0x_142M_74.22/train.log) | 14 | | 1.11x | 4.5 | 170 | 71.5 | 75.1 | 92.2 | [gdrive](https://drive.google.com/file/d/14Lk4pKWIUFk1Mb53ooy_GsZbhMmz3iVE/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1Lb54Jiqyt0Jc6X4F_tUYnw?pwd=dwcb) | [log](./work_dirs/train/repghostnet_1_11x_170M_75.07/train.log) | 15 | | 1.3x | 5.5 | 231 | 92.9 | 76.4 | 92.9 | [gdrive](https://drive.google.com/file/d/1dNHpX2JyiuTcDmmyvr8gnAI9t8RM-Nui/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/19x_OUgxRDvwh2g4E9gN12Q?pwd=uux6) | [log](./work_dirs/train/repghostnet_1_3x_231M_76.37/train.log) | 16 | | 1.5x | 6.6 | 301 | 116.9 | 77.5 | 93.5 | [gdrive](https://drive.google.com/file/d/1TWAY654Dz8zcwhDBDN6QDWhV7as30P8e/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/15UWOMRQN5vw99QbgiFWMRw?pwd=3uqq) | [log](./work_dirs/train/repghostnet_1_5x_301M_77.45/train.log) | 17 | | 2.0x | 9.8 | 516 | 190.0 | 78.8 | 94.3 | [gdrive](https://drive.google.com/file/d/12k00eWCXhKxx_fq3ewDhCNX08ftJ-iyP/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1YbtYvIBt3tTqCzvbcjJuBw?pwd=nq1r) | [log](./work_dirs/train/repghostnet_2_0x_516M_78.81/train.log) | 18 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_0_58x_60M_68.94/args.yaml: -------------------------------------------------------------------------------- 1 | aa: null 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_0_58x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train/ 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_0_58x_60M_68.94/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_0_58x --resume work_dirs/train/repghostnet_0_58x_60M_68.94/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 1.3653137989139557), ('top1', 68.944), ('top5', 88.422)]) 3 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_0_5x_43M_66.95/args.yaml: -------------------------------------------------------------------------------- 1 | aa: null 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_0_5x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_0_5x_43M_66.95/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_0_5x --resume work_dirs/train/repghostnet_0_5x_43M_66.95/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 1.466918696756363), ('top1', 66.946), ('top5', 86.928)]) 3 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_0_8x_96M_72.24/args.yaml: -------------------------------------------------------------------------------- 1 | aa: null 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_0_8x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train/ 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_0_8x_96M_72.24/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_0_8x --resume work_dirs/train/repghostnet_0_8x_96M_72.24/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 1.2117804213762284), ('top1', 72.238), ('top5', 90.486)]) 3 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_0x_142M_74.22/args.yaml: -------------------------------------------------------------------------------- 1 | aa: null 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_1_0x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train/ 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_0x_142M_74.22/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICE=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_1_0x --resume work_dirs/train/repghostnet_1_0x_142M_74.22/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 1.117270246348381), ('top1', 74.218), ('top5', 91.548)]) 3 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_11x_170M_75.07/args.yaml: -------------------------------------------------------------------------------- 1 | aa: null 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_1_11x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train/ 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_11x_170M_75.07/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICE=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_1_11x --resume work_dirs/train/repghostnet_1_11x_170M_75.07/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 1.0745367737579345), ('top1', 75.066), ('top5', 92.186)]) 3 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_3x_231M_76.37/args.yaml: -------------------------------------------------------------------------------- 1 | aa: rand-m9-mstd0.5 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_1_3x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train/ 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_3x_231M_76.37/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_1_3x --resume work_dirs/train/repghostnet_1_3x_231M_76.37/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 1.0206603574037552), ('top1', 76.374), ('top5', 92.898)]) 3 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_5x_301M_77.45/args.yaml: -------------------------------------------------------------------------------- 1 | aa: rand-m9-mstd0.5 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_1_5x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_1_5x_301M_77.45/eval.log: -------------------------------------------------------------------------------- 1 | eval_metrics = OrderedDict([('loss', 0.9776309717464448), ('top1', 77.454), ('top5', 93.5)]) 2 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_2_0x_516M_78.81/args.yaml: -------------------------------------------------------------------------------- 1 | aa: rand-m9-mstd0.5 2 | amp: true 3 | apex_amp: false 4 | aug_splits: 0 5 | batch_size: 128 6 | bn_eps: null 7 | bn_momentum: null 8 | bn_tf: false 9 | channels_last: false 10 | checkpoint_hist: 5 11 | clip_grad: null 12 | clip_mode: norm 13 | color_jitter: 0.4 14 | cooldown_epochs: 0 15 | crop_pct: null 16 | cutmix: 0.0 17 | cutmix_minmax: null 18 | data_dir: /disk2/datasets/imagenet 19 | dataset: '' 20 | decay_epochs: 30 21 | decay_rate: 0.1 22 | dist_bn: '' 23 | drop: 0.2 24 | drop_block: null 25 | drop_connect: null 26 | drop_path: null 27 | epoch_repeats: 0.0 28 | epochs: 300 29 | eval_metric: top1 30 | experiment: '' 31 | gp: null 32 | hflip: 0.5 33 | img_size: null 34 | initial_checkpoint: '' 35 | input_size: null 36 | interpolation: '' 37 | jsd: false 38 | local_rank: 0 39 | log_interval: 50 40 | log_wandb: false 41 | lr: 0.6 42 | lr_cycle_limit: 1 43 | lr_cycle_mul: 1.0 44 | lr_noise: null 45 | lr_noise_pct: 0.67 46 | lr_noise_std: 1.0 47 | mean: null 48 | min_lr: 1.0e-05 49 | mixup: 0.0 50 | mixup_mode: batch 51 | mixup_off_epoch: 0 52 | mixup_prob: 1.0 53 | mixup_switch_prob: 0.5 54 | model: repghost.repghostnet_2_0x 55 | model_ema: true 56 | model_ema_decay: 0.9999 57 | model_ema_force_cpu: false 58 | momentum: 0.9 59 | native_amp: false 60 | no_aug: false 61 | no_prefetcher: false 62 | no_resume_opt: false 63 | num_classes: 1000 64 | opt: sgd 65 | opt_betas: null 66 | opt_eps: null 67 | output: work_dirs/train 68 | patience_epochs: 10 69 | pin_mem: false 70 | pretrained: false 71 | ratio: 72 | - 0.75 73 | - 1.3333333333333333 74 | recount: 1 75 | recovery_interval: 0 76 | remode: pixel 77 | reprob: 0.2 78 | resplit: false 79 | resume: '' 80 | save_images: false 81 | scale: 82 | - 0.08 83 | - 1.0 84 | sched: cosine 85 | seed: 42 86 | smoothing: 0.1 87 | split_bn: false 88 | start_epoch: null 89 | std: null 90 | sync_bn: false 91 | torchscript: false 92 | train_interpolation: random 93 | train_split: train 94 | tta: 0 95 | use_multi_epochs_loader: false 96 | val_split: val 97 | validation_batch_size_multiplier: 1 98 | vflip: 0.0 99 | warmup_epochs: 5 100 | warmup_lr: 0.0001 101 | weight_decay: 1.0e-05 102 | workers: 7 103 | -------------------------------------------------------------------------------- /work_dirs/train/repghostnet_2_0x_516M_78.81/eval.log: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_2_0x --resume work_dirs/train/repghostnet_2_0x_516M_78.81/model_best.pth.tar 2 | eval_metrics = OrderedDict([('loss', 0.897983897819519), ('top1', 78.806), ('top5', 94.326)]) 3 | --------------------------------------------------------------------------------