├── logo.png
├── tests
├── benchmarks
│ ├── results
│ │ ├── p2m0
│ │ │ ├── memory
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ │ └── speedup
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ ├── p3m0
│ │ │ ├── memory
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ │ └── speedup
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ ├── fc_p2m0
│ │ │ ├── memory
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ │ └── speedup
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ ├── fc_p3m0
│ │ │ ├── memory
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ │ └── speedup
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ ├── layer_2d
│ │ │ ├── memory
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ │ └── speedup
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ ├── layer_3d
│ │ │ ├── memory
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ │ └── speedup
│ │ │ │ ├── fwd.png
│ │ │ │ └── fwd_bwd.png
│ │ ├── clifford_mlp_fwd_bwd_runtime.png
│ │ └── clifford_mlp_runtime_scaling.png
│ ├── p2m0.py
│ ├── p3m0.py
│ ├── fc_p2m0.py
│ ├── fc_p3m0.py
│ ├── layer_2d.py
│ └── layer_3d.py
├── p2m0.py
├── p3m0.py
├── fc_p2m0.py
├── fc_p3m0.py
├── cue_baseline.py
├── baselines.py
└── utils.py
├── ops
├── __init__.py
├── p2m0.py
├── fc_p2m0.py
├── p3m0.py
└── fc_p3m0.py
├── modules
├── layer.py
└── baseline.py
└── README.md
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/logo.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p2m0/memory/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/memory/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p3m0/memory/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/memory/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p2m0/memory/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/memory/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p3m0/memory/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/memory/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p2m0/speedup/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/speedup/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p3m0/speedup/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/speedup/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p2m0/speedup/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/speedup/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p3m0/speedup/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/speedup/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_2d/memory/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/memory/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_2d/speedup/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/speedup/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_3d/memory/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/memory/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_3d/speedup/fwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/speedup/fwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p2m0/memory/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/memory/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p2m0/speedup/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p2m0/speedup/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p3m0/memory/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/memory/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/p3m0/speedup/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/p3m0/speedup/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p2m0/memory/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/memory/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p2m0/speedup/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p2m0/speedup/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p3m0/memory/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/memory/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/fc_p3m0/speedup/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/fc_p3m0/speedup/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_2d/memory/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/memory/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_3d/memory/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/memory/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_2d/speedup/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_2d/speedup/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/layer_3d/speedup/fwd_bwd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/layer_3d/speedup/fwd_bwd.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/clifford_mlp_fwd_bwd_runtime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/clifford_mlp_fwd_bwd_runtime.png
--------------------------------------------------------------------------------
/tests/benchmarks/results/clifford_mlp_runtime_scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maxxxzdn/flash-clifford/HEAD/tests/benchmarks/results/clifford_mlp_runtime_scaling.png
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from .p2m0 import fused_gelu_sgp_norm_2d
2 | from .p3m0 import fused_gelu_sgp_norm_3d
3 | from .fc_p2m0 import fused_gelu_fcgp_norm_2d
4 | from .fc_p3m0 import fused_gelu_fcgp_norm_3d
5 |
6 | from .p2m0 import NUM_PRODUCT_WEIGHTS as P2M0_NUM_PRODUCT_WEIGHTS
7 | from .p3m0 import NUM_PRODUCT_WEIGHTS as P3M0_NUM_PRODUCT_WEIGHTS
8 |
9 | from .p2m0 import NUM_GRADES as P2M0_NUM_GRADES
10 | from .p3m0 import NUM_GRADES as P3M0_NUM_GRADES
--------------------------------------------------------------------------------
/tests/p2m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_float32_matmul_precision('medium')
4 |
5 | from ops.p2m0 import fused_gelu_sgp_norm_2d
6 | from tests.baselines import gelu_sgp_norm_2d_torch
7 | from tests.utils import run_correctness_test, run_benchmark
8 |
9 |
10 | if __name__ == "__main__":
11 | assert torch.cuda.is_available()
12 |
13 | rep = 1000
14 | batch_size = 4096
15 | num_features = 512
16 |
17 | x = torch.randn(4, batch_size, num_features).cuda().contiguous()
18 | y = torch.randn(4, batch_size, num_features).cuda().contiguous()
19 | weight = torch.randn(num_features, 10).cuda().contiguous()
20 |
21 | run_correctness_test(fused_gelu_sgp_norm_2d, gelu_sgp_norm_2d_torch, {'x': x, 'y': y, 'weight': weight})
22 | run_benchmark(fused_gelu_sgp_norm_2d, gelu_sgp_norm_2d_torch, (x, y, weight), rep, verbose=True)
--------------------------------------------------------------------------------
/tests/p3m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_float32_matmul_precision('medium')
4 |
5 | from ops.p3m0 import fused_gelu_sgp_norm_3d
6 | from tests.baselines import gelu_sgp_norm_3d_torch
7 | from tests.utils import run_correctness_test, run_benchmark
8 |
9 |
10 | if __name__ == "__main__":
11 | assert torch.cuda.is_available()
12 |
13 | rep = 1000
14 | batch_size = 4096
15 | num_features = 512
16 |
17 | x = torch.randn(8, batch_size, num_features).cuda().contiguous()
18 | y = torch.randn(8, batch_size, num_features).cuda().contiguous()
19 | weight = torch.randn(num_features, 20).cuda().contiguous()
20 |
21 | run_correctness_test(fused_gelu_sgp_norm_3d, gelu_sgp_norm_3d_torch, {'x': x, 'y': y, 'weight': weight})
22 | run_benchmark(fused_gelu_sgp_norm_3d, gelu_sgp_norm_3d_torch, (x, y, weight), rep, verbose=True)
--------------------------------------------------------------------------------
/tests/fc_p2m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_float32_matmul_precision('medium')
4 |
5 | from ops.fc_p2m0 import fused_gelu_fcgp_norm_2d
6 | from tests.baselines import gelu_fcgp_norm_2d_torch
7 | from tests.utils import run_correctness_test, run_benchmark
8 |
9 |
10 | if __name__ == "__main__":
11 | assert torch.cuda.is_available()
12 |
13 | rep = 1000
14 | batch_size = 4096
15 | num_features = 512
16 |
17 | x = torch.randn(4, batch_size, num_features).cuda().contiguous()
18 | y = torch.randn(4, batch_size, num_features).cuda().contiguous()
19 | weight = torch.randn(10, num_features, num_features).cuda().contiguous()
20 |
21 | run_correctness_test(fused_gelu_fcgp_norm_2d, gelu_fcgp_norm_2d_torch, {'x': x, 'y': y, 'weight': weight})
22 | run_benchmark(fused_gelu_fcgp_norm_2d, gelu_fcgp_norm_2d_torch, (x, y, weight), rep, verbose=True)
--------------------------------------------------------------------------------
/tests/fc_p3m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_float32_matmul_precision('medium')
4 |
5 | from ops.fc_p3m0 import fused_gelu_fcgp_norm_3d
6 | from tests.baselines import gelu_fcgp_norm_3d_torch
7 | from tests.utils import run_correctness_test, run_benchmark
8 |
9 |
10 | if __name__ == "__main__":
11 | assert torch.cuda.is_available()
12 |
13 | rep = 1000
14 | batch_size = 4096
15 | num_features = 512
16 |
17 | x = torch.randn(8, batch_size, num_features).cuda().contiguous()
18 | y = torch.randn(8, batch_size, num_features).cuda().contiguous()
19 | weight = torch.randn(20, num_features, num_features).cuda().contiguous()
20 |
21 | run_correctness_test(fused_gelu_fcgp_norm_3d, gelu_fcgp_norm_3d_torch, {'x': x, 'y': y, 'weight': weight})
22 | run_benchmark(fused_gelu_fcgp_norm_3d, gelu_fcgp_norm_3d_torch, (x, y, weight), rep, verbose=True)
--------------------------------------------------------------------------------
/tests/cue_baseline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cuequivariance as cue
3 | import cuequivariance_torch as cuet
4 |
5 |
6 | def mvgelu(x):
7 | """Apply GELU activation gated by scalar component."""
8 | # x has shape (B, N * 8)
9 | b = x.shape[0]
10 | x = x.view(b, -1, 8) # (B, N, 8)
11 | s = x[..., 0:1] # scalar part
12 | gate = 0.5 * (1 + torch.erf(s * 0.7071067811865475))
13 | y = gate * x
14 | y = y.view(b, -1) # (B, N * 8)
15 | return y
16 |
17 |
18 | def initialize_linear(N: int):
19 | """
20 | Initialize MLP with linear layer + weighted GP + GELU + BatchNorm.
21 | Related: https://github.com/NVIDIA/cuEquivariance/issues/194
22 | """
23 | irreps = cue.Irreps("O3", f"{N}x0e + {N}x1o + {N}x1e + {N}x0o")
24 |
25 | ep_weighted = cue.descriptors.fully_connected_tensor_product(
26 | irreps.set_mul(1),
27 | irreps.set_mul(1),
28 | irreps.set_mul(1)
29 | )
30 |
31 | [(_, stp_weighted)] = ep_weighted.polynomial.operations
32 | stp_weighted = stp_weighted.append_modes_to_all_operands("n", dict(n=N))
33 | p_weighted = cue.SegmentedPolynomial.eval_last_operand(stp_weighted)
34 |
35 | weighted_gp = cuet.SegmentedPolynomial(p_weighted, method="uniform_1d").cuda()
36 | linear = cuet.Linear(irreps_in=irreps, irreps_out=irreps, layout=cue.ir_mul, layout_in=cue.ir_mul, layout_out=cue.ir_mul).cuda()
37 | norm = cuet.layers.BatchNorm(irreps=irreps, layout=cue.ir_mul).cuda()
38 |
39 | @torch.compile
40 | def mlp(x, w):
41 | y = linear(x)
42 | x = mvgelu(x)
43 | y = mvgelu(y)
44 | [x] = weighted_gp([w, x, y])
45 | x = norm(x)
46 | return x
47 |
48 | return mlp
--------------------------------------------------------------------------------
/tests/benchmarks/p2m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_float32_matmul_precision('medium')
3 | torch._dynamo.config.cache_size_limit = 512
4 |
5 | from ops.p2m0 import fused_gelu_sgp_norm_2d
6 | from tests.baselines import gelu_sgp_norm_2d_torch
7 | from tests.utils import plot_heatmap, print_results_table, run_sweep
8 |
9 |
10 | def setup_benchmark(batch_size, num_features):
11 | x = torch.randn(4, batch_size, num_features).cuda().contiguous()
12 | y = torch.randn(4, batch_size, num_features).cuda().contiguous()
13 | weight = torch.randn(num_features, 10).cuda().contiguous()
14 | return x, y, weight
15 |
16 |
17 | if __name__ == "__main__":
18 | assert torch.cuda.is_available()
19 |
20 | path = "tests/benchmarks/results/p2m0"
21 |
22 | results = run_sweep(
23 | fused_gelu_sgp_norm_2d,
24 | gelu_sgp_norm_2d_torch,
25 | setup_benchmark,
26 | batch_sizes=[1024, 2048, 4096, 8192],
27 | num_features_list=[128, 256, 512, 1024],
28 | rep=200
29 | )
30 |
31 | print_results_table(results, "p2m0")
32 |
33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(2,0)',
34 | path + '/speedup/fwd.png')
35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(2,0)',
36 | path + '/speedup/fwd_bwd.png')
37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)',
38 | path + '/memory/fwd.png', invert_cmap=True)
39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)',
40 | path + '/memory/fwd_bwd.png', invert_cmap=True)
41 |
--------------------------------------------------------------------------------
/tests/benchmarks/p3m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_float32_matmul_precision('medium')
3 | torch._dynamo.config.cache_size_limit = 512
4 |
5 | from ops.p3m0 import fused_gelu_sgp_norm_3d
6 | from tests.baselines import gelu_sgp_norm_3d_torch
7 | from tests.utils import plot_heatmap, print_results_table, run_sweep
8 |
9 |
10 | def setup_benchmark(batch_size, num_features):
11 | x = torch.randn(8, batch_size, num_features).cuda().contiguous()
12 | y = torch.randn(8, batch_size, num_features).cuda().contiguous()
13 | weight = torch.randn(num_features, 20).cuda().contiguous()
14 | return x, y, weight
15 |
16 |
17 | if __name__ == "__main__":
18 | assert torch.cuda.is_available()
19 |
20 | path = "tests/benchmarks/results/p3m0"
21 |
22 | results = run_sweep(
23 | fused_gelu_sgp_norm_3d,
24 | gelu_sgp_norm_3d_torch,
25 | setup_benchmark,
26 | batch_sizes=[1024, 2048, 4096, 8192],
27 | num_features_list=[128, 256, 512, 1024],
28 | rep=200
29 | )
30 |
31 | print_results_table(results, "p3m0")
32 |
33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(3,0)',
34 | path + '/speedup/fwd.png')
35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(3,0)',
36 | path + '/speedup/fwd_bwd.png')
37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)',
38 | path + '/memory/fwd.png', invert_cmap=True)
39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)',
40 | path + '/memory/fwd_bwd.png', invert_cmap=True)
41 |
--------------------------------------------------------------------------------
/tests/benchmarks/fc_p2m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_float32_matmul_precision('medium')
3 | torch._dynamo.config.cache_size_limit = 512
4 |
5 | from ops.fc_p2m0 import fused_gelu_fcgp_norm_2d
6 | from tests.baselines import gelu_fcgp_norm_2d_torch
7 | from tests.utils import plot_heatmap, print_results_table, run_sweep
8 |
9 |
10 | def setup_benchmark(batch_size, num_features):
11 | x = torch.randn(4, batch_size, num_features).cuda().contiguous()
12 | y = torch.randn(4, batch_size, num_features).cuda().contiguous()
13 | weight = torch.randn(10, num_features, num_features).cuda().contiguous()
14 | return x, y, weight
15 |
16 |
17 | if __name__ == "__main__":
18 | assert torch.cuda.is_available()
19 |
20 | path = "tests/benchmarks/results/fc_p2m0"
21 |
22 | results = run_sweep(
23 | fused_gelu_fcgp_norm_2d,
24 | gelu_fcgp_norm_2d_torch,
25 | setup_benchmark,
26 | batch_sizes=[1024, 2048, 4096, 8192],
27 | num_features_list=[128, 256, 512, 1024],
28 | rep=200
29 | )
30 |
31 | print_results_table(results, "fc_p2m0")
32 |
33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(2,0)',
34 | path + '/speedup/fwd.png')
35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(2,0)',
36 | path + '/speedup/fwd_bwd.png')
37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)',
38 | path + '/memory/fwd.png', invert_cmap=True)
39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)',
40 | path + '/memory/fwd_bwd.png', invert_cmap=True)
41 |
--------------------------------------------------------------------------------
/tests/benchmarks/fc_p3m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_float32_matmul_precision('medium')
3 | torch._dynamo.config.cache_size_limit = 512
4 |
5 | from ops.fc_p3m0 import fused_gelu_fcgp_norm_3d
6 | from tests.baselines import gelu_fcgp_norm_3d_torch
7 | from tests.utils import plot_heatmap, print_results_table, run_sweep
8 |
9 |
10 | def setup_benchmark(batch_size, num_features):
11 | x = torch.randn(8, batch_size, num_features).cuda().contiguous()
12 | y = torch.randn(8, batch_size, num_features).cuda().contiguous()
13 | weight = torch.randn(20, num_features, num_features).cuda().contiguous()
14 | return x, y, weight
15 |
16 |
17 | if __name__ == "__main__":
18 | assert torch.cuda.is_available()
19 |
20 | path = "tests/benchmarks/results/fc_p3m0"
21 |
22 | results = run_sweep(
23 | fused_gelu_fcgp_norm_3d,
24 | gelu_fcgp_norm_3d_torch,
25 | setup_benchmark,
26 | batch_sizes=[1024, 2048, 4096, 8192],
27 | num_features_list=[128, 256, 512, 1024],
28 | rep=200
29 | )
30 |
31 | print_results_table(results, "fc_p3m0")
32 |
33 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(3,0)',
34 | path + '/speedup/fwd.png')
35 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(3,0)',
36 | path + '/speedup/fwd_bwd.png')
37 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)',
38 | path + '/memory/fwd.png', invert_cmap=True)
39 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)',
40 | path + '/memory/fwd_bwd.png', invert_cmap=True)
41 |
--------------------------------------------------------------------------------
/tests/benchmarks/layer_2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_float32_matmul_precision('medium')
3 | torch._dynamo.config.cache_size_limit = 512
4 |
5 | from modules.layer import Layer
6 | from modules.baseline import Layer as BaselineLayer
7 | from tests.utils import plot_heatmap, print_results_table, run_single_benchmark
8 |
9 |
10 | def setup_benchmark(batch_size, num_features):
11 | x = torch.randn(4, batch_size, num_features).cuda().contiguous()
12 | return (x,)
13 |
14 |
15 | def create_layers(num_features: int):
16 | layer_fused = Layer(num_features, dims=2, normalize=True, use_fc=False).float().cuda()
17 | layer_torch = BaselineLayer(num_features, dims=2, normalize=True, use_fc=False).float().cuda()
18 | return layer_fused, layer_torch
19 |
20 |
21 | if __name__ == "__main__":
22 | assert torch.cuda.is_available()
23 |
24 | rep = 1000
25 | warmup = 500
26 | batch_sizes=[1024, 2048, 4096, 8192]
27 | num_features_list=[128, 256, 512, 1024]
28 | path = "tests/benchmarks/results/layer_2d"
29 |
30 | results = []
31 |
32 | print("Running benchmark sweep...")
33 | print(f"Batch sizes: {batch_sizes}")
34 | print(f"Num features: {num_features_list}")
35 |
36 | for batch_size in batch_sizes:
37 | for num_features in num_features_list:
38 | print(f"Running batch_size={batch_size}, num_features={num_features}...", end=" ")
39 | triton_fn, torch_fn = create_layers(num_features)
40 | triton_fn = torch.compile(triton_fn)
41 | torch_fn = torch.compile(torch_fn)
42 |
43 | result = run_single_benchmark(
44 | triton_fn, torch_fn, setup_benchmark, batch_size,
45 | num_features, rep, warmup, verify_correctness=False, return_mode='mean'
46 | )
47 | results.append(result)
48 |
49 | fwd_msg = (f"Fwd: {result['speedup_fwd']:.2f}x"
50 | if result['speedup_fwd'] else "Fwd: OOM")
51 | bwd_msg = (f"Fwd+Bwd: {result['speedup_fwd_bwd']:.2f}x"
52 | if result['speedup_fwd_bwd'] else "Fwd+Bwd: OOM")
53 | print(f"{fwd_msg}, {bwd_msg}")
54 |
55 | print_results_table(results, "layer_2d")
56 |
57 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(2,0)',
58 | path + '/speedup/fwd.png')
59 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(2,0)',
60 | path + '/speedup/fwd_bwd.png')
61 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)',
62 | path + '/memory/fwd.png', invert_cmap=True)
63 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(2,0)',
64 | path + '/memory/fwd_bwd.png', invert_cmap=True)
--------------------------------------------------------------------------------
/tests/benchmarks/layer_3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.set_float32_matmul_precision('medium')
3 | torch._dynamo.config.cache_size_limit = 512
4 |
5 | from modules.layer import Layer
6 | from modules.baseline import Layer as BaselineLayer
7 | from tests.utils import plot_heatmap, print_results_table, run_single_benchmark
8 |
9 |
10 | def setup_benchmark(batch_size, num_features):
11 | x = torch.randn(8, batch_size, num_features).cuda().contiguous()
12 | return (x,)
13 |
14 |
15 | def create_layers(num_features: int):
16 | layer_fused = Layer(num_features, dims=3, normalize=True, use_fc=False).float().cuda()
17 | layer_torch = BaselineLayer(num_features, dims=3, normalize=True, use_fc=False).float().cuda()
18 | return layer_fused, layer_torch
19 |
20 |
21 | if __name__ == "__main__":
22 | assert torch.cuda.is_available()
23 |
24 | rep = 1000
25 | warmup = 500
26 | batch_sizes=[1024, 2048, 4096, 8192]
27 | num_features_list=[128, 256, 512, 1024]
28 | path = "tests/benchmarks/results/layer_3d"
29 |
30 | results = []
31 |
32 | print("Running benchmark sweep...")
33 | print(f"Batch sizes: {batch_sizes}")
34 | print(f"Num features: {num_features_list}")
35 |
36 | for batch_size in batch_sizes:
37 | for num_features in num_features_list:
38 | print(f"Running batch_size={batch_size}, num_features={num_features}...", end=" ")
39 | triton_fn, torch_fn = create_layers(num_features)
40 | triton_fn = torch.compile(triton_fn)
41 | torch_fn = torch.compile(torch_fn)
42 |
43 | result = run_single_benchmark(
44 | triton_fn, torch_fn, setup_benchmark, batch_size,
45 | num_features, rep, warmup, verify_correctness=False, return_mode='mean'
46 | )
47 | results.append(result)
48 |
49 | fwd_msg = (f"Fwd: {result['speedup_fwd']:.2f}x"
50 | if result['speedup_fwd'] else "Fwd: OOM")
51 | bwd_msg = (f"Fwd+Bwd: {result['speedup_fwd_bwd']:.2f}x"
52 | if result['speedup_fwd_bwd'] else "Fwd+Bwd: OOM")
53 | print(f"{fwd_msg}, {bwd_msg}")
54 |
55 | print_results_table(results, "layer_3d")
56 |
57 | plot_heatmap(results, 'speedup_fwd', 'Forward Pass Speedup: Triton vs PyTorch\nCl(3,0)',
58 | path + '/speedup/fwd.png')
59 | plot_heatmap(results, 'speedup_fwd_bwd', 'Forward + Backward Pass Speedup: Triton vs PyTorch\nCl(3,0)',
60 | path + '/speedup/fwd_bwd.png')
61 | plot_heatmap(results, 'mem_ratio_fwd', 'Forward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)',
62 | path + '/memory/fwd.png', invert_cmap=True)
63 | plot_heatmap(results, 'mem_ratio_fwd_bwd', 'Forward + Backward Pass Memory Ratio: Fused / PyTorch\nCl(3,0)',
64 | path + '/memory/fwd_bwd.png', invert_cmap=True)
--------------------------------------------------------------------------------
/modules/layer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from ops import fused_gelu_sgp_norm_2d, fused_gelu_sgp_norm_3d, fused_gelu_fcgp_norm_2d, fused_gelu_fcgp_norm_3d, P2M0_NUM_PRODUCT_WEIGHTS, P3M0_NUM_PRODUCT_WEIGHTS, P2M0_NUM_GRADES, P3M0_NUM_GRADES
5 |
6 | _FUSED_OPS = {
7 | (2, False): fused_gelu_sgp_norm_2d,
8 | (2, True): fused_gelu_fcgp_norm_2d,
9 | (3, False): fused_gelu_sgp_norm_3d,
10 | (3, True): fused_gelu_fcgp_norm_3d,
11 | }
12 |
13 | _CONFIG = {
14 | 2: {
15 | 'num_product_weights': P2M0_NUM_PRODUCT_WEIGHTS,
16 | 'num_grades': P2M0_NUM_GRADES,
17 | 'weight_expansion': torch.tensor([0, 1, 1, 2], dtype=torch.long),
18 | },
19 | 3: {
20 | 'num_product_weights': P3M0_NUM_PRODUCT_WEIGHTS,
21 | 'num_grades': P3M0_NUM_GRADES,
22 | 'weight_expansion': torch.tensor([0, 1, 1, 1, 2, 2, 2, 3], dtype=torch.long),
23 | }
24 | }
25 |
26 |
27 | class Layer(torch.nn.Module):
28 | """
29 | Linear layer: grade-wise linear + weighted GP + GELU + LayerNorm.
30 | Efficient implementation of https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/blob/8482b06b71712dcea2841ebe567d37e7f8432d27/models/nbody_cggnn.py#L47
31 |
32 | Args:
33 | n_features: number of features.
34 | dims: 2 or 3, dimension of the space.
35 | normalize: whether to apply layer normalization at the end.
36 | use_fc: whether to use fully connected GP weights.
37 | True: weight has shape (n_features, n_features, num_product_weights).
38 | False: weight has shape (n_features, num_product_weights).
39 | """
40 | def __init__(self, n_features, dims, normalize=True, use_fc=False):
41 | super().__init__()
42 |
43 | if dims not in _CONFIG:
44 | raise ValueError(f"Unsupported dims: {dims}")
45 |
46 | config = _CONFIG[dims]
47 | self.normalize = normalize
48 | self.fused_op = _FUSED_OPS[(dims, use_fc)]
49 |
50 | if use_fc:
51 | gp_weight_shape = (config['num_product_weights'], n_features, n_features)
52 | else:
53 | gp_weight_shape = (n_features, config['num_product_weights'])
54 |
55 | self.gp_weight = torch.nn.Parameter(torch.empty(gp_weight_shape))
56 |
57 | linear_weight_shape = (config['num_grades'], n_features, n_features)
58 | self.linear_weight = torch.nn.Parameter(torch.empty(linear_weight_shape))
59 | self.linear_bias = torch.nn.Parameter(torch.zeros(1, 1, n_features))
60 |
61 | self.register_buffer("weight_expansion", config['weight_expansion'])
62 |
63 | torch.nn.init.normal_(self.gp_weight, std=1 / (math.sqrt(dims + 1)))
64 | torch.nn.init.normal_(self.linear_weight, std=1 / math.sqrt(n_features) if use_fc else 1 / math.sqrt(n_features * (dims + 1)))
65 |
66 | def forward(self, x):
67 | y = torch.bmm(x, self.linear_weight[self.weight_expansion]) + self.linear_bias
68 | return self.fused_op(x, y, self.gp_weight, self.normalize)
--------------------------------------------------------------------------------
/modules/baseline.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from ops import P2M0_NUM_PRODUCT_WEIGHTS, P3M0_NUM_PRODUCT_WEIGHTS, P2M0_NUM_GRADES, P3M0_NUM_GRADES
5 | from tests.baselines import gelu_sgp_norm_2d_torch, gelu_sgp_norm_3d_torch, gelu_fcgp_norm_2d_torch, gelu_fcgp_norm_3d_torch
6 |
7 | _FUSED_OPS = {
8 | (2, False): gelu_sgp_norm_2d_torch,
9 | (2, True): gelu_fcgp_norm_2d_torch,
10 | (3, False): gelu_sgp_norm_3d_torch,
11 | (3, True): gelu_fcgp_norm_3d_torch,
12 | }
13 |
14 | _CONFIG = {
15 | 2: {
16 | 'num_product_weights': P2M0_NUM_PRODUCT_WEIGHTS,
17 | 'num_grades': P2M0_NUM_GRADES,
18 | 'weight_expansion': torch.tensor([0, 1, 1, 2], dtype=torch.long),
19 | },
20 | 3: {
21 | 'num_product_weights': P3M0_NUM_PRODUCT_WEIGHTS,
22 | 'num_grades': P3M0_NUM_GRADES,
23 | 'weight_expansion': torch.tensor([0, 1, 1, 1, 2, 2, 2, 3], dtype=torch.long),
24 | }
25 | }
26 |
27 |
28 | class Layer(torch.nn.Module):
29 | """
30 | Linear layer: grade-wise linear + weighted GP + GELU + LayerNorm.
31 | Metric-specific implementation of https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/blob/8482b06b71712dcea2841ebe567d37e7f8432d27/models/nbody_cggnn.py#L47
32 |
33 | Args:
34 | n_features: number of features.
35 | dims: 2 or 3, dimension of the space.
36 | normalize: whether to apply layer normalization at the end.
37 | use_fc: whether to use fully connected GP weights.
38 | True: weight has shape (n_features, n_features, num_product_weights).
39 | False: weight has shape (n_features, num_product_weights).
40 | """
41 | def __init__(self, n_features, dims, normalize=True, use_fc=False):
42 | super().__init__()
43 |
44 | if dims not in _CONFIG:
45 | raise ValueError(f"Unsupported dims: {dims}")
46 |
47 | config = _CONFIG[dims]
48 | self.normalize = normalize
49 | self.fused_op = _FUSED_OPS[(dims, use_fc)]
50 |
51 | if use_fc:
52 | gp_weight_shape = (config['num_product_weights'], n_features, n_features)
53 | else:
54 | gp_weight_shape = (n_features, config['num_product_weights'])
55 |
56 | self.gp_weight = torch.nn.Parameter(torch.empty(gp_weight_shape))
57 |
58 | linear_weight_shape = (config['num_grades'], n_features, n_features)
59 | self.linear_weight = torch.nn.Parameter(torch.empty(linear_weight_shape))
60 | self.linear_bias = torch.nn.Parameter(torch.zeros(1, 1, n_features))
61 |
62 | self.register_buffer("weight_expansion", config['weight_expansion'])
63 |
64 | torch.nn.init.normal_(self.gp_weight, std=1 / (math.sqrt(dims + 1)))
65 | torch.nn.init.normal_(self.linear_weight, std=1 / math.sqrt(n_features) if use_fc else 1 / math.sqrt(n_features * (dims + 1)))
66 |
67 | def forward(self, x):
68 | y = torch.bmm(x, self.linear_weight[self.weight_expansion]) + self.linear_bias
69 | return self.fused_op(x, y, self.gp_weight, self.normalize)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Flash Clifford
4 | `flash-clifford` provides efficient Triton-based implementations of Clifford algebra-based models.
5 |
6 |
7 |

