├── LICENSE
├── README.md
├── figures
├── result.png
└── weightnet.png
├── hubconf.py
├── inference.py
├── shufflenet_v2.py
├── test.py
├── train.py
└── weightnet.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 megvii-model
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [WeightNet]()
2 | This repository provides MegEngine implementation for "[WeightNet: Revisiting the Design Space of Weight Network](https://arxiv.org/pdf/2007.11823.pdf)".
3 |
4 |
5 |
6 |
7 | ## Requirement
8 | - MegEngine 0.5.1 (https://github.com/MegEngine/MegEngine)
9 |
10 |
11 | ## Citation
12 | If you use these models in your research, please cite:
13 |
14 |
15 | @inproceedings{ma2020weightnet,
16 | title={WeightNet: Revisiting the Design Space of Weight Networks},
17 | author={Ma, Ningning and Zhang, Xiangyu and Huang, Jiawei and Sun, Jian},
18 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
19 | year={2020}
20 | }
21 |
22 | ## Usage
23 | Train:
24 | ```
25 | python3 train.py --dataset-dir=/path/to/imagenet
26 | ```
27 |
28 | Eval:
29 | ```
30 | python3 test.py --data=/path/to/imagenet --model /path/to/model --ngpus 1
31 | ```
32 |
33 | Inference:
34 | ```
35 | python3 inference.py --model /path/to/model --image /path/to/image.jpg
36 | ```
37 |
38 |
39 | ## Trained Models
40 | - OneDrive download: [Link](https://1drv.ms/u/s!AgaP37NGYuEXhVa4o5xbveef89Ba?e=Xvg6Vo)
41 |
42 | ## Results
43 |
44 |
45 |
46 |
47 |
48 | - Comparison under the same #Params and the same FLOPs.
49 |
50 |
51 | | Model | #Params. | FLOPs | Top-1 err. |
52 | |---------------------|----------|-------|------------|
53 | | ShuffleNetV2 (0.5×) | 1.4M | 41M | 39.7 |
54 | | + WeightNet (1×) | 1.5M | 41M | **36.7** |
55 | | ShuffleNetV2 (1.0×) | 2.2M | 138M | 30.9 |
56 | | + WeightNet (1×) | 2.4M | 139M | **28.8** |
57 | | ShuffleNetV2 (1.5×) | 3.5M | 299M | 27.4 |
58 | | + WeightNet (1×) | 3.9M | 301M | **25.6** |
59 | | ShuffleNetV2 (2.0×) | 5.5M | 557M | 25.5 |
60 | | + WeightNet (1×) | 6.1M | 562M | **24.1** |
61 |
62 |
63 | - Comparison under the same FLOPs.
64 |
65 |
66 | | Model | #Params. | FLOPs | Top-1 err. |
67 | |---------------------|----------|-------|------------|
68 | | ShuffleNetV2 (0.5×) | 1.4M | 41M | 39.7 |
69 | | + WeightNet (8×) | 2.7M | 42M | **34.0** |
70 | | ShuffleNetV2 (1.0×) | 2.2M | 138M | 30.9 |
71 | | + WeightNet (4×) | 5.1M | 141M | **27.6** |
72 | | ShuffleNetV2 (1.5×) | 3.5M | 299M | 27.4 |
73 | | + WeightNet (4×) | 9.6M | 307M | **25.0** |
74 | | ShuffleNetV2 (2.0×) | 5.5M | 557M | 25.5 |
75 | | + WeightNet (4×) | 18.1M | 573M | **23.5** |
76 |
--------------------------------------------------------------------------------
/figures/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/megvii-model/WeightNet/669b5f4c0c46fd30cd0fedf5e5a63161e9e94bcc/figures/result.png
--------------------------------------------------------------------------------
/figures/weightnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/megvii-model/WeightNet/669b5f4c0c46fd30cd0fedf5e5a63161e9e94bcc/figures/weightnet.png
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | from shufflenet_v2 import (
2 | shufflenet_v2_x0_5,
3 | shufflenet_v2_x1_0,
4 | shufflenet_v2_x1_5,
5 | shufflenet_v2_x2_0,
6 | )
7 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3 | #
4 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5 | #
6 | # Unless required by applicable law or agreed to in writing,
7 | # software distributed under the License is distributed on an
8 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | import argparse
10 | import json
11 |
12 | import cv2
13 | import megengine as mge
14 | import megengine.data.transform as T
15 | import megengine.functional as F
16 | import megengine.jit as jit
17 | import numpy as np
18 |
19 | import shufflenet_v2 as M
20 |
21 |
22 | def main():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("-a", "--arch", default="shufflenet_v2_x1_0", type=str)
25 | parser.add_argument("-m", "--model", default=None, type=str)
26 | parser.add_argument("-i", "--image", default=None, type=str)
27 | args = parser.parse_args()
28 |
29 | model = getattr(M, args.arch)(pretrained=(args.model is None))
30 | if args.model:
31 | state_dict = mge.load(args.model)
32 | model.load_state_dict(state_dict)
33 |
34 | if args.image is None:
35 | path = "../../../assets/cat.jpg" # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets
36 | else:
37 | path = args.image
38 | image = cv2.imread(path, cv2.IMREAD_COLOR)
39 |
40 | transform = T.Compose(
41 | [
42 | T.Resize(256),
43 | T.CenterCrop(224),
44 | T.ToMode("CHW"),
45 | ]
46 | )
47 |
48 | @jit.trace(symbolic=True)
49 | def infer_func(processed_img):
50 | model.eval()
51 | logits = model(processed_img)
52 | probs = F.softmax(logits)
53 | return probs
54 |
55 | processed_img = transform.apply(image)[np.newaxis, :]
56 | probs = infer_func(processed_img)
57 |
58 | top_probs, classes = F.top_k(probs, k=5, descending=True)
59 |
60 | with open("../../../assets/imagenet_class_info.json") as fp: # please find the files in https://github.com/MegEngine/Models/tree/master/official/assets
61 | imagenet_class_index = json.load(fp)
62 |
63 | for rank, (prob, classid) in enumerate(
64 | zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1))
65 | ):
66 | print(
67 | "{}: class = {:20s} with probability = {:4.1f} %".format(
68 | rank, imagenet_class_index[str(classid)][1], 100 * prob
69 | )
70 | )
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/shufflenet_v2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # MIT License
3 | #
4 | # Copyright (c) 2019 Megvii Technology
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE.
23 | #
24 | # ------------------------------------------------------------------------------
25 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
26 | #
27 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
28 | #
29 | # Unless required by applicable law or agreed to in writing,
30 | # software distributed under the License is distributed on an
31 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32 | #
33 | # This file has been modified by Megvii ("Megvii Modifications").
34 | # All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
35 | # ------------------------------------------------------------------------------
36 | import megengine.functional as F
37 | import megengine.hub as hub
38 | import megengine.module as M
39 |
40 | from weightnet import WeightNet, WeightNet_DW
41 |
42 | class ShuffleV2Block(M.Module):
43 | def __init__(self, inp, oup, mid_channels, *, ksize, stride):
44 | super(ShuffleV2Block, self).__init__()
45 | self.stride = stride
46 | assert stride in [1, 2]
47 |
48 | self.mid_channels = mid_channels
49 | self.ksize = ksize
50 | pad = ksize // 2
51 | self.pad = pad
52 | self.inp = inp
53 |
54 | outputs = oup - inp
55 |
56 |
57 | self.reduce = M.Conv2d(inp, max(16, inp//16), 1, 1, 0, bias=True)
58 |
59 | self.wnet1 = WeightNet(inp, mid_channels, 1, 1)
60 | self.bn1 = M.BatchNorm2d(mid_channels)
61 |
62 | self.wnet2 = WeightNet_DW(mid_channels, ksize, stride)
63 | self.bn2 =M.BatchNorm2d(mid_channels)
64 |
65 | self.wnet3 = WeightNet(mid_channels, outputs, 1, 1)
66 | self.bn3 = M.BatchNorm2d(outputs)
67 |
68 | if stride == 2:
69 | self.wnet_proj_1 = WeightNet_DW(inp, ksize, stride)
70 | self.bn_proj_1 = M.BatchNorm2d(inp)
71 |
72 | self.wnet_proj_2 = WeightNet(inp, inp, 1, 1)
73 | self.bn_proj_2 = M.BatchNorm2d(inp)
74 |
75 | def forward(self, old_x):
76 | if self.stride == 1:
77 | x_proj, x = self.channel_shuffle(old_x)
78 | elif self.stride == 2:
79 | x_proj = old_x
80 | x = old_x
81 |
82 | x_gap = x.mean(axis=2,keepdims=True).mean(axis=3,keepdims=True)
83 | x_gap = self.reduce(x_gap)
84 |
85 | x = self.wnet1(x, x_gap)
86 | x = self.bn1(x)
87 | x = F.relu(x)
88 |
89 | x = self.wnet2(x, x_gap)
90 | x = self.bn2(x)
91 |
92 | x = self.wnet3(x, x_gap)
93 | x = self.bn3(x)
94 | x = F.relu(x)
95 |
96 | if self.stride == 2:
97 | x_proj = self.wnet_proj_1(x_proj, x_gap)
98 | x_proj = self.bn_proj_1(x_proj)
99 | x_proj = self.wnet_proj_2(x_proj, x_gap)
100 | x_proj = self.bn_proj_2(x_proj)
101 | x_proj = F.relu(x_proj)
102 |
103 | return F.concat((x_proj, x), 1)
104 |
105 | def channel_shuffle(self, x):
106 | batchsize, num_channels, height, width = x.shape
107 | # assert (num_channels % 4 == 0)
108 | x = x.reshape(batchsize * num_channels // 2, 2, height * width)
109 | x = x.dimshuffle(1, 0, 2)
110 | x = x.reshape(2, -1, num_channels // 2, height, width)
111 | return x[0], x[1]
112 |
113 |
114 | class ShuffleNetV2(M.Module):
115 | def __init__(self, input_size=224, num_classes=1000, model_size="1.5x"):
116 | super(ShuffleNetV2, self).__init__()
117 |
118 | self.stage_repeats = [4, 8, 4]
119 | self.model_size = model_size
120 | # We reduce the width slightly here to make WeightNet's FLOPs comparable to the baselines.
121 | if model_size == "0.5x":
122 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
123 | elif model_size == "1.0x":
124 | self.stage_out_channels = [-1, 24, 112, 224, 448, 1024]
125 | elif model_size == "1.5x":
126 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
127 | elif model_size == "2.0x":
128 | self.stage_out_channels = [-1, 24, 248, 496, 992, 1024]
129 | else:
130 | raise NotImplementedError
131 |
132 | # building first layer
133 | input_channel = self.stage_out_channels[1]
134 | self.first_conv = M.Sequential(
135 | M.Conv2d(3, input_channel, 3, 2, 1, bias=True), M.BatchNorm2d(input_channel), M.ReLU(),
136 | )
137 |
138 | self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
139 |
140 | self.features = []
141 | for idxstage in range(len(self.stage_repeats)):
142 | numrepeat = self.stage_repeats[idxstage]
143 | output_channel = self.stage_out_channels[idxstage + 2]
144 |
145 | for i in range(numrepeat):
146 | if i == 0:
147 | self.features.append(
148 | ShuffleV2Block(
149 | input_channel, output_channel, mid_channels=output_channel // 2, ksize=3, stride=2,
150 | )
151 | )
152 | else:
153 | self.features.append(
154 | ShuffleV2Block(
155 | input_channel // 2, output_channel, mid_channels=output_channel // 2, ksize=3, stride=1,
156 | )
157 | )
158 |
159 | input_channel = output_channel
160 |
161 | self.features = M.Sequential(*self.features)
162 |
163 | self.conv_last = M.Sequential(
164 | M.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=True),
165 | M.BatchNorm2d(self.stage_out_channels[-1]),
166 | M.ReLU(),
167 | )
168 | self.globalpool = M.AvgPool2d(7)
169 | if self.model_size == "2.0x":
170 | self.dropout = M.Dropout(0.2)
171 | self.classifier = M.Sequential(M.Linear(self.stage_out_channels[-1], num_classes, bias=True))
172 | self._initialize_weights()
173 |
174 | def forward(self, x):
175 | x = self.first_conv(x)
176 | x = self.maxpool(x)
177 | x = self.features(x)
178 | x = self.conv_last(x)
179 |
180 | x = self.globalpool(x)
181 | if self.model_size == "2.0x":
182 | x = self.dropout(x)
183 | x = x.reshape(-1, self.stage_out_channels[-1])
184 | x = self.classifier(x)
185 | return x
186 |
187 | def _initialize_weights(self):
188 | for name, m in self.named_modules():
189 | if isinstance(m, M.Conv2d):
190 | if "first" in name:
191 | M.init.normal_(m.weight, 0, 0.01)
192 | else:
193 | M.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
194 | if m.bias is not None:
195 | M.init.fill_(m.bias, 0)
196 | elif isinstance(m, M.BatchNorm2d):
197 | M.init.fill_(m.weight, 1)
198 | if m.bias is not None:
199 | M.init.fill_(m.bias, 0.0001)
200 | M.init.fill_(m.running_mean, 0)
201 | elif isinstance(m, M.BatchNorm1d):
202 | M.init.fill_(m.weight, 1)
203 | if m.bias is not None:
204 | M.init.fill_(m.bias, 0.0001)
205 | M.init.fill_(m.running_mean, 0)
206 | elif isinstance(m, M.Linear):
207 | M.init.normal_(m.weight, 0, 0.01)
208 | if m.bias is not None:
209 | M.init.fill_(m.bias, 0)
210 |
211 |
212 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_2.0x_wnet1x_M2G2.model")
213 | def shufflenet_v2_x2_0(num_classes=1000):
214 | return ShuffleNetV2(num_classes=num_classes, model_size="2.0x")
215 |
216 |
217 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_1.5x_wnet1x_M2G2.model")
218 | def shufflenet_v2_x1_5(num_classes=1000):
219 | return ShuffleNetV2(num_classes=num_classes, model_size="1.5x")
220 |
221 |
222 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_1.0x_wnet1x_M2G2.model")
223 | def shufflenet_v2_x1_0(num_classes=1000):
224 | return ShuffleNetV2(num_classes=num_classes, model_size="1.0x")
225 |
226 |
227 | @hub.pretrained("https://data.megengine.org.cn/models/weights/wnet/snetv2_0.5x_wnet1x_M2G2.model")
228 | def shufflenet_v2_x0_5(num_classes=1000):
229 | return ShuffleNetV2(num_classes=num_classes, model_size="0.5x")
230 |
231 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
3 | #
4 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
5 | #
6 | # Unless required by applicable law or agreed to in writing,
7 | # software distributed under the License is distributed on an
8 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | import argparse
10 | import multiprocessing as mp
11 | import time
12 |
13 | import megengine as mge
14 | import megengine.data as data
15 | import megengine.data.transform as T
16 | import megengine.distributed as dist
17 | import megengine.functional as F
18 | import megengine.jit as jit
19 |
20 | import shufflenet_v2 as M
21 |
22 | logger = mge.get_logger(__name__)
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("-a", "--arch", default="shufflenet_v2_x1_0", type=str)
28 | parser.add_argument("-d", "--data", default=None, type=str)
29 | parser.add_argument("-m", "--model", default=None, type=str)
30 |
31 | parser.add_argument("-n", "--ngpus", default=None, type=int)
32 | parser.add_argument("-w", "--workers", default=4, type=int)
33 | parser.add_argument("--report-freq", default=50, type=int)
34 | args = parser.parse_args()
35 |
36 | world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
37 |
38 | if world_size > 1:
39 | # start distributed training, dispatch sub-processes
40 | mp.set_start_method("spawn")
41 | processes = []
42 | for rank in range(world_size):
43 | p = mp.Process(target=worker, args=(rank, world_size, args))
44 | p.start()
45 | processes.append(p)
46 |
47 | for p in processes:
48 | p.join()
49 | else:
50 | worker(0, 1, args)
51 |
52 |
53 | def worker(rank, world_size, args):
54 | if world_size > 1:
55 | # Initialize distributed process group
56 | logger.info("init distributed process group {} / {}".format(rank, world_size))
57 | dist.init_process_group(
58 | master_ip="localhost",
59 | master_port=23456,
60 | world_size=world_size,
61 | rank=rank,
62 | dev=rank,
63 | )
64 |
65 | model = getattr(M, args.arch)(pretrained=(args.model is None))
66 | if args.model:
67 | logger.info("load weights from %s", args.model)
68 | model.load_state_dict(mge.load(args.model))
69 |
70 | @jit.trace(symbolic=True)
71 | def valid_func(image, label):
72 | model.eval()
73 | logits = model(image)
74 | loss = F.cross_entropy_with_softmax(logits, label)
75 | acc1, acc5 = F.accuracy(logits, label, (1, 5))
76 | if dist.is_distributed(): # all_reduce_mean
77 | loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
78 | acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
79 | acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
80 | return loss, acc1, acc5
81 |
82 | logger.info("preparing dataset..")
83 | valid_dataset = data.dataset.ImageNet(args.data, train=False)
84 | valid_sampler = data.SequentialSampler(
85 | valid_dataset, batch_size=100, drop_last=False
86 | )
87 | valid_queue = data.DataLoader(
88 | valid_dataset,
89 | sampler=valid_sampler,
90 | transform=T.Compose(
91 | [
92 | T.Resize(256),
93 | T.CenterCrop(224),
94 | T.ToMode("CHW"),
95 | ]
96 | ),
97 | num_workers=args.workers,
98 | )
99 | _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
100 | logger.info("Valid %.3f / %.3f", valid_acc, valid_acc5)
101 | logger.info("TOTAL TEST: loss=%f,\tTop-1 err = %f,\tTop-5 err = %f", _, 1-valid_acc/100, 1-valid_acc5/100)
102 |
103 | def infer(model, data_queue, args):
104 | objs = AverageMeter("Loss")
105 | top1 = AverageMeter("Acc@1")
106 | top5 = AverageMeter("Acc@5")
107 | total_time = AverageMeter("Time")
108 |
109 | t = time.time()
110 | for step, (image, label) in enumerate(data_queue):
111 | n = image.shape[0]
112 | image = image.astype("float32") # convert np.uint8 to float32
113 | label = label.astype("int32")
114 |
115 | loss, acc1, acc5 = model(image, label)
116 |
117 | objs.update(loss.numpy()[0], n)
118 | top1.update(100 * acc1.numpy()[0], n)
119 | top5.update(100 * acc5.numpy()[0], n)
120 | total_time.update(time.time() - t)
121 | t = time.time()
122 |
123 | if step % args.report_freq == 0 and dist.get_rank() == 0:
124 | logger.info(
125 | "Step %d, %s %s %s %s",
126 | step,
127 | objs,
128 | top1,
129 | top5,
130 | total_time,
131 | )
132 |
133 | return objs.avg, top1.avg, top5.avg
134 |
135 | class AverageMeter:
136 | """Computes and stores the average and current value"""
137 |
138 | def __init__(self, name, fmt=":.3f"):
139 | self.name = name
140 | self.fmt = fmt
141 | self.reset()
142 |
143 | def reset(self):
144 | self.val = 0
145 | self.avg = 0
146 | self.sum = 0
147 | self.count = 0
148 |
149 | def update(self, val, n=1):
150 | self.val = val
151 | self.sum += val * n
152 | self.count += n
153 | self.avg = self.sum / self.count
154 |
155 | def __str__(self):
156 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
157 | return fmtstr.format(**self.__dict__)
158 |
159 |
160 | if __name__ == "__main__":
161 | main()
162 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # MIT License
3 | #
4 | # Copyright (c) 2019 Megvii Technology
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 |
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 |
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE.
23 | #
24 | # ------------------------------------------------------------------------------
25 | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
26 | #
27 | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
28 | #
29 | # Unless required by applicable law or agreed to in writing,
30 | # software distributed under the License is distributed on an
31 | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32 | #
33 | # This file has been modified by Megvii ("Megvii Modifications").
34 | # All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
35 | # ------------------------------------------------------------------------------
36 | import argparse
37 | import multiprocessing as mp
38 | import os
39 | import time
40 |
41 | import megengine as mge
42 | import megengine.data as data
43 | import megengine.data.transform as T
44 | import megengine.distributed as dist
45 | import megengine.functional as F
46 | import megengine.jit as jit
47 | import megengine.optimizer as optim
48 |
49 | import shufflenet_v2 as M
50 |
51 | logger = mge.get_logger(__name__)
52 |
53 |
54 | def main():
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("-a", "--arch", default="shufflenet_v2_x0_5", type=str)
57 | parser.add_argument("-d", "--data", default=None, type=str)
58 | parser.add_argument("-s", "--save", default="./models", type=str)
59 | parser.add_argument("-m", "--model", default=None, type=str)
60 |
61 | parser.add_argument("-b", "--batch-size", default=128, type=int)
62 | parser.add_argument("--learning-rate", default=0.0625, type=float)
63 | parser.add_argument("--momentum", default=0.9, type=float)
64 | parser.add_argument("--weight-decay", default=4e-5, type=float)
65 | parser.add_argument("--steps", default=300000, type=int)
66 |
67 | parser.add_argument("-n", "--ngpus", default=None, type=int)
68 | parser.add_argument("-w", "--workers", default=4, type=int)
69 | parser.add_argument("--report-freq", default=50, type=int)
70 | args = parser.parse_args()
71 |
72 | save_dir = os.path.join(args.save, args.arch)
73 | if not os.path.exists(save_dir):
74 | os.makedirs(save_dir)
75 | mge.set_log_file(os.path.join(save_dir, "log.txt"))
76 |
77 | world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus
78 |
79 | if world_size > 1:
80 | # scale learning rate by number of gpus
81 | args.learning_rate *= world_size
82 | # start distributed training, dispatch sub-processes
83 | mp.set_start_method("spawn")
84 | processes = []
85 | for rank in range(world_size):
86 | p = mp.Process(target=worker, args=(rank, world_size, args))
87 | p.start()
88 | processes.append(p)
89 |
90 | for p in processes:
91 | p.join()
92 | else:
93 | worker(0, 1, args)
94 |
95 |
96 | def get_parameters(model):
97 | group_no_weight_decay = []
98 | group_weight_decay = []
99 | for pname, p in model.named_parameters(requires_grad=True):
100 | if pname.find("weight") >= 0 and len(p.shape) > 1:
101 | # print("include ", pname, p.shape)
102 | group_weight_decay.append(p)
103 | else:
104 | # print("not include ", pname, p.shape)
105 | group_no_weight_decay.append(p)
106 | assert len(list(model.parameters())) == len(group_weight_decay) + len(
107 | group_no_weight_decay
108 | )
109 | groups = [
110 | dict(params=group_weight_decay),
111 | dict(params=group_no_weight_decay, weight_decay=0.0),
112 | ]
113 | return groups
114 |
115 |
116 | def worker(rank, world_size, args):
117 | # pylint: disable=too-many-statements
118 | mge.set_log_file(os.path.join(args.save, args.arch, "log.txt"))
119 |
120 | if world_size > 1:
121 | # Initialize distributed process group
122 | logger.info("init distributed process group {} / {}".format(rank, world_size))
123 | dist.init_process_group(
124 | master_ip="localhost",
125 | master_port=23456,
126 | world_size=world_size,
127 | rank=rank,
128 | dev=rank,
129 | )
130 |
131 | save_dir = os.path.join(args.save, args.arch)
132 |
133 | model = getattr(M, args.arch)()
134 | step_start = 0
135 | if args.model:
136 | logger.info("load weights from %s", args.model)
137 | model.load_state_dict(mge.load(args.model))
138 | step_start = int(args.model.split("-")[1].split(".")[0])
139 |
140 | optimizer = optim.SGD(
141 | get_parameters(model),
142 | lr=args.learning_rate,
143 | momentum=args.momentum,
144 | weight_decay=args.weight_decay,
145 | )
146 |
147 | # Define train and valid graph
148 | @jit.trace(symbolic=True)
149 | def train_func(image, label):
150 | model.train()
151 | logits = model(image)
152 | loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
153 | acc1, acc5 = F.accuracy(logits, label, (1, 5))
154 | optimizer.backward(loss) # compute gradients
155 | if dist.is_distributed(): # all_reduce_mean
156 | loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size()
157 | acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size()
158 | acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size()
159 | return loss, acc1, acc5
160 |
161 | @jit.trace(symbolic=True)
162 | def valid_func(image, label):
163 | model.eval()
164 | logits = model(image)
165 | loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
166 | acc1, acc5 = F.accuracy(logits, label, (1, 5))
167 | if dist.is_distributed(): # all_reduce_mean
168 | loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
169 | acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
170 | acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
171 | return loss, acc1, acc5
172 |
173 | # Build train and valid datasets
174 | logger.info("preparing dataset..")
175 | train_dataset = data.dataset.ImageNet(args.data, train=True)
176 | train_sampler = data.Infinite(data.RandomSampler(
177 | train_dataset, batch_size=args.batch_size, drop_last=True
178 | ))
179 | train_queue = data.DataLoader(
180 | train_dataset,
181 | sampler=train_sampler,
182 | transform=T.Compose(
183 | [
184 | T.RandomResizedCrop(224),
185 | T.RandomHorizontalFlip(),
186 | T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
187 | T.ToMode("CHW"),
188 | ]
189 | ),
190 | num_workers=args.workers,
191 | )
192 |
193 | valid_dataset = data.dataset.ImageNet(args.data, train=False)
194 | valid_sampler = data.SequentialSampler(
195 | valid_dataset, batch_size=100, drop_last=False
196 | )
197 | valid_queue = data.DataLoader(
198 | valid_dataset,
199 | sampler=valid_sampler,
200 | transform=T.Compose(
201 | [
202 | T.Resize(256),
203 | T.CenterCrop(224),
204 | T.ToMode("CHW"),
205 | ]
206 | ),
207 | num_workers=args.workers,
208 | )
209 |
210 | # Start training
211 | objs = AverageMeter("Loss")
212 | top1 = AverageMeter("Acc@1")
213 | top5 = AverageMeter("Acc@5")
214 | total_time = AverageMeter("Time")
215 |
216 | t = time.time()
217 | for step in range(step_start, args.steps + 1):
218 | # Linear learning rate decay
219 | decay = 1.0
220 | decay = 1 - float(step) / args.steps if step < args.steps else 0
221 | for param_group in optimizer.param_groups:
222 | param_group["lr"] = args.learning_rate * decay
223 |
224 | image, label = next(train_queue)
225 | time_data=time.time()-t
226 | image = image.astype("float32")
227 | label = label.astype("int32")
228 |
229 | n = image.shape[0]
230 |
231 | optimizer.zero_grad()
232 | loss, acc1, acc5 = train_func(image, label)
233 | optimizer.step()
234 |
235 | top1.update(100 * acc1.numpy()[0], n)
236 | top5.update(100 * acc5.numpy()[0], n)
237 | objs.update(loss.numpy()[0], n)
238 | total_time.update(time.time() - t)
239 | time_iter=time.time()-t
240 | t = time.time()
241 | if step % args.report_freq == 0 and rank == 0:
242 | logger.info(
243 | "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f",
244 | step,
245 | args.learning_rate * decay,
246 | float(objs.__str__().split()[1]),
247 | 1-float(top1.__str__().split()[1])/100,
248 | 1-float(top5.__str__().split()[1])/100,
249 | time_data,
250 | time_iter - time_data,
251 | time_iter * (args.steps - step) / 3600,
252 | )
253 | objs.reset()
254 | top1.reset()
255 | top5.reset()
256 | total_time.reset()
257 | if step % 10000 == 0 and rank == 0 and step != 0:
258 | logger.info("SAVING %06d", step)
259 | mge.save(
260 | model.state_dict(),
261 | os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)),
262 | )
263 | if step % 50000 == 0 and step != 0:
264 | _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
265 | logger.info("TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f", step, _, 1-valid_acc/100, 1-valid_acc5/100)
266 |
267 | mge.save(
268 | model.state_dict(), os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step))
269 | )
270 | _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
271 | logger.info("TEST Iter %06d: loss=%f,\tTop-1 err = %f,\tTop-5 err = %f", step, _, 1-valid_acc/100, 1-valid_acc5/100)
272 |
273 |
274 | def infer(model, data_queue, args):
275 | objs = AverageMeter("Loss")
276 | top1 = AverageMeter("Acc@1")
277 | top5 = AverageMeter("Acc@5")
278 | total_time = AverageMeter("Time")
279 |
280 | t = time.time()
281 | for step, (image, label) in enumerate(data_queue):
282 | n = image.shape[0]
283 | image = image.astype("float32") # convert np.uint8 to float32
284 | label = label.astype("int32")
285 |
286 | loss, acc1, acc5 = model(image, label)
287 |
288 | objs.update(loss.numpy()[0], n)
289 | top1.update(100 * acc1.numpy()[0], n)
290 | top5.update(100 * acc5.numpy()[0], n)
291 | total_time.update(time.time() - t)
292 | t = time.time()
293 |
294 | if step % args.report_freq == 0 and dist.get_rank() == 0:
295 | logger.info(
296 | "Step %d, %s %s %s %s",
297 | step,
298 | objs,
299 | top1,
300 | top5,
301 | total_time,
302 | )
303 |
304 | return objs.avg, top1.avg, top5.avg
305 |
306 |
307 |
308 | class AverageMeter:
309 | """Computes and stores the average and current value"""
310 |
311 | def __init__(self, name, fmt=":.3f"):
312 | self.name = name
313 | self.fmt = fmt
314 | self.reset()
315 |
316 | def reset(self):
317 | self.val = 0
318 | self.avg = 0
319 | self.sum = 0
320 | self.count = 0
321 |
322 | def update(self, val, n=1):
323 | self.val = val
324 | self.sum += val * n
325 | self.count += n
326 | self.avg = self.sum / self.count
327 |
328 | def __str__(self):
329 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
330 | return fmtstr.format(**self.__dict__)
331 |
332 |
333 | if __name__ == "__main__":
334 | main()
335 |
--------------------------------------------------------------------------------
/weightnet.py:
--------------------------------------------------------------------------------
1 | import megengine.functional as F
2 | import megengine.module as M
3 |
4 | class WeightNet(M.Module):
5 | r"""Applies WeightNet to a standard convolution.
6 |
7 | The grouped fc layer directly generates the convolutional kernel,
8 | this layer has M*inp inputs, G*oup groups and oup*inp*ksize*ksize outputs.
9 |
10 | M/G control the amount of parameters.
11 | """
12 |
13 | def __init__(self, inp, oup, ksize, stride):
14 | super().__init__()
15 |
16 | self.M = 2
17 | self.G = 2
18 |
19 | self.pad = ksize // 2
20 | inp_gap = max(16, inp//16)
21 | self.inp = inp
22 | self.oup = oup
23 | self.ksize = ksize
24 | self.stride = stride
25 |
26 | self.wn_fc1 = M.Conv2d(inp_gap, self.M*oup, 1, 1, 0, groups=1, bias=True)
27 | self.sigmoid = M.Sigmoid()
28 | self.wn_fc2 = M.Conv2d(self.M*oup, oup*inp*ksize*ksize, 1, 1, 0, groups=self.G*oup, bias=False)
29 |
30 |
31 | def forward(self, x, x_gap):
32 | x_w = self.wn_fc1(x_gap)
33 | x_w = self.sigmoid(x_w)
34 | x_w = self.wn_fc2(x_w)
35 |
36 | if x.shape[0] == 1: # case of batch size = 1
37 | x_w = x_w.reshape(self.oup, self.inp, self.ksize, self.ksize)
38 | x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad)
39 | return x
40 |
41 | x = x.reshape(1, -1, x.shape[2], x.shape[3])
42 | x_w = x_w.reshape(-1, self.oup, self.inp, self.ksize, self.ksize)
43 | x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad, groups=x_w.shape[0])
44 | x = x.reshape(-1, self.oup, x.shape[2], x.shape[3])
45 | return x
46 |
47 | class WeightNet_DW(M.Module):
48 | r""" Here we show a grouping manner when we apply WeightNet to a depthwise convolution.
49 |
50 | The grouped fc layer directly generates the convolutional kernel, has fewer parameters while achieving comparable results.
51 | This layer has M/G*inp inputs, inp groups and inp*ksize*ksize outputs.
52 |
53 | """
54 | def __init__(self, inp, ksize, stride):
55 | super().__init__()
56 |
57 | self.M = 2
58 | self.G = 2
59 |
60 | self.pad = ksize // 2
61 | inp_gap = max(16, inp//16)
62 | self.inp = inp
63 | self.ksize = ksize
64 | self.stride = stride
65 |
66 | self.wn_fc1 = M.Conv2d(inp_gap, self.M//self.G*inp, 1, 1, 0, groups=1, bias=True)
67 | self.sigmoid = M.Sigmoid()
68 | self.wn_fc2 = M.Conv2d(self.M//self.G*inp, inp*ksize*ksize, 1, 1, 0, groups=inp, bias=False)
69 |
70 |
71 | def forward(self, x, x_gap):
72 | x_w = self.wn_fc1(x_gap)
73 | x_w = self.sigmoid(x_w)
74 | x_w = self.wn_fc2(x_w)
75 |
76 | x = x.reshape(1, -1, x.shape[2], x.shape[3])
77 | x_w = x_w.reshape(-1, 1, 1, self.ksize, self.ksize)
78 | x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad, groups=x_w.shape[0])
79 | x = x.reshape(-1, self.inp, x.shape[2], x.shape[3])
80 | return x
81 |
--------------------------------------------------------------------------------