├── LICENSE ├── README.md ├── figures ├── result.png └── weightnet.png ├── hubconf.py ├── inference.py ├── shufflenet_v2.py ├── test.py ├── train.py └── weightnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 megvii-model 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WeightNet]() 2 | This repository provides MegEngine implementation for "[WeightNet: Revisiting the Design Space of Weight Network](https://arxiv.org/pdf/2007.11823.pdf)". 3 | 4 | 5 | 6 | 7 | ## Requirement 8 | - MegEngine 0.5.1 (https://github.com/MegEngine/MegEngine) 9 | 10 | 11 | ## Citation 12 | If you use these models in your research, please cite: 13 | 14 | 15 | @inproceedings{ma2020weightnet, 16 | title={WeightNet: Revisiting the Design Space of Weight Networks}, 17 | author={Ma, Ningning and Zhang, Xiangyu and Huang, Jiawei and Sun, Jian}, 18 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 19 | year={2020} 20 | } 21 | 22 | ## Usage 23 | Train: 24 | ``` 25 | python3 train.py --dataset-dir=/path/to/imagenet 26 | ``` 27 | 28 | Eval: 29 | ``` 30 | python3 test.py --data=/path/to/imagenet --model /path/to/model --ngpus 1 31 | ``` 32 | 33 | Inference: 34 | ``` 35 | python3 inference.py --model /path/to/model --image /path/to/image.jpg 36 | ``` 37 | 38 | 39 | ## Trained Models 40 | - OneDrive download: [Link](https://1drv.ms/u/s!AgaP37NGYuEXhVa4o5xbveef89Ba?e=Xvg6Vo) 41 | 42 | ## Results 43 | 44 | 45 | 46 | 47 | 48 | - Comparison under the same #Params and the same FLOPs. 49 | 50 | 51 | | Model | #Params. | FLOPs | Top-1 err. | 52 | |---------------------|----------|-------|------------| 53 | | ShuffleNetV2 (0.5×) | 1.4M | 41M | 39.7 | 54 | | + WeightNet (1×) | 1.5M | 41M | **36.7** | 55 | | ShuffleNetV2 (1.0×) | 2.2M | 138M | 30.9 | 56 | | + WeightNet (1×) | 2.4M | 139M | **28.8** | 57 | | ShuffleNetV2 (1.5×) | 3.5M | 299M | 27.4 | 58 | | + WeightNet (1×) | 3.9M | 301M | **25.6** | 59 | | ShuffleNetV2 (2.0×) | 5.5M | 557M | 25.5 | 60 | | + WeightNet (1×) | 6.1M | 562M | **24.1** | 61 | 62 | 63 | - Comparison under the same FLOPs. 64 | 65 | 66 | | Model | #Params. | FLOPs | Top-1 err. | 67 | |---------------------|----------|-------|------------| 68 | | ShuffleNetV2 (0.5×) | 1.4M | 41M | 39.7 | 69 | | + WeightNet (8×) | 2.7M | 42M | **34.0** | 70 | | ShuffleNetV2 (1.0×) | 2.2M | 138M | 30.9 | 71 | | + WeightNet (4×) | 5.1M | 141M | **27.6** | 72 | | ShuffleNetV2 (1.5×) | 3.5M | 299M | 27.4 | 73 | | + WeightNet (4×) | 9.6M | 307M | **25.0** | 74 | | ShuffleNetV2 (2.0×) | 5.5M | 557M | 25.5 | 75 | | + WeightNet (4×) | 18.1M | 573M | **23.5** | 76 | -------------------------------------------------------------------------------- /figures/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-model/WeightNet/669b5f4c0c46fd30cd0fedf5e5a63161e9e94bcc/figures/result.png -------------------------------------------------------------------------------- /figures/weightnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-model/WeightNet/669b5f4c0c46fd30cd0fedf5e5a63161e9e94bcc/figures/weightnet.png -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from shufflenet_v2 import ( 2 | shufflenet_v2_x0_5, 3 | shufflenet_v2_x1_0, 4 | shufflenet_v2_x1_5, 5 | shufflenet_v2_x2_0, 6 | ) 7 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") 3 | # 4 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. 5 | # 6 | # Unless required by applicable law or agreed to in writing, 7 | # software distributed under the License is distributed on an 8 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | import argparse 10 | import json 11 | 12 | import cv2 13 | import megengine as mge 14 | import megengine.data.transform as T 15 | import megengine.functional as F 16 | import megengine.jit as jit 17 | import numpy as np 18 | 19 | import shufflenet_v2 as M 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("-a", "--arch", default="shufflenet_v2_x1_0", type=str) 25 | parser.add_argument("-m", "--model", default=None, type=str) 26 | parser.add_argument("-i", "--image", default=None, type=str) 27 | args = parser.parse_args() 28 | 29 | model = getattr(M, args.arch)(pretrained=(args.model is None)) 30 | if args.model: 31 | state_dict = mge.load(args.model) 32 | model.load_state_dict(state_dict) 33 | 34 | if args.image is None: 35 | path = "../../../assets/cat.jpg" # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets 36 | else: 37 | path = args.image 38 | image = cv2.imread(path, cv2.IMREAD_COLOR) 39 | 40 | transform = T.Compose( 41 | [ 42 | T.Resize(256), 43 | T.CenterCrop(224), 44 | T.ToMode("CHW"), 45 | ] 46 | ) 47 | 48 | @jit.trace(symbolic=True) 49 | def infer_func(processed_img): 50 | model.eval() 51 | logits = model(processed_img) 52 | probs = F.softmax(logits) 53 | return probs 54 | 55 | processed_img = transform.apply(image)[np.newaxis, :] 56 | probs = infer_func(processed_img) 57 | 58 | top_probs, classes = F.top_k(probs, k=5, descending=True) 59 | 60 | with open("../../../assets/imagenet_class_info.json") as fp: # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets 61 | imagenet_class_index = json.load(fp) 62 | 63 | for rank, (prob, classid) in enumerate( 64 | zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1)) 65 | ): 66 | print( 67 | "{}: class = {:20s} with probability = {:4.1f} %".format( 68 | rank, imagenet_class_index[str(classid)][1], 100 * prob 69 | ) 70 | ) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /shufflenet_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2019 Megvii Technology 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # 24 | # ------------------------------------------------------------------------------ 25 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") 26 | # 27 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. 28 | # 29 | # Unless required by applicable law or agreed to in writing, 30 | # software distributed under the License is distributed on an 31 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | # 33 | # This file has been modified by Megvii ("Megvii Modifications"). 34 | # All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. 35 | # ------------------------------------------------------------------------------ 36 | import megengine.functional as F 37 | import megengine.hub as hub 38 | import megengine.module as M 39 | 40 | from weightnet import WeightNet, WeightNet_DW 41 | 42 | class ShuffleV2Block(M.Module): 43 | def __init__(self, inp, oup, mid_channels, *, ksize, stride): 44 | super(ShuffleV2Block, self).__init__() 45 | self.stride = stride 46 | assert stride in [1, 2] 47 | 48 | self.mid_channels = mid_channels 49 | self.ksize = ksize 50 | pad = ksize // 2 51 | self.pad = pad 52 | self.inp = inp 53 | 54 | outputs = oup - inp 55 | 56 | 57 | self.reduce = M.Conv2d(inp, max(16, inp//16), 1, 1, 0, bias=True) 58 | 59 | self.wnet1 = WeightNet(inp, mid_channels, 1, 1) 60 | self.bn1 = M.BatchNorm2d(mid_channels) 61 | 62 | self.wnet2 = WeightNet_DW(mid_channels, ksize, stride) 63 | self.bn2 =M.BatchNorm2d(mid_channels) 64 | 65 | self.wnet3 = WeightNet(mid_channels, outputs, 1, 1) 66 | self.bn3 = M.BatchNorm2d(outputs) 67 | 68 | if stride == 2: 69 | self.wnet_proj_1 = WeightNet_DW(inp, ksize, stride) 70 | self.bn_proj_1 = M.BatchNorm2d(inp) 71 | 72 | self.wnet_proj_2 = WeightNet(inp, inp, 1, 1) 73 | self.bn_proj_2 = M.BatchNorm2d(inp) 74 | 75 | def forward(self, old_x): 76 | if self.stride == 1: 77 | x_proj, x = self.channel_shuffle(old_x) 78 | elif self.stride == 2: 79 | x_proj = old_x 80 | x = old_x 81 | 82 | x_gap = x.mean(axis=2,keepdims=True).mean(axis=3,keepdims=True) 83 | x_gap = self.reduce(x_gap) 84 | 85 | x = self.wnet1(x, x_gap) 86 | x = self.bn1(x) 87 | x = F.relu(x) 88 | 89 | x = self.wnet2(x, x_gap) 90 | x = self.bn2(x) 91 | 92 | x = self.wnet3(x, x_gap) 93 | x = self.bn3(x) 94 | x = F.relu(x) 95 | 96 | if self.stride == 2: 97 | x_proj = self.wnet_proj_1(x_proj, x_gap) 98 | x_proj = self.bn_proj_1(x_proj) 99 | x_proj = self.wnet_proj_2(x_proj, x_gap) 100 | x_proj = self.bn_proj_2(x_proj) 101 | x_proj = F.relu(x_proj) 102 | 103 | return F.concat((x_proj, x), 1) 104 | 105 | def channel_shuffle(self, x): 106 | batchsize, num_channels, height, width = x.shape 107 | # assert (num_channels % 4 == 0) 108 | x = x.reshape(batchsize * num_channels // 2, 2, height * width) 109 | x = x.dimshuffle(1, 0, 2) 110 | x = x.reshape(2, -1, num_channels // 2, height, width) 111 | return x[0], x[1] 112 | 113 | 114 | class ShuffleNetV2(M.Module): 115 | def __init__(self, input_size=224, num_classes=1000, model_size="1.5x"): 116 | super(ShuffleNetV2, self).__init__() 117 | 118 | self.stage_repeats = [4, 8, 4] 119 | self.model_size = model_size 120 | # We reduce the width slightly here to make WeightNet's FLOPs comparable to the baselines. 121 | if model_size == "0.5x": 122 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 123 | elif model_size == "1.0x": 124 | self.stage_out_channels = [-1, 24, 112, 224, 448, 1024] 125 | elif model_size == "1.5x": 126 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 127 | elif model_size == "2.0x": 128 | self.stage_out_channels = [-1, 24, 248, 496, 992, 1024] 129 | else: 130 | raise NotImplementedError 131 | 132 | # building first layer 133 | input_channel = self.stage_out_channels[1] 134 | self.first_conv = M.Sequential( 135 | M.Conv2d(3, input_channel, 3, 2, 1, bias=True), M.BatchNorm2d(input_channel), M.ReLU(), 136 | ) 137 | 138 | self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | 140 | self.features = [] 141 | for idxstage in range(len(self.stage_repeats)): 142 | numrepeat = self.stage_repeats[idxstage] 143 | output_channel = self.stage_out_channels[idxstage + 2] 144 | 145 | for i in range(numrepeat): 146 | if i == 0: 147 | self.features.append( 148 | ShuffleV2Block( 149 | input_channel, output_channel, mid_channels=output_channel // 2, ksize=3, stride=2, 150 | ) 151 | ) 152 | else: 153 | self.features.append( 154 | ShuffleV2Block( 155 | input_channel // 2, output_channel, mid_channels=output_channel // 2, ksize=3, stride=1, 156 | ) 157 | ) 158 | 159 | input_channel = output_channel 160 | 161 | self.features = M.Sequential(*self.features) 162 | 163 | self.conv_last = M.Sequential( 164 | M.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=True), 165 | M.BatchNorm2d(self.stage_out_channels[-1]), 166 | M.ReLU(), 167 | ) 168 | self.globalpool = M.AvgPool2d(7) 169 | if self.model_size == "2.0x": 170 | self.dropout = M.Dropout(0.2) 171 | self.classifier = M.Sequential(M.Linear(self.stage_out_channels[-1], num_classes, bias=True)) 172 | self._initialize_weights() 173 | 174 | def forward(self, x): 175 | x = self.first_conv(x) 176 | x = self.maxpool(x) 177 | x = self.features(x) 178 | x = self.conv_last(x) 179 | 180 | x = self.globalpool(x) 181 | if self.model_size == "2.0x": 182 | x = self.dropout(x) 183 | x = x.reshape(-1, self.stage_out_channels[-1]) 184 | x = self.classifier(x) 185 | return x 186 | 187 | def _initialize_weights(self): 188 | for name, m in self.named_modules(): 189 | if isinstance(m, M.Conv2d): 190 | if "first" in name: 191 | M.init.normal_(m.weight, 0, 0.01) 192 | else: 193 | M.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) 194 | if m.bias is not None: 195 | M.init.fill_(m.bias, 0) 196 | elif isinstance(m, M.BatchNorm2d): 197 | M.init.fill_(m.weight, 1) 198 | if m.bias is not None: 199 | M.init.fill_(m.bias, 0.0001) 200 | M.init.fill_(m.running_mean, 0) 201 | elif isinstance(m, M.BatchNorm1d): 202 | M.init.fill_(m.weight, 1) 203 | if m.bias is not None: 204 | M.init.fill_(m.bias, 0.0001) 205 | M.init.fill_(m.running_mean, 0) 206 | elif isinstance(m, M.Linear): 207 | M.init.normal_(m.weight, 0, 0.01) 208 | if m.bias is not None: 209 | M.init.fill_(m.bias, 0) 210 | 211 | 212 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_2.0x_wnet1x_M2G2.model") 213 | def shufflenet_v2_x2_0(num_classes=1000): 214 | return ShuffleNetV2(num_classes=num_classes, model_size="2.0x") 215 | 216 | 217 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_1.5x_wnet1x_M2G2.model") 218 | def shufflenet_v2_x1_5(num_classes=1000): 219 | return ShuffleNetV2(num_classes=num_classes, model_size="1.5x") 220 | 221 | 222 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_1.0x_wnet1x_M2G2.model") 223 | def shufflenet_v2_x1_0(num_classes=1000): 224 | return ShuffleNetV2(num_classes=num_classes, model_size="1.0x") 225 | 226 | 227 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_0.5x_wnet1x_M2G2.model") 228 | def shufflenet_v2_x0_5(num_classes=1000): 229 | return ShuffleNetV2(num_classes=num_classes, model_size="0.5x") 230 | 231 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") 3 | # 4 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. 5 | # 6 | # Unless required by applicable law or agreed to in writing, 7 | # software distributed under the License is distributed on an 8 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | import argparse 10 | import multiprocessing as mp 11 | import time 12 | 13 | import megengine as mge 14 | import megengine.data as data 15 | import megengine.data.transform as T 16 | import megengine.distributed as dist 17 | import megengine.functional as F 18 | import megengine.jit as jit 19 | 20 | import shufflenet_v2 as M 21 | 22 | logger = mge.get_logger(__name__) 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("-a", "--arch", default="shufflenet_v2_x1_0", type=str) 28 | parser.add_argument("-d", "--data", default=None, type=str) 29 | parser.add_argument("-m", "--model", default=None, type=str) 30 | 31 | parser.add_argument("-n", "--ngpus", default=None, type=int) 32 | parser.add_argument("-w", "--workers", default=4, type=int) 33 | parser.add_argument("--report-freq", default=50, type=int) 34 | args = parser.parse_args() 35 | 36 | world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus 37 | 38 | if world_size > 1: 39 | # start distributed training, dispatch sub-processes 40 | mp.set_start_method("spawn") 41 | processes = [] 42 | for rank in range(world_size): 43 | p = mp.Process(target=worker, args=(rank, world_size, args)) 44 | p.start() 45 | processes.append(p) 46 | 47 | for p in processes: 48 | p.join() 49 | else: 50 | worker(0, 1, args) 51 | 52 | 53 | def worker(rank, world_size, args): 54 | if world_size > 1: 55 | # Initialize distributed process group 56 | logger.info("init distributed process group {} / {}".format(rank, world_size)) 57 | dist.init_process_group( 58 | master_ip="localhost", 59 | master_port=23456, 60 | world_size=world_size, 61 | rank=rank, 62 | dev=rank, 63 | ) 64 | 65 | model = getattr(M, args.arch)(pretrained=(args.model is None)) 66 | if args.model: 67 | logger.info("load weights from %s", args.model) 68 | model.load_state_dict(mge.load(args.model)) 69 | 70 | @jit.trace(symbolic=True) 71 | def valid_func(image, label): 72 | model.eval() 73 | logits = model(image) 74 | loss = F.cross_entropy_with_softmax(logits, label) 75 | acc1, acc5 = F.accuracy(logits, label, (1, 5)) 76 | if dist.is_distributed(): # all_reduce_mean 77 | loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() 78 | acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() 79 | acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() 80 | return loss, acc1, acc5 81 | 82 | logger.info("preparing dataset..") 83 | valid_dataset = data.dataset.ImageNet(args.data, train=False) 84 | valid_sampler = data.SequentialSampler( 85 | valid_dataset, batch_size=100, drop_last=False 86 | ) 87 | valid_queue = data.DataLoader( 88 | valid_dataset, 89 | sampler=valid_sampler, 90 | transform=T.Compose( 91 | [ 92 | T.Resize(256), 93 | T.CenterCrop(224), 94 | T.ToMode("CHW"), 95 | ] 96 | ), 97 | num_workers=args.workers, 98 | ) 99 | _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) 100 | logger.info("Valid %.3f / %.3f", valid_acc, valid_acc5) 101 | logger.info("TOTAL TEST: loss=%f,\tTop-1 err = %f,\tTop-5 err = %f", _, 1-valid_acc/100, 1-valid_acc5/100) 102 | 103 | def infer(model, data_queue, args): 104 | objs = AverageMeter("Loss") 105 | top1 = AverageMeter("Acc@1") 106 | top5 = AverageMeter("Acc@5") 107 | total_time = AverageMeter("Time") 108 | 109 | t = time.time() 110 | for step, (image, label) in enumerate(data_queue): 111 | n = image.shape[0] 112 | image = image.astype("float32") # convert np.uint8 to float32 113 | label = label.astype("int32") 114 | 115 | loss, acc1, acc5 = model(image, label) 116 | 117 | objs.update(loss.numpy()[0], n) 118 | top1.update(100 * acc1.numpy()[0], n) 119 | top5.update(100 * acc5.numpy()[0], n) 120 | total_time.update(time.time() - t) 121 | t = time.time() 122 | 123 | if step % args.report_freq == 0 and dist.get_rank() == 0: 124 | logger.info( 125 | "Step %d, %s %s %s %s", 126 | step, 127 | objs, 128 | top1, 129 | top5, 130 | total_time, 131 | ) 132 | 133 | return objs.avg, top1.avg, top5.avg 134 | 135 | class AverageMeter: 136 | """Computes and stores the average and current value""" 137 | 138 | def __init__(self, name, fmt=":.3f"): 139 | self.name = name 140 | self.fmt = fmt 141 | self.reset() 142 | 143 | def reset(self): 144 | self.val = 0 145 | self.avg = 0 146 | self.sum = 0 147 | self.count = 0 148 | 149 | def update(self, val, n=1): 150 | self.val = val 151 | self.sum += val * n 152 | self.count += n 153 | self.avg = self.sum / self.count 154 | 155 | def __str__(self): 156 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 157 | return fmtstr.format(**self.__dict__) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2019 Megvii Technology 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # 24 | # ------------------------------------------------------------------------------ 25 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") 26 | # 27 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. 28 | # 29 | # Unless required by applicable law or agreed to in writing, 30 | # software distributed under the License is distributed on an 31 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | # 33 | # This file has been modified by Megvii ("Megvii Modifications"). 34 | # All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. 35 | # ------------------------------------------------------------------------------ 36 | import argparse 37 | import multiprocessing as mp 38 | import os 39 | import time 40 | 41 | import megengine as mge 42 | import megengine.data as data 43 | import megengine.data.transform as T 44 | import megengine.distributed as dist 45 | import megengine.functional as F 46 | import megengine.jit as jit 47 | import megengine.optimizer as optim 48 | 49 | import shufflenet_v2 as M 50 | 51 | logger = mge.get_logger(__name__) 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("-a", "--arch", default="shufflenet_v2_x0_5", type=str) 57 | parser.add_argument("-d", "--data", default=None, type=str) 58 | parser.add_argument("-s", "--save", default="./models", type=str) 59 | parser.add_argument("-m", "--model", default=None, type=str) 60 | 61 | parser.add_argument("-b", "--batch-size", default=128, type=int) 62 | parser.add_argument("--learning-rate", default=0.0625, type=float) 63 | parser.add_argument("--momentum", default=0.9, type=float) 64 | parser.add_argument("--weight-decay", default=4e-5, type=float) 65 | parser.add_argument("--steps", default=300000, type=int) 66 | 67 | parser.add_argument("-n", "--ngpus", default=None, type=int) 68 | parser.add_argument("-w", "--workers", default=4, type=int) 69 | parser.add_argument("--report-freq", default=50, type=int) 70 | args = parser.parse_args() 71 | 72 | save_dir = os.path.join(args.save, args.arch) 73 | if not os.path.exists(save_dir): 74 | os.makedirs(save_dir) 75 | mge.set_log_file(os.path.join(save_dir, "log.txt")) 76 | 77 | world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus 78 | 79 | if world_size > 1: 80 | # scale learning rate by number of gpus 81 | args.learning_rate *= world_size 82 | # start distributed training, dispatch sub-processes 83 | mp.set_start_method("spawn") 84 | processes = [] 85 | for rank in range(world_size): 86 | p = mp.Process(target=worker, args=(rank, world_size, args)) 87 | p.start() 88 | processes.append(p) 89 | 90 | for p in processes: 91 | p.join() 92 | else: 93 | worker(0, 1, args) 94 | 95 | 96 | def get_parameters(model): 97 | group_no_weight_decay = [] 98 | group_weight_decay = [] 99 | for pname, p in model.named_parameters(requires_grad=True): 100 | if pname.find("weight") >= 0 and len(p.shape) > 1: 101 | # print("include ", pname, p.shape) 102 | group_weight_decay.append(p) 103 | else: 104 | # print("not include ", pname, p.shape) 105 | group_no_weight_decay.append(p) 106 | assert len(list(model.parameters())) == len(group_weight_decay) + len( 107 | group_no_weight_decay 108 | ) 109 | groups = [ 110 | dict(params=group_weight_decay), 111 | dict(params=group_no_weight_decay, weight_decay=0.0), 112 | ] 113 | return groups 114 | 115 | 116 | def worker(rank, world_size, args): 117 | # pylint: disable=too-many-statements 118 | mge.set_log_file(os.path.join(args.save, args.arch, "log.txt")) 119 | 120 | if world_size > 1: 121 | # Initialize distributed process group 122 | logger.info("init distributed process group {} / {}".format(rank, world_size)) 123 | dist.init_process_group( 124 | master_ip="localhost", 125 | master_port=23456, 126 | world_size=world_size, 127 | rank=rank, 128 | dev=rank, 129 | ) 130 | 131 | save_dir = os.path.join(args.save, args.arch) 132 | 133 | model = getattr(M, args.arch)() 134 | step_start = 0 135 | if args.model: 136 | logger.info("load weights from %s", args.model) 137 | model.load_state_dict(mge.load(args.model)) 138 | step_start = int(args.model.split("-")[1].split(".")[0]) 139 | 140 | optimizer = optim.SGD( 141 | get_parameters(model), 142 | lr=args.learning_rate, 143 | momentum=args.momentum, 144 | weight_decay=args.weight_decay, 145 | ) 146 | 147 | # Define train and valid graph 148 | @jit.trace(symbolic=True) 149 | def train_func(image, label): 150 | model.train() 151 | logits = model(image) 152 | loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) 153 | acc1, acc5 = F.accuracy(logits, label, (1, 5)) 154 | optimizer.backward(loss) # compute gradients 155 | if dist.is_distributed(): # all_reduce_mean 156 | loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size() 157 | acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size() 158 | acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size() 159 | return loss, acc1, acc5 160 | 161 | @jit.trace(symbolic=True) 162 | def valid_func(image, label): 163 | model.eval() 164 | logits = model(image) 165 | loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) 166 | acc1, acc5 = F.accuracy(logits, label, (1, 5)) 167 | if dist.is_distributed(): # all_reduce_mean 168 | loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() 169 | acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() 170 | acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() 171 | return loss, acc1, acc5 172 | 173 | # Build train and valid datasets 174 | logger.info("preparing dataset..") 175 | train_dataset = data.dataset.ImageNet(args.data, train=True) 176 | train_sampler = data.Infinite(data.RandomSampler( 177 | train_dataset, batch_size=args.batch_size, drop_last=True 178 | )) 179 | train_queue = data.DataLoader( 180 | train_dataset, 181 | sampler=train_sampler, 182 | transform=T.Compose( 183 | [ 184 | T.RandomResizedCrop(224), 185 | T.RandomHorizontalFlip(), 186 | T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 187 | T.ToMode("CHW"), 188 | ] 189 | ), 190 | num_workers=args.workers, 191 | ) 192 | 193 | valid_dataset = data.dataset.ImageNet(args.data, train=False) 194 | valid_sampler = data.SequentialSampler( 195 | valid_dataset, batch_size=100, drop_last=False 196 | ) 197 | valid_queue = data.DataLoader( 198 | valid_dataset, 199 | sampler=valid_sampler, 200 | transform=T.Compose( 201 | [ 202 | T.Resize(256), 203 | T.CenterCrop(224), 204 | T.ToMode("CHW"), 205 | ] 206 | ), 207 | num_workers=args.workers, 208 | ) 209 | 210 | # Start training 211 | objs = AverageMeter("Loss") 212 | top1 = AverageMeter("Acc@1") 213 | top5 = AverageMeter("Acc@5") 214 | total_time = AverageMeter("Time") 215 | 216 | t = time.time() 217 | for step in range(step_start, args.steps + 1): 218 | # Linear learning rate decay 219 | decay = 1.0 220 | decay = 1 - float(step) / args.steps if step < args.steps else 0 221 | for param_group in optimizer.param_groups: 222 | param_group["lr"] = args.learning_rate * decay 223 | 224 | image, label = next(train_queue) 225 | time_data=time.time()-t 226 | image = image.astype("float32") 227 | label = label.astype("int32") 228 | 229 | n = image.shape[0] 230 | 231 | optimizer.zero_grad() 232 | loss, acc1, acc5 = train_func(image, label) 233 | optimizer.step() 234 | 235 | top1.update(100 * acc1.numpy()[0], n) 236 | top5.update(100 * acc5.numpy()[0], n) 237 | objs.update(loss.numpy()[0], n) 238 | total_time.update(time.time() - t) 239 | time_iter=time.time()-t 240 | t = time.time() 241 | if step % args.report_freq == 0 and rank == 0: 242 | logger.info( 243 | "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f", 244 | step, 245 | args.learning_rate * decay, 246 | float(objs.__str__().split()[1]), 247 | 1-float(top1.__str__().split()[1])/100, 248 | 1-float(top5.__str__().split()[1])/100, 249 | time_data, 250 | time_iter - time_data, 251 | time_iter * (args.steps - step) / 3600, 252 | ) 253 | objs.reset() 254 | top1.reset() 255 | top5.reset() 256 | total_time.reset() 257 | if step % 10000 == 0 and rank == 0 and step != 0: 258 | logger.info("SAVING %06d", step) 259 | mge.save( 260 | model.state_dict(), 261 | os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)), 262 | ) 263 | if step % 50000 == 0 and step != 0: 264 | _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) 265 | logger.info("TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f", step, _, 1-valid_acc/100, 1-valid_acc5/100) 266 | 267 | mge.save( 268 | model.state_dict(), os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)) 269 | ) 270 | _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) 271 | logger.info("TEST Iter %06d: loss=%f,\tTop-1 err = %f,\tTop-5 err = %f", step, _, 1-valid_acc/100, 1-valid_acc5/100) 272 | 273 | 274 | def infer(model, data_queue, args): 275 | objs = AverageMeter("Loss") 276 | top1 = AverageMeter("Acc@1") 277 | top5 = AverageMeter("Acc@5") 278 | total_time = AverageMeter("Time") 279 | 280 | t = time.time() 281 | for step, (image, label) in enumerate(data_queue): 282 | n = image.shape[0] 283 | image = image.astype("float32") # convert np.uint8 to float32 284 | label = label.astype("int32") 285 | 286 | loss, acc1, acc5 = model(image, label) 287 | 288 | objs.update(loss.numpy()[0], n) 289 | top1.update(100 * acc1.numpy()[0], n) 290 | top5.update(100 * acc5.numpy()[0], n) 291 | total_time.update(time.time() - t) 292 | t = time.time() 293 | 294 | if step % args.report_freq == 0 and dist.get_rank() == 0: 295 | logger.info( 296 | "Step %d, %s %s %s %s", 297 | step, 298 | objs, 299 | top1, 300 | top5, 301 | total_time, 302 | ) 303 | 304 | return objs.avg, top1.avg, top5.avg 305 | 306 | 307 | 308 | class AverageMeter: 309 | """Computes and stores the average and current value""" 310 | 311 | def __init__(self, name, fmt=":.3f"): 312 | self.name = name 313 | self.fmt = fmt 314 | self.reset() 315 | 316 | def reset(self): 317 | self.val = 0 318 | self.avg = 0 319 | self.sum = 0 320 | self.count = 0 321 | 322 | def update(self, val, n=1): 323 | self.val = val 324 | self.sum += val * n 325 | self.count += n 326 | self.avg = self.sum / self.count 327 | 328 | def __str__(self): 329 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 330 | return fmtstr.format(**self.__dict__) 331 | 332 | 333 | if __name__ == "__main__": 334 | main() 335 | -------------------------------------------------------------------------------- /weightnet.py: -------------------------------------------------------------------------------- 1 | import megengine.functional as F 2 | import megengine.module as M 3 | 4 | class WeightNet(M.Module): 5 | r"""Applies WeightNet to a standard convolution. 6 | 7 | The grouped fc layer directly generates the convolutional kernel, 8 | this layer has M*inp inputs, G*oup groups and oup*inp*ksize*ksize outputs. 9 | 10 | M/G control the amount of parameters. 11 | """ 12 | 13 | def __init__(self, inp, oup, ksize, stride): 14 | super().__init__() 15 | 16 | self.M = 2 17 | self.G = 2 18 | 19 | self.pad = ksize // 2 20 | inp_gap = max(16, inp//16) 21 | self.inp = inp 22 | self.oup = oup 23 | self.ksize = ksize 24 | self.stride = stride 25 | 26 | self.wn_fc1 = M.Conv2d(inp_gap, self.M*oup, 1, 1, 0, groups=1, bias=True) 27 | self.sigmoid = M.Sigmoid() 28 | self.wn_fc2 = M.Conv2d(self.M*oup, oup*inp*ksize*ksize, 1, 1, 0, groups=self.G*oup, bias=False) 29 | 30 | 31 | def forward(self, x, x_gap): 32 | x_w = self.wn_fc1(x_gap) 33 | x_w = self.sigmoid(x_w) 34 | x_w = self.wn_fc2(x_w) 35 | 36 | if x.shape[0] == 1: # case of batch size = 1 37 | x_w = x_w.reshape(self.oup, self.inp, self.ksize, self.ksize) 38 | x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad) 39 | return x 40 | 41 | x = x.reshape(1, -1, x.shape[2], x.shape[3]) 42 | x_w = x_w.reshape(-1, self.oup, self.inp, self.ksize, self.ksize) 43 | x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad, groups=x_w.shape[0]) 44 | x = x.reshape(-1, self.oup, x.shape[2], x.shape[3]) 45 | return x 46 | 47 | class WeightNet_DW(M.Module): 48 | r""" Here we show a grouping manner when we apply WeightNet to a depthwise convolution. 49 | 50 | The grouped fc layer directly generates the convolutional kernel, has fewer parameters while achieving comparable results. 51 | This layer has M/G*inp inputs, inp groups and inp*ksize*ksize outputs. 52 | 53 | """ 54 | def __init__(self, inp, ksize, stride): 55 | super().__init__() 56 | 57 | self.M = 2 58 | self.G = 2 59 | 60 | self.pad = ksize // 2 61 | inp_gap = max(16, inp//16) 62 | self.inp = inp 63 | self.ksize = ksize 64 | self.stride = stride 65 | 66 | self.wn_fc1 = M.Conv2d(inp_gap, self.M//self.G*inp, 1, 1, 0, groups=1, bias=True) 67 | self.sigmoid = M.Sigmoid() 68 | self.wn_fc2 = M.Conv2d(self.M//self.G*inp, inp*ksize*ksize, 1, 1, 0, groups=inp, bias=False) 69 | 70 | 71 | def forward(self, x, x_gap): 72 | x_w = self.wn_fc1(x_gap) 73 | x_w = self.sigmoid(x_w) 74 | x_w = self.wn_fc2(x_w) 75 | 76 | x = x.reshape(1, -1, x.shape[2], x.shape[3]) 77 | x_w = x_w.reshape(-1, 1, 1, self.ksize, self.ksize) 78 | x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad, groups=x_w.shape[0]) 79 | x = x.reshape(-1, self.inp, x.shape[2], x.shape[3]) 80 | return x 81 | --------------------------------------------------------------------------------