├── .gitignore ├── README.md ├── example.py ├── setup.py └── spherenet ├── __init__.py ├── dataset.py └── sphere_cnn.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | 7 | # Python egg metadata, regenerated from source files by setuptools. 8 | /*.egg-info 9 | 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SphereNet-pytorch 2 | This is an unofficial implementation of ECCV 18 [paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Benjamin_Coors_SphereNet_Learning_Spherical_ECCV_2018_paper.pdf) **"SphereNet: Learning Spherical Representations for Detection and Classification in Omnidirectional Images"**. 3 | 4 | Currently only 3x3 SphereNet's Conv2D and MaxPool2D are implemented. For now only `mode=bilinear` is allowed, `mode=nearest` have to wait until Pytorch1.0. You can replace any model's CNN with SphereNet's CNN, they are implemented such that you can directly load pretrained weight to SphereNet's CNN. 5 | 6 | ## Requirements 7 | - python3 8 | - pytorch>=0.4.1 9 | - numpy 10 | - scipy 11 | 12 | ## Installation 13 | Copy spherenet directory to your project. 14 | If you want to install as an package such that you can import spherenet everywhere: 15 | ``` 16 | cd $YOUR_CLONED_SPHERENET_DIR 17 | pip install . 18 | ``` 19 | 20 | ## Example 21 | ``` python 22 | from spherenet import SphereConv2D, SphereMaxPool2D 23 | 24 | conv1 = SphereConv2D(1, 32, stride=1) 25 | pool1 = SphereMaxPool2D(stride=2) 26 | 27 | # toy example 28 | img = torch.randn(1, 1, 60, 60) # (batch, channel, height, weight) 29 | out = conv1(img) # (1, 32, 60, 60) 30 | out = pool1(out) # (1, 32, 30, 30) 31 | ``` 32 | - To apply SphereNet in your trained model, simply replace the ```nn.Conv2d``` with ```SphereConv2D```, and replace ```nn.MaxPool2d``` with ```SphereMaxPool2D```. They should work well with `load_state_dict`. 33 | 34 | ## Results 35 | - Classification OminiMNIST data (`spherenet.OmniMNIST`, `spherenet.OmniFashionMNIST`) 36 | - 37 | - Reproduce OmniMNIST Result 38 | - | Method | Test Error (%) | 39 | | ------------- |:--------------:| 40 | | SphereNet ( paper ) | 5.59 | 41 | | SphereNet ( ours ) | 5.77 | 42 | | EquirectCNN ( paper ) | 9.61 | 43 | | EquirectCNN ( ours )| 9.63 | 44 | - 45 | 46 | ## References 47 | - [paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Benjamin_Coors_SphereNet_Learning_Spherical_ECCV_2018_paper.pdf) 48 | - Benjamin Coors, Alexandru Paul Condurache, Andreas Geiger 49 | - ECCV2018 50 | ``` 51 | @inproceedings{coors2018spherenet, 52 | title={SphereNet: Learning Spherical Representations for Detection and Classification in Omnidirectional Images}, 53 | author={Coors, Benjamin and Condurache, Alexandru Paul and Geiger, Andreas}, 54 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 55 | pages={518--533}, 56 | year={2018} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from spherenet import OmniMNIST, OmniFashionMNIST 3 | from spherenet import SphereConv2D, SphereMaxPool2D 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | class SphereNet(nn.Module): 11 | def __init__(self): 12 | super(SphereNet, self).__init__() 13 | self.conv1 = SphereConv2D(1, 32, stride=1) 14 | self.pool1 = SphereMaxPool2D(stride=2) 15 | self.conv2 = SphereConv2D(32, 64, stride=1) 16 | self.pool2 = SphereMaxPool2D(stride=2) 17 | 18 | self.fc = nn.Linear(14400, 10) 19 | 20 | def forward(self, x): 21 | x = F.relu(self.pool1(self.conv1(x))) 22 | x = F.relu(self.pool2(self.conv2(x))) 23 | x = x.view(-1, 14400) # flatten, [B, C, H, W) -> (B, C*H*W) 24 | x = self.fc(x) 25 | return x 26 | 27 | 28 | class Net(nn.Module): 29 | def __init__(self): 30 | super(Net, self).__init__() 31 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 32 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 33 | self.fc = nn.Linear(64*13*13, 10) 34 | 35 | def forward(self, x): 36 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 37 | x = F.relu(F.max_pool2d(self.conv2(x), 2)) 38 | x = x.view(-1, 64*13*13) # flatten, [B, C, H, W) -> (B, C*H*W) 39 | x = self.fc(x) 40 | return x 41 | 42 | 43 | def train(args, model, device, train_loader, optimizer, epoch): 44 | model.train() 45 | for batch_idx, (data, target) in enumerate(train_loader): 46 | data, target = data.to(device), target.to(device) 47 | optimizer.zero_grad() 48 | if data.dim() == 3: 49 | data = data.unsqueeze(1) # (B, H, W) -> (B, C, H, W) 50 | output = model(data) 51 | loss = F.cross_entropy(output, target) 52 | loss.backward() 53 | optimizer.step() 54 | if batch_idx % args.log_interval == 0: 55 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 56 | epoch, batch_idx * len(data), len(train_loader.dataset), 57 | 100. * batch_idx / len(train_loader), loss.item())) 58 | 59 | 60 | def test(args, model, device, test_loader): 61 | model.eval() 62 | test_loss = 0 63 | correct = 0 64 | with torch.no_grad(): 65 | for data, target in test_loader: 66 | data, target = data.to(device), target.to(device) 67 | if data.dim() == 3: 68 | data = data.unsqueeze(1) # (B, H, W) -> (B, C, H, W) 69 | output = model(data) 70 | test_loss += F.cross_entropy(output, target).item() # sum up batch loss 71 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 72 | correct += pred.eq(target.view_as(pred)).sum().item() 73 | 74 | test_loss /= len(test_loader.dataset) 75 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 76 | test_loss, correct, len(test_loader.dataset), 77 | 100. * correct / len(test_loader.dataset))) 78 | 79 | 80 | def main(): 81 | # Training settings 82 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 83 | parser.add_argument('--data', type=str, default='MNIST', 84 | help='dataset for training, options={"FashionMNIST", "MNIST"}') 85 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 86 | help='input batch size for training') 87 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', 88 | help='input batch size for testing') 89 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 90 | help='number of epochs to train') 91 | parser.add_argument('--optimizer', type=str, default='adam', 92 | help='optimizer, options={"adam, sgd"}') 93 | parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', 94 | help='learning rate') 95 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 96 | help='SGD momentum') 97 | parser.add_argument('--no-cuda', action='store_true', default=False, 98 | help='disables CUDA training') 99 | parser.add_argument('--seed', type=int, default=1, metavar='S', 100 | help='random seed') 101 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 102 | help='how many batches to wait before logging training status') 103 | parser.add_argument('--save-interval', type=int, default=1, metavar='N', 104 | help='how many epochs to wait before saving model weights') 105 | args = parser.parse_args() 106 | use_cuda = not args.no_cuda and torch.cuda.is_available() 107 | 108 | torch.manual_seed(args.seed) 109 | 110 | device = torch.device('cuda' if use_cuda else 'cpu') 111 | 112 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 113 | 114 | np.random.seed(args.seed) 115 | if args.data == 'FashionMNIST': 116 | train_dataset = OmniFashionMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, img_std=255, train=True) 117 | test_dataset = OmniFashionMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, img_std=255, train=False, fix_aug=True) 118 | elif args.data == 'MNIST': 119 | train_dataset = OmniMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, train=True) 120 | test_dataset = OmniMNIST(fov=120, flip=True, h_rotate=True, v_rotate=True, train=False, fix_aug=True) 121 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 122 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 123 | 124 | # Train 125 | sphere_model = SphereNet().to(device) 126 | model = Net().to(device) 127 | if args.optimizer == 'adam': 128 | sphere_optimizer = torch.optim.Adam(sphere_model.parameters(), lr=args.lr) 129 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 130 | elif args.optimizer == 'sgd': 131 | sphere_optimizer = torch.optim.SGD(sphere_model.parameters(), lr=args.lr, momentum=args.momentum) 132 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 133 | 134 | for epoch in range(1, args.epochs + 1): 135 | # SphereCNN 136 | print('{} Sphere CNN {}'.format('='*10, '='*10)) 137 | train(args, sphere_model, device, train_loader, sphere_optimizer, epoch) 138 | test(args, sphere_model, device, test_loader) 139 | if epoch % args.save_interval == 0: 140 | torch.save(sphere_model.state_dict(), 'sphere_model.pkl') 141 | 142 | # Conventional CNN 143 | print('{} Conventional CNN {}'.format('='*10, '='*10)) 144 | train(args, model, device, train_loader, optimizer, epoch) 145 | test(args, model, device, test_loader) 146 | if epoch % args.save_interval == 0: 147 | torch.save(model.state_dict(), 'model.pkl') 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='spherenet', 4 | version='0.1', 5 | description='Pytorch implementation of ECCV 2018 paper: ' 6 | 'SphereNet: Learning Spherical Representations ' 7 | 'for Detection and Classification in Omnidirectional Images', 8 | url='https://github.com/ChiWeiHsiao/SphereNet-pytorch', 9 | author='Chi-Wei Hsiao, Cheng Sun', 10 | author_email='kiwi010379@gmail.com, chengsun@gapp.nthu.edu.tw', 11 | license='MIT', 12 | packages=['spherenet'], 13 | zip_safe=False) 14 | -------------------------------------------------------------------------------- /spherenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_cnn import SphereConv2D 2 | from .sphere_cnn import SphereMaxPool2D 3 | from .dataset import OmniMNIST 4 | from .dataset import OmniFashionMNIST 5 | -------------------------------------------------------------------------------- /spherenet/dataset.py: -------------------------------------------------------------------------------- 1 | # Mathematical 2 | import numpy as np 3 | from scipy.ndimage.interpolation import map_coordinates 4 | 5 | # Pytorch 6 | import torch 7 | from torch.utils import data 8 | from torchvision import datasets 9 | 10 | # Misc 11 | from functools import lru_cache 12 | 13 | 14 | def genuv(h, w): 15 | u, v = np.meshgrid(np.arange(w), np.arange(h)) 16 | u = (u + 0.5) * 2 * np.pi / w - np.pi 17 | v = (v + 0.5) * np.pi / h - np.pi / 2 18 | return np.stack([u, v], axis=-1) 19 | 20 | 21 | def uv2xyz(uv): 22 | sin_u = np.sin(uv[..., 0]) 23 | cos_u = np.cos(uv[..., 0]) 24 | sin_v = np.sin(uv[..., 1]) 25 | cos_v = np.cos(uv[..., 1]) 26 | return np.stack([ 27 | cos_v * cos_u, 28 | cos_v * sin_u, 29 | sin_v 30 | ], axis=-1) 31 | 32 | 33 | def xyz2uv(xyz): 34 | c = np.sqrt((xyz[..., :2] ** 2).sum(-1)) 35 | u = np.arctan2(xyz[..., 1], xyz[..., 0]) 36 | v = np.arctan2(xyz[..., 2], c) 37 | return np.stack([u, v], axis=-1) 38 | 39 | 40 | def uv2img_idx(uv, h, w, u_fov, v_fov, v_c=0): 41 | assert 0 < u_fov and u_fov < np.pi 42 | assert 0 < v_fov and v_fov < np.pi 43 | assert -np.pi < v_c and v_c < np.pi 44 | 45 | xyz = uv2xyz(uv.astype(np.float64)) 46 | Ry = np.array([ 47 | [np.cos(v_c), 0, -np.sin(v_c)], 48 | [0, 1, 0], 49 | [np.sin(v_c), 0, np.cos(v_c)], 50 | ]) 51 | xyz_rot = xyz.copy() 52 | xyz_rot[..., 0] = np.cos(v_c) * xyz[..., 0] - np.sin(v_c) * xyz[..., 2] 53 | xyz_rot[..., 1] = xyz[..., 1] 54 | xyz_rot[..., 2] = np.sin(v_c) * xyz[..., 0] + np.cos(v_c) * xyz[..., 2] 55 | uv_rot = xyz2uv(xyz_rot) 56 | 57 | u = uv_rot[..., 0] 58 | v = uv_rot[..., 1] 59 | 60 | x = np.tan(u) 61 | y = np.tan(v) / np.cos(u) 62 | x = x * w / (2 * np.tan(u_fov / 2)) + w / 2 63 | y = y * h / (2 * np.tan(v_fov / 2)) + h / 2 64 | 65 | invalid = (u < -u_fov / 2) | (u > u_fov / 2) |\ 66 | (v < -v_fov / 2) | (v > v_fov / 2) 67 | x[invalid] = -100 68 | y[invalid] = -100 69 | 70 | return np.stack([y, x], axis=0) 71 | 72 | 73 | class OmniDataset(data.Dataset): 74 | def __init__(self, dataset, fov=120, outshape=(60, 60), 75 | flip=False, h_rotate=False, v_rotate=False, 76 | img_mean=None, img_std=None, fix_aug=False): 77 | ''' 78 | Convert classification dataset to omnidirectional version 79 | @dataset dataset with same interface as torch.utils.data.Dataset 80 | yield (PIL image, label) if indexing 81 | ''' 82 | self.dataset = dataset 83 | self.fov = fov 84 | self.outshape = outshape 85 | self.flip = flip 86 | self.h_rotate = h_rotate 87 | self.v_rotate = v_rotate 88 | self.img_mean = img_mean 89 | self.img_std = img_std 90 | 91 | self.aug = None 92 | if fix_aug: 93 | self.aug = [ 94 | { 95 | 'flip': np.random.randint(2) == 0, 96 | 'h_rotate': np.random.randint(outshape[1]), 97 | 'v_rotate': np.random.uniform(-np.pi/2, np.pi/2), 98 | } 99 | for _ in range(len(self.dataset)) 100 | ] 101 | 102 | def __len__(self): 103 | return len(self.dataset) 104 | 105 | def __getitem__(self, idx): 106 | img = np.array(self.dataset[idx][0], np.float32) 107 | h, w = img.shape[:2] 108 | uv = genuv(*self.outshape) 109 | fov = self.fov * np.pi / 180 110 | 111 | if self.v_rotate: 112 | if self.aug is not None: 113 | v_c = self.aug[idx]['v_rotate'] 114 | else: 115 | v_c = np.random.uniform(-np.pi/2, np.pi/2) 116 | img_idx = uv2img_idx(uv, h, w, fov, fov, v_c) 117 | else: 118 | img_idx = uv2img_idx(uv, h, w, fov, fov, 0) 119 | x = map_coordinates(img, img_idx, order=1) 120 | 121 | # Random flip 122 | if self.aug is not None: 123 | if self.aug[idx]['flip']: 124 | x = np.flip(x, axis=1) 125 | elif self.flip and np.random.randint(2) == 0: 126 | x = np.flip(x, axis=1) 127 | 128 | # Random horizontal rotate 129 | if self.h_rotate: 130 | if self.aug is not None: 131 | dx = self.aug[idx]['h_rotate'] 132 | else: 133 | dx = np.random.randint(x.shape[1]) 134 | x = np.roll(x, dx, axis=1) 135 | 136 | # Normalize image 137 | if self.img_mean is not None: 138 | x = x - self.img_mean 139 | if self.img_std is not None: 140 | x = x / self.img_std 141 | 142 | return torch.FloatTensor(x.copy()), self.dataset[idx][1] 143 | 144 | 145 | class OmniMNIST(OmniDataset): 146 | def __init__(self, root='datas/MNIST', train=True, 147 | download=True, *args, **kwargs): 148 | ''' 149 | Omnidirectional MNIST 150 | @root (str) root directory storing the dataset 151 | @train (bool) train or test split 152 | @download (bool) whether to download if data now exist 153 | ''' 154 | self.MNIST = datasets.MNIST(root, train=train, download=download) 155 | super(OmniMNIST, self).__init__(self.MNIST, *args, **kwargs) 156 | 157 | 158 | class OmniFashionMNIST(OmniDataset): 159 | def __init__(self, root='datas/FashionMNIST', train=True, 160 | download=True, *args, **kwargs): 161 | ''' 162 | Omnidirectional FashionMNIST 163 | @root (str) root directory storing the dataset 164 | @train (bool) train or test split 165 | @download (bool) whether to download if data now exist 166 | ''' 167 | self.FashionMNIST = datasets.FashionMNIST(root, train=train, download=download) 168 | super(OmniFashionMNIST, self).__init__(self.FashionMNIST, *args, **kwargs) 169 | 170 | 171 | if __name__ == '__main__': 172 | 173 | import os 174 | import argparse 175 | from PIL import Image 176 | 177 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 178 | parser.add_argument('--idx', nargs='+', required=True, 179 | help='image indices to demo') 180 | parser.add_argument('--out_dir', default='datas/demo', 181 | help='directory to output demo image') 182 | parser.add_argument('--dataset', default='OmniMNIST', 183 | choices=['OmniMNIST', 'OmniFashionMNIST'], 184 | help='which dataset to use') 185 | 186 | parser.add_argument('--fov', type=int, default=120, 187 | help='fov of the tangent plane') 188 | parser.add_argument('--flip', action='store_true', 189 | help='whether to apply random flip') 190 | parser.add_argument('--h_rotate', action='store_true', 191 | help='whether to apply random panorama horizontal rotation') 192 | parser.add_argument('--v_rotate', action='store_true', 193 | help='whether to apply random panorama vertical rotation') 194 | parser.add_argument('--fix_aug', action='store_true', 195 | help='whether to apply random panorama vertical rotation') 196 | args = parser.parse_args() 197 | 198 | os.makedirs(args.out_dir, exist_ok=True) 199 | 200 | if args.dataset == 'OmniMNIST': 201 | dataset = OmniMNIST(fov=args.fov, flip=args.flip, 202 | h_rotate=args.h_rotate, v_rotate=args.v_rotate, 203 | fix_aug=args.fix_aug) 204 | elif args.dataset == 'OmniFashionMNIST': 205 | dataset = OmniFashionMNIST(fov=args.fov, flip=args.flip, 206 | h_rotate=args.h_rotate, v_rotate=args.v_rotate, 207 | fix_aug=args.fix_aug) 208 | 209 | for idx in args.idx: 210 | idx = int(idx) 211 | path = os.path.join(args.out_dir, '%d.png' % idx) 212 | x, label = dataset[idx] 213 | 214 | print(path, label) 215 | Image.fromarray(x.numpy().astype(np.uint8)).save(path) 216 | -------------------------------------------------------------------------------- /spherenet/sphere_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import sin, cos, tan, pi, arcsin, arctan 3 | from functools import lru_cache 4 | import torch 5 | from torch import nn 6 | from torch.nn.parameter import Parameter 7 | 8 | 9 | # Calculate kernels of SphereCNN 10 | @lru_cache(None) 11 | def get_xy(delta_phi, delta_theta): 12 | return np.array([ 13 | [ 14 | (-tan(delta_theta), 1/cos(delta_theta)*tan(delta_phi)), 15 | (0, tan(delta_phi)), 16 | (tan(delta_theta), 1/cos(delta_theta)*tan(delta_phi)), 17 | ], 18 | [ 19 | (-tan(delta_theta), 0), 20 | (1, 1), 21 | (tan(delta_theta), 0), 22 | ], 23 | [ 24 | (-tan(delta_theta), -1/cos(delta_theta)*tan(delta_phi)), 25 | (0, -tan(delta_phi)), 26 | (tan(delta_theta), -1/cos(delta_theta)*tan(delta_phi)), 27 | ] 28 | ]) 29 | 30 | @lru_cache(None) 31 | def cal_index(h, w, img_r, img_c): 32 | ''' 33 | Calculate Kernel Sampling Pattern 34 | only support 3x3 filter 35 | return 9 locations: (3, 3, 2) 36 | ''' 37 | # pixel -> rad 38 | phi = -((img_r+0.5)/h*pi - pi/2) 39 | theta = (img_c+0.5)/w*2*pi-pi 40 | 41 | delta_phi = pi/h 42 | delta_theta = 2*pi/w 43 | 44 | xys = get_xy(delta_phi, delta_theta) 45 | x = xys[..., 0] 46 | y = xys[..., 1] 47 | rho = np.sqrt(x**2+y**2) 48 | v = arctan(rho) 49 | new_phi= arcsin(cos(v)*sin(phi) + y*sin(v)*cos(phi)/rho) 50 | new_theta = theta + arctan(x*sin(v) / (rho*cos(phi)*cos(v) - y*sin(phi)*sin(v))) 51 | # rad -> pixel 52 | new_r = (-new_phi+pi/2)*h/pi - 0.5 53 | new_c = (new_theta+pi)*w/2/pi - 0.5 54 | # indexs out of image, equirectangular leftmost and rightmost pixel is adjacent 55 | new_c = (new_c + w) % w 56 | new_result = np.stack([new_r, new_c], axis=-1) 57 | new_result[1, 1] = (img_r, img_c) 58 | return new_result 59 | 60 | 61 | @lru_cache(None) 62 | def _gen_filters_coordinates(h, w, stride): 63 | co = np.array([[cal_index(h, w, i, j) for j in range(0, w, stride)] for i in range(0, h, stride)]) 64 | return np.ascontiguousarray(co.transpose([4, 0, 1, 2, 3])) 65 | 66 | 67 | def gen_filters_coordinates(h, w, stride=1): 68 | ''' 69 | return np array of kernel lo (2, H/stride, W/stride, 3, 3) 70 | ''' 71 | assert(isinstance(h, int) and isinstance(w, int)) 72 | return _gen_filters_coordinates(h, w, stride).copy() 73 | 74 | 75 | def gen_grid_coordinates(h, w, stride=1): 76 | coordinates = gen_filters_coordinates(h, w, stride).copy() 77 | coordinates[0] = (coordinates[0] * 2 / h) - 1 78 | coordinates[1] = (coordinates[1] * 2 / w) - 1 79 | coordinates = coordinates[::-1] 80 | coordinates = coordinates.transpose(1, 3, 2, 4, 0) 81 | sz = coordinates.shape 82 | coordinates = coordinates.reshape(1, sz[0]*sz[1], sz[2]*sz[3], sz[4]) 83 | 84 | return coordinates.copy() 85 | 86 | 87 | class SphereConv2D(nn.Module): 88 | ''' SphereConv2D 89 | Note that this layer only support 3x3 filter 90 | ''' 91 | def __init__(self, in_c, out_c, stride=1, bias=True, mode='bilinear'): 92 | super(SphereConv2D, self).__init__() 93 | self.in_c = in_c 94 | self.out_c = out_c 95 | self.stride = stride 96 | self.mode = mode 97 | self.weight = Parameter(torch.Tensor(out_c, in_c, 3, 3)) 98 | if bias: 99 | self.bias = Parameter(torch.Tensor(out_c)) 100 | else: 101 | self.register_parameter('bias', None) 102 | self.grid_shape = None 103 | self.grid = None 104 | 105 | self.reset_parameters() 106 | 107 | def reset_parameters(self): 108 | nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5)) 109 | self.bias.data.zero_() 110 | 111 | def forward(self, x): 112 | if self.grid_shape is None or self.grid_shape != tuple(x.shape[2:4]): 113 | self.grid_shape = tuple(x.shape[2:4]) 114 | coordinates = gen_grid_coordinates(x.shape[2], x.shape[3], self.stride) 115 | with torch.no_grad(): 116 | self.grid = torch.FloatTensor(coordinates).to(x.device) 117 | self.grid.requires_grad = True 118 | 119 | with torch.no_grad(): 120 | grid = self.grid.repeat(x.shape[0], 1, 1, 1) 121 | 122 | x = nn.functional.grid_sample(x, grid, mode=self.mode) 123 | x = nn.functional.conv2d(x, self.weight, self.bias, stride=3) 124 | return x 125 | 126 | 127 | class SphereMaxPool2D(nn.Module): 128 | ''' SphereMaxPool2D 129 | Note that this layer only support 3x3 filter 130 | ''' 131 | def __init__(self, stride=1, mode='bilinear'): 132 | super(SphereMaxPool2D, self).__init__() 133 | self.stride = stride 134 | self.mode = mode 135 | self.grid_shape = None 136 | self.grid = None 137 | self.pool = nn.MaxPool2d(kernel_size=3, stride=3) 138 | 139 | def forward(self, x): 140 | if self.grid_shape is None or self.grid_shape != tuple(x.shape[2:4]): 141 | self.grid_shape = tuple(x.shape[2:4]) 142 | coordinates = gen_grid_coordinates(x.shape[2], x.shape[3], self.stride) 143 | with torch.no_grad(): 144 | self.grid = torch.FloatTensor(coordinates).to(x.device) 145 | self.grid.requires_grad = True 146 | 147 | with torch.no_grad(): 148 | grid = self.grid.repeat(x.shape[0], 1, 1, 1) 149 | 150 | return self.pool(nn.functional.grid_sample(x, grid, mode=self.mode)) 151 | 152 | 153 | if __name__ == '__main__': 154 | import matplotlib.pyplot as plt 155 | import matplotlib.image as mpimg 156 | # test cnn 157 | cnn = SphereConv2D(3, 5, 1) 158 | out = cnn(torch.randn(2, 3, 10, 10)) 159 | print('SphereConv2D(3, 5, 1) output shape: ', out.size()) 160 | # test pool 161 | # create sample image 162 | h, w = 100, 200 163 | img = np.ones([h, w, 3]) 164 | for r in range(h): 165 | for c in range(w): 166 | img[r, c, 0] = img[r, c, 0] - r/h 167 | img[r, c, 1] = img[r, c, 1] - c/w 168 | plt.imsave('demo_original', img) 169 | img = img.transpose([2, 0, 1]) 170 | img = np.expand_dims(img, 0) # (B, C, H, W) 171 | # pool 172 | pool = SphereMaxPool2D(1) 173 | out = pool(torch.from_numpy(img).float()) 174 | out = np.squeeze(out.numpy(), 0).transpose([1, 2, 0]) 175 | plt.imsave('demo_pool_1.png', out) 176 | print('Save image after pooling with stride 1: demo_pool_1.png') 177 | # pool with tride 3 178 | pool = SphereMaxPool2D(3) 179 | out = pool(torch.from_numpy(img).float()) 180 | out = np.squeeze(out.numpy(), 0).transpose([1, 2, 0]) 181 | plt.imsave('demo_pool_3.png', out) 182 | print('Save image after pooling with stride 3: demo_pool_3.png') 183 | --------------------------------------------------------------------------------