├── __init__.py ├── mnist ├── best_model.pt ├── make_mnist_rot.py ├── mnist.py ├── download_mnist.py └── mnist_test.py ├── README.md ├── utils.py ├── layers_2D.py └── layers_3D.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mnist/best_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COGMAR/RotEqNet/HEAD/mnist/best_model.pt -------------------------------------------------------------------------------- /mnist/make_mnist_rot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from mnist import random_rotation, loadMnist 4 | 5 | 6 | 7 | def makeMnistRot(): 8 | """ 9 | Make MNIST-rot from MNIST 10 | Select all training and test samples from MNIST and select 10000 for train, 11 | 2000 for val and 50000 for test. Apply a random rotation to each image. 12 | 13 | Store in numpy file for fast reading 14 | 15 | """ 16 | np.random.seed(0) 17 | 18 | #Get all samples 19 | all_samples = loadMnist('train') + loadMnist('test') 20 | 21 | # 22 | 23 | #Empty arrays 24 | train_data = np.zeros([28,28,10000]) 25 | train_label = np.zeros([10000]) 26 | val_data = np.zeros([28,28,2000]) 27 | val_label = np.zeros([2000]) 28 | test_data = np.zeros([28,28,50000]) 29 | test_label = np.zeros([50000]) 30 | 31 | i = 0 32 | for j in range(10000): 33 | sample =all_samples[i] 34 | train_data[:, :, j] = random_rotation(sample[0]) 35 | train_label[j] = sample[1] 36 | i += 1 37 | 38 | for j in range(2000): 39 | sample = all_samples[i] 40 | val_data[:, :, j] = random_rotation(sample[0]) 41 | val_label[j] = sample[1] 42 | i += 1 43 | 44 | for j in range(50000): 45 | sample = all_samples[i] 46 | test_data[:, :, j] = random_rotation(sample[0]) 47 | test_label[j] = sample[1] 48 | i += 1 49 | 50 | 51 | try: 52 | os.mkdir('mnist_rot/') 53 | except: 54 | None 55 | np.save('mnist_rot/train_data',train_data) 56 | np.save('mnist_rot/train_label', train_label) 57 | np.save('mnist_rot/val_data', val_data) 58 | np.save('mnist_rot/val_label', val_label) 59 | np.save('mnist_rot/test_data', test_data) 60 | np.save('mnist_rot/test_label', test_label) 61 | 62 | if __name__ == '__main__': 63 | makeMnistRot() -------------------------------------------------------------------------------- /mnist/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import sys 4 | sys.path.append('../') 5 | from utils import getGrid, rotate_grid_2D 6 | 7 | def loadMnist(mode): 8 | print 'Loading MNIST', mode, 'images' 9 | # Mode = 'train'/'test 10 | mnist_folder = '/nr/user/andersuw/shared/datasets/mnist/mnist/' 11 | 12 | with file(mnist_folder + mode + '-labels.csv') as f: 13 | path_and_labels = f.readlines() 14 | 15 | samples = []; 16 | for entry in path_and_labels: 17 | path = entry.split(',')[0] 18 | label = int(entry.split(',')[1]) 19 | img = scipy.misc.imread(mnist_folder + path) 20 | 21 | samples.append([img, label]) 22 | return samples 23 | 24 | 25 | def linear_interpolation_2D(input_array, indices, outside_val=0, boundary_correction=True): 26 | # http://stackoverflow.com/questions/6427276/3d-interpolation-of-numpy-arrays-without-scipy 27 | output = np.empty(indices[0].shape) 28 | ind_0 = indices[0,:] 29 | ind_1 = indices[1,:] 30 | 31 | N0, N1 = input_array.shape 32 | 33 | x0_0 = ind_0.astype(np.integer) 34 | x1_0 = ind_1.astype(np.integer) 35 | x0_1 = x0_0 + 1 36 | x1_1 = x1_0 + 1 37 | 38 | # Check if inds are beyond array boundary: 39 | if boundary_correction: 40 | # put all samples outside datacube to 0 41 | inds_out_of_range = (x0_0 < 0) | (x0_1 < 0) | (x1_0 < 0) | (x1_1 < 0) | \ 42 | (x0_0 >= N0) | (x0_1 >= N0) | (x1_0 >= N1) | (x1_1 >= N1) 43 | 44 | x0_0[inds_out_of_range] = 0 45 | x1_0[inds_out_of_range] = 0 46 | x0_1[inds_out_of_range] = 0 47 | x1_1[inds_out_of_range] = 0 48 | 49 | w0 = ind_0 - x0_0 50 | w1 = ind_1 - x1_0 51 | # Replace by this... 52 | # input_array.take(np.array([x0_0, x1_0, x2_0])) 53 | output = (input_array[x0_0, x1_0] * (1 - w0) * (1 - w1) + 54 | input_array[x0_1, x1_0] * w0 * (1 - w1) + 55 | input_array[x0_0, x1_1] * (1 - w0) * w1 + 56 | input_array[x0_1, x1_1] * w0 * w1 ) 57 | 58 | 59 | if boundary_correction: 60 | output[inds_out_of_range] = 0 61 | 62 | return output 63 | 64 | def loadMnistRot(): 65 | def load_and_make_list(mode): 66 | data = np.load('mnist_rot/' + mode + '_data.npy') 67 | lbls = np.load('mnist_rot/' + mode + '_label.npy') 68 | data = np.split(data, data.shape[2],2) 69 | lbls = np.split(lbls, lbls.shape[0],0) 70 | 71 | return zip(data,lbls) 72 | 73 | train = load_and_make_list('train') 74 | val = load_and_make_list('val') 75 | test = load_and_make_list('test') 76 | return train, val, test 77 | 78 | def random_rotation(data): 79 | rot = np.random.rand() * 360 # Random rotation 80 | grid = getGrid([28, 28]) 81 | grid = rotate_grid_2D(grid, rot) 82 | grid += 13.5 83 | data = linear_interpolation_2D(data, grid) 84 | data = np.reshape(data, [28, 28]) 85 | data = data / float(np.max(data)) 86 | return data.astype('float32') 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rotational Equivariant Vector Field Networks (RotEqNet) for PyTorch 2 | 3 | This is a PyTorch implementation of the method proposed in: 4 | Rotation equivariant vector field networks, ICCV 2017, 5 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia. 6 | 7 | https://arxiv.org/abs/1612.09346 8 | 9 | 10 | The original MATLAB implementation can be found at: 11 | 12 | https://github.com/dmarcosg/RotEqNet 13 | 14 | The goal of this code is to provide an implementation of the new network layers proposed in the paper. In addition we try to reproduce the results the MNIST-rot dataset to verify the implementation. 15 | 16 | 17 | 18 | 19 | ### Example usage 20 | ```python 21 | from __future__ import division 22 | from layers_2D import RotConv, VectorMaxPool, VectorBatchNorm, Vector2Magnitude, VectorUpsampling 23 | from torch import nn 24 | 25 | 26 | class MnistNet(nn.Module): 27 | def __init__(self): 28 | super(MnistNet, self).__init__() 29 | 30 | self.main = nn.Sequential( 31 | RotConv(1, 6, [9, 9], 1, 9 // 2, n_angles=17, mode=1), #The first RotConv must have mode=1 32 | VectorMaxPool(2), 33 | VectorBatchNorm(6), 34 | 35 | RotConv(6, 16, [9, 9], 1, 9 // 2, n_angles=17, mode=2), #The next RotConv has mode=2 (since the input is vector field) 36 | VectorMaxPool(2), 37 | VectorBatchNorm(16), 38 | 39 | RotConv(16, 32, [9, 9], 1, 1, n_angles=17, mode=2), 40 | Vector2Magnitude(), #This call converts the vector field to a conventional multichannel image/feature image 41 | 42 | nn.Conv2d(32, 128, 1), 43 | nn.BatchNorm2d(128), 44 | nn.ReLU(), 45 | nn.Dropout2d(0.7), 46 | nn.Conv2d(128, 10, 1), 47 | 48 | ) 49 | 50 | def forward(self,x): 51 | x = self.main(x) 52 | return x 53 | ``` 54 | 55 | 56 | ### Dependencies 57 | The following python packages are required: 58 | 59 | ``` 60 | torch 61 | numpy 62 | scipy 63 | ``` 64 | To download and setup the MNIST-rot dataset, cd into the MNIST-folder and run: 65 | ``` 66 | python download_mnist.py 67 | python make_mnist_rot.py 68 | ``` 69 | To run the MNIST-test: 70 | ``` 71 | python mnist_test.py 72 | ``` 73 | ## Results from the MNIST-rot test 74 | The MNIST-experiment in the orignial paper was obtained by: 75 | - training on 10 000 images from the MNIST-rot dataset + applying random rotation as augmentation 76 | - validating on 2000 images from the MNIST-rot dataset 77 | - testing on 10 0000 images from the MNIST-rot dataset + with test-time augmentation as described in the paper 78 | 79 | Using this implementation, we obtain a test accuracy of 1.2%, while the original paper reports 1.1%. 80 | 81 | ### Known issues: 82 | - The interpolation of filters (apply_transformation in utils.py) sometimes causes "CUDA runtime error 59". This error disappears when we use "torch.gather" to collect the samples, but this does reduce the best test error rate to ~3%. 83 | 84 | 85 | ### Contact 86 | Anders U. Waldeland
87 | Norwegian Computing Center
88 | anders@nr.no
89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /mnist/download_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | """ 6 | From: 7 | https://gist.github.com/ischlag/41d15424e7989b936c1609b53edd1390 8 | """ 9 | 10 | import gzip 11 | import os 12 | import sys 13 | import time 14 | 15 | from six.moves import urllib 16 | from six.moves import xrange # pylint: disable=redefined-builtin 17 | from scipy.misc import imsave 18 | import tensorflow as tf 19 | import numpy as np 20 | import csv 21 | 22 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 23 | WORK_DIRECTORY = 'raw_data' 24 | IMAGE_SIZE = 28 25 | NUM_CHANNELS = 1 26 | PIXEL_DEPTH = 255 27 | NUM_LABELS = 10 28 | 29 | def maybe_download(filename): 30 | """Download the data from Yann's website, unless it's already here.""" 31 | if not tf.gfile.Exists(WORK_DIRECTORY): 32 | tf.gfile.MakeDirs(WORK_DIRECTORY) 33 | filepath = os.path.join(WORK_DIRECTORY, filename) 34 | if not tf.gfile.Exists(filepath): 35 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 36 | with tf.gfile.GFile(filepath) as f: 37 | size = f.size() 38 | print('Successfully downloaded', filename, size, 'bytes.') 39 | return filepath 40 | 41 | 42 | def extract_data(filename, num_images): 43 | """Extract the images into a 4D tensor [image index, y, x, channels]. 44 | Values are rescaled from [0, 255] down to [-0.5, 0.5]. 45 | """ 46 | print('Extracting', filename) 47 | with gzip.open(filename) as bytestream: 48 | bytestream.read(16) 49 | buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images) 50 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) 51 | #data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH 52 | data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1) 53 | return data 54 | 55 | 56 | def extract_labels(filename, num_images): 57 | """Extract the labels into a vector of int64 label IDs.""" 58 | print('Extracting', filename) 59 | with gzip.open(filename) as bytestream: 60 | bytestream.read(8) 61 | buf = bytestream.read(1 * num_images) 62 | labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) 63 | return labels 64 | 65 | train_data_filename = maybe_download('train-images-idx3-ubyte.gz') 66 | train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz') 67 | test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz') 68 | test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz') 69 | 70 | # Extract it into np arrays. 71 | train_data = extract_data(train_data_filename, 60000) 72 | train_labels = extract_labels(train_labels_filename, 60000) 73 | test_data = extract_data(test_data_filename, 10000) 74 | test_labels = extract_labels(test_labels_filename, 10000) 75 | 76 | if not os.path.isdir("mnist/train-images"): 77 | os.makedirs("mnist/train-images") 78 | 79 | if not os.path.isdir("mnist/test-images"): 80 | os.makedirs("mnist/test-images") 81 | 82 | # process train data 83 | with open("mnist/train-labels.csv", 'wb') as csvFile: 84 | writer = csv.writer(csvFile, delimiter=',', quotechar='"') 85 | for i in range(len(train_data)): 86 | imsave("mnist/train-images/" + str(i) + ".jpg", train_data[i][:,:,0]) 87 | writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]]) 88 | 89 | # repeat for test data 90 | with open("mnist/test-labels.csv", 'wb') as csvFile: 91 | writer = csv.writer(csvFile, delimiter=',', quotechar='"') 92 | for i in range(len(test_data)): 93 | imsave("mnist/test-images/" + str(i) + ".jpg", test_data[i][:,:,0]) 94 | writer.writerow(["test-images/" + str(i) + ".jpg", test_labels[i]]) -------------------------------------------------------------------------------- /mnist/mnist_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from torch import optim 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import random 9 | from mnist import loadMnistRot, random_rotation, linear_interpolation_2D 10 | 11 | import sys 12 | sys.path.append('../') #Import 13 | from layers_2D import * 14 | from utils import getGrid 15 | 16 | #!/usr/bin/env python 17 | __author__ = "Anders U. Waldeland" 18 | __email__ = "anders@nr.no" 19 | 20 | """ 21 | A reproduction of the MNIST-classification network described in: 22 | Rotation equivariant vector field networks (ICCV 2017) 23 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia 24 | https://arxiv.org/abs/1612.09346 25 | https://github.com/dmarcosg/RotEqNet 26 | """ 27 | 28 | 29 | if __name__ == '__main__': 30 | 31 | # Define network 32 | class Net(nn.Module): 33 | def __init__(self): 34 | super(Net, self).__init__() 35 | 36 | self.main = nn.Sequential( 37 | 38 | RotConv(1, 6, [9, 9], 1, 9 // 2, n_angles=17, mode=1), 39 | VectorMaxPool(2), 40 | VectorBatchNorm(6), 41 | 42 | RotConv(6, 16, [9, 9], 1, 9 // 2, n_angles=17, mode=2), 43 | VectorMaxPool(2), 44 | VectorBatchNorm(16), 45 | 46 | RotConv(16, 32, [9, 9], 1, 1, n_angles=17, mode=2), 47 | Vector2Magnitude(), 48 | 49 | nn.Conv2d(32, 128, 1), # FC1 50 | nn.BatchNorm2d(128), 51 | nn.ReLU(), 52 | nn.Dropout2d(0.7), 53 | nn.Conv2d(128, 10, 1), # FC2 54 | 55 | ) 56 | 57 | def forward(self, x): 58 | x = self.main(x) 59 | x = x.view(x.size()[0], x.size()[1]) 60 | 61 | return x 62 | 63 | 64 | gpu_no = 0 # Set to False for cpu-version 65 | 66 | #Setup net, loss function, optimizer and hyper parameters 67 | net = Net() 68 | criterion = nn.CrossEntropyLoss() 69 | if type(gpu_no) == int: 70 | net.cuda(gpu_no) 71 | 72 | if True: #Current best setup using this implementation - error rate of 1.2% 73 | start_lr = 0.01 74 | batch_size = 128 75 | optimizer = optim.Adam(net.parameters(), lr=start_lr) # , weight_decay=0.01) 76 | use_test_time_augmentation = True 77 | use_train_time_augmentation = True 78 | 79 | if False: #From paper using MATLAB implementation - reported error rate of 1.4% 80 | start_lr = 0.1 81 | batch_size = 600 82 | optimizer = optim.SGD(net.parameters(), lr=start_lr, weight_decay=0.01) 83 | use_test_time_augmentation = True 84 | use_train_time_augmentation = True 85 | 86 | 87 | def rotate_im(im, theta): 88 | grid = getGrid([28, 28]) 89 | grid = rotate_grid_2D(grid, theta) 90 | grid += 13.5 91 | data = linear_interpolation_2D(im, grid) 92 | data = np.reshape(data, [28, 28]) 93 | return data.astype('float32') 94 | 95 | 96 | def test(model, dataset, mode): 97 | """ Return test-acuracy for a dataset""" 98 | model.eval() 99 | 100 | true = [] 101 | pred = [] 102 | for batch_no in xrange(len(dataset) // batch_size): 103 | data, labels = getBatch(dataset, mode) 104 | 105 | #Run same sample with different orientations through network and average output 106 | if use_test_time_augmentation and mode == 'test': 107 | data = data.cpu() 108 | original_data = data.clone().data.cpu().numpy() 109 | 110 | out = None 111 | rotations = [0,15,30,45, 60, 75, 90] 112 | 113 | for rotation in rotations: 114 | 115 | for i in range(batch_size): 116 | im = original_data[i,:,:,:].squeeze() 117 | im = rotate_im(im, rotation) 118 | im = im.reshape([1, 1, 28, 28]) 119 | im = torch.FloatTensor(im) 120 | data[i,:,:,:] = im 121 | 122 | if type(gpu_no) == int: 123 | data = data.cuda(gpu_no) 124 | 125 | if out is None: 126 | out = F.softmax(model(data)) 127 | else: 128 | out += F.softmax(model(data)) 129 | 130 | out /= len(rotations) 131 | 132 | #Only run once 133 | else: 134 | out = F.softmax(model(data),dim=1) 135 | 136 | loss = criterion(out, labels) 137 | _, c = torch.max(out, 1) 138 | true.append(labels.data.cpu().numpy()) 139 | pred.append(c.data.cpu().numpy()) 140 | true = np.concatenate(true, 0) 141 | pred = np.concatenate(pred, 0) 142 | acc = np.average(pred == true) 143 | return acc 144 | 145 | def getBatch(dataset, mode): 146 | """ Collect a batch of samples from list """ 147 | 148 | # Make batch 149 | data = [] 150 | labels = [] 151 | for sample_no in range(batch_size): 152 | tmp = dataset.pop() # Get top element and remove from list 153 | img = tmp[0].astype('float32').squeeze() 154 | 155 | # Train-time random rotation 156 | if mode == 'train' and use_train_time_augmentation: 157 | img = random_rotation(img) 158 | 159 | data.append(np.expand_dims(np.expand_dims(img, 0), 0)) 160 | labels.append(tmp[1].squeeze()) 161 | data = np.concatenate(data, 0) 162 | labels = np.array(labels, 'int32') 163 | 164 | data = Variable(torch.from_numpy(data)) 165 | labels = Variable(torch.from_numpy(labels).long()) 166 | 167 | if type(gpu_no) == int: 168 | data = data.cuda(gpu_no) 169 | labels = labels.cuda(gpu_no) 170 | 171 | return data, labels 172 | 173 | def adjust_learning_rate(optimizer, epoch): 174 | """Gradually decay learning rate""" 175 | if epoch == 20: 176 | lr = start_lr / 10 177 | for param_group in optimizer.param_groups: 178 | param_group['lr'] = lr 179 | if epoch == 40: 180 | lr = start_lr / 100 181 | for param_group in optimizer.param_groups: 182 | param_group['lr'] = lr 183 | if epoch == 60: 184 | lr = start_lr / 100 185 | for param_group in optimizer.param_groups: 186 | param_group['lr'] = lr 187 | 188 | #Load datasets 189 | train_set, val_set, test_set = loadMnistRot() 190 | 191 | best_acc = 0 192 | for epoch_no in range(90): 193 | 194 | #Random order for each epoch 195 | train_set_for_epoch = train_set[:] #Make a copy 196 | random.shuffle(train_set_for_epoch) #Shuffle the copy 197 | 198 | #Training 199 | net.train() 200 | for batch_no in xrange(len(train_set)//batch_size): 201 | 202 | # Train 203 | optimizer.zero_grad() 204 | 205 | data, labels = getBatch(train_set_for_epoch, 'train') 206 | out = net( data ) 207 | loss = criterion( out,labels ) 208 | _, c = torch.max(out, 1) 209 | loss.backward() 210 | 211 | optimizer.step() 212 | 213 | #Print training-acc 214 | if batch_no%10 == 0: 215 | print('Train', 'epoch:', epoch_no, ' batch:', batch_no, ' loss:', loss.data.cpu().numpy()[0], ' acc:', np.average((c == labels).data.cpu().numpy())) 216 | 217 | 218 | #Validation 219 | acc = test(net, val_set[:], 'val') 220 | print('Val', 'epoch:', epoch_no, ' acc:', acc) 221 | 222 | #Save model if better than previous 223 | if acc > best_acc: 224 | torch.save(net.state_dict(), 'best_model.pt') 225 | best_acc = acc 226 | print('Model saved') 227 | 228 | adjust_learning_rate(optimizer, epoch_no) 229 | 230 | # Finally test on test-set with the best model 231 | net.load_state_dict(torch.load('best_model.pt')) 232 | print('Test', 'acc:', test(net, test_set[:], 'test')) 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | from scipy.linalg import expm, norm 3 | import collections 4 | import itertools 5 | import numpy as np 6 | from torch.autograd import Variable 7 | import torch 8 | 9 | def ntuple(n): 10 | """ Ensure that input has the correct number of elements """ 11 | def parse(x): 12 | if isinstance(x, collections.Iterable): 13 | return x 14 | return tuple(itertools.repeat(x, n)) 15 | return parse 16 | 17 | def getGrid(siz): 18 | """ Returns grid with coordinates from -siz[0]/2 : siz[0]/2, -siz[1]/2 : siz[1]/2, ....""" 19 | space = [np.linspace( -(N/2), (N/2), N ) for N in siz] 20 | mesh = np.meshgrid( *space, indexing='ij' ) 21 | mesh = [np.expand_dims( ax.ravel(), 0) for ax in mesh] 22 | 23 | return np.concatenate(mesh) 24 | 25 | def rotate_grid_2D(grid, theta): 26 | """ Rotate grid """ 27 | theta = np.deg2rad(theta) 28 | 29 | x0 = grid[0, :] * np.cos(theta) - grid[1, :] * np.sin(theta) 30 | x1 = grid[0, :] * np.sin(theta) + grid[1, :] * np.cos(theta) 31 | 32 | grid[0, :] = x0 33 | grid[1, :] = x1 34 | return grid 35 | 36 | def rotate_grid_3D(theta, axis, grid): 37 | """ Rotate grid """ 38 | theta = np.deg2rad(theta) 39 | axis = np.array(axis) 40 | rot_mat = expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 41 | rot_mat =np.expand_dims(rot_mat,2) 42 | grid = np.transpose( np.expand_dims(grid,2), [0,2,1]) 43 | 44 | return np.einsum('ijk,jik->ik',rot_mat,grid) 45 | 46 | 47 | def get_filter_rotation_transforms(kernel_dims, angles): 48 | """ Return the interpolation variables needed to transform a filter by a given number of degrees """ 49 | 50 | dim = len(kernel_dims) 51 | 52 | # Make grid (centered around filter-center) 53 | grid = getGrid(kernel_dims) 54 | 55 | # Rotate grid 56 | if dim == 2: 57 | grid = rotate_grid_2D(grid, angles) 58 | elif dim == 3: 59 | grid = rotate_grid_3D(angles[0], [1, 0, 0], grid) 60 | grid = rotate_grid_3D(angles[1], [0, 0, 1], grid) 61 | 62 | 63 | # Radius of filter 64 | radius = np.min((np.array(kernel_dims)-1) / 2.) 65 | 66 | #Mask out samples outside circle 67 | radius = np.expand_dims(radius,-1) 68 | dist_to_center = np.sqrt(np.sum(grid**2,axis=0)) 69 | mask = dist_to_center>=radius+.0001 70 | mask = 1-mask 71 | 72 | # Move grid to center 73 | grid += radius 74 | 75 | return compute_interpolation_grids(grid, kernel_dims, mask) 76 | 77 | def compute_interpolation_grids(grid, kernel_dims, mask): 78 | 79 | ####################################################### 80 | # The following part is part of nd-linear interpolation 81 | 82 | #Add a small eps to grid so that floor and ceil operations become more stable 83 | grid += 0.000000001 84 | 85 | # Make list where each element represents a dimension 86 | grid = [grid[i, :] for i in range(grid.shape[0])] 87 | 88 | # Get left and right index (integers) 89 | inds_0 = [ind.astype(np.integer) for ind in grid] 90 | inds_1 = [ind + 1 for ind in inds_0] 91 | 92 | # Get weights 93 | weights = [float_ind - int_ind for float_ind, int_ind in zip(grid, inds_0)] 94 | 95 | # Special case for when ind_1 == size (while ind_0 == siz) 96 | # In that case we select ind_0 97 | ind_1_out_of_bounds = np.logical_or.reduce([ind == siz for ind, siz in zip(inds_1, kernel_dims)]) 98 | for i in range(len(inds_1)): 99 | inds_1[i][ind_1_out_of_bounds] = 0 100 | 101 | 102 | # Get samples that are out of bounds or outside mask 103 | inds_out_of_bounds = np.logical_or.reduce([ind < 0 for ind in itertools.chain(inds_0, inds_1)] + \ 104 | [ind >= siz for ind, siz in zip(inds_0, kernel_dims)] + \ 105 | [ind >= siz for ind, siz in zip(inds_1, kernel_dims)] + 106 | (1-mask).astype('bool') 107 | ) 108 | 109 | 110 | # Set these samples to zero get data from upper-left-corner (which will be put to zero) 111 | for i in range(len(inds_0)): 112 | inds_0[i][inds_out_of_bounds] = 0 113 | inds_1[i][inds_out_of_bounds] = 0 114 | 115 | #Reshape 116 | inds_0 = [np.reshape(ind,[1,1]+kernel_dims) for ind in inds_0] 117 | inds_1 = [np.reshape(ind,[1,1]+kernel_dims) for ind in inds_1] 118 | weights = [np.reshape(weight,[1,1]+kernel_dims)for weight in weights] 119 | 120 | #Make pytorch-tensors of the interpolation variables 121 | inds_0 = [Variable(torch.LongTensor(ind)) for ind in inds_0] 122 | inds_1 = [Variable(torch.LongTensor(ind)) for ind in inds_1] 123 | weights = [Variable(torch.FloatTensor(weight)) for weight in weights] 124 | 125 | #Make mask pytorch tensor 126 | mask = mask.reshape(kernel_dims) 127 | mask = mask.astype('float32') 128 | mask = np.expand_dims(mask, 0) 129 | mask = np.expand_dims(mask, 0) 130 | mask = torch.FloatTensor(mask) 131 | 132 | # Uncomment for nearest interpolation (for debugging) 133 | #inds_1 = [ind*0 for ind in inds_1] 134 | #weights = [weight*0 for weight in weights] 135 | 136 | return inds_0, inds_1, weights, mask 137 | 138 | def apply_transform(filter, interp_vars, filters_size, old_bilinear_interpolation=True): 139 | """ Apply a transform specified by the interpolation_variables to a filter """ 140 | 141 | dim = 2 if len(filter.size())==4 else 3 142 | 143 | if dim == 2: 144 | 145 | 146 | if old_bilinear_interpolation: 147 | [x0_0, x1_0], [x0_1, x1_1], [w0, w1] = interp_vars 148 | rotated_filter = (filter[:, :, x0_0, x1_0] * (1 - w0) * (1 - w1) + 149 | filter[:, :, x0_1, x1_0] * w0 * (1 - w1) + 150 | filter[:, :, x0_0, x1_1] * (1 - w0) * w1 + 151 | filter[:, :, x0_1, x1_1] * w0 * w1) 152 | else: 153 | 154 | # Expand dimmentions to fit filter 155 | interp_vars = [[inner_el.expand_as(filter) for inner_el in outer_el] for outer_el in interp_vars] 156 | 157 | [x0_0, x1_0], [x0_1, x1_1], [w0, w1] = interp_vars 158 | 159 | a = torch.gather(torch.gather(filter, 2, x0_0), 3, x1_0) * (1 - w0) * (1 - w1) 160 | b = torch.gather(torch.gather(filter, 2, x0_1), 3, x1_0)* w0 * (1 - w1) 161 | c = torch.gather(torch.gather(filter, 2, x0_0), 3, x1_1)* (1 - w0) * w1 162 | d = torch.gather(torch.gather(filter, 2, x0_1), 3, x1_1)* w0 * w1 163 | rotated_filter = a+b+c+d 164 | 165 | rotated_filter = rotated_filter.view(filter.size()[0],filter.size()[1],filters_size[0],filters_size[1]) 166 | 167 | elif dim == 3: 168 | [x0_0, x1_0, x2_0], [x0_1, x1_1, x2_1], [w0, w1, w2] = interp_vars 169 | 170 | rotated_filter = (filter[x0_0, x1_0, x2_0] * (1 - w0) * (1 - w1)* (1 - w2) + 171 | filter[x0_1, x1_0, x2_0] * w0 * (1 - w1)* (1 - w2) + 172 | filter[x0_0, x1_1, x2_0] * (1 - w0) * w1 * (1 - w2) + 173 | filter[x0_1, x1_1, x2_0] * w0 * w1 * (1 - w2) + 174 | filter[x0_0, x1_0, x2_1] * (1 - w0) * (1 - w1)* w2 + 175 | filter[x0_1, x1_0, x2_1] * w0 * (1 - w1)* w2 + 176 | filter[x0_0, x1_1, x2_1] * (1 - w0) * w1 * w2 + 177 | filter[x0_1, x1_1, x2_1] * w0 * w1 * w2) 178 | 179 | rotated_filter = rotated_filter.view(filter.size()[0], filter.size()[1], filters_size[0], filters_size[1], filters_size[2]) 180 | 181 | return rotated_filter 182 | 183 | 184 | 185 | if __name__ == '__main__': 186 | """ Test rotation of filter """ 187 | import torch.nn as nn 188 | from torch.nn import functional as F 189 | from torch.nn.parameter import Parameter 190 | import math 191 | from utils import * 192 | 193 | ks = [9,9] #Kernel size 194 | angle = 45 195 | interp_vars = get_filter_rotation_transforms(ks, angle) 196 | 197 | w = Variable(torch.ones([1,1]+ks)) 198 | #w[:,:,4,:] = 5 199 | w[:, :, :, 4] = 5 200 | #w[:,:,0,0] = -1 201 | 202 | 203 | print(w) 204 | for angle in [0,90,45,180,65,10]: 205 | print(angle,'degrees') 206 | print(apply_transform(w, get_filter_rotation_transforms(ks, angle)[:-1], ks,old_bilinear_interpolation=True) * Variable(get_filter_rotation_transforms(ks, angle)[-1])) 207 | print('Difference', torch.sum(apply_transform(w, get_filter_rotation_transforms(ks, angle)[:-1], ks,old_bilinear_interpolation=False) * Variable( get_filter_rotation_transforms(ks, angle)[-1]) - apply_transform(w, get_filter_rotation_transforms(ks, angle)[:-1], ks,old_bilinear_interpolation=True) * Variable(get_filter_rotation_transforms(ks, angle)[-1]))) 208 | 209 | 210 | -------------------------------------------------------------------------------- /layers_2D.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is an PyTorch implementation of the method proposed in: 3 | Rotation equivariant vector field networks (ICCV 2017) 4 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia 5 | https://arxiv.org/abs/1612.09346 6 | https://github.com/dmarcosg/RotEqNet (original code) 7 | """ 8 | from __future__ import division 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch.nn.parameter import Parameter 12 | import math 13 | from utils import * 14 | 15 | class RotConv(nn.Module): 16 | 17 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 18 | padding=0, dilation=1, n_angles = 8, mode=1): 19 | super(RotConv, self).__init__() 20 | 21 | kernel_size = ntuple(2)(kernel_size) 22 | stride = ntuple(2)(stride) 23 | padding = ntuple(2)(padding) 24 | dilation = ntuple(2)(dilation) 25 | 26 | self.in_channels = in_channels 27 | self.out_channels = out_channels 28 | self.kernel_size = kernel_size 29 | self.stride = stride 30 | self.padding = padding 31 | self.dilation = dilation 32 | 33 | self.mode = mode 34 | 35 | #Angles 36 | self.angles = np.linspace(0,360,n_angles, endpoint=False) 37 | self.angle_tensors = [] 38 | 39 | #Get interpolation variables 40 | self.interp_vars = [] 41 | for angle in self.angles: 42 | out = get_filter_rotation_transforms(list(self.kernel_size), angle) 43 | self.interp_vars.append(out[:-1]) 44 | self.mask = out[-1] 45 | 46 | self.angle_tensors.append( Variable(torch.FloatTensor( np.array([angle/ 180. * np.pi]) )) ) 47 | 48 | self.weight1 = Parameter(torch.Tensor( out_channels, in_channels , *kernel_size)) 49 | #If input is vector field, we have two filters (one for each component) 50 | if self.mode == 2: 51 | self.weight2 = Parameter(torch.Tensor( out_channels, in_channels, *kernel_size)) 52 | 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | n = self.in_channels 57 | for k in self.kernel_size: 58 | n *= k 59 | stdv = 1. / math.sqrt(n) 60 | self.weight1.data.uniform_(-stdv, stdv) 61 | if self.mode == 2: 62 | self.weight2.data.uniform_(-stdv, stdv) 63 | 64 | def mask_filters(self): 65 | self.weight1.data[self.mask.expand_as(self.weight1) == 0] = 1e-8 66 | if self.mode == 2: 67 | self.weight2.data[self.mask.expand_as(self.weight1) == 0] = 1e-8 68 | 69 | def _apply(self, func): 70 | # This is called whenever user calls model.cuda() 71 | # We intersect to replace tensors and variables with cuda-versions 72 | self.mask = func(self.mask) 73 | self.interp_vars = [[[func(el2) for el2 in el1] for el1 in el0] for el0 in self.interp_vars] 74 | self.angle_tensors = [func(el) for el in self.angle_tensors] 75 | 76 | return super(RotConv, self)._apply(func) 77 | 78 | 79 | def forward(self,input): 80 | #Uncomment this to turn on filter-masking 81 | #Todo: fix broken convergence when filter-masking is on 82 | #self.mask_filters() 83 | 84 | if self.mode == 1: 85 | outputs = [] 86 | 87 | #Loop through the different filter-transformations 88 | for ind, interp_vars in enumerate(self.interp_vars): 89 | #Apply rotation 90 | weight = apply_transform(self.weight1, interp_vars, self.kernel_size) 91 | 92 | #Do convolution 93 | out = F.conv2d(input, weight, None, self.stride, self.padding, self.dilation) 94 | outputs.append(out.unsqueeze(-1)) 95 | 96 | if self.mode == 2: 97 | u = input[0] 98 | v = input[1] 99 | 100 | outputs = [] 101 | # Loop through the different filter-transformations 102 | for ind, interp_vars in enumerate(self.interp_vars): 103 | angle = self.angle_tensors[ind] 104 | # Apply rotation 105 | wu = apply_transform(self.weight1, interp_vars, self.kernel_size) 106 | wv = apply_transform(self.weight2, interp_vars, self.kernel_size) 107 | 108 | # Do convolution for u 109 | wru = torch.cos(angle) * wu - torch.sin(angle ) * wv 110 | u_out = F.conv2d(u, wru, None, self.stride, self.padding, self.dilation) 111 | 112 | # Do convolution for v 113 | wrv = torch.sin(angle) * wu + torch.cos(angle) * wv 114 | v_out = F.conv2d(v, wrv, None, self.stride, self.padding, self.dilation) 115 | 116 | #Compute magnitude (p) 117 | outputs.append( (u_out + v_out).unsqueeze(-1) ) 118 | 119 | 120 | # Get the maximum direction (Orientation Pooling) 121 | strength, max_ind = torch.max(torch.cat(outputs, -1), -1) 122 | 123 | # Convert from polar representation q 124 | angle_map = max_ind.float() * (360. / len(self.angles) / 180. * np.pi) 125 | u = F.relu(strength) * torch.cos(angle_map) 126 | v = F.relu(strength) * torch.sin(angle_map) 127 | 128 | 129 | return u, v 130 | 131 | class VectorMaxPool(nn.Module): 132 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1, 133 | ceil_mode=False): 134 | super(VectorMaxPool, self).__init__() 135 | self.kernel_size = kernel_size 136 | self.stride = stride or kernel_size 137 | self.padding = padding 138 | self.dilation = dilation 139 | self.ceil_mode = ceil_mode 140 | 141 | def forward(self,input): 142 | #Assuming input is vector field 143 | u = input[0] 144 | v = input[1] 145 | 146 | #Magnitude 147 | p = torch.sqrt( v**2 + u**2) 148 | #Max pool 149 | _, max_inds = F.max_pool2d(p, self.kernel_size, self.stride, 150 | self.padding, self.dilation, self.ceil_mode, 151 | return_indices=True) 152 | #Reshape to please pytorch 153 | s1 = u.size() 154 | s2 = max_inds.size() 155 | 156 | max_inds = max_inds.view(s1[0], s1[1], s2[2] * s2[3]) 157 | 158 | u = u.view(s1[0], s1[1], s1[2] * s1[3]) 159 | v = v.view(s1[0], s1[1], s1[2] * s1[3]) 160 | 161 | #Select u/v components according to max pool on magnitude 162 | u = torch.gather(u, 2, max_inds) 163 | v = torch.gather(v, 2, max_inds) 164 | 165 | #Reshape back 166 | u = u.view(s1[0], s1[1], s2[2], s2[3]) 167 | v = v.view(s1[0], s1[1], s2[2], s2[3]) 168 | 169 | return u,v 170 | 171 | class Vector2Magnitude(nn.Module): 172 | def __init__(self): 173 | super(Vector2Magnitude, self).__init__() 174 | 175 | def forward(self, input): 176 | u = input[0] 177 | v = input[1] 178 | 179 | p = torch.sqrt(v ** 2 + u ** 2) 180 | return p 181 | 182 | class Vector2Angle(nn.Module): 183 | def __init__(self): 184 | super(Vector2Angle, self).__init__() 185 | 186 | def forward(self, input): 187 | u = input[0] 188 | v = input[1] 189 | 190 | angle = torch.atan2(u, v) 191 | 192 | return angle 193 | 194 | class VectorBatchNorm(nn.Module): 195 | def __init__(self, num_features, eps=1e-5, momentum=0.5, affine=True): 196 | 197 | super(VectorBatchNorm, self).__init__() 198 | self.num_features = num_features 199 | self.affine = affine 200 | self.eps = eps 201 | self.momentum = momentum 202 | 203 | if self.affine: 204 | self.weight = Parameter(torch.Tensor(1,num_features,1,1)) 205 | else: 206 | self.register_parameter('weight', None) 207 | self.register_buffer('running_var', torch.ones(1,num_features,1,1)) 208 | self.reset_parameters() 209 | 210 | 211 | def reset_parameters(self): 212 | self.running_var.fill_(1) 213 | if self.affine: 214 | self.weight.data.uniform_() 215 | 216 | def forward(self, input): 217 | """ 218 | Based on https://github.com/lberrada/bn.pytorch 219 | """ 220 | if self.training: 221 | #Compute std 222 | std = self.std(input) 223 | 224 | alpha = self.weight / (std + self.eps) 225 | 226 | # update running variance 227 | self.running_var *= (1. - self.momentum) 228 | self.running_var += self.momentum * std.data ** 2 229 | # compute output 230 | u = input[0] * alpha 231 | v = input[1] * alpha 232 | 233 | else: 234 | alpha = self.weight.data / torch.sqrt(self.running_var + self.eps) 235 | 236 | # compute output 237 | u = input[0] * Variable(alpha) 238 | v = input[1] * Variable(alpha) 239 | return u,v 240 | 241 | def std(self, input): 242 | u = input[0] 243 | v = input[1] 244 | 245 | #Vector to magnitude 246 | p = torch.sqrt(u ** 2 + v ** 2) 247 | 248 | #We want to normalize the vector magnitudes, 249 | #therefore we ommit the mean (var = (p-p.mean())**2) 250 | #since we do not want to move the center of the vectors. 251 | 252 | var = (p)**2 253 | var = torch.mean(var, 0, keepdim=True) 254 | var = torch.mean(var, 2, keepdim=True) 255 | var = torch.mean(var, 3, keepdim=True) 256 | std = torch.sqrt(var) 257 | 258 | return std 259 | 260 | class VectorUpsampling(nn.Module): 261 | def __init__(self, size=None, scale_factor=None, mode = 'bilinear'): 262 | super(VectorUpsampling, self).__init__() 263 | self.size = size 264 | self.scale_factor = scale_factor 265 | self.mode = mode 266 | 267 | def forward(self, input): 268 | # Assuming input is vector field 269 | u = input[0] 270 | v = input[1] 271 | 272 | u = F.upsample(u, size=self.size, scale_factor=self.scale_factor, mode=self.mode) 273 | v = F.upsample(v, size=self.size, scale_factor=self.scale_factor, mode=self.mode) 274 | 275 | 276 | return u, v 277 | -------------------------------------------------------------------------------- /layers_3D.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __author__ = "Anders U. Waldeland" 3 | __email__ = "anders@nr.no" 4 | 5 | 6 | """ 7 | This code is an 3D extension of the 2D method proposed in: 8 | Rotation equivariant vector field networks (ICCV 2017) 9 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia 10 | https://arxiv.org/abs/1612.09346 11 | https://github.com/dmarcosg/RotEqNet 12 | 13 | We use the spherical coordinate system (see https://en.wikipedia.org/wiki/Spherical_coordinate_system) 14 | with coordinates (r/radius, theta/inclination, rho/azimuth). The 3D vector field has the cartesian coordinates (x,y,z) 15 | but we denote them with (u,v,w) in correspondence with the original paper. 16 | """ 17 | 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | from torch.nn.parameter import Parameter 21 | import math 22 | from utils import * 23 | 24 | 25 | 26 | 27 | 28 | class RotConv(nn.Module): 29 | 30 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 31 | padding=0, dilation=1, n_inclination = 8, n_azimuth = 4, mode=1): 32 | super(RotConv, self).__init__() 33 | 34 | kernel_size = ntuple(3)(kernel_size) 35 | stride = ntuple(3)(stride) 36 | padding = ntuple(3)(padding) 37 | dilation = ntuple(3)(dilation) 38 | 39 | self.in_channels = in_channels 40 | self.out_channels = out_channels 41 | self.kernel_size = kernel_size 42 | self.stride = stride 43 | self.padding = padding 44 | self.dilation = dilation 45 | 46 | self.mode = mode 47 | 48 | #If input is vector field we have two filters (one for each component) 49 | self.weight1 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) 50 | if self.mode == 2: 51 | self.weight2 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) 52 | self.weight3 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size)) 53 | 54 | 55 | #Angles (dip and azimuth) 56 | self.thetas = np.linspace(0, 180, n_inclination, endpoint=False) 57 | self.phis = np.linspace(0, 360, n_azimuth, endpoint=False) 58 | self.theta_tensors = [] 59 | self.phi_tensors = [] 60 | 61 | #Get interpolation variables 62 | self.interp_vars = [] 63 | for theta in self.thetas: 64 | for phi in self.phis: 65 | 66 | self.interp_vars.append(get_filter_rotation_transforms(self.kernel_size, [theta, phi]))#TODO 67 | 68 | self.theta_tensors.append(Variable(torch.FloatTensor([theta / 180. * np.pi]))) 69 | 70 | self.phi_tensors.append(Variable(torch.FloatTensor([phi / 180. * np.pi]))) 71 | 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | n = self.in_channels 76 | for k in self.kernel_size: 77 | n *= k 78 | stdv = 1. / math.sqrt(n) 79 | self.weight1.data.uniform_(-stdv, stdv) 80 | if self.mode == 2: 81 | self.weight2.data.uniform_(-stdv, stdv) 82 | self.weight3.data.uniform_(-stdv, stdv) 83 | 84 | def _apply(self, l): 85 | # We need to replace tensors and variables with cuda-versions 86 | # This is most likely not the nicest way to do this but it works... 87 | self.interp_vars = [[ [l(el2) for el2 in el1 ] for el1 in el0] for el0 in self.interp_vars] 88 | self.thetas = [l(el) for el in self.thetas] 89 | self.phis = [l(el) for el in self.phis] 90 | 91 | super(RotConv, self)._apply(l) 92 | 93 | 94 | def forward(self,input): 95 | 96 | if self.mode == 1: 97 | outputs = [] 98 | 99 | #Loop through the different filter-transformations 100 | for ind, interp_vars in enumerate(self.interp_vars): 101 | #Apply rotation 102 | weight = apply_transform(self.weight1, interp_vars, self.kernel_size) 103 | 104 | #Do convolution 105 | out = F.conv3d(input, weight, None, self.stride, self.padding, self.dilation) 106 | outputs.append(out.unsqueeze(-1)) 107 | 108 | #Get the maximum direction (Orientation Pooling) 109 | strength, max_ind = torch.max(torch.cat(outputs,-1),-1) 110 | 111 | #Convert to spherical coordinates 112 | theta = max_ind.float() * (360. / 8. / 180. * np.pi) 113 | phi = max_ind.float() * (360. / 8. / 180. * np.pi) 114 | 115 | u = F.relu(strength) * torch.sin(theta) * torch.cos(phi) 116 | v = F.relu(strength) * torch.sin(theta) * torch.sin(phi) 117 | w = F.relu(strength) * torch.cos(theta) 118 | 119 | 120 | if self.mode == 2: 121 | u = input[0] 122 | v = input[1] 123 | w = input[2] 124 | 125 | output_u = [] 126 | output_v = [] 127 | output_w = [] 128 | output_p = [] #magnitude of field 129 | 130 | # Loop through the different filter-transformations 131 | for ind, interp_vars in enumerate(self.interp_vars): 132 | theta = self.theta_tensors[ind] 133 | phi = self.phi_tensors[ind] 134 | 135 | # Apply rotation 136 | wu = apply_transform(self.weight1, interp_vars, self.kernel_size) 137 | wv = apply_transform(self.weight2, interp_vars, self.kernel_size) 138 | ww = apply_transform(self.weight3, interp_vars, self.kernel_size) 139 | 140 | 141 | # Do convolution for u 142 | wru = None#TODO: decompose filters 143 | u_out = F.conv3d(u, wru, None, self.stride, self.padding, self.dilation) 144 | output_u.append(u_out.unsqueeze(-1) ) 145 | 146 | # Do convolution for v 147 | wrv = None # TODO: decompose filters 148 | v_out = F.conv3d(v, wrv, None, self.stride, self.padding, self.dilation) 149 | output_v.append(v_out.unsqueeze(-1) ) 150 | 151 | # Do convolution for w 152 | wrw = None # TODO: decompose filters 153 | w_out = F.conv3d(w, wrw, None, self.stride, self.padding, self.dilation) 154 | output_w.append(w_out.unsqueeze(-1)) 155 | 156 | #Compute magnitude (p) 157 | output_p.append( torch.sqrt( v_out**2 + u_out**2 + w_out**2).unsqueeze(-1) ) 158 | 159 | 160 | 161 | # Get the maximum direction (Orientation Pooling) 162 | strength, max_ind = torch.max(torch.cat(output_p, -1), -1) 163 | 164 | # Select the u,v for the maximum orientation 165 | u = torch.cat(output_u, -1) 166 | v = torch.cat(output_v, -1) 167 | w = torch.cat(output_w, -1) 168 | 169 | u = torch.gather(u, -1, max_ind.unsqueeze(-1))[:, :, :, :, :, 0] 170 | v = torch.gather(v, -1, max_ind.unsqueeze(-1))[:, :, :, :, :, 0] 171 | w = torch.gather(w, -1, max_ind.unsqueeze(-1))[:, :, :, :, :, 0] 172 | 173 | return u, v, w 174 | 175 | class VectorMaxPool(nn.Module): 176 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1, 177 | ceil_mode=False): 178 | super(VectorMaxPool, self).__init__() 179 | self.kernel_size = kernel_size 180 | self.stride = stride or kernel_size 181 | self.padding = padding 182 | self.dilation = dilation 183 | self.ceil_mode = ceil_mode 184 | 185 | def forward(self,input): 186 | #Assuming input is vector field 187 | u = input[0] 188 | v = input[1] 189 | w = input[2] 190 | 191 | #Magnitude 192 | p = torch.sqrt( v**2 + u**2 + w**2) 193 | #Max pool 194 | _, max_inds = F.max_pool3d(p, self.kernel_size, self.stride, 195 | self.padding, self.dilation, self.ceil_mode, 196 | return_indices=True) 197 | #Reshape to please pytorch 198 | s1 = u.size() 199 | s2 = max_inds.size() 200 | 201 | max_inds = max_inds.view(s1[0], s1[1], s2[2] * s2[3] * s2[4]) 202 | 203 | u = u.view(s1[0], s1[1], s1[2] * s1[3] * s1[4]) 204 | v = v.view(s1[0], s1[1], s1[2] * s1[3] * s1[4]) 205 | w = w.view(s1[0], s1[1], s1[2] * s1[3] * s1[4]) 206 | 207 | #Select u/v components according to max pool on magnitude 208 | u = torch.gather(u, 2, max_inds) 209 | v = torch.gather(v, 2, max_inds) 210 | w = torch.gather(w, 2, max_inds) 211 | 212 | #Reshape back 213 | u = u.view(s1[0], s1[1], s2[2], s2[3], s1[4]) 214 | v = v.view(s1[0], s1[1], s2[2], s2[3], s1[4]) 215 | w = w.view(s1[0], s1[1], s2[2], s2[3], s1[4]) 216 | 217 | 218 | return u,v,w 219 | 220 | class Vector2Magnitude(nn.Module): 221 | def __init__(self): 222 | super(Vector2Magnitude, self).__init__() 223 | 224 | def forward(self, input): 225 | u = input[0] 226 | v = input[1] 227 | w = input[2] 228 | 229 | p = torch.sqrt(v ** 2 + u ** 2 + w ** 2) 230 | return p 231 | 232 | class VectorBatchNorm(nn.Module): 233 | def __init__(self): 234 | super(VectorBatchNorm, self).__init__() 235 | 236 | def forward(self, input): 237 | if input[0].size()[0] > 1: 238 | u = input[0] 239 | v = input[1] 240 | w = input[2] 241 | 242 | 243 | p = torch.sqrt(v ** 2 + u ** 2 + w ** 2) 244 | 245 | #Mean 246 | mu = torch.mean(p, 0, keepdim=True) 247 | mu = torch.mean(mu, 2, keepdim=True) 248 | mu = torch.mean(mu, 3, keepdim=True) 249 | mu = torch.mean(mu, 4, keepdim=True) 250 | 251 | #Variance 252 | var = (mu-p)**2 253 | var = torch.sum(var, 0, keepdim=True) 254 | var = torch.sum(var, 2, keepdim=True) 255 | var = torch.sum(var, 3, keepdim=True) 256 | var = torch.sum(var, 4, keepdim=True) 257 | std = torch.sqrt(var) 258 | 259 | eps = 0.00001 260 | std = std + eps 261 | 262 | 263 | return u/std, v/std , w/std 264 | else: 265 | return input 266 | 267 | 268 | class VectorUpsampling(nn.Module): 269 | def __init__(self, size=None, scale_factor=None, mode = 'trilinear'): 270 | super(VectorUpsampling, self).__init__() 271 | self.size = size 272 | self.scale_factor = scale_factor 273 | self.mode = mode 274 | 275 | def forward(self, input): 276 | # Assuming input is vector field 277 | u = input[0] 278 | v = input[1] 279 | w = input[2] 280 | 281 | u = F.upsample(u, size=self.size, scale_factor=self.scale_factor, mode=self.mode) 282 | v = F.upsample(v, size=self.size, scale_factor=self.scale_factor, mode=self.mode) 283 | w = F.upsample(w, size=self.size, scale_factor=self.scale_factor, mode=self.mode) 284 | 285 | 286 | return u, v, w 287 | --------------------------------------------------------------------------------