├── README.md ├── main.py ├── shufflenet.py ├── shufflenet_v2.py └── train_log_20180218_1530.out /README.md: -------------------------------------------------------------------------------- 1 | # ShuffleNet with PyTorch 2 | **Note:** This project is pytorch implementation of [ShuffleNet](https://arxiv.org/abs/1707.01083). 3 | 4 | ### Performance 5 | 6 | Trained on ImageNet with groups=3, get Prec@1 67.898% and Prec@5 87.994%. During the training, I 7 | set batch_size=256, learning_rate=0.1 which decayed every 30 epoch by 10. 8 | 9 | 10 | ### Training on ImageNet 11 | 12 | ```bash 13 | python main -b 256 $Imagenetdir 14 | ``` 15 | 16 | License: MIT license (MIT) 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import torchvision.models as models 17 | from shufflenet import shufflenet 18 | import math 19 | 20 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('--arch', '-a', default='shufflenet',type=str) 30 | parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=256, type=int, 37 | metavar='N', help='mini-batch size (default: 256)') 38 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 39 | metavar='LR', help='initial learning rate') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=4e-5, type=float, 43 | metavar='W', help='weight decay (default: 4e-5)') 44 | parser.add_argument('--print-freq', '-p', default=10, type=int, 45 | metavar='N', help='print frequency (default: 10)') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 51 | help='use pre-trained model') 52 | parser.add_argument('--world-size', default=1, type=int, 53 | help='number of distributed processes') 54 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 55 | help='url used to set up distributed training') 56 | parser.add_argument('--dist-backend', default='gloo', type=str, 57 | help='distributed backend') 58 | 59 | best_prec1 = 0 60 | 61 | def weight_init(m): 62 | if isinstance(m, nn.Conv2d): 63 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 64 | m.weight.data.normal_(0, math.sqrt(2. / n)) 65 | if m.bias is not None: 66 | m.bias.data.zero_() 67 | elif isinstance(m, nn.BatchNorm2d): 68 | m.weight.data.fill_(1) 69 | m.bias.data.zero_() 70 | 71 | def main(): 72 | global args, best_prec1 73 | args = parser.parse_args() 74 | 75 | args.distributed = False 76 | 77 | if args.distributed: 78 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 79 | world_size=args.world_size) 80 | 81 | model = shufflenet() 82 | model.cuda() 83 | model.apply(weight_init) 84 | # create model 85 | """ 86 | if args.pretrained: 87 | print("=> using pre-trained model '{}'".format(args.arch)) 88 | model = models.__dict__[args.arch](pretrained=True) 89 | else: 90 | print("=> creating model '{}'".format(args.arch)) 91 | 92 | if not args.distributed: 93 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 94 | model.features = torch.nn.DataParallel(model.features) 95 | model.cuda() 96 | else: 97 | model = torch.nn.DataParallel(model).cuda() 98 | else: 99 | model.cuda() 100 | model = torch.nn.parallel.DistributedDataParallel(model) 101 | """ 102 | # define loss function (criterion) and optimizer 103 | criterion = nn.CrossEntropyLoss().cuda() 104 | 105 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 106 | momentum=args.momentum, 107 | weight_decay=args.weight_decay) 108 | 109 | # optionally resume from a checkpoint 110 | if args.resume: 111 | if os.path.isfile(args.resume): 112 | print("=> loading checkpoint '{}'".format(args.resume)) 113 | checkpoint = torch.load(args.resume) 114 | args.start_epoch = checkpoint['epoch'] 115 | best_prec1 = checkpoint['best_prec1'] 116 | model.load_state_dict(checkpoint['state_dict']) 117 | optimizer.load_state_dict(checkpoint['optimizer']) 118 | print("=> loaded checkpoint '{}' (epoch {})" 119 | .format(args.resume, checkpoint['epoch'])) 120 | else: 121 | print("=> no checkpoint found at '{}'".format(args.resume)) 122 | 123 | cudnn.benchmark = True 124 | 125 | # Data loading code 126 | traindir = os.path.join(args.data, 'train') 127 | valdir = os.path.join(args.data, 'val') 128 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 129 | std=[0.229, 0.224, 0.225]) 130 | 131 | train_dataset = datasets.ImageFolder( 132 | traindir, 133 | transforms.Compose([ 134 | transforms.RandomResizedCrop(224), 135 | transforms.RandomHorizontalFlip(), 136 | transforms.ToTensor(), 137 | normalize, 138 | ])) 139 | 140 | if args.distributed: 141 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 142 | else: 143 | train_sampler = None 144 | 145 | train_loader = torch.utils.data.DataLoader( 146 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 147 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 148 | 149 | val_loader = torch.utils.data.DataLoader( 150 | datasets.ImageFolder(valdir, transforms.Compose([ 151 | transforms.Resize(256), 152 | transforms.CenterCrop(224), 153 | transforms.ToTensor(), 154 | normalize, 155 | ])), 156 | batch_size=args.batch_size, shuffle=False, 157 | num_workers=args.workers, pin_memory=True) 158 | 159 | if args.evaluate: 160 | validate(val_loader, model, criterion) 161 | return 162 | 163 | for epoch in range(args.start_epoch, args.epochs): 164 | if args.distributed: 165 | train_sampler.set_epoch(epoch) 166 | adjust_learning_rate(optimizer, epoch) 167 | 168 | # train for one epoch 169 | train(train_loader, model, criterion, optimizer, epoch) 170 | 171 | # evaluate on validation set 172 | prec1 = validate(val_loader, model, criterion) 173 | 174 | # remember best prec@1 and save checkpoint 175 | is_best = prec1 > best_prec1 176 | best_prec1 = max(prec1, best_prec1) 177 | save_checkpoint({ 178 | 'epoch': epoch + 1, 179 | 'arch': args.arch, 180 | 'state_dict': model.state_dict(), 181 | 'best_prec1': best_prec1, 182 | 'optimizer' : optimizer.state_dict(), 183 | }, is_best) 184 | 185 | 186 | def train(train_loader, model, criterion, optimizer, epoch): 187 | batch_time = AverageMeter() 188 | data_time = AverageMeter() 189 | losses = AverageMeter() 190 | top1 = AverageMeter() 191 | top5 = AverageMeter() 192 | 193 | # switch to train mode 194 | model.train() 195 | 196 | end = time.time() 197 | for i, (input, target) in enumerate(train_loader): 198 | # measure data loading time 199 | data_time.update(time.time() - end) 200 | 201 | target = target.cuda(async=True) 202 | input_var = torch.autograd.Variable(input).cuda() 203 | target_var = torch.autograd.Variable(target).cuda() 204 | 205 | # compute output 206 | output = model(input_var) 207 | loss = criterion(output, target_var) 208 | 209 | # measure accuracy and record loss 210 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 211 | losses.update(loss.data[0], input.size(0)) 212 | top1.update(prec1[0], input.size(0)) 213 | top5.update(prec5[0], input.size(0)) 214 | 215 | # compute gradient and do SGD step 216 | optimizer.zero_grad() 217 | loss.backward() 218 | optimizer.step() 219 | 220 | # measure elapsed time 221 | batch_time.update(time.time() - end) 222 | end = time.time() 223 | 224 | if i % args.print_freq == 0: 225 | print('Epoch: [{0}][{1}/{2}]\t' 226 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 227 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 228 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 229 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 230 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 231 | epoch, i, len(train_loader), batch_time=batch_time, 232 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 233 | 234 | 235 | def validate(val_loader, model, criterion): 236 | batch_time = AverageMeter() 237 | losses = AverageMeter() 238 | top1 = AverageMeter() 239 | top5 = AverageMeter() 240 | 241 | # switch to evaluate mode 242 | model.eval() 243 | 244 | end = time.time() 245 | for i, (input, target) in enumerate(val_loader): 246 | target = target.cuda(async=True) 247 | input_var = torch.autograd.Variable(input, volatile=True).cuda() 248 | target_var = torch.autograd.Variable(target, volatile=True).cuda() 249 | 250 | # compute output 251 | output = model(input_var) 252 | loss = criterion(output, target_var) 253 | 254 | # measure accuracy and record loss 255 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 256 | losses.update(loss.data[0], input.size(0)) 257 | top1.update(prec1[0], input.size(0)) 258 | top5.update(prec5[0], input.size(0)) 259 | 260 | # measure elapsed time 261 | batch_time.update(time.time() - end) 262 | end = time.time() 263 | 264 | if i % args.print_freq == 0: 265 | print('Test: [{0}/{1}]\t' 266 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 267 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 268 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 269 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 270 | i, len(val_loader), batch_time=batch_time, loss=losses, 271 | top1=top1, top5=top5)) 272 | 273 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 274 | .format(top1=top1, top5=top5)) 275 | 276 | return top1.avg 277 | 278 | 279 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 280 | torch.save(state, filename) 281 | if is_best: 282 | shutil.copyfile(filename, 'model_best.pth.tar') 283 | 284 | 285 | class AverageMeter(object): 286 | """Computes and stores the average and current value""" 287 | def __init__(self): 288 | self.reset() 289 | 290 | def reset(self): 291 | self.val = 0 292 | self.avg = 0 293 | self.sum = 0 294 | self.count = 0 295 | 296 | def update(self, val, n=1): 297 | self.val = val 298 | self.sum += val * n 299 | self.count += n 300 | self.avg = self.sum / self.count 301 | 302 | 303 | def adjust_learning_rate(optimizer, epoch): 304 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 305 | lr = args.lr * (0.1 ** (epoch // 30)) 306 | for param_group in optimizer.param_groups: 307 | param_group['lr'] = lr 308 | 309 | 310 | def accuracy(output, target, topk=(1,)): 311 | """Computes the precision@k for the specified values of k""" 312 | maxk = max(topk) 313 | batch_size = target.size(0) 314 | 315 | _, pred = output.topk(maxk, 1, True, True) 316 | pred = pred.t() 317 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 318 | 319 | res = [] 320 | for k in topk: 321 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 322 | res.append(correct_k.mul_(100.0 / batch_size)) 323 | return res 324 | 325 | 326 | if __name__ == '__main__': 327 | main() 328 | -------------------------------------------------------------------------------- /shufflenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import numpy as np 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class ShufflenetUnit(nn.Module): 15 | expansion = 4 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, flag=False): 17 | super(ShufflenetUnit, self).__init__() 18 | self.downsample = downsample 19 | group_num = 3 20 | self.flag = flag 21 | if self.flag: 22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, groups=1, bias=False) 23 | else: 24 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, groups=group_num, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | 31 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, groups=group_num, bias=False) 32 | self.bn3 = nn.BatchNorm2d(planes * 4) 33 | self.relu = nn.ReLU(inplace=True) 34 | 35 | def _shuffle(self, features, g): 36 | channels = features.size()[1] 37 | index = torch.from_numpy(np.asarray([i for i in range(channels)])) 38 | index = index.view(-1, g).t().contiguous() 39 | index = index.view(-1).cuda() 40 | features = features[:, index] 41 | return features 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | if not self.flag: 51 | out = self._shuffle(out, 3) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | out = torch.cat((out, residual), 1) 62 | else: 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | class ShuffleNet(nn.Module): 69 | inplanes = 24 70 | def __init__(self, block, layers, num_classes=1000): 71 | super(ShuffleNet, self).__init__() 72 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=24, kernel_size=3, 73 | padding=1, stride=2, bias=False) 74 | self.bn1 = nn.BatchNorm2d(24) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 77 | 78 | self.stage2 = self._make_layer(block, 240, layers[0], True) 79 | self.stage3 = self._make_layer(block, 480, layers[1], False) 80 | self.stage4 = self._make_layer(block, 960, layers[2], False) 81 | 82 | self.globalpool = nn.AvgPool2d(kernel_size=7, stride=1) 83 | self.fc = nn.Linear(960, num_classes) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, flag): 94 | downsample = nn.Sequential( 95 | nn.AvgPool2d(kernel_size=3, stride=2,padding=1) 96 | ) 97 | 98 | inner_plane = (planes - self.inplanes) / 4 99 | layers = [] 100 | layers.append(block(self.inplanes, inner_plane, 2, downsample, flag=flag)) 101 | self.inplanes = planes 102 | for i in range(blocks): 103 | layers.append(block(planes, planes/4)) 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self,x): 108 | x = self.conv1(x) 109 | x = self.bn1(x) 110 | x = self.relu(x) 111 | x = self.maxpool(x) 112 | 113 | x = self.stage2(x) 114 | x = self.stage3(x) 115 | x = self.stage4(x) 116 | 117 | x = self.globalpool(x) 118 | x = x.view(x.size(0), -1) 119 | x = self.fc(x) 120 | 121 | return x 122 | 123 | 124 | def shufflenet(): 125 | model = ShuffleNet(ShufflenetUnit, [3, 7, 3]) 126 | return model 127 | 128 | if __name__=="__main__": 129 | model = shufflenet() 130 | print model 131 | -------------------------------------------------------------------------------- /shufflenet_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import math 7 | import numpy as np 8 | 9 | def conv3x3(in_channels, out_channels, stride, padding=1, groups=1): 10 | """3x3 convolution""" 11 | return nn.Conv2d(in_channels, out_channels, 12 | kernel_size=3, stride=stride, padding=padding, 13 | groups=groups, 14 | bias=False) 15 | 16 | def conv1x1(in_channels, out_channels, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_channels, out_channels, 19 | kernel_size=1, stride=stride,padding=0, 20 | bias=False) 21 | 22 | class ShufflenetUnit(nn.Module): 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(ShufflenetUnit, self).__init__() 25 | self.downsample = downsample 26 | 27 | if not self.downsample: #---if not downsample, then channel split, so the channel become half 28 | inplanes = inplanes // 2 29 | planes = planes // 2 30 | 31 | self.conv1x1_1 = conv1x1(in_channels=inplanes, out_channels=planes) 32 | self.conv1x1_1_bn = nn.BatchNorm2d(planes) 33 | 34 | self.dwconv3x3 = conv3x3(in_channels=planes, out_channels=planes, stride=stride, groups=planes) 35 | self.dwconv3x3_bn= nn.BatchNorm2d(planes) 36 | 37 | self.conv1x1_2 = conv1x1(in_channels=planes, out_channels=planes) 38 | self.conv1x1_2_bn = nn.BatchNorm2d(planes) 39 | 40 | self.relu = nn.ReLU(inplace=True) 41 | 42 | def _channel_split(self, features, ratio=0.5): 43 | """ 44 | ratio: c'/c, default value is 0.5 45 | """ 46 | size = features.size()[1] 47 | split_idx = int(size * ratio) 48 | return features[:,:split_idx], features[:,split_idx:] 49 | 50 | def _channel_shuffle(self, features, g=2): 51 | channels = features.size()[1] 52 | index = torch.from_numpy(np.asarray([i for i in range(channels)])) 53 | index = index.view(-1, g).t().contiguous() 54 | index = index.view(-1).cuda() 55 | features = features[:, index] 56 | return features 57 | 58 | def forward(self, x): 59 | if self.downsample: 60 | #x1 = x.clone() #----deep copy x, so where x2 is modified, x1 not be affected 61 | x1 = x 62 | x2 = x 63 | else: 64 | x1, x2 = self._channel_split(x) 65 | 66 | #----right branch----- 67 | x2 = self.conv1x1_1(x2) 68 | x2 = self.conv1x1_1_bn(x2) 69 | x2 = self.relu(x2) 70 | 71 | x2 = self.dwconv3x3(x2) 72 | x2 = self.dwconv3x3_bn(x2) 73 | 74 | x2 = self.conv1x1_2(x2) 75 | x2 = self.conv1x1_2_bn(x2) 76 | x2 = self.relu(x2) 77 | 78 | #---left branch------- 79 | if self.downsample: 80 | x1 = self.downsample(x1) 81 | 82 | x = torch.cat([x1, x2], 1) 83 | x = self._channel_shuffle(x) 84 | return x 85 | 86 | class ShuffleNet(nn.Module): 87 | def __init__(self, feature_dim, layers_num, num_classes=1000): 88 | super(ShuffleNet, self).__init__() 89 | dim1, dim2, dim3, dim4, dim5 = feature_dim 90 | self.conv1 = conv3x3(in_channels=3, out_channels=dim1, 91 | stride=2, padding=1) 92 | self.bn1 = nn.BatchNorm2d(dim1) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | 96 | self.stage2 = self._make_layer(dim1, dim2, layers_num[0]) 97 | self.stage3 = self._make_layer(dim2, dim3, layers_num[1]) 98 | self.stage4 = self._make_layer(dim3, dim4, layers_num[2]) 99 | 100 | self.conv5 = conv1x1(in_channels=dim4, out_channels=dim5) 101 | self.globalpool = nn.AvgPool2d(kernel_size=7, stride=1) 102 | self.fc = nn.Linear(dim5, num_classes) 103 | 104 | """ 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | """ 113 | 114 | def _make_layer(self, dim1, dim2, blocks_num): 115 | half_channel = dim2 // 2 116 | downsample = nn.Sequential( 117 | conv3x3(in_channels=dim1, out_channels=dim1, stride=2, padding=1, groups=dim1), 118 | nn.BatchNorm2d(dim1), 119 | conv1x1(in_channels=dim1, out_channels=half_channel), 120 | nn.BatchNorm2d(half_channel), 121 | nn.ReLU(inplace=True) 122 | ) 123 | 124 | layers = [] 125 | layers.append(ShufflenetUnit(dim1, half_channel, stride=2, downsample=downsample)) 126 | for i in range(blocks_num): 127 | layers.append(ShufflenetUnit(dim2, dim2, stride=1)) 128 | 129 | return nn.Sequential(*layers) 130 | 131 | def forward(self,x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | #print("x0.size:\t", x.size()) 136 | x = self.maxpool(x) 137 | #print("x1.size:\t", x.size()) 138 | x = self.stage2(x) 139 | #print("x2.size:\t", x.size()) 140 | x = self.stage3(x) 141 | #print("x3.size:\t", x.size()) 142 | x = self.stage4(x) 143 | #print("x4.size:\t", x.size()) 144 | 145 | x = self.conv5(x) 146 | #print("x5.size:\t", x.size()) 147 | x = self.globalpool(x) 148 | #print("x6.size:\t", x.size()) 149 | 150 | x = x.view(-1, 1024) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | features = { 156 | "0.5x":[24, 48, 96, 192, 1024], 157 | "1x":[24, 116, 232, 464, 1024], 158 | "1.5x":[24, 176, 352, 704, 1024], 159 | "2x":[24, 244, 488, 976, 2048] 160 | } 161 | 162 | def shufflenet(): 163 | model = ShuffleNet(features["1x"], [3, 7, 3]) 164 | return model 165 | 166 | if __name__=="__main__": 167 | model = shufflenet().cuda() 168 | print(model) 169 | x = torch.rand((1,3,224,224)) 170 | x = Variable(x).cuda() 171 | x = model(x) 172 | --------------------------------------------------------------------------------