├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── quantum_attn │ ├── __init__.py │ ├── config.py │ ├── inductor │ ├── __init__.py │ └── kernels │ │ ├── __init__.py │ │ ├── attention.py │ │ └── mm_common.py │ ├── nn.py │ ├── ops.py │ ├── quantum_attn_interface.py │ ├── tk │ ├── __init__.py │ ├── attention.py │ └── utils.py │ └── utils │ ├── __init__.py │ ├── checks.py │ └── types.py └── tests ├── __init__.py └── test_interface.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | _version.py 177 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fahopper/csrc/cutlass"] 2 | path = fahopper/csrc/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | [submodule "src/quantum_attn/tk_repo"] 5 | path = src/quantum_attn/tk_repo 6 | url = https://github.com/chengzeyi/ThunderKittens.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-docstring-first 6 | - id: check-toml 7 | - id: check-yaml 8 | exclude: packaging/.* 9 | args: 10 | - --allow-multiple-documents 11 | - id: mixed-line-ending 12 | args: [--fix=lf] 13 | - id: end-of-file-fixer 14 | 15 | - repo: https://github.com/omnilib/ufmt 16 | rev: v1.3.3 17 | hooks: 18 | - id: ufmt 19 | additional_dependencies: 20 | - black == 22.3.0 21 | - usort == 1.0.2 22 | 23 | - repo: https://github.com/PyCQA/flake8 24 | rev: 7.1.1 25 | hooks: 26 | - id: flake8 27 | args: [--config=setup.cfg] 28 | exclude: fahopper/.* 29 | 30 | # - repo: https://github.com/PyCQA/pydocstyle 31 | # rev: 6.1.1 32 | # hooks: 33 | # - id: pydocstyle 34 | 35 | # - repo: https://github.com/pre-commit/mirrors-clang-format 36 | # rev: v14.0.6 37 | # hooks: 38 | # - id: clang-format 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Custom License for QuantumAttention 2 | 3 | Copyright (c) 2025 Cheng Zeyi 4 | 5 | Permission is hereby granted, free of charge, to any person or organization with an annual revenue of less than 5,000,000 USD, obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | 1. This license applies only to individuals or organizations with an annual revenue of less than 5,000,000 USD. For individuals or organizations with an annual revenue equal to or greater than 5,000,000 USD, a separate commercial license is required. 8 | 9 | 2. Schools, educational institutions, and non-profit research purposes are exempt from the revenue restriction and are granted permission to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software under the specified conditions. 10 | 11 | 3. The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 12 | 13 | 4. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 14 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | include LICENSE 3 | include requirements.txt 4 | recursive-include src/quantum_attn/tk_repo/include * 5 | recursive-include tests * 6 | prune */__pycache__ 7 | global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuantumAttention 2 | 3 | Better (FP8) attention for Hopper 4 | 5 | ## License 6 | 7 | This project is licensed under a custom license. Individuals or organizations with an annual revenue of less than 5,000,000 USD are granted permission to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the software under the specified conditions. Schools, educational institutions, and non-profit research purposes are exempt from the revenue restriction and are granted permission to use the software under the same conditions. 8 | 9 | For those with an annual revenue equal to or greater than 5,000,000 USD, a separate commercial license is required. 10 | 11 | For more details, please refer to the [LICENSE](./LICENSE.txt) file. 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | 3 | name = "quantum_attn" 4 | # version = "0.0.1" # Remove any existing version parameter. 5 | dynamic = ["version", "dependencies", "optional-dependencies"] 6 | requires-python = ">=3.8" 7 | authors = [ 8 | {name = "Zeyi Cheng", email = "ichengzeyi@gmail.com"}, 9 | ] 10 | description = "Better FP8 attention for Hopper" 11 | 12 | [project.urls] 13 | 14 | Repository = "https://github.com/chengzeyi/QuantumAttention" 15 | 16 | [build-system] 17 | 18 | requires = ["setuptools>=64", "setuptools_scm>=8"] 19 | # build-backend = "setuptools.build_meta" 20 | 21 | [tool.black] 22 | 23 | line-length = 120 24 | target-version = ["py38"] 25 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_format = %(asctime)s %(filename)s:%(lineno)d %(levelname)s %(message)s 3 | log_date_format = %Y-%m-%d %H:%M:%S 4 | log_cli = true 5 | log_level = INFO 6 | addopts = --capture=tee-sys --verbose --color=auto --durations=0 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Example requirement, can be anything that pip knows 2 | # install with `pip install -r requirements.txt`, and make sure that CI does the same 3 | torch 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [metadata] 5 | license_files = LICENSE 6 | 7 | [pep8] 8 | max-line-length = 120 9 | 10 | [flake8] 11 | # note: we ignore all 501s (line too long) anyway as they're taken care of by black 12 | max-line-length = 120 13 | ignore = E731, E203, E402, W503, W504, F821, E501, B, C4, EXE 14 | per-file-ignores = 15 | __init__.py: F401, F403, F405 16 | exclude = venv 17 | 18 | [pydocstyle] 19 | select = D417 # Missing argument descriptions in the docstring 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import importlib.util 4 | 5 | if importlib.util.find_spec("setuptools_scm") is None: 6 | raise ImportError("setuptools-scm is not installed. Install it by `pip3 install setuptools-scm`") 7 | 8 | import os 9 | import subprocess 10 | import sys 11 | from os import path 12 | 13 | from setuptools import find_packages, setup 14 | from setuptools_scm.version import get_local_dirty_tag 15 | 16 | THIS_DIR = path.dirname(path.abspath(__file__)) 17 | 18 | UPDATE_SUBMODULES = os.environ.get("QUANTUM_ATTN_BUILD_UPDATE_SUBMODULES", "1") == "1" 19 | 20 | 21 | def is_git_directory(path="."): 22 | return subprocess.call(["git", "-C", path, "status"], stderr=subprocess.STDOUT, stdout=open(os.devnull, "w")) == 0 23 | 24 | 25 | if UPDATE_SUBMODULES: 26 | if is_git_directory(THIS_DIR): 27 | print("Updating submodules") 28 | subprocess.run( 29 | ["git", "submodule", "update", "--init", "--recursive"], 30 | check=True, 31 | stdout=sys.stdout, 32 | stderr=sys.stderr, 33 | ) 34 | else: 35 | print("Not a git directory. Skipping submodule update.") 36 | 37 | 38 | def my_local_scheme(version): 39 | # The following is used to build release packages. 40 | # Users should never use it. 41 | local_version = os.getenv("QUANTUM_ATTN_BUILD_LOCAL_VERSION") 42 | if local_version is None: 43 | return get_local_dirty_tag(version) 44 | return f"+{local_version}" 45 | 46 | 47 | def fetch_requirements(): 48 | with open("requirements.txt") as f: 49 | reqs = f.read().strip().split("\n") 50 | return reqs 51 | 52 | 53 | setup( 54 | name="quantum_attn", 55 | use_scm_version={"write_to": path.join("src", "quantum_attn", "_version.py"), "local_scheme": my_local_scheme}, 56 | package_dir={ 57 | "": "src", 58 | }, 59 | packages=find_packages(where="src"), 60 | python_requires=">=3.8", 61 | include_package_data=True, 62 | install_requires=fetch_requirements(), 63 | extras_require={ 64 | # optional dependencies, required by some features 65 | "all": [], 66 | # dev dependencies. Install them by `pip3 install 'quantum-attn[dev]'` 67 | "dev": [ 68 | "pre-commit", 69 | "pytest>=7.0.0,<8.0.0", # https://github.com/pytest-dev/pytest/issues/12273 70 | "expecttest", 71 | # 72 | "pandas", 73 | "llnl-hatchet", 74 | ], 75 | }, 76 | ) 77 | -------------------------------------------------------------------------------- /src/quantum_attn/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__, version_tuple 3 | except ImportError: 4 | __version__ = "unknown version" 5 | version_tuple = (0, 0, "unknown version") 6 | 7 | import torch 8 | 9 | from . import config, nn, ops 10 | from .quantum_attn_interface import ( 11 | attn_func, 12 | attn_func_with_fallback, 13 | dynamically_quantize_fp8, 14 | fp8_attn_func, 15 | fp8_attn_func_with_fallback, 16 | fp8_token_wise_attn_func, 17 | fp8_token_wise_attn_func_with_fallback, 18 | ) 19 | 20 | if torch._dynamo.is_inductor_supported(): 21 | from . import inductor 22 | 23 | __all__ = [ 24 | "attn_func", 25 | "attn_func_with_fallback", 26 | "dynamically_quantize_fp8", 27 | "fp8_attn_func", 28 | "fp8_attn_func_with_fallback", 29 | "fp8_token_wise_attn_func", 30 | "fp8_token_wise_attn_func_with_fallback", 31 | ] 32 | -------------------------------------------------------------------------------- /src/quantum_attn/config.py: -------------------------------------------------------------------------------- 1 | import os # noqa: C101 2 | import sys 3 | 4 | # import torch 5 | 6 | _save_config_ignore = { 7 | # workaround: "Can't pickle " 8 | } 9 | 10 | 11 | use_fast_accum = os.getenv("QUANTUM_ATTN_USE_FAST_ACCUM", "1") == "1" 12 | 13 | 14 | class dynamo: 15 | dynamic = os.getenv("QUANTUM_ATTN_DYNAMIC") == "1" 16 | 17 | mode = os.getenv("QUANTUM_ATTN_MODE", "default") 18 | 19 | 20 | class triton: 21 | enable_fast_math = os.getenv("QUANTUM_ATTN_ENABLE_FAST_MATH", "1") == "1" 22 | 23 | allow_reduced_precision_compute = os.getenv("PARA_ATTN_ALLOW_REDUCED_PRECISION_COMPUTE") == "1" 24 | 25 | 26 | class attention: 27 | skip_supported_check = os.getenv("QUANTUM_ATTN_SKIP_SUPPORTED_CHECK") == "1" 28 | force_eager_fallback = os.getenv("QUANTUM_ATTN_FORCE_EAGER_FALLBACK") == "1" 29 | 30 | enable_tk_tma_kernel = os.getenv("QUANTUM_ATTN_ENABLE_TK_TMA_KERNEL", "1") == "1" 31 | enable_triton_tma_kernel = os.getenv("QUANTUM_ATTN_ENABLE_TRITON_TMA_KERNEL") == "1" 32 | 33 | 34 | try: 35 | from torch.utils._config_module import install_config_module 36 | except ImportError: 37 | # torch<2.2.0 38 | from torch._dynamo.config_utils import install_config_module 39 | 40 | # adds patch, save_config, etc 41 | install_config_module(sys.modules[__name__]) 42 | -------------------------------------------------------------------------------- /src/quantum_attn/inductor/__init__.py: -------------------------------------------------------------------------------- 1 | from . import kernels 2 | -------------------------------------------------------------------------------- /src/quantum_attn/inductor/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from . import attention 2 | -------------------------------------------------------------------------------- /src/quantum_attn/inductor/kernels/attention.py: -------------------------------------------------------------------------------- 1 | import sympy 2 | 3 | import torch 4 | 5 | if True: 6 | # Put this first to avoid circular import 7 | # ImportError: cannot import name 'CppGemmTemplate' from partially initialized module 'torch._inductor.codegen.cpp_gemm_template' (most likely due to a circular import) 8 | from torch._inductor.lowering import register_lowering 9 | 10 | from torch._inductor import config as inductor_config, ir 11 | from torch._inductor.codegen.triton import TritonOverrides 12 | from torch._inductor.kernel.mm_common import mm_args 13 | from torch._inductor.runtime.runtime_utils import next_power_of_2 14 | from torch._inductor.select_algorithm import autotune_select_algorithm, ExternKernelChoice, TritonTemplate 15 | from torch._inductor.utils import ceildiv as cdiv, is_dynamic, use_aten_gemm_kernels, use_max_autotune 16 | from torch._inductor.virtualized import V 17 | 18 | from quantum_attn import config 19 | 20 | from quantum_attn.utils import checks 21 | from ...tk.attention import load_tk_attention_module 22 | 23 | from .mm_common import acc_type, get_device_shared_memory, mm_options, reduce_block_size_for_cuda, require_dense_memory 24 | 25 | aten = torch.ops.aten 26 | quantum_attn_ops = torch.ops.quantum_attn 27 | 28 | 29 | def tk_attention_forward_kernel( 30 | query, 31 | key, 32 | value, 33 | attn_mask=None, 34 | dropout_p=0.0, 35 | is_causal=False, 36 | *, 37 | scale=None, 38 | ): 39 | assert attn_mask is None 40 | assert dropout_p == 0.0 41 | assert scale is None 42 | 43 | module = load_tk_attention_module(dtype=value.dtype) 44 | out = module.attention_forward(query, key, value, is_causal)[0] 45 | return out 46 | 47 | 48 | tk_attention_forward = ExternKernelChoice( 49 | tk_attention_forward_kernel, 50 | name="quantum_attn_tk_attention_forward", 51 | has_out_variant=False, 52 | ) 53 | 54 | 55 | def tk_fp8_attention_forward_kernel( 56 | query, 57 | key, 58 | value, 59 | scale_q, 60 | scale_k, 61 | attn_mask=None, 62 | dropout_p=0.0, 63 | is_causal=False, 64 | *, 65 | scale=None, 66 | ): 67 | assert attn_mask is None 68 | assert dropout_p == 0.0 69 | assert scale is None 70 | 71 | module = load_tk_attention_module(dtype=value.dtype, is_fp8=True) 72 | out = module.attention_forward(query, key, value, scale_q, scale_k, is_causal)[0] 73 | return out 74 | 75 | 76 | tk_fp8_attention_forward = ExternKernelChoice( 77 | tk_fp8_attention_forward_kernel, 78 | name="quantum_attn_tk_fp8_attention_forward", 79 | has_out_variant=False, 80 | ) 81 | 82 | 83 | def aten_attention_forward_kernel( 84 | query, 85 | key, 86 | value, 87 | attn_mask=None, 88 | dropout_p=0.0, 89 | is_causal=False, 90 | *, 91 | scale=None, 92 | ): 93 | return quantum_attn_ops.attention_forward( 94 | query, 95 | key, 96 | value, 97 | attn_mask=attn_mask, 98 | dropout_p=dropout_p, 99 | is_causal=is_causal, 100 | scale=scale, 101 | ) 102 | 103 | 104 | aten_attention_forward = ExternKernelChoice( 105 | aten_attention_forward_kernel, 106 | name="quantum_attn_aten_attention_forward", 107 | has_out_variant=False, 108 | ) 109 | 110 | 111 | def aten_fp8_attention_forward_kernel( 112 | query, 113 | key, 114 | value, 115 | scale_q, 116 | scale_k, 117 | attn_mask=None, 118 | dropout_p=0.0, 119 | is_causal=False, 120 | *, 121 | scale=None, 122 | ): 123 | return quantum_attn_ops.fp8_attention_forward( 124 | query, 125 | key, 126 | value, 127 | scale_q, 128 | scale_k, 129 | attn_mask=attn_mask, 130 | dropout_p=dropout_p, 131 | is_causal=is_causal, 132 | scale=scale, 133 | ) 134 | 135 | 136 | aten_fp8_attention_forward = ExternKernelChoice( 137 | aten_fp8_attention_forward_kernel, 138 | name="quantum_attn_aten_fp8_attention_forward", 139 | has_out_variant=False, 140 | ) 141 | 142 | 143 | def persistent_attention_grid(b, h, s, d, meta): 144 | return (min(meta["NUM_SMS"], cdiv(s, meta["BLOCK_M"]) * b * h), 1, 1) 145 | 146 | 147 | attention_forward_template = TritonTemplate( 148 | name="quantum_attn_attention_forward", 149 | grid=persistent_attention_grid, 150 | source=rf""" 151 | import triton 152 | import triton.language as tl 153 | 154 | @triton.jit 155 | def num_threads(): 156 | return tl.extra.cuda.num_threads() 157 | 158 | @triton.jit 159 | def maximum(a, b): 160 | {{% if ENABLE_FAST_MATH %}} 161 | {{% if USE_FP16_COMPUTE %}} 162 | if a.numel % (num_threads() * 2) == 0: 163 | x = tl.inline_asm_elementwise( 164 | "max.ftz.f16x2 $0, $1, $2;", "=r, r, r", [a, b], dtype=tl.float16, is_pure=True, pack=2, 165 | ) 166 | else: 167 | x = tl.inline_asm_elementwise( 168 | "max.ftz.f16 $0, $1, $2;", "=h, h, h", [a, b], dtype=tl.float16, is_pure=True, pack=1, 169 | ) 170 | {{% else %}} 171 | x = tl.inline_asm_elementwise( 172 | "max.ftz.f32 $0, $1, $2;", "=f, f, f", [a, b], dtype=tl.float32, is_pure=True, pack=1, 173 | ) 174 | {{% endif %}} 175 | {{% else %}} 176 | x = a if a > b else b 177 | {{% endif %}} 178 | return x 179 | 180 | @triton.jit 181 | def maximum_(a, b): 182 | {{% if ENABLE_FAST_MATH %}} 183 | {{% if USE_FP16_COMPUTE %}} 184 | if a.numel % (num_threads() * 2) == 0: 185 | x = tl.inline_asm_elementwise( 186 | "max.ftz.f16x2 $0, $1, $2;", "=r, r, r", [a, b], dtype=tl.float16, is_pure=True, pack=2, 187 | ) 188 | else: 189 | x = tl.inline_asm_elementwise( 190 | "max.ftz.f16 $0, $1, $2;", "=h, h, h", [a, b], dtype=tl.float16, is_pure=True, pack=1, 191 | ) 192 | {{% else %}} 193 | x = tl.inline_asm_elementwise( 194 | "max.ftz.f32 $0, $1, $2;", "=f, f, f", [a, b], dtype=tl.float32, is_pure=True, pack=1, 195 | ) 196 | {{% endif %}} 197 | {{% else %}} 198 | x = tl.maximum(a, b) 199 | {{% endif %}} 200 | return x 201 | 202 | @triton.jit 203 | def add(a, b): 204 | {{% if ENABLE_FAST_MATH %}} 205 | {{% if USE_FP16_COMPUTE %}} 206 | if a.numel % (num_threads() * 2) == 0: 207 | x = tl.inline_asm_elementwise( 208 | "add.ftz.f16x2 $0, $1, $2;", "=r, r, r", [a, b], dtype=tl.float16, is_pure=True, pack=2, 209 | ) 210 | else: 211 | x = tl.inline_asm_elementwise( 212 | "add.ftz.f16 $0, $1, $2;", "=h, h, h", [a, b], dtype=tl.float16, is_pure=True, pack=1, 213 | ) 214 | {{% else %}} 215 | x = tl.inline_asm_elementwise( 216 | "add.ftz.f32 $0, $1, $2;", "=f, f, f", [a, b], dtype=tl.float32, is_pure=True, pack=1, 217 | ) 218 | {{% endif %}} 219 | {{% else %}} 220 | x = a + b 221 | {{% endif %}} 222 | return x 223 | 224 | @triton.jit 225 | def sub(a, b): 226 | {{% if ENABLE_FAST_MATH %}} 227 | {{% if USE_FP16_COMPUTE %}} 228 | if a.numel % (num_threads() * 2) == 0: 229 | x = tl.inline_asm_elementwise( 230 | "sub.ftz.f16x2 $0, $1, $2;", "=r, r, r", [a, b], dtype=tl.float16, is_pure=True, pack=2, 231 | ) 232 | else: 233 | x = tl.inline_asm_elementwise( 234 | "sub.ftz.f16 $0, $1, $2;", "=h, h, h", [a, b], dtype=tl.float16, is_pure=True, pack=1, 235 | ) 236 | {{% else %}} 237 | x = tl.inline_asm_elementwise( 238 | "sub.ftz.f32 $0, $1, $2;", "=f, f, f", [a, b], dtype=tl.float32, is_pure=True, pack=1, 239 | ) 240 | {{% endif %}} 241 | {{% else %}} 242 | x = a - b 243 | {{% endif %}} 244 | return x 245 | 246 | @triton.jit 247 | def mul(a, b): 248 | {{% if ENABLE_FAST_MATH %}} 249 | {{% if USE_FP16_COMPUTE %}} 250 | if a.numel % (num_threads() * 2) == 0: 251 | x = tl.inline_asm_elementwise( 252 | "mul.ftz.f16x2 $0, $1, $2;", "=r, r, r", [a, b], dtype=tl.float16, is_pure=True, pack=2, 253 | ) 254 | else: 255 | x = tl.inline_asm_elementwise( 256 | "mul.ftz.f16 $0, $1, $2;", "=h, h, h", [a, b], dtype=tl.float16, is_pure=True, pack=1, 257 | ) 258 | {{% else %}} 259 | x = tl.inline_asm_elementwise( 260 | "mul.ftz.f32 $0, $1, $2;", "=f, f, f", [a, b], dtype=tl.float32, is_pure=True, pack=1, 261 | ) 262 | {{% endif %}} 263 | {{% else %}} 264 | x = a * b 265 | {{% endif %}} 266 | return x 267 | 268 | @triton.jit 269 | def div(a, b): 270 | a_fp32 = a.to(tl.float32) 271 | b_fp32 = b.to(tl.float32) 272 | {{% if ENABLE_FAST_MATH %}} 273 | x = tl.inline_asm_elementwise( 274 | "div.approx.ftz.f32 $0, $1, $2;", "=f, f, f", [a_fp32, b_fp32], dtype=tl.float32, is_pure=True, pack=1, 275 | ) 276 | {{% else %}} 277 | x = a_fp32 / b_fp32 278 | {{% endif %}} 279 | x = x.to(a.dtype) 280 | return x 281 | 282 | @triton.jit 283 | def fma(a, b, c): 284 | {{% if ENABLE_FAST_MATH %}} 285 | {{% if USE_FP16_COMPUTE %}} 286 | if a.numel % (num_threads() * 2) == 0: 287 | x = tl.inline_asm_elementwise( 288 | "fma.rn.ftz.f16x2 $0, $1, $2, $3;", "=r, r, r, r", [a, b, c], dtype=tl.float16, is_pure=True, pack=2, 289 | ) 290 | else: 291 | x = tl.inline_asm_elementwise( 292 | "fma.rn.ftz.f16 $0, $1, $2, $3;", "=h, h, h, h", [a, b, c], dtype=tl.float16, is_pure=True, pack=1, 293 | ) 294 | {{% else %}} 295 | x = tl.inline_asm_elementwise( 296 | "fma.rn.ftz.f32 $0, $1, $2, $3;", "=f, f, f, f", [a, b, c], dtype=tl.float32, is_pure=True, pack=1, 297 | ) 298 | {{% endif %}} 299 | {{% else %}} 300 | x = a * b + c 301 | {{% endif %}} 302 | return x 303 | 304 | @triton.jit 305 | def ex2(x): 306 | {{% if ENABLE_FAST_MATH %}} 307 | {{% if USE_FP16_COMPUTE %}} 308 | if x.numel % (num_threads() * 2) == 0: 309 | y = tl.inline_asm_elementwise( 310 | "ex2.approx.f16x2 $0, $1;", "=r, r", [x], dtype=tl.float16, is_pure=True, pack=2, 311 | ) 312 | else: 313 | y = tl.inline_asm_elementwise( 314 | "ex2.approx.f16 $0, $1;", "=h, h", [x], dtype=tl.float16, is_pure=True, pack=1, 315 | ) 316 | {{% else %}} 317 | y = tl.inline_asm_elementwise( 318 | "ex2.approx.ftz.f32 $0, $1;", "=f, f", [x], dtype=tl.float32, is_pure=True, pack=1, 319 | ) 320 | {{% endif %}} 321 | {{% else %}} 322 | y = {TritonOverrides.exp2("x.to(tl.float32)")}.to(x.dtype) 323 | {{% endif %}} 324 | return y 325 | 326 | @triton.jit 327 | def dot(a, b, acc): 328 | {{% if USE_FAST_ACCUM %}} 329 | acc = tl.dot(a, b, acc, out_dtype=acc.dtype) 330 | {{% else %}} 331 | acc += tl.dot(a, b, out_dtype=acc.dtype) 332 | {{% endif %}} 333 | return acc 334 | 335 | @triton.jit 336 | def _attn_fwd_inner( 337 | {{% for i in range(TILES) %}} 338 | acc_{{{{i}}}}, 339 | q_{{{{i}}}}, 340 | {{% endfor %}} 341 | q_scale, 342 | l_i, m_i, 343 | K_desc_ptr, V_desc_ptr, 344 | K_scale_block_ptr, 345 | start_m, # 346 | v_dtype, 347 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 348 | STAGE: tl.constexpr, 349 | N_CTX_Q, N_CTX_K, 350 | TILES: tl.constexpr, 351 | EVEN_N: tl.constexpr, 352 | QK_ACC_TYPE, 353 | ): 354 | # range of values handled by this stage 355 | if STAGE == 1: 356 | if BLOCK_N <= BLOCK_M: 357 | lo, hi = 0, start_m * BLOCK_M 358 | else: 359 | lo, hi = 0, start_m // (BLOCK_N // BLOCK_M) * BLOCK_N 360 | elif STAGE == 2: 361 | if BLOCK_N <= BLOCK_M: 362 | lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 363 | lo = tl.multiple_of(lo, BLOCK_M) 364 | else: 365 | lo, hi = start_m // (BLOCK_N // BLOCK_M) * BLOCK_N, (start_m + 1) * BLOCK_M 366 | lo = tl.multiple_of(lo, BLOCK_N) 367 | # causal = False 368 | else: 369 | lo, hi = 0, N_CTX_K 370 | 371 | {{% if IS_QUANTIZED %}} 372 | K_scale_block_ptr = tl.advance(K_scale_block_ptr, (lo,)) 373 | {{% endif %}} 374 | 375 | # loop over k, v and update accumulator 376 | for start_n in range(lo, hi, BLOCK_N): 377 | start_n = tl.multiple_of(start_n, BLOCK_N) 378 | 379 | # -- compute qk ---- 380 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=QK_ACC_TYPE) 381 | {{% for i in range(TILES) %}} 382 | k = tl._experimental_descriptor_load( 383 | K_desc_ptr, (start_n, BLOCK_K * {{{{i}}}}), (BLOCK_N, BLOCK_K), q_0.dtype 384 | ) 385 | v_{{{{i}}}} = tl._experimental_descriptor_load( 386 | V_desc_ptr, (start_n, BLOCK_K * {{{{i}}}}), (BLOCK_N, BLOCK_K), v_dtype 387 | ) 388 | qk = dot(q_{{{{i}}}}, k.T, qk) 389 | {{% endfor %}} 390 | 391 | {{% if IS_QUANTIZED %}} 392 | k_scale = tl.load(K_scale_block_ptr, boundary_check=(0,)).to(tl.float32) 393 | K_scale_block_ptr = tl.advance(K_scale_block_ptr, (BLOCK_N,)) 394 | 395 | qk = qk * q_scale[:, None] * k_scale[None, :] 396 | {{% if USE_FP16_COMPUTE %}} 397 | qk = qk.to(tl.float16) 398 | {{% endif %}} 399 | {{% else %}} 400 | qk = mul(qk, tl.full([1], {{{{SM_SCALE}}}} * 1.44269504, dtype=qk.dtype)) 401 | {{% endif %}} 402 | 403 | if EVEN_N: 404 | if STAGE == 2: 405 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 406 | offs_n = tl.arange(0, BLOCK_N) 407 | mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 408 | qk = tl.where(mask, qk, tl.full([1], -float("inf"), dtype=qk.dtype)) 409 | else: 410 | offs_n = tl.arange(0, BLOCK_N) 411 | mask = (start_n + offs_n[None, :]) < N_CTX_K 412 | if STAGE == 2: 413 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 414 | mask = mask & (offs_m[:, None] >= (start_n + offs_n[None, :])) 415 | qk = tl.where(mask, qk, tl.full([1], -float("inf"), dtype=qk.dtype)) 416 | m_ij = maximum_(m_i, tl.reduce(qk, 1, maximum)) 417 | qk = sub(qk, m_ij[:, None]) 418 | 419 | {{% if FAST_SOFTMAX %}} 420 | 421 | numerator = ex2(qk) 422 | 423 | denominator = tl.reduce(numerator, 1, add) 424 | p = div(numerator, denominator[:, None]) 425 | 426 | {{% else %}} 427 | 428 | m_i_m_ij = sub(m_i, m_ij) 429 | alpha = ex2(m_i_m_ij) 430 | 431 | {{% for i in range(TILES) %}} 432 | acc_{{{{i}}}} = mul(acc_{{{{i}}}}, alpha[:, None]) 433 | {{% endfor %}} 434 | 435 | p = ex2(qk) 436 | 437 | l_ij = tl.reduce(p, 1, add) 438 | 439 | # -- update m_i and l_i 440 | l_i = fma(l_i, alpha, l_ij) 441 | m_i = m_ij 442 | 443 | {{% endif %}} 444 | 445 | p = p.to(v_0.dtype) 446 | 447 | # -- update output accumulator -- 448 | {{% for i in range(TILES) %}} 449 | acc_{{{{i}}}} = dot(p, v_{{{{i}}}}, acc_{{{{i}}}}) 450 | {{% endfor %}} 451 | return ( 452 | {{% for i in range(TILES) %}} 453 | acc_{{{{i}}}}, 454 | {{% endfor %}} 455 | l_i, m_i, 456 | ) 457 | 458 | {{% if IS_QUANTIZED %}} 459 | {{{{def_kernel("Q", "K", "V", "Q_scale", "K_scale")}}}} 460 | {{% else %}} 461 | {{{{def_kernel("Q", "K", "V")}}}} 462 | {{% endif %}} 463 | Z = {{{{size("Q", 0)}}}} 464 | H = {{{{size("Q", 1)}}}} 465 | N_CTX_Q = {{{{size("Q", 2)}}}} 466 | N_CTX_K = {{{{size("K", 2)}}}} 467 | D = {{{{size("Q", 3)}}}} 468 | 469 | stride_qz = {{{{stride("Q", 0)}}}} 470 | stride_qh = {{{{stride("Q", 1)}}}} 471 | stride_qm = {{{{stride("Q", 2)}}}} 472 | stride_qk = {{{{stride("Q", 3)}}}} 473 | 474 | stride_kz = {{{{stride("K", 0)}}}} 475 | stride_kh = {{{{stride("K", 1)}}}} 476 | stride_kn = {{{{stride("K", 2)}}}} 477 | stride_kk = {{{{stride("K", 3)}}}} 478 | 479 | stride_vz = {{{{stride("V", 0)}}}} 480 | stride_vh = {{{{stride("V", 1)}}}} 481 | stride_vk = {{{{stride("V", 2)}}}} 482 | stride_vn = {{{{stride("V", 3)}}}} 483 | 484 | {{% if IS_QUANTIZED %}} 485 | stride_q_scale_z = {{{{stride("Q_scale", 0)}}}} 486 | stride_q_scale_h = {{{{stride("Q_scale", 1)}}}} 487 | stride_q_scale_m = {{{{stride("Q_scale", 2)}}}} 488 | 489 | stride_k_scale_z = {{{{stride("K_scale", 0)}}}} 490 | stride_k_scale_h = {{{{stride("K_scale", 1)}}}} 491 | stride_k_scale_m = {{{{stride("K_scale", 2)}}}} 492 | {{% endif %}} 493 | 494 | start_pid = tl.program_id(0) 495 | 496 | workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE 497 | K_desc_ptr = workspace_base 498 | V_desc_ptr = workspace_base + TMA_SIZE 499 | 500 | num_programs_m = (N_CTX_Q + BLOCK_M - 1) // BLOCK_M 501 | for pid in range(start_pid, num_programs_m * Z * H, NUM_SMS): 502 | start_m = pid % num_programs_m 503 | if STAGE & 2: 504 | if start_m // num_programs_m % 2 != 0: 505 | start_m = start_m * (start_m // num_programs_m + 1) - start_m % num_programs_m 506 | off_hz = pid // num_programs_m 507 | off_z = off_hz // H 508 | off_h = off_hz % H 509 | # off_z = off_z.to(tl.int64) 510 | # off_h = off_h.to(tl.int64) 511 | 512 | q_offset = off_z * stride_qz + off_h * stride_qh 513 | k_offset = off_z * stride_kz + off_h * stride_kh 514 | v_offset = off_z * stride_vz + off_h * stride_vh 515 | # o_offset = off_z * stride_qz + off_h * stride_qh 516 | 517 | # block pointers 518 | Q_block_ptr = tl.make_block_ptr( 519 | base=Q + q_offset, 520 | shape=(N_CTX_Q, D), 521 | strides=(stride_qm, stride_qk), 522 | offsets=(start_m * BLOCK_M, 0), 523 | block_shape=(BLOCK_M, BLOCK_K), 524 | order=(1, 0), 525 | ) 526 | 527 | {{% if IS_QUANTIZED %}} 528 | q_scale_offset = off_z * stride_q_scale_z + off_h * stride_q_scale_h 529 | k_scale_offset = off_z * stride_k_scale_z + off_h * stride_k_scale_h 530 | 531 | Q_scale_block_ptr = tl.make_block_ptr( 532 | base=Q_scale + q_scale_offset, 533 | shape=(N_CTX_Q,), 534 | strides=(stride_q_scale_m,), 535 | offsets=(start_m * BLOCK_M,), 536 | block_shape=(BLOCK_M,), 537 | order=(0,), 538 | ) 539 | K_scale_block_ptr = tl.make_block_ptr( 540 | base=K_scale + k_scale_offset, 541 | shape=(N_CTX_K,), 542 | strides=(stride_k_scale_m,), 543 | offsets=(0,), 544 | block_shape=(BLOCK_N,), 545 | order=(0,), 546 | ) 547 | {{% else %}} 548 | K_scale_block_ptr = None 549 | {{% endif %}} 550 | 551 | if start_m < NUM_SMS: 552 | triton.language.extra.cuda.experimental_device_tensormap_create2d( 553 | desc_ptr=K_desc_ptr, 554 | global_address=K + k_offset, 555 | load_size=[BLOCK_N, BLOCK_K], 556 | global_size=[N_CTX_K, D], 557 | element_ty=K.dtype.element_ty, 558 | ) 559 | triton.language.extra.cuda.experimental_device_tensormap_create2d( 560 | desc_ptr=V_desc_ptr, 561 | global_address=V + v_offset, 562 | load_size=[BLOCK_N, BLOCK_K], 563 | global_size=[N_CTX_K, D], 564 | element_ty=V.dtype.element_ty, 565 | ) 566 | 567 | tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(K_desc_ptr) 568 | tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(V_desc_ptr) 569 | 570 | {{% if IS_QUANTIZED %}} 571 | q_scale = tl.load(Q_scale_block_ptr, boundary_check=(0,)).to(tl.float32) 572 | q_scale = q_scale * ({{{{SM_SCALE * 1.44269504}}}}) 573 | {{% else %}} 574 | q_scale = None 575 | {{% endif %}} 576 | 577 | {{% for i in range(TILES) %}} 578 | q_{{{{i}}}} = tl.load(Q_block_ptr, boundary_check=(0,)) 579 | {{% if i + 1 < TILES %}} 580 | Q_block_ptr = tl.advance(Q_block_ptr, (0, BLOCK_K)) 581 | {{% endif %}} 582 | {{% endfor %}} 583 | 584 | # initialize pointer to m and l 585 | {{% for i in range(TILES) %}} 586 | acc_{{{{i}}}} = tl.zeros([BLOCK_M, BLOCK_K], dtype=ACC_TYPE) 587 | {{% endfor %}} 588 | m_i = tl.full([BLOCK_M], -float("inf"), dtype=acc_0.dtype) 589 | l_i = tl.full([BLOCK_M], 1.0, dtype=acc_0.dtype) 590 | # load scales 591 | # stage 1: off-band 592 | # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 593 | # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 594 | if STAGE & 1: 595 | ( 596 | {{% for i in range(TILES) %}} 597 | acc_{{{{i}}}}, 598 | {{% endfor %}} 599 | l_i, m_i, 600 | ) = _attn_fwd_inner( 601 | {{% for i in range(TILES) %}} 602 | acc_{{{{i}}}}, 603 | q_{{{{i}}}}, 604 | {{% endfor %}} 605 | q_scale, 606 | l_i, m_i, K_desc_ptr, V_desc_ptr, 607 | K_scale_block_ptr, 608 | start_m, 609 | V.dtype.element_ty, 610 | BLOCK_M, BLOCK_N, BLOCK_K, 611 | 4 - STAGE, N_CTX_Q, N_CTX_K, 612 | TILES, 613 | EVEN_N, 614 | QK_ACC_TYPE, 615 | ) 616 | # stage 2: on-band 617 | if STAGE & 2: 618 | # barrier makes it easier for compielr to schedule the 619 | # two loops independently 620 | tl.debug_barrier() 621 | ( 622 | {{% for i in range(TILES) %}} 623 | acc_{{{{i}}}}, 624 | {{% endfor %}} 625 | l_i, m_i, 626 | ) = _attn_fwd_inner( 627 | {{% for i in range(TILES) %}} 628 | acc_{{{{i}}}}, 629 | q_{{{{i}}}}, 630 | {{% endfor %}} 631 | q_scale, 632 | l_i, m_i, K_desc_ptr, V_desc_ptr, 633 | K_scale_block_ptr, 634 | start_m, 635 | V.dtype.element_ty, 636 | BLOCK_M, BLOCK_N, BLOCK_K, 637 | 2, N_CTX_Q, N_CTX_K, 638 | TILES, 639 | EVEN_N, 640 | QK_ACC_TYPE, 641 | ) 642 | 643 | # epilogue 644 | start_m = pid % num_programs_m 645 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 646 | off_hz = pid // num_programs_m 647 | off_z = off_hz // H 648 | off_h = off_hz % H 649 | # offs_m = offs_m.to(tl.int64) 650 | # off_z = off_z.to(tl.int64) 651 | # off_h = off_h.to(tl.int64) 652 | 653 | idx_m = offs_m[None, None, :, None] 654 | idx_z = tl.full([1, 1, 1, 1], off_z, dtype=idx_m.dtype) 655 | idx_h = tl.full([1, 1, 1, 1], off_h, dtype=idx_m.dtype) 656 | 657 | {{% for i in range(TILES) %}} 658 | acc_{{{{i}}}} = div(acc_{{{{i}}}}, l_i[:, None]) 659 | acc_{{{{i}}}} = acc_{{{{i}}}}[None, None, :, :] 660 | idx_d = tl.arange({{{{i}}}} * BLOCK_K, {{{{i + 1}}}} * BLOCK_K)[None, None, None, :] 661 | mask = (idx_z < Z) & (idx_h < H) & (idx_m < N_CTX_Q) & (idx_d < D) 662 | acc = acc_{{{{i}}}} 663 | {{% if i == 0 %}} 664 | {{{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask", indent_width=8)}}}} 665 | {{% else %}} 666 | 667 | {{% endif %}} 668 | {{% endfor %}} 669 | """, 670 | ) 671 | 672 | 673 | def attention_heuristic_configs( 674 | head_dim, 675 | B, 676 | H, 677 | N_CTX_Q, 678 | N_CTX_K, 679 | is_causal=False, 680 | layout=None, 681 | optimize_block_size=True, 682 | ): 683 | import triton 684 | 685 | # https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_fwd_launch_template.h 686 | # is_sm8x = torch.cuda.get_device_capability()[0] == 8 687 | 688 | if not checks.cuda_capability_compare("ge", 9, 0, device=layout.device): 689 | confs = "" 690 | else: 691 | if head_dim == 64: 692 | confs = "128.128.64.8.4 128.128.64.8.3 128.128.64.8.2 256.128.64.8.4 256.128.64.8.3 256.128.64.8.2" 693 | elif head_dim == 128: 694 | confs = "128.128.128.8.3 128.128.128.8.2 128.128.64.8.3 128.128.64.8.2" 695 | elif head_dim == 256: 696 | confs = "128.64.128.8.3 128.64.64.8.3 128.64.128.8.2 128.64.64.8.2" 697 | 698 | confs = [[int(x) for x in c.split(".")] for c in confs.split() if c] 699 | 700 | BLOCK_DMODEL = max(next_power_of_2(head_dim), 16) 701 | 702 | is_ge_sm90 = layout.device.type == "cuda" and checks.cuda_capability_compare("ge", 9, 0, device=layout.device) 703 | 704 | configs = [] 705 | picked_confs = set() 706 | for c in confs: 707 | BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages = c 708 | 709 | if optimize_block_size and checks.torch_version_compare("ge", "2.2.0"): 710 | n_ctx_q_hint = V.graph.sizevars.size_hint(N_CTX_Q, fallback=inductor_config.unbacked_symint_fallback) 711 | if n_ctx_q_hint <= 32: 712 | BLOCK_M = min(BLOCK_M, 32) 713 | elif n_ctx_q_hint <= 64: 714 | BLOCK_M = min(BLOCK_M, 64) 715 | # elif n_ctx_q_hint <= 96: 716 | # BLOCK_M = min(BLOCK_M, 32) 717 | 718 | n_ctx_k_hint = V.graph.sizevars.size_hint(N_CTX_K, fallback=inductor_config.unbacked_symint_fallback) 719 | if n_ctx_k_hint <= 32: 720 | BLOCK_N = min(BLOCK_N, 32) 721 | elif n_ctx_k_hint <= 64: 722 | BLOCK_N = min(BLOCK_N, 64) 723 | # elif n_ctx_k_hint <= 96: 724 | # BLOCK_N = min(BLOCK_N, 32) 725 | 726 | if layout.device.type == "cuda": 727 | b_hint = V.graph.sizevars.size_hint(B, fallback=inductor_config.unbacked_symint_fallback) 728 | h_hint = V.graph.sizevars.size_hint(H, fallback=inductor_config.unbacked_symint_fallback) 729 | BLOCK_M, _ = reduce_block_size_for_cuda( 730 | BLOCK_M, 1, n_ctx_q_hint, 1, device=layout.device, b=b_hint * h_hint 731 | ) 732 | 733 | while BLOCK_DMODEL % BLOCK_K != 0: 734 | BLOCK_K //= 2 735 | 736 | if BLOCK_M < 128 and is_ge_sm90 and not checks.triton_version_compare("ge", "3.1.0"): 737 | # https://github.com/triton-lang/triton/pull/4492 738 | # Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed. 739 | num_warps = min(num_warps, 4) 740 | 741 | key = (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) 742 | if key in picked_confs: 743 | continue 744 | picked_confs.add(key) 745 | 746 | configs.append( 747 | triton.Config( 748 | { 749 | "BLOCK_M": BLOCK_M, 750 | "BLOCK_N": BLOCK_N, 751 | "BLOCK_K": BLOCK_K, 752 | "BLOCK_DMODEL": BLOCK_DMODEL, 753 | }, 754 | num_warps=num_warps, 755 | num_stages=num_stages, 756 | ) 757 | ) 758 | 759 | if not use_max_autotune(): 760 | configs = configs[:1] 761 | 762 | return configs 763 | 764 | 765 | def early_attention_config_prune(configs, query, key, value): 766 | query_dtype = query.get_dtype() 767 | key_dtype = key.get_dtype() 768 | value_dtype = value.get_dtype() 769 | device = query.get_device() 770 | 771 | assert device.type == "cuda" 772 | 773 | max_shared_memory = get_device_shared_memory(device.index) 774 | 775 | filtered_configs = [] 776 | for c in configs: 777 | kw = c.kwargs 778 | BLOCK_M, BLOCK_N, BLOCK_DMODEL = kw["BLOCK_M"], kw["BLOCK_N"], kw["BLOCK_DMODEL"] 779 | num_stages = c.num_stages 780 | required_shared_memory = ( 781 | BLOCK_N * num_stages * (key_dtype.itemsize + value_dtype.itemsize) + BLOCK_M * query_dtype.itemsize 782 | ) * BLOCK_DMODEL 783 | if required_shared_memory <= max_shared_memory: 784 | filtered_configs.append(c) 785 | return filtered_configs 786 | 787 | 788 | def get_attention_layout( 789 | query, 790 | dtype, 791 | ): 792 | return ir.FixedLayout( 793 | query.get_device(), 794 | dtype, 795 | ir.convert_shape_to_inductor(query.get_size()), 796 | ) 797 | 798 | 799 | def generate_attention_template_choices( 800 | choices, 801 | query, 802 | key, 803 | value, 804 | scale_q=None, 805 | scale_k=None, 806 | attn_mask=None, 807 | dropout_p=0.0, 808 | is_causal=False, 809 | *, 810 | scale=None, 811 | layout2=None, 812 | enable_max_autotune=False, 813 | ): 814 | from torch._inductor.utils import get_num_sms, get_tma_workspace_arg, TMA_DESCRIPTOR_SIZE 815 | 816 | query_size, key_size = (x.get_size() for x in (query, key)) 817 | Lq = query_size[-1] 818 | Lq = V.graph.sizevars.evaluate_static_shape(Lq) 819 | N_CTX_Q, N_CTX_K = query_size[2], key_size[2] 820 | B, H = query_size[0], query_size[1] 821 | 822 | if scale is None: 823 | scale = float(1.0 / (Lq**0.5)) 824 | 825 | key_t = ir.PermuteView.create(key, [0, 1, 3, 2]) 826 | m1, n1, k1, layout1, mat1, mat2 = mm_args(query, key_t) 827 | stage = 3 if is_causal else 1 828 | 829 | args = [query, key, value] 830 | if scale_q is not None: 831 | args += [scale_q, scale_k] 832 | if attn_mask is not None: 833 | args.append(attn_mask) 834 | 835 | dynamic = is_dynamic(*args) 836 | 837 | triton_configs = [] 838 | heuristic_configs = attention_heuristic_configs( 839 | Lq, 840 | B, 841 | H, 842 | N_CTX_Q, 843 | N_CTX_K, 844 | is_causal=is_causal, 845 | layout=layout2, 846 | optimize_block_size=not dynamic or enable_max_autotune, 847 | ) 848 | triton_configs.extend(heuristic_configs) 849 | 850 | triton_configs = early_attention_config_prune(triton_configs, query, key, value) 851 | 852 | for fa_config in triton_configs: 853 | mm_options_ = mm_options(fa_config, m1, n1, k1, layout1) 854 | mm_options_["ACC_TYPE"] = acc_type(value.get_dtype()) 855 | if scale_q is None: 856 | mm_options_["QK_ACC_TYPE"] = acc_type(value.get_dtype()) 857 | else: 858 | mm_options_["QK_ACC_TYPE"] = "tl.float32" 859 | fast_softmax = not dynamic and mm_options_["BLOCK_N"] >= N_CTX_K 860 | even_n_symbolic = ( 861 | # it isn't worth guarding on this 862 | sympy.gcd(N_CTX_K, mm_options_["BLOCK_N"]) 863 | == mm_options_["BLOCK_N"] 864 | ) 865 | 866 | attention_forward_template.maybe_append_choice( 867 | choices, 868 | input_nodes=args, 869 | layout=layout2, 870 | workspace_arg=get_tma_workspace_arg( 871 | num_tma_descriptors=2, 872 | device=query.get_device(), 873 | ), 874 | SM_SCALE=float(1.0 / (Lq**0.5)) if scale is None else scale, 875 | STAGE=stage, 876 | TILES=cdiv(Lq, mm_options_["BLOCK_K"]), 877 | EVEN_N=even_n_symbolic, 878 | NUM_STAGES=fa_config.num_stages, 879 | FAST_SOFTMAX=fast_softmax, 880 | USE_FP16_COMPUTE=mm_options_["ACC_TYPE"] == "tl.float16", 881 | TMA_SIZE=TMA_DESCRIPTOR_SIZE, 882 | NUM_SMS=get_num_sms(), 883 | IS_QUANTIZED=scale_q is not None, 884 | **mm_options_, 885 | ) 886 | 887 | 888 | def tuned_attention_forward( 889 | query, 890 | key, 891 | value, 892 | attn_mask=None, 893 | dropout_p=0.0, 894 | is_causal=False, 895 | *, 896 | scale=None, 897 | layout=None, 898 | scale_q=None, 899 | scale_k=None, 900 | ): 901 | assert (scale_q is None) == (scale_k is None) 902 | 903 | k1 = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) 904 | n2 = V.graph.sizevars.evaluate_static_shape(value.get_size()[-1]) 905 | 906 | use_tk_tma_kernel = ( 907 | config.attention.enable_tk_tma_kernel 908 | and query.get_dtype() in (torch.float16, torch.bfloat16, torch.float8_e4m3fn) 909 | and checks.cuda_capability_compare("ge", 9, 0) 910 | and attn_mask is None 911 | and dropout_p == 0.0 912 | and scale is None 913 | and k1 == n2 914 | and k1 in (64, 128, 256) 915 | and ( 916 | query.get_dtype() != torch.float8_e4m3fn 917 | or (scale_q is not None and len(scale_q.get_size()) + 2 == len(query.get_size())) 918 | ) 919 | ) 920 | 921 | use_triton_tma_kernel = ( 922 | config.attention.enable_triton_tma_kernel 923 | and query.get_dtype() in (torch.float16, torch.bfloat16, torch.float8_e4m3fn) 924 | and checks.torch_version_compare("ge", "2.7.0") 925 | and checks.has_triton_tma_support() 926 | and attn_mask is None 927 | and dropout_p == 0.0 928 | and k1 == n2 929 | and k1 in (64, 128, 256) 930 | and ( 931 | query.get_dtype() != torch.float8_e4m3fn 932 | or (scale_q is not None and len(scale_q.get_size()) + 1 == len(query.get_size())) 933 | ) 934 | ) 935 | 936 | use_aten_attention_kernel = use_aten_gemm_kernels() 937 | 938 | query, key, value = (ir.ExternKernel.realize_input(x) for x in (query, key, value)) 939 | if use_tk_tma_kernel: 940 | query, key, value = (require_dense_memory(x) for x in (query, key, value)) 941 | if attn_mask is not None: 942 | attn_mask = require_dense_memory(attn_mask) 943 | elif use_triton_tma_kernel: 944 | query = require_dense_memory(query, num_dims=1) 945 | key, value = (require_dense_memory(x, num_dims=2) for x in (key, value)) 946 | if attn_mask is not None: 947 | attn_mask = require_dense_memory(attn_mask, num_dims=2) 948 | 949 | if scale_q is not None: 950 | scale_q, scale_k = (ir.ExternKernel.realize_input(x) for x in (scale_q, scale_k)) 951 | if use_tk_tma_kernel: 952 | scale_q, scale_k = (require_dense_memory(x) for x in (scale_q, scale_k)) 953 | elif use_triton_tma_kernel: 954 | scale_q, scale_k = (require_dense_memory(x, num_dims=1) for x in (scale_q, scale_k)) 955 | for x in (query, key, value, scale_q, scale_k, attn_mask): 956 | if x is not None: 957 | x.freeze_layout() 958 | key_t = ir.PermuteView.create(key, [0, 1, 3, 2]) 959 | m1, n1, k1, layout1, mat1, mat2 = mm_args(query, key_t) 960 | 961 | # if scale is None or math.isnan( 962 | # scale): # og_scale.as_float_unchecked() could be nan 963 | # scale = float(1.0 / (k1**0.5)) 964 | 965 | kwargs = { 966 | "dropout_p": dropout_p, 967 | "is_causal": is_causal, 968 | "scale": scale, 969 | } 970 | ordered_kwargs_for_cpp_kernel = [ 971 | "dropout_p", 972 | "is_causal", 973 | "scale", 974 | ] 975 | 976 | if attn_mask is None: 977 | args = [query, key, value] 978 | if scale_q is not None: 979 | args += [scale_q, scale_k] 980 | kwargs["attn_mask"] = None 981 | ordered_kwargs_for_cpp_kernel.insert(0, "attn_mask") 982 | else: 983 | args = [query, key, value] 984 | if scale_q is not None: 985 | args += [scale_q, scale_k] 986 | args.append(attn_mask) 987 | 988 | if layout is None: 989 | layout2 = get_attention_layout(query, value.get_dtype()) 990 | else: 991 | layout2 = layout 992 | 993 | choices = [] 994 | if use_tk_tma_kernel: 995 | if query.get_dtype() == torch.float8_e4m3fn: 996 | choices.append( 997 | tk_fp8_attention_forward.bind( 998 | args, 999 | layout=layout2, 1000 | **kwargs, 1001 | ) 1002 | ) 1003 | else: 1004 | choices.append( 1005 | tk_attention_forward.bind( 1006 | args, 1007 | layout=layout2, 1008 | **kwargs, 1009 | ) 1010 | ) 1011 | if use_triton_tma_kernel: 1012 | generate_attention_template_choices( 1013 | choices, *args, layout2=layout2, enable_max_autotune=use_max_autotune(), **kwargs 1014 | ) 1015 | if use_aten_attention_kernel: 1016 | if query.get_dtype() == torch.float8_e4m3fn: 1017 | choices.append( 1018 | aten_fp8_attention_forward.bind( 1019 | args, 1020 | layout=layout2, 1021 | **kwargs, 1022 | ) 1023 | ) 1024 | else: 1025 | choices.append( 1026 | aten_attention_forward.bind( 1027 | args, 1028 | layout=layout2, 1029 | **kwargs, 1030 | ) 1031 | ) 1032 | if not use_max_autotune(): 1033 | choices = choices[:1] 1034 | return autotune_select_algorithm("quantum_attn_attention_forward", choices, args, layout2) 1035 | 1036 | 1037 | register_lowering(quantum_attn_ops.attention_forward.default, type_promotion_kind=None)(tuned_attention_forward) 1038 | 1039 | 1040 | @register_lowering(quantum_attn_ops.fp8_attention_forward.default, type_promotion_kind=None) 1041 | def fp8_attention_forward( 1042 | query, 1043 | key, 1044 | value, 1045 | scale_q=None, 1046 | scale_k=None, 1047 | attn_mask=None, 1048 | dropout_p=0.0, 1049 | is_causal=False, 1050 | *, 1051 | scale=None, 1052 | layout=None, 1053 | ): 1054 | return tuned_attention_forward( 1055 | query, 1056 | key, 1057 | value, 1058 | attn_mask, 1059 | dropout_p, 1060 | is_causal, 1061 | scale=scale, 1062 | layout=layout, 1063 | scale_q=scale_q, 1064 | scale_k=scale_k, 1065 | ) 1066 | -------------------------------------------------------------------------------- /src/quantum_attn/inductor/kernels/mm_common.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import unittest 3 | 4 | import torch 5 | from torch._inductor import ir 6 | from torch._inductor.kernel import mm_common 7 | from torch._inductor.utils import ceildiv as cdiv 8 | from torch._inductor.virtualized import V 9 | 10 | from quantum_attn import config 11 | from quantum_attn.utils import checks 12 | 13 | 14 | def require_dense_memory(x, num_dims=None): 15 | strides = x.get_stride() 16 | try: 17 | strides_hint = [V.graph.sizevars.size_hint(s) for s in strides] 18 | stride_order = ir.get_stride_order(strides_hint) 19 | except Exception: 20 | # logger.warning(f"Failed to get size hint for strides: {strides}", exc_info=True) 21 | stride_order = list(reversed(range(len(x.get_size())))) 22 | required_strides = ir.FlexibleLayout.stride_ordered(x.get_size(), stride_order) 23 | if isinstance(x.layout, ir.FlexibleLayout): 24 | x.freeze_layout_with_same_order(required_strides) 25 | else: 26 | for i, (size, stride, required_stride) in enumerate(zip(x.layout.size, x.layout.stride, required_strides)): 27 | if num_dims is not None and i + num_dims < len(x.layout.size): 28 | continue 29 | if stride != required_stride and size != 1: 30 | x = ir.ExternKernel.copy_input(x) 31 | x.freeze_layout_with_same_order(required_strides) 32 | break 33 | return x 34 | 35 | 36 | def acc_type(dtype): 37 | if dtype == torch.float16: 38 | if config.triton.allow_reduced_precision_compute and torch.version.hip is None: 39 | return "tl.float16" 40 | else: 41 | return "tl.float32" 42 | elif dtype == torch.bfloat16: 43 | return "tl.float32" 44 | else: 45 | return f"tl.{dtype}".replace("torch.", "") 46 | 47 | 48 | @functools.cache 49 | def get_device_shared_memory(device=0): 50 | try: 51 | from triton.runtime import driver 52 | 53 | if hasattr(driver, "active"): 54 | return driver.active.utils.get_device_properties(device)["max_shared_mem"] 55 | return driver.utils.get_device_properties(device)["max_shared_mem"] 56 | except Exception: 57 | return 1024**3 58 | 59 | 60 | def mm_options(c, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None, optimize_block_size=True): 61 | with unittest.mock.patch.object(mm_common, "acc_type", acc_type): 62 | if checks.torch_version_compare("ge", "2.3.0"): 63 | options = mm_common.mm_options(c, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=b_prologue_cast_type) 64 | else: 65 | options = mm_common.mm_options(c, sym_k, layout, b_prologue_cast_type=b_prologue_cast_type) 66 | options["ENABLE_FAST_MATH"] = ( 67 | config.triton.enable_fast_math 68 | and checks.has_triton_language("inline_asm_elementwise") 69 | and checks.is_nvidia_cuda() 70 | ) 71 | device_capability = ( 72 | torch.cuda.get_device_capability(layout.device.index or 0) if checks.is_nvidia_cuda() else (0, 0) 73 | ) 74 | cuda_arch = device_capability[0] * 100 + device_capability[1] * 10 75 | options["CUDA_ARCH"] = cuda_arch 76 | cuda_version = checks.torch_cuda_version() 77 | cuda_version = cuda_version[0] * 1000 + cuda_version[1] * 10 78 | options["CUDA_VERSION"] = cuda_version 79 | options["USE_FAST_ACCUM"] = config.use_fast_accum 80 | 81 | return options 82 | 83 | 84 | def reduce_block_size_for_cuda(BLOCK_M, BLOCK_N, m, n, device=None, b=1): 85 | if device is None: 86 | device_index = 0 87 | else: 88 | device_index = device.index or 0 89 | avail_sms = torch.cuda.get_device_properties(device_index).multi_processor_count 90 | if b * cdiv(m, BLOCK_M) * cdiv(n, BLOCK_N) < avail_sms: 91 | # if checks.cuda_capability_compare("ge", 9, 0, device=device): 92 | # # Keep using WGMMA 93 | # min_m = 64 94 | # else: 95 | # min_m = 16 96 | min_m = 16 97 | while True: 98 | if BLOCK_M >= BLOCK_N: 99 | if BLOCK_M > min_m and b * cdiv(m, BLOCK_M // 2) * cdiv(n, BLOCK_N) <= avail_sms: 100 | BLOCK_M //= 2 101 | continue 102 | if BLOCK_N > 16 and b * cdiv(m, BLOCK_M) * cdiv(n, BLOCK_N // 2) <= avail_sms: 103 | BLOCK_N //= 2 104 | continue 105 | else: 106 | if BLOCK_N > 16 and b * cdiv(m, BLOCK_M) * cdiv(n, BLOCK_N // 2) <= avail_sms: 107 | BLOCK_N //= 2 108 | continue 109 | if BLOCK_M > min_m and b * cdiv(m, BLOCK_M // 2) * cdiv(n, BLOCK_N) <= avail_sms: 110 | BLOCK_M //= 2 111 | continue 112 | break 113 | return BLOCK_M, BLOCK_N 114 | -------------------------------------------------------------------------------- /src/quantum_attn/nn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch._higher_order_ops.utils import _set_compilation_env 6 | from torch.fx.experimental.proxy_tensor import _temp_remove_pre_dispatch_torch_function_mode 7 | 8 | from quantum_attn import config 9 | from quantum_attn.utils import checks 10 | 11 | quantum_attn_ops = torch.ops.quantum_attn 12 | 13 | 14 | def _dynamically_quantize_fp8(t: torch.Tensor, *, reduction_dim=-1) -> Tuple[torch.Tensor, torch.Tensor]: 15 | eps = torch.finfo(torch.float32).eps 16 | q_max = torch.finfo(torch.float8_e4m3fn).max 17 | scale = t.abs().amax(reduction_dim, keepdim=True).mul(1.0 / q_max).clamp_min(eps) 18 | t_fp8 = (t / scale).clamp(-q_max, q_max).to(torch.float8_e4m3fn) 19 | return t_fp8, scale.squeeze(reduction_dim).to(torch.float32) 20 | 21 | 22 | def dynamically_quantize_fp8(t: torch.Tensor, *, reduction_dim=-1) -> Tuple[torch.Tensor, torch.Tensor]: 23 | from torch._subclasses.fake_tensor import is_fake 24 | 25 | if any(is_fake(x) for x in (t,)): 26 | out = _dynamically_quantize_fp8(t, reduction_dim=reduction_dim) 27 | return out 28 | 29 | if torch.compiler.is_dynamo_compiling(): 30 | out = _dynamically_quantize_fp8(t, reduction_dim=reduction_dim) 31 | return out 32 | 33 | with _set_compilation_env(): 34 | with torch._dynamo.utils.disable_cache_limit(): 35 | with _temp_remove_pre_dispatch_torch_function_mode(): 36 | out = torch.compile( 37 | _dynamically_quantize_fp8, backend="inductor", fullgraph=True, dynamic=config.dynamo.dynamic 38 | )( 39 | t, 40 | reduction_dim=reduction_dim, 41 | ) 42 | return out 43 | 44 | 45 | _TK_TMA_SUPPORTED_HEAD_DIMS = [64, 128, 256] 46 | 47 | 48 | def _tk_tma_supported_head_dim(n: Union[int, torch.SymInt]) -> bool: 49 | return n in _TK_TMA_SUPPORTED_HEAD_DIMS 50 | 51 | 52 | def _validate_tk_tma_input( 53 | query: torch.Tensor, 54 | key: torch.Tensor, 55 | value: torch.Tensor, 56 | attn_mask: Optional[torch.Tensor] = None, 57 | dropout_p=0.0, 58 | is_causal=False, 59 | scale=None, 60 | scaling_method=None, 61 | ) -> Tuple[bool, str]: 62 | if any(t.requires_grad for t in (query, key, value)): 63 | return False, "NYI: query, key, and value must be leaf tensors" 64 | 65 | if attn_mask is not None: 66 | return False, "NYI: attn_mask must be None" 67 | if dropout_p != 0.0: 68 | return False, "NYI: dropout_p must be 0.0" 69 | if scale is not None: 70 | return False, "NYI: scale must be None" 71 | if scaling_method is None: 72 | if query.dtype != key.dtype or query.dtype != value.dtype: 73 | return ( 74 | False, 75 | f"Expected query, key, and value to have the same dtype, but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, and value.dtype: {value.dtype} instead.", 76 | ) 77 | if query.dtype not in (torch.float16, torch.bfloat16): 78 | return ( 79 | False, 80 | f"Expected query, key, and value to have dtype torch.float16 or torch.bfloat16, but got query.dtype: {query.dtype} instead.", 81 | ) 82 | else: 83 | if scaling_method != "head-wise": 84 | return False, f"Unsupported scaling_method: {scaling_method}" 85 | if query.dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): 86 | return ( 87 | False, 88 | f"Expected query to have dtype torch.float16, torch.bfloat16, or torch.float8_e4m3fn, but got query.dtype: {query.dtype} instead.", 89 | ) 90 | if query.dtype != key.dtype: 91 | return ( 92 | False, 93 | f"Expected query and key to have the same dtype, but got query.dtype: {query.dtype}, key.dtype: {key.dtype} instead.", 94 | ) 95 | if value.dtype not in (torch.float16, torch.bfloat16): 96 | return ( 97 | False, 98 | f"Expected value to have dtype torch.float16 or torch.bfloat16, but got value.dtype: {value.dtype} instead.", 99 | ) 100 | if query.device != key.device or query.device != value.device: 101 | return ( 102 | False, 103 | f"Expected query, key, and value to have the same device type, but got query.device: {query.device}, key.device: {key.device}, and value.device: {value.device} instead.", 104 | ) 105 | if query.device.type != "cuda": 106 | return False, "Expected query, key, and value to be on a CUDA device" 107 | if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: 108 | return False, "NYI: query, key, and value must be 4D tensors" 109 | if key.size(-2) != value.size(-2): 110 | return ( 111 | False, 112 | f"Expect key and value to have the same sequence length but got Sk={key.size(-2)} and Sv={value.size(-2)}.", 113 | ) 114 | if value.size(-1) != query.size(-1): 115 | return False, "NYI: query and value must have the same embedding dimension" 116 | if query.size(-3) != key.size(-3): 117 | return ( 118 | False, 119 | f"Expect query and key/value to have the same number of heads but got Hq={query.size(-3)} and Hkv={key.size(-3)}.", 120 | ) 121 | if not _tk_tma_supported_head_dim(query.size(-1)): 122 | return False, f"Unsupported head dimension: {query.size(-1)}" 123 | 124 | return True, "" 125 | 126 | 127 | _TRITON_TMA_SDPA_SUPPORTED_HEAD_DIMS = [64, 128, 256] 128 | 129 | 130 | def _triton_tma_sdpa_supported_head_dim(n: Union[int, torch.SymInt]) -> bool: 131 | """Returns true if the head dim is supported by FlexAttention""" 132 | return n in _TRITON_TMA_SDPA_SUPPORTED_HEAD_DIMS 133 | 134 | 135 | def _validate_triton_tma_sdpa_input( 136 | query: torch.Tensor, 137 | key: torch.Tensor, 138 | value: torch.Tensor, 139 | attn_mask: Optional[torch.Tensor] = None, 140 | dropout_p=0.0, 141 | is_causal=False, 142 | scale=None, 143 | scaling_method=None, 144 | ) -> Tuple[bool, str]: 145 | if any(t.requires_grad for t in (query, key, value)): 146 | return False, "NYI: query, key, and value must be leaf tensors" 147 | 148 | if attn_mask is not None: 149 | return False, "NYI: attn_mask must be None" 150 | if dropout_p != 0.0: 151 | return False, "NYI: dropout_p must be 0.0" 152 | if scaling_method is None: 153 | if query.dtype != key.dtype or query.dtype != value.dtype: 154 | return ( 155 | False, 156 | f"Expected query, key, and value to have the same dtype, but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, and value.dtype: {value.dtype} instead.", 157 | ) 158 | if query.dtype not in (torch.float16, torch.bfloat16): 159 | return ( 160 | False, 161 | f"Expected query, key, and value to have dtype torch.float16 or torch.bfloat16, but got query.dtype: {query.dtype} instead.", 162 | ) 163 | else: 164 | if scaling_method != "token-wise": 165 | return False, f"Unsupported scaling_method: {scaling_method}" 166 | if query.dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): 167 | return ( 168 | False, 169 | f"Expected query to have dtype torch.float16, torch.bfloat16, or torch.float8_e4m3fn, but got query.dtype: {query.dtype} instead.", 170 | ) 171 | if query.dtype != key.dtype: 172 | return ( 173 | False, 174 | f"Expected query and key to have the same dtype, but got query.dtype: {query.dtype}, key.dtype: {key.dtype} instead.", 175 | ) 176 | if value.dtype not in (torch.float16, torch.bfloat16): 177 | return ( 178 | False, 179 | f"Expected value to have dtype torch.float16 or torch.bfloat16, but got value.dtype: {value.dtype} instead.", 180 | ) 181 | if query.device != key.device or query.device != value.device: 182 | return ( 183 | False, 184 | f"Expected query, key, and value to have the same device type, but got query.device: {query.device}, key.device: {key.device}, and value.device: {value.device} instead.", 185 | ) 186 | if query.device.type != "cuda": 187 | return False, "Expected query, key, and value to be on a CUDA device" 188 | if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: 189 | return False, "NYI: query, key, and value must be 4D tensors" 190 | if key.size(-2) != value.size(-2): 191 | return ( 192 | False, 193 | f"Expect key and value to have the same sequence length but got Sk={key.size(-2)} and Sv={value.size(-2)}.", 194 | ) 195 | if value.size(-1) != query.size(-1): 196 | return False, "NYI: query and value must have the same embedding dimension" 197 | if query.size(-3) != key.size(-3): 198 | return ( 199 | False, 200 | f"Expect query and key/value to have the same number of heads but got Hq={query.size(-3)} and Hkv={key.size(-3)}.", 201 | ) 202 | if not _triton_tma_sdpa_supported_head_dim(query.size(-1)): 203 | return False, f"Unsupported head dimension: {query.size(-1)}" 204 | 205 | return True, "" 206 | 207 | 208 | @torch.compiler.assume_constant_result 209 | def _pre_check_can_use_tk_tma_attention(device): 210 | if device.type != "cuda": 211 | return False, f"Expected device to be on a CUDA device, but got device: {device} instead." 212 | if not config.attention.enable_tk_tma_kernel: 213 | return False, "TK TMA kernel is disabled" 214 | if not checks.cuda_capability_compare("ge", 9, 0, device=device): 215 | return False, "Minimum CUDA capability of 9.0 is required" 216 | return True, "" 217 | 218 | 219 | def can_use_tk_tma_attention( 220 | query: torch.Tensor, 221 | key: torch.Tensor, 222 | value: torch.Tensor, 223 | attn_mask: Optional[torch.Tensor] = None, 224 | dropout_p: float = 0.0, 225 | is_causal: bool = False, 226 | *, 227 | scale: Optional[float] = None, 228 | scaling_method: Optional[str] = None, 229 | ) -> Tuple[bool, str]: 230 | supported, reason = _pre_check_can_use_tk_tma_attention(device=query.device) 231 | if not supported: 232 | return False, reason 233 | 234 | valid, reason = _validate_tk_tma_input( 235 | query, key, value, attn_mask, dropout_p, is_causal, scale, scaling_method=scaling_method 236 | ) 237 | if not valid: 238 | return False, reason 239 | 240 | return True, "" 241 | 242 | 243 | @torch.compiler.assume_constant_result 244 | def _pre_check_can_use_triton_tma_attention(device): 245 | if device.type != "cuda": 246 | return False, f"Expected device to be on a CUDA device, but got device: {device} instead." 247 | if not config.attention.enable_triton_tma_kernel: 248 | return False, "Triton TMA kernel is disabled" 249 | if not checks.cuda_capability_compare("ge", 9, 0, device=device): 250 | return False, "Minimum CUDA capability of 9.0 is required" 251 | if not checks.torch_version_compare("ge", "2.7.0"): 252 | return False, "Minimum PyTorch version of 2.7.0 is required" 253 | if not checks.has_triton_tma_support(): 254 | return False, "Triton TMA support is required" 255 | return True, "" 256 | 257 | 258 | def can_use_triton_tma_attention( 259 | query: torch.Tensor, 260 | key: torch.Tensor, 261 | value: torch.Tensor, 262 | attn_mask: Optional[torch.Tensor] = None, 263 | dropout_p: float = 0.0, 264 | is_causal: bool = False, 265 | *, 266 | scale: Optional[float] = None, 267 | scaling_method: Optional[str] = None, 268 | ) -> Tuple[bool, str]: 269 | supported, reason = _pre_check_can_use_triton_tma_attention(device=query.device) 270 | if not supported: 271 | return False, reason 272 | 273 | valid, reason = _validate_triton_tma_sdpa_input( 274 | query, key, value, attn_mask, dropout_p, is_causal, scale, scaling_method=scaling_method 275 | ) 276 | if not valid: 277 | return False, reason 278 | 279 | return True, "" 280 | 281 | 282 | def can_use_attention( 283 | query: torch.Tensor, 284 | key: torch.Tensor, 285 | value: torch.Tensor, 286 | attn_mask: Optional[torch.Tensor] = None, 287 | dropout_p: float = 0.0, 288 | is_causal: bool = False, 289 | *, 290 | scale: Optional[float] = None, 291 | scaling_method: Optional[str] = None, 292 | ) -> Tuple[bool, str]: 293 | if checks.get_constant_attr("quantum_attn.config", "attention.skip_supported_check"): 294 | return True, "" 295 | prefix_and_funcs = [ 296 | ("tk_tma", can_use_tk_tma_attention), 297 | ("triton_tma_sdpa", can_use_triton_tma_attention), 298 | ] 299 | reasons = [] 300 | for prefix, func in prefix_and_funcs: 301 | supported, reason = func( 302 | query, key, value, attn_mask, dropout_p, is_causal, scale=scale, scaling_method=scaling_method 303 | ) 304 | if supported: 305 | return True, "" 306 | reasons.append((f"{prefix}: {reason}")) 307 | return False, " ".join(f"[{reason}]" for reason in reasons) 308 | 309 | 310 | def _attention_wrapper( 311 | query: torch.Tensor, 312 | key: torch.Tensor, 313 | value: torch.Tensor, 314 | attn_mask: Optional[torch.Tensor] = None, 315 | dropout_p: float = 0.0, 316 | is_causal: bool = False, 317 | *, 318 | scale: Optional[float] = None, 319 | ) -> Tensor: 320 | return quantum_attn_ops.attention_forward( 321 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 322 | ) 323 | 324 | 325 | def attention( 326 | query: torch.Tensor, 327 | key: torch.Tensor, 328 | value: torch.Tensor, 329 | attn_mask: Optional[torch.Tensor] = None, 330 | dropout_p: float = 0.0, 331 | is_causal: bool = False, 332 | *, 333 | scale: Optional[float] = None, 334 | ) -> Tensor: 335 | if not torch._dynamo.is_dynamo_supported(): 336 | raise RuntimeError("attention_forward requires dynamo support") 337 | if not torch._dynamo.is_inductor_supported(): 338 | raise RuntimeError("attention_forward requires inductor support") 339 | 340 | supported, reason = can_use_attention( 341 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 342 | ) 343 | if not supported: 344 | raise ValueError(f"Unsupported input: {reason}") 345 | 346 | # if scale is None: 347 | # scale = 1.0 / math.sqrt(query.size(-1)) 348 | 349 | from torch._subclasses.fake_tensor import is_fake 350 | 351 | if any(is_fake(x) for x in (query, key, value, attn_mask)): 352 | out = _attention_wrapper( 353 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 354 | ) 355 | return out 356 | 357 | if torch.compiler.is_dynamo_compiling(): 358 | # mark head_dim and number of heads to be static 359 | for x in [query, key, value]: 360 | torch._dynamo.mark_static(x, -3) 361 | torch._dynamo.mark_static(x, -1) 362 | out = _attention_wrapper( 363 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 364 | ) 365 | return out 366 | 367 | if config.attention.force_eager_fallback: 368 | out = _attention_wrapper( 369 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 370 | ) 371 | return out 372 | 373 | with _set_compilation_env(): 374 | with torch._dynamo.utils.disable_cache_limit(): 375 | with _temp_remove_pre_dispatch_torch_function_mode(): 376 | out = torch.compile( 377 | _attention_wrapper, 378 | backend="inductor", 379 | fullgraph=True, 380 | dynamic=config.dynamo.dynamic, 381 | mode=config.dynamo.mode, 382 | )( 383 | query, 384 | key, 385 | value, 386 | attn_mask=attn_mask, 387 | dropout_p=dropout_p, 388 | is_causal=is_causal, 389 | scale=scale, 390 | ) 391 | return out 392 | 393 | 394 | def _fp8_attention_wrapper( 395 | query: torch.Tensor, 396 | key: torch.Tensor, 397 | value: torch.Tensor, 398 | attn_mask: Optional[torch.Tensor] = None, 399 | dropout_p: float = 0.0, 400 | is_causal: bool = False, 401 | *, 402 | scale: Optional[float] = None, 403 | scale_q: Optional[torch.Tensor] = None, 404 | scale_k: Optional[torch.Tensor] = None, 405 | scaling_method: Optional[str] = None, 406 | ) -> Tensor: 407 | if (scale_q is None) != (scale_k is None): 408 | raise ValueError("scale_q and scale_k must be both provided or both not provided") 409 | 410 | if scale_q is None: 411 | if scaling_method == "head-wise": 412 | reduction_dim = [query.dim() - 2, query.dim() - 1] 413 | elif scaling_method == "token-wise": 414 | reduction_dim = query.dim() - 1 415 | else: 416 | raise ValueError(f"Unsupported scaling_method: {scaling_method}") 417 | query, scale_q = _dynamically_quantize_fp8(query, reduction_dim=reduction_dim) 418 | key, scale_k = _dynamically_quantize_fp8(key, reduction_dim=reduction_dim) 419 | 420 | return quantum_attn_ops.fp8_attention_forward( 421 | query, 422 | key, 423 | value, 424 | scale_q, 425 | scale_k, 426 | attn_mask=attn_mask, 427 | dropout_p=dropout_p, 428 | is_causal=is_causal, 429 | scale=scale, 430 | ) 431 | 432 | 433 | def fp8_attention( 434 | query: torch.Tensor, 435 | key: torch.Tensor, 436 | value: torch.Tensor, 437 | attn_mask: Optional[torch.Tensor] = None, 438 | dropout_p: float = 0.0, 439 | is_causal: bool = False, 440 | *, 441 | scale: Optional[float] = None, 442 | scale_q: Optional[torch.Tensor] = None, 443 | scale_k: Optional[torch.Tensor] = None, 444 | scaling_method: Optional[str] = None, 445 | ) -> Tensor: 446 | if not torch._dynamo.is_dynamo_supported(): 447 | raise RuntimeError("fp8_attention_forward requires dynamo support") 448 | if not torch._dynamo.is_inductor_supported(): 449 | raise RuntimeError("fp8_attention_forward requires inductor support") 450 | 451 | supported, reason = can_use_attention( 452 | query, 453 | key, 454 | value, 455 | attn_mask=attn_mask, 456 | dropout_p=dropout_p, 457 | is_causal=is_causal, 458 | scale=scale, 459 | scaling_method=scaling_method, 460 | ) 461 | if not supported: 462 | raise ValueError(reason) 463 | 464 | # if scale is None: 465 | # scale = 1.0 / math.sqrt(query.size(-1)) 466 | 467 | from torch._subclasses.fake_tensor import is_fake 468 | 469 | if any(is_fake(x) for x in (query, key, value, attn_mask, scale_q, scale_k)): 470 | out = _fp8_attention_wrapper( 471 | query, 472 | key, 473 | value, 474 | attn_mask=attn_mask, 475 | dropout_p=dropout_p, 476 | is_causal=is_causal, 477 | scale=scale, 478 | scale_q=scale_q, 479 | scale_k=scale_k, 480 | scaling_method=scaling_method, 481 | ) 482 | return out 483 | 484 | if torch.compiler.is_dynamo_compiling(): 485 | # mark head_dim and number of heads to be static 486 | for x in [query, key, value]: 487 | torch._dynamo.mark_static(x, -3) 488 | torch._dynamo.mark_static(x, -1) 489 | out = _fp8_attention_wrapper( 490 | query, 491 | key, 492 | value, 493 | attn_mask=attn_mask, 494 | dropout_p=dropout_p, 495 | is_causal=is_causal, 496 | scale=scale, 497 | scale_q=scale_q, 498 | scale_k=scale_k, 499 | scaling_method=scaling_method, 500 | ) 501 | return out 502 | 503 | if config.attention.force_eager_fallback: 504 | out = _fp8_attention_wrapper( 505 | query, 506 | key, 507 | value, 508 | attn_mask=attn_mask, 509 | dropout_p=dropout_p, 510 | is_causal=is_causal, 511 | scale=scale, 512 | scale_q=scale_q, 513 | scale_k=scale_k, 514 | scaling_method=scaling_method, 515 | ) 516 | return out 517 | 518 | with _set_compilation_env(): 519 | with torch._dynamo.utils.disable_cache_limit(): 520 | with _temp_remove_pre_dispatch_torch_function_mode(): 521 | out = torch.compile( 522 | _fp8_attention_wrapper, 523 | backend="inductor", 524 | fullgraph=True, 525 | dynamic=config.dynamo.dynamic, 526 | mode=config.dynamo.mode, 527 | )( 528 | query, 529 | key, 530 | value, 531 | attn_mask=attn_mask, 532 | dropout_p=dropout_p, 533 | is_causal=is_causal, 534 | scale=scale, 535 | scale_q=scale_q, 536 | scale_k=scale_k, 537 | scaling_method=scaling_method, 538 | ) 539 | return out 540 | -------------------------------------------------------------------------------- /src/quantum_attn/ops.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | aten = torch.ops.aten 6 | 7 | # torch.compile() support is only enabled for pytorch >= 2.4 8 | # The reason for this is that we are using the new custom_op and register_fake 9 | # APIs, which support inplace modification of inputs in the function itself 10 | if torch.__version__ >= "2.4.0": 11 | _torch_custom_op_wrapper = torch.library.custom_op 12 | _torch_register_fake_wrapper = torch.library.register_fake 13 | else: 14 | raise RuntimeError("Your PyTorch version is too old. Please upgrade to PyTorch >= 2.4.0") 15 | 16 | 17 | def _attention_forward( 18 | query: torch.Tensor, 19 | key: torch.Tensor, 20 | value: torch.Tensor, 21 | attn_mask: Optional[torch.Tensor] = None, 22 | dropout_p: float = 0.0, 23 | is_causal: bool = False, 24 | *, 25 | scale: Optional[float] = None, 26 | ) -> torch.Tensor: 27 | return aten.scaled_dot_product_attention( 28 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 29 | ).contiguous() 30 | 31 | 32 | @_torch_custom_op_wrapper("quantum_attn::attention_forward", mutates_args=(), device_types=("cuda",)) 33 | def attention_forward( 34 | query: torch.Tensor, 35 | key: torch.Tensor, 36 | value: torch.Tensor, 37 | attn_mask: Optional[torch.Tensor] = None, 38 | dropout_p: float = 0.0, 39 | is_causal: bool = False, 40 | *, 41 | scale: Optional[float] = None, 42 | ) -> torch.Tensor: 43 | return _attention_forward( 44 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 45 | ) 46 | 47 | 48 | @_torch_register_fake_wrapper("quantum_attn::attention_forward") 49 | def _( 50 | query: torch.Tensor, 51 | key: torch.Tensor, 52 | value: torch.Tensor, 53 | attn_mask: Optional[torch.Tensor] = None, 54 | dropout_p: float = 0.0, 55 | is_causal: bool = False, 56 | *, 57 | scale: Optional[float] = None, 58 | ) -> torch.Tensor: 59 | return _attention_forward( 60 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 61 | ) 62 | 63 | 64 | def _fp8_attention_forward( 65 | query: torch.Tensor, 66 | key: torch.Tensor, 67 | value: torch.Tensor, 68 | scale_q: Optional[torch.Tensor] = None, 69 | scale_k: Optional[torch.Tensor] = None, 70 | attn_mask: Optional[torch.Tensor] = None, 71 | dropout_p: float = 0.0, 72 | is_causal: bool = False, 73 | *, 74 | scale: Optional[float] = None, 75 | ) -> torch.Tensor: 76 | out_dtype = value.dtype 77 | 78 | query = query.to(out_dtype) 79 | key = key.to(out_dtype) 80 | 81 | if scale_q is not None: 82 | assert scale_q.dim() == scale_k.dim() 83 | scale_q = scale_q.to(out_dtype) 84 | scale_k = scale_k.to(out_dtype) 85 | 86 | while scale_q.dim() < query.dim(): 87 | scale_q = scale_q.unsqueeze(-1) 88 | scale_k = scale_k.unsqueeze(-1) 89 | 90 | query = query * scale_q 91 | key = key * scale_k 92 | 93 | return aten.scaled_dot_product_attention( 94 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 95 | ).contiguous() 96 | 97 | 98 | @_torch_custom_op_wrapper("quantum_attn::fp8_attention_forward", mutates_args=(), device_types=("cuda",)) 99 | def fp8_attention_forward( 100 | query: torch.Tensor, 101 | key: torch.Tensor, 102 | value: torch.Tensor, 103 | scale_q: Optional[torch.Tensor] = None, 104 | scale_k: Optional[torch.Tensor] = None, 105 | attn_mask: Optional[torch.Tensor] = None, 106 | dropout_p: float = 0.0, 107 | is_causal: bool = False, 108 | *, 109 | scale: Optional[float] = None, 110 | ) -> torch.Tensor: 111 | return _fp8_attention_forward( 112 | query, 113 | key, 114 | value, 115 | scale_q, 116 | scale_k, 117 | attn_mask=attn_mask, 118 | dropout_p=dropout_p, 119 | is_causal=is_causal, 120 | scale=scale, 121 | ) 122 | 123 | 124 | @_torch_register_fake_wrapper("quantum_attn::fp8_attention_forward") 125 | def _( 126 | query: torch.Tensor, 127 | key: torch.Tensor, 128 | value: torch.Tensor, 129 | scale_q: Optional[torch.Tensor] = None, 130 | scale_k: Optional[torch.Tensor] = None, 131 | attn_mask: Optional[torch.Tensor] = None, 132 | dropout_p: float = 0.0, 133 | is_causal: bool = False, 134 | *, 135 | scale: Optional[float] = None, 136 | ) -> torch.Tensor: 137 | return _fp8_attention_forward( 138 | query, 139 | key, 140 | value, 141 | scale_q, 142 | scale_k, 143 | attn_mask=attn_mask, 144 | dropout_p=dropout_p, 145 | is_causal=is_causal, 146 | scale=scale, 147 | ) 148 | -------------------------------------------------------------------------------- /src/quantum_attn/quantum_attn_interface.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from quantum_attn.nn import attention, can_use_attention, dynamically_quantize_fp8, fp8_attention 7 | 8 | __all__ = [ 9 | "attn_func", 10 | "attn_func_with_fallback", 11 | "fp8_attn_func", 12 | "fp8_attn_func_with_fallback", 13 | "fp8_token_wise_attn_func", 14 | "fp8_token_wise_attn_func_with_fallback", 15 | "dynamically_quantize_fp8", 16 | ] 17 | 18 | torch_sdpa = F.scaled_dot_product_attention 19 | 20 | 21 | def _define_composite_implicit_autograd_op(namespace, name, signature): 22 | def decorator(fn): 23 | torch.library.define(f"{namespace}::{name}", signature) 24 | torch.library.impl(f"{namespace}::{name}", ["CompositeImplicitAutograd"])(fn) 25 | 26 | ns = getattr(torch.ops, namespace) 27 | op = getattr(ns, name) 28 | return op 29 | 30 | return decorator 31 | 32 | 33 | def _define_quantum_attn_composite_implicit_autograd_op(name, signature): 34 | return _define_composite_implicit_autograd_op("quantum_attn", name, signature) 35 | 36 | 37 | def sdpa_dispatcher(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *, scale=None): 38 | return tuple(filter(torch.is_tensor, (query, key, value, attn_mask))) 39 | 40 | 41 | def attn_func( 42 | query: torch.Tensor, 43 | key: torch.Tensor, 44 | value: torch.Tensor, 45 | attn_mask: Optional[torch.Tensor] = None, 46 | dropout_p: float = 0.0, 47 | is_causal: bool = False, 48 | *, 49 | scale: float = None, 50 | ) -> torch.Tensor: 51 | return attention( 52 | query, 53 | key, 54 | value, 55 | attn_mask=attn_mask, 56 | dropout_p=dropout_p, 57 | is_causal=is_causal, 58 | scale=scale, 59 | ) 60 | 61 | 62 | @_define_quantum_attn_composite_implicit_autograd_op( 63 | "attn_func_with_fallback", 64 | "(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor", 65 | ) 66 | def attn_func_with_fallback( 67 | query: torch.Tensor, 68 | key: torch.Tensor, 69 | value: torch.Tensor, 70 | attn_mask: Optional[torch.Tensor] = None, 71 | dropout_p: float = 0.0, 72 | is_causal: bool = False, 73 | *, 74 | scale: float = None, 75 | ) -> torch.Tensor: 76 | supported, reason = can_use_attention( 77 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale 78 | ) 79 | if supported: 80 | return attn_func( 81 | query, 82 | key, 83 | value, 84 | attn_mask=attn_mask, 85 | dropout_p=dropout_p, 86 | is_causal=is_causal, 87 | scale=scale, 88 | ) 89 | 90 | return torch_sdpa( 91 | query, 92 | key, 93 | value, 94 | attn_mask=attn_mask, 95 | dropout_p=dropout_p, 96 | is_causal=is_causal, 97 | scale=scale, 98 | ) 99 | 100 | 101 | def fp8_attn_func( 102 | query: torch.Tensor, 103 | key: torch.Tensor, 104 | value: torch.Tensor, 105 | attn_mask: Optional[torch.Tensor] = None, 106 | dropout_p: float = 0.0, 107 | is_causal: bool = False, 108 | *, 109 | scale: float = None, 110 | scale_q: Optional[torch.Tensor] = None, 111 | scale_k: Optional[torch.Tensor] = None, 112 | scaling_method: Optional[str] = None, 113 | ) -> torch.Tensor: 114 | if scaling_method is None: 115 | scaling_method = "head-wise" 116 | return fp8_attention( 117 | query, 118 | key, 119 | value, 120 | attn_mask=attn_mask, 121 | dropout_p=dropout_p, 122 | is_causal=is_causal, 123 | scale=scale, 124 | scale_q=scale_q, 125 | scale_k=scale_k, 126 | scaling_method=scaling_method, 127 | ) 128 | 129 | 130 | @_define_quantum_attn_composite_implicit_autograd_op( 131 | "fp8_attn_func_with_fallback", 132 | "(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, str? scaling_method=None) -> Tensor", 133 | ) 134 | def fp8_attn_func_with_fallback( 135 | query: torch.Tensor, 136 | key: torch.Tensor, 137 | value: torch.Tensor, 138 | attn_mask: Optional[torch.Tensor] = None, 139 | dropout_p: float = 0.0, 140 | is_causal: bool = False, 141 | *, 142 | scale: float = None, 143 | scaling_method: Optional[str] = None, 144 | ) -> torch.Tensor: 145 | if scaling_method is None: 146 | scaling_method = "head-wise" 147 | if can_use_attention( 148 | query, 149 | key, 150 | value, 151 | attn_mask=attn_mask, 152 | dropout_p=dropout_p, 153 | is_causal=is_causal, 154 | scale=scale, 155 | scaling_method=scaling_method, 156 | )[0]: 157 | return fp8_attn_func( 158 | query, 159 | key, 160 | value, 161 | attn_mask=attn_mask, 162 | dropout_p=dropout_p, 163 | is_causal=is_causal, 164 | scale=scale, 165 | scaling_method=scaling_method, 166 | ) 167 | 168 | return torch_sdpa( 169 | query, 170 | key, 171 | value, 172 | attn_mask=attn_mask, 173 | dropout_p=dropout_p, 174 | is_causal=is_causal, 175 | scale=scale, 176 | ) 177 | 178 | 179 | def fp8_token_wise_attn_func( 180 | query: torch.Tensor, 181 | key: torch.Tensor, 182 | value: torch.Tensor, 183 | attn_mask: Optional[torch.Tensor] = None, 184 | dropout_p: float = 0.0, 185 | is_causal: bool = False, 186 | *, 187 | scale: float = None, 188 | scale_q: Optional[torch.Tensor] = None, 189 | scale_k: Optional[torch.Tensor] = None, 190 | ) -> torch.Tensor: 191 | return fp8_attention( 192 | query, 193 | key, 194 | value, 195 | attn_mask=attn_mask, 196 | dropout_p=dropout_p, 197 | is_causal=is_causal, 198 | scale=scale, 199 | scale_q=scale_q, 200 | scale_k=scale_k, 201 | scaling_method="token-wise", 202 | ) 203 | 204 | 205 | @_define_quantum_attn_composite_implicit_autograd_op( 206 | "fp8_token_wise_attn_func_with_fallback", 207 | "(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor", 208 | ) 209 | def fp8_token_wise_attn_func_with_fallback( 210 | query: torch.Tensor, 211 | key: torch.Tensor, 212 | value: torch.Tensor, 213 | attn_mask: Optional[torch.Tensor] = None, 214 | dropout_p: float = 0.0, 215 | is_causal: bool = False, 216 | *, 217 | scale: float = None, 218 | ) -> torch.Tensor: 219 | if can_use_attention( 220 | query, 221 | key, 222 | value, 223 | attn_mask=attn_mask, 224 | dropout_p=dropout_p, 225 | is_causal=is_causal, 226 | scale=scale, 227 | scaling_method="token-wise", 228 | )[0]: 229 | return fp8_token_wise_attn_func( 230 | query, 231 | key, 232 | value, 233 | attn_mask=attn_mask, 234 | dropout_p=dropout_p, 235 | is_causal=is_causal, 236 | scale=scale, 237 | scaling_method="token-wise", 238 | ) 239 | 240 | return torch_sdpa( 241 | query, 242 | key, 243 | value, 244 | attn_mask=attn_mask, 245 | dropout_p=dropout_p, 246 | is_causal=is_causal, 247 | scale=scale, 248 | ) 249 | -------------------------------------------------------------------------------- /src/quantum_attn/tk/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WaveSpeedAI/QuantumAttention/0baf4b3fd3c568964acd4f176c35e8f073c70c20/src/quantum_attn/tk/__init__.py -------------------------------------------------------------------------------- /src/quantum_attn/tk/attention.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import torch 5 | from torch.utils.cpp_extension import load_inline 6 | 7 | from . import utils 8 | 9 | TK_ATTENTION_SOURCE = """ 10 | #include "kittens.cuh" 11 | #include 12 | #include 13 | 14 | using namespace kittens; 15 | namespace cg = cooperative_groups; 16 | 17 | #if defined(TK_ATTN_DTYPE_FP16) 18 | typedef half DType; 19 | #elif defined(TK_ATTN_DTYPE_BF16) 20 | typedef bf16 DType; 21 | #else 22 | #error "Unsupported dtype" 23 | #endif 24 | 25 | #if defined(TK_ATTN_IS_FP8) 26 | typedef fp8e4m3 QKDType; 27 | #else 28 | typedef DType QKDType; 29 | #endif 30 | 31 | __device__ static inline float fast_exp2f(float x) { 32 | float y; 33 | asm volatile ( "ex2.approx.ftz.f32 %0, %1; " : "=f"(y) : "f"(x)); 34 | return y; 35 | } 36 | 37 | namespace kittens { 38 | namespace base_ops { 39 | 40 | struct fast_exp2 { 41 | template static __device__ inline T op(const T &x) { return fast_exp2f(x); } 42 | }; 43 | template<> __device__ inline float fast_exp2::op (const float &x ) { return fast_exp2f(x); } 44 | template<> __device__ inline float2 fast_exp2::op(const float2 &x) { return float2{fast_exp2f(x.x), fast_exp2f(x.y)}; } 45 | 46 | } 47 | } 48 | 49 | // unary_map(att_block, att_block); 50 | 51 | template struct fwd_attend_ker_tile_dims {}; 52 | template<> struct fwd_attend_ker_tile_dims<64> { 53 | constexpr static int tile_width = (64); 54 | constexpr static int qo_height = (4*16); 55 | constexpr static int kv_height = (8*16); 56 | #if defined(TK_ATTN_IS_FP8) 57 | constexpr static int stages = (3); 58 | #else 59 | constexpr static int stages = (2); 60 | #endif 61 | }; 62 | template<> struct fwd_attend_ker_tile_dims<128> { 63 | constexpr static int tile_width = (128); 64 | constexpr static int qo_height = (4*16); 65 | constexpr static int kv_height = (8*16); 66 | constexpr static int stages = (2); 67 | }; 68 | template<> struct fwd_attend_ker_tile_dims<256> { 69 | constexpr static int tile_width = (256); 70 | constexpr static int qo_height = (4*16); 71 | constexpr static int kv_height = (4*16); 72 | constexpr static int stages = (2); 73 | }; 74 | 75 | template struct fwd_globals { 76 | using q_tile = st::qo_height, fwd_attend_ker_tile_dims::tile_width>; 77 | using k_tile = st::kv_height, fwd_attend_ker_tile_dims::tile_width>; 78 | using v_tile = st::kv_height, fwd_attend_ker_tile_dims::tile_width>; 79 | // using l_col_vec = col_vec::qo_height, fwd_attend_ker_tile_dims::tile_width>>; 80 | using o_tile = st::qo_height, fwd_attend_ker_tile_dims::tile_width>; 81 | 82 | using q_gl = gl; 83 | using k_gl = gl; 84 | using v_gl = gl; 85 | // using l_gl = gl; 86 | using o_gl = gl; 87 | 88 | q_gl q; 89 | k_gl k; 90 | v_gl v; 91 | // l_gl l; 92 | o_gl o; 93 | 94 | #if defined(TK_ATTN_IS_FP8) 95 | float* scale_q; 96 | float* scale_k; 97 | #endif 98 | 99 | const int N; 100 | const int hr; 101 | }; 102 | 103 | template 104 | __global__ __launch_bounds__(((CONSUMER_WARPGROUPS+PRODUCER_WARPGROUPS)*kittens::WARPGROUP_WARPS)*kittens::WARP_THREADS, 1) 105 | void fwd_attend_ker(const __grid_constant__ fwd_globals g) { 106 | constexpr int NUM_WARPGROUPS = (CONSUMER_WARPGROUPS+PRODUCER_WARPGROUPS); 107 | constexpr int NUM_WORKERS = (NUM_WARPGROUPS*kittens::WARPGROUP_WARPS); 108 | 109 | extern __shared__ int __shm[]; 110 | tma_swizzle_allocator al((int*)&__shm[0]); 111 | int warpid = kittens::warpid(), warpgroupid = warpid/kittens::WARPGROUP_WARPS; 112 | 113 | using K = fwd_attend_ker_tile_dims; 114 | 115 | using q_tile = st; 116 | using k_tile = st; 117 | using v_tile = st; 118 | // using l_col_vec = col_vec>; 119 | using o_tile = st; 120 | 121 | union TileUnion { 122 | q_tile q; 123 | o_tile o; 124 | }; 125 | 126 | TileUnion (&q_o_smem)[CONSUMER_WARPGROUPS] = al.allocate(); 127 | k_tile (&k_smem)[K::stages] = al.allocate(); 128 | v_tile (&v_smem)[K::stages] = al.allocate(); 129 | // l_col_vec (&l_smem)[CONSUMER_WARPGROUPS] = al.allocate(); 130 | 131 | int kv_blocks = (g.N + K::kv_height - 1) / (K::kv_height); 132 | int kv_head_idx = blockIdx.y / g.hr; 133 | int seq_idx = blockIdx.x * CONSUMER_WARPGROUPS; 134 | 135 | __shared__ kittens::semaphore qsmem_semaphore, k_smem_arrived[K::stages], v_smem_arrived[K::stages], compute_done[K::stages], qk_done[K::stages]; 136 | if (threadIdx.x == 0) { 137 | init_semaphore(qsmem_semaphore, 0, 1); 138 | for(int j = 0; j < K::stages; j++) { 139 | init_semaphore(k_smem_arrived[j], 0, 1); 140 | init_semaphore(v_smem_arrived[j], 0, 1); 141 | if constexpr (D >= 128) { 142 | init_semaphore(qk_done[j], CONSUMER_WARPGROUPS, 0); 143 | } 144 | init_semaphore(compute_done[j], CONSUMER_WARPGROUPS, 0); 145 | } 146 | 147 | tma::expect_bytes(qsmem_semaphore, sizeof(q_tile) * CONSUMER_WARPGROUPS); 148 | 149 | for (int wg = 0; wg < CONSUMER_WARPGROUPS; wg++) { 150 | coord q_tile_idx = {blockIdx.z, blockIdx.y, (seq_idx) + wg, 0}; 151 | tma::load_async(q_o_smem[wg].q, g.q, q_tile_idx, qsmem_semaphore); 152 | } 153 | 154 | for (int j = 0; j < K::stages - 1; j++) { 155 | coord kv_tile_idx = {blockIdx.z, kv_head_idx, j, 0}; 156 | tma::expect_bytes(k_smem_arrived[j], sizeof(k_tile)); 157 | tma::load_async(k_smem[j], g.k, kv_tile_idx, k_smem_arrived[j]); 158 | tma::expect_bytes(v_smem_arrived[j], sizeof(v_tile)); 159 | tma::load_async(v_smem[j], g.v, kv_tile_idx, v_smem_arrived[j]); 160 | } 161 | } 162 | __syncthreads(); 163 | 164 | int pipe_idx = K::stages - 1; 165 | 166 | if(warpgroupid == NUM_WARPGROUPS-1) { 167 | // warpgroup::decrease_registers<32>(); 168 | warpgroup::producer_registers(); 169 | 170 | int kv_iters; 171 | if constexpr (is_causal) { 172 | kv_iters = (seq_idx * K::qo_height) - 1 + (CONSUMER_WARPGROUPS * K::qo_height); 173 | kv_iters = ((kv_iters / K::kv_height) == 0) ? (0) : ((kv_iters / K::kv_height) - 1); 174 | } 175 | else { kv_iters = kv_blocks-2; } 176 | 177 | if(warpid == NUM_WORKERS-4) { 178 | for (auto kv_idx = pipe_idx - 1; kv_idx <= kv_iters; kv_idx++) { 179 | coord kv_tile_idx = {blockIdx.z, kv_head_idx, kv_idx+1, 0}; 180 | if constexpr (D >= 128) { 181 | if (kv_idx >= pipe_idx) { 182 | wait(qk_done[(kv_idx-pipe_idx)%K::stages], ((kv_idx-(pipe_idx))/K::stages)%2); 183 | } 184 | } else { 185 | if (kv_idx >= pipe_idx) { 186 | wait(compute_done[(kv_idx-pipe_idx)%K::stages], ((kv_idx-(pipe_idx))/K::stages)%2); 187 | } 188 | } 189 | tma::expect_bytes(k_smem_arrived[(kv_idx+1)%K::stages], sizeof(k_tile)); 190 | tma::load_async(k_smem[(kv_idx+1)%K::stages], g.k, kv_tile_idx, k_smem_arrived[(kv_idx+1)%K::stages]); 191 | if constexpr (D >= 128) { 192 | if (kv_idx >= pipe_idx) { 193 | wait(compute_done[(kv_idx-pipe_idx)%K::stages], ((kv_idx-(pipe_idx))/K::stages)%2); 194 | } 195 | } 196 | tma::expect_bytes(v_smem_arrived[(kv_idx+1)%K::stages], sizeof(v_tile)); 197 | tma::load_async(v_smem[(kv_idx+1)%K::stages], g.v, kv_tile_idx, v_smem_arrived[(kv_idx+1)%K::stages]); 198 | } 199 | } 200 | } 201 | else { 202 | // warpgroup::increase_registers<160>(); 203 | warpgroup::consumer_registers(); 204 | 205 | rt_fl<16, K::tile_width> o_reg; 206 | 207 | col_vec> max_vec, norm_vec, max_vec_last_scaled, att_block_sum; 208 | 209 | #if defined(TK_ATTN_IS_FP8) 210 | float scale_q = g.scale_q[blockIdx.z*gridDim.y + blockIdx.y]; 211 | float scale_k = g.scale_k[blockIdx.z*gridDim.y + blockIdx.y]; 212 | 213 | float scale = scale_q * scale_k; 214 | if constexpr (D == 64) { scale *= 1.44269504089f*0.125f; } 215 | else if constexpr (D == 128) { scale *= 1.44269504089f*0.08838834764f; } 216 | else { scale *= 1.44269504089f*0.0625f; } 217 | #endif 218 | 219 | neg_infty(max_vec); 220 | zero(norm_vec); 221 | zero(o_reg); 222 | 223 | int kv_iters; 224 | if constexpr (is_causal) { 225 | kv_iters = (seq_idx * K::qo_height) - 1 + (CONSUMER_WARPGROUPS * K::qo_height); 226 | kv_iters = (kv_iters / K::kv_height); 227 | } 228 | else { kv_iters = kv_blocks - 1; } 229 | 230 | wait(qsmem_semaphore, 0); 231 | 232 | for (auto kv_idx = 0; kv_idx <= kv_iters; kv_idx++) { 233 | rt_fl<16, K::kv_height> att_block; 234 | rt att_block_mma; 235 | 236 | wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2); 237 | warpgroup::mm_ABt(att_block, q_o_smem[warpgroupid].q, k_smem[(kv_idx)%K::stages]); 238 | col_vec> max_vec_scaled; 239 | 240 | #if defined(TK_ATTN_IS_FP8) 241 | mul(max_vec_last_scaled, max_vec, scale); 242 | #else 243 | if constexpr (D == 64) { mul(max_vec_last_scaled, max_vec, 1.44269504089f*0.125f); } 244 | else if constexpr (D == 128) { mul(max_vec_last_scaled, max_vec, 1.44269504089f*0.08838834764f); } 245 | else { mul(max_vec_last_scaled, max_vec, 1.44269504089f*0.0625f); } 246 | #endif 247 | 248 | warpgroup::mma_async_wait(); 249 | if constexpr (D >= 128) { 250 | if(warpgroup::laneid() == 0) arrive(qk_done[(kv_idx)%K::stages], 1); 251 | } 252 | 253 | if constexpr (is_causal) { 254 | if (kv_idx == kv_iters-1 || kv_idx == kv_iters) { 255 | const int q_blk = (seq_idx * (K::qo_height/kittens::TILE_ROW_DIM)) + warpid; 256 | int k_blk = (kv_idx * (K::kv_height/kittens::TILE_ROW_DIM)); 257 | 258 | #pragma unroll 259 | for (auto j = 0; j < (K::kv_height/kittens::TILE_ROW_DIM); j++) { 260 | auto k_idx = k_blk + j; 261 | auto &attn_subtile = reinterpret_cast&>(att_block.tiles[0][j]); 262 | 263 | if (k_idx > q_blk) { neg_infty (attn_subtile); } 264 | else if (k_idx == q_blk) { make_causal(attn_subtile, attn_subtile, kittens::base_types::constants::neg_infty()); } 265 | __syncwarp(); 266 | } 267 | } 268 | } 269 | else { 270 | if (kv_idx == kv_iters && g.N % K::kv_height != 0) { 271 | right_fill(att_block, att_block, g.N % K::kv_height, kittens::base_types::constants::neg_infty()); 272 | } 273 | } 274 | 275 | row_max(max_vec, att_block, max_vec); 276 | 277 | #if defined(TK_ATTN_IS_FP8) 278 | mul(att_block, att_block, scale); 279 | mul(max_vec_scaled, max_vec, scale); 280 | #else 281 | if constexpr (D == 64) { 282 | mul(att_block, att_block, 1.44269504089f*0.125f); 283 | mul(max_vec_scaled, max_vec, 1.44269504089f*0.125f); 284 | } 285 | else if constexpr (D == 128) { 286 | mul(att_block, att_block, 1.44269504089f*0.08838834764f); 287 | mul(max_vec_scaled, max_vec, 1.44269504089f*0.08838834764f); 288 | } 289 | else { 290 | mul(att_block, att_block, 1.44269504089f*0.0625f); 291 | mul(max_vec_scaled, max_vec, 1.44269504089f*0.0625f); 292 | } 293 | #endif 294 | 295 | sub_row(att_block, att_block, max_vec_scaled); 296 | exp2(att_block, att_block); 297 | sub(max_vec_last_scaled, max_vec_last_scaled, max_vec_scaled); 298 | exp2(max_vec_last_scaled, max_vec_last_scaled); 299 | mul(norm_vec, norm_vec, max_vec_last_scaled); 300 | row_sum(norm_vec, att_block, norm_vec); 301 | copy(att_block_mma, att_block); 302 | mul_row(o_reg, o_reg, max_vec_last_scaled); 303 | 304 | wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2); 305 | warpgroup::mma_AB(o_reg, att_block_mma, v_smem[(kv_idx)%K::stages]); 306 | warpgroup::mma_async_wait(); 307 | 308 | if(warpgroup::laneid() == 0) arrive(compute_done[(kv_idx)%K::stages], 1); 309 | } 310 | 311 | div_row(o_reg, o_reg, norm_vec); 312 | warpgroup::store(q_o_smem[warpgroupid].o, o_reg); 313 | warpgroup::sync(warpgroupid+4); 314 | 315 | if (warpid % 4 == 0) { 316 | coord o_tile_idx = {blockIdx.z, blockIdx.y, (seq_idx) + warpgroupid, 0}; 317 | tma::store_async(g.o, q_o_smem[warpgroupid].o, o_tile_idx); 318 | } 319 | 320 | // mul(max_vec_scaled, max_vec_scaled, 0.69314718056f); 321 | // log(norm_vec, norm_vec); 322 | // add(norm_vec, norm_vec, max_vec_scaled); 323 | 324 | // if constexpr (D == 64) { mul(norm_vec, norm_vec, -8.0f); } 325 | // else { mul(norm_vec, norm_vec, -11.313708499f); } 326 | 327 | // warpgroup::store(l_smem[warpgroupid], norm_vec); 328 | // warpgroup::sync(warpgroupid+4); 329 | 330 | // if (warpid % 4 == 0) { 331 | // coord tile_idx = {blockIdx.z, blockIdx.y, 0, (seq_idx) + warpgroupid}; 332 | // tma::store_async(g.l, l_smem[warpgroupid], tile_idx); 333 | // } 334 | tma::store_async_wait(); 335 | } 336 | } 337 | 338 | #include "pyutils/torch_helpers.cuh" 339 | #include 340 | #include 341 | 342 | std::vector 343 | #if TK_ATTN_IS_FP8 344 | attention_forward(const torch::Tensor &q, const torch::Tensor &k, const torch::Tensor &v, const torch::Tensor &scale_q, const torch::Tensor &scale_k, bool causal) 345 | #else 346 | attention_forward(const torch::Tensor &q, const torch::Tensor &k, const torch::Tensor &v, bool causal) 347 | #endif 348 | { 349 | CHECK_CUDA(q); 350 | CHECK_CUDA(k); 351 | CHECK_CUDA(v); 352 | #if TK_ATTN_IS_FP8 353 | CHECK_CUDA(scale_q); 354 | CHECK_CUDA(scale_k); 355 | #endif 356 | 357 | TORCH_CHECK(q.device() == k.device(), "Q and K tensors must be on the same device"); 358 | TORCH_CHECK(q.device() == v.device(), "Q and V tensors must be on the same device"); 359 | 360 | TORCH_CHECK(q.dim() == 4, "Q tensor must have 4 dimensions"); 361 | TORCH_CHECK(k.dim() == 4, "K tensor must have 4 dimensions"); 362 | TORCH_CHECK(v.dim() == 4, "V tensor must have 4 dimensions"); 363 | 364 | auto batch = q.size(0); 365 | auto seq_len_q = q.size(2); 366 | auto seq_len_kv = k.size(2); 367 | auto head_dim = q.size(3); 368 | auto is_causal = causal; 369 | auto qo_heads = q.size(1); 370 | auto kv_heads = k.size(1); 371 | 372 | // check to see that these dimensions match for all inputs 373 | TORCH_CHECK(q.size(0) == batch, "Q batch dimension - idx 0 - must match for all inputs"); 374 | TORCH_CHECK(k.size(0) == batch, "K batch dimension - idx 0 - must match for all inputs"); 375 | TORCH_CHECK(v.size(0) == batch, "V batch dimension - idx 0 - must match for all inputs"); 376 | 377 | TORCH_CHECK(q.size(2) == seq_len_q, "Q sequence length dimension - idx 2 - must match for all inputs"); 378 | TORCH_CHECK(k.size(2) == seq_len_kv, "K sequence length dimension - idx 2 - must match for all inputs"); 379 | TORCH_CHECK(v.size(2) == seq_len_kv, "V sequence length dimension - idx 2 - must match for all inputs"); 380 | 381 | TORCH_CHECK(q.size(3) == head_dim, "Q head dimension - idx 3 - must match for all non-vector inputs"); 382 | TORCH_CHECK(k.size(3) == head_dim, "K head dimension - idx 3 - must match for all non-vector inputs"); 383 | TORCH_CHECK(v.size(3) == head_dim, "V head dimension - idx 3 - must match for all non-vector inputs"); 384 | 385 | TORCH_CHECK(qo_heads >= kv_heads, "QO heads must be greater than or equal to KV heads"); 386 | TORCH_CHECK(qo_heads % kv_heads == 0, "QO heads must be divisible by KV heads"); 387 | TORCH_CHECK(q.size(1) == qo_heads, "QO head dimension - idx 1 - must match for all inputs"); 388 | TORCH_CHECK(k.size(1) == kv_heads, "KV head dimension - idx 1 - must match for all inputs"); 389 | TORCH_CHECK(v.size(1) == kv_heads, "KV head dimension - idx 1 - must match for all inputs"); 390 | 391 | #if TK_ATTN_IS_FP8 392 | TORCH_CHECK(scale_q.dtype() == torch::kFloat32, "Q scale tensor must be of type float32"); 393 | TORCH_CHECK(scale_k.dtype() == torch::kFloat32, "K scale tensor must be of type float32"); 394 | 395 | TORCH_CHECK(scale_q.dim() == 2, "Q scale tensor must have 2 dimensions"); 396 | TORCH_CHECK(scale_k.dim() == 2, "K scale tensor must have 2 dimensions"); 397 | 398 | TORCH_CHECK(scale_q.size(0) == batch, "Q scale batch dimension - idx 0 - must match for all inputs"); 399 | TORCH_CHECK(scale_k.size(0) == batch, "K scale batch dimension - idx 0 - must match for all inputs"); 400 | TORCH_CHECK(scale_q.size(1) == qo_heads, "Q scale head dimension - idx 1 - must match for all inputs"); 401 | TORCH_CHECK(scale_k.size(1) == kv_heads, "K scale head dimension - idx 1 - must match for all inputs"); 402 | #endif 403 | 404 | torch::DeviceGuard device_guard(q.device()); 405 | 406 | torch:: Tensor q_ = q.contiguous(); 407 | torch:: Tensor k_ = k.contiguous(); 408 | torch:: Tensor v_ = v.contiguous(); 409 | 410 | auto hr = qo_heads / kv_heads; 411 | 412 | void* q_ptr = q_.data_ptr(); 413 | void* k_ptr = k_.data_ptr(); 414 | void* v_ptr = v_.data_ptr(); 415 | 416 | QKDType* d_q = reinterpret_cast(q_ptr); 417 | QKDType* d_k = reinterpret_cast(k_ptr); 418 | DType* d_v = reinterpret_cast(v_ptr); 419 | 420 | // for the returned outputs 421 | torch::Tensor o = torch::empty({static_cast(batch), 422 | static_cast(qo_heads), 423 | static_cast(seq_len_q), 424 | static_cast(head_dim)}, v.options().memory_format(at::MemoryFormat::Contiguous)); 425 | 426 | // auto l_vec_stride_h = (seq_len_q * sizeof(float) + 15) / 16 * 16 / sizeof(float); 427 | // torch::Tensor l_vec = torch::empty_strided({static_cast(batch), 428 | // static_cast(qo_heads), 429 | // static_cast(seq_len_q)}, 430 | // {static_cast(qo_heads * l_vec_stride_h), 431 | // static_cast(l_vec_stride_h), 432 | // 1}, 433 | // torch::dtype(torch::kFloat).device(q.device())); 434 | 435 | DType* o_ptr = reinterpret_cast(o.data_ptr()); 436 | DType* d_o = reinterpret_cast(o_ptr); 437 | 438 | // float* l_ptr = reinterpret_cast(l_vec.data_ptr()); 439 | // float* d_l = reinterpret_cast(l_ptr); 440 | 441 | #if TK_ATTN_IS_FP8 442 | torch::Tensor scale_q_ = scale_q.contiguous(); 443 | torch::Tensor scale_k_ = scale_k.contiguous(); 444 | 445 | void* scale_q_ptr = scale_q_.data_ptr(); 446 | void* scale_k_ptr = scale_k_.data_ptr(); 447 | 448 | float* d_scale_q = reinterpret_cast(scale_q_ptr); 449 | float* d_scale_k = reinterpret_cast(scale_k_ptr); 450 | #endif 451 | 452 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 453 | 454 | if (head_dim == 64) { 455 | constexpr int CONSUMER_WARPGROUPS = (3); 456 | constexpr int PRODUCER_WARPGROUPS = (1); 457 | constexpr int NUM_WARPGROUPS = (CONSUMER_WARPGROUPS+PRODUCER_WARPGROUPS); 458 | constexpr int NUM_WORKERS = (NUM_WARPGROUPS*kittens::WARPGROUP_WARPS); 459 | 460 | using q_tile = st::qo_height, fwd_attend_ker_tile_dims<64>::tile_width>; 461 | using k_tile = st::kv_height, fwd_attend_ker_tile_dims<64>::tile_width>; 462 | using v_tile = st::kv_height, fwd_attend_ker_tile_dims<64>::tile_width>; 463 | // using l_col_vec = col_vec::qo_height, fwd_attend_ker_tile_dims<64>::tile_width>>; 464 | using o_tile = st::qo_height, fwd_attend_ker_tile_dims<64>::tile_width>; 465 | 466 | using q_global = gl; 467 | using k_global = gl; 468 | using v_global = gl; 469 | // using l_global = gl; 470 | using o_global = gl; 471 | 472 | using globals = fwd_globals<64>; 473 | 474 | q_global qg_arg{d_q, static_cast(batch), static_cast(qo_heads), static_cast(seq_len_q), nullptr}; 475 | k_global kg_arg{d_k, static_cast(batch), static_cast(kv_heads), static_cast(seq_len_kv), nullptr}; 476 | v_global vg_arg{d_v, static_cast(batch), static_cast(kv_heads), static_cast(seq_len_kv), nullptr}; 477 | // l_global lg_arg{d_l, static_cast(batch), static_cast(qo_heads), nullptr, static_cast(l_vec_stride_h)}; 478 | o_global og_arg{d_o, static_cast(batch), static_cast(qo_heads), static_cast(seq_len_q), nullptr}; 479 | 480 | #if TK_ATTN_IS_FP8 481 | globals g{qg_arg, kg_arg, vg_arg/* , lg_arg */, og_arg, d_scale_q, d_scale_k, static_cast(seq_len_kv), static_cast(hr)}; 482 | #else 483 | globals g{qg_arg, kg_arg, vg_arg/* , lg_arg */, og_arg, static_cast(seq_len_kv), static_cast(hr)}; 484 | #endif 485 | 486 | auto mem_size = kittens::MAX_SHARED_MEMORY; 487 | // auto threads = NUM_WORKERS * kittens::WARP_THREADS; 488 | 489 | constexpr int block_size_m = CONSUMER_WARPGROUPS*fwd_attend_ker_tile_dims<64>::qo_height; 490 | int num_m_blocks = (seq_len_q + block_size_m - 1) / block_size_m; 491 | dim3 grid(num_m_blocks, qo_heads, batch); 492 | 493 | if (is_causal) { 494 | CHECK_CUDA_ERROR(cudaFuncSetAttribute( 495 | fwd_attend_ker<64, true, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS>, 496 | cudaFuncAttributeMaxDynamicSharedMemorySize, 497 | mem_size 498 | )); 499 | 500 | fwd_attend_ker<64, true, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS><<>>(g); 501 | } 502 | else { 503 | CHECK_CUDA_ERROR(cudaFuncSetAttribute( 504 | fwd_attend_ker<64, false, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS>, 505 | cudaFuncAttributeMaxDynamicSharedMemorySize, 506 | mem_size 507 | )); 508 | 509 | fwd_attend_ker<64, false, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS><<>>(g); 510 | } 511 | } 512 | 513 | if (head_dim == 128) { 514 | constexpr int CONSUMER_WARPGROUPS = (3); 515 | constexpr int PRODUCER_WARPGROUPS = (1); 516 | constexpr int NUM_WARPGROUPS = (CONSUMER_WARPGROUPS+PRODUCER_WARPGROUPS); 517 | constexpr int NUM_WORKERS = (NUM_WARPGROUPS*kittens::WARPGROUP_WARPS); 518 | 519 | using q_tile = st::qo_height, fwd_attend_ker_tile_dims<128>::tile_width>; 520 | using k_tile = st::kv_height, fwd_attend_ker_tile_dims<128>::tile_width>; 521 | using v_tile = st::kv_height, fwd_attend_ker_tile_dims<128>::tile_width>; 522 | // using l_col_vec = col_vec::qo_height, fwd_attend_ker_tile_dims<128>::tile_width>>; 523 | using o_tile = st::qo_height, fwd_attend_ker_tile_dims<128>::tile_width>; 524 | 525 | using q_global = gl; 526 | using k_global = gl; 527 | using v_global = gl; 528 | // using l_global = gl; 529 | using o_global = gl; 530 | 531 | using globals = fwd_globals<128>; 532 | 533 | q_global qg_arg{d_q, static_cast(batch), static_cast(qo_heads), static_cast(seq_len_q), nullptr}; 534 | k_global kg_arg{d_k, static_cast(batch), static_cast(kv_heads), static_cast(seq_len_kv), nullptr}; 535 | v_global vg_arg{d_v, static_cast(batch), static_cast(kv_heads), static_cast(seq_len_kv), nullptr}; 536 | // l_global lg_arg{d_l, static_cast(batch), static_cast(qo_heads), nullptr, static_cast(l_vec_stride_h)}; 537 | o_global og_arg{d_o, static_cast(batch), static_cast(qo_heads), static_cast(seq_len_q), nullptr}; 538 | 539 | #if TK_ATTN_IS_FP8 540 | globals g{qg_arg, kg_arg, vg_arg/* , lg_arg */, og_arg, d_scale_q, d_scale_k, static_cast(seq_len_kv), static_cast(hr)}; 541 | #else 542 | globals g{qg_arg, kg_arg, vg_arg/* , lg_arg */, og_arg, static_cast(seq_len_kv), static_cast(hr)}; 543 | #endif 544 | 545 | auto mem_size = kittens::MAX_SHARED_MEMORY; 546 | // auto threads = NUM_WORKERS * kittens::WARP_THREADS; 547 | 548 | constexpr int block_size_m = CONSUMER_WARPGROUPS*fwd_attend_ker_tile_dims<128>::qo_height; 549 | int num_m_blocks = (seq_len_q + block_size_m - 1) / block_size_m; 550 | dim3 grid(num_m_blocks, qo_heads, batch); 551 | 552 | if (is_causal) { 553 | CHECK_CUDA_ERROR(cudaFuncSetAttribute( 554 | fwd_attend_ker<128, true, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS>, 555 | cudaFuncAttributeMaxDynamicSharedMemorySize, 556 | mem_size 557 | )); 558 | 559 | fwd_attend_ker<128, true, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS><<>>(g); 560 | } 561 | else { 562 | CHECK_CUDA_ERROR(cudaFuncSetAttribute( 563 | fwd_attend_ker<128, false, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS>, 564 | cudaFuncAttributeMaxDynamicSharedMemorySize, 565 | mem_size 566 | )); 567 | 568 | fwd_attend_ker<128, false, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS><<>>(g); 569 | } 570 | } 571 | 572 | if (head_dim == 256) { 573 | constexpr int CONSUMER_WARPGROUPS = (2); 574 | constexpr int PRODUCER_WARPGROUPS = (1); 575 | constexpr int NUM_WARPGROUPS = (CONSUMER_WARPGROUPS+PRODUCER_WARPGROUPS); 576 | constexpr int NUM_WORKERS = (NUM_WARPGROUPS*kittens::WARPGROUP_WARPS); 577 | 578 | using q_tile = st::qo_height, fwd_attend_ker_tile_dims<256>::tile_width>; 579 | using k_tile = st::kv_height, fwd_attend_ker_tile_dims<256>::tile_width>; 580 | using v_tile = st::kv_height, fwd_attend_ker_tile_dims<256>::tile_width>; 581 | // using l_col_vec = col_vec::qo_height, fwd_attend_ker_tile_dims<256>::tile_width>>; 582 | using o_tile = st::qo_height, fwd_attend_ker_tile_dims<256>::tile_width>; 583 | 584 | using q_global = gl; 585 | using k_global = gl; 586 | using v_global = gl; 587 | // using l_global = gl; 588 | using o_global = gl; 589 | 590 | using globals = fwd_globals<256>; 591 | 592 | q_global qg_arg{d_q, static_cast(batch), static_cast(qo_heads), static_cast(seq_len_q), nullptr}; 593 | k_global kg_arg{d_k, static_cast(batch), static_cast(kv_heads), static_cast(seq_len_kv), nullptr}; 594 | v_global vg_arg{d_v, static_cast(batch), static_cast(kv_heads), static_cast(seq_len_kv), nullptr}; 595 | // l_global lg_arg{d_l, static_cast(batch), static_cast(qo_heads), nullptr, static_cast(l_vec_stride_h)}; 596 | o_global og_arg{d_o, static_cast(batch), static_cast(qo_heads), static_cast(seq_len_q), nullptr}; 597 | 598 | #if TK_ATTN_IS_FP8 599 | globals g{qg_arg, kg_arg, vg_arg/* , lg_arg */, og_arg, d_scale_q, d_scale_k, static_cast(seq_len_kv), static_cast(hr)}; 600 | #else 601 | globals g{qg_arg, kg_arg, vg_arg/* , lg_arg */, og_arg, static_cast(seq_len_kv), static_cast(hr)}; 602 | #endif 603 | 604 | auto mem_size = kittens::MAX_SHARED_MEMORY; 605 | // auto threads = NUM_WORKERS * kittens::WARP_THREADS; 606 | 607 | constexpr int block_size_m = CONSUMER_WARPGROUPS*fwd_attend_ker_tile_dims<256>::qo_height; 608 | int num_m_blocks = (seq_len_q + block_size_m - 1) / block_size_m; 609 | dim3 grid(num_m_blocks, qo_heads, batch); 610 | 611 | if (is_causal) { 612 | CHECK_CUDA_ERROR(cudaFuncSetAttribute( 613 | fwd_attend_ker<256, true, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS>, 614 | cudaFuncAttributeMaxDynamicSharedMemorySize, 615 | mem_size 616 | )); 617 | 618 | fwd_attend_ker<256, true, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS><<>>(g); 619 | } 620 | else { 621 | CHECK_CUDA_ERROR(cudaFuncSetAttribute( 622 | fwd_attend_ker<256, false, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS>, 623 | cudaFuncAttributeMaxDynamicSharedMemorySize, 624 | mem_size 625 | )); 626 | 627 | fwd_attend_ker<256, false, CONSUMER_WARPGROUPS, PRODUCER_WARPGROUPS><<>>(g); 628 | } 629 | } 630 | 631 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 632 | 633 | return {o/* , l_vec */}; 634 | } 635 | """ 636 | 637 | 638 | @functools.cache 639 | def load_tk_attention_module(dtype, is_fp8=False): 640 | extra_cuda_cflags = [ 641 | "-std=c++20", 642 | # "-U__CUDA_NO_HALF_OPERATORS__", 643 | # "-U__CUDA_NO_HALF_CONVERSIONS__", 644 | # "-U__CUDA_NO_HALF2_OPERATORS__", 645 | # "-U__CUDA_NO_HALF2_CONVERSIONS__", 646 | # "-U__CUDA_NO_BFLOAT16_OPERATORS__", 647 | # "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 648 | # "-U__CUDA_NO_BFLOAT162_OPERATORS__", 649 | # "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 650 | "--expt-extended-lambda", 651 | "--expt-relaxed-constexpr", 652 | "--use_fast_math", 653 | # "--ptxas-options=-v", # printing out number of registers 654 | # "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers 655 | "-lineinfo", 656 | "-O3", 657 | "-Xcudafe --diag_suppress=2361", 658 | "--threads=4", 659 | "-DNDEBUG", 660 | "-DKITTENS_HOPPER", 661 | ] 662 | if dtype == torch.float16: 663 | extra_cuda_cflags.append("-DTK_ATTN_DTYPE_FP16") 664 | elif dtype == torch.bfloat16: 665 | extra_cuda_cflags.append("-DTK_ATTN_DTYPE_BF16") 666 | else: 667 | raise ValueError(f"Unsupported dtype: {dtype}") 668 | 669 | if is_fp8: 670 | extra_cuda_cflags.append("-DTK_ATTN_IS_FP8") 671 | 672 | old_torch_cuda_arch_list = os.getenv("TORCH_CUDA_ARCH_LIST") 673 | os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0a" 674 | try: 675 | module = load_inline( 676 | name=f"quantum_attn_tk_attention_dtype_{str(dtype).replace('torch.', '')}_is_fp8_{is_fp8}", 677 | cpp_sources=[ 678 | "std::vector attention_forward(const torch::Tensor &q, const torch::Tensor &k, const torch::Tensor &v, const torch::Tensor &scale_q, const torch::Tensor &scale_k, bool causal);" 679 | if is_fp8 680 | else "std::vector attention_forward(const torch::Tensor &q, const torch::Tensor &k, const torch::Tensor &v, bool causal);" 681 | ], 682 | cuda_sources=[TK_ATTENTION_SOURCE], 683 | extra_cflags=["-std=c++20", "-O3", "-DNDEBUG"], 684 | extra_cuda_cflags=extra_cuda_cflags, 685 | extra_ldflags=["-lcuda", "-lcudart"], 686 | extra_include_paths=[utils.get_tk_include_dir()], 687 | functions=["attention_forward"], 688 | verbose=True, 689 | ) 690 | finally: 691 | if old_torch_cuda_arch_list is None: 692 | os.environ.pop("TORCH_CUDA_ARCH_LIST") 693 | else: 694 | os.environ["TORCH_CUDA_ARCH_LIST"] = old_torch_cuda_arch_list 695 | return module 696 | -------------------------------------------------------------------------------- /src/quantum_attn/tk/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import quantum_attn 4 | 5 | 6 | def get_tk_include_dir(): 7 | return os.path.join(os.path.dirname(quantum_attn.__file__), "tk_repo", "include") 8 | -------------------------------------------------------------------------------- /src/quantum_attn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WaveSpeedAI/QuantumAttention/0baf4b3fd3c568964acd4f176c35e8f073c70c20/src/quantum_attn/utils/__init__.py -------------------------------------------------------------------------------- /src/quantum_attn/utils/checks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import importlib 4 | 5 | import torch 6 | from packaging import version 7 | 8 | 9 | @torch.compiler.assume_constant_result 10 | def get_constant_attr(module, attr): 11 | obj = importlib.import_module(module) 12 | attr = attr.split(".") 13 | for a in attr: 14 | obj = getattr(obj, a) 15 | return obj 16 | 17 | 18 | def torch_version_compare(op, v): 19 | return getattr(version.parse(torch.__version__).release, f"__{op}__")(version.parse(v).release) 20 | 21 | 22 | @functools.cache 23 | def has_triton_package() -> bool: 24 | try: 25 | import triton 26 | 27 | return triton is not None 28 | except ImportError: 29 | return False 30 | 31 | 32 | def triton_version_compare(op, v): 33 | if not has_triton_package(): 34 | return None 35 | import triton 36 | 37 | return getattr(version.parse(triton.__version__).release, f"__{op}__")(version.parse(v).release) 38 | 39 | 40 | def has_triton_language(attr): 41 | if not has_triton_package(): 42 | return False 43 | import triton.language as tl 44 | 45 | return hasattr(tl, attr) 46 | 47 | 48 | def has_triton_tma_support(): 49 | if not has_triton_language("_experimental_descriptor_load"): 50 | return False 51 | 52 | import triton.language as tl 53 | 54 | return hasattr(tl.extra.cuda, "experimental_tensormap_fenceproxy_acquire") 55 | 56 | 57 | def is_nvidia_cuda(): 58 | return torch.version.hip is None and torch.cuda.is_available() 59 | 60 | 61 | def cuda_capability_compare(op, major, minor, *, device=None): 62 | if not is_nvidia_cuda(): 63 | return None 64 | return getattr(torch.cuda.get_device_capability(device), f"__{op}__")((major, minor)) 65 | 66 | 67 | def torch_cuda_version(): 68 | if torch.version.cuda is None: 69 | return (0, 0) 70 | cuda_version = str(torch.version.cuda) 71 | return tuple(int(x) for x in cuda_version.split("."))[:2] 72 | 73 | 74 | def torch_cuda_version_compare(op, major, minor): 75 | return getattr(torch_cuda_version(), f"__{op}__")((major, minor)) 76 | -------------------------------------------------------------------------------- /src/quantum_attn/utils/types.py: -------------------------------------------------------------------------------- 1 | def is_fp8_type(dtype): 2 | return dtype.is_floating_point and dtype.itemsize == 1 3 | 4 | 5 | def is_8bit_type(dtype): 6 | return dtype.itemsize == 1 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WaveSpeedAI/QuantumAttention/0baf4b3fd3c568964acd4f176c35e8f073c70c20/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_interface.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import quantum_attn 4 | import torch 5 | import torch.nn.functional as F 6 | from quantum_attn import quantum_attn_interface 7 | from torch.nn.attention import sdpa_kernel, SDPBackend 8 | 9 | if not torch.cuda.is_available(): 10 | pytest.skip("CUDA is not available", allow_module_level=True) 11 | 12 | 13 | def flash_attention(query, key, value, is_causal=False): 14 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 15 | return F.scaled_dot_product_attention(query, key, value, is_causal=is_causal) 16 | 17 | 18 | def cudnn_sdpa(query, key, value, is_causal=False): 19 | with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): 20 | return F.scaled_dot_product_attention(query, key, value, is_causal=is_causal) 21 | 22 | 23 | def vanilla_attention(query, key, value, is_causal=False): 24 | return quantum_attn_interface.attn_func(query, key, value, is_causal=is_causal) 25 | 26 | 27 | def fp8_attention(query, key, value, is_causal=False): 28 | return quantum_attn_interface.fp8_attn_func(query, key, value, is_causal=is_causal) 29 | 30 | 31 | @torch.no_grad() 32 | def _test_attn_func(B, H, S_Q, S_KV, D, dtype, device, is_causal, force_eager_fallback, is_fp8=False): 33 | if is_causal and S_Q != S_KV: 34 | pytest.skip("Causal attention is only supported for S_Q == S_KV") 35 | 36 | if is_fp8: 37 | attn_func = fp8_attention 38 | else: 39 | attn_func = vanilla_attention 40 | 41 | torch.manual_seed(0) 42 | query = torch.randn(B, H, S_Q, D, dtype=dtype, device=device) 43 | key = torch.randn(B, H, S_KV, D, dtype=dtype, device=device) 44 | value = torch.randn(B, H, S_KV, D, dtype=dtype, device=device) 45 | 46 | with quantum_attn.config.patch( 47 | { 48 | "attention.force_eager_fallback": force_eager_fallback, 49 | } 50 | ): 51 | try: 52 | attn_out = attn_func(query, key, value, is_causal=is_causal) 53 | except ValueError as e: 54 | pytest.skip(str(e)) 55 | 56 | fa_out = flash_attention(query, key, value, is_causal=is_causal) 57 | 58 | rmse = torch.sqrt(F.mse_loss(attn_out, fa_out)) 59 | print(f"RMSE: {rmse}") 60 | assert rmse < 1e-2, f"RMSE: {rmse}" 61 | 62 | 63 | @pytest.mark.parametrize("B", [1, 2]) 64 | @pytest.mark.parametrize("H", [8, 16]) 65 | @pytest.mark.parametrize("S_Q", [1024, 999]) 66 | @pytest.mark.parametrize("S_KV", [1024, 999]) 67 | @pytest.mark.parametrize("D", [64, 128, 256]) 68 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 69 | @pytest.mark.parametrize("device", ["cuda"]) 70 | @pytest.mark.parametrize("is_causal", [False, True]) 71 | @pytest.mark.parametrize("force_eager_fallback", [False]) 72 | def test_attn_func(B, H, S_Q, S_KV, D, dtype, device, is_causal, force_eager_fallback): 73 | _test_attn_func(B, H, S_Q, S_KV, D, dtype, device, is_causal, force_eager_fallback) 74 | 75 | 76 | @pytest.mark.parametrize("B", [1, 2]) 77 | @pytest.mark.parametrize("H", [8, 16]) 78 | @pytest.mark.parametrize("S_Q", [1024, 1000]) 79 | @pytest.mark.parametrize("S_KV", [1024, 1000]) 80 | @pytest.mark.parametrize("D", [64, 128, 256]) 81 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 82 | @pytest.mark.parametrize("device", ["cuda"]) 83 | @pytest.mark.parametrize("is_causal", [False, True]) 84 | @pytest.mark.parametrize("force_eager_fallback", [False]) 85 | def test_fp8_attn_func(B, H, S_Q, S_KV, D, dtype, device, is_causal, force_eager_fallback): 86 | _test_attn_func(B, H, S_Q, S_KV, D, dtype, device, is_causal, force_eager_fallback, is_fp8=True) 87 | 88 | 89 | @torch.no_grad() 90 | def _test_benchmark_attn_func(D, dtype, device, is_causal, is_fp8=False): 91 | import triton 92 | 93 | torch.manual_seed(0) 94 | 95 | B = 16 96 | H = 16 97 | S_Q = 8192 98 | S_KV = 8192 99 | 100 | query = torch.randn(B, H, S_Q, D, dtype=dtype, device=device) 101 | key = torch.randn(B, H, S_KV, D, dtype=dtype, device=device) 102 | value = torch.randn(B, H, S_KV, D, dtype=dtype, device=device) 103 | 104 | def attention_fn(): 105 | if is_fp8: 106 | fp8_attention(query, key, value, is_causal) 107 | else: 108 | vanilla_attention(query, key, value, is_causal) 109 | 110 | try: 111 | attention_fn() 112 | except ValueError as e: 113 | pytest.skip(str(e)) 114 | 115 | def fa_fn(): 116 | flash_attention(query, key, value, is_causal) 117 | 118 | def cudnn_sdpa_fn(): 119 | cudnn_sdpa(query, key, value, is_causal) 120 | 121 | flops_per_matmul = 2 * B * H * S_Q * S_KV * D 122 | total_flops = 2 * flops_per_matmul 123 | 124 | if is_causal: 125 | total_flops //= 2 126 | 127 | ms_fa = triton.testing.do_bench(fa_fn) 128 | tflops_fa = total_flops * 1e-12 / (ms_fa * 1e-3) 129 | print(f"TFLOPS (Flash Attention): {tflops_fa:.2f}") 130 | 131 | if D <= 128: 132 | ms_cudnn_sdpa = triton.testing.do_bench(cudnn_sdpa_fn) 133 | tflops_cudnn_sdpa = total_flops * 1e-12 / (ms_cudnn_sdpa * 1e-3) 134 | print(f"TFLOPS (CUDNN SDPA): {tflops_cudnn_sdpa:.2f}") 135 | 136 | ms_quantum_attention = triton.testing.do_bench(attention_fn) 137 | tflops_quantum_attention = total_flops * 1e-12 / (ms_quantum_attention * 1e-3) 138 | print(f"TFLOPS (Quantum Attention): {tflops_quantum_attention:.2f}") 139 | 140 | 141 | @pytest.mark.parametrize("D", [64, 128, 256]) 142 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 143 | @pytest.mark.parametrize("device", ["cuda"]) 144 | @pytest.mark.parametrize("is_causal", [False, True]) 145 | def test_benchmark_attn_func(D, dtype, device, is_causal, is_fp8=False): 146 | _test_benchmark_attn_func(D, dtype, device, is_causal, is_fp8) 147 | 148 | 149 | @pytest.mark.parametrize("D", [64, 128, 256]) 150 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) 151 | @pytest.mark.parametrize("device", ["cuda"]) 152 | @pytest.mark.parametrize("is_causal", [False, True]) 153 | def test_benchmark_fp8_attn_func(D, dtype, device, is_causal): 154 | _test_benchmark_attn_func(D, dtype, device, is_causal, is_fp8=True) 155 | --------------------------------------------------------------------------------