├── .dev
├── .clang-format
├── .clang_format.hook
├── .cpplint_pre_commit.hook
├── .pre-commit-config-cpp.yaml
├── .pre-commit-config.yaml
├── clear.sh
├── commit-prepare.sh
├── init_dev.sh
├── init_prod.sh
├── init_prod_mini.sh
├── install.sh
└── uninstall.sh
├── .github
├── .gitignore
└── workflows
│ └── issue.yml
├── .gitignore
├── .gitmodules
├── LICENSE
├── MANIFEST.in
├── README.md
├── bench
├── .gitignore
├── NVIDIA_A30.png
├── NVIDIA_A30_ffpa+acc+f16+L1_Speedup.png
├── NVIDIA_A30_ffpa+acc+f32+L1_Speedup.png
├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png
├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f16+L1_Speedup.png
├── NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f32+L1_Speedup.png
├── NVIDIA_GeForce_RTX_4090.png
├── NVIDIA_GeForce_RTX_4090_ffpa+acc+f16+L1_Speedup.png
├── NVIDIA_GeForce_RTX_4090_ffpa+acc+f32+L1_Speedup.png
├── NVIDIA_L20.png
├── NVIDIA_L20_ffpa+acc+f16+L1_Speedup.png
├── NVIDIA_L20_ffpa+acc+f32+L1_Speedup.png
├── bank_conflicts_check.sh
└── bench.sh
├── csrc
├── .gitignore
├── cuffpa
│ ├── ffpa_attn_F16F16F16_L1.cu
│ ├── ffpa_attn_F16F16F32_L1.cu
│ ├── ffpa_attn_templates_L1.cuh
│ └── launch_templates.cuh
├── deprecated
│ ├── faster_prefill_attn_F16F16F16F16_L1.cu
│ └── faster_prefill_attn_F32F16F16F32_L1.cu
├── extension
│ ├── .gitignore
│ ├── fused_mla_F16F16F16_L1.cu
│ ├── fused_mla_F16F16F32_L1.cu
│ ├── fused_mla_templates_L1.cuh
│ └── launch_templates.cuh
└── pybind
│ ├── ffpa_attn_api.cc
│ └── fused_mla_api.cc
├── env.py
├── ffpa_attn
├── .gitignore
├── __init__.py
├── interface.py
└── version.py
├── include
├── .gitignore
├── cuffpa
│ ├── cp_async.cuh
│ ├── deprecated
│ │ ├── mma_utils.cuh
│ │ └── smem_swizzle.cuh
│ ├── logging.cuh
│ ├── mma.cuh
│ ├── prefill.cuh
│ ├── swizzle.cuh
│ ├── utils.cuh
│ └── warp.cuh
└── extension
│ └── .gitignore
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── .gitignore
├── requirements.txt
├── swizzle_layout.py
├── test_ffpa_attn.py
└── test_fused_mla.py
/.dev/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | Language: Cpp
3 | # BasedOnStyle: LLVM
4 | AccessModifierOffset: -1
5 | AlignAfterOpenBracket: Align
6 | AlignArrayOfStructures: None
7 | AlignConsecutiveMacros: None
8 | AlignConsecutiveAssignments: None
9 | AlignConsecutiveBitFields: None
10 | AlignConsecutiveDeclarations: None
11 | AlignEscapedNewlines: Right
12 | AlignOperands: Align
13 | AlignTrailingComments: true
14 | AllowAllArgumentsOnNextLine: true
15 | AllowAllConstructorInitializersOnNextLine: true
16 | AllowAllParametersOfDeclarationOnNextLine: true
17 | AllowShortEnumsOnASingleLine: true
18 | AllowShortBlocksOnASingleLine: Never
19 | AllowShortCaseLabelsOnASingleLine: false
20 | AllowShortFunctionsOnASingleLine: All
21 | AllowShortLambdasOnASingleLine: All
22 | AllowShortIfStatementsOnASingleLine: Never
23 | AllowShortLoopsOnASingleLine: false
24 | AlwaysBreakAfterDefinitionReturnType: None
25 | AlwaysBreakAfterReturnType: None
26 | AlwaysBreakBeforeMultilineStrings: false
27 | AlwaysBreakTemplateDeclarations: MultiLine
28 | AttributeMacros:
29 | - __capability
30 | BinPackArguments: true
31 | BinPackParameters: true
32 | BraceWrapping:
33 | AfterCaseLabel: false
34 | AfterClass: false
35 | AfterControlStatement: Never
36 | AfterEnum: false
37 | AfterFunction: false
38 | AfterNamespace: false
39 | AfterObjCDeclaration: false
40 | AfterStruct: false
41 | AfterUnion: false
42 | AfterExternBlock: false
43 | BeforeCatch: false
44 | BeforeElse: false
45 | BeforeLambdaBody: false
46 | BeforeWhile: false
47 | IndentBraces: false
48 | SplitEmptyFunction: true
49 | SplitEmptyRecord: true
50 | SplitEmptyNamespace: true
51 | BreakBeforeBinaryOperators: None
52 | BreakBeforeConceptDeclarations: true
53 | BreakBeforeBraces: Attach
54 | BreakBeforeInheritanceComma: false
55 | BreakInheritanceList: BeforeColon
56 | BreakBeforeTernaryOperators: true
57 | BreakConstructorInitializersBeforeComma: false
58 | BreakConstructorInitializers: BeforeColon
59 | BreakAfterJavaFieldAnnotations: false
60 | BreakStringLiterals: true
61 | ColumnLimit: 80
62 | # CommentPragmas: '^ IWYU pragma:'
63 | # CommentPragmas: '^[^ ]'
64 | CommentPragmas: '^\\.+'
65 | CompactNamespaces: false
66 | ConstructorInitializerAllOnOneLineOrOnePerLine: false
67 | ConstructorInitializerIndentWidth: 4
68 | ContinuationIndentWidth: 4
69 | Cpp11BracedListStyle: true
70 | DeriveLineEnding: true
71 | DerivePointerAlignment: false
72 | DisableFormat: false
73 | EmptyLineAfterAccessModifier: Never
74 | EmptyLineBeforeAccessModifier: LogicalBlock
75 | ExperimentalAutoDetectBinPacking: false
76 | FixNamespaceComments: true
77 | ForEachMacros:
78 | - foreach
79 | - Q_FOREACH
80 | - BOOST_FOREACH
81 | IfMacros:
82 | - KJ_IF_MAYBE
83 | IncludeBlocks: Preserve
84 | IncludeCategories:
85 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/'
86 | Priority: 2
87 | SortPriority: 0
88 | CaseSensitive: false
89 | - Regex: '^(<|"(gtest|gmock|isl|json)/)'
90 | Priority: 3
91 | SortPriority: 0
92 | CaseSensitive: false
93 | - Regex: '.*'
94 | Priority: 1
95 | SortPriority: 0
96 | CaseSensitive: false
97 | IncludeIsMainRegex: '(Test)?$'
98 | IncludeIsMainSourceRegex: ''
99 | IndentAccessModifiers: false
100 | IndentCaseLabels: false
101 | IndentCaseBlocks: false
102 | IndentGotoLabels: true
103 | IndentPPDirectives: None
104 | IndentExternBlock: AfterExternBlock
105 | IndentRequires: false
106 | IndentWidth: 2
107 | IndentWrappedFunctionNames: false
108 | InsertTrailingCommas: None
109 | JavaScriptQuotes: Leave
110 | JavaScriptWrapImports: true
111 | KeepEmptyLinesAtTheStartOfBlocks: true
112 | LambdaBodyIndentation: Signature
113 | MacroBlockBegin: ''
114 | MacroBlockEnd: ''
115 | MaxEmptyLinesToKeep: 1
116 | NamespaceIndentation: None
117 | ObjCBinPackProtocolList: Auto
118 | ObjCBlockIndentWidth: 2
119 | ObjCBreakBeforeNestedBlockParam: true
120 | ObjCSpaceAfterProperty: false
121 | ObjCSpaceBeforeProtocolList: true
122 | PenaltyBreakAssignment: 2
123 | PenaltyBreakBeforeFirstCallParameter: 19
124 | PenaltyBreakComment: 300
125 | PenaltyBreakFirstLessLess: 120
126 | PenaltyBreakString: 1000
127 | PenaltyBreakTemplateDeclaration: 10
128 | PenaltyExcessCharacter: 1000000
129 | PenaltyReturnTypeOnItsOwnLine: 60
130 | PenaltyIndentedWhitespace: 0
131 | PointerAlignment: Left
132 | PPIndentWidth: -1
133 | ReferenceAlignment: Pointer
134 | ReflowComments: false
135 | ShortNamespaceLines: 1
136 | SortIncludes: CaseSensitive
137 | SortJavaStaticImport: Before
138 | SortUsingDeclarations: true
139 | SpaceAfterCStyleCast: false
140 | SpaceAfterLogicalNot: false
141 | SpaceAfterTemplateKeyword: true
142 | SpaceBeforeAssignmentOperators: true
143 | SpaceBeforeCaseColon: false
144 | SpaceBeforeCpp11BracedList: false
145 | SpaceBeforeCtorInitializerColon: true
146 | SpaceBeforeInheritanceColon: true
147 | SpaceBeforeParens: ControlStatements
148 | SpaceAroundPointerQualifiers: Default
149 | SpaceBeforeRangeBasedForLoopColon: true
150 | SpaceInEmptyBlock: false
151 | SpaceInEmptyParentheses: false
152 | SpacesBeforeTrailingComments: 2
153 | SpacesInAngles: Never
154 | SpacesInConditionalStatement: false
155 | SpacesInContainerLiterals: true
156 | SpacesInCStyleCastParentheses: false
157 | SpacesInLineCommentPrefix:
158 | Minimum: 1
159 | Maximum: -1
160 | SpacesInParentheses: false
161 | SpacesInSquareBrackets: false
162 | SpaceBeforeSquareBrackets: false
163 | BitFieldColonSpacing: Both
164 | Standard: Latest
165 | StatementAttributeLikeMacros:
166 | - Q_EMIT
167 | StatementMacros:
168 | - Q_UNUSED
169 | - QT_REQUIRE_VERSION
170 | TabWidth: 8
171 | UseCRLF: false
172 | UseTab: Never
173 | WhitespaceSensitiveMacros:
174 | - STRINGIZE
175 | - PP_STRINGIZE
176 | - BOOST_PP_STRINGIZE
177 | - NS_SWIFT_NAME
178 | - CF_SWIFT_NAME
179 | ...
180 |
--------------------------------------------------------------------------------
/.dev/.clang_format.hook:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | readonly VERSION="14.0.0"
5 |
6 | version=$(clang-format -version)
7 |
8 | if ! [[ version==∗"VERSION"* ]]; then
9 | echo "clang-format version check failed."
10 | echo "a version contains 'VERSION′isneeded,butget′version'"
11 | echo "you can install the right version, and make an soft-link to '$PATH' env"
12 | exit -1
13 | fi
14 |
15 | clang-format -style=google $@
16 |
--------------------------------------------------------------------------------
/.dev/.cpplint_pre_commit.hook:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #TOTAL_ERRORS=0
4 | #echo "HAHAHAHAHHA"
5 | #exit 5
6 | #
7 | #files=$(
8 | #
9 | #if [[ ! $TRAVIS_BRANCH ]]; then
10 | # # install cpplint on local machine.
11 | # if [[ ! $(which cpplint) ]]; then
12 | # pip install cpplint
13 | # fi
14 | # # diff files on local machine.
15 | # files=$(git diff --cached --name-status | awk 'Extra open brace or missing close brace2}')
16 | #else
17 | # # diff files between PR and latest commit on Travis CI.
18 | # branch_ref=(gitrev−parse"TRAVIS_BRANCH")
19 | # head_ref=$(git rev-parse HEAD)
20 | # files=(gitdiff−−name−statusbranch_ref $head_ref | awk 'Extra open brace or missing close brace2}')
21 | #fi
22 | ## The trick to remove deleted files: https://stackoverflow.com/a/2413151
23 | #for file in $files; do
24 | # echo $file
25 | # if [[ $file =~ ^(patches/.*) ]]; then
26 | # continue;
27 | # else
28 | # cpplint --filter=-readability/fn_size $file;
29 | # TOTAL_ERRORS=(exprTOTAL_ERRORS + $?);
30 | # fi
31 | #done
32 | #
33 | #exit $TOTAL_ERRORS
34 |
35 | if git rev-parse --verify HEAD >/dev/null 2>&1
36 | then
37 | against=HEAD
38 | else
39 | # Initial commit: diff against an empty tree object
40 | against=4b825dc642cb6eb9a060e54bf8d69288fbee4904
41 | fi
42 |
43 | # Redirect output to stderr.
44 | exec 1>&2
45 |
46 | cpplint=cpplint
47 | sum=0
48 | filters='-build/include_order,-build/namespaces,-legal/copyright,-runtime/references,-build/include_what_you_use'
49 |
50 | # for cpp
51 | for file in $(git diff-index --name-status $against -- | grep -E '\.[ch](pp)?$' | awk '{print $2}'); do
52 | $cpplint --filter=$filters $file
53 | sum=$(expr ${sum} + $?)
54 | done
55 |
56 | if [ ${sum} -eq 0 ]; then
57 | exit 0
58 | else
59 | exit 1
60 | fi
61 |
--------------------------------------------------------------------------------
/.dev/.pre-commit-config-cpp.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: ed714747d7acbc5790b171702bb012af3b0fe145
4 | hooks:
5 | - id: check-merge-conflict
6 | - id: check-symlinks
7 | - id: end-of-file-fixer
8 | - id: trailing-whitespace
9 | - id: detect-private-key
10 | - id: check-symlinks
11 | - id: check-added-large-files
12 |
13 | - repo: local
14 | hooks:
15 | - id: copyright_checker
16 | name: copyright_checker
17 | entry: python ./.copyright.hook
18 | language: system
19 | files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
20 | exclude: (?!.*third_party)^.*$
21 |
22 | - repo: local
23 | hooks:
24 | - id: clang-format-with-version-check
25 | name: clang-format
26 | description: Format files with ClangFormat.
27 | entry: bash .clang_format.hook -i
28 | language: system
29 | files: \.(c|cc|cxx|cpp|cu|hxx|proto)$
30 |
31 | - repo: local
32 | hooks:
33 | - id: cpplint-cpp-source
34 | name: cpplint
35 | description: Check C++ code style using cpplint.py.
36 | entry: bash .cpplint_pre_commit.hook
37 | language: system
38 | files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
39 |
--------------------------------------------------------------------------------
/.dev/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.0.1
4 | hooks:
5 | - id: check-docstring-first
6 | - id: check-toml
7 | - id: check-yaml
8 | exclude: packaging/.*
9 | args:
10 | - --allow-multiple-documents
11 | - id: mixed-line-ending
12 | args: [--fix=lf]
13 | - id: end-of-file-fixer
14 |
15 | - repo: https://github.com/omnilib/ufmt
16 | rev: v1.3.3
17 | hooks:
18 | - id: ufmt
19 | additional_dependencies:
20 | - black == 22.3.0
21 | - usort == 1.0.2
22 |
23 | - repo: https://github.com/PyCQA/flake8
24 | rev: 7.1.1
25 | hooks:
26 | - id: flake8
27 | args: [--config=setup.cfg]
28 | exclude: generator.py
29 |
30 | - repo: https://github.com/PyCQA/pydocstyle
31 | rev: 6.1.1
32 | hooks:
33 | - id: pydocstyle
34 |
--------------------------------------------------------------------------------
/.dev/clear.sh:
--------------------------------------------------------------------------------
1 | rm -rf build dist *.egg-info __pycache__
2 | rm -rf $(find . -name __pycache__)
3 |
--------------------------------------------------------------------------------
/.dev/commit-prepare.sh:
--------------------------------------------------------------------------------
1 | path=$(cd `dirname $0`; pwd)
2 | cd $path
3 |
4 | # cpp & python format lint
5 | # sudo apt-get update
6 | # sudo apt-get install clang-format -y
7 | pip install pre-commit
8 | pip install yapf
9 | pip install cpplint
10 | pre-commit install -c ./.dev/.pre-commit-config.yaml # only lint for python
11 | # pre-commit install -c ./.dev/.pre-commit-config-cpp.yaml # both python + cpp
12 |
--------------------------------------------------------------------------------
/.dev/init_dev.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | export ENABLE_FFPA_ALL_STAGES=1
3 | export ENABLE_FFPA_ALL_HEADDIM=0
4 | export ENABLE_FFPA_AMPERE=0
5 | export ENABLE_FFPA_HOPPER=0
6 | export ENABLE_FFPA_DEBUG=1
7 | set +x
8 |
--------------------------------------------------------------------------------
/.dev/init_prod.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | export ENABLE_FFPA_ALL_STAGES=1
3 | export ENABLE_FFPA_ALL_HEADDIM=1
4 | export ENABLE_FFPA_AMPERE=1
5 | export ENABLE_FFPA_HOPPER=1
6 | export ENABLE_FFPA_DEBUG=0
7 | set +x
8 |
--------------------------------------------------------------------------------
/.dev/init_prod_mini.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | export ENABLE_FFPA_ALL_STAGES=1
3 | export ENABLE_FFPA_ALL_HEADDIM=0
4 | export ENABLE_FFPA_AMPERE=0
5 | export ENABLE_FFPA_HOPPER=0
6 | export ENABLE_FFPA_DEBUG=0
7 | set +x
8 |
--------------------------------------------------------------------------------
/.dev/install.sh:
--------------------------------------------------------------------------------
1 | rm -rf $(find . -name __pycache__)
2 | python3 setup.py bdist_wheel && cd dist
3 | python3 -m pip install *.whl
4 | cd .. && rm -rf build *.egg-info
5 | rm -rf $(find . -name __pycache__)
6 |
--------------------------------------------------------------------------------
/.dev/uninstall.sh:
--------------------------------------------------------------------------------
1 | python3 -m pip uninstall ffpa-attn -y
2 |
--------------------------------------------------------------------------------
/.github/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
22 | dist
23 | .tmp
24 | bin
25 | .cache
26 |
--------------------------------------------------------------------------------
/.github/workflows/issue.yml:
--------------------------------------------------------------------------------
1 | name: issues
2 | on:
3 | schedule:
4 | - cron: "0 0 * * 0"
5 |
6 | jobs:
7 | close-issues:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | issues: write
11 | pull-requests: write
12 | steps:
13 | - uses: actions/stale@v9.0.0
14 | with:
15 | days-before-issue-stale: 30
16 | days-before-issue-close: 7
17 | stale-issue-label: "stale"
18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
19 | close-issue-message: "This issue was closed because it has been inactive for 7 days since being marked as stale."
20 | days-before-pr-stale: -1
21 | days-before-pr-close: -1
22 | repo-token: ${{ secrets.GITHUB_TOKEN }}
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
22 | dist
23 | .tmp
24 | bin
25 | .cache
26 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "cutlass"]
2 | path = cutlass
3 | url = https://github.com/NVIDIA/cutlass.git
4 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include MANIFEST.in
2 | include LICENSE
3 | include requirements.txt
4 | recursive-include tests *
5 | global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp .gitignore *__pycache__
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
🤖FFPA: 1.8x~3x🎉faster vs SDPA EA with or without MMA Acc F32
19 |
20 |
21 | 🤖[WIP] **FFPA**: Yet another **Faster Flash Prefill Attention** with **O(1) SRAM complexity** & **O(d/4) or O(1) register complexity** for large headdim (D > 256), almost **1.8x~3x** 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: [📈L20 ~1.9x↑🎉](#L1-bench-l20), [📈A30 ~1.8x↑🎉](#L1-bench-a30), [📈3080 ~2.9x↑🎉](#L1-bench-3080), [📈4090 ~2.1x↑🎉](#L1-bench-4090). **FFPA Attention Algo: Fine-grained tiling** for large headim, **FA-2 Attention Algo: Coarse-grained tiling** for small headidm.
22 |
23 |
27 |
28 |
29 |
30 | 💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to 🌟👆🏻star this repo to support me ~)
31 |
32 | ## ©️Citations🎉🎉
33 |
34 | ```BibTeX
35 | @misc{ffpa-attn@2025,
36 | title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
37 | url={https://github.com/xlite-dev/ffpa-attn.git},
38 | note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git},
39 | author={xlite-dev etc},
40 | year={2025}
41 | }
42 | ```
43 |
44 | ## 📖 Contents
45 |
46 | - [📖 Installation⚙️](#install)
47 | - [📖 Python Testing👇](#python-test)
48 | - [📖 FFPA L1~L3 Design💡](#ffpa-design)
49 | - [📈 FFPA L1: L20 ~1.9x↑🎉](#L1-bench-l20)
50 | - [📈 FFPA L1: A30 ~1.8x↑🎉](#L1-bench-a30)
51 | - [📈 FFPA L1: 3080 ~2.9x↑🎉](#L1-bench-3080)
52 | - [📈 FFPA L1: 4090 ~2.1x↑🎉](#L1-bench-4090)
53 | - [📖 Fully Fused MLA w/ FFPA🎉](#fused-mla)
54 |
55 | ## 📖 FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level💡
56 |
57 |
58 | We have extended FlashAttention for large headdim (D > 256) by implementing **Fine-grained Tiling** at the **MMA level (GEMM style)** for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (**1.8x~3x** 🎉 faster than SDPA EA).
59 |
60 | We have named this new attention tiling technique **FFPA: Faster Flash Prefill Attention**. We have designed three `(L1~L3)` levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇
61 |
62 | - [x] 📚L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, ≈O(d/4) register complexity.
63 | - [ ] 📚L2: level 2, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + Q@K^T recomputation.
64 | - [ ] 📚L3: level 3, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + scaling O via HBM offloading.
65 |
66 | By leveraging this approach, we can achieve better performance than SDPA EA for very large headdim (D > 256, `FA-2 not supported`). Approximate SRAM and register complexity analysis for FFPA L1~L3 level is as follows: (`d`=headdim, `C,Br,Bc`=Constant, `Br=Bc`, let O(C)≈O(1)) 👇
67 |
68 | |📚Complexity| 📚FFPA L1 | 📚FFPA L2 | 📚FFPA L3 | 📚FA-2 |
69 | |:---:|:---:|:---:|:---:|:---:|
70 | |SRAM | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | ≈O(3xBrxd), d↑ |
71 | |Register | ≈O(d/4), d↑ | O((Bc/16)x4+2C)≈O(1)|O((Bc/16)x4+2C)≈O(1)| ≈O(d/2), d↑ |
72 | |HBM| ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O| ≈FA2≈O(Nd), O | ≈O(Nd), O |
73 | |Extra HBM| ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈O(N), m,l |
74 |
75 | **📚👇Core Features🎉🎉**: I have implemented **FFPA L1~L3** using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, **Fully QKV Fine-grained Tiling(GEMM style)**, Collective Store, etc.
76 |
77 | |📚Feature |📚Feature |📚Feature |📚Feature|
78 | |:---:|:---:|:---:|:---:|
79 | |✔️Tensor Cores |✔️**MMA(m16n8k16)** |✔️Tile Block(Br, Bc) |✔️Tile MMA/Warp |
80 | |✔️**Split Q**(FA-2)|✔️Pack LDST(128 bits)|✔️SMEM **Swizzle/Pad** |✔️Copy Async |
81 | |✔️**Reg Double Buffers** |✔️QKV **Multi-Stages(1~4)** |✔️Collective Store(**Shfl**)|✔️**Prefetch QKV** g2s |
82 | |✔️**QKV Fine-grained Tiling**|✔️**Shared QKV** SMEM|✔️Mixed MMA Acc|✔️**Persist Q** s2r/g2s|
83 |
84 | - 📚 case: FFPA `L1` kernel template signature: [ffpa_attn_templates_L1.cuh](csrc/cuffpa/ffpa_attn_templates_L1.cuh)
85 |
86 | ```CUDA
87 | template<
88 | const int kHeadDim, // Headdim, 32~1024
89 | const int kMmaAtomM, // MMA Atom M, 16
90 | const int kMmaAtomN, // MMA Atom N, 8
91 | const int kMmaAtomK, // MMA Atom K, 16
92 | const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
93 | const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
94 | const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
95 | const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
96 | const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
97 | const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
98 | const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
99 | const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
100 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
101 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
102 | const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
103 | const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point.
104 | const int kPrefetchPV, // Prefetch V at the Appropriate Time Point.
105 | const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V.
106 | const int kPersistQs2r, // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM.
107 | const int kPersistQg2s, // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage.
108 | const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping.
109 | const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4)
110 | const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4)
111 | const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
112 | const int kPadK, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
113 | const int kPadV // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
114 | > __global__ void // Q, K, V, O -> [B, H, N, D]
115 | // FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>=256),
116 | // which can achieve 1.8x~3x🎉 faster than SDPA EA with or without MMA Acc F32.
117 | ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...);
118 | // FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<256),
119 | // which can achieve 95%-105%🎉 performance as SDPA FA-2 BE with MMA Acc F32 for N<=4096,
120 | // and achieve almost 1.2x~1.4x🎉 faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 +
121 | // P@V F16) for all range N.
122 | ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...);
123 | ```
124 |
125 | ## 📖 Prerequisites
126 |
127 |
128 | - Python >= 3.10
129 | - PyTorch >= 2.4.0, CUDA >= 12.4
130 | - flash-attention >= 2.6.3 (for test)
131 | - Recommended: PyTorch 2.5.1, CUDA 12.5
132 | - Docker: nvcr.io/nvidia/pytorch:24.10-py3
133 |
134 | ## 📖 Installation
135 |
136 |
137 |
138 | The FFPA implemented in this repo can be install as a python library, namely, `ffpa-attn` library (optional).
139 | ```bash
140 | git clone https://github.com/xlite-dev/ffpa-attn.git
141 | # clone, then, run bash .dev/install.sh directly or run commands:
142 | python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y
143 | ```
144 |
145 | ## 📖 FFPA L1 (Level 1): Benchmark 🎉🎉
146 |
147 |
148 |
149 | L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, **D=320-1024(FA2 not supported 👀)**. (Notes, `*`=MMA Acc F32, `^`=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)
150 |
151 | - 📚 NVIDIA L20 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**)
152 |
153 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
154 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
155 | |SDPA EA|56T|63T|58T|58T|55T|56T|54T|55T|54T|55T|54T|56T|
156 | |FFPA L1*|102T|102T|103T|104T|103T|95T|95T|95T|95T|96T|95T|94T|
157 | |Speedup|1.82x|1.62x|1.78x|1.79x|1.87x|1.7x|1.76x|1.73x|1.76x|1.75x|1.76x|1.68x|
158 | |FFPA L1^|104T|103T|103T|102T|104T|103T|102T|94T|94T|94T|100T|100T|
159 | |Speedup|1.86x|1.63x|1.78x|1.76x|1.89x|1.84x|1.89x|1.71x|1.74x|1.71x|1.85x|1.79x|
160 |
161 | - 📚 NVIDIA L20 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~1.9x↑🎉**)
162 |
163 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
164 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
165 | |SDPA EA|56T|64T|58T|58T|55T|56T|54T|55T|54T|55T|54T|56T|
166 | |FFPA L1*|105T|102T|104T|103T|105T|95T|95T|94T|94T|94T|102T|101T|
167 | |Speedup|1.88x|1.59x|1.79x|1.78x|1.91x|1.7x|1.76x|1.71x|1.74x|1.71x|1.89x|1.8x|
168 | |FFPA L1^|104T|103T|103T|102T|103T|103T|102T|94T|94T|94T|100T|100T|
169 | |Speedup|1.86x|1.61x|1.78x|1.76x|1.87x|1.84x|1.89x|1.71x|1.74x|1.71x|1.85x|1.79x|
170 |
171 |
172 |

173 |

174 |
175 |
176 |
177 |
178 | - 📚 NVIDIA A30 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**)
179 |
180 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
181 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
182 | |SDPA EA|25T|25T|24T|24T|24T|24T|23T|22T|22T|22T|22T|18T|
183 | |FFPA L1*|45T|44T|44T|43T|43T|38T|37T|37T|37T|36T|33T|32T|
184 | |Speedup|1.8x|1.76x|1.83x|1.79x|1.79x|1.58x|1.61x|1.68x|1.68x|1.64x|1.5x|1.78x|
185 | |FFPA L1^|48T|46T|45T|43T|44T|44T|44T|38T|37T|36T|40T|34T|
186 | |Speedup|1.92x|1.84x|1.88x|1.79x|1.83x|1.83x|1.91x|1.73x|1.68x|1.64x|1.82x|1.89x|
187 |
188 | - 📚 NVIDIA A30 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~1.9x↑🎉**)
189 |
190 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
191 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
192 | |SDPA EA|25T|25T|24T|24T|24T|24T|23T|22T|22T|22T|22T|18T|
193 | |FFPA L1*|48T|46T|46T|43T|44T|38T|38T|38T|37T|36T|40T|34T|
194 | |Speedup|1.92x|1.84x|1.92x|1.79x|1.83x|1.58x|1.65x|1.73x|1.68x|1.64x|1.82x|1.89x|
195 | |FFPA L1^|48T|46T|45T|43T|44T|44T|44T|38T|37T|36T|39T|34T|
196 | |Speedup|1.92x|1.84x|1.88x|1.79x|1.83x|1.83x|1.91x|1.73x|1.68x|1.64x|1.77x|1.89x|
197 |
198 |
199 |

200 |

201 |
202 |
203 |
204 |
205 | - 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~2.5x↑🎉**)
206 |
207 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
208 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
209 | |SDPA EA|13T|16T|11T|16T|15T|15T|15T|15T|14T|14T|14T|14T|
210 | |FFPA L1*|33T|31T|30T|30T|30T|27T|27T|26T|26T|26T|26T|25T|
211 | |Speedup|2.54x|1.94x|2.73x|1.88x|2.0x|1.8x|1.8x|1.73x|1.86x|1.86x|1.86x|1.79x|
212 | |FFPA L1^|43T|41T|39T|39T|39T|39T|39T|36T|34T|33T|31T|33T|
213 | |Speedup|3.31x|2.56x|3.55x|2.44x|2.6x|2.6x|2.6x|2.4x|2.43x|2.36x|2.21x|2.36x|
214 |
215 | - 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~2.9x↑🎉**)
216 |
217 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
218 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
219 | |SDPA EA|13T|15T|12T|15T|14T|15T|14T|14T|14T|14T|14T|14T|
220 | |FFPA L1*|38T|36T|34T|35T|34T|31T|32T|31T|30T|28T|27T|27T|
221 | |Speedup|2.92x|2.4x|2.83x|2.33x|2.43x|2.07x|2.29x|2.21x|2.14x|2.0x|1.93x|1.93x|
222 | |FFPA L1^|44T|41T|39T|39T|38T|39T|39T|36T|34T|32T|31T|33T|
223 | |Speedup|3.38x|2.73x|3.25x|2.6x|2.71x|2.6x|2.79x|2.57x|2.43x|2.29x|2.21x|2.36x|
224 |
225 |
226 |

227 |

228 |
229 |
230 |
231 |
232 | - 📚 NVIDIA RTX 4090 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**)
233 |
234 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
235 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
236 | |SDPA EA|81T|94T|85T|85T|79T|81T|79T|80T|79T|80T|78T|78T|
237 | |FFPA L1*|149T|150T|150T|150T|150T|140T|140T|140T|139T|139T|137T|134T|
238 | |Speedup|1.84x|1.6x|1.76x|1.76x|1.9x|1.73x|1.77x|1.75x|1.76x|1.74x|1.76x|1.72x|
239 | |FFPA L1^|194T|194T|189T|191T|197T|188T|184T|180T|177T|172T|171T|171T|
240 | |Speedup|2.4x|2.06x|2.22x|2.25x|2.49x|2.32x|2.33x|2.25x|2.24x|2.15x|2.19x|2.19x|
241 |
242 | - 📚 NVIDIA RTX 4090 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~2.1x↑🎉**)
243 |
244 | |Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
245 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
246 | |SDPA EA|82T|92T|85T|84T|78T|81T|79T|80T|78T|79T|77T|78T|
247 | |FFPA L1*|176T|170T|171T|171T|171T|161T|160T|161T|160T|158T|165T|164T|
248 | |Speedup|2.15x|1.85x|2.01x|2.04x|2.19x|1.99x|2.03x|2.01x|2.05x|2.0x|2.14x|2.1x|
249 | |FFPA L1^|200T|191T|189T|191T|188T|188T|186T|179T|175T|173T|172T|170T|
250 | |Speedup|2.44x|2.08x|2.22x|2.27x|2.41x|2.32x|2.35x|2.24x|2.24x|2.19x|2.23x|2.18x|
251 |
252 |
253 |

