├── .gitignore ├── README.md ├── data └── flgc.png ├── models └── mobilenet_v2.py ├── modules └── flgc.py ├── print_changes_in_assignment_in_groups.py ├── requirements.txt ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *_checkpoints 2 | .idea 3 | data/cifar-10* 4 | *__pycache__* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for the paper "Fully Learnable Group Convolution for Acceleration of Deep Neural Networks", CVPR 2019 2 | 3 | ## Description 4 | Implementation of the paper ["Fully Learnable Group Convolution for Acceleration of Deep Neural Networks"](https://arxiv.org/pdf/1904.00346.pdf): 5 | 6 | Flgc vs standard convolution 7 | 8 | Given a group number, it proposes to learn which input channels form each group and which filters work with each group. This is obtained with the proposed fully learnable group convolution (FLGC) layer. 9 | 10 | ## Results on CIFAR-10 11 | 12 | | Method | Number of groups | MFLOPS | Accuracy, % | Model | 13 | |:-------------------------|:-----------------|:-------|:------------|:------------| 14 | | MobileNet V2 | N/A | 94.9 |94.43 | N/A | 15 | | MobileNet V2-FLGC (paper)| 8 | 76 |93.09 | N/A | 16 | | MobileNet V2-FLGC (ours) | 8 | 62.6 |93.7 | [Google Drive](https://drive.google.com/file/d/1RXFS9VQmcXvW7698UI4lWmDWRlTEAqjt/view?usp=sharing)| 17 | 18 | ## Implementation Notes 19 | 20 | As an important note (and major drawback for practical usage), there is no such built-in layer, which supports grouping with custom (not uniform) input channels/filters split. So, to see faster network inference time, one should implement it. 21 | 22 | The follow-up paper ["Differentiable Learning-to-Group Channels via Groupable Convolutional Neural Networks"](https://arxiv.org/pdf/1908.05867.pdf) extends this work by making group number also learnable. Despite the improvement in theoretical model complexity, the lack of such inference-optimized layer makes it hard to apply for practical usage. 23 | 24 | ### Training on CIFAR-10 with standard and fully learnable group convolutions 25 | 26 | We have obtained similar results with standard and fully learnable group convolutions in terms of accuracy for MobileNet V2. Accuracy can float up to 0.5% from run to run, which hides the benefit of using fully learnable grouping in these experiments. Possibly experiments at a large scale (ImageNet) will show the full potential of this approach. 27 | 28 | ### Model Architecture 29 | 30 | Since there is no official MobileNet V2 for CIFAR-10 (and authors provide no pre-trained model), there are may be differences between network architectures. We use one from [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar). 31 | 32 | ### FLOPS measurement 33 | 34 | FLOPS were measured with [ptflops](https://github.com/sovrasov/flops-counter.pytorch). 35 | 36 | -------------------------------------------------------------------------------- /data/flgc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniil-Osokin/fully-learnable-group-convolution.pytorch/98aa61c42d2474df4aec0a7db94535558f5cb7d6/data/flgc.png -------------------------------------------------------------------------------- /models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | # mostly borrowed from torchvision/models/mobilenet.py 2 | 3 | from torch import nn 4 | 5 | from modules.flgc import Flgc2d 6 | 7 | 8 | def _make_divisible(v, divisor, min_value=None): 9 | """ 10 | This function is taken from the original tf repo. 11 | It ensures that all layers have a channel number that is divisible by 8 12 | It can be seen here: 13 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 14 | :param v: 15 | :param divisor: 16 | :param min_value: 17 | :return: 18 | """ 19 | if min_value is None: 20 | min_value = divisor 21 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than 10%. 23 | if new_v < 0.9 * v: 24 | new_v += divisor 25 | return new_v 26 | 27 | 28 | class ConvBNReLU(nn.Sequential): 29 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 30 | padding = (kernel_size - 1) // 2 31 | super(ConvBNReLU, self).__init__( 32 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 33 | nn.BatchNorm2d(out_planes), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, expand_ratio, groups_in_1x1, use_flgc=False): 40 | super(InvertedResidual, self).__init__() 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | hidden_dim = int(round(inp * expand_ratio)) 45 | # without conv+bn in shortcut it gives slightly less metric value 46 | self.shortcut = nn.Sequential() 47 | if self.stride == 1 and inp != oup: 48 | self.shortcut = nn.Sequential( 49 | nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False), # TODO: add group conv here 50 | nn.BatchNorm2d(oup), 51 | ) 52 | 53 | pointwise_conv = nn.Conv2d 54 | if use_flgc: 55 | pointwise_conv = Flgc2d 56 | layers = [] 57 | if expand_ratio != 1: 58 | # pw 59 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 60 | layers.extend([ 61 | # dw 62 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 63 | # pw-linear 64 | pointwise_conv(hidden_dim, oup, 1, 1, 0, bias=False, groups=groups_in_1x1), 65 | nn.BatchNorm2d(oup), 66 | ]) 67 | self.conv = nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | return self.conv(x) + self.shortcut(x) if self.stride == 1 else self.conv(x) 71 | 72 | 73 | class MobileNetV2(nn.Module): 74 | def __init__(self, num_classes=10, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, 75 | groups_in_1x1=1, use_flgc=False): 76 | """ 77 | MobileNet V2 main class 78 | 79 | Args: 80 | num_classes (int): Number of classes 81 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 82 | inverted_residual_setting: Network structure 83 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 84 | Set to 1 to turn off rounding 85 | """ 86 | super(MobileNetV2, self).__init__() 87 | block = InvertedResidual 88 | input_channel = 32 89 | last_channel = 1280 90 | 91 | if inverted_residual_setting is None: 92 | inverted_residual_setting = [ 93 | # t, c, n, s 94 | [1, 16, 1, 1], 95 | [6, 24, 2, 1], 96 | [6, 32, 3, 2], 97 | [6, 64, 4, 2], 98 | [6, 96, 3, 1], 99 | [6, 160, 3, 2], 100 | [6, 320, 1, 1], 101 | ] 102 | 103 | # only check the first element, assuming user knows t,c,n,s are required 104 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 105 | raise ValueError("inverted_residual_setting should be non-empty " 106 | "or a 4-element list, got {}".format(inverted_residual_setting)) 107 | 108 | # building first layer 109 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 110 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 111 | features = [ConvBNReLU(3, input_channel, stride=1)] 112 | # building inverted residual blocks 113 | for t, c, n, s in inverted_residual_setting: 114 | output_channel = _make_divisible(c * width_mult, round_nearest) 115 | for i in range(n): 116 | stride = s if i == 0 else 1 117 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, 118 | groups_in_1x1=groups_in_1x1, use_flgc=use_flgc)) 119 | input_channel = output_channel 120 | # building last several layers 121 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 122 | # make it nn.Sequential 123 | self.features = nn.Sequential(*features) 124 | 125 | # building classifier 126 | self.classifier = nn.Linear(self.last_channel, num_classes) # with dropout it gives slightly less metric value 127 | 128 | # weight initialization 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 132 | if m.bias is not None: 133 | nn.init.zeros_(m.bias) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | nn.init.ones_(m.weight) 136 | nn.init.zeros_(m.bias) 137 | elif isinstance(m, nn.Linear): 138 | nn.init.normal_(m.weight, 0, 0.01) 139 | nn.init.zeros_(m.bias) 140 | 141 | def forward(self, x): 142 | x = self.features(x) 143 | x = x.mean([2, 3]) 144 | x = self.classifier(x) 145 | return x 146 | -------------------------------------------------------------------------------- /modules/flgc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Binarize(torch.autograd.Function): 6 | @staticmethod 7 | def forward(context, probs): 8 | binarized = (probs == torch.max(probs, dim=1, keepdim=True)[0]).float() 9 | context.save_for_backward(binarized) 10 | return binarized 11 | 12 | @staticmethod 13 | def backward(context, gradient_output): 14 | binarized, = context.saved_tensors 15 | gradient_output[binarized == 0] = 0 16 | return gradient_output 17 | 18 | 19 | class Flgc2d(nn.Module): 20 | def __init__(self, in_channels, out_channels, kernel_size, 21 | stride=1, padding=0, dilation=1, groups=8, bias=True): 22 | super().__init__() 23 | self.in_channels_in_group_assignment_map = nn.Parameter(torch.Tensor(in_channels, groups)) 24 | nn.init.normal_(self.in_channels_in_group_assignment_map) 25 | self.out_channels_in_group_assignment_map = nn.Parameter(torch.Tensor(out_channels, groups)) 26 | nn.init.normal_(self.out_channels_in_group_assignment_map) 27 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, 1, bias) 28 | self.binarize = Binarize.apply 29 | 30 | def forward(self, x): 31 | map = torch.mm(self.binarize(torch.softmax(self.out_channels_in_group_assignment_map, dim=1)), 32 | torch.t(self.binarize(torch.softmax(self.in_channels_in_group_assignment_map, dim=1)))) 33 | return nn.functional.conv2d(x, self.conv.weight * map[:, :, None, None], self.conv.bias, 34 | self.conv.stride, self.conv.padding, self.conv.dilation) 35 | 36 | 37 | if __name__ == '__main__': 38 | x = torch.randn(4, 3, 7, 7) 39 | conv = Flgc2d(3, 16, 3, padding=1, groups=4) 40 | out = conv(x) 41 | print(out.shape) 42 | -------------------------------------------------------------------------------- /print_changes_in_assignment_in_groups.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from modules.flgc import Flgc2d 9 | 10 | 11 | def get_epoch(checkpoint_path): 12 | return int(checkpoint_path.split('net_epoch_')[1].split('.pth')[0]) 13 | 14 | 15 | def get_flgc_layer(net, needed_layer_id): 16 | layer_id = -1 17 | for module in net.modules(): 18 | if isinstance(module, Flgc2d): 19 | layer_id += 1 20 | if layer_id == needed_layer_id: 21 | return module 22 | raise IndexError('No flgc layer with such id, possible range: [{}, {}]'.format(0, layer_id)) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser('Allows to track changes in assignment of input channels/filters in groups') 27 | parser.add_argument('--checkpoints-dir', type=str, required=True, help='path to the directory with model checkpoints') 28 | parser.add_argument('--layer-id', type=int, default=0, help='flgc layer id to print statistic') 29 | parser.add_argument('--track-filters', action='store_true', default=False, help='track changes in assignment of filters in groups') 30 | args = parser.parse_args() 31 | 32 | checkpoint_paths = glob.glob(os.path.join(args.checkpoints_dir, 'net_epoch_*.pth')) 33 | checkpoint_paths = sorted(checkpoint_paths, key=lambda path: get_epoch(path)) 34 | 35 | net = torch.load(checkpoint_paths[0], map_location='cpu')['net'] 36 | module = get_flgc_layer(net, args.layer_id) 37 | rows_names = ['{:^13}'.format('epoch {}'.format(get_epoch(checkpoint_paths[0])))] 38 | assignment_map_to_check = module.out_channels_in_group_assignment_map if args.track_filters else module.in_channels_in_group_assignment_map 39 | previous_assignment_map = module.binarize(torch.softmax(assignment_map_to_check, dim=1)).detach().data.cpu().numpy() 40 | channels_per_group = np.sum(previous_assignment_map, axis=0).astype(np.int32) 41 | cols_names = ['{:^13}'.format('group {}'.format(group_id)) for group_id in range(len(channels_per_group))] 42 | cols_names = ['{:^13}'.format('')] + cols_names + ['{:^13}'.format('Total diff.')] 43 | stats = [[] for _ in range(len(channels_per_group) + 1)] 44 | for group_id, channels_num in enumerate(channels_per_group): 45 | stats[group_id].append('{:^13}'.format('{:+d}/{:+d}/{:d}'.format(0, 0, channels_num))) 46 | stats[-1].append('{:^13}'.format(0)) 47 | 48 | for checkpoint_path in checkpoint_paths[1:]: 49 | rows_names.append('{:^13}'.format('epoch {}'.format(get_epoch(checkpoint_path)))) 50 | net = torch.load(checkpoint_path, map_location='cpu')['net'] 51 | module = get_flgc_layer(net, args.layer_id) 52 | assignment_map_to_check = module.out_channels_in_group_assignment_map if args.track_filters else module.in_channels_in_group_assignment_map 53 | assignment_map = module.binarize(torch.softmax(assignment_map_to_check, dim=1)).detach().data.cpu().numpy() 54 | assignment_map_diff = assignment_map - previous_assignment_map 55 | deleted_channels_per_group = assignment_map_diff.copy() 56 | deleted_channels_per_group[deleted_channels_per_group > 0] = 0 57 | deleted_channels_per_group = np.sum(deleted_channels_per_group, axis=0).astype(np.int32) 58 | added_channels_per_group = assignment_map_diff.copy() 59 | added_channels_per_group[added_channels_per_group < 0] = 0 60 | added_channels_per_group = np.sum(added_channels_per_group, axis=0).astype(np.int32) 61 | channels_per_group = np.sum(assignment_map, axis=0).astype(np.int32) 62 | for group_id, channels_num in enumerate(channels_per_group): 63 | stats[group_id].append('{:^13}'.format('{:+d}/{:+d}/{:d}'.format(deleted_channels_per_group[group_id], 64 | added_channels_per_group[group_id], channels_num))) 65 | stats[-1].append('{:^13}'.format(np.sum(np.abs(assignment_map_diff).astype(np.int32)))) 66 | previous_assignment_map = assignment_map 67 | for col_name in cols_names: 68 | print(col_name, end='') 69 | print('\n') 70 | for row_id, row_name in enumerate(rows_names): 71 | print(row_name, end='') 72 | for group_stat in stats: 73 | print(group_stat[row_id], end='') 74 | print('\n') 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | 7 | from models.mobilenet_v2 import MobileNetV2 8 | 9 | 10 | def test(net, dataloader): 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | net = net.to(device) 13 | net.eval() 14 | total = 0 15 | correct = 0 16 | with torch.no_grad(): 17 | for batch_idx, (inputs, targets) in enumerate(dataloader): 18 | inputs, targets = inputs.to(device), targets.to(device) 19 | outputs = net(inputs) 20 | _, predicted = outputs.max(1) 21 | total += targets.size(0) 22 | correct += predicted.eq(targets).sum().item() 23 | accuracy = correct / total * 100 24 | return accuracy 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--checkpoint-path', type=str, default=None, help='checkpoint path to continue training from') 30 | parser.add_argument('--num-groups', type=int, default=8, help='group number in group convolutions') 31 | parser.add_argument('--use-standard-group-convolutions', action='store_true', default=False, 32 | help='use standard group convolutions instead of fully learnable') 33 | args = parser.parse_args() 34 | 35 | net = MobileNetV2(groups_in_1x1=args.num_groups, use_flgc=(not args.use_standard_group_convolutions)) 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | net = net.to(device) 38 | checkpoint = torch.load(args.checkpoint_path, map_location='cpu') 39 | net.load_state_dict(checkpoint['state_dict']) 40 | 41 | transform_test = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 46 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=False, num_workers=2) 47 | 48 | print('Testing...') 49 | accuracy = test(net, dataloader) 50 | print('Accuracy: {}%'.format(accuracy)) 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | from models.mobilenet_v2 import MobileNetV2 11 | from test import test 12 | 13 | 14 | def train(checkpoint_path, num_groups, use_standard_group_convolutions, checkpoints_folder, num_epochs_to_dump_net): 15 | net = MobileNetV2(groups_in_1x1=num_groups, use_flgc=(not use_standard_group_convolutions)) 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | net = net.to(device) 18 | optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=4e-5) 19 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250, 350]) 20 | best_accuracy = 0 21 | start_epoch = 0 22 | if checkpoint_path is not None: 23 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 24 | net.load_state_dict(checkpoint['state_dict']) 25 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 26 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) 27 | best_accuracy = checkpoint['best_accuracy'] 28 | start_epoch = checkpoint['last_epoch'] + 1 29 | 30 | transform_train = transforms.Compose([ 31 | transforms.RandomCrop(32, padding=4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 35 | ]) 36 | 37 | transform_test = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 40 | ]) 41 | 42 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 43 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) 44 | 45 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 46 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 47 | 48 | for epoch_id in range(start_epoch, 400): 49 | print('Epoch: {}'.format(epoch_id)) 50 | if num_epochs_to_dump_net is not None: 51 | if epoch_id % num_epochs_to_dump_net == 0: 52 | torch.save({'net': net}, os.path.join(checkpoints_folder, 'net_epoch_{}.pth'.format(epoch_id))) 53 | net.train() 54 | for batch_idx, (inputs, targets) in enumerate(trainloader): 55 | inputs, targets = inputs.to(device), targets.to(device) 56 | optimizer.zero_grad() 57 | outputs = net(inputs) 58 | loss = F.cross_entropy(outputs, targets) 59 | loss.backward() 60 | optimizer.step() 61 | if (batch_idx+1) % 100 == 0: 62 | print('Batch: {}, loss: {}'.format(batch_idx + 1, loss.item())) 63 | 64 | print('Testing...') 65 | accuracy = test(net, testloader) 66 | print('Accuracy: {}%'.format(accuracy)) 67 | lr_scheduler.step() 68 | 69 | if accuracy > best_accuracy: 70 | best_accuracy = accuracy 71 | torch.save({'state_dict': net.state_dict(), 72 | 'optimizer_state_dict': optimizer.state_dict(), 73 | 'lr_scheduler_state_dict': lr_scheduler.state_dict(), 74 | 'best_accuracy': best_accuracy, 75 | 'last_epoch': epoch_id}, 76 | os.path.join(checkpoints_folder, 'best_checkpoint.pth')) 77 | torch.save({'state_dict': net.state_dict(), 78 | 'optimizer_state_dict': optimizer.state_dict(), 79 | 'lr_scheduler_state_dict': lr_scheduler.state_dict(), 80 | 'best_accuracy': best_accuracy, 81 | 'last_epoch': epoch_id}, 82 | os.path.join(checkpoints_folder, 'last_checkpoint.pth')) 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser('Trains MobileNet V2 with different types of group convolutions\n' 87 | 'in place of 1x1 convolutions on CIFAR-10 dataset. It compares\n' 88 | 'fully learnable group convolution and standard group convolution.') 89 | parser.add_argument('--checkpoint-path', type=str, default=None, help='checkpoint path to continue training from') 90 | parser.add_argument('--num-groups', type=int, default=8, help='group number in group convolutions') 91 | parser.add_argument('--use-standard-group-convolutions', action='store_true', default=False, 92 | help='use standard group convolutions instead of fully learnable') 93 | parser.add_argument('--experiment-name', type=str, default='default', 94 | help='experiment name to create folder for checkpoints') 95 | parser.add_argument('--num-epochs-to-dump-net', type=int, default=None, help='number of epochs to dump network to futher analyze it') 96 | args = parser.parse_args() 97 | 98 | checkpoints_folder = '{}_checkpoints'.format(args.experiment_name) 99 | if not os.path.exists(checkpoints_folder): 100 | os.makedirs(checkpoints_folder) 101 | 102 | train(args.checkpoint_path, args.num_groups, args.use_standard_group_convolutions, checkpoints_folder, args.num_epochs_to_dump_net) 103 | --------------------------------------------------------------------------------