├── LICENSE
├── convert.py
├── distributed_train.sh
├── figs
├── accuracy_to_latency.png
├── accuracy_to_latency_iphone12cpu.png
└── repghost_bottleneck.png
├── infotool
├── __init__.py
├── fx_profile.c
├── fx_profile.py
├── helper.py
├── profile.py
├── rnn_hooks.py
└── vision
│ ├── __init__.py
│ ├── basic_hooks.py
│ ├── counter.py
│ ├── efficientnet.py
│ └── onnx_counter.py
├── model
├── __init__.py
└── repghost.py
├── readme.md
├── requirements.txt
├── tools.py
├── train.py
├── train.sh
├── validate.py
└── work_dirs
└── train
├── readme.md
├── repghostnet_0_58x_60M_68.94
├── args.yaml
├── eval.log
└── train.log
├── repghostnet_0_5x_43M_66.95
├── args.yaml
├── eval.log
└── train.log
├── repghostnet_0_8x_96M_72.24
├── args.yaml
├── eval.log
└── train.log
├── repghostnet_1_0x_142M_74.22
├── args.yaml
├── eval.log
└── train.log
├── repghostnet_1_11x_170M_75.07
├── args.yaml
├── eval.log
└── train.log
├── repghostnet_1_3x_231M_76.37
├── args.yaml
├── eval.log
└── train.log
├── repghostnet_1_5x_301M_77.45
├── args.yaml
├── eval.log
└── train.log
└── repghostnet_2_0x_516M_78.81
├── args.yaml
├── eval.log
└── train.log
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ChengpengChen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | # @Author : chengpeng.chen
2 | # @Email : chencp@live.com
3 | """
4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong.
5 | https://arxiv.org/abs/2211.06088
6 | """
7 | import argparse
8 | import os
9 | import importlib
10 | import torch
11 | import torch.nn.parallel
12 | import torch.optim
13 | import torch.utils.data
14 | import torch.utils.data.distributed
15 | from model.repghost import repghost_model_convert
16 |
17 | parser = argparse.ArgumentParser(description='RepGhost Conversion for Inference')
18 | parser.add_argument('load', metavar='LOAD', help='path to the weights file')
19 | parser.add_argument('save', metavar='SAVE', help='path to the weights file')
20 | parser.add_argument('-m', '--model', metavar='model', default='repghot.repghostnet_0_5x')
21 | parser.add_argument('--ema-model', '--ema', action='store_true', help='to load the ema model')
22 | parser.add_argument('--sanity_check', '-c', action='store_true', help='to check the outputs of the models')
23 |
24 |
25 | def convert():
26 | args = parser.parse_args()
27 |
28 | m = importlib.import_module(f"model.{args.model.split('.')[0]}")
29 | train_model = getattr(m, args.model.split('.')[1])()
30 | train_model.eval()
31 |
32 | if os.path.isfile(args.load):
33 | print("=> loading checkpoint '{}'".format(args.load))
34 | checkpoint = torch.load(args.load, map_location='cpu')
35 | if args.ema_model and 'state_dict_ema' in checkpoint:
36 | checkpoint = checkpoint['state_dict_ema']
37 | else:
38 | checkpoint = checkpoint['state_dict']
39 |
40 | try:
41 | train_model.load_state_dict(checkpoint)
42 | except Exception as e:
43 | ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names
44 | # print(ckpt.keys())
45 | train_model.load_state_dict(ckpt)
46 | else:
47 | print("=> no checkpoint found at '{}'".format(args.load))
48 |
49 | infer_model = repghost_model_convert(train_model, save_path=args.save)
50 | print("=> saved checkpoint to '{}'".format(args.save))
51 |
52 | if args.sanity_check:
53 | data = torch.randn(5, 3, 224, 224)
54 | out = train_model(data)
55 | out2 = infer_model(data)
56 | print('=> The diff is', ((out - out2) ** 2).sum())
57 |
58 |
59 | if __name__ == '__main__':
60 | convert()
61 |
--------------------------------------------------------------------------------
/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC --master_port=2345 train.py "$@"
--------------------------------------------------------------------------------
/figs/accuracy_to_latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/figs/accuracy_to_latency.png
--------------------------------------------------------------------------------
/figs/accuracy_to_latency_iphone12cpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/figs/accuracy_to_latency_iphone12cpu.png
--------------------------------------------------------------------------------
/figs/repghost_bottleneck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/figs/repghost_bottleneck.png
--------------------------------------------------------------------------------
/infotool/__init__.py:
--------------------------------------------------------------------------------
1 | from .helper import clever_format
2 | from .profile import profile, profile_origin
3 | import torch
4 |
5 | default_dtype = torch.float64
6 |
--------------------------------------------------------------------------------
/infotool/fx_profile.c:
--------------------------------------------------------------------------------
1 | #error Do not use this file, it is the result of a failed Cython compilation.
2 |
--------------------------------------------------------------------------------
/infotool/fx_profile.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch as th
3 | import torch.nn as nn
4 | from distutils.version import LooseVersion
5 |
6 | if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
7 | logging.warning(
8 | f"torch.fx requires version higher than 1.8.0. "
9 | f"But You are using an old version PyTorch {torch.__version__}. "
10 | )
11 |
12 |
13 | def count_clamp(input_shapes, output_shapes):
14 | return 0
15 |
16 |
17 | def count_mul(input_shapes, output_shapes):
18 | # element-wise
19 | return output_shapes[0].numel()
20 |
21 |
22 | def count_matmul(input_shapes, output_shapes):
23 | in_shape = input_shapes[0]
24 | out_shape = output_shapes[0]
25 | in_features = in_shape[-1]
26 | num_elements = out_shape.numel()
27 | return in_features * num_elements
28 |
29 |
30 | def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
31 | mul_flops = count_matmul(input_shapes, output_shapes)
32 | if "bias" in kwargs:
33 | add_flops = output_shapes[0].numel()
34 | return mul_flops
35 |
36 |
37 | from .vision.counter import counter_conv
38 |
39 |
40 | def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
41 | inputs, weight, bias, stride, padding, dilation, groups = args
42 | if len(input_shapes) == 2:
43 | x_shape, k_shape = input_shapes
44 | elif len(input_shapes) == 3:
45 | x_shape, k_shape, b_shape = input_shapes
46 | out_shape = output_shapes[0]
47 |
48 | kernel_parameters = k_shape[2:].numel()
49 | bias_op = 0 # check it later
50 | in_channel = x_shape[1]
51 |
52 | total_ops = counter_conv(
53 | bias_op, kernel_parameters, out_shape.numel(), in_channel, groups
54 | ).item()
55 | return int(total_ops)
56 |
57 |
58 | def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
59 | return count_matmul(input_shapes, output_shapes)
60 |
61 |
62 | def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
63 | return 0
64 |
65 |
66 | def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
67 | bias_op = 1 if module.bias is not None else 0
68 | out_shape = output_shapes[0]
69 |
70 | in_channel = module.in_channels
71 | groups = module.groups
72 | kernel_ops = module.weight.shape[2:].numel()
73 | total_ops = counter_conv(
74 | bias_op, kernel_ops, out_shape.numel(), in_channel, groups
75 | ).item()
76 | return int(total_ops)
77 |
78 |
79 | def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
80 | assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
81 | y = output_shapes[0]
82 | # y = (x - mean) / \sqrt{var + e} * weight + bias
83 | total_ops = 2 * y.numel()
84 | return total_ops
85 |
86 |
87 | zero_ops = (
88 | nn.ReLU,
89 | nn.ReLU6,
90 | nn.Dropout,
91 | nn.MaxPool2d,
92 | nn.AvgPool2d,
93 | nn.AdaptiveAvgPool2d,
94 | )
95 |
96 | count_map = {
97 | nn.Linear: count_nn_linear,
98 | nn.Conv2d: count_nn_conv2d,
99 | nn.BatchNorm2d: count_nn_bn2d,
100 | "function linear": count_fn_linear,
101 | "clamp": count_clamp,
102 | "built-in function add": count_zero_ops,
103 | "built-in method fl": count_zero_ops,
104 | "built-in method conv2d of type object": count_fn_conv2d,
105 | "built-in function mul": count_mul,
106 | "built-in function truediv": count_mul,
107 | }
108 |
109 | for k in zero_ops:
110 | count_map[k] = count_zero_ops
111 |
112 | missing_maps = {}
113 |
114 | from torch.fx import symbolic_trace
115 | from torch.fx.passes.shape_prop import ShapeProp
116 | from .helper import prGreen, prRed, prYellow
117 |
118 |
119 | def null_print(*args, **kwargs):
120 | return
121 |
122 |
123 | def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
124 | gm: torch.fx.GraphModule = symbolic_trace(mod)
125 | g = gm.graph
126 | ShapeProp(gm).propagate(input)
127 |
128 | fprint = null_print
129 | if verbose:
130 | fprint = print
131 |
132 | v_maps = {}
133 | total_flops = 0
134 |
135 | for node in gm.graph.nodes:
136 | # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
137 | fprint(
138 | f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}"
139 | )
140 | # node_op_type = str(node.target).split(".")[-1]
141 | node_flops = None
142 |
143 | input_shapes = []
144 | output_shapes = []
145 | fprint("input_shape:", end="\t")
146 | for arg in node.args:
147 | if str(arg) not in v_maps:
148 | continue
149 | fprint(f"{v_maps[str(arg)]}", end="\t")
150 | input_shapes.append(v_maps[str(arg)])
151 | fprint()
152 | fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
153 | output_shapes.append(node.meta["tensor_meta"].shape)
154 |
155 | if node.op in ["output", "placeholder"]:
156 | node_flops = 0
157 | elif node.op == "call_function":
158 | # torch internal functions
159 | key = (
160 | str(node.target)
161 | .split("at")[0]
162 | .replace("<", "")
163 | .replace(">", "")
164 | .strip()
165 | )
166 | if key in count_map:
167 | node_flops = count_map[key](
168 | input_shapes, output_shapes, *node.args, **node.kwargs
169 | )
170 | else:
171 | missing_maps[key] = (node.op, key)
172 | prRed(f"|{key}| is missing")
173 | elif node.op == "call_method":
174 | # torch internal functions
175 | # fprint(str(node.target) in count_map, str(node.target), count_map.keys())
176 | key = str(node.target)
177 | if key in count_map:
178 | node_flops = count_map[key](input_shapes, output_shapes)
179 | else:
180 | missing_maps[key] = (node.op, key)
181 | prRed(f"{key} is missing")
182 | elif node.op == "call_module":
183 | # torch.nn modules
184 | # m = getattr(mod, node.target, None)
185 | m = mod.get_submodule(node.target)
186 | key = type(m)
187 | fprint(type(m), type(m) in count_map)
188 | if type(m) in count_map:
189 | node_flops = count_map[type(m)](m, input_shapes, output_shapes)
190 | else:
191 | missing_maps[key] = (node.op,)
192 | prRed(f"{key} is missing")
193 | print("module type:", type(m))
194 | if isinstance(m, zero_ops):
195 | print(f"weight_shape: None")
196 | else:
197 | print(type(m))
198 | print(
199 | f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}"
200 | )
201 |
202 | v_maps[str(node.name)] = node.meta["tensor_meta"].shape
203 | if node_flops is not None:
204 | total_flops += node_flops
205 | prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
206 | fprint("==" * 40)
207 |
208 | if len(missing_maps.keys()) > 0:
209 | from pprint import pprint
210 | print("Missing operators: ")
211 | pprint(missing_maps)
212 | return total_flops
213 |
214 |
215 | if __name__ == "__main__":
216 |
217 | class MyOP(nn.Module):
218 | def forward(self, input):
219 | return input / 1
220 |
221 | class MyModule(torch.nn.Module):
222 | def __init__(self):
223 | super().__init__()
224 | self.linear1 = torch.nn.Linear(5, 3)
225 | self.linear2 = torch.nn.Linear(5, 3)
226 | self.myop = MyOP()
227 |
228 | def forward(self, x):
229 | out1 = self.linear1(x)
230 | out2 = self.linear2(x).clamp(min=0.0, max=1.0)
231 | return self.myop(out1 + out2)
232 |
233 | net = MyModule()
234 | data = th.randn(20, 5)
235 | flops = fx_profile(net, data, verbose=False)
236 | print(flops)
237 |
--------------------------------------------------------------------------------
/infotool/helper.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Iterable
2 |
3 | COLOR_RED = "91m"
4 | COLOR_GREEN = "92m"
5 | COLOR_YELLOW = "93m"
6 |
7 | def colorful_print(fn_print, color=COLOR_RED):
8 | def actual_call(*args, **kwargs):
9 | print(f"\033[{color}", end="")
10 | fn_print(*args, **kwargs)
11 | print("\033[00m", end="")
12 | return actual_call
13 |
14 | prRed = colorful_print(print, color=COLOR_RED)
15 | prGreen = colorful_print(print, color=COLOR_GREEN)
16 | prYellow = colorful_print(print, color=COLOR_YELLOW)
17 |
18 | # def prRed(skk):
19 | # print("\033[91m{}\033[00m".format(skk))
20 |
21 | # def prGreen(skk):
22 | # print("\033[92m{}\033[00m".format(skk))
23 |
24 | # def prYellow(skk):
25 | # print("\033[93m{}\033[00m".format(skk))
26 |
27 |
28 | def clever_format(nums, format="%.2f"):
29 | if not isinstance(nums, Iterable):
30 | nums = [nums]
31 | clever_nums = []
32 |
33 | for num in nums:
34 | if num > 1e12:
35 | clever_nums.append(format % (num / 1e12) + "T")
36 | elif num > 1e9:
37 | clever_nums.append(format % (num / 1e9) + "G")
38 | elif num > 1e6:
39 | clever_nums.append(format % (num / 1e6) + "M")
40 | elif num > 1e3:
41 | clever_nums.append(format % (num / 1e3) + "K")
42 | else:
43 | clever_nums.append(format % num + "B")
44 |
45 | clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
46 |
47 | return clever_nums
48 |
49 |
50 | if __name__ == "__main__":
51 | prRed("hello", "world")
52 | prGreen("hello", "world")
53 | prYellow("hello", "world")
--------------------------------------------------------------------------------
/infotool/profile.py:
--------------------------------------------------------------------------------
1 | from distutils.version import LooseVersion
2 |
3 | from .vision.basic_hooks import *
4 | from .rnn_hooks import *
5 |
6 |
7 | # logger = logging.getLogger(__name__)
8 | # logger.setLevel(logging.INFO)
9 |
10 | from .helper import prGreen, prRed, prYellow
11 |
12 | if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
13 | logging.warning(
14 | "You are using an old version PyTorch {version}, which THOP does NOT support.".format(
15 | version=torch.__version__
16 | )
17 | )
18 |
19 | default_dtype = torch.float64
20 |
21 | register_hooks = {
22 | # nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
23 | nn.Conv1d: count_convNd,
24 | nn.Conv2d: count_convNd,
25 | nn.Conv3d: count_convNd,
26 | nn.ConvTranspose1d: count_convNd,
27 | nn.ConvTranspose2d: count_convNd,
28 | nn.ConvTranspose3d: count_convNd,
29 | # nn.BatchNorm1d: count_bn,
30 | # nn.BatchNorm2d: count_bn,
31 | # nn.BatchNorm3d: count_bn,
32 | # nn.LayerNorm: count_ln,
33 | # nn.InstanceNorm1d: count_in,
34 | # nn.InstanceNorm2d: count_in,
35 | # nn.InstanceNorm3d: count_in,
36 | # nn.PReLU: count_prelu,
37 | # nn.Softmax: count_softmax,
38 | # nn.ReLU: zero_ops,
39 | # nn.ReLU6: zero_ops,
40 | # nn.LeakyReLU: count_relu,
41 | # nn.MaxPool1d: zero_ops,
42 | # nn.MaxPool2d: zero_ops,
43 | # nn.MaxPool3d: zero_ops,
44 | # nn.AdaptiveMaxPool1d: zero_ops,
45 | # nn.AdaptiveMaxPool2d: zero_ops,
46 | # nn.AdaptiveMaxPool3d: zero_ops,
47 | # nn.AvgPool1d: count_avgpool,
48 | # nn.AvgPool2d: count_avgpool,
49 | # nn.AvgPool3d: count_avgpool,
50 | # nn.AdaptiveAvgPool1d: count_adap_avgpool,
51 | # nn.AdaptiveAvgPool2d: count_adap_avgpool,
52 | # nn.AdaptiveAvgPool3d: count_adap_avgpool,
53 | nn.Linear: count_linear,
54 | # nn.Dropout: zero_ops,
55 | # nn.Upsample: count_upsample,
56 | # nn.UpsamplingBilinear2d: count_upsample,
57 | # nn.UpsamplingNearest2d: count_upsample,
58 | # nn.RNNCell: count_rnn_cell,
59 | # nn.GRUCell: count_gru_cell,
60 | # nn.LSTMCell: count_lstm_cell,
61 | # nn.RNN: count_rnn,
62 | # nn.GRU: count_gru,
63 | # nn.LSTM: count_lstm,
64 | # nn.Sequential: zero_ops,
65 | }
66 |
67 | if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
68 | register_hooks.update({nn.SyncBatchNorm: count_bn})
69 |
70 |
71 | def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
72 | handler_collection = []
73 | types_collection = set()
74 | if custom_ops is None:
75 | custom_ops = {}
76 | if report_missing:
77 | verbose = True
78 |
79 | def add_hooks(m):
80 | if len(list(m.children())) > 0:
81 | return
82 |
83 | if hasattr(m, "total_ops") or hasattr(m, "total_params"):
84 | logging.warning(
85 | "Either .total_ops or .total_params is already defined in %s. "
86 | "Be careful, it might change your code's behavior." % str(m)
87 | )
88 |
89 | m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype))
90 | m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype))
91 |
92 | for p in m.parameters():
93 | m.total_params += torch.DoubleTensor([p.numel()])
94 |
95 | m_type = type(m)
96 |
97 | fn = None
98 | if (
99 | m_type in custom_ops
100 | ): # if defined both op maps, use custom_ops to overwrite.
101 | fn = custom_ops[m_type]
102 | if m_type not in types_collection and verbose:
103 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
104 | elif m_type in register_hooks:
105 | fn = register_hooks[m_type]
106 | if m_type not in types_collection and verbose:
107 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
108 | else:
109 | if m_type not in types_collection and report_missing:
110 | prRed(
111 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
112 | % m_type
113 | )
114 |
115 | if fn is not None:
116 | handler = m.register_forward_hook(fn)
117 | handler_collection.append(handler)
118 | types_collection.add(m_type)
119 |
120 | training = model.training
121 |
122 | model.eval()
123 | model.apply(add_hooks)
124 |
125 | with torch.no_grad():
126 | model(*inputs)
127 |
128 | total_ops = 0
129 | total_params = 0
130 | for m in model.modules():
131 | if len(list(m.children())) > 0: # skip for non-leaf module
132 | continue
133 | total_ops += m.total_ops
134 | total_params += m.total_params
135 |
136 | total_ops = total_ops.item()
137 | total_params = total_params.item()
138 |
139 | # reset model to original status
140 | model.train(training)
141 | for handler in handler_collection:
142 | handler.remove()
143 |
144 | # remove temporal buffers
145 | for n, m in model.named_modules():
146 | if len(list(m.children())) > 0:
147 | continue
148 | if "total_ops" in m._buffers:
149 | m._buffers.pop("total_ops")
150 | if "total_params" in m._buffers:
151 | m._buffers.pop("total_params")
152 |
153 | return total_ops, total_params
154 |
155 |
156 | def profile(
157 | model: nn.Module,
158 | inputs,
159 | custom_ops=None,
160 | verbose=True,
161 | ret_layer_info=False,
162 | report_missing=False,
163 | ):
164 | handler_collection = {}
165 | types_collection = set()
166 | if custom_ops is None:
167 | custom_ops = {}
168 | if report_missing:
169 | # overwrite `verbose` option when enable report_missing
170 | verbose = True
171 |
172 | def add_hooks(m: nn.Module):
173 | m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
174 | m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
175 |
176 | # for p in m.parameters():
177 | # m.total_params += torch.DoubleTensor([p.numel()])
178 |
179 | m_type = type(m)
180 |
181 | fn = None
182 | if m_type in custom_ops:
183 | # if defined both op maps, use custom_ops to overwrite.
184 | fn = custom_ops[m_type]
185 | if m_type not in types_collection and verbose:
186 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
187 | elif m_type in register_hooks:
188 | fn = register_hooks[m_type]
189 | if m_type not in types_collection and verbose:
190 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
191 | else:
192 | if m_type not in types_collection and report_missing:
193 | prRed(
194 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
195 | % m_type
196 | )
197 |
198 | if fn is not None:
199 | handler_collection[m] = (
200 | m.register_forward_hook(fn),
201 | m.register_forward_hook(count_parameters),
202 | )
203 | types_collection.add(m_type)
204 |
205 | prev_training_status = model.training
206 |
207 | model.eval()
208 | model.apply(add_hooks)
209 |
210 | with torch.no_grad():
211 | model(*inputs)
212 |
213 | def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
214 | total_ops, total_params = module.total_ops.item(), 0
215 | ret_dict = {}
216 | for n, m in module.named_children():
217 | # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
218 | # m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
219 | # else:
220 | # m_ops, m_params = m.total_ops, m.total_params
221 | next_dict = {}
222 | if m in handler_collection and not isinstance(
223 | m, (nn.Sequential, nn.ModuleList)
224 | ):
225 | m_ops, m_params = m.total_ops.item(), m.total_params.item()
226 | else:
227 | m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
228 | ret_dict[n] = (m_ops, m_params, next_dict)
229 | total_ops += m_ops
230 | total_params += m_params
231 | # print(prefix, module._get_name(), (total_ops, total_params))
232 | return total_ops, total_params, ret_dict
233 |
234 | total_ops, total_params, ret_dict = dfs_count(model)
235 |
236 | # reset model to original status
237 | model.train(prev_training_status)
238 | for m, (op_handler, params_handler) in handler_collection.items():
239 | op_handler.remove()
240 | params_handler.remove()
241 | m._buffers.pop("total_ops")
242 | m._buffers.pop("total_params")
243 |
244 | if ret_layer_info:
245 | return total_ops, total_params, ret_dict
246 | return total_ops, total_params
247 |
--------------------------------------------------------------------------------
/infotool/rnn_hooks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import PackedSequence
4 |
5 |
6 | def _count_rnn_cell(input_size, hidden_size, bias=True):
7 | # h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
8 | total_ops = hidden_size * (input_size + hidden_size) + hidden_size
9 | if bias:
10 | total_ops += hidden_size * 2
11 |
12 | return total_ops
13 |
14 |
15 | def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
16 | total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)
17 |
18 | batch_size = x[0].size(0)
19 | total_ops *= batch_size
20 |
21 | m.total_ops += torch.DoubleTensor([int(total_ops)])
22 |
23 |
24 | def _count_gru_cell(input_size, hidden_size, bias=True):
25 | total_ops = 0
26 | # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
27 | # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
28 | state_ops = (hidden_size + input_size) * hidden_size + hidden_size
29 | if bias:
30 | state_ops += hidden_size * 2
31 | total_ops += state_ops * 2
32 |
33 | # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
34 | total_ops += (hidden_size + input_size) * hidden_size + hidden_size
35 | if bias:
36 | total_ops += hidden_size * 2
37 | # r hadamard : r * (~)
38 | total_ops += hidden_size
39 |
40 | # h' = (1 - z) * n + z * h
41 | # hadamard hadamard add
42 | total_ops += hidden_size * 3
43 |
44 | return total_ops
45 |
46 |
47 | def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
48 | total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)
49 |
50 | batch_size = x[0].size(0)
51 | total_ops *= batch_size
52 |
53 | m.total_ops += torch.DoubleTensor([int(total_ops)])
54 |
55 |
56 | def _count_lstm_cell(input_size, hidden_size, bias=True):
57 | total_ops = 0
58 |
59 | # i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
60 | # f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
61 | # o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
62 | # g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
63 | state_ops = (input_size + hidden_size) * hidden_size + hidden_size
64 | if bias:
65 | state_ops += hidden_size * 2
66 | total_ops += state_ops * 4
67 |
68 | # c' = f * c + i * g \\
69 | # hadamard hadamard add
70 | total_ops += hidden_size * 3
71 |
72 | # h' = o * \tanh(c') \\
73 | total_ops += hidden_size
74 |
75 | return total_ops
76 |
77 |
78 | def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
79 | total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)
80 |
81 | batch_size = x[0].size(0)
82 | total_ops *= batch_size
83 |
84 | m.total_ops += torch.DoubleTensor([int(total_ops)])
85 |
86 |
87 | def count_rnn(m: nn.RNN, x, y):
88 | bias = m.bias
89 | input_size = m.input_size
90 | hidden_size = m.hidden_size
91 | num_layers = m.num_layers
92 |
93 | if isinstance(x[0], PackedSequence):
94 | batch_size = torch.max(x[0].batch_sizes)
95 | num_steps = x[0].batch_sizes.size(0)
96 | else:
97 | if m.batch_first:
98 | batch_size = x[0].size(0)
99 | num_steps = x[0].size(1)
100 | else:
101 | batch_size = x[0].size(1)
102 | num_steps = x[0].size(0)
103 |
104 | total_ops = 0
105 | if m.bidirectional:
106 | total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2
107 | else:
108 | total_ops += _count_rnn_cell(input_size, hidden_size, bias)
109 |
110 | for i in range(num_layers - 1):
111 | if m.bidirectional:
112 | total_ops += _count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
113 | else:
114 | total_ops += _count_rnn_cell(hidden_size, hidden_size, bias)
115 |
116 | # time unroll
117 | total_ops *= num_steps
118 | # batch_size
119 | total_ops *= batch_size
120 |
121 | m.total_ops += torch.DoubleTensor([int(total_ops)])
122 |
123 |
124 | def count_gru(m: nn.GRU, x, y):
125 | bias = m.bias
126 | input_size = m.input_size
127 | hidden_size = m.hidden_size
128 | num_layers = m.num_layers
129 |
130 | if isinstance(x[0], PackedSequence):
131 | batch_size = torch.max(x[0].batch_sizes)
132 | num_steps = x[0].batch_sizes.size(0)
133 | else:
134 | if m.batch_first:
135 | batch_size = x[0].size(0)
136 | num_steps = x[0].size(1)
137 | else:
138 | batch_size = x[0].size(1)
139 | num_steps = x[0].size(0)
140 |
141 | total_ops = 0
142 | if m.bidirectional:
143 | total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2
144 | else:
145 | total_ops += _count_gru_cell(input_size, hidden_size, bias)
146 |
147 | for i in range(num_layers - 1):
148 | if m.bidirectional:
149 | total_ops += _count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
150 | else:
151 | total_ops += _count_gru_cell(hidden_size, hidden_size, bias)
152 |
153 | # time unroll
154 | total_ops *= num_steps
155 | # batch_size
156 | total_ops *= batch_size
157 |
158 | m.total_ops += torch.DoubleTensor([int(total_ops)])
159 |
160 |
161 | def count_lstm(m: nn.LSTM, x, y):
162 | bias = m.bias
163 | input_size = m.input_size
164 | hidden_size = m.hidden_size
165 | num_layers = m.num_layers
166 |
167 | if isinstance(x[0], PackedSequence):
168 | batch_size = torch.max(x[0].batch_sizes)
169 | num_steps = x[0].batch_sizes.size(0)
170 | else:
171 | if m.batch_first:
172 | batch_size = x[0].size(0)
173 | num_steps = x[0].size(1)
174 | else:
175 | batch_size = x[0].size(1)
176 | num_steps = x[0].size(0)
177 |
178 | total_ops = 0
179 | if m.bidirectional:
180 | total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2
181 | else:
182 | total_ops += _count_lstm_cell(input_size, hidden_size, bias)
183 |
184 | for i in range(num_layers - 1):
185 | if m.bidirectional:
186 | total_ops += _count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
187 | else:
188 | total_ops += _count_lstm_cell(hidden_size, hidden_size, bias)
189 |
190 | # time unroll
191 | total_ops *= num_steps
192 | # batch_size
193 | total_ops *= batch_size
194 |
195 | m.total_ops += torch.DoubleTensor([int(total_ops)])
196 |
--------------------------------------------------------------------------------
/infotool/vision/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChengpengChen/RepGhost/3c7d87e22c36b75507afc458f089940155308c3f/infotool/vision/__init__.py
--------------------------------------------------------------------------------
/infotool/vision/basic_hooks.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from .counter import (
4 | counter_parameters,
5 | counter_conv,
6 | counter_norm,
7 | counter_relu,
8 | counter_softmax,
9 | counter_avgpool,
10 | counter_adap_avg,
11 | counter_zero_ops,
12 | counter_upsample,
13 | counter_linear,
14 | )
15 | import torch
16 | import torch.nn as nn
17 | from torch.nn.modules.conv import _ConvNd
18 |
19 | multiply_adds = 1
20 |
21 |
22 | def count_parameters(m, x, y):
23 | total_params = 0
24 | for p in m.parameters():
25 | total_params += torch.DoubleTensor([p.numel()])
26 | m.total_params[0] = counter_parameters(m.parameters())
27 |
28 |
29 | def zero_ops(m, x, y):
30 | m.total_ops += counter_zero_ops()
31 |
32 |
33 | def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
34 | x = x[0]
35 |
36 | kernel_ops = torch.zeros(m.weight.size()[2:]).numel() # Kw x Kh
37 | bias_ops = 1 if m.bias is not None else 0
38 |
39 | # N x Cout x H x W x (Cin x Kw x Kh + bias)
40 | m.total_ops += counter_conv(
41 | bias_ops,
42 | torch.zeros(m.weight.size()[2:]).numel(),
43 | y.nelement(),
44 | m.in_channels,
45 | m.groups,
46 | )
47 |
48 |
49 | def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
50 | x = x[0]
51 |
52 | # N x H x W (exclude Cout)
53 | output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel()
54 | # # Cout x Cin x Kw x Kh
55 | # kernel_ops = m.weight.nelement()
56 | # if m.bias is not None:
57 | # # Cout x 1
58 | # kernel_ops += + m.bias.nelement()
59 | # # x N x H x W x Cout x (Cin x Kw x Kh + bias)
60 | # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)])
61 | m.total_ops += counter_conv(m.bias.nelement(), m.weight.nelement(), output_size)
62 |
63 |
64 | def count_bn(m, x, y):
65 | x = x[0]
66 | if not m.training:
67 | m.total_ops += counter_norm(x.numel())
68 |
69 |
70 | def count_ln(m, x, y):
71 | x = x[0]
72 | if not m.training:
73 | m.total_ops += counter_norm(x.numel())
74 |
75 |
76 | def count_in(m, x, y):
77 | x = x[0]
78 | if not m.training:
79 | m.total_ops += counter_norm(x.numel())
80 |
81 |
82 | def count_prelu(m, x, y):
83 | x = x[0]
84 |
85 | nelements = x.numel()
86 | if not m.training:
87 | m.total_ops += counter_relu(nelements)
88 |
89 |
90 | def count_relu(m, x, y):
91 | x = x[0]
92 |
93 | nelements = x.numel()
94 |
95 | m.total_ops += counter_relu(nelements)
96 |
97 |
98 | def count_softmax(m, x, y):
99 | x = x[0]
100 | nfeatures = x.size()[m.dim]
101 | batch_size = x.numel() // nfeatures
102 |
103 | m.total_ops += counter_softmax(batch_size, nfeatures)
104 |
105 |
106 | def count_avgpool(m, x, y):
107 | # total_add = torch.prod(torch.Tensor([m.kernel_size]))
108 | # total_div = 1
109 | # kernel_ops = total_add + total_div
110 | num_elements = y.numel()
111 | m.total_ops += counter_avgpool(num_elements)
112 |
113 |
114 | def count_adap_avgpool(m, x, y):
115 | kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor(
116 | [*(y.shape[2:])]
117 | )
118 | total_add = torch.prod(kernel)
119 | num_elements = y.numel()
120 | m.total_ops += counter_adap_avg(total_add, num_elements)
121 |
122 |
123 | # TODO: verify the accuracy
124 | def count_upsample(m, x, y):
125 | if m.mode not in (
126 | "nearest",
127 | "linear",
128 | "bilinear",
129 | "bicubic",
130 | ): # "trilinear"
131 | logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode)
132 | return counter_zero_ops()
133 |
134 | if m.mode == "nearest":
135 | return counter_zero_ops()
136 |
137 | x = x[0]
138 | m.total_ops += counter_upsample(m.mode, y.nelement())
139 |
140 |
141 | # nn.Linear
142 | def count_linear(m, x, y):
143 | # per output element
144 | total_mul = m.in_features
145 | # total_add = m.in_features - 1
146 | # total_add += 1 if m.bias is not None else 0
147 | num_elements = y.numel()
148 |
149 | m.total_ops += counter_linear(total_mul, num_elements)
150 |
--------------------------------------------------------------------------------
/infotool/vision/counter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def counter_parameters(para_list):
6 | total_params = 0
7 | for p in para_list:
8 | total_params += torch.DoubleTensor([p.nelement()])
9 | return total_params
10 |
11 |
12 | def counter_zero_ops():
13 | return torch.DoubleTensor([int(0)])
14 |
15 |
16 | def counter_conv(bias, kernel_size, output_size, in_channel, group):
17 | """inputs are all numbers!"""
18 | return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)])
19 |
20 |
21 | def counter_norm(input_size):
22 | """input is a number not a array or tensor"""
23 | return torch.DoubleTensor([2 * input_size])
24 |
25 |
26 | def counter_relu(input_size: torch.Tensor):
27 | return torch.DoubleTensor([int(input_size)])
28 |
29 |
30 | def counter_softmax(batch_size, nfeatures):
31 | total_exp = nfeatures
32 | total_add = nfeatures - 1
33 | total_div = nfeatures
34 | total_ops = batch_size * (total_exp + total_add + total_div)
35 | return torch.DoubleTensor([int(total_ops)])
36 |
37 |
38 | def counter_avgpool(input_size):
39 | return torch.DoubleTensor([int(input_size)])
40 |
41 |
42 | def counter_adap_avg(kernel_size, output_size):
43 | total_div = 1
44 | kernel_op = kernel_size + total_div
45 | return torch.DoubleTensor([int(kernel_op * output_size)])
46 |
47 |
48 | def counter_upsample(mode: str, output_size):
49 | total_ops = output_size
50 | if mode == "linear":
51 | total_ops *= 5
52 | elif mode == "bilinear":
53 | total_ops *= 11
54 | elif mode == "bicubic":
55 | ops_solve_A = 224 # 128 muls + 96 adds
56 | ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
57 | total_ops *= ops_solve_A + ops_solve_p
58 | elif mode == "trilinear":
59 | total_ops *= 13 * 2 + 5
60 | return torch.DoubleTensor([int(total_ops)])
61 |
62 |
63 | def counter_linear(in_feature, num_elements):
64 | return torch.DoubleTensor([int(in_feature * num_elements)])
65 |
66 |
67 | def counter_matmul(input_size, output_size):
68 | input_size = np.array(input_size)
69 | output_size = np.array(output_size)
70 | return np.prod(input_size) * output_size[-1]
71 |
72 |
73 | def counter_mul(input_size):
74 | return input_size
75 |
76 |
77 | def counter_pow(input_size):
78 | return input_size
79 |
80 |
81 | def counter_sqrt(input_size):
82 | return input_size
83 |
84 |
85 | def counter_div(input_size):
86 | return input_size
87 |
--------------------------------------------------------------------------------
/infotool/vision/efficientnet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.modules.conv import _ConvNd
7 |
8 | from efficientnet_pytorch.utils import Conv2dDynamicSamePadding, Conv2dStaticSamePadding
9 |
10 | register_hooks = {}
11 |
--------------------------------------------------------------------------------
/infotool/vision/onnx_counter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from onnx import numpy_helper
4 | from thop.vision.basic_hooks import zero_ops
5 | from .counter import (
6 | counter_matmul,
7 | counter_zero_ops,
8 | counter_conv,
9 | counter_mul,
10 | counter_norm,
11 | counter_pow,
12 | counter_sqrt,
13 | counter_div,
14 | counter_softmax,
15 | counter_avgpool,
16 | )
17 |
18 |
19 | def onnx_counter_matmul(diction, node):
20 | input1 = node.input[0]
21 | input2 = node.input[1]
22 | input1_dim = diction[input1]
23 | input2_dim = diction[input2]
24 | out_size = np.append(input1_dim[0:-1], input2_dim[-1])
25 | output_name = node.output[0]
26 | macs = counter_matmul(input1_dim, out_size[-2:])
27 | return macs, out_size, output_name
28 |
29 |
30 | def onnx_counter_add(diction, node):
31 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
32 | out_size = diction[node.input[1]]
33 | else:
34 | out_size = diction[node.input[0]]
35 | output_name = node.output[0]
36 | macs = counter_zero_ops()
37 | # if '140' in diction:
38 | # print(diction['140'],output_name)
39 | return macs, out_size, output_name
40 |
41 |
42 | def onnx_counter_conv(diction, node):
43 | # print(node)
44 | # bias,kernelsize,outputsize
45 | dim_bias = 0
46 | input_count = 0
47 | for i in node.input:
48 | input_count += 1
49 | if input_count == 3:
50 | dim_bias = 1
51 | dim_weight = diction[node.input[1]]
52 | else:
53 | dim_weight = diction[node.input[1]]
54 | for attr in node.attribute:
55 | # print(attr)
56 | if attr.name == "kernel_shape":
57 | dim_kernel = attr.ints # kw,kh
58 | if attr.name == "strides":
59 | dim_stride = attr.ints
60 | if attr.name == "pads":
61 | dim_pad = attr.ints
62 | if attr.name == "dilations":
63 | dim_dil = attr.ints
64 | if attr.name == "group":
65 | group = attr.i
66 | # print(dim_dil)
67 | dim_input = diction[node.input[0]]
68 | output_size = np.append(
69 | dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0]
70 | )
71 | hw = np.array(dim_input[-np.array(dim_kernel).size :])
72 | for i in range(hw.size):
73 | hw[i] = int(
74 | (hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1)
75 | / dim_stride[i]
76 | + 1
77 | )
78 | output_size = np.append(output_size, hw)
79 | macs = counter_conv(
80 | dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group
81 | )
82 | output_name = node.output[0]
83 |
84 | # if '140' in diction:
85 | # print("conv",diction['140'],output_name)
86 | return macs, output_size, output_name
87 |
88 |
89 | def onnx_counter_constant(diction, node):
90 | # print("constant",node)
91 | macs = counter_zero_ops()
92 | output_name = node.output[0]
93 | output_size = [1]
94 | # print(macs, output_size, output_name)
95 | return macs, output_size, output_name
96 |
97 |
98 | def onnx_counter_mul(diction, node):
99 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
100 | input_size = diction[node.input[1]]
101 | else:
102 | input_size = diction[node.input[0]]
103 | macs = counter_mul(np.prod(input_size))
104 | output_size = diction[node.input[0]]
105 | output_name = node.output[0]
106 | return macs, output_size, output_name
107 |
108 |
109 | def onnx_counter_bn(diction, node):
110 | input_size = diction[node.input[0]]
111 | macs = counter_norm(np.prod(input_size))
112 | output_name = node.output[0]
113 | output_size = input_size
114 | return macs, output_size, output_name
115 |
116 |
117 | def onnx_counter_relu(diction, node):
118 | input_size = diction[node.input[0]]
119 | macs = counter_zero_ops()
120 | output_name = node.output[0]
121 | output_size = input_size
122 | # print(macs, output_size, output_name)
123 | # if '140' in diction:
124 | # print("relu",diction['140'],output_name)
125 | return macs, output_size, output_name
126 |
127 |
128 | def onnx_counter_reducemean(diction, node):
129 | keep_dim = 0
130 | for attr in node.attribute:
131 | if "axes" in attr.name:
132 | dim_axis = np.array(attr.ints)
133 | elif "keepdims" in attr.name:
134 | keep_dim = attr.i
135 |
136 | input_size = diction[node.input[0]]
137 | macs = counter_zero_ops()
138 | output_name = node.output[0]
139 | if keep_dim == 1:
140 | output_size = input_size
141 | else:
142 | output_size = np.delete(input_size, dim_axis)
143 | # output_size = input_size
144 | return macs, output_size, output_name
145 |
146 |
147 | def onnx_counter_sub(diction, node):
148 | input_size = diction[node.input[0]]
149 | macs = counter_zero_ops()
150 | output_name = node.output[0]
151 | output_size = input_size
152 | return macs, output_size, output_name
153 |
154 |
155 | def onnx_counter_pow(diction, node):
156 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
157 | input_size = diction[node.input[1]]
158 | else:
159 | input_size = diction[node.input[0]]
160 | macs = counter_pow(np.prod(input_size))
161 | output_name = node.output[0]
162 | output_size = input_size
163 | return macs, output_size, output_name
164 |
165 |
166 | def onnx_counter_sqrt(diction, node):
167 | input_size = diction[node.input[0]]
168 | macs = counter_sqrt(np.prod(input_size))
169 | output_name = node.output[0]
170 | output_size = input_size
171 | return macs, output_size, output_name
172 |
173 |
174 | def onnx_counter_div(diction, node):
175 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size:
176 | input_size = diction[node.input[1]]
177 | else:
178 | input_size = diction[node.input[0]]
179 | macs = counter_div(np.prod(input_size))
180 | output_name = node.output[0]
181 | output_size = input_size
182 | return macs, output_size, output_name
183 |
184 |
185 | def onnx_counter_instance(diction, node):
186 | input_size = diction[node.input[0]]
187 | macs = counter_norm(np.prod(input_size))
188 | output_name = node.output[0]
189 | output_size = input_size
190 | return macs, output_size, output_name
191 |
192 |
193 | def onnx_counter_softmax(diction, node):
194 | input_size = diction[node.input[0]]
195 | dim = node.attribute[0].i
196 | nfeatures = input_size[dim]
197 | batch_size = np.prod(input_size) / nfeatures
198 | macs = counter_softmax(nfeatures, batch_size)
199 | output_name = node.output[0]
200 | output_size = input_size
201 | return macs, output_size, output_name
202 |
203 |
204 | def onnx_counter_pad(diction, node):
205 | # # TODO add constant name and output real vector
206 | # if
207 | # if (np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size):
208 | # input_size = diction[node.input[1]]
209 | # else:
210 | # input_size = diction[node.input[0]]
211 | input_size = diction[node.input[0]]
212 | macs = counter_zero_ops()
213 | output_name = node.output[0]
214 | output_size = input_size
215 | return macs, output_size, output_name
216 |
217 |
218 | def onnx_counter_averagepool(diction, node):
219 | # TODO add support of ceil_mode and floor
220 | macs = counter_avgpool(np.prod(diction[node.input[0]]))
221 | output_name = node.output[0]
222 | dim_pad = None
223 | for attr in node.attribute:
224 | # print(attr)
225 | if attr.name == "kernel_shape":
226 | dim_kernel = attr.ints # kw,kh
227 | elif attr.name == "strides":
228 | dim_stride = attr.ints
229 | elif attr.name == "pads":
230 | dim_pad = attr.ints
231 | elif attr.name == "dilations":
232 | dim_dil = attr.ints
233 | # print(dim_dil)
234 | dim_input = diction[node.input[0]]
235 | hw = dim_input[-np.array(dim_kernel).size :]
236 | if dim_pad is not None:
237 | for i in range(hw.size):
238 | hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_kernel[i]) / dim_stride[i] + 1)
239 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw)
240 | else:
241 | for i in range(hw.size):
242 | hw[i] = int((hw[i] - dim_kernel[i]) / dim_stride[i] + 1)
243 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw)
244 | # print(macs, output_size, output_name)
245 | return macs, output_size, output_name
246 |
247 |
248 | def onnx_counter_flatten(diction, node):
249 | # print(node)
250 | macs = counter_zero_ops()
251 | output_name = node.output[0]
252 | axis = node.attribute[0].i
253 | input_size = diction[node.input[0]]
254 | output_size = np.append(input_size[axis - 1], np.prod(input_size[axis:]))
255 | # print("flatten",output_size)
256 | return macs, output_size, output_name
257 |
258 |
259 | def onnx_counter_gemm(diction, node):
260 | # print(node)
261 | # Compute Y = alpha * A' * B' + beta * C
262 | input_size = diction[node.input[0]]
263 | dim_weight = diction[node.input[1]]
264 | # print(input_size,dim_weight)
265 | macs = np.prod(input_size) * dim_weight[1] + dim_weight[0]
266 | output_size = np.append(input_size[0:-1], dim_weight[0])
267 | output_name = node.output[0]
268 | return macs, output_size, output_name
269 | pass
270 |
271 |
272 | def onnx_counter_maxpool(diction, node):
273 | # TODO add support of ceil_mode and floor
274 | # print(node)
275 | macs = counter_zero_ops()
276 | output_name = node.output[0]
277 | dim_pad = None
278 | for attr in node.attribute:
279 | # print(attr)
280 | if attr.name == "kernel_shape":
281 | dim_kernel = attr.ints # kw,kh
282 | elif attr.name == "strides":
283 | dim_stride = attr.ints
284 | elif attr.name == "pads":
285 | dim_pad = attr.ints
286 | elif attr.name == "dilations":
287 | dim_dil = attr.ints
288 | # print(dim_dil)
289 | dim_input = diction[node.input[0]]
290 | hw = dim_input[-np.array(dim_kernel).size :]
291 | if dim_pad is not None:
292 | for i in range(hw.size):
293 | hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_kernel[i]) / dim_stride[i] + 1)
294 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw)
295 | else:
296 | for i in range(hw.size):
297 | hw[i] = int((hw[i] - dim_kernel[i]) / dim_stride[i] + 1)
298 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw)
299 | # print(macs, output_size, output_name)
300 | return macs, output_size, output_name
301 |
302 |
303 | def onnx_counter_globalaveragepool(diction, node):
304 | macs = counter_zero_ops()
305 | output_name = node.output[0]
306 | input_size = diction[node.input[0]]
307 | output_size = input_size
308 | return macs, output_size, output_name
309 |
310 |
311 | def onnx_counter_concat(diction, node):
312 | # print(node)
313 | # print(diction[node.input[0]])
314 | axis = node.attribute[0].i
315 | input_size = diction[node.input[0]]
316 | for i in node.input:
317 | dim_concat = diction[i][axis]
318 | output_size = input_size
319 | output_size[axis] = dim_concat
320 | output_name = node.output[0]
321 | macs = counter_zero_ops()
322 | return macs, output_size, output_name
323 |
324 |
325 | def onnx_counter_clip(diction, node):
326 | macs = counter_zero_ops()
327 | output_name = node.output[0]
328 | input_size = diction[node.input[0]]
329 | output_size = input_size
330 | return macs, output_size, output_name
331 |
332 |
333 | onnx_operators = {
334 | "MatMul": onnx_counter_matmul,
335 | "Add": onnx_counter_add,
336 | "Conv": onnx_counter_conv,
337 | "Mul": onnx_counter_mul,
338 | "Constant": onnx_counter_constant,
339 | "BatchNormalization": onnx_counter_bn,
340 | "Relu": onnx_counter_relu,
341 | "ReduceMean": onnx_counter_reducemean,
342 | "Sub": onnx_counter_sub,
343 | "Pow": onnx_counter_pow,
344 | "Sqrt": onnx_counter_sqrt,
345 | "Div": onnx_counter_div,
346 | "InstanceNormalization": onnx_counter_instance,
347 | "Softmax": onnx_counter_softmax,
348 | "Pad": onnx_counter_pad,
349 | "AveragePool": onnx_counter_averagepool,
350 | "MaxPool": onnx_counter_maxpool,
351 | "Flatten": onnx_counter_flatten,
352 | "Gemm": onnx_counter_gemm,
353 | "GlobalAveragePool": onnx_counter_globalaveragepool,
354 | "Concat": onnx_counter_concat,
355 | "Clip": onnx_counter_clip,
356 | None: None,
357 | }
358 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .repghost import *
2 |
--------------------------------------------------------------------------------
/model/repghost.py:
--------------------------------------------------------------------------------
1 | # @Author : chengpeng.chen
2 | # @Email : chencp@live.com
3 | """
4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong.
5 | https://arxiv.org/abs/2211.06088
6 | """
7 | import copy
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | __all__ = [
15 | 'repghostnet_0_5x',
16 | 'repghostnet_repid_0_5x',
17 | 'repghostnet_norep_0_5x',
18 | 'repghostnet_wo_0_5x',
19 | 'repghostnet_0_58x',
20 | 'repghostnet_0_8x',
21 | 'repghostnet_1_0x',
22 | 'repghostnet_1_11x',
23 | 'repghostnet_1_3x',
24 | 'repghostnet_1_5x',
25 | 'repghostnet_2_0x',
26 | 'repghostnet',
27 | ]
28 |
29 |
30 | def _make_divisible(v, divisor, min_value=None):
31 | """
32 | This function is taken from the original tf repo.
33 | It ensures that all layers have a channel number that is divisible by 8
34 | It can be seen here:
35 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
36 | """
37 | if min_value is None:
38 | min_value = divisor
39 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
40 | # Make sure that round down does not go down by more than 10%.
41 | if new_v < 0.9 * v:
42 | new_v += divisor
43 | return new_v
44 |
45 |
46 | def hard_sigmoid(x, inplace: bool = False):
47 | if inplace:
48 | return x.add_(3.0).clamp_(0.0, 6.0).div_(6.0)
49 | else:
50 | return F.relu6(x + 3.0) / 6.0
51 |
52 |
53 | class SqueezeExcite(nn.Module):
54 | def __init__(
55 | self,
56 | in_chs,
57 | se_ratio=0.25,
58 | reduced_base_chs=None,
59 | act_layer=nn.ReLU,
60 | gate_fn=hard_sigmoid,
61 | divisor=4,
62 | **_,
63 | ):
64 | super(SqueezeExcite, self).__init__()
65 | self.gate_fn = gate_fn
66 | reduced_chs = _make_divisible(
67 | (reduced_base_chs or in_chs) * se_ratio, divisor,
68 | )
69 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
70 | self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
71 | self.act1 = act_layer(inplace=True)
72 | self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
73 |
74 | def forward(self, x):
75 | x_se = self.avg_pool(x)
76 | x_se = self.conv_reduce(x_se)
77 | x_se = self.act1(x_se)
78 | x_se = self.conv_expand(x_se)
79 | x = x * self.gate_fn(x_se)
80 | return x
81 |
82 |
83 | class ConvBnAct(nn.Module):
84 | def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_layer=nn.ReLU):
85 | super(ConvBnAct, self).__init__()
86 | self.conv = nn.Conv2d(
87 | in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False,
88 | )
89 | self.bn1 = nn.BatchNorm2d(out_chs)
90 | self.act1 = act_layer(inplace=True)
91 |
92 | def forward(self, x):
93 | x = self.conv(x)
94 | x = self.bn1(x)
95 | x = self.act1(x)
96 | return x
97 |
98 |
99 | class RepGhostModule(nn.Module):
100 | def __init__(
101 | self, inp, oup, kernel_size=1, dw_size=3, stride=1, relu=True, deploy=False, reparam_bn=True, reparam_identity=False
102 | ):
103 | super(RepGhostModule, self).__init__()
104 | init_channels = oup
105 | new_channels = oup
106 | self.deploy = deploy
107 |
108 | self.primary_conv = nn.Sequential(
109 | nn.Conv2d(
110 | inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False,
111 | ),
112 | nn.BatchNorm2d(init_channels),
113 | nn.ReLU(inplace=True) if relu else nn.Sequential(),
114 | )
115 | fusion_conv = []
116 | fusion_bn = []
117 | if not deploy and reparam_bn:
118 | fusion_conv.append(nn.Identity())
119 | fusion_bn.append(nn.BatchNorm2d(init_channels))
120 | if not deploy and reparam_identity:
121 | fusion_conv.append(nn.Identity())
122 | fusion_bn.append(nn.Identity())
123 |
124 | self.fusion_conv = nn.Sequential(*fusion_conv)
125 | self.fusion_bn = nn.Sequential(*fusion_bn)
126 |
127 | self.cheap_operation = nn.Sequential(
128 | nn.Conv2d(
129 | init_channels,
130 | new_channels,
131 | dw_size,
132 | 1,
133 | dw_size // 2,
134 | groups=init_channels,
135 | bias=deploy,
136 | ),
137 | nn.BatchNorm2d(new_channels) if not deploy else nn.Sequential(),
138 | # nn.ReLU(inplace=True) if relu else nn.Sequential(),
139 | )
140 | if deploy:
141 | self.cheap_operation = self.cheap_operation[0]
142 | if relu:
143 | self.relu = nn.ReLU(inplace=False)
144 | else:
145 | self.relu = nn.Sequential()
146 |
147 | def forward(self, x):
148 | x1 = self.primary_conv(x)
149 | x2 = self.cheap_operation(x1)
150 | for conv, bn in zip(self.fusion_conv, self.fusion_bn):
151 | x2 = x2 + bn(conv(x1))
152 | return self.relu(x2)
153 |
154 | def get_equivalent_kernel_bias(self):
155 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1])
156 | for conv, bn in zip(self.fusion_conv, self.fusion_bn):
157 | kernel, bias = self._fuse_bn_tensor(conv, bn, kernel3x3.shape[0], kernel3x3.device)
158 | kernel3x3 += self._pad_1x1_to_3x3_tensor(kernel)
159 | bias3x3 += bias
160 | return kernel3x3, bias3x3
161 |
162 | @staticmethod
163 | def _pad_1x1_to_3x3_tensor(kernel1x1):
164 | if kernel1x1 is None:
165 | return 0
166 | else:
167 | return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
168 |
169 | @staticmethod
170 | def _fuse_bn_tensor(conv, bn, in_channels=None, device=None):
171 | in_channels = in_channels if in_channels else bn.running_mean.shape[0]
172 | device = device if device else bn.weight.device
173 | if isinstance(conv, nn.Conv2d):
174 | kernel = conv.weight
175 | assert conv.bias is None
176 | else:
177 | assert isinstance(conv, nn.Identity)
178 | kernel_value = np.zeros((in_channels, 1, 1, 1), dtype=np.float32)
179 | for i in range(in_channels):
180 | kernel_value[i, 0, 0, 0] = 1
181 | kernel = torch.from_numpy(kernel_value).to(device)
182 |
183 | if isinstance(bn, nn.BatchNorm2d):
184 | running_mean = bn.running_mean
185 | running_var = bn.running_var
186 | gamma = bn.weight
187 | beta = bn.bias
188 | eps = bn.eps
189 | std = (running_var + eps).sqrt()
190 | t = (gamma / std).reshape(-1, 1, 1, 1)
191 | return kernel * t, beta - running_mean * gamma / std
192 | assert isinstance(bn, nn.Identity)
193 | return kernel, torch.zeros(in_channels).to(kernel.device)
194 |
195 | def switch_to_deploy(self):
196 | if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0:
197 | return
198 | kernel, bias = self.get_equivalent_kernel_bias()
199 | self.cheap_operation = nn.Conv2d(in_channels=self.cheap_operation[0].in_channels,
200 | out_channels=self.cheap_operation[0].out_channels,
201 | kernel_size=self.cheap_operation[0].kernel_size,
202 | padding=self.cheap_operation[0].padding,
203 | dilation=self.cheap_operation[0].dilation,
204 | groups=self.cheap_operation[0].groups,
205 | bias=True)
206 | self.cheap_operation.weight.data = kernel
207 | self.cheap_operation.bias.data = bias
208 | self.__delattr__('fusion_conv')
209 | self.__delattr__('fusion_bn')
210 | self.fusion_conv = []
211 | self.fusion_bn = []
212 | self.deploy = True
213 |
214 |
215 | class RepGhostBottleneck(nn.Module):
216 | """RepGhost bottleneck w/ optional SE"""
217 |
218 | def __init__(
219 | self,
220 | in_chs,
221 | mid_chs,
222 | out_chs,
223 | dw_kernel_size=3,
224 | stride=1,
225 | se_ratio=0.0,
226 | shortcut=True,
227 | reparam=True,
228 | reparam_bn=True,
229 | reparam_identity=False,
230 | deploy=False,
231 | ):
232 | super(RepGhostBottleneck, self).__init__()
233 | has_se = se_ratio is not None and se_ratio > 0.0
234 | self.stride = stride
235 | self.enable_shortcut = shortcut
236 | self.in_chs = in_chs
237 | self.out_chs = out_chs
238 |
239 | # Point-wise expansion
240 | self.ghost1 = RepGhostModule(
241 | in_chs,
242 | mid_chs,
243 | relu=True,
244 | reparam_bn=reparam and reparam_bn,
245 | reparam_identity=reparam and reparam_identity,
246 | deploy=deploy,
247 | )
248 |
249 | # Depth-wise convolution
250 | if self.stride > 1:
251 | self.conv_dw = nn.Conv2d(
252 | mid_chs,
253 | mid_chs,
254 | dw_kernel_size,
255 | stride=stride,
256 | padding=(dw_kernel_size - 1) // 2,
257 | groups=mid_chs,
258 | bias=False,
259 | )
260 | self.bn_dw = nn.BatchNorm2d(mid_chs)
261 |
262 | # Squeeze-and-excitation
263 | if has_se:
264 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio)
265 | else:
266 | self.se = None
267 |
268 | # Point-wise linear projection
269 | self.ghost2 = RepGhostModule(
270 | mid_chs,
271 | out_chs,
272 | relu=False,
273 | reparam_bn=reparam and reparam_bn,
274 | reparam_identity=reparam and reparam_identity,
275 | deploy=deploy,
276 | )
277 |
278 | # shortcut
279 | if in_chs == out_chs and self.stride == 1:
280 | self.shortcut = nn.Sequential()
281 | else:
282 | self.shortcut = nn.Sequential(
283 | nn.Conv2d(
284 | in_chs,
285 | in_chs,
286 | dw_kernel_size,
287 | stride=stride,
288 | padding=(dw_kernel_size - 1) // 2,
289 | groups=in_chs,
290 | bias=False,
291 | ),
292 | nn.BatchNorm2d(in_chs),
293 | nn.Conv2d(
294 | in_chs, out_chs, 1, stride=1,
295 | padding=0, bias=False,
296 | ),
297 | nn.BatchNorm2d(out_chs),
298 | )
299 |
300 | def forward(self, x):
301 | residual = x
302 |
303 | # 1st repghost bottleneck
304 | x1 = self.ghost1(x)
305 |
306 | # Depth-wise convolution
307 | if self.stride > 1:
308 | x = self.conv_dw(x1)
309 | x = self.bn_dw(x)
310 | else:
311 | x = x1
312 |
313 | # Squeeze-and-excitation
314 | if self.se is not None:
315 | x = self.se(x)
316 |
317 | # 2nd repghost bottleneck
318 | x = self.ghost2(x)
319 | if not self.enable_shortcut and self.in_chs == self.out_chs and self.stride == 1:
320 | return x
321 | return x + self.shortcut(residual)
322 |
323 |
324 | class RepGhostNet(nn.Module):
325 | def __init__(
326 | self,
327 | cfgs,
328 | num_classes=1000,
329 | width=1.0,
330 | dropout=0.2,
331 | shortcut=True,
332 | reparam=True,
333 | reparam_bn=True,
334 | reparam_identity=False,
335 | deploy=False,
336 | ):
337 | super(RepGhostNet, self).__init__()
338 | # setting of inverted residual blocks
339 | self.cfgs = cfgs
340 | self.dropout = dropout
341 | self.num_classes = num_classes
342 |
343 | # building first layer
344 | output_channel = _make_divisible(16 * width, 4)
345 | self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False)
346 | self.bn1 = nn.BatchNorm2d(output_channel)
347 | self.act1 = nn.ReLU(inplace=True)
348 | input_channel = output_channel
349 |
350 | # building inverted residual blocks
351 | stages = []
352 | block = RepGhostBottleneck
353 | for cfg in self.cfgs:
354 | layers = []
355 | for k, exp_size, c, se_ratio, s in cfg:
356 | output_channel = _make_divisible(c * width, 4)
357 | hidden_channel = _make_divisible(exp_size * width, 4)
358 | layers.append(
359 | block(
360 | input_channel,
361 | hidden_channel,
362 | output_channel,
363 | k,
364 | s,
365 | se_ratio=se_ratio,
366 | shortcut=shortcut,
367 | reparam=reparam,
368 | reparam_bn=reparam_bn,
369 | reparam_identity=reparam_identity,
370 | deploy=deploy
371 | ),
372 | )
373 | input_channel = output_channel
374 | stages.append(nn.Sequential(*layers))
375 |
376 | output_channel = _make_divisible(exp_size * width * 2, 4)
377 | stages.append(
378 | nn.Sequential(
379 | ConvBnAct(input_channel, output_channel, 1),
380 | ),
381 | )
382 | input_channel = output_channel
383 |
384 | self.blocks = nn.Sequential(*stages)
385 |
386 | # building last several layers
387 | output_channel = 1280
388 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
389 | self.conv_head = nn.Conv2d(
390 | input_channel, output_channel, 1, 1, 0, bias=True,
391 | )
392 | self.act2 = nn.ReLU(inplace=True)
393 | self.classifier = nn.Linear(output_channel, num_classes)
394 |
395 | def forward(self, x):
396 | x = self.conv_stem(x)
397 | x = self.bn1(x)
398 | x = self.act1(x)
399 | x = self.blocks(x)
400 | x = self.global_pool(x)
401 | x = self.conv_head(x)
402 | x = self.act2(x)
403 | x = x.view(x.size(0), -1)
404 | if self.dropout > 0.0:
405 | x = F.dropout(x, p=self.dropout, training=self.training)
406 | x = self.classifier(x)
407 | return x
408 |
409 | def convert_to_deploy(self):
410 | repghost_model_convert(self, do_copy=False)
411 |
412 |
413 | def repghost_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):
414 | """
415 | taken from from https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
416 | """
417 | if do_copy:
418 | model = copy.deepcopy(model)
419 | for module in model.modules():
420 | if hasattr(module, 'switch_to_deploy'):
421 | module.switch_to_deploy()
422 | if save_path is not None:
423 | torch.save(model.state_dict(), save_path)
424 | return model
425 |
426 |
427 | def repghostnet(enable_se=True, **kwargs):
428 | """
429 | Constructs a RepGhostNet model
430 | """
431 | cfgs = [
432 | # k, t, c, SE, s
433 | # stage1
434 | [[3, 8, 16, 0, 1]],
435 | # stage2
436 | [[3, 24, 24, 0, 2]],
437 | [[3, 36, 24, 0, 1]],
438 | # stage3
439 | [[5, 36, 40, 0.25 if enable_se else 0, 2]],
440 | [[5, 60, 40, 0.25 if enable_se else 0, 1]],
441 | # stage4
442 | [[3, 120, 80, 0, 2]],
443 | [
444 | [3, 100, 80, 0, 1],
445 | [3, 120, 80, 0, 1],
446 | [3, 120, 80, 0, 1],
447 | [3, 240, 112, 0.25 if enable_se else 0, 1],
448 | [3, 336, 112, 0.25 if enable_se else 0, 1],
449 | ],
450 | # stage5
451 | [[5, 336, 160, 0.25 if enable_se else 0, 2]],
452 | [
453 | [5, 480, 160, 0, 1],
454 | [5, 480, 160, 0.25 if enable_se else 0, 1],
455 | [5, 480, 160, 0, 1],
456 | [5, 480, 160, 0.25 if enable_se else 0, 1],
457 | ],
458 | ]
459 |
460 | return RepGhostNet(cfgs, **kwargs)
461 |
462 |
463 | def repghostnet_0_5x(**kwargs):
464 | return repghostnet(width=0.5, **kwargs)
465 |
466 |
467 | def repghostnet_repid_0_5x(**kwargs):
468 | return repghostnet(width=0.5, reparam_bn=False, reparam_identity=True, **kwargs)
469 |
470 |
471 | def repghostnet_norep_0_5x(**kwargs):
472 | return repghostnet(width=0.5, reparam=False, **kwargs)
473 |
474 |
475 | def repghostnet_wo_0_5x(**kwargs):
476 | return repghostnet(width=0.5, shortcut=False, **kwargs)
477 |
478 |
479 | def repghostnet_0_58x(**kwargs):
480 | return repghostnet(width=0.58, **kwargs)
481 |
482 |
483 | def repghostnet_0_8x(**kwargs):
484 | return repghostnet(width=0.8, **kwargs)
485 |
486 |
487 | def repghostnet_1_0x(**kwargs):
488 | return repghostnet(width=1.0, **kwargs)
489 |
490 |
491 | def repghostnet_1_11x(**kwargs):
492 | return repghostnet(width=1.11, **kwargs)
493 |
494 |
495 | def repghostnet_1_3x(**kwargs):
496 | return repghostnet(width=1.3, **kwargs)
497 |
498 |
499 | def repghostnet_1_5x(**kwargs):
500 | return repghostnet(width=1.5, **kwargs)
501 |
502 |
503 | def repghostnet_2_0x(**kwargs):
504 | return repghostnet(width=2.0, **kwargs)
505 |
506 |
507 | if __name__ == "__main__":
508 | model = repghostnet_norep_0_5x().eval()
509 | print(model)
510 | input = torch.randn(1, 3, 224, 224)
511 | # y = model(input)
512 | # print(y.size())
513 |
514 | import sys
515 |
516 | sys.path.append("../")
517 | from tools import cal_flops_params
518 |
519 | flops, params = cal_flops_params(model, input_size=input.shape)
520 |
521 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization
2 |
3 | The official pytorch implementation of the paper **[RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization](https://arxiv.org/abs/2211.06088)**
4 |
5 | #### Chengpeng Chen\*, Zichao Guo\*, Haien Zeng, Pengfei Xiong, Jian Dong
6 |
7 | >Feature reuse has been a key technique in light-weight convolutional neural networks (CNNs) design. Current methods usually utilize a concatenation operator to keep large channel numbers cheaply (thus large network capacity) by reusing feature maps from other layers. Although concatenation is parameters- and FLOPs-free, its computational cost on hardware devices is non-negligible. To address this, this paper provides a new perspective to realize feature reuse via structural re-parameterization technique. A novel hardware-efficient RepGhost module is proposed for implicit feature reuse via re-parameterization, instead of using concatenation operator. Based on the RepGhost module, we develop our efficient RepGhost bottleneck and RepGhostNet. Experiments on ImageNet and COCO benchmarks demonstrate that the proposed RepGhostNet is much more effective and efficient than GhostNet and MobileNetV3 on mobile devices. Specially, our RepGhostNet surpasses GhostNet 0.5x by 2.5% Top-1 accuracy on ImageNet dataset with less parameters and comparable latency on an ARM-based mobile phone.
8 |
9 |
10 |
11 |
12 |
13 | ```python
14 | python 3.9.12
15 | pytorch 1.11.0
16 | cuda 11.3
17 | timm 0.6.7
18 | ```
19 |
20 | ```bash
21 | git clone https://github.com/ChengpengChen/RepGhost
22 | cd RepGhost
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ### Training
27 | ```bash
28 | bash distributed_train.sh 8 --model repghost.repghostnet_0_5x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/ --data_dir {path_to_imagenet_dir}
29 | ```
30 |
31 | ### Validation
32 | ```bash
33 | python3 -m torch.distributed.launch --nproc_per_node=8 --master_port=2340 validate.py -b 32 --model-ema --model {model} --resume {checkpoint_path} --data_dir {path_to_imagenet_dir}
34 | ```
35 |
36 | ### Convert a training-time RepGhost into a fast-inference one
37 | To check the conversion example at ```convert.py```. You can also convert RepGhostNet model for fast inference via:
38 |
39 | ```python
40 | model.convert_to_deploy()
41 | ```
42 |
43 | ### Results and Pre-trained Models
44 |
45 |
46 |
47 |
48 |
49 | | RepGhostNet | Params(M) | FLOPs(M) | Latency(ms) | Top-1 Acc.(%) | Top-5 Acc.(%) | checkpoints | logs |
50 | |:------------|:----------|:---------|:------------|:--------------|:--------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------|
51 | | 0.5x | 2.3 | 43 | 25.1 | 66.9 | 86.9 | [gdrive](https://drive.google.com/file/d/16AGg-kSscFXDpXPZ3cJpYwqeZbUlUoyr/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1s-tuS8JoHVoCVHWuUiHUFw?pwd=qttp) | [log](./work_dirs/train/repghostnet_0_5x_43M_66.95/train.log) |
52 | | 0.58x | 2.5 | 60 | 31.9 | 68.9 | 88.4 | [gdrive](https://drive.google.com/file/d/1L6ccPjfnCMt5YK-pNFDfqGYvJyTRyZPR/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1bnVk2ILONqPEbmTQahZ0Og?pwd=tiyw) | [log](./work_dirs/train/repghostnet_0_58x_60M_68.94/train.log) |
53 | | 0.8x | 3.3 | 96 | 44.5 | 72.2 | 90.5 | [gdrive](https://drive.google.com/file/d/13gmUpwiJF_O05f3-3UeEyKD57veL5cG-/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1L_EJ0CnQeGpd0QBoOiY7oQ?pwd=rkd8) | [log](./work_dirs/train/repghostnet_0_8x_96M_72.24/train.log) |
54 | | 1.0x | 4.1 | 142 | 62.2 | 74.2 | 91.5 | [gdrive](https://drive.google.com/file/d/1gzfGln60urfY38elpPHVTyv9b94ukn5o/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1CEwuBLV05z7zrVbBrku59w?pwd=z4s7) | [log](./work_dirs/train/repghostnet_1_0x_142M_74.22/train.log) |
55 | | 1.11x | 4.5 | 170 | 71.5 | 75.1 | 92.2 | [gdrive](https://drive.google.com/file/d/14Lk4pKWIUFk1Mb53ooy_GsZbhMmz3iVE/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1Lb54Jiqyt0Jc6X4F_tUYnw?pwd=dwcb) | [log](./work_dirs/train/repghostnet_1_11x_170M_75.07/train.log) |
56 | | 1.3x | 5.5 | 231 | 92.9 | 76.4 | 92.9 | [gdrive](https://drive.google.com/file/d/1dNHpX2JyiuTcDmmyvr8gnAI9t8RM-Nui/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/19x_OUgxRDvwh2g4E9gN12Q?pwd=uux6) | [log](./work_dirs/train/repghostnet_1_3x_231M_76.37/train.log) |
57 | | 1.5x | 6.6 | 301 | 116.9 | 77.5 | 93.5 | [gdrive](https://drive.google.com/file/d/1TWAY654Dz8zcwhDBDN6QDWhV7as30P8e/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/15UWOMRQN5vw99QbgiFWMRw?pwd=3uqq) | [log](./work_dirs/train/repghostnet_1_5x_301M_77.45/train.log) |
58 | | 2.0x | 9.8 | 516 | 190.0 | 78.8 | 94.3 | [gdrive](https://drive.google.com/file/d/12k00eWCXhKxx_fq3ewDhCNX08ftJ-iyP/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1YbtYvIBt3tTqCzvbcjJuBw?pwd=nq1r) | [log](./work_dirs/train/repghostnet_2_0x_516M_78.81/train.log) |
59 |
60 | #### Parameters and FLOPs
61 | We calculate parameters and FLOPs using a modified [thop](https://github.com/Lyken17/pytorch-OpCounter) in ```tools.py```. It only counts infos of convolutional and full-connected layers, without BN. To use it in your code:
62 |
63 | ```python
64 | from tools import cal_flops_params
65 | flops, params = cal_flops_params(model, input_size=(1, 3, 224, 224))
66 | ```
67 |
68 | #### Latency
69 | We first export our pytorch model to a ONNX one, and then use [MNN](https://github.com/alibaba/MNN) to convert it to MNN format, at last evaluate its latency on an ARM-based mobile phone.
70 |
71 | #### Comparisons to MobileOne on iPhone12 CPU
72 | We compare RepGhostNet to [MobileOne](https://arxiv.org/abs/2206.04040) on iPhone12 CPU based on [ModelBench](https://github.com/apple/ml-mobileone/tree/main/ModelBench). The result is shown below.
73 | The MobileOne series models are $\mu0$, $\mu1$, $\mu2$, S0, S1, S2, S3 and S4.
74 |
75 |
76 |
77 |
78 |
79 |
80 | ### Citations
81 | If RepGhostNet helps your research or work, please consider citing:
82 |
83 | ```
84 | @article{chen2022repghost,
85 | title={RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization},
86 | author={Chen, Chengpeng, and Guo, Zichao, and Zeng, Haien, and Xiong, Pengfei and Dong, Jian},
87 | journal={arXiv preprint arXiv:2211.06088},
88 | year={2022}
89 | }
90 | ```
91 |
92 | ### Contact
93 |
94 | If you have any questions, please contact chencp@live.com.
95 |
96 | ---
97 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.6.7
2 | pyyaml
3 |
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | # @Author : chengpeng.chen
2 | # @Email : chencp@live.com
3 | """
4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong.
5 | https://arxiv.org/abs/2211.06088
6 | """
7 | import torch
8 |
9 | from infotool.profile import profile_origin
10 | from infotool.helper import clever_format
11 |
12 | import copy
13 |
14 | def convert_syncbn_to_bn(module):
15 | module_output = module
16 | if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
17 | module_output = torch.nn.BatchNorm2d(
18 | module.num_features,
19 | module.eps,
20 | module.momentum,
21 | module.affine,
22 | module.track_running_stats,
23 | )
24 | if module.affine:
25 | with torch.no_grad():
26 | module_output.weight = module.weight
27 | module_output.bias = module.bias
28 | module_output.running_mean = module.running_mean
29 | module_output.running_var = module.running_var
30 | module_output.num_batches_tracked = module.num_batches_tracked
31 | if hasattr(module, "qconfig"):
32 | module_output.qconfig = module.qconfig
33 |
34 | for name, child in module.named_children():
35 | module_output.add_module(
36 | name, convert_syncbn_to_bn(child)
37 | )
38 | del module
39 | return module_output
40 |
41 |
42 | def cal_flops_params(original_model, input_size):
43 | model = copy.deepcopy(original_model)
44 | model = convert_syncbn_to_bn(model)
45 | input_size = list(input_size)
46 | assert len(input_size) in [3, 4]
47 | if len(input_size) == 4:
48 | if input_size[0] != 1:
49 | print('modify batchsize of input_size from {} to 1'.format(input_size[0]))
50 | input_size[0] = 1
51 |
52 | if len(input_size) == 3:
53 | input_size.insert(0, 1)
54 |
55 | flops, params = profile_origin(model, inputs=(torch.zeros(input_size), ))
56 |
57 | print('flops = {}, params = {}'.format(flops, params))
58 | print('flops = {}, params = {}'.format(clever_format(flops), clever_format(params)))
59 |
60 | return flops, params
61 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # @Author : chengpeng.chen
2 | # @Email : chencp@live.com
3 | """
4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong.
5 | https://arxiv.org/abs/2211.06088
6 | """
7 | #!/usr/bin/env python3
8 | import argparse
9 | import time
10 | import yaml
11 | import os
12 | import logging
13 | from collections import OrderedDict
14 | from contextlib import suppress
15 | from datetime import datetime
16 | import importlib
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torchvision.utils
21 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
22 |
23 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
24 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
25 | convert_splitbn_model, model_parameters
26 | from timm.utils import *
27 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
28 | from timm.optim import create_optimizer_v2, optimizer_kwargs
29 | from timm.scheduler import create_scheduler
30 | from timm.utils import ApexScaler, NativeScaler
31 |
32 | try:
33 | from apex import amp
34 | from apex.parallel import DistributedDataParallel as ApexDDP
35 | from apex.parallel import convert_syncbn_model
36 |
37 | has_apex = True
38 | except ImportError:
39 | has_apex = False
40 |
41 | has_native_amp = False
42 | try:
43 | if getattr(torch.cuda.amp, 'autocast') is not None:
44 | has_native_amp = True
45 | except AttributeError:
46 | pass
47 |
48 | try:
49 | import wandb
50 |
51 | has_wandb = True
52 | except ImportError:
53 | has_wandb = False
54 |
55 | torch.backends.cudnn.benchmark = True
56 | _logger = logging.getLogger('train')
57 |
58 | # The first arg parser parses out only the --config argument, this argument is used to
59 | # load a yaml file containing key-values that override the defaults for the main parser below
60 | config_parser = parser = argparse.ArgumentParser(
61 | description='Training Config', add_help=False)
62 | parser.add_argument(
63 | '-c',
64 | '--config',
65 | default='',
66 | type=str,
67 | metavar='FILE',
68 | help='YAML config file specifying default arguments')
69 |
70 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
71 |
72 | # Dataset / Model parameters
73 | parser.add_argument(
74 | '--data_dir',
75 | metavar='DIR',
76 | default='/disk2/datasets/imagenet',
77 | help='path to dataset')
78 | parser.add_argument(
79 | '--dataset',
80 | '-d',
81 | metavar='NAME',
82 | default='',
83 | help='dataset type (default: ImageFolder/ImageTar if empty)')
84 | parser.add_argument(
85 | '--train-split',
86 | metavar='NAME',
87 | default='train',
88 | help='dataset train split (default: train)')
89 | parser.add_argument(
90 | '--val-split',
91 | metavar='NAME',
92 | default='val',
93 | help='dataset validation split (default: val)')
94 | parser.add_argument(
95 | '--model',
96 | default='',
97 | type=str,
98 | metavar='MODEL',
99 | help='Name of model to train (default: "countception"')
100 | parser.add_argument(
101 | '--pretrained',
102 | action='store_true',
103 | default=False,
104 | help='Start with pretrained version of specified network (if avail)')
105 | parser.add_argument(
106 | '--initial-checkpoint',
107 | default='',
108 | type=str,
109 | metavar='PATH',
110 | help='Initialize model from this checkpoint (default: none)')
111 | parser.add_argument(
112 | '--resume',
113 | default='',
114 | type=str,
115 | metavar='PATH',
116 | help='Resume full model and optimizer state from checkpoint (default: none)'
117 | )
118 | parser.add_argument(
119 | '--no-resume-opt',
120 | action='store_true',
121 | default=False,
122 | help='prevent resume of optimizer state when resuming model')
123 | parser.add_argument(
124 | '--num-classes',
125 | type=int,
126 | default=1000,
127 | metavar='N',
128 | help='number of label classes (Model default if None)')
129 | parser.add_argument(
130 | '--gp',
131 | default=None,
132 | type=str,
133 | metavar='POOL',
134 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.'
135 | )
136 | parser.add_argument(
137 | '--img-size',
138 | type=int,
139 | default=None,
140 | metavar='N',
141 | help='Image patch size (default: None => model default)')
142 | parser.add_argument(
143 | '--input-size',
144 | default=None,
145 | nargs=3,
146 | type=int,
147 | metavar='N N N',
148 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty'
149 | )
150 | parser.add_argument(
151 | '--crop-pct',
152 | default=None,
153 | type=float,
154 | metavar='N',
155 | help='Input image center crop percent (for validation only)')
156 | parser.add_argument(
157 | '--mean',
158 | type=float,
159 | nargs='+',
160 | default=None,
161 | metavar='MEAN',
162 | help='Override mean pixel value of dataset')
163 | parser.add_argument(
164 | '--std',
165 | type=float,
166 | nargs='+',
167 | default=None,
168 | metavar='STD',
169 | help='Override std deviation of of dataset')
170 | parser.add_argument(
171 | '--interpolation',
172 | default='',
173 | type=str,
174 | metavar='NAME',
175 | help='Image resize interpolation type (overrides model)')
176 | parser.add_argument(
177 | '-b',
178 | '--batch-size',
179 | type=int,
180 | default=32,
181 | metavar='N',
182 | help='input batch size for training (default: 32)')
183 | parser.add_argument(
184 | '-vb',
185 | '--validation-batch-size-multiplier',
186 | type=int,
187 | default=1,
188 | metavar='N',
189 | help='ratio of validation batch size to training batch size (default: 1)')
190 |
191 | # Optimizer parameters
192 | parser.add_argument(
193 | '--opt',
194 | default='sgd',
195 | type=str,
196 | metavar='OPTIMIZER',
197 | help='Optimizer (default: "sgd"')
198 | parser.add_argument(
199 | '--opt-eps',
200 | default=None,
201 | type=float,
202 | metavar='EPSILON',
203 | help='Optimizer Epsilon (default: None, use opt default)')
204 | parser.add_argument(
205 | '--opt-betas',
206 | default=None,
207 | type=float,
208 | nargs='+',
209 | metavar='BETA',
210 | help='Optimizer Betas (default: None, use opt default)')
211 | parser.add_argument(
212 | '--momentum',
213 | type=float,
214 | default=0.9,
215 | metavar='M',
216 | help='Optimizer momentum (default: 0.9)')
217 | parser.add_argument(
218 | '--weight-decay',
219 | type=float,
220 | default=0.0001,
221 | help='weight decay (default: 0.0001)')
222 | parser.add_argument(
223 | '--clip-grad',
224 | type=float,
225 | default=None,
226 | metavar='NORM',
227 | help='Clip gradient norm (default: None, no clipping)')
228 | parser.add_argument(
229 | '--clip-mode',
230 | type=str,
231 | default='norm',
232 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
233 |
234 | # Learning rate schedule parameters
235 | parser.add_argument(
236 | '--sched',
237 | default='step',
238 | type=str,
239 | metavar='SCHEDULER',
240 | help='LR scheduler (default: "step"')
241 | parser.add_argument(
242 | '--lr',
243 | type=float,
244 | default=0.01,
245 | metavar='LR',
246 | help='learning rate (default: 0.01)')
247 | parser.add_argument(
248 | '--lr-noise',
249 | type=float,
250 | nargs='+',
251 | default=None,
252 | metavar='pct, pct',
253 | help='learning rate noise on/off epoch percentages')
254 | parser.add_argument(
255 | '--lr-noise-pct',
256 | type=float,
257 | default=0.67,
258 | metavar='PERCENT',
259 | help='learning rate noise limit percent (default: 0.67)')
260 | parser.add_argument(
261 | '--lr-noise-std',
262 | type=float,
263 | default=1.0,
264 | metavar='STDDEV',
265 | help='learning rate noise std-dev (default: 1.0)')
266 | parser.add_argument(
267 | '--lr-cycle-mul',
268 | type=float,
269 | default=1.0,
270 | metavar='MULT',
271 | help='learning rate cycle len multiplier (default: 1.0)')
272 | parser.add_argument(
273 | '--lr-cycle-limit',
274 | type=int,
275 | default=1,
276 | metavar='N',
277 | help='learning rate cycle limit')
278 | parser.add_argument(
279 | '--warmup-lr',
280 | type=float,
281 | default=0.0001,
282 | metavar='LR',
283 | help='warmup learning rate (default: 0.0001)')
284 | parser.add_argument(
285 | '--min-lr',
286 | type=float,
287 | default=1e-5,
288 | metavar='LR',
289 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
290 | parser.add_argument(
291 | '--epochs',
292 | type=int,
293 | default=200,
294 | metavar='N',
295 | help='number of epochs to train (default: 2)')
296 | parser.add_argument(
297 | '--epoch-repeats',
298 | type=float,
299 | default=0.,
300 | metavar='N',
301 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).'
302 | )
303 | parser.add_argument(
304 | '--start-epoch',
305 | default=None,
306 | type=int,
307 | metavar='N',
308 | help='manual epoch number (useful on restarts)')
309 | parser.add_argument(
310 | '--decay-epochs',
311 | type=float,
312 | default=30,
313 | metavar='N',
314 | help='epoch interval to decay LR')
315 | parser.add_argument(
316 | '--warmup-epochs',
317 | type=int,
318 | default=3,
319 | metavar='N',
320 | help='epochs to warmup LR, if scheduler supports')
321 | parser.add_argument(
322 | '--cooldown-epochs',
323 | type=int,
324 | default=0,
325 | metavar='N',
326 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
327 | parser.add_argument(
328 | '--patience-epochs',
329 | type=int,
330 | default=10,
331 | metavar='N',
332 | help='patience epochs for Plateau LR scheduler (default: 10')
333 | parser.add_argument(
334 | '--decay-rate',
335 | '--dr',
336 | type=float,
337 | default=0.1,
338 | metavar='RATE',
339 | help='LR decay rate (default: 0.1)')
340 |
341 | # Augmentation & regularization parameters
342 | parser.add_argument(
343 | '--no-aug',
344 | action='store_true',
345 | default=False,
346 | help='Disable all training augmentation, override other train aug args')
347 | parser.add_argument(
348 | '--scale',
349 | type=float,
350 | nargs='+',
351 | default=[0.08, 1.0],
352 | metavar='PCT',
353 | help='Random resize scale (default: 0.08 1.0)')
354 | parser.add_argument(
355 | '--ratio',
356 | type=float,
357 | nargs='+',
358 | default=[3. / 4., 4. / 3.],
359 | metavar='RATIO',
360 | help='Random resize aspect ratio (default: 0.75 1.33)')
361 | parser.add_argument(
362 | '--hflip',
363 | type=float,
364 | default=0.5,
365 | help='Horizontal flip training aug probability')
366 | parser.add_argument(
367 | '--vflip',
368 | type=float,
369 | default=0.,
370 | help='Vertical flip training aug probability')
371 | parser.add_argument(
372 | '--color-jitter',
373 | type=float,
374 | default=0.4,
375 | metavar='PCT',
376 | help='Color jitter factor (default: 0.4)')
377 | parser.add_argument(
378 | '--aa',
379 | type=str,
380 | default=None,
381 | metavar='NAME',
382 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
383 | parser.add_argument(
384 | '--aug-splits',
385 | type=int,
386 | default=0,
387 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
388 | parser.add_argument(
389 | '--jsd',
390 | action='store_true',
391 | default=False,
392 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
393 | parser.add_argument(
394 | '--reprob',
395 | type=float,
396 | default=0.,
397 | metavar='PCT',
398 | help='Random erase prob (default: 0.)')
399 | parser.add_argument(
400 | '--remode',
401 | type=str,
402 | default='const',
403 | help='Random erase mode (default: "const")')
404 | parser.add_argument(
405 | '--recount', type=int, default=1, help='Random erase count (default: 1)')
406 | parser.add_argument(
407 | '--resplit',
408 | action='store_true',
409 | default=False,
410 | help='Do not random erase first (clean) augmentation split')
411 | parser.add_argument(
412 | '--mixup',
413 | type=float,
414 | default=0.0,
415 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
416 | parser.add_argument(
417 | '--cutmix',
418 | type=float,
419 | default=0.0,
420 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
421 | parser.add_argument(
422 | '--cutmix-minmax',
423 | type=float,
424 | nargs='+',
425 | default=None,
426 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)'
427 | )
428 | parser.add_argument(
429 | '--mixup-prob',
430 | type=float,
431 | default=1.0,
432 | help='Probability of performing mixup or cutmix when either/both is enabled'
433 | )
434 | parser.add_argument(
435 | '--mixup-switch-prob',
436 | type=float,
437 | default=0.5,
438 | help='Probability of switching to cutmix when both mixup and cutmix enabled'
439 | )
440 | parser.add_argument(
441 | '--mixup-mode',
442 | type=str,
443 | default='batch',
444 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
445 | parser.add_argument(
446 | '--mixup-off-epoch',
447 | default=0,
448 | type=int,
449 | metavar='N',
450 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
451 | parser.add_argument(
452 | '--smoothing',
453 | type=float,
454 | default=0.1,
455 | help='Label smoothing (default: 0.1)')
456 | parser.add_argument(
457 | '--train-interpolation',
458 | type=str,
459 | default='random',
460 | help='Training interpolation (random, bilinear, bicubic default: "random")')
461 | parser.add_argument(
462 | '--drop',
463 | type=float,
464 | default=0.2,
465 | metavar='PCT',
466 | help='Dropout rate (default: 0.)')
467 | parser.add_argument(
468 | '--drop-connect',
469 | type=float,
470 | default=None,
471 | metavar='PCT',
472 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
473 | parser.add_argument(
474 | '--drop-path',
475 | type=float,
476 | default=None,
477 | metavar='PCT',
478 | help='Drop path rate (default: None)')
479 | parser.add_argument(
480 | '--drop-block',
481 | type=float,
482 | default=None,
483 | metavar='PCT',
484 | help='Drop block rate (default: None)')
485 |
486 | # Batch norm parameters (only works with gen_efficientnet based models currently)
487 | parser.add_argument(
488 | '--bn-tf',
489 | action='store_true',
490 | default=False,
491 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)'
492 | )
493 | parser.add_argument(
494 | '--bn-momentum',
495 | type=float,
496 | default=None,
497 | help='BatchNorm momentum override (if not None)')
498 | parser.add_argument(
499 | '--bn-eps',
500 | type=float,
501 | default=None,
502 | help='BatchNorm epsilon override (if not None)')
503 | parser.add_argument(
504 | '--sync-bn',
505 | action='store_true',
506 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
507 | parser.add_argument(
508 | '--dist-bn',
509 | type=str,
510 | default='',
511 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")'
512 | )
513 | parser.add_argument(
514 | '--split-bn',
515 | action='store_true',
516 | help='Enable separate BN layers per augmentation split.')
517 |
518 | # Model Exponential Moving Average
519 | parser.add_argument(
520 | '--model-ema',
521 | action='store_true',
522 | default=False,
523 | help='Enable tracking moving average of model weights')
524 | parser.add_argument(
525 | '--model-ema-force-cpu',
526 | action='store_true',
527 | default=False,
528 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.'
529 | )
530 | parser.add_argument(
531 | '--model-ema-decay',
532 | type=float,
533 | default=0.9998,
534 | help='decay factor for model weights moving average (default: 0.9998)')
535 |
536 | # Misc
537 | parser.add_argument(
538 | '--seed',
539 | type=int,
540 | default=42,
541 | metavar='S',
542 | help='random seed (default: 42)')
543 | parser.add_argument(
544 | '--log-interval',
545 | type=int,
546 | default=50,
547 | metavar='N',
548 | help='how many batches to wait before logging training status')
549 | parser.add_argument(
550 | '--recovery-interval',
551 | type=int,
552 | default=0,
553 | metavar='N',
554 | help='how many batches to wait before writing recovery checkpoint')
555 | parser.add_argument(
556 | '--checkpoint-hist',
557 | type=int,
558 | default=5,
559 | metavar='N',
560 | help='number of checkpoints to keep (default: 10)')
561 | parser.add_argument(
562 | '-j',
563 | '--workers',
564 | type=int,
565 | default=4,
566 | metavar='N',
567 | help='how many training processes to use (default: 1)')
568 | parser.add_argument(
569 | '--save-images',
570 | action='store_true',
571 | default=False,
572 | help='save images of input bathes every log interval for debugging')
573 | parser.add_argument(
574 | '--amp',
575 | action='store_true',
576 | default=False,
577 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
578 | parser.add_argument(
579 | '--apex-amp',
580 | action='store_true',
581 | default=False,
582 | help='Use NVIDIA Apex AMP mixed precision')
583 | parser.add_argument(
584 | '--native-amp',
585 | action='store_true',
586 | default=False,
587 | help='Use Native Torch AMP mixed precision')
588 | parser.add_argument(
589 | '--channels-last',
590 | action='store_true',
591 | default=False,
592 | help='Use channels_last memory layout')
593 | parser.add_argument(
594 | '--pin-mem',
595 | action='store_true',
596 | default=False,
597 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.'
598 | )
599 | parser.add_argument(
600 | '--no-prefetcher',
601 | action='store_true',
602 | default=False,
603 | help='disable fast prefetcher')
604 | parser.add_argument(
605 | '--output',
606 | default='',
607 | type=str,
608 | metavar='PATH',
609 | help='path to output folder (default: none, current dir)')
610 | parser.add_argument(
611 | '--experiment',
612 | default='',
613 | type=str,
614 | metavar='NAME',
615 | help='name of train experiment, name of sub-folder for output')
616 | parser.add_argument(
617 | '--eval-metric',
618 | default='top1',
619 | type=str,
620 | metavar='EVAL_METRIC',
621 | help='Best metric (default: "top1"')
622 | parser.add_argument(
623 | '--tta',
624 | type=int,
625 | default=0,
626 | metavar='N',
627 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)'
628 | )
629 | parser.add_argument("--local_rank", default=0, type=int)
630 | parser.add_argument(
631 | '--use-multi-epochs-loader',
632 | action='store_true',
633 | default=False,
634 | help='use the multi-epochs-loader to save time at the beginning of every epoch'
635 | )
636 | parser.add_argument(
637 | '--torchscript',
638 | dest='torchscript',
639 | action='store_true',
640 | help='convert model torchscript for inference')
641 | parser.add_argument(
642 | '--log-wandb',
643 | action='store_true',
644 | default=False,
645 | help='log training and validation metrics to wandb')
646 |
647 |
648 | def _parse_args():
649 | # Do we have a config file to parse?
650 | args_config, remaining = config_parser.parse_known_args()
651 | if args_config.config:
652 | with open(args_config.config, 'r') as f:
653 | cfg = yaml.safe_load(f)
654 | parser.set_defaults(**cfg)
655 |
656 | # The main arg parser parses the rest of the args, the usual
657 | # defaults will have been overridden if config file specified.
658 | args = parser.parse_args(remaining)
659 |
660 | # Cache the args as a text string to save them in the output dir later
661 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
662 | return args, args_text
663 |
664 |
665 | def main():
666 | setup_default_logging()
667 | args, args_text = _parse_args()
668 |
669 | if args.log_wandb:
670 | if has_wandb:
671 | wandb.init(project=args.experiment, config=args)
672 | else:
673 | _logger.warning(
674 | "You've requested to log metrics to wandb but package not found. "
675 | "Metrics not being logged to wandb, try `pip install wandb`")
676 |
677 | args.prefetcher = not args.no_prefetcher
678 | args.distributed = False
679 | if 'WORLD_SIZE' in os.environ:
680 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
681 | args.device = 'cuda:0'
682 | args.world_size = 1
683 | args.rank = 0 # global rank
684 | if args.distributed:
685 | args.device = 'cuda:%d' % args.local_rank
686 | torch.cuda.set_device(args.local_rank)
687 | torch.distributed.init_process_group(
688 | backend='nccl', init_method='env://')
689 | args.world_size = torch.distributed.get_world_size()
690 | args.rank = torch.distributed.get_rank()
691 | _logger.info(
692 | 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
693 | % (args.rank, args.world_size))
694 | else:
695 | _logger.info('Training with a single process on 1 GPUs.')
696 | assert args.rank >= 0
697 |
698 | # resolve AMP arguments based on PyTorch / Apex availability
699 | use_amp = None
700 | if args.amp:
701 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
702 | if has_native_amp:
703 | args.native_amp = True
704 | elif has_apex:
705 | args.apex_amp = True
706 | if args.apex_amp and has_apex:
707 | use_amp = 'apex'
708 | elif args.native_amp and has_native_amp:
709 | use_amp = 'native'
710 | elif args.apex_amp or args.native_amp:
711 | _logger.warning(
712 | "Neither APEX or native Torch AMP is available, using float32. "
713 | "Install NVIDA apex or upgrade to PyTorch 1.6")
714 |
715 | random_seed(args.seed, args.rank)
716 |
717 | # model = create_model(
718 | # args.model,
719 | # pretrained=args.pretrained,
720 | # num_classes=args.num_classes,
721 | # drop_rate=args.drop,
722 | # drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
723 | # drop_path_rate=args.drop_path,
724 | # drop_block_rate=args.drop_block,
725 | # global_pool=args.gp,
726 | # bn_tf=args.bn_tf,
727 | # bn_momentum=args.bn_momentum,
728 | # bn_eps=args.bn_eps,
729 | # scriptable=args.torchscript,
730 | # checkpoint_path=args.initial_checkpoint)
731 |
732 | m = importlib.import_module(f"model.{args.model.split('.')[0]}")
733 | model = getattr(m, args.model.split('.')[1])(dropout=args.drop)
734 |
735 | if args.num_classes is None:
736 | assert hasattr(
737 | model, 'num_classes'
738 | ), 'Model must have `num_classes` attr if not set on cmd line/config.'
739 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
740 |
741 | if args.local_rank == 0:
742 | _logger.info(
743 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}'
744 | )
745 |
746 | data_config = resolve_data_config(
747 | vars(args), model=model, verbose=args.local_rank == 0)
748 |
749 | # setup augmentation batch splits for contrastive loss or split bn
750 | num_aug_splits = 0
751 | if args.aug_splits > 0:
752 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
753 | num_aug_splits = args.aug_splits
754 |
755 | # enable split bn (separate bn stats per batch-portion)
756 | if args.split_bn:
757 | assert num_aug_splits > 1 or args.resplit
758 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
759 |
760 | # move model to GPU, enable channels last layout if set
761 | model.cuda()
762 | if args.channels_last:
763 | model = model.to(memory_format=torch.channels_last)
764 |
765 | # setup synchronized BatchNorm for distributed training
766 | if args.distributed and args.sync_bn:
767 | assert not args.split_bn
768 | if has_apex and use_amp != 'native':
769 | # Apex SyncBN preferred unless native amp is activated
770 | model = convert_syncbn_model(model)
771 | else:
772 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
773 | if args.local_rank == 0:
774 | _logger.info(
775 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
776 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
777 | )
778 |
779 | if args.torchscript:
780 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
781 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
782 | model = torch.jit.script(model)
783 |
784 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
785 |
786 | # setup automatic mixed-precision (AMP) loss scaling and op casting
787 | amp_autocast = suppress # do nothing
788 | loss_scaler = None
789 | if use_amp == 'apex':
790 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
791 | loss_scaler = ApexScaler()
792 | if args.local_rank == 0:
793 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
794 | elif use_amp == 'native':
795 | amp_autocast = torch.cuda.amp.autocast
796 | loss_scaler = NativeScaler()
797 | if args.local_rank == 0:
798 | _logger.info(
799 | 'Using native Torch AMP. Training in mixed precision.')
800 | else:
801 | if args.local_rank == 0:
802 | _logger.info('AMP not enabled. Training in float32.')
803 |
804 | # optionally resume from a checkpoint
805 | resume_epoch = None
806 | if args.resume:
807 | resume_epoch = resume_checkpoint(
808 | model,
809 | args.resume,
810 | optimizer=None if args.no_resume_opt else optimizer,
811 | loss_scaler=None if args.no_resume_opt else loss_scaler,
812 | log_info=args.local_rank == 0)
813 |
814 | # setup exponential moving average of model weights, SWA could be used here too
815 | model_ema = None
816 | if args.model_ema:
817 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
818 | model_ema = ModelEmaV2(
819 | model,
820 | decay=args.model_ema_decay,
821 | device='cpu' if args.model_ema_force_cpu else None)
822 | if args.resume:
823 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
824 |
825 | # setup distributed training
826 | if args.distributed:
827 | if has_apex and use_amp != 'native':
828 | # Apex DDP preferred unless native amp is activated
829 | if args.local_rank == 0:
830 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
831 | model = ApexDDP(model, delay_allreduce=True)
832 | else:
833 | if args.local_rank == 0:
834 | _logger.info("Using native Torch DistributedDataParallel.")
835 | model = NativeDDP(
836 | model, device_ids=[args.local_rank
837 | ]) # can use device str in Torch >= 1.1
838 | # NOTE: EMA model does not need to be wrapped by DDP
839 |
840 | # setup learning rate schedule and starting epoch
841 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
842 | start_epoch = 0
843 | if args.start_epoch is not None:
844 | # a specified start_epoch will always override the resume epoch
845 | start_epoch = args.start_epoch
846 | elif resume_epoch is not None:
847 | start_epoch = resume_epoch
848 | if lr_scheduler is not None and start_epoch > 0:
849 | lr_scheduler.step(start_epoch)
850 |
851 | if args.local_rank == 0:
852 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
853 |
854 | # create the train and eval datasets
855 | dataset_train = create_dataset(
856 | args.dataset,
857 | root=args.data_dir,
858 | split=args.train_split,
859 | is_training=True,
860 | batch_size=args.batch_size,
861 | repeats=args.epoch_repeats)
862 | dataset_eval = create_dataset(
863 | args.dataset,
864 | root=args.data_dir,
865 | split=args.val_split,
866 | is_training=False,
867 | batch_size=args.batch_size)
868 |
869 | # setup mixup / cutmix
870 | collate_fn = None
871 | mixup_fn = None
872 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
873 | if mixup_active:
874 | mixup_args = dict(
875 | mixup_alpha=args.mixup,
876 | cutmix_alpha=args.cutmix,
877 | cutmix_minmax=args.cutmix_minmax,
878 | prob=args.mixup_prob,
879 | switch_prob=args.mixup_switch_prob,
880 | mode=args.mixup_mode,
881 | label_smoothing=args.smoothing,
882 | num_classes=args.num_classes)
883 | if args.prefetcher:
884 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
885 | collate_fn = FastCollateMixup(**mixup_args)
886 | else:
887 | mixup_fn = Mixup(**mixup_args)
888 |
889 | # wrap dataset in AugMix helper
890 | if num_aug_splits > 1:
891 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
892 |
893 | # create data loaders w/ augmentation pipeiine
894 | train_interpolation = args.train_interpolation
895 | if args.no_aug or not train_interpolation:
896 | train_interpolation = data_config['interpolation']
897 | loader_train = create_loader(
898 | dataset_train,
899 | input_size=data_config['input_size'],
900 | batch_size=args.batch_size,
901 | is_training=True,
902 | use_prefetcher=args.prefetcher,
903 | no_aug=args.no_aug,
904 | re_prob=args.reprob,
905 | re_mode=args.remode,
906 | re_count=args.recount,
907 | re_split=args.resplit,
908 | scale=args.scale,
909 | ratio=args.ratio,
910 | hflip=args.hflip,
911 | vflip=args.vflip,
912 | color_jitter=args.color_jitter,
913 | auto_augment=args.aa,
914 | num_aug_splits=num_aug_splits,
915 | interpolation=train_interpolation,
916 | mean=data_config['mean'],
917 | std=data_config['std'],
918 | num_workers=args.workers,
919 | distributed=args.distributed,
920 | collate_fn=collate_fn,
921 | pin_memory=args.pin_mem,
922 | use_multi_epochs_loader=args.use_multi_epochs_loader)
923 |
924 | loader_eval = create_loader(
925 | dataset_eval,
926 | input_size=data_config['input_size'],
927 | batch_size=args.validation_batch_size_multiplier * args.batch_size,
928 | is_training=False,
929 | use_prefetcher=args.prefetcher,
930 | interpolation=data_config['interpolation'],
931 | mean=data_config['mean'],
932 | std=data_config['std'],
933 | num_workers=args.workers,
934 | distributed=args.distributed,
935 | crop_pct=data_config['crop_pct'],
936 | pin_memory=args.pin_mem, )
937 |
938 | # setup loss function
939 | if args.jsd:
940 | assert num_aug_splits > 1 # JSD only valid with aug splits set
941 | train_loss_fn = JsdCrossEntropy(
942 | num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
943 | elif mixup_active:
944 | # smoothing is handled with mixup target transform
945 | train_loss_fn = SoftTargetCrossEntropy().cuda()
946 | elif args.smoothing:
947 | train_loss_fn = LabelSmoothingCrossEntropy(
948 | smoothing=args.smoothing).cuda()
949 | else:
950 | train_loss_fn = nn.CrossEntropyLoss().cuda()
951 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
952 |
953 | # setup checkpoint saver and eval metric tracking
954 | eval_metric = args.eval_metric
955 | best_metric = None
956 | best_epoch = None
957 | saver = None
958 | output_dir = None
959 | if args.rank == 0:
960 | if args.experiment:
961 | exp_name = args.experiment
962 | else:
963 | exp_name = '-'.join([
964 | datetime.now().strftime("%Y%m%d-%H%M%S"),
965 | safe_model_name(args.model), str(data_config['input_size'][-1])
966 | ])
967 | output_dir = get_outdir(args.output
968 | if args.output else './output/train', exp_name)
969 | handler = logging.FileHandler(os.path.join(output_dir, 'train.log'))
970 | _logger.addHandler(handler)
971 | decreasing = True if eval_metric == 'loss' else False
972 | saver = CheckpointSaver(
973 | model=model,
974 | optimizer=optimizer,
975 | args=args,
976 | model_ema=model_ema,
977 | amp_scaler=loss_scaler,
978 | checkpoint_dir=output_dir,
979 | recovery_dir=output_dir,
980 | decreasing=decreasing,
981 | max_history=args.checkpoint_hist)
982 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
983 | f.write(args_text)
984 |
985 | try:
986 | for epoch in range(start_epoch, num_epochs):
987 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
988 | loader_train.sampler.set_epoch(epoch)
989 |
990 | train_metrics = train_one_epoch(
991 | epoch,
992 | model,
993 | loader_train,
994 | optimizer,
995 | train_loss_fn,
996 | args,
997 | lr_scheduler=lr_scheduler,
998 | saver=saver,
999 | output_dir=output_dir,
1000 | amp_autocast=amp_autocast,
1001 | loss_scaler=loss_scaler,
1002 | model_ema=model_ema,
1003 | mixup_fn=mixup_fn)
1004 |
1005 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
1006 | if args.local_rank == 0:
1007 | _logger.info(
1008 | "Distributing BatchNorm running means and vars")
1009 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
1010 |
1011 | eval_metrics = validate(
1012 | model,
1013 | loader_eval,
1014 | validate_loss_fn,
1015 | args,
1016 | amp_autocast=amp_autocast)
1017 |
1018 | if model_ema is not None and not args.model_ema_force_cpu:
1019 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'
1020 | ):
1021 | distribute_bn(model_ema, args.world_size,
1022 | args.dist_bn == 'reduce')
1023 | ema_eval_metrics = validate(
1024 | model_ema.module,
1025 | loader_eval,
1026 | validate_loss_fn,
1027 | args,
1028 | amp_autocast=amp_autocast,
1029 | log_suffix=' (EMA)')
1030 | eval_metrics = ema_eval_metrics
1031 |
1032 | if lr_scheduler is not None:
1033 | # step LR for next epoch
1034 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
1035 |
1036 | if output_dir is not None:
1037 | update_summary(
1038 | epoch,
1039 | train_metrics,
1040 | eval_metrics,
1041 | os.path.join(output_dir, 'summary.csv'),
1042 | write_header=best_metric is None,
1043 | log_wandb=args.log_wandb and has_wandb)
1044 |
1045 | if saver is not None:
1046 | # save proper checkpoint with eval metric
1047 | save_metric = eval_metrics[eval_metric]
1048 | best_metric, best_epoch = saver.save_checkpoint(
1049 | epoch, metric=save_metric)
1050 |
1051 | except KeyboardInterrupt:
1052 | pass
1053 | if best_metric is not None:
1054 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric,
1055 | best_epoch))
1056 |
1057 |
1058 | def train_one_epoch(epoch,
1059 | model,
1060 | loader,
1061 | optimizer,
1062 | loss_fn,
1063 | args,
1064 | lr_scheduler=None,
1065 | saver=None,
1066 | output_dir=None,
1067 | amp_autocast=suppress,
1068 | loss_scaler=None,
1069 | model_ema=None,
1070 | mixup_fn=None):
1071 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
1072 | if args.prefetcher and loader.mixup_enabled:
1073 | loader.mixup_enabled = False
1074 | elif mixup_fn is not None:
1075 | mixup_fn.mixup_enabled = False
1076 |
1077 | second_order = hasattr(optimizer,
1078 | 'is_second_order') and optimizer.is_second_order
1079 | batch_time_m = AverageMeter()
1080 | data_time_m = AverageMeter()
1081 | losses_m = AverageMeter()
1082 | top1_m = AverageMeter()
1083 | top5_m = AverageMeter()
1084 |
1085 | model.train()
1086 |
1087 | end = time.time()
1088 | last_idx = len(loader) - 1
1089 | num_updates = epoch * len(loader)
1090 |
1091 | for batch_idx, (input, target) in enumerate(loader):
1092 | last_batch = batch_idx == last_idx
1093 | data_time_m.update(time.time() - end)
1094 | if not args.prefetcher:
1095 | input, target = input.cuda(), target.cuda()
1096 | if mixup_fn is not None:
1097 | input, target = mixup_fn(input, target)
1098 |
1099 | if args.channels_last:
1100 | input = input.contiguous(memory_format=torch.channels_last)
1101 |
1102 | with amp_autocast():
1103 | output = model(input)
1104 | loss = loss_fn(output, target)
1105 | if len(target.shape) > 1 and target.shape[1] > 1:
1106 | label = torch.argmax(target, dim=1)
1107 | acc1, acc5 = accuracy(output, label, topk=(1, 5))
1108 | else:
1109 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
1110 |
1111 | if not args.distributed:
1112 | losses_m.update(loss.item(), input.size(0))
1113 | top1_m.update(acc1.item(), output.size(0))
1114 | top5_m.update(acc5.item(), output.size(0))
1115 |
1116 | optimizer.zero_grad()
1117 | if loss_scaler is not None:
1118 | loss_scaler(
1119 | loss,
1120 | optimizer,
1121 | clip_grad=args.clip_grad,
1122 | clip_mode=args.clip_mode,
1123 | parameters=model_parameters(
1124 | model, exclude_head='agc' in args.clip_mode),
1125 | create_graph=second_order)
1126 | else:
1127 | loss.backward(create_graph=second_order)
1128 | if args.clip_grad is not None:
1129 | dispatch_clip_grad(
1130 | model_parameters(
1131 | model, exclude_head='agc' in args.clip_mode),
1132 | value=args.clip_grad,
1133 | mode=args.clip_mode)
1134 | optimizer.step()
1135 |
1136 | if model_ema is not None:
1137 | model_ema.update(model)
1138 |
1139 | torch.cuda.synchronize()
1140 | num_updates += 1
1141 | batch_time_m.update(time.time() - end)
1142 | if last_batch or batch_idx % args.log_interval == 0:
1143 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
1144 | lr = sum(lrl) / len(lrl)
1145 |
1146 | if args.distributed:
1147 | reduced_loss = reduce_tensor(loss.data, args.world_size)
1148 | losses_m.update(reduced_loss.item(), input.size(0))
1149 | acc1 = reduce_tensor(acc1, args.world_size)
1150 | acc5 = reduce_tensor(acc5, args.world_size)
1151 | top1_m.update(acc1.item(), output.size(0))
1152 | top5_m.update(acc5.item(), output.size(0))
1153 |
1154 | if args.local_rank == 0:
1155 | now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
1156 | _logger.info(
1157 | '{} '
1158 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
1159 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
1160 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
1161 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) '
1162 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
1163 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
1164 | 'LR: {lr:.3e} '
1165 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
1166 | now,
1167 | epoch,
1168 | batch_idx,
1169 | len(loader),
1170 | 100. * batch_idx / last_idx,
1171 | loss=losses_m,
1172 | top1=top1_m,
1173 | top5=top5_m,
1174 | batch_time=batch_time_m,
1175 | rate=input.size(
1176 | 0) * args.world_size / batch_time_m.val,
1177 | rate_avg=input.size(0) * args.world_size /
1178 | batch_time_m.avg,
1179 | lr=lr,
1180 | data_time=data_time_m))
1181 |
1182 | if args.save_images and output_dir:
1183 | torchvision.utils.save_image(
1184 | input,
1185 | os.path.join(output_dir,
1186 | 'train-batch-%d.jpg' % batch_idx),
1187 | padding=0,
1188 | normalize=True)
1189 |
1190 | if saver is not None and args.recovery_interval and (
1191 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
1192 | saver.save_recovery(epoch, batch_idx=batch_idx)
1193 |
1194 | if lr_scheduler is not None:
1195 | lr_scheduler.step_update(
1196 | num_updates=num_updates, metric=losses_m.avg)
1197 |
1198 | end = time.time()
1199 | # end for
1200 |
1201 | if hasattr(optimizer, 'sync_lookahead'):
1202 | optimizer.sync_lookahead()
1203 |
1204 | return OrderedDict([('loss', losses_m.avg)])
1205 |
1206 |
1207 | def validate(model,
1208 | loader,
1209 | loss_fn,
1210 | args,
1211 | amp_autocast=suppress,
1212 | log_suffix=''):
1213 | batch_time_m = AverageMeter()
1214 | losses_m = AverageMeter()
1215 | top1_m = AverageMeter()
1216 | top5_m = AverageMeter()
1217 |
1218 | model.eval()
1219 |
1220 | end = time.time()
1221 | last_idx = len(loader) - 1
1222 | with torch.no_grad():
1223 | for batch_idx, (input, target) in enumerate(loader):
1224 | last_batch = batch_idx == last_idx
1225 | if not args.prefetcher:
1226 | input = input.cuda()
1227 | target = target.cuda()
1228 | if args.channels_last:
1229 | input = input.contiguous(memory_format=torch.channels_last)
1230 |
1231 | with amp_autocast():
1232 | output = model(input)
1233 | if isinstance(output, (tuple, list)):
1234 | output = output[0]
1235 |
1236 | # augmentation reduction
1237 | reduce_factor = args.tta
1238 | if reduce_factor > 1:
1239 | output = output.unfold(0, reduce_factor, reduce_factor).mean(
1240 | dim=2)
1241 | target = target[0:target.size(0):reduce_factor]
1242 |
1243 | loss = loss_fn(output, target)
1244 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
1245 |
1246 | if args.distributed:
1247 | reduced_loss = reduce_tensor(loss.data, args.world_size)
1248 | acc1 = reduce_tensor(acc1, args.world_size)
1249 | acc5 = reduce_tensor(acc5, args.world_size)
1250 | else:
1251 | reduced_loss = loss.data
1252 |
1253 | torch.cuda.synchronize()
1254 |
1255 | losses_m.update(reduced_loss.item(), input.size(0))
1256 | top1_m.update(acc1.item(), output.size(0))
1257 | top5_m.update(acc5.item(), output.size(0))
1258 |
1259 | batch_time_m.update(time.time() - end)
1260 | end = time.time()
1261 | if args.local_rank == 0 and (last_batch or
1262 | batch_idx % args.log_interval == 0):
1263 | log_name = 'Test' + log_suffix
1264 | _logger.info(
1265 | '{0}: [{1:>4d}/{2}] '
1266 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
1267 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
1268 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
1269 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
1270 | log_name,
1271 | batch_idx,
1272 | last_idx,
1273 | batch_time=batch_time_m,
1274 | loss=losses_m,
1275 | top1=top1_m,
1276 | top5=top5_m))
1277 |
1278 | metrics = OrderedDict(
1279 | [('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
1280 |
1281 | return metrics
1282 |
1283 |
1284 | if __name__ == '__main__':
1285 | main()
1286 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # Default
2 | # For small models
3 | # bash distributed_train.sh 8 --model repghost.repghostnet_0_5x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/
4 | # bash distributed_train.sh 8 --model repghost.repghostnet_0_58x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/
5 | # bash distributed_train.sh 8 --model repghost.repghostnet_0_8x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/
6 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_0x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/
7 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_11x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --remode pixel --reprob 0.2 --output work_dirs/train/
8 |
9 |
10 | # For large models
11 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_3x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --output work_dirs/train/
12 | # bash distributed_train.sh 8 --model repghost.repghostnet_1_5x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --output work_dirs/train/
13 | # bash distributed_train.sh 8 --model repghost.repghostnet_2_0x -b 128 --lr 0.6 --sched cosine --epochs 300 --opt sgd -j 7 --warmup-epochs 5 --warmup-lr 1e-4 --weight-decay 1e-5 --drop 0.2 --amp --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --output work_dirs/train/
14 |
15 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | # @Author : chengpeng.chen
2 | # @Email : chencp@live.com
3 | """
4 | RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization By Chengpeng Chen, Zichao Guo, Haien Zeng, Pengfei Xiong, and Jian Dong.
5 | https://arxiv.org/abs/2211.06088
6 | """
7 | #!/usr/bin/env python3
8 | import argparse
9 | import time
10 | import yaml
11 | import os
12 | import logging
13 | from collections import OrderedDict
14 | from contextlib import suppress
15 | import importlib
16 |
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
20 |
21 | from timm.data import create_dataset, create_loader, resolve_data_config
22 | from timm.models import safe_model_name, load_checkpoint, \
23 | convert_splitbn_model
24 | from timm.utils import *
25 |
26 |
27 | try:
28 | from apex import amp
29 | from apex.parallel import DistributedDataParallel as ApexDDP
30 | from apex.parallel import convert_syncbn_model
31 |
32 | has_apex = True
33 | except ImportError:
34 | has_apex = False
35 |
36 | has_native_amp = False
37 | try:
38 | if getattr(torch.cuda.amp, 'autocast') is not None:
39 | has_native_amp = True
40 | except AttributeError:
41 | pass
42 |
43 | try:
44 | import wandb
45 |
46 | has_wandb = True
47 | except ImportError:
48 | has_wandb = False
49 |
50 | torch.backends.cudnn.benchmark = True
51 | _logger = logging.getLogger('validate')
52 |
53 | # The first arg parser parses out only the --config argument, this argument is used to
54 | # load a yaml file containing key-values that override the defaults for the main parser below
55 | config_parser = parser = argparse.ArgumentParser(
56 | description='Training Config', add_help=False)
57 | parser.add_argument(
58 | '-c',
59 | '--config',
60 | default='',
61 | type=str,
62 | metavar='FILE',
63 | help='YAML config file specifying default arguments')
64 |
65 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
66 |
67 | # Dataset / Model parameters
68 | parser.add_argument(
69 | '--data_dir',
70 | metavar='DIR',
71 | default='/disk2/datasets/imagenet',
72 | help='path to dataset')
73 | parser.add_argument(
74 | '--dataset',
75 | '-d',
76 | metavar='NAME',
77 | default='',
78 | help='dataset type (default: ImageFolder/ImageTar if empty)')
79 | parser.add_argument(
80 | '--train-split',
81 | metavar='NAME',
82 | default='train',
83 | help='dataset train split (default: train)')
84 | parser.add_argument(
85 | '--val-split',
86 | metavar='NAME',
87 | default='val',
88 | help='dataset validation split (default: val)')
89 | parser.add_argument(
90 | '--model',
91 | default='',
92 | type=str,
93 | metavar='MODEL',
94 | help='Name of model to train (default: "countception"')
95 | parser.add_argument(
96 | '--resume',
97 | default='',
98 | type=str,
99 | metavar='PATH',
100 | help='Resume full model and optimizer state from checkpoint (default: none)'
101 | )
102 | parser.add_argument(
103 | '--num-classes',
104 | type=int,
105 | default=1000,
106 | metavar='N',
107 | help='number of label classes (Model default if None)')
108 | parser.add_argument(
109 | '--gp',
110 | default=None,
111 | type=str,
112 | metavar='POOL',
113 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.'
114 | )
115 | parser.add_argument(
116 | '--img-size',
117 | type=int,
118 | default=None,
119 | metavar='N',
120 | help='Image patch size (default: None => model default)')
121 | parser.add_argument(
122 | '--input-size',
123 | default=None,
124 | nargs=3,
125 | type=int,
126 | metavar='N N N',
127 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty'
128 | )
129 | parser.add_argument(
130 | '--crop-pct',
131 | default=None,
132 | type=float,
133 | metavar='N',
134 | help='Input image center crop percent (for validation only)')
135 | parser.add_argument(
136 | '--mean',
137 | type=float,
138 | nargs='+',
139 | default=None,
140 | metavar='MEAN',
141 | help='Override mean pixel value of dataset')
142 | parser.add_argument(
143 | '--std',
144 | type=float,
145 | nargs='+',
146 | default=None,
147 | metavar='STD',
148 | help='Override std deviation of of dataset')
149 | parser.add_argument(
150 | '--interpolation',
151 | default='',
152 | type=str,
153 | metavar='NAME',
154 | help='Image resize interpolation type (overrides model)')
155 | parser.add_argument(
156 | '-b',
157 | '--batch-size',
158 | type=int,
159 | default=32,
160 | metavar='N',
161 | help='input batch size for training (default: 32)')
162 | parser.add_argument(
163 | '-vb',
164 | '--validation-batch-size-multiplier',
165 | type=int,
166 | default=1,
167 | metavar='N',
168 | help='ratio of validation batch size to training batch size (default: 1)')
169 | parser.add_argument(
170 | '--aug-splits',
171 | type=int,
172 | default=0,
173 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
174 | parser.add_argument(
175 | '--drop',
176 | type=float,
177 | default=0.2,
178 | metavar='PCT',
179 | help='Dropout rate (default: 0.)')
180 | parser.add_argument(
181 | '--drop-connect',
182 | type=float,
183 | default=None,
184 | metavar='PCT',
185 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
186 | parser.add_argument(
187 | '--drop-path',
188 | type=float,
189 | default=None,
190 | metavar='PCT',
191 | help='Drop path rate (default: None)')
192 | parser.add_argument(
193 | '--drop-block',
194 | type=float,
195 | default=None,
196 | metavar='PCT',
197 | help='Drop block rate (default: None)')
198 |
199 | # Batch norm parameters (only works with gen_efficientnet based models currently)
200 | parser.add_argument(
201 | '--bn-tf',
202 | action='store_true',
203 | default=False,
204 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)'
205 | )
206 | parser.add_argument(
207 | '--bn-momentum',
208 | type=float,
209 | default=None,
210 | help='BatchNorm momentum override (if not None)')
211 | parser.add_argument(
212 | '--bn-eps',
213 | type=float,
214 | default=None,
215 | help='BatchNorm epsilon override (if not None)')
216 | parser.add_argument(
217 | '--sync-bn',
218 | action='store_true',
219 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
220 | parser.add_argument(
221 | '--dist-bn',
222 | type=str,
223 | default='',
224 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")'
225 | )
226 | parser.add_argument(
227 | '--split-bn',
228 | action='store_true',
229 | help='Enable separate BN layers per augmentation split.')
230 |
231 | # Model Exponential Moving Average
232 | parser.add_argument(
233 | '--model-ema',
234 | action='store_true',
235 | default=False,
236 | help='Enable tracking moving average of model weights')
237 | parser.add_argument(
238 | '--model-ema-force-cpu',
239 | action='store_true',
240 | default=False,
241 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.'
242 | )
243 | parser.add_argument(
244 | '--model-ema-decay',
245 | type=float,
246 | default=0.9998,
247 | help='decay factor for model weights moving average (default: 0.9998)')
248 |
249 | # Misc
250 | parser.add_argument(
251 | '--seed',
252 | type=int,
253 | default=42,
254 | metavar='S',
255 | help='random seed (default: 42)')
256 | parser.add_argument(
257 | '--log-interval',
258 | type=int,
259 | default=50,
260 | metavar='N',
261 | help='how many batches to wait before logging training status')
262 | parser.add_argument(
263 | '--recovery-interval',
264 | type=int,
265 | default=0,
266 | metavar='N',
267 | help='how many batches to wait before writing recovery checkpoint')
268 | parser.add_argument(
269 | '--checkpoint-hist',
270 | type=int,
271 | default=10,
272 | metavar='N',
273 | help='number of checkpoints to keep (default: 10)')
274 | parser.add_argument(
275 | '-j',
276 | '--workers',
277 | type=int,
278 | default=4,
279 | metavar='N',
280 | help='how many training processes to use (default: 1)')
281 | parser.add_argument(
282 | '--save-images',
283 | action='store_true',
284 | default=False,
285 | help='save images of input bathes every log interval for debugging')
286 | parser.add_argument(
287 | '--amp',
288 | action='store_true',
289 | default=False,
290 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
291 | parser.add_argument(
292 | '--apex-amp',
293 | action='store_true',
294 | default=False,
295 | help='Use NVIDIA Apex AMP mixed precision')
296 | parser.add_argument(
297 | '--native-amp',
298 | action='store_true',
299 | default=False,
300 | help='Use Native Torch AMP mixed precision')
301 | parser.add_argument(
302 | '--channels-last',
303 | action='store_true',
304 | default=False,
305 | help='Use channels_last memory layout')
306 | parser.add_argument(
307 | '--pin-mem',
308 | action='store_true',
309 | default=False,
310 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.'
311 | )
312 | parser.add_argument(
313 | '--no-prefetcher',
314 | action='store_true',
315 | default=False,
316 | help='disable fast prefetcher')
317 | parser.add_argument(
318 | '--output',
319 | default='./',
320 | type=str,
321 | metavar='PATH',
322 | help='path to output folder (default: none, current dir)')
323 | parser.add_argument(
324 | '--experiment',
325 | default='',
326 | type=str,
327 | metavar='NAME',
328 | help='name of train experiment, name of sub-folder for output')
329 | parser.add_argument(
330 | '--eval-metric',
331 | default='top1',
332 | type=str,
333 | metavar='EVAL_METRIC',
334 | help='Best metric (default: "top1"')
335 | parser.add_argument(
336 | '--tta',
337 | type=int,
338 | default=0,
339 | metavar='N',
340 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)'
341 | )
342 | parser.add_argument("--local_rank", default=0, type=int)
343 | parser.add_argument(
344 | '--use-multi-epochs-loader',
345 | action='store_true',
346 | default=False,
347 | help='use the multi-epochs-loader to save time at the beginning of every epoch'
348 | )
349 | parser.add_argument(
350 | '--torchscript',
351 | dest='torchscript',
352 | action='store_true',
353 | help='convert model torchscript for inference')
354 | parser.add_argument(
355 | '--log-wandb',
356 | action='store_true',
357 | default=False,
358 | help='log training and validation metrics to wandb')
359 |
360 |
361 | def _parse_args():
362 | # Do we have a config file to parse?
363 | args_config, remaining = config_parser.parse_known_args()
364 | if args_config.config:
365 | with open(args_config.config, 'r') as f:
366 | cfg = yaml.safe_load(f)
367 | parser.set_defaults(**cfg)
368 |
369 | # The main arg parser parses the rest of the args, the usual
370 | # defaults will have been overridden if config file specified.
371 | args = parser.parse_args(remaining)
372 |
373 | # Cache the args as a text string to save them in the output dir later
374 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
375 | return args, args_text
376 |
377 |
378 | def main():
379 | setup_default_logging()
380 | args, args_text = _parse_args()
381 |
382 | if args.log_wandb:
383 | if has_wandb:
384 | wandb.init(project=args.experiment, config=args)
385 | else:
386 | _logger.warning(
387 | "You've requested to log metrics to wandb but package not found. "
388 | "Metrics not being logged to wandb, try `pip install wandb`")
389 |
390 | args.prefetcher = not args.no_prefetcher
391 | args.distributed = False
392 | if 'WORLD_SIZE' in os.environ:
393 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
394 | args.device = 'cuda:0'
395 | args.world_size = 1
396 | args.rank = 0 # global rank
397 | if args.distributed:
398 | args.device = 'cuda:%d' % args.local_rank
399 | torch.cuda.set_device(args.local_rank)
400 | torch.distributed.init_process_group(
401 | backend='nccl', init_method='env://')
402 | args.world_size = torch.distributed.get_world_size()
403 | args.rank = torch.distributed.get_rank()
404 | _logger.info(
405 | 'Testing in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
406 | % (args.rank, args.world_size))
407 | else:
408 | _logger.info('Testing with a single process on 1 GPUs.')
409 | assert args.rank >= 0
410 |
411 | # resolve AMP arguments based on PyTorch / Apex availability
412 | use_amp = None
413 | if args.amp:
414 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
415 | if has_native_amp:
416 | args.native_amp = True
417 | elif has_apex:
418 | args.apex_amp = True
419 | if args.apex_amp and has_apex:
420 | use_amp = 'apex'
421 | elif args.native_amp and has_native_amp:
422 | use_amp = 'native'
423 | elif args.apex_amp or args.native_amp:
424 | _logger.warning(
425 | "Neither APEX or native Torch AMP is available, using float32. "
426 | "Install NVIDA apex or upgrade to PyTorch 1.6")
427 |
428 | random_seed(args.seed, args.rank)
429 |
430 | m = importlib.import_module(f"model.{args.model.split('.')[0]}")
431 | model = getattr(m, args.model.split('.')[1])(dropout=args.drop)
432 |
433 | if args.num_classes is None:
434 | assert hasattr(
435 | model, 'num_classes'
436 | ), 'Model must have `num_classes` attr if not set on cmd line/config.'
437 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
438 |
439 | if args.local_rank == 0:
440 | _logger.info(
441 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}'
442 | )
443 |
444 | data_config = resolve_data_config(
445 | vars(args), model=model, verbose=args.local_rank == 0)
446 |
447 | # setup augmentation batch splits for contrastive loss or split bn
448 | num_aug_splits = 0
449 | if args.aug_splits > 0:
450 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
451 | num_aug_splits = args.aug_splits
452 |
453 | # enable split bn (separate bn stats per batch-portion)
454 | if args.split_bn:
455 | assert num_aug_splits > 1 or args.resplit
456 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
457 |
458 | # move model to GPU, enable channels last layout if set
459 | model.cuda()
460 | if args.channels_last:
461 | model = model.to(memory_format=torch.channels_last)
462 |
463 | # setup synchronized BatchNorm for distributed training
464 | if args.distributed and args.sync_bn:
465 | assert not args.split_bn
466 | if has_apex and use_amp != 'native':
467 | # Apex SyncBN preferred unless native amp is activated
468 | model = convert_syncbn_model(model)
469 | else:
470 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
471 | if args.local_rank == 0:
472 | _logger.info(
473 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
474 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
475 | )
476 |
477 | if args.torchscript:
478 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
479 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
480 | model = torch.jit.script(model)
481 |
482 | amp_autocast = suppress # do nothing
483 |
484 | load_checkpoint(model, args.resume, use_ema=False)
485 |
486 | # setup exponential moving average of model weights, SWA could be used here too
487 | model_ema = None
488 | if args.model_ema:
489 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
490 | model_ema = ModelEmaV2(
491 | model,
492 | decay=args.model_ema_decay,
493 | device='cpu' if args.model_ema_force_cpu else None)
494 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
495 |
496 | # setup distributed training
497 | if args.distributed:
498 | if has_apex and use_amp != 'native':
499 | # Apex DDP preferred unless native amp is activated
500 | if args.local_rank == 0:
501 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
502 | model = ApexDDP(model, delay_allreduce=True)
503 | else:
504 | if args.local_rank == 0:
505 | _logger.info("Using native Torch DistributedDataParallel.")
506 | model = NativeDDP(
507 | model, device_ids=[args.local_rank
508 | ]) # can use device str in Torch >= 1.1
509 | # NOTE: EMA model does not need to be wrapped by DDP
510 |
511 | dataset_eval = create_dataset(
512 | args.dataset,
513 | root=args.data_dir,
514 | split=args.val_split,
515 | is_training=False,
516 | batch_size=args.batch_size)
517 |
518 | loader_eval = create_loader(
519 | dataset_eval,
520 | input_size=data_config['input_size'],
521 | batch_size=args.validation_batch_size_multiplier * args.batch_size,
522 | is_training=False,
523 | use_prefetcher=args.prefetcher,
524 | interpolation=data_config['interpolation'],
525 | mean=data_config['mean'],
526 | std=data_config['std'],
527 | num_workers=args.workers,
528 | distributed=args.distributed,
529 | crop_pct=data_config['crop_pct'],
530 | pin_memory=args.pin_mem, )
531 |
532 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
533 |
534 | eval_metrics = validate(
535 | model,
536 | loader_eval,
537 | validate_loss_fn,
538 | args,
539 | amp_autocast=amp_autocast)
540 | if model_ema is not None and not args.model_ema_force_cpu:
541 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'
542 | ):
543 | distribute_bn(model_ema, args.world_size,
544 | args.dist_bn == 'reduce')
545 | ema_eval_metrics = validate(
546 | model_ema.module,
547 | loader_eval,
548 | validate_loss_fn,
549 | args,
550 | amp_autocast=amp_autocast,
551 | log_suffix=' (EMA)')
552 | eval_metrics = ema_eval_metrics
553 |
554 | if args.rank == 0:
555 | print('eval_metrics = {}'.format(eval_metrics))
556 | time.sleep(3)
557 |
558 |
559 | def validate(model,
560 | loader,
561 | loss_fn,
562 | args,
563 | amp_autocast=suppress,
564 | log_suffix=''):
565 | batch_time_m = AverageMeter()
566 | losses_m = AverageMeter()
567 | top1_m = AverageMeter()
568 | top5_m = AverageMeter()
569 |
570 | model.eval()
571 |
572 | end = time.time()
573 | last_idx = len(loader) - 1
574 | with torch.no_grad():
575 | for batch_idx, (input, target) in enumerate(loader):
576 | last_batch = batch_idx == last_idx
577 | if not args.prefetcher:
578 | input = input.cuda()
579 | target = target.cuda()
580 | if args.channels_last:
581 | input = input.contiguous(memory_format=torch.channels_last)
582 |
583 | with amp_autocast():
584 | output = model(input)
585 | if isinstance(output, (tuple, list)):
586 | output = output[0]
587 |
588 | # augmentation reduction
589 | reduce_factor = args.tta
590 | if reduce_factor > 1:
591 | output = output.unfold(0, reduce_factor, reduce_factor).mean(
592 | dim=2)
593 | target = target[0:target.size(0):reduce_factor]
594 |
595 | loss = loss_fn(output, target)
596 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
597 |
598 | if args.distributed:
599 | reduced_loss = reduce_tensor(loss.data, args.world_size)
600 | acc1 = reduce_tensor(acc1, args.world_size)
601 | acc5 = reduce_tensor(acc5, args.world_size)
602 | else:
603 | reduced_loss = loss.data
604 |
605 | torch.cuda.synchronize()
606 |
607 | losses_m.update(reduced_loss.item(), input.size(0))
608 | top1_m.update(acc1.item(), output.size(0))
609 | top5_m.update(acc5.item(), output.size(0))
610 |
611 | batch_time_m.update(time.time() - end)
612 | end = time.time()
613 | if args.local_rank == 0 and (last_batch or
614 | batch_idx % args.log_interval == 0):
615 | log_name = 'Test' + log_suffix
616 | _logger.info(
617 | '{0}: [{1:>4d}/{2}] '
618 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
619 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
620 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
621 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
622 | log_name,
623 | batch_idx,
624 | last_idx,
625 | batch_time=batch_time_m,
626 | loss=losses_m,
627 | top1=top1_m,
628 | top5=top5_m))
629 |
630 | metrics = OrderedDict(
631 | [('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
632 |
633 | return metrics
634 |
635 |
636 | if __name__ == '__main__':
637 | main()
638 |
--------------------------------------------------------------------------------
/work_dirs/train/readme.md:
--------------------------------------------------------------------------------
1 | We report the results based on the released models in [GoogleDrive](https://drive.google.com/drive/folders/1aL5UkhXgevyoQDo_cLmmd-DUfZcAFRXu?usp=share_link) and [百度网盘](https://pan.baidu.com/s/1yz7IdlagAL8LMf_NvjrbZw?pwd=qy7c), as in ```eval.log```.
2 |
3 | Note that the results differ to that in ```train.log``` slightly, this is because of the BN statistics.
4 | During training, the models are evaluated using 8 GPUs, in which each GPU has its own BN statistics,
5 | and we only save one model in GPU-0 as checkpoint, i.e., our released model.
6 |
7 |
8 | | RepGhostNet | Params(M) | FLOPs(M) | Latency(ms) | Top-1 Acc.(%) | Top-5 Acc.(%) | checkpoints | logs |
9 | |:------------|:----------|:---------|:------------|:--------------|:--------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------|
10 | | 0.5x | 2.3 | 43 | 25.1 | 66.9 | 86.9 | [gdrive](https://drive.google.com/file/d/16AGg-kSscFXDpXPZ3cJpYwqeZbUlUoyr/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1s-tuS8JoHVoCVHWuUiHUFw?pwd=qttp) | [log](./work_dirs/train/repghostnet_0_5x_43M_66.95/train.log) |
11 | | 0.58x | 2.5 | 60 | 31.9 | 68.9 | 88.4 | [gdrive](https://drive.google.com/file/d/1L6ccPjfnCMt5YK-pNFDfqGYvJyTRyZPR/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1bnVk2ILONqPEbmTQahZ0Og?pwd=tiyw) | [log](./work_dirs/train/repghostnet_0_58x_60M_68.94/train.log) |
12 | | 0.8x | 3.3 | 96 | 44.5 | 72.2 | 90.5 | [gdrive](https://drive.google.com/file/d/13gmUpwiJF_O05f3-3UeEyKD57veL5cG-/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1L_EJ0CnQeGpd0QBoOiY7oQ?pwd=rkd8) | [log](./work_dirs/train/repghostnet_0_8x_96M_72.24/train.log) |
13 | | 1.0x | 4.1 | 142 | 62.2 | 74.2 | 91.5 | [gdrive](https://drive.google.com/file/d/1gzfGln60urfY38elpPHVTyv9b94ukn5o/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1CEwuBLV05z7zrVbBrku59w?pwd=z4s7) | [log](./work_dirs/train/repghostnet_1_0x_142M_74.22/train.log) |
14 | | 1.11x | 4.5 | 170 | 71.5 | 75.1 | 92.2 | [gdrive](https://drive.google.com/file/d/14Lk4pKWIUFk1Mb53ooy_GsZbhMmz3iVE/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1Lb54Jiqyt0Jc6X4F_tUYnw?pwd=dwcb) | [log](./work_dirs/train/repghostnet_1_11x_170M_75.07/train.log) |
15 | | 1.3x | 5.5 | 231 | 92.9 | 76.4 | 92.9 | [gdrive](https://drive.google.com/file/d/1dNHpX2JyiuTcDmmyvr8gnAI9t8RM-Nui/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/19x_OUgxRDvwh2g4E9gN12Q?pwd=uux6) | [log](./work_dirs/train/repghostnet_1_3x_231M_76.37/train.log) |
16 | | 1.5x | 6.6 | 301 | 116.9 | 77.5 | 93.5 | [gdrive](https://drive.google.com/file/d/1TWAY654Dz8zcwhDBDN6QDWhV7as30P8e/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/15UWOMRQN5vw99QbgiFWMRw?pwd=3uqq) | [log](./work_dirs/train/repghostnet_1_5x_301M_77.45/train.log) |
17 | | 2.0x | 9.8 | 516 | 190.0 | 78.8 | 94.3 | [gdrive](https://drive.google.com/file/d/12k00eWCXhKxx_fq3ewDhCNX08ftJ-iyP/view?usp=share_link) \ [百度网盘](https://pan.baidu.com/s/1YbtYvIBt3tTqCzvbcjJuBw?pwd=nq1r) | [log](./work_dirs/train/repghostnet_2_0x_516M_78.81/train.log) |
18 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_0_58x_60M_68.94/args.yaml:
--------------------------------------------------------------------------------
1 | aa: null
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_0_58x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train/
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_0_58x_60M_68.94/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICES=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_0_58x --resume work_dirs/train/repghostnet_0_58x_60M_68.94/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 1.3653137989139557), ('top1', 68.944), ('top5', 88.422)])
3 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_0_5x_43M_66.95/args.yaml:
--------------------------------------------------------------------------------
1 | aa: null
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_0_5x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_0_5x_43M_66.95/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICES=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_0_5x --resume work_dirs/train/repghostnet_0_5x_43M_66.95/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 1.466918696756363), ('top1', 66.946), ('top5', 86.928)])
3 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_0_8x_96M_72.24/args.yaml:
--------------------------------------------------------------------------------
1 | aa: null
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_0_8x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train/
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_0_8x_96M_72.24/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICES=0 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_0_8x --resume work_dirs/train/repghostnet_0_8x_96M_72.24/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 1.2117804213762284), ('top1', 72.238), ('top5', 90.486)])
3 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_0x_142M_74.22/args.yaml:
--------------------------------------------------------------------------------
1 | aa: null
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_1_0x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train/
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_0x_142M_74.22/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICE=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_1_0x --resume work_dirs/train/repghostnet_1_0x_142M_74.22/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 1.117270246348381), ('top1', 74.218), ('top5', 91.548)])
3 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_11x_170M_75.07/args.yaml:
--------------------------------------------------------------------------------
1 | aa: null
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_1_11x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train/
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_11x_170M_75.07/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICE=1 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_1_11x --resume work_dirs/train/repghostnet_1_11x_170M_75.07/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 1.0745367737579345), ('top1', 75.066), ('top5', 92.186)])
3 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_3x_231M_76.37/args.yaml:
--------------------------------------------------------------------------------
1 | aa: rand-m9-mstd0.5
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_1_3x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train/
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_3x_231M_76.37/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICES=0 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_1_3x --resume work_dirs/train/repghostnet_1_3x_231M_76.37/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 1.0206603574037552), ('top1', 76.374), ('top5', 92.898)])
3 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_5x_301M_77.45/args.yaml:
--------------------------------------------------------------------------------
1 | aa: rand-m9-mstd0.5
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_1_5x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_1_5x_301M_77.45/eval.log:
--------------------------------------------------------------------------------
1 | eval_metrics = OrderedDict([('loss', 0.9776309717464448), ('top1', 77.454), ('top5', 93.5)])
2 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_2_0x_516M_78.81/args.yaml:
--------------------------------------------------------------------------------
1 | aa: rand-m9-mstd0.5
2 | amp: true
3 | apex_amp: false
4 | aug_splits: 0
5 | batch_size: 128
6 | bn_eps: null
7 | bn_momentum: null
8 | bn_tf: false
9 | channels_last: false
10 | checkpoint_hist: 5
11 | clip_grad: null
12 | clip_mode: norm
13 | color_jitter: 0.4
14 | cooldown_epochs: 0
15 | crop_pct: null
16 | cutmix: 0.0
17 | cutmix_minmax: null
18 | data_dir: /disk2/datasets/imagenet
19 | dataset: ''
20 | decay_epochs: 30
21 | decay_rate: 0.1
22 | dist_bn: ''
23 | drop: 0.2
24 | drop_block: null
25 | drop_connect: null
26 | drop_path: null
27 | epoch_repeats: 0.0
28 | epochs: 300
29 | eval_metric: top1
30 | experiment: ''
31 | gp: null
32 | hflip: 0.5
33 | img_size: null
34 | initial_checkpoint: ''
35 | input_size: null
36 | interpolation: ''
37 | jsd: false
38 | local_rank: 0
39 | log_interval: 50
40 | log_wandb: false
41 | lr: 0.6
42 | lr_cycle_limit: 1
43 | lr_cycle_mul: 1.0
44 | lr_noise: null
45 | lr_noise_pct: 0.67
46 | lr_noise_std: 1.0
47 | mean: null
48 | min_lr: 1.0e-05
49 | mixup: 0.0
50 | mixup_mode: batch
51 | mixup_off_epoch: 0
52 | mixup_prob: 1.0
53 | mixup_switch_prob: 0.5
54 | model: repghost.repghostnet_2_0x
55 | model_ema: true
56 | model_ema_decay: 0.9999
57 | model_ema_force_cpu: false
58 | momentum: 0.9
59 | native_amp: false
60 | no_aug: false
61 | no_prefetcher: false
62 | no_resume_opt: false
63 | num_classes: 1000
64 | opt: sgd
65 | opt_betas: null
66 | opt_eps: null
67 | output: work_dirs/train
68 | patience_epochs: 10
69 | pin_mem: false
70 | pretrained: false
71 | ratio:
72 | - 0.75
73 | - 1.3333333333333333
74 | recount: 1
75 | recovery_interval: 0
76 | remode: pixel
77 | reprob: 0.2
78 | resplit: false
79 | resume: ''
80 | save_images: false
81 | scale:
82 | - 0.08
83 | - 1.0
84 | sched: cosine
85 | seed: 42
86 | smoothing: 0.1
87 | split_bn: false
88 | start_epoch: null
89 | std: null
90 | sync_bn: false
91 | torchscript: false
92 | train_interpolation: random
93 | train_split: train
94 | tta: 0
95 | use_multi_epochs_loader: false
96 | val_split: val
97 | validation_batch_size_multiplier: 1
98 | vflip: 0.0
99 | warmup_epochs: 5
100 | warmup_lr: 0.0001
101 | weight_decay: 1.0e-05
102 | workers: 7
103 |
--------------------------------------------------------------------------------
/work_dirs/train/repghostnet_2_0x_516M_78.81/eval.log:
--------------------------------------------------------------------------------
1 | # CUDA_VISIBLE_DEVICES=0 python3 validate.py -b 32 --model-ema --model repghost.repghostnet_2_0x --resume work_dirs/train/repghostnet_2_0x_516M_78.81/model_best.pth.tar
2 | eval_metrics = OrderedDict([('loss', 0.897983897819519), ('top1', 78.806), ('top5', 94.326)])
3 |
--------------------------------------------------------------------------------