├── .gitignore ├── README.md ├── demo ├── equirectangular.png └── equirectangular_earth.png ├── requirement.txt └── src ├── demo.py ├── demo_maxPool.py └── spherenet ├── GridGenerator.py ├── SphereConv2d.py ├── SphereMaxPool2d.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm 132 | /.idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sphereConv-pytorch 2 | 3 | Fast and Simple Spherical Convolution PyTorch code 🌏 4 | 5 | This Code is an unofficial implementation of "SphereNet: Learning Spherical Representations for Detection and Classification in Omnidirectional Images(ECCV 2018)", and upgrade version of [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch). 6 | 7 | This Code supports spherical kernel sampling on "Equirectangular Image"! 8 | 9 | I wrote the code to be `numpy`-friendly and `torch`-friendly. 😉 10 | 11 | - [x] `numpy`-friendly 12 | - [x] `torch`-friendly 13 | - [x] Support all size of kernel shape (ex: `3x3`, `2x2`, `3x4`, ...) 14 | - [x] Super Fast! 👍 15 | - [ ] Omnidirectional Dataset
(If you want Omni-Dataset, use this repo [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch)) 16 | 17 | ## Demo Result 18 | 19 | ![demo](https://i.imgur.com/CWews2K.png) 20 | 21 | Spherical Kernel can cross over the sides, and left side has brighter color. Therefore, in result image, the right side has bright color by "MaxPooling"! 22 | 23 | 24 | ## Quick Start 25 | 26 | Before start, you should install `pytorch`!! (This code also run on CPU.) 27 | 28 | ``` 29 | cd src 30 | python demo.py 31 | python demo_maxPool.py 32 | ``` 33 | 34 | ## Code Detail 35 | 36 | ### class `GridGenerator` 37 | 38 | This is a class that supports to generate spherical sampling grid on equirectangular image. 39 | 40 | ``` python 41 | gridGenerator = GridGenerator(h, w, self.kernel_size, self.stride) 42 | LonLatSamplingPattern = gridGenerator.createSamplingPattern() 43 | ``` 44 | 45 | This code only use `numpy` and is written `numpy`-friendly! However, this code is super `numpy`-friendly you may feel hard to understand the flow of code 😢. 46 | 47 | I attach some comments on my code and explain how the shape of array changes. Good Luck 🤞. 48 | 49 | 50 | ### class `SphereConv2d` 51 | 52 | This is an implementation of spherical Convolution. This class inherits `nn.Conv2d`, so you can replace `nn.Conv2d` into this. 53 | 54 | ``` python 55 | cnn = SphereConv2d(3, 5, kernel_size=3, stride=1) 56 | out = cnn(torch.randn(2, 3, 10, 10)) 57 | ``` 58 | 59 | This code support various shape of kernels: `(3x3, 2x2, 3x8, ...)`. 60 | 61 | You can test this by using OmniMNIST Dataset from [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch). I've tested using this, and got similar or improved result! 62 | 63 | ### class `SphereMaxPool2d` 64 | 65 | This is an implementation of spherical Convolution. This class inherits `nn.MaxPool2d`, so you can replace `nn.MaxPool2d` into this. 66 | 67 | ``` python 68 | pool = SphereMaxPool2d(kernel_size=3, stride=3) 69 | out = pool(torch.from_numpy(img).float()) 70 | ``` 71 | 72 | Also, this code support various shape of pooling shape! 73 | 74 | Likewise, you can test this by using OmniMNIST Dataset from [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch). 75 | 76 | ## Further Reading 77 | 78 | - Some formulas are inspired by Paul Bourke's work. [link](http://paulbourke.net/dome/dualfish2sphere/) 79 | - If you want to rotate equirectangular image, see my implementation! [BlueHorn07/pyEquirectRotate](https://github.com/BlueHorn07/pyEquirectRotate) 80 | - If you want more awesome omnidirectional python codes, I recommend this repository! 81 | - [sunset1995/py360convert](https://github.com/sunset1995/py360convert) 82 | - I've forked this `py360convert`, and add `p2e`, perspective2equirectangular.
[BlueHorn07/py360convert](https://github.com/BlueHorn07/py360convert), [`p2e`](https://github.com/BlueHorn07/py360convert#p2ep_img-fov_deg-u_deg-v_deg-out_hw-in_rot_deg0) 83 | -------------------------------------------------------------------------------- /demo/equirectangular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueHorn07/sphereConv-pytorch/7c29d731d199f563f4c22e18c5d082030f53a046/demo/equirectangular.png -------------------------------------------------------------------------------- /demo/equirectangular_earth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BlueHorn07/sphereConv-pytorch/7c29d731d199f563f4c22e18c5d082030f53a046/demo/equirectangular_earth.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | python-opencv 3 | matplotlib -------------------------------------------------------------------------------- /src/demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from spherenet import SphereMaxPool2d, SphereConv2d 3 | 4 | import numpy as np 5 | import torch 6 | 7 | if __name__ == '__main__': 8 | """ 9 | This demo code is originated from here "https://github.com/ChiWeiHsiao/SphereNet-pytorch" 10 | """ 11 | 12 | # SphereConv2d 13 | cnn = SphereConv2d(3, 5, kernel_size=3, stride=1) 14 | out = cnn(torch.randn(2, 3, 10, 10)) 15 | print('SphereConv2d(3, 5, 1) output shape: ', out.size()) 16 | 17 | # SphereMaxPool2d 18 | h, w = 100, 200 19 | img = np.ones([h, w, 3]) 20 | for r in range(h): 21 | for c in range(w): 22 | img[r, c, 0] = img[r, c, 0] - r/h 23 | img[r, c, 1] = img[r, c, 1] - c/w 24 | plt.imsave('demo_original.png', img) 25 | img = img.transpose([2, 0, 1]) 26 | img = np.expand_dims(img, 0) # (B, C, H, W) 27 | 28 | # pool 29 | pool = SphereMaxPool2d(kernel_size=3, stride=3) 30 | out = pool(torch.from_numpy(img).float()) 31 | 32 | out = np.squeeze(out.numpy(), 0).transpose([1, 2, 0]) 33 | plt.imsave('demo_pool_3x3.png', out) 34 | 35 | 36 | -------------------------------------------------------------------------------- /src/demo_maxPool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | from spherenet import SphereMaxPool2d 6 | 7 | if __name__ == '__main__': 8 | img = cv2.imread("../demo/equirectangular_earth.png") 9 | 10 | img = img.transpose((2, 0, 1)) 11 | img = np.expand_dims(img, 0) 12 | 13 | spherePool = SphereMaxPool2d(3, stride=3) 14 | out = spherePool(torch.from_numpy(img).float()) 15 | out = np.squeeze(out.numpy(), 0).transpose((1, 2, 0)) 16 | cv2.imwrite("sphere_maxPooled.jpg", out) 17 | cv2.waitKey() 18 | cv2.destroyAllWindows() 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/spherenet/GridGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class GridGenerator: 5 | def __init__(self, height: int, width: int, kernel_size, stride=1): 6 | self.height = height 7 | self.width = width 8 | self.kernel_size = kernel_size # (Kh, Kw) 9 | self.stride = stride # (H, W) 10 | 11 | def createSamplingPattern(self): 12 | """ 13 | :return: (1, H*Kh, W*Kw, (Lat, Lon)) sampling pattern 14 | """ 15 | kerX, kerY = self.createKernel() # (Kh, Kw) 16 | 17 | # create some values using in generating lat/lon sampling pattern 18 | rho = np.sqrt(kerX ** 2 + kerY ** 2) 19 | Kh, Kw = self.kernel_size 20 | # when the value of rho at center is zero, some lat values explode to `nan`. 21 | if Kh % 2 and Kw % 2: 22 | rho[Kh // 2][Kw // 2] = 1e-8 23 | 24 | nu = np.arctan(rho) 25 | cos_nu = np.cos(nu) 26 | sin_nu = np.sin(nu) 27 | 28 | stride_h, stride_w = self.stride 29 | h_range = np.arange(0, self.height, stride_h) 30 | w_range = np.arange(0, self.width, stride_w) 31 | 32 | lat_range = ((h_range / self.height) - 0.5) * np.pi 33 | lon_range = ((w_range / self.width) - 0.5) * (2 * np.pi) 34 | 35 | # generate latitude sampling pattern 36 | lat = np.array([ 37 | np.arcsin(cos_nu * np.sin(_lat) + kerY * sin_nu * np.cos(_lat) / rho) for _lat in lat_range 38 | ]) # (H, Kh, Kw) 39 | 40 | lat = np.array([lat for _ in lon_range]) # (W, H, Kh, Kw) 41 | lat = lat.transpose((1, 0, 2, 3)) # (H, W, Kh, Kw) 42 | 43 | # generate longitude sampling pattern 44 | lon = np.array([ 45 | np.arctan(kerX * sin_nu / (rho * np.cos(_lat) * cos_nu - kerY * np.sin(_lat) * sin_nu)) for _lat in lat_range 46 | ]) # (H, Kh, Kw) 47 | 48 | lon = np.array([lon + _lon for _lon in lon_range]) # (W, H, Kh, Kw) 49 | lon = lon.transpose((1, 0, 2, 3)) # (H, W, Kh, Kw) 50 | 51 | # (radian) -> (index of pixel) 52 | lat = (lat / np.pi + 0.5) * self.height 53 | lon = ((lon / (2 * np.pi) + 0.5) * self.width) % self.width 54 | 55 | LatLon = np.stack((lat, lon)) # (2, H, W, Kh, Kw) = ((lat, lon), H, W, Kh, Kw) 56 | LatLon = LatLon.transpose((1, 3, 2, 4, 0)) # (H, Kh, W, Kw, 2) = (H, Kh, W, Kw, (lat, lon)) 57 | 58 | H, Kh, W, Kw, d = LatLon.shape 59 | LatLon = LatLon.reshape((1, H * Kh, W * Kw, d)) # (1, H*Kh, W*Kw, 2) 60 | 61 | return LatLon 62 | 63 | def createKernel(self): 64 | """ 65 | :return: (Ky, Kx) kernel pattern 66 | """ 67 | Kh, Kw = self.kernel_size 68 | 69 | delta_lat = np.pi / self.height 70 | delta_lon = 2 * np.pi / self.width 71 | 72 | range_x = np.arange(-(Kw // 2), Kw // 2 + 1) 73 | if not Kw % 2: 74 | range_x = np.delete(range_x, Kw // 2) 75 | 76 | range_y = np.arange(-(Kh // 2), Kh // 2 + 1) 77 | if not Kh % 2: 78 | range_y = np.delete(range_y, Kh // 2) 79 | 80 | kerX = np.tan(range_x * delta_lon) 81 | kerY = np.tan(range_y * delta_lat) / np.cos(range_y * delta_lon) 82 | 83 | return np.meshgrid(kerX, kerY) # (Kh, Kw) 84 | -------------------------------------------------------------------------------- /src/spherenet/SphereConv2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from .GridGenerator import GridGenerator 7 | 8 | 9 | class SphereConv2d(nn.Conv2d): 10 | """ 11 | kernel_size: (H, W) 12 | """ 13 | 14 | def __init__(self, in_channels: int, out_channels: int, kernel_size=(3, 3), 15 | stride=1, padding=0, dilation=1, 16 | groups: int = 1, bias: bool = True, padding_mode: str = 'zeros'): 17 | super(SphereConv2d, self).__init__( 18 | in_channels, out_channels, kernel_size, 19 | stride, padding, dilation, groups, bias, padding_mode) 20 | self.grid_shape = None 21 | self.grid = None 22 | 23 | def genSamplingPattern(self, h, w): 24 | gridGenerator = GridGenerator(h, w, self.kernel_size, self.stride) 25 | LonLatSamplingPattern = gridGenerator.createSamplingPattern() 26 | 27 | # generate grid to use `F.grid_sample` 28 | lat_grid = (LonLatSamplingPattern[:, :, :, 0] / h) * 2 - 1 29 | lon_grid = (LonLatSamplingPattern[:, :, :, 1] / w) * 2 - 1 30 | 31 | grid = np.stack((lon_grid, lat_grid), axis=-1) 32 | with torch.no_grad(): 33 | self.grid = torch.FloatTensor(grid) 34 | self.grid.requires_grad = False 35 | 36 | def forward(self, x): 37 | # Generate Sampling Pattern 38 | B, C, H, W = x.shape 39 | 40 | if (self.grid_shape is None) or (self.grid_shape != (H, W)): 41 | self.grid_shape = (H, W) 42 | self.genSamplingPattern(H, W) 43 | 44 | with torch.no_grad(): 45 | grid = self.grid.repeat((B, 1, 1, 1)).to(x.device) # (B, H*Kh, W*Kw, 2) 46 | grid.requires_grad = False 47 | 48 | x = F.grid_sample(x, grid, align_corners=True, mode='nearest') # (B, in_c, H*Kh, W*Kw) 49 | 50 | # self.weight -> (out_c, in_c, Kh, Kw) 51 | x = F.conv2d(x, self.weight, self.bias, stride=self.kernel_size) 52 | 53 | return x # (B, out_c, H/stride_h, W/stride_w) 54 | -------------------------------------------------------------------------------- /src/spherenet/SphereMaxPool2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from .GridGenerator import GridGenerator 7 | 8 | 9 | class SphereMaxPool2d(nn.MaxPool2d): 10 | """ 11 | kernel_size: (H, W) 12 | """ 13 | 14 | def __init__(self, kernel_size=(3, 3), stride=1, padding=0, dilation=1, 15 | return_indices: bool = False, ceil_mode: bool = False): 16 | super(SphereMaxPool2d, self).__init__( 17 | kernel_size, stride, padding, dilation, return_indices, ceil_mode) 18 | if isinstance(kernel_size, int): 19 | self.kernel_size = (kernel_size, kernel_size) 20 | if isinstance(stride, int): 21 | self.stride = (stride, stride) 22 | 23 | self.grid_shape = None 24 | self.grid = None 25 | 26 | def genSamplingPattern(self, h, w): 27 | gridGenerator = GridGenerator(h, w, self.kernel_size, self.stride) 28 | LonLatSamplingPattern = gridGenerator.createSamplingPattern() 29 | 30 | # generate grid to use `F.grid_sample` 31 | lat_grid = (LonLatSamplingPattern[:, :, :, 0] / h) * 2 - 1 32 | lon_grid = (LonLatSamplingPattern[:, :, :, 1] / w) * 2 - 1 33 | 34 | grid = np.stack((lon_grid, lat_grid), axis=-1) 35 | 36 | with torch.no_grad(): 37 | self.grid = torch.FloatTensor(grid) 38 | self.grid.requires_grad = False 39 | 40 | def forward(self, x): 41 | # Generate Sampling Pattern 42 | B, C, H, W = x.shape 43 | 44 | if (self.grid_shape is None) or (self.grid_shape != (H, W)): 45 | self.grid_shape = (H, W) 46 | self.genSamplingPattern(H, W) 47 | 48 | with torch.no_grad(): 49 | grid = self.grid.repeat((B, 1, 1, 1)).to(x.device) # (B, H*Kh, W*Kw, 2) 50 | grid.requires_grad = False 51 | 52 | x = F.grid_sample(x, grid, align_corners=False, mode='bilinear') # (B, in_c, H*Kh, W*Kw) 53 | 54 | x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.kernel_size) 55 | 56 | return x # (B, out_c, H/stride_h, W/stride_w) 57 | -------------------------------------------------------------------------------- /src/spherenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .GridGenerator import GridGenerator 2 | from .SphereConv2d import SphereConv2d 3 | from .SphereMaxPool2d import SphereMaxPool2d 4 | --------------------------------------------------------------------------------