├── .gitignore
├── README.md
├── demo
├── equirectangular.png
└── equirectangular_earth.png
├── requirement.txt
└── src
├── demo.py
├── demo_maxPool.py
└── spherenet
├── GridGenerator.py
├── SphereConv2d.py
├── SphereMaxPool2d.py
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # PyCharm
132 | /.idea
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sphereConv-pytorch
2 |
3 | Fast and Simple Spherical Convolution PyTorch code 🌏
4 |
5 | This Code is an unofficial implementation of "SphereNet: Learning Spherical Representations for Detection and Classification in Omnidirectional Images(ECCV 2018)", and upgrade version of [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch).
6 |
7 | This Code supports spherical kernel sampling on "Equirectangular Image"!
8 |
9 | I wrote the code to be `numpy`-friendly and `torch`-friendly. 😉
10 |
11 | - [x] `numpy`-friendly
12 | - [x] `torch`-friendly
13 | - [x] Support all size of kernel shape (ex: `3x3`, `2x2`, `3x4`, ...)
14 | - [x] Super Fast! 👍
15 | - [ ] Omnidirectional Dataset
(If you want Omni-Dataset, use this repo [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch))
16 |
17 | ## Demo Result
18 |
19 | 
20 |
21 | Spherical Kernel can cross over the sides, and left side has brighter color. Therefore, in result image, the right side has bright color by "MaxPooling"!
22 |
23 |
24 | ## Quick Start
25 |
26 | Before start, you should install `pytorch`!! (This code also run on CPU.)
27 |
28 | ```
29 | cd src
30 | python demo.py
31 | python demo_maxPool.py
32 | ```
33 |
34 | ## Code Detail
35 |
36 | ### class `GridGenerator`
37 |
38 | This is a class that supports to generate spherical sampling grid on equirectangular image.
39 |
40 | ``` python
41 | gridGenerator = GridGenerator(h, w, self.kernel_size, self.stride)
42 | LonLatSamplingPattern = gridGenerator.createSamplingPattern()
43 | ```
44 |
45 | This code only use `numpy` and is written `numpy`-friendly! However, this code is super `numpy`-friendly you may feel hard to understand the flow of code 😢.
46 |
47 | I attach some comments on my code and explain how the shape of array changes. Good Luck 🤞.
48 |
49 |
50 | ### class `SphereConv2d`
51 |
52 | This is an implementation of spherical Convolution. This class inherits `nn.Conv2d`, so you can replace `nn.Conv2d` into this.
53 |
54 | ``` python
55 | cnn = SphereConv2d(3, 5, kernel_size=3, stride=1)
56 | out = cnn(torch.randn(2, 3, 10, 10))
57 | ```
58 |
59 | This code support various shape of kernels: `(3x3, 2x2, 3x8, ...)`.
60 |
61 | You can test this by using OmniMNIST Dataset from [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch). I've tested using this, and got similar or improved result!
62 |
63 | ### class `SphereMaxPool2d`
64 |
65 | This is an implementation of spherical Convolution. This class inherits `nn.MaxPool2d`, so you can replace `nn.MaxPool2d` into this.
66 |
67 | ``` python
68 | pool = SphereMaxPool2d(kernel_size=3, stride=3)
69 | out = pool(torch.from_numpy(img).float())
70 | ```
71 |
72 | Also, this code support various shape of pooling shape!
73 |
74 | Likewise, you can test this by using OmniMNIST Dataset from [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch).
75 |
76 | ## Further Reading
77 |
78 | - Some formulas are inspired by Paul Bourke's work. [link](http://paulbourke.net/dome/dualfish2sphere/)
79 | - If you want to rotate equirectangular image, see my implementation! [BlueHorn07/pyEquirectRotate](https://github.com/BlueHorn07/pyEquirectRotate)
80 | - If you want more awesome omnidirectional python codes, I recommend this repository!
81 | - [sunset1995/py360convert](https://github.com/sunset1995/py360convert)
82 | - I've forked this `py360convert`, and add `p2e`, perspective2equirectangular.
[BlueHorn07/py360convert](https://github.com/BlueHorn07/py360convert), [`p2e`](https://github.com/BlueHorn07/py360convert#p2ep_img-fov_deg-u_deg-v_deg-out_hw-in_rot_deg0)
83 |
--------------------------------------------------------------------------------
/demo/equirectangular.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueHorn07/sphereConv-pytorch/7c29d731d199f563f4c22e18c5d082030f53a046/demo/equirectangular.png
--------------------------------------------------------------------------------
/demo/equirectangular_earth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BlueHorn07/sphereConv-pytorch/7c29d731d199f563f4c22e18c5d082030f53a046/demo/equirectangular_earth.png
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | python-opencv
3 | matplotlib
--------------------------------------------------------------------------------
/src/demo.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from spherenet import SphereMaxPool2d, SphereConv2d
3 |
4 | import numpy as np
5 | import torch
6 |
7 | if __name__ == '__main__':
8 | """
9 | This demo code is originated from here "https://github.com/ChiWeiHsiao/SphereNet-pytorch"
10 | """
11 |
12 | # SphereConv2d
13 | cnn = SphereConv2d(3, 5, kernel_size=3, stride=1)
14 | out = cnn(torch.randn(2, 3, 10, 10))
15 | print('SphereConv2d(3, 5, 1) output shape: ', out.size())
16 |
17 | # SphereMaxPool2d
18 | h, w = 100, 200
19 | img = np.ones([h, w, 3])
20 | for r in range(h):
21 | for c in range(w):
22 | img[r, c, 0] = img[r, c, 0] - r/h
23 | img[r, c, 1] = img[r, c, 1] - c/w
24 | plt.imsave('demo_original.png', img)
25 | img = img.transpose([2, 0, 1])
26 | img = np.expand_dims(img, 0) # (B, C, H, W)
27 |
28 | # pool
29 | pool = SphereMaxPool2d(kernel_size=3, stride=3)
30 | out = pool(torch.from_numpy(img).float())
31 |
32 | out = np.squeeze(out.numpy(), 0).transpose([1, 2, 0])
33 | plt.imsave('demo_pool_3x3.png', out)
34 |
35 |
36 |
--------------------------------------------------------------------------------
/src/demo_maxPool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 |
5 | from spherenet import SphereMaxPool2d
6 |
7 | if __name__ == '__main__':
8 | img = cv2.imread("../demo/equirectangular_earth.png")
9 |
10 | img = img.transpose((2, 0, 1))
11 | img = np.expand_dims(img, 0)
12 |
13 | spherePool = SphereMaxPool2d(3, stride=3)
14 | out = spherePool(torch.from_numpy(img).float())
15 | out = np.squeeze(out.numpy(), 0).transpose((1, 2, 0))
16 | cv2.imwrite("sphere_maxPooled.jpg", out)
17 | cv2.waitKey()
18 | cv2.destroyAllWindows()
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/src/spherenet/GridGenerator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class GridGenerator:
5 | def __init__(self, height: int, width: int, kernel_size, stride=1):
6 | self.height = height
7 | self.width = width
8 | self.kernel_size = kernel_size # (Kh, Kw)
9 | self.stride = stride # (H, W)
10 |
11 | def createSamplingPattern(self):
12 | """
13 | :return: (1, H*Kh, W*Kw, (Lat, Lon)) sampling pattern
14 | """
15 | kerX, kerY = self.createKernel() # (Kh, Kw)
16 |
17 | # create some values using in generating lat/lon sampling pattern
18 | rho = np.sqrt(kerX ** 2 + kerY ** 2)
19 | Kh, Kw = self.kernel_size
20 | # when the value of rho at center is zero, some lat values explode to `nan`.
21 | if Kh % 2 and Kw % 2:
22 | rho[Kh // 2][Kw // 2] = 1e-8
23 |
24 | nu = np.arctan(rho)
25 | cos_nu = np.cos(nu)
26 | sin_nu = np.sin(nu)
27 |
28 | stride_h, stride_w = self.stride
29 | h_range = np.arange(0, self.height, stride_h)
30 | w_range = np.arange(0, self.width, stride_w)
31 |
32 | lat_range = ((h_range / self.height) - 0.5) * np.pi
33 | lon_range = ((w_range / self.width) - 0.5) * (2 * np.pi)
34 |
35 | # generate latitude sampling pattern
36 | lat = np.array([
37 | np.arcsin(cos_nu * np.sin(_lat) + kerY * sin_nu * np.cos(_lat) / rho) for _lat in lat_range
38 | ]) # (H, Kh, Kw)
39 |
40 | lat = np.array([lat for _ in lon_range]) # (W, H, Kh, Kw)
41 | lat = lat.transpose((1, 0, 2, 3)) # (H, W, Kh, Kw)
42 |
43 | # generate longitude sampling pattern
44 | lon = np.array([
45 | np.arctan(kerX * sin_nu / (rho * np.cos(_lat) * cos_nu - kerY * np.sin(_lat) * sin_nu)) for _lat in lat_range
46 | ]) # (H, Kh, Kw)
47 |
48 | lon = np.array([lon + _lon for _lon in lon_range]) # (W, H, Kh, Kw)
49 | lon = lon.transpose((1, 0, 2, 3)) # (H, W, Kh, Kw)
50 |
51 | # (radian) -> (index of pixel)
52 | lat = (lat / np.pi + 0.5) * self.height
53 | lon = ((lon / (2 * np.pi) + 0.5) * self.width) % self.width
54 |
55 | LatLon = np.stack((lat, lon)) # (2, H, W, Kh, Kw) = ((lat, lon), H, W, Kh, Kw)
56 | LatLon = LatLon.transpose((1, 3, 2, 4, 0)) # (H, Kh, W, Kw, 2) = (H, Kh, W, Kw, (lat, lon))
57 |
58 | H, Kh, W, Kw, d = LatLon.shape
59 | LatLon = LatLon.reshape((1, H * Kh, W * Kw, d)) # (1, H*Kh, W*Kw, 2)
60 |
61 | return LatLon
62 |
63 | def createKernel(self):
64 | """
65 | :return: (Ky, Kx) kernel pattern
66 | """
67 | Kh, Kw = self.kernel_size
68 |
69 | delta_lat = np.pi / self.height
70 | delta_lon = 2 * np.pi / self.width
71 |
72 | range_x = np.arange(-(Kw // 2), Kw // 2 + 1)
73 | if not Kw % 2:
74 | range_x = np.delete(range_x, Kw // 2)
75 |
76 | range_y = np.arange(-(Kh // 2), Kh // 2 + 1)
77 | if not Kh % 2:
78 | range_y = np.delete(range_y, Kh // 2)
79 |
80 | kerX = np.tan(range_x * delta_lon)
81 | kerY = np.tan(range_y * delta_lat) / np.cos(range_y * delta_lon)
82 |
83 | return np.meshgrid(kerX, kerY) # (Kh, Kw)
84 |
--------------------------------------------------------------------------------
/src/spherenet/SphereConv2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | from .GridGenerator import GridGenerator
7 |
8 |
9 | class SphereConv2d(nn.Conv2d):
10 | """
11 | kernel_size: (H, W)
12 | """
13 |
14 | def __init__(self, in_channels: int, out_channels: int, kernel_size=(3, 3),
15 | stride=1, padding=0, dilation=1,
16 | groups: int = 1, bias: bool = True, padding_mode: str = 'zeros'):
17 | super(SphereConv2d, self).__init__(
18 | in_channels, out_channels, kernel_size,
19 | stride, padding, dilation, groups, bias, padding_mode)
20 | self.grid_shape = None
21 | self.grid = None
22 |
23 | def genSamplingPattern(self, h, w):
24 | gridGenerator = GridGenerator(h, w, self.kernel_size, self.stride)
25 | LonLatSamplingPattern = gridGenerator.createSamplingPattern()
26 |
27 | # generate grid to use `F.grid_sample`
28 | lat_grid = (LonLatSamplingPattern[:, :, :, 0] / h) * 2 - 1
29 | lon_grid = (LonLatSamplingPattern[:, :, :, 1] / w) * 2 - 1
30 |
31 | grid = np.stack((lon_grid, lat_grid), axis=-1)
32 | with torch.no_grad():
33 | self.grid = torch.FloatTensor(grid)
34 | self.grid.requires_grad = False
35 |
36 | def forward(self, x):
37 | # Generate Sampling Pattern
38 | B, C, H, W = x.shape
39 |
40 | if (self.grid_shape is None) or (self.grid_shape != (H, W)):
41 | self.grid_shape = (H, W)
42 | self.genSamplingPattern(H, W)
43 |
44 | with torch.no_grad():
45 | grid = self.grid.repeat((B, 1, 1, 1)).to(x.device) # (B, H*Kh, W*Kw, 2)
46 | grid.requires_grad = False
47 |
48 | x = F.grid_sample(x, grid, align_corners=True, mode='nearest') # (B, in_c, H*Kh, W*Kw)
49 |
50 | # self.weight -> (out_c, in_c, Kh, Kw)
51 | x = F.conv2d(x, self.weight, self.bias, stride=self.kernel_size)
52 |
53 | return x # (B, out_c, H/stride_h, W/stride_w)
54 |
--------------------------------------------------------------------------------
/src/spherenet/SphereMaxPool2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from .GridGenerator import GridGenerator
7 |
8 |
9 | class SphereMaxPool2d(nn.MaxPool2d):
10 | """
11 | kernel_size: (H, W)
12 | """
13 |
14 | def __init__(self, kernel_size=(3, 3), stride=1, padding=0, dilation=1,
15 | return_indices: bool = False, ceil_mode: bool = False):
16 | super(SphereMaxPool2d, self).__init__(
17 | kernel_size, stride, padding, dilation, return_indices, ceil_mode)
18 | if isinstance(kernel_size, int):
19 | self.kernel_size = (kernel_size, kernel_size)
20 | if isinstance(stride, int):
21 | self.stride = (stride, stride)
22 |
23 | self.grid_shape = None
24 | self.grid = None
25 |
26 | def genSamplingPattern(self, h, w):
27 | gridGenerator = GridGenerator(h, w, self.kernel_size, self.stride)
28 | LonLatSamplingPattern = gridGenerator.createSamplingPattern()
29 |
30 | # generate grid to use `F.grid_sample`
31 | lat_grid = (LonLatSamplingPattern[:, :, :, 0] / h) * 2 - 1
32 | lon_grid = (LonLatSamplingPattern[:, :, :, 1] / w) * 2 - 1
33 |
34 | grid = np.stack((lon_grid, lat_grid), axis=-1)
35 |
36 | with torch.no_grad():
37 | self.grid = torch.FloatTensor(grid)
38 | self.grid.requires_grad = False
39 |
40 | def forward(self, x):
41 | # Generate Sampling Pattern
42 | B, C, H, W = x.shape
43 |
44 | if (self.grid_shape is None) or (self.grid_shape != (H, W)):
45 | self.grid_shape = (H, W)
46 | self.genSamplingPattern(H, W)
47 |
48 | with torch.no_grad():
49 | grid = self.grid.repeat((B, 1, 1, 1)).to(x.device) # (B, H*Kh, W*Kw, 2)
50 | grid.requires_grad = False
51 |
52 | x = F.grid_sample(x, grid, align_corners=False, mode='bilinear') # (B, in_c, H*Kh, W*Kw)
53 |
54 | x = F.max_pool2d(x, kernel_size=self.kernel_size, stride=self.kernel_size)
55 |
56 | return x # (B, out_c, H/stride_h, W/stride_w)
57 |
--------------------------------------------------------------------------------
/src/spherenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .GridGenerator import GridGenerator
2 | from .SphereConv2d import SphereConv2d
3 | from .SphereMaxPool2d import SphereMaxPool2d
4 |
--------------------------------------------------------------------------------