├── README.md ├── deform_conv.py ├── demo.py └── test_against_mxnet.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch Implementation of Deformable Convolution 2 | This repository implements the defromable convolution architecture proposed in this paper: 3 | [*Jifeng Dai, Haozhi Qi, Yuwen Xiong, Yi Li, Guodong Zhang, Han Hu and Yichen Wei. Deformable Convolutional Networks. arXiv preprint arXiv:1703.06211, 2017.*](https://arxiv.org/abs/1703.06211) 4 | 5 | ### Usage 6 | * The defromable convolution module, i.e., *DeformConv2D*, is defined in `deform_conv.py`. 7 | * A simple demo is shown in `demo.py`, it's easy to interpolate the *DeformConv2D* module into your own networks. 8 | 9 | ### TODO 10 | - [x] Memory effeicent implementation. 11 | - [x] Test against MXNet's official implementation. 12 | - [ ] Visualize offsets 13 | - [ ] Demo for RFCN implemantation 14 | 15 | ### Notes 16 | * Although there has already been some implementations, such as [PyTorch](https://github.com/oeway/pytorch-deform-conv)/[TensorFlow](https://github.com/felixlaumon/deform-conv), they seem to have some problems as discussed [here](https://github.com/felixlaumon/deform-conv/issues/4). 17 | * In my opinion, the *DeformConv2D* module is better added to top of higher-level features for the sake of better learning the offsets. More experiments are needed to validate this conjecture. 18 | * This repo has been verified by comparing with the official MXNet implementation, as showed in `test_against_mxnet.ipynb`. 19 | 20 | ### Requirements 21 | * [PyTorch-v0.3.0](http://pytorch.org/docs/0.3.0/) 22 | -------------------------------------------------------------------------------- /deform_conv.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable, Function 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | 6 | 7 | class DeformConv2D(nn.Module): 8 | def __init__(self, inc, outc, kernel_size=3, padding=1, bias=None): 9 | super(DeformConv2D, self).__init__() 10 | self.kernel_size = kernel_size 11 | self.padding = padding 12 | self.zero_padding = nn.ZeroPad2d(padding) 13 | self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) 14 | 15 | def forward(self, x, offset): 16 | dtype = offset.data.type() 17 | ks = self.kernel_size 18 | N = offset.size(1) // 2 19 | 20 | # Change offset's order from [x1, x2, ..., y1, y2, ...] to [x1, y1, x2, y2, ...] 21 | # Codes below are written to make sure same results of MXNet implementation. 22 | # You can remove them, and it won't influence the module's performance. 23 | offsets_index = Variable(torch.cat([torch.arange(0, 2*N, 2), torch.arange(1, 2*N+1, 2)]), requires_grad=False).type_as(x).long() 24 | offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*offset.size()) 25 | offset = torch.gather(offset, dim=1, index=offsets_index) 26 | # ------------------------------------------------------------------------ 27 | 28 | if self.padding: 29 | x = self.zero_padding(x) 30 | 31 | # (b, 2N, h, w) 32 | p = self._get_p(offset, dtype) 33 | 34 | # (b, h, w, 2N) 35 | p = p.contiguous().permute(0, 2, 3, 1) 36 | q_lt = Variable(p.data, requires_grad=False).floor() 37 | q_rb = q_lt + 1 38 | 39 | q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() 40 | q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() 41 | q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1) 42 | q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1) 43 | 44 | # (b, h, w, N) 45 | mask = torch.cat([p[..., :N].lt(self.padding)+p[..., :N].gt(x.size(2)-1-self.padding), 46 | p[..., N:].lt(self.padding)+p[..., N:].gt(x.size(3)-1-self.padding)], dim=-1).type_as(p) 47 | mask = mask.detach() 48 | floor_p = p - (p - torch.floor(p)) 49 | p = p*(1-mask) + floor_p*mask 50 | p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) 51 | 52 | # bilinear kernel (b, h, w, N) 53 | g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) 54 | g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) 55 | g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) 56 | g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) 57 | 58 | # (b, c, h, w, N) 59 | x_q_lt = self._get_x_q(x, q_lt, N) 60 | x_q_rb = self._get_x_q(x, q_rb, N) 61 | x_q_lb = self._get_x_q(x, q_lb, N) 62 | x_q_rt = self._get_x_q(x, q_rt, N) 63 | 64 | # (b, c, h, w, N) 65 | x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ 66 | g_rb.unsqueeze(dim=1) * x_q_rb + \ 67 | g_lb.unsqueeze(dim=1) * x_q_lb + \ 68 | g_rt.unsqueeze(dim=1) * x_q_rt 69 | 70 | x_offset = self._reshape_x_offset(x_offset, ks) 71 | out = self.conv_kernel(x_offset) 72 | 73 | return out 74 | 75 | def _get_p_n(self, N, dtype): 76 | p_n_x, p_n_y = np.meshgrid(range(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), 77 | range(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), indexing='ij') 78 | # (2N, 1) 79 | p_n = np.concatenate((p_n_x.flatten(), p_n_y.flatten())) 80 | p_n = np.reshape(p_n, (1, 2*N, 1, 1)) 81 | p_n = Variable(torch.from_numpy(p_n).type(dtype), requires_grad=False) 82 | 83 | return p_n 84 | 85 | @staticmethod 86 | def _get_p_0(h, w, N, dtype): 87 | p_0_x, p_0_y = np.meshgrid(range(1, h+1), range(1, w+1), indexing='ij') 88 | p_0_x = p_0_x.flatten().reshape(1, 1, h, w).repeat(N, axis=1) 89 | p_0_y = p_0_y.flatten().reshape(1, 1, h, w).repeat(N, axis=1) 90 | p_0 = np.concatenate((p_0_x, p_0_y), axis=1) 91 | p_0 = Variable(torch.from_numpy(p_0).type(dtype), requires_grad=False) 92 | 93 | return p_0 94 | 95 | def _get_p(self, offset, dtype): 96 | N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) 97 | 98 | # (1, 2N, 1, 1) 99 | p_n = self._get_p_n(N, dtype) 100 | # (1, 2N, h, w) 101 | p_0 = self._get_p_0(h, w, N, dtype) 102 | p = p_0 + p_n + offset 103 | return p 104 | 105 | def _get_x_q(self, x, q, N): 106 | b, h, w, _ = q.size() 107 | padded_w = x.size(3) 108 | c = x.size(1) 109 | # (b, c, h*w) 110 | x = x.contiguous().view(b, c, -1) 111 | 112 | # (b, h, w, N) 113 | index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y 114 | # (b, c, h*w*N) 115 | index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) 116 | 117 | x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) 118 | 119 | return x_offset 120 | 121 | @staticmethod 122 | def _reshape_x_offset(x_offset, ks): 123 | b, c, h, w, N = x_offset.size() 124 | x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) 125 | x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) 126 | 127 | return x_offset -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | from deform_conv import DeformConv2D 10 | 11 | from time import time 12 | 13 | # Training settings 14 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 15 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 16 | help='input batch size for training (default: 32)') 17 | parser.add_argument('--test-batch-size', type=int, default=32, metavar='N', 18 | help='input batch size for testing (default: 32)') 19 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 20 | help='number of epochs to train (default: 10)') 21 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 22 | help='learning rate (default: 0.01)') 23 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 24 | help='SGD momentum (default: 0.5)') 25 | parser.add_argument('--no-cuda', action='store_true', default=False, 26 | help='disables CUDA training') 27 | parser.add_argument('--seed', type=int, default=1, metavar='S', 28 | help='random seed (default: 1)') 29 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 30 | help='how many batches to wait before logging training status') 31 | args = parser.parse_args() 32 | args.cuda = not args.no_cuda and torch.cuda.is_available() 33 | 34 | torch.manual_seed(args.seed) 35 | if args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 39 | train_loader = torch.utils.data.DataLoader( 40 | datasets.MNIST('./MNIST', train=True, download=True, 41 | transform=transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.1307,), (0.3081,)) 44 | ])), 45 | batch_size=args.batch_size, shuffle=True, **kwargs) 46 | test_loader = torch.utils.data.DataLoader( 47 | datasets.MNIST('./MNIST', train=False, download=True, 48 | transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)) 51 | ])), 52 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 53 | 54 | 55 | class DeformNet(nn.Module): 56 | def __init__(self): 57 | super(DeformNet, self).__init__() 58 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 59 | self.bn1 = nn.BatchNorm2d(32) 60 | 61 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 62 | self.bn2 = nn.BatchNorm2d(64) 63 | 64 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 65 | self.bn3 = nn.BatchNorm2d(128) 66 | 67 | self.offsets = nn.Conv2d(128, 18, kernel_size=3, padding=1) 68 | self.conv4 = DeformConv2D(128, 128, kernel_size=3, padding=1) 69 | self.bn4 = nn.BatchNorm2d(128) 70 | 71 | self.classifier = nn.Linear(128, 10) 72 | 73 | def forward(self, x): 74 | # convs 75 | x = F.relu(self.conv1(x)) 76 | x = self.bn1(x) 77 | x = F.relu(self.conv2(x)) 78 | x = self.bn2(x) 79 | x = F.relu(self.conv3(x)) 80 | x = self.bn3(x) 81 | # deformable convolution 82 | offsets = self.offsets(x) 83 | x = F.relu(self.conv4(x, offsets)) 84 | x = self.bn4(x) 85 | 86 | x = F.avg_pool2d(x, kernel_size=28, stride=1).view(x.size(0), -1) 87 | x = self.classifier(x) 88 | 89 | return F.log_softmax(x, dim=1) 90 | 91 | 92 | class PlainNet(nn.Module): 93 | def __init__(self): 94 | super(PlainNet, self).__init__() 95 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 96 | self.bn1 = nn.BatchNorm2d(32) 97 | 98 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 99 | self.bn2 = nn.BatchNorm2d(64) 100 | 101 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 102 | self.bn3 = nn.BatchNorm2d(128) 103 | 104 | self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 105 | self.bn4 = nn.BatchNorm2d(128) 106 | 107 | self.classifier = nn.Linear(128, 10) 108 | 109 | def forward(self, x): 110 | # convs 111 | x = F.relu(self.conv1(x)) 112 | x = self.bn1(x) 113 | x = F.relu(self.conv2(x)) 114 | x = self.bn2(x) 115 | x = F.relu(self.conv3(x)) 116 | x = self.bn3(x) 117 | x = F.relu(self.conv4(x)) 118 | x = self.bn4(x) 119 | 120 | x = F.avg_pool2d(x, kernel_size=28, stride=1).view(x.size(0), -1) 121 | x = self.classifier(x) 122 | 123 | return F.log_softmax(x, dim=1) 124 | 125 | model = DeformNet() 126 | 127 | 128 | def init_weights(m): 129 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 130 | nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu')) 131 | if m.bias is not None: 132 | m.bias.data = torch.FloatTensor(m.bias.shape[0]).zero_() 133 | 134 | 135 | def init_conv_offset(m): 136 | m.weight.data = torch.zeros_like(m.weight.data) 137 | if m.bias is not None: 138 | m.bias.data = torch.FloatTensor(m.bias.shape[0]).zero_() 139 | 140 | 141 | model.apply(init_weights) 142 | model.offsets.apply(init_conv_offset) 143 | 144 | if args.cuda: 145 | model.cuda() 146 | 147 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 148 | 149 | 150 | def train(epoch): 151 | model.train() 152 | for batch_idx, (data, target) in enumerate(train_loader): 153 | data, target = Variable(data), Variable(target) 154 | if args.cuda: 155 | data, target = data.cuda(), target.cuda() 156 | optimizer.zero_grad() 157 | output = model(data) 158 | loss = F.nll_loss(output, target) 159 | loss.backward() 160 | optimizer.step() 161 | if batch_idx % args.log_interval == 0: 162 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 163 | epoch, batch_idx * len(data), len(train_loader.dataset), 164 | 100. * batch_idx / len(train_loader), loss.data[0])) 165 | 166 | def test(): 167 | model.eval() 168 | test_loss = 0 169 | correct = 0 170 | for data, target in test_loader: 171 | if args.cuda: 172 | data, target = data.cuda(), target.cuda() 173 | data, target = Variable(data, volatile=True), Variable(target) 174 | output = model(data) 175 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 176 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 177 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 178 | 179 | test_loss /= len(test_loader.dataset) 180 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 181 | test_loss, correct, len(test_loader.dataset), 182 | 100. * correct / len(test_loader.dataset))) 183 | 184 | 185 | for epoch in range(1, args.epochs + 1): 186 | since = time() 187 | train(epoch) 188 | iter = time() - since 189 | print("Spends {}s for each training epoch".format(iter/args.epochs)) 190 | test() 191 | -------------------------------------------------------------------------------- /test_against_mxnet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "*Note: MXNet supplys an official implementation of deformable convolution, here tests this repo against the MXNet's implementation.*" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import torch\n", 18 | "import numpy as np\n", 19 | "import mxnet as mx\n", 20 | "from torch import nn\n", 21 | "from time import time\n", 22 | "from pprint import pprint\n", 23 | "from torch.autograd import Variable\n", 24 | "from mxnet.initializer import Initializer\n", 25 | "from deform_conv import DeformConv2D" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "#### Set up parameters." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Using gpu0\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "bs, inC, ouC, H, W = 1, 1, 1, 4, 5\n", 50 | "kH, kW = 3, 3\n", 51 | "padding = 1\n", 52 | "\n", 53 | "# ---------------------------------------\n", 54 | "use_gpu = torch.cuda.is_available()\n", 55 | "gpu_device = 0\n", 56 | "if use_gpu:\n", 57 | " os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(gpu_device)\n", 58 | " print(\"Using gpu{}\".format(os.getenv(\"CUDA_VISIBLE_DEVICES\")))\n", 59 | "# ---------------------------------------\n", 60 | "raw_inputs = np.random.rand(bs, inC, H, W).astype(np.float32)\n", 61 | "raw_labels = np.random.rand(bs, ouC, (H+2*padding-2)//1, (W+2*padding-2)//1).astype(np.float32)\n", 62 | "# weights for conv offsets.\n", 63 | "offset_weights = np.random.rand(18, inC, 3, 3).astype(np.float32)\n", 64 | "# weights for deformable convolution.\n", 65 | "conv_weights = np.random.rand(ouC, inC, 3, 3).astype(np.float32)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "\n", 78 | "inputs:\n", 79 | "array([[[[ 0.56075788, 0.56448251, 0.38643569, 0.13775933, 0.92719644],\n", 80 | " [ 0.18066591, 0.24222445, 0.29689947, 0.54874617, 0.95001829],\n", 81 | " [ 0.61031544, 0.84815538, 0.27238497, 0.53376287, 0.93240666],\n", 82 | " [ 0.54890364, 0.60794067, 0.1237376 , 0.16012843, 0.82202536]]]], dtype=float32)\n", 83 | "\n", 84 | "labels:\n", 85 | "array([[[[ 0.68657815, 0.10542295, 0.04489666, 0.64058632, 0.52095002],\n", 86 | " [ 0.73867059, 0.91901845, 0.80943078, 0.2182935 , 0.02595145],\n", 87 | " [ 0.31954384, 0.80359656, 0.53808153, 0.46827996, 0.90268624],\n", 88 | " [ 0.84400773, 0.5750683 , 0.55033565, 0.11278367, 0.47512576]]]], dtype=float32)\n", 89 | "\n", 90 | "conv weights:\n", 91 | "array([[[[ 0.89284414, 0.98574871, 0.94764489],\n", 92 | " [ 0.69642198, 0.84854221, 0.98900223],\n", 93 | " [ 0.82735974, 0.4257046 , 0.59915102]]]], dtype=float32)\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "print('\\ninputs:')\n", 99 | "pprint(raw_inputs)\n", 100 | "print('\\nlabels:')\n", 101 | "pprint(raw_labels)\n", 102 | "print('\\nconv weights:')\n", 103 | "pprint(conv_weights)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "### Set up models of PyTorch&MXNet" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "#### Set PyTorch model." 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "class TestModel(nn.Module):\n", 127 | " def __init__(self):\n", 128 | " super(TestModel, self).__init__()\n", 129 | " self.conv_offset = nn.Conv2d(in_channels=inC, out_channels=18, kernel_size=3, padding=padding, bias=None)\n", 130 | " self.deform_conv = DeformConv2D(inc=inC, outc=ouC, padding=padding)\n", 131 | "\n", 132 | " def forward(self, x):\n", 133 | " offsets = self.conv_offset(x)\n", 134 | " out = self.deform_conv(x, offsets)\n", 135 | " return out" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 5, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "model = TestModel()\n", 145 | "\n", 146 | "pt_inputs = Variable(torch.from_numpy(raw_inputs).cuda(), requires_grad=True)\n", 147 | "pt_labels = Variable(torch.from_numpy(raw_labels).cuda(), requires_grad=False)\n", 148 | "\n", 149 | "optimizer = torch.optim.SGD([{'params': model.parameters()}], lr=1e-1)\n", 150 | "loss_fn = torch.nn.MSELoss(reduce=True)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "#### Init weights." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 6, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "Conv2d (1, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)" 169 | ] 170 | }, 171 | "execution_count": 6, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | } 175 | ], 176 | "source": [ 177 | "def init_weights(m):\n", 178 | " if isinstance(m, torch.nn.Conv2d):\n", 179 | " m.weight.data = torch.from_numpy(conv_weights)\n", 180 | " if m.bias is not None:\n", 181 | " m.bias.data = torch.FloatTensor(m.bias.shape[0]).zero_()\n", 182 | "\n", 183 | "def init_offsets_weights(m):\n", 184 | " if isinstance(m, torch.nn.Conv2d):\n", 185 | " m.weight.data = torch.from_numpy(offset_weights)\n", 186 | " if m.bias is not None:\n", 187 | " m.bias.data = torch.FloatTensor(m.bias.shape[0]).zero_()\n", 188 | "\n", 189 | "model.deform_conv.apply(init_weights)\n", 190 | "model.conv_offset.apply(init_offsets_weights)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 7, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "if use_gpu:\n", 200 | " model.cuda()" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "#### Set MXNet model." 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 8, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "# trainiter\n", 217 | "train_iter = mx.io.NDArrayIter(raw_inputs, raw_labels, 1, shuffle=True, data_name='data', label_name='label')\n", 218 | "\n", 219 | "# # symbol\n", 220 | "inputs = mx.symbol.Variable('data')\n", 221 | "labels = mx.symbol.Variable('label')\n", 222 | "offsets = mx.symbol.Convolution(data=inputs, kernel=(3, 3), pad=(padding, padding), num_filter=18, name='offset', no_bias=True)\n", 223 | "net = mx.symbol.contrib.DeformableConvolution(data=inputs, offset=offsets, kernel=(3, 3), pad=(padding, padding), num_filter=ouC, name='deform', no_bias=True)\n", 224 | "outputs = mx.symbol.MakeLoss(data=mx.symbol.mean((net-labels)**2))" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 9, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "mod = mx.mod.Module(symbol=outputs,\n", 234 | " context=mx.gpu(),\n", 235 | " data_names=['data'],\n", 236 | " label_names=['label'])\n", 237 | "\n", 238 | "mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)\n", 239 | "mod.init_params(initializer=mx.initializer.Load({'deform_weight': mx.nd.array(conv_weights),\n", 240 | " 'offset_weight': mx.nd.array(offset_weights)}))\n", 241 | "mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1),))" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "### Inference" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "#### PyTorch" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 10, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stdout", 265 | "output_type": "stream", 266 | "text": [ 267 | "Variable containing:\n", 268 | "(0 ,0 ,.,.) = \n", 269 | " 1.1062 2.6770 3.6377 2.7072 1.3894\n", 270 | " 2.7560 3.1664 3.4931 0.8367 0.5725\n", 271 | " 2.1425 1.1672 2.4347 1.5723 0.0000\n", 272 | " 1.5282 0.7295 1.6247 2.0247 1.5443\n", 273 | "[torch.cuda.FloatTensor of size 1x1x4x5 (GPU 0)]\n", 274 | "\n" 275 | ] 276 | } 277 | ], 278 | "source": [ 279 | "output = model(pt_inputs)\n", 280 | "pprint(output)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "#### MXNet" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 11, 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "\n", 300 | "[[[[ 1.10622048 2.67702818 3.63765717 2.70719433 1.38943076]\n", 301 | " [ 2.75602031 3.16641665 3.49313188 0.83671159 0.57247651]\n", 302 | " [ 2.14249563 1.16722107 2.43471932 1.57227027 0. ]\n", 303 | " [ 1.52822971 0.72951925 1.62470031 2.02470279 1.54425097]]]]\n", 304 | "\n" 305 | ] 306 | } 307 | ], 308 | "source": [ 309 | "mx_inputs = mx.nd.array(raw_inputs, ctx=mx.gpu())\n", 310 | "conv_weights = mx.nd.array(conv_weights, ctx=mx.gpu())\n", 311 | "offset_weights = mx.nd.array(offset_weights, ctx=mx.gpu())\n", 312 | "offset = mx.ndarray.Convolution(data=mx_inputs, weight=offset_weights, kernel=(3, 3), pad=(padding, padding), num_filter=18, name='offset', no_bias=True)\n", 313 | "outputs = mx.ndarray.contrib.DeformableConvolution(data=mx_inputs, offset=offset, weight=conv_weights, kernel=(3, 3), pad=(padding, padding), num_filter=ouC, name='deform', no_bias=True)\n", 314 | "pprint(outputs)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "### Train" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "#### PyTorch" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 12, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "for i in range(100):\n", 338 | " output = model(pt_inputs)\n", 339 | " loss = loss_fn(output, pt_labels)\n", 340 | " optimizer.zero_grad()\n", 341 | " loss.backward()\n", 342 | " optimizer.step()" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 13, 348 | "metadata": {}, 349 | "outputs": [ 350 | { 351 | "name": "stdout", 352 | "output_type": "stream", 353 | "text": [ 354 | "Variable containing:\n", 355 | "(0 ,0 ,.,.) = \n", 356 | " 0.1410 0.3813 0.3941 0.7022 0.3028\n", 357 | " 0.4378 0.7259 0.8587 0.3047 -0.0288\n", 358 | " 0.5056 0.3441 0.6582 0.5480 0.0000\n", 359 | " 0.4662 0.2311 0.4464 0.5114 0.5383\n", 360 | "[torch.cuda.FloatTensor of size 1x1x4x5 (GPU 0)]\n", 361 | "\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "pprint(model(pt_inputs))" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "#### MXNet" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 14, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "for i in range(100):\n", 383 | " train_iter.reset()\n", 384 | " for batch in train_iter:\n", 385 | " # get outputs\n", 386 | "# infer_outputs = mx.mod.Module(symbol=net,\n", 387 | "# context=mx.gpu(),\n", 388 | "# data_names=['data'])\n", 389 | "# infer_outputs.bind(data_shapes=train_iter.provide_data)\n", 390 | "# infer_outputs.set_params(arg_params=mod.get_params()[0], aux_params=mod.get_params()[1], allow_extra=True)\n", 391 | "# outputs_value = infer_outputs.predict(train_iter)\n", 392 | "\n", 393 | " mod.forward(batch, is_train=True) # compute predictions\n", 394 | " mod.backward() # compute gradients\n", 395 | " mod.update() # update parameters" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 15, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "\n", 408 | "[[[[ 0.14098766 0.38126716 0.39405173 0.70216376 0.30278912]\n", 409 | " [ 0.43777964 0.72585207 0.8587358 0.30472821 -0.02884051]\n", 410 | " [ 0.50558764 0.34406984 0.65824842 0.54796678 0. ]\n", 411 | " [ 0.46616155 0.23111115 0.44640523 0.51136774 0.53833312]]]]\n", 412 | "\n" 413 | ] 414 | } 415 | ], 416 | "source": [ 417 | "mx_inputs = mx.nd.array(raw_inputs, ctx=mx.gpu())\n", 418 | "mx_labels = mx.nd.array(raw_labels, ctx=mx.gpu())\n", 419 | "conv_weights = mod.get_params()[0]['deform_weight'].as_in_context(mx.gpu())\n", 420 | "offset_weights = mod.get_params()[0]['offset_weight'].as_in_context(mx.gpu())\n", 421 | "offset = mx.ndarray.Convolution(data=mx_inputs, weight=offset_weights, kernel=(3, 3), pad=(padding, padding), num_filter=18, name='offset', no_bias=True)\n", 422 | "outputs = mx.ndarray.contrib.DeformableConvolution(data=mx_inputs, offset=offset, weight=conv_weights, kernel=(3, 3), pad=(padding, padding), num_filter=ouC, name='deform', no_bias=True)\n", 423 | "pprint(outputs)" 424 | ] 425 | } 426 | ], 427 | "metadata": { 428 | "kernelspec": { 429 | "display_name": "Python 3", 430 | "language": "python", 431 | "name": "python3" 432 | }, 433 | "language_info": { 434 | "codemirror_mode": { 435 | "name": "ipython", 436 | "version": 3 437 | }, 438 | "file_extension": ".py", 439 | "mimetype": "text/x-python", 440 | "name": "python", 441 | "nbconvert_exporter": "python", 442 | "pygments_lexer": "ipython3", 443 | "version": "3.5.2" 444 | } 445 | }, 446 | "nbformat": 4, 447 | "nbformat_minor": 2 448 | } 449 | --------------------------------------------------------------------------------