├── logo.png ├── tests ├── benchmarks │ ├── results │ │ ├── p2m0 │ │ │ ├── memory │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ │ └── speedup │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ ├── p3m0 │ │ │ ├── memory │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ │ └── speedup │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ ├── fc_p2m0 │ │ │ ├── memory │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ │ └── speedup │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ ├── fc_p3m0 │ │ │ ├── memory │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ │ └── speedup │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ ├── layer_2d │ │ │ ├── memory │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ │ └── speedup │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ ├── layer_3d │ │ │ ├── memory │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ │ └── speedup │ │ │ │ ├── fwd.png │ │ │ │ └── fwd_bwd.png │ │ ├── clifford_mlp_fwd_bwd_runtime.png │ │ └── clifford_mlp_runtime_scaling.png │ ├── p2m0.py │ ├── p3m0.py │ ├── fc_p2m0.py │ ├── fc_p3m0.py │ ├── layer_2d.py │ └── layer_3d.py ├── p2m0.py ├── p3m0.py ├── fc_p2m0.py ├── fc_p3m0.py ├── cue_baseline.py ├── baselines.py └── utils.py ├── ops ├── __init__.py ├── p2m0.py ├── fc_p2m0.py ├── p3m0.py └── fc_p3m0.py ├── modules ├── layer.py └── baseline.py └── README.md /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/logo.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p2m0/memory/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/memory/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p3m0/memory/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/memory/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p2m0/memory/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/memory/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p3m0/memory/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/memory/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p2m0/speedup/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/speedup/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p3m0/speedup/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/speedup/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p2m0/speedup/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/speedup/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p3m0/speedup/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/speedup/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_2d/memory/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/memory/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_2d/speedup/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/speedup/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_3d/memory/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/memory/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_3d/speedup/fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/speedup/fwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p2m0/memory/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/memory/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p2m0/speedup/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/speedup/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p3m0/memory/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/memory/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/p3m0/speedup/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/speedup/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p2m0/memory/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/memory/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p2m0/speedup/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/speedup/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p3m0/memory/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/memory/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/fc_p3m0/speedup/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/speedup/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_2d/memory/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/memory/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_3d/memory/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/memory/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_2d/speedup/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/speedup/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/layer_3d/speedup/fwd_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/speedup/fwd_bwd.png -------------------------------------------------------------------------------- /tests/benchmarks/results/clifford_mlp_fwd_bwd_runtime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/clifford_mlp_fwd_bwd_runtime.png -------------------------------------------------------------------------------- /tests/benchmarks/results/clifford_mlp_runtime_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/clifford_mlp_runtime_scaling.png -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .p2m0 import fused_gelu_sgp_norm_2d 2 | from .p3m0 import fused_gelu_sgp_norm_3d 3 | from .fc_p2m0 import fused_gelu_fcgp_norm_2d 4 | from .fc_p3m0 import fused_gelu_fcgp_norm_3d 5 | 6 | from .p2m0 import NUM_PRODUCT_WEIGHTS as P2M0_NUM_PRODUCT_WEIGHTS 7 | from .p3m0 import NUM_PRODUCT_WEIGHTS as P3M0_NUM_PRODUCT_WEIGHTS 8 | 9 | from .p2m0 import NUM_GRADES as P2M0_NUM_GRADES 10 | from .p3m0 import NUM_GRADES as P3M0_NUM_GRADES -------------------------------------------------------------------------------- /tests/p2m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_float32_matmul_precision('medium') 4 | 5 | from ops.p2m0 import fused_gelu_sgp_norm_2d 6 | from tests.baselines import gelu_sgp_norm_2d_torch 7 | from tests.utils import run_correctness_test, run_benchmark 8 | 9 | 10 | if __name__ == "__main__": 11 | assert torch.cuda.is_available() 12 | 13 | rep = 1000 14 | batch_size = 4096 15 | num_features = 512 16 | 17 | x = torch.randn(4, batch_size, num_features).cuda().contiguous() 18 | y = torch.randn(4, batch_size, num_features).cuda().contiguous() 19 | weight = torch.randn(num_features, 10).cuda().contiguous() 20 | 21 | run_correctness_test(fused_gelu_sgp_norm_2d, gelu_sgp_norm_2d_torch, {'x': x, 'y': y, 'weight': weight}) 22 | run_benchmark(fused_gelu_sgp_norm_2d, gelu_sgp_norm_2d_torch, (x, y, weight), rep, verbose=True) -------------------------------------------------------------------------------- /tests/p3m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_float32_matmul_precision('medium') 4 | 5 | from ops.p3m0 import fused_gelu_sgp_norm_3d 6 | from tests.baselines import gelu_sgp_norm_3d_torch 7 | from tests.utils import run_correctness_test, run_benchmark 8 | 9 | 10 | if __name__ == "__main__": 11 | assert torch.cuda.is_available() 12 | 13 | rep = 1000 14 | batch_size = 4096 15 | num_features = 512 16 | 17 | x = torch.randn(8, batch_size, num_features).cuda().contiguous() 18 | y = torch.randn(8, batch_size, num_features).cuda().contiguous() 19 | weight = torch.randn(num_features, 20).cuda().contiguous() 20 | 21 | run_correctness_test(fused_gelu_sgp_norm_3d, gelu_sgp_norm_3d_torch, {'x': x, 'y': y, 'weight': weight}) 22 | run_benchmark(fused_gelu_sgp_norm_3d, gelu_sgp_norm_3d_torch, (x, y, weight), rep, verbose=True) -------------------------------------------------------------------------------- /tests/fc_p2m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_float32_matmul_precision('medium') 4 | 5 | from ops.fc_p2m0 import fused_gelu_fcgp_norm_2d 6 | from tests.baselines import gelu_fcgp_norm_2d_torch 7 | from tests.utils import run_correctness_test, run_benchmark 8 | 9 | 10 | if __name__ == "__main__": 11 | assert torch.cuda.is_available() 12 | 13 | rep = 1000 14 | batch_size = 4096 15 | num_features = 512 16 | 17 | x = torch.randn(4, batch_size, num_features).cuda().contiguous() 18 | y = torch.randn(4, batch_size, num_features).cuda().contiguous() 19 | weight = torch.randn(10, num_features, num_features).cuda().contiguous() 20 | 21 | run_correctness_test(fused_gelu_fcgp_norm_2d, gelu_fcgp_norm_2d_torch, {'x': x, 'y': y, 'weight': weight}) 22 | run_benchmark(fused_gelu_fcgp_norm_2d, gelu_fcgp_norm_2d_torch, (x, y, weight), rep, verbose=True) -------------------------------------------------------------------------------- /tests/fc_p3m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_float32_matmul_precision('medium') 4 | 5 | from ops.fc_p3m0 import fused_gelu_fcgp_norm_3d 6 | from tests.baselines import gelu_fcgp_norm_3d_torch 7 | from tests.utils import run_correctness_test, run_benchmark 8 | 9 | 10 | if __name__ == "__main__": 11 | assert torch.cuda.is_available() 12 | 13 | rep = 1000 14 | batch_size = 4096 15 | num_features = 512 16 | 17 | x = torch.randn(8, batch_size, num_features).cuda().contiguous() 18 | y = torch.randn(8, batch_size, num_features).cuda().contiguous() 19 | weight = torch.randn(20, num_features, num_features).cuda().contiguous() 20 | 21 | run_correctness_test(fused_gelu_fcgp_norm_3d, gelu_fcgp_norm_3d_torch, {'x': x, 'y': y, 'weight': weight}) 22 | run_benchmark(fused_gelu_fcgp_norm_3d, gelu_fcgp_norm_3d_torch, (x, y, weight), rep, verbose=True) -------------------------------------------------------------------------------- /tests/cue_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cuequivariance as cue 3 | import cuequivariance_torch as cuet 4 | 5 | 6 | def mvgelu(x): 7 | """Apply GELU activation gated by scalar component.""" 8 | # x has shape (B, N * 8) 9 | b = x.shape[0] 10 | x = x.view(b, -1, 8) # (B, N, 8) 11 | s = x[..., 0:1] # scalar part 12 | gate = 0.5 * (1 + torch.erf(s * 0.7071067811865475)) 13 | y = gate * x 14 | y = y.view(b, -1) # (B, N * 8) 15 | return y 16 | 17 | 18 | def initialize_linear(N: int): 19 | """ 20 | Initialize MLP with linear layer + weighted GP + GELU + BatchNorm. 21 | Related: https://github.com/NVIDIA/cuEquivariance/issues/194 22 | """ 23 | irreps = cue.Irreps("O3", f"{N}x0e + {N}x1o + {N}x1e + {N}x0o") 24 | 25 | ep_weighted = cue.descriptors.fully_connected_tensor_product( 26 | irreps.set_mul(1), 27 | irreps.set_mul(1), 28 | irreps.set_mul(1) 29 | ) 30 | 31 | [(_, stp_weighted)] = ep_weighted.polynomial.operations 32 | stp_weighted = stp_weighted.append_modes_to_all_operands("n", dict(n=N)) 33 | p_weighted = cue.SegmentedPolynomial.eval_last_operand(stp_weighted) 34 | 35 | weighted_gp = cuet.SegmentedPolynomial(p_weighted, method="uniform_1d").cuda() 36 | linear = cuet.Linear(irreps_in=irreps, irreps_out=irreps, layout=cue.ir_mul, layout_in=cue.ir_mul, layout_out=cue.ir_mul).cuda() 37 | norm = cuet.layers.BatchNorm(irreps=irreps, layout=cue.ir_mul).cuda() 38 | 39 | @torch.compile 40 | def mlp(x, w): 41 | y = linear(x) 42 | x = mvgelu(x) 43 | y = mvgelu(y) 44 | [x] = weighted_gp([w, x, y]) 45 | x = norm(x) 46 | return x 47 | 48 | return mlp -------------------------------------------------------------------------------- /tests/benchmarks/p2m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('medium') 3 | torch._dynamo.config.cache_size_limit = 512 4 | 5 | from ops.p2m0 import fused_gelu_sgp_norm_2d 6 | from tests.baselines import gelu_sgp_norm_2d_torch 7 | from tests.utils import plot_heatmap, print_results_table, run_sweep 8 | 9 | 10 | def setup_benchmark(batch_size, num_features): 11 | x = torch.randn(4, batch_size, num_features).cuda().contiguous() 12 | y = torch.randn(4, batch_size, num_features).cuda().contiguous() 13 | weight = torch.randn(num_features, 10).cuda().contiguous() 14 | return x, y, weight 15 | 16 | 17 | if __name__ == "__main__": 18 | assert torch.cuda.is_available() 19 | 20 | path = "tests/benchmarks/results/p2m0" 21 | 22 | results = run_sweep( 23 | fused_gelu_sgp_norm_2d, 24 | gelu_sgp_norm_2d_torch, 25 | setup_benchmark, 26 | batch_sizes=[1024, 2048, 4096, 8192], 27 | num_features_list=[128, 256, 512, 1024], 28 | rep=200 29 | ) 30 | 31 | print_results_table(results, "p2m0") 32 | 33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(2,0)', 34 | path + '/speedup/fwd.png') 35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(2,0)', 36 | path + '/speedup/fwd_bwd.png') 37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)', 38 | path + '/memory/fwd.png', invert_cmap=True) 39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)', 40 | path + '/memory/fwd_bwd.png', invert_cmap=True) 41 | -------------------------------------------------------------------------------- /tests/benchmarks/p3m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('medium') 3 | torch._dynamo.config.cache_size_limit = 512 4 | 5 | from ops.p3m0 import fused_gelu_sgp_norm_3d 6 | from tests.baselines import gelu_sgp_norm_3d_torch 7 | from tests.utils import plot_heatmap, print_results_table, run_sweep 8 | 9 | 10 | def setup_benchmark(batch_size, num_features): 11 | x = torch.randn(8, batch_size, num_features).cuda().contiguous() 12 | y = torch.randn(8, batch_size, num_features).cuda().contiguous() 13 | weight = torch.randn(num_features, 20).cuda().contiguous() 14 | return x, y, weight 15 | 16 | 17 | if __name__ == "__main__": 18 | assert torch.cuda.is_available() 19 | 20 | path = "tests/benchmarks/results/p3m0" 21 | 22 | results = run_sweep( 23 | fused_gelu_sgp_norm_3d, 24 | gelu_sgp_norm_3d_torch, 25 | setup_benchmark, 26 | batch_sizes=[1024, 2048, 4096, 8192], 27 | num_features_list=[128, 256, 512, 1024], 28 | rep=200 29 | ) 30 | 31 | print_results_table(results, "p3m0") 32 | 33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(3,0)', 34 | path + '/speedup/fwd.png') 35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(3,0)', 36 | path + '/speedup/fwd_bwd.png') 37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)', 38 | path + '/memory/fwd.png', invert_cmap=True) 39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)', 40 | path + '/memory/fwd_bwd.png', invert_cmap=True) 41 | -------------------------------------------------------------------------------- /tests/benchmarks/fc_p2m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('medium') 3 | torch._dynamo.config.cache_size_limit = 512 4 | 5 | from ops.fc_p2m0 import fused_gelu_fcgp_norm_2d 6 | from tests.baselines import gelu_fcgp_norm_2d_torch 7 | from tests.utils import plot_heatmap, print_results_table, run_sweep 8 | 9 | 10 | def setup_benchmark(batch_size, num_features): 11 | x = torch.randn(4, batch_size, num_features).cuda().contiguous() 12 | y = torch.randn(4, batch_size, num_features).cuda().contiguous() 13 | weight = torch.randn(10, num_features, num_features).cuda().contiguous() 14 | return x, y, weight 15 | 16 | 17 | if __name__ == "__main__": 18 | assert torch.cuda.is_available() 19 | 20 | path = "tests/benchmarks/results/fc_p2m0" 21 | 22 | results = run_sweep( 23 | fused_gelu_fcgp_norm_2d, 24 | gelu_fcgp_norm_2d_torch, 25 | setup_benchmark, 26 | batch_sizes=[1024, 2048, 4096, 8192], 27 | num_features_list=[128, 256, 512, 1024], 28 | rep=200 29 | ) 30 | 31 | print_results_table(results, "fc_p2m0") 32 | 33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(2,0)', 34 | path + '/speedup/fwd.png') 35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(2,0)', 36 | path + '/speedup/fwd_bwd.png') 37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)', 38 | path + '/memory/fwd.png', invert_cmap=True) 39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)', 40 | path + '/memory/fwd_bwd.png', invert_cmap=True) 41 | -------------------------------------------------------------------------------- /tests/benchmarks/fc_p3m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('medium') 3 | torch._dynamo.config.cache_size_limit = 512 4 | 5 | from ops.fc_p3m0 import fused_gelu_fcgp_norm_3d 6 | from tests.baselines import gelu_fcgp_norm_3d_torch 7 | from tests.utils import plot_heatmap, print_results_table, run_sweep 8 | 9 | 10 | def setup_benchmark(batch_size, num_features): 11 | x = torch.randn(8, batch_size, num_features).cuda().contiguous() 12 | y = torch.randn(8, batch_size, num_features).cuda().contiguous() 13 | weight = torch.randn(20, num_features, num_features).cuda().contiguous() 14 | return x, y, weight 15 | 16 | 17 | if __name__ == "__main__": 18 | assert torch.cuda.is_available() 19 | 20 | path = "tests/benchmarks/results/fc_p3m0" 21 | 22 | results = run_sweep( 23 | fused_gelu_fcgp_norm_3d, 24 | gelu_fcgp_norm_3d_torch, 25 | setup_benchmark, 26 | batch_sizes=[1024, 2048, 4096, 8192], 27 | num_features_list=[128, 256, 512, 1024], 28 | rep=200 29 | ) 30 | 31 | print_results_table(results, "fc_p3m0") 32 | 33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(3,0)', 34 | path + '/speedup/fwd.png') 35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(3,0)', 36 | path + '/speedup/fwd_bwd.png') 37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)', 38 | path + '/memory/fwd.png', invert_cmap=True) 39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)', 40 | path + '/memory/fwd_bwd.png', invert_cmap=True) 41 | -------------------------------------------------------------------------------- /tests/benchmarks/layer_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('medium') 3 | torch._dynamo.config.cache_size_limit = 512 4 | 5 | from modules.layer import Layer 6 | from modules.baseline import Layer as BaselineLayer 7 | from tests.utils import plot_heatmap, print_results_table, run_single_benchmark 8 | 9 | 10 | def setup_benchmark(batch_size, num_features): 11 | x = torch.randn(4, batch_size, num_features).cuda().contiguous() 12 | return (x,) 13 | 14 | 15 | def create_layers(num_features: int): 16 | layer_fused = Layer(num_features, dims=2, normalize=True, use_fc=False).float().cuda() 17 | layer_torch = BaselineLayer(num_features, dims=2, normalize=True, use_fc=False).float().cuda() 18 | return layer_fused, layer_torch 19 | 20 | 21 | if __name__ == "__main__": 22 | assert torch.cuda.is_available() 23 | 24 | rep = 1000 25 | warmup = 500 26 | batch_sizes=[1024, 2048, 4096, 8192] 27 | num_features_list=[128, 256, 512, 1024] 28 | path = "tests/benchmarks/results/layer_2d" 29 | 30 | results = [] 31 | 32 | print("Running benchmark sweep...") 33 | print(f"Batch sizes: {batch_sizes}") 34 | print(f"Num features: {num_features_list}") 35 | 36 | for batch_size in batch_sizes: 37 | for num_features in num_features_list: 38 | print(f"Running batch_size={batch_size}, num_features={num_features}...", end=" ") 39 | triton_fn, torch_fn = create_layers(num_features) 40 | triton_fn = torch.compile(triton_fn) 41 | torch_fn = torch.compile(torch_fn) 42 | 43 | result = run_single_benchmark( 44 | triton_fn, torch_fn, setup_benchmark, batch_size, 45 | num_features, rep, warmup, verify_correctness=False, return_mode='mean' 46 | ) 47 | results.append(result) 48 | 49 | fwd_msg = (f"Fwd: {result['speedup_fwd']:.2f}x" 50 | if result['speedup_fwd'] else "Fwd: OOM") 51 | bwd_msg = (f"Fwd+Bwd: {result['speedup_fwd_bwd']:.2f}x" 52 | if result['speedup_fwd_bwd'] else "Fwd+Bwd: OOM") 53 | print(f"{fwd_msg}, {bwd_msg}") 54 | 55 | print_results_table(results, "layer_2d") 56 | 57 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(2,0)', 58 | path + '/speedup/fwd.png') 59 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(2,0)', 60 | path + '/speedup/fwd_bwd.png') 61 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)', 62 | path + '/memory/fwd.png', invert_cmap=True) 63 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)', 64 | path + '/memory/fwd_bwd.png', invert_cmap=True) -------------------------------------------------------------------------------- /tests/benchmarks/layer_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.set_float32_matmul_precision('medium') 3 | torch._dynamo.config.cache_size_limit = 512 4 | 5 | from modules.layer import Layer 6 | from modules.baseline import Layer as BaselineLayer 7 | from tests.utils import plot_heatmap, print_results_table, run_single_benchmark 8 | 9 | 10 | def setup_benchmark(batch_size, num_features): 11 | x = torch.randn(8, batch_size, num_features).cuda().contiguous() 12 | return (x,) 13 | 14 | 15 | def create_layers(num_features: int): 16 | layer_fused = Layer(num_features, dims=3, normalize=True, use_fc=False).float().cuda() 17 | layer_torch = BaselineLayer(num_features, dims=3, normalize=True, use_fc=False).float().cuda() 18 | return layer_fused, layer_torch 19 | 20 | 21 | if __name__ == "__main__": 22 | assert torch.cuda.is_available() 23 | 24 | rep = 1000 25 | warmup = 500 26 | batch_sizes=[1024, 2048, 4096, 8192] 27 | num_features_list=[128, 256, 512, 1024] 28 | path = "tests/benchmarks/results/layer_3d" 29 | 30 | results = [] 31 | 32 | print("Running benchmark sweep...") 33 | print(f"Batch sizes: {batch_sizes}") 34 | print(f"Num features: {num_features_list}") 35 | 36 | for batch_size in batch_sizes: 37 | for num_features in num_features_list: 38 | print(f"Running batch_size={batch_size}, num_features={num_features}...", end=" ") 39 | triton_fn, torch_fn = create_layers(num_features) 40 | triton_fn = torch.compile(triton_fn) 41 | torch_fn = torch.compile(torch_fn) 42 | 43 | result = run_single_benchmark( 44 | triton_fn, torch_fn, setup_benchmark, batch_size, 45 | num_features, rep, warmup, verify_correctness=False, return_mode='mean' 46 | ) 47 | results.append(result) 48 | 49 | fwd_msg = (f"Fwd: {result['speedup_fwd']:.2f}x" 50 | if result['speedup_fwd'] else "Fwd: OOM") 51 | bwd_msg = (f"Fwd+Bwd: {result['speedup_fwd_bwd']:.2f}x" 52 | if result['speedup_fwd_bwd'] else "Fwd+Bwd: OOM") 53 | print(f"{fwd_msg}, {bwd_msg}") 54 | 55 | print_results_table(results, "layer_3d") 56 | 57 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(3,0)', 58 | path + '/speedup/fwd.png') 59 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(3,0)', 60 | path + '/speedup/fwd_bwd.png') 61 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)', 62 | path + '/memory/fwd.png', invert_cmap=True) 63 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)', 64 | path + '/memory/fwd_bwd.png', invert_cmap=True) -------------------------------------------------------------------------------- /modules/layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from ops import fused_gelu_sgp_norm_2d, fused_gelu_sgp_norm_3d, fused_gelu_fcgp_norm_2d, fused_gelu_fcgp_norm_3d, P2M0_NUM_PRODUCT_WEIGHTS, P3M0_NUM_PRODUCT_WEIGHTS, P2M0_NUM_GRADES, P3M0_NUM_GRADES 5 | 6 | _FUSED_OPS = { 7 | (2, False): fused_gelu_sgp_norm_2d, 8 | (2, True): fused_gelu_fcgp_norm_2d, 9 | (3, False): fused_gelu_sgp_norm_3d, 10 | (3, True): fused_gelu_fcgp_norm_3d, 11 | } 12 | 13 | _CONFIG = { 14 | 2: { 15 | 'num_product_weights': P2M0_NUM_PRODUCT_WEIGHTS, 16 | 'num_grades': P2M0_NUM_GRADES, 17 | 'weight_expansion': torch.tensor([0, 1, 1, 2], dtype=torch.long), 18 | }, 19 | 3: { 20 | 'num_product_weights': P3M0_NUM_PRODUCT_WEIGHTS, 21 | 'num_grades': P3M0_NUM_GRADES, 22 | 'weight_expansion': torch.tensor([0, 1, 1, 1, 2, 2, 2, 3], dtype=torch.long), 23 | } 24 | } 25 | 26 | 27 | class Layer(torch.nn.Module): 28 | """ 29 | Linear layer: grade-wise linear + weighted GP + GELU + LayerNorm. 30 | Efficient implementation of https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/blob/8482b06b71712dcea2841ebe567d37e7f8432d27/models/nbody_cggnn.py#L47 31 | 32 | Args: 33 | n_features: number of features. 34 | dims: 2 or 3, dimension of the space. 35 | normalize: whether to apply layer normalization at the end. 36 | use_fc: whether to use fully connected GP weights. 37 | True: weight has shape (n_features, n_features, num_product_weights). 38 | False: weight has shape (n_features, num_product_weights). 39 | """ 40 | def __init__(self, n_features, dims, normalize=True, use_fc=False): 41 | super().__init__() 42 | 43 | if dims not in _CONFIG: 44 | raise ValueError(f"Unsupported dims: {dims}") 45 | 46 | config = _CONFIG[dims] 47 | self.normalize = normalize 48 | self.fused_op = _FUSED_OPS[(dims, use_fc)] 49 | 50 | if use_fc: 51 | gp_weight_shape = (config['num_product_weights'], n_features, n_features) 52 | else: 53 | gp_weight_shape = (n_features, config['num_product_weights']) 54 | 55 | self.gp_weight = torch.nn.Parameter(torch.empty(gp_weight_shape)) 56 | 57 | linear_weight_shape = (config['num_grades'], n_features, n_features) 58 | self.linear_weight = torch.nn.Parameter(torch.empty(linear_weight_shape)) 59 | self.linear_bias = torch.nn.Parameter(torch.zeros(1, 1, n_features)) 60 | 61 | self.register_buffer("weight_expansion", config['weight_expansion']) 62 | 63 | torch.nn.init.normal_(self.gp_weight, std=1 / (math.sqrt(dims + 1))) 64 | torch.nn.init.normal_(self.linear_weight, std=1 / math.sqrt(n_features) if use_fc else 1 / math.sqrt(n_features * (dims + 1))) 65 | 66 | def forward(self, x): 67 | y = torch.bmm(x, self.linear_weight[self.weight_expansion]) + self.linear_bias 68 | return self.fused_op(x, y, self.gp_weight, self.normalize) -------------------------------------------------------------------------------- /modules/baseline.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from ops import P2M0_NUM_PRODUCT_WEIGHTS, P3M0_NUM_PRODUCT_WEIGHTS, P2M0_NUM_GRADES, P3M0_NUM_GRADES 5 | from tests.baselines import gelu_sgp_norm_2d_torch, gelu_sgp_norm_3d_torch, gelu_fcgp_norm_2d_torch, gelu_fcgp_norm_3d_torch 6 | 7 | _FUSED_OPS = { 8 | (2, False): gelu_sgp_norm_2d_torch, 9 | (2, True): gelu_fcgp_norm_2d_torch, 10 | (3, False): gelu_sgp_norm_3d_torch, 11 | (3, True): gelu_fcgp_norm_3d_torch, 12 | } 13 | 14 | _CONFIG = { 15 | 2: { 16 | 'num_product_weights': P2M0_NUM_PRODUCT_WEIGHTS, 17 | 'num_grades': P2M0_NUM_GRADES, 18 | 'weight_expansion': torch.tensor([0, 1, 1, 2], dtype=torch.long), 19 | }, 20 | 3: { 21 | 'num_product_weights': P3M0_NUM_PRODUCT_WEIGHTS, 22 | 'num_grades': P3M0_NUM_GRADES, 23 | 'weight_expansion': torch.tensor([0, 1, 1, 1, 2, 2, 2, 3], dtype=torch.long), 24 | } 25 | } 26 | 27 | 28 | class Layer(torch.nn.Module): 29 | """ 30 | Linear layer: grade-wise linear + weighted GP + GELU + LayerNorm. 31 | Metric-specific implementation of https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/blob/8482b06b71712dcea2841ebe567d37e7f8432d27/models/nbody_cggnn.py#L47 32 | 33 | Args: 34 | n_features: number of features. 35 | dims: 2 or 3, dimension of the space. 36 | normalize: whether to apply layer normalization at the end. 37 | use_fc: whether to use fully connected GP weights. 38 | True: weight has shape (n_features, n_features, num_product_weights). 39 | False: weight has shape (n_features, num_product_weights). 40 | """ 41 | def __init__(self, n_features, dims, normalize=True, use_fc=False): 42 | super().__init__() 43 | 44 | if dims not in _CONFIG: 45 | raise ValueError(f"Unsupported dims: {dims}") 46 | 47 | config = _CONFIG[dims] 48 | self.normalize = normalize 49 | self.fused_op = _FUSED_OPS[(dims, use_fc)] 50 | 51 | if use_fc: 52 | gp_weight_shape = (config['num_product_weights'], n_features, n_features) 53 | else: 54 | gp_weight_shape = (n_features, config['num_product_weights']) 55 | 56 | self.gp_weight = torch.nn.Parameter(torch.empty(gp_weight_shape)) 57 | 58 | linear_weight_shape = (config['num_grades'], n_features, n_features) 59 | self.linear_weight = torch.nn.Parameter(torch.empty(linear_weight_shape)) 60 | self.linear_bias = torch.nn.Parameter(torch.zeros(1, 1, n_features)) 61 | 62 | self.register_buffer("weight_expansion", config['weight_expansion']) 63 | 64 | torch.nn.init.normal_(self.gp_weight, std=1 / (math.sqrt(dims + 1))) 65 | torch.nn.init.normal_(self.linear_weight, std=1 / math.sqrt(n_features) if use_fc else 1 / math.sqrt(n_features * (dims + 1))) 66 | 67 | def forward(self, x): 68 | y = torch.bmm(x, self.linear_weight[self.weight_expansion]) + self.linear_bias 69 | return self.fused_op(x, y, self.gp_weight, self.normalize) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Flash Clifford 4 | `flash-clifford` provides efficient Triton-based implementations of Clifford algebra-based models. 5 | 6 | 7 | Flash Clifford Logo 8 | 9 |
10 | 11 | ## $O(n)$-Equivariant operators 12 | The list of currently implemented $O(2)$- and $O(3)$-equivariant operators: 13 | - `fused_gelu_sgp_norm_nd`: multivector GELU $\rightarrow$ weighted geometric product $\rightarrow$ (optionally) multivector RMSNorm 14 | - `fused_gelu_fcgp_norm_nd`: multivector GELU $\rightarrow$ fully connected geometric product $\rightarrow$ (optionally) multivector RMSNorm 15 | - `linear layer`: multivector linear $\rightarrow$ `fused_gelu_sgp_norm_nd` 16 | 17 | Any suggestions for different operators are welcome :) 18 | 19 | ## Primer on Clifford Algebra 20 | Clifford algebra is tightly connected to the Euclidean group $E(n)$. That is, elements of Clifford algebra are **multivectors** - stacks of basis components (scalar, vector, etc.). Those components correspond 1:1 to irreducible representations of $O(n)$. For example, for 3D: 21 | 22 | | Grades | Irreps | 23 | |------------------|--------| 24 | | 0 (scalar) | 0e | 25 | | 1 (vector) | 1o | 26 | | 2 (bivector) | 1e | 27 | | 3 (trivector) | 0o | 28 | 29 | The **geometric product** is a bilinear operation that takes two multivectors and returns a multivector, essentially mixing information between grades equivariantly ([Ruhe et al., 2023](https://arxiv.org/abs/2305.11141)). It is a subset of the tensor product, with the key difference that the geometric product does not generate higher-order representations (e.g., frequency 2). While this might come across as a limitation in terms of expressivity, the fixed and simple structure of the geometric product admits very efficient computation, which can be implemented in approx. ~1k LOC (compared to 10k LOC in [cuEquivariance](https://github.com/NVIDIA/cuEquivariance)). Empirically, Clifford algebra-based neural networks achieve state-of-the-art performance on [N-body tasks](https://arxiv.org/abs/2305.11141) and [jet tagging](https://arxiv.org/abs/2405.14806). At the same time, they are undeservedly dismissed for their inefficiency, which we aim to address in this repo :). 30 | 31 | ## Performance 32 | ### Our approach 33 | The baseline approach taken in [Ruhe et al., 2023](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks) is to implement the geometric product via dense einsum `bni, mnijk, bnk -> bmj`, where `mnijk` is a tensor (Cayley table) that encodes how the interaction of element `i` of multivector 1 and element `k` of multivector 2 results in element `j` of the output multivector. This allows having a single out-of-the-box implementation for any metric space, which is definitely cool, but it suffers in performance as the Cayley table is extremely sparse (85% in 2D, 95% in 3D). Thus, we mainly improve performance by simply hardcoding the rules of the geometric product, eliminating wasteful operations. The second source of optimization comes from fusing multiple kernels into one, specifically the activation function, which typically comes before the geometric product, and normalization. Finally, significant speedup is achieved by switching to `(MV_DIM, BATCH_SIZE, NUM_FEATURES)` memory layout, which allows expressing the linear layer as batch matmul. 34 | 35 | ### Benchmarking 36 | To demonstrate performance improvements, we benchmark the following linear layer: 37 | ``` 38 | input: multivector features x, GP weights w 39 | 1) y = MVLinear(x) 40 | 2) x = MVGELU(x) 41 | 3) y = MVGELU(y) 42 | 4) o = weighted_gp(x, y, w) 43 | 5) o = MVRMSNorm(o) 44 | ``` 45 | which is a primitive of a [Clifford Algebra MLP](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/blob/8482b06b71712dcea2841ebe567d37e7f8432d27/models/nbody_cggnn.py#L47). 46 | We compare against `cuEquivariance` (using the correspondence between multivectors and irreps) and the baseline CEGNN implementation, achieving significant improvements: 47 | 48 |
49 | MLP forward time 50 | MLP forward time 51 |
52 | 53 | 54 | 55 | ## Requirements 56 | The following requirements must be satisfied: 57 | - PyTorch 58 | - Triton >= 3.0 59 | 60 | If you want to run tests, additionally: 61 | - NumPy 62 | - matplotlib 63 | 64 | 65 | ## Usage 66 | ```python 67 | import torch 68 | from modules.layer import Layer 69 | 70 | # Input: multivectors in 3D of shape (8, batch, features) 71 | x = torch.randn(8, 4096, 512).cuda() 72 | 73 | # Linear layer: grade-wise linear + weighted GP 74 | layer = Layer(512, dims=3, normalize=True, use_fc=False).cuda() 75 | 76 | output = layer(x) 77 | ``` 78 | 79 | ## Benchmarking 80 | 81 | Run benchmarks (runtime + memory) with: 82 | ```bash 83 | python -m tests.benchmarks.layer_3d 84 | ``` 85 | This will generate a heatmap comparison against a torch-compiled implementation, which can also be found in [results](/home/maxxxzdn/Deusex/fast_clifford/flash-clifford/tests/benchmarks/results) (done on RTX 4500). 86 | 87 | ## Testing 88 | 89 | To verify correctness against a PyTorch baseline: 90 | ```bash 91 | python -m tests.p3m0 92 | ``` 93 | which will check both forward and backward (gradient) passes as well as measure the runtime and memory consumption. 94 | 95 | ## Citation 96 | If you find this repository helpful, please cite our work: 97 | ``` 98 | @software{flashclifford2025, 99 | title = {Flash Clifford: Hardware-Efficient Implementation of Clifford Algebra Neural Networks}, 100 | author = {Zhdanov, Maksim}, 101 | url = {https://github.com/maxxxzdn/flash-clifford}, 102 | year = {2025} 103 | } 104 | ``` -------------------------------------------------------------------------------- /tests/baselines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | ### Activations ### 5 | 6 | def mv_gelu(x): 7 | """Apply GELU activation gated by scalar component.""" 8 | scalar = x[[0]] 9 | gate = 0.5 * (1 + torch.erf(scalar * 0.7071067811865475)) 10 | return x * gate 11 | 12 | 13 | ### Norms ### 14 | 15 | def mv_rmsnorm_2d(x, eps=1e-6): 16 | """RMS normalization for Cl(2,0) (scalar, vector, pseudoscalar).""" 17 | scalar = x[[0]] 18 | vector = x[[1, 2]] 19 | pseudoscalar = x[[3]] 20 | 21 | scalar_rms = (scalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 22 | scalar = scalar / scalar_rms 23 | 24 | vector_norm = vector.norm(dim=0, keepdim=True) 25 | vector_rms = (vector_norm.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 26 | vector = vector / vector_rms 27 | 28 | pseudoscalar_rms = (pseudoscalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 29 | pseudoscalar = pseudoscalar / pseudoscalar_rms 30 | 31 | return torch.cat([scalar, vector, pseudoscalar], dim=0) 32 | 33 | 34 | def mv_rmsnorm_3d(x, eps=1e-6): 35 | """RMS normalization for Cl(3,0) (scalar, vector, bivector, pseudoscalar).""" 36 | scalar = x[[0]] 37 | vector = x[[1, 2, 3]] 38 | bivector = x[[4, 5, 6]] 39 | pseudoscalar = x[[7]] 40 | 41 | scalar_rms = (scalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 42 | scalar = scalar / scalar_rms 43 | 44 | vector_norm = vector.norm(dim=0, keepdim=True) 45 | vector_rms = (vector_norm.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 46 | vector = vector / vector_rms 47 | 48 | bivector_norm = bivector.norm(dim=0, keepdim=True) 49 | bivector_rms = (bivector_norm.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 50 | bivector = bivector / bivector_rms 51 | 52 | pseudoscalar_rms = (pseudoscalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt() 53 | pseudoscalar = pseudoscalar / pseudoscalar_rms 54 | 55 | return torch.cat([scalar, vector, bivector, pseudoscalar], dim=0) 56 | 57 | 58 | ### Geometric Product Layers ### 59 | 60 | def sgp_2d(x, y, weight): 61 | """Weighted geometric product in Cl(2,0).""" 62 | x0, x1, x2, x3 = x[0], x[1], x[2], x[3] 63 | y0, y1, y2, y3 = y[0], y[1], y[2], y[3] 64 | 65 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9 = weight.T 66 | 67 | o0 = w0 * x0 * y0 + w3 * (x1 * y1 + x2 * y2) - w7 * x3 * y3 68 | o1 = w1 * x0 * y1 + w4 * x1 * y0 - w5 * x2 * y3 + w8 * x3 * y2 69 | o2 = w1 * x0 * y2 + w5 * x1 * y3 + w4 * x2 * y0 - w8 * x3 * y1 70 | o3 = w2 * x0 * y3 + w6 * (x1 * y2 - x2 * y1) + w9 * x3 * y0 71 | 72 | return torch.stack([o0, o1, o2, o3], dim=0) 73 | 74 | 75 | def sgp_3d(x, y, weight): 76 | """Weighted geometric product in Cl(3,0).""" 77 | x0, x1, x2, x3, x4, x5, x6, x7 = x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7] 78 | y0, y1, y2, y3, y4, y5, y6, y7 = y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7] 79 | 80 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19 = weight.T 81 | 82 | o0 = (w0*x0*y0 + w4 * (x1*y1 + x2*y2 + x3*y3) - w10 * (x4*y4 + x5*y5 + x6*y6) - w16*x7*y7) 83 | o1 = (w1*x0*y1 + w5*x1*y0 - w6 * (x2*y4 + x3*y5) + w11 * (x4*y2 + x5*y3) - w12*x6*y7 - w17*x7*y6) 84 | o2 = (w1*x0*y2 + w6*x1*y4 + w5*x2*y0 - w6*x3*y6 - w11*x4*y1 + w12*x5*y7 + w11*x6*y3 + w17*x7*y5) 85 | o3 = (w1*x0*y3 + w6 * (x1*y5 + x2*y6) + w5*x3*y0 - w12*x4*y7 - w11 * (x5*y1 + x6*y2) - w17*x7*y4) 86 | o4 = (w2*x0*y4 + w7*x1*y2 - w7*x2*y1 + w8*x3*y7 + w13*x4*y0 - w14*x5*y6 + w14*x6*y5 + w18*x7*y3) 87 | o5 = (w2*x0*y5 + w7*x1*y3 - w8*x2*y7 - w7*x3*y1 + w14*x4*y6 + w13*x5*y0 - w14*x6*y4 - w18*x7*y2) 88 | o6 = (w2*x0*y6 + w8*x1*y7 + w7*x2*y3 - w7*x3*y2 - w14*x4*y5 + w14*x5*y4 + w13*x6*y0 + w18*x7*y1) 89 | o7 = (w3*x0*y7 + w9*x1*y6 - w9*x2*y5 + w9*x3*y4 + w15*x4*y3 - w15*x5*y2 + w15*x6*y1 + w19*x7*y0) 90 | 91 | return torch.stack([o0, o1, o2, o3, o4, o5, o6, o7], dim=0) 92 | 93 | 94 | def fcgp_2d(x, y, weight): 95 | """Fully connected geometric product in Cl(2,0).""" 96 | x0, x1, x2, x3 = x[0], x[1], x[2], x[3] 97 | y0, y1, y2, y3 = y[0], y[1], y[2], y[3] 98 | 99 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9 = weight 100 | 101 | o0 = (x0 * y0) @ w0 + (x1 * y1 + x2 * y2) @ w3 - (x3 * y3) @ w7 102 | o1 = (x0 * y1) @ w1 + (x1 * y0) @ w4 - (x2 * y3) @ w5 + (x3 * y2) @ w8 103 | o2 = (x0 * y2) @ w1 + (x1 * y3) @ w5 + (x2 * y0) @ w4 - (x3 * y1) @ w8 104 | o3 = x0 * y3 @ w2 + (x1 * y2 - x2 * y1) @ w6 + (x3 * y0) @ w9 105 | 106 | return torch.stack([o0, o1, o2, o3], dim=0) 107 | 108 | 109 | def fcgp_3d(x, y, weight): 110 | """Fully connected geometric product in Cl(3,0).""" 111 | x0, x1, x2, x3, x4, x5, x6, x7 = x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7] 112 | y0, y1, y2, y3, y4, y5, y6, y7 = y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7] 113 | 114 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19 = weight 115 | 116 | o0 = (x0 * y0) @ w0 + (x1 * y1 + x2 * y2 + x3 * y3) @ w4 - (x4 * y4 + x5 * y5 + x6 * y6) @ w10 - (x7 * y7) @ w16 117 | o1 = (x0 * y1) @ w1 + (x1 * y0) @ w5 - (x2 * y4 + x3 * y5) @ w6 + (x4 * y2 + x5 * y3) @ w11 - (x6 * y7) @ w12 - (x7 * y6) @ w17 118 | o2 = (x0 * y2) @ w1 + (x1 * y4) @ w6 + (x2 * y0) @ w5 - (x3 * y6) @ w6 - (x4 * y1) @ w11 + (x5 * y7) @ w12 + (x6 * y3) @ w11 + (x7 * y5) @ w17 119 | o3 = (x0 * y3) @ w1 + (x1 * y5 + x2 * y6) @ w6 + (x3 * y0) @ w5 - (x4 * y7) @ w12 - (x5 * y1 + x6 * y2) @ w11 - (x7 * y4) @ w17 120 | o4 = (x0 * y4) @ w2 + (x1 * y2 - x2 * y1) @ w7 + (x3 * y7) @ w8 + (x4 * y0) @ w13 - (x5 * y6) @ w14 + (x6 * y5) @ w14 + (x7 * y3) @ w18 121 | o5 = (x0 * y5) @ w2 + (x1 * y3) @ w7 - (x2 * y7) @ w8 - (x3 * y1) @ w7 + (x4 * y6) @ w14 + (x5 * y0) @ w13 - (x6 * y4) @ w14 - (x7 * y2) @ w18 122 | o6 = (x0 * y6) @ w2 + (x1 * y7) @ w8 + (x2 * y3) @ w7 - (x3 * y2) @ w7 - (x4 * y5) @ w14 + (x5 * y4) @ w14 + (x6 * y0) @ w13 + (x7 * y1) @ w18 123 | o7 = (x0 * y7) @ w3 + (x1 * y6 - x2 * y5 + x3 * y4) @ w9 + (x4 * y3 - x5 * y2 + x6 * y1) @ w15 + (x7 * y0) @ w19 124 | 125 | return torch.stack([o0, o1, o2, o3, o4, o5, o6, o7], dim=0) 126 | 127 | 128 | ### Baseline implementations ### 129 | 130 | @torch.compile 131 | def gelu_sgp_norm_2d_torch(x, y, weight, normalize=True): 132 | """Geometric product layer with GELU activation and RMS normalization in Cl(2,0).""" 133 | x = mv_gelu(x) 134 | y = mv_gelu(y) 135 | o = sgp_2d(x, y, weight) 136 | if normalize: 137 | o = mv_rmsnorm_2d(o) 138 | return o 139 | 140 | 141 | @torch.compile 142 | def gelu_sgp_norm_3d_torch(x, y, weight, normalize=True): 143 | """Geometric product layer with GELU activation and RMS normalization in Cl(3,0).""" 144 | x = mv_gelu(x) 145 | y = mv_gelu(y) 146 | o = sgp_3d(x, y, weight) 147 | if normalize: 148 | o = mv_rmsnorm_3d(o) 149 | return o 150 | 151 | 152 | @torch.compile 153 | def gelu_fcgp_norm_2d_torch(x, y, weight, normalize=True): 154 | """Fully connected geometric product layer with GELU activation and RMS normalization in Cl(2,0).""" 155 | x = mv_gelu(x) 156 | y = mv_gelu(y) 157 | o = fcgp_2d(x, y, weight) 158 | if normalize: 159 | o = mv_rmsnorm_2d(o) 160 | return o 161 | 162 | 163 | @torch.compile 164 | def gelu_fcgp_norm_3d_torch(x, y, weight, normalize=True): 165 | """Fully connected geometric product layer with GELU activation and RMS normalization in Cl(3,0).""" 166 | x = mv_gelu(x) 167 | y = mv_gelu(y) 168 | o = fcgp_3d(x, y, weight) 169 | if normalize: 170 | o = mv_rmsnorm_3d(o) 171 | return o 172 | 173 | 174 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | import triton 7 | 8 | from contextlib import contextmanager 9 | 10 | 11 | @contextmanager 12 | def measure_memory(): 13 | torch.cuda.reset_peak_memory_stats() 14 | torch.cuda.empty_cache() 15 | torch.cuda.synchronize() 16 | 17 | peak = [None] 18 | yield peak 19 | 20 | torch.cuda.synchronize() 21 | peak[0] = torch.cuda.max_memory_allocated() / 1024**2 22 | 23 | 24 | def print_benchmark_results(avg_time_fused, avg_time_torch, mem_fused, mem_torch, title=""): 25 | header_length_pad = max(0, 35 - len(title) - 4) 26 | print(f"\n┌─ {title} " + "─"*header_length_pad + "┐") 27 | print("│") 28 | print(f"│ Runtime:") 29 | print(f"│ Fused Kernel : {avg_time_fused:>8.2f} ms") 30 | print(f"│ PyTorch : {avg_time_torch:>8.2f} ms") 31 | print(f"│ Speedup : {avg_time_torch / avg_time_fused:>8.2f}×") 32 | print("│") 33 | print(f"│ Memory Usage:") 34 | print(f"│ Fused Kernel : {mem_fused:>8.2f} MB") 35 | print(f"│ PyTorch : {mem_torch:>8.2f} MB") 36 | print(f"│ Memory Ratio : {mem_fused / mem_torch:>8.2f}×") 37 | print("│") 38 | print("└" + "─"*34 + "┘\n") 39 | 40 | 41 | def run_benchmark(triton_fn, torch_fn, args, rep, warmup=200, verbose=True, return_mode='mean'): 42 | """Run forward and forward+backward benchmarks.""" 43 | # Forward-only benchmark 44 | avg_time_fused = triton.testing.do_bench(lambda: triton_fn(*args), warmup, rep, return_mode=return_mode) 45 | avg_time_torch = triton.testing.do_bench(lambda: torch_fn(*args), warmup, rep, return_mode=return_mode) 46 | with measure_memory() as mem_fused_fwd: _ = triton_fn(*args) 47 | with measure_memory() as mem_torch_fwd: _ = torch_fn(*args) 48 | 49 | # Forward + backward benchmark 50 | args = [arg.clone().detach().requires_grad_(True) for arg in args] 51 | avg_time_fused_bwd = triton.testing.do_bench(lambda: triton_fn(*args).sum().backward(), warmup, rep, return_mode=return_mode) 52 | avg_time_torch_bwd = triton.testing.do_bench(lambda: torch_fn(*args).sum().backward(), warmup, rep, return_mode=return_mode) 53 | with measure_memory() as mem_fused_bwd: _ = triton_fn(*args) 54 | with measure_memory() as mem_torch_bwd: _ = torch_fn(*args) 55 | 56 | if verbose: 57 | print_benchmark_results( 58 | avg_time_fused, avg_time_torch, mem_fused_fwd[0], mem_torch_fwd[0], 59 | title="Forward Pass" 60 | ) 61 | print_benchmark_results( 62 | avg_time_fused_bwd, avg_time_torch_bwd, mem_fused_bwd[0], mem_torch_bwd[0], 63 | title="Forward + Backward Pass" 64 | ) 65 | 66 | return avg_time_fused, avg_time_torch, mem_fused_fwd[0], mem_torch_fwd[0], avg_time_fused_bwd, avg_time_torch_bwd, mem_fused_bwd[0], mem_torch_bwd[0] 67 | 68 | 69 | def run_correctness_test(triton_fn, torch_fn, kwargs): 70 | """Run forward and backward correctness test.""" 71 | # Forward correctness check 72 | out_triton = triton_fn(**kwargs) 73 | out_torch = torch_fn(**kwargs) 74 | 75 | max_diff = (out_torch - out_triton).abs().max().item() 76 | is_correct = torch.allclose(out_torch, out_triton, atol=1e-5) 77 | check_mark = " ✔" if is_correct else " ✘" 78 | print(f"Max absolute difference (fwd): {max_diff:.1e}{check_mark}") 79 | 80 | # Backward correctness check 81 | kwargs = {k: v.clone().detach().requires_grad_(True) for k, v in kwargs.items()} 82 | 83 | out_torch = torch_fn(**kwargs) 84 | out_triton = triton_fn(**kwargs) 85 | 86 | grad_output = torch.randn_like(out_torch).contiguous() 87 | out_torch.backward(grad_output) 88 | out_triton.backward(grad_output) 89 | 90 | for name, arg in kwargs.items(): 91 | grad_diff = (arg.grad - arg.grad).abs().max().item() 92 | grad_correct = torch.allclose(arg.grad, arg.grad, atol=1e-2) 93 | print(f"grad {name} max diff: {grad_diff:.1e}" + (" ✔" if grad_correct else " ✘")) 94 | 95 | 96 | def plot_heatmap(results, metric_key, title, save_path, cmap='RdYlGn', invert_cmap=False): 97 | """Heatmap visualization of benchmark results.""" 98 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 99 | 100 | batch_sizes = sorted(set(r['batch_size'] for r in results)) 101 | num_features_list = sorted(set(r['num_features'] for r in results)) 102 | 103 | matrix = np.zeros((len(batch_sizes), len(num_features_list))) 104 | for r in results: 105 | i = batch_sizes.index(r['batch_size']) 106 | j = num_features_list.index(r['num_features']) 107 | matrix[i, j] = r[metric_key] if r[metric_key] is not None else 0 108 | 109 | fig, ax = plt.subplots(figsize=(10, 8)) 110 | cmap_name = f'{cmap}_r' if invert_cmap else cmap 111 | im = ax.imshow(matrix, cmap=cmap_name, aspect='auto') 112 | 113 | ax.set_xticks(np.arange(len(num_features_list))) 114 | ax.set_yticks(np.arange(len(batch_sizes))) 115 | ax.set_xticklabels(num_features_list) 116 | ax.set_yticklabels(batch_sizes) 117 | 118 | ax.set_xlabel('Number of Features', fontsize=12) 119 | ax.set_ylabel('Batch Size', fontsize=12) 120 | ax.set_title(title, fontsize=14, pad=20) 121 | 122 | cbar = plt.colorbar(im, ax=ax) 123 | cbar_label = 'Speedup (x)' if 'speedup' in metric_key else 'Memory Ratio (x)' 124 | cbar.set_label(cbar_label, fontsize=12) 125 | 126 | # Add text annotations 127 | for i in range(len(batch_sizes)): 128 | for j in range(len(num_features_list)): 129 | value = matrix[i, j] 130 | text_str = f'{value:.2f}x' if value > 0 else 'OOM' 131 | ax.text(j, i, text_str, ha="center", va="center", color="black", fontsize=10) 132 | 133 | plt.tight_layout() 134 | plt.savefig(save_path, dpi=300, bbox_inches='tight') 135 | print(f"Heatmap saved to: {save_path}") 136 | plt.close() 137 | 138 | 139 | def print_results_table(results, title): 140 | """Print results as a formatted table.""" 141 | separator = "=" * 140 142 | divider = "-" * 140 143 | 144 | # Forward pass results 145 | print(f"\n{separator}") 146 | print(f"FORWARD PASS RESULTS - {title}") 147 | print(separator) 148 | print(f"{'Batch':<8} {'Features':<10} {'Fused (ms)':<12} {'Torch (ms)':<12} " 149 | f"{'Speedup':<10} {'Max Diff':<12} {'Correct':<8}") 150 | print(divider) 151 | 152 | for r in results: 153 | speedup_str = f"{r['speedup_fwd']:.2f}x" if r['speedup_fwd'] else "N/A" 154 | correct_mark = '✔' if r['is_correct'] else '✘' 155 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {r['time_fwd_fused']:<12.2f} " 156 | f"{r['time_fwd_torch']:<12.2f} {speedup_str:<10} {r['max_diff']:<12.1e} " 157 | f"{correct_mark:<8}") 158 | 159 | print(separator) 160 | 161 | # Forward + backward pass results 162 | print(f"\n{separator}") 163 | print(f"FORWARD + BACKWARD PASS RESULTS - {title}") 164 | print(separator) 165 | print(f"{'Batch':<8} {'Features':<10} {'Fused (ms)':<12} {'Torch (ms)':<12} {'Speedup':<10}") 166 | print(divider) 167 | 168 | for r in results: 169 | fused_time = f"{r['time_fwd_bwd_fused']:.2f}" if r['time_fwd_bwd_fused'] else "OOM" 170 | torch_time = f"{r['time_fwd_bwd_torch']:.2f}" if r['time_fwd_bwd_torch'] else "OOM" 171 | speedup = f"{r['speedup_fwd_bwd']:.2f}x" if r['speedup_fwd_bwd'] else "N/A" 172 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {fused_time:<12} " 173 | f"{torch_time:<12} {speedup:<10}") 174 | 175 | print(separator) 176 | 177 | # Forward memory usage 178 | print(f"\n{separator}") 179 | print(f"FORWARD MEMORY USAGE - {title}") 180 | print(separator) 181 | print(f"{'Batch':<8} {'Features':<10} {'Fused (MB)':<12} {'Torch (MB)':<12} {'Ratio':<10}") 182 | print(divider) 183 | 184 | for r in results: 185 | fused_mem = f"{r['mem_fwd_fused']:.2f}" if r['mem_fwd_fused'] else "OOM" 186 | torch_mem = f"{r['mem_fwd_torch']:.2f}" if r['mem_fwd_torch'] else "OOM" 187 | ratio = f"{r['mem_ratio_fwd']:.2f}x" if r['mem_ratio_fwd'] else "N/A" 188 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {fused_mem:<12} " 189 | f"{torch_mem:<12} {ratio:<10}") 190 | 191 | print(separator) 192 | 193 | # Forward + backward memory usage 194 | print(f"\n{separator}") 195 | print(f"FORWARD + BACKWARD MEMORY USAGE - {title}") 196 | print(separator) 197 | print(f"{'Batch':<8} {'Features':<10} {'Fused (MB)':<12} {'Torch (MB)':<12} {'Ratio':<10}") 198 | print(divider) 199 | 200 | for r in results: 201 | fused_mem = f"{r['mem_fwd_bwd_fused']:.2f}" if r['mem_fwd_bwd_fused'] else "OOM" 202 | torch_mem = f"{r['mem_fwd_bwd_torch']:.2f}" if r['mem_fwd_bwd_torch'] else "OOM" 203 | ratio = f"{r['mem_ratio_fwd_bwd']:.2f}x" if r['mem_ratio_fwd_bwd'] else "N/A" 204 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {fused_mem:<12} " 205 | f"{torch_mem:<12} {ratio:<10}") 206 | 207 | print(separator) 208 | 209 | 210 | def run_single_benchmark(triton_fn, torch_fn, setup_fn, batch_size, num_features, rep, warmup=200, verify_correctness=True, return_mode='mean'): 211 | """Run a single benchmark configuration.""" 212 | args = setup_fn(batch_size, num_features) 213 | 214 | out_triton = triton_fn(*args) 215 | out_torch = torch_fn(*args) 216 | 217 | if verify_correctness: 218 | is_correct = torch.allclose(out_torch, out_triton, atol=1e-5) 219 | max_diff = (out_torch - out_triton).abs().max().item() 220 | else: 221 | is_correct = False 222 | max_diff = 1e10 223 | 224 | time_fwd_fused, time_fwd_torch, mem_fwd_fused, mem_fwd_torch, \ 225 | time_fwd_bwd_fused, time_fwd_bwd_torch, mem_fwd_bwd_fused, mem_fwd_bwd_torch = \ 226 | run_benchmark( 227 | triton_fn, torch_fn, args, rep, 228 | warmup=warmup, verbose=False, return_mode=return_mode 229 | ) 230 | 231 | return { 232 | 'batch_size': batch_size, 233 | 'num_features': num_features, 234 | 'time_fwd_fused': time_fwd_fused, 235 | 'time_fwd_torch': time_fwd_torch, 236 | 'speedup_fwd': time_fwd_torch / time_fwd_fused if time_fwd_fused else None, 237 | 'time_fwd_bwd_fused': time_fwd_bwd_fused, 238 | 'time_fwd_bwd_torch': time_fwd_bwd_torch, 239 | 'speedup_fwd_bwd': time_fwd_bwd_torch / time_fwd_bwd_fused if time_fwd_bwd_fused else None, 240 | 'mem_fwd_fused': mem_fwd_fused, 241 | 'mem_fwd_torch': mem_fwd_torch, 242 | 'mem_ratio_fwd': mem_fwd_fused / mem_fwd_torch if mem_fwd_torch else None, 243 | 'mem_fwd_bwd_fused': mem_fwd_bwd_fused, 244 | 'mem_fwd_bwd_torch': mem_fwd_bwd_torch, 245 | 'mem_ratio_fwd_bwd': mem_fwd_bwd_fused / mem_fwd_bwd_torch if mem_fwd_bwd_torch else None, 246 | 'max_diff': max_diff, 247 | 'is_correct': is_correct, 248 | } 249 | 250 | 251 | def run_sweep(triton_fn, torch_fn, setup_fn, 252 | batch_sizes=[1024, 2048, 4096, 8192], 253 | num_features_list=[128, 256, 512, 1024], 254 | rep=1000): 255 | """Run benchmark sweep across batch sizes and feature dimensions.""" 256 | results = [] 257 | 258 | print("Running benchmark sweep...") 259 | print(f"Batch sizes: {batch_sizes}") 260 | print(f"Num features: {num_features_list}") 261 | 262 | for batch_size in batch_sizes: 263 | for num_features in num_features_list: 264 | print(f"Running batch_size={batch_size}, num_features={num_features}...", end=" ") 265 | result = run_single_benchmark(triton_fn, torch_fn, setup_fn, batch_size, num_features, rep) 266 | results.append(result) 267 | 268 | fwd_msg = (f"Fwd: {result['speedup_fwd']:.2f}x" 269 | if result['speedup_fwd'] else "Fwd: OOM") 270 | bwd_msg = (f"Fwd+Bwd: {result['speedup_fwd_bwd']:.2f}x" 271 | if result['speedup_fwd_bwd'] else "Fwd+Bwd: OOM") 272 | print(f"{fwd_msg}, {bwd_msg}") 273 | 274 | return results -------------------------------------------------------------------------------- /ops/p2m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | MV_DIM = 4 6 | NUM_GRADES = 3 7 | NUM_PRODUCT_WEIGHTS = 10 8 | EPS = 1e-6 9 | 10 | # tuned at RTX 4500 11 | DEFAULT_BATCH_BLOCK = 4 12 | DEFAULT_FEATURE_BLOCK = 128 13 | DEFAULT_NUM_WARPS = 16 14 | DEFAULT_NUM_STAGES = 1 15 | 16 | 17 | @triton.jit 18 | def compute_gelu_gate(x): 19 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))""" 20 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype) 21 | 22 | 23 | @triton.jit 24 | def compute_gelu_gate_grad(x): 25 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)""" 26 | return 0.3989422804 * tl.exp(-0.5 * x * x) 27 | 28 | 29 | @triton.jit 30 | def gelu_wgp_norm_kernel_fwd( 31 | x_ptr, 32 | y_ptr, 33 | output_ptr, 34 | weights_ptr, 35 | pnorm_ptr, 36 | NORMALIZE: tl.constexpr, 37 | batch_size: tl.constexpr, 38 | n_features: tl.constexpr, 39 | BATCH_BLOCK: tl.constexpr, 40 | FEATURE_BLOCK: tl.constexpr, 41 | NUM_PRODUCT_WEIGHTS: tl.constexpr, 42 | ): 43 | """ 44 | Apply GELU non-linearity to inputs, compute weighted geometric product, 45 | and accumulate squared norms for grade-wise RMSNorm. 46 | """ 47 | batch_block_id = tl.program_id(axis=0) 48 | thread_block_id = tl.program_id(axis=1) 49 | 50 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 51 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 52 | 53 | batch_mask = batch_ids < batch_size 54 | feature_mask = feature_ids < n_features 55 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 56 | 57 | stride_component = batch_size * n_features 58 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 59 | 60 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS 61 | 62 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask) 63 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask) 64 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask) 65 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask) 66 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask) 67 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask) 68 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask) 69 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask) 70 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask) 71 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask) 72 | 73 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 74 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 75 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 76 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 77 | 78 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 79 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 80 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 81 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 82 | 83 | # Apply GELU gate 84 | gate_x = compute_gelu_gate(x0) 85 | gate_y = compute_gelu_gate(y0) 86 | 87 | x0 = x0 * gate_x 88 | x1 = x1 * gate_x 89 | x2 = x2 * gate_x 90 | x3 = x3 * gate_x 91 | 92 | y0 = y0 * gate_y 93 | y1 = y1 * gate_y 94 | y2 = y2 * gate_y 95 | y3 = y3 * gate_y 96 | 97 | # Compute geometric product 98 | o0 = w0*x0*y0 + w3*(x1*y1 + x2*y2) - w7*x3*y3 99 | o1 = w1*x0*y1 + w4*x1*y0 - w5*x2*y3 + w8*x3*y2 100 | o2 = w1*x0*y2 + w5*x1*y3 + w4*x2*y0 - w8*x3*y1 101 | o3 = w2*x0*y3 + w6*(x1*y2 - x2*y1) + w9*x3*y0 102 | 103 | if NORMALIZE: 104 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features 105 | pn_vector = tl.sum(o1*o1 + o2*o2, axis=1) / n_features 106 | pn_pseudo = tl.sum(o3 * o3, axis=1) / n_features 107 | 108 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask) 109 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask) 110 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask) 111 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_pseudo, mask=batch_mask) 112 | 113 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask) 114 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask) 115 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask) 116 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask) 117 | 118 | 119 | @triton.jit 120 | def normalize_with_sqrt_kernel( 121 | output_ptr, 122 | pnorm_ptr, 123 | batch_size: tl.constexpr, 124 | n_features: tl.constexpr, 125 | BATCH_BLOCK: tl.constexpr, 126 | FEATURE_BLOCK: tl.constexpr, 127 | MV_DIM: tl.constexpr, 128 | EPS: tl.constexpr, 129 | ): 130 | """Normalize the output by dividing each grade with root of its accumulated norm.""" 131 | batch_block_id = tl.program_id(axis=0) 132 | thread_block_id = tl.program_id(axis=1) 133 | 134 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 135 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 136 | 137 | batch_mask = batch_ids < batch_size 138 | feature_mask = feature_ids < n_features 139 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None] 140 | 141 | component_ids = tl.arange(0, MV_DIM)[None, None, :] 142 | 143 | feature_offset = (component_ids * batch_size * n_features + 144 | batch_ids[:, None, None] * n_features + 145 | feature_ids[None, :, None]) 146 | 147 | norm_indices = component_ids * batch_size + batch_ids[:, None, None] 148 | 149 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None]) 150 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask) 151 | 152 | norm = tl.sqrt(pnorm + EPS) 153 | mv_normalized = mv / norm 154 | 155 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask) 156 | 157 | 158 | def gelu_geometric_product_norm_fwd( 159 | x: torch.Tensor, 160 | y: torch.Tensor, 161 | weight: torch.Tensor, 162 | normalize: bool, 163 | ) -> torch.Tensor: 164 | """Fused operation: GELU non-linearity, weighted geometric product, and grade-wise RMSNorm.""" 165 | assert x.shape == y.shape 166 | assert x.shape[0] == MV_DIM 167 | assert x.shape[2] == weight.shape[0] 168 | assert weight.shape[1] == NUM_PRODUCT_WEIGHTS 169 | 170 | _, B, N = x.shape 171 | 172 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 173 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 174 | 175 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 176 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 177 | 178 | output = torch.empty_like(x) 179 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize 180 | else torch.zeros((1,), device=x.device, dtype=x.dtype)) 181 | 182 | grid = (num_blocks_batch, num_blocks_features) 183 | 184 | gelu_wgp_norm_kernel_fwd[grid]( 185 | x, 186 | y, 187 | output, 188 | weight, 189 | partial_norm, 190 | normalize, 191 | B, 192 | N, 193 | BATCH_BLOCK, 194 | FEATURE_BLOCK, 195 | NUM_PRODUCT_WEIGHTS, 196 | num_warps=DEFAULT_NUM_WARPS, 197 | num_stages=DEFAULT_NUM_STAGES, 198 | ) 199 | 200 | if normalize: 201 | normalize_with_sqrt_kernel[grid]( 202 | output, 203 | partial_norm, 204 | B, 205 | N, 206 | BATCH_BLOCK, 207 | FEATURE_BLOCK, 208 | MV_DIM, 209 | EPS, 210 | num_warps=DEFAULT_NUM_WARPS, 211 | num_stages=DEFAULT_NUM_STAGES, 212 | ) 213 | 214 | return output, partial_norm 215 | 216 | 217 | @triton.jit 218 | def grad_o_dot_o_kernel( 219 | dot_ptr, 220 | pnorm_ptr, 221 | output_ptr, 222 | grad_output_ptr, 223 | batch_size: tl.constexpr, 224 | n_features: tl.constexpr, 225 | BATCH_BLOCK: tl.constexpr, 226 | FEATURE_BLOCK: tl.constexpr, 227 | EPS: tl.constexpr, 228 | ): 229 | """Compute the dot product of grad_output and output for each grade, accumulate across all features.""" 230 | batch_block_id = tl.program_id(axis=0) 231 | thread_block_id = tl.program_id(axis=1) 232 | 233 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 234 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 235 | 236 | batch_mask = batch_ids < batch_size 237 | feature_mask = feature_ids < n_features 238 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 239 | 240 | stride_component = batch_size * n_features 241 | offset = batch_ids[:, None] * n_features + feature_ids[None, :] 242 | 243 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 244 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 245 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 246 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 247 | 248 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 249 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 250 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None] 251 | 252 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 253 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 254 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 255 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 256 | 257 | rms_scalar = tl.sqrt(pn_scalar + EPS) 258 | rms_vector = tl.sqrt(pn_vector + EPS) 259 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 260 | 261 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1) 262 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2), axis=1) 263 | dot_pseudo = tl.sum(rms_pseudo * go3 * o3, axis=1) 264 | 265 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask) 266 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask) 267 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_pseudo, mask=batch_mask) 268 | 269 | 270 | @triton.jit 271 | def gelu_wgp_norm_kernel_bwd( 272 | x_ptr, 273 | y_ptr, 274 | output_ptr, 275 | weights_ptr, 276 | dot_ptr, 277 | pnorm_ptr, 278 | grad_output_ptr, 279 | grad_x_ptr, 280 | grad_y_ptr, 281 | grad_weight_ptr, 282 | NORMALIZE: tl.constexpr, 283 | batch_size: tl.constexpr, 284 | n_features: tl.constexpr, 285 | BATCH_BLOCK: tl.constexpr, 286 | FEATURE_BLOCK: tl.constexpr, 287 | NUM_PRODUCT_WEIGHTS: tl.constexpr, 288 | EPS: tl.constexpr, 289 | ): 290 | """Compute gradients w.r.t. inputs and weights of the fused operation.""" 291 | batch_block_id = tl.program_id(axis=0) 292 | thread_block_id = tl.program_id(axis=1) 293 | 294 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 295 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 296 | 297 | batch_mask = batch_ids < batch_size 298 | feature_mask = feature_ids < n_features 299 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 300 | 301 | stride_component = batch_size * n_features 302 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 303 | 304 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS 305 | block_offset = batch_block_id * n_features * NUM_PRODUCT_WEIGHTS 306 | 307 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask) 308 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask) 309 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask) 310 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask) 311 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask) 312 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask) 313 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask) 314 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask) 315 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask) 316 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask) 317 | 318 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 319 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 320 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 321 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 322 | 323 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 324 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 325 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 326 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 327 | 328 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 329 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 330 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 331 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 332 | 333 | if NORMALIZE: 334 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 335 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 336 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 337 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 338 | 339 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 340 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 341 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None] 342 | 343 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 344 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 345 | dot_pseudo = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None] 346 | 347 | rms_scalar = tl.sqrt(pn_scalar + EPS) 348 | rms_vector = tl.sqrt(pn_vector + EPS) 349 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 350 | 351 | go0 = go0/rms_scalar - o0 * dot_scalar / (n_features*rms_scalar*rms_scalar) 352 | go1 = go1/rms_vector - o1 * dot_vector / (n_features*rms_vector*rms_vector) 353 | go2 = go2/rms_vector - o2 * dot_vector / (n_features*rms_vector*rms_vector) 354 | go3 = go3/rms_pseudo - o3 * dot_pseudo / (n_features*rms_pseudo*rms_pseudo) 355 | 356 | # weighted geometric product backward 357 | gate_x = compute_gelu_gate(x0_raw) 358 | gate_y = compute_gelu_gate(y0_raw) 359 | 360 | x0 = x0_raw * gate_x 361 | x1 = x1_raw * gate_x 362 | x2 = x2_raw * gate_x 363 | x3 = x3_raw * gate_x 364 | 365 | y0 = y0_raw * gate_y 366 | y1 = y1_raw * gate_y 367 | y2 = y2_raw * gate_y 368 | y3 = y3_raw * gate_y 369 | 370 | tmp0 = go0*w0 371 | tmp1 = go1*w1 372 | tmp2 = go2*w1 373 | tmp3 = go0*w3 374 | tmp4 = w4*y0 375 | tmp5 = w5*y3 376 | tmp6 = go3*w6 377 | tmp7 = go0*w7 378 | tmp8 = go2*w8 379 | tmp9 = go1*x1 380 | tmp10 = go2*x2 381 | tmp11 = go1*x2 382 | 383 | x_grad_0 = (go3*w2*y3 + tmp0*y0 + tmp1*y1 + tmp2*y2) 384 | x_grad_1 = (go1*tmp4 + go2*tmp5 + tmp3*y1 + tmp6*y2) 385 | x_grad_2 = (-go1*tmp5 + go2*tmp4 + tmp3*y2 - tmp6*y1) 386 | x_grad_3 = (go1*w8*y2 + go3*w9*y0 - tmp7*y3 - tmp8*y1) 387 | 388 | y_grad_0 = (go3*w9*x3 + tmp0*x0 + tmp10*w4 + tmp9*w4) 389 | y_grad_1 = (tmp1*x0 + tmp3*x1 - tmp6*x2 - tmp8*x3) 390 | y_grad_2 = (go1*w8*x3 + tmp2*x0 + tmp3*x2 + tmp6*x1) 391 | y_grad_3 = (go2*w5*x1 + go3*w2*x0 - tmp11*w5 - tmp7*x3) 392 | 393 | w_grad_0 = tl.sum(go0*x0*y0, axis=0) 394 | w_grad_1 = tl.sum(go1*x0*y1 + go2*x0*y2, axis=0) 395 | w_grad_2 = tl.sum(go3*x0*y3, axis=0) 396 | w_grad_3 = tl.sum(go0*(x1*y1 + x2*y2), axis=0) 397 | w_grad_4 = tl.sum(tmp10*y0 + tmp9*y0, axis=0) 398 | w_grad_5 = tl.sum(go2*x1*y3 - tmp11*y3, axis=0) 399 | w_grad_6 = tl.sum(go3*(x1*y2 - x2*y1), axis=0) 400 | w_grad_7 = tl.sum(-go0*x3*y3, axis=0) 401 | w_grad_8 = tl.sum(go1*x3*y2 - go2*x3*y1, axis=0) 402 | w_grad_9 = tl.sum(go3*x3*y0, axis=0) 403 | 404 | # GELU gate gradients 405 | dgate_x = compute_gelu_gate_grad(x0_raw) 406 | dgate_y = compute_gelu_gate_grad(y0_raw) 407 | 408 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3) 409 | x_grad_1 = gate_x * x_grad_1 410 | x_grad_2 = gate_x * x_grad_2 411 | x_grad_3 = gate_x * x_grad_3 412 | 413 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3) 414 | y_grad_1 = gate_y * y_grad_1 415 | y_grad_2 = gate_y * y_grad_2 416 | y_grad_3 = gate_y * y_grad_3 417 | 418 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask) 419 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask) 420 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask) 421 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask) 422 | 423 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask) 424 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask) 425 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask) 426 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask) 427 | 428 | tl.store(grad_weight_ptr + block_offset + weight_offset + 0, w_grad_0, mask=feature_mask) 429 | tl.store(grad_weight_ptr + block_offset + weight_offset + 1, w_grad_1, mask=feature_mask) 430 | tl.store(grad_weight_ptr + block_offset + weight_offset + 2, w_grad_2, mask=feature_mask) 431 | tl.store(grad_weight_ptr + block_offset + weight_offset + 3, w_grad_3, mask=feature_mask) 432 | tl.store(grad_weight_ptr + block_offset + weight_offset + 4, w_grad_4, mask=feature_mask) 433 | tl.store(grad_weight_ptr + block_offset + weight_offset + 5, w_grad_5, mask=feature_mask) 434 | tl.store(grad_weight_ptr + block_offset + weight_offset + 6, w_grad_6, mask=feature_mask) 435 | tl.store(grad_weight_ptr + block_offset + weight_offset + 7, w_grad_7, mask=feature_mask) 436 | tl.store(grad_weight_ptr + block_offset + weight_offset + 8, w_grad_8, mask=feature_mask) 437 | tl.store(grad_weight_ptr + block_offset + weight_offset + 9, w_grad_9, mask=feature_mask) 438 | 439 | 440 | def gelu_geometric_product_norm_bwd( 441 | x: torch.Tensor, 442 | y: torch.Tensor, 443 | weight: torch.Tensor, 444 | o: torch.Tensor, 445 | partial_norm: torch.Tensor, 446 | grad_output: torch.Tensor, 447 | normalize: bool, 448 | ) -> torch.Tensor: 449 | """Backward pass for the fused operation.""" 450 | _, B, N = x.shape 451 | 452 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 453 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 454 | 455 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 456 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 457 | 458 | grad_x = torch.zeros_like(x) 459 | grad_y = torch.zeros_like(y) 460 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0)) 461 | grad_weight = torch.zeros((num_blocks_batch, N, NUM_PRODUCT_WEIGHTS), device=x.device, dtype=weight.dtype) 462 | 463 | grid = (num_blocks_batch, num_blocks_features) 464 | 465 | if normalize: 466 | grad_o_dot_o_kernel[grid]( 467 | dot, 468 | partial_norm, 469 | o, 470 | grad_output, 471 | B, 472 | N, 473 | BATCH_BLOCK, 474 | FEATURE_BLOCK, 475 | EPS, 476 | num_warps=DEFAULT_NUM_WARPS, 477 | num_stages=DEFAULT_NUM_STAGES, 478 | ) 479 | 480 | gelu_wgp_norm_kernel_bwd[grid]( 481 | x, 482 | y, 483 | o, 484 | weight, 485 | dot, 486 | partial_norm, 487 | grad_output, 488 | grad_x, 489 | grad_y, 490 | grad_weight, 491 | normalize, 492 | B, 493 | N, 494 | BATCH_BLOCK, 495 | FEATURE_BLOCK, 496 | NUM_PRODUCT_WEIGHTS, 497 | EPS, 498 | num_warps=DEFAULT_NUM_WARPS, 499 | num_stages=DEFAULT_NUM_STAGES, 500 | ) 501 | 502 | grad_weight = torch.sum(grad_weight, dim=0) 503 | 504 | return grad_x, grad_y, grad_weight 505 | 506 | 507 | class WeightedGeluGeometricProductNorm2D(torch.autograd.Function): 508 | 509 | @staticmethod 510 | @torch.amp.custom_fwd(device_type="cuda") 511 | def forward(ctx, x, y, weight, normalize): 512 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous() 513 | 514 | ctx.dtype = x.dtype 515 | ctx.normalize = normalize 516 | 517 | o, partial_norm = gelu_geometric_product_norm_fwd( 518 | x, 519 | y, 520 | weight, 521 | normalize, 522 | ) 523 | 524 | ctx.save_for_backward(x, y, weight, o, partial_norm) 525 | 526 | return o.to(x.dtype) 527 | 528 | @staticmethod 529 | @torch.amp.custom_bwd(device_type="cuda") 530 | def backward(ctx, grad_output): 531 | grad_output = grad_output.contiguous() 532 | 533 | x, y, weight, o, partial_norm = ctx.saved_tensors 534 | 535 | grad_x, grad_y, grad_weight = gelu_geometric_product_norm_bwd( 536 | x, 537 | y, 538 | weight, 539 | o, 540 | partial_norm, 541 | grad_output, 542 | ctx.normalize, 543 | ) 544 | 545 | return grad_x, grad_y, grad_weight, None, None, None, None 546 | 547 | 548 | def fused_gelu_sgp_norm_2d(x, y, weight, normalize=True): 549 | """ 550 | Fused operation that applies GELU non-linearity to two multivector inputs, 551 | then computes their weighted geometric product, and applies RMSNorm. 552 | 553 | Clifford algebra is assumed to be Cl(2,0). 554 | 555 | Args: 556 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 557 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 558 | weight (torch.Tensor): Weight tensor of shape (N, NUM_PRODUCT_WEIGHTS), one weight per geometric product component. 559 | normalize (bool): Whether to apply RMSNorm after the geometric product. 560 | 561 | Returns: 562 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation. 563 | """ 564 | return WeightedGeluGeometricProductNorm2D.apply(x, y, weight, normalize) -------------------------------------------------------------------------------- /ops/fc_p2m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | MV_DIM = 4 6 | NUM_GRADES = 3 7 | NUM_PRODUCT_WEIGHTS = 10 8 | WEIGHT_EXPANSION = [0, 3, 7, 1, 4, 5, 8, 1, 5, 4, 8, 2, 6, 9] 9 | EPS = 1e-6 10 | 11 | # tuned at RTX 4500 12 | DEFAULT_BATCH_BLOCK = 4 13 | DEFAULT_FEATURE_BLOCK = 128 14 | DEFAULT_NUM_WARPS = 16 15 | DEFAULT_NUM_STAGES = 1 16 | 17 | 18 | @triton.jit 19 | def compute_gelu_gate(x): 20 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))""" 21 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype) 22 | 23 | 24 | @triton.jit 25 | def compute_gelu_gate_grad(x): 26 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)""" 27 | return 0.3989422804 * tl.exp(-0.5 * x * x) 28 | 29 | 30 | @triton.jit 31 | def gelu_pairwise_kernel_fwd( 32 | x_ptr, 33 | y_ptr, 34 | pairwise_ptr, 35 | batch_size: tl.constexpr, 36 | n_features: tl.constexpr, 37 | BATCH_BLOCK: tl.constexpr, 38 | FEATURE_BLOCK: tl.constexpr, 39 | ): 40 | """ 41 | Apply GELU non-linearity to inputs and compute pairwise products for geometric product. 42 | """ 43 | batch_block_id = tl.program_id(axis=0) 44 | thread_block_id = tl.program_id(axis=1) 45 | 46 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 47 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 48 | 49 | batch_mask = batch_ids < batch_size 50 | feature_mask = feature_ids < n_features 51 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 52 | 53 | stride_component = batch_size * n_features 54 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 55 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 56 | 57 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 58 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 59 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 60 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 61 | 62 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 63 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 64 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 65 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 66 | 67 | gate_x = compute_gelu_gate(x0) 68 | gate_y = compute_gelu_gate(y0) 69 | 70 | x0 = x0 * gate_x 71 | x1 = x1 * gate_x 72 | x2 = x2 * gate_x 73 | x3 = x3 * gate_x 74 | 75 | y0 = y0 * gate_y 76 | y1 = y1 * gate_y 77 | y2 = y2 * gate_y 78 | y3 = y3 * gate_y 79 | 80 | p0 = x0 * y0 81 | p1 = x1*y1 + x2*y2 82 | p2 = -x3 * y3 83 | p3 = x0 * y1 84 | p4 = x1 * y0 85 | p5 = x2 * y3 86 | p6 = x3 * y2 87 | p7 = x0 * y2 88 | p8 = x1 * y3 89 | p9 = x2 * y0 90 | p10 = -x3 * y1 91 | p11 = x0 * y3 92 | p12 = x1*y2 - x2*y1 93 | p13 = x3 * y0 94 | 95 | tl.store(pairwise_ptr + 0*batch_size*n_features + pairwise_offset, p0, mask=batch_feature_mask) 96 | tl.store(pairwise_ptr + 1*batch_size*n_features + pairwise_offset, p1, mask=batch_feature_mask) 97 | tl.store(pairwise_ptr + 2*batch_size*n_features + pairwise_offset, p2, mask=batch_feature_mask) 98 | tl.store(pairwise_ptr + 3*batch_size*n_features + pairwise_offset, p3, mask=batch_feature_mask) 99 | tl.store(pairwise_ptr + 4*batch_size*n_features + pairwise_offset, p4, mask=batch_feature_mask) 100 | tl.store(pairwise_ptr + 5*batch_size*n_features + pairwise_offset, p5, mask=batch_feature_mask) 101 | tl.store(pairwise_ptr + 6*batch_size*n_features + pairwise_offset, p6, mask=batch_feature_mask) 102 | tl.store(pairwise_ptr + 7*batch_size*n_features + pairwise_offset, p7, mask=batch_feature_mask) 103 | tl.store(pairwise_ptr + 8*batch_size*n_features + pairwise_offset, p8, mask=batch_feature_mask) 104 | tl.store(pairwise_ptr + 9*batch_size*n_features + pairwise_offset, p9, mask=batch_feature_mask) 105 | tl.store(pairwise_ptr + 10*batch_size*n_features + pairwise_offset, p10, mask=batch_feature_mask) 106 | tl.store(pairwise_ptr + 11*batch_size*n_features + pairwise_offset, p11, mask=batch_feature_mask) 107 | tl.store(pairwise_ptr + 12*batch_size*n_features + pairwise_offset, p12, mask=batch_feature_mask) 108 | tl.store(pairwise_ptr + 13*batch_size*n_features + pairwise_offset, p13, mask=batch_feature_mask) 109 | 110 | 111 | @triton.jit 112 | def assemble_kernel( 113 | transformed_ptr, 114 | pnorm_ptr, 115 | output_ptr, 116 | NORMALIZE: tl.constexpr, 117 | batch_size: tl.constexpr, 118 | n_features: tl.constexpr, 119 | BATCH_BLOCK: tl.constexpr, 120 | FEATURE_BLOCK: tl.constexpr, 121 | ): 122 | """Gather linearly transformed pairwise products and compute the geometric product.""" 123 | batch_block_id = tl.program_id(axis=0) 124 | thread_block_id = tl.program_id(axis=1) 125 | 126 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 127 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 128 | 129 | batch_mask = batch_ids < batch_size 130 | feature_mask = feature_ids < n_features 131 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 132 | 133 | stride_component = batch_size * n_features 134 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 135 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 136 | 137 | t0 = tl.load(transformed_ptr + 0*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 138 | t1 = tl.load(transformed_ptr + 1*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 139 | t2 = tl.load(transformed_ptr + 2*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 140 | t3 = tl.load(transformed_ptr + 3*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 141 | t4 = tl.load(transformed_ptr + 4*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 142 | t5 = tl.load(transformed_ptr + 5*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 143 | t6 = tl.load(transformed_ptr + 6*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 144 | t7 = tl.load(transformed_ptr + 7*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 145 | t8 = tl.load(transformed_ptr + 8*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 146 | t9 = tl.load(transformed_ptr + 9*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 147 | t10 = tl.load(transformed_ptr + 10*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 148 | t11 = tl.load(transformed_ptr + 11*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 149 | t12 = tl.load(transformed_ptr + 12*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 150 | t13 = tl.load(transformed_ptr + 13*batch_size*n_features + transformed_offset, mask=batch_feature_mask) 151 | 152 | o0 = t0 + t1 + t2 153 | o1 = t3 + t4 - t5 + t6 154 | o2 = t7 + t8 + t9 + t10 155 | o3 = t11 + t12 + t13 156 | 157 | if NORMALIZE: 158 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features 159 | pn_vector = tl.sum(o1*o1 + o2*o2, axis=1) / n_features 160 | pn_pseudo = tl.sum(o3 * o3, axis=1) / n_features 161 | 162 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask) 163 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask) 164 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask) 165 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_pseudo, mask=batch_mask) 166 | 167 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask) 168 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask) 169 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask) 170 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask) 171 | 172 | 173 | @triton.jit 174 | def normalize_with_sqrt_kernel( 175 | output_ptr, 176 | pnorm_ptr, 177 | batch_size: tl.constexpr, 178 | n_features: tl.constexpr, 179 | BATCH_BLOCK: tl.constexpr, 180 | FEATURE_BLOCK: tl.constexpr, 181 | MV_DIM: tl.constexpr, 182 | EPS: tl.constexpr, 183 | ): 184 | """Normalize the output by dividing each grade with root of its accumulated norm.""" 185 | batch_block_id = tl.program_id(axis=0) 186 | thread_block_id = tl.program_id(axis=1) 187 | 188 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 189 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 190 | 191 | batch_mask = batch_ids < batch_size 192 | feature_mask = feature_ids < n_features 193 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None] 194 | 195 | component_ids = tl.arange(0, MV_DIM)[None, None, :] 196 | 197 | feature_offset = (component_ids * batch_size * n_features + 198 | batch_ids[:, None, None] * n_features + 199 | feature_ids[None, :, None]) 200 | 201 | norm_indices = component_ids * batch_size + batch_ids[:, None, None] 202 | 203 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None]) 204 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask) 205 | 206 | norm = tl.sqrt(pnorm + EPS) 207 | mv_normalized = mv / norm 208 | 209 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask) 210 | 211 | 212 | def gelu_fc_geometric_product_norm_fwd( 213 | x: torch.Tensor, 214 | y: torch.Tensor, 215 | weight: torch.Tensor, 216 | expansion_indices: torch.Tensor, 217 | normalize: bool, 218 | ) -> torch.Tensor: 219 | """Fused operation: GELU non-linearity, fully connected geometric product, and grade-wise RMSNorm.""" 220 | assert x.shape == y.shape 221 | assert x.shape[0] == MV_DIM 222 | assert x.shape[2] == weight.shape[1] == weight.shape[2] 223 | assert weight.shape[0] == NUM_PRODUCT_WEIGHTS 224 | 225 | _, B, N = x.shape 226 | 227 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 228 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 229 | 230 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 231 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 232 | 233 | pairwise = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype) 234 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize else torch.zeros((1,), device=x.device, dtype=x.dtype)) 235 | output = torch.empty_like(x) 236 | 237 | grid = (num_blocks_batch, num_blocks_features) 238 | 239 | gelu_pairwise_kernel_fwd[grid]( 240 | x, 241 | y, 242 | pairwise, 243 | B, 244 | N, 245 | BATCH_BLOCK, 246 | FEATURE_BLOCK, 247 | num_warps=DEFAULT_NUM_WARPS, 248 | num_stages=DEFAULT_NUM_STAGES, 249 | ) 250 | 251 | transformed = torch.bmm(pairwise, weight[expansion_indices]) 252 | 253 | assemble_kernel[grid]( 254 | transformed, 255 | partial_norm, 256 | output, 257 | normalize, 258 | B, 259 | N, 260 | BATCH_BLOCK, 261 | FEATURE_BLOCK, 262 | num_warps=DEFAULT_NUM_WARPS, 263 | num_stages=DEFAULT_NUM_STAGES, 264 | ) 265 | 266 | if normalize: 267 | normalize_with_sqrt_kernel[grid]( 268 | output, 269 | partial_norm, 270 | B, 271 | N, 272 | BATCH_BLOCK, 273 | FEATURE_BLOCK, 274 | MV_DIM, 275 | EPS, 276 | num_warps=DEFAULT_NUM_WARPS, 277 | num_stages=DEFAULT_NUM_STAGES, 278 | ) 279 | 280 | return output, pairwise, partial_norm 281 | 282 | 283 | @triton.jit 284 | def grad_o_dot_o_kernel( 285 | dot_ptr, 286 | pnorm_ptr, 287 | output_ptr, 288 | grad_output_ptr, 289 | batch_size: tl.constexpr, 290 | n_features: tl.constexpr, 291 | BATCH_BLOCK: tl.constexpr, 292 | FEATURE_BLOCK: tl.constexpr, 293 | EPS: tl.constexpr, 294 | ): 295 | """Compute the dot product of grad_output and output for each grade, accumulate across all features.""" 296 | batch_block_id = tl.program_id(axis=0) 297 | thread_block_id = tl.program_id(axis=1) 298 | 299 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 300 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 301 | 302 | batch_mask = batch_ids < batch_size 303 | feature_mask = feature_ids < n_features 304 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 305 | 306 | stride_component = batch_size * n_features 307 | offset = batch_ids[:, None] * n_features + feature_ids[None, :] 308 | 309 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 310 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 311 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 312 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 313 | 314 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 315 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 316 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None] 317 | 318 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 319 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 320 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 321 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 322 | 323 | rms_scalar = tl.sqrt(pn_scalar + EPS) 324 | rms_vector = tl.sqrt(pn_vector + EPS) 325 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 326 | 327 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1) 328 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2), axis=1) 329 | dot_pseudo = tl.sum(rms_pseudo * go3 * o3, axis=1) 330 | 331 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask) 332 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask) 333 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_pseudo, mask=batch_mask) 334 | 335 | 336 | @triton.jit 337 | def disassemble_kernel( 338 | grad_output_ptr, 339 | output_ptr, 340 | dot_ptr, 341 | grad_transformed_ptr, 342 | pnorm_ptr, 343 | NORMALIZE: tl.constexpr, 344 | batch_size: tl.constexpr, 345 | n_features: tl.constexpr, 346 | BATCH_BLOCK: tl.constexpr, 347 | FEATURE_BLOCK: tl.constexpr, 348 | EPS: tl.constexpr, 349 | ): 350 | """ 351 | Gather linearly transformed pairwise products and compute the geometric product. 352 | """ 353 | batch_block_id = tl.program_id(axis=0) 354 | thread_block_id = tl.program_id(axis=1) 355 | 356 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 357 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 358 | 359 | batch_mask = batch_ids < batch_size 360 | feature_mask = feature_ids < n_features 361 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 362 | 363 | stride_component = batch_size * n_features 364 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 365 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 366 | 367 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 368 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 369 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 370 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 371 | 372 | if NORMALIZE: 373 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 374 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 375 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 376 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 377 | 378 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 379 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 380 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None] 381 | 382 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 383 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 384 | dot_pseudo = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None] 385 | 386 | rms_scalar = tl.sqrt(pn_scalar + EPS) 387 | rms_vector = tl.sqrt(pn_vector + EPS) 388 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 389 | 390 | go0 = go0/rms_scalar - o0 * dot_scalar / (n_features*rms_scalar*rms_scalar) 391 | go1 = go1/rms_vector - o1 * dot_vector / (n_features*rms_vector*rms_vector) 392 | go2 = go2/rms_vector - o2 * dot_vector / (n_features*rms_vector*rms_vector) 393 | go3 = go3/rms_pseudo - o3 * dot_pseudo / (n_features*rms_pseudo*rms_pseudo) 394 | 395 | tl.store(grad_transformed_ptr + 0*batch_size*n_features + transformed_offset, go0, mask=batch_feature_mask) 396 | tl.store(grad_transformed_ptr + 1*batch_size*n_features + transformed_offset, go0, mask=batch_feature_mask) 397 | tl.store(grad_transformed_ptr + 2*batch_size*n_features + transformed_offset, go0, mask=batch_feature_mask) 398 | tl.store(grad_transformed_ptr + 3*batch_size*n_features + transformed_offset, go1, mask=batch_feature_mask) 399 | tl.store(grad_transformed_ptr + 4*batch_size*n_features + transformed_offset, go1, mask=batch_feature_mask) 400 | tl.store(grad_transformed_ptr + 5*batch_size*n_features + transformed_offset, -go1, mask=batch_feature_mask) 401 | tl.store(grad_transformed_ptr + 6*batch_size*n_features + transformed_offset, go1, mask=batch_feature_mask) 402 | tl.store(grad_transformed_ptr + 7*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask) 403 | tl.store(grad_transformed_ptr + 8*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask) 404 | tl.store(grad_transformed_ptr + 9*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask) 405 | tl.store(grad_transformed_ptr + 10*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask) 406 | tl.store(grad_transformed_ptr + 11*batch_size*n_features + transformed_offset, go3, mask=batch_feature_mask) 407 | tl.store(grad_transformed_ptr + 12*batch_size*n_features + transformed_offset, go3, mask=batch_feature_mask) 408 | tl.store(grad_transformed_ptr + 13*batch_size*n_features + transformed_offset, go3, mask=batch_feature_mask) 409 | 410 | 411 | @triton.jit 412 | def gelu_pairwise_kernel_bwd( 413 | x_ptr, 414 | y_ptr, 415 | grad_pairwise_ptr, 416 | grad_x_ptr, 417 | grad_y_ptr, 418 | batch_size: tl.constexpr, 419 | n_features: tl.constexpr, 420 | BATCH_BLOCK: tl.constexpr, 421 | FEATURE_BLOCK: tl.constexpr, 422 | ): 423 | """ 424 | Apply GELU non-linearity to inputs and compute required pairwise products. 425 | """ 426 | batch_block_id = tl.program_id(axis=0) 427 | thread_block_id = tl.program_id(axis=1) 428 | 429 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 430 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 431 | 432 | batch_mask = batch_ids < batch_size 433 | feature_mask = feature_ids < n_features 434 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 435 | 436 | stride_component = batch_size * n_features 437 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 438 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 439 | 440 | gp0 = tl.load(grad_pairwise_ptr + 0*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 441 | gp1 = tl.load(grad_pairwise_ptr + 1*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 442 | gp2 = tl.load(grad_pairwise_ptr + 2*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 443 | gp3 = tl.load(grad_pairwise_ptr + 3*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 444 | gp4 = tl.load(grad_pairwise_ptr + 4*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 445 | gp5 = tl.load(grad_pairwise_ptr + 5*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 446 | gp6 = tl.load(grad_pairwise_ptr + 6*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 447 | gp7 = tl.load(grad_pairwise_ptr + 7*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 448 | gp8 = tl.load(grad_pairwise_ptr + 8*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 449 | gp9 = tl.load(grad_pairwise_ptr + 9*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 450 | gp10 = tl.load(grad_pairwise_ptr + 10*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 451 | gp11 = tl.load(grad_pairwise_ptr + 11*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 452 | gp12 = tl.load(grad_pairwise_ptr + 12*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 453 | gp13 = tl.load(grad_pairwise_ptr + 13*batch_size*n_features + pairwise_offset, mask=batch_feature_mask) 454 | 455 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 456 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 457 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 458 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 459 | 460 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 461 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 462 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 463 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 464 | 465 | # collect gradients from pairwise products 466 | gate_x = compute_gelu_gate(x0_raw) 467 | gate_y = compute_gelu_gate(y0_raw) 468 | 469 | x0 = x0_raw * gate_x 470 | x1 = x1_raw * gate_x 471 | x2 = x2_raw * gate_x 472 | x3 = x3_raw * gate_x 473 | 474 | y0 = y0_raw * gate_y 475 | y1 = y1_raw * gate_y 476 | y2 = y2_raw * gate_y 477 | y3 = y3_raw * gate_y 478 | 479 | x_grad_0 = gp0*y0 + gp3*y1 + gp7*y2 + gp11*y3 480 | x_grad_1 = gp1*y1 + gp4*y0 + gp8*y3 + gp12*y2 481 | x_grad_2 = gp1*y2 + gp5*y3 + gp9*y0 - gp12*y1 482 | x_grad_3 = -gp2 * y3 + gp6*y2 - gp10*y1 + gp13*y0 483 | 484 | y_grad_0 = gp0*x0 + gp4*x1 + gp9*x2 + gp13*x3 485 | y_grad_1 = gp1*x1 + gp3*x0 - gp10*x3 - gp12*x2 486 | y_grad_2 = gp1*x2 + gp6*x3 + gp7*x0 + gp12*x1 487 | y_grad_3 = -gp2 * x3 + gp5*x2 + gp8*x1 + gp11*x0 488 | 489 | # GELU gate gradients 490 | dgate_x = compute_gelu_gate_grad(x0_raw) 491 | dgate_y = compute_gelu_gate_grad(y0_raw) 492 | 493 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3) 494 | x_grad_1 = gate_x * x_grad_1 495 | x_grad_2 = gate_x * x_grad_2 496 | x_grad_3 = gate_x * x_grad_3 497 | 498 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3) 499 | y_grad_1 = gate_y * y_grad_1 500 | y_grad_2 = gate_y * y_grad_2 501 | y_grad_3 = gate_y * y_grad_3 502 | 503 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask) 504 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask) 505 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask) 506 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask) 507 | 508 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask) 509 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask) 510 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask) 511 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask) 512 | 513 | 514 | def gelu_fc_geometric_product_norm_bwd( 515 | x: torch.Tensor, 516 | y: torch.Tensor, 517 | weight: torch.Tensor, 518 | o: torch.Tensor, 519 | pairwise: torch.Tensor, 520 | partial_norm: torch.Tensor, 521 | grad_output: torch.Tensor, 522 | expansion_indices: torch.Tensor, 523 | normalize: bool, 524 | ) -> torch.Tensor: 525 | """Backward pass for the fused operation.""" 526 | _, B, N = x.shape 527 | 528 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 529 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 530 | 531 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 532 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 533 | 534 | grad_x = torch.zeros_like(x) 535 | grad_y = torch.zeros_like(y) 536 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0)) 537 | grad_weight = torch.zeros((NUM_PRODUCT_WEIGHTS, N, N), device=x.device, dtype=weight.dtype) 538 | grad_transformed = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype) 539 | 540 | grid = (num_blocks_batch, num_blocks_features) 541 | 542 | if normalize: 543 | grad_o_dot_o_kernel[grid]( 544 | dot, 545 | partial_norm, 546 | o, 547 | grad_output, 548 | B, 549 | N, 550 | BATCH_BLOCK, 551 | FEATURE_BLOCK, 552 | EPS, 553 | num_warps=DEFAULT_NUM_WARPS, 554 | num_stages=DEFAULT_NUM_STAGES, 555 | ) 556 | 557 | disassemble_kernel[grid]( 558 | grad_output, 559 | o, 560 | dot, 561 | grad_transformed, 562 | partial_norm, 563 | normalize, 564 | B, 565 | N, 566 | BATCH_BLOCK, 567 | FEATURE_BLOCK, 568 | EPS, 569 | num_warps=DEFAULT_NUM_WARPS, 570 | num_stages=DEFAULT_NUM_STAGES, 571 | ) 572 | 573 | grad_pairwise = torch.bmm(grad_transformed, weight[expansion_indices].transpose(-2, -1)) 574 | 575 | grad_weight.index_add_(0, expansion_indices, torch.bmm(pairwise.transpose(-2, -1), grad_transformed)) 576 | 577 | gelu_pairwise_kernel_bwd[grid]( 578 | x, 579 | y, 580 | grad_pairwise, 581 | grad_x, 582 | grad_y, 583 | B, 584 | N, 585 | BATCH_BLOCK, 586 | FEATURE_BLOCK, 587 | num_warps=DEFAULT_NUM_WARPS, 588 | num_stages=DEFAULT_NUM_STAGES, 589 | ) 590 | 591 | return grad_x, grad_y, grad_weight 592 | 593 | 594 | class FullyConnectedGeluGeometricProductNorm2D(torch.autograd.Function): 595 | 596 | @staticmethod 597 | @torch.amp.custom_fwd(device_type="cuda") 598 | def forward(ctx, x, y, weight, normalize): 599 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous() 600 | 601 | ctx.dtype = x.dtype 602 | ctx.normalize = normalize 603 | 604 | expansion_indices = torch.tensor(WEIGHT_EXPANSION, device=x.device) 605 | 606 | o, pairwise, partial_norm = gelu_fc_geometric_product_norm_fwd( 607 | x, 608 | y, 609 | weight, 610 | expansion_indices, 611 | normalize, 612 | ) 613 | 614 | ctx.save_for_backward(x, y, weight, o, pairwise, partial_norm, expansion_indices) 615 | 616 | return o.to(x.dtype) 617 | 618 | @staticmethod 619 | @torch.amp.custom_bwd(device_type="cuda") 620 | def backward(ctx, grad_output): 621 | grad_output = grad_output.contiguous() 622 | 623 | x, y, weight, o, pairwise, partial_norm, expansion_indices = ctx.saved_tensors 624 | 625 | grad_x, grad_y, grad_weight = gelu_fc_geometric_product_norm_bwd( 626 | x, 627 | y, 628 | weight, 629 | o, 630 | pairwise, 631 | partial_norm, 632 | grad_output, 633 | expansion_indices, 634 | ctx.normalize, 635 | ) 636 | 637 | return grad_x, grad_y, grad_weight, None, None, None, None 638 | 639 | 640 | def fused_gelu_fcgp_norm_2d(x, y, weight, normalize=True): 641 | """ 642 | Fused operation that applies GELU non-linearity to two multivector inputs, 643 | then computes their fully connected geometric product, and applies RMSNorm. 644 | 645 | Clifford algebra is assumed to be Cl(2,0). 646 | 647 | Args: 648 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 649 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 650 | weight (torch.Tensor): Weight tensor of shape (NUM_PRODUCT_WEIGHTS, N, N), one weight per geometric product component. 651 | normalize (bool): Whether to apply RMSNorm after the geometric product. 652 | 653 | Returns: 654 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation. 655 | """ 656 | return FullyConnectedGeluGeometricProductNorm2D.apply(x, y, weight, normalize) 657 | -------------------------------------------------------------------------------- /ops/p3m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | MV_DIM = 8 6 | NUM_GRADES = 4 7 | NUM_PRODUCT_WEIGHTS = 20 8 | EPS = 1e-6 9 | 10 | # tuned at RTX 4500 11 | DEFAULT_BATCH_BLOCK = 4 12 | DEFAULT_FEATURE_BLOCK = 128 13 | DEFAULT_NUM_WARPS = 16 14 | DEFAULT_NUM_STAGES = 1 15 | 16 | 17 | @triton.jit 18 | def compute_gelu_gate(x): 19 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))""" 20 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype) 21 | 22 | 23 | @triton.jit 24 | def compute_gelu_gate_grad(x): 25 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)""" 26 | return 0.3989422804 * tl.exp(-0.5 * x * x) 27 | 28 | 29 | @triton.jit 30 | def gelu_wgp_norm_kernel_fwd( 31 | x_ptr, 32 | y_ptr, 33 | output_ptr, 34 | weights_ptr, 35 | pnorm_ptr, 36 | NORMALIZE: tl.constexpr, 37 | batch_size: tl.constexpr, 38 | n_features: tl.constexpr, 39 | BATCH_BLOCK: tl.constexpr, 40 | FEATURE_BLOCK: tl.constexpr, 41 | NUM_PRODUCT_WEIGHTS: tl.constexpr, 42 | ): 43 | """ 44 | Apply GELU non-linearity to inputs, compute weighted geometric product, 45 | and accumulate squared norms for grade-wise RMSNorm. 46 | """ 47 | batch_block_id = tl.program_id(axis=0) 48 | thread_block_id = tl.program_id(axis=1) 49 | 50 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 51 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 52 | 53 | batch_mask = batch_ids < batch_size 54 | feature_mask = feature_ids < n_features 55 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 56 | 57 | stride_component = batch_size * n_features 58 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 59 | 60 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS 61 | 62 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask) 63 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask) 64 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask) 65 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask) 66 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask) 67 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask) 68 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask) 69 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask) 70 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask) 71 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask) 72 | w10 = tl.load(weights_ptr + weight_offset + 10, mask=feature_mask) 73 | w11 = tl.load(weights_ptr + weight_offset + 11, mask=feature_mask) 74 | w12 = tl.load(weights_ptr + weight_offset + 12, mask=feature_mask) 75 | w13 = tl.load(weights_ptr + weight_offset + 13, mask=feature_mask) 76 | w14 = tl.load(weights_ptr + weight_offset + 14, mask=feature_mask) 77 | w15 = tl.load(weights_ptr + weight_offset + 15, mask=feature_mask) 78 | w16 = tl.load(weights_ptr + weight_offset + 16, mask=feature_mask) 79 | w17 = tl.load(weights_ptr + weight_offset + 17, mask=feature_mask) 80 | w18 = tl.load(weights_ptr + weight_offset + 18, mask=feature_mask) 81 | w19 = tl.load(weights_ptr + weight_offset + 19, mask=feature_mask) 82 | 83 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 84 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 85 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 86 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 87 | x4 = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 88 | x5 = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 89 | x6 = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 90 | x7 = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 91 | 92 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 93 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 94 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 95 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 96 | y4 = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 97 | y5 = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 98 | y6 = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 99 | y7 = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 100 | 101 | # Apply GELU gate 102 | gate_x = compute_gelu_gate(x0) 103 | gate_y = compute_gelu_gate(y0) 104 | 105 | x0 = x0 * gate_x 106 | x1 = x1 * gate_x 107 | x2 = x2 * gate_x 108 | x3 = x3 * gate_x 109 | x4 = x4 * gate_x 110 | x5 = x5 * gate_x 111 | x6 = x6 * gate_x 112 | x7 = x7 * gate_x 113 | 114 | y0 = y0 * gate_y 115 | y1 = y1 * gate_y 116 | y2 = y2 * gate_y 117 | y3 = y3 * gate_y 118 | y4 = y4 * gate_y 119 | y5 = y5 * gate_y 120 | y6 = y6 * gate_y 121 | y7 = y7 * gate_y 122 | 123 | # Compute geometric product 124 | o0 = (w0*x0*y0 + w4 * (x1*y1 + x2*y2 + x3*y3) - w10 * (x4*y4 + x5*y5 + x6*y6) - w16*x7*y7) 125 | o1 = (w1*x0*y1 + w5*x1*y0 - w6 * (x2*y4 + x3*y5) + w11 * (x4*y2 + x5*y3) - w12*x6*y7 - w17*x7*y6) 126 | o2 = (w1*x0*y2 + w6*x1*y4 + w5*x2*y0 - w6*x3*y6 - w11*x4*y1 + w12*x5*y7 + w11*x6*y3 + w17*x7*y5) 127 | o3 = (w1*x0*y3 + w6 * (x1*y5 + x2*y6) + w5*x3*y0 - w12*x4*y7 - w11 * (x5*y1 + x6*y2) - w17*x7*y4) 128 | o4 = (w2*x0*y4 + w7*x1*y2 - w7*x2*y1 + w8*x3*y7 + w13*x4*y0 - w14*x5*y6 + w14*x6*y5 + w18*x7*y3) 129 | o5 = (w2*x0*y5 + w7*x1*y3 - w8*x2*y7 - w7*x3*y1 + w14*x4*y6 + w13*x5*y0 - w14*x6*y4 - w18*x7*y2) 130 | o6 = (w2*x0*y6 + w8*x1*y7 + w7*x2*y3 - w7*x3*y2 - w14*x4*y5 + w14*x5*y4 + w13*x6*y0 + w18*x7*y1) 131 | o7 = (w3*x0*y7 + w9*x1*y6 - w9*x2*y5 + w9*x3*y4 + w15*x4*y3 - w15*x5*y2 + w15*x6*y1 + w19*x7*y0) 132 | 133 | if NORMALIZE: 134 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features 135 | pn_vector = tl.sum(o1*o1 + o2*o2 + o3*o3, axis=1) / n_features 136 | pn_bivect = tl.sum(o4*o4 + o5*o5 + o6*o6, axis=1) / n_features 137 | pn_pseudo = tl.sum(o7 * o7, axis=1) / n_features 138 | 139 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask) 140 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask) 141 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask) 142 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_vector, mask=batch_mask) 143 | tl.atomic_add(pnorm_ptr + 4*batch_size + batch_ids, pn_bivect, mask=batch_mask) 144 | tl.atomic_add(pnorm_ptr + 5*batch_size + batch_ids, pn_bivect, mask=batch_mask) 145 | tl.atomic_add(pnorm_ptr + 6*batch_size + batch_ids, pn_bivect, mask=batch_mask) 146 | tl.atomic_add(pnorm_ptr + 7*batch_size + batch_ids, pn_pseudo, mask=batch_mask) 147 | 148 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask) 149 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask) 150 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask) 151 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask) 152 | tl.store(output_ptr + 4 * stride_component + base_offset, o4, mask=batch_feature_mask) 153 | tl.store(output_ptr + 5 * stride_component + base_offset, o5, mask=batch_feature_mask) 154 | tl.store(output_ptr + 6 * stride_component + base_offset, o6, mask=batch_feature_mask) 155 | tl.store(output_ptr + 7 * stride_component + base_offset, o7, mask=batch_feature_mask) 156 | 157 | 158 | @triton.jit 159 | def normalize_with_sqrt_kernel( 160 | output_ptr, 161 | pnorm_ptr, 162 | batch_size: tl.constexpr, 163 | n_features: tl.constexpr, 164 | BATCH_BLOCK: tl.constexpr, 165 | FEATURE_BLOCK: tl.constexpr, 166 | MV_DIM: tl.constexpr, 167 | EPS: tl.constexpr, 168 | ): 169 | """Normalize the output by dividing each grade with root of its accumulated norm.""" 170 | batch_block_id = tl.program_id(axis=0) 171 | thread_block_id = tl.program_id(axis=1) 172 | 173 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 174 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 175 | 176 | batch_mask = batch_ids < batch_size 177 | feature_mask = feature_ids < n_features 178 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None] 179 | 180 | component_ids = tl.arange(0, MV_DIM)[None, None, :] 181 | 182 | feature_offset = (component_ids * batch_size * n_features + 183 | batch_ids[:, None, None] * n_features + 184 | feature_ids[None, :, None]) 185 | 186 | norm_indices = component_ids * batch_size + batch_ids[:, None, None] 187 | 188 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None]) 189 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask) 190 | 191 | norm = tl.sqrt(pnorm + EPS) 192 | mv_normalized = mv / norm 193 | 194 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask) 195 | 196 | 197 | def gelu_geometric_product_norm_fwd( 198 | x: torch.Tensor, 199 | y: torch.Tensor, 200 | weight: torch.Tensor, 201 | normalize: bool, 202 | ) -> torch.Tensor: 203 | """Fused operation: GELU non-linearity, weighted geometric product, and grade-wise RMSNorm.""" 204 | assert x.shape == y.shape 205 | assert x.shape[0] == MV_DIM 206 | assert x.shape[2] == weight.shape[0] 207 | assert weight.shape[1] == NUM_PRODUCT_WEIGHTS 208 | 209 | _, B, N = x.shape 210 | 211 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 212 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 213 | 214 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 215 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 216 | 217 | output = torch.empty_like(x) 218 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize 219 | else torch.zeros((1,), device=x.device, dtype=x.dtype)) 220 | 221 | grid = (num_blocks_batch, num_blocks_features) 222 | 223 | gelu_wgp_norm_kernel_fwd[grid]( 224 | x, 225 | y, 226 | output, 227 | weight, 228 | partial_norm, 229 | normalize, 230 | B, 231 | N, 232 | BATCH_BLOCK, 233 | FEATURE_BLOCK, 234 | NUM_PRODUCT_WEIGHTS, 235 | num_warps=DEFAULT_NUM_WARPS, 236 | num_stages=DEFAULT_NUM_STAGES, 237 | ) 238 | 239 | if normalize: 240 | normalize_with_sqrt_kernel[grid]( 241 | output, 242 | partial_norm, 243 | B, 244 | N, 245 | BATCH_BLOCK, 246 | FEATURE_BLOCK, 247 | MV_DIM, 248 | EPS, 249 | num_warps=DEFAULT_NUM_WARPS, 250 | num_stages=DEFAULT_NUM_STAGES, 251 | ) 252 | 253 | return output, partial_norm 254 | 255 | 256 | @triton.jit 257 | def grad_o_dot_o_kernel( 258 | dot_ptr, 259 | pnorm_ptr, 260 | output_ptr, 261 | grad_output_ptr, 262 | batch_size: tl.constexpr, 263 | n_features: tl.constexpr, 264 | BATCH_BLOCK: tl.constexpr, 265 | FEATURE_BLOCK: tl.constexpr, 266 | EPS: tl.constexpr, 267 | ): 268 | """Compute the dot product of grad_output and output for each grade, accumulate across all features.""" 269 | batch_block_id = tl.program_id(axis=0) 270 | thread_block_id = tl.program_id(axis=1) 271 | 272 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 273 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 274 | 275 | batch_mask = batch_ids < batch_size 276 | feature_mask = feature_ids < n_features 277 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 278 | 279 | stride_component = batch_size * n_features 280 | offset = batch_ids[:, None] * n_features + feature_ids[None, :] 281 | 282 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 283 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 284 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 285 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 286 | go4 = tl.load(grad_output_ptr + 4 * stride_component + offset, mask=batch_feature_mask) 287 | go5 = tl.load(grad_output_ptr + 5 * stride_component + offset, mask=batch_feature_mask) 288 | go6 = tl.load(grad_output_ptr + 6 * stride_component + offset, mask=batch_feature_mask) 289 | go7 = tl.load(grad_output_ptr + 7 * stride_component + offset, mask=batch_feature_mask) 290 | 291 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 292 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 293 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None] 294 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None] 295 | 296 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 297 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 298 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 299 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 300 | o4 = tl.load(output_ptr + 4 * stride_component + offset, mask=batch_feature_mask) 301 | o5 = tl.load(output_ptr + 5 * stride_component + offset, mask=batch_feature_mask) 302 | o6 = tl.load(output_ptr + 6 * stride_component + offset, mask=batch_feature_mask) 303 | o7 = tl.load(output_ptr + 7 * stride_component + offset, mask=batch_feature_mask) 304 | 305 | rms_scalar = tl.sqrt(pn_scalar + EPS) 306 | rms_vector = tl.sqrt(pn_vector + EPS) 307 | rms_bivect = tl.sqrt(pn_bivect + EPS) 308 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 309 | 310 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1) 311 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2 + go3*o3), axis=1) 312 | dot_bivect = tl.sum(rms_bivect * (go4*o4 + go5*o5 + go6*o6), axis=1) 313 | dot_pseudo = tl.sum(rms_pseudo * go7 * o7, axis=1) 314 | 315 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask) 316 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask) 317 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_bivect, mask=batch_mask) 318 | tl.atomic_add(dot_ptr + 3*batch_size + batch_ids, dot_pseudo, mask=batch_mask) 319 | 320 | 321 | @triton.jit 322 | def gelu_wgp_norm_kernel_bwd( 323 | x_ptr, 324 | y_ptr, 325 | output_ptr, 326 | weights_ptr, 327 | dot_ptr, 328 | pnorm_ptr, 329 | grad_output_ptr, 330 | grad_x_ptr, 331 | grad_y_ptr, 332 | grad_weight_ptr, 333 | NORMALIZE: tl.constexpr, 334 | batch_size: tl.constexpr, 335 | n_features: tl.constexpr, 336 | BATCH_BLOCK: tl.constexpr, 337 | FEATURE_BLOCK: tl.constexpr, 338 | MV_DIM: tl.constexpr, 339 | NUM_GRADES: tl.constexpr, 340 | NUM_PRODUCT_WEIGHTS: tl.constexpr, 341 | EPS: tl.constexpr, 342 | ): 343 | """Compute gradients w.r.t. inputs and weights of the fused operation.""" 344 | batch_block_id = tl.program_id(axis=0) 345 | thread_block_id = tl.program_id(axis=1) 346 | 347 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 348 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 349 | 350 | batch_mask = batch_ids < batch_size 351 | feature_mask = feature_ids < n_features 352 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 353 | 354 | stride_component = batch_size * n_features 355 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 356 | 357 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS 358 | block_offset = batch_block_id * n_features * NUM_PRODUCT_WEIGHTS 359 | 360 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask) 361 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask) 362 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask) 363 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask) 364 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask) 365 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask) 366 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask) 367 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask) 368 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask) 369 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask) 370 | w10 = tl.load(weights_ptr + weight_offset + 10, mask=feature_mask) 371 | w11 = tl.load(weights_ptr + weight_offset + 11, mask=feature_mask) 372 | w12 = tl.load(weights_ptr + weight_offset + 12, mask=feature_mask) 373 | w13 = tl.load(weights_ptr + weight_offset + 13, mask=feature_mask) 374 | w14 = tl.load(weights_ptr + weight_offset + 14, mask=feature_mask) 375 | w15 = tl.load(weights_ptr + weight_offset + 15, mask=feature_mask) 376 | w16 = tl.load(weights_ptr + weight_offset + 16, mask=feature_mask) 377 | w17 = tl.load(weights_ptr + weight_offset + 17, mask=feature_mask) 378 | w18 = tl.load(weights_ptr + weight_offset + 18, mask=feature_mask) 379 | w19 = tl.load(weights_ptr + weight_offset + 19, mask=feature_mask) 380 | 381 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 382 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 383 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 384 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 385 | x4_raw = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 386 | x5_raw = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 387 | x6_raw = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 388 | x7_raw = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 389 | 390 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 391 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 392 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 393 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 394 | y4_raw = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 395 | y5_raw = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 396 | y6_raw = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 397 | y7_raw = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 398 | 399 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 400 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 401 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 402 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 403 | go4 = tl.load(grad_output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 404 | go5 = tl.load(grad_output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 405 | go6 = tl.load(grad_output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 406 | go7 = tl.load(grad_output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 407 | 408 | if NORMALIZE: 409 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 410 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 411 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 412 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 413 | o4 = tl.load(output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 414 | o5 = tl.load(output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 415 | o6 = tl.load(output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 416 | o7 = tl.load(output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 417 | 418 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 419 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 420 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None] 421 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None] 422 | 423 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 424 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 425 | dot_bivect = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None] 426 | dot_pseudo = tl.load(dot_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None] 427 | 428 | rms_scalar = tl.sqrt(pn_scalar + EPS) 429 | rms_vector = tl.sqrt(pn_vector + EPS) 430 | rms_bivect = tl.sqrt(pn_bivect + EPS) 431 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 432 | 433 | go0 = go0/rms_scalar - o0 * dot_scalar / (n_features*rms_scalar*rms_scalar) 434 | go1 = go1/rms_vector - o1 * dot_vector / (n_features*rms_vector*rms_vector) 435 | go2 = go2/rms_vector - o2 * dot_vector / (n_features*rms_vector*rms_vector) 436 | go3 = go3/rms_vector - o3 * dot_vector / (n_features*rms_vector*rms_vector) 437 | go4 = go4/rms_bivect - o4 * dot_bivect / (n_features*rms_bivect*rms_bivect) 438 | go5 = go5/rms_bivect - o5 * dot_bivect / (n_features*rms_bivect*rms_bivect) 439 | go6 = go6/rms_bivect - o6 * dot_bivect / (n_features*rms_bivect*rms_bivect) 440 | go7 = go7/rms_pseudo - o7 * dot_pseudo / (n_features*rms_pseudo*rms_pseudo) 441 | 442 | # weighted geometric product backward 443 | gate_x = compute_gelu_gate(x0_raw) 444 | gate_y = compute_gelu_gate(y0_raw) 445 | 446 | x0 = x0_raw * gate_x 447 | x1 = x1_raw * gate_x 448 | x2 = x2_raw * gate_x 449 | x3 = x3_raw * gate_x 450 | x4 = x4_raw * gate_x 451 | x5 = x5_raw * gate_x 452 | x6 = x6_raw * gate_x 453 | x7 = x7_raw * gate_x 454 | 455 | y0 = y0_raw * gate_y 456 | y1 = y1_raw * gate_y 457 | y2 = y2_raw * gate_y 458 | y3 = y3_raw * gate_y 459 | y4 = y4_raw * gate_y 460 | y5 = y5_raw * gate_y 461 | y6 = y6_raw * gate_y 462 | y7 = y7_raw * gate_y 463 | 464 | tmp0 = go0 * w0 465 | tmp1 = go7 * w3 466 | tmp2 = go1 * y1 467 | tmp3 = go2 * y2 468 | tmp4 = go3 * y3 469 | tmp5 = go4 * y4 470 | tmp6 = go5 * y5 471 | tmp7 = go6 * y6 472 | tmp8 = go0 * w4 473 | tmp9 = w5 * y0 474 | tmp10 = w8 * y7 475 | tmp11 = go7 * w9 476 | tmp12 = go1 * w6 477 | tmp13 = w13 * y0 478 | tmp14 = go7 * w15 479 | tmp15 = go0 * w10 480 | tmp16 = w12 * y7 481 | tmp17 = go3 * w11 482 | tmp18 = go7 * w19 483 | tmp19 = go0 * w16 484 | tmp20 = go4 * y3 485 | tmp21 = go6 * y1 486 | tmp22 = go5 * y2 487 | tmp23 = go1 * y6 488 | tmp24 = go3 * y4 489 | tmp25 = go4 * x4 490 | tmp26 = go5 * x5 491 | tmp27 = go6 * x6 492 | tmp28 = go1 * x1 493 | tmp29 = go2 * x2 494 | tmp30 = go3 * x3 495 | tmp31 = w1 * x0 496 | tmp32 = w18 * x7 497 | tmp33 = w2 * x0 498 | tmp34 = w17 * x7 499 | tmp35 = go4 * x3 500 | tmp36 = go6 * x1 501 | tmp37 = go5 * x2 502 | tmp38 = go1 * x6 503 | tmp39 = go3 * x4 504 | 505 | x_grad_0 = (tmp0*y0 + tmp1*y7 + w1 * (tmp2+tmp3+tmp4) + w2 * (tmp5+tmp6+tmp7)) 506 | x_grad_1 = (go1*tmp9 + go6*tmp10 + tmp11*y6 + tmp8*y1 + w6 * (go2*y4 + go3*y5) + w7 * (go4*y2 + go5*y3)) 507 | x_grad_2 = (go2*tmp9 + go3*w6*y6 - go5*tmp10 - tmp11*y5 - tmp12*y4 + tmp8*y2 + w7 * (-go4 * y1 + go6*y3)) 508 | x_grad_3 = (go3*tmp9 + go4*tmp10 + tmp11*y4 + tmp8*y3 - w6 * (go1*y5 + go2*y6) - w7 * (go5*y1 + go6*y2)) 509 | x_grad_4 = (-go3 * tmp16 + go4*tmp13 + tmp14*y3 - tmp15*y4 + w11 * (go1*y2 - go2*y1) + w14 * (go5*y6 - go6*y5)) 510 | x_grad_5 = (go1*w11*y3 + go2*w12*y7 - go4*w14*y6 + go5*w13*y0 + go6*w14*y4 - tmp14*y2 - tmp15*y5 - tmp17*y1) 511 | x_grad_6 = (-go1 * tmp16 + go6*tmp13 + tmp14*y1 - tmp15*y6 + w11 * (go2*y3 - go3*y2) + w14 * (go4*y5 - go5*y4)) 512 | x_grad_7 = (tmp18*y0 - tmp19*y7 + w17 * (go2*y5 - tmp23 - tmp24) + w18 * (tmp20+tmp21-tmp22)) 513 | 514 | y_grad_0 = (tmp0*x0 + tmp18*x7 + w13 * (tmp25+tmp26+tmp27) + w5 * (tmp28+tmp29+tmp30)) 515 | y_grad_1 = (go1*tmp31 + go6*tmp32 + tmp14*x6 + tmp8*x1 - w11 * (go2*x4 + go3*x5) - w7 * (go4*x2 + go5*x3)) 516 | y_grad_2 = (go1*w11*x4 + go2*tmp31 + go4*w7*x1 - go5*tmp32 - go6*w7*x3 - tmp14*x5 - tmp17*x6 + tmp8*x2) 517 | y_grad_3 = (go3*tmp31 + go4*tmp32 + tmp14*x4 + tmp8*x3 + w11 * (go1*x5 + go2*x6) + w7 * (go5*x1 + go6*x2)) 518 | y_grad_4 = (-go3 * tmp34 + go4*tmp33 + tmp11*x3 - tmp15*x4 + w14 * (-go5 * x6 + go6*x5) + w6 * (-go1 * x2 + go2*x1)) 519 | y_grad_5 = (go2*w17*x7 + go3*w6*x1 + go4*w14*x6 + go5*w2*x0 - go6*w14*x4 - tmp11*x2 - tmp12*x3 - tmp15*x5) 520 | y_grad_6 = (-go1 * tmp34 + go6*tmp33 + tmp11*x1 - tmp15*x6 + w14 * (-go4 * x5 + go5*x4) + w6 * (-go2 * x3 + go3*x2)) 521 | y_grad_7 = (tmp1*x0 - tmp19*x7 + w12 * (go2*x5 - tmp38 - tmp39) + w8 * (tmp35+tmp36-tmp37)) 522 | 523 | w_grad_0 = tl.sum(go0 * x0 * y0, axis=0) 524 | w_grad_1 = tl.sum(tmp2*x0 + tmp3*x0 + tmp4*x0, axis=0) 525 | w_grad_2 = tl.sum(tmp5*x0 + tmp6*x0 + tmp7*x0, axis=0) 526 | w_grad_3 = tl.sum(go7 * x0 * y7, axis=0) 527 | w_grad_4 = tl.sum(go0 * (x1*y1 + x2*y2 + x3*y3), axis=0) 528 | w_grad_5 = tl.sum(tmp28*y0 + tmp29*y0 + tmp30*y0, axis=0) 529 | w_grad_6 = tl.sum(go1 * (-x2 * y4 - x3*y5) + go2 * (x1*y4 - x3*y6) + go3 * (x1*y5 + x2*y6), axis=0) 530 | w_grad_7 = tl.sum(go4 * (x1*y2 - x2*y1) + go5 * (x1*y3 - x3*y1) + go6 * (x2*y3 - x3*y2), axis=0) 531 | w_grad_8 = tl.sum(tmp35*y7 + tmp36*y7 - tmp37*y7, axis=0) 532 | w_grad_9 = tl.sum(go7 * (x1*y6 - x2*y5 + x3*y4), axis=0) 533 | w_grad_10 = tl.sum(go0 * (-x4 * y4 - x5*y5 - x6*y6), axis=0) 534 | w_grad_11 = tl.sum(go1 * (x4*y2 + x5*y3) + go2 * (-x4 * y1 + x6*y3) + go3 * (-x5 * y1 - x6*y2), axis=0) 535 | w_grad_12 = tl.sum(go2*x5*y7 - tmp38*y7 - tmp39*y7, axis=0) 536 | w_grad_13 = tl.sum(tmp25*y0 + tmp26*y0 + tmp27*y0, axis=0) 537 | w_grad_14 = tl.sum(go4 * (-x5 * y6 + x6*y5) + go5 * (x4*y6 - x6*y4) + go6 * (-x4 * y5 + x5*y4), axis=0) 538 | w_grad_15 = tl.sum(go7 * (x4*y3 - x5*y2 + x6*y1), axis=0) 539 | w_grad_16 = tl.sum(-go0 * x7 * y7, axis=0) 540 | w_grad_17 = tl.sum(go2*x7*y5 - tmp23*x7 - tmp24*x7, axis=0) 541 | w_grad_18 = tl.sum(tmp20*x7 + tmp21*x7 - tmp22*x7, axis=0) 542 | w_grad_19 = tl.sum(go7 * x7 * y0, axis=0) 543 | 544 | # GELU gate gradients 545 | dgate_x = compute_gelu_gate_grad(x0_raw) 546 | dgate_y = compute_gelu_gate_grad(y0_raw) 547 | 548 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3 + 549 | x4_raw*x_grad_4 + x5_raw*x_grad_5 + x6_raw*x_grad_6 + 550 | x7_raw*x_grad_7) 551 | x_grad_1 = gate_x * x_grad_1 552 | x_grad_2 = gate_x * x_grad_2 553 | x_grad_3 = gate_x * x_grad_3 554 | x_grad_4 = gate_x * x_grad_4 555 | x_grad_5 = gate_x * x_grad_5 556 | x_grad_6 = gate_x * x_grad_6 557 | x_grad_7 = gate_x * x_grad_7 558 | 559 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3 + 560 | y4_raw*y_grad_4 + y5_raw*y_grad_5 + y6_raw*y_grad_6 + 561 | y7_raw*y_grad_7) 562 | y_grad_1 = gate_y * y_grad_1 563 | y_grad_2 = gate_y * y_grad_2 564 | y_grad_3 = gate_y * y_grad_3 565 | y_grad_4 = gate_y * y_grad_4 566 | y_grad_5 = gate_y * y_grad_5 567 | y_grad_6 = gate_y * y_grad_6 568 | y_grad_7 = gate_y * y_grad_7 569 | 570 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask) 571 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask) 572 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask) 573 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask) 574 | tl.store(grad_x_ptr + 4 * stride_component + base_offset, x_grad_4, mask=batch_feature_mask) 575 | tl.store(grad_x_ptr + 5 * stride_component + base_offset, x_grad_5, mask=batch_feature_mask) 576 | tl.store(grad_x_ptr + 6 * stride_component + base_offset, x_grad_6, mask=batch_feature_mask) 577 | tl.store(grad_x_ptr + 7 * stride_component + base_offset, x_grad_7, mask=batch_feature_mask) 578 | 579 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask) 580 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask) 581 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask) 582 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask) 583 | tl.store(grad_y_ptr + 4 * stride_component + base_offset, y_grad_4, mask=batch_feature_mask) 584 | tl.store(grad_y_ptr + 5 * stride_component + base_offset, y_grad_5, mask=batch_feature_mask) 585 | tl.store(grad_y_ptr + 6 * stride_component + base_offset, y_grad_6, mask=batch_feature_mask) 586 | tl.store(grad_y_ptr + 7 * stride_component + base_offset, y_grad_7, mask=batch_feature_mask) 587 | 588 | tl.store(grad_weight_ptr + block_offset + weight_offset + 0, w_grad_0, mask=feature_mask) 589 | tl.store(grad_weight_ptr + block_offset + weight_offset + 1, w_grad_1, mask=feature_mask) 590 | tl.store(grad_weight_ptr + block_offset + weight_offset + 2, w_grad_2, mask=feature_mask) 591 | tl.store(grad_weight_ptr + block_offset + weight_offset + 3, w_grad_3, mask=feature_mask) 592 | tl.store(grad_weight_ptr + block_offset + weight_offset + 4, w_grad_4, mask=feature_mask) 593 | tl.store(grad_weight_ptr + block_offset + weight_offset + 5, w_grad_5, mask=feature_mask) 594 | tl.store(grad_weight_ptr + block_offset + weight_offset + 6, w_grad_6, mask=feature_mask) 595 | tl.store(grad_weight_ptr + block_offset + weight_offset + 7, w_grad_7, mask=feature_mask) 596 | tl.store(grad_weight_ptr + block_offset + weight_offset + 8, w_grad_8, mask=feature_mask) 597 | tl.store(grad_weight_ptr + block_offset + weight_offset + 9, w_grad_9, mask=feature_mask) 598 | tl.store(grad_weight_ptr + block_offset + weight_offset + 10, w_grad_10, mask=feature_mask) 599 | tl.store(grad_weight_ptr + block_offset + weight_offset + 11, w_grad_11, mask=feature_mask) 600 | tl.store(grad_weight_ptr + block_offset + weight_offset + 12, w_grad_12, mask=feature_mask) 601 | tl.store(grad_weight_ptr + block_offset + weight_offset + 13, w_grad_13, mask=feature_mask) 602 | tl.store(grad_weight_ptr + block_offset + weight_offset + 14, w_grad_14, mask=feature_mask) 603 | tl.store(grad_weight_ptr + block_offset + weight_offset + 15, w_grad_15, mask=feature_mask) 604 | tl.store(grad_weight_ptr + block_offset + weight_offset + 16, w_grad_16, mask=feature_mask) 605 | tl.store(grad_weight_ptr + block_offset + weight_offset + 17, w_grad_17, mask=feature_mask) 606 | tl.store(grad_weight_ptr + block_offset + weight_offset + 18, w_grad_18, mask=feature_mask) 607 | tl.store(grad_weight_ptr + block_offset + weight_offset + 19, w_grad_19, mask=feature_mask) 608 | 609 | 610 | def gelu_geometric_product_norm_bwd( 611 | x: torch.Tensor, 612 | y: torch.Tensor, 613 | weight: torch.Tensor, 614 | o: torch.Tensor, 615 | partial_norm: torch.Tensor, 616 | grad_output: torch.Tensor, 617 | normalize: bool, 618 | ) -> torch.Tensor: 619 | """Backward pass for the fused operation.""" 620 | _, B, N = x.shape 621 | 622 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 623 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 624 | 625 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 626 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 627 | 628 | grad_x = torch.zeros_like(x) 629 | grad_y = torch.zeros_like(y) 630 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0)) 631 | grad_weight = torch.zeros((num_blocks_batch, N, NUM_PRODUCT_WEIGHTS), device=x.device, dtype=weight.dtype) 632 | 633 | grid = (num_blocks_batch, num_blocks_features) 634 | 635 | if normalize: 636 | grad_o_dot_o_kernel[grid]( 637 | dot, 638 | partial_norm, 639 | o, 640 | grad_output, 641 | B, 642 | N, 643 | BATCH_BLOCK, 644 | FEATURE_BLOCK, 645 | EPS, 646 | num_warps=DEFAULT_NUM_WARPS, 647 | num_stages=DEFAULT_NUM_STAGES, 648 | ) 649 | 650 | gelu_wgp_norm_kernel_bwd[grid]( 651 | x, 652 | y, 653 | o, 654 | weight, 655 | dot, 656 | partial_norm, 657 | grad_output, 658 | grad_x, 659 | grad_y, 660 | grad_weight, 661 | normalize, 662 | B, 663 | N, 664 | BATCH_BLOCK, 665 | FEATURE_BLOCK, 666 | MV_DIM, 667 | NUM_GRADES, 668 | NUM_PRODUCT_WEIGHTS, 669 | EPS, 670 | num_warps=DEFAULT_NUM_WARPS, 671 | num_stages=DEFAULT_NUM_STAGES, 672 | ) 673 | 674 | grad_weight = torch.sum(grad_weight, dim=0) 675 | 676 | return grad_x, grad_y, grad_weight 677 | 678 | 679 | class WeightedGeluGeometricProductNorm3D(torch.autograd.Function): 680 | 681 | @staticmethod 682 | @torch.amp.custom_fwd(device_type="cuda") 683 | def forward(ctx, x, y, weight, normalize): 684 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous() 685 | 686 | ctx.dtype = x.dtype 687 | ctx.normalize = normalize 688 | 689 | o, partial_norm = gelu_geometric_product_norm_fwd( 690 | x, 691 | y, 692 | weight, 693 | normalize, 694 | ) 695 | 696 | ctx.save_for_backward(x, y, weight, o, partial_norm) 697 | 698 | return o.to(x.dtype) 699 | 700 | @staticmethod 701 | @torch.amp.custom_bwd(device_type="cuda") 702 | def backward(ctx, grad_output): 703 | grad_output = grad_output.contiguous() 704 | 705 | x, y, weight, o, partial_norm = ctx.saved_tensors 706 | 707 | grad_x, grad_y, grad_weight = gelu_geometric_product_norm_bwd( 708 | x, 709 | y, 710 | weight, 711 | o, 712 | partial_norm, 713 | grad_output, 714 | ctx.normalize, 715 | ) 716 | 717 | return grad_x, grad_y, grad_weight, None, None, None, None 718 | 719 | 720 | def fused_gelu_sgp_norm_3d(x, y, weight, normalize=True): 721 | """ 722 | Fused operation that applies GELU non-linearity to two multivector inputs, 723 | then computes their weighted geometric product, and applies RMSNorm. 724 | 725 | Clifford algebra is assumed to be Cl(3,0). 726 | 727 | Args: 728 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 729 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 730 | weight (torch.Tensor): Weight tensor of shape (N, NUM_PRODUCT_WEIGHTS), one weight per geometric product component. 731 | normalize (bool): Whether to apply RMSNorm after the geometric product. 732 | 733 | Returns: 734 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation. 735 | """ 736 | return WeightedGeluGeometricProductNorm3D.apply(x, y, weight, normalize) -------------------------------------------------------------------------------- /ops/fc_p3m0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | MV_DIM = 8 6 | NUM_GRADES = 4 7 | NUM_PRODUCT_WEIGHTS = 20 8 | WEIGHT_EXPANSION = [0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 16, 17, 17, 17, 18, 18, 18, 19] 9 | EPS = 1e-6 10 | 11 | # tuned at RTX 4500 12 | DEFAULT_BATCH_BLOCK = 4 13 | DEFAULT_FEATURE_BLOCK = 128 14 | DEFAULT_NUM_WARPS = 16 15 | DEFAULT_NUM_STAGES = 1 16 | 17 | 18 | @triton.jit 19 | def compute_gelu_gate(x): 20 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))""" 21 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype) 22 | 23 | 24 | @triton.jit 25 | def compute_gelu_gate_grad(x): 26 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)""" 27 | return 0.3989422804 * tl.exp(-0.5 * x * x) 28 | 29 | 30 | @triton.jit 31 | def gelu_pairwise_kernel_fwd( 32 | x_ptr, 33 | y_ptr, 34 | pairwise_ptr, 35 | batch_size: tl.constexpr, 36 | n_features: tl.constexpr, 37 | BATCH_BLOCK: tl.constexpr, 38 | FEATURE_BLOCK: tl.constexpr, 39 | ): 40 | """Apply GELU non-linearity to inputs and compute pairwise products for geometric product.""" 41 | batch_block_id = tl.program_id(axis=0) 42 | thread_block_id = tl.program_id(axis=1) 43 | 44 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 45 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 46 | 47 | batch_mask = batch_ids < batch_size 48 | feature_mask = feature_ids < n_features 49 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 50 | 51 | stride_component = batch_size * n_features 52 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 53 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 54 | 55 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 56 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 57 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 58 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 59 | x4 = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 60 | x5 = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 61 | x6 = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 62 | x7 = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 63 | 64 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 65 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 66 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 67 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 68 | y4 = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 69 | y5 = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 70 | y6 = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 71 | y7 = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 72 | 73 | gate_x = compute_gelu_gate(x0) 74 | gate_y = compute_gelu_gate(y0) 75 | 76 | x0 = x0 * gate_x 77 | x1 = x1 * gate_x 78 | x2 = x2 * gate_x 79 | x3 = x3 * gate_x 80 | x4 = x4 * gate_x 81 | x5 = x5 * gate_x 82 | x6 = x6 * gate_x 83 | x7 = x7 * gate_x 84 | 85 | y0 = y0 * gate_y 86 | y1 = y1 * gate_y 87 | y2 = y2 * gate_y 88 | y3 = y3 * gate_y 89 | y4 = y4 * gate_y 90 | y5 = y5 * gate_y 91 | y6 = y6 * gate_y 92 | y7 = y7 * gate_y 93 | 94 | p0 = x0*y0 95 | p1 = x0*y1 96 | p2 = x0*y2 97 | p3 = x0*y3 98 | p4 = x0*y4 99 | p5 = x0*y5 100 | p6 = x0*y6 101 | p7 = x0*y7 102 | p8 = x1*y1 + x2*y2 + x3*y3 103 | p9 = x1*y0 104 | p10 = x2*y0 105 | p11 = x3*y0 106 | p12 = -x2*y4 - x3*y5 107 | p13 = x1*y4 - x3*y6 108 | p14 = x1*y5 + x2*y6 109 | p15 = x1*y2 - x2*y1 110 | p16 = x1*y3 - x3*y1 111 | p17 = x2*y3 - x3*y2 112 | p18 = x3*y7 113 | p19 = -x2*y7 114 | p20 = x1*y7 115 | p21 = x1*y6 - x2*y5 + x3*y4 116 | p22 = -x4*y4 - x5*y5 - x6*y6 117 | p23 = x4*y2 + x5*y3 118 | p24 = -x4*y1 + x6*y3 119 | p25 = -x5*y1 - x6*y2 120 | p26 = -x6*y7 121 | p27 = x5*y7 122 | p28 = -x4*y7 123 | p29 = x4*y0 124 | p30 = x5*y0 125 | p31 = x6*y0 126 | p32 = -x5*y6 + x6*y5 127 | p33 = x4*y6 - x6*y4 128 | p34 = -x4*y5 + x5*y4 129 | p35 = x4*y3 - x5*y2 + x6*y1 130 | p36 = -x7*y7 131 | p37 = -x7*y6 132 | p38 = x7*y5 133 | p39 = -x7*y4 134 | p40 = x7*y3 135 | p41 = -x7*y2 136 | p42 = x7*y1 137 | p43 = x7*y0 138 | 139 | tl.store(pairwise_ptr + 0*batch_size*n_features + pairwise_offset, p0, mask=batch_feature_mask) 140 | tl.store(pairwise_ptr + 1*batch_size*n_features + pairwise_offset, p1, mask=batch_feature_mask) 141 | tl.store(pairwise_ptr + 2*batch_size*n_features + pairwise_offset, p2, mask=batch_feature_mask) 142 | tl.store(pairwise_ptr + 3*batch_size*n_features + pairwise_offset, p3, mask=batch_feature_mask) 143 | tl.store(pairwise_ptr + 4*batch_size*n_features + pairwise_offset, p4, mask=batch_feature_mask) 144 | tl.store(pairwise_ptr + 5*batch_size*n_features + pairwise_offset, p5, mask=batch_feature_mask) 145 | tl.store(pairwise_ptr + 6*batch_size*n_features + pairwise_offset, p6, mask=batch_feature_mask) 146 | tl.store(pairwise_ptr + 7*batch_size*n_features + pairwise_offset, p7, mask=batch_feature_mask) 147 | tl.store(pairwise_ptr + 8*batch_size*n_features + pairwise_offset, p8, mask=batch_feature_mask) 148 | tl.store(pairwise_ptr + 9*batch_size*n_features + pairwise_offset, p9, mask=batch_feature_mask) 149 | tl.store(pairwise_ptr + 10*batch_size*n_features + pairwise_offset, p10, mask=batch_feature_mask) 150 | tl.store(pairwise_ptr + 11*batch_size*n_features + pairwise_offset, p11, mask=batch_feature_mask) 151 | tl.store(pairwise_ptr + 12*batch_size*n_features + pairwise_offset, p12, mask=batch_feature_mask) 152 | tl.store(pairwise_ptr + 13*batch_size*n_features + pairwise_offset, p13, mask=batch_feature_mask) 153 | tl.store(pairwise_ptr + 14*batch_size*n_features + pairwise_offset, p14, mask=batch_feature_mask) 154 | tl.store(pairwise_ptr + 15*batch_size*n_features + pairwise_offset, p15, mask=batch_feature_mask) 155 | tl.store(pairwise_ptr + 16*batch_size*n_features + pairwise_offset, p16, mask=batch_feature_mask) 156 | tl.store(pairwise_ptr + 17*batch_size*n_features + pairwise_offset, p17, mask=batch_feature_mask) 157 | tl.store(pairwise_ptr + 18*batch_size*n_features + pairwise_offset, p18, mask=batch_feature_mask) 158 | tl.store(pairwise_ptr + 19*batch_size*n_features + pairwise_offset, p19, mask=batch_feature_mask) 159 | tl.store(pairwise_ptr + 20*batch_size*n_features + pairwise_offset, p20, mask=batch_feature_mask) 160 | tl.store(pairwise_ptr + 21*batch_size*n_features + pairwise_offset, p21, mask=batch_feature_mask) 161 | tl.store(pairwise_ptr + 22*batch_size*n_features + pairwise_offset, p22, mask=batch_feature_mask) 162 | tl.store(pairwise_ptr + 23*batch_size*n_features + pairwise_offset, p23, mask=batch_feature_mask) 163 | tl.store(pairwise_ptr + 24*batch_size*n_features + pairwise_offset, p24, mask=batch_feature_mask) 164 | tl.store(pairwise_ptr + 25*batch_size*n_features + pairwise_offset, p25, mask=batch_feature_mask) 165 | tl.store(pairwise_ptr + 26*batch_size*n_features + pairwise_offset, p26, mask=batch_feature_mask) 166 | tl.store(pairwise_ptr + 27*batch_size*n_features + pairwise_offset, p27, mask=batch_feature_mask) 167 | tl.store(pairwise_ptr + 28*batch_size*n_features + pairwise_offset, p28, mask=batch_feature_mask) 168 | tl.store(pairwise_ptr + 29*batch_size*n_features + pairwise_offset, p29, mask=batch_feature_mask) 169 | tl.store(pairwise_ptr + 30*batch_size*n_features + pairwise_offset, p30, mask=batch_feature_mask) 170 | tl.store(pairwise_ptr + 31*batch_size*n_features + pairwise_offset, p31, mask=batch_feature_mask) 171 | tl.store(pairwise_ptr + 32*batch_size*n_features + pairwise_offset, p32, mask=batch_feature_mask) 172 | tl.store(pairwise_ptr + 33*batch_size*n_features + pairwise_offset, p33, mask=batch_feature_mask) 173 | tl.store(pairwise_ptr + 34*batch_size*n_features + pairwise_offset, p34, mask=batch_feature_mask) 174 | tl.store(pairwise_ptr + 35*batch_size*n_features + pairwise_offset, p35, mask=batch_feature_mask) 175 | tl.store(pairwise_ptr + 36*batch_size*n_features + pairwise_offset, p36, mask=batch_feature_mask) 176 | tl.store(pairwise_ptr + 37*batch_size*n_features + pairwise_offset, p37, mask=batch_feature_mask) 177 | tl.store(pairwise_ptr + 38*batch_size*n_features + pairwise_offset, p38, mask=batch_feature_mask) 178 | tl.store(pairwise_ptr + 39*batch_size*n_features + pairwise_offset, p39, mask=batch_feature_mask) 179 | tl.store(pairwise_ptr + 40*batch_size*n_features + pairwise_offset, p40, mask=batch_feature_mask) 180 | tl.store(pairwise_ptr + 41*batch_size*n_features + pairwise_offset, p41, mask=batch_feature_mask) 181 | tl.store(pairwise_ptr + 42*batch_size*n_features + pairwise_offset, p42, mask=batch_feature_mask) 182 | tl.store(pairwise_ptr + 43*batch_size*n_features + pairwise_offset, p43, mask=batch_feature_mask) 183 | 184 | 185 | @triton.jit 186 | def assemble_kernel( 187 | transformed_ptr, 188 | pnorm_ptr, 189 | output_ptr, 190 | NORMALIZE: tl.constexpr, 191 | batch_size: tl.constexpr, 192 | n_features: tl.constexpr, 193 | BATCH_BLOCK: tl.constexpr, 194 | FEATURE_BLOCK: tl.constexpr, 195 | ): 196 | """ 197 | Gather linearly transformed pairwise products and compute the geometric product. 198 | """ 199 | batch_block_id = tl.program_id(axis=0) 200 | thread_block_id = tl.program_id(axis=1) 201 | 202 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 203 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 204 | 205 | batch_mask = batch_ids < batch_size 206 | feature_mask = feature_ids < n_features 207 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 208 | 209 | stride_component = batch_size * n_features 210 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 211 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 212 | 213 | t0 = tl.load(transformed_ptr + 0 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 214 | t1 = tl.load(transformed_ptr + 1 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 215 | t2 = tl.load(transformed_ptr + 2 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 216 | t3 = tl.load(transformed_ptr + 3 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 217 | t4 = tl.load(transformed_ptr + 4 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 218 | t5 = tl.load(transformed_ptr + 5 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 219 | t6 = tl.load(transformed_ptr + 6 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 220 | t7 = tl.load(transformed_ptr + 7 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 221 | t8 = tl.load(transformed_ptr + 8 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 222 | t9 = tl.load(transformed_ptr + 9 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 223 | t10 = tl.load(transformed_ptr + 10 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 224 | t11 = tl.load(transformed_ptr + 11 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 225 | t12 = tl.load(transformed_ptr + 12 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 226 | t13 = tl.load(transformed_ptr + 13 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 227 | t14 = tl.load(transformed_ptr + 14 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 228 | t15 = tl.load(transformed_ptr + 15 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 229 | t16 = tl.load(transformed_ptr + 16 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 230 | t17 = tl.load(transformed_ptr + 17 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 231 | t18 = tl.load(transformed_ptr + 18 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 232 | t19 = tl.load(transformed_ptr + 19 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 233 | t20 = tl.load(transformed_ptr + 20 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 234 | t21 = tl.load(transformed_ptr + 21 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 235 | t22 = tl.load(transformed_ptr + 22 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 236 | t23 = tl.load(transformed_ptr + 23 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 237 | t24 = tl.load(transformed_ptr + 24 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 238 | t25 = tl.load(transformed_ptr + 25 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 239 | t26 = tl.load(transformed_ptr + 26 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 240 | t27 = tl.load(transformed_ptr + 27 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 241 | t28 = tl.load(transformed_ptr + 28 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 242 | t29 = tl.load(transformed_ptr + 29 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 243 | t30 = tl.load(transformed_ptr + 30 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 244 | t31 = tl.load(transformed_ptr + 31 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 245 | t32 = tl.load(transformed_ptr + 32 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 246 | t33 = tl.load(transformed_ptr + 33 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 247 | t34 = tl.load(transformed_ptr + 34 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 248 | t35 = tl.load(transformed_ptr + 35 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 249 | t36 = tl.load(transformed_ptr + 36 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 250 | t37 = tl.load(transformed_ptr + 37 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 251 | t38 = tl.load(transformed_ptr + 38 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 252 | t39 = tl.load(transformed_ptr + 39 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 253 | t40 = tl.load(transformed_ptr + 40 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 254 | t41 = tl.load(transformed_ptr + 41 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 255 | t42 = tl.load(transformed_ptr + 42 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 256 | t43 = tl.load(transformed_ptr + 43 * batch_size * n_features + transformed_offset, mask=batch_feature_mask) 257 | 258 | o0 = t0 + t8 + t22 + t36 259 | o1 = t1 + t9 + t12 + t23 + t26 + t37 260 | o2 = t2 + t10 + t13 + t24 + t27 + t38 261 | o3 = t3 + t11 + t14 + t25 + t28 + t39 262 | o4 = t4 + t15 + t18 + t29 + t32 + t40 263 | o5 = t5 + t16 + t19 + t30 + t33 + t41 264 | o6 = t6 + t17 + t20 + t31 + t34 + t42 265 | o7 = t7 + t21 + t35 + t43 266 | 267 | if NORMALIZE: 268 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features 269 | pn_vector = tl.sum(o1 * o1 + o2 * o2 + o3 * o3, axis=1) / n_features 270 | pn_bivect = tl.sum(o4 * o4 + o5 * o5 + o6 * o6, axis=1) / n_features 271 | pn_pseudo = tl.sum(o7 * o7, axis=1) / n_features 272 | 273 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask) 274 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask) 275 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask) 276 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_vector, mask=batch_mask) 277 | tl.atomic_add(pnorm_ptr + 4*batch_size + batch_ids, pn_bivect, mask=batch_mask) 278 | tl.atomic_add(pnorm_ptr + 5*batch_size + batch_ids, pn_bivect, mask=batch_mask) 279 | tl.atomic_add(pnorm_ptr + 6*batch_size + batch_ids, pn_bivect, mask=batch_mask) 280 | tl.atomic_add(pnorm_ptr + 7*batch_size + batch_ids, pn_pseudo, mask=batch_mask) 281 | 282 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask) 283 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask) 284 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask) 285 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask) 286 | tl.store(output_ptr + 4 * stride_component + base_offset, o4, mask=batch_feature_mask) 287 | tl.store(output_ptr + 5 * stride_component + base_offset, o5, mask=batch_feature_mask) 288 | tl.store(output_ptr + 6 * stride_component + base_offset, o6, mask=batch_feature_mask) 289 | tl.store(output_ptr + 7 * stride_component + base_offset, o7, mask=batch_feature_mask) 290 | 291 | 292 | @triton.jit 293 | def normalize_with_sqrt_kernel( 294 | output_ptr, 295 | pnorm_ptr, 296 | batch_size: tl.constexpr, 297 | n_features: tl.constexpr, 298 | BATCH_BLOCK: tl.constexpr, 299 | FEATURE_BLOCK: tl.constexpr, 300 | MV_DIM: tl.constexpr, 301 | EPS: tl.constexpr, 302 | ): 303 | """Normalize the output by dividing each grade with root of its accumulated norm.""" 304 | batch_block_id = tl.program_id(axis=0) 305 | thread_block_id = tl.program_id(axis=1) 306 | 307 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 308 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 309 | 310 | batch_mask = batch_ids < batch_size 311 | feature_mask = feature_ids < n_features 312 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None] 313 | 314 | component_ids = tl.arange(0, MV_DIM)[None, None, :] 315 | 316 | feature_offset = (component_ids * batch_size * n_features + 317 | batch_ids[:, None, None] * n_features + 318 | feature_ids[None, :, None]) 319 | 320 | norm_indices = component_ids * batch_size + batch_ids[:, None, None] 321 | 322 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None]) 323 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask) 324 | 325 | norm = tl.sqrt(pnorm + EPS) 326 | mv_normalized = mv / norm 327 | 328 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask) 329 | 330 | 331 | def gelu_fc_geometric_product_norm_fwd( 332 | x: torch.Tensor, 333 | y: torch.Tensor, 334 | weight: torch.Tensor, 335 | expansion_indices: torch.Tensor, 336 | normalize: bool, 337 | ) -> torch.Tensor: 338 | """Fused operation: GELU non-linearity, fully connected geometric product, and grade-wise RMSNorm.""" 339 | assert x.shape == y.shape 340 | assert x.shape[0] == MV_DIM 341 | assert x.shape[2] == weight.shape[1] == weight.shape[2] 342 | assert weight.shape[0] == NUM_PRODUCT_WEIGHTS 343 | 344 | _, B, N = x.shape 345 | 346 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 347 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 348 | 349 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 350 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 351 | 352 | pairwise = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype) 353 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize else torch.zeros((1,), device=x.device, dtype=x.dtype)) 354 | output = torch.empty_like(x) 355 | 356 | grid = (num_blocks_batch, num_blocks_features) 357 | 358 | gelu_pairwise_kernel_fwd[grid]( 359 | x, 360 | y, 361 | pairwise, 362 | B, 363 | N, 364 | BATCH_BLOCK, 365 | FEATURE_BLOCK, 366 | num_warps=DEFAULT_NUM_WARPS, 367 | num_stages=DEFAULT_NUM_STAGES, 368 | ) 369 | 370 | transformed = torch.bmm(pairwise, weight[expansion_indices]) 371 | 372 | assemble_kernel[grid]( 373 | transformed, 374 | partial_norm, 375 | output, 376 | normalize, 377 | B, 378 | N, 379 | BATCH_BLOCK, 380 | FEATURE_BLOCK, 381 | num_warps=DEFAULT_NUM_WARPS, 382 | num_stages=DEFAULT_NUM_STAGES, 383 | ) 384 | 385 | if normalize: 386 | normalize_with_sqrt_kernel[grid]( 387 | output, 388 | partial_norm, 389 | B, 390 | N, 391 | BATCH_BLOCK, 392 | FEATURE_BLOCK, 393 | MV_DIM, 394 | EPS, 395 | num_warps=DEFAULT_NUM_WARPS, 396 | num_stages=DEFAULT_NUM_STAGES, 397 | ) 398 | 399 | return output, pairwise, partial_norm 400 | 401 | 402 | @triton.jit 403 | def grad_o_dot_o_kernel( 404 | dot_ptr, 405 | pnorm_ptr, 406 | output_ptr, 407 | grad_output_ptr, 408 | batch_size: tl.constexpr, 409 | n_features: tl.constexpr, 410 | BATCH_BLOCK: tl.constexpr, 411 | FEATURE_BLOCK: tl.constexpr, 412 | EPS: tl.constexpr, 413 | ): 414 | """Compute the dot product of grad_output and output for each grade, accumulate across all features.""" 415 | batch_block_id = tl.program_id(axis=0) 416 | thread_block_id = tl.program_id(axis=1) 417 | 418 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 419 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 420 | 421 | batch_mask = batch_ids < batch_size 422 | feature_mask = feature_ids < n_features 423 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 424 | 425 | stride_component = batch_size * n_features 426 | offset = batch_ids[:, None] * n_features + feature_ids[None, :] 427 | 428 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 429 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 430 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 431 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 432 | go4 = tl.load(grad_output_ptr + 4 * stride_component + offset, mask=batch_feature_mask) 433 | go5 = tl.load(grad_output_ptr + 5 * stride_component + offset, mask=batch_feature_mask) 434 | go6 = tl.load(grad_output_ptr + 6 * stride_component + offset, mask=batch_feature_mask) 435 | go7 = tl.load(grad_output_ptr + 7 * stride_component + offset, mask=batch_feature_mask) 436 | 437 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 438 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 439 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None] 440 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None] 441 | 442 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask) 443 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask) 444 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask) 445 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask) 446 | o4 = tl.load(output_ptr + 4 * stride_component + offset, mask=batch_feature_mask) 447 | o5 = tl.load(output_ptr + 5 * stride_component + offset, mask=batch_feature_mask) 448 | o6 = tl.load(output_ptr + 6 * stride_component + offset, mask=batch_feature_mask) 449 | o7 = tl.load(output_ptr + 7 * stride_component + offset, mask=batch_feature_mask) 450 | 451 | rms_scalar = tl.sqrt(pn_scalar + EPS) 452 | rms_vector = tl.sqrt(pn_vector + EPS) 453 | rms_bivect = tl.sqrt(pn_bivect + EPS) 454 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 455 | 456 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1) 457 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2 + go3*o3), axis=1) 458 | dot_bivect = tl.sum(rms_bivect * (go4*o4 + go5*o5 + go6*o6), axis=1) 459 | dot_pseudo = tl.sum(rms_pseudo * go7 * o7, axis=1) 460 | 461 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask) 462 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask) 463 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_bivect, mask=batch_mask) 464 | tl.atomic_add(dot_ptr + 3*batch_size + batch_ids, dot_pseudo, mask=batch_mask) 465 | 466 | 467 | @triton.jit 468 | def disassemble_kernel( 469 | grad_output_ptr, 470 | output_ptr, 471 | dot_ptr, 472 | grad_transformed_ptr, 473 | pnorm_ptr, 474 | NORMALIZE: tl.constexpr, 475 | batch_size: tl.constexpr, 476 | n_features: tl.constexpr, 477 | BATCH_BLOCK: tl.constexpr, 478 | FEATURE_BLOCK: tl.constexpr, 479 | EPS: tl.constexpr, 480 | ): 481 | """Gather linearly transformed pairwise products and compute the geometric product.""" 482 | batch_block_id = tl.program_id(axis=0) 483 | thread_block_id = tl.program_id(axis=1) 484 | 485 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 486 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 487 | 488 | batch_mask = batch_ids < batch_size 489 | feature_mask = feature_ids < n_features 490 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 491 | 492 | stride_component = batch_size * n_features 493 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 494 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 495 | 496 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 497 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 498 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 499 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 500 | go4 = tl.load(grad_output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 501 | go5 = tl.load(grad_output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 502 | go6 = tl.load(grad_output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 503 | go7 = tl.load(grad_output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 504 | 505 | if NORMALIZE: 506 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 507 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 508 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 509 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 510 | o4 = tl.load(output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 511 | o5 = tl.load(output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 512 | o6 = tl.load(output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 513 | o7 = tl.load(output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 514 | 515 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 516 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 517 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None] 518 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None] 519 | 520 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None] 521 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None] 522 | dot_bivect = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None] 523 | dot_pseudo = tl.load(dot_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None] 524 | 525 | rms_scalar = tl.sqrt(pn_scalar + EPS) 526 | rms_vector = tl.sqrt(pn_vector + EPS) 527 | rms_bivect = tl.sqrt(pn_bivect + EPS) 528 | rms_pseudo = tl.sqrt(pn_pseudo + EPS) 529 | 530 | go0 = go0 / rms_scalar - o0 * dot_scalar / (n_features * rms_scalar * rms_scalar) 531 | go1 = go1 / rms_vector - o1 * dot_vector / (n_features * rms_vector * rms_vector) 532 | go2 = go2 / rms_vector - o2 * dot_vector / (n_features * rms_vector * rms_vector) 533 | go3 = go3 / rms_vector - o3 * dot_vector / (n_features * rms_vector * rms_vector) 534 | go4 = go4 / rms_bivect - o4 * dot_bivect / (n_features * rms_bivect * rms_bivect) 535 | go5 = go5 / rms_bivect - o5 * dot_bivect / (n_features * rms_bivect * rms_bivect) 536 | go6 = go6 / rms_bivect - o6 * dot_bivect / (n_features * rms_bivect * rms_bivect) 537 | go7 = go7 / rms_pseudo - o7 * dot_pseudo / (n_features * rms_pseudo * rms_pseudo) 538 | 539 | tl.store(grad_transformed_ptr + 0 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask) 540 | tl.store(grad_transformed_ptr + 1 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask) 541 | tl.store(grad_transformed_ptr + 2 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask) 542 | tl.store(grad_transformed_ptr + 3 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask) 543 | tl.store(grad_transformed_ptr + 4 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask) 544 | tl.store(grad_transformed_ptr + 5 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask) 545 | tl.store(grad_transformed_ptr + 6 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask) 546 | tl.store(grad_transformed_ptr + 7 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask) 547 | tl.store(grad_transformed_ptr + 8 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask) 548 | tl.store(grad_transformed_ptr + 9 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask) 549 | tl.store(grad_transformed_ptr + 10 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask) 550 | tl.store(grad_transformed_ptr + 11 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask) 551 | tl.store(grad_transformed_ptr + 12 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask) 552 | tl.store(grad_transformed_ptr + 13 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask) 553 | tl.store(grad_transformed_ptr + 14 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask) 554 | tl.store(grad_transformed_ptr + 15 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask) 555 | tl.store(grad_transformed_ptr + 16 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask) 556 | tl.store(grad_transformed_ptr + 17 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask) 557 | tl.store(grad_transformed_ptr + 18 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask) 558 | tl.store(grad_transformed_ptr + 19 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask) 559 | tl.store(grad_transformed_ptr + 20 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask) 560 | tl.store(grad_transformed_ptr + 21 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask) 561 | tl.store(grad_transformed_ptr + 22 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask) 562 | tl.store(grad_transformed_ptr + 23 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask) 563 | tl.store(grad_transformed_ptr + 24 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask) 564 | tl.store(grad_transformed_ptr + 25 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask) 565 | tl.store(grad_transformed_ptr + 26 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask) 566 | tl.store(grad_transformed_ptr + 27 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask) 567 | tl.store(grad_transformed_ptr + 28 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask) 568 | tl.store(grad_transformed_ptr + 29 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask) 569 | tl.store(grad_transformed_ptr + 30 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask) 570 | tl.store(grad_transformed_ptr + 31 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask) 571 | tl.store(grad_transformed_ptr + 32 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask) 572 | tl.store(grad_transformed_ptr + 33 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask) 573 | tl.store(grad_transformed_ptr + 34 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask) 574 | tl.store(grad_transformed_ptr + 35 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask) 575 | tl.store(grad_transformed_ptr + 36 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask) 576 | tl.store(grad_transformed_ptr + 37 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask) 577 | tl.store(grad_transformed_ptr + 38 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask) 578 | tl.store(grad_transformed_ptr + 39 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask) 579 | tl.store(grad_transformed_ptr + 40 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask) 580 | tl.store(grad_transformed_ptr + 41 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask) 581 | tl.store(grad_transformed_ptr + 42 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask) 582 | tl.store(grad_transformed_ptr + 43 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask) 583 | 584 | 585 | @triton.jit 586 | def gelu_pairwise_kernel_bwd( 587 | x_ptr, 588 | y_ptr, 589 | grad_pairwise_ptr, 590 | grad_x_ptr, 591 | grad_y_ptr, 592 | batch_size: tl.constexpr, 593 | n_features: tl.constexpr, 594 | BATCH_BLOCK: tl.constexpr, 595 | FEATURE_BLOCK: tl.constexpr, 596 | ): 597 | """ 598 | Apply GELU non-linearity to inputs and compute required pairwise products. 599 | """ 600 | batch_block_id = tl.program_id(axis=0) 601 | thread_block_id = tl.program_id(axis=1) 602 | 603 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK) 604 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK) 605 | 606 | batch_mask = batch_ids < batch_size 607 | feature_mask = feature_ids < n_features 608 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :] 609 | 610 | stride_component = batch_size * n_features 611 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 612 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :] 613 | 614 | gp0 = tl.load(grad_pairwise_ptr + 0 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 615 | gp1 = tl.load(grad_pairwise_ptr + 1 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 616 | gp2 = tl.load(grad_pairwise_ptr + 2 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 617 | gp3 = tl.load(grad_pairwise_ptr + 3 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 618 | gp4 = tl.load(grad_pairwise_ptr + 4 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 619 | gp5 = tl.load(grad_pairwise_ptr + 5 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 620 | gp6 = tl.load(grad_pairwise_ptr + 6 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 621 | gp7 = tl.load(grad_pairwise_ptr + 7 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 622 | gp8 = tl.load(grad_pairwise_ptr + 8 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 623 | gp9 = tl.load(grad_pairwise_ptr + 9 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 624 | gp10 = tl.load(grad_pairwise_ptr + 10 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 625 | gp11 = tl.load(grad_pairwise_ptr + 11 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 626 | gp12 = tl.load(grad_pairwise_ptr + 12 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 627 | gp13 = tl.load(grad_pairwise_ptr + 13 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 628 | gp14 = tl.load(grad_pairwise_ptr + 14 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 629 | gp15 = tl.load(grad_pairwise_ptr + 15 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 630 | gp16 = tl.load(grad_pairwise_ptr + 16 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 631 | gp17 = tl.load(grad_pairwise_ptr + 17 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 632 | gp18 = tl.load(grad_pairwise_ptr + 18 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 633 | gp19 = tl.load(grad_pairwise_ptr + 19 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 634 | gp20 = tl.load(grad_pairwise_ptr + 20 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 635 | gp21 = tl.load(grad_pairwise_ptr + 21 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 636 | gp22 = tl.load(grad_pairwise_ptr + 22 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 637 | gp23 = tl.load(grad_pairwise_ptr + 23 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 638 | gp24 = tl.load(grad_pairwise_ptr + 24 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 639 | gp25 = tl.load(grad_pairwise_ptr + 25 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 640 | gp26 = tl.load(grad_pairwise_ptr + 26 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 641 | gp27 = tl.load(grad_pairwise_ptr + 27 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 642 | gp28 = tl.load(grad_pairwise_ptr + 28 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 643 | gp29 = tl.load(grad_pairwise_ptr + 29 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 644 | gp30 = tl.load(grad_pairwise_ptr + 30 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 645 | gp31 = tl.load(grad_pairwise_ptr + 31 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 646 | gp32 = tl.load(grad_pairwise_ptr + 32 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 647 | gp33 = tl.load(grad_pairwise_ptr + 33 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 648 | gp34 = tl.load(grad_pairwise_ptr + 34 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 649 | gp35 = tl.load(grad_pairwise_ptr + 35 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 650 | gp36 = tl.load(grad_pairwise_ptr + 36 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 651 | gp37 = tl.load(grad_pairwise_ptr + 37 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 652 | gp38 = tl.load(grad_pairwise_ptr + 38 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 653 | gp39 = tl.load(grad_pairwise_ptr + 39 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 654 | gp40 = tl.load(grad_pairwise_ptr + 40 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 655 | gp41 = tl.load(grad_pairwise_ptr + 41 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 656 | gp42 = tl.load(grad_pairwise_ptr + 42 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 657 | gp43 = tl.load(grad_pairwise_ptr + 43 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask) 658 | 659 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 660 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 661 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 662 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 663 | x4_raw = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 664 | x5_raw = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 665 | x6_raw = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 666 | x7_raw = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 667 | 668 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask) 669 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask) 670 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask) 671 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask) 672 | y4_raw = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask) 673 | y5_raw = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask) 674 | y6_raw = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask) 675 | y7_raw = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask) 676 | 677 | # collect gradients from pairwise products 678 | gate_x = compute_gelu_gate(x0_raw) 679 | gate_y = compute_gelu_gate(y0_raw) 680 | 681 | x0 = x0_raw * gate_x 682 | x1 = x1_raw * gate_x 683 | x2 = x2_raw * gate_x 684 | x3 = x3_raw * gate_x 685 | x4 = x4_raw * gate_x 686 | x5 = x5_raw * gate_x 687 | x6 = x6_raw * gate_x 688 | x7 = x7_raw * gate_x 689 | 690 | y0 = y0_raw * gate_y 691 | y1 = y1_raw * gate_y 692 | y2 = y2_raw * gate_y 693 | y3 = y3_raw * gate_y 694 | y4 = y4_raw * gate_y 695 | y5 = y5_raw * gate_y 696 | y6 = y6_raw * gate_y 697 | y7 = y7_raw * gate_y 698 | 699 | x_grad_0 = gp0*y0 + gp1*y1 + gp2*y2 + gp3*y3 + gp4*y4 + gp5*y5 + gp6*y6 + gp7*y7 700 | x_grad_1 = gp8*y1 + gp9*y0 + gp13*y4 + gp14*y5 + gp15*y2 + gp16*y3 + gp20*y7 + gp21*y6 701 | x_grad_2 = gp8*y2 + gp10*y0 - gp12*y4 + gp14*y6 - gp15*y1 + gp17*y3 - gp19*y7 - gp21*y5 702 | x_grad_3 = gp8*y3 + gp11*y0 - gp12*y5 - gp13*y6 - gp16*y1 - gp17*y2 + gp18*y7 + gp21*y4 703 | x_grad_4 = -gp22*y4 + gp23*y2 - gp24*y1 - gp28*y7 + gp29*y0 + gp33*y6 - gp34*y5 + gp35*y3 704 | x_grad_5 = -gp22*y5 + gp23*y3 - gp25*y1 + gp27*y7 + gp30*y0 - gp32*y6 + gp34*y4 - gp35*y2 705 | x_grad_6 = -gp22*y6 + gp24*y3 - gp25*y2 - gp26*y7 + gp31*y0 + gp32*y5 - gp33*y4 + gp35*y1 706 | x_grad_7 = -gp36*y7 - gp37*y6 + gp38*y5 - gp39*y4 + gp40*y3 - gp41*y2 + gp42*y1 + gp43*y0 707 | 708 | y_grad_0 = gp0*x0 + gp9*x1 + gp10*x2 + gp11*x3 + gp29*x4 + gp30*x5 + gp31*x6 + gp43*x7 709 | y_grad_1 = gp1*x0 + gp8*x1 - gp15*x2 - gp16*x3 - gp24*x4 - gp25*x5 + gp35*x6 + gp42*x7 710 | y_grad_2 = gp2*x0 + gp8*x2 + gp15*x1 - gp17*x3 + gp23*x4 - gp35*x5 - gp25*x6 - gp41*x7 711 | y_grad_3 = gp3*x0 + gp8*x3 + gp16*x1 + gp17*x2 + gp23*x5 + gp24*x6 + gp35*x4 + gp40*x7 712 | y_grad_4 = gp4*x0 + gp13*x1 - gp12*x2 + gp21*x3 - gp22*x4 + gp34*x5 - gp33*x6 - gp39*x7 713 | y_grad_5 = gp5*x0 + gp14*x1 - gp21*x2 - gp12*x3 - gp22*x5 - gp34*x4 + gp32*x6 + gp38*x7 714 | y_grad_6 = gp6*x0 - gp13*x3 + gp14*x2 + gp21*x1 - gp22*x6 - gp32*x5 + gp33*x4 - gp37*x7 715 | y_grad_7 = gp7*x0 + gp20*x1 - gp19*x2 + gp18*x3 - gp28*x4 + gp27*x5 - gp26*x6 - gp36*x7 716 | 717 | # GELU gate gradients 718 | dgate_x = compute_gelu_gate_grad(x0_raw) 719 | dgate_y = compute_gelu_gate_grad(y0_raw) 720 | 721 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3 + 722 | x4_raw*x_grad_4 + x5_raw*x_grad_5 + x6_raw*x_grad_6 + 723 | x7_raw*x_grad_7) 724 | x_grad_1 = gate_x * x_grad_1 725 | x_grad_2 = gate_x * x_grad_2 726 | x_grad_3 = gate_x * x_grad_3 727 | x_grad_4 = gate_x * x_grad_4 728 | x_grad_5 = gate_x * x_grad_5 729 | x_grad_6 = gate_x * x_grad_6 730 | x_grad_7 = gate_x * x_grad_7 731 | 732 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3 + 733 | y4_raw*y_grad_4 + y5_raw*y_grad_5 + y6_raw*y_grad_6 + 734 | y7_raw*y_grad_7) 735 | y_grad_1 = gate_y * y_grad_1 736 | y_grad_2 = gate_y * y_grad_2 737 | y_grad_3 = gate_y * y_grad_3 738 | y_grad_4 = gate_y * y_grad_4 739 | y_grad_5 = gate_y * y_grad_5 740 | y_grad_6 = gate_y * y_grad_6 741 | y_grad_7 = gate_y * y_grad_7 742 | 743 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask) 744 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask) 745 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask) 746 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask) 747 | tl.store(grad_x_ptr + 4 * stride_component + base_offset, x_grad_4, mask=batch_feature_mask) 748 | tl.store(grad_x_ptr + 5 * stride_component + base_offset, x_grad_5, mask=batch_feature_mask) 749 | tl.store(grad_x_ptr + 6 * stride_component + base_offset, x_grad_6, mask=batch_feature_mask) 750 | tl.store(grad_x_ptr + 7 * stride_component + base_offset, x_grad_7, mask=batch_feature_mask) 751 | 752 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask) 753 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask) 754 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask) 755 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask) 756 | tl.store(grad_y_ptr + 4 * stride_component + base_offset, y_grad_4, mask=batch_feature_mask) 757 | tl.store(grad_y_ptr + 5 * stride_component + base_offset, y_grad_5, mask=batch_feature_mask) 758 | tl.store(grad_y_ptr + 6 * stride_component + base_offset, y_grad_6, mask=batch_feature_mask) 759 | tl.store(grad_y_ptr + 7 * stride_component + base_offset, y_grad_7, mask=batch_feature_mask) 760 | 761 | 762 | def gelu_fc_geometric_product_norm_bwd( 763 | x: torch.Tensor, 764 | y: torch.Tensor, 765 | weight: torch.Tensor, 766 | o: torch.Tensor, 767 | pairwise: torch.Tensor, 768 | partial_norm: torch.Tensor, 769 | grad_output: torch.Tensor, 770 | expansion_indices: torch.Tensor, 771 | normalize: bool, 772 | ) -> torch.Tensor: 773 | """Backward pass for the fused operation.""" 774 | _, B, N = x.shape 775 | 776 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B) 777 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N) 778 | 779 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK) 780 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK) 781 | 782 | grad_x = torch.zeros_like(x) 783 | grad_y = torch.zeros_like(y) 784 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0)) 785 | grad_weight = torch.zeros((NUM_PRODUCT_WEIGHTS, N, N), device=x.device, dtype=weight.dtype) 786 | grad_transformed = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype) 787 | 788 | grid = (num_blocks_batch, num_blocks_features) 789 | 790 | if normalize: 791 | grad_o_dot_o_kernel[grid]( 792 | dot, 793 | partial_norm, 794 | o, 795 | grad_output, 796 | B, 797 | N, 798 | BATCH_BLOCK, 799 | FEATURE_BLOCK, 800 | EPS, 801 | num_warps=DEFAULT_NUM_WARPS, 802 | num_stages=DEFAULT_NUM_STAGES, 803 | ) 804 | 805 | disassemble_kernel[grid]( 806 | grad_output, 807 | o, 808 | dot, 809 | grad_transformed, 810 | partial_norm, 811 | normalize, 812 | B, 813 | N, 814 | BATCH_BLOCK, 815 | FEATURE_BLOCK, 816 | EPS, 817 | num_warps=DEFAULT_NUM_WARPS, 818 | num_stages=DEFAULT_NUM_STAGES, 819 | ) 820 | 821 | grad_pairwise = torch.bmm(grad_transformed, weight[expansion_indices].transpose(-2, -1)) 822 | 823 | grad_weight.index_add_(0, expansion_indices, torch.bmm(pairwise.transpose(-2, -1), grad_transformed)) 824 | 825 | gelu_pairwise_kernel_bwd[grid]( 826 | x, 827 | y, 828 | grad_pairwise, 829 | grad_x, 830 | grad_y, 831 | B, 832 | N, 833 | BATCH_BLOCK, 834 | FEATURE_BLOCK, 835 | num_warps=DEFAULT_NUM_WARPS, 836 | num_stages=DEFAULT_NUM_STAGES, 837 | ) 838 | 839 | return grad_x, grad_y, grad_weight 840 | 841 | 842 | class FullyConnectedGeluGeometricProductNorm3D(torch.autograd.Function): 843 | 844 | @staticmethod 845 | @torch.amp.custom_fwd(device_type="cuda") 846 | def forward(ctx, x, y, weight, normalize): 847 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous() 848 | 849 | ctx.dtype = x.dtype 850 | ctx.normalize = normalize 851 | 852 | expansion_indices = torch.tensor(WEIGHT_EXPANSION, device=x.device) 853 | 854 | o, pairwise, partial_norm = gelu_fc_geometric_product_norm_fwd( 855 | x, 856 | y, 857 | weight, 858 | expansion_indices, 859 | normalize, 860 | ) 861 | 862 | ctx.save_for_backward(x, y, weight, o, pairwise, partial_norm, expansion_indices) 863 | 864 | return o.to(x.dtype) 865 | 866 | @staticmethod 867 | @torch.amp.custom_bwd(device_type="cuda") 868 | def backward(ctx, grad_output): 869 | grad_output = grad_output.contiguous() 870 | 871 | x, y, weight, o, pairwise, partial_norm, expansion_indices = ctx.saved_tensors 872 | 873 | grad_x, grad_y, grad_weight = gelu_fc_geometric_product_norm_bwd( 874 | x, 875 | y, 876 | weight, 877 | o, 878 | pairwise, 879 | partial_norm, 880 | grad_output, 881 | expansion_indices, 882 | ctx.normalize, 883 | ) 884 | 885 | return grad_x, grad_y, grad_weight, None, None, None, None 886 | 887 | 888 | def fused_gelu_fcgp_norm_3d(x, y, weight, normalize=True): 889 | """ 890 | Fused operation that applies GELU non-linearity to two multivector inputs, 891 | then computes their fully connected geometric product, and applies RMSNorm. 892 | 893 | Clifford algebra is assumed to be Cl(3,0). 894 | 895 | Args: 896 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 897 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N). 898 | weight (torch.Tensor): Weight tensor of shape (NUM_PRODUCT_WEIGHTS, N, N), one weight per geometric product component. 899 | normalize (bool): Whether to apply RMSNorm after the geometric product. 900 | 901 | Returns: 902 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation. 903 | """ 904 | return FullyConnectedGeluGeometricProductNorm3D.apply(x, y, weight, normalize) --------------------------------------------------------------------------------