├── __init__.py
├── mnist
├── best_model.pt
├── make_mnist_rot.py
├── mnist.py
├── download_mnist.py
└── mnist_test.py
├── README.md
├── utils.py
├── layers_2D.py
└── layers_3D.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mnist/best_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/COGMAR/RotEqNet/HEAD/mnist/best_model.pt
--------------------------------------------------------------------------------
/mnist/make_mnist_rot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from mnist import random_rotation, loadMnist
4 |
5 |
6 |
7 | def makeMnistRot():
8 | """
9 | Make MNIST-rot from MNIST
10 | Select all training and test samples from MNIST and select 10000 for train,
11 | 2000 for val and 50000 for test. Apply a random rotation to each image.
12 |
13 | Store in numpy file for fast reading
14 |
15 | """
16 | np.random.seed(0)
17 |
18 | #Get all samples
19 | all_samples = loadMnist('train') + loadMnist('test')
20 |
21 | #
22 |
23 | #Empty arrays
24 | train_data = np.zeros([28,28,10000])
25 | train_label = np.zeros([10000])
26 | val_data = np.zeros([28,28,2000])
27 | val_label = np.zeros([2000])
28 | test_data = np.zeros([28,28,50000])
29 | test_label = np.zeros([50000])
30 |
31 | i = 0
32 | for j in range(10000):
33 | sample =all_samples[i]
34 | train_data[:, :, j] = random_rotation(sample[0])
35 | train_label[j] = sample[1]
36 | i += 1
37 |
38 | for j in range(2000):
39 | sample = all_samples[i]
40 | val_data[:, :, j] = random_rotation(sample[0])
41 | val_label[j] = sample[1]
42 | i += 1
43 |
44 | for j in range(50000):
45 | sample = all_samples[i]
46 | test_data[:, :, j] = random_rotation(sample[0])
47 | test_label[j] = sample[1]
48 | i += 1
49 |
50 |
51 | try:
52 | os.mkdir('mnist_rot/')
53 | except:
54 | None
55 | np.save('mnist_rot/train_data',train_data)
56 | np.save('mnist_rot/train_label', train_label)
57 | np.save('mnist_rot/val_data', val_data)
58 | np.save('mnist_rot/val_label', val_label)
59 | np.save('mnist_rot/test_data', test_data)
60 | np.save('mnist_rot/test_label', test_label)
61 |
62 | if __name__ == '__main__':
63 | makeMnistRot()
--------------------------------------------------------------------------------
/mnist/mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc
3 | import sys
4 | sys.path.append('../')
5 | from utils import getGrid, rotate_grid_2D
6 |
7 | def loadMnist(mode):
8 | print 'Loading MNIST', mode, 'images'
9 | # Mode = 'train'/'test
10 | mnist_folder = '/nr/user/andersuw/shared/datasets/mnist/mnist/'
11 |
12 | with file(mnist_folder + mode + '-labels.csv') as f:
13 | path_and_labels = f.readlines()
14 |
15 | samples = [];
16 | for entry in path_and_labels:
17 | path = entry.split(',')[0]
18 | label = int(entry.split(',')[1])
19 | img = scipy.misc.imread(mnist_folder + path)
20 |
21 | samples.append([img, label])
22 | return samples
23 |
24 |
25 | def linear_interpolation_2D(input_array, indices, outside_val=0, boundary_correction=True):
26 | # http://stackoverflow.com/questions/6427276/3d-interpolation-of-numpy-arrays-without-scipy
27 | output = np.empty(indices[0].shape)
28 | ind_0 = indices[0,:]
29 | ind_1 = indices[1,:]
30 |
31 | N0, N1 = input_array.shape
32 |
33 | x0_0 = ind_0.astype(np.integer)
34 | x1_0 = ind_1.astype(np.integer)
35 | x0_1 = x0_0 + 1
36 | x1_1 = x1_0 + 1
37 |
38 | # Check if inds are beyond array boundary:
39 | if boundary_correction:
40 | # put all samples outside datacube to 0
41 | inds_out_of_range = (x0_0 < 0) | (x0_1 < 0) | (x1_0 < 0) | (x1_1 < 0) | \
42 | (x0_0 >= N0) | (x0_1 >= N0) | (x1_0 >= N1) | (x1_1 >= N1)
43 |
44 | x0_0[inds_out_of_range] = 0
45 | x1_0[inds_out_of_range] = 0
46 | x0_1[inds_out_of_range] = 0
47 | x1_1[inds_out_of_range] = 0
48 |
49 | w0 = ind_0 - x0_0
50 | w1 = ind_1 - x1_0
51 | # Replace by this...
52 | # input_array.take(np.array([x0_0, x1_0, x2_0]))
53 | output = (input_array[x0_0, x1_0] * (1 - w0) * (1 - w1) +
54 | input_array[x0_1, x1_0] * w0 * (1 - w1) +
55 | input_array[x0_0, x1_1] * (1 - w0) * w1 +
56 | input_array[x0_1, x1_1] * w0 * w1 )
57 |
58 |
59 | if boundary_correction:
60 | output[inds_out_of_range] = 0
61 |
62 | return output
63 |
64 | def loadMnistRot():
65 | def load_and_make_list(mode):
66 | data = np.load('mnist_rot/' + mode + '_data.npy')
67 | lbls = np.load('mnist_rot/' + mode + '_label.npy')
68 | data = np.split(data, data.shape[2],2)
69 | lbls = np.split(lbls, lbls.shape[0],0)
70 |
71 | return zip(data,lbls)
72 |
73 | train = load_and_make_list('train')
74 | val = load_and_make_list('val')
75 | test = load_and_make_list('test')
76 | return train, val, test
77 |
78 | def random_rotation(data):
79 | rot = np.random.rand() * 360 # Random rotation
80 | grid = getGrid([28, 28])
81 | grid = rotate_grid_2D(grid, rot)
82 | grid += 13.5
83 | data = linear_interpolation_2D(data, grid)
84 | data = np.reshape(data, [28, 28])
85 | data = data / float(np.max(data))
86 | return data.astype('float32')
87 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Rotational Equivariant Vector Field Networks (RotEqNet) for PyTorch
2 |
3 | This is a PyTorch implementation of the method proposed in:
4 | Rotation equivariant vector field networks, ICCV 2017,
5 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia.
6 |
7 | https://arxiv.org/abs/1612.09346
8 |
9 |
10 | The original MATLAB implementation can be found at:
11 |
12 | https://github.com/dmarcosg/RotEqNet
13 |
14 | The goal of this code is to provide an implementation of the new network layers proposed in the paper. In addition we try to reproduce the results the MNIST-rot dataset to verify the implementation.
15 |
16 |
17 |
18 |
19 | ### Example usage
20 | ```python
21 | from __future__ import division
22 | from layers_2D import RotConv, VectorMaxPool, VectorBatchNorm, Vector2Magnitude, VectorUpsampling
23 | from torch import nn
24 |
25 |
26 | class MnistNet(nn.Module):
27 | def __init__(self):
28 | super(MnistNet, self).__init__()
29 |
30 | self.main = nn.Sequential(
31 | RotConv(1, 6, [9, 9], 1, 9 // 2, n_angles=17, mode=1), #The first RotConv must have mode=1
32 | VectorMaxPool(2),
33 | VectorBatchNorm(6),
34 |
35 | RotConv(6, 16, [9, 9], 1, 9 // 2, n_angles=17, mode=2), #The next RotConv has mode=2 (since the input is vector field)
36 | VectorMaxPool(2),
37 | VectorBatchNorm(16),
38 |
39 | RotConv(16, 32, [9, 9], 1, 1, n_angles=17, mode=2),
40 | Vector2Magnitude(), #This call converts the vector field to a conventional multichannel image/feature image
41 |
42 | nn.Conv2d(32, 128, 1),
43 | nn.BatchNorm2d(128),
44 | nn.ReLU(),
45 | nn.Dropout2d(0.7),
46 | nn.Conv2d(128, 10, 1),
47 |
48 | )
49 |
50 | def forward(self,x):
51 | x = self.main(x)
52 | return x
53 | ```
54 |
55 |
56 | ### Dependencies
57 | The following python packages are required:
58 |
59 | ```
60 | torch
61 | numpy
62 | scipy
63 | ```
64 | To download and setup the MNIST-rot dataset, cd into the MNIST-folder and run:
65 | ```
66 | python download_mnist.py
67 | python make_mnist_rot.py
68 | ```
69 | To run the MNIST-test:
70 | ```
71 | python mnist_test.py
72 | ```
73 | ## Results from the MNIST-rot test
74 | The MNIST-experiment in the orignial paper was obtained by:
75 | - training on 10 000 images from the MNIST-rot dataset + applying random rotation as augmentation
76 | - validating on 2000 images from the MNIST-rot dataset
77 | - testing on 10 0000 images from the MNIST-rot dataset + with test-time augmentation as described in the paper
78 |
79 | Using this implementation, we obtain a test accuracy of 1.2%, while the original paper reports 1.1%.
80 |
81 | ### Known issues:
82 | - The interpolation of filters (apply_transformation in utils.py) sometimes causes "CUDA runtime error 59". This error disappears when we use "torch.gather" to collect the samples, but this does reduce the best test error rate to ~3%.
83 |
84 |
85 | ### Contact
86 | Anders U. Waldeland
87 | Norwegian Computing Center
88 | anders@nr.no
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/mnist/download_mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | """
6 | From:
7 | https://gist.github.com/ischlag/41d15424e7989b936c1609b53edd1390
8 | """
9 |
10 | import gzip
11 | import os
12 | import sys
13 | import time
14 |
15 | from six.moves import urllib
16 | from six.moves import xrange # pylint: disable=redefined-builtin
17 | from scipy.misc import imsave
18 | import tensorflow as tf
19 | import numpy as np
20 | import csv
21 |
22 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
23 | WORK_DIRECTORY = 'raw_data'
24 | IMAGE_SIZE = 28
25 | NUM_CHANNELS = 1
26 | PIXEL_DEPTH = 255
27 | NUM_LABELS = 10
28 |
29 | def maybe_download(filename):
30 | """Download the data from Yann's website, unless it's already here."""
31 | if not tf.gfile.Exists(WORK_DIRECTORY):
32 | tf.gfile.MakeDirs(WORK_DIRECTORY)
33 | filepath = os.path.join(WORK_DIRECTORY, filename)
34 | if not tf.gfile.Exists(filepath):
35 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
36 | with tf.gfile.GFile(filepath) as f:
37 | size = f.size()
38 | print('Successfully downloaded', filename, size, 'bytes.')
39 | return filepath
40 |
41 |
42 | def extract_data(filename, num_images):
43 | """Extract the images into a 4D tensor [image index, y, x, channels].
44 | Values are rescaled from [0, 255] down to [-0.5, 0.5].
45 | """
46 | print('Extracting', filename)
47 | with gzip.open(filename) as bytestream:
48 | bytestream.read(16)
49 | buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images)
50 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
51 | #data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
52 | data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1)
53 | return data
54 |
55 |
56 | def extract_labels(filename, num_images):
57 | """Extract the labels into a vector of int64 label IDs."""
58 | print('Extracting', filename)
59 | with gzip.open(filename) as bytestream:
60 | bytestream.read(8)
61 | buf = bytestream.read(1 * num_images)
62 | labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
63 | return labels
64 |
65 | train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
66 | train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
67 | test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
68 | test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
69 |
70 | # Extract it into np arrays.
71 | train_data = extract_data(train_data_filename, 60000)
72 | train_labels = extract_labels(train_labels_filename, 60000)
73 | test_data = extract_data(test_data_filename, 10000)
74 | test_labels = extract_labels(test_labels_filename, 10000)
75 |
76 | if not os.path.isdir("mnist/train-images"):
77 | os.makedirs("mnist/train-images")
78 |
79 | if not os.path.isdir("mnist/test-images"):
80 | os.makedirs("mnist/test-images")
81 |
82 | # process train data
83 | with open("mnist/train-labels.csv", 'wb') as csvFile:
84 | writer = csv.writer(csvFile, delimiter=',', quotechar='"')
85 | for i in range(len(train_data)):
86 | imsave("mnist/train-images/" + str(i) + ".jpg", train_data[i][:,:,0])
87 | writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]])
88 |
89 | # repeat for test data
90 | with open("mnist/test-labels.csv", 'wb') as csvFile:
91 | writer = csv.writer(csvFile, delimiter=',', quotechar='"')
92 | for i in range(len(test_data)):
93 | imsave("mnist/test-images/" + str(i) + ".jpg", test_data[i][:,:,0])
94 | writer.writerow(["test-images/" + str(i) + ".jpg", test_labels[i]])
--------------------------------------------------------------------------------
/mnist/mnist_test.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from torch import optim
6 | import numpy as np
7 | from torch.autograd import Variable
8 | import random
9 | from mnist import loadMnistRot, random_rotation, linear_interpolation_2D
10 |
11 | import sys
12 | sys.path.append('../') #Import
13 | from layers_2D import *
14 | from utils import getGrid
15 |
16 | #!/usr/bin/env python
17 | __author__ = "Anders U. Waldeland"
18 | __email__ = "anders@nr.no"
19 |
20 | """
21 | A reproduction of the MNIST-classification network described in:
22 | Rotation equivariant vector field networks (ICCV 2017)
23 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia
24 | https://arxiv.org/abs/1612.09346
25 | https://github.com/dmarcosg/RotEqNet
26 | """
27 |
28 |
29 | if __name__ == '__main__':
30 |
31 | # Define network
32 | class Net(nn.Module):
33 | def __init__(self):
34 | super(Net, self).__init__()
35 |
36 | self.main = nn.Sequential(
37 |
38 | RotConv(1, 6, [9, 9], 1, 9 // 2, n_angles=17, mode=1),
39 | VectorMaxPool(2),
40 | VectorBatchNorm(6),
41 |
42 | RotConv(6, 16, [9, 9], 1, 9 // 2, n_angles=17, mode=2),
43 | VectorMaxPool(2),
44 | VectorBatchNorm(16),
45 |
46 | RotConv(16, 32, [9, 9], 1, 1, n_angles=17, mode=2),
47 | Vector2Magnitude(),
48 |
49 | nn.Conv2d(32, 128, 1), # FC1
50 | nn.BatchNorm2d(128),
51 | nn.ReLU(),
52 | nn.Dropout2d(0.7),
53 | nn.Conv2d(128, 10, 1), # FC2
54 |
55 | )
56 |
57 | def forward(self, x):
58 | x = self.main(x)
59 | x = x.view(x.size()[0], x.size()[1])
60 |
61 | return x
62 |
63 |
64 | gpu_no = 0 # Set to False for cpu-version
65 |
66 | #Setup net, loss function, optimizer and hyper parameters
67 | net = Net()
68 | criterion = nn.CrossEntropyLoss()
69 | if type(gpu_no) == int:
70 | net.cuda(gpu_no)
71 |
72 | if True: #Current best setup using this implementation - error rate of 1.2%
73 | start_lr = 0.01
74 | batch_size = 128
75 | optimizer = optim.Adam(net.parameters(), lr=start_lr) # , weight_decay=0.01)
76 | use_test_time_augmentation = True
77 | use_train_time_augmentation = True
78 |
79 | if False: #From paper using MATLAB implementation - reported error rate of 1.4%
80 | start_lr = 0.1
81 | batch_size = 600
82 | optimizer = optim.SGD(net.parameters(), lr=start_lr, weight_decay=0.01)
83 | use_test_time_augmentation = True
84 | use_train_time_augmentation = True
85 |
86 |
87 | def rotate_im(im, theta):
88 | grid = getGrid([28, 28])
89 | grid = rotate_grid_2D(grid, theta)
90 | grid += 13.5
91 | data = linear_interpolation_2D(im, grid)
92 | data = np.reshape(data, [28, 28])
93 | return data.astype('float32')
94 |
95 |
96 | def test(model, dataset, mode):
97 | """ Return test-acuracy for a dataset"""
98 | model.eval()
99 |
100 | true = []
101 | pred = []
102 | for batch_no in xrange(len(dataset) // batch_size):
103 | data, labels = getBatch(dataset, mode)
104 |
105 | #Run same sample with different orientations through network and average output
106 | if use_test_time_augmentation and mode == 'test':
107 | data = data.cpu()
108 | original_data = data.clone().data.cpu().numpy()
109 |
110 | out = None
111 | rotations = [0,15,30,45, 60, 75, 90]
112 |
113 | for rotation in rotations:
114 |
115 | for i in range(batch_size):
116 | im = original_data[i,:,:,:].squeeze()
117 | im = rotate_im(im, rotation)
118 | im = im.reshape([1, 1, 28, 28])
119 | im = torch.FloatTensor(im)
120 | data[i,:,:,:] = im
121 |
122 | if type(gpu_no) == int:
123 | data = data.cuda(gpu_no)
124 |
125 | if out is None:
126 | out = F.softmax(model(data))
127 | else:
128 | out += F.softmax(model(data))
129 |
130 | out /= len(rotations)
131 |
132 | #Only run once
133 | else:
134 | out = F.softmax(model(data),dim=1)
135 |
136 | loss = criterion(out, labels)
137 | _, c = torch.max(out, 1)
138 | true.append(labels.data.cpu().numpy())
139 | pred.append(c.data.cpu().numpy())
140 | true = np.concatenate(true, 0)
141 | pred = np.concatenate(pred, 0)
142 | acc = np.average(pred == true)
143 | return acc
144 |
145 | def getBatch(dataset, mode):
146 | """ Collect a batch of samples from list """
147 |
148 | # Make batch
149 | data = []
150 | labels = []
151 | for sample_no in range(batch_size):
152 | tmp = dataset.pop() # Get top element and remove from list
153 | img = tmp[0].astype('float32').squeeze()
154 |
155 | # Train-time random rotation
156 | if mode == 'train' and use_train_time_augmentation:
157 | img = random_rotation(img)
158 |
159 | data.append(np.expand_dims(np.expand_dims(img, 0), 0))
160 | labels.append(tmp[1].squeeze())
161 | data = np.concatenate(data, 0)
162 | labels = np.array(labels, 'int32')
163 |
164 | data = Variable(torch.from_numpy(data))
165 | labels = Variable(torch.from_numpy(labels).long())
166 |
167 | if type(gpu_no) == int:
168 | data = data.cuda(gpu_no)
169 | labels = labels.cuda(gpu_no)
170 |
171 | return data, labels
172 |
173 | def adjust_learning_rate(optimizer, epoch):
174 | """Gradually decay learning rate"""
175 | if epoch == 20:
176 | lr = start_lr / 10
177 | for param_group in optimizer.param_groups:
178 | param_group['lr'] = lr
179 | if epoch == 40:
180 | lr = start_lr / 100
181 | for param_group in optimizer.param_groups:
182 | param_group['lr'] = lr
183 | if epoch == 60:
184 | lr = start_lr / 100
185 | for param_group in optimizer.param_groups:
186 | param_group['lr'] = lr
187 |
188 | #Load datasets
189 | train_set, val_set, test_set = loadMnistRot()
190 |
191 | best_acc = 0
192 | for epoch_no in range(90):
193 |
194 | #Random order for each epoch
195 | train_set_for_epoch = train_set[:] #Make a copy
196 | random.shuffle(train_set_for_epoch) #Shuffle the copy
197 |
198 | #Training
199 | net.train()
200 | for batch_no in xrange(len(train_set)//batch_size):
201 |
202 | # Train
203 | optimizer.zero_grad()
204 |
205 | data, labels = getBatch(train_set_for_epoch, 'train')
206 | out = net( data )
207 | loss = criterion( out,labels )
208 | _, c = torch.max(out, 1)
209 | loss.backward()
210 |
211 | optimizer.step()
212 |
213 | #Print training-acc
214 | if batch_no%10 == 0:
215 | print('Train', 'epoch:', epoch_no, ' batch:', batch_no, ' loss:', loss.data.cpu().numpy()[0], ' acc:', np.average((c == labels).data.cpu().numpy()))
216 |
217 |
218 | #Validation
219 | acc = test(net, val_set[:], 'val')
220 | print('Val', 'epoch:', epoch_no, ' acc:', acc)
221 |
222 | #Save model if better than previous
223 | if acc > best_acc:
224 | torch.save(net.state_dict(), 'best_model.pt')
225 | best_acc = acc
226 | print('Model saved')
227 |
228 | adjust_learning_rate(optimizer, epoch_no)
229 |
230 | # Finally test on test-set with the best model
231 | net.load_state_dict(torch.load('best_model.pt'))
232 | print('Test', 'acc:', test(net, test_set[:], 'test'))
233 |
234 |
235 |
236 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 | from scipy.linalg import expm, norm
3 | import collections
4 | import itertools
5 | import numpy as np
6 | from torch.autograd import Variable
7 | import torch
8 |
9 | def ntuple(n):
10 | """ Ensure that input has the correct number of elements """
11 | def parse(x):
12 | if isinstance(x, collections.Iterable):
13 | return x
14 | return tuple(itertools.repeat(x, n))
15 | return parse
16 |
17 | def getGrid(siz):
18 | """ Returns grid with coordinates from -siz[0]/2 : siz[0]/2, -siz[1]/2 : siz[1]/2, ...."""
19 | space = [np.linspace( -(N/2), (N/2), N ) for N in siz]
20 | mesh = np.meshgrid( *space, indexing='ij' )
21 | mesh = [np.expand_dims( ax.ravel(), 0) for ax in mesh]
22 |
23 | return np.concatenate(mesh)
24 |
25 | def rotate_grid_2D(grid, theta):
26 | """ Rotate grid """
27 | theta = np.deg2rad(theta)
28 |
29 | x0 = grid[0, :] * np.cos(theta) - grid[1, :] * np.sin(theta)
30 | x1 = grid[0, :] * np.sin(theta) + grid[1, :] * np.cos(theta)
31 |
32 | grid[0, :] = x0
33 | grid[1, :] = x1
34 | return grid
35 |
36 | def rotate_grid_3D(theta, axis, grid):
37 | """ Rotate grid """
38 | theta = np.deg2rad(theta)
39 | axis = np.array(axis)
40 | rot_mat = expm(np.cross(np.eye(3), axis / norm(axis) * theta))
41 | rot_mat =np.expand_dims(rot_mat,2)
42 | grid = np.transpose( np.expand_dims(grid,2), [0,2,1])
43 |
44 | return np.einsum('ijk,jik->ik',rot_mat,grid)
45 |
46 |
47 | def get_filter_rotation_transforms(kernel_dims, angles):
48 | """ Return the interpolation variables needed to transform a filter by a given number of degrees """
49 |
50 | dim = len(kernel_dims)
51 |
52 | # Make grid (centered around filter-center)
53 | grid = getGrid(kernel_dims)
54 |
55 | # Rotate grid
56 | if dim == 2:
57 | grid = rotate_grid_2D(grid, angles)
58 | elif dim == 3:
59 | grid = rotate_grid_3D(angles[0], [1, 0, 0], grid)
60 | grid = rotate_grid_3D(angles[1], [0, 0, 1], grid)
61 |
62 |
63 | # Radius of filter
64 | radius = np.min((np.array(kernel_dims)-1) / 2.)
65 |
66 | #Mask out samples outside circle
67 | radius = np.expand_dims(radius,-1)
68 | dist_to_center = np.sqrt(np.sum(grid**2,axis=0))
69 | mask = dist_to_center>=radius+.0001
70 | mask = 1-mask
71 |
72 | # Move grid to center
73 | grid += radius
74 |
75 | return compute_interpolation_grids(grid, kernel_dims, mask)
76 |
77 | def compute_interpolation_grids(grid, kernel_dims, mask):
78 |
79 | #######################################################
80 | # The following part is part of nd-linear interpolation
81 |
82 | #Add a small eps to grid so that floor and ceil operations become more stable
83 | grid += 0.000000001
84 |
85 | # Make list where each element represents a dimension
86 | grid = [grid[i, :] for i in range(grid.shape[0])]
87 |
88 | # Get left and right index (integers)
89 | inds_0 = [ind.astype(np.integer) for ind in grid]
90 | inds_1 = [ind + 1 for ind in inds_0]
91 |
92 | # Get weights
93 | weights = [float_ind - int_ind for float_ind, int_ind in zip(grid, inds_0)]
94 |
95 | # Special case for when ind_1 == size (while ind_0 == siz)
96 | # In that case we select ind_0
97 | ind_1_out_of_bounds = np.logical_or.reduce([ind == siz for ind, siz in zip(inds_1, kernel_dims)])
98 | for i in range(len(inds_1)):
99 | inds_1[i][ind_1_out_of_bounds] = 0
100 |
101 |
102 | # Get samples that are out of bounds or outside mask
103 | inds_out_of_bounds = np.logical_or.reduce([ind < 0 for ind in itertools.chain(inds_0, inds_1)] + \
104 | [ind >= siz for ind, siz in zip(inds_0, kernel_dims)] + \
105 | [ind >= siz for ind, siz in zip(inds_1, kernel_dims)] +
106 | (1-mask).astype('bool')
107 | )
108 |
109 |
110 | # Set these samples to zero get data from upper-left-corner (which will be put to zero)
111 | for i in range(len(inds_0)):
112 | inds_0[i][inds_out_of_bounds] = 0
113 | inds_1[i][inds_out_of_bounds] = 0
114 |
115 | #Reshape
116 | inds_0 = [np.reshape(ind,[1,1]+kernel_dims) for ind in inds_0]
117 | inds_1 = [np.reshape(ind,[1,1]+kernel_dims) for ind in inds_1]
118 | weights = [np.reshape(weight,[1,1]+kernel_dims)for weight in weights]
119 |
120 | #Make pytorch-tensors of the interpolation variables
121 | inds_0 = [Variable(torch.LongTensor(ind)) for ind in inds_0]
122 | inds_1 = [Variable(torch.LongTensor(ind)) for ind in inds_1]
123 | weights = [Variable(torch.FloatTensor(weight)) for weight in weights]
124 |
125 | #Make mask pytorch tensor
126 | mask = mask.reshape(kernel_dims)
127 | mask = mask.astype('float32')
128 | mask = np.expand_dims(mask, 0)
129 | mask = np.expand_dims(mask, 0)
130 | mask = torch.FloatTensor(mask)
131 |
132 | # Uncomment for nearest interpolation (for debugging)
133 | #inds_1 = [ind*0 for ind in inds_1]
134 | #weights = [weight*0 for weight in weights]
135 |
136 | return inds_0, inds_1, weights, mask
137 |
138 | def apply_transform(filter, interp_vars, filters_size, old_bilinear_interpolation=True):
139 | """ Apply a transform specified by the interpolation_variables to a filter """
140 |
141 | dim = 2 if len(filter.size())==4 else 3
142 |
143 | if dim == 2:
144 |
145 |
146 | if old_bilinear_interpolation:
147 | [x0_0, x1_0], [x0_1, x1_1], [w0, w1] = interp_vars
148 | rotated_filter = (filter[:, :, x0_0, x1_0] * (1 - w0) * (1 - w1) +
149 | filter[:, :, x0_1, x1_0] * w0 * (1 - w1) +
150 | filter[:, :, x0_0, x1_1] * (1 - w0) * w1 +
151 | filter[:, :, x0_1, x1_1] * w0 * w1)
152 | else:
153 |
154 | # Expand dimmentions to fit filter
155 | interp_vars = [[inner_el.expand_as(filter) for inner_el in outer_el] for outer_el in interp_vars]
156 |
157 | [x0_0, x1_0], [x0_1, x1_1], [w0, w1] = interp_vars
158 |
159 | a = torch.gather(torch.gather(filter, 2, x0_0), 3, x1_0) * (1 - w0) * (1 - w1)
160 | b = torch.gather(torch.gather(filter, 2, x0_1), 3, x1_0)* w0 * (1 - w1)
161 | c = torch.gather(torch.gather(filter, 2, x0_0), 3, x1_1)* (1 - w0) * w1
162 | d = torch.gather(torch.gather(filter, 2, x0_1), 3, x1_1)* w0 * w1
163 | rotated_filter = a+b+c+d
164 |
165 | rotated_filter = rotated_filter.view(filter.size()[0],filter.size()[1],filters_size[0],filters_size[1])
166 |
167 | elif dim == 3:
168 | [x0_0, x1_0, x2_0], [x0_1, x1_1, x2_1], [w0, w1, w2] = interp_vars
169 |
170 | rotated_filter = (filter[x0_0, x1_0, x2_0] * (1 - w0) * (1 - w1)* (1 - w2) +
171 | filter[x0_1, x1_0, x2_0] * w0 * (1 - w1)* (1 - w2) +
172 | filter[x0_0, x1_1, x2_0] * (1 - w0) * w1 * (1 - w2) +
173 | filter[x0_1, x1_1, x2_0] * w0 * w1 * (1 - w2) +
174 | filter[x0_0, x1_0, x2_1] * (1 - w0) * (1 - w1)* w2 +
175 | filter[x0_1, x1_0, x2_1] * w0 * (1 - w1)* w2 +
176 | filter[x0_0, x1_1, x2_1] * (1 - w0) * w1 * w2 +
177 | filter[x0_1, x1_1, x2_1] * w0 * w1 * w2)
178 |
179 | rotated_filter = rotated_filter.view(filter.size()[0], filter.size()[1], filters_size[0], filters_size[1], filters_size[2])
180 |
181 | return rotated_filter
182 |
183 |
184 |
185 | if __name__ == '__main__':
186 | """ Test rotation of filter """
187 | import torch.nn as nn
188 | from torch.nn import functional as F
189 | from torch.nn.parameter import Parameter
190 | import math
191 | from utils import *
192 |
193 | ks = [9,9] #Kernel size
194 | angle = 45
195 | interp_vars = get_filter_rotation_transforms(ks, angle)
196 |
197 | w = Variable(torch.ones([1,1]+ks))
198 | #w[:,:,4,:] = 5
199 | w[:, :, :, 4] = 5
200 | #w[:,:,0,0] = -1
201 |
202 |
203 | print(w)
204 | for angle in [0,90,45,180,65,10]:
205 | print(angle,'degrees')
206 | print(apply_transform(w, get_filter_rotation_transforms(ks, angle)[:-1], ks,old_bilinear_interpolation=True) * Variable(get_filter_rotation_transforms(ks, angle)[-1]))
207 | print('Difference', torch.sum(apply_transform(w, get_filter_rotation_transforms(ks, angle)[:-1], ks,old_bilinear_interpolation=False) * Variable( get_filter_rotation_transforms(ks, angle)[-1]) - apply_transform(w, get_filter_rotation_transforms(ks, angle)[:-1], ks,old_bilinear_interpolation=True) * Variable(get_filter_rotation_transforms(ks, angle)[-1])))
208 |
209 |
210 |
--------------------------------------------------------------------------------
/layers_2D.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is an PyTorch implementation of the method proposed in:
3 | Rotation equivariant vector field networks (ICCV 2017)
4 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia
5 | https://arxiv.org/abs/1612.09346
6 | https://github.com/dmarcosg/RotEqNet (original code)
7 | """
8 | from __future__ import division
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 | from torch.nn.parameter import Parameter
12 | import math
13 | from utils import *
14 |
15 | class RotConv(nn.Module):
16 |
17 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
18 | padding=0, dilation=1, n_angles = 8, mode=1):
19 | super(RotConv, self).__init__()
20 |
21 | kernel_size = ntuple(2)(kernel_size)
22 | stride = ntuple(2)(stride)
23 | padding = ntuple(2)(padding)
24 | dilation = ntuple(2)(dilation)
25 |
26 | self.in_channels = in_channels
27 | self.out_channels = out_channels
28 | self.kernel_size = kernel_size
29 | self.stride = stride
30 | self.padding = padding
31 | self.dilation = dilation
32 |
33 | self.mode = mode
34 |
35 | #Angles
36 | self.angles = np.linspace(0,360,n_angles, endpoint=False)
37 | self.angle_tensors = []
38 |
39 | #Get interpolation variables
40 | self.interp_vars = []
41 | for angle in self.angles:
42 | out = get_filter_rotation_transforms(list(self.kernel_size), angle)
43 | self.interp_vars.append(out[:-1])
44 | self.mask = out[-1]
45 |
46 | self.angle_tensors.append( Variable(torch.FloatTensor( np.array([angle/ 180. * np.pi]) )) )
47 |
48 | self.weight1 = Parameter(torch.Tensor( out_channels, in_channels , *kernel_size))
49 | #If input is vector field, we have two filters (one for each component)
50 | if self.mode == 2:
51 | self.weight2 = Parameter(torch.Tensor( out_channels, in_channels, *kernel_size))
52 |
53 | self.reset_parameters()
54 |
55 | def reset_parameters(self):
56 | n = self.in_channels
57 | for k in self.kernel_size:
58 | n *= k
59 | stdv = 1. / math.sqrt(n)
60 | self.weight1.data.uniform_(-stdv, stdv)
61 | if self.mode == 2:
62 | self.weight2.data.uniform_(-stdv, stdv)
63 |
64 | def mask_filters(self):
65 | self.weight1.data[self.mask.expand_as(self.weight1) == 0] = 1e-8
66 | if self.mode == 2:
67 | self.weight2.data[self.mask.expand_as(self.weight1) == 0] = 1e-8
68 |
69 | def _apply(self, func):
70 | # This is called whenever user calls model.cuda()
71 | # We intersect to replace tensors and variables with cuda-versions
72 | self.mask = func(self.mask)
73 | self.interp_vars = [[[func(el2) for el2 in el1] for el1 in el0] for el0 in self.interp_vars]
74 | self.angle_tensors = [func(el) for el in self.angle_tensors]
75 |
76 | return super(RotConv, self)._apply(func)
77 |
78 |
79 | def forward(self,input):
80 | #Uncomment this to turn on filter-masking
81 | #Todo: fix broken convergence when filter-masking is on
82 | #self.mask_filters()
83 |
84 | if self.mode == 1:
85 | outputs = []
86 |
87 | #Loop through the different filter-transformations
88 | for ind, interp_vars in enumerate(self.interp_vars):
89 | #Apply rotation
90 | weight = apply_transform(self.weight1, interp_vars, self.kernel_size)
91 |
92 | #Do convolution
93 | out = F.conv2d(input, weight, None, self.stride, self.padding, self.dilation)
94 | outputs.append(out.unsqueeze(-1))
95 |
96 | if self.mode == 2:
97 | u = input[0]
98 | v = input[1]
99 |
100 | outputs = []
101 | # Loop through the different filter-transformations
102 | for ind, interp_vars in enumerate(self.interp_vars):
103 | angle = self.angle_tensors[ind]
104 | # Apply rotation
105 | wu = apply_transform(self.weight1, interp_vars, self.kernel_size)
106 | wv = apply_transform(self.weight2, interp_vars, self.kernel_size)
107 |
108 | # Do convolution for u
109 | wru = torch.cos(angle) * wu - torch.sin(angle ) * wv
110 | u_out = F.conv2d(u, wru, None, self.stride, self.padding, self.dilation)
111 |
112 | # Do convolution for v
113 | wrv = torch.sin(angle) * wu + torch.cos(angle) * wv
114 | v_out = F.conv2d(v, wrv, None, self.stride, self.padding, self.dilation)
115 |
116 | #Compute magnitude (p)
117 | outputs.append( (u_out + v_out).unsqueeze(-1) )
118 |
119 |
120 | # Get the maximum direction (Orientation Pooling)
121 | strength, max_ind = torch.max(torch.cat(outputs, -1), -1)
122 |
123 | # Convert from polar representation q
124 | angle_map = max_ind.float() * (360. / len(self.angles) / 180. * np.pi)
125 | u = F.relu(strength) * torch.cos(angle_map)
126 | v = F.relu(strength) * torch.sin(angle_map)
127 |
128 |
129 | return u, v
130 |
131 | class VectorMaxPool(nn.Module):
132 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
133 | ceil_mode=False):
134 | super(VectorMaxPool, self).__init__()
135 | self.kernel_size = kernel_size
136 | self.stride = stride or kernel_size
137 | self.padding = padding
138 | self.dilation = dilation
139 | self.ceil_mode = ceil_mode
140 |
141 | def forward(self,input):
142 | #Assuming input is vector field
143 | u = input[0]
144 | v = input[1]
145 |
146 | #Magnitude
147 | p = torch.sqrt( v**2 + u**2)
148 | #Max pool
149 | _, max_inds = F.max_pool2d(p, self.kernel_size, self.stride,
150 | self.padding, self.dilation, self.ceil_mode,
151 | return_indices=True)
152 | #Reshape to please pytorch
153 | s1 = u.size()
154 | s2 = max_inds.size()
155 |
156 | max_inds = max_inds.view(s1[0], s1[1], s2[2] * s2[3])
157 |
158 | u = u.view(s1[0], s1[1], s1[2] * s1[3])
159 | v = v.view(s1[0], s1[1], s1[2] * s1[3])
160 |
161 | #Select u/v components according to max pool on magnitude
162 | u = torch.gather(u, 2, max_inds)
163 | v = torch.gather(v, 2, max_inds)
164 |
165 | #Reshape back
166 | u = u.view(s1[0], s1[1], s2[2], s2[3])
167 | v = v.view(s1[0], s1[1], s2[2], s2[3])
168 |
169 | return u,v
170 |
171 | class Vector2Magnitude(nn.Module):
172 | def __init__(self):
173 | super(Vector2Magnitude, self).__init__()
174 |
175 | def forward(self, input):
176 | u = input[0]
177 | v = input[1]
178 |
179 | p = torch.sqrt(v ** 2 + u ** 2)
180 | return p
181 |
182 | class Vector2Angle(nn.Module):
183 | def __init__(self):
184 | super(Vector2Angle, self).__init__()
185 |
186 | def forward(self, input):
187 | u = input[0]
188 | v = input[1]
189 |
190 | angle = torch.atan2(u, v)
191 |
192 | return angle
193 |
194 | class VectorBatchNorm(nn.Module):
195 | def __init__(self, num_features, eps=1e-5, momentum=0.5, affine=True):
196 |
197 | super(VectorBatchNorm, self).__init__()
198 | self.num_features = num_features
199 | self.affine = affine
200 | self.eps = eps
201 | self.momentum = momentum
202 |
203 | if self.affine:
204 | self.weight = Parameter(torch.Tensor(1,num_features,1,1))
205 | else:
206 | self.register_parameter('weight', None)
207 | self.register_buffer('running_var', torch.ones(1,num_features,1,1))
208 | self.reset_parameters()
209 |
210 |
211 | def reset_parameters(self):
212 | self.running_var.fill_(1)
213 | if self.affine:
214 | self.weight.data.uniform_()
215 |
216 | def forward(self, input):
217 | """
218 | Based on https://github.com/lberrada/bn.pytorch
219 | """
220 | if self.training:
221 | #Compute std
222 | std = self.std(input)
223 |
224 | alpha = self.weight / (std + self.eps)
225 |
226 | # update running variance
227 | self.running_var *= (1. - self.momentum)
228 | self.running_var += self.momentum * std.data ** 2
229 | # compute output
230 | u = input[0] * alpha
231 | v = input[1] * alpha
232 |
233 | else:
234 | alpha = self.weight.data / torch.sqrt(self.running_var + self.eps)
235 |
236 | # compute output
237 | u = input[0] * Variable(alpha)
238 | v = input[1] * Variable(alpha)
239 | return u,v
240 |
241 | def std(self, input):
242 | u = input[0]
243 | v = input[1]
244 |
245 | #Vector to magnitude
246 | p = torch.sqrt(u ** 2 + v ** 2)
247 |
248 | #We want to normalize the vector magnitudes,
249 | #therefore we ommit the mean (var = (p-p.mean())**2)
250 | #since we do not want to move the center of the vectors.
251 |
252 | var = (p)**2
253 | var = torch.mean(var, 0, keepdim=True)
254 | var = torch.mean(var, 2, keepdim=True)
255 | var = torch.mean(var, 3, keepdim=True)
256 | std = torch.sqrt(var)
257 |
258 | return std
259 |
260 | class VectorUpsampling(nn.Module):
261 | def __init__(self, size=None, scale_factor=None, mode = 'bilinear'):
262 | super(VectorUpsampling, self).__init__()
263 | self.size = size
264 | self.scale_factor = scale_factor
265 | self.mode = mode
266 |
267 | def forward(self, input):
268 | # Assuming input is vector field
269 | u = input[0]
270 | v = input[1]
271 |
272 | u = F.upsample(u, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
273 | v = F.upsample(v, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
274 |
275 |
276 | return u, v
277 |
--------------------------------------------------------------------------------
/layers_3D.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | __author__ = "Anders U. Waldeland"
3 | __email__ = "anders@nr.no"
4 |
5 |
6 | """
7 | This code is an 3D extension of the 2D method proposed in:
8 | Rotation equivariant vector field networks (ICCV 2017)
9 | Diego Marcos, Michele Volpi, Nikos Komodakis, Devis Tuia
10 | https://arxiv.org/abs/1612.09346
11 | https://github.com/dmarcosg/RotEqNet
12 |
13 | We use the spherical coordinate system (see https://en.wikipedia.org/wiki/Spherical_coordinate_system)
14 | with coordinates (r/radius, theta/inclination, rho/azimuth). The 3D vector field has the cartesian coordinates (x,y,z)
15 | but we denote them with (u,v,w) in correspondence with the original paper.
16 | """
17 |
18 | import torch.nn as nn
19 | from torch.nn import functional as F
20 | from torch.nn.parameter import Parameter
21 | import math
22 | from utils import *
23 |
24 |
25 |
26 |
27 |
28 | class RotConv(nn.Module):
29 |
30 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
31 | padding=0, dilation=1, n_inclination = 8, n_azimuth = 4, mode=1):
32 | super(RotConv, self).__init__()
33 |
34 | kernel_size = ntuple(3)(kernel_size)
35 | stride = ntuple(3)(stride)
36 | padding = ntuple(3)(padding)
37 | dilation = ntuple(3)(dilation)
38 |
39 | self.in_channels = in_channels
40 | self.out_channels = out_channels
41 | self.kernel_size = kernel_size
42 | self.stride = stride
43 | self.padding = padding
44 | self.dilation = dilation
45 |
46 | self.mode = mode
47 |
48 | #If input is vector field we have two filters (one for each component)
49 | self.weight1 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
50 | if self.mode == 2:
51 | self.weight2 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
52 | self.weight3 = Parameter(torch.Tensor(out_channels, in_channels, *kernel_size))
53 |
54 |
55 | #Angles (dip and azimuth)
56 | self.thetas = np.linspace(0, 180, n_inclination, endpoint=False)
57 | self.phis = np.linspace(0, 360, n_azimuth, endpoint=False)
58 | self.theta_tensors = []
59 | self.phi_tensors = []
60 |
61 | #Get interpolation variables
62 | self.interp_vars = []
63 | for theta in self.thetas:
64 | for phi in self.phis:
65 |
66 | self.interp_vars.append(get_filter_rotation_transforms(self.kernel_size, [theta, phi]))#TODO
67 |
68 | self.theta_tensors.append(Variable(torch.FloatTensor([theta / 180. * np.pi])))
69 |
70 | self.phi_tensors.append(Variable(torch.FloatTensor([phi / 180. * np.pi])))
71 |
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | n = self.in_channels
76 | for k in self.kernel_size:
77 | n *= k
78 | stdv = 1. / math.sqrt(n)
79 | self.weight1.data.uniform_(-stdv, stdv)
80 | if self.mode == 2:
81 | self.weight2.data.uniform_(-stdv, stdv)
82 | self.weight3.data.uniform_(-stdv, stdv)
83 |
84 | def _apply(self, l):
85 | # We need to replace tensors and variables with cuda-versions
86 | # This is most likely not the nicest way to do this but it works...
87 | self.interp_vars = [[ [l(el2) for el2 in el1 ] for el1 in el0] for el0 in self.interp_vars]
88 | self.thetas = [l(el) for el in self.thetas]
89 | self.phis = [l(el) for el in self.phis]
90 |
91 | super(RotConv, self)._apply(l)
92 |
93 |
94 | def forward(self,input):
95 |
96 | if self.mode == 1:
97 | outputs = []
98 |
99 | #Loop through the different filter-transformations
100 | for ind, interp_vars in enumerate(self.interp_vars):
101 | #Apply rotation
102 | weight = apply_transform(self.weight1, interp_vars, self.kernel_size)
103 |
104 | #Do convolution
105 | out = F.conv3d(input, weight, None, self.stride, self.padding, self.dilation)
106 | outputs.append(out.unsqueeze(-1))
107 |
108 | #Get the maximum direction (Orientation Pooling)
109 | strength, max_ind = torch.max(torch.cat(outputs,-1),-1)
110 |
111 | #Convert to spherical coordinates
112 | theta = max_ind.float() * (360. / 8. / 180. * np.pi)
113 | phi = max_ind.float() * (360. / 8. / 180. * np.pi)
114 |
115 | u = F.relu(strength) * torch.sin(theta) * torch.cos(phi)
116 | v = F.relu(strength) * torch.sin(theta) * torch.sin(phi)
117 | w = F.relu(strength) * torch.cos(theta)
118 |
119 |
120 | if self.mode == 2:
121 | u = input[0]
122 | v = input[1]
123 | w = input[2]
124 |
125 | output_u = []
126 | output_v = []
127 | output_w = []
128 | output_p = [] #magnitude of field
129 |
130 | # Loop through the different filter-transformations
131 | for ind, interp_vars in enumerate(self.interp_vars):
132 | theta = self.theta_tensors[ind]
133 | phi = self.phi_tensors[ind]
134 |
135 | # Apply rotation
136 | wu = apply_transform(self.weight1, interp_vars, self.kernel_size)
137 | wv = apply_transform(self.weight2, interp_vars, self.kernel_size)
138 | ww = apply_transform(self.weight3, interp_vars, self.kernel_size)
139 |
140 |
141 | # Do convolution for u
142 | wru = None#TODO: decompose filters
143 | u_out = F.conv3d(u, wru, None, self.stride, self.padding, self.dilation)
144 | output_u.append(u_out.unsqueeze(-1) )
145 |
146 | # Do convolution for v
147 | wrv = None # TODO: decompose filters
148 | v_out = F.conv3d(v, wrv, None, self.stride, self.padding, self.dilation)
149 | output_v.append(v_out.unsqueeze(-1) )
150 |
151 | # Do convolution for w
152 | wrw = None # TODO: decompose filters
153 | w_out = F.conv3d(w, wrw, None, self.stride, self.padding, self.dilation)
154 | output_w.append(w_out.unsqueeze(-1))
155 |
156 | #Compute magnitude (p)
157 | output_p.append( torch.sqrt( v_out**2 + u_out**2 + w_out**2).unsqueeze(-1) )
158 |
159 |
160 |
161 | # Get the maximum direction (Orientation Pooling)
162 | strength, max_ind = torch.max(torch.cat(output_p, -1), -1)
163 |
164 | # Select the u,v for the maximum orientation
165 | u = torch.cat(output_u, -1)
166 | v = torch.cat(output_v, -1)
167 | w = torch.cat(output_w, -1)
168 |
169 | u = torch.gather(u, -1, max_ind.unsqueeze(-1))[:, :, :, :, :, 0]
170 | v = torch.gather(v, -1, max_ind.unsqueeze(-1))[:, :, :, :, :, 0]
171 | w = torch.gather(w, -1, max_ind.unsqueeze(-1))[:, :, :, :, :, 0]
172 |
173 | return u, v, w
174 |
175 | class VectorMaxPool(nn.Module):
176 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
177 | ceil_mode=False):
178 | super(VectorMaxPool, self).__init__()
179 | self.kernel_size = kernel_size
180 | self.stride = stride or kernel_size
181 | self.padding = padding
182 | self.dilation = dilation
183 | self.ceil_mode = ceil_mode
184 |
185 | def forward(self,input):
186 | #Assuming input is vector field
187 | u = input[0]
188 | v = input[1]
189 | w = input[2]
190 |
191 | #Magnitude
192 | p = torch.sqrt( v**2 + u**2 + w**2)
193 | #Max pool
194 | _, max_inds = F.max_pool3d(p, self.kernel_size, self.stride,
195 | self.padding, self.dilation, self.ceil_mode,
196 | return_indices=True)
197 | #Reshape to please pytorch
198 | s1 = u.size()
199 | s2 = max_inds.size()
200 |
201 | max_inds = max_inds.view(s1[0], s1[1], s2[2] * s2[3] * s2[4])
202 |
203 | u = u.view(s1[0], s1[1], s1[2] * s1[3] * s1[4])
204 | v = v.view(s1[0], s1[1], s1[2] * s1[3] * s1[4])
205 | w = w.view(s1[0], s1[1], s1[2] * s1[3] * s1[4])
206 |
207 | #Select u/v components according to max pool on magnitude
208 | u = torch.gather(u, 2, max_inds)
209 | v = torch.gather(v, 2, max_inds)
210 | w = torch.gather(w, 2, max_inds)
211 |
212 | #Reshape back
213 | u = u.view(s1[0], s1[1], s2[2], s2[3], s1[4])
214 | v = v.view(s1[0], s1[1], s2[2], s2[3], s1[4])
215 | w = w.view(s1[0], s1[1], s2[2], s2[3], s1[4])
216 |
217 |
218 | return u,v,w
219 |
220 | class Vector2Magnitude(nn.Module):
221 | def __init__(self):
222 | super(Vector2Magnitude, self).__init__()
223 |
224 | def forward(self, input):
225 | u = input[0]
226 | v = input[1]
227 | w = input[2]
228 |
229 | p = torch.sqrt(v ** 2 + u ** 2 + w ** 2)
230 | return p
231 |
232 | class VectorBatchNorm(nn.Module):
233 | def __init__(self):
234 | super(VectorBatchNorm, self).__init__()
235 |
236 | def forward(self, input):
237 | if input[0].size()[0] > 1:
238 | u = input[0]
239 | v = input[1]
240 | w = input[2]
241 |
242 |
243 | p = torch.sqrt(v ** 2 + u ** 2 + w ** 2)
244 |
245 | #Mean
246 | mu = torch.mean(p, 0, keepdim=True)
247 | mu = torch.mean(mu, 2, keepdim=True)
248 | mu = torch.mean(mu, 3, keepdim=True)
249 | mu = torch.mean(mu, 4, keepdim=True)
250 |
251 | #Variance
252 | var = (mu-p)**2
253 | var = torch.sum(var, 0, keepdim=True)
254 | var = torch.sum(var, 2, keepdim=True)
255 | var = torch.sum(var, 3, keepdim=True)
256 | var = torch.sum(var, 4, keepdim=True)
257 | std = torch.sqrt(var)
258 |
259 | eps = 0.00001
260 | std = std + eps
261 |
262 |
263 | return u/std, v/std , w/std
264 | else:
265 | return input
266 |
267 |
268 | class VectorUpsampling(nn.Module):
269 | def __init__(self, size=None, scale_factor=None, mode = 'trilinear'):
270 | super(VectorUpsampling, self).__init__()
271 | self.size = size
272 | self.scale_factor = scale_factor
273 | self.mode = mode
274 |
275 | def forward(self, input):
276 | # Assuming input is vector field
277 | u = input[0]
278 | v = input[1]
279 | w = input[2]
280 |
281 | u = F.upsample(u, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
282 | v = F.upsample(v, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
283 | w = F.upsample(w, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
284 |
285 |
286 | return u, v, w
287 |
--------------------------------------------------------------------------------