├── .gitignore ├── LICENSE ├── README.md ├── deformable-learned-offset-filtered.gif ├── models └── deform_cnn.png ├── scaled_mnist.py ├── tests └── test_deform_conv.py └── torch_deform_conv ├── __init__.py ├── cnn.py ├── deform_conv.py ├── layers.py ├── mnist.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | #####=== OSX ===##### 4 | .DS_Store 5 | .AppleDouble 6 | .LSOverride 7 | 8 | # Icon must end with two \r 9 | Icon 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # Files that might appear in the root of a volume 15 | .DocumentRevisions-V100 16 | .fseventsd 17 | .Spotlight-V100 18 | .TemporaryItems 19 | .Trashes 20 | .VolumeIcon.icns 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | #####=== Python ===##### 30 | 31 | # Byte-compiled / optimized / DLL files 32 | __pycache__/ 33 | *.py[cod] 34 | *$py.class 35 | 36 | # C extensions 37 | *.so 38 | 39 | # Distribution / packaging 40 | .Python 41 | env/ 42 | build/ 43 | develop-eggs/ 44 | dist/ 45 | downloads/ 46 | eggs/ 47 | .eggs/ 48 | lib/ 49 | lib64/ 50 | parts/ 51 | sdist/ 52 | var/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *,cover 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | #####=== IPythonNotebook ===##### 91 | # Temporary data 92 | .ipynb_checkpoints/ 93 | 94 | 95 | logs/ 96 | notebooks/ 97 | tags 98 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Felix Lau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Deformable Convolution 2 | 3 | **!!!Warning: There is some issues in this implementation and this repo is not maintained any more, please consider using for example: [TORCHVISION.OPS.DEFORM_CONV](https://pytorch.org/vision/stable/_modules/torchvision/ops/deform_conv.html)** 4 | 5 | 6 | 7 | * By Wei OUYANG @ Institut Pasteur 8 | * Thanks to Felix Lau's Keras/TensorFlow implementation: ~~https://github.com/felixlaumon/deform-conv~~ (https://github.com/kastnerkyle/deform-conv) 9 | 10 | ### TODO List 11 | - [x] implement offsets mapping in pytorch 12 | - [x] all tests passed 13 | - [x] deformable convolution module 14 | - [x] Fine-tuning the deformable convolution modules 15 | - [x] scaled mnist demo 16 | - [x] improve speed with cached grid array 17 | - [ ] use MNIST dataset from pytorch (instead of Keras) 18 | - [ ] support input image with different width and height 19 | - [ ] benchmark with tensorflow implementation 20 | 21 | ## Deformable Convolutional Networks 22 | > Dai, Jifeng, Haozhi Qi, Yuwen Xiong, Yi Li, Guodong Zhang, Han Hu, and Yichen 23 | Wei. 2017. “Deformable Convolutional Networks.” arXiv [cs.CV]. arXiv. 24 | http://arxiv.org/abs/1703.06211 25 | 26 | The following animation is generated by Felix Lau (with his tensorflow implementation): 27 | 28 | ![](deformable-learned-offset-filtered.gif) 29 | 30 | Also Check out Felix Lau's summary of the paper: https://medium.com/@phelixlau/notes-on-deformable-convolutional-networks-baaabbc11cf3 31 | -------------------------------------------------------------------------------- /deformable-learned-offset-filtered.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oeway/pytorch-deform-conv/d61d3aa4da20880c524193a50f6e9b44b921a938/deformable-learned-offset-filtered.gif -------------------------------------------------------------------------------- /models/deform_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oeway/pytorch-deform-conv/d61d3aa4da20880c524193a50f6e9b44b921a938/models/deform_cnn.png -------------------------------------------------------------------------------- /scaled_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | # %env CUDA_VISIBLE_DEVICES=0 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | 10 | from torch_deform_conv.layers import ConvOffset2D 11 | from torch_deform_conv.cnn import get_cnn, get_deform_cnn 12 | from torch_deform_conv.mnist import get_gen 13 | from torch_deform_conv.utils import transfer_weights 14 | 15 | batch_size = 32 16 | n_train = 60000 17 | n_test = 10000 18 | steps_per_epoch = int(np.ceil(n_train / batch_size)) 19 | validation_steps = int(np.ceil(n_test / batch_size)) 20 | 21 | train_gen = get_gen( 22 | 'train', batch_size=batch_size, 23 | scale=(1.0, 1.0), translate=0.0, 24 | shuffle=True 25 | ) 26 | test_gen = get_gen( 27 | 'test', batch_size=batch_size, 28 | scale=(1.0, 1.0), translate=0.0, 29 | shuffle=False 30 | ) 31 | train_scaled_gen = get_gen( 32 | 'train', batch_size=batch_size, 33 | scale=(1.0, 2.5), translate=0.2, 34 | shuffle=True 35 | ) 36 | test_scaled_gen = get_gen( 37 | 'test', batch_size=batch_size, 38 | scale=(1.0, 2.5), translate=0.2, 39 | shuffle=False 40 | ) 41 | 42 | 43 | def train(model, generator, batch_num, epoch): 44 | model.train() 45 | for batch_idx in range(batch_num): 46 | data, target = next(generator) 47 | data, target = torch.from_numpy(data), torch.from_numpy(target) 48 | # convert BHWC to BCHW 49 | data = data.permute(0, 3, 1, 2) 50 | data, target = data.float().cuda(), target.long().cuda() 51 | 52 | data, target = Variable(data), Variable(target) 53 | optimizer.zero_grad() 54 | output = model(data) 55 | loss = F.cross_entropy(output, target) 56 | loss.backward() 57 | optimizer.step() 58 | 59 | print('Train Epoch: {}\tLoss: {:.6f}'.format(epoch, loss.data[0])) 60 | 61 | def test(model, generator, batch_num, epoch): 62 | model.eval() 63 | test_loss = 0 64 | correct = 0 65 | for batch_idx in range(batch_num): 66 | data, target = next(generator) 67 | data, target = torch.from_numpy(data), torch.from_numpy(target) 68 | # convert BHWC to BCHW 69 | data = data.permute(0, 3, 1, 2) 70 | data, target = data.float().cuda(), target.long().cuda() 71 | 72 | data, target = Variable(data), Variable(target) 73 | output = model(data) 74 | test_loss += F.cross_entropy(output, target).data[0] 75 | pred = output.data.max(1)[1] # get the index of the max log-probability 76 | correct += pred.eq(target.data).cpu().sum() 77 | 78 | test_loss /= batch_num# loss function already averages over batch size 79 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( 80 | test_loss, correct, n_test, 100. * correct / n_test)) 81 | 82 | 83 | # --- 84 | # Normal CNN 85 | 86 | 87 | model = get_cnn() 88 | model = model.cuda() 89 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 90 | for epoch in range(10): 91 | test(model, test_gen, validation_steps, epoch) 92 | train(model, train_gen, steps_per_epoch, epoch) 93 | 94 | 95 | torch.save(model, 'models/cnn.th') 96 | 97 | # --- 98 | # Evaluate normal CNN 99 | 100 | print('Evaluate normal CNN') 101 | model_cnn = torch.load('models/cnn.th') 102 | 103 | test(model_cnn, test_gen, validation_steps, epoch) 104 | # 99.27% 105 | test(model_cnn, test_scaled_gen, validation_steps, epoch) 106 | # 58.83% 107 | 108 | # --- 109 | # Deformable CNN 110 | 111 | print('Finetune deformable CNN (ConvOffset2D and BatchNorm)') 112 | model = get_deform_cnn(trainable=False) 113 | model = model.cuda() 114 | transfer_weights(model_cnn, model) 115 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 116 | for epoch in range(20): 117 | test(model, test_scaled_gen, validation_steps, epoch) 118 | train(model, train_scaled_gen, steps_per_epoch, epoch) 119 | 120 | 121 | torch.save(model, 'models/deform_cnn.th') 122 | 123 | # --- 124 | # Evaluate deformable CNN 125 | 126 | print('Evaluate deformable CNN') 127 | model = torch.load('models/deform_cnn.th') 128 | 129 | test(model, test_gen, validation_steps, epoch) 130 | # xx% 131 | test(model, test_scaled_gen, validation_steps, epoch) 132 | # xx% -------------------------------------------------------------------------------- /tests/test_deform_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from scipy.ndimage.interpolation import map_coordinates 5 | 6 | from torch_deform_conv.deform_conv import ( 7 | th_map_coordinates, 8 | sp_batch_map_coordinates, th_batch_map_coordinates, 9 | sp_batch_map_offsets, th_batch_map_offsets 10 | ) 11 | 12 | 13 | def test_th_map_coordinates(): 14 | np.random.seed(42) 15 | input = np.random.random((100, 100)) 16 | coords = (np.random.random((200, 2)) * 99) 17 | 18 | sp_mapped_vals = map_coordinates(input, coords.T, order=1) 19 | th_mapped_vals = th_map_coordinates( 20 | Variable(torch.from_numpy(input)), Variable(torch.from_numpy(coords)) 21 | ) 22 | assert np.allclose(sp_mapped_vals, th_mapped_vals.data.numpy(), atol=1e-5) 23 | 24 | 25 | def test_th_batch_map_coordinates(): 26 | np.random.seed(42) 27 | input = np.random.random((4, 100, 100)) 28 | coords = (np.random.random((4, 200, 2)) * 99) 29 | 30 | sp_mapped_vals = sp_batch_map_coordinates(input, coords) 31 | th_mapped_vals = th_batch_map_coordinates( 32 | Variable(torch.from_numpy(input)), Variable(torch.from_numpy(coords)) 33 | ) 34 | assert np.allclose(sp_mapped_vals, th_mapped_vals.data.numpy(), atol=1e-5) 35 | 36 | 37 | def test_th_batch_map_offsets(): 38 | np.random.seed(42) 39 | input = np.random.random((4, 100, 100)) 40 | offsets = (np.random.random((4, 100, 100, 2)) * 2) 41 | 42 | sp_mapped_vals = sp_batch_map_offsets(input, offsets) 43 | th_mapped_vals = th_batch_map_offsets( 44 | Variable(torch.from_numpy(input)), Variable(torch.from_numpy(offsets)) 45 | ) 46 | assert np.allclose(sp_mapped_vals, th_mapped_vals.data.numpy(), atol=1e-5) 47 | 48 | 49 | def test_th_batch_map_offsets_grad(): 50 | np.random.seed(42) 51 | input = np.random.random((4, 100, 100)) 52 | offsets = (np.random.random((4, 100, 100, 2)) * 2) 53 | 54 | input = Variable(torch.from_numpy(input), requires_grad=True) 55 | offsets = Variable(torch.from_numpy(offsets), requires_grad=True) 56 | 57 | th_mapped_vals = th_batch_map_offsets(input, offsets) 58 | e = torch.from_numpy(np.random.random((4, 100, 100))) 59 | th_mapped_vals.backward(e) 60 | assert not np.allclose(input.grad.data.numpy(), 0) 61 | assert not np.allclose(offsets.grad.data.numpy(), 0) 62 | -------------------------------------------------------------------------------- /torch_deform_conv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oeway/pytorch-deform-conv/d61d3aa4da20880c524193a50f6e9b44b921a938/torch_deform_conv/__init__.py -------------------------------------------------------------------------------- /torch_deform_conv/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from torch_deform_conv.layers import ConvOffset2D 7 | 8 | class ConvNet(nn.Module): 9 | def __init__(self): 10 | super(ConvNet, self).__init__() 11 | 12 | # conv11 13 | self.conv11 = nn.Conv2d(1, 32, 3, padding=1) 14 | self.bn11 = nn.BatchNorm2d(32) 15 | 16 | # conv12 17 | self.conv12 = nn.Conv2d(32, 64, 3, padding=1, stride=2) 18 | self.bn12 = nn.BatchNorm2d(64) 19 | 20 | # conv21 21 | self.conv21 = nn.Conv2d(64, 128, 3, padding= 1) 22 | self.bn21 = nn.BatchNorm2d(128) 23 | 24 | # conv22 25 | self.conv22 = nn.Conv2d(128, 128, 3, padding=1, stride=2) 26 | self.bn22 = nn.BatchNorm2d(128) 27 | 28 | # out 29 | self.fc = nn.Linear(128, 10) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.conv11(x)) 33 | x = self.bn11(x) 34 | 35 | x = F.relu(self.conv12(x)) 36 | x = self.bn12(x) 37 | 38 | x = F.relu(self.conv21(x)) 39 | x = self.bn21(x) 40 | 41 | x = F.relu(self.conv22(x)) 42 | x = self.bn22(x) 43 | 44 | x = F.avg_pool2d(x, kernel_size=[x.size(2), x.size(3)]) 45 | x = self.fc(x.view(x.size()[:2]))# 46 | x = F.softmax(x) 47 | return x 48 | 49 | class DeformConvNet(nn.Module): 50 | def __init__(self): 51 | super(DeformConvNet, self).__init__() 52 | 53 | # conv11 54 | self.conv11 = nn.Conv2d(1, 32, 3, padding=1) 55 | self.bn11 = nn.BatchNorm2d(32) 56 | 57 | # conv12 58 | self.offset12 = ConvOffset2D(32) 59 | self.conv12 = nn.Conv2d(32, 64, 3, padding=1, stride=2) 60 | self.bn12 = nn.BatchNorm2d(64) 61 | 62 | # conv21 63 | self.offset21 = ConvOffset2D(64) 64 | self.conv21 = nn.Conv2d(64, 128, 3, padding= 1) 65 | self.bn21 = nn.BatchNorm2d(128) 66 | 67 | # conv22 68 | self.offset22 = ConvOffset2D(128) 69 | self.conv22 = nn.Conv2d(128, 128, 3, padding=1, stride=2) 70 | self.bn22 = nn.BatchNorm2d(128) 71 | 72 | # out 73 | self.fc = nn.Linear(128, 10) 74 | 75 | def forward(self, x): 76 | x = F.relu(self.conv11(x)) 77 | x = self.bn11(x) 78 | 79 | x = self.offset12(x) 80 | x = F.relu(self.conv12(x)) 81 | x = self.bn12(x) 82 | 83 | x = self.offset21(x) 84 | x = F.relu(self.conv21(x)) 85 | x = self.bn21(x) 86 | 87 | x = self.offset22(x) 88 | x = F.relu(self.conv22(x)) 89 | x = self.bn22(x) 90 | 91 | x = F.avg_pool2d(x, kernel_size=[x.size(2), x.size(3)]) 92 | x = self.fc(x.view(x.size()[:2])) 93 | x = F.softmax(x) 94 | return x 95 | 96 | def freeze(self, module_classes): 97 | ''' 98 | freeze modules for finetuning 99 | ''' 100 | for k, m in self._modules.items(): 101 | if any([type(m) == mc for mc in module_classes]): 102 | for param in m.parameters(): 103 | param.requires_grad = False 104 | 105 | def unfreeze(self, module_classes): 106 | ''' 107 | unfreeze modules 108 | ''' 109 | for k, m in self._modules.items(): 110 | if any([isinstance(m, mc) for mc in module_classes]): 111 | for param in m.parameters(): 112 | param.requires_grad = True 113 | 114 | def parameters(self): 115 | return filter(lambda p: p.requires_grad, super(DeformConvNet, self).parameters()) 116 | 117 | def get_cnn(): 118 | return ConvNet() 119 | 120 | def get_deform_cnn(trainable=True, freeze_filter=[nn.Conv2d, nn.Linear]): 121 | model = DeformConvNet() 122 | if not trainable: 123 | model.freeze(freeze_filter) 124 | return model 125 | -------------------------------------------------------------------------------- /torch_deform_conv/deform_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates 8 | 9 | 10 | def th_flatten(a): 11 | """Flatten tensor""" 12 | return a.contiguous().view(a.nelement()) 13 | 14 | 15 | def th_repeat(a, repeats, axis=0): 16 | """Torch version of np.repeat for 1D""" 17 | assert len(a.size()) == 1 18 | return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1)) 19 | 20 | 21 | def np_repeat_2d(a, repeats): 22 | """Tensorflow version of np.repeat for 2D""" 23 | 24 | assert len(a.shape) == 2 25 | a = np.expand_dims(a, 0) 26 | a = np.tile(a, [repeats, 1, 1]) 27 | return a 28 | 29 | 30 | def th_gather_2d(input, coords): 31 | inds = coords[:, 0]*input.size(1) + coords[:, 1] 32 | x = torch.index_select(th_flatten(input), 0, inds) 33 | return x.view(coords.size(0)) 34 | 35 | 36 | def th_map_coordinates(input, coords, order=1): 37 | """Tensorflow verion of scipy.ndimage.map_coordinates 38 | Note that coords is transposed and only 2D is supported 39 | Parameters 40 | ---------- 41 | input : tf.Tensor. shape = (s, s) 42 | coords : tf.Tensor. shape = (n_points, 2) 43 | """ 44 | 45 | assert order == 1 46 | input_size = input.size(0) 47 | 48 | coords = torch.clamp(coords, 0, input_size - 1) 49 | coords_lt = coords.floor().long() 50 | coords_rb = coords.ceil().long() 51 | coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1) 52 | coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1) 53 | 54 | vals_lt = th_gather_2d(input, coords_lt.detach()) 55 | vals_rb = th_gather_2d(input, coords_rb.detach()) 56 | vals_lb = th_gather_2d(input, coords_lb.detach()) 57 | vals_rt = th_gather_2d(input, coords_rt.detach()) 58 | 59 | coords_offset_lt = coords - coords_lt.type(coords.data.type()) 60 | 61 | vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0] 62 | vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0] 63 | mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1] 64 | return mapped_vals 65 | 66 | 67 | def sp_batch_map_coordinates(inputs, coords): 68 | """Reference implementation for batch_map_coordinates""" 69 | # coords = coords.clip(0, inputs.shape[1] - 1) 70 | 71 | assert (coords.shape[2] == 2) 72 | height = coords[:,:,0].clip(0, inputs.shape[1] - 1) 73 | width = coords[:,:,1].clip(0, inputs.shape[2] - 1) 74 | np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2) 75 | 76 | mapped_vals = np.array([ 77 | sp_map_coordinates(input, coord.T, mode='nearest', order=1) 78 | for input, coord in zip(inputs, coords) 79 | ]) 80 | return mapped_vals 81 | 82 | 83 | def th_batch_map_coordinates(input, coords, order=1): 84 | """Batch version of th_map_coordinates 85 | Only supports 2D feature maps 86 | Parameters 87 | ---------- 88 | input : tf.Tensor. shape = (b, s, s) 89 | coords : tf.Tensor. shape = (b, n_points, 2) 90 | Returns 91 | ------- 92 | tf.Tensor. shape = (b, s, s) 93 | """ 94 | 95 | batch_size = input.size(0) 96 | input_height = input.size(1) 97 | input_width = input.size(2) 98 | 99 | n_coords = coords.size(1) 100 | 101 | # coords = torch.clamp(coords, 0, input_size - 1) 102 | 103 | coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2) 104 | 105 | assert (coords.size(1) == n_coords) 106 | 107 | coords_lt = coords.floor().long() 108 | coords_rb = coords.ceil().long() 109 | coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2) 110 | coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2) 111 | idx = th_repeat(torch.arange(0, batch_size), n_coords).long() 112 | idx = Variable(idx, requires_grad=False) 113 | if input.is_cuda: 114 | idx = idx.cuda() 115 | 116 | def _get_vals_by_coords(input, coords): 117 | indices = torch.stack([ 118 | idx, th_flatten(coords[..., 0]), th_flatten(coords[..., 1]) 119 | ], 1) 120 | inds = indices[:, 0]*input.size(1)*input.size(2)+ indices[:, 1]*input.size(2) + indices[:, 2] 121 | vals = th_flatten(input).index_select(0, inds) 122 | vals = vals.view(batch_size, n_coords) 123 | return vals 124 | 125 | vals_lt = _get_vals_by_coords(input, coords_lt.detach()) 126 | vals_rb = _get_vals_by_coords(input, coords_rb.detach()) 127 | vals_lb = _get_vals_by_coords(input, coords_lb.detach()) 128 | vals_rt = _get_vals_by_coords(input, coords_rt.detach()) 129 | 130 | coords_offset_lt = coords - coords_lt.type(coords.data.type()) 131 | vals_t = coords_offset_lt[..., 0]*(vals_rt - vals_lt) + vals_lt 132 | vals_b = coords_offset_lt[..., 0]*(vals_rb - vals_lb) + vals_lb 133 | mapped_vals = coords_offset_lt[..., 1]* (vals_b - vals_t) + vals_t 134 | return mapped_vals 135 | 136 | 137 | def sp_batch_map_offsets(input, offsets): 138 | """Reference implementation for tf_batch_map_offsets""" 139 | 140 | batch_size = input.shape[0] 141 | input_height = input.shape[1] 142 | input_width = input.shape[2] 143 | 144 | offsets = offsets.reshape(batch_size, -1, 2) 145 | grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2) 146 | grid = np.repeat([grid], batch_size, axis=0) 147 | coords = offsets + grid 148 | # coords = coords.clip(0, input_size - 1) 149 | 150 | mapped_vals = sp_batch_map_coordinates(input, coords) 151 | return mapped_vals 152 | 153 | 154 | def th_generate_grid(batch_size, input_height, input_width, dtype, cuda): 155 | grid = np.meshgrid( 156 | range(input_height), range(input_width), indexing='ij' 157 | ) 158 | grid = np.stack(grid, axis=-1) 159 | grid = grid.reshape(-1, 2) 160 | 161 | grid = np_repeat_2d(grid, batch_size) 162 | grid = torch.from_numpy(grid).type(dtype) 163 | if cuda: 164 | grid = grid.cuda() 165 | return Variable(grid, requires_grad=False) 166 | 167 | 168 | def th_batch_map_offsets(input, offsets, grid=None, order=1): 169 | """Batch map offsets into input 170 | Parameters 171 | --------- 172 | input : torch.Tensor. shape = (b, s, s) 173 | offsets: torch.Tensor. shape = (b, s, s, 2) 174 | Returns 175 | ------- 176 | torch.Tensor. shape = (b, s, s) 177 | """ 178 | batch_size = input.size(0) 179 | input_height = input.size(1) 180 | input_width = input.size(2) 181 | 182 | offsets = offsets.view(batch_size, -1, 2) 183 | if grid is None: 184 | grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda) 185 | 186 | coords = offsets + grid 187 | 188 | mapped_vals = th_batch_map_coordinates(input, coords) 189 | return mapped_vals 190 | -------------------------------------------------------------------------------- /torch_deform_conv/layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import numpy as np 7 | from torch_deform_conv.deform_conv import th_batch_map_offsets, th_generate_grid 8 | 9 | 10 | class ConvOffset2D(nn.Conv2d): 11 | """ConvOffset2D 12 | 13 | Convolutional layer responsible for learning the 2D offsets and output the 14 | deformed feature map using bilinear interpolation 15 | 16 | Note that this layer does not perform convolution on the deformed feature 17 | map. See get_deform_cnn in cnn.py for usage 18 | """ 19 | def __init__(self, filters, init_normal_stddev=0.01, **kwargs): 20 | """Init 21 | 22 | Parameters 23 | ---------- 24 | filters : int 25 | Number of channel of the input feature map 26 | init_normal_stddev : float 27 | Normal kernel initialization 28 | **kwargs: 29 | Pass to superclass. See Con2d layer in pytorch 30 | """ 31 | self.filters = filters 32 | self._grid_param = None 33 | super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs) 34 | self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev)) 35 | 36 | def forward(self, x): 37 | """Return the deformed featured map""" 38 | x_shape = x.size() 39 | offsets = super(ConvOffset2D, self).forward(x) 40 | 41 | # offsets: (b*c, h, w, 2) 42 | offsets = self._to_bc_h_w_2(offsets, x_shape) 43 | 44 | # x: (b*c, h, w) 45 | x = self._to_bc_h_w(x, x_shape) 46 | 47 | # X_offset: (b*c, h, w) 48 | x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x)) 49 | 50 | # x_offset: (b, h, w, c) 51 | x_offset = self._to_b_c_h_w(x_offset, x_shape) 52 | 53 | return x_offset 54 | 55 | @staticmethod 56 | def _get_grid(self, x): 57 | batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2) 58 | dtype, cuda = x.data.type(), x.data.is_cuda 59 | if self._grid_param == (batch_size, input_height, input_width, dtype, cuda): 60 | return self._grid 61 | self._grid_param = (batch_size, input_height, input_width, dtype, cuda) 62 | self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda) 63 | return self._grid 64 | 65 | @staticmethod 66 | def _init_weights(weights, std): 67 | fan_out = weights.size(0) 68 | fan_in = weights.size(1) * weights.size(2) * weights.size(3) 69 | w = np.random.normal(0.0, std, (fan_out, fan_in)) 70 | return torch.from_numpy(w.reshape(weights.size())) 71 | 72 | @staticmethod 73 | def _to_bc_h_w_2(x, x_shape): 74 | """(b, 2c, h, w) -> (b*c, h, w, 2)""" 75 | x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2) 76 | return x 77 | 78 | @staticmethod 79 | def _to_bc_h_w(x, x_shape): 80 | """(b, c, h, w) -> (b*c, h, w)""" 81 | x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) 82 | return x 83 | 84 | @staticmethod 85 | def _to_b_c_h_w(x, x_shape): 86 | """(b*c, h, w) -> (b, c, h, w)""" 87 | x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) 88 | return x 89 | -------------------------------------------------------------------------------- /torch_deform_conv/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | from keras.datasets import mnist 4 | from keras.preprocessing.image import ImageDataGenerator 5 | 6 | 7 | def get_mnist_dataset(): 8 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 9 | X_train = X_train.astype('float32') / 255 10 | X_test = X_test.astype('float32') / 255 11 | X_train = X_train[..., None] 12 | X_test = X_test[..., None] 13 | Y_train = y_train 14 | Y_test = y_test 15 | 16 | return (X_train, Y_train), (X_test, Y_test) 17 | 18 | 19 | def get_gen(set_name, batch_size, translate, scale, 20 | shuffle=True): 21 | if set_name == 'train': 22 | (X, Y), _ = get_mnist_dataset() 23 | elif set_name == 'test': 24 | _, (X, Y) = get_mnist_dataset() 25 | 26 | image_gen = ImageDataGenerator( 27 | zoom_range=scale, 28 | width_shift_range=translate, 29 | height_shift_range=translate 30 | ) 31 | gen = image_gen.flow(X, Y, batch_size=batch_size, shuffle=shuffle) 32 | return gen 33 | -------------------------------------------------------------------------------- /torch_deform_conv/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | def transfer_weights(model_from, model_to): 4 | wf = copy.deepcopy(model_from.state_dict()) 5 | wt = model_to.state_dict() 6 | for k in wt.keys(): 7 | if not k in wf: 8 | wf[k] = wt[k] 9 | model_to.load_state_dict(wf) 10 | --------------------------------------------------------------------------------