├── .gitignore
├── README.md
├── example.py
├── setup.py
└── spherenet
├── __init__.py
├── dataset.py
└── sphere_cnn.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled python modules.
2 | *.pyc
3 |
4 | # Setuptools distribution folder.
5 | /dist/
6 |
7 | # Python egg metadata, regenerated from source files by setuptools.
8 | /*.egg-info
9 |
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SphereNet-pytorch
2 | This is an unofficial implementation of ECCV 18 [paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Benjamin_Coors_SphereNet_Learning_Spherical_ECCV_2018_paper.pdf) **"SphereNet: Learning Spherical Representations for Detection and Classification in Omnidirectional Images"**.
3 |
4 | Currently only 3x3 SphereNet's Conv2D and MaxPool2D are implemented. For now only `mode=bilinear` is allowed, `mode=nearest` have to wait until Pytorch1.0. You can replace any model's CNN with SphereNet's CNN, they are implemented such that you can directly load pretrained weight to SphereNet's CNN.
5 |
6 | ## Requirements
7 | - python3
8 | - pytorch>=0.4.1
9 | - numpy
10 | - scipy
11 |
12 | ## Installation
13 | Copy spherenet directory to your project.
14 | If you want to install as an package such that you can import spherenet everywhere:
15 | ```
16 | cd $YOUR_CLONED_SPHERENET_DIR
17 | pip install .
18 | ```
19 |
20 | ## Example
21 | ``` python
22 | from spherenet import SphereConv2D, SphereMaxPool2D
23 |
24 | conv1 = SphereConv2D(1, 32, stride=1)
25 | pool1 = SphereMaxPool2D(stride=2)
26 |
27 | # toy example
28 | img = torch.randn(1, 1, 60, 60) # (batch, channel, height, weight)
29 | out = conv1(img) # (1, 32, 60, 60)
30 | out = pool1(out) # (1, 32, 30, 30)
31 | ```
32 | - To apply SphereNet in your trained model, simply replace the ```nn.Conv2d``` with ```SphereConv2D```, and replace ```nn.MaxPool2d``` with ```SphereMaxPool2D```. They should work well with `load_state_dict`.
33 |
34 | ## Results
35 | - Classification OminiMNIST data (`spherenet.OmniMNIST`, `spherenet.OmniFashionMNIST`)
36 | -
37 | - Reproduce OmniMNIST Result
38 | - | Method | Test Error (%) |
39 | | ------------- |:--------------:|
40 | | SphereNet ( paper ) | 5.59 |
41 | | SphereNet ( ours ) | 5.77 |
42 | | EquirectCNN ( paper ) | 9.61 |
43 | | EquirectCNN ( ours )| 9.63 |
44 | -
45 |
46 | ## References
47 | - [paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Benjamin_Coors_SphereNet_Learning_Spherical_ECCV_2018_paper.pdf)
48 | - Benjamin Coors, Alexandru Paul Condurache, Andreas Geiger
49 | - ECCV2018
50 | ```
51 | @inproceedings{coors2018spherenet,
52 | title={SphereNet: Learning Spherical Representations for Detection and Classification in Omnidirectional Images},
53 | author={Coors, Benjamin and Condurache, Alexandru Paul and Geiger, Andreas},
54 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
55 | pages={518--533},
56 | year={2018}
57 | }
58 | ```
59 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from spherenet import OmniMNIST, OmniFashionMNIST
3 | from spherenet import SphereConv2D, SphereMaxPool2D
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 |
9 |
10 | class SphereNet(nn.Module):
11 | def __init__(self):
12 | super(SphereNet, self).__init__()
13 | self.conv1 = SphereConv2D(1, 32, stride=1)
14 | self.pool1 = SphereMaxPool2D(stride=2)
15 | self.conv2 = SphereConv2D(32, 64, stride=1)
16 | self.pool2 = SphereMaxPool2D(stride=2)
17 |
18 | self.fc = nn.Linear(14400, 10)
19 |
20 | def forward(self, x):
21 | x = F.relu(self.pool1(self.conv1(x)))
22 | x = F.relu(self.pool2(self.conv2(x)))
23 | x = x.view(-1, 14400) # flatten, [B, C, H, W) -> (B, C*H*W)
24 | x = self.fc(x)
25 | return x
26 |
27 |
28 | class Net(nn.Module):
29 | def __init__(self):
30 | super(Net, self).__init__()
31 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
32 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
33 | self.fc = nn.Linear(64*13*13, 10)
34 |
35 | def forward(self, x):
36 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
37 | x = F.relu(F.max_pool2d(self.conv2(x), 2))
38 | x = x.view(-1, 64*13*13) # flatten, [B, C, H, W) -> (B, C*H*W)
39 | x = self.fc(x)
40 | return x
41 |
42 |
43 | def train(args, model, device, train_loader, optimizer, epoch):
44 | model.train()
45 | for batch_idx, (data, target) in enumerate(train_loader):
46 | data, target = data.to(device), target.to(device)
47 | optimizer.zero_grad()
48 | if data.dim() == 3:
49 | data = data.unsqueeze(1) # (B, H, W) -> (B, C, H, W)
50 | output = model(data)
51 | loss = F.cross_entropy(output, target)
52 | loss.backward()
53 | optimizer.step()
54 | if batch_idx % args.log_interval == 0:
55 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
56 | epoch, batch_idx * len(data), len(train_loader.dataset),
57 | 100. * batch_idx / len(train_loader), loss.item()))
58 |
59 |
60 | def test(args, model, device, test_loader):
61 | model.eval()
62 | test_loss = 0
63 | correct = 0
64 | with torch.no_grad():
65 | for data, target in test_loader:
66 | data, target = data.to(device), target.to(device)
67 | if data.dim() == 3:
68 | data = data.unsqueeze(1) # (B, H, W) -> (B, C, H, W)
69 | output = model(data)
70 | test_loss += F.cross_entropy(output, target).item() # sum up batch loss
71 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
72 | correct += pred.eq(target.view_as(pred)).sum().item()
73 |
74 | test_loss /= len(test_loader.dataset)
75 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
76 | test_loss, correct, len(test_loader.dataset),
77 | 100. * correct / len(test_loader.dataset)))
78 |
79 |
80 | def main():
81 | # Training settings
82 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
83 | parser.add_argument('--data', type=str, default='MNIST',
84 | help='dataset for training, options={"FashionMNIST", "MNIST"}')
85 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
86 | help='input batch size for training')
87 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
88 | help='input batch size for testing')
89 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
90 | help='number of epochs to train')
91 | parser.add_argument('--optimizer', type=str, default='adam',
92 | help='optimizer, options={"adam, sgd"}')
93 | parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
94 | help='learning rate')
95 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
96 | help='SGD momentum')
97 | parser.add_argument('--no-cuda', action='store_true', default=False,
98 | help='disables CUDA training')
99 | parser.add_argument('--seed', type=int, default=1, metavar='S',
100 | help='random seed')
101 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
102 | help='how many batches to wait before logging training status')
103 | parser.add_argument('--save-interval', type=int, default=1, metavar='N',
104 | help='how many epochs to wait before saving model weights')
105 | args = parser.parse_args()
106 | use_cuda = not args.no_cuda and torch.cuda.is_available()
107 |
108 | torch.manual_seed(args.seed)
109 |
110 | device = torch.device('cuda' if use_cuda else 'cpu')
111 |
112 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
113 |
114 | np.random.seed(args.seed)
115 | if args.data == 'FashionMNIST':
116 | train_dataset = OmniFashionMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, img_std=255, train=True)
117 | test_dataset = OmniFashionMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, img_std=255, train=False, fix_aug=True)
118 | elif args.data == 'MNIST':
119 | train_dataset = OmniMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, train=True)
120 | test_dataset = OmniMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, train=False, fix_aug=True)
121 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
122 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
123 |
124 | # Train
125 | sphere_model = SphereNet().to(device)
126 | model = Net().to(device)
127 | if args.optimizer == 'adam':
128 | sphere_optimizer = torch.optim.Adam(sphere_model.parameters(), lr=args.lr)
129 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
130 | elif args.optimizer == 'sgd':
131 | sphere_optimizer = torch.optim.SGD(sphere_model.parameters(), lr=args.lr, momentum=args.momentum)
132 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
133 |
134 | for epoch in range(1, args.epochs + 1):
135 | # SphereCNN
136 | print('{} Sphere CNN {}'.format('='*10, '='*10))
137 | train(args, sphere_model, device, train_loader, sphere_optimizer, epoch)
138 | test(args, sphere_model, device, test_loader)
139 | if epoch % args.save_interval == 0:
140 | torch.save(sphere_model.state_dict(), 'sphere_model.pkl')
141 |
142 | # Conventional CNN
143 | print('{} Conventional CNN {}'.format('='*10, '='*10))
144 | train(args, model, device, train_loader, optimizer, epoch)
145 | test(args, model, device, test_loader)
146 | if epoch % args.save_interval == 0:
147 | torch.save(model.state_dict(), 'model.pkl')
148 |
149 |
150 | if __name__ == '__main__':
151 | main()
152 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='spherenet',
4 | version='0.1',
5 | description='Pytorch implementation of ECCV 2018 paper: '
6 | 'SphereNet: Learning Spherical Representations '
7 | 'for Detection and Classification in Omnidirectional Images',
8 | url='https://github.com/ChiWeiHsiao/SphereNet-pytorch',
9 | author='Chi-Wei Hsiao, Cheng Sun',
10 | author_email='kiwi010379@gmail.com, chengsun@gapp.nthu.edu.tw',
11 | license='MIT',
12 | packages=['spherenet'],
13 | zip_safe=False)
14 |
--------------------------------------------------------------------------------
/spherenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .sphere_cnn import SphereConv2D
2 | from .sphere_cnn import SphereMaxPool2D
3 | from .dataset import OmniMNIST
4 | from .dataset import OmniFashionMNIST
5 |
--------------------------------------------------------------------------------
/spherenet/dataset.py:
--------------------------------------------------------------------------------
1 | # Mathematical
2 | import numpy as np
3 | from scipy.ndimage.interpolation import map_coordinates
4 |
5 | # Pytorch
6 | import torch
7 | from torch.utils import data
8 | from torchvision import datasets
9 |
10 | # Misc
11 | from functools import lru_cache
12 |
13 |
14 | def genuv(h, w):
15 | u, v = np.meshgrid(np.arange(w), np.arange(h))
16 | u = (u + 0.5) * 2 * np.pi / w - np.pi
17 | v = (v + 0.5) * np.pi / h - np.pi / 2
18 | return np.stack([u, v], axis=-1)
19 |
20 |
21 | def uv2xyz(uv):
22 | sin_u = np.sin(uv[..., 0])
23 | cos_u = np.cos(uv[..., 0])
24 | sin_v = np.sin(uv[..., 1])
25 | cos_v = np.cos(uv[..., 1])
26 | return np.stack([
27 | cos_v * cos_u,
28 | cos_v * sin_u,
29 | sin_v
30 | ], axis=-1)
31 |
32 |
33 | def xyz2uv(xyz):
34 | c = np.sqrt((xyz[..., :2] ** 2).sum(-1))
35 | u = np.arctan2(xyz[..., 1], xyz[..., 0])
36 | v = np.arctan2(xyz[..., 2], c)
37 | return np.stack([u, v], axis=-1)
38 |
39 |
40 | def uv2img_idx(uv, h, w, u_fov, v_fov, v_c=0):
41 | assert 0 < u_fov and u_fov < np.pi
42 | assert 0 < v_fov and v_fov < np.pi
43 | assert -np.pi < v_c and v_c < np.pi
44 |
45 | xyz = uv2xyz(uv.astype(np.float64))
46 | Ry = np.array([
47 | [np.cos(v_c), 0, -np.sin(v_c)],
48 | [0, 1, 0],
49 | [np.sin(v_c), 0, np.cos(v_c)],
50 | ])
51 | xyz_rot = xyz.copy()
52 | xyz_rot[..., 0] = np.cos(v_c) * xyz[..., 0] - np.sin(v_c) * xyz[..., 2]
53 | xyz_rot[..., 1] = xyz[..., 1]
54 | xyz_rot[..., 2] = np.sin(v_c) * xyz[..., 0] + np.cos(v_c) * xyz[..., 2]
55 | uv_rot = xyz2uv(xyz_rot)
56 |
57 | u = uv_rot[..., 0]
58 | v = uv_rot[..., 1]
59 |
60 | x = np.tan(u)
61 | y = np.tan(v) / np.cos(u)
62 | x = x * w / (2 * np.tan(u_fov / 2)) + w / 2
63 | y = y * h / (2 * np.tan(v_fov / 2)) + h / 2
64 |
65 | invalid = (u < -u_fov / 2) | (u > u_fov / 2) |\
66 | (v < -v_fov / 2) | (v > v_fov / 2)
67 | x[invalid] = -100
68 | y[invalid] = -100
69 |
70 | return np.stack([y, x], axis=0)
71 |
72 |
73 | class OmniDataset(data.Dataset):
74 | def __init__(self, dataset, fov=120, outshape=(60, 60),
75 | flip=False, h_rotate=False, v_rotate=False,
76 | img_mean=None, img_std=None, fix_aug=False):
77 | '''
78 | Convert classification dataset to omnidirectional version
79 | @dataset dataset with same interface as torch.utils.data.Dataset
80 | yield (PIL image, label) if indexing
81 | '''
82 | self.dataset = dataset
83 | self.fov = fov
84 | self.outshape = outshape
85 | self.flip = flip
86 | self.h_rotate = h_rotate
87 | self.v_rotate = v_rotate
88 | self.img_mean = img_mean
89 | self.img_std = img_std
90 |
91 | self.aug = None
92 | if fix_aug:
93 | self.aug = [
94 | {
95 | 'flip': np.random.randint(2) == 0,
96 | 'h_rotate': np.random.randint(outshape[1]),
97 | 'v_rotate': np.random.uniform(-np.pi/2, np.pi/2),
98 | }
99 | for _ in range(len(self.dataset))
100 | ]
101 |
102 | def __len__(self):
103 | return len(self.dataset)
104 |
105 | def __getitem__(self, idx):
106 | img = np.array(self.dataset[idx][0], np.float32)
107 | h, w = img.shape[:2]
108 | uv = genuv(*self.outshape)
109 | fov = self.fov * np.pi / 180
110 |
111 | if self.v_rotate:
112 | if self.aug is not None:
113 | v_c = self.aug[idx]['v_rotate']
114 | else:
115 | v_c = np.random.uniform(-np.pi/2, np.pi/2)
116 | img_idx = uv2img_idx(uv, h, w, fov, fov, v_c)
117 | else:
118 | img_idx = uv2img_idx(uv, h, w, fov, fov, 0)
119 | x = map_coordinates(img, img_idx, order=1)
120 |
121 | # Random flip
122 | if self.aug is not None:
123 | if self.aug[idx]['flip']:
124 | x = np.flip(x, axis=1)
125 | elif self.flip and np.random.randint(2) == 0:
126 | x = np.flip(x, axis=1)
127 |
128 | # Random horizontal rotate
129 | if self.h_rotate:
130 | if self.aug is not None:
131 | dx = self.aug[idx]['h_rotate']
132 | else:
133 | dx = np.random.randint(x.shape[1])
134 | x = np.roll(x, dx, axis=1)
135 |
136 | # Normalize image
137 | if self.img_mean is not None:
138 | x = x - self.img_mean
139 | if self.img_std is not None:
140 | x = x / self.img_std
141 |
142 | return torch.FloatTensor(x.copy()), self.dataset[idx][1]
143 |
144 |
145 | class OmniMNIST(OmniDataset):
146 | def __init__(self, root='datas/MNIST', train=True,
147 | download=True, *args, **kwargs):
148 | '''
149 | Omnidirectional MNIST
150 | @root (str) root directory storing the dataset
151 | @train (bool) train or test split
152 | @download (bool) whether to download if data now exist
153 | '''
154 | self.MNIST = datasets.MNIST(root, train=train, download=download)
155 | super(OmniMNIST, self).__init__(self.MNIST, *args, **kwargs)
156 |
157 |
158 | class OmniFashionMNIST(OmniDataset):
159 | def __init__(self, root='datas/FashionMNIST', train=True,
160 | download=True, *args, **kwargs):
161 | '''
162 | Omnidirectional FashionMNIST
163 | @root (str) root directory storing the dataset
164 | @train (bool) train or test split
165 | @download (bool) whether to download if data now exist
166 | '''
167 | self.FashionMNIST = datasets.FashionMNIST(root, train=train, download=download)
168 | super(OmniFashionMNIST, self).__init__(self.FashionMNIST, *args, **kwargs)
169 |
170 |
171 | if __name__ == '__main__':
172 |
173 | import os
174 | import argparse
175 | from PIL import Image
176 |
177 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
178 | parser.add_argument('--idx', nargs='+', required=True,
179 | help='image indices to demo')
180 | parser.add_argument('--out_dir', default='datas/demo',
181 | help='directory to output demo image')
182 | parser.add_argument('--dataset', default='OmniMNIST',
183 | choices=['OmniMNIST', 'OmniFashionMNIST'],
184 | help='which dataset to use')
185 |
186 | parser.add_argument('--fov', type=int, default=120,
187 | help='fov of the tangent plane')
188 | parser.add_argument('--flip', action='store_true',
189 | help='whether to apply random flip')
190 | parser.add_argument('--h_rotate', action='store_true',
191 | help='whether to apply random panorama horizontal rotation')
192 | parser.add_argument('--v_rotate', action='store_true',
193 | help='whether to apply random panorama vertical rotation')
194 | parser.add_argument('--fix_aug', action='store_true',
195 | help='whether to apply random panorama vertical rotation')
196 | args = parser.parse_args()
197 |
198 | os.makedirs(args.out_dir, exist_ok=True)
199 |
200 | if args.dataset == 'OmniMNIST':
201 | dataset = OmniMNIST(fov=args.fov, flip=args.flip,
202 | h_rotate=args.h_rotate, v_rotate=args.v_rotate,
203 | fix_aug=args.fix_aug)
204 | elif args.dataset == 'OmniFashionMNIST':
205 | dataset = OmniFashionMNIST(fov=args.fov, flip=args.flip,
206 | h_rotate=args.h_rotate, v_rotate=args.v_rotate,
207 | fix_aug=args.fix_aug)
208 |
209 | for idx in args.idx:
210 | idx = int(idx)
211 | path = os.path.join(args.out_dir, '%d.png' % idx)
212 | x, label = dataset[idx]
213 |
214 | print(path, label)
215 | Image.fromarray(x.numpy().astype(np.uint8)).save(path)
216 |
--------------------------------------------------------------------------------
/spherenet/sphere_cnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy import sin, cos, tan, pi, arcsin, arctan
3 | from functools import lru_cache
4 | import torch
5 | from torch import nn
6 | from torch.nn.parameter import Parameter
7 |
8 |
9 | # Calculate kernels of SphereCNN
10 | @lru_cache(None)
11 | def get_xy(delta_phi, delta_theta):
12 | return np.array([
13 | [
14 | (-tan(delta_theta), 1/cos(delta_theta)*tan(delta_phi)),
15 | (0, tan(delta_phi)),
16 | (tan(delta_theta), 1/cos(delta_theta)*tan(delta_phi)),
17 | ],
18 | [
19 | (-tan(delta_theta), 0),
20 | (1, 1),
21 | (tan(delta_theta), 0),
22 | ],
23 | [
24 | (-tan(delta_theta), -1/cos(delta_theta)*tan(delta_phi)),
25 | (0, -tan(delta_phi)),
26 | (tan(delta_theta), -1/cos(delta_theta)*tan(delta_phi)),
27 | ]
28 | ])
29 |
30 | @lru_cache(None)
31 | def cal_index(h, w, img_r, img_c):
32 | '''
33 | Calculate Kernel Sampling Pattern
34 | only support 3x3 filter
35 | return 9 locations: (3, 3, 2)
36 | '''
37 | # pixel -> rad
38 | phi = -((img_r+0.5)/h*pi - pi/2)
39 | theta = (img_c+0.5)/w*2*pi-pi
40 |
41 | delta_phi = pi/h
42 | delta_theta = 2*pi/w
43 |
44 | xys = get_xy(delta_phi, delta_theta)
45 | x = xys[..., 0]
46 | y = xys[..., 1]
47 | rho = np.sqrt(x**2+y**2)
48 | v = arctan(rho)
49 | new_phi= arcsin(cos(v)*sin(phi) + y*sin(v)*cos(phi)/rho)
50 | new_theta = theta + arctan(x*sin(v) / (rho*cos(phi)*cos(v) - y*sin(phi)*sin(v)))
51 | # rad -> pixel
52 | new_r = (-new_phi+pi/2)*h/pi - 0.5
53 | new_c = (new_theta+pi)*w/2/pi - 0.5
54 | # indexs out of image, equirectangular leftmost and rightmost pixel is adjacent
55 | new_c = (new_c + w) % w
56 | new_result = np.stack([new_r, new_c], axis=-1)
57 | new_result[1, 1] = (img_r, img_c)
58 | return new_result
59 |
60 |
61 | @lru_cache(None)
62 | def _gen_filters_coordinates(h, w, stride):
63 | co = np.array([[cal_index(h, w, i, j) for j in range(0, w, stride)] for i in range(0, h, stride)])
64 | return np.ascontiguousarray(co.transpose([4, 0, 1, 2, 3]))
65 |
66 |
67 | def gen_filters_coordinates(h, w, stride=1):
68 | '''
69 | return np array of kernel lo (2, H/stride, W/stride, 3, 3)
70 | '''
71 | assert(isinstance(h, int) and isinstance(w, int))
72 | return _gen_filters_coordinates(h, w, stride).copy()
73 |
74 |
75 | def gen_grid_coordinates(h, w, stride=1):
76 | coordinates = gen_filters_coordinates(h, w, stride).copy()
77 | coordinates[0] = (coordinates[0] * 2 / h) - 1
78 | coordinates[1] = (coordinates[1] * 2 / w) - 1
79 | coordinates = coordinates[::-1]
80 | coordinates = coordinates.transpose(1, 3, 2, 4, 0)
81 | sz = coordinates.shape
82 | coordinates = coordinates.reshape(1, sz[0]*sz[1], sz[2]*sz[3], sz[4])
83 |
84 | return coordinates.copy()
85 |
86 |
87 | class SphereConv2D(nn.Module):
88 | ''' SphereConv2D
89 | Note that this layer only support 3x3 filter
90 | '''
91 | def __init__(self, in_c, out_c, stride=1, bias=True, mode='bilinear'):
92 | super(SphereConv2D, self).__init__()
93 | self.in_c = in_c
94 | self.out_c = out_c
95 | self.stride = stride
96 | self.mode = mode
97 | self.weight = Parameter(torch.Tensor(out_c, in_c, 3, 3))
98 | if bias:
99 | self.bias = Parameter(torch.Tensor(out_c))
100 | else:
101 | self.register_parameter('bias', None)
102 | self.grid_shape = None
103 | self.grid = None
104 |
105 | self.reset_parameters()
106 |
107 | def reset_parameters(self):
108 | nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
109 | self.bias.data.zero_()
110 |
111 | def forward(self, x):
112 | if self.grid_shape is None or self.grid_shape != tuple(x.shape[2:4]):
113 | self.grid_shape = tuple(x.shape[2:4])
114 | coordinates = gen_grid_coordinates(x.shape[2], x.shape[3], self.stride)
115 | with torch.no_grad():
116 | self.grid = torch.FloatTensor(coordinates).to(x.device)
117 | self.grid.requires_grad = True
118 |
119 | with torch.no_grad():
120 | grid = self.grid.repeat(x.shape[0], 1, 1, 1)
121 |
122 | x = nn.functional.grid_sample(x, grid, mode=self.mode)
123 | x = nn.functional.conv2d(x, self.weight, self.bias, stride=3)
124 | return x
125 |
126 |
127 | class SphereMaxPool2D(nn.Module):
128 | ''' SphereMaxPool2D
129 | Note that this layer only support 3x3 filter
130 | '''
131 | def __init__(self, stride=1, mode='bilinear'):
132 | super(SphereMaxPool2D, self).__init__()
133 | self.stride = stride
134 | self.mode = mode
135 | self.grid_shape = None
136 | self.grid = None
137 | self.pool = nn.MaxPool2d(kernel_size=3, stride=3)
138 |
139 | def forward(self, x):
140 | if self.grid_shape is None or self.grid_shape != tuple(x.shape[2:4]):
141 | self.grid_shape = tuple(x.shape[2:4])
142 | coordinates = gen_grid_coordinates(x.shape[2], x.shape[3], self.stride)
143 | with torch.no_grad():
144 | self.grid = torch.FloatTensor(coordinates).to(x.device)
145 | self.grid.requires_grad = True
146 |
147 | with torch.no_grad():
148 | grid = self.grid.repeat(x.shape[0], 1, 1, 1)
149 |
150 | return self.pool(nn.functional.grid_sample(x, grid, mode=self.mode))
151 |
152 |
153 | if __name__ == '__main__':
154 | import matplotlib.pyplot as plt
155 | import matplotlib.image as mpimg
156 | # test cnn
157 | cnn = SphereConv2D(3, 5, 1)
158 | out = cnn(torch.randn(2, 3, 10, 10))
159 | print('SphereConv2D(3, 5, 1) output shape: ', out.size())
160 | # test pool
161 | # create sample image
162 | h, w = 100, 200
163 | img = np.ones([h, w, 3])
164 | for r in range(h):
165 | for c in range(w):
166 | img[r, c, 0] = img[r, c, 0] - r/h
167 | img[r, c, 1] = img[r, c, 1] - c/w
168 | plt.imsave('demo_original', img)
169 | img = img.transpose([2, 0, 1])
170 | img = np.expand_dims(img, 0) # (B, C, H, W)
171 | # pool
172 | pool = SphereMaxPool2D(1)
173 | out = pool(torch.from_numpy(img).float())
174 | out = np.squeeze(out.numpy(), 0).transpose([1, 2, 0])
175 | plt.imsave('demo_pool_1.png', out)
176 | print('Save image after pooling with stride 1: demo_pool_1.png')
177 | # pool with tride 3
178 | pool = SphereMaxPool2D(3)
179 | out = pool(torch.from_numpy(img).float())
180 | out = np.squeeze(out.numpy(), 0).transpose([1, 2, 0])
181 | plt.imsave('demo_pool_3.png', out)
182 | print('Save image after pooling with stride 3: demo_pool_3.png')
183 |
--------------------------------------------------------------------------------