├── .gitignore ├── LICENSE ├── README.md ├── dataloader └── __init__.py ├── models ├── hexrunet.py └── unfold_nn.py ├── notebook └── icosahedron_grid.md ├── train.py └── utils ├── geometry_helper.py └── projection_helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.lock 4 | *.pth 5 | 6 | .ipynb_checkpoints/ 7 | .idea/ 8 | 9 | checkpoints/ 10 | raw_data/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 matsuren 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## HexRUNet PyTorch 2 | An unofficial PyTorch implementation of ICCV 2019 paper ["Orientation-Aware Semantic Segmentation on Icosahedron Spheres"](http://openaccess.thecvf.com/content_ICCV_2019/html/Zhang_Orientation-Aware_Semantic_Segmentation_on_Icosahedron_Spheres_ICCV_2019_paper.html). Only HexRUNet-C for Omni-MNIST is implemented right now. 3 | 4 | ## Requirements 5 | Python 3.6 or later is required. 6 | 7 | Python libraries: 8 | - PyTorch >= 1.3.1 9 | - torchvision 10 | - tensorboard 11 | - tqdm 12 | - [igl](https://libigl.github.io/libigl-python-bindings/) 13 | 14 | 15 | ## Training 16 | Run the following command to train with random-rotated training data and evaluate with random-rotated test data. 17 | ```bash 18 | python train.py --train_rot --test_rot 19 | ``` 20 | You can change parameters by arguments (`-h` option for details). 21 | 22 | ## Results 23 | Here is the results of this repository. Accuracy of the last epoch (30th epoch) is reported. 24 | 25 | Omni-MNIST HexRUNet-C accuracy (%) 26 | || N/N | N/R | R/R | 27 | ----|----|----|---- 28 | |This repository | 99.15 | 69.62 | 98.36 29 | |Paper| 99.45 | 29.84 | 97.05 30 | 31 | - `N/N`: Non-rotated training and test data 32 | - `N/R`: Non-rotated training data and random-rotated test data 33 | - `R/R`: Random-rotated training and test data 34 | 35 | As can be observed here, `N/R` of this repogitory is much higher than the one reported in original paper. I guess it's because the implementation of projecting images on a sphere and rotation are different (My implementation of the projection is based on [ChiWeiHsiao/SphereNet-pytorch](https://github.com/ChiWeiHsiao/SphereNet-pytorch)). 36 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | from torchvision import transforms 4 | 5 | from utils.geometry_helper import get_icosahedron, get_unfold_imgcoord 6 | from utils.projection_helper import img2ERP, erp2sphere 7 | 8 | 9 | # -------------------- 10 | # Dataset 11 | class UnfoldIcoDataset(Dataset): 12 | """Unfolded Icosahedron dataset. 13 | 14 | 15 | Examples 16 | -------- 17 | >>> root_dataset = datasets.MNIST(root='raw_data', train=True, download=True) 18 | >>> dataset = UnfoldIcoDataset(root_dataset, erp_shape=(60, 120), level=4, debug=True) 19 | >>> sample = dataset[0] 20 | 21 | >>> # show equirectangular image 22 | >>> import matplotlib.pyplot as plt 23 | >>> plt.imshow(sample['erp_img']) 24 | 25 | >>> # show image on icosahedron 26 | >>> from meshplot import plot 27 | >>> plot(dataset.vertices, dataset.faces, sample['ico_img']) 28 | 29 | >>> # show unfolded images 30 | >>> fig, ax = plt.subplots(1, 5) 31 | >>> _ = [ax[i].imshow(sample[(i+3)%5]) for i in range(5)] 32 | """ 33 | 34 | def __init__(self, dataset, erp_shape, level, rotate=False, transform=None, debug=False): 35 | self.dataset = dataset 36 | self.transform = transform 37 | self.erp_shape = erp_shape 38 | self.level = level 39 | self.rotate = rotate 40 | 41 | self.vertices, self.faces = get_icosahedron(level) 42 | 43 | self.img_coord = get_unfold_imgcoord(level) 44 | 45 | self.debug = debug 46 | 47 | def __len__(self): 48 | return len(self.dataset) 49 | 50 | def get_erp_image(self, idx, v_rot=0, h_rot=0, erp_shape=None): 51 | if erp_shape is None: 52 | erp_shape = self.erp_shape 53 | img = np.array(self.dataset[idx][0]) 54 | erp_img = img2ERP(img, v_rot=v_rot, h_rot=h_rot, outshape=erp_shape) 55 | return erp_img 56 | 57 | @property 58 | def classes(self): 59 | return self.dataset.classes 60 | 61 | def __getitem__(self, idx): 62 | sample = {} 63 | 64 | sample['label'] = self.dataset[idx][1] 65 | 66 | if self.rotate: 67 | h_rot = np.random.uniform(-180, 180) 68 | v_rot = np.random.uniform(-90, 90) 69 | else: 70 | v_rot = 0 71 | h_rot = 0 72 | 73 | erp_img = self.get_erp_image(idx, v_rot=v_rot, h_rot=h_rot) 74 | ico_img = erp2sphere(erp_img, self.vertices) 75 | 76 | # unfolded images 77 | for i in range(5): 78 | sample[i] = ico_img[self.img_coord[i]] 79 | 80 | # debug 81 | if self.debug: 82 | sample['erp_img'] = erp_img 83 | sample['ico_img'] = ico_img 84 | 85 | if self.transform: 86 | sample = self.transform(sample) 87 | 88 | return sample 89 | 90 | 91 | # -------------------- 92 | # Custom transform 93 | class ToTensor(object): 94 | def __init__(self): 95 | self.ToTensor = transforms.ToTensor() 96 | 97 | def __call__(self, sample): 98 | for i in range(5): 99 | sample[i] = self.ToTensor(sample[i]) 100 | return sample 101 | 102 | 103 | class Normalize(object): 104 | def __init__(self, mean, std): 105 | self.normalizer = transforms.Normalize(mean, std) 106 | 107 | def __call__(self, sample): 108 | for i in range(5): 109 | sample[i] = self.normalizer(sample[i]) 110 | return sample 111 | 112 | 113 | if __name__ == '__main__': 114 | from torchvision import datasets 115 | from tqdm import tqdm 116 | 117 | level = 4 118 | erp_shape = (60, 120) 119 | root_dataset = datasets.MNIST(root='../raw_data', train=True, download=True) 120 | tmpset = UnfoldIcoDataset(root_dataset, erp_shape, level) 121 | imgs = [] 122 | for it in tqdm(tmpset): 123 | unfold_imgs = [it[i] / 255 for i in range(5)] 124 | unfold_imgs = np.concatenate(unfold_imgs, axis=1) 125 | imgs.append(unfold_imgs) 126 | 127 | imgs = np.array(imgs, dtype=np.float) 128 | print(f'total len:{len(imgs)}, mean:{imgs.mean():.4f}, std:{imgs.std():.4f}') 129 | # total len:60000, mean:0.0645, std:0.2116 130 | -------------------------------------------------------------------------------- /models/hexrunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.unfold_nn import HexConv2d, UnfoldReLU, UnfoldBatchNorm2d, UnfoldMaxPool2d, UnfoldConv2d 5 | 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self, in_channels, out_channels, level, bias=False): 9 | super(ResBlock, self).__init__() 10 | self.conv1 = nn.Sequential( 11 | UnfoldConv2d(in_channels, out_channels, 1, bias=bias), UnfoldMaxPool2d(2), 12 | UnfoldBatchNorm2d(out_channels), UnfoldReLU()) 13 | self.conv2 = nn.Sequential( 14 | HexConv2d(out_channels, out_channels, level - 1, 1, bias=bias), 15 | UnfoldBatchNorm2d(out_channels), UnfoldReLU()) 16 | self.conv3 = nn.Sequential( 17 | UnfoldConv2d(out_channels, out_channels, 1, bias=bias), UnfoldBatchNorm2d(out_channels)) 18 | 19 | self.downsample = nn.Sequential( 20 | UnfoldConv2d(in_channels, out_channels, 1, bias=bias), UnfoldMaxPool2d(2), 21 | UnfoldBatchNorm2d(out_channels)) 22 | 23 | self.relu = UnfoldReLU() 24 | 25 | def forward(self, x): 26 | identity = x 27 | out = self.conv1(x) 28 | out = self.conv2(out) 29 | out = self.conv3(out) 30 | 31 | identity = self.downsample(identity) 32 | for i in range(5): 33 | out[i] += identity[i] 34 | 35 | out = self.relu(out) 36 | return out 37 | 38 | 39 | class HexRUNet_C(nn.Module): 40 | """ HexRUNet-C proposed in [1]. 41 | 42 | References 43 | ---------- 44 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 45 | 46 | """ 47 | 48 | def __init__(self, in_channels): 49 | super(HexRUNet_C, self).__init__() 50 | self.conv1 = nn.Sequential( 51 | HexConv2d(in_channels, 16, level=4, stride=1), UnfoldReLU(), UnfoldBatchNorm2d(16)) 52 | self.block1 = ResBlock(16, 64, level=4) 53 | self.block2 = ResBlock(64, 256, level=3) 54 | self.fc = nn.Linear(256, 10) 55 | 56 | def forward(self, batch): 57 | out = self.conv1(batch) 58 | out = self.block1(out) 59 | out = self.block2(out) 60 | 61 | # Max pooling. I'm not sure if it's correct 62 | cat_feat = [out[i] for i in range(5)] 63 | cat_feat = torch.cat(cat_feat, axis=3) 64 | b, c, h, w = out[0].shape 65 | cat_feat = cat_feat.view(b, c, -1) 66 | flatten_feat = torch.max(cat_feat, dim=2)[0] 67 | 68 | out = self.fc(flatten_feat) 69 | 70 | return out 71 | -------------------------------------------------------------------------------- /models/unfold_nn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import Parameter 8 | from torch.nn import init 9 | from torch.nn.modules.utils import _pair 10 | 11 | from utils.geometry_helper import unfold_padding, get_weight_alpha 12 | 13 | 14 | class HexConv2d(nn.Module): 15 | def __init__(self, in_channels, out_channels, level, stride=1, bias=False): 16 | """ Hexagonal convolution proposed in [1]. 17 | 18 | References 19 | ---------- 20 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 21 | 22 | """ 23 | super(HexConv2d, self).__init__() 24 | if stride != 1 and stride != 2: 25 | raise ValueError("stride must be 1 or 2") 26 | outlevel = level if stride == 1 else level - 1 27 | 28 | self.in_channels = in_channels 29 | self.out_channels = out_channels 30 | 31 | self.stride = _pair(stride) 32 | # weight 33 | weight_trainable = torch.Tensor(out_channels, in_channels, 7) 34 | nn.init.kaiming_uniform_(weight_trainable, mode='fan_in', nonlinearity='relu') 35 | self.weight = Parameter(weight_trainable) # adding zero 36 | 37 | # self._w1_index = [ 38 | # [-1, 5, 4], 39 | # [0, 6, 3], 40 | # [1, 2, -1] 41 | # ] 42 | # self._w2_index = [ 43 | # [-1, 0, 5], 44 | # [1, 6, 4], 45 | # [2, 3, -1] 46 | # ] 47 | 48 | _w1_index = [ 49 | 5, 4, 50 | 0, 6, 3, 51 | 1, 2 52 | ] 53 | _w2_index = [ 54 | 0, 5, 55 | 1, 6, 4, 56 | 2, 3 57 | ] 58 | 59 | self.register_buffer('w1_index', torch.tensor(_w1_index)) 60 | self.register_buffer('w2_index', torch.tensor(_w2_index)) 61 | 62 | alpha = torch.from_numpy(get_weight_alpha(outlevel)).float() 63 | self.register_buffer('alpha', alpha) 64 | 65 | if bias: 66 | self.bias = Parameter(torch.Tensor(out_channels)) 67 | fan_in = in_channels * 9 68 | bound = 1 / math.sqrt(fan_in) 69 | init.uniform_(self.bias, -bound, bound) 70 | else: 71 | self.register_parameter('bias', None) 72 | 73 | def extra_repr(self): 74 | s = ('{in_channels}, {out_channels}, kernel_size=(hexagonal)' 75 | ', stride={stride}') 76 | return s.format(**self.__dict__) 77 | 78 | def get_hex_weight(self): 79 | weight1 = F.pad(torch.index_select(self.weight, -1, self.w1_index), (1, 1)) 80 | weight2 = F.pad(torch.index_select(self.weight, -1, self.w2_index), (1, 1)) 81 | out_ch, in_ch = weight1.shape[:2] 82 | return weight1.view(out_ch, in_ch, 3, 3), weight2.view(out_ch, in_ch, 3, 3) 83 | 84 | def forward(self, input): 85 | x = unfold_padding(input) 86 | weight1, weight2 = self.get_hex_weight() 87 | 88 | outputs = [None for _ in range(5)] 89 | for i in range(5): 90 | feat1 = F.conv2d(x[i], weight1, self.bias, self.stride) 91 | feat2 = F.conv2d(x[i], weight2, self.bias, self.stride) 92 | outputs[i] = self.alpha * feat1 + (1 - self.alpha) * feat2 93 | return outputs 94 | 95 | 96 | class UnfoldReLU(nn.ReLU): 97 | def forward(self, x): 98 | out = [super(UnfoldReLU, self).forward(x[i]) for i in range(5)] 99 | return out 100 | 101 | 102 | # class UnfoldBatchNorm2d(nn.BatchNorm2d): 103 | # def forward(self, x): 104 | # b, c, h, w = x[0].shape 105 | # # batch => len(x)*batch 106 | # out_cat = super(UnfoldBatchNorm2d, self).forward(torch.cat(x, dim=0)) 107 | # out = [out_cat[b * i:b * (i + 1)] for i in range(5)] # => list of b x c x h x w 108 | # return out 109 | class UnfoldBatchNorm2d(nn.BatchNorm3d): 110 | def forward(self, x): 111 | stack_x = torch.stack(x, dim=2) 112 | stack_out = super(UnfoldBatchNorm2d, self).forward(stack_x) 113 | out = [stack_out[:, :, i] for i in range(5)] 114 | return out 115 | 116 | 117 | class UnfoldConv2d(nn.Conv2d): 118 | def forward(self, x): 119 | out = [super(UnfoldConv2d, self).forward(x[i]) for i in range(5)] 120 | return out 121 | 122 | 123 | class UnfoldMaxPool2d(nn.MaxPool2d): 124 | def forward(self, x): 125 | out = [super(UnfoldMaxPool2d, self).forward(x[i]) for i in range(5)] 126 | return out 127 | 128 | 129 | class UnfoldUpsample(nn.Module): 130 | def __init__(self): 131 | super(UnfoldUpsample, self).__init__() 132 | self.up = partial(F.interpolate, mode='bilinear', align_corners=True) 133 | 134 | def forward(self, x): 135 | x = unfold_padding(x, only_NE=True) 136 | h, w = x[0].shape[-2:] 137 | 138 | for i in range(5): 139 | x[i] = self.up(x[i], (2 * h - 1, 2 * w - 1))[..., 1:, :-1] 140 | 141 | return x 142 | 143 | 144 | class UnfoldAvgPool2d(nn.AvgPool2d): 145 | def forward(self, x): 146 | out = [super(UnfoldAvgPool2d, self).forward(x[i]) for i in range(5)] 147 | return out 148 | -------------------------------------------------------------------------------- /notebook/icosahedron_grid.md: -------------------------------------------------------------------------------- 1 | ```python 2 | import open3d as o3d 3 | ``` 4 | 5 | ```python 6 | import igl 7 | import numpy as np 8 | import math 9 | from numpy.linalg import norm 10 | import matplotlib.pyplot as plt 11 | %matplotlib inline 12 | 13 | from meshplot import plot, subplot, interact 14 | from functools import partial 15 | 16 | from functools import lru_cache 17 | import torch 18 | ``` 19 | 20 | # Icosahedron 21 | $L_{edge} \sin{\frac{2\pi}{5}} = r$ 22 | 23 | ```python 24 | 25 | ``` 26 | 27 | ```python 28 | def get_base_icosahedron(): 29 | t = (1.0 + 5.0 ** .5) / 2.0 30 | vertices =[-1, t, 0, 1, t, 0, 0, 1, t, -t, 0, 1, -t, 0, -1, 0, 1, -t, t, 0, -1, t, 0, 31 | 1, 0, -1, t, -1, -t, 0, 0, -1, -t, 1, -t, 0] 32 | faces = [0,2,1, 0,3,2, 0,4,3, 0,5,4, 0,1,5, 33 | 1,7,6, 1,2,7, 2,8,7, 2,3,8, 3,9,8, 3,4,9, 4,10,9, 4,5,10, 5,6,10, 5,1,6, 34 | 6,7,11, 7,8,11, 8,9,11, 9,10,11, 10,6,11] 35 | 36 | # make every vertex have radius 1.0 37 | vertices = np.reshape(vertices, (-1, 3)) / (np.sin(2*np.pi/5)*2) 38 | faces = np.reshape(faces, (-1, 3)) 39 | 40 | # Rotate vertices so that v[0] = (0, -1, 0), v[1] is on yz-plane 41 | ry = -vertices[0] 42 | rx = np.cross(ry, vertices[1]) 43 | rx /= np.linalg.norm(rx) 44 | rz = np.cross(rx, ry) 45 | R = np.stack([rx,ry,rz]) 46 | vertices = vertices.dot(R.T) 47 | return vertices, faces 48 | 49 | def subdivision(v,f,level=1): 50 | for _ in range(level): 51 | # subdivision 52 | v, f = igl.upsample(v, f) 53 | # normalize 54 | v /= np.linalg.norm(v, axis=1)[:,np.newaxis] 55 | return v,f 56 | 57 | @lru_cache(maxsize=12) 58 | def get_icosahedron(level=0): 59 | if level == 0: 60 | v, f = get_base_icosahedron() 61 | return v, f 62 | # require subdivision 63 | v, f = get_icosahedron(level-1) 64 | v, f = subdivision(v, f, 1) 65 | return v, f 66 | ``` 67 | 68 | ```python 69 | 70 | ``` 71 | 72 | ```python 73 | level = 1 74 | v, f = get_icosahedron(level) 75 | 76 | len(v) 77 | ``` 78 | 79 | ```python 80 | plot(v, f) 81 | ``` 82 | 83 | ```python 84 | def drawAxis(T=np.eye(4) ,scale=0.3, colors=['r', 'g', 'b']): 85 | tvec = T[:3,3] 86 | R = T[:3,:3] 87 | start = tvec.flatten()[np.newaxis].repeat(3, axis=0) 88 | end = start + scale * R.T 89 | for s, e, c in zip(start, end, colors): 90 | ax.plot([s[0], e[0]], [s[1], e[1]], [s[2], e[2]], c=c) 91 | 92 | %matplotlib notebook 93 | ``` 94 | 95 | ```python 96 | # This import registers the 3D projection, but is otherwise unused. 97 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import 98 | 99 | fig = plt.figure() 100 | ax = fig.add_subplot(111, projection='3d') 101 | 102 | for i,it in enumerate(v): 103 | ax.scatter(it[0], it[1], it[2], marker='d') 104 | ax.text(it[0], it[1], it[2], f'{i}') 105 | 106 | # Origin 107 | drawAxis(scale=1) 108 | 109 | # drawAxis(T,scale=1) 110 | ax.set_xlabel('X Label') 111 | ax.set_ylabel('Y Label') 112 | ax.set_zlabel('Z Label') 113 | 114 | plt.show() 115 | 116 | ``` 117 | 118 | ## Unfolding 119 | 120 | ```python 121 | %matplotlib inline 122 | ``` 123 | 124 | ```python 125 | def get_base_unfold(): 126 | v, f = get_base_icosahedron() 127 | unfold_v = {i:[] for i in range(12)} 128 | 129 | # edge length 130 | l = 1/np.sin(2*np.pi/5) 131 | # height 132 | h = 3**0.5*l/2 133 | 134 | # v0 135 | for i in range(5): 136 | unfold_v[0].append([i*l, 0]) 137 | 138 | # v1 139 | for _ in range(5): 140 | unfold_v[1].append([-0.5*l, h]) 141 | unfold_v[1][1] = [-0.5*l + 5*l, h] 142 | unfold_v[1][4] = [-0.5*l + 5*l, h] 143 | 144 | # v2-v5 145 | for i in range(2, 6): 146 | for _ in range(5): 147 | unfold_v[i].append([(0.5 + i - 2)*l, h]) 148 | 149 | # v6 150 | for _ in range(5): 151 | unfold_v[6].append([-l, 2*h]) 152 | unfold_v[6][1] = [-l + 5*l, 2*h] 153 | unfold_v[6][2] = [-l + 5*l, 2*h] 154 | unfold_v[6][4] = [-l + 5*l, 2*h] 155 | 156 | # v7-v10 157 | for i in range(7, 11): 158 | for _ in range(5): 159 | unfold_v[i].append([(i - 7)*l, 2*h]) 160 | 161 | # v11 162 | for i in range(5): 163 | unfold_v[11].append([(-0.5 + i)*l, 3*h]) 164 | 165 | # to numpy 166 | for i in range(len(unfold_v)): 167 | unfold_v[i] = np.array(unfold_v[i]) 168 | return unfold_v, f 169 | ``` 170 | 171 | ```python 172 | class UnfoldVertex(object): 173 | def __init__(self, unfold_v): 174 | self.unfold_v = unfold_v 175 | self.reset() 176 | 177 | def __getitem__(self, item): 178 | pos = self.unfold_v[item][self.cnt[item]] 179 | self.cnt[item] += 1 180 | return pos 181 | 182 | def reset(self): 183 | self.cnt = {key:0 for key in self.unfold_v.keys()} 184 | 185 | 186 | class VertexIdxManager(object): 187 | def __init__(self, unfold_v): 188 | self.reg_v = {} 189 | self.next_v_index = len(unfold_v) 190 | 191 | def get_next(self, a, b): 192 | if a>b: 193 | a,b = b,a 194 | key = f'{a},{b}' 195 | if key not in self.reg_v: 196 | self.reg_v[key] = self.next_v_index 197 | self.next_v_index += 1 198 | return self.reg_v[key] 199 | 200 | from copy import copy 201 | def unfold_subdivision(unfold_v, faces): 202 | v_idx_manager = VertexIdxManager(unfold_v) 203 | 204 | new_faces = [] 205 | new_unfold = copy(unfold_v) 206 | v_obj = UnfoldVertex(unfold_v) 207 | for (a, b, c) in faces: 208 | a_pos = v_obj[a] 209 | b_pos = v_obj[b] 210 | c_pos = v_obj[c] 211 | 212 | new_a= v_idx_manager.get_next(a, b) 213 | new_b= v_idx_manager.get_next(b, c) 214 | new_c= v_idx_manager.get_next(c, a) 215 | 216 | new_a_pos = (a_pos+b_pos)/2 217 | new_b_pos = (b_pos+c_pos)/2 218 | new_c_pos = (c_pos+a_pos)/2 219 | 220 | # new faces 221 | new_faces.append([a, new_a, new_c]) 222 | new_faces.append([b, new_b, new_a]) 223 | new_faces.append([new_a, new_b, new_c]) 224 | new_faces.append([new_b, c, new_c]) 225 | 226 | # new vertex 227 | indices = [new_a, new_b, new_c] 228 | poses = [new_a_pos, new_b_pos, new_c_pos] 229 | for (idx, pos) in zip(indices, poses): 230 | if idx not in new_unfold: 231 | new_unfold[idx] = [] 232 | for _ in range(3): 233 | new_unfold[idx].append(pos) 234 | return new_unfold, new_faces 235 | 236 | 237 | @lru_cache(maxsize=12) 238 | def get_unfold_icosahedron(level=0): 239 | if level == 0: 240 | unfold_v, f = get_base_unfold() 241 | return unfold_v, f 242 | # require subdivision 243 | unfold_v, f = get_unfold_icosahedron(level-1) 244 | unfold_v, f = unfold_subdivision(unfold_v, f) 245 | return unfold_v, f 246 | ``` 247 | 248 | ```python 249 | base_unfold_v, base_f = get_unfold_icosahedron(0) 250 | level = 1 251 | new_unfold, new_faces = get_unfold_icosahedron(level) 252 | vertices, _ = get_icosahedron(level) 253 | ``` 254 | 255 | ```python 256 | # draw base icosahedron 257 | v_obj = UnfoldVertex(base_unfold_v) # vertex object 258 | v_obj.reset() 259 | fig = plt.figure() 260 | ax = fig.add_subplot(111) 261 | ax.set_aspect(aspect=1) 262 | for i, it in base_unfold_v.items(): 263 | for xy in it: 264 | ax.plot(xy[0], -xy[1], 'bo') 265 | ax.text(xy[0]+0.05, -xy[1]+0.05, f'{i}') 266 | # draw lines 267 | for (a,b,c) in base_f: 268 | lines = [] 269 | a_pos = v_obj[a] 270 | b_pos = v_obj[b] 271 | c_pos = v_obj[c] 272 | lines += [a_pos, b_pos, c_pos, a_pos] 273 | lines = np.array(lines) 274 | plt.plot(lines[:, 0], -lines[:, 1], 'g') 275 | ``` 276 | 277 | ```python 278 | fig = plt.figure(figsize=(14,14)) 279 | ax = fig.add_subplot(111) 280 | ax.set_aspect(aspect=1) 281 | for i, it in new_unfold.items(): 282 | for xy in it: 283 | ax.plot(xy[0], -xy[1], 'bo', markersize=3) 284 | ax.text(xy[0]+0.02, -xy[1]+0.02, f'{i}', fontsize=9) 285 | 286 | # draw lines 287 | v_obj = UnfoldVertex(base_unfold_v) 288 | for (a,b,c) in base_f: 289 | lines = [] 290 | a_pos = v_obj[a] 291 | b_pos = v_obj[b] 292 | c_pos = v_obj[c] 293 | lines += [a_pos, b_pos, c_pos, a_pos] 294 | lines = np.array(lines) 295 | plt.plot(lines[:, 0], -lines[:, 1], 'g') 296 | ``` 297 | 298 | ## Distored grid 299 | 300 | ```python 301 | 302 | def distort_grid(unfold_v): 303 | np_round = partial(np.round, decimals=9) 304 | 305 | # calculate transform matrix 306 | new_x = unfold_v[2][0]-unfold_v[0][0] 307 | edge_len = np.linalg.norm(new_x) 308 | new_x /= edge_len 309 | new_y = np.cross([0,0,1], np.append(new_x, 0))[:2] 310 | R = np.stack([new_x, new_y]) 311 | 312 | a = unfold_v[2][0]-unfold_v[0][0] 313 | b = unfold_v[1][0]-unfold_v[0][0] 314 | skew = np.eye(2) 315 | skew[0, 1] = -1/np.tan(np.arccos(a.dot(b)/norm(a)/norm(b))) 316 | skew[0]/=norm(skew[0]) 317 | 318 | T = skew.dot(R) 319 | # scale adjust 320 | scale = np.linalg.det(skew)*edge_len 321 | T /=scale 322 | 323 | # to numpy array for efficient computation 324 | # np_round to alleviate numerical error when sorting 325 | distort_unfold = copy(unfold_v) 326 | five_neighbor = [distort_unfold[i] for i in range(12)] 327 | five_neighbor = np.array(five_neighbor) 328 | # Transform 329 | five_neighbor = np_round(five_neighbor.dot(T.T)) 330 | 331 | # the same procedure for six_neighbor if len(unfold_v) > 12 332 | if len(unfold_v)>12: 333 | six_neighbor = [distort_unfold[i] for i in range(12, len(unfold_v))] 334 | six_neighbor = np.array(six_neighbor) 335 | six_neighbor = np_round(six_neighbor.dot(T.T)) 336 | 337 | # to original shape 338 | distort_unfold = {} 339 | cnt = 0 340 | for it in five_neighbor: 341 | distort_unfold[cnt] = it 342 | cnt+=1 343 | if len(unfold_v)>12: 344 | for it in six_neighbor: 345 | distort_unfold[cnt] = it 346 | cnt+=1 347 | return distort_unfold 348 | ``` 349 | 350 | ```python 351 | draw_base = distort_grid(base_unfold_v) 352 | distort_unfold = distort_grid(new_unfold) 353 | 354 | ``` 355 | 356 | ```python 357 | fig = plt.figure(figsize=(14,14)) 358 | ax = fig.add_subplot(111) 359 | ax.set_aspect(aspect=1) 360 | for i, it in distort_unfold.items(): 361 | for xy in it: 362 | ax.plot(xy[0], -xy[1], 'bo', markersize=3) 363 | ax.text(xy[0]+0.02, -xy[1]+0.02, f'{i}', fontsize=9) 364 | 365 | # draw lines 366 | v_obj = UnfoldVertex(draw_base) 367 | for (a,b,c) in base_f: 368 | lines = [] 369 | a_pos = v_obj[a] 370 | b_pos = v_obj[b] 371 | c_pos = v_obj[c] 372 | lines += [a_pos, b_pos, c_pos, a_pos] 373 | lines = np.array(lines) 374 | plt.plot(lines[:, 0], -lines[:, 1], 'g') 375 | ``` 376 | 377 | ## To image coordinate 378 | 379 | ```python 380 | import math 381 | 382 | def get_rect_idxs(x, y): 383 | rect_idxs = [] 384 | for i in range(5): 385 | x_min = i 386 | x_max = x_min+1 387 | y_min = -i 388 | y_max = y_min+2 389 | if x_min<=x<=x_max and y_min<=y<=y_max: 390 | rect_idxs.append(i) 391 | return rect_idxs 392 | 393 | 394 | def distort_unfold_to_imgcoord(distort_unfold, drop_NE=True): 395 | """ 396 | Parameters 397 | ---------- 398 | distort_unfold : 399 | distorted unfold 400 | drop_NE : bool 401 | drop north and east as in [1] 402 | 403 | References 404 | ---------- 405 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 406 | 407 | """ 408 | vertex_num = len(distort_unfold) 409 | level = round(math.log((vertex_num-2)//10, 4)) 410 | 411 | width = 2**level+1 412 | height = 2*width - 1 413 | 414 | unfold_pts_set = set() # (vertex_id, x, y) 415 | 416 | # remove duplicate 417 | for key, arr in distort_unfold.items(): 418 | for val in arr: 419 | unfold_pts_set.add((key, val[0], val[1])) 420 | 421 | # sort 422 | unfold_pts_set = sorted(unfold_pts_set, key=lambda x : (x[1], x[2])) 423 | 424 | # to image coorinate 425 | img_coord = {} 426 | for (vertex_id, x, y) in unfold_pts_set: 427 | rect_idxs = get_rect_idxs(x, y) 428 | for key in rect_idxs: 429 | if key not in img_coord: 430 | img_coord[key] = [] 431 | img_coord[key].append(vertex_id) 432 | 433 | # to numpy 434 | for key in img_coord: 435 | img_coord[key] = np.array(img_coord[key]).reshape(width, height).T 436 | 437 | if drop_NE: 438 | # orientation-aware semantic segmentation on icosahedron spheres form 439 | for key in img_coord: 440 | img_coord[key] = img_coord[key][1:,:-1] 441 | 442 | return img_coord 443 | ``` 444 | 445 | ```python 446 | img_coord = distort_unfold_to_imgcoord(distort_unfold) 447 | ``` 448 | 449 | ## unfold_padding 450 | 451 | 452 | 453 | ```python 454 | 455 | import torch 456 | import torch.nn.functional as F 457 | def unfold_padding(arr, cval=0, only_NE=False): 458 | """ 459 | Parameters 460 | ---------- 461 | arr : dict {0-4: array} 462 | array 463 | cval : int or float 464 | initial padding value 465 | 466 | only_NE : bool 467 | drop north and east as in [1] 468 | 469 | References 470 | ---------- 471 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 472 | """ 473 | is_ndarray = False 474 | if isinstance(arr[0], np.ndarray): 475 | is_ndarray = True 476 | # to torch tensor 477 | arr = copy(arr) 478 | for i in range(5): 479 | arr[i] = torch.from_numpy(arr[i].copy()) 480 | if arr[i].ndim == 3: 481 | # H x W x C -> C x H x W 482 | arr[i] = arr[i].permute(2, 0, 1) 483 | elif arr[i].ndim == 2: 484 | # H x W -> 1 x H x W 485 | arr[i] = arr[i].unsqueeze(0) 486 | # Add batch dimension 487 | arr[i] = arr[i].unsqueeze(0) 488 | 489 | h, w = arr[0].size(2), arr[0].size(3) 490 | 491 | arr_with_pad = [] 492 | pad_w = (0,1,1,0) if only_NE else (1,1,1,1) 493 | for key in range(5): 494 | arr_with_pad.append(F.pad(arr[key], pad_w, value=cval)) 495 | 496 | for key in range(5): 497 | tgt = (key + 1) % 5 498 | # north 499 | arr_with_pad[key][..., 0, pad_w[0]+1:] = arr[tgt][..., :w, 0] 500 | # east 501 | arr_with_pad[key][..., 1:w+1, -1] = arr[tgt][..., w:, 0] 502 | arr_with_pad[key][..., w+1:-1-pad_w[3], -1] = arr[tgt][..., -1, 1:] 503 | 504 | if not only_NE: 505 | tgt = (key - 1) % 5 506 | # some Indices look like shifted but if you check the connectivility, it's fine 507 | # west 508 | arr_with_pad[key][..., 1:w+1,0] = arr[tgt][..., 0, :] 509 | arr_with_pad[key][..., w+1:-1,0] = arr[tgt][..., :w, -1] 510 | # south 511 | arr_with_pad[key][..., -1,:-1] = arr[tgt][..., w-1:, -1] 512 | 513 | if is_ndarray: 514 | for i in range(5): 515 | arr_with_pad[i] = arr_with_pad[i].permute(0, 2, 3, 1).squeeze().numpy() 516 | 517 | return arr_with_pad 518 | ``` 519 | 520 | ```python 521 | tmp = unfold_padding(img_coord) 522 | tmp[0] 523 | ``` 524 | 525 | ## Calculate weight 526 | Five elements are symmetric, therefore calculation for only one element is enough 527 | 528 | ```python 529 | def calculate_weight(img_coord, vertices): 530 | """ calculate weight alpha 531 | 532 | References 533 | ---------- 534 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 535 | """ 536 | arr_with_pad = unfold_padding(img_coord, cval=0, only_NE=False) 537 | 538 | key = 0 539 | vi_idx = arr_with_pad[key][1:-1,1:-1] 540 | vn1_idx = arr_with_pad[key][1:-1, 0:-2] 541 | vn6_idx = arr_with_pad[key][0:-2,1:-1] 542 | 543 | vi = vertices[vi_idx] 544 | vn1 = vertices[vn1_idx] 545 | vn6 = vertices[vn6_idx] 546 | 547 | # vector from vi to north pole 548 | north_pole = np.array([0, -1, 0]) 549 | to_north_pole = north_pole - vi 550 | 551 | # unit vector from vi to neighbor 552 | vn1vi = vn1 - vi 553 | vn1vi /= norm(vn1vi, axis=-1, keepdims=True) 554 | vn6vi = vn6 - vi 555 | vn6vi /= norm(vn6vi, axis=-1, keepdims=True) 556 | 557 | # face normal 558 | face_n = np.cross(vn1vi, vn6vi) 559 | face_n /= norm(face_n, axis=-1, keepdims=True) 560 | 561 | # to north pole on tangent plane 562 | proj_vec = np.sum(to_north_pole*face_n, axis=-1, keepdims=True)*face_n 563 | np_tangent_plane = to_north_pole - proj_vec 564 | np_tangent_plane /=norm(np_tangent_plane, axis=-1, keepdims=True) 565 | 566 | # calculate cost 567 | psi = np.arccos(np.sum(vn1vi*np_tangent_plane, axis=-1)) 568 | tmp_vals = np.sum(vn6vi*np_tangent_plane, axis=-1) 569 | tmp_vals[0, 0] = np.clip(tmp_vals[0, 0], -1, 1) # make sure value is between -1 and 1 570 | phi = np.arccos(tmp_vals) 571 | weight = phi/(psi+phi) 572 | 573 | return weight 574 | ``` 575 | 576 | ```python 577 | weight = calculate_weight(img_coord, vertices) 578 | ``` 579 | 580 | ```python 581 | %matplotlib inline 582 | ``` 583 | 584 | ```python 585 | plt.figure() 586 | plt.imshow(weight) 587 | ``` 588 | 589 | ```python 590 | img_coord[0] 591 | ``` 592 | 593 | # Project image 594 | 595 | ```python 596 | import matplotlib.pyplot as plt 597 | from torchvision import datasets 598 | from projection_helper import img2ERP, erp2sphere 599 | 600 | outshape = (60, 120) 601 | print("getting mnist data") 602 | trainset = datasets.MNIST(root='raw_data', train=True, download=True) 603 | trainset = datasets.CIFAR10(root='raw_data', train=True, download=True) 604 | ``` 605 | 606 | ## Equirectangular projection 607 | 608 | ```python 609 | idx = 5 610 | h_rot = np.random.uniform(-180, 180) 611 | v_rot = np.random.uniform(-90, 90) 612 | h_rot = 0 613 | v_rot = 0 614 | print(f'Rotate horizontal:{h_rot:.1f} deg, vertical {v_rot:.1f} deg') 615 | img = np.array(trainset[idx][0]) 616 | label_str = trainset.classes[trainset[idx][1]] 617 | print(label_str) 618 | erp_img = img2ERP(img, v_rot=v_rot, h_rot=h_rot, outshape=outshape) 619 | ``` 620 | 621 | ```python 622 | plt.imshow(img) 623 | ``` 624 | 625 | ```python 626 | plt.imshow(erp_img) 627 | ``` 628 | 629 | ```python 630 | 631 | ``` 632 | 633 | ## Project on icosahedron 634 | 635 | ```python 636 | @lru_cache(maxsize=12) 637 | def get_unfold_imgcoord(level=0): 638 | unfold_v, new_faces = get_unfold_icosahedron(level) 639 | distort_unfold = distort_grid(unfold_v) 640 | img_coord = distort_unfold_to_imgcoord(distort_unfold) 641 | return img_coord 642 | 643 | @lru_cache(maxsize=12) 644 | def get_weight_alpha(level=0): 645 | v, f = get_icosahedron(level) 646 | img_coord = get_unfold_imgcoord(level) 647 | weight = calculate_weight(img_coord, v) 648 | return weight 649 | ``` 650 | 651 | ```python 652 | # icosahedron 653 | level = 5 654 | v, f = get_icosahedron(level) 655 | 656 | # unfold 657 | new_unfold, new_faces = get_unfold_icosahedron(level) 658 | 659 | img_coord = get_unfold_imgcoord(level) 660 | alpha = get_weight_alpha(level) 661 | 662 | # unfold_v, new_faces = get_unfold_icosahedron(level) 663 | # distort_unfold = distort_grid(unfold_v) 664 | # img_coord = distort_unfold_to_imgcoord(distort_unfold) 665 | # weight = calculate_weight(img_coord, v) 666 | ``` 667 | 668 | ```python 669 | # img_coord[0] 670 | ``` 671 | 672 | ```python 673 | plt.imshow(alpha) 674 | ``` 675 | 676 | ```python 677 | out = erp2sphere(erp_img, v)/255 678 | # plot(v, f, out) 679 | if out.ndim == 1: 680 | color = out[:, np.newaxis].repeat(3, axis=1) 681 | else: 682 | color = out 683 | ``` 684 | 685 | ```python 686 | # Open3d 687 | pcd = o3d.geometry.PointCloud() 688 | pcd.points = o3d.utility.Vector3dVector(v) 689 | pcd.colors = o3d.utility.Vector3dVector(color) 690 | origin = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6) 691 | o3d.visualization.draw_geometries([pcd, origin]) 692 | ``` 693 | 694 | ## Unfold projection 695 | 696 | ```python 697 | proj_imgs = [color[img_coord[i]] for i in range(5)] 698 | ``` 699 | 700 | ```python 701 | fig, ax = plt.subplots(1, 5) 702 | 703 | for i in range(5): 704 | proj_id = (i+3)%5 705 | ax[i].set_title(f'{proj_id}') 706 | ax[i].imshow(proj_imgs[proj_id]) 707 | ax[i].set_yticks([], []) 708 | ``` 709 | 710 | ```python 711 | proj_imgs_with_pad = unfold_padding(proj_imgs) 712 | fig, ax = plt.subplots(1, 5) 713 | 714 | for i in range(5): 715 | proj_id = (i+3)%5 716 | ax[i].set_title(f'{proj_id}') 717 | ax[i].imshow(proj_imgs_with_pad[proj_id], vmin=0, vmax=1.0) 718 | ax[i].set_yticks([], []) 719 | ``` 720 | 721 | ```python 722 | proj_imgs[0].shape 723 | ``` 724 | 725 | # Plot icosphere points on ERP 726 | 727 | ```python 728 | from projection_helper import xyz2uv, uv2img_idx 729 | ``` 730 | 731 | ```python 732 | outshape = (80, 160) 733 | dst = np.zeros(outshape) 734 | plt.imshow(dst) 735 | ``` 736 | 737 | ```python 738 | v, f = get_base_icosahedron() 739 | v, f = subdivision(v, f, 3) 740 | ``` 741 | 742 | ```python 743 | uv = xyz2uv(v) 744 | img_idx = uv2img_idx(uv, dst) 745 | img_idx = np.round(img_idx) 746 | for y, x in zip(*img_idx): 747 | dst[int(y), int(x)] = 1 748 | plt.imshow(dst) 749 | ``` 750 | 751 | ```python 752 | 753 | ``` 754 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from collections import OrderedDict 5 | from datetime import datetime 6 | from os.path import join 7 | 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchvision import datasets 15 | from torchvision import transforms 16 | from tqdm import tqdm 17 | 18 | from dataloader import UnfoldIcoDataset, ToTensor, Normalize 19 | from models.hexrunet import HexRUNet_C 20 | 21 | parser = argparse.ArgumentParser(description='Training for OmniMNIST', 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | 24 | parser.add_argument('--epochs', default=30, type=int, metavar='N', help='total epochs') 25 | parser.add_argument('--pretrained', default=None, metavar='PATH', 26 | help='path to pre-trained model') 27 | parser.add_argument('--level', default=4, type=int, metavar='N', help='max level for icosahedron') 28 | parser.add_argument('-b', '--batch-size', default=15, type=int, metavar='N', help='mini-batch size') 29 | parser.add_argument('-j', '--workers', default=6, type=int, metavar='N', help='number of data loading workers') 30 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='initial learning rate') 31 | parser.add_argument('--arch', default='hexrunet', type=str, help='architecture name for log folder') 32 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', help='tensorboard log interval') 33 | parser.add_argument('--train_rot', action='store_true', help='rotate image for trainset') 34 | parser.add_argument('--test_rot', action='store_true', help='rotate image for testset') 35 | 36 | 37 | def main(): 38 | args = parser.parse_args() 39 | print('Arguments:') 40 | print(json.dumps(vars(args), indent=1)) 41 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 42 | # device = torch.device('cpu') 43 | if device.type != 'cpu': 44 | cudnn.benchmark = True 45 | print("device:", device) 46 | 47 | print('=> setting data loader') 48 | erp_shape = (60, 120) 49 | transform = transforms.Compose([ToTensor(), Normalize((0.0645,), (0.2116,))]) 50 | trainset = UnfoldIcoDataset(datasets.MNIST(root='raw_data', train=True, download=True), 51 | erp_shape, args.level, rotate=args.train_rot, transform=transform) 52 | testset = UnfoldIcoDataset(datasets.MNIST(root='raw_data', train=False, download=True), 53 | erp_shape, args.level, rotate=args.test_rot, transform=transform) 54 | train_loader = DataLoader(trainset, args.batch_size, shuffle=True, num_workers=args.workers) 55 | test_loader = DataLoader(testset, args.batch_size, shuffle=False, num_workers=args.workers) 56 | 57 | print('=> setting model') 58 | start_epoch = 0 59 | model = HexRUNet_C(1) 60 | total_params = 0 61 | for param in model.parameters(): 62 | total_params += np.prod(param.shape) 63 | print(f"Total model parameters: {total_params:,}.") 64 | model = model.to(device) 65 | 66 | # Loss function 67 | print('=> setting loss function') 68 | criterion = nn.CrossEntropyLoss() 69 | 70 | # setup solver scheduler 71 | print('=> setting optimizer') 72 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 73 | 74 | print('=> setting scheduler') 75 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) 76 | 77 | if args.pretrained: 78 | checkpoint = torch.load(args.pretrained) 79 | print("=> using pre-trained weights") 80 | model.load_state_dict(checkpoint['state_dict']) 81 | start_epoch = checkpoint['epoch'] 82 | optimizer.load_state_dict(checkpoint['optimizer']) 83 | scheduler.load_state_dict(checkpoint['scheduler']) 84 | print("=> Resume training from epoch {}".format(start_epoch)) 85 | 86 | timestamp = datetime.now().strftime("%m%d-%H%M") 87 | log_folder = join('checkpoints', f'{args.arch}_{timestamp}') 88 | print(f'=> create log folder: {log_folder}') 89 | os.makedirs(log_folder, exist_ok=True) 90 | with open(join(log_folder, 'args.json'), 'w') as f: 91 | json.dump(vars(args), f, indent=1) 92 | writer = SummaryWriter(log_dir=log_folder) 93 | writer.add_text('args', json.dumps(vars(args), indent=1)) 94 | 95 | # Training 96 | for epoch in range(start_epoch, args.epochs): 97 | 98 | # -------------------------- 99 | # training 100 | # -------------------------- 101 | model.train() 102 | losses = [] 103 | pbar = tqdm(train_loader) 104 | total = 0 105 | correct = 0 106 | mode = 'train' 107 | for idx, batch in enumerate(pbar): 108 | # to cuda 109 | for key in batch.keys(): 110 | batch[key] = batch[key].to(device) 111 | outputs = model(batch) 112 | labels = batch['label'] 113 | 114 | # Loss function 115 | loss = criterion(outputs, labels) 116 | losses.append(loss.item()) 117 | 118 | # update parameters 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | # accuracy 124 | _, predicted = outputs.max(1) 125 | total += labels.size(0) 126 | correct += (predicted == labels).sum().item() 127 | 128 | # update progress bar 129 | display = OrderedDict(mode=f'{mode}', epoch=f"{epoch:>2}", loss=f"{losses[-1]:.4f}") 130 | pbar.set_postfix(display) 131 | 132 | # tensorboard log 133 | if idx % args.log_interval == 0: 134 | niter = epoch * len(train_loader) + idx 135 | writer.add_scalar(f'{mode}/loss', loss.item(), niter) 136 | 137 | # End of one epoch 138 | scheduler.step() 139 | ave_loss = sum(losses) / len(losses) 140 | ave_acc = 100 * correct / total 141 | writer.add_scalar(f'{mode}/loss_ave', ave_loss, epoch) 142 | writer.add_scalar(f'{mode}/acc_ave', ave_acc, epoch) 143 | 144 | print(f"Epoch:{epoch}, Train Loss average:{ave_loss:.4f}, Accuracy average:{ave_acc:.2f}") 145 | 146 | save_data = { 147 | 'epoch': epoch + 1, 148 | 'state_dict': model.state_dict(), 149 | 'optimizer': optimizer.state_dict(), 150 | 'scheduler': scheduler.state_dict(), 151 | 'ave_loss': ave_loss, 152 | } 153 | torch.save(save_data, join(log_folder, f'checkpoints_{epoch}.pth')) 154 | 155 | # -------------------------- 156 | # evaluation 157 | # -------------------------- 158 | model.eval() 159 | losses = [] 160 | pbar = tqdm(test_loader) 161 | total = 0 162 | correct = 0 163 | mode = 'test' 164 | for idx, batch in enumerate(pbar): 165 | with torch.no_grad(): 166 | # to cuda 167 | for key in batch.keys(): 168 | batch[key] = batch[key].to(device) 169 | outputs = model(batch) 170 | labels = batch['label'] 171 | 172 | # Loss function 173 | loss = criterion(outputs, labels) 174 | losses.append(loss.item()) 175 | 176 | # accuracy 177 | _, predicted = outputs.max(1) 178 | total += labels.size(0) 179 | correct += (predicted == labels).sum().item() 180 | 181 | # update progress bar 182 | display = OrderedDict(mode=f'{mode}', epoch=f"{epoch:>2}", loss=f"{losses[-1]:.4f}") 183 | pbar.set_postfix(display) 184 | 185 | # tensorboard log 186 | if idx % args.log_interval == 0: 187 | niter = epoch * len(test_loader) + idx 188 | writer.add_scalar(f'{mode}/loss', loss.item(), niter) 189 | 190 | # End of one epoch 191 | ave_loss = sum(losses) / len(losses) 192 | ave_acc = 100 * correct / total 193 | writer.add_scalar(f'{mode}/loss_ave', ave_loss, epoch) 194 | writer.add_scalar(f'{mode}/acc_ave', ave_acc, epoch) 195 | 196 | print(f"Epoch:{epoch}, Test Loss average:{ave_loss:.4f}, Accuracy average:{ave_acc:.2f}") 197 | 198 | writer.close() 199 | print("Finish") 200 | 201 | if __name__ == '__main__': 202 | main() 203 | -------------------------------------------------------------------------------- /utils/geometry_helper.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import copy 3 | from functools import lru_cache 4 | from functools import partial 5 | 6 | import igl 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from numpy.linalg import norm 11 | 12 | 13 | # 14 | @lru_cache(maxsize=12) 15 | def get_icosahedron(level=0): 16 | if level == 0: 17 | v, f = get_base_icosahedron() 18 | return v, f 19 | # require subdivision 20 | v, f = get_icosahedron(level - 1) 21 | v, f = subdivision(v, f, 1) 22 | return v, f 23 | 24 | 25 | @lru_cache(maxsize=12) 26 | def get_unfold_icosahedron(level=0): 27 | if level == 0: 28 | unfold_v, f = get_base_unfold() 29 | return unfold_v, f 30 | # require subdivision 31 | unfold_v, f = get_unfold_icosahedron(level - 1) 32 | unfold_v, f = unfold_subdivision(unfold_v, f) 33 | return unfold_v, f 34 | 35 | 36 | @lru_cache(maxsize=12) 37 | def get_unfold_imgcoord(level=0): 38 | unfold_v, new_faces = get_unfold_icosahedron(level) 39 | distort_unfold = distort_grid(unfold_v) 40 | img_coord = distort_unfold_to_imgcoord(distort_unfold) 41 | return img_coord 42 | 43 | 44 | @lru_cache(maxsize=12) 45 | def get_weight_alpha(level=0): 46 | v, f = get_icosahedron(level) 47 | img_coord = get_unfold_imgcoord(level) 48 | weight = calculate_weight(img_coord, v) 49 | return weight 50 | 51 | 52 | def get_base_icosahedron(): 53 | t = (1.0 + 5.0 ** .5) / 2.0 54 | vertices = [-1, t, 0, 1, t, 0, 0, 1, t, -t, 0, 1, -t, 0, -1, 0, 1, -t, t, 0, -1, t, 0, 55 | 1, 0, -1, t, -1, -t, 0, 0, -1, -t, 1, -t, 0] 56 | faces = [0, 2, 1, 0, 3, 2, 0, 4, 3, 0, 5, 4, 0, 1, 5, 57 | 1, 7, 6, 1, 2, 7, 2, 8, 7, 2, 3, 8, 3, 9, 8, 3, 4, 9, 4, 10, 9, 4, 5, 10, 5, 6, 10, 5, 1, 6, 58 | 6, 7, 11, 7, 8, 11, 8, 9, 11, 9, 10, 11, 10, 6, 11] 59 | 60 | # make every vertex have radius 1.0 61 | vertices = np.reshape(vertices, (-1, 3)) / (np.sin(2 * np.pi / 5) * 2) 62 | faces = np.reshape(faces, (-1, 3)) 63 | 64 | # Rotate vertices so that v[0] = (0, -1, 0), v[1] is on yz-plane 65 | ry = -vertices[0] 66 | rx = np.cross(ry, vertices[1]) 67 | rx /= np.linalg.norm(rx) 68 | rz = np.cross(rx, ry) 69 | R = np.stack([rx, ry, rz]) 70 | vertices = vertices.dot(R.T) 71 | return vertices, faces 72 | 73 | 74 | def subdivision(v, f, level=1): 75 | for _ in range(level): 76 | # subdivision 77 | v, f = igl.upsample(v, f) 78 | # normalize 79 | v /= np.linalg.norm(v, axis=1)[:, np.newaxis] 80 | return v, f 81 | 82 | 83 | def get_base_unfold(): 84 | v, f = get_base_icosahedron() 85 | unfold_v = {i: [] for i in range(12)} 86 | 87 | # edge length 88 | edge_len = 1 / np.sin(2 * np.pi / 5) 89 | # height 90 | h = 3 ** 0.5 * edge_len / 2 91 | 92 | # v0 93 | for i in range(5): 94 | unfold_v[0].append([i * edge_len, 0]) 95 | 96 | # v1 97 | for _ in range(5): 98 | unfold_v[1].append([-0.5 * edge_len, h]) 99 | unfold_v[1][1] = [-0.5 * edge_len + 5 * edge_len, h] 100 | unfold_v[1][4] = [-0.5 * edge_len + 5 * edge_len, h] 101 | 102 | # v2-v5 103 | for i in range(2, 6): 104 | for _ in range(5): 105 | unfold_v[i].append([(0.5 + i - 2) * edge_len, h]) 106 | 107 | # v6 108 | for _ in range(5): 109 | unfold_v[6].append([-edge_len, 2 * h]) 110 | unfold_v[6][1] = [-edge_len + 5 * edge_len, 2 * h] 111 | unfold_v[6][2] = [-edge_len + 5 * edge_len, 2 * h] 112 | unfold_v[6][4] = [-edge_len + 5 * edge_len, 2 * h] 113 | 114 | # v7-v10 115 | for i in range(7, 11): 116 | for _ in range(5): 117 | unfold_v[i].append([(i - 7) * edge_len, 2 * h]) 118 | 119 | # v11 120 | for i in range(5): 121 | unfold_v[11].append([(-0.5 + i) * edge_len, 3 * h]) 122 | 123 | # to numpy 124 | for i in range(len(unfold_v)): 125 | unfold_v[i] = np.array(unfold_v[i]) 126 | return unfold_v, f 127 | 128 | 129 | class UnfoldVertex(object): 130 | def __init__(self, unfold_v): 131 | self.unfold_v = unfold_v 132 | self.reset() 133 | 134 | def __getitem__(self, item): 135 | pos = self.unfold_v[item][self.cnt[item]] 136 | self.cnt[item] += 1 137 | return pos 138 | 139 | def reset(self): 140 | self.cnt = {key: 0 for key in self.unfold_v.keys()} 141 | 142 | 143 | class VertexIdxManager(object): 144 | def __init__(self, unfold_v): 145 | self.reg_v = {} 146 | self.next_v_index = len(unfold_v) 147 | 148 | def get_next(self, a, b): 149 | if a > b: 150 | a, b = b, a 151 | key = f'{a},{b}' 152 | if key not in self.reg_v: 153 | self.reg_v[key] = self.next_v_index 154 | self.next_v_index += 1 155 | return self.reg_v[key] 156 | 157 | 158 | def unfold_subdivision(unfold_v, faces): 159 | v_idx_manager = VertexIdxManager(unfold_v) 160 | 161 | new_faces = [] 162 | new_unfold = copy(unfold_v) 163 | v_obj = UnfoldVertex(unfold_v) 164 | for (a, b, c) in faces: 165 | a_pos = v_obj[a] 166 | b_pos = v_obj[b] 167 | c_pos = v_obj[c] 168 | 169 | new_a = v_idx_manager.get_next(a, b) 170 | new_b = v_idx_manager.get_next(b, c) 171 | new_c = v_idx_manager.get_next(c, a) 172 | 173 | new_a_pos = (a_pos + b_pos) / 2 174 | new_b_pos = (b_pos + c_pos) / 2 175 | new_c_pos = (c_pos + a_pos) / 2 176 | 177 | # new faces 178 | new_faces.append([a, new_a, new_c]) 179 | new_faces.append([b, new_b, new_a]) 180 | new_faces.append([new_a, new_b, new_c]) 181 | new_faces.append([new_b, c, new_c]) 182 | 183 | # new vertex 184 | indices = [new_a, new_b, new_c] 185 | poses = [new_a_pos, new_b_pos, new_c_pos] 186 | for (idx, pos) in zip(indices, poses): 187 | if idx not in new_unfold: 188 | new_unfold[idx] = [] 189 | for _ in range(3): 190 | new_unfold[idx].append(pos) 191 | return new_unfold, new_faces 192 | 193 | 194 | def distort_grid(unfold_v): 195 | np_round = partial(np.round, decimals=9) 196 | 197 | # calculate transform matrix 198 | new_x = unfold_v[2][0] - unfold_v[0][0] 199 | edge_len = np.linalg.norm(new_x) 200 | new_x /= edge_len 201 | new_y = np.cross([0, 0, 1], np.append(new_x, 0))[:2] 202 | R = np.stack([new_x, new_y]) 203 | 204 | a = unfold_v[2][0] - unfold_v[0][0] 205 | b = unfold_v[1][0] - unfold_v[0][0] 206 | skew = np.eye(2) 207 | skew[0, 1] = -1 / np.tan(np.arccos(a.dot(b) / norm(a) / norm(b))) 208 | skew[0] /= norm(skew[0]) 209 | 210 | T = skew.dot(R) 211 | # scale adjust 212 | scale = np.linalg.det(skew) * edge_len 213 | T /= scale 214 | 215 | # to numpy array for efficient computation 216 | # np_round to alleviate numerical error when sorting 217 | distort_unfold = copy(unfold_v) 218 | five_neighbor = [distort_unfold[i] for i in range(12)] 219 | five_neighbor = np.array(five_neighbor) 220 | # Transform 221 | five_neighbor = np_round(five_neighbor.dot(T.T)) 222 | 223 | # the same procedure for six_neighbor if len(unfold_v) > 12 224 | if len(unfold_v) > 12: 225 | six_neighbor = [distort_unfold[i] for i in range(12, len(unfold_v))] 226 | six_neighbor = np.array(six_neighbor) 227 | six_neighbor = np_round(six_neighbor.dot(T.T)) 228 | 229 | # to original shape 230 | distort_unfold = {} 231 | cnt = 0 232 | for it in five_neighbor: 233 | distort_unfold[cnt] = it 234 | cnt += 1 235 | if len(unfold_v) > 12: 236 | for it in six_neighbor: 237 | distort_unfold[cnt] = it 238 | cnt += 1 239 | return distort_unfold 240 | 241 | 242 | def get_rect_idxs(x, y): 243 | rect_idxs = [] 244 | for i in range(5): 245 | x_min = i 246 | x_max = x_min + 1 247 | y_min = -i 248 | y_max = y_min + 2 249 | if x_min <= x <= x_max and y_min <= y <= y_max: 250 | rect_idxs.append(i) 251 | return rect_idxs 252 | 253 | 254 | def distort_unfold_to_imgcoord(distort_unfold, drop_NE=True): 255 | """ 256 | Parameters 257 | ---------- 258 | distort_unfold : 259 | distorted unfold 260 | drop_NE : bool 261 | drop north and east as in [1] 262 | 263 | References 264 | ---------- 265 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 266 | 267 | """ 268 | vertex_num = len(distort_unfold) 269 | level = round(math.log((vertex_num - 2) // 10, 4)) 270 | 271 | width = 2 ** level + 1 272 | height = 2 * width - 1 273 | 274 | unfold_pts_set = set() # (vertex_id, x, y) 275 | 276 | # remove duplicate 277 | for key, arr in distort_unfold.items(): 278 | for val in arr: 279 | unfold_pts_set.add((key, val[0], val[1])) 280 | 281 | # sort 282 | unfold_pts_set = sorted(unfold_pts_set, key=lambda x: (x[1], x[2])) 283 | 284 | # to image coorinate 285 | img_coord = {} 286 | for (vertex_id, x, y) in unfold_pts_set: 287 | rect_idxs = get_rect_idxs(x, y) 288 | for key in rect_idxs: 289 | if key not in img_coord: 290 | img_coord[key] = [] 291 | img_coord[key].append(vertex_id) 292 | 293 | # to numpy 294 | for key in img_coord: 295 | img_coord[key] = np.array(img_coord[key]).reshape(width, height).T 296 | 297 | if drop_NE: 298 | # orientation-aware semantic segmentation on icosahedron spheres form 299 | for key in img_coord: 300 | img_coord[key] = img_coord[key][1:, :-1] 301 | 302 | return img_coord 303 | 304 | 305 | def unfold_padding(arr, cval=0, only_NE=False): 306 | """ 307 | Parameters 308 | ---------- 309 | arr : dict {0-4: array} 310 | array 311 | cval : int or float 312 | initial padding value 313 | 314 | only_NE : bool 315 | drop north and east as in [1] 316 | 317 | References 318 | ---------- 319 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 320 | """ 321 | is_ndarray = False 322 | if isinstance(arr[0], np.ndarray): 323 | is_ndarray = True 324 | # to torch tensor 325 | arr = copy(arr) 326 | for i in range(5): 327 | arr[i] = torch.from_numpy(arr[i].copy()) 328 | if arr[i].ndim == 3: 329 | # H x W x C -> C x H x W 330 | arr[i] = arr[i].permute(2, 0, 1) 331 | elif arr[i].ndim == 2: 332 | # H x W -> 1 x H x W 333 | arr[i] = arr[i].unsqueeze(0) 334 | # Add batch dimension 335 | arr[i] = arr[i].unsqueeze(0) 336 | 337 | h, w = arr[0].size(2), arr[0].size(3) 338 | 339 | arr_with_pad = [] 340 | pad_w = (0, 1, 1, 0) if only_NE else (1, 1, 1, 1) 341 | for key in range(5): 342 | arr_with_pad.append(F.pad(arr[key], pad_w, value=cval)) 343 | 344 | for key in range(5): 345 | tgt = (key + 1) % 5 346 | # north 347 | arr_with_pad[key][..., 0, pad_w[0] + 1:] = arr[tgt][..., :w, 0] 348 | # east 349 | arr_with_pad[key][..., 1:w + 1, -1] = arr[tgt][..., w:, 0] 350 | arr_with_pad[key][..., w + 1:-1 - pad_w[3], -1] = arr[tgt][..., -1, 1:] 351 | 352 | if not only_NE: 353 | tgt = (key - 1) % 5 354 | # some Indices look like shifted but if you check the connectivility, it's fine 355 | # west 356 | arr_with_pad[key][..., 1:w + 1, 0] = arr[tgt][..., 0, :] 357 | arr_with_pad[key][..., w + 1:-1, 0] = arr[tgt][..., :w, -1] 358 | # south 359 | arr_with_pad[key][..., -1, :-1] = arr[tgt][..., w - 1:, -1] 360 | 361 | if is_ndarray: 362 | for i in range(5): 363 | arr_with_pad[i] = arr_with_pad[i].permute(0, 2, 3, 1).squeeze().numpy() 364 | 365 | return arr_with_pad 366 | 367 | 368 | def calculate_weight(img_coord, vertices): 369 | """ calculate weight alpha 370 | 371 | References 372 | ---------- 373 | [1] orientation-aware semantic segmentation on icosahedron spheres, ICCV2019 374 | """ 375 | arr_with_pad = unfold_padding(img_coord, cval=0, only_NE=False) 376 | 377 | key = 0 378 | vi_idx = arr_with_pad[key][1:-1, 1:-1] 379 | vn1_idx = arr_with_pad[key][1:-1, 0:-2] 380 | vn6_idx = arr_with_pad[key][0:-2, 1:-1] 381 | 382 | vi = vertices[vi_idx] 383 | vn1 = vertices[vn1_idx] 384 | vn6 = vertices[vn6_idx] 385 | 386 | # vector from vi to north pole 387 | north_pole = np.array([0, -1, 0]) 388 | to_north_pole = north_pole - vi 389 | 390 | # unit vector from vi to neighbor 391 | vn1vi = vn1 - vi 392 | vn1vi /= norm(vn1vi, axis=-1, keepdims=True) 393 | vn6vi = vn6 - vi 394 | vn6vi /= norm(vn6vi, axis=-1, keepdims=True) 395 | 396 | # face normal 397 | face_n = np.cross(vn1vi, vn6vi) 398 | face_n /= norm(face_n, axis=-1, keepdims=True) 399 | 400 | # to north pole on tangent plane 401 | proj_vec = np.sum(to_north_pole * face_n, axis=-1, keepdims=True) * face_n 402 | np_tangent_plane = to_north_pole - proj_vec 403 | np_tangent_plane /= norm(np_tangent_plane, axis=-1, keepdims=True) 404 | 405 | # calculate cost 406 | psi = np.arccos(np.sum(vn1vi * np_tangent_plane, axis=-1)) 407 | tmp_vals = np.sum(vn6vi * np_tangent_plane, axis=-1) 408 | tmp_vals[0, 0] = np.clip(tmp_vals[0, 0], -1, 1) # make sure value is between -1 and 1 409 | phi = np.arccos(tmp_vals) 410 | weight = phi / (psi + phi) 411 | 412 | return weight 413 | -------------------------------------------------------------------------------- /utils/projection_helper.py: -------------------------------------------------------------------------------- 1 | # original https://github.com/ChiWeiHsiao/SphereNet-pytorch/blob/master/spherenet/dataset.py 2 | # modified by https://github.com/matsuren 3 | from functools import lru_cache 4 | 5 | import numpy as np 6 | from scipy.ndimage.interpolation import map_coordinates 7 | 8 | 9 | @lru_cache(maxsize=360) 10 | def genuv(h, w, v_rot=0): 11 | assert -np.pi / 2 <= v_rot and v_rot <= np.pi / 2 12 | u, v = np.meshgrid(np.arange(w), np.arange(h)) 13 | u = (u + 0.5) * 2 * np.pi / w - np.pi 14 | v = (v + 0.5) * np.pi / h - np.pi / 2 15 | uv = np.stack([u, v], axis=-1) 16 | 17 | if v_rot != 0: 18 | # rotation 19 | xyz = uv2xyz(uv.astype(np.float64)) 20 | # Rx = np.array([ 21 | # [1, 0, 0], 22 | # [0, np.cos(v_rot), np.sin(v_rot)], 23 | # [0, -np.sin(v_rot), np.cos(v_rot)], 24 | # ]) 25 | xyz_rot = xyz.copy() 26 | xyz_rot[..., 0] = xyz[..., 0] 27 | xyz_rot[..., 1] = np.cos(v_rot) * xyz[..., 1] + np.sin(v_rot) * xyz[..., 2] 28 | xyz_rot[..., 2] = -np.sin(v_rot) * xyz[..., 1] + np.cos(v_rot) * xyz[..., 2] 29 | uv = xyz2uv(xyz_rot) 30 | 31 | return uv 32 | 33 | 34 | def uv2xyz(uv): 35 | sin_u = np.sin(uv[..., 0]) 36 | cos_u = np.cos(uv[..., 0]) 37 | sin_v = np.sin(uv[..., 1]) 38 | cos_v = np.cos(uv[..., 1]) 39 | return np.stack([ 40 | cos_v * sin_u, 41 | sin_v, 42 | cos_v * cos_u, 43 | ], axis=-1) 44 | 45 | 46 | def xyz2uv(xyz): 47 | x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] 48 | u = np.arctan2(x, z) 49 | c = np.sqrt(x * x + z * z) 50 | v = np.arctan2(y, c) 51 | return np.stack([u, v], axis=-1) 52 | 53 | 54 | def uv2proj_img_idx(uv, h, w, u_fov, v_fov): 55 | assert 0 < u_fov and u_fov < np.pi 56 | assert 0 < v_fov and v_fov < np.pi 57 | 58 | u = uv[..., 0] 59 | v = uv[..., 1] 60 | 61 | x = np.tan(u) 62 | y = np.tan(v) / np.cos(u) 63 | x = x * w / (2 * np.tan(u_fov / 2)) + w / 2 64 | y = y * h / (2 * np.tan(v_fov / 2)) + h / 2 65 | 66 | invalid = (u < -u_fov / 2) | (u > u_fov / 2) | \ 67 | (v < -v_fov / 2) | (v > v_fov / 2) 68 | x[invalid] = -100 69 | y[invalid] = -100 70 | 71 | return np.stack([y, x], axis=0) 72 | 73 | 74 | def remap(img, img_idx, cval=[0, 0, 0], method="linear"): 75 | # interpolation method 76 | if method == "linear": 77 | order = 1 78 | else: 79 | # nearest 80 | order = 0 81 | 82 | # remap image 83 | if img.ndim == 2: 84 | # grayscale 85 | x = map_coordinates(img, img_idx, order=order, cval=cval[0]) 86 | elif img.ndim == 3: 87 | # color 88 | x = np.zeros([*img_idx.shape[1:], img.shape[2]], dtype=img.dtype) 89 | for i in range(img.shape[2]): 90 | x[..., i] = map_coordinates(img[..., i], img_idx, order=order, cval=cval[i]) 91 | else: 92 | assert False, 'img.ndim should be 2 (grayscale) or 3 (color)' 93 | 94 | return x 95 | 96 | 97 | def img2ERP(img, h_rot=0, v_rot=0, outshape=(60, 60), fov=120, cval=[0, 0, 0]): 98 | h, w = img.shape[:2] 99 | 100 | fov_rad = fov * np.pi / 180 101 | h_rot_rad = h_rot * np.pi / 180 102 | v_rot_rad = v_rot * np.pi / 180 103 | 104 | # Vertical rotate if applicable 105 | uv = genuv(*outshape, v_rot_rad) 106 | img_idx = uv2proj_img_idx(uv, h, w, fov_rad, fov_rad) 107 | 108 | # transform 109 | x = remap(img, img_idx, cval=cval) 110 | 111 | # Horizontal rotate 112 | delta = 2 * np.pi / (outshape[1]) 113 | v_rot_idx = int(np.round(h_rot_rad / delta)) 114 | x = np.roll(x, v_rot_idx, axis=1) 115 | return x 116 | 117 | 118 | def uv2img_idx(uv, erp_img): 119 | h, w = erp_img.shape[:2] 120 | delta_w = 2 * np.pi / w 121 | delta_h = np.pi / h 122 | x = uv[..., 0] / delta_w + w / 2 - 0.5 123 | y = uv[..., 1] / delta_h + h / 2 - 0.5 124 | x = np.clip(x, 0, w - 1) 125 | y = np.clip(y, 0, h - 1) 126 | return np.stack([y, x], axis=0) 127 | 128 | 129 | def erp2sphere(erp_img, V, method="linear"): 130 | """ 131 | 132 | Parameters 133 | ---------- 134 | erp_img: equirectangular projection image 135 | V: array of spherical coordinates of shape (n_vertex, 3) 136 | method: interpolation method. "linear" or "nearest" 137 | """ 138 | uv = xyz2uv(V) 139 | img_idx = uv2img_idx(uv, erp_img) 140 | x = remap(erp_img, img_idx, method=method) 141 | return x 142 | 143 | 144 | if __name__ == "__main__": 145 | import matplotlib.pyplot as plt 146 | from torchvision import datasets 147 | 148 | outshape = (60, 60) 149 | print("getting mnist data") 150 | trainset = datasets.MNIST(root='raw_data', train=True, download=True) 151 | 152 | results = [] 153 | for idx in range(3): 154 | h_rot = np.random.uniform(-180, 180) 155 | v_rot = np.random.uniform(-90, 90) 156 | print(f'Rotate horizontal:{h_rot:.1f} deg, vertical {v_rot:.1f} deg') 157 | img = np.array(trainset[idx][0]) 158 | label_str = trainset.classes[trainset[idx][1]] 159 | x = img2ERP(img, v_rot=v_rot, h_rot=h_rot, outshape=outshape) 160 | results.append((img, x, label_str)) 161 | 162 | # show 163 | fig, ax = plt.subplots(3, 2, figsize=(5, 8)) 164 | for i, (img, x, label_str) in enumerate(results): 165 | ax[i][0].set_title(label_str) 166 | ax[i][0].imshow(img) 167 | ax[i][1].imshow(x) 168 | for it in ax.flatten(): 169 | it.set_yticklabels([]) 170 | it.set_xticklabels([]) 171 | plt.show() 172 | --------------------------------------------------------------------------------