├── bazel ├── BUILD └── lit.bzl ├── .bazelversion ├── .bazelignore ├── requirements.txt ├── lib ├── Conversion │ ├── CMakeLists.txt │ └── PolyToStandard │ │ ├── PolyToStandard.td │ │ ├── CMakeLists.txt │ │ ├── PolyToStandard.h │ │ ├── BUILD │ │ └── PolyToStandard.cpp ├── Analysis │ ├── CMakeLists.txt │ └── ReduceNoiseAnalysis │ │ ├── CMakeLists.txt │ │ ├── BUILD │ │ ├── ReduceNoiseAnalysis.h │ │ └── ReduceNoiseAnalysis.cpp ├── Dialect │ ├── CMakeLists.txt │ ├── Noisy │ │ ├── NoisyTypes.h │ │ ├── NoisyDialect.td │ │ ├── NoisyDialect.h │ │ ├── NoisyTypes.td │ │ ├── NoisyOps.h │ │ ├── NoisyDialect.cpp │ │ ├── CMakeLists.txt │ │ ├── NoisyOps.cpp │ │ ├── NoisyOps.td │ │ └── BUILD │ └── Poly │ │ ├── PolyDialect.h │ │ ├── PolyTypes.h │ │ ├── PolyDialect.td │ │ ├── PolyOps.h │ │ ├── PolyTypes.td │ │ ├── PolyTraits.h │ │ ├── PolyPatterns.td │ │ ├── CMakeLists.txt │ │ ├── PolyDialect.cpp │ │ ├── PolyOps.td │ │ ├── BUILD │ │ └── PolyOps.cpp ├── Transform │ ├── CMakeLists.txt │ ├── Arith │ │ ├── MulToAdd.h │ │ ├── Passes.h │ │ ├── MulToAddPdll.h │ │ ├── CMakeLists.txt │ │ ├── Passes.td │ │ ├── MulToAddPdll.cpp │ │ ├── BUILD │ │ ├── MulToAdd.pdll │ │ └── MulToAdd.cpp │ ├── Affine │ │ ├── AffineFullUnroll.h │ │ ├── Passes.h │ │ ├── AffineFullUnrollPatternRewrite.h │ │ ├── CMakeLists.txt │ │ ├── Passes.td │ │ ├── AffineFullUnroll.cpp │ │ ├── AffineFullUnrollPatternRewrite.cpp │ │ └── BUILD │ └── Noisy │ │ ├── Passes.h │ │ ├── ReduceNoiseOptimizer.h │ │ ├── CMakeLists.txt │ │ ├── Passes.td │ │ ├── BUILD │ │ └── ReduceNoiseOptimizer.cpp └── CMakeLists.txt ├── BUILD ├── .gitmodules ├── tests ├── ctlz_simple.mlir ├── poly_to_llvm_main.c ├── poly_verifier.mlir ├── noisy_syntax.mlir ├── lit.cmake.site.cfg.py.in ├── CMakeLists.txt ├── cse.mlir ├── poly_to_llvm_eval.mlir ├── BUILD ├── poly_to_llvm.mlir ├── control_flow_sink.mlir ├── code_motion.mlir ├── mul_to_add_pdll.mlir ├── mul_to_add.mlir ├── ctlz_runner.mlir ├── affine_loop_unroll.mlir ├── sccp.mlir ├── lit.cfg.py ├── poly_syntax.mlir ├── lit.cmake.cfg.py ├── ctlz.mlir ├── poly_canonicalize.mlir ├── poly_to_standard.mlir └── noisy_reduce_noise.mlir ├── .bazelrc ├── .gitignore ├── tools ├── CMakeLists.txt ├── BUILD └── tutorial-opt.cpp ├── .github └── workflows │ ├── build_and_test.yml │ └── build_and_test_cmake.yml ├── CMakeLists.txt ├── extensions.bzl ├── MODULE.bazel └── README.md /bazel/BUILD: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 8.3.1 2 | -------------------------------------------------------------------------------- /.bazelignore: -------------------------------------------------------------------------------- 1 | externals 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lit==18.1.8 2 | -------------------------------------------------------------------------------- /lib/Conversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(PolyToStandard) 2 | -------------------------------------------------------------------------------- /lib/Analysis/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(ReduceNoiseAnalysis) 2 | -------------------------------------------------------------------------------- /lib/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Noisy) 2 | add_subdirectory(Poly) 3 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | # An MLIR tutorial 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | -------------------------------------------------------------------------------- /lib/Transform/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Affine) 2 | add_subdirectory(Arith) 3 | add_subdirectory(Noisy) 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "externals/llvm-project"] 2 | path = externals/llvm-project 3 | url = https://github.com/llvm/llvm-project.git 4 | -------------------------------------------------------------------------------- /lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Conversion) 2 | add_subdirectory(Dialect) 3 | add_subdirectory(Analysis) 4 | add_subdirectory(Transform) 5 | -------------------------------------------------------------------------------- /tests/ctlz_simple.mlir: -------------------------------------------------------------------------------- 1 | // RUN: mlir-opt %s --convert-math-to-funcs=convert-ctlz | FileCheck %s 2 | 3 | func.func @main(%arg0: i32) -> i32 { 4 | // CHECK-NOT: math.ctlz 5 | // CHECK: call 6 | %0 = math.ctlz %arg0 : i32 7 | func.return %0 : i32 8 | } 9 | -------------------------------------------------------------------------------- /lib/Analysis/ReduceNoiseAnalysis/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(ReduceNoiseAnalysis 2 | ReduceNoiseAnalysis.cpp 3 | 4 | ${PROJECT_SOURCE_DIR}/lib/Analysis/ReduceNoiseAnalysis/ 5 | ADDITIONAL_HEADER_DIRS 6 | LINK_LIBS PUBLIC 7 | ortools::ortools 8 | ) 9 | -------------------------------------------------------------------------------- /tests/poly_to_llvm_main.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // This is the function we want to call from LLVM 4 | int test_poly_fn(int x); 5 | 6 | int main(int argc, char *argv[]) { 7 | int i = 1; 8 | int result = test_poly_fn(i); 9 | printf("Result: %d\n", result); 10 | 11 | return 0; 12 | } 13 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyTypes.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TYPES_NOISY_NOISYTYPES_H_ 2 | #define LIB_TYPES_NOISY_NOISYTYPES_H_ 3 | 4 | #include "mlir/include/mlir/IR/DialectImplementation.h" 5 | 6 | #define GET_TYPEDEF_CLASSES 7 | #include "lib/Dialect/Noisy/NoisyTypes.h.inc" 8 | 9 | #endif // LIB_TYPES_NOISY_NOISYTYPES_H_ 10 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyDialect.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYDIALECT_H_ 2 | #define LIB_DIALECT_POLY_POLYDIALECT_H_ 3 | 4 | // Required because the .h.inc file refers to MLIR classes and does not itself 5 | // have any includes. 6 | #include "mlir/include/mlir/IR/DialectImplementation.h" 7 | 8 | #include "lib/Dialect/Poly/PolyDialect.h.inc" 9 | 10 | #endif // LIB_DIALECT_POLY_POLYDIALECT_H_ 11 | -------------------------------------------------------------------------------- /lib/Transform/Arith/MulToAdd.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_ARITH_MULTOADD_H_ 2 | #define LIB_TRANSFORM_ARITH_MULTOADD_H_ 3 | 4 | #include "mlir/Pass/Pass.h" 5 | 6 | namespace mlir { 7 | namespace tutorial { 8 | 9 | #define GEN_PASS_DECL_MULTOADD 10 | #include "lib/Transform/Arith/Passes.h.inc" 11 | 12 | } // namespace tutorial 13 | } // namespace mlir 14 | 15 | #endif // LIB_TRANSFORM_ARITH_MULTOADD_H_ 16 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyTypes.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TYPES_POLY_POLYTYPES_H_ 2 | #define LIB_TYPES_POLY_POLYTYPES_H_ 3 | 4 | // Required because the .h.inc file refers to MLIR classes and does not itself 5 | // have any includes. 6 | #include "mlir/include/mlir/IR/DialectImplementation.h" 7 | 8 | #define GET_TYPEDEF_CLASSES 9 | #include "lib/Dialect/Poly/PolyTypes.h.inc" 10 | 11 | #endif // LIB_TYPES_POLY_POLYTYPES_H_ 12 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | common --enable_bzlmod 2 | common --action_env=BAZEL_CXXOPTS=-std=c++17 3 | common --cxxopt='-std=c++17' 4 | common --deleted_packages=externals 5 | build:macos --apple_platform_type=macos 6 | build:macos --macos_minimum_os=10.13 7 | build:macos --macos_sdk_version=10.13 8 | build:macos_arm64 --cpu=darwin_arm64 9 | common --copt=-fdiagnostics-color=always 10 | common --test_output=errors 11 | common -c dbg 12 | -------------------------------------------------------------------------------- /lib/Transform/Affine/AffineFullUnroll.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_AFFINE_AFFINEFULLUNROLL_H_ 2 | #define LIB_TRANSFORM_AFFINE_AFFINEFULLUNROLL_H_ 3 | 4 | #include "mlir/Pass/Pass.h" 5 | 6 | namespace mlir { 7 | namespace tutorial { 8 | 9 | #define GEN_PASS_DECL_AFFINEFULLUNROLL 10 | #include "lib/Transform/Affine/Passes.h.inc" 11 | 12 | } // namespace tutorial 13 | } // namespace mlir 14 | 15 | #endif // LIB_TRANSFORM_AFFINE_AFFINEFULLUNROLL_H_ 16 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyDialect.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_NOISY_NOISYDIALECT_TD_ 2 | #define LIB_DIALECT_NOISY_NOISYDIALECT_TD_ 3 | 4 | include "mlir/IR/OpBase.td" 5 | 6 | def Noisy_Dialect : Dialect { 7 | let name = "noisy"; 8 | let summary = "A dialect for arithmetic on noisy i32s"; 9 | 10 | let cppNamespace = "::mlir::tutorial::noisy"; 11 | 12 | let useDefaultTypePrinterParser = 1; 13 | } 14 | 15 | #endif // LIB_DIALECT_NOISY_NOISYDIALECT_TD_ 16 | -------------------------------------------------------------------------------- /lib/Transform/Arith/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_ARITH_PASSES_H_ 2 | #define LIB_TRANSFORM_ARITH_PASSES_H_ 3 | 4 | #include "lib/Transform/Arith/MulToAdd.h" 5 | #include "lib/Transform/Arith/MulToAddPdll.h" 6 | 7 | namespace mlir { 8 | namespace tutorial { 9 | 10 | #define GEN_PASS_REGISTRATION 11 | #include "lib/Transform/Arith/Passes.h.inc" 12 | 13 | } // namespace tutorial 14 | } // namespace mlir 15 | 16 | #endif // LIB_TRANSFORM_ARITH_PASSES_H_ 17 | -------------------------------------------------------------------------------- /lib/Transform/Noisy/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_NOISY_PASSES_H_ 2 | #define LIB_TRANSFORM_NOISY_PASSES_H_ 3 | 4 | #include "lib/Transform/Noisy/ReduceNoiseOptimizer.h" 5 | 6 | namespace mlir { 7 | namespace tutorial { 8 | namespace noisy { 9 | 10 | #define GEN_PASS_REGISTRATION 11 | #include "lib/Transform/Noisy/Passes.h.inc" 12 | 13 | } // namespace noisy 14 | } // namespace tutorial 15 | } // namespace mlir 16 | 17 | #endif // LIB_TRANSFORM_NOISY_PASSES_H_ 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # bazel build files 2 | bazel-bin 3 | bazel-mlir-tutorial 4 | bazel-out 5 | bazel-testlogs 6 | MODULE.bazel.lock 7 | 8 | # cmake related files 9 | # ignore the user specified CMake presets in subproject directories. 10 | /*/CMakeUserPresets.json 11 | 12 | # Nested build directory 13 | /build* 14 | 15 | # Visual Studio built-in CMake configuration 16 | .vscode* 17 | /CMakeSettings.json 18 | # Compilation databases 19 | compile_commands.json 20 | tablegen_compile_commands.yml 21 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyDialect.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_NOISY_NOISYDIALECT_H_ 2 | #define LIB_DIALECT_NOISY_NOISYDIALECT_H_ 3 | 4 | // Required because the .h.inc file refers to MLIR classes and does not itself 5 | // have any includes. 6 | #include "mlir/include/mlir/IR/DialectImplementation.h" 7 | 8 | #include "lib/Dialect/Noisy/NoisyDialect.h.inc" 9 | 10 | 11 | constexpr int INITIAL_NOISE = 12; 12 | constexpr int MAX_NOISE = 26; 13 | 14 | #endif // LIB_DIALECT_NOISY_NOISYDIALECT_H_ 15 | -------------------------------------------------------------------------------- /lib/Analysis/ReduceNoiseAnalysis/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | ) 4 | 5 | cc_library( 6 | name = "ReduceNoiseAnalysis", 7 | srcs = ["ReduceNoiseAnalysis.cpp"], 8 | hdrs = ["ReduceNoiseAnalysis.h"], 9 | deps = [ 10 | "//lib/Dialect/Noisy", 11 | "@or-tools//ortools/base", 12 | "@or-tools//ortools/linear_solver", 13 | "@llvm-project//llvm:Support", 14 | "@llvm-project//mlir:IR", 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /lib/Transform/Affine/Passes.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_AFFINE_PASSES_H_ 2 | #define LIB_TRANSFORM_AFFINE_PASSES_H_ 3 | 4 | #include "lib/Transform/Affine/AffineFullUnroll.h" 5 | #include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h" 6 | 7 | namespace mlir { 8 | namespace tutorial { 9 | 10 | #define GEN_PASS_REGISTRATION 11 | #include "lib/Transform/Affine/Passes.h.inc" 12 | 13 | } // namespace tutorial 14 | } // namespace mlir 15 | 16 | #endif // LIB_TRANSFORM_AFFINE_PASSES_H_ 17 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyTypes.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_NOISY_NOISYTYPES_TD_ 2 | #define LIB_DIALECT_NOISY_NOISYTYPES_TD_ 3 | 4 | include "NoisyDialect.td" 5 | include "mlir/IR/AttrTypeBase.td" 6 | 7 | class Noisy_Type : TypeDef { 8 | let mnemonic = typeMnemonic; 9 | } 10 | 11 | def Noisy_I32 : Noisy_Type<"NoisyI32", "i32"> { 12 | let summary = "A type for approximate 32-bit integers."; 13 | } 14 | 15 | #endif // LIB_DIALECT_NOISY_NOISYTYPES_TD_ 16 | -------------------------------------------------------------------------------- /tests/poly_verifier.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s 2>%t; FileCheck %s < %t 2 | 3 | func.func @test_invalid_evalop(%arg0: !poly.poly<10>, %cst: i64) -> i64 { 4 | // This is a little brittle, since it matches both the error message 5 | // emitted by Has32BitArguments as well as that of EvalOp::verify. 6 | // I manually tested that they both fire when the input is as below. 7 | // CHECK: to be a 32-bit integer 8 | %0 = poly.eval %arg0, %cst : (!poly.poly<10>, i64) -> i64 9 | return %0 : i64 10 | } 11 | -------------------------------------------------------------------------------- /lib/Transform/Noisy/ReduceNoiseOptimizer.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_NOISY_REDUCENOISEOPTIMIZER_H_ 2 | #define LIB_TRANSFORM_NOISY_REDUCENOISEOPTIMIZER_H_ 3 | 4 | #include "mlir/Pass/Pass.h" 5 | 6 | namespace mlir { 7 | namespace tutorial { 8 | namespace noisy { 9 | 10 | #define GEN_PASS_DECL_REDUCENOISEOPTIMIZER 11 | #include "lib/Transform/Noisy/Passes.h.inc" 12 | 13 | } // namespace noisy 14 | } // namespace tutorial 15 | } // namespace mlir 16 | 17 | #endif // LIB_TRANSFORM_NOISY_REDUCENOISEOPTIMIZER_H_ 18 | -------------------------------------------------------------------------------- /tests/noisy_syntax.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s | FileCheck %s 2 | // Check for syntax 3 | 4 | // CHECK-LABEL: test_op_syntax 5 | func.func @test_op_syntax() { 6 | %0 = arith.constant 3 : i5 7 | %1 = arith.constant 4 : i5 8 | %2 = noisy.encode %0 : i5 -> !noisy.i32 9 | %3 = noisy.encode %1 : i5 -> !noisy.i32 10 | %4 = noisy.add %2, %3 : !noisy.i32 11 | %5 = noisy.mul %4, %4 : !noisy.i32 12 | %6 = noisy.reduce_noise %5 : !noisy.i32 13 | %7 = noisy.decode %6 : !noisy.i32 -> i5 14 | return 15 | } 16 | -------------------------------------------------------------------------------- /lib/Transform/Affine/AffineFullUnrollPatternRewrite.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_AFFINE_AFFINEFULLUNROLLPATTERNREWRITE_H_ 2 | #define LIB_TRANSFORM_AFFINE_AFFINEFULLUNROLLPATTERNREWRITE_H_ 3 | 4 | #include "mlir/Pass/Pass.h" 5 | 6 | namespace mlir { 7 | namespace tutorial { 8 | 9 | #define GEN_PASS_DECL_AFFINEFULLUNROLLPATTERNREWRITE 10 | #include "lib/Transform/Affine/Passes.h.inc" 11 | 12 | } // namespace tutorial 13 | } // namespace mlir 14 | 15 | #endif // LIB_TRANSFORM_AFFINE_AFFINEFULLUNROLLPATTERNREWRITE_H_ 16 | -------------------------------------------------------------------------------- /lib/Transform/Noisy/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(NoisyPasses 2 | ReduceNoiseOptimizer.cpp 3 | 4 | ${PROJECT_SOURCE_DIR}/lib/Transform/Noisy/ 5 | ADDITIONAL_HEADER_DIRS 6 | 7 | DEPENDS 8 | MLIRNoisy 9 | MLIRNoisyPasses 10 | 11 | LINK_LIBS PUBLIC 12 | ReduceNoiseAnalysis 13 | ) 14 | 15 | set(LLVM_TARGET_DEFINITIONS Passes.td) 16 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name Noisy) 17 | add_public_tablegen_target(MLIRNoisyPasses) 18 | add_mlir_doc(Passes NoisyPasses ./ -gen-pass-doc) 19 | -------------------------------------------------------------------------------- /tests/lit.cmake.site.cfg.py.in: -------------------------------------------------------------------------------- 1 | @LIT_SITE_CFG_IN_HEADER@ 2 | 3 | config.llvm_tools_dir = lit_config.substitute("@LLVM_TOOLS_DIR@") 4 | config.mlir_obj_dir = "@MLIR_BINARY_DIR@" 5 | config.llvm_shlib_ext = "@SHLIBEXT@" 6 | config.project_binary_dir = "@PROJECT_BINARY_DIR@" 7 | config.project_source_dir = "@PROJECT_SOURCE_DIR@" 8 | 9 | import lit.llvm 10 | lit.llvm.initialize(lit_config, config) 11 | 12 | # Let the main config do the real work. 13 | lit_config.load_config(config, "@PROJECT_SOURCE_DIR@/tests/lit.cmake.cfg.py") 14 | -------------------------------------------------------------------------------- /lib/Transform/Affine/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(AffineFullUnroll 2 | AffineFullUnroll.cpp 3 | AffineFullUnrollPatternRewrite.cpp 4 | 5 | ${PROJECT_SOURCE_DIR}/lib/Transform/Affine/ 6 | ADDITIONAL_HEADER_DIRS 7 | 8 | DEPENDS 9 | MLIRAffineFullUnrollPasses 10 | 11 | LINK_LIBS PUBLIC 12 | ) 13 | 14 | set(LLVM_TARGET_DEFINITIONS Passes.td) 15 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name Affine) 16 | add_public_tablegen_target(MLIRAffineFullUnrollPasses) 17 | add_mlir_doc(Passes AffinePasses ./ -gen-pass-doc) 18 | -------------------------------------------------------------------------------- /tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | configure_lit_site_cfg( 2 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.cmake.site.cfg.py.in 3 | ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py 4 | MAIN_CONFIG 5 | ${CMAKE_CURRENT_SOURCE_DIR}/lit.cmake.cfg.py 6 | ) 7 | 8 | set (MLIR_TUTORIAL_TEST_DEPENDS 9 | FileCheck count not 10 | mlir-opt 11 | mlir-runner 12 | # tutorial-opt 13 | ) 14 | 15 | add_lit_testsuite(check-mlir-tutorial "Running the MLIR tutorial regression tests" 16 | ${CMAKE_CURRENT_BINARY_DIR} 17 | DEPENDS ${MLIR_TUTORIAL_TEST_DEPENDS} 18 | ) -------------------------------------------------------------------------------- /lib/Transform/Noisy/Passes.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_NOISY_PASSES_TD_ 2 | #define LIB_TRANSFORM_NOISY_PASSES_TD_ 3 | 4 | include "mlir/Pass/PassBase.td" 5 | 6 | def ReduceNoiseOptimizer : Pass<"noisy-reduce-noise-optimizer"> { 7 | let summary = "Insert reduce_noise ops optimally"; 8 | let description = [{ 9 | Solves an integer linear program to select the optimal locations in the IR 10 | to insert `reduce_noise` ops. 11 | }]; 12 | let dependentDialects = ["mlir::tutorial::noisy::NoisyDialect"]; 13 | } 14 | 15 | #endif // LIB_TRANSFORM_NOISY_PASSES_TD_ 16 | -------------------------------------------------------------------------------- /tests/cse.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt -cse %s | FileCheck %s 2 | 3 | // CHECK-LABEL: @test_simple_cse 4 | func.func @test_simple_cse() -> !poly.poly<10> { 5 | %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> 6 | // CHECK: poly.from_tensor 7 | %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> 8 | // exactly one mul op 9 | // CHECK-NEXT: poly.mul 10 | // CHECK-NEXT: poly.add 11 | %2 = poly.mul %p0, %p0 : !poly.poly<10> 12 | %3 = poly.mul %p0, %p0 : !poly.poly<10> 13 | %4 = poly.add %2, %3 : !poly.poly<10> 14 | return %4 : !poly.poly<10> 15 | } 16 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyOps.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_NOISY_NOISYOPS_H_ 2 | #define LIB_DIALECT_NOISY_NOISYOPS_H_ 3 | 4 | #include "lib/Dialect/Noisy/NoisyDialect.h" 5 | #include "lib/Dialect/Noisy/NoisyTypes.h" 6 | #include "mlir/Interfaces/InferTypeOpInterface.h" 7 | #include "mlir/Interfaces/InferIntRangeInterface.h" 8 | #include "mlir/include/mlir/IR/BuiltinOps.h" 9 | #include "mlir/include/mlir/IR/BuiltinTypes.h" 10 | #include "mlir/include/mlir/IR/Dialect.h" 11 | 12 | #define GET_OP_CLASSES 13 | #include "lib/Dialect/Noisy/NoisyOps.h.inc" 14 | 15 | #endif // LIB_DIALECT_NOISY_NOISYOPS_H_ 16 | -------------------------------------------------------------------------------- /lib/Transform/Arith/MulToAddPdll.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ 2 | #define LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ 3 | 4 | #include "mlir/Pass/Pass.h" 5 | #include "mlir/IR/PatternMatch.h" 6 | #include "mlir/Dialect/Arith/IR/Arith.h" 7 | #include "mlir/Parser/Parser.h" 8 | 9 | namespace mlir { 10 | namespace tutorial { 11 | 12 | #define GEN_PASS_DECL_MULTOADDPDLL 13 | #include "lib/Transform/Arith/Passes.h.inc" 14 | 15 | #include "lib/Transform/Arith/MulToAddPdll.h.inc" 16 | 17 | } // namespace tutorial 18 | } // namespace mlir 19 | 20 | #endif // LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ 21 | -------------------------------------------------------------------------------- /tests/poly_to_llvm_eval.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt --poly-to-llvm %s | mlir-translate --mlir-to-llvmir | llc --relocation-model=pic -filetype=obj > %t 2 | // RUN: clang -c %project_source_dir/tests/poly_to_llvm_main.c 3 | // RUN: clang poly_to_llvm_main.o %t -o eval_test.out 4 | // RUN: ./eval_test.out | FileCheck %s 5 | 6 | // CHECK: 9 7 | func.func @test_poly_fn(%arg : i32) -> i32 { 8 | // 2 + 3x + 4x^2 evaluated at x=1, should be 2+3+4 9 | %input = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<3> 10 | %0 = poly.eval %input, %arg: (!poly.poly<3>, i32) -> i32 11 | return %0 : i32 12 | } 13 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyDialect.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYDIALECT_TD_ 2 | #define LIB_DIALECT_POLY_POLYDIALECT_TD_ 3 | 4 | include "mlir/IR/OpBase.td" 5 | 6 | def Poly_Dialect : Dialect { 7 | let name = "poly"; 8 | let summary = "A dialect for polynomial math"; 9 | let description = [{ 10 | The poly dialect defines types and operations for single-variable 11 | polynomials over integers. 12 | }]; 13 | 14 | let cppNamespace = "::mlir::tutorial::poly"; 15 | 16 | let useDefaultTypePrinterParser = 1; 17 | let hasConstantMaterializer = 1; 18 | } 19 | 20 | #endif // LIB_DIALECT_POLY_POLYDIALECT_TD_ 21 | -------------------------------------------------------------------------------- /lib/Transform/Arith/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_pdll_library(MulToAddPdllIncGen 2 | MulToAdd.pdll 3 | MulToAddPdll.h.inc 4 | ) 5 | 6 | add_mlir_library(MulToAdd 7 | MulToAdd.cpp 8 | MulToAddPdll.cpp 9 | 10 | ${PROJECT_SOURCE_DIR}/lib/Transform/Arith/ 11 | ADDITIONAL_HEADER_DIRS 12 | 13 | DEPENDS 14 | MLIRMulToAddPasses 15 | MulToAddPdllIncGen 16 | 17 | LINK_LIBS PUBLIC 18 | ) 19 | 20 | set(LLVM_TARGET_DEFINITIONS Passes.td) 21 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name Arith) 22 | add_public_tablegen_target(MLIRMulToAddPasses) 23 | add_mlir_doc(Passes ArithPasses ./ -gen-pass-doc) 24 | -------------------------------------------------------------------------------- /tools/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) 2 | get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) 3 | 4 | set (LIBS 5 | ${dialect_libs} 6 | ${conversion_libs} 7 | AffineFullUnroll 8 | MLIRNoisy 9 | MLIROptLib 10 | MLIRPass 11 | MLIRPoly 12 | MulToAdd 13 | NoisyPasses 14 | PolyToStandard 15 | ortools::ortools 16 | ) 17 | 18 | add_llvm_executable(tutorial-opt tutorial-opt.cpp) 19 | 20 | llvm_update_compile_flags(tutorial-opt) 21 | target_link_libraries(tutorial-opt PRIVATE ${LIBS}) 22 | 23 | mlir_check_all_link_libraries(tutorial-opt) 24 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyOps.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYOPS_H_ 2 | #define LIB_DIALECT_POLY_POLYOPS_H_ 3 | 4 | #include "lib/Dialect/Poly/PolyDialect.h" 5 | #include "lib/Dialect/Poly/PolyTraits.h" 6 | #include "lib/Dialect/Poly/PolyTypes.h" 7 | #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project 8 | #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project 9 | #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project 10 | #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project 11 | 12 | #define GET_OP_CLASSES 13 | #include "lib/Dialect/Poly/PolyOps.h.inc" 14 | 15 | #endif // LIB_DIALECT_POLY_POLYOPS_H_ 16 | -------------------------------------------------------------------------------- /lib/Conversion/PolyToStandard/PolyToStandard.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_ 2 | #define LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_ 3 | 4 | include "mlir/Pass/PassBase.td" 5 | 6 | def PolyToStandard : Pass<"poly-to-standard"> { 7 | let summary = "Lower `poly` to standard MLIR dialects."; 8 | 9 | let description = [{ 10 | This pass lowers the `poly` dialect to standard MLIR, a mixture of affine, 11 | tensor, and arith. 12 | }]; 13 | let dependentDialects = [ 14 | "mlir::arith::ArithDialect", 15 | "mlir::tutorial::poly::PolyDialect", 16 | "mlir::tensor::TensorDialect", 17 | "mlir::scf::SCFDialect", 18 | ]; 19 | } 20 | 21 | #endif // LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_ 22 | -------------------------------------------------------------------------------- /tests/BUILD: -------------------------------------------------------------------------------- 1 | load("//bazel:lit.bzl", "glob_lit_tests") 2 | 3 | # Bundle together all of the test utilities that are used by tests. 4 | filegroup( 5 | name = "test_utilities", 6 | testonly = True, 7 | data = [ 8 | "//tests:lit.cfg.py", 9 | "//tests:poly_to_llvm_main.c", 10 | "//tools:tutorial-opt", 11 | "@llvm-project//clang:clang", 12 | "@llvm-project//llvm:FileCheck", 13 | "@llvm-project//llvm:count", 14 | "@llvm-project//llvm:llc", 15 | "@llvm-project//llvm:not", 16 | "@llvm-project//mlir:mlir-runner", 17 | "@llvm-project//mlir:mlir-opt", 18 | "@llvm-project//mlir:mlir-translate", 19 | "@mlir_tutorial_pip_deps//lit", 20 | ], 21 | ) 22 | 23 | glob_lit_tests() 24 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyTypes.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYTYPES_TD_ 2 | #define LIB_DIALECT_POLY_POLYTYPES_TD_ 3 | 4 | include "PolyDialect.td" 5 | include "mlir/IR/AttrTypeBase.td" 6 | 7 | // A base class for all types in this dialect 8 | class Poly_Type : TypeDef { 9 | let mnemonic = typeMnemonic; 10 | } 11 | 12 | def Polynomial : Poly_Type<"Polynomial", "poly"> { 13 | let summary = "A polynomial with u32 coefficients"; 14 | 15 | let description = [{ 16 | A type for polynomials with integer coefficients in a single-variable polynomial ring. 17 | }]; 18 | 19 | let parameters = (ins "int":$degreeBound); 20 | let assemblyFormat = "`<` $degreeBound `>`"; 21 | } 22 | 23 | #endif // LIB_DIALECT_POLY_POLYTYPES_TD_ 24 | -------------------------------------------------------------------------------- /tests/poly_to_llvm.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt --poly-to-llvm %s | mlir-translate --mlir-to-llvmir | llc --relocation-model=pic -filetype=obj > %t 2 | // RUN: clang -c %project_source_dir/tests/poly_to_llvm_main.c 3 | // RUN: clang poly_to_llvm_main.o %t -o a.out 4 | // RUN: ./a.out | FileCheck %s 5 | 6 | // CHECK: 351 7 | func.func @test_poly_fn(%arg : i32) -> i32 { 8 | %tens = tensor.splat %arg : tensor<10xi32> 9 | %input = poly.from_tensor %tens : tensor<10xi32> -> !poly.poly<10> 10 | %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> 11 | %1 = poly.add %0, %input : !poly.poly<10> 12 | %2 = poly.mul %1, %1 : !poly.poly<10> 13 | %3 = poly.sub %2, %input : !poly.poly<10> 14 | %4 = poly.eval %3, %arg: (!poly.poly<10>, i32) -> i32 15 | return %4 : i32 16 | } 17 | -------------------------------------------------------------------------------- /lib/Transform/Affine/Passes.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_AFFINE_PASSES_TD_ 2 | #define LIB_TRANSFORM_AFFINE_PASSES_TD_ 3 | 4 | include "mlir/Pass/PassBase.td" 5 | 6 | def AffineFullUnroll : Pass<"affine-full-unroll"> { 7 | let summary = "Fully unroll all affine loops"; 8 | let description = [{ 9 | Fully unroll all affine loops. 10 | }]; 11 | let dependentDialects = ["mlir::affine::AffineDialect"]; 12 | } 13 | 14 | def AffineFullUnrollPatternRewrite : Pass<"affine-full-unroll-rewrite"> { 15 | let summary = "Fully unroll all affine loops using the pattern rewrite engine"; 16 | let description = [{ 17 | Fully unroll all affine loops using the pattern rewrite engine. 18 | }]; 19 | let dependentDialects = ["mlir::affine::AffineDialect"]; 20 | } 21 | 22 | #endif // LIB_TRANSFORM_AFFINE_PASSES_TD_ 23 | -------------------------------------------------------------------------------- /lib/Conversion/PolyToStandard/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_library(PolyToStandard 2 | PolyToStandard.cpp 3 | 4 | ${PROJECT_SOURCE_DIR}/lib/Conversion/PolyToStandard/ 5 | ADDITIONAL_HEADER_DIRS 6 | 7 | DEPENDS 8 | PolyToStandardPassIncGen 9 | 10 | LINK_COMPONENTS 11 | Core 12 | 13 | LINK_LIBS PUBLIC 14 | MLIRPoly 15 | MLIRArithDialect 16 | MLIRFuncDialect 17 | MLIRFuncTransforms 18 | MLIRIR 19 | MLIRPass 20 | MLIRSCFDialect 21 | MLIRTensorDialect 22 | MLIRTransforms 23 | ) 24 | 25 | set(LLVM_TARGET_DEFINITIONS PolyToStandard.td) 26 | mlir_tablegen(PolyToStandard.h.inc -gen-pass-decls -name PolyToStandard) 27 | add_dependencies(mlir-headers MLIRPolyOpsIncGen) 28 | add_public_tablegen_target(PolyToStandardPassIncGen) 29 | add_mlir_doc(PolyToStandard PolyToStandard PolyToStandard/ -gen-pass-doc) 30 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyTraits.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYTRAITS_H_ 2 | #define LIB_DIALECT_POLY_POLYTRAITS_H_ 3 | 4 | #include "mlir/include/mlir/IR/OpDefinition.h" 5 | 6 | namespace mlir::tutorial::poly { 7 | 8 | template 9 | class Has32BitArguments : public OpTrait::TraitBase { 10 | public: 11 | static LogicalResult verifyTrait(Operation *op) { 12 | for (auto type : op->getOperandTypes()) { 13 | // OK to skip non-integer operand types 14 | if (!type.isIntOrIndex()) continue; 15 | 16 | if (!type.isInteger(32)) { 17 | return op->emitOpError() 18 | << "requires each numeric operand to be a 32-bit integer"; 19 | } 20 | } 21 | 22 | return success(); 23 | } 24 | }; 25 | 26 | } 27 | 28 | #endif // LIB_DIALECT_POLY_POLYTRAITS_H_ 29 | -------------------------------------------------------------------------------- /lib/Conversion/PolyToStandard/PolyToStandard.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_ 2 | #define LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_ 3 | 4 | #include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project 5 | 6 | // Extra includes needed for dependent dialects 7 | #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project 8 | #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project 9 | 10 | namespace mlir { 11 | namespace tutorial { 12 | namespace poly { 13 | 14 | #define GEN_PASS_DECL 15 | #include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" 16 | 17 | #define GEN_PASS_REGISTRATION 18 | #include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" 19 | 20 | } // namespace poly 21 | } // namespace tutorial 22 | } // namespace mlir 23 | 24 | #endif // LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_ 25 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyPatterns.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYPATTERNS_TD_ 2 | #define LIB_DIALECT_POLY_POLYPATTERNS_TD_ 3 | 4 | include "PolyOps.td" 5 | include "mlir/Dialect/Complex/IR/ComplexOps.td" 6 | include "mlir/IR/PatternBase.td" 7 | 8 | def LiftConjThroughEval : Pat< 9 | (Poly_EvalOp $f, (ConjOp $z, $fastmath)), 10 | (ConjOp (Poly_EvalOp $f, $z), $fastmath) 11 | >; 12 | 13 | def HasOneUse: Constraint, "has one use">; 14 | 15 | // Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses. 16 | def DifferenceOfSquares : Pattern< 17 | (Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)), 18 | [ 19 | (Poly_AddOp:$sum $x, $y), 20 | (Poly_SubOp:$diff $x, $y), 21 | (Poly_MulOp:$res $sum, $diff), 22 | ], 23 | [(HasOneUse:$lhs), (HasOneUse:$rhs)] 24 | >; 25 | 26 | #endif // LIB_DIALECT_POLY_POLYPATTERNS_TD_ 27 | -------------------------------------------------------------------------------- /lib/Transform/Arith/Passes.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_TRANSFORM_ARITH_PASSES_TD_ 2 | #define LIB_TRANSFORM_ARITH_PASSES_TD_ 3 | 4 | include "mlir/Dialect/PDL/IR/PDLDialect.td" 5 | include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td" 6 | include "mlir/Pass/PassBase.td" 7 | 8 | def MulToAdd : Pass<"mul-to-add"> { 9 | let summary = "Convert multiplications to repeated additions"; 10 | let description = [{ 11 | Convert multiplications to repeated additions. 12 | }]; 13 | } 14 | 15 | def MulToAddPdll : Pass<"mul-to-add-pdll"> { 16 | let summary = "Convert multiplications to repeated additions using pdll"; 17 | let description = [{ 18 | Convert multiplications to repeated additions (using pdll). 19 | }]; 20 | let dependentDialects = [ 21 | "mlir::pdl::PDLDialect", 22 | "mlir::pdl_interp::PDLInterpDialect", 23 | ]; 24 | } 25 | 26 | #endif // LIB_TRANSFORM_ARITH_PASSES_TD_ 27 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyDialect.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Dialect/Noisy/NoisyDialect.h" 2 | 3 | #include "lib/Dialect/Noisy/NoisyOps.h" 4 | #include "lib/Dialect/Noisy/NoisyTypes.h" 5 | #include "mlir/include/mlir/IR/Builders.h" 6 | #include "llvm/include/llvm/ADT/TypeSwitch.h" 7 | 8 | #include "lib/Dialect/Noisy/NoisyDialect.cpp.inc" 9 | #define GET_TYPEDEF_CLASSES 10 | #include "lib/Dialect/Noisy/NoisyTypes.cpp.inc" 11 | #define GET_OP_CLASSES 12 | #include "lib/Dialect/Noisy/NoisyOps.cpp.inc" 13 | 14 | namespace mlir { 15 | namespace tutorial { 16 | namespace noisy { 17 | 18 | void NoisyDialect::initialize() { 19 | addTypes< 20 | #define GET_TYPEDEF_LIST 21 | #include "lib/Dialect/Noisy/NoisyTypes.cpp.inc" 22 | >(); 23 | addOperations< 24 | #define GET_OP_LIST 25 | #include "lib/Dialect/Noisy/NoisyOps.cpp.inc" 26 | >(); 27 | } 28 | 29 | } // namespace noisy 30 | } // namespace tutorial 31 | } // namespace mlir 32 | -------------------------------------------------------------------------------- /tests/control_flow_sink.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt -control-flow-sink %s | FileCheck %s 2 | 3 | // Test that operations can be sunk. 4 | 5 | // CHECK-LABEL: @test_simple_sink 6 | func.func @test_simple_sink(%arg0: i1) -> !poly.poly<10> { 7 | %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> 8 | %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> 9 | %1 = arith.constant dense<[9, 8, 16]> : tensor<3xi32> 10 | %p1 = poly.from_tensor %1 : tensor<3xi32> -> !poly.poly<10> 11 | // CHECK-NOT: poly.from_tensor 12 | // CHECK: scf.if 13 | %4 = scf.if %arg0 -> (!poly.poly<10>) { 14 | // CHECK: poly.from_tensor 15 | %2 = poly.mul %p0, %p0 : !poly.poly<10> 16 | scf.yield %2 : !poly.poly<10> 17 | // CHECK: else 18 | } else { 19 | // CHECK: poly.from_tensor 20 | %3 = poly.mul %p1, %p1 : !poly.poly<10> 21 | scf.yield %3 : !poly.poly<10> 22 | } 23 | return %4 : !poly.poly<10> 24 | } 25 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Inlining `add_mlir_dialect(Noisy noisy)` commands so that 2 | # we can custom name `*.inc` generated files. 3 | set(LLVM_TARGET_DEFINITIONS NoisyOps.td) 4 | mlir_tablegen(NoisyOps.h.inc -gen-op-decls) 5 | mlir_tablegen(NoisyOps.cpp.inc -gen-op-defs) 6 | mlir_tablegen(NoisyTypes.h.inc -gen-typedef-decls -typedefs-dialect=noisy) 7 | mlir_tablegen(NoisyTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=noisy) 8 | mlir_tablegen(NoisyDialect.h.inc -gen-dialect-decls -dialect=noisy) 9 | mlir_tablegen(NoisyDialect.cpp.inc -gen-dialect-defs -dialect=noisy) 10 | add_public_tablegen_target(MLIRNoisyOpsIncGen) 11 | add_dependencies(mlir-headers MLIRNoisyOpsIncGen) 12 | 13 | add_mlir_doc(NoisyDialect NoisyDialect Noisy/ -gen-dialect-doc) 14 | 15 | add_mlir_dialect_library(MLIRNoisy 16 | NoisyDialect.cpp 17 | NoisyOps.cpp 18 | 19 | ADDITIONAL_HEADER_DIRS 20 | ${PROJECT_SOURCE_DIR}/lib/Dialect/Noisy 21 | ) 22 | -------------------------------------------------------------------------------- /lib/Transform/Affine/AffineFullUnroll.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Transform/Affine/AffineFullUnroll.h" 2 | #include "mlir/Dialect/Affine/IR/AffineOps.h" 3 | #include "mlir/Dialect/Affine/LoopUtils.h" 4 | #include "mlir/include/mlir/Pass/Pass.h" 5 | 6 | namespace mlir { 7 | namespace tutorial { 8 | 9 | #define GEN_PASS_DEF_AFFINEFULLUNROLL 10 | #include "lib/Transform/Affine/Passes.h.inc" 11 | 12 | using mlir::affine::AffineForOp; 13 | using mlir::affine::loopUnrollFull; 14 | 15 | // A pass that manually walks the IR 16 | struct AffineFullUnroll : impl::AffineFullUnrollBase { 17 | using AffineFullUnrollBase::AffineFullUnrollBase; 18 | 19 | void runOnOperation() { 20 | getOperation()->walk([&](AffineForOp op) { 21 | if (failed(loopUnrollFull(op))) { 22 | op.emitError("unrolling failed"); 23 | signalPassFailure(); 24 | } 25 | }); 26 | } 27 | }; 28 | 29 | 30 | 31 | } // namespace tutorial 32 | } // namespace mlir 33 | -------------------------------------------------------------------------------- /tests/code_motion.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s --loop-invariant-code-motion > %t 2 | // RUN: FileCheck %s < %t 3 | 4 | module { 5 | // CHECK-LABEL: func.func @test_loop_invariant_code_motion 6 | func.func @test_loop_invariant_code_motion() -> !poly.poly<10> { 7 | %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> 8 | %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> 9 | 10 | %1 = arith.constant dense<[9, 8, 16]> : tensor<3xi32> 11 | %p1 = poly.from_tensor %1 : tensor<3xi32> -> !poly.poly<10> 12 | // CHECK: poly.mul 13 | 14 | // CHECK: affine.for 15 | %ret_val = affine.for %i = 0 to 100 iter_args(%sum_iter = %p0) -> !poly.poly<10> { 16 | // The poly.mul should be hoisted out of the loop. 17 | // CHECK-NOT: poly.mul 18 | %2 = poly.mul %p0, %p1 : !poly.poly<10> 19 | %sum_next = poly.add %sum_iter, %2 : !poly.poly<10> 20 | affine.yield %sum_next : !poly.poly<10> 21 | } 22 | 23 | return %ret_val : !poly.poly<10> 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.h: -------------------------------------------------------------------------------- 1 | #ifndef LIB_ANALYSIS_REDUCENOISEANALYSIS_REDUCENOISEANALYSIS_H_ 2 | #define LIB_ANALYSIS_REDUCENOISEANALYSIS_REDUCENOISEANALYSIS_H_ 3 | 4 | #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project 5 | #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project 6 | #include "mlir/include/mlir/IR/Value.h" // from @llvm-project 7 | 8 | namespace mlir { 9 | namespace tutorial { 10 | 11 | class ReduceNoiseAnalysis { 12 | public: 13 | ReduceNoiseAnalysis(Operation *op); 14 | ~ReduceNoiseAnalysis() = default; 15 | 16 | /// Return true if a reduce_noise op should be inserted after the given 17 | /// operation, according to the solution to the optimization problem. 18 | bool shouldInsertReduceNoise(Operation *op) const { 19 | return solution.lookup(op); 20 | } 21 | 22 | private: 23 | llvm::DenseMap solution; 24 | }; 25 | 26 | } // namespace tutorial 27 | } // namespace mlir 28 | 29 | #endif // LIB_ANALYSIS_REDUCENOISEANALYSIS_REDUCENOISEANALYSIS_H_ 30 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Inlining `add_mlir_dialect(Poly poly)` commands so that 2 | # we can custom name `*.inc` generated files. 3 | set(LLVM_TARGET_DEFINITIONS PolyOps.td) 4 | mlir_tablegen(PolyOps.h.inc -gen-op-decls) 5 | mlir_tablegen(PolyOps.cpp.inc -gen-op-defs) 6 | mlir_tablegen(PolyTypes.h.inc -gen-typedef-decls -typedefs-dialect=poly) 7 | mlir_tablegen(PolyTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=poly) 8 | mlir_tablegen(PolyDialect.h.inc -gen-dialect-decls -dialect=poly) 9 | mlir_tablegen(PolyDialect.cpp.inc -gen-dialect-defs -dialect=poly) 10 | add_public_tablegen_target(MLIRPolyOpsIncGen) 11 | add_dependencies(mlir-headers MLIRPolyOpsIncGen) 12 | 13 | add_mlir_doc(PolyDialect PolyDialect Poly/ -gen-dialect-doc) 14 | 15 | set(LLVM_TARGET_DEFINITIONS PolyPatterns.td) 16 | mlir_tablegen(PolyCanonicalize.cpp.inc -gen-rewriters) 17 | add_public_tablegen_target(MLIRPolyCanonicalizationIncGen) 18 | 19 | add_mlir_dialect_library(MLIRPoly 20 | PolyDialect.cpp 21 | PolyOps.cpp 22 | 23 | ADDITIONAL_HEADER_DIRS 24 | ${PROJECT_SOURCE_DIR}/lib/Dialect/Poly 25 | 26 | LINK_LIBS PUBLIC 27 | ) 28 | -------------------------------------------------------------------------------- /.github/workflows/build_and_test.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test w/Bazel 2 | permissions: read-all 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | jobs: 11 | build-and-test: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out repository code 15 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4.2.2 16 | 17 | - name: Cache bazel build artifacts 18 | uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # pin@v4.2.3 19 | with: 20 | path: | 21 | ~/.cache/bazel 22 | # add extensions.bzl so that a new build occurs after an LLVM commit hash update 23 | key: ${{ runner.os }}-bazel-${{ hashFiles('extensions.bzl') }}-${{ hashFiles('.bazelversion', '.bazelrc', 'MODULE.bazel') }} 24 | restore-keys: | 25 | ${{ runner.os }}-bazel-${{ hashFiles('extensions.bzl') }} 26 | 27 | - name: "Run `bazel build`" 28 | run: | 29 | bazel build -c fastbuild //... 30 | 31 | - name: "Run `bazel test`" 32 | run: | 33 | bazel test -c fastbuild //... 34 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyDialect.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Dialect/Poly/PolyDialect.h" 2 | 3 | #include "lib/Dialect/Poly/PolyOps.h" 4 | #include "lib/Dialect/Poly/PolyTypes.h" 5 | #include "mlir/include/mlir/IR/Builders.h" 6 | #include "llvm/include/llvm/ADT/TypeSwitch.h" 7 | 8 | #include "lib/Dialect/Poly/PolyDialect.cpp.inc" 9 | #define GET_TYPEDEF_CLASSES 10 | #include "lib/Dialect/Poly/PolyTypes.cpp.inc" 11 | #define GET_OP_CLASSES 12 | #include "lib/Dialect/Poly/PolyOps.cpp.inc" 13 | 14 | namespace mlir { 15 | namespace tutorial { 16 | namespace poly { 17 | 18 | void PolyDialect::initialize() { 19 | addTypes< 20 | #define GET_TYPEDEF_LIST 21 | #include "lib/Dialect/Poly/PolyTypes.cpp.inc" 22 | >(); 23 | addOperations< 24 | #define GET_OP_LIST 25 | #include "lib/Dialect/Poly/PolyOps.cpp.inc" 26 | >(); 27 | } 28 | 29 | Operation *PolyDialect::materializeConstant(OpBuilder &builder, Attribute value, 30 | Type type, Location loc) { 31 | auto coeffs = dyn_cast(value); 32 | if (!coeffs) 33 | return nullptr; 34 | return builder.create(loc, type, coeffs); 35 | } 36 | 37 | } // namespace poly 38 | } // namespace tutorial 39 | } // namespace mlir 40 | -------------------------------------------------------------------------------- /tests/mul_to_add_pdll.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s --mul-to-add-pdll | FileCheck %s 2 | 3 | func.func @just_power_of_two(%arg: i32) -> i32 { 4 | %0 = arith.constant 8 : i32 5 | %1 = arith.muli %arg, %0 : i32 6 | func.return %1 : i32 7 | } 8 | 9 | // CHECK-LABEL: func.func @just_power_of_two( 10 | // CHECK-SAME: %[[ARG:.*]]: i32 11 | // CHECK-SAME: ) -> i32 { 12 | // CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]] 13 | // CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]] 14 | // CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]] 15 | // CHECK: return %[[SUM_2]] : i32 16 | // CHECK: } 17 | 18 | 19 | func.func @power_of_two_plus_one(%arg: i32) -> i32 { 20 | %0 = arith.constant 9 : i32 21 | %1 = arith.muli %arg, %0 : i32 22 | func.return %1 : i32 23 | } 24 | 25 | // CHECK-LABEL: func.func @power_of_two_plus_one( 26 | // CHECK-SAME: %[[ARG:.*]]: i32 27 | // CHECK-SAME: ) -> i32 { 28 | // CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]] 29 | // CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]] 30 | // CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]] 31 | // CHECK: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[ARG]] 32 | // CHECK: return %[[SUM_3]] : i32 33 | // CHECK: } 34 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.20.0) 2 | 3 | project(mlir-tutorial LANGUAGES CXX C) 4 | 5 | set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to") 6 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 7 | set(BUILD_DEPS ON) 8 | 9 | find_package(MLIR REQUIRED CONFIG) 10 | 11 | message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") 12 | message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") 13 | 14 | set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) 15 | 16 | include(AddLLVM) 17 | include(TableGen) 18 | 19 | list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") 20 | include(AddMLIR) 21 | include_directories(${LLVM_INCLUDE_DIRS}) 22 | include_directories(${MLIR_INCLUDE_DIRS}) 23 | include_directories(${PROJECT_SOURCE_DIR}) 24 | include_directories(${PROJECT_SOURCE_DIR}/externals/llvm-project) 25 | include_directories(${PROJECT_BINARY_DIR}) 26 | 27 | message(STATUS "Fetching or-tools...") 28 | include(FetchContent) 29 | FetchContent_Declare( 30 | or-tools 31 | GIT_REPOSITORY https://github.com/google/or-tools.git 32 | GIT_TAG v9.11 33 | ) 34 | FetchContent_MakeAvailable(or-tools) 35 | message(STATUS "Done fetching or-tools") 36 | 37 | add_subdirectory(tests) 38 | add_subdirectory(tools) 39 | add_subdirectory(lib) 40 | -------------------------------------------------------------------------------- /tests/mul_to_add.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s --mul-to-add > %t 2 | // RUN: FileCheck %s < %t 3 | 4 | func.func @just_power_of_two(%arg: i32) -> i32 { 5 | %0 = arith.constant 8 : i32 6 | %1 = arith.muli %arg, %0 : i32 7 | func.return %1 : i32 8 | } 9 | 10 | // CHECK-LABEL: func.func @just_power_of_two( 11 | // CHECK-SAME: %[[ARG:.*]]: i32 12 | // CHECK-SAME: ) -> i32 { 13 | // CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]] 14 | // CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]] 15 | // CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]] 16 | // CHECK: return %[[SUM_2]] : i32 17 | // CHECK: } 18 | 19 | 20 | func.func @power_of_two_plus_one(%arg: i32) -> i32 { 21 | %0 = arith.constant 9 : i32 22 | %1 = arith.muli %arg, %0 : i32 23 | func.return %1 : i32 24 | } 25 | 26 | // CHECK-LABEL: func.func @power_of_two_plus_one( 27 | // CHECK-SAME: %[[ARG:.*]]: i32 28 | // CHECK-SAME: ) -> i32 { 29 | // CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]] 30 | // CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]] 31 | // CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]] 32 | // CHECK: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[ARG]] 33 | // CHECK: return %[[SUM_3]] : i32 34 | // CHECK: } 35 | -------------------------------------------------------------------------------- /lib/Conversion/PolyToStandard/BUILD: -------------------------------------------------------------------------------- 1 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | gentbl_cc_library( 8 | name = "pass_inc_gen", 9 | tbl_outs = [ 10 | ( 11 | [ 12 | "-gen-pass-decls", 13 | "-name=PolyToStandard", 14 | ], 15 | "PolyToStandard.h.inc", 16 | ), 17 | ], 18 | tblgen = "@llvm-project//mlir:mlir-tblgen", 19 | td_file = "PolyToStandard.td", 20 | deps = [ 21 | "@llvm-project//mlir:OpBaseTdFiles", 22 | "@llvm-project//mlir:PassBaseTdFiles", 23 | ], 24 | ) 25 | 26 | cc_library( 27 | name = "PolyToStandard", 28 | srcs = ["PolyToStandard.cpp"], 29 | hdrs = ["PolyToStandard.h"], 30 | deps = [ 31 | "pass_inc_gen", 32 | "//lib/Dialect/Poly", 33 | "@llvm-project//mlir:ArithDialect", 34 | "@llvm-project//mlir:FuncDialect", 35 | "@llvm-project//mlir:FuncTransforms", 36 | "@llvm-project//mlir:IR", 37 | "@llvm-project//mlir:Pass", 38 | "@llvm-project//mlir:SCFDialect", 39 | "@llvm-project//mlir:TensorDialect", 40 | "@llvm-project//mlir:Transforms", 41 | ], 42 | ) 43 | -------------------------------------------------------------------------------- /tests/ctlz_runner.mlir: -------------------------------------------------------------------------------- 1 | // RUN: mlir-opt %s \ 2 | // RUN: -pass-pipeline="builtin.module( \ 3 | // RUN: convert-math-to-funcs{convert-ctlz}, \ 4 | // RUN: func.func(convert-scf-to-cf,convert-arith-to-llvm), \ 5 | // RUN: convert-func-to-llvm, \ 6 | // RUN: convert-cf-to-llvm, \ 7 | // RUN: reconcile-unrealized-casts)" \ 8 | // RUN: | mlir-runner -e test_7i32_to_29 -entry-point-result=i32 > %t 9 | // RUN: FileCheck %s --check-prefix=CHECK_TEST_7i32_TO_29 < %t 10 | 11 | func.func @test_7i32_to_29() -> i32 { 12 | %arg = arith.constant 7 : i32 13 | %0 = math.ctlz %arg : i32 14 | func.return %0 : i32 15 | } 16 | // CHECK_TEST_7i32_TO_29: 29 17 | 18 | 19 | // RUN: mlir-opt %s \ 20 | // RUN: -pass-pipeline="builtin.module( \ 21 | // RUN: convert-math-to-funcs{convert-ctlz}, \ 22 | // RUN: func.func(convert-scf-to-cf,convert-arith-to-llvm), \ 23 | // RUN: convert-func-to-llvm, \ 24 | // RUN: convert-cf-to-llvm, \ 25 | // RUN: reconcile-unrealized-casts)" \ 26 | // RUN: | mlir-runner -e test_7i64_to_61 -entry-point-result=i64 > %t 27 | // RUN: FileCheck %s --check-prefix=CHECK_TEST_7i64_TO_61 < %t 28 | func.func @test_7i64_to_61() -> i64 { 29 | %arg = arith.constant 7 : i64 30 | %0 = math.ctlz %arg : i64 31 | func.return %0 : i64 32 | } 33 | // CHECK_TEST_7i64_TO_61: 61 34 | -------------------------------------------------------------------------------- /tools/BUILD: -------------------------------------------------------------------------------- 1 | # The "tools" directory contains binary targets that expose the public API of 2 | # the passes in the project. 3 | 4 | package( 5 | default_visibility = ["//visibility:public"], 6 | ) 7 | 8 | # We name the tool `tutorial-opt` following the pattern of `mlir-opt`. 9 | cc_binary( 10 | name = "tutorial-opt", 11 | srcs = ["tutorial-opt.cpp"], 12 | includes = ["include"], 13 | deps = [ 14 | "//lib/Conversion/PolyToStandard", 15 | "//lib/Dialect/Noisy", 16 | "//lib/Dialect/Poly", 17 | "//lib/Transform/Affine:Passes", 18 | "//lib/Transform/Arith:Passes", 19 | "//lib/Transform/Noisy:Passes", 20 | "@llvm-project//mlir:AllPassesAndDialects", 21 | "@llvm-project//mlir:ArithToLLVM", 22 | "@llvm-project//mlir:BufferizationPipelines", 23 | "@llvm-project//mlir:BufferizationTransforms", 24 | "@llvm-project//mlir:ControlFlowToLLVM", 25 | "@llvm-project//mlir:FuncToLLVM", 26 | "@llvm-project//mlir:LinalgTransforms", 27 | "@llvm-project//mlir:MemRefToLLVM", 28 | "@llvm-project//mlir:MemRefTransforms", 29 | "@llvm-project//mlir:MlirOptLib", 30 | "@llvm-project//mlir:Pass", 31 | "@llvm-project//mlir:SCFToControlFlow", 32 | "@llvm-project//mlir:TensorToLinalg", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /tests/affine_loop_unroll.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s --affine-full-unroll > %t 2 | // RUN: FileCheck %s < %t 3 | 4 | // RUN: tutorial-opt %s --affine-full-unroll-rewrite > %t 5 | // RUN: FileCheck %s < %t 6 | 7 | func.func @test_single_nested_loop(%buffer: memref<4xi32>) -> (i32) { 8 | %sum_0 = arith.constant 0 : i32 9 | // CHECK-LABEL: test_single_nested_loop 10 | // CHECK-NOT: affine.for 11 | %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 { 12 | %t = affine.load %buffer[%i] : memref<4xi32> 13 | %sum_next = arith.addi %sum_iter, %t : i32 14 | affine.yield %sum_next : i32 15 | } 16 | return %sum : i32 17 | } 18 | 19 | func.func @test_doubly_nested_loop(%buffer: memref<4x3xi32>) -> (i32) { 20 | %sum_0 = arith.constant 0 : i32 21 | // CHECK-LABEL: test_doubly_nested_loop 22 | // CHECK-NOT: affine.for 23 | %sum = affine.for %i = 0 to 4 iter_args(%sum_iter = %sum_0) -> i32 { 24 | %sum_nested_0 = arith.constant 0 : i32 25 | %sum_nested = affine.for %j = 0 to 3 iter_args(%sum_nested_iter = %sum_nested_0) -> i32 { 26 | %t = affine.load %buffer[%i, %j] : memref<4x3xi32> 27 | %sum_nested_next = arith.addi %sum_nested_iter, %t : i32 28 | affine.yield %sum_nested_next : i32 29 | } 30 | %sum_next = arith.addi %sum_iter, %sum_nested : i32 31 | affine.yield %sum_next : i32 32 | } 33 | return %sum : i32 34 | } 35 | -------------------------------------------------------------------------------- /tests/sccp.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt -pass-pipeline="builtin.module(func.func(sccp))" %s | FileCheck %s 2 | 3 | // Note how sscp creates new constants for the computed values, 4 | // though it does not remove the dead code. 5 | 6 | // CHECK-LABEL: @test_arith_sccp 7 | // CHECK-NEXT: %[[v0:.*]] = arith.constant 63 : i32 8 | // CHECK-NEXT: %[[v1:.*]] = arith.constant 49 : i32 9 | // CHECK-NEXT: %[[v2:.*]] = arith.constant 14 : i32 10 | // CHECK-NEXT: %[[v3:.*]] = arith.constant 8 : i32 11 | // CHECK-NEXT: %[[v4:.*]] = arith.constant 7 : i32 12 | // CHECK-NEXT: return %[[v2]] : i32 13 | func.func @test_arith_sccp() -> i32 { 14 | %0 = arith.constant 7 : i32 15 | %1 = arith.constant 8 : i32 16 | %2 = arith.addi %0, %0 : i32 17 | %3 = arith.muli %0, %0 : i32 18 | %4 = arith.addi %2, %3 : i32 19 | return %2 : i32 20 | } 21 | 22 | // CHECK-LABEL: @test_poly_sccp 23 | func.func @test_poly_sccp() -> !poly.poly<10> { 24 | %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> 25 | %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> 26 | // CHECK: poly.constant dense<[2, 8, 20, 24, 18]> 27 | // CHECK: poly.constant dense<[1, 4, 10, 12, 9]> 28 | // CHECK: poly.constant dense<[1, 2, 3]> 29 | // CHECK-NOT: poly.mul 30 | // CHECK-NOT: poly.add 31 | %2 = poly.mul %p0, %p0 : !poly.poly<10> 32 | %3 = poly.mul %p0, %p0 : !poly.poly<10> 33 | %4 = poly.add %2, %3 : !poly.poly<10> 34 | return %2 : !poly.poly<10> 35 | } 36 | -------------------------------------------------------------------------------- /lib/Transform/Noisy/BUILD: -------------------------------------------------------------------------------- 1 | # Passes that work with the Noisy dialect 2 | 3 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | ) 8 | 9 | gentbl_cc_library( 10 | name = "pass_inc_gen", 11 | tbl_outs = [ 12 | ( 13 | [ 14 | "-gen-pass-decls", 15 | "-name=Noisy", 16 | ], 17 | "Passes.h.inc", 18 | ), 19 | ( 20 | ["-gen-pass-doc"], 21 | "NoisyPasses.md", 22 | ), 23 | ], 24 | tblgen = "@llvm-project//mlir:mlir-tblgen", 25 | td_file = "Passes.td", 26 | deps = [ 27 | "@llvm-project//mlir:OpBaseTdFiles", 28 | "@llvm-project//mlir:PassBaseTdFiles", 29 | ], 30 | ) 31 | 32 | cc_library( 33 | name = "ReduceNoiseOptimizer", 34 | srcs = ["ReduceNoiseOptimizer.cpp"], 35 | hdrs = [ 36 | "Passes.h", 37 | "ReduceNoiseOptimizer.h", 38 | ], 39 | deps = [ 40 | ":pass_inc_gen", 41 | "//lib/Analysis/ReduceNoiseAnalysis", 42 | "//lib/Dialect/Noisy", 43 | "@llvm-project//llvm:Support", 44 | "@llvm-project//mlir:Analysis", 45 | "@llvm-project//mlir:Pass", 46 | "@llvm-project//mlir:Transforms", 47 | ], 48 | ) 49 | 50 | cc_library( 51 | name = "Passes", 52 | hdrs = ["Passes.h"], 53 | deps = [ 54 | ":ReduceNoiseOptimizer", 55 | ":pass_inc_gen", 56 | ], 57 | ) 58 | -------------------------------------------------------------------------------- /tests/lit.cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from lit.formats import ShTest 5 | 6 | # oddly, the `config` variable is defined in the context of the lit runner that 7 | # runs this module. 8 | 9 | config.name = "mlir_tutorial" 10 | config.test_format = ShTest() 11 | config.suffixes = [".mlir"] 12 | 13 | # lit executes relative to the directory 14 | # 15 | # bazel-bin/tests/.runfiles/_main/ 16 | # 17 | # which contains all the binary targets included in via the `data` attribute in 18 | # the lit.bzl macro, which in turn gets them from the filegroup //tests:test_utilities. 19 | # To manually inspect what is included in the filesystem in situ, add the 20 | # following to this script and run `bazel test //tests:` 21 | # 22 | # import subprocess 23 | # 24 | # print(subprocess.run(["pwd",]).stdout) 25 | # print(subprocess.run(["ls", "-l", os.environ["RUNFILES_DIR"]]).stdout) 26 | # print(subprocess.run([ "env", ]).stdout) 27 | # 28 | # Bazel defines RUNFILES_DIR which includes _main/ and third party 29 | # dependencies as their own directory. Generally, it seems that $PWD == 30 | # $RUNFILES_DIR/_main/ 31 | runfiles_dir = Path(os.environ["RUNFILES_DIR"]) 32 | 33 | # Fix tool paths to use _main instead of mlir_tutorial 34 | tool_relpaths = [ 35 | "+_repo_rules+llvm-project/mlir", 36 | "+_repo_rules+llvm-project/llvm", 37 | "_main/tools", 38 | ] 39 | 40 | config.environment["PATH"] = ( 41 | ":".join(str(runfiles_dir.joinpath(Path(path))) for path in tool_relpaths) 42 | + ":" 43 | + os.environ["PATH"] 44 | ) 45 | 46 | substitutions = { 47 | "%project_source_dir": str(runfiles_dir.joinpath(Path('_main'))), 48 | } 49 | config.substitutions.extend(substitutions.items()) 50 | -------------------------------------------------------------------------------- /extensions.bzl: -------------------------------------------------------------------------------- 1 | """Module extensions for MLIR Tutorial dependencies.""" 2 | 3 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") 4 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 5 | load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") 6 | 7 | def _mlir_tutorial_deps_impl(module_ctx): 8 | """Implementation of the mlir_tutorial_deps module extension.""" 9 | 10 | # Download LLVM/MLIR using a git repository 11 | new_git_repository( 12 | name = "llvm-raw", 13 | build_file_content = "# empty", 14 | commit = "d9190f8141661bd6120dea61d28ae8940fd775d0", 15 | init_submodules = False, 16 | remote = "https://github.com/llvm/llvm-project.git", 17 | ) 18 | 19 | # Optional LLVM dependencies for performance 20 | maybe( 21 | http_archive, 22 | name = "llvm_zstd", 23 | build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", 24 | sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", 25 | strip_prefix = "zstd-1.5.2", 26 | urls = [ 27 | "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz", 28 | ], 29 | ) 30 | 31 | maybe( 32 | http_archive, 33 | name = "llvm_zlib", 34 | build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD", 35 | sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731", 36 | strip_prefix = "zlib-ng-2.0.7", 37 | urls = [ 38 | "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip", 39 | ], 40 | ) 41 | 42 | mlir_tutorial_deps = module_extension( 43 | implementation = _mlir_tutorial_deps_impl, 44 | ) 45 | -------------------------------------------------------------------------------- /tests/poly_syntax.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s > %t 2 | // RUN FileCheck %s < %t 3 | 4 | module { 5 | // CHECK-LABEL: test_type_syntax 6 | func.func @test_type_syntax(%arg0: !poly.poly<10>) -> !poly.poly<10> { 7 | // CHECK: poly.poly 8 | return %arg0 : !poly.poly<10> 9 | } 10 | 11 | // CHECK-LABEL: test_op_syntax 12 | func.func @test_op_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> { 13 | // CHECK: poly.add 14 | %0 = poly.add %arg0, %arg1 : !poly.poly<10> 15 | // CHECK: poly.sub 16 | %1 = poly.sub %arg0, %arg1 : !poly.poly<10> 17 | // CHECK: poly.mul 18 | %2 = poly.mul %arg0, %arg1 : !poly.poly<10> 19 | 20 | %3 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> 21 | // CHECK: poly.from_tensor 22 | %4 = poly.from_tensor %3 : tensor<3xi32> -> !poly.poly<10> 23 | 24 | %5 = arith.constant 7 : i32 25 | // CHECK: poly.eval 26 | %6 = poly.eval %4, %5 : (!poly.poly<10>, i32) -> i32 27 | 28 | %z = complex.constant [1.0, 2.0] : complex 29 | // CHECK: poly.eval 30 | %complex_eval = poly.eval %4, %z : (!poly.poly<10>, complex) -> complex 31 | 32 | %7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>> 33 | // CHECK: poly.add 34 | %8 = poly.add %7, %7 : tensor<2x!poly.poly<10>> 35 | 36 | // CHECK: poly.constant 37 | %10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> 38 | %11 = poly.constant dense<[2, 3, 4]> : tensor<3xi8> : !poly.poly<10> 39 | %12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10> 40 | %13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10> 41 | 42 | // CHECK: poly.to_tensor 43 | %14 = poly.to_tensor %1 : !poly.poly<10> -> tensor<10xi32> 44 | 45 | return %4 : !poly.poly<10> 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /lib/Transform/Affine/AffineFullUnrollPatternRewrite.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Transform/Affine/AffineFullUnrollPatternRewrite.h" 2 | #include "mlir/Dialect/Affine/IR/AffineOps.h" 3 | #include "mlir/Dialect/Affine/LoopUtils.h" 4 | #include "mlir/IR/PatternMatch.h" 5 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 6 | #include "mlir/include/mlir/Pass/Pass.h" 7 | 8 | namespace mlir { 9 | namespace tutorial { 10 | 11 | #define GEN_PASS_DEF_AFFINEFULLUNROLLPATTERNREWRITE 12 | #include "lib/Transform/Affine/Passes.h.inc" 13 | 14 | using mlir::affine::AffineForOp; 15 | using mlir::affine::loopUnrollFull; 16 | 17 | // A pattern that matches on AffineForOp and unrolls it. 18 | struct AffineFullUnrollPattern : public OpRewritePattern { 19 | AffineFullUnrollPattern(mlir::MLIRContext *context) 20 | : OpRewritePattern(context, /*benefit=*/1) {} 21 | 22 | LogicalResult matchAndRewrite(AffineForOp op, 23 | PatternRewriter &rewriter) const override { 24 | // This is technically not allowed, since in a RewritePattern all 25 | // modifications to the IR are supposed to go through the `rewriter` arg, 26 | // but it works for our limited test cases. 27 | return loopUnrollFull(op); 28 | } 29 | }; 30 | 31 | // A pass that invokes the pattern rewrite engine. 32 | struct AffineFullUnrollPatternRewrite 33 | : impl::AffineFullUnrollPatternRewriteBase { 34 | using AffineFullUnrollPatternRewriteBase::AffineFullUnrollPatternRewriteBase; 35 | void runOnOperation() { 36 | mlir::RewritePatternSet patterns(&getContext()); 37 | patterns.add(&getContext()); 38 | // One could use GreedyRewriteConfig here to slightly tweak the behavior of 39 | // the pattern application. 40 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 41 | } 42 | }; 43 | 44 | } // namespace tutorial 45 | } // namespace mlir 46 | -------------------------------------------------------------------------------- /tests/lit.cmake.cfg.py: -------------------------------------------------------------------------------- 1 | # -*- Python -*- 2 | 3 | import os 4 | 5 | import lit.formats 6 | import lit.util 7 | 8 | from lit.llvm import llvm_config 9 | 10 | # Configuration file for the 'lit' test runner. 11 | 12 | # name: The name of this test suite. 13 | config.name = "MLIR_TUTORIAL" 14 | 15 | config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) 16 | 17 | # suffixes: A list of file extensions to treat as test files. 18 | config.suffixes = [".mlir"] 19 | 20 | # test_source_root: The root path where tests are located. 21 | config.test_source_root = os.path.dirname(__file__) 22 | 23 | # test_exec_root: The root path where tests should be run. 24 | config.test_exec_root = os.path.join(config.project_binary_dir, "tests") 25 | 26 | config.substitutions.append(("%PATH%", config.environment["PATH"])) 27 | config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) 28 | config.substitutions.append(("%project_source_dir", config.project_source_dir)) 29 | 30 | llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"]) 31 | 32 | llvm_config.use_default_substitutions() 33 | 34 | # excludes: A list of directories to exclude from the testsuite. The 'Inputs' 35 | # subdirectories contain auxiliary inputs for various tests in their parent 36 | # directories. 37 | config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"] 38 | 39 | # test_exec_root: The root path where tests should be run. 40 | config.test_exec_root = os.path.join(config.project_binary_dir, "test") 41 | config.project_tools_dir = os.path.join(config.project_binary_dir, "tools") 42 | 43 | # Tweak the PATH to include the tools dir. 44 | llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) 45 | 46 | tool_dirs = [config.project_tools_dir, config.llvm_tools_dir] 47 | tools = [ 48 | "mlir-opt", 49 | "mlir-runner", 50 | "tutorial-opt" 51 | ] 52 | 53 | llvm_config.add_tool_substitutions(tools, tool_dirs) 54 | -------------------------------------------------------------------------------- /lib/Transform/Arith/MulToAddPdll.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Transform/Arith/MulToAddPdll.h" 2 | #include "mlir/Dialect/Arith/IR/Arith.h" 3 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 4 | #include "mlir/IR/PatternMatch.h" 5 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 6 | #include "mlir/include/mlir/Pass/Pass.h" 7 | 8 | namespace mlir { 9 | namespace tutorial { 10 | 11 | #define GEN_PASS_DEF_MULTOADDPDLL 12 | #include "lib/Transform/Arith/Passes.h.inc" 13 | 14 | LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results, 15 | ArrayRef args) { 16 | Attribute attr = args[0].cast(); 17 | IntegerAttr cAttr = cast(attr); 18 | int64_t value = cAttr.getValue().getSExtValue(); 19 | results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2)); 20 | return success(); 21 | } 22 | 23 | LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results, 24 | ArrayRef args) { 25 | Attribute attr = args[0].cast(); 26 | IntegerAttr cAttr = cast(attr); 27 | int64_t value = cAttr.getValue().getSExtValue(); 28 | results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1)); 29 | return success(); 30 | } 31 | 32 | void registerNativeConstraints(RewritePatternSet &patterns) { 33 | patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl); 34 | patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl); 35 | } 36 | 37 | struct MulToAddPdll : impl::MulToAddPdllBase { 38 | using MulToAddPdllBase::MulToAddPdllBase; 39 | 40 | void runOnOperation() { 41 | mlir::RewritePatternSet patterns(&getContext()); 42 | populateGeneratedPDLLPatterns(patterns); 43 | registerNativeConstraints(patterns); 44 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 45 | } 46 | }; 47 | 48 | } // namespace tutorial 49 | } // namespace mlir 50 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyOps.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Dialect/Noisy/NoisyOps.h" 2 | 3 | namespace mlir { 4 | namespace tutorial { 5 | namespace noisy { 6 | 7 | ConstantIntRanges initialNoiseRange() { 8 | return ConstantIntRanges::fromUnsigned(APInt(32, 0), 9 | APInt(32, INITIAL_NOISE)); 10 | } 11 | 12 | ConstantIntRanges unionPlusOne(ArrayRef inputRanges) { 13 | auto lhsRange = inputRanges[0]; 14 | auto rhsRange = inputRanges[1]; 15 | auto joined = lhsRange.rangeUnion(rhsRange); 16 | return ConstantIntRanges::fromUnsigned(joined.umin(), joined.umax() + 1); 17 | } 18 | 19 | void EncodeOp::inferResultRanges(ArrayRef inputRanges, 20 | SetIntRangeFn setResultRange) { 21 | setResultRange(getResult(), initialNoiseRange()); 22 | } 23 | 24 | void AddOp::inferResultRanges(ArrayRef inputRanges, 25 | SetIntRangeFn setResultRange) { 26 | setResultRange(getResult(), unionPlusOne(inputRanges)); 27 | } 28 | 29 | void SubOp::inferResultRanges(ArrayRef inputRanges, 30 | SetIntRangeFn setResultRange) { 31 | setResultRange(getResult(), unionPlusOne(inputRanges)); 32 | } 33 | 34 | void MulOp::inferResultRanges(ArrayRef inputRanges, 35 | SetIntRangeFn setResultRange) { 36 | auto lhsRange = inputRanges[0]; 37 | auto rhsRange = inputRanges[1]; 38 | setResultRange(getResult(), ConstantIntRanges::fromUnsigned( 39 | lhsRange.umin() + rhsRange.umin(), 40 | lhsRange.umax() + rhsRange.umax())); 41 | } 42 | 43 | void ReduceNoiseOp::inferResultRanges(ArrayRef inputRanges, 44 | SetIntRangeFn setResultRange) { 45 | setResultRange(getResult(), initialNoiseRange()); 46 | } 47 | 48 | } // namespace noisy 49 | } // namespace tutorial 50 | } // namespace mlir 51 | -------------------------------------------------------------------------------- /lib/Transform/Affine/BUILD: -------------------------------------------------------------------------------- 1 | # Passes that work with the Affine dialect 2 | 3 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | ) 8 | 9 | gentbl_cc_library( 10 | name = "pass_inc_gen", 11 | tbl_outs = [ 12 | ( 13 | [ 14 | "-gen-pass-decls", 15 | "-name=Affine", 16 | ], 17 | "Passes.h.inc", 18 | ), 19 | ( 20 | ["-gen-pass-doc"], 21 | "AffinePasses.md", 22 | ), 23 | ], 24 | tblgen = "@llvm-project//mlir:mlir-tblgen", 25 | td_file = "Passes.td", 26 | deps = [ 27 | "@llvm-project//mlir:OpBaseTdFiles", 28 | "@llvm-project//mlir:PassBaseTdFiles", 29 | ], 30 | ) 31 | 32 | cc_library( 33 | name = "AffineFullUnroll", 34 | srcs = ["AffineFullUnroll.cpp"], 35 | hdrs = [ 36 | "AffineFullUnroll.h", 37 | "Passes.h", 38 | ], 39 | deps = [ 40 | ":pass_inc_gen", 41 | "@llvm-project//mlir:AffineDialect", 42 | "@llvm-project//mlir:AffineUtils", 43 | "@llvm-project//mlir:FuncDialect", 44 | "@llvm-project//mlir:Pass", 45 | "@llvm-project//mlir:Transforms", 46 | ], 47 | ) 48 | 49 | cc_library( 50 | name = "AffineFullUnrollPatternRewrite", 51 | srcs = ["AffineFullUnrollPatternRewrite.cpp"], 52 | hdrs = [ 53 | "AffineFullUnrollPatternRewrite.h", 54 | "Passes.h", 55 | ], 56 | deps = [ 57 | ":pass_inc_gen", 58 | "@llvm-project//mlir:AffineDialect", 59 | "@llvm-project//mlir:AffineUtils", 60 | "@llvm-project//mlir:FuncDialect", 61 | "@llvm-project//mlir:Pass", 62 | "@llvm-project//mlir:Transforms", 63 | ], 64 | ) 65 | 66 | cc_library( 67 | name = "Passes", 68 | hdrs = ["Passes.h"], 69 | deps = [ 70 | ":AffineFullUnroll", 71 | ":AffineFullUnrollPatternRewrite", 72 | ":pass_inc_gen", 73 | ], 74 | ) 75 | -------------------------------------------------------------------------------- /bazel/lit.bzl: -------------------------------------------------------------------------------- 1 | """Macros for defining lit tests.""" 2 | 3 | load("@bazel_skylib//lib:paths.bzl", "paths") 4 | load("@rules_python//python:py_test.bzl", "py_test") 5 | 6 | _DEFAULT_FILE_EXTS = ["mlir"] 7 | 8 | def lit_test(name = None, src = None, size = "small", tags = None): 9 | """Define a lit test. 10 | 11 | In its simplest form, a manually defined lit test would look like this: 12 | 13 | py_test( 14 | name = "ops.mlir.test", 15 | srcs = ["@llvm_project//llvm:lit"], 16 | args = ["-v", "tests/ops.mlir"], 17 | data = [":test_utilities", ":ops.mlir"], 18 | size = "small", 19 | main = "lit.py", 20 | ) 21 | 22 | Where the `ops.mlir` file contains the test cases in standard RUN + CHECK 23 | format. 24 | 25 | The adjacent :test_utilities target contains all the tools (like mlir-opt) 26 | and files (like lit.cfg.py) that are needed to run a lit test. lit.cfg.py 27 | further specifies the lit configuration, including augmenting $PATH to 28 | include any mlir-opt-like tools. 29 | 30 | This macro simplifies the above definition by filling in the boilerplate. 31 | 32 | Args: 33 | name: the name of the test. 34 | src: the source file for the test. 35 | size: the size of the test. 36 | tags: tags to pass to the target. 37 | """ 38 | if not src: 39 | fail("src must be specified") 40 | name = name or src + ".test" 41 | 42 | filegroup_name = name + ".filegroup" 43 | native.filegroup( 44 | name = filegroup_name, 45 | srcs = [src], 46 | ) 47 | 48 | py_test( 49 | name = name, 50 | size = size, 51 | # -v ensures lit outputs useful info during test failures 52 | args = ["-v", paths.join(native.package_name(), src)], 53 | data = ["@mlir_tutorial//tests:test_utilities", filegroup_name], 54 | deps = ["@mlir_tutorial_pip_deps//lit"], 55 | srcs = ["@llvm-project//llvm:lit"], 56 | main = "lit.py", 57 | python_version = "PY3", 58 | tags = tags, 59 | ) 60 | 61 | def glob_lit_tests(): 62 | """Searches the caller's directory for files to run as lit tests.""" 63 | tests = native.glob(["*.mlir"]) 64 | for curr_test in tests: 65 | lit_test(src = curr_test, size = "small") 66 | -------------------------------------------------------------------------------- /lib/Transform/Arith/BUILD: -------------------------------------------------------------------------------- 1 | # Passes that work with the Arith dialect 2 | 3 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | ) 8 | 9 | gentbl_cc_library( 10 | name = "pass_inc_gen", 11 | tbl_outs = [ 12 | ( 13 | [ 14 | "-gen-pass-decls", 15 | "-name=Arith", 16 | ], 17 | "Passes.h.inc", 18 | ), 19 | ( 20 | ["-gen-pass-doc"], 21 | "ArithPasses.md", 22 | ), 23 | ], 24 | tblgen = "@llvm-project//mlir:mlir-tblgen", 25 | td_file = "Passes.td", 26 | deps = [ 27 | "@llvm-project//mlir:OpBaseTdFiles", 28 | "@llvm-project//mlir:PDLDialectTdFiles", 29 | "@llvm-project//mlir:PDLInterpOpsTdFiles", 30 | "@llvm-project//mlir:PassBaseTdFiles", 31 | ], 32 | ) 33 | 34 | cc_library( 35 | name = "MulToAdd", 36 | srcs = ["MulToAdd.cpp"], 37 | hdrs = ["MulToAdd.h"], 38 | deps = [ 39 | ":pass_inc_gen", 40 | "@llvm-project//mlir:ArithDialect", 41 | "@llvm-project//mlir:FuncDialect", 42 | "@llvm-project//mlir:Pass", 43 | "@llvm-project//mlir:Transforms", 44 | ], 45 | ) 46 | 47 | cc_library( 48 | name = "Passes", 49 | hdrs = ["Passes.h"], 50 | deps = [ 51 | ":MulToAdd", 52 | ":MulToAddPdll", 53 | ":pass_inc_gen", 54 | ], 55 | ) 56 | 57 | gentbl_cc_library( 58 | name = "MulToAddPdllIncGen", 59 | tbl_outs = [ 60 | ( 61 | ["-x=cpp"], 62 | "MulToAddPdll.h.inc", 63 | ), 64 | ], 65 | tblgen = "@llvm-project//mlir:mlir-pdll", 66 | td_file = "MulToAdd.pdll", 67 | deps = [ 68 | "@llvm-project//mlir:ArithDialect", 69 | "@llvm-project//mlir:FuncDialect", 70 | "@llvm-project//mlir:ArithOpsTdFiles", 71 | ], 72 | ) 73 | 74 | cc_library( 75 | name = "MulToAddPdll", 76 | srcs = ["MulToAddPdll.cpp"], 77 | hdrs = ["MulToAddPdll.h"], 78 | deps = [ 79 | ":pass_inc_gen", 80 | ":MulToAddPdllIncGen", 81 | "@llvm-project//mlir:ArithDialect", 82 | "@llvm-project//mlir:FuncDialect", 83 | "@llvm-project//mlir:Pass", 84 | "@llvm-project//mlir:Transforms", 85 | ], 86 | ) 87 | 88 | -------------------------------------------------------------------------------- /tests/ctlz.mlir: -------------------------------------------------------------------------------- 1 | // RUN: mlir-opt %s --convert-math-to-funcs=convert-ctlz | FileCheck %s 2 | 3 | func.func @main(%arg0: i32) { 4 | %0 = math.ctlz %arg0 : i32 5 | func.return 6 | } 7 | // CHECK-LABEL: func.func @main( 8 | // CHECK-SAME: %[[VAL_0:.*]]: i32 9 | // CHECK-SAME: ) { 10 | // CHECK: %[[VAL_1:.*]] = call @__mlir_math_ctlz_i32(%[[VAL_0]]) : (i32) -> i32 11 | // CHECK: return 12 | // CHECK: } 13 | 14 | // CHECK-LABEL: func.func private @__mlir_math_ctlz_i32( 15 | // CHECK-SAME: %[[ARG:.*]]: i32 16 | // CHECK-SAME: ) -> i32 attributes {llvm.linkage = #llvm.linkage} { 17 | // CHECK: %[[C_32:.*]] = arith.constant 32 : i32 18 | // CHECK: %[[C_0:.*]] = arith.constant 0 : i32 19 | // CHECK: %[[ARGCMP:.*]] = arith.cmpi eq, %[[ARG]], %[[C_0]] : i32 20 | // CHECK: %[[OUT:.*]] = scf.if %[[ARGCMP]] -> (i32) { 21 | // CHECK: scf.yield %[[C_32]] : i32 22 | // CHECK: } else { 23 | // CHECK: %[[C_1INDEX:.*]] = arith.constant 1 : index 24 | // CHECK: %[[C_1I32:.*]] = arith.constant 1 : i32 25 | // CHECK: %[[C_32INDEX:.*]] = arith.constant 32 : index 26 | // CHECK: %[[N:.*]] = arith.constant 0 : i32 27 | // CHECK: %[[FOR_RET:.*]]:2 = scf.for %[[I:.*]] = %[[C_1INDEX]] to %[[C_32INDEX]] step %[[C_1INDEX]] 28 | // CHECK: iter_args(%[[ARG_ITER:.*]] = %[[ARG]], %[[N_ITER:.*]] = %[[N]]) -> (i32, i32) { 29 | // CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ARG_ITER]], %[[C_0]] : i32 30 | // CHECK: %[[IF_RET:.*]]:2 = scf.if %[[COND]] -> (i32, i32) { 31 | // CHECK: scf.yield %[[ARG_ITER]], %[[N_ITER]] : i32, i32 32 | // CHECK: } else { 33 | // CHECK: %[[N_NEXT:.*]] = arith.addi %[[N_ITER]], %[[C_1I32]] : i32 34 | // CHECK: %[[ARG_NEXT:.*]] = arith.shli %[[ARG_ITER]], %[[C_1I32]] : i32 35 | // CHECK: scf.yield %[[ARG_NEXT]], %[[N_NEXT]] : i32, i32 36 | // CHECK: } 37 | // CHECK: scf.yield %[[IF_RET]]#0, %[[IF_RET]]#1 : i32, i32 38 | // CHECK: } 39 | // CHECK: scf.yield %[[FOR_RET]]#1 : i32 40 | // CHECK: } 41 | // CHECK: return %[[OUT]] : i32 42 | // CHECK: } 43 | // NOCVT-NOT: __mlir_math_ctlz_i32 44 | -------------------------------------------------------------------------------- /lib/Transform/Arith/MulToAdd.pdll: -------------------------------------------------------------------------------- 1 | #include "mlir/Dialect/Arith/IR/ArithOps.td" 2 | 3 | Constraint IsPowerOfTwo(attr: Attr) [{ 4 | int64_t value = cast<::mlir::IntegerAttr>(attr).getValue().getSExtValue(); 5 | return success((value & (value - 1)) == 0); 6 | }]; 7 | 8 | // Currently, constraints that return values must be defined in C++ 9 | Constraint Halve(atttr: Attr) -> Attr; 10 | Constraint MinusOne(attr: Attr) -> Attr; 11 | 12 | // Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do 13 | // nothing. 14 | Pattern PowerOfTwoExpandRhs with benefit(2) { 15 | let root = op(op {value = const: Attr}, rhs: Value); 16 | IsPowerOfTwo(const); 17 | let halved: Attr = Halve(const); 18 | 19 | rewrite root with { 20 | let newConst = op {value = halved}; 21 | let newMul = op(newConst, rhs); 22 | let newAdd = op(newMul, newMul); 23 | replace root with newAdd; 24 | }; 25 | } 26 | 27 | Pattern PowerOfTwoExpandLhs with benefit(2) { 28 | let root = op(lhs: Value, op {value = const: Attr}); 29 | IsPowerOfTwo(const); 30 | let halved: Attr = Halve(const); 31 | 32 | rewrite root with { 33 | let newConst = op {value = halved}; 34 | let newMul = op(lhs, newConst); 35 | let newAdd = op(newMul, newMul); 36 | replace root with newAdd; 37 | }; 38 | } 39 | 40 | // Replace y = 9*x with y = 8*x + x 41 | Pattern PeelFromMulRhs with benefit(1) { 42 | let root = op(lhs: Value, op {value = const: Attr}); 43 | 44 | // We are guaranteed `value` is not a power of two, because the greedy 45 | // rewrite engine ensures the PowerOfTwoExpand pattern is run first, since 46 | // it has higher benefit. 47 | let minusOne: Attr = MinusOne(const); 48 | 49 | rewrite root with { 50 | let newConst = op {value = minusOne}; 51 | let newMul = op(lhs, newConst); 52 | let newAdd = op(newMul, lhs); 53 | replace root with newAdd; 54 | }; 55 | } 56 | 57 | Pattern PeelFromMulLhs with benefit(1) { 58 | let root = op(op {value = const: Attr}, rhs: Value); 59 | let minusOne: Attr = MinusOne(const); 60 | 61 | rewrite root with { 62 | let newConst = op {value = minusOne}; 63 | let newMul = op(newConst, rhs); 64 | let newAdd = op(newMul, rhs); 65 | replace root with newAdd; 66 | }; 67 | } 68 | -------------------------------------------------------------------------------- /tests/poly_canonicalize.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt --canonicalize %s | FileCheck %s 2 | 3 | // CHECK-LABEL: @test_simple 4 | func.func @test_simple() -> !poly.poly<10> { 5 | // CHECK: poly.constant dense<[2, 4, 6]> 6 | // CHECK-NEXT: return 7 | %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> 8 | %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> 9 | %2 = poly.add %p0, %p0 : !poly.poly<10> 10 | %3 = poly.mul %p0, %p0 : !poly.poly<10> 11 | %4 = poly.add %2, %3 : !poly.poly<10> 12 | return %2 : !poly.poly<10> 13 | } 14 | 15 | // CHECK-LABEL: func.func @test_difference_of_squares 16 | // CHECK-SAME: %[[x:.+]]: !poly.poly<3>, 17 | // CHECK-SAME: %[[y:.+]]: !poly.poly<3> 18 | func.func @test_difference_of_squares( 19 | %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> { 20 | // CHECK: %[[sum:.+]] = poly.add %[[x]], %[[y]] 21 | // CHECK: %[[diff:.+]] = poly.sub %[[x]], %[[y]] 22 | // CHECK: %[[mul:.+]] = poly.mul %[[sum]], %[[diff]] 23 | %2 = poly.mul %0, %0 : !poly.poly<3> 24 | %3 = poly.mul %1, %1 : !poly.poly<3> 25 | %4 = poly.sub %2, %3 : !poly.poly<3> 26 | %5 = poly.add %4, %4 : !poly.poly<3> 27 | return %5 : !poly.poly<3> 28 | } 29 | 30 | // CHECK-LABEL: func.func @test_difference_of_squares_other_uses 31 | // CHECK-SAME: %[[x:.+]]: !poly.poly<3>, 32 | // CHECK-SAME: %[[y:.+]]: !poly.poly<3> 33 | func.func @test_difference_of_squares_other_uses( 34 | %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> { 35 | // The canonicalization does not occur because x_squared has a second use. 36 | // CHECK: %[[x_squared:.+]] = poly.mul %[[x]], %[[x]] 37 | // CHECK: %[[y_squared:.+]] = poly.mul %[[y]], %[[y]] 38 | // CHECK: %[[diff:.+]] = poly.sub %[[x_squared]], %[[y_squared]] 39 | // CHECK: %[[sum:.+]] = poly.add %[[diff]], %[[x_squared]] 40 | %2 = poly.mul %0, %0 : !poly.poly<3> 41 | %3 = poly.mul %1, %1 : !poly.poly<3> 42 | %4 = poly.sub %2, %3 : !poly.poly<3> 43 | %5 = poly.add %4, %2 : !poly.poly<3> 44 | return %5 : !poly.poly<3> 45 | } 46 | 47 | // CHECK-LABEL: func.func @test_normalize_conj_through_eval 48 | // CHECK-SAME: %[[f:.+]]: !poly.poly<3>, 49 | // CHECK-SAME: %[[z:.+]]: complex 50 | func.func @test_normalize_conj_through_eval( 51 | %f: !poly.poly<3>, %z: complex) -> complex { 52 | // CHECK: %[[evaled:.+]] = poly.eval %[[f]], %[[z]] 53 | // CHECK-NEXT: %[[eval_bar:.+]] = complex.conj %[[evaled]] 54 | // CHECK-NEXT: return %[[eval_bar]] 55 | %z_bar = complex.conj %z : complex 56 | %evaled = poly.eval %f, %z_bar : (!poly.poly<3>, complex) -> complex 57 | return %evaled : complex 58 | } 59 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/NoisyOps.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_NOISY_NOISYOPS_TD_ 2 | #define LIB_DIALECT_NOISY_NOISYOPS_TD_ 3 | 4 | include "NoisyDialect.td" 5 | include "NoisyTypes.td" 6 | include "mlir/IR/BuiltinAttributes.td" 7 | include "mlir/IR/CommonTypeConstraints.td" 8 | include "mlir/IR/OpBase.td" 9 | include "mlir/Interfaces/InferIntRangeInterface.td" 10 | include "mlir/Interfaces/InferTypeOpInterface.td" 11 | include "mlir/Interfaces/SideEffectInterfaces.td" 12 | 13 | class Noisy_BinOp : Op 17 | ]> { 18 | let arguments = (ins Noisy_I32:$lhs, Noisy_I32:$rhs); 19 | let results = (outs Noisy_I32:$output); 20 | let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))"; 21 | } 22 | 23 | def Noisy_AddOp : Noisy_BinOp<"add"> { 24 | let summary = "Addition operation between noisy ints. Adds noise."; 25 | } 26 | 27 | def Noisy_SubOp : Noisy_BinOp<"sub"> { 28 | let summary = "Subtraction operation between noisy ints. Adds noise."; 29 | } 30 | 31 | def Noisy_MulOp : Noisy_BinOp<"mul"> { 32 | let summary = "Multiplication operation between noisy ints. Multiplies noise."; 33 | } 34 | 35 | def Noisy_EncodeOp : Op]> { 37 | let summary = "Encodes a noisy i32 from a small-width integer, injecting 12 bits of noise."; 38 | let arguments = (ins AnyIntOfWidths<[1, 2, 3, 4, 5]>:$input); 39 | let results = (outs Noisy_I32:$output); 40 | let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))"; 41 | } 42 | 43 | def Noisy_DecodeOp : Op { 44 | let summary = "Decodes a noisy integer to a regular integer, failing if the noise is too high."; 45 | let arguments = (ins Noisy_I32:$input); 46 | let results = (outs AnyIntOfWidths<[1, 2, 3, 4, 5]>:$output); 47 | let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; 48 | } 49 | 50 | def Noisy_ReduceNoiseOp : Op]> { 52 | let summary = "Reduces the noise in a noisy integer to a fixed noise level. Expensive!"; 53 | let arguments = (ins Noisy_I32:$input); 54 | let results = (outs Noisy_I32:$output); 55 | let assemblyFormat = "$input attr-dict `:` qualified(type($output))"; 56 | } 57 | 58 | #endif // LIB_DIALECT_NOISY_NOISYOPS_TD_ 59 | -------------------------------------------------------------------------------- /.github/workflows/build_and_test_cmake.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test w/CMake 2 | permissions: read-all 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # pin@v4.2.2 15 | with: 16 | submodules: recursive 17 | - uses: seanmiddleditch/gha-setup-ninja@master 18 | 19 | - name: Install prerequisites 20 | run: | 21 | sudo apt update 22 | sudo apt install -y uuid-dev 23 | 24 | - name: Cache LLVM artifact 25 | id: cache-llvm 26 | uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # pin@v4.2.3 27 | with: 28 | path: | 29 | ./externals/llvm-project 30 | key: ${{ runner.os }}-cmake-${{ hashFiles('extensions.bzl') }}-${{ hashFiles('**/CMakeLists.txt') }} 31 | 32 | - name: Cache mlir-tutorial build 33 | id: cache-mlir-tutorial 34 | uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # pin@v4.2.3 35 | with: 36 | path: | 37 | ./build 38 | key: ${{ runner.os }}-cmake-${{ hashFiles('extensions.bzl') }}-${{ hashFiles('**/CMakeLists.txt') }} 39 | 40 | - name: Git config 41 | run: | 42 | git config --global --add safe.directory ${GITHUB_WORKSPACE} 43 | 44 | - name: Build LLVM 45 | if: steps.cache-llvm.outputs.cache-hit != 'true' 46 | run: | 47 | LLVM_COMMIT=$(grep 'commit = ' ${GITHUB_WORKSPACE}/extensions.bzl | head -n 1 | cut -d'"' -f 2) 48 | git submodule update --init --recursive 49 | cd externals/llvm-project 50 | git checkout ${LLVM_COMMIT} 51 | mkdir build && cd build 52 | cmake -G Ninja ../llvm -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_RTTI=ON -DLLVM_TARGETS_TO_BUILD="host" 53 | cmake --build . --target check-mlir 54 | 55 | - name: Build and test mlir-tutorial 56 | run: | 57 | mkdir build && cd build 58 | cmake -DLLVM_DIR=${GITHUB_WORKSPACE}/externals/llvm-project/build/lib/cmake/llvm -DMLIR_DIR=${GITHUB_WORKSPACE}/externals/llvm-project/build/lib/cmake/mlir -DBUILD_DEPS="ON" -DBUILD_SHARED_LIBS="OFF" .. 59 | cmake --build . --target MLIRAffineFullUnrollPasses 60 | cmake --build . --target MLIRMulToAddPasses 61 | cmake --build . --target MLIRNoisyPasses 62 | cmake --build . --target mlir-headers 63 | cmake --build . --target tutorial-opt 64 | cmake --build . --target check-mlir-tutorial 65 | -------------------------------------------------------------------------------- /lib/Dialect/Noisy/BUILD: -------------------------------------------------------------------------------- 1 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | td_library( 8 | name = "td_files", 9 | srcs = [ 10 | "NoisyDialect.td", 11 | "NoisyOps.td", 12 | "NoisyTypes.td", 13 | ], 14 | deps = [ 15 | "@llvm-project//mlir:BuiltinDialectTdFiles", 16 | "@llvm-project//mlir:InferIntRangeInterfaceTdFiles", 17 | "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", 18 | "@llvm-project//mlir:OpBaseTdFiles", 19 | "@llvm-project//mlir:SideEffectInterfacesTdFiles", 20 | ], 21 | ) 22 | 23 | gentbl_cc_library( 24 | name = "dialect_inc_gen", 25 | tbl_outs = [ 26 | ( 27 | ["-gen-dialect-decls"], 28 | "NoisyDialect.h.inc", 29 | ), 30 | ( 31 | ["-gen-dialect-defs"], 32 | "NoisyDialect.cpp.inc", 33 | ), 34 | ], 35 | tblgen = "@llvm-project//mlir:mlir-tblgen", 36 | td_file = "NoisyDialect.td", 37 | deps = [ 38 | ":td_files", 39 | ], 40 | ) 41 | 42 | gentbl_cc_library( 43 | name = "types_inc_gen", 44 | tbl_outs = [ 45 | ( 46 | ["-gen-typedef-decls"], 47 | "NoisyTypes.h.inc", 48 | ), 49 | ( 50 | ["-gen-typedef-defs"], 51 | "NoisyTypes.cpp.inc", 52 | ), 53 | ], 54 | tblgen = "@llvm-project//mlir:mlir-tblgen", 55 | td_file = "NoisyTypes.td", 56 | deps = [ 57 | ":dialect_inc_gen", 58 | ":td_files", 59 | ], 60 | ) 61 | 62 | gentbl_cc_library( 63 | name = "ops_inc_gen", 64 | tbl_outs = [ 65 | ( 66 | ["-gen-op-decls"], 67 | "NoisyOps.h.inc", 68 | ), 69 | ( 70 | ["-gen-op-defs"], 71 | "NoisyOps.cpp.inc", 72 | ), 73 | ], 74 | tblgen = "@llvm-project//mlir:mlir-tblgen", 75 | td_file = "NoisyOps.td", 76 | deps = [ 77 | ":dialect_inc_gen", 78 | ":td_files", 79 | ":types_inc_gen", 80 | ], 81 | ) 82 | 83 | cc_library( 84 | name = "Noisy", 85 | srcs = [ 86 | "NoisyDialect.cpp", 87 | "NoisyOps.cpp", 88 | ], 89 | hdrs = [ 90 | "NoisyDialect.h", 91 | "NoisyOps.h", 92 | "NoisyTypes.h", 93 | ], 94 | deps = [ 95 | ":dialect_inc_gen", 96 | ":ops_inc_gen", 97 | ":types_inc_gen", 98 | "@llvm-project//mlir:ComplexDialect", 99 | "@llvm-project//mlir:Dialect", 100 | "@llvm-project//mlir:IR", 101 | "@llvm-project//mlir:InferIntRangeInterface", 102 | "@llvm-project//mlir:InferTypeOpInterface", 103 | "@llvm-project//mlir:Support", 104 | ], 105 | ) 106 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyOps.td: -------------------------------------------------------------------------------- 1 | #ifndef LIB_DIALECT_POLY_POLYOPS_TD_ 2 | #define LIB_DIALECT_POLY_POLYOPS_TD_ 3 | 4 | include "PolyDialect.td" 5 | include "PolyTypes.td" 6 | include "mlir/IR/BuiltinAttributes.td" 7 | include "mlir/IR/OpBase.td" 8 | include "mlir/Interfaces/InferTypeOpInterface.td" 9 | include "mlir/Interfaces/SideEffectInterfaces.td" 10 | 11 | // Type constraint for poly binop arguments: polys, vectors of polys, or 12 | // tensors of polys. 13 | def PolyOrContainer : TypeOrValueSemanticsContainer; 14 | 15 | // Inject verification that all integer-like arguments are 32-bits 16 | def Has32BitArguments : NativeOpTrait<"Has32BitArguments"> { 17 | let cppNamespace = "::mlir::tutorial::poly"; 18 | } 19 | 20 | class Poly_BinOp : Op { 21 | let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs); 22 | let results = (outs PolyOrContainer:$output); 23 | let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))"; 24 | let hasFolder = 1; 25 | let hasCanonicalizer = 1; 26 | } 27 | 28 | def Poly_AddOp : Poly_BinOp<"add"> { 29 | let summary = "Addition operation between polynomials."; 30 | } 31 | 32 | def Poly_SubOp : Poly_BinOp<"sub"> { 33 | let summary = "Subtraction operation between polynomials."; 34 | } 35 | 36 | def Poly_MulOp : Poly_BinOp<"mul"> { 37 | let summary = "Multiplication operation between polynomials."; 38 | } 39 | 40 | def Poly_FromTensorOp : Op { 41 | let summary = "Creates a Polynomial from integer coefficients stored in a tensor."; 42 | let arguments = (ins TensorOf<[AnyInteger]>:$input); 43 | let results = (outs Polynomial:$output); 44 | let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))"; 45 | let hasFolder = 1; 46 | } 47 | 48 | def Poly_ToTensorOp : Op { 49 | let summary = "Converts a polynomial to a tensor of its integer coefficients."; 50 | let arguments = (ins Polynomial:$input); 51 | let results = (outs TensorOf<[AnyInteger]>:$output); 52 | let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; 53 | } 54 | 55 | def IntOrComplex : AnyTypeOf<[AnyInteger, AnyComplex]>; 56 | 57 | def Poly_EvalOp : Op, Has32BitArguments]> { 58 | let summary = "Evaluates a Polynomial at a given input value."; 59 | let arguments = (ins Polynomial:$polynomial, IntOrComplex:$point); 60 | let results = (outs IntOrComplex:$output); 61 | let assemblyFormat = "$polynomial `,` $point attr-dict `:` `(` qualified(type($polynomial)) `,` type($point) `)` `->` type($output)"; 62 | let hasVerifier = 1; 63 | let hasCanonicalizer = 1; 64 | } 65 | 66 | def Poly_ConstantOp : Op { 67 | let summary = "Define a constant polynomial via an attribute."; 68 | let arguments = (ins AnyIntElementsAttr:$coefficients); 69 | let results = (outs Polynomial:$output); 70 | let assemblyFormat = "$coefficients attr-dict `:` qualified(type($output))"; 71 | let hasFolder = 1; 72 | } 73 | 74 | 75 | #endif // LIB_DIALECT_POLY_POLYOPS_TD_ 76 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/BUILD: -------------------------------------------------------------------------------- 1 | load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | td_library( 8 | name = "td_files", 9 | srcs = [ 10 | "PolyDialect.td", 11 | "PolyOps.td", 12 | "PolyPatterns.td", 13 | "PolyTypes.td", 14 | ], 15 | deps = [ 16 | "@llvm-project//mlir:BuiltinDialectTdFiles", 17 | "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", 18 | "@llvm-project//mlir:OpBaseTdFiles", 19 | "@llvm-project//mlir:SideEffectInterfacesTdFiles", 20 | ], 21 | ) 22 | 23 | gentbl_cc_library( 24 | name = "dialect_inc_gen", 25 | tbl_outs = [ 26 | ( 27 | ["-gen-dialect-decls"], 28 | "PolyDialect.h.inc", 29 | ), 30 | ( 31 | ["-gen-dialect-defs"], 32 | "PolyDialect.cpp.inc", 33 | ), 34 | ], 35 | tblgen = "@llvm-project//mlir:mlir-tblgen", 36 | td_file = "PolyDialect.td", 37 | deps = [ 38 | ":td_files", 39 | ], 40 | ) 41 | 42 | gentbl_cc_library( 43 | name = "types_inc_gen", 44 | tbl_outs = [ 45 | ( 46 | ["-gen-typedef-decls"], 47 | "PolyTypes.h.inc", 48 | ), 49 | ( 50 | ["-gen-typedef-defs"], 51 | "PolyTypes.cpp.inc", 52 | ), 53 | ], 54 | tblgen = "@llvm-project//mlir:mlir-tblgen", 55 | td_file = "PolyTypes.td", 56 | deps = [ 57 | ":dialect_inc_gen", 58 | ":td_files", 59 | ], 60 | ) 61 | 62 | gentbl_cc_library( 63 | name = "ops_inc_gen", 64 | tbl_outs = [ 65 | ( 66 | ["-gen-op-decls"], 67 | "PolyOps.h.inc", 68 | ), 69 | ( 70 | ["-gen-op-defs"], 71 | "PolyOps.cpp.inc", 72 | ), 73 | ], 74 | tblgen = "@llvm-project//mlir:mlir-tblgen", 75 | td_file = "PolyOps.td", 76 | deps = [ 77 | ":dialect_inc_gen", 78 | ":td_files", 79 | ":types_inc_gen", 80 | ], 81 | ) 82 | 83 | gentbl_cc_library( 84 | name = "canonicalize_inc_gen", 85 | tbl_outs = [ 86 | ( 87 | ["-gen-rewriters"], 88 | "PolyCanonicalize.cpp.inc", 89 | ), 90 | ], 91 | tblgen = "@llvm-project//mlir:mlir-tblgen", 92 | td_file = "PolyPatterns.td", 93 | deps = [ 94 | ":td_files", 95 | ":types_inc_gen", 96 | "@llvm-project//mlir:ComplexOpsTdFiles", 97 | ], 98 | ) 99 | 100 | cc_library( 101 | name = "Poly", 102 | srcs = [ 103 | "PolyDialect.cpp", 104 | "PolyOps.cpp", 105 | ], 106 | hdrs = [ 107 | "PolyDialect.h", 108 | "PolyOps.h", 109 | "PolyTraits.h", 110 | "PolyTypes.h", 111 | ], 112 | deps = [ 113 | ":canonicalize_inc_gen", 114 | ":dialect_inc_gen", 115 | ":ops_inc_gen", 116 | ":types_inc_gen", 117 | "@llvm-project//mlir:ComplexDialect", 118 | "@llvm-project//mlir:Dialect", 119 | "@llvm-project//mlir:IR", 120 | "@llvm-project//mlir:InferTypeOpInterface", 121 | "@llvm-project//mlir:Support", 122 | ], 123 | ) 124 | -------------------------------------------------------------------------------- /lib/Transform/Noisy/ReduceNoiseOptimizer.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Transform/Noisy/ReduceNoiseOptimizer.h" 2 | 3 | #include "lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.h" 4 | #include "lib/Dialect/Noisy/NoisyOps.h" 5 | #include "lib/Dialect/Noisy/NoisyTypes.h" 6 | #include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 7 | #include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 8 | #include "mlir/include/mlir/Analysis/DataFlowFramework.h" 9 | #include "mlir/include/mlir/IR/Visitors.h" 10 | #include "mlir/include/mlir/Pass/Pass.h" 11 | 12 | namespace mlir { 13 | namespace tutorial { 14 | namespace noisy { 15 | 16 | #define GEN_PASS_DEF_REDUCENOISEOPTIMIZER 17 | #include "lib/Transform/Noisy/Passes.h.inc" 18 | 19 | struct ReduceNoiseOptimizer 20 | : impl::ReduceNoiseOptimizerBase { 21 | using ReduceNoiseOptimizerBase::ReduceNoiseOptimizerBase; 22 | 23 | void runOnOperation() { 24 | Operation *module = getOperation(); 25 | 26 | // FIXME: Should have some way to mark failure when solver is infeasible 27 | ReduceNoiseAnalysis analysis(module); 28 | OpBuilder b(&getContext()); 29 | 30 | module->walk([&](Operation *op) { 31 | if (!analysis.shouldInsertReduceNoise(op)) 32 | return; 33 | 34 | b.setInsertionPointAfter(op); 35 | auto reduceOp = b.create(op->getLoc(), op->getResult(0)); 36 | op->getResult(0).replaceAllUsesExcept(reduceOp.getResult(), {reduceOp}); 37 | }); 38 | 39 | // Use the int range analysis to confirm the noise is always below the 40 | // maximum. 41 | DataFlowSolver solver; 42 | // The IntegerRangeAnalysis depends on DeadCodeAnalysis, but this 43 | // dependence is not automatic and fails silently. 44 | solver.load(); 45 | solver.load(); 46 | if (failed(solver.initializeAndRun(module))) { 47 | getOperation()->emitOpError() << "Failed to run the analysis.\n"; 48 | signalPassFailure(); 49 | return; 50 | } 51 | 52 | auto result = module->walk([&](Operation *op) { 53 | if (!llvm::isa(*op)) { 55 | return WalkResult::advance(); 56 | } 57 | const dataflow::IntegerValueRangeLattice *opRange = 58 | solver.lookupState( 59 | op->getResult(0)); 60 | if (!opRange || opRange->getValue().isUninitialized()) { 61 | op->emitOpError() 62 | << "Found op without a set integer range; did the analysis fail?"; 63 | return WalkResult::interrupt(); 64 | } 65 | 66 | ConstantIntRanges range = opRange->getValue().getValue(); 67 | if (range.umax().getZExtValue() > MAX_NOISE) { 68 | op->emitOpError() << "Found op after which the noise exceeds the " 69 | "allowable maximum of " 70 | << MAX_NOISE 71 | << "; it was: " << range.umax().getZExtValue() 72 | << "\n"; 73 | return WalkResult::interrupt(); 74 | } 75 | 76 | return WalkResult::advance(); 77 | }); 78 | 79 | if (result.wasInterrupted()) { 80 | getOperation()->emitOpError() 81 | << "Detected error in the noise analysis.\n"; 82 | signalPassFailure(); 83 | } 84 | } 85 | }; 86 | 87 | } // namespace noisy 88 | } // namespace tutorial 89 | } // namespace mlir 90 | -------------------------------------------------------------------------------- /lib/Dialect/Poly/PolyOps.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Dialect/Poly/PolyOps.h" 2 | 3 | #include "mlir/Dialect/CommonFolders.h" 4 | #include "mlir/Dialect/Complex/IR/Complex.h" 5 | #include "mlir/IR/PatternMatch.h" 6 | 7 | // Required after PatternMatch.h 8 | #include "lib/Dialect/Poly/PolyCanonicalize.cpp.inc" 9 | 10 | namespace mlir { 11 | namespace tutorial { 12 | namespace poly { 13 | 14 | OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) { 15 | return adaptor.getCoefficients(); 16 | } 17 | 18 | OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) { 19 | return constFoldBinaryOp( 20 | adaptor.getOperands(), [&](APInt a, APInt b) { return a + b; }); 21 | } 22 | 23 | OpFoldResult SubOp::fold(SubOp::FoldAdaptor adaptor) { 24 | return constFoldBinaryOp( 25 | adaptor.getOperands(), [&](APInt a, APInt b) { return a - b; }); 26 | } 27 | 28 | OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) { 29 | auto lhs = dyn_cast_or_null(adaptor.getOperands()[0]); 30 | auto rhs = dyn_cast_or_null(adaptor.getOperands()[1]); 31 | 32 | if (!lhs || !rhs) return nullptr; 33 | 34 | auto degree = llvm::cast(getResult().getType()).getDegreeBound(); 35 | auto maxIndex = lhs.size() + rhs.size() - 1; 36 | 37 | SmallVector result; 38 | result.reserve(maxIndex); 39 | for (int i = 0; i < maxIndex; ++i) { 40 | result.push_back(APInt((*lhs.begin()).getBitWidth(), 0)); 41 | } 42 | 43 | int i = 0; 44 | for (auto lhsIt = lhs.value_begin(); lhsIt != lhs.value_end(); 45 | ++lhsIt) { 46 | int j = 0; 47 | for (auto rhsIt = rhs.value_begin(); rhsIt != rhs.value_end(); 48 | ++rhsIt) { 49 | // index is modulo degree because poly's semantics are defined modulo x^N 50 | // = 1. 51 | result[(i + j) % degree] += *rhsIt * (*lhsIt); 52 | ++j; 53 | } 54 | ++i; 55 | } 56 | 57 | return DenseIntElementsAttr::get( 58 | RankedTensorType::get(static_cast(result.size()), 59 | IntegerType::get(getContext(), 32)), 60 | result); 61 | } 62 | 63 | OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) { 64 | // Returns null if the cast failed, which corresponds to a failed fold. 65 | return dyn_cast_or_null(adaptor.getInput()); 66 | } 67 | 68 | LogicalResult EvalOp::verify() { 69 | auto pointTy = getPoint().getType(); 70 | bool isSignlessInteger = pointTy.isSignlessInteger(32); 71 | auto complexPt = dyn_cast(pointTy); 72 | return isSignlessInteger || complexPt ? success() 73 | : emitOpError( 74 | "argument point must be a 32-bit " 75 | "integer, or a complex number"); 76 | } 77 | 78 | void AddOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, 79 | ::mlir::MLIRContext *context) {} 80 | 81 | void SubOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, 82 | ::mlir::MLIRContext *context) { 83 | results.add(context); 84 | } 85 | 86 | void MulOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, 87 | ::mlir::MLIRContext *context) {} 88 | 89 | void EvalOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, 90 | ::mlir::MLIRContext *context) { 91 | results.add(context); 92 | } 93 | 94 | } // namespace poly 95 | } // namespace tutorial 96 | } // namespace mlir 97 | -------------------------------------------------------------------------------- /tools/tutorial-opt.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Conversion/PolyToStandard/PolyToStandard.h" 2 | #include "lib/Dialect/Noisy/NoisyDialect.h" 3 | #include "lib/Dialect/Poly/PolyDialect.h" 4 | #include "lib/Transform/Affine/Passes.h" 5 | #include "lib/Transform/Arith/Passes.h" 6 | #include "lib/Transform/Noisy/Passes.h" 7 | #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 8 | #include "mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 9 | #include "mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 10 | #include "mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 11 | #include "mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" 12 | #include "mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h" 13 | #include "mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h" 14 | #include "mlir/include/mlir/Dialect/Linalg/Passes.h" 15 | #include "mlir/include/mlir/InitAllDialects.h" 16 | #include "mlir/include/mlir/InitAllPasses.h" 17 | #include "mlir/include/mlir/Pass/PassManager.h" 18 | #include "mlir/include/mlir/Pass/PassRegistry.h" 19 | #include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h" 20 | #include "mlir/include/mlir/Transforms/Passes.h" 21 | 22 | void polyToLLVMPipelineBuilder(mlir::OpPassManager &manager) { 23 | // Poly 24 | manager.addPass(mlir::tutorial::poly::createPolyToStandard()); 25 | manager.addPass(mlir::createCanonicalizerPass()); 26 | 27 | manager.addPass(mlir::createConvertElementwiseToLinalgPass()); 28 | manager.addPass(mlir::createConvertTensorToLinalgPass()); 29 | 30 | // One-shot bufferize, from 31 | // https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation 32 | mlir::bufferization::OneShotBufferizePassOptions bufferizationOptions; 33 | bufferizationOptions.bufferizeFunctionBoundaries = true; 34 | manager.addPass( 35 | mlir::bufferization::createOneShotBufferizePass(bufferizationOptions)); 36 | mlir::bufferization::BufferDeallocationPipelineOptions deallocationOptions; 37 | mlir::bufferization::buildBufferDeallocationPipeline(manager, 38 | deallocationOptions); 39 | 40 | manager.addPass(mlir::createConvertLinalgToLoopsPass()); 41 | 42 | // Needed to lower memref.subview 43 | manager.addPass(mlir::memref::createExpandStridedMetadataPass()); 44 | 45 | manager.addPass(mlir::createSCFToControlFlowPass()); 46 | manager.addPass(mlir::createConvertControlFlowToLLVMPass()); 47 | manager.addPass(mlir::createArithToLLVMConversionPass()); 48 | manager.addPass(mlir::createConvertFuncToLLVMPass()); 49 | manager.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); 50 | manager.addPass(mlir::createReconcileUnrealizedCastsPass()); 51 | 52 | // Cleanup 53 | manager.addPass(mlir::createCanonicalizerPass()); 54 | manager.addPass(mlir::createSCCPPass()); 55 | manager.addPass(mlir::createCSEPass()); 56 | manager.addPass(mlir::createSymbolDCEPass()); 57 | } 58 | 59 | int main(int argc, char **argv) { 60 | mlir::DialectRegistry registry; 61 | registry.insert(); 62 | registry.insert(); 63 | mlir::registerAllDialects(registry); 64 | mlir::registerAllPasses(); 65 | 66 | mlir::tutorial::registerAffinePasses(); 67 | mlir::tutorial::registerArithPasses(); 68 | mlir::tutorial::noisy::registerNoisyPasses(); 69 | 70 | // Dialect conversion passes 71 | mlir::tutorial::poly::registerPolyToStandardPasses(); 72 | 73 | mlir::PassPipelineRegistration<>( 74 | "poly-to-llvm", "Run passes to lower the poly dialect to LLVM", 75 | polyToLLVMPipelineBuilder); 76 | 77 | return mlir::asMainReturnCode( 78 | mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry)); 79 | } 80 | -------------------------------------------------------------------------------- /lib/Transform/Arith/MulToAdd.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Transform/Arith/MulToAdd.h" 2 | #include "mlir/Dialect/Arith/IR/Arith.h" 3 | #include "mlir/IR/PatternMatch.h" 4 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 5 | #include "mlir/include/mlir/Pass/Pass.h" 6 | 7 | namespace mlir { 8 | namespace tutorial { 9 | 10 | #define GEN_PASS_DEF_MULTOADD 11 | #include "lib/Transform/Arith/Passes.h.inc" 12 | 13 | using arith::AddIOp; 14 | using arith::ConstantOp; 15 | using arith::MulIOp; 16 | 17 | // Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do 18 | // nothing. 19 | struct PowerOfTwoExpand : public OpRewritePattern { 20 | PowerOfTwoExpand(mlir::MLIRContext *context) 21 | : OpRewritePattern(context, /*benefit=*/2) {} 22 | 23 | LogicalResult matchAndRewrite(MulIOp op, 24 | PatternRewriter &rewriter) const override { 25 | Value lhs = op.getOperand(0); 26 | 27 | // canonicalization patterns ensure the constant is on the right, if there 28 | // is a constant See 29 | // https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules 30 | Value rhs = op.getOperand(1); 31 | auto rhsDefiningOp = rhs.getDefiningOp(); 32 | if (!rhsDefiningOp) { 33 | return failure(); 34 | } 35 | 36 | int64_t value = rhsDefiningOp.value(); 37 | bool is_power_of_two = (value & (value - 1)) == 0; 38 | 39 | if (!is_power_of_two) { 40 | return failure(); 41 | } 42 | 43 | ConstantOp newConstant = rewriter.create( 44 | rhsDefiningOp.getLoc(), 45 | rewriter.getIntegerAttr(rhs.getType(), value / 2)); 46 | MulIOp newMul = rewriter.create(op.getLoc(), lhs, newConstant); 47 | AddIOp newAdd = rewriter.create(op.getLoc(), newMul, newMul); 48 | 49 | rewriter.replaceOp(op, newAdd); 50 | rewriter.eraseOp(rhsDefiningOp); 51 | 52 | return success(); 53 | } 54 | }; 55 | 56 | // Replace y = 9*x with y = 8*x + x 57 | struct PeelFromMul : public OpRewritePattern { 58 | PeelFromMul(mlir::MLIRContext *context) 59 | : OpRewritePattern(context, /*benefit=*/1) {} 60 | 61 | LogicalResult matchAndRewrite(MulIOp op, 62 | PatternRewriter &rewriter) const override { 63 | Value lhs = op.getOperand(0); 64 | Value rhs = op.getOperand(1); 65 | auto rhsDefiningOp = rhs.getDefiningOp(); 66 | if (!rhsDefiningOp) { 67 | return failure(); 68 | } 69 | 70 | int64_t value = rhsDefiningOp.value(); 71 | 72 | // We are guaranteed `value` is not a power of two, because the greedy 73 | // rewrite engine ensures the PowerOfTwoExpand pattern is run first, since 74 | // it has higher benefit. 75 | 76 | ConstantOp newConstant = rewriter.create( 77 | rhsDefiningOp.getLoc(), 78 | rewriter.getIntegerAttr(rhs.getType(), value - 1)); 79 | MulIOp newMul = rewriter.create(op.getLoc(), lhs, newConstant); 80 | AddIOp newAdd = rewriter.create(op.getLoc(), newMul, lhs); 81 | 82 | rewriter.replaceOp(op, newAdd); 83 | rewriter.eraseOp(rhsDefiningOp); 84 | 85 | return success(); 86 | } 87 | }; 88 | 89 | struct MulToAdd : impl::MulToAddBase { 90 | using MulToAddBase::MulToAddBase; 91 | 92 | void runOnOperation() { 93 | mlir::RewritePatternSet patterns(&getContext()); 94 | patterns.add(&getContext()); 95 | patterns.add(&getContext()); 96 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 97 | } 98 | }; 99 | 100 | } // namespace tutorial 101 | } // namespace mlir 102 | -------------------------------------------------------------------------------- /MODULE.bazel: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Bazel now uses Bzlmod by default to manage external dependencies. 3 | # Please consider migrating your external dependencies from WORKSPACE to MODULE.bazel. 4 | # 5 | # For more details, please check https://github.com/bazelbuild/bazel/issues/18958 6 | ############################################################################### 7 | 8 | module( 9 | name = "mlir_tutorial", 10 | version = "1.0.0", 11 | repo_name = "mlir_tutorial", 12 | ) 13 | 14 | # Dependencies available in BCR 15 | bazel_dep(name = "bazel_skylib", version = "1.7.1") 16 | bazel_dep(name = "rules_python", version = "1.2.0") 17 | bazel_dep(name = "platforms", version = "0.0.11") 18 | bazel_dep(name = "rules_cc", version = "0.1.1") 19 | bazel_dep(name = "rules_java", version = "8.12.0") 20 | bazel_dep(name = "protobuf", version = "30.1") 21 | bazel_dep(name = "rules_proto", version = "7.1.0") 22 | bazel_dep(name = "rules_pkg", version = "1.1.0") 23 | bazel_dep(name = "re2", version = "2024-07-02.bcr.1") 24 | bazel_dep(name = "abseil-cpp", version = "20250512.1") 25 | bazel_dep(name = "or-tools", version = "9.12") 26 | bazel_dep(name = "eigen", version = "4.0.0-20241125.bcr.2") 27 | bazel_dep(name = "highs", version = "1.11.0") 28 | bazel_dep(name = "pcre2", version = "10.46-DEV") 29 | bazel_dep(name = "glpk", version = "5.0.bcr.4") 30 | bazel_dep(name = "bliss", version = "0.73") 31 | bazel_dep(name = "scip", version = "9.2.0.bcr.3") 32 | bazel_dep(name = "zlib-ng", version = "2.0.7") 33 | 34 | # Hedron's Compile Commands Extractor for Bazel 35 | # https://github.com/hedronvision/bazel-compile-commands-extractor 36 | bazel_dep(name = "hedron_compile_commands", dev_dependency = True) 37 | git_override( 38 | module_name = "hedron_compile_commands", 39 | remote = "https://github.com/hedronvision/bazel-compile-commands-extractor.git", 40 | commit = "0e990032f3c5a866e72615cf67e5ce22186dcb97", 41 | # Replace the commit hash (above) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main). 42 | # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). 43 | ) 44 | 45 | # Use module extensions for LLVM and other dependencies that aren't in BCR 46 | mlir_tutorial_deps = use_extension("//:extensions.bzl", "mlir_tutorial_deps") 47 | use_repo(mlir_tutorial_deps, 48 | "llvm-raw", 49 | "llvm_zstd", 50 | "llvm_zlib" 51 | ) 52 | 53 | # The subset of LLVM backend targets that should be compiled 54 | _LLVM_TARGETS = [ 55 | "X86", 56 | # The bazel dependency graph for mlir-opt fails to load (at the analysis 57 | # step) without the NVPTX target in this list, because mlir/test:TestGPU 58 | # depends on the //llvm:NVPTXCodeGen target, which is not defined unless this 59 | # is included. @j2kun asked the LLVM maintiners for tips on how to fix this, 60 | # see https://github.com/llvm/llvm-project/issues/63135 61 | "NVPTX", 62 | # Needed for Apple M1 targets, see 63 | # https://github.com/j2kun/mlir-tutorial/issues/11 64 | "AArch64", 65 | ] 66 | 67 | # Configure LLVM project using use_repo_rule 68 | llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") 69 | llvm_configure( 70 | name = "llvm-project", 71 | targets = _LLVM_TARGETS, 72 | ) 73 | 74 | # Configure Python dependencies 75 | python = use_extension("@rules_python//python/extensions:python.bzl", "python") 76 | python.toolchain(python_version = "3.13") 77 | use_repo(python, "python_3_13") 78 | 79 | pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") 80 | pip.parse( 81 | hub_name = "mlir_tutorial_pip_deps", 82 | python_version = "3.13", 83 | requirements_lock = "//:requirements.txt", 84 | ) 85 | use_repo(pip, "mlir_tutorial_pip_deps") -------------------------------------------------------------------------------- /tests/poly_to_standard.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt --poly-to-standard %s | FileCheck %s 2 | 3 | // CHECK-LABEL: test_lower_add 4 | func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { 5 | // CHECK: arith.addi 6 | %2 = poly.add %0, %1: !poly.poly<10> 7 | return %2 : !poly.poly<10> 8 | } 9 | 10 | // CHECK-LABEL: test_lower_sub 11 | func.func @test_lower_sub(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { 12 | // CHECK: arith.subi 13 | %2 = poly.sub %0, %1: !poly.poly<10> 14 | return %2 : !poly.poly<10> 15 | } 16 | 17 | // CHECK-LABEL: test_lower_to_tensor( 18 | // CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T]] { 19 | // CHECK-NEXT: return %[[V0]] : [[T]] 20 | func.func @test_lower_to_tensor(%0: !poly.poly<10>) -> tensor<10xi32> { 21 | %2 = poly.to_tensor %0: !poly.poly<10> -> tensor<10xi32> 22 | return %2 : tensor<10xi32> 23 | } 24 | 25 | // CHECK-LABEL: test_lower_from_tensor( 26 | // CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T]] { 27 | // CHECK-NEXT: return %[[V0]] : [[T]] 28 | func.func @test_lower_from_tensor(%0 : tensor<10xi32>) -> !poly.poly<10> { 29 | %2 = poly.from_tensor %0: tensor<10xi32> -> !poly.poly<10> 30 | return %2 : !poly.poly<10> 31 | } 32 | 33 | // CHECK-LABEL: test_lower_from_tensor_extend( 34 | // CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T2:tensor<20xi32>]] { 35 | // CHECK: %[[V1:.*]] = tensor.pad %[[V0]] low[0] high[10] 36 | // CHECK: return %[[V1]] : [[T2]] 37 | func.func @test_lower_from_tensor_extend(%0 : tensor<10xi32>) -> !poly.poly<20> { 38 | %2 = poly.from_tensor %0: tensor<10xi32> -> !poly.poly<20> 39 | return %2 : !poly.poly<20> 40 | } 41 | 42 | // CHECK-LABEL: test_lower_add_and_fold 43 | func.func @test_lower_add_and_fold() { 44 | // CHECK: arith.constant dense<[2, 3, 4]> : tensor<3xi32> 45 | %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> 46 | // CHECK: arith.constant dense<[3, 4, 5]> : tensor<3xi32> 47 | %1 = poly.constant dense<[3, 4, 5]> : tensor<3xi32> : !poly.poly<10> 48 | // would be an addi, but it was folded 49 | // CHECK: arith.constant 50 | %2 = poly.add %0, %1: !poly.poly<10> 51 | return 52 | } 53 | 54 | // CHECK-LABEL: test_lower_mul 55 | // CHECK-SAME: (%[[p0:.*]]: [[T:tensor<10xi32>]], %[[p1:.*]]: [[T]]) -> [[T]] { 56 | // CHECK: %[[cst:.*]] = arith.constant dense<0> : [[T]] 57 | // CHECK: %[[c0:.*]] = arith.constant 0 : index 58 | // CHECK: %[[c10:.*]] = arith.constant 10 : index 59 | // CHECK: %[[c1:.*]] = arith.constant 1 : index 60 | // CHECK: %[[outer:.*]] = scf.for %[[outer_iv:.*]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[outer_iter_arg:.*]] = %[[cst]]) -> ([[T]]) { 61 | // CHECK: %[[inner:.*]] = scf.for %[[inner_iv:.*]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[inner_iter_arg:.*]] = %[[outer_iter_arg]]) -> ([[T]]) { 62 | // CHECK: %[[index_sum:.*]] = arith.addi %arg2, %arg4 63 | // CHECK: %[[dest_index:.*]] = arith.remui %[[index_sum]], %[[c10]] 64 | // CHECK-DAG: %[[p0_extracted:.*]] = tensor.extract %[[p0]][%[[outer_iv]]] 65 | // CHECK-DAG: %[[p1_extracted:.*]] = tensor.extract %[[p1]][%[[inner_iv]]] 66 | // CHECK: %[[coeff_mul:.*]] = arith.muli %[[p0_extracted]], %[[p1_extracted]] 67 | // CHECK: %[[accum:.*]] = tensor.extract %[[inner_iter_arg]][%[[dest_index]]] 68 | // CHECK: %[[to_insert:.*]] = arith.addi %[[coeff_mul]], %[[accum]] 69 | // CHECK: %[[inserted:.*]] = tensor.insert %[[to_insert]] into %[[inner_iter_arg]][%[[dest_index]]] 70 | // CHECK: scf.yield %[[inserted]] 71 | // CHECK: } 72 | // CHECK: scf.yield %[[inner]] 73 | // CHECK: } 74 | // CHECK: return %[[outer]] 75 | // CHECK: } 76 | func.func @test_lower_mul(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { 77 | %2 = poly.mul %0, %1: !poly.poly<10> 78 | return %2 : !poly.poly<10> 79 | } 80 | 81 | 82 | // CHECK-LABEL: test_lower_eval 83 | // CHECK-SAME: (%[[poly:.*]]: [[T:tensor<10xi32>]], %[[point:.*]]: i32) -> i32 { 84 | // CHECK: %[[c1:.*]] = arith.constant 1 : index 85 | // CHECK: %[[c10:.*]] = arith.constant 10 : index 86 | // CHECK: %[[c11:.*]] = arith.constant 11 : index 87 | // CHECK: %[[accum:.*]] = arith.constant 0 : i32 88 | // CHECK: %[[loop:.*]] = scf.for %[[iv:.*]] = %[[c1]] to %[[c11]] step %[[c1]] iter_args(%[[iter_arg:.*]] = %[[accum]]) -> (i32) { 89 | // CHECK: %[[coeffIndex:.*]] = arith.subi %[[c10]], %[[iv]] 90 | // CHECK: %[[mulOp:.*]] = arith.muli %[[point]], %[[iter_arg]] 91 | // CHECK: %[[nextCoeff:.*]] = tensor.extract %[[poly]][%[[coeffIndex]]] 92 | // CHECK: %[[next:.*]] = arith.addi %[[mulOp]], %[[nextCoeff]] 93 | // CHECK: scf.yield %[[next]] 94 | // CHECK: } 95 | // CHECK: return %[[loop]] 96 | // CHECK: } 97 | func.func @test_lower_eval(%0 : !poly.poly<10>, %1 : i32) -> i32 { 98 | %2 = poly.eval %0, %1: (!poly.poly<10>, i32) -> i32 99 | return %2 : i32 100 | } 101 | 102 | 103 | // CHECK-LABEL: test_lower_many 104 | // CHECK-NOT: poly 105 | func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 { 106 | %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> 107 | %1 = poly.add %0, %arg : !poly.poly<10> 108 | %2 = poly.mul %1, %1 : !poly.poly<10> 109 | %3 = poly.sub %2, %arg : !poly.poly<10> 110 | %4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32 111 | return %4 : i32 112 | } 113 | -------------------------------------------------------------------------------- /tests/noisy_reduce_noise.mlir: -------------------------------------------------------------------------------- 1 | // RUN: tutorial-opt %s --noisy-reduce-noise-optimizer | FileCheck %s 2 | // Check for syntax 3 | 4 | // CHECK-LABEL: test_insert_noise_reduction_ops_mul 5 | // CHECK: [[V0:%.*]] = arith.constant 3 6 | // CHECK-NEXT: [[V1:%.*]] = arith.constant 4 7 | // CHECK-NEXT: [[V2:%.*]] = noisy.encode [[V0]] 8 | // CHECK-NEXT: [[V3:%.*]] = noisy.encode [[V1]] 9 | // CHECK-NEXT: [[V4:%.*]] = noisy.mul [[V2]], [[V3]] 10 | // CHECK-NEXT: [[V4_R:%.*]] = noisy.reduce_noise [[V4]] 11 | // CHECK-NEXT: [[V5:%.*]] = noisy.mul [[V4_R]], [[V4_R]] 12 | // CHECK-NEXT: [[V5_R:%.*]] = noisy.reduce_noise [[V5]] 13 | // CHECK-NEXT: [[V6:%.*]] = noisy.mul [[V5_R]], [[V5_R]] 14 | // CHECK-NEXT: [[V6_R:%.*]] = noisy.reduce_noise [[V6]] 15 | // This last mul does not need to be reduced 16 | // CHECK-NEXT: [[V7:%.*]] = noisy.mul [[V6_R]], [[V6_R]] 17 | // CHECK-NEXT: [[V8:%.*]] = noisy.decode [[V7]] 18 | // CHECK-NEXT: return 19 | func.func @test_insert_noise_reduction_ops_mul() -> i5 { 20 | %0 = arith.constant 3 : i5 21 | %1 = arith.constant 4 : i5 22 | %2 = noisy.encode %0 : i5 -> !noisy.i32 23 | %3 = noisy.encode %1 : i5 -> !noisy.i32 24 | %4 = noisy.mul %2, %3 : !noisy.i32 25 | %5 = noisy.mul %4, %4 : !noisy.i32 26 | %6 = noisy.mul %5, %5 : !noisy.i32 27 | %7 = noisy.mul %6, %6 : !noisy.i32 28 | %8 = noisy.decode %7 : !noisy.i32 -> i5 29 | return %8 : i5 30 | } 31 | 32 | // CHECK-LABEL: test_insert_noise_reduction_ops_add_none_needed 33 | // CHECK-NOT: noisy.reduce_noise 34 | func.func @test_insert_noise_reduction_ops_add_none_needed() -> i5 { 35 | %0 = arith.constant 3 : i5 36 | %1 = arith.constant 4 : i5 37 | %2 = noisy.encode %0 : i5 -> !noisy.i32 38 | %3 = noisy.encode %1 : i5 -> !noisy.i32 39 | %4 = noisy.add %2, %3 : !noisy.i32 40 | %5 = noisy.add %4, %4 : !noisy.i32 41 | %6 = noisy.add %5, %5 : !noisy.i32 42 | %7 = noisy.add %6, %6 : !noisy.i32 43 | %8 = noisy.decode %7 : !noisy.i32 -> i5 44 | return %8 : i5 45 | } 46 | 47 | 48 | // CHECK-LABEL: test_add_after_mul 49 | // CHECK: noisy.mul 50 | // CHECK: noisy.reduce_noise 51 | // CHECK: noisy.add 52 | // CHECK: noisy.add 53 | // CHECK: noisy.add 54 | // CHECK: noisy.add 55 | // CHECK: noisy.decode 56 | // CHECK: return 57 | func.func @test_add_after_mul() -> i5 { 58 | %0 = arith.constant 3 : i5 59 | %1 = arith.constant 4 : i5 60 | %2 = noisy.encode %0 : i5 -> !noisy.i32 61 | %3 = noisy.encode %1 : i5 -> !noisy.i32 62 | // Noise: 12 63 | %4 = noisy.mul %2, %3 : !noisy.i32 64 | // Noise: 24 65 | %5 = noisy.add %4, %3 : !noisy.i32 66 | // Noise: 25 67 | %6 = noisy.add %5, %5 : !noisy.i32 68 | // Noise: 26 69 | %7 = noisy.add %6, %6 : !noisy.i32 70 | // Noise: 27 71 | %8 = noisy.add %7, %7 : !noisy.i32 72 | %9 = noisy.decode %8 : !noisy.i32 -> i5 73 | return %9 : i5 74 | } 75 | 76 | // This test checks that the solver can find a single insertion point 77 | // for a reduce_noise op that handles two branches, each of which would 78 | // also need a reduce_noise op if handled separately. 79 | // CHECK-LABEL: test_single_insertion_branching 80 | // CHECK: noisy.mul 81 | // CHECK-NOT: noisy.add 82 | // CHECK-COUNT-1: noisy.reduce_noise 83 | // CHECK-NOT: noisy.reduce_noise 84 | func.func @test_single_insertion_branching() -> i5 { 85 | %0 = arith.constant 3 : i5 86 | %1 = arith.constant 4 : i5 87 | %2 = noisy.encode %0 : i5 -> !noisy.i32 88 | %3 = noisy.encode %1 : i5 -> !noisy.i32 89 | // Noise: 12 90 | %4 = noisy.mul %2, %3 : !noisy.i32 91 | // Noise: 24 92 | 93 | // branch 1 94 | %b1 = noisy.add %4, %3 : !noisy.i32 95 | // Noise: 25 96 | %b2 = noisy.add %b1, %3 : !noisy.i32 97 | // Noise: 25 98 | %b3 = noisy.add %b2, %3 : !noisy.i32 99 | // Noise: 26 100 | %b4 = noisy.add %b3, %3 : !noisy.i32 101 | // Noise: 27 102 | 103 | // branch 2 104 | %c1 = noisy.sub %4, %2 : !noisy.i32 105 | // Noise: 25 106 | %c2 = noisy.sub %c1, %3 : !noisy.i32 107 | // Noise: 25 108 | %c3 = noisy.sub %c2, %3 : !noisy.i32 109 | // Noise: 26 110 | %c4 = noisy.sub %c3, %3 : !noisy.i32 111 | // Noise: 27 112 | 113 | %x1 = noisy.decode %b4 : !noisy.i32 -> i5 114 | %x2 = noisy.decode %c4 : !noisy.i32 -> i5 115 | %x3 = arith.addi %x1, %x2 : i5 116 | return %x3 : i5 117 | } 118 | 119 | // same as test_single_insertion_branching, but because the last two values 120 | // are multiplied, we need two reduce_noise ops, one on each branch. 121 | // CHECK-LABEL: test_double_insertion_branching 122 | // CHECK: noisy.mul 123 | // CHECK: noisy.add 124 | // CHECK-COUNT-2: noisy.reduce_noise 125 | // CHECK-NOT: noisy.reduce_noise 126 | // CHECK: noisy.mul 127 | func.func @test_double_insertion_branching() -> i5 { 128 | %0 = arith.constant 3 : i5 129 | %1 = arith.constant 4 : i5 130 | %2 = noisy.encode %0 : i5 -> !noisy.i32 131 | %3 = noisy.encode %1 : i5 -> !noisy.i32 132 | // Noise: 12 133 | %4 = noisy.mul %2, %3 : !noisy.i32 134 | // Noise: 24 135 | 136 | // branch 1 137 | %b1 = noisy.add %4, %3 : !noisy.i32 138 | // Noise: 25 139 | %b2 = noisy.add %b1, %3 : !noisy.i32 140 | // Noise: 25 141 | %b3 = noisy.add %b2, %3 : !noisy.i32 142 | // Noise: 26 143 | %b4 = noisy.add %b3, %3 : !noisy.i32 144 | // Noise: 27 145 | 146 | // branch 2 147 | %c1 = noisy.add %4, %2 : !noisy.i32 148 | // Noise: 25 149 | %c2 = noisy.add %c1, %3 : !noisy.i32 150 | // Noise: 25 151 | %c3 = noisy.add %c2, %3 : !noisy.i32 152 | // Noise: 26 153 | %c4 = noisy.add %c3, %3 : !noisy.i32 154 | // Noise: 27 155 | 156 | %exit = noisy.mul %b4, %c4 : !noisy.i32 157 | 158 | %x1 = noisy.decode %exit : !noisy.i32 -> i5 159 | return %x1 : i5 160 | } 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLIR For Beginners 2 | 3 | This is the code repository for a series of articles on the 4 | [MLIR framework](https://mlir.llvm.org/) for building compilers. 5 | 6 | ## Articles 7 | 8 | 1. [Build System (Getting Started)](https://jeremykun.com/2023/08/10/mlir-getting-started/) 9 | 2. [Running and Testing a Lowering](https://jeremykun.com/2023/08/10/mlir-running-and-testing-a-lowering/) 10 | 3. [Writing Our First Pass](https://jeremykun.com/2023/08/10/mlir-writing-our-first-pass/) 11 | 4. [Using Tablegen for Passes](https://jeremykun.com/2023/08/10/mlir-using-tablegen-for-passes/) 12 | 5. [Defining a New Dialect](https://jeremykun.com/2023/08/21/mlir-defining-a-new-dialect/) 13 | 6. [Using Traits](https://jeremykun.com/2023/09/07/mlir-using-traits/) 14 | 7. [Folders and Constant Propagation](https://jeremykun.com/2023/09/11/mlir-folders/) 15 | 8. [Verifiers](https://jeremykun.com/2023/09/13/mlir-verifiers/) 16 | 9. [Canonicalizers and Declarative Rewrite Patterns](https://jeremykun.com/2023/09/20/mlir-canonicalizers-and-declarative-rewrite-patterns/) 17 | 10. [Dialect Conversion](https://jeremykun.com/2023/10/23/mlir-dialect-conversion/) 18 | 11. [Lowering through LLVM](https://jeremykun.com/2023/11/01/mlir-lowering-through-llvm/) 19 | 12. [A Global Optimization and Dataflow Analysis](https://jeremykun.com/2023/11/15/mlir-a-global-optimization-and-dataflow-analysis/) 20 | 12. [Defining Patterns with PDLL](https://www.jeremykun.com/2024/08/04/mlir-pdll/) 21 | 22 | ## Bazel build 23 | 24 | Bazel is one of two supported build systems for this tutorial. The other is 25 | CMake. If you're unfamiliar with Bazel, you can read the tutorials at 26 | [https://bazel.build/start](https://bazel.build/start). Familiarity with Bazel 27 | is not required to build or test, but it is required to follow the articles in 28 | the tutorial series and explained in the first article, 29 | [Build System (Getting Started)](https://jeremykun.com/2023/08/10/mlir-getting-started/). 30 | The CMake build is maintained, but was added at article 10 (Dialect Conversion) 31 | and will not be explained in the articles. 32 | 33 | **Note**: This project has been upgraded to Bazel 8.3.1 and migrated to use 34 | Bzlmod for dependency management, replacing the traditional WORKSPACE file 35 | approach. Dependencies are now managed through `MODULE.bazel` using the 36 | Bazel Central Registry (BCR) where possible. 37 | 38 | ### Prerequisites 39 | 40 | Install Bazelisk via instructions at 41 | [https://github.com/bazelbuild/bazelisk#installation](https://github.com/bazelbuild/bazelisk#installation). 42 | This should create the `bazel` command on your system. 43 | 44 | You should also have a modern C++ compiler on your system, either `gcc` or 45 | `clang`, which Bazel will detect. 46 | 47 | **Bazel Version**: This project requires Bazel 8.3.1 or newer. The specific 48 | version is pinned in `.bazelversion`. 49 | 50 | ### Build and test 51 | 52 | Run 53 | 54 | ```bash 55 | bazel build ...:all 56 | bazel test ...:all 57 | ``` 58 | 59 | ### Dependency Management 60 | 61 | The project uses Bzlmod (MODULE.bazel) for dependency management: 62 | 63 | - **Core dependencies**: Managed through Bazel Central Registry (BCR) 64 | - rules_python, rules_java, protobuf, abseil-cpp, or-tools, etc. 65 | - **LLVM dependencies**: Managed through custom module extension 66 | - LLVM/MLIR source code via git repository 67 | - **Development tools**: hedron_compile_commands via git_override 68 | 69 | This approach provides better dependency resolution, versioning, and 70 | compatibility compared to the legacy WORKSPACE approach. 71 | 72 | ## CMake build 73 | 74 | CMake is one of two supported build systems for this tutorial. The other is 75 | Bazel. If you're unfamiliar with CMake, you can read the tutorials at 76 | [https://cmake.org/getting-started/](https://cmake.org/getting-started/). The 77 | CMake build is maintained, but was added at article 10 (Dialect Conversion) and 78 | will not be explained in the articles. 79 | 80 | ### Prerequisites 81 | 82 | * Make sure you have installed everything needed to build LLVM 83 | https://llvm.org/docs/GettingStarted.html#software 84 | * For this recipe Ninja is used so be sure to have it as well installed 85 | https://github.com/ninja-build/ninja/wiki/Pre-built-Ninja-packages 86 | 87 | ### Checking out the code 88 | 89 | Checkout the tutorial including the LLVM dependency (submodules): 90 | 91 | ```bash 92 | git clone --recurse-submodules https://github.com/j2kun/mlir-tutorial.git 93 | cd mlir-tutorial 94 | ``` 95 | 96 | ### Building dependencies 97 | 98 | Note: The following steps are suitable for macOs and use ninja as building 99 | system, they should not be hard to adapt for your environment. 100 | 101 | *Build LLVM/MLIR* 102 | 103 | ```bash 104 | #!/bin/sh 105 | 106 | BUILD_SYSTEM=Ninja 107 | BUILD_TAG=ninja 108 | THIRDPARTY_LLVM_DIR=$PWD/externals/llvm-project 109 | BUILD_DIR=$THIRDPARTY_LLVM_DIR/build 110 | INSTALL_DIR=$THIRDPARTY_LLVM_DIR/install 111 | 112 | mkdir -p $BUILD_DIR 113 | mkdir -p $INSTALL_DIR 114 | 115 | pushd $BUILD_DIR 116 | 117 | cmake ../llvm -G $BUILD_SYSTEM \ 118 | -DCMAKE_CXX_COMPILER="$(xcrun --find clang++)" \ 119 | -DCMAKE_C_COMPILER="$(xcrun --find clang)" \ 120 | -DCMAKE_INSTALL_PREFIX=$INSTALL_DIR \ 121 | -DLLVM_LOCAL_RPATH=$INSTALL_DIR/lib \ 122 | -DLLVM_PARALLEL_COMPILE_JOBS=7 \ 123 | -DLLVM_PARALLEL_LINK_JOBS=1 \ 124 | -DLLVM_BUILD_EXAMPLES=OFF \ 125 | -DLLVM_INSTALL_UTILS=ON \ 126 | -DCMAKE_OSX_ARCHITECTURES="$(uname -m)" \ 127 | -DCMAKE_BUILD_TYPE=Release \ 128 | -DLLVM_ENABLE_ASSERTIONS=ON \ 129 | -DLLVM_CCACHE_BUILD=ON \ 130 | -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ 131 | -DLLVM_ENABLE_PROJECTS='mlir' \ 132 | -DDEFAULT_SYSROOT="$(xcrun --show-sdk-path)" \ 133 | -DCMAKE_OSX_SYSROOT="$(xcrun --show-sdk-path)" 134 | 135 | cmake --build . --target check-mlir 136 | 137 | popd 138 | ``` 139 | 140 | ### Build and test 141 | 142 | ```bash 143 | #!/bin/sh 144 | 145 | BUILD_SYSTEM="Ninja" 146 | BUILD_DIR=./build-`echo ${BUILD_SYSTEM}| tr '[:upper:]' '[:lower:]'` 147 | 148 | rm -rf $BUILD_DIR 149 | mkdir $BUILD_DIR 150 | pushd $BUILD_DIR 151 | 152 | LLVM_BUILD_DIR=externals/llvm-project/build 153 | cmake -G $BUILD_SYSTEM .. \ 154 | -DLLVM_DIR="$LLVM_BUILD_DIR/lib/cmake/llvm" \ 155 | -DMLIR_DIR="$LLVM_BUILD_DIR/lib/cmake/mlir" \ 156 | -DBUILD_DEPS="ON" \ 157 | -DBUILD_SHARED_LIBS="OFF" \ 158 | -DCMAKE_BUILD_TYPE=Debug 159 | 160 | popd 161 | 162 | cmake --build $BUILD_DIR --target MLIRAffineFullUnrollPasses 163 | cmake --build $BUILD_DIR --target MLIRMulToAddPasses 164 | cmake --build $BUILD_DIR --target MLIRNoisyPasses 165 | cmake --build $BUILD_DIR --target mlir-headers 166 | cmake --build $BUILD_DIR --target mlir-doc 167 | cmake --build $BUILD_DIR --target tutorial-opt 168 | cmake --build $BUILD_DIR --target check-mlir-tutorial 169 | ``` 170 | -------------------------------------------------------------------------------- /lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.h" 2 | 3 | #include 4 | 5 | #include "lib/Dialect/Noisy/NoisyOps.h" 6 | #include "mlir/include/mlir/IR/Operation.h" 7 | #include "mlir/include/mlir/IR/Value.h" 8 | #include "ortools/linear_solver/linear_solver.h" 9 | #include "llvm/Support/Debug.h" 10 | #include "llvm/include/llvm/ADT/DenseMap.h" 11 | #include "llvm/include/llvm/ADT/TypeSwitch.h" 12 | 13 | using namespace operations_research; 14 | 15 | namespace mlir { 16 | namespace tutorial { 17 | 18 | #define DEBUG_TYPE "ReduceNoiseAnalysis" 19 | 20 | // This needs only be larger than 32, since we're hard coding i32s in this 21 | // tutorial. 22 | constexpr int IF_THEN_AUX = 100; 23 | 24 | std::string nameAndLoc(Operation *op) { 25 | std::string varName; 26 | llvm::raw_string_ostream ss(varName); 27 | ss << op->getName() << "_" << op->getLoc(); 28 | return ss.str(); 29 | } 30 | 31 | ReduceNoiseAnalysis::ReduceNoiseAnalysis(Operation *op) { 32 | std::unique_ptr solver(MPSolver::CreateSolver("SCIP")); 33 | MPObjective *const objective = solver->MutableObjective(); 34 | objective->SetMinimization(); 35 | 36 | llvm::DenseMap decisionVariables; 37 | llvm::DenseMap ssaNoiseVariables; 38 | std::vector allVariables; 39 | 40 | // First walk the IR to define variables for all values and ops, 41 | // and constraint initial conditions. 42 | op->walk([&](Operation *op) { 43 | // FIXME: assumes all reduce_noise ops have already been removed and their 44 | // values forwarded. 45 | if (!isa(op)) { 46 | return; 47 | } 48 | 49 | std::string varName = "InsertReduceNoise_" + nameAndLoc(op); 50 | auto decisionVar = solver->MakeIntVar(0, 1, varName); 51 | decisionVariables.insert(std::make_pair(op, decisionVar)); 52 | allVariables.push_back(decisionVar); 53 | objective->SetCoefficient(decisionVar, 1); 54 | 55 | int index = 0; 56 | for (auto operand : op->getOperands()) { 57 | if (ssaNoiseVariables.contains(operand)) { 58 | continue; 59 | } 60 | std::string varName = 61 | "NoiseAt_" + nameAndLoc(op) + "_arg_" + std::to_string(index++); 62 | auto ssaNoiseVar = solver->MakeNumVar(0, MAX_NOISE, varName); 63 | allVariables.push_back(ssaNoiseVar); 64 | ssaNoiseVariables.insert(std::make_pair(operand, ssaNoiseVar)); 65 | } 66 | 67 | if (!ssaNoiseVariables.contains(op->getResult(0))) { 68 | std::string varName = "NoiseAt_" + nameAndLoc(op) + "_result"; 69 | auto ssaNoiseVar = solver->MakeNumVar(0, MAX_NOISE, varName); 70 | allVariables.push_back(ssaNoiseVar); 71 | ssaNoiseVariables.insert(std::make_pair(op->getResult(0), ssaNoiseVar)); 72 | } 73 | }); 74 | 75 | // Define constraints on the noise at each SSA value 76 | for (auto item : ssaNoiseVariables) { 77 | auto value = item.first; 78 | auto var = item.second; 79 | // An input node has noise equal to the initial noise, though we're being a 80 | // bit sloppy by saying that EVERY block argument counts as an input node. 81 | // In the tutorial, there is no control flow, so these are the function 82 | // arguments of the main function being analyzed. A real compiler would 83 | // need to handle this more generically. 84 | if (isa(value) || 85 | isa(value.getDefiningOp())) { 86 | MPConstraint *const ct = 87 | solver->MakeRowConstraint(INITIAL_NOISE, INITIAL_NOISE, ""); 88 | ct->SetCoefficient(var, 1); 89 | } 90 | } 91 | 92 | std::string cstName; 93 | // Define the decision variable constraints 94 | op->walk([&](Operation *op) { 95 | llvm::TypeSwitch(*op) 96 | .Case([&](auto op) { 97 | // result_noise = input_noise (1 - reduce_decision) + 12 * 98 | // reduce_decision but linearized due to the quadratic term 99 | // input_noise * reduce_decision 100 | 101 | auto inf = solver->infinity(); 102 | auto lhsNoiseVar = ssaNoiseVariables.lookup(op.getLhs()); 103 | auto rhsNoiseVar = ssaNoiseVariables.lookup(op.getRhs()); 104 | auto resultNoiseVar = ssaNoiseVariables.lookup(op.getResult()); 105 | auto reduceNoiseDecision = decisionVariables.lookup(op); 106 | 107 | // result_noise >= 12 * reduce_decision 108 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_1"; 109 | MPConstraint *const ct1 = 110 | solver->MakeRowConstraint(0.0, inf, cstName); 111 | ct1->SetCoefficient(resultNoiseVar, 1); 112 | ct1->SetCoefficient(reduceNoiseDecision, -INITIAL_NOISE); 113 | 114 | // result_noise <= 12 + (1 - reduce_decision) * BIG_CONST 115 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_2"; 116 | MPConstraint *const ct2 = solver->MakeRowConstraint( 117 | 0.0, INITIAL_NOISE * IF_THEN_AUX, cstName); 118 | ct2->SetCoefficient(resultNoiseVar, 1); 119 | ct2->SetCoefficient(reduceNoiseDecision, IF_THEN_AUX); 120 | 121 | // result_noise >= input_noise - reduce_decision * BIG_CONST 122 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_3"; 123 | MPConstraint *const ct3 = 124 | solver->MakeRowConstraint(0.0, inf, cstName); 125 | ct3->SetCoefficient(resultNoiseVar, 1); 126 | ct3->SetCoefficient(reduceNoiseDecision, IF_THEN_AUX); 127 | // The input noise is the sum of the two argument noises 128 | if (op.getLhs() == op.getRhs()) { 129 | ct3->SetCoefficient(lhsNoiseVar, -2); 130 | } else { 131 | ct3->SetCoefficient(lhsNoiseVar, -1); 132 | ct3->SetCoefficient(rhsNoiseVar, -1); 133 | } 134 | 135 | // result_noise <= input_noise + reduce_decision * BIG_CONST 136 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_4"; 137 | MPConstraint *const ct4 = 138 | solver->MakeRowConstraint(-inf, 0.0, cstName); 139 | ct4->SetCoefficient(resultNoiseVar, 1); 140 | ct4->SetCoefficient(reduceNoiseDecision, -IF_THEN_AUX); 141 | if (op.getLhs() == op.getRhs()) { 142 | ct4->SetCoefficient(lhsNoiseVar, -2); 143 | } else { 144 | ct4->SetCoefficient(lhsNoiseVar, -1); 145 | ct4->SetCoefficient(rhsNoiseVar, -1); 146 | } 147 | 148 | // ensure the noise before the reduce_noise op (input_noise) 149 | // also is not too large 150 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_5"; 151 | MPConstraint *const ct5 = 152 | solver->MakeRowConstraint(0.0, MAX_NOISE, cstName); 153 | if (op.getLhs() == op.getRhs()) { 154 | ct5->SetCoefficient(lhsNoiseVar, 2); 155 | } else { 156 | ct5->SetCoefficient(lhsNoiseVar, 1); 157 | ct5->SetCoefficient(rhsNoiseVar, 1); 158 | } 159 | }) 160 | .Case([&](auto op) { 161 | // Same as for MulOp, but the noise combination function is more 162 | // complicated because it involves a maximum. 163 | auto inf = solver->infinity(); 164 | auto lhsNoiseVar = ssaNoiseVariables.lookup(op.getLhs()); 165 | auto rhsNoiseVar = ssaNoiseVariables.lookup(op.getRhs()); 166 | auto resultNoiseVar = ssaNoiseVariables.lookup(op.getResult()); 167 | auto reduceNoiseDecision = decisionVariables.lookup(op); 168 | 169 | // result_noise >= 12 * reduce_decision 170 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_1"; 171 | MPConstraint *const ct1 = 172 | solver->MakeRowConstraint(0.0, inf, cstName); 173 | ct1->SetCoefficient(resultNoiseVar, 1); 174 | ct1->SetCoefficient(reduceNoiseDecision, -INITIAL_NOISE); 175 | 176 | // result_noise <= 12 + (1 - reduce_decision) * BIG_CONST 177 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_2"; 178 | MPConstraint *const ct2 = solver->MakeRowConstraint( 179 | 0.0, INITIAL_NOISE * IF_THEN_AUX, cstName); 180 | ct2->SetCoefficient(resultNoiseVar, 1); 181 | ct2->SetCoefficient(reduceNoiseDecision, IF_THEN_AUX); 182 | 183 | // for AddOp, the input noise is the max of the two argument noises 184 | // plus one. Model this with an extra variable Z and two constraints: 185 | // 186 | // lhs_noise + 1 <= Z <= MAX_NOISE 187 | // rhs_noise + 1 <= Z <= MAX_NOISE 188 | // input_noise := Z 189 | // 190 | // Then add theze Z variables to the minimization objective, and 191 | // they will be clamped to the larger of the two lower bounds. 192 | cstName = "Z_" + nameAndLoc(op); 193 | auto zVar = solver->MakeNumVar(0, MAX_NOISE, cstName); 194 | allVariables.push_back(zVar); 195 | // The objective coefficient is not all that important: the solver 196 | // cannot cheat by making Z larger than necessary, since making Z 197 | // larger than it needs to be would further increase the need to 198 | // insert reduce_noise ops, which would be more expensive. 199 | objective->SetCoefficient(zVar, 0.1); 200 | 201 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_z1"; 202 | MPConstraint *const zCt1 = 203 | solver->MakeRowConstraint(1.0, inf, cstName); 204 | zCt1->SetCoefficient(zVar, 1); 205 | zCt1->SetCoefficient(lhsNoiseVar, -1); 206 | 207 | if (op.getLhs() != op.getRhs()) { 208 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_z2"; 209 | MPConstraint *const zCt2 = 210 | solver->MakeRowConstraint(1.0, inf, cstName); 211 | zCt2->SetCoefficient(zVar, 1); 212 | zCt2->SetCoefficient(rhsNoiseVar, -1); 213 | } 214 | 215 | // result_noise >= input_noise - reduce_decision * BIG_CONST 216 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_3"; 217 | MPConstraint *const ct3 = 218 | solver->MakeRowConstraint(0.0, inf, cstName); 219 | ct3->SetCoefficient(resultNoiseVar, 1); 220 | ct3->SetCoefficient(reduceNoiseDecision, IF_THEN_AUX); 221 | ct3->SetCoefficient(zVar, -1); 222 | 223 | // result_noise <= input_noise + reduce_decision * BIG_CONST 224 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_4"; 225 | MPConstraint *const ct4 = 226 | solver->MakeRowConstraint(-inf, 0.0, cstName); 227 | ct4->SetCoefficient(resultNoiseVar, 1); 228 | ct4->SetCoefficient(reduceNoiseDecision, -IF_THEN_AUX); 229 | ct4->SetCoefficient(zVar, -1); 230 | 231 | // ensure the noise before the reduce_noise op (input_noise) 232 | // also is not too large 233 | cstName = "DecisionDynamics_" + nameAndLoc(op) + "_5"; 234 | MPConstraint *const ct5 = 235 | solver->MakeRowConstraint(0.0, MAX_NOISE, cstName); 236 | ct5->SetCoefficient(zVar, 1); 237 | }); 238 | }); 239 | 240 | // Uncomment if you want to read the model's textual description, 241 | // generally not for those unfamiliar with linear programming. 242 | // std::string modelAsString; 243 | // solver->ExportModelAsLpFormat(false, &modelAsString); 244 | // LLVM_DEBUG(llvm::dbgs() << "Model string = " << modelAsString << "\n"); 245 | 246 | solver->Solve(); 247 | LLVM_DEBUG(llvm::dbgs() << "Problem solved in " << solver->wall_time() 248 | << " milliseconds" 249 | << "\n"); 250 | 251 | LLVM_DEBUG(llvm::dbgs() << "Solution:\n"); 252 | LLVM_DEBUG(llvm::dbgs() << "Objective value = " << objective->Value() 253 | << "\n"); 254 | // LLVM_DEBUG(llvm::dbgs() << "Variables:\n"); 255 | // for (auto var : allVariables) { 256 | // LLVM_DEBUG(llvm::dbgs() << " " << var->name() << " = " 257 | // << var->solution_value() << "\n"); 258 | // } 259 | 260 | for (auto item : decisionVariables) { 261 | solution.insert(std::make_pair(item.first, item.second->solution_value())); 262 | } 263 | } 264 | 265 | } // namespace tutorial 266 | } // namespace mlir 267 | -------------------------------------------------------------------------------- /lib/Conversion/PolyToStandard/PolyToStandard.cpp: -------------------------------------------------------------------------------- 1 | #include "lib/Conversion/PolyToStandard/PolyToStandard.h" 2 | 3 | #include "lib/Dialect/Poly/PolyOps.h" 4 | #include "lib/Dialect/Poly/PolyTypes.h" 5 | #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project 6 | #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project 7 | #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 8 | #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project 9 | #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project 10 | #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project 11 | 12 | namespace mlir { 13 | namespace tutorial { 14 | namespace poly { 15 | 16 | #define GEN_PASS_DEF_POLYTOSTANDARD 17 | #include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" 18 | 19 | class PolyToStandardTypeConverter : public TypeConverter { 20 | public: 21 | PolyToStandardTypeConverter(MLIRContext *ctx) { 22 | addConversion([](Type type) { return type; }); 23 | addConversion([ctx](PolynomialType type) -> Type { 24 | int degreeBound = type.getDegreeBound(); 25 | IntegerType elementTy = 26 | IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless); 27 | return RankedTensorType::get({degreeBound}, elementTy); 28 | }); 29 | 30 | // We don't include any custom materialization hooks because this lowering 31 | // is all done in a single pass. The dialect conversion framework works by 32 | // resolving intermediate (mid-pass) type conflicts by inserting 33 | // unrealized_conversion_cast ops, and only converting those to custom 34 | // materializations if they persist at the end of the pass. In our case, 35 | // we'd only need to use custom materializations if we split this lowering 36 | // across multiple passes. 37 | } 38 | }; 39 | 40 | struct ConvertAdd : public OpConversionPattern { 41 | ConvertAdd(mlir::MLIRContext *context) 42 | : OpConversionPattern(context) {} 43 | 44 | using OpConversionPattern::OpConversionPattern; 45 | 46 | LogicalResult matchAndRewrite( 47 | AddOp op, OpAdaptor adaptor, 48 | ConversionPatternRewriter &rewriter) const override { 49 | arith::AddIOp addOp = rewriter.create( 50 | op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); 51 | rewriter.replaceOp(op.getOperation(), addOp); 52 | return success(); 53 | } 54 | }; 55 | 56 | struct ConvertSub : public OpConversionPattern { 57 | ConvertSub(mlir::MLIRContext *context) 58 | : OpConversionPattern(context) {} 59 | 60 | using OpConversionPattern::OpConversionPattern; 61 | 62 | LogicalResult matchAndRewrite( 63 | SubOp op, OpAdaptor adaptor, 64 | ConversionPatternRewriter &rewriter) const override { 65 | arith::SubIOp subOp = rewriter.create( 66 | op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); 67 | rewriter.replaceOp(op.getOperation(), subOp); 68 | return success(); 69 | } 70 | }; 71 | 72 | struct ConvertMul : public OpConversionPattern { 73 | ConvertMul(mlir::MLIRContext *context) 74 | : OpConversionPattern(context) {} 75 | 76 | using OpConversionPattern::OpConversionPattern; 77 | 78 | LogicalResult matchAndRewrite( 79 | MulOp op, OpAdaptor adaptor, 80 | ConversionPatternRewriter &rewriter) const override { 81 | auto polymulTensorType = cast(adaptor.getLhs().getType()); 82 | auto numTerms = polymulTensorType.getShape()[0]; 83 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); 84 | 85 | // Create an all-zeros tensor to store the result 86 | auto polymulResult = b.create( 87 | polymulTensorType, DenseElementsAttr::get(polymulTensorType, 0)); 88 | 89 | // Loop bounds and step. 90 | auto lowerBound = 91 | b.create(b.getIndexType(), b.getIndexAttr(0)); 92 | auto numTermsOp = 93 | b.create(b.getIndexType(), b.getIndexAttr(numTerms)); 94 | auto step = 95 | b.create(b.getIndexType(), b.getIndexAttr(1)); 96 | 97 | auto p0 = adaptor.getLhs(); 98 | auto p1 = adaptor.getRhs(); 99 | 100 | // for i = 0, ..., N-1 101 | // for j = 0, ..., N-1 102 | // product[i+j (mod N)] += p0[i] * p1[j] 103 | auto outerLoop = b.create( 104 | lowerBound, numTermsOp, step, ValueRange(polymulResult.getResult()), 105 | [&](OpBuilder &builder, Location loc, Value p0Index, 106 | ValueRange loopState) { 107 | ImplicitLocOpBuilder b(op.getLoc(), builder); 108 | auto innerLoop = b.create( 109 | lowerBound, numTermsOp, step, loopState, 110 | [&](OpBuilder &builder, Location loc, Value p1Index, 111 | ValueRange loopState) { 112 | ImplicitLocOpBuilder b(op.getLoc(), builder); 113 | auto accumTensor = loopState.front(); 114 | auto destIndex = b.create( 115 | b.create(p0Index, p1Index), numTermsOp); 116 | auto mulOp = b.create( 117 | b.create(p0, ValueRange(p0Index)), 118 | b.create(p1, ValueRange(p1Index))); 119 | auto result = b.create( 120 | mulOp, b.create(accumTensor, 121 | destIndex.getResult())); 122 | auto stored = b.create(result, accumTensor, 123 | destIndex.getResult()); 124 | b.create(stored.getResult()); 125 | }); 126 | 127 | b.create(innerLoop.getResults()); 128 | }); 129 | 130 | rewriter.replaceOp(op, outerLoop.getResult(0)); 131 | return success(); 132 | } 133 | }; 134 | 135 | struct ConvertEval : public OpConversionPattern { 136 | ConvertEval(mlir::MLIRContext *context) 137 | : OpConversionPattern(context) {} 138 | 139 | using OpConversionPattern::OpConversionPattern; 140 | 141 | LogicalResult matchAndRewrite( 142 | EvalOp op, OpAdaptor adaptor, 143 | ConversionPatternRewriter &rewriter) const override { 144 | auto polyTensorType = 145 | cast(adaptor.getPolynomial().getType()); 146 | auto numTerms = polyTensorType.getShape()[0]; 147 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); 148 | 149 | auto lowerBound = 150 | b.create(b.getIndexType(), b.getIndexAttr(1)); 151 | auto numTermsOp = b.create(b.getIndexType(), 152 | b.getIndexAttr(numTerms)); 153 | auto upperBound = b.create(b.getIndexType(), 154 | b.getIndexAttr(numTerms + 1)); 155 | auto step = lowerBound; 156 | 157 | auto poly = adaptor.getPolynomial(); 158 | auto point = adaptor.getPoint(); 159 | 160 | // Horner's method: 161 | // 162 | // accum = 0 163 | // for i = 1, 2, ..., N 164 | // accum = point * accum + coeff[N - i] 165 | auto accum = 166 | b.create(b.getI32Type(), b.getI32IntegerAttr(0)); 167 | auto loop = b.create( 168 | lowerBound, upperBound, step, accum.getResult(), 169 | [&](OpBuilder &builder, Location loc, Value loopIndex, 170 | ValueRange loopState) { 171 | ImplicitLocOpBuilder b(op.getLoc(), builder); 172 | auto accum = loopState.front(); 173 | auto coeffIndex = b.create(numTermsOp, loopIndex); 174 | auto mulOp = b.create(point, accum); 175 | auto result = b.create( 176 | mulOp, b.create(poly, coeffIndex.getResult())); 177 | b.create(result.getResult()); 178 | }); 179 | 180 | rewriter.replaceOp(op, loop.getResult(0)); 181 | return success(); 182 | } 183 | }; 184 | 185 | struct ConvertFromTensor : public OpConversionPattern { 186 | ConvertFromTensor(mlir::MLIRContext *context) 187 | : OpConversionPattern(context) {} 188 | 189 | using OpConversionPattern::OpConversionPattern; 190 | 191 | LogicalResult matchAndRewrite( 192 | FromTensorOp op, OpAdaptor adaptor, 193 | ConversionPatternRewriter &rewriter) const override { 194 | auto resultTensorTy = cast( 195 | typeConverter->convertType(op->getResultTypes()[0])); 196 | auto resultShape = resultTensorTy.getShape()[0]; 197 | auto resultEltTy = resultTensorTy.getElementType(); 198 | 199 | auto inputTensorTy = op.getInput().getType(); 200 | auto inputShape = inputTensorTy.getShape()[0]; 201 | 202 | // Zero pad the tensor if the coefficients' size is less than the polynomial 203 | // degree. 204 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); 205 | auto coeffValue = adaptor.getInput(); 206 | if (inputShape < resultShape) { 207 | SmallVector low, high; 208 | low.push_back(rewriter.getIndexAttr(0)); 209 | high.push_back(rewriter.getIndexAttr(resultShape - inputShape)); 210 | coeffValue = b.create( 211 | resultTensorTy, coeffValue, low, high, 212 | b.create(rewriter.getIntegerAttr(resultEltTy, 0)), 213 | /*nofold=*/false); 214 | } 215 | 216 | rewriter.replaceOp(op, coeffValue); 217 | return success(); 218 | } 219 | }; 220 | 221 | struct ConvertToTensor : public OpConversionPattern { 222 | ConvertToTensor(mlir::MLIRContext *context) 223 | : OpConversionPattern(context) {} 224 | 225 | using OpConversionPattern::OpConversionPattern; 226 | 227 | LogicalResult matchAndRewrite( 228 | ToTensorOp op, OpAdaptor adaptor, 229 | ConversionPatternRewriter &rewriter) const override { 230 | rewriter.replaceOp(op, adaptor.getInput()); 231 | return success(); 232 | } 233 | }; 234 | 235 | struct ConvertConstant : public OpConversionPattern { 236 | ConvertConstant(mlir::MLIRContext *context) 237 | : OpConversionPattern(context) {} 238 | 239 | using OpConversionPattern::OpConversionPattern; 240 | 241 | LogicalResult matchAndRewrite( 242 | ConstantOp op, OpAdaptor adaptor, 243 | ConversionPatternRewriter &rewriter) const override { 244 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); 245 | auto constOp = b.create(adaptor.getCoefficients()); 246 | auto fromTensorOp = 247 | b.create(op.getResult().getType(), constOp); 248 | rewriter.replaceOp(op, fromTensorOp.getResult()); 249 | return success(); 250 | } 251 | }; 252 | 253 | struct PolyToStandard : impl::PolyToStandardBase { 254 | using PolyToStandardBase::PolyToStandardBase; 255 | 256 | void runOnOperation() override { 257 | MLIRContext *context = &getContext(); 258 | auto *module = getOperation(); 259 | 260 | ConversionTarget target(*context); 261 | target.addLegalDialect(); 262 | target.addIllegalDialect(); 263 | 264 | RewritePatternSet patterns(context); 265 | PolyToStandardTypeConverter typeConverter(context); 266 | patterns.add(typeConverter, 268 | context); 269 | 270 | populateFunctionOpInterfaceTypeConversionPattern( 271 | patterns, typeConverter); 272 | target.addDynamicallyLegalOp([&](func::FuncOp op) { 273 | return typeConverter.isSignatureLegal(op.getFunctionType()) && 274 | typeConverter.isLegal(&op.getBody()); 275 | }); 276 | 277 | populateReturnOpTypeConversionPattern(patterns, typeConverter); 278 | target.addDynamicallyLegalOp( 279 | [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); 280 | 281 | populateCallOpTypeConversionPattern(patterns, typeConverter); 282 | target.addDynamicallyLegalOp( 283 | [&](func::CallOp op) { return typeConverter.isLegal(op); }); 284 | 285 | populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); 286 | target.markUnknownOpDynamicallyLegal([&](Operation *op) { 287 | return isNotBranchOpInterfaceOrReturnLikeOp(op) || 288 | isLegalForBranchOpInterfaceTypeConversionPattern(op, 289 | typeConverter) || 290 | isLegalForReturnOpTypeConversionPattern(op, typeConverter); 291 | }); 292 | 293 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) { 294 | signalPassFailure(); 295 | } 296 | } 297 | }; 298 | 299 | } // namespace poly 300 | } // namespace tutorial 301 | } // namespace mlir 302 | --------------------------------------------------------------------------------