├── media
├── QuACK.png
├── dsmem.png
├── tv_layout.png
├── liger-16k-65k.png
├── max_reduction.png
├── our-16k-131k.png
├── warp_storage.png
├── block_reduction.png
├── our-16k-131k-sol.png
├── pytorch-16k-131k.png
├── thread_reduction.png
├── warp_reduction.png
├── cluster_reduction.png
├── liger-16k-65k-ncu.png
├── gpu-memory-hierarchy.png
├── memory-access-hierarchy.png
├── productivity-performance.png
├── combined_kernel_benchmarks_final.png
└── our-16k-131k-arithmetic-intensity-white.png
├── benchmarks
├── visual_outputs
│ ├── quack_vs_pytorch_speedup.png
│ ├── rmsnorm_speedup_comparison.png
│ ├── quack_vs_torchcompile_speedup.png
│ ├── quack_vs_pytorch_backward_speedup.png
│ ├── rmsnorm_backward_speedup_comparison.png
│ └── quack_vs_torchcompile_backward_speedup.png
├── benchmark_layernorm.py
├── benchmark_softmax.py
├── benchmark_topk.py
├── benchmark_cross_entropy.py
└── benchmark_rmsnorm.py
├── quack
├── __init__.py
├── compile_utils.py
├── sort
│ ├── utils.py
│ ├── bitonic_sort.py
│ └── sorting_networks.py
├── broadcast_utils.py
├── sm100_utils.py
├── mlp.py
├── fast_math.py
├── reduction_base.py
├── gemm_config.py
├── cute_dsl_utils.py
├── sm90_utils.py
├── tensormap_manager.py
├── gemm.py
├── gemm_dact.py
├── utils.py
├── linear.py
└── linear_cross_entropy.py
├── .pre-commit-config.yaml
├── CLAUDE.md
├── pyproject.toml
├── .github
└── workflows
│ └── publish.yaml
├── README.md
├── tests
├── test_linear_cross_entropy.py
├── test_topk.py
├── test_layernorm.py
├── test_softmax.py
├── test_linear.py
├── test_symmetric_gemm.py
└── test_linear_varlen_k.py
├── .gitignore
├── docs
├── dsl_control_flow.rst
└── limitations.rst
└── LICENSE
/media/QuACK.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/QuACK.png
--------------------------------------------------------------------------------
/media/dsmem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/dsmem.png
--------------------------------------------------------------------------------
/media/tv_layout.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/tv_layout.png
--------------------------------------------------------------------------------
/media/liger-16k-65k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/liger-16k-65k.png
--------------------------------------------------------------------------------
/media/max_reduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/max_reduction.png
--------------------------------------------------------------------------------
/media/our-16k-131k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/our-16k-131k.png
--------------------------------------------------------------------------------
/media/warp_storage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/warp_storage.png
--------------------------------------------------------------------------------
/media/block_reduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/block_reduction.png
--------------------------------------------------------------------------------
/media/our-16k-131k-sol.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/our-16k-131k-sol.png
--------------------------------------------------------------------------------
/media/pytorch-16k-131k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/pytorch-16k-131k.png
--------------------------------------------------------------------------------
/media/thread_reduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/thread_reduction.png
--------------------------------------------------------------------------------
/media/warp_reduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/warp_reduction.png
--------------------------------------------------------------------------------
/media/cluster_reduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/cluster_reduction.png
--------------------------------------------------------------------------------
/media/liger-16k-65k-ncu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/liger-16k-65k-ncu.png
--------------------------------------------------------------------------------
/media/gpu-memory-hierarchy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/gpu-memory-hierarchy.png
--------------------------------------------------------------------------------
/media/memory-access-hierarchy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/memory-access-hierarchy.png
--------------------------------------------------------------------------------
/media/productivity-performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/productivity-performance.png
--------------------------------------------------------------------------------
/media/combined_kernel_benchmarks_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/combined_kernel_benchmarks_final.png
--------------------------------------------------------------------------------
/media/our-16k-131k-arithmetic-intensity-white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/media/our-16k-131k-arithmetic-intensity-white.png
--------------------------------------------------------------------------------
/benchmarks/visual_outputs/quack_vs_pytorch_speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_pytorch_speedup.png
--------------------------------------------------------------------------------
/benchmarks/visual_outputs/rmsnorm_speedup_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/rmsnorm_speedup_comparison.png
--------------------------------------------------------------------------------
/benchmarks/visual_outputs/quack_vs_torchcompile_speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_torchcompile_speedup.png
--------------------------------------------------------------------------------
/benchmarks/visual_outputs/quack_vs_pytorch_backward_speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_pytorch_backward_speedup.png
--------------------------------------------------------------------------------
/benchmarks/visual_outputs/rmsnorm_backward_speedup_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/rmsnorm_backward_speedup_comparison.png
--------------------------------------------------------------------------------
/benchmarks/visual_outputs/quack_vs_torchcompile_backward_speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Dao-AILab/quack/HEAD/benchmarks/visual_outputs/quack_vs_torchcompile_backward_speedup.png
--------------------------------------------------------------------------------
/quack/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.2"
2 |
3 | from quack.rmsnorm import rmsnorm
4 | from quack.softmax import softmax
5 | from quack.cross_entropy import cross_entropy
6 |
7 | __all__ = [
8 | "rmsnorm",
9 | "softmax",
10 | "cross_entropy",
11 | ]
12 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | rev: v0.11.13
4 | hooks:
5 | - id: ruff-check
6 | args: [--fix, --exit-non-zero-on-fix]
7 | files: ^(quack|tests)/.*\.py$
8 | - id: ruff-format
9 | files: ^(quack|tests)/.*\.py$
10 |
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # Project convention
2 | For code that uses cute-dsl (as part of function decorated with cute.jit or
3 | cute.kernel), only a subset of Python syntax is supported.
4 | Follow [Control Flow](docs/dsl_control_flow.rst) and [Limitations](docs/limitations.rst).
5 | # Code style
6 | - Favor concise, self-explanatory code
7 | - Avoid unnecessary comments
8 | - Avoid unnecessary line breaks
9 | - Empty lines should not have any space
10 |
--------------------------------------------------------------------------------
/quack/compile_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2 |
3 | from typing import Optional
4 |
5 | import cutlass.cute as cute
6 |
7 |
8 | def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9 | if leading_dim < 0:
10 | leading_dim = len(shape) + leading_dim
11 | if dtype is None:
12 | return None
13 | stride = tuple(
14 | cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15 | for i in range(len(shape))
16 | )
17 | return cute.runtime.make_fake_tensor(
18 | dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19 | )
20 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "quack-kernels"
7 | dynamic = ["version"]
8 | requires-python = ">=3.10"
9 | dependencies = [
10 | "nvidia-cutlass-dsl==4.3.3",
11 | "torch",
12 | "apache-tvm-ffi>=0.1.5,<0.2",
13 | "torch-c-dlpack-ext",
14 | ]
15 |
16 | [project.optional-dependencies]
17 | dev = [
18 | "pre-commit",
19 | "ruff",
20 | ]
21 |
22 | [tool.setuptools.packages.find]
23 | exclude = ["tests", "benchmarks"]
24 |
25 | [tool.setuptools.dynamic]
26 | version = {attr = "quack.__version__"}
27 |
28 | [tool.ruff]
29 | line-length = 100
30 |
31 | [tool.ruff.lint]
32 | ignore = [
33 | "E731", # do not assign a lambda expression, use a def
34 | "E741", # Do not use variables named 'I', 'O', or 'l'
35 | "F841", # local variable is assigned to but never used
36 | ]
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Build and Publish to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*'
7 |
8 | jobs:
9 | build-and-publish:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 |
14 | - name: Set up Python
15 | uses: actions/setup-python@v5
16 | with:
17 | python-version: '3.9'
18 |
19 | - name: Install build dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install build twine
23 |
24 | - name: Build package
25 | run: python -m build
26 |
27 | - name: Create GitHub Release
28 | uses: softprops/action-gh-release@v2
29 | with:
30 | files: dist/*
31 | generate_release_notes: true
32 |
33 | - name: Publish to PyPI
34 | env:
35 | TWINE_USERNAME: __token__
36 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
37 | run: |
38 | twine upload dist/*
39 |
--------------------------------------------------------------------------------
/quack/sort/utils.py:
--------------------------------------------------------------------------------
1 | import cutlass.cute as cute
2 | from cutlass import Float32, const_expr
3 |
4 | import quack.utils as utils
5 |
6 |
7 | @cute.jit
8 | def compare_and_swap(
9 | arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False
10 | ) -> None:
11 | """Compare and swap elements at indices i and j in ascending or descending order."""
12 | if const_expr(use_selection):
13 | a, b = arr[i], arr[j]
14 | if (a > b) ^ (not ascending):
15 | arr[i] = b
16 | arr[j] = a
17 | # if const_expr(ascending):
18 | # if a > b:
19 | # arr[i] = b
20 | # arr[j] = a
21 | # else:
22 | # if a < b:
23 | # arr[i] = b
24 | # arr[j] = a
25 | else:
26 | min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin
27 | max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax
28 | if const_expr(ascending):
29 | arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j])
30 | else:
31 | arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j])
32 |
--------------------------------------------------------------------------------
/quack/broadcast_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 | from typing import Callable
3 |
4 | import cutlass
5 | import cutlass.cute as cute
6 | from cutlass import Float32, const_expr
7 |
8 | from quack.layout_utils import make_acc_tensor_mn_view
9 |
10 |
11 | @cute.jit
12 | def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13 | if const_expr(tCrC.element_type != Float32): # Convert to f32
14 | tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
15 | tCrC_f32.store(tCrC.load().to(Float32))
16 | else:
17 | tCrC_f32 = tCrC
18 | # this happens to work for frgA layout too, not just acc layout
19 | tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
20 | if const_expr(is_colvec):
21 | assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
22 | for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
23 | tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
24 | else:
25 | assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
26 | for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
27 | tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
28 | if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
29 | tCrC.store(tCrC_f32.load().to(tCrC.element_type))
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🦆 QuACK: A Quirky Assortment of CuTe Kernels 🦆
2 |
3 | Kernels are written in the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
4 |
5 | ## Installation
6 |
7 | ``` bash
8 | pip install quack-kernels
9 | ```
10 |
11 | ## Requirements
12 |
13 | - H100 or B200 GPU
14 | - CUDA toolkit 12.9+
15 | - Python 3.12
16 |
17 | ## Kernels 🐥
18 |
19 | - 🦆 RMSNorm forward + backward
20 | - 🦆 Softmax forward + backward
21 | - 🦆 Cross entropy forward + backward
22 | - 🦆 Layernorm forward
23 | - 🦆 Hopper gemm + epilogue
24 | - 🦆 Blackwell gemm + epilogue
25 |
26 | ## Usage
27 |
28 | ```
29 | from quack import rmsnorm, softmax, cross_entropy
30 | ```
31 |
32 | ## Documentations
33 |
34 | [2025-07-10] We have a comprehensive
35 | [blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
36 | to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
37 |
38 | ## Performance
39 |
40 |
41 |
42 |
45 |
46 |
47 |
48 | See our [blogpost](media/2025-07-10-membound-sol.md) for the details.
49 |
50 | ## Development
51 |
52 | To set up the development environment:
53 |
54 | ```bash
55 | pip install -e '.[dev]'
56 | pre-commit install
57 | ```
58 |
--------------------------------------------------------------------------------
/quack/sm100_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 |
3 | from typing import Type, Union
4 |
5 | import cutlass.cute as cute
6 | import cutlass.utils.blackwell_helpers as sm100_utils_og
7 | from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode
8 | from cutlass.cutlass_dsl import Numeric, dsl_user_op
9 |
10 |
11 | @dsl_user_op
12 | def make_smem_layout_cpasync_a(
13 | tiled_mma: cute.TiledMma,
14 | mma_tiler_mnk: cute.Tile,
15 | a_dtype: Type[Numeric],
16 | num_stages: int,
17 | *,
18 | loc=None,
19 | ip=None,
20 | ) -> Union[cute.Layout, cute.ComposedLayout]:
21 | """
22 | :param tiled_mma: The tiled MMA used to partition tensor A
23 | :type tiled_mma: cute.TiledMma
24 | :param mma_tiler_mnk: The MMA tile shape
25 | :type mma_tiler_mnk: cute.cute.Tile
26 | :param a_dtype: The element type for tensor A
27 | :type a_dtype: Type[Numeric]
28 | :param num_stages: The number of pipeline stages for tensor A
29 | :type num_stages: int
30 |
31 | :return: SMEM layout for tensor A
32 | :rtype: Union[cute.Layout, cute.ComposedLayout]
33 | """
34 |
35 | is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K
36 | a_smem_shape = tiled_mma.partition_shape_A(
37 | cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip)
38 | )
39 | a_smem_shape_mn_k = (
40 | cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1],
41 | cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2],
42 | )
43 | a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom(
44 | sm100_utils_og.get_smem_layout_atom_ab(
45 | tiled_mma.op.a_major_mode,
46 | a_dtype,
47 | a_smem_shape_mn_k,
48 | loc=loc,
49 | ip=ip,
50 | ),
51 | a_dtype,
52 | loc=loc,
53 | ip=ip,
54 | )
55 | a_smem_layout_staged = cute.tile_to_shape(
56 | a_smem_layout_atom,
57 | cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip),
58 | order=((1, 0, 2) if not is_k_major else (0, 1, 2)),
59 | loc=loc,
60 | ip=ip,
61 | )
62 | return a_smem_layout_staged
63 |
--------------------------------------------------------------------------------
/quack/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 |
7 | from quack.linear import linear_act_func, act_linear_func
8 |
9 |
10 | def mlp_func(x, weight1, weight2, activation: str, fuse_grad_accum=False, tuned=True):
11 | preact, postact = linear_act_func(
12 | x,
13 | weight1,
14 | activation,
15 | store_preact=torch.is_grad_enabled(),
16 | fuse_grad_accum=fuse_grad_accum,
17 | tuned=tuned,
18 | )
19 | out = act_linear_func(
20 | preact,
21 | weight2,
22 | postact,
23 | activation=activation,
24 | fuse_grad_accum=fuse_grad_accum,
25 | tuned=tuned,
26 | )
27 | return out
28 |
29 |
30 | class MLP(nn.Module):
31 | def __init__(
32 | self,
33 | in_features,
34 | hidden_features=None,
35 | out_features=None,
36 | bias1=False,
37 | bias2=False,
38 | activation="gelu",
39 | device=None,
40 | dtype=None,
41 | fuse_grad_accum: bool = False,
42 | tuned: bool = True,
43 | ):
44 | factory_kwargs = {"device": device, "dtype": dtype}
45 | super().__init__()
46 | out_features = out_features if out_features is not None else in_features
47 | hidden_features = hidden_features if hidden_features is not None else 4 * in_features
48 | self.activation = activation
49 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
50 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
51 | self.fuse_grad_accum = fuse_grad_accum
52 | self.tuned = tuned
53 |
54 | def forward(self, input: Tensor) -> Tensor:
55 | if (
56 | self.fc1.bias is None
57 | and self.fc2.bias is None
58 | and input.is_cuda
59 | and input.stride(-1) == 1
60 | and self.fc1.in_features % 8 == 0
61 | and self.fc1.out_features % 8 == 0
62 | and self.fc2.out_features % 8 == 0
63 | ):
64 | return mlp_func(
65 | input,
66 | self.fc1.weight,
67 | self.fc2.weight,
68 | activation=self.activation,
69 | fuse_grad_accum=self.fuse_grad_accum,
70 | tuned=self.tuned,
71 | )
72 | else:
73 | y = self.fc1(input)
74 | return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
75 |
--------------------------------------------------------------------------------
/quack/fast_math.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 |
3 | from typing import Tuple
4 | from dataclasses import dataclass
5 |
6 | import cutlass
7 | import cutlass.cute as cute
8 | from cutlass import Int32, Uint32
9 | from cutlass.cutlass_dsl import T, dsl_user_op
10 | from cutlass._mlir.dialects import llvm
11 |
12 | from quack.cute_dsl_utils import ParamsBase
13 |
14 |
15 | @cute.jit
16 | def clz(x: Int32) -> Int32:
17 | # for i in cutlass.range_constexpr(32):
18 | # if (1 << (31 - i)) & x:
19 | # return Int32(i)
20 | # return Int32(32)
21 | # Early exit is not supported yet
22 | res = Int32(32)
23 | done = False
24 | for i in cutlass.range(32):
25 | if ((1 << (31 - i)) & x) and not done:
26 | res = Int32(i)
27 | done = True
28 | return res
29 |
30 |
31 | def find_log2(x: Int32) -> Int32:
32 | a: Int32 = Int32(31 - clz(x))
33 | return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
34 |
35 |
36 | @dsl_user_op
37 | def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
38 | return Uint32(
39 | llvm.inline_asm(
40 | T.i32(),
41 | [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
42 | "mul.hi.u32 $0, $1, $2;",
43 | "=r,r,r",
44 | has_side_effects=False,
45 | is_align_stack=False,
46 | asm_dialect=llvm.AsmDialect.AD_ATT,
47 | )
48 | )
49 |
50 |
51 | @dataclass
52 | class FastDivmod(ParamsBase):
53 | divisor: Int32
54 | multiplier: Uint32
55 | shift_right: Uint32
56 |
57 | # called by host
58 | @staticmethod
59 | def create(divisor: Int32) -> "FastDivmod":
60 | """Construct the FastDivmod object, in host code.
61 | This precomputes some values based on the divisor and is computationally expensive.
62 | """
63 | p = Uint32(31 + find_log2(divisor))
64 | divisor_u32 = Uint32(divisor)
65 | multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66 | shift_right = Uint32(p - 32)
67 | return FastDivmod(divisor, multiplier, shift_right)
68 |
69 | @cute.jit
70 | def div(self, dividend: Int32) -> Int32:
71 | return (
72 | Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
73 | if self.divisor != 1
74 | else dividend
75 | )
76 |
77 | def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
78 | quotient = self.div(dividend)
79 | remainder = dividend - quotient * self.divisor
80 | return quotient, remainder
81 |
--------------------------------------------------------------------------------
/benchmarks/benchmark_layernorm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Type
4 |
5 | import torch
6 | from triton.testing import do_bench
7 |
8 | import cutlass
9 | import cutlass.torch as cutlass_torch
10 | from cutlass.cute.runtime import from_dlpack
11 | from quack.layernorm import layernorm, layernorm_ref, rstd_ref, mean_ref
12 | import cutlass.cute as cute
13 |
14 | try:
15 | import cudnn
16 | except ImportError:
17 | cudnn = None
18 |
19 |
20 | def run_layernorm(
21 | M,
22 | N,
23 | dtype: Type[cutlass.Numeric],
24 | warmup_iterations=2,
25 | iterations=200,
26 | ):
27 | if not torch.cuda.is_available():
28 | raise RuntimeError(f"Ampere GPU is required to run this example!")
29 |
30 | print(f"Tensor dimensions: [{M}, {N}]")
31 | print(f"Input and Output Data type: {dtype}")
32 |
33 | torch_dtype = cutlass_torch.dtype(dtype)
34 |
35 | device = "cuda"
36 | x = torch.randn(M, N, device=device, dtype=torch_dtype)
37 | w = torch.randn(N, device=device, dtype=torch.float32)
38 |
39 | print(f"Input tensor shapes:")
40 | print(f"x: {x.shape}, dtype: {x.dtype}")
41 | print(f"w: {w.shape}, dtype: {w.dtype}")
42 |
43 | eps = 1e-6
44 |
45 | print("Executing kernel...")
46 | out, rstd, mean = layernorm(x, w, eps=eps, return_rstd=True, return_mean=True)
47 |
48 | compiled_func_ref = torch.compile(layernorm_ref)
49 |
50 | fn = lambda: layernorm(x, w, eps=eps)
51 | time.sleep(0.5)
52 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
53 | mem_bw = (2 * x.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9
54 | print(f"Kernel execution time: {avg_time:.4f} ms")
55 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
56 |
57 | fn = lambda: compiled_func_ref(x, w, eps=eps)
58 | for _ in range(5):
59 | fn() # warm up
60 | time.sleep(0.5)
61 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
62 | mem_bw_ref = (2 * x.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9
63 | print(f"Ref kernel execution time: {avg_time:.4f} ms")
64 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
65 |
66 | return mem_bw, mem_bw_ref
67 |
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser(
71 | description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
72 | )
73 | parser.add_argument("--M", default=32768, type=int)
74 | parser.add_argument("--N", default=16384, type=int)
75 | parser.add_argument("--warmup_iterations", default=10, type=int)
76 | parser.add_argument("--iterations", default=100, type=int)
77 |
78 | args = parser.parse_args()
79 |
80 | run_layernorm(
81 | args.M,
82 | args.N,
83 | dtype=cutlass.BFloat16,
84 | warmup_iterations=args.warmup_iterations,
85 | iterations=args.iterations,
86 | )
87 |
--------------------------------------------------------------------------------
/quack/reduction_base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2 |
3 | from typing import Type, Tuple, Optional
4 |
5 | import cutlass
6 | import cutlass.cute as cute
7 | from cutlass import Int32, Int64, Float32, const_expr
8 |
9 | import quack.copy_utils as copy_utils
10 |
11 |
12 | class ReductionBase:
13 | def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32):
14 | self.dtype = dtype
15 | self.N = N
16 | self.stage = stage
17 | self.reduction_dtype = reduction_dtype
18 |
19 | def _threads_per_row(self):
20 | raise NotImplementedError()
21 |
22 | def _num_threads(self):
23 | return 128 if self.N <= 16384 else 256
24 |
25 | def _set_cluster_n(self):
26 | self.cluster_n = 1
27 |
28 | def _get_tiled_copy(self, vecsize: int = 1):
29 | assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
30 | threads_per_row = self._threads_per_row()
31 | num_threads = self._num_threads()
32 | assert num_threads % cute.arch.WARP_SIZE == 0
33 | num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
34 | tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row)
35 | tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize)
36 | return tiled_copy, tiler_mn, threads_per_row
37 |
38 | def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
39 | num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
40 | warps_per_row = (
41 | num_warps
42 | if cute.rank(tv_layout.shape[0]) == 1
43 | else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
44 | )
45 | return cute.make_ordered_layout(
46 | (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
47 | order=(1, 0, 2),
48 | )
49 |
50 | def _allocate_reduction_buffer_and_mbar(
51 | self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False
52 | ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
53 | reduction_buffer = smem.allocate_tensor(
54 | self.reduction_dtype,
55 | self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
56 | byte_alignment=8,
57 | )
58 | if const_expr(self.cluster_n > 1):
59 | mbar_ptr = smem.allocate_array(
60 | Int64, num_elems=self.stage if not is_persistent else self.stage * 2
61 | )
62 | else:
63 | mbar_ptr = None
64 | return reduction_buffer, mbar_ptr
65 |
66 | @cute.jit
67 | def _initialize_cluster(
68 | self,
69 | tidx: Int32,
70 | mbar_ptr: cute.Pointer,
71 | num_warps: int,
72 | is_persistent: bool = False,
73 | ):
74 | if const_expr(self.cluster_n > 1):
75 | if tidx < self.stage: # Initialize full barrier
76 | cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
77 | if const_expr(is_persistent): # Initialize empty barrier
78 | cute.arch.mbarrier_init(
79 | mbar_ptr + self.stage + tidx, num_warps * self.cluster_n
80 | )
81 | cute.arch.mbarrier_init_fence()
82 | # Cluster arrive after barrier init
83 | cute.arch.cluster_arrive_relaxed()
84 |
--------------------------------------------------------------------------------
/quack/gemm_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025, Fri Dao.
2 | import itertools
3 | from typing import Optional, List, Literal
4 | from functools import partial
5 | from dataclasses import dataclass
6 |
7 |
8 | @dataclass(frozen=True)
9 | class GemmConfig:
10 | tile_m: int = 128
11 | tile_n: int = 192
12 | pingpong: bool = True
13 | cluster_m: int = 2
14 | cluster_n: int = 1
15 | swap_ab: bool = False
16 | # raster_order: int = 1
17 | max_swizzle_size: int = 8
18 |
19 |
20 | def get_all_configs(
21 | device_capacity: Literal[9, 10] = 9,
22 | epilogue: Optional[str] = None,
23 | tune_coop: bool = True,
24 | # tune_raster_order=True,
25 | ) -> List[GemmConfig]:
26 | assert device_capacity in [9, 10]
27 | if device_capacity == 9:
28 | tile_n_vals = [128, 144, 160, 176, 192, 208]
29 | tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [
30 | (128, 224),
31 | (128, 256),
32 | # (192, 256), # Getting IOT instruction (core dumped) in the bwd
33 | ]
34 | tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)]
35 | if epilogue in ["gated"]:
36 | tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192]
37 | tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0]
38 | elif epilogue in ["lse"]:
39 | tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192]
40 | tile_mn_vals = []
41 | if tune_coop:
42 | tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals]
43 | tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals]
44 | cluster = [(1, 2), (2, 1)]
45 | # cluster = [(1, 1), (1, 2), (2, 1)]
46 | if epilogue in ["lse"]:
47 | cluster = [(1, 2), (2, 1)]
48 | swap_ab_vals = [False, True]
49 | if epilogue in ["lse", "gated"]:
50 | swap_ab_vals = [False]
51 | # raster_swizzle = (
52 | # [(0, 1)]
53 | # if not tune_raster_order
54 | # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
55 | # )
56 | return [
57 | GemmConfig(
58 | tile_m=tile_m,
59 | tile_n=tile_n,
60 | pingpong=pingpong,
61 | cluster_m=cluster_m,
62 | cluster_n=cluster_n,
63 | swap_ab=swap_ab,
64 | # raster_order=raster_order,
65 | # max_swizzle_size=max_swizzle_size,
66 | )
67 | for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product(
68 | tile_mn_vals,
69 | cluster,
70 | swap_ab_vals,
71 | # raster_swizzle,
72 | )
73 | ]
74 | elif device_capacity == 10:
75 | tile_n_vals = [128, 160, 192, 224, 256]
76 | tile_n_64_vals = [128, 192, 256]
77 | tile_mn_cluster_vals = (
78 | [(128, tile_n, (1, 2)) for tile_n in tile_n_vals]
79 | # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals]
80 | + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals]
81 | + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals]
82 | )
83 | swap_ab_vals = [False, True]
84 | if epilogue in ["lse", "gated"]:
85 | swap_ab_vals = [False]
86 | max_swizzle_size_vals = [4, 8, 16]
87 | GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100
88 | return [
89 | GemmConfigCls(
90 | tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms
91 | )
92 | for (m, n, (cm, cn)), sab, ms in itertools.product(
93 | tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals
94 | )
95 | ]
96 |
--------------------------------------------------------------------------------
/tests/test_linear_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 |
3 | import pytest
4 | import torch
5 |
6 | from quack.linear_cross_entropy import (
7 | chunked_linear_cross_entropy,
8 | linear_cross_entropy_func_ref,
9 | )
10 |
11 |
12 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
13 | @pytest.mark.parametrize("reduction", ["mean", "sum"])
14 | @pytest.mark.parametrize("V", [32000, 50264, 128256])
15 | # @pytest.mark.parametrize("V", [32000])
16 | @pytest.mark.parametrize("d", [768, 1024])
17 | # @pytest.mark.parametrize("d", [768])
18 | @pytest.mark.parametrize("B_L", [8, 16, 24])
19 | @pytest.mark.parametrize("chunk_size", [16])
20 | def test_chunked_linear_cross_entropy(B_L, d, V, chunk_size, reduction, input_dtype):
21 | """Test chunked linear cross entropy against reference implementation."""
22 | device = "cuda"
23 | atol, rtol = 1e-3, 1e-3
24 | torch.random.manual_seed(0)
25 | x = (torch.randn(B_L, d, device=device, dtype=input_dtype) * 0.1).requires_grad_()
26 | weight = (torch.randn(V, d, device=device, dtype=input_dtype) / (d**0.5)).requires_grad_()
27 | target = torch.randint(0, V, (B_L,), device=device, dtype=torch.int64)
28 | x_ref = x.detach().clone().requires_grad_(True)
29 | weight_ref = weight.detach().clone().requires_grad_(True)
30 | x_pt = x.detach().clone().requires_grad_(True)
31 | weight_pt = weight.detach().clone().requires_grad_(True)
32 | loss_ref = linear_cross_entropy_func_ref(
33 | x_ref.float(), weight_ref.float(), None, target, reduction=reduction
34 | )
35 | loss_pt = linear_cross_entropy_func_ref(x_pt, weight_pt, None, target, reduction=reduction)
36 | # Chunked implementation
37 | loss = chunked_linear_cross_entropy(
38 | x, weight, target, chunk_size=chunk_size, reduction=reduction, tuned=False
39 | )
40 | assert (loss - loss_ref).abs().max() < 3 * (loss_pt - loss_ref).abs().max() + 1e-5
41 | loss.backward()
42 | loss_ref.backward()
43 | loss_pt.backward()
44 | assert (x.grad - x_ref.grad).abs().max() < 2 * (x_pt.grad - x_ref.grad).abs().max() + 1e-4
45 | assert (weight.grad - weight_ref.grad).abs().max() < 2 * (
46 | weight_pt.grad - weight_ref.grad
47 | ).abs().max() + 1e-4
48 |
49 |
50 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
51 | @pytest.mark.parametrize("reduction", ["mean", "sum"])
52 | @pytest.mark.parametrize("chunk_size", [256, 1024])
53 | def test_chunked_linear_cross_entropy_ignore_index(input_dtype, reduction, chunk_size):
54 | """Test chunked linear cross entropy with ignore_index."""
55 | device = "cuda"
56 | B_L, d, V = 1024, 512, 2048
57 | ignore_index = V - 1
58 | atol, rtol = 1e-3, 1e-3
59 | torch.random.manual_seed(42)
60 | x = (torch.randn(B_L, d, device=device, dtype=input_dtype) * 0.1).requires_grad_()
61 | weight = (torch.randn(V, d, device=device, dtype=input_dtype) / (d**0.5)).requires_grad_()
62 | target = torch.randint(0, V, (B_L,), device=device, dtype=torch.int64)
63 | x_ref = x.detach().clone().requires_grad_(True)
64 | weight_ref = weight.detach().clone().requires_grad_(True)
65 | x_pt = x.detach().clone().requires_grad_(True)
66 | weight_pt = weight.detach().clone().requires_grad_(True)
67 | # Set some targets to ignore_index
68 | ignore_mask = torch.rand(B_L, device=device) < 0.2
69 | target[ignore_mask] = ignore_index
70 | loss_ref = linear_cross_entropy_func_ref(
71 | x_ref.float(), weight_ref.float(), None, target, reduction=reduction
72 | )
73 | loss_pt = linear_cross_entropy_func_ref(x_pt, weight_pt, None, target, reduction=reduction)
74 | # Chunked implementation
75 | loss = chunked_linear_cross_entropy(
76 | x, weight, target, chunk_size=chunk_size, reduction=reduction, tuned=False
77 | )
78 | assert (loss - loss_ref).abs().max() < 3 * (loss_pt - loss_ref).abs().max() + 1e-5
79 | loss.backward()
80 | loss_ref.backward()
81 | loss_pt.backward()
82 | assert (x.grad - x_ref.grad).abs().max() < 2 * (x_pt.grad - x_ref.grad).abs().max() + 1e-4
83 | assert (weight.grad - weight_ref.grad).abs().max() < 2 * (
84 | weight_pt.grad - weight_ref.grad
85 | ).abs().max() + 1e-4
86 |
--------------------------------------------------------------------------------
/quack/cute_dsl_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 |
3 | from typing import Tuple
4 | from functools import lru_cache
5 | from dataclasses import dataclass, fields
6 |
7 | import torch
8 |
9 | try:
10 | from triton.tools.disasm import extract
11 | except ImportError:
12 | extract = None
13 |
14 | import cutlass
15 | import cutlass.cute as cute
16 | from cutlass import Int32, Int64, Float16, BFloat16, Float32
17 | from cutlass.base_dsl.typing import JitArgument
18 | from cutlass.cutlass_dsl import NumericMeta
19 |
20 |
21 | StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
22 |
23 |
24 | load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
25 | cute_compile_og = cute.compile
26 |
27 |
28 | torch2cute_dtype_map = {
29 | torch.float16: Float16,
30 | torch.bfloat16: BFloat16,
31 | torch.float32: Float32,
32 | torch.int32: Int32,
33 | torch.int64: Int64,
34 | }
35 |
36 |
37 | @lru_cache
38 | def get_max_active_clusters(cluster_size):
39 | return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
40 |
41 |
42 | @lru_cache
43 | def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
44 | return torch.cuda.get_device_capability(device)
45 |
46 |
47 | @dataclass
48 | class ParamsBase:
49 | def __extract_mlir_values__(self):
50 | all_fields = [getattr(self, field.name) for field in fields(self)]
51 | non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
52 | values, self._values_pos = [], []
53 | for obj in non_constexpr_fields:
54 | obj_values = cutlass.extract_mlir_values(obj)
55 | values += obj_values
56 | self._values_pos.append(len(obj_values))
57 | return values
58 |
59 | def __new_from_mlir_values__(self, values):
60 | all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
61 | constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
62 | non_constexpr_fields = {
63 | n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
64 | }
65 | for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
66 | non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
67 | values = values[n_items:]
68 | return self.__class__(**non_constexpr_fields, **constexpr_fields)
69 |
70 |
71 | @dataclass
72 | class ArgumentsBase(JitArgument):
73 | def __c_pointers__(self):
74 | all_fields = [getattr(self, field.name) for field in fields(self)]
75 | non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
76 | c_ptrs = []
77 | for obj in non_constexpr_fields:
78 | if hasattr(obj, "__c_pointers__"):
79 | c_ptrs.extend(obj.__c_pointers__())
80 | return c_ptrs
81 |
82 | def __get_mlir_types__(self):
83 | all_fields = [getattr(self, field.name) for field in fields(self)]
84 | non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
85 | types, self._values_pos = [], []
86 | for obj in non_constexpr_fields:
87 | if hasattr(obj, "__get_mlir_types__"):
88 | obj_types = obj.__get_mlir_types__()
89 | types.extend(obj_types)
90 | self._values_pos.append(len(obj_types))
91 | else:
92 | self._values_pos.append(0)
93 | return types
94 |
95 | def __new_from_mlir_values__(self, values):
96 | all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
97 | constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
98 | non_constexpr_fields = {
99 | n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
100 | }
101 | for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
102 | non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
103 | values = values[n_items:]
104 | return self.__class__(**non_constexpr_fields, **constexpr_fields)
105 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .claude/
3 |
4 | # Created by https://www.toptal.com/developers/gitignore/api/python
5 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
6 |
7 | ### Python ###
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/#use-with-ide
117 | .pdm.toml
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
169 | ### Python Patch ###
170 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
171 | poetry.toml
172 |
173 | # ruff
174 | .ruff_cache/
175 |
176 | # LSP config files
177 | pyrightconfig.json
178 |
179 | # End of https://www.toptal.com/developers/gitignore/api/python
180 |
--------------------------------------------------------------------------------
/quack/sm90_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 |
3 | from typing import Type, Union, Optional
4 |
5 | import cutlass
6 | import cutlass.cute as cute
7 | import cutlass.utils.hopper_helpers as sm90_utils_og
8 | from cutlass.cute.nvgpu import warpgroup
9 | from cutlass.cutlass_dsl import Numeric, dsl_user_op
10 | from cutlass import Float32, Int32, Boolean, const_expr
11 | from cutlass.utils import LayoutEnum
12 |
13 |
14 | @dsl_user_op
15 | def make_smem_layout(
16 | dtype: Type[Numeric],
17 | layout: LayoutEnum,
18 | tile: cute.Tile,
19 | stage: Optional[int] = None,
20 | *,
21 | loc=None,
22 | ip=None,
23 | ) -> Union[cute.Layout, cute.ComposedLayout]:
24 | shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip)
25 | major_mode_size = shape[1] if layout.is_n_major_c() else shape[0]
26 | smem_layout_atom = warpgroup.make_smem_layout_atom(
27 | sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
28 | dtype,
29 | )
30 | smem_layout_staged = cute.tile_to_shape(
31 | smem_layout_atom,
32 | cute.append(shape, stage) if const_expr(stage is not None) else shape,
33 | order=(1, 0, 2) if layout.is_m_major_c() else (0, 1, 2),
34 | )
35 | return smem_layout_staged
36 |
37 |
38 | # For compatibility with blackwell_helpers.py
39 | make_smem_layout_epi = make_smem_layout
40 |
41 |
42 | @dsl_user_op
43 | def partition_for_epilogue(
44 | cT: cute.Tensor,
45 | epi_tile: cute.Tile,
46 | tiled_copy: cute.TiledCopy,
47 | tidx: Int32,
48 | reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy
49 | *,
50 | loc=None,
51 | ip=None,
52 | ) -> cute.Tensor:
53 | thr_copy = tiled_copy.get_slice(tidx)
54 | cT_epi = cute.flat_divide(cT, epi_tile)
55 | # (CPY, CPY_M, CPY_N, EPI_M, EPI_N)
56 | if const_expr(reference_src):
57 | return thr_copy.partition_S(cT_epi, loc=loc, ip=ip)
58 | else:
59 | return thr_copy.partition_D(cT_epi, loc=loc, ip=ip)
60 |
61 |
62 | @cute.jit
63 | def gemm(
64 | tiled_mma: cute.TiledMma,
65 | acc: cute.Tensor,
66 | tCrA: cute.Tensor,
67 | tCrB: cute.Tensor,
68 | zero_init: cutlass.Constexpr[bool] = False,
69 | wg_wait: cutlass.Constexpr[int] = 0,
70 | # A_in_regs: cutlass.Constexpr[bool] = False,
71 | swap_AB: cutlass.Constexpr[bool] = False,
72 | ) -> None:
73 | if const_expr(swap_AB):
74 | gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False)
75 | else:
76 | warpgroup.fence()
77 | # We make a new mma_atom since we'll be modifying its attribute (accumulate).
78 | # Otherwise the compiler complains "operand #0 does not dominate this use"
79 | mma_atom = cute.make_mma_atom(tiled_mma.op)
80 | mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init)
81 | for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
82 | cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
83 | mma_atom.set(warpgroup.Field.ACCUMULATE, True)
84 | warpgroup.commit_group()
85 | if const_expr(wg_wait >= 0):
86 | warpgroup.wait_group(wg_wait)
87 |
88 |
89 | def gemm_zero_init(
90 | tiled_mma: cute.TiledMma,
91 | shape: cute.Shape,
92 | tCrA: cute.Tensor,
93 | tCrB: cute.Tensor,
94 | A_idx: Optional[Int32] = None,
95 | B_idx: Optional[Int32] = None,
96 | wg_wait: int = -1,
97 | swap_AB: bool = False,
98 | ) -> cute.Tensor:
99 | if const_expr(swap_AB):
100 | return gemm_zero_init(
101 | tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
102 | )
103 | else:
104 | acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
105 | rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
106 | rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
107 | gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
108 | return acc
109 |
110 |
111 | def gemm_w_idx(
112 | tiled_mma: cute.TiledMma,
113 | acc: cute.Tensor,
114 | tCrA: cute.Tensor,
115 | tCrB: cute.Tensor,
116 | zero_init: Boolean,
117 | A_idx: Optional[Int32] = None,
118 | B_idx: Optional[Int32] = None,
119 | wg_wait: int = -1,
120 | swap_AB: bool = False,
121 | ) -> None:
122 | if const_expr(swap_AB):
123 | gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False)
124 | else:
125 | rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
126 | rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
127 | gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
128 |
--------------------------------------------------------------------------------
/quack/sort/bitonic_sort.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2 |
3 | import math
4 | from typing import Optional
5 |
6 | import cutlass
7 | import cutlass.cute as cute
8 | from cutlass import Int32, Float32, const_expr
9 |
10 | import quack.utils as utils
11 | from quack.sort.utils import compare_and_swap
12 | from quack.sort.sorting_networks import optimal_sort
13 |
14 |
15 | @cute.jit
16 | def bitonic_merge(
17 | arr: cute.Tensor,
18 | n: Optional[cutlass.Constexpr[int]] = None,
19 | start: cutlass.Constexpr[int] = 0,
20 | ascending: cutlass.Constexpr[bool] = True,
21 | ) -> None:
22 | """Merge a bitonic sequence into a sorted sequence using iterative approach."""
23 | if const_expr(n is None):
24 | n = cute.size(arr.shape)
25 | if const_expr(n > 1):
26 | num_levels = int(math.log2(n))
27 | assert n == 2**num_levels, "n must be a power of 2"
28 | # This one must be range_constexpr otherwise it's very slow for n = 128
29 | for level in cutlass.range_constexpr(num_levels):
30 | length = n >> level # n // (2^level)
31 | step = length // 2
32 | for i in cutlass.range(n // length, unroll_full=True):
33 | start_i = start + i * length
34 | for j in cutlass.range(step, unroll_full=True):
35 | compare_and_swap(arr, start_i + j, start_i + j + step, ascending)
36 |
37 |
38 | @cute.jit
39 | def bitonic_sort(
40 | arr: cute.Tensor,
41 | n: Optional[cutlass.Constexpr[int]] = None,
42 | start: cutlass.Constexpr[int] = 0,
43 | ascending: cutlass.Constexpr[bool] = True,
44 | ) -> None:
45 | """
46 | Bitonic sort for small arrays of size N (power of 2, N <= 128).
47 |
48 | Args:
49 | arr: Array to sort
50 | n: Size of array (must be power of 2 and <= 128)
51 | start: Starting index (default 0)
52 | ascending: Sort in ascending order (default True)
53 | """
54 | if const_expr(n is None):
55 | n = cute.size(arr.shape)
56 | assert n <= 128
57 | if const_expr(n > 1):
58 | if const_expr(n in [2, 4, 8, 16, 32, 64]):
59 | optimal_sort(arr, n, start, ascending)
60 | else: # Fall back to bitonic sort
61 | assert n % 2 == 0
62 | # Sort first half in ascending order
63 | bitonic_sort(arr, n // 2, start, True)
64 | # Sort second half in descending order
65 | bitonic_sort(arr, n // 2, start + n // 2, False)
66 | # Merge the whole sequence
67 | bitonic_merge(arr, n, start, ascending)
68 |
69 |
70 | @cute.jit
71 | def bitonic_topk_merge(
72 | arr0: cute.Tensor,
73 | arr1: cute.Tensor,
74 | k: Optional[cutlass.Constexpr[int]] = None,
75 | start0: cutlass.Constexpr[int] = 0,
76 | start1: cutlass.Constexpr[int] = 0,
77 | ascending: cutlass.Constexpr[bool] = False,
78 | ) -> None:
79 | if const_expr(k is None):
80 | k = cute.size(arr0.shape)
81 | if const_expr(arr0.element_type == Float32):
82 | minmax_fn = utils.fmin if ascending else cute.arch.fmax
83 | else:
84 | minmax_fn = min if ascending else max
85 | # Write the top k elements to the first half of the array
86 | for i in cutlass.range(k, unfoll_full=True):
87 | arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
88 | # Now the 1st half is bitonic, we just need to merge it
89 | bitonic_merge(arr0, k, start0, ascending)
90 |
91 |
92 | @cute.jit
93 | def bitonic_topk(
94 | arr: cute.Tensor,
95 | k: cutlass.Constexpr[int],
96 | ascending: cutlass.Constexpr[bool] = False,
97 | warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
98 | ) -> cute.Tensor:
99 | """
100 | Bitonic top-k for small arrays of size N (power of 2, N <= 128).
101 |
102 | Args:
103 | arr: Array to sort
104 | k: must be power of 2 and <= 128
105 | ascending: Sort in ascending order (default False)
106 | """
107 | assert arr.element_type in [Float32, Int32]
108 | n = cute.size(arr.shape)
109 | assert k == 1 << int(math.log2(k)), "k must be a power of 2"
110 | assert n % k == 0, "n must be divisible by k"
111 | topk_vals = cute.make_fragment(k, arr.element_type)
112 | for v in cutlass.range(k, unroll_full=True):
113 | topk_vals[v] = arr[v]
114 | bitonic_sort(topk_vals, ascending=ascending)
115 | for i in cutlass.range(1, n // k, unroll_full=True):
116 | other_vals = cute.make_fragment(k, arr.element_type)
117 | for v in cutlass.range(k, unroll_full=True):
118 | other_vals[v] = arr[i * k + v]
119 | bitonic_sort(other_vals, ascending=ascending)
120 | # Merge 2 sorted top-k sequences to get a new top-k sequence
121 | bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
122 | # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
123 | # do duplicate work.
124 | for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
125 | other_vals = cute.make_fragment(k, arr.element_type)
126 | for v in cutlass.range(k, unroll_full=True):
127 | other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
128 | bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)
129 | return topk_vals
130 |
--------------------------------------------------------------------------------
/tests/test_topk.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2 |
3 | import pytest
4 | import torch
5 |
6 |
7 | from quack.topk import topk
8 |
9 | torch._dynamo.config.cache_size_limit = 1024
10 | torch._dynamo.config.accumulated_cache_size_limit = 1024
11 |
12 |
13 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
14 | # @pytest.mark.parametrize("input_dtype", [torch.float32])
15 | @pytest.mark.parametrize(
16 | "N, k",
17 | [(64, 16), (128, 32), (256, 16), (512, 32), (1024, 32), (4096, 32), (4096, 64), (4096, 128)],
18 | # [(64, 16)],
19 | )
20 | @pytest.mark.parametrize("M", [1, 37, 199])
21 | # @pytest.mark.parametrize("M", [1])
22 | @pytest.mark.parametrize("softmax", [False, True])
23 | # @pytest.mark.parametrize("softmax", [False])
24 | @pytest.mark.parametrize("function", [topk, torch.compile(topk, fullgraph=True)])
25 | # @pytest.mark.parametrize("function", [torch.compile(topk, fullgraph=True)])
26 | def test_topk(M, N, k, input_dtype, softmax, function):
27 | """Test TopK forward/backward against PyTorch reference implementation."""
28 | device = "cuda"
29 | # Set tolerance based on dtype
30 | if input_dtype == torch.bfloat16:
31 | atol = 1e-2
32 | rtol = 1e-2
33 | elif input_dtype == torch.float16:
34 | atol = 1e-3
35 | rtol = 1e-3
36 | else:
37 | atol = 1e-3
38 | rtol = 5e-4
39 |
40 | torch.random.manual_seed(0)
41 | # Create input tensors
42 | x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
43 | out_val, out_idx = function(x, k, softmax=softmax)
44 | out_val_ref, out_idx_ref = torch.topk(x.detach(), k, dim=-1, largest=True, sorted=True)
45 | if softmax:
46 | out_val_ref = torch.softmax(out_val_ref.float(), dim=-1).to(input_dtype)
47 |
48 | # Check output shape and dtype
49 | assert out_val.shape == (M, k)
50 | assert out_val.dtype == input_dtype
51 | # Check accuracy - values should match the reference
52 | torch.testing.assert_close(out_val, out_val_ref, atol=atol, rtol=rtol)
53 |
54 | if not softmax:
55 | # 1. Values should be in descending order
56 | assert torch.all(out_val[:, :-1] >= out_val[:, 1:]), "Some rows not in descending order"
57 | # 2. Values indexed at output indices should match output values
58 | indexed_vals = torch.gather(x, 1, out_idx.long())
59 | torch.testing.assert_close(indexed_vals, out_val, atol=atol, rtol=rtol)
60 |
61 | # Backward check
62 | dvalues = torch.randn_like(out_val)
63 | out_val.backward(dvalues)
64 | dx_ref = torch.zeros_like(x)
65 | dx_ref.scatter_(1, out_idx.long(), dvalues)
66 | torch.testing.assert_close(x.grad, dx_ref, atol=1e-3, rtol=1e-3)
67 | else:
68 | # For softmax case, check that probabilities sum to 1
69 | torch.testing.assert_close(
70 | out_val.float().sum(dim=-1),
71 | torch.ones(M, device=device, dtype=torch.float32),
72 | atol=1e-2,
73 | rtol=1e-2,
74 | msg="Softmax probabilities don't sum to 1",
75 | )
76 | dvalues = torch.randn_like(out_val)
77 | out_val.backward(dvalues)
78 | dot = (dvalues.float() * out_val.float()).sum(dim=1, keepdim=True)
79 | grad_topk = out_val.float() * (dvalues.float() - dot)
80 | grad_topk = grad_topk.to(input_dtype)
81 | dx_ref = torch.zeros_like(x)
82 | dx_ref.scatter_(1, out_idx.long(), grad_topk)
83 | torch.testing.assert_close(x.grad, dx_ref, atol=1e-3, rtol=1e-3)
84 |
85 |
86 | # @pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32])
87 | # def test_topk_extreme_values(input_dtype):
88 | # """Test TopK with extreme input values."""
89 | # device = "cuda"
90 | # M, N, k = 16, 64, 16
91 |
92 | # # Test with identical values
93 | # x_uniform = torch.full((M, N), 1.0, device=device, dtype=input_dtype)
94 | # out_uniform = topk(x_uniform, k)
95 | # # All output values should be 1.0
96 | # expected = torch.full((M, k), 1.0, device=device, dtype=input_dtype)
97 | # torch.testing.assert_close(out_uniform, expected, atol=1e-3, rtol=1e-3)
98 |
99 | # # Test with large range of values
100 | # x_range = torch.arange(N, dtype=input_dtype, device=device).unsqueeze(0).expand(M, -1)
101 | # out_range = topk(x_range, k)
102 | # # Should get the largest k values in descending order
103 | # expected_range = torch.arange(N-1, N-k-1, -1, dtype=input_dtype, device=device).unsqueeze(0).expand(M, -1)
104 | # torch.testing.assert_close(out_range, expected_range, atol=1e-6, rtol=1e-6)
105 |
106 |
107 | # def test_topk_edge_cases():
108 | # """Test TopK edge cases."""
109 | # device = "cuda"
110 |
111 | # # Test k=1 (single maximum)
112 | # M, N = 8, 64
113 | # x = torch.randn(M, N, device=device, dtype=torch.float32)
114 | # out_val = topk(x, 1)
115 | # out_val_ref = torch.max(x, dim=-1, keepdim=True)[0]
116 | # torch.testing.assert_close(out_val, out_val_ref, atol=1e-6, rtol=1e-6)
117 |
118 | # # Test with negative values
119 | # x_neg = torch.randn(M, N, device=device, dtype=torch.float32) - 10.0
120 | # out_neg = topk(x_neg, 8)
121 | # out_ref_neg, _ = torch.topk(x_neg, 8, dim=-1, largest=True, sorted=True)
122 | # torch.testing.assert_close(out_neg, out_ref_neg, atol=1e-6, rtol=1e-6)
123 |
--------------------------------------------------------------------------------
/tests/test_layernorm.py:
--------------------------------------------------------------------------------
1 | # tests/test_layernorm.py
2 |
3 | import pytest
4 | import torch
5 |
6 | from quack.rmsnorm import layernorm_fwd, layernorm_ref, layernorm_rstd_ref, layernorm_mean_ref
7 |
8 |
9 | @pytest.mark.parametrize("eps", [1e-5, 1e-6])
10 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
11 | @pytest.mark.parametrize("M", [1, 37, 199])
12 | @pytest.mark.parametrize(
13 | "N", [256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]
14 | ) # , 32768])
15 | def test_layernorm_forward(M, N, input_dtype, eps):
16 | """Test LayerNorm forward pass against reference implementation."""
17 | device = "cuda"
18 |
19 | # tolerance depends on precision
20 | if input_dtype == torch.bfloat16:
21 | atol = 1e-2
22 | rtol = 1e-2
23 | elif input_dtype == torch.float16:
24 | atol = 1e-3
25 | rtol = 1e-3
26 | else:
27 | atol = 1e-4
28 | rtol = 1e-4
29 |
30 | torch.random.manual_seed(0)
31 | x = torch.randn(M, N, device=device, dtype=input_dtype, requires_grad=True)
32 | weight = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
33 |
34 | # pure‐PyTorch refs
35 | x_ref = x.detach().clone().requires_grad_()
36 | weight_ref = weight.detach().clone().requires_grad_()
37 |
38 | out, rstd, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=True)
39 | out_ref = layernorm_ref(x_ref, weight_ref, eps=eps)
40 | rstd_ref_val = layernorm_rstd_ref(x_ref, eps=eps)
41 | mean_ref_val = layernorm_mean_ref(x_ref)
42 |
43 | # shapes & dtypes
44 | assert out.shape == x.shape
45 | assert out.dtype == input_dtype
46 | assert rstd.shape == (M,) and rstd.dtype == torch.float32
47 | assert mean.shape == (M,) and mean.dtype == torch.float32
48 |
49 | # numeric check
50 | torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
51 | torch.testing.assert_close(rstd, rstd_ref_val, atol=6e-4, rtol=6e-4)
52 | torch.testing.assert_close(mean, mean_ref_val, atol=6e-4, rtol=6e-4)
53 |
54 |
55 | @pytest.mark.parametrize("return_rstd", [True, False])
56 | @pytest.mark.parametrize("return_mean", [True, False])
57 | def test_layernormnorm_return_rstd_option(return_rstd, return_mean):
58 | """Test that return_rstd option works correctly."""
59 | device = "cuda"
60 | M, N = 32, 1024
61 | eps = 1e-6
62 |
63 | x = torch.randn(M, N, device=device, dtype=torch.float16)
64 | weight = torch.randn(N, device=device, dtype=torch.float32)
65 |
66 | if return_rstd and return_mean:
67 | out, rstd, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=True)
68 | assert out.shape == (M, N)
69 | assert rstd.shape == (M,)
70 | assert rstd.dtype == torch.float32
71 | assert mean.shape == (M,)
72 | assert mean.dtype == torch.float32
73 | elif return_rstd and not return_mean:
74 | out, rstd = layernorm_fwd(x, weight, eps=eps, return_rstd=True, return_mean=False)
75 | assert out.shape == (M, N)
76 | assert rstd.shape == (M,)
77 | assert rstd.dtype == torch.float32
78 | elif not return_rstd and return_mean:
79 | out, mean = layernorm_fwd(x, weight, eps=eps, return_rstd=False, return_mean=True)
80 | assert out.shape == (M, N)
81 | assert mean.shape == (M,)
82 | assert mean.dtype == torch.float32
83 | else:
84 | out = layernorm_fwd(x, weight, eps=eps, return_rstd=False, return_mean=False)
85 | assert out.shape == (M, N)
86 | assert isinstance(out, torch.Tensor)
87 |
88 |
89 | def test_layernorm_input_validation():
90 | """Test input validation and error handling."""
91 | device = "cuda"
92 |
93 | # Test 3D input (should fail)
94 | x_3d = torch.randn(2, 32, 1024, device=device, dtype=torch.float16)
95 | weight = torch.randn(1024, device=device, dtype=torch.float32)
96 |
97 | with pytest.raises(AssertionError, match="Input must be 2D"):
98 | layernorm_fwd(x_3d, weight)
99 |
100 | # Test weight dimension mismatch
101 | x = torch.randn(32, 1024, device=device, dtype=torch.float16)
102 | weight_wrong = torch.randn(512, device=device, dtype=torch.float32)
103 |
104 | with pytest.raises(ValueError, match="Mismatched mW.shape[0]*"):
105 | layernorm_fwd(x, weight_wrong)
106 |
107 | # Test CPU tensors (should fail)
108 | x_cpu = torch.randn(32, 1024, dtype=torch.float16)
109 | weight_cpu = torch.randn(1024, dtype=torch.float32)
110 |
111 | # with pytest.raises(AssertionError, match="Tensors must be on CUDA device"):
112 | # With torch.library custom op, this now fails with NotImplementedError
113 | with pytest.raises(NotImplementedError):
114 | layernorm_fwd(x_cpu, weight_cpu)
115 |
116 | # Test unsupported dtype
117 | x = torch.randn(32, 1024, device=device, dtype=torch.float64)
118 | weight = torch.randn(1024, device=device, dtype=torch.float32)
119 |
120 | with pytest.raises(AssertionError, match="Unsupported dtype"):
121 | layernorm_fwd(x, weight)
122 |
123 | # Test wrong weight dtype
124 | x = torch.randn(32, 1024, device=device, dtype=torch.float16)
125 | weight_wrong_dtype = torch.randn(1024, device=device, dtype=torch.float16)
126 |
127 | with pytest.raises(AssertionError, match="Weight must be float32"):
128 | layernorm_fwd(x, weight_wrong_dtype)
129 |
--------------------------------------------------------------------------------
/quack/tensormap_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 |
3 | from typing import Tuple
4 | from dataclasses import dataclass
5 |
6 | import cutlass
7 | import cutlass.cute as cute
8 | from cutlass.cutlass_dsl import Boolean, const_expr, Int32
9 | from cutlass.utils import TensorMapUpdateMode, TensorMapManager
10 | from cutlass._mlir.dialects import llvm
11 |
12 |
13 | @dataclass(frozen=True)
14 | class TensorMapManagerSm90(TensorMapManager):
15 | """
16 | We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only
17 | perform the operation if warp_id matches the current warp.
18 | But for Hopper pingpong gemm we want to call it with warp_id 0 and 4.
19 | So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not.
20 | """
21 |
22 | @cute.jit
23 | def init_tensormap_from_atom(
24 | self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean
25 | ) -> None:
26 | if is_manager_warp:
27 | with cute.arch.elect_one():
28 | cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr)
29 | cute.arch.sync_warp()
30 | return
31 |
32 | @cute.jit
33 | def update_tensormap(
34 | self,
35 | tensor_gmem: Tuple[cute.Tensor, ...],
36 | tma_copy_atom: Tuple[cute.CopyAtom, ...],
37 | tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
38 | is_manager_warp: Boolean,
39 | tensormap_smem_ptr: Tuple[cute.Pointer, ...],
40 | ) -> None:
41 | # updates before touching tensormap in global memory
42 | if is_manager_warp:
43 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
44 | for copy_atom, tensor, smem_ptr in zip(
45 | tma_copy_atom, tensor_gmem, tensormap_smem_ptr
46 | ):
47 | cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr)
48 | # wait until it's safe to update tensormap in global memory
49 | with cute.arch.elect_one():
50 | cute.arch.cp_async_bulk_commit_group()
51 | cute.arch.cp_async_bulk_wait_group(0, read=True)
52 | cute.arch.sync_warp()
53 | # updates to tensormap in global memory
54 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
55 | for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
56 | cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
57 | else:
58 | for copy_atom, tensor, gmem_ptr in zip(
59 | tma_copy_atom, tensor_gmem, tensormap_gmem_ptr
60 | ):
61 | cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr)
62 | cute.arch.sync_warp()
63 | cute.nvgpu.cpasync.fence_tma_desc_release()
64 |
65 | @cute.jit
66 | def update_tensormap_shape(
67 | self,
68 | tensormap_gmem_ptr: Tuple[cute.Pointer, ...],
69 | is_manager_warp: Boolean,
70 | tensormap_smem_ptr: Tuple[cute.Pointer, ...],
71 | shapes: Tuple[Int32, ...],
72 | orders: cutlass.Constexpr[Tuple[int, ...]],
73 | ) -> None:
74 | # updates before touching tensormap in global memory
75 | if is_manager_warp:
76 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
77 | for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders):
78 | smem_ptr_i32 = smem_ptr.toint().ir_value()
79 | llvm.inline_asm(
80 | None,
81 | [smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()],
82 | "{\n\t"
83 | ".reg .b64 smem_ptr_i64;\n\t"
84 | "cvt.u64.u32 smem_ptr_i64, $0;\n\t"
85 | f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t"
86 | "}\n",
87 | "r,r",
88 | has_side_effects=True,
89 | is_align_stack=False,
90 | asm_dialect=llvm.AsmDialect.AD_ATT,
91 | )
92 | # wait until it's safe to update tensormap in global memory
93 | with cute.arch.elect_one():
94 | cute.arch.cp_async_bulk_commit_group()
95 | cute.arch.cp_async_bulk_wait_group(0, read=True)
96 | cute.arch.sync_warp()
97 | # updates to tensormap in global memory
98 | if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM):
99 | for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr):
100 | cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr)
101 | else:
102 | assert len(shapes) == len(orders) == len(tensormap_gmem_ptr)
103 | for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders):
104 | gmem_ptr_i64 = gmem_ptr.toint().ir_value()
105 | llvm.inline_asm(
106 | None,
107 | [gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()],
108 | f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;",
109 | "l,r",
110 | has_side_effects=True,
111 | is_align_stack=False,
112 | asm_dialect=llvm.AsmDialect.AD_ATT,
113 | )
114 | cute.arch.sync_warp()
115 | cute.nvgpu.cpasync.fence_tma_desc_release()
116 |
--------------------------------------------------------------------------------
/benchmarks/benchmark_softmax.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Type
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from triton.testing import do_bench
8 |
9 | import cutlass
10 | import cutlass.torch as cutlass_torch
11 |
12 | from quack.softmax import softmax
13 |
14 | try:
15 | from liger_kernel.transformers.functional import liger_softmax
16 | except ImportError:
17 | liger_softmax = None
18 |
19 |
20 | def run_softmax(
21 | M,
22 | N,
23 | dtype: Type[cutlass.Numeric],
24 | warmup_iterations=10,
25 | iterations=1000,
26 | ):
27 | if not torch.cuda.is_available():
28 | raise RuntimeError(f"Ampere GPU is required to run this example!")
29 |
30 | print(f"Tensor dimensions: [{M}, {N}]")
31 | print(f"Input and Output Data type: {dtype}")
32 |
33 | torch_dtype = cutlass_torch.dtype(dtype)
34 |
35 | device = "cuda"
36 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype)
37 |
38 | print(f"Input tensor shapes:")
39 | print(f"x: {x.shape}, dtype: {x.dtype}")
40 | out = softmax(x)
41 | # compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1))
42 | compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1))
43 | fn = lambda: softmax(x)
44 | time.sleep(0.5)
45 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
46 | mem_bw = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
47 | print(f"Kernel execution time: {avg_time:.4f} ms")
48 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
49 |
50 | fn = lambda: compiled_func_ref(x)
51 | for _ in range(5): fn() # warm up
52 | time.sleep(0.5)
53 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
54 | mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
55 | print(f"Torch compile kernel execution time: {avg_time:.4f} ms")
56 | print(f"Torch compile mem throughput: {mem_bw_ref:.2f} GB/s")
57 |
58 | if liger_softmax is not None:
59 | fn = lambda: liger_softmax(x)
60 | for _ in range(5): fn() # warm up
61 | time.sleep(0.5)
62 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
63 | mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
64 | print(f"Liger kernel execution time: {avg_time:.4f} ms")
65 | print(f"Liger mem throughput: {mem_bw_ref:.2f} GB/s")
66 |
67 | return mem_bw, mem_bw_ref
68 |
69 |
70 | def run_softmax_backward(
71 | M,
72 | N,
73 | dtype: Type[cutlass.Numeric],
74 | warmup_iterations=10,
75 | iterations=1000,
76 | ):
77 | if not torch.cuda.is_available():
78 | raise RuntimeError(f"Ampere GPU is required to run this example!")
79 |
80 | print(f"Tensor dimensions: [{M}, {N}]")
81 | print(f"Input and Output Data type: {dtype}")
82 |
83 | torch_dtype = cutlass_torch.dtype(dtype)
84 |
85 | device = "cuda"
86 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True)
87 | x_ref = x.detach().clone().requires_grad_()
88 |
89 | print(f"Input tensor shapes:")
90 | print(f"x: {x.shape}, dtype: {x.dtype}")
91 |
92 | y = softmax(x)
93 | dy = torch.randn_like(y)
94 |
95 | time.sleep(0.5)
96 | fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True)
97 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
98 | # Memory: read dy and y, write ax backward
99 | mem_bw = round(3 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9)
100 | print(f"Kernel execution time: {avg_time:.4f} ms")
101 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
102 |
103 | # Reference implementation
104 | y_ref = F.softmax(x_ref, dim=-1)
105 | compiled_func_ref = torch.compile(lambda: torch.autograd.grad(y_ref, x_ref, grad_outputs=dy, retain_graph=True))
106 |
107 | for _ in range(5): compiled_func_ref() # warm up
108 | time.sleep(0.5)
109 | avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations)
110 | mem_bw_ref = round(3 * x.numel() * dtype.width // 8 / (avg_time_ref / 1000) / 1e9)
111 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms")
112 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
113 |
114 | return mem_bw, mem_bw_ref
115 |
116 |
117 | if __name__ == "__main__":
118 | parser = argparse.ArgumentParser(
119 | description="Benchmark softmax forward and backward passes"
120 | )
121 | parser.add_argument("--M", default=8192, type=int)
122 | parser.add_argument("--N", default=16384, type=int)
123 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16)
124 | parser.add_argument("--warmup_iterations", default=10, type=int)
125 | parser.add_argument("--iterations", default=100, type=int)
126 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass")
127 |
128 | args = parser.parse_args()
129 | torch.manual_seed(0)
130 |
131 | if args.backward:
132 | print("=== Softmax Backward Pass Benchmark ===")
133 | run_softmax_backward(
134 | args.M,
135 | args.N,
136 | dtype=args.dtype,
137 | warmup_iterations=args.warmup_iterations,
138 | iterations=args.iterations,
139 | )
140 | else:
141 | print("=== Softmax Forward Pass Benchmark ===")
142 | run_softmax(
143 | args.M,
144 | args.N,
145 | dtype=args.dtype,
146 | warmup_iterations=args.warmup_iterations,
147 | iterations=args.iterations,
148 | )
149 |
150 | # MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)]
151 | # # MN_pairs = [(32768, 1024)]
152 | # results = []
153 | # for M, N in MN_pairs:
154 | # res = run_softmax(
155 | # M,
156 | # N,
157 | # dtype=args.dtype,
158 | # warmup_iterations=args.warmup_iterations,
159 | # iterations=args.iterations,
160 | # )
161 | # results.append(res)
162 | # # print(results)
163 | # print([x for x, _ in results])
164 |
--------------------------------------------------------------------------------
/docs/dsl_control_flow.rst:
--------------------------------------------------------------------------------
1 | .. _dsl_control_flow:
2 | .. |DC| replace:: dynamic compilation
3 | .. |IR| replace:: intermediate representation (IR)
4 | .. |DSL| replace:: CuTe DSL
5 | .. |Constexpr| replace:: **Constexpr** (compile-time Python value)
6 |
7 | Control Flow
8 | ==================
9 |
10 |
11 | Overview
12 | --------
13 | |DSL| walks Python's AST and converts each control-flow construct it finds into
14 | structured |IR|. You can therefore write ordinary Python loops and branches
15 | while the compiler decides—statement by statement—whether to
16 |
17 | * **evaluate at compile time** if it's a native Python control flow, or
18 | * **emit intermediate representation (IR)** when the control flow is marked as dynamic.
19 |
20 | Passing |IR| values to a native Python control flow will result in an error.
21 |
22 |
23 | For Loops
24 | ---------
25 | |DSL| recognises three kinds of ranges for ``for`` loops:
26 |
27 | * ``range`` – the Python built-in, always lowered to |IR|
28 | * ``cutlass.range`` - Same as Python built-in ``range``, but supports advanced unrolling and pipelining control
29 | * ``cutlass.range_constexpr`` – unrolled at compile time
30 |
31 |
32 | range(...)/cutlass.range(...)
33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
34 | Use when you *always* want a loop in the generated |IR|, even if the inputs
35 | are Python values.
36 |
37 | cutlass.range_constexpr(...)
38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
39 | Runs in the Python interpreter and is fully unrolled before code generation.
40 | All loop indices must be |Constexpr|.
41 |
42 |
43 | **Example:**
44 |
45 | .. code-block:: python
46 |
47 | @cute.jit
48 | def control_flow_examples(bound: cutlass.Int32):
49 | n = 10
50 |
51 | # ✅ This loop is Python loop, evaluated at compile time.
52 | for i in cutlass.range_constexpr(n):
53 | cute.printf("%d\\n", i)
54 |
55 | # ✅ This loop is dynamic, even when bound is Python value.
56 | for i in range(n):
57 | cute.printf("%d\\n", i)
58 |
59 | # ❌ This loop bound is a dynamic value, not allowed in Python loop.
60 | # Should use `range` instead.
61 | for i in cutlass.range_constexpr(bound):
62 | cute.printf("%d\\n", i)
63 |
64 | # ✅ This loop is dynamic, emitted IR loop.
65 | for i in range(bound):
66 | cute.printf("%d\\n", i)
67 |
68 | # ✅ This loop is dynamic, emitted IR loop with unrolling
69 | for i in cutlass.range(bound, unroll=2):
70 | cute.printf("%d\\n", i)
71 |
72 |
73 | If-Else Statements
74 | ------------------
75 |
76 | Standard Python ``if``/``elif``/``else`` is supported.
77 |
78 | * **Predicate without annotation** → lowered to |IR|.
79 | * **Predicate annotated with `cutlass.const_expr`** → evaluated at compile time.
80 |
81 | **Example:**
82 |
83 | .. code-block:: python
84 |
85 | @cute.jit
86 | def main(const_var: cutlass.Constexpr, dynamic_var: cutlass.Int32):
87 | # ✅ This branch is Python branch, evaluated at compile time.
88 | if cutlass.const_expr(const_var):
89 | cute.printf("Const branch\\n")
90 | else:
91 | cute.printf("Const else\\n")
92 |
93 | # ✅ This branch is dynamic branch, emitted IR branch.
94 | if dynamic_var == 10:
95 | cute.printf("Dynamic True\\n")
96 | else:
97 | cute.printf("Dynamic False\\n")
98 |
99 | # ❌ Using a dynamic value with `cutlass.const_expr` is not allowed.
100 | if cutlass.const_expr(dynamic_var == 10):
101 | cute.printf("Bound is 10\\n")
102 |
103 |
104 | While Loops
105 | -----------
106 |
107 | Standard Python ``while`` is supported.
108 |
109 | * **Condition without annotation** → lowered to |IR|.
110 | * **Condition annotated with `cutlass.const_expr`** → evaluated at compile time.
111 |
112 | **Example:**
113 |
114 | .. code-block:: python
115 |
116 | @cute.jit
117 | def main(dynamic_var: cutlass.Int32):
118 | n = 0
119 |
120 | # ✅ This is Python while loop, evaluated at compile time.
121 | while cutlass.const_expr(n < 10):
122 | cute.printf("Const branch\\n")
123 | n += 1
124 |
125 | # ✅ This is dynamic while loop, emitted IR while loop.
126 | while dynamic_var == 10:
127 | cute.printf("Dynamic True\\n")
128 | n += 1
129 |
130 | # ❌ Using a dynamic value with `cutlass.const_expr` is not allowed.
131 | while cutlass.const_expr(n < dynamic_var):
132 | n += 1
133 |
134 |
135 | Compile-Time Metaprogramming
136 | ----------------------------
137 |
138 | Mix compile-time constructs with normal |DSL| code to generate specialised
139 | kernels without runtime overhead. A compile-time flag can, for example, toggle
140 | an optional **ReLU** epilogue:
141 |
142 | .. code-block:: python
143 |
144 | @cute.kernel
145 | def gemm(..., do_relu: cutlass.Constexpr):
146 | # main GEMM work
147 | ...
148 | if cutlass.const_expr(do_relu): # compile-time guard
149 | # ReLU code is emitted only when do_relu is True
150 | ...
151 |
152 | .. code-block:: text
153 |
154 | gemm(..., False) # ReLU is omitted from the generated |IR|
155 | gemm(..., True) # ReLU is included
156 |
157 |
158 | Limitations of Dynamic Control Flow
159 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160 |
161 | * Early-exit ``break``, ``continue``, ``pass`` or raising exception from
162 | control flow body are not yet supported.
163 | * Operations in the control flow body are traced only when tracing is active in
164 | that region.
165 | * Values originating in control flow body are not available outside the control
166 | flow.
167 | * Changing type of a variable in control flow body is not allowed.
168 |
169 | **Example:**
170 |
171 | .. code-block:: python
172 |
173 | @cute.jit
174 | def control_flow_negative_examples(predicate: cutlass.Boolean):
175 | n = 10
176 |
177 | # ❌ This loop is dynamic, early-exit isn't allowed.
178 | for i in range(n):
179 | if i == 5:
180 | break # Early-exit
181 |
182 | if predicate:
183 | val = 10
184 | # ❌ return from control flow body is not allowed.
185 | return
186 | # ❌ Raising exception from control flow body is not allowed.
187 | raise ValueError("This is not allowed")
188 | # ❌ Using pass in control flow body is not allowed.
189 | pass
190 |
191 | # ❌ val is not available outside the dynamic if
192 | cute.printf("%d\\n", val)
193 |
194 | if predicate:
195 | # ❌ Changing type of a variable in control flow body is not allowed.
196 | n = 10.0
197 |
198 |
--------------------------------------------------------------------------------
/tests/test_softmax.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2 |
3 | import pytest
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | from quack.softmax import softmax
9 |
10 |
11 | torch._dynamo.config.cache_size_limit = 1024
12 | torch._dynamo.config.accumulated_cache_size_limit = 1024
13 |
14 |
15 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16, torch.float32])
16 | # @pytest.mark.parametrize("input_dtype", [torch.float32])
17 | @pytest.mark.parametrize(
18 | "N",
19 | [192, 256, 512, 760, 1024, 1128, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144],
20 | # [32768]
21 | )
22 | @pytest.mark.parametrize("M", [1, 37, 199])
23 | @pytest.mark.parametrize("function", [softmax, torch.compile(softmax, fullgraph=True)])
24 | # @pytest.mark.parametrize("M", [1])
25 | def test_softmax(M, N, input_dtype, function):
26 | """Test Softmax forward and backward passes against reference implementation."""
27 | device = "cuda"
28 | # Set tolerance based on dtype
29 | if input_dtype == torch.bfloat16:
30 | atol = 1e-2
31 | rtol = 1e-2
32 | elif input_dtype == torch.float16:
33 | atol = 1e-3
34 | rtol = 1e-3
35 | else:
36 | atol = 1e-4
37 | rtol = 1e-4
38 |
39 | torch.random.manual_seed(0)
40 | # Create input tensors (scale down to avoid overflow in softmax)
41 | x = (0.1 * torch.randn(M, N, device=device, dtype=input_dtype)).requires_grad_()
42 | x_ref = x.detach().clone().requires_grad_(True)
43 |
44 | # Forward pass
45 | out = function(x)
46 | out_ref = F.softmax(x_ref, dim=-1)
47 |
48 | # Check output shape and dtype
49 | assert out.shape == x.shape
50 | assert out.dtype == input_dtype
51 | # Check accuracy
52 | torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
53 | # Check softmax properties
54 | # Sum along last dimension should be 1
55 | sums = torch.sum(out, dim=-1)
56 | torch.testing.assert_close(sums, torch.ones_like(sums), atol=1e-4, rtol=1e-4)
57 | # All values should be positive
58 | assert (out >= 0).all()
59 | # All values should be <= 1
60 | assert (out <= 1).all()
61 |
62 | # Test backward pass
63 | dy = torch.randn_like(out)
64 | torch.cuda.synchronize() # without sync, torch.autograd gets wrong results
65 | (dx,) = torch.autograd.grad(out, x, grad_outputs=dy)
66 | (dx_ref,) = torch.autograd.grad(out_ref, x_ref, grad_outputs=dy)
67 | assert dx.shape == dy.shape
68 | assert dx.dtype == input_dtype
69 | torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)
70 |
71 |
72 | @pytest.mark.parametrize("input_dtype", [torch.float16, torch.float32])
73 | @pytest.mark.parametrize("function", [softmax, torch.compile(softmax, fullgraph=True)])
74 | def test_softmax_extreme_values(input_dtype, function):
75 | """Test Softmax with extreme input values."""
76 | device = "cuda"
77 | M, N = 16, 1024
78 | # Test with large positive values
79 | x_large = torch.full((M, N), 10.0, device=device, dtype=input_dtype)
80 | out_large = function(x_large)
81 | # Should be uniform since all values are the same
82 | expected = torch.full_like(out_large, 1.0 / N)
83 | torch.testing.assert_close(out_large, expected, atol=1e-3, rtol=1e-3)
84 | # Test with large negative values
85 | x_small = torch.full((M, N), -10.0, device=device, dtype=input_dtype)
86 | out_small = function(x_small)
87 | # Should also be uniform
88 | torch.testing.assert_close(out_small, expected, atol=1e-3, rtol=1e-3)
89 | # Test with mixed extreme values
90 | x_mixed = torch.zeros((M, N), device=device, dtype=input_dtype)
91 | x_mixed[:, 0] = 10.0 # One large value per row
92 | x_mixed[:, 1:] = -10.0 # Rest are small
93 | out_mixed = function(x_mixed)
94 | # First column should be close to 1, rest close to 0
95 | assert (out_mixed[:, 0] > 0.99).all()
96 | assert (out_mixed[:, 1:] < 0.01).all()
97 |
98 |
99 | @pytest.mark.parametrize("function", [softmax, torch.compile(softmax, fullgraph=True)])
100 | def test_softmax_numerical_stability(function):
101 | """Test that softmax is numerically stable."""
102 | device = "cuda"
103 | M, N = 8, 512
104 | # Create input with a wide range of values
105 | x = torch.randn(M, N, device=device, dtype=torch.float32)
106 | # Add large constant to test numerical stability
107 | x_shifted = x + 100.0
108 | out = function(x)
109 | out_shifted = function(x_shifted)
110 | # Results should be identical (softmax is translation invariant)
111 | torch.testing.assert_close(out, out_shifted, atol=1e-6, rtol=1e-6)
112 |
113 |
114 | # def test_softmax_backward_properties():
115 | # """Test mathematical properties of softmax backward pass."""
116 | # device = "cuda"
117 | # M, N = 16, 1024
118 |
119 | # torch.random.manual_seed(42)
120 | # # Create test inputs
121 | # x = torch.randn(M, N, device=device, dtype=torch.float32, requires_grad=True)
122 | # y = F.softmax(x, dim=-1)
123 |
124 | # # Test 1: Gradient of uniform upstream should sum to zero along last dim
125 | # dy_uniform = torch.ones_like(y)
126 | # dx = softmax_backward(dy_uniform, y.detach())
127 |
128 | # # Sum along last dimension should be approximately zero
129 | # row_sums = torch.sum(dx, dim=-1)
130 | # torch.testing.assert_close(row_sums, torch.zeros_like(row_sums), atol=1e-6, rtol=1e-6)
131 |
132 | # # Test 2: Gradient should be zero when upstream gradient is proportional to y
133 | # # (since softmax(x) ∝ y means dy ∝ y, so dy - dot*1 = 0)
134 | # for scale in [0.5, 1.0, 2.0]:
135 | # dy_prop = scale * y
136 | # dx_prop = softmax_backward(dy_prop, y.detach())
137 | # torch.testing.assert_close(dx_prop, torch.zeros_like(dx_prop), atol=1e-6, rtol=1e-6)
138 |
139 |
140 | # def test_softmax_backward_edge_cases():
141 | # """Test softmax backward with edge cases."""
142 | # device = "cuda"
143 |
144 | # # Test with one-hot softmax output (peaked distribution)
145 | # M, N = 8, 512
146 | # y = torch.zeros(M, N, device=device, dtype=torch.float32)
147 | # y[:, 0] = 1.0 # One-hot at first position
148 |
149 | # dy = torch.randn_like(y)
150 | # dx = softmax_backward(dy, y)
151 |
152 | # # Check that gradients exist and are finite
153 | # assert torch.isfinite(dx).all()
154 |
155 | # # Test with uniform softmax output
156 | # y_uniform = torch.full((M, N), 1.0/N, device=device, dtype=torch.float32)
157 | # dy_uniform = torch.randn_like(y_uniform)
158 | # dx_uniform = softmax_backward(dy_uniform, y_uniform)
159 |
160 | # # Check that gradients exist and are finite
161 | # assert torch.isfinite(dx_uniform).all()
162 |
163 | # # Row sums should be zero (property of softmax jacobian)
164 | # row_sums = torch.sum(dx_uniform, dim=-1)
165 | # torch.testing.assert_close(row_sums, torch.zeros_like(row_sums), atol=1e-6, rtol=1e-6)
166 |
--------------------------------------------------------------------------------
/benchmarks/benchmark_topk.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Type
4 |
5 | import torch
6 | from triton.testing import do_bench
7 |
8 | import cutlass
9 | import cutlass.torch as cutlass_torch
10 |
11 | from quack.topk import topk, topk_bwd
12 |
13 | try:
14 | import rtopk
15 | except ImportError:
16 | rtopk = None
17 |
18 |
19 | def run_topk(
20 | M,
21 | N,
22 | k,
23 | dtype: Type[cutlass.Numeric],
24 | softmax: bool = False,
25 | backward: bool = False,
26 | warmup_iterations=10,
27 | iterations=1000,
28 | ):
29 | if not torch.cuda.is_available():
30 | raise RuntimeError(f"CUDA GPU is required to run this example!")
31 |
32 | print(f"Tensor dimensions: [{M}, {N}], k={k}")
33 | print(f"Input and Output Data type: {dtype}")
34 |
35 | torch_dtype = cutlass_torch.dtype(dtype)
36 |
37 | device = "cuda"
38 | x = torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True)
39 |
40 | print(f"Input tensor shapes:")
41 | print(f"x: {x.shape}, dtype: {x.dtype}")
42 | out, idx = topk(x, k, softmax=softmax)
43 | print(f"Output shape: {out.shape}")
44 |
45 | if backward:
46 | dvalues = torch.randn_like(out)
47 | fn = lambda: topk_bwd(dvalues, out, idx, N, softmax=softmax)
48 | else:
49 | # Benchmark our implementation
50 | fn = lambda: topk(x, k, softmax=softmax)
51 | fn() # warm up
52 | time.sleep(0.5)
53 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
54 | # Memory: read input (M*N elements), write output (M*k elements)
55 | if backward:
56 | mem_accessed = (M * N + 2 * M * k) * dtype.width // 8
57 | else:
58 | mem_accessed = (M * N + M * k) * dtype.width // 8
59 | mem_bw = round(mem_accessed / (avg_time / 1000) / 1e9, 2)
60 | print(f"Kernel execution time: {avg_time:.4f} ms")
61 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
62 |
63 | # Benchmark PyTorch reference
64 | if backward:
65 | fn_ref = lambda: torch.autograd.grad(
66 | torch.softmax(torch.topk(x, k, dim=-1, largest=True, sorted=True)[0], dim=-1)
67 | if softmax
68 | else torch.topk(x, k, dim=-1, largest=True, sorted=True)[0],
69 | x,
70 | grad_outputs=dvalues,
71 | retain_graph=True,
72 | )
73 | for _ in range(5):
74 | fn_ref()
75 | time.sleep(0.5)
76 | avg_time_ref = do_bench(fn_ref, warmup=warmup_iterations, rep=iterations)
77 | mem_bw_ref = round(mem_accessed / (avg_time_ref / 1000) / 1e9, 2)
78 | print(f"Ref backward execution time: {avg_time_ref:.4f} ms")
79 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
80 | speedup = avg_time_ref / avg_time
81 | print(f"Speedup: {speedup:.2f}x")
82 | else:
83 | fn_ref = lambda: torch.topk(x, k, dim=-1, largest=True, sorted=True)[0]
84 | for _ in range(5): fn_ref() # warm up
85 | time.sleep(0.5)
86 | avg_time_ref = do_bench(fn_ref, warmup=warmup_iterations, rep=iterations)
87 | mem_bw_ref = round(mem_accessed / (avg_time_ref / 1000) / 1e9, 2)
88 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms")
89 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
90 |
91 | speedup = avg_time_ref / avg_time
92 | print(f"Speedup: {speedup:.2f}x")
93 |
94 | if rtopk is not None:
95 | fn_rtopk = lambda: rtopk.ops.rtopk(x, k, max_iter=512)
96 | for _ in range(5): fn_rtopk() # warm up
97 | time.sleep(0.5)
98 | avg_time_ref = do_bench(fn_rtopk, warmup=warmup_iterations, rep=iterations)
99 | mem_bw_ref = round(mem_accessed / (avg_time_ref / 1000) / 1e9, 2)
100 | print(f"RTopK kernel execution time: {avg_time_ref:.4f} ms")
101 | print(f"RTopK mem throughput: {mem_bw_ref:.2f} GB/s")
102 |
103 | # do_bench doesn't seem very accurate for very fast kernels, so we use pytorch_profiler
104 | from flash_attn.utils.benchmark import pytorch_profiler
105 | pytorch_profiler(fn)
106 | if rtopk is not None:
107 | pytorch_profiler(fn_rtopk)
108 |
109 | return mem_bw, mem_bw_ref, speedup
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = argparse.ArgumentParser(
114 | description="Benchmark top-k operation"
115 | )
116 | parser.add_argument("--M", default=8192, type=int)
117 | parser.add_argument("--N", default=1024, type=int)
118 | parser.add_argument("--k", default=32, type=int)
119 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16)
120 | parser.add_argument("--softmax", action="store_true", help="Apply softmax to top-k values")
121 | parser.add_argument("--warmup_iterations", default=10, type=int)
122 | parser.add_argument("--iterations", default=100, type=int)
123 | parser.add_argument("--sweep", action="store_true", help="Run sweep across different N and k values")
124 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass")
125 |
126 | args = parser.parse_args()
127 | torch.manual_seed(0)
128 |
129 | cutlass.cuda.initialize_cuda_context()
130 |
131 | if args.sweep:
132 | print("=== Top-K Sweep Benchmark ===")
133 | # Test different combinations of N and k
134 | N_values = [64, 128, 256, 512, 1024]
135 | k_values = [8, 16, 32]
136 |
137 | results = []
138 | for N in N_values:
139 | for k in k_values:
140 | if k > N // 2: # Skip if k is too large relative to N
141 | continue
142 | print(f"\n--- N={N}, k={k} ---")
143 | try:
144 | mem_bw, mem_bw_ref, speedup = run_topk(
145 | args.M,
146 | N,
147 | k,
148 | dtype=args.dtype,
149 | softmax=args.softmax,
150 | backward=args.backward,
151 | warmup_iterations=args.warmup_iterations,
152 | iterations=args.iterations,
153 | )
154 | results.append((N, k, mem_bw, mem_bw_ref, speedup))
155 | except Exception as e:
156 | print(f"Error with N={N}, k={k}: {e}")
157 |
158 | print("\n=== Summary ===")
159 | print("N\tk\tOurs (GB/s)\tRef (GB/s)\tSpeedup")
160 | for N, k, mem_bw, mem_bw_ref, speedup in results:
161 | print(f"{N}\t{k}\t{mem_bw}\t\t{mem_bw_ref}\t\t{speedup:.2f}x")
162 | else:
163 | print("=== Top-K Benchmark ===")
164 | run_topk(
165 | args.M,
166 | args.N,
167 | args.k,
168 | dtype=args.dtype,
169 | softmax=args.softmax,
170 | backward=args.backward,
171 | warmup_iterations=args.warmup_iterations,
172 | iterations=args.iterations,
173 | )
174 |
--------------------------------------------------------------------------------
/benchmarks/benchmark_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Type
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from triton.testing import do_bench
8 |
9 | import cutlass
10 | import cutlass.torch as cutlass_torch
11 |
12 | from quack.cross_entropy import cross_entropy_fwd, cross_entropy
13 |
14 |
15 | def run_cross_entropy(
16 | M,
17 | N,
18 | dtype: Type[cutlass.Numeric],
19 | warmup_iterations=2,
20 | iterations=200,
21 | return_dx=False,
22 | ):
23 | if not torch.cuda.is_available():
24 | raise RuntimeError(f"Ampere GPU is required to run this example!")
25 |
26 | print(f"Tensor dimensions: [{M}, {N}]")
27 | print(f"Input Data type: {dtype}")
28 |
29 | torch_dtype = cutlass_torch.dtype(dtype)
30 |
31 | device = "cuda"
32 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype)
33 | target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
34 |
35 | loss = cross_entropy_fwd(x, target, return_dx=return_dx)
36 |
37 | compiled_func_ref = torch.compile(lambda x, target: F.cross_entropy(x, target, reduction='none'))
38 |
39 | fn = lambda: cross_entropy_fwd(x, target, return_dx=return_dx)
40 | time.sleep(0.5)
41 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
42 | # Memory bandwidth calculation: read x (M*N elements) + read target (M elements) + write loss (M elements)
43 | mem_bytes = (M * N * (2 if return_dx else 1) + M + M) * dtype.width // 8
44 | mem_bw = round(mem_bytes / (avg_time / 1000) / 1e9)
45 | print(f"Kernel execution time: {avg_time:.4f} ms")
46 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
47 |
48 | fn_ref = lambda: compiled_func_ref(x, target)
49 | for _ in range(5): fn_ref() # warm up
50 | time.sleep(0.5)
51 | avg_time = do_bench(fn_ref, warmup=warmup_iterations, rep=iterations)
52 | mem_bytes = (M * N + M + M) * dtype.width // 8
53 | mem_bw_ref = round(mem_bytes / (avg_time / 1000) / 1e9)
54 | print(f"Ref kernel execution time: {avg_time:.4f} ms")
55 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
56 |
57 | # from flash_attn.utils.benchmark import pytorch_profiler
58 | # pytorch_profiler(fn)
59 | # pytorch_profiler(fn_ref)
60 | # pytorch_profiler(torch.compile(torch.logsumexp), x, dim=-1)
61 |
62 | return mem_bw, mem_bw_ref
63 |
64 |
65 |
66 | def run_cross_entropy_backward(
67 | M,
68 | N,
69 | dtype: Type[cutlass.Numeric],
70 | warmup_iterations=10,
71 | iterations=1000,
72 | ):
73 | if not torch.cuda.is_available():
74 | raise RuntimeError(f"Ampere GPU is required to run this example!")
75 |
76 | print(f"Tensor dimensions: [{M}, {N}]")
77 | print(f"Input Data type: {dtype}")
78 |
79 | torch_dtype = cutlass_torch.dtype(dtype)
80 |
81 | device = "cuda"
82 | x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True)
83 | target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
84 | x_ref = x.detach().clone().requires_grad_()
85 |
86 | print(f"Input tensor shapes:")
87 | print(f"x: {x.shape}, dtype: {x.dtype}")
88 | print(f"target: {target.shape}, dtype: {target.dtype}")
89 |
90 | loss = cross_entropy(x, target, reduction="none")
91 | dloss = torch.randn(M, device=device, dtype=torch.float32)
92 | torch.cuda.synchronize()
93 |
94 | # Reference implementation
95 | loss_ref = F.cross_entropy(x_ref, target, reduction='none')
96 | compiled_func_ref = torch.compile(lambda: torch.autograd.grad(loss_ref, x_ref, grad_outputs=dloss, retain_graph=True))
97 |
98 | for _ in range(5): compiled_func_ref() # warm up
99 | time.sleep(0.5)
100 | avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations)
101 | mem_bw_ref = round((2 * x.numel() * x.element_size() + target.numel() * target.element_size() +
102 | dloss.numel() * dloss.element_size()) / (avg_time_ref / 1000) / 1e9)
103 |
104 | time.sleep(0.5)
105 | fn = lambda: torch.autograd.grad(loss, x, grad_outputs=dloss, retain_graph=True)
106 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
107 | # Memory bandwidth calculation: read x (M*N) + read target (M) + read dloss (M) + write grad (M*N)
108 | mem_bw = round((2 * x.numel() * x.element_size() + target.numel() * target.element_size() +
109 | dloss.numel() * dloss.element_size()) / (avg_time / 1000) / 1e9)
110 | print(f"Kernel execution time: {avg_time:.4f} ms")
111 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
112 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms")
113 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
114 |
115 | return mem_bw, mem_bw_ref
116 |
117 |
118 | if __name__ == "__main__":
119 | parser = argparse.ArgumentParser(
120 | description="Benchmark cross entropy forward and backward passes"
121 | )
122 | parser.add_argument("--M", default=8192, type=int)
123 | parser.add_argument("--N", default=16384, type=int)
124 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16)
125 | parser.add_argument("--warmup_iterations", default=10, type=int)
126 | parser.add_argument("--iterations", default=100, type=int)
127 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass")
128 | parser.add_argument("--fwd_dx", action="store_true", help="Benchmark forward pass that also computes dx")
129 |
130 | args = parser.parse_args()
131 | torch.manual_seed(0)
132 | cutlass.cuda.initialize_cuda_context()
133 |
134 |
135 | if args.backward:
136 | print("=== Cross Entropy Backward Pass Benchmark ===")
137 | run_cross_entropy_backward(
138 | args.M,
139 | args.N,
140 | dtype=args.dtype,
141 | warmup_iterations=args.warmup_iterations,
142 | iterations=args.iterations,
143 | )
144 | else:
145 | print("=== Cross Entropy Forward Pass Benchmark ===")
146 | run_cross_entropy(
147 | args.M,
148 | args.N,
149 | dtype=args.dtype,
150 | warmup_iterations=args.warmup_iterations,
151 | iterations=args.iterations,
152 | return_dx=args.fwd_dx,
153 | )
154 |
155 |
156 | '''
157 | #MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)]
158 | MN_pairs = [(32768, 65536)]
159 | results = []
160 | for M, N in MN_pairs:
161 | res = run_cross_entropy_backward(
162 | M,
163 | N,
164 | dtype=args.dtype,
165 | warmup_iterations=args.warmup_iterations,
166 | iterations=args.iterations,
167 | )
168 | results.append(res)
169 | print(results)
170 | #print([x for x, _ in results])
171 | '''
172 |
--------------------------------------------------------------------------------
/tests/test_linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025, Tri Dao.
2 | import math
3 | import pytest
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from quack.linear import linear_func, linear_act_func
8 | from quack.gemm_interface import (
9 | gemm_add_inplace,
10 | gemm_dact,
11 | gemm_act_ref,
12 | gemm_dact_ref,
13 | )
14 |
15 |
16 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
17 | @pytest.mark.parametrize("has_bias", [False, True])
18 | @pytest.mark.parametrize("out_features", [1504, 2048])
19 | @pytest.mark.parametrize("in_features", [736, 4096])
20 | # @pytest.mark.parametrize("out_features", [2048])
21 | # @pytest.mark.parametrize("in_features", [4096])
22 | def test_linear(in_features, out_features, has_bias, input_dtype):
23 | device = "cuda"
24 | torch.random.manual_seed(0)
25 | m = 1920
26 | x = torch.randn((m, in_features), device=device, dtype=input_dtype)
27 | x = x[::2].requires_grad_(True) # Testing non-contiguous
28 | w = (
29 | torch.randn((out_features, in_features), device=device, dtype=input_dtype)
30 | / math.sqrt(in_features)
31 | ).requires_grad_()
32 | bias = torch.randn(out_features, device=device, requires_grad=True) if has_bias else None
33 | x_ref, w_ref, bias_ref = [
34 | t.detach().clone().float().requires_grad_(True) if t is not None else None
35 | for t in (x, w, bias)
36 | ]
37 | x_pt, w_pt, bias_pt = [
38 | t.detach().clone().to(x.dtype).requires_grad_(True) if t is not None else None
39 | for t in (x, w, bias)
40 | ]
41 | out = linear_func(x, w, bias, tuned=False) # Disable tuning for faster test
42 | out_ref = F.linear(x_ref, w_ref, bias_ref)
43 | out_pt = F.linear(x_pt, w_pt, bias_pt)
44 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-6
45 | dout = torch.randn_like(out)
46 | out.backward(dout)
47 | out_ref.backward(dout.float())
48 | out_pt.backward(dout)
49 | assert (x.grad - x_ref.grad).abs().max() < 2 * (x_pt.grad - x_ref.grad).abs().max() + 1e-6
50 | assert (w.grad - w_ref.grad).abs().max() < 2 * (w_pt.grad - w_ref.grad).abs().max() + 1e-6
51 | if bias is not None:
52 | assert (bias.grad - bias_ref.grad).abs().max() < 2 * (
53 | bias_pt.grad - bias_ref.grad
54 | ).abs().max() + 1e-6
55 |
56 |
57 | @pytest.mark.parametrize("store_preact", [False, True])
58 | @pytest.mark.parametrize("activation", ["relu", "relu_sq", "gelu_tanh_approx"])
59 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
60 | @pytest.mark.parametrize("has_bias", [False, True])
61 | @pytest.mark.parametrize("out_features", [1504, 2048])
62 | @pytest.mark.parametrize("in_features", [736, 4096])
63 | # @pytest.mark.parametrize("out_features", [2048])
64 | # @pytest.mark.parametrize("in_features", [4096])
65 | def test_linear_act(in_features, out_features, has_bias, input_dtype, activation, store_preact):
66 | device = "cuda"
67 | torch.random.manual_seed(0)
68 | m = 1920
69 | x = torch.randn((m, in_features), device=device, dtype=input_dtype)
70 | x = x[::2].requires_grad_(True) # Testing non-contiguous
71 | w = (
72 | torch.randn((out_features, in_features), device=device, dtype=input_dtype)
73 | / math.sqrt(in_features)
74 | ).requires_grad_()
75 | bias = torch.randn(out_features, device=device, requires_grad=True) if has_bias else None
76 | # Disable tuning for faster test
77 | preact, postact = linear_act_func(
78 | x, w, activation, bias=bias, store_preact=store_preact, tuned=False
79 | )
80 | preact_ref, postact_ref = gemm_act_ref(
81 | x.float(), w.float().T, activation=activation, bias=bias, store_preact=store_preact
82 | )
83 | preact_pt, postact_pt = gemm_act_ref(
84 | x, w.T, activation=activation, bias=bias, store_preact=store_preact
85 | )
86 | assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-6
87 | if store_preact:
88 | assert preact is not None and preact_ref is not None
89 | assert (preact - preact_ref).abs().max() < 2 * (preact_pt - preact_ref).abs().max() + 1e-6
90 |
91 |
92 | @pytest.mark.parametrize("activation", ["relu", "relu_sq", "gelu_tanh_approx"])
93 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
94 | @pytest.mark.parametrize("k", [736, 1024])
95 | @pytest.mark.parametrize("n", [1504, 2048])
96 | def test_gemm_dact(n, k, input_dtype, activation):
97 | """Test GEMM with activation gradient computation."""
98 | device = "cuda"
99 | torch.random.manual_seed(0)
100 | m = 960
101 | dout_input = torch.randn((m, k), device=device, dtype=input_dtype)
102 | weight = torch.randn((n, k), device=device, dtype=input_dtype) / math.sqrt(k)
103 | preact = torch.randn((m, n), device=device, dtype=input_dtype, requires_grad=True)
104 | # Disable tuning for faster test
105 | dx, postact = gemm_dact(dout_input, weight.T, preact, activation=activation, tuned=False)
106 | dx_ref, postact_ref = gemm_dact_ref(
107 | dout_input.float(), weight.float().T, preact.float(), activation=activation
108 | )
109 | dx_pt, postact_pt = gemm_dact_ref(dout_input, weight.T, preact, activation=activation)
110 | assert (dx - dx_ref).abs().max() < 2 * (dx_pt - dx_ref).abs().max() + 1e-5
111 | assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-5
112 |
113 |
114 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
115 | @pytest.mark.parametrize("n", [1504, 2048])
116 | @pytest.mark.parametrize("k", [736, 1024])
117 | @pytest.mark.parametrize("m", [960, 1920])
118 | def test_gemm_add_inplace(m, k, n, input_dtype):
119 | """Test in-place GEMM with addition: C += A @ B."""
120 | device = "cuda"
121 | torch.random.manual_seed(0)
122 | A = torch.randn((m, k), device=device, dtype=input_dtype)
123 | B = torch.randn((k, n), device=device, dtype=input_dtype)
124 | C = torch.randn((m, n), device=device, dtype=input_dtype)
125 | # Save original C for reference computation
126 | C_og = C.clone()
127 | gemm_add_inplace(A, B, C, tuned=False)
128 | C_ref = C_og.float() + torch.mm(A.float(), B.float())
129 | C_pt = C_og + torch.mm(A, B)
130 | assert (C - C_ref).abs().max() < 2 * (C_pt - C_ref).abs().max() + 1e-5
131 |
132 |
133 | @pytest.mark.parametrize("alpha_beta_type", ["float", "tensor"])
134 | @pytest.mark.parametrize("alpha", [0.5, 1.0, 2.0])
135 | @pytest.mark.parametrize("beta", [0.0, 0.5, 1.0, 1.5])
136 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
137 | @pytest.mark.parametrize("n", [512, 1024])
138 | @pytest.mark.parametrize("k", [256, 768])
139 | @pytest.mark.parametrize("m", [480, 960])
140 | def test_gemm_add_inplace_alpha_beta(m, k, n, input_dtype, alpha, beta, alpha_beta_type):
141 | """Test in-place GEMM with alpha/beta scaling: C = alpha * A @ B + beta * C."""
142 | device = "cuda"
143 | torch.random.manual_seed(42)
144 | A = torch.randn((m, k), device=device, dtype=input_dtype)
145 | B = torch.randn((k, n), device=device, dtype=input_dtype)
146 | C = torch.randn((m, n), device=device, dtype=input_dtype)
147 | if alpha_beta_type == "tensor":
148 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32)
149 | beta = torch.tensor(beta, device=device, dtype=torch.float32)
150 | C_og = C.clone()
151 | gemm_add_inplace(A, B, C, alpha=alpha, beta=beta, tuned=False)
152 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha
153 | beta_val = beta.item() if torch.is_tensor(beta) else beta
154 | C_ref = alpha_val * torch.mm(A.float(), B.float()) + beta_val * C_og.float()
155 | C_pt = alpha_val * torch.mm(A, B) + beta_val * C_og
156 | assert (C - C_ref).abs().max() < 2 * (C_pt - C_ref).abs().max() + 1e-4
157 |
--------------------------------------------------------------------------------
/quack/gemm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from functools import partial
3 |
4 | from torch import Tensor
5 |
6 | import cutlass.cute as cute
7 | import cutlass.torch as cutlass_torch
8 | from cutlass import Float32
9 | from cutlass.cute.runtime import from_dlpack, make_ptr
10 |
11 | from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
12 | from quack.gemm_wrapper_utils import GemmWrapperBase
13 | from quack.gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100
14 |
15 |
16 | def gemm(
17 | # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k
18 | A: Tensor,
19 | B: Tensor, # (l, n, k) or (n, total_k) if varlen_k
20 | D: Tensor, # (l, m, n) or (total_m, n) if varlen_m
21 | C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m
22 | tile_count_semaphore: Optional[Tensor], # (1,)
23 | tile_M: int,
24 | tile_N: int,
25 | cluster_M: int,
26 | cluster_N: int,
27 | pingpong: bool = False,
28 | persistent: bool = True,
29 | max_swizzle_size: int = 8,
30 | rowvec_bias: Optional[Tensor] = None, # (l, n)
31 | colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m
32 | alpha: float | Tensor = 1.0,
33 | beta: float | Tensor = 1.0,
34 | cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
35 | cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length
36 | A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen
37 | batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler
38 | add_to_output: bool = False,
39 | ) -> None:
40 | varlen = cu_seqlens_m is not None or cu_seqlens_k is not None
41 | assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
42 | "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
43 | )
44 | gather_A = A_idx is not None
45 | if gather_A:
46 | assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)"
47 | assert cluster_N == 1, "gather_A requires cluster_N=1"
48 | if varlen:
49 | assert persistent, "varlen requires persistent=True"
50 | if add_to_output:
51 | assert cu_seqlens_m is None, "Add to output not supported with varlen_m"
52 | if cu_seqlens_m is not None:
53 | assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
54 | assert D.stride(-1) == 1, "varlen_m requires D to be n-major"
55 | if cu_seqlens_k is not None:
56 | assert A.stride(-2) == 1, "varlen_k requires A to be m-major"
57 | assert B.stride(-2) == 1, "varlen_k requires B to be n-major"
58 |
59 | L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
60 | A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
61 | )
62 | GemmWrapperBase.permute_tensors(
63 | tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None
64 | )
65 | GemmWrapperBase.extract_dtypes(tensor_infos)
66 | major_configs = {
67 | "A": ("m", "k", "l"),
68 | "B": ("n", "k", "l"),
69 | "D": ("m", "n", "l"),
70 | "C": ("m", "n", "l"),
71 | }
72 | GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
73 |
74 | device_capacity = get_device_capacity(A.device)
75 | assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
76 | GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90
77 |
78 | acc_dtype = Float32
79 | tile_shape_mn = (tile_M, tile_N)
80 | cluster_shape_mnk = (cluster_M, cluster_N, 1)
81 | if not GemmCls.is_valid_dtypes(
82 | tensor_infos["A"].dtype,
83 | tensor_infos["B"].dtype,
84 | acc_dtype,
85 | tensor_infos["D"].dtype,
86 | tensor_infos["A"].major,
87 | tensor_infos["B"].major,
88 | ):
89 | raise TypeError("Skipping due to unsupported combination of types and majors")
90 |
91 | max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
92 | GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
93 |
94 | def scalar_arg(scalar: float | Tensor):
95 | if isinstance(scalar, float):
96 | return Float32(scalar) if scalar != 1.0 else None
97 | else:
98 | assert isinstance(scalar, Tensor)
99 | return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4)
100 |
101 | epi_args = GemmCls.EpilogueArguments(
102 | scalar_arg(alpha),
103 | scalar_arg(beta),
104 | mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
105 | leading_dim=1
106 | )
107 | if rowvec_bias is not None
108 | else None,
109 | mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic(
110 | leading_dim=1 if cu_seqlens_m is None else 0
111 | )
112 | if colvec_bias is not None
113 | else None,
114 | add_to_output=add_to_output,
115 | )
116 | scheduler_args = GemmWrapperBase.create_scheduler_args(
117 | max_active_clusters,
118 | tile_count_semaphore,
119 | batch_idx_permute,
120 | max_swizzle_size,
121 | )
122 |
123 | # Create varlen arguments if needed (assumes persistent=True when varlen)
124 | varlen_args = GemmWrapperBase.create_varlen_args(
125 | cu_seqlens_m,
126 | cu_seqlens_k,
127 | A_idx,
128 | max_active_clusters,
129 | cluster_shape_mnk,
130 | tensor_infos,
131 | GemmCls.num_epi_tensormaps,
132 | pingpong,
133 | )
134 |
135 | current_stream = cutlass_torch.current_stream()
136 | compile_key = GemmWrapperBase.get_compile_key(
137 | tensor_infos,
138 | None, # activation
139 | tile_shape_mn,
140 | cluster_shape_mnk,
141 | pingpong,
142 | persistent,
143 | tile_count_semaphore is not None,
144 | device_capacity,
145 | # Technically we don't need to recompile for different max_swizzle_size, but currently
146 | # not recompiling will skew the autotuning results due to power throttling.
147 | # Effectively we're recompiling as a way to pause between benchmarks during autotuning.
148 | max_swizzle_size,
149 | rowvec_bias.dtype if rowvec_bias is not None else None,
150 | colvec_bias.dtype if colvec_bias is not None else None,
151 | 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0),
152 | 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0),
153 | add_to_output,
154 | cu_seqlens_m is not None,
155 | cu_seqlens_k is not None,
156 | gather_A,
157 | batch_idx_permute is not None,
158 | key_tensor_names=("A", "B", "D", "C"),
159 | )
160 | cache = gemm.compile_cache
161 | if compile_key not in cache:
162 | if device_capacity[0] == 9:
163 | GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
164 | gemm_obj = GemmCls(
165 | acc_dtype,
166 | tensor_infos["A"].dtype,
167 | tile_shape_mn,
168 | cluster_shape_mnk,
169 | gather_A=gather_A,
170 | )
171 | cache[compile_key] = cute.compile(
172 | gemm_obj,
173 | tensor_infos["A"].cute_tensor,
174 | tensor_infos["B"].cute_tensor,
175 | tensor_infos["D"].cute_tensor,
176 | tensor_infos["C"].cute_tensor,
177 | epi_args,
178 | scheduler_args,
179 | varlen_args,
180 | current_stream,
181 | )
182 | cache[compile_key](
183 | tensor_infos["A"].cute_tensor,
184 | tensor_infos["B"].cute_tensor,
185 | tensor_infos["D"].cute_tensor,
186 | tensor_infos["C"].cute_tensor,
187 | epi_args,
188 | scheduler_args,
189 | varlen_args,
190 | current_stream,
191 | )
192 |
193 |
194 | gemm.compile_cache = {}
195 |
--------------------------------------------------------------------------------
/quack/gemm_dact.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao.
2 | from typing import Optional, Tuple
3 | from functools import partial
4 |
5 | from torch import Tensor
6 |
7 | import cutlass
8 | import cutlass.cute as cute
9 | from cutlass import Float32, const_expr
10 | import cutlass.torch as cutlass_torch
11 |
12 | from quack.gemm_sm90 import GemmSm90
13 | from quack.gemm_sm100 import GemmSm100
14 | from quack.gemm_default_epi import GemmDefaultEpiMixin
15 | from quack.gemm_act import GemmActMixin
16 | from quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters
17 | from quack.gemm_wrapper_utils import GemmWrapperBase
18 | import quack.activation
19 |
20 |
21 | class GemmDActMixin(GemmActMixin):
22 | # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout)
23 | # and return 2 arguments (dx, out)
24 | EpilogueArguments = GemmActMixin.EpilogueArguments
25 | EpilogueParams = GemmActMixin.EpilogueParams
26 |
27 | @cute.jit
28 | def epi_visit_subtile(
29 | self,
30 | params: EpilogueParams,
31 | epi_loop_tensors: Tuple[cute.Tensor, ...],
32 | tRS_rD: cute.Tensor,
33 | tRS_rC: Optional[cute.Tensor] = None,
34 | ) -> Optional[cute.Tensor]:
35 | assert tRS_rC is not None
36 | # We don't add C to the accumulator
37 | GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None)
38 | tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype)
39 | tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype))
40 | # If we don't have .shape here, the compiler generates local stores and loads
41 | if const_expr(params.act_fn is not None):
42 | tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype)
43 | if const_expr(self.arch < 100):
44 | for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True):
45 | tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i])
46 | else:
47 | for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True):
48 | (
49 | (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
50 | (tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]),
51 | ) = params.act_fn(
52 | (tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]),
53 | (tRS_rD[2 * i], tRS_rD[2 * i + 1]),
54 | )
55 | else:
56 | tRS_rPostAct = tRS_rC_acc
57 | # Type conversion
58 | tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype)
59 | tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype))
60 | return tRS_rPostAct_out
61 |
62 |
63 | class GemmDActSm90(GemmDActMixin, GemmSm90):
64 | pass
65 |
66 |
67 | class GemmDActSm100(GemmDActMixin, GemmSm100):
68 | pass
69 |
70 |
71 | dact_fn_map = {
72 | None: None,
73 | "relu": quack.activation.drelu,
74 | "relu_sq": quack.activation.drelu_sq,
75 | "gelu_tanh_approx": quack.activation.dgelu_tanh_approx,
76 | }
77 |
78 |
79 | def gemm_dact(
80 | A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m
81 | B: Tensor, # (l, n, k)
82 | Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m
83 | PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
84 | PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m
85 | tile_count_semaphore: Optional[Tensor], # (1,)
86 | activation: Optional[str],
87 | tile_M: int,
88 | tile_N: int,
89 | cluster_M: int,
90 | cluster_N: int,
91 | pingpong: bool = True,
92 | persistent: bool = True,
93 | max_swizzle_size: int = 8,
94 | cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length
95 | A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m
96 | ) -> None:
97 | if cu_seqlens_m is not None:
98 | assert persistent, "varlen_m requires persistent=True"
99 | assert A.stride(-1) == 1, "varlen_m requires A to be k-major"
100 | assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major"
101 | assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major"
102 | assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major"
103 | gather_A = A_idx is not None
104 | if gather_A:
105 | assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)"
106 | assert cluster_N == 1, "gather_A requires cluster_N=1"
107 | assert activation in dact_fn_map, f"Unsupported activation {activation}"
108 |
109 | L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors(
110 | A,
111 | B,
112 | Out,
113 | PreAct,
114 | additional_tensors={"PostAct": PostAct},
115 | cu_seqlens_m=cu_seqlens_m,
116 | A_idx=A_idx,
117 | )
118 | GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None)
119 | GemmWrapperBase.extract_dtypes(tensor_infos)
120 | major_configs = {
121 | "A": ("m", "k", "l"),
122 | "B": ("n", "k", "l"),
123 | "D": ("m", "n", "l"),
124 | "C": ("m", "n", "l"),
125 | "PostAct": ("m", "n", "l"),
126 | }
127 | GemmWrapperBase.determine_major_orders(tensor_infos, major_configs)
128 |
129 | device_capacity = get_device_capacity(A.device)
130 | assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported"
131 | GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90
132 |
133 | acc_dtype = Float32
134 | tile_shape_mn = (tile_M, tile_N)
135 | cluster_shape_mnk = (cluster_M, cluster_N, 1)
136 | if not GemmCls.is_valid_dtypes(
137 | tensor_infos["A"].dtype,
138 | tensor_infos["B"].dtype,
139 | acc_dtype,
140 | tensor_infos["D"].dtype,
141 | tensor_infos["A"].major,
142 | tensor_infos["B"].major,
143 | ):
144 | raise TypeError("Skipping due to unsupported combination of types and majors")
145 |
146 | max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0
147 | GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs)
148 | act_fn = dact_fn_map[activation]
149 | epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn)
150 | scheduler_args = GemmWrapperBase.create_scheduler_args(
151 | max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size
152 | )
153 |
154 | # Create varlen arguments if needed (assumes persistent=True when varlen_m)
155 | varlen_args = GemmWrapperBase.create_varlen_args(
156 | cu_seqlens_m,
157 | None, # cu_seqlens_k
158 | A_idx,
159 | max_active_clusters,
160 | cluster_shape_mnk,
161 | tensor_infos,
162 | GemmCls.num_epi_tensormaps,
163 | pingpong,
164 | )
165 |
166 | current_stream = cutlass_torch.current_stream()
167 | compile_key = GemmWrapperBase.get_compile_key(
168 | tensor_infos,
169 | activation,
170 | tile_shape_mn,
171 | cluster_shape_mnk,
172 | pingpong,
173 | persistent,
174 | tile_count_semaphore is not None,
175 | device_capacity,
176 | max_swizzle_size,
177 | cu_seqlens_m is not None,
178 | A_idx is not None,
179 | key_tensor_names=("A", "B", "D", "PostAct", "C"),
180 | )
181 | cache = gemm_dact.compile_cache
182 | if compile_key not in cache:
183 | if device_capacity[0] == 9:
184 | GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent)
185 | gemm = GemmCls(
186 | acc_dtype,
187 | tensor_infos["A"].dtype,
188 | tile_shape_mn,
189 | cluster_shape_mnk,
190 | gather_A=gather_A,
191 | )
192 | cache[compile_key] = cute.compile(
193 | gemm,
194 | tensor_infos["A"].cute_tensor,
195 | tensor_infos["B"].cute_tensor,
196 | tensor_infos["D"].cute_tensor,
197 | tensor_infos["C"].cute_tensor,
198 | epi_args,
199 | scheduler_args,
200 | varlen_args,
201 | current_stream,
202 | )
203 | cache[compile_key](
204 | tensor_infos["A"].cute_tensor,
205 | tensor_infos["B"].cute_tensor,
206 | tensor_infos["D"].cute_tensor,
207 | tensor_infos["C"].cute_tensor,
208 | epi_args,
209 | scheduler_args,
210 | varlen_args,
211 | current_stream,
212 | )
213 |
214 |
215 | gemm_dact.compile_cache = {}
216 |
--------------------------------------------------------------------------------
/quack/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2 |
3 | import math
4 | from functools import partial
5 | from typing import Optional, Tuple, Union
6 |
7 | import cutlass
8 | import cutlass.cute as cute
9 |
10 | from cutlass import Float32, Int32, Boolean, const_expr
11 | from cutlass.cutlass_dsl import T, dsl_user_op
12 | from cutlass._mlir.dialects import llvm, nvvm, vector
13 |
14 |
15 | # cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default
16 | fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
17 | mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
18 | add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN)
19 | sub_packed_f32x2 = partial(
20 | cute.arch.calc_packed_f32x2_op,
21 | src_c=None,
22 | calc_func=nvvm.sub_packed_f32x2,
23 | rnd=nvvm.RoundingModeKind.RN,
24 | )
25 |
26 |
27 | @dsl_user_op
28 | def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
29 | return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
30 |
31 |
32 | @cute.jit
33 | def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
34 | if const_expr(isinstance(x, cute.Pointer)):
35 | return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
36 | else:
37 | assert isinstance(x, Float32)
38 | return x
39 |
40 |
41 | @dsl_user_op
42 | def set_block_rank(
43 | smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None
44 | ) -> Int32:
45 | """Map the given smem pointer to the address at another CTA rank in the cluster."""
46 | smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
47 | return Int32(
48 | llvm.inline_asm(
49 | T.i32(),
50 | [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
51 | "mapa.shared::cluster.u32 $0, $1, $2;",
52 | "=r,r,r",
53 | has_side_effects=False,
54 | is_align_stack=False,
55 | asm_dialect=llvm.AsmDialect.AD_ATT,
56 | )
57 | )
58 |
59 |
60 | @dsl_user_op
61 | def store_shared_remote(
62 | val: float | Float32 | Int32 | cutlass.Int64,
63 | smem_ptr: cute.Pointer,
64 | mbar_ptr: cute.Pointer,
65 | peer_cta_rank_in_cluster: cute.typing.Int,
66 | *,
67 | loc=None,
68 | ip=None,
69 | ) -> None:
70 | remote_smem_ptr_i32 = set_block_rank(
71 | smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
72 | ).ir_value()
73 | remote_mbar_ptr_i32 = set_block_rank(
74 | mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
75 | ).ir_value()
76 | if const_expr(isinstance(val, float)):
77 | val = Float32(val)
78 | assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
79 | suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
80 | constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
81 | llvm.inline_asm(
82 | None,
83 | [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
84 | f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
85 | f"r,{constraint},r",
86 | has_side_effects=True,
87 | is_align_stack=False,
88 | asm_dialect=llvm.AsmDialect.AD_ATT,
89 | )
90 |
91 |
92 | @dsl_user_op
93 | def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
94 | return Float32(
95 | nvvm.fmin(
96 | T.f32(),
97 | Float32(a).ir_value(loc=loc, ip=ip),
98 | Float32(b).ir_value(loc=loc, ip=ip),
99 | loc=loc,
100 | ip=ip,
101 | )
102 | )
103 |
104 |
105 | @dsl_user_op
106 | def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
107 | return Float32(
108 | llvm.inline_asm(
109 | T.f32(),
110 | [Float32(a).ir_value(loc=loc, ip=ip)],
111 | "sqrt.approx.f32 $0, $1;",
112 | "=f,f",
113 | has_side_effects=False,
114 | is_align_stack=False,
115 | asm_dialect=llvm.AsmDialect.AD_ATT,
116 | )
117 | )
118 |
119 |
120 | @dsl_user_op
121 | def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
122 | return Int32(
123 | llvm.inline_asm(
124 | T.i32(),
125 | [Float32(a).ir_value(loc=loc, ip=ip)],
126 | "cvt.rpi.ftz.s32.f32 $0, $1;",
127 | "=r,f",
128 | has_side_effects=False,
129 | is_align_stack=False,
130 | asm_dialect=llvm.AsmDialect.AD_ATT,
131 | )
132 | )
133 |
134 |
135 | @dsl_user_op
136 | def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
137 | return Int32(
138 | llvm.inline_asm(
139 | T.i32(),
140 | [
141 | Int32(a).ir_value(loc=loc, ip=ip),
142 | Int32(b).ir_value(loc=loc, ip=ip),
143 | Int32(c).ir_value(loc=loc, ip=ip),
144 | ],
145 | "prmt.b32 $0, $1, $2, $3;",
146 | "=r,r,r,r",
147 | has_side_effects=False,
148 | is_align_stack=False,
149 | asm_dialect=llvm.AsmDialect.AD_ATT,
150 | )
151 | )
152 |
153 |
154 | @cute.jit
155 | def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
156 | # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
157 | tApA = cute.make_fragment(
158 | cute.make_layout(
159 | (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
160 | stride=(cute.size(tAcA, mode=[2]), 0, 1),
161 | ),
162 | Boolean,
163 | )
164 | for rest_v in cutlass.range_constexpr(tApA.shape[0]):
165 | for rest_k in cutlass.range_constexpr(tApA.shape[2]):
166 | tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
167 | return tApA
168 |
169 |
170 | @cute.jit
171 | def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None:
172 | """Fill out-of-bounds values in shared memory tensor.
173 |
174 | Args:
175 | tXsX: Shared memory tensor to fill
176 | tXpX: Predicate tensor indicating valid elements
177 | fill_value: Value to fill OOB locations with
178 | """
179 | tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
180 | tXrX_fill.fill(fill_value)
181 | for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
182 | for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
183 | if const_expr(tXpX is not None):
184 | if not tXpX[rest_v, 0, rest_k]:
185 | cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
186 | else:
187 | cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
188 |
189 |
190 | @dsl_user_op
191 | def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
192 | vec_f32x2 = vector.from_elements(
193 | T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
194 | )
195 | vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
196 | res = cutlass.Int64(
197 | vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
198 | )
199 | return res
200 |
201 |
202 | @dsl_user_op
203 | def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
204 | vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
205 | vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
206 | res0 = Float32(
207 | vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
208 | )
209 | res1 = Float32(
210 | vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
211 | )
212 | return res0, res1
213 |
214 |
215 | @cute.jit
216 | def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32:
217 | if const_expr(lane is None):
218 | lane = cute.arch.lane_idx()
219 | for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
220 | offset = 1 << i
221 | # Very important that we set mask_and_clamp to 0
222 | partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
223 | if lane >= offset:
224 | val += partial_sum
225 | return val
226 |
227 |
228 | @dsl_user_op
229 | def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
230 | return nvvm.atomicrmw(
231 | res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
232 | )
233 |
234 |
235 | @dsl_user_op
236 | def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
237 | return nvvm.atomicrmw(
238 | res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
239 | )
240 |
--------------------------------------------------------------------------------
/docs/limitations.rst:
--------------------------------------------------------------------------------
1 | .. _limitations:
2 |
3 | Limitations
4 | ====================
5 |
6 |
7 | Overview
8 | ---------------------
9 | CuTe DSL is an embedded domain-specific language within Python. It utilizes a subset of Python's
10 | syntax to provide a streamlined programming experience. It is important to understand that CuTe DSL
11 | does NOT implement the complete Python language semantics in its JIT compilation process.
12 |
13 | Programming Model
14 | ---------------------
15 |
16 | **Python Native Data Types**
17 | CuTe DSL supports Python data structures when used for "meta-programming,"
18 | but these structures cannot be treated as dynamic values modifiable at runtime.
19 | For instance, lists and dictionaries can be used to configure kernel parameters
20 | during compilation or serve as containers for dynamic values,
21 | but their structure and organization cannot be altered during kernel execution.
22 |
23 | - **Static Values:**
24 | - Evaluated during JIT compilation phase
25 | - Immutable after compilation completes
26 | - Most Python native types (lists, tuples, dictionaries) are processed as static values
27 | - Primarily utilized for "meta-programming" and configuration purposes
28 | - Example: Lists can contain dynamic values but their structure cannot
29 | be modified during kernel execution
30 |
31 | - **Dynamic Values:**
32 | - Evaluated during runtime execution
33 | - Modifiable during execution of JIT-compiled functions
34 | - Only a specific subset of Python types are supported as dynamic values
35 | - Primitive types are automatically converted when passed as function arguments:
36 |
37 | - ``int`` → ``Int32`` (may be updated to ``Int64`` in future releases)
38 | - ``bool`` → ``Bool``
39 | - ``float`` → ``Float32`` (may be updated to ``Float64`` in future releases)
40 |
41 | The JIT compiler processes Python native types analogously to C++ template parameters.
42 | The compiled code cannot manipulate dynamic values of composite types
43 | such as lists, tuples, or dictionaries.
44 |
45 | For example, following code doesn't work as traditional Python program inside JIT function.
46 |
47 | .. code:: python
48 |
49 | @cute.jit
50 | def foo(a: Float32, b: Float32, i: Int32, res: cute.Tensor):
51 | xs = [a, b]
52 | # indexing list with dynamic index is not supported in CuTe DSL:
53 | res[0] = xs[i]
54 |
55 | if i == 0:
56 | # This will alway append Float32(3.0) to the list regardless
57 | # of the runtime value of `i`
58 | xs.append(Float32(3.0))
59 |
60 | for i in range(10):
61 | # This only append one element to the list at compile-time
62 | # as loop doesn't unroll at compile-time
63 | xs.append(Float32(1.0))
64 |
65 | **Python Function**
66 | The DSL currently does not implement support for return values from Python functions,
67 | although this capability is planned for future releases.
68 |
69 | Example:
70 |
71 | .. code:: python
72 |
73 | @cute.jit
74 | def foo():
75 | return 1 # Currently unsupported in CuTe DSL
76 |
77 | **Expression or Statement with Dependent Type**
78 | CuTe DSL implements static typing and does not support dependent types.
79 | The type of each expression must be determinable during compile time,
80 | in contrast to standard Python which implements dynamic typing.
81 |
82 | Example illustrating functionality in Python that is not supported in the DSL:
83 |
84 | .. code:: python
85 |
86 | # Valid in standard Python, but unsupported in CuTe DSL
87 | max(int(1), float(2.0)) # => 2.0 : float
88 | max(int(3), float(2.0)) # => 3 : int
89 |
90 | In CuTe DSL, types are promoted. For example:
91 |
92 | .. code:: python
93 |
94 | @cute.jit
95 | def foo(a: Int32, b: Float32, res: cute.Tensor):
96 | res[0] = max(a, b) # Type is automatically promoted to Float32
97 |
98 | Following code using inlined if-else expression with dependent types
99 | is not supported in CuTe DSL:
100 |
101 | .. code:: python
102 |
103 | @cute.jit
104 | def foo(cond: Boolean, a: Int32, b: Float32, res: cute.Tensor):
105 | res[0] = a if cond else b
106 |
107 |
108 | **Control Flow**
109 | The DSL transforms Python control flow statements (``if``, ``for``, ``while``)
110 | during Abstract Syntax Tree (AST) processing into structured control flow in MLIR
111 | which has the same constraints as dependent types. For instance,
112 | changing type of a variable in loop body is not allowed.
113 |
114 | - Variables must be defined prior to the control flow statement
115 | - Type consistency must be maintained throughout the control flow statement
116 | - Don't support early exit or return from if-else statements
117 |
118 | Example illustrating functionality in Python that is not supported in the DSL:
119 |
120 | .. code:: python
121 |
122 | @cute.jit
123 | def foo():
124 | a = Int32(1)
125 | for i in range(10):
126 | a = Float32(2) # Changing type inside loop-body is not allowed in the DSL
127 |
128 |
129 | **Built-in Operators**
130 | The DSL transforms built-in operators like ``and``, ``or``, ``max``, ``min``, etc.
131 | into MLIR operations. They also follow the same constraints of dependent types.
132 | For instance, ``a and b`` requires ``a`` and ``b`` to be of the same type.
133 |
134 |
135 | **Special Variables**
136 | The DSL treats ``_`` as a special variable that it's value is meant to be ignored.
137 | It is not allowed to read ``_`` in the DSL.
138 |
139 | Example illustrating functionality in Python that is not supported in the DSL:
140 |
141 | .. code:: python
142 |
143 | @cute.jit
144 | def foo():
145 | _ = 1
146 | print(_) # This is not allowed in the DSL
147 |
148 |
149 | **Object Oriented Programming**
150 | The DSL is implemented on top of Python and supports Python's object-oriented programming (OOP) features
151 | for meta-programming at compile-time.
152 |
153 | However, similar to other composed data types, the DSL provides limited support for OOP when objects
154 | contain dynamic values. It is strongly recommended to avoid passing dynamic values between member methods
155 | through class state in your code.
156 |
157 | The following example illustrates functionality in Python that is not supported in the DSL
158 | without implementing the ``DynamicExpression`` protocol:
159 |
160 | .. code:: python
161 |
162 | class Foo:
163 | def __init__(self, a: Int32):
164 | self.a = a
165 |
166 | def set_a(self, i: Int32):
167 | self.a = i
168 |
169 | def get_a(self):
170 | return self.a
171 |
172 | @cute.jit
173 | def foo(a: Int32, res: cute.Tensor):
174 | foo = Foo(a)
175 | for i in range(10):
176 | foo.set_a(i)
177 |
178 | # This fails to compile because `a` is assigned a local value defined within the for-loop body
179 | # and is not visible outside of the loop body
180 | res[0] = foo.get_a()
181 |
182 | The example above fails to compile because ``Foo.a`` is assigned a local value defined within the for-loop body,
183 | which is not visible outside the loop body.
184 |
185 | The CuTe DSL implements an internal mechanism that provides limited support for OOP patterns via protocol.
186 | As the DSL continues to evolve to support additional features, this mechanism is subject to change
187 | and is not recommended for direct use in users' code for better portability.
188 |
189 |
190 | **CuTe Layout algebra in native Python**
191 | Entirety of CuTe Layout algebra operations and APIs require JIT compilation. These
192 | functionalities are exclusively available within JIT-compiled functions and cannot be
193 | accessed in standard Python execution environments.
194 |
195 | Additionally, there exists a restricted set of data types that can be passed as arguments
196 | to JIT-compiled functions, which further constrains their usage in native Python contexts.
197 | Only following CuTe algebra types are supported as JIT function arguments: ``Tensor``, ``Pointer``,
198 | ``Shape``, ``Stride``, ``Coord`` and ``IntTuple``. For ``Stride``, we don't support ``ScacledBasis``
199 | from native Python Context. Unfortunately, in the first release, we don't support
200 | passing ``Layout`` under native Python Context.
201 |
202 |
203 | Suggestions
204 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
205 |
206 | For reliable and predictable results:
207 |
208 | - Avoid dependent types in your code
209 | - Implement explicit type conversion for dynamic values
210 | - Clearly distinguish between static (compile-time) and dynamic (runtime) values
211 | - Use type annotations as much as possible to help JIT compiler
212 | to identify type to avoid ambiguity
213 |
214 |
215 | .. code:: python
216 |
217 | # Example demonstrating explicit typing
218 | alpha = 1.0 # Explicitly defined as float using `1.0` instead of `1`
219 | # or `float(1)`
220 | beta = 2.0 # Explicitly defined as float
221 | result = max(alpha, beta) # Will correctly perform float comparison
222 |
--------------------------------------------------------------------------------
/quack/sort/sorting_networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao.
2 | """
3 | Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html
4 |
5 | This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly.
6 | """
7 |
8 | # fmt: off
9 | # ruff: noqa
10 | # isort: skip_file
11 |
12 | import cutlass
13 | import cutlass.cute as cute
14 |
15 | from quack.sort.utils import compare_and_swap
16 |
17 |
18 | networks = {
19 | # Size 2: 1 CEs, depth 1
20 | 2: [[(0, 1)]],
21 |
22 | # Size 4: 5 CEs, depth 3
23 | 4: [
24 | [(0, 2), (1, 3)],
25 | [(0, 1), (2, 3)],
26 | [(1, 2)]
27 | ],
28 |
29 | # Size 8: 19 CEs, depth 6
30 | 8: [
31 | [(0, 2), (1, 3), (4, 6), (5, 7)],
32 | [(0, 4), (1, 5), (2, 6), (3, 7)],
33 | [(0, 1), (2, 3), (4, 5), (6, 7)],
34 | [(2, 4), (3, 5)],
35 | [(1, 4), (3, 6)],
36 | [(1, 2), (3, 4), (5, 6)]
37 | ],
38 |
39 | # Size 16: 60 CEs, depth 10
40 | 16: [
41 | [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)],
42 | [(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)],
43 | [(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)],
44 | [(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)],
45 | [(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)],
46 | [(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)],
47 | [(2, 4), (3, 6), (9, 12), (11, 13)],
48 | [(3, 5), (6, 8), (7, 9), (10, 12)],
49 | [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)],
50 | [(6, 7), (8, 9)]
51 | ],
52 |
53 | # Size 32: 185 CEs, depth 14
54 | 32: [
55 | [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)],
56 | [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)],
57 | [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)],
58 | [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)],
59 | [(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)],
60 | [(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)],
61 | [(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)],
62 | [(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)],
63 | [(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)],
64 | [(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)],
65 | [(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)],
66 | [(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)],
67 | [(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)],
68 | [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)]
69 | ],
70 |
71 | # Size 64: 521 CEs, depth 21
72 | 64: [
73 | [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)],
74 | [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)],
75 | [(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)],
76 | [(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)],
77 | [(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)],
78 | [(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)],
79 | [(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)],
80 | [(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)],
81 | [(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)],
82 | [(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)],
83 | [(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)],
84 | [(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)],
85 | [(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)],
86 | [(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)],
87 | [(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)],
88 | [(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)],
89 | [(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)],
90 | [(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)],
91 | [(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)],
92 | [(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)],
93 | [(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)]
94 | ],
95 |
96 | }
97 |
98 |
99 | @cute.jit
100 | def optimal_sort(
101 | arr: cute.Tensor,
102 | n: cutlass.Constexpr[int],
103 | start: cutlass.Constexpr[int] = 0,
104 | ascending: cutlass.Constexpr[bool] = True
105 | ) -> None:
106 | """
107 | Optimal sorting network dispatcher.
108 |
109 | Args:
110 | arr: Array to sort
111 | n: Size of array (must be power of 2 and available in networks)
112 | start: Starting index (default 0)
113 | ascending: Sort in ascending order (default True)
114 |
115 | Source: https://bertdobbelaere.github.io/sorting_networks.html
116 | """
117 | assert n in networks
118 | for level in networks[n]:
119 | for i, j in level:
120 | compare_and_swap(arr, start + i, start + j, ascending)
121 |
--------------------------------------------------------------------------------
/quack/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 | from torch.amp import custom_fwd, custom_bwd
9 |
10 |
11 | from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
12 |
13 |
14 | def linear_fwd_convert_type(*tensors):
15 | autocast_dtype = torch.get_autocast_dtype("cuda")
16 | if torch.is_autocast_enabled():
17 | tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
18 | return tensors
19 |
20 |
21 | def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
22 | needs_input_grad, needs_weight_grad = needs_x_w_grad
23 | if not needs_input_grad:
24 | weight, weight_og = None, None
25 | if not needs_weight_grad:
26 | x = None
27 | ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
28 |
29 |
30 | def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
31 | if ctx.needs_input_grad[0]:
32 | assert weight is not None
33 | return matmul_fn(dout, weight)
34 | else:
35 | return None
36 |
37 |
38 | def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
39 | if ctx.needs_input_grad[1]:
40 | assert x is not None
41 | x = x.reshape(-1, x.shape[-1])
42 | # fuse_grad_accum is not compatible with torch.compile
43 | if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
44 | dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
45 | else:
46 | # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
47 | matmul_inplace_fn(dout.T, x, weight_og.grad)
48 | dweight = weight_og.grad
49 | weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
50 | else:
51 | dweight = None
52 | return dweight
53 |
54 |
55 | class LinearFunc(torch.autograd.Function):
56 | matmul_fwd_fn = gemm
57 | matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
58 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
59 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
60 |
61 | # Use classmethod instead of staticmethod to allow inheritance
62 | @classmethod
63 | @custom_fwd(device_type="cuda")
64 | def forward(cls, ctx, x, weight, bias=None, fuse_grad_accum=False):
65 | """
66 | x: (..., in_features)
67 | weight: (out_features, in_features)
68 | bias: (out_features,) or None
69 | out: (..., out_features)
70 | """
71 | ctx.weight_dtype = weight.dtype
72 | ctx.fuse_grad_accum = fuse_grad_accum
73 | weight_og = weight
74 | x, weight = linear_fwd_convert_type(x, weight)
75 | batch_shape = x.shape[:-1]
76 | x = x.reshape(-1, x.shape[-1])
77 | # out = F.linear(x, weight)
78 | out = cls.matmul_fwd_fn(x, weight.T, bias=bias)
79 | linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
80 | ctx.bias_dtype = bias.dtype if bias is not None else None
81 | return out.reshape(*batch_shape, out.shape[-1])
82 |
83 | @classmethod
84 | @custom_bwd(device_type="cuda")
85 | def backward(cls, ctx, dout, *args):
86 | """
87 | dout: (..., out_features)
88 | """
89 | x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
90 | batch_shape = dout.shape[:-1]
91 | dout = dout.reshape(-1, dout.shape[-1])
92 | dbias = (
93 | dout.sum(0, dtype=ctx.bias_dtype)
94 | if ctx.bias_dtype is not None and ctx.needs_input_grad[2]
95 | else None
96 | )
97 | dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
98 | dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
99 | dweight = linear_bwd_compute_weight_grad(
100 | ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
101 | )
102 | # return extra Nones for other classes that inherit from LinearFunc
103 | return dx, dweight, dbias, *([None] * 10)
104 |
105 |
106 | class LinearUntunedFunc(LinearFunc):
107 | # Passing in tuned=False to disable tuning at runtime
108 | matmul_fwd_fn = partial(gemm, tuned=False)
109 | matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
110 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
111 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
112 |
113 |
114 | def linear_func(x, weight, bias=None, fuse_grad_accum=False, tuned=True):
115 | fn_cls = LinearFunc if tuned else LinearUntunedFunc
116 | return fn_cls.apply(x, weight, bias, fuse_grad_accum)
117 |
118 |
119 | class LinearActFunc(LinearFunc):
120 | matmul_fwd_fn = gemm_act
121 |
122 | # Use classmethod instead of staticmethod to allow inheritance
123 | @classmethod
124 | @custom_fwd(device_type="cuda")
125 | def forward(
126 | cls, ctx, x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False
127 | ):
128 | """
129 | x: (..., in_features)
130 | weight: (out_features, in_features)
131 | bias: (out_features,) or None
132 | out: (..., out_features)
133 | Return both out and post-activation, but only out is differentiable.
134 | """
135 | ctx.weight_dtype = weight.dtype
136 | ctx.fuse_grad_accum = fuse_grad_accum
137 | weight_og = weight
138 | x, weight = linear_fwd_convert_type(x, weight)
139 | batch_shape = x.shape[:-1]
140 | x = x.reshape(-1, x.shape[-1])
141 | out, postact = cls.matmul_fwd_fn(
142 | x, weight.T, bias=bias, activation=activation, store_preact=store_preact
143 | )
144 | linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
145 | if out is not None:
146 | out = out.reshape(*batch_shape, out.shape[-1])
147 | ctx.bias_dtype = bias.dtype if bias is not None else None
148 | ctx.mark_non_differentiable(postact)
149 | ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
150 | return out, postact.reshape(*batch_shape, postact.shape[-1])
151 |
152 |
153 | class LinearActUntunedFunc(LinearActFunc):
154 | # Passing in tuned=False to disable tuning at runtime
155 | matmul_fwd_fn = partial(gemm_act, tuned=False)
156 | matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
157 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
158 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
159 |
160 |
161 | def linear_act_func(
162 | x, weight, activation, bias=None, store_preact=True, fuse_grad_accum=False, tuned=True
163 | ):
164 | fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
165 | return fn_cls.apply(x, weight, activation, bias, store_preact, fuse_grad_accum)
166 |
167 |
168 | class DActLinearFunc(LinearFunc):
169 | matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
170 |
171 | # Use classmethod instead of staticmethod to allow inheritance
172 | @classmethod
173 | @custom_fwd(device_type="cuda")
174 | def forward(cls, ctx, preact, weight, x, activation, fuse_grad_accum=False):
175 | """
176 | x: (..., in_features)
177 | weight: (out_features, in_features)
178 | out: (..., out_features)
179 | Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
180 | """
181 | ctx.weight_dtype = weight.dtype
182 | ctx.fuse_grad_accum = fuse_grad_accum
183 | weight_og = weight
184 | x, weight = linear_fwd_convert_type(x, weight)
185 | batch_shape = x.shape[:-1]
186 | x = x.reshape(-1, x.shape[-1])
187 | out = cls.matmul_fwd_fn(x, weight.T)
188 | # Store preact instead of x, we will recompute x in the backward pass
189 | linear_fwd_postprocess(
190 | ctx, preact, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
191 | )
192 | ctx.activation = activation
193 | return out.reshape(*batch_shape, out.shape[-1])
194 |
195 | @classmethod
196 | @custom_bwd(device_type="cuda")
197 | def backward(cls, ctx, dout):
198 | """
199 | dout: (..., out_features)
200 | """
201 | # weight_og is None if not ctx.fuse_grad_accum
202 | preact, weight, weight_og = ctx.saved_tensors
203 | batch_shape = dout.shape[:-1]
204 | dout = dout.reshape(-1, dout.shape[-1])
205 | preact = preact.reshape(-1, preact.shape[-1])
206 | if ctx.needs_input_grad[0]:
207 | assert weight is not None
208 | dpreact, x = cls.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
209 | else:
210 | dpreact, x = None, None
211 | dpreact = dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
212 | dweight = linear_bwd_compute_weight_grad(
213 | ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
214 | )
215 | return dpreact, dweight, *([None] * 3)
216 |
217 |
218 | class DActLinearUntunedFunc(DActLinearFunc):
219 | # Passing in tuned=False to disable tuning at runtime
220 | matmul_fwd_fn = partial(gemm, tuned=False)
221 | matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
222 | matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
223 | matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
224 |
225 |
226 | def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
227 | fn_cls = DActLinearFunc if tuned else DActLinearUntunedFunc
228 | return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
229 |
230 |
231 | class Linear(nn.Linear):
232 | def __init__(
233 | self,
234 | in_features: int,
235 | out_features: int,
236 | bias: bool = False,
237 | device=None,
238 | dtype=None,
239 | fuse_grad_accum: bool = False,
240 | ) -> None:
241 | super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
242 | self.fuse_grad_accum = fuse_grad_accum
243 |
244 | def forward(self, input: Tensor) -> Tensor:
245 | if input.is_cuda and self.in_features % 8 == 0 and self.out_features % 8 == 0:
246 | return linear_func(input, self.weight, self.bias, fuse_grad_accum=self.fuse_grad_accum)
247 | else:
248 | return F.linear(input, self.weight, self.bias)
249 |
--------------------------------------------------------------------------------
/benchmarks/benchmark_rmsnorm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from typing import Type, Optional
4 |
5 | import torch
6 | from triton.testing import do_bench
7 |
8 | import cutlass
9 | import cutlass.torch as cutlass_torch
10 | from cutlass.cute.runtime import from_dlpack
11 | from quack.rmsnorm import rmsnorm_fwd, rmsnorm_ref, rmsnorm, rmsnorm_bwd
12 | import cutlass.cute as cute
13 |
14 | try:
15 | import cudnn
16 | except ImportError:
17 | cudnn = None
18 |
19 |
20 | def run_rmsnorm(
21 | M,
22 | N,
23 | dtype: torch.dtype,
24 | residual_dtype: Optional[torch.dtype] = None,
25 | warmup_iterations=5,
26 | iterations=100,
27 | ):
28 | if not torch.cuda.is_available():
29 | raise RuntimeError(f"Ampere GPU is required to run this example!")
30 |
31 | print(f"Tensor dimensions: [{M}, {N}]")
32 | print(f"Input and Output Data type: {dtype}")
33 |
34 | device = "cuda"
35 | x = torch.randn(M, N, device=device, dtype=dtype)
36 | if residual_dtype is not None:
37 | residual = torch.randn(M, N, device=device, dtype=residual_dtype)
38 | else:
39 | residual = None
40 | w = torch.randn(N, device=device, dtype=torch.float32)
41 |
42 | print(f"Input tensor shapes:")
43 | print(f"x: {x.shape}, dtype: {x.dtype}")
44 | print(f"w: {w.shape}, dtype: {w.dtype}")
45 |
46 | eps = 1e-6
47 |
48 | print("Executing kernel...")
49 | rmsnorm_fwd(x, w, residual=residual, eps=eps, store_rstd=True)
50 |
51 | compiled_func_ref = torch.compile(rmsnorm_ref)
52 |
53 | fn = lambda: rmsnorm_fwd(x, w, residual=residual, eps=eps)
54 | time.sleep(0.5)
55 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
56 | mem_bytes = (2 * x.numel() * dtype.itemsize + w.numel() * 4)
57 | if residual is not None:
58 | mem_bytes += 2 * residual.numel() * residual.dtype.itemsize
59 | mem_bw = round(mem_bytes / (avg_time / 1000) / 1e9)
60 | print(f"Kernel execution time: {avg_time:.4f} ms")
61 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
62 |
63 | fn = lambda: compiled_func_ref(x, w, residual=residual, eps=eps)
64 | for _ in range(5): fn() # warm up
65 | time.sleep(0.5)
66 | avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations)
67 | mem_bytes_ref = mem_bytes
68 | mem_bw_ref = round(mem_bytes_ref / (avg_time / 1000) / 1e9)
69 | print(f"Ref kernel execution time: {avg_time:.4f} ms")
70 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
71 |
72 | if cudnn is not None:
73 | run_cudnn = rmsnorm_cudnn_setup(M, N, dtype)
74 | time.sleep(0.5)
75 | avg_time = do_bench(run_cudnn, warmup=warmup_iterations, rep=iterations)
76 | mem_bytes_cudnn = (2 * x.numel() * dtype.itemsize + w.numel() * 4)
77 | mem_bw_cudnn = round(mem_bytes_cudnn / (avg_time / 1000) / 1e9)
78 | print(f"Cudnn kernel execution time: {avg_time:.4f} ms")
79 | print(f"Cudnn mem throughput: {mem_bw_cudnn:.2f} GB/s")
80 |
81 | return mem_bw, mem_bw_ref
82 |
83 |
84 | def rmsnorm_cudnn_setup(M, N, dtype):
85 | x_gpu = torch.empty(M, N, dtype=dtype, device="cuda")
86 | scale_gpu = torch.empty(1, N, dtype=dtype, device="cuda")
87 | epsilon_cpu = torch.ones((1, 1), dtype=torch.float32, device="cpu")
88 | out_gpu = torch.empty_like(x_gpu)
89 | inv_var_gpu = torch.empty(M, 1, dtype=torch.float32, device="cuda")
90 | handle = cudnn.create_handle()
91 | graph = cudnn.pygraph(
92 | handle=handle,
93 | intermediate_data_type=cudnn.data_type.FLOAT,
94 | compute_data_type=cudnn.data_type.FLOAT,
95 | )
96 | # create tensor handles with the graph API
97 | x = graph.tensor_like(x_gpu.detach()).set_name("X")
98 | scale = graph.tensor_like(scale_gpu.detach()).set_name("scale")
99 | epsilon = graph.tensor_like(epsilon_cpu).set_name("epsilon")
100 | (out, inv_var) = graph.rmsnorm(
101 | name="rmsnorm",
102 | input=x,
103 | norm_forward_phase=cudnn.norm_forward_phase.TRAINING,
104 | scale=scale,
105 | epsilon=epsilon,
106 | )
107 | # enable all outputs
108 | out.set_name("output").set_output(True).set_data_type(out_gpu.dtype)
109 | inv_var.set_name("inv_var").set_output(True).set_data_type(inv_var_gpu.dtype)
110 | graph.build([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
111 | # Mapping of (handles -> memory)
112 | variant_pack = {
113 | x: x_gpu.detach(),
114 | scale: scale_gpu.detach(),
115 | epsilon: epsilon_cpu,
116 | out: out_gpu,
117 | inv_var: inv_var_gpu,
118 | }
119 | workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
120 |
121 | def run(*args, **kwargs):
122 | graph.execute(variant_pack, workspace)
123 | return out_gpu, inv_var_gpu
124 |
125 | return run
126 |
127 |
128 | def run_rmsnorm_bwd(
129 | M,
130 | N,
131 | dtype: torch.dtype,
132 | residual_dtype: Optional[torch.dtype] = None,
133 | warmup_iterations=5,
134 | iterations=100,
135 | ):
136 | if not torch.cuda.is_available():
137 | raise RuntimeError(f"Ampere GPU is required to run this example!")
138 |
139 | print(f"Tensor dimensions: [{M}, {N}]")
140 | print(f"Input and Output Data type: {dtype}")
141 |
142 | device = "cuda"
143 |
144 | # Set up forward pass inputs with gradients enabled
145 | x = torch.randn(M, N, device=device, dtype=dtype, requires_grad=True)
146 | x_ref = x.detach().clone().requires_grad_()
147 | w = torch.randn(N, device=device, dtype=torch.float32, requires_grad=True)
148 | w_ref = w.detach().clone().requires_grad_()
149 | if residual_dtype is not None:
150 | residual = torch.randn(M, N, device=device, dtype=residual_dtype, requires_grad=True)
151 | residual_ref = residual.detach().clone().requires_grad_()
152 | else:
153 | residual, residual_ref = None, None
154 |
155 | print(f"Input tensor shapes:")
156 | print(f"x: {x.shape}, dtype: {x.dtype}")
157 | print(f"w: {w.shape}, dtype: {w.dtype}")
158 |
159 | eps = 1e-6
160 |
161 | # Forward pass to get outputs and rstd
162 | y = rmsnorm(x, w, residual=residual, eps=eps)
163 | if residual is not None:
164 | y, residual_out = y
165 | else:
166 | residual_out = None
167 |
168 | # Create upstream gradients
169 | dy = torch.randn_like(y)
170 | rstd = torch.randn(M, device=device, dtype=torch.float32)
171 | dresidual_out = torch.randn_like(residual_out) if residual is not None else None
172 |
173 | def mem_in_bytes(*args):
174 | return sum(t.numel() * t.dtype.itemsize for t in args if t is not None)
175 |
176 | time.sleep(0.5)
177 | # Benchmark custom backward pass
178 | # fn = lambda: torch.autograd.grad(y, [x, w], grad_outputs=dy, retain_graph=True)
179 | def fn():
180 | # x.grad = None # Reset gradients to avoid accumulation
181 | # y.backward(dy, retain_graph=True)
182 | rmsnorm_bwd(x if residual is None else residual_out, w, dy, rstd, dresidual_out=dresidual_out)
183 |
184 | avg_time = do_bench(fn, grad_to_none=(x,), warmup=warmup_iterations, rep=iterations)
185 | sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 2
186 | mem_bytes = mem_in_bytes(x if residual is None else residual_out, w, dy, dresidual_out, x, residual if residual is not None and residual.dtype != x.dtype else None)
187 | mem_bw = round(mem_bytes / (avg_time / 1000) / 1e9)
188 | print(f"Kernel execution time: {avg_time:.4f} ms")
189 | print(f"Mem throughput: {mem_bw:.2f} GB/s")
190 | from flash_attn.utils.benchmark import pytorch_profiler
191 | pytorch_profiler(fn)
192 |
193 | # Reference implementation
194 | y_ref = torch.compile(rmsnorm_ref)(x_ref, w_ref, eps=eps)
195 | compiled_func_ref = lambda: torch.autograd.grad(y_ref, [x_ref, w_ref], grad_outputs=dy, retain_graph=True)
196 | # def f():
197 | # x_ref.grad = None # Reset gradients to avoid accumulation
198 | # w_ref.grad = None
199 | # rmsnorm_ref(x_ref, w_ref, eps=eps).backward(dy)
200 | # compiled_func_ref = torch.compile(f)
201 |
202 | for _ in range(5): compiled_func_ref() # warm up
203 | time.sleep(0.5)
204 | avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations)
205 | mem_bytes_ref = (3 * x.numel() * dtype.itemsize + w.numel() * 4 + x.shape[0] * 4 + sm_count * w.numel() * 4)
206 | mem_bw_ref = round(mem_bytes_ref / (avg_time_ref / 1000) / 1e9)
207 | print(f"Ref kernel execution time: {avg_time_ref:.4f} ms")
208 | print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s")
209 | pytorch_profiler(compiled_func_ref)
210 |
211 | return mem_bw, mem_bw_ref
212 |
213 |
214 | if __name__ == "__main__":
215 | parser = argparse.ArgumentParser(
216 | description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels"
217 | )
218 | parser.add_argument("--M", default=32768, type=int)
219 | parser.add_argument("--N", default=32768, type=int)
220 | parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16)
221 | parser.add_argument("--residual_dtype", type=cutlass.dtype, choices=[None, cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=None)
222 | parser.add_argument("--warmup_iterations", default=10, type=int)
223 | parser.add_argument("--iterations", default=100, type=int)
224 | parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass")
225 |
226 | args = parser.parse_args()
227 |
228 | if args.backward:
229 | print("=== RMSNorm Backward Pass Benchmark ===")
230 | run_rmsnorm_bwd(
231 | args.M,
232 | args.N,
233 | dtype=cutlass.torch.dtype(args.dtype),
234 | residual_dtype=cutlass.torch.dtype(args.residual_dtype) if args.residual_dtype else None,
235 | warmup_iterations=args.warmup_iterations,
236 | iterations=args.iterations,
237 | )
238 | else:
239 | print("=== RMSNorm Forward Pass Benchmark ===")
240 | run_rmsnorm(
241 | args.M,
242 | args.N,
243 | dtype=cutlass.torch.dtype(args.dtype),
244 | residual_dtype=cutlass.torch.dtype(args.residual_dtype) if args.residual_dtype else None,
245 | warmup_iterations=args.warmup_iterations,
246 | iterations=args.iterations,
247 | )
248 | # # MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)]
249 | # # MN_pairs = [(32768, 2048)]
250 | # MN_pairs = [(16384, 65536)]
251 | # results = []
252 | # for M, N in MN_pairs:
253 | # res = run_rmsnorm(
254 | # M,
255 | # N,
256 | # dtype=cutlass.BFloat16,
257 | # skip_ref_check=False,
258 | # benchmark=True,
259 | # warmup_iterations=args.warmup_iterations,
260 | # iterations=args.iterations,
261 | # )
262 | # results.append(res)
263 | # print(results)
264 | # print("\nPASS")
265 |
--------------------------------------------------------------------------------
/quack/linear_cross_entropy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025, Tri Dao
2 | from typing import Optional, Literal
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 | from torch.amp import custom_fwd, custom_bwd
9 |
10 | from quack.cross_entropy import cross_entropy, cross_entropy_fwd_out
11 | from quack.gemm_interface import gemm, gemm_add, gemm_add_inplace
12 | from quack.linear import linear_fwd_convert_type
13 |
14 |
15 | def linear_cross_entropy_func(
16 | x: Tensor, # (..., d)
17 | weight: Tensor, # (V, d)
18 | bias: Optional[Tensor], # (V,) or None
19 | target: Tensor, # (...,), int or long
20 | ignore_index: int = -100,
21 | reduction: Literal["none", "mean", "sum"] = "mean",
22 | inplace_backward: bool = False,
23 | ) -> Tensor:
24 | y = F.linear(x, weight, bias) # (..., V)
25 | return cross_entropy(
26 | y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
27 | )
28 |
29 |
30 | def linear_cross_entropy_func_ref(
31 | x: Tensor, # (..., d)
32 | weight: Tensor, # (V, d)
33 | bias: Optional[Tensor], # (V,) or None
34 | target: Tensor, # (...,), int or long
35 | ignore_index: int = -100,
36 | reduction: Literal["none", "mean", "sum"] = "mean",
37 | ) -> Tensor:
38 | y = F.linear(x, weight, bias) # (..., V)
39 | return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
40 |
41 |
42 | def chunked_linear_cross_entropy_fwd(
43 | x: Tensor, # (B*L, d) where B is batch, L is seqlen
44 | weight: Tensor, # (V, d) where V is vocab size
45 | target: Tensor, # (B*L,)
46 | chunk_size: int = 4096,
47 | ignore_index: int = -100,
48 | tuned: bool = True,
49 | ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
50 | """
51 | Chunked forward pass for linear cross entropy.
52 |
53 | Splits input along batch dimension, computes matmul and cross_entropy_fwd
54 | for each chunk, stores dx for each chunk, and accumulates dw.
55 |
56 | Returns:
57 | loss: (B*L,) loss values
58 | dx: (B*L, d) gradient w.r.t. input
59 | dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
60 | last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
61 | last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
62 | """
63 | B_L, d = x.shape
64 | V, _ = weight.shape
65 | device = x.device
66 | num_chunks = (B_L + chunk_size - 1) // chunk_size
67 | # Since we use gemm with TMA we require some alignment
68 | assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
69 | assert B_L % 8 == 0
70 | # Pre-allocate outputs
71 | loss = torch.empty(B_L, device=device, dtype=torch.float32)
72 | logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
73 | dx = torch.empty_like(x)
74 | # Last chunk of dw will be deferred to the backward pass
75 | dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
76 | last_dlogits_chunk = None
77 | last_x_chunk = None
78 |
79 | # Process in chunks
80 | for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
81 | zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
82 | ):
83 | chunk_len = x_chunk.shape[0]
84 | logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
85 | torch.mm(x_chunk, weight.mT, out=logits_chunk)
86 | # Compute cross entropy forward with gradients
87 | dlogits_chunk = logits_chunk # inplace_backward
88 | cross_entropy_fwd_out(
89 | logits_chunk,
90 | target_chunk,
91 | None, # target_logit
92 | loss=loss_chunk,
93 | lse=None, # we don't need lse here
94 | dx=dlogits_chunk,
95 | ignore_index=ignore_index,
96 | )
97 | # Compute dx for this chunk: dlogits @ weight
98 | torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
99 | # Compute dw for all chunks except the last
100 | if i == num_chunks - 1:
101 | # Last chunk: save for backward pass
102 | last_dlogits_chunk = dlogits_chunk
103 | last_x_chunk = x_chunk
104 | elif i == 0:
105 | # First chunk: dw = dlogits.T @ x_chunk
106 | gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
107 | else:
108 | # Middle chunks: dw += dlogits.T @ x_chunk
109 | gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
110 | return loss, dx, dw, last_dlogits_chunk, last_x_chunk
111 |
112 |
113 | class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
114 | @staticmethod
115 | @custom_fwd(device_type="cuda")
116 | def forward(
117 | ctx,
118 | x: Tensor,
119 | weight: Tensor,
120 | target: Tensor,
121 | ignore_index: int = -100,
122 | reduction: Literal["mean", "sum"] = "mean",
123 | chunk_size: int = 4096,
124 | tuned: bool = True,
125 | ):
126 | """
127 | Forward pass computes loss and stores dx and dw for backward.
128 | """
129 | ctx.weight_dtype = weight.dtype
130 | x, weight = linear_fwd_convert_type(x, weight)
131 | batch_shape = x.shape[:-1]
132 | x = x.reshape(-1, x.shape[-1])
133 | # TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
134 | loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
135 | x, weight, target, chunk_size, ignore_index, tuned=tuned
136 | )
137 | loss_sum = loss.sum()
138 | loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
139 | ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
140 | ctx.batch_shape = batch_shape
141 | ctx.ignore_index = ignore_index
142 | ctx.reduction = reduction
143 | ctx.tuned = tuned
144 | return loss_sum if loss_scale is None else loss_sum * loss_scale
145 |
146 | @staticmethod
147 | @custom_bwd(device_type="cuda")
148 | def backward(ctx, dloss):
149 | """
150 | Backward pass scales pre-computed gradients by dloss and completes
151 | the last chunk's dw computation.
152 | dloss is a scalar.
153 | """
154 | dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
155 | tuned = ctx.tuned
156 | if loss_scale is not None:
157 | dloss = dloss * loss_scale
158 | # TODO: the case where x or weight doesn't require grad
159 | dx.mul_(dloss)
160 | dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
161 | # Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
162 | if dw is None:
163 | # Only had one chunk, compute dw directly with dloss scaling
164 | dw = gemm(
165 | last_dlogits_chunk.T,
166 | last_x_chunk,
167 | out_dtype=ctx.weight_dtype,
168 | alpha=dloss,
169 | tuned=tuned,
170 | )
171 | else:
172 | # Add last chunk's contribution with dloss scaling
173 | # dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
174 | # We use alpha=dloss, beta=dloss
175 | if ctx.weight_dtype == dw.dtype:
176 | gemm_add_inplace(
177 | last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
178 | )
179 | else:
180 | dw = gemm_add(
181 | last_dlogits_chunk.T,
182 | last_x_chunk,
183 | dw,
184 | alpha=dloss,
185 | beta=dloss,
186 | out_dtype=ctx.weight_dtype,
187 | tuned=tuned,
188 | )
189 | return dx, dw, None, None, None, None, None
190 |
191 |
192 | def chunked_linear_cross_entropy(
193 | x: Tensor,
194 | weight: Tensor,
195 | target: Tensor,
196 | chunk_size: int = 4096,
197 | ignore_index: int = -100,
198 | reduction: Literal["mean", "sum"] = "mean",
199 | tuned: bool = True,
200 | ) -> Tensor:
201 | """
202 | Chunked linear cross entropy with automatic differentiation support.
203 |
204 | Args:
205 | x: Input tensor of shape (B*L, d)
206 | weight: Weight tensor of shape (V, d)
207 | target: Target indices of shape (B*L,)
208 | chunk_size: Size of chunks to process
209 | ignore_index: Index to ignore in loss computation
210 | reduction: Type of reduction to apply
211 | tuned: Whether to use tuned kernels
212 |
213 | Returns:
214 | Loss tensor with specified reduction
215 | """
216 | if reduction not in ["mean", "sum"]:
217 | raise ValueError(f"Invalid reduction: {reduction}")
218 | loss = ChunkedLinearCrossEntropyFunction.apply(
219 | x, weight, target, ignore_index, reduction, chunk_size, tuned
220 | )
221 | return loss
222 |
223 |
224 | class LinearCrossEntropy(nn.Linear):
225 | def __init__(
226 | self,
227 | in_features: int,
228 | out_features: int,
229 | bias: bool = False,
230 | ignore_index: int = -100,
231 | reduction: Literal["none", "mean", "sum"] = "mean",
232 | chunk_size: Optional[int] = None,
233 | inplace_backward: bool = False,
234 | tuned: bool = True,
235 | device=None,
236 | dtype=None,
237 | ) -> None:
238 | super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
239 | self.ignore_index = ignore_index
240 | self.reduction = reduction
241 | self.chunk_size = chunk_size
242 | self.inplace_backward = inplace_backward
243 | self.tuned = tuned
244 |
245 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
246 | if (
247 | self.bias is None
248 | and input.is_cuda
249 | and input.stride(-1) == 1
250 | and self.in_features % 8 == 0
251 | and self.out_features % 8 == 0
252 | and input.shape[:-1].numel() % 8 == 0
253 | and self.chunk_size is not None
254 | and self.chunk_size % 8 == 0
255 | and self.reduction in ["mean", "sum"]
256 | ):
257 | return chunked_linear_cross_entropy(
258 | input,
259 | self.weight,
260 | target,
261 | chunk_size=self.chunk_size,
262 | ignore_index=self.ignore_index,
263 | reduction=self.reduction,
264 | tuned=self.tuned,
265 | )
266 | else:
267 | return linear_cross_entropy_func(
268 | input,
269 | self.weight,
270 | self.bias,
271 | target,
272 | ignore_index=self.ignore_index,
273 | reduction=self.reduction,
274 | inplace_backward=self.inplace_backward,
275 | )
276 |
--------------------------------------------------------------------------------
/tests/test_symmetric_gemm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | from quack.gemm_interface import gemm_symmetric as symmetric_dense_gemm
5 |
6 | class TestSymmetricGemm:
7 | """Unit tests for symmetric dense GEMM wrapper."""
8 |
9 | @pytest.fixture(params=[torch.float16, torch.bfloat16, torch.float32])
10 | def dtype(self, request):
11 | """Test different data types."""
12 | return request.param
13 |
14 | @property
15 | def default_shape(self):
16 | """Default shape for most tests (L, M, K)."""
17 | return (2, 1024, 512)
18 |
19 | def torch_reference(self, a, b=None, C=None, alpha=1.0, beta=1.0):
20 | """Reference implementation using PyTorch operations.
21 |
22 | Args:
23 | a: Input tensor A of shape (L, M, K)
24 | b: Input tensor B of shape (L, M, K) - if None, uses A (symmetric case)
25 | C: Optional additive tensor C of shape (L, M, M)
26 | alpha: Scaling factor for A @ B^T
27 | beta: Scaling factor for C
28 |
29 | Returns:
30 | Result tensor of shape (L, M, M)
31 | """
32 | if b is None:
33 | b = a
34 |
35 | # Use einsum for batched matrix multiplication: A @ B^T
36 | # a: (L, M, K), b: (L, M, K) -> result: (L, M, M)
37 | result = alpha * torch.einsum("lmk,lnk->lmn", a, b)
38 |
39 | if C is not None:
40 | result = result + beta * C
41 |
42 | return result
43 |
44 | def create_test_tensor(self, L, M, K, dtype, device, stride_pattern="m_major", seed=None):
45 | """Create test tensor with specified stride pattern.
46 |
47 | Args:
48 | L, M, K: Tensor dimensions
49 | dtype: Data type
50 | device: Device ('cuda' or 'cpu')
51 | stride_pattern: How to arrange strides - 'm_major' means M has stride 1, 'k_major' means K has stride 1
52 | seed: Random seed for reproducibility
53 | """
54 | if stride_pattern == "m_major":
55 | # M has stride 1: (L, M, K) with strides (M*K, 1, M)
56 | tensor = torch.empty_strided((L, M, K), (M * K, 1, M), dtype=dtype, device=device)
57 | elif stride_pattern == "k_major":
58 | # K has stride 1: (L, M, K) with strides (M*K, K, 1)
59 | tensor = torch.empty_strided((L, M, K), (M * K, K, 1), dtype=dtype, device=device)
60 | else:
61 | raise ValueError(f"Unsupported stride pattern: {stride_pattern}")
62 |
63 | # Fill with random data
64 | if seed is not None:
65 | torch.manual_seed(seed)
66 | tensor.uniform_(-2, 2)
67 | return tensor
68 |
69 | def create_symmetric_tensor(self, L, M, dtype, device, seed=None):
70 | """Create a symmetric tensor of shape (L, M, M)."""
71 | if seed is not None:
72 | torch.manual_seed(seed)
73 |
74 | tensor = torch.randn(L, M, M, dtype=dtype, device=device)
75 |
76 | for l in range(L):
77 | matrix = tensor[l, :, :]
78 | tensor[l, :, :] = (matrix + matrix.T) / 2
79 |
80 | return tensor
81 |
82 | def test_basic_symmetric_gemm(self, dtype):
83 | """Test basic symmetric GEMM without bias."""
84 | if not torch.cuda.is_available():
85 | pytest.skip("CUDA not available")
86 |
87 | L, M, K = self.default_shape
88 | device = "cuda"
89 |
90 | # Create input tensor A with stride 1 along M dimension
91 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42)
92 |
93 | print(f"a.shape = {a.shape}, a.stride = {a.stride()}")
94 |
95 | # Test symmetric case (B = A.transpose(-2, -1) for symmetric GEMM)
96 | result_quack = symmetric_dense_gemm(a, a.transpose(-2, -1), C=None)
97 | result_torch = self.torch_reference(a, a)
98 |
99 | assert result_quack.shape == result_torch.shape == (L, M, M)
100 |
101 | if dtype == torch.float32:
102 | torch.testing.assert_close(result_quack, result_torch, atol=1e-4, rtol=1e-4)
103 | else: # float16, bfloat16
104 | torch.testing.assert_close(result_quack, result_torch, atol=1e-2, rtol=1e-2)
105 |
106 | def test_symmetric_gemm_with_bias(self, dtype):
107 | """Test symmetric GEMM with bias tensor C."""
108 | if not torch.cuda.is_available():
109 | pytest.skip("CUDA not available")
110 |
111 | L, M, K = self.default_shape
112 | device = "cuda"
113 |
114 | # Create input tensors
115 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42)
116 | c = self.create_symmetric_tensor(L, M, dtype, device, seed=123)
117 |
118 | # Compute with our wrapper
119 | result_quack = symmetric_dense_gemm(a, a.transpose(-2, -1), C=c, alpha=1.0, beta=1.0)
120 |
121 | # Compute reference
122 | result_torch = self.torch_reference(a, a, C=c)
123 |
124 | # Check shapes match
125 | assert result_quack.shape == result_torch.shape == (L, M, M)
126 |
127 | # Check values match
128 | if dtype == torch.float32:
129 | torch.testing.assert_close(result_quack, result_torch, atol=1e-4, rtol=1e-4)
130 | else:
131 | torch.testing.assert_close(result_quack, result_torch, atol=1e-2, rtol=1e-2)
132 |
133 | def test_alpha_beta_scaling(self, dtype):
134 | """Test alpha and beta scaling factors."""
135 | if not torch.cuda.is_available():
136 | pytest.skip("CUDA not available")
137 |
138 | L, M, K = self.default_shape
139 | device = "cuda"
140 | alpha, beta = 2.5, 0.5
141 |
142 | # Create input tensors
143 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42)
144 | c = self.create_symmetric_tensor(L, M, dtype, device, seed=123)
145 |
146 | # Compute with our wrapper
147 | result_quack = symmetric_dense_gemm(a, a.transpose(-2, -1), C=c, alpha=alpha, beta=beta)
148 |
149 | # Compute reference
150 | result_torch = self.torch_reference(a, a, C=c, alpha=alpha, beta=beta)
151 |
152 | # Check values match
153 | if dtype == torch.float32:
154 | torch.testing.assert_close(result_quack, result_torch, atol=1e-4, rtol=1e-4)
155 | else:
156 | torch.testing.assert_close(result_quack, result_torch, atol=1e-2, rtol=1e-2)
157 |
158 | def test_symmetry_property(self, dtype):
159 | """Test that output is actually symmetric (D = D^T)."""
160 | if not torch.cuda.is_available():
161 | pytest.skip("CUDA not available")
162 |
163 | L, M, K = self.default_shape
164 | device = "cuda"
165 |
166 | # Create input tensor
167 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42)
168 |
169 | # Compute symmetric GEMM
170 | result = symmetric_dense_gemm(a, a.transpose(-2, -1), C=None, alpha=1.0, beta=1.0)
171 |
172 | # Check symmetry for each batch
173 | for l in range(L):
174 | matrix = result[l, :, :]
175 | torch.testing.assert_close(matrix, matrix.T, atol=1e-6, rtol=1e-6)
176 |
177 | def test_different_sizes(self):
178 | """Test various matrix sizes to ensure robustness."""
179 | if not torch.cuda.is_available():
180 | pytest.skip("CUDA not available")
181 |
182 | device = "cuda"
183 | dtype = torch.float16
184 |
185 | test_sizes = [
186 | (3, 128, 128),
187 | (5, 256, 256),
188 | (5, 1024, 1024),
189 | (3, 2048, 2048),
190 | (1, 4096, 4096),
191 | ]
192 |
193 | for L, M, K in test_sizes:
194 | a = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42)
195 |
196 | result = symmetric_dense_gemm(a, a.transpose(-2, -1), C=None, alpha=1.0, beta=1.0)
197 | expected = self.torch_reference(a, a)
198 |
199 | assert result.shape == (L, M, M)
200 | torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)
201 |
202 | # Verify symmetry
203 | for l in range(L):
204 | matrix = result[l, :, :]
205 | torch.testing.assert_close(matrix, matrix.T, atol=1e-6, rtol=1e-6)
206 |
207 | def test_different_stride_patterns(self, dtype):
208 | """Test symmetric GEMM with different stride patterns (m_major vs k_major)."""
209 | if not torch.cuda.is_available():
210 | pytest.skip("CUDA not available")
211 |
212 | L, M, K = self.default_shape
213 | device = "cuda"
214 |
215 | a_m_major = self.create_test_tensor(L, M, K, dtype, device, "m_major", seed=42)
216 |
217 | data_contiguous = a_m_major.contiguous()
218 | a_k_major = torch.empty_strided((L, M, K), (M * K, K, 1), dtype=dtype, device=device)
219 | a_k_major.copy_(data_contiguous)
220 |
221 | assert torch.equal(a_m_major, a_k_major), "Input tensors should have identical values"
222 | assert a_m_major.stride() != a_k_major.stride(), "Stride patterns should be different"
223 |
224 | result_m_major = symmetric_dense_gemm(a_m_major, a_m_major.transpose(-2, -1), C=None, alpha=1.0, beta=1.0)
225 | result_k_major = symmetric_dense_gemm(a_k_major, a_k_major.transpose(-2, -1), C=None, alpha=1.0, beta=1.0)
226 |
227 | assert result_m_major.shape == result_k_major.shape == (L, M, M)
228 |
229 | if dtype == torch.float32:
230 | torch.testing.assert_close(result_m_major, result_k_major, atol=1e-6, rtol=1e-6)
231 | else:
232 | torch.testing.assert_close(result_m_major, result_k_major, atol=1e-4, rtol=1e-4)
233 |
234 | expected = self.torch_reference(a_m_major, a_m_major)
235 |
236 | if dtype == torch.float32:
237 | torch.testing.assert_close(result_m_major, expected, atol=1e-4, rtol=1e-4)
238 | torch.testing.assert_close(result_k_major, expected, atol=1e-4, rtol=1e-4)
239 | else:
240 | torch.testing.assert_close(result_m_major, expected, atol=1e-2, rtol=1e-2)
241 | torch.testing.assert_close(result_k_major, expected, atol=1e-2, rtol=1e-2)
242 |
243 |
244 | def run_tests():
245 | """Run all tests manually (for debugging)."""
246 | test_class = TestSymmetricGemm()
247 |
248 | try:
249 | # Test basic functionality
250 | print("Testing basic symmetric GEMM...")
251 | test_class.test_basic_symmetric_gemm(torch.float16)
252 | print("✓ Basic test passed")
253 |
254 | # Test with bias
255 | print("Testing with bias...")
256 | test_class.test_symmetric_gemm_with_bias(torch.float16)
257 | print("✓ Bias test passed")
258 |
259 | # Test scaling
260 | print("Testing alpha/beta scaling...")
261 | test_class.test_alpha_beta_scaling(torch.float16)
262 | print("✓ Scaling test passed")
263 |
264 | # Test symmetry
265 | print("Testing symmetry property...")
266 | test_class.test_symmetry_property(torch.float16)
267 | print("✓ Symmetry test passed")
268 |
269 | # Test different sizes
270 | print("Testing different sizes...")
271 | test_class.test_different_sizes()
272 | print("✓ Different sizes test passed")
273 |
274 | # Test different stride patterns
275 | print("Testing different stride patterns...")
276 | test_class.test_different_stride_patterns(torch.float16)
277 | print("✓ Different stride patterns test passed")
278 |
279 | print("\n🎉 All tests passed!")
280 |
281 | except Exception as e:
282 | print(f"\n❌ Test failed with error: {e}")
283 | import traceback
284 |
285 | traceback.print_exc()
286 |
287 |
288 | if __name__ == "__main__":
289 | run_tests()
290 |
--------------------------------------------------------------------------------
/tests/test_linear_varlen_k.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025, Tri Dao.
2 | import math
3 | import pytest
4 | import torch
5 |
6 | from quack.gemm_interface import (
7 | gemm,
8 | gemm_ref,
9 | gemm_add,
10 | gemm_add_ref,
11 | gemm_add_inplace,
12 | )
13 |
14 |
15 | def generate_A_with_gather(m, total_k, device, dtype, gather_A=False):
16 | """Generate A matrix and optionally A_idx for gather_A case with varlen_k.
17 |
18 | Args:
19 | m: Number of rows
20 | total_k: Number of columns needed
21 | device: Device to create tensors on
22 | dtype: Data type of tensors
23 | gather_A: Whether to create gather indices
24 |
25 | Returns:
26 | A: Matrix of shape (m, larger_k) if gather_A else (m, total_k)
27 | A_idx: Index tensor of shape (total_k,) if gather_A else None
28 | """
29 | if gather_A:
30 | # Create random indices for gathering from a larger A matrix
31 | larger_k = total_k * 2 # Make A larger than needed
32 | A = torch.randn((m, larger_k), device=device, dtype=dtype)
33 | # Make A m-major
34 | A = A.T.contiguous().T
35 | # Create random indices to gather from A
36 | A_idx = torch.randperm(larger_k, device=device, dtype=torch.int32)[:total_k]
37 | else:
38 | A = torch.randn((m, total_k), device=device, dtype=dtype)
39 | # Make A m-major
40 | A = A.T.contiguous().T
41 | A_idx = None
42 | return A, A_idx
43 |
44 |
45 | @pytest.mark.parametrize("permute_batch", [False, True])
46 | @pytest.mark.parametrize("gather_A", [False, True])
47 | # @pytest.mark.parametrize("gather_A", [False])
48 | @pytest.mark.parametrize("dynamic_scheduler", [False, True])
49 | # @pytest.mark.parametrize("dynamic_scheduler", [False])
50 | @pytest.mark.parametrize("alpha_is_tensor", [False, True])
51 | # @pytest.mark.parametrize("alpha_is_tensor", [False])
52 | @pytest.mark.parametrize("alpha", [1.0, 0.93])
53 | # @pytest.mark.parametrize("alpha", [1.0])
54 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
55 | @pytest.mark.parametrize("n", [1024, 1504, 4096])
56 | @pytest.mark.parametrize("m", [2048, 1064, 8192])
57 | # @pytest.mark.parametrize("n", [1024])
58 | # @pytest.mark.parametrize("m", [2048])
59 | @pytest.mark.parametrize("num_groups", [2, 4])
60 | # @pytest.mark.parametrize("num_groups", [2])
61 | def test_gemm_varlen_k(
62 | num_groups,
63 | m,
64 | n,
65 | input_dtype,
66 | alpha,
67 | alpha_is_tensor,
68 | dynamic_scheduler,
69 | gather_A,
70 | permute_batch,
71 | ):
72 | device = "cuda"
73 | torch.random.manual_seed(42)
74 | seq_lens = torch.randint(50, 300, (num_groups,), device="cpu")
75 | total_k = seq_lens.sum().item()
76 | # Create cumulative sequence lengths (num_groups + 1)
77 | cu_seqlens_k = torch.cat(
78 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)]
79 | )
80 | cu_seqlens_k = cu_seqlens_k.to(device)
81 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A)
82 | avg_k = total_k / num_groups
83 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k)
84 | if alpha_is_tensor:
85 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32)
86 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha
87 | if permute_batch:
88 | batch_idx_permute = torch.randperm(num_groups, device=device).to(torch.int32)
89 | else:
90 | batch_idx_permute = None
91 | out = gemm(
92 | A,
93 | B,
94 | alpha=alpha,
95 | cu_seqlens_k=cu_seqlens_k,
96 | A_idx=A_idx,
97 | batch_idx_permute=batch_idx_permute,
98 | dynamic_scheduler=dynamic_scheduler,
99 | tuned=False,
100 | )
101 | assert out.shape == (num_groups, m, n)
102 | out_ref = gemm_ref(
103 | A.float(),
104 | B.float(),
105 | alpha=alpha_val,
106 | cu_seqlens_k=cu_seqlens_k,
107 | A_idx=A_idx,
108 | )
109 | out_pt = gemm_ref(A, B, alpha=alpha_val, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx)
110 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4
111 |
112 |
113 | @pytest.mark.parametrize("gather_A", [False, True])
114 | # @pytest.mark.parametrize("gather_A", [False])
115 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
116 | @pytest.mark.parametrize("n", [1024])
117 | @pytest.mark.parametrize("m", [2048])
118 | def test_gemm_varlen_k_with_zero_lengths(
119 | m,
120 | n,
121 | input_dtype,
122 | gather_A,
123 | ):
124 | device = "cuda"
125 | torch.random.manual_seed(42)
126 | seq_lens = torch.tensor([150, 64, 0, 200, 0], device="cpu", dtype=torch.int32)
127 | num_groups = seq_lens.shape[0]
128 | total_k = seq_lens.sum().item()
129 | cu_seqlens_k = torch.cat(
130 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)]
131 | )
132 | cu_seqlens_k = cu_seqlens_k.to(device)
133 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A)
134 | avg_k = total_k / num_groups
135 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k)
136 | out = gemm(
137 | A,
138 | B,
139 | cu_seqlens_k=cu_seqlens_k,
140 | A_idx=A_idx,
141 | dynamic_scheduler=False,
142 | tuned=False,
143 | )
144 | assert out.shape == (num_groups, m, n)
145 | out_ref = gemm_ref(
146 | A.float(),
147 | B.float(),
148 | cu_seqlens_k=cu_seqlens_k,
149 | A_idx=A_idx,
150 | )
151 | out_pt = gemm_ref(A, B, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx)
152 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4
153 |
154 |
155 | @pytest.mark.parametrize("gather_A", [False, True])
156 | # @pytest.mark.parametrize("gather_A", [False])
157 | @pytest.mark.parametrize("dynamic_scheduler", [False, True])
158 | # @pytest.mark.parametrize("dynamic_scheduler", [False])
159 | @pytest.mark.parametrize("C_major", ["m", "n"])
160 | @pytest.mark.parametrize("alpha_is_tensor", [False, True])
161 | @pytest.mark.parametrize("beta_is_tensor", [False, True])
162 | @pytest.mark.parametrize("beta", [0.0, 1.17])
163 | @pytest.mark.parametrize("alpha", [1.0, 0.93])
164 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
165 | @pytest.mark.parametrize("n", [1024, 1504])
166 | @pytest.mark.parametrize("m", [2048, 1024])
167 | @pytest.mark.parametrize("num_groups", [2, 4])
168 | def test_gemm_add_varlen_k(
169 | num_groups,
170 | m,
171 | n,
172 | input_dtype,
173 | alpha,
174 | beta,
175 | alpha_is_tensor,
176 | beta_is_tensor,
177 | C_major,
178 | dynamic_scheduler,
179 | gather_A,
180 | ):
181 | device = "cuda"
182 | torch.random.manual_seed(42)
183 | seq_lens = torch.randint(50, 300, (num_groups,), device="cpu")
184 | total_k = seq_lens.sum().item()
185 | # Create cumulative sequence lengths (num_groups + 1)
186 | cu_seqlens_k = torch.cat(
187 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)]
188 | )
189 | cu_seqlens_k = cu_seqlens_k.to(device)
190 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A)
191 | # Make A m-major
192 | A = A.T.contiguous().T
193 | avg_k = total_k / num_groups
194 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k)
195 | C = torch.randn((num_groups, m, n), device=device, dtype=input_dtype)
196 | if C_major == "m":
197 | C = C.permute(0, 2, 1).contiguous().permute(0, 2, 1)
198 | if alpha_is_tensor:
199 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32)
200 | if beta_is_tensor:
201 | beta = torch.tensor(beta, device=device, dtype=torch.float32)
202 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha
203 | beta_val = beta.item() if torch.is_tensor(beta) else beta
204 | out = gemm_add(
205 | A,
206 | B,
207 | C,
208 | alpha=alpha,
209 | beta=beta,
210 | cu_seqlens_k=cu_seqlens_k,
211 | A_idx=A_idx,
212 | dynamic_scheduler=dynamic_scheduler,
213 | tuned=False,
214 | )
215 | assert out.shape == (num_groups, m, n)
216 | out_ref = gemm_add_ref(
217 | A.float(),
218 | B.float(),
219 | C.float(),
220 | alpha=alpha_val,
221 | beta=beta_val,
222 | cu_seqlens_k=cu_seqlens_k,
223 | A_idx=A_idx,
224 | )
225 | out_pt = gemm_add_ref(
226 | A, B, C, alpha=alpha_val, beta=beta_val, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx
227 | )
228 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4
229 |
230 |
231 | @pytest.mark.parametrize("gather_A", [False, True])
232 | # @pytest.mark.parametrize("gather_A", [False])
233 | @pytest.mark.parametrize("dynamic_scheduler", [False, True])
234 | # @pytest.mark.parametrize("dynamic_scheduler", [False])
235 | @pytest.mark.parametrize("alpha_is_tensor", [False, True])
236 | @pytest.mark.parametrize("beta_is_tensor", [False, True])
237 | @pytest.mark.parametrize("beta", [0.0, 1.17])
238 | @pytest.mark.parametrize("alpha", [1.0, 0.93])
239 | @pytest.mark.parametrize("input_dtype", [torch.bfloat16])
240 | @pytest.mark.parametrize("n", [1024, 1504])
241 | @pytest.mark.parametrize("m", [2048, 1024])
242 | @pytest.mark.parametrize("num_groups", [2, 4])
243 | def test_gemm_add_inplace_varlen_k(
244 | num_groups,
245 | m,
246 | n,
247 | input_dtype,
248 | alpha,
249 | beta,
250 | alpha_is_tensor,
251 | beta_is_tensor,
252 | dynamic_scheduler,
253 | gather_A,
254 | ):
255 | device = "cuda"
256 | torch.random.manual_seed(42)
257 | seq_lens = torch.randint(50, 300, (num_groups,), device="cpu")
258 | total_k = seq_lens.sum().item()
259 | # Create cumulative sequence lengths (num_groups + 1)
260 | cu_seqlens_k = torch.cat(
261 | [torch.zeros(1, dtype=torch.int32), seq_lens.cumsum(0).to(torch.int32)]
262 | )
263 | cu_seqlens_k = cu_seqlens_k.to(device)
264 | A, A_idx = generate_A_with_gather(m, total_k, device, input_dtype, gather_A)
265 | # Make A m-major
266 | A = A.T.contiguous().T
267 | avg_k = total_k / num_groups
268 | B = torch.randn((total_k, n), device=device, dtype=input_dtype) / math.sqrt(avg_k)
269 | out = torch.randn((num_groups, m, n), device=device, dtype=input_dtype)
270 | if alpha_is_tensor:
271 | alpha = torch.tensor(alpha, device=device, dtype=torch.float32)
272 | if beta_is_tensor:
273 | beta = torch.tensor(beta, device=device, dtype=torch.float32)
274 | # Save original out for reference computation
275 | out_og = out.clone()
276 | gemm_add_inplace(
277 | A,
278 | B,
279 | out,
280 | alpha=alpha,
281 | beta=beta,
282 | cu_seqlens_k=cu_seqlens_k,
283 | A_idx=A_idx,
284 | dynamic_scheduler=dynamic_scheduler,
285 | tuned=False,
286 | )
287 | alpha_val = alpha.item() if torch.is_tensor(alpha) else alpha
288 | beta_val = beta.item() if torch.is_tensor(beta) else beta
289 | out_ref = gemm_add_ref(
290 | A.float(),
291 | B.float(),
292 | out_og.float(),
293 | out=None, # Don't use in-place for reference
294 | alpha=alpha_val,
295 | beta=beta_val,
296 | cu_seqlens_k=cu_seqlens_k,
297 | A_idx=A_idx,
298 | )
299 | out_pt = gemm_add_ref(
300 | A,
301 | B,
302 | out_og,
303 | out=None,
304 | alpha=alpha_val,
305 | beta=beta_val,
306 | cu_seqlens_k=cu_seqlens_k,
307 | A_idx=A_idx,
308 | )
309 | assert out.shape == (num_groups, m, n), (
310 | f"Output shape mismatch: {out.shape} vs expected ({num_groups}, {m}, {n})"
311 | )
312 | assert (out - out_ref).abs().max() < 2 * (out_pt - out_ref).abs().max() + 1e-4
313 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------