├── s2cnn ├── utils │ ├── __init__.py │ ├── cuda.py │ ├── complex.py │ └── decorator.py ├── soft │ ├── calculus.pdf │ ├── __init__.py │ ├── so3_integrate.py │ ├── s2_conv.py │ ├── so3_rotation.py │ ├── so3_conv.py │ ├── s2_fft.py │ └── so3_fft.py ├── __init__.py ├── s2_grid.py ├── so3_grid.py ├── s2_ft.py ├── so3_ft.py ├── so3_mm.py └── s2_mm.py ├── .gitignore ├── examples ├── equivariance_plot │ ├── fig.jpeg │ ├── earth128.jpg │ └── main.py ├── shrec17 │ ├── .gitignore │ ├── model.py │ ├── model_original.py │ ├── README.md │ ├── test.py │ ├── train.py │ └── dataset.py ├── molecules │ ├── run_all_experiments.sh │ ├── README.md │ ├── baseline_model.py │ ├── utils.py │ ├── s2cnn_model.py │ ├── datagen.py │ └── run_experiment.py ├── fast_ft │ └── main.py ├── equivariance_error │ └── main.py └── mnist │ ├── README.md │ ├── run_classic.py │ ├── run.py │ └── gendata.py ├── setup.py ├── LICENSE ├── FAQ.md ├── tests └── so3_fft.py └── README.md /s2cnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.eggs 3 | __pycache__ 4 | build 5 | dist 6 | *cache 7 | -------------------------------------------------------------------------------- /s2cnn/soft/calculus.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonkhler/s2cnn/HEAD/s2cnn/soft/calculus.pdf -------------------------------------------------------------------------------- /examples/equivariance_plot/fig.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonkhler/s2cnn/HEAD/examples/equivariance_plot/fig.jpeg -------------------------------------------------------------------------------- /examples/equivariance_plot/earth128.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jonkhler/s2cnn/HEAD/examples/equivariance_plot/earth128.jpg -------------------------------------------------------------------------------- /examples/shrec17/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | evaluator 3 | *.ipynb 4 | *.zip 5 | *.pkl 6 | test_perturbed 7 | val_perturbed 8 | test_normal 9 | val_normal 10 | -------------------------------------------------------------------------------- /s2cnn/soft/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101,W0401 2 | from .s2_conv import S2Convolution 3 | from .so3_conv import SO3Convolution 4 | from .so3_integrate import so3_integrate 5 | from .so3_rotation import so3_rotation 6 | -------------------------------------------------------------------------------- /examples/molecules/run_all_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | runid=0; 3 | for run in 0 1 2 3 4; 4 | do 5 | for strat in 0 1 2 3 4; 6 | do 7 | echo "starting run $run for strat $strat" 8 | python3 run_experiment.py --test_strat ${strat} > logs/default_settings_strat_${strat}_run_${run}.txt 9 | done; 10 | done; 11 | 12 | -------------------------------------------------------------------------------- /s2cnn/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101,W0401 2 | from .s2_ft import s2_rft 3 | from .so3_ft import so3_rft 4 | from .s2_grid import s2_near_identity_grid, s2_equatorial_grid, s2_soft_grid 5 | from .so3_grid import so3_near_identity_grid, so3_equatorial_grid, so3_soft_grid 6 | from .s2_mm import s2_mm 7 | from .so3_mm import so3_mm 8 | from .soft import * 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #pylint: disable=C 2 | import os 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='s2cnn', 7 | version="1.0.0", 8 | author="Mario Geiger, Taco Cohen, Jonas Koehler", 9 | description=("SO(3) equivariant CNNs for PyTorch."), 10 | license="MIT", 11 | keywords="so3 equivariant cnn pytorch", 12 | url="https://github.com/AMLab-Amsterdam/s2cnn", 13 | long_description=open(os.path.join(os.path.dirname(__file__), "README.md"), encoding='utf-8').read(), 14 | packages=find_packages(exclude=["build"]), 15 | ) 16 | -------------------------------------------------------------------------------- /examples/molecules/README.md: -------------------------------------------------------------------------------- 1 | # Molecule Example 2 | 3 | ## Get dataset 4 | 5 | Download original QM7 Matlab file 6 | 7 | ```bash 8 | wget http://quantum-machine.org/data/qm7.mat 9 | ``` 10 | 11 | Run preprocessing script 12 | ```bash 13 | python3 datagen.py 14 | ``` 15 | 16 | ## Run experiments 17 | 18 | ```bash 19 | ./run_all_experiments.sh 20 | ``` 21 | ### Remark about results 22 | 23 | This version is not the exact architecture as explained in the paper but a much simpler one which instead uses only a fraction of parameters, runs faster, is more stable and produces much better results. When run correctly, it should produce a RMSE in the ~5 regime. 24 | 25 | If you have of any questions about this experiment feel free to contact jonas (at) argmin (dot) xyz. 26 | -------------------------------------------------------------------------------- /s2cnn/utils/cuda.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | from collections import namedtuple 3 | from cupy.cuda import function # pylint: disable=E0401 4 | from pynvrtc.compiler import Program # pylint: disable=E0401 5 | 6 | 7 | CUDA_NUM_THREADS = 1024 8 | CUDA_MAX_GRID_DIM = 2**16 - 1 9 | 10 | 11 | def get_blocks(n, num_threads): 12 | n_per_instance = (n + num_threads * CUDA_MAX_GRID_DIM - 1) // (num_threads * CUDA_MAX_GRID_DIM) 13 | return (n + num_threads * n_per_instance - 1) // (num_threads * n_per_instance) 14 | 15 | 16 | Stream = namedtuple('Stream', ['ptr']) 17 | 18 | 19 | def compile_kernel(kernel, filename, functioname): 20 | program = Program(kernel, filename) 21 | ptx = program.compile() 22 | 23 | m = function.Module() 24 | m.load(bytes(ptx.encode())) 25 | 26 | f = m.get_function(functioname) 27 | return f 28 | -------------------------------------------------------------------------------- /examples/fast_ft/main.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101,E1102,W0621 2 | ''' 3 | Compare so3_ft with so3_fft 4 | ''' 5 | import torch 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | b_in, b_out = 6, 6 # bandwidth 10 | # random input data to be Fourier Transform 11 | x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float, device=device) # [beta, alpha, gamma] 12 | 13 | 14 | # Fast version 15 | from s2cnn.soft.so3_fft import so3_rfft 16 | 17 | y1 = so3_rfft(x, b_out=b_out) 18 | 19 | 20 | # Equivalent version but using the naive version 21 | from s2cnn import so3_rft, so3_soft_grid 22 | import lie_learn.spaces.S3 as S3 23 | 24 | # so3_ft computes a non weighted Fourier transform 25 | weights = torch.tensor(S3.quadrature_weights(b_in), dtype=torch.float, device=device) 26 | x = torch.einsum("bac,b->bac", (x, weights)) 27 | 28 | y2 = so3_rft(x.view(-1), b_out, so3_soft_grid(b_in)) 29 | 30 | 31 | # Compare values 32 | assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item() 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Taco Cohen, Mario Geiger, Jonas Köhler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /s2cnn/soft/so3_integrate.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import torch 3 | from functools import lru_cache 4 | from s2cnn.utils.decorator import show_running 5 | 6 | 7 | def so3_integrate(x): 8 | """ 9 | Integrate a signal on SO(3) using the Haar measure 10 | 11 | :param x: [..., beta, alpha, gamma] (..., 2b, 2b, 2b) 12 | :return y: [...] (...) 13 | """ 14 | assert x.size(-1) == x.size(-2) 15 | assert x.size(-2) == x.size(-3) 16 | 17 | b = x.size(-1) // 2 18 | 19 | w = _setup_so3_integrate(b, device_type=x.device.type, device_index=x.device.index) # [beta] 20 | 21 | x = torch.sum(x, dim=-1).squeeze(-1) # [..., beta, alpha] 22 | x = torch.sum(x, dim=-1).squeeze(-1) # [..., beta] 23 | 24 | sz = x.size() 25 | x = x.view(-1, 2 * b) 26 | w = w.view(2 * b, 1) 27 | x = torch.mm(x, w).squeeze(-1) 28 | x = x.view(*sz[:-1]) 29 | return x 30 | 31 | 32 | @lru_cache(maxsize=32) 33 | @show_running 34 | def _setup_so3_integrate(b, device_type, device_index): 35 | import lie_learn.spaces.S3 as S3 36 | 37 | return torch.tensor(S3.quadrature_weights(b), dtype=torch.float32, device=torch.device(device_type, device_index)) # (2b) [beta] # pylint: disable=E1102 38 | -------------------------------------------------------------------------------- /s2cnn/utils/complex.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101 2 | import torch 3 | 4 | 5 | def as_complex(x): 6 | """ 7 | In pytorch, a complex array is represented as a real array with an extra length-2 axis at the end. 8 | This function takes a real-valued array x and adds complex axis where the real part is set to x and the imaginary part is set to 0. 9 | """ 10 | imaginary = torch.zeros_like(x) 11 | z = torch.stack((x, imaginary), dim=x.ndimension()) 12 | return z 13 | 14 | 15 | def complex_mm(x, y, conj_x=False, conj_y=False): 16 | ''' 17 | :param x: [i, k, complex] (M, K, 2) 18 | :param y: [k, j, complex] (K, N, 2) 19 | :return: [i, j, complex] (M, N, 2) 20 | ''' 21 | xr = x[:, :, 0] 22 | xi = x[:, :, 1] 23 | 24 | yr = y[:, :, 0] 25 | yi = y[:, :, 1] 26 | 27 | if not conj_x and not conj_y: 28 | zr = torch.mm(xr, yr) - torch.mm(xi, yi) 29 | zi = torch.mm(xr, yi) + torch.mm(xi, yr) 30 | if conj_x and not conj_y: 31 | zr = torch.mm(xr, yr) + torch.mm(xi, yi) 32 | zi = torch.mm(xr, yi) - torch.mm(xi, yr) 33 | if not conj_x and conj_y: 34 | zr = torch.mm(xr, yr) + torch.mm(xi, yi) 35 | zi = torch.mm(xi, yr) - torch.mm(xr, yi) 36 | if conj_x and conj_y: 37 | zr = torch.mm(xr, yr) - torch.mm(xi, yi) 38 | zi = - torch.mm(xr, yi) - torch.mm(xi, yr) 39 | 40 | return torch.stack((zr, zi), 2) 41 | -------------------------------------------------------------------------------- /s2cnn/s2_grid.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import numpy as np 3 | 4 | 5 | def s2_near_identity_grid(max_beta=np.pi / 8, n_alpha=8, n_beta=3): 6 | ''' 7 | :return: rings around the north pole 8 | size of the kernel = n_alpha * n_beta 9 | ''' 10 | beta = np.arange(start=1, stop=n_beta + 1, dtype=np.float) * max_beta / n_beta 11 | alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) 12 | B, A = np.meshgrid(beta, alpha, indexing='ij') 13 | B = B.flatten() 14 | A = A.flatten() 15 | grid = np.stack((B, A), axis=1) 16 | return tuple(tuple(ba) for ba in grid) 17 | 18 | 19 | def s2_equatorial_grid(max_beta=0, n_alpha=32, n_beta=1): 20 | ''' 21 | :return: rings around the equator 22 | size of the kernel = n_alpha * n_beta 23 | ''' 24 | beta = np.linspace(start=np.pi/2 - max_beta, stop=np.pi/2 + max_beta, num=n_beta, endpoint=True) 25 | alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) 26 | B, A = np.meshgrid(beta, alpha, indexing='ij') 27 | B = B.flatten() 28 | A = A.flatten() 29 | grid = np.stack((B, A), axis=1) 30 | return tuple(tuple(ba) for ba in grid) 31 | 32 | 33 | def s2_soft_grid(b): 34 | beta = (np.arange(2 * b) + 0.5) / (2 * b) * np.pi 35 | alpha = np.linspace(start=0, stop=2 * np.pi, num=2 * b, endpoint=False) 36 | B, A = np.meshgrid(beta, alpha, indexing='ij') 37 | B = B.flatten() 38 | A = A.flatten() 39 | grid = np.stack((B, A), axis=1) 40 | return tuple(tuple(ba) for ba in grid) 41 | -------------------------------------------------------------------------------- /examples/equivariance_error/main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Aim: see that L_R Phi(x) = Phi(L_R x) 3 | 4 | Where Phi is a composition of a S^2 convolution and a SO(3) convolution 5 | 6 | For simplicity, R is a rotation around the Z axis. 7 | ''' 8 | 9 | #pylint: disable=C,R,E1101,W0621 10 | import torch 11 | 12 | from s2cnn import s2_equatorial_grid, S2Convolution 13 | from s2cnn import so3_equatorial_grid, SO3Convolution 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | # Define the two convolutions 18 | s2_grid = s2_equatorial_grid(max_beta=0, n_alpha=64, n_beta=1) 19 | s2_conv = S2Convolution(nfeature_in=12, nfeature_out=15, b_in=64, b_out=32, grid=s2_grid) 20 | s2_conv.to(device) 21 | 22 | so3_grid = so3_equatorial_grid(max_beta=0, max_gamma=0, n_alpha=64, n_beta=1, n_gamma=1) 23 | so3_conv = SO3Convolution(nfeature_in=15, nfeature_out=21, b_in=32, b_out=24, grid=so3_grid) 24 | so3_conv.to(device) 25 | 26 | def phi(x): 27 | x = s2_conv(x) 28 | x = torch.nn.functional.relu(x) 29 | x = so3_conv(x) 30 | return x 31 | 32 | def rot(x, angle): 33 | # rotate the signal around the Z axis 34 | n = round(x.size(3) * angle / 360) 35 | return torch.cat([x[:, :, :, n:], x[:, :, :, :n]], dim=3) 36 | 37 | # Create random input 38 | x = torch.randn(1, 12, 128, 128).to(device) # [batch, feature, beta, alpha] 39 | 40 | y = phi(x) 41 | y1 = rot(phi(x), angle=45) 42 | y2 = phi(rot(x, angle=45)) 43 | 44 | relative_error = torch.std(y1.data - y2.data) / torch.std(y.data) 45 | 46 | print('relative error = {}'.format(relative_error)) 47 | -------------------------------------------------------------------------------- /examples/molecules/baseline_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | CHARGES = [0., 1., 6., 7., 8., 16.] 10 | 11 | 12 | class BaselineRegressor(nn.Module): 13 | '''Very simple baseline model only utilizing the frequency and charge 14 | of atoms in a molecule.''' 15 | 16 | def __init__(self): 17 | super(BaselineRegressor, self).__init__() 18 | 19 | # num atoms and atom types 20 | self.n_atoms = 23 21 | self.n_types = 6 22 | 23 | # number of hidden units in the output regression 24 | self.n_hidden = 50 25 | self.n_hidden2 = 20 26 | 27 | self.W_h = nn.Linear(self.n_types, self.n_hidden) 28 | self.W_h2 = nn.Linear(self.n_hidden, self.n_hidden2) 29 | self.W_t = nn.Linear(self.n_hidden2, 1) 30 | 31 | def forward(self, x): 32 | ''' 33 | x: [batch, n_atoms, n_types, beta, alpha] 34 | types: [batch, n_atoms, n_types] 35 | ''' 36 | 37 | # get charge 38 | z = torch.autograd.Variable( 39 | torch.from_numpy(np.array(CHARGES)) 40 | ).view(1, -1).float().cuda() 41 | 42 | # get atom frequency per molecule 43 | x = torch.sum(x, dim=1) 44 | x[:, 0] = 0 45 | 46 | # multiply frequency by charge 47 | # TODO: concatenate instead? 48 | z = z.expand_as(x) 49 | x = x * z 50 | 51 | # simple transform 52 | x = self.W_h(x) 53 | x = F.relu(x) 54 | x = self.W_h2(x) 55 | x = F.relu(x) 56 | x = self.W_t(x) 57 | 58 | return x 59 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # Spherical MNIST example 2 | 3 | ## Generate the spherical MNIST data set 4 | 5 | - __NR__: non rotated 6 | - __R__: randomly rotated 7 | 8 | ##### train: __NR__ - test: __NR__ 9 | ```bash 10 | python3 gendata.py --no_rotate_train --no_rotate_test 11 | ``` 12 | 13 | ##### train: __R__ - test: __R__ 14 | ```bash 15 | python3 gendata.py 16 | ``` 17 | 18 | ##### train: __NR__ - test: __R__ 19 | ```bash 20 | python3 gendata.py --no_rotate_train 21 | ``` 22 | 23 | This will generate a `s2_mnist.gz` in the same folder containing the compressed generated dataset. 24 | 25 | To get more information about other params for the data generation (noise magnitude, number of images having the same random rotations etc.): 26 | ```bash 27 | python3 gendata.py --help 28 | ``` 29 | 30 | ## Run the models 31 | 32 | (Apologies for the ugly global constants regarding hyperparams. I will add some nice argparse at a later point. - jonas) 33 | 34 | ### Simple 2D CNN 35 | 36 | ```bash 37 | python3 run_classic.py 38 | ``` 39 | 40 | ### Run S2CNNs 41 | 42 | To run the original S2CNN architecture reported in the paper simply call 43 | ```bash 44 | python3 run.py 45 | ``` 46 | or 47 | ```bash 48 | python3 run.py --network=original 49 | ``` 50 | 51 | An improved model can be selected by calling 52 | ```bash 53 | python3 run.py --network=deep 54 | ``` 55 | This architecture served as baseline for the the Icosahedral CNN [[1]](https://arxiv.org/pdf/1902.04615.pdf) (in the baseline run of [[1]](https://arxiv.org/pdf/1902.04615.pdf) slightly different hyperparameters like the bandwidth, learning rate decay and batch size were used). 56 | It achieves an accuracy of ~99.2%. 57 | 58 | 59 | ## References 60 | 61 | [1] Taco S. Cohen, Maurice Weiler, Berkay Kicanaoglu, Max Welling, 62 | [Gauge Equivariant Convolutional Networks and the Icosahedral CNN](https://arxiv.org/pdf/1902.04615.pdf). 63 | International Conference on Machine Learning (ICML), 2019. 64 | -------------------------------------------------------------------------------- /examples/shrec17/model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from s2cnn import s2_equatorial_grid, S2Convolution, so3_equatorial_grid, SO3Convolution, so3_integrate 6 | 7 | class Model(nn.Module): 8 | def __init__(self, nclasses): 9 | super().__init__() 10 | 11 | self.features = [6, 100, 100, nclasses] 12 | self.bandwidths = [64, 16, 10] 13 | 14 | assert len(self.bandwidths) == len(self.features) - 1 15 | 16 | sequence = [] 17 | 18 | # S2 layer 19 | grid = s2_equatorial_grid(max_beta=0, n_alpha=2 * self.bandwidths[0], n_beta=1) 20 | sequence.append(S2Convolution(self.features[0], self.features[1], self.bandwidths[0], self.bandwidths[1], grid)) 21 | 22 | # SO3 layers 23 | for l in range(1, len(self.features) - 2): 24 | nfeature_in = self.features[l] 25 | nfeature_out = self.features[l + 1] 26 | b_in = self.bandwidths[l] 27 | b_out = self.bandwidths[l + 1] 28 | 29 | sequence.append(nn.BatchNorm3d(nfeature_in, affine=True)) 30 | sequence.append(nn.ReLU()) 31 | grid = so3_equatorial_grid(max_beta=0, max_gamma=0, n_alpha=2 * b_in, n_beta=1, n_gamma=1) 32 | sequence.append(SO3Convolution(nfeature_in, nfeature_out, b_in, b_out, grid)) 33 | 34 | sequence.append(nn.BatchNorm3d(self.features[-2], affine=True)) 35 | sequence.append(nn.ReLU()) 36 | 37 | self.sequential = nn.Sequential(*sequence) 38 | 39 | # Output layer 40 | output_features = self.features[-2] 41 | self.out_layer = nn.Linear(output_features, self.features[-1]) 42 | 43 | def forward(self, x): # pylint: disable=W0221 44 | x = self.sequential(x) # [batch, feature, beta, alpha, gamma] 45 | x = so3_integrate(x) # [batch, feature] 46 | 47 | x = self.out_layer(x) 48 | return F.log_softmax(x, dim=1) 49 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | > Why there is a very little difference when evaluating on the non-rotational input and the rotational input? 2 | 3 | Our formula are equivariant, but in our implementation we discretized the formula which causes this little difference. 4 | 5 | > The Spherical CNN takes lots of memory. Now we plan to build a larger model and train on multiple gpus. Is this possible? 6 | 7 | This is due to the 3 dimensional internal representations. 8 | Last attempt to support multi-gpu was done here [#8](https://github.com/jonas-koehler/s2cnn/issues/8) 9 | You can also look at other architecture that consume less memory like [1902.04615](https://arxiv.org/abs/1902.04615) 10 | 11 | > 1. You provided two options: identity and equators. Why is that ‘identity’ is localised while ‘equator’ is non-local support? 12 | 13 | These are two choices of shape of filters. In 2d images the filters (also called kernel) are usually squared 3x3, but it can also be rectangular. In theory all shapes are allowed. In our framework we convolve over the sphere s2 and the group SO3, in these spaces as well we can chose the shape of the kernel as we wish. We took a localized one (similar to the 3x3 square of the 2d images) that we places wlog at the north pole. The second shape we tried is a ring around the equator, non-local because it turns around the entire space. 14 | 15 | > 2. Take the identity one for example, the sampling grid is chose to be close to the north pole, does this mean it only sees the input data around north pole? 16 | 17 | No, the kernel is roatated around in all possible orientations. Like the 3x3 kernel on 2d images who is translated on the entire image. 18 | 19 | > 3. For the choice of max_alpha, by default, it is 2*pi. Is it still meaning if using smaller value for max_alpha, e.g. pi or pi/2. 20 | 21 | For the identity shape you will get a shape of pie. For the equatorial kernel you will get a portion of ring. Again all the shapes are allowed and it is not intuitive to me which one is better than another. 22 | 23 | -------------------------------------------------------------------------------- /examples/shrec17/model_original.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from s2cnn import s2_equatorial_grid, S2Convolution, so3_equatorial_grid, SO3Convolution 6 | 7 | class Model(nn.Module): 8 | def __init__(self, nclasses): 9 | super().__init__() 10 | 11 | self.features = [6, 50, 70, 350, nclasses] 12 | self.bandwidths = [128, 32, 22, 7] 13 | 14 | assert len(self.bandwidths) == len(self.features) - 1 15 | 16 | sequence = [] 17 | 18 | # S2 layer 19 | grid = s2_equatorial_grid(max_beta=0, n_alpha=2 * self.bandwidths[0], n_beta=1) 20 | sequence.append(S2Convolution(self.features[0], self.features[1], self.bandwidths[0], self.bandwidths[1], grid)) 21 | 22 | # SO3 layers 23 | for l in range(1, len(self.features) - 2): 24 | nfeature_in = self.features[l] 25 | nfeature_out = self.features[l + 1] 26 | b_in = self.bandwidths[l] 27 | b_out = self.bandwidths[l + 1] 28 | 29 | sequence.append(nn.BatchNorm3d(nfeature_in, affine=True)) 30 | sequence.append(nn.ReLU()) 31 | grid = so3_equatorial_grid(max_beta=0, max_gamma=0, n_alpha=2 * b_in, n_beta=1, n_gamma=1) 32 | sequence.append(SO3Convolution(nfeature_in, nfeature_out, b_in, b_out, grid)) 33 | 34 | sequence.append(nn.BatchNorm3d(self.features[-2], affine=True)) 35 | sequence.append(nn.ReLU()) 36 | 37 | self.sequential = nn.Sequential(*sequence) 38 | 39 | # Output layer 40 | self.out_layer = nn.Sequential( 41 | nn.BatchNorm1d(self.features[-2], affine=False), 42 | nn.Linear(self.features[-2], self.features[-1]) 43 | ) 44 | 45 | def forward(self, x): # pylint: disable=W0221 46 | x = self.sequential(x) # [batch, feature, beta, alpha, gamma] 47 | x = x.view(x.size(0), x.size(1), -1).max(-1)[0] # [batch, feature] 48 | 49 | x = self.out_layer(x) 50 | return F.log_softmax(x, dim=1) -------------------------------------------------------------------------------- /s2cnn/soft/s2_conv.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101 2 | import math 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules import Module 6 | 7 | from .s2_fft import S2_fft_real 8 | from .so3_fft import SO3_ifft_real 9 | from s2cnn import s2_mm 10 | from s2cnn import s2_rft 11 | 12 | 13 | class S2Convolution(Module): 14 | def __init__(self, nfeature_in, nfeature_out, b_in, b_out, grid): 15 | ''' 16 | :param nfeature_in: number of input fearures 17 | :param nfeature_out: number of output features 18 | :param b_in: input bandwidth (precision of the input SOFT grid) 19 | :param b_out: output bandwidth 20 | :param grid: points of the sphere defining the kernel, tuple of (alpha, beta)'s 21 | ''' 22 | super(S2Convolution, self).__init__() 23 | self.nfeature_in = nfeature_in 24 | self.nfeature_out = nfeature_out 25 | self.b_in = b_in 26 | self.b_out = b_out 27 | self.grid = grid 28 | self.kernel = Parameter(torch.empty(nfeature_in, nfeature_out, len(grid)).uniform_(-1, 1)) 29 | self.scaling = 1. / math.sqrt(len(self.grid) * self.nfeature_in * (self.b_out ** 4.) / (self.b_in ** 2.)) 30 | self.bias = Parameter(torch.zeros(1, nfeature_out, 1, 1, 1)) 31 | 32 | def forward(self, x): # pylint: disable=W 33 | ''' 34 | :x: [batch, feature_in, beta, alpha] 35 | :return: [batch, feature_out, beta, alpha, gamma] 36 | ''' 37 | assert x.size(1) == self.nfeature_in 38 | assert x.size(2) == 2 * self.b_in 39 | assert x.size(3) == 2 * self.b_in 40 | x = S2_fft_real.apply(x, self.b_out) # [l * m, batch, feature_in, complex] 41 | y = s2_rft(self.kernel * self.scaling, self.b_out, self.grid) # [l * m, feature_in, feature_out, complex] 42 | z = s2_mm(x, y) # [l * m * n, batch, feature_out, complex] 43 | z = SO3_ifft_real.apply(z) # [batch, feature_out, beta, alpha, gamma] 44 | 45 | z = z + self.bias 46 | 47 | return z 48 | -------------------------------------------------------------------------------- /examples/shrec17/README.md: -------------------------------------------------------------------------------- 1 | # SHREC17 2 | 3 | ## Dependencies 4 | 5 | To load the dataset SHREC17 we use the libaray `trimesh` and make it work efficiently you will aslo need to install `pyembree`. 6 | (See on [the trimesh github](https://github.com/mikedh/trimesh/blob/master/docs/install.rst) how to install them) 7 | 8 | ## Training 9 | ``` 10 | python train.py --model_path model.py --log_dir my_run --dataset train --batch_size 32 --learning_rate 0.5 --augmentation 5 11 | ``` 12 | 13 | ## Validation 14 | ``` 15 | python test.py --log_dir my_run --dataset val --batch_size 32 --augmentation 5 16 | cat my_run/summary.csv | grep micro 17 | ``` 18 | 19 | ## Original Code 20 | The files here are not the original code used for the article. They have been recoded from scratch. 21 | These files a simpler and clearer to read. 22 | The model has also been simplified. 23 | These simplifications gives *very similar results*. 24 | 25 | ### Model 26 | The results reported in the article where produced by the model `model_original.py`. 27 | The file `model.py` is a simplification of the original model we used, here is the list of the differences: 28 | - `model.py` has one less layer than `model_original.py` (2 hidden layers vs 3 hidden layers) 29 | - `model.py` has 1M less parameters than `model_original.py` (1.4M vs 400k) 30 | - `model.py` takes as input a bandwidth of 64 instead of 128 (smaller input images on the sphere) 31 | - To get rid of the spacial dimensions `model.py` computes the integral instead of taking the maximum 32 | 33 | ### Training 34 | Originaly (for the article) we used Adam with a complicated learning rate schedule. 35 | Here `train.py` used SGD with momentum and the learning rate is divided by 10 every 100 epoch. 36 | 37 | ### Data Augmentation 38 | Originaly the data augmentation (rotation and translation of the 3D model before the raytracing) where fixed to 7 augmentations, 39 | we did one translation in each direction in 3D: `(0,0,0), (+1,0,0), (-1,0,0), (0,+1,0), (0,-1,0), (0,0,+1), (0,0,-1)`. 40 | Now to make things more flexible and easy, the code in `dataset.py` perform random translations such that the number of augmentation can be set arbitratily. 41 | -------------------------------------------------------------------------------- /s2cnn/soft/so3_rotation.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101 2 | import torch 3 | import numpy as np 4 | 5 | from .so3_fft import SO3_fft_real, SO3_ifft_real 6 | from s2cnn.utils.complex import complex_mm 7 | from functools import lru_cache 8 | from s2cnn.utils.decorator import cached_dirpklgz 9 | 10 | 11 | def so3_rotation(x, alpha, beta, gamma): 12 | ''' 13 | :param x: [..., beta, alpha, gamma] (..., 2b, 2b, 2b) 14 | ''' 15 | b = x.size()[-1] // 2 16 | x_size = x.size() 17 | 18 | Us = _setup_so3_rotation(b, alpha, beta, gamma, device_type=x.device.type, device_index=x.device.index) 19 | 20 | # fourier transform 21 | x = SO3_fft_real.apply(x) # [l * m * n, ..., complex] 22 | 23 | # rotated spectrum 24 | Fz_list = [] 25 | begin = 0 26 | for l in range(b): 27 | L = 2 * l + 1 28 | size = L ** 2 29 | 30 | Fx = x[begin:begin+size] 31 | Fx = Fx.view(L, -1, 2) # [m, n * batch, complex] 32 | 33 | U = Us[l].view(L, L, 2) # [m, n, complex] 34 | 35 | Fz = complex_mm(U, Fx, conj_x=True) # [m, n * batch, complex] 36 | 37 | Fz = Fz.view(size, -1, 2) # [m * n, batch, complex] 38 | Fz_list.append(Fz) 39 | 40 | begin += size 41 | 42 | Fz = torch.cat(Fz_list, 0) # [l * m * n, batch, complex] 43 | z = SO3_ifft_real.apply(Fz) 44 | 45 | z = z.contiguous() 46 | z = z.view(*x_size) 47 | 48 | return z 49 | 50 | 51 | @cached_dirpklgz("cache/setup_so3_rotation") 52 | def __setup_so3_rotation(b, alpha, beta, gamma): 53 | from lie_learn.representations.SO3.wigner_d import wigner_D_matrix 54 | 55 | Us = [wigner_D_matrix(l, alpha, beta, gamma, 56 | field='complex', normalization='quantum', order='centered', condon_shortley='cs') 57 | for l in range(b)] 58 | # Us[l][m, n] = exp(i m alpha) d^l_mn(beta) exp(i n gamma) 59 | 60 | Us = [Us[l].astype(np.complex64).view(np.float32).reshape((2 * l + 1, 2 * l + 1, 2)) for l in range(b)] 61 | 62 | return Us 63 | 64 | 65 | @lru_cache(maxsize=32) 66 | def _setup_so3_rotation(b, alpha, beta, gamma, device_type, device_index): 67 | Us = __setup_so3_rotation(b, alpha, beta, gamma) 68 | 69 | # convert to torch Tensor 70 | Us = [torch.tensor(U, dtype=torch.float32, device=torch.device(device_type, device_index)) for U in Us] # pylint: disable=E1102 71 | 72 | return Us 73 | -------------------------------------------------------------------------------- /s2cnn/so3_grid.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import numpy as np 3 | import warnings 4 | 5 | 6 | def so3_near_identity_grid(max_beta=np.pi / 8, max_gamma=2*np.pi, n_alpha=8, n_beta=3, n_gamma=None): 7 | ''' 8 | :return: rings of rotations around the identity, all points (rotations) in 9 | a ring are at the same distance from the identity 10 | size of the kernel = n_alpha * n_beta * n_gamma 11 | ''' 12 | if n_gamma is None: 13 | n_gamma = n_alpha # similar to regular representations 14 | beta = np.arange(start=1, stop=n_beta + 1, dtype=np.float) * max_beta / n_beta 15 | alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) 16 | pre_gamma = np.linspace(start=-max_gamma, stop=max_gamma, num=n_gamma, endpoint=True) 17 | B, A, preC = np.meshgrid(beta, alpha, pre_gamma, indexing='ij') 18 | C = preC - A 19 | B = B.flatten() 20 | A = A.flatten() 21 | C = C.flatten() 22 | grid = np.stack((B, A, C), axis=1) 23 | if sum(grid[:, 0] == 0) > 1: 24 | warnings.warn("Gimbal lock: beta take value 0 in the grid") 25 | return tuple(tuple(bac) for bac in grid) 26 | 27 | 28 | def so3_equatorial_grid(max_beta=0, max_gamma=np.pi / 8, n_alpha=32, n_beta=1, n_gamma=2): 29 | ''' 30 | :return: rings of rotations around the equator. 31 | size of the kernel = n_alpha * n_beta * n_gamma 32 | ''' 33 | beta = np.linspace(start=np.pi/2 - max_beta, stop=np.pi/2 + max_beta, num=n_beta, endpoint=True) 34 | alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) 35 | gamma = np.linspace(start=-max_gamma, stop=max_gamma, num=n_gamma, endpoint=True) 36 | B, A, C = np.meshgrid(beta, alpha, gamma, indexing='ij') 37 | B = B.flatten() 38 | A = A.flatten() 39 | C = C.flatten() 40 | grid = np.stack((B, A, C), axis=1) 41 | if sum(grid[:, 0] == 0) > 1: 42 | warnings.warn("Gimbal lock: beta take value 0 in the grid") 43 | return tuple(tuple(bac) for bac in grid) 44 | 45 | 46 | def so3_soft_grid(b): 47 | beta = (np.arange(2 * b) + 0.5) / (2 * b) * np.pi 48 | alpha = gamma = np.linspace(start=0, stop=2 * np.pi, num=2 * b, endpoint=False) 49 | B, A, C = np.meshgrid(beta, alpha, gamma, indexing='ij') 50 | B = B.flatten() 51 | A = A.flatten() 52 | C = C.flatten() 53 | grid = np.stack((B, A, C), axis=1) 54 | return tuple(tuple(bac) for bac in grid) 55 | -------------------------------------------------------------------------------- /s2cnn/s2_ft.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import torch 3 | import numpy as np 4 | from functools import lru_cache 5 | from s2cnn.utils.decorator import cached_dirpklgz 6 | 7 | 8 | def s2_rft(x, b, grid): 9 | """ 10 | Real Fourier Transform 11 | :param x: [..., beta_alpha] 12 | :param b: output bandwidth signal 13 | :param grid: tuple of (beta, alpha) tuples 14 | :return: [l * m, ..., complex] 15 | """ 16 | # F is the Fourier matrix 17 | F = _setup_s2_ft(b, grid, device_type=x.device.type, device_index=x.device.index) # [beta_alpha, l * m, complex] 18 | 19 | assert x.size(-1) == F.size(0) 20 | 21 | sz = x.size() 22 | x = torch.einsum("ia,afc->fic", (x.view(-1, x.size(-1)), F.clone())) # [l * m, ..., complex] 23 | x = x.view(-1, *sz[:-1], 2) 24 | return x 25 | 26 | 27 | @cached_dirpklgz("cache/setup_s2_ft") 28 | def __setup_s2_ft(b, grid): 29 | from lie_learn.representations.SO3.wigner_d import wigner_D_matrix 30 | 31 | # Note: optionally get quadrature weights for the chosen grid and use them to weigh the D matrices below. 32 | # This is optional because we can also view the filter coefficients as having absorbed the weights already. 33 | 34 | # Sample the Wigner-D functions on the local grid 35 | n_spatial = len(grid) 36 | n_spectral = np.sum([(2 * l + 1) for l in range(b)]) 37 | F = np.zeros((n_spatial, n_spectral), dtype=complex) 38 | for i, (beta, alpha) in enumerate(grid): 39 | Dmats = [(2 * b) * wigner_D_matrix(l, alpha, beta, 0, 40 | field='complex', normalization='quantum', order='centered', condon_shortley='cs') 41 | .conj() 42 | for l in range(b)] 43 | F[i] = np.hstack([Dmats[l][:, l] for l in range(b)]) 44 | 45 | # F is a complex matrix of shape (n_spatial, n_spectral) 46 | # If we view it as float, we get a real matrix of shape (n_spatial, 2 * n_spectral) 47 | # In the so3_local_ft, we will multiply a batch of real (..., n_spatial) vectors x with this matrix F as xF. 48 | # The result is a (..., 2 * n_spectral) array that can be interpreted as a batch of complex vectors. 49 | F = F.view('float').reshape((-1, n_spectral, 2)) 50 | return F 51 | 52 | 53 | @lru_cache(maxsize=32) 54 | def _setup_s2_ft(b, grid, device_type, device_index): 55 | F = __setup_s2_ft(b, grid) 56 | 57 | # convert to torch Tensor 58 | F = torch.tensor(F.astype(np.float32), dtype=torch.float32, device=torch.device(device_type, device_index)) # pylint: disable=E1102 59 | 60 | return F 61 | -------------------------------------------------------------------------------- /s2cnn/so3_ft.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import torch 3 | import numpy as np 4 | from functools import lru_cache 5 | from s2cnn.utils.decorator import cached_dirpklgz 6 | 7 | 8 | def so3_rft(x, b, grid): 9 | """ 10 | Real Fourier Transform 11 | :param x: [..., beta_alpha_gamma] 12 | :param b: output bandwidth signal 13 | :param grid: tuple of (beta, alpha, gamma) tuples 14 | :return: [l * m * n, ..., complex] 15 | """ 16 | # F is the Fourier matrix 17 | F = _setup_so3_ft(b, grid, device_type=x.device.type, device_index=x.device.index) # [beta_alpha_gamma, l * m * n, complex] 18 | 19 | assert x.size(-1) == F.size(0) 20 | 21 | sz = x.size() 22 | x = torch.einsum("ia,afc->fic", (x.view(-1, x.size(-1)), F.clone())) # [l * m * n, ..., complex] 23 | x = x.view(-1, *sz[:-1], 2) 24 | return x 25 | 26 | 27 | @cached_dirpklgz("cache/setup_so3_ft") 28 | def __setup_so3_ft(b, grid): 29 | from lie_learn.representations.SO3.wigner_d import wigner_D_matrix 30 | 31 | # Note: optionally get quadrature weights for the chosen grid and use them to weigh the D matrices below. 32 | # This is optional because we can also view the filter coefficients as having absorbed the weights already. 33 | # The weights depend on the spacing between the point of the grid 34 | # Only the coefficient sin(beta) can be added without requireing to know the spacings 35 | 36 | # Sample the Wigner-D functions on the local grid 37 | n_spatial = len(grid) 38 | n_spectral = np.sum([(2 * l + 1) ** 2 for l in range(b)]) 39 | F = np.zeros((n_spatial, n_spectral), dtype=complex) 40 | for i, (beta, alpha, gamma) in enumerate(grid): 41 | Dmats = [wigner_D_matrix(l, alpha, beta, gamma, 42 | field='complex', normalization='quantum', order='centered', condon_shortley='cs') 43 | .conj() 44 | for l in range(b)] 45 | F[i] = np.hstack([Dl.flatten() for Dl in Dmats]) 46 | 47 | # F is a complex matrix of shape (n_spatial, n_spectral) 48 | # If we view it as float, we get a real matrix of shape (n_spatial, 2 * n_spectral) 49 | # In the so3_local_ft, we will multiply a batch of real (..., n_spatial) vectors x with this matrix F as xF. 50 | # The result is a (..., 2 * n_spectral) array that can be interpreted as a batch of complex vectors. 51 | F = F.view('float').reshape((-1, n_spectral, 2)) 52 | return F 53 | 54 | 55 | @lru_cache(maxsize=32) 56 | def _setup_so3_ft(b, grid, device_type, device_index): 57 | F = __setup_so3_ft(b, grid) 58 | 59 | # convert to torch Tensor 60 | F = torch.tensor(F.astype(np.float32), dtype=torch.float32, device=torch.device(device_type, device_index)) # pylint: disable=E1102 61 | 62 | return F 63 | -------------------------------------------------------------------------------- /examples/equivariance_plot/main.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101,E1102 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use('agg') 5 | import matplotlib.pyplot as plt 6 | from matplotlib.pyplot import imread 7 | 8 | import torch 9 | from s2cnn import S2Convolution, SO3Convolution, so3_rotation 10 | from s2cnn import s2_near_identity_grid, so3_near_identity_grid 11 | 12 | 13 | def s2_rotation(x, a, b, c): 14 | x = so3_rotation(x.view(*x.size(), 1).expand(*x.size(), x.size(-1)), a, b, c) 15 | return x[..., 0] 16 | 17 | 18 | def plot(x, text, normalize=False): 19 | assert x.size(0) == 1 20 | assert x.size(1) in [1, 3] 21 | x = x[0] 22 | if x.dim() == 4: 23 | x = x[..., 0] 24 | 25 | nch = x.size(0) 26 | is_rgb = (nch == 3) 27 | 28 | if normalize: 29 | x = x - x.view(nch, -1).mean(-1).view(nch, 1, 1) 30 | x = 0.4 * x / x.view(nch, -1).std(-1).view(nch, 1, 1) 31 | 32 | x = x.detach().cpu().numpy() 33 | x = x.transpose((1, 2, 0)).clip(0, 1) 34 | 35 | print(x.shape) 36 | if is_rgb: 37 | plt.imshow(x) 38 | else: 39 | plt.imshow(x[:, :, 0], cmap='gray') 40 | plt.axis("off") 41 | 42 | plt.text(0.5, 0.5, text, 43 | horizontalalignment='center', 44 | verticalalignment='center', 45 | transform=plt.gca().transAxes, 46 | color='white', fontsize=20) 47 | 48 | 49 | def main(): 50 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 51 | 52 | # load image 53 | x = imread("earth128.jpg").astype(np.float32).transpose((2, 0, 1)) / 255 54 | b = 64 55 | x = torch.tensor(x, dtype=torch.float, device=device) 56 | x = x.view(1, 3, 2 * b, 2 * b) 57 | 58 | # equivariant transformation 59 | s2_grid = s2_near_identity_grid(max_beta=0.2, n_alpha=12, n_beta=1) 60 | s2_conv = S2Convolution(3, 50, b_in=b, b_out=b, grid=s2_grid) 61 | s2_conv.to(device) 62 | 63 | so3_grid = so3_near_identity_grid(max_beta=0.2, n_alpha=12, n_beta=1) 64 | so3_conv = SO3Convolution(50, 1, b_in=b, b_out=b, grid=so3_grid) 65 | so3_conv.to(device) 66 | 67 | def phi(x): 68 | x = s2_conv(x) 69 | x = torch.nn.functional.softplus(x) 70 | x = so3_conv(x) 71 | return x 72 | 73 | # test equivariance 74 | abc = (0.5, 1, 0) # rotation angles 75 | 76 | y1 = phi(s2_rotation(x, *abc)) 77 | y2 = so3_rotation(phi(x), *abc) 78 | print((y1 - y2).std().item(), y1.std().item()) 79 | 80 | plt.figure(figsize=(12, 8)) 81 | 82 | plt.subplot(2, 3, 1) 83 | plot(x, "x : signal on the sphere") 84 | 85 | plt.subplot(2, 3, 2) 86 | plot(phi(x), "phi(x) : convolutions", True) 87 | 88 | plt.subplot(2, 3, 3) 89 | plot(so3_rotation(phi(x), *abc), "R(phi(x))", True) 90 | 91 | plt.subplot(2, 3, 4) 92 | plot(s2_rotation(x, *abc), "R(x) : rotation using fft") 93 | 94 | plt.subplot(2, 3, 5) 95 | plot(phi(s2_rotation(x, *abc)), "phi(R(x))", True) 96 | 97 | plt.tight_layout() 98 | plt.savefig("fig.jpeg") 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /s2cnn/utils/decorator.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import threading 3 | import time 4 | from functools import wraps 5 | from functools import lru_cache 6 | import pickle 7 | import gzip 8 | import os 9 | import sys 10 | 11 | 12 | class WaitPrint(threading.Thread): 13 | def __init__(self, t, message): 14 | super().__init__() 15 | self.t = t 16 | self.message = message 17 | self.running = True 18 | 19 | def stop(self): 20 | self.running = False 21 | 22 | def run(self): 23 | for _ in range(int(self.t // 0.1)): 24 | time.sleep(0.1) 25 | if not self.running: 26 | return 27 | print(self.message, end="") 28 | 29 | 30 | def show_running(func): 31 | @wraps(func) 32 | def g(*args, **kargs): 33 | x = WaitPrint( 34 | 2, 35 | "{}({})... ".format( 36 | func.__name__, 37 | ", ".join( 38 | [repr(x) for x in args] + 39 | ["{}={}".format(key, repr(value)) for key, value in kargs.items()] 40 | ) 41 | ) 42 | ) 43 | x.start() 44 | t = time.perf_counter() 45 | r = func(*args, **kargs) 46 | if x.is_alive(): 47 | x.stop() 48 | else: 49 | print("done in {:.0f} seconds".format(time.perf_counter() - t)) 50 | return r 51 | return g 52 | 53 | 54 | def cached_dirpklgz(dirname): 55 | ''' 56 | Cache a function with a directory 57 | ''' 58 | def decorator(func): 59 | ''' 60 | The actual decorator 61 | ''' 62 | @lru_cache(maxsize=None) 63 | @wraps(func) 64 | def wrapper(*args): 65 | ''' 66 | The wrapper of the function 67 | ''' 68 | try: 69 | os.makedirs(dirname) 70 | except FileExistsError: 71 | pass 72 | 73 | indexfile = os.path.join(dirname, "index.pkl") 74 | 75 | try: 76 | with open(indexfile, "rb") as file: 77 | index = pickle.load(file) 78 | except FileNotFoundError: 79 | index = {} 80 | 81 | try: 82 | filename = index[args] 83 | except KeyError: 84 | index[args] = filename = "{}.pkl.gz".format(len(index)) 85 | with open(indexfile, "wb") as file: 86 | pickle.dump(index, file) 87 | 88 | filepath = os.path.join(dirname, filename) 89 | 90 | try: 91 | with gzip.open(filepath, "rb") as file: 92 | print("load {}... ".format(filename), end="") 93 | result = pickle.load(file) 94 | except FileNotFoundError: 95 | print("compute {}... ".format(filename), end="") 96 | sys.stdout.flush() 97 | result = func(*args) 98 | print("save {}... ".format(filename), end="") 99 | with gzip.open(filepath, "wb") as file: 100 | pickle.dump(result, file) 101 | print("done") 102 | return result 103 | return wrapper 104 | return decorator 105 | -------------------------------------------------------------------------------- /s2cnn/soft/so3_conv.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101 2 | import math 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules import Module 6 | 7 | from .so3_fft import SO3_fft_real, SO3_ifft_real 8 | from s2cnn import so3_mm 9 | from s2cnn import so3_rft 10 | 11 | 12 | class SO3Convolution(Module): 13 | def __init__(self, nfeature_in, nfeature_out, b_in, b_out, grid): 14 | ''' 15 | :param nfeature_in: number of input fearures 16 | :param nfeature_out: number of output features 17 | :param b_in: input bandwidth (precision of the input SOFT grid) 18 | :param b_out: output bandwidth 19 | :param grid: points of the SO(3) group defining the kernel, tuple of (alpha, beta, gamma)'s 20 | ''' 21 | super(SO3Convolution, self).__init__() 22 | self.nfeature_in = nfeature_in 23 | self.nfeature_out = nfeature_out 24 | self.b_in = b_in 25 | self.b_out = b_out 26 | self.grid = grid 27 | self.kernel = Parameter(torch.empty(nfeature_in, nfeature_out, len(grid)).uniform_(-1, 1)) 28 | self.bias = Parameter(torch.zeros(1, nfeature_out, 1, 1, 1)) 29 | 30 | # When useing ADAM optimizer, the variance of each componant of the gradient 31 | # is normalized by ADAM around 1. 32 | # Then it is suited to have parameters of order one. 33 | # Therefore the scaling, needed for the proper forward propagation, is done "outside" of the parameters 34 | self.scaling = 1. / math.sqrt(len(self.grid) * self.nfeature_in * (self.b_out ** 3.) / (self.b_in ** 3.)) 35 | 36 | def forward(self, x): # pylint: disable=W 37 | ''' 38 | :x: [batch, feature_in, beta, alpha, gamma] 39 | :return: [batch, feature_out, beta, alpha, gamma] 40 | ''' 41 | assert x.size(1) == self.nfeature_in 42 | assert x.size(2) == 2 * self.b_in 43 | assert x.size(3) == 2 * self.b_in 44 | assert x.size(4) == 2 * self.b_in 45 | 46 | x = SO3_fft_real.apply(x, self.b_out) # [l * m * n, batch, feature_in, complex] 47 | y = so3_rft(self.kernel * self.scaling, self.b_out, self.grid) # [l * m * n, feature_in, feature_out, complex] 48 | assert x.size(0) == y.size(0) 49 | assert x.size(2) == y.size(1) 50 | z = so3_mm(x, y) # [l * m * n, batch, feature_out, complex] 51 | assert z.size(0) == x.size(0) 52 | assert z.size(1) == x.size(1) 53 | assert z.size(2) == y.size(2) 54 | z = SO3_ifft_real.apply(z) # [batch, feature_out, beta, alpha, gamma] 55 | 56 | z = z + self.bias 57 | 58 | return z 59 | 60 | 61 | class SO3Shortcut(Module): 62 | ''' 63 | Useful for ResNet 64 | ''' 65 | 66 | def __init__(self, nfeature_in, nfeature_out, b_in, b_out): 67 | super(SO3Shortcut, self).__init__() 68 | assert b_out <= b_in 69 | 70 | if (nfeature_in != nfeature_out) or (b_in != b_out): 71 | self.conv = SO3Convolution( 72 | nfeature_in=nfeature_in, nfeature_out=nfeature_out, b_in=b_in, b_out=b_out, 73 | grid=((0, 0, 0), )) 74 | else: 75 | self.conv = None 76 | 77 | def forward(self, x): # pylint: disable=W 78 | ''' 79 | :x: [batch, feature_in, beta, alpha, gamma] 80 | :return: [batch, feature_out, beta, alpha, gamma] 81 | ''' 82 | if self.conv is not None: 83 | return self.conv(x) 84 | else: 85 | return x 86 | -------------------------------------------------------------------------------- /examples/molecules/utils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import torch 3 | import numpy as np 4 | import joblib 5 | 6 | 7 | class IndexBatcher: 8 | 9 | def __init__(self, indices, n_batch, cuda=None): 10 | self.indices = indices.astype(np.int64) 11 | self.n_batch = n_batch 12 | self.pos = 0 13 | self.cuda = cuda 14 | self.internal_indices = np.arange(len(indices)).astype(np.int64) 15 | np.random.shuffle(self.internal_indices) 16 | 17 | def __iter__(self): 18 | return self 19 | 20 | def reset(self): 21 | self.pos = 0 22 | np.random.shuffle(self.internal_indices) 23 | 24 | def __next__(self): 25 | start = self.pos 26 | end = np.minimum(self.pos + self.n_batch, len(self.indices)) 27 | self.pos += self.n_batch 28 | if self.pos >= len(self.indices): 29 | self.reset() 30 | raise StopIteration 31 | tensor = torch.LongTensor( 32 | self.indices[self.internal_indices[start:end]]) 33 | if self.cuda is not None: 34 | tensor.cuda(self.cuda) 35 | return tensor 36 | 37 | def num_iterations(self): 38 | return len(self.indices) // self.n_batch 39 | 40 | next = __next__ 41 | 42 | 43 | def to_one_hot(x, n): 44 | x_ = torch.unsqueeze(x, 2) 45 | dims = (*x.size(), n) 46 | one_hot = torch.FloatTensor(*dims).zero_() 47 | one_hot.scatter_(2, x_, 1) 48 | return one_hot 49 | 50 | 51 | def load_data(path, test_strat_id=None, cuda=None): 52 | ''' 53 | Loads the data 54 | 55 | path: path to the molecule .gz 56 | batch_size: size of a mini batch 57 | test_strat_id: id of strat being used as test set 58 | ''' 59 | data = joblib.load(path) 60 | 61 | # map charges to type indices 62 | # TODO refactor to individual function 63 | # TODO make less reliant on individual data dict structure 64 | type_remap = -np.ones(int(data["features"]["atom_types"].max())+1) 65 | unique_types = np.unique(data["features"]["atom_types"]).astype(int) 66 | type_remap[unique_types] = np.arange(len(unique_types)) 67 | data["features"]["atom_types"] = type_remap[ 68 | data["features"]["atom_types"].astype(int)] 69 | 70 | # wrap features as torch tensors 71 | data["features"]["geometry"] = torch.FloatTensor( 72 | data["features"]["geometry"].astype(np.float32)) 73 | data["features"]["atom_types"] = torch.LongTensor( 74 | data["features"]["atom_types"].astype(np.int64)) 75 | data["targets"] = torch.from_numpy(data["targets"]) 76 | 77 | if cuda is not None: 78 | data["features"]["geometry"].cuda(cuda) 79 | data["features"]["atom_types"].cuda(cuda) 80 | data["targets"].cuda(cuda) 81 | 82 | train = np.ndarray((0)) 83 | test = np.ndarray((0)) 84 | 85 | # split in train and test set according to used strat 86 | # TODO this should be solved in a less ugly/ad-hoc fashion! 87 | if not test_strat_id: 88 | test_strat_id = np.random.randint(len(data["strats"])) 89 | for i in range(len(data["strats"])): 90 | if i != test_strat_id: 91 | train = np.concatenate((train, data["strats"][i])) 92 | else: 93 | test = np.concatenate((test, data["strats"][i])) 94 | 95 | return data, train, test 96 | 97 | 98 | def exp_lr_scheduler(optimizer, epoch, init_lr=5e-3, lr_decay_epoch=40): 99 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 100 | lr = init_lr * (0.1**(epoch // lr_decay_epoch)) 101 | if epoch % lr_decay_epoch == 0: 102 | print('LR is set to {}'.format(lr)) 103 | for param_group in optimizer.param_groups: 104 | param_group['lr'] = lr 105 | return optimizer 106 | 107 | 108 | def count_params(model): 109 | return sum([np.prod(p.size()) 110 | for p in model.parameters() 111 | if p.requires_grad]) 112 | -------------------------------------------------------------------------------- /examples/mnist/run_classic.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | import torch.utils.data as data_utils 7 | import gzip 8 | import pickle 9 | import numpy as np 10 | from torch.autograd import Variable 11 | 12 | MNIST_PATH = "s2_mnist.gz" 13 | 14 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | NUM_EPOCHS = 20 17 | BATCH_SIZE = 32 18 | LEARNING_RATE = 5e-4 19 | 20 | 21 | class ConvNet(nn.Module): 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | f1 = 32 27 | f2 = 64 28 | 29 | self.feature_layer = nn.Sequential( 30 | torch.nn.Conv2d(1, f1, kernel_size=5, stride=3), 31 | torch.nn.ReLU(), 32 | torch.nn.Conv2d(f1, f2, kernel_size=5, stride=3), 33 | torch.nn.ReLU() 34 | ) 35 | self.out_layer = torch.nn.Linear(f2 * 5**2, 10) 36 | 37 | def forward(self, x): 38 | x = self.feature_layer(x) 39 | x = x.view(x.shape[0], -1) 40 | x = self.out_layer(x) 41 | return x 42 | 43 | 44 | def load_data(path, batch_size): 45 | 46 | with gzip.open(path, 'rb') as f: 47 | dataset = pickle.load(f) 48 | 49 | train_data = torch.from_numpy( 50 | dataset["train"]["images"][:, None, :, :].astype(np.float32)) 51 | train_labels = torch.from_numpy( 52 | dataset["train"]["labels"].astype(np.int64)) 53 | 54 | # TODO normalize dataset 55 | # mean = train_data.mean() 56 | # stdv = train_data.std() 57 | 58 | train_dataset = data_utils.TensorDataset(train_data, train_labels) 59 | train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 60 | 61 | test_data = torch.from_numpy( 62 | dataset["test"]["images"][:, None, :, :].astype(np.float32)) 63 | test_labels = torch.from_numpy( 64 | dataset["test"]["labels"].astype(np.int64)) 65 | 66 | test_dataset = data_utils.TensorDataset(test_data, test_labels) 67 | test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 68 | 69 | return train_loader, test_loader, train_dataset, test_dataset 70 | 71 | 72 | def main(): 73 | 74 | train_loader, test_loader, train_dataset, _ = load_data( 75 | MNIST_PATH, BATCH_SIZE) 76 | 77 | classifier = ConvNet() 78 | classifier.to(DEVICE) 79 | 80 | print("#params", sum([x.numel() for x in classifier.parameters()])) 81 | 82 | 83 | criterion = nn.CrossEntropyLoss() 84 | criterion = criterion.to(DEVICE) 85 | 86 | optimizer = torch.optim.Adam( 87 | classifier.parameters(), 88 | lr=LEARNING_RATE) 89 | 90 | for epoch in range(NUM_EPOCHS): 91 | for i, (images, labels) in enumerate(train_loader): 92 | classifier.train() 93 | 94 | images = images.to(DEVICE) 95 | labels = labels.to(DEVICE) 96 | 97 | optimizer.zero_grad() 98 | outputs = classifier(images) 99 | loss = criterion(outputs, labels) 100 | loss.backward() 101 | 102 | optimizer.step() 103 | 104 | print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format( 105 | epoch+1, NUM_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, 106 | loss.item()), end="") 107 | print("") 108 | correct = 0 109 | total = 0 110 | for i, (images, labels) in enumerate(test_loader): 111 | classifier.eval() 112 | 113 | with torch.no_grad(): 114 | images = images.to(DEVICE) 115 | labels = labels.to(DEVICE) 116 | 117 | outputs = classifier(images) 118 | _, predicted = torch.max(outputs, 1) 119 | total += labels.size(0) 120 | correct += (predicted == labels).long().sum().item() 121 | 122 | print('Test Accuracy: {0}'.format(100 * correct / total)) 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /examples/shrec17/test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import os 3 | import numpy as np 4 | import shutil 5 | import requests 6 | import zipfile 7 | from dataset import Shrec17, CacheNPY, ToMesh, ProjectOnSphere 8 | from subprocess import check_output 9 | import torch 10 | import torchvision 11 | import types 12 | import importlib.machinery 13 | 14 | 15 | class KeepName: 16 | def __init__(self, transform): 17 | self.transform = transform 18 | 19 | def __call__(self, file_name): 20 | return file_name, self.transform(file_name) 21 | 22 | 23 | def main(log_dir, augmentation, dataset, batch_size, num_workers): 24 | print(check_output(["nodejs", "--version"]).decode("utf-8")) 25 | 26 | torch.backends.cudnn.benchmark = True 27 | 28 | # Increasing `repeat` will generate more cached files 29 | transform = torchvision.transforms.Compose([ 30 | CacheNPY(prefix="b64_", repeat=augmentation, pick_randomly=False, transform=torchvision.transforms.Compose( 31 | [ 32 | ToMesh(random_rotations=True, random_translation=0.1), 33 | ProjectOnSphere(bandwidth=64) 34 | ] 35 | )), 36 | lambda xs: torch.stack([torch.FloatTensor(x) for x in xs]) 37 | ]) 38 | transform = KeepName(transform) 39 | 40 | test_set = Shrec17("data", dataset, perturbed=True, download=True, transform=transform) 41 | 42 | loader = importlib.machinery.SourceFileLoader('model', os.path.join(log_dir, "model.py")) 43 | mod = types.ModuleType(loader.name) 44 | loader.exec_module(mod) 45 | 46 | model = mod.Model(55) 47 | model.cuda() 48 | 49 | model.load_state_dict(torch.load(os.path.join(log_dir, "state.pkl"))) 50 | 51 | resdir = os.path.join(log_dir, dataset + "_perturbed") 52 | if os.path.isdir(resdir): 53 | shutil.rmtree(resdir) 54 | os.mkdir(resdir) 55 | 56 | predictions = [] 57 | ids = [] 58 | 59 | loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=False) 60 | 61 | for batch_idx, data in enumerate(loader): 62 | model.eval() 63 | 64 | if dataset != "test": 65 | data = data[0] 66 | 67 | file_names, data = data 68 | batch_size, rep = data.size()[:2] 69 | data = data.view(-1, *data.size()[2:]) 70 | 71 | data = data.cuda() 72 | with torch.no_grad(): 73 | pred = model(data).data 74 | pred = pred.view(batch_size, rep, -1) 75 | pred = pred.sum(1) 76 | 77 | predictions.append(pred.cpu().numpy()) 78 | ids.extend([x.split("/")[-1].split(".")[0] for x in file_names]) 79 | 80 | print("[{}/{}] ".format(batch_idx, len(loader))) 81 | 82 | predictions = np.concatenate(predictions) 83 | 84 | predictions_class = np.argmax(predictions, axis=1) 85 | 86 | for i in range(len(ids)): 87 | if i % 100 == 0: 88 | print("{}/{} ".format(i, len(ids)), end="\r") 89 | idfile = os.path.join(resdir, ids[i]) 90 | 91 | retrieved = [(predictions[j, predictions_class[j]], ids[j]) for j in range(len(ids)) if predictions_class[j] == predictions_class[i]] 92 | retrieved = sorted(retrieved, reverse=True) 93 | retrieved = [i for _, i in retrieved] 94 | 95 | with open(idfile, "w") as f: 96 | f.write("\n".join(retrieved)) 97 | 98 | url = "https://shapenet.cs.stanford.edu/shrec17/code/evaluator.zip" 99 | file_path = "evaluator.zip" 100 | 101 | r = requests.get(url, stream=True) 102 | with open(file_path, 'wb') as f: 103 | for chunk in r.iter_content(chunk_size=16 * 1024 ** 2): 104 | if chunk: # filter out keep-alive new chunks 105 | f.write(chunk) 106 | f.flush() 107 | 108 | zip_ref = zipfile.ZipFile(file_path, 'r') 109 | zip_ref.extractall(".") 110 | zip_ref.close() 111 | 112 | print(check_output(["nodejs", "evaluate.js", os.path.join("..", log_dir) + "/"], cwd="evaluator").decode("utf-8")) 113 | shutil.copy2(os.path.join("evaluator", log_dir + ".summary.csv"), os.path.join(log_dir, "summary.csv")) 114 | 115 | if __name__ == "__main__": 116 | import argparse 117 | 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument("--log_dir", type=str, required=True) 121 | parser.add_argument("--augmentation", type=int, default=1, 122 | help="Generate multiple image with random rotations and translations") 123 | parser.add_argument("--dataset", choices={"test", "val", "train"}, default="val") 124 | parser.add_argument("--batch_size", type=int, default=32) 125 | parser.add_argument("--num_workers", type=int, default=1) 126 | 127 | args = parser.parse_args() 128 | 129 | main(**args.__dict__) 130 | -------------------------------------------------------------------------------- /tests/so3_fft.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,R,E1101,E1102,W0621 2 | ''' 3 | Compare so3_ft with so3_fft 4 | ''' 5 | import torch 6 | from functools import partial 7 | 8 | def test_so3_rfft(b_in, b_out, device): 9 | x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float, device=device) # [beta, alpha, gamma] 10 | 11 | from s2cnn.soft.so3_fft import so3_rfft 12 | y1 = so3_rfft(x, b_out=b_out) 13 | 14 | from s2cnn import so3_rft, so3_soft_grid 15 | import lie_learn.spaces.S3 as S3 16 | 17 | # so3_ft computes a non weighted Fourier transform 18 | weights = torch.tensor(S3.quadrature_weights(b_in), dtype=torch.float, device=device) 19 | x2 = torch.einsum("bac,b->bac", (x, weights)) 20 | 21 | y2 = so3_rft(x2.view(-1), b_out, so3_soft_grid(b_in)) 22 | assert (y1 - y2).abs().max().item() < 1e-4 * y1.abs().mean().item() 23 | 24 | test_so3_rfft(7, 5, torch.device("cpu")) 25 | # test_so3_rfft(5, 7, torch.device("cpu")) # so3_rft introduce aliasing 26 | 27 | if torch.cuda.is_available(): 28 | test_so3_rfft(7, 5, torch.device("cuda:0")) 29 | # test_so3_rfft(5, 7, torch.device("cuda:0")) # so3_rft introduce aliasing 30 | 31 | 32 | 33 | 34 | def test_inverse(f, g, b_in, b_out, device, complex): 35 | if complex: 36 | x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, 2, dtype=torch.float, device=device) # [beta, alpha, gamma] 37 | else: 38 | x = torch.randn(2 * b_in, 2 * b_in, 2 * b_in, dtype=torch.float, device=device) # [beta, alpha, gamma] 39 | 40 | x = g(f(x, b_out=b_out), b_out=b_in) 41 | 42 | y = g(f(x, b_out=b_out), b_out=b_in) 43 | 44 | assert (x - y).abs().max().item() < 1e-4 * y.abs().mean().item() 45 | 46 | 47 | def test_inverse2(f, g, b_in, b_out, device): 48 | x = torch.randn(b_in * (4 * b_in**2 - 1) // 3, 2, dtype=torch.float, device=device) # [beta, alpha, gamma] 49 | 50 | x = g(f(x, b_out=b_out), b_out=b_in) 51 | 52 | y = g(f(x, b_out=b_out), b_out=b_in) 53 | 54 | assert (x - y).abs().max().item() < 1e-4 * y.abs().mean().item() 55 | 56 | 57 | from s2cnn.soft.so3_fft import so3_fft, so3_ifft 58 | test_inverse(so3_fft, so3_ifft, 7, 7, torch.device("cpu"), True) 59 | test_inverse(so3_fft, so3_ifft, 5, 4, torch.device("cpu"), True) 60 | test_inverse(so3_fft, so3_ifft, 7, 4, torch.device("cpu"), True) 61 | 62 | test_inverse2(so3_ifft, so3_fft, 7, 7, torch.device("cpu")) 63 | test_inverse2(so3_ifft, so3_fft, 5, 4, torch.device("cpu")) 64 | test_inverse2(so3_ifft, so3_fft, 4, 7, torch.device("cpu")) 65 | 66 | if torch.cuda.is_available(): 67 | test_inverse(so3_fft, so3_ifft, 7, 7, torch.device("cuda:0"), True) 68 | test_inverse(so3_fft, so3_ifft, 7, 5, torch.device("cuda:0"), True) 69 | test_inverse(so3_fft, so3_ifft, 4, 6, torch.device("cuda:0"), True) 70 | 71 | test_inverse2(so3_ifft, so3_fft, 7, 7, torch.device("cuda:0")) 72 | test_inverse2(so3_ifft, so3_fft, 5, 4, torch.device("cuda:0")) 73 | test_inverse2(so3_ifft, so3_fft, 4, 7, torch.device("cuda:0")) 74 | 75 | from s2cnn.soft.so3_fft import so3_rfft, so3_rifft 76 | test_inverse(so3_rfft, so3_rifft, 7, 7, torch.device("cpu"), False) 77 | test_inverse(so3_rfft, so3_rifft, 5, 4, torch.device("cpu"), False) 78 | test_inverse(so3_rfft, so3_rifft, 4, 6, torch.device("cpu"), False) 79 | 80 | test_inverse2(so3_rifft, so3_rfft, 7, 7, torch.device("cpu")) 81 | test_inverse2(so3_rifft, so3_rfft, 5, 4, torch.device("cpu")) 82 | test_inverse2(so3_rifft, so3_rfft, 4, 7, torch.device("cpu")) 83 | 84 | if torch.cuda.is_available(): 85 | test_inverse(so3_rfft, so3_rifft, 7, 7, torch.device("cuda:0"), False) 86 | test_inverse(so3_rfft, so3_rifft, 5, 4, torch.device("cuda:0"), False) 87 | test_inverse(so3_rfft, so3_rifft, 4, 6, torch.device("cuda:0"), False) 88 | 89 | test_inverse2(so3_rifft, so3_rfft, 7, 7, torch.device("cuda:0")) 90 | test_inverse2(so3_rifft, so3_rfft, 5, 4, torch.device("cuda:0")) 91 | test_inverse2(so3_rifft, so3_rfft, 4, 7, torch.device("cuda:0")) 92 | 93 | 94 | 95 | def compare_cpu_gpu(f, x): 96 | z1 = f(x.cpu()) 97 | z2 = f(x.cuda()).cpu() 98 | 99 | q = (z1 - z2).abs().max().item() / z1.std().item() 100 | assert q < 1e-4 101 | 102 | for b_in, b_out in [(2, 9), (6, 6), (9, 2), (10, 11), (11, 10)]: 103 | x = torch.rand(2 * b_in, 2 * b_in, 2 * b_in, 2) # [..., beta, alpha, gamma, complex] 104 | compare_cpu_gpu(partial(so3_fft, b_out=b_out), x) 105 | 106 | x = torch.rand(2 * b_in, 2 * b_in, 2 * b_in) # [..., beta, alpha, gamma] 107 | compare_cpu_gpu(partial(so3_rfft, b_out=b_out), x) 108 | 109 | x = torch.rand(b_in * (4 * b_in**2 - 1) // 3, 2) # [l * m * n, ..., complex] 110 | compare_cpu_gpu(partial(so3_ifft, b_out=b_out), x) 111 | 112 | x = torch.rand(2 * b_in, 2 * b_in, 2 * b_in) # [..., beta, alpha, gamma] 113 | x = so3_rfft(x) 114 | compare_cpu_gpu(partial(so3_rifft, b_out=b_out), x) 115 | 116 | -------------------------------------------------------------------------------- /examples/molecules/s2cnn_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from s2cnn.soft.so3_conv import SO3Convolution 6 | from s2cnn.soft.s2_conv import S2Convolution 7 | from s2cnn.soft.so3_integrate import so3_integrate 8 | from s2cnn.so3_grid import so3_near_identity_grid 9 | from s2cnn.s2_grid import s2_near_identity_grid 10 | 11 | 12 | nonlinearity = F.relu 13 | AFFINE = True 14 | 15 | 16 | class S2Block(nn.Module): 17 | """ simple s2 convolution block """ 18 | 19 | def __init__(self, b_in, b_out, f_in, f_out): 20 | """ b_in/b_out: bandwidth of input/output signals 21 | f_in/f_out: filters in input/output signals """ 22 | 23 | super(S2Block, self).__init__() 24 | 25 | self.grid_s2 = s2_near_identity_grid( 26 | n_alpha=2*b_in, n_beta=2) 27 | 28 | self.cnn = S2Convolution( 29 | nfeature_in=f_in, 30 | nfeature_out=f_out, 31 | b_in=b_in, 32 | b_out=b_out, 33 | grid=self.grid_s2) 34 | 35 | self.bn = nn.BatchNorm3d(f_out, affine=AFFINE) 36 | 37 | def forward(self, x): 38 | x = self.cnn(x) 39 | x = self.bn(x) 40 | x = nonlinearity(x) 41 | return x 42 | 43 | 44 | class So3Block(nn.Module): 45 | """ simple so3 convolution block """ 46 | 47 | def __init__(self, b_in, b_out, f_in, f_out): 48 | """ b_in/b_out: bandwidth of input/output signals 49 | f_in/f_out: filters in input/output signals """ 50 | 51 | super(So3Block, self).__init__() 52 | 53 | self.grid_so3 = so3_near_identity_grid( 54 | n_alpha=2*b_in, n_beta=2, n_gamma=2) 55 | 56 | self.cnn = SO3Convolution( 57 | nfeature_in=f_in, 58 | nfeature_out=f_out, 59 | b_in=b_in, 60 | b_out=b_out, 61 | grid=self.grid_so3) 62 | 63 | self.bn = nn.BatchNorm3d(f_out, affine=AFFINE) 64 | 65 | def forward(self, x): 66 | x = self.cnn(x) 67 | x = self.bn(x) 68 | x = nonlinearity(x) 69 | return x 70 | 71 | 72 | class DeepSet(nn.Module): 73 | """ deep set block """ 74 | 75 | def __init__(self, f, h1, h_latent, h2, n_objs): 76 | """ f: input filters 77 | h1, h2: hidden units for encoder/decoder mlps 78 | h_latent: dimensions 79 | n_objs: of objects to aggregate in latent space """ 80 | 81 | super(DeepSet, self).__init__() 82 | self.f = f 83 | self.h1 = h1 84 | self.h3 = h2 85 | self.n_objs = n_objs 86 | 87 | # encoder 88 | self.emb_h = nn.Linear(f, h1) 89 | self.emb_rep = nn.Linear(h1, h_latent) 90 | 91 | # decoder 92 | self.proj_h = nn.Linear(h_latent, h2) 93 | self.proj = nn.Linear(h2, 1) 94 | 95 | self.bn1 = nn.BatchNorm1d(h1, affine=AFFINE) 96 | self.bn2 = nn.BatchNorm1d(h_latent, affine=AFFINE) 97 | self.bn3 = nn.BatchNorm1d(h2, affine=AFFINE) 98 | 99 | def forward(self, x, mask): 100 | 101 | # encode atoms 102 | x = self.emb_h(x) 103 | x = self.bn1(x) 104 | x = nonlinearity(x) 105 | x = self.emb_rep(x) 106 | x = self.bn2(x) 107 | x = nonlinearity(x) 108 | 109 | # reshape (batch * atoms, features) -> (batch, atoms, features) 110 | n, h_latent = x.size() 111 | x = x.view(n // self.n_objs, self.n_objs, h_latent) 112 | 113 | # sum over latent atoms, filter out NULL atoms with mask 114 | x = torch.sum(x * mask, dim=1) 115 | 116 | # decode to final energy 117 | x = self.proj_h(x) 118 | x = self.bn3(x) 119 | x = nonlinearity(x) 120 | x = self.proj(x) 121 | 122 | return x 123 | 124 | 125 | class S2CNNRegressor(nn.Module): 126 | """ approximate energy using spherical representations """ 127 | 128 | def __init__(self): 129 | super(S2CNNRegressor, self).__init__() 130 | 131 | # number of atoms in a molecule 132 | n_objs = 23 133 | 134 | self.blocks = [ 135 | S2Block(b_in=10, f_in=5, b_out=8, f_out=8), 136 | So3Block(b_in=8, b_out=6, f_in=8, f_out=16), 137 | So3Block(b_in=6, b_out=4, f_in=16, f_out=32), 138 | So3Block(b_in=4, b_out=2, f_in=32, f_out=64), 139 | ] 140 | 141 | # TODO: replace with nn.Sequential or similar 142 | for i, block in enumerate(self.blocks): 143 | setattr(self, "block{0}".format(i), block) 144 | 145 | self.ds = DeepSet(64, 256, 64, 512, n_objs) 146 | 147 | def forward(self, x, atom_types): 148 | 149 | n_batch, n_atoms, n_features, bandwidth, _ = x.size() 150 | 151 | # compute mask of atoms which are present 152 | # this prevents from the need to learn NULL atoms 153 | mask = (atom_types > 0).view(n_batch, n_atoms, 1).float() 154 | 155 | # push atoms to batch dimension 156 | x = x.view(n_batch * n_atoms, n_features, bandwidth, bandwidth) 157 | 158 | # propagate through convolutions 159 | for block in self.blocks: 160 | x = block(x) 161 | 162 | # integrate over SO(3) 163 | x = so3_integrate(x) 164 | 165 | # combine atom representations to final energy 166 | y = self.ds(x, mask) 167 | 168 | return y 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | :warning: :warning: This code is old and does not support the last versions of pytorch! Especially since the change in the fft interface. :warning: :warning: 2 | 3 | # Spherical CNNs 4 | ## Equivariant CNNs for the sphere and SO(3) implemented in PyTorch 5 | 6 | ![Equivariance](https://github.com/jonas-koehler/s2cnn/raw/master/examples/equivariance_plot/fig.jpeg) 7 | 8 | ## Overview 9 | This library contains a PyTorch implementation of the rotation equivariant CNNs for spherical signals (e.g. omnidirectional images, signals on the globe) as presented in [[1]](https://arxiv.org/abs/1801.10130). Equivariant networks for the plane are available [here](https://github.com/tscohen/GrouPy). 10 | 11 | ## Dependencies 12 | 13 | * __PyTorch__: http://pytorch.org/ (>= 0.4.0) 14 | * __cupy__: https://github.com/cupy/cupy 15 | * __lie_learn__: https://github.com/AMLab-Amsterdam/lie_learn 16 | * __pynvrtc__: https://github.com/NVIDIA/pynvrtc 17 | 18 | (commands to install all the dependencies on a new conda environment) 19 | ```bash 20 | conda create --name cuda9 python=3.6 21 | conda activate cuda9 22 | 23 | # s2cnn deps 24 | #conda install pytorch torchvision cuda90 -c pytorch # get correct command line at http://pytorch.org/ 25 | conda install -c anaconda cupy 26 | pip install pynvrtc joblib 27 | 28 | # lie_learn deps 29 | conda install -c anaconda cython 30 | conda install -c anaconda requests 31 | 32 | # shrec17 example dep 33 | conda install -c anaconda scipy 34 | conda install -c conda-forge rtree shapely 35 | conda install -c conda-forge pyembree 36 | pip install "trimesh[easy]" 37 | ``` 38 | 39 | ## Installation 40 | 41 | To install, run 42 | 43 | ```bash 44 | $ python setup.py install 45 | ``` 46 | 47 | ## Usage 48 | Please have a look at the [examples](examples). 49 | 50 | Please cite [[1]](https://arxiv.org/abs/1801.10130) in your work when using this library in your experiments. 51 | 52 | 53 | ## Design choices for Spherical CNN Architectures 54 | 55 | Spherical CNNs come with different choices of grids and grid hyperparameters which are on the first look not obviously related to those of conventional CNNs. 56 | The `s2_near_identity_grid` and `so3_near_identity_grid` are the preferred choices since they correspond to spatially localized kernels, defined at the north pole and rotated over the sphere via the action of SO(3). 57 | In contrast, `s2_equatorial_grid` and `so3_equatorial_grid` define line-like (or ring-like) kernels around the equator. 58 | 59 | To clarify the possible parameter choices for `s2_near_identity_grid`: 60 | #### max_beta: 61 | Adapts the size of the kernel as angle measured from the north pole. 62 | Conventional CNNs on flat space usually use a fixed kernel size but pool the signal spatially. 63 | This spatial pooling gives the kernels in later layers an effectively increased field of view. 64 | One can emulate a pooling by a factor of 2 in spherical CNNs by decreasing the signal bandwidth by 2 and increasing `max_beta` by 2. 65 | #### n_beta: 66 | Number of rings of the kernel around the equator, equally spaced in 67 | [β=0, β=`max_beta`]. 68 | The choice `n_beta=1` corresponds to a small 3x3 kernel in `conv2d` since in both cases the resulting kernel consists of one central pixel and one ring around the center. 69 | #### n_alpha: 70 | Gives the number of learned parameters of the rings around the pole. 71 | These values are per default equally spaced on the azimuth. 72 | A sensible number of values depends on the bandwidth and `max_beta` since a higher resolution or spatial extent allow to sample more fine kernels without producing aliased results. 73 | In practice this value is typically set to a constant, low value like 6 or 8. 74 | A reduced bandwidth of the signal is thereby counteracted by an increased `max_beta` to emulate spatial pooling. 75 | 76 | The `so3_near_identity_grid` has two additional parameters `max_gamma` and `n_gamma`. 77 | SO(3) can be seen as a (principal) fiber bundle SO(3)→S² with the sphere S² as base space and fiber SO(2) attached to each point. 78 | The additional parameters control the grid on the fiber in the following way: 79 | #### max_gamma: 80 | The kernel spans over the fiber SO(2) between γ∈[0, `max_gamma`]. 81 | The fiber SO(2) encodes the kernel responses for every sampled orientation at a given position on the sphere. 82 | Setting `max_gamma`≨2π results in the kernel not seeing the responses of all kernel orientations simultaneously and is in general unfavored. 83 | Steerable CNNs [[3]](https://arxiv.org/abs/1803.10743) usually always use `max_gamma`=2π. 84 | #### n_gamma: 85 | Number of learned parameters on the fiber. 86 | Typically set equal to `n_alpha`, i.e. to a low value like 6 or 8. 87 | 88 | See the deep model of the MNIST example for an example of how to adapt these parameters over layers. 89 | 90 | 91 | 92 | ## Feedback 93 | For questions and comments, feel free to contact us: **geiger.mario (gmail)**, taco.cohen (gmail), jonas (argmin.xyz). 94 | 95 | 96 | ## License 97 | MIT 98 | 99 | ## References 100 | 101 | [1] Taco S. Cohen, Mario Geiger, Jonas Köhler, Max Welling, 102 | [Spherical CNNs](https://arxiv.org/abs/1801.10130). 103 | International Conference on Learning Representations (ICLR), 2018. 104 | 105 | [2] Taco S. Cohen, Mario Geiger, Jonas Köhler, Max Welling, 106 | [Convolutional Networks for Spherical Signals](https://arxiv.org/abs/1709.04893). 107 | ICML Workshop on Principled Approaches to Deep Learning, 2017. 108 | 109 | [3] Taco S. Cohen, Mario Geiger, Maurice Weiler, 110 | [Intertwiners between Induced Representations (with applications to the theory of equivariant neural networks)](https://arxiv.org/abs/1803.10743), 111 | ArXiv preprint 1803.10743, 2018. 112 | -------------------------------------------------------------------------------- /examples/shrec17/train.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C,W1202 2 | import torch 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | import os 7 | import shutil 8 | import time 9 | import logging 10 | import copy 11 | import types 12 | import importlib.machinery 13 | 14 | from dataset import Shrec17, CacheNPY, ToMesh, ProjectOnSphere 15 | 16 | 17 | def main(log_dir, model_path, augmentation, dataset, batch_size, learning_rate, num_workers): 18 | arguments = copy.deepcopy(locals()) 19 | 20 | os.mkdir(log_dir) 21 | shutil.copy2(__file__, os.path.join(log_dir, "script.py")) 22 | shutil.copy2(model_path, os.path.join(log_dir, "model.py")) 23 | 24 | logger = logging.getLogger("train") 25 | logger.setLevel(logging.DEBUG) 26 | logger.handlers = [] 27 | ch = logging.StreamHandler() 28 | logger.addHandler(ch) 29 | fh = logging.FileHandler(os.path.join(log_dir, "log.txt")) 30 | logger.addHandler(fh) 31 | 32 | logger.info("%s", repr(arguments)) 33 | 34 | torch.backends.cudnn.benchmark = True 35 | 36 | # Load the model 37 | loader = importlib.machinery.SourceFileLoader('model', os.path.join(log_dir, "model.py")) 38 | mod = types.ModuleType(loader.name) 39 | loader.exec_module(mod) 40 | 41 | model = mod.Model(55) 42 | model.cuda() 43 | 44 | logger.info("{} paramerters in total".format(sum(x.numel() for x in model.parameters()))) 45 | logger.info("{} paramerters in the last layer".format(sum(x.numel() for x in model.out_layer.parameters()))) 46 | 47 | bw = model.bandwidths[0] 48 | 49 | # Load the dataset 50 | # Increasing `repeat` will generate more cached files 51 | transform = CacheNPY(prefix="b{}_".format(bw), repeat=augmentation, transform=torchvision.transforms.Compose( 52 | [ 53 | ToMesh(random_rotations=True, random_translation=0.1), 54 | ProjectOnSphere(bandwidth=bw) 55 | ] 56 | )) 57 | 58 | def target_transform(x): 59 | classes = ['02691156', '02747177', '02773838', '02801938', '02808440', '02818832', '02828884', '02843684', '02871439', '02876657', 60 | '02880940', '02924116', '02933112', '02942699', '02946921', '02954340', '02958343', '02992529', '03001627', '03046257', 61 | '03085013', '03207941', '03211117', '03261776', '03325088', '03337140', '03467517', '03513137', '03593526', '03624134', 62 | '03636649', '03642806', '03691459', '03710193', '03759954', '03761084', '03790512', '03797390', '03928116', '03938244', 63 | '03948459', '03991062', '04004475', '04074963', '04090263', '04099429', '04225987', '04256520', '04330267', '04379243', 64 | '04401088', '04460130', '04468005', '04530566', '04554684'] 65 | return classes.index(x[0]) 66 | 67 | train_set = Shrec17("data", dataset, perturbed=True, download=True, transform=transform, target_transform=target_transform) 68 | 69 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) 70 | 71 | optimizer = torch.optim.SGD(model.parameters(), lr=0, momentum=0.9) 72 | 73 | def train_step(data, target): 74 | model.train() 75 | data, target = data.cuda(), target.cuda() 76 | 77 | prediction = model(data) 78 | loss = F.nll_loss(prediction, target) 79 | 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | correct = prediction.data.max(1)[1].eq(target.data).long().cpu().sum() 85 | 86 | return loss.item(), correct.item() 87 | 88 | def get_learning_rate(epoch): 89 | limits = [100, 200] 90 | lrs = [1, 0.1, 0.01] 91 | assert len(lrs) == len(limits) + 1 92 | for lim, lr in zip(limits, lrs): 93 | if epoch < lim: 94 | return lr * learning_rate 95 | return lrs[-1] * learning_rate 96 | 97 | for epoch in range(300): 98 | 99 | lr = get_learning_rate(epoch) 100 | logger.info("learning rate = {} and batch size = {}".format(lr, train_loader.batch_size)) 101 | for p in optimizer.param_groups: 102 | p['lr'] = lr 103 | 104 | total_loss = 0 105 | total_correct = 0 106 | time_before_load = time.perf_counter() 107 | for batch_idx, (data, target) in enumerate(train_loader): 108 | time_after_load = time.perf_counter() 109 | time_before_step = time.perf_counter() 110 | loss, correct = train_step(data, target) 111 | 112 | total_loss += loss 113 | total_correct += correct 114 | 115 | logger.info("[{}:{}/{}] LOSS={:.2} ={:.2} ACC={:.2} ={:.2} time={:.2}+{:.2}".format( 116 | epoch, batch_idx, len(train_loader), 117 | loss, total_loss / (batch_idx + 1), 118 | correct / len(data), total_correct / len(data) / (batch_idx + 1), 119 | time_after_load - time_before_load, 120 | time.perf_counter() - time_before_step)) 121 | time_before_load = time.perf_counter() 122 | 123 | torch.save(model.state_dict(), os.path.join(log_dir, "state.pkl")) 124 | 125 | 126 | if __name__ == "__main__": 127 | import argparse 128 | 129 | parser = argparse.ArgumentParser() 130 | 131 | parser.add_argument("--log_dir", type=str, required=True) 132 | parser.add_argument("--model_path", type=str, required=True) 133 | parser.add_argument("--augmentation", type=int, default=1, 134 | help="Generate multiple image with random rotations and translations") 135 | parser.add_argument("--dataset", choices={"test", "val", "train"}, default="train") 136 | parser.add_argument("--batch_size", type=int, default=32) 137 | parser.add_argument("--num_workers", type=int, default=1) 138 | parser.add_argument("--learning_rate", type=float, default=0.5) 139 | 140 | args = parser.parse_args() 141 | 142 | main(**args.__dict__) 143 | -------------------------------------------------------------------------------- /examples/molecules/datagen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import scipy.io as spio 4 | from scipy.spatial import distance as spdist 5 | import joblib 6 | import lie_learn.spaces.S2 as S2 7 | 8 | 9 | MAX_NUM_ATOMS_PER_MOLECULE = 23 10 | NUM_ATOM_TYPES = 5 11 | 12 | 13 | def get_raw_data(path): 14 | """ load data from matlab file """ 15 | raw = spio.loadmat(path) 16 | coordinates = raw["R"] 17 | charges = raw["Z"] 18 | energies = raw["T"] 19 | strat_ids = raw["P"] 20 | return coordinates, charges, energies, strat_ids 21 | 22 | 23 | def get_projection_grid(bandwidth, grid_type="Driscoll-Healy"): 24 | theta, phi = S2.meshgrid(b=bandwidth, grid_type=grid_type) 25 | x_ = np.sin(theta) * np.cos(phi) 26 | y_ = np.sin(theta) * np.sin(phi) 27 | z_ = np.cos(theta) 28 | return np.array((x_, y_, z_)) 29 | 30 | 31 | def compute_features_for_molecule(unique_charges, mol_coords, mol_charges, 32 | atom_grids, grid_bandwidth, min_atom_dist=1): 33 | # output features 34 | num_atoms = len(mol_coords) 35 | mol_features = np.ndarray((len(unique_charges), num_atoms, 36 | 2*grid_bandwidth, 2*grid_bandwidth)) 37 | 38 | # for each possible interacting atom type 39 | # compute one feature map of interaction 40 | for i, probe_charge in enumerate(unique_charges): 41 | features = np.sum(compute_coloumb_forces(atom_grids, probe_charge, 42 | mol_coords, mol_charges, 43 | min_atom_dist), 44 | axis=0) 45 | mol_features[i] = features.reshape(num_atoms, 2*grid_bandwidth, 46 | 2*grid_bandwidth) 47 | return mol_features 48 | 49 | 50 | def compute_coloumb_forces(atom_grids, probe_charge, mol_coords, mol_charges, 51 | min_atom_dist=1): 52 | inv_sq_distance = compute_inv_sq_distances(atom_grids, probe_charge, 53 | mol_coords, mol_charges, 54 | min_atom_dist) 55 | f = inv_sq_distance * probe_charge * mol_charges.reshape(1, -1, 1) 56 | return f 57 | 58 | 59 | def compute_inv_sq_distances(atom_grids, probe_charge, mol_coords, mol_charges, 60 | min_atom_dist=1): 61 | num_atoms = len(mol_coords) 62 | 63 | # all positions of atoms having the probe charge 64 | probe_coords = mol_coords[mol_charges == probe_charge] 65 | 66 | # if there is no atom with the probe charge we are done 67 | if len(probe_coords) == 0: 68 | return np.zeros((1, num_atoms, atom_grids.shape[0]//num_atoms)) 69 | 70 | # spatial distance between z atoms and grid points 71 | distances = spdist.cdist(probe_coords, atom_grids).reshape( 72 | len(probe_coords), num_atoms, -1) 73 | 74 | # for each atom in the molecule set the distance 75 | # to the grid point around that atom to 0 76 | nonzero = (distances - distances.min(axis=1).reshape( 77 | len(probe_coords), 1, -1)) > 0 78 | 79 | distances = nonzero * 1 / distances**2 80 | 81 | return distances 82 | 83 | 84 | def get_min_distance(coords, charges): 85 | num_molecules = len(coords) 86 | distances = [] 87 | for mol_idx in range(num_molecules): 88 | non_null_atoms = charges[mol_idx] != 0 89 | atom_coords = coords[mol_idx][non_null_atoms] 90 | pairwise_atom_distances = spdist.pdist(atom_coords) 91 | distances.append(pairwise_atom_distances.min()) 92 | # TODO remove unncesseary O(N) traversal by accumulator 93 | min_distance = np.min(distances) 94 | return min_distance 95 | 96 | 97 | def generate_dataset(coordinates, charges, energies, strat_ids, 98 | grid_bandwidth): 99 | num_molecules = len(coordinates) 100 | 101 | data = {} 102 | data["features"] = { 103 | "geometry": np.zeros((num_molecules, MAX_NUM_ATOMS_PER_MOLECULE, 104 | NUM_ATOM_TYPES, 2*grid_bandwidth, 105 | 2*grid_bandwidth)), 106 | "atom_types": np.zeros((num_molecules, MAX_NUM_ATOMS_PER_MOLECULE,)), 107 | "num_atoms": np.zeros((num_molecules,)) 108 | } 109 | data["targets"] = energies.T 110 | data["strats"] = strat_ids 111 | 112 | unique_charges = np.sort(np.unique(charges)) 113 | 114 | # compute minimum distance in the data set 115 | min_atom_dist = get_min_distance(coordinates, charges) 116 | 117 | # get grid for bandwidth 118 | grid = (get_projection_grid(grid_bandwidth) 119 | * min_atom_dist).reshape(3, -1).T 120 | 121 | for mol_idx in range(num_molecules): 122 | 123 | print("\rprocessing molecule {0}/{1}".format( 124 | mol_idx+1, num_molecules), end="") 125 | 126 | non_null_atoms = charges[mol_idx] != 0 127 | num_non_null_atoms = sum(non_null_atoms) 128 | 129 | # get position and types of non NULL atoms 130 | mol_coords = coordinates[mol_idx][non_null_atoms] 131 | mol_charges = charges[mol_idx][non_null_atoms] 132 | 133 | # copy grid around each atom 134 | atom_grids = (np.stack((grid,)*mol_coords.shape[0], 0) 135 | + mol_coords[:, np.newaxis, :]).reshape(-1, 3) 136 | 137 | mol_features = compute_features_for_molecule( 138 | unique_charges[1:], mol_coords, mol_charges, atom_grids, 139 | grid_bandwidth, min_atom_dist) 140 | 141 | # transpose features to right shape 142 | axes = np.arange(len(mol_features.shape)) 143 | axes[0:2] = (1, 0) 144 | mol_features = np.transpose(mol_features, axes) 145 | 146 | # copy to data dict 147 | data["features"]["geometry"][ 148 | mol_idx, :num_non_null_atoms, ...] = mol_features 149 | data["features"]["atom_types"][mol_idx][ 150 | :num_non_null_atoms] = mol_charges 151 | data["features"]["num_atoms"] = non_null_atoms 152 | 153 | print("") 154 | return data 155 | 156 | 157 | def main(): 158 | 159 | parser = argparse.ArgumentParser() 160 | 161 | parser.add_argument("--bandwidth", 162 | help="the bandwidth of the S2 signal", 163 | type=int, 164 | default=10, 165 | required=False) 166 | parser.add_argument("--data_file", 167 | help="file for saving the data output (.gz file)", 168 | type=str, 169 | default="qm7.mat", 170 | required=False) 171 | parser.add_argument("--output_file", 172 | help="file for saving the data output (.gz file)", 173 | type=str, 174 | default="data.joblib", 175 | required=False) 176 | 177 | args = parser.parse_args() 178 | 179 | raw_data = get_raw_data(args.data_file) 180 | 181 | data = generate_dataset(*raw_data, args.bandwidth) 182 | 183 | print("save to file") 184 | joblib.dump(data, args.output_file) 185 | 186 | 187 | if __name__ == '__main__': 188 | main() 189 | -------------------------------------------------------------------------------- /examples/mnist/run.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import numpy as np 3 | import torch.nn as nn 4 | from s2cnn import SO3Convolution 5 | from s2cnn import S2Convolution 6 | from s2cnn import so3_integrate 7 | from s2cnn import so3_near_identity_grid 8 | from s2cnn import s2_near_identity_grid 9 | import torch.nn.functional as F 10 | import torch 11 | import torch.utils.data as data_utils 12 | import gzip 13 | import pickle 14 | import numpy as np 15 | from torch.autograd import Variable 16 | import argparse 17 | 18 | MNIST_PATH = "s2_mnist.gz" 19 | 20 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | NUM_EPOCHS = 20 23 | BATCH_SIZE = 32 24 | LEARNING_RATE = 5e-3 25 | 26 | 27 | def load_data(path, batch_size): 28 | 29 | with gzip.open(path, 'rb') as f: 30 | dataset = pickle.load(f) 31 | 32 | train_data = torch.from_numpy( 33 | dataset["train"]["images"][:, None, :, :].astype(np.float32)) 34 | train_labels = torch.from_numpy( 35 | dataset["train"]["labels"].astype(np.int64)) 36 | 37 | # TODO normalize dataset 38 | # mean = train_data.mean() 39 | # stdv = train_data.std() 40 | 41 | train_dataset = data_utils.TensorDataset(train_data, train_labels) 42 | train_loader = data_utils.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 43 | 44 | test_data = torch.from_numpy( 45 | dataset["test"]["images"][:, None, :, :].astype(np.float32)) 46 | test_labels = torch.from_numpy( 47 | dataset["test"]["labels"].astype(np.int64)) 48 | 49 | test_dataset = data_utils.TensorDataset(test_data, test_labels) 50 | test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 51 | 52 | return train_loader, test_loader, train_dataset, test_dataset 53 | 54 | 55 | class S2ConvNet_original(nn.Module): 56 | 57 | def __init__(self): 58 | super(S2ConvNet_original, self).__init__() 59 | 60 | f1 = 20 61 | f2 = 40 62 | f_output = 10 63 | 64 | b_in = 30 65 | b_l1 = 10 66 | b_l2 = 6 67 | 68 | grid_s2 = s2_near_identity_grid() 69 | grid_so3 = so3_near_identity_grid() 70 | 71 | self.conv1 = S2Convolution( 72 | nfeature_in=1, 73 | nfeature_out=f1, 74 | b_in=b_in, 75 | b_out=b_l1, 76 | grid=grid_s2) 77 | 78 | self.conv2 = SO3Convolution( 79 | nfeature_in=f1, 80 | nfeature_out=f2, 81 | b_in=b_l1, 82 | b_out=b_l2, 83 | grid=grid_so3) 84 | 85 | self.out_layer = nn.Linear(f2, f_output) 86 | 87 | def forward(self, x): 88 | 89 | x = self.conv1(x) 90 | x = F.relu(x) 91 | x = self.conv2(x) 92 | x = F.relu(x) 93 | 94 | x = so3_integrate(x) 95 | 96 | x = self.out_layer(x) 97 | 98 | return x 99 | 100 | 101 | class S2ConvNet_deep(nn.Module): 102 | 103 | def __init__(self, bandwidth=30): 104 | super(S2ConvNet_deep, self).__init__() 105 | 106 | grid_s2 = s2_near_identity_grid(n_alpha=6, max_beta=np.pi/16, n_beta=1) 107 | grid_so3_1 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/16, n_beta=1, max_gamma=2*np.pi, n_gamma=6) 108 | grid_so3_2 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/ 8, n_beta=1, max_gamma=2*np.pi, n_gamma=6) 109 | grid_so3_3 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/ 4, n_beta=1, max_gamma=2*np.pi, n_gamma=6) 110 | grid_so3_4 = so3_near_identity_grid(n_alpha=6, max_beta=np.pi/ 2, n_beta=1, max_gamma=2*np.pi, n_gamma=6) 111 | 112 | self.convolutional = nn.Sequential( 113 | S2Convolution( 114 | nfeature_in = 1, 115 | nfeature_out = 8, 116 | b_in = bandwidth, 117 | b_out = bandwidth, 118 | grid=grid_s2), 119 | nn.ReLU(inplace=False), 120 | SO3Convolution( 121 | nfeature_in = 8, 122 | nfeature_out = 16, 123 | b_in = bandwidth, 124 | b_out = bandwidth//2, 125 | grid=grid_so3_1), 126 | nn.ReLU(inplace=False), 127 | 128 | SO3Convolution( 129 | nfeature_in = 16, 130 | nfeature_out = 16, 131 | b_in = bandwidth//2, 132 | b_out = bandwidth//2, 133 | grid=grid_so3_2), 134 | nn.ReLU(inplace=False), 135 | SO3Convolution( 136 | nfeature_in = 16, 137 | nfeature_out = 24, 138 | b_in = bandwidth//2, 139 | b_out = bandwidth//4, 140 | grid=grid_so3_2), 141 | nn.ReLU(inplace=False), 142 | 143 | SO3Convolution( 144 | nfeature_in = 24, 145 | nfeature_out = 24, 146 | b_in = bandwidth//4, 147 | b_out = bandwidth//4, 148 | grid=grid_so3_3), 149 | nn.ReLU(inplace=False), 150 | SO3Convolution( 151 | nfeature_in = 24, 152 | nfeature_out = 32, 153 | b_in = bandwidth//4, 154 | b_out = bandwidth//8, 155 | grid=grid_so3_3), 156 | nn.ReLU(inplace=False), 157 | 158 | SO3Convolution( 159 | nfeature_in = 32, 160 | nfeature_out = 64, 161 | b_in = bandwidth//8, 162 | b_out = bandwidth//8, 163 | grid=grid_so3_4), 164 | nn.ReLU(inplace=False) 165 | ) 166 | 167 | self.linear = nn.Sequential( 168 | # linear 1 169 | nn.BatchNorm1d(64), 170 | nn.Linear(in_features=64,out_features=64), 171 | nn.ReLU(inplace=False), 172 | # linear 2 173 | nn.BatchNorm1d(64), 174 | nn.Linear(in_features=64, out_features=32), 175 | nn.ReLU(inplace=False), 176 | # linear 3 177 | nn.BatchNorm1d(32), 178 | nn.Linear(in_features=32, out_features=10) 179 | ) 180 | 181 | def forward(self, x): 182 | x = self.convolutional(x) 183 | x = so3_integrate(x) 184 | x = self.linear(x) 185 | return x 186 | 187 | 188 | 189 | def main(network): 190 | 191 | train_loader, test_loader, train_dataset, _ = load_data( 192 | MNIST_PATH, BATCH_SIZE) 193 | 194 | if network == 'original': 195 | classifier = S2ConvNet_original() 196 | elif network == 'deep': 197 | classifier = S2ConvNet_deep() 198 | else: 199 | raise ValueError('Unknown network architecture') 200 | classifier.to(DEVICE) 201 | 202 | print("#params", sum(x.numel() for x in classifier.parameters())) 203 | 204 | criterion = nn.CrossEntropyLoss() 205 | criterion = criterion.to(DEVICE) 206 | 207 | optimizer = torch.optim.Adam( 208 | classifier.parameters(), 209 | lr=LEARNING_RATE) 210 | 211 | for epoch in range(NUM_EPOCHS): 212 | for i, (images, labels) in enumerate(train_loader): 213 | classifier.train() 214 | 215 | images = images.to(DEVICE) 216 | labels = labels.to(DEVICE) 217 | 218 | optimizer.zero_grad() 219 | outputs = classifier(images) 220 | loss = criterion(outputs, labels) 221 | loss.backward() 222 | 223 | optimizer.step() 224 | 225 | print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format( 226 | epoch+1, NUM_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, 227 | loss.item()), end="") 228 | print("") 229 | correct = 0 230 | total = 0 231 | for images, labels in test_loader: 232 | 233 | classifier.eval() 234 | 235 | with torch.no_grad(): 236 | images = images.to(DEVICE) 237 | labels = labels.to(DEVICE) 238 | 239 | outputs = classifier(images) 240 | _, predicted = torch.max(outputs, 1) 241 | total += labels.size(0) 242 | correct += (predicted == labels).long().sum().item() 243 | 244 | print('Test Accuracy: {0}'.format(100 * correct / total)) 245 | 246 | 247 | if __name__ == '__main__': 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument("--network", 250 | help="network architecture to use", 251 | default='original', 252 | choices=['original', 'deep']) 253 | args = parser.parse_args() 254 | 255 | main(args.network) 256 | -------------------------------------------------------------------------------- /examples/molecules/run_experiment.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from s2cnn_model import S2CNNRegressor 7 | from baseline_model import BaselineRegressor 8 | from utils import load_data, IndexBatcher, to_one_hot, exp_lr_scheduler, \ 9 | count_params 10 | import numpy as np 11 | 12 | 13 | OPTIMIZER = torch.optim.Adam 14 | 15 | NUM_ATOM = 23 16 | NUM_ATOM_TYPES = 6 17 | 18 | 19 | def eval_batch_mlp(mlp, data, batch_idxs, criterion, device_id=0): 20 | """ evaluate a batch for the baseline mlp """ 21 | atom_types = to_one_hot(data["features"]["atom_types"][batch_idxs, ...], 22 | NUM_ATOM_TYPES) 23 | targets = data["targets"][batch_idxs, ...] 24 | 25 | atom_types = Variable(atom_types) 26 | targets = Variable(targets) 27 | 28 | if torch.cuda.is_available(): 29 | atom_types = atom_types.cuda(device_id) 30 | targets = targets.cuda(device_id) 31 | 32 | outputs = mlp(atom_types) 33 | loss = criterion(outputs, targets) 34 | return loss 35 | 36 | 37 | def eval_batch_s2cnn(mlp, s2cnn, data, batch_idxs, criterion, device_id=0): 38 | """ evaluate a batch for the s2cnn """ 39 | geometry = data["features"]["geometry"][batch_idxs, ...] 40 | atom_types = data["features"]["atom_types"][batch_idxs, ...] 41 | atom_types_one_hot = to_one_hot(atom_types, NUM_ATOM_TYPES) 42 | targets = data["targets"][batch_idxs, ...] 43 | 44 | geometry = Variable(geometry) 45 | atom_types = Variable(atom_types) 46 | atom_types_one_hot = Variable(atom_types_one_hot) 47 | targets = Variable(targets) 48 | 49 | if torch.cuda.is_available(): 50 | atom_types_one_hot = atom_types_one_hot.cuda(device_id) 51 | geometry = geometry.cuda(device_id) 52 | atom_types = atom_types.cuda(device_id) 53 | targets = targets.cuda(device_id) 54 | 55 | outputs = mlp(atom_types_one_hot) 56 | outputs += s2cnn(geometry, atom_types) 57 | 58 | loss = criterion(outputs, targets) 59 | 60 | return loss 61 | 62 | 63 | def train_baseline(mlp, data, train_batches, test_batches, num_epochs, 64 | learning_rate_mlp, device_id=0): 65 | """ train the baseline model """ 66 | optim = OPTIMIZER(mlp.parameters(), lr=learning_rate_mlp) 67 | criterion = nn.MSELoss() 68 | if torch.cuda.is_available(): 69 | criterion = criterion.cuda(device_id) 70 | for epoch in range(num_epochs): 71 | train_losses = [] 72 | print("training") 73 | for iteration, batch_idxs in enumerate(train_batches): 74 | mlp.train() 75 | optim.zero_grad() 76 | loss = eval_batch_mlp(mlp, data, batch_idxs, criterion, device_id) 77 | loss.backward() 78 | optim.step() 79 | train_losses.append(loss.item()) 80 | print("\riteration {}/{}".format( 81 | iteration+1, train_batches.num_iterations()), end="") 82 | print() 83 | test_losses = [] 84 | print("evaluating") 85 | for iteration, batch_idxs in enumerate(test_batches): 86 | mlp.eval() 87 | loss = eval_batch_mlp(mlp, data, batch_idxs, criterion) 88 | test_losses.append(loss.item()) 89 | print("\riteration {}/{}".format( 90 | iteration+1, test_batches.num_iterations()), end="") 91 | print() 92 | train_loss = np.sqrt(np.mean(train_losses)) 93 | test_loss = np.sqrt(np.mean(test_losses)) 94 | print("epoch {}/{} - avg train loss: {}, test loss: {}".format( 95 | epoch+1, num_epochs, train_loss, test_loss)) 96 | return train_loss, test_loss 97 | 98 | 99 | def train_s2cnn(mlp, s2cnn, data, train_batches, test_batches, num_epochs, 100 | init_learning_rate_s2cnn, learning_rate_decay_epochs, 101 | device_id=0): 102 | """ train the s2cnn keeping the baseline frozen """ 103 | optim = OPTIMIZER(s2cnn.parameters(), lr=init_learning_rate_s2cnn) 104 | criterion = nn.MSELoss() 105 | if torch.cuda.is_available(): 106 | criterion = criterion.cuda(device_id) 107 | for epoch in range(num_epochs): 108 | optim = exp_lr_scheduler(optim, epoch, 109 | init_lr=init_learning_rate_s2cnn, 110 | lr_decay_epoch=learning_rate_decay_epochs) 111 | train_losses = [] 112 | print("training") 113 | for iteration, batch_idxs in enumerate(train_batches): 114 | s2cnn.train() 115 | mlp.eval() 116 | optim.zero_grad() 117 | loss = eval_batch_s2cnn(mlp, s2cnn, data, batch_idxs, criterion) 118 | loss.backward() 119 | optim.step() 120 | train_losses.append(loss.item()) 121 | print("\riteration {}/{} - batch loss: {}".format( 122 | iteration+1, train_batches.num_iterations(), 123 | np.sqrt(train_losses[-1])), end="") 124 | print() 125 | test_losses = [] 126 | print("evaluating") 127 | for iteration, batch_idxs in enumerate(test_batches): 128 | s2cnn.eval() 129 | mlp.eval() 130 | loss = eval_batch_s2cnn(mlp, s2cnn, data, batch_idxs, criterion) 131 | test_losses.append(loss.item()) 132 | print("\riteration {}/{} - batch loss: {}".format( 133 | iteration+1, test_batches.num_iterations(), 134 | np.sqrt(test_losses[-1])), end="") 135 | print() 136 | train_loss = np.sqrt(np.mean(train_losses)) 137 | test_loss = np.sqrt(np.mean(test_losses)) 138 | print("epoch {}/{} - avg train loss: {}, test loss: {}".format( 139 | epoch+1, num_epochs, train_loss, test_loss)) 140 | return train_loss, test_loss 141 | 142 | 143 | def main(): 144 | 145 | parser = argparse.ArgumentParser() 146 | 147 | parser.add_argument("--data_path", 148 | type=str, 149 | default="data.joblib") 150 | parser.add_argument("--test_strat", 151 | type=int, 152 | default=0) 153 | parser.add_argument("--device_id", 154 | type=int, 155 | default=0) 156 | parser.add_argument("--num_epochs_s2cnn", 157 | type=int, 158 | default=30) 159 | parser.add_argument("--num_epochs_mlp", 160 | type=int, 161 | default=30) 162 | parser.add_argument("--batch_size_s2cnn", 163 | type=int, 164 | default=32) 165 | parser.add_argument("--batch_size_mlp", 166 | type=int, 167 | default=32) 168 | parser.add_argument("--init_learning_rate_s2cnn", 169 | type=int, 170 | default=1e-3) 171 | parser.add_argument("--learning_rate_mlp", 172 | type=int, 173 | default=1e-3) 174 | parser.add_argument("--learning_rate_decay_epochs", 175 | type=int, 176 | default=10) 177 | 178 | args = parser.parse_args() 179 | 180 | torch.cuda.set_device(args.device_id) 181 | 182 | print("evaluating on {}".format(args.test_strat)) 183 | 184 | print("loading data...", end="") 185 | data, train_idxs, test_idxs = load_data(args.data_path, args.test_strat, 186 | cuda=args.device_id) 187 | print("done!") 188 | 189 | mlp = BaselineRegressor() 190 | s2cnn = S2CNNRegressor() 191 | 192 | if torch.cuda.is_available(): 193 | for model in [mlp, s2cnn]: 194 | model.cuda(args.device_id) 195 | 196 | print("training baseline model") 197 | print("mlp #params: {}".format(count_params(mlp))) 198 | train_baseline(mlp, data, 199 | IndexBatcher(train_idxs, args.batch_size_mlp, 200 | cuda=args.device_id), 201 | IndexBatcher(test_idxs, args.batch_size_mlp, 202 | cuda=args.device_id), 203 | args.num_epochs_mlp, args.learning_rate_mlp, args.device_id) 204 | 205 | print("training residual s2cnn model") 206 | print("s2cnn #params: {}".format(count_params(s2cnn))) 207 | train_s2cnn(mlp, s2cnn, data, 208 | IndexBatcher(train_idxs, args.batch_size_s2cnn, 209 | cuda=args.device_id), 210 | IndexBatcher(test_idxs, args.batch_size_s2cnn, 211 | cuda=args.device_id), 212 | args.num_epochs_s2cnn, args.init_learning_rate_s2cnn, 213 | args.learning_rate_decay_epochs, args.device_id) 214 | 215 | 216 | if __name__ == '__main__': 217 | main() 218 | -------------------------------------------------------------------------------- /examples/mnist/gendata.py: -------------------------------------------------------------------------------- 1 | '''Module to generate the spherical mnist data set''' 2 | 3 | import gzip 4 | import pickle 5 | import numpy as np 6 | import argparse 7 | import lie_learn.spaces.S2 as S2 8 | from torchvision import datasets 9 | 10 | 11 | NORTHPOLE_EPSILON = 1e-3 12 | 13 | 14 | def rand_rotation_matrix(deflection=1.0, randnums=None): 15 | """ 16 | Creates a random rotation matrix. 17 | 18 | deflection: the magnitude of the rotation. For 0, no rotation; for 1, competely random 19 | rotation. Small deflection => small perturbation. 20 | randnums: 3 random numbers in the range [0, 1]. If `None`, they will be auto-generated. 21 | 22 | # http://blog.lostinmyterminal.com/python/2015/05/12/random-rotation-matrix.html 23 | """ 24 | 25 | if randnums is None: 26 | randnums = np.random.uniform(size=(3,)) 27 | 28 | theta, phi, z = randnums 29 | 30 | theta = theta * 2.0*deflection*np.pi # Rotation about the pole (Z). 31 | phi = phi * 2.0*np.pi # For direction of pole deflection. 32 | z = z * 2.0*deflection # For magnitude of pole deflection. 33 | 34 | # Compute a vector V used for distributing points over the sphere 35 | # via the reflection I - V Transpose(V). This formulation of V 36 | # will guarantee that if x[1] and x[2] are uniformly distributed, 37 | # the reflected points will be uniform on the sphere. Note that V 38 | # has length sqrt(2) to eliminate the 2 in the Householder matrix. 39 | 40 | r = np.sqrt(z) 41 | V = ( 42 | np.sin(phi) * r, 43 | np.cos(phi) * r, 44 | np.sqrt(2.0 - z) 45 | ) 46 | 47 | st = np.sin(theta) 48 | ct = np.cos(theta) 49 | 50 | R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) 51 | 52 | # Construct the rotation matrix ( V Transpose(V) - I ) R. 53 | 54 | M = (np.outer(V, V) - np.eye(3)).dot(R) 55 | return M 56 | 57 | 58 | def rotate_grid(rot, grid): 59 | x, y, z = grid 60 | xyz = np.array((x, y, z)) 61 | x_r, y_r, z_r = np.einsum('ij,jab->iab', rot, xyz) 62 | return x_r, y_r, z_r 63 | 64 | 65 | def get_projection_grid(b, grid_type="Driscoll-Healy"): 66 | ''' returns the spherical grid in euclidean 67 | coordinates, where the sphere's center is moved 68 | to (0, 0, 1)''' 69 | theta, phi = S2.meshgrid(b=b, grid_type=grid_type) 70 | x_ = np.sin(theta) * np.cos(phi) 71 | y_ = np.sin(theta) * np.sin(phi) 72 | z_ = np.cos(theta) 73 | return x_, y_, z_ 74 | 75 | 76 | def project_sphere_on_xy_plane(grid, projection_origin): 77 | ''' returns xy coordinates on the plane 78 | obtained from projecting each point of 79 | the spherical grid along the ray from 80 | the projection origin through the sphere ''' 81 | 82 | sx, sy, sz = projection_origin 83 | x, y, z = grid 84 | z = z.copy() + 1 85 | 86 | t = -z / (z - sz) 87 | qx = t * (x - sx) + x 88 | qy = t * (y - sy) + y 89 | 90 | xmin = 1/2 * (-1 - sx) + -1 91 | ymin = 1/2 * (-1 - sy) + -1 92 | 93 | # ensure that plane projection 94 | # ends up on southern hemisphere 95 | rx = (qx - xmin) / (2 * np.abs(xmin)) 96 | ry = (qy - ymin) / (2 * np.abs(ymin)) 97 | 98 | return rx, ry 99 | 100 | 101 | def sample_within_bounds(signal, x, y, bounds): 102 | ''' ''' 103 | xmin, xmax, ymin, ymax = bounds 104 | 105 | idxs = (xmin <= x) & (x < xmax) & (ymin <= y) & (y < ymax) 106 | 107 | if len(signal.shape) > 2: 108 | sample = np.zeros((signal.shape[0], x.shape[0], x.shape[1])) 109 | sample[:, idxs] = signal[:, x[idxs], y[idxs]] 110 | else: 111 | sample = np.zeros((x.shape[0], x.shape[1])) 112 | sample[idxs] = signal[x[idxs], y[idxs]] 113 | return sample 114 | 115 | 116 | def sample_bilinear(signal, rx, ry): 117 | ''' ''' 118 | 119 | signal_dim_x = signal.shape[1] 120 | signal_dim_y = signal.shape[2] 121 | 122 | rx *= signal_dim_x 123 | ry *= signal_dim_y 124 | 125 | # discretize sample position 126 | ix = rx.astype(int) 127 | iy = ry.astype(int) 128 | 129 | # obtain four sample coordinates 130 | ix0 = ix 131 | iy0 = iy 132 | ix1 = ix + 1 133 | iy1 = iy + 1 134 | 135 | bounds = (0, signal_dim_x, 0, signal_dim_y) 136 | 137 | # sample signal at each four positions 138 | signal_00 = sample_within_bounds(signal, ix0, iy0, bounds) 139 | signal_10 = sample_within_bounds(signal, ix1, iy0, bounds) 140 | signal_01 = sample_within_bounds(signal, ix0, iy1, bounds) 141 | signal_11 = sample_within_bounds(signal, ix1, iy1, bounds) 142 | 143 | # linear interpolation in x-direction 144 | fx1 = (ix1-rx) * signal_00 + (rx-ix0) * signal_10 145 | fx2 = (ix1-rx) * signal_01 + (rx-ix0) * signal_11 146 | 147 | # linear interpolation in y-direction 148 | return (iy1 - ry) * fx1 + (ry - iy0) * fx2 149 | 150 | 151 | def project_2d_on_sphere(signal, grid, projection_origin=None): 152 | ''' ''' 153 | if projection_origin is None: 154 | projection_origin = (0, 0, 2 + NORTHPOLE_EPSILON) 155 | 156 | rx, ry = project_sphere_on_xy_plane(grid, projection_origin) 157 | sample = sample_bilinear(signal, rx, ry) 158 | 159 | # ensure that only south hemisphere gets projected 160 | sample *= (grid[2] <= 1).astype(np.float64) 161 | 162 | # rescale signal to [0,1] 163 | sample_min = sample.min(axis=(1, 2)).reshape(-1, 1, 1) 164 | sample_max = sample.max(axis=(1, 2)).reshape(-1, 1, 1) 165 | 166 | sample = (sample - sample_min) / (sample_max - sample_min) 167 | sample *= 255 168 | sample = sample.astype(np.uint8) 169 | 170 | return sample 171 | 172 | 173 | def main(): 174 | ''' ''' 175 | parser = argparse.ArgumentParser() 176 | 177 | parser.add_argument("--bandwidth", 178 | help="the bandwidth of the S2 signal", 179 | type=int, 180 | default=30, 181 | required=False) 182 | parser.add_argument("--noise", 183 | help="the rotational noise applied on the sphere", 184 | type=float, 185 | default=1.0, 186 | required=False) 187 | parser.add_argument("--chunk_size", 188 | help="size of image chunk with same rotation", 189 | type=int, 190 | default=500, 191 | required=False) 192 | parser.add_argument("--mnist_data_folder", 193 | help="folder for saving the mnist data", 194 | type=str, 195 | default="MNIST_data", 196 | required=False) 197 | parser.add_argument("--output_file", 198 | help="file for saving the data output (.gz file)", 199 | type=str, 200 | default="s2_mnist.gz", 201 | required=False) 202 | parser.add_argument("--no_rotate_train", 203 | help="do not rotate train set", 204 | dest='no_rotate_train', action='store_true') 205 | parser.add_argument("--no_rotate_test", 206 | help="do not rotate test set", 207 | dest='no_rotate_test', action='store_true') 208 | 209 | args = parser.parse_args() 210 | 211 | print("getting mnist data") 212 | trainset = datasets.MNIST(root=args.mnist_data_folder, train=True, download=True) 213 | testset = datasets.MNIST(root=args.mnist_data_folder, train=False, download=True) 214 | mnist_train = {} 215 | mnist_train['images'] = trainset.train_data.numpy() 216 | mnist_train['labels'] = trainset.train_labels.numpy() 217 | mnist_test = {} 218 | mnist_test['images'] = testset.test_data.numpy() 219 | mnist_test['labels'] = testset.test_labels.numpy() 220 | 221 | grid = get_projection_grid(b=args.bandwidth) 222 | 223 | # result 224 | dataset = {} 225 | 226 | no_rotate = {"train": args.no_rotate_train, "test": args.no_rotate_test} 227 | 228 | for label, data in zip(["train", "test"], [mnist_train, mnist_test]): 229 | 230 | print("projecting {0} data set".format(label)) 231 | current = 0 232 | signals = data['images'].reshape(-1, 28, 28).astype(np.float64) 233 | n_signals = signals.shape[0] 234 | projections = np.ndarray( 235 | (signals.shape[0], 2 * args.bandwidth, 2 * args.bandwidth), 236 | dtype=np.uint8) 237 | 238 | while current < n_signals: 239 | 240 | if not no_rotate[label]: 241 | rot = rand_rotation_matrix(deflection=args.noise) 242 | rotated_grid = rotate_grid(rot, grid) 243 | else: 244 | rotated_grid = grid 245 | 246 | idxs = np.arange(current, min(n_signals, 247 | current + args.chunk_size)) 248 | chunk = signals[idxs] 249 | projections[idxs] = project_2d_on_sphere(chunk, rotated_grid) 250 | current += args.chunk_size 251 | print("\r{0}/{1}".format(current, n_signals), end="") 252 | print("") 253 | dataset[label] = { 254 | 'images': projections, 255 | 'labels': data['labels'] 256 | } 257 | print("writing pickle") 258 | with gzip.open(args.output_file, 'wb') as f: 259 | pickle.dump(dataset, f) 260 | print("done") 261 | 262 | 263 | if __name__ == '__main__': 264 | main() 265 | -------------------------------------------------------------------------------- /s2cnn/soft/s2_fft.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101,E1102 2 | from functools import lru_cache 3 | import torch 4 | import torch.cuda 5 | from string import Template 6 | from s2cnn.utils.decorator import cached_dirpklgz 7 | import torch.fft 8 | 9 | 10 | # inspired by https://gist.github.com/szagoruyko/89f83b6f5f4833d3c8adf81ee49f22a8 11 | 12 | 13 | def s2_fft(x, for_grad=False, b_out=None): 14 | ''' 15 | :param x: [..., beta, alpha, complex] 16 | :return: [l * m, ..., complex] 17 | ''' 18 | assert x.size(-1) == 2 19 | b_in = x.size(-2) // 2 20 | assert x.size(-2) == 2 * b_in 21 | assert x.size(-3) == 2 * b_in 22 | if b_out is None: 23 | b_out = b_in 24 | assert b_out <= b_in 25 | batch_size = x.size()[:-3] 26 | 27 | x = x.view(-1, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, complex] 28 | 29 | ''' 30 | :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) 31 | :return: [l * m, batch, complex] (b_out**2, nbatch, 2) 32 | ''' 33 | nspec = b_out ** 2 34 | nbatch = x.size(0) 35 | 36 | wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) 37 | wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec) 38 | 39 | x = torch.view_as_real(torch.fft.fft(torch.view_as_complex(x))) # [batch, beta, m, complex] 40 | 41 | output = x.new_empty((nspec, nbatch, 2)) 42 | if x.is_cuda and x.dtype == torch.float32: 43 | import s2cnn.utils.cuda as cuda_utils 44 | cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index) 45 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 46 | cuda_kernel(block=(1024, 1, 1), 47 | grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), 48 | args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()], 49 | stream=stream) 50 | # [l * m, batch, complex] 51 | else: 52 | for l in range(b_out): 53 | s = slice(l ** 2, l ** 2 + 2 * l + 1) 54 | xx = torch.cat((x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1] 55 | output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) 56 | 57 | output = output.view(-1, *batch_size, 2) # [l * m, ..., complex] (nspec, ..., 2) 58 | return output 59 | 60 | 61 | def s2_ifft(x, for_grad=False, b_out=None): 62 | ''' 63 | :param x: [l * m, ..., complex] 64 | ''' 65 | assert x.size(-1) == 2 66 | nspec = x.size(0) 67 | b_in = round(nspec ** 0.5) 68 | assert nspec == b_in ** 2 69 | if b_out is None: 70 | b_out = b_in 71 | assert b_out >= b_in 72 | batch_size = x.size()[1:-1] 73 | 74 | x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2) 75 | 76 | ''' 77 | :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) 78 | :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) 79 | ''' 80 | nbatch = x.size(1) 81 | 82 | wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) 83 | wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) 84 | 85 | if x.is_cuda and x.dtype == torch.float32: 86 | import s2cnn.utils.cuda as cuda_utils 87 | cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index) 88 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 89 | output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) 90 | cuda_kernel(block=(1024, 1, 1), 91 | grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1), 92 | args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], 93 | stream=stream) 94 | # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) 95 | else: 96 | output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) 97 | for l in range(b_in): 98 | s = slice(l ** 2, l ** 2 + 2 * l + 1) 99 | out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) 100 | output[:, :, :l + 1] += out[:, :, -l - 1:] 101 | if l > 0: 102 | output[:, :, -l:] += out[:, :, :l] 103 | 104 | output = torch.view_as_real(torch.fft.ifft(torch.view_as_complex(output))) * output.size(-2) # [batch, beta, alpha, complex] 105 | output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2) 106 | return output 107 | 108 | 109 | @lru_cache(maxsize=32) 110 | def _setup_wigner(b, nl, weighted, device): 111 | dss = _setup_s2_fft(b, nl, weighted) 112 | dss = torch.tensor(dss, dtype=torch.float32, device=device) # [beta, l * m] # pylint: disable=E1102 113 | return dss.contiguous() 114 | 115 | 116 | @cached_dirpklgz("cache/setup_s2_fft") 117 | def _setup_s2_fft(b, nl, weighted): 118 | from lie_learn.representations.SO3.wigner_d import wigner_d_matrix 119 | import lie_learn.spaces.S3 as S3 120 | import numpy as np 121 | import logging 122 | 123 | betas = (np.arange(2 * b) + 0.5) / (2 * b) * np.pi 124 | w = S3.quadrature_weights(b) * 2 * b 125 | assert len(w) == len(betas) 126 | 127 | logging.getLogger("trainer").info("Compute Wigner (only columns): b=%d nbeta=%d nl=%d nspec=%d", b, len(betas), nl, 128 | nl ** 2) 129 | 130 | dss = [] 131 | for b, beta in enumerate(betas): 132 | ds = [] 133 | for l in range(nl): 134 | d = wigner_d_matrix(l, beta, 135 | field='complex', normalization='quantum', order='centered', condon_shortley='cs') 136 | d = d[:, l] # d[m=:, n=0] 137 | 138 | if weighted: 139 | d *= w[b] 140 | else: 141 | d *= 2 * l + 1 142 | 143 | ds.append(d) # [m] 144 | dss.append(np.concatenate(ds)) # [l * m] 145 | 146 | dss = np.stack(dss) # [beta, l * m] 147 | return dss 148 | 149 | 150 | @lru_cache(maxsize=32) 151 | def _setup_s2fft_cuda_kernel(b, nspec, nbatch, device=0): 152 | kernel = Template(''' 153 | #define COMPUTE_LM(s) \ 154 | int l = sqrtf(s); \ 155 | int m = (s - l * l) - l; 156 | 157 | #define MOD(i, n) (((i) + (n)) % (n)) 158 | 159 | extern "C" 160 | __global__ void main_(const float* in, const float* wig, float* out) { 161 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < ${nspec} * ${nbatch}; index += blockDim.x * gridDim.x) { 162 | int i = index % ${nbatch}; // batch index 163 | int s = index / ${nbatch}; // spectral index 164 | 165 | // compute s -> (l,m) 166 | COMPUTE_LM(s) 167 | 168 | float out_re = 0.0; 169 | float out_im = 0.0; 170 | for (int beta = 0; beta < 2 * ${b}; ++beta) { 171 | float in_re = in[((i * 2 * ${b} + beta) * 2 * ${b} + MOD(m, 2 * ${b})) * 2 + 0]; 172 | float in_im = in[((i * 2 * ${b} + beta) * 2 * ${b} + MOD(m, 2 * ${b})) * 2 + 1]; 173 | float w = wig[beta * ${nspec} + s]; 174 | 175 | out_re += w * in_re; 176 | out_im += w * in_im; 177 | } 178 | out[index * 2 + 0] = out_re; 179 | out[index * 2 + 1] = out_im; 180 | } 181 | } 182 | ''').substitute({'b': b, 'nbatch': nbatch, 'nspec': nspec}) 183 | 184 | import s2cnn.utils.cuda as cuda_utils 185 | return cuda_utils.compile_kernel(kernel, 's2fft.cu', 'main_') 186 | 187 | 188 | @lru_cache(maxsize=32) 189 | def _setup_s2ifft_cuda_kernel(b, nl, nbatch, device=0): 190 | kernel = Template(''' 191 | extern "C" 192 | __global__ void main_(const float* in, const float* wig, float* out) { 193 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < ${nbatch} * 2 * ${b} * 2 * ${b}; index += blockDim.x * gridDim.x) { 194 | int i = index / (2 * ${b} * 2 * ${b}); // batch index 195 | int beta = (index / (2 * ${b})) % (2 * ${b}); 196 | int m = index % (2 * ${b}); 197 | 198 | // from 0,1,2, 3, 4 or 0,1,2, 3, 4, 5 199 | // to 0,1,2,-2,-1 or 0,1,2,-3,-2,-1 200 | int mm = m <= (2 * ${b} - 1) / 2 ? m : m - 2 * ${b}; 201 | 202 | float out_re = 0.0; 203 | float out_im = 0.0; 204 | 205 | for (int l = abs(mm); l < ${nl}; ++l) { 206 | int s = l * l + (l + mm); 207 | 208 | float in_re = in[(s * ${nbatch} + i) * 2 + 0]; 209 | float in_im = in[(s * ${nbatch} + i) * 2 + 1]; 210 | float w = wig[beta * ${nspec} + s]; 211 | 212 | out_re += in_re * w; 213 | out_im += in_im * w; 214 | } 215 | 216 | out[index * 2 + 0] = out_re; 217 | out[index * 2 + 1] = out_im; 218 | } 219 | } 220 | ''').substitute({'b': b, 'nbatch': nbatch, 'nl': nl, 'nspec': nl ** 2}) 221 | 222 | import s2cnn.utils.cuda as cuda_utils 223 | return cuda_utils.compile_kernel(kernel, 's2ifft.cu', 'main_') 224 | 225 | 226 | class S2_fft_real(torch.autograd.Function): 227 | @staticmethod 228 | def forward(ctx, x, b_out=None): # pylint: disable=W 229 | from s2cnn.utils.complex import as_complex 230 | ctx.b_out = b_out 231 | ctx.b_in = x.size(-1) // 2 232 | return s2_fft(as_complex(x), b_out=ctx.b_out) 233 | 234 | @staticmethod 235 | def backward(ctx, grad_output): # pylint: disable=W 236 | return s2_ifft(grad_output, for_grad=True, b_out=ctx.b_in)[..., 0], None 237 | 238 | 239 | class S2_ifft_real(torch.autograd.Function): 240 | @staticmethod 241 | def forward(ctx, x, b_out=None): # pylint: disable=W 242 | nspec = x.size(0) 243 | ctx.b_out = b_out 244 | ctx.b_in = round(nspec ** 0.5) 245 | return s2_ifft(x, b_out=ctx.b_out)[..., 0] 246 | 247 | @staticmethod 248 | def backward(ctx, grad_output): # pylint: disable=W 249 | from s2cnn.utils.complex import as_complex 250 | return s2_fft(as_complex(grad_output), for_grad=True, b_out=ctx.b_in), None 251 | 252 | 253 | def test_s2fft_cuda_cpu(): 254 | x = torch.rand(1, 2, 12, 12, 2) # [..., beta, alpha, complex] 255 | z1 = s2_fft(x, b_out=5) 256 | z2 = s2_fft(x.cuda(), b_out=5).cpu() 257 | q = (z1 - z2).abs().max().item() / z1.std().item() 258 | print(q) 259 | assert q < 1e-4 260 | 261 | 262 | def test_s2ifft_cuda_cpu(): 263 | x = torch.rand(12 ** 2, 10, 2) # [l * m, ..., complex] 264 | z1 = s2_ifft(x, b_out=13) 265 | z2 = s2_ifft(x.cuda(), b_out=13).cpu() 266 | q = (z1 - z2).abs().max().item() / z1.std().item() 267 | print(q) 268 | assert q < 1e-4 269 | 270 | 271 | if __name__ == "__main__": 272 | test_s2fft_cuda_cpu() 273 | test_s2ifft_cuda_cpu() 274 | -------------------------------------------------------------------------------- /s2cnn/so3_mm.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | import math 3 | from functools import lru_cache 4 | import torch 5 | import torch.cuda 6 | 7 | 8 | def so3_mm(x, y): 9 | ''' 10 | :param x: [l * m * n, batch, feature_in, complex] 11 | :param y: [l * m * n, feature_in, feature_out, complex] 12 | :return: [l * m * n, batch, feature_out, complex] 13 | ''' 14 | from s2cnn.utils.complex import complex_mm 15 | import math 16 | 17 | assert y.size(3) == 2 18 | assert x.size(3) == 2 19 | nbatch = x.size(1) 20 | nfeature_in = x.size(2) 21 | nfeature_out = y.size(2) 22 | assert y.size(1) == nfeature_in 23 | nspec = x.size(0) 24 | assert y.size(0) == nspec 25 | nl = math.ceil((3 / 4 * nspec) ** (1 / 3)) 26 | assert nspec == nl * (4 * nl ** 2 - 1) // 3 27 | 28 | if x.is_cuda: 29 | return _cuda_SO3_mm.apply(x, y) 30 | 31 | Fz_list = [] 32 | begin = 0 33 | for l in range(nl): 34 | L = 2 * l + 1 35 | size = L ** 2 36 | 37 | Fx = x[begin:begin + size] # [m * n, batch, feature_in, complex] 38 | Fy = y[begin:begin + size] # [m * n, feature_in, feature_out, complex] 39 | 40 | Fx = Fx.view(L, L, nbatch, nfeature_in, 2) # [m, n, batch, feature_in, complex] 41 | Fx = Fx.transpose(0, 1) # [n, m, batch, feature_in, complex] 42 | Fx = Fx.transpose(0, 2) # [batch, m, n, feature_in, complex] 43 | Fx = Fx.transpose(2, 3) # [batch, m, feature_in, n, complex] 44 | Fx = Fx.contiguous() 45 | Fx = Fx.view(nbatch * L, nfeature_in * L, 2) # [batch * m, feature_in * n, complex] 46 | 47 | Fy = Fy.view(L, L, nfeature_in, nfeature_out, 2) # [m, n, feature_in, feature_out, complex] 48 | Fy = Fy.transpose(0, 2) # [feature_in, n, m, feature_out, complex] 49 | Fy = Fy.contiguous() 50 | Fy = Fy.view(nfeature_in * L, L * nfeature_out, 2) # [feature_in * n, m * feature_out, complex] 51 | 52 | Fz = complex_mm(Fx, Fy, conj_y=True) # [batch * m_x, m_y * feature_out, complex] m_x -> m, m_y -> n 53 | Fz = Fz.view(nbatch, L * L, nfeature_out, 2) # [batch, m * n, feature_out, complex] 54 | Fz = Fz.transpose(0, 1) # [m * n, batch, feature_out, complex] 55 | 56 | Fz_list.append(Fz) 57 | 58 | begin += size 59 | 60 | z = torch.cat(Fz_list, 0) # [l * m * n, batch, feature_out, complex] 61 | return z 62 | 63 | 64 | class _cuda_SO3_mm(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, x, y): # pylint: disable=W 67 | ''' 68 | :param x: [l * m * n, batch, feature_in, complex] 69 | :param y: [l * m * n, feature_in, feature_out, complex] 70 | :return: [l * m * n, batch, feature_out, complex] 71 | ''' 72 | assert x.is_cuda and x.dtype == torch.float32 73 | assert y.is_cuda and y.dtype == torch.float32 74 | assert y.size(3) == 2 75 | assert x.size(3) == 2 76 | nbatch = x.size(1) 77 | nfeature_in = x.size(2) 78 | nfeature_out = y.size(2) 79 | assert y.size(1) == nfeature_in 80 | nspec = x.size(0) 81 | assert y.size(0) == nspec 82 | nl = round((3 / 4 * nspec) ** (1 / 3)) 83 | assert nspec == nl * (4 * nl ** 2 - 1) // 3 84 | 85 | ctx.save_for_backward(x, y) 86 | device = torch.cuda.current_device() 87 | cuda_kernel = _setup_so3mm_cuda_kernel(nl=nl, ni=nbatch, nj=nfeature_out, nk=nfeature_in, conj_y=True, 88 | trans_y_spec=True, device=device) 89 | 90 | output = x.new_empty((nspec, nbatch, nfeature_out, 2)) 91 | cuda_kernel(x, y, output) # [l * m * n, batch, feature_out, complex] 92 | 93 | return output 94 | 95 | @staticmethod 96 | def backward(ctx, gradz): # pylint: disable=W 97 | x, y = ctx.saved_tensors 98 | nspec = x.size(0) 99 | nbatch = x.size(1) 100 | nfeature_in = x.size(2) 101 | nfeature_out = y.size(2) 102 | 103 | nl = round((3 / 4 * nspec) ** (1 / 3)) 104 | assert nspec == nl * (4 * nl ** 2 - 1) // 3 105 | 106 | gradx = grady = None 107 | 108 | device = torch.cuda.current_device() 109 | if ctx.needs_input_grad[0]: 110 | gradx_cuda_kernel = _setup_so3mm_cuda_kernel(nl=nl, ni=nbatch, nj=nfeature_in, nk=nfeature_out, 111 | trans_y_feature=True, device=device) 112 | gradx = gradz.new_empty((nspec, nbatch, nfeature_in, 2)) 113 | gradx_cuda_kernel(gradz, y, gradx) 114 | 115 | if ctx.needs_input_grad[1]: 116 | grady_cuda_kernel = _setup_so3mm_cuda_kernel(nl=nl, ni=nfeature_out, nj=nfeature_in, nk=nbatch, 117 | trans_out_feature=True, conj_x=True, trans_x_spec=True, 118 | trans_x_feature=True, device=device) 119 | grady = gradz.new_empty((nspec, nfeature_in, nfeature_out, 2)) 120 | grady_cuda_kernel(gradz, x, grady) 121 | 122 | return gradx, grady 123 | 124 | 125 | @lru_cache(maxsize=32) 126 | def _setup_so3mm_cuda_kernel(nl, ni, nj, nk, 127 | conj_x=False, conj_y=False, 128 | trans_x_spec=False, trans_x_feature=False, 129 | trans_y_spec=False, trans_y_feature=False, 130 | trans_out_feature=False, device=0): 131 | ''' 132 | return a function that computes 133 | out[l*m*n, i, j] = sum_k sum_p x[l*m*p, i, k] y[l*p*n, k, j] 134 | where out, x, y are complex valued 135 | 136 | if conj_x is set to True, x is conjugated 137 | if conj_y is set to True, y is conjugated 138 | if trans_x_spec is set to True m and p are permuted in x[...] 139 | if trans_y_spec is set to True p and n are permuted in y[...] 140 | if trans_x_feature is set to True i and k are permuted in x[...] 141 | if trans_y_feature is set to True k and j are permuted in y[...] 142 | if trans_out_feature is set to True i and j are permuted in out[...] 143 | ''' 144 | 145 | kernel = ''' 146 | #define NI {} 147 | #define NJ {} 148 | #define NK {} 149 | '''.format(ni, nj, nk) 150 | 151 | if not trans_x_spec and not trans_x_feature: 152 | kernel += '#define INDEX_X (((L0 + m * L + p) * NI + i) * NK + k)\n' 153 | if not trans_x_spec and trans_x_feature: 154 | kernel += '#define INDEX_X (((L0 + m * L + p) * NK + k) * NI + i)\n' 155 | if trans_x_spec and not trans_x_feature: 156 | kernel += '#define INDEX_X (((L0 + p * L + m) * NI + i) * NK + k)\n' 157 | if trans_x_spec and trans_x_feature: 158 | kernel += '#define INDEX_X (((L0 + p * L + m) * NK + k) * NI + i)\n' 159 | 160 | if not trans_y_spec and not trans_y_feature: 161 | kernel += '#define INDEX_Y (((L0 + p * L + n) * NK + k) * NJ + j)\n' 162 | if not trans_y_spec and trans_y_feature: 163 | kernel += '#define INDEX_Y (((L0 + p * L + n) * NJ + j) * NK + k)\n' 164 | if trans_y_spec and not trans_y_feature: 165 | kernel += '#define INDEX_Y (((L0 + n * L + p) * NK + k) * NJ + j)\n' 166 | if trans_y_spec and trans_y_feature: 167 | kernel += '#define INDEX_Y (((L0 + n * L + p) * NJ + j) * NK + k)\n' 168 | 169 | if not trans_out_feature: 170 | kernel += '#define INDEX_OUT (((L0 + m * L + n) * NI + i) * NJ + j)\n' 171 | if trans_out_feature: 172 | kernel += '#define INDEX_OUT (((L0 + m * L + n) * NJ + j) * NI + i)\n' 173 | 174 | kernel += ''' 175 | #define CONJ_X {} 176 | #define CONJ_Y {} 177 | '''.format("x_im = -x_im;" if conj_x else ";", "y_im = -y_im;" if conj_y else ";") 178 | 179 | kernel += ''' 180 | #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) 181 | 182 | extern "C" 183 | __global__ void main_(const float* in_x, const float* in_y, float* out) 184 | { 185 | // start of thread independant code 186 | int l = blockIdx.z; 187 | int L = 2 * l + 1; 188 | int L0 = (4 * l*l - 1) * l / 3; 189 | 190 | if (blockIdx.y * 32 >= L * NI || blockIdx.x * 32 >= L * NJ) { 191 | return; 192 | } 193 | 194 | int ntile = CEIL_DIV(L * NK, 32); 195 | // end of thread independant code 196 | 197 | int mi = blockIdx.y * 32 + threadIdx.y; 198 | int m = mi / NI; 199 | int i = mi % NI; 200 | int nj = blockIdx.x * 32 + threadIdx.x; 201 | int n = nj / NJ; 202 | int j = nj % NJ; 203 | 204 | float sum_re = 0.0; 205 | float sum_im = 0.0; 206 | 207 | for (int tile = 0; tile < ntile; ++tile) { 208 | __shared__ float tileX[2][32][32]; 209 | __shared__ float tileY[2][32][32]; 210 | 211 | int pk = tile * 32 + threadIdx.x; 212 | int p = pk / NK; 213 | int k = pk % NK; 214 | int index = INDEX_X * 2; 215 | tileX[0][threadIdx.y][threadIdx.x] = m < L && p < L ? in_x[index + 0] : 0.0; 216 | tileX[1][threadIdx.y][threadIdx.x] = m < L && p < L ? in_x[index + 1] : 0.0; 217 | 218 | pk = tile * 32 + threadIdx.y; 219 | p = pk / NK; 220 | k = pk % NK; 221 | index = INDEX_Y * 2; 222 | tileY[0][threadIdx.y][threadIdx.x] = p < L && n < L ? in_y[index + 0] : 0.0; 223 | tileY[1][threadIdx.y][threadIdx.x] = p < L && n < L ? in_y[index + 1] : 0.0; 224 | 225 | __syncthreads(); 226 | 227 | for (int any = 0; any < 32; ++any) { 228 | float x_re = tileX[0][threadIdx.y][any]; 229 | float x_im = tileX[1][threadIdx.y][any]; 230 | float y_re = tileY[0][any][threadIdx.x]; 231 | float y_im = tileY[1][any][threadIdx.x]; 232 | 233 | CONJ_X 234 | CONJ_Y 235 | 236 | sum_re += x_re * y_re - x_im * y_im; 237 | sum_im += x_re * y_im + x_im * y_re; 238 | } 239 | 240 | __syncthreads(); 241 | } 242 | 243 | if (m < L && n < L) { 244 | int index = INDEX_OUT * 2; 245 | out[index + 0] = sum_re; 246 | out[index + 1] = sum_im; 247 | } 248 | } 249 | ''' 250 | import s2cnn.utils.cuda as cuda_utils 251 | kernel = cuda_utils.compile_kernel(kernel, 'so3_mm.cu', 'main_') 252 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 253 | 254 | def fun(x, y, output): 255 | assert output.is_contiguous() 256 | kernel(block=(32, 32, 1), 257 | grid=(math.ceil((2 * nl - 1) * nj / 32), math.ceil((2 * nl - 1) * ni / 32), nl), 258 | args=[x.contiguous().data_ptr(), y.contiguous().data_ptr(), output.data_ptr()], 259 | stream=stream) 260 | 261 | return fun 262 | 263 | 264 | def test_compare_cuda_cpu(): 265 | x = torch.rand(1+9+25+49, 2, 3, 2) # [l * m * n, batch, feature_in, complex] 266 | y = torch.rand(1+9+25+49, 3, 5, 2) # [l * m * n, feature_in, feature_out, complex] 267 | z1 = so3_mm(x, y) 268 | z2 = so3_mm(x.cuda(), y.cuda()).cpu() 269 | q = (z1 - z2).abs().max().item() / z1.std().item() 270 | print(q) 271 | assert q < 1e-4 272 | 273 | 274 | if __name__ == "__main__": 275 | test_compare_cuda_cpu() 276 | -------------------------------------------------------------------------------- /examples/shrec17/dataset.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,R,C 2 | import csv 3 | import glob 4 | import os 5 | import re 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | import trimesh 10 | import logging 11 | 12 | logging.getLogger('pyembree').disabled = True 13 | 14 | 15 | def rotmat(a, b, c, hom_coord=False): # apply to mesh using mesh.apply_transform(rotmat(a,b,c, True)) 16 | """ 17 | Create a rotation matrix with an optional fourth homogeneous coordinate 18 | 19 | :param a, b, c: ZYZ-Euler angles 20 | """ 21 | def z(a): 22 | return np.array([[np.cos(a), np.sin(a), 0, 0], 23 | [-np.sin(a), np.cos(a), 0, 0], 24 | [0, 0, 1, 0], 25 | [0, 0, 0, 1]]) 26 | 27 | def y(a): 28 | return np.array([[np.cos(a), 0, np.sin(a), 0], 29 | [0, 1, 0, 0], 30 | [-np.sin(a), 0, np.cos(a), 0], 31 | [0, 0, 0, 1]]) 32 | 33 | r = z(a).dot(y(b)).dot(z(c)) # pylint: disable=E1101 34 | if hom_coord: 35 | return r 36 | else: 37 | return r[:3, :3] 38 | 39 | 40 | def make_sgrid(b, alpha, beta, gamma): 41 | from lie_learn.spaces import S2 42 | 43 | theta, phi = S2.meshgrid(b=b, grid_type='SOFT') 44 | sgrid = S2.change_coordinates(np.c_[theta[..., None], phi[..., None]], p_from='S', p_to='C') 45 | sgrid = sgrid.reshape((-1, 3)) 46 | 47 | R = rotmat(alpha, beta, gamma, hom_coord=False) 48 | sgrid = np.einsum('ij,nj->ni', R, sgrid) 49 | 50 | return sgrid 51 | 52 | 53 | def render_model(mesh, sgrid): 54 | 55 | # Cast rays 56 | # triangle_indices = mesh.ray.intersects_first(ray_origins=sgrid, ray_directions=-sgrid) 57 | index_tri, index_ray, loc = mesh.ray.intersects_id( 58 | ray_origins=sgrid, ray_directions=-sgrid, multiple_hits=False, return_locations=True) 59 | loc = loc.reshape((-1, 3)) # fix bug if loc is empty 60 | 61 | # Each ray is in 1-to-1 correspondence with a grid point. Find the position of these points 62 | grid_hits = sgrid[index_ray] 63 | grid_hits_normalized = grid_hits / np.linalg.norm(grid_hits, axis=1, keepdims=True) 64 | 65 | # Compute the distance from the grid points to the intersection pionts 66 | dist = np.linalg.norm(grid_hits - loc, axis=-1) 67 | 68 | # For each intersection, look up the normal of the triangle that was hit 69 | normals = mesh.face_normals[index_tri] 70 | normalized_normals = normals / np.linalg.norm(normals, axis=1, keepdims=True) 71 | 72 | # Construct spherical images 73 | dist_im = np.ones(sgrid.shape[0]) 74 | dist_im[index_ray] = dist 75 | # dist_im = dist_im.reshape(theta.shape) 76 | 77 | # shaded_im = np.zeros(sgrid.shape[0]) 78 | # shaded_im[index_ray] = normals.dot(light_dir) 79 | # shaded_im = shaded_im.reshape(theta.shape) + 0.4 80 | 81 | n_dot_ray_im = np.zeros(sgrid.shape[0]) 82 | # n_dot_ray_im[index_ray] = np.abs(np.einsum("ij,ij->i", normals, grid_hits_normalized)) 83 | n_dot_ray_im[index_ray] = np.einsum("ij,ij->i", normalized_normals, grid_hits_normalized) 84 | 85 | nx, ny, nz = normalized_normals[:, 0], normalized_normals[:, 1], normalized_normals[:, 2] 86 | gx, gy, gz = grid_hits_normalized[:, 0], grid_hits_normalized[:, 1], grid_hits_normalized[:, 2] 87 | wedge_norm = np.sqrt((nx * gy - ny * gx) ** 2 + (nx * gz - nz * gx) ** 2 + (ny * gz - nz * gy) ** 2) 88 | n_wedge_ray_im = np.zeros(sgrid.shape[0]) 89 | n_wedge_ray_im[index_ray] = wedge_norm 90 | 91 | # Combine channels to construct final image 92 | # im = dist_im.reshape((1,) + dist_im.shape) 93 | im = np.stack((dist_im, n_dot_ray_im, n_wedge_ray_im), axis=0) 94 | 95 | return im 96 | 97 | 98 | def rnd_rot(): 99 | a = np.random.rand() * 2 * np.pi 100 | z = np.random.rand() * 2 - 1 101 | c = np.random.rand() * 2 * np.pi 102 | rot = rotmat(a, np.arccos(z), c, True) 103 | return rot 104 | 105 | 106 | class ToMesh: 107 | def __init__(self, random_rotations=False, random_translation=0): 108 | self.rot = random_rotations 109 | self.tr = random_translation 110 | 111 | def __call__(self, path): 112 | mesh = trimesh.load_mesh(path) 113 | mesh.remove_degenerate_faces() 114 | mesh.fix_normals() 115 | mesh.fill_holes() 116 | mesh.remove_duplicate_faces() 117 | mesh.remove_infinite_values() 118 | mesh.remove_unreferenced_vertices() 119 | 120 | mesh.apply_translation(-mesh.centroid) 121 | 122 | r = np.max(np.linalg.norm(mesh.vertices, axis=-1)) 123 | mesh.apply_scale(1 / r) 124 | 125 | if self.tr > 0: 126 | tr = np.random.rand() * self.tr 127 | rot = rnd_rot() 128 | mesh.apply_transform(rot) 129 | mesh.apply_translation([tr, 0, 0]) 130 | 131 | if not self.rot: 132 | mesh.apply_transform(rot.T) 133 | 134 | if self.rot: 135 | mesh.apply_transform(rnd_rot()) 136 | 137 | r = np.max(np.linalg.norm(mesh.vertices, axis=-1)) 138 | mesh.apply_scale(0.99 / r) 139 | 140 | return mesh 141 | 142 | def __repr__(self): 143 | return self.__class__.__name__ + '(rotation={0}, translation={1})'.format(self.rot, self.tr) 144 | 145 | 146 | class ProjectOnSphere: 147 | def __init__(self, bandwidth): 148 | self.bandwidth = bandwidth 149 | self.sgrid = make_sgrid(bandwidth, alpha=0, beta=0, gamma=0) 150 | 151 | def __call__(self, mesh): 152 | im = render_model(mesh, self.sgrid) 153 | im = im.reshape(3, 2 * self.bandwidth, 2 * self.bandwidth) 154 | 155 | from scipy.spatial.qhull import QhullError # pylint: disable=E0611 156 | try: 157 | convex_hull = mesh.convex_hull 158 | except QhullError: 159 | convex_hull = mesh 160 | 161 | hull_im = render_model(convex_hull, self.sgrid) 162 | hull_im = hull_im.reshape(3, 2 * self.bandwidth, 2 * self.bandwidth) 163 | 164 | im = np.concatenate([im, hull_im], axis=0) 165 | assert len(im) == 6 166 | 167 | im[0] -= 0.75 168 | im[0] /= 0.26 169 | im[1] -= 0.59 170 | im[1] /= 0.50 171 | im[2] -= 0.54 172 | im[2] /= 0.29 173 | im[3] -= 0.52 174 | im[3] /= 0.19 175 | im[4] -= 0.80 176 | im[4] /= 0.18 177 | im[5] -= 0.51 178 | im[5] /= 0.25 179 | 180 | im = im.astype(np.float32) # pylint: disable=E1101 181 | 182 | return im 183 | 184 | def __repr__(self): 185 | return self.__class__.__name__ + '(bandwidth={0})'.format(self.bandwidth) 186 | 187 | 188 | class CacheNPY: 189 | def __init__(self, prefix, repeat, transform, pick_randomly=True): 190 | self.transform = transform 191 | self.prefix = prefix 192 | self.repeat = repeat 193 | self.pick_randomly = pick_randomly 194 | 195 | def check_trans(self, file_path): 196 | print("transform {}...".format(file_path)) 197 | try: 198 | return self.transform(file_path) 199 | except: 200 | print("Exception during transform of {}".format(file_path)) 201 | raise 202 | 203 | def __call__(self, file_path): 204 | head, tail = os.path.split(file_path) 205 | root, _ = os.path.splitext(tail) 206 | npy_path = os.path.join(head, self.prefix + root + '_{0}.npy') 207 | 208 | exists = [os.path.exists(npy_path.format(i)) for i in range(self.repeat)] 209 | 210 | if self.pick_randomly and all(exists): 211 | i = np.random.randint(self.repeat) 212 | try: return np.load(npy_path.format(i)) 213 | except OSError: exists[i] = False 214 | 215 | if self.pick_randomly: 216 | img = self.check_trans(file_path) 217 | np.save(npy_path.format(exists.index(False)), img) 218 | 219 | return img 220 | 221 | output = [] 222 | for i in range(self.repeat): 223 | try: 224 | img = np.load(npy_path.format(i)) 225 | except (OSError, FileNotFoundError): 226 | img = self.check_trans(file_path) 227 | np.save(npy_path.format(i), img) 228 | output.append(img) 229 | 230 | return output 231 | 232 | def __repr__(self): 233 | return self.__class__.__name__ + '(prefix={0}, transform={1})'.format(self.prefix, self.transform) 234 | 235 | 236 | class Shrec17(torch.utils.data.Dataset): 237 | ''' 238 | Download SHREC17 and output valid obj files content 239 | ''' 240 | 241 | url_data = 'http://3dvision.princeton.edu/ms/shrec17-data/{}.zip' 242 | url_label = 'http://3dvision.princeton.edu/ms/shrec17-data/{}.csv' 243 | 244 | def __init__(self, root, dataset, perturbed=True, download=False, transform=None, target_transform=None): 245 | self.root = os.path.expanduser(root) 246 | 247 | if not dataset in ["train", "test", "val"]: 248 | raise ValueError("Invalid dataset") 249 | 250 | self.dir = os.path.join(self.root, dataset + ("_perturbed" if perturbed else "")) 251 | self.transform = transform 252 | self.target_transform = target_transform 253 | 254 | if download: 255 | self.download(dataset, perturbed) 256 | 257 | if not self._check_exists(): 258 | raise RuntimeError('Dataset not found.' + 259 | ' You can use download=True to download it') 260 | 261 | self.files = sorted(glob.glob(os.path.join(self.dir, '*.obj'))) 262 | if dataset != "test": 263 | with open(os.path.join(self.root, dataset + ".csv"), 'rt') as f: 264 | reader = csv.reader(f) 265 | self.labels = {} 266 | for row in [x for x in reader][1:]: 267 | self.labels[row[0]] = (row[1], row[2]) 268 | else: 269 | self.labels = None 270 | 271 | def __getitem__(self, index): 272 | img = f = self.files[index] 273 | 274 | if self.transform is not None: 275 | img = self.transform(img) 276 | 277 | if self.labels is not None: 278 | i = os.path.splitext(os.path.basename(f))[0] 279 | target = self.labels[i] 280 | 281 | if self.target_transform is not None: 282 | target = self.target_transform(target) 283 | 284 | return img, target 285 | else: 286 | return img 287 | 288 | def __len__(self): 289 | return len(self.files) 290 | 291 | def _check_exists(self): 292 | files = glob.glob(os.path.join(self.dir, "*.obj")) 293 | 294 | return len(files) > 0 295 | 296 | def _download(self, url): 297 | import requests 298 | 299 | filename = url.split('/')[-1] 300 | file_path = os.path.join(self.root, filename) 301 | 302 | if os.path.exists(file_path): 303 | return file_path 304 | 305 | print('Downloading ' + url) 306 | 307 | r = requests.get(url, stream=True) 308 | with open(file_path, 'wb') as f: 309 | for chunk in r.iter_content(chunk_size=16 * 1024 ** 2): 310 | if chunk: # filter out keep-alive new chunks 311 | f.write(chunk) 312 | f.flush() 313 | 314 | return file_path 315 | 316 | def _unzip(self, file_path): 317 | import zipfile 318 | 319 | if os.path.exists(self.dir): 320 | return 321 | 322 | print('Unzip ' + file_path) 323 | 324 | zip_ref = zipfile.ZipFile(file_path, 'r') 325 | zip_ref.extractall(self.root) 326 | zip_ref.close() 327 | os.unlink(file_path) 328 | 329 | def _fix(self): 330 | print("Fix obj files") 331 | 332 | r = re.compile(r'f (\d+)[/\d]* (\d+)[/\d]* (\d+)[/\d]*') 333 | 334 | path = os.path.join(self.dir, "*.obj") 335 | files = sorted(glob.glob(path)) 336 | 337 | c = 0 338 | for i, f in enumerate(files): 339 | with open(f, "rt") as x: 340 | y = x.read() 341 | yy = r.sub(r"f \1 \2 \3", y) 342 | if y != yy: 343 | c += 1 344 | with open(f, "wt") as x: 345 | x.write(yy) 346 | print("{}/{} {} fixed ".format(i + 1, len(files), c), end="\r") 347 | 348 | def download(self, dataset, perturbed): 349 | 350 | if self._check_exists(): 351 | return 352 | 353 | # download files 354 | try: 355 | os.makedirs(self.root) 356 | except OSError as e: 357 | if e.errno == os.errno.EEXIST: 358 | pass 359 | else: 360 | raise 361 | 362 | url = self.url_data.format(dataset + ("_perturbed" if perturbed else "")) 363 | file_path = self._download(url) 364 | self._unzip(file_path) 365 | self._fix() 366 | 367 | if dataset != "test": 368 | url = self.url_label.format(dataset) 369 | self._download(url) 370 | 371 | print('Done!') 372 | -------------------------------------------------------------------------------- /s2cnn/s2_mm.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101 2 | from functools import lru_cache 3 | import torch 4 | import torch.cuda 5 | from string import Template 6 | 7 | # TODO simplify the cuda code like it was done in SO3_mm using only one code for the kernel 8 | 9 | 10 | def s2_mm(x, y): 11 | ''' 12 | :param x: [l * m, batch, feature_in, complex] 13 | :param y: [l * m, feature_in, feature_out, complex] 14 | :return: [l * m * n, batch, feature_out, complex] 15 | ''' 16 | from s2cnn.utils.complex import complex_mm 17 | 18 | assert y.size(3) == 2 19 | assert x.size(3) == 2 20 | nbatch = x.size(1) 21 | nfeature_in = x.size(2) 22 | nfeature_out = y.size(2) 23 | assert y.size(1) == nfeature_in 24 | nspec = x.size(0) 25 | assert y.size(0) == nspec 26 | 27 | if x.is_cuda: 28 | return _cuda_S2_mm.apply(x, y) 29 | 30 | nl = round(nspec**0.5) 31 | 32 | Fz_list = [] 33 | begin = 0 34 | for l in range(nl): 35 | L = 2 * l + 1 36 | size = L 37 | 38 | Fx = x[begin:begin+size] # [m, batch, feature_in, complex] 39 | Fy = y[begin:begin+size] # [m, feature_in, feature_out, complex] 40 | 41 | Fx = Fx.view(L * nbatch, nfeature_in, 2) # [m * batch, feature_in, complex] 42 | 43 | Fy = Fy.transpose(0, 1) # [feature_in, m, feature_out, complex] 44 | Fy = Fy.contiguous() 45 | Fy = Fy.view(nfeature_in, L * nfeature_out, 2) # [feature_in, m * feature_out, complex] 46 | 47 | Fz = complex_mm(Fx, Fy, conj_y=True) # [m_x * batch, m_y * feature_out, complex] m_x -> m, m_y -> n 48 | Fz = Fz.view(L, nbatch, L, nfeature_out, 2) # [m, batch, n, feature_out, complex] 49 | Fz = Fz.transpose(1, 2) # [m, n, batch, feature_out, complex] 50 | Fz = Fz.contiguous() 51 | Fz = Fz.view(L * L, nbatch, nfeature_out, 2) # [m * n, batch, feature_out, complex] 52 | 53 | Fz_list.append(Fz) 54 | 55 | begin += size 56 | 57 | z = torch.cat(Fz_list, 0) # [l * m * n, batch, feature_out, complex] 58 | return z 59 | 60 | 61 | class _cuda_S2_mm(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, x, y): # pylint: disable=W 64 | ctx.save_for_backward(x, y) 65 | return _cuda_s2_mm(x, y) 66 | 67 | @staticmethod 68 | def backward(ctx, gradz): # pylint: disable=W 69 | import s2cnn.utils.cuda as cuda_utils 70 | x, y = ctx.saved_tensors 71 | nl = round(x.size(0) ** 0.5) 72 | nbatch = x.size(1) 73 | nfeature_in = x.size(2) 74 | nfeature_out = y.size(2) 75 | nspec = (4 * nl ** 2 - 1) * nl // 3 76 | device = torch.cuda.current_device() 77 | 78 | gradx_cuda_kernel = _setup_s2mm_gradx_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, 79 | nfeature_out=nfeature_out, device=device) 80 | grady_cuda_kernel = _setup_s2mm_grady_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, 81 | nfeature_out=nfeature_out, device=device) 82 | 83 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 84 | 85 | gradx = grady = None 86 | 87 | if ctx.needs_input_grad[0]: 88 | gradx = gradz.new_empty((nl ** 2, nbatch, nfeature_in, 2)) 89 | gradx_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), 90 | grid=(cuda_utils.get_blocks(nl ** 2 * nbatch * nfeature_in, 1024), 1, 1), 91 | args=[gradz.contiguous().data_ptr(), y.contiguous().data_ptr(), gradx.data_ptr()], 92 | stream=stream) 93 | 94 | if ctx.needs_input_grad[1]: 95 | grady = gradz.new_empty((nl ** 2, nfeature_in, nfeature_out, 2)) 96 | grady_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), 97 | grid=(cuda_utils.get_blocks(nl ** 2 * nfeature_in * nfeature_out, 1024), 1, 1), 98 | args=[gradz.contiguous().data_ptr(), x.contiguous().data_ptr(), grady.data_ptr()], 99 | stream=stream) 100 | 101 | return gradx, grady 102 | 103 | 104 | def _cuda_s2_mm(x, y): 105 | ''' 106 | :param x: [l * m, batch, feature_in, complex] 107 | :param y: [l * m, feature_in, feature_out, complex] 108 | :return: [l * m * n, batch, feature_out, complex] 109 | ''' 110 | import s2cnn.utils.cuda as cuda_utils 111 | assert x.is_cuda and x.dtype == torch.float32 112 | assert y.is_cuda and y.dtype == torch.float32 113 | assert y.size(3) == 2 114 | assert x.size(3) == 2 115 | nbatch = x.size(1) 116 | nfeature_in = x.size(2) 117 | nfeature_out = y.size(2) 118 | assert y.size(1) == nfeature_in 119 | assert y.size(0) == x.size(0) 120 | nl = round(x.size(0) ** 0.5) 121 | nspec = (4 * nl ** 2 - 1) * nl // 3 122 | assert x.size(0) == nl ** 2 123 | assert y.size(0) == nl ** 2 124 | 125 | device = torch.cuda.current_device() 126 | cuda_kernel = _setup_s2mm_cuda_kernel(nbatch=nbatch, nspec=nspec, nfeature_in=nfeature_in, 127 | nfeature_out=nfeature_out, device=device) 128 | 129 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 130 | output = x.new_empty((nspec, nbatch, nfeature_out, 2)) 131 | cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), 132 | grid=(cuda_utils.get_blocks(nspec * nbatch * nfeature_out, 1024), 1, 1), 133 | args=[x.contiguous().data_ptr(), y.contiguous().data_ptr(), output.data_ptr()], 134 | stream=stream) 135 | # [l * m * n, batch, feature_out, complex] 136 | 137 | return output 138 | 139 | 140 | @lru_cache(maxsize=32) 141 | def _setup_s2mm_cuda_kernel(nbatch, nspec, nfeature_in, nfeature_out, device=0): 142 | kernel = Template(''' 143 | #define COMPUTE_LMN(s) \ 144 | int l = powf(3.0/4.0 * s, 1.0/3.0) - 0.5; \ 145 | int L = l * (4 * l * l - 1) / 3; \ 146 | int rest = s - L; \ 147 | if (rest >= (2 * l + 1) * (2 * l + 1)) { \ 148 | ++l; \ 149 | L = l * (4 * l * l - 1) / 3; \ 150 | rest = s - L; \ 151 | } \ 152 | int m = rest / (2 * l + 1) - l; \ 153 | int n = rest % (2 * l + 1) - l; 154 | 155 | #define EXTRACT(i1, i2, n2, i3, n3) \ 156 | int i1 = index; \ 157 | int i3 = i1 % (n3); i1 /= n3; \ 158 | int i2 = i1 % (n2); i1 /= n2; 159 | 160 | #define CONTRACT1(s1, i2, n2, i3, n3) \ 161 | ( ( (l * l + (l + (s1))) * (n2) + (i2) ) * (n3) + (i3) ) 162 | 163 | #define CONTRACT2(s1, s2, i2, n2, i3, n3) \ 164 | ( ( (L + (l + (s1)) * (2 * l + 1) + (l + (s2))) * (n2) + (i2) ) * (n3) + (i3) ) 165 | 166 | extern "C" 167 | __global__ void main_(const float* in_x, const float* in_y, float* out) { 168 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < ${nspec} * ${nbatch} * ${nfeature_out}; index += blockDim.x * gridDim.x) { 169 | EXTRACT(s, i, ${nbatch}, f_out, ${nfeature_out}) 170 | 171 | // compute s -> (l,m,n) 172 | COMPUTE_LMN(s) 173 | 174 | float out_re = 0.0; 175 | float out_im = 0.0; 176 | 177 | for (int f_in = 0; f_in < ${nfeature_in}; ++f_in) { 178 | float x_re = in_x[CONTRACT1(m, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 0]; 179 | float x_im = in_x[CONTRACT1(m, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 1]; 180 | float y_re = in_y[CONTRACT1(n, f_in, ${nfeature_in}, f_out, ${nfeature_out}) * 2 + 0]; 181 | float y_im = in_y[CONTRACT1(n, f_in, ${nfeature_in}, f_out, ${nfeature_out}) * 2 + 1]; 182 | 183 | // x times y conjugate 184 | out_re += x_re * y_re + x_im * y_im; 185 | out_im += x_im * y_re - x_re * y_im; 186 | } 187 | 188 | out[index * 2 + 0] = out_re; 189 | out[index * 2 + 1] = out_im; 190 | } 191 | } 192 | ''').substitute({'nbatch': nbatch, 193 | 'nspec': nspec, 194 | 'nfeature_in': nfeature_in, 195 | 'nfeature_out': nfeature_out}) 196 | 197 | import s2cnn.utils.cuda as cuda_utils 198 | return cuda_utils.compile_kernel(kernel, 's2mm.cu', 'main_') 199 | 200 | 201 | @lru_cache(maxsize=32) 202 | def _setup_s2mm_gradx_cuda_kernel(nbatch, nspec, nl, nfeature_in, nfeature_out, device=0): 203 | kernel = Template(''' 204 | #define COMPUTE_LM(s) \ 205 | int l = sqrtf(s); \ 206 | int L = (4 * l * l - 1) * l / 3; \ 207 | int m = s - l * l - l; 208 | 209 | #define EXTRACT(i1, i2, n2, i3, n3) \ 210 | int i1 = index; \ 211 | int i3 = i1 % (n3); i1 /= n3; \ 212 | int i2 = i1 % (n2); i1 /= n2; 213 | 214 | #define CONTRACT1(s1, i2, n2, i3, n3) \ 215 | ( ( (l * l + (l + (s1))) * (n2) + (i2) ) * (n3) + (i3) ) 216 | 217 | #define CONTRACT2(s1, s2, i2, n2, i3, n3) \ 218 | ( ( (L + (l + (s1)) * (2 * l + 1) + (l + (s2))) * (n2) + (i2) ) * (n3) + (i3) ) 219 | 220 | extern "C" 221 | __global__ void main_(const float* grad_z, const float* y, float* grad_x) { 222 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (${nl} * ${nl}) * ${nbatch} * ${nfeature_in}; index += blockDim.x * gridDim.x) { 223 | EXTRACT(s, i, ${nbatch}, f_in, ${nfeature_in}) 224 | 225 | // compute s -> (l,m) 226 | COMPUTE_LM(s) 227 | 228 | float out_re = 0.0; 229 | float out_im = 0.0; 230 | 231 | for (int f_out = 0; f_out < ${nfeature_out}; ++f_out) { 232 | for (int k = -l; k <= l; ++k) { 233 | float grad_z_re = grad_z[CONTRACT2(m, k, i, ${nbatch}, f_out, ${nfeature_out}) * 2 + 0]; 234 | float grad_z_im = grad_z[CONTRACT2(m, k, i, ${nbatch}, f_out, ${nfeature_out}) * 2 + 1]; 235 | float y_re = y[CONTRACT1(k, f_in, ${nfeature_in}, f_out, ${nfeature_out}) * 2 + 0]; 236 | float y_im = y[CONTRACT1(k, f_in, ${nfeature_in}, f_out, ${nfeature_out}) * 2 + 1]; 237 | 238 | // grad_z times y 239 | out_re += grad_z_re * y_re - grad_z_im * y_im; 240 | out_im += grad_z_re * y_im + grad_z_im * y_re; 241 | } 242 | } 243 | 244 | grad_x[index * 2 + 0] = out_re; 245 | grad_x[index * 2 + 1] = out_im; 246 | } 247 | } 248 | ''').substitute({'nbatch': nbatch, 249 | 'nspec': nspec, 250 | 'nl': nl, 251 | 'nfeature_in': nfeature_in, 252 | 'nfeature_out': nfeature_out}) 253 | 254 | import s2cnn.utils.cuda as cuda_utils 255 | return cuda_utils.compile_kernel(kernel, 's2mm_gradx.cu', 'main_') 256 | 257 | 258 | @lru_cache(maxsize=32) 259 | def _setup_s2mm_grady_cuda_kernel(nbatch, nspec, nl, nfeature_in, nfeature_out, device=0): 260 | kernel = Template(''' 261 | #define COMPUTE_LM(s) \ 262 | int l = powf(s, 0.5); \ 263 | int L = (4 * l * l - 1) * l / 3; \ 264 | int m = s - l * l - l; 265 | 266 | #define EXTRACT(i1, i2, n2, i3, n3) \ 267 | int i1 = index; \ 268 | int i3 = i1 % (n3); i1 /= n3; \ 269 | int i2 = i1 % (n2); i1 /= n2; 270 | 271 | #define CONTRACT1(s1, i2, n2, i3, n3) \ 272 | ( ( (l * l + (l + (s1))) * (n2) + (i2) ) * (n3) + (i3) ) 273 | 274 | #define CONTRACT2(s1, s2, i2, n2, i3, n3) \ 275 | ( ( (L + (l + (s1)) * (2 * l + 1) + (l + (s2))) * (n2) + (i2) ) * (n3) + (i3) ) 276 | 277 | extern "C" 278 | __global__ void main_(const float* grad_z, const float* x, float* grad_y) { 279 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (${nl} * ${nl}) * ${nfeature_in} * ${nfeature_out}; index += blockDim.x * gridDim.x) { 280 | EXTRACT(s, f_in, ${nfeature_in}, f_out, ${nfeature_out}) 281 | 282 | // compute s -> (l,m) 283 | COMPUTE_LM(s) 284 | 285 | float out_re = 0.0; 286 | float out_im = 0.0; 287 | 288 | for (int i = 0; i < ${nbatch}; ++i) { 289 | for (int k = -l; k <= l; ++k) { 290 | float grad_z_re = grad_z[CONTRACT2(k, m, i, ${nbatch}, f_out, ${nfeature_out}) * 2 + 0]; 291 | float grad_z_im = grad_z[CONTRACT2(k, m, i, ${nbatch}, f_out, ${nfeature_out}) * 2 + 1]; 292 | float x_re = x[CONTRACT1(k, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 0]; 293 | float x_im = x[CONTRACT1(k, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 1]; 294 | 295 | // conjugate grad_z times x 296 | out_re += grad_z_re * x_re + grad_z_im * x_im; 297 | out_im += grad_z_re * x_im - grad_z_im * x_re; 298 | } 299 | } 300 | 301 | grad_y[index * 2 + 0] = out_re; 302 | grad_y[index * 2 + 1] = out_im; 303 | } 304 | } 305 | ''').substitute({'nbatch': nbatch, 306 | 'nspec': nspec, 307 | 'nl': nl, 308 | 'nfeature_in': nfeature_in, 309 | 'nfeature_out': nfeature_out}) 310 | 311 | import s2cnn.utils.cuda as cuda_utils 312 | return cuda_utils.compile_kernel(kernel, 's2mm_grady.cu', 'main_') 313 | 314 | 315 | def test_compare_cuda_cpu(): 316 | x = torch.rand(1+3+5+7, 2, 3, 2) # [l * m, batch, feature_in, complex] 317 | y = torch.rand(1+3+5+7, 3, 5, 2) # [l * m, feature_in, feature_out, complex] 318 | z1 = s2_mm(x, y) 319 | z2 = s2_mm(x.cuda(), y.cuda()).cpu() 320 | q = (z1 - z2).abs().max().item() / z1.std().item() 321 | print(q) 322 | assert q < 1e-4 323 | 324 | 325 | if __name__ == "__main__": 326 | test_compare_cuda_cpu() 327 | -------------------------------------------------------------------------------- /s2cnn/soft/so3_fft.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R,C,E1101,E1102 2 | import math 3 | from functools import lru_cache 4 | import torch 5 | import torch.cuda 6 | from s2cnn.utils.decorator import cached_dirpklgz 7 | 8 | 9 | # inspired by https://gist.github.com/szagoruyko/89f83b6f5f4833d3c8adf81ee49f22a8 10 | 11 | 12 | def so3_fft(x, for_grad=False, b_out=None): 13 | ''' 14 | :param x: [..., beta, alpha, gamma, complex] 15 | :return: [l * m * n, ..., complex] 16 | ''' 17 | assert x.size(-1) == 2, x.size() 18 | b_in = x.size(-2) // 2 19 | assert x.size(-2) == 2 * b_in 20 | assert x.size(-3) == 2 * b_in 21 | assert x.size(-4) == 2 * b_in 22 | if b_out is None: 23 | b_out = b_in 24 | batch_size = x.size()[:-4] 25 | 26 | x = x.view(-1, 2 * b_in, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, gamma, complex] 27 | 28 | ''' 29 | :param x: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_in, 2 b_in, 2 b_in, 2) 30 | :return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2) 31 | ''' 32 | nspec = b_out * (4 * b_out ** 2 - 1) // 3 33 | nbatch = x.size(0) 34 | 35 | wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) # [beta, l * m * n] 36 | 37 | x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x),dim=[2,3])) # [batch, beta, m, n, complex] 38 | 39 | output = x.new_empty((nspec, nbatch, 2)) 40 | if x.is_cuda and x.dtype == torch.float32: 41 | cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=False, device=x.device.index) 42 | cuda_kernel(x, wigner, output) # [l * m * n, batch, complex] 43 | else: 44 | if b_in < b_out: 45 | output.fill_(0) 46 | for l in range(b_out): 47 | s = slice(l * (4 * l ** 2 - 1) // 3, l * (4 * l ** 2 - 1) // 3 + (2 * l + 1) ** 2) 48 | l1 = min(l, b_in - 1) # if b_out > b_in, consider high frequencies as null 49 | 50 | xx = x.new_zeros((x.size(0), x.size(1), 2 * l + 1, 2 * l + 1, 2)) 51 | xx[:, :, l: l + l1 + 1, l: l + l1 + 1] = x[:, :, :l1 + 1, :l1 + 1] 52 | if l1 > 0: 53 | xx[:, :, l - l1:l, l: l + l1 + 1] = x[:, :, -l1:, :l1 + 1] 54 | xx[:, :, l: l + l1 + 1, l - l1:l] = x[:, :, :l1 + 1, -l1:] 55 | xx[:, :, l - l1:l, l - l1:l] = x[:, :, -l1:, -l1:] 56 | 57 | out = torch.einsum("bmn,zbmnc->mnzc", (wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1), xx)) 58 | output[s] = out.view((2 * l + 1) ** 2, -1, 2) 59 | 60 | output = output.view(-1, *batch_size, 2) # [l * m * n, ..., complex] 61 | return output 62 | 63 | 64 | def so3_rfft(x, for_grad=False, b_out=None): 65 | ''' 66 | :param x: [..., beta, alpha, gamma] 67 | :return: [l * m * n, ..., complex] 68 | ''' 69 | b_in = x.size(-1) // 2 70 | assert x.size(-1) == 2 * b_in 71 | assert x.size(-2) == 2 * b_in 72 | assert x.size(-3) == 2 * b_in 73 | if b_out is None: 74 | b_out = b_in 75 | batch_size = x.size()[:-3] 76 | 77 | x = x.contiguous().view(-1, 2 * b_in, 2 * b_in, 2 * b_in) # [batch, beta, alpha, gamma] 78 | 79 | ''' 80 | :param x: [batch, beta, alpha, gamma] (nbatch, 2 b_in, 2 b_in, 2 b_in) 81 | :return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2) 82 | ''' 83 | nspec = b_out * (4 * b_out ** 2 - 1) // 3 84 | nbatch = x.size(0) 85 | 86 | wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) 87 | 88 | output = x.new_empty((nspec, nbatch, 2)) 89 | if x.is_cuda and x.dtype == torch.float32: 90 | x = torch.view_as_real(torch.fft.rfftn(x, dim=[2,3])) # [batch, beta, m, n, complex] 91 | cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=True, device=x.device.index) 92 | cuda_kernel(x, wigner, output) 93 | else: 94 | x = torch.view_as_real(torch.fft.rfftn(torch.view_as_complex(torch.stack((x, torch.zeros_like(x)), dim=-1)), dim=[2,3])) 95 | if b_in < b_out: 96 | output.fill_(0) 97 | for l in range(b_out): 98 | s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2) 99 | l1 = min(l, b_in - 1) # if b_out > b_in, consider high frequencies as null 100 | 101 | xx = x.new_zeros((x.size(0), x.size(1), 2 * l + 1, 2 * l + 1, 2)) 102 | xx[:, :, l: l + l1 + 1, l: l + l1 + 1] = x[:, :, :l1 + 1, :l1 + 1] 103 | if l1 > 0: 104 | xx[:, :, l - l1:l, l: l + l1 + 1] = x[:, :, -l1:, :l1 + 1] 105 | xx[:, :, l: l + l1 + 1, l - l1:l] = x[:, :, :l1 + 1, -l1:] 106 | xx[:, :, l - l1:l, l - l1:l] = x[:, :, -l1:, -l1:] 107 | 108 | out = torch.einsum("bmn,zbmnc->mnzc", (wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1), xx)) 109 | output[s] = out.view((2 * l + 1) ** 2, -1, 2) 110 | 111 | output = output.view(-1, *batch_size, 2) # [l * m * n, ..., complex] 112 | return output 113 | 114 | 115 | def so3_ifft(x, for_grad=False, b_out=None): 116 | ''' 117 | :param x: [l * m * n, ..., complex] 118 | ''' 119 | assert x.size(-1) == 2 120 | nspec = x.size(0) 121 | b_in = round((3 / 4 * nspec) ** (1 / 3)) 122 | assert nspec == b_in * (4 * b_in ** 2 - 1) // 3 123 | if b_out is None: 124 | b_out = b_in 125 | batch_size = x.size()[1:-1] 126 | 127 | x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2) 128 | 129 | ''' 130 | :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2) 131 | :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2) 132 | ''' 133 | nbatch = x.size(1) 134 | 135 | wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec) 136 | 137 | output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2)) 138 | if x.is_cuda and x.dtype == torch.float32: 139 | cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index) 140 | cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex] 141 | else: 142 | output.fill_(0) 143 | for l in range(min(b_in, b_out)): 144 | s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2) 145 | out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1))) 146 | l1 = min(l, b_out - 1) # if b_out < b_in 147 | output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1] 148 | if l > 0: 149 | output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1] 150 | output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l] 151 | output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l] 152 | 153 | output = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(output), dim=[2,3])) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex] 154 | output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2) 155 | return output 156 | 157 | 158 | def so3_rifft(x, for_grad=False, b_out=None): 159 | ''' 160 | :param x: [l * m * n, ..., complex] 161 | ''' 162 | assert x.size(-1) == 2 163 | nspec = x.size(0) 164 | b_in = round((3 / 4 * nspec) ** (1 / 3)) 165 | assert nspec == b_in * (4 * b_in ** 2 - 1) // 3 166 | if b_out is None: 167 | b_out = b_in 168 | batch_size = x.size()[1:-1] 169 | 170 | x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2) 171 | 172 | ''' 173 | :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2) 174 | :return: [batch, beta, alpha, gamma] (nbatch, 2 b_out, 2 b_out, 2 b_out) 175 | ''' 176 | nbatch = x.size(1) 177 | 178 | wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec) 179 | 180 | output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2)) 181 | if x.is_cuda and x.dtype == torch.float32: 182 | cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=True, device=x.device.index) 183 | cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex] 184 | else: 185 | # TODO can be optimized knowing that the output is real, like in _setup_so3ifft_cuda_kernel(real_output=True) 186 | output.fill_(0) 187 | for l in range(min(b_in, b_out)): 188 | s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2) 189 | out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1))) 190 | l1 = min(l, b_out - 1) # if b_out < b_in 191 | output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1] 192 | if l > 0: 193 | output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1] 194 | output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l] 195 | output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l] 196 | 197 | output = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(output), dim=[2,3])) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex] 198 | output = output[..., 0] # [batch, beta, alpha, gamma] 199 | output = output.contiguous() 200 | 201 | output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out) 202 | return output 203 | 204 | 205 | @lru_cache(maxsize=32) 206 | def _setup_wigner(b, nl, weighted, device): 207 | dss = _setup_so3_fft(b, nl, weighted) 208 | dss = torch.tensor(dss, dtype=torch.float32, device=device) # [beta, l * m * n] # pylint: disable=E1102 209 | return dss.contiguous() 210 | 211 | 212 | @cached_dirpklgz("cache/setup_so3_fft") 213 | def _setup_so3_fft(b, nl, weighted): 214 | from lie_learn.representations.SO3.wigner_d import wigner_d_matrix 215 | import lie_learn.spaces.S3 as S3 216 | import numpy as np 217 | import logging 218 | 219 | betas = (np.arange(2 * b) + 0.5) / (2 * b) * np.pi 220 | w = S3.quadrature_weights(b) 221 | assert len(w) == len(betas) 222 | 223 | logging.getLogger("trainer").info("Compute Wigner: b=%d nbeta=%d nl=%d nspec=%d", b, len(betas), nl, nl ** 2) 224 | 225 | dss = [] 226 | for b, beta in enumerate(betas): 227 | ds = [] 228 | for l in range(nl): 229 | d = wigner_d_matrix(l, beta, 230 | field='complex', normalization='quantum', order='centered', condon_shortley='cs') 231 | d = d.reshape(((2 * l + 1) ** 2,)) 232 | 233 | if weighted: 234 | d *= w[b] 235 | else: 236 | d *= 2 * l + 1 237 | 238 | # d # [m * n] 239 | ds.append(d) 240 | ds = np.concatenate(ds) # [l * m * n] 241 | dss.append(ds) 242 | dss = np.stack(dss) # [beta, l * m * n] 243 | return dss 244 | 245 | 246 | @lru_cache(maxsize=32) 247 | def _setup_so3fft_cuda_kernel(b_in, b_out, nbatch, real_input, device=0): 248 | kernel = ''' 249 | #define B_IN {} 250 | #define B_OUT {} 251 | #define NSPEC {} 252 | #define NBATCH {} 253 | '''.format(b_in, b_out, b_out * (4 * b_out ** 2 - 1) // 3, nbatch) 254 | 255 | if real_input: 256 | kernel += ''' 257 | #define REAL_IN 258 | ''' 259 | 260 | kernel += ''' 261 | #define MOD(i, n) (((i) + (n)) % (n)) 262 | #define MAX(x, y) ((x) < (y) ? (y) : (x)) 263 | #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) 264 | 265 | extern "C" 266 | __global__ void main_(const float* in, const float* wig, float* out) 267 | { 268 | // blockIdx = (l, batch, mn) 269 | // blockDim = (32, 32, 1) 270 | // threadIdx = (sub l, sub batch, 0) 271 | // gridDim = (b / 32, nbatch / 32, (2b-1)**2) 272 | int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1); 273 | int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1); 274 | 275 | int l_min = MAX(abs(m), abs(n)); 276 | 277 | if (blockIdx.x * 32 + 31 < l_min) { 278 | // for blocks fully out of l-range 279 | return; // note: this return does not depend on threadIdx 280 | } 281 | 282 | #ifdef REAL_IN 283 | if (n < 0 || (n == 0 && m < 0)) { 284 | return; // note: this return does not depend on threadIdx 285 | } 286 | #endif 287 | 288 | int batch = blockIdx.y * 32 + threadIdx.y; 289 | int l = blockIdx.x * 32 + threadIdx.x; 290 | 291 | int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n); 292 | 293 | float sum_re = 0.0; 294 | float sum_im = 0.0; 295 | 296 | for (int tile = 0; tile < CEIL_DIV(2 * B_IN, 32); ++tile) { 297 | __shared__ float tileA[32][32][2]; 298 | __shared__ float tileB[32][32]; 299 | 300 | int beta = tile * 32 + threadIdx.x; 301 | #ifdef REAL_IN 302 | // `in` shape is (NBATCH, 2 * B_IN, 2 * B_IN, B_IN + 1, 2) 303 | // http://www.fftw.org/fftw3_doc/Multi_002dDimensional-DFTs-of-Real-Data.html 304 | int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * (B_IN + 1) + n) * 2; 305 | #else 306 | int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * 2*B_IN + MOD(n, 2*B_IN)) * 2; 307 | #endif 308 | tileA[threadIdx.y][threadIdx.x][0] = beta < 2*B_IN && batch < NBATCH && m < B_IN && n < B_IN && m > -B_IN && n > -B_IN ? in[i + 0] : 0.0; 309 | tileA[threadIdx.y][threadIdx.x][1] = beta < 2*B_IN && batch < NBATCH && m < B_IN && n < B_IN && m > -B_IN && n > -B_IN ? in[i + 1] : 0.0; 310 | // add constraints to m and n to remove aliasing (when b_out > b_in) 311 | 312 | beta = tile * 32 + threadIdx.y; 313 | tileB[threadIdx.y][threadIdx.x] = beta < 2*B_IN && l_min <= l && l < B_OUT ? wig[beta * NSPEC + lmn] : 0.0; 314 | 315 | __syncthreads(); 316 | 317 | for (int beta = 0; beta < 32; ++beta) { 318 | sum_re += tileA[threadIdx.y][beta][0] * tileB[beta][threadIdx.x]; 319 | sum_im += tileA[threadIdx.y][beta][1] * tileB[beta][threadIdx.x]; 320 | } 321 | 322 | __syncthreads(); 323 | } 324 | 325 | // About this if: some blocks are used to compute but not to save the results 326 | if (l_min <= l && l < B_OUT && batch < NBATCH) { 327 | out[(lmn * NBATCH + batch) * 2 + 0] = sum_re; 328 | out[(lmn * NBATCH + batch) * 2 + 1] = sum_im; 329 | 330 | #ifdef REAL_IN 331 | lmn = (4 * l*l - 1) * l / 3 + (l-m) * (2 * l + 1) + (l-n); 332 | float fudge = (m - n) % 2 == 0 ? 1.0 : -1.0; 333 | out[(lmn * NBATCH + batch) * 2 + 0] = fudge * sum_re; 334 | out[(lmn * NBATCH + batch) * 2 + 1] = -fudge * sum_im; 335 | #endif 336 | } 337 | } 338 | ''' 339 | import s2cnn.utils.cuda as cuda_utils 340 | kernel = cuda_utils.compile_kernel(kernel, 'so3fft.cu', 'main_') 341 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 342 | 343 | def fun(x, wigner, output): 344 | assert output.is_contiguous() 345 | kernel(block=(32, 32, 1), 346 | grid=(math.ceil(b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1) ** 2), 347 | args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()], 348 | stream=stream) 349 | 350 | return fun 351 | 352 | 353 | @lru_cache(maxsize=32) 354 | def _setup_so3ifft_cuda_kernel(b_in, b_out, nbatch, real_output, device=0): 355 | kernel = ''' 356 | #define B_IN {} 357 | #define B_OUT {} 358 | #define NSPEC {} 359 | #define NBATCH {} 360 | '''.format(b_in, b_out, b_in * (4 * b_in ** 2 - 1) // 3, nbatch) 361 | 362 | if real_output: 363 | kernel += ''' 364 | #define REAL_OUT 365 | ''' 366 | 367 | kernel += ''' 368 | #define MOD(i, n) (((i) + (n)) % (n)) 369 | #define MAX(x, y) ((x) < (y) ? (y) : (x)) 370 | #define MIN(x, y) ((x) < (y) ? (x) : (y)) 371 | #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) 372 | 373 | extern "C" 374 | __global__ void main_(const float* in, const float* wig, float* out) 375 | { 376 | int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1); 377 | int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1); 378 | 379 | #ifdef REAL_OUT 380 | if (n < 0 || (n == 0 && m < 0)) { 381 | return; // note: this return does not depend on threadIdx 382 | } 383 | #endif 384 | 385 | int l_min = MAX(abs(m), abs(n)); 386 | 387 | int batch = blockIdx.y * 32 + threadIdx.y; 388 | 389 | float sum_re = 0.0; 390 | float sum_im = 0.0; 391 | 392 | // will not calculate when l > min(b_in, b_out)-1 393 | for (int tile = 0; tile < CEIL_DIV(MIN(B_IN, B_OUT) - l_min, 32); ++tile) { 394 | __shared__ float tileA[2][32][32]; 395 | __shared__ float tileB[32][32+1]; 396 | 397 | int l = l_min + tile * 32 + threadIdx.x; 398 | int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n); 399 | int i = (lmn * NBATCH + batch) * 2; 400 | tileA[0][threadIdx.y][threadIdx.x] = l < MIN(B_IN, B_OUT) && batch < NBATCH && m < B_OUT && n < B_OUT && m > -B_OUT && n > -B_OUT ? in[i + 0] : 0.0; 401 | tileA[1][threadIdx.y][threadIdx.x] = l < MIN(B_IN, B_OUT) && batch < NBATCH && m < B_OUT && n < B_OUT && m > -B_OUT && n > -B_OUT ? in[i + 1] : 0.0; 402 | // add constraints to m and n to remove aliasing (when b_out < b_in) 403 | 404 | int beta = blockIdx.x * 32 + threadIdx.y; 405 | tileB[threadIdx.x][threadIdx.y] = l < MIN(B_IN, B_OUT) && beta < 2*B_OUT ? wig[beta * NSPEC + lmn] : 0.0; 406 | 407 | __syncthreads(); 408 | 409 | for (int l = 0; l < 32; ++l) { 410 | sum_re += tileA[0][threadIdx.y][l] * tileB[l][threadIdx.x]; 411 | sum_im += tileA[1][threadIdx.y][l] * tileB[l][threadIdx.x]; 412 | } 413 | 414 | __syncthreads(); 415 | } 416 | 417 | int beta = blockIdx.x * 32 + threadIdx.x; 418 | 419 | if (beta < 2*B_OUT && batch < NBATCH) { 420 | int i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(m, 2*B_OUT)) * 2*B_OUT + MOD(n, 2*B_OUT)) * 2; 421 | out[i + 0] = sum_re; 422 | out[i + 1] = sum_im; 423 | 424 | #ifdef REAL_OUT 425 | i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(-m, 2*B_OUT)) * 2*B_OUT + MOD(-n, 2*B_OUT)) * 2; 426 | out[i + 0] = sum_re; 427 | out[i + 1] = -sum_im; 428 | #endif 429 | } 430 | } 431 | ''' 432 | import s2cnn.utils.cuda as cuda_utils 433 | kernel = cuda_utils.compile_kernel(kernel, 'so3ifft.cu', 'main_') 434 | stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) 435 | 436 | def fun(x, wigner, output): 437 | output[:] = 0 438 | kernel(block=(32, 32, 1), 439 | grid=(math.ceil(2 * b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1) ** 2), 440 | args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], 441 | stream=stream) 442 | 443 | return fun 444 | 445 | 446 | class SO3_fft_real(torch.autograd.Function): 447 | @staticmethod 448 | def forward(ctx, x, b_out=None): # pylint: disable=W 449 | ctx.b_out = b_out 450 | ctx.b_in = x.size(-1) // 2 451 | return so3_rfft(x, b_out=ctx.b_out) 452 | 453 | @staticmethod 454 | def backward(self, grad_output): # pylint: disable=W 455 | # ifft of grad_output is not necessarily real, therefore we cannot use rifft 456 | return so3_ifft(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None 457 | 458 | 459 | class SO3_ifft_real(torch.autograd.Function): 460 | @staticmethod 461 | def forward(ctx, x, b_out=None): # pylint: disable=W 462 | nspec = x.size(0) 463 | ctx.b_out = b_out 464 | ctx.b_in = round((3 / 4 * nspec) ** (1 / 3)) 465 | return so3_rifft(x, b_out=ctx.b_out) 466 | 467 | @staticmethod 468 | def backward(ctx, grad_output): # pylint: disable=W 469 | return so3_rfft(grad_output, for_grad=True, b_out=ctx.b_in), None 470 | --------------------------------------------------------------------------------