├── README.md ├── cifar10 ├── models │ ├── __init__.py │ └── resnet.py ├── train.py └── utils.py └── mnist └── mnist.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-gconv-experiments 2 | Experiments with [Group Equivariant Convolutional Networks (T. S. Cohen, M. Welling, 2016)](https://arxiv.org/abs/1602.07576) implemented in PyTorch. 3 | 4 | # Installation 5 | 6 | Install [GrouPy](https://github.com/adambielski/GrouPy) with PyTorch support. 7 | 8 | # MNIST 9 | 10 | Modified [MNIST PyTorch example](https://github.com/pytorch/examples/tree/master/mnist) validating my implementation of G-convolutions in PyTorch. 11 | 12 | ``` 13 | cd mnist 14 | python mnist.py 15 | ``` 16 | 17 | This simple example uses p4 group convolutions and plane group spatial max pooling. 18 | 19 | # CIFAR-10 20 | 21 | Experiments with ResNet implementation based by [kuangliu repository](https://github.com/kuangliu/pytorch-cifar) for CIFAR-10 with PyTorch. Training uses online data augmentation with translation and flips 22 | 23 | All planar convolutions were replaced with p4m group convolutions. The number of filters in each convolutional layer was reduced by sqrt(8) to keep similar number of parameters (following [Group Equivariant Convolutional Networks](https://arxiv.org/abs/1602.07576), section 8.2). 24 | 25 | To train the ResNet18 network run 26 | 27 | ``` 28 | cd cifar10 29 | python train.py --n_epochs 120 --checkpoint resnet18_p4m --lr=0.01 30 | ``` 31 | 32 | The learning rate is reduced by a factor of 0.1 after 50 and 100 epochs. 33 | 34 | After 120 epochs, the network achieves **94.22%** on test set, compared to 93.02% using planar convolutions reported [here](https://github.com/kuangliu/pytorch-cifar). -------------------------------------------------------------------------------- /cifar10/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adambielski/pytorch-gconv-experiments/193b391672b50917ead7e63f509a19f651af5e18/cifar10/models/__init__.py -------------------------------------------------------------------------------- /cifar10/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | from groupy.gconv.pytorch_gconv import P4MConvZ2, P4MConvP4M 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = P4MConvP4M(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = P4MConvP4M(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | P4MConvP4M(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = P4MConvP4M(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = P4MConvP4M(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = P4MConvP4M(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | P4MConvP4M(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 23 73 | 74 | self.conv1 = P4MConvZ2(3, 23, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(23) 76 | self.layer1 = self._make_layer(block, 23, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 45, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 91, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 181, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(181*8*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | outs = out.size() 97 | out = out.view(outs[0], outs[1]*outs[2], outs[3], outs[4]) 98 | out = F.avg_pool2d(out, 4) 99 | out = out.view(out.size(0), -1) 100 | out = self.linear(out) 101 | return out 102 | 103 | 104 | def ResNet18(): 105 | return ResNet(BasicBlock, [2,2,2,2]) 106 | 107 | def ResNet34(): 108 | return ResNet(BasicBlock, [3,4,6,3]) 109 | 110 | def ResNet50(): 111 | return ResNet(Bottleneck, [3,4,6,3]) 112 | 113 | def ResNet101(): 114 | return ResNet(Bottleneck, [3,4,23,3]) 115 | 116 | def ResNet152(): 117 | return ResNet(Bottleneck, [3,8,36,3]) 118 | 119 | 120 | def test(): 121 | net = ResNet18() 122 | y = net(Variable(torch.randn(1,3,32,32))) 123 | print(y.size()) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /cifar10/train.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | 14 | import os 15 | import argparse 16 | 17 | from models.resnet import * 18 | from utils import progress_bar 19 | from torch.autograd import Variable 20 | 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 24 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 25 | parser.add_argument('--n_epochs', default=350, type=int) 26 | parser.add_argument('--checkpoint', required=True) 27 | args = parser.parse_args() 28 | 29 | use_cuda = torch.cuda.is_available() 30 | best_acc = 0 # best test accuracy 31 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 32 | n_epochs = args.n_epochs 33 | 34 | # Data 35 | means = (0.4914, 0.4822, 0.4465) 36 | print('==> Preparing data..') 37 | transform_train = transforms.Compose([ 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | transforms.Normalize(means, (0.2023, 0.1994, 0.2010)), 42 | ]) 43 | 44 | transform_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize(means, (0.2023, 0.1994, 0.2010)), 47 | ]) 48 | 49 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 50 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 51 | 52 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 53 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 54 | 55 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 56 | 57 | # Model 58 | if args.resume: 59 | # Load checkpoint. 60 | print('==> Resuming from checkpoint..') 61 | assert os.path.isdir(args.checkpoint), 'Error: no checkpoint directory found!' 62 | checkpoint = torch.load(os.path.join(args.checkpoint, 'ckpt.t7')) 63 | net = checkpoint['net'] 64 | best_acc = checkpoint['acc'] 65 | start_epoch = checkpoint['epoch'] 66 | else: 67 | print('==> Building model..') 68 | net = ResNet18() 69 | 70 | if use_cuda: 71 | net.cuda() 72 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 73 | cudnn.benchmark = True 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 77 | 78 | # Training 79 | def train(epoch): 80 | print('\nEpoch: %d' % epoch) 81 | net.train() 82 | train_loss = 0 83 | correct = 0 84 | total = 0 85 | for batch_idx, (inputs, targets) in enumerate(trainloader): 86 | if use_cuda: 87 | inputs, targets = inputs.cuda(), targets.cuda() 88 | optimizer.zero_grad() 89 | inputs, targets = Variable(inputs), Variable(targets) 90 | outputs = net(inputs) 91 | loss = criterion(outputs, targets) 92 | loss.backward() 93 | optimizer.step() 94 | 95 | train_loss += loss.data[0] 96 | _, predicted = torch.max(outputs.data, 1) 97 | total += targets.size(0) 98 | correct += predicted.eq(targets.data).cpu().sum() 99 | 100 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 101 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 102 | 103 | def test(epoch): 104 | global best_acc 105 | net.eval() 106 | test_loss = 0 107 | correct = 0 108 | total = 0 109 | for batch_idx, (inputs, targets) in enumerate(testloader): 110 | if use_cuda: 111 | inputs, targets = inputs.cuda(), targets.cuda() 112 | inputs, targets = Variable(inputs, volatile=True), Variable(targets) 113 | outputs = net(inputs) 114 | loss = criterion(outputs, targets) 115 | 116 | test_loss += loss.data[0] 117 | _, predicted = torch.max(outputs.data, 1) 118 | total += targets.size(0) 119 | correct += predicted.eq(targets.data).cpu().sum() 120 | 121 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 122 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 123 | 124 | # Save checkpoint. 125 | acc = 100.*correct/total 126 | if acc > best_acc: 127 | print('Saving..') 128 | state = { 129 | 'net': net.module if use_cuda else net, 130 | 'acc': acc, 131 | 'epoch': epoch, 132 | } 133 | if not os.path.isdir(args.checkpoint): 134 | os.mkdir(args.checkpoint) 135 | torch.save(state, os.path.join(args.checkpoint, 'ckpt.t7')) 136 | best_acc = acc 137 | 138 | milestones = [50, 100, 140] 139 | scheduler = MultiStepLR(optimizer, milestones, gamma=0.1) 140 | for epoch in range(start_epoch): 141 | scheduler.step() 142 | 143 | for epoch in range(start_epoch, n_epochs): 144 | scheduler.step() 145 | train(epoch) 146 | test(epoch) 147 | -------------------------------------------------------------------------------- /cifar10/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f -------------------------------------------------------------------------------- /mnist/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | from groupy.gconv.pytorch_gconv.splitgconv2d import P4ConvZ2, P4ConvP4 10 | from groupy.gconv.pytorch_gconv.pooling import plane_group_spatial_max_pooling 11 | 12 | 13 | # Training settings 14 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 15 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 16 | help='input batch size for training (default: 64)') 17 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 18 | help='input batch size for testing (default: 1000)') 19 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 20 | help='number of epochs to train (default: 10)') 21 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 22 | help='learning rate (default: 0.01)') 23 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 24 | help='SGD momentum (default: 0.5)') 25 | parser.add_argument('--no-cuda', action='store_true', default=False, 26 | help='disables CUDA training') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', 28 | help='random seed (default: 1)') 29 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 30 | help='how many batches to wait before logging training status') 31 | args = parser.parse_args() 32 | args.cuda = not args.no_cuda and torch.cuda.is_available() 33 | 34 | torch.manual_seed(args.seed) 35 | if args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | 39 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 40 | train_loader = torch.utils.data.DataLoader( 41 | datasets.MNIST('../data', train=True, download=True, 42 | transform=transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.1307,), (0.3081,)) 45 | ])), 46 | batch_size=args.batch_size, shuffle=True, **kwargs) 47 | test_loader = torch.utils.data.DataLoader( 48 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)) 51 | ])), 52 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 53 | 54 | 55 | class Net(nn.Module): 56 | def __init__(self): 57 | super(Net, self).__init__() 58 | self.conv1 = P4ConvZ2(1, 10, kernel_size=3) 59 | self.conv2 = P4ConvP4(10, 10, kernel_size=3) 60 | self.conv3 = P4ConvP4(10, 20, kernel_size=3) 61 | self.conv4 = P4ConvP4(20, 20, kernel_size=3) 62 | self.fc1 = nn.Linear(4*4*20*4, 50) 63 | self.fc2 = nn.Linear(50, 10) 64 | 65 | def forward(self, x): 66 | x = F.relu(self.conv1(x)) 67 | x = F.relu(self.conv2(x)) 68 | x = plane_group_spatial_max_pooling(x, 2, 2) 69 | x = F.relu(self.conv3(x)) 70 | x = F.relu(self.conv4(x)) 71 | x = plane_group_spatial_max_pooling(x, 2, 2) 72 | x = x.view(x.size()[0], -1) 73 | x = F.relu(self.fc1(x)) 74 | x = F.dropout(x, training=self.training) 75 | x = self.fc2(x) 76 | return F.log_softmax(x) 77 | 78 | model = Net() 79 | if args.cuda: 80 | model.cuda() 81 | 82 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 83 | 84 | 85 | def train(epoch): 86 | model.train() 87 | for batch_idx, (data, target) in enumerate(train_loader): 88 | if args.cuda: 89 | data, target = data.cuda(), target.cuda() 90 | data, target = Variable(data), Variable(target) 91 | optimizer.zero_grad() 92 | output = model(data) 93 | loss = F.nll_loss(output, target) 94 | loss.backward() 95 | optimizer.step() 96 | if batch_idx % args.log_interval == 0: 97 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 98 | epoch, batch_idx * len(data), len(train_loader.dataset), 99 | 100. * batch_idx / len(train_loader), loss.data[0])) 100 | 101 | 102 | def test(): 103 | model.eval() 104 | test_loss = 0 105 | correct = 0 106 | for data, target in test_loader: 107 | if args.cuda: 108 | data, target = data.cuda(), target.cuda() 109 | data, target = Variable(data, volatile=True), Variable(target) 110 | output = model(data) 111 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 112 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 113 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 114 | 115 | test_loss /= len(test_loader.dataset) 116 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 117 | test_loss, correct, len(test_loader.dataset), 118 | 100. * correct / len(test_loader.dataset))) 119 | 120 | 121 | for epoch in range(1, args.epochs + 1): 122 | train(epoch) 123 | test() 124 | --------------------------------------------------------------------------------