├── .gitignore ├── README.rst ├── benchmarks ├── benchmark_contiguous_cpu.out ├── benchmark_contiguous_cpu.py ├── benchmark_contiguous_gpu.out ├── benchmark_contiguous_gpu.py ├── benchmark_real_data.py └── old │ ├── benchmark_contiguous.out │ ├── benchmark_contiguous.py │ ├── benchmark_contiguous_tmp.out │ └── benchmark_contiguous_tmp.py ├── ci ├── docker │ └── Dockerfile.build └── pipeline.yml ├── data ├── molecular_crystals.xyz └── random-methane-10k.extxyz ├── docs ├── CG.png ├── Makefile ├── benchmarks.rst ├── conf.py ├── index.rst ├── make.bat ├── reference_guide.rst └── tutorial.ipynb ├── pyproject.toml ├── setup.py ├── sparse_accumulation ├── __init__.py ├── clebsch_gordan.py ├── cpu_extension │ ├── __init__.py │ ├── sparse_accumulation_active_dim_first.cpp │ ├── sparse_accumulation_active_dim_first.py │ ├── sparse_accumulation_active_dim_last.cpp │ ├── sparse_accumulation_active_dim_last.py │ ├── sparse_accumulation_active_dim_middle.cpp │ └── sparse_accumulation_active_dim_middle.py ├── cuda_extension │ ├── jit.py │ ├── sparse_accumulation_cuda.cpp │ ├── sparse_accumulation_cuda_kernel.cu │ └── sparse_accumulation_cuda_kernel2D.cu ├── other_operations.py ├── reference_implementations.py └── unified_operation.py ├── tests ├── test_cpp_contiguous.py └── test_cpp_jit_cuda.py └── update_docs.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | build/ 3 | dist/ 4 | __pycache__/ 5 | 6 | *.egg-info/ 7 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. inclusion-marker-preambule-start 2 | 3 | .. role:: bash(code) 4 | :language: bash 5 | 6 | Sparse accumulation 7 | =================== 8 | 9 | 10 | This package contains significantly optimized CPU and GPU PyTorch extensions for the operation we call sparse accumulation. This operation takes two input arrays - X_1 and X_2, and produces an output one, given the transformation rule defined by one-dimensional arrays m_1, m_2, mu, and C. The functional form can be best explained by the following pseudocode: 11 | 12 | .. code-block:: python 13 | 14 | output = torch.zeros([..., output_size]) 15 | for index in range(m_1.shape[0]): 16 | output[..., mu[index]] += X_1[..., m_1[index]] * X_2[..., m_2[index]] * C[index] 17 | 18 | 19 | 20 | This operation is required for SO(3) equivariant neural networks and other machine learning models. The fundamental building block of such methods is the so-called Clebsch-Gordan iteration given by: 21 | 22 | .. image:: https://raw.githubusercontent.com/lab-cosmo/sparse_accumulation/docs_update/docs/CG.png 23 | 24 | where :math:`C_{m_1, m_2, \mu}^{l_1, l_2, l_{output}}` are the Clebsch-Gordan coefficients. These coefficients are sparse, particularly for the complex-valued version the sparsity pattern is that the only non-zero values are for :math:`m_1 + m_2 = \mu`. For the real-valued version, the sparsity pattern is more complicated, but still, only a small ratio of the entries are non-zeros. Thus, it makes sense to store only non-zero values in a one-dimensional array. In this case, one needs to provide additional arrays with indices providing the information about the corresponding :math:`m_1`, :math:`m_2` and :math:`\mu`. With such data organization, the CG iteration falls to the defined above sparse accumulation operation. 25 | 26 | Our benchmarks show that our custom PyTorch extension while being memory efficient, is significantly faster compared to all alternative implementations we were able to come up with, including dense matrix multiplication (with a lot of zeros inside due to sparsity of CG coefficients), sparse matrix multiplication using PyTorch `sparse engine `_ and the one relying on PyTorch `index_add `_. 27 | 28 | [todo] benchmark against e3nn 29 | 30 | All the benchmarks measurements and reference implementations details can be found in the [todo: structure results into the table] `benchmarks section `_. 31 | 32 | ++++++++++++ 33 | Installation 34 | ++++++++++++ 35 | 36 | :bash:`python3 -m pip install .` 37 | 38 | ++++++++++++ 39 | Tests 40 | ++++++++++++ 41 | 42 | gpu tests: 43 | :bash:`python3 -m pytest tests/test_cpp_jit_cuda.py -vrA` 44 | 45 | cpu tests: 46 | :bash:`python3 -m pytest tests/test_cpp_contiguous.py` 47 | 48 | .. inclusion-marker-preambule-end 49 | 50 | +++++++++++++ 51 | Documentation 52 | +++++++++++++ 53 | 54 | Documentation can be found `here `_ 55 | -------------------------------------------------------------------------------- /benchmarks/benchmark_contiguous_cpu.out: -------------------------------------------------------------------------------- 1 | L_MAX=8; BATCH_SIZE=1000; N_FEATURES=2000 2 | preparing real life transformation rule 3 | transformation rule is computed 4 | ************* 5 | CPU BENCHMARKS 6 | Running on 72 threads 7 | ************* 8 | 9 | ***forward*** 10 | 11 | python loops; active dim 0; forward; cpu: 0.287917349073622 12 | torch index_add_; active dim 0; forward; cpu: 0.2882208559248183 13 | cpp; active dim 0; forward; cpu: 0.06418042712741429 14 | 15 | python loops; active dim 1; forward; cpu: 0.10099238819546169 16 | torch index_add_; active dim 1; forward; cpu: 0.19646917449103463 17 | cpp; active dim 1; forward; cpu: 0.015409390131632486 18 | 19 | python loops; active dim 2; forward; cpu: 1.13313627243042 20 | torch index_add_; active dim 2; forward; cpu: 0.9349785645802816 21 | cpp; active dim 2; forward; cpu 0.029056257671780057 22 | 23 | ***backward*** 24 | 25 | python loops; active dim 0; backward; cpu 8.56085040834215 26 | torch index_add_; active dim 0; backward; cpu 0.8768206967247857 27 | cpp; active dim 0; backward; cpu 0.14745905664232042 28 | 29 | python loops; active dim 1; backward; cpu 12.528574811087715 30 | torch index_add_; active dim 1; backward; cpu 1.3579767015245225 31 | cpp; active dim 1; backward; cpu 0.11550368203057183 32 | 33 | python loops; active dim 2; backward; cpu 1.43605547481113 34 | torch index_add_; active dim 2; backward; cpu 1.3703345987531874 35 | cpp; active dim 2; backward; cpu 0.05493460761176215 36 | -------------------------------------------------------------------------------- /benchmarks/benchmark_contiguous_cpu.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | from sparse_accumulation.clebsch_gordan import get_real_clebsch_gordan, ClebschGordan 5 | from sparse_accumulation.reference_implementations import sparse_accumulation_loops, sparse_accumulation_index_add 6 | from sparse_accumulation.cpu_extension import sparse_accumulation_active_dim_last, sparse_accumulation_active_dim_first, sparse_accumulation_active_dim_middle 7 | import numpy as np 8 | from torch.utils import cpp_extension 9 | 10 | L_MAX = 8 11 | BATCH_SIZE = 1000 12 | N_FEATURES = 2000 13 | print(f"L_MAX={L_MAX}; BATCH_SIZE={BATCH_SIZE}; N_FEATURES={N_FEATURES}") 14 | print("preparing real life transformation rule") 15 | 16 | clebsch = ClebschGordan(L_MAX).precomputed_ 17 | indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX) 18 | 19 | m1_aligned, m2_aligned = [], [] 20 | multipliers, mu_aligned = [], [] 21 | for mu in range(0, 2 * L_MAX + 1): 22 | for el in indices[mu]: 23 | m1, m2, multiplier = el 24 | m1_aligned.append(m1) 25 | m2_aligned.append(m2) 26 | multipliers.append(multiplier) 27 | mu_aligned.append(mu) 28 | m1_aligned = torch.LongTensor(m1_aligned) 29 | m2_aligned = torch.LongTensor(m2_aligned) 30 | mu_aligned = torch.LongTensor(mu_aligned) 31 | multipliers = torch.FloatTensor(multipliers) 32 | 33 | indices = np.argsort(mu_aligned) 34 | 35 | m1_aligned = m1_aligned[indices] 36 | m2_aligned = m2_aligned[indices] 37 | mu_aligned = mu_aligned[indices] 38 | multipliers = multipliers[indices] 39 | 40 | print("transformation rule is computed") 41 | 42 | def get_input(BATCH_SIZE, N_FEATURES, active_dim, device): 43 | if active_dim == 0: 44 | X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device = device) 45 | X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device = device) 46 | 47 | if active_dim == 1: 48 | X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device = device) 49 | X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device = device) 50 | 51 | if active_dim == 2: 52 | X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device = device) 53 | X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device = device) 54 | 55 | 56 | if (active_dim != 0) and (active_dim != 2) and (active_dim != 1): 57 | raise ValueError("active dim should be one of 0, 1, 2") 58 | 59 | return X1, X2 60 | 61 | def benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 62 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cpu') 63 | times = [] 64 | 65 | for _ in range(n_trials): 66 | begin = time.time() 67 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 68 | times.append(time.time() - begin) 69 | return times 70 | 71 | 72 | def benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 73 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cpu') 74 | 75 | X1.requires_grad = True 76 | X2.requires_grad = True 77 | times = [] 78 | for _ in range(n_trials): 79 | begin = time.time() 80 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 81 | output.backward(gradient=torch.ones_like(output)) 82 | times.append(time.time() - begin) 83 | return np.array(times) 84 | 85 | 86 | def get_func_fixed_dim(func, active_dim): 87 | def func_fixed_dim(*args): 88 | return func(*args, active_dim = active_dim) 89 | return func_fixed_dim 90 | 91 | 92 | print("*************") 93 | print("CPU BENCHMARKS") 94 | print(f"Running on {torch.get_num_threads()} threads") 95 | print("*************") 96 | 97 | m1_aligned = m1_aligned.cpu() 98 | m2_aligned = m2_aligned.cpu() 99 | mu_aligned = mu_aligned.cpu() 100 | multipliers = multipliers.cpu() 101 | 102 | print() 103 | print("***forward***") 104 | print() 105 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 0, 106 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 107 | print("python loops; active dim 0; forward; cpu: ", np.mean(times[1:])) 108 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 0, 109 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 110 | print("torch index_add_; active dim 0; forward; cpu: ", np.mean(times[1:])) 111 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 0, 112 | sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 113 | print("cpp; active dim 0; forward; cpu: ", np.mean(times[1:])) 114 | 115 | print() 116 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 1, 117 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 118 | print("python loops; active dim 1; forward; cpu: ", np.mean(times[1:])) 119 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 1, 120 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 121 | print("torch index_add_; active dim 1; forward; cpu: ", np.mean(times[1:])) 122 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 1, 123 | sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 124 | print("cpp; active dim 1; forward; cpu: ", np.mean(times[1:])) 125 | 126 | print() 127 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_loops, 2), 10) 128 | print("python loops; active dim 2; forward; cpu:", np.mean(times[1:])) 129 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 130 | print("torch index_add_; active dim 2; forward; cpu: ", np.mean(times[1:])) 131 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation_active_dim_last.SparseAccumulationActiveDimLast.apply, 10) 132 | print("cpp; active dim 2; forward; cpu ", np.mean(times[1:])) 133 | 134 | print() 135 | print("***backward***") 136 | print() 137 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 0, 138 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 139 | print("python loops; active dim 0; backward; cpu ", np.mean(times[1:])) 140 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 0, 141 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 142 | print("torch index_add_; active dim 0; backward; cpu ", np.mean(times[1:])) 143 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 0, 144 | sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 145 | print("cpp; active dim 0; backward; cpu ", np.mean(times[1:])) 146 | 147 | print() 148 | 149 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 1, 150 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 151 | print("python loops; active dim 1; backward; cpu ", np.mean(times[1:])) 152 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 1, 153 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 154 | print("torch index_add_; active dim 1; backward; cpu ", np.mean(times[1:])) 155 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 1, 156 | sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 157 | print("cpp; active dim 1; backward; cpu ", np.mean(times[1:])) 158 | 159 | 160 | print() 161 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 2, 162 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 163 | print("python loops; active dim 2; backward; cpu ", np.mean(times[1:])) 164 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 2, 165 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 166 | print("torch index_add_; active dim 2; backward; cpu ", np.mean(times[1:])) 167 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation_active_dim_last.SparseAccumulationActiveDimLast.apply, 10) 168 | print("cpp; active dim 2; backward; cpu ", np.mean(times[1:])) 169 | -------------------------------------------------------------------------------- /benchmarks/benchmark_contiguous_gpu.out: -------------------------------------------------------------------------------- 1 | ninja: no work to do. 2 | L_MAX=8; BATCH_SIZE=1000; N_FEATURES=2000 3 | preparing real life transformation rule 4 | transformation rule is computed 5 | 6 | ************* 7 | GPU benchmarks 8 | ************* 9 | 10 | ***forward*** 11 | 12 | python loops; active dim 0; forward; cuda: 0.06442029486762153 13 | torch index_add_; active dim 0; forward; cuda: 0.04863210678100585 14 | 15 | python loops; active dim 1; forward; cuda: 0.07500541941324869 16 | torch index_add_; active dim 1; forward; cuda: 0.04478669526841905 17 | 18 | python loops; active dim 2; forward; cuda: 0.3839471096462673 19 | torch index_add_; active dim 2; forward; cuda: 0.04732361221313477 20 | CUDA kernel; active dim 2; forward; cuda: 0.002660088883505928 21 | 22 | ***backward*** 23 | 24 | python loops; active dim 0; backward; cuda: 1.1864166802300347 25 | torch index_add_; active dim 0; backward; cuda: 0.09936237504747178 26 | 27 | python loops; active dim 1; backward; cuda: 1.5248256022135418 28 | torch index_add_; active dim 1; backward; cuda: 0.09908599090576171 29 | 30 | python loops; active dim 2; backward; cuda: 0.7663369886610244 31 | torch index_add_; active dim 2; backward; cuda: 0.7663484971788194 32 | CUDA kernel; active dim 2; backward; cuda: 0.0068883590698242195 33 | -------------------------------------------------------------------------------- /benchmarks/benchmark_contiguous_gpu.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | from sparse_accumulation.clebsch_gordan import get_real_clebsch_gordan, ClebschGordan 5 | from sparse_accumulation.reference_implementations import sparse_accumulation_loops, sparse_accumulation_index_add 6 | from sparse_accumulation.reference_implementations import sparse_accumulation_matrix_multiply, sparse_accumulation_sparse_matrix_multiply, sparse_accumulation_sparse_matrix_multiply_optimized 7 | from e3nn import o3 8 | from sparse_accumulation.reference_implementations import get_transformation, get_transformation_sparse 9 | 10 | import numpy as np 11 | from torch.utils import cpp_extension 12 | 13 | cpp_extension.load( 14 | name="sparse_accumulation_cuda", 15 | sources=["../sparse_accumulation/cuda_extension/sparse_accumulation_cuda_kernel2D.cu"], 16 | is_python_module=False, 17 | extra_cuda_cflags=None, 18 | verbose=True, 19 | ) 20 | 21 | import sparse_accumulation_cuda 22 | 23 | L_MAX = 8 24 | BATCH_SIZE = 2000 25 | N_FEATURES = 300 26 | print(f"L_MAX={L_MAX}; BATCH_SIZE={BATCH_SIZE}; N_FEATURES={N_FEATURES}") 27 | print("preparing real life transformation rule") 28 | 29 | clebsch = ClebschGordan(L_MAX).precomputed_ 30 | indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX) 31 | 32 | m1_aligned, m2_aligned = [], [] 33 | multipliers, mu_aligned = [], [] 34 | for mu in range(0, 2 * L_MAX + 1): 35 | for el in indices[mu]: 36 | m1, m2, multiplier = el 37 | m1_aligned.append(m1) 38 | m2_aligned.append(m2) 39 | multipliers.append(multiplier) 40 | mu_aligned.append(mu) 41 | m1_aligned = torch.LongTensor(m1_aligned) 42 | m2_aligned = torch.LongTensor(m2_aligned) 43 | mu_aligned = torch.LongTensor(mu_aligned) 44 | multipliers = torch.FloatTensor(multipliers) 45 | 46 | indices = np.argsort(mu_aligned) 47 | 48 | m1_aligned = m1_aligned[indices] 49 | m2_aligned = m2_aligned[indices] 50 | mu_aligned = mu_aligned[indices] 51 | multipliers = multipliers[indices] 52 | 53 | transformation = get_transformation(mu_aligned, 2 * L_MAX + 1, 2 * L_MAX + 1, 2 * L_MAX + 1, 54 | m1_aligned, m2_aligned, multipliers) 55 | transformation_sparse = get_transformation_sparse(mu_aligned, 2 * L_MAX + 1, 2 * L_MAX + 1, 2 * L_MAX + 1, 56 | m1_aligned, m2_aligned, multipliers) 57 | 58 | num_total = BATCH_SIZE * N_FEATURES 59 | e3nn_transformation = o3.ElementwiseTensorProduct(f"{num_total}x{L_MAX}e", f"{num_total}x{L_MAX}e", [f"{L_MAX}e"]) 60 | print("transformation rule is computed") 61 | 62 | 63 | 64 | def get_input(BATCH_SIZE, N_FEATURES, active_dim, device): 65 | if active_dim == 0: 66 | X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device = device) 67 | X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device = device) 68 | 69 | if active_dim == 1: 70 | X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device = device) 71 | X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device = device) 72 | 73 | if active_dim == 2: 74 | X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device = device) 75 | X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device = device) 76 | 77 | 78 | if (active_dim != 0) and (active_dim != 2) and (active_dim != 1): 79 | raise ValueError("active dim should be one of 0, 1, 2") 80 | 81 | return X1, X2 82 | 83 | def benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 84 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cpu') 85 | times = [] 86 | 87 | for _ in range(n_trials): 88 | begin = time.time() 89 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 90 | times.append(time.time() - begin) 91 | return times 92 | 93 | 94 | 95 | def benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 96 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cuda') 97 | times = [] 98 | torch.cuda.synchronize('cuda') 99 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 100 | enable_timing=True) 101 | 102 | for _ in range(n_trials): 103 | starter.record() 104 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 105 | ender.record() 106 | torch.cuda.synchronize('cuda') 107 | delta_time = starter.elapsed_time(ender) 108 | times.append(delta_time / 1000.0) 109 | return times 110 | 111 | 112 | def benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 113 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cpu') 114 | 115 | X1.requires_grad = True 116 | X2.requires_grad = True 117 | times = [] 118 | for _ in range(n_trials): 119 | begin = time.time() 120 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 121 | output.backward(gradient=torch.ones_like(output)) 122 | times.append(time.time() - begin) 123 | return np.array(times) 124 | 125 | 126 | def benchmark_forward_e3nn_gpu(BATCH_SIZE, N_FEATURES, n_trials): 127 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, 2, "cuda") 128 | X1 = X1.reshape(-1) 129 | X2 = X2.reshape(-1) 130 | 131 | times = [] 132 | torch.cuda.synchronize("cuda") 133 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 134 | enable_timing=True 135 | ) 136 | 137 | for _ in range(n_trials): 138 | starter.record() 139 | output = e3nn_transformation(X1, X2) 140 | ender.record() 141 | torch.cuda.synchronize("cuda") 142 | delta_time = starter.elapsed_time(ender) 143 | times.append(delta_time / 1000.0) 144 | return times 145 | 146 | 147 | 148 | def benchmark_forward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 149 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 150 | times = [] 151 | torch.cuda.synchronize("cuda") 152 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 153 | enable_timing=True 154 | ) 155 | 156 | for _ in range(n_trials): 157 | starter.record() 158 | output = sparse_accumulation_matrix_multiply( 159 | X1, X2, transformation 160 | ) 161 | ender.record() 162 | torch.cuda.synchronize("cuda") 163 | delta_time = starter.elapsed_time(ender) 164 | times.append(delta_time / 1000.0) 165 | return times 166 | 167 | def benchmark_forward_sparse_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 168 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 169 | times = [] 170 | torch.cuda.synchronize("cuda") 171 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 172 | enable_timing=True 173 | ) 174 | 175 | for _ in range(n_trials): 176 | starter.record() 177 | output = sparse_accumulation_sparse_matrix_multiply( 178 | X1, X2, transformation_sparse 179 | ) 180 | ender.record() 181 | torch.cuda.synchronize("cuda") 182 | delta_time = starter.elapsed_time(ender) 183 | times.append(delta_time / 1000.0) 184 | #print(output.shape) 185 | return times 186 | 187 | def benchmark_forward_sparse_matrix_multiply_optimized_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 188 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 189 | times = [] 190 | torch.cuda.synchronize("cuda") 191 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 192 | enable_timing=True 193 | ) 194 | 195 | for _ in range(n_trials): 196 | starter.record() 197 | output = sparse_accumulation_sparse_matrix_multiply_optimized( 198 | X1, X2, transformation_sparse 199 | ) 200 | ender.record() 201 | torch.cuda.synchronize("cuda") 202 | delta_time = starter.elapsed_time(ender) 203 | times.append(delta_time / 1000.0) 204 | #print(output.shape) 205 | return times 206 | 207 | 208 | def benchmark_backward_e3nn_gpu(BATCH_SIZE, N_FEATURES, n_trials): 209 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, 2, "cuda") 210 | X1 = X1.reshape(-1) 211 | X2 = X2.reshape(-1) 212 | X1.requires_grad = True 213 | X2.requires_grad = True 214 | 215 | times = [] 216 | torch.cuda.synchronize("cuda") 217 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 218 | enable_timing=True 219 | ) 220 | 221 | for _ in range(n_trials): 222 | starter.record() 223 | 224 | output = e3nn_transformation(X1, X2) 225 | torch.cuda.synchronize("cuda") 226 | starter.record() 227 | output.backward(gradient=torch.ones_like(output)) 228 | 229 | ender.record() 230 | torch.cuda.synchronize("cuda") 231 | delta_time = starter.elapsed_time(ender) 232 | times.append(delta_time / 1000.0) 233 | 234 | return times 235 | 236 | 237 | def benchmark_backward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 238 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 239 | 240 | X1.requires_grad = True 241 | X2.requires_grad = True 242 | times = [] 243 | 244 | torch.cuda.synchronize("cuda") 245 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 246 | enable_timing=True 247 | ) 248 | 249 | for _ in range(n_trials): 250 | output = sparse_accumulation_matrix_multiply( 251 | X1, X2, transformation 252 | ) 253 | torch.cuda.synchronize("cuda") 254 | starter.record() 255 | output.backward(gradient=torch.ones_like(output)) 256 | 257 | ender.record() 258 | torch.cuda.synchronize("cuda") 259 | delta_time = starter.elapsed_time(ender) 260 | times.append(delta_time / 1000.0) 261 | return np.array(times) 262 | 263 | def benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 264 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cuda') 265 | 266 | X1.requires_grad = True 267 | X2.requires_grad = True 268 | times = [] 269 | 270 | torch.cuda.synchronize('cuda') 271 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 272 | enable_timing=True) 273 | 274 | for _ in range(n_trials): 275 | starter.record() 276 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 277 | output.backward(gradient=torch.ones_like(output)) 278 | ender.record() 279 | torch.cuda.synchronize('cuda') 280 | delta_time = starter.elapsed_time(ender) 281 | times.append(delta_time / 1000.0) 282 | return np.array(times) 283 | 284 | 285 | def benchmark_backward_gpu_cuda(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 286 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cuda') 287 | 288 | X1.requires_grad = True 289 | X2.requires_grad = True 290 | times = [] 291 | 292 | torch.cuda.synchronize('cuda') 293 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 294 | enable_timing=True) 295 | 296 | for _ in range(n_trials): 297 | starter.record() 298 | output = function(torch.ones((X1.size()[0], X1.size()[1], 2 * L_MAX + 1),dtype=X1.dtype,device='cuda'),X1, X2, mu_aligned, m1_aligned, m2_aligned, multipliers) 299 | #output.backward(gradient=torch.ones_like(output)) 300 | ender.record() 301 | torch.cuda.synchronize('cuda') 302 | delta_time = starter.elapsed_time(ender) 303 | times.append(delta_time / 1000.0) 304 | return np.array(times) 305 | 306 | 307 | def get_func_fixed_dim(func, active_dim): 308 | def func_fixed_dim(*args): 309 | return func(*args, active_dim = active_dim) 310 | return func_fixed_dim 311 | 312 | 313 | print() 314 | print("*************") 315 | print("GPU benchmarks") 316 | print("*************") 317 | 318 | m1_aligned = m1_aligned.cuda() 319 | m2_aligned = m2_aligned.cuda() 320 | mu_aligned = mu_aligned.cuda() 321 | multipliers = multipliers.cuda() 322 | transformation = transformation.cuda() 323 | transformation_sparse = transformation_sparse.cuda() 324 | e3nn_transformation = e3nn_transformation.cuda() 325 | print() 326 | print("***forward***") 327 | print() 328 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 0, 329 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 330 | print("python loops; active dim 0; forward; cuda: ", np.mean(times[1:])) 331 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 0, 332 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 333 | print("torch index_add_; active dim 0; forward; cuda: ", np.mean(times[1:])) 334 | 335 | print() 336 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 1, 337 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 338 | print("python loops; active dim 1; forward; cuda: ", np.mean(times[1:])) 339 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 1, 340 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 341 | print("torch index_add_; active dim 1; forward; cuda: ", np.mean(times[1:])) 342 | 343 | print() 344 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 345 | get_func_fixed_dim(sparse_accumulation_loops, 2), 10) 346 | print("python loops; active dim 2; forward; cuda: ", np.mean(times[1:])) 347 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 348 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 349 | print("torch index_add_; active dim 2; forward; cuda: ", np.mean(times[1:])) 350 | 351 | #print("torch index_add_; active dim 2; forward; cuda: ", np.mean(times[1:])) 352 | times = benchmark_forward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, 2, 10) 353 | print("dense matrix multiply: ", np.mean(times[1:])) 354 | 355 | times = benchmark_forward_sparse_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, 2, 10) 356 | print("sparse matrix multiply; active dim 2; forward; cuda: ", np.mean(times[1:])) 357 | 358 | times = benchmark_forward_sparse_matrix_multiply_optimized_gpu(BATCH_SIZE, N_FEATURES, 0, 10) 359 | print("sparse matrix optimized multiply; active dim 0; forward; cuda: ", np.mean(times[1:])) 360 | 361 | times = benchmark_forward_e3nn_gpu(BATCH_SIZE, N_FEATURES, 10) 362 | print("e3nn: ", np.mean(times[1:])) 363 | 364 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 365 | sparse_accumulation_cuda.forward, 10) 366 | print("CUDA kernel; active dim 2; forward; cuda: ", np.mean(times[1:])) 367 | 368 | 369 | print() 370 | print("***backward***") 371 | print() 372 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 0, 373 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 374 | print("python loops; active dim 0; backward; cuda: ", np.mean(times[1:])) 375 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 0, 376 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 377 | print("torch index_add_; active dim 0; backward; cuda: ", np.mean(times[1:])) 378 | 379 | print() 380 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 1, 381 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 382 | print("python loops; active dim 1; backward; cuda: ", np.mean(times[1:])) 383 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 1, 384 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 385 | print("torch index_add_; active dim 1; backward; cuda: ", np.mean(times[1:])) 386 | 387 | print() 388 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 2, 389 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 390 | print("python loops; active dim 2; backward; cuda: ", np.mean(times[1:])) 391 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 2, 392 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 393 | print("torch index_add_; active dim 2; backward; cuda: ", np.mean(times[1:])) 394 | 395 | times = benchmark_backward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, 2, 10) 396 | print("dense matrix multiply: ", np.mean(times[1:])) 397 | 398 | 399 | times = benchmark_backward_e3nn_gpu(BATCH_SIZE, N_FEATURES, 10) 400 | print("e3nn backward: ", np.mean(times[1:])) 401 | 402 | 403 | times = benchmark_backward_gpu_cuda(BATCH_SIZE, N_FEATURES, 2, 404 | sparse_accumulation_cuda.backward, 10) 405 | print("CUDA kernel; active dim 2; backward; cuda: ", np.mean(times[1:])) 406 | 407 | 408 | -------------------------------------------------------------------------------- /benchmarks/benchmark_real_data.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import ase.io 4 | import rascaline 5 | import torch 6 | import torch.utils.cpp_extension 7 | from equistore import TensorBlock, TensorMap 8 | 9 | from clebsch_gordan import ClebschGordan, get_real_clebsch_gordan 10 | from sparse_accumulation_plain_torch import sparse_accumulation_index_add 11 | 12 | torch.utils.cpp_extension.load( 13 | name="sparse_accumulation_cuda", 14 | sources=["cuda_optimized/sparse_accumulation_cuda_kernel2D.cu"], 15 | is_python_module=False, 16 | extra_cuda_cflags=None, 17 | verbose=True, 18 | ) 19 | 20 | L_MAX = 8 21 | HYPERS = { 22 | "cutoff": 3.5, 23 | "max_radial": 10, 24 | "max_angular": L_MAX - 1, 25 | "atomic_gaussian_width": 0.3, 26 | "radial_basis": {"Gto": {}}, 27 | "center_atom_weight": 1.0, 28 | "cutoff_function": {"ShiftedCosine": {"width": 0.5}}, 29 | } 30 | 31 | 32 | def descriptor_to_torch(descriptor, device, dtype): 33 | blocks = [] 34 | for _, block in descriptor: 35 | new_block = TensorBlock( 36 | values=torch.tensor(block.values, device=device, dtype=dtype), 37 | samples=block.samples, 38 | components=block.components, 39 | properties=block.properties, 40 | ) 41 | 42 | for parameter in block.gradients_list(): 43 | gradient = block.gradient(parameter) 44 | new_block.add_gradient( 45 | parameter, 46 | data=torch.tensor(block.values, device=device, dtype=dtype), 47 | samples=gradient.samples, 48 | components=gradient.components, 49 | ) 50 | 51 | blocks.append(new_block) 52 | 53 | return TensorMap(descriptor.keys, blocks) 54 | 55 | 56 | def descriptor_to_torch_list(descriptor, device, dtype): 57 | data = [] 58 | for (l,), block in descriptor: 59 | data.append(((l,), torch.tensor(block.values, device=device, dtype=dtype))) 60 | return data 61 | 62 | 63 | def descriptor_to_torch_list_active_dim_last(descriptor, device, dtype): 64 | data = [] 65 | for (l,), block in descriptor: 66 | values = torch.tensor(block.values, device=device, dtype=dtype) 67 | data.append(((l,), values.swapaxes(1, 2).contiguous())) 68 | return data 69 | 70 | 71 | def create_real_data(path, selection): 72 | frames = ase.io.read(path, selection) 73 | 74 | calculator = rascaline.SphericalExpansion(**HYPERS) 75 | descriptor = calculator.compute(frames) 76 | descriptor.keys_to_samples("species_center") 77 | descriptor.keys_to_properties("species_neighbor") 78 | return descriptor 79 | 80 | 81 | def generate_cg_multipliers(device, dtype): 82 | precomputed = {} 83 | clebsch = ClebschGordan(L_MAX).precomputed_ 84 | 85 | for l1 in range(L_MAX): 86 | for l2 in range(L_MAX): 87 | for l in range(L_MAX): 88 | indices = get_real_clebsch_gordan(clebsch[l1, l2, l], l1, l2, l) 89 | m1_aligned = [] 90 | m2_aligned = [] 91 | mu_aligned = [] 92 | multipliers = [] 93 | 94 | for mu in range(0, 2 * l + 1): 95 | for el in indices[mu]: 96 | m1, m2, multiplier = el 97 | m1_aligned.append(m1) 98 | m2_aligned.append(m2) 99 | multipliers.append(multiplier) 100 | mu_aligned.append(mu) 101 | 102 | m1 = torch.tensor(m1_aligned, device=device, dtype=torch.int64) 103 | m2 = torch.tensor(m2_aligned, device=device, dtype=torch.int64) 104 | mu = torch.tensor(mu_aligned, device=device, dtype=torch.int64) 105 | 106 | multipliers = torch.tensor(multipliers, device=device, dtype=dtype) 107 | 108 | precomputed[(l1, l2, l)] = (m1, m2, mu, multipliers) 109 | 110 | return precomputed 111 | 112 | 113 | def run_cg_combine(function, x1, x2, precomputed_cg): 114 | for (l1,), spx1 in x1: 115 | for (l2,), spx2 in x2: 116 | for l in range(L_MAX): 117 | m1, m2, mu, multipliers = precomputed_cg[(l1, l2, l)] 118 | output = function(spx1, spx2, mu, 2 * l + 1, m1, m2, multipliers) 119 | 120 | 121 | def bench_cg_combine(function, x1, x2, precomputed_cg, n_iters=10): 122 | run_cg_combine(function, x1, x2, precomputed_cg) 123 | 124 | start = time.time() 125 | for _ in range(n_iters): 126 | run_cg_combine(function, x1, x2, precomputed_cg) 127 | 128 | elapsed = time.time() - start 129 | return elapsed / n_iters 130 | 131 | 132 | if __name__ == "__main__": 133 | device = "cuda" 134 | dtype = torch.float64 135 | 136 | print(f"\n\nrunning on {device=}, {dtype=}") 137 | 138 | # descriptor = create_real_data("molecular_crystals.xyz", ":30") 139 | descriptor = create_real_data("random-methane-10k.extxyz", ":300") 140 | 141 | precomputed_cg = generate_cg_multipliers(device, dtype) 142 | 143 | print(f"done loading data\n") 144 | 145 | def index_add_impl(x1, x2, mu, size, m1, m2, multipliers): 146 | return sparse_accumulation_index_add( 147 | x1, x2, mu, size, m1, m2, multipliers, active_dim=1 148 | ) 149 | 150 | x = descriptor_to_torch_list(descriptor, device, dtype) 151 | timing = bench_cg_combine(index_add_impl, x, x, precomputed_cg) 152 | print(f"index_add: {1e3 * timing:.5} ms") 153 | 154 | x = descriptor_to_torch_list_active_dim_last(descriptor, device, dtype) 155 | timing = bench_cg_combine( 156 | torch.ops.sparse_accumulation_cuda.forward, x, x, precomputed_cg 157 | ) 158 | print(f"Custom CUDA: {1e3 * timing:.5} ms") 159 | -------------------------------------------------------------------------------- /benchmarks/old/benchmark_contiguous.out: -------------------------------------------------------------------------------- 1 | Using /home/pozdn/.cache/torch_extensions/py38_cu111 as PyTorch extensions root... 2 | Detected CUDA files, patching ldflags 3 | Emitting ninja build file /home/pozdn/.cache/torch_extensions/py38_cu111/sparse_accumulation_cuda/build.ninja... 4 | Building extension module sparse_accumulation_cuda... 5 | Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) 6 | ninja: no work to do. 7 | Loading extension module sparse_accumulation_cuda... 8 | L_MAX=8; BATCH_SIZE=1000; N_FEATURES=200 9 | preparing real life transformation rule 10 | transformation rule is computed 11 | ************* 12 | CUDA 13 | ************* 14 | ***forward*** 15 | python loops; active dim 0; forward; cuda: 0.02626611942715115 16 | torch index_add_; active dim 0; forward; cuda: 0.0067368142869737415 17 | 18 | python loops; active dim 1; forward; cuda: 0.028479398939344615 19 | torch index_add_; active dim 1; forward; cuda: 0.005996817800733778 20 | 21 | python loops first; active dim 2; forward; cuda: 0.09641836547851562 22 | torch index_add_ first; active dim 2; forward; cuda: 0.006406335830688476 23 | CUDA kernel first; active dim 2; forward; cuda: 0.0005303679704666138 24 | 25 | python loops; active dim 2; forward; cuda: 0.09575971052381728 26 | torch index_add_; active dim 2; forward; cuda: 0.006397336800893148 27 | CUDA kernel; active dim 2; forward; cuda: 0.0004992959996064504 28 | ***backward*** 29 | python loops; active dim 0; backward; cuda: 0.2310748816596137 30 | torch index_add_; active dim 0; backward; cuda: 0.015111829439798989 31 | 32 | python loops; active dim 1; backward; cuda: 0.3074806416829427 33 | torch index_add_; active dim 1; backward; cuda: 0.015122151056925456 34 | 35 | python loops; active dim 2; backward; cuda: 0.10060705057779948 36 | torch index_add_; active dim 2; backward; cuda: 0.10064567142062716 37 | CUDA kernel; active dim 2; backward; cuda: 0.001021127118004693 38 | ************* 39 | CPU 40 | ************* 41 | ***forward*** 42 | python loops; active dim 0; forward; cpu: 0.10445965660942926 43 | torch index_add_; active dim 0; forward; cpu: 0.4440735975901286 44 | cpp; active dim 0; forward; cpu: 0.04300083054436578 45 | 46 | python loops; active dim 1; forward; cpu: 0.13819887903001574 47 | torch index_add_; active dim 1; forward; cpu: 0.4437105390760634 48 | cpp; active dim 1; forward; cpu: 0.0679466724395752 49 | 50 | python loops; active dim 2; forward; cpu: 1.0245144102308485 51 | torch index_add_; active dim 2; forward; cpu: 1.4955637719896104 52 | cpp; active dim 2; forward; cpu 0.13358383708530003 53 | ***backward*** 54 | python loops; active dim 0; backward; cpu 6.009502304924859 55 | torch index_add_; active dim 0; backward; cpu 0.8905059761471219 56 | cpp; active dim 0; backward; cpu 0.1339370674557156 57 | 58 | python loops; active dim 1; backward; cpu 8.616596115960014 59 | torch index_add_; active dim 1; backward; cpu 0.9371486769782172 60 | cpp; active dim 1; backward; cpu 0.16515795389811197 61 | 62 | python loops; active dim 2; backward; cpu 3.100877947277493 63 | torch index_add_; active dim 2; backward; cpu 3.103803899553087 64 | cpp; active dim 2; backward; cpu 0.2753954463534885 65 | -------------------------------------------------------------------------------- /benchmarks/old/benchmark_contiguous.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | torch.set_num_threads(1) 4 | from clebsch_gordan import get_real_clebsch_gordan, ClebschGordan 5 | from sparse_accumulation_plain_torch import sparse_accumulation_loops, sparse_accumulation_index_add 6 | from sparse_accumulation_cpu import sparse_accumulation_active_dim_last, sparse_accumulation_active_dim_first, sparse_accumulation_active_dim_middle 7 | import numpy as np 8 | from torch.utils import cpp_extension 9 | 10 | cpp_extension.load( 11 | name="sparse_accumulation_cuda", 12 | sources=["cuda_optimized/sparse_accumulation_cuda_kernel2D.cu"], 13 | is_python_module=False, 14 | extra_cuda_cflags=None, 15 | verbose=True, 16 | ) 17 | 18 | L_MAX = 8 19 | BATCH_SIZE = 1000 20 | N_FEATURES = 200 21 | print(F"{L_MAX=}; {BATCH_SIZE=}; {N_FEATURES=}") 22 | print("preparing real life transformation rule") 23 | 24 | clebsch = ClebschGordan(L_MAX).precomputed_ 25 | indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX) 26 | 27 | m1_aligned, m2_aligned = [], [] 28 | multipliers, mu_aligned = [], [] 29 | for mu in range(0, 2 * L_MAX + 1): 30 | for el in indices[mu]: 31 | m1, m2, multiplier = el 32 | m1_aligned.append(m1) 33 | m2_aligned.append(m2) 34 | multipliers.append(multiplier) 35 | mu_aligned.append(mu) 36 | m1_aligned = torch.LongTensor(m1_aligned) 37 | m2_aligned = torch.LongTensor(m2_aligned) 38 | mu_aligned = torch.LongTensor(mu_aligned) 39 | multipliers = torch.FloatTensor(multipliers) 40 | 41 | indices = np.argsort(mu_aligned) 42 | 43 | m1_aligned = m1_aligned[indices] 44 | m2_aligned = m2_aligned[indices] 45 | mu_aligned = mu_aligned[indices] 46 | multipliers = multipliers[indices] 47 | 48 | print("transformation rule is computed") 49 | 50 | def get_input(BATCH_SIZE, N_FEATURES, active_dim, device): 51 | if active_dim == 0: 52 | X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device = device) 53 | X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device = device) 54 | 55 | if active_dim == 1: 56 | X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device = device) 57 | X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device = device) 58 | 59 | if active_dim == 2: 60 | X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device = device) 61 | X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device = device) 62 | 63 | 64 | if (active_dim != 0) and (active_dim != 2) and (active_dim != 1): 65 | raise ValueError("active dim should be one of 0, 1, 2") 66 | 67 | return X1, X2 68 | 69 | def benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 70 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cpu') 71 | times = [] 72 | 73 | for _ in range(n_trials): 74 | begin = time.time() 75 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 76 | times.append(time.time() - begin) 77 | return times 78 | 79 | 80 | def benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 81 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cuda') 82 | times = [] 83 | torch.cuda.synchronize('cuda') 84 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 85 | enable_timing=True) 86 | 87 | for _ in range(n_trials): 88 | starter.record() 89 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 90 | ender.record() 91 | torch.cuda.synchronize('cuda') 92 | delta_time = starter.elapsed_time(ender) 93 | times.append(delta_time / 1000.0) 94 | return times 95 | 96 | 97 | def benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 98 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cpu') 99 | 100 | X1.requires_grad = True 101 | X2.requires_grad = True 102 | times = [] 103 | for _ in range(n_trials): 104 | begin = time.time() 105 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 106 | output.backward(gradient=torch.ones_like(output)) 107 | times.append(time.time() - begin) 108 | return np.array(times) 109 | 110 | def benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 111 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cuda') 112 | 113 | X1.requires_grad = True 114 | X2.requires_grad = True 115 | times = [] 116 | 117 | torch.cuda.synchronize('cuda') 118 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 119 | enable_timing=True) 120 | 121 | for _ in range(n_trials): 122 | starter.record() 123 | output = function(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 124 | output.backward(gradient=torch.ones_like(output)) 125 | ender.record() 126 | torch.cuda.synchronize('cuda') 127 | delta_time = starter.elapsed_time(ender) 128 | times.append(delta_time / 1000.0) 129 | return np.array(times) 130 | 131 | 132 | def benchmark_backward_gpu_cuda(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 133 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, 'cuda') 134 | 135 | X1.requires_grad = True 136 | X2.requires_grad = True 137 | times = [] 138 | 139 | torch.cuda.synchronize('cuda') 140 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 141 | enable_timing=True) 142 | 143 | for _ in range(n_trials): 144 | starter.record() 145 | output = function(torch.ones((X1.size()[0], X1.size()[1], 2 * L_MAX + 1),dtype=X1.dtype,device='cuda'),X1, X2, mu_aligned, m1_aligned, m2_aligned, multipliers) 146 | #output.backward(gradient=torch.ones_like(output)) 147 | ender.record() 148 | torch.cuda.synchronize('cuda') 149 | delta_time = starter.elapsed_time(ender) 150 | times.append(delta_time / 1000.0) 151 | return np.array(times) 152 | 153 | 154 | # In[5]: 155 | 156 | 157 | def get_func_fixed_dim(func, active_dim): 158 | def func_fixed_dim(*args): 159 | return func(*args, active_dim = active_dim) 160 | return func_fixed_dim 161 | 162 | 163 | print("*************") 164 | print("CUDA") 165 | print("*************") 166 | 167 | m1_aligned = m1_aligned.cuda() 168 | m2_aligned = m2_aligned.cuda() 169 | mu_aligned = mu_aligned.cuda() 170 | multipliers = multipliers.cuda() 171 | 172 | print("***forward***") 173 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 0, 174 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 175 | print("python loops; active dim 0; forward; cuda: ", np.mean(times[1:])) 176 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 0, 177 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 178 | print("torch index_add_; active dim 0; forward; cuda: ", np.mean(times[1:])) 179 | '''times = benchmark_forward(BATCH_SIZE, N_FEATURES, 0, 180 | sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 181 | print("cpp; active dim 0; forward: ", np.mean(times[1:]))''' 182 | 183 | print() 184 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 1, 185 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 186 | print("python loops; active dim 1; forward; cuda: ", np.mean(times[1:])) 187 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 1, 188 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 189 | print("torch index_add_; active dim 1; forward; cuda: ", np.mean(times[1:])) 190 | '''times = benchmark_forward(BATCH_SIZE, N_FEATURES, 1, 191 | sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 192 | print("cpp; active dim 1; forward: ", np.mean(times[1:]))''' 193 | 194 | 195 | print() 196 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 197 | get_func_fixed_dim(sparse_accumulation_loops, 2), 1) 198 | print("python loops first; active dim 2; forward; cuda: ", np.mean(times)) 199 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 200 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 1) 201 | print("torch index_add_ first; active dim 2; forward; cuda: ", np.mean(times)) 202 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 203 | torch.ops.sparse_accumulation_cuda.forward, 1) 204 | print("CUDA kernel first; active dim 2; forward; cuda: ", np.mean(times)) 205 | '''times = benchmark_forward(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation.SparseAccumulation.apply, 10) 206 | print("cpp; active dim 2; forward: ", np.mean(times[1:]))''' 207 | 208 | print() 209 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 210 | get_func_fixed_dim(sparse_accumulation_loops, 2), 10) 211 | print("python loops; active dim 2; forward; cuda: ", np.mean(times[1:])) 212 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 213 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 214 | print("torch index_add_; active dim 2; forward; cuda: ", np.mean(times[1:])) 215 | times = benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, 2, 216 | torch.ops.sparse_accumulation_cuda.forward, 10) 217 | print("CUDA kernel; active dim 2; forward; cuda: ", np.mean(times[1:])) 218 | '''times = benchmark_forward(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation.SparseAccumulation.apply, 10) 219 | print("cpp; active dim 2; forward: ", np.mean(times[1:]))''' 220 | 221 | # In[10]: 222 | 223 | print("***backward***") 224 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 0, 225 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 226 | print("python loops; active dim 0; backward; cuda: ", np.mean(times[1:])) 227 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 0, 228 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 229 | print("torch index_add_; active dim 0; backward; cuda: ", np.mean(times[1:])) 230 | '''times = benchmark_backward(BATCH_SIZE, N_FEATURES, 0, 231 | sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 232 | print("cpp; active dim 0; backward: ", np.mean(times[1:]))''' 233 | 234 | print() 235 | 236 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 1, 237 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 238 | print("python loops; active dim 1; backward; cuda: ", np.mean(times[1:])) 239 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 1, 240 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 241 | print("torch index_add_; active dim 1; backward; cuda: ", np.mean(times[1:])) 242 | '''times = benchmark_backward(BATCH_SIZE, N_FEATURES, 1, 243 | sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 244 | print("cpp; active dim 1; backward: ", np.mean(times[1:]))''' 245 | 246 | 247 | print() 248 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 2, 249 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 250 | print("python loops; active dim 2; backward; cuda: ", np.mean(times[1:])) 251 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 2, 252 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 253 | print("torch index_add_; active dim 2; backward; cuda: ", np.mean(times[1:])) 254 | times = benchmark_backward_gpu_cuda(BATCH_SIZE, N_FEATURES, 2, 255 | torch.ops.sparse_accumulation_cuda.backward, 10) 256 | print("CUDA kernel; active dim 2; backward; cuda: ", np.mean(times[1:])) 257 | '''times = benchmark_backward(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation.SparseAccumulation.apply, 10) 258 | print("cpp; active dim 2; backward: ", np.mean(times[1:]))''' 259 | 260 | 261 | print("*************") 262 | print("CPU") 263 | print("*************") 264 | 265 | m1_aligned = m1_aligned.cpu() 266 | m2_aligned = m2_aligned.cpu() 267 | mu_aligned = mu_aligned.cpu() 268 | multipliers = multipliers.cpu() 269 | print("***forward***") 270 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 0, 271 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 272 | print("python loops; active dim 0; forward; cpu: ", np.mean(times[1:])) 273 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 0, 274 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 275 | print("torch index_add_; active dim 0; forward; cpu: ", np.mean(times[1:])) 276 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 0, 277 | sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 278 | print("cpp; active dim 0; forward; cpu: ", np.mean(times[1:])) 279 | 280 | print() 281 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 1, 282 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 283 | print("python loops; active dim 1; forward; cpu: ", np.mean(times[1:])) 284 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 1, 285 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 286 | print("torch index_add_; active dim 1; forward; cpu: ", np.mean(times[1:])) 287 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 1, 288 | sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 289 | print("cpp; active dim 1; forward; cpu: ", np.mean(times[1:])) 290 | 291 | print() 292 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_loops, 2), 10) 293 | print("python loops; active dim 2; forward; cpu:", np.mean(times[1:])) 294 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 295 | print("torch index_add_; active dim 2; forward; cpu: ", np.mean(times[1:])) 296 | times = benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation_active_dim_last.SparseAccumulationActiveDimLast.apply, 10) 297 | print("cpp; active dim 2; forward; cpu ", np.mean(times[1:])) 298 | print("***backward***") 299 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 0, 300 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 301 | print("python loops; active dim 0; backward; cpu ", np.mean(times[1:])) 302 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 0, 303 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 304 | print("torch index_add_; active dim 0; backward; cpu ", np.mean(times[1:])) 305 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 0, 306 | sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 307 | print("cpp; active dim 0; backward; cpu ", np.mean(times[1:])) 308 | 309 | print() 310 | 311 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 1, 312 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 313 | print("python loops; active dim 1; backward; cpu ", np.mean(times[1:])) 314 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 1, 315 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 316 | print("torch index_add_; active dim 1; backward; cpu ", np.mean(times[1:])) 317 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 1, 318 | sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 319 | print("cpp; active dim 1; backward; cpu ", np.mean(times[1:])) 320 | 321 | 322 | print() 323 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 2, 324 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 325 | print("python loops; active dim 2; backward; cpu ", np.mean(times[1:])) 326 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 2, 327 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 328 | print("torch index_add_; active dim 2; backward; cpu ", np.mean(times[1:])) 329 | times = benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation_active_dim_last.SparseAccumulationActiveDimLast.apply, 10) 330 | print("cpp; active dim 2; backward; cpu ", np.mean(times[1:])) -------------------------------------------------------------------------------- /benchmarks/old/benchmark_contiguous_tmp.out: -------------------------------------------------------------------------------- 1 | Using /home/pozdn/.cache/torch_extensions/py38_cu111 as PyTorch extensions root... 2 | Detected CUDA files, patching ldflags 3 | Emitting ninja build file /home/pozdn/.cache/torch_extensions/py38_cu111/sparse_accumulation_cuda/build.ninja... 4 | Building extension module sparse_accumulation_cuda... 5 | Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) 6 | ninja: no work to do. 7 | Loading extension module sparse_accumulation_cuda... 8 | transformation rule is computed 9 | L_MAX=5; BATCH_SIZE=10000; N_FEATURES=200; sparse dim length = 11; sparse indices length = 126 10 | preparing real life transformation rule 11 | ************* 12 | CUDA 13 | ************* 14 | ***forward*** 15 | python loops; active dim 0; forward; cuda: 0.025215061399671765 16 | torch index_add_; active dim 0; forward; cuda: 0.026948707580566406 17 | 18 | python loops; active dim 1; forward; cuda: 0.03350805706448025 19 | torch index_add_; active dim 1; forward; cuda: 0.02625405163235135 20 | 21 | python loops; active dim 2; forward; cuda: 0.19782067362467448 22 | torch index_add_; active dim 2; forward; cuda: 0.02156950018141005 23 | dense matrix multiply: 0.007511555565728082 24 | sparse matrix multiply; active dim 2; forward; cuda: 0.17568610127766926 25 | sparse matrix optimized multiply; active dim 0; forward; cuda: 0.2286965043809679 26 | CUDA kernel; active dim 2; forward; cuda: 0.0018692764573627048 27 | ***backward*** 28 | python loops; active dim 0; backward; cuda: 0.4620384724934896 29 | torch index_add_; active dim 0; backward; cuda: 0.03216430155436198 30 | 31 | python loops; active dim 1; backward; cuda: 0.5971531100802951 32 | torch index_add_; active dim 1; backward; cuda: 0.035279142591688364 33 | 34 | python loops; active dim 2; backward; cuda: 0.309046383327908 35 | torch index_add_; active dim 2; backward; cuda: 0.3086910502115885 36 | CUDA kernel; active dim 2; backward; cuda: 0.004007651567459107 37 | dense matrix multiply: 0.023251587549845378 38 | -------------------------------------------------------------------------------- /benchmarks/old/benchmark_contiguous_tmp.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils import cpp_extension 6 | 7 | from sparse_accumulation_cpu import sparse_accumulation_active_dim_last, sparse_accumulation_active_dim_first, sparse_accumulation_active_dim_middle 8 | from clebsch_gordan import ClebschGordan, get_real_clebsch_gordan 9 | from sparse_accumulation_plain_torch import ( 10 | sparse_accumulation_index_add, 11 | sparse_accumulation_loops, 12 | get_transformation, 13 | get_transformation_sparse, 14 | sparse_accumulation_matrix_multiply, 15 | sparse_accumulation_sparse_matrix_multiply, 16 | sparse_accumulation_sparse_matrix_multiply_optimized 17 | ) 18 | 19 | cpp_extension.load( 20 | name="sparse_accumulation_cuda", 21 | sources=["cuda_optimized/sparse_accumulation_cuda_kernel2D.cu"], 22 | is_python_module=False, 23 | extra_cuda_cflags=None, 24 | verbose=True, 25 | ) 26 | 27 | L_MAX = 5 28 | BATCH_SIZE = 10000 29 | N_FEATURES = 200 30 | 31 | clebsch = ClebschGordan(L_MAX).precomputed_ 32 | indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX) 33 | 34 | m1_aligned, m2_aligned = [], [] 35 | multipliers, mu_aligned = [], [] 36 | for mu in range(0, 2 * L_MAX + 1): 37 | for el in indices[mu]: 38 | m1, m2, multiplier = el 39 | m1_aligned.append(m1) 40 | m2_aligned.append(m2) 41 | multipliers.append(multiplier) 42 | mu_aligned.append(mu) 43 | m1_aligned = torch.LongTensor(m1_aligned) 44 | m2_aligned = torch.LongTensor(m2_aligned) 45 | mu_aligned = torch.LongTensor(mu_aligned) 46 | multipliers = torch.FloatTensor(multipliers) 47 | 48 | indices = np.argsort(mu_aligned) 49 | 50 | m1_aligned = m1_aligned[indices] 51 | m2_aligned = m2_aligned[indices] 52 | mu_aligned = mu_aligned[indices] 53 | multipliers = multipliers[indices] 54 | 55 | 56 | transformation = get_transformation(mu_aligned, 2 * L_MAX + 1, 2 * L_MAX + 1, 2 * L_MAX + 1, 57 | m1_aligned, m2_aligned, multipliers) 58 | transformation_sparse = get_transformation_sparse(mu_aligned, 2 * L_MAX + 1, 2 * L_MAX + 1, 2 * L_MAX + 1, 59 | m1_aligned, m2_aligned, multipliers) 60 | 61 | print("transformation rule is computed") 62 | 63 | 64 | print( 65 | f"{L_MAX=}; {BATCH_SIZE=}; {N_FEATURES=}; sparse dim length = {2 * L_MAX + 1}; sparse indices length = {multipliers.shape[0]}" 66 | ) 67 | print("preparing real life transformation rule") 68 | 69 | USE_FLOAT64 = False 70 | if USE_FLOAT64: 71 | multipliers = multipliers.to(dtype=torch.float64) 72 | torch.set_default_dtype(torch.float64) 73 | 74 | 75 | def get_input(BATCH_SIZE, N_FEATURES, active_dim, device): 76 | if active_dim == 0: 77 | X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device=device) 78 | X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES, device=device) 79 | 80 | if active_dim == 1: 81 | X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device=device) 82 | X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES, device=device) 83 | 84 | if active_dim == 2: 85 | X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device=device) 86 | X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1, device=device) 87 | 88 | if (active_dim != 0) and (active_dim != 2) and (active_dim != 1): 89 | raise ValueError("active dim should be one of 0, 1, 2") 90 | 91 | return X1, X2 92 | 93 | 94 | def benchmark_forward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 95 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cpu") 96 | times = [] 97 | 98 | for _ in range(n_trials): 99 | begin = time.time() 100 | output = function( 101 | X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers 102 | ) 103 | times.append(time.time() - begin) 104 | return times 105 | 106 | 107 | def benchmark_forward_gpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 108 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 109 | times = [] 110 | torch.cuda.synchronize("cuda") 111 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 112 | enable_timing=True 113 | ) 114 | 115 | for _ in range(n_trials): 116 | starter.record() 117 | output = function( 118 | X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers 119 | ) 120 | ender.record() 121 | torch.cuda.synchronize("cuda") 122 | delta_time = starter.elapsed_time(ender) 123 | times.append(delta_time / 1000.0) 124 | return times 125 | 126 | def benchmark_forward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 127 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 128 | times = [] 129 | torch.cuda.synchronize("cuda") 130 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 131 | enable_timing=True 132 | ) 133 | 134 | for _ in range(n_trials): 135 | starter.record() 136 | output = sparse_accumulation_matrix_multiply( 137 | X1, X2, transformation 138 | ) 139 | ender.record() 140 | torch.cuda.synchronize("cuda") 141 | delta_time = starter.elapsed_time(ender) 142 | times.append(delta_time / 1000.0) 143 | return times 144 | 145 | def benchmark_forward_sparse_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 146 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 147 | times = [] 148 | torch.cuda.synchronize("cuda") 149 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 150 | enable_timing=True 151 | ) 152 | 153 | for _ in range(n_trials): 154 | starter.record() 155 | output = sparse_accumulation_sparse_matrix_multiply( 156 | X1, X2, transformation_sparse 157 | ) 158 | ender.record() 159 | torch.cuda.synchronize("cuda") 160 | delta_time = starter.elapsed_time(ender) 161 | times.append(delta_time / 1000.0) 162 | #print(output.shape) 163 | return times 164 | 165 | def benchmark_forward_sparse_matrix_multiply_optimized_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 166 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 167 | times = [] 168 | torch.cuda.synchronize("cuda") 169 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 170 | enable_timing=True 171 | ) 172 | 173 | for _ in range(n_trials): 174 | starter.record() 175 | output = sparse_accumulation_sparse_matrix_multiply_optimized( 176 | X1, X2, transformation_sparse 177 | ) 178 | ender.record() 179 | torch.cuda.synchronize("cuda") 180 | delta_time = starter.elapsed_time(ender) 181 | times.append(delta_time / 1000.0) 182 | #print(output.shape) 183 | return times 184 | 185 | 186 | def benchmark_backward_cpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 187 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cpu") 188 | 189 | X1.requires_grad = True 190 | X2.requires_grad = True 191 | times = [] 192 | for _ in range(n_trials): 193 | output = function( 194 | X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers 195 | ) 196 | begin = time.time() 197 | output.backward(gradient=torch.ones_like(output)) 198 | times.append(time.time() - begin) 199 | return np.array(times) 200 | 201 | 202 | def benchmark_backward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, active_dim, n_trials): 203 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 204 | 205 | X1.requires_grad = True 206 | X2.requires_grad = True 207 | times = [] 208 | 209 | torch.cuda.synchronize("cuda") 210 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 211 | enable_timing=True 212 | ) 213 | 214 | for _ in range(n_trials): 215 | output = sparse_accumulation_matrix_multiply( 216 | X1, X2, transformation 217 | ) 218 | torch.cuda.synchronize("cuda") 219 | starter.record() 220 | output.backward(gradient=torch.ones_like(output)) 221 | 222 | ender.record() 223 | torch.cuda.synchronize("cuda") 224 | delta_time = starter.elapsed_time(ender) 225 | times.append(delta_time / 1000.0) 226 | return np.array(times) 227 | 228 | def benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 229 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 230 | 231 | X1.requires_grad = True 232 | X2.requires_grad = True 233 | times = [] 234 | 235 | torch.cuda.synchronize("cuda") 236 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 237 | enable_timing=True 238 | ) 239 | 240 | for _ in range(n_trials): 241 | output = function( 242 | X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers 243 | ) 244 | torch.cuda.synchronize("cuda") 245 | starter.record() 246 | output.backward(gradient=torch.ones_like(output)) 247 | 248 | ender.record() 249 | torch.cuda.synchronize("cuda") 250 | delta_time = starter.elapsed_time(ender) 251 | times.append(delta_time / 1000.0) 252 | return np.array(times) 253 | 254 | 255 | def benchmark_backward_gpu_cuda(BATCH_SIZE, N_FEATURES, active_dim, function, n_trials): 256 | X1, X2 = get_input(BATCH_SIZE, N_FEATURES, active_dim, "cuda") 257 | 258 | X1.requires_grad = True 259 | X2.requires_grad = True 260 | times = [] 261 | 262 | torch.cuda.synchronize("cuda") 263 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( 264 | enable_timing=True 265 | ) 266 | 267 | for _ in range(n_trials): 268 | starter.record() 269 | output = function( 270 | torch.ones( 271 | (X1.size()[0], X1.size()[1], 2 * L_MAX + 1), 272 | dtype=X1.dtype, 273 | device="cuda", 274 | ), 275 | X1, 276 | X2, 277 | mu_aligned, 278 | m1_aligned, 279 | m2_aligned, 280 | multipliers, 281 | ) 282 | # output.backward(gradient=torch.ones_like(output)) 283 | ender.record() 284 | torch.cuda.synchronize("cuda") 285 | delta_time = starter.elapsed_time(ender) 286 | times.append(delta_time / 1000.0) 287 | return np.array(times) 288 | 289 | 290 | def get_func_fixed_dim(func, active_dim): 291 | def func_fixed_dim(*args): 292 | return func(*args, active_dim=active_dim) 293 | 294 | return func_fixed_dim 295 | 296 | 297 | print("*************") 298 | print("CUDA") 299 | print("*************") 300 | 301 | m1_aligned = m1_aligned.cuda() 302 | m2_aligned = m2_aligned.cuda() 303 | mu_aligned = mu_aligned.cuda() 304 | multipliers = multipliers.cuda() 305 | transformation = transformation.cuda() 306 | transformation_sparse = transformation_sparse.cuda() 307 | print("***forward***") 308 | times = benchmark_forward_gpu( 309 | BATCH_SIZE, N_FEATURES, 0, get_func_fixed_dim(sparse_accumulation_loops, 0), 10 310 | ) 311 | print("python loops; active dim 0; forward; cuda: ", np.mean(times[1:])) 312 | times = benchmark_forward_gpu( 313 | BATCH_SIZE, N_FEATURES, 0, get_func_fixed_dim(sparse_accumulation_index_add, 0), 10 314 | ) 315 | print("torch index_add_; active dim 0; forward; cuda: ", np.mean(times[1:])) 316 | 317 | 318 | print() 319 | times = benchmark_forward_gpu( 320 | BATCH_SIZE, N_FEATURES, 1, get_func_fixed_dim(sparse_accumulation_loops, 1), 10 321 | ) 322 | print("python loops; active dim 1; forward; cuda: ", np.mean(times[1:])) 323 | times = benchmark_forward_gpu( 324 | BATCH_SIZE, N_FEATURES, 1, get_func_fixed_dim(sparse_accumulation_index_add, 1), 10 325 | ) 326 | print("torch index_add_; active dim 1; forward; cuda: ", np.mean(times[1:])) 327 | 328 | 329 | print() 330 | times = benchmark_forward_gpu( 331 | BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_loops, 2), 10 332 | ) 333 | print("python loops; active dim 2; forward; cuda: ", np.mean(times[1:])) 334 | times = benchmark_forward_gpu( 335 | BATCH_SIZE, N_FEATURES, 2, get_func_fixed_dim(sparse_accumulation_index_add, 2), 10 336 | ) 337 | print("torch index_add_; active dim 2; forward; cuda: ", np.mean(times[1:])) 338 | times = benchmark_forward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, 2, 10) 339 | print("dense matrix multiply: ", np.mean(times[1:])) 340 | 341 | times = benchmark_forward_sparse_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, 2, 10) 342 | print("sparse matrix multiply; active dim 2; forward; cuda: ", np.mean(times[1:])) 343 | 344 | times = benchmark_forward_sparse_matrix_multiply_optimized_gpu(BATCH_SIZE, N_FEATURES, 0, 10) 345 | print("sparse matrix optimized multiply; active dim 0; forward; cuda: ", np.mean(times[1:])) 346 | 347 | 348 | times = benchmark_forward_gpu( 349 | BATCH_SIZE, N_FEATURES, 2, torch.ops.sparse_accumulation_cuda.forward, 10 350 | ) 351 | print("CUDA kernel; active dim 2; forward; cuda: ", np.mean(times[1:])) 352 | # times = benchmark_forward_gpu( 353 | # BATCH_SIZE, N_FEATURES, 2, torch.ops.sparse_accumulation_cuda.forward_grpwrites, 10 354 | # ) 355 | # print("CUDA kernel grpwrites; active dim 2; forward; cuda: ", np.mean(times[1:])) 356 | 357 | print("***backward***") 358 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 0, 359 | get_func_fixed_dim(sparse_accumulation_loops, 0), 10) 360 | print("python loops; active dim 0; backward; cuda: ", np.mean(times[1:])) 361 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 0, 362 | get_func_fixed_dim(sparse_accumulation_index_add, 0), 10) 363 | print("torch index_add_; active dim 0; backward; cuda: ", np.mean(times[1:])) 364 | #times = benchmark_backward(BATCH_SIZE, N_FEATURES, 0, 365 | # sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply, 10) 366 | #print("cpp; active dim 0; backward: ", np.mean(times[1:])) 367 | 368 | print() 369 | 370 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 1, 371 | get_func_fixed_dim(sparse_accumulation_loops, 1), 10) 372 | print("python loops; active dim 1; backward; cuda: ", np.mean(times[1:])) 373 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 1, 374 | get_func_fixed_dim(sparse_accumulation_index_add, 1), 10) 375 | print("torch index_add_; active dim 1; backward; cuda: ", np.mean(times[1:])) 376 | #times = benchmark_backward(BATCH_SIZE, N_FEATURES, 1, 377 | # sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply, 10) 378 | #print("cpp; active dim 1; backward: ", np.mean(times[1:])) 379 | 380 | 381 | print() 382 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 2, 383 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 384 | print("python loops; active dim 2; backward; cuda: ", np.mean(times[1:])) 385 | times = benchmark_backward_gpu(BATCH_SIZE, N_FEATURES, 2, 386 | get_func_fixed_dim(sparse_accumulation_index_add, 2), 10) 387 | print("torch index_add_; active dim 2; backward; cuda: ", np.mean(times[1:])) 388 | times = benchmark_backward_gpu_cuda(BATCH_SIZE, N_FEATURES, 2, 389 | torch.ops.sparse_accumulation_cuda.backward, 10) 390 | print("CUDA kernel; active dim 2; backward; cuda: ", np.mean(times[1:])) 391 | 392 | times = benchmark_backward_matrix_multiply_gpu(BATCH_SIZE, N_FEATURES, 2, 10) 393 | print("dense matrix multiply: ", np.mean(times[1:])) 394 | 395 | 396 | #times = benchmark_backward(BATCH_SIZE, N_FEATURES, 2, sparse_accumulation.SparseAccumulation.apply, 10) 397 | #print("cpp; active dim 2; backward: ", np.mean(times[1:])) 398 | -------------------------------------------------------------------------------- /ci/docker/Dockerfile.build: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.03-py3 2 | 3 | COPY . /sparse_accumulation 4 | 5 | WORKDIR /sparse_accumulation 6 | RUN pip install sympy -------------------------------------------------------------------------------- /ci/pipeline.yml: -------------------------------------------------------------------------------- 1 | include: 2 | - remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml' 3 | 4 | stages: 5 | - build 6 | - test 7 | 8 | variables: 9 | PERSIST_IMAGE_NAME: $CSCS_REGISTRY_PATH/sparse_accumulation:$CI_COMMIT_REF_NAME 10 | 11 | build_job: 12 | stage: build 13 | extends: .container-builder 14 | variables: 15 | DOCKERFILE: ci/docker/Dockerfile.build 16 | 17 | test_job: 18 | stage: test 19 | extends: .container-runner-daint-gpu 20 | image: $PERSIST_IMAGE_NAME 21 | script: 22 | - cd /sparse_accumulation 23 | - python3 -m pip install . --user 24 | - python3 -m pytest tests/test_cpp_contiguous.py 25 | variables: 26 | SLURM_JOB_NUM_NODES: 1 27 | SLURM_PARTITION: normal 28 | SLURM_NTASKS: 1 -------------------------------------------------------------------------------- /docs/CG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-cosmo/sparse_accumulation/04068ed2b84f51fce3b094333bca18f0f369412c/docs/CG.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = ../../build/ 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/benchmarks.rst: -------------------------------------------------------------------------------- 1 | Benchmarks 2 | ========== 3 | 4 | CPU benchmarks 5 | ---------------------- 6 | 7 | .. include:: ../benchmarks/benchmark_contiguous_cpu.out 8 | :literal: 9 | 10 | GPU benchmarks 11 | ----------------------- 12 | 13 | .. include:: ../benchmarks/benchmark_contiguous_gpu.out 14 | :literal: 15 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('.')) 16 | sys.path.insert(0, os.path.abspath('../sparse_accumulation')) 17 | 18 | import sphinx_rtd_theme 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'sparse accumulation' 22 | copyright = 'q' 23 | author = 'q' 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | 32 | 33 | 34 | extensions = [ 35 | "sphinx_rtd_theme", "nbsphinx", "sphinxcontrib.napoleon", "sphinx_togglebutton" 36 | ] 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ['_templates'] 40 | 41 | # List of patterns, relative to source directory, that match files and 42 | # directories to ignore when looking for source files. 43 | # This pattern also affects html_static_path and html_extra_path. 44 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 45 | 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = 'sphinx_rtd_theme' 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | html_static_path = ['_static'] 58 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. q 2 | 3 | .. include:: ../README.rst 4 | :start-after: inclusion-marker-preambule-start 5 | :end-before: inclusion-marker-preambule-end 6 | 7 | .. toctree:: 8 | :glob: 9 | :maxdepth: 1 10 | :caption: Tutorial 11 | 12 | tutorial 13 | 14 | 15 | .. toctree:: 16 | :glob: 17 | :maxdepth: 1 18 | :caption: Benchmarks 19 | 20 | benchmarks 21 | 22 | 23 | .. toctree:: 24 | :glob: 25 | :maxdepth: 1 26 | :caption: Reference Guide 27 | 28 | reference_guide -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR= ../../build/ 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/reference_guide.rst: -------------------------------------------------------------------------------- 1 | accumulate 2 | ========== 3 | 4 | .. automodule:: sparse_accumulation 5 | :noindex: 6 | 7 | .. autofunction:: accumulate 8 | 9 | .. autofunction:: accumulate_active_dim_middle 10 | 11 | .. autofunction:: accumulate_active_dim_first 12 | 13 | get_cg_transformation_rule 14 | ========================== 15 | 16 | .. automodule:: sparse_accumulation 17 | :noindex: 18 | 19 | .. autofunction:: get_cg_transformation_rule 20 | 21 | CGCalculatorSingle 22 | ================== 23 | 24 | .. automodule:: sparse_accumulation 25 | :noindex: 26 | 27 | .. autoclass:: CGCalculatorSingle 28 | :members: 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "507700e8", 6 | "metadata": {}, 7 | "source": [ 8 | "# Sparse accumulation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "bfb72c10", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import numpy as np\n", 20 | "from sparse_accumulation import accumulate" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "a9059993", 26 | "metadata": {}, 27 | "source": [ 28 | "## Preparing a dummy data" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "d17e5b89", 34 | "metadata": {}, 35 | "source": [ 36 | "Let's prepare some dummy data to play with:" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "fedcfb51", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "X1 = torch.randn(10, 20, 3)\n", 47 | "X2 = torch.randn(10, 20, 4)\n", 48 | "\n", 49 | "m1 = torch.LongTensor([0, 1, 1, 2])\n", 50 | "m2 = torch.LongTensor([0, 0, 3, 1])\n", 51 | "mu = torch.LongTensor([0, 3, 1, 2])\n", 52 | "\n", 53 | "C = torch.FloatTensor([0.17, 0.23, 0.4, -0.9])\n", 54 | "output_size = 42" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "91f8a702", 60 | "metadata": {}, 61 | "source": [ 62 | "**Important** sparse accumulation operation requires mu tensor to be sorted to work correctly.\n", 63 | "\n", 64 | "It is very clear that the result of the sparse accumulation operation doesn't change for the simultaneous permutation of all the tensors m1, m2, mu, and C since the result of the summation doesn't depend on the order of the terms. Thus, it is always reachable to have mu tensor to be sorted, and one can achieve this as simply as:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "d73214b2", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "indices = np.argsort(mu)\n", 75 | "\n", 76 | "m1 = m1[indices]\n", 77 | "m2 = m2[indices]\n", 78 | "mu = mu[indices]\n", 79 | "C = C[indices]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "480604ea", 85 | "metadata": {}, 86 | "source": [ 87 | "## `accumulate` function" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "217e437b", 93 | "metadata": {}, 94 | "source": [ 95 | "The main function which does sparse accumulation operation is called `accumulate`. It can be invoked like this:" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "70fbe0e3", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "torch.Size([10, 20, 42]) cpu\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "output = accumulate(X1, X2, mu, output_size, m1, m2, C)\n", 114 | "print(output.shape, output.device)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "62309235", 120 | "metadata": {}, 121 | "source": [ 122 | "Since the input tensors are located on cpu, the pytorch cpu extension was invoked internally. " 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "f9ebc863", 128 | "metadata": {}, 129 | "source": [ 130 | "Now let's move our dummy data to the gpu:" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "id": "87ff70ef", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "X1_cuda = X1.cuda()\n", 141 | "X2_cuda = X2.cuda()\n", 142 | "m1_cuda = m1.cuda()\n", 143 | "m2_cuda = m2.cuda()\n", 144 | "mu_cuda = mu.cuda()\n", 145 | "C_cuda = C.cuda()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "181f0299", 151 | "metadata": {}, 152 | "source": [ 153 | "The call is exactly the same:" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 6, 159 | "id": "b646ba32", 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stdout", 164 | "output_type": "stream", 165 | "text": [ 166 | "torch.Size([10, 20, 42]) cuda:0\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "output = accumulate(X1_cuda, X2_cuda, mu_cuda, 42, m1_cuda, m2_cuda, C_cuda)\n", 172 | "print(output.shape, output.device)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "4bf7822c", 178 | "metadata": {}, 179 | "source": [ 180 | "This time our cuda kernel was invoked interenally since the input tensors are located on gpu" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "id": "cf5d4700", 186 | "metadata": {}, 187 | "source": [ 188 | "# [optional] Clebsch-Gordan iteration" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "id": "76567b2e", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "from sparse_accumulation import get_cg_transformation_rule" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "id": "90013ea9", 204 | "metadata": {}, 205 | "source": [ 206 | "## precomputing Clebsch-Gordan transformation rule" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "9f89ad82", 212 | "metadata": {}, 213 | "source": [ 214 | "If we want the sparse accumulation operation to do the actual Clebsch-Gordan iteration we need to precompute the corresponding transformation rule and populate the arrays ``m1``, ``m2``, ``mu`` and ``C`` with the actual Clebsch-Gordan coefficients. " 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 8, 220 | "id": "423cb660", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "torch.Size([93]) torch.Size([93]) torch.Size([93]) torch.Size([93])\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "l1 = 3\n", 233 | "l2 = 4\n", 234 | "l_output = 5\n", 235 | "m1, m2, mu, C = get_cg_transformation_rule(l1, l2, l_output)\n", 236 | "print(m1.shape, m2.shape, mu.shape, C.shape)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "4a219bc6", 242 | "metadata": {}, 243 | "source": [ 244 | "The mentioned above sorting operation is not required now since it has been already performed inside `get_cg_transformation_rule`" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "4af5b470", 250 | "metadata": {}, 251 | "source": [ 252 | "Now, given this transformation rule, sparse accumulation operation performs actual CG iteration, producing \n", 253 | "covariant vectors with l = l_output given covariant vectors with l = l1 and l = l2:" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 9, 259 | "id": "61d89904", 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "torch.Size([10, 20, 11])\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "X1 = torch.randn(10, 20, 2 * l1 + 1)\n", 272 | "X2 = torch.randn(10, 20, 2 * l2 + 1)\n", 273 | "output = accumulate(X1, X2, mu, 2 * l_output + 1, m1, m2, C)\n", 274 | "print(output.shape)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "9a1e81a9", 280 | "metadata": {}, 281 | "source": [ 282 | "## Clebsch-Gordan Calculator" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "id": "b01a9cfc", 288 | "metadata": {}, 289 | "source": [ 290 | "It makes sense to wrap up the mentioned steps into the class, where the CG transformation rule is computed during initialization, and next is used in the forward method. We provide such a class called `CGCalculatorSingle`" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 10, 296 | "id": "21ea96c8", 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "from sparse_accumulation import CGCalculatorSingle" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 11, 306 | "id": "29baa90b", 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "torch.Size([10, 20, 11]) cpu\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "calc = CGCalculatorSingle(l1, l2, l_output)\n", 319 | "output = calc(X1, X2)\n", 320 | "print(output.shape, output.device)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "id": "4becad64", 326 | "metadata": {}, 327 | "source": [ 328 | "This class supports convenient reallocation to gpu:" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 12, 334 | "id": "22bb4fdf", 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "torch.Size([10, 20, 11]) cuda:0\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "calc = calc.cuda()\n", 347 | "output = calc(X1.cuda(), X2.cuda())\n", 348 | "print(output.shape, output.device)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "id": "03b662ec", 354 | "metadata": {}, 355 | "source": [ 356 | "All the tensors constituting the transformation rule (m1, m2, mu, and C) are stored as buffers, not the parameters, so they will not be optimized." 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "id": "297b2fbb", 362 | "metadata": {}, 363 | "source": [ 364 | "[todo] add raise Value error to accumulate function if the size of shared memory is insufficient; mention it here.\n", 365 | "[todo] add fallback to alternative (select which one) implementation to the CGCalculatorSingle if the size of shared memory is insufficient; mention it here." 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "5c97342a", 371 | "metadata": {}, 372 | "source": [ 373 | "# Outdated" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "id": "a69963ab", 379 | "metadata": {}, 380 | "source": [ 381 | "The goal is to compute this for all $\\mu$:\n", 382 | "\n", 383 | "$\\text{Output}[:, :, \\mu] = \\sum\\limits_{m_1, m_2} \\text{X_1}[:, :, m_1] * \\text{X_2}[:, :, m_2] * C_{m_1, m_2, \\mu}$\n", 384 | "\n", 385 | "This is the subpart of the Clebsch-Gordan iteration for fixed l1, l2, and l. The first two dimensions are the \"dense\" ones, so the same operation is performed for all the indices in the first two dimensions. \n", 386 | "\n", 387 | "Since Clebsch-Gordan coefficients are very sparse, it is worthwhile to align them into a 1-dimensional tensor containing only non-zero values, but in this case, we need to supply this tensor with supplementary indices tensors telling us what are the corresponding m1, m2, and $\\mu$ indices. " 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "id": "c1d56ce1", 393 | "metadata": {}, 394 | "source": [ 395 | "Reference slow python implementation is as simple as this:" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 13, 401 | "id": "f1d8e38f", 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "def sparse_accumulation_loops(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers):\n", 406 | " device = X1.device #all tensors must be on the same device and blah, blah, blah \n", 407 | " dtype = X1.dtype \n", 408 | " \n", 409 | " output = torch.zeros([X1.shape[0], X2.shape[1], output_size], device = device,dtype=dtype)\n", 410 | " for index in range(idx_output.shape[0]): \n", 411 | " output[:, :, idx_output[index]] += X1[:, :, idx_1[index]] * X2[:, :, idx_2[index]] * multipliers[index]\n", 412 | " return output" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "7f27d3b2", 418 | "metadata": {}, 419 | "source": [ 420 | "Here multipliers are the values of Clebsch-Gordan coefficients, idx_1 is the tensor containing corresponding m1 indices, idx_2 is the tensor containing corresponding m2 indices, and idx_output is the tensor containing $\\mu$ indices. output_size is just a single integer, the desired length of the output (2 * l + 1). \n", 421 | "\n", 422 | "So the loops go over all the terms, for all $\\mu$, m1, and m2 with non-zero clebsch-gordan coefficients, and the current contribution is added to the output array to the proper place defined by $\\mu$ which is stored in the idx_output\n", 423 | "\n", 424 | "The first two dense dimensions are introduced, keeping in mind batch and feature dimensions. If you need just 1, it is possible to introduce a dummy dimension of size 1 ^^. \n", 425 | "\n", 426 | "\n", 427 | "The transformation itself, i.e., Clebsch-Gordan coefficients, can be precomputed once at the very beginning. This repo among the other things contains the code for this:" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 14, 433 | "id": "41d0ab81", 434 | "metadata": {}, 435 | "outputs": [ 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "L_MAX: \n", 441 | "multipliers shape: torch.Size([126])\n", 442 | "m1_aligned shape: torch.Size([126])\n", 443 | "m2_aligned shape: torch.Size([126])\n", 444 | "multipliers shape: torch.Size([126])\n" 445 | ] 446 | } 447 | ], 448 | "source": [ 449 | "from sparse_accumulation.clebsch_gordan import ClebschGordan, get_real_clebsch_gordan\n", 450 | "L_MAX = 5\n", 451 | "clebsch = ClebschGordan(L_MAX).precomputed_\n", 452 | "indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX)\n", 453 | "\n", 454 | "m1_aligned, m2_aligned = [], []\n", 455 | "multipliers, mu_aligned = [], []\n", 456 | "for mu in range(0, 2 * L_MAX + 1):\n", 457 | " for el in indices[mu]:\n", 458 | " m1, m2, multiplier = el\n", 459 | " m1_aligned.append(m1)\n", 460 | " m2_aligned.append(m2)\n", 461 | " multipliers.append(multiplier)\n", 462 | " mu_aligned.append(mu)\n", 463 | "m1_aligned = torch.LongTensor(m1_aligned)\n", 464 | "m2_aligned = torch.LongTensor(m2_aligned)\n", 465 | "mu_aligned = torch.LongTensor(mu_aligned)\n", 466 | "multipliers = torch.FloatTensor(multipliers)\n", 467 | "\n", 468 | "indices = np.argsort(mu_aligned)\n", 469 | "\n", 470 | "m1_aligned = m1_aligned[indices].cuda()\n", 471 | "m2_aligned = m2_aligned[indices].cuda()\n", 472 | "mu_aligned = mu_aligned[indices].cuda()\n", 473 | "multipliers = multipliers[indices].cuda()\n", 474 | "\n", 475 | "print(\"L_MAX: \")\n", 476 | "print(\"multipliers shape: \", multipliers.shape)\n", 477 | "print(\"m1_aligned shape: \", m1_aligned.shape)\n", 478 | "print(\"m2_aligned shape: \", m2_aligned.shape)\n", 479 | "print(\"multipliers shape: \", multipliers.shape)" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "id": "b15db4ef", 485 | "metadata": {}, 486 | "source": [ 487 | "This is a simple wrapper on sympy package, and the definition of the real clebsch-gordan coefficients is consistent with librascal real spherical harmonics, nice, wigner iterations, and rascaline\n", 488 | "\n", 489 | "Now we can do the Clebsch-Gordan iteration:" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 15, 495 | "id": "7ee0c258", 496 | "metadata": {}, 497 | "outputs": [ 498 | { 499 | "name": "stdout", 500 | "output_type": "stream", 501 | "text": [ 502 | "torch.Size([100, 17, 11])\n" 503 | ] 504 | } 505 | ], 506 | "source": [ 507 | "X1 = torch.randn(100, 17, 2 * L_MAX + 1).cuda()\n", 508 | "X2 = torch.randn(100, 17, 2 * L_MAX + 1).cuda()\n", 509 | "\n", 510 | "output_loops = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers)\n", 511 | "print(output_loops.shape)" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "id": "e57d475e", 517 | "metadata": {}, 518 | "source": [ 519 | "You can take a look at the benchmarks files .py along with their output .out to get an idea 1) how to benchmark this properly with gpu synchronization and 2) the speed of this operation compared to a naive implementation" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "id": "5aa8c1ef", 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [] 529 | } 530 | ], 531 | "metadata": { 532 | "kernelspec": { 533 | "display_name": "Python 3 (ipykernel)", 534 | "language": "python", 535 | "name": "python3" 536 | }, 537 | "language_info": { 538 | "codemirror_mode": { 539 | "name": "ipython", 540 | "version": 3 541 | }, 542 | "file_extension": ".py", 543 | "mimetype": "text/x-python", 544 | "name": "python", 545 | "nbconvert_exporter": "python", 546 | "pygments_lexer": "ipython3", 547 | "version": "3.9.12" 548 | } 549 | }, 550 | "nbformat": 4, 551 | "nbformat_minor": 5 552 | } 553 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "torch"] 3 | 4 | [project] 5 | name = "sparse_accumulation" 6 | version = "0.0.0" 7 | description = "A package that contains significantly optimized CPU and GPU PyTorch extensions for the sparse accumulation operation" 8 | authors = [ 9 | {name = "Sergey Pozdnyakov"}, 10 | {name = "Davide Tisi"}, 11 | {name = "Prashanth Kanduri"}, 12 | {name = "Filippo Bigi"}, 13 | {name = "Henrique Mendonça"}, 14 | {name = "Guillaume Fraux"}, 15 | ] 16 | readme = "README.rst" 17 | requires-python = ">=3.8" 18 | classifiers = [ 19 | "Development Status :: 2 - Pre-Alpha", 20 | "Environment :: Console", 21 | "Intended Audience :: Science/Research", 22 | "Natural Language :: English", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: C++", 25 | "Topic :: Scientific/Engineering", 26 | ] 27 | dependencies = [ 28 | "numpy", 29 | "torch", 30 | ] 31 | 32 | 33 | [project.optional-dependencies] 34 | test = [ 35 | "pytest", 36 | "os", 37 | ] 38 | 39 | 40 | [project.urls] 41 | homepage = "https://lab-cosmo.github.io/sparse_accumulation/" 42 | documentation = "https://lab-cosmo.github.io/sparse_accumulation/" 43 | repository = "https://github.com/lab-cosmo/sparse_accumulation" 44 | 45 | 46 | [tool.setuptools.packages.find] 47 | where = ["sparse_accumulation"] 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension, find_packages 2 | from torch.utils import cpp_extension 3 | from torch import cuda 4 | 5 | if cuda.is_available(): 6 | ext_cuda = cpp_extension.CUDAExtension("sparse_accumulation_cuda", 7 | ["sparse_accumulation/cuda_extension/sparse_accumulation_cuda_kernel2D.cu"]) 8 | 9 | ext_first = cpp_extension.CppExtension('sparse_accumulation_active_dim_first_cpp', 10 | ['sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_first.cpp'], 11 | extra_compile_args=['-fopenmp']) 12 | 13 | ext_middle = cpp_extension.CppExtension('sparse_accumulation_active_dim_middle_cpp', 14 | ['sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_middle.cpp'], 15 | extra_compile_args=['-fopenmp']) 16 | 17 | ext_last = cpp_extension.CppExtension('sparse_accumulation_active_dim_last_cpp', 18 | ['sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_last.cpp'], 19 | extra_compile_args=['-fopenmp']) 20 | 21 | ext_modules = [ext_first, ext_middle, ext_last] 22 | if cuda.is_available(): ext_modules.append(ext_cuda) 23 | 24 | setup(name='sparse_accumulation', 25 | packages = find_packages(), 26 | ext_modules = ext_modules, 27 | cmdclass={'build_ext': cpp_extension.BuildExtension}) 28 | 29 | -------------------------------------------------------------------------------- /sparse_accumulation/__init__.py: -------------------------------------------------------------------------------- 1 | from .unified_operation import accumulate 2 | from .clebsch_gordan import get_cg_transformation_rule, CGCalculatorSingle 3 | from .other_operations import accumulate_active_dim_middle, accumulate_active_dim_first 4 | -------------------------------------------------------------------------------- /sparse_accumulation/clebsch_gordan.py: -------------------------------------------------------------------------------- 1 | from sympy import S 2 | from sympy.physics.wigner import clebsch_gordan 3 | 4 | try: 5 | import wigners 6 | except ImportError: 7 | wigners = None 8 | 9 | import numpy as np 10 | import torch 11 | from .unified_operation import accumulate 12 | 13 | def _compute_cg(l1, l2, l, m1, m2): 14 | if wigners is None: 15 | # use sympy 16 | return float(clebsch_gordan(S(l1), S(l2), S(l), S(m1), S(m2), S(m1 + m2))) 17 | else: 18 | if abs(m1) > l1 or abs(m2) > l2 or abs(m1 + m2) > l: 19 | return 0.0 20 | return wigners.clebsch_gordan(l1, m1, l2, m2, l, m1 + m2) 21 | 22 | 23 | class ClebschGordan: 24 | def __init__(self, l_max): 25 | self.l_max_ = l_max 26 | self.precomputed_ = np.zeros( 27 | [l_max + 1, l_max + 1, l_max + 1, 2 * l_max + 1, 2 * l_max + 1] 28 | ) 29 | 30 | for l1 in range(l_max + 1): 31 | for l2 in range(l_max + 1): 32 | for l in range(l_max + 1): 33 | for m1 in range(-l_max, l_max + 1): 34 | for m2 in range(-l_max, l_max + 1): 35 | now = _compute_cg(l1, l2, l, m1, m2) 36 | self.precomputed_[l1, l2, l, m1 + l1, m2 + l2] = now 37 | 38 | class PartialClebschGordan: 39 | def __init__(self, l1, l2, l_output): 40 | self.l1 = l1 41 | self.l2 = l2 42 | self.l_output = l_output 43 | 44 | self.values = np.zeros([2 * l1 + 1, 2 * l2 + 1]) 45 | for m1 in range(-l1, l1 + 1): 46 | for m2 in range(-l2, l2 + 1): 47 | self.values[m1 + l1, m2 + l2] = _compute_cg(l1, l2, l_output, m1, m2) 48 | 49 | def _multiply(first, second, multiplier): 50 | return [first[0], second[0], first[1] * second[1] * multiplier] 51 | 52 | 53 | def _multiply_sequence(sequence, multiplier): 54 | result = [] 55 | 56 | for el in sequence: 57 | # print(el) 58 | # print(len(el)) 59 | result.append([el[0], el[1], el[2] * multiplier]) 60 | return result 61 | 62 | 63 | def _get_conversion(l, m): 64 | if m < 0: 65 | X_re = [abs(m) + l, 1.0 / np.sqrt(2)] 66 | X_im = [m + l, -1.0 / np.sqrt(2)] 67 | if m == 0: 68 | X_re = [l, 1.0] 69 | X_im = [l, 0.0] 70 | if m > 0: 71 | if m % 2 == 0: 72 | X_re = [m + l, 1.0 / np.sqrt(2)] 73 | X_im = [-m + l, 1.0 / np.sqrt(2)] 74 | else: 75 | X_re = [m + l, -1.0 / np.sqrt(2)] 76 | X_im = [-m + l, -1.0 / np.sqrt(2)] 77 | return X_re, X_im 78 | 79 | 80 | def _compress(sequence, epsilon=1e-15): 81 | result = [] 82 | for i in range(len(sequence)): 83 | m1, m2, multiplier = sequence[i][0], sequence[i][1], sequence[i][2] 84 | already = False 85 | for j in range(len(result)): 86 | if (m1 == result[j][0]) and (m2 == result[j][1]): 87 | already = True 88 | break 89 | 90 | if not already: 91 | multiplier = 0.0 92 | for j in range(i, len(sequence)): 93 | if (m1 == sequence[j][0]) and (m2 == sequence[j][1]): 94 | multiplier += sequence[j][2] 95 | if np.abs(multiplier) > epsilon: 96 | result.append([m1, m2, multiplier]) 97 | # print(len(sequence), '->', len(result)) 98 | return result 99 | 100 | 101 | def get_real_clebsch_gordan(clebsch, l1, l2, lambd): 102 | result = [[] for _ in range(2 * lambd + 1)] 103 | for mu in range(0, lambd + 1): 104 | real_now = [] 105 | imag_now = [] 106 | for m2 in range(max(-l2, mu - l1), min(l2, mu + l1) + 1): 107 | m1 = mu - m2 108 | X1_re, X1_im = _get_conversion(l1, m1) 109 | X2_re, X2_im = _get_conversion(l2, m2) 110 | 111 | real_now.append(_multiply(X1_re, X2_re, clebsch[m1 + l1, m2 + l2])) 112 | real_now.append(_multiply(X1_im, X2_im, -clebsch[m1 + l1, m2 + l2])) 113 | 114 | imag_now.append(_multiply(X1_re, X2_im, clebsch[m1 + l1, m2 + l2])) 115 | imag_now.append(_multiply(X1_im, X2_re, clebsch[m1 + l1, m2 + l2])) 116 | # print(real_now) 117 | if (l1 + l2 - lambd) % 2 == 1: 118 | imag_now, real_now = real_now, _multiply_sequence(imag_now, -1) 119 | if mu > 0: 120 | if mu % 2 == 0: 121 | result[mu + lambd] = _multiply_sequence(real_now, np.sqrt(2)) 122 | result[-mu + lambd] = _multiply_sequence(imag_now, np.sqrt(2)) 123 | else: 124 | result[mu + lambd] = _multiply_sequence(real_now, -np.sqrt(2)) 125 | result[-mu + lambd] = _multiply_sequence(imag_now, -np.sqrt(2)) 126 | else: 127 | result[lambd] = real_now 128 | 129 | for i in range(len(result)): 130 | result[i] = _compress(result[i]) 131 | return result 132 | 133 | def check_l_consistency(l1, l2, l_output): 134 | if (l_output < abs(l1 - l2)) or (l_output > (l1 + l2)): 135 | raise ValueError("l_output must be in between |l1 - l2| and (l1 + l2)") 136 | 137 | def get_cg_transformation_rule(l1, l2, l_output, dtype = torch.float32, device = "cpu"): 138 | check_l_consistency(l1, l2, l_output) 139 | 140 | clebsch = PartialClebschGordan(l1, l2, l_output).values 141 | indices = get_real_clebsch_gordan(clebsch, l1, l2, l_output) 142 | 143 | m1_aligned, m2_aligned = [], [] 144 | multipliers, mu_aligned = [], [] 145 | for mu in range(2 * l_output + 1): 146 | for el in indices[mu]: 147 | m1, m2, multiplier = el 148 | m1_aligned.append(m1) 149 | m2_aligned.append(m2) 150 | multipliers.append(multiplier * 1.0) 151 | mu_aligned.append(mu) 152 | m1_aligned = torch.tensor(m1_aligned, dtype=torch.int64, device=device) 153 | m2_aligned = torch.tensor(m2_aligned, dtype=torch.int64, device=device) 154 | mu_aligned = torch.tensor(mu_aligned, dtype=torch.int64, device=device) 155 | multipliers = torch.tensor(multipliers, dtype=dtype, device=device) 156 | 157 | indices = np.argsort(mu_aligned) 158 | 159 | m1_aligned = m1_aligned[indices] 160 | m2_aligned = m2_aligned[indices] 161 | mu_aligned = mu_aligned[indices] 162 | multipliers = multipliers[indices] 163 | 164 | return m1_aligned, m2_aligned, mu_aligned, multipliers 165 | 166 | class CGCalculatorSingle(torch.nn.Module): 167 | def __init__(self, l1, l2, l_output, dtype = torch.float32): 168 | super(CGCalculatorSingle, self).__init__() 169 | check_l_consistency(l1, l2, l_output) 170 | self.l1 = l1 171 | self.l2 = l2 172 | self.l_output = l_output 173 | m1, m2, mu, C = get_cg_transformation_rule(l1, l2, l_output, dtype = dtype) 174 | self.register_buffer('m1', m1) 175 | self.register_buffer('m2', m2) 176 | self.register_buffer('mu', mu) 177 | self.register_buffer('C', C) 178 | 179 | def forward(self, X1, X2): 180 | return accumulate(X1, X2, self.mu, 2 * self.l_output + 1, self.m1, self.m2, self.C) 181 | -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-cosmo/sparse_accumulation/04068ed2b84f51fce3b094333bca18f0f369412c/sparse_accumulation/cpu_extension/__init__.py -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_first.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | using namespace torch::indexing; 6 | 7 | 8 | template 9 | void _sparse_accumulation_active_dim_first_contiguous_backward( 10 | torch::Tensor d_X1, 11 | torch::Tensor d_X2, 12 | torch::Tensor d_output, 13 | torch::Tensor X1, 14 | torch::Tensor X2, 15 | torch::Tensor idx_output, 16 | torch::Tensor idx_1, 17 | torch::Tensor idx_2, 18 | torch::Tensor multipliers) { 19 | 20 | scalar_t* d_X1_ptr = d_X1.data_ptr(); 21 | scalar_t* d_X2_ptr = d_X2.data_ptr(); 22 | scalar_t* d_output_ptr = d_output.data_ptr(); 23 | 24 | scalar_t* X1_ptr = X1.data_ptr(); 25 | scalar_t* X2_ptr = X2.data_ptr(); 26 | 27 | scalar_t* multipliers_ptr = multipliers.data_ptr(); 28 | long* idx_1_ptr = idx_1.data_ptr(); 29 | long* idx_2_ptr = idx_2.data_ptr(); 30 | long* idx_output_ptr = idx_output.data_ptr(); 31 | 32 | long active_size = idx_output.sizes()[0]; 33 | long first_size = X1.sizes()[1]; 34 | long second_size = X1.sizes()[2]; 35 | long inner_size = first_size * second_size; 36 | 37 | for (int index = 0; index < active_size; ++index) { 38 | long shift_active_x1 = idx_1_ptr[index] * inner_size; 39 | long shift_active_x2 = idx_2_ptr[index] * inner_size; 40 | long shift_active_output = idx_output_ptr[index] * inner_size; 41 | scalar_t multiplier = multipliers_ptr[index]; 42 | #pragma omp parallel for 43 | for (int index_first = 0; index_first < first_size; ++index_first) { 44 | // #pragma omp parallel for // This makes little difference 45 | for (int index_second = 0; index_second < second_size; ++index_second) { 46 | long shift_local = index_first * second_size + index_second; 47 | scalar_t grad = d_output_ptr[shift_active_output + shift_local] * multiplier; 48 | d_X1_ptr[shift_active_x1 + shift_local] += grad * X2_ptr[shift_active_x2 + shift_local]; 49 | d_X2_ptr[shift_active_x2 + shift_local] += grad * X1_ptr[shift_active_x1 + shift_local]; 50 | } 51 | } 52 | } 53 | 54 | } 55 | 56 | 57 | template 58 | void _sparse_accumulation_active_dim_first_contiguous_forward( 59 | torch::Tensor output, 60 | torch::Tensor X1, 61 | torch::Tensor X2, 62 | torch::Tensor idx_output, 63 | int output_size, 64 | torch::Tensor idx_1, 65 | torch::Tensor idx_2, 66 | torch::Tensor multipliers) { 67 | 68 | scalar_t* X1_ptr = X1.data_ptr(); 69 | scalar_t* X2_ptr = X2.data_ptr(); 70 | scalar_t* output_ptr = output.data_ptr(); 71 | scalar_t* multipliers_ptr = multipliers.data_ptr(); 72 | long* idx_1_ptr = idx_1.data_ptr(); 73 | long* idx_2_ptr = idx_2.data_ptr(); 74 | long* idx_output_ptr = idx_output.data_ptr(); 75 | 76 | long active_size = idx_output.sizes()[0]; 77 | long first_size = X1.sizes()[1]; 78 | long second_size = X1.sizes()[2]; 79 | long inner_size = first_size * second_size; 80 | 81 | for (int index = 0; index < active_size; ++index) { 82 | long shift_active_x1 = idx_1_ptr[index] * inner_size; 83 | long shift_active_x2 = idx_2_ptr[index] * inner_size; 84 | long shift_active_output = idx_output_ptr[index] * inner_size; 85 | scalar_t third = multipliers_ptr[index]; 86 | 87 | #pragma omp parallel for 88 | for (int index_first = 0; index_first < first_size; ++index_first) { 89 | // #pragma omp parallel for // This makes little difference 90 | for (int index_second = 0; index_second < second_size; ++index_second) { 91 | long shift_local = index_first * second_size + index_second; 92 | output_ptr[shift_active_output + shift_local] += X1_ptr[shift_active_x1 + shift_local] * X2_ptr[shift_active_x2 + shift_local] * third; 93 | } 94 | } 95 | } 96 | 97 | } 98 | 99 | 100 | template 101 | void _sparse_accumulation_active_dim_first_forward( 102 | torch::Tensor output, 103 | torch::Tensor X1, 104 | torch::Tensor X2, 105 | torch::Tensor idx_output, 106 | int output_size, 107 | torch::Tensor idx_1, 108 | torch::Tensor idx_2, 109 | torch::Tensor multipliers) { 110 | 111 | 112 | auto X1_a = X1.accessor(); 113 | auto X2_a = X2.accessor(); 114 | auto multipliers_a = multipliers.accessor(); 115 | 116 | auto output_a = output.accessor(); 117 | 118 | auto idx_1_a = idx_1.accessor(); 119 | auto idx_2_a = idx_2.accessor(); 120 | auto idx_output_a = idx_output.accessor(); 121 | 122 | for (int index = 0; index < idx_output_a.size(0); ++index) { 123 | for (int index_first = 0; index_first < output.size(1); ++index_first){ 124 | for (int index_second = 0; index_second < output.size(2); ++index_second) { 125 | auto first = X1_a[idx_1_a[index]][index_first][index_second]; 126 | auto second = X2_a[idx_2_a[index]][index_first][index_second]; 127 | auto third = multipliers_a[index]; 128 | auto contribution = first * second * third; 129 | output_a[idx_output_a[index]][index_first][index_second] += contribution; 130 | } 131 | } 132 | } 133 | 134 | } 135 | 136 | 137 | template 138 | void _sparse_accumulation_active_dim_first_backward( 139 | torch::Tensor d_X1, 140 | torch::Tensor d_X2, 141 | torch::Tensor d_output, 142 | torch::Tensor X1, 143 | torch::Tensor X2, 144 | torch::Tensor idx_output, 145 | torch::Tensor idx_1, 146 | torch::Tensor idx_2, 147 | torch::Tensor multipliers) { 148 | 149 | 150 | auto X1_a = X1.accessor(); 151 | auto X2_a = X2.accessor(); 152 | auto multipliers_a = multipliers.accessor(); 153 | 154 | auto d_output_a = d_output.accessor(); 155 | 156 | auto idx_1_a = idx_1.accessor(); 157 | auto idx_2_a = idx_2.accessor(); 158 | auto idx_output_a = idx_output.accessor(); 159 | 160 | auto d_X1_a = d_X1.accessor(); 161 | auto d_X2_a = d_X2.accessor(); 162 | 163 | for (int index = 0; index < idx_output_a.size(0); ++index) { 164 | for (int index_first = 0; index_first < d_output_a.size(1); ++index_first){ 165 | for (int index_second = 0; index_second < d_output_a.size(2); ++index_second) { 166 | auto from_X1 = X1_a[idx_1_a[index]][index_first][index_second]; 167 | auto from_X2 = X2_a[idx_2_a[index]][index_first][index_second]; 168 | auto multiplier = multipliers_a[index]; 169 | auto grad = d_output_a[idx_output_a[index]][index_first][index_second]; 170 | 171 | auto to_X1 = multiplier * grad * from_X2; 172 | auto to_X2 = multiplier * grad * from_X1; 173 | 174 | d_X1_a[idx_1_a[index]][index_first][index_second] += to_X1; 175 | d_X2_a[idx_2_a[index]][index_first][index_second] += to_X2; 176 | } 177 | } 178 | } 179 | 180 | } 181 | 182 | 183 | std::vector sparse_accumulation_active_dim_first_contiguous_backward( 184 | torch::Tensor d_output, 185 | torch::Tensor X1, 186 | torch::Tensor X2, 187 | torch::Tensor idx_output, 188 | torch::Tensor idx_1, 189 | torch::Tensor idx_2, 190 | torch::Tensor multipliers 191 | ) { 192 | 193 | auto d_X1 = torch::zeros_like(X1); 194 | auto d_X2 = torch::zeros_like(X2); 195 | 196 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_first_contiguous_backward", ([&] { 197 | _sparse_accumulation_active_dim_first_contiguous_backward( 198 | d_X1, 199 | d_X2, 200 | d_output, 201 | X1, 202 | X2, 203 | idx_output, 204 | idx_1, 205 | idx_2, 206 | multipliers 207 | ); 208 | })); 209 | 210 | return {d_X1, d_X2}; 211 | } 212 | 213 | 214 | torch::Tensor sparse_accumulation_active_dim_first_contiguous_forward( 215 | torch::Tensor X1, 216 | torch::Tensor X2, 217 | torch::Tensor idx_output, 218 | int output_size, 219 | torch::Tensor idx_1, 220 | torch::Tensor idx_2, 221 | torch::Tensor multipliers 222 | ) { 223 | 224 | auto output = torch::zeros({output_size, X1.sizes()[1], X1.sizes()[2]}, X1.options()); 225 | 226 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_first_contiguous_forward", ([&] { 227 | _sparse_accumulation_active_dim_first_contiguous_forward( 228 | output, 229 | X1, 230 | X2, 231 | idx_output, 232 | output_size, 233 | idx_1, 234 | idx_2, 235 | multipliers 236 | ); 237 | })); 238 | 239 | return output; 240 | } 241 | 242 | 243 | torch::Tensor sparse_accumulation_active_dim_first_forward( 244 | torch::Tensor X1, 245 | torch::Tensor X2, 246 | torch::Tensor idx_output, 247 | int output_size, 248 | torch::Tensor idx_1, 249 | torch::Tensor idx_2, 250 | torch::Tensor multipliers 251 | ) { 252 | 253 | auto output = torch::zeros({output_size, X1.sizes()[1], X1.sizes()[2]}, X1.options()); 254 | 255 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_first_contiguous_forward", ([&] { 256 | _sparse_accumulation_active_dim_first_forward( 257 | output, 258 | X1, 259 | X2, 260 | idx_output, 261 | output_size, 262 | idx_1, 263 | idx_2, 264 | multipliers 265 | ); 266 | })); 267 | 268 | return output; 269 | } 270 | 271 | 272 | std::vector sparse_accumulation_active_dim_first_backward( 273 | torch::Tensor d_output, 274 | torch::Tensor X1, 275 | torch::Tensor X2, 276 | torch::Tensor idx_output, 277 | torch::Tensor idx_1, 278 | torch::Tensor idx_2, 279 | torch::Tensor multipliers 280 | ) { 281 | 282 | auto d_X1 = torch::zeros_like(X1); 283 | auto d_X2 = torch::zeros_like(X2); 284 | 285 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_first_backward", ([&] { 286 | _sparse_accumulation_active_dim_first_backward( 287 | d_X1, 288 | d_X2, 289 | d_output, 290 | X1, 291 | X2, 292 | idx_output, 293 | idx_1, 294 | idx_2, 295 | multipliers 296 | ); 297 | })); 298 | 299 | return {d_X1, d_X2}; 300 | } 301 | 302 | 303 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 304 | m.def("forward", &sparse_accumulation_active_dim_first_forward, "sparse accumulation active dim first forward"); 305 | m.def("forward_contiguous", &sparse_accumulation_active_dim_first_contiguous_forward, "sparse accumulation active dim first contiguous forward"); 306 | m.def("backward", &sparse_accumulation_active_dim_first_backward, "sparse accumulation active dim first backward"); 307 | m.def("backward_contiguous", &sparse_accumulation_active_dim_first_contiguous_backward, "sparse accumulation active dim first contiguous backward"); 308 | } 309 | -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_first.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sparse_accumulation_active_dim_first_cpp 3 | 4 | class SparseAccumulationActiveDimFirst(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 7 | all_contiguous = X1.is_contiguous() and X2.is_contiguous() and idx_output.is_contiguous() and idx_1.is_contiguous() and idx_2.is_contiguous() and multipliers.is_contiguous() 8 | if all_contiguous: 9 | output = sparse_accumulation_active_dim_first_cpp.forward_contiguous(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 10 | else: 11 | output = sparse_accumulation_active_dim_first_cpp.forward(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 12 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 13 | return output 14 | 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 19 | all_contiguous = X1.is_contiguous() and X2.is_contiguous() and idx_output.is_contiguous() and idx_1.is_contiguous() and idx_2.is_contiguous() and multipliers.is_contiguous() 20 | if all_contiguous: 21 | d_X1, d_X2 = sparse_accumulation_active_dim_first_cpp.backward_contiguous(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 22 | else: 23 | d_X1, d_X2 = sparse_accumulation_active_dim_first_cpp.backward(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 24 | return d_X1, d_X2, None, None, None, None, None -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_last.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | using namespace torch::indexing; 6 | 7 | 8 | template 9 | void _sparse_accumulation_active_dim_last_contiguous_backward( 10 | torch::Tensor d_X1, 11 | torch::Tensor d_X2, 12 | torch::Tensor d_output, 13 | torch::Tensor X1, 14 | torch::Tensor X2, 15 | torch::Tensor idx_output, 16 | torch::Tensor idx_1, 17 | torch::Tensor idx_2, 18 | torch::Tensor multipliers) { 19 | 20 | scalar_t* d_X1_ptr = d_X1.data_ptr(); 21 | scalar_t* d_X2_ptr = d_X2.data_ptr(); 22 | scalar_t* d_output_ptr = d_output.data_ptr(); 23 | 24 | scalar_t* X1_ptr = X1.data_ptr(); 25 | scalar_t* X2_ptr = X2.data_ptr(); 26 | 27 | scalar_t* multipliers_ptr = multipliers.data_ptr(); 28 | long* idx_1_ptr = idx_1.data_ptr(); 29 | long* idx_2_ptr = idx_2.data_ptr(); 30 | long* idx_output_ptr = idx_output.data_ptr(); 31 | 32 | long active_size = idx_output.sizes()[0]; 33 | long first_size = X1.sizes()[0]; 34 | long second_size = X1.sizes()[1]; 35 | 36 | long output_active_dim = d_output.sizes()[2]; 37 | long X1_active_dim = X1.sizes()[2]; 38 | long X2_active_dim = X2.sizes()[2]; 39 | 40 | #pragma omp parallel for 41 | for (int index_first = 0; index_first < first_size; ++index_first){ 42 | // #pragma omp parallel for // This makes little difference 43 | for (int index_second = 0; index_second < second_size; ++index_second) { 44 | long shift_number = index_first * second_size + index_second; 45 | long shift_output = shift_number * output_active_dim; 46 | long shift_X1 = shift_number * X1_active_dim; 47 | long shift_X2 = shift_number * X2_active_dim; 48 | for (int index = 0; index < active_size; ++index) { 49 | scalar_t grad = d_output_ptr[shift_output + idx_output_ptr[index]] * multipliers_ptr[index]; 50 | d_X1_ptr[shift_X1 + idx_1_ptr[index]] += grad * X2_ptr[shift_X2 + idx_2_ptr[index]]; 51 | d_X2_ptr[shift_X2 + idx_2_ptr[index]] += grad * X1_ptr[shift_X1 + idx_1_ptr[index]]; 52 | } 53 | shift_output += output_active_dim; 54 | shift_X1 += X1_active_dim; 55 | shift_X2 += X2_active_dim; 56 | } 57 | } 58 | 59 | } 60 | 61 | 62 | template 63 | void _sparse_accumulation_active_dim_last_contiguous_forward( 64 | torch::Tensor output, 65 | torch::Tensor X1, 66 | torch::Tensor X2, 67 | torch::Tensor idx_output, 68 | int output_size, 69 | torch::Tensor idx_1, 70 | torch::Tensor idx_2, 71 | torch::Tensor multipliers) { 72 | 73 | scalar_t* X1_ptr = X1.data_ptr(); 74 | scalar_t* X2_ptr = X2.data_ptr(); 75 | scalar_t* output_ptr = output.data_ptr(); 76 | scalar_t* multipliers_ptr = multipliers.data_ptr(); 77 | long* idx_1_ptr = idx_1.data_ptr(); 78 | long* idx_2_ptr = idx_2.data_ptr(); 79 | long* idx_output_ptr = idx_output.data_ptr(); 80 | 81 | long active_size = idx_output.sizes()[0]; 82 | long first_size = X1.sizes()[0]; 83 | long second_size = X1.sizes()[1]; 84 | 85 | long output_active_dim = output_size; 86 | long X1_active_dim = X1.sizes()[2]; 87 | long X2_active_dim = X2.sizes()[2]; 88 | 89 | #pragma omp parallel for 90 | for (int index_first = 0; index_first < first_size; ++index_first){ 91 | // #pragma omp parallel for // This makes little difference 92 | for (int index_second = 0; index_second < second_size; ++index_second) { 93 | long shift_number = index_first * second_size + index_second; 94 | long shift_output = shift_number * output_active_dim; 95 | long shift_X1 = shift_number * X1_active_dim; 96 | long shift_X2 = shift_number * X2_active_dim; 97 | for (int index = 0; index < active_size; ++index) { 98 | output_ptr[shift_output + idx_output_ptr[index]] += multipliers_ptr[index] * X1_ptr[shift_X1 + idx_1_ptr[index]] * X2_ptr[shift_X2 + idx_2_ptr[index]]; 99 | } 100 | } 101 | } 102 | 103 | } 104 | 105 | 106 | template 107 | void _sparse_accumulation_active_dim_last_forward( 108 | torch::Tensor output, 109 | torch::Tensor X1, 110 | torch::Tensor X2, 111 | torch::Tensor idx_output, 112 | int output_size, 113 | torch::Tensor idx_1, 114 | torch::Tensor idx_2, 115 | torch::Tensor multipliers) { 116 | 117 | 118 | auto X1_a = X1.accessor(); 119 | auto X2_a = X2.accessor(); 120 | auto multipliers_a = multipliers.accessor(); 121 | 122 | auto output_a = output.accessor(); 123 | 124 | auto idx_1_a = idx_1.accessor(); 125 | auto idx_2_a = idx_2.accessor(); 126 | auto idx_output_a = idx_output.accessor(); 127 | 128 | for (int index_first = 0; index_first < output.size(0); ++index_first){ 129 | for (int index_second = 0; index_second < output.size(1); ++index_second) { 130 | for (int index = 0; index < idx_output_a.size(0); ++index) { 131 | auto first = X1_a[index_first][index_second][idx_1_a[index]]; 132 | auto second = X2_a[index_first][index_second][idx_2_a[index]]; 133 | auto third = multipliers_a[index]; 134 | 135 | auto contribution = first * second * third; 136 | output_a[index_first][index_second][idx_output_a[index]] += contribution; 137 | } 138 | } 139 | } 140 | 141 | } 142 | 143 | 144 | template 145 | void _sparse_accumulation_active_dim_last_backward( 146 | torch::Tensor d_X1, 147 | torch::Tensor d_X2, 148 | torch::Tensor d_output, 149 | torch::Tensor X1, 150 | torch::Tensor X2, 151 | torch::Tensor idx_output, 152 | torch::Tensor idx_1, 153 | torch::Tensor idx_2, 154 | torch::Tensor multipliers) { 155 | 156 | 157 | auto X1_a = X1.accessor(); 158 | auto X2_a = X2.accessor(); 159 | auto multipliers_a = multipliers.accessor(); 160 | 161 | auto d_output_a = d_output.accessor(); 162 | 163 | auto idx_1_a = idx_1.accessor(); 164 | auto idx_2_a = idx_2.accessor(); 165 | auto idx_output_a = idx_output.accessor(); 166 | 167 | auto d_X1_a = d_X1.accessor(); 168 | auto d_X2_a = d_X2.accessor(); 169 | 170 | for (int index_first = 0; index_first < d_output_a.size(0); ++index_first){ 171 | for (int index_second = 0; index_second < d_output_a.size(1); ++index_second) { 172 | for (int index = 0; index < idx_output_a.size(0); ++index) { 173 | auto from_X1 = X1_a[index_first][index_second][idx_1_a[index]]; 174 | auto from_X2 = X2_a[index_first][index_second][idx_2_a[index]]; 175 | auto multiplier = multipliers_a[index]; 176 | auto grad = d_output_a[index_first][index_second][idx_output_a[index]]; 177 | 178 | auto to_X1 = multiplier * grad * from_X2; 179 | auto to_X2 = multiplier * grad * from_X1; 180 | 181 | d_X1_a[index_first][index_second][idx_1_a[index]] += to_X1; 182 | d_X2_a[index_first][index_second][idx_2_a[index]] += to_X2; 183 | } 184 | } 185 | } 186 | 187 | } 188 | 189 | 190 | std::vector sparse_accumulation_active_dim_last_contiguous_backward( 191 | torch::Tensor d_output, 192 | torch::Tensor X1, 193 | torch::Tensor X2, 194 | torch::Tensor idx_output, 195 | torch::Tensor idx_1, 196 | torch::Tensor idx_2, 197 | torch::Tensor multipliers 198 | ) { 199 | 200 | auto d_X1 = torch::zeros_like(X1); 201 | auto d_X2 = torch::zeros_like(X2); 202 | 203 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_last_contiguous_backward", ([&] { 204 | _sparse_accumulation_active_dim_last_contiguous_backward( 205 | d_X1, 206 | d_X2, 207 | d_output, 208 | X1, 209 | X2, 210 | idx_output, 211 | idx_1, 212 | idx_2, 213 | multipliers 214 | ); 215 | })); 216 | 217 | return {d_X1, d_X2}; 218 | } 219 | 220 | 221 | torch::Tensor sparse_accumulation_active_dim_last_contiguous_forward( 222 | torch::Tensor X1, 223 | torch::Tensor X2, 224 | torch::Tensor idx_output, 225 | int output_size, 226 | torch::Tensor idx_1, 227 | torch::Tensor idx_2, 228 | torch::Tensor multipliers 229 | ) { 230 | 231 | auto output = torch::zeros({X1.sizes()[0], X1.sizes()[1], output_size}, X1.options()); 232 | 233 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_last_contiguous_forward", ([&] { 234 | _sparse_accumulation_active_dim_last_contiguous_forward( 235 | output, 236 | X1, 237 | X2, 238 | idx_output, 239 | output_size, 240 | idx_1, 241 | idx_2, 242 | multipliers 243 | ); 244 | })); 245 | 246 | return output; 247 | } 248 | 249 | 250 | std::vector sparse_accumulation_active_dim_last_backward( 251 | torch::Tensor d_output, 252 | torch::Tensor X1, 253 | torch::Tensor X2, 254 | torch::Tensor idx_output, 255 | torch::Tensor idx_1, 256 | torch::Tensor idx_2, 257 | torch::Tensor multipliers 258 | ) { 259 | 260 | auto d_X1 = torch::zeros_like(X1); 261 | auto d_X2 = torch::zeros_like(X2); 262 | 263 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_last_backward", ([&] { 264 | _sparse_accumulation_active_dim_last_backward( 265 | d_X1, 266 | d_X2, 267 | d_output, 268 | X1, 269 | X2, 270 | idx_output, 271 | idx_1, 272 | idx_2, 273 | multipliers 274 | ); 275 | })); 276 | 277 | return {d_X1, d_X2}; 278 | } 279 | 280 | 281 | torch::Tensor sparse_accumulation_active_dim_last_forward( 282 | torch::Tensor X1, 283 | torch::Tensor X2, 284 | torch::Tensor idx_output, 285 | int output_size, 286 | torch::Tensor idx_1, 287 | torch::Tensor idx_2, 288 | torch::Tensor multipliers 289 | ) { 290 | 291 | auto output = torch::zeros({X1.sizes()[0], X1.sizes()[1], output_size}, X1.options()); 292 | 293 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_last_forward", ([&] { 294 | _sparse_accumulation_active_dim_last_forward( 295 | output, 296 | X1, 297 | X2, 298 | idx_output, 299 | output_size, 300 | idx_1, 301 | idx_2, 302 | multipliers 303 | ); 304 | })); 305 | 306 | return output; 307 | } 308 | 309 | 310 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 311 | m.def("forward", &sparse_accumulation_active_dim_last_forward, "sparse accumulation active dim last forward"); 312 | m.def("forward_contiguous", &sparse_accumulation_active_dim_last_contiguous_forward, "sparse accumulation active dim last contiguous forward"); 313 | m.def("backward", &sparse_accumulation_active_dim_last_backward, "sparse accumulation active dim last backward"); 314 | m.def("backward_contiguous", &sparse_accumulation_active_dim_last_contiguous_backward, "sparse accumulation active dim last contiguous backward"); 315 | } 316 | 317 | 318 | -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_last.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sparse_accumulation_active_dim_last_cpp 3 | 4 | class SparseAccumulationActiveDimLast(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 7 | all_contiguous = X1.is_contiguous() and X2.is_contiguous() and idx_output.is_contiguous() and idx_1.is_contiguous() and idx_2.is_contiguous() and multipliers.is_contiguous() 8 | if all_contiguous: 9 | output = sparse_accumulation_active_dim_last_cpp.forward_contiguous(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 10 | else: 11 | output = sparse_accumulation_active_dim_last_cpp.forward(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 12 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 13 | return output 14 | 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 19 | all_contiguous = X1.is_contiguous() and X2.is_contiguous() and idx_output.is_contiguous() and idx_1.is_contiguous() and idx_2.is_contiguous() and multipliers.is_contiguous() 20 | if all_contiguous: 21 | d_X1, d_X2 = sparse_accumulation_active_dim_last_cpp.backward_contiguous(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 22 | else: 23 | d_X1, d_X2 = sparse_accumulation_active_dim_last_cpp.backward(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 24 | return d_X1, d_X2, None, None, None, None, None -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_middle.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | using namespace torch::indexing; 6 | 7 | 8 | template 9 | void _sparse_accumulation_active_dim_middle_contiguous_backward( 10 | torch::Tensor d_X1, 11 | torch::Tensor d_X2, 12 | torch::Tensor d_output, 13 | torch::Tensor X1, 14 | torch::Tensor X2, 15 | torch::Tensor idx_output, 16 | torch::Tensor idx_1, 17 | torch::Tensor idx_2, 18 | torch::Tensor multipliers) { 19 | 20 | scalar_t* d_X1_ptr = d_X1.data_ptr(); 21 | scalar_t* d_X2_ptr = d_X2.data_ptr(); 22 | scalar_t* d_output_ptr = d_output.data_ptr(); 23 | 24 | scalar_t* X1_ptr = X1.data_ptr(); 25 | scalar_t* X2_ptr = X2.data_ptr(); 26 | 27 | scalar_t* multipliers_ptr = multipliers.data_ptr(); 28 | long* idx_1_ptr = idx_1.data_ptr(); 29 | long* idx_2_ptr = idx_2.data_ptr(); 30 | long* idx_output_ptr = idx_output.data_ptr(); 31 | 32 | long active_size = idx_output.sizes()[0]; 33 | long first_size = X1.sizes()[0]; 34 | long second_size = X1.sizes()[2]; 35 | 36 | long output_active_dim = d_output.sizes()[1]; 37 | long X1_active_dim = X1.sizes()[1]; 38 | long X2_active_dim = X2.sizes()[1]; 39 | 40 | long X1_inner_dimensions = X1_active_dim * second_size; 41 | long X2_inner_dimensions = X2_active_dim * second_size; 42 | long output_inner_dimensions = output_active_dim * second_size; 43 | 44 | #pragma omp parallel for 45 | for (int index_first = 0; index_first < first_size; ++index_first) { 46 | 47 | long shift_X1_first = X1_inner_dimensions * index_first; 48 | long shift_X2_first = X2_inner_dimensions * index_first; 49 | long shift_output_first = output_inner_dimensions * index_first; 50 | 51 | for (int index = 0; index < active_size; ++index) { 52 | 53 | long shift_output = idx_output_ptr[index] * second_size + shift_output_first; 54 | long shift_X1 = idx_1_ptr[index] * second_size + shift_X1_first; 55 | long shift_X2 = idx_2_ptr[index] * second_size + shift_X2_first; 56 | 57 | scalar_t multiplier = multipliers_ptr[index]; 58 | 59 | // #pragma omp parallel for // In most cases, this slows the code down significantly 60 | for (int index_second = 0; index_second < second_size; ++index_second) { 61 | 62 | scalar_t grad = d_output_ptr[shift_output + index_second] * multiplier; 63 | d_X1_ptr[shift_X1 + index_second] += grad * X2_ptr[shift_X2 + index_second]; 64 | d_X2_ptr[shift_X2 + index_second] += grad * X1_ptr[shift_X1 + index_second]; 65 | } 66 | } 67 | } 68 | 69 | } 70 | 71 | 72 | template 73 | void _sparse_accumulation_active_dim_middle_contiguous_forward( 74 | torch::Tensor output, 75 | torch::Tensor X1, 76 | torch::Tensor X2, 77 | torch::Tensor idx_output, 78 | int output_size, 79 | torch::Tensor idx_1, 80 | torch::Tensor idx_2, 81 | torch::Tensor multipliers) { 82 | 83 | scalar_t* X1_ptr = X1.data_ptr(); 84 | scalar_t* X2_ptr = X2.data_ptr(); 85 | scalar_t* output_ptr = output.data_ptr(); 86 | scalar_t* multipliers_ptr = multipliers.data_ptr(); 87 | long* idx_1_ptr = idx_1.data_ptr(); 88 | long* idx_2_ptr = idx_2.data_ptr(); 89 | long* idx_output_ptr = idx_output.data_ptr(); 90 | 91 | long active_size = idx_output.sizes()[0]; 92 | long first_size = X1.sizes()[0]; 93 | long second_size = X1.sizes()[2]; 94 | 95 | long output_active_dim = output.sizes()[1]; 96 | long X1_active_dim = X1.sizes()[1]; 97 | long X2_active_dim = X2.sizes()[1]; 98 | 99 | long X1_inner_dimensions = X1_active_dim * second_size; 100 | long X2_inner_dimensions = X2_active_dim * second_size; 101 | long output_inner_dimensions = output_active_dim * second_size; 102 | 103 | #pragma omp parallel for 104 | for (int index_first = 0; index_first < first_size; ++index_first) { 105 | 106 | long shift_X1_first = X1_inner_dimensions * index_first; 107 | long shift_X2_first = X2_inner_dimensions * index_first; 108 | long shift_output_first = output_inner_dimensions * index_first; 109 | 110 | for (int index = 0; index < active_size; ++index) { 111 | 112 | long shift_output = idx_output_ptr[index] * second_size + shift_output_first; 113 | long shift_X1 = idx_1_ptr[index] * second_size + shift_X1_first; 114 | long shift_X2 = idx_2_ptr[index] * second_size + shift_X2_first; 115 | 116 | scalar_t multiplier = multipliers_ptr[index]; 117 | // #pragma omp parallel for // In most cases, this slows the code down significantly 118 | for (int index_second = 0; index_second < second_size; ++index_second) { 119 | output_ptr[shift_output + index_second] += X1_ptr[shift_X1 + index_second] * X2_ptr[shift_X2 + index_second] * multiplier; 120 | } 121 | } 122 | } 123 | 124 | } 125 | 126 | 127 | std::vector sparse_accumulation_active_dim_middle_contiguous_backward( 128 | torch::Tensor d_output, 129 | torch::Tensor X1, 130 | torch::Tensor X2, 131 | torch::Tensor idx_output, 132 | torch::Tensor idx_1, 133 | torch::Tensor idx_2, 134 | torch::Tensor multipliers 135 | ) { 136 | 137 | auto d_X1 = torch::zeros_like(X1); 138 | auto d_X2 = torch::zeros_like(X2); 139 | 140 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_middle_contiguous_backward", ([&] { 141 | _sparse_accumulation_active_dim_middle_contiguous_backward( 142 | d_X1, 143 | d_X2, 144 | d_output, 145 | X1, 146 | X2, 147 | idx_output, 148 | idx_1, 149 | idx_2, 150 | multipliers 151 | ); 152 | })); 153 | 154 | return {d_X1, d_X2}; 155 | } 156 | 157 | 158 | torch::Tensor sparse_accumulation_active_dim_middle_contiguous_forward( 159 | torch::Tensor X1, 160 | torch::Tensor X2, 161 | torch::Tensor idx_output, 162 | int output_size, 163 | torch::Tensor idx_1, 164 | torch::Tensor idx_2, 165 | torch::Tensor multipliers 166 | ) { 167 | 168 | auto output = torch::zeros({X1.sizes()[0], output_size, X1.sizes()[2]}, X1.options()); 169 | 170 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_active_dim_middle_contiguous_forward", ([&] { 171 | _sparse_accumulation_active_dim_middle_contiguous_forward( 172 | output, 173 | X1, 174 | X2, 175 | idx_output, 176 | output_size, 177 | idx_1, 178 | idx_2, 179 | multipliers 180 | ); 181 | })); 182 | 183 | return output; 184 | } 185 | 186 | 187 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 188 | m.def("forward_contiguous", &sparse_accumulation_active_dim_middle_contiguous_forward, "sparse accumulation active dim middle contiguous forward"); 189 | m.def("backward_contiguous", &sparse_accumulation_active_dim_middle_contiguous_backward, "sparse accumulation active dim middle contiguous backward"); 190 | } 191 | 192 | 193 | -------------------------------------------------------------------------------- /sparse_accumulation/cpu_extension/sparse_accumulation_active_dim_middle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sparse_accumulation_active_dim_middle_cpp 3 | 4 | class SparseAccumulationActiveDimMiddle(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 7 | all_contiguous = X1.is_contiguous() and X2.is_contiguous() and idx_output.is_contiguous() and idx_1.is_contiguous() and idx_2.is_contiguous() and multipliers.is_contiguous() 8 | if all_contiguous: 9 | output = sparse_accumulation_active_dim_middle_cpp.forward_contiguous(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 10 | else: 11 | output = sparse_accumulation_active_dim_middle_cpp.forward(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 12 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 13 | return output 14 | 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 19 | all_contiguous = X1.is_contiguous() and X2.is_contiguous() and idx_output.is_contiguous() and idx_1.is_contiguous() and idx_2.is_contiguous() and multipliers.is_contiguous() 20 | if all_contiguous: 21 | d_X1, d_X2 = sparse_accumulation_active_dim_middle_cpp.backward_contiguous(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 22 | else: 23 | d_X1, d_X2 = sparse_accumulation_active_dim_middle_cpp.backward(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 24 | return d_X1, d_X2, None, None, None, None, None -------------------------------------------------------------------------------- /sparse_accumulation/cuda_extension/jit.py: -------------------------------------------------------------------------------- 1 | from torch.utils.cpp_extension import load 2 | sparse_accumulation_cuda = load( 3 | 'sparse_accumulation_cuda', ['sparse_accumulation_cuda.cpp', 'sparse_accumulation_cuda_kernel.cu'], verbose=True) 4 | help(sparse_accumulation_cuda) -------------------------------------------------------------------------------- /sparse_accumulation/cuda_extension/sparse_accumulation_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector sparse_accumulation_cuda_forward(torch::Tensor X1, 8 | torch::Tensor X2, 9 | torch::Tensor idx_output, 10 | int64_t output_size, 11 | torch::Tensor idx_1, 12 | torch::Tensor idx_2, 13 | torch::Tensor multipliers); 14 | 15 | std::vector sparse_accumulation_cuda_backward( 16 | torch::Tensor d_output, 17 | torch::Tensor X1, 18 | torch::Tensor X2, 19 | torch::Tensor idx_output, 20 | torch::Tensor idx_1, 21 | torch::Tensor idx_2, 22 | torch::Tensor multipliers 23 | ); 24 | // C++ interface 25 | 26 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 27 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 28 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 29 | 30 | std::vector sparse_accumulation_gpu_forward( 31 | torch::Tensor X1, 32 | torch::Tensor X2, 33 | torch::Tensor idx_output, 34 | int64_t output_size, 35 | torch::Tensor idx_1, 36 | torch::Tensor idx_2, 37 | torch::Tensor multipliers){ 38 | 39 | CHECK_INPUT(X1); 40 | CHECK_INPUT(X2); 41 | CHECK_INPUT(idx_output); 42 | //CHECK_INPUT(output_size); 43 | CHECK_INPUT(idx_1); 44 | CHECK_INPUT(idx_2); 45 | CHECK_INPUT(multipliers); 46 | 47 | return sparse_accumulation_cuda_forward(X1,X2,idx_output,output_size,idx_1,idx_2,multipliers); 48 | } 49 | 50 | std::vector sparse_accumulation_gpu_backward( 51 | torch::Tensor d_output, 52 | torch::Tensor X1, 53 | torch::Tensor X2, 54 | torch::Tensor idx_output, 55 | torch::Tensor idx_1, 56 | torch::Tensor idx_2, 57 | torch::Tensor multipliers 58 | ) { 59 | CHECK_INPUT(d_output); 60 | CHECK_INPUT(X1); 61 | CHECK_INPUT(X2); 62 | CHECK_INPUT(idx_output); 63 | CHECK_INPUT(idx_1); 64 | CHECK_INPUT(idx_2 ); 65 | CHECK_INPUT(multipliers); 66 | 67 | return sparse_accumulation_cuda_backward(d_output,X1,X2,idx_output,idx_1,idx_2,multipliers); 68 | } 69 | 70 | //PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 71 | // m.def("forward", &sparse_accumulation_gpu_forward, "Sparse Accumulation forward (CUDA)"); 72 | // m.def("backward", &sparse_accumulation_gpu_forward, "Sparse Accumulation backward (CUDA)"); 73 | //} 74 | 75 | TORCH_LIBRARY(sparse_accumulation_cuda, m) { 76 | m.def("forward", sparse_accumulation_gpu_forward); 77 | m.def("backward", sparse_accumulation_gpu_backward); 78 | } -------------------------------------------------------------------------------- /sparse_accumulation/cuda_extension/sparse_accumulation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | using namespace torch::indexing; 8 | 9 | template 10 | __global__ void sparse_accumulation_cuda_forward_kernel( 11 | scalar_t* __restrict__ output, 12 | const scalar_t* __restrict__ X1, 13 | const scalar_t* __restrict__ X2, 14 | const int64_t* __restrict__ idx_output, 15 | const int64_t* __restrict__ idx_1, 16 | const int64_t* __restrict__ idx_2, 17 | const scalar_t* __restrict__ multipliers, 18 | const int output_size, 19 | const int X1_third_size, 20 | const int X2_third_size, 21 | const int nx, 22 | const int ny, 23 | const int nz ) { 24 | 25 | int i = threadIdx.x + blockDim.x * blockIdx.x ; 26 | int j = threadIdx.y + blockDim.y * blockIdx.y ; 27 | int z = threadIdx.z + blockDim.z * blockIdx.z ; 28 | 29 | //if (i 82 | __global__ void sparse_accumulation_cuda_backward_kernel( 83 | const scalar_t* __restrict__ X1, 84 | const scalar_t* __restrict__ X2, 85 | scalar_t* __restrict__ d_X1, 86 | scalar_t* __restrict__ d_X2 87 | ) { 88 | const int state_size = 100 ; 89 | const int column = blockIdx.x * blockDim.x + threadIdx.x; 90 | const int index = blockIdx.y * state_size + column; 91 | const int gates_row = blockIdx.y * (state_size * 3); 92 | if (column < state_size) { 93 | d_X1[index] = X1[column]; 94 | d_X2[index] = X2[column]; 95 | } 96 | } 97 | 98 | 99 | std::vector sparse_accumulation_cuda_forward( 100 | torch::Tensor X1, 101 | torch::Tensor X2, 102 | torch::Tensor idx_output, 103 | int output_size, 104 | torch::Tensor idx_1, 105 | torch::Tensor idx_2, 106 | torch::Tensor multipliers) 107 | { 108 | //auto output = torch::zeros_like(X1); 109 | auto output = torch::zeros({X1.sizes()[0], X1.sizes()[1], output_size}, 110 | torch::TensorOptions() 111 | .dtype(X1.dtype()) 112 | .device(X1.device())); 113 | 114 | auto X1_third_size = X1.sizes()[2]; 115 | auto X2_third_size = X2.sizes()[2]; 116 | const auto batch_sizex = output.sizes()[0]; 117 | const auto batch_sizey = output.sizes()[1]; 118 | const auto batch_sizez = idx_output.sizes()[0]; 119 | printf("idx_output.sizes()[0] %d \n",idx_output.sizes()[0]); 120 | 121 | auto nx = batch_sizex ; 122 | auto ny = batch_sizey ; 123 | auto nz = batch_sizez ; 124 | auto threads = 124; 125 | //const dim3 blocks((n+threads-1)/threads, batch_size); 126 | //auto blocks = (n+threads-1)/threads; 127 | 128 | //AT_DISPATCH_FLOATING_TYPES(output.type(), "sparse_accumulation_forward_cuda", ([&] { 129 | // sparse_accumulation_cuda_forward_kernel<<>>( 130 | // output.data(), 131 | // X1.data(), 132 | // n1, 133 | // n2, 134 | // ); 135 | //})); 136 | 137 | auto find_num_blocks = [](int x, int bdim) {return (x+bdim-1)/bdim;}; 138 | dim3 block_dim(16, 4,4); 139 | int nbx = find_num_blocks(nx, block_dim.x); 140 | int nby = find_num_blocks(ny, block_dim.y); 141 | int nbz = find_num_blocks(nz, block_dim.z); 142 | dim3 grid_dim(nbx, nby, nbz); 143 | 144 | AT_DISPATCH_FLOATING_TYPES(output.type(), "sparse_accumulation_forward_cuda", ([&] { 145 | sparse_accumulation_cuda_forward_kernel<<>>( 146 | output.data(), 147 | X1.data(), 148 | X2.data(), 149 | idx_output.data(), 150 | idx_1.data(), 151 | idx_2.data(), 152 | multipliers.data(), 153 | output_size, 154 | X1_third_size, 155 | X2_third_size, 156 | nx, 157 | ny, 158 | nz 159 | ); 160 | })); 161 | 162 | return {output}; 163 | } 164 | 165 | std::vector sparse_accumulation_cuda_backward( 166 | torch::Tensor d_output, 167 | torch::Tensor X1, 168 | torch::Tensor X2, 169 | torch::Tensor idx_output, 170 | torch::Tensor idx_1, 171 | torch::Tensor idx_2, 172 | torch::Tensor multipliers) 173 | { 174 | auto d_X1 = torch::zeros_like(X1); 175 | auto d_X2 = torch::zeros_like(X2); 176 | 177 | const auto batch_size = 2; 178 | const auto state_size = 1; 179 | 180 | const int threads = 1024; 181 | const dim3 blocks((state_size + threads - 1) / threads, batch_size); 182 | 183 | AT_DISPATCH_FLOATING_TYPES(d_X1.type(), "sparse_accumulation_backward_cuda", ([&] { 184 | sparse_accumulation_cuda_backward_kernel<<>>( 185 | X1.data(), 186 | X2.data(), 187 | d_X1.data(), 188 | d_X2.data() 189 | ); 190 | })); 191 | return {d_X1, d_X2}; 192 | 193 | } 194 | 195 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 196 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 197 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 198 | 199 | std::vector sparse_accumulation_gpu_forward( 200 | torch::Tensor X1, 201 | torch::Tensor X2, 202 | torch::Tensor idx_output, 203 | int64_t output_size, 204 | torch::Tensor idx_1, 205 | torch::Tensor idx_2, 206 | torch::Tensor multipliers){ 207 | 208 | CHECK_INPUT(X1); 209 | CHECK_INPUT(X2); 210 | CHECK_INPUT(idx_output); 211 | //CHECK_INPUT(output_size); 212 | CHECK_INPUT(idx_1); 213 | CHECK_INPUT(idx_2); 214 | CHECK_INPUT(multipliers); 215 | 216 | return sparse_accumulation_cuda_forward(X1,X2,idx_output,output_size,idx_1,idx_2,multipliers); 217 | } 218 | 219 | std::vector sparse_accumulation_gpu_backward( 220 | torch::Tensor d_output, 221 | torch::Tensor X1, 222 | torch::Tensor X2, 223 | torch::Tensor idx_output, 224 | torch::Tensor idx_1, 225 | torch::Tensor idx_2, 226 | torch::Tensor multipliers 227 | ) { 228 | CHECK_INPUT(d_output); 229 | CHECK_INPUT(X1); 230 | CHECK_INPUT(X2); 231 | CHECK_INPUT(idx_output); 232 | CHECK_INPUT(idx_1); 233 | CHECK_INPUT(idx_2 ); 234 | CHECK_INPUT(multipliers); 235 | 236 | return sparse_accumulation_cuda_backward(d_output,X1,X2,idx_output,idx_1,idx_2,multipliers); 237 | } 238 | 239 | //PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 240 | // m.def("forward", &sparse_accumulation_gpu_forward, "Sparse Accumulation forward (CUDA)"); 241 | // m.def("backward", &sparse_accumulation_gpu_forward, "Sparse Accumulation backward (CUDA)"); 242 | //} 243 | 244 | TORCH_LIBRARY(sparse_accumulation_cuda, m) { 245 | m.def("forward", sparse_accumulation_gpu_forward); 246 | m.def("backward", sparse_accumulation_gpu_backward); 247 | } -------------------------------------------------------------------------------- /sparse_accumulation/cuda_extension/sparse_accumulation_cuda_kernel2D.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | using namespace torch::indexing; 8 | 9 | #define BLOCK_SIZE 8 10 | 11 | template 12 | __global__ void sparse_accumulation_cuda_forward_kernel( 13 | scalar_t* __restrict__ output, 14 | const scalar_t* __restrict__ X1, 15 | const scalar_t* __restrict__ X2, 16 | const int64_t* __restrict__ idx_output, 17 | const int64_t* __restrict__ idx_1, 18 | const int64_t* __restrict__ idx_2, 19 | const scalar_t* __restrict__ multipliers, 20 | const int32_t output_size, 21 | const int32_t X1_third_size, 22 | const int32_t X2_third_size, 23 | const int32_t nx, 24 | const int32_t ny, 25 | const int32_t nz 26 | ) { 27 | extern __shared__ char buffer[]; 28 | // offset (in bytes) of the first available slot in the shared memory buffer 29 | size_t offset = 0; 30 | 31 | scalar_t* buffer_X1 = reinterpret_cast(buffer + offset); 32 | offset += BLOCK_SIZE * BLOCK_SIZE * X1_third_size * sizeof(scalar_t); 33 | 34 | scalar_t* buffer_X2 = reinterpret_cast(buffer + offset); 35 | offset += BLOCK_SIZE * BLOCK_SIZE * X2_third_size * sizeof(scalar_t); 36 | 37 | scalar_t* buffer_multipliers = reinterpret_cast(buffer + offset); 38 | offset += nz * sizeof(scalar_t); 39 | 40 | int32_t* buffer_idx_output = reinterpret_cast(buffer + offset); 41 | offset += nz * sizeof(int32_t); 42 | 43 | int32_t* buffer_idx_X1 = reinterpret_cast(buffer + offset); 44 | offset += nz * sizeof(int32_t); 45 | 46 | int32_t* buffer_idx_X2 = reinterpret_cast(buffer + offset); 47 | offset += nz * sizeof(int32_t); 48 | 49 | int i = threadIdx.x + blockDim.x * blockIdx.x; 50 | int j = threadIdx.y + blockDim.y * blockIdx.y; 51 | 52 | int single_multipliers_block_size = (nz / (BLOCK_SIZE * BLOCK_SIZE)) + 1; 53 | int total_thread_idx = threadIdx.x * BLOCK_SIZE + threadIdx.y; 54 | int multipliers_pos_from = total_thread_idx * single_multipliers_block_size; 55 | int multipliers_pos_to = (total_thread_idx + 1) * single_multipliers_block_size; 56 | if (multipliers_pos_to > nz) { 57 | multipliers_pos_to = nz; 58 | } 59 | 60 | int delta_now_X1 = j * X1_third_size + i * ny * X1_third_size; 61 | int delta_now_output = j * output_size + i * ny * output_size; 62 | int delta_now_X2 = j * X2_third_size + i * ny * X2_third_size; 63 | 64 | 65 | int delta_buffer_X1 = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * X1_third_size; 66 | int delta_buffer_X2 = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * X2_third_size; 67 | 68 | for (int active_index = multipliers_pos_from; active_index < multipliers_pos_to; ++active_index) { 69 | buffer_multipliers[active_index] = multipliers[active_index]; 70 | buffer_idx_output[active_index] = idx_output[active_index]; 71 | buffer_idx_X1[active_index] = idx_1[active_index]; 72 | buffer_idx_X2[active_index] = idx_2[active_index]; 73 | } 74 | 75 | scalar_t* buffer_X1_final = buffer_X1 + delta_buffer_X1; 76 | scalar_t* buffer_X2_final = buffer_X2 + delta_buffer_X2; 77 | 78 | auto output_final = output + delta_now_output; 79 | auto X1_final = X1 + delta_now_X1; 80 | auto X2_final = X2 + delta_now_X2; 81 | __syncthreads(); 82 | 83 | if (i < nx && j < ny) { 84 | 85 | for (int X1_index = 0; X1_index < X1_third_size; ++X1_index) { 86 | buffer_X1_final[X1_index] = X1_final[X1_index]; 87 | } 88 | 89 | for (int X2_index = 0; X2_index < X2_third_size; ++X2_index) { 90 | buffer_X2_final[X2_index] = X2_final[X2_index]; 91 | } 92 | 93 | int z_output, z_X1, z_X2; 94 | scalar_t now = 0; 95 | int z_old = 0; 96 | for (int z = 0 ; z < nz ; ++z){ 97 | z_output = buffer_idx_output[z]; 98 | if (z_old != z_output) { 99 | output_final[z_old] = now; 100 | now = 0; 101 | z_old = z_output; 102 | } 103 | z_X1 = buffer_idx_X1[z]; 104 | z_X2 = buffer_idx_X2[z]; 105 | now += buffer_X1_final[z_X1] 106 | * buffer_X2_final[z_X2] 107 | * buffer_multipliers[z]; 108 | }; 109 | output_final[z_old] = now; 110 | }; 111 | } 112 | 113 | template 114 | __global__ void sparse_accumulation_cuda_backward_kernel( 115 | scalar_t* __restrict__ d_X1, 116 | scalar_t* __restrict__ d_X2, 117 | const scalar_t* __restrict__ d_output, 118 | const scalar_t* __restrict__ X1, 119 | const scalar_t* __restrict__ X2, 120 | const int64_t* __restrict__ idx_output, 121 | const int64_t* __restrict__ idx_1, 122 | const int64_t* __restrict__ idx_2, 123 | const scalar_t* __restrict__ multipliers, 124 | const int output_size, 125 | const int X1_third_size, 126 | const int X2_third_size, 127 | const int nx, 128 | const int ny, 129 | const int nz 130 | ) { 131 | extern __shared__ char buffer[]; 132 | // offset (in bytes) of the first available slot in the shared memory buffer 133 | size_t offset = 0; 134 | scalar_t* buffer_output = reinterpret_cast(buffer + offset); 135 | offset += BLOCK_SIZE * BLOCK_SIZE * output_size * sizeof(scalar_t); 136 | 137 | scalar_t* buffer_X1 = reinterpret_cast(buffer + offset); 138 | offset += BLOCK_SIZE * BLOCK_SIZE * X1_third_size * sizeof(scalar_t); 139 | 140 | scalar_t* buffer_X2 = reinterpret_cast(buffer + offset); 141 | offset += BLOCK_SIZE * BLOCK_SIZE * X2_third_size * sizeof(scalar_t); 142 | 143 | scalar_t* buffer_d_X1 = reinterpret_cast(buffer + offset); 144 | offset += BLOCK_SIZE * BLOCK_SIZE * X1_third_size * sizeof(scalar_t); 145 | 146 | scalar_t* buffer_d_X2 = reinterpret_cast(buffer + offset); 147 | offset += BLOCK_SIZE * BLOCK_SIZE * X2_third_size * sizeof(scalar_t); 148 | 149 | scalar_t* buffer_multipliers = reinterpret_cast(buffer + offset); 150 | offset += nz * sizeof(scalar_t); 151 | 152 | int32_t* buffer_idx_output = reinterpret_cast(buffer + offset); 153 | offset += nz * sizeof(int32_t); 154 | 155 | int32_t* buffer_idx_X1 = reinterpret_cast(buffer + offset); 156 | offset += nz * sizeof(int32_t); 157 | 158 | int32_t* buffer_idx_X2 = reinterpret_cast(buffer + offset); 159 | offset += nz * sizeof(int32_t); 160 | 161 | int i = threadIdx.x + blockDim.x * blockIdx.x; 162 | int j = threadIdx.y + blockDim.y * blockIdx.y; 163 | 164 | int single_multipliers_block_size = (nz / (BLOCK_SIZE * BLOCK_SIZE)) + 1; 165 | int total_thread_idx = threadIdx.x * BLOCK_SIZE + threadIdx.y; 166 | int multipliers_pos_from = total_thread_idx * single_multipliers_block_size; 167 | int multipliers_pos_to = (total_thread_idx + 1) * single_multipliers_block_size; 168 | if (multipliers_pos_to > nz) { 169 | multipliers_pos_to = nz; 170 | } 171 | 172 | int delta_now_X1 = j * X1_third_size + i * ny * X1_third_size; 173 | int delta_now_output = j * output_size + i * ny * output_size; 174 | int delta_now_X2 = j * X2_third_size + i * ny * X2_third_size; 175 | //int delta_now_d_X1 = j * X1_third_size + i * ny * X1_third_size; 176 | //int delta_now_d_X2 = j * X2_third_size + i * ny * X2_third_size; 177 | 178 | int delta_buffer_output = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * output_size; 179 | int delta_buffer_X1 = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * X1_third_size; 180 | int delta_buffer_X2 = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * X2_third_size; 181 | //int delta_buffer_d_X1 = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * X1_third_size; 182 | //int delta_buffer_d_X2 = (BLOCK_SIZE * threadIdx.x + threadIdx.y) * X2_third_size; 183 | 184 | 185 | for (int active_index = multipliers_pos_from; active_index < multipliers_pos_to; ++active_index) { 186 | buffer_multipliers[active_index] = multipliers[active_index]; 187 | buffer_idx_output[active_index] = idx_output[active_index]; 188 | buffer_idx_X1[active_index] = idx_1[active_index]; 189 | buffer_idx_X2[active_index] = idx_2[active_index]; 190 | } 191 | scalar_t* buffer_output_final = buffer_output + delta_buffer_output; 192 | scalar_t* buffer_X1_final = buffer_X1 + delta_buffer_X1; 193 | scalar_t* buffer_X2_final = buffer_X2 + delta_buffer_X2; 194 | scalar_t* buffer_d_X1_final = buffer_d_X1 + delta_buffer_X1; 195 | scalar_t* buffer_d_X2_final = buffer_d_X2 + delta_buffer_X2; 196 | 197 | auto d_output_final = d_output + delta_now_output; 198 | auto X1_final = X1 + delta_now_X1; 199 | auto X2_final = X2 + delta_now_X2; 200 | auto d_X1_final = d_X1 + delta_now_X1; 201 | auto d_X2_final = d_X2 + delta_now_X2; 202 | __syncthreads(); 203 | 204 | if (i < nx && j < ny) { 205 | //printf("in kernel i %d j %d\n",i,j) ; 206 | for (int z_output = 0; z_output < output_size; ++z_output) { 207 | buffer_output_final[z_output] = d_output_final[z_output]; 208 | } 209 | 210 | for (int X1_index = 0; X1_index < X1_third_size; ++X1_index) { 211 | buffer_X1_final[X1_index] = X1_final[X1_index]; 212 | buffer_d_X1_final[X1_index] = 0; 213 | } 214 | 215 | for (int X2_index = 0; X2_index < X2_third_size; ++X2_index) { 216 | buffer_X2_final[X2_index] = X2_final[X2_index]; 217 | buffer_d_X2_final[X2_index] = 0; 218 | } 219 | 220 | int z_output, z_X1, z_X2; 221 | // scalar_t now = 0; 222 | // int z_old = 0; 223 | scalar_t grad_multi; 224 | for (int z = 0 ; z < nz ; ++z){ 225 | z_output = buffer_idx_output[z]; 226 | z_X1 = buffer_idx_X1[z]; 227 | z_X2 = buffer_idx_X2[z]; 228 | grad_multi = buffer_output_final[z_output] * buffer_multipliers[z]; 229 | buffer_d_X1_final[z_X1] += grad_multi * buffer_X2_final[z_X2]; 230 | 231 | buffer_d_X2_final[z_X2] += grad_multi * buffer_X1_final[z_X1]; 232 | }; 233 | for (int z = 0; z < X1_third_size; ++z) { 234 | d_X1_final[z] = buffer_d_X1_final[z]; 235 | } 236 | 237 | for (int z = 0; z < X2_third_size; ++z) { 238 | d_X2_final[z] = buffer_d_X2_final[z]; 239 | } 240 | }; 241 | } 242 | 243 | 244 | /* Forward propagation function 245 | * It computes the sparce accumulation assuming the last 246 | * dimension being the active-sparse dimension. 247 | * 248 | * The operation can be summarized in this equation: 249 | * 250 | * output_tensor[isample,ifeatures,idx_output[m]] = X1[[isample,ifeatures,idx_1[m]]] * X2[[isample,ifeatures,idx_1[m]]] * multipliers[m] 251 | * 252 | * input: 253 | * X1[nsample,nfeatures,m1]: First input tensor 254 | * X2[nsample,nfeatures,m2]: Second input Tensor 255 | * idx_output: Tensor with the indeces of the third dimension of the output_tensor 256 | * output_size: third dimension of the output Tensor 257 | * idx_1: Tensor with the indeces of the third dimension of X1 258 | * idx_2: Tensor with the indeces of the third dimension of X2 259 | * multipliers: Tensor containing the multipliers for the sparse accumulations 260 | * output: 261 | * output_tensor[nsample,nfeatures,output_size]: output Tensor 262 | */ 263 | std::vector sparse_accumulation_cuda_forward( 264 | torch::Tensor X1, 265 | torch::Tensor X2, 266 | torch::Tensor idx_output, 267 | int output_size, 268 | torch::Tensor idx_1, 269 | torch::Tensor idx_2, 270 | torch::Tensor multipliers 271 | ) { 272 | auto output = torch::zeros({X1.sizes()[0], X1.sizes()[1], output_size}, 273 | torch::TensorOptions() 274 | .dtype(X1.dtype()) 275 | .device(X1.device())); 276 | 277 | auto X1_third_size = X1.sizes()[2]; 278 | auto X2_third_size = X2.sizes()[2]; 279 | const auto batch_sizex = output.sizes()[0]; 280 | const auto batch_sizey = output.sizes()[1]; 281 | const auto batch_sizez = idx_output.sizes()[0]; 282 | 283 | auto nx = batch_sizex ; 284 | auto ny = batch_sizey ; 285 | auto nz = batch_sizez ; 286 | 287 | auto find_num_blocks = [](int x, int bdim) {return (x+bdim-1)/bdim;}; 288 | dim3 block_dim(BLOCK_SIZE, BLOCK_SIZE); 289 | int nbx = find_num_blocks(nx, block_dim.x); 290 | int nby = find_num_blocks(ny, block_dim.y); 291 | int nbz = find_num_blocks(nz, block_dim.z); 292 | dim3 grid_dim(nbx, nby); 293 | 294 | 295 | AT_DISPATCH_FLOATING_TYPES(output.type(), "sparse_accumulation_forward_cuda", ([&] { 296 | size_t X1_buf_size = BLOCK_SIZE * BLOCK_SIZE * X1_third_size * sizeof(scalar_t); 297 | size_t X2_buf_size = BLOCK_SIZE * BLOCK_SIZE * X2_third_size * sizeof(scalar_t); 298 | size_t multipliers_size = multipliers.sizes()[0] * sizeof(scalar_t); 299 | size_t index_size = idx_output.sizes()[0] * sizeof(int32_t); 300 | 301 | size_t total_buf_size = X1_buf_size + X2_buf_size + multipliers_size + index_size * 3; 302 | 303 | sparse_accumulation_cuda_forward_kernel<<>>( 304 | output.data_ptr(), 305 | X1.data_ptr(), 306 | X2.data_ptr(), 307 | idx_output.data_ptr(), 308 | idx_1.data_ptr(), 309 | idx_2.data_ptr(), 310 | multipliers.data_ptr(), 311 | output_size, 312 | X1_third_size, 313 | X2_third_size, 314 | nx, 315 | ny, 316 | nz 317 | ); 318 | })); 319 | 320 | return {output}; 321 | } 322 | 323 | std::vector sparse_accumulation_cuda_backward( 324 | torch::Tensor d_output, 325 | torch::Tensor X1, 326 | torch::Tensor X2, 327 | torch::Tensor idx_output, 328 | torch::Tensor idx_1, 329 | torch::Tensor idx_2, 330 | torch::Tensor multipliers 331 | ) { 332 | auto d_X1 = torch::zeros_like(X1); 333 | auto d_X2 = torch::zeros_like(X2); 334 | 335 | auto X1_third_size = X1.sizes()[2]; 336 | auto X2_third_size = X2.sizes()[2]; 337 | const auto nx = d_output.sizes()[0] ; 338 | const auto ny = d_output.sizes()[1] ; 339 | const auto output_size = d_output.sizes()[2] ; 340 | const auto nz = idx_output.sizes()[0]; 341 | 342 | auto find_num_blocks = [](int x, int bdim) {return (x+bdim-1)/bdim;}; 343 | dim3 block_dim(BLOCK_SIZE, BLOCK_SIZE); 344 | int nbx = find_num_blocks(nx, block_dim.x); 345 | int nby = find_num_blocks(ny, block_dim.y); 346 | dim3 grid_dim(nbx, nby); 347 | 348 | AT_DISPATCH_FLOATING_TYPES(X1.type(), "sparse_accumulation_backward_cuda", ([&] { 349 | size_t output_buf_size = BLOCK_SIZE * BLOCK_SIZE * output_size * sizeof(scalar_t); 350 | size_t X1_buf_size = BLOCK_SIZE * BLOCK_SIZE * X1_third_size * sizeof(scalar_t); 351 | size_t X2_buf_size = BLOCK_SIZE * BLOCK_SIZE * X2_third_size * sizeof(scalar_t); 352 | size_t multipliers_size = multipliers.sizes()[0] * sizeof(scalar_t); 353 | size_t index_size = idx_output.sizes()[0] * sizeof(int32_t); 354 | 355 | size_t total_buf_size = output_buf_size + 2*X1_buf_size + 2*X2_buf_size + multipliers_size + index_size * 3; 356 | sparse_accumulation_cuda_backward_kernel<<>>( 357 | d_X1.data_ptr(), 358 | d_X2.data_ptr(), 359 | d_output.data_ptr(), 360 | X1.data_ptr(), 361 | X2.data_ptr(), 362 | idx_output.data_ptr(), 363 | idx_1.data_ptr(), 364 | idx_2.data_ptr(), 365 | multipliers.data_ptr(), 366 | output_size, 367 | X1_third_size, 368 | X2_third_size, 369 | nx, 370 | ny, 371 | nz 372 | ); 373 | })); 374 | 375 | return {d_X1, d_X2}; 376 | } 377 | 378 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 379 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 380 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 381 | 382 | std::vector sparse_accumulation_gpu_forward( 383 | torch::Tensor X1, 384 | torch::Tensor X2, 385 | torch::Tensor idx_output, 386 | int64_t output_size, 387 | torch::Tensor idx_1, 388 | torch::Tensor idx_2, 389 | torch::Tensor multipliers 390 | ) { 391 | CHECK_INPUT(X1); 392 | CHECK_INPUT(X2); 393 | CHECK_INPUT(idx_output); 394 | CHECK_INPUT(idx_1); 395 | CHECK_INPUT(idx_2); 396 | CHECK_INPUT(multipliers); 397 | 398 | return sparse_accumulation_cuda_forward(X1,X2,idx_output,output_size,idx_1,idx_2,multipliers); 399 | } 400 | 401 | std::vector sparse_accumulation_gpu_backward( 402 | torch::Tensor d_output, 403 | torch::Tensor X1, 404 | torch::Tensor X2, 405 | torch::Tensor idx_output, 406 | torch::Tensor idx_1, 407 | torch::Tensor idx_2, 408 | torch::Tensor multipliers 409 | ) { 410 | CHECK_INPUT(d_output); 411 | CHECK_INPUT(X1); 412 | CHECK_INPUT(X2); 413 | CHECK_INPUT(idx_output); 414 | CHECK_INPUT(idx_1); 415 | CHECK_INPUT(idx_2 ); 416 | CHECK_INPUT(multipliers); 417 | 418 | return sparse_accumulation_cuda_backward(d_output,X1,X2,idx_output,idx_1,idx_2,multipliers); 419 | } 420 | 421 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 422 | m.def("forward", &sparse_accumulation_gpu_forward, "Sparse Accumulation forward (CUDA)"); 423 | m.def("backward", &sparse_accumulation_gpu_backward, "Sparse Accumulation backward (CUDA)"); 424 | } 425 | -------------------------------------------------------------------------------- /sparse_accumulation/other_operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sparse_accumulation_active_dim_first_cpp, sparse_accumulation_active_dim_middle_cpp 3 | 4 | 5 | def check_all_contiguous(tensors): 6 | for tensor in tensors: 7 | if not tensor.is_contiguous(): 8 | raise ValueError("all the tensors must be contiguous") 9 | 10 | 11 | def check_all_on_cpu(tensors): 12 | for tensor in tensors: 13 | if str(tensor.device) != 'cpu': 14 | raise ValueError("all the tensors must be on cpu") 15 | 16 | 17 | def check_all_on_cuda(tensors): 18 | for tensor in tensors: 19 | if not tensor.is_cuda: 20 | raise ValueError("all the tensors must be on cuda gpu") 21 | 22 | 23 | def check_all_on_same_device(tensors): 24 | if len(tensors) == 0: 25 | return 26 | device = tensors[0].get_device() 27 | for tensor in tensors: 28 | if tensor.get_device() != device: 29 | raise ValueError("all the tensors must be on the same device") 30 | 31 | def accumulate_active_dim_middle(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 32 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 33 | check_all_on_same_device(tensors) 34 | 35 | if X1.is_cuda: 36 | raise NotImplementedError("active dimensions other than the last do not have a CUDA implementation") 37 | else: 38 | return SparseAccumulationCPUMiddle.apply(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 39 | 40 | 41 | class SparseAccumulationCPUMiddle(torch.autograd.Function): 42 | @staticmethod 43 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 44 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 45 | check_all_on_cpu(tensors) 46 | 47 | # we have implementation for non-contiguous arrays, but it is so slow that I (SP) think that 48 | # it is better to force the user to make the tensors contiguous 49 | check_all_contiguous(tensors) 50 | 51 | output = sparse_accumulation_active_dim_middle_cpp.forward_contiguous(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 52 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 53 | return output 54 | 55 | 56 | @staticmethod 57 | def backward(ctx, grad_output): 58 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 59 | if idx_output.requires_grad or idx_1.requires_grad or idx_2.requires_grad: 60 | raise ValueError("can not compute gradients with respect to tensors with integers") 61 | 62 | if multipliers.requires_grad: 63 | raise ValueError("gradients with respect to multipliers (tensor named C in the documentastion) are not supported") 64 | 65 | check_all_on_cpu(ctx.saved_tensors) 66 | 67 | # we have implementation for non-contiguous arrays, but it is so slow that I (SP) think that 68 | # it is better to force the user to make the tensors contiguous 69 | check_all_contiguous(ctx.saved_tensors) 70 | 71 | d_X1, d_X2 = sparse_accumulation_active_dim_middle_cpp.backward_contiguous(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 72 | 73 | if not X1.requires_grad: 74 | d_X1 = None 75 | if not X2.requires_grad: 76 | d_X2 = None 77 | 78 | return d_X1, d_X2, None, None, None, None, None 79 | 80 | 81 | def accumulate_active_dim_first(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 82 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 83 | check_all_on_same_device(tensors) 84 | 85 | if X1.is_cuda: 86 | raise NotImplementedError("active dimensions other than the last do not have a CUDA implementation") 87 | else: 88 | return SparseAccumulationCPUFirst.apply(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 89 | 90 | 91 | class SparseAccumulationCPUFirst(torch.autograd.Function): 92 | @staticmethod 93 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 94 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 95 | check_all_on_cpu(tensors) 96 | 97 | # we have implementation for non-contiguous arrays, but it is so slow that I (SP) think that 98 | # it is better to force the user to make the tensors contiguous 99 | check_all_contiguous(tensors) 100 | 101 | output = sparse_accumulation_active_dim_first_cpp.forward_contiguous(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 102 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 103 | return output 104 | 105 | 106 | @staticmethod 107 | def backward(ctx, grad_output): 108 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 109 | if idx_output.requires_grad or idx_1.requires_grad or idx_2.requires_grad: 110 | raise ValueError("can not compute gradients with respect to tensors with integers") 111 | 112 | if multipliers.requires_grad: 113 | raise ValueError("gradients with respect to multipliers (tensor named C in the documentastion) are not supported") 114 | 115 | check_all_on_cpu(ctx.saved_tensors) 116 | 117 | # we have implementation for non-contiguous arrays, but it is so slow that I (SP) think that 118 | # it is better to force the user to make the tensors contiguous 119 | check_all_contiguous(ctx.saved_tensors) 120 | 121 | d_X1, d_X2 = sparse_accumulation_active_dim_first_cpp.backward_contiguous(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 122 | 123 | if not X1.requires_grad: 124 | d_X1 = None 125 | if not X2.requires_grad: 126 | d_X2 = None 127 | 128 | return d_X1, d_X2, None, None, None, None, None 129 | -------------------------------------------------------------------------------- /sparse_accumulation/reference_implementations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def sparse_accumulation_loops(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers, active_dim): 4 | device = X1.device #all tensors must be on the same device and blah, blah, blah 5 | dtype = X1.dtype 6 | 7 | if active_dim == 0: 8 | output = torch.zeros([output_size, X1.shape[1], X2.shape[2]], device = device,dtype=dtype) 9 | for index in range(idx_output.shape[0]): 10 | output[idx_output[index], :, :] += X1[idx_1[index], :, :] * X2[idx_2[index], :, :] * multipliers[index] 11 | return output 12 | 13 | if active_dim == 1: 14 | output = torch.zeros([X1.shape[0], output_size, X2.shape[2]], device = device,dtype=dtype) 15 | for index in range(idx_output.shape[0]): 16 | output[:, idx_output[index], :] += X1[:, idx_1[index], :] * X2[:, idx_2[index], :] * multipliers[index] 17 | return output 18 | 19 | if active_dim == 2: 20 | output = torch.zeros([X1.shape[0], X2.shape[1], output_size], device = device,dtype=dtype) 21 | for index in range(idx_output.shape[0]): 22 | output[:, :, idx_output[index]] += X1[:, :, idx_1[index]] * X2[:, :, idx_2[index]] * multipliers[index] 23 | return output 24 | 25 | raise ValueError("active dim should be one of 0, 1, 2") 26 | 27 | 28 | def sparse_accumulation_index_add(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers, active_dim): 29 | device = X1.device #all tensors must be on the same device and blah, blah, blah 30 | dtype = X1.dtype 31 | 32 | if active_dim == 0: 33 | contributions = X1[idx_1, :, :] * X2[idx_2, :, :] * multipliers[:, None, None] 34 | output = torch.zeros([output_size, X1.shape[1], X2.shape[2]], device = device,dtype=dtype) 35 | output.index_add_(0, idx_output, contributions) 36 | return output 37 | 38 | if active_dim == 1: 39 | contributions = X1[:, idx_1, :] * X2[:, idx_2, :] * multipliers[None, :, None] 40 | output = torch.zeros([X1.shape[0], output_size, X2.shape[2]], device = device,dtype=dtype) 41 | output.index_add_(1, idx_output, contributions) 42 | return output 43 | 44 | if active_dim == 2: 45 | contributions = X1[:, :, idx_1] * X2[:, :, idx_2] * multipliers[None, None, :] 46 | output = torch.zeros([X1.shape[0], X2.shape[1], output_size], device = device,dtype=dtype) 47 | output.index_add_(2, idx_output, contributions) 48 | return output 49 | 50 | raise ValueError("active dim should be one of 0, 1, 2") 51 | 52 | def get_transformation(idx_output, output_size, X1_size, X2_size, idx_1, idx_2, multipliers): 53 | transformation = torch.zeros([X1_size, X2_size, output_size]) 54 | for index in range(idx_output.shape[0]): 55 | transformation[idx_1[index], idx_2[index], idx_output[index]] = multipliers[index] 56 | transformation = transformation.reshape([-1, transformation.shape[2]]) 57 | return transformation 58 | 59 | 60 | def get_transformation_sparse(idx_output, output_size, X1_size, X2_size, idx_1, idx_2, multipliers): 61 | i_idx, j_idx, values = [], [], [] 62 | for index in range(idx_output.shape[0]): 63 | i_idx.append(idx_1[index] * X2_size + idx_2[index]) 64 | j_idx.append(idx_output[index]) 65 | values.append(multipliers[index]) 66 | transformation_sparse = torch.sparse_coo_tensor([j_idx, i_idx], values, (output_size, X1_size * X2_size)) 67 | return transformation_sparse 68 | 69 | 70 | def sparse_accumulation_matrix_multiply(X1, X2, transformation): 71 | 72 | initial_0 = X1.shape[0] 73 | initial_1 = X1.shape[1] 74 | X = X1[:, :, :, None] * X2[:, :, None, :] #..., m1, m2 75 | first_dim = X.shape[0] * X.shape[1] 76 | second_dim = X.shape[2] * X.shape[3] 77 | X = X.reshape([first_dim, second_dim]) 78 | output = torch.matmul(X, transformation) 79 | output = output.reshape([initial_0, initial_1, -1]) 80 | return output 81 | 82 | 83 | def sparse_accumulation_sparse_matrix_multiply(X1, X2, transformation_sparse): 84 | 85 | initial_0 = X1.shape[0] 86 | initial_1 = X1.shape[1] 87 | X = X1[:, :, :, None] * X2[:, :, None, :] #..., m1, m2 88 | first_dim = X.shape[0] * X.shape[1] 89 | second_dim = X.shape[2] * X.shape[3] 90 | X = X.reshape([first_dim, second_dim]) 91 | second = X.T 92 | output = torch.matmul(transformation_sparse, second) 93 | output = output.T 94 | output = output.reshape([initial_0, initial_1, -1]) 95 | return output 96 | 97 | def sparse_accumulation_sparse_matrix_multiply_optimized(X1, X2, transformation_sparse): 98 | #initial_1 = X1.shape[1] 99 | #initial_2 = X1.shape[2] 100 | 101 | X = X1[:, None, :, :] * X2[None, :, :, :] 102 | first_dim = X.shape[0] * X.shape[1] 103 | second_dim = X.shape[2] * X.shape[3] 104 | X = X.reshape([first_dim, second_dim]) 105 | output = torch.matmul(transformation_sparse, X) 106 | return output -------------------------------------------------------------------------------- /sparse_accumulation/unified_operation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sparse_accumulation_active_dim_last_cpp 3 | try: 4 | import sparse_accumulation_cuda 5 | except ModuleNotFoundError: 6 | pass # the cuda version is not installed 7 | 8 | def check_all_contiguous(tensors): 9 | for tensor in tensors: 10 | if not tensor.is_contiguous(): 11 | raise ValueError("all the tensors must be contiguous") 12 | 13 | 14 | def check_all_on_cpu(tensors): 15 | for tensor in tensors: 16 | if str(tensor.device) != 'cpu': 17 | raise ValueError("all the tensors must be on cpu") 18 | 19 | 20 | def check_all_on_cuda(tensors): 21 | for tensor in tensors: 22 | if not tensor.is_cuda: 23 | raise ValueError("all the tensors must be on cuda gpu") 24 | 25 | 26 | def check_all_on_same_device(tensors): 27 | if len(tensors) == 0: 28 | return 29 | device = tensors[0].get_device() 30 | for tensor in tensors: 31 | if tensor.get_device() != device: 32 | raise ValueError("all the tensors must be on the same device") 33 | 34 | def accumulate(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 35 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 36 | check_all_on_same_device(tensors) 37 | 38 | if X1.is_cuda: 39 | return SparseAccumulationCUDA.apply(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 40 | else: 41 | return SparseAccumulationCPU.apply(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 42 | 43 | 44 | class SparseAccumulationCUDA(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 47 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 48 | check_all_on_cuda(tensors) 49 | check_all_on_same_device(tensors) 50 | check_all_contiguous(tensors) 51 | 52 | output = sparse_accumulation_cuda.forward(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers)[0] 53 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 54 | return output 55 | 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 60 | if idx_output.requires_grad or idx_1.requires_grad or idx_2.requires_grad: 61 | raise ValueError("can not compute gradients with respect to tensors with integers") 62 | 63 | if multipliers.requires_grad: 64 | raise ValueError("gradients with respect to multipliers (tensor named C in the documentastion) are not supported") 65 | 66 | check_all_on_cuda(ctx.saved_tensors) 67 | check_all_contiguous(ctx.saved_tensors) 68 | check_all_on_same_device(ctx.saved_tensors) 69 | 70 | 71 | d_X1, d_X2 = sparse_accumulation_cuda.backward(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 72 | 73 | if not X1.requires_grad: 74 | d_X1 = None 75 | if not X2.requires_grad: 76 | d_X2 = None 77 | 78 | return d_X1, d_X2, None, None, None, None, None 79 | 80 | 81 | class SparseAccumulationCPU(torch.autograd.Function): 82 | @staticmethod 83 | def forward(ctx, X1, X2, idx_output, output_size, idx_1, idx_2, multipliers): 84 | tensors = [X1, X2, idx_output, idx_1, idx_2, multipliers] 85 | check_all_on_cpu(tensors) 86 | 87 | # we have implementation for non-contiguous arrays, but it is so slow that I (SP) think that 88 | # it is better to force the user to make the tensors contiguous 89 | check_all_contiguous(tensors) 90 | 91 | output = sparse_accumulation_active_dim_last_cpp.forward_contiguous(X1, X2, idx_output, output_size, idx_1, idx_2, multipliers) 92 | ctx.save_for_backward(*[X1, X2, idx_output, idx_1, idx_2, multipliers]) 93 | return output 94 | 95 | 96 | @staticmethod 97 | def backward(ctx, grad_output): 98 | X1, X2, idx_output, idx_1, idx_2, multipliers = ctx.saved_tensors 99 | if idx_output.requires_grad or idx_1.requires_grad or idx_2.requires_grad: 100 | raise ValueError("can not compute gradients with respect to tensors with integers") 101 | 102 | if multipliers.requires_grad: 103 | raise ValueError("gradients with respect to multipliers (tensor named C in the documentastion) are not supported") 104 | 105 | check_all_on_cpu(ctx.saved_tensors) 106 | 107 | # we have implementation for non-contiguous arrays, but it is so slow that I (SP) think that 108 | # it is better to force the user to make the tensors contiguous 109 | check_all_contiguous(ctx.saved_tensors) 110 | 111 | d_X1, d_X2 = sparse_accumulation_active_dim_last_cpp.backward_contiguous(grad_output, X1, X2, idx_output, idx_1, idx_2, multipliers) 112 | 113 | if not X1.requires_grad: 114 | d_X1 = None 115 | if not X2.requires_grad: 116 | d_X2 = None 117 | 118 | return d_X1, d_X2, None, None, None, None, None 119 | -------------------------------------------------------------------------------- /tests/test_cpp_contiguous.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sparse_accumulation.clebsch_gordan import get_real_clebsch_gordan, ClebschGordan 3 | from sparse_accumulation.reference_implementations import sparse_accumulation_loops 4 | from sparse_accumulation.cpu_extension import sparse_accumulation_active_dim_first, sparse_accumulation_active_dim_middle 5 | from sparse_accumulation import accumulate 6 | 7 | import numpy as np 8 | 9 | def get_rule(L_MAX): 10 | clebsch = ClebschGordan(L_MAX).precomputed_ 11 | indices = get_real_clebsch_gordan(clebsch[L_MAX, L_MAX, L_MAX], L_MAX, L_MAX, L_MAX) 12 | 13 | m1_aligned, m2_aligned = [], [] 14 | multipliers, mu_aligned = [], [] 15 | for mu in range(0, 2 * L_MAX + 1): 16 | for el in indices[mu]: 17 | m1, m2, multiplier = el 18 | m1_aligned.append(m1) 19 | m2_aligned.append(m2) 20 | multipliers.append(multiplier) 21 | mu_aligned.append(mu) 22 | m1_aligned = torch.LongTensor(m1_aligned) 23 | m2_aligned = torch.LongTensor(m2_aligned) 24 | mu_aligned = torch.LongTensor(mu_aligned) 25 | multipliers = torch.FloatTensor(multipliers) 26 | 27 | indices = np.argsort(mu_aligned) 28 | 29 | m1_aligned = m1_aligned[indices] 30 | m2_aligned = m2_aligned[indices] 31 | mu_aligned = mu_aligned[indices] 32 | multipliers = multipliers[indices] 33 | 34 | return m1_aligned, m2_aligned, mu_aligned, multipliers 35 | 36 | def test_forward(epsilon = 1e-7): 37 | print("Testing forward pass with the active dimension being the last one") 38 | L_MAX = 5 39 | BATCH_SIZE = 1000 40 | N_FEATURES = 100 41 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L_MAX) 42 | 43 | 44 | X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1) 45 | X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1) 46 | 47 | 48 | python_loops_output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers, 49 | active_dim = 2) 50 | cpp_output = accumulate(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 51 | delta = python_loops_output - cpp_output 52 | 53 | relative_error = torch.mean(torch.abs(delta)) / torch.mean(torch.abs(python_loops_output)) 54 | assert relative_error < epsilon 55 | 56 | 57 | def test_forward_active_dim_first(epsilon = 1e-7): 58 | print("Testing forward pass with the active dimension being the first one") 59 | L_MAX = 5 60 | BATCH_SIZE = 1000 61 | N_FEATURES = 100 62 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L_MAX) 63 | 64 | 65 | X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES) 66 | X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES) 67 | 68 | 69 | python_loops_output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers, 70 | active_dim = 0) 71 | cpp_output = sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply(X1, X2, mu_aligned, 72 | 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 73 | delta = python_loops_output - cpp_output 74 | 75 | relative_error = torch.mean(torch.abs(delta)) / torch.mean(torch.abs(python_loops_output)) 76 | assert relative_error < epsilon 77 | 78 | def test_forward_active_dim_middle(epsilon = 1e-7): 79 | print("Testing forward pass with the active dimension being the middle one") 80 | L_MAX = 5 81 | BATCH_SIZE = 1000 82 | N_FEATURES = 100 83 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L_MAX) 84 | 85 | 86 | X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES) 87 | X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES) 88 | 89 | 90 | python_loops_output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers, 91 | active_dim = 1) 92 | cpp_output = sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply(X1, X2, mu_aligned, 93 | 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 94 | delta = python_loops_output - cpp_output 95 | 96 | relative_error = torch.mean(torch.abs(delta)) / torch.mean(torch.abs(python_loops_output)) 97 | assert relative_error < epsilon 98 | 99 | 100 | 101 | def get_relative_error(first, second): 102 | delta = first - second 103 | return torch.sum(torch.abs(delta)) / torch.sum(torch.abs(first)) 104 | 105 | 106 | def test_backward(epsilon = 1e-7): 107 | print("Testing backward pass with the active dimension being the last one") 108 | L_MAX = 5 109 | BATCH_SIZE = 1000 110 | N_FEATURES = 100 111 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L_MAX) 112 | 113 | X1 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1) 114 | X2 = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1) 115 | 116 | X1.requires_grad = True 117 | X2.requires_grad = True 118 | python_loops_output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers, 119 | active_dim = 2) 120 | output_grad = torch.randn(*python_loops_output.shape) 121 | python_loops_output.backward(gradient = output_grad) 122 | 123 | X1_grad_python_loops = torch.detach(torch.clone(X1.grad)) 124 | X2_grad_python_loops = torch.detach(torch.clone(X2.grad)) 125 | 126 | X1.grad.zero_() 127 | X2.grad.zero_() 128 | 129 | cpp_output = accumulate(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 130 | cpp_output.backward(gradient = output_grad) 131 | 132 | X1_grad_cpp = torch.detach(torch.clone(X1.grad)) 133 | X2_grad_cpp = torch.detach(torch.clone(X2.grad)) 134 | 135 | assert get_relative_error(X1_grad_python_loops, X1_grad_cpp) < epsilon 136 | assert get_relative_error(X2_grad_python_loops, X2_grad_cpp) < epsilon 137 | 138 | 139 | def test_backward_active_dim_middle(epsilon = 1e-7): 140 | print("Testing backward pass with the active dimension being the middle one") 141 | L_MAX = 5 142 | BATCH_SIZE = 1000 143 | N_FEATURES = 100 144 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L_MAX) 145 | 146 | X1 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES) 147 | X2 = torch.randn(BATCH_SIZE, 2 * L_MAX + 1, N_FEATURES) 148 | 149 | X1.requires_grad = True 150 | X2.requires_grad = True 151 | python_loops_output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers, 152 | active_dim = 1) 153 | 154 | output_grad = torch.randn(*python_loops_output.shape) 155 | python_loops_output.backward(gradient = output_grad) 156 | 157 | X1_grad_python_loops = torch.detach(torch.clone(X1.grad)) 158 | X2_grad_python_loops = torch.detach(torch.clone(X2.grad)) 159 | 160 | X1.grad.zero_() 161 | X2.grad.zero_() 162 | 163 | cpp_output = sparse_accumulation_active_dim_middle.SparseAccumulationActiveDimMiddle.apply(X1, X2, mu_aligned, 164 | 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 165 | 166 | cpp_output.backward(gradient = output_grad) 167 | 168 | X1_grad_cpp = torch.detach(torch.clone(X1.grad)) 169 | X2_grad_cpp = torch.detach(torch.clone(X2.grad)) 170 | 171 | assert get_relative_error(X1_grad_python_loops, X1_grad_cpp) < epsilon 172 | assert get_relative_error(X2_grad_python_loops, X2_grad_cpp) < epsilon 173 | 174 | 175 | def test_backward_active_dim_first(epsilon = 1e-7): 176 | print("Testing backward pass with the active dimension being the first one") 177 | L_MAX = 5 178 | BATCH_SIZE = 1000 179 | N_FEATURES = 100 180 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L_MAX) 181 | 182 | X1 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES) 183 | X2 = torch.randn(2 * L_MAX + 1, BATCH_SIZE, N_FEATURES) 184 | 185 | X1.requires_grad = True 186 | X2.requires_grad = True 187 | python_loops_output = sparse_accumulation_loops(X1, X2, mu_aligned, 2 * L_MAX + 1, 188 | m1_aligned, m2_aligned, multipliers, 189 | active_dim = 0) 190 | output_grad = torch.randn(*python_loops_output.shape) 191 | python_loops_output.backward(gradient = output_grad) 192 | 193 | X1_grad_python_loops = torch.detach(torch.clone(X1.grad)) 194 | X2_grad_python_loops = torch.detach(torch.clone(X2.grad)) 195 | 196 | X1.grad.zero_() 197 | X2.grad.zero_() 198 | 199 | cpp_output = sparse_accumulation_active_dim_first.SparseAccumulationActiveDimFirst.apply(X1, X2, mu_aligned, 200 | 2 * L_MAX + 1, m1_aligned, m2_aligned, multipliers) 201 | cpp_output.backward(gradient = output_grad) 202 | 203 | X1_grad_cpp = torch.detach(torch.clone(X1.grad)) 204 | X2_grad_cpp = torch.detach(torch.clone(X2.grad)) 205 | 206 | assert get_relative_error(X1_grad_python_loops, X1_grad_cpp) < epsilon 207 | assert get_relative_error(X2_grad_python_loops, X2_grad_cpp) < epsilon 208 | 209 | -------------------------------------------------------------------------------- /tests/test_cpp_jit_cuda.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from functools import partial 4 | 5 | import torch 6 | from torch.utils import cpp_extension 7 | import numpy as np 8 | from sparse_accumulation.clebsch_gordan import ClebschGordan, get_real_clebsch_gordan 9 | from sparse_accumulation.reference_implementations import sparse_accumulation_loops 10 | from sparse_accumulation import accumulate, get_cg_transformation_rule 11 | 12 | 13 | def get_rule(l1, l2, l_output, dtype=torch.float64, device="cpu"): 14 | cachepath = f".cache/clebsch_gordan_l1_{l1}_l2_{l2}_l_output_{l_output}_dtype_{dtype}.pt" 15 | if os.path.isfile(cachepath): 16 | return torch.load(cachepath, map_location=device) 17 | 18 | 19 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_cg_transformation_rule(l1, l2, l_output, dtype = dtype, device = device) 20 | os.makedirs(os.path.dirname(cachepath), exist_ok=True) 21 | torch.save([m1_aligned, m2_aligned, mu_aligned, multipliers], cachepath) 22 | return m1_aligned, m2_aligned, mu_aligned, multipliers 23 | 24 | 25 | @pytest.mark.parametrize("L1", [3, 5, 8]) 26 | @pytest.mark.parametrize("L2", [3, 5, 8]) 27 | @pytest.mark.parametrize("L_OUTPUT", [3, 5, 8]) 28 | @pytest.mark.parametrize("BATCH_SIZE", [1, 20, 200]) 29 | @pytest.mark.parametrize("N_FEATURES", [1, 20, 105]) 30 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 31 | def test_forward(L1, L2, L_OUTPUT, BATCH_SIZE, N_FEATURES, dtype): 32 | if (L_OUTPUT < abs(L1 - L2)) or (L_OUTPUT > L1 + L2): 33 | pytest.skip() 34 | 35 | atol, rtol = (1e-5, 1e-6) if dtype is torch.float32 else (1e-7, 1e-8) 36 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L1, L2, L_OUTPUT, dtype) 37 | 38 | m1_aligned_d = m1_aligned.clone().cuda() 39 | m2_aligned_d = m2_aligned.clone().cuda() 40 | mu_aligned_d = mu_aligned.clone().cuda() 41 | multipliers_d = multipliers.clone().cuda() 42 | 43 | generator = torch.Generator() 44 | generator.manual_seed(30) 45 | X1 = torch.randn( 46 | (BATCH_SIZE, N_FEATURES, 2 * L1 + 1), 47 | generator=generator, 48 | dtype=dtype, 49 | ) 50 | X2 = torch.randn( 51 | (BATCH_SIZE, N_FEATURES, 2 * L2 + 1), 52 | generator=generator, 53 | dtype=dtype, 54 | ) 55 | # X1_d = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1,device="cuda") 56 | # X2_d = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1,device="cuda") 57 | X1_d = X1.clone().cuda() # torch.randn(BATCH_SIZE, N_FEATURES,device="cuda") 58 | X2_d = X2.clone().cuda() # torch.randn(BATCH_SIZE,device="cuda") 59 | 60 | python_loops_output = sparse_accumulation_loops( 61 | X1, 62 | X2, 63 | mu_aligned, 64 | 2 * L_OUTPUT + 1, 65 | m1_aligned, 66 | m2_aligned, 67 | multipliers, 68 | active_dim=2, 69 | ) 70 | 71 | cuda_output = accumulate( 72 | X1_d, 73 | X2_d, 74 | mu_aligned_d, 75 | 2 * L_OUTPUT + 1, 76 | m1_aligned_d, 77 | m2_aligned_d, 78 | multipliers_d, 79 | ) 80 | 81 | cuda_output_cpu = cuda_output.cpu() 82 | delta = python_loops_output - cuda_output_cpu 83 | relative_error = torch.amax(torch.abs(delta / python_loops_output)) 84 | 85 | assert torch.allclose( 86 | python_loops_output, cuda_output_cpu, atol=atol, rtol=rtol 87 | ), f"assertion failed \n {cuda_output=} \n {python_loops_output=}" 88 | 89 | errmax = torch.amax(torch.abs(delta)) 90 | print(f"{errmax=}") 91 | print(f"{torch.amin(torch.abs(cuda_output_cpu))=}") 92 | 93 | assert torch.allclose(python_loops_output, cuda_output_cpu, atol=atol, rtol=rtol) 94 | # print(f"{python_time=} s") 95 | print() 96 | 97 | 98 | @pytest.mark.parametrize("seed", [30, 42]) 99 | @pytest.mark.parametrize("L1", [3, 5, 8]) 100 | @pytest.mark.parametrize("L2", [3, 5, 8]) 101 | @pytest.mark.parametrize("L_OUTPUT", [3, 5, 8]) 102 | @pytest.mark.parametrize("BATCH_SIZE", [1, 20, 200]) 103 | @pytest.mark.parametrize("N_FEATURES", [1, 20, 105]) 104 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 105 | def test_backward(L1, L2, L_OUTPUT, BATCH_SIZE, N_FEATURES, seed, dtype): 106 | if (L_OUTPUT < abs(L1 - L2)) or (L_OUTPUT > L1 + L2): 107 | pytest.skip() 108 | atol, rtol = (1e-5, 1e-6) if dtype is torch.float32 else (1e-7, 1e-8) 109 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L1, L2, L_OUTPUT, dtype) 110 | 111 | m1_aligned_d = m1_aligned.clone().cuda() 112 | m2_aligned_d = m2_aligned.clone().cuda() 113 | mu_aligned_d = mu_aligned.clone().cuda() 114 | multipliers_d = multipliers.clone().cuda() 115 | generator = torch.Generator() 116 | generator.manual_seed(seed) 117 | X1 = torch.randn( 118 | (BATCH_SIZE, N_FEATURES, 2 * L1 + 1), 119 | generator=generator, 120 | dtype=dtype, 121 | ) 122 | X2 = torch.randn( 123 | (BATCH_SIZE, N_FEATURES, 2 * L2 + 1), 124 | generator=generator, 125 | dtype=dtype, 126 | ) 127 | # X1_d = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1,device="cuda") 128 | # X2_d = torch.randn(BATCH_SIZE, N_FEATURES, 2 * L_MAX + 1,device="cuda") 129 | X1_d = X1.clone().cuda() # torch.randn(BATCH_SIZE, N_FEATURES,device="cuda") 130 | X2_d = X2.clone().cuda() # torch.randn(BATCH_SIZE,device="cuda") 131 | 132 | X1.requires_grad = True 133 | X2.requires_grad = True 134 | 135 | python_loops_output = sparse_accumulation_loops( 136 | X1, 137 | X2, 138 | mu_aligned, 139 | 2 * L_OUTPUT + 1, 140 | m1_aligned, 141 | m2_aligned, 142 | multipliers, 143 | active_dim=2, 144 | ) 145 | output_grad = torch.zeros(*python_loops_output.shape, dtype=dtype) 146 | python_loops_output.backward(gradient=output_grad) 147 | 148 | X1_grad_python_loops = torch.detach(torch.clone(X1.grad)) 149 | X2_grad_python_loops = torch.detach(torch.clone(X2.grad)) 150 | 151 | 152 | output_grad_d = output_grad.clone().cuda() 153 | 154 | 155 | X1_d.requires_grad = True 156 | X2_d.requires_grad = True 157 | cuda_output = accumulate( 158 | X1_d, 159 | X2_d, 160 | mu_aligned_d, 161 | 2 * L_OUTPUT + 1, 162 | m1_aligned_d, 163 | m2_aligned_d, 164 | multipliers_d 165 | ) 166 | 167 | cuda_output.backward(gradient=output_grad_d) 168 | X1_grad_cuda = torch.detach(torch.clone(X1_d.grad)) 169 | X2_grad_cuda = torch.detach(torch.clone(X2_d.grad)) 170 | 171 | 172 | 173 | X1_grad_cuda = X1_grad_cuda.cpu() 174 | X2_grad_cuda = X2_grad_cuda.cpu() 175 | 176 | errmax1 = torch.amax(torch.abs(X1_grad_python_loops - X1_grad_cuda)) 177 | errmax2 = torch.amax(torch.abs(X2_grad_python_loops - X2_grad_cuda)) 178 | print(f"{errmax1=}") 179 | print(f"{torch.amin(torch.abs(X1_grad_cuda))=}") 180 | print(f"{errmax2=}") 181 | print(f"{torch.amin(torch.abs(X2_grad_cuda))=}") 182 | 183 | # print(f"{X1_grad_cuda=}") 184 | # print(f"{X1_grad_python_loops=}") 185 | 186 | # print(f"{X2_grad_cuda=}") 187 | # print(f"{X2_grad_python_loops=}") 188 | # assert torch.allclose(python_loops_output , cuda_output_cpu,atol=atol) 189 | 190 | assert torch.allclose( 191 | X1_grad_python_loops, X1_grad_cuda, atol=atol, rtol=rtol 192 | ) 193 | assert torch.allclose( 194 | X2_grad_python_loops, X2_grad_cuda, atol=atol, rtol=rtol 195 | ) 196 | 197 | 198 | 199 | 200 | @pytest.mark.parametrize("function", [ 201 | partial(sparse_accumulation_loops, active_dim=2), 202 | accumulate 203 | ]) 204 | @pytest.mark.parametrize("L1", [3, 5, 7]) 205 | @pytest.mark.parametrize("L2", [3, 5, 7]) 206 | @pytest.mark.parametrize("L_OUTPUT", [3, 5, 7]) 207 | @pytest.mark.parametrize("BATCH_SIZE", [1, 20, 200]) 208 | @pytest.mark.parametrize("N_FEATURES", [1, 20, 105]) 209 | @pytest.mark.parametrize("dtype", [torch.float64]) 210 | @pytest.mark.parametrize("device", ['cpu', 'cuda']) 211 | def test_backward_gradcheck(function, L1, L2, L_OUTPUT, BATCH_SIZE, N_FEATURES, dtype, device): 212 | if (L_OUTPUT < abs(L1 - L2)) or (L_OUTPUT > L1 + L2): 213 | pytest.skip() 214 | atol, rtol = (5e-2, 1e-3) if dtype == torch.float32 else (1e-7, 1e-8) 215 | if device == 'cpu' and function == accumulate: 216 | pytest.skip() 217 | 218 | m1_aligned, m2_aligned, mu_aligned, multipliers = get_rule(L1, L2, L_OUTPUT, dtype, device) 219 | 220 | generator = torch.Generator(device=device) 221 | generator.manual_seed(0xDEADBEEF) 222 | X1 = torch.randn( 223 | (BATCH_SIZE, N_FEATURES, 2 * L1 + 1), 224 | requires_grad=True, 225 | generator=generator, 226 | dtype=dtype, 227 | device=device, 228 | ) 229 | X2 = torch.randn( 230 | (BATCH_SIZE, N_FEATURES, 2 * L2 + 1), 231 | requires_grad=True, 232 | generator=generator, 233 | dtype=dtype, 234 | device=device, 235 | ) 236 | 237 | assert torch.autograd.gradcheck( 238 | function, 239 | (X1, X2, mu_aligned, 2 * L_OUTPUT + 1, m1_aligned, m2_aligned, multipliers), 240 | fast_mode=True, atol=atol, rtol=rtol, 241 | ) 242 | 243 | print("torch.autograd.gradcheck passed\n") 244 | -------------------------------------------------------------------------------- /update_docs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system("rm -r ../build/*") 4 | os.chdir("./docs") 5 | os.system("make html") 6 | os.chdir("../") 7 | os.system("git checkout -f gh-pages") 8 | os.system("git rm -r *") 9 | os.system("cp -r ../build/html/* .") 10 | with open(".nojekyll", "w") as f: 11 | pass 12 | 13 | os.system("git add *") 14 | os.system("git add .nojekyll") 15 | os.system("git commit -m 'automatic docs build'") 16 | os.system("git push") 17 | os.system("git checkout main") 18 | --------------------------------------------------------------------------------