├── tools ├── CMakeLists.txt └── hands-on-opt │ ├── CMakeLists.txt │ └── hands-on-opt.cpp ├── ubsan.supp ├── examples ├── torch │ ├── .gitignore │ ├── linear │ │ ├── linear.py │ │ ├── cuda │ │ │ ├── run_hom.sh │ │ │ ├── linear_test.py │ │ │ ├── run_iree.sh │ │ │ ├── linear.py │ │ │ ├── run.sh │ │ │ ├── run_fp16.cu │ │ │ ├── run_fp16.sh │ │ │ ├── b.cu │ │ │ ├── run.cu │ │ │ ├── a.cu │ │ │ └── c.cu │ │ ├── run.sh │ │ └── run.cpp │ ├── elementwise │ │ ├── add.py │ │ ├── run_add.sh │ │ └── add.cu │ ├── benchmark.py │ ├── gelu │ │ └── generate_gelu.py │ ├── layernorm │ │ ├── cuda │ │ │ ├── layernorm.py │ │ │ ├── run.sh │ │ │ ├── layernorm.cu │ │ │ ├── ln_gemm.cu │ │ │ └── gemm_with_mean_var.cu │ │ └── generate_layernorm.py │ ├── bert │ │ ├── modify_for_iree.py │ │ ├── parse_iree.py │ │ ├── parse_hom.py │ │ ├── compile.sh │ │ ├── run_iree.sh │ │ ├── benchmark_bert.py │ │ ├── convert_hf.py │ │ └── run_hom.sh │ ├── softmax │ │ └── generate_softmax.py │ └── bert_attention │ │ ├── run_cuseqlen.sh │ │ ├── bert_attention.py │ │ ├── bert_self_attention.py │ │ ├── run_bert_self_attn.sh │ │ ├── run_bert_attn.sh │ │ ├── cuSeqLen.cu │ │ ├── bert_attn.cu │ │ └── bert_self_attn.cu └── mlir │ ├── cpu_gemm │ ├── README.md │ ├── matmul.mlir │ ├── run.sh │ └── naive.mlir │ └── utils │ ├── run.sh │ └── fill_and_print.mlir ├── lib ├── Dialect │ ├── CMakeLists.txt │ ├── HOM │ │ ├── CMakeLists.txt │ │ ├── HOMOps.cpp │ │ └── HOMSerializeWeight.cpp │ └── HOMNVGPU │ │ ├── CMakeLists.txt │ │ ├── HOMNVGPUOps.cpp │ │ ├── HOMNVGPUAutotunePass.cpp │ │ ├── HOMNVGPUFusionPass.cpp │ │ └── HOMNVGPULegalizeGemmPass.cpp ├── Conversions │ ├── MatMulCPUOptimize │ │ └── CMakeLists.txt │ ├── CMakeLists.txt │ ├── Function │ │ ├── CMakeLists.txt │ │ └── OptimizeMemory.cpp │ ├── HOM │ │ ├── CMakeLists.txt │ │ └── HOMToHOMNVGPU.cpp │ ├── FP32toFP16 │ │ └── CMakeLists.txt │ └── Tosa │ │ ├── CMakeLists.txt │ │ └── TosaToHOM.cpp ├── WeightsEngine │ ├── CMakeLists.txt │ └── WeightsEngine.cpp ├── NVGPUKernels │ ├── GemmManifest.cu │ ├── GemmRunner.cu │ └── CMakeLists.txt ├── CMakeLists.txt └── ExecutionEngine │ ├── ExecutionEngine.cpp │ └── CutlassCAPI.cu ├── include ├── CMakeLists.txt ├── Dialect │ ├── CMakeLists.txt │ ├── HOM │ │ ├── CMakeLists.txt │ │ ├── Passes.h │ │ ├── Passes.td │ │ ├── HOMOps.h │ │ ├── HOMTypesBase.td │ │ └── HOMFusion.pdll │ └── HOMNVGPU │ │ ├── HOMNVGPUAutotune.pdll │ │ ├── Passes.h │ │ ├── CMakeLists.txt │ │ ├── HOMNVGPUFusion.pdll │ │ ├── Passes.td │ │ ├── HOMNVGPUOps.h │ │ ├── HOMNVGPULegalizeGemm.pdll │ │ └── HOMNVGPUOps.td ├── Conversions │ ├── CMakeLists.txt │ ├── Function │ │ ├── CMakeLists.txt │ │ ├── Passes.td │ │ └── Passes.h │ ├── FP32toFP16 │ │ ├── Passes.td │ │ ├── CMakeLists.txt │ │ ├── Passes.h │ │ └── FP32toFP16.pdll │ ├── HOM │ │ ├── Passes.td │ │ ├── CMakeLists.txt │ │ ├── Passes.h │ │ └── HOMToHOMNVGPU.pdll │ ├── Tosa │ │ ├── CMakeLists.txt │ │ ├── Passes.td │ │ └── Passes.h │ └── MatMulCPUOptimize │ │ └── Passes.h ├── half.h ├── InitAllDialects.h ├── Utils.pdll ├── WeightsEngine │ └── WeightsEngine.h ├── InitAllPasses.h ├── NVGPUKernels │ ├── GemmManifest.h │ ├── GemmProfiler.h │ ├── OperationRunner.h │ ├── CuSeqLen.h │ └── GatherRunner.h └── ExecutionEngine │ └── HandsOnRunnerUtils.h ├── requirements.txt ├── cmake ├── CMakeLists.txt └── check_simd.cmake ├── .clangd ├── lsan.supp ├── .gitmodules ├── .clang-format ├── .gitignore ├── .pre-commit-config.yaml ├── README_OLD.md ├── CMakeLists.txt └── README.md /tools/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(hands-on-opt) 2 | -------------------------------------------------------------------------------- /ubsan.supp: -------------------------------------------------------------------------------- 1 | # Currently no idea on this error. 2 | vptr:shared_ptr_base.h 3 | -------------------------------------------------------------------------------- /examples/torch/.gitignore: -------------------------------------------------------------------------------- 1 | *.mlir 2 | *.txt 3 | *.s 4 | run 5 | *.dot 6 | *.vmfb 7 | -------------------------------------------------------------------------------- /lib/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(HOM) 2 | add_subdirectory(HOMNVGPU) 3 | -------------------------------------------------------------------------------- /include/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Dialect) 2 | add_subdirectory(Conversions) 3 | -------------------------------------------------------------------------------- /include/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(HOM) 2 | add_subdirectory(HOMNVGPU) 3 | -------------------------------------------------------------------------------- /lib/Conversions/MatMulCPUOptimize/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(MatMulCPUOptimization MatMulCPUOptimize.cpp) 2 | -------------------------------------------------------------------------------- /examples/mlir/cpu_gemm/README.md: -------------------------------------------------------------------------------- 1 | # \[WIP\] Matmul Optimize demo 2 | 3 | Simply try `sh run.sh`. And hopefully, it should show gflops. 4 | -------------------------------------------------------------------------------- /include/Conversions/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Function) 2 | add_subdirectory(Tosa) 3 | add_subdirectory(HOM) 4 | add_subdirectory(FP32toFP16) 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://llvm.github.io/torch-mlir/package-index/ --pre 2 | torch-mlir==20240127.1096 3 | transformers==4.37.2 4 | pre-commit==3.6.0 5 | -------------------------------------------------------------------------------- /cmake/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file( 2 | COPY . 3 | DESTINATION ${CMAKE_CURRENT_BINARY_DIR} 4 | FILES_MATCHING 5 | PATTERN *.cmake 6 | PATTERN CMakeFiles EXCLUDE) 7 | -------------------------------------------------------------------------------- /.clangd: -------------------------------------------------------------------------------- 1 | CompileFlags: 2 | Add: --no-cuda-version-check 3 | Remove: [-fno-lifetime-dse, -forward-unknown-to-host-compiler, --generate-code=*, -Xcompiler=-fPIC] 4 | -------------------------------------------------------------------------------- /lib/WeightsEngine/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | aux_source_directory(. DIR_LIB_SRCS) 2 | set(LLVM_LINK_COMPONENTS Core Support) 3 | 4 | add_llvm_library(WeightsEngine ${DIR_LIB_SRCS}) 5 | -------------------------------------------------------------------------------- /lsan.supp: -------------------------------------------------------------------------------- 1 | # This is a known leak since we don't call deallocFn when exit the program. 2 | leak:HandsOnRunnerUtils.cpp 3 | 4 | # Have no idea on this leak. 5 | leak:cuda/run 6 | -------------------------------------------------------------------------------- /include/half.h: -------------------------------------------------------------------------------- 1 | namespace mlir { 2 | namespace hands_on_mlir { 3 | #ifdef __clang__ 4 | typedef _Float16 fp16; 5 | #endif 6 | } // namespace hands_on_mlir 7 | } // namespace mlir 8 | -------------------------------------------------------------------------------- /lib/Conversions/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(MatMulCPUOptimize) 2 | add_subdirectory(Tosa) 3 | add_subdirectory(Function) 4 | add_subdirectory(HOM) 5 | add_subdirectory(FP32toFP16) 6 | -------------------------------------------------------------------------------- /include/Conversions/Function/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name HOMToFuncTransforms) 3 | add_public_tablegen_target(HOMToFuncTransformsPassIncGen) 4 | -------------------------------------------------------------------------------- /include/Conversions/FP32toFP16/Passes.td: -------------------------------------------------------------------------------- 1 | include "mlir/Pass/PassBase.td" 2 | 3 | def HOMFP32ToFP16Pass : Pass<"hom-fp32-to-fp16", "mlir::func::FuncOp"> { 4 | let summary = "Lowering HOM to Func"; 5 | let dependentDialects = ["::mlir::func::FuncDialect"]; 6 | } 7 | -------------------------------------------------------------------------------- /lib/NVGPUKernels/GemmManifest.cu: -------------------------------------------------------------------------------- 1 | #include "NVGPUKernels/GemmManifest.h" 2 | 3 | namespace mlir { 4 | namespace hands_on_mlir { 5 | namespace homnvgpu_kernel { 6 | 7 | GemmManifest manifest; 8 | 9 | } 10 | } // namespace hands_on_mlir 11 | } // namespace mlir 12 | -------------------------------------------------------------------------------- /lib/Conversions/Function/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB sources_ *.cpp) 2 | 3 | add_mlir_library( 4 | HOMToFuncTransforms 5 | ${sources_} 6 | DEPENDS 7 | HOMToFuncTransformsPassIncGen 8 | LINK_LIBS 9 | PUBLIC 10 | MLIRIR 11 | MLIRPass 12 | MLIRTransforms) 13 | -------------------------------------------------------------------------------- /include/Conversions/HOM/Passes.td: -------------------------------------------------------------------------------- 1 | include "mlir/Pass/PassBase.td" 2 | 3 | def HOMToHOMNVGPUPass : Pass<"hom-to-homnvgpu", "mlir::func::FuncOp"> { 4 | let summary = "Lowering HOM to HOMNVGPU"; 5 | let dependentDialects = ["::mlir::hands_on_mlir::homnvgpu::HOMNVGPUDialect"]; 6 | } 7 | -------------------------------------------------------------------------------- /lib/Conversions/HOM/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB sources_ *.cpp) 2 | 3 | add_mlir_library( 4 | HOMToHOMNVGPUNVGPUTransforms 5 | ${sources_} 6 | DEPENDS 7 | HOMToHOMNVGPUNVGPUTransformsPassIncGen 8 | HOMToHOMNVGPUPDLLPatternsIncGen 9 | LINK_LIBS 10 | MLIRIR 11 | MLIRPass 12 | MLIRTransforms) 13 | -------------------------------------------------------------------------------- /include/Conversions/Tosa/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaToHOMTransforms) 3 | add_public_tablegen_target(TosaToHOMTransformsPassIncGen) 4 | 5 | add_mlir_pdll_library(TosaToHOMPDLLPatternsIncGen TosaToHOM.pdll 6 | TosaToHOM.pdll.h.inc) 7 | -------------------------------------------------------------------------------- /lib/Conversions/FP32toFP16/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | aux_source_directory(. DIR_LIB_SRCS) 2 | 3 | add_mlir_library( 4 | HOMFP32ToFP16Transforms 5 | ${DIR_LIB_SRCS} 6 | DEPENDS 7 | HOMFP32ToFP16TransformsPassIncGen 8 | HOMFP32ToFP16PDLLPatternsIncGen 9 | LINK_LIBS 10 | PUBLIC 11 | MLIRIR 12 | MLIRPass 13 | MLIRTransforms) 14 | -------------------------------------------------------------------------------- /include/InitAllDialects.h: -------------------------------------------------------------------------------- 1 | #include "HOM/HOMOps.h" 2 | #include "HOMNVGPU/HOMNVGPUOps.h" 3 | 4 | namespace mlir { 5 | namespace hands_on_mlir { 6 | inline void registerAllDialects(DialectRegistry ®istry) { 7 | registry.insert(); 8 | } 9 | } // namespace hands_on_mlir 10 | } // namespace mlir 11 | -------------------------------------------------------------------------------- /include/Conversions/FP32toFP16/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name HOMFP32ToFP16Transforms) 3 | add_public_tablegen_target(HOMFP32ToFP16TransformsPassIncGen) 4 | 5 | add_mlir_pdll_library(HOMFP32ToFP16PDLLPatternsIncGen FP32toFP16.pdll 6 | HOMFP32ToFP16.pdll.h.inc) 7 | -------------------------------------------------------------------------------- /include/Conversions/HOM/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name HOMToHOMNVGPUNVGPUTransforms) 3 | add_public_tablegen_target(HOMToHOMNVGPUNVGPUTransformsPassIncGen) 4 | 5 | add_mlir_pdll_library(HOMToHOMNVGPUPDLLPatternsIncGen HOMToHOMNVGPU.pdll 6 | HOMToHOMNVGPU.pdll.h.inc) 7 | -------------------------------------------------------------------------------- /lib/Conversions/Tosa/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB sources_ *.cpp) 2 | 3 | add_mlir_library( 4 | TosaToHOMTransforms 5 | ${sources_} 6 | DEPENDS 7 | TosaToHOMTransformsPassIncGen 8 | TosaToHOMPDLLPatternsIncGen 9 | WeightsEngine 10 | LINK_COMPONENTS 11 | Core 12 | LINK_LIBS 13 | PUBLIC 14 | MLIRIR 15 | MLIRPass 16 | MLIRTransforms 17 | WeightsEngine) 18 | -------------------------------------------------------------------------------- /examples/mlir/cpu_gemm/matmul.mlir: -------------------------------------------------------------------------------- 1 | func.func @main() { 2 | %A = memref.alloc() : memref<2088x2048xf32> 3 | // Align %B and %C since these are shape cast to vector types. 4 | %B = memref.alloc() {alignment = 32} : memref<2048x2048xf32> 5 | %C = memref.alloc() {alignment = 32} : memref<2088x2048xf32> 6 | linalg.matmul ins(%A, %B : memref<2088x2048xf32>, memref<2048x2048xf32>) outs(%C : memref<2088x2048xf32>) 7 | return 8 | } 9 | -------------------------------------------------------------------------------- /lib/NVGPUKernels/GemmRunner.cu: -------------------------------------------------------------------------------- 1 | #include "NVGPUKernels/GemmRunner.h" 2 | 3 | namespace mlir { 4 | namespace hands_on_mlir { 5 | namespace homnvgpu_kernel { 6 | 7 | GemmOperationRunnerBase::~GemmOperationRunnerBase() {} 8 | bool GemmOperationRunnerBase::contains(const char *str) { 9 | return strstr(description_.name, str) != nullptr; 10 | } 11 | 12 | } // namespace homnvgpu_kernel 13 | } // namespace hands_on_mlir 14 | } // namespace mlir 15 | -------------------------------------------------------------------------------- /include/Dialect/HOM/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect(HOMOps hom) 2 | set(LLVM_TARGET_DEFINITIONS HOMOps.td) 3 | add_mlir_doc(HOMOps HOMDialect Dialects/ -gen-dialect-doc) 4 | 5 | set(LLVM_TARGET_DEFINITIONS Passes.td) 6 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name HOMTransforms) 7 | add_public_tablegen_target(HOMTransformsPassIncGen) 8 | 9 | add_mlir_pdll_library(HOMFusionPDLLPatternsIncGen HOMFusion.pdll 10 | HOMFusion.pdll.h.inc) 11 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/llvm-project"] 2 | path = thirdparty/llvm-project 3 | url = https://github.com/llvm/llvm-project.git 4 | [submodule "thirdparty/cutlass"] 5 | path = thirdparty/cutlass 6 | url = https://github.com/NVIDIA/cutlass.git 7 | [submodule "thirdparty/TransformerEngine"] 8 | path = thirdparty/TransformerEngine 9 | url = https://github.com/NVIDIA/TransformerEngine.git 10 | [submodule "thirdparty/llvm-project/"] 11 | branch = main 12 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/HOMNVGPUAutotune.pdll: -------------------------------------------------------------------------------- 1 | #include "HOMNVGPU/HOMNVGPUOps.td" 2 | 3 | Rewrite profileMatmul(op0 : Op); 4 | 5 | Pattern { 6 | let kernel = attr<"0 : i32">; 7 | let matmul = op(input0 8 | : Value, input1 9 | : Value, input2 10 | : Value){kernel_name = kernel}; 11 | 12 | rewrite matmul with { profileMatmul(matmul); }; 13 | } 14 | -------------------------------------------------------------------------------- /lib/Dialect/HOM/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB _sources *.cpp) 2 | 3 | add_mlir_library( 4 | MLIRHOM 5 | ${_sources} 6 | ADDITIONAL_HEADER_DIRS 7 | ${HANDS_ON_MLIR_INCLUDE_DIR}/Dialect/HOM 8 | DEPENDS 9 | MLIRHOMOpsIncGen 10 | HOMTransformsPassIncGen 11 | HOMFusionPDLLPatternsIncGen 12 | LINK_LIBS 13 | PUBLIC 14 | MLIRIR 15 | MLIRPass 16 | MLIRSupport 17 | MLIRParser 18 | MLIRFuncDialect 19 | MLIRQuantDialect 20 | MLIRRewrite 21 | MLIRTransforms) 22 | -------------------------------------------------------------------------------- /examples/mlir/utils/run.sh: -------------------------------------------------------------------------------- 1 | hands-on-opt -convert-linalg-to-loops -lower-affine -convert-vector-to-llvm -convert-memref-to-llvm -convert-scf-to-cf -convert-arith-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts fill_and_print.mlir | mlir-cpu-runner -O3 -e main -entry-point-result=void -shared-libs=../../../llvm-project/build/lib/libmlir_c_runner_utils.dylib -shared-libs=../../../llvm-project/build/lib/libmlir_runner_utils.dylib -shared-libs=../../build/lib/libhands_on_mlir_runner_utils.dylib 2 | -------------------------------------------------------------------------------- /include/Conversions/FP32toFP16/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_CONVERSIONS_FP32TOFP16_PASS_H_ 2 | #define HOM_CONVERSIONS_FP32TOFP16_PASS_H_ 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/Pass/Pass.h" 6 | #include "mlir/Transforms/DialectConversion.h" 7 | 8 | namespace mlir { 9 | namespace hands_on_mlir { 10 | 11 | #define GEN_PASS_DECL_HOMFP32TOFP16PASS 12 | #define GEN_PASS_REGISTRATION 13 | #include "Conversions/FP32toFP16/Passes.h.inc" 14 | 15 | } // namespace hands_on_mlir 16 | } // namespace mlir 17 | 18 | #endif 19 | -------------------------------------------------------------------------------- /examples/torch/linear/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | 4 | torch.manual_seed(42) 5 | 6 | 7 | class A(torch.nn.Module): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | self.fc = torch.nn.Linear(100, 10) 11 | 12 | def forward(self, x): 13 | return self.fc(x) 14 | 15 | 16 | a = A() 17 | 18 | x = torch.ones(2, 3, 100) 19 | 20 | print(a(x)) 21 | 22 | module = torch_mlir.compile(a, x, output_type="tosa") 23 | with open("linear.mlir", "w") as fl: 24 | print(module, file=fl, end="") 25 | -------------------------------------------------------------------------------- /include/Conversions/HOM/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_CONVERSIONS_HOM_TRANSFORMS_PASSES_H 2 | #define HOM_CONVERSIONS_HOM_TRANSFORMS_PASSES_H 3 | 4 | #include "Dialect/HOM/HOMOps.h" 5 | #include "Dialect/HOMNVGPU/HOMNVGPUOps.h" 6 | #include "mlir/Pass/Pass.h" 7 | 8 | namespace mlir { 9 | namespace hands_on_mlir { 10 | 11 | #define GEN_PASS_DECL_HOMTOHOMNVGPUPASS 12 | #define GEN_PASS_REGISTRATION 13 | #include "Conversions/HOM/Passes.h.inc" 14 | 15 | } // namespace hands_on_mlir 16 | } // namespace mlir 17 | 18 | #endif // HOM_CONVERSIONS_HOM_TRANSFORMS_PASSES_H 19 | -------------------------------------------------------------------------------- /lib/Dialect/HOMNVGPU/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB _sources *.cpp) 2 | 3 | add_mlir_library( 4 | MLIRHOMNVGPU 5 | ${_sources} 6 | ADDITIONAL_HEADER_DIRS 7 | ${HANDS_ON_MLIR_INCLUDE_DIR}/Dialect/HOMNVGPU 8 | DEPENDS 9 | MLIRHOMNVGPUOpsIncGen 10 | HOMNVGPUTransformsPassIncGen 11 | HOMNVGPUFusionPDLLPatternsIncGen 12 | HOMNVGPUAutotunePDLLPatternsIncGen 13 | HOMNVGPULegalizeGemmPDLLPatternsIncGen 14 | LINK_LIBS 15 | PUBLIC 16 | MLIRIR 17 | MLIRPass 18 | MLIRSupport 19 | MLIRParser 20 | MLIRFuncDialect 21 | MLIRQuantDialect 22 | MLIRRewrite 23 | MLIRTransforms) 24 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # This file is used by clang-format to autoformat source code 2 | # 3 | # The clang-format is part of llvm toolchain. 4 | # It need to install llvm and clang to format source code style. 5 | # 6 | # The basic usage is, 7 | # clang-format -i -style=file PATH/TO/SOURCE/CODE 8 | # 9 | # The -style=file implicit use ".clang-format" file located in one of 10 | # parent directory. 11 | # The -i means inplace change. 12 | # 13 | # The document of clang-format is 14 | # http://clang.llvm.org/docs/ClangFormat.html 15 | # http://clang.llvm.org/docs/ClangFormatStyleOptions.html 16 | 17 | BasedOnStyle: LLVM 18 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/run_hom.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | # Generated by Kimi Chat 4 | 5 | # 遍历当前目录及子目录下的所有 .mlir 文件 6 | find . -type f -name "hom_*.mlir" | while read mlir_file; do 7 | 8 | pattern="hom_linear_([0-9]+)_([0-9]+)_([0-9]+)\.mlir" 9 | 10 | if [[ $mlir_file =~ $pattern ]]; then 11 | M="${BASH_REMATCH[1]}" 12 | N="${BASH_REMATCH[2]}" 13 | K="${BASH_REMATCH[3]}" 14 | 15 | # 输出提取的值 16 | echo "M: $M, N: $N, K: $K" 17 | fi 18 | 19 | bash ./run_fp16.sh $mlir_file 20 | 21 | done 22 | 23 | echo "Compilation process completed for all matching files." 24 | -------------------------------------------------------------------------------- /examples/torch/elementwise/add.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | 4 | torch.manual_seed(42) 5 | 6 | a = torch.rand((3, 3, 3)) 7 | b = torch.rand((3, 1, 3)) 8 | 9 | 10 | class Wrapper(torch.nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, a, b): 15 | return a + b 16 | 17 | 18 | model = Wrapper() 19 | 20 | model.eval() 21 | 22 | print(a, b) 23 | 24 | print(model(a, b)) 25 | 26 | with torch.no_grad(): 27 | module = torch_mlir.compile(model, (a, b), output_type="tosa") 28 | with open("add.mlir", "w") as fl: 29 | print(module, file=fl, end="") 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # build and vscode 35 | build/* 36 | .vscode/* 37 | optimized.mlir 38 | .cache/* 39 | 40 | # Memo 41 | memo.md 42 | 43 | # nsys 44 | *.nsys-rep 45 | 46 | # python 47 | __pycache__/ 48 | 49 | # log 50 | *.log 51 | -------------------------------------------------------------------------------- /examples/torch/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertTokenizer 3 | 4 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 5 | 6 | 7 | @torch.no_grad() 8 | def speed_test(model, input_list): 9 | print([(i.shape, i.dtype) for i in input_list]) 10 | 11 | for _ in range(10): 12 | model(*input_list) 13 | 14 | a = torch.cuda.Event(True) 15 | b = torch.cuda.Event(True) 16 | a.record() 17 | 18 | for _ in range(1000): 19 | model(*input_list) 20 | 21 | b.record() 22 | 23 | torch.cuda.synchronize() 24 | 25 | time = a.elapsed_time(b) 26 | 27 | print(time / 1000) 28 | -------------------------------------------------------------------------------- /examples/mlir/cpu_gemm/run.sh: -------------------------------------------------------------------------------- 1 | ../../build/bin/hands-on-opt --matmul-cpu-optimize --convert-linalg-to-affine-loops \ 2 | -lower-affine -convert-scf-to-cf -convert-vector-to-llvm \ 3 | -finalize-memref-to-llvm -convert-arith-to-llvm --convert-math-to-llvm \ 4 | -convert-func-to-llvm -reconcile-unrealized-casts naive.mlir | \ 5 | mlir-cpu-runner -O3 -e main \ 6 | -entry-point-result=void \ 7 | -shared-libs=../../../llvm-project/build/lib/libmlir_c_runner_utils.dylib \ 8 | -shared-libs=../../../llvm-project/build/lib/libmlir_runner_utils.dylib \ 9 | -shared-libs=../../build/lib/libhands_on_mlir_runner_utils.dylib 10 | -------------------------------------------------------------------------------- /examples/torch/gelu/generate_gelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | 4 | hidden_states = torch.rand((1, 10, 100)) 5 | 6 | 7 | class Wrapper(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.ln = torch.nn.LayerNorm(100) 11 | 12 | def forward(self, hidden_states): 13 | return torch.nn.functional.gelu(hidden_states) 14 | 15 | 16 | model = Wrapper() 17 | 18 | model.eval() 19 | 20 | with torch.no_grad(): 21 | module = torch_mlir.compile( 22 | model, hidden_states, output_type="tosa", use_tracing=True 23 | ) 24 | with open("gelu.mlir", "w") as fl: 25 | print(module, file=fl, end="") 26 | -------------------------------------------------------------------------------- /include/Conversions/FP32toFP16/FP32toFP16.pdll: -------------------------------------------------------------------------------- 1 | #include "HOM/HOMOps.td" 2 | #include "HOMNVGPU/HOMNVGPUOps.td" 3 | #include "mlir/Dialect/Tosa/IR/TosaOps.td" 4 | 5 | Constraint okToInsertCast(op 6 | : Op)[{ 7 | auto constOp = dyn_cast(op); 8 | if (constOp.getResult().getType().getElementType().isF16() || 9 | !constOp.getResult().getType().getElementType().isF32()) { 10 | return failure(); 11 | } 12 | return success(); 13 | }]; 14 | 15 | Rewrite generateCastOp(op : Op); 16 | 17 | Pattern { 18 | let root = op; 19 | okToInsertCast(root); 20 | 21 | rewrite root with { generateCastOp(root); }; 22 | } 23 | -------------------------------------------------------------------------------- /include/Conversions/Tosa/Passes.td: -------------------------------------------------------------------------------- 1 | include "mlir/Pass/PassBase.td" 2 | 3 | def TosaToHOMPass : Pass<"tosa-to-hom", "mlir::func::FuncOp"> { 4 | let summary = "Lower Tosa to the HOM dialect"; 5 | let dependentDialects = [ 6 | "::mlir::hands_on_mlir::hom::HOMDialect", "::mlir::pdl::PDLDialect", 7 | "::mlir::pdl_interp::PDLInterpDialect" 8 | ]; 9 | } 10 | 11 | def TosaConstantFoldingPass 12 | : Pass<"hom-tosa-constant-folding", "mlir::func::FuncOp"> { 13 | let summary = "Lower Tosa to the HOM dialect"; 14 | let dependentDialects = [ 15 | "::mlir::tosa::TosaDialect", "::mlir::pdl::PDLDialect", 16 | "::mlir::pdl_interp::PDLInterpDialect" 17 | ]; 18 | } 19 | -------------------------------------------------------------------------------- /include/Dialect/HOM/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_HOMTRANSFORMS_PASSES_H 2 | #define HOM_HOMTRANSFORMS_PASSES_H 3 | 4 | #include 5 | 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/Dialect/PDL/IR/PDL.h" 8 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 9 | #include "mlir/Pass/Pass.h" 10 | 11 | namespace mlir { 12 | namespace hands_on_mlir { 13 | namespace hom { 14 | 15 | #define GEN_PASS_DECL_HOMFUSIONPASS 16 | #define GEN_PASS_DECL_HOMSERIALIZEWEIGHTPASS 17 | #define GEN_PASS_REGISTRATION 18 | #include "HOM/Passes.h.inc" 19 | 20 | } // namespace hom 21 | } // namespace hands_on_mlir 22 | } // namespace mlir 23 | 24 | #endif // HOM_HOMTRANSFORMS_PASSES_H 25 | -------------------------------------------------------------------------------- /include/Dialect/HOM/Passes.td: -------------------------------------------------------------------------------- 1 | include "mlir/Pass/PassBase.td" 2 | 3 | def HOMFusionPass : Pass<"hom-fusion", "mlir::func::FuncOp"> { 4 | let summary = "HOM Fusion Pass"; 5 | let dependentDialects = [ 6 | "::mlir::hands_on_mlir::hom::HOMDialect", "::mlir::tosa::TosaDialect", 7 | "::mlir::pdl::PDLDialect", "::mlir::pdl_interp::PDLInterpDialect" 8 | ]; 9 | } 10 | 11 | def HOMSerializeWeightPass 12 | : Pass<"hom-serialize-weight", "mlir::func::FuncOp"> { 13 | let summary = "Serialize Weight Pass"; 14 | let dependentDialects = [ 15 | "::mlir::hands_on_mlir::hom::HOMDialect", "::mlir::pdl::PDLDialect", 16 | "::mlir::pdl_interp::PDLInterpDialect" 17 | ]; 18 | } 19 | -------------------------------------------------------------------------------- /examples/torch/layernorm/cuda/layernorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | 4 | torch.manual_seed(42) 5 | 6 | hidden_states = torch.rand((1, 2, 10)) 7 | 8 | 9 | class Wrapper(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.ln = torch.nn.LayerNorm(10) 13 | 14 | def forward(self, hidden_states): 15 | return self.ln(hidden_states) 16 | 17 | 18 | model = Wrapper() 19 | 20 | model.eval() 21 | 22 | print(hidden_states) 23 | 24 | print(model(hidden_states)) 25 | 26 | with torch.no_grad(): 27 | module = torch_mlir.compile(model, hidden_states, output_type="tosa") 28 | with open("layernorm.mlir", "w") as fl: 29 | print(module, file=fl, end="") 30 | -------------------------------------------------------------------------------- /examples/torch/layernorm/generate_layernorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | 4 | hidden_states = torch.rand((1, 10, 100)) 5 | 6 | 7 | class Wrapper(torch.nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.ln = torch.nn.LayerNorm(10) 11 | self.l = torch.nn.Linear(100, 10, bias=False) 12 | 13 | def forward(self, hidden_states): 14 | return self.ln(self.l(hidden_states)) 15 | 16 | 17 | model = Wrapper() 18 | 19 | model.eval() 20 | 21 | with torch.no_grad(): 22 | module = torch_mlir.compile( 23 | model, hidden_states, output_type="tosa", use_tracing=True 24 | ) 25 | with open("layernorm.mlir", "w") as fl: 26 | print(module, file=fl, end="") 27 | -------------------------------------------------------------------------------- /tools/hands-on-opt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) 2 | get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) 3 | 4 | add_llvm_tool(hands-on-opt hands-on-opt.cpp) 5 | 6 | target_link_libraries( 7 | hands-on-opt 8 | PRIVATE ${dialect_libs} 9 | ${conversion_libs} 10 | HOMToFuncTransforms 11 | HOMToHOMNVGPUNVGPUTransforms 12 | MLIROptLib 13 | MLIRArithDialect 14 | MatMulCPUOptimization 15 | MLIRHOM 16 | MLIRHOMNVGPU 17 | WeightsEngine 18 | TosaToHOMTransforms 19 | HOMFP32ToFP16Transforms 20 | GemmManifestAndProfiler) 21 | 22 | mlir_check_all_link_libraries(hands-on-opt) 23 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOMNVGPU_TRANSFORMS_PASSES_H 2 | #define HOMNVGPU_TRANSFORMS_PASSES_H 3 | 4 | #include 5 | 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/Dialect/PDL/IR/PDL.h" 8 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 9 | #include "mlir/Pass/Pass.h" 10 | 11 | namespace mlir { 12 | namespace hands_on_mlir { 13 | namespace homnvgpu { 14 | 15 | #define GEN_PASS_DECL_HOMNVGPUFUSIONPASS 16 | #define GEN_PASS_DECL_HOMNVGPUAUTOTUNEPASS 17 | #define GEN_PASS_DECL_HOMNVGPULEGALIZEGEMMPASS 18 | #define GEN_PASS_REGISTRATION 19 | #include "HOMNVGPU/Passes.h.inc" 20 | 21 | } // namespace homnvgpu 22 | } // namespace hands_on_mlir 23 | } // namespace mlir 24 | 25 | #endif // HOMNVGPU_TRANSFORMS_PASSES_H 26 | -------------------------------------------------------------------------------- /include/Utils.pdll: -------------------------------------------------------------------------------- 1 | #include "HOM/HOMOps.td" 2 | 3 | Rewrite getSingleFloatValue(op 4 | : Op) 5 | ->F32Attr[{ 6 | auto value = dyn_cast(op).getValueAttr(); 7 | auto elementType = value.getElementType(); 8 | 9 | auto data = value.getValues()[0].convertToFloat(); 10 | return rewriter.getFloatAttr(elementType, data); 11 | }]; 12 | 13 | Constraint isSingleFloatConstant(op 14 | : Op)[{ 15 | auto constOp = dyn_cast(op); 16 | if (constOp) { 17 | auto value = constOp.getValueAttr(); 18 | auto elementType = value.getElementType(); 19 | return success(elementType.isF32() && value.getNumElements() == 1); 20 | } 21 | return failure(); 22 | }]; 23 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect(HOMNVGPUOps homnvgpu) 2 | set(LLVM_TARGET_DEFINITIONS HOMNVGPUOps.td) 3 | add_mlir_doc(HOMNVGPUOps HOMNVGPUDialect Dialects/ -gen-dialect-doc) 4 | 5 | set(LLVM_TARGET_DEFINITIONS Passes.td) 6 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name HOMNVGPUTransforms) 7 | add_public_tablegen_target(HOMNVGPUTransformsPassIncGen) 8 | 9 | add_mlir_pdll_library(HOMNVGPUFusionPDLLPatternsIncGen HOMNVGPUFusion.pdll 10 | HOMNVGPUFusion.pdll.h.inc) 11 | add_mlir_pdll_library(HOMNVGPULegalizeGemmPDLLPatternsIncGen 12 | HOMNVGPULegalizeGemm.pdll HOMNVGPULegalizeGemm.pdll.h.inc) 13 | add_mlir_pdll_library(HOMNVGPUAutotunePDLLPatternsIncGen HOMNVGPUAutotune.pdll 14 | HOMNVGPUAutotune.pdll.h.inc) 15 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/HOMNVGPUFusion.pdll: -------------------------------------------------------------------------------- 1 | #include "HOM/HOMOps.td" 2 | #include "HOMNVGPU/HOMNVGPUOps.td" 3 | #include "mlir/Dialect/Tosa/IR/TosaOps.td" 4 | 5 | Constraint hasOneUse(op : Op)[{ return success(op->hasOneUse()); }]; 6 | 7 | // I have to add a rewrite pattern here rather than pdll rewrite statements. Or 8 | // it will emit error. Idk why. Error message: `runtime error: member access 9 | // within null pointer of type 'mlir::IRObjectWithUseList'` 10 | Rewrite generateGemmLnGemm(op0 : Op, op1 : Op, op2 : Op); 11 | Rewrite updateMaskWithCuSeqLen(op0 : Op, op1 : Op); 12 | 13 | Pattern { 14 | let mask = op(input : Value); 15 | let attn = op(input0 : Value, mask); 16 | 17 | rewrite attn with { updateMaskWithCuSeqLen(mask, attn); }; 18 | } 19 | -------------------------------------------------------------------------------- /tools/hands-on-opt/hands-on-opt.cpp: -------------------------------------------------------------------------------- 1 | #include "mlir/IR/MLIRContext.h" 2 | #include "mlir/InitAllDialects.h" 3 | #include "mlir/InitAllPasses.h" 4 | #include "mlir/Support/FileUtilities.h" 5 | #include "mlir/Tools/mlir-opt/MlirOptMain.h" 6 | 7 | #include "InitAllDialects.h" 8 | #include "InitAllPasses.h" 9 | 10 | int main(int argc, char **argv) { 11 | // Register all MLIR passes. 12 | mlir::registerAllPasses(); 13 | mlir::hands_on_mlir::registerAllPasses(); 14 | 15 | mlir::DialectRegistry registry; 16 | // Register all MLIR core dialects. 17 | registerAllDialects(registry); 18 | // Register dialects in hands-on-mlir project. 19 | mlir::hands_on_mlir::registerAllDialects(registry); 20 | 21 | return mlir::failed(mlir::MlirOptMain( 22 | argc, argv, "hands-on-mlir optimizer driver", registry)); 23 | } 24 | -------------------------------------------------------------------------------- /examples/torch/bert/modify_for_iree.py: -------------------------------------------------------------------------------- 1 | for name in ["bert-base-uncased", "bert-large-uncased"]: 2 | for bs in [1, 8, 16, 32]: 3 | for len in [64, 128]: 4 | 5 | with open(f"{name}_{bs}_{len}.mlir", "r") as fl_in: 6 | with open(f"iree_{name}_{bs}_{len}.mlir", "w") as fl_out: 7 | found = False 8 | for line in fl_in.readlines(): 9 | if not found and "func.func @forward" in line: 10 | idx = line.find(") ->") 11 | line = ( 12 | line[:idx] 13 | + ", %arg3: !hal.buffer {iree.abi.output = 0 : index}" 14 | + line[idx:] 15 | ) 16 | found = True 17 | print(line, file=fl_out, end="") 18 | -------------------------------------------------------------------------------- /examples/torch/bert/parse_iree.py: -------------------------------------------------------------------------------- 1 | res = {"base": {}, "large": {}} 2 | 3 | with open("iree_new.log", "r") as fl: 4 | state = 0 5 | mode = "" 6 | bs = 0 7 | seq = 0 8 | for line in fl.readlines(): 9 | line = line.strip() 10 | if "Compilation successful:" in line: 11 | mode = "base" if "bert-base" in line else "large" 12 | idx = line.find("d_") 13 | line = line[idx + 2 :] 14 | idx = line.find("_") 15 | bs = int(line[:idx]) 16 | seq = int(line[idx + 1 : -5]) 17 | elif "BM_forward/process_time/real_time_mean" in line: 18 | idx = line.find("ms") 19 | res[mode][(bs, seq)] = float(line[idx - 5 : idx]) 20 | 21 | for mode in ["base", "large"]: 22 | for bs in [1, 8, 16, 32]: 23 | for seq in [64, 128]: 24 | try: 25 | print(f"{mode} ({bs}, {seq}): {res[mode][(bs, seq)]}") 26 | except: 27 | print("No Info") 28 | -------------------------------------------------------------------------------- /include/Dialect/HOM/HOMOps.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_HOMDIALECT_H 2 | #define HOM_HOMDIALECT_H 3 | 4 | #include "mlir/Dialect/Quant/QuantTypes.h" 5 | #include "mlir/Dialect/Shape/IR/Shape.h" 6 | #include "mlir/IR/Attributes.h" 7 | #include "mlir/IR/Builders.h" 8 | #include "mlir/IR/BuiltinAttributes.h" 9 | #include "mlir/IR/BuiltinTypes.h" 10 | #include "mlir/IR/Dialect.h" 11 | #include "mlir/IR/DialectImplementation.h" 12 | #include "mlir/IR/Location.h" 13 | #include "mlir/IR/MLIRContext.h" 14 | #include "mlir/IR/OpDefinition.h" 15 | #include "mlir/IR/Operation.h" 16 | #include "mlir/IR/TensorEncoding.h" 17 | #include "mlir/IR/TypeUtilities.h" 18 | #include "mlir/IR/Types.h" 19 | #include "mlir/Interfaces/InferTypeOpInterface.h" 20 | #include "mlir/Interfaces/SideEffectInterfaces.h" 21 | #include "mlir/Support/LogicalResult.h" 22 | #include "llvm/ADT/StringRef.h" 23 | 24 | #include "HOM/HOMOpsDialect.h.inc" 25 | 26 | #define GET_OP_CLASSES 27 | #include "HOM/HOMOps.h.inc" 28 | 29 | #endif // HOM_HOMDIALECT_H 30 | -------------------------------------------------------------------------------- /include/WeightsEngine/WeightsEngine.h: -------------------------------------------------------------------------------- 1 | #ifndef HANDS_ON_MLIR_WEIGHTSENGINE_WEIGHTSENGINE_H_ 2 | #define HANDS_ON_MLIR_WEIGHTSENGINE_WEIGHTSENGINE_H_ 3 | 4 | #include "mlir/IR/BuiltinAttributeInterfaces.h" 5 | #include "mlir/IR/BuiltinTypeInterfaces.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace mlir { 12 | namespace hands_on_mlir { 13 | class WeightsEngine { 14 | size_t weightsIds; 15 | std::unordered_map> weightsMap; 16 | 17 | public: 18 | WeightsEngine() { weightsIds = 0; } 19 | size_t addWeight(std::shared_ptr); 20 | size_t addWeight(ElementsAttr &); 21 | void removeWeight(size_t idx); 22 | 23 | template 24 | static void serializeWeightToDisk(const ShapedType &shape, T *data, 25 | const std::string &fileName); 26 | }; 27 | 28 | extern WeightsEngine gWe; 29 | 30 | } // namespace hands_on_mlir 31 | } // namespace mlir 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/Passes.td: -------------------------------------------------------------------------------- 1 | include "mlir/Pass/PassBase.td" 2 | 3 | def HOMNVGPUFusionPass : Pass<"homnvgpu-fusion", "mlir::func::FuncOp"> { 4 | let summary = "HOM Fusion Pass for NVIDIA GPU"; 5 | let dependentDialects = [ 6 | "::mlir::hands_on_mlir::hom::HOMDialect", "::mlir::tosa::TosaDialect", 7 | "::mlir::pdl::PDLDialect", "::mlir::pdl_interp::PDLInterpDialect" 8 | ]; 9 | } 10 | 11 | def HOMNVGPULegalizeGemmPass 12 | : Pass<"homnvgpu-legalize-gemm", "mlir::func::FuncOp"> { 13 | let summary = "HOM Fusion Pass for NVIDIA GPU"; 14 | let dependentDialects = [ 15 | "::mlir::tosa::TosaDialect", "::mlir::pdl::PDLDialect", 16 | "::mlir::pdl_interp::PDLInterpDialect" 17 | ]; 18 | } 19 | 20 | def HOMNVGPUAutotunePass : Pass<"homnvgpu-autotune", "mlir::func::FuncOp"> { 21 | let summary = "HOM Fusion Pass for NVIDIA GPU"; 22 | let dependentDialects = [ 23 | "::mlir::tosa::TosaDialect", "::mlir::pdl::PDLDialect", 24 | "::mlir::pdl_interp::PDLInterpDialect" 25 | ]; 26 | } 27 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/HOMNVGPUOps.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_HOMNVGPUDIALECT_H 2 | #define HOM_HOMNVGPUDIALECT_H 3 | 4 | #include "mlir/Dialect/Quant/QuantTypes.h" 5 | #include "mlir/Dialect/Shape/IR/Shape.h" 6 | #include "mlir/IR/Attributes.h" 7 | #include "mlir/IR/Builders.h" 8 | #include "mlir/IR/BuiltinAttributes.h" 9 | #include "mlir/IR/BuiltinTypes.h" 10 | #include "mlir/IR/Dialect.h" 11 | #include "mlir/IR/DialectImplementation.h" 12 | #include "mlir/IR/Location.h" 13 | #include "mlir/IR/MLIRContext.h" 14 | #include "mlir/IR/OpDefinition.h" 15 | #include "mlir/IR/Operation.h" 16 | #include "mlir/IR/TensorEncoding.h" 17 | #include "mlir/IR/TypeUtilities.h" 18 | #include "mlir/IR/Types.h" 19 | #include "mlir/Interfaces/InferTypeOpInterface.h" 20 | #include "mlir/Interfaces/SideEffectInterfaces.h" 21 | #include "mlir/Support/LogicalResult.h" 22 | #include "llvm/ADT/StringRef.h" 23 | 24 | #include "HOMNVGPU/HOMNVGPUOpsDialect.h.inc" 25 | 26 | #define GET_OP_CLASSES 27 | #include "HOMNVGPU/HOMNVGPUOps.h.inc" 28 | 29 | #endif // HOM_HOMNVGPUDIALECT_H 30 | -------------------------------------------------------------------------------- /examples/torch/softmax/generate_softmax.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append( 4 | "/Users/pzzzzz/MyProjects/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir" 5 | ) 6 | 7 | import torch 8 | import torch_mlir 9 | from transformers import BertConfig, BertForMaskedLM 10 | 11 | hidden_states = torch.rand((1, 10)) 12 | 13 | 14 | class Wrapper(torch.nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | config = BertConfig().from_pretrained("bert-base-uncased") 18 | config.num_hidden_layers = 1 19 | config.hidden_size = 24 20 | self.model = BertForMaskedLM(config) 21 | 22 | def forward(self, hidden_states): 23 | return torch.nn.functional.softmax(hidden_states, dim=-1) 24 | 25 | 26 | model = Wrapper() 27 | 28 | model.eval() 29 | 30 | with torch.no_grad(): 31 | module = torch_mlir.compile( 32 | model, hidden_states, output_type="tosa", use_tracing=True 33 | ) 34 | # output = model(*encoded_input_list) 35 | with open("softmax.mlir", "w") as fl: 36 | print(module, file=fl, end="") 37 | -------------------------------------------------------------------------------- /include/InitAllPasses.h: -------------------------------------------------------------------------------- 1 | #include "Conversions/FP32toFP16/Passes.h" 2 | #include "Conversions/Function/Passes.h" 3 | #include "Conversions/HOM/Passes.h" 4 | #include "Conversions/MatMulCPUOptimize/Passes.h" 5 | #include "Conversions/Tosa/Passes.h" 6 | #include "HOM/Passes.h" 7 | #include "HOMNVGPU/Passes.h" 8 | 9 | namespace mlir { 10 | namespace hands_on_mlir { 11 | inline void registerAllPasses() { 12 | registerExtractInitFuncPass(); 13 | registerHOMFP32ToFP16Pass(); 14 | registerHOMFuncToLLVMPipelines(); 15 | registerHOMNVGPUToFuncPass(); 16 | registerHOMToFuncPass(); 17 | registerHOMToHOMNVGPUPass(); 18 | registerMatMulCPUOptimizePass(); 19 | registerUnifyLLVMFuncInterfacePass(); 20 | hom::registerHOMFusionPass(); 21 | hom::registerHOMSerializeWeightPass(); 22 | hom::registerTosaToHOMPass(); 23 | hom::registerTosaConstantFoldingPass(); 24 | hom::registerTosaToHOMPipelines(); 25 | homnvgpu::registerHOMNVGPUFusionPass(); 26 | homnvgpu::registerHOMNVGPUAutotunePass(); 27 | homnvgpu::registerHOMNVGPULegalizeGemmPass(); 28 | } 29 | } // namespace hands_on_mlir 30 | } // namespace mlir 31 | -------------------------------------------------------------------------------- /include/Conversions/Function/Passes.td: -------------------------------------------------------------------------------- 1 | include "mlir/Pass/PassBase.td" 2 | 3 | def HOMToFuncPass : Pass<"hom-to-func", "mlir::func::FuncOp"> { 4 | let summary = "Lowering HOM to Func"; 5 | let dependentDialects = ["::mlir::func::FuncDialect"]; 6 | } 7 | 8 | def HOMNVGPUToFuncPass : Pass<"homnvgpu-to-func", "mlir::func::FuncOp"> { 9 | let summary = "Lowering HOMNVGPU to Func"; 10 | let dependentDialects = 11 | ["::mlir::func::FuncDialect", "::mlir::arith::ArithDialect"]; 12 | } 13 | 14 | def OptimizeMemoryPass : Pass<"hom-opti-mem", "mlir::func::FuncOp"> { 15 | let summary = "Lowering HOMNVGPU to Func"; 16 | let dependentDialects = 17 | ["::mlir::func::FuncDialect", "::mlir::arith::ArithDialect"]; 18 | } 19 | 20 | def ExtractInitFuncPass : Pass<"extract-init-func", "mlir::ModuleOp"> { 21 | let summary = "Lowering Stablehlo to HOM"; 22 | let dependentDialects = 23 | ["::mlir::arith::ArithDialect", "::mlir::func::FuncDialect"]; 24 | } 25 | 26 | def UnifyLLVMFuncInterfacePass 27 | : Pass<"unify-llvm-func-interface", "mlir::ModuleOp"> { 28 | let summary = "Lowering Stablehlo to HOM"; 29 | let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; 30 | } 31 | -------------------------------------------------------------------------------- /examples/mlir/utils/fill_and_print.mlir: -------------------------------------------------------------------------------- 1 | func.func @main() { 2 | %A = memref.alloc() : memref<4x8xf32> 3 | %i = arith.constant 0 : index 4 | %j = arith.constant 1 : index 5 | %M = arith.constant 4 : i64 6 | %N = arith.constant 8 : i64 7 | %c = arith.constant 4.0 : f32 8 | 9 | linalg.fill ins(%c : f32) outs(%A : memref<4x8xf32>) 10 | 11 | %B = memref.cast %A : memref<4x8xf32> to memref<*xf32> 12 | 13 | func.call @printMemrefF32(%B) : (memref<*xf32>) -> () 14 | func.call @print2DMatrixF32(%B) : (memref<*xf32>) -> () 15 | 16 | func.call @fill2DRandomMatrixF32(%B) : (memref<*xf32>) -> () 17 | func.call @printMemrefF32(%B) : (memref<*xf32>) -> () 18 | func.call @print2DMatrixF32(%B) : (memref<*xf32>) -> () 19 | 20 | func.call @fill2DIncMatrixF32(%B) : (memref<*xf32>) -> () 21 | func.call @printMemrefF32(%B) : (memref<*xf32>) -> () 22 | func.call @print2DMatrixF32(%B) : (memref<*xf32>) -> () 23 | 24 | return 25 | } 26 | 27 | func.func private @print2DMatrixF32(memref<*xf32>) 28 | func.func private @fill2DRandomMatrixF32(memref<*xf32>) 29 | func.func private @fill2DIncMatrixF32(memref<*xf32>) 30 | func.func private @printMemrefF32(memref<*xf32>) 31 | -------------------------------------------------------------------------------- /examples/torch/linear/run.sh: -------------------------------------------------------------------------------- 1 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-serialize-weight --hom-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface linear.mlir | \ 2 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir | \ 3 | ../../../thirdparty/llvm-project/build/bin/llc > linear.s 4 | 5 | clang++-18 linear.s -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-18 -std=gnu++17 -g -o liblinear.so 6 | 7 | clang++-18 run.cpp -fsanitize=address,undefined -I../../../include/ -I../../../thirdparty/llvm-project/mlir/include/ -I../../../thirdparty/llvm-project/llvm/include/ -I../../../thirdparty/llvm-project/build/include/ -L./ -L../../../build/lib/ -L../../../thirdparty/llvm-project/build/lib -lLLVM-18 -lhands_on_mlir_runner_utils -llinear -lhands_on_mlir_execution_engine -Wl,-rpath,../../../build/lib -Wl,-rpath,../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ -std=gnu++17 -o run 8 | 9 | LSAN_OPTIONS=suppressions=../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../ubsan.supp ./run 10 | -------------------------------------------------------------------------------- /examples/torch/bert/parse_hom.py: -------------------------------------------------------------------------------- 1 | res = {"base": {}, "large": {}} 2 | 3 | with open("hom_sync.log", "r") as fl: 4 | bs = 0 5 | seq = 0 6 | mode = "" 7 | autotune = False 8 | p = 0 9 | for line in fl.readlines(): 10 | line = line.strip() 11 | if "Tag" in line: 12 | mode = "base" if "base" in line else "large" 13 | line = line.split() 14 | bs = int(line[2]) 15 | seq = int(line[3]) 16 | autotune = int(line[1]) 17 | if (bs, seq, autotune) not in res[mode]: 18 | res[mode][(bs, seq, autotune)] = "err" 19 | p = 2 20 | continue 21 | elif p == 1 and "E2E" in line: 22 | idx = line.find(":") 23 | if res[mode][(bs, seq, autotune)] == "err": 24 | res[mode][(bs, seq, autotune)] = float(line[idx + 1 : -2]) 25 | 26 | p -= 1 27 | pre = line 28 | 29 | for mode in ["base", "large"]: 30 | for bs in [1, 8, 16, 32]: 31 | for seq in [64, 128]: 32 | for auto in [0, 1]: 33 | try: 34 | print(f"{mode} ({bs}, {seq}, {auto}): {res[mode][(bs, seq, auto)]}") 35 | except: 36 | print("No info") 37 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/linear_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.realpath("../..")) 5 | 6 | import torch 7 | from benchmark import speed_test 8 | 9 | torch.manual_seed(42) 10 | 11 | for hidden in [768, 1024]: 12 | for m in [64]: 13 | for n in [hidden, hidden * 4]: 14 | for k in [hidden, hidden * 4]: 15 | 16 | if hidden != 1024: 17 | continue 18 | 19 | if n != hidden or k != hidden * 4: 20 | continue 21 | 22 | class LinearWithoutBias(torch.nn.Module): 23 | def __init__(self) -> None: 24 | super().__init__() 25 | self.fc = torch.nn.Linear(k, n, bias=False) 26 | 27 | def forward(self, x): 28 | return self.fc(x) 29 | 30 | class LinearWithResidual(torch.nn.Module): 31 | def __init__(self) -> None: 32 | super().__init__() 33 | self.fc = torch.nn.Linear(100, 10, bias=False) 34 | 35 | def forward(self, x): 36 | return self.fc(x) + x 37 | 38 | a = LinearWithoutBias().cuda().half() 39 | 40 | x = torch.ones(m, k).half().cuda() 41 | 42 | print(m, n, k) 43 | 44 | speed_test(a, [x]) 45 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/run_iree.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | # Generated by Kimi Chat 4 | 5 | # 遍历当前目录及子目录下的所有 .mlir 文件 6 | find . -type f -name "iree*.mlir" | while read mlir_file; do 7 | # 定义输出文件的名称,将 .mlir 后缀替换为 .vmfb 8 | output_file="${mlir_file%.mlir}.vmfb" 9 | 10 | # 使用 iree-compile 命令编译 MLIR 文件,并将结果输出到 .vmfb 文件 11 | # 这里假设 iree-compile 命令接受 -o 选项来指定输出文件 12 | # 并且假设 iree-compile 命令可以直接处理文件名中的通配符 13 | ~/stuff/iree-build/tools/iree-compile --iree-hal-target-backends=cuda --iree-hal-cuda-llvm-target-arch=sm_86 -o "$output_file" "$mlir_file" 14 | 15 | # 检查编译是否成功 16 | if [ $? -eq 0 ]; then 17 | echo "Compilation successful: $output_file" 18 | else 19 | echo "Compilation failed for $mlir_file" 20 | fi 21 | 22 | pattern="iree_linear_([0-9]+)_([0-9]+)_([0-9]+)\.mlir" 23 | 24 | if [[ $mlir_file =~ $pattern ]]; then 25 | M="${BASH_REMATCH[1]}" 26 | N="${BASH_REMATCH[2]}" 27 | K="${BASH_REMATCH[3]}" 28 | 29 | # 输出提取的值 30 | echo "M: $M, N: $N, K: $K" 31 | fi 32 | 33 | ~/stuff/iree-build/tools/iree-benchmark-module --module=$output_file \ 34 | --iree-hal-target-backends=cuda --device=cuda://0 \ 35 | --function=forward \ 36 | --device_allocator=caching \ 37 | --input=1x${M}x${K}xf16=-1 --benchmark_repetitions=10 38 | 39 | done 40 | 41 | echo "Compilation process completed for all matching files." 42 | -------------------------------------------------------------------------------- /include/NVGPUKernels/GemmManifest.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "NVGPUKernels/GemmRunner.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace mlir { 11 | namespace hands_on_mlir { 12 | namespace homnvgpu_kernel { 13 | 14 | class GemmManifest; 15 | 16 | void initialize_all(GemmManifest &manifest); 17 | 18 | class GemmManifest { 19 | private: 20 | std::vector> ops_; 21 | bool isInitialized; 22 | 23 | public: 24 | void updateAllKernels() { 25 | if (!isInitialized) { 26 | init(); 27 | } 28 | } 29 | 30 | GemmManifest() : isInitialized(false) {} 31 | 32 | auto &getKernel(int32_t idx) { 33 | if (!isInitialized) { 34 | init(); 35 | } 36 | return ops_[idx]; 37 | }; 38 | 39 | auto &operator[](int idx) { 40 | if (!isInitialized) { 41 | init(); 42 | } 43 | return ops_[idx]; 44 | } 45 | 46 | void append(GemmOperationRunnerBase *operation) { 47 | ops_.emplace_back(operation); 48 | } 49 | 50 | void reserve(size_t size) { ops_.reserve(size); } 51 | 52 | auto size() { 53 | if (!isInitialized) { 54 | init(); 55 | } 56 | return ops_.size(); 57 | } 58 | 59 | void init() { 60 | if (!isInitialized) { 61 | initialize_all(*this); 62 | isInitialized = true; 63 | } 64 | } 65 | }; 66 | 67 | extern GemmManifest manifest; 68 | 69 | } // namespace homnvgpu_kernel 70 | } // namespace hands_on_mlir 71 | } // namespace mlir 72 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | 4 | torch.manual_seed(42) 5 | 6 | hidden_state = 768 7 | 8 | with torch.no_grad(): 9 | for m in [64]: 10 | for n in [hidden_state, hidden_state * 4]: 11 | for k in [hidden_state, hidden_state * 4]: 12 | 13 | class LinearWithoutBias(torch.nn.Module): 14 | def __init__(self) -> None: 15 | super().__init__() 16 | self.fc = torch.nn.Linear(k, n, bias=False) 17 | 18 | def forward(self, x): 19 | return self.fc(x) 20 | 21 | class LinearWithResidual(torch.nn.Module): 22 | def __init__(self) -> None: 23 | super().__init__() 24 | self.fc = torch.nn.Linear(100, 10, bias=False) 25 | 26 | def forward(self, x): 27 | return self.fc(x) + x 28 | 29 | a = LinearWithoutBias() 30 | 31 | x = torch.ones(1, 64, k) 32 | 33 | module = torch_mlir.compile(a, x, output_type="tosa") 34 | with open(f"hom_linear_{m}_{n}_{k}.mlir", "w") as fl: 35 | print(module, file=fl, end="") 36 | 37 | a = a.half() 38 | x = x.half() 39 | 40 | module = torch_mlir.compile(a, x, output_type="tosa") 41 | with open(f"iree_linear_{m}_{n}_{k}.mlir", "w") as fl: 42 | print(module, file=fl, end="") 43 | -------------------------------------------------------------------------------- /lib/Dialect/HOM/HOMOps.cpp: -------------------------------------------------------------------------------- 1 | #include "mlir/IR/Builders.h" 2 | #include "mlir/IR/BuiltinTypes.h" 3 | #include "mlir/IR/DialectImplementation.h" 4 | #include "mlir/IR/MLIRContext.h" 5 | #include "mlir/IR/Operation.h" 6 | #include "mlir/IR/OperationSupport.h" 7 | #include "mlir/Transforms/InliningUtils.h" 8 | #include "llvm/ADT/TypeSwitch.h" 9 | #include "llvm/AsmParser/Parser.h" 10 | #include "llvm/IR/Attributes.h" 11 | #include "llvm/IR/Function.h" 12 | #include "llvm/IR/Type.h" 13 | #include "llvm/Support/SourceMgr.h" 14 | 15 | #include "HOM/HOMOps.h" 16 | 17 | using namespace mlir; 18 | using namespace hands_on_mlir::hom; 19 | 20 | #include "HOM/HOMOpsDialect.cpp.inc" 21 | 22 | void HOMDialect::initialize() { 23 | addOperations< 24 | #define GET_OP_LIST 25 | #include "HOM/HOMOps.cpp.inc" 26 | >(); 27 | } 28 | 29 | #define GET_OP_CLASSES 30 | #include "HOM/HOMOps.cpp.inc" 31 | 32 | //===----------------------------------------------------------------------===// 33 | // ConstantOp 34 | //===----------------------------------------------------------------------===// 35 | 36 | OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { 37 | assert(adaptor.getOperands().empty() && "constant has no operands"); 38 | 39 | // Return the held attribute value. 40 | return getIdxAttr(); 41 | } 42 | 43 | /// Print a `constant` op. 44 | /// 45 | /// op ::= attr-dict $value 46 | /// 47 | /// When the `value` and `output` have different type, it just uses the default 48 | /// operator assembly format as a fallback. 49 | // void ConstantOp::print(::mlir::OpAsmPrinter &p) {} 50 | -------------------------------------------------------------------------------- /lib/Conversions/HOM/HOMToHOMNVGPU.cpp: -------------------------------------------------------------------------------- 1 | #include "Conversions/HOM/Passes.h" 2 | #include "Conversions/Tosa/Passes.h" 3 | #include "HOM/HOMOps.h" 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 6 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" 7 | #include "mlir/IR/BuiltinAttributes.h" 8 | #include "mlir/IR/PatternMatch.h" 9 | #include "mlir/Parser/Parser.h" 10 | #include "mlir/Pass/PassManager.h" 11 | #include "mlir/Pass/PassRegistry.h" 12 | #include "mlir/Support/LogicalResult.h" 13 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 14 | 15 | #define PASS_NAME "hom-to-homnvgpu" 16 | #define DEBUG_TYPE PASS_NAME 17 | 18 | namespace mlir { 19 | namespace hands_on_mlir { 20 | 21 | #define GEN_PASS_DEF_HOMTOHOMNVGPUPASS 22 | #include "Conversions/HOM/Passes.h.inc" 23 | 24 | #include "Conversions/HOM/HOMToHOMNVGPU.pdll.h.inc" 25 | 26 | namespace { 27 | struct HOMToHOMNVGPUPass : impl::HOMToHOMNVGPUPassBase { 28 | void runOnOperation() final; 29 | 30 | LogicalResult initialize(MLIRContext *ctx) override; 31 | 32 | private: 33 | FrozenRewritePatternSet patterns; 34 | }; 35 | 36 | LogicalResult HOMToHOMNVGPUPass::initialize(MLIRContext *ctx) { 37 | RewritePatternSet patternList(ctx); 38 | 39 | populateGeneratedPDLLPatterns(patternList); 40 | patterns = std::move(patternList); 41 | return success(); 42 | } 43 | 44 | void HOMToHOMNVGPUPass::runOnOperation() { 45 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 46 | } 47 | } // namespace 48 | 49 | } // namespace hands_on_mlir 50 | } // namespace mlir 51 | -------------------------------------------------------------------------------- /examples/torch/layernorm/cuda/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | ../../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-serialize-weight --hom-to-homnvgpu --homnvgpu-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface layernorm.mlir |\ 4 | ../../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 5 | ../../../../thirdparty/llvm-project/build/bin/llc > layernorm_nvgpu.s 6 | 7 | clang++-17 layernorm_nvgpu.s -fPIC -shared -L../../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o liblayernorm_nvgpu.so 8 | 9 | clang++-17 layernorm.cu -g -debug -fsanitize=address,undefined -I../../../../include/ -I../../../../thirdparty/llvm-project/mlir/include/ -I../../../../thirdparty/llvm-project/llvm/include/ -I../../../../thirdparty/cutlass/include/ -I../../../../thirdparty/TransformerEngine/transformer_engine/common/include -I../../../../thirdparty/llvm-project/build/include/ -L./ -L../../../../build/lib/ -L../../../../thirdparty/llvm-project/build/lib -lLLVM-17 -lhands_on_mlir_runner_utils -llayernorm_nvgpu -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -L$CUDA_HOME/lib64 \ 10 | -lcudart_static -Wl,-rpath,../../../../build/lib -Wl,-rpath,../../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ --cuda-gpu-arch=sm_89 -std=gnu++17 -o run 11 | 12 | LSAN_OPTIONS=suppressions=../../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0 ./run 13 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/run_cuseqlen.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-serialize-weight --hom-to-homnvgpu --homnvgpu-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface cu_seqlen.mlir |\ 4 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 5 | ../../../thirdparty/llvm-project/build/bin/llc > cuseqlen_nvgpu.s 6 | 7 | clang++-17 cuseqlen_nvgpu.s -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libcuseqlen_nvgpu.so 8 | 9 | clang++-17 cuSeqLen.cu -g -debug -fsanitize=address,undefined -I../../../include/ -I../../../thirdparty/llvm-project/mlir/include/ -I../../../thirdparty/TransformerEngine/transformer_engine/common/include -I../../../thirdparty/llvm-project/llvm/include/ -I../../../thirdparty/cutlass/include/ -I../../../thirdparty/llvm-project/build/include/ -L./ -L../../../build/lib/ -L../../../thirdparty/llvm-project/build/lib -L../../../thirdparty/TransformerEngine -lLLVM-17 -lhands_on_mlir_runner_utils -lcuseqlen_nvgpu -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -ltransformer_engine -L$CUDA_HOME/lib64 \ 10 | -lcudart_static -Wl,-rpath,../../../build/lib -Wl,-rpath,../../../thirdparty/TransformerEngine -Wl,-rpath,../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ --cuda-gpu-arch=sm_89 -std=gnu++17 -o run 11 | 12 | LSAN_OPTIONS=suppressions=../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0 ./run 13 | -------------------------------------------------------------------------------- /lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Dialect) 2 | add_subdirectory(Conversions) 3 | add_subdirectory(WeightsEngine) 4 | 5 | add_mlir_library( 6 | static_mlir_async_runtime 7 | STATIC 8 | ${LLVM_MLIR_SOURCE_DIR}/lib/ExecutionEngine/AsyncRuntime.cpp 9 | EXCLUDE_FROM_LIBMLIR 10 | LINK_LIBS 11 | PUBLIC 12 | ${LLVM_PTHREAD_LIB}) 13 | 14 | target_compile_definitions(static_mlir_async_runtime 15 | PRIVATE MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS) 16 | 17 | set(LLVM_OPTIONAL_SOURCES ExecutionEngine/HandsOnRunnerUtils.cpp 18 | ExecutionEngine/ExecutionEngine.cpp) 19 | 20 | add_mlir_library(hands_on_mlir_runner_utils SHARED 21 | ExecutionEngine/HandsOnRunnerUtils.cpp) 22 | 23 | add_mlir_library( 24 | hands_on_mlir_execution_engine 25 | SHARED 26 | ExecutionEngine/ExecutionEngine.cpp 27 | LINK_COMPONENTS 28 | Core 29 | Support 30 | LINK_LIBS 31 | PUBLIC 32 | MLIRIR 33 | MLIRPass 34 | MLIRTransforms 35 | WeightsEngine) 36 | 37 | target_compile_definitions(hands_on_mlir_runner_utils 38 | PRIVATE hands_on_mlir_runner_utils_EXPORTS) 39 | 40 | if(ENABLE_CUDA) 41 | add_mlir_library( 42 | hands_on_mlir_nvgpu_runner_utils 43 | SHARED 44 | ExecutionEngine/HandsOnNVGPURunnerUtils.cu 45 | LINK_LIBS 46 | WeightsEngine 47 | nvToolsExt 48 | cublas 49 | dl 50 | GemmManifestAndProfiler) 51 | target_compile_options(hands_on_mlir_nvgpu_runner_utils PRIVATE -fexceptions 52 | -lcublas) 53 | target_link_options(hands_on_mlir_nvgpu_runner_utils PRIVATE -lcublas) 54 | add_subdirectory(NVGPUKernels) 55 | endif() 56 | -------------------------------------------------------------------------------- /examples/torch/elementwise/run_add.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/wsl/lib:/home/shared_folder/cudnn-linux-x86_64-9.0.0.312_cuda12-archive/lib 4 | 5 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-to-homnvgpu --homnvgpu-fusion --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface add.mlir |\ 6 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 7 | ../../../thirdparty/llvm-project/build/bin/llc > add_nvgpu.s 8 | 9 | clang++-17 add_nvgpu.s -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libadd_nvgpu.so 10 | 11 | clang++-17 add.cu -g -debug -fsanitize=address,undefined -I../../../include/ -I../../../thirdparty/llvm-project/mlir/include/ -I../../../thirdparty/llvm-project/llvm/include/ -I../../../thirdparty/cutlass/include/ -I../../../thirdparty/TransformerEngine/transformer_engine/common/include -I../../../thirdparty/llvm-project/build/include/ -L./ -L../../../build/lib/ -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -lhands_on_mlir_runner_utils -ladd_nvgpu -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -L$CUDA_HOME/lib64 \ 12 | -lcudart_static -Wl,-rpath,../../../build/lib -Wl,-rpath,../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ --cuda-gpu-arch=sm_86 -std=gnu++17 -o run 13 | 14 | LSAN_OPTIONS=suppressions=../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0 ./run 15 | -------------------------------------------------------------------------------- /include/NVGPUKernels/GemmProfiler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "driver_types.h" 5 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace mlir { 17 | namespace hands_on_mlir { 18 | namespace homnvgpu_kernel { 19 | 20 | class GemmProfiler { 21 | 22 | C_UnrankedMemRefType a, b, c, tb; 23 | 24 | float alpha_, beta_; 25 | 26 | int64_t M_, N_, K_, activation_; 27 | 28 | cudaEvent_t start_, stop_; 29 | 30 | static void updateShape(C_UnrankedMemRefType &A, int64_t m, int64_t n, 31 | int64_t k); 32 | 33 | std::vector splitKFactor; 34 | 35 | float 36 | profileHelper(std::function runFn, const char *kernelName, 37 | float previousBestTime = std::numeric_limits::max()); 38 | 39 | std::map, 40 | std::tuple> 41 | timingCache; 42 | 43 | void updateSplitKFactor(int32_t K); 44 | 45 | public: 46 | GemmProfiler() = delete; 47 | 48 | GemmProfiler(int64_t M, int64_t N, int64_t K, int64_t activation, float alpha, 49 | float beta); 50 | ~GemmProfiler(); 51 | 52 | std::tuple profile(); 53 | 54 | std::tuple profile(int64_t M, int64_t N, int64_t K, 55 | int64_t activation, float alpha, 56 | float beta); 57 | }; 58 | 59 | } // namespace homnvgpu_kernel 60 | 61 | } // namespace hands_on_mlir 62 | } // namespace mlir 63 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | ../../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-to-homnvgpu --homnvgpu-fusion --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface linear.mlir |\ 4 | ../../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 5 | ../../../../thirdparty/llvm-project/build/bin/llc > linear_nvgpu.s 6 | 7 | clang++-18 linear_nvgpu.s -fPIC -shared -L../../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../../thirdparty/llvm-project/build/lib -lLLVM-18 -std=gnu++17 -g -o liblinear_nvgpu.so 8 | 9 | clang++-18 run.cu -g -debug -fsanitize=address,undefined -I../../../../include/ -I../../../../thirdparty/llvm-project/mlir/include/ -I../../../../thirdparty/llvm-project/llvm/include/ -I../../../../thirdparty/cutlass/include/ -I../../../../thirdparty/llvm-project/build/include/ -I../../../../thirdparty/TransformerEngine/transformer_engine/common/include -L../../../../thirdparty/TransformerEngine/ -L./ -L../../../../build/lib/ -L../../../../thirdparty/llvm-project/build/lib -lLLVM-18 -lhands_on_mlir_runner_utils -ltransformer_engine -llinear_nvgpu -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -L$CUDA_HOME/lib64 \ 10 | -lcudart_static -Wl,-rpath,../../../../build/lib -Wl,-rpath,../../../../thirdparty/llvm-project/build/lib -Wl,-rpath,../../../../thirdparty/TransformerEngine/ -Wl,-rpath,./ --cuda-gpu-arch=sm_89 -std=gnu++17 -o run 11 | 12 | LSAN_OPTIONS=suppressions=../../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0,detect_odr_violation=0 ./run 13 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/bert_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | from transformers import BertConfig, BertForMaskedLM 4 | from transformers.models.bert.modeling_bert import BertAttention 5 | 6 | torch.manual_seed(42) 7 | 8 | hs = 768 9 | 10 | encoded_input_list = [ 11 | torch.rand((1, 64, hs)), 12 | torch.ones((1, 64), dtype=torch.int64), 13 | ] 14 | 15 | 16 | class BertAttentionWrapper(torch.nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | config = BertConfig().from_pretrained("bert-base-uncased") 20 | self.attn = BertAttention(config) 21 | self.attn.self.query.bias.data.zero_() 22 | self.attn.self.key.bias.data.zero_() 23 | self.attn.self.value.bias.data.zero_() 24 | self.attn.output.dense.bias.data.zero_() 25 | self.model = BertForMaskedLM(config) 26 | 27 | def forward(self, hidden, mask): 28 | mask = mask[:, None, None, :] 29 | mask = mask.to(torch.float32) 30 | mask = (1.0 - mask) * torch.finfo(torch.float32).min 31 | return self.model.bert.encoder(hidden, mask).last_hidden_state 32 | 33 | 34 | model = BertAttentionWrapper() 35 | model.eval() 36 | 37 | with open("0.txt", "w") as fl: 38 | for i in encoded_input_list[0].view(-1): 39 | print(float(i), file=fl) 40 | thing = model(*encoded_input_list) 41 | 42 | print(thing) 43 | 44 | with open("1.txt", "w") as fl: 45 | for i in thing.view(-1): 46 | print(float(i), file=fl) 47 | 48 | with torch.no_grad(): 49 | module = torch_mlir.compile( 50 | model, encoded_input_list, output_type="TOSA", use_tracing=True 51 | ) 52 | # output = model(*encoded_input_list) 53 | with open("bert_attn.mlir", "w") as fl: 54 | print(module, file=fl, end="") 55 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/HOMNVGPULegalizeGemm.pdll: -------------------------------------------------------------------------------- 1 | #include "HOMNVGPU/HOMNVGPUOps.td" 2 | #include "mlir/Dialect/Tosa/IR/TosaOps.td" 3 | 4 | Rewrite generateTranspose(op0 : Op); 5 | Constraint isF16(op 6 | : Op)[{ 7 | if (auto tp = dyn_cast(op->getResult(0).getType())) { 8 | return success(tp.getElementType().isF16()); 9 | } 10 | return failure(); 11 | }]; 12 | 13 | Pattern { 14 | let transa = attr<"0 : i1">; 15 | let transb = attr<"0 : i1">; 16 | let kernel = attr<"0 : i32">; 17 | let matmul = op(input0 18 | : Value, input1 19 | : Value, input2 20 | : Value){transa = transa, transb = transb, 21 | kernel_name = kernel}; 22 | isF16(matmul); 23 | 24 | rewrite matmul with { generateTranspose(matmul); }; 25 | } 26 | 27 | // Currently do not support transpose for cutlass 28 | Pattern { 29 | let transa = attr<"0 : i1">; 30 | let transb = attr<"0 : i1">; 31 | let trans = attr<"1 : i1">; 32 | let kernel = attr<"0 : i32">; 33 | let transpose = op(input1 : Value, perm : Value); 34 | let matmul = op(input0 35 | : Value, transpose, input2 36 | : Value){ 37 | transa = transa, 38 | transb = transb, 39 | act = act : Attr, 40 | alpha = alpha : Attr, 41 | beta = beta : Attr 42 | }; 43 | isF16(matmul); 44 | 45 | rewrite matmul with { 46 | 47 | replace matmul with op(input0, input1, input2){ 48 | transa = transa, transb = trans, kernel_name = kernel, 49 | alpha = alpha, beta = beta, act = act}; 50 | }; 51 | } 52 | -------------------------------------------------------------------------------- /examples/torch/linear/run.cpp: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 3 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 4 | #include "llvm/Support/Error.h" 5 | #include 6 | #include 7 | 8 | struct Res { 9 | C_UnrankedMemRefType a; 10 | }; 11 | 12 | #define RowMajor(A, i, j, k) \ 13 | ((A).data[(i) * (A).strides[0] + (j) * (A).strides[1] + (k) * (A).strides[2]]) 14 | 15 | int main() { 16 | C_UnrankedMemRefType a; 17 | 18 | a.rank = 3; 19 | 20 | a.descriptor = malloc(sizeof(StridedMemRefType)); 21 | auto des = static_cast *>(a.descriptor); 22 | des->data = new float[3 * 100]; 23 | des->basePtr = des->data; 24 | des->sizes[0] = 1; 25 | des->sizes[1] = 3; 26 | des->sizes[2] = 100; 27 | des->strides[0] = 300; 28 | des->strides[1] = 100; 29 | des->strides[2] = 1; 30 | for (int i = 0; i < 300; i++) { 31 | des->data[i] = 1; 32 | } 33 | 34 | Res b; 35 | mlir::hands_on_mlir::ExecutionEngine e("liblinear.so"); 36 | 37 | auto res = e.invoke("forward", a.rank, a.descriptor, 38 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 39 | if (res) { 40 | llvm::handleAllErrors(std::move(res)); 41 | } 42 | auto c = DynamicMemRefType(b.a); 43 | std::cout << c.rank << std::endl; 44 | for (int i = 0; i < c.sizes[0]; i++) { 45 | for (int j = 0; j < c.sizes[1]; j++) { 46 | for (int k = 0; k < c.sizes[2]; k++) { 47 | std::cout << RowMajor(c, i, j, k) << " "; 48 | } 49 | std::cout << std::endl; 50 | } 51 | std::cout << std::endl; 52 | } 53 | 54 | delete[] des->data; 55 | delete[] c.data; 56 | 57 | free(a.descriptor); 58 | free(b.a.descriptor); 59 | } 60 | -------------------------------------------------------------------------------- /examples/torch/bert/compile.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/wsl/lib:/home/shared_folder/cudnn-linux-x86_64-9.0.0.312_cuda12-archive/lib 4 | 5 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-fp32-to-fp16 --hom-to-homnvgpu --homnvgpu-fusion bert.mlir > pre_tune.mlir 6 | 7 | # ../../../build/bin/hands-on-opt --homnvgpu-legalize-gemm --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --hom-func-to-llvm-pipeline pre_tune.mlir | \ 8 | # ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 9 | # ../../../thirdparty/llvm-project/build/bin/llc > bert_nvgpu.s 10 | 11 | ../../../build/bin/hands-on-opt --homnvgpu-autotune --homnvgpu-legalize-gemm --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --hom-func-to-llvm-pipeline pre_tune.mlir | \ 12 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 13 | ../../../thirdparty/llvm-project/build/bin/llc > bert_autotune_nvgpu.s 14 | 15 | clang++-17 bert_nvgpu.s -O3 -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libbert_nvgpu.so 16 | clang++-17 bert_autotune_nvgpu.s -O3 -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libbert_autotune_nvgpu.so 17 | 18 | # CUDNN_FRONTEND_LOG_INFO=1 CUDNN_FRONTEND_LOG_FILE=stderr CUDNN_LOGERR_DBG=3 CUDNN_LOGDEST_DBG=stderr \ 19 | # LSAN_OPTIONS=suppressions=../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0,detect_odr_violation=0 ./run 20 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/bert_self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_mlir 3 | from transformers import BertConfig 4 | from transformers.models.bert.modeling_bert import BertSelfAttention 5 | 6 | torch.manual_seed(42) 7 | 8 | encoded_input_list = [ 9 | torch.rand((1, 64, 128)), 10 | torch.ones((1, 64), dtype=torch.int64), 11 | ] 12 | 13 | 14 | class BertAttentionWrapper(torch.nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | config = BertConfig().from_pretrained("bert-base-uncased") 18 | config.num_hidden_layers = 2 19 | config.hidden_size = 128 20 | config.num_attention_heads = 2 21 | self.attn = BertSelfAttention(config) 22 | self.attn.query.bias.data.zero_() 23 | self.attn.key.bias.data.zero_() 24 | self.attn.value.bias.data.zero_() 25 | self.linear = torch.nn.Linear(128, 128, bias=False) 26 | self.ln = torch.nn.LayerNorm(128) 27 | 28 | def forward(self, hidden, mask): 29 | mask = mask[:, None, None, :] 30 | mask = mask.to(torch.float32) 31 | mask = (1.0 - mask) * torch.finfo(torch.float32).min 32 | return self.attn(hidden, mask)[0] 33 | 34 | 35 | model = BertAttentionWrapper() 36 | model.eval() 37 | 38 | with open("0.txt", "w") as fl: 39 | for i in encoded_input_list[0].view(-1): 40 | print(float(i), file=fl) 41 | thing = model(*encoded_input_list) 42 | 43 | print(thing) 44 | 45 | with open("1.txt", "w") as fl: 46 | for i in thing.view(-1): 47 | print(float(i), file=fl) 48 | 49 | with torch.no_grad(): 50 | module = torch_mlir.compile( 51 | model, encoded_input_list, output_type="TOSA", use_tracing=True 52 | ) 53 | # output = model(*encoded_input_list) 54 | with open("bert_self_attn.mlir", "w") as fl: 55 | print(module, file=fl, end="") 56 | -------------------------------------------------------------------------------- /examples/torch/bert/run_iree.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | # Generated by Kimi Chat 4 | 5 | # 遍历当前目录及子目录下的所有 .mlir 文件 6 | find . -type f -name "iree*.mlir" | while read mlir_file; do 7 | # 定义输出文件的名称,将 .mlir 后缀替换为 .vmfb 8 | output_file="${mlir_file%.mlir}.vmfb" 9 | 10 | # 使用 iree-compile 命令编译 MLIR 文件,并将结果输出到 .vmfb 文件 11 | # 这里假设 iree-compile 命令接受 -o 选项来指定输出文件 12 | # 并且假设 iree-compile 命令可以直接处理文件名中的通配符 13 | ~/stuff/iree-build/tools/iree-compile --iree-hal-target-backends=cuda --iree-opt-demote-f32-to-f16 --iree-hal-cuda-llvm-target-arch=sm_86 -o "$output_file" "$mlir_file" 14 | 15 | # 检查编译是否成功 16 | if [ $? -eq 0 ]; then 17 | echo "Compilation successful: $output_file" 18 | else 19 | echo "Compilation failed for $mlir_file" 20 | fi 21 | 22 | pattern="iree_.*_([0-9]+)_([0-9]+)\.mlir" 23 | 24 | if [[ $mlir_file =~ $pattern ]]; then 25 | bs="${BASH_REMATCH[1]}" 26 | seq_len="${BASH_REMATCH[2]}" 27 | 28 | # 输出提取的值 29 | echo "bs: $bs, seq_len: $seq_len" 30 | fi 31 | 32 | ~/stuff/iree-build/tools/iree-benchmark-module --module=$output_file \ 33 | --iree-hal-target-backends=cuda --device=cuda://0 \ 34 | --function=forward \ 35 | --device_allocator=caching \ 36 | --input=${bs}x${seq_len}xi64=1 --input=${bs}x${seq_len}xi64=1 --input=${bs}x${seq_len}xi64=1 --input="&${bs}x${seq_len}x30522xf16" --benchmark_repetitions=10 37 | 38 | nsys profile -o $output_file ~/stuff/iree-build/tools/iree-benchmark-module --module=$output_file \ 39 | --iree-hal-target-backends=cuda --device=cuda://0 \ 40 | --function=forward \ 41 | --device_allocator=caching \ 42 | --input=${bs}x${seq_len}xi64=1 --input=${bs}x${seq_len}xi64=1 --input=${bs}x${seq_len}xi64=1 --input="&${bs}x${seq_len}x30522xf16" --benchmark_repetitions=2 43 | 44 | done 45 | 46 | echo "Compilation process completed for all matching files." 47 | -------------------------------------------------------------------------------- /include/NVGPUKernels/OperationRunner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cutlass/coord.h" 4 | #include "cutlass/gemm_coord.h" 5 | #include "cutlass/library/types.h" 6 | 7 | namespace mlir { 8 | namespace hands_on_mlir { 9 | 10 | class OperationRunner { 11 | public: 12 | virtual void poly() {} 13 | }; 14 | 15 | struct MathInstructionDescription { 16 | 17 | using NumericTypeID = cutlass::library::NumericTypeID; 18 | using OpcodeClassID = cutlass::library::OpcodeClassID; 19 | using MathOperationID = cutlass::library::MathOperationID; 20 | 21 | /// Shape of the target math instruction 22 | cutlass::gemm::GemmCoord instruction_shape; 23 | 24 | /// Describes the data type of the internal accumulator 25 | NumericTypeID element_accumulator; 26 | 27 | /// Classification of math instruction 28 | OpcodeClassID opcode_class; 29 | 30 | /// Type of math operation performed 31 | MathOperationID math_operation; 32 | 33 | // 34 | // Methods 35 | // 36 | 37 | MathInstructionDescription( 38 | cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), 39 | NumericTypeID element_accumulator = NumericTypeID::kInvalid, 40 | OpcodeClassID opcode_class = OpcodeClassID::kInvalid, 41 | MathOperationID math_operation = MathOperationID::kMultiplyAdd) 42 | : instruction_shape(instruction_shape), 43 | element_accumulator(element_accumulator), opcode_class(opcode_class), 44 | math_operation(math_operation) {} 45 | 46 | // Equality operator 47 | inline bool operator==(MathInstructionDescription const &rhs) const { 48 | return ((instruction_shape == rhs.instruction_shape) && 49 | (element_accumulator == rhs.element_accumulator) && 50 | (opcode_class == rhs.opcode_class) && 51 | (math_operation == rhs.math_operation)); 52 | } 53 | 54 | // Inequality operator 55 | inline bool operator!=(MathInstructionDescription const &rhs) const { 56 | return !(*this == rhs); 57 | } 58 | }; 59 | 60 | } // namespace hands_on_mlir 61 | } // namespace mlir 62 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/run_bert_self_attn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/wsl/lib:/home/shared_folder/cudnn-linux-x86_64-9.0.0.312_cuda12-archive/lib 4 | 5 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-fp32-to-fp16 --hom-to-homnvgpu --homnvgpu-fusion --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface bert_self_attn.mlir |\ 6 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 7 | ../../../thirdparty/llvm-project/build/bin/llc > bert_self_attn_nvgpu.s 8 | 9 | clang++-17 bert_self_attn_nvgpu.s -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libbert_self_attn_nvgpu.so 10 | 11 | clang++-17 bert_self_attn.cu -g -debug -fsanitize=address,undefined -I../../../include/ -I../../../thirdparty/llvm-project/mlir/include/ -I../../../thirdparty/TransformerEngine/transformer_engine/common/include -I../../../thirdparty/llvm-project/llvm/include/ -I../../../thirdparty/cutlass/include/ -I../../../thirdparty/llvm-project/build/include/ -L./ -L../../../build/lib/ -L../../../thirdparty/llvm-project/build/lib -L../../../thirdparty/TransformerEngine -lLLVM-17 -lhands_on_mlir_runner_utils -lbert_self_attn_nvgpu -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -ltransformer_engine -L$CUDA_HOME/lib64 \ 12 | -lcudart_static -Wl,-rpath,../../../build/lib -Wl,-rpath,../../../thirdparty/TransformerEngine -Wl,-rpath,../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ --cuda-gpu-arch=sm_89 -std=gnu++17 -o run 13 | 14 | CUDNN_FRONTEND_LOG_INFO=1 CUDNN_FRONTEND_LOG_FILE=stderr CUDNN_LOGERR_DBG=3 CUDNN_LOGDEST_DBG=stderr LSAN_OPTIONS=suppressions=../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0,detect_odr_violation=0 ./run 15 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/run_fp16.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/Utils.h" 5 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 6 | #include "llvm/Support/Error.h" 7 | #include 8 | #include 9 | #include 10 | 11 | struct Res { 12 | UnrankedMemRefType a; 13 | }; 14 | 15 | #define RowMajor(A, des, i, j, k) \ 16 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1] + \ 17 | (k) * (des).strides[2]]) 18 | 19 | int main(int argc, char *argv[]) { 20 | 21 | auto m = std::atoi(argv[1]), n = std::atoi(argv[2]), k = std::atoi(argv[3]); 22 | 23 | auto a = allocHelper({1, m, k}, nvgpuAllocer); 24 | 25 | Res b; 26 | 27 | std::string filename = "libhom_linear_" + std::to_string(m) + "_" + 28 | std::to_string(n) + "_" + std::to_string(k) + ".so"; 29 | 30 | mlir::hands_on_mlir::ExecutionEngine e(filename); 31 | cudaEvent_t start, stop; 32 | checkCudaErrors(cudaEventCreate(&start)); 33 | checkCudaErrors(cudaEventCreate(&stop)); 34 | 35 | auto res = e.invoke("forward", a.rank, a.descriptor, 36 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 37 | if (res) { 38 | llvm::handleAllErrors(std::move(res)); 39 | } 40 | 41 | checkCudaErrors(cudaEventRecord(start)); 42 | 43 | for (int i = 0; i < 1000; i++) { 44 | res = e.invoke("forward", a.rank, a.descriptor, 45 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 46 | if (res) { 47 | llvm::handleAllErrors(std::move(res)); 48 | } 49 | } 50 | 51 | checkCudaErrors(cudaEventRecord(stop)); 52 | checkCudaErrors(cudaEventSynchronize(stop)); 53 | 54 | float msecTotal = 0; 55 | checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop)); 56 | 57 | std::cout << "E2E latency: " << msecTotal / 1000.0 << "ms" << std::endl; 58 | 59 | free(a.descriptor); 60 | free(b.a.descriptor); 61 | } 62 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | repos: 3 | - repo: https://github.com/pycqa/isort 4 | rev: 5.13.2 5 | hooks: 6 | - id: isort 7 | args: ["--profile", "black"] 8 | - repo: https://github.com/pre-commit/mirrors-clang-format 9 | rev: v17.0.6 10 | hooks: 11 | - id: clang-format 12 | alias: clang-format-pdll-td 13 | types_or: [file] 14 | args: [--style=LLVM] 15 | files: ^.*\.(pdll|td)$ 16 | - repo: https://github.com/pre-commit/mirrors-clang-format 17 | rev: v17.0.6 18 | hooks: 19 | - id: clang-format 20 | types_or: [c++, c, cuda] 21 | args: [--style=LLVM] 22 | exclude: | 23 | (?x)^(.*cubin.cpp$ | .*fmha_cubin.h)$ 24 | - repo: https://github.com/pre-commit/pre-commit-hooks 25 | rev: v4.5.0 26 | hooks: 27 | - id: check-added-large-files 28 | exclude: | 29 | (?x)^(.*cubin.cpp)$ 30 | - id: check-merge-conflict 31 | - id: check-symlinks 32 | - id: detect-private-key 33 | - id: end-of-file-fixer 34 | - id: check-yaml 35 | - id: trailing-whitespace 36 | - repo: https://github.com/psf/black 37 | rev: 24.1.1 38 | hooks: 39 | - id: black 40 | name: black 41 | description: "Black: The uncompromising Python code formatter" 42 | entry: black 43 | language: python 44 | minimum_pre_commit_version: 2.9.2 45 | require_serial: true 46 | types_or: [python, pyi] 47 | - repo: https://github.com/PyCQA/autoflake 48 | rev: v1.6.1 49 | hooks: 50 | - id: autoflake 51 | args: 52 | [ 53 | "--in-place", 54 | "--remove-all-unused-imports", 55 | "--remove-unused-variables", 56 | ] 57 | - repo: https://github.com/cheshirekow/cmake-format-precommit 58 | rev: v0.6.10 59 | hooks: 60 | - id: cmake-format 61 | - repo: https://github.com/codespell-project/codespell 62 | rev: v2.2.4 63 | hooks: 64 | - id: codespell 65 | args: 66 | - --skip=.git,thirdparty* 67 | - -L te 68 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/run_bert_attn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/wsl/lib:/home/shared_folder/cudnn-linux-x86_64-9.0.0.312_cuda12-archive/lib 4 | 5 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-fp32-to-fp16 --hom-to-homnvgpu --homnvgpu-fusion --homnvgpu-legalize-gemm --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --extract-init-func -convert-func-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm -unify-llvm-func-interface bert_attn.mlir |\ 6 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 7 | ../../../thirdparty/llvm-project/build/bin/llc > bert_attn_nvgpu.s 8 | 9 | clang++-17 bert_attn_nvgpu.s -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libbert_attn_nvgpu.so 10 | 11 | clang++-17 bert_attn.cu -g -debug -fsanitize=address,undefined -I../../../include/ -I../../../thirdparty/llvm-project/mlir/include/ -I../../../thirdparty/cutlass/tools/library/include/ -I../../../thirdparty/TransformerEngine/transformer_engine/common/include -I../../../thirdparty/llvm-project/llvm/include/ -I../../../thirdparty/cutlass/include/ -I../../../thirdparty/llvm-project/build/include/ -L./ -L../../../build/lib/ -L../../../thirdparty/llvm-project/build/lib -L../../../thirdparty/TransformerEngine -lLLVM-17 -lhands_on_mlir_runner_utils -lbert_attn_nvgpu -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -ltransformer_engine -L$CUDA_HOME/lib64 \ 12 | -lcudart_static -Wl,-rpath,../../../build/lib -Wl,-rpath,../../../thirdparty/TransformerEngine -Wl,-rpath,../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ --cuda-gpu-arch=sm_89 -std=gnu++17 -o run 13 | 14 | CUDNN_FRONTEND_LOG_INFO=1 CUDNN_FRONTEND_LOG_FILE=stderr CUDNN_LOGERR_DBG=3 CUDNN_LOGDEST_DBG=stderr LSAN_OPTIONS=suppressions=../../../lsan.supp UBSAN_OPTIONS=suppressions=../../../ubsan.supp ASAN_OPTIONS=protect_shadow_gap=0,detect_odr_violation=0 ./run 15 | -------------------------------------------------------------------------------- /include/Conversions/HOM/HOMToHOMNVGPU.pdll: -------------------------------------------------------------------------------- 1 | #include "HOM/HOMOps.td" 2 | #include "HOMNVGPU/HOMNVGPUOps.td" 3 | 4 | Pattern { 5 | let root = op(input0 : Value, input1 : Value, input2 : Value); 6 | 7 | replace root with op(input0, input1, input2){ 8 | alpha = attr<"1.0 : f32">, beta = attr<"1.0 : f32">, 9 | act = attr<"0 : i32">}; 10 | } 11 | 12 | Pattern { 13 | let root = op(input0 14 | : Value, input1 15 | : Value) 16 | ->(resultType 17 | : Type); 18 | 19 | rewrite root with { 20 | let dummy_tensor = op->(resultType); 21 | replace root with op(input0, input1, dummy_tensor){ 22 | alpha = attr<"1.0 : f32">, beta = attr<"0.0 : f32">, 23 | act = attr<"0 : i32">}; 24 | }; 25 | } 26 | 27 | Pattern { 28 | let act = attr<"0 : i32">; 29 | let matmul = op( 30 | input0 31 | : Value, input1 32 | : Value, input2 33 | : Value){act = act, alpha = A : Attr, beta = B : Attr}; 34 | let root = op(matmul); 35 | 36 | rewrite root with { 37 | replace root with op(input0, input1, input2){ 38 | act = attr<"1 : i32">, alpha = A, beta = B}; 39 | erase matmul; 40 | }; 41 | } 42 | 43 | Pattern { 44 | let root = op(input 45 | : Value){axis = axis : Attr, eps = eps : Attr}; 46 | 47 | replace root with op(input){axis = axis, eps = eps}; 48 | } 49 | 50 | Pattern { 51 | let root = op( 52 | qkv 53 | : Value, mask 54 | : Value){scale = scale : Attr, head_num = head_num : Attr}; 55 | 56 | replace root with op(qkv, mask){scale = scale, 57 | head_num = head_num}; 58 | } 59 | 60 | Pattern { 61 | let root = op(a : Value, b : Value); 62 | 63 | replace root with op(a, b); 64 | } 65 | 66 | Pattern { 67 | let root = op(a : Value, b : Value); 68 | 69 | replace root with op(a, b); 70 | } 71 | -------------------------------------------------------------------------------- /lib/WeightsEngine/WeightsEngine.cpp: -------------------------------------------------------------------------------- 1 | #include "WeightsEngine/WeightsEngine.h" 2 | #include "WeightsEngine/Utils.h" 3 | #include "half.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | namespace mlir { 10 | namespace hands_on_mlir { 11 | 12 | template 13 | static void printNativeElement(const T &element, llvm::raw_ostream &out) { 14 | out << element; 15 | } 16 | 17 | template <> 18 | void printNativeElement(const fp16 &element, llvm::raw_ostream &out) { 19 | out << float(element); 20 | } 21 | 22 | template 23 | void WeightsEngine::serializeWeightToDisk(const ShapedType &shape, T *data, 24 | const std::string &fileName) { 25 | auto dimSize = shape.getShape(); 26 | std::error_code EC; 27 | llvm::raw_fd_ostream out(fileName, EC); 28 | // To-do: Change to a better store format. 29 | for (auto i : dimSize) { 30 | out << i << " "; 31 | } 32 | out << "\n"; 33 | auto totalSize = shape.getNumElements(); 34 | for (int i = 0; i < totalSize; i++) { 35 | printNativeElement(data[i], out); 36 | out << " "; 37 | } 38 | out << "\n"; 39 | } 40 | 41 | size_t WeightsEngine::addWeight(std::shared_ptr weight) { 42 | weightsMap[weightsIds++] = weight; 43 | return weightsIds - 1; 44 | } 45 | 46 | size_t WeightsEngine::addWeight(ElementsAttr &elements) { 47 | std::shared_ptr sPtr; 48 | auto idx = addWeight(sPtr); 49 | 50 | auto fn = [&](std::shared_ptr dataPtr) { 51 | serializeWeightToDisk( 52 | elements.getShapedType(), dataPtr.get(), 53 | std::filesystem::path(__FILE__).parent_path().string() + 54 | std::string("/../../examples/torch/linear/") + std::to_string(idx) + 55 | ".txt"); 56 | sPtr = dataPtr; 57 | }; 58 | 59 | universalCastElementsToPtr(elements, fn); 60 | 61 | return idx; 62 | } 63 | 64 | void WeightsEngine::removeWeight(size_t idx) { 65 | auto iter = weightsMap.find(idx); 66 | if (iter != weightsMap.end()) { 67 | weightsMap.erase(iter); 68 | } 69 | } 70 | 71 | WeightsEngine gWe; 72 | 73 | } // namespace hands_on_mlir 74 | } // namespace mlir 75 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/run_fp16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/wsl/lib:/home/shared_folder/cudnn-linux-x86_64-9.0.0.312_cuda12-archive/lib 4 | 5 | input_file=$1 6 | file_name=${input_file%.mlir} 7 | file_name=${file_name#./} 8 | asm_file=${input_file%.mlir}.s 9 | so_file=lib$file_name.so 10 | 11 | echo "Processing $input_file" 12 | echo "so_file $so_file" 13 | 14 | ../../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-fp32-to-fp16 --hom-to-homnvgpu --homnvgpu-fusion --homnvgpu-legalize-gemm --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --hom-func-to-llvm-pipeline $input_file |\ 15 | ../../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 16 | ../../../../thirdparty/llvm-project/build/bin/llc > $asm_file 17 | 18 | clang++-17 $asm_file -O3 -g -fPIC -shared -L../../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o $so_file 19 | 20 | clang++-17 run_fp16.cu -g -O3 -I../../../../thirdparty/cutlass/tools/library/include -I../../../../include/ -I../../../../thirdparty/llvm-project/mlir/include/ -I../../../../thirdparty/TransformerEngine/transformer_engine/common/include -L../../../../thirdparty/TransformerEngine/ -I../../../../thirdparty/llvm-project/llvm/include/ -I../../../../thirdparty/cutlass/include/ -I../../../../thirdparty/llvm-project/build/include/ -L./ -L../../../../build/lib/ -L../../../../thirdparty/llvm-project/build/lib -ltransformer_engine -lLLVM-17 -lhands_on_mlir_runner_utils -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -ldl -lpthread -lrt -L$CUDA_HOME/lib64 \ 21 | -lcudart_static -Wl,-rpath,../../../../build/lib -Wl,-rpath,../../../../thirdparty/llvm-project/build/lib -Wl,-rpath,../../../../thirdparty/TransformerEngine/ -Wl,-rpath,./ --cuda-gpu-arch=sm_86 -std=gnu++17 -o run 22 | 23 | pattern="hom_linear_([0-9]+)_([0-9]+)_([0-9]+)\.mlir" 24 | if [[ $input_file =~ $pattern ]]; then 25 | M="${BASH_REMATCH[1]}" 26 | N="${BASH_REMATCH[2]}" 27 | K="${BASH_REMATCH[3]}" 28 | # 输出提取的值 29 | echo "run_fp16.sh: M: $M, N: $N, K: $K" 30 | fi 31 | 32 | ./run $M $N $K 33 | -------------------------------------------------------------------------------- /lib/NVGPUKernels/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(Python3 REQUIRED) 2 | 3 | execute_process( 4 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../python/ 5 | COMMAND 6 | ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR} 7 | ${Python3_EXECUTABLE} ${HANDS_ON_MLIR_SOURCE_DIR}/python/gemm_generator.py 8 | --operations "*" # To-do: make it configurable 9 | --build-dir ${PROJECT_BINARY_DIR} --curr-build-dir 10 | ${CMAKE_CURRENT_BINARY_DIR} --generator-target library --architectures 11 | "${CMAKE_CUDA_ARCHITECTURES}" 12 | # To-do: make it configurable 13 | --kernels "*_nn_*" --ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}" 14 | --kernel-filter-file "${KERNEL_FILTER_FILE}" --selected-kernel-list 15 | "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}" --cuda-version 16 | "${CMAKE_CUDA_COMPILER_VERSION}" --log-level DEBUG 17 | --disable-cutlass-package-imports 18 | RESULT_VARIABLE cutlass_kernel_INSTANCE_GENERATION_RESULT 19 | OUTPUT_VARIABLE cutlass_kernel_INSTANCE_GENERATION_OUTPUT 20 | OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_kernel_instance_generation.log 21 | ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_kernel_instance_generation.log) 22 | 23 | if(NOT cutlass_kernel_INSTANCE_GENERATION_RESULT EQUAL 0) 24 | message( 25 | FATAL_ERROR 26 | "Error generating library instances. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_kernel_instance_generation.log" 27 | ) 28 | endif() 29 | 30 | message( 31 | STATUS 32 | "Completed generation of library instances. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_kernel_instance_generation.log for more information." 33 | ) 34 | 35 | set(CUTLASS_KERNEL_MANIFEST_CMAKE_FILE 36 | ${CMAKE_CURRENT_BINARY_DIR}/generated/manifest.cmake) 37 | if(EXISTS "${CUTLASS_KERNEL_MANIFEST_CMAKE_FILE}") 38 | include(${CUTLASS_KERNEL_MANIFEST_CMAKE_FILE}) 39 | else() 40 | message( 41 | STATUS 42 | "auto-generated library manifest cmake file (${CUTLASS_KERNEL_MANIFEST_CMAKE_FILE}) not found." 43 | ) 44 | endif() 45 | 46 | add_library(GemmManifestAndProfiler STATIC GemmManifest.cu GemmProfiler.cu 47 | GemmRunner.cu) 48 | 49 | target_compile_options(GemmManifestAndProfiler PRIVATE -fPIC) 50 | target_link_libraries(GemmManifestAndProfiler cutlass_library_objs) 51 | -------------------------------------------------------------------------------- /lib/Dialect/HOM/HOMSerializeWeight.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "Conversions/Tosa/Passes.h" 7 | #include "HOM/HOMOps.h" 8 | #include "WeightsEngine/WeightsEngine.h" 9 | #include "mlir/Dialect/Func/IR/FuncOps.h" 10 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 11 | #include "mlir/IR/BuiltinAttributes.h" 12 | #include "mlir/IR/PatternMatch.h" 13 | #include "mlir/Parser/Parser.h" 14 | #include "mlir/Support/LogicalResult.h" 15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 | #include "llvm/ADT/APInt.h" 17 | #include "llvm/Support/raw_ostream.h" 18 | 19 | #define PASS_NAME "hom-serialize-weight" 20 | #define DEBUG_TYPE PASS_NAME 21 | 22 | namespace mlir { 23 | namespace hands_on_mlir { 24 | namespace hom { 25 | 26 | #define GEN_PASS_DEF_HOMSERIALIZEWEIGHTPASS 27 | #include "HOM/Passes.h.inc" 28 | 29 | namespace { 30 | 31 | struct SerializeTosaConstOp : public OpRewritePattern { 32 | using OpRewritePattern::OpRewritePattern; 33 | 34 | LogicalResult matchAndRewrite(tosa::ConstOp op, 35 | PatternRewriter &rewriter) const override { 36 | auto loc = op.getLoc(); 37 | auto value = op.getValueAttr(); 38 | 39 | auto idx = gWe.addWeight(value); 40 | 41 | auto constantOP = rewriter.create(loc, value.getType(), idx); 42 | 43 | while (!op->getUses().empty()) { 44 | op->getUses().begin()->set(constantOP.getResult()); 45 | } 46 | 47 | rewriter.eraseOp(op); 48 | 49 | return success(); 50 | } 51 | }; 52 | 53 | struct HOMSerializeWeightPass 54 | : impl::HOMSerializeWeightPassBase { 55 | void runOnOperation() final; 56 | 57 | LogicalResult initialize(MLIRContext *ctx) override; 58 | 59 | private: 60 | FrozenRewritePatternSet patterns; 61 | }; 62 | 63 | LogicalResult HOMSerializeWeightPass::initialize(MLIRContext *ctx) { 64 | RewritePatternSet patternList(ctx); 65 | 66 | patternList.add(ctx); 67 | patterns = std::move(patternList); 68 | return success(); 69 | } 70 | 71 | void HOMSerializeWeightPass::runOnOperation() { 72 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 73 | } 74 | 75 | } // namespace 76 | } // namespace hom 77 | } // namespace hands_on_mlir 78 | } // namespace mlir 79 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/cuSeqLen.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/Utils.h" 5 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 6 | #include "llvm/Support/Error.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #define RowMajor(A, des, i) ((A)[(i) * (des).strides[0]]) 13 | 14 | int main() { 15 | 16 | constexpr int64_t bs = 16; 17 | constexpr int64_t seq_len = 64; 18 | 19 | auto a = allocHelper({bs, seq_len}, nvgpuAllocer); 20 | 21 | auto des = static_cast *>(a.descriptor); 22 | 23 | auto host_ptr = new int64_t[seq_len * bs]; 24 | checkCudaErrors(cudaMalloc(&(des->data), sizeof(int64_t) * bs * seq_len)); 25 | 26 | int32_t data[] = {64, 12, 31, 32, 33, 34, 35, 36}; 27 | for (int i = 0; i < bs; i++) { 28 | for (int j = 0; j < seq_len; j++) { 29 | host_ptr[i * seq_len + j] = j < data[i % 8]; 30 | std::cout << host_ptr[i * seq_len + j] << " "; 31 | } 32 | std::cout << std::endl; 33 | } 34 | cudaMemcpy(des->data, host_ptr, sizeof(int64_t) * seq_len * bs, 35 | cudaMemcpyHostToDevice); 36 | 37 | UnrankedMemRefType b; 38 | mlir::hands_on_mlir::ExecutionEngine e("libcuseqlen_nvgpu.so"); 39 | 40 | auto res = e.invoke("forward", a.rank, a.descriptor, 41 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 42 | if (res) { 43 | llvm::handleAllErrors(std::move(res)); 44 | } 45 | 46 | res = e.invoke("forward", a.rank, a.descriptor, 47 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 48 | if (res) { 49 | llvm::handleAllErrors(std::move(res)); 50 | } 51 | 52 | auto new_host = new int32_t[bs * seq_len]; 53 | 54 | auto c = DynamicMemRefType(b); 55 | std::cout << c.rank << std::endl; 56 | cudaMemcpy(new_host, c.data, sizeof(int32_t) * c.sizes[0], 57 | cudaMemcpyDeviceToHost); 58 | for (int i = 0; i < c.sizes[0]; i++) { 59 | std::cout << RowMajor(new_host, c, i) << " "; 60 | } 61 | std::cout << std::endl; 62 | 63 | cudaFree(des->data); 64 | cudaFree(c.data); 65 | 66 | delete[] host_ptr; 67 | delete[] new_host; 68 | 69 | free(a.descriptor); 70 | free(b.descriptor); 71 | } 72 | -------------------------------------------------------------------------------- /examples/torch/layernorm/cuda/layernorm.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 3 | #include "NVGPUKernels/Utils.h" 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include "llvm/Support/Error.h" 6 | #include 7 | #include 8 | #include 9 | 10 | struct Res { 11 | C_UnrankedMemRefType a; 12 | }; 13 | 14 | #define RowMajor(A, des, i, j) \ 15 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1]]) 16 | 17 | int main() { 18 | C_UnrankedMemRefType a; 19 | 20 | a.rank = 2; 21 | 22 | a.descriptor = malloc(sizeof(StridedMemRefType)); 23 | auto des = static_cast *>(a.descriptor); 24 | float host_ptr[] = {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 25 | 0.7936, 0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 26 | 0.7411, 0.4294, 0.8854, 0.5739, 0.2666, 0.6274}; 27 | checkCudaErrors(cudaMalloc(&(des->data), sizeof(host_ptr))); 28 | std::cout << des->data << std::endl; 29 | des->basePtr = des->data; 30 | des->sizes[0] = 2; 31 | des->sizes[1] = 10; 32 | des->strides[0] = 10; 33 | des->strides[1] = 1; 34 | cudaMemcpy(des->data, host_ptr, sizeof(host_ptr), cudaMemcpyHostToDevice); 35 | 36 | Res b; 37 | mlir::hands_on_mlir::ExecutionEngine e("liblayernorm_nvgpu.so"); 38 | 39 | auto res = e.invoke("forward", a.rank, a.descriptor, 40 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 41 | if (res) { 42 | llvm::handleAllErrors(std::move(res)); 43 | } 44 | 45 | res = e.invoke("forward", a.rank, a.descriptor, 46 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 47 | if (res) { 48 | llvm::handleAllErrors(std::move(res)); 49 | } 50 | 51 | auto c = DynamicMemRefType(b.a); 52 | std::cout << c.rank << " " << c.sizes[0] << " " << c.sizes[1] << std::endl; 53 | cudaMemcpy(host_ptr, c.data, sizeof(float) * c.sizes[0] * c.sizes[1], 54 | cudaMemcpyDeviceToHost); 55 | for (int i = 0; i < c.sizes[0]; i++) { 56 | for (int j = 0; j < c.sizes[1]; j++) { 57 | std::cout << RowMajor(host_ptr, c, i, j) << " "; 58 | } 59 | std::cout << std::endl; 60 | } 61 | 62 | cudaFree(des->data); 63 | cudaFree(c.data); 64 | 65 | free(a.descriptor); 66 | free(b.a.descriptor); 67 | } 68 | -------------------------------------------------------------------------------- /README_OLD.md: -------------------------------------------------------------------------------- 1 | # Hands-on-MLIR 2 | 3 | A simple project to optimize `linalg.matmul` using mlir framework. Currently developing in progress. Feel free to create an issue if you have any suggestions or problems. 4 | 5 | # What can it do? 6 | 7 | Currently, this project can lower the `linalg.matmul` to `affine` dialect with tiling. Also, this project provide a simple benchmark to measure the optimization's gFlops. However, it is not fast right now.(at about 2 gFlops compared to ~100 gFlops of mkl performance) 8 | 9 | # To-do 10 | 11 | + explicit affine data packing mechanism. (`affineDataCopyGenerate` simply cannot work when the tensor shape is unknown. Maybe I should implement it myself.) 12 | + vector ld/st & compute. 13 | + And more... 14 | 15 | # Install 16 | 17 | ## Install MLIR 18 | 19 | Install it in your preferable way. This project should be compatible with the main branch of mlir. 20 | 21 | ## Install this project 22 | 23 | If you didn't enable address sanitizer when installing the mlir, please remove the following lines in CMakeLists.txt. (I'm to lazy to make it configurable) 24 | 25 | ``` 26 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address") 27 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") 28 | ``` 29 | 30 | Then use the following command to compile. 31 | 32 | ``` 33 | $ cd Hands-on-MLIR 34 | $ mkdir build && cd build 35 | $ cmake -G Ninja .. \ 36 | -DMLIR_DIR=/your/path/to/llvm-project/build/lib/cmake/mlir \ 37 | -DLLVM_DIR=/your/path/to/llvm-project/build/lib/cmake/llvm \ 38 | -DLLVM_ENABLE_ASSERTIONS=ON 39 | ``` 40 | 41 | or you can use this setup in VSCode. 42 | 43 | ``` 44 | "cmake.configureArgs": [ 45 | "-DMLIR_DIR=/your/path/to/llvm-project/build/lib/cmake/mlir", 46 | "-DLLVM_DIR=/your/path/to/llvm-project/build/lib/cmake/llvm", 47 | "-DLLVM_ENABLE_ASSERTIONS=ON" 48 | ], 49 | ``` 50 | 51 | # Reference 52 | 53 | + [MLIR](https://github.com/llvm/llvm-project/):抄了很多( 54 | + [buddy-mlir](https://github.com/buddy-compiler/buddy-mlir):同样抄了很多( 55 | + [polymage-labs/mlirx](https://github.com/polymage-labs/mlirx):版本太老了,很多都没法抄( 56 | + [Polyhedral Model 三篇](https://mp.weixin.qq.com/s?__biz=MzI3MDQ2MjA3OA==&mid=2247485130&idx=1&sn=a5773bf17e6854d1238b035366641bcc&chksm=ead1fbdbdda672cdf9b2480a431cef85e4d377d07f8c586a932adabd50656cbdcd7d891156bf&mpshare=1&scene=1&srcid=&sharer_sharetime=1569677798809&sharer_shareid=b33ef36fa0caf5cb82e76916516aa7df#rd):知道多面体优化的基本概念。 57 | -------------------------------------------------------------------------------- /include/Conversions/Function/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_CONVERSIONS_FUNC_TRANSFORMS_PASSES_H 2 | #define HOM_CONVERSIONS_FUNC_TRANSFORMS_PASSES_H 3 | 4 | #include "mlir/Conversion/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/Pass/Pass.h" 7 | #include "mlir/Pass/PassManager.h" 8 | #include "mlir/Pass/PassRegistry.h" 9 | #include "mlir/Support/LogicalResult.h" 10 | #include "mlir/Transforms/DialectConversion.h" 11 | 12 | #include "Conversions/Function/FunctionUtils.h" 13 | #include "HOM/HOMOps.h" 14 | 15 | namespace mlir { 16 | namespace hands_on_mlir { 17 | 18 | #define GEN_PASS_DECL_EXTRACTINITFUNCPASS 19 | #define GEN_PASS_DECL_HOMTOFUNCPASS 20 | #define GEN_PASS_DECL_HOMNVGPUTOFUNCPASS 21 | #define GEN_PASS_DECL_UNIFYLLVMFUNCINTERFACEPASS 22 | #define GEN_PASS_DECL_OPTIMIZEMEMORYPASS 23 | #define GEN_PASS_REGISTRATION 24 | #include "Conversions/Function/Passes.h.inc" 25 | 26 | namespace { 27 | struct ConvertHOMDummyTensorOp 28 | : public OpConversionPattern { 29 | using OpConversionPattern::OpConversionPattern; 30 | 31 | LogicalResult 32 | matchAndRewrite(hom::DummyTensorOp op, OpAdaptor adaptor, 33 | ConversionPatternRewriter &rewriter) const override { 34 | 35 | auto loc = op->getLoc(); 36 | auto moduleOp = op->getParentOfType(); 37 | 38 | auto allocFn = lookupOrCreateAllocDummyTensorF32Fn(moduleOp); 39 | 40 | auto allocCaller = 41 | rewriter.create(loc, allocFn, ArrayRef{}); 42 | 43 | while (!op.use_empty()) { 44 | op->getUses().begin()->set(allocCaller->getResult(0)); 45 | } 46 | 47 | // maybeInsertDeallocFn(rewriter, op, {allocCaller->getResult(0)}); 48 | rewriter.eraseOp(op); 49 | 50 | return success(); 51 | } 52 | }; 53 | } // namespace 54 | 55 | inline void registerHOMFuncToLLVMPipelines() { 56 | PassPipelineRegistration<>( 57 | "hom-func-to-llvm-pipeline", "Convert HOM func call to llvm ir", 58 | [](OpPassManager &pm) { 59 | pm.addPass(createExtractInitFuncPass()); 60 | pm.addPass(createConvertFuncToLLVMPass()); 61 | pm.addPass(createFinalizeMemRefToLLVMConversionPass()); 62 | pm.addPass(createArithToLLVMConversionPass()); 63 | pm.addPass(createUnifyLLVMFuncInterfacePass()); 64 | }); 65 | } 66 | 67 | } // namespace hands_on_mlir 68 | } // namespace mlir 69 | 70 | #endif // HOM_CONVERSIONS_FUNC_TRANSFORMS_PASSES_H 71 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/b.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include "llvm/Support/Error.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define RowMajor(A, des, i, j, k) \ 12 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1] + \ 13 | (k) * (des).strides[2]]) 14 | 15 | int main() { 16 | 17 | int64_t m = 64 * 8, n = 768, k = 768; 18 | 19 | auto a = allocHelper({1, m, k}, nvgpuAllocer); 20 | auto b = allocHelper({1, k, n}, nvgpuAllocer); 21 | auto c = allocHelper({1, m, n}, nvgpuAllocer); 22 | auto d = allocHelper({1, m, n}, nvgpuAllocer); 23 | 24 | auto a_host = new half[m * k]; 25 | auto b_host = new half[n * k]; 26 | auto c_host = new half[m * n]; 27 | 28 | for (int i = 0; i < m * k; i++) { 29 | a_host[i] = 0.1; 30 | } 31 | 32 | for (int i = 0; i < n * k; i++) { 33 | b_host[i] = 0.2; 34 | } 35 | 36 | for (int i = 0; i < m * n; i++) { 37 | c_host[i] = 0.1; 38 | } 39 | 40 | auto desA = static_cast *>(a.descriptor); 41 | auto desb = static_cast *>(b.descriptor); 42 | auto desd = static_cast *>(d.descriptor); 43 | 44 | cudaMemcpy(desA->data, a_host, sizeof(half) * m * k, cudaMemcpyHostToDevice); 45 | cudaMemcpy(desb->data, b_host, sizeof(half) * n * k, cudaMemcpyHostToDevice); 46 | 47 | cutlassGemmF16(a.rank, a.descriptor, false, b.rank, b.descriptor, false, 48 | c.rank, c.descriptor, d.rank, d.descriptor, 0, 1, 0, 219, 1); 49 | 50 | auto err = cudaStreamSynchronize(nullptr); 51 | if (err != cudaSuccess) { 52 | std::cout << cudaGetErrorString(err) << std::endl; 53 | exit(-1); 54 | } 55 | 56 | cudaMemcpy(c_host, desd->data, sizeof(half) * m * n, cudaMemcpyDeviceToHost); 57 | 58 | for (int i = 0; i < 1; i++) { 59 | for (int j = 0; j < m; j++) { 60 | for (int kk = 0; kk < n; kk++) { 61 | if (std::abs(float(c_host[i]) - k * 0.02) > 1e-1) { 62 | std::cout << float(c_host[i]) << std::endl; 63 | std::cout << float(c_host[i]) - k * 0.01 << std::endl; 64 | std::cout << "Not ok" << std::endl; 65 | } 66 | } 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/run.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 3 | #include "NVGPUKernels/Utils.h" 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include "llvm/Support/Error.h" 6 | #include 7 | #include 8 | #include 9 | 10 | struct Res { 11 | C_UnrankedMemRefType a; 12 | }; 13 | 14 | #define RowMajor(A, des, i, j, k) \ 15 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1] + \ 16 | (k) * (des).strides[2]]) 17 | 18 | int main() { 19 | C_UnrankedMemRefType a; 20 | 21 | a.rank = 3; 22 | 23 | a.descriptor = malloc(sizeof(StridedMemRefType)); 24 | auto des = static_cast *>(a.descriptor); 25 | auto host_ptr = new float[3 * 200000]; 26 | checkCudaErrors(cudaMalloc(&(des->data), sizeof(float) * 3 * 200000)); 27 | std::cout << des->data << std::endl; 28 | des->basePtr = des->data; 29 | des->sizes[0] = 2; 30 | des->sizes[1] = 3; 31 | des->sizes[2] = 100000; 32 | des->strides[0] = 300000; 33 | des->strides[1] = 100000; 34 | des->strides[2] = 1; 35 | for (int i = 0; i < 600000; i++) { 36 | host_ptr[i] = 1; 37 | } 38 | cudaMemcpy(des->data, host_ptr, sizeof(float) * 3 * 200000, 39 | cudaMemcpyHostToDevice); 40 | 41 | Res b; 42 | mlir::hands_on_mlir::ExecutionEngine e("liblinear_nvgpu.so"); 43 | 44 | auto res = e.invoke("forward", a.rank, a.descriptor, 45 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 46 | if (res) { 47 | llvm::handleAllErrors(std::move(res)); 48 | } 49 | 50 | res = e.invoke("forward", a.rank, a.descriptor, 51 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 52 | if (res) { 53 | llvm::handleAllErrors(std::move(res)); 54 | } 55 | 56 | auto c = DynamicMemRefType(b.a); 57 | std::cout << c.rank << std::endl; 58 | cudaMemcpy(host_ptr, c.data, 59 | sizeof(float) * c.sizes[0] * c.sizes[1] * c.sizes[2], 60 | cudaMemcpyDeviceToHost); 61 | for (int i = 0; i < c.sizes[0]; i++) { 62 | for (int j = 0; j < c.sizes[1]; j++) { 63 | for (int k = 0; k < c.sizes[2]; k++) { 64 | std::cout << RowMajor(host_ptr, c, i, j, k) << " "; 65 | } 66 | std::cout << std::endl; 67 | } 68 | std::cout << std::endl; 69 | } 70 | 71 | cudaFree(des->data); 72 | cudaFree(c.data); 73 | 74 | delete[] host_ptr; 75 | 76 | free(a.descriptor); 77 | free(b.a.descriptor); 78 | } 79 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/a.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 3 | #include "NVGPUKernels/Utils.h" 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include "transformer_engine/gemm.h" 6 | #include "transformer_engine/transformer_engine.h" 7 | #include "llvm/Support/Error.h" 8 | #include 9 | #include 10 | #include 11 | 12 | int main() { 13 | using namespace transformer_engine; 14 | 15 | unsigned long long m = 2048, n = 3072, k = 768; 16 | half *a, *b, *c; 17 | checkCudaErrors(cudaMalloc(&a, sizeof(half) * m * k)); 18 | checkCudaErrors(cudaMalloc(&b, sizeof(half) * n * k)); 19 | checkCudaErrors(cudaMalloc(&c, sizeof(half) * m * n)); 20 | 21 | TensorWrapper a_tensor(a, {m, k}, NVTEWrapperDTypeMap::kType); 22 | TensorWrapper b_tensor(b, {n, k}, NVTEWrapperDTypeMap::kType); 23 | TensorWrapper c_tensor(c, {m, n}, NVTEWrapperDTypeMap::kType); 24 | 25 | auto workspace_buffer = getDummyPointer(4 * 1024 * 1024); 26 | auto pre_gelu_buffer = nullptr; 27 | 28 | TensorWrapper workspace(workspace_buffer.get(), {4 * 1024 * 1024}, 29 | NVTEWrapperDTypeMap::kType); 30 | TensorWrapper bias(nullptr, std::vector{0}, 31 | NVTEWrapperDTypeMap::kType); 32 | TensorWrapper pre_gelu(pre_gelu_buffer, std::vector{0}, 33 | NVTEWrapperDTypeMap::kType); 34 | 35 | auto mpCount = getMulitProcessorCount(); 36 | nvte_cublas_gemm(b_tensor.data(), a_tensor.data(), c_tensor.data(), 37 | bias.data(), pre_gelu.data(), true, false, false, 38 | workspace.data(), false, false, mpCount, nullptr); 39 | 40 | cudaEvent_t s, t; 41 | checkCudaErrors(cudaEventCreate(&s)); 42 | checkCudaErrors(cudaEventCreate(&t)); 43 | 44 | checkCudaErrors(cudaEventRecord(s)); 45 | 46 | for (int i = 0; i < 10000; i++) { 47 | nvte_cublas_gemm(b_tensor.data(), a_tensor.data(), c_tensor.data(), 48 | bias.data(), pre_gelu.data(), true, false, false, 49 | workspace.data(), false, false, mpCount, nullptr); 50 | } 51 | 52 | cudaEventRecord(t); 53 | checkCudaErrors(cudaEventSynchronize(t)); 54 | float msecTotal = 0; 55 | checkCudaErrors(cudaEventElapsedTime(&msecTotal, s, t)); 56 | 57 | std::cout << "E2E latency: " << msecTotal / 1000.0 / 10000.0 << "s" 58 | << std::endl; 59 | std::cout << "GFlops: " 60 | << (m * n * k * 2.0) * 1e-9 / (msecTotal / 10000.0 / 1000.0) 61 | << std::endl; 62 | } 63 | -------------------------------------------------------------------------------- /lib/Conversions/Function/OptimizeMemory.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "Conversions/Function/FunctionUtils.h" 11 | #include "Conversions/Function/Passes.h" 12 | #include "HOM/HOMOps.h" 13 | #include "WeightsEngine/WeightsEngine.h" 14 | #include "mlir/Dialect/Arith/IR/Arith.h" 15 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 | #include "mlir/Dialect/Func/IR/FuncOps.h" 17 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 18 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" 19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 21 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 22 | #include "mlir/IR/Builders.h" 23 | #include "mlir/IR/BuiltinAttributes.h" 24 | #include "mlir/IR/BuiltinOps.h" 25 | #include "mlir/IR/BuiltinTypes.h" 26 | #include "mlir/IR/Location.h" 27 | #include "mlir/IR/PatternMatch.h" 28 | #include "mlir/IR/ValueRange.h" 29 | #include "mlir/Parser/Parser.h" 30 | #include "mlir/Support/LLVM.h" 31 | #include "mlir/Support/LogicalResult.h" 32 | #include "mlir/Transforms/DialectConversion.h" 33 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 34 | #include "llvm/ADT/APInt.h" 35 | #include "llvm/ADT/STLExtras.h" 36 | #include "llvm/Support/ErrorHandling.h" 37 | #include "llvm/Support/raw_ostream.h" 38 | 39 | #define PASS_NAME "unify-llvm-func-interface" 40 | #define DEBUG_TYPE PASS_NAME 41 | 42 | namespace mlir { 43 | namespace hands_on_mlir { 44 | 45 | #define GEN_PASS_DEF_OPTIMIZEMEMORYPASS 46 | #include "Conversions/Function/Passes.h.inc" 47 | 48 | namespace { 49 | 50 | struct OptimizeMemoryPass : impl::OptimizeMemoryPassBase { 51 | void runOnOperation() final; 52 | 53 | LogicalResult initialize(MLIRContext *ctx) override; 54 | 55 | private: 56 | FrozenRewritePatternSet patterns; 57 | }; 58 | 59 | struct OptimizeMemory : public OpRewritePattern { 60 | using OpRewritePattern::OpRewritePattern; 61 | 62 | LogicalResult matchAndRewrite(func::CallOp op, 63 | PatternRewriter &rewriter) const override { 64 | return failure(); 65 | } 66 | }; 67 | 68 | LogicalResult OptimizeMemoryPass::initialize(MLIRContext *ctx) { 69 | RewritePatternSet patternList(ctx); 70 | patterns = std::move(patternList); 71 | return success(); 72 | } 73 | 74 | void OptimizeMemoryPass::runOnOperation() { 75 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 76 | } 77 | 78 | } // namespace 79 | } // namespace hands_on_mlir 80 | } // namespace mlir 81 | -------------------------------------------------------------------------------- /cmake/check_simd.cmake: -------------------------------------------------------------------------------- 1 | macro(CHECK_SIMD) 2 | 3 | include(CheckCXXSourceRuns) 4 | 5 | # ------------------------------------------------------------------------------- 6 | # Check Intel SSE 7 | # ------------------------------------------------------------------------------- 8 | 9 | set(CMAKE_REQUIRED_FLAGS -msse) 10 | check_cxx_source_runs( 11 | " 12 | #include 13 | int main() { 14 | __m128 x; 15 | x = _mm_set_ps(1.0f,1.0f,1.0f,1.0f); 16 | return 0; 17 | } 18 | " 19 | HAVE_SSE) 20 | 21 | if(${HAVE_SSE}) 22 | message(STATUS "\tSSE support - yes") 23 | else() 24 | message(STATUS "\tSSE support - no") 25 | endif(${HAVE_SSE}) 26 | 27 | # ------------------------------------------------------------------------------- 28 | # Check Intel AVX2 29 | # ------------------------------------------------------------------------------- 30 | 31 | set(CMAKE_REQUIRED_FLAGS -mavx2) 32 | check_cxx_source_runs( 33 | " 34 | #include 35 | int main() { 36 | int data[8] = {0,0,0,0,0,0,0,0}; 37 | __m256i a = _mm256_loadu_si256((const __m256i *)data); 38 | __m256i b = _mm256_bslli_epi128(a, 1); 39 | return 0; 40 | } 41 | " 42 | HAVE_AVX2) 43 | 44 | if(${HAVE_AVX2}) 45 | message(STATUS "\tAVX2 support - yes") 46 | else() 47 | message(STATUS "\tAVX2 support - no") 48 | endif(${HAVE_AVX2}) 49 | 50 | # ------------------------------------------------------------------------------- 51 | # Check Intel AVX512 52 | # ------------------------------------------------------------------------------- 53 | 54 | set(CMAKE_REQUIRED_FLAGS -mavx512f) 55 | check_cxx_source_runs( 56 | " 57 | #include 58 | int main() { 59 | float data[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; 60 | __m512 vector = _mm512_loadu_ps(data); 61 | return 0; 62 | } 63 | " 64 | HAVE_AVX512) 65 | 66 | if(${HAVE_AVX512}) 67 | message(STATUS "\tAVX512 support - yes") 68 | else() 69 | message(STATUS "\tAVX512 support - no") 70 | endif(${HAVE_AVX512}) 71 | 72 | # ------------------------------------------------------------------------------- 73 | # Check Arm Neon 74 | # ------------------------------------------------------------------------------- 75 | 76 | check_cxx_source_runs( 77 | " 78 | #include 79 | int main() { 80 | float32x4_t a; 81 | float A[] = {1.0,2.0,3.0,4.0}; 82 | a = vld1q_f32(A); 83 | return 0; 84 | } 85 | " 86 | HAVE_NEON) 87 | 88 | if(${HAVE_NEON}) 89 | message(STATUS "\tArm Neon support - yes") 90 | else() 91 | message(STATUS "\tArm Neon support - no") 92 | endif(${HAVE_NEON}) 93 | 94 | endmacro(CHECK_SIMD) 95 | -------------------------------------------------------------------------------- /examples/torch/bert/benchmark_bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.realpath("..")) 5 | 6 | import torch 7 | import torch.cuda.nvtx as nvtx 8 | from benchmark import speed_test 9 | from transformers import BertConfig, BertForMaskedLM, BertTokenizer 10 | 11 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 12 | 13 | 14 | def gen_input(bs, max_len): 15 | text = "Hello I'm a [MASK] model." 16 | encoded_input = tokenizer(text, return_tensors="pt") 17 | encoded_input_list = [ 18 | encoded_input["input_ids"].expand(bs, -1), 19 | encoded_input["attention_mask"].expand(bs, -1), 20 | encoded_input["token_type_ids"].expand(bs, -1), 21 | ] 22 | 23 | encoded_input_list = [ 24 | torch.concat( 25 | [ 26 | i, 27 | ( 28 | torch.zeros(bs, max_len - i.shape[1], dtype=torch.int64) 29 | if idx != 0 30 | else ( 31 | torch.tensor( 32 | [ 33 | [102 for k in range(max_len - i.shape[1])] 34 | for j in range(bs) 35 | ] 36 | ) 37 | if idx == 2 38 | else torch.ones(bs, max_len - i.shape[1], dtype=torch.int64) 39 | ) 40 | ), 41 | ], 42 | dim=-1, 43 | ).cuda() 44 | for idx, i in enumerate(encoded_input_list) 45 | ] 46 | 47 | return encoded_input_list 48 | 49 | 50 | for model_mode in ["base", "large"]: 51 | for bs in [1, 8, 16, 32]: 52 | for max_len in [64, 128]: 53 | i = gen_input(bs, max_len) 54 | 55 | class BertWrapper(torch.nn.Module): 56 | def __init__(self): 57 | with torch.cuda.amp.autocast(): 58 | super().__init__() 59 | config = BertConfig().from_pretrained( 60 | f"bert-{model_mode}-uncased" 61 | ) 62 | self.model = BertForMaskedLM(config) 63 | 64 | def forward(self, input_ids, attention_mask, token_type_ids): 65 | with torch.cuda.amp.autocast(): 66 | return self.model( 67 | input_ids, attention_mask, token_type_ids 68 | ).logits 69 | 70 | model = BertWrapper() 71 | model.eval() 72 | model = model.cuda() 73 | 74 | msg = f"model: {model_mode}, bs: {bs}, seq_len: {max_len}" 75 | 76 | print(msg) 77 | 78 | with nvtx.range(msg): 79 | speed_test(model, i) 80 | -------------------------------------------------------------------------------- /examples/torch/bert/convert_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.amp 3 | import torch_mlir 4 | from transformers import BertConfig, BertForMaskedLM, BertTokenizer 5 | 6 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 7 | 8 | text = "Hello I'm a [MASK] model." 9 | encoded_input = tokenizer(text, return_tensors="pt") 10 | 11 | 12 | for name in ["bert-base-uncased", "bert-large-uncased"]: 13 | for bs in [32]: 14 | for len in [128]: 15 | encoded_input_list = [ 16 | encoded_input["input_ids"].expand(bs, -1), 17 | encoded_input["attention_mask"].expand(bs, -1), 18 | encoded_input["token_type_ids"].expand(bs, -1), 19 | ] 20 | 21 | encoded_input_list = [ 22 | torch.concat( 23 | [ 24 | i, 25 | ( 26 | torch.zeros(bs, len - i.shape[1], dtype=torch.int64) 27 | if idx != 0 28 | else torch.tensor( 29 | [ 30 | [102 for k in range(len - i.shape[1])] 31 | for j in range(bs) 32 | ] 33 | ) 34 | ), 35 | ], 36 | dim=-1, 37 | ) 38 | for idx, i in enumerate(encoded_input_list) 39 | ] 40 | 41 | class BertWrapper(torch.nn.Module): 42 | def __init__(self): 43 | super().__init__() 44 | config = BertConfig().from_pretrained(name) 45 | self.model = BertForMaskedLM(config) 46 | 47 | def forward(self, input_ids, attention_mask, token_type_ids): 48 | return self.model(input_ids, attention_mask, token_type_ids).logits 49 | 50 | model = BertWrapper() 51 | model.eval() 52 | 53 | output = model(*encoded_input_list) 54 | 55 | with torch.no_grad(): 56 | module = torch_mlir.compile( 57 | model, encoded_input_list, output_type="TOSA", use_tracing=True 58 | ) 59 | 60 | with open(f"{name}_{bs}_{len}.mlir", "w") as fl: 61 | print(module, file=fl, end="") 62 | 63 | for idx in range(3): 64 | with open(f"{name}_{bs}_{len}_{idx}.txt", "w") as fl: 65 | for i in encoded_input_list[idx].reshape(-1): 66 | print(int(i), file=fl) 67 | 68 | with open(f"{name}_{bs}_{len}_3.txt", "w") as fl: 69 | for i in output.reshape(-1): 70 | print(float(i), file=fl) 71 | -------------------------------------------------------------------------------- /include/NVGPUKernels/CuSeqLen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/OperationRunner.h" 5 | #include "NVGPUKernels/Utils.h" 6 | #include "thrust/iterator/counting_iterator.h" 7 | #include "thrust/iterator/discard_iterator.h" 8 | #include "thrust/iterator/transform_iterator.h" 9 | #include "thrust/reduce.h" 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | namespace mlir { 21 | namespace hands_on_mlir { 22 | namespace homnvgpu_kernel { 23 | 24 | template 25 | class CuSeqLenRunner : public mlir::hands_on_mlir::OperationRunner { 26 | 27 | public: 28 | struct MakeKey : public std::unary_function { 29 | 30 | int64_t row_length_; 31 | 32 | MakeKey(int64_t row_length) : row_length_(row_length) {} 33 | 34 | __host__ __device__ constexpr int64_t operator()(const int64_t &x) const { 35 | return (x / row_length_) & 1; 36 | } 37 | }; 38 | 39 | template 40 | struct Cast : public std::unary_function { 41 | 42 | __host__ __device__ constexpr Result operator()(const Arg &x) const { 43 | return Result(x); 44 | } 45 | }; 46 | 47 | public: 48 | Status run(int rankIn, void *desIn, int rankOut, void *desOut) { 49 | auto In = convertToDynamicMemRefType(rankIn, desIn); 50 | auto Out = convertToDynamicMemRefType(rankOut, desOut); 51 | 52 | auto inTotlaSize = std::accumulate(In.sizes, In.sizes + rankIn, 1, 53 | std::multiplies()); 54 | 55 | assert(In.sizes[0] + 1 == Out.sizes[0]); 56 | assert(In.rank == 2); 57 | assert(Out.rank == 1); 58 | checkCudaErrors(cudaMemset( 59 | Out.data, 0, sizeof(int32_t))); // Use memset to avoid malloc by thrust. 60 | 61 | auto key = thrust::make_transform_iterator( 62 | thrust::make_counting_iterator(0), MakeKey(In.sizes[1])); 63 | auto inPtr = thrust::device_pointer_cast(In.data); 64 | auto in = thrust::make_transform_iterator( 65 | inPtr, Cast()); 66 | auto outPtr = thrust::device_pointer_cast(Out.data); 67 | 68 | thrust::reduce_by_key(key, key + inTotlaSize, in, 69 | thrust::make_discard_iterator(), outPtr + 1); 70 | 71 | thrust::inclusive_scan(outPtr + 1, outPtr + In.sizes[0] + 1, outPtr + 1); 72 | 73 | return Status::kSuccess; 74 | } 75 | }; 76 | 77 | } // namespace homnvgpu_kernel 78 | } // namespace hands_on_mlir 79 | } // namespace mlir 80 | -------------------------------------------------------------------------------- /lib/Dialect/HOMNVGPU/HOMNVGPUOps.cpp: -------------------------------------------------------------------------------- 1 | #include "mlir/IR/Builders.h" 2 | #include "mlir/IR/BuiltinTypes.h" 3 | #include "mlir/IR/DialectImplementation.h" 4 | #include "mlir/IR/MLIRContext.h" 5 | #include "mlir/IR/Operation.h" 6 | #include "mlir/IR/OperationSupport.h" 7 | #include "mlir/Support/LogicalResult.h" 8 | #include "mlir/Transforms/InliningUtils.h" 9 | #include "llvm/ADT/TypeSwitch.h" 10 | #include "llvm/AsmParser/Parser.h" 11 | #include "llvm/IR/Attributes.h" 12 | #include "llvm/IR/Function.h" 13 | #include "llvm/IR/Type.h" 14 | #include "llvm/Support/SourceMgr.h" 15 | 16 | #include "HOMNVGPU/HOMNVGPUOps.h" 17 | 18 | using namespace mlir; 19 | using namespace hands_on_mlir::homnvgpu; 20 | 21 | #include "HOMNVGPU/HOMNVGPUOpsDialect.cpp.inc" 22 | 23 | void HOMNVGPUDialect::initialize() { 24 | addOperations< 25 | #define GET_OP_LIST 26 | #include "HOMNVGPU/HOMNVGPUOps.cpp.inc" 27 | >(); 28 | } 29 | 30 | #define GET_OP_CLASSES 31 | #include "HOMNVGPU/HOMNVGPUOps.cpp.inc" 32 | 33 | LogicalResult MatmulWithVarMeanOp::inferReturnTypes( 34 | ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, 35 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 36 | ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, 37 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 38 | auto types = operands.getTypes(); 39 | auto type0 = dyn_cast(types[0]); 40 | auto type1 = dyn_cast(types[1]); 41 | auto type2 = dyn_cast(types[2]); 42 | 43 | if (type0 && type1 && type2 && 44 | type0.getElementType() == type1.getElementType() && 45 | type0.getElementType() == type2.getElementType()) { 46 | auto shape0 = type0.getShape(); 47 | auto shape1 = type1.getShape(); 48 | auto shape2 = type2.getShape(); 49 | 50 | if (shape0.size() == shape1.size() && shape0.size() == shape2.size() && 51 | shape0.size() == 3 && shape0[2] == shape1[1]) { 52 | auto resultType = RankedTensorType::get({std::max(shape0[0], shape2[0]), 53 | std::max(shape0[1], shape2[1]), 54 | std::max(shape1[2], shape2[2])}, 55 | type0.getElementType()); 56 | auto varType = RankedTensorType::get( 57 | {resultType.getShape()[0] * resultType.getShape()[1]}, 58 | type0.getElementType()); 59 | auto meanType = RankedTensorType::get( 60 | {resultType.getShape()[0] * resultType.getShape()[1]}, 61 | type0.getElementType()); 62 | 63 | inferredReturnTypes.emplace_back(resultType); 64 | inferredReturnTypes.emplace_back(varType); 65 | inferredReturnTypes.emplace_back(meanType); 66 | return success(); 67 | } 68 | } 69 | 70 | return failure(); 71 | } 72 | -------------------------------------------------------------------------------- /include/Dialect/HOM/HOMTypesBase.td: -------------------------------------------------------------------------------- 1 | #ifndef HOM_TYPES_BASE 2 | #define HOM_TYPES_BASE 3 | 4 | include "mlir/IR/OpBase.td" 5 | include "mlir/IR/BuiltinTypes.td" 6 | 7 | def HOM_UInt8 : UI<8>; 8 | def HOM_UInt16 : UI<16>; 9 | 10 | def HOM_Int4 : I<4>; 11 | def HOM_Int8 : I<8>; 12 | def HOM_Int16 : I<16>; 13 | def HOM_Int32 : I<32>; 14 | def HOM_Int48 : I<48>; 15 | def HOM_Int64 : I<64>; 16 | 17 | def HOM_SignedInt 18 | : AnyTypeOf<[HOM_Int8, HOM_Int16, HOM_Int32, HOM_Int48, HOM_Int64]>; 19 | 20 | def HOM_Bool : I<1>; 21 | 22 | // No unsigned unquantized int types. 23 | def HOM_Int : AnyTypeOf<[HOM_Bool, HOM_UInt8, HOM_UInt16, HOM_SignedInt]>; 24 | 25 | def HOM_Int32Or64 : AnyTypeOf<[HOM_Int32, HOM_Int64]>; 26 | 27 | //===----------------------------------------------------------------------===// 28 | // Floating-point types. 29 | //===----------------------------------------------------------------------===// 30 | def HOM_Float : AnyTypeOf<[F32, F16, BF16]>; 31 | 32 | //===----------------------------------------------------------------------===// 33 | // Multi-category types. 34 | //===----------------------------------------------------------------------===// 35 | def HOM_AnyNumber : AnyTypeOf<[HOM_Int, HOM_Float], "number">; 36 | 37 | //===----------------------------------------------------------------------===// 38 | // Tensor types 39 | //===----------------------------------------------------------------------===// 40 | 41 | def HOM_Int32Tensor : TensorOf<[HOM_Int32]>; 42 | def HOM_Int32Or64Tensor : TensorOf<[HOM_Int32Or64]>; 43 | 44 | // Either ranked or unranked tensor of HOM supported element types. 45 | def HOM_Tensor : TensorOf<[HOM_AnyNumber]>; 46 | 47 | // Must be ranked but no further constraints 48 | def HOM_RankedTensor : RankedTensorOf<[HOM_AnyNumber]>; 49 | 50 | def HOM_UnrankedTensor : AnyTypeOf<[UnrankedTensorOf<[HOM_AnyNumber]>]>; 51 | 52 | def HOM_Tensor1D 53 | : AnyTypeOf<[HOM_UnrankedTensor, 1DTensorOf < [HOM_AnyNumber] > ]>; 54 | def HOM_Tensor2D 55 | : AnyTypeOf<[HOM_UnrankedTensor, 2DTensorOf < [HOM_AnyNumber] > ]>; 56 | def HOM_Tensor3D 57 | : AnyTypeOf<[HOM_UnrankedTensor, 3DTensorOf < [HOM_AnyNumber] > ]>; 58 | def HOM_Tensor4D 59 | : AnyTypeOf<[HOM_UnrankedTensor, 4DTensorOf < [HOM_AnyNumber] > ]>; 60 | def HOM_Tensor5D 61 | : AnyTypeOf<[HOM_UnrankedTensor, TensorRankOf<[HOM_AnyNumber], [5]>]>; 62 | 63 | // Any tensor element type allowed in HOM ops. 64 | def HOM_ElementType 65 | : Type, "hom.dtype">; 66 | 67 | class HOM_TensorOfOrNone allowedTypes, string description = ""> 68 | : AnyTypeOf<[TensorOf, NoneType], description>; 69 | 70 | def HOM_MemRef 71 | : AnyTypeOf<[UnrankedMemRefOf<[HOM_AnyNumber]>, MemRefOf<[HOM_AnyNumber]>]>; 72 | 73 | def HOM_TensorOrMemRef : AnyTypeOf<[HOM_Tensor, HOM_MemRef]>; 74 | 75 | #endif // HOM_TYPES_BASE 76 | -------------------------------------------------------------------------------- /examples/torch/elementwise/add.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/Utils.h" 5 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 6 | #include "llvm/Support/Error.h" 7 | #include 8 | #include 9 | #include 10 | 11 | #define RowMajor(A, des, i, j, k) \ 12 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1] + \ 13 | (k) * (des).strides[2]]) 14 | 15 | int main() { 16 | 17 | auto a = allocHelper({3, 3, 3}, nvgpuAllocer); 18 | auto b = allocHelper({3, 1, 3}, nvgpuAllocer); 19 | 20 | auto Ades = static_cast *>(a.descriptor); 21 | float host_ptr_a[] = {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 22 | 0.7936, 0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 23 | 0.7411, 0.4294, 0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 24 | 0.4414, 0.2969, 0.8317, 0.1053, 0.2695, 0.3588}; 25 | float host_ptr_b[] = {0.1994, 0.5472, 0.0062, 0.9516, 0.0753, 26 | 0.8860, 0.5832, 0.3376, 0.809}; 27 | cudaMemcpy(Ades->data, host_ptr_a, sizeof(host_ptr_a), 28 | cudaMemcpyHostToDevice); 29 | 30 | auto Bdes = static_cast *>(b.descriptor); 31 | cudaMemcpy(Bdes->data, host_ptr_b, sizeof(host_ptr_b), 32 | cudaMemcpyHostToDevice); 33 | 34 | UnrankedMemRefType cc; 35 | mlir::hands_on_mlir::ExecutionEngine e("libadd_nvgpu.so"); 36 | 37 | auto res = e.invoke("forward", a.rank, a.descriptor, b.rank, b.descriptor, 38 | mlir::hands_on_mlir::ExecutionEngine::result(cc)); 39 | if (res) { 40 | llvm::handleAllErrors(std::move(res)); 41 | } 42 | 43 | res = e.invoke("forward", a.rank, a.descriptor, b.rank, b.descriptor, 44 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 45 | if (res) { 46 | llvm::handleAllErrors(std::move(res)); 47 | } 48 | 49 | auto c = DynamicMemRefType(cc); 50 | std::cout << c.rank << " " << c.sizes[0] << " " << c.sizes[1] << " " 51 | << c.sizes[2] << std::endl; 52 | cudaMemcpy(host_ptr_a, c.data, 53 | sizeof(float) * c.sizes[0] * c.sizes[1] * c.sizes[2], 54 | cudaMemcpyDeviceToHost); 55 | for (int i = 0; i < c.sizes[0]; i++) { 56 | for (int j = 0; j < c.sizes[1]; j++) { 57 | for (int k = 0; k < c.sizes[2]; k++) { 58 | std::cout << RowMajor(host_ptr_a, c, i, j, k) << " "; 59 | } 60 | } 61 | std::cout << std::endl; 62 | } 63 | 64 | cudaFree(Ades->data); 65 | cudaFree(Bdes->data); 66 | cudaFree(c.data); 67 | 68 | free(a.descriptor); 69 | free(b.descriptor); 70 | free(cc.descriptor); 71 | } 72 | -------------------------------------------------------------------------------- /include/NVGPUKernels/GatherRunner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/OperationRunner.h" 5 | #include "NVGPUKernels/Utils.h" 6 | #include "driver_types.h" 7 | #include "thrust/device_ptr.h" 8 | #include "thrust/iterator/counting_iterator.h" 9 | #include "thrust/iterator/transform_iterator.h" 10 | #include "thrust/iterator/zip_iterator.h" 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | namespace mlir { 25 | namespace hands_on_mlir { 26 | namespace homnvgpu_kernel { 27 | 28 | template 29 | class GatherRunner : public mlir::hands_on_mlir::OperationRunner { 30 | public: 31 | template 32 | struct GetRealIndexFn : public std::unary_function { 33 | 34 | thrust::device_ptr indices; 35 | T row_length; 36 | 37 | __host__ __device__ constexpr T operator()(const T &x) const { 38 | return indices[x / row_length] * row_length + (x % row_length); 39 | } 40 | 41 | GetRealIndexFn(thrust::device_ptr indices_, T row_length_) 42 | : indices(indices_), row_length(row_length_) {} 43 | }; 44 | 45 | public: 46 | Status run(int rankIndices, void *desIndices, int rankValue, void *desValue, 47 | int rankOut, void *desOut) { 48 | auto indices = convertToDynamicMemRefType(rankIndices, desIndices); 49 | auto value = convertToDynamicMemRefType(rankValue, desValue); 50 | auto out = convertToDynamicMemRefType(rankOut, desOut); 51 | 52 | assert(value.rank == 3); 53 | assert(value.sizes[0] == 1); 54 | assert(indices.rank + 1 == out.rank); 55 | assert(indices.rank == 2); 56 | 57 | assert(out.sizes[out.rank - 1] == value.sizes[2]); 58 | for (auto i = 0; i < indices.rank; i++) { 59 | assert(indices.sizes[i] == out.sizes[i]); 60 | } 61 | 62 | auto indices_thrust_ptr = thrust::device_pointer_cast(indices.data); 63 | auto value_thrust_ptr = thrust::device_pointer_cast(value.data); 64 | auto out_thrust_ptr = thrust::device_pointer_cast(out.data); 65 | 66 | auto total_size = std::accumulate(out.sizes, out.sizes + out.rank, 1, 67 | std::multiplies<>()); 68 | 69 | auto map_iter = thrust::make_transform_iterator( 70 | thrust::make_counting_iterator(0), 71 | GetRealIndexFn(indices_thrust_ptr, value.sizes[2])); 72 | 73 | thrust::gather(map_iter, map_iter + total_size, value_thrust_ptr, 74 | out_thrust_ptr); 75 | 76 | return Status::kSuccess; 77 | } 78 | }; 79 | 80 | } // namespace homnvgpu_kernel 81 | } // namespace hands_on_mlir 82 | } // namespace mlir 83 | -------------------------------------------------------------------------------- /lib/Conversions/Tosa/TosaToHOM.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "Conversions/Tosa/Passes.h" 7 | #include "HOM/HOMOps.h" 8 | #include "mlir/Dialect/Func/IR/FuncOps.h" 9 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 10 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" 11 | #include "mlir/IR/BuiltinAttributes.h" 12 | #include "mlir/IR/PatternMatch.h" 13 | #include "mlir/Parser/Parser.h" 14 | #include "mlir/Pass/PassManager.h" 15 | #include "mlir/Pass/PassRegistry.h" 16 | #include "mlir/Support/LogicalResult.h" 17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 | 19 | #define PASS_NAME "tosa-to-hom" 20 | #define DEBUG_TYPE PASS_NAME 21 | 22 | namespace mlir { 23 | namespace hands_on_mlir { 24 | namespace hom { 25 | 26 | #define GEN_PASS_DEF_TOSATOHOMPASS 27 | #include "Conversions/Tosa/Passes.h.inc" 28 | #include "Conversions/Tosa/TosaToHOM.pdll.h.inc" 29 | 30 | namespace { 31 | 32 | struct TosaToHOMPass : impl::TosaToHOMPassBase { 33 | void runOnOperation() final; 34 | 35 | LogicalResult initialize(MLIRContext *ctx) override; 36 | 37 | private: 38 | FrozenRewritePatternSet patterns; 39 | }; 40 | 41 | struct ConvertTosaMatmulOp : public OpRewritePattern { 42 | using OpRewritePattern::OpRewritePattern; 43 | 44 | LogicalResult matchAndRewrite(tosa::MatMulOp op, 45 | PatternRewriter &rewriter) const override { 46 | auto loc = op.getLoc(); 47 | 48 | int useSize = 0; 49 | for (auto iter = op->getUsers().begin(); iter != op->getUsers().end(); 50 | ++iter, ++useSize) { 51 | if (useSize == 2) { 52 | break; 53 | } 54 | } 55 | 56 | if (useSize == 1) { 57 | if (auto addOp = llvm::dyn_cast(*(op->getUsers().begin()))) { 58 | auto homMatmulAddOp = rewriter.create( 59 | loc, addOp.getResult().getType(), op.getA(), op.getB(), 60 | addOp.getInput2()); 61 | while (!addOp->getUses().empty()) { 62 | addOp->getUses().begin()->set(homMatmulAddOp.getResult()); 63 | } 64 | rewriter.eraseOp(addOp); 65 | op->dropAllUses(); 66 | rewriter.eraseOp(op); 67 | return success(); 68 | } 69 | } 70 | 71 | auto homMatmulOp = rewriter.create( 72 | loc, op.getResult().getType(), op->getOperand(0), op->getOperand(1)); 73 | 74 | rewriter.replaceOp(op, homMatmulOp); 75 | 76 | return success(); 77 | } 78 | }; 79 | 80 | LogicalResult TosaToHOMPass::initialize(MLIRContext *ctx) { 81 | RewritePatternSet patternList(ctx); 82 | 83 | populateGeneratedPDLLPatterns(patternList); 84 | patternList.add(ctx); 85 | patterns = std::move(patternList); 86 | return success(); 87 | } 88 | 89 | void TosaToHOMPass::runOnOperation() { 90 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 91 | } 92 | 93 | } // namespace 94 | } // namespace hom 95 | } // namespace hands_on_mlir 96 | } // namespace mlir 97 | -------------------------------------------------------------------------------- /lib/ExecutionEngine/ExecutionEngine.cpp: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "Conversions/Function/FunctionUtils.h" 3 | #include "mlir/IR/BuiltinOps.h" 4 | #include "mlir/IR/BuiltinTypes.h" 5 | #include "llvm/Support/Error.h" 6 | #include 7 | #include 8 | 9 | #define DEBUG_TYPE "execution-engine" 10 | 11 | using namespace mlir; 12 | using namespace hands_on_mlir; 13 | using llvm::Error; 14 | using llvm::Expected; 15 | using llvm::StringError; 16 | 17 | /// Wrap a string into an llvm::StringError. 18 | static Error makeStringError(const Twine &message) { 19 | return llvm::make_error(message.str(), 20 | llvm::inconvertibleErrorCode()); 21 | } 22 | 23 | Expected ExecutionEngine::lookupHandle(StringRef name) const { 24 | if (handle) { 25 | auto symbol = dlsym(handle, name.str().c_str()); 26 | auto err = dlerror(); 27 | if (err) { 28 | return makeStringError(err); 29 | } else { 30 | return symbol; 31 | } 32 | } 33 | return makeStringError("Handle is invalid."); 34 | } 35 | 36 | Expected 37 | ExecutionEngine::lookupPacked(StringRef name) const { 38 | auto result = lookup(name); 39 | if (!result) 40 | return result.takeError(); 41 | return reinterpret_cast(result.get()); 42 | } 43 | 44 | Expected ExecutionEngine::lookup(StringRef name) const { 45 | const std::string lookupName = "_hom_ciface_" + name.str(); 46 | auto expectedSymbol = lookupHandle(lookupName); 47 | 48 | if (!expectedSymbol) { 49 | std::string errorMessage; 50 | llvm::raw_string_ostream os(errorMessage); 51 | llvm::handleAllErrors(expectedSymbol.takeError(), 52 | [&os](llvm::ErrorInfoBase &ei) { ei.log(os); }); 53 | return makeStringError(os.str()); 54 | } 55 | 56 | if (void *fptr = expectedSymbol.get()) 57 | return fptr; 58 | return makeStringError("looked up function is null"); 59 | } 60 | 61 | Error ExecutionEngine::invokePacked(StringRef name, 62 | MutableArrayRef args) { 63 | auto expectedFPtr = lookupPacked(name); 64 | if (!expectedFPtr) 65 | return expectedFPtr.takeError(); 66 | auto fptr = *expectedFPtr; 67 | 68 | (*fptr)(args.data()); 69 | 70 | return Error::success(); 71 | } 72 | 73 | Expected> 74 | ExecutionEngine::invokeInit(StringRef name) { 75 | SmallVector concatName; 76 | 77 | int32_t argNum; 78 | 79 | auto argNumError = 80 | invokeInternal((name + kArgNum).toStringRef(concatName), result(argNum)); 81 | if (!argNumError) { 82 | llvm::handleAllErrors(std::move(argNumError)); 83 | } 84 | 85 | concatName.clear(); 86 | 87 | PackedArguments packedArgs(argNum); 88 | 89 | auto packedMemRefRes = invokeInternal((name + kInit).toStringRef(concatName), 90 | result(packedArgs)); 91 | 92 | if (packedMemRefRes) { 93 | return packedMemRefRes; 94 | } else { 95 | return packedArgs; 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /examples/torch/bert/run_hom.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash 2 | 3 | # Generated by Kimi Chat 4 | 5 | clang++-17 bert.cu -O3 -I../../../thirdparty/cutlass/tools/library/include -I../../../include/ -I../../../thirdparty/llvm-project/mlir/include/ -I../../../thirdparty/TransformerEngine/transformer_engine/common/include -I../../../thirdparty/llvm-project/llvm/include/ -I../../../thirdparty/cutlass/include/ -I../../../thirdparty/llvm-project/build/include/ -L./ -L../../../build/lib/ -L../../../thirdparty/llvm-project/build/lib -L../../../thirdparty/TransformerEngine -lLLVM-17 -lhands_on_mlir_runner_utils -lhands_on_mlir_nvgpu_runner_utils -lhands_on_mlir_execution_engine -ltransformer_engine -ldl -lpthread -lrt -L$CUDA_HOME/lib64 \ 6 | -lcudart_static -Wl,-rpath,../../../build/lib -Wl,-rpath,../../../thirdparty/TransformerEngine -Wl,-rpath,../../../thirdparty/llvm-project/build/lib -Wl,-rpath,./ --cuda-gpu-arch=sm_86 -std=gnu++17 -o run 7 | 8 | # 遍历当前目录及子目录下的所有 .mlir 文件 9 | find . -type f -name "bert-*.mlir" | while read mlir_file; do 10 | # 定义输出文件的名称,将 .mlir 后缀替换为 .so 11 | output_file="${mlir_file%.mlir}" 12 | output_file="${output_file#./}" 13 | 14 | ../../../build/bin/hands-on-opt --tosa-to-hom-pipeline --hom-fusion --hom-fp32-to-fp16 --hom-to-homnvgpu --homnvgpu-fusion $mlir_file > pre_tune.mlir 15 | 16 | # 检查编译是否成功 17 | if [ $? -eq 0 ]; then 18 | echo "Compilation successful: $output_file" 19 | else 20 | echo "Compilation failed for $mlir_file" 21 | fi 22 | 23 | pattern="bert-.*_([0-9]+)_([0-9]+)\.mlir" 24 | 25 | if [[ $mlir_file =~ $pattern ]]; then 26 | bs="${BASH_REMATCH[1]}" 27 | seq_len="${BASH_REMATCH[2]}" 28 | 29 | # 输出提取的值 30 | echo "bs: $bs, seq_len: $seq_len" 31 | fi 32 | 33 | ../../../build/bin/hands-on-opt --homnvgpu-legalize-gemm --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --hom-func-to-llvm-pipeline pre_tune.mlir | \ 34 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 35 | ../../../thirdparty/llvm-project/build/bin/llc > $output_file.s 36 | 37 | clang++-17 $output_file.s -O3 -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o lib${output_file}.so 38 | 39 | ./run $bs $seq_len $output_file 0 40 | 41 | nsys profile -o true_sync_$output_file ./run $bs $seq_len $output_file 0 42 | 43 | ../../../build/bin/hands-on-opt --homnvgpu-autotune --homnvgpu-legalize-gemm --tosa-layerwise-constant-fold --hom-serialize-weight --homnvgpu-to-func --hom-func-to-llvm-pipeline pre_tune.mlir | \ 44 | ../../../thirdparty/llvm-project/build/bin/mlir-translate --mlir-to-llvmir |\ 45 | ../../../thirdparty/llvm-project/build/bin/llc > autotune_$output_file.s 46 | 47 | clang++-17 autotune_$output_file.s -O3 -fPIC -shared -L../../../build/lib/ -lhands_on_mlir_execution_engine -lhands_on_mlir_nvgpu_runner_utils -L../../../thirdparty/llvm-project/build/lib -lLLVM-17 -std=gnu++17 -g -o libautotune_${output_file}.so 48 | 49 | ./run $bs $seq_len $output_file 1 50 | 51 | nsys profile -o autotune_true_sync_$output_file ./run $bs $seq_len $output_file 1 52 | 53 | done 54 | 55 | echo "Compilation process completed for all matching files." 56 | -------------------------------------------------------------------------------- /lib/ExecutionEngine/CutlassCAPI.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cutlass/cutlass.h" 4 | #include "cutlass/epilogue/collective/default_epilogue.hpp" 5 | #include "cutlass/epilogue/thread/linear_combination.h" 6 | #include "cutlass/gemm/collective/collective_builder.hpp" 7 | #include "cutlass/gemm/device/gemm.h" 8 | #include "cutlass/gemm/device/gemm_universal_adapter.h" 9 | #include "cutlass/gemm/kernel/gemm_universal.hpp" 10 | 11 | #include "cutlass/half.h" 12 | #include "cutlass/util/host_tensor.h" 13 | #include "cutlass/util/packed_stride.hpp" 14 | 15 | using namespace cute; 16 | // Copy examples from cutlass here just to make sure we can really compile 17 | // cutlass. 18 | auto cutlassMatmul(int M, int N, int K, float alpha, float const *A, int lda, 19 | float const *B, int ldb, float beta, float *C, int ldc) { 20 | 21 | // Define type definition for single-precision CUTLASS GEMM with 22 | // column-major input matrices and 128x128x8 threadblock tile size (chosen 23 | // by default). 24 | // 25 | // To keep the interface manageable, several helpers are defined for 26 | // plausible compositions including the following example for 27 | // single-precision GEMM. Typical values are used as default template 28 | // arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for 29 | // more details. 30 | // 31 | // To view the full gemm device API interface, see 32 | // `cutlass/gemm/device/gemm.h` 33 | 34 | using ColumnMajor = cutlass::layout::ColumnMajor; 35 | 36 | using CutlassGemm = 37 | cutlass::gemm::device::Gemm; // Layout of C matrix 43 | 44 | // Define a CUTLASS GEMM type 45 | CutlassGemm gemm_operator; 46 | 47 | // Construct the CUTLASS GEMM arguments object. 48 | // 49 | // One of CUTLASS's design patterns is to define gemm argument objects that 50 | // are constructible in host code and passed to kernels by value. These may 51 | // include pointers, strides, scalars, and other arguments needed by Gemm 52 | // and its components. 53 | // 54 | // The benefits of this pattern are (1.) a structured, composable strategy 55 | // for passing host-constructible arguments to kernels and (2.) minimized 56 | // initialization overhead on kernel entry. 57 | // 58 | CutlassGemm::Arguments args( 59 | {M, N, K}, // Gemm Problem dimensions 60 | {A, lda}, // Tensor-ref for source matrix A 61 | {B, ldb}, // Tensor-ref for source matrix B 62 | {C, ldc}, // Tensor-ref for source matrix C 63 | {C, ldc}, // Tensor-ref for destination matrix D (may be different 64 | // memory than source C matrix) 65 | {alpha, beta}); // Scalars used in the Epilogue 66 | 67 | // 68 | // Launch the CUTLASS GEMM kernel. 69 | // 70 | 71 | cutlass::Status status = gemm_operator(args); 72 | 73 | // 74 | // Return a cudaError_t if the CUTLASS GEMM operator returned an error code. 75 | // 76 | 77 | if (status != cutlass::Status::kSuccess) { 78 | return cudaErrorUnknown; 79 | } 80 | 81 | // Return success, if no errors were encountered. 82 | return cudaSuccess; 83 | } 84 | -------------------------------------------------------------------------------- /include/Conversions/MatMulCPUOptimize/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_MATMULCPUOPTIMIZE_PASSES_H 2 | #define HOM_MATMULCPUOPTIMIZE_PASSES_H 3 | 4 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 5 | #include "mlir/Dialect/Affine/Analysis/Utils.h" 6 | #include "mlir/Dialect/Affine/IR/AffineOps.h" 7 | #include "mlir/Dialect/Affine/LoopUtils.h" 8 | #include "mlir/Dialect/Affine/Utils.h" 9 | #include "mlir/Dialect/Arith/IR/Arith.h" 10 | #include "mlir/Dialect/Func/IR/FuncOps.h" 11 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 12 | #include "mlir/Dialect/Math/IR/Math.h" 13 | #include "mlir/IR/BuiltinDialect.h" 14 | #include "mlir/IR/Dialect.h" 15 | #include "mlir/IR/IntegerSet.h" 16 | #include "mlir/IR/Operation.h" 17 | #include "mlir/IR/TypeUtilities.h" 18 | #include "mlir/IR/Value.h" 19 | #include "mlir/Pass/Pass.h" 20 | #include "mlir/Transforms/DialectConversion.h" 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 | #include "llvm/Support/Debug.h" 23 | 24 | #include "mlir/Pass/Pass.h" 25 | #include 26 | 27 | namespace mlir { 28 | namespace hands_on_mlir { 29 | struct MatMulCPUOptimizePass 30 | : public PassWrapper> { 31 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulCPUOptimizePass) 32 | 33 | StringRef getArgument() const final { return "matmul-cpu-optimize"; } 34 | StringRef getDescription() const final { 35 | return "MatMul Optimization on CPU."; 36 | } 37 | 38 | MatMulCPUOptimizePass() = default; 39 | MatMulCPUOptimizePass(const MatMulCPUOptimizePass &) {} 40 | 41 | void getDependentDialects(DialectRegistry ®istry) const override { 42 | registry.insert(); 44 | } 45 | void runOnOperation() final; 46 | 47 | // Copied from llvm-project 48 | template 49 | void simplifyAndUpdateAttribute(Operation *op, StringAttr name, 50 | AttributeT attr) { 51 | auto &simplified = simplifiedAttributes[attr]; 52 | if (simplified == attr) 53 | return; 54 | 55 | // This is a newly encountered attribute. 56 | if (!simplified) { 57 | // Try to simplify the value of the attribute. 58 | auto value = attr.getValue(); 59 | auto simplifiedValue = simplify(value); 60 | if (simplifiedValue == value) { 61 | simplified = attr; 62 | return; 63 | } 64 | simplified = AttributeT::get(simplifiedValue); 65 | } 66 | 67 | // Simplification was successful, so update the attribute. 68 | op->setAttr(name, simplified); 69 | } 70 | 71 | IntegerSet simplify(IntegerSet set) { 72 | return affine::simplifyIntegerSet(set); 73 | } 74 | 75 | /// Performs basic affine map simplifications. 76 | AffineMap simplify(AffineMap map) { 77 | MutableAffineMap mMap(map); 78 | mMap.simplify(); 79 | return mMap.getAffineMap(); 80 | } 81 | 82 | DenseMap simplifiedAttributes; 83 | 84 | // Copy end. 85 | 86 | const size_t M_KERNEL_SIZE = 6; 87 | const size_t N_KERNEL_SIZE = 16; 88 | const int32_t K_BLOCK_SIZE = 1024; 89 | const int32_t M_BLOCK_SIZE = 384; 90 | const int32_t N_BLOCK_SIZE = 1024; 91 | }; 92 | 93 | inline void registerMatMulCPUOptimizePass() { 94 | PassRegistration(); 95 | } 96 | } // namespace hands_on_mlir 97 | 98 | } // namespace mlir 99 | 100 | #endif // HOM_CONVERSIONS_FUNC_TRANSFORMS_PASSES_H 101 | -------------------------------------------------------------------------------- /examples/torch/linear/cuda/c.cu: -------------------------------------------------------------------------------- 1 | #include "cutlass/arch/arch.h" 2 | #include "cutlass/arch/mma.h" 3 | #include "cutlass/arch/wmma.h" 4 | #include "cutlass/cutlass.h" 5 | #include "cutlass/gemm/device/gemm.h" 6 | #include "cutlass/gemm/device/gemm_universal.h" 7 | #include "cutlass/gemm/device/gemm_universal_adapter.h" 8 | #include "cutlass/gemm/gemm_enumerated_types.h" 9 | #include "cutlass/gemm/kernel/default_gemm_universal.h" 10 | #include "cutlass/half.h" 11 | #include "cutlass/layout/matrix.h" 12 | #include "cutlass/numeric_types.h" 13 | 14 | using cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2_base = 15 | typename cutlass::gemm::kernel::DefaultGemmUniversal< 16 | cutlass::half_t, cutlass::layout::RowMajor, 17 | cutlass::ComplexTransform::kNone, 2, // transposed B operand 18 | cutlass::half_t, cutlass::layout::RowMajor, 19 | cutlass::ComplexTransform::kNone, 2, // transposed A operand 20 | cutlass::half_t, cutlass::layout::RowMajor, float, 21 | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 22 | cutlass::gemm::GemmShape<256, 128, 32>, 23 | cutlass::gemm::GemmShape<64, 64, 32>, 24 | cutlass::gemm::GemmShape<16, 8, 8>, 25 | 26 | cutlass::epilogue::thread::LinearCombination, 28 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, 2, 29 | cutlass::arch::OpMultiplyAdd>::GemmKernel; 30 | 31 | // Define named type 32 | struct cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2 33 | : public cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2_base {}; 34 | 35 | int main() { 36 | int64_t m = 64, n = 768, k = 768; 37 | 38 | cutlass::gemm::device::GemmUniversal< 39 | cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, 40 | cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::RowMajor, 41 | float, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 42 | cutlass::gemm::GemmShape<256, 128, 32>, 43 | cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 8>, 44 | cutlass::epilogue::thread::LinearCombination, 46 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, 2, 2, 2, 47 | cutlass::arch::OpMultiplyAdd> 48 | gemm; 49 | 50 | half *a, *b, *c, *d; 51 | cudaMalloc(&a, sizeof(half) * m * k); 52 | cudaMalloc(&b, sizeof(half) * n * k); 53 | cudaMalloc(&c, sizeof(half) * m * n); 54 | cudaMalloc(&d, sizeof(half) * m * n); 55 | 56 | cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align2_base::Arguments 57 | args(cutlass::gemm::GemmUniversalMode::kBatched, {int(m), int(n), int(k)}, 58 | 1, {1.0f, 0.0f}, a, b, c, c, m * k, n * k, m * n, m * n, k, n, n, n); 59 | 60 | decltype(gemm)::Arguments argsGemm( 61 | cutlass::gemm::GemmUniversalMode::kGemm, {int(m), int(n), int(k)}, 1, 62 | {1.0, 1.0}, a, b, c, d, m * k, n * k, m * n, m * n, k, n, n, n); 63 | 64 | auto res = gemm.can_implement(argsGemm); 65 | if (res != cutlass::Status::kSuccess) { 66 | std::cout << "Not good" << std::endl; 67 | } 68 | 69 | res = gemm.initialize(argsGemm); 70 | 71 | if (res != cutlass::Status::kSuccess) { 72 | std::cout << "Not good" << std::endl; 73 | } 74 | 75 | std::cout << "Before run." << std::endl; 76 | 77 | gemm(argsGemm); 78 | 79 | res = gemm.run(); 80 | 81 | auto err = cudaStreamSynchronize(nullptr); 82 | 83 | if (err != cudaSuccess) { 84 | std::cout << "A: " << cudaGetErrorString(err) << std::endl; 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/bert_attn.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/Utils.h" 5 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 6 | #include "llvm/Support/Error.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #define RowMajor(A, des, i, j, k) \ 18 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1] + \ 19 | (k) * (des).strides[2]]) 20 | 21 | int main() { 22 | constexpr int64_t seq_len = 64; 23 | constexpr int64_t hidden_size = 768; 24 | auto hidden_state = 25 | allocHelper({1, seq_len, hidden_size}, nvgpuAllocer); 26 | auto mask = allocHelper({1, seq_len}, nvgpuAllocer); 27 | 28 | auto hidden_des = 29 | static_cast *>(hidden_state.descriptor); 30 | auto mask_des = static_cast *>(mask.descriptor); 31 | 32 | std::vector hidden_data(hidden_size * seq_len); 33 | 34 | std::ifstream in; 35 | in.open("0.txt"); 36 | float a; 37 | size_t ii = 0; 38 | while (in >> a) { 39 | assert(ii < hidden_data.size()); 40 | hidden_data[ii++] = a; 41 | } 42 | 43 | checkCudaErrors(cudaMemcpy(hidden_des->data, hidden_data.data(), 44 | sizeof(half) * hidden_data.size(), 45 | cudaMemcpyHostToDevice)); 46 | 47 | int64_t mask_data[seq_len]; 48 | for (auto &i : mask_data) { 49 | i = 1; 50 | } 51 | checkCudaErrors(cudaMemcpy(mask_des->data, mask_data, sizeof(mask_data), 52 | cudaMemcpyHostToDevice)); 53 | 54 | UnrankedMemRefType b; 55 | mlir::hands_on_mlir::ExecutionEngine e("libbert_attn_nvgpu.so"); 56 | 57 | auto res = e.invoke("forward", hidden_state.rank, hidden_state.descriptor, 58 | mask.rank, mask.descriptor, 59 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 60 | if (res) { 61 | llvm::handleAllErrors(std::move(res)); 62 | } 63 | 64 | // res = e.invoke("forward", hidden_state.rank, hidden_state.descriptor, 65 | // mask.rank, mask.descriptor, 66 | // mlir::hands_on_mlir::ExecutionEngine::result(b)); 67 | // if (res) { 68 | // llvm::handleAllErrors(std::move(res)); 69 | // } 70 | 71 | auto c = DynamicMemRefType(b); 72 | std::cout << c.rank << std::endl; 73 | assert(std::accumulate(c.sizes, c.sizes + c.rank, 1, std::multiplies<>()) == 74 | hidden_data.size()); 75 | 76 | std::vector thing; 77 | in.close(); 78 | in.open("1.txt"); 79 | while (in >> a) { 80 | thing.emplace_back(a); 81 | } 82 | checkCudaErrors(cudaMemcpy(hidden_data.data(), c.data, 83 | sizeof(half) * hidden_data.size(), 84 | cudaMemcpyDeviceToHost)); 85 | for (int i = 0; i < c.sizes[0]; i++) { 86 | for (int j = 0; j < c.sizes[1]; j++) { 87 | for (int k = 0; k < c.sizes[2]; k++) { 88 | std::cout << float(RowMajor(hidden_data, c, i, j, k) - 89 | RowMajor(thing, c, i, j, k)) 90 | << " "; 91 | } 92 | std::cout << std::endl; 93 | } 94 | std::cout << std::endl; 95 | } 96 | std::cout << std::endl; 97 | 98 | cudaFree(hidden_des->data); 99 | cudaFree(c.data); 100 | 101 | free(hidden_state.descriptor); 102 | free(mask.descriptor); 103 | free(b.descriptor); 104 | } 105 | -------------------------------------------------------------------------------- /lib/Dialect/HOMNVGPU/HOMNVGPUAutotunePass.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "Conversions/Tosa/Passes.h" 12 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 13 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 14 | #include "HOMNVGPU/HOMNVGPUOps.h" 15 | #include "NVGPUKernels/GemmProfiler.h" 16 | #include "cuda_runtime_api.h" 17 | #include "driver_types.h" 18 | #include "mlir/Dialect/Func/IR/FuncOps.h" 19 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 20 | #include "mlir/IR/Builders.h" 21 | #include "mlir/IR/BuiltinAttributeInterfaces.h" 22 | #include "mlir/IR/BuiltinAttributes.h" 23 | #include "mlir/IR/BuiltinTypes.h" 24 | #include "mlir/IR/Location.h" 25 | #include "mlir/IR/Operation.h" 26 | #include "mlir/IR/PatternMatch.h" 27 | #include "mlir/IR/Value.h" 28 | #include "mlir/Parser/Parser.h" 29 | #include "mlir/Support/LogicalResult.h" 30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 | #include "llvm/ADT/APInt.h" 32 | #include "llvm/ADT/ArrayRef.h" 33 | #include "llvm/ADT/SmallPtrSet.h" 34 | #include "llvm/ADT/SmallVector.h" 35 | #include "llvm/Support/Casting.h" 36 | #include "llvm/Support/ErrorHandling.h" 37 | #include "llvm/Support/raw_ostream.h" 38 | 39 | #define PASS_NAME "homnvgpu-autotune" 40 | #define DEBUG_TYPE PASS_NAME 41 | 42 | namespace mlir { 43 | namespace hands_on_mlir { 44 | namespace homnvgpu { 45 | 46 | #define GEN_PASS_DEF_HOMNVGPUAUTOTUNEPASS 47 | #include "HOMNVGPU/HOMNVGPUAutotune.pdll.h.inc" 48 | #include "HOMNVGPU/Passes.h.inc" 49 | 50 | namespace { 51 | 52 | static void profileMatmulImpl(PatternRewriter &rewriter, Operation *gemm_) { 53 | using namespace homnvgpu_kernel; 54 | auto gemm = dyn_cast(gemm_); 55 | 56 | if (!gemm.getResult().getType().getElementType().isF16()) { 57 | llvm_unreachable("Not supported."); 58 | } 59 | 60 | auto A = gemm.getOperand0().getType().getShape(); 61 | auto B = gemm.getOperand1().getType().getShape(); 62 | 63 | // To-do: Use a dedicated logger to log this. 64 | for (auto i : A) { 65 | std::cerr << i << " "; 66 | } 67 | 68 | std::cerr << std::endl; 69 | 70 | for (auto i : B) { 71 | std::cerr << i << " "; 72 | } 73 | 74 | std::cerr << std::endl; 75 | 76 | auto M = A[0] * A[1], N = B[2], K = A[2]; 77 | auto alpha = gemm.getAlpha().convertToFloat(); 78 | auto beta = gemm.getBeta().convertToFloat(); 79 | auto act = gemm.getAct(); 80 | 81 | if (act != 0 && act != 1) { 82 | return; 83 | } 84 | 85 | static GemmProfiler profiler(M, N, K, act, alpha, beta); 86 | 87 | auto [bestIdx, bestSplitKFactor] = 88 | profiler.profile(M, N, K, act, alpha, beta); 89 | 90 | gemm.setKernelName(bestIdx + 1); 91 | gemm.setSplitKFactor(bestSplitKFactor); 92 | } 93 | 94 | struct HOMNVGPUAutotunePass 95 | : impl::HOMNVGPUAutotunePassBase { 96 | void runOnOperation() final; 97 | 98 | LogicalResult initialize(MLIRContext *ctx) override; 99 | 100 | private: 101 | FrozenRewritePatternSet patterns; 102 | }; 103 | 104 | LogicalResult HOMNVGPUAutotunePass::initialize(MLIRContext *ctx) { 105 | RewritePatternSet patternList(ctx); 106 | 107 | populateGeneratedPDLLPatterns(patternList); 108 | patternList.getPDLPatterns().registerRewriteFunction("profileMatmul", 109 | profileMatmulImpl); 110 | patterns = std::move(patternList); 111 | return success(); 112 | } 113 | 114 | void HOMNVGPUAutotunePass::runOnOperation() { 115 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 116 | } 117 | 118 | } // namespace 119 | } // namespace homnvgpu 120 | } // namespace hands_on_mlir 121 | } // namespace mlir 122 | -------------------------------------------------------------------------------- /examples/torch/bert_attention/bert_self_attn.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/ExecutionEngine.h" 2 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 3 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 4 | #include "NVGPUKernels/Utils.h" 5 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 6 | #include "llvm/Support/Error.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | struct Res { 18 | UnrankedMemRefType a; 19 | }; 20 | 21 | #define RowMajor(A, des, i, j, k) \ 22 | ((A)[(i) * (des).strides[0] + (j) * (des).strides[1] + \ 23 | (k) * (des).strides[2]]) 24 | 25 | int main() { 26 | constexpr int64_t seq_len = 64; 27 | constexpr int64_t hidden_size = 128; 28 | auto hidden_state = 29 | allocHelper({1, seq_len, hidden_size}, nvgpuAllocer); 30 | auto mask = allocHelper({1, seq_len}, nvgpuAllocer); 31 | 32 | auto hidden_des = 33 | static_cast *>(hidden_state.descriptor); 34 | auto mask_des = static_cast *>(mask.descriptor); 35 | 36 | std::vector hidden_data(hidden_size * seq_len); 37 | 38 | std::ifstream in; 39 | in.open("0.txt"); 40 | float a; 41 | size_t ii = 0; 42 | while (in >> a) { 43 | assert(ii < hidden_data.size()); 44 | hidden_data[ii++] = a; 45 | } 46 | 47 | checkCudaErrors(cudaMemcpy(hidden_des->data, hidden_data.data(), 48 | sizeof(half) * hidden_data.size(), 49 | cudaMemcpyHostToDevice)); 50 | 51 | int64_t mask_data[seq_len]; 52 | for (auto &i : mask_data) { 53 | i = 1; 54 | } 55 | checkCudaErrors(cudaMemcpy(mask_des->data, mask_data, sizeof(mask_data), 56 | cudaMemcpyHostToDevice)); 57 | 58 | UnrankedMemRefType b; 59 | mlir::hands_on_mlir::ExecutionEngine e("libbert_self_attn_nvgpu.so"); 60 | 61 | auto res = e.invoke("forward", hidden_state.rank, hidden_state.descriptor, 62 | mask.rank, mask.descriptor, 63 | mlir::hands_on_mlir::ExecutionEngine::result(b)); 64 | if (res) { 65 | llvm::handleAllErrors(std::move(res)); 66 | } 67 | 68 | // res = e.invoke("forward", hidden_state.rank, hidden_state.descriptor, 69 | // mask.rank, mask.descriptor, 70 | // mlir::hands_on_mlir::ExecutionEngine::result(b)); 71 | // if (res) { 72 | // llvm::handleAllErrors(std::move(res)); 73 | // } 74 | 75 | auto c = DynamicMemRefType(b); 76 | std::cout << c.rank << std::endl; 77 | assert(std::accumulate(c.sizes, c.sizes + c.rank, 1, std::multiplies<>()) == 78 | hidden_data.size()); 79 | 80 | std::vector thing; 81 | in.close(); 82 | in.open("1.txt"); 83 | while (in >> a) { 84 | thing.emplace_back(a); 85 | } 86 | checkCudaErrors(cudaMemcpy(hidden_data.data(), c.data, 87 | sizeof(half) * hidden_data.size(), 88 | cudaMemcpyDeviceToHost)); 89 | for (int i = 0; i < c.sizes[0]; i++) { 90 | for (int j = 0; j < c.sizes[1]; j++) { 91 | for (int k = 0; k < c.sizes[2]; k++) { 92 | std::cout << float(RowMajor(hidden_data, c, i, j, k) - 93 | RowMajor(thing, c, i, j, k)) 94 | << " "; 95 | } 96 | std::cout << std::endl; 97 | } 98 | std::cout << std::endl; 99 | } 100 | std::cout << std::endl; 101 | 102 | cudaFree(hidden_des->data); 103 | cudaFree(c.data); 104 | 105 | free(hidden_state.descriptor); 106 | free(mask.descriptor); 107 | free(b.descriptor); 108 | } 109 | -------------------------------------------------------------------------------- /include/Conversions/Tosa/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef HOM_CONVERSIONS_TOSA_TRANSFORMS_PASSES_H 2 | #define HOM_CONVERSIONS_TOSA_TRANSFORMS_PASSES_H 3 | 4 | #include 5 | #include 6 | 7 | #include "WeightsEngine/Utils.h" 8 | #include "mlir/Dialect/Func/IR/FuncOps.h" 9 | #include "mlir/Dialect/PDL/IR/PDL.h" 10 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 11 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 12 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" 13 | #include "mlir/Pass/Pass.h" 14 | #include "mlir/Pass/PassManager.h" 15 | 16 | namespace mlir { 17 | namespace hands_on_mlir { 18 | namespace hom { 19 | 20 | #define GEN_PASS_DECL_TOSATOHOMPASS 21 | #define GEN_PASS_DECL_TOSACONSTANTFOLDINGPASS 22 | #define GEN_PASS_REGISTRATION 23 | #include "Conversions/Tosa/Passes.h.inc" 24 | 25 | template std::shared_ptr getNewData(size_t size) { 26 | static std::shared_ptr newData; 27 | static size_t newDataSize = 0; 28 | if (size > newDataSize) { 29 | newData.reset(new T[size]); 30 | newDataSize = size; 31 | } 32 | return newData; 33 | } 34 | 35 | template 36 | auto doCastFolding(T *data, Type newType, ArrayRef size, 37 | size_t totalSize) { 38 | auto newData = getNewData(totalSize); 39 | for (size_t i = 0; i < totalSize; i++) { 40 | newData.get()[i] = data[i]; 41 | } 42 | return getDenseElementsAttr(newType, size, newData.get(), totalSize); 43 | } 44 | 45 | template 46 | tosa::ConstOp foldCast(PatternRewriter &rewriter, tosa::ConstOp oldConst, 47 | Type newType, T *data, ArrayRef size) { 48 | auto requestNewDataSize = 49 | std::accumulate(size.begin(), size.end(), 1, std::multiplies()); 50 | 51 | DenseElementsAttr denseAttr; 52 | 53 | if (newType.isF32()) { 54 | denseAttr = 55 | doCastFolding(data, newType, size, requestNewDataSize); 56 | } else if (newType.isF16()) { 57 | denseAttr = doCastFolding(data, newType, size, requestNewDataSize); 58 | } else if (newType.isIntOrIndex()) { 59 | auto intType = llvm::dyn_cast(newType); 60 | switch (intType.getWidth()) { 61 | case 64: 62 | denseAttr = 63 | doCastFolding(data, newType, size, requestNewDataSize); 64 | break; 65 | case 32: 66 | denseAttr = 67 | doCastFolding(data, newType, size, requestNewDataSize); 68 | break; 69 | case 16: 70 | denseAttr = 71 | doCastFolding(data, newType, size, requestNewDataSize); 72 | break; 73 | case 8: 74 | denseAttr = 75 | doCastFolding(data, newType, size, requestNewDataSize); 76 | break; 77 | default: 78 | llvm_unreachable("Unsupported integer width. "); 79 | } 80 | } else { 81 | llvm_unreachable("Not supported type."); 82 | } 83 | 84 | return rewriter.create(oldConst->getLoc(), denseAttr.getType(), 85 | denseAttr); 86 | } 87 | 88 | inline void registerTosaToHOMPipelines() { 89 | PassPipelineRegistration<>( 90 | "tosa-to-hom-pipeline", 91 | "Convert TOSA operators to hom with some optimization", 92 | [](OpPassManager &pm) { 93 | tosa::TosaLayerwiseConstantFoldPassOptions tosaConstFoldOption; 94 | tosaConstFoldOption.aggressiveReduceConstant = true; 95 | pm.addPass( 96 | tosa::createTosaLayerwiseConstantFoldPass(tosaConstFoldOption)); 97 | pm.addPass(createTosaConstantFoldingPass()); 98 | pm.addPass( 99 | tosa::createTosaLayerwiseConstantFoldPass(tosaConstFoldOption)); 100 | pm.addPass(createTosaToHOMPass()); 101 | }); 102 | } 103 | 104 | } // namespace hom 105 | } // namespace hands_on_mlir 106 | } // namespace mlir 107 | 108 | #endif // HOM_CONVERSIONS_TOSA_TRANSFORMS_PASSES_H 109 | -------------------------------------------------------------------------------- /examples/mlir/cpu_gemm/naive.mlir: -------------------------------------------------------------------------------- 1 | // Driver for sgemm matmul with initialization and GFLOPS reporting. 2 | func.func @main() { 3 | %A = memref.alloc() : memref<2088x2048xf32> 4 | // Align %B and %C since these are shape cast to vector types. 5 | %B = memref.alloc() {alignment = 32} : memref<2048x2048xf32> 6 | %C = memref.alloc() {alignment = 32} : memref<2088x2048xf32> 7 | %C1 = memref.alloc() {alignment = 32} : memref<2088x2048xf32> 8 | 9 | %cf1 = arith.constant 1.00000e+01 : f32 // Large cf1 here to ensure beta is correct. 10 | 11 | %AA = memref.cast %A : memref<2088x2048xf32> to memref<*xf32> 12 | %BB = memref.cast %B : memref<2048x2048xf32> to memref<*xf32> 13 | %CC1 = memref.cast %C : memref<2088x2048xf32> to memref<*xf32> 14 | %CC2 = memref.cast %C1 : memref<2088x2048xf32> to memref<*xf32> 15 | 16 | %AAA = memref.cast %A : memref<2088x2048xf32> to memref 17 | %BBB = memref.cast %B : memref<2048x2048xf32> to memref 18 | %CCC1 = memref.cast %C : memref<2088x2048xf32> to memref 19 | %CCC2 = memref.cast %C1 : memref<2088x2048xf32> to memref 20 | 21 | func.call @fill2DRandomMatrixF32(%AA) : (memref<*xf32>) -> () 22 | func.call @fill2DRandomMatrixF32(%BB) : (memref<*xf32>) -> () 23 | 24 | linalg.fill ins(%cf1 : f32) outs(%C : memref<2088x2048xf32>) 25 | linalg.fill ins(%cf1 : f32) outs(%C1 : memref<2088x2048xf32>) 26 | func.call @matmul(%A, %B, %C) : (memref<2088x2048xf32>, memref<2048x2048xf32>, memref<2088x2048xf32>) -> () 27 | func.call @mmatmul(%AAA, %BBB, %CCC1) : (memref, memref, memref) -> () 28 | func.call @validateF32WithRefMatmul(%AA, %BB, %CC2, %CC1) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>) -> () 29 | 30 | %reps = arith.constant 5 : index 31 | 32 | //warm up 33 | func.call @matmul(%A, %B, %C) : (memref<2088x2048xf32>, memref<2048x2048xf32>, memref<2088x2048xf32>) -> () 34 | func.call @mmatmul(%AAA, %BBB, %CCC1) : (memref, memref, memref) -> () 35 | %t_start = func.call @rtclock() : () -> (f64) 36 | affine.for %ti = 0 to %reps { 37 | func.call @mmatmul(%AAA, %BBB, %CCC1) : (memref, memref, memref) -> () 38 | func.call @matmul(%A, %B, %C) : (memref<2088x2048xf32>, memref<2048x2048xf32>, memref<2088x2048xf32>) -> () 39 | } 40 | %t_end = func.call @rtclock() : () -> (f64) 41 | %pC = memref.cast %C : memref<2088x2048xf32> to memref<*xf32> 42 | 43 | %c0 = arith.constant 0 : index 44 | %c1 = arith.constant 1 : index 45 | %M = memref.dim %C, %c0 : memref<2088x2048xf32> 46 | %N = memref.dim %C, %c1 : memref<2088x2048xf32> 47 | %K = memref.dim %A, %c1 : memref<2088x2048xf32> 48 | 49 | %t = arith.subf %t_end, %t_start : f64 50 | %f1 = arith.muli %M, %N : index 51 | %f2 = arith.muli %f1, %K : index 52 | // 2*M*N*K. 53 | %c2 = arith.constant 2 : index 54 | %f3 = arith.muli %c2, %f2 : index 55 | %num_flops = arith.muli %reps, %f3 : index 56 | %num_flops_i = arith.index_cast %num_flops : index to i64 57 | %num_flops_f = arith.sitofp %num_flops_i : i64 to f64 58 | %flops = arith.divf %num_flops_f, %t : f64 59 | func.call @printFlops(%flops) : (f64) -> () 60 | 61 | return 62 | } 63 | 64 | #K_UB = affine_map<(d0) -> (480, d0 * -480 + 2048)> 65 | #I_LB = affine_map<(d0) -> (d0 * 110)> 66 | #I_UB = affine_map<(d0) -> (696, d0 * 110 + 110)> 67 | 68 | func.func @matmul(%arg0: memref<2088x2048xf32>, %arg1: memref<2048x2048xf32>, %arg2: memref<2088x2048xf32>) { 69 | //linalg.matmul ins(%arg0,%arg1:memref<2088x2048xf32>,memref<2048x2048xf32>) outs(%arg2:memref<2088x2048xf32>) 70 | return 71 | } 72 | 73 | func.func @mmatmul(%arg0: memref, %arg1: memref, %arg2: memref) { 74 | linalg.matmul ins(%arg0,%arg1:memref,memref) outs(%arg2:memref) 75 | return 76 | } 77 | 78 | func.func private @printFlops(f64) 79 | func.func private @rtclock() -> f64 80 | func.func private @print2DMatrixF32(memref<*xf32>) 81 | func.func private @fill2DRandomMatrixF32(memref<*xf32>) 82 | func.func private @fill2DIncMatrixF32(memref<*xf32>) 83 | func.func private @printMemrefF32(memref<*xf32>) 84 | func.func private @validateF32WithRefMatmul(memref<*xf32>,memref<*xf32>,memref<*xf32>,memref<*xf32>) 85 | -------------------------------------------------------------------------------- /include/ExecutionEngine/HandsOnRunnerUtils.h: -------------------------------------------------------------------------------- 1 | #ifndef HANDS_ON_MLIR_EXECUTIONENGINE_RUNNERUTILS_H 2 | #define HANDS_ON_MLIR_EXECUTIONENGINE_RUNNERUTILS_H 3 | 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #ifdef _WIN32 // Copied from official mlir project 10 | #ifndef HANDS_ON_MLIR_RUNNERUTILS_EXPORT 11 | #ifdef mlir_runner_utils_EXPORTS 12 | // We are building this library 13 | #define HANDS_ON_MLIR_RUNNERUTILS_EXPORT __declspec(dllexport) 14 | #else 15 | // We are using this library 16 | #define HANDS_ON_MLIR_RUNNERUTILS_EXPORT __declspec(dllimport) 17 | #endif // mlir_runner_utils_EXPORTS 18 | #endif // MLIR_RUNNERUTILS_EXPORT 19 | #else 20 | // Non-windows: use visibility attributes. 21 | #define HANDS_ON_MLIR_RUNNERUTILS_EXPORT __attribute__((visibility("default"))) 22 | #endif // _WIN32 23 | 24 | // Need this type for extern C functions since these functions cannot return 25 | // template type. However, a normal type with template member function is 26 | // allowed. 27 | struct C_UnrankedMemRefType : public UnrankedMemRefType { 28 | C_UnrankedMemRefType() = default; 29 | 30 | template 31 | C_UnrankedMemRefType(const UnrankedMemRefType &memref) { 32 | this->rank = memref.rank; 33 | this->descriptor = memref.descriptor; 34 | } 35 | }; 36 | 37 | template 38 | auto convertToDynamicMemRefType(int64_t rank, void *dst) { 39 | UnrankedMemRefType unrankType = {rank, dst}; 40 | DynamicMemRefType dyType(unrankType); 41 | return dyType; 42 | } 43 | 44 | using allocFnType = std::function; 45 | 46 | template 47 | static UnrankedMemRefType 48 | allocHelper(const std::vector &sizes, allocFnType customAllocer) { 49 | auto returnMemRef = UnrankedMemRefType(); 50 | returnMemRef.rank = rank; 51 | returnMemRef.descriptor = 52 | malloc(sizeof(StridedMemRefType)); 53 | auto des = static_cast *>( 54 | returnMemRef.descriptor); 55 | 56 | assert(rank == sizes.size()); 57 | 58 | auto totalSize = 59 | std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); 60 | 61 | customAllocer(reinterpret_cast(&(des->data)), 62 | sizeof(ElementType) * totalSize); 63 | 64 | des->basePtr = des->data; 65 | des->offset = 0; 66 | int64_t strides = 1; 67 | for (int i = 0; i < rank; i++) { 68 | des->sizes[i] = sizes[i]; 69 | des->strides[rank - i - 1] = strides; 70 | strides *= sizes[rank - i - 1]; 71 | } 72 | return returnMemRef; 73 | } 74 | 75 | extern "C" { 76 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT void print2DMatrixF32(int64_t rank, void *dst); 77 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT void fill2DRandomMatrixF32(int64_t rank, 78 | void *dst); 79 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT void fill2DIncMatrixF32(int64_t rank, 80 | void *dst); 81 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT void validateF32WithRefMatmul(int64_t, void *, 82 | int64_t, void *, 83 | int64_t, void *, 84 | int64_t, void *); 85 | 86 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT void deallocF32(int64_t rank, void *dst); 87 | 88 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT C_UnrankedMemRefType 89 | allocF32(int32_t elementNum); 90 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT 91 | C_UnrankedMemRefType alloc3DMemRefF32(int32_t, int32_t, int32_t); 92 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT C_UnrankedMemRefType 93 | allocByMemRefF32(int64_t rank, void *dst); 94 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT C_UnrankedMemRefType 95 | allocConstantF32(int32_t idx); 96 | 97 | HANDS_ON_MLIR_RUNNERUTILS_EXPORT void matmulAddF32(int64_t, void *, int64_t, 98 | void *, int64_t, void *, 99 | int64_t, void *); 100 | } 101 | 102 | #endif 103 | -------------------------------------------------------------------------------- /include/Dialect/HOM/HOMFusion.pdll: -------------------------------------------------------------------------------- 1 | #include "HOM/HOMOps.td" 2 | #include "Utils.pdll" 3 | #include "mlir/Dialect/Tosa/IR/TosaOps.td" 4 | 5 | Constraint checkReshapeRemovable(op0 : Op, op1 : Op, op2 : Op); 6 | Constraint checkMHAQKVReshape(op0 : Op, op1 : Op, op2 : Op); 7 | Constraint checkMHAQKVTransposePerm(op0 : Op, op1 : Op, op2 : Op); 8 | Constraint checkTransposeReshapeChangeable(op0 : Op); 9 | Rewrite changeTransposeReshape(op0 : Op, op1 : Op, op2 : Op); 10 | Rewrite removeRedundantReshape(op0 : Op, op1 : Op); 11 | Rewrite buildMHAOp(op0 12 | : Op, value 13 | : Attr, op2 14 | : Value, op3 15 | : Op, op4 16 | : Op, op5 17 | : Op) 18 | ->Op; 19 | 20 | // Try to remove meaningless reshape here. 21 | Pattern { 22 | let A : Value<_ : HOM_RankedTensor>; 23 | let B : Value<_ : HOM_RankedTensor>; 24 | let reshapeA = op(A); 25 | let reshapeB = op(B); 26 | let matmul = op(reshapeA, reshapeB); 27 | let reshapeC = op(matmul); 28 | checkReshapeRemovable(reshapeA, reshapeB, reshapeC); 29 | 30 | replace reshapeC with op(A, B); 31 | } 32 | 33 | Pattern { 34 | let newShape : Attr; 35 | let reshapeA = op(A : Value) { 36 | new_shape = newShape 37 | } -> (output : Type); 38 | let reshapeB = op(A){new_shape = newShape}; 39 | 40 | rewrite reshapeB with { removeRedundantReshape(reshapeA, reshapeB); }; 41 | } 42 | 43 | Pattern { 44 | let matmul = op(input0 : Value, input1 : Value); 45 | let root = op(matmul, input2 : Value); 46 | 47 | replace root with op(input0, input1, input2); 48 | } 49 | 50 | // Try to move transpose closer to matmul. 51 | Pattern { 52 | let perm = op; 53 | let transpose = op(input1 : Value, perm); 54 | let reshape = op(transpose); 55 | checkTransposeReshapeChangeable(reshape); 56 | 57 | rewrite reshape with { changeTransposeReshape(transpose, perm, reshape); }; 58 | } 59 | 60 | // MHA 61 | Pattern { 62 | let hiddenState : Value<_ : HOM_RankedTensor>; 63 | let qWeights = op; 64 | let kWeights = op; 65 | let vWeights = op; 66 | let q = op(hiddenState, qWeights); 67 | let k = op(hiddenState, kWeights); 68 | let v = op(hiddenState, vWeights); 69 | 70 | let newShape : DenseI64ArrayAttr; 71 | let reshapeQ = op(q){new_shape = newShape}; 72 | let reshapeK = op(k){new_shape = newShape}; 73 | let reshapeV = op(v){new_shape = newShape}; 74 | checkMHAQKVReshape(reshapeQ, reshapeK, reshapeV); 75 | 76 | let permQ = op; 77 | let permK = op; 78 | let permV = op; 79 | 80 | let transposeQ = op(reshapeQ, permQ); 81 | let transposeK = op(reshapeK, permK); 82 | let transposeV = op(reshapeV, permV); 83 | checkMHAQKVTransposePerm(permQ, permK, permV); 84 | 85 | let qkBMM = op(transposeQ, transposeK); 86 | let scale = op; 87 | isSingleFloatConstant(scale); 88 | let shift = attr<"0 : i8">; 89 | 90 | let calMHA = op( 91 | op(op(op(qkBMM, scale){shift = shift}, mask 92 | : Value)), 93 | transposeV); 94 | 95 | let mha = op(op(calMHA, permMHA 96 | : Value < _ 97 | : HOM_RankedTensor >)); 98 | 99 | rewrite mha with { 100 | let scaleAttr = getSingleFloatValue(scale); 101 | replace mha with buildMHAOp(reshapeQ, scaleAttr, mask, q, k, v); 102 | }; 103 | } 104 | 105 | Pattern { 106 | let cast = op(input : Value); 107 | let root = op(cast, values : Value); 108 | 109 | replace root with op(input, values); 110 | } 111 | 112 | Pattern { 113 | let reshape = op(input : Value); 114 | let root = op(reshape, values : Value); 115 | let reshape0 = op(root); 116 | 117 | replace reshape0 with op(input, values); 118 | } 119 | -------------------------------------------------------------------------------- /lib/Dialect/HOMNVGPU/HOMNVGPUFusionPass.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "Conversions/Tosa/Passes.h" 11 | #include "HOM/HOMOps.h" 12 | #include "HOMNVGPU/HOMNVGPUOps.h" 13 | #include "WeightsEngine/WeightsEngine.h" 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" 15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 | #include "mlir/IR/Builders.h" 17 | #include "mlir/IR/BuiltinAttributeInterfaces.h" 18 | #include "mlir/IR/BuiltinAttributes.h" 19 | #include "mlir/IR/BuiltinTypes.h" 20 | #include "mlir/IR/Location.h" 21 | #include "mlir/IR/Operation.h" 22 | #include "mlir/IR/PatternMatch.h" 23 | #include "mlir/IR/Value.h" 24 | #include "mlir/Parser/Parser.h" 25 | #include "mlir/Support/LogicalResult.h" 26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 | #include "llvm/ADT/APInt.h" 28 | #include "llvm/ADT/ArrayRef.h" 29 | #include "llvm/ADT/SmallPtrSet.h" 30 | #include "llvm/ADT/SmallVector.h" 31 | #include "llvm/Support/Casting.h" 32 | #include "llvm/Support/ErrorHandling.h" 33 | #include "llvm/Support/raw_ostream.h" 34 | 35 | #define PASS_NAME "homnvgpu-fusion" 36 | #define DEBUG_TYPE PASS_NAME 37 | 38 | namespace mlir { 39 | namespace hands_on_mlir { 40 | namespace homnvgpu { 41 | 42 | #define GEN_PASS_DEF_HOMNVGPUFUSIONPASS 43 | #include "Dialect/HOMNVGPU/HOMNVGPUFusion.pdll.h.inc" 44 | #include "HOMNVGPU/Passes.h.inc" 45 | 46 | namespace { 47 | 48 | static void generateGemmLnGemmImpl(PatternRewriter &rewriter, Operation *gemm0_, 49 | Operation *ln_, Operation *gemm1_) { 50 | 51 | auto gemm0 = dyn_cast(gemm0_); 52 | auto ln = dyn_cast(ln_); 53 | auto gemm1 = dyn_cast(gemm1_); 54 | 55 | auto gemmWithVarMean = rewriter.create( 56 | gemm0->getLoc(), gemm0.getOperand0(), gemm0.getOperand1(), 57 | gemm0.getOperand2(), gemm0.getAlpha(), gemm0.getBeta(), gemm0.getAct(), 58 | ln.getEps()); 59 | auto LnGemm = rewriter.create( 60 | gemm1->getLoc(), gemm1.getResult().getType(), gemmWithVarMean.getOutput(), 61 | gemm1.getOperand1(), gemm1.getOperand2(), gemmWithVarMean.getVar(), 62 | gemmWithVarMean.getMean(), gemm1.getAlpha(), gemm1.getBeta(), 63 | gemm1.getAct()); 64 | 65 | gemm1.replaceAllUsesWith(LnGemm.getResult()); 66 | } 67 | 68 | static void updateMaskWithCuSeqLenImpl(PatternRewriter &rewriter, 69 | Operation *mask_, Operation *bert_mha_) { 70 | 71 | auto mask = dyn_cast(mask_); 72 | auto bert_mha = dyn_cast(bert_mha_); 73 | 74 | homnvgpu::CuSeqLenOp newMask; 75 | 76 | mask->getBlock()->walk([&](homnvgpu::CuSeqLenOp op) { 77 | if (op.getInput() == mask.getInput()) { 78 | newMask = op; 79 | } 80 | }); 81 | 82 | if (!newMask) { 83 | auto shape = mask.getType().getShape(); 84 | rewriter.setInsertionPointToStart(mask->getBlock()); 85 | newMask = rewriter.create( 86 | mask->getLoc(), 87 | RankedTensorType::get({shape[0] + 1}, rewriter.getI32Type()), 88 | mask.getInput()); 89 | } 90 | 91 | bert_mha->setOperand(1, newMask.getOutput()); 92 | } 93 | 94 | struct HOMNVGPUFusionPass : impl::HOMNVGPUFusionPassBase { 95 | void runOnOperation() final; 96 | 97 | LogicalResult initialize(MLIRContext *ctx) override; 98 | 99 | private: 100 | FrozenRewritePatternSet patterns; 101 | }; 102 | 103 | LogicalResult HOMNVGPUFusionPass::initialize(MLIRContext *ctx) { 104 | RewritePatternSet patternList(ctx); 105 | 106 | populateGeneratedPDLLPatterns(patternList); 107 | patternList.getPDLPatterns().registerRewriteFunction("generateGemmLnGemm", 108 | generateGemmLnGemmImpl); 109 | patternList.getPDLPatterns().registerRewriteFunction( 110 | "updateMaskWithCuSeqLen", updateMaskWithCuSeqLenImpl); 111 | patterns = std::move(patternList); 112 | return success(); 113 | } 114 | 115 | void HOMNVGPUFusionPass::runOnOperation() { 116 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 117 | } 118 | 119 | } // namespace 120 | } // namespace homnvgpu 121 | } // namespace hands_on_mlir 122 | } // namespace mlir 123 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # ===- CMakeLists.txt - HANDS_ON_MLIR-mlir cmake root -----------------*- cmake 2 | # -*-===// 3 | # 4 | # Configure the HANDS_ON_MLIR-mlir build. 5 | # 6 | # ===----------------------------------------------------------------------===// 7 | 8 | cmake_minimum_required(VERSION 3.10) 9 | 10 | if(POLICY CMP0077) 11 | cmake_policy(SET CMP0077 NEW) 12 | endif() 13 | 14 | if(POLICY CMP0116) 15 | cmake_policy(SET CMP0116 OLD) 16 | endif() 17 | 18 | if(POLICY CMP0074) 19 | cmake_policy(SET CMP0074 OLD) 20 | endif() 21 | 22 | # ------------------------------------------------------------------------------- 23 | # Project setup and globals 24 | # ------------------------------------------------------------------------------- 25 | 26 | project(Hands-on-MLIR LANGUAGES CXX C) 27 | 28 | set(CMAKE_CXX_STANDARD 20) 29 | set(CMAKE_CXX_STANDARD_REQUIRED YES) 30 | 31 | # ------------------------------------------------------------------------------- 32 | # Options and settings 33 | # ------------------------------------------------------------------------------- 34 | 35 | option(LLVM_INCLUDE_TOOLS "Generate build targets for the LLVM tools." ON) 36 | option(LLVM_BUILD_TOOLS 37 | "Build the LLVM tools. If OFF, just generate build targets." ON) 38 | 39 | # ------------------------------------------------------------------------------- 40 | # MLIR/LLVM Configuration 41 | # ------------------------------------------------------------------------------- 42 | find_package(MLIR REQUIRED CONFIG) 43 | 44 | set(LLVM_MLIR_BINARY_DIR ${MLIR_DIR}/../../../bin) 45 | set(LLVM_MLIR_SOURCE_DIR ${MLIR_DIR}/../../../../mlir) 46 | 47 | list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") 48 | list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") 49 | 50 | include(TableGen) 51 | include(AddLLVM) 52 | include(AddMLIR) 53 | include(HandleLLVMOptions) 54 | 55 | # ------------------------------------------------------------------------------- 56 | # HANDS_ON_MLIR configuration 57 | # ------------------------------------------------------------------------------- 58 | 59 | # HANDS_ON_MLIR project. 60 | set(HANDS_ON_MLIR_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) 61 | set(HANDS_ON_MLIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/bin) 62 | set(HANDS_ON_MLIR_INCLUDE_DIR ${HANDS_ON_MLIR_SOURCE_DIR}/include/) 63 | 64 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${HANDS_ON_MLIR_BINARY_DIR}) 65 | 66 | set(HANDS_ON_MLIR_EXAMPLES 67 | OFF 68 | CACHE BOOL "Build examples") 69 | 70 | # Add MLIR and LLVM headers to the include path 71 | include_directories(${LLVM_INCLUDE_DIRS}) 72 | include_directories(${MLIR_INCLUDE_DIRS}) 73 | 74 | # Add HANDS_ON_MLIR files to the include path 75 | include_directories(${HANDS_ON_MLIR_MAIN_INCLUDE_DIR}) 76 | include_directories(${HANDS_ON_MLIR_INCLUDE_DIR}) 77 | include_directories(${HANDS_ON_MLIR_INCLUDE_DIR}/Dialect) 78 | include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/Dialect) 79 | include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) 80 | include_directories(${HANDS_ON_MLIR_SOURCE_DIR}/lib) 81 | include_directories(${HANDS_ON_MLIR_SOURCE_DIR}/thirdparty/half-2.2.0/include) 82 | 83 | # ------------------------------------------------------------------------------- 84 | # CUDA configuration 85 | # ------------------------------------------------------------------------------- 86 | 87 | if(ENABLE_CUDA) 88 | add_compile_definitions(ENABLE_CUDA) 89 | enable_language(CUDA) 90 | set(CMAKE_CUDA_STANDARD 17) 91 | set(CMAKE_CUDA_STANDARD_REQUIRED YES) 92 | find_package(CUDAToolkit REQUIRED) 93 | include_directories(${HANDS_ON_MLIR_SOURCE_DIR}/thirdparty/cutlass/include) 94 | include_directories( 95 | ${HANDS_ON_MLIR_SOURCE_DIR}/thirdparty/TransformerEngine/transformer_engine/common/include 96 | ) 97 | include_directories( 98 | ${HANDS_ON_MLIR_SOURCE_DIR}/thirdparty/cutlass/tools/library/include) 99 | include_directories( 100 | ${HANDS_ON_MLIR_SOURCE_DIR}/thirdparty/cutlass/tools/util/include) 101 | include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 102 | link_directories(${HANDS_ON_MLIR_SOURCE_DIR}/thirdparty/TransformerEngine/) 103 | add_link_options(-ltransformer_engine) 104 | endif() 105 | 106 | # ------------------------------------------------------------------------------- 107 | # Hardware detection 108 | # ------------------------------------------------------------------------------- 109 | 110 | include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/check_simd.cmake) 111 | check_simd() 112 | 113 | # ------------------------------------------------------------------------------- 114 | # Directory setup 115 | # ------------------------------------------------------------------------------- 116 | 117 | add_subdirectory(cmake) 118 | add_subdirectory(include) 119 | add_subdirectory(lib) 120 | add_subdirectory(tools) 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hands-on-MLIR 2 | 3 | WIP. Heavily developing in progress currently, so no document available. E2E bert is runnable with rtol at about 1e-3 with fp16 on 3090. Please see `examples/torch/bert`. The code is quite messy right now. Doesn't have time to clean up. 4 | 5 | # Features 6 | 7 | + Lower torch model from TOSA to LLVM dialect. 8 | + End-to-end huggingface bert model support (Limited support) 9 | + E2E is limited supported now. Limitations are as follow (TE stands for transformer engine, HOM stands for hands on mlir project): 10 | 1. `seqlen % 64 == 0` (TE limitation) 11 | 2. `head_dim % 64 ==0` (TE limitation) 12 | 3. `head_num > 1` (HOM limitation, the reshape pattern has some issue with `head_num==1`) 13 | 4. fp16 only (HOM limitation, didn't write fp32 pass) 14 | 5. padding mode only (HOM limitation) 15 | 6. nvgpu only (HOM limitation) 16 | 7. sm80+ only (TE limitation) 17 | 8. Native Linux only (TE limitation) 18 | 9. Static shape (HOM limitation) 19 | + Some simple fusion pass 20 | + GEMM + GELU fusion 21 | + Packed qkv bert attention 22 | + etc... 23 | + Autotuning cutlass 24 | + Only support GEMM, GEMM + GELU op with Row,Row,Row layout 25 | + The tilings are from cutlass official repo with some customization for gelu 26 | + Provide about 20% performance boost 27 | + Only support fp16 with fp32 acc 28 | + sm < 90 since I didn't generate sm90 cutlass kernel 29 | + Serial split k only 30 | 31 | # To-do 32 | 33 | + Clean up code 34 | + Improve precision 35 | + More fusion pattern 36 | 37 | # Pre-requirement 38 | 39 | + For nvcc, host compiler gcc >= 11. Clang is not tested. 40 | + For cpp code, Must use clang to compile (as new as possible for gnu++20 support) 41 | 1. For _Float16 support (could be removed in the future) 42 | 2. Used some other weird stuff. Simply just cannot compiled by gcc. 43 | 3. C++ std=gnu++20 for template lambda support. 44 | + CUDNN > 8.9.7. If cudnn is not installed by package manager, you will also need to set env `CUDNN_PATH` for cudnn-frontend to find the correct cudnn location. 45 | + Only tested on sm86 and sm89. sm version lower than 80 is not supported. 46 | + Only tested on Linux. WSL is not supported. 47 | 48 | # Install 49 | 50 | ## Clone submodules 51 | 52 | ``` 53 | # Wrote down from my memory, may not be correct. 54 | git submodule init 55 | git submodule update --recursive 56 | ``` 57 | 58 | ## Install thirdparty 59 | 60 | ... 61 | 62 | ## Install python env 63 | 64 | Install enssential python packages. Also, this project requires python venv since `transformer engine` needs torch with cuda, however, `torch-mlir` needs preview version of torch with cpu-only. 65 | 66 | ``` 67 | # Install script not finished. 68 | pip install -r requirements.txt 69 | pre-commit install 70 | ``` 71 | 72 | ## Install MLIR 73 | 74 | Install it in your preferable way. This project should be compatible with the main branch of mlir. Also, there is one under `thirdparty/llvm-project`, which is the one I'm currently working on. You can use that. **Strongly recommend using lld to get faster linking speed.** 75 | 76 | ## Install this project 77 | 78 | Use the following command to compile. **Strongly recommend using lld to get faster linking speed.** 79 | 80 | ``` 81 | $ cd Hands-on-MLIR 82 | $ mkdir build && cd build 83 | $ cmake -G Ninja .. \ 84 | -DMLIR_DIR=/your/path/to/llvm-project/build/lib/cmake/mlir \ 85 | -DLLVM_DIR=/your/path/to/llvm-project/build/lib/cmake/llvm \ 86 | -DLLVM_ENABLE_ASSERTIONS=ON \ 87 | -DLLVM_USE_LINKER=lld \ 88 | -DENABLE_CUDA=ON \ 89 | -DCMAKE_CUDA_ARCHITECTURES=your sm version 90 | ``` 91 | 92 | or you can use this setup in VSCode. 93 | 94 | ``` 95 | "cmake.configureArgs": [ 96 | "-DMLIR_DIR=/your/path/to/llvm-project/build/lib/cmake/mlir", 97 | "-DLLVM_DIR=/your/path/to/llvm-project/build/lib/cmake/llvm", 98 | "-DLLVM_ENABLE_ASSERTIONS=ON", 99 | "-DLLVM_USE_LINKER=lld", 100 | "-DENABLE_CUDA=ON", 101 | "-DCMAKE_CUDA_ARCHITECTURES=your sm version ", 102 | // "-DLLVM_USE_SANITIZER=Address;Undefined" add this option if you want to enable the sanitizer. Also, maybe you should add it to llvm as well. 103 | ], 104 | ``` 105 | 106 | # Reference 107 | 108 | + [MLIR](https://github.com/llvm/llvm-project/):抄了很多( 109 | + [buddy-mlir](https://github.com/buddy-compiler/buddy-mlir):同样抄了很多( 110 | + [polymage-labs/mlirx](https://github.com/polymage-labs/mlirx):版本太老了,很多都没法抄( 111 | + [Polyhedral Model 三篇](https://mp.weixin.qq.com/s?__biz=MzI3MDQ2MjA3OA==&mid=2247485130&idx=1&sn=a5773bf17e6854d1238b035366641bcc&chksm=ead1fbdbdda672cdf9b2480a431cef85e4d377d07f8c586a932adabd50656cbdcd7d891156bf&mpshare=1&scene=1&srcid=&sharer_sharetime=1569677798809&sharer_shareid=b33ef36fa0caf5cb82e76916516aa7df#rd):知道多面体优化的基本概念。 112 | -------------------------------------------------------------------------------- /lib/Dialect/HOMNVGPU/HOMNVGPULegalizeGemmPass.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "Conversions/Tosa/Passes.h" 11 | #include "HOM/HOMOps.h" 12 | #include "HOMNVGPU/HOMNVGPUOps.h" 13 | #include "WeightsEngine/WeightsEngine.h" 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" 15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 | #include "mlir/IR/Builders.h" 17 | #include "mlir/IR/BuiltinAttributeInterfaces.h" 18 | #include "mlir/IR/BuiltinAttributes.h" 19 | #include "mlir/IR/BuiltinTypes.h" 20 | #include "mlir/IR/Location.h" 21 | #include "mlir/IR/Operation.h" 22 | #include "mlir/IR/PatternMatch.h" 23 | #include "mlir/IR/Value.h" 24 | #include "mlir/Parser/Parser.h" 25 | #include "mlir/Support/LogicalResult.h" 26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 | #include "llvm/ADT/APInt.h" 28 | #include "llvm/ADT/ArrayRef.h" 29 | #include "llvm/ADT/SmallPtrSet.h" 30 | #include "llvm/ADT/SmallVector.h" 31 | #include "llvm/Support/Casting.h" 32 | #include "llvm/Support/ErrorHandling.h" 33 | #include "llvm/Support/raw_ostream.h" 34 | 35 | #define PASS_NAME "homnvgpu-legalize-gemm" 36 | #define DEBUG_TYPE PASS_NAME 37 | 38 | namespace mlir { 39 | namespace hands_on_mlir { 40 | namespace homnvgpu { 41 | 42 | #define GEN_PASS_DEF_HOMNVGPULEGALIZEGEMMPASS 43 | #include "HOMNVGPU/HOMNVGPULegalizeGemm.pdll.h.inc" 44 | #include "HOMNVGPU/Passes.h.inc" 45 | 46 | namespace { 47 | 48 | static void generateTransposeImpl(PatternRewriter &rewriter, 49 | Operation *matmul_) { 50 | auto matmul = dyn_cast(matmul_); 51 | 52 | if (auto defining = matmul->getOperand(1).getDefiningOp()) { 53 | if (auto constOp = dyn_cast(defining)) { 54 | auto oldType = constOp.getResult().getType(); 55 | auto oldShape = oldType.getShape(); 56 | 57 | assert(oldShape.size() == 3); 58 | 59 | SmallVector newShape = {oldShape[0], oldShape[2], oldShape[1]}; 60 | 61 | tosa::TransposeOp transposeOp; 62 | auto needCreateTranspose = [&transposeOp, &constOp]() { 63 | for (const auto &user : constOp->getUsers()) { 64 | if (auto op = dyn_cast(user)) { 65 | llvm::SmallVector perms; 66 | if (op.getConstantPerms(perms).succeeded() && perms[0] == 0 && 67 | perms[1] == 2 && perms[2] == 1) { 68 | transposeOp = op; 69 | return false; 70 | } 71 | } 72 | } 73 | return true; 74 | }; 75 | 76 | if (needCreateTranspose()) { 77 | auto permAttr = DenseIntElementsAttr::get( 78 | RankedTensorType::get({3}, rewriter.getI32Type()), 79 | ArrayRef{0, 2, 1}); 80 | 81 | rewriter.setInsertionPointToStart(matmul->getBlock()); 82 | auto perm = rewriter.create( 83 | constOp->getLoc(), permAttr.getType(), permAttr); 84 | transposeOp = rewriter.create( 85 | constOp->getLoc(), 86 | RankedTensorType::get(newShape, oldType.getElementType()), 87 | constOp.getResult(), perm.getResult()); 88 | } 89 | 90 | matmul->setOperand(1, transposeOp.getResult()); 91 | matmul.setTransb(true); 92 | return; 93 | } else if (auto transposeOp = dyn_cast( 94 | matmul->getOperand(1).getDefiningOp())) { 95 | auto perm = 96 | dyn_cast(transposeOp.getPerms().getDefiningOp()) 97 | .getValue() 98 | .getValues(); 99 | if (perm.size() == 3 && perm[0] == 0 && perm[1] == 2 && perm[2] == 1) { 100 | matmul->setOperand(1, transposeOp.getInput1()); 101 | matmul.setTransb(true); 102 | return; 103 | } 104 | } 105 | } 106 | llvm_unreachable("Does not support this format."); 107 | } 108 | 109 | struct HOMNVGPULegalizeGemmPass 110 | : impl::HOMNVGPULegalizeGemmPassBase { 111 | void runOnOperation() final; 112 | 113 | LogicalResult initialize(MLIRContext *ctx) override; 114 | 115 | private: 116 | FrozenRewritePatternSet patterns; 117 | }; 118 | 119 | LogicalResult HOMNVGPULegalizeGemmPass::initialize(MLIRContext *ctx) { 120 | RewritePatternSet patternList(ctx); 121 | 122 | populateGeneratedPDLLPatterns(patternList); 123 | patternList.getPDLPatterns().registerRewriteFunction("generateTranspose", 124 | generateTransposeImpl); 125 | patterns = std::move(patternList); 126 | return success(); 127 | } 128 | 129 | void HOMNVGPULegalizeGemmPass::runOnOperation() { 130 | (void)applyPatternsAndFoldGreedily(getOperation(), patterns); 131 | } 132 | 133 | } // namespace 134 | } // namespace homnvgpu 135 | } // namespace hands_on_mlir 136 | } // namespace mlir 137 | -------------------------------------------------------------------------------- /examples/torch/layernorm/cuda/ln_gemm.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 2 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 3 | #include "NVGPUKernels/Utils.h" 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | void fillRand(_Float16 *a, int64_t m, int64_t n, int64_t k) { 11 | for (int i = 0; i < m; i++) { 12 | for (int j = 0; j < n; j++) { 13 | for (int ii = 0; ii < k; ii++) { 14 | a[i * (n * k) + j * k + ii] = float(rand()) / ((float)RAND_MAX / 1); 15 | } 16 | } 17 | } 18 | } 19 | 20 | void plusOne(_Float16 *a, _Float16 *b, int64_t m, int64_t n, int64_t k) { 21 | for (int i = 0; i < m; i++) { 22 | for (int j = 0; j < n; j++) { 23 | for (int ii = 0; ii < k; ii++) { 24 | a[i * (n * k) + j * k + ii] = 25 | b[i * (n * k) + j * k + ii] * b[i * (n * k) + j * k + ii] + 1; 26 | } 27 | } 28 | } 29 | } 30 | 31 | void print3D(_Float16 *a, int64_t m, int64_t n, int64_t k) { 32 | 33 | std::cout << "=====================\n"; 34 | for (int i = 0; i < m; i++) { 35 | for (int j = 0; j < n; j++) { 36 | for (int ii = 0; ii < k; ii++) { 37 | std::cout << float(a[i * (n * k) + j * k + ii]) << " "; 38 | } 39 | std::cout << std::endl; 40 | } 41 | std::cout << std::endl; 42 | } 43 | std::cout << "=====================\n"; 44 | } 45 | 46 | void print1D(_Float16 *a, int64_t m) { 47 | std::cout << "=====================\n"; 48 | for (int i = 0; i < m; i++) { 49 | std::cout << float(a[i]) << " "; 50 | } 51 | std::cout << "\n=====================\n"; 52 | } 53 | 54 | int main() { 55 | 56 | int bs = 2; 57 | 58 | int m = 2, n = 8, k = 32; 59 | int scale = 2; 60 | 61 | auto a = allocHelper({bs, m, k}, nvgpuAllocer); 62 | auto b = allocHelper({1, k, n}, nvgpuAllocer); 63 | auto c = allocHelper({bs, m, n}, nvgpuAllocer); 64 | auto d = allocHelper({1, n, n * scale}, nvgpuAllocer); 65 | auto e = allocHelper({bs, m, n * scale}, nvgpuAllocer); 66 | 67 | auto mean = allocHelper({bs * m}, nvgpuAllocer); 68 | auto var = allocHelper({bs * m}, nvgpuAllocer); 69 | 70 | _Float16 *host_ptr = new _Float16[bs * m * n * k * scale]; 71 | 72 | auto des = static_cast *>(a.descriptor); 73 | fillRand(host_ptr, des->sizes[0], des->sizes[1], des->sizes[2]); 74 | print3D(host_ptr, bs, m, k); 75 | checkCudaErrors(cudaMemcpy(des->data, host_ptr, sizeof(int16_t) * bs * m * k, 76 | cudaMemcpyHostToDevice)); 77 | 78 | des = static_cast *>(b.descriptor); 79 | fillRand(host_ptr, 1, k, n); 80 | print3D(host_ptr, 1, k, n); 81 | checkCudaErrors(cudaMemcpy(des->data, host_ptr, sizeof(int16_t) * 1 * k * n, 82 | cudaMemcpyHostToDevice)); 83 | 84 | des = static_cast *>(d.descriptor); 85 | fillRand(host_ptr, 1, n, n * scale); 86 | print3D(host_ptr, 1, n, n * scale); 87 | checkCudaErrors(cudaMemcpy(des->data, host_ptr, 88 | sizeof(int16_t) * 1 * n * n * scale, 89 | cudaMemcpyHostToDevice)); 90 | 91 | auto desMean = static_cast *>(mean.descriptor); 92 | auto desVar = static_cast *>(var.descriptor); 93 | 94 | cutlassGemmWithVarMeanF16(a.rank, a.descriptor, b.rank, b.descriptor, c.rank, 95 | c.descriptor, c.rank, c.descriptor, var.rank, 96 | var.descriptor, mean.rank, mean.descriptor, 1, 0, 0, 97 | 1e-6); 98 | 99 | cutlassLayernormGemmF16(c.rank, c.descriptor, d.rank, d.descriptor, e.rank, 100 | e.descriptor, e.rank, e.descriptor, var.rank, 101 | var.descriptor, mean.rank, mean.descriptor, 1, 0, 0); 102 | 103 | des = static_cast *>(c.descriptor); 104 | 105 | checkCudaErrors(cudaMemcpy(host_ptr, des->data, sizeof(_Float16) * bs * m * n, 106 | cudaMemcpyDeviceToHost)); 107 | print3D(host_ptr, bs, m, n); 108 | 109 | checkCudaErrors(cudaMemcpy(host_ptr, desMean->data, sizeof(_Float16) * bs * m, 110 | cudaMemcpyDeviceToHost)); 111 | print1D(host_ptr, bs * m); 112 | 113 | checkCudaErrors(cudaMemcpy(host_ptr, desVar->data, sizeof(_Float16) * bs * m, 114 | cudaMemcpyDeviceToHost)); 115 | print1D(host_ptr, bs * m); 116 | 117 | des = static_cast *>(e.descriptor); 118 | 119 | checkCudaErrors(cudaMemcpy(host_ptr, des->data, 120 | sizeof(_Float16) * bs * m * n * scale, 121 | cudaMemcpyDeviceToHost)); 122 | print3D(host_ptr, bs, m, n * scale); 123 | } 124 | -------------------------------------------------------------------------------- /examples/torch/layernorm/cuda/gemm_with_mean_var.cu: -------------------------------------------------------------------------------- 1 | #include "ExecutionEngine/HandsOnNVGPURunnerUtils.h" 2 | #include "ExecutionEngine/HandsOnRunnerUtils.h" 3 | #include "NVGPUKernels/Utils.h" 4 | #include "mlir/ExecutionEngine/CRunnerUtils.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | void fillRand(_Float16 *a, int64_t m, int64_t n, int64_t k) { 11 | for (int i = 0; i < m; i++) { 12 | for (int j = 0; j < n; j++) { 13 | for (int ii = 0; ii < k; ii++) { 14 | a[i * (n * k) + j * k + ii] = float(rand()) / ((float)RAND_MAX / 5); 15 | } 16 | } 17 | } 18 | } 19 | 20 | void plusOne(_Float16 *a, _Float16 *b, int64_t m, int64_t n, int64_t k) { 21 | for (int i = 0; i < m; i++) { 22 | for (int j = 0; j < n; j++) { 23 | for (int ii = 0; ii < k; ii++) { 24 | a[i * (n * k) + j * k + ii] = 25 | b[i * (n * k) + j * k + ii] * b[i * (n * k) + j * k + ii] + 1; 26 | } 27 | } 28 | } 29 | } 30 | 31 | void print3D(_Float16 *a, int64_t m, int64_t n, int64_t k) { 32 | 33 | std::cout << "=====================\n"; 34 | for (int i = 0; i < m; i++) { 35 | for (int j = 0; j < n; j++) { 36 | for (int ii = 0; ii < k; ii++) { 37 | std::cout << float(a[i * (n * k) + j * k + ii]) << " "; 38 | } 39 | std::cout << std::endl; 40 | } 41 | std::cout << std::endl; 42 | } 43 | std::cout << "=====================\n"; 44 | } 45 | 46 | void print1D(_Float16 *a, int64_t m) { 47 | std::cout << "=====================\n"; 48 | for (int i = 0; i < m; i++) { 49 | std::cout << float(a[i]) << " "; 50 | } 51 | std::cout << "\n=====================\n"; 52 | } 53 | 54 | int main() { 55 | 56 | int m = 8, n = 8, k = 8; 57 | 58 | C_UnrankedMemRefType a, b, c, mean, var; 59 | a.rank = b.rank = c.rank = 3; 60 | mean.rank = var.rank = 1; 61 | a.descriptor = malloc(sizeof(StridedMemRefType)); 62 | b.descriptor = malloc(sizeof(StridedMemRefType)); 63 | c.descriptor = malloc(sizeof(StridedMemRefType)); 64 | mean.descriptor = malloc(sizeof(StridedMemRefType)); 65 | var.descriptor = malloc(sizeof(StridedMemRefType)); 66 | _Float16 *host_ptr = new _Float16[2 * m * n * k]; 67 | 68 | auto des = static_cast *>(a.descriptor); 69 | checkCudaErrors(cudaMalloc(&(des->data), sizeof(_Float16) * 2 * m * k)); 70 | fillRand(host_ptr, 2, m, k); 71 | print3D(host_ptr, 2, m, k); 72 | checkCudaErrors(cudaMemcpy(des->data, host_ptr, sizeof(int16_t) * 2 * m * k, 73 | cudaMemcpyHostToDevice)); 74 | des->sizes[0] = 2; 75 | des->sizes[1] = m; 76 | des->sizes[2] = k; 77 | 78 | des = static_cast *>(b.descriptor); 79 | checkCudaErrors(cudaMalloc(&(des->data), sizeof(_Float16) * 1 * k * n)); 80 | fillRand(host_ptr, 1, k, n); 81 | print3D(host_ptr, 1, k, n); 82 | checkCudaErrors(cudaMemcpy(des->data, host_ptr, sizeof(int16_t) * 1 * k * n, 83 | cudaMemcpyHostToDevice)); 84 | des->sizes[0] = 1; 85 | des->sizes[1] = k; 86 | des->sizes[2] = n; 87 | 88 | des = static_cast *>(c.descriptor); 89 | checkCudaErrors(cudaMalloc(&(des->data), sizeof(_Float16) * 2 * m * n)); 90 | fillRand(host_ptr, 2, m, n); 91 | checkCudaErrors(cudaMemcpy(des->data, host_ptr, sizeof(int16_t) * 2 * m * n, 92 | cudaMemcpyHostToDevice)); 93 | des->sizes[0] = 2; 94 | des->sizes[1] = m; 95 | des->sizes[2] = n; 96 | 97 | auto desMean = static_cast *>(mean.descriptor); 98 | checkCudaErrors(cudaMalloc(&(desMean->data), sizeof(_Float16) * 2 * n)); 99 | fillRand(host_ptr, 1, 1, 2 * n); 100 | checkCudaErrors(cudaMemcpy(desMean->data, host_ptr, sizeof(int16_t) * 2 * n, 101 | cudaMemcpyHostToDevice)); 102 | desMean->sizes[0] = 2 * n; 103 | print1D(host_ptr, 2 * n); 104 | 105 | auto desVar = static_cast *>(var.descriptor); 106 | checkCudaErrors(cudaMalloc(&(desVar->data), sizeof(_Float16) * 2 * n)); 107 | plusOne(host_ptr, host_ptr, 1, 1, 2 * n); 108 | checkCudaErrors(cudaMemcpy(desVar->data, host_ptr, sizeof(int16_t) * 2 * n, 109 | cudaMemcpyHostToDevice)); 110 | print1D(host_ptr, 2 * n); 111 | desVar->sizes[0] = 2 * n; 112 | 113 | cutlassGemmWithVarMeanF16(a.rank, a.descriptor, b.rank, b.descriptor, c.rank, 114 | c.descriptor, c.rank, c.descriptor, var.rank, 115 | var.descriptor, mean.rank, mean.descriptor, 1, 0, 0, 116 | 1e-6); 117 | 118 | checkCudaErrors(cudaMemcpy(host_ptr, des->data, sizeof(_Float16) * 2 * m * n, 119 | cudaMemcpyDeviceToHost)); 120 | print3D(host_ptr, 2, m, n); 121 | 122 | checkCudaErrors(cudaMemcpy(host_ptr, desMean->data, sizeof(_Float16) * 2 * n, 123 | cudaMemcpyDeviceToHost)); 124 | print1D(host_ptr, 2 * n); 125 | 126 | checkCudaErrors(cudaMemcpy(host_ptr, desVar->data, sizeof(_Float16) * 2 * n, 127 | cudaMemcpyDeviceToHost)); 128 | print1D(host_ptr, 2 * n); 129 | } 130 | -------------------------------------------------------------------------------- /include/Dialect/HOMNVGPU/HOMNVGPUOps.td: -------------------------------------------------------------------------------- 1 | #ifndef HOM_HOMNVGPUDIALECT_TD 2 | #define HOM_HOMNVGPUDIALECT_TD 3 | 4 | include "mlir/IR/OpBase.td" 5 | include "mlir/IR/AttrTypeBase.td" 6 | include "mlir/IR/EnumAttr.td" 7 | include "mlir/IR/SymbolInterfaces.td" 8 | include "mlir/Interfaces/SideEffectInterfaces.td" 9 | include "mlir/IR/OpAsmInterface.td" 10 | include "HOM/HOMTypesBase.td" // Just reuses the hom types. 11 | include "mlir/Interfaces/InferTypeOpInterface.td" 12 | 13 | //===----------------------------------------------------------------------===// 14 | // HOMNVGPU Dialect Definition. 15 | //===----------------------------------------------------------------------===// 16 | 17 | def HOMNVGPU_Dialect : Dialect { 18 | let name = "homnvgpu"; 19 | let summary = "The Hands on MLIR Dialect for nvgpu."; 20 | let description = 21 | [{ The `HOMNVGPU` dialect is for better nv gpu related optimization. }]; 22 | let cppNamespace = "::mlir::hands_on_mlir::homnvgpu"; 23 | } 24 | 25 | class HOMNVGPU_Op traits = []> 26 | : Op; 27 | 28 | def HOMNVGPU_PrintOp : HOMNVGPU_Op<"print", []> { 29 | let arguments = (ins F64Tensor : $input); 30 | 31 | let assemblyFormat = "$input attr-dict `:` type($input)"; 32 | } 33 | 34 | def HOMNVGPU_MatmulOp : HOMNVGPU_Op<"matmul", []> { 35 | let arguments = (ins HOM_RankedTensor 36 | : $operand0, HOM_RankedTensor 37 | : $operand1, HOM_Tensor 38 | : $operand2, F32Attr 39 | : $alpha, F32Attr 40 | : $beta, I32Attr 41 | : $act, DefaultValuedAttr 42 | : $transa, DefaultValuedAttr 43 | : $transb, DefaultValuedAttr 44 | : $kernel_name, DefaultValuedAttr 45 | : $split_k_factor); // Only support serial splitK now. 46 | 47 | let results = (outs HOM_RankedTensor : $output); 48 | } 49 | 50 | def HOMNVGPU_MatmulWithVarMeanOp 51 | : HOMNVGPU_Op<"matmul_with_var_mean", 52 | [DeclareOpInterfaceMethods]> { 53 | let arguments = (ins HOM_RankedTensor 54 | : $operand0, HOM_RankedTensor 55 | : $operand1, HOM_Tensor 56 | : $operand2, F32Attr 57 | : $alpha, F32Attr 58 | : $beta, I32Attr 59 | : $act, F32Attr 60 | : $eps, DefaultValuedAttr 61 | : $transa, DefaultValuedAttr 62 | : $transb); 63 | 64 | let results = (outs HOM_RankedTensor 65 | : $output, HOM_RankedTensor 66 | : $var, HOM_RankedTensor 67 | : $mean); 68 | } 69 | 70 | def HOMNVGPU_LayernormMatmulOp : HOMNVGPU_Op<"ln_matmul", []> { 71 | let arguments = (ins HOM_RankedTensor 72 | : $operand0, HOM_RankedTensor 73 | : $operand1, HOM_Tensor 74 | : $operand2, HOM_RankedTensor 75 | : $var, HOM_RankedTensor 76 | : $mean, F32Attr 77 | : $alpha, F32Attr 78 | : $beta, I32Attr 79 | : $act, DefaultValuedAttr 80 | : $transa, DefaultValuedAttr 81 | : $transb); 82 | 83 | let results = (outs HOM_RankedTensor : $output); 84 | } 85 | 86 | def HOMNVGPU_LayernormOp : HOMNVGPU_Op<"layernorm", []> { 87 | let summary = "Layernorm operator"; 88 | 89 | let description = [{Do layernorm.}]; 90 | 91 | let arguments = (ins HOM_Tensor : $input, F32Attr : $eps, I32Attr : $axis); 92 | 93 | let results = (outs HOM_Tensor : $output); 94 | } 95 | 96 | def HOMNVGPU_BertMhaOp : HOMNVGPU_Op<"bert_mha", []> { 97 | let summary = "Bert's MHA operator"; 98 | 99 | let description = [{Do Bert MHA.}]; 100 | 101 | let arguments = (ins HOM_RankedTensor 102 | : $qkv, HOM_RankedTensor 103 | : $mask, F32Attr 104 | : $scale, I64Attr 105 | : $head_num); 106 | 107 | let results = (outs HOM_RankedTensor : $output); 108 | } 109 | 110 | def HOMNVGPU_CuSeqLenOp : HOMNVGPU_Op<"cu_seqlen", []> { 111 | let summary = "Bert's MHA operator"; 112 | 113 | let description = [{Do CuSeqLen.}]; 114 | 115 | let arguments = (ins HOM_RankedTensor : $input); 116 | 117 | let results = (outs HOM_RankedTensor : $output); 118 | } 119 | 120 | def HOMNVGPU_AddOp : HOMNVGPU_Op<"add", [Commutative]> { 121 | let summary = "Elementwise addition operator"; 122 | 123 | let description = 124 | [{Elementwise addition of input1 and 125 | input2.Axis of size 1 will be broadcast as necessary.}]; 126 | 127 | let arguments = (ins HOM_Tensor : $input1, HOM_Tensor : $input2); 128 | 129 | let results = (outs HOM_Tensor : $output); 130 | } 131 | 132 | def HOMNVGPU_GatherOp : HOMNVGPU_Op<"gather", []> { 133 | let arguments = (ins HOM_RankedTensor : $indices, HOM_RankedTensor : $value); 134 | 135 | let results = (outs HOM_RankedTensor : $output); 136 | } 137 | 138 | #endif // HOM_HOMNVGPUDIALECT_TD 139 | --------------------------------------------------------------------------------