├── .dev ├── .clang-format ├── .clang_format.hook ├── .cpplint_pre_commit.hook ├── .pre-commit-config-cpp.yaml ├── .pre-commit-config.yaml ├── clear.sh ├── commit-prepare.sh ├── init_dev.sh ├── init_prod.sh ├── init_prod_mini.sh ├── install.sh └── uninstall.sh ├── .github ├── .gitignore └── workflows │ └── issue.yml ├── .gitignore ├── .gitmodules ├── LICENSE ├── MANIFEST.in ├── README.md ├── bench ├── .gitignore ├── NVIDIA_A30.png ├── NVIDIA_A30_ffpa+acc+f16+L1_Speedup.png ├── NVIDIA_A30_ffpa+acc+f32+L1_Speedup.png ├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png ├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f16+L1_Speedup.png ├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f32+L1_Speedup.png ├── NVIDIA_GeForce_RTX_4090.png ├── NVIDIA_GeForce_RTX_4090_ffpa+acc+f16+L1_Speedup.png ├── NVIDIA_GeForce_RTX_4090_ffpa+acc+f32+L1_Speedup.png ├── NVIDIA_L20.png ├── NVIDIA_L20_ffpa+acc+f16+L1_Speedup.png ├── NVIDIA_L20_ffpa+acc+f32+L1_Speedup.png ├── bank_conflicts_check.sh └── bench.sh ├── csrc ├── .gitignore ├── cuffpa │ ├── ffpa_attn_F16F16F16_L1.cu │ ├── ffpa_attn_F16F16F32_L1.cu │ ├── ffpa_attn_templates_L1.cuh │ └── launch_templates.cuh ├── deprecated │ ├── faster_prefill_attn_F16F16F16F16_L1.cu │ └── faster_prefill_attn_F32F16F16F32_L1.cu ├── extension │ ├── .gitignore │ ├── fused_mla_F16F16F16_L1.cu │ ├── fused_mla_F16F16F32_L1.cu │ ├── fused_mla_templates_L1.cuh │ └── launch_templates.cuh └── pybind │ ├── ffpa_attn_api.cc │ └── fused_mla_api.cc ├── env.py ├── ffpa_attn ├── .gitignore ├── __init__.py ├── interface.py └── version.py ├── include ├── .gitignore ├── cuffpa │ ├── cp_async.cuh │ ├── deprecated │ │ ├── mma_utils.cuh │ │ └── smem_swizzle.cuh │ ├── logging.cuh │ ├── mma.cuh │ ├── prefill.cuh │ ├── swizzle.cuh │ ├── utils.cuh │ └── warp.cuh └── extension │ └── .gitignore ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── .gitignore ├── requirements.txt ├── swizzle_layout.py ├── test_ffpa_attn.py └── test_fused_mla.py /.dev/.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignArrayOfStructures: None 7 | AlignConsecutiveMacros: None 8 | AlignConsecutiveAssignments: None 9 | AlignConsecutiveBitFields: None 10 | AlignConsecutiveDeclarations: None 11 | AlignEscapedNewlines: Right 12 | AlignOperands: Align 13 | AlignTrailingComments: true 14 | AllowAllArgumentsOnNextLine: true 15 | AllowAllConstructorInitializersOnNextLine: true 16 | AllowAllParametersOfDeclarationOnNextLine: true 17 | AllowShortEnumsOnASingleLine: true 18 | AllowShortBlocksOnASingleLine: Never 19 | AllowShortCaseLabelsOnASingleLine: false 20 | AllowShortFunctionsOnASingleLine: All 21 | AllowShortLambdasOnASingleLine: All 22 | AllowShortIfStatementsOnASingleLine: Never 23 | AllowShortLoopsOnASingleLine: false 24 | AlwaysBreakAfterDefinitionReturnType: None 25 | AlwaysBreakAfterReturnType: None 26 | AlwaysBreakBeforeMultilineStrings: false 27 | AlwaysBreakTemplateDeclarations: MultiLine 28 | AttributeMacros: 29 | - __capability 30 | BinPackArguments: true 31 | BinPackParameters: true 32 | BraceWrapping: 33 | AfterCaseLabel: false 34 | AfterClass: false 35 | AfterControlStatement: Never 36 | AfterEnum: false 37 | AfterFunction: false 38 | AfterNamespace: false 39 | AfterObjCDeclaration: false 40 | AfterStruct: false 41 | AfterUnion: false 42 | AfterExternBlock: false 43 | BeforeCatch: false 44 | BeforeElse: false 45 | BeforeLambdaBody: false 46 | BeforeWhile: false 47 | IndentBraces: false 48 | SplitEmptyFunction: true 49 | SplitEmptyRecord: true 50 | SplitEmptyNamespace: true 51 | BreakBeforeBinaryOperators: None 52 | BreakBeforeConceptDeclarations: true 53 | BreakBeforeBraces: Attach 54 | BreakBeforeInheritanceComma: false 55 | BreakInheritanceList: BeforeColon 56 | BreakBeforeTernaryOperators: true 57 | BreakConstructorInitializersBeforeComma: false 58 | BreakConstructorInitializers: BeforeColon 59 | BreakAfterJavaFieldAnnotations: false 60 | BreakStringLiterals: true 61 | ColumnLimit: 80 62 | # CommentPragmas: '^ IWYU pragma:' 63 | # CommentPragmas: '^[^ ]' 64 | CommentPragmas: '^\\.+' 65 | CompactNamespaces: false 66 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 67 | ConstructorInitializerIndentWidth: 4 68 | ContinuationIndentWidth: 4 69 | Cpp11BracedListStyle: true 70 | DeriveLineEnding: true 71 | DerivePointerAlignment: false 72 | DisableFormat: false 73 | EmptyLineAfterAccessModifier: Never 74 | EmptyLineBeforeAccessModifier: LogicalBlock 75 | ExperimentalAutoDetectBinPacking: false 76 | FixNamespaceComments: true 77 | ForEachMacros: 78 | - foreach 79 | - Q_FOREACH 80 | - BOOST_FOREACH 81 | IfMacros: 82 | - KJ_IF_MAYBE 83 | IncludeBlocks: Preserve 84 | IncludeCategories: 85 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 86 | Priority: 2 87 | SortPriority: 0 88 | CaseSensitive: false 89 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 90 | Priority: 3 91 | SortPriority: 0 92 | CaseSensitive: false 93 | - Regex: '.*' 94 | Priority: 1 95 | SortPriority: 0 96 | CaseSensitive: false 97 | IncludeIsMainRegex: '(Test)?$' 98 | IncludeIsMainSourceRegex: '' 99 | IndentAccessModifiers: false 100 | IndentCaseLabels: false 101 | IndentCaseBlocks: false 102 | IndentGotoLabels: true 103 | IndentPPDirectives: None 104 | IndentExternBlock: AfterExternBlock 105 | IndentRequires: false 106 | IndentWidth: 2 107 | IndentWrappedFunctionNames: false 108 | InsertTrailingCommas: None 109 | JavaScriptQuotes: Leave 110 | JavaScriptWrapImports: true 111 | KeepEmptyLinesAtTheStartOfBlocks: true 112 | LambdaBodyIndentation: Signature 113 | MacroBlockBegin: '' 114 | MacroBlockEnd: '' 115 | MaxEmptyLinesToKeep: 1 116 | NamespaceIndentation: None 117 | ObjCBinPackProtocolList: Auto 118 | ObjCBlockIndentWidth: 2 119 | ObjCBreakBeforeNestedBlockParam: true 120 | ObjCSpaceAfterProperty: false 121 | ObjCSpaceBeforeProtocolList: true 122 | PenaltyBreakAssignment: 2 123 | PenaltyBreakBeforeFirstCallParameter: 19 124 | PenaltyBreakComment: 300 125 | PenaltyBreakFirstLessLess: 120 126 | PenaltyBreakString: 1000 127 | PenaltyBreakTemplateDeclaration: 10 128 | PenaltyExcessCharacter: 1000000 129 | PenaltyReturnTypeOnItsOwnLine: 60 130 | PenaltyIndentedWhitespace: 0 131 | PointerAlignment: Left 132 | PPIndentWidth: -1 133 | ReferenceAlignment: Pointer 134 | ReflowComments: false 135 | ShortNamespaceLines: 1 136 | SortIncludes: CaseSensitive 137 | SortJavaStaticImport: Before 138 | SortUsingDeclarations: true 139 | SpaceAfterCStyleCast: false 140 | SpaceAfterLogicalNot: false 141 | SpaceAfterTemplateKeyword: true 142 | SpaceBeforeAssignmentOperators: true 143 | SpaceBeforeCaseColon: false 144 | SpaceBeforeCpp11BracedList: false 145 | SpaceBeforeCtorInitializerColon: true 146 | SpaceBeforeInheritanceColon: true 147 | SpaceBeforeParens: ControlStatements 148 | SpaceAroundPointerQualifiers: Default 149 | SpaceBeforeRangeBasedForLoopColon: true 150 | SpaceInEmptyBlock: false 151 | SpaceInEmptyParentheses: false 152 | SpacesBeforeTrailingComments: 2 153 | SpacesInAngles: Never 154 | SpacesInConditionalStatement: false 155 | SpacesInContainerLiterals: true 156 | SpacesInCStyleCastParentheses: false 157 | SpacesInLineCommentPrefix: 158 | Minimum: 1 159 | Maximum: -1 160 | SpacesInParentheses: false 161 | SpacesInSquareBrackets: false 162 | SpaceBeforeSquareBrackets: false 163 | BitFieldColonSpacing: Both 164 | Standard: Latest 165 | StatementAttributeLikeMacros: 166 | - Q_EMIT 167 | StatementMacros: 168 | - Q_UNUSED 169 | - QT_REQUIRE_VERSION 170 | TabWidth: 8 171 | UseCRLF: false 172 | UseTab: Never 173 | WhitespaceSensitiveMacros: 174 | - STRINGIZE 175 | - PP_STRINGIZE 176 | - BOOST_PP_STRINGIZE 177 | - NS_SWIFT_NAME 178 | - CF_SWIFT_NAME 179 | ... 180 | -------------------------------------------------------------------------------- /.dev/.clang_format.hook: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | readonly VERSION="14.0.0" 5 | 6 | version=$(clang-format -version) 7 | 8 | if ! [[ version==∗"VERSION"* ]]; then 9 | echo "clang-format version check failed." 10 | echo "a version contains 'VERSION′isneeded,butget′version'" 11 | echo "you can install the right version, and make an soft-link to '$PATH' env" 12 | exit -1 13 | fi 14 | 15 | clang-format -style=google $@ 16 | -------------------------------------------------------------------------------- /.dev/.cpplint_pre_commit.hook: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #TOTAL_ERRORS=0 4 | #echo "HAHAHAHAHHA" 5 | #exit 5 6 | # 7 | #files=$( 8 | # 9 | #if [[ ! $TRAVIS_BRANCH ]]; then 10 | # # install cpplint on local machine. 11 | # if [[ ! $(which cpplint) ]]; then 12 | # pip install cpplint 13 | # fi 14 | # # diff files on local machine. 15 | # files=$(git diff --cached --name-status | awk 'Extra open brace or missing close brace2}') 16 | #else 17 | # # diff files between PR and latest commit on Travis CI. 18 | # branch_ref=(gitrev−parse"TRAVIS_BRANCH") 19 | # head_ref=$(git rev-parse HEAD) 20 | # files=(gitdiff−−name−statusbranch_ref $head_ref | awk 'Extra open brace or missing close brace2}') 21 | #fi 22 | ## The trick to remove deleted files: https://stackoverflow.com/a/2413151 23 | #for file in $files; do 24 | # echo $file 25 | # if [[ $file =~ ^(patches/.*) ]]; then 26 | # continue; 27 | # else 28 | # cpplint --filter=-readability/fn_size $file; 29 | # TOTAL_ERRORS=(exprTOTAL_ERRORS + $?); 30 | # fi 31 | #done 32 | # 33 | #exit $TOTAL_ERRORS 34 | 35 | if git rev-parse --verify HEAD >/dev/null 2>&1 36 | then 37 | against=HEAD 38 | else 39 | # Initial commit: diff against an empty tree object 40 | against=4b825dc642cb6eb9a060e54bf8d69288fbee4904 41 | fi 42 | 43 | # Redirect output to stderr. 44 | exec 1>&2 45 | 46 | cpplint=cpplint 47 | sum=0 48 | filters='-build/include_order,-build/namespaces,-legal/copyright,-runtime/references,-build/include_what_you_use' 49 | 50 | # for cpp 51 | for file in $(git diff-index --name-status $against -- | grep -E '\.[ch](pp)?$' | awk '{print $2}'); do 52 | $cpplint --filter=$filters $file 53 | sum=$(expr ${sum} + $?) 54 | done 55 | 56 | if [ ${sum} -eq 0 ]; then 57 | exit 0 58 | else 59 | exit 1 60 | fi 61 | -------------------------------------------------------------------------------- /.dev/.pre-commit-config-cpp.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: ed714747d7acbc5790b171702bb012af3b0fe145 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-symlinks 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | - id: detect-private-key 10 | - id: check-symlinks 11 | - id: check-added-large-files 12 | 13 | - repo: local 14 | hooks: 15 | - id: copyright_checker 16 | name: copyright_checker 17 | entry: python ./.copyright.hook 18 | language: system 19 | files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ 20 | exclude: (?!.*third_party)^.*$ 21 | 22 | - repo: local 23 | hooks: 24 | - id: clang-format-with-version-check 25 | name: clang-format 26 | description: Format files with ClangFormat. 27 | entry: bash .clang_format.hook -i 28 | language: system 29 | files: \.(c|cc|cxx|cpp|cu|hxx|proto)$ 30 | 31 | - repo: local 32 | hooks: 33 | - id: cpplint-cpp-source 34 | name: cpplint 35 | description: Check C++ code style using cpplint.py. 36 | entry: bash .cpplint_pre_commit.hook 37 | language: system 38 | files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ 39 | -------------------------------------------------------------------------------- /.dev/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-docstring-first 6 | - id: check-toml 7 | - id: check-yaml 8 | exclude: packaging/.* 9 | args: 10 | - --allow-multiple-documents 11 | - id: mixed-line-ending 12 | args: [--fix=lf] 13 | - id: end-of-file-fixer 14 | 15 | - repo: https://github.com/omnilib/ufmt 16 | rev: v1.3.3 17 | hooks: 18 | - id: ufmt 19 | additional_dependencies: 20 | - black == 22.3.0 21 | - usort == 1.0.2 22 | 23 | - repo: https://github.com/PyCQA/flake8 24 | rev: 7.1.1 25 | hooks: 26 | - id: flake8 27 | args: [--config=setup.cfg] 28 | exclude: generator.py 29 | 30 | - repo: https://github.com/PyCQA/pydocstyle 31 | rev: 6.1.1 32 | hooks: 33 | - id: pydocstyle 34 | -------------------------------------------------------------------------------- /.dev/clear.sh: -------------------------------------------------------------------------------- 1 | rm -rf build dist *.egg-info __pycache__ 2 | rm -rf $(find . -name __pycache__) 3 | -------------------------------------------------------------------------------- /.dev/commit-prepare.sh: -------------------------------------------------------------------------------- 1 | path=$(cd `dirname $0`; pwd) 2 | cd $path 3 | 4 | # cpp & python format lint 5 | # sudo apt-get update 6 | # sudo apt-get install clang-format -y 7 | pip install pre-commit 8 | pip install yapf 9 | pip install cpplint 10 | pre-commit install -c ./.dev/.pre-commit-config.yaml # only lint for python 11 | # pre-commit install -c ./.dev/.pre-commit-config-cpp.yaml # both python + cpp 12 | -------------------------------------------------------------------------------- /.dev/init_dev.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | export ENABLE_FFPA_ALL_STAGES=1 3 | export ENABLE_FFPA_ALL_HEADDIM=0 4 | export ENABLE_FFPA_AMPERE=0 5 | export ENABLE_FFPA_HOPPER=0 6 | export ENABLE_FFPA_DEBUG=1 7 | set +x 8 | -------------------------------------------------------------------------------- /.dev/init_prod.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | export ENABLE_FFPA_ALL_STAGES=1 3 | export ENABLE_FFPA_ALL_HEADDIM=1 4 | export ENABLE_FFPA_AMPERE=1 5 | export ENABLE_FFPA_HOPPER=1 6 | export ENABLE_FFPA_DEBUG=0 7 | set +x 8 | -------------------------------------------------------------------------------- /.dev/init_prod_mini.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | export ENABLE_FFPA_ALL_STAGES=1 3 | export ENABLE_FFPA_ALL_HEADDIM=0 4 | export ENABLE_FFPA_AMPERE=0 5 | export ENABLE_FFPA_HOPPER=0 6 | export ENABLE_FFPA_DEBUG=0 7 | set +x 8 | -------------------------------------------------------------------------------- /.dev/install.sh: -------------------------------------------------------------------------------- 1 | rm -rf $(find . -name __pycache__) 2 | python3 setup.py bdist_wheel && cd dist 3 | python3 -m pip install *.whl 4 | cd .. && rm -rf build *.egg-info 5 | rm -rf $(find . -name __pycache__) 6 | -------------------------------------------------------------------------------- /.dev/uninstall.sh: -------------------------------------------------------------------------------- 1 | python3 -m pip uninstall ffpa-attn -y 2 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info 22 | dist 23 | .tmp 24 | bin 25 | .cache 26 | -------------------------------------------------------------------------------- /.github/workflows/issue.yml: -------------------------------------------------------------------------------- 1 | name: issues 2 | on: 3 | schedule: 4 | - cron: "0 0 * * 0" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9.0.0 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 7 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 7 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info 22 | dist 23 | .tmp 24 | bin 25 | .cache 26 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cutlass"] 2 | path = cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | include LICENSE 3 | include requirements.txt 4 | recursive-include tests * 5 | global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp .gitignore *__pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 |

🤖FFPA(Split-D): Yet another Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for large headdim🐑

4 | 📚FFPA(Split-D) Blog | 📈L20 ~1.9x↑🎉 | 📈A30 ~1.8x↑🎉 | 📈3080 ~2.9x↑🎉 | 📈4090 ~2.1x↑🎉

5 |

6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 |
16 | 17 |
18 |

🤖FFPA: 1.8x~3x🎉faster vs SDPA EA with or without MMA Acc F32

