├── .clang-format ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── README.md ├── bench ├── CMakeLists.txt └── device │ ├── CMakeLists.txt │ ├── harness.h │ ├── simt_binary_or_binary_and_dsrgemm_nn_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_nn_t_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_nt_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_nt_t_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tn_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tn_t_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tt_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tt_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nn_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nn_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nt_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nt_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tn_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tn_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tt_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tt_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nn_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nn_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nt_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nt_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tn_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tn_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tt_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tt_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nn_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nn_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nt_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nt_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tn_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tn_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tt_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tt_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nn_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nn_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nt_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nt_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tn_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tn_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tt_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tt_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nn_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nn_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nt_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nt_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tn_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tn_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tt_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tt_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nn_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nn_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nt_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nt_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tn_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tn_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tt_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tt_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nn_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nn_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nt_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nt_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tn_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tn_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tt_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tt_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nn_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nn_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nt_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nt_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tn_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tn_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tt_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tt_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nn_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nn_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nt_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nt_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tn_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tn_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tt_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tt_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nn_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nn_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nt_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nt_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tn_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tn_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tt_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tt_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nn_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nn_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nt_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nt_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tn_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tn_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tt_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tt_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nn_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nn_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nt_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nt_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tn_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tn_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tt_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tt_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nn_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nn_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nt_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nt_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tn_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tn_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tt_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tt_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nn_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nn_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nt_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nt_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tn_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tn_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tt_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tt_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nn_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nn_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nt_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nt_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tn_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tn_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tt_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tt_t_sm50.cu │ └── simt_sm50.py ├── examples ├── 00_minplus_srgemm │ ├── CMakeLists.txt │ └── minplus_srgemm.cu ├── 01_userdefined_semiring │ ├── CMakeLists.txt │ └── userdefined_semiring.cu ├── 02_splitk_srgemm │ ├── CMakeLists.txt │ └── splitk_srgemm.cu └── CMakeLists.txt ├── include └── cuasr │ ├── arch │ └── srmma.h │ ├── functional.h │ ├── gemm │ ├── device │ │ ├── default_srgemm_configuration.h │ │ ├── srgemm.h │ │ └── srgemm_splitk_parallel.h │ ├── epilogue │ │ └── thread │ │ │ └── semiring_linear_combination.h │ ├── kernel │ │ ├── default_srgemm.h │ │ ├── default_srgemm_splitk_parallel.h │ │ ├── srgemm.h │ │ └── srgemm_splitk_parallel.h │ ├── thread │ │ ├── srmma.h │ │ └── srmma_sm50.h │ ├── threadblock │ │ ├── default_srmma.h │ │ ├── default_srmma_core.h │ │ ├── default_srmma_core_simt.h │ │ └── srmma_pipelined.h │ └── warp │ │ └── srmma_simt.h │ └── reduction │ ├── kernel │ └── reduce_split_k.h │ └── thread │ ├── reduce.h │ └── reduction_operators.h ├── test ├── CMakeLists.txt ├── device │ ├── CMakeLists.txt │ ├── simt_binary_or_binary_and_dsrgemm_nn_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_nn_t_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_nt_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_nt_t_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tn_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tn_t_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tt_n_sm50.cu │ ├── simt_binary_or_binary_and_dsrgemm_tt_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nn_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nn_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nt_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_nt_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tn_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tn_t_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tt_n_sm50.cu │ ├── simt_binary_or_binary_and_ssrgemm_tt_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nn_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nn_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nt_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_nt_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tn_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tn_t_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tt_n_sm50.cu │ ├── simt_maximum_minimum_dsrgemm_tt_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nn_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nn_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nt_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_nt_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tn_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tn_t_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tt_n_sm50.cu │ ├── simt_maximum_minimum_ssrgemm_tt_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nn_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nn_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nt_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_nt_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tn_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tn_t_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tt_n_sm50.cu │ ├── simt_maximum_multiplies_dsrgemm_tt_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nn_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nn_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nt_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_nt_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tn_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tn_t_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tt_n_sm50.cu │ ├── simt_maximum_multiplies_ssrgemm_tt_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nn_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nn_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nt_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_nt_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tn_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tn_t_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tt_n_sm50.cu │ ├── simt_maximum_plus_dsrgemm_tt_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nn_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nn_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nt_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_nt_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tn_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tn_t_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tt_n_sm50.cu │ ├── simt_maximum_plus_ssrgemm_tt_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nn_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nn_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nt_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_nt_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tn_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tn_t_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tt_n_sm50.cu │ ├── simt_minimum_maximum_dsrgemm_tt_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nn_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nn_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nt_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_nt_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tn_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tn_t_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tt_n_sm50.cu │ ├── simt_minimum_maximum_ssrgemm_tt_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nn_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nn_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nt_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_nt_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tn_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tn_t_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tt_n_sm50.cu │ ├── simt_minimum_multiplies_dsrgemm_tt_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nn_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nn_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nt_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_nt_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tn_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tn_t_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tt_n_sm50.cu │ ├── simt_minimum_multiplies_ssrgemm_tt_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nn_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nn_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nt_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_nt_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tn_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tn_t_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tt_n_sm50.cu │ ├── simt_minimum_plus_dsrgemm_tt_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nn_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nn_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nt_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_nt_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tn_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tn_t_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tt_n_sm50.cu │ ├── simt_minimum_plus_ssrgemm_tt_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nn_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nn_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nt_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_nt_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tn_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tn_t_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tt_n_sm50.cu │ ├── simt_plus_multiplies_dsrgemm_tt_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nn_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nn_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nt_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_nt_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tn_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tn_t_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tt_n_sm50.cu │ ├── simt_plus_multiplies_ssrgemm_tt_t_sm50.cu │ ├── simt_sm50.py │ └── testbed.h ├── harness.cpp └── regress │ ├── CMakeLists.txt │ ├── Matrix_test.cpp │ ├── Srgemm_test.cu │ ├── include │ └── fwgpu │ │ ├── Matrix.hpp │ │ ├── cpu_srgemm.hpp │ │ ├── gpu_srgemm.cuh │ │ ├── gpu_srgemm.hpp │ │ └── utils.hpp │ ├── src │ ├── cutlass_srgemm.cu │ └── utils.cu │ └── utils.cuh └── tools └── include └── cuasr └── reference └── srgemm └── host_srgemm.h /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | BasedOnStyle: WebKit 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: AlwaysBreak 6 | AlignConsecutiveAssignments: true 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Right 9 | AlignOperands: false 10 | AlignTrailingComments: true 11 | AllowAllArgumentsOnNextLine: true 12 | AllowAllParametersOfDeclarationOnNextLine: true 13 | AllowShortBlocksOnASingleLine: false 14 | AllowShortCaseLabelsOnASingleLine: false 15 | AllowShortFunctionsOnASingleLine: All 16 | AllowShortIfStatementsOnASingleLine: false 17 | AllowShortLoopsOnASingleLine: false 18 | AlwaysBreakAfterDefinitionReturnType: None 19 | AlwaysBreakAfterReturnType: None 20 | AlwaysBreakBeforeMultilineStrings: false 21 | AlwaysBreakTemplateDeclarations: Yes 22 | BinPackArguments: true 23 | BinPackParameters: false 24 | BreakBeforeBraces: Custom 25 | BraceWrapping: 26 | AfterClass: false 27 | AfterControlStatement: false 28 | AfterEnum: false 29 | AfterFunction: false 30 | AfterNamespace: false 31 | AfterObjCDeclaration: false 32 | AfterStruct: false 33 | AfterUnion: false 34 | AfterExternBlock: false 35 | BeforeCatch: true 36 | BeforeElse: true 37 | IndentBraces: false 38 | SplitEmptyFunction: true 39 | SplitEmptyRecord: true 40 | SplitEmptyNamespace: true 41 | BreakBeforeBinaryOperators: All 42 | BreakBeforeInheritanceComma: false 43 | BreakInheritanceList: BeforeColon 44 | BreakBeforeTernaryOperators: true 45 | BreakConstructorInitializersBeforeComma: false 46 | BreakConstructorInitializers: BeforeComma 47 | BreakAfterJavaFieldAnnotations: false 48 | BreakStringLiterals: true 49 | ColumnLimit: 90 50 | CommentPragmas: '^ IWYU pragma:' 51 | CompactNamespaces: false 52 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 53 | ConstructorInitializerIndentWidth: 4 54 | ContinuationIndentWidth: 4 55 | Cpp11BracedListStyle: false 56 | DerivePointerAlignment: false 57 | DisableFormat: false 58 | ExperimentalAutoDetectBinPacking: false 59 | FixNamespaceComments: false 60 | ForEachMacros: 61 | - foreach 62 | - Q_FOREACHfalse 63 | - BOOST_FOREACH 64 | IncludeBlocks: Preserve 65 | IncludeCategories: 66 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 67 | Priority: 2 68 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 69 | Priority: 3 70 | - Regex: '.*' 71 | Priority: 1 72 | IncludeIsMainRegex: '(Test)?$' 73 | IndentCaseLabels: true 74 | IndentPPDirectives: None 75 | IndentWidth: 2 76 | IndentWrappedFunctionNames: false 77 | JavaScriptQuotes: Leave 78 | JavaScriptWrapImports: true 79 | KeepEmptyLinesAtTheStartOfBlocks: false 80 | MacroBlockBegin: '' 81 | MacroBlockEnd: '' 82 | MaxEmptyLinesToKeep: 2 83 | NamespaceIndentation: None 84 | ObjCBinPackProtocolList: Auto 85 | ObjCBlockIndentWidth: 4 86 | ObjCSpaceAfterProperty: true 87 | ObjCSpaceBeforeProtocolList: true 88 | PenaltyBreakAssignment: 2 89 | PenaltyBreakBeforeFirstCallParameter: 19 90 | PenaltyBreakComment: 300 91 | PenaltyBreakFirstLessLess: 120 92 | PenaltyBreakString: 1000 93 | PenaltyBreakTemplateDeclaration: 10 94 | PenaltyExcessCharacter: 99999999 95 | PenaltyReturnTypeOnItsOwnLine: 60 96 | PointerAlignment: Right 97 | ReflowComments: true 98 | SortIncludes: true 99 | SortUsingDeclarations: true 100 | SpaceAfterCStyleCast: false 101 | SpaceAfterTemplateKeyword: true 102 | SpaceBeforeAssignmentOperators: true 103 | SpaceBeforeCpp11BracedList: true 104 | SpaceBeforeCtorInitializerColon: true 105 | SpaceBeforeInheritanceColon: true 106 | SpaceBeforeParens: ControlStatements 107 | SpaceBeforeRangeBasedForLoopColon: true 108 | SpaceInEmptyParentheses: false 109 | SpacesBeforeTrailingComments: 1 110 | SpacesInAngles: false 111 | SpacesInContainerLiterals: false 112 | SpacesInCStyleCastParentheses: false 113 | SpacesInParentheses: false 114 | SpacesInSquareBrackets: false 115 | Standard: Cpp11 116 | TabWidth: 4 117 | UseTab: Never 118 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE files: 2 | .vscode 3 | 4 | # For OSX users: 5 | .DS_Store 6 | 7 | # Prerequisites 8 | *.d 9 | 10 | # Compiled Object files 11 | *.slo 12 | *.lo 13 | *.o 14 | *.obj 15 | 16 | # Build directory 17 | /build/* 18 | 19 | # Precompiled Headers 20 | *.gch 21 | *.pch 22 | 23 | # Compiled Dynamic libraries 24 | *.so 25 | *.dylib 26 | *.dll 27 | 28 | # Fortran module files 29 | *.mod 30 | *.smod 31 | 32 | # Compiled Static libraries 33 | *.lai 34 | *.la 35 | *.a 36 | *.lib 37 | 38 | # Executables 39 | *.exe 40 | *.out 41 | *.app 42 | 43 | # CMake crap: 44 | CMakeLists.txt.user 45 | CMakeCache.txt 46 | CMakeFiles 47 | CMakeScripts 48 | Testing 49 | Makefile 50 | cmake_install.cmake 51 | install_manifest.txt 52 | compile_commands.json 53 | CTestTestfile.cmake 54 | _deps 55 | 56 | # GDB history 57 | .gdb_history 58 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "test/gtest"] 2 | path = test/gtest 3 | url = https://github.com/google/googletest.git 4 | [submodule "bench/benchmark"] 5 | path = bench/benchmark 6 | url = https://github.com/google/benchmark.git 7 | [submodule "cutlass"] 8 | path = cutlass 9 | url = https://github.com/NVIDIA/cutlass.git 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.13) 2 | project(cuASR CUDA CXX) 3 | 4 | # RELEASE config by default if none is provided: 5 | if(NOT CMAKE_BUILD_TYPE) 6 | set(CMAKE_BUILD_TYPE "RELEASE") 7 | set(BUILD_TYPE_INFERRED_RELEASE TRUE) 8 | endif() 9 | 10 | # first convert build type string to uppercase, and then compare 11 | string(TOUPPER "${CMAKE_BUILD_TYPE}" uppercase_CMAKE_BUILD_TYPE) 12 | if(NOT uppercase_CMAKE_BUILD_TYPE MATCHES "^(DEBUG|RELEASE|RELWITHDEBINFO)$") 13 | message(FATAL_ERROR "Invalid value for CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") 14 | endif() 15 | 16 | # always dump compiler invocation commands to compile_commands.json 17 | # note that this does not include nvcc invocations >:( 18 | set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE) 19 | 20 | # Switches for testing, benchmarks and examples 21 | option(CUASR_TEST "Build cuASR test suite. Use with CUASR_TEST_LEVEL={0|1|2}." ON) 22 | option(CUASR_BENCH "Build cuASR benchmark suite." ON) 23 | option(CUASR_EXAMPLE "Build cuASR examples." ON) 24 | 25 | # By default, build fat binaries. TODO add sm_80 here 26 | option(CUASR_CUDA_ARCHS "List of CUDA architectures to compile for." "60 61 70 72 75") 27 | 28 | # CUDA native compiler (nvcc) only supports upto C++14 for now 29 | find_package(CUDA REQUIRED) 30 | set(CMAKE_CXX_EXTENSIONS OFF) 31 | set(CMAKE_CXX_STANDARD 14) 32 | 33 | # C++ compiler flags for target compile options 34 | set(cuASR_CXX_FLAGS -Wall -Wextra -Wno-unused-parameter -Wno-uninitialized -Wno-strict-aliasing) 35 | set(cuASR_CXX_FLAGS_DEBUG -O0 -g3 -DDEBUG ${cuASR_CXX_FLAGS}) 36 | set(cuASR_CXX_FLAGS_RELEASE -O3 -DNDEBUG ${cuASR_CXX_FLAGS}) 37 | set(cuASR_CXX_FLAGS_RELWITHDEBINFO -O3 -g3 -DNDEBUG ${cuASR_CXX_FLAGS}) 38 | 39 | # CUDA compiler flags for target compile options 40 | set(cuASR_CUDA_FLAGS --expt-relaxed-constexpr) 41 | set(cuASR_CUDA_FLAGS_DEBUG -G ${cuASR_CUDA_FLAGS}) 42 | set(cuASR_CUDA_FLAGS_RELEASE -O3 ${cuASR_CUDA_FLAGS}) 43 | set(cuASR_CUDA_FLAGS_RELWITHDEBINFO -G ${cuASR_CUDA_FLAGS}) 44 | set(CMAKE_CUDA_ARCHITECTURES ${CUASR_CUDA_ARCHS}) 45 | 46 | # the sub-modules update themselves with git, so find git 47 | find_package(Git QUIET) 48 | 49 | # make sure we have cutlass checked-out and is up-to-date 50 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 51 | message(STATUS "Checking submodule version for NVIDIA/cutlass") 52 | execute_process( 53 | COMMAND ${GIT_EXECUTABLE} submodule update --init ${PROJECT_SOURCE_DIR}/cutlass 54 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 55 | OUTPUT_VARIABLE GIT_SUBMOD_STDOUT OUTPUT_STRIP_TRAILING_WHITESPACE 56 | ERROR_VARIABLE GIT_SUBMOD_STDERR ERROR_STRIP_TRAILING_WHITESPACE 57 | RESULT_VARIABLE GIT_SUBMOD_RESULT 58 | ) 59 | if(NOT GIT_SUBMOD_RESULT EQUAL "0") 60 | message(FATAL_ERROR "git submodule update --init failed with ${GIT_SUBMOD_RESULT}, please checkout cutlass manually. Git stdout was ${GIT_SUBMOD_STDOUT}. Git stderr was ${GIT_SUBMOD_STDERR}.") 61 | elseif(NOT ${GIT_SUBMOD_STDOUT} STREQUAL "") 62 | message(STATUS ${GIT_SUBMOD_STDOUT}) 63 | endif() 64 | endif() 65 | 66 | if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/cutlass/include") 67 | message(FATAL_ERROR "Cutlass submodule is not present and automatic checkout failed, please checkout cutlass manually.") 68 | endif() 69 | 70 | if(CUASR_TEST) 71 | enable_testing() 72 | add_subdirectory(test) 73 | if(NOT DEFINED CUASR_TEST_LEVEL) 74 | set(CUASR_TEST_LEVEL 0) 75 | endif() 76 | endif() 77 | 78 | if(CUASR_BENCH) 79 | add_subdirectory(bench) 80 | if(NOT DEFINED CUASR_BENCH_LEVEL) 81 | set(CUASR_BENCH_LEVEL 0) 82 | endif() 83 | endif() 84 | 85 | if(CUASR_EXAMPLE) 86 | add_subdirectory(examples) 87 | endif() 88 | 89 | message(STATUS "") 90 | message(STATUS "BUILD SUMMARY:") 91 | message(STATUS " Build type : ${uppercase_CMAKE_BUILD_TYPE}") 92 | message(STATUS " CMAKE_GENERATOR : ${CMAKE_GENERATOR}") 93 | message(STATUS " C++ Compiler : ${CMAKE_CXX_COMPILER}") 94 | message(STATUS " C++ Compiler version : ${CMAKE_CXX_COMPILER_VERSION}") 95 | message(STATUS " CUDA Compiler : ${CMAKE_CUDA_COMPILER}") 96 | message(STATUS " CUDA Compiler version: ${CMAKE_CUDA_COMPILER_VERSION}") 97 | message(STATUS " Build tests : ${CUASR_TEST}") 98 | message(STATUS " Test level : ${CUASR_TEST_LEVEL}") 99 | message(STATUS " Build benchmarks : ${CUASR_BENCH}") 100 | message(STATUS " Bench level : ${CUASR_BENCH_LEVEL}") 101 | message(STATUS " Build examples : ${CUASR_EXAMPLE}") 102 | message(STATUS " Found CUDA? : ${CUDA_FOUND}") 103 | message(STATUS " CXX flags : ${cuASR_CXX_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}") 104 | message(STATUS " CUDA flags : ${cuASR_CUDA_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}") 105 | if (BUILD_TYPE_INFERRED_RELEASE) 106 | message(WARNING "No build type provided, defaulted to RELEASE configuration.") 107 | endif() 108 | -------------------------------------------------------------------------------- /bench/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # first make sure we have benchmark checked-out and its up-to-date 2 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 3 | message(STATUS "Checking submodule version for google/benchmark") 4 | execute_process( 5 | COMMAND ${GIT_EXECUTABLE} submodule update --init ${PROJECT_SOURCE_DIR}/bench/benchmark 6 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 7 | OUTPUT_VARIABLE GIT_SUBMOD_STDOUT OUTPUT_STRIP_TRAILING_WHITESPACE 8 | ERROR_VARIABLE GIT_SUBMOD_STDERR ERROR_STRIP_TRAILING_WHITESPACE 9 | RESULT_VARIABLE GIT_SUBMOD_RESULT 10 | ) 11 | if(NOT GIT_SUBMOD_RESULT EQUAL "0") 12 | message(FATAL_ERROR "git submodule update --init failed with ${GIT_SUBMOD_RESULT}, please checkout benchmark manually. Git stdout was ${GIT_SUBMOD_STDOUT}. Git stderr was ${GIT_SUBMOD_STDERR}.") 13 | elseif(NOT ${GIT_SUBMOD_STDOUT} STREQUAL "") 14 | message(STATUS ${GIT_SUBMOD_STDOUT}) 15 | endif() 16 | endif() 17 | 18 | if(NOT EXISTS "${PROJECT_SOURCE_DIR}/bench/benchmark/include") 19 | message(FATAL_ERROR "GTest submodule is not present and automatic checkout failed, please checkout benchmark manually.") 20 | endif() 21 | 22 | set(BENCHMARK_ENABLE_TESTING OFF) 23 | add_subdirectory(benchmark) 24 | add_subdirectory(device) 25 | -------------------------------------------------------------------------------- /bench/device/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SIMT_BENCH_SRCS CONFIGURE_DEPENDS *.cu) 2 | add_executable(cuasr_bench_srgemm_device 3 | ${SIMT_BENCH_SRCS} 4 | ) 5 | target_include_directories( 6 | cuasr_bench_srgemm_device 7 | PRIVATE 8 | ${PROJECT_SOURCE_DIR}/include/ 9 | ${PROJECT_SOURCE_DIR}/tools/include/ 10 | ${PROJECT_SOURCE_DIR}/cutlass/include/ 11 | ${PROJECT_SOURCE_DIR}/cutlass/tools/util/include/ 12 | ) 13 | target_link_libraries(cuasr_bench_srgemm_device 14 | benchmark 15 | benchmark_main 16 | ${cuASR_LIB_NAME} 17 | ) 18 | if(NOT DEFINED CUASR_BENCH_LEVEL) 19 | set(CUASR_BENCH_LEVEL 0) 20 | endif() 21 | target_compile_definitions(cuasr_bench_srgemm_device 22 | PRIVATE CUASR_BENCH_LEVEL=${CUASR_BENCH_LEVEL} 23 | ) 24 | -------------------------------------------------------------------------------- /bench/device/harness.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass/util/distribution.h" 4 | #include "cutlass/util/host_tensor.h" 5 | #include "cutlass/util/reference/host/tensor_compare.h" 6 | #include "cutlass/util/reference/host/tensor_copy.h" 7 | #include "cutlass/util/reference/host/tensor_fill.h" 8 | #include "cutlass/util/reference/host/tensor_norm.h" 9 | #include "cutlass/util/tensor_view_io.h" 10 | 11 | #include "cuasr/reference/srgemm/host_srgemm.h" 12 | 13 | 14 | namespace cuasr { 15 | namespace bench { 16 | namespace device { 17 | 18 | 19 | namespace { 20 | inline char const *to_string(cutlass::Status status) { 21 | switch (status) { 22 | case cutlass::Status::kSuccess: 23 | return "kSuccess"; 24 | case cutlass::Status::kErrorMisalignedOperand: 25 | return "kErrorMisalignedOperand"; 26 | case cutlass::Status::kErrorInvalidLayout: 27 | return "kErrorInvalidLayout"; 28 | case cutlass::Status::kErrorInvalidProblem: 29 | return "kErrorInvalidProblem"; 30 | case cutlass::Status::kErrorNotSupported: 31 | return "kErrorNotSupported"; 32 | case cutlass::Status::kErrorWorkspaceNull: 33 | return "kErrorWorkspaceNull"; 34 | case cutlass::Status::kErrorInternal: 35 | return "kErrorInternal"; 36 | case cutlass::Status::kInvalid: 37 | return "kInvalid"; 38 | default: 39 | break; 40 | } 41 | return "invalid"; 42 | } 43 | } 44 | 45 | // Given a SIMT SRGEMM, sets up host and device tensors for the benchmark loop 46 | template 47 | class BenchHarness { 48 | using ElementAccumulator = typename Srgemm::ElementAccumulator; 49 | using ElementCompute = 50 | typename Srgemm::SrgemmKernel::Epilogue::OutputOp::ElementCompute; 51 | 52 | cutlass::gemm::GemmCoord problem_size; 53 | cutlass::Distribution::Kind init_A; 54 | cutlass::Distribution::Kind init_B; 55 | cutlass::Distribution::Kind init_C; 56 | uint64_t seed; 57 | 58 | cutlass::HostTensor tensor_A; 59 | cutlass::HostTensor tensor_B; 60 | cutlass::HostTensor tensor_C; 61 | cutlass::HostTensor tensor_D; 62 | cutlass::HostTensor reference_D; 63 | 64 | /// Helper to initialize a tensor view 65 | template 66 | bool initialize_tensor( 67 | cutlass::TensorView view, 68 | cutlass::Distribution::Kind dist_kind, 69 | uint64_t seed) { 70 | if (dist_kind == cutlass::Distribution::Uniform) { 71 | double scope_max, scope_min; 72 | int bits_input = cutlass::sizeof_bits::value; 73 | int bits_output = cutlass::sizeof_bits::value; 74 | 75 | if (bits_input == 1) { 76 | scope_max = 2; 77 | scope_min = 0; 78 | } 79 | else if (bits_input <= 8) { 80 | scope_max = 2; 81 | scope_min = -2; 82 | } 83 | else if (bits_output == 16) { 84 | scope_max = 5; 85 | scope_min = -5; 86 | } 87 | else { 88 | scope_max = 8; 89 | scope_min = -8; 90 | } 91 | 92 | cutlass::reference::host::TensorFillRandomUniform( 93 | view, seed, scope_max, scope_min, 0); 94 | } 95 | else if (dist_kind == cutlass::Distribution::Identity) { 96 | cutlass::reference::host::TensorFillIdentity(view); 97 | } 98 | else if (dist_kind == cutlass::Distribution::Gaussian) { 99 | cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); 100 | } 101 | else if (dist_kind == cutlass::Distribution::Sequential) { 102 | cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); 103 | } 104 | else { 105 | return false; 106 | } 107 | return true; 108 | } 109 | 110 | public: 111 | BenchHarness() = delete; 112 | 113 | // Methods 114 | BenchHarness( 115 | cutlass::gemm::GemmCoord problem_size_, 116 | cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, 117 | cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, 118 | cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, 119 | uint64_t seed_ = 2080) 120 | : problem_size(problem_size_) 121 | , init_A(init_A_) 122 | , init_B(init_B_) 123 | , init_C(init_C_) 124 | , seed(seed_) { 125 | this->initialize(problem_size); 126 | } 127 | 128 | 129 | // Initializes data structures on both host-device side for benchmark 130 | auto initialize(cutlass::gemm::GemmCoord problem_size) -> void { 131 | // Allocate the GEMM workspace 132 | tensor_A.resize(problem_size.mk()); 133 | tensor_B.resize(problem_size.kn()); 134 | tensor_C.resize(problem_size.mn()); 135 | tensor_D.resize(problem_size.mn()); 136 | reference_D.resize(problem_size.mn(), false); 137 | 138 | initialize_tensor(tensor_A.host_view(), init_A, seed + 2019); 139 | initialize_tensor(tensor_B.host_view(), init_B, seed + 2018); 140 | initialize_tensor(tensor_C.host_view(), init_C, seed + 2017); 141 | 142 | // It is possible to randomly initialize to all zeros, so override this with non-zeros 143 | // in the upper left corner of each operand. 144 | tensor_A.host_view().at({ 0, 0 }) = typename Srgemm::ElementA(1); 145 | tensor_B.host_view().at({ 0, 0 }) = typename Srgemm::ElementB(1); 146 | tensor_C.host_view().at({ 0, 0 }) = typename Srgemm::ElementC(1); 147 | 148 | cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); 149 | 150 | tensor_A.sync_device(); 151 | tensor_B.sync_device(); 152 | tensor_C.sync_device(); 153 | tensor_D.sync_device(); 154 | } 155 | 156 | // Runs one loop of the benchmark on initialized tensors 157 | auto 158 | run(int split_k_slices = 1, 159 | ElementCompute alpha = ElementCompute(Srgemm::MultiplicationOp::Identity), 160 | ElementCompute beta = ElementCompute(Srgemm::MultiplicationOp::Identity)) 161 | -> cutlass::Status { 162 | // Initialize the GEMM operator 163 | typename Srgemm::Arguments arguments { 164 | problem_size, // 165 | tensor_A.device_ref(), // 166 | tensor_B.device_ref(), // 167 | tensor_C.device_ref(), // 168 | tensor_D.device_ref(), // 169 | { alpha, beta }, // 170 | split_k_slices // 171 | }; 172 | 173 | Srgemm gemm_op; 174 | size_t workspace_size = Srgemm::get_workspace_size(arguments); 175 | cutlass::device_memory::allocation workspace(workspace_size); 176 | cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); 177 | 178 | // Run the GEMM 179 | status = gemm_op(); 180 | return status; 181 | } 182 | }; 183 | 184 | 185 | } // namespace device 186 | } // namespace bench 187 | } // namespace cuasr 188 | -------------------------------------------------------------------------------- /examples/00_minplus_srgemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(minplus_srgemm 2 | minplus_srgemm.cu 3 | ) 4 | target_include_directories(minplus_srgemm 5 | PRIVATE 6 | ${PROJECT_SOURCE_DIR}/include 7 | ${PROJECT_SOURCE_DIR}/cutlass/include 8 | ${CUDA_INCLUDE_DIRS} 9 | ) 10 | target_compile_options(minplus_srgemm 11 | PUBLIC 12 | # C++ compiler flags 13 | $<$,$>: 14 | ${cuASR_CXX_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 15 | 16 | # CUDA compiler flags 17 | $<$,$>: 18 | ${cuASR_CUDA_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 19 | ) 20 | -------------------------------------------------------------------------------- /examples/00_minplus_srgemm/minplus_srgemm.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuasr/gemm/device/default_srgemm_configuration.h" 6 | #include "cuasr/gemm/device/srgemm.h" 7 | #include "cuasr/functional.h" 8 | 9 | auto cuasr_minplus_srsgemm_nt_n( 10 | int M, 11 | int N, 12 | int K, 13 | float const *A, 14 | int lda, 15 | float const *B, 16 | int ldb, 17 | float *C, 18 | int ldc, 19 | float *D, 20 | bool do_epilogue_min, 21 | cudaStream_t stream = nullptr) -> int { 22 | // compile time configuration of this srgemm kernel using OperatorClass 23 | using OperatorClass = cutlass::arch::OpClassSimt; 24 | using SmArch = cutlass::arch::Sm50; 25 | using AdditionOp = cuasr::minimum; 26 | using MultiplicationOp = cuasr::plus; 27 | 28 | using TropicalConfig = typename cuasr::gemm::device::DefaultSemiRingConfiguration< 29 | float, float, float, float, OperatorClass, // 30 | AdditionOp, MultiplicationOp, SmArch>; 31 | 32 | using ColumnMajor = cutlass::layout::ColumnMajor; 33 | using RowMajor = cutlass::layout::RowMajor; 34 | 35 | using cuASR_MinPlus_SGEMM = cuasr::gemm::device::Srgemm< 36 | AdditionOp, // Thread level SemiRing operator 37 | MultiplicationOp, // Thread level SemiRing operator 38 | float, // element type of A 39 | ColumnMajor, // layout of A 40 | float, // element type of B 41 | RowMajor, // layout of B 42 | float, // element t ype of C 43 | ColumnMajor, // layout of C 44 | float // element type of D 45 | >; 46 | 47 | float alpha = MultiplicationOp::Identity; 48 | float beta 49 | = do_epilogue_min ? MultiplicationOp::Identity : MultiplicationOp::Annihilator; 50 | 51 | // construct kernel arguments struct 52 | cuASR_MinPlus_SGEMM::Arguments args( 53 | { M, N, K }, // Problem dimensions 54 | { A, lda }, // Tensor-ref for source matrix A 55 | { B, ldb }, // Tensor-ref for source matrix B 56 | { C, ldc }, // Tensor-ref for source matrix C 57 | { D, ldc }, // Tensor-ref for destination matrix D 58 | { alpha, beta } // 59 | ); 60 | 61 | // launch SRGEMM kernel 62 | cuASR_MinPlus_SGEMM minplus_gemm; 63 | cutlass::Status status = minplus_gemm(args, nullptr, stream); 64 | return static_cast(status); 65 | } 66 | 67 | auto cuasr_minplus_srsgemm_nt_n( 68 | int M, 69 | int N, 70 | int K, 71 | float const *A, 72 | int lda, 73 | float const *B, 74 | int ldb, 75 | float *C, 76 | int ldc, 77 | bool do_epilogue_min, 78 | cudaStream_t stream) -> int { 79 | return cuasr_minplus_srsgemm_nt_n( 80 | M, N, K, A, lda, B, ldb, C, ldc, C, do_epilogue_min, stream); 81 | } 82 | 83 | auto rng_init_matrix(float *buf, int len, int seed, float min = 0.5, float max = 1.5) 84 | -> void { 85 | auto rng = std::mt19937_64(seed); 86 | auto dist = std::uniform_real_distribution(min, max); 87 | for (auto i = 0; i < len; ++i) { 88 | buf[i] = dist(rng); 89 | } 90 | } 91 | 92 | int main() { 93 | using namespace std::chrono; 94 | // problem size 95 | constexpr int M = 4096; 96 | constexpr int N = 4096; 97 | constexpr int K = 4096; 98 | constexpr int repeats = 1; 99 | 100 | std::cout << "Running tropical SRGEMM on A = " << M << 'x' << K << " and B = " << K 101 | << 'x' << N << '\n'; 102 | 103 | std::cout << "Allocating and initializing host/device buffers\n"; 104 | float *A = new float[M * K]; 105 | float *B = new float[K * N]; 106 | float *C = new float[M * N]; 107 | 108 | rng_init_matrix(A, M * K, 3090 + 0); 109 | rng_init_matrix(B, K * N, 3090 + 1); 110 | rng_init_matrix(C, M * N, 3090 + 2); 111 | 112 | float *d_A, *d_B, *d_C; 113 | cudaMalloc((void **)&d_A, sizeof(float) * M * K); 114 | cudaMalloc((void **)&d_B, sizeof(float) * K * N); 115 | cudaMalloc((void **)&d_C, sizeof(float) * M * N); 116 | 117 | cudaMemcpy(d_A, A, sizeof(float) * M * K, cudaMemcpyHostToDevice); 118 | cudaMemcpy(d_B, B, sizeof(float) * K * N, cudaMemcpyHostToDevice); 119 | cudaMemcpy(d_C, C, sizeof(float) * M * N, cudaMemcpyHostToDevice); 120 | 121 | auto retval = 0; 122 | auto start = high_resolution_clock::now(); 123 | for (int i = 0; i < repeats; ++i) { 124 | retval |= cuasr_minplus_srsgemm_nt_n(M, N, K, d_A, M, d_B, K, d_C, M, true, nullptr); 125 | cudaDeviceSynchronize(); 126 | } 127 | auto end = high_resolution_clock::now(); 128 | auto delta = duration_cast(end - start).count(); 129 | 130 | if (retval) { 131 | std::cout << "Error code " << retval << '\n'; 132 | return retval; 133 | } 134 | 135 | std::cout << "Min-Plus SRGEMM FLOP/s = " << (repeats * 2.0 * M * N * K) / (delta / 1'000'000'000.0) 136 | << '\n'; 137 | return 0; 138 | } 139 | -------------------------------------------------------------------------------- /examples/01_userdefined_semiring/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(userdefined_semiring 2 | userdefined_semiring.cu 3 | ) 4 | target_include_directories(userdefined_semiring 5 | PRIVATE 6 | ${CUDA_INCLUDE_DIRS} 7 | ${PROJECT_SOURCE_DIR}/include 8 | ${PROJECT_SOURCE_DIR}/tools/include/ 9 | ${PROJECT_SOURCE_DIR}/cutlass/include 10 | ${PROJECT_SOURCE_DIR}/cutlass/tools/util/include/ 11 | ) 12 | target_compile_options(userdefined_semiring 13 | PUBLIC 14 | # C++ compiler flags 15 | $<$,$>: 16 | ${cuASR_CXX_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 17 | 18 | # CUDA compiler flags 19 | $<$,$>: 20 | ${cuASR_CUDA_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 21 | ) 22 | -------------------------------------------------------------------------------- /examples/01_userdefined_semiring/userdefined_semiring.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuasr/gemm/device/default_srgemm_configuration.h" 6 | #include "cuasr/gemm/device/srgemm.h" 7 | #include "cuasr/functional.h" 8 | 9 | #include "cuasr/reference/srgemm/host_srgemm.h" 10 | 11 | /* cuASR for Galois Field Semiring GEMM : A Demo of cuasr extension 12 | * 13 | * In this example, we show how to define a custom semiring GEMM operator 14 | * that is not supported by the provided default SRGEMM configurations in cuASR. 15 | * 16 | * Galois Field SRGEMM explained here is an implementation of GEMM over GF(2) field 17 | * arithmetic. cuasr/functional.h already contains an implementation of binary_and 18 | * operation, so we must define a binary_xor here in order to define our own out of 19 | * library ring. 20 | * 21 | * GF(2) GEMM: 22 | * Addition operator = binary XOR 23 | * Multiplication Operator = binary AND 24 | * Zero = Addition Identity = false 25 | * Multiplicative Annihilator = false 26 | * 27 | * The primary thing that needs to be done for this is contained in the anonymous 28 | * namespace below. All cuasr ring operators are defined as default constructible structs 29 | * that contain many overloads of operator() with which the cuasr SRGEMM core kernel can 30 | * invoke them. Although verbose, the different scalar and cutlass::Array overloads 31 | * of each operator allow for optimizations to be done, primarily for unrolling. These 32 | * structs need minimal knowledge of CUDA and are still quite short to implement at around 33 | * 50 lines. 34 | * 35 | * This operator struct must also contain a constexpr definition of the Identity and 36 | * Annihilator elements for the user defined operator, as these are used within the core 37 | * cuasr SRGEMM kernel to initialize the accumulators and during the epilogue to see if a 38 | * load from the C matrix is needed. In our case of xor operation, this is as simple as 39 | * including `static T constexpr Identity = static_cast(false);` in the struct 40 | * definition. 41 | * 42 | * After the operator struct is defined, the rest is some simple boilerplate for 43 | * instantiating the cuasr::gemm::device::Srgemm template such as input matrix data types, 44 | * leading dimensions, alignments as well as the tile shapes for threadblock, warp and 45 | * instruction level SRGEMM. In the case of SIMT SRGEMM, only valid `InstructionShape` is 46 | * <1, 1, 1> since each lane processes a single element at a time. ThreadblockShape and 47 | * WarpShape are the two main points of optimization as they affect the amount of shared 48 | * memory and register usage and unrolling. Since SRGEMM only supports SIMT instructions, 49 | * OperatorClass must be set to OpClassSimt. SmArch can be set to Sm50 for SRGEMM on 50 | * Maxwell or later which only supports 2 stage SRGEMM. Support for Sm80 (Ampere) 51 | * multi-stage pipelined SRGEMM is planned for the future. 52 | */ 53 | 54 | // clang-format off 55 | namespace { 56 | template 57 | struct binary_xor { 58 | static T constexpr Identity = static_cast(false); 59 | 60 | // expose base scalar operator 61 | __host__ __device__ 62 | T operator()(T lhs, T const &rhs) const { 63 | lhs ^= rhs; 64 | return lhs; 65 | } 66 | 67 | __host__ __device__ 68 | cutlass::Array 69 | operator()(cutlass::Array const &lhs, cutlass::Array const &rhs) const { 70 | cutlass::Array result; 71 | #pragma unroll 72 | for (int i = 0; i < N; ++i) { 73 | result[i] = this->operator()(lhs[i], rhs[i]); 74 | } 75 | return result; 76 | } 77 | 78 | __host__ __device__ 79 | cutlass::Array 80 | operator()(cutlass::Array const &lhs, T const &scalar) const { 81 | cutlass::Array result; 82 | #pragma unroll 83 | for (int i = 0; i < N; ++i) { 84 | result[i] = this->operator()(lhs[i], scalar); 85 | } 86 | return result; 87 | } 88 | 89 | __host__ __device__ 90 | cutlass::Array 91 | operator()(T const &scalar, cutlass::Array const &rhs) const { 92 | cutlass::Array result; 93 | #pragma unroll 94 | for (int i = 0; i < N; ++i) { 95 | result[i] = this->operator()(scalar, rhs[i]); 96 | } 97 | return result; 98 | } 99 | }; 100 | } // namespace 101 | // clang-format on 102 | 103 | // GF(2) xor-and SRGEMM 104 | auto cuasr_gf_srgemm_nnn( 105 | int M, 106 | int N, 107 | int K, 108 | int const *A, 109 | int lda, 110 | int const *B, 111 | int ldb, 112 | int *C, 113 | int ldc, 114 | int *D, 115 | bool do_epilogue_and, 116 | cudaStream_t stream = nullptr) -> int { 117 | // compile time configuration of this srgemm kernel 118 | using OperatorClass = cutlass::arch::OpClassSimt; 119 | using SmArch = cutlass::arch::Sm50; 120 | 121 | using AdditionOp = binary_xor; 122 | using MultiplicationOp = cuasr::binary_and; 123 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 124 | AdditionOp, MultiplicationOp, int, 1>; 125 | 126 | static int constexpr AlignmentA = 1; 127 | static int constexpr AlignmentB = 1; 128 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 129 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 130 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 131 | using ThreadblockSwizzle = 132 | typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; 133 | static int constexpr Stages = 2; 134 | 135 | using RowMajor = cutlass::layout::RowMajor; 136 | 137 | using cuASRGaloisFieldSrgemm = cuasr::gemm::device::Srgemm< 138 | AdditionOp, // Thread level SemiRing operator 139 | MultiplicationOp, // Thread level SemiRing operator 140 | int, // element type of A 141 | RowMajor, // layout of A 142 | int, // element type of B 143 | RowMajor, // layout of B 144 | int, // element t ype of C 145 | RowMajor, // layout of C 146 | int, // element type of D 147 | OperatorClass, // Logical operator class (SIMT/Tensor) 148 | SmArch, // CUDA architecture 149 | ThreadblockShape, // GEMM shape at CTA level 150 | WarpShape, // GEMM shape at Warp level 151 | InstructionShape, // GEMM shape at thread level 152 | EpilogueOutputOp, // Epilogue operator at thread level 153 | ThreadblockSwizzle, // GEMM threadblock swizzler 154 | Stages, // Pipeline stages for shmem 155 | AlignmentA, // Alignment of A elements 156 | AlignmentB, // Alignment of B elements 157 | false // SplitKSerial 158 | >; 159 | 160 | int alpha = MultiplicationOp::Identity; 161 | int beta = do_epilogue_and ? MultiplicationOp::Identity : MultiplicationOp::Annihilator; 162 | 163 | // construct kernel arguments struct 164 | cuASRGaloisFieldSrgemm::Arguments args( 165 | { M, N, K }, // Problem dimensions 166 | { A, lda }, // Tensor-ref for source matrix A 167 | { B, ldb }, // Tensor-ref for source matrix B 168 | { C, ldc }, // Tensor-ref for source matrix C 169 | { D, ldc }, // Tensor-ref for destination matrix D 170 | { alpha, beta } // 171 | ); 172 | 173 | // launch SRGEMM kernel 174 | cuASRGaloisFieldSrgemm gf_srgemm; 175 | cutlass::Status status = gf_srgemm(args, nullptr, stream); 176 | return static_cast(status); 177 | } 178 | 179 | auto cuasr_gf_srgemm_nnn( 180 | int M, 181 | int N, 182 | int K, 183 | int const *A, 184 | int lda, 185 | int const *B, 186 | int ldb, 187 | int *C, 188 | int ldc, 189 | bool do_epilogue_and, 190 | cudaStream_t stream) -> int { 191 | return cuasr_gf_srgemm_nnn(M, N, K, A, lda, B, ldb, C, ldc, C, do_epilogue_and, stream); 192 | } 193 | 194 | auto rng_init_matrix(int *buf, int len, int seed) -> void { 195 | auto rng = std::mt19937_64(seed); 196 | auto dist = std::bernoulli_distribution(0.025); 197 | for (auto i = 0; i < len; ++i) { 198 | buf[i] = static_cast(dist(rng)); 199 | } 200 | } 201 | 202 | // compares result of SRGEMM to a CPU kernel as reference 203 | auto compare_host_reference( 204 | int M, 205 | int N, 206 | int K, 207 | int alpha, 208 | int *A, 209 | int lda, 210 | int *B, 211 | int ldb, 212 | int beta, 213 | int *C, 214 | int ldc, 215 | int *reference_D, 216 | int *device_D) -> bool { 217 | using AdditionOp = binary_xor; 218 | using MultiplicationOp = cuasr::binary_and; 219 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 220 | AdditionOp, MultiplicationOp, int, 1>; 221 | using RowMajor = cutlass::layout::RowMajor; 222 | 223 | cuasr::reference::host::Srgemm< 224 | AdditionOp, // 225 | MultiplicationOp, // 226 | int, RowMajor, // 227 | int, RowMajor, // 228 | int, RowMajor, // 229 | typename EpilogueOutputOp::ElementCompute, // 230 | typename EpilogueOutputOp::ElementAccumulator, // 231 | EpilogueOutputOp> 232 | reference_srgemm; 233 | 234 | reference_srgemm( 235 | { M, N, K }, // 236 | alpha, { A, lda }, { B, ldb }, // 237 | beta, { C, ldc }, { reference_D, ldc }, // 238 | AdditionOp::Identity); 239 | 240 | auto is_correct = true; 241 | for (int n = 0; n < N; ++n) { 242 | for (int m = 0; m < M; ++m) { 243 | is_correct &= (reference_D[(ldc * n) + m] == device_D[(ldc * n) + m]); 244 | } 245 | } 246 | return is_correct; 247 | } 248 | 249 | 250 | int main() { 251 | using namespace std::chrono; 252 | // problem size 253 | constexpr int M = 512; // 4096 254 | constexpr int N = 512; 255 | constexpr int K = 512; 256 | constexpr bool do_epilogue_and = true; 257 | 258 | std::cout << "Running Xor-And Galois Field SRGEMM on A = " << M << 'x' << K 259 | << " and B = " << K << 'x' << N << '\n'; 260 | 261 | // input matrices 262 | std::cout << "Allocating and initializing host/device buffers\n"; 263 | int *A = new int[M * K]; 264 | int *B = new int[K * N]; 265 | int *C = new int[M * N]; 266 | 267 | // output 268 | int *reference_D = new int[M * N]; 269 | int *device_D = new int[M * N]; 270 | 271 | rng_init_matrix(A, M * K, 3090 + 0); 272 | rng_init_matrix(B, K * N, 3090 + 1); 273 | rng_init_matrix(C, M * N, 3090 + 2); 274 | 275 | int *d_A, *d_B, *d_C; 276 | cudaMalloc((void **)&d_A, sizeof(int) * M * K); 277 | cudaMalloc((void **)&d_B, sizeof(int) * K * N); 278 | cudaMalloc((void **)&d_C, sizeof(int) * M * N); 279 | 280 | cudaMemcpy(d_A, A, sizeof(int) * M * K, cudaMemcpyHostToDevice); 281 | cudaMemcpy(d_B, B, sizeof(int) * K * N, cudaMemcpyHostToDevice); 282 | cudaMemcpy(d_C, C, sizeof(int) * M * N, cudaMemcpyHostToDevice); 283 | 284 | auto start = high_resolution_clock::now(); 285 | 286 | auto retval 287 | = cuasr_gf_srgemm_nnn(M, N, K, d_A, M, d_B, K, d_C, M, do_epilogue_and, nullptr); 288 | retval |= cudaDeviceSynchronize(); 289 | auto end = high_resolution_clock::now(); 290 | duration delta = (end - start); 291 | 292 | if (retval) { 293 | std::cout << "Error code " << retval << '\n'; 294 | return retval; 295 | } 296 | 297 | std::cout << "Xor-And Galois Field SRGEMM FLOP/s = " 298 | << (2.0 * M * N * K) / delta.count() << '\n'; 299 | 300 | cudaMemcpy(device_D, d_C, sizeof(int) * M * N, cudaMemcpyDeviceToHost); 301 | 302 | // compare against host 303 | std::cout << "Comparing against reference host-side SRGEMM : "; 304 | int alpha = cuasr::binary_and::Identity; 305 | int beta = do_epilogue_and ? cuasr::binary_and::Identity 306 | : cuasr::binary_and::Annihilator; 307 | auto is_correct = compare_host_reference( 308 | M, N, K, alpha, A, M, B, N, beta, C, M, reference_D, device_D); 309 | 310 | if (is_correct) { 311 | std::cout << "PASSED!\n"; 312 | } 313 | else { 314 | std::cout << "FAILED!\n"; 315 | } 316 | return !is_correct; 317 | } 318 | -------------------------------------------------------------------------------- /examples/02_splitk_srgemm/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(splitk_srgemm 2 | splitk_srgemm.cu 3 | ) 4 | target_include_directories(splitk_srgemm 5 | PRIVATE 6 | ${CUDA_INCLUDE_DIRS} 7 | ${PROJECT_SOURCE_DIR}/include 8 | ${PROJECT_SOURCE_DIR}/tools/include/ 9 | ${PROJECT_SOURCE_DIR}/cutlass/include 10 | ${PROJECT_SOURCE_DIR}/cutlass/tools/util/include/ 11 | ) 12 | target_compile_options(splitk_srgemm 13 | PUBLIC 14 | # C++ compiler flags 15 | $<$,$>: 16 | ${cuASR_CXX_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 17 | 18 | # CUDA compiler flags 19 | $<$,$>: 20 | ${cuASR_CUDA_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 21 | ) 22 | -------------------------------------------------------------------------------- /examples/02_splitk_srgemm/splitk_srgemm.cu: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "cuasr/functional.h" 10 | #include "cuasr/gemm/device/default_srgemm_configuration.h" 11 | #include "cuasr/gemm/device/srgemm.h" 12 | #include "cuasr/gemm/device/srgemm_splitk_parallel.h" 13 | 14 | #include "cutlass/gemm/device/gemm_splitk_parallel.h" 15 | #include "cutlass/util/device_memory.h" 16 | 17 | auto cuasr_splitk_minplus_srsgemm_tn_t( 18 | int M, 19 | int N, 20 | int K, 21 | float const *A, 22 | int lda, 23 | float const *B, 24 | int ldb, 25 | float *C, 26 | int ldc, 27 | float *D, 28 | bool do_epilogue_min, 29 | int split_k_slices, 30 | cudaStream_t stream = nullptr) -> int { 31 | // compile time configuration of this srgemm kernel using OperatorClass 32 | using OperatorClass = cutlass::arch::OpClassSimt; 33 | using SmArch = cutlass::arch::Sm50; 34 | using AdditionOp = cuasr::minimum; 35 | using MultiplicationOp = cuasr::plus; 36 | 37 | using TropicalConfig = typename cuasr::gemm::device::DefaultSemiRingConfiguration< 38 | float, float, float, float, OperatorClass, // 39 | AdditionOp, MultiplicationOp, SmArch>; 40 | 41 | using ColumnMajor = cutlass::layout::ColumnMajor; 42 | using RowMajor = cutlass::layout::RowMajor; 43 | 44 | using cuASR_SplitK_SRGEMM = cuasr::gemm::device::SrgemmSplitKParallel< 45 | AdditionOp, // Thread level SemiRing operator 46 | MultiplicationOp, // Thread level SemiRing operator 47 | float, // element type of A 48 | RowMajor, // layout of A 49 | float, // element type of B 50 | ColumnMajor, // layout of B 51 | float, // element t ype of C 52 | RowMajor, // layout of C 53 | float // element type of D 54 | >; 55 | 56 | // setup runtime configuration 57 | float alpha = MultiplicationOp::Identity; 58 | float beta 59 | = do_epilogue_min ? MultiplicationOp::Identity : MultiplicationOp::Annihilator; 60 | 61 | // construct kernel arguments struct 62 | cuASR_SplitK_SRGEMM::Arguments args( 63 | { M, N, K }, // Problem dimensions 64 | { A, lda }, // Tensor-ref for source matrix A 65 | { B, ldb }, // Tensor-ref for source matrix B 66 | { C, ldc }, // Tensor-ref for source matrix C 67 | { D, ldc }, // Tensor-ref for destination matrix D 68 | { alpha, beta }, // epilogue scalars 69 | split_k_slices // number of K dimension slices 70 | ); 71 | 72 | // using the arguments, query for extra workspace required parallel reduction 73 | size_t workspace_size = cuASR_SplitK_SRGEMM::get_workspace_size(args); 74 | cutlass::device_memory::allocation workspace(workspace_size); 75 | 76 | // construct cuasr kernel depending on templates 77 | cuASR_SplitK_SRGEMM splitk_minplus_srgemm_op; 78 | 79 | // Initialize cuasr kernel with arguments and workspace ptr 80 | cutlass::Status status = splitk_minplus_srgemm_op.initialize(args, workspace.get()); 81 | 82 | // launch split-K parallel SRGEMM kernel 83 | status = splitk_minplus_srgemm_op(); 84 | 85 | return static_cast(status); 86 | } 87 | 88 | auto cuasr_splitk_minplus_srsgemm_tn_t( 89 | int M, 90 | int N, 91 | int K, 92 | float const *A, 93 | int lda, 94 | float const *B, 95 | int ldb, 96 | float *C, 97 | int ldc, 98 | bool do_epilogue_min, 99 | int split_k_slices, 100 | cudaStream_t stream) -> int { 101 | return cuasr_splitk_minplus_srsgemm_tn_t( 102 | M, N, K, A, lda, B, ldb, C, ldc, C, do_epilogue_min, split_k_slices, stream); 103 | } 104 | 105 | auto cuasr_minplus_srsgemm_tn_t( 106 | int M, 107 | int N, 108 | int K, 109 | float const *A, 110 | int lda, 111 | float const *B, 112 | int ldb, 113 | float *C, 114 | int ldc, 115 | float *D, 116 | bool do_epilogue_min, 117 | cudaStream_t stream = nullptr) -> int { 118 | // compile time configuration of this srgemm kernel using OperatorClass 119 | using OperatorClass = cutlass::arch::OpClassSimt; 120 | using SmArch = cutlass::arch::Sm50; 121 | using AdditionOp = cuasr::minimum; 122 | using MultiplicationOp = cuasr::plus; 123 | 124 | using TropicalConfig = typename cuasr::gemm::device::DefaultSemiRingConfiguration< 125 | float, float, float, float, OperatorClass, // 126 | AdditionOp, MultiplicationOp, SmArch>; 127 | 128 | using ColumnMajor = cutlass::layout::ColumnMajor; 129 | using RowMajor = cutlass::layout::RowMajor; 130 | 131 | using cuASR_MinPlus_SGEMM = cuasr::gemm::device::Srgemm< 132 | AdditionOp, // Thread level SemiRing operator 133 | MultiplicationOp, // Thread level SemiRing operator 134 | float, // element type of A 135 | RowMajor, // layout of A 136 | float, // element type of B 137 | ColumnMajor, // layout of B 138 | float, // element t ype of C 139 | RowMajor, // layout of C 140 | float // element type of D 141 | >; 142 | 143 | float alpha = MultiplicationOp::Identity; 144 | float beta 145 | = do_epilogue_min ? MultiplicationOp::Identity : MultiplicationOp::Annihilator; 146 | 147 | // construct kernel arguments struct 148 | cuASR_MinPlus_SGEMM::Arguments args( 149 | { M, N, K }, // Problem dimensions 150 | { A, lda }, // Tensor-ref for source matrix A 151 | { B, ldb }, // Tensor-ref for source matrix B 152 | { C, ldc }, // Tensor-ref for source matrix C 153 | { D, ldc }, // Tensor-ref for destination matrix D 154 | { alpha, beta } // 155 | ); 156 | 157 | // launch SRGEMM kernel 158 | cuASR_MinPlus_SGEMM minplus_gemm; 159 | cutlass::Status status = minplus_gemm(args, nullptr, stream); 160 | return static_cast(status); 161 | } 162 | 163 | auto cuasr_minplus_srsgemm_tn_t( 164 | int M, 165 | int N, 166 | int K, 167 | float const *A, 168 | int lda, 169 | float const *B, 170 | int ldb, 171 | float *C, 172 | int ldc, 173 | bool do_epilogue_min, 174 | cudaStream_t stream) -> int { 175 | return cuasr_minplus_srsgemm_tn_t( 176 | M, N, K, A, lda, B, ldb, C, ldc, C, do_epilogue_min, stream); 177 | } 178 | 179 | auto rng_init_matrix(float *buf, int len, int seed, float min = 0.5, float max = 1.5) 180 | -> void { 181 | auto rng = std::mt19937_64(seed); 182 | auto dist = std::uniform_real_distribution(min, max); 183 | for (auto i = 0; i < len; ++i) { 184 | buf[i] = dist(rng); 185 | } 186 | } 187 | 188 | int main(int argc, const char *argv[]) { 189 | using namespace std::chrono; 190 | 191 | // problem size 192 | constexpr int M = 128; 193 | constexpr int N = 128; 194 | constexpr int K = 128 * 32; 195 | constexpr int lda = N; // num cols if row major, num rows if col major 196 | constexpr int ldb = K; // num cols if row major, num rows if col major 197 | constexpr int ldc = N; // num cols if row major, num rows if col major 198 | constexpr int repeats = 10; 199 | int split_k_slices = 8; 200 | if (argc > 1) { 201 | split_k_slices = std::atoi(argv[1]); 202 | } 203 | 204 | std::cout << "Running tropical SRGEMM on A = " << M << 'x' << K << " and B = " << K 205 | << 'x' << N << " with " << split_k_slices << " split-K slices." << '\n'; 206 | 207 | std::cout << "Allocating and initializing host/device buffers\n"; 208 | float *A = new float[M * K]; 209 | float *B = new float[K * N]; 210 | float *C = new float[M * N]; 211 | float *C_splitk = new float[M * N]; 212 | 213 | rng_init_matrix(A, M * K, 3090 + 0); 214 | rng_init_matrix(B, K * N, 3090 + 1); 215 | rng_init_matrix(C, M * N, 3090 + 2); 216 | 217 | auto retval = 0; 218 | 219 | float *d_A, *d_B, *d_C_regular, *d_C_splitk; 220 | retval |= cudaMalloc((void **)&d_A, sizeof(float) * M * K); 221 | retval |= cudaMalloc((void **)&d_B, sizeof(float) * K * N); 222 | retval |= cudaMalloc((void **)&d_C_regular, sizeof(float) * M * N); 223 | retval |= cudaMalloc((void **)&d_C_splitk, sizeof(float) * M * N); 224 | 225 | retval |= cudaMemcpy(d_A, A, sizeof(float) * M * K, cudaMemcpyHostToDevice); 226 | retval |= cudaMemcpy(d_B, B, sizeof(float) * K * N, cudaMemcpyHostToDevice); 227 | retval |= cudaMemcpy(d_C_regular, C, sizeof(float) * M * N, cudaMemcpyHostToDevice); 228 | retval |= cudaMemcpy(d_C_splitk, C, sizeof(float) * M * N, cudaMemcpyHostToDevice); 229 | 230 | if (retval > 0) { 231 | std::cout << "Could not allocate or copy to device.\n"; 232 | return retval; 233 | } 234 | 235 | // run the tests 236 | auto start = high_resolution_clock::now(); 237 | for (int i = 0; i < repeats; ++i) { 238 | retval |= cuasr_minplus_srsgemm_tn_t( 239 | M, N, K, d_A, M, d_B, K, d_C_regular, M, true, nullptr); 240 | retval |= cudaDeviceSynchronize(); 241 | } 242 | auto end = high_resolution_clock::now(); 243 | auto delta_regular = duration_cast(end - start).count(); 244 | 245 | retval = 0; 246 | start = high_resolution_clock::now(); 247 | for (int i = 0; i < repeats; ++i) { 248 | retval |= cuasr_splitk_minplus_srsgemm_tn_t( 249 | M, N, K, d_A, lda, d_B, ldb, d_C_splitk, ldc, true, split_k_slices, nullptr); 250 | retval |= cudaDeviceSynchronize(); 251 | } 252 | end = high_resolution_clock::now(); 253 | auto delta_splitk = duration_cast(end - start).count(); 254 | 255 | if (retval) { 256 | std::cout << "Error code " << retval << '\n'; 257 | return retval; 258 | } 259 | 260 | // print perf numbers 261 | std::cout << "Min-Plus SRGEMM FLOP/s = " 262 | << (repeats * 2.0 * M * N * K) / (delta_regular / 1'000'000'000.0) << '\n'; 263 | 264 | std::cout << "Min-Plus Split-K SRGEMM FLOP/s = " 265 | << (repeats * 2.0 * M * N * K) / (delta_splitk / 1'000'000'000.0) << '\n'; 266 | 267 | std::cout << "Split-K speedup over regular = " 268 | << static_cast(delta_regular) / delta_splitk << '\n'; 269 | 270 | // verify correct 271 | cudaMemcpy(C, d_C_regular, sizeof(float) * M * N, cudaMemcpyDeviceToHost); 272 | cudaMemcpy(C_splitk, d_C_splitk, sizeof(float) * M * N, cudaMemcpyDeviceToHost); 273 | auto is_correct = true; 274 | for (int n = 0; n < N; ++n) { 275 | for (int m = 0; m < M; ++m) { 276 | is_correct &= (C[(ldc * n) + m] == C_splitk[(ldc * n) + m]); 277 | } 278 | } 279 | 280 | if (is_correct) { 281 | std::cout << "Split-K matches regular SRGEMM\n"; 282 | return 0; 283 | } 284 | else { 285 | std::cout << "Split-K does NOT match regular SRGEMM\n"; 286 | return 1; 287 | } 288 | } 289 | -------------------------------------------------------------------------------- /examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(00_minplus_srgemm) 2 | add_subdirectory(01_userdefined_semiring) 3 | add_subdirectory(02_splitk_srgemm) 4 | -------------------------------------------------------------------------------- /include/cuasr/arch/srmma.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Templates exposing architecture support for multiply-add operations 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/array.h" 11 | #include "cutlass/arch/mma.h" 12 | #include "cutlass/gemm/gemm.h" 13 | 14 | ///////////////////////////////////////////////////////////////////////////////////////////////// 15 | 16 | namespace cuasr { 17 | namespace arch { 18 | 19 | /// Matrix product operator for all semi-rings 20 | template < 21 | /// Size of the matrix product (concept: GemmShape) 22 | typename Shape_, 23 | /// Number of threads participating 24 | int kThreads_, 25 | /// Data type of A elements 26 | typename ElementA, 27 | /// Layout of A matrix (concept: MatrixLayout) 28 | typename LayoutA, 29 | /// Data type of B elements 30 | typename ElementB, 31 | /// Layout of B matrix (concept: MatrixLayout) 32 | typename LayoutB, 33 | /// Element type of C matrix 34 | typename ElementC, 35 | /// Layout of C matrix (concept: MatrixLayout) 36 | typename LayoutC, 37 | /// addition operator of the semi-ring 38 | typename AdditionOp, 39 | /// multiplication operator of the semi-ring 40 | typename MultiplicationOp 41 | > 42 | struct Srmma; 43 | 44 | 45 | ///////////////////////////////////////////////////////////////////////////////////////////////// 46 | 47 | /// Semi-rings multiply-add specialized for 1 element per instruction 48 | template < 49 | /// Data type of A elements 50 | typename ElementA, 51 | /// Layout of A matrix (concept: MatrixLayout) 52 | typename LayoutA, 53 | /// Data type of B elements 54 | typename ElementB, 55 | /// Layout of B matrix (concept: MatrixLayout) 56 | typename LayoutB, 57 | /// Element type of C matrix 58 | typename ElementC, 59 | /// Layout of C matrix (concept: MatrixLayout) 60 | typename LayoutC, 61 | /// Addition operator of the semi-ring 62 | typename AdditionOp, 63 | /// Multiplication operator of the semi-ring 64 | typename MultiplicationOp> 65 | struct Srmma< 66 | cutlass::gemm::GemmShape<1, 1, 1>, 67 | 1, 68 | ElementA, 69 | LayoutA, 70 | ElementB, 71 | LayoutB, 72 | ElementC, 73 | LayoutC, 74 | AdditionOp, 75 | MultiplicationOp> { 76 | using Shape = cutlass::gemm::GemmShape<1, 1, 1>; 77 | 78 | // semi-ring operators must be default contructible and 79 | // have a binary invocation () operator 80 | AdditionOp add; 81 | MultiplicationOp mult; 82 | 83 | CUTLASS_HOST_DEVICE 84 | void operator()( 85 | cutlass::Array &d, 86 | cutlass::Array const &a, 87 | cutlass::Array const &b, 88 | cutlass::Array const &c 89 | ) { 90 | d[0] = add(c[0], mult(a[0], b[0])); 91 | } 92 | }; 93 | 94 | ///////////////////////////////////////////////////////////////////////////////////////////////// 95 | 96 | } // namespace arch 97 | } // namespace cuasr 98 | 99 | ///////////////////////////////////////////////////////////////////////////////////////////////// 100 | -------------------------------------------------------------------------------- /include/cuasr/functional.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Defines basic semi-ring reels together with their identity and 6 | annihilator constants given type T. 7 | 8 | This is inspired by the Standard Library's header. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cutlass/array.h" 14 | #include "cutlass/cutlass.h" 15 | 16 | #include 17 | #include 18 | 19 | namespace cuasr { 20 | using cutlass::Array; 21 | 22 | namespace { 23 | 24 | // helpers to get the +inf/-inf and min/max for integrals/floats 25 | // NOTE: we only use min/max values even for floats for now to avoid 26 | // having to use actual +inf/-inf-ies. In practice, min/max for 27 | // floats should behave the same as +inf/-inf 28 | template 29 | constexpr auto get_inf() noexcept { 30 | return std::numeric_limits::max(); 31 | } 32 | 33 | template 34 | constexpr auto get_neginf() noexcept { 35 | return std::numeric_limits::min(); 36 | } 37 | } 38 | 39 | template 40 | struct plus { 41 | static T constexpr Identity = static_cast(0); 42 | static T constexpr Annihilator = get_inf(); 43 | 44 | // scalar operator 45 | CUTLASS_HOST_DEVICE 46 | T operator()(T lhs, T const &rhs) const { 47 | lhs += rhs; 48 | return lhs; 49 | } 50 | 51 | CUTLASS_HOST_DEVICE 52 | Array operator()(Array const &lhs, Array const &rhs) const { 53 | Array result; 54 | CUTLASS_PRAGMA_UNROLL 55 | for (int i = 0; i < N; ++i) { 56 | result[i] = this->operator()(lhs[i], rhs[i]); 57 | } 58 | return result; 59 | } 60 | 61 | CUTLASS_HOST_DEVICE 62 | Array operator()(Array const &lhs, T const &scalar) const { 63 | Array result; 64 | CUTLASS_PRAGMA_UNROLL 65 | for (int i = 0; i < N; ++i) { 66 | result[i] = this->operator()(lhs[i], scalar); 67 | } 68 | return result; 69 | } 70 | 71 | CUTLASS_HOST_DEVICE 72 | Array operator()(T const &scalar, Array const &rhs) const { 73 | Array result; 74 | CUTLASS_PRAGMA_UNROLL 75 | for (int i = 0; i < N; ++i) { 76 | result[i] = this->operator()(scalar, rhs[i]); 77 | } 78 | return result; 79 | } 80 | }; 81 | 82 | template 83 | struct multiplies { 84 | static T constexpr Identity = static_cast(1); 85 | static T constexpr Annihilator = static_cast(0); 86 | 87 | // scalar operator 88 | CUTLASS_HOST_DEVICE 89 | T operator()(T lhs, T const &rhs) const { 90 | lhs *= rhs; 91 | return lhs; 92 | } 93 | 94 | CUTLASS_HOST_DEVICE 95 | Array operator()(Array const &lhs, Array const &rhs) const { 96 | Array result; 97 | CUTLASS_PRAGMA_UNROLL 98 | for (int i = 0; i < N; ++i) { 99 | result[i] = this->operator()(lhs[i], rhs[i]); 100 | } 101 | return result; 102 | } 103 | 104 | CUTLASS_HOST_DEVICE 105 | Array operator()(Array const &lhs, T const &scalar) const { 106 | Array result; 107 | CUTLASS_PRAGMA_UNROLL 108 | for (int i = 0; i < N; ++i) { 109 | result[i] = this->operator()(lhs[i], scalar); 110 | } 111 | return result; 112 | } 113 | 114 | CUTLASS_HOST_DEVICE 115 | Array operator()(T const &scalar, Array const &rhs) const { 116 | Array result; 117 | CUTLASS_PRAGMA_UNROLL 118 | for (int i = 0; i < N; ++i) { 119 | result[i] = this->operator()(scalar, rhs[i]); 120 | } 121 | return result; 122 | } 123 | }; 124 | 125 | template 126 | struct minimum { 127 | static T constexpr Identity = get_inf(); 128 | static T constexpr Annihilator = get_neginf(); 129 | 130 | // scalar operator 131 | CUTLASS_HOST_DEVICE 132 | T operator()(T const &lhs, T const &rhs) const { return (rhs < lhs ? rhs : lhs); } 133 | 134 | CUTLASS_HOST_DEVICE 135 | Array operator()(Array const &lhs, Array const &rhs) const { 136 | Array result; 137 | CUTLASS_PRAGMA_UNROLL 138 | for (int i = 0; i < N; ++i) { 139 | result[i] = this->operator()(lhs[i], rhs[i]); 140 | } 141 | return result; 142 | } 143 | 144 | CUTLASS_HOST_DEVICE 145 | Array operator()(Array const &lhs, T const &scalar) const { 146 | Array result; 147 | CUTLASS_PRAGMA_UNROLL 148 | for (int i = 0; i < N; ++i) { 149 | result[i] = this->operator()(lhs[i], scalar); 150 | } 151 | return result; 152 | } 153 | 154 | CUTLASS_HOST_DEVICE 155 | Array operator()(T const &scalar, Array const &rhs) const { 156 | Array result; 157 | CUTLASS_PRAGMA_UNROLL 158 | for (int i = 0; i < N; ++i) { 159 | result[i] = this->operator()(scalar, rhs[i]); 160 | } 161 | return result; 162 | } 163 | }; 164 | 165 | template 166 | struct maximum { 167 | static T constexpr Identity = get_neginf(); 168 | static T constexpr Annihilator = get_inf(); 169 | 170 | // scalar operator 171 | CUTLASS_HOST_DEVICE 172 | T operator()(T const &lhs, T const &rhs) const { return (lhs < rhs ? rhs : lhs); } 173 | 174 | CUTLASS_HOST_DEVICE 175 | Array operator()(Array const &lhs, Array const &rhs) const { 176 | Array result; 177 | CUTLASS_PRAGMA_UNROLL 178 | for (int i = 0; i < N; ++i) { 179 | result[i] = this->operator()(lhs[i], rhs[i]); 180 | } 181 | return result; 182 | } 183 | 184 | CUTLASS_HOST_DEVICE 185 | Array operator()(Array const &lhs, T const &scalar) const { 186 | Array result; 187 | CUTLASS_PRAGMA_UNROLL 188 | for (int i = 0; i < N; ++i) { 189 | result[i] = this->operator()(lhs[i], scalar); 190 | } 191 | return result; 192 | } 193 | 194 | CUTLASS_HOST_DEVICE 195 | Array operator()(T const &scalar, Array const &rhs) const { 196 | Array result; 197 | CUTLASS_PRAGMA_UNROLL 198 | for (int i = 0; i < N; ++i) { 199 | result[i] = this->operator()(scalar, rhs[i]); 200 | } 201 | return result; 202 | } 203 | }; 204 | 205 | template 206 | struct binary_and { 207 | static T constexpr Identity = static_cast(true); 208 | static T constexpr Annihilator = static_cast(false); 209 | 210 | // scalar operator 211 | CUTLASS_HOST_DEVICE 212 | T operator()(T lhs, T const &rhs) const { return lhs && rhs; } 213 | 214 | CUTLASS_HOST_DEVICE 215 | Array operator()(Array const &lhs, Array const &rhs) const { 216 | Array result; 217 | CUTLASS_PRAGMA_UNROLL 218 | for (int i = 0; i < N; ++i) { 219 | result[i] = this->operator()(lhs[i], rhs[i]); 220 | } 221 | return result; 222 | } 223 | 224 | CUTLASS_HOST_DEVICE 225 | Array operator()(Array const &lhs, T const &scalar) const { 226 | Array result; 227 | CUTLASS_PRAGMA_UNROLL 228 | for (int i = 0; i < N; ++i) { 229 | result[i] = this->operator()(lhs[i], scalar); 230 | } 231 | return result; 232 | } 233 | 234 | CUTLASS_HOST_DEVICE 235 | Array operator()(T const &scalar, Array const &rhs) const { 236 | Array result; 237 | CUTLASS_PRAGMA_UNROLL 238 | for (int i = 0; i < N; ++i) { 239 | result[i] = this->operator()(scalar, rhs[i]); 240 | } 241 | return result; 242 | } 243 | }; 244 | 245 | template 246 | struct binary_or { 247 | static T constexpr Identity = static_cast(false); 248 | static T constexpr Annihilator = static_cast(true); 249 | 250 | // scalar operator 251 | CUTLASS_HOST_DEVICE 252 | T operator()(T lhs, T const &rhs) const { return lhs || rhs; } 253 | 254 | CUTLASS_HOST_DEVICE 255 | Array operator()(Array const &lhs, Array const &rhs) const { 256 | Array result; 257 | CUTLASS_PRAGMA_UNROLL 258 | for (int i = 0; i < N; ++i) { 259 | result[i] = this->operator()(lhs[i], rhs[i]); 260 | } 261 | return result; 262 | } 263 | 264 | CUTLASS_HOST_DEVICE 265 | Array operator()(Array const &lhs, T const &scalar) const { 266 | Array result; 267 | CUTLASS_PRAGMA_UNROLL 268 | for (int i = 0; i < N; ++i) { 269 | result[i] = this->operator()(lhs[i], scalar); 270 | } 271 | return result; 272 | } 273 | 274 | CUTLASS_HOST_DEVICE 275 | Array operator()(T const &scalar, Array const &rhs) const { 276 | Array result; 277 | CUTLASS_PRAGMA_UNROLL 278 | for (int i = 0; i < N; ++i) { 279 | result[i] = this->operator()(scalar, rhs[i]); 280 | } 281 | return result; 282 | } 283 | }; 284 | 285 | } // namespace cuasr 286 | -------------------------------------------------------------------------------- /include/cuasr/gemm/device/default_srgemm_configuration.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Definitions for SRGEMM configuration structures. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/numeric_types.h" 12 | #include "cutlass/gemm/gemm.h" 13 | 14 | #include "cuasr/functional.h" 15 | #include "cuasr/arch/srmma.h" 16 | #include "cuasr/gemm/epilogue/thread/semiring_linear_combination.h" 17 | 18 | #include 19 | 20 | //////////////////////////////////////////////////////////////////////////////// 21 | 22 | namespace cuasr { 23 | namespace gemm { 24 | namespace device { 25 | 26 | //////////////////////////////////////////////////////////////////////////////// 27 | 28 | template < 29 | typename ElementA, 30 | typename ElementB, 31 | typename ElementC, 32 | typename ElementAccumulator, 33 | typename OperatorClass, 34 | typename AdditionOp, 35 | typename MultiplicationOp, 36 | typename ArchTag 37 | > 38 | struct DefaultSemiRingConfiguration; 39 | 40 | //////////////////////////////////////////////////////////////////////////////// 41 | 42 | // Plus-Times semi-ring GEMM configuration 43 | // this is the traditional GEMM 44 | template < 45 | typename Element, 46 | typename ArchTag 47 | > 48 | struct DefaultSemiRingConfiguration< 49 | Element, 50 | Element, 51 | Element, 52 | Element, 53 | cutlass::arch::OpClassSimt, 54 | cuasr::plus, 55 | cuasr::multiplies, 56 | ArchTag> { 57 | 58 | static int constexpr kAlignmentA = 1; 59 | static int constexpr kAlignmentB = 1; 60 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 61 | using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>; 62 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 63 | static int constexpr kStages = 2; 64 | 65 | using AdditionOp = cuasr::plus; 66 | using MultiplicationOp = cuasr::multiplies; 67 | 68 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 69 | AdditionOp, MultiplicationOp, Element, 1>; 70 | }; 71 | 72 | // Min-Plus (tropical) semi-ring GEMM configuration 73 | // example application: All Pairs Shorted Path 74 | template < 75 | typename Element, 76 | typename ArchTag 77 | > 78 | struct DefaultSemiRingConfiguration< 79 | Element, 80 | Element, 81 | Element, 82 | Element, 83 | cutlass::arch::OpClassSimt, 84 | cuasr::minimum, 85 | cuasr::plus, 86 | ArchTag> { 87 | 88 | static int constexpr kAlignmentA = 1; 89 | static int constexpr kAlignmentB = 1; 90 | using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 8>; 91 | using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>; 92 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 93 | static int constexpr kStages = 2; 94 | 95 | using AdditionOp = cuasr::minimum; 96 | using MultiplicationOp = cuasr::plus; 97 | 98 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 99 | AdditionOp, MultiplicationOp, Element, 1>; 100 | }; 101 | 102 | // Max-Plus semi-ring GEMM configuration 103 | // example application: Viterbi algorithm 104 | template < 105 | typename Element, 106 | typename ArchTag 107 | > 108 | struct DefaultSemiRingConfiguration< 109 | Element, 110 | Element, 111 | Element, 112 | Element, 113 | cutlass::arch::OpClassSimt, 114 | cuasr::maximum, 115 | cuasr::plus, 116 | ArchTag> { 117 | 118 | static int constexpr kAlignmentA = 1; 119 | static int constexpr kAlignmentB = 1; 120 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 121 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 122 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 123 | static int constexpr kStages = 2; 124 | 125 | using AdditionOp = cuasr::maximum; 126 | using MultiplicationOp = cuasr::plus; 127 | 128 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 129 | AdditionOp, MultiplicationOp, Element, 1>; 130 | }; 131 | 132 | // Max-Min 133 | template < 134 | typename Element, 135 | typename ArchTag 136 | > 137 | struct DefaultSemiRingConfiguration< 138 | Element, 139 | Element, 140 | Element, 141 | Element, 142 | cutlass::arch::OpClassSimt, 143 | cuasr::maximum, 144 | cuasr::minimum, 145 | ArchTag> { 146 | 147 | static int constexpr kAlignmentA = 1; 148 | static int constexpr kAlignmentB = 1; 149 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 150 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 151 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 152 | static int constexpr kStages = 2; 153 | 154 | using AdditionOp = cuasr::maximum; 155 | using MultiplicationOp = cuasr::minimum; 156 | 157 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 158 | AdditionOp, MultiplicationOp, Element, 1>; 159 | }; 160 | 161 | // Min-Max 162 | template < 163 | typename Element, 164 | typename ArchTag 165 | > 166 | struct DefaultSemiRingConfiguration< 167 | Element, 168 | Element, 169 | Element, 170 | Element, 171 | cutlass::arch::OpClassSimt, 172 | cuasr::minimum, 173 | cuasr::maximum, 174 | ArchTag> { 175 | 176 | static int constexpr kAlignmentA = 1; 177 | static int constexpr kAlignmentB = 1; 178 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 179 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 180 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 181 | static int constexpr kStages = 2; 182 | 183 | using AdditionOp = cuasr::minimum; 184 | using MultiplicationOp = cuasr::maximum; 185 | 186 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 187 | AdditionOp, MultiplicationOp, Element, 1>; 188 | }; 189 | 190 | // Min-Times 191 | template < 192 | typename Element, 193 | typename ArchTag 194 | > 195 | struct DefaultSemiRingConfiguration< 196 | Element, 197 | Element, 198 | Element, 199 | Element, 200 | cutlass::arch::OpClassSimt, 201 | cuasr::minimum, 202 | cuasr::multiplies, 203 | ArchTag> { 204 | 205 | static int constexpr kAlignmentA = 1; 206 | static int constexpr kAlignmentB = 1; 207 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 208 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 209 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 210 | static int constexpr kStages = 2; 211 | 212 | using AdditionOp = cuasr::minimum; 213 | using MultiplicationOp = cuasr::multiplies; 214 | 215 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 216 | AdditionOp, MultiplicationOp, Element, 1>; 217 | }; 218 | 219 | // Max-Times 220 | template < 221 | typename Element, 222 | typename ArchTag 223 | > 224 | struct DefaultSemiRingConfiguration< 225 | Element, 226 | Element, 227 | Element, 228 | Element, 229 | cutlass::arch::OpClassSimt, 230 | cuasr::maximum, 231 | cuasr::multiplies, 232 | ArchTag> { 233 | 234 | static int constexpr kAlignmentA = 1; 235 | static int constexpr kAlignmentB = 1; 236 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 237 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 238 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 239 | static int constexpr kStages = 2; 240 | 241 | using AdditionOp = cuasr::maximum; 242 | using MultiplicationOp = cuasr::multiplies; 243 | 244 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 245 | AdditionOp, MultiplicationOp, Element, 1>; 246 | }; 247 | 248 | // Or-And boolean ring 249 | template < 250 | typename Element, 251 | typename ArchTag 252 | > 253 | struct DefaultSemiRingConfiguration< 254 | Element, 255 | Element, 256 | Element, 257 | Element, 258 | cutlass::arch::OpClassSimt, 259 | cuasr::binary_or, 260 | cuasr::binary_and, 261 | ArchTag> { 262 | 263 | static int constexpr kAlignmentA = 1; 264 | static int constexpr kAlignmentB = 1; 265 | using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 8>; 266 | using WarpShape = cutlass::gemm::GemmShape<16, 64, 8>; 267 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 268 | static int constexpr kStages = 2; 269 | 270 | using AdditionOp = cuasr::binary_or; 271 | using MultiplicationOp = cuasr::binary_and; 272 | 273 | using EpilogueOutputOp = cuasr::epilogue::thread::SemiringLinearCombination< 274 | AdditionOp, MultiplicationOp, Element, 1>; 275 | }; 276 | 277 | //////////////////////////////////////////////////////////////////////////////// 278 | 279 | } // namespace device 280 | } // namespace gemm 281 | } // namespace cuasr 282 | 283 | //////////////////////////////////////////////////////////////////////////////// 284 | -------------------------------------------------------------------------------- /include/cuasr/gemm/epilogue/thread/semiring_linear_combination.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Functor performing linear combination operations used by epilogues. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/array.h" 11 | #include "cutlass/cutlass.h" 12 | #include "cutlass/functional.h" 13 | #include "cutlass/numeric_conversion.h" 14 | #include "cutlass/numeric_types.h" 15 | 16 | ///////////////////////////////////////////////////////////////////////////////////////////////// 17 | 18 | namespace cuasr { 19 | namespace epilogue { 20 | namespace thread { 21 | 22 | ///////////////////////////////////////////////////////////////////////////////////////////////// 23 | 24 | /// Applies a linear combination operator to an array of elements. 25 | /// 26 | /// D = alpha * accumulator + beta * source + uniform 27 | /// 28 | template < 29 | typename AdditionOp_, ///< Addition reel of this semi-ring 30 | typename MultiplicationOp_, ///< Addition reel of this semi-ring 31 | typename ElementOutput_, ///< Data type used to load and store tensors 32 | int Count, ///< Number of elements computed per operation 33 | typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type 34 | typename ElementCompute_ 35 | = ElementOutput_, ///< Data type used to compute linear combination 36 | cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest> 37 | class SemiringLinearCombination { 38 | public: 39 | using AdditionOp = AdditionOp_; 40 | using MultiplicationOp = MultiplicationOp_; 41 | 42 | using ElementOutput = ElementOutput_; 43 | using ElementAccumulator = ElementAccumulator_; 44 | using ElementCompute = ElementCompute_; 45 | static int const kCount = Count; 46 | 47 | using FragmentOutput = cutlass::Array; 48 | using FragmentAccumulator = cutlass::Array; 49 | using ComputeFragment = cutlass::Array; 50 | 51 | static cutlass::FloatRoundStyle const kRound = Round; 52 | 53 | /// Host-constructable parameters structure 54 | struct Params { 55 | ElementCompute alpha; ///< scales accumulators 56 | ElementCompute beta; ///< scales source tensor 57 | ElementCompute const 58 | *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory 59 | ElementCompute const 60 | *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory 61 | 62 | // 63 | // Methods 64 | // 65 | 66 | CUTLASS_HOST_DEVICE 67 | Params() 68 | : alpha(MultiplicationOp::Identity) 69 | , beta(MultiplicationOp::Annihilator) 70 | , alpha_ptr(nullptr) 71 | , beta_ptr(nullptr) { } 72 | 73 | CUTLASS_HOST_DEVICE 74 | Params(ElementCompute alpha, ElementCompute beta) 75 | : alpha(alpha) 76 | , beta(beta) 77 | , alpha_ptr(nullptr) 78 | , beta_ptr(nullptr) { } 79 | 80 | CUTLASS_HOST_DEVICE 81 | Params(ElementCompute alpha) 82 | : alpha(alpha) 83 | , beta(MultiplicationOp::Annihilator) 84 | , alpha_ptr(nullptr) 85 | , beta_ptr(nullptr) { } 86 | 87 | CUTLASS_HOST_DEVICE 88 | Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) 89 | : alpha(MultiplicationOp::Identity) 90 | , beta(MultiplicationOp::Annihilator) 91 | , alpha_ptr(alpha_ptr) 92 | , beta_ptr(beta_ptr) { } 93 | 94 | CUTLASS_HOST_DEVICE 95 | Params(ElementCompute const *alpha_ptr) 96 | : alpha(MultiplicationOp::Identity) 97 | , beta(MultiplicationOp::Annihilator) 98 | , alpha_ptr(alpha_ptr) 99 | , beta_ptr(nullptr) { } 100 | }; 101 | 102 | private: 103 | // scalars 104 | ElementCompute alpha_; 105 | ElementCompute beta_; 106 | AdditionOp add_op_; 107 | MultiplicationOp mult_op_; 108 | 109 | public: 110 | /// Constructs the function object, possibly loading from pointers in host memory 111 | CUTLASS_HOST_DEVICE 112 | SemiringLinearCombination(Params const ¶ms) { 113 | alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); 114 | beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); 115 | } 116 | 117 | /// Returns true if source is needed 118 | CUTLASS_HOST_DEVICE 119 | bool is_source_needed() const { 120 | ElementCompute kAdditiveIdentity = AdditionOp::Identity; 121 | ElementCompute kMultiplicativeIdentity = MultiplicationOp::Identity; 122 | 123 | // no source needed if mult_op(beta, C[i,j]) is equal to add_op's identity 124 | return (kAdditiveIdentity != mult_op_(beta_, kMultiplicativeIdentity)); 125 | } 126 | 127 | /// Functionally required for serial reduction in the epilogue 128 | CUTLASS_HOST_DEVICE 129 | void set_k_partition(int k_partition) { 130 | if (k_partition) { 131 | ElementCompute kMultiplicativeIdentity = MultiplicationOp::Identity; 132 | beta_ = kMultiplicativeIdentity; 133 | } 134 | } 135 | 136 | /// Computes semiring linear scale and translate 137 | /// D = add_op_(mult_op_(alpha * accumulator), mult_op_(beta * source)) 138 | CUTLASS_HOST_DEVICE 139 | FragmentOutput 140 | operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const { 141 | // Convert source to internal compute numeric type 142 | cutlass::NumericArrayConverter 143 | source_converter; 144 | cutlass::NumericArrayConverter 145 | accumulator_converter; 146 | 147 | ComputeFragment converted_source = source_converter(source); 148 | ComputeFragment converted_accumulator = accumulator_converter(accumulator); 149 | 150 | // Perform binary operations 151 | // X = beta * C 152 | ComputeFragment intermediate = mult_op_(beta_, converted_source); 153 | 154 | // D = (alpha * Accum) + X 155 | intermediate = add_op_(mult_op_(alpha_, converted_accumulator), intermediate); 156 | 157 | // Convert to destination numeric type 158 | cutlass::NumericArrayConverter 159 | destination_converter; 160 | 161 | return destination_converter(intermediate); 162 | } 163 | 164 | /// Computes semiring linear scaling: D = mult_op_(alpha, accumulator) 165 | CUTLASS_HOST_DEVICE 166 | FragmentOutput operator()(FragmentAccumulator const &accumulator) const { 167 | // Convert source to internal compute numeric type 168 | cutlass::NumericArrayConverter 169 | accumulator_converter; 170 | 171 | ComputeFragment converted_accumulator = accumulator_converter(accumulator); 172 | 173 | // Perform binary operations 174 | ComputeFragment intermediate; 175 | 176 | intermediate = mult_op_(alpha_, converted_accumulator); // D = alpha * Accum 177 | 178 | // Convert to destination numeric type 179 | cutlass::NumericArrayConverter 180 | destination_converter; 181 | 182 | return destination_converter(intermediate); 183 | } 184 | }; 185 | 186 | ///////////////////////////////////////////////////////////////////////////////////////////////// 187 | 188 | } // namespace thread 189 | } // namespace epilogue 190 | } // namespace cuasr 191 | -------------------------------------------------------------------------------- /include/cuasr/gemm/kernel/default_srgemm.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief 6 | Default kernel-level SRGEMM definitions combine threadblock-scoped matrix srmma 7 | with the appropriate threadblock-scoped epilogue. 8 | 9 | Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are 10 | accommodated by exchanging A and B operands and assuming transposed layouts. Partial 11 | specializations here choose 'device::GemmTransposed' to implement this functionality. 12 | */ 13 | 14 | #pragma once 15 | 16 | #include "cutlass/cutlass.h" 17 | 18 | #include "cutlass/layout/matrix.h" 19 | #include "cutlass/numeric_types.h" 20 | 21 | #include "cutlass/gemm/gemm.h" 22 | #include "cutlass/gemm/kernel/gemm_pipelined.h" 23 | #include "cutlass/gemm/threadblock/threadblock_swizzle.h" 24 | 25 | #include "cutlass/epilogue/threadblock/epilogue.h" 26 | #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" 27 | #include "cutlass/transform/threadblock/predicated_tile_iterator.h" 28 | 29 | #include "cuasr/arch/srmma.h" 30 | #include "cuasr/gemm/kernel/srgemm.h" 31 | #include "cuasr/gemm/threadblock/default_srmma.h" 32 | 33 | //////////////////////////////////////////////////////////////////////////////// 34 | 35 | namespace cuasr { 36 | namespace gemm { 37 | namespace kernel { 38 | 39 | //////////////////////////////////////////////////////////////////////////////// 40 | 41 | template < 42 | /// Element type for A matrix operand 43 | typename ElementA_, 44 | /// Layout type for A matrix operand 45 | typename LayoutA_, 46 | /// Access granularity of A matrix in units of elements 47 | int kAlignmentA, 48 | /// Element type for B matrix operand 49 | typename ElementB_, 50 | /// Layout type for B matrix operand 51 | typename LayoutB_, 52 | /// Access granularity of B matrix in units of elements 53 | int kAlignmentB, 54 | /// Element type for C and D matrix operands 55 | typename ElementC_, 56 | /// Layout type for C and D matrix operands 57 | typename LayoutC_, 58 | /// Element type for internal accumulation 59 | typename ElementAccumulator, 60 | /// Operator class tag 61 | typename OperatorClass, 62 | /// Tag indicating architecture to tune for 63 | typename ArchTag, 64 | /// Threadblock-level tile size (concept: GemmShape) 65 | typename ThreadblockShape, 66 | /// Warp-level tile size (concept: GemmShape) 67 | typename WarpShape, 68 | /// Instruction-level tile size (concept: GemmShape) 69 | typename InstructionShape, 70 | /// Addition operator of the semi-ring 71 | typename AdditionOp, 72 | /// Multiplication operator of the semi-ring 73 | typename MultiplicationOp, 74 | /// Epilogue output operator 75 | typename EpilogueOutputOp, 76 | /// Threadblock-level swizzling operator 77 | typename ThreadblockSwizzle, 78 | /// Number of stages used in the pipelined mainloop 79 | int Stages, 80 | /// If true, kernel is configured to support serial reduction in the 81 | /// epilogue 82 | bool SplitKSerial> 83 | struct DefaultSrgemm; 84 | 85 | template < 86 | /// Element type for A matrix operand 87 | typename ElementA, 88 | /// Layout type for A matrix operand 89 | typename LayoutA, 90 | /// Access granularity of A matrix in units of elements 91 | int kAlignmentA, 92 | /// Element type for B matrix operand 93 | typename ElementB, 94 | /// Layout type for B matrix operand 95 | typename LayoutB, 96 | /// Access granularity of A matrix in units of elements 97 | int kAlignmentB, 98 | /// Element type for C and D matrix operands 99 | typename ElementC, 100 | /// Element type for internal accumulation 101 | typename ElementAccumulator, 102 | /// Tag indicating architecture to tune for 103 | typename ArchTag, 104 | /// Threadblock-level tile size (concept: GemmShape) 105 | typename ThreadblockShape, 106 | /// Warp-level tile size (concept: GemmShape) 107 | typename WarpShape, 108 | /// Addition operator of the semi-ring 109 | typename AdditionOp, 110 | /// Multiplication operator of the semi-ring 111 | typename MultiplicationOp, 112 | /// Epilogue output operator 113 | typename EpilogueOutputOp, 114 | /// Threadblock-level swizzling operator 115 | typename ThreadblockSwizzle, 116 | /// If true, kernel is configured to support serial reduction in the epilogue 117 | bool SplitKSerial 118 | > 119 | struct DefaultSrgemm< 120 | ElementA, 121 | LayoutA, 122 | kAlignmentA, 123 | ElementB, 124 | LayoutB, 125 | kAlignmentB, 126 | ElementC, 127 | cutlass::layout::RowMajor, 128 | ElementAccumulator, 129 | cutlass::arch::OpClassSimt, 130 | ArchTag, 131 | ThreadblockShape, 132 | WarpShape, 133 | cutlass::gemm::GemmShape<1, 1, 1>, 134 | AdditionOp, 135 | MultiplicationOp, 136 | EpilogueOutputOp, 137 | ThreadblockSwizzle, 138 | 2, 139 | SplitKSerial> { 140 | /// Define the threadblock-scoped matrix multiply-accumulate 141 | using Srmma = typename cuasr::gemm::threadblock::DefaultSrmma< 142 | ElementA, 143 | LayoutA, 144 | kAlignmentA, 145 | ElementB, 146 | LayoutB, 147 | kAlignmentB, 148 | ElementAccumulator, 149 | cutlass::layout::RowMajor, 150 | cutlass::arch::OpClassSimt, 151 | cutlass::arch::Sm50, 152 | ThreadblockShape, 153 | WarpShape, 154 | cutlass::gemm::GemmShape<1, 1, 1>, 155 | AdditionOp, 156 | MultiplicationOp, 157 | 2>::ThreadblockSrmma; 158 | 159 | static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; 160 | static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); 161 | 162 | /// Define the epilogue 163 | using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< 164 | ThreadblockShape, 165 | typename Srmma::Operator, 166 | EpilogueOutputOp, 167 | kEpilogueElementsPerAccess 168 | >::Epilogue; 169 | 170 | /// Define the kernel-level GEMM operator. 171 | using SrgemmKernel = cuasr::gemm::kernel::Srgemm< 172 | Srmma, 173 | AdditionOp, 174 | MultiplicationOp, 175 | Epilogue, 176 | ThreadblockSwizzle, 177 | SplitKSerial 178 | >; 179 | }; 180 | 181 | //////////////////////////////////////////////////////////////////////////////// 182 | 183 | } // namespace kernel 184 | } // namespace gemm 185 | } // namespace cuasr 186 | 187 | //////////////////////////////////////////////////////////////////////////////// 188 | -------------------------------------------------------------------------------- /include/cuasr/gemm/kernel/default_srgemm_splitk_parallel.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | 3 | **************************************************************************************************/ 4 | 5 | /*! \file 6 | \brief 7 | Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with 8 | the appropriate threadblock-scoped epilogue. 9 | 10 | Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are 11 | accommodated by exchanging A and B operands and assuming transposed layouts. Partial 12 | specializations here choose 'device::GemmTransposed' to implement this functionality. 13 | */ 14 | 15 | #pragma once 16 | 17 | #include "cutlass/cutlass.h" 18 | 19 | #include "cuasr/gemm/kernel/default_srgemm.h" 20 | #include "cuasr/gemm/kernel/srgemm_splitk_parallel.h" 21 | 22 | //////////////////////////////////////////////////////////////////////////////// 23 | 24 | namespace cuasr { 25 | namespace gemm { 26 | namespace kernel { 27 | 28 | //////////////////////////////////////////////////////////////////////////////// 29 | 30 | template < 31 | /// Element type for A matrix operand 32 | typename ElementA_, 33 | /// Layout type for A matrix operand 34 | typename LayoutA_, 35 | /// Access granularity of A matrix in units of elements 36 | int kAlignmentA, 37 | /// Element type for B matrix operand 38 | typename ElementB_, 39 | /// Layout type for B matrix operand 40 | typename LayoutB_, 41 | /// Access granularity of B matrix in units of elements 42 | int kAlignmentB, 43 | /// Element type for C and D matrix operands 44 | typename ElementC_, 45 | /// Layout type for C and D matrix operands 46 | typename LayoutC_, 47 | /// Element type for internal accumulation 48 | typename ElementAccumulator, 49 | /// Operator class tag 50 | typename OperatorClass, 51 | /// Tag indicating architecture to tune for 52 | typename ArchTag, 53 | /// Threadblock-level tile size (concept: GemmShape) 54 | typename ThreadblockShape, 55 | /// Warp-level tile size (concept: GemmShape) 56 | typename WarpShape, 57 | /// Warp-level tile size (concept: GemmShape) 58 | typename InstructionShape, 59 | /// Addition operator of the semi-ring 60 | typename AdditionOp, 61 | /// Multiplication operator of the semi-ring 62 | typename MultiplicationOp, 63 | /// Epilogue output operator 64 | typename EpilogueOutputOp, 65 | /// Threadblock-level swizzling operator 66 | typename ThreadblockSwizzle, 67 | /// Number of stages used in the pipelined mainloop 68 | int Stages 69 | > 70 | struct DefaultSrgemmSplitKParallel { 71 | 72 | // Define threadblock-scoped split-K matrix multiply using 73 | // the basic SRGEMM's kernel level main loop 74 | using Default = DefaultSrgemm< 75 | ElementA_, 76 | LayoutA_, 77 | kAlignmentA, 78 | ElementB_, 79 | LayoutB_, 80 | kAlignmentB, 81 | ElementAccumulator, 82 | LayoutC_, 83 | ElementAccumulator, 84 | OperatorClass, 85 | ArchTag, 86 | ThreadblockShape, 87 | WarpShape, 88 | InstructionShape, 89 | AdditionOp, 90 | MultiplicationOp, 91 | EpilogueOutputOp, 92 | ThreadblockSwizzle, 93 | Stages, 94 | false 95 | >; 96 | 97 | /// Define the semiring matrix multiply operator 98 | using Srmma = typename Default::Srmma; 99 | 100 | /// Define the epilogue 101 | using Epilogue = typename Default::Epilogue; 102 | 103 | /// Define the kernel-level GEMM operator. 104 | using SrgemmKernel = kernel::SrgemmSplitKParallel< 105 | Srmma, 106 | AdditionOp, 107 | MultiplicationOp, 108 | Epilogue, 109 | ThreadblockSwizzle 110 | >; 111 | }; 112 | 113 | /////////////////////////////////////////////////////////////////////////////////////////////////// 114 | 115 | } // namespace kernel 116 | } // namespace gemm 117 | } // namespace cuasr 118 | 119 | /////////////////////////////////////////////////////////////////////////////////////////////////// 120 | -------------------------------------------------------------------------------- /include/cuasr/gemm/kernel/srgemm.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Template for a pipelined Semiring GEMM kernel. Does not compute batching or support split-K. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | 12 | #include "cutlass/gemm/gemm.h" 13 | #include "cutlass/matrix_coord.h" 14 | #include "cutlass/semaphore.h" 15 | 16 | #include "cuasr/arch/srmma.h" 17 | 18 | ///////////////////////////////////////////////////////////////////////////////////////////////// 19 | 20 | namespace cuasr { 21 | namespace gemm { 22 | namespace kernel { 23 | 24 | ///////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | // SemiRing Gemm kernel that support custom thread level MMA and init values. 27 | template < 28 | typename Srmma_, ///! Threadblock-scoped matrix multiply-accumulate 29 | typename AdditionOp_, ///! Addition operator of the semi-ring 30 | typename MultiplicationOp_, ///! Multiplication operator of the semi-ring 31 | typename Epilogue_, ///! Epilogue 32 | typename ThreadblockSwizzle_, ///! Threadblock swizzling function 33 | bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. 34 | > 35 | struct Srgemm { 36 | 37 | using Srmma = Srmma_; 38 | using Epilogue = Epilogue_; 39 | using OutputOp = typename Epilogue::OutputOp; 40 | using ThreadblockSwizzle = ThreadblockSwizzle_; 41 | using AdditionOp = AdditionOp_; 42 | using MultiplicationOp = MultiplicationOp_; 43 | static bool const kSplitKSerial = SplitKSerial; 44 | 45 | /// Warp count (concept: GemmShape) 46 | using WarpCount = typename Srmma::WarpCount; 47 | static int const kThreadCount = 32 * WarpCount::kCount; 48 | 49 | /// Parameters structure 50 | struct Params { 51 | cutlass::gemm::GemmCoord problem_size; 52 | cutlass::gemm::GemmCoord grid_tiled_shape; 53 | typename Srmma::IteratorA::Params params_A; 54 | typename Srmma::IteratorA::TensorRef ref_A; 55 | typename Srmma::IteratorB::Params params_B; 56 | typename Srmma::IteratorB::TensorRef ref_B; 57 | typename Epilogue::OutputTileIterator::Params params_C; 58 | typename Epilogue::OutputTileIterator::TensorRef ref_C; 59 | typename Epilogue::OutputTileIterator::Params params_D; 60 | typename Epilogue::OutputTileIterator::TensorRef ref_D; 61 | typename OutputOp::Params output_op; 62 | int *semaphore; 63 | int gemm_k_iterations; 64 | int gemm_k_size; 65 | 66 | // 67 | // Methods 68 | // 69 | 70 | CUTLASS_HOST_DEVICE 71 | Params() { } 72 | 73 | CUTLASS_HOST_DEVICE 74 | Params( 75 | cutlass::gemm::GemmCoord const & problem_size, 76 | cutlass::gemm::GemmCoord const & grid_tiled_shape, 77 | typename Srmma::IteratorA::TensorRef ref_A, 78 | typename Srmma::IteratorB::TensorRef ref_B, 79 | typename Epilogue::OutputTileIterator::TensorRef ref_C, 80 | typename Epilogue::OutputTileIterator::TensorRef ref_D, 81 | typename OutputOp::Params output_op = typename OutputOp::Params(), 82 | int *semaphore = nullptr 83 | ): 84 | problem_size(problem_size), 85 | grid_tiled_shape(grid_tiled_shape), 86 | params_A(ref_A.layout()), 87 | ref_A(ref_A), 88 | params_B(ref_B.layout()), 89 | ref_B(ref_B), 90 | params_C(ref_C.layout()), 91 | ref_C(ref_C), 92 | params_D(ref_D.layout()), 93 | ref_D(ref_D), 94 | output_op(output_op), 95 | semaphore(semaphore) { 96 | 97 | int total_gemm_k_iterations = (problem_size.k() + Srmma::Shape::kK - 1) / Srmma::Shape::kK; 98 | int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); 99 | 100 | gemm_k_size = gemm_k_iterations * Srmma::Shape::kK; 101 | } 102 | }; 103 | 104 | /// Shared memory storage structure 105 | union SharedStorage { 106 | typename Srmma::SharedStorage main_loop; 107 | typename Epilogue::SharedStorage epilogue; 108 | }; 109 | 110 | // 111 | // Methods 112 | // 113 | 114 | CUTLASS_HOST_DEVICE 115 | Srgemm() { } 116 | 117 | /// Determines whether kernel satisfies alignment 118 | static cutlass::Status can_implement( 119 | cutlass::gemm::GemmCoord const & problem_size, 120 | typename Srmma::IteratorA::TensorRef ref_A, 121 | typename Srmma::IteratorB::TensorRef ref_B, 122 | typename Epilogue::OutputTileIterator::TensorRef ref_C, 123 | typename Epilogue::OutputTileIterator::TensorRef ref_D) { 124 | 125 | static int const kAlignmentA = Srmma::IteratorA::AccessType::kElements; 126 | static int const kAlignmentB = Srmma::IteratorB::AccessType::kElements; 127 | static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; 128 | 129 | if (!TensorRef_aligned(ref_A, kAlignmentA)) { 130 | return cutlass::Status::kErrorMisalignedOperand; 131 | } 132 | 133 | if (!TensorRef_aligned(ref_B, kAlignmentB)) { 134 | return cutlass::Status::kErrorMisalignedOperand; 135 | } 136 | 137 | if (!TensorRef_aligned(ref_C, kAlignmentC)) { 138 | return cutlass::Status::kErrorMisalignedOperand; 139 | } 140 | 141 | if (!TensorRef_aligned(ref_D, kAlignmentC)) { 142 | return cutlass::Status::kErrorMisalignedOperand; 143 | } 144 | 145 | if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || 146 | (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || 147 | (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { 148 | 149 | return cutlass::Status::kErrorMisalignedOperand; 150 | } 151 | 152 | return cutlass::Status::kSuccess; 153 | } 154 | 155 | /// Executes one GEMM 156 | CUTLASS_DEVICE 157 | void operator()(Params const ¶ms, SharedStorage &shared_storage) { 158 | constexpr typename OutputOp::ElementCompute kAdditiveIdentity = AdditionOp::Identity; 159 | 160 | // Compute threadblock location 161 | ThreadblockSwizzle threadblock_swizzle; 162 | 163 | cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); 164 | 165 | // Early exit if CTA is out of range 166 | if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || 167 | params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { 168 | 169 | return; 170 | } 171 | 172 | // Compute initial location in logical coordinates 173 | cutlass::MatrixCoord tb_offset_A{ 174 | threadblock_tile_offset.m() * Srmma::Shape::kM, 175 | threadblock_tile_offset.k() * params.gemm_k_size, 176 | }; 177 | 178 | cutlass::MatrixCoord tb_offset_B{ 179 | threadblock_tile_offset.k() * params.gemm_k_size, 180 | threadblock_tile_offset.n() * Srmma::Shape::kN 181 | }; 182 | 183 | // Problem size is a function of threadblock index in the K dimension 184 | int problem_size_k = min( 185 | params.problem_size.k(), 186 | (threadblock_tile_offset.k() + 1) * params.gemm_k_size); 187 | 188 | // Compute threadblock-scoped matrix multiply-add 189 | int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Srmma::Shape::kK - 1) / Srmma::Shape::kK; 190 | 191 | // Compute position within threadblock 192 | int thread_idx = threadIdx.x; 193 | 194 | // Construct iterators to A and B operands 195 | typename Srmma::IteratorA iterator_A( 196 | params.params_A, 197 | params.ref_A.data(), 198 | {params.problem_size.m(), problem_size_k}, 199 | thread_idx, 200 | tb_offset_A); 201 | 202 | typename Srmma::IteratorB iterator_B( 203 | params.params_B, 204 | params.ref_B.data(), 205 | {problem_size_k, params.problem_size.n()}, 206 | thread_idx, 207 | tb_offset_B); 208 | 209 | int warp_idx = threadIdx.x / 32; 210 | int lane_idx = threadIdx.x % 32; 211 | 212 | // 213 | // Main loop 214 | // 215 | 216 | // Construct thread-scoped matrix multiply 217 | Srmma srmma_thrblock_op(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, kAdditiveIdentity); 218 | 219 | // need to clear accumulators to additive identity for SemiRing Gemm 220 | typename Srmma::FragmentC accumulators; 221 | accumulators.fill(kAdditiveIdentity); 222 | 223 | if (!kSplitKSerial || gemm_k_iterations > 0) { 224 | // Compute threadblock-scoped matrix multiply-add 225 | srmma_thrblock_op(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); 226 | } 227 | 228 | // 229 | // Epilogue 230 | // 231 | 232 | OutputOp output_op(params.output_op); 233 | 234 | // 235 | // Masked tile iterators constructed from members 236 | // 237 | 238 | threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); 239 | 240 | //assume identity swizzle 241 | cutlass::MatrixCoord threadblock_offset( 242 | threadblock_tile_offset.m() * Srmma::Shape::kM, 243 | threadblock_tile_offset.n() * Srmma::Shape::kN 244 | ); 245 | 246 | int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); 247 | 248 | // Construct the semaphore. 249 | cutlass::Semaphore semaphore(params.semaphore + block_idx, thread_idx); 250 | 251 | // If performing a reduction via split-K, fetch the initial synchronization 252 | if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { 253 | 254 | // Fetch the synchronization lock initially but do not block. 255 | semaphore.fetch(); 256 | 257 | // Indicate which position in a serial reduction the output operator is currently updating 258 | output_op.set_k_partition(threadblock_tile_offset.k()); 259 | } 260 | 261 | // Tile iterator loading from source tensor. 262 | typename Epilogue::OutputTileIterator iterator_C( 263 | params.params_C, 264 | params.ref_C.data(), 265 | params.problem_size.mn(), 266 | thread_idx, 267 | threadblock_offset 268 | ); 269 | 270 | // Tile iterator writing to destination tensor. 271 | typename Epilogue::OutputTileIterator iterator_D( 272 | params.params_D, 273 | params.ref_D.data(), 274 | params.problem_size.mn(), 275 | thread_idx, 276 | threadblock_offset 277 | ); 278 | 279 | Epilogue epilogue( 280 | shared_storage.epilogue, 281 | thread_idx, 282 | warp_idx, 283 | lane_idx); 284 | 285 | // Wait on the semaphore - this latency may have been covered by iterator construction 286 | if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { 287 | 288 | // For subsequent threadblocks, the source matrix is held in the 'D' tensor. 289 | if (threadblock_tile_offset.k()) { 290 | iterator_C = iterator_D; 291 | } 292 | 293 | semaphore.wait(threadblock_tile_offset.k()); 294 | 295 | __threadfence(); 296 | } 297 | 298 | // Execute the epilogue operator to update the destination tensor. 299 | epilogue(output_op, iterator_D, accumulators, iterator_C); 300 | 301 | // 302 | // Release the semaphore 303 | // 304 | 305 | if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { 306 | 307 | int lock = 0; 308 | if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { 309 | 310 | // The final threadblock resets the semaphore for subsequent grids. 311 | lock = 0; 312 | } 313 | else { 314 | // Otherwise, the semaphore is incremented 315 | lock = threadblock_tile_offset.k() + 1; 316 | } 317 | 318 | __threadfence(); 319 | semaphore.release(lock); 320 | } 321 | } 322 | }; 323 | 324 | 325 | ///////////////////////////////////////////////////////////////////////////////////////////////// 326 | 327 | } // namespace kernel 328 | } // namespace gemm 329 | } // namespace cuasr 330 | 331 | ///////////////////////////////////////////////////////////////////////////////////////////////// 332 | -------------------------------------------------------------------------------- /include/cuasr/gemm/kernel/srgemm_splitk_parallel.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Template for 3D SRGEMM performing a reduction over K partitions in parallel. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | 12 | #include "cutlass/gemm/gemm.h" 13 | #include "cutlass/matrix_coord.h" 14 | 15 | ///////////////////////////////////////////////////////////////////////////////////////////////// 16 | 17 | namespace cuasr { 18 | namespace gemm { 19 | namespace kernel { 20 | 21 | ///////////////////////////////////////////////////////////////////////////////////////////////// 22 | 23 | template < 24 | typename Srmma_, ///! Threadblock-scoped matrix multiply-accumulate 25 | typename AdditionOp_, ///! Addition operator of the semi-ring 26 | typename MultiplicationOp_, ///! Multiplication operator of the semi-ring 27 | typename Epilogue_, ///! Epilogue 28 | typename ThreadblockSwizzle_ ///! Threadblock swizzling function 29 | > 30 | struct SrgemmSplitKParallel { 31 | 32 | using Srmma = Srmma_; 33 | using AdditionOp = AdditionOp_; 34 | using MultiplicationOp = MultiplicationOp_; 35 | using Epilogue = Epilogue_; 36 | using OutputOp = typename Epilogue::OutputOp; 37 | using ThreadblockSwizzle = ThreadblockSwizzle_; 38 | 39 | /// Warp count (concept: GemmShape) 40 | using WarpCount = typename Srmma::WarpCount; 41 | static int const kThreadCount = 32 * WarpCount::kCount; 42 | static int const kAlignmentK = Srmma::Operator::Shape::kK; 43 | 44 | /// Parameters structure 45 | struct Params { 46 | cutlass::gemm::GemmCoord problem_size; 47 | cutlass::gemm::GemmCoord grid_tiled_shape; 48 | typename Srmma::IteratorA::Params params_A; 49 | typename Srmma::IteratorA::TensorRef ref_A; 50 | typename Srmma::IteratorB::Params params_B; 51 | typename Srmma::IteratorB::TensorRef ref_B; 52 | typename Epilogue::OutputTileIterator::Params params_D; 53 | typename Epilogue::OutputTileIterator::TensorRef ref_D; 54 | typename OutputOp::Params output_op; 55 | int64_t splitk_slice_stride; 56 | int gemm_k_size; 57 | 58 | // 59 | // Methods 60 | // 61 | 62 | CUTLASS_HOST_DEVICE 63 | Params() { } 64 | 65 | CUTLASS_HOST_DEVICE 66 | Params( 67 | cutlass::gemm::GemmCoord const & problem_size, 68 | cutlass::gemm::GemmCoord const & grid_tiled_shape, 69 | typename Srmma::IteratorA::TensorRef ref_A, 70 | typename Srmma::IteratorB::TensorRef ref_B, 71 | typename Epilogue::OutputTileIterator::TensorRef ref_D, 72 | typename OutputOp::Params output_op, 73 | int64_t splitk_slice_stride 74 | ): 75 | problem_size(problem_size), 76 | grid_tiled_shape(grid_tiled_shape), 77 | params_A(ref_A.layout()), 78 | ref_A(ref_A), 79 | params_B(ref_B.layout()), 80 | ref_B(ref_B), 81 | params_D(ref_D.layout()), 82 | ref_D(ref_D), 83 | output_op(output_op), 84 | splitk_slice_stride(splitk_slice_stride) { 85 | 86 | int full_gemm_k_iterations = problem_size.k() / Srmma::Shape::kK; 87 | int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); 88 | 89 | gemm_k_size = gemm_k_iterations * Srmma::Shape::kK; 90 | } 91 | }; 92 | 93 | /// Shared memory storage structure 94 | union SharedStorage { 95 | typename Srmma::SharedStorage main_loop; 96 | typename Epilogue::SharedStorage epilogue; 97 | }; 98 | 99 | // 100 | // Methods 101 | // 102 | 103 | CUTLASS_HOST_DEVICE 104 | SrgemmSplitKParallel() { } 105 | 106 | /// Executes one GEMM 107 | CUTLASS_DEVICE 108 | void operator()(Params const ¶ms, SharedStorage &shared_storage) { 109 | constexpr typename OutputOp::ElementCompute kAdditiveIdentity = AdditionOp::Identity; 110 | 111 | // Compute threadblock location 112 | ThreadblockSwizzle threadblock_swizzle; 113 | 114 | cutlass::gemm::GemmCoord threadblock_tile_offset = 115 | threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); 116 | 117 | // Early exit if CTA is out of range 118 | if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || 119 | params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { 120 | 121 | return; 122 | } 123 | 124 | // Compute initial location in logical coordinates 125 | cutlass::MatrixCoord tb_offset_A{ 126 | threadblock_tile_offset.m() * Srmma::Shape::kM, 127 | threadblock_tile_offset.k() * params.gemm_k_size, 128 | }; 129 | 130 | cutlass::MatrixCoord tb_offset_B{ 131 | threadblock_tile_offset.k() * params.gemm_k_size, 132 | threadblock_tile_offset.n() * Srmma::Shape::kN 133 | }; 134 | 135 | // Problem size is a function of threadblock index in the K dimension 136 | int problem_size_k; 137 | if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { 138 | problem_size_k = params.problem_size.k(); 139 | } 140 | else { 141 | problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; 142 | } 143 | 144 | // Compute threadblock-scoped matrix multiply-add 145 | int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Srmma::Shape::kK - 1) / Srmma::Shape::kK; 146 | 147 | // Compute position within threadblock 148 | int thread_idx = threadIdx.x; 149 | 150 | // Construct iterators to A and B operands 151 | typename Srmma::IteratorA iterator_A( 152 | params.params_A, 153 | params.ref_A.data(), 154 | {params.problem_size.m(), problem_size_k}, 155 | thread_idx, 156 | tb_offset_A); 157 | 158 | typename Srmma::IteratorB iterator_B( 159 | params.params_B, 160 | params.ref_B.data(), 161 | {problem_size_k, params.problem_size.n()}, 162 | thread_idx, 163 | tb_offset_B); 164 | 165 | int warp_idx = threadIdx.x / 32; 166 | int lane_idx = threadIdx.x % 32; 167 | 168 | // 169 | // Main loop 170 | // 171 | 172 | // Construct thread-scoped matrix multiply 173 | Srmma srmma_thrblock_op(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, kAdditiveIdentity); 174 | 175 | // need to clear accumulators to additive identity for SemiRing Gemm 176 | typename Srmma::FragmentC accumulators; 177 | accumulators.fill(kAdditiveIdentity); 178 | 179 | srmma_thrblock_op(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); 180 | 181 | // 182 | // Epilogue 183 | // 184 | 185 | OutputOp output_op(params.output_op); 186 | 187 | // 188 | // Masked tile iterators constructed from members 189 | // 190 | 191 | threadblock_tile_offset = 192 | threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); 193 | 194 | // assume identity swizzle 195 | cutlass::MatrixCoord threadblock_offset( 196 | threadblock_tile_offset.m() * Srmma::Shape::kM, 197 | threadblock_tile_offset.n() * Srmma::Shape::kN 198 | ); 199 | 200 | // Tile iterator writing to output tile 201 | typename Epilogue::OutputTileIterator iterator_D( 202 | params.params_D, 203 | params.ref_D.data(), 204 | params.problem_size.mn(), 205 | thread_idx, 206 | threadblock_offset 207 | ); 208 | 209 | iterator_D.add_pointer_offset(params.splitk_slice_stride * threadblock_tile_offset.k()); 210 | 211 | // Execute the epilogue 212 | Epilogue epilogue( 213 | shared_storage.epilogue, 214 | thread_idx, 215 | warp_idx, 216 | lane_idx); 217 | 218 | // Run efficient epilogue 219 | epilogue(output_op, iterator_D, accumulators, iterator_D); 220 | } 221 | }; 222 | 223 | ///////////////////////////////////////////////////////////////////////////////////////////////// 224 | 225 | } // namespace kernel 226 | } // namespace gemm 227 | } // namespace cuasr 228 | -------------------------------------------------------------------------------- /include/cuasr/gemm/thread/srmma.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Templates exposing architecture support for warp-level multiply-add operations 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/array.h" 12 | #include "cutlass/numeric_types.h" 13 | #include "cutlass/gemm/gemm.h" 14 | #include "cutlass/arch/mma.h" 15 | 16 | #include "cutlass/gemm/thread/mma.h" 17 | 18 | ///////////////////////////////////////////////////////////////////////////////////////////////// 19 | 20 | namespace cuasr { 21 | namespace gemm { 22 | namespace thread { 23 | 24 | ///////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | /// Structure to compute the matrix product 27 | template < 28 | /// Size of the Gemm problem - concept: gemm::GemmShape<> 29 | typename Shape, 30 | /// Data type of A elements 31 | typename ElementA, 32 | /// Layout of A matrix (concept: MatrixLayout) 33 | typename LayoutA, 34 | /// Data type of B elements 35 | typename ElementB, 36 | /// Layout of B matrix (concept: MatrixLayout) 37 | typename LayoutB, 38 | /// Element type of C matrix 39 | typename ElementC, 40 | /// Layout of C matrix (concept: MatrixLayout) 41 | typename LayoutC, 42 | /// Addition operator of the semi-ring 43 | typename AdditionOp, 44 | /// Multiplication operator of the semi-ring 45 | typename MultiplicationOp, 46 | /// Used for partial specialization 47 | typename Enable = bool 48 | > 49 | struct Srmma; 50 | 51 | ///////////////////////////////////////////////////////////////////////////////////////////////// 52 | 53 | } // namespace thread 54 | } // namespace gemm 55 | } // namespace cuasr 56 | 57 | ///////////////////////////////////////////////////////////////////////////////////////////////// 58 | 59 | // 60 | // Overloads specialized for existing architectures 61 | // 62 | 63 | #include "cuasr/gemm/thread/srmma_sm50.h" 64 | 65 | ///////////////////////////////////////////////////////////////////////////////////////////////// 66 | -------------------------------------------------------------------------------- /include/cuasr/gemm/thread/srmma_sm50.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Templates exposing architecture support for multiply-add operations 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/tensor_ref.h" 12 | #include "cutlass/layout/matrix.h" 13 | #include "cutlass/gemm/gemm.h" 14 | 15 | #include "cuasr/arch/srmma.h" 16 | 17 | ///////////////////////////////////////////////////////////////////////////////////////////////// 18 | 19 | namespace cuasr { 20 | namespace gemm { 21 | namespace thread { 22 | 23 | ///////////////////////////////////////////////////////////////////////////////////////////////// 24 | 25 | /// Gemplate that handles all packed matrix layouts 26 | template < 27 | /// Size of the Gemm problem - concept: cutlass::gemm::GemmShape<> 28 | typename Shape_, 29 | /// Data type of A elements 30 | typename ElementA_, 31 | /// Layout of A matrix (concept: layout::MapFunc) 32 | typename LayoutA_, 33 | /// Data type of B elements 34 | typename ElementB_, 35 | /// Layout of B matrix (concept: layout::MapFunc) 36 | typename LayoutB_, 37 | /// Element type of C matrix 38 | typename ElementC_, 39 | /// Layout of C matrix (concept: layout::MapFunc) 40 | typename LayoutC_, 41 | /// Addition operator of the semi-ring 42 | typename AdditionOp_, 43 | /// Multiplication operator of the semi-ring 44 | typename MultiplicationOp_ 45 | > 46 | struct SrmmaGeneric { 47 | 48 | /// Size of the Gemm problem - concept: cutlass::gemm::GemmShape<> 49 | using Shape = Shape_; 50 | 51 | /// Data type of operand A 52 | using ElementA = ElementA_; 53 | 54 | /// Layout of A matrix (concept: layout::MapFunc) 55 | using LayoutA = LayoutA_; 56 | 57 | /// Data type of operand B 58 | using ElementB = ElementB_; 59 | 60 | /// Layout of B matrix (concept: layout::MapFunc) 61 | using LayoutB = LayoutB_; 62 | 63 | /// Element type of operand C 64 | using ElementC = ElementC_; 65 | 66 | /// Layout of C matrix (concept: layout::MapFunc) 67 | using LayoutC = LayoutC_; 68 | 69 | /// Underlying semi-ring operators 70 | using AdditionOp = AdditionOp_; 71 | using MultiplicationOp = MultiplicationOp_; 72 | 73 | /// A operand storage 74 | using FragmentA = cutlass::Array; 75 | 76 | /// B operand storage 77 | using FragmentB = cutlass::Array; 78 | 79 | /// C operand storage 80 | using FragmentC = cutlass::Array; 81 | 82 | /// Instruction 83 | using SrmmaOp = arch::Srmma< 84 | cutlass::gemm::GemmShape<1,1,1>, 85 | 1, 86 | ElementA, LayoutA, 87 | ElementB, LayoutB, 88 | ElementC, LayoutC, 89 | AdditionOp, MultiplicationOp>; 90 | 91 | // 92 | // Methods 93 | // 94 | 95 | /// Computes a generalized matrix product on any semi-ring 96 | CUTLASS_HOST_DEVICE 97 | void operator()( 98 | FragmentC & D, 99 | FragmentA const & A, 100 | FragmentB const & B, 101 | FragmentC const & C) { 102 | 103 | cutlass::TensorRef a_ref( 104 | reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); 105 | 106 | cutlass::TensorRef b_ref( 107 | reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); 108 | 109 | cutlass::TensorRef d_ref( 110 | reinterpret_cast(&D), LayoutC::packed({ Shape::kM, Shape::kN })); 111 | 112 | SrmmaOp srmma_op; 113 | 114 | // Copy accumulators 115 | D = C; 116 | 117 | // Compute matrix product 118 | CUTLASS_PRAGMA_UNROLL 119 | for (int k = 0; k < Shape::kK; ++k) { 120 | 121 | CUTLASS_PRAGMA_UNROLL 122 | for (int n = 0; n < Shape::kN; ++n) { 123 | 124 | CUTLASS_PRAGMA_UNROLL 125 | for (int m = 0; m < Shape::kM; ++m) { 126 | 127 | int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; 128 | 129 | cutlass::MatrixCoord mn(m_serpentine, n); 130 | cutlass::MatrixCoord mk(m_serpentine, k); 131 | cutlass::MatrixCoord kn(k, n); 132 | 133 | cutlass::Array d; 134 | cutlass::Array a; 135 | cutlass::Array b; 136 | 137 | d[0] = d_ref.at(mn); 138 | a[0] = a_ref.at(mk); 139 | b[0] = b_ref.at(kn); 140 | 141 | srmma_op(d, a, b, d); 142 | 143 | d_ref.at(mn) = d[0]; 144 | } 145 | } 146 | } 147 | } 148 | }; 149 | 150 | 151 | ///////////////////////////////////////////////////////////////////////////////////////////////// 152 | 153 | /// Gemplate that handles conventional layouts for FFMA and DFMA GEMM 154 | template < 155 | /// Size of the Gemm problem - concept: cutlass::gemm::GemmShape<> 156 | typename Shape_, 157 | /// Data type of A elements 158 | typename ElementA_, 159 | /// Layout of A matrix (concept: layout::MapFunc) 160 | typename LayoutA_, 161 | /// Data type of B elements 162 | typename ElementB_, 163 | /// Layout of B matrix (concept: layout::MapFunc) 164 | typename LayoutB_, 165 | /// Element type of C matrix 166 | typename ElementC_, 167 | /// Layout of C matrix (concept: layout::MapFunc) 168 | typename LayoutC_, 169 | /// Addition operator of the semi-ring 170 | typename AdditionOp_, 171 | /// Multiplication operator of the semi-ring 172 | typename MultiplicationOp_ 173 | > 174 | struct Srmma< 175 | Shape_, 176 | ElementA_, 177 | LayoutA_, 178 | ElementB_, 179 | LayoutB_, 180 | ElementC_, 181 | LayoutC_, 182 | AdditionOp_, 183 | MultiplicationOp_, 184 | bool 185 | > { 186 | 187 | /// Size of the Gemm problem - concept: cutlass::gemm::GemmShape<> 188 | using Shape = Shape_; 189 | 190 | /// Data type of operand A 191 | using ElementA = ElementA_; 192 | 193 | /// Layout of A matrix (concept: layout::MapFunc) 194 | using LayoutA = LayoutA_; 195 | 196 | /// Data type of operand B 197 | using ElementB = ElementB_; 198 | 199 | /// Layout of B matrix (concept: layout::MapFunc) 200 | using LayoutB = LayoutB_; 201 | 202 | /// Element type of operand C 203 | using ElementC = ElementC_; 204 | 205 | /// Layout of C matrix (concept: layout::MapFunc) 206 | using LayoutC = LayoutC_; 207 | 208 | /// Underlying semi-ring operators 209 | using AdditionOp = AdditionOp_; 210 | using MultiplicationOp = MultiplicationOp_; 211 | 212 | /// A operand storage 213 | using FragmentA = Array; 214 | 215 | /// B operand storage 216 | using FragmentB = Array; 217 | 218 | /// C operand storage 219 | using FragmentC = Array; 220 | 221 | // 222 | // Methods 223 | // 224 | 225 | /// Computes a matrix product for any semi-ring 226 | CUTLASS_HOST_DEVICE 227 | void operator()( 228 | FragmentC & D, 229 | FragmentA const & A, 230 | FragmentB const & B, 231 | FragmentC const & C) { 232 | 233 | SrmmaGeneric< 234 | Shape, 235 | ElementA, 236 | LayoutA, 237 | ElementB, 238 | LayoutB, 239 | ElementC, 240 | LayoutC, 241 | AdditionOp, 242 | MultiplicationOp> srmma; 243 | 244 | srmma(D, A, B, C); 245 | } 246 | }; 247 | 248 | 249 | ///////////////////////////////////////////////////////////////////////////////////////////////// 250 | 251 | } // namespace thread 252 | } // namespace gemm 253 | } // namespace cuasr 254 | 255 | ///////////////////////////////////////////////////////////////////////////////////////////////// 256 | -------------------------------------------------------------------------------- /include/cuasr/gemm/threadblock/default_srmma.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/numeric_types.h" 12 | #include "cutlass/arch/arch.h" 13 | #include "cutlass/arch/wmma.h" 14 | 15 | #include "cutlass/layout/matrix.h" 16 | #include "cutlass/transform/threadblock/predicated_tile_iterator.h" 17 | #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" 18 | 19 | #include "cuasr/gemm/threadblock/default_srmma_core.h" 20 | 21 | //////////////////////////////////////////////////////////////////////////////// 22 | 23 | namespace cuasr { 24 | namespace gemm { 25 | namespace threadblock { 26 | 27 | //////////////////////////////////////////////////////////////////////////////// 28 | 29 | template < 30 | /// Element type for A matrix operand 31 | typename ElementA_, 32 | /// Layout type for A matrix operand 33 | typename LayoutA_, 34 | /// Access granularity of A matrix in units of elements 35 | int kAlignmentA, 36 | /// Element type for B matrix operand 37 | typename ElementB_, 38 | /// Layout type for B matrix operand 39 | typename LayoutB_, 40 | /// Access granularity of B matrix in units of elements 41 | int kAlignmentB, 42 | /// Element type for internal accumulation 43 | typename ElementAccumulator_, 44 | /// Layout type for C and D matrix operands 45 | typename LayoutC_, 46 | /// Operator class tag 47 | typename OperatorClass_, 48 | /// Tag indicating architecture to tune for 49 | typename ArchTag_, 50 | /// Threadblock-level tile size (concept: GemmShape) 51 | typename ThreadblockShape_, 52 | /// Warp-level tile size (concept: GemmShape) 53 | typename WarpShape_, 54 | /// Instruction-level tile size (concept: GemmShape) 55 | typename InstructionShape_, 56 | /// Addition operator of the semi-ring 57 | typename AdditionOp_, 58 | /// Multiplication operator of the semi-ring 59 | typename MultiplicationOp_, 60 | /// Number of stages used in the pipelined mainloop 61 | int Stages, 62 | /// Store the accumulators in row major or column major. 63 | /// Row major is used when output layout is interleaved. 64 | bool AccumulatorsInRowMajor = false 65 | > 66 | struct DefaultSrmma; 67 | 68 | //////////////////////////////////////////////////////////////////////////////// 69 | 70 | /// Specialization for row-major output (OperatorClass Simt) 71 | template < 72 | /// Element type for A matrix operand 73 | typename ElementA, 74 | /// Layout type for A matrix operand 75 | typename LayoutA, 76 | /// Access granularity of A matrix in units of elements 77 | int kAlignmentA, 78 | /// Element type for B matrix operand 79 | typename ElementB, 80 | /// Layout type for B matrix operand 81 | typename LayoutB, 82 | /// Access granularity of B matrix in units of elements 83 | int kAlignmentB, 84 | /// Element type for internal accumulation 85 | typename ElementAccumulator, 86 | /// Tag indicating architecture to tune for 87 | typename ArchTag, 88 | /// Threadblock-level tile size (concept: GemmShape) 89 | typename ThreadblockShape, 90 | /// Warp-level tile size (concept: GemmShape) 91 | typename WarpShape, 92 | /// Instruction-level tile size (concept: GemmShape) 93 | typename InstructionShape, 94 | /// Addition operator of the semi-ring 95 | typename AdditionOp, 96 | /// Multiplication operator of the semi-ring 97 | typename MultiplicationOp> 98 | struct DefaultSrmma { 102 | // Define the SrmmaCore components 103 | using SrmmaCore = typename cuasr::gemm::threadblock::DefaultSrmmaCore< 104 | ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, 105 | ElementB, LayoutB, ElementAccumulator, cutlass::layout::RowMajor, 106 | cutlass::arch::OpClassSimt, AdditionOp, MultiplicationOp, 2>; 107 | 108 | // Define iterators over tiles from the A operand 109 | using IteratorA = 110 | cutlass::transform::threadblock::PredicatedTileIterator< 111 | cutlass::MatrixShape, 112 | ElementA, LayoutA, 1, typename SrmmaCore::IteratorThreadMapA, kAlignmentA>; 113 | 114 | // Define iterators over tiles from the B operand 115 | using IteratorB = 116 | cutlass::transform::threadblock::PredicatedTileIterator< 117 | cutlass::MatrixShape, 118 | ElementB, LayoutB, 0, typename SrmmaCore::IteratorThreadMapB, kAlignmentB>; 119 | 120 | // Define the threadblock-scoped pipelined matrix multiply 121 | using ThreadblockSrmma = cuasr::gemm::threadblock::SrmmaPipelined< 122 | typename SrmmaCore::Shape, IteratorA, typename SrmmaCore::SmemIteratorA, 123 | IteratorB, typename SrmmaCore::SmemIteratorB, ElementAccumulator, 124 | cutlass::layout::RowMajor, typename SrmmaCore::MmaPolicy>; 125 | }; 126 | 127 | 128 | } // namespace threadblock 129 | } // namespace gemm 130 | } // namespace cuasr 131 | 132 | //////////////////////////////////////////////////////////////////////////////// 133 | -------------------------------------------------------------------------------- /include/cuasr/gemm/threadblock/default_srmma_core.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data 6 | layout of the global memory fragments, data types, and internal tile sizes. 7 | 8 | Partial specializations for threadblock::Mma operations targeting TensorOp instructions. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cutlass/cutlass.h" 14 | #include "cutlass/functional.h" 15 | #include "cutlass/array.h" 16 | #include "cutlass/numeric_types.h" 17 | #include "cutlass/matrix_shape.h" 18 | 19 | #include "cutlass/arch/cache_operation.h" 20 | #include "cutlass/gemm/warp/mma.h" 21 | 22 | #include "cuasr/gemm/threadblock/srmma_pipelined.h" 23 | 24 | ///////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | namespace cuasr { 27 | namespace gemm { 28 | namespace threadblock { 29 | 30 | ///////////////////////////////////////////////////////////////////////////////////////////////// 31 | 32 | /// Template defininng default matrix multiply operators inferred from threadblock tile size, 33 | /// global memory data layout, and target math instruction. 34 | template < 35 | /// Shape of threadblock-scoped matrix multiply operator 36 | typename Shape, 37 | /// Shape of warp-level matrix multiply operator 38 | typename WarpShape, 39 | /// Shape of one matrix production operation (concept: GemmShape) 40 | typename InstructionShape, 41 | /// Element data type of A operand 42 | typename ElementA, 43 | /// Layout of operand A 44 | typename LayoutA, 45 | /// Element data type of B operand 46 | typename ElementB, 47 | /// Layout of operand B 48 | typename LayoutB, 49 | /// Data type of accumulator 50 | typename ElementC, 51 | /// Layout of accumulator 52 | typename LayoutC, 53 | /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) 54 | typename OperatorClass, 55 | /// Addition operator of the semi-ring 56 | typename AdditionOp_, 57 | /// Multiplication operator of the semi-ring 58 | typename MultiplicationOp_, 59 | /// Number of stages 60 | int Stages = 2, 61 | /// Store the accumulators in row major or column major. 62 | /// Row major is usedd when output layout is interleaved. 63 | bool AccumulatorsInRowMajor = false, 64 | /// Cache operation of operand A 65 | cutlass::arch::CacheOperation::Kind CacheOpA = 66 | cutlass::arch::CacheOperation::Global, 67 | /// Cache operation of operand B 68 | cutlass::arch::CacheOperation::Kind CacheOpB = 69 | cutlass::arch::CacheOperation::Global, 70 | /// per-element transformation for elements of A 71 | cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone, 72 | /// per-element transformation for elements of B 73 | cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone, 74 | bool IsComplex = false // (is_complex::value || is_complex::value) 75 | > 76 | struct DefaultSrmmaCore; 77 | 78 | ///////////////////////////////////////////////////////////////////////////////////////////////// 79 | 80 | } // namespace threadblock 81 | } // namespace gemm 82 | } // namespace cuasr 83 | 84 | #include "cuasr/gemm/threadblock/default_srmma_core_simt.h" 85 | -------------------------------------------------------------------------------- /include/cuasr/gemm/threadblock/srmma_pipelined.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Template for a double-buffered threadblock-scoped GEMM kernel. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/array.h" 12 | #include "cutlass/aligned_buffer.h" 13 | #include "cutlass/numeric_conversion.h" 14 | 15 | #include "cutlass/numeric_types.h" 16 | #include "cutlass/matrix_shape.h" 17 | 18 | #include "cutlass/gemm/gemm.h" 19 | #include "cutlass/gemm/threadblock/mma_base.h" 20 | 21 | ///////////////////////////////////////////////////////////////////////////////////////////////// 22 | 23 | namespace cuasr { 24 | namespace gemm { 25 | namespace threadblock { 26 | 27 | ///////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. 30 | template < 31 | /// Size of the Gemm problem - concept: gemm::GemmShape<> 32 | typename Shape_, 33 | /// Iterates over tiles of A operand in global memory 34 | // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) 35 | typename IteratorA_, 36 | /// Iterates over tiles of A operand in shared memory 37 | /// (concept: WriteableTileIterator | RandomAccessTileIterator) 38 | typename SmemIteratorA_, 39 | /// Iterates over tiles of B operand in global memory 40 | // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) 41 | typename IteratorB_, 42 | /// Iterates over tiles of B operand in shared memory 43 | /// (concept: WriteableTileIterator | RandomAccessTileIterator) 44 | typename SmemIteratorB_, 45 | /// Data type of accumulator matrix 46 | typename ElementC_, 47 | /// Data type of accumulator matrix 48 | typename LayoutC_, 49 | /// Policy describing tuning details (concept: MmaPolicy) 50 | typename Policy_, 51 | /// Transformation applied to A operand 52 | typename TransformA_ = cutlass::NumericArrayConverter< 53 | typename SmemIteratorA_::Element, 54 | typename IteratorA_::Element, 55 | IteratorA_::Fragment::kElements>, 56 | /// 57 | /// Transformation applied to A operand 58 | typename TransformB_ = cutlass::NumericArrayConverter< 59 | typename SmemIteratorB_::Element, 60 | typename IteratorB_::Element, 61 | IteratorB_::Fragment::kElements>, 62 | /// Used for partial specialization 63 | typename Enable = bool 64 | > 65 | class SrmmaPipelined : public cutlass::gemm::threadblock::MmaBase { 66 | public: 67 | 68 | ///< Base class 69 | using Base = cutlass::gemm::threadblock::MmaBase; 70 | 71 | using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> 72 | using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory 73 | using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory 74 | using ElementC = ElementC_; ///< Data type of accumulator matrix 75 | using LayoutC = LayoutC_; ///< Layout of accumulator matrix 76 | using Policy = Policy_; ///< Policy describing tuning details 77 | 78 | using SmemIteratorA = SmemIteratorA_; 79 | using SmemIteratorB = SmemIteratorB_; 80 | 81 | using TransformA = TransformA_; 82 | using TransformB = TransformB_; 83 | 84 | // 85 | // Dependent types 86 | // 87 | 88 | /// Fragment of operand A loaded from global memory 89 | using FragmentA = typename IteratorA::Fragment; 90 | 91 | /// Fragment of operand B loaded from global memory 92 | using FragmentB = typename IteratorB::Fragment; 93 | 94 | /// Fragment of accumulator tile 95 | using FragmentC = typename Policy::Operator::FragmentC; 96 | 97 | /// Warp-level Mma 98 | using Operator = typename Policy::Operator; 99 | 100 | // staticaly assert kStages for SrmmaPipelined is two (Double-buffered pipeline) 101 | static_assert((Base::kStages==2), "SrmmaPipelined requires kStages set to value 2"); 102 | 103 | private: 104 | 105 | using WarpFragmentA = typename Operator::FragmentA; 106 | using WarpFragmentB = typename Operator::FragmentB; 107 | 108 | protected: 109 | 110 | /// Iterator to write threadblock-scoped tile of A operand to shared memory 111 | SmemIteratorA smem_iterator_A_; 112 | 113 | /// Iterator to write threadblock-scoped tile of B operand to shared memory 114 | SmemIteratorB smem_iterator_B_; 115 | 116 | ElementC additive_identity_; 117 | 118 | public: 119 | 120 | /// Construct from tensor references 121 | CUTLASS_DEVICE 122 | SrmmaPipelined( 123 | typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM 124 | int thread_idx, ///< ID within the threadblock 125 | int warp_idx, ///< ID of warp 126 | int lane_idx, ///< ID of each thread within a warp 127 | ElementC additive_identity ///< Identity value of multiply op 128 | ): 129 | Base(shared_storage, thread_idx, warp_idx, lane_idx), 130 | smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), 131 | smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), 132 | additive_identity_(additive_identity) { 133 | 134 | // Compute warp location within threadblock tile by mapping the warp_id to 135 | // three coordinates: 136 | // _m: the warp's position within the threadblock along the M dimension 137 | // _n: the warp's position within the threadblock along the N dimension 138 | // _k: the warp's position within the threadblock along the K dimension 139 | 140 | int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); 141 | int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); 142 | 143 | int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; 144 | int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; 145 | 146 | // Add per-warp offsets in units of warp-level tiles 147 | this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); 148 | this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); 149 | } 150 | 151 | /// Perform a threadblock-scoped matrix multiply-accumulate 152 | CUTLASS_DEVICE 153 | void operator()( 154 | int gemm_k_iterations, ///< number of iterations of the mainloop 155 | FragmentC &accum, ///< destination accumulator tile 156 | IteratorA iterator_A, ///< iterator over A operand in global memory 157 | IteratorB iterator_B, ///< iterator over B operand in global memory 158 | FragmentC const &src_accum, ///< source accumulator tile 159 | TransformA transform_A = TransformA(), ///< transformation applied to A fragment 160 | TransformB transform_B = TransformB()) { ///< transformation applied to B fragment 161 | 162 | // 163 | // Prologue 164 | // 165 | 166 | // Perform accumulation in the 'd' output operand 167 | accum = src_accum; 168 | 169 | FragmentA tb_frag_A; 170 | FragmentB tb_frag_B; 171 | 172 | tb_frag_A.fill(additive_identity_); 173 | tb_frag_B.fill(additive_identity_); 174 | 175 | // The last kblock is loaded in the prolog 176 | iterator_A.load(tb_frag_A); 177 | iterator_B.load(tb_frag_B); 178 | 179 | ++iterator_A; 180 | ++iterator_B; 181 | 182 | this->smem_iterator_A_.store(transform_A(tb_frag_A)); 183 | this->smem_iterator_B_.store(transform_B(tb_frag_B)); 184 | 185 | ++this->smem_iterator_A_; 186 | ++this->smem_iterator_B_; 187 | 188 | __syncthreads(); 189 | 190 | // Pair of fragments used to overlap shared memory loads and math instructions 191 | WarpFragmentA warp_frag_A[2]; 192 | WarpFragmentB warp_frag_B[2]; 193 | 194 | this->warp_tile_iterator_A_.set_kgroup_index(0); 195 | this->warp_tile_iterator_B_.set_kgroup_index(0); 196 | 197 | this->warp_tile_iterator_A_.load(warp_frag_A[0]); 198 | this->warp_tile_iterator_B_.load(warp_frag_B[0]); 199 | 200 | ++this->warp_tile_iterator_A_; 201 | ++this->warp_tile_iterator_B_; 202 | 203 | Operator warp_mma; 204 | 205 | int smem_write_stage_idx = 1; 206 | 207 | // Avoid reading out of bounds 208 | if (gemm_k_iterations <= 1) { 209 | iterator_A.clear_mask(); 210 | iterator_B.clear_mask(); 211 | } 212 | 213 | // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing 214 | // shared memory loads (which have the tighest latency requirement). 215 | 216 | // 217 | // Mainloop 218 | // 219 | 220 | // Note: The main loop does not support Base::kWarpGemmIterations == 2. 221 | CUTLASS_GEMM_LOOP 222 | for (; gemm_k_iterations > 0; --gemm_k_iterations) { 223 | // 224 | // Loop over GEMM K dimension 225 | // 226 | CUTLASS_PRAGMA_UNROLL 227 | for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { 228 | 229 | // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group 230 | // as the case may be. 231 | 232 | if (warp_mma_k == Base::kWarpGemmIterations - 1) { 233 | 234 | // Write fragments to shared memory 235 | this->smem_iterator_A_.store(transform_A(tb_frag_A)); 236 | 237 | this->smem_iterator_B_.store(transform_B(tb_frag_B)); 238 | 239 | __syncthreads(); 240 | 241 | ++this->smem_iterator_B_; 242 | ++this->smem_iterator_A_; 243 | 244 | // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory 245 | if (smem_write_stage_idx == 1) { 246 | this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); 247 | this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); 248 | } 249 | else { 250 | this->warp_tile_iterator_A_.add_tile_offset( 251 | {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); 252 | this->warp_tile_iterator_B_.add_tile_offset( 253 | {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 254 | 0}); 255 | } 256 | 257 | smem_write_stage_idx ^= 1; 258 | } 259 | 260 | this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); 261 | this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); 262 | 263 | this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); 264 | this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); 265 | 266 | ++this->warp_tile_iterator_A_; 267 | ++this->warp_tile_iterator_B_; 268 | 269 | if (warp_mma_k == 0) { 270 | 271 | iterator_A.load(tb_frag_A); 272 | iterator_B.load(tb_frag_B); 273 | 274 | ++iterator_A; 275 | ++iterator_B; 276 | 277 | // Avoid reading out of bounds if this was the last loop iteration 278 | if (gemm_k_iterations <= 2) { 279 | iterator_A.clear_mask(); 280 | iterator_B.clear_mask(); 281 | } 282 | } 283 | 284 | warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum); 285 | } 286 | } 287 | 288 | } 289 | }; 290 | 291 | ///////////////////////////////////////////////////////////////////////////////////////////////// 292 | 293 | } // namespace threadblock 294 | } // namespace gemm 295 | } // namespace cuasr 296 | -------------------------------------------------------------------------------- /include/cuasr/gemm/warp/srmma_simt.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). All rights reserved. 3 | **************************************************************************************************/ 4 | /*! \file 5 | \brief Templates implementing warp-level matrix multiply-accumulate operations. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "cutlass/cutlass.h" 11 | #include "cutlass/array.h" 12 | #include "cutlass/numeric_types.h" 13 | #include "cutlass/matrix_shape.h" 14 | 15 | #include "cutlass/gemm/gemm.h" 16 | #include "cutlass/gemm/warp/mma.h" 17 | #include "cutlass/gemm/warp/mma_simt_tile_iterator.h" 18 | 19 | #include "cuasr/gemm/thread/srmma.h" 20 | 21 | ///////////////////////////////////////////////////////////////////////////////////////////////// 22 | 23 | namespace cuasr { 24 | namespace gemm { 25 | namespace warp { 26 | 27 | ///////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. 30 | template < 31 | /// Size of the Gemm problem - concept: gemm::GemmShape<> 32 | typename Shape_, 33 | /// Data type of A elements 34 | typename ElementA_, 35 | /// Layout of A matrix (concept: MatrixLayout) 36 | typename LayoutA_, 37 | /// Data type of B elements 38 | typename ElementB_, 39 | /// Layout of B matrix (concept: MatrixLayout) 40 | typename LayoutB_, 41 | /// Element type of C matrix 42 | typename ElementC_, 43 | /// Layout of C matrix (concept: MatrixLayout) 44 | typename LayoutC_, 45 | /// Shape of the warp in units of thread (concept: MmaSimtPolicy) 46 | typename Policy_, 47 | /// Addition of the semi-ring 48 | typename AdditionOp_, 49 | /// Multiplication operator of the semi-ring 50 | typename MultiplicationOp_, 51 | /// Number of partitions along K dimension 52 | int PartitionsK = 1, 53 | /// Used for partial specialization 54 | typename Enable = bool 55 | > 56 | class SrmmaSimt { 57 | public: 58 | /// Shape of warp-level matrix operation (concept: GemmShape) 59 | using Shape = Shape_; 60 | 61 | /// Data type of multiplicand A 62 | using ElementA = ElementA_; 63 | 64 | /// Layout of multiplicand A 65 | using LayoutA = LayoutA_; 66 | 67 | /// Data type of multiplicand B 68 | using ElementB = ElementB_; 69 | 70 | /// Layout of multiplicand B 71 | using LayoutB = LayoutB_; 72 | 73 | /// Data type of accumulator matrix C 74 | using ElementC = ElementC_; 75 | 76 | /// Layout of accumulator matrix C 77 | using LayoutC = LayoutC_; 78 | 79 | /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) 80 | using Policy = Policy_; 81 | 82 | /// Indicates class of matrix operator 83 | using OperatorClass = cutlass::arch::OpClassSimt; 84 | 85 | /// Underlying semi-ring operators 86 | using AdditionOp = AdditionOp_; 87 | using MultiplicationOp = MultiplicationOp_; 88 | 89 | using ThreadLayoutA = typename cutlass::platform::conditional< 90 | cutlass::platform::is_same, LayoutA>:: 91 | value, 92 | cutlass::layout::ColumnMajor, 93 | typename cutlass::platform::conditional< 94 | cutlass::platform::is_same, LayoutA>:: 95 | value, 96 | cutlass::layout::RowMajor, 97 | LayoutA>::type>::type; 98 | 99 | using ThreadLayoutB = typename cutlass::platform::conditional< 100 | cutlass::platform::is_same, LayoutB>:: 101 | value, 102 | cutlass::layout::ColumnMajor, 103 | typename cutlass::platform::conditional< 104 | cutlass::platform::is_same, LayoutB>:: 105 | value, 106 | cutlass::layout::RowMajor, 107 | LayoutB>::type>::type; 108 | 109 | static constexpr bool use_dp4a 110 | = (cutlass::platform::is_same, LayoutA>:: 111 | value 112 | || cutlass::platform::is_same, LayoutA>:: 113 | value) 114 | && cutlass::platform::is_same::value 115 | && cutlass::platform::is_same::value; 116 | 117 | using dp4a_type = typename cutlass::platform::conditional< use_dp4a , int8_t, bool >::type; 118 | 119 | /// Thread-level matrix multiply accumulate operator 120 | using ThreadMma = cuasr::gemm::thread::Srmma< 121 | cutlass::gemm::GemmShape< 122 | Shape::kM / Policy::WarpShape::kRow, 123 | Shape::kN / Policy::WarpShape::kColumn, 124 | Policy::LaneMmaShape::kK>, 125 | ElementA, 126 | ThreadLayoutA, 127 | ElementB, 128 | ThreadLayoutB, 129 | ElementC, 130 | LayoutC, 131 | AdditionOp, 132 | MultiplicationOp, 133 | dp4a_type 134 | >; 135 | 136 | public: 137 | 138 | /// Iterates over the A operand in memory 139 | using IteratorA = cutlass::gemm::warp::MmaSimtTileIterator< 140 | cutlass::MatrixShape, 141 | cutlass::gemm::Operand::kA, 142 | ElementA, 143 | LayoutA, 144 | Policy, 145 | PartitionsK, 146 | Shape::kK 147 | >; 148 | 149 | /// Storage for A tile 150 | using FragmentA = typename IteratorA::Fragment; 151 | 152 | /// Iterates over the B operand in memory 153 | using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< 154 | cutlass::MatrixShape, 155 | cutlass::gemm::Operand::kB, 156 | ElementB, 157 | LayoutB, 158 | Policy, 159 | PartitionsK, 160 | Shape::kK 161 | >; 162 | 163 | /// Storage for B tile 164 | using FragmentB = typename IteratorB::Fragment; 165 | 166 | /// Iterates over the C operand in memory 167 | using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< 168 | cutlass::MatrixShape, 169 | cutlass::gemm::Operand::kC, 170 | ElementC, 171 | LayoutC, 172 | Policy 173 | >; 174 | 175 | /// Storage for C tile 176 | using FragmentC = typename ThreadMma::FragmentC; 177 | 178 | public: 179 | 180 | // 181 | // Methods 182 | // 183 | 184 | /// Ctor 185 | CUTLASS_DEVICE 186 | SrmmaSimt() {} 187 | 188 | /// Performs a warp-level matrix multiply-accumulate operation 189 | CUTLASS_DEVICE 190 | void operator()( 191 | FragmentC &d, 192 | FragmentA const &a, 193 | FragmentB const &b, 194 | FragmentC const &c, int group_idx = 0) const { 195 | 196 | ThreadMma srmma; 197 | 198 | srmma(d, a, b, c); 199 | } 200 | }; 201 | 202 | ///////////////////////////////////////////////////////////////////////////////////////////////// 203 | 204 | } // namespace warp 205 | } // namespace gemm 206 | } // namespace cuasr 207 | -------------------------------------------------------------------------------- /include/cuasr/reduction/kernel/reduce_split_k.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are permitted 5 | * provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright notice, this list of 7 | * conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright notice, this list of 9 | * conditions and the following disclaimer in the documentation and/or other materials 10 | * provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used 12 | * to endorse or promote products derived from this software without specific prior written 13 | * permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 16 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 17 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 18 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 19 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 20 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 21 | * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | * 24 | **************************************************************************************************/ 25 | /*! \file 26 | \brief Kernel performing a reduction over densely packed tensors in global memory 27 | */ 28 | 29 | #pragma once 30 | 31 | #include "cutlass/cutlass.h" 32 | #include "cutlass/tensor_ref.h" 33 | #include "cutlass/numeric_types.h" 34 | #include "cutlass/array.h" 35 | #include "cutlass/functional.h" 36 | #include "cutlass/matrix_shape.h" 37 | #include "cutlass/numeric_conversion.h" 38 | #include "cutlass/layout/matrix.h" 39 | 40 | #include "cuasr/reduction/thread/reduction_operators.h" 41 | 42 | ///////////////////////////////////////////////////////////////////////////////////////////////// 43 | 44 | namespace cuasr { 45 | namespace reduction { 46 | namespace kernel { 47 | 48 | ///////////////////////////////////////////////////////////////////////////////////////////////// 49 | 50 | template < 51 | typename Shape_, ///< shape of CTA (concept: MatrixShape) 52 | typename OutputOp_ , ///< output operator (concept: epilogue::thread operator) 53 | typename ReductionOp_, ///< functional addition operator to be used for reduction 54 | int PartitionsPerStage = 4 ///< number of partitions to issue 55 | > 56 | class ReduceSplitK { 57 | public: 58 | // type aliases 59 | using Shape = Shape_; 60 | using ReductionOp = ReductionOp_; 61 | using OutputOp = OutputOp_; 62 | 63 | using ElementWorkspace = typename ReductionOp::Element; 64 | using ElementAccumulator = typename ReductionOp::ElementAccumulator; 65 | using ElementOutput = typename OutputOp::ElementOutput; 66 | 67 | // static storage 68 | static int const kElementsPerAccess = OutputOp::kCount; 69 | static int const kPartitionsPerStage = PartitionsPerStage; 70 | 71 | using WorkspaceTensorRef = cutlass::TensorRef; 72 | using OutputTensorRef = cutlass::TensorRef; 73 | 74 | using FragmentWorkspace = cutlass::AlignedArray; 75 | using FragmentAccumulator = cutlass::Array; 76 | using FragmentOutput = cutlass::AlignedArray; 77 | 78 | // 79 | // Types nested 80 | // 81 | 82 | /// Params structure 83 | struct Params { 84 | 85 | cutlass::MatrixCoord problem_size; 86 | int partitions; 87 | size_t partition_stride; 88 | WorkspaceTensorRef workspace; 89 | OutputTensorRef destination; 90 | OutputTensorRef source; 91 | typename OutputOp::Params output; 92 | typename ReductionOp::Params reduction; 93 | 94 | // 95 | // Methods 96 | // 97 | 98 | CUTLASS_HOST_DEVICE 99 | Params() { } 100 | 101 | CUTLASS_HOST_DEVICE 102 | Params( 103 | cutlass::MatrixCoord problem_size_, 104 | int partitions_, 105 | size_t partition_stride_, 106 | WorkspaceTensorRef workspace_, 107 | OutputTensorRef destination_, 108 | OutputTensorRef source_, 109 | typename OutputOp::Params output_ = typename OutputOp::Params(), 110 | typename ReductionOp::Params reduction_ = typename ReductionOp::Params() 111 | ): 112 | problem_size(problem_size_), 113 | partitions(partitions_), 114 | partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess), 115 | workspace(workspace_), 116 | destination(destination_), 117 | source(source_), 118 | output(output_), 119 | reduction(reduction_) { 120 | 121 | } 122 | }; 123 | 124 | struct SharedStorage { }; 125 | 126 | public: 127 | 128 | /// Computes the grid size given a chosen threadblock shape 129 | CUTLASS_HOST_DEVICE 130 | static dim3 grid_shape( 131 | cutlass::MatrixCoord problem_size) { 132 | 133 | return dim3( 134 | (problem_size.row() + Shape::kRow - 1) / Shape::kRow, 135 | (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn); 136 | } 137 | 138 | /// Determines the threadblock shape 139 | CUTLASS_HOST_DEVICE 140 | static dim3 block_shape() { 141 | return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow); 142 | } 143 | 144 | /// Perform a reduction 145 | CUTLASS_DEVICE 146 | void operator()(Params const ¶ms, SharedStorage &storage) { 147 | 148 | // Determine CTA position 149 | cutlass::MatrixCoord thread_offset( 150 | int(blockIdx.x) * Shape::kRow + threadIdx.y, 151 | int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess 152 | ); 153 | 154 | // One guard conditional 155 | if (!(thread_offset.row() < params.problem_size.row() && 156 | thread_offset.column() < params.problem_size.column())) { 157 | 158 | return; 159 | } 160 | 161 | 162 | ReductionOp reduction_op(params.reduction); 163 | 164 | FragmentAccumulator accumulator; 165 | 166 | ElementWorkspace kReductionIdentity = ReductionOp::Identity; 167 | accumulator.fill(kReductionIdentity); 168 | 169 | // 170 | // Load the first slice 171 | // 172 | 173 | char const *workspace_ptr = 174 | reinterpret_cast( 175 | params.workspace.data() + params.workspace.offset(thread_offset)); 176 | 177 | FragmentWorkspace workspace_frag[kPartitionsPerStage]; 178 | 179 | // 180 | // Construct the output operator 181 | // 182 | 183 | OutputOp output_op(params.output); 184 | 185 | // 186 | // Load and accumulate with a simple batched loading sequence. 187 | // 188 | 189 | CUTLASS_PRAGMA_NO_UNROLL 190 | for (int k = 0; k < params.partitions; k += kPartitionsPerStage) { 191 | 192 | CUTLASS_PRAGMA_UNROLL 193 | for (int i = 0; i < kPartitionsPerStage; ++i) { 194 | if (k + i < params.partitions) { 195 | workspace_frag[i] = *reinterpret_cast(workspace_ptr); 196 | workspace_ptr += params.partition_stride; 197 | } 198 | } 199 | 200 | CUTLASS_PRAGMA_UNROLL 201 | for (int i = 0; i < kPartitionsPerStage; ++i) { 202 | if (k + i < params.partitions) { 203 | accumulator = reduction_op(accumulator, workspace_frag[i]); 204 | } 205 | } 206 | } 207 | 208 | // 209 | // Conditionally load the source 210 | // 211 | 212 | FragmentOutput source_frag; 213 | 214 | source_frag.fill(kReductionIdentity); 215 | 216 | FragmentOutput const *source_ptr = reinterpret_cast( 217 | params.source.data() + params.source.offset(thread_offset)); 218 | 219 | if (output_op.is_source_needed()) { 220 | reinterpret_cast(source_frag) = *source_ptr; 221 | } 222 | 223 | // 224 | // Compute the output 225 | // 226 | 227 | typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag); 228 | 229 | // 230 | // Store 231 | // 232 | 233 | FragmentOutput *dest_ptr = reinterpret_cast( 234 | params.destination.data() + params.destination.offset(thread_offset)); 235 | 236 | *dest_ptr = reinterpret_cast(output_frag); 237 | } 238 | }; 239 | 240 | ///////////////////////////////////////////////////////////////////////////////////////////////// 241 | 242 | } // namespace kernel 243 | } // namespace reduction 244 | } // namespace cuasr 245 | -------------------------------------------------------------------------------- /include/cuasr/reduction/thread/reduce.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are permitted 5 | * provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright notice, this list of 7 | * conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright notice, this list of 9 | * conditions and the following disclaimer in the documentation and/or other materials 10 | * provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used 12 | * to endorse or promote products derived from this software without specific prior written 13 | * permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 16 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 17 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 18 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 19 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 20 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 21 | * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | * 24 | **************************************************************************************************/ 25 | /*! \file 26 | \brief Defines basic thread level reduction with specializations for Array. 27 | */ 28 | 29 | #pragma once 30 | 31 | #include "cutlass/cutlass.h" 32 | #include "cutlass/numeric_types.h" 33 | #include "cutlass/array.h" 34 | #include "cutlass/half.h" 35 | #include "cutlass/functional.h" 36 | 37 | ///////////////////////////////////////////////////////////////////////////////////////////////// 38 | 39 | namespace cuasr { 40 | namespace reduction { 41 | namespace thread { 42 | 43 | // Structure to compute the thread level reduction with semiring addition operator 44 | template 45 | struct Reduce { 46 | CUTLASS_HOST_DEVICE 47 | T operator()(T lhs, T const &rhs) const { 48 | AdditionOp add; 49 | return add(lhs, rhs); 50 | } 51 | 52 | CUTLASS_HOST_DEVICE 53 | cutlass::Array operator()(cutlass::Array const &in) const { 54 | cutlass::Array result; 55 | result.fill(AdditionOp::Identity); 56 | 57 | CUTLASS_PRAGMA_UNROLL 58 | for (auto i = 0; i < N; ++i) { 59 | result[0] = this->operator()(result[0], in[i]); 60 | } 61 | 62 | return result; 63 | } 64 | }; 65 | 66 | ///////////////////////////////////////////////////////////////////////////////////////////////// 67 | 68 | } // namespace thread 69 | } // namespace reduction 70 | } // namespace cuasr 71 | 72 | ///////////////////////////////////////////////////////////////////////////////////////////////// 73 | -------------------------------------------------------------------------------- /include/cuasr/reduction/thread/reduction_operators.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without modification, are permitted 5 | * provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright notice, this list of 7 | * conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright notice, this list of 9 | * conditions and the following disclaimer in the documentation and/or other materials 10 | * provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used 12 | * to endorse or promote products derived from this software without specific prior written 13 | * permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 16 | * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND 17 | * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE 18 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 19 | * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 20 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 21 | * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | * 24 | **************************************************************************************************/ 25 | /*! \file 26 | \brief Kernel performing a reduction over densely packed tensors in global memory 27 | */ 28 | 29 | #pragma once 30 | 31 | #include "cutlass/cutlass.h" 32 | #include "cutlass/tensor_ref.h" 33 | #include "cutlass/numeric_types.h" 34 | #include "cutlass/array.h" 35 | #include "cutlass/functional.h" 36 | #include "cutlass/numeric_conversion.h" 37 | 38 | namespace cuasr { 39 | namespace reduction { 40 | namespace thread { 41 | 42 | /// Mixed-precision reduction with a functional reduction operator 43 | template < 44 | typename AdditionOp_, 45 | typename ElementAccumulator_, 46 | typename Element_, 47 | int Count = 1 48 | > 49 | struct SemiringReduce { 50 | // Type aliases 51 | using AdditionOp = AdditionOp_; 52 | using ElementAccumulator = ElementAccumulator_; 53 | using Element = Element_; 54 | 55 | // Static members 56 | static int const kCount = Count; 57 | static Element constexpr Identity = AdditionOp::Identity; 58 | 59 | using FragmentAccumulator = cutlass::Array; 60 | using FragmentElement = cutlass::Array; 61 | 62 | // Types nested 63 | struct Params { }; 64 | 65 | // Data members 66 | Params params; 67 | 68 | // Methods 69 | 70 | /// Constructor 71 | CUTLASS_HOST_DEVICE 72 | SemiringReduce(Params params_ = Params()): params(params_) { } 73 | 74 | /// Operator 75 | CUTLASS_HOST_DEVICE 76 | FragmentAccumulator operator()( 77 | FragmentAccumulator accumulator, 78 | FragmentElement element) const { 79 | 80 | AdditionOp op; 81 | cutlass::NumericArrayConverter< 82 | ElementAccumulator, 83 | Element, 84 | kCount, 85 | cutlass::PreferredRoundingMode::kRound> converter; 86 | 87 | return op(accumulator, converter(element)); 88 | } 89 | }; 90 | 91 | } // namespace thread 92 | } // namespace reduction 93 | } // namespace cuasr 94 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # first make sure we have gtest checked-out and its up-to-date 2 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 3 | message(STATUS "Checking submodule version for google/googletest") 4 | execute_process( 5 | COMMAND ${GIT_EXECUTABLE} submodule update --init ${PROJECT_SOURCE_DIR}/test/gtest 6 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 7 | OUTPUT_VARIABLE GIT_SUBMOD_STDOUT OUTPUT_STRIP_TRAILING_WHITESPACE 8 | ERROR_VARIABLE GIT_SUBMOD_STDERR ERROR_STRIP_TRAILING_WHITESPACE 9 | RESULT_VARIABLE GIT_SUBMOD_RESULT 10 | ) 11 | if(NOT GIT_SUBMOD_RESULT EQUAL "0") 12 | message(FATAL_ERROR "git submodule update --init failed with ${GIT_SUBMOD_RESULT}, please checkout gtest manually. Git stdout was ${GIT_SUBMOD_STDOUT}. Git stderr was ${GIT_SUBMOD_STDERR}.") 13 | elseif(NOT ${GIT_SUBMOD_STDOUT} STREQUAL "") 14 | message(STATUS ${GIT_SUBMOD_STDOUT}) 15 | endif() 16 | endif() 17 | 18 | if(NOT EXISTS "${PROJECT_SOURCE_DIR}/test/gtest/googletest/include") 19 | message(FATAL_ERROR "GTest submodule is not present and automatic checkout failed, please checkout gtest manually.") 20 | endif() 21 | 22 | add_subdirectory(gtest) 23 | add_subdirectory(regress) 24 | add_subdirectory(device) 25 | -------------------------------------------------------------------------------- /test/device/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SIMT_TEST_SRCS CONFIGURE_DEPENDS *.cu) 2 | add_executable(cuasr_test_srgemm_device 3 | ${PROJECT_SOURCE_DIR}/test/harness.cpp 4 | ${SIMT_TEST_SRCS} 5 | ) 6 | target_include_directories( 7 | cuasr_test_srgemm_device 8 | PRIVATE 9 | ${PROJECT_SOURCE_DIR}/include/ 10 | ${PROJECT_SOURCE_DIR}/tools/include/ 11 | ${PROJECT_SOURCE_DIR}/cutlass/include/ 12 | ${PROJECT_SOURCE_DIR}/cutlass/tools/util/include/ 13 | ) 14 | target_link_libraries(cuasr_test_srgemm_device 15 | gtest 16 | ${cuASR_LIB_NAME} 17 | ) 18 | add_test( 19 | NAME cuasr_test_srgemm_device 20 | COMMAND cuasr_test_srgemm_device 21 | ) 22 | if(NOT DEFINED CUASR_TEST_LEVEL) 23 | set(CUASR_TEST_LEVEL 0) 24 | endif() 25 | target_compile_definitions(cuasr_test_srgemm_device 26 | PRIVATE CUASR_TEST_LEVEL=${CUASR_TEST_LEVEL} 27 | ) 28 | -------------------------------------------------------------------------------- /test/device/simt_sm50.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # this file creates the test/unit/gemm/device simt tests and the CMake file to go with it 4 | ################################################################################ 5 | # parameters 6 | # Edge - for tiles, the edges represent the length of one side 7 | # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles 8 | # MaxEdge - maximum length of each edge 9 | # Min/Max - minimum/maximum of the product of edge lengths 10 | ################################################################################ 11 | THREADS_PER_WARP = 32 12 | WARPS_PER_TB_EDGE = [1, 2, 4, 8, 16] 13 | WARPS_PER_TB_RATIO = 2 14 | WARPS_PER_TB_MAX = 16 15 | # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases 16 | 17 | WARP_SHAPE_EDGES = [8, 16, 32, 64, 128, 256] 18 | WARP_SHAPE_RATIO = 4 19 | WARP_SHAPE_MAX = 64*64 20 | WARP_SHAPE_MIN = 8*8 21 | 22 | THREADBLOCK_MAX_EDGE = 256 23 | 24 | UNROLL_MIN = 8 25 | 26 | # char, type bits/elem, max tile, L0 threadblock tiles 27 | precisions = [ 28 | ["d", "double", 64, 64*64, [[64, 64], [32, 32]]], 29 | ["s", "float", 32, 128 * 30 | 128, [[128, 256], [128, 128], [64, 64]]], 31 | # ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], 32 | # ["i", "int", 32, 128*128, [[128, 64], [16, 32]]], 33 | ] 34 | 35 | transposes = [ 36 | [False, False, True], 37 | [False, False, False], 38 | [False, True, True], 39 | [False, True, False], 40 | [True, False, True], 41 | [True, False, False], 42 | [True, True, True], 43 | [True, True, False], 44 | ] 45 | 46 | semiring_operators = [ 47 | ["plus", "multiplies"], # regular GEMM 48 | ["minimum", "plus"], # min-plus (tropical) 49 | ["maximum", "plus"], # max-plus 50 | ["minimum", "maximum"], # min-max 51 | ["maximum", "minimum"], # max-min 52 | ["minimum", "multiplies"], # min-multiplies 53 | ["maximum", "multiplies"], # max-multiplies 54 | ["binary_or", "binary_and"] # or-and 55 | ] 56 | 57 | testfile_header = """\ 58 | /*************************************************************************************************** 59 | * Copyright (c) 2020, Vijay Thakkar (thakkarv@gatech.edu). 60 | **************************************************************************************************/ 61 | ///////////////////////////////////////////////////////////////// 62 | // THIS TEST FILE IS GENERATED AUTOMATICALLY : DO NOT MODIFY // 63 | ///////////////////////////////////////////////////////////////// 64 | 65 | #include "gtest/gtest.h" 66 | 67 | /// from upstream cutlass 68 | #include "cutlass/cutlass.h" 69 | #include "cutlass/gemm/gemm.h" 70 | #include "cutlass/gemm/threadblock/threadblock_swizzle.h" 71 | 72 | /// from cuasr lib 73 | #include "cuasr/gemm/device/default_srgemm_configuration.h" 74 | #include "cuasr/gemm/device/srgemm.h" 75 | #include "cuasr/functional.h" 76 | 77 | /// from cuasr tools 78 | #include "cuasr/reference/srgemm/host_srgemm.h" 79 | 80 | /// from local test dir 81 | #include "testbed.h" 82 | 83 | """ 84 | 85 | test_header_template = """\ 86 | //////////////////////////////////////////////////////////////////////////////// 87 | // Elements / Thread: {:3.0f} x {:3.0f} 88 | // Threads / Warp: {:3.0f} x {:3.0f} 89 | // Warps / Block: {:3.0f} x {:3.0f} 90 | // Threadblock: {:3.0f} x {:3.0f} x {:3.0f} 91 | """ 92 | 93 | test_template = """\ 94 | #if defined(CUASR_TEST_LEVEL) and (CUASR_TEST_LEVEL >= {21}) 95 | TEST(SM50_device_{0}_{1}_{2}srgemm_{4}{5}_{6}, {10}x{11}x{12}_{13}x{14}x1_{15}x{16}_{17}x{18}_{19}x{20}) {{ 96 | using precision = {3}; 97 | using OpClass = cutlass::arch::OpClassSimt; 98 | using SmArch = cutlass::arch::Sm50; 99 | 100 | using ThreadblockShape = cutlass::gemm::GemmShape<{10}, {11}, {12}>; 101 | using WarpShape = cutlass::gemm::GemmShape<{13}, {14}, {12}>; 102 | using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; 103 | 104 | using Config = typename cuasr::gemm::device::DefaultSemiRingConfiguration< // 105 | precision, precision, precision, precision, OpClass, // 106 | cuasr::{0}, cuasr::{1}, SmArch>; 107 | 108 | using AddOp = Config::AdditionOp; 109 | using MultOp = Config::MultiplicationOp; 110 | using EpilogueOutputOp = Config::EpilogueOutputOp; 111 | 112 | using Srgemm = cuasr::gemm::device::Srgemm< // 113 | AddOp, MultOp, // 114 | precision, cutlass::layout::{7}Major, // 115 | precision, cutlass::layout::{8}Major, // 116 | precision, cutlass::layout::{9}Major, // 117 | precision, OpClass, SmArch, // 118 | ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, // 119 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>; 120 | 121 | EXPECT_TRUE(cuasr::test::gemm::device::TestAllGemm()); 122 | }} 123 | #endif 124 | 125 | """ 126 | 127 | 128 | def write_test_file_header(testfile): 129 | testfile.write(testfile_header) 130 | 131 | 132 | def write_test_to_file( 133 | testfile, 134 | add_op, 135 | mult_op, 136 | precision_char, 137 | precision_type, 138 | transA, 139 | transB, 140 | transC, 141 | threadblock_tile, 142 | unroll, 143 | warp_shape, 144 | thread_tileM, 145 | thread_tileN, 146 | warp_threadsM, 147 | warp_threadsN, 148 | warps_per_tb, 149 | test_level): 150 | print("{:.0f}x{:.0f}x{:.0f}__{:.0f}x{:.0f}_{:.0f}x{:.0f}_{:.0f}x{:.0f}".format( 151 | threadblock_tile[0], threadblock_tile[1], unroll, 152 | thread_tileM, thread_tileN, 153 | warp_threadsM, warp_threadsN, 154 | warps_per_tb[0], warps_per_tb[1])) 155 | 156 | testfile.write(test_header_template.format( 157 | thread_tileM, thread_tileN, 158 | warp_threadsM, warp_threadsN, 159 | warps_per_tb[0], warps_per_tb[1], 160 | threadblock_tile[0], threadblock_tile[1], unroll 161 | )) 162 | 163 | trans_typeA = "Column" if transA == "n" else "Row" 164 | trans_typeB = "Column" if transB == "n" else "Row" 165 | trans_typeC = "Column" if transC == "n" else "Row" 166 | print(precision_type) 167 | testfile.write(test_template.format( 168 | add_op, # 0 169 | mult_op, # 1 170 | precision_char, # 2 171 | precision_type, # 3 172 | transA, # 4 173 | transB, # 5 174 | transC, # 6 175 | trans_typeA, # 7 176 | trans_typeB, # 8 177 | trans_typeC, # 9 178 | int(threadblock_tile[0]), # 10 179 | int(threadblock_tile[1]), # 11 180 | int(unroll), # 12 181 | int(warp_shape[0]), # 13 182 | int(warp_shape[1]), # 14 183 | int(thread_tileM), # 15 184 | int(thread_tileN), # 16 185 | int(warp_threadsM), # 17 186 | int(warp_threadsN), # 18 187 | int(warps_per_tb[0]), # 19 188 | int(warps_per_tb[1]), # 20 189 | int(test_level) # 21 190 | )) 191 | 192 | 193 | def main(output_dir: str): 194 | # warps per threadblock 195 | warps_per_threadblocks = [] 196 | for warps_per_tb0 in WARPS_PER_TB_EDGE: 197 | for warps_per_tb1 in WARPS_PER_TB_EDGE: 198 | if (warps_per_tb0 / warps_per_tb1 <= WARPS_PER_TB_RATIO) \ 199 | and (warps_per_tb1 / warps_per_tb0 <= WARPS_PER_TB_RATIO) \ 200 | and (warps_per_tb0 * warps_per_tb1 <= WARPS_PER_TB_MAX): 201 | warps_per_threadblocks.append([warps_per_tb0, warps_per_tb1]) 202 | print("Warps Per Threadblocks", warps_per_threadblocks) 203 | 204 | # warp shapes 205 | warp_shapes = [] 206 | for warp0 in WARP_SHAPE_EDGES: 207 | for warp1 in WARP_SHAPE_EDGES: 208 | if (warp0 / warp1 <= WARP_SHAPE_RATIO) \ 209 | and (warp1 / warp0 <= WARP_SHAPE_RATIO) \ 210 | and (warp0 * warp1 <= WARP_SHAPE_MAX) \ 211 | and (warp0 * warp1 > WARP_SHAPE_MIN): 212 | warp_shapes.append([warp0, warp1]) 213 | print("Warp Shapes", warp_shapes) 214 | 215 | # create kernels 216 | # create a file for each precision/transpose 217 | # each file contains many tile sizes 218 | 219 | # for all semiring add/mul pairs 220 | num_tests = 0 221 | testcount_L0 = 0 222 | testcount_L1 = 0 223 | testcount_L2 = 0 224 | for add_op, mult_op in semiring_operators: 225 | 226 | # precisions 227 | for precision in precisions: 228 | precision_char = precision[0] 229 | precision_type = precision[1] 230 | precision_bits = precision[2] 231 | tb_max_elements = precision[3] 232 | tb_tiles_L0 = precision[4] 233 | 234 | # transposes 235 | for transpose in transposes: 236 | # get transpose char 237 | column_major_A = transpose[0] 238 | column_major_B = transpose[1] 239 | column_major_C = transpose[2] 240 | transA = "n" if column_major_A else "t" 241 | transB = "n" if column_major_B else "t" 242 | transC = "n" if column_major_C else "t" 243 | 244 | # open file 245 | testfile_name = "simt_{}_{}_{}srgemm_{}{}_{}_sm50.cu".format( 246 | add_op, mult_op, precision_char, 247 | transA, transB, transC) 248 | print("\n", testfile_name) 249 | 250 | filePath = os.path.join(output_dir, testfile_name) 251 | with open(filePath, "w") as testfile: 252 | write_test_file_header(testfile) 253 | 254 | # keeps track of which L0 and L1 test shapes have been seen 255 | seen_tb_tiles_L0 = {} 256 | seen_tb_tiles_L1 = {} 257 | 258 | # for each combination of tile sizes 259 | for warps_per_tb in warps_per_threadblocks: 260 | for warp_shape in warp_shapes: 261 | warp_threadsM = 0 262 | if warp_shape[0] > warp_shape[1]: 263 | warp_threadsM = 8 264 | else: 265 | warp_threadsM = 4 266 | warp_threadsN = THREADS_PER_WARP / warp_threadsM 267 | 268 | # skip shapes with conflicting rectangularity 269 | # they are unlikely to be fastest 270 | blockG = warps_per_tb[0] > warps_per_tb[1] 271 | blockL = warps_per_tb[0] < warps_per_tb[1] 272 | warpG = warp_shape[0] > warp_shape[1] 273 | warpL = warp_shape[0] < warp_shape[1] 274 | 275 | blockG2 = warps_per_tb[0] > warps_per_tb[1]*2 276 | blockL2 = warps_per_tb[0] * \ 277 | 2 < warps_per_tb[1] 278 | warpG2 = warp_shape[0] > warp_shape[1]*2 279 | warpL2 = warp_shape[0]*2 < warp_shape[1] 280 | 281 | if blockG2 and warpL: 282 | continue 283 | if blockL2 and warpG: 284 | continue 285 | if warpG2 and blockL: 286 | continue 287 | if warpL2 and blockG: 288 | continue 289 | 290 | # check threadblock ratios and max 291 | threadblock_tile = [warp_shape[0]*warps_per_tb[0], 292 | warp_shape[1]*warps_per_tb[1]] 293 | if threadblock_tile[0] * threadblock_tile[1] > tb_max_elements: 294 | continue 295 | if threadblock_tile[0] > THREADBLOCK_MAX_EDGE: 296 | continue 297 | if threadblock_tile[1] > THREADBLOCK_MAX_EDGE: 298 | continue 299 | total_threads = THREADS_PER_WARP * \ 300 | warps_per_tb[0]*warps_per_tb[1] 301 | 302 | # calculate unroll 303 | # ensure that every iteration at least a full load of A,B are done 304 | unroll_min0 = total_threads / threadblock_tile[0] 305 | unroll_min1 = total_threads / threadblock_tile[1] 306 | unroll = max(UNROLL_MIN, unroll_min0, unroll_min1) 307 | 308 | thread_tileM = warp_shape[0] / warp_threadsM 309 | thread_tileN = warp_shape[1] / warp_threadsN 310 | if thread_tileM < 2 or thread_tileN < 2: 311 | continue 312 | if thread_tileM * thread_tileN * precision_bits > 8 * 8 * 32: 313 | continue 314 | 315 | # epilogue currently only supports N < THREADS_PER_WARP 316 | if threadblock_tile[1] < THREADS_PER_WARP: 317 | continue 318 | 319 | # limit smem 320 | shmem_bitsA = threadblock_tile[0] * unroll * 2 * precision_bits 321 | shmem_bitsB = threadblock_tile[1] * unroll * 2 * precision_bits 322 | shmem_KiBs = ((shmem_bitsA + shmem_bitsB) / 8) / 1024 323 | if (shmem_KiBs > 48): 324 | continue 325 | 326 | test_level = -1 327 | for tileId in range(0, len(tb_tiles_L0)): 328 | tbTile = tb_tiles_L0[tileId] 329 | if tbTile[0] == threadblock_tile[0] and tbTile[1] == threadblock_tile[1]: 330 | if tuple(tbTile) not in seen_tb_tiles_L0: 331 | test_level = 0 332 | testcount_L0 += 1 333 | seen_tb_tiles_L0[tuple(tbTile)] = True 334 | 335 | # test level 1 336 | if test_level < 0: 337 | if tuple(threadblock_tile) not in seen_tb_tiles_L1: 338 | test_level = 1 339 | testcount_L1 += 1 340 | seen_tb_tiles_L1[tuple(threadblock_tile)] = True 341 | 342 | # test level 2 343 | if test_level < 0: 344 | test_level = 2 345 | testcount_L2 += 1 346 | 347 | # write this tile to file 348 | write_test_to_file( 349 | testfile, 350 | add_op, 351 | mult_op, 352 | precision_char, 353 | precision_type, 354 | transA, 355 | transB, 356 | transC, 357 | threadblock_tile, 358 | unroll, 359 | warp_shape, 360 | thread_tileM, 361 | thread_tileN, 362 | warp_threadsM, 363 | warp_threadsN, 364 | warps_per_tb, 365 | test_level) 366 | num_tests += 1 367 | print("Total test count per semi-ring = {}".format(num_tests//len(semiring_operators))) 368 | 369 | 370 | if __name__ == "__main__": 371 | main(".") 372 | -------------------------------------------------------------------------------- /test/device/testbed.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cutlass/util/distribution.h" 6 | #include "cutlass/util/host_tensor.h" 7 | #include "cutlass/util/reference/host/gemm.h" 8 | #include "cutlass/util/reference/host/tensor_compare.h" 9 | #include "cutlass/util/reference/host/tensor_copy.h" 10 | #include "cutlass/util/reference/host/tensor_fill.h" 11 | #include "cutlass/util/reference/host/tensor_norm.h" 12 | #include "cutlass/util/tensor_view_io.h" 13 | 14 | #include "cuasr/reference/srgemm/host_srgemm.h" 15 | 16 | #include 17 | 18 | namespace cuasr { 19 | namespace test { 20 | namespace gemm { 21 | namespace device { 22 | 23 | namespace { 24 | inline char const *to_string(cutlass::Status status) { 25 | switch (status) { 26 | case cutlass::Status::kSuccess: 27 | return "kSuccess"; 28 | case cutlass::Status::kErrorMisalignedOperand: 29 | return "kErrorMisalignedOperand"; 30 | case cutlass::Status::kErrorInvalidLayout: 31 | return "kErrorInvalidLayout"; 32 | case cutlass::Status::kErrorInvalidProblem: 33 | return "kErrorInvalidProblem"; 34 | case cutlass::Status::kErrorNotSupported: 35 | return "kErrorNotSupported"; 36 | case cutlass::Status::kErrorWorkspaceNull: 37 | return "kErrorWorkspaceNull"; 38 | case cutlass::Status::kErrorInternal: 39 | return "kErrorInternal"; 40 | case cutlass::Status::kInvalid: 41 | return "kInvalid"; 42 | default: 43 | break; 44 | } 45 | return "invalid"; 46 | } 47 | } 48 | 49 | // Given a SIMT SRGEMM, runs test cases against it 50 | template 51 | struct Testbed { 52 | using ElementAccumulator = typename Srgemm::ElementAccumulator; 53 | using ElementCompute = 54 | typename Srgemm::SrgemmKernel::Epilogue::OutputOp::ElementCompute; 55 | 56 | /// Initialization 57 | cutlass::Distribution::Kind init_A; 58 | cutlass::Distribution::Kind init_B; 59 | cutlass::Distribution::Kind init_C; 60 | uint64_t seed; 61 | 62 | cutlass::HostTensor tensor_A; 63 | cutlass::HostTensor tensor_B; 64 | cutlass::HostTensor tensor_C; 65 | cutlass::HostTensor tensor_D; 66 | cutlass::HostTensor reference_D; 67 | 68 | // Methods 69 | Testbed( 70 | cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, 71 | cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, 72 | cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, 73 | uint64_t seed_ = 2080) 74 | : init_A(init_A_) 75 | , init_B(init_B_) 76 | , init_C(init_C_) 77 | , seed(seed_) { } 78 | 79 | /// Helper to initialize a tensor view 80 | template 81 | bool initialize_tensor( 82 | cutlass::TensorView view, 83 | cutlass::Distribution::Kind dist_kind, 84 | uint64_t seed) { 85 | if (dist_kind == cutlass::Distribution::Uniform) { 86 | double scope_max, scope_min; 87 | int bits_input = cutlass::sizeof_bits::value; 88 | int bits_output = cutlass::sizeof_bits::value; 89 | 90 | if (bits_input == 1) { 91 | scope_max = 2; 92 | scope_min = 0; 93 | } 94 | else if (bits_input <= 8) { 95 | scope_max = 2; 96 | scope_min = -2; 97 | } 98 | else if (bits_output == 16) { 99 | scope_max = 5; 100 | scope_min = -5; 101 | } 102 | else { 103 | scope_max = 8; 104 | scope_min = -8; 105 | } 106 | 107 | cutlass::reference::host::TensorFillRandomUniform( 108 | view, seed, scope_max, scope_min, 0); 109 | } 110 | else if (dist_kind == cutlass::Distribution::Identity) { 111 | cutlass::reference::host::TensorFillIdentity(view); 112 | } 113 | else if (dist_kind == cutlass::Distribution::Gaussian) { 114 | cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); 115 | } 116 | else if (dist_kind == cutlass::Distribution::Sequential) { 117 | cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); 118 | } 119 | else { 120 | EXPECT_TRUE(false) << "Not implemented"; 121 | return false; 122 | } 123 | 124 | return true; 125 | } 126 | 127 | /// Initializes data structures 128 | void initialize(cutlass::gemm::GemmCoord problem_size) { 129 | // Allocate the GEMM workspace 130 | tensor_A.resize(problem_size.mk()); 131 | tensor_B.resize(problem_size.kn()); 132 | tensor_C.resize(problem_size.mn()); 133 | tensor_D.resize(problem_size.mn()); 134 | reference_D.resize(problem_size.mn(), false); 135 | 136 | EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); 137 | EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); 138 | EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); 139 | 140 | // It is possible to randomly initialize to all zeros, so override this with non-zeros 141 | // in the upper left corner of each operand. 142 | tensor_A.host_view().at({ 0, 0 }) = typename Srgemm::ElementA(1); 143 | tensor_B.host_view().at({ 0, 0 }) = typename Srgemm::ElementB(1); 144 | tensor_C.host_view().at({ 0, 0 }) = typename Srgemm::ElementC(1); 145 | 146 | cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); 147 | 148 | tensor_A.sync_device(); 149 | tensor_B.sync_device(); 150 | tensor_C.sync_device(); 151 | tensor_D.sync_device(); 152 | } 153 | 154 | /// Compares computed reference with device reference and outputs to a file if incorrect 155 | bool compare_reference( 156 | cutlass::gemm::GemmCoord problem_size, ElementCompute alpha, ElementCompute beta) { 157 | tensor_D.sync_host(); 158 | 159 | bool passed = cutlass::reference::host::TensorEquals( 160 | reference_D.host_view(), tensor_D.host_view()); 161 | EXPECT_TRUE(passed); 162 | 163 | if (!passed) { 164 | // record failed test cases to a file for debug records 165 | std::string add_op_name_full(abi::__cxa_demangle( 166 | typeid(typename Srgemm::AdditionOp).name(), // 167 | nullptr, nullptr, nullptr)); 168 | 169 | std::string mult_op_name_full(abi::__cxa_demangle( 170 | typeid(typename Srgemm::MultiplicationOp).name(), // 171 | nullptr, nullptr, nullptr)); 172 | 173 | std::string add_op_name( 174 | add_op_name_full.substr(0, add_op_name_full.find_first_of('<'))); 175 | std::string mult_op_name( 176 | mult_op_name_full.substr(0, mult_op_name_full.find_first_of('<'))); 177 | 178 | std::stringstream fname; 179 | fname << "error_Srgemm_device_" << problem_size.m() << 'x' << problem_size.n() 180 | << 'x' << problem_size.k() << '_' << add_op_name << '_' << mult_op_name << '_' 181 | << Srgemm::ThreadblockShape::kM << 'x' << Srgemm::ThreadblockShape::kN << 'x' 182 | << Srgemm::ThreadblockShape::kK << '_' << Srgemm::WarpShape::kM << 'x' 183 | << Srgemm::WarpShape::kN << 'x' << Srgemm::WarpShape::kK << ".txt"; 184 | 185 | std::ofstream file(fname.str()); 186 | file << "problem: " << problem_size << ", alpha: " << alpha << ", beta: " << beta 187 | << "\n\n"; 188 | 189 | file << "Addition operator: " << add_op_name_full << '\n'; 190 | file << "Multiplication operator: " << mult_op_name_full << '\n'; 191 | 192 | file << "A =\n" 193 | << tensor_A.host_view() << "\nB =\n" 194 | << tensor_B.host_view() << "\nC =\n" 195 | << tensor_C.host_view() << "\n\nReference =\n" 196 | << reference_D.host_view() << "\nComputed =\n" 197 | << tensor_D.host_view(); 198 | } 199 | 200 | return passed; 201 | } 202 | 203 | /// Verifies the result is a GEMM 204 | bool verify( 205 | cutlass::gemm::GemmCoord problem_size, ElementCompute alpha, ElementCompute beta) { 206 | cuasr::reference::host::Srgemm< 207 | typename Srgemm::AdditionOp, // 208 | typename Srgemm::MultiplicationOp, // 209 | typename Srgemm::ElementA, typename Srgemm::LayoutA, // 210 | typename Srgemm::ElementB, typename Srgemm::LayoutB, // 211 | typename Srgemm::ElementC, typename Srgemm::LayoutC, // 212 | typename Srgemm::EpilogueOutputOp::ElementCompute, // 213 | typename Srgemm::EpilogueOutputOp::ElementAccumulator, // 214 | typename Srgemm::EpilogueOutputOp> 215 | reference_srgemm; 216 | 217 | reference_srgemm( 218 | problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), // 219 | beta, tensor_C.host_ref(), reference_D.host_ref(), // 220 | Srgemm::AdditionOp::Identity); 221 | 222 | return compare_reference(problem_size, alpha, beta); 223 | } 224 | 225 | // Executes one test 226 | bool 227 | run(cutlass::gemm::GemmCoord problem_size, 228 | int split_k_slices = 1, 229 | ElementCompute alpha = ElementCompute(Srgemm::MultiplicationOp::Identity), 230 | ElementCompute beta = ElementCompute(Srgemm::MultiplicationOp::Identity)) { 231 | this->initialize(problem_size); 232 | 233 | // Initialize the GEMM operator 234 | typename Srgemm::Arguments arguments { 235 | problem_size, // 236 | tensor_A.device_ref(), // 237 | tensor_B.device_ref(), // 238 | tensor_C.device_ref(), // 239 | tensor_D.device_ref(), // 240 | { alpha, beta }, // 241 | split_k_slices // 242 | }; 243 | 244 | Srgemm gemm_op; 245 | size_t workspace_size = Srgemm::get_workspace_size(arguments); 246 | cutlass::device_memory::allocation workspace(workspace_size); 247 | cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); 248 | EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); 249 | 250 | // Run the GEMM 251 | status = gemm_op(); 252 | EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); 253 | 254 | // Verify 255 | bool passed = this->verify(problem_size, alpha, beta); 256 | if (!passed) { 257 | std::cout << "Error with split_k_slices = " << split_k_slices 258 | << ", alpha: " << alpha << " and beta: " << beta << std::endl; 259 | } 260 | 261 | return passed; 262 | } 263 | }; 264 | 265 | ///////////////////////////////////////////////////////////////////////////////////////////////// 266 | 267 | // Wrapper class to run different problem sizes and input combinations 268 | template 269 | bool TestAllGemm() { 270 | bool passed = true; 271 | 272 | int const kMinimumOperandElementSize = std::min( 273 | int(cutlass::sizeof_bits::value), 274 | int(cutlass::sizeof_bits::value)); 275 | 276 | int const kAlignment 277 | = cutlass::platform::is_same< 278 | typename Srgemm::OperatorClass, cutlass::arch::OpClassSimt>::value 279 | ? 1 280 | : 128 / kMinimumOperandElementSize; 281 | 282 | // int8_t gemm alignment constraints 283 | int const kAlignmentM 284 | = cutlass::platform::is_same< 285 | typename Srgemm::OperatorClass, cutlass::arch::OpClassSimt>::value 286 | && cutlass::platform::is_same::value 287 | && cutlass::platform::is_same< 288 | typename Srgemm::LayoutA, cutlass::layout::ColumnMajor>::value 289 | ? 4 290 | : kAlignment; 291 | 292 | int const kAlignmentN 293 | = cutlass::platform::is_same< 294 | typename Srgemm::OperatorClass, cutlass::arch::OpClassSimt>::value 295 | && cutlass::platform::is_same::value 296 | && cutlass::platform::is_same< 297 | typename Srgemm::LayoutB, cutlass::layout::RowMajor>::value 298 | ? 4 299 | : kAlignment; 300 | 301 | int const kAlignmentK 302 | = cutlass::platform::is_same< 303 | typename Srgemm::OperatorClass, cutlass::arch::OpClassSimt>::value 304 | && cutlass::platform::is_same::value 305 | && cutlass::platform::is_same::value 306 | && (cutlass::platform::is_same< 307 | typename Srgemm::LayoutA, cutlass::layout::RowMajor>::value 308 | || cutlass::platform::is_same< 309 | typename Srgemm::LayoutB, cutlass::layout::ColumnMajor>::value) 310 | ? 4 311 | : kAlignment; 312 | 313 | int problem_size_m[] = { kAlignmentM, 512 - 3 * kAlignmentM }; 314 | 315 | int problem_size_n[] = { kAlignmentN, 512 - 2 * kAlignmentN }; 316 | 317 | int problem_size_k[] 318 | = { kAlignmentK, 319 | Srgemm::ThreadblockShape::kK * (Srgemm::kStages + 1) - kAlignmentK }; 320 | 321 | // TODO: add split-K SRGEMM 322 | int split_k_slices[] = { 1, 2, 3, 8 }; 323 | 324 | double problem_alpha[] = { Srgemm::MultiplicationOp::Identity }; 325 | double problem_beta[] = { Srgemm::MultiplicationOp::Annihilator }; 326 | 327 | Testbed testbed; 328 | using ElementCompute = typename Srgemm::EpilogueOutputOp::ElementCompute; 329 | 330 | for (int m : problem_size_m) { 331 | for (int n : problem_size_n) { 332 | for (int k : problem_size_k) { 333 | for (int split_k : split_k_slices) { 334 | if (!Srgemm::kSplitKSerial && split_k > 1) { 335 | continue; 336 | } 337 | 338 | if (split_k > 1 && k / Srgemm::ThreadblockShape::kK < split_k) { 339 | continue; 340 | } 341 | 342 | for (auto alpha : problem_alpha) { 343 | for (auto beta : problem_beta) { 344 | cutlass::gemm::GemmCoord problem_size(m, n, k); 345 | 346 | passed = testbed.run( 347 | problem_size, split_k, cutlass::from_real(alpha), 348 | cutlass::from_real(beta)); 349 | 350 | if (!passed) { 351 | return false; 352 | } 353 | } 354 | } 355 | } 356 | } 357 | } 358 | } 359 | 360 | return passed; 361 | } 362 | 363 | } // namespace device 364 | } // namespace gemm 365 | } // namespace test 366 | } // namespace cuasr 367 | -------------------------------------------------------------------------------- /test/harness.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | 3 | auto main(int argc, char **argv) -> int { 4 | ::testing::InitGoogleTest(&argc, argv); 5 | return RUN_ALL_TESTS(); 6 | } 7 | -------------------------------------------------------------------------------- /test/regress/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # cuasr library configuration 2 | add_library(deprecated_libfwgpu ${cuASR_LIB_TYPE} 3 | ./src/cutlass_srgemm.cu 4 | ./src/utils.cu 5 | ) 6 | target_include_directories(deprecated_libfwgpu 7 | PUBLIC ${PROJECT_SOURCE_DIR}/include ${CUDA_INCLUDE_DIRS} 8 | PRIVATE ${PROJECT_SOURCE_DIR}/cutlass/include 9 | PRIVATE ${PROJECT_SOURCE_DIR}/test/regress/include 10 | ) 11 | target_compile_options(deprecated_libfwgpu 12 | PUBLIC 13 | # C++ compiler flags 14 | $<$,$>: 15 | ${cuASR_CXX_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 16 | 17 | # CUDA compiler flags 18 | $<$,$>: 19 | ${cuASR_CUDA_FLAGS_${uppercase_CMAKE_BUILD_TYPE}}> 20 | ) 21 | 22 | ### Matrix tests 23 | add_executable(Matrix_tests ${PROJECT_SOURCE_DIR}/test/harness.cpp Matrix_test.cpp) 24 | target_include_directories(Matrix_tests 25 | PRIVATE gtest/googletest/include 26 | PRIVATE ${PROJECT_SOURCE_DIR}/test/regress/include) 27 | target_link_libraries(Matrix_tests gtest deprecated_libfwgpu) 28 | add_test( 29 | NAME Matrix_tests 30 | COMMAND Matrix_tests 31 | ) 32 | 33 | ### SemiRing GEMM tests 34 | add_executable(tropical_gemm_tests ${PROJECT_SOURCE_DIR}/test/harness.cpp Srgemm_test.cu) 35 | target_include_directories(tropical_gemm_tests 36 | PRIVATE gtest/googletest/include 37 | PRIVATE ${PROJECT_SOURCE_DIR}/test/regress/include) 38 | target_link_libraries(tropical_gemm_tests gtest deprecated_libfwgpu) 39 | add_test( 40 | NAME tropical_gemm_tests 41 | COMMAND tropical_gemm_tests 42 | ) 43 | -------------------------------------------------------------------------------- /test/regress/Matrix_test.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | 3 | #include "fwgpu/Matrix.hpp" 4 | 5 | TEST(cuASR_Matrix, BasicConstructorCorrect) { 6 | auto x = fwgpu::Matrix(6, 2); 7 | for (auto i = 0u; i < 12; ++i) { 8 | x(i) = (float)i; 9 | } 10 | 11 | EXPECT_EQ(size_t { 12 }, x.size()); 12 | EXPECT_EQ(size_t { 12 * sizeof(float) }, x.bytesize()); 13 | EXPECT_EQ(size_t { 6 }, x.num_rows()); 14 | EXPECT_EQ(size_t { 2 }, x.num_cols()); 15 | EXPECT_FLOAT_EQ(10.0f, x(10)); 16 | EXPECT_FLOAT_EQ(0.0f, x(0, 0)); 17 | EXPECT_FLOAT_EQ(8.0f, x(2, 1)); 18 | EXPECT_FLOAT_EQ(11.0f, x(5, 1)); 19 | } 20 | 21 | TEST(cuASR_Matrix, InitializerListConstructorCorrect) { 22 | // [8.0 3.0 0.0 1.0] 23 | // [2.0 5.0 4.0 9.0] 24 | // [7.0 6.0 10. 13.] 25 | auto x = fwgpu::Matrix( 26 | 3, 4, { 8.0, 2.0, 7.0, 3.0, 5.0, 6.0, 0.0, 4.0, 10.0, 1.0, 9.0, 13.0 }); 27 | 28 | EXPECT_EQ(size_t { 12 }, x.size()); 29 | EXPECT_EQ(size_t { 12 * sizeof(float) }, x.bytesize()); 30 | EXPECT_EQ(size_t { 3 }, x.num_rows()); 31 | EXPECT_EQ(size_t { 4 }, x.num_cols()); 32 | EXPECT_FLOAT_EQ(8.0f, x(0, 0)); 33 | EXPECT_FLOAT_EQ(1.0f, x(0, 3)); 34 | EXPECT_FLOAT_EQ(10.0f, x(2, 2)); 35 | } 36 | 37 | TEST(cuASR_Matrix, RandomFloatMatrixConstructorCorrect) { 38 | size_t const seed = 8; 39 | auto const minimum = 1.0545; 40 | auto const maximum = 28.1; 41 | auto x = fwgpu::Matrix(9, 8, seed, minimum, maximum); 42 | 43 | EXPECT_EQ(size_t { 9 * 8 }, x.size()); 44 | EXPECT_EQ(size_t { 9 * 8 * sizeof(double) }, x.bytesize()); 45 | EXPECT_EQ(size_t { 9 }, x.num_rows()); 46 | EXPECT_EQ(size_t { 8 }, x.num_cols()); 47 | 48 | for (auto i = 0u; i < x.size(); ++i) { 49 | double const val = x(i); 50 | EXPECT_TRUE((val >= minimum && val <= maximum)); 51 | } 52 | } 53 | 54 | TEST(cuASR_Matrix, RandomIntMatrixConstructorCorrect) { 55 | size_t const seed = 8; 56 | auto const minimum = 1; 57 | auto const maximum = 128; 58 | auto x = fwgpu::Matrix(7, 5, seed, minimum, maximum); 59 | 60 | EXPECT_EQ(size_t { 7 * 5 }, x.size()); 61 | EXPECT_EQ(size_t { 7 * 5 * sizeof(int) }, x.bytesize()); 62 | EXPECT_EQ(size_t { 7 }, x.num_rows()); 63 | EXPECT_EQ(size_t { 5 }, x.num_cols()); 64 | 65 | for (auto i = 0u; i < x.size(); ++i) { 66 | int const val = x(i); 67 | EXPECT_TRUE((val >= minimum && val <= maximum)); 68 | } 69 | } 70 | 71 | TEST(cuASR_Matrix, CopyConstructorCorrect) { 72 | auto from = fwgpu::Matrix(5, 7, 0.0f); 73 | auto to = from; 74 | EXPECT_TRUE(from == to); 75 | } 76 | 77 | TEST(cuASR_Matrix, MoveConstructorCorrect) { 78 | auto from = fwgpu::Matrix(5, 7, 0.0f); 79 | EXPECT_EQ(from.num_rows(), 5); 80 | EXPECT_EQ(from.num_cols(), 7); 81 | EXPECT_TRUE(from.get_buf() != nullptr); 82 | 83 | auto to = fwgpu::Matrix(std::move(from)); 84 | EXPECT_EQ(to.num_rows(), 5); 85 | EXPECT_EQ(to.num_cols(), 7); 86 | EXPECT_TRUE(to.get_buf() != nullptr); 87 | 88 | EXPECT_EQ(from.num_rows(), 0); 89 | EXPECT_EQ(from.num_cols(), 0); 90 | EXPECT_TRUE(from.get_buf() == nullptr); 91 | } 92 | 93 | TEST(cuASR_Matrix, ConstantConstructorCorrect) { 94 | auto mat = fwgpu::Matrix(6, 2, 42); 95 | EXPECT_EQ(mat.num_rows(), 6); 96 | EXPECT_EQ(mat.num_cols(), 2); 97 | EXPECT_EQ(mat(0, 0), 42); 98 | EXPECT_EQ(mat(3, 0), 42); 99 | EXPECT_EQ(mat(3, 1), 42); 100 | EXPECT_EQ(mat(5, 1), 42); 101 | } 102 | 103 | TEST(cuASR_Matrix, BufferConstructorCorrect) { 104 | std::vector matvals(12, 42); 105 | auto mat = fwgpu::Matrix(6, 2, matvals.data()); 106 | 107 | EXPECT_EQ(mat.num_rows(), 6); 108 | EXPECT_EQ(mat.num_cols(), 2); 109 | EXPECT_EQ(mat(0, 0), 42); 110 | EXPECT_EQ(mat(3, 0), 42); 111 | EXPECT_EQ(mat(3, 1), 42); 112 | EXPECT_EQ(mat(5, 1), 42); 113 | } 114 | 115 | TEST(cuASR_Matrix, CopyAssignmentCorrect) { 116 | auto from = fwgpu::Matrix(5, 7, 0.0f); 117 | fwgpu::Matrix to(1, 1); 118 | to = from; 119 | EXPECT_TRUE(from == to); 120 | } 121 | 122 | TEST(cuASR_Matrix, ColumnMajorLayoutCorrect) { 123 | auto mat = fwgpu::Matrix( 124 | 4, 4, 125 | { 126 | 0.840187728, 0.911647379, 0.277774721, 0.364784479, // 127 | 0.394382924, 0.197551370, 0.553969979, 0.513400912, // 128 | 0.729605675, 0.335222751, 0.477397054, 0.952229738, // 129 | 0.798440039, 0.768229604, 0.628870904, 0.916195095 // 130 | }); 131 | 132 | // corners 133 | EXPECT_FLOAT_EQ(mat(0, 0), 0.840187728); 134 | EXPECT_FLOAT_EQ(mat(0, 3), 0.798440039); 135 | EXPECT_FLOAT_EQ(mat(3, 0), 0.364784479); 136 | EXPECT_FLOAT_EQ(mat(3, 3), 0.916195095); 137 | 138 | // middle 2x2 139 | EXPECT_FLOAT_EQ(mat(1, 1), 0.197551370); 140 | EXPECT_FLOAT_EQ(mat(1, 2), 0.335222751); 141 | EXPECT_FLOAT_EQ(mat(2, 1), 0.553969979); 142 | EXPECT_FLOAT_EQ(mat(2, 2), 0.477397054); 143 | } 144 | 145 | TEST(cuASR_Matrix, RowMajorLayoutCorrect) { 146 | auto mat = fwgpu::Matrix( 147 | 4, 4, 148 | { 149 | 0.840187728, 0.911647379, 0.277774721, 0.364784479, // 150 | 0.394382924, 0.197551370, 0.553969979, 0.513400912, // 151 | 0.729605675, 0.335222751, 0.477397054, 0.952229738, // 152 | 0.798440039, 0.768229604, 0.628870904, 0.916195095 // 153 | }); 154 | 155 | // corners 156 | EXPECT_FLOAT_EQ(mat(0, 0), 0.840187728); 157 | EXPECT_FLOAT_EQ(mat(0, 3), 0.364784479); 158 | EXPECT_FLOAT_EQ(mat(3, 0), 0.798440039); 159 | EXPECT_FLOAT_EQ(mat(3, 3), 0.916195095); 160 | 161 | // middle 2x2 162 | EXPECT_FLOAT_EQ(mat(1, 1), 0.197551370); 163 | EXPECT_FLOAT_EQ(mat(1, 2), 0.553969979); 164 | EXPECT_FLOAT_EQ(mat(2, 1), 0.335222751); 165 | EXPECT_FLOAT_EQ(mat(2, 2), 0.477397054); 166 | } 167 | -------------------------------------------------------------------------------- /test/regress/include/fwgpu/Matrix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace fwgpu { 11 | 12 | struct ColumnMajor { 13 | size_t m_rows; 14 | size_t m_cols; 15 | 16 | ColumnMajor() = delete; 17 | ColumnMajor(size_t rows, size_t cols) 18 | : m_rows(rows) 19 | , m_cols(cols) {}; 20 | 21 | auto linearize(size_t row_idx, size_t col_idx) const noexcept -> size_t { 22 | return row_idx + (m_rows * col_idx); 23 | } 24 | }; 25 | 26 | struct RowMajor { 27 | size_t m_rows; 28 | size_t m_cols; 29 | 30 | RowMajor() = delete; 31 | RowMajor(size_t rows, size_t cols) 32 | : m_rows(rows) 33 | , m_cols(cols) {}; 34 | 35 | auto linearize(size_t row_idx, size_t col_idx) const noexcept -> size_t { 36 | return (row_idx * m_cols) + col_idx; 37 | } 38 | }; 39 | 40 | /* 41 | * Matrix datastructure for a tightly packed 2D array. 42 | * ElementT = float and column major by default. 43 | **/ 44 | template 45 | class Matrix { 46 | private: 47 | Layout m_layout; 48 | std::vector m_host_buf; 49 | 50 | public: 51 | /* 52 | * No default contructor. 53 | **/ 54 | Matrix() = delete; 55 | 56 | /* 57 | * Default distructor. 58 | **/ 59 | ~Matrix() = default; 60 | 61 | /* 62 | * De-facto default constructor: allocate ElementT buffer of size rows*cols 63 | **/ 64 | Matrix(size_t rows, size_t cols) 65 | : m_layout(rows, cols) 66 | , m_host_buf(rows * cols) { } 67 | 68 | /* 69 | * Assign buf from external source. 70 | * TODO: not sure we should allow this? 71 | **/ 72 | Matrix(size_t rows, size_t cols, ElementT *buf) 73 | : m_layout(rows, cols) 74 | , m_host_buf(buf, buf + (rows * cols)) { } 75 | 76 | /* 77 | * Allocates and initializes the matrix with input value. 78 | **/ 79 | Matrix(size_t rows, size_t cols, ElementT val) 80 | : m_layout(rows, cols) 81 | , m_host_buf(rows * cols, val) { } 82 | 83 | /* 84 | * Random Fill Constructor: allocates and initializes the matrix 85 | * with random numbers in the input range. 86 | **/ 87 | Matrix( 88 | size_t rows, 89 | size_t cols, 90 | size_t seed, 91 | ElementT min = ElementT(0.0), 92 | ElementT max = ElementT(1.0)) 93 | : m_layout(rows, cols) 94 | , m_host_buf(rows * cols) { 95 | using Distribution = std::conditional_t< 96 | std::is_integral::value, // if ElementT is integral 97 | std::uniform_int_distribution, // use int dist 98 | std::uniform_real_distribution // otherwise floating point dist 99 | >; 100 | auto rng = std::mt19937_64(seed); 101 | auto dist = Distribution(min, max); 102 | for (auto i = 0ull; i < (rows * cols); ++i) { 103 | m_host_buf[i] = dist(rng); 104 | } 105 | } 106 | 107 | /* 108 | * Allocates and initializes the matrix from an initializer list. 109 | * This mainly makes testing easier. 110 | **/ 111 | Matrix(size_t rows, size_t cols, const std::initializer_list &elements) 112 | : m_layout(rows, cols) 113 | , m_host_buf(rows * cols) { 114 | auto i = 0ull; 115 | for (auto val : elements) { 116 | m_host_buf[i++] = val; 117 | } 118 | } 119 | 120 | /* 121 | * Copy constructor: deep copy other 122 | **/ 123 | Matrix(const Matrix &other) 124 | : m_layout(other.m_layout) 125 | , m_host_buf(other.m_host_buf) { } 126 | 127 | /* 128 | * Move constructor: sink other into this 129 | */ 130 | Matrix(Matrix &&other) 131 | : m_layout(other.m_layout) 132 | , m_host_buf(std::move(other.m_host_buf)) { 133 | other.m_layout.m_rows = 0; 134 | other.m_layout.m_cols = 0; 135 | } 136 | 137 | /* 138 | * Copy assignment operator. 139 | **/ 140 | auto operator=(const Matrix &other) -> Matrix & { 141 | m_layout = other.m_layout; 142 | m_host_buf = other.m_host_buf; 143 | return *this; 144 | } 145 | 146 | /* 147 | * Returns a non-owning, const pointer to the backing buffer of type ElementT[]. 148 | **/ 149 | auto get_buf() const noexcept -> const ElementT * { return m_host_buf.data(); } 150 | 151 | /* 152 | * Returns a non-owning pointer to the backing buffer of type ElementT[]. 153 | **/ 154 | auto get_buf() noexcept -> ElementT * { return m_host_buf.data(); } 155 | 156 | /* 157 | * Returns total number of elements stored in the matrix. 158 | **/ 159 | auto size() const noexcept -> size_t { return m_layout.m_rows * m_layout.m_cols; } 160 | 161 | /* 162 | * Returns total number of bytes occupied by the backing store ElementT[]. 163 | **/ 164 | auto bytesize() const noexcept -> size_t { return size() * sizeof(ElementT); } 165 | 166 | /* 167 | * Returns true if matrix has (0, 0) dimentions. False otherwise. 168 | **/ 169 | auto is_empty() const noexcept -> size_t { 170 | return (m_layout.m_rows == 0) || (m_layout.m_cols == 0); 171 | } 172 | 173 | /* 174 | * Returns numbers of rows in the matrix. 175 | **/ 176 | auto num_rows() const noexcept -> size_t { return m_layout.m_rows; } 177 | 178 | /* 179 | * Returns numbers of columns in the matrix. 180 | **/ 181 | auto num_cols() const noexcept -> size_t { return m_layout.m_cols; } 182 | 183 | /* 184 | * Linear index into the flat buffer. 185 | **/ 186 | auto operator[](size_t idx) -> ElementT & { return m_host_buf[idx]; } 187 | auto operator[](size_t idx) const -> ElementT const & { return m_host_buf[idx]; } 188 | 189 | /* 190 | * Linear index into the flat buffer. 191 | **/ 192 | auto operator()(size_t idx) -> ElementT & { return m_host_buf[idx]; } 193 | auto operator()(size_t idx) const -> ElementT const & { return m_host_buf[idx]; } 194 | 195 | /* 196 | * Matrix index with major dimention offset. 197 | * Column major for now, but we can add support for changing to row major later 198 | * with some template magic. 199 | */ 200 | auto operator()(size_t row_idx, size_t col_idx) -> ElementT & { 201 | return m_host_buf[m_layout.linearize(row_idx, col_idx)]; 202 | } 203 | 204 | auto operator()(size_t row_idx, size_t col_idx) const -> ElementT const & { 205 | return m_host_buf[m_layout.linearize(row_idx, col_idx)]; 206 | } 207 | }; 208 | 209 | // Element-wise equality test for two matrices of the same template type. 210 | template 211 | inline auto operator==(const Matrix &lhs, const Matrix &rhs) -> bool { 212 | // both dims much match first 213 | if ((lhs.num_rows() != rhs.num_rows()) || (lhs.num_cols() != rhs.num_cols())) { 214 | return false; 215 | } 216 | 217 | for (auto i = 0ull; i < lhs.size(); ++i) { 218 | if (lhs[i] < rhs[i]) { 219 | return false; 220 | } 221 | } 222 | 223 | return true; 224 | } 225 | 226 | // Element-wise inequality test for two matrices of the same template type. 227 | template 228 | inline auto operator!=(const Matrix &lhs, const Matrix &rhs) -> bool { 229 | return !(lhs == rhs); 230 | } 231 | 232 | // Prints matrix to stdout; prefer using this only for small matrices. 233 | template 234 | inline auto operator<<(std::ostream &os, const Matrix &mat) -> std::ostream & { 235 | for (auto row_idx = 0ull; row_idx < mat.num_rows(); ++row_idx) { 236 | os << '[' << mat(row_idx, 0); 237 | 238 | for (auto col_idx = 1ull; col_idx < mat.num_cols() - 1; ++col_idx) { 239 | os << ", " << mat(row_idx, col_idx); 240 | } 241 | 242 | os << "]\n"; 243 | } 244 | 245 | return os; 246 | } 247 | 248 | } // namespace fwgpu 249 | -------------------------------------------------------------------------------- /test/regress/include/fwgpu/cpu_srgemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace fwgpu { 4 | 5 | template 6 | inline auto cpu_srgemm_naive( 7 | int m, int n, int k, const T *A, int lda, const T *B, int ldb, T *C, int ldc) 8 | -> void { 9 | for (int row = 0; row < m; ++row) { 10 | for (int col = 0; col < n; ++col) { 11 | T mindist = C[row + (col * ldc)]; 12 | for (int i = 0; i < k; ++i) { 13 | mindist = std::min(mindist, A[row + (i * lda)] + B[i + (col * ldb)]); 14 | } 15 | C[row + (col * ldc)] = mindist; 16 | } 17 | } 18 | } 19 | 20 | template 21 | inline auto cpu_fwgemm_naive( 22 | int m, 23 | int n, 24 | int k, 25 | const TData *A, 26 | int lda, 27 | const TData *B, 28 | int ldb, 29 | TData *dist, 30 | int ldc, 31 | TIdx *parent) -> void { 32 | for (int row = 0; row < n; ++row) { 33 | for (int col = 0; col < n; ++col) { 34 | // dist and parent for this vertex pair (i, j) 35 | TData curr_dist = dist[row + (col * ldc)]; 36 | TIdx curr_parent = parent[row + (col * ldc)]; 37 | for (int k = 0; k < n; ++k) { 38 | TData prod = A[row + (k * lda)] + B[k + (col * ldb)]; 39 | if (prod < curr_dist) { 40 | curr_dist = prod; 41 | curr_parent = k; 42 | } 43 | } 44 | dist[row + (col * ldc)] = curr_dist; 45 | parent[row + (col * ldc)] = curr_parent; 46 | } 47 | } 48 | } 49 | 50 | } // namespace fwgpu 51 | -------------------------------------------------------------------------------- /test/regress/include/fwgpu/gpu_srgemm.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace fwgpu { 6 | 7 | template 8 | __global__ auto 9 | gpu_srgemm_naive(int m, int n, int k, T *A, int lda, T *B, int ldb, T *dist, int ldc) 10 | -> void { 11 | size_t ty = blockIdx.y * blockDim.y + threadIdx.y; 12 | size_t tx = blockIdx.x * blockDim.x + threadIdx.x; 13 | 14 | size_t n_idx = ty; 15 | while (n_idx < n) { 16 | size_t m_idx = tx; 17 | while (m_idx < m) { 18 | // initialize current minimum distance 19 | T mindist = dist[(n_idx * ldc) + m_idx]; 20 | for (size_t k_idx = 0; k_idx < k; ++k_idx) { 21 | // calculate the distance between n_idx->m_idx by going through k_idx 22 | T thisone = A[(k_idx * lda) + m_idx] + B[(n_idx * ldb) + k_idx]; 23 | if (thisone < mindist) { 24 | mindist = thisone; 25 | } 26 | } 27 | // finally, store new min distance to dist matrix 28 | dist[(n_idx * ldc) + m_idx] = mindist; 29 | m_idx += gridDim.x * blockDim.x; 30 | } 31 | n_idx += gridDim.y * blockDim.y; 32 | } 33 | } 34 | 35 | } // namespace fwgpu 36 | -------------------------------------------------------------------------------- /test/regress/include/fwgpu/gpu_srgemm.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace fwgpu { 4 | 5 | // Cutlass semiring gemm based on {sum, min} as ring operators 6 | auto cutlass_srsgemm_nn( 7 | int M, 8 | int N, 9 | int K, 10 | float const *A, 11 | int lda, 12 | float const *B, 13 | int ldb, 14 | float *C, 15 | int ldc, 16 | float *D, 17 | bool do_epilogue_min = true, 18 | void *stream = nullptr) -> int; 19 | 20 | // Cutlass semiring sgemm based on {sum, min} as ring operators 21 | auto cutlass_srsgemm_nn( 22 | int M, 23 | int N, 24 | int K, 25 | float const *A, 26 | int lda, 27 | float const *B, 28 | int ldb, 29 | float *C, 30 | int ldc, 31 | bool do_epilogue_min = true, 32 | void *stream = nullptr) -> int; 33 | 34 | } // namespace fwgpu 35 | -------------------------------------------------------------------------------- /test/regress/include/fwgpu/utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace fwgpu { 4 | 5 | // device memory allocation of size bytes 6 | auto malloc_device(void **dptr, size_t size) -> int; 7 | 8 | // unified managed memory allocation of size bytes 9 | auto malloc_unified(void **dptr, size_t size) -> int; 10 | 11 | // free cuda allocated memory, managed or unmanaged 12 | auto free_device(void *dptr) -> int; 13 | 14 | // MEMCPY API 15 | // memory copy: device -> host 16 | auto memcpy_d2h(void *dest, const void *src, size_t size) -> int; 17 | 18 | // memory copy: host -> device 19 | auto memcpy_h2d(void *dest, const void *src, size_t size) -> int; 20 | 21 | // memory copy: host -> host 22 | auto memcpy_h2h(void *dest, const void *src, size_t size) -> int; 23 | 24 | // memory copy: device -> device 25 | auto memcpy_d2d(void *dest, const void *src, size_t size) -> int; 26 | 27 | // memory copy: direction inferred based on src and dest. Requires unified memory. 28 | auto memcpy_inferred(void *dest, const void *src, size_t size) -> int; 29 | 30 | auto memcpy_2d_h2d( 31 | void *deset, 32 | size_t dpitch, 33 | const void *src, 34 | size_t spitch, 35 | size_t width, 36 | size_t height) -> int; 37 | auto memcpy_2d_d2h( 38 | void *deset, 39 | size_t dpitch, 40 | const void *src, 41 | size_t spitch, 42 | size_t width, 43 | size_t height) -> int; 44 | auto memcpy_2d_d2d( 45 | void *dest, 46 | size_t dpitch, 47 | const void *src, 48 | size_t spitch, 49 | size_t width, 50 | size_t height) -> int; 51 | auto memcpy_2d_inferred( 52 | void *dest, 53 | size_t dpitch, 54 | const void *src, 55 | size_t spitch, 56 | size_t width, 57 | size_t height) -> int; 58 | } // namespace fwgpu 59 | -------------------------------------------------------------------------------- /test/regress/src/cutlass_srgemm.cu: -------------------------------------------------------------------------------- 1 | #include "fwgpu/gpu_srgemm.hpp" 2 | 3 | #include "cuasr/arch/srmma.h" 4 | #include "cuasr/gemm/device/default_srgemm_configuration.h" 5 | #include "cuasr/gemm/device/srgemm.h" 6 | 7 | #include "cutlass/functional.h" 8 | 9 | namespace fwgpu { 10 | 11 | auto cutlass_srsgemm_nn( 12 | int M, 13 | int N, 14 | int K, 15 | float const *A, 16 | int lda, 17 | float const *B, 18 | int ldb, 19 | float *C, 20 | int ldc, 21 | float *D, 22 | bool do_epilogue_min, 23 | void *stream) -> int { 24 | cudaStream_t stream_ = nullptr; 25 | if (stream) { 26 | stream_ = *(static_cast(stream)); 27 | } 28 | // compile time configuration of this srgemm kernel 29 | using OperatorClass = cutlass::arch::OpClassSimt; 30 | using SmArch = cutlass::arch::Sm50; 31 | using TropicalConfig = typename cuasr::gemm::device::DefaultSemiRingConfiguration< 32 | float, float, float, float, OperatorClass, cuasr::minimum, 33 | cuasr::plus, SmArch>; 34 | 35 | using AdditionOp = TropicalConfig::AdditionOp; 36 | using MultiplicationOp = TropicalConfig::MultiplicationOp; 37 | using ColumnMajor = cutlass::layout::ColumnMajor; 38 | using ThreadblockShape = typename TropicalConfig::ThreadblockShape; 39 | using WarpShape = typename TropicalConfig::WarpShape; 40 | using InstructionShape = typename TropicalConfig::InstructionShape; 41 | using EpilogueOutputOp = typename TropicalConfig::EpilogueOutputOp; 42 | using ThreadblockSwizzle = 43 | typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; 44 | constexpr int Stages = TropicalConfig::kStages; 45 | constexpr int AlignmentA = TropicalConfig::kAlignmentA; 46 | constexpr int AlignmentB = TropicalConfig::kAlignmentB; 47 | 48 | using cuASR_MinPlus_SGEMM = cuasr::gemm::device::Srgemm< 49 | AdditionOp, // Thread level SemiRing operator 50 | MultiplicationOp, // Thread level SemiRing operator 51 | float, // element type of A 52 | ColumnMajor, // layout of A 53 | float, // element type of B 54 | ColumnMajor, // layout of B 55 | float, // element type of C 56 | ColumnMajor, // layout of C 57 | float, // element type of D 58 | OperatorClass, // Logical operator class (SIMT/Tensor) 59 | SmArch, // cuda architecture 60 | ThreadblockShape, // GEMM shape at CTA level 61 | WarpShape, // GEMM shape at Warp level 62 | InstructionShape, // GEMM shape at thread level 63 | EpilogueOutputOp, // Epilogue operator at thread level 64 | ThreadblockSwizzle, // GEMM threadblock swizzler 65 | Stages, // Pipeline stages for shmem 66 | AlignmentA, // Alignment of A elements 67 | AlignmentB, // Alignment of B elements 68 | false // SplitKSerial 69 | >; 70 | 71 | float alpha = MultiplicationOp::Identity; 72 | float beta 73 | = do_epilogue_min ? MultiplicationOp::Identity : MultiplicationOp::Annihilator; 74 | // construct kernel arguments struct 75 | cuASR_MinPlus_SGEMM::Arguments args( 76 | { M, N, K }, // Problem dimensions 77 | { A, lda }, // Tensor-ref for source matrix A 78 | { B, ldb }, // Tensor-ref for source matrix B 79 | { C, ldc }, // Tensor-ref for source matrix C 80 | { D, ldc }, // Tensor-ref for destination matrix D 81 | { alpha, beta } // True if we perform a final min with source matrix C 82 | ); 83 | 84 | // launch SRGEMM kernel 85 | cuASR_MinPlus_SGEMM minplus_gemm; 86 | cutlass::Status status = minplus_gemm(args, nullptr, stream_); 87 | return static_cast(status); 88 | } 89 | 90 | auto cutlass_srsgemm_nn( 91 | int M, 92 | int N, 93 | int K, 94 | float const *A, 95 | int lda, 96 | float const *B, 97 | int ldb, 98 | float *C, 99 | int ldc, 100 | bool do_epilogue_min, 101 | void *stream) -> int { 102 | return cutlass_srsgemm_nn(M, N, K, A, lda, B, ldb, C, ldc, C, do_epilogue_min, stream); 103 | } 104 | 105 | } // namespace fwgpu 106 | -------------------------------------------------------------------------------- /test/regress/src/utils.cu: -------------------------------------------------------------------------------- 1 | #include "fwgpu/utils.hpp" 2 | 3 | #include 4 | 5 | namespace fwgpu { 6 | 7 | auto malloc_device(void **dptr, size_t size) -> int { 8 | auto retval = static_cast(cudaMalloc(dptr, size)); 9 | return retval; 10 | } 11 | 12 | auto malloc_unified(void **dptr, size_t size) -> int { 13 | auto retval = static_cast(cudaMallocManaged(dptr, size)); 14 | return retval; 15 | } 16 | 17 | auto memcpy_inferred(void *dest, const void *src, size_t size) -> int { 18 | auto retval = static_cast(cudaMemcpy(dest, src, size, cudaMemcpyDefault)); 19 | return retval; 20 | } 21 | 22 | auto free_device(void *dbuf) -> int { 23 | auto retval = static_cast(cudaFree(dbuf)); 24 | return retval; 25 | } 26 | 27 | auto memcpy_d2h(void *dest, const void *src, size_t size) -> int { 28 | auto retval = static_cast(cudaMemcpy(dest, src, size, cudaMemcpyDeviceToHost)); 29 | return retval; 30 | } 31 | 32 | auto memcpy_h2d(void *dest, const void *src, size_t size) -> int { 33 | auto retval = static_cast(cudaMemcpy(dest, src, size, cudaMemcpyHostToDevice)); 34 | return retval; 35 | } 36 | 37 | auto memcpy_h2h(void *dest, const void *src, size_t size) -> int { 38 | auto retval = static_cast(cudaMemcpy(dest, src, size, cudaMemcpyDeviceToDevice)); 39 | return retval; 40 | } 41 | 42 | auto memcpy_d2d(void *dest, const void *src, size_t size) -> int { 43 | auto retval = static_cast(cudaMemcpy(dest, src, size, cudaMemcpyHostToHost)); 44 | return retval; 45 | } 46 | 47 | auto memcpy_2d_h2d( 48 | void *dest, 49 | size_t dpitch, 50 | const void *src, 51 | size_t spitch, 52 | size_t width, 53 | size_t height) -> int { 54 | auto retval = static_cast( 55 | cudaMemcpy2D(dest, dpitch, src, spitch, width, height, cudaMemcpyHostToDevice)); 56 | return retval; 57 | } 58 | 59 | auto memcpy_2d_d2h( 60 | void *dest, 61 | size_t dpitch, 62 | const void *src, 63 | size_t spitch, 64 | size_t width, 65 | size_t height) -> int { 66 | auto retval = static_cast( 67 | cudaMemcpy2D(dest, dpitch, src, spitch, width, height, cudaMemcpyDeviceToHost)); 68 | return retval; 69 | } 70 | 71 | auto memcpy_2d_d2d( 72 | void *dest, 73 | size_t dpitch, 74 | const void *src, 75 | size_t spitch, 76 | size_t width, 77 | size_t height) -> int { 78 | auto retval = static_cast( 79 | cudaMemcpy2D(dest, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice)); 80 | return retval; 81 | } 82 | 83 | auto memcpy_2d_inferred( 84 | void *dest, 85 | size_t dpitch, 86 | const void *src, 87 | size_t spitch, 88 | size_t width, 89 | size_t height) -> int { 90 | auto retval = static_cast( 91 | cudaMemcpy2D(dest, dpitch, src, spitch, width, height, cudaMemcpyDefault)); 92 | return retval; 93 | } 94 | 95 | } // namespace fwgpu 96 | -------------------------------------------------------------------------------- /test/regress/utils.cuh: -------------------------------------------------------------------------------- 1 | #ifndef cuASR_INTERNAL_UTILS 2 | #define cuASR_INTERNAL_UTILS 3 | 4 | #include 5 | 6 | #include "fwgpu/Matrix.hpp" 7 | 8 | namespace fwgpu { 9 | namespace internal { 10 | 11 | template 12 | inline auto alloc_and_init_device_gemm_mats( 13 | const Matrix &A, const Matrix &B, const Matrix &C) 14 | -> std::tuple { 15 | // allocate for inputs and outputs on device 16 | void *d_A, *d_B, *d_C; 17 | cudaMalloc(&d_A, A.bytesize()); 18 | cudaMalloc(&d_B, B.bytesize()); 19 | cudaMalloc(&d_C, C.bytesize()); 20 | 21 | // copy inputs to device 22 | cudaMemcpy(d_A, A.get_buf(), A.bytesize(), cudaMemcpyHostToDevice); 23 | cudaMemcpy(d_B, B.get_buf(), B.bytesize(), cudaMemcpyHostToDevice); 24 | cudaMemcpy(d_C, C.get_buf(), C.bytesize(), cudaMemcpyHostToDevice); 25 | 26 | return std::make_tuple( 27 | reinterpret_cast(d_A), reinterpret_cast(d_B), reinterpret_cast(d_C)); 28 | } 29 | 30 | template 31 | inline auto dealloc_device_gemm_mats(std::tuple device_ptrs) -> void { 32 | cudaFree(std::get<0>(device_ptrs)); 33 | cudaFree(std::get<1>(device_ptrs)); 34 | cudaFree(std::get<2>(device_ptrs)); 35 | } 36 | 37 | template 38 | inline auto alloc_and_init_device_gemm_mats( 39 | const Matrix &A, const Matrix &B, const Matrix &C, const Matrix &D) 40 | -> std::tuple { 41 | // allocate for inputs and outputs on device 42 | void *d_A, *d_B, *d_C, *d_D; 43 | cudaMalloc(&d_A, A.bytesize()); 44 | cudaMalloc(&d_B, B.bytesize()); 45 | cudaMalloc(&d_C, C.bytesize()); 46 | cudaMalloc(&d_D, D.bytesize()); 47 | 48 | // copy inputs to device 49 | cudaMemcpy(d_A, A.get_buf(), A.bytesize(), cudaMemcpyHostToDevice); 50 | cudaMemcpy(d_B, B.get_buf(), B.bytesize(), cudaMemcpyHostToDevice); 51 | cudaMemcpy(d_C, C.get_buf(), C.bytesize(), cudaMemcpyHostToDevice); 52 | cudaMemcpy(d_D, D.get_buf(), D.bytesize(), cudaMemcpyHostToDevice); 53 | 54 | return std::make_tuple( 55 | reinterpret_cast(d_A), reinterpret_cast(d_B), reinterpret_cast(d_C), 56 | reinterpret_cast(d_D)); 57 | } 58 | 59 | template 60 | inline auto dealloc_device_gemm_mats(std::tuple device_ptrs) -> void { 61 | cudaFree(std::get<0>(device_ptrs)); 62 | cudaFree(std::get<1>(device_ptrs)); 63 | cudaFree(std::get<2>(device_ptrs)); 64 | cudaFree(std::get<3>(device_ptrs)); 65 | } 66 | 67 | } // namespace internal 68 | } // namespace fwgpu 69 | 70 | #endif // cuASR_INTERNAL_UTILS 71 | -------------------------------------------------------------------------------- /tools/include/cuasr/reference/srgemm/host_srgemm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass/coord.h" 4 | #include "cutlass/functional.h" 5 | #include "cutlass/numeric_conversion.h" 6 | #include "cutlass/numeric_types.h" 7 | 8 | #include "cutlass/arch/mma.h" 9 | #include "cutlass/gemm/gemm.h" 10 | #include "cutlass/tensor_view.h" 11 | #include "cutlass/util/host_tensor.h" 12 | 13 | namespace cuasr { 14 | namespace reference { 15 | namespace host { 16 | 17 | /// Host side SemiRing GEMM for rank-2 tensors for testing. 18 | template < 19 | typename AdditionOp, 20 | typename MultiplicationOp, 21 | typename ElementA, 22 | typename LayoutA, 23 | typename ElementB, 24 | typename LayoutB, 25 | typename ElementC, 26 | typename LayoutC, 27 | typename ScalarType, 28 | typename ComputeType, 29 | typename EpilogueOp, 30 | typename ConvertOp = cutlass::NumericConverter> 31 | struct Srgemm { 32 | public: 33 | void operator()( 34 | cutlass::gemm::GemmCoord problem_size, 35 | ComputeType alpha, 36 | cutlass::TensorRef tensor_a, 37 | cutlass::TensorRef tensor_b, 38 | ComputeType beta, 39 | cutlass::TensorRef tensor_c, 40 | cutlass::TensorRef tensor_d, 41 | ComputeType add_identity) { 42 | static_assert( 43 | LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, 44 | "Tensors must be of rank 2"); 45 | 46 | // use OMP to speed up host side reference GEMM if we can 47 | #pragma omp parallel proc_bind(spread) firstprivate( \ 48 | problem_size, tensor_a, tensor_b, tensor_c, tensor_d, add_identity, do_epilogue_add, \ 49 | alpha, beta) 50 | { 51 | int const M = problem_size.m(); 52 | int const N = problem_size.n(); 53 | int const K = problem_size.k(); 54 | 55 | // Blocking necessary to speedup reference implementation 56 | constexpr int Mblock = 32; 57 | constexpr int Nblock = 32; 58 | 59 | ConvertOp convert_op; 60 | AdditionOp add_op; 61 | MultiplicationOp mult_op; 62 | 63 | #pragma omp for schedule(static) collapse(2) 64 | for (int row_block = 0; row_block < M; row_block += Mblock) { 65 | for (int col_block = 0; col_block < N; col_block += Nblock) { 66 | // init registers 67 | ComputeType accum[Mblock][Nblock]; 68 | for (int j = 0; j < Nblock; j++) { 69 | for (int i = 0; i < Mblock; i++) { 70 | accum[i][j] = add_identity; 71 | } 72 | } 73 | 74 | // main loop over k-dim 75 | for (int k_block = 0; k_block < K; ++k_block) { 76 | for (int j = 0; j < Nblock; j++) { 77 | for (int i = 0; i < Mblock; i++) { 78 | int row = row_block + i; 79 | int col = col_block + j; 80 | if (row < M && col < N) { 81 | ElementA a = tensor_a.at(cutlass::MatrixCoord(row, k_block)); 82 | ElementB b = tensor_b.at(cutlass::MatrixCoord(k_block, col)); 83 | 84 | ComputeType compute_a(static_cast(a)); 85 | ComputeType compute_b(static_cast(b)); 86 | 87 | accum[i][j] = add_op(mult_op(compute_a, compute_b), accum[i][j]); 88 | } 89 | } 90 | } 91 | } 92 | 93 | // perform epilogue operator 94 | for (int j = 0; j < Nblock; j++) { 95 | for (int i = 0; i < Mblock; i++) { 96 | int row = row_block + i; 97 | int col = col_block + j; 98 | cutlass::MatrixCoord coord(row, col); 99 | if (row < M && col < N) { 100 | auto c = tensor_c.at(coord); 101 | tensor_d.at(coord) = convert_op( // 102 | add_op( // 103 | mult_op(alpha, accum[i][j]), // 104 | mult_op(beta, c) // 105 | ) // 106 | ); 107 | } 108 | } 109 | } 110 | } 111 | } // #pragma omp for 112 | } // #pragma omp parallel 113 | } 114 | }; 115 | 116 | } // namespace reference 117 | } // namespace host 118 | } // namespace cuasr 119 | --------------------------------------------------------------------------------