8 |
9 |
10 |
11 | ## $O(n)$-Equivariant operators
12 | The list of currently implemented $O(2)$- and $O(3)$-equivariant operators:
13 | - `fused_gelu_sgp_norm_nd`: multivector GELU $\rightarrow$ weighted geometric product $\rightarrow$ (optionally) multivector RMSNorm
14 | - `fused_gelu_fcgp_norm_nd`: multivector GELU $\rightarrow$ fully connected geometric product $\rightarrow$ (optionally) multivector RMSNorm
15 | - `linear layer`: multivector linear $\rightarrow$ `fused_gelu_sgp_norm_nd`
16 |
17 | Any suggestions for different operators are welcome :)
18 |
19 | ## Primer on Clifford Algebra
20 | Clifford algebra is tightly connected to the Euclidean group $E(n)$. That is, elements of Clifford algebra are **multivectors** - stacks of basis components (scalar, vector, etc.). Those components correspond 1:1 to irreducible representations of $O(n)$. For example, for 3D:
21 |
22 | | Grades | Irreps |
23 | |------------------|--------|
24 | | 0 (scalar) | 0e |
25 | | 1 (vector) | 1o |
26 | | 2 (bivector) | 1e |
27 | | 3 (trivector) | 0o |
28 |
29 | The **geometric product** is a bilinear operation that takes two multivectors and returns a multivector, essentially mixing information between grades equivariantly ([Ruhe et al., 2023](https://arxiv.org/abs/2305.11141)). It is a subset of the tensor product, with the key difference that the geometric product does not generate higher-order representations (e.g., frequency 2). While this might come across as a limitation in terms of expressivity, the fixed and simple structure of the geometric product admits very efficient computation, which can be implemented in approx. ~1k LOC (compared to 10k LOC in [cuEquivariance](https://github.com/NVIDIA/cuEquivariance)). Empirically, Clifford algebra-based neural networks achieve state-of-the-art performance on [N-body tasks](https://arxiv.org/abs/2305.11141) and [jet tagging](https://arxiv.org/abs/2405.14806). At the same time, they are undeservedly dismissed for their inefficiency, which we aim to address in this repo :).
30 |
31 | ## Performance
32 | ### Our approach
33 | The baseline approach taken in [Ruhe et al., 2023](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks) is to implement the geometric product via dense einsum `bni, mnijk, bnk -> bmj`, where `mnijk` is a tensor (Cayley table) that encodes how the interaction of element `i` of multivector 1 and element `k` of multivector 2 results in element `j` of the output multivector. This allows having a single out-of-the-box implementation for any metric space, which is definitely cool, but it suffers in performance as the Cayley table is extremely sparse (85% in 2D, 95% in 3D). Thus, we mainly improve performance by simply hardcoding the rules of the geometric product, eliminating wasteful operations. The second source of optimization comes from fusing multiple kernels into one, specifically the activation function, which typically comes before the geometric product, and normalization. Finally, significant speedup is achieved by switching to `(MV_DIM, BATCH_SIZE, NUM_FEATURES)` memory layout, which allows expressing the linear layer as batch matmul.
34 |
35 | ### Benchmarking
36 | To demonstrate performance improvements, we benchmark the following linear layer:
37 | ```
38 | input: multivector features x, GP weights w
39 | 1) y = MVLinear(x)
40 | 2) x = MVGELU(x)
41 | 3) y = MVGELU(y)
42 | 4) o = weighted_gp(x, y, w)
43 | 5) o = MVRMSNorm(o)
44 | ```
45 | which is a primitive of a [Clifford Algebra MLP](https://github.com/DavidRuhe/clifford-group-equivariant-neural-networks/blob/8482b06b71712dcea2841ebe567d37e7f8432d27/models/nbody_cggnn.py#L47).
46 | We compare against `cuEquivariance` (using the correspondence between multivectors and irreps) and the baseline CEGNN implementation, achieving significant improvements:
47 |
48 |
49 |

50 |

51 |
52 |
53 |
54 |
55 | ## Requirements
56 | The following requirements must be satisfied:
57 | - PyTorch
58 | - Triton >= 3.0
59 |
60 | If you want to run tests, additionally:
61 | - NumPy
62 | - matplotlib
63 |
64 |
65 | ## Usage
66 | ```python
67 | import torch
68 | from modules.layer import Layer
69 |
70 | # Input: multivectors in 3D of shape (8, batch, features)
71 | x = torch.randn(8, 4096, 512).cuda()
72 |
73 | # Linear layer: grade-wise linear + weighted GP
74 | layer = Layer(512, dims=3, normalize=True, use_fc=False).cuda()
75 |
76 | output = layer(x)
77 | ```
78 |
79 | ## Benchmarking
80 |
81 | Run benchmarks (runtime + memory) with:
82 | ```bash
83 | python -m tests.benchmarks.layer_3d
84 | ```
85 | This will generate a heatmap comparison against a torch-compiled implementation, which can also be found in [results](/home/maxxxzdn/Deusex/fast_clifford/flash-clifford/tests/benchmarks/results) (done on RTX 4500).
86 |
87 | ## Testing
88 |
89 | To verify correctness against a PyTorch baseline:
90 | ```bash
91 | python -m tests.p3m0
92 | ```
93 | which will check both forward and backward (gradient) passes as well as measure the runtime and memory consumption.
94 |
95 | ## Citation
96 | If you find this repository helpful, please cite our work:
97 | ```
98 | @software{flashclifford2025,
99 | title = {Flash Clifford: Hardware-Efficient Implementation of Clifford Algebra Neural Networks},
100 | author = {Zhdanov, Maksim},
101 | url = {https://github.com/maxxxzdn/flash-clifford},
102 | year = {2025}
103 | }
104 | ```
--------------------------------------------------------------------------------
/tests/baselines.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | ### Activations ###
5 |
6 | def mv_gelu(x):
7 | """Apply GELU activation gated by scalar component."""
8 | scalar = x[[0]]
9 | gate = 0.5 * (1 + torch.erf(scalar * 0.7071067811865475))
10 | return x * gate
11 |
12 |
13 | ### Norms ###
14 |
15 | def mv_rmsnorm_2d(x, eps=1e-6):
16 | """RMS normalization for Cl(2,0) (scalar, vector, pseudoscalar)."""
17 | scalar = x[[0]]
18 | vector = x[[1, 2]]
19 | pseudoscalar = x[[3]]
20 |
21 | scalar_rms = (scalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
22 | scalar = scalar / scalar_rms
23 |
24 | vector_norm = vector.norm(dim=0, keepdim=True)
25 | vector_rms = (vector_norm.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
26 | vector = vector / vector_rms
27 |
28 | pseudoscalar_rms = (pseudoscalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
29 | pseudoscalar = pseudoscalar / pseudoscalar_rms
30 |
31 | return torch.cat([scalar, vector, pseudoscalar], dim=0)
32 |
33 |
34 | def mv_rmsnorm_3d(x, eps=1e-6):
35 | """RMS normalization for Cl(3,0) (scalar, vector, bivector, pseudoscalar)."""
36 | scalar = x[[0]]
37 | vector = x[[1, 2, 3]]
38 | bivector = x[[4, 5, 6]]
39 | pseudoscalar = x[[7]]
40 |
41 | scalar_rms = (scalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
42 | scalar = scalar / scalar_rms
43 |
44 | vector_norm = vector.norm(dim=0, keepdim=True)
45 | vector_rms = (vector_norm.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
46 | vector = vector / vector_rms
47 |
48 | bivector_norm = bivector.norm(dim=0, keepdim=True)
49 | bivector_rms = (bivector_norm.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
50 | bivector = bivector / bivector_rms
51 |
52 | pseudoscalar_rms = (pseudoscalar.pow(2).mean(dim=-1, keepdim=True) + eps).sqrt()
53 | pseudoscalar = pseudoscalar / pseudoscalar_rms
54 |
55 | return torch.cat([scalar, vector, bivector, pseudoscalar], dim=0)
56 |
57 |
58 | ### Geometric Product Layers ###
59 |
60 | def sgp_2d(x, y, weight):
61 | """Weighted geometric product in Cl(2,0)."""
62 | x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
63 | y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
64 |
65 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9 = weight.T
66 |
67 | o0 = w0 * x0 * y0 + w3 * (x1 * y1 + x2 * y2) - w7 * x3 * y3
68 | o1 = w1 * x0 * y1 + w4 * x1 * y0 - w5 * x2 * y3 + w8 * x3 * y2
69 | o2 = w1 * x0 * y2 + w5 * x1 * y3 + w4 * x2 * y0 - w8 * x3 * y1
70 | o3 = w2 * x0 * y3 + w6 * (x1 * y2 - x2 * y1) + w9 * x3 * y0
71 |
72 | return torch.stack([o0, o1, o2, o3], dim=0)
73 |
74 |
75 | def sgp_3d(x, y, weight):
76 | """Weighted geometric product in Cl(3,0)."""
77 | x0, x1, x2, x3, x4, x5, x6, x7 = x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]
78 | y0, y1, y2, y3, y4, y5, y6, y7 = y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7]
79 |
80 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19 = weight.T
81 |
82 | o0 = (w0*x0*y0 + w4 * (x1*y1 + x2*y2 + x3*y3) - w10 * (x4*y4 + x5*y5 + x6*y6) - w16*x7*y7)
83 | o1 = (w1*x0*y1 + w5*x1*y0 - w6 * (x2*y4 + x3*y5) + w11 * (x4*y2 + x5*y3) - w12*x6*y7 - w17*x7*y6)
84 | o2 = (w1*x0*y2 + w6*x1*y4 + w5*x2*y0 - w6*x3*y6 - w11*x4*y1 + w12*x5*y7 + w11*x6*y3 + w17*x7*y5)
85 | o3 = (w1*x0*y3 + w6 * (x1*y5 + x2*y6) + w5*x3*y0 - w12*x4*y7 - w11 * (x5*y1 + x6*y2) - w17*x7*y4)
86 | o4 = (w2*x0*y4 + w7*x1*y2 - w7*x2*y1 + w8*x3*y7 + w13*x4*y0 - w14*x5*y6 + w14*x6*y5 + w18*x7*y3)
87 | o5 = (w2*x0*y5 + w7*x1*y3 - w8*x2*y7 - w7*x3*y1 + w14*x4*y6 + w13*x5*y0 - w14*x6*y4 - w18*x7*y2)
88 | o6 = (w2*x0*y6 + w8*x1*y7 + w7*x2*y3 - w7*x3*y2 - w14*x4*y5 + w14*x5*y4 + w13*x6*y0 + w18*x7*y1)
89 | o7 = (w3*x0*y7 + w9*x1*y6 - w9*x2*y5 + w9*x3*y4 + w15*x4*y3 - w15*x5*y2 + w15*x6*y1 + w19*x7*y0)
90 |
91 | return torch.stack([o0, o1, o2, o3, o4, o5, o6, o7], dim=0)
92 |
93 |
94 | def fcgp_2d(x, y, weight):
95 | """Fully connected geometric product in Cl(2,0)."""
96 | x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
97 | y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
98 |
99 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9 = weight
100 |
101 | o0 = (x0 * y0) @ w0 + (x1 * y1 + x2 * y2) @ w3 - (x3 * y3) @ w7
102 | o1 = (x0 * y1) @ w1 + (x1 * y0) @ w4 - (x2 * y3) @ w5 + (x3 * y2) @ w8
103 | o2 = (x0 * y2) @ w1 + (x1 * y3) @ w5 + (x2 * y0) @ w4 - (x3 * y1) @ w8
104 | o3 = x0 * y3 @ w2 + (x1 * y2 - x2 * y1) @ w6 + (x3 * y0) @ w9
105 |
106 | return torch.stack([o0, o1, o2, o3], dim=0)
107 |
108 |
109 | def fcgp_3d(x, y, weight):
110 | """Fully connected geometric product in Cl(3,0)."""
111 | x0, x1, x2, x3, x4, x5, x6, x7 = x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]
112 | y0, y1, y2, y3, y4, y5, y6, y7 = y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7]
113 |
114 | w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19 = weight
115 |
116 | o0 = (x0 * y0) @ w0 + (x1 * y1 + x2 * y2 + x3 * y3) @ w4 - (x4 * y4 + x5 * y5 + x6 * y6) @ w10 - (x7 * y7) @ w16
117 | o1 = (x0 * y1) @ w1 + (x1 * y0) @ w5 - (x2 * y4 + x3 * y5) @ w6 + (x4 * y2 + x5 * y3) @ w11 - (x6 * y7) @ w12 - (x7 * y6) @ w17
118 | o2 = (x0 * y2) @ w1 + (x1 * y4) @ w6 + (x2 * y0) @ w5 - (x3 * y6) @ w6 - (x4 * y1) @ w11 + (x5 * y7) @ w12 + (x6 * y3) @ w11 + (x7 * y5) @ w17
119 | o3 = (x0 * y3) @ w1 + (x1 * y5 + x2 * y6) @ w6 + (x3 * y0) @ w5 - (x4 * y7) @ w12 - (x5 * y1 + x6 * y2) @ w11 - (x7 * y4) @ w17
120 | o4 = (x0 * y4) @ w2 + (x1 * y2 - x2 * y1) @ w7 + (x3 * y7) @ w8 + (x4 * y0) @ w13 - (x5 * y6) @ w14 + (x6 * y5) @ w14 + (x7 * y3) @ w18
121 | o5 = (x0 * y5) @ w2 + (x1 * y3) @ w7 - (x2 * y7) @ w8 - (x3 * y1) @ w7 + (x4 * y6) @ w14 + (x5 * y0) @ w13 - (x6 * y4) @ w14 - (x7 * y2) @ w18
122 | o6 = (x0 * y6) @ w2 + (x1 * y7) @ w8 + (x2 * y3) @ w7 - (x3 * y2) @ w7 - (x4 * y5) @ w14 + (x5 * y4) @ w14 + (x6 * y0) @ w13 + (x7 * y1) @ w18
123 | o7 = (x0 * y7) @ w3 + (x1 * y6 - x2 * y5 + x3 * y4) @ w9 + (x4 * y3 - x5 * y2 + x6 * y1) @ w15 + (x7 * y0) @ w19
124 |
125 | return torch.stack([o0, o1, o2, o3, o4, o5, o6, o7], dim=0)
126 |
127 |
128 | ### Baseline implementations ###
129 |
130 | @torch.compile
131 | def gelu_sgp_norm_2d_torch(x, y, weight, normalize=True):
132 | """Geometric product layer with GELU activation and RMS normalization in Cl(2,0)."""
133 | x = mv_gelu(x)
134 | y = mv_gelu(y)
135 | o = sgp_2d(x, y, weight)
136 | if normalize:
137 | o = mv_rmsnorm_2d(o)
138 | return o
139 |
140 |
141 | @torch.compile
142 | def gelu_sgp_norm_3d_torch(x, y, weight, normalize=True):
143 | """Geometric product layer with GELU activation and RMS normalization in Cl(3,0)."""
144 | x = mv_gelu(x)
145 | y = mv_gelu(y)
146 | o = sgp_3d(x, y, weight)
147 | if normalize:
148 | o = mv_rmsnorm_3d(o)
149 | return o
150 |
151 |
152 | @torch.compile
153 | def gelu_fcgp_norm_2d_torch(x, y, weight, normalize=True):
154 | """Fully connected geometric product layer with GELU activation and RMS normalization in Cl(2,0)."""
155 | x = mv_gelu(x)
156 | y = mv_gelu(y)
157 | o = fcgp_2d(x, y, weight)
158 | if normalize:
159 | o = mv_rmsnorm_2d(o)
160 | return o
161 |
162 |
163 | @torch.compile
164 | def gelu_fcgp_norm_3d_torch(x, y, weight, normalize=True):
165 | """Fully connected geometric product layer with GELU activation and RMS normalization in Cl(3,0)."""
166 | x = mv_gelu(x)
167 | y = mv_gelu(y)
168 | o = fcgp_3d(x, y, weight)
169 | if normalize:
170 | o = mv_rmsnorm_3d(o)
171 | return o
172 |
173 |
174 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 | import triton
7 |
8 | from contextlib import contextmanager
9 |
10 |
11 | @contextmanager
12 | def measure_memory():
13 | torch.cuda.reset_peak_memory_stats()
14 | torch.cuda.empty_cache()
15 | torch.cuda.synchronize()
16 |
17 | peak = [None]
18 | yield peak
19 |
20 | torch.cuda.synchronize()
21 | peak[0] = torch.cuda.max_memory_allocated() / 1024**2
22 |
23 |
24 | def print_benchmark_results(avg_time_fused, avg_time_torch, mem_fused, mem_torch, title=""):
25 | header_length_pad = max(0, 35 - len(title) - 4)
26 | print(f"\n┌─ {title} " + "─"*header_length_pad + "┐")
27 | print("│")
28 | print(f"│ Runtime:")
29 | print(f"│ Fused Kernel : {avg_time_fused:>8.2f} ms")
30 | print(f"│ PyTorch : {avg_time_torch:>8.2f} ms")
31 | print(f"│ Speedup : {avg_time_torch / avg_time_fused:>8.2f}×")
32 | print("│")
33 | print(f"│ Memory Usage:")
34 | print(f"│ Fused Kernel : {mem_fused:>8.2f} MB")
35 | print(f"│ PyTorch : {mem_torch:>8.2f} MB")
36 | print(f"│ Memory Ratio : {mem_fused / mem_torch:>8.2f}×")
37 | print("│")
38 | print("└" + "─"*34 + "┘\n")
39 |
40 |
41 | def run_benchmark(triton_fn, torch_fn, args, rep, warmup=200, verbose=True, return_mode='mean'):
42 | """Run forward and forward+backward benchmarks."""
43 | # Forward-only benchmark
44 | avg_time_fused = triton.testing.do_bench(lambda: triton_fn(*args), warmup, rep, return_mode=return_mode)
45 | avg_time_torch = triton.testing.do_bench(lambda: torch_fn(*args), warmup, rep, return_mode=return_mode)
46 | with measure_memory() as mem_fused_fwd: _ = triton_fn(*args)
47 | with measure_memory() as mem_torch_fwd: _ = torch_fn(*args)
48 |
49 | # Forward + backward benchmark
50 | args = [arg.clone().detach().requires_grad_(True) for arg in args]
51 | avg_time_fused_bwd = triton.testing.do_bench(lambda: triton_fn(*args).sum().backward(), warmup, rep, return_mode=return_mode)
52 | avg_time_torch_bwd = triton.testing.do_bench(lambda: torch_fn(*args).sum().backward(), warmup, rep, return_mode=return_mode)
53 | with measure_memory() as mem_fused_bwd: _ = triton_fn(*args)
54 | with measure_memory() as mem_torch_bwd: _ = torch_fn(*args)
55 |
56 | if verbose:
57 | print_benchmark_results(
58 | avg_time_fused, avg_time_torch, mem_fused_fwd[0], mem_torch_fwd[0],
59 | title="Forward Pass"
60 | )
61 | print_benchmark_results(
62 | avg_time_fused_bwd, avg_time_torch_bwd, mem_fused_bwd[0], mem_torch_bwd[0],
63 | title="Forward + Backward Pass"
64 | )
65 |
66 | return avg_time_fused, avg_time_torch, mem_fused_fwd[0], mem_torch_fwd[0], avg_time_fused_bwd, avg_time_torch_bwd, mem_fused_bwd[0], mem_torch_bwd[0]
67 |
68 |
69 | def run_correctness_test(triton_fn, torch_fn, kwargs):
70 | """Run forward and backward correctness test."""
71 | # Forward correctness check
72 | out_triton = triton_fn(**kwargs)
73 | out_torch = torch_fn(**kwargs)
74 |
75 | max_diff = (out_torch - out_triton).abs().max().item()
76 | is_correct = torch.allclose(out_torch, out_triton, atol=1e-5)
77 | check_mark = " ✔" if is_correct else " ✘"
78 | print(f"Max absolute difference (fwd): {max_diff:.1e}{check_mark}")
79 |
80 | # Backward correctness check
81 | kwargs = {k: v.clone().detach().requires_grad_(True) for k, v in kwargs.items()}
82 |
83 | out_torch = torch_fn(**kwargs)
84 | out_triton = triton_fn(**kwargs)
85 |
86 | grad_output = torch.randn_like(out_torch).contiguous()
87 | out_torch.backward(grad_output)
88 | out_triton.backward(grad_output)
89 |
90 | for name, arg in kwargs.items():
91 | grad_diff = (arg.grad - arg.grad).abs().max().item()
92 | grad_correct = torch.allclose(arg.grad, arg.grad, atol=1e-2)
93 | print(f"grad {name} max diff: {grad_diff:.1e}" + (" ✔" if grad_correct else " ✘"))
94 |
95 |
96 | def plot_heatmap(results, metric_key, title, save_path, cmap='RdYlGn', invert_cmap=False):
97 | """Heatmap visualization of benchmark results."""
98 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
99 |
100 | batch_sizes = sorted(set(r['batch_size'] for r in results))
101 | num_features_list = sorted(set(r['num_features'] for r in results))
102 |
103 | matrix = np.zeros((len(batch_sizes), len(num_features_list)))
104 | for r in results:
105 | i = batch_sizes.index(r['batch_size'])
106 | j = num_features_list.index(r['num_features'])
107 | matrix[i, j] = r[metric_key] if r[metric_key] is not None else 0
108 |
109 | fig, ax = plt.subplots(figsize=(10, 8))
110 | cmap_name = f'{cmap}_r' if invert_cmap else cmap
111 | im = ax.imshow(matrix, cmap=cmap_name, aspect='auto')
112 |
113 | ax.set_xticks(np.arange(len(num_features_list)))
114 | ax.set_yticks(np.arange(len(batch_sizes)))
115 | ax.set_xticklabels(num_features_list)
116 | ax.set_yticklabels(batch_sizes)
117 |
118 | ax.set_xlabel('Number of Features', fontsize=12)
119 | ax.set_ylabel('Batch Size', fontsize=12)
120 | ax.set_title(title, fontsize=14, pad=20)
121 |
122 | cbar = plt.colorbar(im, ax=ax)
123 | cbar_label = 'Speedup (x)' if 'speedup' in metric_key else 'Memory Ratio (x)'
124 | cbar.set_label(cbar_label, fontsize=12)
125 |
126 | # Add text annotations
127 | for i in range(len(batch_sizes)):
128 | for j in range(len(num_features_list)):
129 | value = matrix[i, j]
130 | text_str = f'{value:.2f}x' if value > 0 else 'OOM'
131 | ax.text(j, i, text_str, ha="center", va="center", color="black", fontsize=10)
132 |
133 | plt.tight_layout()
134 | plt.savefig(save_path, dpi=300, bbox_inches='tight')
135 | print(f"Heatmap saved to: {save_path}")
136 | plt.close()
137 |
138 |
139 | def print_results_table(results, title):
140 | """Print results as a formatted table."""
141 | separator = "=" * 140
142 | divider = "-" * 140
143 |
144 | # Forward pass results
145 | print(f"\n{separator}")
146 | print(f"FORWARD PASS RESULTS - {title}")
147 | print(separator)
148 | print(f"{'Batch':<8} {'Features':<10} {'Fused (ms)':<12} {'Torch (ms)':<12} "
149 | f"{'Speedup':<10} {'Max Diff':<12} {'Correct':<8}")
150 | print(divider)
151 |
152 | for r in results:
153 | speedup_str = f"{r['speedup_fwd']:.2f}x" if r['speedup_fwd'] else "N/A"
154 | correct_mark = '✔' if r['is_correct'] else '✘'
155 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {r['time_fwd_fused']:<12.2f} "
156 | f"{r['time_fwd_torch']:<12.2f} {speedup_str:<10} {r['max_diff']:<12.1e} "
157 | f"{correct_mark:<8}")
158 |
159 | print(separator)
160 |
161 | # Forward + backward pass results
162 | print(f"\n{separator}")
163 | print(f"FORWARD + BACKWARD PASS RESULTS - {title}")
164 | print(separator)
165 | print(f"{'Batch':<8} {'Features':<10} {'Fused (ms)':<12} {'Torch (ms)':<12} {'Speedup':<10}")
166 | print(divider)
167 |
168 | for r in results:
169 | fused_time = f"{r['time_fwd_bwd_fused']:.2f}" if r['time_fwd_bwd_fused'] else "OOM"
170 | torch_time = f"{r['time_fwd_bwd_torch']:.2f}" if r['time_fwd_bwd_torch'] else "OOM"
171 | speedup = f"{r['speedup_fwd_bwd']:.2f}x" if r['speedup_fwd_bwd'] else "N/A"
172 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {fused_time:<12} "
173 | f"{torch_time:<12} {speedup:<10}")
174 |
175 | print(separator)
176 |
177 | # Forward memory usage
178 | print(f"\n{separator}")
179 | print(f"FORWARD MEMORY USAGE - {title}")
180 | print(separator)
181 | print(f"{'Batch':<8} {'Features':<10} {'Fused (MB)':<12} {'Torch (MB)':<12} {'Ratio':<10}")
182 | print(divider)
183 |
184 | for r in results:
185 | fused_mem = f"{r['mem_fwd_fused']:.2f}" if r['mem_fwd_fused'] else "OOM"
186 | torch_mem = f"{r['mem_fwd_torch']:.2f}" if r['mem_fwd_torch'] else "OOM"
187 | ratio = f"{r['mem_ratio_fwd']:.2f}x" if r['mem_ratio_fwd'] else "N/A"
188 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {fused_mem:<12} "
189 | f"{torch_mem:<12} {ratio:<10}")
190 |
191 | print(separator)
192 |
193 | # Forward + backward memory usage
194 | print(f"\n{separator}")
195 | print(f"FORWARD + BACKWARD MEMORY USAGE - {title}")
196 | print(separator)
197 | print(f"{'Batch':<8} {'Features':<10} {'Fused (MB)':<12} {'Torch (MB)':<12} {'Ratio':<10}")
198 | print(divider)
199 |
200 | for r in results:
201 | fused_mem = f"{r['mem_fwd_bwd_fused']:.2f}" if r['mem_fwd_bwd_fused'] else "OOM"
202 | torch_mem = f"{r['mem_fwd_bwd_torch']:.2f}" if r['mem_fwd_bwd_torch'] else "OOM"
203 | ratio = f"{r['mem_ratio_fwd_bwd']:.2f}x" if r['mem_ratio_fwd_bwd'] else "N/A"
204 | print(f"{r['batch_size']:<8} {r['num_features']:<10} {fused_mem:<12} "
205 | f"{torch_mem:<12} {ratio:<10}")
206 |
207 | print(separator)
208 |
209 |
210 | def run_single_benchmark(triton_fn, torch_fn, setup_fn, batch_size, num_features, rep, warmup=200, verify_correctness=True, return_mode='mean'):
211 | """Run a single benchmark configuration."""
212 | args = setup_fn(batch_size, num_features)
213 |
214 | out_triton = triton_fn(*args)
215 | out_torch = torch_fn(*args)
216 |
217 | if verify_correctness:
218 | is_correct = torch.allclose(out_torch, out_triton, atol=1e-5)
219 | max_diff = (out_torch - out_triton).abs().max().item()
220 | else:
221 | is_correct = False
222 | max_diff = 1e10
223 |
224 | time_fwd_fused, time_fwd_torch, mem_fwd_fused, mem_fwd_torch, \
225 | time_fwd_bwd_fused, time_fwd_bwd_torch, mem_fwd_bwd_fused, mem_fwd_bwd_torch = \
226 | run_benchmark(
227 | triton_fn, torch_fn, args, rep,
228 | warmup=warmup, verbose=False, return_mode=return_mode
229 | )
230 |
231 | return {
232 | 'batch_size': batch_size,
233 | 'num_features': num_features,
234 | 'time_fwd_fused': time_fwd_fused,
235 | 'time_fwd_torch': time_fwd_torch,
236 | 'speedup_fwd': time_fwd_torch / time_fwd_fused if time_fwd_fused else None,
237 | 'time_fwd_bwd_fused': time_fwd_bwd_fused,
238 | 'time_fwd_bwd_torch': time_fwd_bwd_torch,
239 | 'speedup_fwd_bwd': time_fwd_bwd_torch / time_fwd_bwd_fused if time_fwd_bwd_fused else None,
240 | 'mem_fwd_fused': mem_fwd_fused,
241 | 'mem_fwd_torch': mem_fwd_torch,
242 | 'mem_ratio_fwd': mem_fwd_fused / mem_fwd_torch if mem_fwd_torch else None,
243 | 'mem_fwd_bwd_fused': mem_fwd_bwd_fused,
244 | 'mem_fwd_bwd_torch': mem_fwd_bwd_torch,
245 | 'mem_ratio_fwd_bwd': mem_fwd_bwd_fused / mem_fwd_bwd_torch if mem_fwd_bwd_torch else None,
246 | 'max_diff': max_diff,
247 | 'is_correct': is_correct,
248 | }
249 |
250 |
251 | def run_sweep(triton_fn, torch_fn, setup_fn,
252 | batch_sizes=[1024, 2048, 4096, 8192],
253 | num_features_list=[128, 256, 512, 1024],
254 | rep=1000):
255 | """Run benchmark sweep across batch sizes and feature dimensions."""
256 | results = []
257 |
258 | print("Running benchmark sweep...")
259 | print(f"Batch sizes: {batch_sizes}")
260 | print(f"Num features: {num_features_list}")
261 |
262 | for batch_size in batch_sizes:
263 | for num_features in num_features_list:
264 | print(f"Running batch_size={batch_size}, num_features={num_features}...", end=" ")
265 | result = run_single_benchmark(triton_fn, torch_fn, setup_fn, batch_size, num_features, rep)
266 | results.append(result)
267 |
268 | fwd_msg = (f"Fwd: {result['speedup_fwd']:.2f}x"
269 | if result['speedup_fwd'] else "Fwd: OOM")
270 | bwd_msg = (f"Fwd+Bwd: {result['speedup_fwd_bwd']:.2f}x"
271 | if result['speedup_fwd_bwd'] else "Fwd+Bwd: OOM")
272 | print(f"{fwd_msg}, {bwd_msg}")
273 |
274 | return results
--------------------------------------------------------------------------------
/ops/p2m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | MV_DIM = 4
6 | NUM_GRADES = 3
7 | NUM_PRODUCT_WEIGHTS = 10
8 | EPS = 1e-6
9 |
10 | # tuned at RTX 4500
11 | DEFAULT_BATCH_BLOCK = 4
12 | DEFAULT_FEATURE_BLOCK = 128
13 | DEFAULT_NUM_WARPS = 16
14 | DEFAULT_NUM_STAGES = 1
15 |
16 |
17 | @triton.jit
18 | def compute_gelu_gate(x):
19 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))"""
20 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype)
21 |
22 |
23 | @triton.jit
24 | def compute_gelu_gate_grad(x):
25 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)"""
26 | return 0.3989422804 * tl.exp(-0.5 * x * x)
27 |
28 |
29 | @triton.jit
30 | def gelu_wgp_norm_kernel_fwd(
31 | x_ptr,
32 | y_ptr,
33 | output_ptr,
34 | weights_ptr,
35 | pnorm_ptr,
36 | NORMALIZE: tl.constexpr,
37 | batch_size: tl.constexpr,
38 | n_features: tl.constexpr,
39 | BATCH_BLOCK: tl.constexpr,
40 | FEATURE_BLOCK: tl.constexpr,
41 | NUM_PRODUCT_WEIGHTS: tl.constexpr,
42 | ):
43 | """
44 | Apply GELU non-linearity to inputs, compute weighted geometric product,
45 | and accumulate squared norms for grade-wise RMSNorm.
46 | """
47 | batch_block_id = tl.program_id(axis=0)
48 | thread_block_id = tl.program_id(axis=1)
49 |
50 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
51 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
52 |
53 | batch_mask = batch_ids < batch_size
54 | feature_mask = feature_ids < n_features
55 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
56 |
57 | stride_component = batch_size * n_features
58 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
59 |
60 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS
61 |
62 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask)
63 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask)
64 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask)
65 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask)
66 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask)
67 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask)
68 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask)
69 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask)
70 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask)
71 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask)
72 |
73 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
74 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
75 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
76 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
77 |
78 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
79 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
80 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
81 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
82 |
83 | # Apply GELU gate
84 | gate_x = compute_gelu_gate(x0)
85 | gate_y = compute_gelu_gate(y0)
86 |
87 | x0 = x0 * gate_x
88 | x1 = x1 * gate_x
89 | x2 = x2 * gate_x
90 | x3 = x3 * gate_x
91 |
92 | y0 = y0 * gate_y
93 | y1 = y1 * gate_y
94 | y2 = y2 * gate_y
95 | y3 = y3 * gate_y
96 |
97 | # Compute geometric product
98 | o0 = w0*x0*y0 + w3*(x1*y1 + x2*y2) - w7*x3*y3
99 | o1 = w1*x0*y1 + w4*x1*y0 - w5*x2*y3 + w8*x3*y2
100 | o2 = w1*x0*y2 + w5*x1*y3 + w4*x2*y0 - w8*x3*y1
101 | o3 = w2*x0*y3 + w6*(x1*y2 - x2*y1) + w9*x3*y0
102 |
103 | if NORMALIZE:
104 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features
105 | pn_vector = tl.sum(o1*o1 + o2*o2, axis=1) / n_features
106 | pn_pseudo = tl.sum(o3 * o3, axis=1) / n_features
107 |
108 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask)
109 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask)
110 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask)
111 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_pseudo, mask=batch_mask)
112 |
113 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask)
114 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask)
115 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask)
116 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask)
117 |
118 |
119 | @triton.jit
120 | def normalize_with_sqrt_kernel(
121 | output_ptr,
122 | pnorm_ptr,
123 | batch_size: tl.constexpr,
124 | n_features: tl.constexpr,
125 | BATCH_BLOCK: tl.constexpr,
126 | FEATURE_BLOCK: tl.constexpr,
127 | MV_DIM: tl.constexpr,
128 | EPS: tl.constexpr,
129 | ):
130 | """Normalize the output by dividing each grade with root of its accumulated norm."""
131 | batch_block_id = tl.program_id(axis=0)
132 | thread_block_id = tl.program_id(axis=1)
133 |
134 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
135 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
136 |
137 | batch_mask = batch_ids < batch_size
138 | feature_mask = feature_ids < n_features
139 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None]
140 |
141 | component_ids = tl.arange(0, MV_DIM)[None, None, :]
142 |
143 | feature_offset = (component_ids * batch_size * n_features +
144 | batch_ids[:, None, None] * n_features +
145 | feature_ids[None, :, None])
146 |
147 | norm_indices = component_ids * batch_size + batch_ids[:, None, None]
148 |
149 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None])
150 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask)
151 |
152 | norm = tl.sqrt(pnorm + EPS)
153 | mv_normalized = mv / norm
154 |
155 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask)
156 |
157 |
158 | def gelu_geometric_product_norm_fwd(
159 | x: torch.Tensor,
160 | y: torch.Tensor,
161 | weight: torch.Tensor,
162 | normalize: bool,
163 | ) -> torch.Tensor:
164 | """Fused operation: GELU non-linearity, weighted geometric product, and grade-wise RMSNorm."""
165 | assert x.shape == y.shape
166 | assert x.shape[0] == MV_DIM
167 | assert x.shape[2] == weight.shape[0]
168 | assert weight.shape[1] == NUM_PRODUCT_WEIGHTS
169 |
170 | _, B, N = x.shape
171 |
172 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
173 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
174 |
175 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
176 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
177 |
178 | output = torch.empty_like(x)
179 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize
180 | else torch.zeros((1,), device=x.device, dtype=x.dtype))
181 |
182 | grid = (num_blocks_batch, num_blocks_features)
183 |
184 | gelu_wgp_norm_kernel_fwd[grid](
185 | x,
186 | y,
187 | output,
188 | weight,
189 | partial_norm,
190 | normalize,
191 | B,
192 | N,
193 | BATCH_BLOCK,
194 | FEATURE_BLOCK,
195 | NUM_PRODUCT_WEIGHTS,
196 | num_warps=DEFAULT_NUM_WARPS,
197 | num_stages=DEFAULT_NUM_STAGES,
198 | )
199 |
200 | if normalize:
201 | normalize_with_sqrt_kernel[grid](
202 | output,
203 | partial_norm,
204 | B,
205 | N,
206 | BATCH_BLOCK,
207 | FEATURE_BLOCK,
208 | MV_DIM,
209 | EPS,
210 | num_warps=DEFAULT_NUM_WARPS,
211 | num_stages=DEFAULT_NUM_STAGES,
212 | )
213 |
214 | return output, partial_norm
215 |
216 |
217 | @triton.jit
218 | def grad_o_dot_o_kernel(
219 | dot_ptr,
220 | pnorm_ptr,
221 | output_ptr,
222 | grad_output_ptr,
223 | batch_size: tl.constexpr,
224 | n_features: tl.constexpr,
225 | BATCH_BLOCK: tl.constexpr,
226 | FEATURE_BLOCK: tl.constexpr,
227 | EPS: tl.constexpr,
228 | ):
229 | """Compute the dot product of grad_output and output for each grade, accumulate across all features."""
230 | batch_block_id = tl.program_id(axis=0)
231 | thread_block_id = tl.program_id(axis=1)
232 |
233 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
234 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
235 |
236 | batch_mask = batch_ids < batch_size
237 | feature_mask = feature_ids < n_features
238 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
239 |
240 | stride_component = batch_size * n_features
241 | offset = batch_ids[:, None] * n_features + feature_ids[None, :]
242 |
243 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
244 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
245 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
246 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
247 |
248 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
249 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
250 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None]
251 |
252 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
253 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
254 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
255 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
256 |
257 | rms_scalar = tl.sqrt(pn_scalar + EPS)
258 | rms_vector = tl.sqrt(pn_vector + EPS)
259 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
260 |
261 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1)
262 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2), axis=1)
263 | dot_pseudo = tl.sum(rms_pseudo * go3 * o3, axis=1)
264 |
265 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask)
266 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask)
267 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_pseudo, mask=batch_mask)
268 |
269 |
270 | @triton.jit
271 | def gelu_wgp_norm_kernel_bwd(
272 | x_ptr,
273 | y_ptr,
274 | output_ptr,
275 | weights_ptr,
276 | dot_ptr,
277 | pnorm_ptr,
278 | grad_output_ptr,
279 | grad_x_ptr,
280 | grad_y_ptr,
281 | grad_weight_ptr,
282 | NORMALIZE: tl.constexpr,
283 | batch_size: tl.constexpr,
284 | n_features: tl.constexpr,
285 | BATCH_BLOCK: tl.constexpr,
286 | FEATURE_BLOCK: tl.constexpr,
287 | NUM_PRODUCT_WEIGHTS: tl.constexpr,
288 | EPS: tl.constexpr,
289 | ):
290 | """Compute gradients w.r.t. inputs and weights of the fused operation."""
291 | batch_block_id = tl.program_id(axis=0)
292 | thread_block_id = tl.program_id(axis=1)
293 |
294 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
295 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
296 |
297 | batch_mask = batch_ids < batch_size
298 | feature_mask = feature_ids < n_features
299 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
300 |
301 | stride_component = batch_size * n_features
302 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
303 |
304 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS
305 | block_offset = batch_block_id * n_features * NUM_PRODUCT_WEIGHTS
306 |
307 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask)
308 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask)
309 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask)
310 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask)
311 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask)
312 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask)
313 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask)
314 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask)
315 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask)
316 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask)
317 |
318 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
319 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
320 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
321 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
322 |
323 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
324 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
325 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
326 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
327 |
328 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
329 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
330 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
331 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
332 |
333 | if NORMALIZE:
334 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
335 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
336 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
337 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
338 |
339 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
340 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
341 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None]
342 |
343 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
344 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
345 | dot_pseudo = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None]
346 |
347 | rms_scalar = tl.sqrt(pn_scalar + EPS)
348 | rms_vector = tl.sqrt(pn_vector + EPS)
349 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
350 |
351 | go0 = go0/rms_scalar - o0 * dot_scalar / (n_features*rms_scalar*rms_scalar)
352 | go1 = go1/rms_vector - o1 * dot_vector / (n_features*rms_vector*rms_vector)
353 | go2 = go2/rms_vector - o2 * dot_vector / (n_features*rms_vector*rms_vector)
354 | go3 = go3/rms_pseudo - o3 * dot_pseudo / (n_features*rms_pseudo*rms_pseudo)
355 |
356 | # weighted geometric product backward
357 | gate_x = compute_gelu_gate(x0_raw)
358 | gate_y = compute_gelu_gate(y0_raw)
359 |
360 | x0 = x0_raw * gate_x
361 | x1 = x1_raw * gate_x
362 | x2 = x2_raw * gate_x
363 | x3 = x3_raw * gate_x
364 |
365 | y0 = y0_raw * gate_y
366 | y1 = y1_raw * gate_y
367 | y2 = y2_raw * gate_y
368 | y3 = y3_raw * gate_y
369 |
370 | tmp0 = go0*w0
371 | tmp1 = go1*w1
372 | tmp2 = go2*w1
373 | tmp3 = go0*w3
374 | tmp4 = w4*y0
375 | tmp5 = w5*y3
376 | tmp6 = go3*w6
377 | tmp7 = go0*w7
378 | tmp8 = go2*w8
379 | tmp9 = go1*x1
380 | tmp10 = go2*x2
381 | tmp11 = go1*x2
382 |
383 | x_grad_0 = (go3*w2*y3 + tmp0*y0 + tmp1*y1 + tmp2*y2)
384 | x_grad_1 = (go1*tmp4 + go2*tmp5 + tmp3*y1 + tmp6*y2)
385 | x_grad_2 = (-go1*tmp5 + go2*tmp4 + tmp3*y2 - tmp6*y1)
386 | x_grad_3 = (go1*w8*y2 + go3*w9*y0 - tmp7*y3 - tmp8*y1)
387 |
388 | y_grad_0 = (go3*w9*x3 + tmp0*x0 + tmp10*w4 + tmp9*w4)
389 | y_grad_1 = (tmp1*x0 + tmp3*x1 - tmp6*x2 - tmp8*x3)
390 | y_grad_2 = (go1*w8*x3 + tmp2*x0 + tmp3*x2 + tmp6*x1)
391 | y_grad_3 = (go2*w5*x1 + go3*w2*x0 - tmp11*w5 - tmp7*x3)
392 |
393 | w_grad_0 = tl.sum(go0*x0*y0, axis=0)
394 | w_grad_1 = tl.sum(go1*x0*y1 + go2*x0*y2, axis=0)
395 | w_grad_2 = tl.sum(go3*x0*y3, axis=0)
396 | w_grad_3 = tl.sum(go0*(x1*y1 + x2*y2), axis=0)
397 | w_grad_4 = tl.sum(tmp10*y0 + tmp9*y0, axis=0)
398 | w_grad_5 = tl.sum(go2*x1*y3 - tmp11*y3, axis=0)
399 | w_grad_6 = tl.sum(go3*(x1*y2 - x2*y1), axis=0)
400 | w_grad_7 = tl.sum(-go0*x3*y3, axis=0)
401 | w_grad_8 = tl.sum(go1*x3*y2 - go2*x3*y1, axis=0)
402 | w_grad_9 = tl.sum(go3*x3*y0, axis=0)
403 |
404 | # GELU gate gradients
405 | dgate_x = compute_gelu_gate_grad(x0_raw)
406 | dgate_y = compute_gelu_gate_grad(y0_raw)
407 |
408 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3)
409 | x_grad_1 = gate_x * x_grad_1
410 | x_grad_2 = gate_x * x_grad_2
411 | x_grad_3 = gate_x * x_grad_3
412 |
413 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3)
414 | y_grad_1 = gate_y * y_grad_1
415 | y_grad_2 = gate_y * y_grad_2
416 | y_grad_3 = gate_y * y_grad_3
417 |
418 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask)
419 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask)
420 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask)
421 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask)
422 |
423 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask)
424 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask)
425 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask)
426 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask)
427 |
428 | tl.store(grad_weight_ptr + block_offset + weight_offset + 0, w_grad_0, mask=feature_mask)
429 | tl.store(grad_weight_ptr + block_offset + weight_offset + 1, w_grad_1, mask=feature_mask)
430 | tl.store(grad_weight_ptr + block_offset + weight_offset + 2, w_grad_2, mask=feature_mask)
431 | tl.store(grad_weight_ptr + block_offset + weight_offset + 3, w_grad_3, mask=feature_mask)
432 | tl.store(grad_weight_ptr + block_offset + weight_offset + 4, w_grad_4, mask=feature_mask)
433 | tl.store(grad_weight_ptr + block_offset + weight_offset + 5, w_grad_5, mask=feature_mask)
434 | tl.store(grad_weight_ptr + block_offset + weight_offset + 6, w_grad_6, mask=feature_mask)
435 | tl.store(grad_weight_ptr + block_offset + weight_offset + 7, w_grad_7, mask=feature_mask)
436 | tl.store(grad_weight_ptr + block_offset + weight_offset + 8, w_grad_8, mask=feature_mask)
437 | tl.store(grad_weight_ptr + block_offset + weight_offset + 9, w_grad_9, mask=feature_mask)
438 |
439 |
440 | def gelu_geometric_product_norm_bwd(
441 | x: torch.Tensor,
442 | y: torch.Tensor,
443 | weight: torch.Tensor,
444 | o: torch.Tensor,
445 | partial_norm: torch.Tensor,
446 | grad_output: torch.Tensor,
447 | normalize: bool,
448 | ) -> torch.Tensor:
449 | """Backward pass for the fused operation."""
450 | _, B, N = x.shape
451 |
452 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
453 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
454 |
455 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
456 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
457 |
458 | grad_x = torch.zeros_like(x)
459 | grad_y = torch.zeros_like(y)
460 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0))
461 | grad_weight = torch.zeros((num_blocks_batch, N, NUM_PRODUCT_WEIGHTS), device=x.device, dtype=weight.dtype)
462 |
463 | grid = (num_blocks_batch, num_blocks_features)
464 |
465 | if normalize:
466 | grad_o_dot_o_kernel[grid](
467 | dot,
468 | partial_norm,
469 | o,
470 | grad_output,
471 | B,
472 | N,
473 | BATCH_BLOCK,
474 | FEATURE_BLOCK,
475 | EPS,
476 | num_warps=DEFAULT_NUM_WARPS,
477 | num_stages=DEFAULT_NUM_STAGES,
478 | )
479 |
480 | gelu_wgp_norm_kernel_bwd[grid](
481 | x,
482 | y,
483 | o,
484 | weight,
485 | dot,
486 | partial_norm,
487 | grad_output,
488 | grad_x,
489 | grad_y,
490 | grad_weight,
491 | normalize,
492 | B,
493 | N,
494 | BATCH_BLOCK,
495 | FEATURE_BLOCK,
496 | NUM_PRODUCT_WEIGHTS,
497 | EPS,
498 | num_warps=DEFAULT_NUM_WARPS,
499 | num_stages=DEFAULT_NUM_STAGES,
500 | )
501 |
502 | grad_weight = torch.sum(grad_weight, dim=0)
503 |
504 | return grad_x, grad_y, grad_weight
505 |
506 |
507 | class WeightedGeluGeometricProductNorm2D(torch.autograd.Function):
508 |
509 | @staticmethod
510 | @torch.amp.custom_fwd(device_type="cuda")
511 | def forward(ctx, x, y, weight, normalize):
512 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous()
513 |
514 | ctx.dtype = x.dtype
515 | ctx.normalize = normalize
516 |
517 | o, partial_norm = gelu_geometric_product_norm_fwd(
518 | x,
519 | y,
520 | weight,
521 | normalize,
522 | )
523 |
524 | ctx.save_for_backward(x, y, weight, o, partial_norm)
525 |
526 | return o.to(x.dtype)
527 |
528 | @staticmethod
529 | @torch.amp.custom_bwd(device_type="cuda")
530 | def backward(ctx, grad_output):
531 | grad_output = grad_output.contiguous()
532 |
533 | x, y, weight, o, partial_norm = ctx.saved_tensors
534 |
535 | grad_x, grad_y, grad_weight = gelu_geometric_product_norm_bwd(
536 | x,
537 | y,
538 | weight,
539 | o,
540 | partial_norm,
541 | grad_output,
542 | ctx.normalize,
543 | )
544 |
545 | return grad_x, grad_y, grad_weight, None, None, None, None
546 |
547 |
548 | def fused_gelu_sgp_norm_2d(x, y, weight, normalize=True):
549 | """
550 | Fused operation that applies GELU non-linearity to two multivector inputs,
551 | then computes their weighted geometric product, and applies RMSNorm.
552 |
553 | Clifford algebra is assumed to be Cl(2,0).
554 |
555 | Args:
556 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
557 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
558 | weight (torch.Tensor): Weight tensor of shape (N, NUM_PRODUCT_WEIGHTS), one weight per geometric product component.
559 | normalize (bool): Whether to apply RMSNorm after the geometric product.
560 |
561 | Returns:
562 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation.
563 | """
564 | return WeightedGeluGeometricProductNorm2D.apply(x, y, weight, normalize)
--------------------------------------------------------------------------------
/ops/fc_p2m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | MV_DIM = 4
6 | NUM_GRADES = 3
7 | NUM_PRODUCT_WEIGHTS = 10
8 | WEIGHT_EXPANSION = [0, 3, 7, 1, 4, 5, 8, 1, 5, 4, 8, 2, 6, 9]
9 | EPS = 1e-6
10 |
11 | # tuned at RTX 4500
12 | DEFAULT_BATCH_BLOCK = 4
13 | DEFAULT_FEATURE_BLOCK = 128
14 | DEFAULT_NUM_WARPS = 16
15 | DEFAULT_NUM_STAGES = 1
16 |
17 |
18 | @triton.jit
19 | def compute_gelu_gate(x):
20 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))"""
21 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype)
22 |
23 |
24 | @triton.jit
25 | def compute_gelu_gate_grad(x):
26 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)"""
27 | return 0.3989422804 * tl.exp(-0.5 * x * x)
28 |
29 |
30 | @triton.jit
31 | def gelu_pairwise_kernel_fwd(
32 | x_ptr,
33 | y_ptr,
34 | pairwise_ptr,
35 | batch_size: tl.constexpr,
36 | n_features: tl.constexpr,
37 | BATCH_BLOCK: tl.constexpr,
38 | FEATURE_BLOCK: tl.constexpr,
39 | ):
40 | """
41 | Apply GELU non-linearity to inputs and compute pairwise products for geometric product.
42 | """
43 | batch_block_id = tl.program_id(axis=0)
44 | thread_block_id = tl.program_id(axis=1)
45 |
46 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
47 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
48 |
49 | batch_mask = batch_ids < batch_size
50 | feature_mask = feature_ids < n_features
51 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
52 |
53 | stride_component = batch_size * n_features
54 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
55 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
56 |
57 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
58 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
59 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
60 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
61 |
62 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
63 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
64 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
65 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
66 |
67 | gate_x = compute_gelu_gate(x0)
68 | gate_y = compute_gelu_gate(y0)
69 |
70 | x0 = x0 * gate_x
71 | x1 = x1 * gate_x
72 | x2 = x2 * gate_x
73 | x3 = x3 * gate_x
74 |
75 | y0 = y0 * gate_y
76 | y1 = y1 * gate_y
77 | y2 = y2 * gate_y
78 | y3 = y3 * gate_y
79 |
80 | p0 = x0 * y0
81 | p1 = x1*y1 + x2*y2
82 | p2 = -x3 * y3
83 | p3 = x0 * y1
84 | p4 = x1 * y0
85 | p5 = x2 * y3
86 | p6 = x3 * y2
87 | p7 = x0 * y2
88 | p8 = x1 * y3
89 | p9 = x2 * y0
90 | p10 = -x3 * y1
91 | p11 = x0 * y3
92 | p12 = x1*y2 - x2*y1
93 | p13 = x3 * y0
94 |
95 | tl.store(pairwise_ptr + 0*batch_size*n_features + pairwise_offset, p0, mask=batch_feature_mask)
96 | tl.store(pairwise_ptr + 1*batch_size*n_features + pairwise_offset, p1, mask=batch_feature_mask)
97 | tl.store(pairwise_ptr + 2*batch_size*n_features + pairwise_offset, p2, mask=batch_feature_mask)
98 | tl.store(pairwise_ptr + 3*batch_size*n_features + pairwise_offset, p3, mask=batch_feature_mask)
99 | tl.store(pairwise_ptr + 4*batch_size*n_features + pairwise_offset, p4, mask=batch_feature_mask)
100 | tl.store(pairwise_ptr + 5*batch_size*n_features + pairwise_offset, p5, mask=batch_feature_mask)
101 | tl.store(pairwise_ptr + 6*batch_size*n_features + pairwise_offset, p6, mask=batch_feature_mask)
102 | tl.store(pairwise_ptr + 7*batch_size*n_features + pairwise_offset, p7, mask=batch_feature_mask)
103 | tl.store(pairwise_ptr + 8*batch_size*n_features + pairwise_offset, p8, mask=batch_feature_mask)
104 | tl.store(pairwise_ptr + 9*batch_size*n_features + pairwise_offset, p9, mask=batch_feature_mask)
105 | tl.store(pairwise_ptr + 10*batch_size*n_features + pairwise_offset, p10, mask=batch_feature_mask)
106 | tl.store(pairwise_ptr + 11*batch_size*n_features + pairwise_offset, p11, mask=batch_feature_mask)
107 | tl.store(pairwise_ptr + 12*batch_size*n_features + pairwise_offset, p12, mask=batch_feature_mask)
108 | tl.store(pairwise_ptr + 13*batch_size*n_features + pairwise_offset, p13, mask=batch_feature_mask)
109 |
110 |
111 | @triton.jit
112 | def assemble_kernel(
113 | transformed_ptr,
114 | pnorm_ptr,
115 | output_ptr,
116 | NORMALIZE: tl.constexpr,
117 | batch_size: tl.constexpr,
118 | n_features: tl.constexpr,
119 | BATCH_BLOCK: tl.constexpr,
120 | FEATURE_BLOCK: tl.constexpr,
121 | ):
122 | """Gather linearly transformed pairwise products and compute the geometric product."""
123 | batch_block_id = tl.program_id(axis=0)
124 | thread_block_id = tl.program_id(axis=1)
125 |
126 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
127 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
128 |
129 | batch_mask = batch_ids < batch_size
130 | feature_mask = feature_ids < n_features
131 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
132 |
133 | stride_component = batch_size * n_features
134 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
135 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
136 |
137 | t0 = tl.load(transformed_ptr + 0*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
138 | t1 = tl.load(transformed_ptr + 1*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
139 | t2 = tl.load(transformed_ptr + 2*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
140 | t3 = tl.load(transformed_ptr + 3*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
141 | t4 = tl.load(transformed_ptr + 4*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
142 | t5 = tl.load(transformed_ptr + 5*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
143 | t6 = tl.load(transformed_ptr + 6*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
144 | t7 = tl.load(transformed_ptr + 7*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
145 | t8 = tl.load(transformed_ptr + 8*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
146 | t9 = tl.load(transformed_ptr + 9*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
147 | t10 = tl.load(transformed_ptr + 10*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
148 | t11 = tl.load(transformed_ptr + 11*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
149 | t12 = tl.load(transformed_ptr + 12*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
150 | t13 = tl.load(transformed_ptr + 13*batch_size*n_features + transformed_offset, mask=batch_feature_mask)
151 |
152 | o0 = t0 + t1 + t2
153 | o1 = t3 + t4 - t5 + t6
154 | o2 = t7 + t8 + t9 + t10
155 | o3 = t11 + t12 + t13
156 |
157 | if NORMALIZE:
158 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features
159 | pn_vector = tl.sum(o1*o1 + o2*o2, axis=1) / n_features
160 | pn_pseudo = tl.sum(o3 * o3, axis=1) / n_features
161 |
162 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask)
163 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask)
164 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask)
165 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_pseudo, mask=batch_mask)
166 |
167 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask)
168 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask)
169 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask)
170 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask)
171 |
172 |
173 | @triton.jit
174 | def normalize_with_sqrt_kernel(
175 | output_ptr,
176 | pnorm_ptr,
177 | batch_size: tl.constexpr,
178 | n_features: tl.constexpr,
179 | BATCH_BLOCK: tl.constexpr,
180 | FEATURE_BLOCK: tl.constexpr,
181 | MV_DIM: tl.constexpr,
182 | EPS: tl.constexpr,
183 | ):
184 | """Normalize the output by dividing each grade with root of its accumulated norm."""
185 | batch_block_id = tl.program_id(axis=0)
186 | thread_block_id = tl.program_id(axis=1)
187 |
188 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
189 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
190 |
191 | batch_mask = batch_ids < batch_size
192 | feature_mask = feature_ids < n_features
193 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None]
194 |
195 | component_ids = tl.arange(0, MV_DIM)[None, None, :]
196 |
197 | feature_offset = (component_ids * batch_size * n_features +
198 | batch_ids[:, None, None] * n_features +
199 | feature_ids[None, :, None])
200 |
201 | norm_indices = component_ids * batch_size + batch_ids[:, None, None]
202 |
203 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None])
204 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask)
205 |
206 | norm = tl.sqrt(pnorm + EPS)
207 | mv_normalized = mv / norm
208 |
209 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask)
210 |
211 |
212 | def gelu_fc_geometric_product_norm_fwd(
213 | x: torch.Tensor,
214 | y: torch.Tensor,
215 | weight: torch.Tensor,
216 | expansion_indices: torch.Tensor,
217 | normalize: bool,
218 | ) -> torch.Tensor:
219 | """Fused operation: GELU non-linearity, fully connected geometric product, and grade-wise RMSNorm."""
220 | assert x.shape == y.shape
221 | assert x.shape[0] == MV_DIM
222 | assert x.shape[2] == weight.shape[1] == weight.shape[2]
223 | assert weight.shape[0] == NUM_PRODUCT_WEIGHTS
224 |
225 | _, B, N = x.shape
226 |
227 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
228 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
229 |
230 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
231 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
232 |
233 | pairwise = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype)
234 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize else torch.zeros((1,), device=x.device, dtype=x.dtype))
235 | output = torch.empty_like(x)
236 |
237 | grid = (num_blocks_batch, num_blocks_features)
238 |
239 | gelu_pairwise_kernel_fwd[grid](
240 | x,
241 | y,
242 | pairwise,
243 | B,
244 | N,
245 | BATCH_BLOCK,
246 | FEATURE_BLOCK,
247 | num_warps=DEFAULT_NUM_WARPS,
248 | num_stages=DEFAULT_NUM_STAGES,
249 | )
250 |
251 | transformed = torch.bmm(pairwise, weight[expansion_indices])
252 |
253 | assemble_kernel[grid](
254 | transformed,
255 | partial_norm,
256 | output,
257 | normalize,
258 | B,
259 | N,
260 | BATCH_BLOCK,
261 | FEATURE_BLOCK,
262 | num_warps=DEFAULT_NUM_WARPS,
263 | num_stages=DEFAULT_NUM_STAGES,
264 | )
265 |
266 | if normalize:
267 | normalize_with_sqrt_kernel[grid](
268 | output,
269 | partial_norm,
270 | B,
271 | N,
272 | BATCH_BLOCK,
273 | FEATURE_BLOCK,
274 | MV_DIM,
275 | EPS,
276 | num_warps=DEFAULT_NUM_WARPS,
277 | num_stages=DEFAULT_NUM_STAGES,
278 | )
279 |
280 | return output, pairwise, partial_norm
281 |
282 |
283 | @triton.jit
284 | def grad_o_dot_o_kernel(
285 | dot_ptr,
286 | pnorm_ptr,
287 | output_ptr,
288 | grad_output_ptr,
289 | batch_size: tl.constexpr,
290 | n_features: tl.constexpr,
291 | BATCH_BLOCK: tl.constexpr,
292 | FEATURE_BLOCK: tl.constexpr,
293 | EPS: tl.constexpr,
294 | ):
295 | """Compute the dot product of grad_output and output for each grade, accumulate across all features."""
296 | batch_block_id = tl.program_id(axis=0)
297 | thread_block_id = tl.program_id(axis=1)
298 |
299 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
300 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
301 |
302 | batch_mask = batch_ids < batch_size
303 | feature_mask = feature_ids < n_features
304 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
305 |
306 | stride_component = batch_size * n_features
307 | offset = batch_ids[:, None] * n_features + feature_ids[None, :]
308 |
309 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
310 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
311 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
312 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
313 |
314 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
315 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
316 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None]
317 |
318 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
319 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
320 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
321 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
322 |
323 | rms_scalar = tl.sqrt(pn_scalar + EPS)
324 | rms_vector = tl.sqrt(pn_vector + EPS)
325 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
326 |
327 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1)
328 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2), axis=1)
329 | dot_pseudo = tl.sum(rms_pseudo * go3 * o3, axis=1)
330 |
331 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask)
332 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask)
333 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_pseudo, mask=batch_mask)
334 |
335 |
336 | @triton.jit
337 | def disassemble_kernel(
338 | grad_output_ptr,
339 | output_ptr,
340 | dot_ptr,
341 | grad_transformed_ptr,
342 | pnorm_ptr,
343 | NORMALIZE: tl.constexpr,
344 | batch_size: tl.constexpr,
345 | n_features: tl.constexpr,
346 | BATCH_BLOCK: tl.constexpr,
347 | FEATURE_BLOCK: tl.constexpr,
348 | EPS: tl.constexpr,
349 | ):
350 | """
351 | Gather linearly transformed pairwise products and compute the geometric product.
352 | """
353 | batch_block_id = tl.program_id(axis=0)
354 | thread_block_id = tl.program_id(axis=1)
355 |
356 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
357 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
358 |
359 | batch_mask = batch_ids < batch_size
360 | feature_mask = feature_ids < n_features
361 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
362 |
363 | stride_component = batch_size * n_features
364 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
365 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
366 |
367 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
368 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
369 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
370 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
371 |
372 | if NORMALIZE:
373 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
374 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
375 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
376 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
377 |
378 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
379 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
380 | pn_pseudo = tl.load(pnorm_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None]
381 |
382 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
383 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
384 | dot_pseudo = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None]
385 |
386 | rms_scalar = tl.sqrt(pn_scalar + EPS)
387 | rms_vector = tl.sqrt(pn_vector + EPS)
388 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
389 |
390 | go0 = go0/rms_scalar - o0 * dot_scalar / (n_features*rms_scalar*rms_scalar)
391 | go1 = go1/rms_vector - o1 * dot_vector / (n_features*rms_vector*rms_vector)
392 | go2 = go2/rms_vector - o2 * dot_vector / (n_features*rms_vector*rms_vector)
393 | go3 = go3/rms_pseudo - o3 * dot_pseudo / (n_features*rms_pseudo*rms_pseudo)
394 |
395 | tl.store(grad_transformed_ptr + 0*batch_size*n_features + transformed_offset, go0, mask=batch_feature_mask)
396 | tl.store(grad_transformed_ptr + 1*batch_size*n_features + transformed_offset, go0, mask=batch_feature_mask)
397 | tl.store(grad_transformed_ptr + 2*batch_size*n_features + transformed_offset, go0, mask=batch_feature_mask)
398 | tl.store(grad_transformed_ptr + 3*batch_size*n_features + transformed_offset, go1, mask=batch_feature_mask)
399 | tl.store(grad_transformed_ptr + 4*batch_size*n_features + transformed_offset, go1, mask=batch_feature_mask)
400 | tl.store(grad_transformed_ptr + 5*batch_size*n_features + transformed_offset, -go1, mask=batch_feature_mask)
401 | tl.store(grad_transformed_ptr + 6*batch_size*n_features + transformed_offset, go1, mask=batch_feature_mask)
402 | tl.store(grad_transformed_ptr + 7*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask)
403 | tl.store(grad_transformed_ptr + 8*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask)
404 | tl.store(grad_transformed_ptr + 9*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask)
405 | tl.store(grad_transformed_ptr + 10*batch_size*n_features + transformed_offset, go2, mask=batch_feature_mask)
406 | tl.store(grad_transformed_ptr + 11*batch_size*n_features + transformed_offset, go3, mask=batch_feature_mask)
407 | tl.store(grad_transformed_ptr + 12*batch_size*n_features + transformed_offset, go3, mask=batch_feature_mask)
408 | tl.store(grad_transformed_ptr + 13*batch_size*n_features + transformed_offset, go3, mask=batch_feature_mask)
409 |
410 |
411 | @triton.jit
412 | def gelu_pairwise_kernel_bwd(
413 | x_ptr,
414 | y_ptr,
415 | grad_pairwise_ptr,
416 | grad_x_ptr,
417 | grad_y_ptr,
418 | batch_size: tl.constexpr,
419 | n_features: tl.constexpr,
420 | BATCH_BLOCK: tl.constexpr,
421 | FEATURE_BLOCK: tl.constexpr,
422 | ):
423 | """
424 | Apply GELU non-linearity to inputs and compute required pairwise products.
425 | """
426 | batch_block_id = tl.program_id(axis=0)
427 | thread_block_id = tl.program_id(axis=1)
428 |
429 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
430 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
431 |
432 | batch_mask = batch_ids < batch_size
433 | feature_mask = feature_ids < n_features
434 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
435 |
436 | stride_component = batch_size * n_features
437 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
438 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
439 |
440 | gp0 = tl.load(grad_pairwise_ptr + 0*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
441 | gp1 = tl.load(grad_pairwise_ptr + 1*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
442 | gp2 = tl.load(grad_pairwise_ptr + 2*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
443 | gp3 = tl.load(grad_pairwise_ptr + 3*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
444 | gp4 = tl.load(grad_pairwise_ptr + 4*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
445 | gp5 = tl.load(grad_pairwise_ptr + 5*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
446 | gp6 = tl.load(grad_pairwise_ptr + 6*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
447 | gp7 = tl.load(grad_pairwise_ptr + 7*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
448 | gp8 = tl.load(grad_pairwise_ptr + 8*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
449 | gp9 = tl.load(grad_pairwise_ptr + 9*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
450 | gp10 = tl.load(grad_pairwise_ptr + 10*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
451 | gp11 = tl.load(grad_pairwise_ptr + 11*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
452 | gp12 = tl.load(grad_pairwise_ptr + 12*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
453 | gp13 = tl.load(grad_pairwise_ptr + 13*batch_size*n_features + pairwise_offset, mask=batch_feature_mask)
454 |
455 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
456 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
457 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
458 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
459 |
460 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
461 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
462 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
463 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
464 |
465 | # collect gradients from pairwise products
466 | gate_x = compute_gelu_gate(x0_raw)
467 | gate_y = compute_gelu_gate(y0_raw)
468 |
469 | x0 = x0_raw * gate_x
470 | x1 = x1_raw * gate_x
471 | x2 = x2_raw * gate_x
472 | x3 = x3_raw * gate_x
473 |
474 | y0 = y0_raw * gate_y
475 | y1 = y1_raw * gate_y
476 | y2 = y2_raw * gate_y
477 | y3 = y3_raw * gate_y
478 |
479 | x_grad_0 = gp0*y0 + gp3*y1 + gp7*y2 + gp11*y3
480 | x_grad_1 = gp1*y1 + gp4*y0 + gp8*y3 + gp12*y2
481 | x_grad_2 = gp1*y2 + gp5*y3 + gp9*y0 - gp12*y1
482 | x_grad_3 = -gp2 * y3 + gp6*y2 - gp10*y1 + gp13*y0
483 |
484 | y_grad_0 = gp0*x0 + gp4*x1 + gp9*x2 + gp13*x3
485 | y_grad_1 = gp1*x1 + gp3*x0 - gp10*x3 - gp12*x2
486 | y_grad_2 = gp1*x2 + gp6*x3 + gp7*x0 + gp12*x1
487 | y_grad_3 = -gp2 * x3 + gp5*x2 + gp8*x1 + gp11*x0
488 |
489 | # GELU gate gradients
490 | dgate_x = compute_gelu_gate_grad(x0_raw)
491 | dgate_y = compute_gelu_gate_grad(y0_raw)
492 |
493 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3)
494 | x_grad_1 = gate_x * x_grad_1
495 | x_grad_2 = gate_x * x_grad_2
496 | x_grad_3 = gate_x * x_grad_3
497 |
498 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3)
499 | y_grad_1 = gate_y * y_grad_1
500 | y_grad_2 = gate_y * y_grad_2
501 | y_grad_3 = gate_y * y_grad_3
502 |
503 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask)
504 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask)
505 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask)
506 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask)
507 |
508 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask)
509 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask)
510 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask)
511 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask)
512 |
513 |
514 | def gelu_fc_geometric_product_norm_bwd(
515 | x: torch.Tensor,
516 | y: torch.Tensor,
517 | weight: torch.Tensor,
518 | o: torch.Tensor,
519 | pairwise: torch.Tensor,
520 | partial_norm: torch.Tensor,
521 | grad_output: torch.Tensor,
522 | expansion_indices: torch.Tensor,
523 | normalize: bool,
524 | ) -> torch.Tensor:
525 | """Backward pass for the fused operation."""
526 | _, B, N = x.shape
527 |
528 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
529 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
530 |
531 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
532 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
533 |
534 | grad_x = torch.zeros_like(x)
535 | grad_y = torch.zeros_like(y)
536 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0))
537 | grad_weight = torch.zeros((NUM_PRODUCT_WEIGHTS, N, N), device=x.device, dtype=weight.dtype)
538 | grad_transformed = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype)
539 |
540 | grid = (num_blocks_batch, num_blocks_features)
541 |
542 | if normalize:
543 | grad_o_dot_o_kernel[grid](
544 | dot,
545 | partial_norm,
546 | o,
547 | grad_output,
548 | B,
549 | N,
550 | BATCH_BLOCK,
551 | FEATURE_BLOCK,
552 | EPS,
553 | num_warps=DEFAULT_NUM_WARPS,
554 | num_stages=DEFAULT_NUM_STAGES,
555 | )
556 |
557 | disassemble_kernel[grid](
558 | grad_output,
559 | o,
560 | dot,
561 | grad_transformed,
562 | partial_norm,
563 | normalize,
564 | B,
565 | N,
566 | BATCH_BLOCK,
567 | FEATURE_BLOCK,
568 | EPS,
569 | num_warps=DEFAULT_NUM_WARPS,
570 | num_stages=DEFAULT_NUM_STAGES,
571 | )
572 |
573 | grad_pairwise = torch.bmm(grad_transformed, weight[expansion_indices].transpose(-2, -1))
574 |
575 | grad_weight.index_add_(0, expansion_indices, torch.bmm(pairwise.transpose(-2, -1), grad_transformed))
576 |
577 | gelu_pairwise_kernel_bwd[grid](
578 | x,
579 | y,
580 | grad_pairwise,
581 | grad_x,
582 | grad_y,
583 | B,
584 | N,
585 | BATCH_BLOCK,
586 | FEATURE_BLOCK,
587 | num_warps=DEFAULT_NUM_WARPS,
588 | num_stages=DEFAULT_NUM_STAGES,
589 | )
590 |
591 | return grad_x, grad_y, grad_weight
592 |
593 |
594 | class FullyConnectedGeluGeometricProductNorm2D(torch.autograd.Function):
595 |
596 | @staticmethod
597 | @torch.amp.custom_fwd(device_type="cuda")
598 | def forward(ctx, x, y, weight, normalize):
599 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous()
600 |
601 | ctx.dtype = x.dtype
602 | ctx.normalize = normalize
603 |
604 | expansion_indices = torch.tensor(WEIGHT_EXPANSION, device=x.device)
605 |
606 | o, pairwise, partial_norm = gelu_fc_geometric_product_norm_fwd(
607 | x,
608 | y,
609 | weight,
610 | expansion_indices,
611 | normalize,
612 | )
613 |
614 | ctx.save_for_backward(x, y, weight, o, pairwise, partial_norm, expansion_indices)
615 |
616 | return o.to(x.dtype)
617 |
618 | @staticmethod
619 | @torch.amp.custom_bwd(device_type="cuda")
620 | def backward(ctx, grad_output):
621 | grad_output = grad_output.contiguous()
622 |
623 | x, y, weight, o, pairwise, partial_norm, expansion_indices = ctx.saved_tensors
624 |
625 | grad_x, grad_y, grad_weight = gelu_fc_geometric_product_norm_bwd(
626 | x,
627 | y,
628 | weight,
629 | o,
630 | pairwise,
631 | partial_norm,
632 | grad_output,
633 | expansion_indices,
634 | ctx.normalize,
635 | )
636 |
637 | return grad_x, grad_y, grad_weight, None, None, None, None
638 |
639 |
640 | def fused_gelu_fcgp_norm_2d(x, y, weight, normalize=True):
641 | """
642 | Fused operation that applies GELU non-linearity to two multivector inputs,
643 | then computes their fully connected geometric product, and applies RMSNorm.
644 |
645 | Clifford algebra is assumed to be Cl(2,0).
646 |
647 | Args:
648 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
649 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
650 | weight (torch.Tensor): Weight tensor of shape (NUM_PRODUCT_WEIGHTS, N, N), one weight per geometric product component.
651 | normalize (bool): Whether to apply RMSNorm after the geometric product.
652 |
653 | Returns:
654 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation.
655 | """
656 | return FullyConnectedGeluGeometricProductNorm2D.apply(x, y, weight, normalize)
657 |
--------------------------------------------------------------------------------
/ops/p3m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | MV_DIM = 8
6 | NUM_GRADES = 4
7 | NUM_PRODUCT_WEIGHTS = 20
8 | EPS = 1e-6
9 |
10 | # tuned at RTX 4500
11 | DEFAULT_BATCH_BLOCK = 4
12 | DEFAULT_FEATURE_BLOCK = 128
13 | DEFAULT_NUM_WARPS = 16
14 | DEFAULT_NUM_STAGES = 1
15 |
16 |
17 | @triton.jit
18 | def compute_gelu_gate(x):
19 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))"""
20 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype)
21 |
22 |
23 | @triton.jit
24 | def compute_gelu_gate_grad(x):
25 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)"""
26 | return 0.3989422804 * tl.exp(-0.5 * x * x)
27 |
28 |
29 | @triton.jit
30 | def gelu_wgp_norm_kernel_fwd(
31 | x_ptr,
32 | y_ptr,
33 | output_ptr,
34 | weights_ptr,
35 | pnorm_ptr,
36 | NORMALIZE: tl.constexpr,
37 | batch_size: tl.constexpr,
38 | n_features: tl.constexpr,
39 | BATCH_BLOCK: tl.constexpr,
40 | FEATURE_BLOCK: tl.constexpr,
41 | NUM_PRODUCT_WEIGHTS: tl.constexpr,
42 | ):
43 | """
44 | Apply GELU non-linearity to inputs, compute weighted geometric product,
45 | and accumulate squared norms for grade-wise RMSNorm.
46 | """
47 | batch_block_id = tl.program_id(axis=0)
48 | thread_block_id = tl.program_id(axis=1)
49 |
50 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
51 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
52 |
53 | batch_mask = batch_ids < batch_size
54 | feature_mask = feature_ids < n_features
55 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
56 |
57 | stride_component = batch_size * n_features
58 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
59 |
60 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS
61 |
62 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask)
63 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask)
64 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask)
65 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask)
66 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask)
67 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask)
68 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask)
69 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask)
70 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask)
71 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask)
72 | w10 = tl.load(weights_ptr + weight_offset + 10, mask=feature_mask)
73 | w11 = tl.load(weights_ptr + weight_offset + 11, mask=feature_mask)
74 | w12 = tl.load(weights_ptr + weight_offset + 12, mask=feature_mask)
75 | w13 = tl.load(weights_ptr + weight_offset + 13, mask=feature_mask)
76 | w14 = tl.load(weights_ptr + weight_offset + 14, mask=feature_mask)
77 | w15 = tl.load(weights_ptr + weight_offset + 15, mask=feature_mask)
78 | w16 = tl.load(weights_ptr + weight_offset + 16, mask=feature_mask)
79 | w17 = tl.load(weights_ptr + weight_offset + 17, mask=feature_mask)
80 | w18 = tl.load(weights_ptr + weight_offset + 18, mask=feature_mask)
81 | w19 = tl.load(weights_ptr + weight_offset + 19, mask=feature_mask)
82 |
83 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
84 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
85 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
86 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
87 | x4 = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
88 | x5 = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
89 | x6 = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
90 | x7 = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
91 |
92 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
93 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
94 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
95 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
96 | y4 = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
97 | y5 = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
98 | y6 = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
99 | y7 = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
100 |
101 | # Apply GELU gate
102 | gate_x = compute_gelu_gate(x0)
103 | gate_y = compute_gelu_gate(y0)
104 |
105 | x0 = x0 * gate_x
106 | x1 = x1 * gate_x
107 | x2 = x2 * gate_x
108 | x3 = x3 * gate_x
109 | x4 = x4 * gate_x
110 | x5 = x5 * gate_x
111 | x6 = x6 * gate_x
112 | x7 = x7 * gate_x
113 |
114 | y0 = y0 * gate_y
115 | y1 = y1 * gate_y
116 | y2 = y2 * gate_y
117 | y3 = y3 * gate_y
118 | y4 = y4 * gate_y
119 | y5 = y5 * gate_y
120 | y6 = y6 * gate_y
121 | y7 = y7 * gate_y
122 |
123 | # Compute geometric product
124 | o0 = (w0*x0*y0 + w4 * (x1*y1 + x2*y2 + x3*y3) - w10 * (x4*y4 + x5*y5 + x6*y6) - w16*x7*y7)
125 | o1 = (w1*x0*y1 + w5*x1*y0 - w6 * (x2*y4 + x3*y5) + w11 * (x4*y2 + x5*y3) - w12*x6*y7 - w17*x7*y6)
126 | o2 = (w1*x0*y2 + w6*x1*y4 + w5*x2*y0 - w6*x3*y6 - w11*x4*y1 + w12*x5*y7 + w11*x6*y3 + w17*x7*y5)
127 | o3 = (w1*x0*y3 + w6 * (x1*y5 + x2*y6) + w5*x3*y0 - w12*x4*y7 - w11 * (x5*y1 + x6*y2) - w17*x7*y4)
128 | o4 = (w2*x0*y4 + w7*x1*y2 - w7*x2*y1 + w8*x3*y7 + w13*x4*y0 - w14*x5*y6 + w14*x6*y5 + w18*x7*y3)
129 | o5 = (w2*x0*y5 + w7*x1*y3 - w8*x2*y7 - w7*x3*y1 + w14*x4*y6 + w13*x5*y0 - w14*x6*y4 - w18*x7*y2)
130 | o6 = (w2*x0*y6 + w8*x1*y7 + w7*x2*y3 - w7*x3*y2 - w14*x4*y5 + w14*x5*y4 + w13*x6*y0 + w18*x7*y1)
131 | o7 = (w3*x0*y7 + w9*x1*y6 - w9*x2*y5 + w9*x3*y4 + w15*x4*y3 - w15*x5*y2 + w15*x6*y1 + w19*x7*y0)
132 |
133 | if NORMALIZE:
134 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features
135 | pn_vector = tl.sum(o1*o1 + o2*o2 + o3*o3, axis=1) / n_features
136 | pn_bivect = tl.sum(o4*o4 + o5*o5 + o6*o6, axis=1) / n_features
137 | pn_pseudo = tl.sum(o7 * o7, axis=1) / n_features
138 |
139 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask)
140 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask)
141 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask)
142 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_vector, mask=batch_mask)
143 | tl.atomic_add(pnorm_ptr + 4*batch_size + batch_ids, pn_bivect, mask=batch_mask)
144 | tl.atomic_add(pnorm_ptr + 5*batch_size + batch_ids, pn_bivect, mask=batch_mask)
145 | tl.atomic_add(pnorm_ptr + 6*batch_size + batch_ids, pn_bivect, mask=batch_mask)
146 | tl.atomic_add(pnorm_ptr + 7*batch_size + batch_ids, pn_pseudo, mask=batch_mask)
147 |
148 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask)
149 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask)
150 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask)
151 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask)
152 | tl.store(output_ptr + 4 * stride_component + base_offset, o4, mask=batch_feature_mask)
153 | tl.store(output_ptr + 5 * stride_component + base_offset, o5, mask=batch_feature_mask)
154 | tl.store(output_ptr + 6 * stride_component + base_offset, o6, mask=batch_feature_mask)
155 | tl.store(output_ptr + 7 * stride_component + base_offset, o7, mask=batch_feature_mask)
156 |
157 |
158 | @triton.jit
159 | def normalize_with_sqrt_kernel(
160 | output_ptr,
161 | pnorm_ptr,
162 | batch_size: tl.constexpr,
163 | n_features: tl.constexpr,
164 | BATCH_BLOCK: tl.constexpr,
165 | FEATURE_BLOCK: tl.constexpr,
166 | MV_DIM: tl.constexpr,
167 | EPS: tl.constexpr,
168 | ):
169 | """Normalize the output by dividing each grade with root of its accumulated norm."""
170 | batch_block_id = tl.program_id(axis=0)
171 | thread_block_id = tl.program_id(axis=1)
172 |
173 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
174 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
175 |
176 | batch_mask = batch_ids < batch_size
177 | feature_mask = feature_ids < n_features
178 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None]
179 |
180 | component_ids = tl.arange(0, MV_DIM)[None, None, :]
181 |
182 | feature_offset = (component_ids * batch_size * n_features +
183 | batch_ids[:, None, None] * n_features +
184 | feature_ids[None, :, None])
185 |
186 | norm_indices = component_ids * batch_size + batch_ids[:, None, None]
187 |
188 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None])
189 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask)
190 |
191 | norm = tl.sqrt(pnorm + EPS)
192 | mv_normalized = mv / norm
193 |
194 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask)
195 |
196 |
197 | def gelu_geometric_product_norm_fwd(
198 | x: torch.Tensor,
199 | y: torch.Tensor,
200 | weight: torch.Tensor,
201 | normalize: bool,
202 | ) -> torch.Tensor:
203 | """Fused operation: GELU non-linearity, weighted geometric product, and grade-wise RMSNorm."""
204 | assert x.shape == y.shape
205 | assert x.shape[0] == MV_DIM
206 | assert x.shape[2] == weight.shape[0]
207 | assert weight.shape[1] == NUM_PRODUCT_WEIGHTS
208 |
209 | _, B, N = x.shape
210 |
211 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
212 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
213 |
214 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
215 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
216 |
217 | output = torch.empty_like(x)
218 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize
219 | else torch.zeros((1,), device=x.device, dtype=x.dtype))
220 |
221 | grid = (num_blocks_batch, num_blocks_features)
222 |
223 | gelu_wgp_norm_kernel_fwd[grid](
224 | x,
225 | y,
226 | output,
227 | weight,
228 | partial_norm,
229 | normalize,
230 | B,
231 | N,
232 | BATCH_BLOCK,
233 | FEATURE_BLOCK,
234 | NUM_PRODUCT_WEIGHTS,
235 | num_warps=DEFAULT_NUM_WARPS,
236 | num_stages=DEFAULT_NUM_STAGES,
237 | )
238 |
239 | if normalize:
240 | normalize_with_sqrt_kernel[grid](
241 | output,
242 | partial_norm,
243 | B,
244 | N,
245 | BATCH_BLOCK,
246 | FEATURE_BLOCK,
247 | MV_DIM,
248 | EPS,
249 | num_warps=DEFAULT_NUM_WARPS,
250 | num_stages=DEFAULT_NUM_STAGES,
251 | )
252 |
253 | return output, partial_norm
254 |
255 |
256 | @triton.jit
257 | def grad_o_dot_o_kernel(
258 | dot_ptr,
259 | pnorm_ptr,
260 | output_ptr,
261 | grad_output_ptr,
262 | batch_size: tl.constexpr,
263 | n_features: tl.constexpr,
264 | BATCH_BLOCK: tl.constexpr,
265 | FEATURE_BLOCK: tl.constexpr,
266 | EPS: tl.constexpr,
267 | ):
268 | """Compute the dot product of grad_output and output for each grade, accumulate across all features."""
269 | batch_block_id = tl.program_id(axis=0)
270 | thread_block_id = tl.program_id(axis=1)
271 |
272 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
273 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
274 |
275 | batch_mask = batch_ids < batch_size
276 | feature_mask = feature_ids < n_features
277 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
278 |
279 | stride_component = batch_size * n_features
280 | offset = batch_ids[:, None] * n_features + feature_ids[None, :]
281 |
282 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
283 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
284 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
285 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
286 | go4 = tl.load(grad_output_ptr + 4 * stride_component + offset, mask=batch_feature_mask)
287 | go5 = tl.load(grad_output_ptr + 5 * stride_component + offset, mask=batch_feature_mask)
288 | go6 = tl.load(grad_output_ptr + 6 * stride_component + offset, mask=batch_feature_mask)
289 | go7 = tl.load(grad_output_ptr + 7 * stride_component + offset, mask=batch_feature_mask)
290 |
291 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
292 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
293 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None]
294 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None]
295 |
296 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
297 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
298 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
299 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
300 | o4 = tl.load(output_ptr + 4 * stride_component + offset, mask=batch_feature_mask)
301 | o5 = tl.load(output_ptr + 5 * stride_component + offset, mask=batch_feature_mask)
302 | o6 = tl.load(output_ptr + 6 * stride_component + offset, mask=batch_feature_mask)
303 | o7 = tl.load(output_ptr + 7 * stride_component + offset, mask=batch_feature_mask)
304 |
305 | rms_scalar = tl.sqrt(pn_scalar + EPS)
306 | rms_vector = tl.sqrt(pn_vector + EPS)
307 | rms_bivect = tl.sqrt(pn_bivect + EPS)
308 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
309 |
310 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1)
311 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2 + go3*o3), axis=1)
312 | dot_bivect = tl.sum(rms_bivect * (go4*o4 + go5*o5 + go6*o6), axis=1)
313 | dot_pseudo = tl.sum(rms_pseudo * go7 * o7, axis=1)
314 |
315 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask)
316 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask)
317 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_bivect, mask=batch_mask)
318 | tl.atomic_add(dot_ptr + 3*batch_size + batch_ids, dot_pseudo, mask=batch_mask)
319 |
320 |
321 | @triton.jit
322 | def gelu_wgp_norm_kernel_bwd(
323 | x_ptr,
324 | y_ptr,
325 | output_ptr,
326 | weights_ptr,
327 | dot_ptr,
328 | pnorm_ptr,
329 | grad_output_ptr,
330 | grad_x_ptr,
331 | grad_y_ptr,
332 | grad_weight_ptr,
333 | NORMALIZE: tl.constexpr,
334 | batch_size: tl.constexpr,
335 | n_features: tl.constexpr,
336 | BATCH_BLOCK: tl.constexpr,
337 | FEATURE_BLOCK: tl.constexpr,
338 | MV_DIM: tl.constexpr,
339 | NUM_GRADES: tl.constexpr,
340 | NUM_PRODUCT_WEIGHTS: tl.constexpr,
341 | EPS: tl.constexpr,
342 | ):
343 | """Compute gradients w.r.t. inputs and weights of the fused operation."""
344 | batch_block_id = tl.program_id(axis=0)
345 | thread_block_id = tl.program_id(axis=1)
346 |
347 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
348 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
349 |
350 | batch_mask = batch_ids < batch_size
351 | feature_mask = feature_ids < n_features
352 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
353 |
354 | stride_component = batch_size * n_features
355 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
356 |
357 | weight_offset = feature_ids * NUM_PRODUCT_WEIGHTS
358 | block_offset = batch_block_id * n_features * NUM_PRODUCT_WEIGHTS
359 |
360 | w0 = tl.load(weights_ptr + weight_offset + 0, mask=feature_mask)
361 | w1 = tl.load(weights_ptr + weight_offset + 1, mask=feature_mask)
362 | w2 = tl.load(weights_ptr + weight_offset + 2, mask=feature_mask)
363 | w3 = tl.load(weights_ptr + weight_offset + 3, mask=feature_mask)
364 | w4 = tl.load(weights_ptr + weight_offset + 4, mask=feature_mask)
365 | w5 = tl.load(weights_ptr + weight_offset + 5, mask=feature_mask)
366 | w6 = tl.load(weights_ptr + weight_offset + 6, mask=feature_mask)
367 | w7 = tl.load(weights_ptr + weight_offset + 7, mask=feature_mask)
368 | w8 = tl.load(weights_ptr + weight_offset + 8, mask=feature_mask)
369 | w9 = tl.load(weights_ptr + weight_offset + 9, mask=feature_mask)
370 | w10 = tl.load(weights_ptr + weight_offset + 10, mask=feature_mask)
371 | w11 = tl.load(weights_ptr + weight_offset + 11, mask=feature_mask)
372 | w12 = tl.load(weights_ptr + weight_offset + 12, mask=feature_mask)
373 | w13 = tl.load(weights_ptr + weight_offset + 13, mask=feature_mask)
374 | w14 = tl.load(weights_ptr + weight_offset + 14, mask=feature_mask)
375 | w15 = tl.load(weights_ptr + weight_offset + 15, mask=feature_mask)
376 | w16 = tl.load(weights_ptr + weight_offset + 16, mask=feature_mask)
377 | w17 = tl.load(weights_ptr + weight_offset + 17, mask=feature_mask)
378 | w18 = tl.load(weights_ptr + weight_offset + 18, mask=feature_mask)
379 | w19 = tl.load(weights_ptr + weight_offset + 19, mask=feature_mask)
380 |
381 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
382 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
383 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
384 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
385 | x4_raw = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
386 | x5_raw = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
387 | x6_raw = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
388 | x7_raw = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
389 |
390 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
391 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
392 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
393 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
394 | y4_raw = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
395 | y5_raw = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
396 | y6_raw = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
397 | y7_raw = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
398 |
399 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
400 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
401 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
402 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
403 | go4 = tl.load(grad_output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
404 | go5 = tl.load(grad_output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
405 | go6 = tl.load(grad_output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
406 | go7 = tl.load(grad_output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
407 |
408 | if NORMALIZE:
409 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
410 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
411 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
412 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
413 | o4 = tl.load(output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
414 | o5 = tl.load(output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
415 | o6 = tl.load(output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
416 | o7 = tl.load(output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
417 |
418 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
419 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
420 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None]
421 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None]
422 |
423 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
424 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
425 | dot_bivect = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None]
426 | dot_pseudo = tl.load(dot_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None]
427 |
428 | rms_scalar = tl.sqrt(pn_scalar + EPS)
429 | rms_vector = tl.sqrt(pn_vector + EPS)
430 | rms_bivect = tl.sqrt(pn_bivect + EPS)
431 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
432 |
433 | go0 = go0/rms_scalar - o0 * dot_scalar / (n_features*rms_scalar*rms_scalar)
434 | go1 = go1/rms_vector - o1 * dot_vector / (n_features*rms_vector*rms_vector)
435 | go2 = go2/rms_vector - o2 * dot_vector / (n_features*rms_vector*rms_vector)
436 | go3 = go3/rms_vector - o3 * dot_vector / (n_features*rms_vector*rms_vector)
437 | go4 = go4/rms_bivect - o4 * dot_bivect / (n_features*rms_bivect*rms_bivect)
438 | go5 = go5/rms_bivect - o5 * dot_bivect / (n_features*rms_bivect*rms_bivect)
439 | go6 = go6/rms_bivect - o6 * dot_bivect / (n_features*rms_bivect*rms_bivect)
440 | go7 = go7/rms_pseudo - o7 * dot_pseudo / (n_features*rms_pseudo*rms_pseudo)
441 |
442 | # weighted geometric product backward
443 | gate_x = compute_gelu_gate(x0_raw)
444 | gate_y = compute_gelu_gate(y0_raw)
445 |
446 | x0 = x0_raw * gate_x
447 | x1 = x1_raw * gate_x
448 | x2 = x2_raw * gate_x
449 | x3 = x3_raw * gate_x
450 | x4 = x4_raw * gate_x
451 | x5 = x5_raw * gate_x
452 | x6 = x6_raw * gate_x
453 | x7 = x7_raw * gate_x
454 |
455 | y0 = y0_raw * gate_y
456 | y1 = y1_raw * gate_y
457 | y2 = y2_raw * gate_y
458 | y3 = y3_raw * gate_y
459 | y4 = y4_raw * gate_y
460 | y5 = y5_raw * gate_y
461 | y6 = y6_raw * gate_y
462 | y7 = y7_raw * gate_y
463 |
464 | tmp0 = go0 * w0
465 | tmp1 = go7 * w3
466 | tmp2 = go1 * y1
467 | tmp3 = go2 * y2
468 | tmp4 = go3 * y3
469 | tmp5 = go4 * y4
470 | tmp6 = go5 * y5
471 | tmp7 = go6 * y6
472 | tmp8 = go0 * w4
473 | tmp9 = w5 * y0
474 | tmp10 = w8 * y7
475 | tmp11 = go7 * w9
476 | tmp12 = go1 * w6
477 | tmp13 = w13 * y0
478 | tmp14 = go7 * w15
479 | tmp15 = go0 * w10
480 | tmp16 = w12 * y7
481 | tmp17 = go3 * w11
482 | tmp18 = go7 * w19
483 | tmp19 = go0 * w16
484 | tmp20 = go4 * y3
485 | tmp21 = go6 * y1
486 | tmp22 = go5 * y2
487 | tmp23 = go1 * y6
488 | tmp24 = go3 * y4
489 | tmp25 = go4 * x4
490 | tmp26 = go5 * x5
491 | tmp27 = go6 * x6
492 | tmp28 = go1 * x1
493 | tmp29 = go2 * x2
494 | tmp30 = go3 * x3
495 | tmp31 = w1 * x0
496 | tmp32 = w18 * x7
497 | tmp33 = w2 * x0
498 | tmp34 = w17 * x7
499 | tmp35 = go4 * x3
500 | tmp36 = go6 * x1
501 | tmp37 = go5 * x2
502 | tmp38 = go1 * x6
503 | tmp39 = go3 * x4
504 |
505 | x_grad_0 = (tmp0*y0 + tmp1*y7 + w1 * (tmp2+tmp3+tmp4) + w2 * (tmp5+tmp6+tmp7))
506 | x_grad_1 = (go1*tmp9 + go6*tmp10 + tmp11*y6 + tmp8*y1 + w6 * (go2*y4 + go3*y5) + w7 * (go4*y2 + go5*y3))
507 | x_grad_2 = (go2*tmp9 + go3*w6*y6 - go5*tmp10 - tmp11*y5 - tmp12*y4 + tmp8*y2 + w7 * (-go4 * y1 + go6*y3))
508 | x_grad_3 = (go3*tmp9 + go4*tmp10 + tmp11*y4 + tmp8*y3 - w6 * (go1*y5 + go2*y6) - w7 * (go5*y1 + go6*y2))
509 | x_grad_4 = (-go3 * tmp16 + go4*tmp13 + tmp14*y3 - tmp15*y4 + w11 * (go1*y2 - go2*y1) + w14 * (go5*y6 - go6*y5))
510 | x_grad_5 = (go1*w11*y3 + go2*w12*y7 - go4*w14*y6 + go5*w13*y0 + go6*w14*y4 - tmp14*y2 - tmp15*y5 - tmp17*y1)
511 | x_grad_6 = (-go1 * tmp16 + go6*tmp13 + tmp14*y1 - tmp15*y6 + w11 * (go2*y3 - go3*y2) + w14 * (go4*y5 - go5*y4))
512 | x_grad_7 = (tmp18*y0 - tmp19*y7 + w17 * (go2*y5 - tmp23 - tmp24) + w18 * (tmp20+tmp21-tmp22))
513 |
514 | y_grad_0 = (tmp0*x0 + tmp18*x7 + w13 * (tmp25+tmp26+tmp27) + w5 * (tmp28+tmp29+tmp30))
515 | y_grad_1 = (go1*tmp31 + go6*tmp32 + tmp14*x6 + tmp8*x1 - w11 * (go2*x4 + go3*x5) - w7 * (go4*x2 + go5*x3))
516 | y_grad_2 = (go1*w11*x4 + go2*tmp31 + go4*w7*x1 - go5*tmp32 - go6*w7*x3 - tmp14*x5 - tmp17*x6 + tmp8*x2)
517 | y_grad_3 = (go3*tmp31 + go4*tmp32 + tmp14*x4 + tmp8*x3 + w11 * (go1*x5 + go2*x6) + w7 * (go5*x1 + go6*x2))
518 | y_grad_4 = (-go3 * tmp34 + go4*tmp33 + tmp11*x3 - tmp15*x4 + w14 * (-go5 * x6 + go6*x5) + w6 * (-go1 * x2 + go2*x1))
519 | y_grad_5 = (go2*w17*x7 + go3*w6*x1 + go4*w14*x6 + go5*w2*x0 - go6*w14*x4 - tmp11*x2 - tmp12*x3 - tmp15*x5)
520 | y_grad_6 = (-go1 * tmp34 + go6*tmp33 + tmp11*x1 - tmp15*x6 + w14 * (-go4 * x5 + go5*x4) + w6 * (-go2 * x3 + go3*x2))
521 | y_grad_7 = (tmp1*x0 - tmp19*x7 + w12 * (go2*x5 - tmp38 - tmp39) + w8 * (tmp35+tmp36-tmp37))
522 |
523 | w_grad_0 = tl.sum(go0 * x0 * y0, axis=0)
524 | w_grad_1 = tl.sum(tmp2*x0 + tmp3*x0 + tmp4*x0, axis=0)
525 | w_grad_2 = tl.sum(tmp5*x0 + tmp6*x0 + tmp7*x0, axis=0)
526 | w_grad_3 = tl.sum(go7 * x0 * y7, axis=0)
527 | w_grad_4 = tl.sum(go0 * (x1*y1 + x2*y2 + x3*y3), axis=0)
528 | w_grad_5 = tl.sum(tmp28*y0 + tmp29*y0 + tmp30*y0, axis=0)
529 | w_grad_6 = tl.sum(go1 * (-x2 * y4 - x3*y5) + go2 * (x1*y4 - x3*y6) + go3 * (x1*y5 + x2*y6), axis=0)
530 | w_grad_7 = tl.sum(go4 * (x1*y2 - x2*y1) + go5 * (x1*y3 - x3*y1) + go6 * (x2*y3 - x3*y2), axis=0)
531 | w_grad_8 = tl.sum(tmp35*y7 + tmp36*y7 - tmp37*y7, axis=0)
532 | w_grad_9 = tl.sum(go7 * (x1*y6 - x2*y5 + x3*y4), axis=0)
533 | w_grad_10 = tl.sum(go0 * (-x4 * y4 - x5*y5 - x6*y6), axis=0)
534 | w_grad_11 = tl.sum(go1 * (x4*y2 + x5*y3) + go2 * (-x4 * y1 + x6*y3) + go3 * (-x5 * y1 - x6*y2), axis=0)
535 | w_grad_12 = tl.sum(go2*x5*y7 - tmp38*y7 - tmp39*y7, axis=0)
536 | w_grad_13 = tl.sum(tmp25*y0 + tmp26*y0 + tmp27*y0, axis=0)
537 | w_grad_14 = tl.sum(go4 * (-x5 * y6 + x6*y5) + go5 * (x4*y6 - x6*y4) + go6 * (-x4 * y5 + x5*y4), axis=0)
538 | w_grad_15 = tl.sum(go7 * (x4*y3 - x5*y2 + x6*y1), axis=0)
539 | w_grad_16 = tl.sum(-go0 * x7 * y7, axis=0)
540 | w_grad_17 = tl.sum(go2*x7*y5 - tmp23*x7 - tmp24*x7, axis=0)
541 | w_grad_18 = tl.sum(tmp20*x7 + tmp21*x7 - tmp22*x7, axis=0)
542 | w_grad_19 = tl.sum(go7 * x7 * y0, axis=0)
543 |
544 | # GELU gate gradients
545 | dgate_x = compute_gelu_gate_grad(x0_raw)
546 | dgate_y = compute_gelu_gate_grad(y0_raw)
547 |
548 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3 +
549 | x4_raw*x_grad_4 + x5_raw*x_grad_5 + x6_raw*x_grad_6 +
550 | x7_raw*x_grad_7)
551 | x_grad_1 = gate_x * x_grad_1
552 | x_grad_2 = gate_x * x_grad_2
553 | x_grad_3 = gate_x * x_grad_3
554 | x_grad_4 = gate_x * x_grad_4
555 | x_grad_5 = gate_x * x_grad_5
556 | x_grad_6 = gate_x * x_grad_6
557 | x_grad_7 = gate_x * x_grad_7
558 |
559 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3 +
560 | y4_raw*y_grad_4 + y5_raw*y_grad_5 + y6_raw*y_grad_6 +
561 | y7_raw*y_grad_7)
562 | y_grad_1 = gate_y * y_grad_1
563 | y_grad_2 = gate_y * y_grad_2
564 | y_grad_3 = gate_y * y_grad_3
565 | y_grad_4 = gate_y * y_grad_4
566 | y_grad_5 = gate_y * y_grad_5
567 | y_grad_6 = gate_y * y_grad_6
568 | y_grad_7 = gate_y * y_grad_7
569 |
570 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask)
571 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask)
572 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask)
573 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask)
574 | tl.store(grad_x_ptr + 4 * stride_component + base_offset, x_grad_4, mask=batch_feature_mask)
575 | tl.store(grad_x_ptr + 5 * stride_component + base_offset, x_grad_5, mask=batch_feature_mask)
576 | tl.store(grad_x_ptr + 6 * stride_component + base_offset, x_grad_6, mask=batch_feature_mask)
577 | tl.store(grad_x_ptr + 7 * stride_component + base_offset, x_grad_7, mask=batch_feature_mask)
578 |
579 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask)
580 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask)
581 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask)
582 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask)
583 | tl.store(grad_y_ptr + 4 * stride_component + base_offset, y_grad_4, mask=batch_feature_mask)
584 | tl.store(grad_y_ptr + 5 * stride_component + base_offset, y_grad_5, mask=batch_feature_mask)
585 | tl.store(grad_y_ptr + 6 * stride_component + base_offset, y_grad_6, mask=batch_feature_mask)
586 | tl.store(grad_y_ptr + 7 * stride_component + base_offset, y_grad_7, mask=batch_feature_mask)
587 |
588 | tl.store(grad_weight_ptr + block_offset + weight_offset + 0, w_grad_0, mask=feature_mask)
589 | tl.store(grad_weight_ptr + block_offset + weight_offset + 1, w_grad_1, mask=feature_mask)
590 | tl.store(grad_weight_ptr + block_offset + weight_offset + 2, w_grad_2, mask=feature_mask)
591 | tl.store(grad_weight_ptr + block_offset + weight_offset + 3, w_grad_3, mask=feature_mask)
592 | tl.store(grad_weight_ptr + block_offset + weight_offset + 4, w_grad_4, mask=feature_mask)
593 | tl.store(grad_weight_ptr + block_offset + weight_offset + 5, w_grad_5, mask=feature_mask)
594 | tl.store(grad_weight_ptr + block_offset + weight_offset + 6, w_grad_6, mask=feature_mask)
595 | tl.store(grad_weight_ptr + block_offset + weight_offset + 7, w_grad_7, mask=feature_mask)
596 | tl.store(grad_weight_ptr + block_offset + weight_offset + 8, w_grad_8, mask=feature_mask)
597 | tl.store(grad_weight_ptr + block_offset + weight_offset + 9, w_grad_9, mask=feature_mask)
598 | tl.store(grad_weight_ptr + block_offset + weight_offset + 10, w_grad_10, mask=feature_mask)
599 | tl.store(grad_weight_ptr + block_offset + weight_offset + 11, w_grad_11, mask=feature_mask)
600 | tl.store(grad_weight_ptr + block_offset + weight_offset + 12, w_grad_12, mask=feature_mask)
601 | tl.store(grad_weight_ptr + block_offset + weight_offset + 13, w_grad_13, mask=feature_mask)
602 | tl.store(grad_weight_ptr + block_offset + weight_offset + 14, w_grad_14, mask=feature_mask)
603 | tl.store(grad_weight_ptr + block_offset + weight_offset + 15, w_grad_15, mask=feature_mask)
604 | tl.store(grad_weight_ptr + block_offset + weight_offset + 16, w_grad_16, mask=feature_mask)
605 | tl.store(grad_weight_ptr + block_offset + weight_offset + 17, w_grad_17, mask=feature_mask)
606 | tl.store(grad_weight_ptr + block_offset + weight_offset + 18, w_grad_18, mask=feature_mask)
607 | tl.store(grad_weight_ptr + block_offset + weight_offset + 19, w_grad_19, mask=feature_mask)
608 |
609 |
610 | def gelu_geometric_product_norm_bwd(
611 | x: torch.Tensor,
612 | y: torch.Tensor,
613 | weight: torch.Tensor,
614 | o: torch.Tensor,
615 | partial_norm: torch.Tensor,
616 | grad_output: torch.Tensor,
617 | normalize: bool,
618 | ) -> torch.Tensor:
619 | """Backward pass for the fused operation."""
620 | _, B, N = x.shape
621 |
622 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
623 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
624 |
625 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
626 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
627 |
628 | grad_x = torch.zeros_like(x)
629 | grad_y = torch.zeros_like(y)
630 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0))
631 | grad_weight = torch.zeros((num_blocks_batch, N, NUM_PRODUCT_WEIGHTS), device=x.device, dtype=weight.dtype)
632 |
633 | grid = (num_blocks_batch, num_blocks_features)
634 |
635 | if normalize:
636 | grad_o_dot_o_kernel[grid](
637 | dot,
638 | partial_norm,
639 | o,
640 | grad_output,
641 | B,
642 | N,
643 | BATCH_BLOCK,
644 | FEATURE_BLOCK,
645 | EPS,
646 | num_warps=DEFAULT_NUM_WARPS,
647 | num_stages=DEFAULT_NUM_STAGES,
648 | )
649 |
650 | gelu_wgp_norm_kernel_bwd[grid](
651 | x,
652 | y,
653 | o,
654 | weight,
655 | dot,
656 | partial_norm,
657 | grad_output,
658 | grad_x,
659 | grad_y,
660 | grad_weight,
661 | normalize,
662 | B,
663 | N,
664 | BATCH_BLOCK,
665 | FEATURE_BLOCK,
666 | MV_DIM,
667 | NUM_GRADES,
668 | NUM_PRODUCT_WEIGHTS,
669 | EPS,
670 | num_warps=DEFAULT_NUM_WARPS,
671 | num_stages=DEFAULT_NUM_STAGES,
672 | )
673 |
674 | grad_weight = torch.sum(grad_weight, dim=0)
675 |
676 | return grad_x, grad_y, grad_weight
677 |
678 |
679 | class WeightedGeluGeometricProductNorm3D(torch.autograd.Function):
680 |
681 | @staticmethod
682 | @torch.amp.custom_fwd(device_type="cuda")
683 | def forward(ctx, x, y, weight, normalize):
684 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous()
685 |
686 | ctx.dtype = x.dtype
687 | ctx.normalize = normalize
688 |
689 | o, partial_norm = gelu_geometric_product_norm_fwd(
690 | x,
691 | y,
692 | weight,
693 | normalize,
694 | )
695 |
696 | ctx.save_for_backward(x, y, weight, o, partial_norm)
697 |
698 | return o.to(x.dtype)
699 |
700 | @staticmethod
701 | @torch.amp.custom_bwd(device_type="cuda")
702 | def backward(ctx, grad_output):
703 | grad_output = grad_output.contiguous()
704 |
705 | x, y, weight, o, partial_norm = ctx.saved_tensors
706 |
707 | grad_x, grad_y, grad_weight = gelu_geometric_product_norm_bwd(
708 | x,
709 | y,
710 | weight,
711 | o,
712 | partial_norm,
713 | grad_output,
714 | ctx.normalize,
715 | )
716 |
717 | return grad_x, grad_y, grad_weight, None, None, None, None
718 |
719 |
720 | def fused_gelu_sgp_norm_3d(x, y, weight, normalize=True):
721 | """
722 | Fused operation that applies GELU non-linearity to two multivector inputs,
723 | then computes their weighted geometric product, and applies RMSNorm.
724 |
725 | Clifford algebra is assumed to be Cl(3,0).
726 |
727 | Args:
728 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
729 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
730 | weight (torch.Tensor): Weight tensor of shape (N, NUM_PRODUCT_WEIGHTS), one weight per geometric product component.
731 | normalize (bool): Whether to apply RMSNorm after the geometric product.
732 |
733 | Returns:
734 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation.
735 | """
736 | return WeightedGeluGeometricProductNorm3D.apply(x, y, weight, normalize)
--------------------------------------------------------------------------------
/ops/fc_p3m0.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import triton
3 | import triton.language as tl
4 |
5 | MV_DIM = 8
6 | NUM_GRADES = 4
7 | NUM_PRODUCT_WEIGHTS = 20
8 | WEIGHT_EXPANSION = [0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 16, 17, 17, 17, 18, 18, 18, 19]
9 | EPS = 1e-6
10 |
11 | # tuned at RTX 4500
12 | DEFAULT_BATCH_BLOCK = 4
13 | DEFAULT_FEATURE_BLOCK = 128
14 | DEFAULT_NUM_WARPS = 16
15 | DEFAULT_NUM_STAGES = 1
16 |
17 |
18 | @triton.jit
19 | def compute_gelu_gate(x):
20 | """Compute the GELU gate Φ(x) := 0.5 * (1 + erf(x / sqrt(2)))"""
21 | return 0.5 * (1 + tl.erf(x.to(tl.float32) * 0.7071067811865475)).to(x.dtype)
22 |
23 |
24 | @triton.jit
25 | def compute_gelu_gate_grad(x):
26 | """Compute the gradient of the GELU gate = 1/sqrt(2pi) * exp(-x^2/2)"""
27 | return 0.3989422804 * tl.exp(-0.5 * x * x)
28 |
29 |
30 | @triton.jit
31 | def gelu_pairwise_kernel_fwd(
32 | x_ptr,
33 | y_ptr,
34 | pairwise_ptr,
35 | batch_size: tl.constexpr,
36 | n_features: tl.constexpr,
37 | BATCH_BLOCK: tl.constexpr,
38 | FEATURE_BLOCK: tl.constexpr,
39 | ):
40 | """Apply GELU non-linearity to inputs and compute pairwise products for geometric product."""
41 | batch_block_id = tl.program_id(axis=0)
42 | thread_block_id = tl.program_id(axis=1)
43 |
44 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
45 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
46 |
47 | batch_mask = batch_ids < batch_size
48 | feature_mask = feature_ids < n_features
49 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
50 |
51 | stride_component = batch_size * n_features
52 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
53 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
54 |
55 | x0 = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
56 | x1 = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
57 | x2 = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
58 | x3 = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
59 | x4 = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
60 | x5 = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
61 | x6 = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
62 | x7 = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
63 |
64 | y0 = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
65 | y1 = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
66 | y2 = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
67 | y3 = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
68 | y4 = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
69 | y5 = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
70 | y6 = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
71 | y7 = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
72 |
73 | gate_x = compute_gelu_gate(x0)
74 | gate_y = compute_gelu_gate(y0)
75 |
76 | x0 = x0 * gate_x
77 | x1 = x1 * gate_x
78 | x2 = x2 * gate_x
79 | x3 = x3 * gate_x
80 | x4 = x4 * gate_x
81 | x5 = x5 * gate_x
82 | x6 = x6 * gate_x
83 | x7 = x7 * gate_x
84 |
85 | y0 = y0 * gate_y
86 | y1 = y1 * gate_y
87 | y2 = y2 * gate_y
88 | y3 = y3 * gate_y
89 | y4 = y4 * gate_y
90 | y5 = y5 * gate_y
91 | y6 = y6 * gate_y
92 | y7 = y7 * gate_y
93 |
94 | p0 = x0*y0
95 | p1 = x0*y1
96 | p2 = x0*y2
97 | p3 = x0*y3
98 | p4 = x0*y4
99 | p5 = x0*y5
100 | p6 = x0*y6
101 | p7 = x0*y7
102 | p8 = x1*y1 + x2*y2 + x3*y3
103 | p9 = x1*y0
104 | p10 = x2*y0
105 | p11 = x3*y0
106 | p12 = -x2*y4 - x3*y5
107 | p13 = x1*y4 - x3*y6
108 | p14 = x1*y5 + x2*y6
109 | p15 = x1*y2 - x2*y1
110 | p16 = x1*y3 - x3*y1
111 | p17 = x2*y3 - x3*y2
112 | p18 = x3*y7
113 | p19 = -x2*y7
114 | p20 = x1*y7
115 | p21 = x1*y6 - x2*y5 + x3*y4
116 | p22 = -x4*y4 - x5*y5 - x6*y6
117 | p23 = x4*y2 + x5*y3
118 | p24 = -x4*y1 + x6*y3
119 | p25 = -x5*y1 - x6*y2
120 | p26 = -x6*y7
121 | p27 = x5*y7
122 | p28 = -x4*y7
123 | p29 = x4*y0
124 | p30 = x5*y0
125 | p31 = x6*y0
126 | p32 = -x5*y6 + x6*y5
127 | p33 = x4*y6 - x6*y4
128 | p34 = -x4*y5 + x5*y4
129 | p35 = x4*y3 - x5*y2 + x6*y1
130 | p36 = -x7*y7
131 | p37 = -x7*y6
132 | p38 = x7*y5
133 | p39 = -x7*y4
134 | p40 = x7*y3
135 | p41 = -x7*y2
136 | p42 = x7*y1
137 | p43 = x7*y0
138 |
139 | tl.store(pairwise_ptr + 0*batch_size*n_features + pairwise_offset, p0, mask=batch_feature_mask)
140 | tl.store(pairwise_ptr + 1*batch_size*n_features + pairwise_offset, p1, mask=batch_feature_mask)
141 | tl.store(pairwise_ptr + 2*batch_size*n_features + pairwise_offset, p2, mask=batch_feature_mask)
142 | tl.store(pairwise_ptr + 3*batch_size*n_features + pairwise_offset, p3, mask=batch_feature_mask)
143 | tl.store(pairwise_ptr + 4*batch_size*n_features + pairwise_offset, p4, mask=batch_feature_mask)
144 | tl.store(pairwise_ptr + 5*batch_size*n_features + pairwise_offset, p5, mask=batch_feature_mask)
145 | tl.store(pairwise_ptr + 6*batch_size*n_features + pairwise_offset, p6, mask=batch_feature_mask)
146 | tl.store(pairwise_ptr + 7*batch_size*n_features + pairwise_offset, p7, mask=batch_feature_mask)
147 | tl.store(pairwise_ptr + 8*batch_size*n_features + pairwise_offset, p8, mask=batch_feature_mask)
148 | tl.store(pairwise_ptr + 9*batch_size*n_features + pairwise_offset, p9, mask=batch_feature_mask)
149 | tl.store(pairwise_ptr + 10*batch_size*n_features + pairwise_offset, p10, mask=batch_feature_mask)
150 | tl.store(pairwise_ptr + 11*batch_size*n_features + pairwise_offset, p11, mask=batch_feature_mask)
151 | tl.store(pairwise_ptr + 12*batch_size*n_features + pairwise_offset, p12, mask=batch_feature_mask)
152 | tl.store(pairwise_ptr + 13*batch_size*n_features + pairwise_offset, p13, mask=batch_feature_mask)
153 | tl.store(pairwise_ptr + 14*batch_size*n_features + pairwise_offset, p14, mask=batch_feature_mask)
154 | tl.store(pairwise_ptr + 15*batch_size*n_features + pairwise_offset, p15, mask=batch_feature_mask)
155 | tl.store(pairwise_ptr + 16*batch_size*n_features + pairwise_offset, p16, mask=batch_feature_mask)
156 | tl.store(pairwise_ptr + 17*batch_size*n_features + pairwise_offset, p17, mask=batch_feature_mask)
157 | tl.store(pairwise_ptr + 18*batch_size*n_features + pairwise_offset, p18, mask=batch_feature_mask)
158 | tl.store(pairwise_ptr + 19*batch_size*n_features + pairwise_offset, p19, mask=batch_feature_mask)
159 | tl.store(pairwise_ptr + 20*batch_size*n_features + pairwise_offset, p20, mask=batch_feature_mask)
160 | tl.store(pairwise_ptr + 21*batch_size*n_features + pairwise_offset, p21, mask=batch_feature_mask)
161 | tl.store(pairwise_ptr + 22*batch_size*n_features + pairwise_offset, p22, mask=batch_feature_mask)
162 | tl.store(pairwise_ptr + 23*batch_size*n_features + pairwise_offset, p23, mask=batch_feature_mask)
163 | tl.store(pairwise_ptr + 24*batch_size*n_features + pairwise_offset, p24, mask=batch_feature_mask)
164 | tl.store(pairwise_ptr + 25*batch_size*n_features + pairwise_offset, p25, mask=batch_feature_mask)
165 | tl.store(pairwise_ptr + 26*batch_size*n_features + pairwise_offset, p26, mask=batch_feature_mask)
166 | tl.store(pairwise_ptr + 27*batch_size*n_features + pairwise_offset, p27, mask=batch_feature_mask)
167 | tl.store(pairwise_ptr + 28*batch_size*n_features + pairwise_offset, p28, mask=batch_feature_mask)
168 | tl.store(pairwise_ptr + 29*batch_size*n_features + pairwise_offset, p29, mask=batch_feature_mask)
169 | tl.store(pairwise_ptr + 30*batch_size*n_features + pairwise_offset, p30, mask=batch_feature_mask)
170 | tl.store(pairwise_ptr + 31*batch_size*n_features + pairwise_offset, p31, mask=batch_feature_mask)
171 | tl.store(pairwise_ptr + 32*batch_size*n_features + pairwise_offset, p32, mask=batch_feature_mask)
172 | tl.store(pairwise_ptr + 33*batch_size*n_features + pairwise_offset, p33, mask=batch_feature_mask)
173 | tl.store(pairwise_ptr + 34*batch_size*n_features + pairwise_offset, p34, mask=batch_feature_mask)
174 | tl.store(pairwise_ptr + 35*batch_size*n_features + pairwise_offset, p35, mask=batch_feature_mask)
175 | tl.store(pairwise_ptr + 36*batch_size*n_features + pairwise_offset, p36, mask=batch_feature_mask)
176 | tl.store(pairwise_ptr + 37*batch_size*n_features + pairwise_offset, p37, mask=batch_feature_mask)
177 | tl.store(pairwise_ptr + 38*batch_size*n_features + pairwise_offset, p38, mask=batch_feature_mask)
178 | tl.store(pairwise_ptr + 39*batch_size*n_features + pairwise_offset, p39, mask=batch_feature_mask)
179 | tl.store(pairwise_ptr + 40*batch_size*n_features + pairwise_offset, p40, mask=batch_feature_mask)
180 | tl.store(pairwise_ptr + 41*batch_size*n_features + pairwise_offset, p41, mask=batch_feature_mask)
181 | tl.store(pairwise_ptr + 42*batch_size*n_features + pairwise_offset, p42, mask=batch_feature_mask)
182 | tl.store(pairwise_ptr + 43*batch_size*n_features + pairwise_offset, p43, mask=batch_feature_mask)
183 |
184 |
185 | @triton.jit
186 | def assemble_kernel(
187 | transformed_ptr,
188 | pnorm_ptr,
189 | output_ptr,
190 | NORMALIZE: tl.constexpr,
191 | batch_size: tl.constexpr,
192 | n_features: tl.constexpr,
193 | BATCH_BLOCK: tl.constexpr,
194 | FEATURE_BLOCK: tl.constexpr,
195 | ):
196 | """
197 | Gather linearly transformed pairwise products and compute the geometric product.
198 | """
199 | batch_block_id = tl.program_id(axis=0)
200 | thread_block_id = tl.program_id(axis=1)
201 |
202 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
203 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
204 |
205 | batch_mask = batch_ids < batch_size
206 | feature_mask = feature_ids < n_features
207 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
208 |
209 | stride_component = batch_size * n_features
210 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
211 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
212 |
213 | t0 = tl.load(transformed_ptr + 0 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
214 | t1 = tl.load(transformed_ptr + 1 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
215 | t2 = tl.load(transformed_ptr + 2 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
216 | t3 = tl.load(transformed_ptr + 3 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
217 | t4 = tl.load(transformed_ptr + 4 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
218 | t5 = tl.load(transformed_ptr + 5 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
219 | t6 = tl.load(transformed_ptr + 6 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
220 | t7 = tl.load(transformed_ptr + 7 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
221 | t8 = tl.load(transformed_ptr + 8 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
222 | t9 = tl.load(transformed_ptr + 9 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
223 | t10 = tl.load(transformed_ptr + 10 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
224 | t11 = tl.load(transformed_ptr + 11 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
225 | t12 = tl.load(transformed_ptr + 12 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
226 | t13 = tl.load(transformed_ptr + 13 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
227 | t14 = tl.load(transformed_ptr + 14 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
228 | t15 = tl.load(transformed_ptr + 15 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
229 | t16 = tl.load(transformed_ptr + 16 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
230 | t17 = tl.load(transformed_ptr + 17 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
231 | t18 = tl.load(transformed_ptr + 18 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
232 | t19 = tl.load(transformed_ptr + 19 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
233 | t20 = tl.load(transformed_ptr + 20 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
234 | t21 = tl.load(transformed_ptr + 21 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
235 | t22 = tl.load(transformed_ptr + 22 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
236 | t23 = tl.load(transformed_ptr + 23 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
237 | t24 = tl.load(transformed_ptr + 24 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
238 | t25 = tl.load(transformed_ptr + 25 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
239 | t26 = tl.load(transformed_ptr + 26 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
240 | t27 = tl.load(transformed_ptr + 27 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
241 | t28 = tl.load(transformed_ptr + 28 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
242 | t29 = tl.load(transformed_ptr + 29 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
243 | t30 = tl.load(transformed_ptr + 30 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
244 | t31 = tl.load(transformed_ptr + 31 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
245 | t32 = tl.load(transformed_ptr + 32 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
246 | t33 = tl.load(transformed_ptr + 33 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
247 | t34 = tl.load(transformed_ptr + 34 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
248 | t35 = tl.load(transformed_ptr + 35 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
249 | t36 = tl.load(transformed_ptr + 36 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
250 | t37 = tl.load(transformed_ptr + 37 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
251 | t38 = tl.load(transformed_ptr + 38 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
252 | t39 = tl.load(transformed_ptr + 39 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
253 | t40 = tl.load(transformed_ptr + 40 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
254 | t41 = tl.load(transformed_ptr + 41 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
255 | t42 = tl.load(transformed_ptr + 42 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
256 | t43 = tl.load(transformed_ptr + 43 * batch_size * n_features + transformed_offset, mask=batch_feature_mask)
257 |
258 | o0 = t0 + t8 + t22 + t36
259 | o1 = t1 + t9 + t12 + t23 + t26 + t37
260 | o2 = t2 + t10 + t13 + t24 + t27 + t38
261 | o3 = t3 + t11 + t14 + t25 + t28 + t39
262 | o4 = t4 + t15 + t18 + t29 + t32 + t40
263 | o5 = t5 + t16 + t19 + t30 + t33 + t41
264 | o6 = t6 + t17 + t20 + t31 + t34 + t42
265 | o7 = t7 + t21 + t35 + t43
266 |
267 | if NORMALIZE:
268 | pn_scalar = tl.sum(o0 * o0, axis=1) / n_features
269 | pn_vector = tl.sum(o1 * o1 + o2 * o2 + o3 * o3, axis=1) / n_features
270 | pn_bivect = tl.sum(o4 * o4 + o5 * o5 + o6 * o6, axis=1) / n_features
271 | pn_pseudo = tl.sum(o7 * o7, axis=1) / n_features
272 |
273 | tl.atomic_add(pnorm_ptr + 0*batch_size + batch_ids, pn_scalar, mask=batch_mask)
274 | tl.atomic_add(pnorm_ptr + 1*batch_size + batch_ids, pn_vector, mask=batch_mask)
275 | tl.atomic_add(pnorm_ptr + 2*batch_size + batch_ids, pn_vector, mask=batch_mask)
276 | tl.atomic_add(pnorm_ptr + 3*batch_size + batch_ids, pn_vector, mask=batch_mask)
277 | tl.atomic_add(pnorm_ptr + 4*batch_size + batch_ids, pn_bivect, mask=batch_mask)
278 | tl.atomic_add(pnorm_ptr + 5*batch_size + batch_ids, pn_bivect, mask=batch_mask)
279 | tl.atomic_add(pnorm_ptr + 6*batch_size + batch_ids, pn_bivect, mask=batch_mask)
280 | tl.atomic_add(pnorm_ptr + 7*batch_size + batch_ids, pn_pseudo, mask=batch_mask)
281 |
282 | tl.store(output_ptr + 0 * stride_component + base_offset, o0, mask=batch_feature_mask)
283 | tl.store(output_ptr + 1 * stride_component + base_offset, o1, mask=batch_feature_mask)
284 | tl.store(output_ptr + 2 * stride_component + base_offset, o2, mask=batch_feature_mask)
285 | tl.store(output_ptr + 3 * stride_component + base_offset, o3, mask=batch_feature_mask)
286 | tl.store(output_ptr + 4 * stride_component + base_offset, o4, mask=batch_feature_mask)
287 | tl.store(output_ptr + 5 * stride_component + base_offset, o5, mask=batch_feature_mask)
288 | tl.store(output_ptr + 6 * stride_component + base_offset, o6, mask=batch_feature_mask)
289 | tl.store(output_ptr + 7 * stride_component + base_offset, o7, mask=batch_feature_mask)
290 |
291 |
292 | @triton.jit
293 | def normalize_with_sqrt_kernel(
294 | output_ptr,
295 | pnorm_ptr,
296 | batch_size: tl.constexpr,
297 | n_features: tl.constexpr,
298 | BATCH_BLOCK: tl.constexpr,
299 | FEATURE_BLOCK: tl.constexpr,
300 | MV_DIM: tl.constexpr,
301 | EPS: tl.constexpr,
302 | ):
303 | """Normalize the output by dividing each grade with root of its accumulated norm."""
304 | batch_block_id = tl.program_id(axis=0)
305 | thread_block_id = tl.program_id(axis=1)
306 |
307 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
308 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
309 |
310 | batch_mask = batch_ids < batch_size
311 | feature_mask = feature_ids < n_features
312 | batch_feature_mask = batch_mask[:, None, None] & feature_mask[None, :, None]
313 |
314 | component_ids = tl.arange(0, MV_DIM)[None, None, :]
315 |
316 | feature_offset = (component_ids * batch_size * n_features +
317 | batch_ids[:, None, None] * n_features +
318 | feature_ids[None, :, None])
319 |
320 | norm_indices = component_ids * batch_size + batch_ids[:, None, None]
321 |
322 | pnorm = tl.load(pnorm_ptr + norm_indices, mask=batch_mask[:, None, None])
323 | mv = tl.load(output_ptr + feature_offset, mask=batch_feature_mask)
324 |
325 | norm = tl.sqrt(pnorm + EPS)
326 | mv_normalized = mv / norm
327 |
328 | tl.store(output_ptr + feature_offset, mv_normalized, mask=batch_feature_mask)
329 |
330 |
331 | def gelu_fc_geometric_product_norm_fwd(
332 | x: torch.Tensor,
333 | y: torch.Tensor,
334 | weight: torch.Tensor,
335 | expansion_indices: torch.Tensor,
336 | normalize: bool,
337 | ) -> torch.Tensor:
338 | """Fused operation: GELU non-linearity, fully connected geometric product, and grade-wise RMSNorm."""
339 | assert x.shape == y.shape
340 | assert x.shape[0] == MV_DIM
341 | assert x.shape[2] == weight.shape[1] == weight.shape[2]
342 | assert weight.shape[0] == NUM_PRODUCT_WEIGHTS
343 |
344 | _, B, N = x.shape
345 |
346 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
347 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
348 |
349 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
350 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
351 |
352 | pairwise = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype)
353 | partial_norm = (torch.zeros((MV_DIM, B), device=x.device, dtype=x.dtype) if normalize else torch.zeros((1,), device=x.device, dtype=x.dtype))
354 | output = torch.empty_like(x)
355 |
356 | grid = (num_blocks_batch, num_blocks_features)
357 |
358 | gelu_pairwise_kernel_fwd[grid](
359 | x,
360 | y,
361 | pairwise,
362 | B,
363 | N,
364 | BATCH_BLOCK,
365 | FEATURE_BLOCK,
366 | num_warps=DEFAULT_NUM_WARPS,
367 | num_stages=DEFAULT_NUM_STAGES,
368 | )
369 |
370 | transformed = torch.bmm(pairwise, weight[expansion_indices])
371 |
372 | assemble_kernel[grid](
373 | transformed,
374 | partial_norm,
375 | output,
376 | normalize,
377 | B,
378 | N,
379 | BATCH_BLOCK,
380 | FEATURE_BLOCK,
381 | num_warps=DEFAULT_NUM_WARPS,
382 | num_stages=DEFAULT_NUM_STAGES,
383 | )
384 |
385 | if normalize:
386 | normalize_with_sqrt_kernel[grid](
387 | output,
388 | partial_norm,
389 | B,
390 | N,
391 | BATCH_BLOCK,
392 | FEATURE_BLOCK,
393 | MV_DIM,
394 | EPS,
395 | num_warps=DEFAULT_NUM_WARPS,
396 | num_stages=DEFAULT_NUM_STAGES,
397 | )
398 |
399 | return output, pairwise, partial_norm
400 |
401 |
402 | @triton.jit
403 | def grad_o_dot_o_kernel(
404 | dot_ptr,
405 | pnorm_ptr,
406 | output_ptr,
407 | grad_output_ptr,
408 | batch_size: tl.constexpr,
409 | n_features: tl.constexpr,
410 | BATCH_BLOCK: tl.constexpr,
411 | FEATURE_BLOCK: tl.constexpr,
412 | EPS: tl.constexpr,
413 | ):
414 | """Compute the dot product of grad_output and output for each grade, accumulate across all features."""
415 | batch_block_id = tl.program_id(axis=0)
416 | thread_block_id = tl.program_id(axis=1)
417 |
418 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
419 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
420 |
421 | batch_mask = batch_ids < batch_size
422 | feature_mask = feature_ids < n_features
423 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
424 |
425 | stride_component = batch_size * n_features
426 | offset = batch_ids[:, None] * n_features + feature_ids[None, :]
427 |
428 | go0 = tl.load(grad_output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
429 | go1 = tl.load(grad_output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
430 | go2 = tl.load(grad_output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
431 | go3 = tl.load(grad_output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
432 | go4 = tl.load(grad_output_ptr + 4 * stride_component + offset, mask=batch_feature_mask)
433 | go5 = tl.load(grad_output_ptr + 5 * stride_component + offset, mask=batch_feature_mask)
434 | go6 = tl.load(grad_output_ptr + 6 * stride_component + offset, mask=batch_feature_mask)
435 | go7 = tl.load(grad_output_ptr + 7 * stride_component + offset, mask=batch_feature_mask)
436 |
437 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
438 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
439 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None]
440 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None]
441 |
442 | o0 = tl.load(output_ptr + 0 * stride_component + offset, mask=batch_feature_mask)
443 | o1 = tl.load(output_ptr + 1 * stride_component + offset, mask=batch_feature_mask)
444 | o2 = tl.load(output_ptr + 2 * stride_component + offset, mask=batch_feature_mask)
445 | o3 = tl.load(output_ptr + 3 * stride_component + offset, mask=batch_feature_mask)
446 | o4 = tl.load(output_ptr + 4 * stride_component + offset, mask=batch_feature_mask)
447 | o5 = tl.load(output_ptr + 5 * stride_component + offset, mask=batch_feature_mask)
448 | o6 = tl.load(output_ptr + 6 * stride_component + offset, mask=batch_feature_mask)
449 | o7 = tl.load(output_ptr + 7 * stride_component + offset, mask=batch_feature_mask)
450 |
451 | rms_scalar = tl.sqrt(pn_scalar + EPS)
452 | rms_vector = tl.sqrt(pn_vector + EPS)
453 | rms_bivect = tl.sqrt(pn_bivect + EPS)
454 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
455 |
456 | dot_scalar = tl.sum(rms_scalar * go0 * o0, axis=1)
457 | dot_vector = tl.sum(rms_vector * (go1*o1 + go2*o2 + go3*o3), axis=1)
458 | dot_bivect = tl.sum(rms_bivect * (go4*o4 + go5*o5 + go6*o6), axis=1)
459 | dot_pseudo = tl.sum(rms_pseudo * go7 * o7, axis=1)
460 |
461 | tl.atomic_add(dot_ptr + 0*batch_size + batch_ids, dot_scalar, mask=batch_mask)
462 | tl.atomic_add(dot_ptr + 1*batch_size + batch_ids, dot_vector, mask=batch_mask)
463 | tl.atomic_add(dot_ptr + 2*batch_size + batch_ids, dot_bivect, mask=batch_mask)
464 | tl.atomic_add(dot_ptr + 3*batch_size + batch_ids, dot_pseudo, mask=batch_mask)
465 |
466 |
467 | @triton.jit
468 | def disassemble_kernel(
469 | grad_output_ptr,
470 | output_ptr,
471 | dot_ptr,
472 | grad_transformed_ptr,
473 | pnorm_ptr,
474 | NORMALIZE: tl.constexpr,
475 | batch_size: tl.constexpr,
476 | n_features: tl.constexpr,
477 | BATCH_BLOCK: tl.constexpr,
478 | FEATURE_BLOCK: tl.constexpr,
479 | EPS: tl.constexpr,
480 | ):
481 | """Gather linearly transformed pairwise products and compute the geometric product."""
482 | batch_block_id = tl.program_id(axis=0)
483 | thread_block_id = tl.program_id(axis=1)
484 |
485 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
486 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
487 |
488 | batch_mask = batch_ids < batch_size
489 | feature_mask = feature_ids < n_features
490 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
491 |
492 | stride_component = batch_size * n_features
493 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
494 | transformed_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
495 |
496 | go0 = tl.load(grad_output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
497 | go1 = tl.load(grad_output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
498 | go2 = tl.load(grad_output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
499 | go3 = tl.load(grad_output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
500 | go4 = tl.load(grad_output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
501 | go5 = tl.load(grad_output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
502 | go6 = tl.load(grad_output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
503 | go7 = tl.load(grad_output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
504 |
505 | if NORMALIZE:
506 | o0 = tl.load(output_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
507 | o1 = tl.load(output_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
508 | o2 = tl.load(output_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
509 | o3 = tl.load(output_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
510 | o4 = tl.load(output_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
511 | o5 = tl.load(output_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
512 | o6 = tl.load(output_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
513 | o7 = tl.load(output_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
514 |
515 | pn_scalar = tl.load(pnorm_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
516 | pn_vector = tl.load(pnorm_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
517 | pn_bivect = tl.load(pnorm_ptr + 4*batch_size + batch_ids, mask=batch_mask)[:, None]
518 | pn_pseudo = tl.load(pnorm_ptr + 7*batch_size + batch_ids, mask=batch_mask)[:, None]
519 |
520 | dot_scalar = tl.load(dot_ptr + 0*batch_size + batch_ids, mask=batch_mask)[:, None]
521 | dot_vector = tl.load(dot_ptr + 1*batch_size + batch_ids, mask=batch_mask)[:, None]
522 | dot_bivect = tl.load(dot_ptr + 2*batch_size + batch_ids, mask=batch_mask)[:, None]
523 | dot_pseudo = tl.load(dot_ptr + 3*batch_size + batch_ids, mask=batch_mask)[:, None]
524 |
525 | rms_scalar = tl.sqrt(pn_scalar + EPS)
526 | rms_vector = tl.sqrt(pn_vector + EPS)
527 | rms_bivect = tl.sqrt(pn_bivect + EPS)
528 | rms_pseudo = tl.sqrt(pn_pseudo + EPS)
529 |
530 | go0 = go0 / rms_scalar - o0 * dot_scalar / (n_features * rms_scalar * rms_scalar)
531 | go1 = go1 / rms_vector - o1 * dot_vector / (n_features * rms_vector * rms_vector)
532 | go2 = go2 / rms_vector - o2 * dot_vector / (n_features * rms_vector * rms_vector)
533 | go3 = go3 / rms_vector - o3 * dot_vector / (n_features * rms_vector * rms_vector)
534 | go4 = go4 / rms_bivect - o4 * dot_bivect / (n_features * rms_bivect * rms_bivect)
535 | go5 = go5 / rms_bivect - o5 * dot_bivect / (n_features * rms_bivect * rms_bivect)
536 | go6 = go6 / rms_bivect - o6 * dot_bivect / (n_features * rms_bivect * rms_bivect)
537 | go7 = go7 / rms_pseudo - o7 * dot_pseudo / (n_features * rms_pseudo * rms_pseudo)
538 |
539 | tl.store(grad_transformed_ptr + 0 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask)
540 | tl.store(grad_transformed_ptr + 1 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask)
541 | tl.store(grad_transformed_ptr + 2 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask)
542 | tl.store(grad_transformed_ptr + 3 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask)
543 | tl.store(grad_transformed_ptr + 4 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask)
544 | tl.store(grad_transformed_ptr + 5 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask)
545 | tl.store(grad_transformed_ptr + 6 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask)
546 | tl.store(grad_transformed_ptr + 7 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask)
547 | tl.store(grad_transformed_ptr + 8 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask)
548 | tl.store(grad_transformed_ptr + 9 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask)
549 | tl.store(grad_transformed_ptr + 10 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask)
550 | tl.store(grad_transformed_ptr + 11 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask)
551 | tl.store(grad_transformed_ptr + 12 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask)
552 | tl.store(grad_transformed_ptr + 13 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask)
553 | tl.store(grad_transformed_ptr + 14 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask)
554 | tl.store(grad_transformed_ptr + 15 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask)
555 | tl.store(grad_transformed_ptr + 16 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask)
556 | tl.store(grad_transformed_ptr + 17 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask)
557 | tl.store(grad_transformed_ptr + 18 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask)
558 | tl.store(grad_transformed_ptr + 19 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask)
559 | tl.store(grad_transformed_ptr + 20 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask)
560 | tl.store(grad_transformed_ptr + 21 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask)
561 | tl.store(grad_transformed_ptr + 22 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask)
562 | tl.store(grad_transformed_ptr + 23 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask)
563 | tl.store(grad_transformed_ptr + 24 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask)
564 | tl.store(grad_transformed_ptr + 25 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask)
565 | tl.store(grad_transformed_ptr + 26 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask)
566 | tl.store(grad_transformed_ptr + 27 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask)
567 | tl.store(grad_transformed_ptr + 28 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask)
568 | tl.store(grad_transformed_ptr + 29 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask)
569 | tl.store(grad_transformed_ptr + 30 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask)
570 | tl.store(grad_transformed_ptr + 31 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask)
571 | tl.store(grad_transformed_ptr + 32 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask)
572 | tl.store(grad_transformed_ptr + 33 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask)
573 | tl.store(grad_transformed_ptr + 34 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask)
574 | tl.store(grad_transformed_ptr + 35 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask)
575 | tl.store(grad_transformed_ptr + 36 * batch_size * n_features + transformed_offset, go0, mask=batch_feature_mask)
576 | tl.store(grad_transformed_ptr + 37 * batch_size * n_features + transformed_offset, go1, mask=batch_feature_mask)
577 | tl.store(grad_transformed_ptr + 38 * batch_size * n_features + transformed_offset, go2, mask=batch_feature_mask)
578 | tl.store(grad_transformed_ptr + 39 * batch_size * n_features + transformed_offset, go3, mask=batch_feature_mask)
579 | tl.store(grad_transformed_ptr + 40 * batch_size * n_features + transformed_offset, go4, mask=batch_feature_mask)
580 | tl.store(grad_transformed_ptr + 41 * batch_size * n_features + transformed_offset, go5, mask=batch_feature_mask)
581 | tl.store(grad_transformed_ptr + 42 * batch_size * n_features + transformed_offset, go6, mask=batch_feature_mask)
582 | tl.store(grad_transformed_ptr + 43 * batch_size * n_features + transformed_offset, go7, mask=batch_feature_mask)
583 |
584 |
585 | @triton.jit
586 | def gelu_pairwise_kernel_bwd(
587 | x_ptr,
588 | y_ptr,
589 | grad_pairwise_ptr,
590 | grad_x_ptr,
591 | grad_y_ptr,
592 | batch_size: tl.constexpr,
593 | n_features: tl.constexpr,
594 | BATCH_BLOCK: tl.constexpr,
595 | FEATURE_BLOCK: tl.constexpr,
596 | ):
597 | """
598 | Apply GELU non-linearity to inputs and compute required pairwise products.
599 | """
600 | batch_block_id = tl.program_id(axis=0)
601 | thread_block_id = tl.program_id(axis=1)
602 |
603 | batch_ids = batch_block_id*BATCH_BLOCK + tl.arange(0, BATCH_BLOCK)
604 | feature_ids = thread_block_id*FEATURE_BLOCK + tl.arange(0, FEATURE_BLOCK)
605 |
606 | batch_mask = batch_ids < batch_size
607 | feature_mask = feature_ids < n_features
608 | batch_feature_mask = batch_mask[:, None] & feature_mask[None, :]
609 |
610 | stride_component = batch_size * n_features
611 | base_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
612 | pairwise_offset = batch_ids[:, None] * n_features + feature_ids[None, :]
613 |
614 | gp0 = tl.load(grad_pairwise_ptr + 0 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
615 | gp1 = tl.load(grad_pairwise_ptr + 1 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
616 | gp2 = tl.load(grad_pairwise_ptr + 2 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
617 | gp3 = tl.load(grad_pairwise_ptr + 3 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
618 | gp4 = tl.load(grad_pairwise_ptr + 4 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
619 | gp5 = tl.load(grad_pairwise_ptr + 5 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
620 | gp6 = tl.load(grad_pairwise_ptr + 6 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
621 | gp7 = tl.load(grad_pairwise_ptr + 7 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
622 | gp8 = tl.load(grad_pairwise_ptr + 8 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
623 | gp9 = tl.load(grad_pairwise_ptr + 9 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
624 | gp10 = tl.load(grad_pairwise_ptr + 10 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
625 | gp11 = tl.load(grad_pairwise_ptr + 11 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
626 | gp12 = tl.load(grad_pairwise_ptr + 12 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
627 | gp13 = tl.load(grad_pairwise_ptr + 13 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
628 | gp14 = tl.load(grad_pairwise_ptr + 14 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
629 | gp15 = tl.load(grad_pairwise_ptr + 15 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
630 | gp16 = tl.load(grad_pairwise_ptr + 16 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
631 | gp17 = tl.load(grad_pairwise_ptr + 17 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
632 | gp18 = tl.load(grad_pairwise_ptr + 18 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
633 | gp19 = tl.load(grad_pairwise_ptr + 19 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
634 | gp20 = tl.load(grad_pairwise_ptr + 20 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
635 | gp21 = tl.load(grad_pairwise_ptr + 21 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
636 | gp22 = tl.load(grad_pairwise_ptr + 22 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
637 | gp23 = tl.load(grad_pairwise_ptr + 23 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
638 | gp24 = tl.load(grad_pairwise_ptr + 24 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
639 | gp25 = tl.load(grad_pairwise_ptr + 25 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
640 | gp26 = tl.load(grad_pairwise_ptr + 26 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
641 | gp27 = tl.load(grad_pairwise_ptr + 27 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
642 | gp28 = tl.load(grad_pairwise_ptr + 28 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
643 | gp29 = tl.load(grad_pairwise_ptr + 29 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
644 | gp30 = tl.load(grad_pairwise_ptr + 30 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
645 | gp31 = tl.load(grad_pairwise_ptr + 31 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
646 | gp32 = tl.load(grad_pairwise_ptr + 32 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
647 | gp33 = tl.load(grad_pairwise_ptr + 33 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
648 | gp34 = tl.load(grad_pairwise_ptr + 34 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
649 | gp35 = tl.load(grad_pairwise_ptr + 35 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
650 | gp36 = tl.load(grad_pairwise_ptr + 36 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
651 | gp37 = tl.load(grad_pairwise_ptr + 37 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
652 | gp38 = tl.load(grad_pairwise_ptr + 38 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
653 | gp39 = tl.load(grad_pairwise_ptr + 39 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
654 | gp40 = tl.load(grad_pairwise_ptr + 40 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
655 | gp41 = tl.load(grad_pairwise_ptr + 41 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
656 | gp42 = tl.load(grad_pairwise_ptr + 42 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
657 | gp43 = tl.load(grad_pairwise_ptr + 43 * batch_size * n_features + pairwise_offset, mask=batch_feature_mask)
658 |
659 | x0_raw = tl.load(x_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
660 | x1_raw = tl.load(x_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
661 | x2_raw = tl.load(x_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
662 | x3_raw = tl.load(x_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
663 | x4_raw = tl.load(x_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
664 | x5_raw = tl.load(x_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
665 | x6_raw = tl.load(x_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
666 | x7_raw = tl.load(x_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
667 |
668 | y0_raw = tl.load(y_ptr + 0 * stride_component + base_offset, mask=batch_feature_mask)
669 | y1_raw = tl.load(y_ptr + 1 * stride_component + base_offset, mask=batch_feature_mask)
670 | y2_raw = tl.load(y_ptr + 2 * stride_component + base_offset, mask=batch_feature_mask)
671 | y3_raw = tl.load(y_ptr + 3 * stride_component + base_offset, mask=batch_feature_mask)
672 | y4_raw = tl.load(y_ptr + 4 * stride_component + base_offset, mask=batch_feature_mask)
673 | y5_raw = tl.load(y_ptr + 5 * stride_component + base_offset, mask=batch_feature_mask)
674 | y6_raw = tl.load(y_ptr + 6 * stride_component + base_offset, mask=batch_feature_mask)
675 | y7_raw = tl.load(y_ptr + 7 * stride_component + base_offset, mask=batch_feature_mask)
676 |
677 | # collect gradients from pairwise products
678 | gate_x = compute_gelu_gate(x0_raw)
679 | gate_y = compute_gelu_gate(y0_raw)
680 |
681 | x0 = x0_raw * gate_x
682 | x1 = x1_raw * gate_x
683 | x2 = x2_raw * gate_x
684 | x3 = x3_raw * gate_x
685 | x4 = x4_raw * gate_x
686 | x5 = x5_raw * gate_x
687 | x6 = x6_raw * gate_x
688 | x7 = x7_raw * gate_x
689 |
690 | y0 = y0_raw * gate_y
691 | y1 = y1_raw * gate_y
692 | y2 = y2_raw * gate_y
693 | y3 = y3_raw * gate_y
694 | y4 = y4_raw * gate_y
695 | y5 = y5_raw * gate_y
696 | y6 = y6_raw * gate_y
697 | y7 = y7_raw * gate_y
698 |
699 | x_grad_0 = gp0*y0 + gp1*y1 + gp2*y2 + gp3*y3 + gp4*y4 + gp5*y5 + gp6*y6 + gp7*y7
700 | x_grad_1 = gp8*y1 + gp9*y0 + gp13*y4 + gp14*y5 + gp15*y2 + gp16*y3 + gp20*y7 + gp21*y6
701 | x_grad_2 = gp8*y2 + gp10*y0 - gp12*y4 + gp14*y6 - gp15*y1 + gp17*y3 - gp19*y7 - gp21*y5
702 | x_grad_3 = gp8*y3 + gp11*y0 - gp12*y5 - gp13*y6 - gp16*y1 - gp17*y2 + gp18*y7 + gp21*y4
703 | x_grad_4 = -gp22*y4 + gp23*y2 - gp24*y1 - gp28*y7 + gp29*y0 + gp33*y6 - gp34*y5 + gp35*y3
704 | x_grad_5 = -gp22*y5 + gp23*y3 - gp25*y1 + gp27*y7 + gp30*y0 - gp32*y6 + gp34*y4 - gp35*y2
705 | x_grad_6 = -gp22*y6 + gp24*y3 - gp25*y2 - gp26*y7 + gp31*y0 + gp32*y5 - gp33*y4 + gp35*y1
706 | x_grad_7 = -gp36*y7 - gp37*y6 + gp38*y5 - gp39*y4 + gp40*y3 - gp41*y2 + gp42*y1 + gp43*y0
707 |
708 | y_grad_0 = gp0*x0 + gp9*x1 + gp10*x2 + gp11*x3 + gp29*x4 + gp30*x5 + gp31*x6 + gp43*x7
709 | y_grad_1 = gp1*x0 + gp8*x1 - gp15*x2 - gp16*x3 - gp24*x4 - gp25*x5 + gp35*x6 + gp42*x7
710 | y_grad_2 = gp2*x0 + gp8*x2 + gp15*x1 - gp17*x3 + gp23*x4 - gp35*x5 - gp25*x6 - gp41*x7
711 | y_grad_3 = gp3*x0 + gp8*x3 + gp16*x1 + gp17*x2 + gp23*x5 + gp24*x6 + gp35*x4 + gp40*x7
712 | y_grad_4 = gp4*x0 + gp13*x1 - gp12*x2 + gp21*x3 - gp22*x4 + gp34*x5 - gp33*x6 - gp39*x7
713 | y_grad_5 = gp5*x0 + gp14*x1 - gp21*x2 - gp12*x3 - gp22*x5 - gp34*x4 + gp32*x6 + gp38*x7
714 | y_grad_6 = gp6*x0 - gp13*x3 + gp14*x2 + gp21*x1 - gp22*x6 - gp32*x5 + gp33*x4 - gp37*x7
715 | y_grad_7 = gp7*x0 + gp20*x1 - gp19*x2 + gp18*x3 - gp28*x4 + gp27*x5 - gp26*x6 - gp36*x7
716 |
717 | # GELU gate gradients
718 | dgate_x = compute_gelu_gate_grad(x0_raw)
719 | dgate_y = compute_gelu_gate_grad(y0_raw)
720 |
721 | x_grad_0 = (gate_x + x0_raw*dgate_x) * x_grad_0 + dgate_x * (x1_raw*x_grad_1 + x2_raw*x_grad_2 + x3_raw*x_grad_3 +
722 | x4_raw*x_grad_4 + x5_raw*x_grad_5 + x6_raw*x_grad_6 +
723 | x7_raw*x_grad_7)
724 | x_grad_1 = gate_x * x_grad_1
725 | x_grad_2 = gate_x * x_grad_2
726 | x_grad_3 = gate_x * x_grad_3
727 | x_grad_4 = gate_x * x_grad_4
728 | x_grad_5 = gate_x * x_grad_5
729 | x_grad_6 = gate_x * x_grad_6
730 | x_grad_7 = gate_x * x_grad_7
731 |
732 | y_grad_0 = (gate_y + y0_raw*dgate_y) * y_grad_0 + dgate_y * (y1_raw*y_grad_1 + y2_raw*y_grad_2 + y3_raw*y_grad_3 +
733 | y4_raw*y_grad_4 + y5_raw*y_grad_5 + y6_raw*y_grad_6 +
734 | y7_raw*y_grad_7)
735 | y_grad_1 = gate_y * y_grad_1
736 | y_grad_2 = gate_y * y_grad_2
737 | y_grad_3 = gate_y * y_grad_3
738 | y_grad_4 = gate_y * y_grad_4
739 | y_grad_5 = gate_y * y_grad_5
740 | y_grad_6 = gate_y * y_grad_6
741 | y_grad_7 = gate_y * y_grad_7
742 |
743 | tl.store(grad_x_ptr + 0 * stride_component + base_offset, x_grad_0, mask=batch_feature_mask)
744 | tl.store(grad_x_ptr + 1 * stride_component + base_offset, x_grad_1, mask=batch_feature_mask)
745 | tl.store(grad_x_ptr + 2 * stride_component + base_offset, x_grad_2, mask=batch_feature_mask)
746 | tl.store(grad_x_ptr + 3 * stride_component + base_offset, x_grad_3, mask=batch_feature_mask)
747 | tl.store(grad_x_ptr + 4 * stride_component + base_offset, x_grad_4, mask=batch_feature_mask)
748 | tl.store(grad_x_ptr + 5 * stride_component + base_offset, x_grad_5, mask=batch_feature_mask)
749 | tl.store(grad_x_ptr + 6 * stride_component + base_offset, x_grad_6, mask=batch_feature_mask)
750 | tl.store(grad_x_ptr + 7 * stride_component + base_offset, x_grad_7, mask=batch_feature_mask)
751 |
752 | tl.store(grad_y_ptr + 0 * stride_component + base_offset, y_grad_0, mask=batch_feature_mask)
753 | tl.store(grad_y_ptr + 1 * stride_component + base_offset, y_grad_1, mask=batch_feature_mask)
754 | tl.store(grad_y_ptr + 2 * stride_component + base_offset, y_grad_2, mask=batch_feature_mask)
755 | tl.store(grad_y_ptr + 3 * stride_component + base_offset, y_grad_3, mask=batch_feature_mask)
756 | tl.store(grad_y_ptr + 4 * stride_component + base_offset, y_grad_4, mask=batch_feature_mask)
757 | tl.store(grad_y_ptr + 5 * stride_component + base_offset, y_grad_5, mask=batch_feature_mask)
758 | tl.store(grad_y_ptr + 6 * stride_component + base_offset, y_grad_6, mask=batch_feature_mask)
759 | tl.store(grad_y_ptr + 7 * stride_component + base_offset, y_grad_7, mask=batch_feature_mask)
760 |
761 |
762 | def gelu_fc_geometric_product_norm_bwd(
763 | x: torch.Tensor,
764 | y: torch.Tensor,
765 | weight: torch.Tensor,
766 | o: torch.Tensor,
767 | pairwise: torch.Tensor,
768 | partial_norm: torch.Tensor,
769 | grad_output: torch.Tensor,
770 | expansion_indices: torch.Tensor,
771 | normalize: bool,
772 | ) -> torch.Tensor:
773 | """Backward pass for the fused operation."""
774 | _, B, N = x.shape
775 |
776 | BATCH_BLOCK = min(DEFAULT_BATCH_BLOCK, B)
777 | FEATURE_BLOCK = min(DEFAULT_FEATURE_BLOCK, N)
778 |
779 | num_blocks_batch = triton.cdiv(B, BATCH_BLOCK)
780 | num_blocks_features = triton.cdiv(N, FEATURE_BLOCK)
781 |
782 | grad_x = torch.zeros_like(x)
783 | grad_y = torch.zeros_like(y)
784 | dot = (torch.zeros((NUM_GRADES, B), device=x.device, dtype=x.dtype) if normalize else torch.empty(0))
785 | grad_weight = torch.zeros((NUM_PRODUCT_WEIGHTS, N, N), device=x.device, dtype=weight.dtype)
786 | grad_transformed = torch.empty((len(WEIGHT_EXPANSION), B, N), device=x.device, dtype=x.dtype)
787 |
788 | grid = (num_blocks_batch, num_blocks_features)
789 |
790 | if normalize:
791 | grad_o_dot_o_kernel[grid](
792 | dot,
793 | partial_norm,
794 | o,
795 | grad_output,
796 | B,
797 | N,
798 | BATCH_BLOCK,
799 | FEATURE_BLOCK,
800 | EPS,
801 | num_warps=DEFAULT_NUM_WARPS,
802 | num_stages=DEFAULT_NUM_STAGES,
803 | )
804 |
805 | disassemble_kernel[grid](
806 | grad_output,
807 | o,
808 | dot,
809 | grad_transformed,
810 | partial_norm,
811 | normalize,
812 | B,
813 | N,
814 | BATCH_BLOCK,
815 | FEATURE_BLOCK,
816 | EPS,
817 | num_warps=DEFAULT_NUM_WARPS,
818 | num_stages=DEFAULT_NUM_STAGES,
819 | )
820 |
821 | grad_pairwise = torch.bmm(grad_transformed, weight[expansion_indices].transpose(-2, -1))
822 |
823 | grad_weight.index_add_(0, expansion_indices, torch.bmm(pairwise.transpose(-2, -1), grad_transformed))
824 |
825 | gelu_pairwise_kernel_bwd[grid](
826 | x,
827 | y,
828 | grad_pairwise,
829 | grad_x,
830 | grad_y,
831 | B,
832 | N,
833 | BATCH_BLOCK,
834 | FEATURE_BLOCK,
835 | num_warps=DEFAULT_NUM_WARPS,
836 | num_stages=DEFAULT_NUM_STAGES,
837 | )
838 |
839 | return grad_x, grad_y, grad_weight
840 |
841 |
842 | class FullyConnectedGeluGeometricProductNorm3D(torch.autograd.Function):
843 |
844 | @staticmethod
845 | @torch.amp.custom_fwd(device_type="cuda")
846 | def forward(ctx, x, y, weight, normalize):
847 | assert x.is_contiguous() and y.is_contiguous() and weight.is_contiguous()
848 |
849 | ctx.dtype = x.dtype
850 | ctx.normalize = normalize
851 |
852 | expansion_indices = torch.tensor(WEIGHT_EXPANSION, device=x.device)
853 |
854 | o, pairwise, partial_norm = gelu_fc_geometric_product_norm_fwd(
855 | x,
856 | y,
857 | weight,
858 | expansion_indices,
859 | normalize,
860 | )
861 |
862 | ctx.save_for_backward(x, y, weight, o, pairwise, partial_norm, expansion_indices)
863 |
864 | return o.to(x.dtype)
865 |
866 | @staticmethod
867 | @torch.amp.custom_bwd(device_type="cuda")
868 | def backward(ctx, grad_output):
869 | grad_output = grad_output.contiguous()
870 |
871 | x, y, weight, o, pairwise, partial_norm, expansion_indices = ctx.saved_tensors
872 |
873 | grad_x, grad_y, grad_weight = gelu_fc_geometric_product_norm_bwd(
874 | x,
875 | y,
876 | weight,
877 | o,
878 | pairwise,
879 | partial_norm,
880 | grad_output,
881 | expansion_indices,
882 | ctx.normalize,
883 | )
884 |
885 | return grad_x, grad_y, grad_weight, None, None, None, None
886 |
887 |
888 | def fused_gelu_fcgp_norm_3d(x, y, weight, normalize=True):
889 | """
890 | Fused operation that applies GELU non-linearity to two multivector inputs,
891 | then computes their fully connected geometric product, and applies RMSNorm.
892 |
893 | Clifford algebra is assumed to be Cl(3,0).
894 |
895 | Args:
896 | x (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
897 | y (torch.Tensor): Input tensor of shape (MV_DIM, B, N).
898 | weight (torch.Tensor): Weight tensor of shape (NUM_PRODUCT_WEIGHTS, N, N), one weight per geometric product component.
899 | normalize (bool): Whether to apply RMSNorm after the geometric product.
900 |
901 | Returns:
902 | torch.Tensor: Output tensor of shape (MV_DIM, B, N) after applying the fused operation.
903 | """
904 | return FullyConnectedGeluGeometricProductNorm3D.apply(x, y, weight, normalize)
--------------------------------------------------------------------------------