254 |

255 |
256 |
257 | ## 📖 Python Testing
258 |
259 |
260 | 👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The `--gen-bench` and `--plot` options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉.
261 |
262 | - 📚 case: B=1, H=48, N=8192, D=320(`FA2 not supported`)
263 | ```python
264 | # You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
265 | cd tests && python3 test_ffpa_attn.py --B 1 --H 48 --N 8192 --show-all --D 320
266 | ---------------------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5--------------------
267 | (sdpa): ['-0.02380371'], time:73.66518ms, TFLOPS:56.19 (+0.00 %)(~1.00x)
268 | (ffpa+acc+f32+L1+stage1): ['-0.02378845'], time:52.87361ms, TFLOPS:78.28 (+39.32%)(~1.39x)
269 | (ffpa+acc+f32+L1+stage2): ['-0.02378845'], time:40.84062ms, TFLOPS:101.35(+29.46%)(~1.80x)
270 | (ffpa+acc+f32+L1+stage3): ['-0.02378845'], time:40.49534ms, TFLOPS:102.21(+0.85 %)(~1.82x)
271 | (ffpa+acc+f32+L1+stage4): ['-0.02378845'], time:40.88177ms, TFLOPS:101.25(+0.00 %)(~1.80x)
272 | (ffpa+acc+f16+L1+stage1): ['-0.02378845'], time:53.43298ms, TFLOPS:77.46 (+0.00 %)(~1.38x)
273 | (ffpa+acc+f16+L1+stage2): ['-0.02378845'], time:39.76068ms, TFLOPS:104.10(+1.85 %)(~1.85x)
274 | (ffpa+acc+f16+L1+stage3): ['-0.02378845'], time:39.54901ms, TFLOPS:104.66(+0.54 %)(~1.86x)
275 | (ffpa+acc+f16+L1+stage4): ['-0.02378845'], time:41.06554ms, TFLOPS:100.79(+0.00 %)(~1.79x)
276 | --------------------------------------------------------------------------------------------------------
277 | ```
278 | - 📚 case: Generate benchmark table and speedup bar plots on Your device.
279 | ```bash
280 | cd tests && pip install matplotlib && python3 test_ffpa_attn.py --gen-bench --show-all --plot
281 | ```
282 | - 📚 case: Compare small headdim (d<256, e.g 64), FFPA-L1 vs SDPA FA-2 BE.
283 | ```python
284 | # Enable ffpa-attn small d kernel which using coarse-grained tiling method.
285 | export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1
286 | cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20
287 | ---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5--------------------
288 | (sdpa): ['0.00802612'], time:0.148057ms, TFLOPS:59.14 (+0.00 %)(~1.00x)
289 | (ffpa+acc+f32+L1+stage1): ['0.00803375'], time:0.103807ms, TFLOPS:84.34 (+42.63%)(~1.43x)
290 | (ffpa+acc+f32+L1+stage2): ['0.00803375'], time:0.102233ms, TFLOPS:85.64 (+1.54 %)(~1.45x)
291 | (ffpa+acc+f32+L1+stage3): ['0.00803375'], time:0.102519ms, TFLOPS:85.40 (+0.00 %)(~1.44x)
292 | (ffpa+acc+f32+L1+stage4): ['0.00803375'], time:0.102043ms, TFLOPS:85.80 (+0.19 %)(~1.45x)
293 | (ffpa+acc+f16+L1+stage1): ['0.00795746'], time:0.104713ms, TFLOPS:83.61 (+0.00 %)(~1.41x)
294 | (ffpa+acc+f16+L1+stage2): ['0.00795746'], time:0.102949ms, TFLOPS:85.05 (+0.00 %)(~1.44x)
295 | (ffpa+acc+f16+L1+stage3): ['0.00795746'], time:0.108957ms, TFLOPS:80.36 (+0.00 %)(~1.36x)
296 | (ffpa+acc+f16+L1+stage4): ['0.00795746'], time:0.103282ms, TFLOPS:84.77 (+0.00 %)(~1.43x)
297 | --------------------------------------------------------------------------------------------------------
298 | cd tests && python3 test_ffpa_attn.py --B 1 --H 32 --N 4096 --check --show-all --D 64 # NVIDIA L20
299 | -------------------------B=1, H=32, N=4096, D=64, Warmup: 1, Iters: 5-----------------------------------
300 | (sdpa): ['0.01959229'], time:1.397752ms, TFLOPS:100.24(+0.00 %)(~1.00x)
301 | (ffpa+acc+f32+L1+stage1): ['0.01959229'], time:1.368856ms, TFLOPS:102.36(+2.11 %)(~1.02x)
302 | (ffpa+acc+f32+L1+stage2): ['0.01959229'], time:1.367807ms, TFLOPS:102.44(+0.08 %)(~1.02x)
303 | (ffpa+acc+f32+L1+stage3): ['0.01959229'], time:1.367855ms, TFLOPS:102.43(+0.00 %)(~1.02x)
304 | (ffpa+acc+f32+L1+stage4): ['0.01959229'], time:1.368045ms, TFLOPS:102.42(+0.00 %)(~1.02x)
305 | (ffpa+acc+f16+L1+stage1): ['0.01957703'], time:1.389312ms, TFLOPS:100.85(+0.00 %)(~1.01x)
306 | (ffpa+acc+f16+L1+stage2): ['0.01957703'], time:1.388311ms, TFLOPS:100.92(+0.00 %)(~1.01x)
307 | (ffpa+acc+f16+L1+stage3): ['0.01957703'], time:1.386976ms, TFLOPS:101.02(+0.00 %)(~1.01x)
308 | (ffpa+acc+f16+L1+stage4): ['0.01957703'], time:1.387834ms, TFLOPS:100.96(+0.00 %)(~1.01x)
309 | --------------------------------------------------------------------------------------------------------
310 | ```
311 |
312 | 💡NOTE: Please check all configurable environment variables in [env.py](./env.py).
313 |
314 | ## 📖 Fully Fused MLA with FFPA 🎉
315 |
316 |
317 |
318 | Extending the support of FA for large headdim is meaningful in the context of **DeepSeek MLA**. For example, when FA supports headdim values greater than 512, we can achieve fully Fused MLA into a single CUDA kernel, after W_UK/W_UV are absorbed into W_Q/W_O (resulting in C_kv/C_q with `dc/dc' >= 512`). TODO list👇:
319 |
320 | - [ ] 📚Fully Fused MLA into a single CUDA kernel using **FFPA** Algo and Tensor Cores.
321 |
322 | ## ©️License
323 |
324 |
325 |
326 | GNU General Public License v3.0
327 |
328 | ## 🎉Contribute
329 |
330 |
331 |
332 | How to contribute? Wecome to star⭐️ this repo to support me👆🏻 ~
333 |
334 |
343 |
344 | ## 📖 References
345 |
346 |
347 | - [flash-attention](https://github.com/Dao-AILab/flash-attention)
348 | - [LeetCUDA](https://github.com/xlite-dev/LeetCUDA)
349 | - [flashinfer](https://github.com/flashinfer-ai/flashinfer)
350 |
--------------------------------------------------------------------------------
/bench/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
22 | dist
23 |
--------------------------------------------------------------------------------
/bench/NVIDIA_A30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_A30.png
--------------------------------------------------------------------------------
/bench/NVIDIA_A30_ffpa+acc+f16+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_A30_ffpa+acc+f16+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_A30_ffpa+acc+f32+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_A30_ffpa+acc+f32+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png
--------------------------------------------------------------------------------
/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f16+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f16+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f32+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2_ffpa+acc+f32+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_GeForce_RTX_4090.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_4090.png
--------------------------------------------------------------------------------
/bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f16+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f16+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f32+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_GeForce_RTX_4090_ffpa+acc+f32+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_L20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_L20.png
--------------------------------------------------------------------------------
/bench/NVIDIA_L20_ffpa+acc+f16+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_L20_ffpa+acc+f16+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/NVIDIA_L20_ffpa+acc+f32+L1_Speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xlite-dev/ffpa-attn/3b90cc1d47eb921503dd57848e13a52f2a8cf2be/bench/NVIDIA_L20_ffpa+acc+f32+L1_Speedup.png
--------------------------------------------------------------------------------
/bench/bank_conflicts_check.sh:
--------------------------------------------------------------------------------
1 | cd tests
2 | ncu \
3 | --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
4 | python3 test.py --B 1 --H 8 --N 1024 --D 320 --w 0 --i 1 --show-all
5 | cd ..
6 |
--------------------------------------------------------------------------------
/bench/bench.sh:
--------------------------------------------------------------------------------
1 | # benchmarks
2 | cd tests
3 | python3 test.py --gen-bench --show-all --plot
4 | cd ..
5 |
--------------------------------------------------------------------------------
/csrc/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
--------------------------------------------------------------------------------
/csrc/cuffpa/ffpa_attn_F16F16F16_L1.cu:
--------------------------------------------------------------------------------
1 | #include "launch_templates.cuh"
2 | using namespace ffpa;
3 |
4 |
5 | void ffpa_mma_acc_f16_L1(torch::Tensor Q,
6 | torch::Tensor K,
7 | torch::Tensor V,
8 | torch::Tensor O,
9 | int stages) {
10 | CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
11 | CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
12 | CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
13 | CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
14 | const int d = Q.size(3); // B, H, N, d
15 | // Q@K^T or P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
16 | constexpr int kMmaAccFloat32QK = 0;
17 | constexpr int kMmaAccFloat32PV = 0;
18 |
19 | #ifdef ENABLE_FFPA_ALL_STAGES
20 | // dispatch stages
21 | if (stages == 2) {
22 | DISPATCH_HEADDIM(LAUNCHER_L1, 2);
23 | } else if (stages == 3) {
24 | DISPATCH_HEADDIM(LAUNCHER_L1, 3);
25 | } else if (stages == 4) {
26 | DISPATCH_HEADDIM(LAUNCHER_L1, 4);
27 | } else {
28 | DISPATCH_HEADDIM(LAUNCHER_L1, 1);
29 | }
30 | #else
31 | // dispatch stages
32 | if (stages == 2) {
33 | DISPATCH_HEADDIM(LAUNCHER_L1, 2);
34 | } else {
35 | DISPATCH_HEADDIM(LAUNCHER_L1, 1);
36 | }
37 | #endif
38 | }
39 |
--------------------------------------------------------------------------------
/csrc/cuffpa/ffpa_attn_F16F16F32_L1.cu:
--------------------------------------------------------------------------------
1 | #include "launch_templates.cuh"
2 | using namespace ffpa;
3 |
4 |
5 | void ffpa_mma_acc_f32_L1(torch::Tensor Q,
6 | torch::Tensor K,
7 | torch::Tensor V,
8 | torch::Tensor O,
9 | int stages) {
10 | CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
11 | CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
12 | CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
13 | CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
14 | const int d = Q.size(3); // B, H, N, d
15 | // Q@K^T or P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
16 | #ifdef ENABLE_FFPA_FORCE_QK_F16
17 | constexpr int kMmaAccFloat32QK = 0;
18 | #else
19 | constexpr int kMmaAccFloat32QK = 1;
20 | #endif
21 | #ifdef ENABLE_FFPA_FORCE_PV_F16
22 | constexpr int kMmaAccFloat32PV = 0;
23 | #else
24 | constexpr int kMmaAccFloat32PV = 1;
25 | #endif
26 |
27 | #ifdef ENABLE_FFPA_ALL_STAGES
28 | // dispatch stages
29 | if (stages == 2) {
30 | DISPATCH_HEADDIM(LAUNCHER_L1, 2);
31 | } else if (stages == 3) {
32 | DISPATCH_HEADDIM(LAUNCHER_L1, 3);
33 | } else if (stages == 4) {
34 | DISPATCH_HEADDIM(LAUNCHER_L1, 4);
35 | } else {
36 | DISPATCH_HEADDIM(LAUNCHER_L1, 1);
37 | }
38 | #else
39 | // dispatch stages
40 | if (stages == 2) {
41 | DISPATCH_HEADDIM(LAUNCHER_L1, 2);
42 | } else {
43 | DISPATCH_HEADDIM(LAUNCHER_L1, 1);
44 | }
45 | #endif
46 | }
47 |
--------------------------------------------------------------------------------
/csrc/cuffpa/launch_templates.cuh:
--------------------------------------------------------------------------------
1 | #include "ffpa_attn_templates_L1.cuh"
2 | using namespace ffpa;
3 |
4 | static constexpr int kMaxDForSmallDKernel = 64;
5 | static constexpr int kMaxDForOStoreFloat32 = 64;
6 | static constexpr int kMaxDForSmallBlockTile = 256;
7 |
8 | template
9 | static constexpr int getConfigMmaTileSeqLenQP() {
10 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S
11 | #if defined(BUILD_FFPA_ATTN_MMA_L20)
12 | constexpr int kMmaTileSeqLenQP = (
13 | kHeadDim <= kMaxDForSmallBlockTile) ? 4: 8;
14 | #else
15 | constexpr int kMmaTileSeqLenQP = (
16 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 8;
17 | #endif
18 | #else // if undef ENABLE_FFPA_PERSIST_KV_G2S
19 | // O(1) SRAM complexity, may always use large tile for
20 | // ffpa large d kernel. TODO: tune block size for L20/4090/3080 etc.
21 | #if defined(BUILD_FFPA_ATTN_MMA_L20)
22 | constexpr int kMmaTileSeqLenQP = (
23 | kHeadDim <= kMaxDForSmallBlockTile) ? 4: 8;
24 | #else
25 | constexpr int kMmaTileSeqLenQP = (
26 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 8;
27 | #endif
28 | #endif
29 | return kMmaTileSeqLenQP;
30 | }
31 |
32 | template
33 | static constexpr int getConfigWarpTileSeqLenK() {
34 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S
35 | #if defined(BUILD_FFPA_ATTN_MMA_L20)
36 | constexpr int kWarpTileSeqLenK = (
37 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 16;
38 | #else
39 | constexpr int kWarpTileSeqLenK = (
40 | kHeadDim <= kMaxDForSmallBlockTile) ? 16: 16;
41 | #endif
42 | #else // if undef ENABLE_FFPA_PERSIST_KV_G2S
43 | #if defined(BUILD_FFPA_ATTN_MMA_L20)
44 | constexpr int kWarpTileSeqLenK = (
45 | kHeadDim <= kMaxDForSmallBlockTile) ? 8: 16;
46 | #else
47 | constexpr int kWarpTileSeqLenK = (
48 | kHeadDim <= kMaxDForSmallBlockTile) ? 16: 16;
49 | #endif
50 | #endif
51 | return kWarpTileSeqLenK;
52 | }
53 |
54 | template
55 | static constexpr int getConfigWarpTileHeadDimV() {
56 | constexpr int kMmaAtomN = 8;
57 | constexpr int kMmaTileHeadDimV = 1;
58 | constexpr int kWarpTileHeadDimV = (
59 | kHeadDim / (kMmaAtomN * kMmaTileHeadDimV));
60 | return kWarpTileHeadDimV;
61 | }
62 |
63 | static constexpr int getConfigShareSmemQKV() {
64 | #if defined(ENABLE_FFPA_QKV_SMEM_SHARE)
65 | constexpr int kShareSmemQKV = 1;
66 | #else
67 | constexpr int kShareSmemQKV = 0;
68 | #endif
69 | return kShareSmemQKV;
70 | }
71 |
72 | template
73 | static constexpr int getConfigOStorageAccFloat32() {
74 | // 0/1, The precision of the O storage buffer can differ from
75 | // that of the MMA, supporting either FP32 or Half precision.
76 | // FP16 can provide precision to approximately 3-4 decimal places.
77 | // Thus, if the error does not exceed 1e-3, using FP16 storage is
78 | // sufficient for most applications.
79 | return ((kHeadDim <= kMaxDForOStoreFloat32)) ? 1 : 0;
80 | }
81 |
82 | template
83 | static constexpr int getConfigPrefetchQKV() {
84 | // Prefetch QKV at the appropriate time point.
85 | #if defined(ENABLE_FFPA_PREFETCH_QKV)
86 | #if defined(ENABLE_FFPA_PERSIST_KV_G2S)
87 | constexpr int kPrefetchQKV = 1; // kStageQKV is unused
88 | #else
89 | constexpr int kPrefetchQKV = (kStageQKV > 1) ? 1 : 0;
90 | #endif
91 | #else
92 | constexpr int kPrefetchQKV = 0;
93 | #endif
94 | return kPrefetchQKV;
95 | }
96 |
97 | template
98 | static constexpr int getConfigPersistQg2s() {
99 | // Persist load Q g2s for headdim < 512, more SRAM, but still
100 | // keep register usage.
101 | #if defined(ENABLE_FFPA_PERSIST_Q_G2S)
102 | constexpr int kPersistQg2s = (kHeadDim < 256) ? 1 : (
103 | (kHeadDim <= 320) ? ((kStageQK < 3) ? 1 : 0) : 0
104 | );
105 | #else
106 | constexpr int kPersistQg2s = 0;
107 | #endif
108 | return kPersistQg2s;
109 | }
110 |
111 | static constexpr int getConfigPersistQs2r() {
112 | // Persist load Q s2r for headdim < 512, more registers,
113 | // but still keep O(1) SRAM.
114 | #ifdef ENABLE_FFPA_PERSIST_Q_S2R
115 | constexpr int kPersistQs2r = 1;
116 | #else
117 | constexpr int kPersistQs2r = 0;
118 | #endif
119 | return kPersistQs2r;
120 | }
121 |
122 | static constexpr int getConfigPersistVs2r() {
123 | #ifdef ENABLE_FFPA_PERSIST_V_S2R
124 | constexpr int kPersistVs2r = 1;
125 | #else
126 | constexpr int kPersistVs2r = 0;
127 | #endif
128 | return kPersistVs2r;
129 | }
130 |
131 | static constexpr int getConfigRegistersPipeKV() {
132 | #ifdef ENABLE_FFPA_REGISTERS_PIPE_KV
133 | constexpr int kRegPipeKV = 1;
134 | #else
135 | constexpr int kRegPipeKV = 0;
136 | #endif
137 | return kRegPipeKV;
138 | }
139 |
140 | static constexpr int getConfigPadQ() {
141 | #ifdef ENABLE_FFPA_SMEM_SWIZZLE_Q
142 | constexpr int kPadQ = 0;
143 | #else
144 | constexpr int kPadQ = 8;
145 | #endif
146 | return kPadQ;
147 | }
148 |
149 | static constexpr int getConfigPadK() {
150 | #ifdef ENABLE_FFPA_SMEM_SWIZZLE_K
151 | constexpr int kPadK = 0;
152 | #else
153 | constexpr int kPadK = 8;
154 | #endif
155 | return kPadK;
156 | }
157 |
158 | static constexpr int getConfigPadV() {
159 | #ifdef ENABLE_FFPA_SMEM_SWIZZLE_V
160 | constexpr int kPadV = 0;
161 | #else
162 | constexpr int kPadV = 8;
163 | #endif
164 | return kPadV;
165 | }
166 |
167 | template
168 | static inline dim3 getConfigBlock() {
169 | dim3 block(kNumThreads);
170 | return block;
171 | }
172 |
173 | template
174 | static inline dim3 getConfigGrid(
175 | const int B, const int H, const int N) {
176 | // Tr(=N/Br), batch_size x num_heads
177 | // try grid(N/Br, B * H) or grid(N/Br, H, B)
178 | #ifdef ENABLE_FFPA_LAUNCH_GRID_DNHB
179 | dim3 grid(utils::div_ceil(N, Br), H, B);
180 | #else
181 | dim3 grid(utils::div_ceil(N, Br), B * H);
182 | #endif
183 | return grid;
184 | }
185 |
186 | template<
187 | const int Br,
188 | const int Bc,
189 | const int kMmaAtomM,
190 | const int kMmaAtomN,
191 | const int kMmaAtomK,
192 | const int kHeadDim,
193 | const int kShareSmemQKV,
194 | const int kPersistQg2s,
195 | const int kPersistQs2r,
196 | const int kStageQK,
197 | const int kStagePV,
198 | const int kPadQ,
199 | const int kPadK,
200 | const int kPadV
201 | >
202 | static constexpr int getConfigQKVSmemMaxSize() {
203 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S
204 | if constexpr (kHeadDim <= kMaxDForSmallDKernel) { // e.g > 128 will use large d kernel
205 | // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem.
206 | constexpr int Q_smem_size = (
207 | (kHeadDim / kMmaAtomK) * (Br * (kMmaAtomK + kPadQ)));
208 | constexpr int K_smem_size = (
209 | (kHeadDim / kMmaAtomK) * (Bc * (kMmaAtomK + kPadK)));
210 | constexpr int V_smem_size = (
211 | (kHeadDim / (kMmaAtomN * 2)) * (Bc * (kMmaAtomN * 2 + kPadV)));
212 | constexpr int kQSmemMaxSize = Q_smem_size * 2;
213 | constexpr int kKSmemMaxSize = K_smem_size * 2;
214 | constexpr int kVSmemMaxSize = V_smem_size * 2;
215 | constexpr int kQKSmemMaxSize = (
216 | kQSmemMaxSize > kKSmemMaxSize ? kQSmemMaxSize : kKSmemMaxSize);
217 | constexpr int kQKVSmemMaxSize = (
218 | (kShareSmemQKV && kPersistQs2r) ?
219 | (kQKSmemMaxSize + kVSmemMaxSize) : // QK shared the same smem
220 | (kQSmemMaxSize + kKSmemMaxSize + kVSmemMaxSize)
221 | );
222 | return kQKVSmemMaxSize;
223 | } else {
224 | // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem.
225 | constexpr int Q_smem_size = ((kPersistQg2s ? (kHeadDim / kMmaAtomK) : kStageQK) *
226 | (Br * (kMmaAtomK + kPadQ))) * 2;
227 | constexpr int K_smem_size = ((kStageQK) * (Bc * (kMmaAtomK + kPadK))) * 2;
228 | constexpr int V_smem_size = (kStagePV * (Bc * (kMmaAtomN * 2 + kPadV))) * 2;
229 | constexpr int kQKSmemMaxSize = (Q_smem_size + K_smem_size);
230 | constexpr int kVSmemMaxSize = V_smem_size;
231 | // try to let V reuse all Q+K smem after Q@K^T, reduce smem usage.
232 | constexpr int kQKVSmemMaxSize = (
233 | (kShareSmemQKV && (!kPersistQg2s)) ?
234 | ((kQKSmemMaxSize > kVSmemMaxSize) ? kQKSmemMaxSize: kVSmemMaxSize) :
235 | (kQKSmemMaxSize + kVSmemMaxSize)
236 | );
237 | // NOTE: R_D registers usage, s=2, d=64, 16 regs; d=128, 32 regs;
238 | // d=256, 64 regs; d=512, 128 regs; d=1024, 256 regs;
239 | return kQKVSmemMaxSize;
240 | }
241 | #else
242 | // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem.
243 | constexpr int Q_smem_size = ((kPersistQg2s ? (kHeadDim / kMmaAtomK) : kStageQK) *
244 | (Br * (kMmaAtomK + kPadQ))) * 2;
245 | constexpr int K_smem_size = ((kStageQK) * (Bc * (kMmaAtomK + kPadK))) * 2;
246 | constexpr int V_smem_size = (kStagePV * (Bc * (kMmaAtomN * 2 + kPadV))) * 2;
247 | constexpr int kQKSmemMaxSize = (Q_smem_size + K_smem_size);
248 | constexpr int kVSmemMaxSize = V_smem_size;
249 | // try to let V reuse all Q+K smem after Q@K^T, reduce smem usage.
250 | constexpr int kQKVSmemMaxSize = (
251 | (kShareSmemQKV && (!kPersistQg2s)) ?
252 | ((kQKSmemMaxSize > kVSmemMaxSize) ? kQKSmemMaxSize: kVSmemMaxSize) :
253 | (kQKSmemMaxSize + kVSmemMaxSize)
254 | );
255 | // NOTE: R_D registers usage, s=2, d=64, 16 regs; d=128, 32 regs;
256 | // d=256, 64 regs; d=512, 128 regs; d=1024, 256 regs;
257 | return kQKVSmemMaxSize;
258 | #endif
259 | }
260 |
261 | template<
262 | const int kHeadDim, // Headdim, 32~1024
263 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
264 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
265 | const int kStage
266 | >
267 | void launch_ffpa_mma_L1_template(torch::Tensor Q,
268 | torch::Tensor K,
269 | torch::Tensor V,
270 | torch::Tensor O) {
271 | // Q,K,V,O with [B, H, N, D] layout, B=batch, H=head, N=seqlen, D=dim
272 | // TODO: support BNHD layout, Q,K,V,O with [B, N, H, D] layout.
273 | constexpr int kMmaAtomM = 16;
274 | constexpr int kMmaAtomN = 8;
275 | constexpr int kMmaAtomK = 16;
276 | // Split-Q(FA-2) Algo, Tile MMA across Q and keep KV access for all MMAs.
277 | constexpr int kMmaTileSeqLenQ = getConfigMmaTileSeqLenQP();
278 | constexpr int kMmaTileSeqLenK = 1;
279 | constexpr int kMmaTileSeqLenP = getConfigMmaTileSeqLenQP();
280 | constexpr int kMmaTileHeadDimV = 1;
281 | constexpr int kWarpTileSeqLenQ = 1;
282 | constexpr int kWarpTileSeqLenK = getConfigWarpTileSeqLenK();
283 | constexpr int kWarpTileSeqLenP = 1;
284 | constexpr int kWarpTileHeadDimV = getConfigWarpTileHeadDimV();
285 | constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ;
286 | constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK;
287 | static_assert(Br == Bc, "Br must be equal Bc to avoid illegal memory access.");
288 | constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK;
289 | constexpr int kOStorageAccFloat32 = getConfigOStorageAccFloat32();
290 | // Apply different multi stages policy for QK and V.
291 | // TODO: tune stages for Q@K and P@V.
292 | constexpr int kStageQK = kStage; // <= 4
293 | constexpr int kStagePV = kStage; // <= 4
294 | // Prefetch QKV, Persist Q g2s/s2r, Shared QKV smem.
295 | constexpr int kShareSmemQKV = getConfigShareSmemQKV();
296 | constexpr int kPrefetchQK = getConfigPrefetchQKV();
297 | constexpr int kPrefetchPV = getConfigPrefetchQKV();
298 | constexpr int kPersistQs2r = getConfigPersistQs2r();
299 | constexpr int kPersistQg2s = getConfigPersistQg2s();
300 | constexpr int kRegPipeKV = getConfigRegistersPipeKV();
301 | // QKV smem swizzle, 0 for smem swizzle, !0 for smem padding.
302 | constexpr int kPadQ = getConfigPadQ();
303 | constexpr int kPadK = getConfigPadK();
304 | constexpr int kPadV = getConfigPadV();
305 | // Calculate SRAM size needed for per block.
306 | constexpr int kQKVSmemMaxSize = getConfigQKVSmemMaxSize<
307 | Br, Bc, kMmaAtomM, kMmaAtomN, kMmaAtomK, kHeadDim, kShareSmemQKV,
308 | kPersistQg2s, kPersistQs2r, kStageQK, kStagePV, kPadQ, kPadK,
309 | kPadV
310 | >();
311 |
312 | const int QKV_batch = Q.size(0);
313 | const int QKV_head = Q.size(1);
314 | const int QKV_seqlen = Q.size(2); // QKV_seqlen
315 | assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc)
316 |
317 | const dim3 block = getConfigBlock(); // 4/8 warps per block
318 | const dim3 grid = getConfigGrid
(QKV_batch, QKV_head, QKV_seqlen);
319 | // Precompute softmax scale and Tc
320 | const int Tc = utils::div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d]
321 | const float scale = 1.0f / sqrt((float) kHeadDim);
322 |
323 | #define LAUNCH_TEMPLATE_FUNC(TEMPLATE_FUNC) \
324 | cudaFuncSetAttribute( \
325 | TEMPLATE_FUNC, \
326 | cudaFuncAttributeMaxDynamicSharedMemorySize, \
327 | kQKVSmemMaxSize \
328 | ); \
329 | TEMPLATE_FUNC<<>>( \
330 | reinterpret_cast(Q.data_ptr()), \
331 | reinterpret_cast(K.data_ptr()), \
332 | reinterpret_cast(V.data_ptr()), \
333 | reinterpret_cast(O.data_ptr()), \
334 | QKV_seqlen, \
335 | QKV_head, \
336 | scale, \
337 | Tc \
338 | );
339 |
340 | #ifdef ENABLE_FFPA_PERSIST_KV_G2S
341 | if constexpr (kHeadDim <= kMaxDForSmallDKernel) { // e.g > 128 will use large d kernel
342 | constexpr int kPersistVs2r = getConfigPersistVs2r(); // only for d < 256
343 |
344 | auto ffpa_mma_L1_small_d_kernel_func = (
345 | ffpa_mma_stages_split_q_L1_small_d_template<
346 | kHeadDim,
347 | kMmaAtomM,
348 | kMmaAtomN,
349 | kMmaAtomK,
350 | kMmaTileSeqLenQ,
351 | kMmaTileSeqLenK,
352 | kMmaTileSeqLenP,
353 | kMmaTileHeadDimV,
354 | kWarpTileSeqLenQ,
355 | kWarpTileSeqLenK,
356 | kWarpTileSeqLenP,
357 | kWarpTileHeadDimV,
358 | kMmaAccFloat32QK,
359 | kMmaAccFloat32PV,
360 | kOStorageAccFloat32,
361 | kPrefetchQK,
362 | kPrefetchPV,
363 | kShareSmemQKV,
364 | kPersistQs2r,
365 | kPersistVs2r,
366 | // Force disable KV registers ping pong buffers
367 | // while V s2r is enabled.
368 | (kPersistVs2r) ? 0 : kRegPipeKV,
369 | 1, /*kStageQK unused*/
370 | 1, /*kStagePV unused*/
371 | kPadQ,
372 | kPadK,
373 | kPadV
374 | >
375 | );
376 | LAUNCH_TEMPLATE_FUNC(ffpa_mma_L1_small_d_kernel_func);
377 | } else { // large headdim > kMaxDForSmallDKernel (e.g 128)
378 | auto ffpa_mma_L1_large_d_kernel_func = (
379 | ffpa_mma_stages_split_q_L1_large_d_template<
380 | kHeadDim,
381 | kMmaAtomM,
382 | kMmaAtomN,
383 | kMmaAtomK,
384 | kMmaTileSeqLenQ,
385 | kMmaTileSeqLenK,
386 | kMmaTileSeqLenP,
387 | kMmaTileHeadDimV,
388 | kWarpTileSeqLenQ,
389 | kWarpTileSeqLenK,
390 | kWarpTileSeqLenP,
391 | kWarpTileHeadDimV,
392 | kMmaAccFloat32QK,
393 | kMmaAccFloat32PV,
394 | kOStorageAccFloat32,
395 | kPrefetchQK,
396 | kPrefetchPV,
397 | (kPersistQg2s) ? 0 : kShareSmemQKV,
398 | // Force disable Q s2r for d >= 256, Q s2r for large d will
399 | // need too many register, thus, introduce performance drops.
400 | (kPersistQg2s || kHeadDim > 256) ? 0 : kPersistQs2r,
401 | kPersistQg2s,
402 | kRegPipeKV,
403 | kStageQK,
404 | kStagePV,
405 | kPadQ,
406 | kPadK,
407 | kPadV
408 | >
409 | );
410 | LAUNCH_TEMPLATE_FUNC(ffpa_mma_L1_large_d_kernel_func);
411 | }
412 | #else
413 | auto ffpa_mma_L1_large_d_kernel_func = (
414 | ffpa_mma_stages_split_q_L1_large_d_template<
415 | kHeadDim,
416 | kMmaAtomM,
417 | kMmaAtomN,
418 | kMmaAtomK,
419 | kMmaTileSeqLenQ,
420 | kMmaTileSeqLenK,
421 | kMmaTileSeqLenP,
422 | kMmaTileHeadDimV,
423 | kWarpTileSeqLenQ,
424 | kWarpTileSeqLenK,
425 | kWarpTileSeqLenP,
426 | kWarpTileHeadDimV,
427 | kMmaAccFloat32QK,
428 | kMmaAccFloat32PV,
429 | kOStorageAccFloat32,
430 | kPrefetchQK,
431 | kPrefetchPV,
432 | (kPersistQg2s) ? 0 : kShareSmemQKV,
433 | // Force disable Q s2r for d >= 256, Q s2r for large d will
434 | // need too many register, thus, introduce performance drops.
435 | (kPersistQg2s || kHeadDim > 256) ? 0 : kPersistQs2r,
436 | kPersistQg2s,
437 | kRegPipeKV,
438 | kStageQK,
439 | kStagePV,
440 | kPadQ,
441 | kPadK,
442 | kPadV
443 | >
444 | );
445 | LAUNCH_TEMPLATE_FUNC(ffpa_mma_L1_large_d_kernel_func);
446 | #endif
447 |
448 | #undef LAUNCH_TEMPLATE_FUNC
449 | }
450 |
451 | // dispatch headdim
452 | #define LAUNCHER_L1(D, S) \
453 | case D: \
454 | launch_ffpa_mma_L1_template< \
455 | (D), \
456 | kMmaAccFloat32QK, \
457 | kMmaAccFloat32PV, \
458 | (S) \
459 | >(Q, K, V, O); \
460 | break;
461 |
462 | #ifdef ENABLE_FFPA_DEBUG
463 | // minimal kernels for debug mode
464 | #define DISPATCH_HEADDIM(LAUNCHER, S) \
465 | { \
466 | switch (d) \
467 | { \
468 | LAUNCHER(32, (S)); \
469 | LAUNCHER(64, (S)); \
470 | LAUNCHER(128, (S)); \
471 | LAUNCHER(256, (S)); \
472 | LAUNCHER(320, (S)); \
473 | LAUNCHER(512, (S)); \
474 | LAUNCHER(1024, (S)); \
475 | default: \
476 | throw std::runtime_error( \
477 | "headdim not support!"); \
478 | break; \
479 | } \
480 | }
481 |
482 | #else
483 | #ifdef ENABLE_FFPA_ALL_HEADDIM
484 | // multiple of 32
485 | #define DISPATCH_HEADDIM(LAUNCHER, S) \
486 | { \
487 | switch (d) \
488 | { \
489 | LAUNCHER(32, (S)); \
490 | LAUNCHER(64, (S)); \
491 | LAUNCHER(96, (S)); \
492 | LAUNCHER(128, (S)); \
493 | LAUNCHER(160, (S)); \
494 | LAUNCHER(192, (S)); \
495 | LAUNCHER(224, (S)); \
496 | LAUNCHER(256, (S)); \
497 | LAUNCHER(288, (S)); \
498 | LAUNCHER(320, (S)); \
499 | LAUNCHER(352, (S)); \
500 | LAUNCHER(384, (S)); \
501 | LAUNCHER(416, (S)); \
502 | LAUNCHER(448, (S)); \
503 | LAUNCHER(480, (S)); \
504 | LAUNCHER(512, (S)); \
505 | LAUNCHER(544, (S)); \
506 | LAUNCHER(576, (S)); \
507 | LAUNCHER(608, (S)); \
508 | LAUNCHER(640, (S)); \
509 | LAUNCHER(672, (S)); \
510 | LAUNCHER(704, (S)); \
511 | LAUNCHER(736, (S)); \
512 | LAUNCHER(768, (S)); \
513 | LAUNCHER(800, (S)); \
514 | LAUNCHER(832, (S)); \
515 | LAUNCHER(864, (S)); \
516 | LAUNCHER(896, (S)); \
517 | LAUNCHER(928, (S)); \
518 | LAUNCHER(960, (S)); \
519 | LAUNCHER(992, (S)); \
520 | LAUNCHER(1024, (S)); \
521 | default: \
522 | throw std::runtime_error( \
523 | "headdim not support!"); \
524 | break; \
525 | } \
526 | }
527 | #else
528 | // multiple of 64
529 | #define DISPATCH_HEADDIM(LAUNCHER, S) \
530 | { \
531 | switch (d) \
532 | { \
533 | LAUNCHER(256, (S)); \
534 | LAUNCHER(320, (S)); \
535 | LAUNCHER(384, (S)); \
536 | LAUNCHER(448, (S)); \
537 | LAUNCHER(512, (S)); \
538 | LAUNCHER(576, (S)); \
539 | LAUNCHER(640, (S)); \
540 | LAUNCHER(704, (S)); \
541 | LAUNCHER(768, (S)); \
542 | LAUNCHER(832, (S)); \
543 | LAUNCHER(896, (S)); \
544 | LAUNCHER(960, (S)); \
545 | LAUNCHER(1024, (S)); \
546 | default: \
547 | throw std::runtime_error( \
548 | "headdim not support!"); \
549 | break; \
550 | } \
551 | }
552 | #endif
553 |
554 | #endif
555 |
--------------------------------------------------------------------------------
/csrc/extension/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
--------------------------------------------------------------------------------
/csrc/extension/fused_mla_F16F16F16_L1.cu:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/csrc/extension/fused_mla_F16F16F32_L1.cu:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/csrc/extension/fused_mla_templates_L1.cuh:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/csrc/extension/launch_templates.cuh:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/csrc/pybind/ffpa_attn_api.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #define STRINGFY(str) #str
5 | #define TORCH_BINDING_COMMON_EXTENSION(func) \
6 | m.def(STRINGFY(func), &func, STRINGFY(func));
7 |
8 | void ffpa_mma_acc_f16_L1(torch::Tensor Q, torch::Tensor K, torch::Tensor V,
9 | torch::Tensor O, int stages);
10 |
11 | void ffpa_mma_acc_f32_L1(torch::Tensor Q, torch::Tensor K, torch::Tensor V,
12 | torch::Tensor O, int stages);
13 |
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 | TORCH_BINDING_COMMON_EXTENSION(ffpa_mma_acc_f16_L1)
16 | TORCH_BINDING_COMMON_EXTENSION(ffpa_mma_acc_f32_L1)
17 | }
18 |
--------------------------------------------------------------------------------
/csrc/pybind/fused_mla_api.cc:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/env.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 |
6 | class ENV(object):
7 | # ENVs for FFPA kernels compiling
8 |
9 | # Project dir, path to faster-prefill-attention
10 | PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
11 |
12 | # Enable debug mode for FFPA, fast build minimal kernels, default False.
13 | ENABLE_FFPA_DEBUG = bool(int(os.environ.get("ENABLE_FFPA_DEBUG", 0)))
14 |
15 | # Enable build FFPA kernels for Ada devices (sm89, L2O, 4090, etc),
16 | # default True.
17 | ENABLE_FFPA_ADA = bool(int(os.environ.get("ENABLE_FFPA_ADA", 1)))
18 |
19 | # Enable build FFPA kernels for Ampere devices (sm80, A30, A100, etc),
20 | # default True.
21 | ENABLE_FFPA_AMPERE = bool(int(os.environ.get("ENABLE_FFPA_AMPERE", 1)))
22 |
23 | # Enable build FFPA kernels for Hopper devices (sm90, H100, H20, etc),
24 | # default False.
25 | ENABLE_FFPA_HOPPER = bool(int(os.environ.get("ENABLE_FFPA_HOPPER", 0)))
26 |
27 | # Enable all multi stages kernels or not, if True (1~4) else (1~2), default True.
28 | ENABLE_FFPA_ALL_STAGES = bool(int(os.environ.get("ENABLE_FFPA_ALL_STAGES", 1)))
29 |
30 | # Enable all headdims for FFPA kernels or not, default False.
31 | # True, headdim will range from 32 to 1024 with step = 32, range(32, 1024, 32)
32 | # False, headdim will range from 256 to 1024 with step = 64, range(256, 1024, 64)
33 | ENABLE_FFPA_ALL_HEADDIM = bool(int(os.environ.get("ENABLE_FFPA_ALL_HEADDIM", 0)))
34 |
35 | # Enable force Q@K^T use fp16 as MMA Acc dtype for FFPA Acc F32 kernels, default False.
36 | # FFPA Acc F32 kernels MMA Acc = Mixed Q@K^T MMA Acc F16 + P@V MMA Acc F32.
37 | ENABLE_FFPA_FORCE_QK_F16 = bool(int(os.environ.get("ENABLE_FFPA_FORCE_QK_F16", 0)))
38 |
39 | # Enable force P@V use fp16 as MMA Acc dtype, for FFPA Acc F32 kernels, default False.
40 | # FFPA Acc F32 kernels MMA Acc = Mixed Q@K^T MMA Acc F32 + P@V MMA Acc F16.
41 | ENABLE_FFPA_FORCE_PV_F16 = bool(int(os.environ.get("ENABLE_FFPA_FORCE_PV_F16", 0)))
42 |
43 | # Enable FFPA Prefetch QKV at the Appropriate Time Point, default True, boost 5%~10%.
44 | ENABLE_FFPA_PREFETCH_QKV = bool(int(os.environ.get("ENABLE_FFPA_PREFETCH_QKV", 1)))
45 |
46 | # Enable QKV smem shared policy, default False (perfered for MMA & g2s overlap).
47 | # Please, set it as True if you want to run FFPA on low SRAM device.
48 | ENABLE_FFPA_QKV_SMEM_SHARE = bool(
49 | int(os.environ.get("ENABLE_FFPA_QKV_SMEM_SHARE", 0))
50 | )
51 |
52 | # Enable smem swizzle for Q, default True. True: bank conflicts free for Q smem
53 | # via swizzle; False: bank conflicts free for Q smem via padding.
54 | ENABLE_FFPA_SMEM_SWIZZLE_Q = bool(
55 | int(os.environ.get("ENABLE_FFPA_SMEM_SWIZZLE_Q", 1))
56 | )
57 |
58 | # Enable smem swizzle for K, default True. True: bank conflicts free for K smem
59 | # via swizzle; False: bank conflicts free for K smem via padding.
60 | ENABLE_FFPA_SMEM_SWIZZLE_K = bool(
61 | int(os.environ.get("ENABLE_FFPA_SMEM_SWIZZLE_K", 1))
62 | )
63 |
64 | # Enable smem swizzle for V, now default True. True: bank conflicts free for V smem
65 | # via swizzle; False: bank conflicts free for V smem via padding. FIXME(DefTruth):
66 | # swizzle V seems can not get good performance. why? Will enable it by default untill
67 | # I have fixed the performance issue. (Fixed)
68 | ENABLE_FFPA_SMEM_SWIZZLE_V = bool(
69 | int(os.environ.get("ENABLE_FFPA_SMEM_SWIZZLE_V", 1))
70 | )
71 |
72 | # Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage.
73 | ENABLE_FFPA_PERSIST_Q_G2S = bool(
74 | int(os.environ.get("ENABLE_FFPA_PERSIST_Q_G2S", 0))
75 | )
76 |
77 | # Persist load KV g2s for headdim <= 256, more SRAM. If True, auto use flash-attn
78 | # algo that tiling at attention level for headdim <= 256 and auto use ffpa-attn
79 | # fined-grain tiling at MMA level for headdim > 256.
80 | ENABLE_FFPA_PERSIST_KV_G2S = bool(
81 | int(os.environ.get("ENABLE_FFPA_PERSIST_KV_G2S", 0))
82 | )
83 |
84 | # Persist load Q from s2r for headdim < 512 to reduce Q from g2s and s2r IO access,
85 | # but still keep O(1) SRAM complexity. Default value is False. This option will
86 | # introduce more registers for Q frags as the headdim becomes larger. We should
87 | # choose to enable it or not according to the balance between register usage and
88 | # IO access reduction.
89 | ENABLE_FFPA_PERSIST_Q_S2R = bool(
90 | int(os.environ.get("ENABLE_FFPA_PERSIST_Q_S2R", 0))
91 | )
92 |
93 | # Persist V s2r only for small d kernel, more registers.
94 | ENABLE_FFPA_PERSIST_V_S2R = bool(
95 | int(os.environ.get("ENABLE_FFPA_PERSIST_V_S2R", ENABLE_FFPA_PERSIST_KV_G2S))
96 | )
97 |
98 | # Registers Ping pong double buffers for ldmatrix & mma computation overlapping.
99 | ENABLE_FFPA_REGISTERS_PIPE_KV = bool(
100 | int(os.environ.get("ENABLE_FFPA_REGISTERS_PIPE_KV", 0))
101 | )
102 |
103 | # if True: grid(N/Br, H, B) else: grid(N/Br, B * H)
104 | ENABLE_FFPA_LAUNCH_GRID_DNHB = bool(
105 | int(os.environ.get("ENABLE_FFPA_LAUNCH_GRID_DNHB", 0))
106 | )
107 |
108 | @classmethod
109 | def project_dir(cls):
110 | return cls.PROJECT_DIR
111 |
112 | @classmethod
113 | def enable_debug(cls):
114 | return cls.ENABLE_FFPA_DEBUG
115 |
116 | @classmethod
117 | def enable_ada(cls):
118 | return cls.ENABLE_FFPA_ADA
119 |
120 | @classmethod
121 | def enable_ampere(cls):
122 | return cls.ENABLE_FFPA_AMPERE
123 |
124 | @classmethod
125 | def enable_hopper(cls):
126 | return cls.ENABLE_FFPA_HOPPER
127 |
128 | @classmethod
129 | def enable_all_mutistages(cls):
130 | return cls.ENABLE_FFPA_ALL_STAGES
131 |
132 | @classmethod
133 | def enable_all_headdim(cls):
134 | return cls.ENABLE_FFPA_ALL_HEADDIM
135 |
136 | @classmethod
137 | def enable_force_pv_fp16(cls):
138 | return cls.ENABLE_FFPA_FORCE_PV_F16
139 |
140 | @classmethod
141 | def enable_force_qk_fp16(cls):
142 | return cls.ENABLE_FFPA_FORCE_QK_F16
143 |
144 | @classmethod
145 | def enable_prefetch_qkv(cls):
146 | return cls.ENABLE_FFPA_PREFETCH_QKV
147 |
148 | @classmethod
149 | def enable_qkv_smem_share(cls):
150 | return cls.ENABLE_FFPA_QKV_SMEM_SHARE
151 |
152 | @classmethod
153 | def enable_smem_swizzle_q(cls):
154 | return cls.ENABLE_FFPA_SMEM_SWIZZLE_Q
155 |
156 | @classmethod
157 | def enable_smem_swizzle_k(cls):
158 | return cls.ENABLE_FFPA_SMEM_SWIZZLE_K
159 |
160 | @classmethod
161 | def enable_smem_swizzle_v(cls):
162 | return cls.ENABLE_FFPA_SMEM_SWIZZLE_V
163 |
164 | @classmethod
165 | def enable_persist_q_g2s(cls):
166 | return cls.ENABLE_FFPA_PERSIST_Q_G2S
167 |
168 | @classmethod
169 | def enable_persist_kv_g2s(cls):
170 | return cls.ENABLE_FFPA_PERSIST_KV_G2S
171 |
172 | @classmethod
173 | def enable_persist_q_s2r(cls):
174 | return cls.ENABLE_FFPA_PERSIST_Q_S2R
175 |
176 | @classmethod
177 | def enable_persist_v_s2r(cls):
178 | if cls.enable_persist_kv_g2s():
179 | return cls.ENABLE_FFPA_PERSIST_V_S2R
180 | return False
181 |
182 | @classmethod
183 | def enable_registers_pipe_kv(cls):
184 | return cls.ENABLE_FFPA_REGISTERS_PIPE_KV
185 |
186 | @classmethod
187 | def enable_launch_grid_dnhb(cls):
188 | return cls.ENABLE_FFPA_LAUNCH_GRID_DNHB
189 |
190 | @classmethod
191 | def env_cuda_cflags(cls):
192 | extra_env_cflags = []
193 | if cls.enable_debug():
194 | extra_env_cflags.append("-DENABLE_FFPA_DEBUG")
195 | if cls.enable_all_mutistages():
196 | extra_env_cflags.append("-DENABLE_FFPA_ALL_STAGES")
197 | if cls.enable_all_headdim():
198 | extra_env_cflags.append("-DENABLE_FFPA_ALL_HEADDIM")
199 | if cls.enable_force_qk_fp16():
200 | extra_env_cflags.append("-DENABLE_FFPA_FORCE_QK_F16")
201 | if cls.enable_force_pv_fp16():
202 | extra_env_cflags.append("-DENABLE_FFPA_FORCE_PV_F16")
203 | if cls.enable_prefetch_qkv():
204 | extra_env_cflags.append("-DENABLE_FFPA_PREFETCH_QKV")
205 | if cls.enable_qkv_smem_share():
206 | extra_env_cflags.append("-DENABLE_FFPA_QKV_SMEM_SHARE")
207 | if cls.enable_smem_swizzle_q():
208 | extra_env_cflags.append("-DENABLE_FFPA_SMEM_SWIZZLE_Q")
209 | if cls.enable_smem_swizzle_k():
210 | extra_env_cflags.append("-DENABLE_FFPA_SMEM_SWIZZLE_K")
211 | if cls.enable_smem_swizzle_v():
212 | extra_env_cflags.append("-DENABLE_FFPA_SMEM_SWIZZLE_V")
213 | if cls.enable_persist_q_g2s():
214 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_Q_G2S")
215 | if cls.enable_persist_kv_g2s():
216 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_KV_G2S")
217 | if cls.enable_persist_q_s2r():
218 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_Q_S2R")
219 | if cls.enable_persist_v_s2r():
220 | extra_env_cflags.append("-DENABLE_FFPA_PERSIST_V_S2R")
221 | if cls.enable_registers_pipe_kv():
222 | extra_env_cflags.append("-DENABLE_FFPA_REGISTERS_PIPE_KV")
223 | if cls.enable_launch_grid_dnhb():
224 | extra_env_cflags.append("-DENBALE_FFPA_LAUNCH_GRID_DNHB")
225 |
226 | if cls.enable_persist_kv_g2s():
227 | assert (
228 | cls.enable_persist_q_g2s()
229 | ), "PERSIST_Q_G2S must be enable if PERSIST_KV_G2S is enabled."
230 | if cls.enable_qkv_smem_share():
231 | assert (
232 | cls.enable_persist_q_s2r()
233 | ), "PERSIST_Q_S2R must be enable if QKV_SMEM_SHARE and "
234 | "PERSIST_KV_G2S are enabled."
235 | else:
236 | assert not all(
237 | (cls.enable_persist_q_s2r(), cls.enable_persist_q_g2s())
238 | ), "PERSIST_Q_G2S and PERSIST_Q_S2R can not both enabled."
239 | assert not all(
240 | (cls.enable_qkv_smem_share(), cls.enable_persist_q_g2s())
241 | ), "PERSIST_Q_G2S and QKV_SMEM_SHARE can not both enabled."
242 | assert not all(
243 | (cls.enable_qkv_smem_share(), cls.enable_persist_kv_g2s())
244 | ), "PERSIST_KV_G2S and QKV_SMEM_SHARE can not both enabled."
245 | return extra_env_cflags
246 |
247 | @classmethod
248 | def list_ffpa_env(cls):
249 | def formatenv(name, value):
250 | try:
251 | print(
252 | f"{name:<30}: {str(value):<5} -> command:"
253 | f" export {name}={int(value)}"
254 | )
255 | except Exception:
256 | print(f"{name:<30}: {value}")
257 |
258 | pretty_print_line("FFPA-ATTN ENVs")
259 | formatenv("PROJECT_DIR", cls.project_dir())
260 | formatenv("ENABLE_FFPA_DEBUG", cls.enable_debug())
261 | formatenv("ENABLE_FFPA_ADA", cls.enable_ada())
262 | formatenv("ENABLE_FFPA_AMPERE", cls.enable_ampere())
263 | formatenv("ENABLE_FFPA_HOPPER", cls.enable_hopper())
264 | formatenv("ENABLE_FFPA_ALL_STAGES", cls.enable_all_mutistages())
265 | formatenv("ENABLE_FFPA_ALL_HEADDIM", cls.enable_all_headdim())
266 | formatenv("ENABLE_FFPA_PREFETCH_QKV", cls.enable_prefetch_qkv())
267 | formatenv("ENABLE_FFPA_FORCE_QK_F16", cls.enable_force_qk_fp16())
268 | formatenv("ENABLE_FFPA_FORCE_PV_F16", cls.enable_force_pv_fp16())
269 | formatenv("ENABLE_FFPA_PERSIST_Q_G2S", cls.enable_persist_q_g2s())
270 | formatenv("ENABLE_FFPA_PERSIST_KV_G2S", cls.enable_persist_kv_g2s())
271 | formatenv("ENABLE_FFPA_PERSIST_Q_S2R", cls.enable_persist_q_s2r())
272 | formatenv("ENABLE_FFPA_PERSIST_V_S2R", cls.enable_persist_v_s2r())
273 | formatenv("ENABLE_FFPA_QKV_SMEM_SHARE", cls.enable_qkv_smem_share())
274 | formatenv("ENABLE_FFPA_SMEM_SWIZZLE_Q", cls.enable_smem_swizzle_q())
275 | formatenv("ENABLE_FFPA_SMEM_SWIZZLE_K", cls.enable_smem_swizzle_k())
276 | formatenv("ENABLE_FFPA_SMEM_SWIZZLE_V", cls.enable_smem_swizzle_v())
277 | formatenv("ENABLE_FFPA_REGISTERS_PIPE_KV", cls.enable_registers_pipe_kv())
278 | formatenv("ENABLE_FFPA_LAUNCH_GRID_DNHB", cls.enable_launch_grid_dnhb())
279 | pretty_print_line()
280 |
281 | @staticmethod
282 | def get_device_name():
283 | device_name = torch.cuda.get_device_name(torch.cuda.current_device())
284 | # since we will run GPU on WSL2, so add WSL2 tag.
285 | if "Laptop" in device_name:
286 | device_name += " WSL2"
287 | return device_name
288 |
289 | @staticmethod
290 | def get_device_capability():
291 | return torch.cuda.get_device_capability(torch.cuda.current_device())
292 |
293 | @staticmethod
294 | def get_build_sources(build_pkg: bool = False):
295 | def csrc(sub_dir, filename):
296 | csrc_file = f"{ENV.project_dir()}/csrc/{sub_dir}/{filename}"
297 | if ENV.enable_debug() or build_pkg:
298 | pretty_print_line(f"csrc_file: {csrc_file}", sep="", mode="left")
299 | return csrc_file
300 |
301 | if ENV.enable_debug() or build_pkg:
302 | pretty_print_line()
303 | build_sources = [
304 | csrc("pybind", "ffpa_attn_api.cc"),
305 | csrc("cuffpa", "ffpa_attn_F16F16F16_L1.cu"),
306 | csrc("cuffpa", "ffpa_attn_F16F16F32_L1.cu"),
307 | ]
308 | if ENV.enable_debug() or build_pkg:
309 | pretty_print_line()
310 | return build_sources
311 |
312 | @staticmethod
313 | def get_build_cuda_cflags(build_pkg: bool = False):
314 | device_name = ENV.get_device_name()
315 | extra_cuda_cflags = []
316 | extra_cuda_cflags.append("-O3")
317 | extra_cuda_cflags.append("-std=c++17")
318 | extra_cuda_cflags.append("-U__CUDA_NO_HALF_OPERATORS__")
319 | extra_cuda_cflags.append("-U__CUDA_NO_HALF_CONVERSIONS__")
320 | extra_cuda_cflags.append("-U__CUDA_NO_HALF2_OPERATORS__")
321 | extra_cuda_cflags.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
322 | extra_cuda_cflags.append("--expt-relaxed-constexpr")
323 | extra_cuda_cflags.append("--expt-extended-lambda")
324 | extra_cuda_cflags.append("--use_fast_math")
325 | extra_cuda_cflags.append(
326 | "-diag-suppress 177" if not build_pkg else "--ptxas-options=-v"
327 | )
328 | extra_cuda_cflags.append(
329 | "-Xptxas -v" if not build_pkg else "--ptxas-options=-O3"
330 | )
331 | extra_cuda_cflags.append(
332 | "-DBUILD_FFPA_ATTN_MMA_L20" if "L20" in device_name else ""
333 | )
334 | extra_cuda_cflags.append(
335 | "-DBUILD_FFPA_ATTN_MMA_4090" if "4090" in device_name else ""
336 | )
337 | extra_cuda_cflags.append(
338 | "-DBUILD_FFPA_ATTN_MMA_3080" if "3080" in device_name else ""
339 | )
340 | extra_cuda_cflags.extend(ENV.env_cuda_cflags())
341 | extra_cuda_cflags.append(f"-I {ENV.project_dir()}/include")
342 | extra_cuda_cflags.append(f"-I {ENV.project_dir()}/csrc/cuffpa")
343 | return extra_cuda_cflags
344 |
345 | @staticmethod
346 | def get_build_cflags():
347 | extra_cflags = []
348 | extra_cflags.append("-std=c++17")
349 | return extra_cflags
350 |
351 | @staticmethod
352 | def get_cuda_bare_metal_version(cuda_dir):
353 | # helper function to get cuda version
354 | import subprocess
355 |
356 | from packaging.version import parse
357 |
358 | raw_output = subprocess.check_output(
359 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
360 | )
361 | output = raw_output.split()
362 | release_idx = output.index("release") + 1
363 | bare_metal_version = parse(output[release_idx].split(",")[0])
364 |
365 | return raw_output, bare_metal_version
366 |
367 | @staticmethod
368 | def build_ffpa_from_sources(verbose: bool = False):
369 | from torch.utils.cpp_extension import load
370 |
371 | torch_arch_list_env = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
372 | # Load the CUDA kernel as a python module
373 | pretty_print_line(
374 | f"Loading ffpa_attn lib on device: {ENV.get_device_name()}, "
375 | f"capability: {ENV.get_device_capability()}, "
376 | f"Arch ENV: {torch_arch_list_env}"
377 | )
378 | return load(
379 | name="pyffpa_cuda",
380 | sources=ENV.get_build_sources(),
381 | extra_cuda_cflags=ENV.get_build_cuda_cflags(),
382 | extra_cflags=ENV.get_build_cflags(),
383 | verbose=verbose,
384 | )
385 |
386 | @staticmethod
387 | def try_load_ffpa_library(force_build: bool = False, verbose: bool = False):
388 | use_ffpa_attn_package = False
389 | if not force_build:
390 | # check if can import ffpa_attn
391 | try:
392 | import ffpa_attn
393 |
394 | pretty_print_line("Import ffpa_attn library done, use it!")
395 | use_ffpa_attn_package = True
396 | return ffpa_attn, use_ffpa_attn_package
397 | except Exception:
398 | pretty_print_line("Can't import ffpa_attn, force build from sources")
399 | pretty_print_line(
400 | "Also may need export LD_LIBRARY_PATH="
401 | "PATH-TO/torch/lib:$LD_LIBRARY_PATH"
402 | )
403 | ffpa_attn = ENV.build_ffpa_from_sources(verbose=verbose)
404 | use_ffpa_attn_package = False
405 | return ffpa_attn, use_ffpa_attn_package
406 | else:
407 | pretty_print_line("Force ffpa_attn lib build from sources")
408 | ffpa_attn = ENV.build_ffpa_from_sources(verbose=verbose)
409 | use_ffpa_attn_package = False
410 | return ffpa_attn, use_ffpa_attn_package
411 |
412 |
413 | def pretty_print_line(
414 | m: str = "", sep: str = "-", mode: str = "center", width: int = 150
415 | ):
416 | res_len = width - len(m)
417 | if mode == "center":
418 | left_len = int(res_len / 2)
419 | right_len = res_len - left_len
420 | pretty_line = sep * left_len + m + sep * right_len
421 | elif mode == "left":
422 | pretty_line = m + sep * res_len
423 | else:
424 | pretty_line = sep * res_len + m
425 | print(pretty_line)
426 |
427 |
428 | if __name__ == "__main__":
429 | # Debug: show FFPA ENV information. run: python3 env.py
430 | ENV.list_ffpa_env()
431 |
--------------------------------------------------------------------------------
/ffpa_attn/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | .egg-info
22 | dist
23 |
--------------------------------------------------------------------------------
/ffpa_attn/__init__.py:
--------------------------------------------------------------------------------
1 | from .interface import (
2 | faster_prefill_attn_func,
3 | ffpa,
4 | ffpa_acc_f16_L1,
5 | ffpa_acc_f32_L1,
6 | ffpa_mma_acc_f16_L1,
7 | ffpa_mma_acc_f32_L1,
8 | LevelType,
9 | MMAAccType,
10 | )
11 | from .version import __version__
12 |
13 |
14 | # e.g pyffpa.L1
15 | L1 = LevelType.L1
16 | L2 = LevelType.L2
17 | L3 = LevelType.L3
18 | # e.g pyffpa.FP32
19 | FP32 = MMAAccType.FP32
20 | FP16 = MMAAccType.FP16
21 |
--------------------------------------------------------------------------------
/ffpa_attn/interface.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from functools import partial
3 | from typing import Optional
4 |
5 | import torch
6 |
7 | # from pyffpa_cuda.cpython.*.so
8 | from pyffpa_cuda import ffpa_mma_acc_f16_L1, ffpa_mma_acc_f32_L1
9 |
10 |
11 | class LevelType(Enum):
12 | L1 = 0
13 | L2 = 1
14 | L3 = 2
15 |
16 |
17 | class MMAAccType(Enum):
18 | FP32 = 0
19 | FP16 = 1
20 |
21 |
22 | def faster_prefill_attn_func(
23 | q: torch.Tensor,
24 | k: torch.Tensor,
25 | v: torch.Tensor,
26 | o: Optional[torch.Tensor] = None,
27 | num_stages: int = 2,
28 | level: LevelType = LevelType.L1,
29 | acc: MMAAccType = MMAAccType.FP32,
30 | ):
31 | # Q, K, V, O: [B, H, N, D] layout
32 | if not isinstance(o, torch.Tensor) or o is None:
33 | o = torch.zeros_like(q)
34 | assert level == LevelType.L1, "only support FFPA L1 level now."
35 | if acc == MMAAccType.FP32:
36 | ffpa_mma_acc_f32_L1(q, k, v, o, num_stages)
37 | else:
38 | ffpa_mma_acc_f16_L1(q, k, v, o, num_stages)
39 | return o
40 |
41 |
42 | ffpa: callable = faster_prefill_attn_func
43 | ffpa_acc_f32_L1 = partial(
44 | faster_prefill_attn_func, level=LevelType.L1, acc=MMAAccType.FP32
45 | )
46 | ffpa_acc_f16_L1 = partial(
47 | faster_prefill_attn_func, level=LevelType.L1, acc=MMAAccType.FP16
48 | )
49 |
--------------------------------------------------------------------------------
/ffpa_attn/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.2" # type: ignore
2 |
--------------------------------------------------------------------------------
/include/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
22 | dist
23 |
--------------------------------------------------------------------------------
/include/cuffpa/cp_async.cuh:
--------------------------------------------------------------------------------
1 | // cp async operations
2 | #pragma once
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | namespace ffpa {
16 | namespace cp_async {
17 |
18 | // Simple wrappers for cp.async/ld/st instructions.
19 | __device__ __forceinline__ void commit_group() {
20 | asm volatile("cp.async.commit_group;\n" ::);
21 | }
22 |
23 | // e.g: wait_group<1>();
24 | template
25 | __device__ __forceinline__ void wait_group() {
26 | asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
27 | }
28 |
29 | // e.g: cp_async(smem_ptr, gmem_ptr);
30 | template
31 | __device__ __forceinline__ void cp_async(
32 | uint32_t smem_ptr, const T* gmem_ptr) {
33 | static_assert(kBytes == 16 || kBytes == 8); // 8 or 4 halfs
34 | if constexpr (kBytes == 16) {
35 | asm volatile(
36 | "cp.async.cg.shared.global.L2::128B "
37 | "[%0], [%1], %2, %3;\n"
38 | ::"r"(smem_ptr), "l"(gmem_ptr),
39 | "n"(16), "r"(16)
40 | );
41 | } else {
42 | asm volatile(
43 | "cp.async.ca.shared.global.L2::128B "
44 | "[%0], [%1], %2, %3;\n"
45 | ::"r"(smem_ptr), "l"(gmem_ptr),
46 | "n"(8), "r"(8)
47 | );
48 | }
49 | }
50 |
51 | // e.g ldg_sync_128b(...);
52 | template
53 | __device__ __forceinline__ void ldg_sync_128b(
54 | T0 * mem_dst_ptr, T1 * gmem_src_ptr) {
55 | using _128b_t = uint4;
56 | _128b_t * dst_128b_ptr = reinterpret_cast<_128b_t*>(
57 | mem_dst_ptr);
58 | _128b_t * src_128b_ptr = reinterpret_cast<_128b_t*>(
59 | gmem_src_ptr);
60 | *(dst_128b_ptr) = *(src_128b_ptr);
61 | }
62 |
63 | // e.g stg_sync_128b(...);
64 | template
65 | __device__ __forceinline__ void stg_sync_128b(
66 | T0 * gmem_dst_ptr, T1 * mem_src_ptr) {
67 | using _128b_t = uint4;
68 | _128b_t * dst_128b_ptr = reinterpret_cast<_128b_t*>(
69 | gmem_dst_ptr);
70 | _128b_t * src_128b_ptr = reinterpret_cast<_128b_t*>(
71 | mem_src_ptr);
72 | *(dst_128b_ptr) = *(src_128b_ptr);
73 | }
74 |
75 | } // cp_async
76 | } // ffpa
77 |
--------------------------------------------------------------------------------
/include/cuffpa/deprecated/mma_utils.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | using namespace nvcuda;
16 |
17 | #define WARP_SIZE 32
18 | #define INT4(value) (reinterpret_cast(&(value))[0])
19 | #define FLOAT2(value) (reinterpret_cast(&(value))[0])
20 | #define FLOAT4(value) (reinterpret_cast(&(value))[0])
21 | #define HALF2(value) (reinterpret_cast(&(value))[0])
22 | #define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
23 | #define LDST32BITS(value) (reinterpret_cast(&(value))[0])
24 | #define LDST64BITS(value) (reinterpret_cast(&(value))[0])
25 | #define LDST128BITS(value) (reinterpret_cast(&(value))[0])
26 | // gmem -> smem
27 | #define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
28 | #define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
29 | #define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
30 | // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
31 | #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
32 | #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
33 | // ldmatrix
34 | #define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
35 | #define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
36 | #define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
37 | #define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
38 | #define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
39 | #define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
40 | // mma m16n8k16 acc f32 or f16
41 | #define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))
42 | #define HMMA16816F32(RD0, RD1, RD2, RD3, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1, RC2, RC3) asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" : "=r"(RD0), "=r"(RD1), "=r"(RD2), "=r"(RD3): "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1), "r"(RC2), "r"(RC3))
43 |
44 |
45 | __device__ __host__ inline
46 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }
47 |
48 | template
49 | __device__ inline T warp_reduce_sum(T val) {
50 | #pragma unroll
51 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
52 | val += __shfl_xor_sync(0xffffffff, val, mask, kWarpSize);
53 | }
54 | return val;
55 | }
56 |
57 | template
58 | __device__ inline T warp_reduce_max(T val) {
59 | #pragma unroll
60 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
61 | val = max(val, __shfl_xor_sync(0xffffffff, val, mask, kWarpSize));
62 | }
63 | return val;
64 | }
65 |
66 | template
67 | __device__ inline void fill_3D_regs(T (&R)[M][N][K], T val) {
68 | #pragma unroll
69 | for (int i = 0; i < M; ++i) {
70 | #pragma unroll
71 | for (int j = 0; j < N; ++j) {
72 | #pragma unroll
73 | for (int k = 0; k < K; ++k) {
74 | R[i][j][k] = val;
75 | }
76 | }
77 | }
78 | }
79 |
80 | template
81 | __device__ inline void fill_2D_regs(T (&R)[M][N], T val) {
82 | #pragma unroll
83 | for (int i = 0; i < M; ++i) {
84 | #pragma unroll
85 | for (int j = 0; j < N; ++j) {
86 | R[i][j] = val;
87 | }
88 | }
89 | }
90 |
91 | template
92 | __device__ inline void fill_1D_regs(T (&S)[M], T val) {
93 | #pragma unroll
94 | for (int i = 0; i < M; ++i) {
95 | S[i] = val;
96 | }
97 | }
98 |
99 | #ifdef FFPA_MMA_DEBUG
100 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) \
101 | { \
102 | if (tid == 0) { \
103 | float2 v_reg = __half22float2(HALF2(R)); \
104 | printf("[T0] " format ", V0=%f, V1=%f\n", \
105 | ##__VA_ARGS__, v_reg.x, v_reg.y); \
106 | } \
107 | }
108 |
109 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) \
110 | { \
111 | if (tid < 32) { \
112 | float2 v_reg = __half22float2(HALF2(R)); \
113 | printf("[T%d] " format ", V0=%f, V1=%f\n", \
114 | tid, ##__VA_ARGS__, v_reg.x, v_reg.y);\
115 | } \
116 | }
117 |
118 | #define FFPA_MMA_PRINT_REG(R, format, ...) \
119 | { \
120 | { \
121 | float2 v_reg = __half22float2(HALF2(R)); \
122 | printf(format", V0=%f, V1=%f\n", \
123 | ##__VA_ARGS__, v_reg.x, v_reg.y); \
124 | } \
125 | }
126 |
127 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) \
128 | { \
129 | { \
130 | float2 v_reg_0 = __half22float2(HALF2(R0)); \
131 | float2 v_reg_1 = __half22float2(HALF2(R1)); \
132 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \
133 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \
134 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \
135 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \
136 | } \
137 | } \
138 | }
139 |
140 | #define FFPA_MMA_CHECK_PRINT_T32_REG(R0, R1, format, ...) \
141 | { \
142 | if (tid < 32){ \
143 | float2 v_reg_0 = __half22float2(HALF2(R0)); \
144 | float2 v_reg_1 = __half22float2(HALF2(R1)); \
145 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \
146 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \
147 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \
148 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \
149 | } \
150 | } \
151 | }
152 |
153 | #define FFPA_MMA_PRINT_T0(format, ...) \
154 | { \
155 | if (tid == 0) { \
156 | printf("[T0] " format, ##__VA_ARGS__); \
157 | } \
158 | }
159 |
160 | #define FFPA_MMA_PRINT_T32(format, ...) \
161 | { \
162 | if (tid < 32) { \
163 | printf("[T%d] " format, tid, ##__VA_ARGS__);\
164 | } \
165 | }
166 |
167 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) \
168 | { \
169 | if (lane_id == 0) { \
170 | float2 v_reg = __half22float2(HALF2(R)); \
171 | printf("[L0] " format", V0=%f, V1=%f\n", \
172 | ##__VA_ARGS__, v_reg.x, v_reg.y); \
173 | } \
174 | }
175 |
176 | #define FFPA_MMA_PRINT_L0(format, ...) \
177 | { \
178 | if (lane_id == 0) { \
179 | printf("[L0] " format, ##__VA_ARGS__); \
180 | } \
181 | }
182 |
183 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) \
184 | { \
185 | if (tid == 0 && blockIdx.z == 0) { \
186 | printf("----------------------------------------\n"); \
187 | printf(format, ##__VA_ARGS__); \
188 | for (int i = 0; i < Br; ++i) { \
189 | for (int j = 0; j < kMmaTileSeqLenK; ++j) { \
190 | printf("[%d][%d]=%f", i, j, (B)[i][j]); \
191 | } \
192 | printf("\n"); \
193 | } \
194 | printf("----------------------------------------\n"); \
195 | } \
196 | __syncthreads(); \
197 | }
198 |
199 | #else
200 |
201 | #define FFPA_MMA_PRINT_REG(R, format, ...) {}
202 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) {}
203 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) {}
204 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) {}
205 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) {}
206 | #define FFPA_MMA_PRINT_T0(format, ...) {}
207 | #define FFPA_MMA_PRINT_T32(format, ...) {}
208 | #define FFPA_MMA_PRINT_L0(format, ...) {}
209 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) {}
210 |
211 | #endif
212 |
213 | #define STRINGFY(str) #str
214 | #define TORCH_BINDING_COMMON_EXTENSION(func) \
215 | m.def(STRINGFY(func), &func, STRINGFY(func));
216 |
217 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
218 | if(((T).options().dtype() != (th_type))) { \
219 | std::cout << "Tensor Info:" << (T).options() << std::endl; \
220 | throw std::runtime_error("values must be "#th_type); \
221 | }
222 |
223 | #define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
224 | if (((T2).size(0) != (T1).size(0)) || \
225 | ((T2).size(1) != (T1).size(1)) || \
226 | ((T2).size(2) != (T1).size(2)) || \
227 | ((T2).size(3) != (T1).size(3))) { \
228 | throw std::runtime_error("Tensor size mismatch!"); \
229 | }
230 |
--------------------------------------------------------------------------------
/include/cuffpa/deprecated/smem_swizzle.cuh:
--------------------------------------------------------------------------------
1 | // Manually SMEM swizzling for bank conflict free.
2 | // ----------------------------------------------------------------
3 | // Manually SMEM swizzling for bank conflict free.
4 | // ----------------------------------------------------------------
5 | // [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
6 | // [INFO] For logical_col_stride > 16, we have to permute the |
7 | // [INFO] smem store layout using col major ZigZag method: |
8 | // [INFO] e.g, --> Q smem logical layout [Br][64]. |
9 | // [INFO] --> col major ZigZag permuted --> |
10 | // [INFO] --> Q smem store layout [4][Br][16]. |
11 | // ----------------------------------------------------------------
12 | // ----------------------------------------------------------------
13 | // -------------------------swizzle layout-------------------------
14 | // --------------------logical col 0~64, step 8--------------------
15 | // ---------------------smem col 0~16, step 8----------------------
16 | // ----------------------------------------------------------------
17 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
18 | // |row 0 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
19 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
20 | // |row 1 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
21 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
22 | // |row 2 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
23 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
24 | // |row 3 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
25 | // ----------------------------------------------------------------
26 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
27 | // |row 4 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
28 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
29 | // |row 5 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
30 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
31 | // |row 6 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
32 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
33 | // |row 7 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
34 | // ----------------------------------------------------------------
35 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
36 | // |row 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
37 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
38 | // |row 9 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
39 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
40 | // |row 10| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
41 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
42 | // |row 11| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
43 | // ----------------------------------------------------------------
44 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
45 | // |row 12| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
46 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
47 | // |row 13| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
48 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
49 | // |row 14| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
50 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
51 | // |row 15| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
52 | // ----------------------------------------------------------------
53 | #pragma once
54 | #include
55 | #include
56 | #include
57 | #include
58 | #include
59 |
60 | // i: row index; j: col index.
61 | template
62 | static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) {
63 | // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep;
64 | static_assert(kColStride <= 16, "Currently, kColStride must be less than or equal to 16.");
65 | static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4.");
66 | static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep.");
67 | if constexpr (kStep == 8) {
68 | return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3;
69 | } else {
70 | static_assert(kStep == 4);
71 | return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2;
72 | }
73 | }
74 |
75 | // i: row index; j: col index
76 | // e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue.
77 | template
78 | static __device__ __forceinline__ int swizzle_permuted_Q_j(int i, int j) {
79 | return swizzle_permuted_j(i, j);
80 | }
81 |
82 | // i: row index; j: col index
83 | // e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue.
84 | template
85 | static __device__ __forceinline__ int swizzle_permuted_K_j(int i, int j) {
86 | return swizzle_permuted_j(i, j);
87 | }
88 |
89 | // i: row index; j: col index
90 | // e.g kColStride = kMmaAtomN * 2 = 16, kStep = 8 -> load 8 half as 128 bits memory issue.
91 | template
92 | static __device__ __forceinline__ int swizzle_permuted_V_j(int i, int j) {
93 | return swizzle_permuted_j(i, j);
94 | }
95 |
--------------------------------------------------------------------------------
/include/cuffpa/logging.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #define STRINGFY(str) #str
17 | #define TORCH_BINDING_COMMON_EXTENSION(func) \
18 | m.def(STRINGFY(func), &func, STRINGFY(func));
19 |
20 | #define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
21 | if(((T).options().dtype() != (th_type))) { \
22 | std::cout << "Tensor Info:" << (T).options() << std::endl; \
23 | throw std::runtime_error("values must be "#th_type); \
24 | }
25 |
26 | #define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
27 | if (((T2).size(0) != (T1).size(0)) || \
28 | ((T2).size(1) != (T1).size(1)) || \
29 | ((T2).size(2) != (T1).size(2)) || \
30 | ((T2).size(3) != (T1).size(3))) { \
31 | throw std::runtime_error("Tensor size mismatch!"); \
32 | }
33 |
34 | #define HALF2(value) (reinterpret_cast(&(value))[0])
35 |
36 |
37 | #ifdef ENABLE_FFPA_DEBUG
38 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) \
39 | { \
40 | if (tid == 0) { \
41 | float2 v_reg = __half22float2(HALF2(R)); \
42 | printf("[T0] " format ", V0=%f, V1=%f\n", \
43 | ##__VA_ARGS__, v_reg.x, v_reg.y); \
44 | } \
45 | }
46 |
47 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) \
48 | { \
49 | if (tid < 32) { \
50 | float2 v_reg = __half22float2(HALF2(R)); \
51 | printf("[T%d] " format ", V0=%f, V1=%f\n", \
52 | tid, ##__VA_ARGS__, v_reg.x, v_reg.y);\
53 | } \
54 | }
55 |
56 | #define FFPA_MMA_PRINT_REG(R, format, ...) \
57 | { \
58 | { \
59 | float2 v_reg = __half22float2(HALF2(R)); \
60 | printf(format", V0=%f, V1=%f\n", \
61 | ##__VA_ARGS__, v_reg.x, v_reg.y); \
62 | } \
63 | }
64 |
65 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) \
66 | { \
67 | { \
68 | float2 v_reg_0 = __half22float2(HALF2(R0)); \
69 | float2 v_reg_1 = __half22float2(HALF2(R1)); \
70 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \
71 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \
72 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \
73 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \
74 | } \
75 | } \
76 | }
77 |
78 | #define FFPA_MMA_CHECK_PRINT_T32_REG(R0, R1, format, ...) \
79 | { \
80 | if (tid < 32){ \
81 | float2 v_reg_0 = __half22float2(HALF2(R0)); \
82 | float2 v_reg_1 = __half22float2(HALF2(R1)); \
83 | if ((fabs(v_reg_0.x - v_reg_1.x) > 0.01f) || \
84 | (fabs(v_reg_0.y - v_reg_1.y) > 0.01f)) { \
85 | printf(format", R0, V0=%f, V1=%f, R1, V0=%f, V1=%f\n", \
86 | ##__VA_ARGS__, v_reg_0.x, v_reg_0.y, v_reg_1.x, v_reg_1.y); \
87 | } \
88 | } \
89 | }
90 |
91 | #define FFPA_MMA_PRINT_T0(format, ...) \
92 | { \
93 | if (tid == 0) { \
94 | printf("[T0] " format, ##__VA_ARGS__); \
95 | } \
96 | }
97 |
98 | #define FFPA_MMA_PRINT_T32(format, ...) \
99 | { \
100 | if (tid < 32) { \
101 | printf("[T%d] " format, tid, ##__VA_ARGS__);\
102 | } \
103 | }
104 |
105 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) \
106 | { \
107 | if (lane_id == 0) { \
108 | float2 v_reg = __half22float2(HALF2(R)); \
109 | printf("[L0] " format", V0=%f, V1=%f\n", \
110 | ##__VA_ARGS__, v_reg.x, v_reg.y); \
111 | } \
112 | }
113 |
114 | #define FFPA_MMA_PRINT_L0(format, ...) \
115 | { \
116 | if (lane_id == 0) { \
117 | printf("[L0] " format, ##__VA_ARGS__); \
118 | } \
119 | }
120 |
121 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) \
122 | { \
123 | if (tid == 0 && blockIdx.z == 0) { \
124 | printf("----------------------------------------\n"); \
125 | printf(format, ##__VA_ARGS__); \
126 | for (int i = 0; i < Br; ++i) { \
127 | for (int j = 0; j < kMmaTileSeqLenK; ++j) { \
128 | printf("[%d][%d]=%f", i, j, (B)[i][j]); \
129 | } \
130 | printf("\n"); \
131 | } \
132 | printf("----------------------------------------\n"); \
133 | } \
134 | __syncthreads(); \
135 | }
136 |
137 | #else
138 |
139 | #define FFPA_MMA_PRINT_REG(R, format, ...) {}
140 | #define FFPA_MMA_CHECK_PRINT_REG(R0, R1, format, ...) {}
141 | #define FFPA_MMA_PRINT_T0_REG(R, format, ...) {}
142 | #define FFPA_MMA_PRINT_T32_REG(R, format, ...) {}
143 | #define FFPA_MMA_PRINT_L0_REG(R, format, ...) {}
144 | #define FFPA_MMA_PRINT_T0(format, ...) {}
145 | #define FFPA_MMA_PRINT_T32(format, ...) {}
146 | #define FFPA_MMA_PRINT_L0(format, ...) {}
147 | #define FFPA_MMA_PRINT_T0_B0_MATRIX(B, format, ...) {}
148 |
149 | #endif
150 |
151 |
--------------------------------------------------------------------------------
/include/cuffpa/mma.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace ffpa {
15 | namespace mma {
16 |
17 | enum class MMAMode {
18 | kAutoZeroFill = 0U,
19 | kInplaceUpdate = 1U,
20 | };
21 |
22 | // Simple wrappers for mma and ldmatrix instructions.
23 | template
24 | __device__ __forceinline__ void m16n8k16_f16f16f16(
25 | uint32_t * RD0, uint32_t * RD1,
26 | uint32_t * RA0, uint32_t * RA1, uint32_t * RA2, uint32_t * RA3,
27 | uint32_t * RB0, uint32_t * RB1
28 | ) {
29 | if constexpr (mma_mode == MMAMode::kInplaceUpdate) {
30 | asm volatile(
31 | "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
32 | "{%0, %1}, "
33 | "{%2, %3, %4, %5}, "
34 | "{%6, %7}, "
35 | "{%8, %9};\n"
36 | : "=r"(RD0[0]), "=r"(RD1[0])
37 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]),
38 | "r"(RB0[0]), "r"(RB1[0]),
39 | "r"(RD0[0]), "r"(RD1[0])
40 | );
41 | } else {
42 | // WARN: seems can not get good performance while stage = 1.
43 | asm volatile(
44 | "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
45 | "{%0, %1}, "
46 | "{%2, %3, %4, %5}, "
47 | "{%6, %7}, "
48 | "{%8, %9};\n"
49 | : "=r"(RD0[0]), "=r"(RD1[0])
50 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]),
51 | "r"(RB0[0]), "r"(RB1[0]),
52 | "r"(0), "r"(0)
53 | );
54 | }
55 | }
56 |
57 | template
58 | __device__ __forceinline__ void m16n8k16_f16f16f32(
59 | uint32_t * RD0, uint32_t * RD1, uint32_t * RD2, uint32_t * RD3,
60 | uint32_t * RA0, uint32_t * RA1, uint32_t * RA2, uint32_t * RA3,
61 | uint32_t * RB0, uint32_t * RB1
62 | ) {
63 | // "h" = .u16 reg; "r" = .u32 reg; "l" = .u64 reg;
64 | // "f" = .f32 reg; "d" = .f64 reg
65 | if constexpr (mma_mode == MMAMode::kInplaceUpdate) {
66 | asm volatile(
67 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
68 | "{%0, %1, %2, %3}, "
69 | "{%4, %5, %6, %7}, "
70 | "{%8, %9}, "
71 | "{%10, %11, %12, %13};\n"
72 | : "=r"(RD0[0]), "=r"(RD1[0]), "=r"(RD2[0]), "=r"(RD3[0])
73 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]),
74 | "r"(RB0[0]), "r"(RB1[0]),
75 | "r"(RD0[0]), "r"(RD1[0]), "r"(RD2[0]), "r"(RD3[0])
76 | );
77 | } else {
78 | // WARN: seems can not get good performance while stage = 1.
79 | asm volatile(
80 | "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
81 | "{%0, %1, %2, %3}, "
82 | "{%4, %5, %6, %7}, "
83 | "{%8, %9}, "
84 | "{%10, %11, %12, %13};\n"
85 | : "=r"(RD0[0]), "=r"(RD1[0]), "=r"(RD2[0]), "=r"(RD3[0])
86 | : "r"(RA0[0]), "r"(RA1[0]), "r"(RA2[0]), "r"(RA3[0]),
87 | "r"(RB0[0]), "r"(RB1[0]),
88 | "r"(0), "r"(0), "r"(0), "r"(0)
89 | );
90 | }
91 | }
92 |
93 | __device__ __forceinline__ void ldmatrix_m8n8x4(
94 | uint32_t * R0, uint32_t * R1, uint32_t * R2, uint32_t * R3,
95 | uint32_t smem_ptr
96 | ) {
97 | asm volatile(
98 | "ldmatrix.sync.aligned.x4.m8n8.shared.b16 "
99 | "{%0, %1, %2, %3}, [%4];\n"
100 | : "=r"(R0[0]), "=r"(R1[0]), "=r"(R2[0]), "=r"(R3[0])
101 | : "r"(smem_ptr)
102 | );
103 | }
104 |
105 | __device__ __forceinline__ void ldmatrix_m8n8x4_trans(
106 | uint32_t * R0, uint32_t * R1, uint32_t * R2, uint32_t * R3,
107 | uint32_t smem_ptr
108 | ) {
109 | asm volatile(
110 | "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 "
111 | "{%0, %1, %2, %3}, [%4];\n"
112 | : "=r"(R0[0]), "=r"(R1[0]), "=r"(R2[0]), "=r"(R3[0])
113 | : "r"(smem_ptr)
114 | );
115 | }
116 |
117 | __device__ __forceinline__ void ldmatrix_m8n8x2(
118 | uint32_t * R0, uint32_t * R1,
119 | uint32_t smem_ptr
120 | ) {
121 | asm volatile(
122 | "ldmatrix.sync.aligned.x2.m8n8.shared.b16 "
123 | "{%0, %1}, [%2];\n"
124 | : "=r"(R0[0]), "=r"(R1[0])
125 | : "r"(smem_ptr)
126 | );
127 | }
128 |
129 | __device__ __forceinline__ void ldmatrix_m8n8x2_trans(
130 | uint32_t * R0, uint32_t * R1,
131 | uint32_t smem_ptr
132 | ) {
133 | asm volatile(
134 | "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 "
135 | "{%0, %1}, [%2];\n"
136 | : "=r"(R0[0]), "=r"(R1[0])
137 | : "r"(smem_ptr)
138 | );
139 | }
140 |
141 | } // mma
142 | } // ffpa
143 |
--------------------------------------------------------------------------------
/include/cuffpa/prefill.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "cuffpa/mma.cuh" // ffpa::mma
3 | #include "cuffpa/warp.cuh" // ffpa::warp
4 | #include "cuffpa/swizzle.cuh" // ffpa::swizzle
5 | #include "cuffpa/cp_async.cuh" // ffpa::cp_async
6 | #include "cuffpa/utils.cuh" // ffpa::utils
7 |
8 | namespace ffpa {
9 | namespace prefill {
10 | // prefill utils: prefetch/load QKV g2s funcs, rescale/softmax funcs etc.
11 |
12 | template<
13 | const int kHeadDim, // Headdim, 32,64,128
14 | const int kMmaAtomM, // MMA Atom M, 16
15 | const int kMmaAtomN, // MMA Atom N, 8
16 | const int kMmaAtomK, // MMA Atom K, 16
17 | const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
18 | const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
19 | const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
20 | const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
21 | const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
22 | const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
23 | const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
24 | const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
25 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
26 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
27 | const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
28 | const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point.
29 | const int kPrefetchPV, // Prefetch V at the Appropriate Time Point.
30 | const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V.
31 | const int kPersistQs2r, // Persist load Q s2r for headdim < 320, more registers, but still keep O(1) SRAM.
32 | const int kPersistQg2s, // Persist load Q g2s for headdim < 320, more SRAM, but still keep register usage.
33 | const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping.
34 | const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4)
35 | const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4)
36 | const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
37 | const int kPadK,
38 | const int kPadV
39 | >
40 | __device__ __forceinline__ void check_large_d_compiling_states() {
41 | // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
42 | // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
43 | static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
44 | static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
45 | static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
46 | static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
47 | static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
48 | kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
49 | static_assert(kMmaAccFloat32QK == 0 || kMmaAccFloat32QK == 1);
50 | static_assert(kMmaAccFloat32PV == 0 || kMmaAccFloat32PV == 1);
51 | static_assert(kOStorageAccFloat32 == 0 || kOStorageAccFloat32 == 1);
52 | // Make sure that Br >= Bc, for shared memory reuse.
53 | static_assert(
54 | (kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ) >=
55 | (kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK));
56 | static_assert(kPrefetchQK == 0 || kPrefetchQK == 1);
57 | static_assert(kPrefetchPV == 0 || kPrefetchPV == 1);
58 | static_assert(kShareSmemQKV == 0 || kShareSmemQKV == 1);
59 | // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM.
60 | static_assert(kPersistQs2r == 0 || kPersistQs2r == 1);
61 | // Persist load Q g2s for headdim < 512, more SRAM, but still keep register usage.
62 | static_assert(kPersistQg2s == 0 || kPersistQg2s == 1);
63 | // kPersistQg2s and kPersistQs2r can not both enabled for large d kernel.
64 | static_assert((kPersistQg2s & kPersistQs2r) == 0);
65 | // kPersistQg2s and kShareSmemQKV can not both enabled for large d kernel..
66 | static_assert((kPersistQg2s & kShareSmemQKV) == 0);
67 | // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping.
68 | static_assert(kRegPipeKV == 0 || kRegPipeKV == 1);
69 | // May apply different multi stages policy for QK and V.
70 | static_assert(kStageQK < 5 && kStageQK > 0); // QK (<=4)
71 | static_assert(kStagePV < 5 && kStagePV > 0); // V (<=4)
72 | static_assert(kPadQ >= 0 && kPadQ % 8 == 0); // 0,8,16
73 | static_assert(kPadK >= 0 && kPadK % 8 == 0); // 0,8,16
74 | static_assert(kPadV >= 0 && kPadV % 8 == 0); // 0,8,16
75 | }
76 |
77 |
78 | template<
79 | const int kHeadDim, // Headdim, 32~1024
80 | const int kMmaAtomM, // MMA Atom M, 16
81 | const int kMmaAtomN, // MMA Atom N, 8
82 | const int kMmaAtomK, // MMA Atom K, 16
83 | const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
84 | const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
85 | const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
86 | const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
87 | const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
88 | const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
89 | const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
90 | const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
91 | const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
92 | const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
93 | const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
94 | const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point.
95 | const int kPrefetchPV, // Prefetch V at the Appropriate Time Point.
96 | const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V.
97 | const int kPersistQs2r, // Persist load Q s2r for headdim <= 128, more registers.
98 | const int kPersistVs2r, // Persist load V s2r for headdim <= 128, more registers.
99 | const int kRegPipeKV, // Registers Ping pong double buffers for ldmatrix s2r & mma computation overlapping.
100 | const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4)
101 | const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4)
102 | const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
103 | const int kPadK,
104 | const int kPadV
105 | >
106 | __device__ __forceinline__ void check_small_d_compiling_states() {
107 | // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
108 | // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
109 | static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
110 | static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
111 | static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
112 | static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
113 | static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
114 | kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
115 | static_assert(kMmaAccFloat32QK == 0 || kMmaAccFloat32QK == 1);
116 | static_assert(kMmaAccFloat32PV == 0 || kMmaAccFloat32PV == 1);
117 | static_assert(kOStorageAccFloat32 == 0 || kOStorageAccFloat32 == 1);
118 | // Make sure that Br >= Bc, for shared memory reuse.
119 | static_assert(
120 | (kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ) >=
121 | (kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK));
122 | static_assert(kPrefetchQK == 0 || kPrefetchQK == 1);
123 | static_assert(kPrefetchPV == 0 || kPrefetchPV == 1);
124 | static_assert(kShareSmemQKV == 0 || kShareSmemQKV == 1);
125 | // Persist load Q s2r for headdim <= 128, more registers.
126 | static_assert(kPersistQs2r == 0 || kPersistQs2r == 1);
127 | // Persist load V s2r for headdim <= 128, more registers.
128 | static_assert(kPersistVs2r == 0 || kPersistVs2r == 1);
129 | if constexpr (kShareSmemQKV) {
130 | // kPersistQs2r must be enabled is set kShareSmemQKV as 1
131 | static_assert(kPersistQs2r == 1);
132 | }
133 | // Registers Ping pong double buffers for ldmatrix s2r & mma
134 | // computation overlapping.
135 | static_assert(kRegPipeKV == 0 || kRegPipeKV == 1);
136 | // kRegPipeKV and kPersistVs2r can not both enabled.
137 | static_assert((kRegPipeKV & kPersistVs2r) == 0);
138 | // May apply different multi stages policy for QK and V.
139 | static_assert(kStageQK < 5 && kStageQK > 0); // QK (<=4)
140 | static_assert(kStagePV < 5 && kStagePV > 0); // V (<=4)
141 | static_assert(kPadQ >= 0 && kPadQ % 8 == 0); // 0,8,16
142 | static_assert(kPadK >= 0 && kPadK % 8 == 0); // 0,8,16
143 | static_assert(kPadV >= 0 && kPadV % 8 == 0); // 0,8,16
144 | }
145 |
146 |
147 | template<
148 | const int BrOrBc,
149 | const int kTileSize,
150 | const int kHeadDim,
151 | const int kMmaAtomK,
152 | const int kNumThreads,
153 | const int kPad
154 | >
155 | __device__ __forceinline__ void cp_async_qkv_g2s(
156 | uint32_t smem_base_ptr, // QKV smem base ptr
157 | const half * gmem_ptr, // QKV gmem ptr
158 | const int gmem_offset, // QKV gmem_offset
159 | const int n_tile_id, // seqlen offset, Q_tile_id * Br, tile_K_seqlen * Bc
160 | const int d_tile_id, // headdim offset, tile_K_d * kMmaAtomK, tile_V_d * kMmaAtomN * 2
161 | const int stage // stage * QKV tile_size
162 | ) {
163 | // QK: tile_K_d < (kHeadDim / kMmaAtomK)
164 | // V: tile_V_d < (kHeadDim / kMmaAtomN * 2)
165 | if (d_tile_id >= (kHeadDim / kMmaAtomK)) { return; }
166 | const int tid = threadIdx.x; // within block
167 | constexpr bool kSwizzle = (kPad == 0) ? true : false;
168 | // Mapping QKV tid -> smem, tile size [64/128, 16]
169 | // Br 64, tid / 2, row 0~64
170 | const int load_smem_BrOrBc = (tid / (kNumThreads / BrOrBc));
171 | // (tid % 2) * 8, 0,8,...
172 | const int load_smem_d = (
173 | tid % (kNumThreads / BrOrBc)) * (kMmaAtomK / (kNumThreads / BrOrBc));
174 | // Mapping QKV tid -> gmem, tile size [64/128, 16], row offset by
175 | // n_tile_id(seqlen), col offset by d_tile_id(Headdim).
176 | const int load_gmem_BrOrBc = (n_tile_id * BrOrBc) + load_smem_BrOrBc;
177 | const int load_gmem_d = (d_tile_id * kMmaAtomK) + load_smem_d; // 0,8
178 | // Offset by QKV global gmem_offset.
179 | const int load_gmem_addr = (
180 | gmem_offset + load_gmem_BrOrBc * kHeadDim + load_gmem_d);
181 |
182 | // cp async & apply swizzle or padding.
183 | #pragma unroll
184 | for (int i = 0; i < (kMmaAtomK / (kNumThreads / BrOrBc)); i += 8) {
185 | const uint32_t load_smem_ptr = (
186 | smem_base_ptr + (stage * kTileSize +
187 | load_smem_BrOrBc * (kMmaAtomK + kPad) +
188 | (kSwizzle ? swizzle::permuted(
189 | load_smem_BrOrBc, load_smem_d + i) :
190 | load_smem_d + i )
191 | ) * sizeof(half));
192 | cp_async::cp_async<16>(load_smem_ptr, &(gmem_ptr[load_gmem_addr + i]));
193 | }
194 | // cp_async::commit_group();
195 | }
196 |
197 |
198 | template<
199 | const int kTrans,
200 | const int kNumRegs,
201 | const int kTileSize,
202 | const int kMmaAtomM,
203 | const int kMmaAtomN,
204 | const int kMmaAtomK,
205 | const int kPad
206 | >
207 | __device__ __forceinline__ void sync_fetch_qkv_frags_s2r(
208 | uint32_t smem_base_ptr, // QKV smem base ptr
209 | uint32_t * R, // Register ptr, R_QKV
210 | const int mma_tile_id, // Q warp_QP 0~num MMAs, KV warp_KV 0
211 | const int warp_tile_id, // Q 0, KV 0~kWarpTileSeqLenK
212 | const int n_tile_id, // seqlen QK 0, V tile_V_Bc
213 | const int stage
214 | ) {
215 | const int lane_id = threadIdx.x % WARP_SIZE; // 0~31
216 | constexpr bool kSwizzle = (kPad == 0) ? true : false;
217 | if constexpr (kTrans) {
218 | // load V m8n8x2 via ldmatrix.x2.trans
219 | static_assert(kNumRegs == 2);
220 | // mma_tile_id = warp_KV == 0, warp_tile_id = (j % 2), n_tile_id = tile_V_Bc
221 | // warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + (j % 2) * kMmaAtomN;
222 | const int warp_smem_d = warp_tile_id * kMmaAtomN;
223 | const int lane_smem_Bc = n_tile_id * kMmaAtomK + lane_id % 16;
224 | const int lane_smem_d = warp_smem_d; // 0,8
225 | const uint32_t lane_smem_ptr = (
226 | smem_base_ptr + (stage * kTileSize +
227 | lane_smem_Bc * (kMmaAtomN * 2 + kPad) +
228 | (kSwizzle ? swizzle::permuted(
229 | lane_smem_Bc, lane_smem_d):
230 | lane_smem_d)
231 | ) * sizeof(half)
232 | );
233 | mma::ldmatrix_m8n8x2_trans(&R[0], &R[1], lane_smem_ptr);
234 | } else {
235 | static_assert(kNumRegs == 2 || kNumRegs == 4);
236 | if constexpr (kNumRegs == 4) {
237 | // load Q m8n8x4 via ldmatrix.x4
238 | // mma_tile_id = warp_QP, kWarpTileSeqLenQ=1
239 | // warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + 0 * kMmaAtomM
240 | const int warp_smem_Br = mma_tile_id * (kMmaAtomM);
241 | const int lane_smem_Br = warp_smem_Br + lane_id % 16; // 0~15
242 | const int lane_smem_d = (lane_id / 16) * 8; // 0,8
243 | const uint32_t lane_smem_ptr = (
244 | smem_base_ptr + (stage * kTileSize +
245 | lane_smem_Br * (kMmaAtomK + kPad) +
246 | (kSwizzle ? swizzle::permuted(
247 | lane_smem_Br, lane_smem_d):
248 | lane_smem_d)
249 | ) * sizeof(half)
250 | );
251 | mma::ldmatrix_m8n8x4(&R[0], &R[1], &R[2], &R[3], lane_smem_ptr);
252 | } else {
253 | // load K m8n8x2 via ldmatrix.x2
254 | // mma_tile_id = warp_KV == 0, warp_tile_id = j
255 | // warp_smem_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
256 | const int warp_smem_Bc = warp_tile_id * kMmaAtomN;
257 | const int lane_smem_Bc = warp_smem_Bc + lane_id % 8; // 0~7
258 | const int lane_smem_d = ((lane_id / 8) % 2) * 8; // 0,8
259 | const uint32_t lane_smem_ptr = (
260 | smem_base_ptr + (stage * kTileSize +
261 | lane_smem_Bc * (kMmaAtomK + kPad) +
262 | (kSwizzle ? swizzle::permuted(
263 | lane_smem_Bc, lane_smem_d):
264 | lane_smem_d )
265 | ) * sizeof(half)
266 | );
267 | mma::ldmatrix_m8n8x2(&R[0], &R[1], lane_smem_ptr);
268 | }
269 | }
270 | }
271 |
272 |
273 | template
274 | __device__ __forceinline__ void sync_online_safe_softmax(
275 | uint32_t * R_S, // &R_S[0][0][0]
276 | const float scale, // 1 / sqrt(d)
277 | float * lane_row_max_new, // &lane_row_max_new[0][0]
278 | float * lane_row_sum_new, // &lane_row_sum_new[0][0]
279 | float * lane_block_row_max_old, // &lane_block_row_max_old[0][0]
280 | float * lane_block_row_sum_old // &lane_block_row_sum_old[0][0]
281 | ) {
282 | if constexpr (kMmaAccFloat32) {
283 | // Row max for [Br,Bc] tile, Thread -> Warp -> Block.
284 | { // kWarpTileSeqLenQ = 1
285 | // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc.
286 | #pragma unroll
287 | for (int j = 0; j < kWarpTileSeqLenK; ++j) {
288 | const float* t_fptr_S_0_1 = reinterpret_cast(R_S + j * 4); // &R_S[0][j][0]
289 | // This should be the row max after S = (Q @ K^T) / sqrt(d)
290 | const float tmp_max_0 = max(t_fptr_S_0_1[0], t_fptr_S_0_1[1]) * scale;
291 | const float tmp_max_1 = max(t_fptr_S_0_1[2], t_fptr_S_0_1[3]) * scale;
292 | lane_row_max_new[0] = max(lane_row_max_new[0], tmp_max_0);
293 | lane_row_max_new[1] = max(lane_row_max_new[1], tmp_max_1);
294 | } // end for kWarpTileSeqLenK
295 |
296 | // Warp level reduce max, warp_size = 4
297 | // Each thread contains the maximum of 2 rows of Br,
298 | // and only the values of T0, T4, ..., T28 are used.
299 | lane_row_max_new[0] = warp::reduce_max(lane_row_max_new[0]);
300 | lane_row_max_new[1] = warp::reduce_max(lane_row_max_new[1]);
301 | } // end for kWarpTileSeqLenQ
302 |
303 | // static_assert(kWarpTileSeqLenQ == 1);
304 | // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
305 | { // kWarpTileSeqLenQ = 1
306 | // Use latest global row max without update.
307 | // Br 0, row_id, 0~7, 16~23, 32~39, 48~55;
308 | // Br 1, row_id, 8~15, 24~31, 40~47, 56~63;
309 | // Apply m_new = max(m_old, m_new) here.
310 | const float block_row_max_new_0 = max(lane_block_row_max_old[0], lane_row_max_new[0]);
311 | const float block_row_max_new_1 = max(lane_block_row_max_old[1], lane_row_max_new[1]);
312 |
313 | #pragma unroll
314 | for (int j = 0; j < kWarpTileSeqLenK; ++j) {
315 | // R_S[][][4] 4 32bit registers with each contains 1 F32 element.
316 | // (x,y) 0~7->{c0, c1}, (z,w)->8~15 {c2, c3}
317 | float* t_fptr_S_0_1 = reinterpret_cast(R_S + j * 4);
318 | half* t_hptr_S_0_1 = reinterpret_cast< half*>(R_S + j * 4);
319 | // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z in registers;
320 | t_fptr_S_0_1[0] = __expf(__fmaf_rn(t_fptr_S_0_1[0], scale, - block_row_max_new_0));
321 | t_fptr_S_0_1[1] = __expf(__fmaf_rn(t_fptr_S_0_1[1], scale, - block_row_max_new_0));
322 | t_fptr_S_0_1[2] = __expf(__fmaf_rn(t_fptr_S_0_1[2], scale, - block_row_max_new_1));
323 | t_fptr_S_0_1[3] = __expf(__fmaf_rn(t_fptr_S_0_1[3], scale, - block_row_max_new_1));
324 | lane_row_sum_new[0] += (t_fptr_S_0_1[0] + t_fptr_S_0_1[1]);
325 | lane_row_sum_new[1] += (t_fptr_S_0_1[2] + t_fptr_S_0_1[3]);
326 | // Update R_S for P[Br,Bc] = Exp(S-m), point wise.
327 | // Also convert F32 -> half for P@V MMA, reuse R_S as P.
328 | t_hptr_S_0_1[0] = __float2half_rn(t_fptr_S_0_1[0]);
329 | t_hptr_S_0_1[1] = __float2half_rn(t_fptr_S_0_1[1]);
330 | t_hptr_S_0_1[2] = __float2half_rn(t_fptr_S_0_1[2]);
331 | t_hptr_S_0_1[3] = __float2half_rn(t_fptr_S_0_1[3]);
332 | } // end for kWarpTileSeqLenK
333 |
334 | // Warp level reduce sum, warp_size = 4
335 | lane_row_sum_new[0] = warp::reduce_sum(lane_row_sum_new[0]);
336 | lane_row_sum_new[1] = warp::reduce_sum(lane_row_sum_new[1]);
337 | }
338 |
339 | } else {
340 | // MMA Acc F16
341 | // Row max for [Br,Bc] tile, Thread -> Warp -> Block.
342 | { // kWarpTileSeqLenQ = 1
343 | // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc.
344 | #pragma unroll
345 | for (int j = 0; j < kWarpTileSeqLenK; ++j) {
346 | const half* t_hptr_S_0_1 = reinterpret_cast(R_S + j * 2);
347 | // This should be the row max after S = (Q @ K^T) / sqrt(d)
348 | const float tmp_max_0 = __half2float(__hmax(t_hptr_S_0_1[0], t_hptr_S_0_1[1])) * scale;
349 | const float tmp_max_1 = __half2float(__hmax(t_hptr_S_0_1[2], t_hptr_S_0_1[3])) * scale;
350 | lane_row_max_new[0] = max(lane_row_max_new[0], tmp_max_0);
351 | lane_row_max_new[1] = max(lane_row_max_new[1], tmp_max_1);
352 | } // end for kWarpTileSeqLenK
353 |
354 | // Warp level reduce max, warp_size = 4
355 | // Each thread contains the maximum of 2 rows of Br,
356 | // and only the values of T0, T4, ..., T28 are used.
357 | lane_row_max_new[0] = warp::reduce_max(lane_row_max_new[0]);
358 | lane_row_max_new[1] = warp::reduce_max(lane_row_max_new[1]);
359 | } // end for kWarpTileSeqLenQ
360 |
361 | // static_assert(kWarpTileSeqLenQ == 1);
362 | // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
363 | { // kWarpTileSeqLenQ = 1
364 | // Apply m_new = max(m_old, m_new) here.
365 | const float block_row_max_new_0 = max(lane_block_row_max_old[0], lane_row_max_new[0]);
366 | const float block_row_max_new_1 = max(lane_block_row_max_old[1], lane_row_max_new[1]);
367 |
368 | #pragma unroll
369 | for (int j = 0; j < kWarpTileSeqLenK; ++j) {
370 | half* t_hptr_S_0_1 = reinterpret_cast(R_S + j * 2);
371 | // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z;
372 | float4 t_reg_S_0_1;
373 | t_reg_S_0_1.x = __expf(__fmaf_rn(
374 | __half2float(t_hptr_S_0_1[0]), scale, - block_row_max_new_0));
375 | t_reg_S_0_1.y = __expf(__fmaf_rn(
376 | __half2float(t_hptr_S_0_1[1]), scale, - block_row_max_new_0));
377 | t_reg_S_0_1.z = __expf(__fmaf_rn(
378 | __half2float(t_hptr_S_0_1[2]), scale, - block_row_max_new_1));
379 | t_reg_S_0_1.w = __expf(__fmaf_rn(
380 | __half2float(t_hptr_S_0_1[3]), scale, - block_row_max_new_1));
381 | lane_row_sum_new[0] += (t_reg_S_0_1.x + t_reg_S_0_1.y);
382 | lane_row_sum_new[1] += (t_reg_S_0_1.z + t_reg_S_0_1.w);
383 | // Update R_S for P[Br,Bc] = Exp(S-m), point wise.
384 | t_hptr_S_0_1[0] = __float2half_rn(t_reg_S_0_1.x);
385 | t_hptr_S_0_1[1] = __float2half_rn(t_reg_S_0_1.y);
386 | t_hptr_S_0_1[2] = __float2half_rn(t_reg_S_0_1.z);
387 | t_hptr_S_0_1[3] = __float2half_rn(t_reg_S_0_1.w);
388 | } // end for kWarpTileSeqLenK
389 |
390 | // Warp level reduce sum, warp_size = 4
391 | lane_row_sum_new[0] = warp::reduce_sum(lane_row_sum_new[0]);
392 | lane_row_sum_new[1] = warp::reduce_sum(lane_row_sum_new[1]);
393 | }
394 | }
395 | }
396 |
397 |
398 | __device__ __forceinline__ void sync_precompute_rescale_factors(
399 | float * rescale_o_factor_0, // rescale factor
400 | float * rescale_o_factor_1, // rescale factor
401 | const float * lane_row_max_new, // &lane_row_max_new[0][0]
402 | const float * lane_block_row_max_old, // &lane_block_row_max_old[0][0]
403 | const int n_tile_id // tile_K_seqlen
404 | ) {
405 | float block_row_max_new_0 = lane_row_max_new[0];
406 | float block_row_max_new_1 = lane_row_max_new[1];
407 | float block_row_max_old_0 = lane_block_row_max_old[0];
408 | float block_row_max_old_1 = lane_block_row_max_old[1];
409 | // NOTE: max(-inf, val) = val.
410 | block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
411 | block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
412 | // Avoid inf value while using m_old for rescaling O.
413 | block_row_max_old_0 = (n_tile_id > 0 ? block_row_max_old_0 :
414 | block_row_max_new_0);
415 | block_row_max_old_1 = (n_tile_id > 0 ? block_row_max_old_1 :
416 | block_row_max_new_1);
417 | // Precompute rescale_o_factor_0 & rescale_o_factor_1, avoid redundant exp.
418 | rescale_o_factor_0[0] = __expf(block_row_max_old_0 - block_row_max_new_0);
419 | rescale_o_factor_1[0] = __expf(block_row_max_old_1 - block_row_max_new_1);
420 | }
421 |
422 | template
423 | __device__ __forceinline__ void sync_rescaling_tiling_o(
424 | uint32_t * R_D, // &R_D[0][0][0]
425 | uint32_t * R_O, // &R_O[0]
426 | const float * rescale_o_factor_0, // rescale factor
427 | const float * rescale_o_factor_1, // rescale factor
428 | const int n_tile_id, // tile_K_seqlen
429 | const int d_tile_id // j
430 | ) {
431 | // Now, we get [Br,8] slice of [Br,d], each warp(MMA) contains m16n8.
432 | // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old.
433 | // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V
434 | // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper)
435 | if constexpr (kMmaAccFloat32) {
436 | const float* t_fptr_O_0_1 = reinterpret_cast(R_O);
437 | if constexpr (kOStorageAccFloat32) {
438 | // (x,y) 0~7->{c0, c1}, (z,w)->8~15 {c2, c3} kWarpTileSeqLenP=1
439 | float* t_fptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 4); // &(R_D[0][j][0])
440 | t_fptr_D_0_1[0] = __fmaf_rn(rescale_o_factor_0[0], t_fptr_D_0_1[0], t_fptr_O_0_1[0]);
441 | t_fptr_D_0_1[1] = __fmaf_rn(rescale_o_factor_0[0], t_fptr_D_0_1[1], t_fptr_O_0_1[1]);
442 | t_fptr_D_0_1[2] = __fmaf_rn(rescale_o_factor_1[0], t_fptr_D_0_1[2], t_fptr_O_0_1[2]);
443 | t_fptr_D_0_1[3] = __fmaf_rn(rescale_o_factor_1[0], t_fptr_D_0_1[3], t_fptr_O_0_1[3]);
444 | } else {
445 | half* t_hptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 2);
446 | t_hptr_D_0_1[0] = __float2half_rn(__fmaf_rn(
447 | rescale_o_factor_0[0], __half2float(t_hptr_D_0_1[0]), t_fptr_O_0_1[0]));
448 | t_hptr_D_0_1[1] = __float2half_rn(__fmaf_rn(
449 | rescale_o_factor_0[0], __half2float(t_hptr_D_0_1[1]), t_fptr_O_0_1[1]));
450 | t_hptr_D_0_1[2] = __float2half_rn(__fmaf_rn(
451 | rescale_o_factor_1[0], __half2float(t_hptr_D_0_1[2]), t_fptr_O_0_1[2]));
452 | t_hptr_D_0_1[3] = __float2half_rn(__fmaf_rn(
453 | rescale_o_factor_1[0], __half2float(t_hptr_D_0_1[3]), t_fptr_O_0_1[3]));
454 | }
455 | } else {
456 | // MMA Acc F16
457 | const half* t_hptr_O_0_1 = reinterpret_cast(R_O);
458 | if constexpr (kOStorageAccFloat32) {
459 | // (x,y) 0~7->{c0, c1}, (z,w)->8~15 {c2, c3} kWarpTileSeqLenP=1
460 | float* t_fptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 4);
461 | t_fptr_D_0_1[0] = __fmaf_rn(
462 | rescale_o_factor_0[0], t_fptr_D_0_1[0], __half2float(t_hptr_O_0_1[0]));
463 | t_fptr_D_0_1[1] = __fmaf_rn(
464 | rescale_o_factor_0[0], t_fptr_D_0_1[1], __half2float(t_hptr_O_0_1[1]));
465 | t_fptr_D_0_1[2] = __fmaf_rn(
466 | rescale_o_factor_1[0], t_fptr_D_0_1[2], __half2float(t_hptr_O_0_1[2]));
467 | t_fptr_D_0_1[3] = __fmaf_rn(
468 | rescale_o_factor_1[0], t_fptr_D_0_1[3], __half2float(t_hptr_O_0_1[3]));
469 | } else {
470 | half* t_hptr_D_0_1 = reinterpret_cast(R_D + d_tile_id * 2);
471 | t_hptr_D_0_1[0] = __float2half_rn(__fmaf_rn(rescale_o_factor_0[0],
472 | __half2float(t_hptr_D_0_1[0]), __half2float(t_hptr_O_0_1[0])));
473 | t_hptr_D_0_1[1] = __float2half_rn(__fmaf_rn(rescale_o_factor_0[0],
474 | __half2float(t_hptr_D_0_1[1]), __half2float(t_hptr_O_0_1[1])));
475 | t_hptr_D_0_1[2] = __float2half_rn(__fmaf_rn(rescale_o_factor_1[0],
476 | __half2float(t_hptr_D_0_1[2]), __half2float(t_hptr_O_0_1[2])));
477 | t_hptr_D_0_1[3] = __float2half_rn(__fmaf_rn(rescale_o_factor_1[0],
478 | __half2float(t_hptr_D_0_1[3]), __half2float(t_hptr_O_0_1[3])));
479 | }
480 | }
481 | }
482 |
483 | __device__ __forceinline__ void sync_update_max_expsum(
484 | float * lane_row_max_new, // &lane_row_max_new[0][0]
485 | float * lane_row_sum_new, // &lane_row_sum_new[0][0]
486 | float * lane_block_row_max_old, // &lane_block_row_max_old[0][0]
487 | float * lane_block_row_sum_old, // &lane_block_row_sum_old[0][0]
488 | const float * rescale_o_factor_0, // rescale factor 0 exp(m_old - m_new)
489 | const float * rescale_o_factor_1 // rescale factor 1 exp(m_old - m_new)
490 | ) {
491 | // Now, we can update m, l after O has been scaled.
492 | // Update l = exp(m_old - m_new) * l_old + row_sum(P).
493 | lane_block_row_sum_old[0] = (__fmaf_rn(
494 | rescale_o_factor_0[0], lane_block_row_sum_old[0], lane_row_sum_new[0]));
495 | lane_block_row_sum_old[1] = (__fmaf_rn(
496 | rescale_o_factor_1[0], lane_block_row_sum_old[1], lane_row_sum_new[1]));
497 | // 2. Then, update block row max for each lane.
498 | lane_block_row_max_old[0] = max(lane_block_row_max_old[0], lane_row_max_new[0]);
499 | lane_block_row_max_old[1] = max(lane_block_row_max_old[1], lane_row_max_new[1]);
500 | }
501 |
502 |
503 | template
504 | __device__ __forceinline__ void sync_rescaling_final_o(
505 | uint32_t * R_D, // Final O after loop over N
506 | const float * lane_block_row_sum_old // &lane_block_row_sum_old[0][0]
507 | ) {
508 | // Finaly, we still have to rescale O once more.
509 | // O_output(D) = ( 1/l_final ) * O_final (FA2 paper)
510 | // static_assert(kWarpTileSeqLenP == 1);
511 | { // kWarpTileSeqLenP = 1
512 | const float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[0]);
513 | const float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[1]);
514 | #pragma unroll
515 | for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
516 | // Scaling in registers & convert F32 -> half for O collective store.
517 | if constexpr (kOStorageAccFloat32) {
518 | const float* t_fptr_D_0_1 = reinterpret_cast(R_D + j * 4);
519 | half* t_hptr_D_0_1 = reinterpret_cast< half*>(R_D + j * 4);
520 | t_hptr_D_0_1[0] = __float2half_rn(rescale_factor_0 * t_fptr_D_0_1[0]);
521 | t_hptr_D_0_1[1] = __float2half_rn(rescale_factor_0 * t_fptr_D_0_1[1]);
522 | t_hptr_D_0_1[2] = __float2half_rn(rescale_factor_1 * t_fptr_D_0_1[2]);
523 | t_hptr_D_0_1[3] = __float2half_rn(rescale_factor_1 * t_fptr_D_0_1[3]);
524 | } else {
525 | half* t_hptr_D_0_1 = reinterpret_cast(R_D + j * 2);
526 | t_hptr_D_0_1[0] = __float2half_rn(rescale_factor_0 * __half2float(t_hptr_D_0_1[0]));
527 | t_hptr_D_0_1[1] = __float2half_rn(rescale_factor_0 * __half2float(t_hptr_D_0_1[1]));
528 | t_hptr_D_0_1[2] = __float2half_rn(rescale_factor_1 * __half2float(t_hptr_D_0_1[2]));
529 | t_hptr_D_0_1[3] = __float2half_rn(rescale_factor_1 * __half2float(t_hptr_D_0_1[3]));
530 | }
531 | } // end for kWarpTileHeadDimV
532 | } // end for kWarpTileSeqLenP = 1
533 | }
534 |
535 |
536 | template<
537 | const int Br,
538 | const int kHeadDim,
539 | const int kMmaAtomM,
540 | const int kMmaAtomN,
541 | const int kWarpTileHeadDimV,
542 | const int kOStorageAccFloat32
543 | >
544 | __device__ __forceinline__ void sync_store_o_r2g(
545 | half * gmem_ptr, // O gmem ptr
546 | const int gmem_offset, // O gmem global offset
547 | const int n_tile_id, // curr tile id (seqlen) O_tile_id
548 | const int mma_tile_id, // Q warp_QP 0~num MMAs, KV warp_KV 0
549 | uint32_t * R_D, // Final scaled O
550 | uint32_t * R_Q, // R_Q[1][4] for registers reuse
551 | uint32_t * R_K // R_K[8][2] for registers reuse
552 | ) {
553 | // Store O(D): Write O[Br,d] from regs -> gmem, collective store
554 | // with reg reuse & warp shuffle.
555 | const int lane_id = threadIdx.x % WARP_SIZE; // 0~31
556 | // static_assert(kWarpTileSeqLenP == 1);
557 | { // kWarpTileSeqLenP = 1
558 | #pragma unroll
559 | for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
560 | // reuse R_Q[1][4], R_K[8][2] for collective store.
561 | uint32_t* t_uptr_Z_0 = reinterpret_cast(R_Q);
562 | uint32_t* t_uptr_Z_1 = reinterpret_cast(R_K);
563 | const int offset = (kOStorageAccFloat32) ? j * 4 : j * 2;
564 | t_uptr_Z_0[0] = R_D[offset + 0];
565 | t_uptr_Z_1[0] = R_D[offset + 1];
566 | t_uptr_Z_0[1] = __shfl_sync((0xffffffff), R_D[offset + 0], lane_id + 1, 4);
567 | t_uptr_Z_0[2] = __shfl_sync((0xffffffff), R_D[offset + 0], lane_id + 2, 4);
568 | t_uptr_Z_0[3] = __shfl_sync((0xffffffff), R_D[offset + 0], lane_id + 3, 4);
569 | t_uptr_Z_1[1] = __shfl_sync((0xffffffff), R_D[offset + 1], lane_id + 1, 4);
570 | t_uptr_Z_1[2] = __shfl_sync((0xffffffff), R_D[offset + 1], lane_id + 2, 4);
571 | t_uptr_Z_1[3] = __shfl_sync((0xffffffff), R_D[offset + 1], lane_id + 3, 4);
572 |
573 | // st.global.v4 128 bits. [Br,d]
574 | if (lane_id % 4 == 0) {
575 | // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 kWarpTileSeqLenP = 1
576 | // int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + 0 * kMmaAtomM;
577 | const int store_warp_regs_O_Br = mma_tile_id * (kMmaAtomM);
578 | const int store_lane_gmem_O_Br = n_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
579 | // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) warp_KV = 0
580 | // int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
581 | const int store_warp_regs_O_d = j * kMmaAtomN;
582 | const int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
583 | const int store_gmem_O_addr_0 = (
584 | gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
585 | const int store_gmem_O_addr_1 = (
586 | gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
587 | cp_async::stg_sync_128b(&gmem_ptr[store_gmem_O_addr_0], t_uptr_Z_0);
588 | cp_async::stg_sync_128b(&gmem_ptr[store_gmem_O_addr_1], t_uptr_Z_1);
589 | }
590 | } // end for kWarpTileHeadDimV
591 | } // kWarpTileSeqLenP = 1
592 | }
593 |
594 | } // prefill
595 | } // ffpa
596 |
--------------------------------------------------------------------------------
/include/cuffpa/swizzle.cuh:
--------------------------------------------------------------------------------
1 | // Manually SMEM swizzling for bank conflict free.
2 | // ----------------------------------------------------------------
3 | // [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
4 | // [INFO] For logical_col_stride > 16, we have to permute the |
5 | // [INFO] smem store layout using col major ZigZag method: |
6 | // [INFO] e.g, --> Q smem logical layout [Br][64]. |
7 | // [INFO] --> col major ZigZag permuted --> |
8 | // [INFO] --> Q smem store layout [4][Br][16]. |
9 | // ----------------------------------------------------------------
10 | // ----------------------------------------------------------------
11 | // -------------------------swizzle layout-------------------------
12 | // --------------------logical col 0~64, step 8--------------------
13 | // ---------------------smem col 0~16, step 8----------------------
14 | // ----------------------------------------------------------------
15 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
16 | // |row 0 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
17 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
18 | // |row 1 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
19 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
20 | // |row 2 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
21 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
22 | // |row 3 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
23 | // ----------------------------------------------------------------
24 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
25 | // |row 4 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
26 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
27 | // |row 5 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
28 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
29 | // |row 6 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
30 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
31 | // |row 7 | 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
32 | // ----------------------------------------------------------------
33 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
34 | // |row 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
35 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
36 | // |row 9 | 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
37 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
38 | // |row 10| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
39 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
40 | // |row 11| 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8 |
41 | // ----------------------------------------------------------------
42 | // |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
43 | // |row 12| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
44 | // |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
45 | // |row 13| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
46 | // |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
47 | // |row 14| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
48 | // |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
49 | // |row 15| 8 | 0 | 8 | 0 | 8 | 0 | 8 | 0 |
50 | // ----------------------------------------------------------------
51 | #pragma once
52 | #include
53 |
54 | namespace ffpa {
55 | namespace swizzle {
56 |
57 | // i: row index; j: col index.
58 | template
59 | static __device__ __forceinline__ int permuted(int i, int j) {
60 | // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep;
61 | static_assert(kColStride <= 16, "Currently, kColStride must be less than or equal to 16.");
62 | static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4.");
63 | static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep.");
64 | if constexpr (kStep == 8) {
65 | return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3;
66 | } else {
67 | static_assert(kStep == 4);
68 | return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2;
69 | }
70 | }
71 |
72 | } // swizzle
73 | } // ffpa
74 |
75 |
--------------------------------------------------------------------------------
/include/cuffpa/utils.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "cuffpa/logging.cuh" // log
3 |
4 | namespace ffpa {
5 | namespace utils {
6 |
7 | __device__ __host__ inline
8 | int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }
9 |
10 | template
11 | __device__ inline void fill_3D_regs(T (&R)[M][N][K], T val) {
12 | #pragma unroll
13 | for (int i = 0; i < M; ++i) {
14 | #pragma unroll
15 | for (int j = 0; j < N; ++j) {
16 | #pragma unroll
17 | for (int k = 0; k < K; ++k) {
18 | R[i][j][k] = val;
19 | }
20 | }
21 | }
22 | }
23 |
24 | template
25 | __device__ inline void fill_2D_regs(T (&R)[M][N], T val) {
26 | #pragma unroll
27 | for (int i = 0; i < M; ++i) {
28 | #pragma unroll
29 | for (int j = 0; j < N; ++j) {
30 | R[i][j] = val;
31 | }
32 | }
33 | }
34 |
35 | template
36 | __device__ inline void fill_1D_regs(T (&S)[M], T val) {
37 | #pragma unroll
38 | for (int i = 0; i < M; ++i) {
39 | S[i] = val;
40 | }
41 | }
42 |
43 | } // utils
44 | } // ffpa
45 |
46 |
--------------------------------------------------------------------------------
/include/cuffpa/warp.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #define WARP_SIZE 32
15 |
16 | namespace ffpa {
17 | namespace warp {
18 |
19 | template
20 | __device__ inline T reduce_sum(T val) {
21 | #pragma unroll
22 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
23 | val += __shfl_xor_sync(0xffffffff, val, mask, kWarpSize);
24 | }
25 | return val;
26 | }
27 |
28 | template
29 | __device__ inline T reduce_max(T val) {
30 | #pragma unroll
31 | for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
32 | val = max(val, __shfl_xor_sync(0xffffffff, val, mask, kWarpSize));
33 | }
34 | return val;
35 | }
36 |
37 | } // warp
38 | } // ffpa
39 |
--------------------------------------------------------------------------------
/include/extension/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
22 | dist
23 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | packaging
2 | ninja
3 | torch>=2.4.0
4 | numpy
5 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal=1
3 |
4 | [metadata]
5 | license_files = LICENSE
6 | desciption_file = README.md
7 |
8 | [pep8]
9 | max-line-length = 120
10 |
11 | [flake8]
12 | max-line-length = 120
13 | ignore = E731, E203, E402, W503, W504, F821, E501, B, C4, EXE
14 | per-file-ignores =
15 | __init__.py: F401, F403, F405
16 | exclude = venv
17 |
18 | [pydocstyle]
19 | select = D417 # Missing argument descriptions in the docstring
20 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from env import ENV
4 | from packaging.version import Version
5 | from setuptools import find_packages, setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
7 | import warnings
8 | warnings.filterwarnings("ignore")
9 |
10 |
11 | def get_long_description():
12 | description = (Path(ENV.project_dir()) / "README.md").read_text(encoding="utf-8")
13 | return description
14 |
15 |
16 | # package name managed by pip, which can be remove by `pip uninstall ffpa-attn -y`
17 | PACKAGE_NAME = "ffpa-attn"
18 |
19 | ext_modules = []
20 | generator_flag = []
21 | cc_flag = []
22 |
23 | ENV.list_ffpa_env()
24 |
25 | if ENV.enable_ampere():
26 | cc_flag.append("-gencode")
27 | cc_flag.append("arch=compute_80,code=sm_80")
28 |
29 | if ENV.enable_ada():
30 | cc_flag.append("-gencode")
31 | cc_flag.append("arch=compute_89,code=sm_89")
32 |
33 | if ENV.enable_hopper():
34 | if CUDA_HOME is not None:
35 | _, bare_metal_version = ENV.get_cuda_bare_metal_version(CUDA_HOME)
36 | if bare_metal_version >= Version("11.8"):
37 | cc_flag.append("-gencode")
38 | cc_flag.append("arch=compute_90,code=sm_90")
39 |
40 | assert cc_flag is not None, "cc_flag can not be NoneType."
41 |
42 | # cuda module
43 | # may need export LD_LIBRARY_PATH=PATH-TO/torch/lib:$LD_LIBRARY_PATH
44 | ext_modules.append(
45 | CUDAExtension(
46 | # package name for import
47 | name="pyffpa_cuda",
48 | sources=ENV.get_build_sources(build_pkg=True),
49 | extra_compile_args={
50 | # add c compile flags
51 | "cxx": ["-O3", "-std=c++17"] + generator_flag,
52 | # add nvcc compile flags
53 | "nvcc": ENV.get_build_cuda_cflags(build_pkg=True)
54 | + generator_flag
55 | + cc_flag,
56 | },
57 | include_dirs=[
58 | Path(ENV.project_dir()) / "include",
59 | ],
60 | )
61 | )
62 |
63 |
64 | def fetch_requirements():
65 | with open("requirements.txt") as f:
66 | reqs = f.read().strip().split("\n")
67 | return reqs
68 |
69 |
70 | setup(
71 | name=PACKAGE_NAME,
72 | version="0.0.2",
73 | author="DefTruth",
74 | author_email="qyjdef@163.com",
75 | license="GNU General Public License v3.0",
76 | packages=find_packages(
77 | exclude=(
78 | "build",
79 | "dist",
80 | "include",
81 | "csrc",
82 | "tests",
83 | "bench",
84 | "tmp",
85 | "cuffpa_py.egg-info",
86 | "ffpa_attn.egg-info",
87 | "__pycache__",
88 | "third_party",
89 | )
90 | ),
91 | description="FFPA: Yet another Faster Flash Prefill Attention for large headdim, 1.8x~3x faster than SDPA EA.",
92 | long_description=get_long_description(),
93 | long_description_content_type="text/markdown",
94 | url="https://github.com/DefTruth/ffpa-attn-mma.git",
95 | ext_modules=ext_modules,
96 | cmdclass={"build_ext": BuildExtension},
97 | python_requires=">=3.10",
98 | install_requires=fetch_requirements(),
99 | extras_require={
100 | "all": [],
101 | "dev": [
102 | "pre-commit",
103 | "packaging",
104 | "ninja",
105 | ],
106 | },
107 | )
108 |
--------------------------------------------------------------------------------
/tests/.gitignore:
--------------------------------------------------------------------------------
1 | *.so
2 | *.a
3 | *.dylib
4 | *.dll
5 | *.lib
6 | .DS_Store
7 | build
8 | *.whl
9 | tmp
10 | __pycache__
11 | *.onnx
12 | *.engine
13 | *.pt
14 | *.pth
15 | *.nsys*
16 | *.ncu*
17 | *.sqlite*
18 | *.engine
19 | *.bin
20 | outupt
21 | *.egg-info
22 | dist
23 | .tmp
24 |
--------------------------------------------------------------------------------
/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 |
--------------------------------------------------------------------------------
/tests/swizzle_layout.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def pretty_print_line(
5 | m: str = "", sep: str = "-", width: int = 130, return_str: bool = False
6 | ):
7 | res_len = width - len(m)
8 | left_len = int(res_len / 2)
9 | right_len = res_len - left_len
10 | pretty_line = sep * left_len + m + sep * right_len
11 | if not return_str:
12 | print(pretty_line)
13 | else:
14 | return pretty_line
15 |
16 |
17 | PERMUTED_DOCS_STRING = """----------------------------------------------------------------
18 | [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
19 | [INFO] For logical_col_stride > 16, we have to permute the |
20 | [INFO] smem store layout using col major ZigZag method: |
21 | [INFO] e.g, --> Q smem logical layout [Br][64]. |
22 | [INFO] --> col major ZigZag permuted --> |
23 | [INFO] --> Q smem store layout [4][Br][16]. |
24 | ----------------------------------------------------------------"""
25 |
26 | def swizzle_permuted_j(
27 | i: int, j: int, col_stride: int = 16, num_elems_per_128b: int = 8
28 | ):
29 | # i: row index; j: col index. col_stride <= 16.
30 | # assert col_stride <= 16, f"col_stride must <= 16, but got {col_stride}"
31 | # for col_stride > 16, we have to permute it using col major ZigZag order.
32 | # e.g, Q smem logical layout [Br,d]=[Br,64] -> store layout [4][Br][16].
33 | return (
34 | (int(j / num_elems_per_128b) ^ int(i / 4))
35 | % (int(col_stride / num_elems_per_128b))
36 | ) * num_elems_per_128b
37 |
38 |
39 | def print_smem_swizzle_layout(
40 | rows: int = 16,
41 | logical_col_stride: int = 16,
42 | num_elems_per_128b: int = 8,
43 | smem_pading: int = 0,
44 | show_logical_col_id: bool = False,
45 | use_logical_col_stride: bool = False,
46 | ):
47 | # ----------------------------------------------------------------
48 | # [INFO] Assert smem store layout col_stride <= 16, prefer 16. |
49 | # [INFO] For logical_col_stride > 16, we have to permute the |
50 | # [INFO] smem store layout using col major ZigZag method: |
51 | # [INFO] e.g, --> Q smem logical layout [Br][64]. |
52 | # [INFO] --> col major ZigZag permuted --> |
53 | # [INFO] --> Q smem store layout [4][Br][16]. |
54 | # ----------------------------------------------------------------
55 | # ----------------------------------------------------------------
56 | # -------------------------swizzle layout-------------------------
57 | # --------------------logical col 0~64, step 8--------------------
58 | # ---------------------smem col 0~16, step 8----------------------
59 | # ----------------------------------------------------------------
60 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
61 | # |row 0 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
62 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
63 | # |row 1 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
64 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
65 | # |row 2 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
66 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
67 | # |row 3 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
68 | # ----------------------------------------------------------------
69 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
70 | # |row 4 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
71 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
72 | # |row 5 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
73 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
74 | # |row 6 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
75 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
76 | # |row 7 | 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
77 | # ----------------------------------------------------------------
78 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
79 | # |row 8 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
80 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
81 | # |row 9 | 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
82 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
83 | # |row 10| 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
84 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
85 | # |row 11| 0:0 | 8:8 |16:0 |24:8 |32:0 |40:8 |48:0 |56:8 |
86 | # ----------------------------------------------------------------
87 | # |bank |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |b 0~3 |b 4~7 |
88 | # |row 12| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
89 | # |bank |b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|b 8~11|b12~15|
90 | # |row 13| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
91 | # |bank |b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|b16~19|b20~23|
92 | # |row 14| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
93 | # |bank |b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|b24~27|b28~31|
94 | # |row 15| 0:8 | 8:0 |16:8 |24:0 |32:8 |40:0 |48:8 |56:0 |
95 | # ----------------------------------------------------------------
96 | str_len = 0
97 | total_banks = 0
98 | assert smem_pading == 0 or smem_pading == 8, "smem_pading must be 0 or 8"
99 | # 4 bytes per bank
100 | banks_per_col = int((16 * 2) / 4) if logical_col_stride >= 16 else 4
101 | if use_logical_col_stride:
102 | banks_per_col = int((logical_col_stride * 2) / 4)
103 | if logical_col_stride > 16:
104 | print(f"[WARN] col_stride must <= 16, but got {logical_col_stride}")
105 | if smem_pading == 8:
106 | banks_per_col += 4
107 | print(
108 | f"[INFO] smem padding 8 half values, 4 banks, banks_per_col: {banks_per_col}"
109 | )
110 |
111 | banks_per_num_elems_per_128b = int((num_elems_per_128b * 2) / 4)
112 | for i in range(rows):
113 | layout_str_len = 0
114 | banks_str_len = 0
115 |
116 | # bank_layout_str
117 | banks_start = total_banks % 32 # 32 banks in total
118 | banks_end = banks_start + banks_per_col
119 | bank_layout_str = "|bank |"
120 | max_bank_str_len = 0
121 | if logical_col_stride >= 16 and (not use_logical_col_stride):
122 | for k in range(int(logical_col_stride / 16)):
123 | for j in range(banks_start, banks_end, banks_per_num_elems_per_128b):
124 | curr_bank_str = (
125 | f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|"
126 | )
127 | max_bank_str_len = max(max_bank_str_len, len(curr_bank_str))
128 | bank_layout_str += curr_bank_str
129 | else:
130 | for j in range(banks_start, banks_end, banks_per_num_elems_per_128b):
131 | curr_bank_str = f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|"
132 | max_bank_str_len = max(max_bank_str_len, len(curr_bank_str))
133 | bank_layout_str += curr_bank_str
134 |
135 | # smem_layout_str
136 | logical_col_ids = []
137 | smem_layout_col_ids = []
138 | if logical_col_stride >= 16 and (not use_logical_col_stride):
139 | for k in range(int(logical_col_stride / 16)):
140 | for j in range(0, 16, num_elems_per_128b):
141 | layout_j = swizzle_permuted_j(i, j, 16, num_elems_per_128b)
142 | logical_col_ids.append(k * 16 + j)
143 | smem_layout_col_ids.append(layout_j)
144 | else:
145 | for j in range(0, logical_col_stride, num_elems_per_128b):
146 | layout_j = swizzle_permuted_j(
147 | i, j, logical_col_stride, num_elems_per_128b
148 | )
149 | logical_col_ids.append(j)
150 | smem_layout_col_ids.append(layout_j)
151 |
152 | smem_layout_str = f"|row {i:<2}|"
153 |
154 | r = 0
155 | for c, l in zip(logical_col_ids, smem_layout_col_ids):
156 | smem_layout_str += (
157 | pretty_print_line(
158 | (f"{c:>2}:{l:<2}" if show_logical_col_id else f"{l:<2}"),
159 | sep=" ",
160 | width=(max_bank_str_len - 1),
161 | return_str=True,
162 | )
163 | + "|"
164 | )
165 | r += 1
166 | if logical_col_stride >= 16:
167 | if smem_pading == 8 and (r > 1 and r % 2 == 0):
168 | smem_layout_str += (
169 | pretty_print_line(
170 | ("pad"),
171 | sep=" ",
172 | width=max_bank_str_len - 1,
173 | return_str=True,
174 | )
175 | + "|"
176 | )
177 | else:
178 | if smem_pading == 8:
179 | smem_layout_str += (
180 | pretty_print_line(
181 | ("pad"),
182 | sep=" ",
183 | width=max_bank_str_len - 1,
184 | return_str=True,
185 | )
186 | + "|"
187 | )
188 |
189 | layout_str_len = len(smem_layout_str)
190 | str_len = max(layout_str_len, banks_str_len)
191 |
192 | # print banks and smem layout
193 | if i == 0:
194 | print("-" * str_len)
195 | pretty_print_line("swizzle layout", width=str_len)
196 | pretty_print_line(
197 | f"logical col 0~{logical_col_stride}, " f"step {num_elems_per_128b}",
198 | width=str_len,
199 | )
200 | pretty_print_line(
201 | f"smem col 0~16, step {num_elems_per_128b}"
202 | if logical_col_stride >= 16
203 | else f"smem col 0~8, step {num_elems_per_128b}",
204 | width=str_len,
205 | )
206 | print("-" * str_len)
207 | print(bank_layout_str)
208 | print(smem_layout_str)
209 | if (i + 1) % 4 == 0 and i != (rows - 1):
210 | print("-" * str_len)
211 | total_banks += banks_per_col
212 | print("-" * str_len)
213 |
214 |
215 | def get_args():
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument("--rows", type=int, default=16)
218 | parser.add_argument("--smem-padding", "--pad", type=int, default=0)
219 | parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
220 | parser.add_argument(
221 | "--logical-col-stride", "--logical-col", "--col", type=int, default=64
222 | )
223 | parser.add_argument(
224 | "--use-logical-col-stride", "--use-logical-col", action="store_true"
225 | )
226 | parser.add_argument(
227 | "--show-logical-col-id", "--show-logical-col", action="store_true"
228 | )
229 | return parser.parse_args()
230 |
231 |
232 | if __name__ == "__main__":
233 | args = get_args()
234 | print(args)
235 | print(PERMUTED_DOCS_STRING)
236 | print_smem_swizzle_layout(
237 | rows=args.rows,
238 | logical_col_stride=args.logical_col_stride,
239 | num_elems_per_128b=args.num_elems_per_128b,
240 | smem_pading=args.smem_padding,
241 | show_logical_col_id=args.show_logical_col_id,
242 | use_logical_col_stride=args.use_logical_col_stride,
243 | )
244 |
--------------------------------------------------------------------------------
/tests/test_ffpa_attn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import random
5 | import sys
6 | import time
7 | from functools import partial
8 | from typing import Optional
9 |
10 | import numpy as np
11 | import torch
12 | from torch import Tensor
13 | from torch.nn import functional as F
14 | from torch.nn.attention import sdpa_kernel, SDPBackend
15 | try:
16 | from flash_attn import flash_attn_func
17 | has_flash_attn = True
18 | except ImportError:
19 | flash_attn_func = None
20 | has_flash_attn = False
21 |
22 | sys.path.append("../")
23 | from env import ENV, pretty_print_line
24 |
25 | torch.set_grad_enabled(False)
26 | torch.set_printoptions(
27 | precision=6, threshold=8, edgeitems=3, linewidth=120, sci_mode=False
28 | )
29 |
30 |
31 | def get_args():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--no-rand-q", "--no-rq", action="store_true")
34 | parser.add_argument("--no-rand-k", "--no-rk", action="store_true")
35 | parser.add_argument("--no-rand-v", "--no-rv", action="store_true")
36 | parser.add_argument("--no-rand-qkv", "--no-rqkv", action="store_true")
37 | parser.add_argument("--run-torch-unfused", "--torch", action="store_true")
38 | parser.add_argument("--run-flash-attn", "--flash", action="store_true")
39 | parser.add_argument("--check", action="store_true")
40 | parser.add_argument("--check-all", action="store_true")
41 | parser.add_argument("--show-all", "--show", action="store_true")
42 | parser.add_argument("--show-less", "--show-l", action="store_true")
43 | parser.add_argument("--show-matrix", action="store_true")
44 | parser.add_argument("--only-flops-matmul", "--flops-mm", action="store_true")
45 | parser.add_argument("--B", type=int, default=None)
46 | parser.add_argument("--H", type=int, default=None)
47 | parser.add_argument("--N", type=int, default=None)
48 | parser.add_argument("--D", type=int, default=None)
49 | parser.add_argument("--MAX-D", "--MD", type=int, default=1024)
50 | parser.add_argument("--seed", type=int, default=None)
51 | parser.add_argument("--sleep", type=float, default=0.05)
52 | parser.add_argument("--debug", action="store_true")
53 | parser.add_argument("--verbose", "--v", action="store_true")
54 | parser.add_argument("--warmup", "--w", type=int, default=1)
55 | parser.add_argument("--iters", "--i", type=int, default=5)
56 | parser.add_argument("--tag-hints", "--tags", "--hints", type=str, default=None)
57 | parser.add_argument(
58 | "--plot-flops", "--plot", action="store_true", help="Plot TFLOPS"
59 | )
60 | parser.add_argument(
61 | "--save-dir", "--dir", type=str, default="tmp", help="Save dir for plot"
62 | )
63 | parser.add_argument(
64 | "--save-tag", "--tag", type=str, default=None, help="Save name for plot"
65 | )
66 | parser.add_argument("--gen-bench-table", "--gen-bench", action="store_true")
67 | parser.add_argument(
68 | "--force-build", "--build", action="store_true", help="Force build from sources"
69 | )
70 |
71 | return parser.parse_args()
72 |
73 |
74 | args = get_args()
75 | ENV.list_ffpa_env()
76 |
77 |
78 | def set_rand_seed(seed: int = 1):
79 | random.seed(seed)
80 | np.random.seed(seed)
81 | torch.manual_seed(seed)
82 | torch.cuda.manual_seed_all(seed)
83 |
84 |
85 | pretty_print_line()
86 | print(args)
87 | pretty_print_line()
88 |
89 | # Load the CUDA kernel as a python module
90 | ffpa_attn, use_ffpa_attn_package = ENV.try_load_ffpa_library(
91 | force_build=args.force_build, verbose=args.verbose
92 | )
93 | if use_ffpa_attn_package:
94 | import ffpa_attn # tricks for IDE code search
95 |
96 |
97 | def get_mha_tflops(
98 | B: int, H: int, N: int, D: int, secs: float = 1.0, only_matmul: bool = False
99 | ):
100 | # Q @ K^T FLOPs
101 | flops_qk = B * H * N * N * (2 * D - 1)
102 |
103 | # Scaling FLOPs
104 | flops_scaling = B * H * N * N
105 |
106 | # Safe_Softmax FLOPs
107 | flops_row_max = B * H * N * (N - 1) # row max
108 | flops_subtract_max = B * H * N * N # sub max
109 | flops_exp = B * H * N * N # pointwise exp
110 | flops_row_sum = B * H * N * (N - 1) # row sum
111 | flops_normalization = B * H * N * N # normalization
112 |
113 | flops_safe_softmax = (
114 | flops_row_max
115 | + flops_subtract_max
116 | + flops_exp
117 | + flops_row_sum
118 | + flops_normalization
119 | )
120 |
121 | # P @ V FLOPs
122 | flops_pv = B * H * N * D * (2 * N - 1)
123 |
124 | # Total FLOPs
125 | total_flops = flops_qk + flops_scaling + flops_safe_softmax + flops_pv
126 | if only_matmul:
127 | total_flops = flops_qk + flops_pv
128 |
129 | # Convert to TFLOPS
130 | # 1 TFLOPS = 10^12 FLOPS
131 | # ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
132 | tflops = total_flops * 1e-12 / (secs)
133 |
134 | return tflops
135 |
136 |
137 | MAX_TFLOPS = -1
138 | STATIS_INFO: dict[str, list[float | int] | set] = {}
139 | STATIS_INFO["headdim"] = set()
140 | TOATL_TFLOPS: dict[str, float] = {}
141 | SDPA_TFLOPS = -1
142 |
143 | def run_benchmark(
144 | perf_func: callable,
145 | q: torch.Tensor,
146 | k: torch.Tensor,
147 | v: torch.Tensor,
148 | tag: str,
149 | out: Optional[torch.Tensor] = None,
150 | s: Optional[torch.Tensor] = None, # DEBUG
151 | stages: int = -1,
152 | warmup: int = args.warmup,
153 | iters: int = args.iters,
154 | show_matrix: bool = args.show_matrix,
155 | only_show_improved: bool = not args.show_all,
156 | ):
157 |
158 | global MAX_TFLOPS
159 | global MAX_HEADDIM_CFG
160 | global SDPA_TFLOPS
161 |
162 | tag_hints: str = args.tag_hints # e.g "share-qkv,tiling-kv,swizzle"
163 | if tag_hints:
164 | tag_hints: list = tag_hints.strip().split(",")
165 | tag_hints.append("sdpa")
166 | tag_hints.append("unfused")
167 | hit_hints = False
168 | for hint in tag_hints:
169 | if hint in tag:
170 | hit_hints = True
171 | if not hit_hints:
172 | return None, None
173 |
174 | B, H, N, D = q.size()
175 | if "flash" in tag:
176 | B, N, H, D = q.size()
177 |
178 | if "unfused" in tag and (not args.run_torch_unfused):
179 | return None, None
180 | if "flash" in tag and ((not args.run_flash_attn)
181 | or (not has_flash_attn) or (D > 256)):
182 | return None, None
183 |
184 | STATIS_INFO["headdim"].add(D)
185 |
186 | max_supported_D = MAX_HEADDIM_CFG.get(tag, None)
187 | # skip if headdim not supported.
188 | if max_supported_D is not None:
189 | if D > max_supported_D:
190 | return None, None
191 |
192 | if out is not None:
193 | out.fill_(0)
194 | if s is not None:
195 | s.fill_(0)
196 | if out is not None:
197 | for i in range(warmup):
198 | if stages >= 1:
199 | if s is not None:
200 | perf_func(q, k, v, out, s, stages)
201 | else:
202 | perf_func(q, k, v, out, stages)
203 | else:
204 | perf_func(q, k, v, out)
205 | else:
206 | for i in range(warmup):
207 | _ = perf_func(q, k, v)
208 |
209 | torch.cuda.synchronize()
210 | start = time.time()
211 | # iters
212 | if out is not None:
213 | for i in range(iters):
214 | if stages >= 1:
215 | if s is not None:
216 | perf_func(q, k, v, out, s, stages)
217 | else:
218 | perf_func(q, k, v, out, stages)
219 | else:
220 | perf_func(q, k, v, out)
221 | else:
222 | for i in range(iters):
223 | out = perf_func(q, k, v)
224 | torch.cuda.synchronize()
225 | end = time.time()
226 | total_secs = end - start
227 | total_time = (end - start) * 1000 # ms
228 | mean_time = total_time / iters
229 | mean_secs = total_secs / iters
230 |
231 | TFLOPS = get_mha_tflops(B, H, N, D, mean_secs, only_matmul=args.only_flops_matmul)
232 | if tag in STATIS_INFO:
233 | STATIS_INFO[tag].append(int(round(TFLOPS)))
234 | else:
235 | STATIS_INFO[tag] = []
236 | STATIS_INFO[tag].append(int(round(TFLOPS)))
237 |
238 | if "sdpa" in tag:
239 | SDPA_TFLOPS = TFLOPS
240 | out_info = f"{tag}"
241 | out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist()
242 | out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist()
243 | out_val_first = [round(v, 8) for v in out_val_first]
244 | out_val_last = [round(v, 8) for v in out_val_last]
245 | if not args.show_less:
246 | out_val = out_val_first[:2]
247 | out_val.append(out_val_last[-1])
248 | else:
249 | out_val = out_val_first[:1]
250 | out_val = [f"{v:<12}" for v in out_val]
251 | if args.show_less:
252 | out_val = [v.strip() for v in out_val]
253 |
254 | if SDPA_TFLOPS > 0:
255 | speedup_sdpa = TFLOPS / SDPA_TFLOPS
256 | else:
257 | speedup_sdpa = 1.0
258 |
259 | # caculate TFLOPS improved.
260 | if TFLOPS > MAX_TFLOPS:
261 | if MAX_TFLOPS > 0:
262 | improve = ((TFLOPS - MAX_TFLOPS) / MAX_TFLOPS) * 100
263 | improve = round(improve, 2)
264 | else:
265 | improve = 0
266 |
267 | MAX_TFLOPS = TFLOPS
268 | print(
269 | f"{out_info:>25}: {out_val}, time:{str(mean_time)[:8]}ms, "
270 | f"TFLOPS:{TFLOPS:<6.2f}(+{improve:<5.2f}%)(~{speedup_sdpa:<4.2f}x)"
271 | )
272 | else:
273 | improve = 0
274 | if (not only_show_improved) or (("flash" in tag) or ("sdpa" in tag)):
275 | print(
276 | f"{out_info:>25}: {out_val}, time:{str(mean_time)[:8]}ms, "
277 | f"TFLOPS:{TFLOPS:<6.2f}(+{improve:<5.2f}%)(~{speedup_sdpa:<4.2f}x)"
278 | )
279 |
280 | if show_matrix:
281 | print(out)
282 | time.sleep(args.sleep)
283 | torch.cuda.synchronize()
284 | return out.clone() if args.check else out, mean_time
285 |
286 |
287 | def get_best_tflops():
288 | global STATIS_INFO
289 | if ENV.enable_all_mutistages():
290 | sdpa_tflops = STATIS_INFO["(sdpa)"]
291 | ffpa_l1_f32_best_tflops = [
292 | max(x, y, z, w)
293 | for x, y, z, w in zip(
294 | STATIS_INFO["(ffpa+acc+f32+L1+stage1)"],
295 | STATIS_INFO["(ffpa+acc+f32+L1+stage2)"],
296 | STATIS_INFO["(ffpa+acc+f32+L1+stage3)"],
297 | STATIS_INFO["(ffpa+acc+f32+L1+stage4)"],
298 | )
299 | ]
300 | ffpa_l1_f16_best_tflops = [
301 | max(x, y, z, w)
302 | for x, y, z, w in zip(
303 | STATIS_INFO["(ffpa+acc+f16+L1+stage1)"],
304 | STATIS_INFO["(ffpa+acc+f16+L1+stage2)"],
305 | STATIS_INFO["(ffpa+acc+f16+L1+stage3)"],
306 | STATIS_INFO["(ffpa+acc+f16+L1+stage4)"],
307 | )
308 | ]
309 | else:
310 | ffpa_l1_f32_best_tflops = [
311 | max(x, y)
312 | for x, y in zip(
313 | STATIS_INFO["(ffpa+acc+f32+L1+stage1)"],
314 | STATIS_INFO["(ffpa+acc+f32+L1+stage2)"],
315 | )
316 | ]
317 | ffpa_l1_f16_best_tflops = [
318 | max(x, y)
319 | for x, y in zip(
320 | STATIS_INFO["(ffpa+acc+f16+L1+stage1)"],
321 | STATIS_INFO["(ffpa+acc+f16+L1+stage2)"],
322 | )
323 | ]
324 |
325 | # calculate improved
326 | ffpa_l1_f32_speedup = [
327 | round(f / s, 2) for f, s in zip(ffpa_l1_f32_best_tflops, sdpa_tflops)
328 | ]
329 | ffpa_l1_f16_speedup = [
330 | round(f / s, 2) for f, s in zip(ffpa_l1_f16_best_tflops, sdpa_tflops)
331 | ]
332 | STATIS_INFO["ffpa+acc-f32+L1(best)"] = ffpa_l1_f32_best_tflops
333 | STATIS_INFO["ffpa+acc-f16+L1(best)"] = ffpa_l1_f16_best_tflops
334 | STATIS_INFO["ffpa+acc-f32+L1(speedup)"] = ffpa_l1_f32_speedup
335 | STATIS_INFO["ffpa+acc-f16+L1(speedup)"] = ffpa_l1_f16_speedup
336 |
337 | return (
338 | sdpa_tflops,
339 | ffpa_l1_f32_best_tflops,
340 | ffpa_l1_f16_best_tflops,
341 | ffpa_l1_f32_speedup,
342 | ffpa_l1_f16_speedup,
343 | )
344 |
345 |
346 | def gen_bench_markdown_table():
347 | global STATIS_INFO
348 | STATIS_INFO["headdim"] = sorted(list(STATIS_INFO["headdim"]))
349 | pretty_print_line("FFPA-L1 Benchmark Data")
350 | print(STATIS_INFO)
351 | pretty_print_line()
352 | headdims = [str(d) for d in STATIS_INFO["headdim"]]
353 | num_headdim = len(headdims)
354 | table_header = "|Algorithm|" + "|".join(headdims) + "|\n"
355 | table_header += "|:---:|" + ":---:|" * num_headdim
356 | (
357 | sdpa_tflops,
358 | ffpa_l1_f32_best_tflops,
359 | ffpa_l1_f16_best_tflops,
360 | ffpa_l1_f32_speedup,
361 | ffpa_l1_f16_speedup,
362 | ) = get_best_tflops()
363 |
364 | # sdpa, ffpa, speedup strings.
365 | sdpa_tflops_str = "|SDPA EA|" + "|".join([str(s) + "T" for s in sdpa_tflops]) + "|"
366 | ffpa_l1_f32_tflops_str = (
367 | "|FFPA L1*|" + "|".join([str(f) + "T" for f in ffpa_l1_f32_best_tflops]) + "|"
368 | )
369 | ffpa_l1_f32_speedup_str = (
370 | "|Speedup|" + "|".join([str(fs) + "x" for fs in ffpa_l1_f32_speedup]) + "|"
371 | )
372 | ffpa_l1_f16_tflops_str = (
373 | "|FFPA L1^|" + "|".join([str(f) + "T" for f in ffpa_l1_f16_best_tflops]) + "|"
374 | )
375 | ffpa_l1_f16_speedup_str = (
376 | "|Speedup|" + "|".join([str(fs) + "x" for fs in ffpa_l1_f16_speedup]) + "|"
377 | )
378 | pretty_print_line("FFPA-L1 Best Benchmark Markdown Table:")
379 | print("\n")
380 | print(table_header)
381 | print(sdpa_tflops_str)
382 | print(ffpa_l1_f32_tflops_str)
383 | print(ffpa_l1_f32_speedup_str)
384 | print(ffpa_l1_f16_tflops_str)
385 | print(ffpa_l1_f16_speedup_str)
386 | print("\n")
387 |
388 |
389 | def sort_tflops_by_headdim():
390 | global STATIS_INFO
391 | NEW_STATIS_INFO = {}
392 | headdims = sorted(list(STATIS_INFO["headdim"]), reverse=True)
393 | NEW_STATIS_INFO["headdim"] = headdims
394 | for tag, tflops in STATIS_INFO.items():
395 | if tag == "headdim":
396 | continue
397 | new_tflops = []
398 | for d in headdims:
399 | idx = STATIS_INFO["headdim"].index(d)
400 | new_tflops.append(tflops[idx])
401 | NEW_STATIS_INFO[tag] = new_tflops
402 | return NEW_STATIS_INFO
403 |
404 |
405 | def plot_speedup_bar(
406 | speedup: list[float], extra_tag: str = "ffpa+acc+f32+L1", headdim: list[int] = None
407 | ):
408 | import matplotlib.pyplot as plt
409 |
410 | _ = plt.subplots(figsize=(16, 9))[1] # fig, axs
411 | plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.07)
412 | x = range(len(speedup))
413 | # Plot the bar chart, setting a different color for each bar
414 | random.seed(0)
415 | for i, value in enumerate(speedup):
416 | random_color = (random.random(), random.random(), random.random())
417 | plt.bar(i, value, color=random_color)
418 | plt.xlabel("Headdim(D)", fontsize=15, fontweight="bold")
419 | plt.xticks(x, headdim, fontsize=15, fontweight="bold")
420 | plt.ylabel(f"{extra_tag.upper()} SpeedUp", fontsize=15, fontweight="bold")
421 | plt.title(
422 | f"SpeedUp of {extra_tag.upper()} vs SDPA EA, {ENV.get_device_name()}",
423 | fontsize=15,
424 | fontweight="bold",
425 | )
426 | # Set the range of the y-axis, adjusted according to the data
427 | plt.ylim(0, max(speedup) + 0.5)
428 | plt.yticks(fontweight="bold")
429 | for i, v in enumerate(speedup):
430 | plt.text(i, v + 0.1, str(v), ha="center", fontsize=20, fontweight="bold")
431 | plt.grid(True)
432 | # Display the graph
433 | device_name = ENV.get_device_name().replace(" ", "_")
434 | if args.save_tag:
435 | save_path = (
436 | f"{args.save_dir}/{device_name}_{args.save_tag}_{extra_tag}_Speedup.png"
437 | )
438 | else:
439 | save_path = f"{args.save_dir}/{device_name}_{extra_tag}_Speedup.png"
440 | plt.tight_layout()
441 | plt.savefig(save_path, dpi=300)
442 | pretty_print_line(f"plot FFPA Speedup bar done, saved as {save_path}")
443 | plt.close()
444 |
445 |
446 | def plot_tflops(level: str = "L1"):
447 | import matplotlib.pyplot as plt
448 | import numpy as np
449 |
450 | ax: plt.Axes = plt.subplots(figsize=(16, 9))[1] # fig, axs
451 | plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.07)
452 | B = 1 if not args.B else args.B
453 | H = 48 if not args.H else args.H
454 | N = 8192 if not args.N else args.N
455 | ax.set_title(
456 | f"FFPA {level} vs SDPA EA, {ENV.get_device_name()}, "
457 | f"B={B}, H={H}, N={N}, Warmup={args.warmup}, "
458 | f"Iters={args.iters}",
459 | fontsize=15,
460 | fontweight="bold",
461 | )
462 | ax.set_xlabel("Headdim(D)", fontsize=15, fontweight="bold")
463 | ax.set_ylabel("TFLOPS", fontsize=15, fontweight="bold")
464 | ax.grid(True)
465 |
466 | get_best_tflops()
467 | new_statis_info = sort_tflops_by_headdim()
468 |
469 | ax.set_xticks(np.arange(0, len(new_statis_info["headdim"]), 1))
470 | ax.set_xticklabels(
471 | new_statis_info["headdim"],
472 | rotation=45,
473 | ha="right",
474 | fontsize=15,
475 | fontweight="bold",
476 | )
477 | exclude_tags = []
478 | exclude_tags.append("headdim")
479 | exclude_tags = set(exclude_tags)
480 |
481 | draw_tags = list(new_statis_info.keys())
482 | draw_tags.remove("headdim")
483 | draw_tags.remove("ffpa+acc-f32+L1(speedup)")
484 | draw_tags.remove("ffpa+acc-f16+L1(speedup)")
485 | draw_tags = set(draw_tags)
486 | draw_tags.add("ffpa+acc-f32+L1(best)")
487 | draw_tags.add("ffpa+acc-f16+L1(best)")
488 | draw_tags = sorted(list(draw_tags))
489 |
490 | def skip_it(tag: str) -> bool:
491 | for etag in exclude_tags:
492 | if etag in tag:
493 | return True
494 | if tag not in draw_tags:
495 | return True
496 | return False
497 |
498 | for tag, tflops in new_statis_info.items():
499 | if skip_it(tag):
500 | continue
501 | if tag == "(sdpa)":
502 | ax.plot(tflops, label=tag, linewidth=3, color="green")
503 | elif tag == "(unfused)":
504 | ax.plot(tflops, label=tag, linewidth=3, color="black")
505 | else:
506 | if "ffpa+acc-f32+L1(best)" in tag:
507 | ax.plot(tflops, label=tag, linewidth=4, color="blue")
508 | elif "ffpa+acc-f16+L1(best)" in tag:
509 | ax.plot(tflops, label=tag, linewidth=4, color="red")
510 | else:
511 | ax.plot(tflops, label=tag, linestyle="--")
512 |
513 | for label in ax.get_yticklabels():
514 | label.set_fontweight("bold")
515 |
516 | ax.legend()
517 | device_name = ENV.get_device_name().replace(" ", "_")
518 | if args.save_tag:
519 | save_path = f"{args.save_dir}/{device_name}_{args.save_tag}.png"
520 | else:
521 | save_path = f"{args.save_dir}/{device_name}.png"
522 | os.makedirs(args.save_dir, exist_ok=True)
523 | plt.tight_layout()
524 | plt.savefig(save_path, dpi=300)
525 | pretty_print_line(f"plot FFPA TFLOPS done, saved as {save_path}")
526 | plt.close()
527 | plot_speedup_bar(
528 | new_statis_info["ffpa+acc-f32+L1(speedup)"],
529 | "ffpa+acc+f32+L1",
530 | new_statis_info["headdim"],
531 | )
532 | plot_speedup_bar(
533 | new_statis_info["ffpa+acc-f16+L1(speedup)"],
534 | "ffpa+acc+f16+L1",
535 | new_statis_info["headdim"],
536 | )
537 |
538 |
539 | def get_qkvo(B, H, N, D):
540 | if not (args.no_rand_q or args.no_rand_qkv):
541 | q = torch.randn((B, H, N, D), dtype=torch.half, device="cuda")
542 | else:
543 | q = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
544 | if not (args.no_rand_k or args.no_rand_qkv):
545 | k = torch.randn((B, H, N, D), dtype=torch.half, device="cuda")
546 | else:
547 | k = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
548 | if not (args.no_rand_v or args.no_rand_qkv):
549 | v = torch.randn((B, H, N, D), dtype=torch.half, device="cuda")
550 | else:
551 | v = torch.ones(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
552 |
553 | o = torch.zeros(B, H, N, D, device="cuda", dtype=torch.half).contiguous()
554 | # transpose (H,N) -> (N,H) for FA2.
555 | fq = q.transpose(1, 2).contiguous()
556 | fk = k.transpose(1, 2).contiguous()
557 | fv = v.transpose(1, 2).contiguous()
558 | # transpose (N,D) -> (D,N) for V smem swizzle.
559 | tk = k.transpose(-2, -1).contiguous() # [B,H,N,D] -> [B,H,D,N]
560 | tv = v.transpose(-2, -1).contiguous() # [B,H,N,D] -> [B,H,D,N]
561 |
562 | return q, k, v, o, fq, fk, fv, tk, tv
563 |
564 |
565 | # un-fused naive attn
566 | def unfused_standard_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
567 | att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))
568 | att = F.softmax(att, dim=-1)
569 | y = att @ v
570 | return y
571 |
572 |
573 | def sdpa(q: Tensor, k: Tensor, v: Tensor, use_flash: bool = False):
574 | if not use_flash:
575 | with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
576 | out: Tensor = F.scaled_dot_product_attention(q, k, v)
577 | else:
578 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
579 | out: Tensor = F.scaled_dot_product_attention(q, k, v)
580 | return out
581 |
582 |
583 | def check_all_close(
584 | out_flash_or_sdpa: torch.Tensor,
585 | out_mma: torch.Tensor,
586 | tag: str = "out_mma",
587 | check_all: bool = False,
588 | is_flash: bool = False,
589 | ):
590 | if any((out_flash_or_sdpa is None, out_mma is None)):
591 | return
592 | if is_flash:
593 | true_tag = "out_flash"
594 | out_flash_or_sdpa = out_flash_or_sdpa.transpose(1, 2)
595 | else:
596 | true_tag = "out_sdpa"
597 | if check_all:
598 | for i in range(int(N / 8)):
599 | if i < 4:
600 | pretty_print_line()
601 | print(f"{true_tag}[:, :, {(i * 8)}:{(i + 1) * 8}, :]:\n")
602 | print(out_flash_or_sdpa[:, :, (i * 8) : (i + 1) * 8, :].float())
603 | print(f"{tag}[:, :, {(i * 8)}:{(i + 1) * 8}, :]:\n")
604 | print(out_mma[:, :, (i * 8) : (i + 1) * 8, :].float())
605 | pretty_print_line()
606 | diff = torch.abs(out_flash_or_sdpa - out_mma)
607 | all_close = str(torch.allclose(out_flash_or_sdpa, out_mma, atol=1e-2))
608 | pretty_print_line(
609 | f"{true_tag} vs {tag:<15}, all close: {all_close:<6}, "
610 | f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
611 | f"mean diff: {diff.mean().item():.6f}",
612 | sep="",
613 | mode="left",
614 | )
615 |
616 |
617 | Bs = [1] if not args.B else [args.B]
618 | Hs = [48] if not args.H else [args.H]
619 | Ns = [8192] if not args.N else [args.N]
620 | Ds = list(range(320, args.MAX_D + 64, 64)) if not args.D else [args.D]
621 | # batch_size, n_head, seq_len, head_dim (B,H,N,D)
622 | BHNDs = [(B, H, N, D) for B in Bs for H in Hs for N in Ns for D in Ds]
623 | # max headdim supported for different methods. skip if D > max_D.
624 | MAX_HEADDIM_CFG: dict[str, int] = {
625 | # FFPA, SDPA, Naive MHA.
626 | "(sdpa)": 4096, # may no limit
627 | "(unfused)": 4096, # may no limit
628 | "(ffpa+acc+f16+L1+stage1)": 1024, # may no limit
629 | "(ffpa+acc+f16+L1+stage2)": 1024, # may no limit
630 | "(ffpa+acc+f16+L1+stage3)": 1024, # may no limit
631 | "(ffpa+acc+f16+L1+stage4)": 1024, # may no limit
632 | "(ffpa+acc+f32+L1+stage1)": 1024, # may no limit
633 | "(ffpa+acc+f32+L1+stage2)": 1024, # may no limit
634 | "(ffpa+acc+f32+L1+stage3)": 1024, # may no limit
635 | "(ffpa+acc+f32+L1+stage4)": 1024, # may no limit
636 | }
637 |
638 | seed = args.seed if args.seed else random.choice(range(10000))
639 | set_rand_seed(seed)
640 | pretty_print_line()
641 | pretty_print_line(
642 | f"B: batch_size, H: n_head, N: seq_len, D: head_dim, "
643 | f"seed: {seed}, Warmup: {args.warmup}, Iters: {args.iters}"
644 | )
645 |
646 | for (B, H, N, D) in BHNDs:
647 | MAX_TFLOPS = -1
648 | SDPA_TFLOPS = -1
649 | q, k, v, o, fq, fk, fv, tk, tv = get_qkvo(B, H, N, D)
650 | torch.cuda.synchronize()
651 | pretty_print_line()
652 | pretty_print_line(
653 | f"B={B}, H={H}, N={N}, D={D}, Warmup: {args.warmup}, Iters: {args.iters}"
654 | )
655 | if not use_ffpa_attn_package:
656 | # Naive MHA, FFPA, SDPA (D > 256)
657 | out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "(unfused)")
658 | out_sdpa, _ = run_benchmark(
659 | partial(sdpa, use_flash=(D <= 256)), q, k, v, "(sdpa)"
660 | )
661 | out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
662 | out_ffpa_l1_f321, _ = run_benchmark(
663 | ffpa_attn.ffpa_mma_acc_f32_L1,
664 | q,
665 | k,
666 | v,
667 | "(ffpa+acc+f32+L1+stage1)",
668 | o,
669 | stages=1,
670 | )
671 | out_ffpa_l1_f322, _ = run_benchmark(
672 | ffpa_attn.ffpa_mma_acc_f32_L1,
673 | q,
674 | k,
675 | v,
676 | "(ffpa+acc+f32+L1+stage2)",
677 | o,
678 | stages=2,
679 | )
680 | if ENV.enable_all_mutistages():
681 | out_ffpa_l1_f323, _ = run_benchmark(
682 | ffpa_attn.ffpa_mma_acc_f32_L1,
683 | q,
684 | k,
685 | v,
686 | "(ffpa+acc+f32+L1+stage3)",
687 | o,
688 | stages=3,
689 | )
690 | out_ffpa_l1_f324, _ = run_benchmark(
691 | ffpa_attn.ffpa_mma_acc_f32_L1,
692 | q,
693 | k,
694 | v,
695 | "(ffpa+acc+f32+L1+stage4)",
696 | o,
697 | stages=4,
698 | )
699 | out_ffpa_l1_f161, _ = run_benchmark(
700 | ffpa_attn.ffpa_mma_acc_f16_L1,
701 | q,
702 | k,
703 | v,
704 | "(ffpa+acc+f16+L1+stage1)",
705 | o,
706 | stages=1,
707 | )
708 | out_ffpa_l1_f162, _ = run_benchmark(
709 | ffpa_attn.ffpa_mma_acc_f16_L1,
710 | q,
711 | k,
712 | v,
713 | "(ffpa+acc+f16+L1+stage2)",
714 | o,
715 | stages=2,
716 | )
717 | if ENV.enable_all_mutistages():
718 | out_ffpa_l1_f163, _ = run_benchmark(
719 | ffpa_attn.ffpa_mma_acc_f16_L1,
720 | q,
721 | k,
722 | v,
723 | "(ffpa+acc+f16+L1+stage3)",
724 | o,
725 | stages=3,
726 | )
727 | out_ffpa_l1_f164, _ = run_benchmark(
728 | ffpa_attn.ffpa_mma_acc_f16_L1,
729 | q,
730 | k,
731 | v,
732 | "(ffpa+acc+f16+L1+stage4)",
733 | o,
734 | stages=4,
735 | )
736 | else:
737 | # Naive MHA, FFPA, SDPA (D > 256)
738 | out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "(unfused)")
739 | out_sdpa, _ = run_benchmark(
740 | partial(sdpa, use_flash=(D <= 256)), q, k, v, "(sdpa)"
741 | )
742 | out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
743 | out_ffpa_l1_f321, _ = run_benchmark(
744 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32),
745 | q,
746 | k,
747 | v,
748 | "(ffpa+acc+f32+L1+stage1)",
749 | o,
750 | stages=1,
751 | )
752 | out_ffpa_l1_f322, _ = run_benchmark(
753 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32),
754 | q,
755 | k,
756 | v,
757 | "(ffpa+acc+f32+L1+stage2)",
758 | o,
759 | stages=2,
760 | )
761 | if ENV.enable_all_mutistages():
762 | out_ffpa_l1_f323, _ = run_benchmark(
763 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32),
764 | q,
765 | k,
766 | v,
767 | "(ffpa+acc+f32+L1+stage3)",
768 | o,
769 | stages=3,
770 | )
771 | out_ffpa_l1_f324, _ = run_benchmark(
772 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP32),
773 | q,
774 | k,
775 | v,
776 | "(ffpa+acc+f32+L1+stage4)",
777 | o,
778 | stages=4,
779 | )
780 | out_ffpa_l1_f161, _ = run_benchmark(
781 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16),
782 | q,
783 | k,
784 | v,
785 | "(ffpa+acc+f16+L1+stage1)",
786 | o,
787 | stages=1,
788 | )
789 | out_ffpa_l1_f162, _ = run_benchmark(
790 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16),
791 | q,
792 | k,
793 | v,
794 | "(ffpa+acc+f16+L1+stage2)",
795 | o,
796 | stages=2,
797 | )
798 | if ENV.enable_all_mutistages():
799 | out_ffpa_l1_f163, _ = run_benchmark(
800 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16),
801 | q,
802 | k,
803 | v,
804 | "(ffpa+acc+f16+L1+stage3)",
805 | o,
806 | stages=3,
807 | )
808 | out_ffpa_l1_f164, _ = run_benchmark(
809 | partial(ffpa_attn.ffpa, level=ffpa_attn.L1, acc=ffpa_attn.FP16),
810 | q,
811 | k,
812 | v,
813 | "(ffpa+acc+f16+L1+stage4)",
814 | o,
815 | stages=4,
816 | )
817 | pretty_print_line()
818 |
819 | del q
820 | del k
821 | del v
822 | del o
823 | del fq
824 | del fk
825 | del fv
826 | del tk
827 | del tv
828 | torch.cuda.empty_cache()
829 | torch.cuda.synchronize()
830 | if args.check:
831 | check_all_close(out_sdpa, out_ffpa_l1_f321, "out_ffpa_l1_f321", args.check_all)
832 | check_all_close(out_sdpa, out_ffpa_l1_f322, "out_ffpa_l1_f322", args.check_all)
833 | if ENV.enable_all_mutistages():
834 | check_all_close(
835 | out_sdpa, out_ffpa_l1_f323, "out_ffpa_l1_f323", args.check_all
836 | )
837 | check_all_close(
838 | out_sdpa, out_ffpa_l1_f324, "out_ffpa_l1_f324", args.check_all
839 | )
840 | check_all_close(out_sdpa, out_ffpa_l1_f161, "out_ffpa_l1_f161", args.check_all)
841 | check_all_close(out_sdpa, out_ffpa_l1_f162, "out_ffpa_l1_f161", args.check_all)
842 | if ENV.enable_all_mutistages():
843 | check_all_close(
844 | out_sdpa, out_ffpa_l1_f163, "out_ffpa_l1_f163", args.check_all
845 | )
846 | check_all_close(
847 | out_sdpa, out_ffpa_l1_f164, "out_ffpa_l1_f164", args.check_all
848 | )
849 | pretty_print_line()
850 |
851 | if args.gen_bench_table:
852 | gen_bench_markdown_table()
853 |
854 | if args.plot_flops:
855 | plot_tflops()
856 |
--------------------------------------------------------------------------------
/tests/test_fused_mla.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------