19 |
20 | 21 | 🤖[WIP] **FFPA**: Yet another **Faster Flash Prefill Attention** with **O(1) SRAM complexity** & **O(d/4) or O(1) register complexity** for large headdim (D > 256), almost **1.8x~3x** 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: [📈L20 ~1.9x↑🎉](#L1-bench-l20), [📈A30 ~1.8x↑🎉](#L1-bench-a30), [📈3080 ~2.9x↑🎉](#L1-bench-3080), [📈4090 ~2.1x↑🎉](#L1-bench-4090). **FFPA Attention Algo: Fine-grained tiling** for large headim, **FA-2 Attention Algo: Coarse-grained tiling** for small headidm. 22 | 23 | 27 | image 28 | 29 | 30 | 💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to 🌟👆🏻star this repo to support me ~) 31 | 32 | ## ©️Citations🎉🎉 33 | 34 | ```BibTeX 35 | @misc{ffpa-attn@2025, 36 | title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.}, 37 | url={https://github.com/xlite-dev/ffpa-attn.git}, 38 | note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git}, 39 | author={xlite-dev etc}, 40 | year={2025} 41 | } 42 | ``` 43 | 44 | ## 📖 Contents 45 | 46 | - [📖 Installation⚙️](#install) 47 | - [📖 Python Testing👇](#python-test) 48 | - [📖 FFPA L1~L3 Design💡](#ffpa-design) 49 | - [📈 FFPA L1: L20 ~1.9x↑🎉](#L1-bench-l20) 50 | - [📈 FFPA L1: A30 ~1.8x↑🎉](#L1-bench-a30) 51 | - [📈 FFPA L1: 3080 ~2.9x↑🎉](#L1-bench-3080) 52 | - [📈 FFPA L1: 4090 ~2.1x↑🎉](#L1-bench-4090) 53 | - [📖 Fully Fused MLA w/ FFPA🎉](#fused-mla) 54 | 55 | ## 📖 FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level💡 56 |
57 | 58 | We have extended FlashAttention for large headdim (D > 256) by implementing **Fine-grained Tiling** at the **MMA level (GEMM style)** for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (**1.8x~3x** 🎉 faster than SDPA EA). 59 | 60 | We have named this new attention tiling technique **FFPA: Faster Flash Prefill Attention**. We have designed three `(L1~L3)` levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇 61 | 62 | - [x] 📚L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, ≈O(d/4) register complexity. 63 | - [ ] 📚L2: level 2, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + Q@K^T recomputation. 64 | - [ ] 📚L3: level 3, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + scaling O via HBM offloading. 65 | 66 | By leveraging this approach, we can achieve better performance than SDPA EA for very large headdim (D > 256, `FA-2 not supported`). Approximate SRAM and register complexity analysis for FFPA L1~L3 level is as follows: (`d`=headdim, `C,Br,Bc`=Constant, `Br=Bc`, let O(C)≈O(1)) 👇 67 | 68 | |📚Complexity| 📚FFPA L1 | 📚FFPA L2 | 📚FFPA L3 | 📚FA-2 | 69 | |:---:|:---:|:---:|:---:|:---:| 70 | |SRAM | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | ≈O(3xBrxd), d↑ | 71 | |Register | ≈O(d/4), d↑ | O((Bc/16)x4+2C)≈O(1)|O((Bc/16)x4+2C)≈O(1)| ≈O(d/2), d↑ | 72 | |HBM| ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O| ≈FA2≈O(Nd), O | ≈O(Nd), O | 73 | |Extra HBM| ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈O(N), m,l | 74 | 75 | **📚👇Core Features🎉🎉**: I have implemented **FFPA L1~L3** using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, **Fully QKV Fine-grained Tiling(GEMM style)**, Collective Store, etc. 76 | 77 | |📚Feature |📚Feature |📚Feature |📚Feature| 78 | |:---:|:---:|:---:|:---:| 79 | |✔️Tensor Cores |✔️**MMA(m16n8k16)** |✔️Tile Block(Br, Bc) |✔️Tile MMA/Warp | 80 | |✔️**Split Q**(FA-2)|✔️Pack LDST(128 bits)|✔️SMEM **Swizzle/Pad** |✔️Copy Async | 81 | |✔️**Reg Double Buffers** |✔️QKV **Multi-Stages(1~4)** |✔️Collective Store(**Shfl**)|✔️**Prefetch QKV** g2s | 82 | |✔️**QKV Fine-grained Tiling**|✔️**Shared QKV** SMEM|✔️Mixed MMA Acc|✔️**Persist Q** s2r/g2s| 83 | 84 | - 📚 case: FFPA `L1` kernel template signature: [ffpa_attn_templates_L1.cuh](csrc/cuffpa/ffpa_attn_templates_L1.cuh) 85 | 86 | ```CUDA 87 | template< 88 | const int kHeadDim, // Headdim, 32~1024 89 | const int kMmaAtomM, // MMA Atom M, 16 90 | const int kMmaAtomN, // MMA Atom N, 8 91 | const int kMmaAtomK, // MMA Atom K, 16 92 | const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] 93 | const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] 94 | const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] 95 | const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] 96 | const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M 97 | const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N 98 | const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M 99 | const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... 100 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 101 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 102 | const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half. 103 | const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point. 104 | const int kPrefetchPV, // Prefetch V at the Appropriate Time Point. 105 | const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. 106 | const int kPersistQs2r, // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM. 107 | const int kPersistQg2s, // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage. 108 | const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. 109 | const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) 110 | const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) 111 | const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding 112 | const int kPadK, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding 113 | const int kPadV // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding 114 | > __global__ void // Q, K, V, O -> [B, H, N, D] 115 | // FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>=256), 116 | // which can achieve 1.8x~3x🎉 faster than SDPA EA with or without MMA Acc F32. 117 | ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...); 118 | // FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<256), 119 | // which can achieve 95%-105%🎉 performance as SDPA FA-2 BE with MMA Acc F32 for N<=4096, 120 | // and achieve almost 1.2x~1.4x🎉 faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 + 121 | // P@V F16) for all range N. 122 | ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...); 123 | ``` 124 | 125 | ## 📖 Prerequisites 126 |
127 | 128 | - Python >= 3.10 129 | - PyTorch >= 2.4.0, CUDA >= 12.4 130 | - flash-attention >= 2.6.3 (for test) 131 | - Recommended: PyTorch 2.5.1, CUDA 12.5 132 | - Docker: nvcr.io/nvidia/pytorch:24.10-py3 133 | 134 | ## 📖 Installation 135 | 136 |
137 | 138 | The FFPA implemented in this repo can be install as a python library, namely, `ffpa-attn` library (optional). 139 | ```bash 140 | git clone https://github.com/xlite-dev/ffpa-attn.git 141 | # clone, then, run bash .dev/install.sh directly or run commands: 142 | python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y 143 | ``` 144 | 145 | ## 📖 FFPA L1 (Level 1): Benchmark 🎉🎉 146 | 147 |
148 | 149 | L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, **D=320-1024(FA2 not supported 👀)**. (Notes, `*`=MMA Acc F32, `^`=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark) 150 | 151 | - 📚 NVIDIA L20 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**) 152 | 153 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 154 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 155 | |SDPA EA|56T|63T|58T|58T|55T|56T|54T|55T|54T|55T|54T|56T| 156 | |FFPA L1*|102T|102T|103T|104T|103T|95T|95T|95T|95T|96T|95T|94T| 157 | |Speedup|1.82x|1.62x|1.78x|1.79x|1.87x|1.7x|1.76x|1.73x|1.76x|1.75x|1.76x|1.68x| 158 | |FFPA L1^|104T|103T|103T|102T|104T|103T|102T|94T|94T|94T|100T|100T| 159 | |Speedup|1.86x|1.63x|1.78x|1.76x|1.89x|1.84x|1.89x|1.71x|1.74x|1.71x|1.85x|1.79x| 160 | 161 | - 📚 NVIDIA L20 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~1.9x↑🎉**) 162 | 163 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 164 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 165 | |SDPA EA|56T|64T|58T|58T|55T|56T|54T|55T|54T|55T|54T|56T| 166 | |FFPA L1*|105T|102T|104T|103T|105T|95T|95T|94T|94T|94T|102T|101T| 167 | |Speedup|1.88x|1.59x|1.79x|1.78x|1.91x|1.7x|1.76x|1.71x|1.74x|1.71x|1.89x|1.8x| 168 | |FFPA L1^|104T|103T|103T|102T|103T|103T|102T|94T|94T|94T|100T|100T| 169 | |Speedup|1.86x|1.61x|1.78x|1.76x|1.87x|1.84x|1.89x|1.71x|1.74x|1.71x|1.85x|1.79x| 170 | 171 |
172 | 173 | 174 |
175 | 176 |
177 | 178 | - 📚 NVIDIA A30 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**) 179 | 180 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 181 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 182 | |SDPA EA|25T|25T|24T|24T|24T|24T|23T|22T|22T|22T|22T|18T| 183 | |FFPA L1*|45T|44T|44T|43T|43T|38T|37T|37T|37T|36T|33T|32T| 184 | |Speedup|1.8x|1.76x|1.83x|1.79x|1.79x|1.58x|1.61x|1.68x|1.68x|1.64x|1.5x|1.78x| 185 | |FFPA L1^|48T|46T|45T|43T|44T|44T|44T|38T|37T|36T|40T|34T| 186 | |Speedup|1.92x|1.84x|1.88x|1.79x|1.83x|1.83x|1.91x|1.73x|1.68x|1.64x|1.82x|1.89x| 187 | 188 | - 📚 NVIDIA A30 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~1.9x↑🎉**) 189 | 190 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 191 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 192 | |SDPA EA|25T|25T|24T|24T|24T|24T|23T|22T|22T|22T|22T|18T| 193 | |FFPA L1*|48T|46T|46T|43T|44T|38T|38T|38T|37T|36T|40T|34T| 194 | |Speedup|1.92x|1.84x|1.92x|1.79x|1.83x|1.58x|1.65x|1.73x|1.68x|1.64x|1.82x|1.89x| 195 | |FFPA L1^|48T|46T|45T|43T|44T|44T|44T|38T|37T|36T|39T|34T| 196 | |Speedup|1.92x|1.84x|1.88x|1.79x|1.83x|1.83x|1.91x|1.73x|1.68x|1.64x|1.77x|1.89x| 197 | 198 |
199 | 200 | 201 |
202 | 203 |
204 | 205 | - 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~2.5x↑🎉**) 206 | 207 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 208 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 209 | |SDPA EA|13T|16T|11T|16T|15T|15T|15T|15T|14T|14T|14T|14T| 210 | |FFPA L1*|33T|31T|30T|30T|30T|27T|27T|26T|26T|26T|26T|25T| 211 | |Speedup|2.54x|1.94x|2.73x|1.88x|2.0x|1.8x|1.8x|1.73x|1.86x|1.86x|1.86x|1.79x| 212 | |FFPA L1^|43T|41T|39T|39T|39T|39T|39T|36T|34T|33T|31T|33T| 213 | |Speedup|3.31x|2.56x|3.55x|2.44x|2.6x|2.6x|2.6x|2.4x|2.43x|2.36x|2.21x|2.36x| 214 | 215 | - 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~2.9x↑🎉**) 216 | 217 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 218 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 219 | |SDPA EA|13T|15T|12T|15T|14T|15T|14T|14T|14T|14T|14T|14T| 220 | |FFPA L1*|38T|36T|34T|35T|34T|31T|32T|31T|30T|28T|27T|27T| 221 | |Speedup|2.92x|2.4x|2.83x|2.33x|2.43x|2.07x|2.29x|2.21x|2.14x|2.0x|1.93x|1.93x| 222 | |FFPA L1^|44T|41T|39T|39T|38T|39T|39T|36T|34T|32T|31T|33T| 223 | |Speedup|3.38x|2.73x|3.25x|2.6x|2.71x|2.6x|2.79x|2.57x|2.43x|2.29x|2.21x|2.36x| 224 | 225 |
226 | 227 | 228 |
229 | 230 |
231 | 232 | - 📚 NVIDIA RTX 4090 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**) 233 | 234 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 235 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 236 | |SDPA EA|81T|94T|85T|85T|79T|81T|79T|80T|79T|80T|78T|78T| 237 | |FFPA L1*|149T|150T|150T|150T|150T|140T|140T|140T|139T|139T|137T|134T| 238 | |Speedup|1.84x|1.6x|1.76x|1.76x|1.9x|1.73x|1.77x|1.75x|1.76x|1.74x|1.76x|1.72x| 239 | |FFPA L1^|194T|194T|189T|191T|197T|188T|184T|180T|177T|172T|171T|171T| 240 | |Speedup|2.4x|2.06x|2.22x|2.25x|2.49x|2.32x|2.33x|2.25x|2.24x|2.15x|2.19x|2.19x| 241 | 242 | - 📚 NVIDIA RTX 4090 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~2.1x↑🎉**) 243 | 244 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024| 245 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 246 | |SDPA EA|82T|92T|85T|84T|78T|81T|79T|80T|78T|79T|77T|78T| 247 | |FFPA L1*|176T|170T|171T|171T|171T|161T|160T|161T|160T|158T|165T|164T| 248 | |Speedup|2.15x|1.85x|2.01x|2.04x|2.19x|1.99x|2.03x|2.01x|2.05x|2.0x|2.14x|2.1x| 249 | |FFPA L1^|200T|191T|189T|191T|188T|188T|186T|179T|175T|173T|172T|170T| 250 | |Speedup|2.44x|2.08x|2.22x|2.27x|2.41x|2.32x|2.35x|2.24x|2.24x|2.19x|2.23x|2.18x| 251 | 252 |
253 | 254 | 255 |
256 | 257 | ## 📖 Python Testing 258 |
259 | 260 | 👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The `--gen-bench` and `--plot` options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉. 261 | 262 | - 📚 case: B=1, H=48, N=8192, D=320(`FA2 not supported`) 263 | ```python 264 | # You can test on many devices, such as Volta, Ampere, Ada, Hopper, ... 265 | cd tests && python3 test_ffpa_attn.py --B 1 --H 48 --N 8192 --show-all --D 320 266 | ---------------------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5-------------------- 267 | (sdpa): ['-0.02380371'], time:73.66518ms, TFLOPS:56.19 (+0.00 %)(~1.00x) 268 | (ffpa+acc+f32+L1+stage1): ['-0.02378845'], time:52.87361ms, TFLOPS:78.28 (+39.32%)(~1.39x) 269 | (ffpa+acc+f32+L1+stage2): ['-0.02378845'], time:40.84062ms, TFLOPS:101.35(+29.46%)(~1.80x) 270 | (ffpa+acc+f32+L1+stage3): ['-0.02378845'], time:40.49534ms, TFLOPS:102.21(+0.85 %)(~1.82x) 271 | (ffpa+acc+f32+L1+stage4): ['-0.02378845'], time:40.88177ms, TFLOPS:101.25(+0.00 %)(~1.80x) 272 | (ffpa+acc+f16+L1+stage1): ['-0.02378845'], time:53.43298ms, TFLOPS:77.46 (+0.00 %)(~1.38x) 273 | (ffpa+acc+f16+L1+stage2): ['-0.02378845'], time:39.76068ms, TFLOPS:104.10(+1.85 %)(~1.85x) 274 | (ffpa+acc+f16+L1+stage3): ['-0.02378845'], time:39.54901ms, TFLOPS:104.66(+0.54 %)(~1.86x) 275 | (ffpa+acc+f16+L1+stage4): ['-0.02378845'], time:41.06554ms, TFLOPS:100.79(+0.00 %)(~1.79x) 276 | -------------------------------------------------------------------------------------------------------- 277 | ``` 278 | - 📚 case: Generate benchmark table and speedup bar plots on Your device. 279 | ```bash 280 | cd tests && pip install matplotlib && python3 test_ffpa_attn.py --gen-bench --show-all --plot 281 | ``` 282 | - 📚 case: Compare small headdim (d<256, e.g 64), FFPA-L1 vs SDPA FA-2 BE. 283 | ```python 284 | # Enable ffpa-attn small d kernel which using coarse-grained tiling method. 285 | export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1 286 | cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20 287 | ---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5-------------------- 288 | (sdpa): ['0.00802612'], time:0.148057ms, TFLOPS:59.14 (+0.00 %)(~1.00x) 289 | (ffpa+acc+f32+L1+stage1): ['0.00803375'], time:0.103807ms, TFLOPS:84.34 (+42.63%)(~1.43x) 290 | (ffpa+acc+f32+L1+stage2): ['0.00803375'], time:0.102233ms, TFLOPS:85.64 (+1.54 %)(~1.45x) 291 | (ffpa+acc+f32+L1+stage3): ['0.00803375'], time:0.102519ms, TFLOPS:85.40 (+0.00 %)(~1.44x) 292 | (ffpa+acc+f32+L1+stage4): ['0.00803375'], time:0.102043ms, TFLOPS:85.80 (+0.19 %)(~1.45x) 293 | (ffpa+acc+f16+L1+stage1): ['0.00795746'], time:0.104713ms, TFLOPS:83.61 (+0.00 %)(~1.41x) 294 | (ffpa+acc+f16+L1+stage2): ['0.00795746'], time:0.102949ms, TFLOPS:85.05 (+0.00 %)(~1.44x) 295 | (ffpa+acc+f16+L1+stage3): ['0.00795746'], time:0.108957ms, TFLOPS:80.36 (+0.00 %)(~1.36x) 296 | (ffpa+acc+f16+L1+stage4): ['0.00795746'], time:0.103282ms, TFLOPS:84.77 (+0.00 %)(~1.43x) 297 | -------------------------------------------------------------------------------------------------------- 298 | cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 4096 --check --show-all --D 64 # NVIDIA L20 299 | -------------------------B=1, H=32, N=4096, D=64, Warmup: 1, Iters: 5----------------------------------- 300 | (sdpa): ['0.01959229'], time:1.397752ms, TFLOPS:100.24(+0.00 %)(~1.00x) 301 | (ffpa+acc+f32+L1+stage1): ['0.01959229'], time:1.368856ms, TFLOPS:102.36(+2.11 %)(~1.02x) 302 | (ffpa+acc+f32+L1+stage2): ['0.01959229'], time:1.367807ms, TFLOPS:102.44(+0.08 %)(~1.02x) 303 | (ffpa+acc+f32+L1+stage3): ['0.01959229'], time:1.367855ms, TFLOPS:102.43(+0.00 %)(~1.02x) 304 | (ffpa+acc+f32+L1+stage4): ['0.01959229'], time:1.368045ms, TFLOPS:102.42(+0.00 %)(~1.02x) 305 | (ffpa+acc+f16+L1+stage1): ['0.01957703'], time:1.389312ms, TFLOPS:100.85(+0.00 %)(~1.01x) 306 | (ffpa+acc+f16+L1+stage2): ['0.01957703'], time:1.388311ms, TFLOPS:100.92(+0.00 %)(~1.01x) 307 | (ffpa+acc+f16+L1+stage3): ['0.01957703'], time:1.386976ms, TFLOPS:101.02(+0.00 %)(~1.01x) 308 | (ffpa+acc+f16+L1+stage4): ['0.01957703'], time:1.387834ms, TFLOPS:100.96(+0.00 %)(~1.01x) 309 | -------------------------------------------------------------------------------------------------------- 310 | ``` 311 | 312 | 💡NOTE: Please check all configurable environment variables in [env.py](./env.py). 313 | 314 | ## 📖 Fully Fused MLA with FFPA 🎉 315 | 316 |
317 | 318 | Extending the support of FA for large headdim is meaningful in the context of **DeepSeek MLA**. For example, when FA supports headdim values greater than 512, we can achieve fully Fused MLA into a single CUDA kernel, after W_UK/W_UV are absorbed into W_Q/W_O (resulting in C_kv/C_q with `dc/dc' >= 512`). TODO list👇: 319 | 320 | - [ ] 📚Fully Fused MLA into a single CUDA kernel using **FFPA** Algo and Tensor Cores. 321 | 322 | ## ©️License 323 | 324 |
325 | 326 | GNU General Public License v3.0 327 | 328 | ## 🎉Contribute 329 | 330 |
331 | 332 | How to contribute? Wecome to star⭐️ this repo to support me👆🏻 ~ 333 | 334 |
335 | 336 | 337 | 338 | 339 | Star History Chart 340 | 341 | 342 |
343 | 344 | ## 📖 References 345 |
346 | 347 | - [flash-attention](https://github.com/Dao-AILab/flash-attention) 348 | - [LeetCUDA](https://github.com/xlite-dev/LeetCUDA) 349 | - [flashinfer](https://github.com/flashinfer-ai/flashinfer) 350 | -------------------------------------------------------------------------------- /bench/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info 22 | dist 23 | -------------------------------------------------------------------------------- /bench/NVIDIA_A30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_A30.png -------------------------------------------------------------------------------- /bench/NVIDIA_A30_ffpa+acc+f16+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_A30_ffpa+acc+f16+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_A30_ffpa+acc+f32+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_A30_ffpa+acc+f32+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png -------------------------------------------------------------------------------- /bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f16+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f16+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f32+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f32+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_GeForce_RTX_4090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_4090.png -------------------------------------------------------------------------------- /bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f16+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f16+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f32+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f32+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_L20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_L20.png -------------------------------------------------------------------------------- /bench/NVIDIA_L20_ffpa+acc+f16+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_L20_ffpa+acc+f16+L1_Speedup.png -------------------------------------------------------------------------------- /bench/NVIDIA_L20_ffpa+acc+f32+L1_Speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_L20_ffpa+acc+f32+L1_Speedup.png -------------------------------------------------------------------------------- /bench/bank_conflicts_check.sh: -------------------------------------------------------------------------------- 1 | cd tests 2 | ncu \ 3 | --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \ 4 | python3 test.py --B 1 --H 8 --N 1024 --D 320 --w 0 --i 1 --show-all 5 | cd .. 6 | -------------------------------------------------------------------------------- /bench/bench.sh: -------------------------------------------------------------------------------- 1 | # benchmarks 2 | cd tests 3 | python3 test.py --gen-bench --show-all --plot 4 | cd .. 5 | -------------------------------------------------------------------------------- /csrc/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info -------------------------------------------------------------------------------- /csrc/cuffpa/ffpa_attn_F16F16F16_L1.cu: -------------------------------------------------------------------------------- 1 | #include "launch_templates.cuh" 2 | using namespace ffpa; 3 | 4 | 5 | void ffpa_mma_acc_f16_L1(torch::Tensor Q, 6 | torch::Tensor K, 7 | torch::Tensor V, 8 | torch::Tensor O, 9 | int stages) { 10 | CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] 11 | CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] 12 | CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] 13 | CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] 14 | const int d = Q.size(3); // B, H, N, d 15 | // Q@K^T or P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 16 | constexpr int kMmaAccFloat32QK = 0; 17 | constexpr int kMmaAccFloat32PV = 0; 18 | 19 | #ifdef ENABLE_FFPA_ALL_STAGES 20 | // dispatch stages 21 | if (stages == 2) { 22 | DISPATCH_HEADDIM(LAUNCHER_L1, 2); 23 | } else if (stages == 3) { 24 | DISPATCH_HEADDIM(LAUNCHER_L1, 3); 25 | } else if (stages == 4) { 26 | DISPATCH_HEADDIM(LAUNCHER_L1, 4); 27 | } else { 28 | DISPATCH_HEADDIM(LAUNCHER_L1, 1); 29 | } 30 | #else 31 | // dispatch stages 32 | if (stages == 2) { 33 | DISPATCH_HEADDIM(LAUNCHER_L1, 2); 34 | } else { 35 | DISPATCH_HEADDIM(LAUNCHER_L1, 1); 36 | } 37 | #endif 38 | } 39 | -------------------------------------------------------------------------------- /csrc/cuffpa/ffpa_attn_F16F16F32_L1.cu: -------------------------------------------------------------------------------- 1 | #include "launch_templates.cuh" 2 | using namespace ffpa; 3 | 4 | 5 | void ffpa_mma_acc_f32_L1(torch::Tensor Q, 6 | torch::Tensor K, 7 | torch::Tensor V, 8 | torch::Tensor O, 9 | int stages) { 10 | CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] 11 | CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] 12 | CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] 13 | CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] 14 | const int d = Q.size(3); // B, H, N, d 15 | // Q@K^T or P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 16 | #ifdef ENABLE_FFPA_FORCE_QK_F16 17 | constexpr int kMmaAccFloat32QK = 0; 18 | #else 19 | constexpr int kMmaAccFloat32QK = 1; 20 | #endif 21 | #ifdef ENABLE_FFPA_FORCE_PV_F16 22 | constexpr int kMmaAccFloat32PV = 0; 23 | #else 24 | constexpr int kMmaAccFloat32PV = 1; 25 | #endif 26 | 27 | #ifdef ENABLE_FFPA_ALL_STAGES 28 | // dispatch stages 29 | if (stages == 2) { 30 | DISPATCH_HEADDIM(LAUNCHER_L1, 2); 31 | } else if (stages == 3) { 32 | DISPATCH_HEADDIM(LAUNCHER_L1, 3); 33 | } else if (stages == 4) { 34 | DISPATCH_HEADDIM(LAUNCHER_L1, 4); 35 | } else { 36 | DISPATCH_HEADDIM(LAUNCHER_L1, 1); 37 | } 38 | #else 39 | // dispatch stages 40 | if (stages == 2) { 41 | DISPATCH_HEADDIM(LAUNCHER_L1, 2); 42 | } else { 43 | DISPATCH_HEADDIM(LAUNCHER_L1, 1); 44 | } 45 | #endif 46 | } 47 | -------------------------------------------------------------------------------- /csrc/cuffpa/launch_templates.cuh: -------------------------------------------------------------------------------- 1 | #include "ffpa_attn_templates_L1.cuh" 2 | using namespace ffpa; 3 | 4 | static constexpr int kMaxDForSmallDKernel = 64; 5 | static constexpr int kMaxDForOStoreFloat32 = 64; 6 | static constexpr int kMaxDForSmallBlockTile = 256; 7 | 8 | template 9 | static constexpr int getConfigMmaTileSeqLenQP() { 10 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S 11 | #if defined(BUILD_FFPA_ATTN_MMA_L20) 12 | constexpr int kMmaTileSeqLenQP = ( 13 | kHeadDim <= kMaxDForSmallBlockTile) ? 4: 8; 14 | #else 15 | constexpr int kMmaTileSeqLenQP = ( 16 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 8; 17 | #endif 18 | #else // if undef ENABLE_FFPA_PERSIST_KV_G2S 19 | // O(1) SRAM complexity, may always use large tile for 20 | // ffpa large d kernel. TODO: tune block size for L20/4090/3080 etc. 21 | #if defined(BUILD_FFPA_ATTN_MMA_L20) 22 | constexpr int kMmaTileSeqLenQP = ( 23 | kHeadDim <= kMaxDForSmallBlockTile) ? 4: 8; 24 | #else 25 | constexpr int kMmaTileSeqLenQP = ( 26 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 8; 27 | #endif 28 | #endif 29 | return kMmaTileSeqLenQP; 30 | } 31 | 32 | template 33 | static constexpr int getConfigWarpTileSeqLenK() { 34 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S 35 | #if defined(BUILD_FFPA_ATTN_MMA_L20) 36 | constexpr int kWarpTileSeqLenK = ( 37 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 16; 38 | #else 39 | constexpr int kWarpTileSeqLenK = ( 40 | kHeadDim <= kMaxDForSmallBlockTile) ? 16: 16; 41 | #endif 42 | #else // if undef ENABLE_FFPA_PERSIST_KV_G2S 43 | #if defined(BUILD_FFPA_ATTN_MMA_L20) 44 | constexpr int kWarpTileSeqLenK = ( 45 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 16; 46 | #else 47 | constexpr int kWarpTileSeqLenK = ( 48 | kHeadDim <= kMaxDForSmallBlockTile) ? 16: 16; 49 | #endif 50 | #endif 51 | return kWarpTileSeqLenK; 52 | } 53 | 54 | template 55 | static constexpr int getConfigWarpTileHeadDimV() { 56 | constexpr int kMmaAtomN = 8; 57 | constexpr int kMmaTileHeadDimV = 1; 58 | constexpr int kWarpTileHeadDimV = ( 59 | kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); 60 | return kWarpTileHeadDimV; 61 | } 62 | 63 | static constexpr int getConfigShareSmemQKV() { 64 | #if defined(ENABLE_FFPA_QKV_SMEM_SHARE) 65 | constexpr int kShareSmemQKV = 1; 66 | #else 67 | constexpr int kShareSmemQKV = 0; 68 | #endif 69 | return kShareSmemQKV; 70 | } 71 | 72 | template 73 | static constexpr int getConfigOStorageAccFloat32() { 74 | // 0/1, The precision of the O storage buffer can differ from 75 | // that of the MMA, supporting either FP32 or Half precision. 76 | // FP16 can provide precision to approximately 3-4 decimal places. 77 | // Thus, if the error does not exceed 1e-3, using FP16 storage is 78 | // sufficient for most applications. 79 | return ((kHeadDim <= kMaxDForOStoreFloat32)) ? 1 : 0; 80 | } 81 | 82 | template 83 | static constexpr int getConfigPrefetchQKV() { 84 | // Prefetch QKV at the appropriate time point. 85 | #if defined(ENABLE_FFPA_PREFETCH_QKV) 86 | #if defined(ENABLE_FFPA_PERSIST_KV_G2S) 87 | constexpr int kPrefetchQKV = 1; // kStageQKV is unused 88 | #else 89 | constexpr int kPrefetchQKV = (kStageQKV > 1) ? 1 : 0; 90 | #endif 91 | #else 92 | constexpr int kPrefetchQKV = 0; 93 | #endif 94 | return kPrefetchQKV; 95 | } 96 | 97 | template 98 | static constexpr int getConfigPersistQg2s() { 99 | // Persist load Q g2s for headdim < 512, more SRAM, but still 100 | // keep register usage. 101 | #if defined(ENABLE_FFPA_PERSIST_Q_G2S) 102 | constexpr int kPersistQg2s = (kHeadDim < 256) ? 1 : ( 103 | (kHeadDim <= 320) ? ((kStageQK < 3) ? 1 : 0) : 0 104 | ); 105 | #else 106 | constexpr int kPersistQg2s = 0; 107 | #endif 108 | return kPersistQg2s; 109 | } 110 | 111 | static constexpr int getConfigPersistQs2r() { 112 | // Persist load Q s2r for headdim < 512, more registers, 113 | // but still keep O(1) SRAM. 114 | #ifdef ENABLE_FFPA_PERSIST_Q_S2R 115 | constexpr int kPersistQs2r = 1; 116 | #else 117 | constexpr int kPersistQs2r = 0; 118 | #endif 119 | return kPersistQs2r; 120 | } 121 | 122 | static constexpr int getConfigPersistVs2r() { 123 | #ifdef ENABLE_FFPA_PERSIST_V_S2R 124 | constexpr int kPersistVs2r = 1; 125 | #else 126 | constexpr int kPersistVs2r = 0; 127 | #endif 128 | return kPersistVs2r; 129 | } 130 | 131 | static constexpr int getConfigRegistersPipeKV() { 132 | #ifdef ENABLE_FFPA_REGISTERS_PIPE_KV 133 | constexpr int kRegPipeKV = 1; 134 | #else 135 | constexpr int kRegPipeKV = 0; 136 | #endif 137 | return kRegPipeKV; 138 | } 139 | 140 | static constexpr int getConfigPadQ() { 141 | #ifdef ENABLE_FFPA_SMEM_SWIZZLE_Q 142 | constexpr int kPadQ = 0; 143 | #else 144 | constexpr int kPadQ = 8; 145 | #endif 146 | return kPadQ; 147 | } 148 | 149 | static constexpr int getConfigPadK() { 150 | #ifdef ENABLE_FFPA_SMEM_SWIZZLE_K 151 | constexpr int kPadK = 0; 152 | #else 153 | constexpr int kPadK = 8; 154 | #endif 155 | return kPadK; 156 | } 157 | 158 | static constexpr int getConfigPadV() { 159 | #ifdef ENABLE_FFPA_SMEM_SWIZZLE_V 160 | constexpr int kPadV = 0; 161 | #else 162 | constexpr int kPadV = 8; 163 | #endif 164 | return kPadV; 165 | } 166 | 167 | template 168 | static inline dim3 getConfigBlock() { 169 | dim3 block(kNumThreads); 170 | return block; 171 | } 172 | 173 | template 174 | static inline dim3 getConfigGrid( 175 | const int B, const int H, const int N) { 176 | // Tr(=N/Br), batch_size x num_heads 177 | // try grid(N/Br, B * H) or grid(N/Br, H, B) 178 | #ifdef ENABLE_FFPA_LAUNCH_GRID_DNHB 179 | dim3 grid(utils::div_ceil(N, Br), H, B); 180 | #else 181 | dim3 grid(utils::div_ceil(N, Br), B * H); 182 | #endif 183 | return grid; 184 | } 185 | 186 | template< 187 | const int Br, 188 | const int Bc, 189 | const int kMmaAtomM, 190 | const int kMmaAtomN, 191 | const int kMmaAtomK, 192 | const int kHeadDim, 193 | const int kShareSmemQKV, 194 | const int kPersistQg2s, 195 | const int kPersistQs2r, 196 | const int kStageQK, 197 | const int kStagePV, 198 | const int kPadQ, 199 | const int kPadK, 200 | const int kPadV 201 | > 202 | static constexpr int getConfigQKVSmemMaxSize() { 203 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S 204 | if constexpr (kHeadDim <= kMaxDForSmallDKernel) { // e.g > 128 will use large d kernel 205 | // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem. 206 | constexpr int Q_smem_size = ( 207 | (kHeadDim / kMmaAtomK) * (Br * (kMmaAtomK + kPadQ))); 208 | constexpr int K_smem_size = ( 209 | (kHeadDim / kMmaAtomK) * (Bc * (kMmaAtomK + kPadK))); 210 | constexpr int V_smem_size = ( 211 | (kHeadDim / (kMmaAtomN * 2)) * (Bc * (kMmaAtomN * 2 + kPadV))); 212 | constexpr int kQSmemMaxSize = Q_smem_size * 2; 213 | constexpr int kKSmemMaxSize = K_smem_size * 2; 214 | constexpr int kVSmemMaxSize = V_smem_size * 2; 215 | constexpr int kQKSmemMaxSize = ( 216 | kQSmemMaxSize > kKSmemMaxSize ? kQSmemMaxSize : kKSmemMaxSize); 217 | constexpr int kQKVSmemMaxSize = ( 218 | (kShareSmemQKV && kPersistQs2r) ? 219 | (kQKSmemMaxSize + kVSmemMaxSize) : // QK shared the same smem 220 | (kQSmemMaxSize + kKSmemMaxSize + kVSmemMaxSize) 221 | ); 222 | return kQKVSmemMaxSize; 223 | } else { 224 | // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem. 225 | constexpr int Q_smem_size = ((kPersistQg2s ? (kHeadDim / kMmaAtomK) : kStageQK) * 226 | (Br * (kMmaAtomK + kPadQ))) * 2; 227 | constexpr int K_smem_size = ((kStageQK) * (Bc * (kMmaAtomK + kPadK))) * 2; 228 | constexpr int V_smem_size = (kStagePV * (Bc * (kMmaAtomN * 2 + kPadV))) * 2; 229 | constexpr int kQKSmemMaxSize = (Q_smem_size + K_smem_size); 230 | constexpr int kVSmemMaxSize = V_smem_size; 231 | // try to let V reuse all Q+K smem after Q@K^T, reduce smem usage. 232 | constexpr int kQKVSmemMaxSize = ( 233 | (kShareSmemQKV && (!kPersistQg2s)) ? 234 | ((kQKSmemMaxSize > kVSmemMaxSize) ? kQKSmemMaxSize: kVSmemMaxSize) : 235 | (kQKSmemMaxSize + kVSmemMaxSize) 236 | ); 237 | // NOTE: R_D registers usage, s=2, d=64, 16 regs; d=128, 32 regs; 238 | // d=256, 64 regs; d=512, 128 regs; d=1024, 256 regs; 239 | return kQKVSmemMaxSize; 240 | } 241 | #else 242 | // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem. 243 | constexpr int Q_smem_size = ((kPersistQg2s ? (kHeadDim / kMmaAtomK) : kStageQK) * 244 | (Br * (kMmaAtomK + kPadQ))) * 2; 245 | constexpr int K_smem_size = ((kStageQK) * (Bc * (kMmaAtomK + kPadK))) * 2; 246 | constexpr int V_smem_size = (kStagePV * (Bc * (kMmaAtomN * 2 + kPadV))) * 2; 247 | constexpr int kQKSmemMaxSize = (Q_smem_size + K_smem_size); 248 | constexpr int kVSmemMaxSize = V_smem_size; 249 | // try to let V reuse all Q+K smem after Q@K^T, reduce smem usage. 250 | constexpr int kQKVSmemMaxSize = ( 251 | (kShareSmemQKV && (!kPersistQg2s)) ? 252 | ((kQKSmemMaxSize > kVSmemMaxSize) ? kQKSmemMaxSize: kVSmemMaxSize) : 253 | (kQKSmemMaxSize + kVSmemMaxSize) 254 | ); 255 | // NOTE: R_D registers usage, s=2, d=64, 16 regs; d=128, 32 regs; 256 | // d=256, 64 regs; d=512, 128 regs; d=1024, 256 regs; 257 | return kQKVSmemMaxSize; 258 | #endif 259 | } 260 | 261 | template< 262 | const int kHeadDim, // Headdim, 32~1024 263 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 264 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 265 | const int kStage 266 | > 267 | void launch_ffpa_mma_L1_template(torch::Tensor Q, 268 | torch::Tensor K, 269 | torch::Tensor V, 270 | torch::Tensor O) { 271 | // Q,K,V,O with [B, H, N, D] layout, B=batch, H=head, N=seqlen, D=dim 272 | // TODO: support BNHD layout, Q,K,V,O with [B, N, H, D] layout. 273 | constexpr int kMmaAtomM = 16; 274 | constexpr int kMmaAtomN = 8; 275 | constexpr int kMmaAtomK = 16; 276 | // Split-Q(FA-2) Algo, Tile MMA across Q and keep KV access for all MMAs. 277 | constexpr int kMmaTileSeqLenQ = getConfigMmaTileSeqLenQP(); 278 | constexpr int kMmaTileSeqLenK = 1; 279 | constexpr int kMmaTileSeqLenP = getConfigMmaTileSeqLenQP(); 280 | constexpr int kMmaTileHeadDimV = 1; 281 | constexpr int kWarpTileSeqLenQ = 1; 282 | constexpr int kWarpTileSeqLenK = getConfigWarpTileSeqLenK(); 283 | constexpr int kWarpTileSeqLenP = 1; 284 | constexpr int kWarpTileHeadDimV = getConfigWarpTileHeadDimV(); 285 | constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; 286 | constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; 287 | static_assert(Br == Bc, "Br must be equal Bc to avoid illegal memory access."); 288 | constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; 289 | constexpr int kOStorageAccFloat32 = getConfigOStorageAccFloat32(); 290 | // Apply different multi stages policy for QK and V. 291 | // TODO: tune stages for Q@K and P@V. 292 | constexpr int kStageQK = kStage; // <= 4 293 | constexpr int kStagePV = kStage; // <= 4 294 | // Prefetch QKV, Persist Q g2s/s2r, Shared QKV smem. 295 | constexpr int kShareSmemQKV = getConfigShareSmemQKV(); 296 | constexpr int kPrefetchQK = getConfigPrefetchQKV(); 297 | constexpr int kPrefetchPV = getConfigPrefetchQKV(); 298 | constexpr int kPersistQs2r = getConfigPersistQs2r(); 299 | constexpr int kPersistQg2s = getConfigPersistQg2s(); 300 | constexpr int kRegPipeKV = getConfigRegistersPipeKV(); 301 | // QKV smem swizzle, 0 for smem swizzle, !0 for smem padding. 302 | constexpr int kPadQ = getConfigPadQ(); 303 | constexpr int kPadK = getConfigPadK(); 304 | constexpr int kPadV = getConfigPadV(); 305 | // Calculate SRAM size needed for per block. 306 | constexpr int kQKVSmemMaxSize = getConfigQKVSmemMaxSize< 307 | Br, Bc, kMmaAtomM, kMmaAtomN, kMmaAtomK, kHeadDim, kShareSmemQKV, 308 | kPersistQg2s, kPersistQs2r, kStageQK, kStagePV, kPadQ, kPadK, 309 | kPadV 310 | >(); 311 | 312 | const int QKV_batch = Q.size(0); 313 | const int QKV_head = Q.size(1); 314 | const int QKV_seqlen = Q.size(2); // QKV_seqlen 315 | assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc) 316 | 317 | const dim3 block = getConfigBlock(); // 4/8 warps per block 318 | const dim3 grid = getConfigGrid
(QKV_batch, QKV_head, QKV_seqlen); 319 | // Precompute softmax scale and Tc 320 | const int Tc = utils::div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d] 321 | const float scale = 1.0f / sqrt((float) kHeadDim); 322 | 323 | #define LAUNCH_TEMPLATE_FUNC(TEMPLATE_FUNC) \ 324 | cudaFuncSetAttribute( \ 325 | TEMPLATE_FUNC, \ 326 | cudaFuncAttributeMaxDynamicSharedMemorySize, \ 327 | kQKVSmemMaxSize \ 328 | ); \ 329 | TEMPLATE_FUNC<<>>( \ 330 | reinterpret_cast(Q.data_ptr()), \ 331 | reinterpret_cast(K.data_ptr()), \ 332 | reinterpret_cast(V.data_ptr()), \ 333 | reinterpret_cast(O.data_ptr()), \ 334 | QKV_seqlen, \ 335 | QKV_head, \ 336 | scale, \ 337 | Tc \ 338 | ); 339 | 340 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S 341 | if constexpr (kHeadDim <= kMaxDForSmallDKernel) { // e.g > 128 will use large d kernel 342 | constexpr int kPersistVs2r = getConfigPersistVs2r(); // only for d < 256 343 | 344 | auto ffpa_mma_L1_small_d_kernel_func = ( 345 | ffpa_mma_stages_split_q_L1_small_d_template< 346 | kHeadDim, 347 | kMmaAtomM, 348 | kMmaAtomN, 349 | kMmaAtomK, 350 | kMmaTileSeqLenQ, 351 | kMmaTileSeqLenK, 352 | kMmaTileSeqLenP, 353 | kMmaTileHeadDimV, 354 | kWarpTileSeqLenQ, 355 | kWarpTileSeqLenK, 356 | kWarpTileSeqLenP, 357 | kWarpTileHeadDimV, 358 | kMmaAccFloat32QK, 359 | kMmaAccFloat32PV, 360 | kOStorageAccFloat32, 361 | kPrefetchQK, 362 | kPrefetchPV, 363 | kShareSmemQKV, 364 | kPersistQs2r, 365 | kPersistVs2r, 366 | // Force disable KV registers ping pong buffers 367 | // while V s2r is enabled. 368 | (kPersistVs2r) ? 0 : kRegPipeKV, 369 | 1, /*kStageQK unused*/ 370 | 1, /*kStagePV unused*/ 371 | kPadQ, 372 | kPadK, 373 | kPadV 374 | > 375 | ); 376 | LAUNCH_TEMPLATE_FUNC(ffpa_mma_L1_small_d_kernel_func); 377 | } else { // large headdim > kMaxDForSmallDKernel (e.g 128) 378 | auto ffpa_mma_L1_large_d_kernel_func = ( 379 | ffpa_mma_stages_split_q_L1_large_d_template< 380 | kHeadDim, 381 | kMmaAtomM, 382 | kMmaAtomN, 383 | kMmaAtomK, 384 | kMmaTileSeqLenQ, 385 | kMmaTileSeqLenK, 386 | kMmaTileSeqLenP, 387 | kMmaTileHeadDimV, 388 | kWarpTileSeqLenQ, 389 | kWarpTileSeqLenK, 390 | kWarpTileSeqLenP, 391 | kWarpTileHeadDimV, 392 | kMmaAccFloat32QK, 393 | kMmaAccFloat32PV, 394 | kOStorageAccFloat32, 395 | kPrefetchQK, 396 | kPrefetchPV, 397 | (kPersistQg2s) ? 0 : kShareSmemQKV, 398 | // Force disable Q s2r for d >= 256, Q s2r for large d will 399 | // need too many register, thus, introduce performance drops. 400 | (kPersistQg2s || kHeadDim > 256) ? 0 : kPersistQs2r, 401 | kPersistQg2s, 402 | kRegPipeKV, 403 | kStageQK, 404 | kStagePV, 405 | kPadQ, 406 | kPadK, 407 | kPadV 408 | > 409 | ); 410 | LAUNCH_TEMPLATE_FUNC(ffpa_mma_L1_large_d_kernel_func); 411 | } 412 | #else 413 | auto ffpa_mma_L1_large_d_kernel_func = ( 414 | ffpa_mma_stages_split_q_L1_large_d_template< 415 | kHeadDim, 416 | kMmaAtomM, 417 | kMmaAtomN, 418 | kMmaAtomK, 419 | kMmaTileSeqLenQ, 420 | kMmaTileSeqLenK, 421 | kMmaTileSeqLenP, 422 | kMmaTileHeadDimV, 423 | kWarpTileSeqLenQ, 424 | kWarpTileSeqLenK, 425 | kWarpTileSeqLenP, 426 | kWarpTileHeadDimV, 427 | kMmaAccFloat32QK, 428 | kMmaAccFloat32PV, 429 | kOStorageAccFloat32, 430 | kPrefetchQK, 431 | kPrefetchPV, 432 | (kPersistQg2s) ? 0 : kShareSmemQKV, 433 | // Force disable Q s2r for d >= 256, Q s2r for large d will 434 | // need too many register, thus, introduce performance drops. 435 | (kPersistQg2s || kHeadDim > 256) ? 0 : kPersistQs2r, 436 | kPersistQg2s, 437 | kRegPipeKV, 438 | kStageQK, 439 | kStagePV, 440 | kPadQ, 441 | kPadK, 442 | kPadV 443 | > 444 | ); 445 | LAUNCH_TEMPLATE_FUNC(ffpa_mma_L1_large_d_kernel_func); 446 | #endif 447 | 448 | #undef LAUNCH_TEMPLATE_FUNC 449 | } 450 | 451 | // dispatch headdim 452 | #define LAUNCHER_L1(D, S) \ 453 | case D: \ 454 | launch_ffpa_mma_L1_template< \ 455 | (D), \ 456 | kMmaAccFloat32QK, \ 457 | kMmaAccFloat32PV, \ 458 | (S) \ 459 | >(Q, K, V, O); \ 460 | break; 461 | 462 | #ifdef ENABLE_FFPA_DEBUG 463 | // minimal kernels for debug mode 464 | #define DISPATCH_HEADDIM(LAUNCHER, S) \ 465 | { \ 466 | switch (d) \ 467 | { \ 468 | LAUNCHER(32, (S)); \ 469 | LAUNCHER(64, (S)); \ 470 | LAUNCHER(128, (S)); \ 471 | LAUNCHER(256, (S)); \ 472 | LAUNCHER(320, (S)); \ 473 | LAUNCHER(512, (S)); \ 474 | LAUNCHER(1024, (S)); \ 475 | default: \ 476 | throw std::runtime_error( \ 477 | "headdim not support!"); \ 478 | break; \ 479 | } \ 480 | } 481 | 482 | #else 483 | #ifdef ENABLE_FFPA_ALL_HEADDIM 484 | // multiple of 32 485 | #define DISPATCH_HEADDIM(LAUNCHER, S) \ 486 | { \ 487 | switch (d) \ 488 | { \ 489 | LAUNCHER(32, (S)); \ 490 | LAUNCHER(64, (S)); \ 491 | LAUNCHER(96, (S)); \ 492 | LAUNCHER(128, (S)); \ 493 | LAUNCHER(160, (S)); \ 494 | LAUNCHER(192, (S)); \ 495 | LAUNCHER(224, (S)); \ 496 | LAUNCHER(256, (S)); \ 497 | LAUNCHER(288, (S)); \ 498 | LAUNCHER(320, (S)); \ 499 | LAUNCHER(352, (S)); \ 500 | LAUNCHER(384, (S)); \ 501 | LAUNCHER(416, (S)); \ 502 | LAUNCHER(448, (S)); \ 503 | LAUNCHER(480, (S)); \ 504 | LAUNCHER(512, (S)); \ 505 | LAUNCHER(544, (S)); \ 506 | LAUNCHER(576, (S)); \ 507 | LAUNCHER(608, (S)); \ 508 | LAUNCHER(640, (S)); \ 509 | LAUNCHER(672, (S)); \ 510 | LAUNCHER(704, (S)); \ 511 | LAUNCHER(736, (S)); \ 512 | LAUNCHER(768, (S)); \ 513 | LAUNCHER(800, (S)); \ 514 | LAUNCHER(832, (S)); \ 515 | LAUNCHER(864, (S)); \ 516 | LAUNCHER(896, (S)); \ 517 | LAUNCHER(928, (S)); \ 518 | LAUNCHER(960, (S)); \ 519 | LAUNCHER(992, (S)); \ 520 | LAUNCHER(1024, (S)); \ 521 | default: \ 522 | throw std::runtime_error( \ 523 | "headdim not support!"); \ 524 | break; \ 525 | } \ 526 | } 527 | #else 528 | // multiple of 64 529 | #define DISPATCH_HEADDIM(LAUNCHER, S) \ 530 | { \ 531 | switch (d) \ 532 | { \ 533 | LAUNCHER(256, (S)); \ 534 | LAUNCHER(320, (S)); \ 535 | LAUNCHER(384, (S)); \ 536 | LAUNCHER(448, (S)); \ 537 | LAUNCHER(512, (S)); \ 538 | LAUNCHER(576, (S)); \ 539 | LAUNCHER(640, (S)); \ 540 | LAUNCHER(704, (S)); \ 541 | LAUNCHER(768, (S)); \ 542 | LAUNCHER(832, (S)); \ 543 | LAUNCHER(896, (S)); \ 544 | LAUNCHER(960, (S)); \ 545 | LAUNCHER(1024, (S)); \ 546 | default: \ 547 | throw std::runtime_error( \ 548 | "headdim not support!"); \ 549 | break; \ 550 | } \ 551 | } 552 | #endif 553 | 554 | #endif 555 | -------------------------------------------------------------------------------- /csrc/extension/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info -------------------------------------------------------------------------------- /csrc/extension/fused_mla_F16F16F16_L1.cu: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /csrc/extension/fused_mla_F16F16F32_L1.cu: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /csrc/extension/fused_mla_templates_L1.cuh: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /csrc/extension/launch_templates.cuh: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /csrc/pybind/ffpa_attn_api.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define STRINGFY(str) #str 5 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 6 | m.def(STRINGFY(func), &func, STRINGFY(func)); 7 | 8 | void ffpa_mma_acc_f16_L1(torch::Tensor Q, torch::Tensor K, torch::Tensor V, 9 | torch::Tensor O, int stages); 10 | 11 | void ffpa_mma_acc_f32_L1(torch::Tensor Q, torch::Tensor K, torch::Tensor V, 12 | torch::Tensor O, int stages); 13 | 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | TORCH_BINDING_COMMON_EXTENSION(ffpa_mma_acc_f16_L1) 16 | TORCH_BINDING_COMMON_EXTENSION(ffpa_mma_acc_f32_L1) 17 | } 18 | -------------------------------------------------------------------------------- /csrc/pybind/fused_mla_api.cc: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | class ENV(object): 7 | # ENVs for FFPA kernels compiling 8 | 9 | # Project dir, path to faster-prefill-attention 10 | PROJECT_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | # Enable debug mode for FFPA, fast build minimal kernels, default False. 13 | ENABLE_FFPA_DEBUG = bool(int(os.environ.get("ENABLE_FFPA_DEBUG", 0))) 14 | 15 | # Enable build FFPA kernels for Ada devices (sm89, L2O, 4090, etc), 16 | # default True. 17 | ENABLE_FFPA_ADA = bool(int(os.environ.get("ENABLE_FFPA_ADA", 1))) 18 | 19 | # Enable build FFPA kernels for Ampere devices (sm80, A30, A100, etc), 20 | # default True. 21 | ENABLE_FFPA_AMPERE = bool(int(os.environ.get("ENABLE_FFPA_AMPERE", 1))) 22 | 23 | # Enable build FFPA kernels for Hopper devices (sm90, H100, H20, etc), 24 | # default False. 25 | ENABLE_FFPA_HOPPER = bool(int(os.environ.get("ENABLE_FFPA_HOPPER", 0))) 26 | 27 | # Enable all multi stages kernels or not, if True (1~4) else (1~2), default True. 28 | ENABLE_FFPA_ALL_STAGES = bool(int(os.environ.get("ENABLE_FFPA_ALL_STAGES", 1))) 29 | 30 | # Enable all headdims for FFPA kernels or not, default False. 31 | # True, headdim will range from 32 to 1024 with step = 32, range(32, 1024, 32) 32 | # False, headdim will range from 256 to 1024 with step = 64, range(256, 1024, 64) 33 | ENABLE_FFPA_ALL_HEADDIM = bool(int(os.environ.get("ENABLE_FFPA_ALL_HEADDIM", 0))) 34 | 35 | # Enable force Q@K^T use fp16 as MMA Acc dtype for FFPA Acc F32 kernels, default False. 36 | # FFPA Acc F32 kernels MMA Acc = Mixed Q@K^T MMA Acc F16 + P@V MMA Acc F32. 37 | ENABLE_FFPA_FORCE_QK_F16 = bool(int(os.environ.get("ENABLE_FFPA_FORCE_QK_F16", 0))) 38 | 39 | # Enable force P@V use fp16 as MMA Acc dtype, for FFPA Acc F32 kernels, default False. 40 | # FFPA Acc F32 kernels MMA Acc = Mixed Q@K^T MMA Acc F32 + P@V MMA Acc F16. 41 | ENABLE_FFPA_FORCE_PV_F16 = bool(int(os.environ.get("ENABLE_FFPA_FORCE_PV_F16", 0))) 42 | 43 | # Enable FFPA Prefetch QKV at the Appropriate Time Point, default True, boost 5%~10%. 44 | ENABLE_FFPA_PREFETCH_QKV = bool(int(os.environ.get("ENABLE_FFPA_PREFETCH_QKV", 1))) 45 | 46 | # Enable QKV smem shared policy, default False (perfered for MMA & g2s overlap). 47 | # Please, set it as True if you want to run FFPA on low SRAM device. 48 | ENABLE_FFPA_QKV_SMEM_SHARE = bool( 49 | int(os.environ.get("ENABLE_FFPA_QKV_SMEM_SHARE", 0)) 50 | ) 51 | 52 | # Enable smem swizzle for Q, default True. True: bank conflicts free for Q smem 53 | # via swizzle; False: bank conflicts free for Q smem via padding. 54 | ENABLE_FFPA_SMEM_SWIZZLE_Q = bool( 55 | int(os.environ.get("ENABLE_FFPA_SMEM_SWIZZLE_Q", 1)) 56 | ) 57 | 58 | # Enable smem swizzle for K, default True. True: bank conflicts free for K smem 59 | # via swizzle; False: bank conflicts free for K smem via padding. 60 | ENABLE_FFPA_SMEM_SWIZZLE_K = bool( 61 | int(os.environ.get("ENABLE_FFPA_SMEM_SWIZZLE_K", 1)) 62 | ) 63 | 64 | # Enable smem swizzle for V, now default True. True: bank conflicts free for V smem 65 | # via swizzle; False: bank conflicts free for V smem via padding. FIXME(DefTruth): 66 | # swizzle V seems can not get good performance. why? Will enable it by default untill 67 | # I have fixed the performance issue. (Fixed) 68 | ENABLE_FFPA_SMEM_SWIZZLE_V = bool( 69 | int(os.environ.get("ENABLE_FFPA_SMEM_SWIZZLE_V", 1)) 70 | ) 71 | 72 | # Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage. 73 | ENABLE_FFPA_PERSIST_Q_G2S = bool( 74 | int(os.environ.get("ENABLE_FFPA_PERSIST_Q_G2S", 0)) 75 | ) 76 | 77 | # Persist load KV g2s for headdim <= 256, more SRAM. If True, auto use flash-attn 78 | # algo that tiling at attention level for headdim <= 256 and auto use ffpa-attn 79 | # fined-grain tiling at MMA level for headdim > 256. 80 | ENABLE_FFPA_PERSIST_KV_G2S = bool( 81 | int(os.environ.get("ENABLE_FFPA_PERSIST_KV_G2S", 0)) 82 | ) 83 | 84 | # Persist load Q from s2r for headdim < 512 to reduce Q from g2s and s2r IO access, 85 | # but still keep O(1) SRAM complexity. Default value is False. This option will 86 | # introduce more registers for Q frags as the headdim becomes larger. We should 87 | # choose to enable it or not according to the balance between register usage and 88 | # IO access reduction. 89 | ENABLE_FFPA_PERSIST_Q_S2R = bool( 90 | int(os.environ.get("ENABLE_FFPA_PERSIST_Q_S2R", 0)) 91 | ) 92 | 93 | # Persist V s2r only for small d kernel, more registers. 94 | ENABLE_FFPA_PERSIST_V_S2R = bool( 95 | int(os.environ.get("ENABLE_FFPA_PERSIST_V_S2R", ENABLE_FFPA_PERSIST_KV_G2S)) 96 | ) 97 | 98 | # Registers Ping pong double buffers for ldmatrix & mma computation overlapping. 99 | ENABLE_FFPA_REGISTERS_PIPE_KV = bool( 100 | int(os.environ.get("ENABLE_FFPA_REGISTERS_PIPE_KV", 0)) 101 | ) 102 | 103 | # if True: grid(N/Br, H, B) else: grid(N/Br, B * H) 104 | ENABLE_FFPA_LAUNCH_GRID_DNHB = bool( 105 | int(os.environ.get("ENABLE_FFPA_LAUNCH_GRID_DNHB", 0)) 106 | ) 107 | 108 | @classmethod 109 | def project_dir(cls): 110 | return cls.PROJECT_DIR 111 | 112 | @classmethod 113 | def enable_debug(cls): 114 | return cls.ENABLE_FFPA_DEBUG 115 | 116 | @classmethod 117 | def enable_ada(cls): 118 | return cls.ENABLE_FFPA_ADA 119 | 120 | @classmethod 121 | def enable_ampere(cls): 122 | return cls.ENABLE_FFPA_AMPERE 123 | 124 | @classmethod 125 | def enable_hopper(cls): 126 | return cls.ENABLE_FFPA_HOPPER 127 | 128 | @classmethod 129 | def enable_all_mutistages(cls): 130 | return cls.ENABLE_FFPA_ALL_STAGES 131 | 132 | @classmethod 133 | def enable_all_headdim(cls): 134 | return cls.ENABLE_FFPA_ALL_HEADDIM 135 | 136 | @classmethod 137 | def enable_force_pv_fp16(cls): 138 | return cls.ENABLE_FFPA_FORCE_PV_F16 139 | 140 | @classmethod 141 | def enable_force_qk_fp16(cls): 142 | return cls.ENABLE_FFPA_FORCE_QK_F16 143 | 144 | @classmethod 145 | def enable_prefetch_qkv(cls): 146 | return cls.ENABLE_FFPA_PREFETCH_QKV 147 | 148 | @classmethod 149 | def enable_qkv_smem_share(cls): 150 | return cls.ENABLE_FFPA_QKV_SMEM_SHARE 151 | 152 | @classmethod 153 | def enable_smem_swizzle_q(cls): 154 | return cls.ENABLE_FFPA_SMEM_SWIZZLE_Q 155 | 156 | @classmethod 157 | def enable_smem_swizzle_k(cls): 158 | return cls.ENABLE_FFPA_SMEM_SWIZZLE_K 159 | 160 | @classmethod 161 | def enable_smem_swizzle_v(cls): 162 | return cls.ENABLE_FFPA_SMEM_SWIZZLE_V 163 | 164 | @classmethod 165 | def enable_persist_q_g2s(cls): 166 | return cls.ENABLE_FFPA_PERSIST_Q_G2S 167 | 168 | @classmethod 169 | def enable_persist_kv_g2s(cls): 170 | return cls.ENABLE_FFPA_PERSIST_KV_G2S 171 | 172 | @classmethod 173 | def enable_persist_q_s2r(cls): 174 | return cls.ENABLE_FFPA_PERSIST_Q_S2R 175 | 176 | @classmethod 177 | def enable_persist_v_s2r(cls): 178 | if cls.enable_persist_kv_g2s(): 179 | return cls.ENABLE_FFPA_PERSIST_V_S2R 180 | return False 181 | 182 | @classmethod 183 | def enable_registers_pipe_kv(cls): 184 | return cls.ENABLE_FFPA_REGISTERS_PIPE_KV 185 | 186 | @classmethod 187 | def enable_launch_grid_dnhb(cls): 188 | return cls.ENABLE_FFPA_LAUNCH_GRID_DNHB 189 | 190 | @classmethod 191 | def env_cuda_cflags(cls): 192 | extra_env_cflags = [] 193 | if cls.enable_debug(): 194 | extra_env_cflags.append("-DENABLE_FFPA_DEBUG") 195 | if cls.enable_all_mutistages(): 196 | extra_env_cflags.append("-DENABLE_FFPA_ALL_STAGES") 197 | if cls.enable_all_headdim(): 198 | extra_env_cflags.append("-DENABLE_FFPA_ALL_HEADDIM") 199 | if cls.enable_force_qk_fp16(): 200 | extra_env_cflags.append("-DENABLE_FFPA_FORCE_QK_F16") 201 | if cls.enable_force_pv_fp16(): 202 | extra_env_cflags.append("-DENABLE_FFPA_FORCE_PV_F16") 203 | if cls.enable_prefetch_qkv(): 204 | extra_env_cflags.append("-DENABLE_FFPA_PREFETCH_QKV") 205 | if cls.enable_qkv_smem_share(): 206 | extra_env_cflags.append("-DENABLE_FFPA_QKV_SMEM_SHARE") 207 | if cls.enable_smem_swizzle_q(): 208 | extra_env_cflags.append("-DENABLE_FFPA_SMEM_SWIZZLE_Q") 209 | if cls.enable_smem_swizzle_k(): 210 | extra_env_cflags.append("-DENABLE_FFPA_SMEM_SWIZZLE_K") 211 | if cls.enable_smem_swizzle_v(): 212 | extra_env_cflags.append("-DENABLE_FFPA_SMEM_SWIZZLE_V") 213 | if cls.enable_persist_q_g2s(): 214 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_Q_G2S") 215 | if cls.enable_persist_kv_g2s(): 216 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_KV_G2S") 217 | if cls.enable_persist_q_s2r(): 218 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_Q_S2R") 219 | if cls.enable_persist_v_s2r(): 220 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_V_S2R") 221 | if cls.enable_registers_pipe_kv(): 222 | extra_env_cflags.append("-DENABLE_FFPA_REGISTERS_PIPE_KV") 223 | if cls.enable_launch_grid_dnhb(): 224 | extra_env_cflags.append("-DENBALE_FFPA_LAUNCH_GRID_DNHB") 225 | 226 | if cls.enable_persist_kv_g2s(): 227 | assert ( 228 | cls.enable_persist_q_g2s() 229 | ), "PERSIST_Q_G2S must be enable if PERSIST_KV_G2S is enabled." 230 | if cls.enable_qkv_smem_share(): 231 | assert ( 232 | cls.enable_persist_q_s2r() 233 | ), "PERSIST_Q_S2R must be enable if QKV_SMEM_SHARE and " 234 | "PERSIST_KV_G2S are enabled." 235 | else: 236 | assert not all( 237 | (cls.enable_persist_q_s2r(), cls.enable_persist_q_g2s()) 238 | ), "PERSIST_Q_G2S and PERSIST_Q_S2R can not both enabled." 239 | assert not all( 240 | (cls.enable_qkv_smem_share(), cls.enable_persist_q_g2s()) 241 | ), "PERSIST_Q_G2S and QKV_SMEM_SHARE can not both enabled." 242 | assert not all( 243 | (cls.enable_qkv_smem_share(), cls.enable_persist_kv_g2s()) 244 | ), "PERSIST_KV_G2S and QKV_SMEM_SHARE can not both enabled." 245 | return extra_env_cflags 246 | 247 | @classmethod 248 | def list_ffpa_env(cls): 249 | def formatenv(name, value): 250 | try: 251 | print( 252 | f"{name:<30}: {str(value):<5} -> command:" 253 | f" export {name}={int(value)}" 254 | ) 255 | except Exception: 256 | print(f"{name:<30}: {value}") 257 | 258 | pretty_print_line("FFPA-ATTN ENVs") 259 | formatenv("PROJECT_DIR", cls.project_dir()) 260 | formatenv("ENABLE_FFPA_DEBUG", cls.enable_debug()) 261 | formatenv("ENABLE_FFPA_ADA", cls.enable_ada()) 262 | formatenv("ENABLE_FFPA_AMPERE", cls.enable_ampere()) 263 | formatenv("ENABLE_FFPA_HOPPER", cls.enable_hopper()) 264 | formatenv("ENABLE_FFPA_ALL_STAGES", cls.enable_all_mutistages()) 265 | formatenv("ENABLE_FFPA_ALL_HEADDIM", cls.enable_all_headdim()) 266 | formatenv("ENABLE_FFPA_PREFETCH_QKV", cls.enable_prefetch_qkv()) 267 | formatenv("ENABLE_FFPA_FORCE_QK_F16", cls.enable_force_qk_fp16()) 268 | formatenv("ENABLE_FFPA_FORCE_PV_F16", cls.enable_force_pv_fp16()) 269 | formatenv("ENABLE_FFPA_PERSIST_Q_G2S", cls.enable_persist_q_g2s()) 270 | formatenv("ENABLE_FFPA_PERSIST_KV_G2S", cls.enable_persist_kv_g2s()) 271 | formatenv("ENABLE_FFPA_PERSIST_Q_S2R", cls.enable_persist_q_s2r()) 272 | formatenv("ENABLE_FFPA_PERSIST_V_S2R", cls.enable_persist_v_s2r()) 273 | formatenv("ENABLE_FFPA_QKV_SMEM_SHARE", cls.enable_qkv_smem_share()) 274 | formatenv("ENABLE_FFPA_SMEM_SWIZZLE_Q", cls.enable_smem_swizzle_q()) 275 | formatenv("ENABLE_FFPA_SMEM_SWIZZLE_K", cls.enable_smem_swizzle_k()) 276 | formatenv("ENABLE_FFPA_SMEM_SWIZZLE_V", cls.enable_smem_swizzle_v()) 277 | formatenv("ENABLE_FFPA_REGISTERS_PIPE_KV", cls.enable_registers_pipe_kv()) 278 | formatenv("ENABLE_FFPA_LAUNCH_GRID_DNHB", cls.enable_launch_grid_dnhb()) 279 | pretty_print_line() 280 | 281 | @staticmethod 282 | def get_device_name(): 283 | device_name = torch.cuda.get_device_name(torch.cuda.current_device()) 284 | # since we will run GPU on WSL2, so add WSL2 tag. 285 | if "Laptop" in device_name: 286 | device_name += " WSL2" 287 | return device_name 288 | 289 | @staticmethod 290 | def get_device_capability(): 291 | return torch.cuda.get_device_capability(torch.cuda.current_device()) 292 | 293 | @staticmethod 294 | def get_build_sources(build_pkg: bool = False): 295 | def csrc(sub_dir, filename): 296 | csrc_file = f"{ENV.project_dir()}/csrc/{sub_dir}/{filename}" 297 | if ENV.enable_debug() or build_pkg: 298 | pretty_print_line(f"csrc_file: {csrc_file}", sep="", mode="left") 299 | return csrc_file 300 | 301 | if ENV.enable_debug() or build_pkg: 302 | pretty_print_line() 303 | build_sources = [ 304 | csrc("pybind", "ffpa_attn_api.cc"), 305 | csrc("cuffpa", "ffpa_attn_F16F16F16_L1.cu"), 306 | csrc("cuffpa", "ffpa_attn_F16F16F32_L1.cu"), 307 | ] 308 | if ENV.enable_debug() or build_pkg: 309 | pretty_print_line() 310 | return build_sources 311 | 312 | @staticmethod 313 | def get_build_cuda_cflags(build_pkg: bool = False): 314 | device_name = ENV.get_device_name() 315 | extra_cuda_cflags = [] 316 | extra_cuda_cflags.append("-O3") 317 | extra_cuda_cflags.append("-std=c++17") 318 | extra_cuda_cflags.append("-U__CUDA_NO_HALF_OPERATORS__") 319 | extra_cuda_cflags.append("-U__CUDA_NO_HALF_CONVERSIONS__") 320 | extra_cuda_cflags.append("-U__CUDA_NO_HALF2_OPERATORS__") 321 | extra_cuda_cflags.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") 322 | extra_cuda_cflags.append("--expt-relaxed-constexpr") 323 | extra_cuda_cflags.append("--expt-extended-lambda") 324 | extra_cuda_cflags.append("--use_fast_math") 325 | extra_cuda_cflags.append( 326 | "-diag-suppress 177" if not build_pkg else "--ptxas-options=-v" 327 | ) 328 | extra_cuda_cflags.append( 329 | "-Xptxas -v" if not build_pkg else "--ptxas-options=-O3" 330 | ) 331 | extra_cuda_cflags.append( 332 | "-DBUILD_FFPA_ATTN_MMA_L20" if "L20" in device_name else "" 333 | ) 334 | extra_cuda_cflags.append( 335 | "-DBUILD_FFPA_ATTN_MMA_4090" if "4090" in device_name else "" 336 | ) 337 | extra_cuda_cflags.append( 338 | "-DBUILD_FFPA_ATTN_MMA_3080" if "3080" in device_name else "" 339 | ) 340 | extra_cuda_cflags.extend(ENV.env_cuda_cflags()) 341 | extra_cuda_cflags.append(f"-I {ENV.project_dir()}/include") 342 | extra_cuda_cflags.append(f"-I {ENV.project_dir()}/csrc/cuffpa") 343 | return extra_cuda_cflags 344 | 345 | @staticmethod 346 | def get_build_cflags(): 347 | extra_cflags = [] 348 | extra_cflags.append("-std=c++17") 349 | return extra_cflags 350 | 351 | @staticmethod 352 | def get_cuda_bare_metal_version(cuda_dir): 353 | # helper function to get cuda version 354 | import subprocess 355 | 356 | from packaging.version import parse 357 | 358 | raw_output = subprocess.check_output( 359 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 360 | ) 361 | output = raw_output.split() 362 | release_idx = output.index("release") + 1 363 | bare_metal_version = parse(output[release_idx].split(",")[0]) 364 | 365 | return raw_output, bare_metal_version 366 | 367 | @staticmethod 368 | def build_ffpa_from_sources(verbose: bool = False): 369 | from torch.utils.cpp_extension import load 370 | 371 | torch_arch_list_env = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 372 | # Load the CUDA kernel as a python module 373 | pretty_print_line( 374 | f"Loading ffpa_attn lib on device: {ENV.get_device_name()}, " 375 | f"capability: {ENV.get_device_capability()}, " 376 | f"Arch ENV: {torch_arch_list_env}" 377 | ) 378 | return load( 379 | name="pyffpa_cuda", 380 | sources=ENV.get_build_sources(), 381 | extra_cuda_cflags=ENV.get_build_cuda_cflags(), 382 | extra_cflags=ENV.get_build_cflags(), 383 | verbose=verbose, 384 | ) 385 | 386 | @staticmethod 387 | def try_load_ffpa_library(force_build: bool = False, verbose: bool = False): 388 | use_ffpa_attn_package = False 389 | if not force_build: 390 | # check if can import ffpa_attn 391 | try: 392 | import ffpa_attn 393 | 394 | pretty_print_line("Import ffpa_attn library done, use it!") 395 | use_ffpa_attn_package = True 396 | return ffpa_attn, use_ffpa_attn_package 397 | except Exception: 398 | pretty_print_line("Can't import ffpa_attn, force build from sources") 399 | pretty_print_line( 400 | "Also may need export LD_LIBRARY_PATH=" 401 | "PATH-TO/torch/lib:$LD_LIBRARY_PATH" 402 | ) 403 | ffpa_attn = ENV.build_ffpa_from_sources(verbose=verbose) 404 | use_ffpa_attn_package = False 405 | return ffpa_attn, use_ffpa_attn_package 406 | else: 407 | pretty_print_line("Force ffpa_attn lib build from sources") 408 | ffpa_attn = ENV.build_ffpa_from_sources(verbose=verbose) 409 | use_ffpa_attn_package = False 410 | return ffpa_attn, use_ffpa_attn_package 411 | 412 | 413 | def pretty_print_line( 414 | m: str = "", sep: str = "-", mode: str = "center", width: int = 150 415 | ): 416 | res_len = width - len(m) 417 | if mode == "center": 418 | left_len = int(res_len / 2) 419 | right_len = res_len - left_len 420 | pretty_line = sep * left_len + m + sep * right_len 421 | elif mode == "left": 422 | pretty_line = m + sep * res_len 423 | else: 424 | pretty_line = sep * res_len + m 425 | print(pretty_line) 426 | 427 | 428 | if __name__ == "__main__": 429 | # Debug: show FFPA ENV information. run: python3 env.py 430 | ENV.list_ffpa_env() 431 | -------------------------------------------------------------------------------- /ffpa_attn/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | .egg-info 22 | dist 23 | -------------------------------------------------------------------------------- /ffpa_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .interface import ( 2 | faster_prefill_attn_func, 3 | ffpa, 4 | ffpa_acc_f16_L1, 5 | ffpa_acc_f32_L1, 6 | ffpa_mma_acc_f16_L1, 7 | ffpa_mma_acc_f32_L1, 8 | LevelType, 9 | MMAAccType, 10 | ) 11 | from .version import __version__ 12 | 13 | 14 | # e.g pyffpa.L1 15 | L1 = LevelType.L1 16 | L2 = LevelType.L2 17 | L3 = LevelType.L3 18 | # e.g pyffpa.FP32 19 | FP32 = MMAAccType.FP32 20 | FP16 = MMAAccType.FP16 21 | -------------------------------------------------------------------------------- /ffpa_attn/interface.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | # from pyffpa_cuda.cpython.*.so 8 | from pyffpa_cuda import ffpa_mma_acc_f16_L1, ffpa_mma_acc_f32_L1 9 | 10 | 11 | class LevelType(Enum): 12 | L1 = 0 13 | L2 = 1 14 | L3 = 2 15 | 16 | 17 | class MMAAccType(Enum): 18 | FP32 = 0 19 | FP16 = 1 20 | 21 | 22 | def faster_prefill_attn_func( 23 | q: torch.Tensor, 24 | k: torch.Tensor, 25 | v: torch.Tensor, 26 | o: Optional[torch.Tensor] = None, 27 | num_stages: int = 2, 28 | level: LevelType = LevelType.L1, 29 | acc: MMAAccType = MMAAccType.FP32, 30 | ): 31 | # Q, K, V, O: [B, H, N, D] layout 32 | if not isinstance(o, torch.Tensor) or o is None: 33 | o = torch.zeros_like(q) 34 | assert level == LevelType.L1, "only support FFPA L1 level now." 35 | if acc == MMAAccType.FP32: 36 | ffpa_mma_acc_f32_L1(q, k, v, o, num_stages) 37 | else: 38 | ffpa_mma_acc_f16_L1(q, k, v, o, num_stages) 39 | return o 40 | 41 | 42 | ffpa: callable = faster_prefill_attn_func 43 | ffpa_acc_f32_L1 = partial( 44 | faster_prefill_attn_func, level=LevelType.L1, acc=MMAAccType.FP32 45 | ) 46 | ffpa_acc_f16_L1 = partial( 47 | faster_prefill_attn_func, level=LevelType.L1, acc=MMAAccType.FP16 48 | ) 49 | -------------------------------------------------------------------------------- /ffpa_attn/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.2" # type: ignore 2 | -------------------------------------------------------------------------------- /include/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info 22 | dist 23 | -------------------------------------------------------------------------------- /include/cuffpa/cp_async.cuh: -------------------------------------------------------------------------------- 1 | // cp async operations 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace ffpa { 16 | namespace cp_async { 17 | 18 | // Simple wrappers for cp.async/ld/st instructions. 19 | __device__ __forceinline__ void commit_group() { 20 | asm volatile("cp.async.commit_group;\n" ::); 21 | } 22 | 23 | // e.g: wait_group<1>(); 24 | template 25 | __device__ __forceinline__ void wait_group() { 26 | asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); 27 | } 28 | 29 | // e.g: cp_async(smem_ptr, gmem_ptr); 30 | template 31 | __device__ __forceinline__ void cp_async( 32 | uint32_t smem_ptr, const T* gmem_ptr) { 33 | static_assert(kBytes == 16 || kBytes == 8); // 8 or 4 halfs 34 | if constexpr (kBytes == 16) { 35 | asm volatile( 36 | "cp.async.cg.shared.global.L2::128B " 37 | "[%0], [%1], %2, %3;\n" 38 | ::"r"(smem_ptr), "l"(gmem_ptr), 39 | "n"(16), "r"(16) 40 | ); 41 | } else { 42 | asm volatile( 43 | "cp.async.ca.shared.global.L2::128B " 44 | "[%0], [%1], %2, %3;\n" 45 | ::"r"(smem_ptr), "l"(gmem_ptr), 46 | "n"(8), "r"(8) 47 | ); 48 | } 49 | } 50 | 51 | // e.g ldg_sync_128b(...); 52 | template 53 | __device__ __forceinline__ void ldg_sync_128b( 54 | T0 * mem_dst_ptr, T1 * gmem_src_ptr) { 55 | using _128b_t = uint4; 56 | _128b_t * dst_128b_ptr = reinterpret_cast<_128b_t*>( 57 | mem_dst_ptr); 58 | _128b_t * src_128b_ptr = reinterpret_cast<_128b_t*>( 59 | gmem_src_ptr); 60 | *(dst_128b_ptr) = *(src_128b_ptr); 61 | } 62 | 63 | // e.g stg_sync_128b(...); 64 | template 65 | __device__ __forceinline__ void stg_sync_128b( 66 | T0 * gmem_dst_ptr, T1 * mem_src_ptr) { 67 | using _128b_t = uint4; 68 | _128b_t * dst_128b_ptr = reinterpret_cast<_128b_t*>( 69 | gmem_dst_ptr); 70 | _128b_t * src_128b_ptr = reinterpret_cast<_128b_t*>( 71 | mem_src_ptr); 72 | *(dst_128b_ptr) = *(src_128b_ptr); 73 | } 74 | 75 | } // cp_async 76 | } // ffpa 77 | -------------------------------------------------------------------------------- /include/cuffpa/deprecated/mma_utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | using namespace nvcuda; 16 | 17 | #define WARP_SIZE 32 18 | #define INT4(value) (reinterpret_cast(&(value))[0]) 19 | #define FLOAT2(value) (reinterpret_cast(&(value))[0]) 20 | #define FLOAT4(value) (reinterpret_cast(&(value))[0]) 21 | #define HALF2(value) (reinterpret_cast(&(value))[0]) 22 | #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) 23 | #define LDST32BITS(value) (reinterpret_cast(&(value))[0]) 24 | #define LDST64BITS(value) (reinterpret_cast(&(value))[0]) 25 | #define LDST128BITS(value) (reinterpret_cast(&(value))[0]) 26 | // gmem -> smem 27 | #define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) 28 | #define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) 29 | #define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) 30 | // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. 31 | #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 32 | #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) 33 | // ldmatrix 34 | #define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 35 | #define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 36 | #define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 37 | #define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) 38 | #define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) 39 | #define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) 40 | // mma m16n8k16 acc f32 or f16 41 | #define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) 42 | #define HMMA16816F32(RD0, RD1, RD2, RD3, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1, RC2, RC3) asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=r"(RD0), "=r"(RD1), "=r"(RD2), "=r"(RD3): "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1), "r"(RC2), "r"(RC3)) 43 | 44 | 45 | __device__ __host__ inline 46 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } 47 | 48 | template 49 | __device__ inline T warp_reduce_sum(T val) { 50 | #pragma unroll 51 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { 52 | val += __shfl_xor_sync(0xffffffff, val, mask, kWarpSize); 53 | } 54 | return val; 55 | } 56 | 57 | template 58 | __device__ inline T warp_reduce_max(T val) { 59 | #pragma unroll 60 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { 61 | val = max(val, __shfl_xor_sync(0xffffffff, val, mask, kWarpSize)); 62 | } 63 | return val; 64 | } 65 | 66 | template 67 | __device__ inline void fill_3D_regs(T (&R)[M][N][K], T val) { 68 | #pragma unroll 69 | for (int i = 0; i < M; ++i) { 70 | #pragma unroll 71 | for (int j = 0; j < N; ++j) { 72 | #pragma unroll 73 | for (int k = 0; k < K; ++k) { 74 | R[i][j][k] = val; 75 | } 76 | } 77 | } 78 | } 79 | 80 | template 81 | __device__ inline void fill_2D_regs(T (&R)[M][N], T val) { 82 | #pragma unroll 83 | for (int i = 0; i < M; ++i) { 84 | #pragma unroll 85 | for (int j = 0; j < N; ++j) { 86 | R[i][j] = val; 87 | } 88 | } 89 | } 90 | 91 | template 92 | __device__ inline void fill_1D_regs(T (&S)[M], T val) { 93 | #pragma unroll 94 | for (int i = 0; i < M; ++i) { 95 | S[i] = val; 96 | } 97 | } 98 | 99 | #ifdef FFPA_MMA_DEBUG 100 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) \ 101 | { \ 102 | if (tid == 0) { \ 103 | float2 v_reg = __half22float2(HALF2(R)); \ 104 | printf("[T0] " format ", V0=%f, V1=%f\n", \ 105 | ##__VA_ARGS__, v_reg.x, v_reg.y); \ 106 | } \ 107 | } 108 | 109 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) \ 110 | { \ 111 | if (tid < 32) { \ 112 | float2 v_reg = __half22float2(HALF2(R)); \ 113 | printf("[T%d] " format ", V0=%f, V1=%f\n", \ 114 | tid, ##__VA_ARGS__, v_reg.x, v_reg.y);\ 115 | } \ 116 | } 117 | 118 | #define FFPA_MMA_PRINT_REG(R, format, ...) \ 119 | { \ 120 | { \ 121 | float2 v_reg = __half22float2(HALF2(R)); \ 122 | printf(format", V0=%f, V1=%f\n", \ 123 | ##__VA_ARGS__, v_reg.x, v_reg.y); \ 124 | } \ 125 | } 126 | 127 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) \ 128 | { \ 129 | { \ 130 | float2 v_reg_0 = __half22float2(HALF2(R0)); \ 131 | float2 v_reg_1 = __half22float2(HALF2(R1)); \ 132 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \ 133 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \ 134 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \ 135 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \ 136 | } \ 137 | } \ 138 | } 139 | 140 | #define FFPA_MMA_CHECK_PRINT_T32_REG(R0, R1, format, ...) \ 141 | { \ 142 | if (tid < 32){ \ 143 | float2 v_reg_0 = __half22float2(HALF2(R0)); \ 144 | float2 v_reg_1 = __half22float2(HALF2(R1)); \ 145 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \ 146 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \ 147 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \ 148 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \ 149 | } \ 150 | } \ 151 | } 152 | 153 | #define FFPA_MMA_PRINT_T0(format, ...) \ 154 | { \ 155 | if (tid == 0) { \ 156 | printf("[T0] " format, ##__VA_ARGS__); \ 157 | } \ 158 | } 159 | 160 | #define FFPA_MMA_PRINT_T32(format, ...) \ 161 | { \ 162 | if (tid < 32) { \ 163 | printf("[T%d] " format, tid, ##__VA_ARGS__);\ 164 | } \ 165 | } 166 | 167 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) \ 168 | { \ 169 | if (lane_id == 0) { \ 170 | float2 v_reg = __half22float2(HALF2(R)); \ 171 | printf("[L0] " format", V0=%f, V1=%f\n", \ 172 | ##__VA_ARGS__, v_reg.x, v_reg.y); \ 173 | } \ 174 | } 175 | 176 | #define FFPA_MMA_PRINT_L0(format, ...) \ 177 | { \ 178 | if (lane_id == 0) { \ 179 | printf("[L0] " format, ##__VA_ARGS__); \ 180 | } \ 181 | } 182 | 183 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) \ 184 | { \ 185 | if (tid == 0 && blockIdx.z == 0) { \ 186 | printf("----------------------------------------\n"); \ 187 | printf(format, ##__VA_ARGS__); \ 188 | for (int i = 0; i < Br; ++i) { \ 189 | for (int j = 0; j < kMmaTileSeqLenK; ++j) { \ 190 | printf("[%d][%d]=%f", i, j, (B)[i][j]); \ 191 | } \ 192 | printf("\n"); \ 193 | } \ 194 | printf("----------------------------------------\n"); \ 195 | } \ 196 | __syncthreads(); \ 197 | } 198 | 199 | #else 200 | 201 | #define FFPA_MMA_PRINT_REG(R, format, ...) {} 202 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) {} 203 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) {} 204 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) {} 205 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) {} 206 | #define FFPA_MMA_PRINT_T0(format, ...) {} 207 | #define FFPA_MMA_PRINT_T32(format, ...) {} 208 | #define FFPA_MMA_PRINT_L0(format, ...) {} 209 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) {} 210 | 211 | #endif 212 | 213 | #define STRINGFY(str) #str 214 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 215 | m.def(STRINGFY(func), &func, STRINGFY(func)); 216 | 217 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 218 | if(((T).options().dtype() != (th_type))) { \ 219 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 220 | throw std::runtime_error("values must be "#th_type); \ 221 | } 222 | 223 | #define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ 224 | if (((T2).size(0) != (T1).size(0)) || \ 225 | ((T2).size(1) != (T1).size(1)) || \ 226 | ((T2).size(2) != (T1).size(2)) || \ 227 | ((T2).size(3) != (T1).size(3))) { \ 228 | throw std::runtime_error("Tensor size mismatch!"); \ 229 | } 230 | -------------------------------------------------------------------------------- /include/cuffpa/deprecated/smem_swizzle.cuh: -------------------------------------------------------------------------------- 1 | // Manually SMEM swizzling for bank conflict free. 2 | // ---------------------------------------------------------------- 3 | // Manually SMEM swizzling for bank conflict free. 4 | // ---------------------------------------------------------------- 5 | // [INFO] Assert smem store layout col_stride <= 16, prefer 16. | 6 | // [INFO] For logical_col_stride > 16, we have to permute the | 7 | // [INFO] smem store layout using col major ZigZag method: | 8 | // [INFO] e.g, --> Q smem logical layout [Br][64]. | 9 | // [INFO] --> col major ZigZag permuted --> | 10 | // [INFO] --> Q smem store layout [4][Br][16]. | 11 | // ---------------------------------------------------------------- 12 | // ---------------------------------------------------------------- 13 | // -------------------------swizzle layout------------------------- 14 | // --------------------logical col 0~64, step 8-------------------- 15 | // ---------------------smem col 0~16, step 8---------------------- 16 | // ---------------------------------------------------------------- 17 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 18 | // |row 0 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 19 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 20 | // |row 1 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 21 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 22 | // |row 2 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 23 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 24 | // |row 3 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 25 | // ---------------------------------------------------------------- 26 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 27 | // |row 4 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 28 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 29 | // |row 5 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 30 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 31 | // |row 6 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 32 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 33 | // |row 7 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 34 | // ---------------------------------------------------------------- 35 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 36 | // |row 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 37 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 38 | // |row 9 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 39 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 40 | // |row 10| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 41 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 42 | // |row 11| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 43 | // ---------------------------------------------------------------- 44 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 45 | // |row 12| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 46 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 47 | // |row 13| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 48 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 49 | // |row 14| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 50 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 51 | // |row 15| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 52 | // ---------------------------------------------------------------- 53 | #pragma once 54 | #include 55 | #include 56 | #include 57 | #include 58 | #include 59 | 60 | // i: row index; j: col index. 61 | template 62 | static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) { 63 | // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep; 64 | static_assert(kColStride <= 16, "Currently, kColStride must be less than or equal to 16."); 65 | static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4."); 66 | static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep."); 67 | if constexpr (kStep == 8) { 68 | return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3; 69 | } else { 70 | static_assert(kStep == 4); 71 | return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2; 72 | } 73 | } 74 | 75 | // i: row index; j: col index 76 | // e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue. 77 | template 78 | static __device__ __forceinline__ int swizzle_permuted_Q_j(int i, int j) { 79 | return swizzle_permuted_j(i, j); 80 | } 81 | 82 | // i: row index; j: col index 83 | // e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue. 84 | template 85 | static __device__ __forceinline__ int swizzle_permuted_K_j(int i, int j) { 86 | return swizzle_permuted_j(i, j); 87 | } 88 | 89 | // i: row index; j: col index 90 | // e.g kColStride = kMmaAtomN * 2 = 16, kStep = 8 -> load 8 half as 128 bits memory issue. 91 | template 92 | static __device__ __forceinline__ int swizzle_permuted_V_j(int i, int j) { 93 | return swizzle_permuted_j(i, j); 94 | } 95 | -------------------------------------------------------------------------------- /include/cuffpa/logging.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #define STRINGFY(str) #str 17 | #define TORCH_BINDING_COMMON_EXTENSION(func) \ 18 | m.def(STRINGFY(func), &func, STRINGFY(func)); 19 | 20 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ 21 | if(((T).options().dtype() != (th_type))) { \ 22 | std::cout << "Tensor Info:" << (T).options() << std::endl; \ 23 | throw std::runtime_error("values must be "#th_type); \ 24 | } 25 | 26 | #define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \ 27 | if (((T2).size(0) != (T1).size(0)) || \ 28 | ((T2).size(1) != (T1).size(1)) || \ 29 | ((T2).size(2) != (T1).size(2)) || \ 30 | ((T2).size(3) != (T1).size(3))) { \ 31 | throw std::runtime_error("Tensor size mismatch!"); \ 32 | } 33 | 34 | #define HALF2(value) (reinterpret_cast(&(value))[0]) 35 | 36 | 37 | #ifdef ENABLE_FFPA_DEBUG 38 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) \ 39 | { \ 40 | if (tid == 0) { \ 41 | float2 v_reg = __half22float2(HALF2(R)); \ 42 | printf("[T0] " format ", V0=%f, V1=%f\n", \ 43 | ##__VA_ARGS__, v_reg.x, v_reg.y); \ 44 | } \ 45 | } 46 | 47 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) \ 48 | { \ 49 | if (tid < 32) { \ 50 | float2 v_reg = __half22float2(HALF2(R)); \ 51 | printf("[T%d] " format ", V0=%f, V1=%f\n", \ 52 | tid, ##__VA_ARGS__, v_reg.x, v_reg.y);\ 53 | } \ 54 | } 55 | 56 | #define FFPA_MMA_PRINT_REG(R, format, ...) \ 57 | { \ 58 | { \ 59 | float2 v_reg = __half22float2(HALF2(R)); \ 60 | printf(format", V0=%f, V1=%f\n", \ 61 | ##__VA_ARGS__, v_reg.x, v_reg.y); \ 62 | } \ 63 | } 64 | 65 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) \ 66 | { \ 67 | { \ 68 | float2 v_reg_0 = __half22float2(HALF2(R0)); \ 69 | float2 v_reg_1 = __half22float2(HALF2(R1)); \ 70 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \ 71 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \ 72 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \ 73 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \ 74 | } \ 75 | } \ 76 | } 77 | 78 | #define FFPA_MMA_CHECK_PRINT_T32_REG(R0, R1, format, ...) \ 79 | { \ 80 | if (tid < 32){ \ 81 | float2 v_reg_0 = __half22float2(HALF2(R0)); \ 82 | float2 v_reg_1 = __half22float2(HALF2(R1)); \ 83 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \ 84 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \ 85 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \ 86 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \ 87 | } \ 88 | } \ 89 | } 90 | 91 | #define FFPA_MMA_PRINT_T0(format, ...) \ 92 | { \ 93 | if (tid == 0) { \ 94 | printf("[T0] " format, ##__VA_ARGS__); \ 95 | } \ 96 | } 97 | 98 | #define FFPA_MMA_PRINT_T32(format, ...) \ 99 | { \ 100 | if (tid < 32) { \ 101 | printf("[T%d] " format, tid, ##__VA_ARGS__);\ 102 | } \ 103 | } 104 | 105 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) \ 106 | { \ 107 | if (lane_id == 0) { \ 108 | float2 v_reg = __half22float2(HALF2(R)); \ 109 | printf("[L0] " format", V0=%f, V1=%f\n", \ 110 | ##__VA_ARGS__, v_reg.x, v_reg.y); \ 111 | } \ 112 | } 113 | 114 | #define FFPA_MMA_PRINT_L0(format, ...) \ 115 | { \ 116 | if (lane_id == 0) { \ 117 | printf("[L0] " format, ##__VA_ARGS__); \ 118 | } \ 119 | } 120 | 121 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) \ 122 | { \ 123 | if (tid == 0 && blockIdx.z == 0) { \ 124 | printf("----------------------------------------\n"); \ 125 | printf(format, ##__VA_ARGS__); \ 126 | for (int i = 0; i < Br; ++i) { \ 127 | for (int j = 0; j < kMmaTileSeqLenK; ++j) { \ 128 | printf("[%d][%d]=%f", i, j, (B)[i][j]); \ 129 | } \ 130 | printf("\n"); \ 131 | } \ 132 | printf("----------------------------------------\n"); \ 133 | } \ 134 | __syncthreads(); \ 135 | } 136 | 137 | #else 138 | 139 | #define FFPA_MMA_PRINT_REG(R, format, ...) {} 140 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) {} 141 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) {} 142 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) {} 143 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) {} 144 | #define FFPA_MMA_PRINT_T0(format, ...) {} 145 | #define FFPA_MMA_PRINT_T32(format, ...) {} 146 | #define FFPA_MMA_PRINT_L0(format, ...) {} 147 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) {} 148 | 149 | #endif 150 | 151 | -------------------------------------------------------------------------------- /include/cuffpa/mma.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace ffpa { 15 | namespace mma { 16 | 17 | enum class MMAMode { 18 | kAutoZeroFill = 0U, 19 | kInplaceUpdate = 1U, 20 | }; 21 | 22 | // Simple wrappers for mma and ldmatrix instructions. 23 | template 24 | __device__ __forceinline__ void m16n8k16_f16f16f16( 25 | uint32_t * RD0, uint32_t * RD1, 26 | uint32_t * RA0, uint32_t * RA1, uint32_t * RA2, uint32_t * RA3, 27 | uint32_t * RB0, uint32_t * RB1 28 | ) { 29 | if constexpr (mma_mode == MMAMode::kInplaceUpdate) { 30 | asm volatile( 31 | "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " 32 | "{%0, %1}, " 33 | "{%2, %3, %4, %5}, " 34 | "{%6, %7}, " 35 | "{%8, %9};\n" 36 | : "=r"(RD0[0]), "=r"(RD1[0]) 37 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]), 38 | "r"(RB0[0]), "r"(RB1[0]), 39 | "r"(RD0[0]), "r"(RD1[0]) 40 | ); 41 | } else { 42 | // WARN: seems can not get good performance while stage = 1. 43 | asm volatile( 44 | "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " 45 | "{%0, %1}, " 46 | "{%2, %3, %4, %5}, " 47 | "{%6, %7}, " 48 | "{%8, %9};\n" 49 | : "=r"(RD0[0]), "=r"(RD1[0]) 50 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]), 51 | "r"(RB0[0]), "r"(RB1[0]), 52 | "r"(0), "r"(0) 53 | ); 54 | } 55 | } 56 | 57 | template 58 | __device__ __forceinline__ void m16n8k16_f16f16f32( 59 | uint32_t * RD0, uint32_t * RD1, uint32_t * RD2, uint32_t * RD3, 60 | uint32_t * RA0, uint32_t * RA1, uint32_t * RA2, uint32_t * RA3, 61 | uint32_t * RB0, uint32_t * RB1 62 | ) { 63 | // "h" = .u16 reg; "r" = .u32 reg; "l" = .u64 reg; 64 | // "f" = .f32 reg; "d" = .f64 reg 65 | if constexpr (mma_mode == MMAMode::kInplaceUpdate) { 66 | asm volatile( 67 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 68 | "{%0, %1, %2, %3}, " 69 | "{%4, %5, %6, %7}, " 70 | "{%8, %9}, " 71 | "{%10, %11, %12, %13};\n" 72 | : "=r"(RD0[0]), "=r"(RD1[0]), "=r"(RD2[0]), "=r"(RD3[0]) 73 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]), 74 | "r"(RB0[0]), "r"(RB1[0]), 75 | "r"(RD0[0]), "r"(RD1[0]), "r"(RD2[0]), "r"(RD3[0]) 76 | ); 77 | } else { 78 | // WARN: seems can not get good performance while stage = 1. 79 | asm volatile( 80 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " 81 | "{%0, %1, %2, %3}, " 82 | "{%4, %5, %6, %7}, " 83 | "{%8, %9}, " 84 | "{%10, %11, %12, %13};\n" 85 | : "=r"(RD0[0]), "=r"(RD1[0]), "=r"(RD2[0]), "=r"(RD3[0]) 86 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]), 87 | "r"(RB0[0]), "r"(RB1[0]), 88 | "r"(0), "r"(0), "r"(0), "r"(0) 89 | ); 90 | } 91 | } 92 | 93 | __device__ __forceinline__ void ldmatrix_m8n8x4( 94 | uint32_t * R0, uint32_t * R1, uint32_t * R2, uint32_t * R3, 95 | uint32_t smem_ptr 96 | ) { 97 | asm volatile( 98 | "ldmatrix.sync.aligned.x4.m8n8.shared.b16 " 99 | "{%0, %1, %2, %3}, [%4];\n" 100 | : "=r"(R0[0]), "=r"(R1[0]), "=r"(R2[0]), "=r"(R3[0]) 101 | : "r"(smem_ptr) 102 | ); 103 | } 104 | 105 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans( 106 | uint32_t * R0, uint32_t * R1, uint32_t * R2, uint32_t * R3, 107 | uint32_t smem_ptr 108 | ) { 109 | asm volatile( 110 | "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 " 111 | "{%0, %1, %2, %3}, [%4];\n" 112 | : "=r"(R0[0]), "=r"(R1[0]), "=r"(R2[0]), "=r"(R3[0]) 113 | : "r"(smem_ptr) 114 | ); 115 | } 116 | 117 | __device__ __forceinline__ void ldmatrix_m8n8x2( 118 | uint32_t * R0, uint32_t * R1, 119 | uint32_t smem_ptr 120 | ) { 121 | asm volatile( 122 | "ldmatrix.sync.aligned.x2.m8n8.shared.b16 " 123 | "{%0, %1}, [%2];\n" 124 | : "=r"(R0[0]), "=r"(R1[0]) 125 | : "r"(smem_ptr) 126 | ); 127 | } 128 | 129 | __device__ __forceinline__ void ldmatrix_m8n8x2_trans( 130 | uint32_t * R0, uint32_t * R1, 131 | uint32_t smem_ptr 132 | ) { 133 | asm volatile( 134 | "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 " 135 | "{%0, %1}, [%2];\n" 136 | : "=r"(R0[0]), "=r"(R1[0]) 137 | : "r"(smem_ptr) 138 | ); 139 | } 140 | 141 | } // mma 142 | } // ffpa 143 | -------------------------------------------------------------------------------- /include/cuffpa/prefill.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuffpa/mma.cuh" // ffpa::mma 3 | #include "cuffpa/warp.cuh" // ffpa::warp 4 | #include "cuffpa/swizzle.cuh" // ffpa::swizzle 5 | #include "cuffpa/cp_async.cuh" // ffpa::cp_async 6 | #include "cuffpa/utils.cuh" // ffpa::utils 7 | 8 | namespace ffpa { 9 | namespace prefill { 10 | // prefill utils: prefetch/load QKV g2s funcs, rescale/softmax funcs etc. 11 | 12 | template< 13 | const int kHeadDim, // Headdim, 32,64,128 14 | const int kMmaAtomM, // MMA Atom M, 16 15 | const int kMmaAtomN, // MMA Atom N, 8 16 | const int kMmaAtomK, // MMA Atom K, 16 17 | const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] 18 | const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] 19 | const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] 20 | const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] 21 | const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M 22 | const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N 23 | const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M 24 | const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... 25 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 26 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 27 | const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half. 28 | const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point. 29 | const int kPrefetchPV, // Prefetch V at the Appropriate Time Point. 30 | const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. 31 | const int kPersistQs2r, // Persist load Q s2r for headdim < 320, more registers, but still keep O(1) SRAM. 32 | const int kPersistQg2s, // Persist load Q g2s for headdim < 320, more SRAM, but still keep register usage. 33 | const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. 34 | const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) 35 | const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) 36 | const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding 37 | const int kPadK, 38 | const int kPadV 39 | > 40 | __device__ __forceinline__ void check_large_d_compiling_states() { 41 | // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. 42 | // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. 43 | static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 44 | static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T 45 | static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V 46 | static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T 47 | static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( 48 | kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V 49 | static_assert(kMmaAccFloat32QK == 0 || kMmaAccFloat32QK == 1); 50 | static_assert(kMmaAccFloat32PV == 0 || kMmaAccFloat32PV == 1); 51 | static_assert(kOStorageAccFloat32 == 0 || kOStorageAccFloat32 == 1); 52 | // Make sure that Br >= Bc, for shared memory reuse. 53 | static_assert( 54 | (kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ) >= 55 | (kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK)); 56 | static_assert(kPrefetchQK == 0 || kPrefetchQK == 1); 57 | static_assert(kPrefetchPV == 0 || kPrefetchPV == 1); 58 | static_assert(kShareSmemQKV == 0 || kShareSmemQKV == 1); 59 | // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM. 60 | static_assert(kPersistQs2r == 0 || kPersistQs2r == 1); 61 | // Persist load Q g2s for headdim < 512, more SRAM, but still keep register usage. 62 | static_assert(kPersistQg2s == 0 || kPersistQg2s == 1); 63 | // kPersistQg2s and kPersistQs2r can not both enabled for large d kernel. 64 | static_assert((kPersistQg2s & kPersistQs2r) == 0); 65 | // kPersistQg2s and kShareSmemQKV can not both enabled for large d kernel.. 66 | static_assert((kPersistQg2s & kShareSmemQKV) == 0); 67 | // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. 68 | static_assert(kRegPipeKV == 0 || kRegPipeKV == 1); 69 | // May apply different multi stages policy for QK and V. 70 | static_assert(kStageQK < 5 && kStageQK > 0); // QK (<=4) 71 | static_assert(kStagePV < 5 && kStagePV > 0); // V (<=4) 72 | static_assert(kPadQ >= 0 && kPadQ % 8 == 0); // 0,8,16 73 | static_assert(kPadK >= 0 && kPadK % 8 == 0); // 0,8,16 74 | static_assert(kPadV >= 0 && kPadV % 8 == 0); // 0,8,16 75 | } 76 | 77 | 78 | template< 79 | const int kHeadDim, // Headdim, 32~1024 80 | const int kMmaAtomM, // MMA Atom M, 16 81 | const int kMmaAtomN, // MMA Atom N, 8 82 | const int kMmaAtomK, // MMA Atom K, 16 83 | const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] 84 | const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] 85 | const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] 86 | const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] 87 | const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M 88 | const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N 89 | const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M 90 | const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... 91 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 92 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32. 93 | const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half. 94 | const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point. 95 | const int kPrefetchPV, // Prefetch V at the Appropriate Time Point. 96 | const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V. 97 | const int kPersistQs2r, // Persist load Q s2r for headdim <= 128, more registers. 98 | const int kPersistVs2r, // Persist load V s2r for headdim <= 128, more registers. 99 | const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping. 100 | const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4) 101 | const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4) 102 | const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding 103 | const int kPadK, 104 | const int kPadV 105 | > 106 | __device__ __forceinline__ void check_small_d_compiling_states() { 107 | // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. 108 | // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. 109 | static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 110 | static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T 111 | static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V 112 | static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T 113 | static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( 114 | kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V 115 | static_assert(kMmaAccFloat32QK == 0 || kMmaAccFloat32QK == 1); 116 | static_assert(kMmaAccFloat32PV == 0 || kMmaAccFloat32PV == 1); 117 | static_assert(kOStorageAccFloat32 == 0 || kOStorageAccFloat32 == 1); 118 | // Make sure that Br >= Bc, for shared memory reuse. 119 | static_assert( 120 | (kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ) >= 121 | (kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK)); 122 | static_assert(kPrefetchQK == 0 || kPrefetchQK == 1); 123 | static_assert(kPrefetchPV == 0 || kPrefetchPV == 1); 124 | static_assert(kShareSmemQKV == 0 || kShareSmemQKV == 1); 125 | // Persist load Q s2r for headdim <= 128, more registers. 126 | static_assert(kPersistQs2r == 0 || kPersistQs2r == 1); 127 | // Persist load V s2r for headdim <= 128, more registers. 128 | static_assert(kPersistVs2r == 0 || kPersistVs2r == 1); 129 | if constexpr (kShareSmemQKV) { 130 | // kPersistQs2r must be enabled is set kShareSmemQKV as 1 131 | static_assert(kPersistQs2r == 1); 132 | } 133 | // Registers Ping pong double buffers for ldmatrix s2r & mma 134 | // computation overlapping. 135 | static_assert(kRegPipeKV == 0 || kRegPipeKV == 1); 136 | // kRegPipeKV and kPersistVs2r can not both enabled. 137 | static_assert((kRegPipeKV & kPersistVs2r) == 0); 138 | // May apply different multi stages policy for QK and V. 139 | static_assert(kStageQK < 5 && kStageQK > 0); // QK (<=4) 140 | static_assert(kStagePV < 5 && kStagePV > 0); // V (<=4) 141 | static_assert(kPadQ >= 0 && kPadQ % 8 == 0); // 0,8,16 142 | static_assert(kPadK >= 0 && kPadK % 8 == 0); // 0,8,16 143 | static_assert(kPadV >= 0 && kPadV % 8 == 0); // 0,8,16 144 | } 145 | 146 | 147 | template< 148 | const int BrOrBc, 149 | const int kTileSize, 150 | const int kHeadDim, 151 | const int kMmaAtomK, 152 | const int kNumThreads, 153 | const int kPad 154 | > 155 | __device__ __forceinline__ void cp_async_qkv_g2s( 156 | uint32_t smem_base_ptr, // QKV smem base ptr 157 | const half * gmem_ptr, // QKV gmem ptr 158 | const int gmem_offset, // QKV gmem_offset 159 | const int n_tile_id, // seqlen offset, Q_tile_id * Br, tile_K_seqlen * Bc 160 | const int d_tile_id, // headdim offset, tile_K_d * kMmaAtomK, tile_V_d * kMmaAtomN * 2 161 | const int stage // stage * QKV tile_size 162 | ) { 163 | // QK: tile_K_d < (kHeadDim / kMmaAtomK) 164 | // V: tile_V_d < (kHeadDim / kMmaAtomN * 2) 165 | if (d_tile_id >= (kHeadDim / kMmaAtomK)) { return; } 166 | const int tid = threadIdx.x; // within block 167 | constexpr bool kSwizzle = (kPad == 0) ? true : false; 168 | // Mapping QKV tid -> smem, tile size [64/128, 16] 169 | // Br 64, tid / 2, row 0~64 170 | const int load_smem_BrOrBc = (tid / (kNumThreads / BrOrBc)); 171 | // (tid % 2) * 8, 0,8,... 172 | const int load_smem_d = ( 173 | tid % (kNumThreads / BrOrBc)) * (kMmaAtomK / (kNumThreads / BrOrBc)); 174 | // Mapping QKV tid -> gmem, tile size [64/128, 16], row offset by 175 | // n_tile_id(seqlen), col offset by d_tile_id(Headdim). 176 | const int load_gmem_BrOrBc = (n_tile_id * BrOrBc) + load_smem_BrOrBc; 177 | const int load_gmem_d = (d_tile_id * kMmaAtomK) + load_smem_d; // 0,8 178 | // Offset by QKV global gmem_offset. 179 | const int load_gmem_addr = ( 180 | gmem_offset + load_gmem_BrOrBc * kHeadDim + load_gmem_d); 181 | 182 | // cp async & apply swizzle or padding. 183 | #pragma unroll 184 | for (int i = 0; i < (kMmaAtomK / (kNumThreads / BrOrBc)); i += 8) { 185 | const uint32_t load_smem_ptr = ( 186 | smem_base_ptr + (stage * kTileSize + 187 | load_smem_BrOrBc * (kMmaAtomK + kPad) + 188 | (kSwizzle ? swizzle::permuted( 189 | load_smem_BrOrBc, load_smem_d + i) : 190 | load_smem_d + i ) 191 | ) * sizeof(half)); 192 | cp_async::cp_async<16>(load_smem_ptr, &(gmem_ptr[load_gmem_addr + i])); 193 | } 194 | // cp_async::commit_group(); 195 | } 196 | 197 | 198 | template< 199 | const int kTrans, 200 | const int kNumRegs, 201 | const int kTileSize, 202 | const int kMmaAtomM, 203 | const int kMmaAtomN, 204 | const int kMmaAtomK, 205 | const int kPad 206 | > 207 | __device__ __forceinline__ void sync_fetch_qkv_frags_s2r( 208 | uint32_t smem_base_ptr, // QKV smem base ptr 209 | uint32_t * R, // Register ptr, R_QKV 210 | const int mma_tile_id, // Q warp_QP 0~num MMAs, KV warp_KV 0 211 | const int warp_tile_id, // Q 0, KV 0~kWarpTileSeqLenK 212 | const int n_tile_id, // seqlen QK 0, V tile_V_Bc 213 | const int stage 214 | ) { 215 | const int lane_id = threadIdx.x % WARP_SIZE; // 0~31 216 | constexpr bool kSwizzle = (kPad == 0) ? true : false; 217 | if constexpr (kTrans) { 218 | // load V m8n8x2 via ldmatrix.x2.trans 219 | static_assert(kNumRegs == 2); 220 | // mma_tile_id = warp_KV == 0, warp_tile_id = (j % 2), n_tile_id = tile_V_Bc 221 | // warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + (j % 2) * kMmaAtomN; 222 | const int warp_smem_d = warp_tile_id * kMmaAtomN; 223 | const int lane_smem_Bc = n_tile_id * kMmaAtomK + lane_id % 16; 224 | const int lane_smem_d = warp_smem_d; // 0,8 225 | const uint32_t lane_smem_ptr = ( 226 | smem_base_ptr + (stage * kTileSize + 227 | lane_smem_Bc * (kMmaAtomN * 2 + kPad) + 228 | (kSwizzle ? swizzle::permuted( 229 | lane_smem_Bc, lane_smem_d): 230 | lane_smem_d) 231 | ) * sizeof(half) 232 | ); 233 | mma::ldmatrix_m8n8x2_trans(&R[0], &R[1], lane_smem_ptr); 234 | } else { 235 | static_assert(kNumRegs == 2 || kNumRegs == 4); 236 | if constexpr (kNumRegs == 4) { 237 | // load Q m8n8x4 via ldmatrix.x4 238 | // mma_tile_id = warp_QP, kWarpTileSeqLenQ=1 239 | // warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + 0 * kMmaAtomM 240 | const int warp_smem_Br = mma_tile_id * (kMmaAtomM); 241 | const int lane_smem_Br = warp_smem_Br + lane_id % 16; // 0~15 242 | const int lane_smem_d = (lane_id / 16) * 8; // 0,8 243 | const uint32_t lane_smem_ptr = ( 244 | smem_base_ptr + (stage * kTileSize + 245 | lane_smem_Br * (kMmaAtomK + kPad) + 246 | (kSwizzle ? swizzle::permuted( 247 | lane_smem_Br, lane_smem_d): 248 | lane_smem_d) 249 | ) * sizeof(half) 250 | ); 251 | mma::ldmatrix_m8n8x4(&R[0], &R[1], &R[2], &R[3], lane_smem_ptr); 252 | } else { 253 | // load K m8n8x2 via ldmatrix.x2 254 | // mma_tile_id = warp_KV == 0, warp_tile_id = j 255 | // warp_smem_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; 256 | const int warp_smem_Bc = warp_tile_id * kMmaAtomN; 257 | const int lane_smem_Bc = warp_smem_Bc + lane_id % 8; // 0~7 258 | const int lane_smem_d = ((lane_id / 8) % 2) * 8; // 0,8 259 | const uint32_t lane_smem_ptr = ( 260 | smem_base_ptr + (stage * kTileSize + 261 | lane_smem_Bc * (kMmaAtomK + kPad) + 262 | (kSwizzle ? swizzle::permuted( 263 | lane_smem_Bc, lane_smem_d): 264 | lane_smem_d ) 265 | ) * sizeof(half) 266 | ); 267 | mma::ldmatrix_m8n8x2(&R[0], &R[1], lane_smem_ptr); 268 | } 269 | } 270 | } 271 | 272 | 273 | template 274 | __device__ __forceinline__ void sync_online_safe_softmax( 275 | uint32_t * R_S, // &R_S[0][0][0] 276 | const float scale, // 1 / sqrt(d) 277 | float * lane_row_max_new, // &lane_row_max_new[0][0] 278 | float * lane_row_sum_new, // &lane_row_sum_new[0][0] 279 | float * lane_block_row_max_old, // &lane_block_row_max_old[0][0] 280 | float * lane_block_row_sum_old // &lane_block_row_sum_old[0][0] 281 | ) { 282 | if constexpr (kMmaAccFloat32) { 283 | // Row max for [Br,Bc] tile, Thread -> Warp -> Block. 284 | { // kWarpTileSeqLenQ = 1 285 | // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. 286 | #pragma unroll 287 | for (int j = 0; j < kWarpTileSeqLenK; ++j) { 288 | const float* t_fptr_S_0_1 = reinterpret_cast(R_S + j * 4); // &R_S[0][j][0] 289 | // This should be the row max after S = (Q @ K^T) / sqrt(d) 290 | const float tmp_max_0 = max(t_fptr_S_0_1[0], t_fptr_S_0_1[1]) * scale; 291 | const float tmp_max_1 = max(t_fptr_S_0_1[2], t_fptr_S_0_1[3]) * scale; 292 | lane_row_max_new[0] = max(lane_row_max_new[0], tmp_max_0); 293 | lane_row_max_new[1] = max(lane_row_max_new[1], tmp_max_1); 294 | } // end for kWarpTileSeqLenK 295 | 296 | // Warp level reduce max, warp_size = 4 297 | // Each thread contains the maximum of 2 rows of Br, 298 | // and only the values of T0, T4, ..., T28 are used. 299 | lane_row_max_new[0] = warp::reduce_max(lane_row_max_new[0]); 300 | lane_row_max_new[1] = warp::reduce_max(lane_row_max_new[1]); 301 | } // end for kWarpTileSeqLenQ 302 | 303 | // static_assert(kWarpTileSeqLenQ == 1); 304 | // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. 305 | { // kWarpTileSeqLenQ = 1 306 | // Use latest global row max without update. 307 | // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; 308 | // Br 1, row_id, 8~15, 24~31, 40~47, 56~63; 309 | // Apply m_new = max(m_old, m_new) here. 310 | const float block_row_max_new_0 = max(lane_block_row_max_old[0], lane_row_max_new[0]); 311 | const float block_row_max_new_1 = max(lane_block_row_max_old[1], lane_row_max_new[1]); 312 | 313 | #pragma unroll 314 | for (int j = 0; j < kWarpTileSeqLenK; ++j) { 315 | // R_S[][][4] 4 32bit registers with each contains 1 F32 element. 316 | // (x,y) 0~7->{c0, c1}, (z,w)->8~15 {c2, c3} 317 | float* t_fptr_S_0_1 = reinterpret_cast(R_S + j * 4); 318 | half* t_hptr_S_0_1 = reinterpret_cast< half*>(R_S + j * 4); 319 | // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z in registers; 320 | t_fptr_S_0_1[0] = __expf(__fmaf_rn(t_fptr_S_0_1[0], scale, - block_row_max_new_0)); 321 | t_fptr_S_0_1[1] = __expf(__fmaf_rn(t_fptr_S_0_1[1], scale, - block_row_max_new_0)); 322 | t_fptr_S_0_1[2] = __expf(__fmaf_rn(t_fptr_S_0_1[2], scale, - block_row_max_new_1)); 323 | t_fptr_S_0_1[3] = __expf(__fmaf_rn(t_fptr_S_0_1[3], scale, - block_row_max_new_1)); 324 | lane_row_sum_new[0] += (t_fptr_S_0_1[0] + t_fptr_S_0_1[1]); 325 | lane_row_sum_new[1] += (t_fptr_S_0_1[2] + t_fptr_S_0_1[3]); 326 | // Update R_S for P[Br,Bc] = Exp(S-m), point wise. 327 | // Also convert F32 -> half for P@V MMA, reuse R_S as P. 328 | t_hptr_S_0_1[0] = __float2half_rn(t_fptr_S_0_1[0]); 329 | t_hptr_S_0_1[1] = __float2half_rn(t_fptr_S_0_1[1]); 330 | t_hptr_S_0_1[2] = __float2half_rn(t_fptr_S_0_1[2]); 331 | t_hptr_S_0_1[3] = __float2half_rn(t_fptr_S_0_1[3]); 332 | } // end for kWarpTileSeqLenK 333 | 334 | // Warp level reduce sum, warp_size = 4 335 | lane_row_sum_new[0] = warp::reduce_sum(lane_row_sum_new[0]); 336 | lane_row_sum_new[1] = warp::reduce_sum(lane_row_sum_new[1]); 337 | } 338 | 339 | } else { 340 | // MMA Acc F16 341 | // Row max for [Br,Bc] tile, Thread -> Warp -> Block. 342 | { // kWarpTileSeqLenQ = 1 343 | // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. 344 | #pragma unroll 345 | for (int j = 0; j < kWarpTileSeqLenK; ++j) { 346 | const half* t_hptr_S_0_1 = reinterpret_cast(R_S + j * 2); 347 | // This should be the row max after S = (Q @ K^T) / sqrt(d) 348 | const float tmp_max_0 = __half2float(__hmax(t_hptr_S_0_1[0], t_hptr_S_0_1[1])) * scale; 349 | const float tmp_max_1 = __half2float(__hmax(t_hptr_S_0_1[2], t_hptr_S_0_1[3])) * scale; 350 | lane_row_max_new[0] = max(lane_row_max_new[0], tmp_max_0); 351 | lane_row_max_new[1] = max(lane_row_max_new[1], tmp_max_1); 352 | } // end for kWarpTileSeqLenK 353 | 354 | // Warp level reduce max, warp_size = 4 355 | // Each thread contains the maximum of 2 rows of Br, 356 | // and only the values of T0, T4, ..., T28 are used. 357 | lane_row_max_new[0] = warp::reduce_max(lane_row_max_new[0]); 358 | lane_row_max_new[1] = warp::reduce_max(lane_row_max_new[1]); 359 | } // end for kWarpTileSeqLenQ 360 | 361 | // static_assert(kWarpTileSeqLenQ == 1); 362 | // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. 363 | { // kWarpTileSeqLenQ = 1 364 | // Apply m_new = max(m_old, m_new) here. 365 | const float block_row_max_new_0 = max(lane_block_row_max_old[0], lane_row_max_new[0]); 366 | const float block_row_max_new_1 = max(lane_block_row_max_old[1], lane_row_max_new[1]); 367 | 368 | #pragma unroll 369 | for (int j = 0; j < kWarpTileSeqLenK; ++j) { 370 | half* t_hptr_S_0_1 = reinterpret_cast(R_S + j * 2); 371 | // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z; 372 | float4 t_reg_S_0_1; 373 | t_reg_S_0_1.x = __expf(__fmaf_rn( 374 | __half2float(t_hptr_S_0_1[0]), scale, - block_row_max_new_0)); 375 | t_reg_S_0_1.y = __expf(__fmaf_rn( 376 | __half2float(t_hptr_S_0_1[1]), scale, - block_row_max_new_0)); 377 | t_reg_S_0_1.z = __expf(__fmaf_rn( 378 | __half2float(t_hptr_S_0_1[2]), scale, - block_row_max_new_1)); 379 | t_reg_S_0_1.w = __expf(__fmaf_rn( 380 | __half2float(t_hptr_S_0_1[3]), scale, - block_row_max_new_1)); 381 | lane_row_sum_new[0] += (t_reg_S_0_1.x + t_reg_S_0_1.y); 382 | lane_row_sum_new[1] += (t_reg_S_0_1.z + t_reg_S_0_1.w); 383 | // Update R_S for P[Br,Bc] = Exp(S-m), point wise. 384 | t_hptr_S_0_1[0] = __float2half_rn(t_reg_S_0_1.x); 385 | t_hptr_S_0_1[1] = __float2half_rn(t_reg_S_0_1.y); 386 | t_hptr_S_0_1[2] = __float2half_rn(t_reg_S_0_1.z); 387 | t_hptr_S_0_1[3] = __float2half_rn(t_reg_S_0_1.w); 388 | } // end for kWarpTileSeqLenK 389 | 390 | // Warp level reduce sum, warp_size = 4 391 | lane_row_sum_new[0] = warp::reduce_sum(lane_row_sum_new[0]); 392 | lane_row_sum_new[1] = warp::reduce_sum(lane_row_sum_new[1]); 393 | } 394 | } 395 | } 396 | 397 | 398 | __device__ __forceinline__ void sync_precompute_rescale_factors( 399 | float * rescale_o_factor_0, // rescale factor 400 | float * rescale_o_factor_1, // rescale factor 401 | const float * lane_row_max_new, // &lane_row_max_new[0][0] 402 | const float * lane_block_row_max_old, // &lane_block_row_max_old[0][0] 403 | const int n_tile_id // tile_K_seqlen 404 | ) { 405 | float block_row_max_new_0 = lane_row_max_new[0]; 406 | float block_row_max_new_1 = lane_row_max_new[1]; 407 | float block_row_max_old_0 = lane_block_row_max_old[0]; 408 | float block_row_max_old_1 = lane_block_row_max_old[1]; 409 | // NOTE: max(-inf, val) = val. 410 | block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); 411 | block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); 412 | // Avoid inf value while using m_old for rescaling O. 413 | block_row_max_old_0 = (n_tile_id > 0 ? block_row_max_old_0 : 414 | block_row_max_new_0); 415 | block_row_max_old_1 = (n_tile_id > 0 ? block_row_max_old_1 : 416 | block_row_max_new_1); 417 | // Precompute rescale_o_factor_0 & rescale_o_factor_1, avoid redundant exp. 418 | rescale_o_factor_0[0] = __expf(block_row_max_old_0 - block_row_max_new_0); 419 | rescale_o_factor_1[0] = __expf(block_row_max_old_1 - block_row_max_new_1); 420 | } 421 | 422 | template 423 | __device__ __forceinline__ void sync_rescaling_tiling_o( 424 | uint32_t * R_D, // &R_D[0][0][0] 425 | uint32_t * R_O, // &R_O[0] 426 | const float * rescale_o_factor_0, // rescale factor 427 | const float * rescale_o_factor_1, // rescale factor 428 | const int n_tile_id, // tile_K_seqlen 429 | const int d_tile_id // j 430 | ) { 431 | // Now, we get [Br,8] slice of [Br,d], each warp(MMA) contains m16n8. 432 | // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old. 433 | // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V 434 | // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper) 435 | if constexpr (kMmaAccFloat32) { 436 | const float* t_fptr_O_0_1 = reinterpret_cast(R_O); 437 | if constexpr (kOStorageAccFloat32) { 438 | // (x,y) 0~7->{c0, c1}, (z,w)->8~15 {c2, c3} kWarpTileSeqLenP=1 439 | float* t_fptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 4); // &(R_D[0][j][0]) 440 | t_fptr_D_0_1[0] = __fmaf_rn(rescale_o_factor_0[0], t_fptr_D_0_1[0], t_fptr_O_0_1[0]); 441 | t_fptr_D_0_1[1] = __fmaf_rn(rescale_o_factor_0[0], t_fptr_D_0_1[1], t_fptr_O_0_1[1]); 442 | t_fptr_D_0_1[2] = __fmaf_rn(rescale_o_factor_1[0], t_fptr_D_0_1[2], t_fptr_O_0_1[2]); 443 | t_fptr_D_0_1[3] = __fmaf_rn(rescale_o_factor_1[0], t_fptr_D_0_1[3], t_fptr_O_0_1[3]); 444 | } else { 445 | half* t_hptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 2); 446 | t_hptr_D_0_1[0] = __float2half_rn(__fmaf_rn( 447 | rescale_o_factor_0[0], __half2float(t_hptr_D_0_1[0]), t_fptr_O_0_1[0])); 448 | t_hptr_D_0_1[1] = __float2half_rn(__fmaf_rn( 449 | rescale_o_factor_0[0], __half2float(t_hptr_D_0_1[1]), t_fptr_O_0_1[1])); 450 | t_hptr_D_0_1[2] = __float2half_rn(__fmaf_rn( 451 | rescale_o_factor_1[0], __half2float(t_hptr_D_0_1[2]), t_fptr_O_0_1[2])); 452 | t_hptr_D_0_1[3] = __float2half_rn(__fmaf_rn( 453 | rescale_o_factor_1[0], __half2float(t_hptr_D_0_1[3]), t_fptr_O_0_1[3])); 454 | } 455 | } else { 456 | // MMA Acc F16 457 | const half* t_hptr_O_0_1 = reinterpret_cast(R_O); 458 | if constexpr (kOStorageAccFloat32) { 459 | // (x,y) 0~7->{c0, c1}, (z,w)->8~15 {c2, c3} kWarpTileSeqLenP=1 460 | float* t_fptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 4); 461 | t_fptr_D_0_1[0] = __fmaf_rn( 462 | rescale_o_factor_0[0], t_fptr_D_0_1[0], __half2float(t_hptr_O_0_1[0])); 463 | t_fptr_D_0_1[1] = __fmaf_rn( 464 | rescale_o_factor_0[0], t_fptr_D_0_1[1], __half2float(t_hptr_O_0_1[1])); 465 | t_fptr_D_0_1[2] = __fmaf_rn( 466 | rescale_o_factor_1[0], t_fptr_D_0_1[2], __half2float(t_hptr_O_0_1[2])); 467 | t_fptr_D_0_1[3] = __fmaf_rn( 468 | rescale_o_factor_1[0], t_fptr_D_0_1[3], __half2float(t_hptr_O_0_1[3])); 469 | } else { 470 | half* t_hptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 2); 471 | t_hptr_D_0_1[0] = __float2half_rn(__fmaf_rn(rescale_o_factor_0[0], 472 | __half2float(t_hptr_D_0_1[0]), __half2float(t_hptr_O_0_1[0]))); 473 | t_hptr_D_0_1[1] = __float2half_rn(__fmaf_rn(rescale_o_factor_0[0], 474 | __half2float(t_hptr_D_0_1[1]), __half2float(t_hptr_O_0_1[1]))); 475 | t_hptr_D_0_1[2] = __float2half_rn(__fmaf_rn(rescale_o_factor_1[0], 476 | __half2float(t_hptr_D_0_1[2]), __half2float(t_hptr_O_0_1[2]))); 477 | t_hptr_D_0_1[3] = __float2half_rn(__fmaf_rn(rescale_o_factor_1[0], 478 | __half2float(t_hptr_D_0_1[3]), __half2float(t_hptr_O_0_1[3]))); 479 | } 480 | } 481 | } 482 | 483 | __device__ __forceinline__ void sync_update_max_expsum( 484 | float * lane_row_max_new, // &lane_row_max_new[0][0] 485 | float * lane_row_sum_new, // &lane_row_sum_new[0][0] 486 | float * lane_block_row_max_old, // &lane_block_row_max_old[0][0] 487 | float * lane_block_row_sum_old, // &lane_block_row_sum_old[0][0] 488 | const float * rescale_o_factor_0, // rescale factor 0 exp(m_old - m_new) 489 | const float * rescale_o_factor_1 // rescale factor 1 exp(m_old - m_new) 490 | ) { 491 | // Now, we can update m, l after O has been scaled. 492 | // Update l = exp(m_old - m_new) * l_old + row_sum(P). 493 | lane_block_row_sum_old[0] = (__fmaf_rn( 494 | rescale_o_factor_0[0], lane_block_row_sum_old[0], lane_row_sum_new[0])); 495 | lane_block_row_sum_old[1] = (__fmaf_rn( 496 | rescale_o_factor_1[0], lane_block_row_sum_old[1], lane_row_sum_new[1])); 497 | // 2. Then, update block row max for each lane. 498 | lane_block_row_max_old[0] = max(lane_block_row_max_old[0], lane_row_max_new[0]); 499 | lane_block_row_max_old[1] = max(lane_block_row_max_old[1], lane_row_max_new[1]); 500 | } 501 | 502 | 503 | template 504 | __device__ __forceinline__ void sync_rescaling_final_o( 505 | uint32_t * R_D, // Final O after loop over N 506 | const float * lane_block_row_sum_old // &lane_block_row_sum_old[0][0] 507 | ) { 508 | // Finaly, we still have to rescale O once more. 509 | // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) 510 | // static_assert(kWarpTileSeqLenP == 1); 511 | { // kWarpTileSeqLenP = 1 512 | const float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[0]); 513 | const float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[1]); 514 | #pragma unroll 515 | for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... 516 | // Scaling in registers & convert F32 -> half for O collective store. 517 | if constexpr (kOStorageAccFloat32) { 518 | const float* t_fptr_D_0_1 = reinterpret_cast(R_D + j * 4); 519 | half* t_hptr_D_0_1 = reinterpret_cast< half*>(R_D + j * 4); 520 | t_hptr_D_0_1[0] = __float2half_rn(rescale_factor_0 * t_fptr_D_0_1[0]); 521 | t_hptr_D_0_1[1] = __float2half_rn(rescale_factor_0 * t_fptr_D_0_1[1]); 522 | t_hptr_D_0_1[2] = __float2half_rn(rescale_factor_1 * t_fptr_D_0_1[2]); 523 | t_hptr_D_0_1[3] = __float2half_rn(rescale_factor_1 * t_fptr_D_0_1[3]); 524 | } else { 525 | half* t_hptr_D_0_1 = reinterpret_cast(R_D + j * 2); 526 | t_hptr_D_0_1[0] = __float2half_rn(rescale_factor_0 * __half2float(t_hptr_D_0_1[0])); 527 | t_hptr_D_0_1[1] = __float2half_rn(rescale_factor_0 * __half2float(t_hptr_D_0_1[1])); 528 | t_hptr_D_0_1[2] = __float2half_rn(rescale_factor_1 * __half2float(t_hptr_D_0_1[2])); 529 | t_hptr_D_0_1[3] = __float2half_rn(rescale_factor_1 * __half2float(t_hptr_D_0_1[3])); 530 | } 531 | } // end for kWarpTileHeadDimV 532 | } // end for kWarpTileSeqLenP = 1 533 | } 534 | 535 | 536 | template< 537 | const int Br, 538 | const int kHeadDim, 539 | const int kMmaAtomM, 540 | const int kMmaAtomN, 541 | const int kWarpTileHeadDimV, 542 | const int kOStorageAccFloat32 543 | > 544 | __device__ __forceinline__ void sync_store_o_r2g( 545 | half * gmem_ptr, // O gmem ptr 546 | const int gmem_offset, // O gmem global offset 547 | const int n_tile_id, // curr tile id (seqlen) O_tile_id 548 | const int mma_tile_id, // Q warp_QP 0~num MMAs, KV warp_KV 0 549 | uint32_t * R_D, // Final scaled O 550 | uint32_t * R_Q, // R_Q[1][4] for registers reuse 551 | uint32_t * R_K // R_K[8][2] for registers reuse 552 | ) { 553 | // Store O(D): Write O[Br,d] from regs -> gmem, collective store 554 | // with reg reuse & warp shuffle. 555 | const int lane_id = threadIdx.x % WARP_SIZE; // 0~31 556 | // static_assert(kWarpTileSeqLenP == 1); 557 | { // kWarpTileSeqLenP = 1 558 | #pragma unroll 559 | for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 560 | // reuse R_Q[1][4], R_K[8][2] for collective store. 561 | uint32_t* t_uptr_Z_0 = reinterpret_cast(R_Q); 562 | uint32_t* t_uptr_Z_1 = reinterpret_cast(R_K); 563 | const int offset = (kOStorageAccFloat32) ? j * 4 : j * 2; 564 | t_uptr_Z_0[0] = R_D[offset + 0]; 565 | t_uptr_Z_1[0] = R_D[offset + 1]; 566 | t_uptr_Z_0[1] = __shfl_sync((0xffffffff), R_D[offset + 0], lane_id + 1, 4); 567 | t_uptr_Z_0[2] = __shfl_sync((0xffffffff), R_D[offset + 0], lane_id + 2, 4); 568 | t_uptr_Z_0[3] = __shfl_sync((0xffffffff), R_D[offset + 0], lane_id + 3, 4); 569 | t_uptr_Z_1[1] = __shfl_sync((0xffffffff), R_D[offset + 1], lane_id + 1, 4); 570 | t_uptr_Z_1[2] = __shfl_sync((0xffffffff), R_D[offset + 1], lane_id + 2, 4); 571 | t_uptr_Z_1[3] = __shfl_sync((0xffffffff), R_D[offset + 1], lane_id + 3, 4); 572 | 573 | // st.global.v4 128 bits. [Br,d] 574 | if (lane_id % 4 == 0) { 575 | // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 kWarpTileSeqLenP = 1 576 | // int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + 0 * kMmaAtomM; 577 | const int store_warp_regs_O_Br = mma_tile_id * (kMmaAtomM); 578 | const int store_lane_gmem_O_Br = n_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 579 | // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) warp_KV = 0 580 | // int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; 581 | const int store_warp_regs_O_d = j * kMmaAtomN; 582 | const int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) 583 | const int store_gmem_O_addr_0 = ( 584 | gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); 585 | const int store_gmem_O_addr_1 = ( 586 | gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); 587 | cp_async::stg_sync_128b(&gmem_ptr[store_gmem_O_addr_0], t_uptr_Z_0); 588 | cp_async::stg_sync_128b(&gmem_ptr[store_gmem_O_addr_1], t_uptr_Z_1); 589 | } 590 | } // end for kWarpTileHeadDimV 591 | } // kWarpTileSeqLenP = 1 592 | } 593 | 594 | } // prefill 595 | } // ffpa 596 | -------------------------------------------------------------------------------- /include/cuffpa/swizzle.cuh: -------------------------------------------------------------------------------- 1 | // Manually SMEM swizzling for bank conflict free. 2 | // ---------------------------------------------------------------- 3 | // [INFO] Assert smem store layout col_stride <= 16, prefer 16. | 4 | // [INFO] For logical_col_stride > 16, we have to permute the | 5 | // [INFO] smem store layout using col major ZigZag method: | 6 | // [INFO] e.g, --> Q smem logical layout [Br][64]. | 7 | // [INFO] --> col major ZigZag permuted --> | 8 | // [INFO] --> Q smem store layout [4][Br][16]. | 9 | // ---------------------------------------------------------------- 10 | // ---------------------------------------------------------------- 11 | // -------------------------swizzle layout------------------------- 12 | // --------------------logical col 0~64, step 8-------------------- 13 | // ---------------------smem col 0~16, step 8---------------------- 14 | // ---------------------------------------------------------------- 15 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 16 | // |row 0 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 17 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 18 | // |row 1 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 19 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 20 | // |row 2 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 21 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 22 | // |row 3 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 23 | // ---------------------------------------------------------------- 24 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 25 | // |row 4 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 26 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 27 | // |row 5 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 28 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 29 | // |row 6 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 30 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 31 | // |row 7 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 32 | // ---------------------------------------------------------------- 33 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 34 | // |row 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 35 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 36 | // |row 9 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 37 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 38 | // |row 10| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 39 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 40 | // |row 11| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 41 | // ---------------------------------------------------------------- 42 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 43 | // |row 12| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 44 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 45 | // |row 13| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 46 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 47 | // |row 14| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 48 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 49 | // |row 15| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 50 | // ---------------------------------------------------------------- 51 | #pragma once 52 | #include 53 | 54 | namespace ffpa { 55 | namespace swizzle { 56 | 57 | // i: row index; j: col index. 58 | template 59 | static __device__ __forceinline__ int permuted(int i, int j) { 60 | // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep; 61 | static_assert(kColStride <= 16, "Currently, kColStride must be less than or equal to 16."); 62 | static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4."); 63 | static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep."); 64 | if constexpr (kStep == 8) { 65 | return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3; 66 | } else { 67 | static_assert(kStep == 4); 68 | return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2; 69 | } 70 | } 71 | 72 | } // swizzle 73 | } // ffpa 74 | 75 | -------------------------------------------------------------------------------- /include/cuffpa/utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuffpa/logging.cuh" // log 3 | 4 | namespace ffpa { 5 | namespace utils { 6 | 7 | __device__ __host__ inline 8 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } 9 | 10 | template 11 | __device__ inline void fill_3D_regs(T (&R)[M][N][K], T val) { 12 | #pragma unroll 13 | for (int i = 0; i < M; ++i) { 14 | #pragma unroll 15 | for (int j = 0; j < N; ++j) { 16 | #pragma unroll 17 | for (int k = 0; k < K; ++k) { 18 | R[i][j][k] = val; 19 | } 20 | } 21 | } 22 | } 23 | 24 | template 25 | __device__ inline void fill_2D_regs(T (&R)[M][N], T val) { 26 | #pragma unroll 27 | for (int i = 0; i < M; ++i) { 28 | #pragma unroll 29 | for (int j = 0; j < N; ++j) { 30 | R[i][j] = val; 31 | } 32 | } 33 | } 34 | 35 | template 36 | __device__ inline void fill_1D_regs(T (&S)[M], T val) { 37 | #pragma unroll 38 | for (int i = 0; i < M; ++i) { 39 | S[i] = val; 40 | } 41 | } 42 | 43 | } // utils 44 | } // ffpa 45 | 46 | -------------------------------------------------------------------------------- /include/cuffpa/warp.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #define WARP_SIZE 32 15 | 16 | namespace ffpa { 17 | namespace warp { 18 | 19 | template 20 | __device__ inline T reduce_sum(T val) { 21 | #pragma unroll 22 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { 23 | val += __shfl_xor_sync(0xffffffff, val, mask, kWarpSize); 24 | } 25 | return val; 26 | } 27 | 28 | template 29 | __device__ inline T reduce_max(T val) { 30 | #pragma unroll 31 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) { 32 | val = max(val, __shfl_xor_sync(0xffffffff, val, mask, kWarpSize)); 33 | } 34 | return val; 35 | } 36 | 37 | } // warp 38 | } // ffpa 39 | -------------------------------------------------------------------------------- /include/extension/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info 22 | dist 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | ninja 3 | torch>=2.4.0 4 | numpy 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [metadata] 5 | license_files = LICENSE 6 | desciption_file = README.md 7 | 8 | [pep8] 9 | max-line-length = 120 10 | 11 | [flake8] 12 | max-line-length = 120 13 | ignore = E731, E203, E402, W503, W504, F821, E501, B, C4, EXE 14 | per-file-ignores = 15 | __init__.py: F401, F403, F405 16 | exclude = venv 17 | 18 | [pydocstyle] 19 | select = D417 # Missing argument descriptions in the docstring 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from env import ENV 4 | from packaging.version import Version 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | 11 | def get_long_description(): 12 | description = (Path(ENV.project_dir()) / "README.md").read_text(encoding="utf-8") 13 | return description 14 | 15 | 16 | # package name managed by pip, which can be remove by `pip uninstall ffpa-attn -y` 17 | PACKAGE_NAME = "ffpa-attn" 18 | 19 | ext_modules = [] 20 | generator_flag = [] 21 | cc_flag = [] 22 | 23 | ENV.list_ffpa_env() 24 | 25 | if ENV.enable_ampere(): 26 | cc_flag.append("-gencode") 27 | cc_flag.append("arch=compute_80,code=sm_80") 28 | 29 | if ENV.enable_ada(): 30 | cc_flag.append("-gencode") 31 | cc_flag.append("arch=compute_89,code=sm_89") 32 | 33 | if ENV.enable_hopper(): 34 | if CUDA_HOME is not None: 35 | _, bare_metal_version = ENV.get_cuda_bare_metal_version(CUDA_HOME) 36 | if bare_metal_version >= Version("11.8"): 37 | cc_flag.append("-gencode") 38 | cc_flag.append("arch=compute_90,code=sm_90") 39 | 40 | assert cc_flag is not None, "cc_flag can not be NoneType." 41 | 42 | # cuda module 43 | # may need export LD_LIBRARY_PATH=PATH-TO/torch/lib:$LD_LIBRARY_PATH 44 | ext_modules.append( 45 | CUDAExtension( 46 | # package name for import 47 | name="pyffpa_cuda", 48 | sources=ENV.get_build_sources(build_pkg=True), 49 | extra_compile_args={ 50 | # add c compile flags 51 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 52 | # add nvcc compile flags 53 | "nvcc": ENV.get_build_cuda_cflags(build_pkg=True) 54 | + generator_flag 55 | + cc_flag, 56 | }, 57 | include_dirs=[ 58 | Path(ENV.project_dir()) / "include", 59 | ], 60 | ) 61 | ) 62 | 63 | 64 | def fetch_requirements(): 65 | with open("requirements.txt") as f: 66 | reqs = f.read().strip().split("\n") 67 | return reqs 68 | 69 | 70 | setup( 71 | name=PACKAGE_NAME, 72 | version="0.0.2", 73 | author="DefTruth", 74 | author_email="qyjdef@163.com", 75 | license="GNU General Public License v3.0", 76 | packages=find_packages( 77 | exclude=( 78 | "build", 79 | "dist", 80 | "include", 81 | "csrc", 82 | "tests", 83 | "bench", 84 | "tmp", 85 | "cuffpa_py.egg-info", 86 | "ffpa_attn.egg-info", 87 | "__pycache__", 88 | "third_party", 89 | ) 90 | ), 91 | description="FFPA: Yet another Faster Flash Prefill Attention for large headdim, 1.8x~3x faster than SDPA EA.", 92 | long_description=get_long_description(), 93 | long_description_content_type="text/markdown", 94 | url="https://github.com/DefTruth/ffpa-attn-mma.git", 95 | ext_modules=ext_modules, 96 | cmdclass={"build_ext": BuildExtension}, 97 | python_requires=">=3.10", 98 | install_requires=fetch_requirements(), 99 | extras_require={ 100 | "all": [], 101 | "dev": [ 102 | "pre-commit", 103 | "packaging", 104 | "ninja", 105 | ], 106 | }, 107 | ) 108 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.a 3 | *.dylib 4 | *.dll 5 | *.lib 6 | .DS_Store 7 | build 8 | *.whl 9 | tmp 10 | __pycache__ 11 | *.onnx 12 | *.engine 13 | *.pt 14 | *.pth 15 | *.nsys* 16 | *.ncu* 17 | *.sqlite* 18 | *.engine 19 | *.bin 20 | outupt 21 | *.egg-info 22 | dist 23 | .tmp 24 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | -------------------------------------------------------------------------------- /tests/swizzle_layout.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def pretty_print_line( 5 | m: str = "", sep: str = "-", width: int = 130, return_str: bool = False 6 | ): 7 | res_len = width - len(m) 8 | left_len = int(res_len / 2) 9 | right_len = res_len - left_len 10 | pretty_line = sep * left_len + m + sep * right_len 11 | if not return_str: 12 | print(pretty_line) 13 | else: 14 | return pretty_line 15 | 16 | 17 | PERMUTED_DOCS_STRING = """---------------------------------------------------------------- 18 | [INFO] Assert smem store layout col_stride <= 16, prefer 16. | 19 | [INFO] For logical_col_stride > 16, we have to permute the | 20 | [INFO] smem store layout using col major ZigZag method: | 21 | [INFO] e.g, --> Q smem logical layout [Br][64]. | 22 | [INFO] --> col major ZigZag permuted --> | 23 | [INFO] --> Q smem store layout [4][Br][16]. | 24 | ----------------------------------------------------------------""" 25 | 26 | def swizzle_permuted_j( 27 | i: int, j: int, col_stride: int = 16, num_elems_per_128b: int = 8 28 | ): 29 | # i: row index; j: col index. col_stride <= 16. 30 | # assert col_stride <= 16, f"col_stride must <= 16, but got {col_stride}" 31 | # for col_stride > 16, we have to permute it using col major ZigZag order. 32 | # e.g, Q smem logical layout [Br,d]=[Br,64] -> store layout [4][Br][16]. 33 | return ( 34 | (int(j / num_elems_per_128b) ^ int(i / 4)) 35 | % (int(col_stride / num_elems_per_128b)) 36 | ) * num_elems_per_128b 37 | 38 | 39 | def print_smem_swizzle_layout( 40 | rows: int = 16, 41 | logical_col_stride: int = 16, 42 | num_elems_per_128b: int = 8, 43 | smem_pading: int = 0, 44 | show_logical_col_id: bool = False, 45 | use_logical_col_stride: bool = False, 46 | ): 47 | # ---------------------------------------------------------------- 48 | # [INFO] Assert smem store layout col_stride <= 16, prefer 16. | 49 | # [INFO] For logical_col_stride > 16, we have to permute the | 50 | # [INFO] smem store layout using col major ZigZag method: | 51 | # [INFO] e.g, --> Q smem logical layout [Br][64]. | 52 | # [INFO] --> col major ZigZag permuted --> | 53 | # [INFO] --> Q smem store layout [4][Br][16]. | 54 | # ---------------------------------------------------------------- 55 | # ---------------------------------------------------------------- 56 | # -------------------------swizzle layout------------------------- 57 | # --------------------logical col 0~64, step 8-------------------- 58 | # ---------------------smem col 0~16, step 8---------------------- 59 | # ---------------------------------------------------------------- 60 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 61 | # |row 0 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 62 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 63 | # |row 1 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 64 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 65 | # |row 2 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 66 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 67 | # |row 3 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 68 | # ---------------------------------------------------------------- 69 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 70 | # |row 4 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 71 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 72 | # |row 5 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 73 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 74 | # |row 6 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 75 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 76 | # |row 7 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 77 | # ---------------------------------------------------------------- 78 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 79 | # |row 8 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 80 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 81 | # |row 9 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 82 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 83 | # |row 10| 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 84 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 85 | # |row 11| 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 | 86 | # ---------------------------------------------------------------- 87 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 | 88 | # |row 12| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 89 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15| 90 | # |row 13| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 91 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23| 92 | # |row 14| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 93 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31| 94 | # |row 15| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 | 95 | # ---------------------------------------------------------------- 96 | str_len = 0 97 | total_banks = 0 98 | assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8" 99 | # 4 bytes per bank 100 | banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4 101 | if use_logical_col_stride: 102 | banks_per_col = int((logical_col_stride * 2) / 4) 103 | if logical_col_stride > 16: 104 | print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}") 105 | if smem_pading == 8: 106 | banks_per_col += 4 107 | print( 108 | f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}" 109 | ) 110 | 111 | banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4) 112 | for i in range(rows): 113 | layout_str_len = 0 114 | banks_str_len = 0 115 | 116 | # bank_layout_str 117 | banks_start = total_banks % 32 # 32 banks in total 118 | banks_end = banks_start + banks_per_col 119 | bank_layout_str = "|bank |" 120 | max_bank_str_len = 0 121 | if logical_col_stride >= 16 and (not use_logical_col_stride): 122 | for k in range(int(logical_col_stride / 16)): 123 | for j in range(banks_start, banks_end, banks_per_num_elems_per_128b): 124 | curr_bank_str = ( 125 | f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|" 126 | ) 127 | max_bank_str_len = max(max_bank_str_len, len(curr_bank_str)) 128 | bank_layout_str += curr_bank_str 129 | else: 130 | for j in range(banks_start, banks_end, banks_per_num_elems_per_128b): 131 | curr_bank_str = f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|" 132 | max_bank_str_len = max(max_bank_str_len, len(curr_bank_str)) 133 | bank_layout_str += curr_bank_str 134 | 135 | # smem_layout_str 136 | logical_col_ids = [] 137 | smem_layout_col_ids = [] 138 | if logical_col_stride >= 16 and (not use_logical_col_stride): 139 | for k in range(int(logical_col_stride / 16)): 140 | for j in range(0, 16, num_elems_per_128b): 141 | layout_j = swizzle_permuted_j(i, j, 16, num_elems_per_128b) 142 | logical_col_ids.append(k * 16 + j) 143 | smem_layout_col_ids.append(layout_j) 144 | else: 145 | for j in range(0, logical_col_stride, num_elems_per_128b): 146 | layout_j = swizzle_permuted_j( 147 | i, j, logical_col_stride, num_elems_per_128b 148 | ) 149 | logical_col_ids.append(j) 150 | smem_layout_col_ids.append(layout_j) 151 | 152 | smem_layout_str = f"|row {i:<2}|" 153 | 154 | r = 0 155 | for c, l in zip(logical_col_ids, smem_layout_col_ids): 156 | smem_layout_str += ( 157 | pretty_print_line( 158 | (f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"), 159 | sep=" ", 160 | width=(max_bank_str_len - 1), 161 | return_str=True, 162 | ) 163 | + "|" 164 | ) 165 | r += 1 166 | if logical_col_stride >= 16: 167 | if smem_pading == 8 and (r > 1 and r % 2 == 0): 168 | smem_layout_str += ( 169 | pretty_print_line( 170 | ("pad"), 171 | sep=" ", 172 | width=max_bank_str_len - 1, 173 | return_str=True, 174 | ) 175 | + "|" 176 | ) 177 | else: 178 | if smem_pading == 8: 179 | smem_layout_str += ( 180 | pretty_print_line( 181 | ("pad"), 182 | sep=" ", 183 | width=max_bank_str_len - 1, 184 | return_str=True, 185 | ) 186 | + "|" 187 | ) 188 | 189 | layout_str_len = len(smem_layout_str) 190 | str_len = max(layout_str_len, banks_str_len) 191 | 192 | # print banks and smem layout 193 | if i == 0: 194 | print("-" * str_len) 195 | pretty_print_line("swizzle layout", width=str_len) 196 | pretty_print_line( 197 | f"logical col 0~{logical_col_stride}, " f"step {num_elems_per_128b}", 198 | width=str_len, 199 | ) 200 | pretty_print_line( 201 | f"smem col 0~16, step {num_elems_per_128b}" 202 | if logical_col_stride >= 16 203 | else f"smem col 0~8, step {num_elems_per_128b}", 204 | width=str_len, 205 | ) 206 | print("-" * str_len) 207 | print(bank_layout_str) 208 | print(smem_layout_str) 209 | if (i + 1) % 4 == 0 and i != (rows - 1): 210 | print("-" * str_len) 211 | total_banks += banks_per_col 212 | print("-" * str_len) 213 | 214 | 215 | def get_args(): 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument("--rows", type=int, default=16) 218 | parser.add_argument("--smem-padding", "--pad", type=int, default=0) 219 | parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8) 220 | parser.add_argument( 221 | "--logical-col-stride", "--logical-col", "--col", type=int, default=64 222 | ) 223 | parser.add_argument( 224 | "--use-logical-col-stride", "--use-logical-col", action="store_true" 225 | ) 226 | parser.add_argument( 227 | "--show-logical-col-id", "--show-logical-col", action="store_true" 228 | ) 229 | return parser.parse_args() 230 | 231 | 232 | if __name__ == "__main__": 233 | args = get_args() 234 | print(args) 235 | print(PERMUTED_DOCS_STRING) 236 | print_smem_swizzle_layout( 237 | rows=args.rows, 238 | logical_col_stride=args.logical_col_stride, 239 | num_elems_per_128b=args.num_elems_per_128b, 240 | smem_pading=args.smem_padding, 241 | show_logical_col_id=args.show_logical_col_id, 242 | use_logical_col_stride=args.use_logical_col_stride, 243 | ) 244 | -------------------------------------------------------------------------------- /tests/test_ffpa_attn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import random 5 | import sys 6 | import time 7 | from functools import partial 8 | from typing import Optional 9 | 10 | import numpy as np 11 | import torch 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | from torch.nn.attention import sdpa_kernel, SDPBackend 15 | try: 16 | from flash_attn import flash_attn_func 17 | has_flash_attn = True 18 | except ImportError: 19 | flash_attn_func = None 20 | has_flash_attn = False 21 | 22 | sys.path.append("../") 23 | from env import ENV, pretty_print_line 24 | 25 | torch.set_grad_enabled(False) 26 | torch.set_printoptions( 27 | precision=6, threshold=8, edgeitems=3, linewidth=120, sci_mode=False 28 | ) 29 | 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--no-rand-q", "--no-rq", action="store_true") 34 | parser.add_argument("--no-rand-k", "--no-rk", action="store_true") 35 | parser.add_argument("--no-rand-v", "--no-rv", action="store_true") 36 | parser.add_argument("--no-rand-qkv", "--no-rqkv", action="store_true") 37 | parser.add_argument("--run-torch-unfused", "--torch", action="store_true") 38 | parser.add_argument("--run-flash-attn", "--flash", action="store_true") 39 | parser.add_argument("--check", action="store_true") 40 | parser.add_argument("--check-all", action="store_true") 41 | parser.add_argument("--show-all", "--show", action="store_true") 42 | parser.add_argument("--show-less", "--show-l", action="store_true") 43 | parser.add_argument("--show-matrix", action="store_true") 44 | parser.add_argument("--only-flops-matmul", "--flops-mm", action="store_true") 45 | parser.add_argument("--B", type=int, default=None) 46 | parser.add_argument("--H", type=int, default=None) 47 | parser.add_argument("--N", type=int, default=None) 48 | parser.add_argument("--D", type=int, default=None) 49 | parser.add_argument("--MAX-D", "--MD", type=int, default=1024) 50 | parser.add_argument("--seed", type=int, default=None) 51 | parser.add_argument("--sleep", type=float, default=0.05) 52 | parser.add_argument("--debug", action="store_true") 53 | parser.add_argument("--verbose", "--v", action="store_true") 54 | parser.add_argument("--warmup", "--w", type=int, default=1) 55 | parser.add_argument("--iters", "--i", type=int, default=5) 56 | parser.add_argument("--tag-hints", "--tags", "--hints", type=str, default=None) 57 | parser.add_argument( 58 | "--plot-flops", "--plot", action="store_true", help="Plot TFLOPS" 59 | ) 60 | parser.add_argument( 61 | "--save-dir", "--dir", type=str, default="tmp", help="Save dir for plot" 62 | ) 63 | parser.add_argument( 64 | "--save-tag", "--tag", type=str, default=None, help="Save name for plot" 65 | ) 66 | parser.add_argument("--gen-bench-table", "--gen-bench", action="store_true") 67 | parser.add_argument( 68 | "--force-build", "--build", action="store_true", help="Force build from sources" 69 | ) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | args = get_args() 75 | ENV.list_ffpa_env() 76 | 77 | 78 | def set_rand_seed(seed: int = 1): 79 | random.seed(seed) 80 | np.random.seed(seed) 81 | torch.manual_seed(seed) 82 | torch.cuda.manual_seed_all(seed) 83 | 84 | 85 | pretty_print_line() 86 | print(args) 87 | pretty_print_line() 88 | 89 | # Load the CUDA kernel as a python module 90 | ffpa_attn, use_ffpa_attn_package = ENV.try_load_ffpa_library( 91 | force_build=args.force_build, verbose=args.verbose 92 | ) 93 | if use_ffpa_attn_package: 94 | import ffpa_attn # tricks for IDE code search 95 | 96 | 97 | def get_mha_tflops( 98 | B: int, H: int, N: int, D: int, secs: float = 1.0, only_matmul: bool = False 99 | ): 100 | # Q @ K^T FLOPs 101 | flops_qk = B * H * N * N * (2 * D - 1) 102 | 103 | # Scaling FLOPs 104 | flops_scaling = B * H * N * N 105 | 106 | # Safe_Softmax FLOPs 107 | flops_row_max = B * H * N * (N - 1) # row max 108 | flops_subtract_max = B * H * N * N # sub max 109 | flops_exp = B * H * N * N # pointwise exp 110 | flops_row_sum = B * H * N * (N - 1) # row sum 111 | flops_normalization = B * H * N * N # normalization 112 | 113 | flops_safe_softmax = ( 114 | flops_row_max 115 | + flops_subtract_max 116 | + flops_exp 117 | + flops_row_sum 118 | + flops_normalization 119 | ) 120 | 121 | # P @ V FLOPs 122 | flops_pv = B * H * N * D * (2 * N - 1) 123 | 124 | # Total FLOPs 125 | total_flops = flops_qk + flops_scaling + flops_safe_softmax + flops_pv 126 | if only_matmul: 127 | total_flops = flops_qk + flops_pv 128 | 129 | # Convert to TFLOPS 130 | # 1 TFLOPS = 10^12 FLOPS 131 | # ref: https://imgtec.eetrend.com/blog/2021/100062210.html. 132 | tflops = total_flops * 1e-12 / (secs) 133 | 134 | return tflops 135 | 136 | 137 | MAX_TFLOPS = -1 138 | STATIS_INFO: dict[str, list[float | int] | set] = {} 139 | STATIS_INFO["headdim"] = set() 140 | TOATL_TFLOPS: dict[str, float] = {} 141 | SDPA_TFLOPS = -1 142 | 143 | def run_benchmark( 144 | perf_func: callable, 145 | q: torch.Tensor, 146 | k: torch.Tensor, 147 | v: torch.Tensor, 148 | tag: str, 149 | out: Optional[torch.Tensor] = None, 150 | s: Optional[torch.Tensor] = None, # DEBUG 151 | stages: int = -1, 152 | warmup: int = args.warmup, 153 | iters: int = args.iters, 154 | show_matrix: bool = args.show_matrix, 155 | only_show_improved: bool = not args.show_all, 156 | ): 157 | 158 | global MAX_TFLOPS 159 | global MAX_HEADDIM_CFG 160 | global SDPA_TFLOPS 161 | 162 | tag_hints: str = args.tag_hints # e.g "share-qkv,tiling-kv,swizzle" 163 | if tag_hints: 164 | tag_hints: list = tag_hints.strip().split(",") 165 | tag_hints.append("sdpa") 166 | tag_hints.append("unfused") 167 | hit_hints = False 168 | for hint in tag_hints: 169 | if hint in tag: 170 | hit_hints = True 171 | if not hit_hints: 172 | return None, None 173 | 174 | B, H, N, D = q.size() 175 | if "flash" in tag: 176 | B, N, H, D = q.size() 177 | 178 | if "unfused" in tag and (not args.run_torch_unfused): 179 | return None, None 180 | if "flash" in tag and ((not args.run_flash_attn) 181 | or (not has_flash_attn) or (D > 256)): 182 | return None, None 183 | 184 | STATIS_INFO["headdim"].add(D) 185 | 186 | max_supported_D = MAX_HEADDIM_CFG.get(tag, None) 187 | # skip if headdim not supported. 188 | if max_supported_D is not None: 189 | if D > max_supported_D: 190 | return None, None 191 | 192 | if out is not None: 193 | out.fill_(0) 194 | if s is not None: 195 | s.fill_(0) 196 | if out is not None: 197 | for i in range(warmup): 198 | if stages >= 1: 199 | if s is not None: 200 | perf_func(q, k, v, out, s, stages) 201 | else: 202 | perf_func(q, k, v, out, stages) 203 | else: 204 | perf_func(q, k, v, out) 205 | else: 206 | for i in range(warmup): 207 | _ = perf_func(q, k, v) 208 | 209 | torch.cuda.synchronize() 210 | start = time.time() 211 | # iters 212 | if out is not None: 213 | for i in range(iters): 214 | if stages >= 1: 215 | if s is not None: 216 | perf_func(q, k, v, out, s, stages) 217 | else: 218 | perf_func(q, k, v, out, stages) 219 | else: 220 | perf_func(q, k, v, out) 221 | else: 222 | for i in range(iters): 223 | out = perf_func(q, k, v) 224 | torch.cuda.synchronize() 225 | end = time.time() 226 | total_secs = end - start 227 | total_time = (end - start) * 1000 # ms 228 | mean_time = total_time / iters 229 | mean_secs = total_secs / iters 230 | 231 | TFLOPS = get_mha_tflops(B, H, N, D, mean_secs, only_matmul=args.only_flops_matmul) 232 | if tag in STATIS_INFO: 233 | STATIS_INFO[tag].append(int(round(TFLOPS))) 234 | else: 235 | STATIS_INFO[tag] = [] 236 | STATIS_INFO[tag].append(int(round(TFLOPS))) 237 | 238 | if "sdpa" in tag: 239 | SDPA_TFLOPS = TFLOPS 240 | out_info = f"{tag}" 241 | out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist() 242 | out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist() 243 | out_val_first = [round(v, 8) for v in out_val_first] 244 | out_val_last = [round(v, 8) for v in out_val_last] 245 | if not args.show_less: 246 | out_val = out_val_first[:2] 247 | out_val.append(out_val_last[-1]) 248 | else: 249 | out_val = out_val_first[:1] 250 | out_val = [f"{v:<12}" for v in out_val] 251 | if args.show_less: 252 | out_val = [v.strip() for v in out_val] 253 | 254 | if SDPA_TFLOPS > 0: 255 | speedup_sdpa = TFLOPS / SDPA_TFLOPS 256 | else: 257 | speedup_sdpa = 1.0 258 | 259 | # caculate TFLOPS improved. 260 | if TFLOPS > MAX_TFLOPS: 261 | if MAX_TFLOPS > 0: 262 | improve = ((TFLOPS - MAX_TFLOPS) / MAX_TFLOPS) * 100 263 | improve = round(improve, 2) 264 | else: 265 | improve = 0 266 | 267 | MAX_TFLOPS = TFLOPS 268 | print( 269 | f"{out_info:>25}: {out_val}, time:{str(mean_time)[:8]}ms, " 270 | f"TFLOPS:{TFLOPS:<6.2f}(+{improve:<5.2f}%)(~{speedup_sdpa:<4.2f}x)" 271 | ) 272 | else: 273 | improve = 0 274 | if (not only_show_improved) or (("flash" in tag) or ("sdpa" in tag)): 275 | print( 276 | f"{out_info:>25}: {out_val}, time:{str(mean_time)[:8]}ms, " 277 | f"TFLOPS:{TFLOPS:<6.2f}(+{improve:<5.2f}%)(~{speedup_sdpa:<4.2f}x)" 278 | ) 279 | 280 | if show_matrix: 281 | print(out) 282 | time.sleep(args.sleep) 283 | torch.cuda.synchronize() 284 | return out.clone() if args.check else out, mean_time 285 | 286 | 287 | def get_best_tflops(): 288 | global STATIS_INFO 289 | if ENV.enable_all_mutistages(): 290 | sdpa_tflops = STATIS_INFO["(sdpa)"] 291 | ffpa_l1_f32_best_tflops = [ 292 | max(x, y, z, w) 293 | for x, y, z, w in zip( 294 | STATIS_INFO["(ffpa+acc+f32+L1+stage1)"], 295 | STATIS_INFO["(ffpa+acc+f32+L1+stage2)"], 296 | STATIS_INFO["(ffpa+acc+f32+L1+stage3)"], 297 | STATIS_INFO["(ffpa+acc+f32+L1+stage4)"], 298 | ) 299 | ] 300 | ffpa_l1_f16_best_tflops = [ 301 | max(x, y, z, w) 302 | for x, y, z, w in zip( 303 | STATIS_INFO["(ffpa+acc+f16+L1+stage1)"], 304 | STATIS_INFO["(ffpa+acc+f16+L1+stage2)"], 305 | STATIS_INFO["(ffpa+acc+f16+L1+stage3)"], 306 | STATIS_INFO["(ffpa+acc+f16+L1+stage4)"], 307 | ) 308 | ] 309 | else: 310 | ffpa_l1_f32_best_tflops = [ 311 | max(x, y) 312 | for x, y in zip( 313 | STATIS_INFO["(ffpa+acc+f32+L1+stage1)"], 314 | STATIS_INFO["(ffpa+acc+f32+L1+stage2)"], 315 | ) 316 | ] 317 | ffpa_l1_f16_best_tflops = [ 318 | max(x, y) 319 | for x, y in zip( 320 | STATIS_INFO["(ffpa+acc+f16+L1+stage1)"], 321 | STATIS_INFO["(ffpa+acc+f16+L1+stage2)"], 322 | ) 323 | ] 324 | 325 | # calculate improved 326 | ffpa_l1_f32_speedup = [ 327 | round(f / s, 2) for f, s in zip(ffpa_l1_f32_best_tflops, sdpa_tflops) 328 | ] 329 | ffpa_l1_f16_speedup = [ 330 | round(f / s, 2) for f, s in zip(ffpa_l1_f16_best_tflops, sdpa_tflops) 331 | ] 332 | STATIS_INFO["ffpa+acc-f32+L1(best)"] = ffpa_l1_f32_best_tflops 333 | STATIS_INFO["ffpa+acc-f16+L1(best)"] = ffpa_l1_f16_best_tflops 334 | STATIS_INFO["ffpa+acc-f32+L1(speedup)"] = ffpa_l1_f32_speedup 335 | STATIS_INFO["ffpa+acc-f16+L1(speedup)"] = ffpa_l1_f16_speedup 336 | 337 | return ( 338 | sdpa_tflops, 339 | ffpa_l1_f32_best_tflops, 340 | ffpa_l1_f16_best_tflops, 341 | ffpa_l1_f32_speedup, 342 | ffpa_l1_f16_speedup, 343 | ) 344 | 345 | 346 | def gen_bench_markdown_table(): 347 | global STATIS_INFO 348 | STATIS_INFO["headdim"] = sorted(list(STATIS_INFO["headdim"])) 349 | pretty_print_line("FFPA-L1 Benchmark Data") 350 | print(STATIS_INFO) 351 | pretty_print_line() 352 | headdims = [str(d) for d in STATIS_INFO["headdim"]] 353 | num_headdim = len(headdims) 354 | table_header = "|Algorithm|" + "|".join(headdims) + "|\n" 355 | table_header += "|:---:|" + ":---:|" * num_headdim 356 | ( 357 | sdpa_tflops, 358 | ffpa_l1_f32_best_tflops, 359 | ffpa_l1_f16_best_tflops, 360 | ffpa_l1_f32_speedup, 361 | ffpa_l1_f16_speedup, 362 | ) = get_best_tflops() 363 | 364 | # sdpa, ffpa, speedup strings. 365 | sdpa_tflops_str = "|SDPA EA|" + "|".join([str(s) + "T" for s in sdpa_tflops]) + "|" 366 | ffpa_l1_f32_tflops_str = ( 367 | "|FFPA L1*|" + "|".join([str(f) + "T" for f in ffpa_l1_f32_best_tflops]) + "|" 368 | ) 369 | ffpa_l1_f32_speedup_str = ( 370 | "|Speedup|" + "|".join([str(fs) + "x" for fs in ffpa_l1_f32_speedup]) + "|" 371 | ) 372 | ffpa_l1_f16_tflops_str = ( 373 | "|FFPA L1^|" + "|".join([str(f) + "T" for f in ffpa_l1_f16_best_tflops]) + "|" 374 | ) 375 | ffpa_l1_f16_speedup_str = ( 376 | "|Speedup|" + "|".join([str(fs) + "x" for fs in ffpa_l1_f16_speedup]) + "|" 377 | ) 378 | pretty_print_line("FFPA-L1 Best Benchmark Markdown Table:") 379 | print("\n") 380 | print(table_header) 381 | print(sdpa_tflops_str) 382 | print(ffpa_l1_f32_tflops_str) 383 | print(ffpa_l1_f32_speedup_str) 384 | print(ffpa_l1_f16_tflops_str) 385 | print(ffpa_l1_f16_speedup_str) 386 | print("\n") 387 | 388 | 389 | def sort_tflops_by_headdim(): 390 | global STATIS_INFO 391 | NEW_STATIS_INFO = {} 392 | headdims = sorted(list(STATIS_INFO["headdim"]), reverse=True) 393 | NEW_STATIS_INFO["headdim"] = headdims 394 | for tag, tflops in STATIS_INFO.items(): 395 | if tag == "headdim": 396 | continue 397 | new_tflops = [] 398 | for d in headdims: 399 | idx = STATIS_INFO["headdim"].index(d) 400 | new_tflops.append(tflops[idx]) 401 | NEW_STATIS_INFO[tag] = new_tflops 402 | return NEW_STATIS_INFO 403 | 404 | 405 | def plot_speedup_bar( 406 | speedup: list[float], extra_tag: str = "ffpa+acc+f32+L1", headdim: list[int] = None 407 | ): 408 | import matplotlib.pyplot as plt 409 | 410 | _ = plt.subplots(figsize=(16, 9))[1] # fig, axs 411 | plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.07) 412 | x = range(len(speedup)) 413 | # Plot the bar chart, setting a different color for each bar 414 | random.seed(0) 415 | for i, value in enumerate(speedup): 416 | random_color = (random.random(), random.random(), random.random()) 417 | plt.bar(i, value, color=random_color) 418 | plt.xlabel("Headdim(D)", fontsize=15, fontweight="bold") 419 | plt.xticks(x, headdim, fontsize=15, fontweight="bold") 420 | plt.ylabel(f"{extra_tag.upper()} SpeedUp", fontsize=15, fontweight="bold") 421 | plt.title( 422 | f"SpeedUp of {extra_tag.upper()} vs SDPA EA, {ENV.get_device_name()}", 423 | fontsize=15, 424 | fontweight="bold", 425 | ) 426 | # Set the range of the y-axis, adjusted according to the data 427 | plt.ylim(0, max(speedup) + 0.5) 428 | plt.yticks(fontweight="bold") 429 | for i, v in enumerate(speedup): 430 | plt.text(i, v + 0.1, str(v), ha="center", fontsize=20, fontweight="bold") 431 | plt.grid(True) 432 | # Display the graph 433 | device_name = ENV.get_device_name().replace(" ", "_") 434 | if args.save_tag: 435 | save_path = ( 436 | f"{args.save_dir}/{device_name}_{args.save_tag}_{extra_tag}_Speedup.png" 437 | ) 438 | else: 439 | save_path = f"{args.save_dir}/{device_name}_{extra_tag}_Speedup.png" 440 | plt.tight_layout() 441 | plt.savefig(save_path, dpi=300) 442 | pretty_print_line(f"plot FFPA Speedup bar done, saved as {save_path}") 443 | plt.close() 444 | 445 | 446 | def plot_tflops(level: str = "L1"): 447 | import matplotlib.pyplot as plt 448 | import numpy as np 449 | 450 | ax: plt.Axes = plt.subplots(figsize=(16, 9))[1] # fig, axs 451 | plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.07) 452 | B = 1 if not args.B else args.B 453 | H = 48 if not args.H else args.H 454 | N = 8192 if not args.N else args.N 455 | ax.set_title( 456 | f"FFPA {level} vs SDPA EA, {ENV.get_device_name()}, " 457 | f"B={B}, H={H}, N={N}, Warmup={args.warmup}, " 458 | f"Iters={args.iters}", 459 | fontsize=15, 460 | fontweight="bold", 461 | ) 462 | ax.set_xlabel("Headdim(D)", fontsize=15, fontweight="bold") 463 | ax.set_ylabel("TFLOPS", fontsize=15, fontweight="bold") 464 | ax.grid(True) 465 | 466 | get_best_tflops() 467 | new_statis_info = sort_tflops_by_headdim() 468 | 469 | ax.set_xticks(np.arange(0, len(new_statis_info["headdim"]), 1)) 470 | ax.set_xticklabels( 471 | new_statis_info["headdim"], 472 | rotation=45, 473 | ha="right", 474 | fontsize=15, 475 | fontweight="bold", 476 | ) 477 | exclude_tags = [] 478 | exclude_tags.append("headdim") 479 | exclude_tags = set(exclude_tags) 480 | 481 | draw_tags = list(new_statis_info.keys()) 482 | draw_tags.remove("headdim") 483 | draw_tags.remove("ffpa+acc-f32+L1(speedup)") 484 | draw_tags.remove("ffpa+acc-f16+L1(speedup)") 485 | draw_tags = set(draw_tags) 486 | draw_tags.add("ffpa+acc-f32+L1(best)") 487 | draw_tags.add("ffpa+acc-f16+L1(best)") 488 | draw_tags = sorted(list(draw_tags)) 489 | 490 | def skip_it(tag: str) -> bool: 491 | for etag in exclude_tags: 492 | if etag in tag: 493 | return True 494 | if tag not in draw_tags: 495 | return True 496 | return False 497 | 498 | for tag, tflops in new_statis_info.items(): 499 | if skip_it(tag): 500 | continue 501 | if tag == "(sdpa)": 502 | ax.plot(tflops, label=tag, linewidth=3, color="green") 503 | elif tag == "(unfused)": 504 | ax.plot(tflops, label=tag, linewidth=3, color="black") 505 | else: 506 | if "ffpa+acc-f32+L1(best)" in tag: 507 | ax.plot(tflops, label=tag, linewidth=4, color="blue") 508 | elif "ffpa+acc-f16+L1(best)" in tag: 509 | ax.plot(tflops, label=tag, linewidth=4, color="red") 510 | else: 511 | ax.plot(tflops, label=tag, linestyle="--") 512 | 513 | for label in ax.get_yticklabels(): 514 | label.set_fontweight("bold") 515 | 516 | ax.legend() 517 | device_name = ENV.get_device_name().replace(" ", "_") 518 | if args.save_tag: 519 | save_path = f"{args.save_dir}/{device_name}_{args.save_tag}.png" 520 | else: 521 | save_path = f"{args.save_dir}/{device_name}.png" 522 | os.makedirs(args.save_dir, exist_ok=True) 523 | plt.tight_layout() 524 | plt.savefig(save_path, dpi=300) 525 | pretty_print_line(f"plot FFPA TFLOPS done, saved as {save_path}") 526 | plt.close() 527 | plot_speedup_bar( 528 | new_statis_info["ffpa+acc-f32+L1(speedup)"], 529 | "ffpa+acc+f32+L1", 530 | new_statis_info["headdim"], 531 | ) 532 | plot_speedup_bar( 533 | new_statis_info["ffpa+acc-f16+L1(speedup)"], 534 | "ffpa+acc+f16+L1", 535 | new_statis_info["headdim"], 536 | ) 537 | 538 | 539 | def get_qkvo(B, H, N, D): 540 | if not (args.no_rand_q or args.no_rand_qkv): 541 | q = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") 542 | else: 543 | q = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() 544 | if not (args.no_rand_k or args.no_rand_qkv): 545 | k = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") 546 | else: 547 | k = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() 548 | if not (args.no_rand_v or args.no_rand_qkv): 549 | v = torch.randn((B, H, N, D), dtype=torch.half, device="cuda") 550 | else: 551 | v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous() 552 | 553 | o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous() 554 | # transpose (H,N) -> (N,H) for FA2. 555 | fq = q.transpose(1, 2).contiguous() 556 | fk = k.transpose(1, 2).contiguous() 557 | fv = v.transpose(1, 2).contiguous() 558 | # transpose (N,D) -> (D,N) for V smem swizzle. 559 | tk = k.transpose(-2, -1).contiguous() # [B,H,N,D] -> [B,H,D,N] 560 | tv = v.transpose(-2, -1).contiguous() # [B,H,N,D] -> [B,H,D,N] 561 | 562 | return q, k, v, o, fq, fk, fv, tk, tv 563 | 564 | 565 | # un-fused naive attn 566 | def unfused_standard_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): 567 | att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) 568 | att = F.softmax(att, dim=-1) 569 | y = att @ v 570 | return y 571 | 572 | 573 | def sdpa(q: Tensor, k: Tensor, v: Tensor, use_flash: bool = False): 574 | if not use_flash: 575 | with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): 576 | out: Tensor = F.scaled_dot_product_attention(q, k, v) 577 | else: 578 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 579 | out: Tensor = F.scaled_dot_product_attention(q, k, v) 580 | return out 581 | 582 | 583 | def check_all_close( 584 | out_flash_or_sdpa: torch.Tensor, 585 | out_mma: torch.Tensor, 586 | tag: str = "out_mma", 587 | check_all: bool = False, 588 | is_flash: bool = False, 589 | ): 590 | if any((out_flash_or_sdpa is None, out_mma is None)): 591 | return 592 | if is_flash: 593 | true_tag = "out_flash" 594 | out_flash_or_sdpa = out_flash_or_sdpa.transpose(1, 2) 595 | else: 596 | true_tag = "out_sdpa" 597 | if check_all: 598 | for i in range(int(N / 8)): 599 | if i < 4: 600 | pretty_print_line() 601 | print(f"{true_tag}[:, :, {(i * 8)}:{(i + 1) * 8}, :]:\n") 602 | print(out_flash_or_sdpa[:, :, (i * 8) : (i + 1) * 8, :].float()) 603 | print(f"{tag}[:, :, {(i * 8)}:{(i + 1) * 8}, :]:\n") 604 | print(out_mma[:, :, (i * 8) : (i + 1) * 8, :].float()) 605 | pretty_print_line() 606 | diff = torch.abs(out_flash_or_sdpa - out_mma) 607 | all_close = str(torch.allclose(out_flash_or_sdpa, out_mma, atol=1e-2)) 608 | pretty_print_line( 609 | f"{true_tag} vs {tag:<15}, all close: {all_close:<6}, " 610 | f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, " 611 | f"mean diff: {diff.mean().item():.6f}", 612 | sep="", 613 | mode="left", 614 | ) 615 | 616 | 617 | Bs = [1] if not args.B else [args.B] 618 | Hs = [48] if not args.H else [args.H] 619 | Ns = [8192] if not args.N else [args.N] 620 | Ds = list(range(320, args.MAX_D + 64, 64)) if not args.D else [args.D] 621 | # batch_size, n_head, seq_len, head_dim (B,H,N,D) 622 | BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds] 623 | # max headdim supported for different methods. skip if D > max_D. 624 | MAX_HEADDIM_CFG: dict[str, int] = { 625 | # FFPA, SDPA, Naive MHA. 626 | "(sdpa)": 4096, # may no limit 627 | "(unfused)": 4096, # may no limit 628 | "(ffpa+acc+f16+L1+stage1)": 1024, # may no limit 629 | "(ffpa+acc+f16+L1+stage2)": 1024, # may no limit 630 | "(ffpa+acc+f16+L1+stage3)": 1024, # may no limit 631 | "(ffpa+acc+f16+L1+stage4)": 1024, # may no limit 632 | "(ffpa+acc+f32+L1+stage1)": 1024, # may no limit 633 | "(ffpa+acc+f32+L1+stage2)": 1024, # may no limit 634 | "(ffpa+acc+f32+L1+stage3)": 1024, # may no limit 635 | "(ffpa+acc+f32+L1+stage4)": 1024, # may no limit 636 | } 637 | 638 | seed = args.seed if args.seed else random.choice(range(10000)) 639 | set_rand_seed(seed) 640 | pretty_print_line() 641 | pretty_print_line( 642 | f"B: batch_size, H: n_head, N: seq_len, D: head_dim, " 643 | f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}" 644 | ) 645 | 646 | for (B, H, N, D) in BHNDs: 647 | MAX_TFLOPS = -1 648 | SDPA_TFLOPS = -1 649 | q, k, v, o, fq, fk, fv, tk, tv = get_qkvo(B, H, N, D) 650 | torch.cuda.synchronize() 651 | pretty_print_line() 652 | pretty_print_line( 653 | f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}" 654 | ) 655 | if not use_ffpa_attn_package: 656 | # Naive MHA, FFPA, SDPA (D > 256) 657 | out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "(unfused)") 658 | out_sdpa, _ = run_benchmark( 659 | partial(sdpa, use_flash=(D <= 256)), q, k, v, "(sdpa)" 660 | ) 661 | out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") 662 | out_ffpa_l1_f321, _ = run_benchmark( 663 | ffpa_attn.ffpa_mma_acc_f32_L1, 664 | q, 665 | k, 666 | v, 667 | "(ffpa+acc+f32+L1+stage1)", 668 | o, 669 | stages=1, 670 | ) 671 | out_ffpa_l1_f322, _ = run_benchmark( 672 | ffpa_attn.ffpa_mma_acc_f32_L1, 673 | q, 674 | k, 675 | v, 676 | "(ffpa+acc+f32+L1+stage2)", 677 | o, 678 | stages=2, 679 | ) 680 | if ENV.enable_all_mutistages(): 681 | out_ffpa_l1_f323, _ = run_benchmark( 682 | ffpa_attn.ffpa_mma_acc_f32_L1, 683 | q, 684 | k, 685 | v, 686 | "(ffpa+acc+f32+L1+stage3)", 687 | o, 688 | stages=3, 689 | ) 690 | out_ffpa_l1_f324, _ = run_benchmark( 691 | ffpa_attn.ffpa_mma_acc_f32_L1, 692 | q, 693 | k, 694 | v, 695 | "(ffpa+acc+f32+L1+stage4)", 696 | o, 697 | stages=4, 698 | ) 699 | out_ffpa_l1_f161, _ = run_benchmark( 700 | ffpa_attn.ffpa_mma_acc_f16_L1, 701 | q, 702 | k, 703 | v, 704 | "(ffpa+acc+f16+L1+stage1)", 705 | o, 706 | stages=1, 707 | ) 708 | out_ffpa_l1_f162, _ = run_benchmark( 709 | ffpa_attn.ffpa_mma_acc_f16_L1, 710 | q, 711 | k, 712 | v, 713 | "(ffpa+acc+f16+L1+stage2)", 714 | o, 715 | stages=2, 716 | ) 717 | if ENV.enable_all_mutistages(): 718 | out_ffpa_l1_f163, _ = run_benchmark( 719 | ffpa_attn.ffpa_mma_acc_f16_L1, 720 | q, 721 | k, 722 | v, 723 | "(ffpa+acc+f16+L1+stage3)", 724 | o, 725 | stages=3, 726 | ) 727 | out_ffpa_l1_f164, _ = run_benchmark( 728 | ffpa_attn.ffpa_mma_acc_f16_L1, 729 | q, 730 | k, 731 | v, 732 | "(ffpa+acc+f16+L1+stage4)", 733 | o, 734 | stages=4, 735 | ) 736 | else: 737 | # Naive MHA, FFPA, SDPA (D > 256) 738 | out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "(unfused)") 739 | out_sdpa, _ = run_benchmark( 740 | partial(sdpa, use_flash=(D <= 256)), q, k, v, "(sdpa)" 741 | ) 742 | out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") 743 | out_ffpa_l1_f321, _ = run_benchmark( 744 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32), 745 | q, 746 | k, 747 | v, 748 | "(ffpa+acc+f32+L1+stage1)", 749 | o, 750 | stages=1, 751 | ) 752 | out_ffpa_l1_f322, _ = run_benchmark( 753 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32), 754 | q, 755 | k, 756 | v, 757 | "(ffpa+acc+f32+L1+stage2)", 758 | o, 759 | stages=2, 760 | ) 761 | if ENV.enable_all_mutistages(): 762 | out_ffpa_l1_f323, _ = run_benchmark( 763 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32), 764 | q, 765 | k, 766 | v, 767 | "(ffpa+acc+f32+L1+stage3)", 768 | o, 769 | stages=3, 770 | ) 771 | out_ffpa_l1_f324, _ = run_benchmark( 772 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32), 773 | q, 774 | k, 775 | v, 776 | "(ffpa+acc+f32+L1+stage4)", 777 | o, 778 | stages=4, 779 | ) 780 | out_ffpa_l1_f161, _ = run_benchmark( 781 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16), 782 | q, 783 | k, 784 | v, 785 | "(ffpa+acc+f16+L1+stage1)", 786 | o, 787 | stages=1, 788 | ) 789 | out_ffpa_l1_f162, _ = run_benchmark( 790 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16), 791 | q, 792 | k, 793 | v, 794 | "(ffpa+acc+f16+L1+stage2)", 795 | o, 796 | stages=2, 797 | ) 798 | if ENV.enable_all_mutistages(): 799 | out_ffpa_l1_f163, _ = run_benchmark( 800 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16), 801 | q, 802 | k, 803 | v, 804 | "(ffpa+acc+f16+L1+stage3)", 805 | o, 806 | stages=3, 807 | ) 808 | out_ffpa_l1_f164, _ = run_benchmark( 809 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16), 810 | q, 811 | k, 812 | v, 813 | "(ffpa+acc+f16+L1+stage4)", 814 | o, 815 | stages=4, 816 | ) 817 | pretty_print_line() 818 | 819 | del q 820 | del k 821 | del v 822 | del o 823 | del fq 824 | del fk 825 | del fv 826 | del tk 827 | del tv 828 | torch.cuda.empty_cache() 829 | torch.cuda.synchronize() 830 | if args.check: 831 | check_all_close(out_sdpa, out_ffpa_l1_f321, "out_ffpa_l1_f321", args.check_all) 832 | check_all_close(out_sdpa, out_ffpa_l1_f322, "out_ffpa_l1_f322", args.check_all) 833 | if ENV.enable_all_mutistages(): 834 | check_all_close( 835 | out_sdpa, out_ffpa_l1_f323, "out_ffpa_l1_f323", args.check_all 836 | ) 837 | check_all_close( 838 | out_sdpa, out_ffpa_l1_f324, "out_ffpa_l1_f324", args.check_all 839 | ) 840 | check_all_close(out_sdpa, out_ffpa_l1_f161, "out_ffpa_l1_f161", args.check_all) 841 | check_all_close(out_sdpa, out_ffpa_l1_f162, "out_ffpa_l1_f161", args.check_all) 842 | if ENV.enable_all_mutistages(): 843 | check_all_close( 844 | out_sdpa, out_ffpa_l1_f163, "out_ffpa_l1_f163", args.check_all 845 | ) 846 | check_all_close( 847 | out_sdpa, out_ffpa_l1_f164, "out_ffpa_l1_f164", args.check_all 848 | ) 849 | pretty_print_line() 850 | 851 | if args.gen_bench_table: 852 | gen_bench_markdown_table() 853 | 854 | if args.plot_flops: 855 | plot_tflops() 856 | -------------------------------------------------------------------------------- /tests/test_fused_mla.py: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------