├── media ├── QuACK.png ├── dsmem.png ├── tv_layout.png ├── liger-16k-65k.png ├── max_reduction.png ├── our-16k-131k.png ├── warp_storage.png ├── block_reduction.png ├── our-16k-131k-sol.png ├── pytorch-16k-131k.png ├── thread_reduction.png ├── warp_reduction.png ├── cluster_reduction.png ├── liger-16k-65k-ncu.png ├── gpu-memory-hierarchy.png ├── memory-access-hierarchy.png ├── productivity-performance.png ├── combined_kernel_benchmarks_final.png └── our-16k-131k-arithmetic-intensity-white.png ├── benchmarks ├── visual_outputs │ ├── quack_vs_pytorch_speedup.png │ ├── rmsnorm_speedup_comparison.png │ ├── quack_vs_torchcompile_speedup.png │ ├── quack_vs_pytorch_backward_speedup.png │ ├── rmsnorm_backward_speedup_comparison.png │ └── quack_vs_torchcompile_backward_speedup.png ├── benchmark_layernorm.py ├── benchmark_softmax.py ├── benchmark_topk.py ├── benchmark_cross_entropy.py └── benchmark_rmsnorm.py ├── quack ├── __init__.py ├── compile_utils.py ├── sort │ ├── utils.py │ ├── bitonic_sort.py │ └── sorting_networks.py ├── broadcast_utils.py ├── sm100_utils.py ├── mlp.py ├── fast_math.py ├── reduction_base.py ├── gemm_config.py ├── cute_dsl_utils.py ├── sm90_utils.py ├── tensormap_manager.py ├── gemm.py ├── gemm_dact.py ├── utils.py ├── linear.py └── linear_cross_entropy.py ├── .pre-commit-config.yaml ├── CLAUDE.md ├── pyproject.toml ├── .github └── workflows │ └── publish.yaml ├── README.md ├── tests ├── test_linear_cross_entropy.py ├── test_topk.py ├── test_layernorm.py ├── test_softmax.py ├── test_linear.py ├── test_symmetric_gemm.py └── test_linear_varlen_k.py ├── .gitignore ├── docs ├── dsl_control_flow.rst └── limitations.rst └── LICENSE /media/QuACK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/QuACK.png -------------------------------------------------------------------------------- /media/dsmem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/dsmem.png -------------------------------------------------------------------------------- /media/tv_layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/tv_layout.png -------------------------------------------------------------------------------- /media/liger-16k-65k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/liger-16k-65k.png -------------------------------------------------------------------------------- /media/max_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/max_reduction.png -------------------------------------------------------------------------------- /media/our-16k-131k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/our-16k-131k.png -------------------------------------------------------------------------------- /media/warp_storage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/warp_storage.png -------------------------------------------------------------------------------- /media/block_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/block_reduction.png -------------------------------------------------------------------------------- /media/our-16k-131k-sol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/our-16k-131k-sol.png -------------------------------------------------------------------------------- /media/pytorch-16k-131k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/pytorch-16k-131k.png -------------------------------------------------------------------------------- /media/thread_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/thread_reduction.png -------------------------------------------------------------------------------- /media/warp_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/warp_reduction.png -------------------------------------------------------------------------------- /media/cluster_reduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/cluster_reduction.png -------------------------------------------------------------------------------- /media/liger-16k-65k-ncu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/liger-16k-65k-ncu.png -------------------------------------------------------------------------------- /media/gpu-memory-hierarchy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/gpu-memory-hierarchy.png -------------------------------------------------------------------------------- /media/memory-access-hierarchy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/memory-access-hierarchy.png -------------------------------------------------------------------------------- /media/productivity-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/productivity-performance.png -------------------------------------------------------------------------------- /media/combined_kernel_benchmarks_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/combined_kernel_benchmarks_final.png -------------------------------------------------------------------------------- /media/our-16k-131k-arithmetic-intensity-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/our-16k-131k-arithmetic-intensity-white.png -------------------------------------------------------------------------------- /benchmarks/visual_outputs/quack_vs_pytorch_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_pytorch_speedup.png -------------------------------------------------------------------------------- /benchmarks/visual_outputs/rmsnorm_speedup_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/rmsnorm_speedup_comparison.png -------------------------------------------------------------------------------- /benchmarks/visual_outputs/quack_vs_torchcompile_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_torchcompile_speedup.png -------------------------------------------------------------------------------- /benchmarks/visual_outputs/quack_vs_pytorch_backward_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_pytorch_backward_speedup.png -------------------------------------------------------------------------------- /benchmarks/visual_outputs/rmsnorm_backward_speedup_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/rmsnorm_backward_speedup_comparison.png -------------------------------------------------------------------------------- /benchmarks/visual_outputs/quack_vs_torchcompile_backward_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_torchcompile_backward_speedup.png -------------------------------------------------------------------------------- /quack/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.2" 2 | 3 | from quack.rmsnorm import rmsnorm 4 | from quack.softmax import softmax 5 | from quack.cross_entropy import cross_entropy 6 | 7 | __all__ = [ 8 | "rmsnorm", 9 | "softmax", 10 | "cross_entropy", 11 | ] 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.11.13 4 | hooks: 5 | - id: ruff-check 6 | args: [--fix, --exit-non-zero-on-fix] 7 | files: ^(quack|tests)/.*\.py$ 8 | - id: ruff-format 9 | files: ^(quack|tests)/.*\.py$ 10 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # Project convention 2 | For code that uses cute-dsl (as part of function decorated with cute.jit or 3 | cute.kernel), only a subset of Python syntax is supported. 4 | Follow [Control Flow](docs/dsl_control_flow.rst) and [Limitations](docs/limitations.rst). 5 | # Code style 6 | - Favor concise, self-explanatory code 7 | - Avoid unnecessary comments 8 | - Avoid unnecessary line breaks 9 | - Empty lines should not have any space 10 | -------------------------------------------------------------------------------- /quack/compile_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. 2 | 3 | from typing import Optional 4 | 5 | import cutlass.cute as cute 6 | 7 | 8 | def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]: 9 | if leading_dim < 0: 10 | leading_dim = len(shape) + leading_dim 11 | if dtype is None: 12 | return None 13 | stride = tuple( 14 | cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 15 | for i in range(len(shape)) 16 | ) 17 | return cute.runtime.make_fake_tensor( 18 | dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8 19 | ) 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "quack-kernels" 7 | dynamic = ["version"] 8 | requires-python = ">=3.10" 9 | dependencies = [ 10 | "nvidia-cutlass-dsl==4.3.3", 11 | "torch", 12 | "apache-tvm-ffi>=0.1.5,<0.2", 13 | "torch-c-dlpack-ext", 14 | ] 15 | 16 | [project.optional-dependencies] 17 | dev = [ 18 | "pre-commit", 19 | "ruff", 20 | ] 21 | 22 | [tool.setuptools.packages.find] 23 | exclude = ["tests", "benchmarks"] 24 | 25 | [tool.setuptools.dynamic] 26 | version = {attr = "quack.__version__"} 27 | 28 | [tool.ruff] 29 | line-length = 100 30 | 31 | [tool.ruff.lint] 32 | ignore = [ 33 | "E731", # do not assign a lambda expression, use a def 34 | "E741", # Do not use variables named 'I', 'O', or 'l' 35 | "F841", # local variable is assigned to but never used 36 | ] -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | build-and-publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: '3.9' 18 | 19 | - name: Install build dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install build twine 23 | 24 | - name: Build package 25 | run: python -m build 26 | 27 | - name: Create GitHub Release 28 | uses: softprops/action-gh-release@v2 29 | with: 30 | files: dist/* 31 | generate_release_notes: true 32 | 33 | - name: Publish to PyPI 34 | env: 35 | TWINE_USERNAME: __token__ 36 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 37 | run: | 38 | twine upload dist/* 39 | -------------------------------------------------------------------------------- /quack/sort/utils.py: -------------------------------------------------------------------------------- 1 | import cutlass.cute as cute 2 | from cutlass import Float32, const_expr 3 | 4 | import quack.utils as utils 5 | 6 | 7 | @cute.jit 8 | def compare_and_swap( 9 | arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False 10 | ) -> None: 11 | """Compare and swap elements at indices i and j in ascending or descending order.""" 12 | if const_expr(use_selection): 13 | a, b = arr[i], arr[j] 14 | if (a > b) ^ (not ascending): 15 | arr[i] = b 16 | arr[j] = a 17 | # if const_expr(ascending): 18 | # if a > b: 19 | # arr[i] = b 20 | # arr[j] = a 21 | # else: 22 | # if a < b: 23 | # arr[i] = b 24 | # arr[j] = a 25 | else: 26 | min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin 27 | max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax 28 | if const_expr(ascending): 29 | arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j]) 30 | else: 31 | arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j]) 32 | -------------------------------------------------------------------------------- /quack/broadcast_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | from typing import Callable 3 | 4 | import cutlass 5 | import cutlass.cute as cute 6 | from cutlass import Float32, const_expr 7 | 8 | from quack.layout_utils import make_acc_tensor_mn_view 9 | 10 | 11 | @cute.jit 12 | def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None: 13 | if const_expr(tCrC.element_type != Float32): # Convert to f32 14 | tCrC_f32 = cute.make_fragment(tCrC.shape, Float32) 15 | tCrC_f32.store(tCrC.load().to(Float32)) 16 | else: 17 | tCrC_f32 = tCrC 18 | # this happens to work for frgA layout too, not just acc layout 19 | tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32) 20 | if const_expr(is_colvec): 21 | assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec) 22 | for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True): 23 | tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r])) 24 | else: 25 | assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec) 26 | for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True): 27 | tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c])) 28 | if const_expr(tCrC.element_type != Float32): # Convert back to original dtype 29 | tCrC.store(tCrC_f32.load().to(tCrC.element_type)) 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦆 QuACK: A Quirky Assortment of CuTe Kernels 🦆 2 | 3 | Kernels are written in the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html). 4 | 5 | ## Installation 6 | 7 | ``` bash 8 | pip install quack-kernels 9 | ``` 10 | 11 | ## Requirements 12 | 13 | - H100 or B200 GPU 14 | - CUDA toolkit 12.9+ 15 | - Python 3.12 16 | 17 | ## Kernels 🐥 18 | 19 | - 🦆 RMSNorm forward + backward 20 | - 🦆 Softmax forward + backward 21 | - 🦆 Cross entropy forward + backward 22 | - 🦆 Layernorm forward 23 | - 🦆 Hopper gemm + epilogue 24 | - 🦆 Blackwell gemm + epilogue 25 | 26 | ## Usage 27 | 28 | ``` 29 | from quack import rmsnorm, softmax, cross_entropy 30 | ``` 31 | 32 | ## Documentations 33 | 34 | [2025-07-10] We have a comprehensive 35 | [blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels 36 | to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html). 37 | 38 | ## Performance 39 | 40 |
41 |
42 | 45 |
46 |
47 | 48 | See our [blogpost](media/2025-07-10-membound-sol.md) for the details. 49 | 50 | ## Development 51 | 52 | To set up the development environment: 53 | 54 | ```bash 55 | pip install -e '.[dev]' 56 | pre-commit install 57 | ``` 58 | -------------------------------------------------------------------------------- /quack/sm100_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | 3 | from typing import Type, Union 4 | 5 | import cutlass.cute as cute 6 | import cutlass.utils.blackwell_helpers as sm100_utils_og 7 | from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode 8 | from cutlass.cutlass_dsl import Numeric, dsl_user_op 9 | 10 | 11 | @dsl_user_op 12 | def make_smem_layout_cpasync_a( 13 | tiled_mma: cute.TiledMma, 14 | mma_tiler_mnk: cute.Tile, 15 | a_dtype: Type[Numeric], 16 | num_stages: int, 17 | *, 18 | loc=None, 19 | ip=None, 20 | ) -> Union[cute.Layout, cute.ComposedLayout]: 21 | """ 22 | :param tiled_mma: The tiled MMA used to partition tensor A 23 | :type tiled_mma: cute.TiledMma 24 | :param mma_tiler_mnk: The MMA tile shape 25 | :type mma_tiler_mnk: cute.cute.Tile 26 | :param a_dtype: The element type for tensor A 27 | :type a_dtype: Type[Numeric] 28 | :param num_stages: The number of pipeline stages for tensor A 29 | :type num_stages: int 30 | 31 | :return: SMEM layout for tensor A 32 | :rtype: Union[cute.Layout, cute.ComposedLayout] 33 | """ 34 | 35 | is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K 36 | a_smem_shape = tiled_mma.partition_shape_A( 37 | cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) 38 | ) 39 | a_smem_shape_mn_k = ( 40 | cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], 41 | cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], 42 | ) 43 | a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom( 44 | sm100_utils_og.get_smem_layout_atom_ab( 45 | tiled_mma.op.a_major_mode, 46 | a_dtype, 47 | a_smem_shape_mn_k, 48 | loc=loc, 49 | ip=ip, 50 | ), 51 | a_dtype, 52 | loc=loc, 53 | ip=ip, 54 | ) 55 | a_smem_layout_staged = cute.tile_to_shape( 56 | a_smem_layout_atom, 57 | cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip), 58 | order=((1, 0, 2) if not is_k_major else (0, 1, 2)), 59 | loc=loc, 60 | ip=ip, 61 | ) 62 | return a_smem_layout_staged 63 | -------------------------------------------------------------------------------- /quack/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from quack.linear import linear_act_func, act_linear_func 8 | 9 | 10 | def mlp_func(x, weight1, weight2, activation: str, fuse_grad_accum=False, tuned=True): 11 | preact, postact = linear_act_func( 12 | x, 13 | weight1, 14 | activation, 15 | store_preact=torch.is_grad_enabled(), 16 | fuse_grad_accum=fuse_grad_accum, 17 | tuned=tuned, 18 | ) 19 | out = act_linear_func( 20 | preact, 21 | weight2, 22 | postact, 23 | activation=activation, 24 | fuse_grad_accum=fuse_grad_accum, 25 | tuned=tuned, 26 | ) 27 | return out 28 | 29 | 30 | class MLP(nn.Module): 31 | def __init__( 32 | self, 33 | in_features, 34 | hidden_features=None, 35 | out_features=None, 36 | bias1=False, 37 | bias2=False, 38 | activation="gelu", 39 | device=None, 40 | dtype=None, 41 | fuse_grad_accum: bool = False, 42 | tuned: bool = True, 43 | ): 44 | factory_kwargs = {"device": device, "dtype": dtype} 45 | super().__init__() 46 | out_features = out_features if out_features is not None else in_features 47 | hidden_features = hidden_features if hidden_features is not None else 4 * in_features 48 | self.activation = activation 49 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) 50 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) 51 | self.fuse_grad_accum = fuse_grad_accum 52 | self.tuned = tuned 53 | 54 | def forward(self, input: Tensor) -> Tensor: 55 | if ( 56 | self.fc1.bias is None 57 | and self.fc2.bias is None 58 | and input.is_cuda 59 | and input.stride(-1) == 1 60 | and self.fc1.in_features % 8 == 0 61 | and self.fc1.out_features % 8 == 0 62 | and self.fc2.out_features % 8 == 0 63 | ): 64 | return mlp_func( 65 | input, 66 | self.fc1.weight, 67 | self.fc2.weight, 68 | activation=self.activation, 69 | fuse_grad_accum=self.fuse_grad_accum, 70 | tuned=self.tuned, 71 | ) 72 | else: 73 | y = self.fc1(input) 74 | return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2]) 75 | -------------------------------------------------------------------------------- /quack/fast_math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | 3 | from typing import Tuple 4 | from dataclasses import dataclass 5 | 6 | import cutlass 7 | import cutlass.cute as cute 8 | from cutlass import Int32, Uint32 9 | from cutlass.cutlass_dsl import T, dsl_user_op 10 | from cutlass._mlir.dialects import llvm 11 | 12 | from quack.cute_dsl_utils import ParamsBase 13 | 14 | 15 | @cute.jit 16 | def clz(x: Int32) -> Int32: 17 | # for i in cutlass.range_constexpr(32): 18 | # if (1 << (31 - i)) & x: 19 | # return Int32(i) 20 | # return Int32(32) 21 | # Early exit is not supported yet 22 | res = Int32(32) 23 | done = False 24 | for i in cutlass.range(32): 25 | if ((1 << (31 - i)) & x) and not done: 26 | res = Int32(i) 27 | done = True 28 | return res 29 | 30 | 31 | def find_log2(x: Int32) -> Int32: 32 | a: Int32 = Int32(31 - clz(x)) 33 | return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. 34 | 35 | 36 | @dsl_user_op 37 | def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: 38 | return Uint32( 39 | llvm.inline_asm( 40 | T.i32(), 41 | [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], 42 | "mul.hi.u32 $0, $1, $2;", 43 | "=r,r,r", 44 | has_side_effects=False, 45 | is_align_stack=False, 46 | asm_dialect=llvm.AsmDialect.AD_ATT, 47 | ) 48 | ) 49 | 50 | 51 | @dataclass 52 | class FastDivmod(ParamsBase): 53 | divisor: Int32 54 | multiplier: Uint32 55 | shift_right: Uint32 56 | 57 | # called by host 58 | @staticmethod 59 | def create(divisor: Int32) -> "FastDivmod": 60 | """Construct the FastDivmod object, in host code. 61 | This precomputes some values based on the divisor and is computationally expensive. 62 | """ 63 | p = Uint32(31 + find_log2(divisor)) 64 | divisor_u32 = Uint32(divisor) 65 | multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) 66 | shift_right = Uint32(p - 32) 67 | return FastDivmod(divisor, multiplier, shift_right) 68 | 69 | @cute.jit 70 | def div(self, dividend: Int32) -> Int32: 71 | return ( 72 | Int32(umulhi(dividend, self.multiplier) >> self.shift_right) 73 | if self.divisor != 1 74 | else dividend 75 | ) 76 | 77 | def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: 78 | quotient = self.div(dividend) 79 | remainder = dividend - quotient * self.divisor 80 | return quotient, remainder 81 | -------------------------------------------------------------------------------- /benchmarks/benchmark_layernorm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Type 4 | 5 | import torch 6 | from triton.testing import do_bench 7 | 8 | import cutlass 9 | import cutlass.torch as cutlass_torch 10 | from cutlass.cute.runtime import from_dlpack 11 | from quack.layernorm import layernorm, layernorm_ref, rstd_ref, mean_ref 12 | import cutlass.cute as cute 13 | 14 | try: 15 | import cudnn 16 | except ImportError: 17 | cudnn = None 18 | 19 | 20 | def run_layernorm( 21 | M, 22 | N, 23 | dtype: Type[cutlass.Numeric], 24 | warmup_iterations=2, 25 | iterations=200, 26 | ): 27 | if not torch.cuda.is_available(): 28 | raise RuntimeError(f"Ampere GPU is required to run this example!") 29 | 30 | print(f"Tensor dimensions: [{M}, {N}]") 31 | print(f"Input and Output Data type: {dtype}") 32 | 33 | torch_dtype = cutlass_torch.dtype(dtype) 34 | 35 | device = "cuda" 36 | x = torch.randn(M, N, device=device, dtype=torch_dtype) 37 | w = torch.randn(N, device=device, dtype=torch.float32) 38 | 39 | print(f"Input tensor shapes:") 40 | print(f"x: {x.shape}, dtype: {x.dtype}") 41 | print(f"w: {w.shape}, dtype: {w.dtype}") 42 | 43 | eps = 1e-6 44 | 45 | print("Executing kernel...") 46 | out, rstd, mean = layernorm(x, w, eps=eps, return_rstd=True, return_mean=True) 47 | 48 | compiled_func_ref = torch.compile(layernorm_ref) 49 | 50 | fn = lambda: layernorm(x, w, eps=eps) 51 | time.sleep(0.5) 52 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 53 | mem_bw = (2 * x.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9 54 | print(f"Kernel execution time: {avg_time:.4f} ms") 55 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 56 | 57 | fn = lambda: compiled_func_ref(x, w, eps=eps) 58 | for _ in range(5): 59 | fn() # warm up 60 | time.sleep(0.5) 61 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 62 | mem_bw_ref = (2 * x.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9 63 | print(f"Ref kernel execution time: {avg_time:.4f} ms") 64 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 65 | 66 | return mem_bw, mem_bw_ref 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser( 71 | description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels" 72 | ) 73 | parser.add_argument("--M", default=32768, type=int) 74 | parser.add_argument("--N", default=16384, type=int) 75 | parser.add_argument("--warmup_iterations", default=10, type=int) 76 | parser.add_argument("--iterations", default=100, type=int) 77 | 78 | args = parser.parse_args() 79 | 80 | run_layernorm( 81 | args.M, 82 | args.N, 83 | dtype=cutlass.BFloat16, 84 | warmup_iterations=args.warmup_iterations, 85 | iterations=args.iterations, 86 | ) 87 | -------------------------------------------------------------------------------- /quack/reduction_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. 2 | 3 | from typing import Type, Tuple, Optional 4 | 5 | import cutlass 6 | import cutlass.cute as cute 7 | from cutlass import Int32, Int64, Float32, const_expr 8 | 9 | import quack.copy_utils as copy_utils 10 | 11 | 12 | class ReductionBase: 13 | def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32): 14 | self.dtype = dtype 15 | self.N = N 16 | self.stage = stage 17 | self.reduction_dtype = reduction_dtype 18 | 19 | def _threads_per_row(self): 20 | raise NotImplementedError() 21 | 22 | def _num_threads(self): 23 | return 128 if self.N <= 16384 else 256 24 | 25 | def _set_cluster_n(self): 26 | self.cluster_n = 1 27 | 28 | def _get_tiled_copy(self, vecsize: int = 1): 29 | assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" 30 | threads_per_row = self._threads_per_row() 31 | num_threads = self._num_threads() 32 | assert num_threads % cute.arch.WARP_SIZE == 0 33 | num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n) 34 | tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row) 35 | tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize) 36 | return tiled_copy, tiler_mn, threads_per_row 37 | 38 | def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int): 39 | num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE 40 | warps_per_row = ( 41 | num_warps 42 | if cute.rank(tv_layout.shape[0]) == 1 43 | else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) 44 | ) 45 | return cute.make_ordered_layout( 46 | (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), 47 | order=(1, 0, 2), 48 | ) 49 | 50 | def _allocate_reduction_buffer_and_mbar( 51 | self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False 52 | ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: 53 | reduction_buffer = smem.allocate_tensor( 54 | self.reduction_dtype, 55 | self._get_reduction_buffer_layout(tv_layout, self.cluster_n), 56 | byte_alignment=8, 57 | ) 58 | if const_expr(self.cluster_n > 1): 59 | mbar_ptr = smem.allocate_array( 60 | Int64, num_elems=self.stage if not is_persistent else self.stage * 2 61 | ) 62 | else: 63 | mbar_ptr = None 64 | return reduction_buffer, mbar_ptr 65 | 66 | @cute.jit 67 | def _initialize_cluster( 68 | self, 69 | tidx: Int32, 70 | mbar_ptr: cute.Pointer, 71 | num_warps: int, 72 | is_persistent: bool = False, 73 | ): 74 | if const_expr(self.cluster_n > 1): 75 | if tidx < self.stage: # Initialize full barrier 76 | cute.arch.mbarrier_init(mbar_ptr + tidx, 1) 77 | if const_expr(is_persistent): # Initialize empty barrier 78 | cute.arch.mbarrier_init( 79 | mbar_ptr + self.stage + tidx, num_warps * self.cluster_n 80 | ) 81 | cute.arch.mbarrier_init_fence() 82 | # Cluster arrive after barrier init 83 | cute.arch.cluster_arrive_relaxed() 84 | -------------------------------------------------------------------------------- /quack/gemm_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025, Fri Dao. 2 | import itertools 3 | from typing import Optional, List, Literal 4 | from functools import partial 5 | from dataclasses import dataclass 6 | 7 | 8 | @dataclass(frozen=True) 9 | class GemmConfig: 10 | tile_m: int = 128 11 | tile_n: int = 192 12 | pingpong: bool = True 13 | cluster_m: int = 2 14 | cluster_n: int = 1 15 | swap_ab: bool = False 16 | # raster_order: int = 1 17 | max_swizzle_size: int = 8 18 | 19 | 20 | def get_all_configs( 21 | device_capacity: Literal[9, 10] = 9, 22 | epilogue: Optional[str] = None, 23 | tune_coop: bool = True, 24 | # tune_raster_order=True, 25 | ) -> List[GemmConfig]: 26 | assert device_capacity in [9, 10] 27 | if device_capacity == 9: 28 | tile_n_vals = [128, 144, 160, 176, 192, 208] 29 | tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [ 30 | (128, 224), 31 | (128, 256), 32 | # (192, 256), # Getting IOT instruction (core dumped) in the bwd 33 | ] 34 | tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)] 35 | if epilogue in ["gated"]: 36 | tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192] 37 | tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0] 38 | elif epilogue in ["lse"]: 39 | tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192] 40 | tile_mn_vals = [] 41 | if tune_coop: 42 | tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals] 43 | tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals] 44 | cluster = [(1, 2), (2, 1)] 45 | # cluster = [(1, 1), (1, 2), (2, 1)] 46 | if epilogue in ["lse"]: 47 | cluster = [(1, 2), (2, 1)] 48 | swap_ab_vals = [False, True] 49 | if epilogue in ["lse", "gated"]: 50 | swap_ab_vals = [False] 51 | # raster_swizzle = ( 52 | # [(0, 1)] 53 | # if not tune_raster_order 54 | # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)] 55 | # ) 56 | return [ 57 | GemmConfig( 58 | tile_m=tile_m, 59 | tile_n=tile_n, 60 | pingpong=pingpong, 61 | cluster_m=cluster_m, 62 | cluster_n=cluster_n, 63 | swap_ab=swap_ab, 64 | # raster_order=raster_order, 65 | # max_swizzle_size=max_swizzle_size, 66 | ) 67 | for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product( 68 | tile_mn_vals, 69 | cluster, 70 | swap_ab_vals, 71 | # raster_swizzle, 72 | ) 73 | ] 74 | elif device_capacity == 10: 75 | tile_n_vals = [128, 160, 192, 224, 256] 76 | tile_n_64_vals = [128, 192, 256] 77 | tile_mn_cluster_vals = ( 78 | [(128, tile_n, (1, 2)) for tile_n in tile_n_vals] 79 | # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals] 80 | + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals] 81 | + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals] 82 | ) 83 | swap_ab_vals = [False, True] 84 | if epilogue in ["lse", "gated"]: 85 | swap_ab_vals = [False] 86 | max_swizzle_size_vals = [4, 8, 16] 87 | GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100 88 | return [ 89 | GemmConfigCls( 90 | tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms 91 | ) 92 | for (m, n, (cm, cn)), sab, ms in itertools.product( 93 | tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals 94 | ) 95 | ] 96 | -------------------------------------------------------------------------------- /tests/test_linear_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | 3 | import pytest 4 | import torch 5 | 6 | from quack.linear_cross_entropy import ( 7 | chunked_linear_cross_entropy, 8 | linear_cross_entropy_func_ref, 9 | ) 10 | 11 | 12 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 13 | @pytest.mark.parametrize("reduction", ["mean", "sum"]) 14 | @pytest.mark.parametrize("V", [32000, 50264, 128256]) 15 | # @pytest.mark.parametrize("V", [32000]) 16 | @pytest.mark.parametrize("d", [768, 1024]) 17 | # @pytest.mark.parametrize("d", [768]) 18 | @pytest.mark.parametrize("B_L", [8, 16, 24]) 19 | @pytest.mark.parametrize("chunk_size", [16]) 20 | def test_chunked_linear_cross_entropy(B_L, d, V, chunk_size, reduction, input_dtype): 21 | """Test chunked linear cross entropy against reference implementation.""" 22 | device = "cuda" 23 | atol, rtol = 1e-3, 1e-3 24 | torch.random.manual_seed(0) 25 | x = (torch.randn(B_L, d, device=device, dtype=input_dtype) * 0.1).requires_grad_() 26 | weight = (torch.randn(V, d, device=device, dtype=input_dtype) / (d**0.5)).requires_grad_() 27 | target = torch.randint(0, V, (B_L,), device=device, dtype=torch.int64) 28 | x_ref = x.detach().clone().requires_grad_(True) 29 | weight_ref = weight.detach().clone().requires_grad_(True) 30 | x_pt = x.detach().clone().requires_grad_(True) 31 | weight_pt = weight.detach().clone().requires_grad_(True) 32 | loss_ref = linear_cross_entropy_func_ref( 33 | x_ref.float(), weight_ref.float(), None, target, reduction=reduction 34 | ) 35 | loss_pt = linear_cross_entropy_func_ref(x_pt, weight_pt, None, target, reduction=reduction) 36 | # Chunked implementation 37 | loss = chunked_linear_cross_entropy( 38 | x, weight, target, chunk_size=chunk_size, reduction=reduction, tuned=False 39 | ) 40 | assert (loss - loss_ref).abs().max() < 3 * (loss_pt - loss_ref).abs().max() + 1e-5 41 | loss.backward() 42 | loss_ref.backward() 43 | loss_pt.backward() 44 | assert (x.grad - x_ref.grad).abs().max() < 2 * (x_pt.grad - x_ref.grad).abs().max() + 1e-4 45 | assert (weight.grad - weight_ref.grad).abs().max() < 2 * ( 46 | weight_pt.grad - weight_ref.grad 47 | ).abs().max() + 1e-4 48 | 49 | 50 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 51 | @pytest.mark.parametrize("reduction", ["mean", "sum"]) 52 | @pytest.mark.parametrize("chunk_size", [256, 1024]) 53 | def test_chunked_linear_cross_entropy_ignore_index(input_dtype, reduction, chunk_size): 54 | """Test chunked linear cross entropy with ignore_index.""" 55 | device = "cuda" 56 | B_L, d, V = 1024, 512, 2048 57 | ignore_index = V - 1 58 | atol, rtol = 1e-3, 1e-3 59 | torch.random.manual_seed(42) 60 | x = (torch.randn(B_L, d, device=device, dtype=input_dtype) * 0.1).requires_grad_() 61 | weight = (torch.randn(V, d, device=device, dtype=input_dtype) / (d**0.5)).requires_grad_() 62 | target = torch.randint(0, V, (B_L,), device=device, dtype=torch.int64) 63 | x_ref = x.detach().clone().requires_grad_(True) 64 | weight_ref = weight.detach().clone().requires_grad_(True) 65 | x_pt = x.detach().clone().requires_grad_(True) 66 | weight_pt = weight.detach().clone().requires_grad_(True) 67 | # Set some targets to ignore_index 68 | ignore_mask = torch.rand(B_L, device=device) < 0.2 69 | target[ignore_mask] = ignore_index 70 | loss_ref = linear_cross_entropy_func_ref( 71 | x_ref.float(), weight_ref.float(), None, target, reduction=reduction 72 | ) 73 | loss_pt = linear_cross_entropy_func_ref(x_pt, weight_pt, None, target, reduction=reduction) 74 | # Chunked implementation 75 | loss = chunked_linear_cross_entropy( 76 | x, weight, target, chunk_size=chunk_size, reduction=reduction, tuned=False 77 | ) 78 | assert (loss - loss_ref).abs().max() < 3 * (loss_pt - loss_ref).abs().max() + 1e-5 79 | loss.backward() 80 | loss_ref.backward() 81 | loss_pt.backward() 82 | assert (x.grad - x_ref.grad).abs().max() < 2 * (x_pt.grad - x_ref.grad).abs().max() + 1e-4 83 | assert (weight.grad - weight_ref.grad).abs().max() < 2 * ( 84 | weight_pt.grad - weight_ref.grad 85 | ).abs().max() + 1e-4 86 | -------------------------------------------------------------------------------- /quack/cute_dsl_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | 3 | from typing import Tuple 4 | from functools import lru_cache 5 | from dataclasses import dataclass, fields 6 | 7 | import torch 8 | 9 | try: 10 | from triton.tools.disasm import extract 11 | except ImportError: 12 | extract = None 13 | 14 | import cutlass 15 | import cutlass.cute as cute 16 | from cutlass import Int32, Int64, Float16, BFloat16, Float32 17 | from cutlass.base_dsl.typing import JitArgument 18 | from cutlass.cutlass_dsl import NumericMeta 19 | 20 | 21 | StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) 22 | 23 | 24 | load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data 25 | cute_compile_og = cute.compile 26 | 27 | 28 | torch2cute_dtype_map = { 29 | torch.float16: Float16, 30 | torch.bfloat16: BFloat16, 31 | torch.float32: Float32, 32 | torch.int32: Int32, 33 | torch.int64: Int64, 34 | } 35 | 36 | 37 | @lru_cache 38 | def get_max_active_clusters(cluster_size): 39 | return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) 40 | 41 | 42 | @lru_cache 43 | def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: 44 | return torch.cuda.get_device_capability(device) 45 | 46 | 47 | @dataclass 48 | class ParamsBase: 49 | def __extract_mlir_values__(self): 50 | all_fields = [getattr(self, field.name) for field in fields(self)] 51 | non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] 52 | values, self._values_pos = [], [] 53 | for obj in non_constexpr_fields: 54 | obj_values = cutlass.extract_mlir_values(obj) 55 | values += obj_values 56 | self._values_pos.append(len(obj_values)) 57 | return values 58 | 59 | def __new_from_mlir_values__(self, values): 60 | all_fields = {field.name: getattr(self, field.name) for field in fields(self)} 61 | constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} 62 | non_constexpr_fields = { 63 | n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) 64 | } 65 | for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): 66 | non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) 67 | values = values[n_items:] 68 | return self.__class__(**non_constexpr_fields, **constexpr_fields) 69 | 70 | 71 | @dataclass 72 | class ArgumentsBase(JitArgument): 73 | def __c_pointers__(self): 74 | all_fields = [getattr(self, field.name) for field in fields(self)] 75 | non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] 76 | c_ptrs = [] 77 | for obj in non_constexpr_fields: 78 | if hasattr(obj, "__c_pointers__"): 79 | c_ptrs.extend(obj.__c_pointers__()) 80 | return c_ptrs 81 | 82 | def __get_mlir_types__(self): 83 | all_fields = [getattr(self, field.name) for field in fields(self)] 84 | non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] 85 | types, self._values_pos = [], [] 86 | for obj in non_constexpr_fields: 87 | if hasattr(obj, "__get_mlir_types__"): 88 | obj_types = obj.__get_mlir_types__() 89 | types.extend(obj_types) 90 | self._values_pos.append(len(obj_types)) 91 | else: 92 | self._values_pos.append(0) 93 | return types 94 | 95 | def __new_from_mlir_values__(self, values): 96 | all_fields = {field.name: getattr(self, field.name) for field in fields(self)} 97 | constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} 98 | non_constexpr_fields = { 99 | n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) 100 | } 101 | for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): 102 | non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) 103 | values = values[n_items:] 104 | return self.__class__(**non_constexpr_fields, **constexpr_fields) 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .claude/ 3 | 4 | # Created by https://www.toptal.com/developers/gitignore/api/python 5 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 6 | 7 | ### Python ### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | 169 | ### Python Patch ### 170 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 171 | poetry.toml 172 | 173 | # ruff 174 | .ruff_cache/ 175 | 176 | # LSP config files 177 | pyrightconfig.json 178 | 179 | # End of https://www.toptal.com/developers/gitignore/api/python 180 | -------------------------------------------------------------------------------- /quack/sm90_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | 3 | from typing import Type, Union, Optional 4 | 5 | import cutlass 6 | import cutlass.cute as cute 7 | import cutlass.utils.hopper_helpers as sm90_utils_og 8 | from cutlass.cute.nvgpu import warpgroup 9 | from cutlass.cutlass_dsl import Numeric, dsl_user_op 10 | from cutlass import Float32, Int32, Boolean, const_expr 11 | from cutlass.utils import LayoutEnum 12 | 13 | 14 | @dsl_user_op 15 | def make_smem_layout( 16 | dtype: Type[Numeric], 17 | layout: LayoutEnum, 18 | tile: cute.Tile, 19 | stage: Optional[int] = None, 20 | *, 21 | loc=None, 22 | ip=None, 23 | ) -> Union[cute.Layout, cute.ComposedLayout]: 24 | shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip) 25 | major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] 26 | smem_layout_atom = warpgroup.make_smem_layout_atom( 27 | sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), 28 | dtype, 29 | ) 30 | smem_layout_staged = cute.tile_to_shape( 31 | smem_layout_atom, 32 | cute.append(shape, stage) if const_expr(stage is not None) else shape, 33 | order=(1, 0, 2) if layout.is_m_major_c() else (0, 1, 2), 34 | ) 35 | return smem_layout_staged 36 | 37 | 38 | # For compatibility with blackwell_helpers.py 39 | make_smem_layout_epi = make_smem_layout 40 | 41 | 42 | @dsl_user_op 43 | def partition_for_epilogue( 44 | cT: cute.Tensor, 45 | epi_tile: cute.Tile, 46 | tiled_copy: cute.TiledCopy, 47 | tidx: Int32, 48 | reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy 49 | *, 50 | loc=None, 51 | ip=None, 52 | ) -> cute.Tensor: 53 | thr_copy = tiled_copy.get_slice(tidx) 54 | cT_epi = cute.flat_divide(cT, epi_tile) 55 | # (CPY, CPY_M, CPY_N, EPI_M, EPI_N) 56 | if const_expr(reference_src): 57 | return thr_copy.partition_S(cT_epi, loc=loc, ip=ip) 58 | else: 59 | return thr_copy.partition_D(cT_epi, loc=loc, ip=ip) 60 | 61 | 62 | @cute.jit 63 | def gemm( 64 | tiled_mma: cute.TiledMma, 65 | acc: cute.Tensor, 66 | tCrA: cute.Tensor, 67 | tCrB: cute.Tensor, 68 | zero_init: cutlass.Constexpr[bool] = False, 69 | wg_wait: cutlass.Constexpr[int] = 0, 70 | # A_in_regs: cutlass.Constexpr[bool] = False, 71 | swap_AB: cutlass.Constexpr[bool] = False, 72 | ) -> None: 73 | if const_expr(swap_AB): 74 | gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) 75 | else: 76 | warpgroup.fence() 77 | # We make a new mma_atom since we'll be modifying its attribute (accumulate). 78 | # Otherwise the compiler complains "operand #0 does not dominate this use" 79 | mma_atom = cute.make_mma_atom(tiled_mma.op) 80 | mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) 81 | for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): 82 | cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) 83 | mma_atom.set(warpgroup.Field.ACCUMULATE, True) 84 | warpgroup.commit_group() 85 | if const_expr(wg_wait >= 0): 86 | warpgroup.wait_group(wg_wait) 87 | 88 | 89 | def gemm_zero_init( 90 | tiled_mma: cute.TiledMma, 91 | shape: cute.Shape, 92 | tCrA: cute.Tensor, 93 | tCrB: cute.Tensor, 94 | A_idx: Optional[Int32] = None, 95 | B_idx: Optional[Int32] = None, 96 | wg_wait: int = -1, 97 | swap_AB: bool = False, 98 | ) -> cute.Tensor: 99 | if const_expr(swap_AB): 100 | return gemm_zero_init( 101 | tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False 102 | ) 103 | else: 104 | acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) 105 | rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] 106 | rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] 107 | gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) 108 | return acc 109 | 110 | 111 | def gemm_w_idx( 112 | tiled_mma: cute.TiledMma, 113 | acc: cute.Tensor, 114 | tCrA: cute.Tensor, 115 | tCrB: cute.Tensor, 116 | zero_init: Boolean, 117 | A_idx: Optional[Int32] = None, 118 | B_idx: Optional[Int32] = None, 119 | wg_wait: int = -1, 120 | swap_AB: bool = False, 121 | ) -> None: 122 | if const_expr(swap_AB): 123 | gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False) 124 | else: 125 | rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] 126 | rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] 127 | gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) 128 | -------------------------------------------------------------------------------- /quack/sort/bitonic_sort.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. 2 | 3 | import math 4 | from typing import Optional 5 | 6 | import cutlass 7 | import cutlass.cute as cute 8 | from cutlass import Int32, Float32, const_expr 9 | 10 | import quack.utils as utils 11 | from quack.sort.utils import compare_and_swap 12 | from quack.sort.sorting_networks import optimal_sort 13 | 14 | 15 | @cute.jit 16 | def bitonic_merge( 17 | arr: cute.Tensor, 18 | n: Optional[cutlass.Constexpr[int]] = None, 19 | start: cutlass.Constexpr[int] = 0, 20 | ascending: cutlass.Constexpr[bool] = True, 21 | ) -> None: 22 | """Merge a bitonic sequence into a sorted sequence using iterative approach.""" 23 | if const_expr(n is None): 24 | n = cute.size(arr.shape) 25 | if const_expr(n > 1): 26 | num_levels = int(math.log2(n)) 27 | assert n == 2**num_levels, "n must be a power of 2" 28 | # This one must be range_constexpr otherwise it's very slow for n = 128 29 | for level in cutlass.range_constexpr(num_levels): 30 | length = n >> level # n // (2^level) 31 | step = length // 2 32 | for i in cutlass.range(n // length, unroll_full=True): 33 | start_i = start + i * length 34 | for j in cutlass.range(step, unroll_full=True): 35 | compare_and_swap(arr, start_i + j, start_i + j + step, ascending) 36 | 37 | 38 | @cute.jit 39 | def bitonic_sort( 40 | arr: cute.Tensor, 41 | n: Optional[cutlass.Constexpr[int]] = None, 42 | start: cutlass.Constexpr[int] = 0, 43 | ascending: cutlass.Constexpr[bool] = True, 44 | ) -> None: 45 | """ 46 | Bitonic sort for small arrays of size N (power of 2, N <= 128). 47 | 48 | Args: 49 | arr: Array to sort 50 | n: Size of array (must be power of 2 and <= 128) 51 | start: Starting index (default 0) 52 | ascending: Sort in ascending order (default True) 53 | """ 54 | if const_expr(n is None): 55 | n = cute.size(arr.shape) 56 | assert n <= 128 57 | if const_expr(n > 1): 58 | if const_expr(n in [2, 4, 8, 16, 32, 64]): 59 | optimal_sort(arr, n, start, ascending) 60 | else: # Fall back to bitonic sort 61 | assert n % 2 == 0 62 | # Sort first half in ascending order 63 | bitonic_sort(arr, n // 2, start, True) 64 | # Sort second half in descending order 65 | bitonic_sort(arr, n // 2, start + n // 2, False) 66 | # Merge the whole sequence 67 | bitonic_merge(arr, n, start, ascending) 68 | 69 | 70 | @cute.jit 71 | def bitonic_topk_merge( 72 | arr0: cute.Tensor, 73 | arr1: cute.Tensor, 74 | k: Optional[cutlass.Constexpr[int]] = None, 75 | start0: cutlass.Constexpr[int] = 0, 76 | start1: cutlass.Constexpr[int] = 0, 77 | ascending: cutlass.Constexpr[bool] = False, 78 | ) -> None: 79 | if const_expr(k is None): 80 | k = cute.size(arr0.shape) 81 | if const_expr(arr0.element_type == Float32): 82 | minmax_fn = utils.fmin if ascending else cute.arch.fmax 83 | else: 84 | minmax_fn = min if ascending else max 85 | # Write the top k elements to the first half of the array 86 | for i in cutlass.range(k, unfoll_full=True): 87 | arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i]) 88 | # Now the 1st half is bitonic, we just need to merge it 89 | bitonic_merge(arr0, k, start0, ascending) 90 | 91 | 92 | @cute.jit 93 | def bitonic_topk( 94 | arr: cute.Tensor, 95 | k: cutlass.Constexpr[int], 96 | ascending: cutlass.Constexpr[bool] = False, 97 | warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, 98 | ) -> cute.Tensor: 99 | """ 100 | Bitonic top-k for small arrays of size N (power of 2, N <= 128). 101 | 102 | Args: 103 | arr: Array to sort 104 | k: must be power of 2 and <= 128 105 | ascending: Sort in ascending order (default False) 106 | """ 107 | assert arr.element_type in [Float32, Int32] 108 | n = cute.size(arr.shape) 109 | assert k == 1 << int(math.log2(k)), "k must be a power of 2" 110 | assert n % k == 0, "n must be divisible by k" 111 | topk_vals = cute.make_fragment(k, arr.element_type) 112 | for v in cutlass.range(k, unroll_full=True): 113 | topk_vals[v] = arr[v] 114 | bitonic_sort(topk_vals, ascending=ascending) 115 | for i in cutlass.range(1, n // k, unroll_full=True): 116 | other_vals = cute.make_fragment(k, arr.element_type) 117 | for v in cutlass.range(k, unroll_full=True): 118 | other_vals[v] = arr[i * k + v] 119 | bitonic_sort(other_vals, ascending=ascending) 120 | # Merge 2 sorted top-k sequences to get a new top-k sequence 121 | bitonic_topk_merge(topk_vals, other_vals, ascending=ascending) 122 | # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps 123 | # do duplicate work. 124 | for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True): 125 | other_vals = cute.make_fragment(k, arr.element_type) 126 | for v in cutlass.range(k, unroll_full=True): 127 | other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i) 128 | bitonic_topk_merge(topk_vals, other_vals, ascending=ascending) 129 | return topk_vals 130 | -------------------------------------------------------------------------------- /tests/test_topk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. 2 | 3 | import pytest 4 | import torch 5 | 6 | 7 | from quack.topk import topk 8 | 9 | torch._dynamo.config.cache_size_limit = 1024 10 | torch._dynamo.config.accumulated_cache_size_limit = 1024 11 | 12 | 13 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32]) 14 | # @pytest.mark.parametrize("input_dtype", [torch.float32]) 15 | @pytest.mark.parametrize( 16 | "N, k", 17 | [(64, 16), (128, 32), (256, 16), (512, 32), (1024, 32), (4096, 32), (4096, 64), (4096, 128)], 18 | # [(64, 16)], 19 | ) 20 | @pytest.mark.parametrize("M", [1, 37, 199]) 21 | # @pytest.mark.parametrize("M", [1]) 22 | @pytest.mark.parametrize("softmax", [False, True]) 23 | # @pytest.mark.parametrize("softmax", [False]) 24 | @pytest.mark.parametrize("function", [topk, torch.compile(topk, fullgraph=True)]) 25 | # @pytest.mark.parametrize("function", [torch.compile(topk, fullgraph=True)]) 26 | def test_topk(M, N, k, input_dtype, softmax, function): 27 | """Test TopK forward/backward against PyTorch reference implementation.""" 28 | device = "cuda" 29 | # Set tolerance based on dtype 30 | if input_dtype == torch.bfloat16: 31 | atol = 1e-2 32 | rtol = 1e-2 33 | elif input_dtype == torch.float16: 34 | atol = 1e-3 35 | rtol = 1e-3 36 | else: 37 | atol = 1e-3 38 | rtol = 5e-4 39 | 40 | torch.random.manual_seed(0) 41 | # Create input tensors 42 | x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True) 43 | out_val, out_idx = function(x, k, softmax=softmax) 44 | out_val_ref, out_idx_ref = torch.topk(x.detach(), k, dim=-1, largest=True, sorted=True) 45 | if softmax: 46 | out_val_ref = torch.softmax(out_val_ref.float(), dim=-1).to(input_dtype) 47 | 48 | # Check output shape and dtype 49 | assert out_val.shape == (M, k) 50 | assert out_val.dtype == input_dtype 51 | # Check accuracy - values should match the reference 52 | torch.testing.assert_close(out_val, out_val_ref, atol=atol, rtol=rtol) 53 | 54 | if not softmax: 55 | # 1. Values should be in descending order 56 | assert torch.all(out_val[:, :-1] >= out_val[:, 1:]), "Some rows not in descending order" 57 | # 2. Values indexed at output indices should match output values 58 | indexed_vals = torch.gather(x, 1, out_idx.long()) 59 | torch.testing.assert_close(indexed_vals, out_val, atol=atol, rtol=rtol) 60 | 61 | # Backward check 62 | dvalues = torch.randn_like(out_val) 63 | out_val.backward(dvalues) 64 | dx_ref = torch.zeros_like(x) 65 | dx_ref.scatter_(1, out_idx.long(), dvalues) 66 | torch.testing.assert_close(x.grad, dx_ref, atol=1e-3, rtol=1e-3) 67 | else: 68 | # For softmax case, check that probabilities sum to 1 69 | torch.testing.assert_close( 70 | out_val.float().sum(dim=-1), 71 | torch.ones(M, device=device, dtype=torch.float32), 72 | atol=1e-2, 73 | rtol=1e-2, 74 | msg="Softmax probabilities don't sum to 1", 75 | ) 76 | dvalues = torch.randn_like(out_val) 77 | out_val.backward(dvalues) 78 | dot = (dvalues.float() * out_val.float()).sum(dim=1, keepdim=True) 79 | grad_topk = out_val.float() * (dvalues.float() - dot) 80 | grad_topk = grad_topk.to(input_dtype) 81 | dx_ref = torch.zeros_like(x) 82 | dx_ref.scatter_(1, out_idx.long(), grad_topk) 83 | torch.testing.assert_close(x.grad, dx_ref, atol=1e-3, rtol=1e-3) 84 | 85 | 86 | # @pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32]) 87 | # def test_topk_extreme_values(input_dtype): 88 | # """Test TopK with extreme input values.""" 89 | # device = "cuda" 90 | # M, N, k = 16, 64, 16 91 | 92 | # # Test with identical values 93 | # x_uniform = torch.full((M, N), 1.0, device=device, dtype=input_dtype) 94 | # out_uniform = topk(x_uniform, k) 95 | # # All output values should be 1.0 96 | # expected = torch.full((M, k), 1.0, device=device, dtype=input_dtype) 97 | # torch.testing.assert_close(out_uniform, expected, atol=1e-3, rtol=1e-3) 98 | 99 | # # Test with large range of values 100 | # x_range = torch.arange(N, dtype=input_dtype, device=device).unsqueeze(0).expand(M, -1) 101 | # out_range = topk(x_range, k) 102 | # # Should get the largest k values in descending order 103 | # expected_range = torch.arange(N-1, N-k-1, -1, dtype=input_dtype, device=device).unsqueeze(0).expand(M, -1) 104 | # torch.testing.assert_close(out_range, expected_range, atol=1e-6, rtol=1e-6) 105 | 106 | 107 | # def test_topk_edge_cases(): 108 | # """Test TopK edge cases.""" 109 | # device = "cuda" 110 | 111 | # # Test k=1 (single maximum) 112 | # M, N = 8, 64 113 | # x = torch.randn(M, N, device=device, dtype=torch.float32) 114 | # out_val = topk(x, 1) 115 | # out_val_ref = torch.max(x, dim=-1, keepdim=True)[0] 116 | # torch.testing.assert_close(out_val, out_val_ref, atol=1e-6, rtol=1e-6) 117 | 118 | # # Test with negative values 119 | # x_neg = torch.randn(M, N, device=device, dtype=torch.float32) - 10.0 120 | # out_neg = topk(x_neg, 8) 121 | # out_ref_neg, _ = torch.topk(x_neg, 8, dim=-1, largest=True, sorted=True) 122 | # torch.testing.assert_close(out_neg, out_ref_neg, atol=1e-6, rtol=1e-6) 123 | -------------------------------------------------------------------------------- /tests/test_layernorm.py: -------------------------------------------------------------------------------- 1 | # tests/test_layernorm.py 2 | 3 | import pytest 4 | import torch 5 | 6 | from quack.rmsnorm import layernorm_fwd, layernorm_ref, layernorm_rstd_ref, layernorm_mean_ref 7 | 8 | 9 | @pytest.mark.parametrize("eps", [1e-5, 1e-6]) 10 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32]) 11 | @pytest.mark.parametrize("M", [1, 37, 199]) 12 | @pytest.mark.parametrize( 13 | "N", [256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144] 14 | ) # , 32768]) 15 | def test_layernorm_forward(M, N, input_dtype, eps): 16 | """Test LayerNorm forward pass against reference implementation.""" 17 | device = "cuda" 18 | 19 | # tolerance depends on precision 20 | if input_dtype == torch.bfloat16: 21 | atol = 1e-2 22 | rtol = 1e-2 23 | elif input_dtype == torch.float16: 24 | atol = 1e-3 25 | rtol = 1e-3 26 | else: 27 | atol = 1e-4 28 | rtol = 1e-4 29 | 30 | torch.random.manual_seed(0) 31 | x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True) 32 | weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True) 33 | 34 | # pure‐PyTorch refs 35 | x_ref = x.detach().clone().requires_grad_() 36 | weight_ref = weight.detach().clone().requires_grad_() 37 | 38 | out, rstd, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=True) 39 | out_ref = layernorm_ref(x_ref, weight_ref, eps=eps) 40 | rstd_ref_val = layernorm_rstd_ref(x_ref, eps=eps) 41 | mean_ref_val = layernorm_mean_ref(x_ref) 42 | 43 | # shapes & dtypes 44 | assert out.shape == x.shape 45 | assert out.dtype == input_dtype 46 | assert rstd.shape == (M,) and rstd.dtype == torch.float32 47 | assert mean.shape == (M,) and mean.dtype == torch.float32 48 | 49 | # numeric check 50 | torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) 51 | torch.testing.assert_close(rstd, rstd_ref_val, atol=6e-4, rtol=6e-4) 52 | torch.testing.assert_close(mean, mean_ref_val, atol=6e-4, rtol=6e-4) 53 | 54 | 55 | @pytest.mark.parametrize("return_rstd", [True, False]) 56 | @pytest.mark.parametrize("return_mean", [True, False]) 57 | def test_layernormnorm_return_rstd_option(return_rstd, return_mean): 58 | """Test that return_rstd option works correctly.""" 59 | device = "cuda" 60 | M, N = 32, 1024 61 | eps = 1e-6 62 | 63 | x = torch.randn(M, N, device=device, dtype=torch.float16) 64 | weight = torch.randn(N, device=device, dtype=torch.float32) 65 | 66 | if return_rstd and return_mean: 67 | out, rstd, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=True) 68 | assert out.shape == (M, N) 69 | assert rstd.shape == (M,) 70 | assert rstd.dtype == torch.float32 71 | assert mean.shape == (M,) 72 | assert mean.dtype == torch.float32 73 | elif return_rstd and not return_mean: 74 | out, rstd = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=False) 75 | assert out.shape == (M, N) 76 | assert rstd.shape == (M,) 77 | assert rstd.dtype == torch.float32 78 | elif not return_rstd and return_mean: 79 | out, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=False, return_mean=True) 80 | assert out.shape == (M, N) 81 | assert mean.shape == (M,) 82 | assert mean.dtype == torch.float32 83 | else: 84 | out = layernorm_fwd(x, weight, eps=eps, return_rstd=False, return_mean=False) 85 | assert out.shape == (M, N) 86 | assert isinstance(out, torch.Tensor) 87 | 88 | 89 | def test_layernorm_input_validation(): 90 | """Test input validation and error handling.""" 91 | device = "cuda" 92 | 93 | # Test 3D input (should fail) 94 | x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16) 95 | weight = torch.randn(1024, device=device, dtype=torch.float32) 96 | 97 | with pytest.raises(AssertionError, match="Input must be 2D"): 98 | layernorm_fwd(x_3d, weight) 99 | 100 | # Test weight dimension mismatch 101 | x = torch.randn(32, 1024, device=device, dtype=torch.float16) 102 | weight_wrong = torch.randn(512, device=device, dtype=torch.float32) 103 | 104 | with pytest.raises(ValueError, match="Mismatched mW.shape[0]*"): 105 | layernorm_fwd(x, weight_wrong) 106 | 107 | # Test CPU tensors (should fail) 108 | x_cpu = torch.randn(32, 1024, dtype=torch.float16) 109 | weight_cpu = torch.randn(1024, dtype=torch.float32) 110 | 111 | # with pytest.raises(AssertionError, match="Tensors must be on CUDA device"): 112 | # With torch.library custom op, this now fails with NotImplementedError 113 | with pytest.raises(NotImplementedError): 114 | layernorm_fwd(x_cpu, weight_cpu) 115 | 116 | # Test unsupported dtype 117 | x = torch.randn(32, 1024, device=device, dtype=torch.float64) 118 | weight = torch.randn(1024, device=device, dtype=torch.float32) 119 | 120 | with pytest.raises(AssertionError, match="Unsupported dtype"): 121 | layernorm_fwd(x, weight) 122 | 123 | # Test wrong weight dtype 124 | x = torch.randn(32, 1024, device=device, dtype=torch.float16) 125 | weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float16) 126 | 127 | with pytest.raises(AssertionError, match="Weight must be float32"): 128 | layernorm_fwd(x, weight_wrong_dtype) 129 | -------------------------------------------------------------------------------- /quack/tensormap_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | 3 | from typing import Tuple 4 | from dataclasses import dataclass 5 | 6 | import cutlass 7 | import cutlass.cute as cute 8 | from cutlass.cutlass_dsl import Boolean, const_expr, Int32 9 | from cutlass.utils import TensorMapUpdateMode, TensorMapManager 10 | from cutlass._mlir.dialects import llvm 11 | 12 | 13 | @dataclass(frozen=True) 14 | class TensorMapManagerSm90(TensorMapManager): 15 | """ 16 | We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only 17 | perform the operation if warp_id matches the current warp. 18 | But for Hopper pingpong gemm we want to call it with warp_id 0 and 4. 19 | So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not. 20 | """ 21 | 22 | @cute.jit 23 | def init_tensormap_from_atom( 24 | self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean 25 | ) -> None: 26 | if is_manager_warp: 27 | with cute.arch.elect_one(): 28 | cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr) 29 | cute.arch.sync_warp() 30 | return 31 | 32 | @cute.jit 33 | def update_tensormap( 34 | self, 35 | tensor_gmem: Tuple[cute.Tensor, ...], 36 | tma_copy_atom: Tuple[cute.CopyAtom, ...], 37 | tensormap_gmem_ptr: Tuple[cute.Pointer, ...], 38 | is_manager_warp: Boolean, 39 | tensormap_smem_ptr: Tuple[cute.Pointer, ...], 40 | ) -> None: 41 | # updates before touching tensormap in global memory 42 | if is_manager_warp: 43 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): 44 | for copy_atom, tensor, smem_ptr in zip( 45 | tma_copy_atom, tensor_gmem, tensormap_smem_ptr 46 | ): 47 | cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr) 48 | # wait until it's safe to update tensormap in global memory 49 | with cute.arch.elect_one(): 50 | cute.arch.cp_async_bulk_commit_group() 51 | cute.arch.cp_async_bulk_wait_group(0, read=True) 52 | cute.arch.sync_warp() 53 | # updates to tensormap in global memory 54 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): 55 | for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): 56 | cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) 57 | else: 58 | for copy_atom, tensor, gmem_ptr in zip( 59 | tma_copy_atom, tensor_gmem, tensormap_gmem_ptr 60 | ): 61 | cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr) 62 | cute.arch.sync_warp() 63 | cute.nvgpu.cpasync.fence_tma_desc_release() 64 | 65 | @cute.jit 66 | def update_tensormap_shape( 67 | self, 68 | tensormap_gmem_ptr: Tuple[cute.Pointer, ...], 69 | is_manager_warp: Boolean, 70 | tensormap_smem_ptr: Tuple[cute.Pointer, ...], 71 | shapes: Tuple[Int32, ...], 72 | orders: cutlass.Constexpr[Tuple[int, ...]], 73 | ) -> None: 74 | # updates before touching tensormap in global memory 75 | if is_manager_warp: 76 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): 77 | for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders): 78 | smem_ptr_i32 = smem_ptr.toint().ir_value() 79 | llvm.inline_asm( 80 | None, 81 | [smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()], 82 | "{\n\t" 83 | ".reg .b64 smem_ptr_i64;\n\t" 84 | "cvt.u64.u32 smem_ptr_i64, $0;\n\t" 85 | f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t" 86 | "}\n", 87 | "r,r", 88 | has_side_effects=True, 89 | is_align_stack=False, 90 | asm_dialect=llvm.AsmDialect.AD_ATT, 91 | ) 92 | # wait until it's safe to update tensormap in global memory 93 | with cute.arch.elect_one(): 94 | cute.arch.cp_async_bulk_commit_group() 95 | cute.arch.cp_async_bulk_wait_group(0, read=True) 96 | cute.arch.sync_warp() 97 | # updates to tensormap in global memory 98 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): 99 | for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): 100 | cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) 101 | else: 102 | assert len(shapes) == len(orders) == len(tensormap_gmem_ptr) 103 | for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders): 104 | gmem_ptr_i64 = gmem_ptr.toint().ir_value() 105 | llvm.inline_asm( 106 | None, 107 | [gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()], 108 | f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;", 109 | "l,r", 110 | has_side_effects=True, 111 | is_align_stack=False, 112 | asm_dialect=llvm.AsmDialect.AD_ATT, 113 | ) 114 | cute.arch.sync_warp() 115 | cute.nvgpu.cpasync.fence_tma_desc_release() 116 | -------------------------------------------------------------------------------- /benchmarks/benchmark_softmax.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Type 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from triton.testing import do_bench 8 | 9 | import cutlass 10 | import cutlass.torch as cutlass_torch 11 | 12 | from quack.softmax import softmax 13 | 14 | try: 15 | from liger_kernel.transformers.functional import liger_softmax 16 | except ImportError: 17 | liger_softmax = None 18 | 19 | 20 | def run_softmax( 21 | M, 22 | N, 23 | dtype: Type[cutlass.Numeric], 24 | warmup_iterations=10, 25 | iterations=1000, 26 | ): 27 | if not torch.cuda.is_available(): 28 | raise RuntimeError(f"Ampere GPU is required to run this example!") 29 | 30 | print(f"Tensor dimensions: [{M}, {N}]") 31 | print(f"Input and Output Data type: {dtype}") 32 | 33 | torch_dtype = cutlass_torch.dtype(dtype) 34 | 35 | device = "cuda" 36 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype) 37 | 38 | print(f"Input tensor shapes:") 39 | print(f"x: {x.shape}, dtype: {x.dtype}") 40 | out = softmax(x) 41 | # compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1)) 42 | compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1)) 43 | fn = lambda: softmax(x) 44 | time.sleep(0.5) 45 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 46 | mem_bw = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) 47 | print(f"Kernel execution time: {avg_time:.4f} ms") 48 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 49 | 50 | fn = lambda: compiled_func_ref(x) 51 | for _ in range(5): fn() # warm up 52 | time.sleep(0.5) 53 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 54 | mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) 55 | print(f"Torch compile kernel execution time: {avg_time:.4f} ms") 56 | print(f"Torch compile mem throughput: {mem_bw_ref:.2f} GB/s") 57 | 58 | if liger_softmax is not None: 59 | fn = lambda: liger_softmax(x) 60 | for _ in range(5): fn() # warm up 61 | time.sleep(0.5) 62 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 63 | mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) 64 | print(f"Liger kernel execution time: {avg_time:.4f} ms") 65 | print(f"Liger mem throughput: {mem_bw_ref:.2f} GB/s") 66 | 67 | return mem_bw, mem_bw_ref 68 | 69 | 70 | def run_softmax_backward( 71 | M, 72 | N, 73 | dtype: Type[cutlass.Numeric], 74 | warmup_iterations=10, 75 | iterations=1000, 76 | ): 77 | if not torch.cuda.is_available(): 78 | raise RuntimeError(f"Ampere GPU is required to run this example!") 79 | 80 | print(f"Tensor dimensions: [{M}, {N}]") 81 | print(f"Input and Output Data type: {dtype}") 82 | 83 | torch_dtype = cutlass_torch.dtype(dtype) 84 | 85 | device = "cuda" 86 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True) 87 | x_ref = x.detach().clone().requires_grad_() 88 | 89 | print(f"Input tensor shapes:") 90 | print(f"x: {x.shape}, dtype: {x.dtype}") 91 | 92 | y = softmax(x) 93 | dy = torch.randn_like(y) 94 | 95 | time.sleep(0.5) 96 | fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) 97 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 98 | # Memory: read dy and y, write ax backward 99 | mem_bw = round(3 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) 100 | print(f"Kernel execution time: {avg_time:.4f} ms") 101 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 102 | 103 | # Reference implementation 104 | y_ref = F.softmax(x_ref, dim=-1) 105 | compiled_func_ref = torch.compile(lambda: torch.autograd.grad(y_ref, x_ref, grad_outputs=dy, retain_graph=True)) 106 | 107 | for _ in range(5): compiled_func_ref() # warm up 108 | time.sleep(0.5) 109 | avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations) 110 | mem_bw_ref = round(3 * x.numel() * dtype.width // 8 / (avg_time_ref / 1000) / 1e9) 111 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms") 112 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 113 | 114 | return mem_bw, mem_bw_ref 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser( 119 | description="Benchmark softmax forward and backward passes" 120 | ) 121 | parser.add_argument("--M", default=8192, type=int) 122 | parser.add_argument("--N", default=16384, type=int) 123 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16) 124 | parser.add_argument("--warmup_iterations", default=10, type=int) 125 | parser.add_argument("--iterations", default=100, type=int) 126 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass") 127 | 128 | args = parser.parse_args() 129 | torch.manual_seed(0) 130 | 131 | if args.backward: 132 | print("=== Softmax Backward Pass Benchmark ===") 133 | run_softmax_backward( 134 | args.M, 135 | args.N, 136 | dtype=args.dtype, 137 | warmup_iterations=args.warmup_iterations, 138 | iterations=args.iterations, 139 | ) 140 | else: 141 | print("=== Softmax Forward Pass Benchmark ===") 142 | run_softmax( 143 | args.M, 144 | args.N, 145 | dtype=args.dtype, 146 | warmup_iterations=args.warmup_iterations, 147 | iterations=args.iterations, 148 | ) 149 | 150 | # MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)] 151 | # # MN_pairs = [(32768, 1024)] 152 | # results = [] 153 | # for M, N in MN_pairs: 154 | # res = run_softmax( 155 | # M, 156 | # N, 157 | # dtype=args.dtype, 158 | # warmup_iterations=args.warmup_iterations, 159 | # iterations=args.iterations, 160 | # ) 161 | # results.append(res) 162 | # # print(results) 163 | # print([x for x, _ in results]) 164 | -------------------------------------------------------------------------------- /docs/dsl_control_flow.rst: -------------------------------------------------------------------------------- 1 | .. _dsl_control_flow: 2 | .. |DC| replace:: dynamic compilation 3 | .. |IR| replace:: intermediate representation (IR) 4 | .. |DSL| replace:: CuTe DSL 5 | .. |Constexpr| replace:: **Constexpr** (compile-time Python value) 6 | 7 | Control Flow 8 | ================== 9 | 10 | 11 | Overview 12 | -------- 13 | |DSL| walks Python's AST and converts each control-flow construct it finds into 14 | structured |IR|. You can therefore write ordinary Python loops and branches 15 | while the compiler decides—statement by statement—whether to 16 | 17 | * **evaluate at compile time** if it's a native Python control flow, or 18 | * **emit intermediate representation (IR)** when the control flow is marked as dynamic. 19 | 20 | Passing |IR| values to a native Python control flow will result in an error. 21 | 22 | 23 | For Loops 24 | --------- 25 | |DSL| recognises three kinds of ranges for ``for`` loops: 26 | 27 | * ``range`` – the Python built-in, always lowered to |IR| 28 | * ``cutlass.range`` - Same as Python built-in ``range``, but supports advanced unrolling and pipelining control 29 | * ``cutlass.range_constexpr`` – unrolled at compile time 30 | 31 | 32 | range(...)/cutlass.range(...) 33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | Use when you *always* want a loop in the generated |IR|, even if the inputs 35 | are Python values. 36 | 37 | cutlass.range_constexpr(...) 38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 39 | Runs in the Python interpreter and is fully unrolled before code generation. 40 | All loop indices must be |Constexpr|. 41 | 42 | 43 | **Example:** 44 | 45 | .. code-block:: python 46 | 47 | @cute.jit 48 | def control_flow_examples(bound: cutlass.Int32): 49 | n = 10 50 | 51 | # ✅ This loop is Python loop, evaluated at compile time. 52 | for i in cutlass.range_constexpr(n): 53 | cute.printf("%d\\n", i) 54 | 55 | # ✅ This loop is dynamic, even when bound is Python value. 56 | for i in range(n): 57 | cute.printf("%d\\n", i) 58 | 59 | # ❌ This loop bound is a dynamic value, not allowed in Python loop. 60 | # Should use `range` instead. 61 | for i in cutlass.range_constexpr(bound): 62 | cute.printf("%d\\n", i) 63 | 64 | # ✅ This loop is dynamic, emitted IR loop. 65 | for i in range(bound): 66 | cute.printf("%d\\n", i) 67 | 68 | # ✅ This loop is dynamic, emitted IR loop with unrolling 69 | for i in cutlass.range(bound, unroll=2): 70 | cute.printf("%d\\n", i) 71 | 72 | 73 | If-Else Statements 74 | ------------------ 75 | 76 | Standard Python ``if``/``elif``/``else`` is supported. 77 | 78 | * **Predicate without annotation** → lowered to |IR|. 79 | * **Predicate annotated with `cutlass.const_expr`** → evaluated at compile time. 80 | 81 | **Example:** 82 | 83 | .. code-block:: python 84 | 85 | @cute.jit 86 | def main(const_var: cutlass.Constexpr, dynamic_var: cutlass.Int32): 87 | # ✅ This branch is Python branch, evaluated at compile time. 88 | if cutlass.const_expr(const_var): 89 | cute.printf("Const branch\\n") 90 | else: 91 | cute.printf("Const else\\n") 92 | 93 | # ✅ This branch is dynamic branch, emitted IR branch. 94 | if dynamic_var == 10: 95 | cute.printf("Dynamic True\\n") 96 | else: 97 | cute.printf("Dynamic False\\n") 98 | 99 | # ❌ Using a dynamic value with `cutlass.const_expr` is not allowed. 100 | if cutlass.const_expr(dynamic_var == 10): 101 | cute.printf("Bound is 10\\n") 102 | 103 | 104 | While Loops 105 | ----------- 106 | 107 | Standard Python ``while`` is supported. 108 | 109 | * **Condition without annotation** → lowered to |IR|. 110 | * **Condition annotated with `cutlass.const_expr`** → evaluated at compile time. 111 | 112 | **Example:** 113 | 114 | .. code-block:: python 115 | 116 | @cute.jit 117 | def main(dynamic_var: cutlass.Int32): 118 | n = 0 119 | 120 | # ✅ This is Python while loop, evaluated at compile time. 121 | while cutlass.const_expr(n < 10): 122 | cute.printf("Const branch\\n") 123 | n += 1 124 | 125 | # ✅ This is dynamic while loop, emitted IR while loop. 126 | while dynamic_var == 10: 127 | cute.printf("Dynamic True\\n") 128 | n += 1 129 | 130 | # ❌ Using a dynamic value with `cutlass.const_expr` is not allowed. 131 | while cutlass.const_expr(n < dynamic_var): 132 | n += 1 133 | 134 | 135 | Compile-Time Metaprogramming 136 | ---------------------------- 137 | 138 | Mix compile-time constructs with normal |DSL| code to generate specialised 139 | kernels without runtime overhead. A compile-time flag can, for example, toggle 140 | an optional **ReLU** epilogue: 141 | 142 | .. code-block:: python 143 | 144 | @cute.kernel 145 | def gemm(..., do_relu: cutlass.Constexpr): 146 | # main GEMM work 147 | ... 148 | if cutlass.const_expr(do_relu): # compile-time guard 149 | # ReLU code is emitted only when do_relu is True 150 | ... 151 | 152 | .. code-block:: text 153 | 154 | gemm(..., False) # ReLU is omitted from the generated |IR| 155 | gemm(..., True) # ReLU is included 156 | 157 | 158 | Limitations of Dynamic Control Flow 159 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 160 | 161 | * Early-exit ``break``, ``continue``, ``pass`` or raising exception from 162 | control flow body are not yet supported. 163 | * Operations in the control flow body are traced only when tracing is active in 164 | that region. 165 | * Values originating in control flow body are not available outside the control 166 | flow. 167 | * Changing type of a variable in control flow body is not allowed. 168 | 169 | **Example:** 170 | 171 | .. code-block:: python 172 | 173 | @cute.jit 174 | def control_flow_negative_examples(predicate: cutlass.Boolean): 175 | n = 10 176 | 177 | # ❌ This loop is dynamic, early-exit isn't allowed. 178 | for i in range(n): 179 | if i == 5: 180 | break # Early-exit 181 | 182 | if predicate: 183 | val = 10 184 | # ❌ return from control flow body is not allowed. 185 | return 186 | # ❌ Raising exception from control flow body is not allowed. 187 | raise ValueError("This is not allowed") 188 | # ❌ Using pass in control flow body is not allowed. 189 | pass 190 | 191 | # ❌ val is not available outside the dynamic if 192 | cute.printf("%d\\n", val) 193 | 194 | if predicate: 195 | # ❌ Changing type of a variable in control flow body is not allowed. 196 | n = 10.0 197 | 198 | -------------------------------------------------------------------------------- /tests/test_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. 2 | 3 | import pytest 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | from quack.softmax import softmax 9 | 10 | 11 | torch._dynamo.config.cache_size_limit = 1024 12 | torch._dynamo.config.accumulated_cache_size_limit = 1024 13 | 14 | 15 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32]) 16 | # @pytest.mark.parametrize("input_dtype", [torch.float32]) 17 | @pytest.mark.parametrize( 18 | "N", 19 | [192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144], 20 | # [32768] 21 | ) 22 | @pytest.mark.parametrize("M", [1, 37, 199]) 23 | @pytest.mark.parametrize("function", [softmax, torch.compile(softmax, fullgraph=True)]) 24 | # @pytest.mark.parametrize("M", [1]) 25 | def test_softmax(M, N, input_dtype, function): 26 | """Test Softmax forward and backward passes against reference implementation.""" 27 | device = "cuda" 28 | # Set tolerance based on dtype 29 | if input_dtype == torch.bfloat16: 30 | atol = 1e-2 31 | rtol = 1e-2 32 | elif input_dtype == torch.float16: 33 | atol = 1e-3 34 | rtol = 1e-3 35 | else: 36 | atol = 1e-4 37 | rtol = 1e-4 38 | 39 | torch.random.manual_seed(0) 40 | # Create input tensors (scale down to avoid overflow in softmax) 41 | x = (0.1 * torch.randn(M, N, device=device, dtype=input_dtype)).requires_grad_() 42 | x_ref = x.detach().clone().requires_grad_(True) 43 | 44 | # Forward pass 45 | out = function(x) 46 | out_ref = F.softmax(x_ref, dim=-1) 47 | 48 | # Check output shape and dtype 49 | assert out.shape == x.shape 50 | assert out.dtype == input_dtype 51 | # Check accuracy 52 | torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) 53 | # Check softmax properties 54 | # Sum along last dimension should be 1 55 | sums = torch.sum(out, dim=-1) 56 | torch.testing.assert_close(sums, torch.ones_like(sums), atol=1e-4, rtol=1e-4) 57 | # All values should be positive 58 | assert (out >= 0).all() 59 | # All values should be <= 1 60 | assert (out <= 1).all() 61 | 62 | # Test backward pass 63 | dy = torch.randn_like(out) 64 | torch.cuda.synchronize() # without sync, torch.autograd gets wrong results 65 | (dx,) = torch.autograd.grad(out, x, grad_outputs=dy) 66 | (dx_ref,) = torch.autograd.grad(out_ref, x_ref, grad_outputs=dy) 67 | assert dx.shape == dy.shape 68 | assert dx.dtype == input_dtype 69 | torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol) 70 | 71 | 72 | @pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32]) 73 | @pytest.mark.parametrize("function", [softmax, torch.compile(softmax, fullgraph=True)]) 74 | def test_softmax_extreme_values(input_dtype, function): 75 | """Test Softmax with extreme input values.""" 76 | device = "cuda" 77 | M, N = 16, 1024 78 | # Test with large positive values 79 | x_large = torch.full((M, N), 10.0, device=device, dtype=input_dtype) 80 | out_large = function(x_large) 81 | # Should be uniform since all values are the same 82 | expected = torch.full_like(out_large, 1.0 / N) 83 | torch.testing.assert_close(out_large, expected, atol=1e-3, rtol=1e-3) 84 | # Test with large negative values 85 | x_small = torch.full((M, N), -10.0, device=device, dtype=input_dtype) 86 | out_small = function(x_small) 87 | # Should also be uniform 88 | torch.testing.assert_close(out_small, expected, atol=1e-3, rtol=1e-3) 89 | # Test with mixed extreme values 90 | x_mixed = torch.zeros((M, N), device=device, dtype=input_dtype) 91 | x_mixed[:, 0] = 10.0 # One large value per row 92 | x_mixed[:, 1:] = -10.0 # Rest are small 93 | out_mixed = function(x_mixed) 94 | # First column should be close to 1, rest close to 0 95 | assert (out_mixed[:, 0] > 0.99).all() 96 | assert (out_mixed[:, 1:] < 0.01).all() 97 | 98 | 99 | @pytest.mark.parametrize("function", [softmax, torch.compile(softmax, fullgraph=True)]) 100 | def test_softmax_numerical_stability(function): 101 | """Test that softmax is numerically stable.""" 102 | device = "cuda" 103 | M, N = 8, 512 104 | # Create input with a wide range of values 105 | x = torch.randn(M, N, device=device, dtype=torch.float32) 106 | # Add large constant to test numerical stability 107 | x_shifted = x + 100.0 108 | out = function(x) 109 | out_shifted = function(x_shifted) 110 | # Results should be identical (softmax is translation invariant) 111 | torch.testing.assert_close(out, out_shifted, atol=1e-6, rtol=1e-6) 112 | 113 | 114 | # def test_softmax_backward_properties(): 115 | # """Test mathematical properties of softmax backward pass.""" 116 | # device = "cuda" 117 | # M, N = 16, 1024 118 | 119 | # torch.random.manual_seed(42) 120 | # # Create test inputs 121 | # x = torch.randn(M, N, device=device, dtype=torch.float32, requires_grad=True) 122 | # y = F.softmax(x, dim=-1) 123 | 124 | # # Test 1: Gradient of uniform upstream should sum to zero along last dim 125 | # dy_uniform = torch.ones_like(y) 126 | # dx = softmax_backward(dy_uniform, y.detach()) 127 | 128 | # # Sum along last dimension should be approximately zero 129 | # row_sums = torch.sum(dx, dim=-1) 130 | # torch.testing.assert_close(row_sums, torch.zeros_like(row_sums), atol=1e-6, rtol=1e-6) 131 | 132 | # # Test 2: Gradient should be zero when upstream gradient is proportional to y 133 | # # (since softmax(x) ∝ y means dy ∝ y, so dy - dot*1 = 0) 134 | # for scale in [0.5, 1.0, 2.0]: 135 | # dy_prop = scale * y 136 | # dx_prop = softmax_backward(dy_prop, y.detach()) 137 | # torch.testing.assert_close(dx_prop, torch.zeros_like(dx_prop), atol=1e-6, rtol=1e-6) 138 | 139 | 140 | # def test_softmax_backward_edge_cases(): 141 | # """Test softmax backward with edge cases.""" 142 | # device = "cuda" 143 | 144 | # # Test with one-hot softmax output (peaked distribution) 145 | # M, N = 8, 512 146 | # y = torch.zeros(M, N, device=device, dtype=torch.float32) 147 | # y[:, 0] = 1.0 # One-hot at first position 148 | 149 | # dy = torch.randn_like(y) 150 | # dx = softmax_backward(dy, y) 151 | 152 | # # Check that gradients exist and are finite 153 | # assert torch.isfinite(dx).all() 154 | 155 | # # Test with uniform softmax output 156 | # y_uniform = torch.full((M, N), 1.0/N, device=device, dtype=torch.float32) 157 | # dy_uniform = torch.randn_like(y_uniform) 158 | # dx_uniform = softmax_backward(dy_uniform, y_uniform) 159 | 160 | # # Check that gradients exist and are finite 161 | # assert torch.isfinite(dx_uniform).all() 162 | 163 | # # Row sums should be zero (property of softmax jacobian) 164 | # row_sums = torch.sum(dx_uniform, dim=-1) 165 | # torch.testing.assert_close(row_sums, torch.zeros_like(row_sums), atol=1e-6, rtol=1e-6) 166 | -------------------------------------------------------------------------------- /benchmarks/benchmark_topk.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Type 4 | 5 | import torch 6 | from triton.testing import do_bench 7 | 8 | import cutlass 9 | import cutlass.torch as cutlass_torch 10 | 11 | from quack.topk import topk, topk_bwd 12 | 13 | try: 14 | import rtopk 15 | except ImportError: 16 | rtopk = None 17 | 18 | 19 | def run_topk( 20 | M, 21 | N, 22 | k, 23 | dtype: Type[cutlass.Numeric], 24 | softmax: bool = False, 25 | backward: bool = False, 26 | warmup_iterations=10, 27 | iterations=1000, 28 | ): 29 | if not torch.cuda.is_available(): 30 | raise RuntimeError(f"CUDA GPU is required to run this example!") 31 | 32 | print(f"Tensor dimensions: [{M}, {N}], k={k}") 33 | print(f"Input and Output Data type: {dtype}") 34 | 35 | torch_dtype = cutlass_torch.dtype(dtype) 36 | 37 | device = "cuda" 38 | x = torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True) 39 | 40 | print(f"Input tensor shapes:") 41 | print(f"x: {x.shape}, dtype: {x.dtype}") 42 | out, idx = topk(x, k, softmax=softmax) 43 | print(f"Output shape: {out.shape}") 44 | 45 | if backward: 46 | dvalues = torch.randn_like(out) 47 | fn = lambda: topk_bwd(dvalues, out, idx, N, softmax=softmax) 48 | else: 49 | # Benchmark our implementation 50 | fn = lambda: topk(x, k, softmax=softmax) 51 | fn() # warm up 52 | time.sleep(0.5) 53 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 54 | # Memory: read input (M*N elements), write output (M*k elements) 55 | if backward: 56 | mem_accessed = (M * N + 2 * M * k) * dtype.width // 8 57 | else: 58 | mem_accessed = (M * N + M * k) * dtype.width // 8 59 | mem_bw = round(mem_accessed / (avg_time / 1000) / 1e9, 2) 60 | print(f"Kernel execution time: {avg_time:.4f} ms") 61 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 62 | 63 | # Benchmark PyTorch reference 64 | if backward: 65 | fn_ref = lambda: torch.autograd.grad( 66 | torch.softmax(torch.topk(x, k, dim=-1, largest=True, sorted=True)[0], dim=-1) 67 | if softmax 68 | else torch.topk(x, k, dim=-1, largest=True, sorted=True)[0], 69 | x, 70 | grad_outputs=dvalues, 71 | retain_graph=True, 72 | ) 73 | for _ in range(5): 74 | fn_ref() 75 | time.sleep(0.5) 76 | avg_time_ref = do_bench(fn_ref, warmup=warmup_iterations, rep=iterations) 77 | mem_bw_ref = round(mem_accessed / (avg_time_ref / 1000) / 1e9, 2) 78 | print(f"Ref backward execution time: {avg_time_ref:.4f} ms") 79 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 80 | speedup = avg_time_ref / avg_time 81 | print(f"Speedup: {speedup:.2f}x") 82 | else: 83 | fn_ref = lambda: torch.topk(x, k, dim=-1, largest=True, sorted=True)[0] 84 | for _ in range(5): fn_ref() # warm up 85 | time.sleep(0.5) 86 | avg_time_ref = do_bench(fn_ref, warmup=warmup_iterations, rep=iterations) 87 | mem_bw_ref = round(mem_accessed / (avg_time_ref / 1000) / 1e9, 2) 88 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms") 89 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 90 | 91 | speedup = avg_time_ref / avg_time 92 | print(f"Speedup: {speedup:.2f}x") 93 | 94 | if rtopk is not None: 95 | fn_rtopk = lambda: rtopk.ops.rtopk(x, k, max_iter=512) 96 | for _ in range(5): fn_rtopk() # warm up 97 | time.sleep(0.5) 98 | avg_time_ref = do_bench(fn_rtopk, warmup=warmup_iterations, rep=iterations) 99 | mem_bw_ref = round(mem_accessed / (avg_time_ref / 1000) / 1e9, 2) 100 | print(f"RTopK kernel execution time: {avg_time_ref:.4f} ms") 101 | print(f"RTopK mem throughput: {mem_bw_ref:.2f} GB/s") 102 | 103 | # do_bench doesn't seem very accurate for very fast kernels, so we use pytorch_profiler 104 | from flash_attn.utils.benchmark import pytorch_profiler 105 | pytorch_profiler(fn) 106 | if rtopk is not None: 107 | pytorch_profiler(fn_rtopk) 108 | 109 | return mem_bw, mem_bw_ref, speedup 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser( 114 | description="Benchmark top-k operation" 115 | ) 116 | parser.add_argument("--M", default=8192, type=int) 117 | parser.add_argument("--N", default=1024, type=int) 118 | parser.add_argument("--k", default=32, type=int) 119 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16) 120 | parser.add_argument("--softmax", action="store_true", help="Apply softmax to top-k values") 121 | parser.add_argument("--warmup_iterations", default=10, type=int) 122 | parser.add_argument("--iterations", default=100, type=int) 123 | parser.add_argument("--sweep", action="store_true", help="Run sweep across different N and k values") 124 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass") 125 | 126 | args = parser.parse_args() 127 | torch.manual_seed(0) 128 | 129 | cutlass.cuda.initialize_cuda_context() 130 | 131 | if args.sweep: 132 | print("=== Top-K Sweep Benchmark ===") 133 | # Test different combinations of N and k 134 | N_values = [64, 128, 256, 512, 1024] 135 | k_values = [8, 16, 32] 136 | 137 | results = [] 138 | for N in N_values: 139 | for k in k_values: 140 | if k > N // 2: # Skip if k is too large relative to N 141 | continue 142 | print(f"\n--- N={N}, k={k} ---") 143 | try: 144 | mem_bw, mem_bw_ref, speedup = run_topk( 145 | args.M, 146 | N, 147 | k, 148 | dtype=args.dtype, 149 | softmax=args.softmax, 150 | backward=args.backward, 151 | warmup_iterations=args.warmup_iterations, 152 | iterations=args.iterations, 153 | ) 154 | results.append((N, k, mem_bw, mem_bw_ref, speedup)) 155 | except Exception as e: 156 | print(f"Error with N={N}, k={k}: {e}") 157 | 158 | print("\n=== Summary ===") 159 | print("N\tk\tOurs (GB/s)\tRef (GB/s)\tSpeedup") 160 | for N, k, mem_bw, mem_bw_ref, speedup in results: 161 | print(f"{N}\t{k}\t{mem_bw}\t\t{mem_bw_ref}\t\t{speedup:.2f}x") 162 | else: 163 | print("=== Top-K Benchmark ===") 164 | run_topk( 165 | args.M, 166 | args.N, 167 | args.k, 168 | dtype=args.dtype, 169 | softmax=args.softmax, 170 | backward=args.backward, 171 | warmup_iterations=args.warmup_iterations, 172 | iterations=args.iterations, 173 | ) 174 | -------------------------------------------------------------------------------- /benchmarks/benchmark_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Type 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from triton.testing import do_bench 8 | 9 | import cutlass 10 | import cutlass.torch as cutlass_torch 11 | 12 | from quack.cross_entropy import cross_entropy_fwd, cross_entropy 13 | 14 | 15 | def run_cross_entropy( 16 | M, 17 | N, 18 | dtype: Type[cutlass.Numeric], 19 | warmup_iterations=2, 20 | iterations=200, 21 | return_dx=False, 22 | ): 23 | if not torch.cuda.is_available(): 24 | raise RuntimeError(f"Ampere GPU is required to run this example!") 25 | 26 | print(f"Tensor dimensions: [{M}, {N}]") 27 | print(f"Input Data type: {dtype}") 28 | 29 | torch_dtype = cutlass_torch.dtype(dtype) 30 | 31 | device = "cuda" 32 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype) 33 | target = torch.randint(0, N, (M,), device=device, dtype=torch.int64) 34 | 35 | loss = cross_entropy_fwd(x, target, return_dx=return_dx) 36 | 37 | compiled_func_ref = torch.compile(lambda x, target: F.cross_entropy(x, target, reduction='none')) 38 | 39 | fn = lambda: cross_entropy_fwd(x, target, return_dx=return_dx) 40 | time.sleep(0.5) 41 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 42 | # Memory bandwidth calculation: read x (M*N elements) + read target (M elements) + write loss (M elements) 43 | mem_bytes = (M * N * (2 if return_dx else 1) + M + M) * dtype.width // 8 44 | mem_bw = round(mem_bytes / (avg_time / 1000) / 1e9) 45 | print(f"Kernel execution time: {avg_time:.4f} ms") 46 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 47 | 48 | fn_ref = lambda: compiled_func_ref(x, target) 49 | for _ in range(5): fn_ref() # warm up 50 | time.sleep(0.5) 51 | avg_time = do_bench(fn_ref, warmup=warmup_iterations, rep=iterations) 52 | mem_bytes = (M * N + M + M) * dtype.width // 8 53 | mem_bw_ref = round(mem_bytes / (avg_time / 1000) / 1e9) 54 | print(f"Ref kernel execution time: {avg_time:.4f} ms") 55 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 56 | 57 | # from flash_attn.utils.benchmark import pytorch_profiler 58 | # pytorch_profiler(fn) 59 | # pytorch_profiler(fn_ref) 60 | # pytorch_profiler(torch.compile(torch.logsumexp), x, dim=-1) 61 | 62 | return mem_bw, mem_bw_ref 63 | 64 | 65 | 66 | def run_cross_entropy_backward( 67 | M, 68 | N, 69 | dtype: Type[cutlass.Numeric], 70 | warmup_iterations=10, 71 | iterations=1000, 72 | ): 73 | if not torch.cuda.is_available(): 74 | raise RuntimeError(f"Ampere GPU is required to run this example!") 75 | 76 | print(f"Tensor dimensions: [{M}, {N}]") 77 | print(f"Input Data type: {dtype}") 78 | 79 | torch_dtype = cutlass_torch.dtype(dtype) 80 | 81 | device = "cuda" 82 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True) 83 | target = torch.randint(0, N, (M,), device=device, dtype=torch.int64) 84 | x_ref = x.detach().clone().requires_grad_() 85 | 86 | print(f"Input tensor shapes:") 87 | print(f"x: {x.shape}, dtype: {x.dtype}") 88 | print(f"target: {target.shape}, dtype: {target.dtype}") 89 | 90 | loss = cross_entropy(x, target, reduction="none") 91 | dloss = torch.randn(M, device=device, dtype=torch.float32) 92 | torch.cuda.synchronize() 93 | 94 | # Reference implementation 95 | loss_ref = F.cross_entropy(x_ref, target, reduction='none') 96 | compiled_func_ref = torch.compile(lambda: torch.autograd.grad(loss_ref, x_ref, grad_outputs=dloss, retain_graph=True)) 97 | 98 | for _ in range(5): compiled_func_ref() # warm up 99 | time.sleep(0.5) 100 | avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations) 101 | mem_bw_ref = round((2 * x.numel() * x.element_size() + target.numel() * target.element_size() + 102 | dloss.numel() * dloss.element_size()) / (avg_time_ref / 1000) / 1e9) 103 | 104 | time.sleep(0.5) 105 | fn = lambda: torch.autograd.grad(loss, x, grad_outputs=dloss, retain_graph=True) 106 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 107 | # Memory bandwidth calculation: read x (M*N) + read target (M) + read dloss (M) + write grad (M*N) 108 | mem_bw = round((2 * x.numel() * x.element_size() + target.numel() * target.element_size() + 109 | dloss.numel() * dloss.element_size()) / (avg_time / 1000) / 1e9) 110 | print(f"Kernel execution time: {avg_time:.4f} ms") 111 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 112 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms") 113 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 114 | 115 | return mem_bw, mem_bw_ref 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser( 120 | description="Benchmark cross entropy forward and backward passes" 121 | ) 122 | parser.add_argument("--M", default=8192, type=int) 123 | parser.add_argument("--N", default=16384, type=int) 124 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16) 125 | parser.add_argument("--warmup_iterations", default=10, type=int) 126 | parser.add_argument("--iterations", default=100, type=int) 127 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass") 128 | parser.add_argument("--fwd_dx", action="store_true", help="Benchmark forward pass that also computes dx") 129 | 130 | args = parser.parse_args() 131 | torch.manual_seed(0) 132 | cutlass.cuda.initialize_cuda_context() 133 | 134 | 135 | if args.backward: 136 | print("=== Cross Entropy Backward Pass Benchmark ===") 137 | run_cross_entropy_backward( 138 | args.M, 139 | args.N, 140 | dtype=args.dtype, 141 | warmup_iterations=args.warmup_iterations, 142 | iterations=args.iterations, 143 | ) 144 | else: 145 | print("=== Cross Entropy Forward Pass Benchmark ===") 146 | run_cross_entropy( 147 | args.M, 148 | args.N, 149 | dtype=args.dtype, 150 | warmup_iterations=args.warmup_iterations, 151 | iterations=args.iterations, 152 | return_dx=args.fwd_dx, 153 | ) 154 | 155 | 156 | ''' 157 | #MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)] 158 | MN_pairs = [(32768, 65536)] 159 | results = [] 160 | for M, N in MN_pairs: 161 | res = run_cross_entropy_backward( 162 | M, 163 | N, 164 | dtype=args.dtype, 165 | warmup_iterations=args.warmup_iterations, 166 | iterations=args.iterations, 167 | ) 168 | results.append(res) 169 | print(results) 170 | #print([x for x, _ in results]) 171 | ''' 172 | -------------------------------------------------------------------------------- /tests/test_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025, Tri Dao. 2 | import math 3 | import pytest 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from quack.linear import linear_func, linear_act_func 8 | from quack.gemm_interface import ( 9 | gemm_add_inplace, 10 | gemm_dact, 11 | gemm_act_ref, 12 | gemm_dact_ref, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 17 | @pytest.mark.parametrize("has_bias", [False, True]) 18 | @pytest.mark.parametrize("out_features", [1504, 2048]) 19 | @pytest.mark.parametrize("in_features", [736, 4096]) 20 | # @pytest.mark.parametrize("out_features", [2048]) 21 | # @pytest.mark.parametrize("in_features", [4096]) 22 | def test_linear(in_features, out_features, has_bias, input_dtype): 23 | device = "cuda" 24 | torch.random.manual_seed(0) 25 | m = 1920 26 | x = torch.randn((m, in_features), device=device, dtype=input_dtype) 27 | x = x[::2].requires_grad_(True) # Testing non-contiguous 28 | w = ( 29 | torch.randn((out_features, in_features), device=device, dtype=input_dtype) 30 | / math.sqrt(in_features) 31 | ).requires_grad_() 32 | bias = torch.randn(out_features, device=device, requires_grad=True) if has_bias else None 33 | x_ref, w_ref, bias_ref = [ 34 | t.detach().clone().float().requires_grad_(True) if t is not None else None 35 | for t in (x, w, bias) 36 | ] 37 | x_pt, w_pt, bias_pt = [ 38 | t.detach().clone().to(x.dtype).requires_grad_(True) if t is not None else None 39 | for t in (x, w, bias) 40 | ] 41 | out = linear_func(x, w, bias, tuned=False) # Disable tuning for faster test 42 | out_ref = F.linear(x_ref, w_ref, bias_ref) 43 | out_pt = F.linear(x_pt, w_pt, bias_pt) 44 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-6 45 | dout = torch.randn_like(out) 46 | out.backward(dout) 47 | out_ref.backward(dout.float()) 48 | out_pt.backward(dout) 49 | assert (x.grad - x_ref.grad).abs().max() < 2 * (x_pt.grad - x_ref.grad).abs().max() + 1e-6 50 | assert (w.grad - w_ref.grad).abs().max() < 2 * (w_pt.grad - w_ref.grad).abs().max() + 1e-6 51 | if bias is not None: 52 | assert (bias.grad - bias_ref.grad).abs().max() < 2 * ( 53 | bias_pt.grad - bias_ref.grad 54 | ).abs().max() + 1e-6 55 | 56 | 57 | @pytest.mark.parametrize("store_preact", [False, True]) 58 | @pytest.mark.parametrize("activation", ["relu", "relu_sq", "gelu_tanh_approx"]) 59 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 60 | @pytest.mark.parametrize("has_bias", [False, True]) 61 | @pytest.mark.parametrize("out_features", [1504, 2048]) 62 | @pytest.mark.parametrize("in_features", [736, 4096]) 63 | # @pytest.mark.parametrize("out_features", [2048]) 64 | # @pytest.mark.parametrize("in_features", [4096]) 65 | def test_linear_act(in_features, out_features, has_bias, input_dtype, activation, store_preact): 66 | device = "cuda" 67 | torch.random.manual_seed(0) 68 | m = 1920 69 | x = torch.randn((m, in_features), device=device, dtype=input_dtype) 70 | x = x[::2].requires_grad_(True) # Testing non-contiguous 71 | w = ( 72 | torch.randn((out_features, in_features), device=device, dtype=input_dtype) 73 | / math.sqrt(in_features) 74 | ).requires_grad_() 75 | bias = torch.randn(out_features, device=device, requires_grad=True) if has_bias else None 76 | # Disable tuning for faster test 77 | preact, postact = linear_act_func( 78 | x, w, activation, bias=bias, store_preact=store_preact, tuned=False 79 | ) 80 | preact_ref, postact_ref = gemm_act_ref( 81 | x.float(), w.float().T, activation=activation, bias=bias, store_preact=store_preact 82 | ) 83 | preact_pt, postact_pt = gemm_act_ref( 84 | x, w.T, activation=activation, bias=bias, store_preact=store_preact 85 | ) 86 | assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-6 87 | if store_preact: 88 | assert preact is not None and preact_ref is not None 89 | assert (preact - preact_ref).abs().max() < 2 * (preact_pt - preact_ref).abs().max() + 1e-6 90 | 91 | 92 | @pytest.mark.parametrize("activation", ["relu", "relu_sq", "gelu_tanh_approx"]) 93 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 94 | @pytest.mark.parametrize("k", [736, 1024]) 95 | @pytest.mark.parametrize("n", [1504, 2048]) 96 | def test_gemm_dact(n, k, input_dtype, activation): 97 | """Test GEMM with activation gradient computation.""" 98 | device = "cuda" 99 | torch.random.manual_seed(0) 100 | m = 960 101 | dout_input = torch.randn((m, k), device=device, dtype=input_dtype) 102 | weight = torch.randn((n, k), device=device, dtype=input_dtype) / math.sqrt(k) 103 | preact = torch.randn((m, n), device=device, dtype=input_dtype, requires_grad=True) 104 | # Disable tuning for faster test 105 | dx, postact = gemm_dact(dout_input, weight.T, preact, activation=activation, tuned=False) 106 | dx_ref, postact_ref = gemm_dact_ref( 107 | dout_input.float(), weight.float().T, preact.float(), activation=activation 108 | ) 109 | dx_pt, postact_pt = gemm_dact_ref(dout_input, weight.T, preact, activation=activation) 110 | assert (dx - dx_ref).abs().max() < 2 * (dx_pt - dx_ref).abs().max() + 1e-5 111 | assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-5 112 | 113 | 114 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 115 | @pytest.mark.parametrize("n", [1504, 2048]) 116 | @pytest.mark.parametrize("k", [736, 1024]) 117 | @pytest.mark.parametrize("m", [960, 1920]) 118 | def test_gemm_add_inplace(m, k, n, input_dtype): 119 | """Test in-place GEMM with addition: C += A @ B.""" 120 | device = "cuda" 121 | torch.random.manual_seed(0) 122 | A = torch.randn((m, k), device=device, dtype=input_dtype) 123 | B = torch.randn((k, n), device=device, dtype=input_dtype) 124 | C = torch.randn((m, n), device=device, dtype=input_dtype) 125 | # Save original C for reference computation 126 | C_og = C.clone() 127 | gemm_add_inplace(A, B, C, tuned=False) 128 | C_ref = C_og.float() + torch.mm(A.float(), B.float()) 129 | C_pt = C_og + torch.mm(A, B) 130 | assert (C - C_ref).abs().max() < 2 * (C_pt - C_ref).abs().max() + 1e-5 131 | 132 | 133 | @pytest.mark.parametrize("alpha_beta_type", ["float", "tensor"]) 134 | @pytest.mark.parametrize("alpha", [0.5, 1.0, 2.0]) 135 | @pytest.mark.parametrize("beta", [0.0, 0.5, 1.0, 1.5]) 136 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 137 | @pytest.mark.parametrize("n", [512, 1024]) 138 | @pytest.mark.parametrize("k", [256, 768]) 139 | @pytest.mark.parametrize("m", [480, 960]) 140 | def test_gemm_add_inplace_alpha_beta(m, k, n, input_dtype, alpha, beta, alpha_beta_type): 141 | """Test in-place GEMM with alpha/beta scaling: C = alpha * A @ B + beta * C.""" 142 | device = "cuda" 143 | torch.random.manual_seed(42) 144 | A = torch.randn((m, k), device=device, dtype=input_dtype) 145 | B = torch.randn((k, n), device=device, dtype=input_dtype) 146 | C = torch.randn((m, n), device=device, dtype=input_dtype) 147 | if alpha_beta_type == "tensor": 148 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32) 149 | beta = torch.tensor(beta, device=device, dtype=torch.float32) 150 | C_og = C.clone() 151 | gemm_add_inplace(A, B, C, alpha=alpha, beta=beta, tuned=False) 152 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha 153 | beta_val = beta.item() if torch.is_tensor(beta) else beta 154 | C_ref = alpha_val * torch.mm(A.float(), B.float()) + beta_val * C_og.float() 155 | C_pt = alpha_val * torch.mm(A, B) + beta_val * C_og 156 | assert (C - C_ref).abs().max() < 2 * (C_pt - C_ref).abs().max() + 1e-4 157 | -------------------------------------------------------------------------------- /quack/gemm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from functools import partial 3 | 4 | from torch import Tensor 5 | 6 | import cutlass.cute as cute 7 | import cutlass.torch as cutlass_torch 8 | from cutlass import Float32 9 | from cutlass.cute.runtime import from_dlpack, make_ptr 10 | 11 | from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters 12 | from quack.gemm_wrapper_utils import GemmWrapperBase 13 | from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100 14 | 15 | 16 | def gemm( 17 | # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k 18 | A: Tensor, 19 | B: Tensor, # (l, n, k) or (n, total_k) if varlen_k 20 | D: Tensor, # (l, m, n) or (total_m, n) if varlen_m 21 | C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m 22 | tile_count_semaphore: Optional[Tensor], # (1,) 23 | tile_M: int, 24 | tile_N: int, 25 | cluster_M: int, 26 | cluster_N: int, 27 | pingpong: bool = False, 28 | persistent: bool = True, 29 | max_swizzle_size: int = 8, 30 | rowvec_bias: Optional[Tensor] = None, # (l, n) 31 | colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m 32 | alpha: float | Tensor = 1.0, 33 | beta: float | Tensor = 1.0, 34 | cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length 35 | cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length 36 | A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen 37 | batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler 38 | add_to_output: bool = False, 39 | ) -> None: 40 | varlen = cu_seqlens_m is not None or cu_seqlens_k is not None 41 | assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), ( 42 | "Only one of cu_seqlens_m and cu_seqlens_k can be specified" 43 | ) 44 | gather_A = A_idx is not None 45 | if gather_A: 46 | assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)" 47 | assert cluster_N == 1, "gather_A requires cluster_N=1" 48 | if varlen: 49 | assert persistent, "varlen requires persistent=True" 50 | if add_to_output: 51 | assert cu_seqlens_m is None, "Add to output not supported with varlen_m" 52 | if cu_seqlens_m is not None: 53 | assert A.stride(-1) == 1, "varlen_m requires A to be k-major" 54 | assert D.stride(-1) == 1, "varlen_m requires D to be n-major" 55 | if cu_seqlens_k is not None: 56 | assert A.stride(-2) == 1, "varlen_k requires A to be m-major" 57 | assert B.stride(-2) == 1, "varlen_k requires B to be n-major" 58 | 59 | L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( 60 | A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx 61 | ) 62 | GemmWrapperBase.permute_tensors( 63 | tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None 64 | ) 65 | GemmWrapperBase.extract_dtypes(tensor_infos) 66 | major_configs = { 67 | "A": ("m", "k", "l"), 68 | "B": ("n", "k", "l"), 69 | "D": ("m", "n", "l"), 70 | "C": ("m", "n", "l"), 71 | } 72 | GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) 73 | 74 | device_capacity = get_device_capacity(A.device) 75 | assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" 76 | GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90 77 | 78 | acc_dtype = Float32 79 | tile_shape_mn = (tile_M, tile_N) 80 | cluster_shape_mnk = (cluster_M, cluster_N, 1) 81 | if not GemmCls.is_valid_dtypes( 82 | tensor_infos["A"].dtype, 83 | tensor_infos["B"].dtype, 84 | acc_dtype, 85 | tensor_infos["D"].dtype, 86 | tensor_infos["A"].major, 87 | tensor_infos["B"].major, 88 | ): 89 | raise TypeError("Skipping due to unsupported combination of types and majors") 90 | 91 | max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 92 | GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) 93 | 94 | def scalar_arg(scalar: float | Tensor): 95 | if isinstance(scalar, float): 96 | return Float32(scalar) if scalar != 1.0 else None 97 | else: 98 | assert isinstance(scalar, Tensor) 99 | return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) 100 | 101 | epi_args = GemmCls.EpilogueArguments( 102 | scalar_arg(alpha), 103 | scalar_arg(beta), 104 | mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic( 105 | leading_dim=1 106 | ) 107 | if rowvec_bias is not None 108 | else None, 109 | mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( 110 | leading_dim=1 if cu_seqlens_m is None else 0 111 | ) 112 | if colvec_bias is not None 113 | else None, 114 | add_to_output=add_to_output, 115 | ) 116 | scheduler_args = GemmWrapperBase.create_scheduler_args( 117 | max_active_clusters, 118 | tile_count_semaphore, 119 | batch_idx_permute, 120 | max_swizzle_size, 121 | ) 122 | 123 | # Create varlen arguments if needed (assumes persistent=True when varlen) 124 | varlen_args = GemmWrapperBase.create_varlen_args( 125 | cu_seqlens_m, 126 | cu_seqlens_k, 127 | A_idx, 128 | max_active_clusters, 129 | cluster_shape_mnk, 130 | tensor_infos, 131 | GemmCls.num_epi_tensormaps, 132 | pingpong, 133 | ) 134 | 135 | current_stream = cutlass_torch.current_stream() 136 | compile_key = GemmWrapperBase.get_compile_key( 137 | tensor_infos, 138 | None, # activation 139 | tile_shape_mn, 140 | cluster_shape_mnk, 141 | pingpong, 142 | persistent, 143 | tile_count_semaphore is not None, 144 | device_capacity, 145 | # Technically we don't need to recompile for different max_swizzle_size, but currently 146 | # not recompiling will skew the autotuning results due to power throttling. 147 | # Effectively we're recompiling as a way to pause between benchmarks during autotuning. 148 | max_swizzle_size, 149 | rowvec_bias.dtype if rowvec_bias is not None else None, 150 | colvec_bias.dtype if colvec_bias is not None else None, 151 | 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0), 152 | 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0), 153 | add_to_output, 154 | cu_seqlens_m is not None, 155 | cu_seqlens_k is not None, 156 | gather_A, 157 | batch_idx_permute is not None, 158 | key_tensor_names=("A", "B", "D", "C"), 159 | ) 160 | cache = gemm.compile_cache 161 | if compile_key not in cache: 162 | if device_capacity[0] == 9: 163 | GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) 164 | gemm_obj = GemmCls( 165 | acc_dtype, 166 | tensor_infos["A"].dtype, 167 | tile_shape_mn, 168 | cluster_shape_mnk, 169 | gather_A=gather_A, 170 | ) 171 | cache[compile_key] = cute.compile( 172 | gemm_obj, 173 | tensor_infos["A"].cute_tensor, 174 | tensor_infos["B"].cute_tensor, 175 | tensor_infos["D"].cute_tensor, 176 | tensor_infos["C"].cute_tensor, 177 | epi_args, 178 | scheduler_args, 179 | varlen_args, 180 | current_stream, 181 | ) 182 | cache[compile_key]( 183 | tensor_infos["A"].cute_tensor, 184 | tensor_infos["B"].cute_tensor, 185 | tensor_infos["D"].cute_tensor, 186 | tensor_infos["C"].cute_tensor, 187 | epi_args, 188 | scheduler_args, 189 | varlen_args, 190 | current_stream, 191 | ) 192 | 193 | 194 | gemm.compile_cache = {} 195 | -------------------------------------------------------------------------------- /quack/gemm_dact.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao. 2 | from typing import Optional, Tuple 3 | from functools import partial 4 | 5 | from torch import Tensor 6 | 7 | import cutlass 8 | import cutlass.cute as cute 9 | from cutlass import Float32, const_expr 10 | import cutlass.torch as cutlass_torch 11 | 12 | from quack.gemm_sm90 import GemmSm90 13 | from quack.gemm_sm100 import GemmSm100 14 | from quack.gemm_default_epi import GemmDefaultEpiMixin 15 | from quack.gemm_act import GemmActMixin 16 | from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters 17 | from quack.gemm_wrapper_utils import GemmWrapperBase 18 | import quack.activation 19 | 20 | 21 | class GemmDActMixin(GemmActMixin): 22 | # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout) 23 | # and return 2 arguments (dx, out) 24 | EpilogueArguments = GemmActMixin.EpilogueArguments 25 | EpilogueParams = GemmActMixin.EpilogueParams 26 | 27 | @cute.jit 28 | def epi_visit_subtile( 29 | self, 30 | params: EpilogueParams, 31 | epi_loop_tensors: Tuple[cute.Tensor, ...], 32 | tRS_rD: cute.Tensor, 33 | tRS_rC: Optional[cute.Tensor] = None, 34 | ) -> Optional[cute.Tensor]: 35 | assert tRS_rC is not None 36 | # We don't add C to the accumulator 37 | GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None) 38 | tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype) 39 | tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype)) 40 | # If we don't have .shape here, the compiler generates local stores and loads 41 | if const_expr(params.act_fn is not None): 42 | tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype) 43 | if const_expr(self.arch < 100): 44 | for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): 45 | tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i]) 46 | else: 47 | for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): 48 | ( 49 | (tRS_rD[2 * i], tRS_rD[2 * i + 1]), 50 | (tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]), 51 | ) = params.act_fn( 52 | (tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]), 53 | (tRS_rD[2 * i], tRS_rD[2 * i + 1]), 54 | ) 55 | else: 56 | tRS_rPostAct = tRS_rC_acc 57 | # Type conversion 58 | tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype) 59 | tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) 60 | return tRS_rPostAct_out 61 | 62 | 63 | class GemmDActSm90(GemmDActMixin, GemmSm90): 64 | pass 65 | 66 | 67 | class GemmDActSm100(GemmDActMixin, GemmSm100): 68 | pass 69 | 70 | 71 | dact_fn_map = { 72 | None: None, 73 | "relu": quack.activation.drelu, 74 | "relu_sq": quack.activation.drelu_sq, 75 | "gelu_tanh_approx": quack.activation.dgelu_tanh_approx, 76 | } 77 | 78 | 79 | def gemm_dact( 80 | A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m 81 | B: Tensor, # (l, n, k) 82 | Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m 83 | PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m 84 | PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m 85 | tile_count_semaphore: Optional[Tensor], # (1,) 86 | activation: Optional[str], 87 | tile_M: int, 88 | tile_N: int, 89 | cluster_M: int, 90 | cluster_N: int, 91 | pingpong: bool = True, 92 | persistent: bool = True, 93 | max_swizzle_size: int = 8, 94 | cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length 95 | A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m 96 | ) -> None: 97 | if cu_seqlens_m is not None: 98 | assert persistent, "varlen_m requires persistent=True" 99 | assert A.stride(-1) == 1, "varlen_m requires A to be k-major" 100 | assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major" 101 | assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major" 102 | assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" 103 | gather_A = A_idx is not None 104 | if gather_A: 105 | assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" 106 | assert cluster_N == 1, "gather_A requires cluster_N=1" 107 | assert activation in dact_fn_map, f"Unsupported activation {activation}" 108 | 109 | L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( 110 | A, 111 | B, 112 | Out, 113 | PreAct, 114 | additional_tensors={"PostAct": PostAct}, 115 | cu_seqlens_m=cu_seqlens_m, 116 | A_idx=A_idx, 117 | ) 118 | GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) 119 | GemmWrapperBase.extract_dtypes(tensor_infos) 120 | major_configs = { 121 | "A": ("m", "k", "l"), 122 | "B": ("n", "k", "l"), 123 | "D": ("m", "n", "l"), 124 | "C": ("m", "n", "l"), 125 | "PostAct": ("m", "n", "l"), 126 | } 127 | GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) 128 | 129 | device_capacity = get_device_capacity(A.device) 130 | assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" 131 | GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90 132 | 133 | acc_dtype = Float32 134 | tile_shape_mn = (tile_M, tile_N) 135 | cluster_shape_mnk = (cluster_M, cluster_N, 1) 136 | if not GemmCls.is_valid_dtypes( 137 | tensor_infos["A"].dtype, 138 | tensor_infos["B"].dtype, 139 | acc_dtype, 140 | tensor_infos["D"].dtype, 141 | tensor_infos["A"].major, 142 | tensor_infos["B"].major, 143 | ): 144 | raise TypeError("Skipping due to unsupported combination of types and majors") 145 | 146 | max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 147 | GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) 148 | act_fn = dact_fn_map[activation] 149 | epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn) 150 | scheduler_args = GemmWrapperBase.create_scheduler_args( 151 | max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size 152 | ) 153 | 154 | # Create varlen arguments if needed (assumes persistent=True when varlen_m) 155 | varlen_args = GemmWrapperBase.create_varlen_args( 156 | cu_seqlens_m, 157 | None, # cu_seqlens_k 158 | A_idx, 159 | max_active_clusters, 160 | cluster_shape_mnk, 161 | tensor_infos, 162 | GemmCls.num_epi_tensormaps, 163 | pingpong, 164 | ) 165 | 166 | current_stream = cutlass_torch.current_stream() 167 | compile_key = GemmWrapperBase.get_compile_key( 168 | tensor_infos, 169 | activation, 170 | tile_shape_mn, 171 | cluster_shape_mnk, 172 | pingpong, 173 | persistent, 174 | tile_count_semaphore is not None, 175 | device_capacity, 176 | max_swizzle_size, 177 | cu_seqlens_m is not None, 178 | A_idx is not None, 179 | key_tensor_names=("A", "B", "D", "PostAct", "C"), 180 | ) 181 | cache = gemm_dact.compile_cache 182 | if compile_key not in cache: 183 | if device_capacity[0] == 9: 184 | GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) 185 | gemm = GemmCls( 186 | acc_dtype, 187 | tensor_infos["A"].dtype, 188 | tile_shape_mn, 189 | cluster_shape_mnk, 190 | gather_A=gather_A, 191 | ) 192 | cache[compile_key] = cute.compile( 193 | gemm, 194 | tensor_infos["A"].cute_tensor, 195 | tensor_infos["B"].cute_tensor, 196 | tensor_infos["D"].cute_tensor, 197 | tensor_infos["C"].cute_tensor, 198 | epi_args, 199 | scheduler_args, 200 | varlen_args, 201 | current_stream, 202 | ) 203 | cache[compile_key]( 204 | tensor_infos["A"].cute_tensor, 205 | tensor_infos["B"].cute_tensor, 206 | tensor_infos["D"].cute_tensor, 207 | tensor_infos["C"].cute_tensor, 208 | epi_args, 209 | scheduler_args, 210 | varlen_args, 211 | current_stream, 212 | ) 213 | 214 | 215 | gemm_dact.compile_cache = {} 216 | -------------------------------------------------------------------------------- /quack/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. 2 | 3 | import math 4 | from functools import partial 5 | from typing import Optional, Tuple, Union 6 | 7 | import cutlass 8 | import cutlass.cute as cute 9 | 10 | from cutlass import Float32, Int32, Boolean, const_expr 11 | from cutlass.cutlass_dsl import T, dsl_user_op 12 | from cutlass._mlir.dialects import llvm, nvvm, vector 13 | 14 | 15 | # cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default 16 | fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) 17 | mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) 18 | add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) 19 | sub_packed_f32x2 = partial( 20 | cute.arch.calc_packed_f32x2_op, 21 | src_c=None, 22 | calc_func=nvvm.sub_packed_f32x2, 23 | rnd=nvvm.RoundingModeKind.RN, 24 | ) 25 | 26 | 27 | @dsl_user_op 28 | def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: 29 | return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) 30 | 31 | 32 | @cute.jit 33 | def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32: 34 | if const_expr(isinstance(x, cute.Pointer)): 35 | return Float32(cute.make_tensor(x, cute.make_layout(1))[0]) 36 | else: 37 | assert isinstance(x, Float32) 38 | return x 39 | 40 | 41 | @dsl_user_op 42 | def set_block_rank( 43 | smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None 44 | ) -> Int32: 45 | """Map the given smem pointer to the address at another CTA rank in the cluster.""" 46 | smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() 47 | return Int32( 48 | llvm.inline_asm( 49 | T.i32(), 50 | [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], 51 | "mapa.shared::cluster.u32 $0, $1, $2;", 52 | "=r,r,r", 53 | has_side_effects=False, 54 | is_align_stack=False, 55 | asm_dialect=llvm.AsmDialect.AD_ATT, 56 | ) 57 | ) 58 | 59 | 60 | @dsl_user_op 61 | def store_shared_remote( 62 | val: float | Float32 | Int32 | cutlass.Int64, 63 | smem_ptr: cute.Pointer, 64 | mbar_ptr: cute.Pointer, 65 | peer_cta_rank_in_cluster: cute.typing.Int, 66 | *, 67 | loc=None, 68 | ip=None, 69 | ) -> None: 70 | remote_smem_ptr_i32 = set_block_rank( 71 | smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip 72 | ).ir_value() 73 | remote_mbar_ptr_i32 = set_block_rank( 74 | mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip 75 | ).ir_value() 76 | if const_expr(isinstance(val, float)): 77 | val = Float32(val) 78 | assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" 79 | suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] 80 | constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] 81 | llvm.inline_asm( 82 | None, 83 | [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], 84 | f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];", 85 | f"r,{constraint},r", 86 | has_side_effects=True, 87 | is_align_stack=False, 88 | asm_dialect=llvm.AsmDialect.AD_ATT, 89 | ) 90 | 91 | 92 | @dsl_user_op 93 | def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32: 94 | return Float32( 95 | nvvm.fmin( 96 | T.f32(), 97 | Float32(a).ir_value(loc=loc, ip=ip), 98 | Float32(b).ir_value(loc=loc, ip=ip), 99 | loc=loc, 100 | ip=ip, 101 | ) 102 | ) 103 | 104 | 105 | @dsl_user_op 106 | def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32: 107 | return Float32( 108 | llvm.inline_asm( 109 | T.f32(), 110 | [Float32(a).ir_value(loc=loc, ip=ip)], 111 | "sqrt.approx.f32 $0, $1;", 112 | "=f,f", 113 | has_side_effects=False, 114 | is_align_stack=False, 115 | asm_dialect=llvm.AsmDialect.AD_ATT, 116 | ) 117 | ) 118 | 119 | 120 | @dsl_user_op 121 | def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32: 122 | return Int32( 123 | llvm.inline_asm( 124 | T.i32(), 125 | [Float32(a).ir_value(loc=loc, ip=ip)], 126 | "cvt.rpi.ftz.s32.f32 $0, $1;", 127 | "=r,f", 128 | has_side_effects=False, 129 | is_align_stack=False, 130 | asm_dialect=llvm.AsmDialect.AD_ATT, 131 | ) 132 | ) 133 | 134 | 135 | @dsl_user_op 136 | def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32: 137 | return Int32( 138 | llvm.inline_asm( 139 | T.i32(), 140 | [ 141 | Int32(a).ir_value(loc=loc, ip=ip), 142 | Int32(b).ir_value(loc=loc, ip=ip), 143 | Int32(c).ir_value(loc=loc, ip=ip), 144 | ], 145 | "prmt.b32 $0, $1, $2, $3;", 146 | "=r,r,r,r", 147 | has_side_effects=False, 148 | is_align_stack=False, 149 | asm_dialect=llvm.AsmDialect.AD_ATT, 150 | ) 151 | ) 152 | 153 | 154 | @cute.jit 155 | def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: 156 | # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" 157 | tApA = cute.make_fragment( 158 | cute.make_layout( 159 | (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), 160 | stride=(cute.size(tAcA, mode=[2]), 0, 1), 161 | ), 162 | Boolean, 163 | ) 164 | for rest_v in cutlass.range_constexpr(tApA.shape[0]): 165 | for rest_k in cutlass.range_constexpr(tApA.shape[2]): 166 | tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) 167 | return tApA 168 | 169 | 170 | @cute.jit 171 | def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None: 172 | """Fill out-of-bounds values in shared memory tensor. 173 | 174 | Args: 175 | tXsX: Shared memory tensor to fill 176 | tXpX: Predicate tensor indicating valid elements 177 | fill_value: Value to fill OOB locations with 178 | """ 179 | tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) 180 | tXrX_fill.fill(fill_value) 181 | for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]): 182 | for rest_k in cutlass.range_constexpr(tXsX.shape[2]): 183 | if const_expr(tXpX is not None): 184 | if not tXpX[rest_v, 0, rest_k]: 185 | cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) 186 | else: 187 | cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) 188 | 189 | 190 | @dsl_user_op 191 | def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: 192 | vec_f32x2 = vector.from_elements( 193 | T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip 194 | ) 195 | vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2) 196 | res = cutlass.Int64( 197 | vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip) 198 | ) 199 | return res 200 | 201 | 202 | @dsl_user_op 203 | def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: 204 | vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip) 205 | vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1) 206 | res0 = Float32( 207 | vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) 208 | ) 209 | res1 = Float32( 210 | vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) 211 | ) 212 | return res0, res1 213 | 214 | 215 | @cute.jit 216 | def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32: 217 | if const_expr(lane is None): 218 | lane = cute.arch.lane_idx() 219 | for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): 220 | offset = 1 << i 221 | # Very important that we set mask_and_clamp to 0 222 | partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) 223 | if lane >= offset: 224 | val += partial_sum 225 | return val 226 | 227 | 228 | @dsl_user_op 229 | def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: 230 | return nvvm.atomicrmw( 231 | res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() 232 | ) 233 | 234 | 235 | @dsl_user_op 236 | def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: 237 | return nvvm.atomicrmw( 238 | res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() 239 | ) 240 | -------------------------------------------------------------------------------- /docs/limitations.rst: -------------------------------------------------------------------------------- 1 | .. _limitations: 2 | 3 | Limitations 4 | ==================== 5 | 6 | 7 | Overview 8 | --------------------- 9 | CuTe DSL is an embedded domain-specific language within Python. It utilizes a subset of Python's 10 | syntax to provide a streamlined programming experience. It is important to understand that CuTe DSL 11 | does NOT implement the complete Python language semantics in its JIT compilation process. 12 | 13 | Programming Model 14 | --------------------- 15 | 16 | **Python Native Data Types** 17 | CuTe DSL supports Python data structures when used for "meta-programming," 18 | but these structures cannot be treated as dynamic values modifiable at runtime. 19 | For instance, lists and dictionaries can be used to configure kernel parameters 20 | during compilation or serve as containers for dynamic values, 21 | but their structure and organization cannot be altered during kernel execution. 22 | 23 | - **Static Values:** 24 | - Evaluated during JIT compilation phase 25 | - Immutable after compilation completes 26 | - Most Python native types (lists, tuples, dictionaries) are processed as static values 27 | - Primarily utilized for "meta-programming" and configuration purposes 28 | - Example: Lists can contain dynamic values but their structure cannot 29 | be modified during kernel execution 30 | 31 | - **Dynamic Values:** 32 | - Evaluated during runtime execution 33 | - Modifiable during execution of JIT-compiled functions 34 | - Only a specific subset of Python types are supported as dynamic values 35 | - Primitive types are automatically converted when passed as function arguments: 36 | 37 | - ``int`` → ``Int32`` (may be updated to ``Int64`` in future releases) 38 | - ``bool`` → ``Bool`` 39 | - ``float`` → ``Float32`` (may be updated to ``Float64`` in future releases) 40 | 41 | The JIT compiler processes Python native types analogously to C++ template parameters. 42 | The compiled code cannot manipulate dynamic values of composite types 43 | such as lists, tuples, or dictionaries. 44 | 45 | For example, following code doesn't work as traditional Python program inside JIT function. 46 | 47 | .. code:: python 48 | 49 | @cute.jit 50 | def foo(a: Float32, b: Float32, i: Int32, res: cute.Tensor): 51 | xs = [a, b] 52 | # indexing list with dynamic index is not supported in CuTe DSL: 53 | res[0] = xs[i] 54 | 55 | if i == 0: 56 | # This will alway append Float32(3.0) to the list regardless 57 | # of the runtime value of `i` 58 | xs.append(Float32(3.0)) 59 | 60 | for i in range(10): 61 | # This only append one element to the list at compile-time 62 | # as loop doesn't unroll at compile-time 63 | xs.append(Float32(1.0)) 64 | 65 | **Python Function** 66 | The DSL currently does not implement support for return values from Python functions, 67 | although this capability is planned for future releases. 68 | 69 | Example: 70 | 71 | .. code:: python 72 | 73 | @cute.jit 74 | def foo(): 75 | return 1 # Currently unsupported in CuTe DSL 76 | 77 | **Expression or Statement with Dependent Type** 78 | CuTe DSL implements static typing and does not support dependent types. 79 | The type of each expression must be determinable during compile time, 80 | in contrast to standard Python which implements dynamic typing. 81 | 82 | Example illustrating functionality in Python that is not supported in the DSL: 83 | 84 | .. code:: python 85 | 86 | # Valid in standard Python, but unsupported in CuTe DSL 87 | max(int(1), float(2.0)) # => 2.0 : float 88 | max(int(3), float(2.0)) # => 3 : int 89 | 90 | In CuTe DSL, types are promoted. For example: 91 | 92 | .. code:: python 93 | 94 | @cute.jit 95 | def foo(a: Int32, b: Float32, res: cute.Tensor): 96 | res[0] = max(a, b) # Type is automatically promoted to Float32 97 | 98 | Following code using inlined if-else expression with dependent types 99 | is not supported in CuTe DSL: 100 | 101 | .. code:: python 102 | 103 | @cute.jit 104 | def foo(cond: Boolean, a: Int32, b: Float32, res: cute.Tensor): 105 | res[0] = a if cond else b 106 | 107 | 108 | **Control Flow** 109 | The DSL transforms Python control flow statements (``if``, ``for``, ``while``) 110 | during Abstract Syntax Tree (AST) processing into structured control flow in MLIR 111 | which has the same constraints as dependent types. For instance, 112 | changing type of a variable in loop body is not allowed. 113 | 114 | - Variables must be defined prior to the control flow statement 115 | - Type consistency must be maintained throughout the control flow statement 116 | - Don't support early exit or return from if-else statements 117 | 118 | Example illustrating functionality in Python that is not supported in the DSL: 119 | 120 | .. code:: python 121 | 122 | @cute.jit 123 | def foo(): 124 | a = Int32(1) 125 | for i in range(10): 126 | a = Float32(2) # Changing type inside loop-body is not allowed in the DSL 127 | 128 | 129 | **Built-in Operators** 130 | The DSL transforms built-in operators like ``and``, ``or``, ``max``, ``min``, etc. 131 | into MLIR operations. They also follow the same constraints of dependent types. 132 | For instance, ``a and b`` requires ``a`` and ``b`` to be of the same type. 133 | 134 | 135 | **Special Variables** 136 | The DSL treats ``_`` as a special variable that it's value is meant to be ignored. 137 | It is not allowed to read ``_`` in the DSL. 138 | 139 | Example illustrating functionality in Python that is not supported in the DSL: 140 | 141 | .. code:: python 142 | 143 | @cute.jit 144 | def foo(): 145 | _ = 1 146 | print(_) # This is not allowed in the DSL 147 | 148 | 149 | **Object Oriented Programming** 150 | The DSL is implemented on top of Python and supports Python's object-oriented programming (OOP) features 151 | for meta-programming at compile-time. 152 | 153 | However, similar to other composed data types, the DSL provides limited support for OOP when objects 154 | contain dynamic values. It is strongly recommended to avoid passing dynamic values between member methods 155 | through class state in your code. 156 | 157 | The following example illustrates functionality in Python that is not supported in the DSL 158 | without implementing the ``DynamicExpression`` protocol: 159 | 160 | .. code:: python 161 | 162 | class Foo: 163 | def __init__(self, a: Int32): 164 | self.a = a 165 | 166 | def set_a(self, i: Int32): 167 | self.a = i 168 | 169 | def get_a(self): 170 | return self.a 171 | 172 | @cute.jit 173 | def foo(a: Int32, res: cute.Tensor): 174 | foo = Foo(a) 175 | for i in range(10): 176 | foo.set_a(i) 177 | 178 | # This fails to compile because `a` is assigned a local value defined within the for-loop body 179 | # and is not visible outside of the loop body 180 | res[0] = foo.get_a() 181 | 182 | The example above fails to compile because ``Foo.a`` is assigned a local value defined within the for-loop body, 183 | which is not visible outside the loop body. 184 | 185 | The CuTe DSL implements an internal mechanism that provides limited support for OOP patterns via protocol. 186 | As the DSL continues to evolve to support additional features, this mechanism is subject to change 187 | and is not recommended for direct use in users' code for better portability. 188 | 189 | 190 | **CuTe Layout algebra in native Python** 191 | Entirety of CuTe Layout algebra operations and APIs require JIT compilation. These 192 | functionalities are exclusively available within JIT-compiled functions and cannot be 193 | accessed in standard Python execution environments. 194 | 195 | Additionally, there exists a restricted set of data types that can be passed as arguments 196 | to JIT-compiled functions, which further constrains their usage in native Python contexts. 197 | Only following CuTe algebra types are supported as JIT function arguments: ``Tensor``, ``Pointer``, 198 | ``Shape``, ``Stride``, ``Coord`` and ``IntTuple``. For ``Stride``, we don't support ``ScacledBasis`` 199 | from native Python Context. Unfortunately, in the first release, we don't support 200 | passing ``Layout`` under native Python Context. 201 | 202 | 203 | Suggestions 204 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 205 | 206 | For reliable and predictable results: 207 | 208 | - Avoid dependent types in your code 209 | - Implement explicit type conversion for dynamic values 210 | - Clearly distinguish between static (compile-time) and dynamic (runtime) values 211 | - Use type annotations as much as possible to help JIT compiler 212 | to identify type to avoid ambiguity 213 | 214 | 215 | .. code:: python 216 | 217 | # Example demonstrating explicit typing 218 | alpha = 1.0 # Explicitly defined as float using `1.0` instead of `1` 219 | # or `float(1)` 220 | beta = 2.0 # Explicitly defined as float 221 | result = max(alpha, beta) # Will correctly perform float comparison 222 | -------------------------------------------------------------------------------- /quack/sort/sorting_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. 2 | """ 3 | Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html 4 | 5 | This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly. 6 | """ 7 | 8 | # fmt: off 9 | # ruff: noqa 10 | # isort: skip_file 11 | 12 | import cutlass 13 | import cutlass.cute as cute 14 | 15 | from quack.sort.utils import compare_and_swap 16 | 17 | 18 | networks = { 19 | # Size 2: 1 CEs, depth 1 20 | 2: [[(0, 1)]], 21 | 22 | # Size 4: 5 CEs, depth 3 23 | 4: [ 24 | [(0, 2), (1, 3)], 25 | [(0, 1), (2, 3)], 26 | [(1, 2)] 27 | ], 28 | 29 | # Size 8: 19 CEs, depth 6 30 | 8: [ 31 | [(0, 2), (1, 3), (4, 6), (5, 7)], 32 | [(0, 4), (1, 5), (2, 6), (3, 7)], 33 | [(0, 1), (2, 3), (4, 5), (6, 7)], 34 | [(2, 4), (3, 5)], 35 | [(1, 4), (3, 6)], 36 | [(1, 2), (3, 4), (5, 6)] 37 | ], 38 | 39 | # Size 16: 60 CEs, depth 10 40 | 16: [ 41 | [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)], 42 | [(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)], 43 | [(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)], 44 | [(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)], 45 | [(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)], 46 | [(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)], 47 | [(2, 4), (3, 6), (9, 12), (11, 13)], 48 | [(3, 5), (6, 8), (7, 9), (10, 12)], 49 | [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)], 50 | [(6, 7), (8, 9)] 51 | ], 52 | 53 | # Size 32: 185 CEs, depth 14 54 | 32: [ 55 | [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)], 56 | [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)], 57 | [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)], 58 | [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)], 59 | [(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)], 60 | [(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)], 61 | [(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)], 62 | [(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)], 63 | [(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)], 64 | [(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)], 65 | [(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)], 66 | [(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)], 67 | [(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)], 68 | [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)] 69 | ], 70 | 71 | # Size 64: 521 CEs, depth 21 72 | 64: [ 73 | [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)], 74 | [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)], 75 | [(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)], 76 | [(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)], 77 | [(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)], 78 | [(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)], 79 | [(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)], 80 | [(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)], 81 | [(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)], 82 | [(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)], 83 | [(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)], 84 | [(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)], 85 | [(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)], 86 | [(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)], 87 | [(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)], 88 | [(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)], 89 | [(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)], 90 | [(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)], 91 | [(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)], 92 | [(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)], 93 | [(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)] 94 | ], 95 | 96 | } 97 | 98 | 99 | @cute.jit 100 | def optimal_sort( 101 | arr: cute.Tensor, 102 | n: cutlass.Constexpr[int], 103 | start: cutlass.Constexpr[int] = 0, 104 | ascending: cutlass.Constexpr[bool] = True 105 | ) -> None: 106 | """ 107 | Optimal sorting network dispatcher. 108 | 109 | Args: 110 | arr: Array to sort 111 | n: Size of array (must be power of 2 and available in networks) 112 | start: Starting index (default 0) 113 | ascending: Sort in ascending order (default True) 114 | 115 | Source: https://bertdobbelaere.github.io/sorting_networks.html 116 | """ 117 | assert n in networks 118 | for level in networks[n]: 119 | for i, j in level: 120 | compare_and_swap(arr, start + i, start + j, ascending) 121 | -------------------------------------------------------------------------------- /quack/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from torch.amp import custom_fwd, custom_bwd 9 | 10 | 11 | from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact 12 | 13 | 14 | def linear_fwd_convert_type(*tensors): 15 | autocast_dtype = torch.get_autocast_dtype("cuda") 16 | if torch.is_autocast_enabled(): 17 | tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors) 18 | return tensors 19 | 20 | 21 | def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad): 22 | needs_input_grad, needs_weight_grad = needs_x_w_grad 23 | if not needs_input_grad: 24 | weight, weight_og = None, None 25 | if not needs_weight_grad: 26 | x = None 27 | ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None) 28 | 29 | 30 | def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn): 31 | if ctx.needs_input_grad[0]: 32 | assert weight is not None 33 | return matmul_fn(dout, weight) 34 | else: 35 | return None 36 | 37 | 38 | def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn): 39 | if ctx.needs_input_grad[1]: 40 | assert x is not None 41 | x = x.reshape(-1, x.shape[-1]) 42 | # fuse_grad_accum is not compatible with torch.compile 43 | if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling(): 44 | dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype) 45 | else: 46 | # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape) 47 | matmul_inplace_fn(dout.T, x, weight_og.grad) 48 | dweight = weight_og.grad 49 | weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again 50 | else: 51 | dweight = None 52 | return dweight 53 | 54 | 55 | class LinearFunc(torch.autograd.Function): 56 | matmul_fwd_fn = gemm 57 | matmul_bwd_dx = partial(gemm, dynamic_scheduler=True) 58 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True) 59 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True) 60 | 61 | # Use classmethod instead of staticmethod to allow inheritance 62 | @classmethod 63 | @custom_fwd(device_type="cuda") 64 | def forward(cls, ctx, x, weight, bias=None, fuse_grad_accum=False): 65 | """ 66 | x: (..., in_features) 67 | weight: (out_features, in_features) 68 | bias: (out_features,) or None 69 | out: (..., out_features) 70 | """ 71 | ctx.weight_dtype = weight.dtype 72 | ctx.fuse_grad_accum = fuse_grad_accum 73 | weight_og = weight 74 | x, weight = linear_fwd_convert_type(x, weight) 75 | batch_shape = x.shape[:-1] 76 | x = x.reshape(-1, x.shape[-1]) 77 | # out = F.linear(x, weight) 78 | out = cls.matmul_fwd_fn(x, weight.T, bias=bias) 79 | linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]) 80 | ctx.bias_dtype = bias.dtype if bias is not None else None 81 | return out.reshape(*batch_shape, out.shape[-1]) 82 | 83 | @classmethod 84 | @custom_bwd(device_type="cuda") 85 | def backward(cls, ctx, dout, *args): 86 | """ 87 | dout: (..., out_features) 88 | """ 89 | x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum 90 | batch_shape = dout.shape[:-1] 91 | dout = dout.reshape(-1, dout.shape[-1]) 92 | dbias = ( 93 | dout.sum(0, dtype=ctx.bias_dtype) 94 | if ctx.bias_dtype is not None and ctx.needs_input_grad[2] 95 | else None 96 | ) 97 | dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx) 98 | dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None 99 | dweight = linear_bwd_compute_weight_grad( 100 | ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace 101 | ) 102 | # return extra Nones for other classes that inherit from LinearFunc 103 | return dx, dweight, dbias, *([None] * 10) 104 | 105 | 106 | class LinearUntunedFunc(LinearFunc): 107 | # Passing in tuned=False to disable tuning at runtime 108 | matmul_fwd_fn = partial(gemm, tuned=False) 109 | matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False) 110 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False) 111 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True) 112 | 113 | 114 | def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True): 115 | fn_cls = LinearFunc if tuned else LinearUntunedFunc 116 | return fn_cls.apply(x, weight, bias, fuse_grad_accum) 117 | 118 | 119 | class LinearActFunc(LinearFunc): 120 | matmul_fwd_fn = gemm_act 121 | 122 | # Use classmethod instead of staticmethod to allow inheritance 123 | @classmethod 124 | @custom_fwd(device_type="cuda") 125 | def forward( 126 | cls, ctx, x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False 127 | ): 128 | """ 129 | x: (..., in_features) 130 | weight: (out_features, in_features) 131 | bias: (out_features,) or None 132 | out: (..., out_features) 133 | Return both out and post-activation, but only out is differentiable. 134 | """ 135 | ctx.weight_dtype = weight.dtype 136 | ctx.fuse_grad_accum = fuse_grad_accum 137 | weight_og = weight 138 | x, weight = linear_fwd_convert_type(x, weight) 139 | batch_shape = x.shape[:-1] 140 | x = x.reshape(-1, x.shape[-1]) 141 | out, postact = cls.matmul_fwd_fn( 142 | x, weight.T, bias=bias, activation=activation, store_preact=store_preact 143 | ) 144 | linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]) 145 | if out is not None: 146 | out = out.reshape(*batch_shape, out.shape[-1]) 147 | ctx.bias_dtype = bias.dtype if bias is not None else None 148 | ctx.mark_non_differentiable(postact) 149 | ctx.set_materialize_grads(False) # We don't want to materialize grads for postact 150 | return out, postact.reshape(*batch_shape, postact.shape[-1]) 151 | 152 | 153 | class LinearActUntunedFunc(LinearActFunc): 154 | # Passing in tuned=False to disable tuning at runtime 155 | matmul_fwd_fn = partial(gemm_act, tuned=False) 156 | matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False) 157 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False) 158 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True) 159 | 160 | 161 | def linear_act_func( 162 | x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True 163 | ): 164 | fn_cls = LinearActFunc if tuned else LinearActUntunedFunc 165 | return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum) 166 | 167 | 168 | class DActLinearFunc(LinearFunc): 169 | matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True) 170 | 171 | # Use classmethod instead of staticmethod to allow inheritance 172 | @classmethod 173 | @custom_fwd(device_type="cuda") 174 | def forward(cls, ctx, preact, weight, x, activation, fuse_grad_accum=False): 175 | """ 176 | x: (..., in_features) 177 | weight: (out_features, in_features) 178 | out: (..., out_features) 179 | Takes in an extra preact argument which is the pre-activation, to be used in the backward pass. 180 | """ 181 | ctx.weight_dtype = weight.dtype 182 | ctx.fuse_grad_accum = fuse_grad_accum 183 | weight_og = weight 184 | x, weight = linear_fwd_convert_type(x, weight) 185 | batch_shape = x.shape[:-1] 186 | x = x.reshape(-1, x.shape[-1]) 187 | out = cls.matmul_fwd_fn(x, weight.T) 188 | # Store preact instead of x, we will recompute x in the backward pass 189 | linear_fwd_postprocess( 190 | ctx, preact, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2] 191 | ) 192 | ctx.activation = activation 193 | return out.reshape(*batch_shape, out.shape[-1]) 194 | 195 | @classmethod 196 | @custom_bwd(device_type="cuda") 197 | def backward(cls, ctx, dout): 198 | """ 199 | dout: (..., out_features) 200 | """ 201 | # weight_og is None if not ctx.fuse_grad_accum 202 | preact, weight, weight_og = ctx.saved_tensors 203 | batch_shape = dout.shape[:-1] 204 | dout = dout.reshape(-1, dout.shape[-1]) 205 | preact = preact.reshape(-1, preact.shape[-1]) 206 | if ctx.needs_input_grad[0]: 207 | assert weight is not None 208 | dpreact, x = cls.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation) 209 | else: 210 | dpreact, x = None, None 211 | dpreact = dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None 212 | dweight = linear_bwd_compute_weight_grad( 213 | ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace 214 | ) 215 | return dpreact, dweight, *([None] * 3) 216 | 217 | 218 | class DActLinearUntunedFunc(DActLinearFunc): 219 | # Passing in tuned=False to disable tuning at runtime 220 | matmul_fwd_fn = partial(gemm, tuned=False) 221 | matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False) 222 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False) 223 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True) 224 | 225 | 226 | def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True): 227 | fn_cls = DActLinearFunc if tuned else DActLinearUntunedFunc 228 | return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum) 229 | 230 | 231 | class Linear(nn.Linear): 232 | def __init__( 233 | self, 234 | in_features: int, 235 | out_features: int, 236 | bias: bool = False, 237 | device=None, 238 | dtype=None, 239 | fuse_grad_accum: bool = False, 240 | ) -> None: 241 | super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) 242 | self.fuse_grad_accum = fuse_grad_accum 243 | 244 | def forward(self, input: Tensor) -> Tensor: 245 | if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0: 246 | return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum) 247 | else: 248 | return F.linear(input, self.weight, self.bias) 249 | -------------------------------------------------------------------------------- /benchmarks/benchmark_rmsnorm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from typing import Type, Optional 4 | 5 | import torch 6 | from triton.testing import do_bench 7 | 8 | import cutlass 9 | import cutlass.torch as cutlass_torch 10 | from cutlass.cute.runtime import from_dlpack 11 | from quack.rmsnorm import rmsnorm_fwd, rmsnorm_ref, rmsnorm, rmsnorm_bwd 12 | import cutlass.cute as cute 13 | 14 | try: 15 | import cudnn 16 | except ImportError: 17 | cudnn = None 18 | 19 | 20 | def run_rmsnorm( 21 | M, 22 | N, 23 | dtype: torch.dtype, 24 | residual_dtype: Optional[torch.dtype] = None, 25 | warmup_iterations=5, 26 | iterations=100, 27 | ): 28 | if not torch.cuda.is_available(): 29 | raise RuntimeError(f"Ampere GPU is required to run this example!") 30 | 31 | print(f"Tensor dimensions: [{M}, {N}]") 32 | print(f"Input and Output Data type: {dtype}") 33 | 34 | device = "cuda" 35 | x = torch.randn(M, N, device=device, dtype=dtype) 36 | if residual_dtype is not None: 37 | residual = torch.randn(M, N, device=device, dtype=residual_dtype) 38 | else: 39 | residual = None 40 | w = torch.randn(N, device=device, dtype=torch.float32) 41 | 42 | print(f"Input tensor shapes:") 43 | print(f"x: {x.shape}, dtype: {x.dtype}") 44 | print(f"w: {w.shape}, dtype: {w.dtype}") 45 | 46 | eps = 1e-6 47 | 48 | print("Executing kernel...") 49 | rmsnorm_fwd(x, w, residual=residual, eps=eps, store_rstd=True) 50 | 51 | compiled_func_ref = torch.compile(rmsnorm_ref) 52 | 53 | fn = lambda: rmsnorm_fwd(x, w, residual=residual, eps=eps) 54 | time.sleep(0.5) 55 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 56 | mem_bytes = (2 * x.numel() * dtype.itemsize + w.numel() * 4) 57 | if residual is not None: 58 | mem_bytes += 2 * residual.numel() * residual.dtype.itemsize 59 | mem_bw = round(mem_bytes / (avg_time / 1000) / 1e9) 60 | print(f"Kernel execution time: {avg_time:.4f} ms") 61 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 62 | 63 | fn = lambda: compiled_func_ref(x, w, residual=residual, eps=eps) 64 | for _ in range(5): fn() # warm up 65 | time.sleep(0.5) 66 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) 67 | mem_bytes_ref = mem_bytes 68 | mem_bw_ref = round(mem_bytes_ref / (avg_time / 1000) / 1e9) 69 | print(f"Ref kernel execution time: {avg_time:.4f} ms") 70 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 71 | 72 | if cudnn is not None: 73 | run_cudnn = rmsnorm_cudnn_setup(M, N, dtype) 74 | time.sleep(0.5) 75 | avg_time = do_bench(run_cudnn, warmup=warmup_iterations, rep=iterations) 76 | mem_bytes_cudnn = (2 * x.numel() * dtype.itemsize + w.numel() * 4) 77 | mem_bw_cudnn = round(mem_bytes_cudnn / (avg_time / 1000) / 1e9) 78 | print(f"Cudnn kernel execution time: {avg_time:.4f} ms") 79 | print(f"Cudnn mem throughput: {mem_bw_cudnn:.2f} GB/s") 80 | 81 | return mem_bw, mem_bw_ref 82 | 83 | 84 | def rmsnorm_cudnn_setup(M, N, dtype): 85 | x_gpu = torch.empty(M, N, dtype=dtype, device="cuda") 86 | scale_gpu = torch.empty(1, N, dtype=dtype, device="cuda") 87 | epsilon_cpu = torch.ones((1, 1), dtype=torch.float32, device="cpu") 88 | out_gpu = torch.empty_like(x_gpu) 89 | inv_var_gpu = torch.empty(M, 1, dtype=torch.float32, device="cuda") 90 | handle = cudnn.create_handle() 91 | graph = cudnn.pygraph( 92 | handle=handle, 93 | intermediate_data_type=cudnn.data_type.FLOAT, 94 | compute_data_type=cudnn.data_type.FLOAT, 95 | ) 96 | # create tensor handles with the graph API 97 | x = graph.tensor_like(x_gpu.detach()).set_name("X") 98 | scale = graph.tensor_like(scale_gpu.detach()).set_name("scale") 99 | epsilon = graph.tensor_like(epsilon_cpu).set_name("epsilon") 100 | (out, inv_var) = graph.rmsnorm( 101 | name="rmsnorm", 102 | input=x, 103 | norm_forward_phase=cudnn.norm_forward_phase.TRAINING, 104 | scale=scale, 105 | epsilon=epsilon, 106 | ) 107 | # enable all outputs 108 | out.set_name("output").set_output(True).set_data_type(out_gpu.dtype) 109 | inv_var.set_name("inv_var").set_output(True).set_data_type(inv_var_gpu.dtype) 110 | graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) 111 | # Mapping of (handles -> memory) 112 | variant_pack = { 113 | x: x_gpu.detach(), 114 | scale: scale_gpu.detach(), 115 | epsilon: epsilon_cpu, 116 | out: out_gpu, 117 | inv_var: inv_var_gpu, 118 | } 119 | workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) 120 | 121 | def run(*args, **kwargs): 122 | graph.execute(variant_pack, workspace) 123 | return out_gpu, inv_var_gpu 124 | 125 | return run 126 | 127 | 128 | def run_rmsnorm_bwd( 129 | M, 130 | N, 131 | dtype: torch.dtype, 132 | residual_dtype: Optional[torch.dtype] = None, 133 | warmup_iterations=5, 134 | iterations=100, 135 | ): 136 | if not torch.cuda.is_available(): 137 | raise RuntimeError(f"Ampere GPU is required to run this example!") 138 | 139 | print(f"Tensor dimensions: [{M}, {N}]") 140 | print(f"Input and Output Data type: {dtype}") 141 | 142 | device = "cuda" 143 | 144 | # Set up forward pass inputs with gradients enabled 145 | x = torch.randn(M, N, device=device, dtype=dtype, requires_grad=True) 146 | x_ref = x.detach().clone().requires_grad_() 147 | w = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True) 148 | w_ref = w.detach().clone().requires_grad_() 149 | if residual_dtype is not None: 150 | residual = torch.randn(M, N, device=device, dtype=residual_dtype, requires_grad=True) 151 | residual_ref = residual.detach().clone().requires_grad_() 152 | else: 153 | residual, residual_ref = None, None 154 | 155 | print(f"Input tensor shapes:") 156 | print(f"x: {x.shape}, dtype: {x.dtype}") 157 | print(f"w: {w.shape}, dtype: {w.dtype}") 158 | 159 | eps = 1e-6 160 | 161 | # Forward pass to get outputs and rstd 162 | y = rmsnorm(x, w, residual=residual, eps=eps) 163 | if residual is not None: 164 | y, residual_out = y 165 | else: 166 | residual_out = None 167 | 168 | # Create upstream gradients 169 | dy = torch.randn_like(y) 170 | rstd = torch.randn(M, device=device, dtype=torch.float32) 171 | dresidual_out = torch.randn_like(residual_out) if residual is not None else None 172 | 173 | def mem_in_bytes(*args): 174 | return sum(t.numel() * t.dtype.itemsize for t in args if t is not None) 175 | 176 | time.sleep(0.5) 177 | # Benchmark custom backward pass 178 | # fn = lambda: torch.autograd.grad(y, [x, w], grad_outputs=dy, retain_graph=True) 179 | def fn(): 180 | # x.grad = None # Reset gradients to avoid accumulation 181 | # y.backward(dy, retain_graph=True) 182 | rmsnorm_bwd(x if residual is None else residual_out, w, dy, rstd, dresidual_out=dresidual_out) 183 | 184 | avg_time = do_bench(fn, grad_to_none=(x,), warmup=warmup_iterations, rep=iterations) 185 | sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 2 186 | mem_bytes = mem_in_bytes(x if residual is None else residual_out, w, dy, dresidual_out, x, residual if residual is not None and residual.dtype != x.dtype else None) 187 | mem_bw = round(mem_bytes / (avg_time / 1000) / 1e9) 188 | print(f"Kernel execution time: {avg_time:.4f} ms") 189 | print(f"Mem throughput: {mem_bw:.2f} GB/s") 190 | from flash_attn.utils.benchmark import pytorch_profiler 191 | pytorch_profiler(fn) 192 | 193 | # Reference implementation 194 | y_ref = torch.compile(rmsnorm_ref)(x_ref, w_ref, eps=eps) 195 | compiled_func_ref = lambda: torch.autograd.grad(y_ref, [x_ref, w_ref], grad_outputs=dy, retain_graph=True) 196 | # def f(): 197 | # x_ref.grad = None # Reset gradients to avoid accumulation 198 | # w_ref.grad = None 199 | # rmsnorm_ref(x_ref, w_ref, eps=eps).backward(dy) 200 | # compiled_func_ref = torch.compile(f) 201 | 202 | for _ in range(5): compiled_func_ref() # warm up 203 | time.sleep(0.5) 204 | avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations) 205 | mem_bytes_ref = (3 * x.numel() * dtype.itemsize + w.numel() * 4 + x.shape[0] * 4 + sm_count * w.numel() * 4) 206 | mem_bw_ref = round(mem_bytes_ref / (avg_time_ref / 1000) / 1e9) 207 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms") 208 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") 209 | pytorch_profiler(compiled_func_ref) 210 | 211 | return mem_bw, mem_bw_ref 212 | 213 | 214 | if __name__ == "__main__": 215 | parser = argparse.ArgumentParser( 216 | description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels" 217 | ) 218 | parser.add_argument("--M", default=32768, type=int) 219 | parser.add_argument("--N", default=32768, type=int) 220 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16) 221 | parser.add_argument("--residual_dtype", type=cutlass.dtype, choices=[None, cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=None) 222 | parser.add_argument("--warmup_iterations", default=10, type=int) 223 | parser.add_argument("--iterations", default=100, type=int) 224 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass") 225 | 226 | args = parser.parse_args() 227 | 228 | if args.backward: 229 | print("=== RMSNorm Backward Pass Benchmark ===") 230 | run_rmsnorm_bwd( 231 | args.M, 232 | args.N, 233 | dtype=cutlass.torch.dtype(args.dtype), 234 | residual_dtype=cutlass.torch.dtype(args.residual_dtype) if args.residual_dtype else None, 235 | warmup_iterations=args.warmup_iterations, 236 | iterations=args.iterations, 237 | ) 238 | else: 239 | print("=== RMSNorm Forward Pass Benchmark ===") 240 | run_rmsnorm( 241 | args.M, 242 | args.N, 243 | dtype=cutlass.torch.dtype(args.dtype), 244 | residual_dtype=cutlass.torch.dtype(args.residual_dtype) if args.residual_dtype else None, 245 | warmup_iterations=args.warmup_iterations, 246 | iterations=args.iterations, 247 | ) 248 | # # MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)] 249 | # # MN_pairs = [(32768, 2048)] 250 | # MN_pairs = [(16384, 65536)] 251 | # results = [] 252 | # for M, N in MN_pairs: 253 | # res = run_rmsnorm( 254 | # M, 255 | # N, 256 | # dtype=cutlass.BFloat16, 257 | # skip_ref_check=False, 258 | # benchmark=True, 259 | # warmup_iterations=args.warmup_iterations, 260 | # iterations=args.iterations, 261 | # ) 262 | # results.append(res) 263 | # print(results) 264 | # print("\nPASS") 265 | -------------------------------------------------------------------------------- /quack/linear_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Tri Dao 2 | from typing import Optional, Literal 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from torch.amp import custom_fwd, custom_bwd 9 | 10 | from quack.cross_entropy import cross_entropy, cross_entropy_fwd_out 11 | from quack.gemm_interface import gemm, gemm_add, gemm_add_inplace 12 | from quack.linear import linear_fwd_convert_type 13 | 14 | 15 | def linear_cross_entropy_func( 16 | x: Tensor, # (..., d) 17 | weight: Tensor, # (V, d) 18 | bias: Optional[Tensor], # (V,) or None 19 | target: Tensor, # (...,), int or long 20 | ignore_index: int = -100, 21 | reduction: Literal["none", "mean", "sum"] = "mean", 22 | inplace_backward: bool = False, 23 | ) -> Tensor: 24 | y = F.linear(x, weight, bias) # (..., V) 25 | return cross_entropy( 26 | y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward 27 | ) 28 | 29 | 30 | def linear_cross_entropy_func_ref( 31 | x: Tensor, # (..., d) 32 | weight: Tensor, # (V, d) 33 | bias: Optional[Tensor], # (V,) or None 34 | target: Tensor, # (...,), int or long 35 | ignore_index: int = -100, 36 | reduction: Literal["none", "mean", "sum"] = "mean", 37 | ) -> Tensor: 38 | y = F.linear(x, weight, bias) # (..., V) 39 | return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction) 40 | 41 | 42 | def chunked_linear_cross_entropy_fwd( 43 | x: Tensor, # (B*L, d) where B is batch, L is seqlen 44 | weight: Tensor, # (V, d) where V is vocab size 45 | target: Tensor, # (B*L,) 46 | chunk_size: int = 4096, 47 | ignore_index: int = -100, 48 | tuned: bool = True, 49 | ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: 50 | """ 51 | Chunked forward pass for linear cross entropy. 52 | 53 | Splits input along batch dimension, computes matmul and cross_entropy_fwd 54 | for each chunk, stores dx for each chunk, and accumulates dw. 55 | 56 | Returns: 57 | loss: (B*L,) loss values 58 | dx: (B*L, d) gradient w.r.t. input 59 | dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last) 60 | last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation) 61 | last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation) 62 | """ 63 | B_L, d = x.shape 64 | V, _ = weight.shape 65 | device = x.device 66 | num_chunks = (B_L + chunk_size - 1) // chunk_size 67 | # Since we use gemm with TMA we require some alignment 68 | assert chunk_size % 8 == 0, "chunk_size must be multiple of 8" 69 | assert B_L % 8 == 0 70 | # Pre-allocate outputs 71 | loss = torch.empty(B_L, device=device, dtype=torch.float32) 72 | logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype) 73 | dx = torch.empty_like(x) 74 | # Last chunk of dw will be deferred to the backward pass 75 | dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None 76 | last_dlogits_chunk = None 77 | last_x_chunk = None 78 | 79 | # Process in chunks 80 | for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate( 81 | zip(*(t.split(chunk_size) for t in (x, target, loss, dx))) 82 | ): 83 | chunk_len = x_chunk.shape[0] 84 | logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V) 85 | torch.mm(x_chunk, weight.mT, out=logits_chunk) 86 | # Compute cross entropy forward with gradients 87 | dlogits_chunk = logits_chunk # inplace_backward 88 | cross_entropy_fwd_out( 89 | logits_chunk, 90 | target_chunk, 91 | None, # target_logit 92 | loss=loss_chunk, 93 | lse=None, # we don't need lse here 94 | dx=dlogits_chunk, 95 | ignore_index=ignore_index, 96 | ) 97 | # Compute dx for this chunk: dlogits @ weight 98 | torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d) 99 | # Compute dw for all chunks except the last 100 | if i == num_chunks - 1: 101 | # Last chunk: save for backward pass 102 | last_dlogits_chunk = dlogits_chunk 103 | last_x_chunk = x_chunk 104 | elif i == 0: 105 | # First chunk: dw = dlogits.T @ x_chunk 106 | gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned) 107 | else: 108 | # Middle chunks: dw += dlogits.T @ x_chunk 109 | gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned) 110 | return loss, dx, dw, last_dlogits_chunk, last_x_chunk 111 | 112 | 113 | class ChunkedLinearCrossEntropyFunction(torch.autograd.Function): 114 | @staticmethod 115 | @custom_fwd(device_type="cuda") 116 | def forward( 117 | ctx, 118 | x: Tensor, 119 | weight: Tensor, 120 | target: Tensor, 121 | ignore_index: int = -100, 122 | reduction: Literal["mean", "sum"] = "mean", 123 | chunk_size: int = 4096, 124 | tuned: bool = True, 125 | ): 126 | """ 127 | Forward pass computes loss and stores dx and dw for backward. 128 | """ 129 | ctx.weight_dtype = weight.dtype 130 | x, weight = linear_fwd_convert_type(x, weight) 131 | batch_shape = x.shape[:-1] 132 | x = x.reshape(-1, x.shape[-1]) 133 | # TODO: don't need to compute bwd if neither x nor weight requires grad, or not training 134 | loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd( 135 | x, weight, target, chunk_size, ignore_index, tuned=tuned 136 | ) 137 | loss_sum = loss.sum() 138 | loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float() 139 | ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale) 140 | ctx.batch_shape = batch_shape 141 | ctx.ignore_index = ignore_index 142 | ctx.reduction = reduction 143 | ctx.tuned = tuned 144 | return loss_sum if loss_scale is None else loss_sum * loss_scale 145 | 146 | @staticmethod 147 | @custom_bwd(device_type="cuda") 148 | def backward(ctx, dloss): 149 | """ 150 | Backward pass scales pre-computed gradients by dloss and completes 151 | the last chunk's dw computation. 152 | dloss is a scalar. 153 | """ 154 | dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors 155 | tuned = ctx.tuned 156 | if loss_scale is not None: 157 | dloss = dloss * loss_scale 158 | # TODO: the case where x or weight doesn't require grad 159 | dx.mul_(dloss) 160 | dx = dx.reshape(*ctx.batch_shape, dx.shape[-1]) 161 | # Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk) 162 | if dw is None: 163 | # Only had one chunk, compute dw directly with dloss scaling 164 | dw = gemm( 165 | last_dlogits_chunk.T, 166 | last_x_chunk, 167 | out_dtype=ctx.weight_dtype, 168 | alpha=dloss, 169 | tuned=tuned, 170 | ) 171 | else: 172 | # Add last chunk's contribution with dloss scaling 173 | # dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk) 174 | # We use alpha=dloss, beta=dloss 175 | if ctx.weight_dtype == dw.dtype: 176 | gemm_add_inplace( 177 | last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned 178 | ) 179 | else: 180 | dw = gemm_add( 181 | last_dlogits_chunk.T, 182 | last_x_chunk, 183 | dw, 184 | alpha=dloss, 185 | beta=dloss, 186 | out_dtype=ctx.weight_dtype, 187 | tuned=tuned, 188 | ) 189 | return dx, dw, None, None, None, None, None 190 | 191 | 192 | def chunked_linear_cross_entropy( 193 | x: Tensor, 194 | weight: Tensor, 195 | target: Tensor, 196 | chunk_size: int = 4096, 197 | ignore_index: int = -100, 198 | reduction: Literal["mean", "sum"] = "mean", 199 | tuned: bool = True, 200 | ) -> Tensor: 201 | """ 202 | Chunked linear cross entropy with automatic differentiation support. 203 | 204 | Args: 205 | x: Input tensor of shape (B*L, d) 206 | weight: Weight tensor of shape (V, d) 207 | target: Target indices of shape (B*L,) 208 | chunk_size: Size of chunks to process 209 | ignore_index: Index to ignore in loss computation 210 | reduction: Type of reduction to apply 211 | tuned: Whether to use tuned kernels 212 | 213 | Returns: 214 | Loss tensor with specified reduction 215 | """ 216 | if reduction not in ["mean", "sum"]: 217 | raise ValueError(f"Invalid reduction: {reduction}") 218 | loss = ChunkedLinearCrossEntropyFunction.apply( 219 | x, weight, target, ignore_index, reduction, chunk_size, tuned 220 | ) 221 | return loss 222 | 223 | 224 | class LinearCrossEntropy(nn.Linear): 225 | def __init__( 226 | self, 227 | in_features: int, 228 | out_features: int, 229 | bias: bool = False, 230 | ignore_index: int = -100, 231 | reduction: Literal["none", "mean", "sum"] = "mean", 232 | chunk_size: Optional[int] = None, 233 | inplace_backward: bool = False, 234 | tuned: bool = True, 235 | device=None, 236 | dtype=None, 237 | ) -> None: 238 | super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) 239 | self.ignore_index = ignore_index 240 | self.reduction = reduction 241 | self.chunk_size = chunk_size 242 | self.inplace_backward = inplace_backward 243 | self.tuned = tuned 244 | 245 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 246 | if ( 247 | self.bias is None 248 | and input.is_cuda 249 | and input.stride(-1) == 1 250 | and self.in_features % 8 == 0 251 | and self.out_features % 8 == 0 252 | and input.shape[:-1].numel() % 8 == 0 253 | and self.chunk_size is not None 254 | and self.chunk_size % 8 == 0 255 | and self.reduction in ["mean", "sum"] 256 | ): 257 | return chunked_linear_cross_entropy( 258 | input, 259 | self.weight, 260 | target, 261 | chunk_size=self.chunk_size, 262 | ignore_index=self.ignore_index, 263 | reduction=self.reduction, 264 | tuned=self.tuned, 265 | ) 266 | else: 267 | return linear_cross_entropy_func( 268 | input, 269 | self.weight, 270 | self.bias, 271 | target, 272 | ignore_index=self.ignore_index, 273 | reduction=self.reduction, 274 | inplace_backward=self.inplace_backward, 275 | ) 276 | -------------------------------------------------------------------------------- /tests/test_symmetric_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from quack.gemm_interface import gemm_symmetric as symmetric_dense_gemm 5 | 6 | class TestSymmetricGemm: 7 | """Unit tests for symmetric dense GEMM wrapper.""" 8 | 9 | @pytest.fixture(params=[torch.float16, torch.bfloat16, torch.float32]) 10 | def dtype(self, request): 11 | """Test different data types.""" 12 | return request.param 13 | 14 | @property 15 | def default_shape(self): 16 | """Default shape for most tests (L, M, K).""" 17 | return (2, 1024, 512) 18 | 19 | def torch_reference(self, a, b=None, C=None, alpha=1.0, beta=1.0): 20 | """Reference implementation using PyTorch operations. 21 | 22 | Args: 23 | a: Input tensor A of shape (L, M, K) 24 | b: Input tensor B of shape (L, M, K) - if None, uses A (symmetric case) 25 | C: Optional additive tensor C of shape (L, M, M) 26 | alpha: Scaling factor for A @ B^T 27 | beta: Scaling factor for C 28 | 29 | Returns: 30 | Result tensor of shape (L, M, M) 31 | """ 32 | if b is None: 33 | b = a 34 | 35 | # Use einsum for batched matrix multiplication: A @ B^T 36 | # a: (L, M, K), b: (L, M, K) -> result: (L, M, M) 37 | result = alpha * torch.einsum("lmk,lnk->lmn", a, b) 38 | 39 | if C is not None: 40 | result = result + beta * C 41 | 42 | return result 43 | 44 | def create_test_tensor(self, L, M, K, dtype, device, stride_pattern="m_major", seed=None): 45 | """Create test tensor with specified stride pattern. 46 | 47 | Args: 48 | L, M, K: Tensor dimensions 49 | dtype: Data type 50 | device: Device ('cuda' or 'cpu') 51 | stride_pattern: How to arrange strides - 'm_major' means M has stride 1, 'k_major' means K has stride 1 52 | seed: Random seed for reproducibility 53 | """ 54 | if stride_pattern == "m_major": 55 | # M has stride 1: (L, M, K) with strides (M*K, 1, M) 56 | tensor = torch.empty_strided((L, M, K), (M * K, 1, M), dtype=dtype, device=device) 57 | elif stride_pattern == "k_major": 58 | # K has stride 1: (L, M, K) with strides (M*K, K, 1) 59 | tensor = torch.empty_strided((L, M, K), (M * K, K, 1), dtype=dtype, device=device) 60 | else: 61 | raise ValueError(f"Unsupported stride pattern: {stride_pattern}") 62 | 63 | # Fill with random data 64 | if seed is not None: 65 | torch.manual_seed(seed) 66 | tensor.uniform_(-2, 2) 67 | return tensor 68 | 69 | def create_symmetric_tensor(self, L, M, dtype, device, seed=None): 70 | """Create a symmetric tensor of shape (L, M, M).""" 71 | if seed is not None: 72 | torch.manual_seed(seed) 73 | 74 | tensor = torch.randn(L, M, M, dtype=dtype, device=device) 75 | 76 | for l in range(L): 77 | matrix = tensor[l, :, :] 78 | tensor[l, :, :] = (matrix + matrix.T) / 2 79 | 80 | return tensor 81 | 82 | def test_basic_symmetric_gemm(self, dtype): 83 | """Test basic symmetric GEMM without bias.""" 84 | if not torch.cuda.is_available(): 85 | pytest.skip("CUDA not available") 86 | 87 | L, M, K = self.default_shape 88 | device = "cuda" 89 | 90 | # Create input tensor A with stride 1 along M dimension 91 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42) 92 | 93 | print(f"a.shape = {a.shape}, a.stride = {a.stride()}") 94 | 95 | # Test symmetric case (B = A.transpose(-2, -1) for symmetric GEMM) 96 | result_quack = symmetric_dense_gemm(a, a.transpose(-2, -1), C=None) 97 | result_torch = self.torch_reference(a, a) 98 | 99 | assert result_quack.shape == result_torch.shape == (L, M, M) 100 | 101 | if dtype == torch.float32: 102 | torch.testing.assert_close(result_quack, result_torch, atol=1e-4, rtol=1e-4) 103 | else: # float16, bfloat16 104 | torch.testing.assert_close(result_quack, result_torch, atol=1e-2, rtol=1e-2) 105 | 106 | def test_symmetric_gemm_with_bias(self, dtype): 107 | """Test symmetric GEMM with bias tensor C.""" 108 | if not torch.cuda.is_available(): 109 | pytest.skip("CUDA not available") 110 | 111 | L, M, K = self.default_shape 112 | device = "cuda" 113 | 114 | # Create input tensors 115 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42) 116 | c = self.create_symmetric_tensor(L, M, dtype, device, seed=123) 117 | 118 | # Compute with our wrapper 119 | result_quack = symmetric_dense_gemm(a, a.transpose(-2, -1), C=c, alpha=1.0, beta=1.0) 120 | 121 | # Compute reference 122 | result_torch = self.torch_reference(a, a, C=c) 123 | 124 | # Check shapes match 125 | assert result_quack.shape == result_torch.shape == (L, M, M) 126 | 127 | # Check values match 128 | if dtype == torch.float32: 129 | torch.testing.assert_close(result_quack, result_torch, atol=1e-4, rtol=1e-4) 130 | else: 131 | torch.testing.assert_close(result_quack, result_torch, atol=1e-2, rtol=1e-2) 132 | 133 | def test_alpha_beta_scaling(self, dtype): 134 | """Test alpha and beta scaling factors.""" 135 | if not torch.cuda.is_available(): 136 | pytest.skip("CUDA not available") 137 | 138 | L, M, K = self.default_shape 139 | device = "cuda" 140 | alpha, beta = 2.5, 0.5 141 | 142 | # Create input tensors 143 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42) 144 | c = self.create_symmetric_tensor(L, M, dtype, device, seed=123) 145 | 146 | # Compute with our wrapper 147 | result_quack = symmetric_dense_gemm(a, a.transpose(-2, -1), C=c, alpha=alpha, beta=beta) 148 | 149 | # Compute reference 150 | result_torch = self.torch_reference(a, a, C=c, alpha=alpha, beta=beta) 151 | 152 | # Check values match 153 | if dtype == torch.float32: 154 | torch.testing.assert_close(result_quack, result_torch, atol=1e-4, rtol=1e-4) 155 | else: 156 | torch.testing.assert_close(result_quack, result_torch, atol=1e-2, rtol=1e-2) 157 | 158 | def test_symmetry_property(self, dtype): 159 | """Test that output is actually symmetric (D = D^T).""" 160 | if not torch.cuda.is_available(): 161 | pytest.skip("CUDA not available") 162 | 163 | L, M, K = self.default_shape 164 | device = "cuda" 165 | 166 | # Create input tensor 167 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42) 168 | 169 | # Compute symmetric GEMM 170 | result = symmetric_dense_gemm(a, a.transpose(-2, -1), C=None, alpha=1.0, beta=1.0) 171 | 172 | # Check symmetry for each batch 173 | for l in range(L): 174 | matrix = result[l, :, :] 175 | torch.testing.assert_close(matrix, matrix.T, atol=1e-6, rtol=1e-6) 176 | 177 | def test_different_sizes(self): 178 | """Test various matrix sizes to ensure robustness.""" 179 | if not torch.cuda.is_available(): 180 | pytest.skip("CUDA not available") 181 | 182 | device = "cuda" 183 | dtype = torch.float16 184 | 185 | test_sizes = [ 186 | (3, 128, 128), 187 | (5, 256, 256), 188 | (5, 1024, 1024), 189 | (3, 2048, 2048), 190 | (1, 4096, 4096), 191 | ] 192 | 193 | for L, M, K in test_sizes: 194 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42) 195 | 196 | result = symmetric_dense_gemm(a, a.transpose(-2, -1), C=None, alpha=1.0, beta=1.0) 197 | expected = self.torch_reference(a, a) 198 | 199 | assert result.shape == (L, M, M) 200 | torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2) 201 | 202 | # Verify symmetry 203 | for l in range(L): 204 | matrix = result[l, :, :] 205 | torch.testing.assert_close(matrix, matrix.T, atol=1e-6, rtol=1e-6) 206 | 207 | def test_different_stride_patterns(self, dtype): 208 | """Test symmetric GEMM with different stride patterns (m_major vs k_major).""" 209 | if not torch.cuda.is_available(): 210 | pytest.skip("CUDA not available") 211 | 212 | L, M, K = self.default_shape 213 | device = "cuda" 214 | 215 | a_m_major = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42) 216 | 217 | data_contiguous = a_m_major.contiguous() 218 | a_k_major = torch.empty_strided((L, M, K), (M * K, K, 1), dtype=dtype, device=device) 219 | a_k_major.copy_(data_contiguous) 220 | 221 | assert torch.equal(a_m_major, a_k_major), "Input tensors should have identical values" 222 | assert a_m_major.stride() != a_k_major.stride(), "Stride patterns should be different" 223 | 224 | result_m_major = symmetric_dense_gemm(a_m_major, a_m_major.transpose(-2, -1), C=None, alpha=1.0, beta=1.0) 225 | result_k_major = symmetric_dense_gemm(a_k_major, a_k_major.transpose(-2, -1), C=None, alpha=1.0, beta=1.0) 226 | 227 | assert result_m_major.shape == result_k_major.shape == (L, M, M) 228 | 229 | if dtype == torch.float32: 230 | torch.testing.assert_close(result_m_major, result_k_major, atol=1e-6, rtol=1e-6) 231 | else: 232 | torch.testing.assert_close(result_m_major, result_k_major, atol=1e-4, rtol=1e-4) 233 | 234 | expected = self.torch_reference(a_m_major, a_m_major) 235 | 236 | if dtype == torch.float32: 237 | torch.testing.assert_close(result_m_major, expected, atol=1e-4, rtol=1e-4) 238 | torch.testing.assert_close(result_k_major, expected, atol=1e-4, rtol=1e-4) 239 | else: 240 | torch.testing.assert_close(result_m_major, expected, atol=1e-2, rtol=1e-2) 241 | torch.testing.assert_close(result_k_major, expected, atol=1e-2, rtol=1e-2) 242 | 243 | 244 | def run_tests(): 245 | """Run all tests manually (for debugging).""" 246 | test_class = TestSymmetricGemm() 247 | 248 | try: 249 | # Test basic functionality 250 | print("Testing basic symmetric GEMM...") 251 | test_class.test_basic_symmetric_gemm(torch.float16) 252 | print("✓ Basic test passed") 253 | 254 | # Test with bias 255 | print("Testing with bias...") 256 | test_class.test_symmetric_gemm_with_bias(torch.float16) 257 | print("✓ Bias test passed") 258 | 259 | # Test scaling 260 | print("Testing alpha/beta scaling...") 261 | test_class.test_alpha_beta_scaling(torch.float16) 262 | print("✓ Scaling test passed") 263 | 264 | # Test symmetry 265 | print("Testing symmetry property...") 266 | test_class.test_symmetry_property(torch.float16) 267 | print("✓ Symmetry test passed") 268 | 269 | # Test different sizes 270 | print("Testing different sizes...") 271 | test_class.test_different_sizes() 272 | print("✓ Different sizes test passed") 273 | 274 | # Test different stride patterns 275 | print("Testing different stride patterns...") 276 | test_class.test_different_stride_patterns(torch.float16) 277 | print("✓ Different stride patterns test passed") 278 | 279 | print("\n🎉 All tests passed!") 280 | 281 | except Exception as e: 282 | print(f"\n❌ Test failed with error: {e}") 283 | import traceback 284 | 285 | traceback.print_exc() 286 | 287 | 288 | if __name__ == "__main__": 289 | run_tests() 290 | -------------------------------------------------------------------------------- /tests/test_linear_varlen_k.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025, Tri Dao. 2 | import math 3 | import pytest 4 | import torch 5 | 6 | from quack.gemm_interface import ( 7 | gemm, 8 | gemm_ref, 9 | gemm_add, 10 | gemm_add_ref, 11 | gemm_add_inplace, 12 | ) 13 | 14 | 15 | def generate_A_with_gather(m, total_k, device, dtype, gather_A=False): 16 | """Generate A matrix and optionally A_idx for gather_A case with varlen_k. 17 | 18 | Args: 19 | m: Number of rows 20 | total_k: Number of columns needed 21 | device: Device to create tensors on 22 | dtype: Data type of tensors 23 | gather_A: Whether to create gather indices 24 | 25 | Returns: 26 | A: Matrix of shape (m, larger_k) if gather_A else (m, total_k) 27 | A_idx: Index tensor of shape (total_k,) if gather_A else None 28 | """ 29 | if gather_A: 30 | # Create random indices for gathering from a larger A matrix 31 | larger_k = total_k * 2 # Make A larger than needed 32 | A = torch.randn((m, larger_k), device=device, dtype=dtype) 33 | # Make A m-major 34 | A = A.T.contiguous().T 35 | # Create random indices to gather from A 36 | A_idx = torch.randperm(larger_k, device=device, dtype=torch.int32)[:total_k] 37 | else: 38 | A = torch.randn((m, total_k), device=device, dtype=dtype) 39 | # Make A m-major 40 | A = A.T.contiguous().T 41 | A_idx = None 42 | return A, A_idx 43 | 44 | 45 | @pytest.mark.parametrize("permute_batch", [False, True]) 46 | @pytest.mark.parametrize("gather_A", [False, True]) 47 | # @pytest.mark.parametrize("gather_A", [False]) 48 | @pytest.mark.parametrize("dynamic_scheduler", [False, True]) 49 | # @pytest.mark.parametrize("dynamic_scheduler", [False]) 50 | @pytest.mark.parametrize("alpha_is_tensor", [False, True]) 51 | # @pytest.mark.parametrize("alpha_is_tensor", [False]) 52 | @pytest.mark.parametrize("alpha", [1.0, 0.93]) 53 | # @pytest.mark.parametrize("alpha", [1.0]) 54 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 55 | @pytest.mark.parametrize("n", [1024, 1504, 4096]) 56 | @pytest.mark.parametrize("m", [2048, 1064, 8192]) 57 | # @pytest.mark.parametrize("n", [1024]) 58 | # @pytest.mark.parametrize("m", [2048]) 59 | @pytest.mark.parametrize("num_groups", [2, 4]) 60 | # @pytest.mark.parametrize("num_groups", [2]) 61 | def test_gemm_varlen_k( 62 | num_groups, 63 | m, 64 | n, 65 | input_dtype, 66 | alpha, 67 | alpha_is_tensor, 68 | dynamic_scheduler, 69 | gather_A, 70 | permute_batch, 71 | ): 72 | device = "cuda" 73 | torch.random.manual_seed(42) 74 | seq_lens = torch.randint(50, 300, (num_groups,), device="cpu") 75 | total_k = seq_lens.sum().item() 76 | # Create cumulative sequence lengths (num_groups + 1) 77 | cu_seqlens_k = torch.cat( 78 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)] 79 | ) 80 | cu_seqlens_k = cu_seqlens_k.to(device) 81 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A) 82 | avg_k = total_k / num_groups 83 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k) 84 | if alpha_is_tensor: 85 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32) 86 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha 87 | if permute_batch: 88 | batch_idx_permute = torch.randperm(num_groups, device=device).to(torch.int32) 89 | else: 90 | batch_idx_permute = None 91 | out = gemm( 92 | A, 93 | B, 94 | alpha=alpha, 95 | cu_seqlens_k=cu_seqlens_k, 96 | A_idx=A_idx, 97 | batch_idx_permute=batch_idx_permute, 98 | dynamic_scheduler=dynamic_scheduler, 99 | tuned=False, 100 | ) 101 | assert out.shape == (num_groups, m, n) 102 | out_ref = gemm_ref( 103 | A.float(), 104 | B.float(), 105 | alpha=alpha_val, 106 | cu_seqlens_k=cu_seqlens_k, 107 | A_idx=A_idx, 108 | ) 109 | out_pt = gemm_ref(A, B, alpha=alpha_val, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx) 110 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4 111 | 112 | 113 | @pytest.mark.parametrize("gather_A", [False, True]) 114 | # @pytest.mark.parametrize("gather_A", [False]) 115 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 116 | @pytest.mark.parametrize("n", [1024]) 117 | @pytest.mark.parametrize("m", [2048]) 118 | def test_gemm_varlen_k_with_zero_lengths( 119 | m, 120 | n, 121 | input_dtype, 122 | gather_A, 123 | ): 124 | device = "cuda" 125 | torch.random.manual_seed(42) 126 | seq_lens = torch.tensor([150, 64, 0, 200, 0], device="cpu", dtype=torch.int32) 127 | num_groups = seq_lens.shape[0] 128 | total_k = seq_lens.sum().item() 129 | cu_seqlens_k = torch.cat( 130 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)] 131 | ) 132 | cu_seqlens_k = cu_seqlens_k.to(device) 133 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A) 134 | avg_k = total_k / num_groups 135 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k) 136 | out = gemm( 137 | A, 138 | B, 139 | cu_seqlens_k=cu_seqlens_k, 140 | A_idx=A_idx, 141 | dynamic_scheduler=False, 142 | tuned=False, 143 | ) 144 | assert out.shape == (num_groups, m, n) 145 | out_ref = gemm_ref( 146 | A.float(), 147 | B.float(), 148 | cu_seqlens_k=cu_seqlens_k, 149 | A_idx=A_idx, 150 | ) 151 | out_pt = gemm_ref(A, B, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx) 152 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4 153 | 154 | 155 | @pytest.mark.parametrize("gather_A", [False, True]) 156 | # @pytest.mark.parametrize("gather_A", [False]) 157 | @pytest.mark.parametrize("dynamic_scheduler", [False, True]) 158 | # @pytest.mark.parametrize("dynamic_scheduler", [False]) 159 | @pytest.mark.parametrize("C_major", ["m", "n"]) 160 | @pytest.mark.parametrize("alpha_is_tensor", [False, True]) 161 | @pytest.mark.parametrize("beta_is_tensor", [False, True]) 162 | @pytest.mark.parametrize("beta", [0.0, 1.17]) 163 | @pytest.mark.parametrize("alpha", [1.0, 0.93]) 164 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 165 | @pytest.mark.parametrize("n", [1024, 1504]) 166 | @pytest.mark.parametrize("m", [2048, 1024]) 167 | @pytest.mark.parametrize("num_groups", [2, 4]) 168 | def test_gemm_add_varlen_k( 169 | num_groups, 170 | m, 171 | n, 172 | input_dtype, 173 | alpha, 174 | beta, 175 | alpha_is_tensor, 176 | beta_is_tensor, 177 | C_major, 178 | dynamic_scheduler, 179 | gather_A, 180 | ): 181 | device = "cuda" 182 | torch.random.manual_seed(42) 183 | seq_lens = torch.randint(50, 300, (num_groups,), device="cpu") 184 | total_k = seq_lens.sum().item() 185 | # Create cumulative sequence lengths (num_groups + 1) 186 | cu_seqlens_k = torch.cat( 187 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)] 188 | ) 189 | cu_seqlens_k = cu_seqlens_k.to(device) 190 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A) 191 | # Make A m-major 192 | A = A.T.contiguous().T 193 | avg_k = total_k / num_groups 194 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k) 195 | C = torch.randn((num_groups, m, n), device=device, dtype=input_dtype) 196 | if C_major == "m": 197 | C = C.permute(0, 2, 1).contiguous().permute(0, 2, 1) 198 | if alpha_is_tensor: 199 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32) 200 | if beta_is_tensor: 201 | beta = torch.tensor(beta, device=device, dtype=torch.float32) 202 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha 203 | beta_val = beta.item() if torch.is_tensor(beta) else beta 204 | out = gemm_add( 205 | A, 206 | B, 207 | C, 208 | alpha=alpha, 209 | beta=beta, 210 | cu_seqlens_k=cu_seqlens_k, 211 | A_idx=A_idx, 212 | dynamic_scheduler=dynamic_scheduler, 213 | tuned=False, 214 | ) 215 | assert out.shape == (num_groups, m, n) 216 | out_ref = gemm_add_ref( 217 | A.float(), 218 | B.float(), 219 | C.float(), 220 | alpha=alpha_val, 221 | beta=beta_val, 222 | cu_seqlens_k=cu_seqlens_k, 223 | A_idx=A_idx, 224 | ) 225 | out_pt = gemm_add_ref( 226 | A, B, C, alpha=alpha_val, beta=beta_val, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx 227 | ) 228 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4 229 | 230 | 231 | @pytest.mark.parametrize("gather_A", [False, True]) 232 | # @pytest.mark.parametrize("gather_A", [False]) 233 | @pytest.mark.parametrize("dynamic_scheduler", [False, True]) 234 | # @pytest.mark.parametrize("dynamic_scheduler", [False]) 235 | @pytest.mark.parametrize("alpha_is_tensor", [False, True]) 236 | @pytest.mark.parametrize("beta_is_tensor", [False, True]) 237 | @pytest.mark.parametrize("beta", [0.0, 1.17]) 238 | @pytest.mark.parametrize("alpha", [1.0, 0.93]) 239 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) 240 | @pytest.mark.parametrize("n", [1024, 1504]) 241 | @pytest.mark.parametrize("m", [2048, 1024]) 242 | @pytest.mark.parametrize("num_groups", [2, 4]) 243 | def test_gemm_add_inplace_varlen_k( 244 | num_groups, 245 | m, 246 | n, 247 | input_dtype, 248 | alpha, 249 | beta, 250 | alpha_is_tensor, 251 | beta_is_tensor, 252 | dynamic_scheduler, 253 | gather_A, 254 | ): 255 | device = "cuda" 256 | torch.random.manual_seed(42) 257 | seq_lens = torch.randint(50, 300, (num_groups,), device="cpu") 258 | total_k = seq_lens.sum().item() 259 | # Create cumulative sequence lengths (num_groups + 1) 260 | cu_seqlens_k = torch.cat( 261 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)] 262 | ) 263 | cu_seqlens_k = cu_seqlens_k.to(device) 264 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A) 265 | # Make A m-major 266 | A = A.T.contiguous().T 267 | avg_k = total_k / num_groups 268 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k) 269 | out = torch.randn((num_groups, m, n), device=device, dtype=input_dtype) 270 | if alpha_is_tensor: 271 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32) 272 | if beta_is_tensor: 273 | beta = torch.tensor(beta, device=device, dtype=torch.float32) 274 | # Save original out for reference computation 275 | out_og = out.clone() 276 | gemm_add_inplace( 277 | A, 278 | B, 279 | out, 280 | alpha=alpha, 281 | beta=beta, 282 | cu_seqlens_k=cu_seqlens_k, 283 | A_idx=A_idx, 284 | dynamic_scheduler=dynamic_scheduler, 285 | tuned=False, 286 | ) 287 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha 288 | beta_val = beta.item() if torch.is_tensor(beta) else beta 289 | out_ref = gemm_add_ref( 290 | A.float(), 291 | B.float(), 292 | out_og.float(), 293 | out=None, # Don't use in-place for reference 294 | alpha=alpha_val, 295 | beta=beta_val, 296 | cu_seqlens_k=cu_seqlens_k, 297 | A_idx=A_idx, 298 | ) 299 | out_pt = gemm_add_ref( 300 | A, 301 | B, 302 | out_og, 303 | out=None, 304 | alpha=alpha_val, 305 | beta=beta_val, 306 | cu_seqlens_k=cu_seqlens_k, 307 | A_idx=A_idx, 308 | ) 309 | assert out.shape == (num_groups, m, n), ( 310 | f"Output shape mismatch: {out.shape} vs expected ({num_groups}, {m}, {n})" 311 | ) 312 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4 313 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------