├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── README.md ├── build ├── cmake └── module │ ├── AddDMC.cmake │ └── DeclarativeCompiler.cmake ├── env ├── include ├── CMakeLists.txt └── dmc │ ├── CMakeLists.txt │ ├── Dynamic │ ├── Alias.h │ ├── DynamicAttribute.h │ ├── DynamicContext.h │ ├── DynamicDialect.h │ ├── DynamicObject.h │ ├── DynamicOperation.h │ ├── DynamicType.h │ ├── Metadata.h │ └── TypeIDAllocator.h │ ├── Embed │ ├── Constraints.h │ ├── Expose.h │ ├── InMemoryDef.h │ ├── Init.h │ ├── OpFormatGen.h │ ├── ParserPrinter.h │ ├── PythonGen.h │ └── TypeFormatGen.h │ ├── IO │ └── ModuleWriter.h │ ├── Kind.h │ ├── Python │ ├── APComplex.h │ ├── DialectAsm.h │ ├── OpAsm.h │ ├── Polymorphic.h │ └── PyMLIR.h │ ├── Spec │ ├── CMakeLists.txt │ ├── DialectGen.h │ ├── FormatOp.h │ ├── FormatOp.td │ ├── HasChildren.h │ ├── NamedConstraints.h │ ├── OpType.h │ ├── ParameterList.h │ ├── ParameterList.td │ ├── Parsing.h │ ├── ReparseOpInterface.h │ ├── ReparseOpInterface.td │ ├── SpecAttrBase.h │ ├── SpecAttrDetail.h │ ├── SpecAttrImplementation.h │ ├── SpecAttrSwitch.h │ ├── SpecAttrs.h │ ├── SpecDialect.h │ ├── SpecKinds.h │ ├── SpecOps.h │ ├── SpecRegion.h │ ├── SpecRegionSwitch.h │ ├── SpecSuccessor.h │ ├── SpecSuccessorSwitch.h │ ├── SpecTypeDetail.h │ ├── SpecTypeImplementation.h │ ├── SpecTypeSwitch.h │ ├── SpecTypes.h │ └── Support.h │ └── Traits │ ├── Kinds.h │ ├── OpTrait.h │ ├── Registry.h │ ├── SpecTraits.h │ └── StandardTraits.h ├── lib ├── CMakeLists.txt ├── Dynamic │ ├── CMakeLists.txt │ ├── DynamicAttribute.cpp │ ├── DynamicContext.cpp │ ├── DynamicDialect.cpp │ ├── DynamicDialectImpl.cpp │ ├── DynamicObject.cpp │ ├── DynamicOperation.cpp │ ├── DynamicType.cpp │ └── TypeIDAllocator.cpp ├── Embed │ ├── CMakeLists.txt │ ├── Constraints.cpp │ ├── Expose.cpp │ ├── FormatUtils.cpp │ ├── FormatUtils.h │ ├── InMemoryDef.cpp │ ├── Init.cpp │ ├── OpFormatGen.cpp │ ├── ParserPrinter.cpp │ ├── PythonGen.cpp │ ├── Scope.cpp │ ├── Scope.h │ ├── Spec.cpp │ └── TypeFormatGen.cpp ├── IO │ ├── CMakeLists.txt │ └── ModuleWriter.cpp ├── Python │ ├── AsmUtils.h │ ├── Attribute.cpp │ ├── Attribute.h │ ├── BuildableType.cpp │ ├── CMakeLists.txt │ ├── Context.cpp │ ├── Context.h │ ├── DialectAsm.cpp │ ├── DllInit.cpp │ ├── Expose.cpp │ ├── Expose.h │ ├── ExposeArrayAttr.cpp │ ├── ExposeAttribute.cpp │ ├── ExposeBuilder.cpp │ ├── ExposeDialectAsm.cpp │ ├── ExposeDictAttr.cpp │ ├── ExposeElementsAttr.cpp │ ├── ExposeFunctionType.cpp │ ├── ExposeIntFPAttr.cpp │ ├── ExposeLocation.cpp │ ├── ExposeModule.cpp │ ├── ExposeOpAsm.cpp │ ├── ExposeOpaqueType.cpp │ ├── ExposeOps.cpp │ ├── ExposeParser.cpp │ ├── ExposeShapedTypes.cpp │ ├── ExposeStandardTypes.cpp │ ├── ExposeSymbolRefAttr.cpp │ ├── ExposeType.cpp │ ├── ExposeValue.cpp │ ├── ExternalModule.cpp │ ├── Identifier.cpp │ ├── Identifier.h │ ├── Location.cpp │ ├── Location.h │ ├── Module.cpp │ ├── Module.h │ ├── OpAsm.cpp │ ├── OwningModuleRef.cpp │ ├── OwningModuleRef.h │ ├── Parser.cpp │ ├── Parser.h │ ├── Type.cpp │ ├── Type.h │ └── Utility.h ├── Spec │ ├── CMakeLists.txt │ ├── DialectGen.cpp │ ├── FormatOp.cpp │ ├── OpReparsing.cpp │ ├── OpType.cpp │ ├── ParameterList.cpp │ ├── Parsing.cpp │ ├── ReparseOpInterface.cpp │ ├── SpecAttrImplementation.cpp │ ├── SpecAttrs.cpp │ ├── SpecDialect.cpp │ ├── SpecOps.cpp │ ├── SpecRegion.cpp │ ├── SpecSuccessor.cpp │ ├── SpecTypeDetail.cpp │ ├── SpecTypeImplementation.cpp │ ├── SpecTypeParsing.cpp │ └── SpecTypes.cpp └── Traits │ ├── CMakeLists.txt │ ├── GenericConstructor.h │ ├── OpTrait.cpp │ ├── Registry.cpp │ ├── SpecTraits.cpp │ └── StandardTraits.cpp ├── list_targets ├── lua ├── .gitignore ├── CMakeLists.txt ├── Lua.g4 ├── Makefile ├── binarytree.lua ├── builtins.cpp ├── fannkuch.lua ├── impl.cpp ├── impl.h ├── lib.h ├── lib.mlir ├── loops.lua ├── lua.mlir ├── luac.py ├── markov.in ├── markov.lua ├── markov.mlir ├── perf.mlir ├── preliminary-results.txt ├── test.c └── test.lua ├── oec ├── .gitignore ├── CMakeLists.txt ├── Makefile ├── ast.txt ├── dl_stencil.cpp ├── laplace.mlir ├── main.cpp ├── python.mlir ├── stencil.py └── test.py ├── spec ├── dialect.mlir ├── laplace.mlir ├── stencil.mlir └── test.mlir └── tools ├── CMakeLists.txt ├── dialectgen.cpp └── spec.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | bin/ 3 | *.swp 4 | .idea/ 5 | .clangd/ 6 | compile_commands.json 7 | *.swo 8 | mlir.so 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "llvm-project"] 2 | path = llvm-project 3 | url = https://github.com/Mogball/llvm-project 4 | [submodule "pybind11"] 5 | path = pybind11 6 | url = https://github.com/Mogball/pybind11 7 | [submodule "lua/rx-cpp"] 8 | path = lua/rx-cpp 9 | url = https://github.com/stevedonovan/rx-cpp 10 | [submodule "oec/open-earth-compiler"] 11 | path = oec/open-earth-compiler 12 | url = https://github.com/Mogball/open-earth-compiler.git 13 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | # C++17 4 | set(CMAKE_CXX_STANDARD 17) 5 | set(CMAKE_CXX_STANDARD_REQUIRED TRUE) 6 | set(CMAKE_CXX_EXTENSIONS OFF) 7 | 8 | project(declarative-compiler) 9 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 10 | 11 | # Enable RTTI and exceptions. 12 | set(LLVM_ENABLE_RTTI ON CACHE INTERNAL "") 13 | set(LLVM_ENABLE_EH ON CACHE INTERNAL "") 14 | # Configure LLVM and MLIR directories 15 | set(LLVM_PROJECT_DIR llvm-project) 16 | set(LLVM_DIR ${LLVM_PROJECT_DIR}/llvm) 17 | set(MLIR_DIR ${LLVM_PROJECT_DIR}/mlir) 18 | set(LLVM_BINARY_DIR ${CMAKE_BINARY_DIR}/${LLVM_DIR}) 19 | set(MLIR_BINARY_DIR ${LLVM_BINARY_DIR}/tools/mlir) 20 | # Add LLVM and MLIR targets 21 | set(LLVM_TARGETS_TO_BUILD "host;NVPTX" CACHE INTERNAL "") 22 | set(LLVM_ENABLE_PROJECTS "mlir;clang" CACHE INTERNAL "") 23 | add_subdirectory(${LLVM_DIR}) 24 | # Configure mlir-tblgen from built target 25 | set(MLIR_TABLEGEN_EXE $) 26 | # Configure LLVM and MLIR include directories 27 | set(LLVM_INCLUDE_DIRS ${LLVM_DIR}/include ${LLVM_BINARY_DIR}/include) 28 | set(MLIR_INCLUDE_DIRS ${MLIR_DIR}/include ${MLIR_BINARY_DIR}/include) 29 | 30 | message(STATUS "Building project in directory: ${CMAKE_BINARY_DIR}") 31 | message(STATUS "LLVM directory: ${LLVM_DIR}") 32 | message(STATUS "MLIR directory: ${MLIR_DIR}") 33 | 34 | # Include project-specific configs 35 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/module") 36 | include(AddDMC) 37 | 38 | # Python bindings. 39 | find_package(Python3 REQUIRED COMPONENTS Interpreter Development) 40 | add_subdirectory(pybind11) 41 | 42 | # Global options 43 | add_compile_options(-Wall -fdiagnostics-color) 44 | 45 | # Global includes 46 | include_directories( 47 | include 48 | ${CMAKE_BINARY_DIR}/include 49 | ${LLVM_INCLUDE_DIRS} 50 | ${MLIR_INCLUDE_DIRS} 51 | ) 52 | 53 | add_subdirectory(include) 54 | add_subdirectory(lib) 55 | add_subdirectory(tools) 56 | add_subdirectory(lua) 57 | add_subdirectory(oec) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Declarative MLIR Compilers 2 | 3 | [Design document](https://docs.google.com/document/d/1eAgIQZZ2dItJFSrCxemt7fwH0CD4w6_ueLKVl6UL-NU/edit?usp=sharing) 4 | 5 | ## Building DMC 6 | 7 | Build requirements: 8 | 9 | - `cmake >= 3.10` 10 | - `python >= 3.6` 11 | 12 | Arch Linux: 13 | 14 | ```bash 15 | sudo pacman -Sy cmake python 16 | ``` 17 | 18 | MacOS: 19 | 20 | ```bash 21 | brew install cmake python3 22 | ``` 23 | 24 | Clone the repo and its submodules: 25 | 26 | ```bash 27 | git clone git@github.com:Mogball/declarative-mlir-compiler.git 28 | cd declarative-mlir-compiler 29 | git submodule update --init 30 | ``` 31 | 32 | Configure CMake with whatever generator you 33 | prefer: 34 | 35 | ```bash 36 | cd bin/ 37 | cmake ../ -G 38 | ``` 39 | 40 | Then build the target `gen`. 41 | 42 | ```bash 43 | cmake --build $BINDIR --target gen 44 | ``` 45 | 46 | ## Using DMC 47 | 48 | Let's define a dead simple dialect. Make sure the shared library is built with 49 | `cmake --build $BINDIR -t mlir`, and that `$BINDIR/lib/Python` is added to your 50 | `PYTHONPATH`. 51 | 52 | ```mlir 53 | // toy.mlir 54 | Dialect @toy { 55 | Op @add3(one: !dmc.Any, two: !dmc.Any, three: !dmc.Any) -> (res: !dmc.Any) 56 | traits [@SameType<"one", "two", "three", "res">] 57 | } 58 | ``` 59 | 60 | Our Python script will need to load the dialect. 61 | 62 | ```python 63 | from mlir import * 64 | 65 | dialects = registerDynamicDialects(parseSourceFile("toy.mlir")) 66 | toy = dialects[0] 67 | ``` 68 | 69 | Okay, let's generate a simple program. 70 | 71 | ```python 72 | m = ModuleOp() 73 | func = FuncOp("add3_impl", FunctionType([I64Type()] * 3, [I64Type()])) 74 | m.append(func) 75 | 76 | entry = func.addEntryBlock() 77 | args = entry.args 78 | b = Builder() 79 | b.insertAtStart(entry) 80 | add3 = b.create(toy.add3, one=args[0], two=args[1], three=args[2], 81 | res=I64Type()) 82 | b.create(ReturnOp, operands=[add3.res()]) 83 | verify(m) 84 | print(m) 85 | ``` 86 | 87 | The output should be 88 | 89 | ```mlir 90 | module { 91 | func @main(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 { 92 | %0 = "toy.add3"(%arg0, %arg1, %arg2) : (i64, i64, i64) -> i64 93 | return %0 : i64 94 | } 95 | } 96 | ``` 97 | 98 | ## Building the Lua Compiler 99 | 100 | The Lua compile requires `antlr >= 4`. On Arch, install the Pacman package 101 | `antlr4`. On MacOS, install the Homebrew package `antlr`. 102 | 103 | First, build the CMake targets `mlir`, `mlir-translate`, and `lua-parser`. 104 | 105 | ```bash 106 | cmake --build $BINDIR -t mlir -t mlir-translate -t lua-parser 107 | ``` 108 | 109 | The Python shared library and autogenerated Lua parser must be added to your 110 | `PYTHONPATH`. These can be found under `$BINDIR/lua/parser` and 111 | `$BINDIR/lib/Python`. 112 | 113 | ```bash 114 | export PATH=$PATH:$BINDIR/tools:$BINDIR/llvm-project/llvm/bin:$BINDIR/lua 115 | export PYTHONPATH=$BINDIR/lua/parser:$BINDIR/lib/Python 116 | ``` 117 | 118 | ## Using the Lua Compiler 119 | 120 | Let's define a basic Lua file. 121 | 122 | ```lua 123 | -- example.lua 124 | function fib(n) 125 | if n == 0 then return 0 end 126 | if n == 1 then return 1 end 127 | return fib(n-1) + fib(n-2) 128 | end 129 | 130 | print("fib(5) = ", fib(5)) 131 | ``` 132 | 133 | The Python script `lua/luac.py` calls the ANTLR parser and lowers the IR to 134 | LLVM IR. The resulting file must be compiled (`-O2` strongly recommended) and 135 | linked against a runtime `lua/impl.cpp` and the builtins `lua/builtins.cpp`. 136 | 137 | ```bash 138 | python3 lua/luac.py example.lua > example.mlir 139 | mlir-translate --mlir-to-llvmir example.mlir -o example.ll 140 | clang++ example.ll lua/impl.cpp lua/builtins.cpp -O2 -std=c++17 -o example 141 | ``` 142 | 143 | Running (hopefully) produces the correct output: 144 | 145 | ```bash 146 | me$ ./example 147 | fib(5) = 5 148 | ``` 149 | -------------------------------------------------------------------------------- /build: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | cd bin 3 | cmake --build . --target $1 4 | -------------------------------------------------------------------------------- /cmake/module/AddDMC.cmake: -------------------------------------------------------------------------------- 1 | # LLVM sets -fvisibility-inlines-hidden. We must match the setting or else 2 | # the linker will give warnings. 3 | if(NOT WIN32 AND NOT CYGWIN AND NOT (${CMAKE_SYSTEM_NAME} MATCHES "AIX" AND CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) 4 | check_cxx_compiler_flag("-fvisibility-inlines-hidden" SUPPORTS_FVISIBILITY_INLINES_HIDDEN_FLAG) 5 | append_if(SUPPORTS_FVISIBILITY_INLINES_HIDDEN_FLAG "-fvisibility-inlines-hidden" CMAKE_CXX_FLAGS) 6 | endif() 7 | 8 | -------------------------------------------------------------------------------- /cmake/module/DeclarativeCompiler.cmake: -------------------------------------------------------------------------------- 1 | function(add_dialect dialect dialect_namespace) 2 | set(LLVM_TARGET_DEFINITIONS ${dialect.td}) 3 | mlir_tablegen(${dialect}.h.inc -gen-op-defs -I ${MLIR_TABLEGEN_INCUDE}) 4 | mlir_tablegen(${dialect}.cpp.inc -gen-op-decls -I ${MLIR_TABLEGEN_INCUDE}) 5 | add_public_tablegen_target(${dialect}IncGen) 6 | endfunction() 7 | -------------------------------------------------------------------------------- /env: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | export PATH=$PATH:$(pwd)/bin/tools:$(pwd):$(pwd)/bin/llvm-project/llvm/bin:$(pwd)/bin/lua:$(pwd)/lua 3 | export PATH=$PATH:$(pwd)/bin/bin 4 | export LD_LIBRARY_PATH=$(pwd)/bin/llvm-project/llvm/lib 5 | export PYTHONPATH=$(pwd)/bin/lua/parser:$(pwd)/bin/lib/Python:$(pwd)/bin/oec 6 | -------------------------------------------------------------------------------- /include/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(dmc) 2 | -------------------------------------------------------------------------------- /include/dmc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Spec) 2 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/Alias.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Metadata.h" 4 | 5 | #include 6 | 7 | namespace dmc { 8 | 9 | /// An alias allows shorthand definitions of types or attributes. 10 | /// See `dmc.Alias`. Aliases are erased during parsing of the dialect module. 11 | /// 12 | /// `DynamicDialect::parseType`, for example, upon parsing a type alias, return 13 | /// the aliased type. 14 | class TypeAlias : public TypeMetadata { 15 | public: 16 | TypeAlias(llvm::StringRef name, mlir::Type aliasedType, 17 | llvm::Optional builder = {}) 18 | : TypeMetadata{name, builder}, 19 | aliasedType{aliasedType} {} 20 | 21 | inline auto getAliasedType() { return aliasedType; } 22 | 23 | private: 24 | mlir::Type aliasedType; 25 | }; 26 | 27 | /// `DynamicDialect::parseAttribute` will directly return the aliased 28 | /// attribute. 29 | class AttributeAlias : public AttributeMetadata { 30 | public: 31 | AttributeAlias(llvm::StringRef name, mlir::Attribute aliasedAttr, 32 | llvm::Optional builder = {}, 33 | mlir::Type type = {}) 34 | : AttributeMetadata{name, builder, type}, 35 | aliasedAttr{aliasedAttr} {} 36 | 37 | inline auto getAliasedAttr() { return aliasedAttr; } 38 | 39 | private: 40 | mlir::Attribute aliasedAttr; 41 | }; 42 | 43 | } // end namespace dmc 44 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/DynamicAttribute.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Metadata.h" 4 | #include "DynamicObject.h" 5 | #include "dmc/Kind.h" 6 | #include "dmc/Spec/ParameterList.h" 7 | 8 | #include 9 | 10 | namespace dmc { 11 | 12 | /// Forward declarations. 13 | class DynamicDialect; 14 | 15 | namespace detail { 16 | struct DynamicAttributeStorage; 17 | } // end namespace detail 18 | 19 | /// DynamicAttribute underlying class. Each dynamic Attribute instance holds 20 | /// a reference to an instance of this class. Implementation details are 21 | /// similar to DynamicType. 22 | class DynamicAttributeImpl : public DynamicObject, public AttributeMetadata { 23 | public: 24 | /// Create a dynamic attribute with the given name and parameter spec. 25 | explicit DynamicAttributeImpl(DynamicDialect *dialect, llvm::StringRef name, 26 | NamedParameterRange paramSpec); 27 | 28 | /// Getters. 29 | inline DynamicDialect *getDialect() { return dialect; } 30 | inline auto getParamSpec() { return paramSpec; } 31 | 32 | /// Delegate parse and printer. 33 | mlir::Attribute parseAttribute(mlir::Location loc, 34 | mlir::DialectAsmParser &parser); 35 | void printAttribute(mlir::Attribute attr, mlir::DialectAsmPrinter &printer); 36 | void setFormat(std::string parserName, std::string printerName); 37 | 38 | private: 39 | /// The dialect to which this attribute belongs. 40 | DynamicDialect *dialect; 41 | /// The dynamic attribute is formed by composing other attributes. The 42 | /// attributes must be Spec attributes. 43 | NamedParameterRange paramSpec; 44 | 45 | /// The function names of the custom parser and printer, if present. 46 | llvm::Optional parserFcn, printerFcn; 47 | 48 | friend class DynamicAttribute; 49 | }; 50 | 51 | class DynamicAttribute 52 | : public mlir::Attribute::AttrBase { 54 | /// Provide a single kind for casting to DynamicAttribute. 55 | static constexpr auto DynamicAttributeKind = dmc::Kind::FIRST_DYNAMIC_ATTR; 56 | 57 | public: 58 | using Base::Base; 59 | 60 | /// Get a DynamicAttribute with a backing DynamicAttributeImpl and parameter 61 | /// values. 62 | static DynamicAttribute get(DynamicAttributeImpl *impl, 63 | llvm::ArrayRef params); 64 | static DynamicAttribute getChecked( 65 | mlir::Location loc, DynamicAttributeImpl *impl, 66 | llvm::ArrayRef params); 67 | /// Verify that the parameter attributes are valid. 68 | static mlir::LogicalResult verifyConstructionInvariants( 69 | mlir::Location loc, DynamicAttributeImpl *impl, 70 | llvm::ArrayRef params); 71 | 72 | /// Allow casting Attribute to DynamicAttribute. 73 | static bool kindof(unsigned kind) { return kind == DynamicAttributeKind; } 74 | 75 | /// Getters. 76 | DynamicAttributeImpl *getDynImpl(); 77 | llvm::ArrayRef getParams(); 78 | }; 79 | 80 | } // end namesdpace dmc 81 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/DynamicContext.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "TypeIDAllocator.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace dmc { 11 | 12 | /// Forward declarations. 13 | class DynamicDialect; 14 | 15 | /// Manages the creation and lifetime of dynamic MLIR objects: 16 | /// Dialects, Operations, Types, and Attributes. 17 | class DynamicContext : public mlir::Dialect { 18 | public: 19 | ~DynamicContext(); 20 | DynamicContext(mlir::MLIRContext *ctx); 21 | 22 | /// Subclass dialect to make the dynamic context globally accessible. 23 | static llvm::StringRef getDialectNamespace() { return "dyn_context"; } 24 | 25 | /// Getters. 26 | TypeIDAllocator *getTypeIDAlloc() { return typeIdAlloc; } 27 | 28 | /// Create a DynamicDialect and return an instance registered with 29 | /// the MLIRContext. 30 | DynamicDialect *createDynamicDialect(llvm::StringRef name); 31 | /// Lookup the dynamic dialect belonging to a dynamic MLIR object. This is 32 | /// necessary since aliased types and attributes do subclass a generic class. 33 | /// 34 | /// TODO This is not an ideal solution. 35 | DynamicDialect *lookupDialectFor(mlir::Type type); 36 | DynamicDialect *lookupDialectFor(mlir::Attribute attr); 37 | DynamicDialect *lookupDialectFor(mlir::OperationName opName); 38 | mlir::LogicalResult registerDialectSymbol(DynamicDialect *dialect, 39 | mlir::Type type); 40 | mlir::LogicalResult registerDialectSymbol(DynamicDialect *dialect, 41 | mlir::Attribute attr); 42 | mlir::LogicalResult registerDialectSymbol(DynamicDialect *dialect, 43 | mlir::OperationName opName); 44 | 45 | private: 46 | class Impl; 47 | TypeIDAllocator *typeIdAlloc; 48 | std::unique_ptr impl; 49 | }; 50 | 51 | } // end namespace dmc 52 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/DynamicObject.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace dmc { 6 | 7 | /// Forward declarations. 8 | class DynamicContext; 9 | 10 | /// TypeIDs are not associated with a class type but are assigned to an instance 11 | /// of a dynamic object that mocks an otherwise statically known class. 12 | class DynamicObject { 13 | public: 14 | explicit DynamicObject(DynamicContext *ctx); 15 | 16 | inline DynamicContext *getDynContext() const { return ctx; } 17 | inline mlir::TypeID getTypeID() { return typeId; } 18 | 19 | private: 20 | DynamicContext *ctx; 21 | mlir::TypeID typeId; 22 | }; 23 | 24 | } // end namespace dmc 25 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/DynamicType.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Metadata.h" 4 | #include "DynamicObject.h" 5 | #include "dmc/Kind.h" 6 | #include "dmc/Spec/ParameterList.h" 7 | 8 | #include 9 | 10 | namespace dmc { 11 | 12 | /// Forward declarations. 13 | class DynamicDialect; 14 | 15 | namespace detail{ 16 | struct DynamicTypeStorage; 17 | } // end namespace detail 18 | 19 | /// DynamicType underlying class. The class stores type class functions like 20 | /// the parser, printer, and conversions. Each dynamic Type instance holds a 21 | /// reference to an instance of this class. 22 | class DynamicTypeImpl : public DynamicObject, public TypeMetadata { 23 | public: 24 | /// Create a dynamic type with the provided name and parameter spec. 25 | explicit DynamicTypeImpl(DynamicDialect *dialect, llvm::StringRef name, 26 | NamedParameterRange paramSpec); 27 | 28 | /// Getters. 29 | inline DynamicDialect *getDialect() { return dialect; } 30 | inline auto getParamSpec() { return paramSpec; } 31 | 32 | /// Delegate parser and printer. 33 | mlir::Type parseType(mlir::Location loc, mlir::DialectAsmParser &parser); 34 | void printType(mlir::Type type, mlir::DialectAsmPrinter &printer); 35 | void setFormat(std::string parserName, std::string printerName); 36 | 37 | private: 38 | /// The dialect to which this type belongs. 39 | DynamicDialect *dialect; 40 | /// The parameters are defined by Attribute constraints. The Attribute 41 | /// instances must be Spec attributes. 42 | NamedParameterRange paramSpec; 43 | 44 | /// The function names of the custom parser and printer, if present. 45 | llvm::Optional parserFcn, printerFcn; 46 | 47 | friend class DynamicType; 48 | }; 49 | 50 | /// DynamicType class. Stores parameters according to DynamicTypeImpl. 51 | class DynamicType : public mlir::Type::TypeBase { 53 | /// Static casting between with DynamicType doesn't make sense so provide 54 | /// a single kind for casting to DynamicType. 55 | static constexpr auto DynamicTypeKind = dmc::Kind::FIRST_DYNAMIC_TYPE; 56 | 57 | public: 58 | using Base::Base; 59 | 60 | /// Get a DynamicType with a backing DynamicTypeImpl and provided parameter 61 | /// values. 62 | static DynamicType get(DynamicTypeImpl *impl, 63 | llvm::ArrayRef params); 64 | static DynamicType getChecked(mlir::Location loc, DynamicTypeImpl *impl, 65 | llvm::ArrayRef params); 66 | /// Verify that the parameter attributes are valid. 67 | static mlir::LogicalResult verifyConstructionInvariants( 68 | mlir::Location loc, DynamicTypeImpl *impl, 69 | llvm::ArrayRef params); 70 | 71 | /// Allow casting of Type to DynamicType. 72 | static bool kindof(unsigned kind) { return kind == DynamicTypeKind; } 73 | 74 | /// Getters. 75 | DynamicTypeImpl *getDynImpl(); 76 | llvm::ArrayRef getParams(); 77 | mlir::Attribute getParam(llvm::StringRef name); 78 | }; 79 | 80 | DynamicType buildDynamicType( 81 | llvm::StringRef dialectName, llvm::StringRef typeName, 82 | llvm::ArrayRef params, mlir::Location loc); 83 | mlir::Type getAliasedType(llvm::StringRef dialectName, llvm::StringRef typeName, 84 | mlir::MLIRContext *ctx); 85 | 86 | } // end namespace dmc 87 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/Metadata.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /// A structure to store metadata for dynamic types and attributes, such as 6 | /// constant builder strings and names. This is to provide a unified source of 7 | /// metadata across aliases and concrete types. 8 | /// 9 | /// The metadata will have to be queried directly from the dynamic dialect. 10 | namespace dmc { 11 | 12 | class TypeMetadata { 13 | public: 14 | explicit TypeMetadata(llvm::StringRef name, 15 | llvm::Optional builder) 16 | : name{name}, 17 | builder{builder} {} 18 | 19 | inline auto getName() { return name; } 20 | inline auto getBuilder() { return builder; } 21 | 22 | private: 23 | /// The name of the type. 24 | llvm::StringRef name; 25 | /// An optional Python builder. 26 | llvm::Optional builder; 27 | }; 28 | 29 | class AttributeMetadata { 30 | public: 31 | explicit AttributeMetadata(llvm::StringRef name, 32 | llvm::Optional builder, 33 | mlir::Type type) 34 | : name{name}, 35 | builder{builder}, 36 | type{type} {} 37 | 38 | inline auto getName() { return name; } 39 | inline auto getBuilder() { return builder; } 40 | inline auto getType() { return type; } 41 | 42 | private: 43 | /// The name of the attribute. 44 | llvm::StringRef name; 45 | /// An optional Python builder. 46 | llvm::Optional builder; 47 | /// An optional attribute type. 48 | mlir::Type type; 49 | }; 50 | 51 | } // end namespace dmc 52 | -------------------------------------------------------------------------------- /include/dmc/Dynamic/TypeIDAllocator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace dmc { 6 | 7 | /// MLIR relies on static type IDs of classes, such as Dialect, Type, 8 | /// and Attribute, to manage objects. Since we are dynamically creating 9 | /// objects, we need to dynamically allocate TypeIDs. 10 | class TypeIDAllocator { 11 | public: 12 | virtual mlir::TypeID allocateID() = 0; 13 | }; 14 | 15 | TypeIDAllocator *getFixedTypeIDAllocator(); 16 | 17 | } // end namespace dmc 18 | -------------------------------------------------------------------------------- /include/dmc/Embed/Constraints.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace dmc { 9 | namespace py { 10 | 11 | mlir::LogicalResult registerConstraint(mlir::Location loc, llvm::StringRef expr, 12 | std::string &funcName); 13 | 14 | mlir::LogicalResult evalConstraint(const std::string &funcName, 15 | mlir::Type type); 16 | mlir::LogicalResult evalConstraint(const std::string &funcName, 17 | mlir::Attribute attr); 18 | 19 | } // end namespace py 20 | } // end namespace dmc 21 | -------------------------------------------------------------------------------- /include/dmc/Embed/Expose.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace dmc { 7 | class DynamicDialect; 8 | namespace py { 9 | void exposeDialectInternal(DynamicDialect *dialect, 10 | llvm::ArrayRef scope); 11 | } // end namespace py 12 | } // end namespace dmc 13 | -------------------------------------------------------------------------------- /include/dmc/Embed/InMemoryDef.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "PythonGen.h" 4 | 5 | #include 6 | 7 | namespace pybind11 { 8 | class module; 9 | } 10 | 11 | namespace dmc { 12 | namespace py { 13 | 14 | class InMemoryStream { 15 | public: 16 | inline PythonGenStream &stream() { return pgs; } 17 | inline const std::string &str() { return os.str(); } 18 | 19 | protected: 20 | std::string buf; 21 | llvm::raw_string_ostream os{buf}; 22 | PythonGenStream pgs{os}; 23 | }; 24 | 25 | class InMemoryDef : public InMemoryStream { 26 | public: 27 | explicit InMemoryDef(llvm::StringRef fcnName, llvm::StringRef fcnSig); 28 | ~InMemoryDef(); 29 | }; 30 | 31 | class InMemoryClass : public InMemoryStream { 32 | public: 33 | explicit InMemoryClass( 34 | llvm::StringRef clsName, llvm::ArrayRef parentCls, 35 | pybind11::module &m); 36 | ~InMemoryClass(); 37 | 38 | private: 39 | pybind11::module &m; 40 | }; 41 | 42 | } // end namespace py 43 | } // end namespace dmc 44 | -------------------------------------------------------------------------------- /include/dmc/Embed/Init.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace mlir { 4 | class MLIRContext; 5 | namespace py { 6 | void init(MLIRContext *ctx); 7 | } // end namespace py 8 | } // end namespace mlir 9 | -------------------------------------------------------------------------------- /include/dmc/Embed/OpFormatGen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "PythonGen.h" 4 | #include "dmc/Spec/SpecOps.h" 5 | 6 | mlir::LogicalResult generateOpFormat(dmc::OperationOp op, 7 | dmc::py::PythonGenStream &parserOs, 8 | dmc::py::PythonGenStream &printerOs); 9 | -------------------------------------------------------------------------------- /include/dmc/Embed/ParserPrinter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace mlir { 7 | class OpAsmParser; 8 | class OpAsmPrinter; 9 | class Operation; 10 | struct OperationState; 11 | class DialectAsmParser; 12 | class DialectAsmPrinter; 13 | class Attribute; 14 | } // end namespace mlir 15 | 16 | namespace dmc { 17 | class DynamicOperation; 18 | namespace py { 19 | bool execParser(const std::string &name, mlir::OpAsmParser &parser, 20 | mlir::OperationState &result); 21 | void execPrinter(const std::string &name, mlir::OpAsmPrinter &printer, 22 | mlir::Operation *op, DynamicOperation *spec); 23 | bool execParser(const std::string &name, mlir::DialectAsmParser &parser, 24 | std::vector &result); 25 | template 26 | void execPrinter(const std::string &name, mlir::DialectAsmPrinter &printer, 27 | DynamicT type); 28 | } // end namespace py 29 | } // end namespace dmc 30 | -------------------------------------------------------------------------------- /include/dmc/Embed/PythonGen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace dmc { 7 | namespace py { 8 | 9 | class PythonGenStream { 10 | public: 11 | class Line { 12 | public: 13 | template Line &operator<<(ArgT &&arg) & { 14 | s.os << std::forward(arg); 15 | return *this; 16 | } 17 | 18 | template Line &&operator<<(ArgT &&arg) && { 19 | return std::move(operator<<(std::forward(arg))); 20 | } 21 | 22 | inline Line &operator<<(PythonGenStream &(*fcn)(PythonGenStream &)) & { 23 | fcn(s); 24 | return *this; 25 | } 26 | 27 | inline Line &&operator<<(PythonGenStream &(*fcn)(PythonGenStream &)) && { 28 | return std::move(operator<<(fcn)); 29 | } 30 | 31 | inline Line &operator<<(std::function fcn) & { 32 | fcn(*this); 33 | return *this; 34 | } 35 | 36 | ~Line(); 37 | Line(Line &&line); 38 | 39 | private: 40 | explicit Line(PythonGenStream &s); 41 | Line(const Line &) = delete; 42 | 43 | PythonGenStream &s; 44 | bool newline; 45 | 46 | friend class PythonGenStream; 47 | }; 48 | 49 | explicit PythonGenStream(llvm::raw_ostream &os); 50 | 51 | Line line(); 52 | 53 | PythonGenStream &block(llvm::StringRef ty, llvm::Twine expr); 54 | PythonGenStream &endblock(); 55 | 56 | inline PythonGenStream &if_(llvm::Twine expr) { 57 | return block("if", expr); 58 | } 59 | inline PythonGenStream &else_() { 60 | endif(); 61 | return block("else", ""); 62 | } 63 | inline PythonGenStream &def(llvm::Twine decl) { 64 | return block("def", decl); 65 | } 66 | 67 | inline PythonGenStream &endif() { return endblock(); } 68 | inline PythonGenStream &enddef() { return endblock(); } 69 | 70 | inline PythonGenStream &incr() { 71 | changeIndent(4); 72 | return *this; 73 | } 74 | inline PythonGenStream &decr() { 75 | changeIndent(-4); 76 | return *this; 77 | } 78 | 79 | private: 80 | void changeIndent(int delta); 81 | 82 | llvm::raw_ostream &os; 83 | int indent; 84 | 85 | friend class Line; 86 | friend PythonGenStream &incr(PythonGenStream &); 87 | friend PythonGenStream &decr(PythonGenStream &); 88 | }; 89 | 90 | inline PythonGenStream &incr(PythonGenStream &s) { return s.incr(); } 91 | inline PythonGenStream &decr(PythonGenStream &s) { return s.decr(); } 92 | 93 | } // end namespace py 94 | } // end namespace dmc 95 | -------------------------------------------------------------------------------- /include/dmc/Embed/TypeFormatGen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "PythonGen.h" 4 | #include "dmc/Spec/ParameterList.h" 5 | #include "dmc/Spec/FormatOp.h" 6 | 7 | template 8 | mlir::LogicalResult generateTypeFormat(OpT op, DynamicT *impl, 9 | dmc::py::PythonGenStream &parserOs, 10 | dmc::py::PythonGenStream &printerOs); 11 | -------------------------------------------------------------------------------- /include/dmc/IO/ModuleWriter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "dmc/Dynamic/DynamicContext.h" 8 | 9 | namespace dmc { 10 | 11 | /// Forward declarations. 12 | class DynamicOperation; 13 | 14 | /// This class provides an API for writing DynamicOperations to 15 | /// a single MLIR Module. It hides some of the gritty underworkings. 16 | class ModuleWriter { 17 | public: 18 | explicit ModuleWriter(DynamicContext *ctx); 19 | 20 | inline mlir::ModuleOp getModule() { return module; } 21 | 22 | mlir::FuncOp createFunction( 23 | llvm::StringRef name, 24 | llvm::ArrayRef argTys, 25 | llvm::ArrayRef retTys); 26 | 27 | private: 28 | mlir::OpBuilder builder; 29 | mlir::ModuleOp module; 30 | }; 31 | 32 | class FunctionWriter { 33 | public: 34 | explicit FunctionWriter(mlir::FuncOp func); 35 | 36 | inline auto getArguments() { return func.getArguments(); } 37 | 38 | mlir::Operation *createOp( 39 | DynamicOperation *op, mlir::ValueRange args, 40 | llvm::ArrayRef retTys); 41 | 42 | mlir::Operation *createOp( 43 | llvm::StringRef name, mlir::ValueRange args, 44 | llvm::ArrayRef retTys); 45 | 46 | private: 47 | mlir::OpBuilder builder; 48 | mlir::FuncOp func; 49 | 50 | mlir::Block *entryBlock; 51 | 52 | /// Create a generic Operation. 53 | mlir::Operation *createOp( 54 | mlir::OperationName opName, mlir::ValueRange args, 55 | llvm::ArrayRef retTys); 56 | }; 57 | 58 | } // end namespace dmc 59 | -------------------------------------------------------------------------------- /include/dmc/Kind.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace dmc { 7 | namespace Kind { 8 | 9 | constexpr auto FIRST_SPEC_TYPE = 10 | mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE; 11 | constexpr auto FIRST_SPEC_ATTR = 12 | mlir::Attribute::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_ATTR; 13 | 14 | constexpr auto FIRST_DYNAMIC_TYPE = 15 | mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_1_TYPE; 16 | constexpr auto FIRST_DYNAMIC_ATTR = 17 | mlir::Attribute::Kind::FIRST_PRIVATE_EXPERIMENTAL_1_ATTR; 18 | 19 | } // end namespace Kind 20 | } // end namespace dmc 21 | -------------------------------------------------------------------------------- /include/dmc/Python/APComplex.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | /// Common type rebinds. 8 | namespace pybind11 { 9 | namespace detail { 10 | 11 | // warning: precision loss ahead 12 | // std::complex 13 | template <> struct type_caster> { 14 | bool load(handle src, bool convert) { 15 | if (!src) 16 | return false; 17 | if (!convert && !PyComplex_Check(src.ptr())) 18 | return false; 19 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 20 | if (result.real == -1.0 && PyErr_Occurred()) { 21 | PyErr_Clear(); 22 | return false; 23 | } 24 | /// Store to 64-bit integer 25 | using storage_t = uint64_t; 26 | auto realVal = static_cast(result.real); 27 | auto imagVal = static_cast(result.imag); 28 | constexpr unsigned bitWidth = sizeof(storage_t) * CHAR_BIT; 29 | value = std::complex{{bitWidth, llvm::makeArrayRef(realVal)}, 30 | {bitWidth, llvm::makeArrayRef(imagVal)}}; 31 | return true; 32 | } 33 | 34 | static handle cast(const std::complex &src, 35 | return_value_policy /* policy */, 36 | handle /* parent */) { 37 | return PyComplex_FromDoubles(src.real().getZExtValue(), 38 | src.imag().getZExtValue()); 39 | } 40 | 41 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 42 | }; 43 | 44 | // std::complex 45 | template <> struct type_caster> { 46 | bool load(handle src, bool convert) { 47 | if (!src) 48 | return false; 49 | if (!convert && !PyComplex_Check(src.ptr())) 50 | return false; 51 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 52 | if (result.real == -1.0 && PyErr_Occurred()) { 53 | PyErr_Clear(); 54 | return false; 55 | } 56 | /// Store to double. 57 | value = std::complex{llvm::APFloat{result.real}, 58 | llvm::APFloat{result.imag}}; 59 | return true; 60 | } 61 | 62 | static handle cast(const std::complex &src, 63 | return_value_policy /* policy */, 64 | handle /* parent */) { 65 | return PyComplex_FromDoubles(src.real().convertToDouble(), 66 | src.imag().convertToDouble()); 67 | } 68 | 69 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 70 | }; 71 | 72 | } // end namespace detail 73 | } // end namespace pybind11 74 | -------------------------------------------------------------------------------- /include/dmc/Python/DialectAsm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Spec/ParameterList.h" 4 | 5 | #include 6 | #include 7 | 8 | namespace dmc { 9 | class DynamicType; 10 | class DynamicAttribute; 11 | namespace py { 12 | 13 | class TypeWrap { 14 | public: 15 | explicit TypeWrap(DynamicType type); 16 | explicit TypeWrap(DynamicAttribute attr); 17 | 18 | auto getParams() { return params; } 19 | auto getSpec() { return paramSpec; } 20 | 21 | private: 22 | llvm::ArrayRef params; 23 | NamedParameterRange paramSpec; 24 | }; 25 | 26 | class TypeResultWrap { 27 | public: 28 | explicit TypeResultWrap(std::vector &result) 29 | : result{result} {} 30 | 31 | auto &getImpl() { return result; } 32 | 33 | private: 34 | std::vector &result; 35 | }; 36 | 37 | } // end namespace py 38 | } // end namespace dmc 39 | -------------------------------------------------------------------------------- /include/dmc/Python/OpAsm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace dmc { 7 | class DynamicOperation; 8 | class TypeConstraintTrait; 9 | class AttrConstraintTrait; 10 | class RegionConstraintTrait; 11 | class SuccessorConstraintTrait; 12 | namespace py { 13 | 14 | class OperationWrap { 15 | public: 16 | explicit OperationWrap(mlir::Operation *op, DynamicOperation *spec); 17 | 18 | auto *getOp() { return op; } 19 | auto *getSpec() { return spec; } 20 | 21 | mlir::Value getOperand(std::string name); 22 | mlir::Value getResult(std::string name); 23 | mlir::ValueRange getOperandGroup(std::string name); 24 | mlir::ValueRange getResultGroup(std::string name); 25 | mlir::Region &getRegion(std::string name); 26 | 27 | mlir::Value getOperandOrResult(llvm::StringRef name); 28 | mlir::ValueRange getOperandOrResultGroup(llvm::StringRef name); 29 | 30 | private: 31 | mlir::Operation *op; 32 | DynamicOperation *spec; 33 | TypeConstraintTrait *type; 34 | AttrConstraintTrait *attr; 35 | SuccessorConstraintTrait *succ; 36 | RegionConstraintTrait *region; 37 | }; 38 | 39 | } // end namespace py 40 | } // end namespace dmc 41 | -------------------------------------------------------------------------------- /include/dmc/Python/PyMLIR.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Polymorphic.h" 4 | 5 | #include 6 | 7 | namespace mlir { 8 | class MLIRContext; 9 | namespace py { 10 | void getModule(pybind11::module &m); 11 | void setMLIRContext(MLIRContext *ctx); 12 | MLIRContext *getMLIRContext(); 13 | } // end namespace py 14 | } // end namespace mlir 15 | -------------------------------------------------------------------------------- /include/dmc/Spec/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS ParameterList.td) 2 | mlir_tablegen(ParameterList.h.inc -gen-op-interface-decls) 3 | mlir_tablegen(ParameterList.cpp.inc -gen-op-interface-defs) 4 | add_public_tablegen_target(DMCParameterListIncGen) 5 | 6 | set(LLVM_TARGET_DEFINITIONS ReparseOpInterface.td) 7 | mlir_tablegen(ReparseOpInterface.h.inc -gen-op-interface-decls) 8 | mlir_tablegen(ReparseOpInterface.cpp.inc -gen-op-interface-defs) 9 | add_public_tablegen_target(DMCReparseOpInterfaceIncGen) 10 | 11 | set(LLVM_TARGET_DEFINITIONS FormatOp.td) 12 | mlir_tablegen(FormatOp.h.inc -gen-op-interface-decls) 13 | mlir_tablegen(FormatOp.cpp.inc -gen-op-interface-defs) 14 | add_public_tablegen_target(DMCFormatOpIncGen) 15 | -------------------------------------------------------------------------------- /include/dmc/Spec/DialectGen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecOps.h" 4 | #include "dmc/Dynamic/DynamicContext.h" 5 | 6 | #include 7 | 8 | namespace dmc { 9 | 10 | mlir::LogicalResult registerDialect(DialectOp dialectOp, DynamicContext *ctx, 11 | llvm::ArrayRef scope); 12 | mlir::LogicalResult registerAllDialects(mlir::ModuleOp dialects, 13 | DynamicContext *ctx); 14 | 15 | } // end namespace dmc 16 | -------------------------------------------------------------------------------- /include/dmc/Spec/FormatOp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | #include "dmc/Spec/FormatOp.h.inc" 7 | } // end namespace mlir 8 | -------------------------------------------------------------------------------- /include/dmc/Spec/FormatOp.td: -------------------------------------------------------------------------------- 1 | #ifndef FORMAT_OP_TD 2 | #define FORMAT_OP_TD 3 | 4 | include "mlir/IR/OpBase.td" 5 | 6 | def FormatOp : OpInterface<"FormatOp"> { 7 | let description = [{ 8 | This class provides an interface and verifier for operations that define 9 | MLIR objects with an optional format string, for generating parsers and 10 | printers. 11 | }]; 12 | 13 | let methods = [ 14 | InterfaceMethod<[{ Get the format string attribute. }], 15 | "StringAttr", "getAssemblyFormat", (ins), [{}], [{ 16 | return this->getOperation()->template getAttrOfType( 17 | getFmtAttrName()); 18 | }] 19 | > 20 | ]; 21 | 22 | let verify = [{ 23 | auto fmtAttr = $_op->getAttr(getFmtAttrName()); 24 | if (fmtAttr && !fmtAttr.isa()) 25 | return $_op->emitOpError("expected attribute '") << getFmtAttrName() 26 | << "' to be a string attribute"; 27 | return success(); 28 | }]; 29 | 30 | let extraTraitClassDeclaration = [{ 31 | static llvm::StringLiteral getFmtAttrName() { return "fmt"; } 32 | }]; 33 | } 34 | 35 | #endif // FORMAT_OP_TD 36 | -------------------------------------------------------------------------------- /include/dmc/Spec/HasChildren.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace dmc { 6 | namespace OpTrait { 7 | 8 | /// Check that an operation is one of the specified types. 9 | namespace detail { 10 | template struct IsOneOfImpl; 11 | 12 | template 13 | struct IsOneOfImpl { 14 | bool operator()(mlir::Operation *op) { 15 | return llvm::isa(op) || IsOneOfImpl{}(op); 16 | } 17 | }; 18 | 19 | template <> struct IsOneOfImpl<> { 20 | bool operator()(mlir::Operation *) { return false; } 21 | }; 22 | } // end namespace detail 23 | 24 | template 25 | bool isOneOf(mlir::Operation *op) { 26 | return detail::IsOneOfImpl{}(op); 27 | } 28 | 29 | /// Print a list of operation names. 30 | namespace detail { 31 | template struct PrintOpNamesImpl; 32 | 33 | template 34 | struct PrintOpNamesImpl { 35 | void operator()(mlir::InFlightDiagnostic &diag) { 36 | diag << OpType::getOperationName() << ", "; 37 | PrintOpNamesImpl{}(diag); 38 | } 39 | }; 40 | 41 | template 42 | struct PrintOpNamesImpl { 43 | void operator()(mlir::InFlightDiagnostic &diag) { 44 | diag << OpType::getOperationName(); 45 | } 46 | }; 47 | } // end namespace detail 48 | 49 | template 50 | void printOpNames(mlir::InFlightDiagnostic &diag) { 51 | return detail::PrintOpNamesImpl{}(diag); 52 | } 53 | 54 | /// Assert that if an operation has children, the children must each be one of 55 | /// the specified operations. 56 | template 57 | struct HasOnlyChildren { 58 | /// ConcreteType must be an iteratable op. 59 | template 60 | struct Impl : public mlir::OpTrait::TraitBase { 61 | public: 62 | static mlir::LogicalResult verifyTrait(mlir::Operation *op) { 63 | for (auto &child : llvm::cast(op)) { 64 | if (!isOneOf(&child)) { 65 | op->emitOpError("has invalid child operation '") 66 | << child.getName() << "'\n"; 67 | auto diag = child.emitOpError("must be one of [ "); 68 | printOpNames(diag); 69 | return diag << " ]"; 70 | } 71 | } 72 | return mlir::success(); 73 | } 74 | }; 75 | }; 76 | 77 | } // end namespace OpTrait 78 | } // end namespace dmc 79 | -------------------------------------------------------------------------------- /include/dmc/Spec/NamedConstraints.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | 5 | #include 6 | 7 | namespace dmc { 8 | namespace detail { 9 | struct NamedConstraintStorage; 10 | } // end namespace detail 11 | 12 | struct NamedConstraint { 13 | llvm::StringRef name; 14 | mlir::Attribute attr; 15 | 16 | bool isVariadic() const; 17 | }; 18 | 19 | namespace detail { 20 | inline mlir::Attribute unwrap(const NamedConstraint &a) { return a.attr; } 21 | } // end namespace detail 22 | 23 | class OpRegion : public mlir::Attribute::AttrBase< 24 | OpRegion, mlir::Attribute, detail::NamedConstraintStorage> { 25 | public: 26 | using Base::Base; 27 | 28 | static OpRegion getChecked(mlir::Location loc, 29 | llvm::ArrayRef regions); 30 | 31 | static bool kindof(unsigned kind) 32 | { return kind == AttrKinds::OpRegion; } 33 | 34 | llvm::ArrayRef getRegions() const; 35 | 36 | inline unsigned getNumRegions() { return std::size(getRegions()); } 37 | inline const NamedConstraint *getRegion(unsigned idx) 38 | { return &getRegions()[idx]; } 39 | inline llvm::StringRef getRegionName(unsigned idx) 40 | { return getRegion(idx)->name; } 41 | inline mlir::Attribute getRegionAttr(unsigned idx) 42 | { return getRegion(idx)->attr; } 43 | 44 | inline auto getRegionAttrs() const 45 | { return llvm::map_range(getRegions(), detail::unwrap); } 46 | 47 | inline auto begin() const { return std::begin(getRegionAttrs()); } 48 | inline auto end() const { return std::end(getRegionAttrs()); } 49 | inline auto size() const { return std::size(getRegions()); } 50 | }; 51 | 52 | class OpSuccessor : public mlir::Attribute::AttrBase< 53 | OpSuccessor, mlir::Attribute, detail::NamedConstraintStorage> { 54 | public: 55 | using Base::Base; 56 | 57 | static OpSuccessor getChecked(mlir::Location loc, 58 | llvm::ArrayRef successors); 59 | 60 | static bool kindof(unsigned kind) 61 | { return kind == AttrKinds::OpSuccessor; } 62 | 63 | llvm::ArrayRef getSuccessors() const; 64 | 65 | inline unsigned getNumSuccessors() { return std::size(getSuccessors()); } 66 | inline const NamedConstraint *getSuccessor(unsigned idx) 67 | { return &getSuccessors()[idx]; } 68 | inline llvm::StringRef getSuccessorName(unsigned idx) 69 | { return getSuccessor(idx)->name; } 70 | inline mlir::Attribute getSuccessorAttr(unsigned idx) 71 | { return getSuccessor(idx)->attr; } 72 | 73 | inline auto getSuccessorAttrs() const 74 | { return llvm::map_range(getSuccessors(), detail::unwrap); } 75 | 76 | inline auto begin() const { return std::begin(getSuccessorAttrs()); } 77 | inline auto end() const { return std::end(getSuccessorAttrs()); } 78 | inline auto size() const { return std::size(getSuccessors()); } 79 | }; 80 | 81 | } // end namespace dmc 82 | -------------------------------------------------------------------------------- /include/dmc/Spec/OpType.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | 5 | #include 6 | 7 | namespace dmc { 8 | namespace detail { 9 | struct OpTypeStorage; 10 | } // end namespace detail 11 | 12 | struct NamedType { 13 | llvm::StringRef name; 14 | mlir::Type type; 15 | 16 | bool isVariadic() const; 17 | }; 18 | 19 | class OpType : public mlir::Type::TypeBase { 21 | public: 22 | using Base::Base; 23 | 24 | static OpType getChecked(mlir::Location loc, 25 | llvm::ArrayRef operands, 26 | llvm::ArrayRef results); 27 | 28 | static bool kindof(unsigned kind) { return kind == TypeKinds::OpType; } 29 | 30 | llvm::ArrayRef getOperands(); 31 | llvm::ArrayRef getResults(); 32 | 33 | inline unsigned getNumOperands() { return std::size(getOperands()); } 34 | inline unsigned getNumResults() { return std::size(getResults()); } 35 | inline const NamedType *getOperand(unsigned idx) 36 | { return &getOperands()[idx]; } 37 | inline const NamedType *getResult(unsigned idx) 38 | { return &getResults()[idx]; } 39 | 40 | inline llvm::StringRef getOperandName(unsigned idx) 41 | { return getOperand(idx)->name; } 42 | inline llvm::StringRef getResultName(unsigned idx) 43 | { return getResult(idx)->name; } 44 | inline mlir::Type getOperandType(unsigned idx) 45 | { return getOperand(idx)->type; } 46 | inline mlir::Type getResultType(unsigned idx) 47 | { return getResult(idx)->type; } 48 | 49 | inline auto getOperandTypes() 50 | { return llvm::map_range(getOperands(), &unwrap); } 51 | inline auto getResultTypes() 52 | { return llvm::map_range(getResults(), &unwrap); } 53 | 54 | inline auto operand_begin() { return std::begin(getOperands()); } 55 | inline auto operand_end() { return std::end(getOperands()); } 56 | inline auto result_begin() { return std::begin(getResults()); } 57 | inline auto result_end() { return std::end(getResults()); } 58 | 59 | private: 60 | static mlir::Type unwrap(const NamedType &a) { return a.type; } 61 | }; 62 | 63 | } // end namespace dmc 64 | -------------------------------------------------------------------------------- /include/dmc/Spec/ParameterList.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | #include "Parsing.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace mlir { 10 | namespace detail { 11 | struct NamedParameterStorage; 12 | } // end namespace detail 13 | 14 | class NamedParameter : public Attribute::AttrBase< 15 | NamedParameter, Attribute, detail::NamedParameterStorage> { 16 | public: 17 | using Base::Base; 18 | 19 | static NamedParameter get(StringRef name, Attribute constraint); 20 | static NamedParameter getChecked(Location loc, StringRef name, 21 | Attribute constraint); 22 | static LogicalResult verifyConstructionInvariants( 23 | Location loc, StringRef name, Attribute constraint); 24 | 25 | static bool kindof(unsigned kind) { 26 | return kind == dmc::AttrKinds::NamedParameter; 27 | } 28 | 29 | StringRef getName() const; 30 | Attribute getConstraint() const; 31 | }; 32 | 33 | using NamedParameterRange = iterator_range::iterator, NamedParameter (*)(Attribute)>>; 35 | 36 | #include "dmc/Spec/ParameterList.h.inc" 37 | 38 | } // end namespace mlir 39 | 40 | namespace dmc { 41 | using NamedParameterRange = mlir::NamedParameterRange; 42 | } // end namespace dmc 43 | -------------------------------------------------------------------------------- /include/dmc/Spec/ParameterList.td: -------------------------------------------------------------------------------- 1 | #ifndef PARAMETER_LIST_TD 2 | #define PARAMETER_LIST_TD 3 | 4 | include "mlir/IR/OpBase.td" 5 | 6 | /// Interface for operations that have a parameter list of attributes. 7 | def ParameterList : OpInterface<"ParameterList"> { 8 | let description = [{ 9 | This class provides an interface and verifier for operations that take 10 | a list of SpecAttr parameters: . 11 | }]; 12 | 13 | let methods = [ 14 | InterfaceMethod<[{ 15 | Get the Op's parameter list. 16 | }], 17 | "NamedParameterRange", "getParameters", (ins), [{}], [{ 18 | auto attr = this->getOperation()->template 19 | getAttrOfType(getParametersAttrName()); 20 | NamedParameter (*unwrap)(Attribute) = [](Attribute attr) { 21 | return attr.cast(); 22 | }; 23 | return llvm::map_range(attr.getValue(), unwrap); 24 | }] 25 | >, 26 | InterfaceMethod<[{ 27 | Replace the Op's parameter list. 28 | }], 29 | "void", "setParameters", (ins "ArrayRef":$params), [{}], [{ 30 | auto *op = this->getOperation(); 31 | op->setAttr(getParametersAttrName(), 32 | ArrayAttr::get(params, op->getContext())); 33 | }] 34 | >, 35 | InterfaceMethod<[{ 36 | Print the parameter list. 37 | }], 38 | "void", "printParameters", (ins "OpAsmPrinter &":$printer), [{}], [{ 39 | auto params = dyn_cast(this->getOperation()).getParameters(); 40 | if (!params.empty()) { 41 | printer << '<'; 42 | llvm::interleaveComma(params, printer, [&](Attribute attr) { 43 | auto param = attr.cast(); 44 | printer << param.getName() << ": " << param.getConstraint(); 45 | }); 46 | printer << '>'; 47 | } 48 | }] 49 | >, 50 | ]; 51 | 52 | let verify = [{ 53 | auto paramsAttr = $_op->getAttrOfType(getParametersAttrName()); 54 | if (!paramsAttr) 55 | return $_op->emitOpError("expected an ArrayAttr named: ") 56 | << getParametersAttrName(); 57 | return success(); 58 | }]; 59 | 60 | let extraTraitClassDeclaration = [{ 61 | static llvm::StringLiteral getParametersAttrName() { return "params"; } 62 | }]; 63 | 64 | /// Parser and printer for parameter lists. 65 | let extraClassDeclaration = [{ 66 | static ParseResult parse(OpAsmParser &parser, NamedAttrList &attrList); 67 | }]; 68 | } 69 | 70 | #endif // PARAMETER_LIST_TD 71 | -------------------------------------------------------------------------------- /include/dmc/Spec/Parsing.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Traits/OpTrait.h" 4 | 5 | #include 6 | #include 7 | 8 | namespace dmc { 9 | // Forward declarations 10 | class OpType; 11 | class OpRegion; 12 | class OpSuccessor; 13 | 14 | namespace impl { 15 | 16 | /// Parse a single attribute using an OpAsmParser. 17 | mlir::ParseResult parseSingleAttribute(mlir::OpAsmParser &parser, 18 | mlir::Attribute &attr); 19 | 20 | /// Parse an optional parameter list. 21 | mlir::ParseResult parseOptionalParameterList(mlir::DialectAsmParser &parser, 22 | mlir::ArrayAttr &attr); 23 | mlir::ParseResult parseOptionalParameterList(mlir::OpAsmParser &parser, 24 | mlir::ArrayAttr &attr); 25 | 26 | /// Print a parameter list. 27 | void printOptionalParameterList(mlir::OpAsmPrinter &printer, 28 | llvm::ArrayRef params); 29 | void printOptionalParameterList(mlir::DialectAsmPrinter &printer, 30 | llvm::ArrayRef params); 31 | 32 | /// Parse and print an op trait list attribute in pretty form. 33 | mlir::ParseResult parseOptionalOpTraitList(mlir::OpAsmParser &parser, 34 | OpTraitsAttr &traitArr); 35 | void printOptionalOpTraitList(mlir::OpAsmPrinter &printer, 36 | OpTraitsAttr traitArr); 37 | 38 | /// Parse and print an op region attribute list. 39 | mlir::ParseResult parseOpRegion(mlir::OpAsmParser &parser, 40 | mlir::Attribute &opRegion); 41 | void printOpRegion(llvm::raw_ostream &os, mlir::Attribute opRegion); 42 | mlir::ParseResult parseOptionalRegionList(mlir::OpAsmParser &parser, 43 | OpRegion &opRegion); 44 | template 45 | void printOptionalRegionList(PrinterT &printer, OpRegion opRegion); 46 | 47 | /// Parse and print an op successor attribute list. 48 | mlir::ParseResult parseOpSuccessor(mlir::OpAsmParser &parser, 49 | mlir::Attribute &opSucc); 50 | void printOpSuccessor(llvm::raw_ostream &os, mlir::Attribute opSucc); 51 | mlir::ParseResult parseOptionalSuccessorList(mlir::OpAsmParser &parser, 52 | OpSuccessor &opSucc); 53 | template 54 | void printOptionalSuccessorList(PrinterT &printer, OpSuccessor opSucc); 55 | 56 | /// Parse and print a list of integers, which may be empty. 57 | /// int-list ::= (int (`,` int)*)? 58 | template 59 | mlir::ParseResult parseIntegerList(mlir::DialectAsmParser &parser, 60 | ListT &ints) { 61 | std::remove_reference_t().front())> val; 62 | auto ret = parser.parseOptionalInteger(val); 63 | if (ret.hasValue()) { // tri-state 64 | if (*ret) // failed to parse integer 65 | return mlir::failure(); 66 | ints.push_back(val); 67 | while (!parser.parseOptionalComma()) { 68 | if (parser.parseInteger(val)) 69 | return mlir::failure(); 70 | ints.push_back(val); 71 | } 72 | } 73 | return mlir::success(); 74 | } 75 | 76 | template 77 | void printIntegerList(mlir::DialectAsmPrinter &printer, 78 | ListT &ints) { 79 | llvm::interleaveComma(ints, printer, [&](auto val) { printer << val; }); 80 | } 81 | 82 | /// Parse and print an OpType. 83 | mlir::ParseResult parseOpType(mlir::OpAsmParser &parser, OpType &opType); 84 | template void printOpType(PrinterT &printer, OpType opType); 85 | 86 | } // end namespace impl 87 | } // end namespace dmc 88 | -------------------------------------------------------------------------------- /include/dmc/Spec/ReparseOpInterface.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | #include "dmc/Spec/ReparseOpInterface.h.inc" 7 | } // end namespace mlir 8 | -------------------------------------------------------------------------------- /include/dmc/Spec/ReparseOpInterface.td: -------------------------------------------------------------------------------- 1 | #ifndef REPARSE_OP_INTERFACE_TD 2 | #define REPARSE_OP_INTERFACE_TD 3 | 4 | include "mlir/IR/OpBase.td" 5 | 6 | /// An interface that allows operations to reparse themselves to resolve 7 | /// opaque objects. 8 | /// 9 | /// This is the second-best solution after incremental parsing: effectively, 10 | /// interpreting MLIR as it is parsed. This would allow for full on-the-fly 11 | /// verification and resolution of symbol references and otherwise opaque 12 | /// values with full error reporting. 13 | /// 14 | /// However, the parser API is not amenable to this solution, so a reparsable 15 | /// interface will have to suffice. 16 | def ReparseOpInterface : OpInterface<"ReparseOpInterface"> { 17 | let description = [{ 18 | Because dynamic objects are registered after first-parse of a MLIR module, 19 | dynamic types and attributes, for example, remain opaque and must be 20 | resolved during second pass. 21 | 22 | Lacking a generic interface to traverse operation types and attributes, 23 | however, the best we can do is ask the Spec operations to reparse parts 24 | of themselves that might contain opaque objects. 25 | }]; 26 | 27 | let methods = [ 28 | InterfaceMethod<[{ 29 | Request that the operation reparse itself. 30 | }], 31 | "ParseResult", "reparse", (ins) 32 | >, 33 | ]; 34 | } 35 | 36 | #endif // REPARSE_OP_INTERFACE_TD 37 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecAttrBase.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /// Make these definitions public so that other constraint kinds can use them. 6 | namespace dmc { 7 | namespace detail { 8 | 9 | /// OneAttrStorage implementation. Store one attribute. 10 | struct OneAttrStorage : public mlir::AttributeStorage { 11 | using KeyTy = mlir::Attribute; 12 | 13 | inline explicit OneAttrStorage(KeyTy key) : attr{key} {} 14 | inline bool operator==(const KeyTy &key) const { return key == attr; } 15 | static llvm::hash_code hashKey(const KeyTy &key) { 16 | return hash_value(key); 17 | } 18 | 19 | static OneAttrStorage *construct(mlir::AttributeStorageAllocator &alloc, 20 | const KeyTy &key) { 21 | return new (alloc.allocate()) OneAttrStorage{key}; 22 | } 23 | 24 | KeyTy attr; 25 | }; 26 | 27 | } // end namespace detail 28 | } // end namespace dmc 29 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecAttrDetail.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecAttrImplementation.h" 4 | #include "SpecTypeDetail.h" 5 | 6 | namespace dmc { 7 | 8 | /// Place full declaration in header to allow template usage. 9 | namespace detail { 10 | 11 | struct TypedAttrStorage : public mlir::AttributeStorage { 12 | using KeyTy = mlir::Type; 13 | 14 | explicit TypedAttrStorage(KeyTy key); 15 | bool operator==(const KeyTy &key) const; 16 | static llvm::hash_code hashKey(const KeyTy &key); 17 | static TypedAttrStorage *construct( 18 | mlir::AttributeStorageAllocator &alloc, const KeyTy &key); 19 | 20 | KeyTy type; 21 | }; 22 | 23 | } // end namespace detail 24 | 25 | /// AttrConstraint on an IntegerAttr with a specified underlying Type. 26 | template 28 | class TypedAttrBase 29 | : public SpecAttr { 30 | public: 31 | using Base = TypedAttrBase; 32 | using Parent = SpecAttr; 33 | using Underlying = UnderlyingT; 34 | using Parent::Parent; 35 | 36 | static ConcreteType getChecked( 37 | mlir::Location loc, UnderlyingT ty) { 38 | return Parent::getChecked(loc, Kind, ty); 39 | } 40 | 41 | static mlir::LogicalResult verifyConstructionInvariants( 42 | mlir::Location loc, UnderlyingT ty) { 43 | if (!ty) 44 | return mlir::emitError(loc) << "Type cannot be null"; 45 | return mlir::success(); 46 | } 47 | 48 | mlir::LogicalResult verify(mlir::Attribute attr) { 49 | return mlir::success(attr.isa() && 50 | mlir::succeeded(this->getImpl()->type.template cast() 51 | .verify(attr.cast().getType()))); 52 | } 53 | 54 | static mlir::Attribute parse(mlir::DialectAsmParser &parser) { 55 | auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 56 | unsigned width; 57 | if (parser.parseLess() || parser.parseInteger(width) || 58 | parser.parseGreater()) 59 | return {}; 60 | return getChecked(loc, UnderlyingT::getChecked(loc, width)); 61 | } 62 | 63 | void print(mlir::DialectAsmPrinter &printer) { 64 | printer << ConcreteType::getAttrName() << '<' 65 | << this->getImpl()->type.template cast().getWidth() << '>'; 66 | } 67 | }; 68 | 69 | } // end namespace dmc 70 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecAttrImplementation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace dmc { 10 | 11 | namespace SpecAttrs { 12 | bool is(mlir::Attribute base); 13 | mlir::LogicalResult delegateVerify(mlir::Attribute base, 14 | mlir::Attribute attr); 15 | } // end namespace SpecAttrs 16 | 17 | template 19 | class SpecAttr 20 | : public mlir::Attribute::AttrBase { 22 | public: 23 | static constexpr auto Kind = SpecKind; 24 | 25 | using Parent = mlir::Attribute::AttrBase; 27 | using Base = SpecAttr; 28 | using Parent::Parent; 29 | 30 | static bool kindof(unsigned kind) { return kind == Kind; } 31 | }; 32 | 33 | template 34 | class SimpleAttr : public SpecAttr { 35 | public: 36 | using Parent = SpecAttr; 37 | using Base = SimpleAttr; 38 | using Parent::Parent; 39 | 40 | static ConcreteType get(mlir::MLIRContext *ctx) { 41 | return Parent::get(ctx, Kind); 42 | } 43 | static mlir::Attribute parse(mlir::DialectAsmParser &parser) { 44 | return get(parser.getBuilder().getContext()); 45 | } 46 | void print(mlir::DialectAsmPrinter &printer) { 47 | printer << ConcreteType::getAttrName(); 48 | } 49 | }; 50 | 51 | /// Verify Attribute constraints. 52 | namespace impl { 53 | mlir::LogicalResult verifyAttrConstraints( 54 | mlir::Operation *op, mlir::DictionaryAttr opAttrs); 55 | } // end namespace impl 56 | 57 | } // end namespace dmc 58 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecAttrSwitch.h: -------------------------------------------------------------------------------- 1 | #include "SpecAttrs.h" 2 | 3 | namespace dmc { 4 | namespace SpecAttrs { 5 | 6 | /// Big switch table. 7 | template 8 | auto kindSwitch(const ActionT &action, unsigned kind) { 9 | switch (kind) { 10 | default: 11 | return action.template operator()(); 12 | case Bool: 13 | return action.template operator()(); 14 | case Index: 15 | return action.template operator()(); 16 | case APInt: 17 | return action.template operator()(); 18 | case AnyI: 19 | return action.template operator()(); 20 | case I: 21 | return action.template operator()(); 22 | case SI: 23 | return action.template operator()(); 24 | case UI: 25 | return action.template operator()(); 26 | case F: 27 | return action.template operator()(); 28 | case String: 29 | return action.template operator()(); 30 | case Type: 31 | return action.template operator()(); 32 | case Unit: 33 | return action.template operator()(); 34 | case Dictionary: 35 | return action.template operator()(); 36 | case Elements: 37 | return action.template operator()(); 38 | case DenseElements: 39 | return action.template operator()(); 40 | case ElementsOf: 41 | return action.template operator()(); 42 | case RankedElements: 43 | return action.template operator()(); 44 | case StringElements: 45 | return action.template operator()(); 46 | case Array: 47 | return action.template operator()(); 48 | case ArrayOf: 49 | return action.template operator()(); 50 | case SymbolRef: 51 | return action.template operator()(); 52 | case FlatSymbolRef: 53 | return action.template operator()(); 54 | case Constant: 55 | return action.template operator()(); 56 | case AnyOf: 57 | return action.template operator()(); 58 | case AllOf: 59 | return action.template operator()(); 60 | case OfType: 61 | return action.template operator()(); 62 | case Optional: 63 | return action.template operator()(); 64 | case Default: 65 | return action.template operator()(); 66 | case Isa: 67 | return action.template operator()(); 68 | case Py: 69 | return action.template operator()(); 70 | } 71 | } 72 | 73 | template 74 | auto kindSwitch(const ActionT &action, mlir::Attribute base) { 75 | assert(SpecAttrs::is(base) && "Not a SpecAttr"); 76 | KindActionWrapper wrapper{action, base}; 77 | return kindSwitch(wrapper, base.getKind()); 78 | } 79 | 80 | } // end namespace SpecAttrs 81 | } // end namespace dmc 82 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecDialect.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace dmc { 6 | 7 | namespace detail { 8 | struct WidthStorage; 9 | struct WidthListStorage; 10 | } // end namespace detail 11 | 12 | /// This dialect defines an DSL/IR that describes 13 | /// - Dialects and their properties 14 | /// - Operations, their types, operands, results, properties, and traits 15 | /// 16 | /// Some properties and traits/verifiers hook into functions defined natively. 17 | /// If generating from a higher-level DSL (e.g. Python), these may hook into 18 | /// Python functions with MLIR bindings for complete features. 19 | /// 20 | /// Dialect operations, types, and attributes define their own parsing and 21 | /// printing syntax, which is used by the generated dialect. 22 | /// 23 | /// The Spec dialect has the following (planned) Ops: 24 | /// - A Dialect module-level Op that defines a dialect and its properties 25 | /// - An Operation Op that defines individual operations 26 | /// - A Type Op that defines custom types 27 | /// - An Attribute Op that defines custom attributes 28 | /// - Ops to define hooks into higher-level DSLs, e.g. function prototypes for 29 | /// Python hooks 30 | class SpecDialect : public mlir::Dialect { 31 | public: 32 | explicit SpecDialect(mlir::MLIRContext *ctx); 33 | static llvm::StringRef getDialectNamespace() { return "dmc"; } 34 | 35 | /// Custom parser and printer for operand and result type specs. 36 | mlir::Type parseType(mlir::DialectAsmParser &parser) const override; 37 | void printType(mlir::Type type, 38 | mlir::DialectAsmPrinter &printer) const override; 39 | 40 | /// Custom parser and printer for attributes. 41 | mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser, 42 | mlir::Type type) const override; 43 | void printAttribute(mlir::Attribute attribute, 44 | mlir::DialectAsmPrinter &printer) const override; 45 | }; 46 | 47 | } // end namespace dmc 48 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecKinds.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Kind.h" 4 | 5 | namespace dmc { 6 | 7 | namespace SpecTypes { 8 | enum Kinds { 9 | Any = Kind::FIRST_SPEC_TYPE, 10 | None, 11 | AnyOf, 12 | AllOf, 13 | 14 | AnyInteger, 15 | AnyI, 16 | AnyIntOfWidths, 17 | 18 | AnySignlessInteger, 19 | I, 20 | SignlessIntOfWidths, 21 | 22 | AnySignedInteger, 23 | SI, 24 | SignedIntOfWidths, 25 | 26 | AnyUnsignedInteger, 27 | UI, 28 | UnsignedIntOfWidths, 29 | 30 | Index, 31 | 32 | AnyFloat, 33 | F, 34 | FloatOfWidths, 35 | BF16, 36 | 37 | AnyComplex, 38 | Complex, 39 | 40 | Opaque, 41 | Function, 42 | 43 | Variadic, // Optional is a subset of Variadic 44 | 45 | Isa, 46 | 47 | /// Generic Python type constraint. 48 | Py, 49 | 50 | LAST_SPEC_TYPE 51 | }; 52 | } // end namespace SpecTypes 53 | 54 | namespace SpecAttrs { 55 | enum Kinds { 56 | Any = Kind::FIRST_SPEC_ATTR, 57 | Bool, 58 | Index, 59 | APInt, 60 | 61 | AnyI, 62 | I, 63 | SI, 64 | UI, 65 | F, 66 | 67 | String, 68 | Type, 69 | Unit, 70 | Dictionary, 71 | Elements, 72 | DenseElements, 73 | ElementsOf, 74 | RankedElements, 75 | StringElements, 76 | Array, 77 | ArrayOf, 78 | 79 | SymbolRef, 80 | FlatSymbolRef, 81 | 82 | Constant, 83 | AnyOf, 84 | AllOf, 85 | OfType, 86 | 87 | Optional, 88 | Default, 89 | 90 | Isa, 91 | 92 | /// Generic Python attribute constraint. 93 | Py, 94 | 95 | /// Non-attribute-constraint kinds. 96 | OpTrait, 97 | OpTraits, 98 | 99 | LAST_SPEC_ATTR 100 | }; 101 | } // end namespace SpecAttrs 102 | 103 | namespace SpecRegion { 104 | enum Kinds { 105 | Any = SpecAttrs::LAST_SPEC_ATTR, 106 | Sized, 107 | IsolatedFromAbove, 108 | Variadic, 109 | 110 | LAST_SPEC_REGION 111 | }; 112 | } // end namespace SpecRegion 113 | 114 | namespace SpecSuccessor { 115 | enum Kinds { 116 | Any = SpecRegion::LAST_SPEC_REGION, 117 | Variadic, 118 | 119 | LAST_SPEC_SUCCESSOR 120 | }; 121 | } // end namespace SpecSuccessor 122 | 123 | namespace TypeKinds { 124 | enum Kinds { 125 | OpType = SpecTypes::LAST_SPEC_TYPE, 126 | 127 | LAST_KIND 128 | }; 129 | } // end namespace TypeKinds 130 | 131 | namespace AttrKinds { 132 | enum Kinds { 133 | OpRegion= SpecSuccessor::LAST_SPEC_SUCCESSOR, 134 | OpSuccessor, 135 | NamedParameter, 136 | 137 | LAST_KIND 138 | }; 139 | } // end namespace AttrKinds 140 | 141 | } // end namespace dmc 142 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecRegion.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | #include "SpecAttrBase.h" 5 | 6 | #include 7 | #include 8 | 9 | /// Forward declarations. 10 | namespace mlir { 11 | class OpAsmParser; 12 | class OpAsmPrinter; 13 | }; 14 | 15 | namespace dmc { 16 | class OpRegion; 17 | 18 | namespace SpecRegion { 19 | bool is(mlir::Attribute base); 20 | mlir::LogicalResult delegateVerify(mlir::Attribute base, mlir::Region ®ion); 21 | /// TODO Instead of avoiding Dialect::printAttribute, use it. 22 | std::string toString(mlir::Attribute opRegion); 23 | } // end namespace SpecRegion 24 | 25 | namespace detail { 26 | struct SizedRegionAttrStorage; 27 | } // end namespace detail 28 | 29 | /// Match any region. 30 | class AnyRegion : public mlir::Attribute::AttrBase< 31 | AnyRegion, mlir::Attribute, mlir::AttributeStorage> { 32 | public: 33 | using Base::Base; 34 | static llvm::StringLiteral getName() { return "Any"; } 35 | static bool kindof(unsigned kind) { return kind == SpecRegion::Any; } 36 | 37 | static AnyRegion get(mlir::MLIRContext *ctx) { 38 | return Base::get(ctx, SpecRegion::Any); 39 | } 40 | 41 | inline mlir::LogicalResult verify(mlir::Region &) { return mlir::success(); } 42 | 43 | static Attribute parse(mlir::OpAsmParser &parser); 44 | void print(llvm::raw_ostream &os); 45 | }; 46 | 47 | /// Match a region with a given number of blocks. 48 | class SizedRegion : public mlir::Attribute::AttrBase< 49 | SizedRegion, mlir::Attribute, detail::SizedRegionAttrStorage> { 50 | public: 51 | using Base::Base; 52 | static llvm::StringLiteral getName() { return "Sized"; } 53 | static bool kindof(unsigned kind) { return kind == SpecRegion::Sized; } 54 | 55 | static SizedRegion getChecked(mlir::Location loc, unsigned size); 56 | static mlir::LogicalResult verifyConstructionInvariants(mlir::Location loc, 57 | unsigned size); 58 | 59 | mlir::LogicalResult verify(mlir::Region ®ion); 60 | 61 | static Attribute parse(mlir::OpAsmParser &parser); 62 | void print(llvm::raw_ostream &os); 63 | }; 64 | 65 | /// Match a region isolated from above. 66 | class IsolatedFromAboveRegion : public mlir::Attribute::AttrBase< 67 | IsolatedFromAboveRegion, mlir::Attribute, mlir::AttributeStorage> { 68 | public: 69 | using Base::Base; 70 | static llvm::StringLiteral getName() { return "IsolatedFromAbove"; } 71 | static bool kindof(unsigned kind) 72 | { return kind == SpecRegion::IsolatedFromAbove; } 73 | 74 | static IsolatedFromAboveRegion get(mlir::MLIRContext *ctx) { 75 | return Base::get(ctx, SpecRegion::IsolatedFromAbove); 76 | } 77 | 78 | mlir::LogicalResult verify(mlir::Region ®ion); 79 | 80 | static Attribute parse(mlir::OpAsmParser &parser); 81 | void print(llvm::raw_ostream &os); 82 | }; 83 | 84 | /// Variadic regions. 85 | class VariadicRegion : public mlir::Attribute::AttrBase< 86 | VariadicRegion, mlir::Attribute, detail::OneAttrStorage> { 87 | public: 88 | using Base::Base; 89 | static llvm::StringLiteral getName() { return "Variadic"; } 90 | static bool kindof(unsigned kind) { return kind == SpecRegion::Variadic; } 91 | 92 | static VariadicRegion getChecked(mlir::Location loc, 93 | mlir::Attribute regionConstraint); 94 | static mlir::LogicalResult verifyConstructionInvariants( 95 | mlir::Location loc, mlir::Attribute regionConstraint); 96 | 97 | mlir::LogicalResult verify(mlir::Region ®ion); 98 | 99 | static Attribute parse(mlir::OpAsmParser &parser); 100 | void print(llvm::raw_ostream &os); 101 | }; 102 | 103 | /// Verify Region constraints. 104 | namespace impl { 105 | mlir::LogicalResult verifyRegionConstraints( 106 | mlir::Operation *op, OpRegion opRegions); 107 | } // end namespace impl 108 | 109 | } // end namespace dmc 110 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecRegionSwitch.h: -------------------------------------------------------------------------------- 1 | #include "SpecRegion.h" 2 | #include "Support.h" 3 | 4 | namespace dmc { 5 | namespace SpecRegion { 6 | 7 | template 8 | auto kindSwitch(const ActionT &action, unsigned kind) { 9 | switch (kind) { 10 | default: 11 | return action.template operator()(); 12 | case Sized: 13 | return action.template operator()(); 14 | case IsolatedFromAbove: 15 | return action.template operator()(); 16 | case Variadic: 17 | return action.template operator()(); 18 | } 19 | } 20 | 21 | template 22 | auto kindSwitch(const ActionT &action, mlir::Attribute base) { 23 | assert(SpecRegion::is(base) && "Not a SpecRegion"); 24 | KindActionWrapper wrapper{action, base}; 25 | return kindSwitch(wrapper, base.getKind()); 26 | } 27 | 28 | } // end namespace SpecRegion 29 | } // end namespace dmc 30 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecSuccessor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | #include "SpecAttrBase.h" 5 | 6 | #include 7 | 8 | /// Forward declarations. 9 | namespace mlir { 10 | class OpAsmParser; 11 | class OpAsmPrinter; 12 | }; 13 | 14 | namespace dmc { 15 | class OpSuccessor; 16 | 17 | namespace SpecSuccessor { 18 | bool is(mlir::Attribute base); 19 | mlir::LogicalResult delegateVerify(mlir::Attribute base, mlir::Block *block); 20 | /// TODO Instead of avoiding Dialect::printAttribute, use it. 21 | std::string toString(mlir::Attribute opSucc); 22 | } // end namespace SpecSuccessor 23 | 24 | /// Match any successor. 25 | class AnySuccessor : public mlir::Attribute::AttrBase< 26 | AnySuccessor, mlir::Attribute, mlir::AttributeStorage> { 27 | public: 28 | using Base::Base; 29 | static llvm::StringLiteral getName() { return "Any"; } 30 | static bool kindof(unsigned kind) { return kind == SpecSuccessor::Any; } 31 | 32 | static AnySuccessor get(mlir::MLIRContext *ctx) { 33 | return Base::get(ctx, SpecSuccessor::Any); 34 | } 35 | 36 | inline mlir::LogicalResult verify(mlir::Block *) { return mlir::success(); } 37 | 38 | static Attribute parse(mlir::OpAsmParser &parser); 39 | void print(llvm::raw_ostream &os); 40 | }; 41 | 42 | /// Variadic successors. 43 | class VariadicSuccessor : public mlir::Attribute::AttrBase< 44 | VariadicSuccessor, mlir::Attribute, detail::OneAttrStorage> { 45 | public: 46 | using Base::Base; 47 | static llvm::StringLiteral getName() { return "Variadic"; } 48 | static bool kindof(unsigned kind) { return kind == SpecSuccessor::Variadic; } 49 | 50 | static VariadicSuccessor getChecked(mlir::Location loc, 51 | mlir::Attribute succConstraint); 52 | static mlir::LogicalResult verifyConstructionInvariants( 53 | mlir::Location loc, mlir::Attribute succConstraint); 54 | 55 | mlir::LogicalResult verify(mlir::Block *block); 56 | 57 | static Attribute parse(mlir::OpAsmParser &parser); 58 | void print(llvm::raw_ostream &os); 59 | }; 60 | 61 | namespace impl { 62 | mlir::LogicalResult verifySuccessorConstraints(mlir::Operation *op, 63 | OpSuccessor opSuccs); 64 | } // end namespace impl 65 | 66 | } // end namespace dmc 67 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecSuccessorSwitch.h: -------------------------------------------------------------------------------- 1 | #include "SpecSuccessor.h" 2 | #include "Support.h" 3 | 4 | namespace dmc { 5 | namespace SpecSuccessor { 6 | 7 | template 8 | auto kindSwitch(const ActionT &action, unsigned kind) { 9 | switch (kind) { 10 | default: 11 | return action.template operator()(); 12 | case Variadic: 13 | return action.template operator()(); 14 | } 15 | } 16 | 17 | template 18 | auto kindSwitch(const ActionT &action, mlir::Attribute base) { 19 | assert(SpecSuccessor::is(base) && "Not a SpecSuccessor"); 20 | KindActionWrapper wrapper{action, base}; 21 | return kindSwitch(wrapper, base.getKind()); 22 | } 23 | 24 | } // end namespace SpecSuccessor 25 | } // end namespace dmc 26 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecTypeImplementation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "SpecKinds.h" 4 | #include "OpType.h" 5 | #include "dmc/Dynamic/DynamicOperation.h" 6 | 7 | #include 8 | #include 9 | 10 | namespace dmc { 11 | 12 | namespace SpecTypes { 13 | bool is(mlir::Type base); 14 | mlir::LogicalResult delegateVerify(mlir::Type base, mlir::Type ty); 15 | } // end namespace SpecTypes 16 | 17 | /// A SpecType is used to define a TypeConstraint. Each SpecType 18 | /// implements a TypeConstraint called on DynamicOperations during 19 | /// trait and Op verification. 20 | template 22 | class SpecType 23 | : public mlir::Type::TypeBase { 24 | public: 25 | static constexpr auto Kind = SpecKind; 26 | 27 | /// Explicitly define Base class for templated classes. 28 | using Parent = mlir::Type::TypeBase; 29 | using Base = SpecType; 30 | 31 | /// Inherit parent constructors to pass onto child classes. 32 | using Parent::Parent; 33 | 34 | /// All SpecType subclasses implement a function of the signature 35 | /// 36 | /// LogicalResult verify(Type ty) 37 | /// 38 | /// Which executes the TypeConstraint. Because mlir::Type is a CRTP 39 | /// class, we have manually create a virtual table using the kind. 40 | 41 | /// Compare type kinds. 42 | static bool kindof(unsigned kind) { return kind == Kind; } 43 | }; 44 | 45 | /// Simple type shorthand class. 46 | template 47 | class SimpleType : public SpecType { 48 | public: 49 | using Parent = SpecType; 50 | using Base = SimpleType; 51 | 52 | using Parent::Parent; 53 | 54 | /// Dispatch to simple Type getter. 55 | static ConcreteType get(mlir::MLIRContext *ctx) { 56 | return Parent::get(ctx, Kind); 57 | } 58 | /// Parser for simple types. 59 | static mlir::Type parse(mlir::DialectAsmParser &parser) { 60 | return get(parser.getBuilder().getContext()); 61 | } 62 | /// Printer for simple types. 63 | void print(mlir::DialectAsmPrinter &printer) { 64 | printer << ConcreteType::getTypeName(); 65 | } 66 | }; 67 | 68 | /// Verify Type constraints. 69 | namespace impl { 70 | mlir::LogicalResult verifyTypeConstraints(mlir::Operation *op, OpType opTy); 71 | } // end namespace impl 72 | 73 | } // end namespace dmc 74 | -------------------------------------------------------------------------------- /include/dmc/Spec/SpecTypeSwitch.h: -------------------------------------------------------------------------------- 1 | #include "SpecTypes.h" 2 | #include "Support.h" 3 | 4 | namespace dmc { 5 | namespace SpecTypes { 6 | 7 | /// Big switch table. 8 | template 9 | auto kindSwitch(const ActionT &action, unsigned kind) { 10 | switch (kind) { 11 | default: 12 | return action.template operator()(); 13 | case None: 14 | return action.template operator()(); 15 | case AnyOf: 16 | return action.template operator()(); 17 | case AllOf: 18 | return action.template operator()(); 19 | case AnyInteger: 20 | return action.template operator()(); 21 | case AnyI: 22 | return action.template operator()(); 23 | case AnyIntOfWidths: 24 | return action.template operator()(); 25 | case AnySignlessInteger: 26 | return action.template operator()(); 27 | case I: 28 | return action.template operator()(); 29 | case SignlessIntOfWidths: 30 | return action.template operator()(); 31 | case AnySignedInteger: 32 | return action.template operator()(); 33 | case SI: 34 | return action.template operator()(); 35 | case SignedIntOfWidths: 36 | return action.template operator()(); 37 | case AnyUnsignedInteger: 38 | return action.template operator()(); 39 | case UI: 40 | return action.template operator()(); 41 | case UnsignedIntOfWidths: 42 | return action.template operator()(); 43 | case Index: 44 | return action.template operator()(); 45 | case AnyFloat: 46 | return action.template operator()(); 47 | case F: 48 | return action.template operator()(); 49 | case FloatOfWidths: 50 | return action.template operator()(); 51 | case BF16: 52 | return action.template operator()(); 53 | case AnyComplex: 54 | return action.template operator()(); 55 | case Complex: 56 | return action.template operator()(); 57 | case Opaque: 58 | return action.template operator()(); 59 | case Variadic: 60 | return action.template operator()(); 61 | case Isa: 62 | return action.template operator()(); 63 | case Py: 64 | return action.template operator()(); 65 | } 66 | } 67 | 68 | template 69 | auto kindSwitch(const ActionT &action, mlir::Type base) { 70 | assert(SpecTypes::is(base) && "Not a SpecType"); 71 | KindActionWrapper wrapper{action, base}; 72 | return kindSwitch(wrapper, base.getKind()); 73 | } 74 | 75 | } // end namespace SpecTypes 76 | } // end namespace dmc 77 | -------------------------------------------------------------------------------- /include/dmc/Spec/Support.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace dmc { 8 | 9 | /// An immutable list that self-sorts on creation. 10 | template 11 | struct ImmutableSortedList : public llvm::SmallVector { 12 | /// Sort on creation with comparator. 13 | template 14 | ImmutableSortedList(const Container &c, 15 | ComparatorT comparator = ComparatorT{}) 16 | : llvm::SmallVector{std::begin(c), std::end(c)} { 17 | llvm::sort(std::begin(*this), std::end(*this), comparator); 18 | } 19 | 20 | /// Compare list sizes and contents. 21 | bool operator==(const ImmutableSortedList &other) const { 22 | if (this->size() != other.size()) 23 | return false; 24 | return std::equal(this->begin(), this->end(), other.begin()); 25 | } 26 | 27 | /// Hash list values. 28 | llvm::hash_code hash() const { 29 | return llvm::hash_combine_range(this->begin(), this->end()); 30 | } 31 | }; 32 | 33 | template 34 | ImmutableSortedList getSortedListOf(llvm::ArrayRef arr) { 35 | return {arr, ComparatorT{}}; 36 | } 37 | 38 | /// Wrapper for kind switches with an Arg instance. 39 | template 40 | struct KindActionWrapper { 41 | const ActionT &action; 42 | ArgT base; 43 | 44 | template 45 | auto operator()() const { 46 | return action(base.template cast()); 47 | } 48 | }; 49 | 50 | template 51 | struct ParseAction { 52 | ParserT &parser; 53 | 54 | template 55 | RetT operator()() const { 56 | return ConcreteType::parse(parser); 57 | } 58 | }; 59 | 60 | template 61 | struct PrintAction { 62 | PrinterT &printer; 63 | 64 | template 65 | int operator()(ConcreteType base) const { 66 | base.print(printer); 67 | return 0; 68 | } 69 | }; 70 | 71 | template 72 | struct VerifyAction { 73 | ArgT arg; 74 | 75 | template 76 | mlir::LogicalResult operator()(ConcreteType base) const { 77 | return base.verify(arg); 78 | } 79 | }; 80 | 81 | 82 | } // end namespace dmc 83 | -------------------------------------------------------------------------------- /include/dmc/Traits/Kinds.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Spec/SpecKinds.h" 4 | 5 | namespace dmc { 6 | namespace Traits { 7 | enum Kind { 8 | IsTerminator, 9 | IsCommutative, 10 | IsIsolatedFromAbove, 11 | 12 | OperandsAreFloatLike, 13 | OperandsAreSignlessIntegerLike, 14 | ResultsAreBoolLike, 15 | ResultsAreFloatLike, 16 | ResultsAreSignlessIntegerLike, 17 | 18 | SameOperandsShape, 19 | SameOperandsAndResultShape, 20 | SameOperandsElementType, 21 | SameOperandsAndResultElementType, 22 | SameOperandsAndResultType, 23 | SameTypeOperands, 24 | 25 | NOperands, 26 | AtLeastNOperands, 27 | NRegions, 28 | AtLeastNRegions, 29 | NResults, 30 | AtLeastNResults, 31 | NSuccessors, 32 | AtLeastNSuccessors, 33 | 34 | SameVariadicOperandSizes, 35 | SameVariadicResultSizes, 36 | SizedOperandSegments, 37 | SizedResultSegments, 38 | TypeConstraintTrait, 39 | AttrConstraintTrait, 40 | 41 | NUM_TRAITS 42 | }; 43 | } // end namespace Traits 44 | 45 | namespace TraitAttr { 46 | enum Kinds { 47 | OpTrait = AttrKinds::LAST_KIND, 48 | OpTraits, 49 | 50 | LAST_TRAIT_ATTR 51 | }; 52 | } // end namespace TraitAttr 53 | } // end namespace dmc 54 | -------------------------------------------------------------------------------- /include/dmc/Traits/OpTrait.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Kinds.h" 4 | 5 | #include 6 | 7 | namespace dmc { 8 | 9 | namespace detail { 10 | struct OpTraitStorage; 11 | struct OpTraitsStorage; 12 | } // end namespace detail 13 | 14 | /// An attribute representing a parameterized op trait. 15 | class OpTraitAttr : public mlir::Attribute::AttrBase< 16 | OpTraitAttr, mlir::Attribute, detail::OpTraitStorage> { 17 | public: 18 | using Base::Base; 19 | 20 | static bool kindof(unsigned kind) { return kind == TraitAttr::OpTrait; } 21 | 22 | /// Attribute hooks. 23 | static OpTraitAttr get(mlir::FlatSymbolRefAttr nameAttr, 24 | mlir::ArrayAttr paramAttr); 25 | static OpTraitAttr getChecked( 26 | mlir::Location loc, mlir::FlatSymbolRefAttr nameAttr, 27 | mlir::ArrayAttr paramAttr); 28 | static mlir::LogicalResult verifyConstructionInvariants( 29 | mlir::Location loc, mlir::FlatSymbolRefAttr nameAttr, 30 | mlir::ArrayAttr paramAttr); 31 | 32 | /// Parsing and printing. 33 | static OpTraitAttr parse(mlir::DialectAsmParser &parser); 34 | void print(mlir::DialectAsmPrinter &printer); 35 | 36 | /// Getters. 37 | llvm::StringRef getName(); 38 | llvm::ArrayRef getParameters(); 39 | }; 40 | 41 | /// An attribute representing a dynamic operation's dynamic op traits. 42 | class OpTraitsAttr : public mlir::Attribute::AttrBase< 43 | OpTraitsAttr, mlir::Attribute, detail::OpTraitsStorage> { 44 | public: 45 | using Base::Base; 46 | 47 | static bool kindof(unsigned kind) { return kind == TraitAttr::OpTraits; } 48 | 49 | /// Attribute hooks. 50 | static OpTraitsAttr get(mlir::ArrayAttr traits); 51 | static OpTraitsAttr getChecked(mlir::Location loc, mlir::ArrayAttr traits); 52 | static mlir::LogicalResult verifyConstructionInvariants( 53 | mlir::Location loc, mlir::ArrayAttr traits); 54 | 55 | /// Parsing and printing. 56 | static OpTraitsAttr parse(mlir::DialectAsmParser &parser); 57 | void print(mlir::DialectAsmPrinter &printer); 58 | 59 | /// Getters. 60 | inline auto getValue() { 61 | return llvm::map_range(getUnderlyingValue(), [](mlir::Attribute attr) 62 | { return attr.cast(); }); 63 | } 64 | 65 | private: 66 | llvm::ArrayRef getUnderlyingValue(); 67 | }; 68 | 69 | } // end namespace dmc 70 | -------------------------------------------------------------------------------- /include/dmc/Traits/Registry.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Dynamic/DynamicOperation.h" 4 | 5 | #include 6 | #include 7 | 8 | namespace dmc { 9 | 10 | using Trait = std::unique_ptr; 11 | 12 | /// The trait constructor leverages MLIR's attribute system to store generic 13 | /// values to pass to a "trait constructor". This is used to generically create 14 | /// parameterized traits, such as @NSuccessors<2>. 15 | class TraitConstructor { 16 | public: 17 | using ArgsT = llvm::ArrayRef; 18 | 19 | /// Create the constructor with a signature verifier and a call function. 20 | template 21 | TraitConstructor(VerifyFn verifyFunc, CallFn callFunc) 22 | : verifyFunc{verifyFunc}, callFunc{callFunc} {} 23 | 24 | /// Convertible from nullptr and to bool. 25 | TraitConstructor(std::nullptr_t) {} 26 | operator bool() const { return verifyFunc && callFunc; } 27 | 28 | /// Delegate to internal functions. 29 | inline auto call(ArgsT args) { return callFunc(args); } 30 | inline auto verify(mlir::Location loc, ArgsT args) { 31 | return verifyFunc(loc, args); 32 | } 33 | 34 | private: 35 | /// The internal functions. 36 | std::function verifyFunc; 37 | std::function callFunc; 38 | }; 39 | 40 | /// The trait registry is a Dialect so that it can be stored inside the 41 | /// MLIRContext for later lookup. 42 | class TraitRegistry : public mlir::Dialect { 43 | public: 44 | explicit TraitRegistry(mlir::MLIRContext *ctx); 45 | static llvm::StringRef getDialectNamespace() { return "trait"; } 46 | 47 | /// Register a trait constructor. 48 | void registerTrait(llvm::StringRef name, TraitConstructor &&getter); 49 | /// Lookup a trait constructor. 50 | TraitConstructor lookupTrait(llvm::StringRef name); 51 | 52 | /// OpTrait attribute parsing and printing. 53 | mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser, 54 | mlir::Type type) const override; 55 | void printAttribute(mlir::Attribute attr, 56 | mlir::DialectAsmPrinter &printer) const override; 57 | 58 | private: 59 | llvm::StringMap traitRegistry; 60 | }; 61 | 62 | /// Of-out-line definitions. 63 | 64 | } // end namespace dmc 65 | -------------------------------------------------------------------------------- /lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Dynamic) 2 | add_subdirectory(IO) 3 | add_subdirectory(Spec) 4 | add_subdirectory(Traits) 5 | add_subdirectory(Python) 6 | add_subdirectory(Embed) 7 | -------------------------------------------------------------------------------- /lib/Dynamic/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(DMCDynamic 2 | DynamicContext.cpp 3 | DynamicDialect.cpp 4 | DynamicDialectImpl.cpp 5 | DynamicObject.cpp 6 | DynamicOperation.cpp 7 | DynamicType.cpp 8 | DynamicAttribute.cpp 9 | TypeIDAllocator.cpp 10 | ) 11 | target_link_libraries(DMCDynamic 12 | DMCSpec 13 | DMCEmbed 14 | MLIRIR 15 | ) 16 | -------------------------------------------------------------------------------- /lib/Dynamic/DynamicContext.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Dynamic/DynamicContext.h" 2 | #include "dmc/Dynamic/DynamicDialect.h" 3 | #include "dmc/Dynamic/DynamicType.h" 4 | #include "dmc/Dynamic/DynamicAttribute.h" 5 | #include "dmc/Embed/Init.h" 6 | 7 | #include 8 | #include 9 | 10 | using namespace mlir; 11 | 12 | namespace dmc { 13 | 14 | class DynamicContext::Impl { 15 | friend class DynamicContext; 16 | 17 | /// A registry of symbols and their associated dynamic dialect. 18 | DenseMap dialectSymbols; 19 | 20 | template DynamicDialect *lookupDialectFor(SymbolT sym) { 21 | auto it = dialectSymbols.find(sym.getAsOpaquePointer()); 22 | return it == std::end(dialectSymbols) ? nullptr : it->second; 23 | } 24 | 25 | template 26 | LogicalResult registerDialectSymbol(DynamicDialect *dialect, SymbolT sym) { 27 | auto [it, inserted] = dialectSymbols.try_emplace(sym.getAsOpaquePointer(), 28 | dialect); 29 | return success(inserted); 30 | } 31 | }; 32 | 33 | DynamicContext::~DynamicContext() = default; 34 | 35 | DynamicContext::DynamicContext(MLIRContext *ctx) 36 | : Dialect{getDialectNamespace(), ctx, TypeID::get()}, 37 | typeIdAlloc{getFixedTypeIDAllocator()}, 38 | impl{std::make_unique()} { 39 | // Automatically initialize the interpreter 40 | py::init(ctx); 41 | } 42 | 43 | DynamicDialect *DynamicContext::createDynamicDialect(StringRef name) { 44 | auto *dialect = new DynamicDialect{name, this}; 45 | auto typeId = dynamic_cast(dialect)->getTypeID(); 46 | auto ctor = [dialect, typeId]() { 47 | std::unique_ptr ptr{dialect}; 48 | ptr->dialectID = typeId; 49 | return ptr; 50 | }; 51 | getContext()->getOrCreateDialect(name, typeId, ctor); 52 | return dialect; 53 | } 54 | 55 | DynamicDialect *DynamicContext::lookupDialectFor(Type type) { 56 | if (auto dynTy = type.dyn_cast()) 57 | return dynTy.getDynImpl()->getDialect(); 58 | return impl->lookupDialectFor(type); 59 | } 60 | 61 | DynamicDialect *DynamicContext::lookupDialectFor(Attribute attr) { 62 | if (auto dynAttr = attr.dyn_cast()) 63 | return dynAttr.getDynImpl()->getDialect(); 64 | return impl->lookupDialectFor(attr); 65 | } 66 | 67 | DynamicDialect *DynamicContext::lookupDialectFor(OperationName opName) { 68 | return impl->lookupDialectFor(opName); 69 | } 70 | 71 | LogicalResult DynamicContext::registerDialectSymbol(DynamicDialect *dialect, 72 | Type type) { 73 | return impl->registerDialectSymbol(dialect, type); 74 | } 75 | 76 | LogicalResult DynamicContext::registerDialectSymbol(DynamicDialect *dialect, 77 | Attribute attr) { 78 | return impl->registerDialectSymbol(dialect, attr); 79 | } 80 | 81 | LogicalResult DynamicContext::registerDialectSymbol(DynamicDialect *dialect, 82 | OperationName opName) { 83 | return impl->registerDialectSymbol(dialect, opName); 84 | } 85 | 86 | } // end namespace dmc 87 | -------------------------------------------------------------------------------- /lib/Dynamic/DynamicDialect.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Dynamic/Alias.h" 2 | #include "dmc/Dynamic/DynamicContext.h" 3 | #include "dmc/Dynamic/DynamicType.h" 4 | #include "dmc/Dynamic/DynamicAttribute.h" 5 | #include "dmc/Dynamic/DynamicDialect.h" 6 | #include "dmc/Dynamic/DynamicOperation.h" 7 | 8 | #include 9 | 10 | using namespace mlir; 11 | 12 | namespace dmc { 13 | 14 | std::unique_ptr 15 | DynamicDialect::createDynamicOp(StringRef name) { 16 | /// Allocate on heap so AbstractOperation references stay valid. 17 | /// Ownership must be passed to DynamicContext. 18 | return std::make_unique(name, this); 19 | } 20 | 21 | LogicalResult DynamicDialect::createDynamicType(StringRef name, 22 | NamedParameterRange paramSpec) { 23 | return registerDynamicType( 24 | std::make_unique(this, name, paramSpec)); 25 | } 26 | 27 | LogicalResult DynamicDialect::createDynamicAttr(StringRef name, 28 | NamedParameterRange paramSpec) { 29 | return registerDynamicAttr( 30 | std::make_unique(this, name, paramSpec)); 31 | } 32 | 33 | Type DynamicDialect::parseType(DialectAsmParser &parser) const { 34 | auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 35 | /// Get the type name. 36 | StringRef name; 37 | if (parser.parseKeyword(&name)) 38 | return {}; 39 | 40 | /// Return a type alias if one is found. 41 | auto *typeAlias = lookupTypeAlias(name); 42 | if (typeAlias) 43 | return typeAlias->getAliasedType(); 44 | 45 | /// Lookup a dynamic type and call its parser. 46 | auto *typeImpl = lookupType(name); 47 | if (!typeImpl) { 48 | emitError(loc) << "Unknown type name: " << name; 49 | return {}; 50 | } 51 | return typeImpl->parseType(loc, parser); 52 | } 53 | 54 | void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const { 55 | auto dynTy = type.cast(); 56 | dynTy.getDynImpl()->printType(type, printer); 57 | } 58 | 59 | /// TODO Typed custom attributes. Combining types and attributes requires 60 | /// user-written verification code, which isn't possible until a higher-level 61 | /// language is incorporated. 62 | Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser, 63 | Type type) const { 64 | if (type && !type.isa()) { 65 | parser.emitError(parser.getCurrentLocation(), 66 | "typed custom attributes currently unsupported"); 67 | return {}; 68 | } 69 | auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 70 | 71 | /// Get the attribute name. 72 | StringRef name; 73 | if (parser.parseKeyword(&name)) 74 | return {}; 75 | 76 | /// Return an attribute alias if one is found. 77 | auto *attrAlias = lookupAttrAlias(name); 78 | if (attrAlias) 79 | return attrAlias->getAliasedAttr(); 80 | 81 | /// Lookup a dynamic attribute and call its parser. 82 | auto *attrImpl = lookupAttr(name); 83 | if (!attrImpl) { 84 | emitError(loc) << "Unknown attribute name: " << name; 85 | return {}; 86 | } 87 | return attrImpl->parseAttribute(loc, parser); 88 | } 89 | 90 | void DynamicDialect::printAttribute(Attribute attr, 91 | DialectAsmPrinter &printer) const { 92 | auto dynAttr = attr.cast(); 93 | dynAttr.getDynImpl()->printAttribute(attr, printer); 94 | } 95 | 96 | } // end namespace dmc 97 | -------------------------------------------------------------------------------- /lib/Dynamic/DynamicObject.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Dynamic/DynamicContext.h" 2 | #include "dmc/Dynamic/DynamicObject.h" 3 | 4 | namespace dmc { 5 | 6 | DynamicObject::DynamicObject(DynamicContext *ctx) 7 | : ctx{ctx}, 8 | typeId{ctx->getTypeIDAlloc()->allocateID()} {} 9 | 10 | } // end namespace dmc 11 | -------------------------------------------------------------------------------- /lib/Dynamic/TypeIDAllocator.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Dynamic/TypeIDAllocator.h" 2 | 3 | using namespace mlir; 4 | 5 | namespace dmc { 6 | 7 | /// Pre-allocate a pool of TypeIDs. Definitely a hack. 8 | namespace { 9 | 10 | template 11 | using IDList = std::array; 12 | 13 | namespace detail { 14 | 15 | template struct IDReserve {}; 16 | 17 | template 18 | IDList allocateIDPoolImpl(std::index_sequence) { 19 | return {TypeID::get>()...}; 20 | } 21 | 22 | template auto allocateIDPool() { 23 | return allocateIDPoolImpl(std::make_index_sequence()); 24 | } 25 | 26 | } // end namespace detail 27 | 28 | template 29 | class FixedTypeIDAllocator : public TypeIDAllocator { 30 | public: 31 | TypeID allocateID() override { 32 | assert(index < ids.size() && "Out of TypeIDs"); 33 | return ids[index++]; 34 | } 35 | 36 | private: 37 | std::size_t index{}; 38 | IDList ids = detail::allocateIDPool(); 39 | }; 40 | 41 | } // end anonymous namespace 42 | 43 | TypeIDAllocator *getFixedTypeIDAllocator() { 44 | static FixedTypeIDAllocator<2048> typeIdAllocator; 45 | return &typeIdAllocator; 46 | } 47 | 48 | } // end namespace dmc 49 | -------------------------------------------------------------------------------- /lib/Embed/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(DMCEmbed 2 | Constraints.cpp 3 | Spec.cpp 4 | OpFormatGen.cpp 5 | TypeFormatGen.cpp 6 | PythonGen.cpp 7 | InMemoryDef.cpp 8 | ParserPrinter.cpp 9 | Expose.cpp 10 | FormatUtils.cpp 11 | FormatUtils.h 12 | Scope.cpp 13 | ) 14 | 15 | target_link_libraries(DMCEmbed PUBLIC 16 | MLIRIR 17 | MLIRTableGen 18 | pybind11 19 | pymlir 20 | ) 21 | 22 | add_library(DMCEmbedInit Init.cpp) 23 | target_link_libraries(DMCEmbedInit PUBLIC 24 | pybind11 25 | pymlir 26 | ) 27 | -------------------------------------------------------------------------------- /lib/Embed/Constraints.cpp: -------------------------------------------------------------------------------- 1 | #include "Scope.h" 2 | #include "dmc/Embed/Constraints.h" 3 | #include "dmc/Traits/StandardTraits.h" 4 | #include "dmc/Dynamic/DynamicOperation.h" 5 | 6 | /// The polymorphic_type_hook must be visible so that Type and Attribute can be 7 | /// downcasted to their appropriate derived classes. 8 | #include "dmc/Python/Polymorphic.h" 9 | #include "dmc/Python/OpAsm.h" 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace pybind11; 16 | 17 | namespace dmc { 18 | namespace py { 19 | 20 | namespace { 21 | class ConstraintRegistry { 22 | public: 23 | static ConstraintRegistry &get() { 24 | static ConstraintRegistry instance; 25 | return instance; 26 | } 27 | 28 | /// Function registers a constraint and returns the name. Throws on error. 29 | std::string registerConstraint(std::string expr) { 30 | // Substitute `{self}` 31 | dict fmtArgs{"self"_a = "arg"}; 32 | auto pyExpr = pybind11::cast(expr).cast().format(**fmtArgs); 33 | // Wrap in a function and register it in the main scope 34 | std::string funcName{"anonymous_constraint_"}; 35 | funcName += std::to_string(idx++); 36 | dict funcExpr{"func_name"_a = funcName, "expr"_a = pyExpr}; 37 | auto funcStr = "def {func_name}(arg): return {expr}"_s 38 | .format(**funcExpr); 39 | exec(funcStr, getInternalScope()); 40 | return funcName; 41 | } 42 | 43 | template 44 | LogicalResult evalConstraint(const std::string &funcName, ArgT arg) { 45 | return success( 46 | getInternalScope()[funcName.c_str()](arg).template cast()); 47 | } 48 | 49 | private: 50 | ConstraintRegistry() = default; 51 | 52 | std::size_t idx{}; 53 | }; 54 | } // end anonymous namespace 55 | 56 | LogicalResult registerConstraint(Location loc, StringRef expr, 57 | std::string &funcName) { 58 | try { 59 | funcName = ConstraintRegistry::get().registerConstraint(expr.str()); 60 | } catch (const std::runtime_error &e) { 61 | return emitError(loc) << "Failed to create Python constraint: " << e.what(); 62 | } 63 | return success(); 64 | } 65 | 66 | LogicalResult evalConstraint(const std::string &funcName, Type type) { 67 | return ConstraintRegistry::get().evalConstraint(funcName, type); 68 | } 69 | 70 | LogicalResult evalConstraint(const std::string &funcName, Attribute attr) { 71 | return ConstraintRegistry::get().evalConstraint(funcName, attr); 72 | } 73 | 74 | } // end namespace py 75 | 76 | Region &LoopLike::getLoopRegion(DynamicOperation *impl, Operation *op) { 77 | py::OperationWrap wrap{op, impl}; 78 | return wrap.getRegion(region.str()); 79 | } 80 | 81 | bool LoopLike::isDefinedOutside(DynamicOperation *impl, Operation *op, 82 | Value value) { 83 | return !getLoopRegion(impl, op).isAncestor(value.getParentRegion()) && 84 | py::getMainScope()[definedOutsideFcn.str().c_str()](op, value).cast(); 85 | } 86 | 87 | bool LoopLike::canBeHoisted(DynamicOperation *impl, Operation *op) { 88 | return py::getMainScope()[canBeHoistedFcn.str().c_str()](op).cast(); 89 | } 90 | 91 | } // end namespace dmc 92 | -------------------------------------------------------------------------------- /lib/Embed/InMemoryDef.cpp: -------------------------------------------------------------------------------- 1 | #include "Scope.h" 2 | #include "dmc/Embed/InMemoryDef.h" 3 | 4 | #include 5 | #include 6 | 7 | using namespace llvm; 8 | using namespace pybind11; 9 | 10 | namespace dmc { 11 | namespace py { 12 | 13 | InMemoryDef::InMemoryDef(StringRef fcnName, StringRef fcnSig) { 14 | pgs.def(fcnName + fcnSig); 15 | } 16 | 17 | InMemoryDef::~InMemoryDef() { 18 | pgs.enddef(); 19 | // Store the parser/printer in the internal scope 20 | exec(os.str(), getInternalScope()); 21 | } 22 | 23 | InMemoryClass::InMemoryClass(StringRef clsName, ArrayRef parentCls, 24 | module &m) : m{m} { 25 | // intercept invalid class names 26 | auto valid = StringSwitch(clsName) 27 | .Case("return", false) 28 | .Case("def", false) 29 | .Case("class", false) 30 | .Case("assert", false) 31 | .Default(true); 32 | auto name = clsName.str(); 33 | if (!valid) 34 | name.front() = std::toupper(name.front()); 35 | 36 | auto line = pgs.line() << "class " << name << "("; 37 | llvm::interleaveComma(parentCls, line, [&](StringRef cls) { line << cls; }); 38 | line << "):" << incr; 39 | } 40 | 41 | InMemoryClass::~InMemoryClass() { 42 | pgs.endblock(); 43 | exec(os.str(), m.attr("__dict__")); 44 | } 45 | 46 | } // end namespace py 47 | } // end namespace dmc 48 | -------------------------------------------------------------------------------- /lib/Embed/Init.cpp: -------------------------------------------------------------------------------- 1 | #include "Scope.h" 2 | #include "dmc/Embed/Constraints.h" 3 | #include "dmc/Python/PyMLIR.h" 4 | 5 | #include 6 | 7 | namespace { 8 | PYBIND11_EMBEDDED_MODULE(mlir, m) { 9 | mlir::py::getModule(m); 10 | } 11 | } // end anonymous namespace 12 | 13 | using namespace pybind11; 14 | 15 | namespace mlir { 16 | namespace py { 17 | 18 | static bool inited{false}; 19 | 20 | void init(MLIRContext *ctx) { 21 | if (inited) 22 | return; 23 | inited = true; 24 | 25 | setMLIRContext(ctx); 26 | initialize_interpreter(); 27 | } 28 | 29 | } // end namespace py 30 | } // end namespace mlir 31 | -------------------------------------------------------------------------------- /lib/Embed/ParserPrinter.cpp: -------------------------------------------------------------------------------- 1 | #include "Scope.h" 2 | #include "dmc/Dynamic/DynamicOperation.h" 3 | #include "dmc/Dynamic/DynamicType.h" 4 | #include "dmc/Dynamic/DynamicAttribute.h" 5 | #include "dmc/Python/OpAsm.h" 6 | #include "dmc/Python/DialectAsm.h" 7 | 8 | #include 9 | #include 10 | 11 | using namespace mlir; 12 | using namespace pybind11; 13 | 14 | // is_copy_constructible is true for OperationState even though it contains 15 | // vector, so explicitly mark it as non copy constructible. 16 | template <> struct std::is_copy_constructible 17 | : public std::false_type {}; 18 | 19 | namespace dmc { 20 | namespace py { 21 | 22 | bool execParser(const std::string &name, OpAsmParser &parser, 23 | OperationState &result) { 24 | constexpr auto parser_policy = return_value_policy::reference; 25 | ensureBuiltins(getInternalModule()); 26 | auto fcn = getInternalScope()[name.c_str()]; 27 | return fcn.operator()(parser, result).cast(); 28 | } 29 | 30 | void execPrinter(const std::string &name, OpAsmPrinter &printer, Operation *op, 31 | DynamicOperation *spec) { 32 | constexpr auto printer_policy = return_value_policy::reference; 33 | ensureBuiltins(getInternalModule()); 34 | auto fcn = getInternalScope()[name.c_str()]; 35 | OperationWrap wrap{op, spec}; 36 | fcn.operator()(printer, &wrap); 37 | } 38 | 39 | bool execParser(const std::string &name, DialectAsmParser &parser, 40 | std::vector &result) { 41 | constexpr auto parser_policy = return_value_policy::reference; 42 | ensureBuiltins(getInternalModule()); 43 | auto fcn = getInternalScope()[name.c_str()]; 44 | TypeResultWrap wrap{result}; 45 | return fcn.operator()(parser, wrap).cast(); 46 | } 47 | 48 | template 49 | void execPrinter(const std::string &name, DialectAsmPrinter &printer, 50 | DynamicT t) { 51 | constexpr auto printer_policy = return_value_policy::reference; 52 | ensureBuiltins(getInternalModule()); 53 | auto fcn = getInternalScope()[name.c_str()]; 54 | TypeWrap wrap{t}; 55 | fcn.operator()(printer, &wrap); 56 | } 57 | 58 | template void execPrinter(const std::string &name, DialectAsmPrinter &printer, 59 | DynamicType type); 60 | template void execPrinter(const std::string &name, DialectAsmPrinter &printer, 61 | DynamicAttribute attr); 62 | 63 | } // end namespace py 64 | } // end namespace dmc 65 | -------------------------------------------------------------------------------- /lib/Embed/PythonGen.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Embed/PythonGen.h" 2 | 3 | using namespace llvm; 4 | 5 | namespace dmc { 6 | namespace py { 7 | 8 | PythonGenStream::Line::Line(PythonGenStream &s) 9 | : s{s}, newline{true} { 10 | s.os.indent(s.indent); 11 | } 12 | 13 | PythonGenStream::Line::~Line() { 14 | if (newline) 15 | s.os << "\n"; 16 | } 17 | 18 | PythonGenStream::Line::Line(Line &&line) 19 | : s{line.s}, newline{line.newline} { 20 | line.newline = false; 21 | } 22 | 23 | PythonGenStream::PythonGenStream(raw_ostream &os) 24 | : os{os}, indent{0} {} 25 | 26 | PythonGenStream::Line PythonGenStream::line() { 27 | return Line{*this}; 28 | } 29 | 30 | PythonGenStream &PythonGenStream::block(StringRef ty, Twine expr) { 31 | line() << ty << " " << expr << ":"; 32 | incr(); 33 | return *this; 34 | } 35 | 36 | PythonGenStream &PythonGenStream::endblock() { 37 | decr(); 38 | return *this; 39 | } 40 | 41 | void PythonGenStream::changeIndent(int delta) { 42 | indent += delta; 43 | assert(indent >= 0 && "Negative indent"); 44 | } 45 | 46 | } // end namespace py 47 | } // end namespace dmc 48 | -------------------------------------------------------------------------------- /lib/Embed/Scope.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | using namespace pybind11; 4 | 5 | namespace dmc { 6 | namespace py { 7 | 8 | module getInternalModule() { 9 | return module::import("mlir"); 10 | } 11 | 12 | void ensureBuiltins(module m) { 13 | auto scope = m.attr("__dict__").cast(); 14 | if (!scope.contains("__builtins__")) 15 | scope["__builtins__"] = PyEval_GetBuiltins(); 16 | } 17 | 18 | } // end namespace py 19 | } // end namespace dmc 20 | -------------------------------------------------------------------------------- /lib/Embed/Scope.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace dmc { 6 | namespace py { 7 | 8 | pybind11::module getInternalModule(); 9 | 10 | inline auto getInternalScope() { 11 | return getInternalModule().attr("__dict__"); 12 | } 13 | 14 | inline auto getMainScope() { 15 | return pybind11::module::import("__main__").attr("__dict__"); 16 | } 17 | 18 | void ensureBuiltins(pybind11::module m); 19 | 20 | } // end namespace py 21 | } // end namespace dmc 22 | -------------------------------------------------------------------------------- /lib/Embed/Spec.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Embed/Constraints.h" 2 | #include "dmc/Spec/SpecTypes.h" 3 | #include "dmc/Spec/SpecAttrs.h" 4 | 5 | using namespace mlir; 6 | 7 | namespace dmc { 8 | 9 | namespace detail { 10 | 11 | /// Store the constraint expression. 12 | struct PyConstraintStorage { 13 | using KeyTy = StringRef; 14 | 15 | explicit PyConstraintStorage(KeyTy key) : expr{key} {} 16 | bool operator==(KeyTy key) const { return key == expr; } 17 | static llvm::hash_code hashKey(KeyTy key) { return hash_value(key); } 18 | 19 | StringRef expr; 20 | /// Not part of the key, but store the function name. Initialize to empty. 21 | std::string funcName{}; 22 | }; 23 | 24 | struct PyTypeStorage : public PyConstraintStorage, public TypeStorage { 25 | using PyConstraintStorage::PyConstraintStorage; 26 | 27 | static PyTypeStorage *construct(TypeStorageAllocator &alloc, KeyTy key) { 28 | auto expr = alloc.copyInto(key); 29 | return new (alloc.allocate()) PyTypeStorage{expr}; 30 | } 31 | }; 32 | 33 | struct PyAttrStorage : public PyConstraintStorage, public AttributeStorage { 34 | using PyConstraintStorage::PyConstraintStorage; 35 | 36 | static PyAttrStorage *construct(AttributeStorageAllocator &alloc, KeyTy key) { 37 | auto expr = alloc.copyInto(key); 38 | return new (alloc.allocate()) PyAttrStorage{expr}; 39 | } 40 | }; 41 | 42 | } // end namespace detail 43 | 44 | /// PyType implementation. 45 | PyType PyType::getChecked(Location loc, StringRef expr) { 46 | auto ret = Base::get(loc.getContext(), Kind, expr); 47 | auto &funcName = ret.getImpl()->funcName; 48 | if (funcName.empty()) { 49 | if (failed(py::registerConstraint(loc, expr, funcName))) 50 | return {}; 51 | } 52 | return ret; 53 | } 54 | 55 | LogicalResult PyType::verify(Type ty) { 56 | return py::evalConstraint(getImpl()->funcName, ty); 57 | } 58 | 59 | void PyType::print(DialectAsmPrinter &printer) { 60 | printer << getTypeName() << "<\"" << getImpl()->expr << "\">"; 61 | } 62 | 63 | /// PyAttr implementation. 64 | PyAttr PyAttr::getChecked(Location loc, StringRef expr) { 65 | auto ret = Base::get(loc.getContext(), Kind, expr); 66 | auto &funcName = ret.getImpl()->funcName; 67 | if (funcName.empty()) { 68 | if (failed(py::registerConstraint(loc, expr, funcName))) 69 | return {}; 70 | } 71 | return ret; 72 | } 73 | 74 | LogicalResult PyAttr::verify(Attribute attr) { 75 | return py::evalConstraint(getImpl()->funcName, attr); 76 | } 77 | 78 | void PyAttr::print(DialectAsmPrinter &printer) { 79 | printer << getAttrName() << "<\"" << getImpl()->expr << "\">"; 80 | } 81 | 82 | } // end namespace dmc 83 | -------------------------------------------------------------------------------- /lib/IO/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(DMCIO 2 | ModuleWriter.cpp 3 | ) 4 | target_link_libraries(DMCIO MLIRIR) 5 | -------------------------------------------------------------------------------- /lib/IO/ModuleWriter.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/IO/ModuleWriter.h" 2 | #include "dmc/Dynamic/DynamicOperation.h" 3 | 4 | using namespace mlir; 5 | 6 | namespace dmc { 7 | 8 | ModuleWriter::ModuleWriter(DynamicContext *ctx) 9 | : builder{ctx->getContext()}, 10 | // TODO location data in Python file of call 11 | module{ModuleOp::create(builder.getUnknownLoc())} {} 12 | 13 | FuncOp ModuleWriter::createFunction( 14 | StringRef name, 15 | ArrayRef argTys, ArrayRef retTys) { 16 | auto funcType = builder.getFunctionType(argTys, retTys); 17 | // TODO location data 18 | auto funcOp = FuncOp::create(builder.getUnknownLoc(), name, funcType); 19 | module.push_back(funcOp); 20 | return funcOp; 21 | } 22 | 23 | FunctionWriter::FunctionWriter(FuncOp func) 24 | : builder{func.getContext()}, 25 | func(func), 26 | entryBlock{func.addEntryBlock()} { 27 | builder.setInsertionPointToStart(entryBlock); 28 | } 29 | 30 | Operation *FunctionWriter::createOp( 31 | DynamicOperation *op, ValueRange args, ArrayRef retTys) { 32 | return createOp(op->getOpInfo(), args, retTys); 33 | } 34 | 35 | Operation *FunctionWriter::createOp( 36 | StringRef name, ValueRange args, ArrayRef retTys) { 37 | return createOp(OperationName{name, func.getContext()}, args, retTys); 38 | } 39 | 40 | Operation *FunctionWriter::createOp( 41 | OperationName opName, ValueRange args, ArrayRef retTys) { 42 | // TODO location data 43 | OperationState state{builder.getUnknownLoc(), opName}; 44 | state.addOperands(args); 45 | state.addTypes(retTys); 46 | return builder.createOperation(state); 47 | } 48 | 49 | } // end namespace dmc 50 | -------------------------------------------------------------------------------- /lib/Python/AsmUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | template 10 | void exposeLiteralParser(pybind11::class_ &cls, NameT name, 11 | FcnT fcn) { 12 | cls.def(name, fcn); 13 | } 14 | 15 | template 17 | void exposeLiteralParser(pybind11::class_ &cls, NameT name, 18 | FcnT fcn, TailTs ...tail) { 19 | exposeLiteralParser(cls, name, fcn); 20 | exposeLiteralParser(cls, tail...); 21 | } 22 | 23 | template 24 | void exposeAllLiteralParsers(pybind11::class_ &cls) { 25 | exposeLiteralParser( 26 | cls, 27 | "parseArrow", &ParserT::parseArrow, 28 | "parseColon", &ParserT::parseColon, 29 | "parseComma", &ParserT::parseComma, 30 | "parseEqual", &ParserT::parseEqual, 31 | "parseLess", &ParserT::parseLess, 32 | "parseGreater", &ParserT::parseGreater, 33 | "parseLParen", &ParserT::parseLParen, 34 | "parseRParen", &ParserT::parseRParen, 35 | "parseLSquare", &ParserT::parseLSquare, 36 | "parseRSquare", &ParserT::parseRSquare, 37 | "parseOptionalArrow", &ParserT::parseOptionalArrow, 38 | "parseOptionalColon", &ParserT::parseOptionalColon, 39 | "parseOptionalComma", &ParserT::parseOptionalComma, 40 | "parseOptionalLess", &ParserT::parseOptionalLess, 41 | "parseOptionalGreater", &ParserT::parseOptionalGreater, 42 | "parseOptionalLParen", &ParserT::parseOptionalLParen, 43 | "parseOptionalRParen", &ParserT::parseOptionalRParen, 44 | "parseOptionalLSquare", &ParserT::parseOptionalLSquare, 45 | "parseOptionalRSquare", &ParserT::parseOptionalRSquare, 46 | "parseOptionalEllipsis", &ParserT::parseOptionalEllipsis); 47 | } 48 | 49 | } // end namespace py 50 | } // end namespace mlir 51 | -------------------------------------------------------------------------------- /lib/Python/Attribute.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | #include "Identifier.h" 3 | #include "Attribute.h" 4 | #include "Location.h" 5 | 6 | #include 7 | #include 8 | 9 | using namespace pybind11; 10 | 11 | namespace mlir { 12 | namespace py { 13 | 14 | ArrayAttr getArrayAttr(const std::vector &attrs) { 15 | return ArrayAttr::get(attrs, getMLIRContext()); 16 | } 17 | 18 | AttributeArray *getArrayAttrValue(ArrayAttr attr) { 19 | return new AttributeArray{attr.getValue()}; 20 | } 21 | 22 | auto wrapIndex(ptrdiff_t i, unsigned sz) { 23 | if (i < 0) 24 | i += sz; 25 | if (i < 0 || static_cast(i) >= sz) 26 | throw index_error{}; 27 | return i; 28 | } 29 | 30 | AttributeArray *arrayGetSlice(ArrayAttr attr, pybind11::slice s) { 31 | size_t start, stop, step, sliceLength; 32 | if (!s.compute(attr.size(), &start, &stop, &step, &sliceLength)) 33 | throw error_already_set{}; 34 | auto *ret = new AttributeArray; 35 | ret->reserve(sliceLength); 36 | for (unsigned i = 0; i < sliceLength; ++i) { 37 | ret->push_back(attr[start]); 38 | start += step; 39 | } 40 | return ret; 41 | } 42 | 43 | Attribute arrayGetIndex(ArrayAttr attr, ptrdiff_t i) { 44 | i = wrapIndex(i, attr.size()); 45 | return attr[static_cast(i)]; 46 | } 47 | 48 | DictionaryAttr getDictionaryAttr(const AttributeMap &attrs) { 49 | NamedAttrList attrList; 50 | for (auto &[name, attr] : attrs) { 51 | attrList.push_back({getIdentifierChecked(name), attr}); 52 | } 53 | return DictionaryAttr::get(attrList, getMLIRContext()); 54 | } 55 | 56 | AttributeMap *getDictionaryAttrValue(DictionaryAttr attr) { 57 | auto *ret = new AttributeMap; 58 | ret->reserve(attr.size()); 59 | for (auto &[name, attr] : attr) { 60 | ret->emplace(name.strref(), attr); 61 | } 62 | return ret; 63 | } 64 | 65 | Attribute dictionaryAttrGetItem(DictionaryAttr attr, const std::string &key) { 66 | auto ret = attr.get(key); 67 | if (!ret) 68 | throw key_error{}; 69 | return ret; 70 | } 71 | 72 | FloatAttr getFloatAttr(Type ty, double val, Location loc) { 73 | if (!ty) 74 | throw std::invalid_argument{"Float type cannot be null"}; 75 | if (failed(FloatAttr::verifyConstructionInvariants( 76 | loc, ty, val))) 77 | throw std::invalid_argument{"Bad float representation"}; 78 | return FloatAttr::get(ty, val); 79 | } 80 | 81 | IntegerAttr getIntegerAttr(Type ty, int64_t val, Location loc) { 82 | if (!ty) 83 | throw std::invalid_argument{"Integer type cannot be null"}; 84 | if (failed(IntegerAttr::verifyConstructionInvariants( 85 | loc, ty, val))) 86 | throw std::invalid_argument{"Bad integer representation"}; 87 | return IntegerAttr::get(ty, val); 88 | } 89 | 90 | OpaqueAttr getOpaqueAttr(const std::string &dialect, const std::string &data, 91 | Type type, Location loc) { 92 | auto id = getIdentifierChecked(dialect); 93 | if (failed(OpaqueAttr::verifyConstructionInvariants(loc, id, data, type))) 94 | throw std::invalid_argument{"Invalid OpaqueAttr construction"}; 95 | return OpaqueAttr::get(id, data, type, getMLIRContext()); 96 | } 97 | 98 | std::string getOpaqueAttrDialect(OpaqueAttr attr) { 99 | return attr.getDialectNamespace().str(); 100 | } 101 | 102 | std::string getOpaqueAttrData(OpaqueAttr attr) { 103 | return attr.getAttrData().str(); 104 | } 105 | 106 | } // end namespace py 107 | } // end namespace mlir 108 | -------------------------------------------------------------------------------- /lib/Python/Attribute.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | /// ArrayAttr. 9 | using AttributeArray = std::vector; 10 | 11 | ArrayAttr getArrayAttr(const std::vector &attrs); 12 | AttributeArray *getArrayAttrValue(ArrayAttr attr); 13 | AttributeArray *arrayGetSlice(ArrayAttr attr, pybind11::slice s); 14 | Attribute arrayGetIndex(ArrayAttr attr, ptrdiff_t i); 15 | 16 | /// DictionaryAttr. 17 | using AttributeMap = std::unordered_map; 18 | 19 | DictionaryAttr getDictionaryAttr(const AttributeMap &attrs); 20 | AttributeMap *getDictionaryAttrValue(DictionaryAttr attr); 21 | Attribute dictionaryAttrGetItem(DictionaryAttr attr, const std::string &key); 22 | 23 | /// FloatAttr. 24 | FloatAttr getFloatAttr(Type ty, double val, Location loc); 25 | 26 | /// IntegerAttr. 27 | IntegerAttr getIntegerAttr(Type ty, int64_t val, Location loc); 28 | 29 | /// OpaqueAttr. 30 | OpaqueAttr getOpaqueAttr(const std::string &dialect, const std::string &data, 31 | Type type, Location loc); 32 | std::string getOpaqueAttrDialect(OpaqueAttr attr); 33 | std::string getOpaqueAttrData(OpaqueAttr attr); 34 | 35 | } // end namespace py 36 | } // end namespace mlir 37 | -------------------------------------------------------------------------------- /lib/Python/BuildableType.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | #include "Location.h" 3 | #include "dmc/Dynamic/DynamicContext.h" 4 | #include "dmc/Dynamic/DynamicDialect.h" 5 | #include "dmc/Dynamic/DynamicType.h" 6 | #include "dmc/Dynamic/Alias.h" 7 | #include "dmc/Python/Polymorphic.h" 8 | 9 | using namespace pybind11; 10 | using namespace mlir; 11 | 12 | namespace dmc { 13 | namespace py { 14 | 15 | static Type buildDynamicType( 16 | std::string dialectName, std::string typeName, 17 | const std::vector ¶ms, Location loc) { 18 | auto *dialect = mlir::py::getMLIRContext()->getRegisteredDialect(dialectName); 19 | if (!dialect) 20 | throw std::invalid_argument{"Unknown dialect name: " + dialectName}; 21 | auto *dynDialect = dynamic_cast(dialect); 22 | if (!dynDialect) 23 | throw std::invalid_argument{"Not a dynamic dialect: " + dialectName}; 24 | auto *impl = dynDialect->lookupType(typeName); 25 | if (!impl) 26 | throw std::invalid_argument{"Unknown type '" + typeName + "' in dialect '" + 27 | dialectName + "'"}; 28 | return DynamicType::getChecked(loc, impl, params); 29 | } 30 | 31 | static Type getAliasedType(std::string dialectName, std::string aliasName) { 32 | auto *dialect = mlir::py::getMLIRContext()->getRegisteredDialect(dialectName); 33 | if (!dialect) 34 | throw std::invalid_argument{"Unknown dialect name: " + dialectName}; 35 | auto *dynDialect = dynamic_cast(dialect); 36 | if (!dynDialect) 37 | throw std::invalid_argument{"Not a dynamic dialect: " + dialectName}; 38 | auto *alias = dynDialect->lookupTypeAlias(aliasName); 39 | if (!alias) 40 | throw std::invalid_argument{"Unknown type '" + aliasName + "' in dialect '" + 41 | dialectName + "'"}; 42 | return alias->getAliasedType(); 43 | } 44 | 45 | void exposeDynamicTypes(module &m) { 46 | m.def("build_dynamic_type", &buildDynamicType, 47 | "dialectName"_a, "typeName"_a, "params"_a = std::vector{}, 48 | "location"_a = mlir::py::getUnknownLoc()); 49 | 50 | m.def("get_aliased_type", &getAliasedType); 51 | } 52 | 53 | } // end namespace py 54 | } // end namespace dmc 55 | -------------------------------------------------------------------------------- /lib/Python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(pymlir 2 | Context.cpp 3 | Context.h 4 | OwningModuleRef.cpp 5 | OwningModuleRef.h 6 | Parser.cpp 7 | Parser.h 8 | Module.cpp 9 | Module.h 10 | Location.cpp 11 | Location.h 12 | Identifier.cpp 13 | Identifier.h 14 | Type.cpp 15 | Type.h 16 | Attribute.cpp 17 | Attribute.h 18 | 19 | BuildableType.cpp 20 | DialectAsm.cpp 21 | OpAsm.cpp 22 | Utility.h 23 | AsmUtils.h 24 | 25 | # Reduce compilation time by splitting the functions 26 | ExposeBuilder.cpp 27 | ExposeDialectAsm.cpp 28 | ExposeOpAsm.cpp 29 | ExposeOps.cpp 30 | ExposeValue.cpp 31 | ExposeShapedTypes.cpp 32 | ExposeStandardTypes.cpp 33 | ExposeOpaqueType.cpp 34 | ExposeFunctionType.cpp 35 | ExposeElementsAttr.cpp 36 | ExposeSymbolRefAttr.cpp 37 | ExposeIntFPAttr.cpp 38 | ExposeDictAttr.cpp 39 | ExposeArrayAttr.cpp 40 | ExposeAttribute.cpp 41 | ExposeParser.cpp 42 | ExposeModule.cpp 43 | ExposeLocation.cpp 44 | ExposeType.cpp 45 | Expose.cpp 46 | Expose.h 47 | ) 48 | 49 | target_include_directories(pymlir PUBLIC 50 | ${Python3_INCLUDE_DIRS} 51 | ) 52 | target_link_libraries(pymlir PUBLIC 53 | ${Python3_LIBRARIES} 54 | pybind11 55 | MLIRIR 56 | MLIRParser 57 | MLIRStandardOps 58 | MLIRLLVMIR 59 | MLIRSCF 60 | MLIRTransformUtils 61 | MLIRSCFToStandard 62 | MLIRStandardToLLVM 63 | MLIRTransforms 64 | ) 65 | 66 | add_library(DMCDLLInit DllInit.cpp) 67 | target_link_libraries(DMCDLLInit PUBLIC 68 | pybind11 69 | pymlir 70 | ) 71 | 72 | pybind11_add_module(mlir ExternalModule.cpp) 73 | target_link_libraries(mlir PUBLIC 74 | pymlir 75 | DMCDynamic 76 | DMCSpec 77 | DMCTraits 78 | DMCDLLInit 79 | ) 80 | -------------------------------------------------------------------------------- /lib/Python/Context.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecDialect.h" 2 | #include "dmc/Traits/Registry.h" 3 | #include "dmc/Dynamic/DynamicContext.h" 4 | 5 | #include 6 | #include 7 | 8 | namespace mlir { 9 | namespace py { 10 | 11 | /// Wrap the global MLIR context in a singleton class, using member objects 12 | /// to ensure initialization order of dialect registrations. Static objects 13 | /// are not guaranteed by the standard to be initialized in any order. 14 | class GlobalContextHandle { 15 | public: 16 | static GlobalContextHandle &instance() { 17 | static GlobalContextHandle instance; 18 | return instance; 19 | } 20 | 21 | MLIRContext *getContext() { return ptr; } 22 | void setContext(MLIRContext *ctx) { ptr = ctx; } 23 | 24 | private: 25 | /// Initiazation order is guaranteed. 26 | DialectRegistration standardOpsDialect; 27 | DialectRegistration scfDialect; 28 | DialectRegistration llvmDialect; 29 | DialectRegistration specDialect; 30 | DialectRegistration traitRegistry; 31 | MLIRContext context; 32 | MLIRContext *ptr{&context}; 33 | }; 34 | 35 | MLIRContext *getMLIRContext() { 36 | return GlobalContextHandle::instance().getContext(); 37 | } 38 | 39 | void setMLIRContext(MLIRContext *ctx) { 40 | GlobalContextHandle::instance().setContext(ctx); 41 | } 42 | 43 | } // end namespace py 44 | } // end namespace mlir 45 | -------------------------------------------------------------------------------- /lib/Python/Context.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | /// Store a global MLIR context instance. All calls to MLIR functions through 9 | /// the Python API will use this instance. This simplifies the Python API as 10 | /// users will not need to pass a context handle to all function calls. 11 | MLIRContext *getMLIRContext(); 12 | 13 | } // end namespace py 14 | } // end namespace mlir 15 | -------------------------------------------------------------------------------- /lib/Python/DialectAsm.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Python/DialectAsm.h" 2 | #include "dmc/Dynamic/DynamicType.h" 3 | #include "dmc/Dynamic/DynamicAttribute.h" 4 | 5 | #include 6 | 7 | using namespace pybind11; 8 | using namespace mlir; 9 | 10 | namespace dmc { 11 | namespace py { 12 | 13 | TypeWrap::TypeWrap(DynamicType type) 14 | : params{type.getParams()}, 15 | paramSpec{type.getDynImpl()->getParamSpec()} {} 16 | 17 | TypeWrap::TypeWrap(DynamicAttribute attr) 18 | : params{attr.getParams()}, 19 | paramSpec{attr.getDynImpl()->getParamSpec()} {} 20 | 21 | void exposeTypeWrap(module &m) { 22 | class_(m, "TypeWrap") 23 | .def("getParameter", [](TypeWrap &wrap, std::string name) { 24 | for (auto [param, spec] : llvm::zip(wrap.getParams(), wrap.getSpec())) { 25 | if (spec.getName() == name) 26 | return param; 27 | } 28 | throw std::invalid_argument{"Unknown parameter name: " + name}; 29 | }); 30 | 31 | class_(m, "TypeResultWrap") 32 | .def("append", [](TypeResultWrap &wrap, Attribute attr) { 33 | wrap.getImpl().push_back(attr); 34 | }); 35 | } 36 | 37 | } // end namespace py 38 | } // end namespace dmc 39 | -------------------------------------------------------------------------------- /lib/Python/DllInit.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "dmc/Python/PyMLIR.h" 4 | 5 | using namespace pybind11; 6 | 7 | namespace mlir { 8 | namespace py { 9 | 10 | void init(MLIRContext *ctx) {} 11 | 12 | } // end namespace py 13 | } // end namespace mlir 14 | -------------------------------------------------------------------------------- /lib/Python/Expose.cpp: -------------------------------------------------------------------------------- 1 | #include "Expose.h" 2 | 3 | #include 4 | 5 | using namespace pybind11; 6 | 7 | namespace mlir { 8 | namespace py { 9 | 10 | void getModule(module &m) { 11 | exposeParser(m); 12 | auto type = exposeTypeBase(m); 13 | exposeAttribute(m); 14 | exposeType(m, type); 15 | exposeValue(m); 16 | exposeOps(m); 17 | 18 | exposeOpAsm(m); 19 | exposeDialectAsm(m); 20 | 21 | exposeBuilder(m); 22 | } 23 | 24 | } // end namespace py 25 | } // end namespace mlir 26 | -------------------------------------------------------------------------------- /lib/Python/Expose.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Dynamic/DynamicOperation.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace mlir { 10 | namespace py { 11 | 12 | using AttrClass = pybind11::class_; 13 | using TypeClass = pybind11::class_; 14 | using OpClass = pybind11::class_; 15 | 16 | void exposeParser(pybind11::module &m); 17 | /// pybind11 needs Type to be exposed before it can be used in default args. 18 | TypeClass exposeTypeBase(pybind11::module &m); 19 | void exposeType(pybind11::module &m, TypeClass &type); 20 | void exposeAttribute(pybind11::module &m); 21 | void exposeValue(pybind11::module &m); 22 | 23 | /// Operations. 24 | void exposeOps(pybind11::module &m); 25 | void exposeModule(pybind11::module &m, OpClass &cls); 26 | 27 | /// Attribute subclasses. 28 | void exposeLocation(pybind11::module &m, AttrClass &attr); 29 | void exposeArrayAttr(pybind11::module &m, AttrClass &attr); 30 | void exposeDictAttr(pybind11::module &m, AttrClass &attr); 31 | void exposeIntFPAttr(pybind11::module &m, AttrClass &attr); 32 | void exposeSymbolRefAttr(pybind11::module &m, AttrClass &attr); 33 | void exposeElementsAttr(pybind11::module &m, AttrClass &attr); 34 | 35 | /// Type subclasses. 36 | void exposeFunctionType(pybind11::module &m, TypeClass &type); 37 | void exposeOpaqueType(pybind11::module &m, TypeClass &type); 38 | void exposeStandardNumericTypes(pybind11::module &m, TypeClass &type); 39 | void exposeShapedTypes(pybind11::module &m, TypeClass &type); 40 | 41 | /// Parsers and printers. 42 | void exposeOpAsm(pybind11::module &m); 43 | void exposeDialectAsm(pybind11::module &m); 44 | 45 | void exposeBuilder(pybind11::module &m); 46 | 47 | } // end namespace py 48 | } // end namespace mlir 49 | -------------------------------------------------------------------------------- /lib/Python/ExposeArrayAttr.cpp: -------------------------------------------------------------------------------- 1 | #include "Attribute.h" 2 | #include "Utility.h" 3 | 4 | #include 5 | 6 | using namespace pybind11; 7 | 8 | namespace mlir { 9 | namespace py { 10 | 11 | template auto nullcheck(FcnT fcn) { 12 | return ::nullcheck(fcn, "array attribute"); 13 | } 14 | 15 | void exposeArrayAttr(module &m, class_ &attr) { 16 | class_(m, "ArrayAttr", attr) 17 | .def(init(&getArrayAttr)) 18 | .def("getValue", nullcheck(&getArrayAttrValue)) 19 | .def("empty", nullcheck([](ArrayAttr attr) { return attr.size() == 0; })) 20 | .def("__getitem__", nullcheck(&arrayGetSlice)) 21 | .def("__getitem__", nullcheck(&arrayGetIndex)) 22 | .def("__len__", nullcheck([](ArrayAttr attr) { return attr.size(); })) 23 | .def("__iter__", nullcheck([](ArrayAttr attr) { 24 | return make_iterator(attr.begin(), attr.end()); 25 | }), keep_alive<0, 1>()); 26 | } 27 | 28 | } // end namespace py 29 | } // end namespace mlir 30 | -------------------------------------------------------------------------------- /lib/Python/ExposeDialectAsm.cpp: -------------------------------------------------------------------------------- 1 | #include "Utility.h" 2 | #include "AsmUtils.h" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace pybind11; 9 | using namespace mlir; 10 | using namespace llvm; 11 | 12 | namespace dmc { 13 | namespace py { 14 | extern void exposeTypeWrap(module &m); 15 | } // end namespace py 16 | } // end namespace dmc 17 | 18 | namespace mlir { 19 | namespace py { 20 | 21 | namespace { 22 | static void printDimensionListOrRaw(DialectAsmPrinter &p, Attribute attr) { 23 | if (auto arr = attr.dyn_cast()) { 24 | interleave(arr, p, [&](Attribute el) { 25 | if (auto i = el.dyn_cast()) { 26 | if (i.getValue().getSExtValue() == -1) { 27 | p << '?'; 28 | } else { 29 | p << i.getValue().getZExtValue(); 30 | } 31 | } else { 32 | p << el; 33 | } 34 | }, "x"); 35 | p << "x"; 36 | } else { 37 | p.printAttribute(attr); 38 | } 39 | } 40 | } // end anonymous namespace 41 | 42 | void exposeDialectAsm(module &m) { 43 | class_> 44 | (m, "DialectAsmPrinter") 45 | .def("print", [](DialectAsmPrinter &p, std::string val) { 46 | p << val; 47 | }) 48 | .def("printAttribute", &DialectAsmPrinter::printAttribute) 49 | .def("printDimensionListOrRaw", &printDimensionListOrRaw); 50 | 51 | class_> 52 | parserCls{m, "DialectAsmParser"}; 53 | exposeAllLiteralParsers(parserCls); 54 | parserCls 55 | .def("parseDimensionList", [](DialectAsmParser &parser, 56 | bool allowDynamic) { 57 | SmallVector dims; 58 | if (failed(parser.parseDimensionList(dims, allowDynamic))) 59 | return make_tuple(nullptr, false); 60 | std::vector ret{std::begin(dims), std::end(dims)}; 61 | return pybind11::make_tuple(std::move(ret), true); 62 | }) 63 | .def("parseAttribute", [](DialectAsmParser &parser) { 64 | Attribute attr; 65 | auto ret = parser.parseAttribute(attr); 66 | return make_tuple(attr, ret); 67 | }); 68 | 69 | dmc::py::exposeTypeWrap(m); 70 | } 71 | 72 | } // end namespace py 73 | } // end namespace mlir 74 | -------------------------------------------------------------------------------- /lib/Python/ExposeDictAttr.cpp: -------------------------------------------------------------------------------- 1 | #include "Attribute.h" 2 | #include "Utility.h" 3 | 4 | #include 5 | #include 6 | 7 | using namespace pybind11; 8 | 9 | namespace mlir { 10 | namespace py { 11 | 12 | template auto nullcheck(FcnT fcn) { 13 | return ::nullcheck(fcn, "dictionary attribute"); 14 | } 15 | 16 | void exposeDictAttr(module &m, class_ &attr) { 17 | class_(m, "DictionaryAttr", attr) 18 | .def(init(&getDictionaryAttr)) 19 | .def("getValue", nullcheck(&getDictionaryAttrValue)) 20 | .def("empty", nullcheck([](DictionaryAttr attr) { return attr.empty(); })) 21 | .def("__iter__", nullcheck([](DictionaryAttr attr) { 22 | return make_key_iterator(attr.begin(), attr.end()); 23 | }), keep_alive<0, 1>()) 24 | .def("items", nullcheck([](DictionaryAttr attr) { 25 | return make_iterator(attr.begin(), attr.end()); 26 | }), keep_alive<0, 1>()) 27 | .def("__getitem__", nullcheck(&dictionaryAttrGetItem)) 28 | .def("__contains__", nullcheck( 29 | [](DictionaryAttr attr, const std::string &key) -> bool { 30 | return !!attr.get(key); 31 | })) 32 | .def("__len__", nullcheck( 33 | [](DictionaryAttr attr) { return attr.size(); })); 34 | } 35 | 36 | } // end namespace py 37 | } // end namespace mlir 38 | -------------------------------------------------------------------------------- /lib/Python/ExposeFunctionType.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | #include "Utility.h" 3 | #include "Type.h" 4 | #include "Expose.h" 5 | 6 | #include 7 | 8 | using namespace pybind11; 9 | 10 | namespace mlir { 11 | namespace py { 12 | 13 | template auto nullcheck(FcnT fcn) { 14 | return ::nullcheck(fcn, "function type"); 15 | } 16 | 17 | FunctionType getFunctionType(TypeListRef inputs, TypeListRef results) { 18 | return FunctionType::get(inputs, results, getMLIRContext()); 19 | } 20 | 21 | void exposeFunctionType(pybind11::module &m, TypeClass &type) { 22 | class_(m, "FunctionType", type) 23 | .def(init(&getFunctionType), "inputs"_a = TypeList{}, 24 | "results"_a = TypeList{}) 25 | .def_property_readonly("inputs", nullcheck([](FunctionType ty) { 26 | auto inputs = ty.getInputs(); 27 | return new std::vector{std::begin(inputs), std::end(inputs)}; 28 | })) 29 | .def_property_readonly("results", nullcheck([](FunctionType ty) { 30 | auto inputs = ty.getResults(); 31 | return new std::vector{std::begin(inputs), std::end(inputs)}; 32 | })); 33 | } 34 | 35 | } // end namespace py 36 | } // end namespace mlir 37 | -------------------------------------------------------------------------------- /lib/Python/ExposeIntFPAttr.cpp: -------------------------------------------------------------------------------- 1 | #include "Attribute.h" 2 | #include "Location.h" 3 | #include "Utility.h" 4 | 5 | #include 6 | 7 | using namespace pybind11; 8 | 9 | namespace mlir { 10 | namespace py { 11 | 12 | template auto nullcheck(FcnT fcn) { 13 | return ::nullcheck(fcn, "attribute"); 14 | } 15 | 16 | void exposeIntFPAttr(module &m, class_ &attr) { 17 | /// TODO handle conversions from APFloat and APInt to Python floats and 18 | /// integers to support arbitrary-precision values. 19 | class_(m, "FloatAttr", attr) 20 | .def(init(&getFloatAttr), "type"_a, "value"_a, 21 | "location"_a = getUnknownLoc()) 22 | .def("getValue", nullcheck( 23 | [](FloatAttr attr) { return attr.getValueAsDouble(); })); 24 | 25 | class_(m, "IntegerAttr", attr) 26 | .def(init(&getIntegerAttr), "type"_a, "value"_a, 27 | "location"_a = getUnknownLoc()) 28 | .def("getInt", nullcheck(&IntegerAttr::getInt)) 29 | .def("getSInt", nullcheck(&IntegerAttr::getSInt)) 30 | .def("getUInt", nullcheck(&IntegerAttr::getUInt)); 31 | } 32 | 33 | } // end namespace py 34 | } // end namespace mlir 35 | -------------------------------------------------------------------------------- /lib/Python/ExposeLocation.cpp: -------------------------------------------------------------------------------- 1 | #include "Utility.h" 2 | #include "Location.h" 3 | 4 | #include 5 | #include 6 | 7 | using namespace llvm; 8 | using namespace mlir; 9 | using namespace pybind11; 10 | 11 | namespace mlir { 12 | namespace py { 13 | 14 | template auto isa() { 15 | return ::isa(); 16 | } 17 | 18 | void exposeLocation(module &m, class_ &attr) { 19 | class_ locAttr{m, "LocationAttr", attr}; 20 | locAttr 21 | .def(init()) 22 | .def(self == self) 23 | .def(self != self) 24 | .def("__repr__", StringPrinter{}) 25 | .def("__hash__", [](LocationAttr loc) { return hash_value(loc); }) 26 | .def("isUnknownLoc", isa()) 27 | .def("isCallSiteLoc", isa()) 28 | .def("isFileLineColLoc", isa()) 29 | .def("isFusedLoc", isa()) 30 | .def("isNameLoc", isa()); 31 | 32 | class_ loc{m, "Location", locAttr}; 33 | 34 | class_(m, "UnknownLoc", locAttr) 35 | .def(init(&getUnknownLoc)); 36 | 37 | class_(m, "CallSiteLoc", locAttr) 38 | .def(init(&getCallSiteLoc)) 39 | .def_property_readonly("callee", &getCallee) 40 | .def_property_readonly("caller", &getCaller); 41 | 42 | class_(m, "FileLineColLoc", locAttr) 43 | .def(init(&getFileLineColLoc), "file"_a, "line"_a = 1, "col"_a = 1) 44 | .def_property_readonly("filename", &getFilename) 45 | .def_property_readonly("line", &getLine) 46 | .def_property_readonly("col", &getColumn); 47 | 48 | class_(m, "FusedLoc", locAttr) 49 | .def(init(&getFusedLoc)) 50 | .def_property_readonly("locs", &getLocations); 51 | 52 | class_(m, "NameLoc", locAttr) 53 | .def(init(overload(&getNameLoc))) 54 | .def(init(overload(&getNameLoc))) 55 | .def_property_readonly("name", &getName) 56 | .def_property_readonly("child", &getChildLoc); 57 | 58 | implicitly_convertible(); 59 | implicitly_convertible_from_all< 60 | LocationAttr, UnknownLoc, CallSiteLoc, 61 | FileLineColLoc, FusedLoc, NameLoc>(loc); 62 | } 63 | 64 | } // end namespace py 65 | } // end namespace mlir 66 | -------------------------------------------------------------------------------- /lib/Python/ExposeOpaqueType.cpp: -------------------------------------------------------------------------------- 1 | #include "Location.h" 2 | #include "Utility.h" 3 | #include "Context.h" 4 | #include "Identifier.h" 5 | #include "Type.h" 6 | #include "Expose.h" 7 | 8 | using namespace pybind11; 9 | 10 | namespace mlir { 11 | namespace py { 12 | 13 | template auto nullcheck(FcnT fcn) { 14 | return ::nullcheck(fcn, "opaque type"); 15 | } 16 | 17 | OpaqueType getOpaqueType(const std::string &dialect, 18 | const std::string &typeData) { 19 | auto id = getIdentifierChecked(dialect); 20 | if (failed(OpaqueType::verifyConstructionInvariants( 21 | getUnknownLoc(), id, typeData))) 22 | throw std::invalid_argument{"Bad opaque type arguments"}; 23 | return OpaqueType::get(id, typeData, getMLIRContext()); 24 | } 25 | 26 | void exposeOpaqueType(pybind11::module &m, TypeClass &type) { 27 | class_(m, "OpaqueType", type) 28 | .def(init(&getOpaqueType)) 29 | .def_property_readonly("dialectNamespace", nullcheck([](OpaqueType ty) { 30 | return ty.getDialectNamespace().str(); 31 | })) 32 | .def_property_readonly("typeData", nullcheck([](OpaqueType ty) { 33 | return ty.getTypeData().str(); 34 | })); 35 | } 36 | 37 | } // end namespace py 38 | } // end namespace mlir 39 | -------------------------------------------------------------------------------- /lib/Python/ExposeParser.cpp: -------------------------------------------------------------------------------- 1 | #include "Parser.h" 2 | 3 | #include 4 | 5 | using namespace pybind11; 6 | 7 | namespace mlir { 8 | namespace py { 9 | 10 | void exposeParser(module &m) { 11 | m.def("parseSourceFile", &parseSourceFile); 12 | } 13 | 14 | } // end namespace py 15 | } // end namespace mlir 16 | -------------------------------------------------------------------------------- /lib/Python/ExposeStandardTypes.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | #include "Location.h" 3 | #include "Utility.h" 4 | #include "Type.h" 5 | #include "Expose.h" 6 | 7 | #include 8 | 9 | using namespace mlir; 10 | using namespace pybind11; 11 | 12 | namespace mlir { 13 | namespace py { 14 | 15 | template auto nullcheck(FcnT fcn) { 16 | return ::nullcheck(fcn, "type"); 17 | } 18 | 19 | void exposeStandardNumericTypes(pybind11::module &m, TypeClass &type) { 20 | class_(m, "ComplexType", type) 21 | .def(init(&ComplexType::getChecked), "elementType"_a, 22 | "location"_a = getUnknownLoc()) 23 | .def_property_readonly("elementType", 24 | nullcheck(&ComplexType::getElementType)); 25 | 26 | class_(m, "IndexType", type) 27 | .def(init([]() { return IndexType::get(getMLIRContext()); })); 28 | 29 | class_ intType{m, "IntegerType", type}; 30 | intType 31 | .def(init(overload( 32 | &IntegerType::getChecked)), "width"_a, "location"_a = getUnknownLoc()) 33 | .def(init(overload(&IntegerType::getChecked)), 35 | "width"_a, "signedness"_a, "location"_a = getUnknownLoc()) 36 | .def_property_readonly("width", nullcheck(&IntegerType::getWidth)) 37 | .def_property_readonly("signedness", nullcheck(&IntegerType::getSignedness)) 38 | .def("isSignless", nullcheck(&IntegerType::isSignless)) 39 | .def("isSigned", nullcheck(&IntegerType::isSigned)) 40 | .def("isUnsigned", nullcheck(&IntegerType::isUnsigned)); 41 | 42 | enum_(intType, "SignednessSemantics") 43 | .value("Signless", IntegerType::Signless) 44 | .value("Signed", IntegerType::Signed) 45 | .value("Unsigned", IntegerType::Unsigned) 46 | .export_values(); 47 | 48 | class_(m, "FloatType", type) 49 | .def_property_readonly("width", nullcheck(&FloatType::getWidth)); 50 | // llvm::fltSemantics definition not publicly visible 51 | 52 | m.def("BF16Type", []() { return FloatType::getBF16(getMLIRContext()); }); 53 | m.def("F16Type", []() { return FloatType::getF16(getMLIRContext()); }); 54 | m.def("F32Type", []() { return FloatType::getF32(getMLIRContext()); }); 55 | m.def("F64Type", []() { return FloatType::getF64(getMLIRContext()); }); 56 | 57 | class_(m, "NoneType", type) 58 | .def(init([]() { return mlir::NoneType::get(getMLIRContext()); })); 59 | } 60 | 61 | } // end namespace py 62 | } // end namespace mlir 63 | -------------------------------------------------------------------------------- /lib/Python/ExposeSymbolRefAttr.cpp: -------------------------------------------------------------------------------- 1 | #include "Utility.h" 2 | #include "Context.h" 3 | 4 | #include 5 | #include 6 | 7 | using namespace pybind11; 8 | 9 | namespace mlir { 10 | namespace py { 11 | 12 | template auto nullcheck(FcnT fcn) { 13 | return ::nullcheck(fcn, "symbol reference attribute"); 14 | } 15 | 16 | void exposeSymbolRefAttr(module &m, class_ &attr) { 17 | class_ symbolRefAttr{m, "SymbolRefAttr", attr}; 18 | symbolRefAttr 19 | .def(init([](const std::string &value, 20 | const std::vector &refs) { 21 | return SymbolRefAttr::get(value, refs, getMLIRContext()); 22 | })) 23 | .def_property_readonly("root", nullcheck([](SymbolRefAttr attr) 24 | { return attr.getRootReference().str(); })) 25 | .def_property_readonly("leaf", nullcheck([](SymbolRefAttr attr) 26 | { return attr.getLeafReference().str(); })) 27 | .def("getNestedReferences", nullcheck([](SymbolRefAttr attr) { 28 | auto refs = attr.getNestedReferences(); 29 | return std::vector{std::begin(refs), 30 | std::end(refs)}; 31 | })); 32 | 33 | // FlatSymbolRefAttr subclasses SymbolRefAttr 34 | class_(m, "FlatSymbolRefAttr", symbolRefAttr) 35 | .def(init([](const std::string &value) { 36 | return FlatSymbolRefAttr::get(value, getMLIRContext()); 37 | })) 38 | .def("getValue", nullcheck([](FlatSymbolRefAttr attr) { 39 | return attr.getValue().str(); 40 | })); 41 | } 42 | 43 | } // end namespace py 44 | } // end namespace mlir 45 | -------------------------------------------------------------------------------- /lib/Python/ExternalModule.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Python/PyMLIR.h" 2 | #include "dmc/Dynamic/DynamicContext.h" 3 | #include "dmc/Dynamic/DynamicDialect.h" 4 | #include "dmc/Spec/DialectGen.h" 5 | #include "dmc/Spec/SpecOps.h" 6 | #include "dmc/Embed/Expose.h" 7 | 8 | #include 9 | 10 | using namespace dmc; 11 | using namespace mlir; 12 | using namespace pybind11; 13 | 14 | PYBIND11_MODULE(mlir, m) { 15 | mlir::py::getModule(m); 16 | // ownership is given to MLIRContext 17 | auto *ctx = mlir::py::getMLIRContext()->getOrCreateDialect(); 18 | 19 | m.def("registerDynamicDialects", [ctx](ModuleOp module) { 20 | list ret; 21 | std::vector scope; 22 | for (auto dialectOp : module.getOps()) { 23 | scope.push_back(dialectOp.getName()); 24 | if (failed(registerDialect(dialectOp, ctx, scope))) 25 | throw std::invalid_argument{"Failed to register dialect: " + 26 | dialectOp.getName().str()}; 27 | auto *dialect = 28 | mlir::py::getMLIRContext()->getRegisteredDialect(dialectOp.getName()); 29 | ret.append(eval(dialect->getNamespace().str(), 30 | module::import("mlir").attr("__dict__"))); 31 | } 32 | return ret; 33 | }); 34 | } 35 | -------------------------------------------------------------------------------- /lib/Python/Identifier.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | Identifier getIdentifierChecked(std::string id){ 9 | if (id.empty()) 10 | throw std::invalid_argument{"Identifier cannot be an empty string."}; 11 | if (id.find('\0') != std::string::npos) 12 | throw std::invalid_argument{"Identifier cannot contain null characters."}; 13 | return Identifier::get(id, getMLIRContext()); 14 | } 15 | 16 | } // end namespace py 17 | } // end namespace mlir 18 | -------------------------------------------------------------------------------- /lib/Python/Identifier.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | Identifier getIdentifierChecked(std::string id); 9 | 10 | } // end namespace py 11 | } // end namespace mlir 12 | -------------------------------------------------------------------------------- /lib/Python/Location.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | #include "Utility.h" 3 | #include "Identifier.h" 4 | 5 | #include 6 | 7 | namespace mlir { 8 | namespace py { 9 | 10 | /// UnknownLoc. 11 | UnknownLoc getUnknownLoc() { 12 | return UnknownLoc::get(getMLIRContext()).cast(); 13 | } 14 | 15 | /// CallSiteLoc. 16 | CallSiteLoc getCallSiteLoc(Location callee, Location caller) { 17 | return CallSiteLoc::get(callee, caller).cast(); 18 | } 19 | 20 | Location getCallee(CallSiteLoc loc) { 21 | return loc.getCallee(); 22 | } 23 | 24 | Location getCaller(CallSiteLoc loc) { 25 | return loc.getCaller(); 26 | } 27 | 28 | /// FileLineColLoc. 29 | FileLineColLoc getFileLineColLoc(std::string filename, unsigned line, 30 | unsigned col) { 31 | return FileLineColLoc::get(filename, line, col, getMLIRContext()) 32 | .cast(); 33 | } 34 | 35 | std::string getFilename(FileLineColLoc loc) { 36 | return loc.getFilename().str(); 37 | } 38 | 39 | unsigned getLine(FileLineColLoc loc) { 40 | return loc.getLine(); 41 | } 42 | 43 | unsigned getColumn(FileLineColLoc loc) { 44 | return loc.getColumn(); 45 | } 46 | 47 | /// FusedLoc. 48 | FusedLoc getFusedLoc(const std::vector &locs) { 49 | return FusedLoc::get(locs, getMLIRContext()).cast(); 50 | } 51 | 52 | std::vector *getLocations(FusedLoc loc) { 53 | return new std::vector{loc.getLocations()}; 54 | } 55 | 56 | /// NameLoc. 57 | NameLoc getNameLoc(std::string name, Location child) { 58 | return NameLoc::get(getIdentifierChecked(name), child).cast(); 59 | } 60 | 61 | NameLoc getNameLoc(std::string name) { 62 | return NameLoc::get(getIdentifierChecked(name), getMLIRContext()) 63 | .cast(); 64 | } 65 | 66 | std::string getName(NameLoc loc) { 67 | return loc.getName().str(); 68 | } 69 | 70 | Location getChildLoc(NameLoc loc) { 71 | return loc.getChildLoc(); 72 | } 73 | 74 | } // end namespace py 75 | } // end namespace mlir 76 | -------------------------------------------------------------------------------- /lib/Python/Location.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | /// UnknownLoc. 9 | UnknownLoc getUnknownLoc(); 10 | 11 | /// CallSiteLoc. 12 | CallSiteLoc getCallSiteLoc(Location callee, Location caller); 13 | Location getCallee(CallSiteLoc loc); 14 | Location getCaller(CallSiteLoc loc); 15 | 16 | /// FileLineColLoc. 17 | FileLineColLoc getFileLineColLoc(std::string filename, unsigned line, 18 | unsigned col); 19 | std::string getFilename(FileLineColLoc loc); 20 | unsigned getLine(FileLineColLoc loc); 21 | unsigned getColumn(FileLineColLoc loc); 22 | 23 | /// FusedLoc. 24 | FusedLoc getFusedLoc(const std::vector &locs); 25 | std::vector *getLocations(FusedLoc loc); 26 | 27 | /// NameLoc. 28 | NameLoc getNameLoc(std::string name, Location child); 29 | NameLoc getNameLoc(std::string name); 30 | std::string getName(NameLoc loc); 31 | Location getChildLoc(NameLoc loc); 32 | 33 | } // end namespace py 34 | } // end namespace mlir 35 | -------------------------------------------------------------------------------- /lib/Python/Module.cpp: -------------------------------------------------------------------------------- 1 | #include "Location.h" 2 | #include "Utility.h" 3 | 4 | #include 5 | #include 6 | 7 | using namespace llvm; 8 | using namespace pybind11; 9 | 10 | namespace mlir { 11 | namespace py { 12 | 13 | std::optional getModuleName(ModuleOp moduleOp) { 14 | if (auto name = moduleOp.getName()) 15 | return name->str(); 16 | return {}; 17 | } 18 | 19 | } // end namespace py 20 | } // end namespace mlir 21 | -------------------------------------------------------------------------------- /lib/Python/Module.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | /// Getters. 9 | std::optional getModuleName(ModuleOp moduleOp); 10 | 11 | } // end namespace py 12 | } // end namespace mlir 13 | -------------------------------------------------------------------------------- /lib/Python/OwningModuleRef.cpp: -------------------------------------------------------------------------------- 1 | #include "Utility.h" 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | std::string printModuleRef(OwningModuleRef &moduleRef) { 9 | if (!moduleRef) 10 | throw std::invalid_argument{"module is null"}; 11 | return StringPrinter{}(*moduleRef); 12 | } 13 | 14 | ModuleOp getOwnedModule(OwningModuleRef &moduleRef) { 15 | return *moduleRef; 16 | } 17 | 18 | } // end namespace py 19 | } // end namespace mlir 20 | -------------------------------------------------------------------------------- /lib/Python/OwningModuleRef.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | std::string printModuleRef(OwningModuleRef &moduleRef); 9 | ModuleOp getOwnedModule(OwningModuleRef &moduleRef); 10 | 11 | } // end namespace py 12 | } // end namespace mlir 13 | -------------------------------------------------------------------------------- /lib/Python/Parser.cpp: -------------------------------------------------------------------------------- 1 | #include "Context.h" 2 | #include "Utility.h" 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace llvm; 9 | 10 | namespace mlir { 11 | namespace py { 12 | 13 | /// Parse a source file from a given filename. Provide a source manager and 14 | /// a diagnostic handler for the parse. 15 | ModuleOp parseSourceFile(std::string filename) { 16 | // TODO 100% a memory leak. The SourceMgr needs to be kept alive. Make python 17 | // manage the lifetime of the SourceMgr. 18 | auto *sourceMgr = new SourceMgr; 19 | new SourceMgrDiagnosticHandler{*sourceMgr, getMLIRContext()}; 20 | auto ret = parseSourceFile(filename, *sourceMgr, getMLIRContext()); 21 | return ret.release(); 22 | } 23 | 24 | } // end namespace py 25 | } // end namespace mlir 26 | -------------------------------------------------------------------------------- /lib/Python/Parser.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace mlir { 4 | namespace py { 5 | 6 | ModuleOp parseSourceFile(std::string filename); 7 | 8 | } // end namespace py 9 | } // end namespace mlir 10 | -------------------------------------------------------------------------------- /lib/Python/Type.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace mlir { 4 | namespace py { 5 | 6 | unsigned getIntOrFloatBitWidth(Type ty) { 7 | if (!ty.isIntOrFloat()) 8 | throw std::invalid_argument{"only integer or float types have bit widths"}; 9 | return ty.getIntOrFloatBitWidth(); 10 | } 11 | 12 | } // end namespace py 13 | } // end namespace mlir 14 | -------------------------------------------------------------------------------- /lib/Python/Type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace mlir { 6 | namespace py { 7 | 8 | unsigned getIntOrFloatBitWidth(Type ty); 9 | 10 | } // end namespace py 11 | } // end namespace mlir 12 | -------------------------------------------------------------------------------- /lib/Python/Utility.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Python/Polymorphic.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | /// Shorthands. 10 | namespace mlir { 11 | namespace py { 12 | 13 | using StringList = std::vector; 14 | using ValueList = std::vector; 15 | using TypeList = std::vector; 16 | using AttrList = std::vector; 17 | using BlockList = std::vector; 18 | using AttrDict = std::unordered_map; 19 | 20 | using StringListRef = const StringList &; 21 | using ValueListRef = const ValueList &; 22 | using TypeListRef = const TypeList &; 23 | using AttrListRef = const AttrList &; 24 | using BlockListRef = const BlockList &; 25 | using AttrDictRef = const AttrDict &; 26 | 27 | } // end namespace py 28 | } // end namespace mlir 29 | 30 | /// Create a printer for MLIR objects to std::string. 31 | template 32 | struct StringPrinter { 33 | std::string operator()(T t) const { 34 | std::string buf; 35 | llvm::raw_string_ostream os{buf}; 36 | t.print(os); 37 | return std::move(os.str()); 38 | } 39 | }; 40 | 41 | /// Cast to an overloaded function type. 42 | template 43 | auto overload(FcnT fcn) { return fcn; } 44 | 45 | /// Move a value to the heap and let Python manage its lifetime. 46 | template 47 | std::unique_ptr moveToHeap(T &&t) { 48 | auto ptr = std::make_unique(); 49 | *ptr = std::move(t); 50 | return ptr; 51 | } 52 | 53 | /// Automatically wrap function calls in a nullcheck of the primary argument. 54 | template 55 | std::function> 56 | nullcheck(FcnT fcn, std::string name, 57 | std::enable_if_t> * = 0) { 58 | return [fcn, name](auto t, auto ...ts) { 59 | if (!t) 60 | throw std::invalid_argument{name + " is null"}; 61 | return fcn(t, ts...); 62 | }; 63 | } 64 | 65 | /// Automatically wrap member function calls in a nullcheck of the object. 66 | template 67 | std::function 68 | nullcheck(RetT(ObjT::*fcn)(ArgTs...), std::string name) { 69 | return [fcn, name](auto t, ArgTs ...args) -> RetT { 70 | if (!t) 71 | throw std::invalid_argument{name + " is null"}; 72 | return (t.*fcn)(args...); 73 | }; 74 | } 75 | 76 | /// For const member functions. 77 | template 78 | std::function 79 | nullcheck(RetT(ObjT::*fcn)(ArgTs...) const, std::string name) { 80 | return [fcn, name](auto t, ArgTs ...args) -> RetT { 81 | if (!t) 82 | throw std::invalid_argument{name + " is null"}; 83 | return (t.*fcn)(args...); 84 | }; 85 | } 86 | 87 | /// Create an isa<> check. 88 | template 89 | auto isa() { 90 | return [](From f) { return f.template isa(); }; 91 | } 92 | 93 | /// Automatically generate implicit conversions to parent class with 94 | /// LLVM polymorphism: implicit conversion statements and constuctors. 95 | namespace detail { 96 | 97 | template 98 | struct implicitly_convertible_from_all_helper; 99 | 100 | template 101 | struct implicitly_convertible_from_all_helper { 102 | template static void doit(ClassT &cls) { 103 | cls.def(pybind11::init()); 104 | pybind11::implicitly_convertible(); 105 | } 106 | }; 107 | 108 | template 109 | struct implicitly_convertible_from_all_helper { 110 | template static void doit(ClassT &cls) { 111 | implicitly_convertible_from_all_helper::doit(cls); 112 | implicitly_convertible_from_all_helper::doit(cls); 113 | } 114 | }; 115 | 116 | } // end namespace detail 117 | 118 | template 119 | void implicitly_convertible_from_all(pybind11::class_ &cls) { 120 | detail::implicitly_convertible_from_all_helper< 121 | BaseT, DerivedTs...>::doit(cls); 122 | } 123 | -------------------------------------------------------------------------------- /lib/Spec/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(DMCSpec 2 | SpecDialect.cpp 3 | SpecTypes.cpp 4 | SpecTypeDetail.cpp 5 | SpecOps.cpp 6 | SpecAttrs.cpp 7 | SpecRegion.cpp 8 | SpecSuccessor.cpp 9 | SpecTypeImplementation.cpp 10 | SpecAttrImplementation.cpp 11 | DialectGen.cpp 12 | ParameterList.cpp 13 | ReparseOpInterface.cpp 14 | OpReparsing.cpp 15 | Parsing.cpp 16 | OpType.cpp 17 | FormatOp.cpp 18 | ) 19 | target_link_libraries(DMCSpec 20 | MLIRIR 21 | MLIRTransforms 22 | ) 23 | add_dependencies(DMCSpec 24 | DMCParameterListIncGen 25 | DMCReparseOpInterfaceIncGen 26 | DMCFormatOpIncGen 27 | ) 28 | -------------------------------------------------------------------------------- /lib/Spec/FormatOp.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/FormatOp.h" 2 | 3 | namespace mlir { 4 | #include "dmc/Spec/FormatOp.cpp.inc" 5 | } // end namespace mlir 6 | -------------------------------------------------------------------------------- /lib/Spec/ParameterList.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/ParameterList.h" 2 | #include "dmc/Spec/SpecAttrs.h" 3 | #include "dmc/Spec/Parsing.h" 4 | 5 | using namespace dmc; 6 | 7 | namespace mlir { 8 | 9 | #include "dmc/Spec/ParameterList.cpp.inc" 10 | 11 | namespace detail { 12 | struct NamedParameterStorage : public AttributeStorage { 13 | using KeyTy = std::pair; 14 | 15 | explicit NamedParameterStorage(StringRef name, Attribute constraint) 16 | : name{name}, constraint{constraint} {} 17 | 18 | bool operator==(const KeyTy &key) const { 19 | return key.first == name && key.second == constraint; 20 | } 21 | 22 | static llvm::hash_code hashKey(const KeyTy &key) { 23 | return llvm::hash_combine(key.first, key.second); 24 | } 25 | 26 | static NamedParameterStorage *construct(AttributeStorageAllocator &alloc, 27 | const KeyTy &key) { 28 | return new (alloc.allocate()) 29 | NamedParameterStorage{alloc.copyInto(key.first), key.second}; 30 | } 31 | 32 | StringRef name; 33 | Attribute constraint; 34 | }; 35 | } // end namespace detail 36 | 37 | NamedParameter NamedParameter::get(StringRef name, Attribute constraint) { 38 | return Base::get(constraint.getContext(), dmc::AttrKinds::NamedParameter, 39 | name, constraint); 40 | } 41 | 42 | NamedParameter NamedParameter::getChecked(Location loc, StringRef name, 43 | Attribute constraint) { 44 | return Base::getChecked(loc, dmc::AttrKinds::NamedParameter, name, 45 | constraint); 46 | } 47 | 48 | LogicalResult NamedParameter::verifyConstructionInvariants( 49 | Location loc, StringRef name, Attribute constraint) { 50 | if (!dmc::SpecAttrs::is(constraint) && !constraint.isa()) 51 | return emitError(loc) << "expected a valid attribute constraint"; 52 | return success(); 53 | } 54 | 55 | StringRef NamedParameter::getName() const { 56 | return getImpl()->name; 57 | } 58 | 59 | Attribute NamedParameter::getConstraint() const { 60 | return getImpl()->constraint; 61 | } 62 | 63 | ParseResult ParameterList::parse(OpAsmParser &parser, 64 | NamedAttrList &attrList) { 65 | SmallVector params; 66 | if (succeeded(parser.parseOptionalLess())) { 67 | StringRef name; 68 | Attribute constraint; 69 | do { 70 | auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 71 | if (parser.parseKeyword(&name) || parser.parseColon() || 72 | dmc::impl::parseSingleAttribute(parser, constraint)) 73 | return failure(); 74 | if (auto type = constraint.dyn_cast()) 75 | constraint = dmc::OfTypeAttr::get(type.getValue()); 76 | params.push_back(NamedParameter::getChecked(loc, name, constraint)); 77 | } while (succeeded(parser.parseOptionalComma())); 78 | if (parser.parseGreater()) 79 | return failure(); 80 | } 81 | attrList.append(Trait::getParametersAttrName(), 82 | parser.getBuilder().getArrayAttr(params)); 83 | return success(); 84 | } 85 | 86 | } // end namespace mlir 87 | -------------------------------------------------------------------------------- /lib/Spec/ReparseOpInterface.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/ReparseOpInterface.h" 2 | 3 | namespace mlir { 4 | #include "dmc/Spec/ReparseOpInterface.cpp.inc" 5 | } // end namespace mlir 6 | -------------------------------------------------------------------------------- /lib/Spec/SpecAttrImplementation.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecAttrSwitch.h" 2 | 3 | using namespace mlir; 4 | 5 | namespace dmc { 6 | 7 | namespace SpecAttrs { 8 | 9 | bool is(Attribute base) { 10 | return Any <= base.getKind() && base.getKind() < LAST_SPEC_ATTR; 11 | } 12 | 13 | LogicalResult delegateVerify(Attribute base, Attribute attr) { 14 | /// If not an attribute constraint, do a direct comparison. 15 | if (!is(base)) 16 | return success(base == attr); 17 | /// Use the switch table. 18 | VerifyAction action{attr}; 19 | return SpecAttrs::kindSwitch(action, base); 20 | } 21 | 22 | } // end namespace SpecAttrs 23 | 24 | namespace impl { 25 | 26 | LogicalResult verifyAttribute(Operation *op, NamedAttribute attr) { 27 | auto opAttr = op->getAttr(attr.first); 28 | if (!opAttr) { 29 | /// A missing optional attribute is okay. 30 | if (attr.second.isa()) 31 | return success(); 32 | return op->emitOpError("missing attribute '") << attr.first << '\''; 33 | } 34 | auto baseAttr = attr.second; 35 | if (SpecAttrs::is(baseAttr)) { 36 | if (failed(SpecAttrs::delegateVerify(baseAttr, opAttr))) 37 | return op->emitOpError("attribute '") << attr.first << '\'' 38 | << ", which is '" << opAttr << "', failed to satisfy '" 39 | << baseAttr << '\''; 40 | } else if (baseAttr != opAttr) 41 | return op->emitOpError("expected attribute '") << attr.first << '\'' 42 | << " to be '" << baseAttr << "' but got '" << opAttr << '\''; 43 | return success(); 44 | } 45 | 46 | LogicalResult verifyAttrConstraints( 47 | Operation *op, mlir::DictionaryAttr opAttrs) { 48 | for (auto &attr : opAttrs.getValue()) { 49 | if (failed(verifyAttribute(op, attr))) 50 | return failure(); 51 | } 52 | return success(); 53 | } 54 | 55 | } // end namespace impl 56 | 57 | } // end namespace dmc 58 | -------------------------------------------------------------------------------- /lib/Spec/SpecSuccessor.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecSuccessor.h" 2 | #include "dmc/Spec/SpecSuccessorSwitch.h" 3 | #include "dmc/Spec/Parsing.h" 4 | 5 | #include 6 | 7 | using namespace mlir; 8 | 9 | namespace dmc { 10 | 11 | namespace SpecSuccessor { 12 | 13 | bool is(Attribute base) { 14 | return Any <= base.getKind() && base.getKind() < LAST_SPEC_SUCCESSOR; 15 | } 16 | 17 | LogicalResult delegateVerify(Attribute base, Block *block) { 18 | VerifyAction action{block}; 19 | return SpecSuccessor::kindSwitch(action, base); 20 | } 21 | 22 | std::string toString(Attribute opSucc) { 23 | std::string ret; 24 | llvm::raw_string_ostream os{ret}; 25 | impl::printOpSuccessor(os, opSucc); 26 | return std::move(os.str()); 27 | } 28 | 29 | } // end namespace SpecSuccessor 30 | 31 | /// VariadicSuccessor. 32 | VariadicSuccessor VariadicSuccessor::getChecked(Location loc, 33 | Attribute succConstraint) { 34 | return Base::getChecked(loc, SpecSuccessor::Variadic, succConstraint); 35 | } 36 | 37 | LogicalResult VariadicSuccessor::verifyConstructionInvariants( 38 | Location loc, Attribute succConstraint) { 39 | if (!SpecSuccessor::is(succConstraint)) 40 | return emitError(loc) << "expected a valid successor constraint"; 41 | return success(); 42 | } 43 | 44 | LogicalResult VariadicSuccessor::verify(Block *block) { 45 | return SpecSuccessor::delegateVerify(getImpl()->attr, block); 46 | } 47 | 48 | /// Parsing 49 | Attribute AnySuccessor::parse(OpAsmParser &parser) { 50 | return get(parser.getBuilder().getContext()); 51 | } 52 | 53 | Attribute VariadicSuccessor::parse(OpAsmParser &parser) { 54 | Attribute opSucc; 55 | auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 56 | if (parser.parseLess() || impl::parseOpSuccessor(parser, opSucc) || 57 | parser.parseGreater()) 58 | return {}; 59 | return getChecked(loc, opSucc); 60 | } 61 | 62 | void AnySuccessor::print(llvm::raw_ostream &os) { 63 | os << getName(); 64 | } 65 | 66 | void VariadicSuccessor::print(llvm::raw_ostream &os) { 67 | os << getName() << '<'; 68 | impl::printOpSuccessor(os, getImpl()->attr); 69 | os << '>'; 70 | } 71 | 72 | } // end namespace dmc 73 | -------------------------------------------------------------------------------- /lib/Spec/SpecTypeDetail.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecTypeDetail.h" 2 | 3 | #include 4 | 5 | using namespace mlir; 6 | 7 | namespace dmc { 8 | namespace impl { 9 | 10 | LogicalResult verifyIntWidth(Location loc, unsigned width) { 11 | switch (width) { 12 | case 1: 13 | case 8: 14 | case 16: 15 | case 32: 16 | case 64: 17 | return success(); 18 | default: 19 | return emitError(loc) << "integer width must be one of " 20 | << "[1, 8, 16, 32, 64]"; 21 | } 22 | } 23 | 24 | LogicalResult verifyFloatWidth(Location loc, unsigned width) { 25 | switch (width) { 26 | case 16: 27 | case 32: 28 | case 64: 29 | return success(); 30 | default: 31 | return emitError(loc) << "float width must be one of [16, 32, 64]"; 32 | } 33 | } 34 | 35 | LogicalResult verifyFloatType(unsigned width, Type ty) { 36 | switch (width) { 37 | case 16: 38 | return success(ty.isF16()); 39 | case 32: 40 | return success(ty.isF32()); 41 | case 64: 42 | return success(ty.isF64()); 43 | default: 44 | llvm_unreachable("Invalid floating point width"); 45 | return failure(); 46 | } 47 | } 48 | 49 | LogicalResult verifyWidthList( 50 | Location loc, ArrayRef widths, 51 | LogicalResult (&verifyWidth)(Location, unsigned)) { 52 | /// Check empty list. 53 | if (widths.empty()) 54 | return emitError(loc) << "empty width list"; 55 | /// Check for duplicate values. 56 | std::unordered_set widthSet{std::begin(widths), std::end(widths)}; 57 | if (std::size(widthSet) != std::size(widths)) 58 | return emitError(loc) << "duplicate widths in width list"; 59 | /// Verify individual widths. 60 | for (auto width : widths) 61 | if (failed(verifyWidth(loc, width))) 62 | return failure(); 63 | return success(); 64 | } 65 | 66 | } // end namespace impl 67 | } // end namespace dmc 68 | -------------------------------------------------------------------------------- /lib/Spec/SpecTypeImplementation.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecTypeSwitch.h" 2 | #include "dmc/Traits/SpecTraits.h" 3 | 4 | #include 5 | 6 | using namespace mlir; 7 | 8 | namespace dmc { 9 | namespace SpecTypes { 10 | 11 | bool is(Type base) { 12 | return Any <= base.getKind() && base.getKind() < LAST_SPEC_TYPE; 13 | } 14 | 15 | LogicalResult delegateVerify(Type base, Type ty) { 16 | /// If not a type constraint, use a direct comparison. 17 | if (!is(base)) 18 | return success(base == ty); 19 | /// Use the switch table. 20 | VerifyAction action{ty}; 21 | return SpecTypes::kindSwitch(action, base); 22 | } 23 | 24 | } // end namespace SpecTypes 25 | 26 | /// Type verification. 27 | namespace impl { 28 | 29 | template 30 | LogicalResult verifyTypeRange(Operation *op, OpTypeRange baseTys, 31 | TypeRange tys, StringRef name) { 32 | auto firstTy = std::begin(tys), tyEnd = std::end(tys); 33 | auto tyIt = firstTy; 34 | for (auto baseIt = std::begin(baseTys), baseEnd = std::end(baseTys); 35 | baseIt != baseEnd || tyIt != tyEnd; ++tyIt, ++baseIt) { 36 | /// Number of operands and results are verified by previous traits. 37 | assert(baseIt != baseEnd && tyIt != tyEnd); 38 | if (failed(SpecTypes::delegateVerify(*baseIt, *tyIt))) 39 | return op->emitOpError() << name << " #" << std::distance(firstTy, tyIt) 40 | << " must be " << *baseIt << " but got " << *tyIt; 41 | } 42 | return success(); 43 | } 44 | 45 | template 46 | LogicalResult verifyVariadicTypes(Operation *op, OpTypeRange baseTys, 47 | GetValueGroup getValues, const char *name) { 48 | unsigned groupIdx = 0, valIdx = 0; 49 | auto values = getValues(op, groupIdx); 50 | for (auto tyIt = std::begin(baseTys), tyEnd = std::end(baseTys); 51 | tyIt != tyEnd || std::begin(values) != std::end(values); 52 | ++tyIt, values = getValues(op, ++groupIdx)) { 53 | assert(tyIt != tyEnd); 54 | assert(std::begin(values) != std::end(values) || 55 | (*tyIt).template isa()); 56 | for (auto valIt = std::begin(values), valEnd = std::end(values); 57 | valIt != valEnd; ++valIt, ++valIdx) { 58 | // TODO custom type descriptions with dynamic types 59 | auto valType = (*valIt).getType(); 60 | if (failed(SpecTypes::delegateVerify(*tyIt, valType))) 61 | return op->emitOpError() << name << " #" << valIdx << " must be " 62 | << *tyIt << " but got " << valType; 63 | } 64 | } 65 | return success(); 66 | } 67 | 68 | template 70 | LogicalResult verifyGroupTypes( 71 | Operation *op, TypeRange types, DynamicOperation *info, GetAllFcn getAll, 72 | const char *name) { 73 | if (info->getTrait()) { 74 | return verifyVariadicTypes(op, types, SizedT::getGroup, name); 75 | } else if (info->getTrait()) { 76 | return verifyVariadicTypes(op, types, SameT::getGroup, name); 77 | } else { 78 | return verifyTypeRange(op, types, (op->*getAll)(), name); 79 | } 80 | } 81 | 82 | LogicalResult 83 | verifyOperandTypes(Operation *op, OpType opTy, DynamicOperation *info) { 84 | return verifyGroupTypes( 85 | op, opTy.getOperandTypes(), info, &Operation::getOperandTypes, "operand"); 86 | } 87 | 88 | LogicalResult 89 | verifyResultTypes(Operation *op, OpType opTy, DynamicOperation *info) { 90 | return verifyGroupTypes( 91 | op, opTy.getResultTypes(), info, &Operation::getResultTypes, "result"); 92 | } 93 | 94 | LogicalResult verifyTypeConstraints(Operation *op, OpType opTy) { 95 | auto *info = DynamicOperation::of(op); 96 | return failure(failed(verifyOperandTypes(op, opTy, info)) || 97 | failed(verifyResultTypes(op, opTy, info))); 98 | } 99 | 100 | } // end namespace impl 101 | 102 | } // end namespace dmc 103 | -------------------------------------------------------------------------------- /lib/Spec/SpecTypeParsing.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecTypes.h" 2 | #include "dmc/Spec/SpecDialect.h" 3 | -------------------------------------------------------------------------------- /lib/Traits/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(DMCTraits 2 | SpecTraits.cpp 3 | Registry.cpp 4 | OpTrait.cpp 5 | StandardTraits.cpp 6 | ) 7 | target_link_libraries(DMCTraits 8 | DMCDynamic 9 | DMCSpec 10 | MLIRIR 11 | ) 12 | -------------------------------------------------------------------------------- /lib/Traits/GenericConstructor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dmc/Traits/Registry.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace dmc { 10 | 11 | /// Helpers for unpacking generic argument arrays. 12 | namespace detail { 13 | 14 | /// Unpack a typed argument at an index. 15 | template 16 | auto unpackArg(llvm::ArrayRef args) { 17 | return args[I].cast(); 18 | } 19 | 20 | /// Call the function with the given arguments. 21 | template 22 | auto callFcn(FcnT fcn, llvm::ArrayRef args, 23 | std::integer_sequence) { 24 | return fcn(unpackArg(args)...); 25 | } 26 | 27 | /// Check that an argument at an index is the correct type. 28 | template 29 | mlir::LogicalResult checkArgType(mlir::Location loc, 30 | llvm::ArrayRef args) { 31 | if (!args[I].isa()) 32 | return mlir::emitError(loc) << "trait constructor expected " 33 | << typeid(ArgT).name() << " for argument #" << I << " but got " 34 | << args[I]; 35 | return mlir::success(); 36 | } 37 | 38 | /// AND a variadic list. 39 | mlir::LogicalResult andFold() { return mlir::success(); } 40 | 41 | template 42 | mlir::LogicalResult andFold(mlir::LogicalResult first, BoolT... vals) { 43 | return mlir::success(mlir::succeeded(first) && 44 | mlir::succeeded(andFold(vals...))); 45 | } 46 | 47 | /// Check that the provided argument array has the correct signature. 48 | template 49 | auto checkFcnSignature(mlir::Location loc, llvm::ArrayRef args, 50 | std::integer_sequence) { 51 | return andFold(checkArgType(loc, args)...); 52 | } 53 | 54 | } // end namespace detail 55 | 56 | template 57 | class GenericConstructor { 58 | public: 59 | using ConstructorT = Trait (*)(ArgTs...); 60 | using Indices = std::make_integer_sequence; 61 | 62 | GenericConstructor(ConstructorT ctor) : ctor(ctor) {} 63 | 64 | mlir::LogicalResult verifySignature( 65 | mlir::Location loc, llvm::ArrayRef args) const { 66 | if (llvm::size(args) != sizeof...(ArgTs)) 67 | return mlir::emitError(loc) << "expected " << sizeof...(ArgTs) 68 | << " arguments to trait constructor but got " << llvm::size(args); 69 | return detail::checkFcnSignature(loc, args, Indices{}); 70 | } 71 | 72 | Trait callConstructor(llvm::ArrayRef args) const { 73 | return detail::callFcn(ctor, args, Indices{}); 74 | } 75 | 76 | private: 77 | ConstructorT ctor; 78 | }; 79 | 80 | } // end namespace dmc 81 | -------------------------------------------------------------------------------- /lib/Traits/Registry.cpp: -------------------------------------------------------------------------------- 1 | #include "GenericConstructor.h" 2 | #include "dmc/Traits/Registry.h" 3 | #include "dmc/Traits/OpTrait.h" 4 | #include "dmc/Traits/StandardTraits.h" 5 | #include "dmc/Traits/SpecTraits.h" 6 | 7 | using namespace mlir; 8 | 9 | namespace dmc { 10 | namespace { 11 | 12 | /// Unwrap values stored inside attributes. 13 | template struct unwrap { 14 | auto operator()(const T &t) const { return t.getValue(); } 15 | }; 16 | 17 | template <> struct unwrap { 18 | auto operator()(IntegerAttr val) const { return val.getValue().getZExtValue(); } 19 | }; 20 | 21 | template <> struct unwrap { 22 | auto operator()(Attribute val) const { return val; } 23 | }; 24 | 25 | /// Get a trait constructor's signature from a function type. 26 | template struct TraitSignature; 27 | template 28 | struct TraitSignature {}; 29 | 30 | /// Generic trait constructor registration. 31 | template 32 | void registerTrait(TraitRegistry *reg, TraitSignature) { 33 | GenericConstructor ctorObj{[](ArgTs... args) -> Trait { 34 | return std::make_unique(unwrap{}(args)...); 35 | }}; 36 | TraitConstructor traitCtor{ 37 | [ctorObj](Location loc, ArrayRef args) { 38 | return ctorObj.verifySignature(loc, args); 39 | }, 40 | [ctorObj](ArrayRef args) { 41 | return ctorObj.callConstructor(args); 42 | } 43 | }; 44 | reg->registerTrait(TraitT::getName(), std::move(traitCtor)); 45 | } 46 | 47 | template 48 | void registerTrait(TraitRegistry *reg) { 49 | registerTrait(reg, TraitSignature{}); 50 | } 51 | 52 | template 53 | void registerTraits(TraitRegistry *reg) { 54 | (void) std::initializer_list{0, (registerTrait(reg), 0)...}; 55 | } 56 | 57 | } // end anonymous namespace 58 | 59 | TraitRegistry::TraitRegistry(MLIRContext *ctx) 60 | : Dialect{getDialectNamespace(), ctx, TypeID::get()} { 61 | addAttributes< 62 | OpTraitAttr, OpTraitsAttr 63 | >(); 64 | registerTraits< 65 | IsTerminator(), IsCommutative(), IsIsolatedFromAbove(), 66 | MemoryAlloc(), MemoryFree(), MemoryRead(), MemoryWrite(), 67 | Alloc(Attribute), Free(Attribute), ReadFrom(Attribute), WriteTo(Attribute), 68 | NoSideEffects(), 69 | 70 | LoopLike(StringAttr, StringAttr, StringAttr), 71 | 72 | OperandsAreFloatLike(), OperandsAreSignlessIntegerLike(), 73 | ResultsAreBoolLike(), ResultsAreFloatLike(), 74 | ResultsAreSignlessIntegerLike(), 75 | 76 | SameOperandsShape(), SameOperandsAndResultShape(), 77 | SameOperandsElementType(), SameOperandsAndResultElementType(), 78 | SameOperandsAndResultType(), SameTypeOperands(), 79 | 80 | SameVariadicOperandSizes(), SameVariadicResultSizes(), 81 | SizedOperandSegments(), SizedResultSegments(), 82 | 83 | NOperands(IntegerAttr), AtLeastNOperands(IntegerAttr), 84 | NRegions(IntegerAttr), AtLeastNRegions(IntegerAttr), 85 | NResults(IntegerAttr), AtLeastNResults(IntegerAttr), 86 | NSuccessors(IntegerAttr), AtLeastNSuccessors(IntegerAttr), 87 | 88 | HasParent(StringAttr), SingleBlockImplicitTerminator(StringAttr) 89 | >(this); 90 | } 91 | 92 | void TraitRegistry::registerTrait(StringRef name, TraitConstructor &&getter) { 93 | auto [it, inserted] = traitRegistry.try_emplace( 94 | name, std::forward(getter)); 95 | assert(inserted && "Trait has already been registered"); 96 | } 97 | 98 | TraitConstructor TraitRegistry::lookupTrait(StringRef name) { 99 | auto it = traitRegistry.find(name); 100 | if (it == std::end(traitRegistry)) 101 | return nullptr; 102 | return it->second; 103 | } 104 | 105 | } // end namespace dmc 106 | -------------------------------------------------------------------------------- /lib/Traits/StandardTraits.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Traits/StandardTraits.h" 2 | 3 | using namespace mlir; 4 | 5 | namespace dmc { 6 | 7 | LogicalResult HasParent::verifyOp(Operation *op) const { 8 | auto parentOpName = op->getParentOp()->getName().getStringRef(); 9 | if (parentOpName == parentName) 10 | return success(); 11 | return op->emitOpError() << "expects parent op '" << parentName 12 | << "' but got '" << parentOpName << '\''; 13 | } 14 | 15 | LogicalResult SingleBlockImplicitTerminator::verifyOp(Operation *op) const { 16 | /// Each region should have zero or one block, and the block must be 17 | /// terminated by `terminatorName` op. 18 | auto regions = op->getRegions(); 19 | unsigned idx{}; 20 | for (auto it = std::begin(regions), e = std::end(regions); it != e; 21 | ++it, ++idx) { 22 | auto ®ion = *it; 23 | if (region.empty()) 24 | continue; 25 | 26 | if (std::next(std::begin(region)) != std::end(region)) 27 | return op->emitOpError("expects region #") << idx 28 | << " to have 0 or 1 blocks"; 29 | 30 | auto &block = region.front(); 31 | if (block.empty()) 32 | return op->emitOpError("expects a non-empty block"); 33 | 34 | auto *term = block.getTerminator(); 35 | if (term->getName().getStringRef() == terminatorName) 36 | continue; 37 | 38 | return op->emitOpError("expects regions to end with '") << terminatorName 39 | << "' but found '" << term->getName() << "'"; 40 | } 41 | return success(); 42 | } 43 | 44 | static auto unpackTargets(Attribute targets) { 45 | std::vector ret; 46 | if (auto name = targets.dyn_cast()) { 47 | ret.push_back(name.getValue()); 48 | } else { 49 | for (auto target : targets.cast()) 50 | ret.push_back(target.cast().getValue()); 51 | } 52 | return ret; 53 | } 54 | 55 | Alloc::Alloc(Attribute targets) 56 | : ValueMemoryEffect{unpackTargets(targets)} {} 57 | Free::Free(Attribute targets) 58 | : ValueMemoryEffect{unpackTargets(targets)} {} 59 | ReadFrom::ReadFrom(Attribute targets) 60 | : ValueMemoryEffect{unpackTargets(targets)} {} 61 | WriteTo::WriteTo(Attribute targets) 62 | : ValueMemoryEffect{unpackTargets(targets)} {} 63 | 64 | } // end namespace dmc 65 | -------------------------------------------------------------------------------- /list_targets: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | if [ ! -d bin ]; then 3 | ./configure 4 | fi 5 | 6 | cd bin && cmake --build . --target help 7 | -------------------------------------------------------------------------------- /lua/.gitignore: -------------------------------------------------------------------------------- 1 | *.ll 2 | *.o 3 | *.s 4 | main 5 | libc.mlir 6 | test.mlir 7 | main.mlir 8 | libnone.mlir 9 | test-markov 10 | test-perf 11 | a.out 12 | -------------------------------------------------------------------------------- /lua/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | #set(LUAC_EXE $) 2 | #set(MLIR_TRANSLATE_EXE $) 3 | #set(CLANG_EXE $) 4 | 5 | find_program(ANTLR NAMES antlr antlr4 REQUIRED) 6 | add_custom_command( 7 | OUTPUT parser/Lua.interp 8 | parser/Lua.tokens 9 | parser/LuaLexer.interp 10 | parser/LuaLexer.py 11 | parser/LuaLexer.tokens 12 | parser/LuaListener.py 13 | parser/LuaParser.py 14 | COMMAND ${ANTLR} -Dlanguage=Python3 -o parser ${CMAKE_CURRENT_SOURCE_DIR}/Lua.g4 15 | DEPENDS Lua.g4 16 | COMMENT "Generating Lua parser from ANTLR" 17 | ) 18 | add_custom_target(lua-parser DEPENDS parser/LuaParser.py) 19 | -------------------------------------------------------------------------------- /lua/Makefile: -------------------------------------------------------------------------------- 1 | CFLAGS=-Ofast -g -flto 2 | FILE=fannkuch.lua 3 | 4 | main: main.o impl.o builtins.o 5 | clang++ main.o impl.o builtins.o -o main $(CFLAGS) 6 | 7 | builtins.o: builtins.cpp lib.h 8 | clang++ -c -std=c++17 builtins.cpp -o builtins.o $(CFLAGS) 9 | 10 | impl.o: impl.cpp lib.h 11 | clang++ -c -std=c++17 impl.cpp -o impl.o $(CFLAGS) 12 | 13 | main.s: mainopt.ll 14 | clang -S mainopt.ll $(CFLAGS) -o main.s 15 | 16 | main.o: mainopt.ll 17 | clang -c mainopt.ll -o main.o $(CFLAGS) 18 | 19 | mainopt.ll: main.ll 20 | clang -S -emit-llvm main.ll -o mainopt.ll $(CFLAGS) 21 | 22 | main.ll: main.mlir 23 | mlir-translate -mlir-to-llvmir main.mlir -o main.ll 24 | 25 | main.mlir: luac.py $(FILE) lua.mlir lib.mlir 26 | python3 luac.py $(FILE) > main.mlir 27 | 28 | clean: 29 | rm -f *.o 30 | rm -f *.ll 31 | rm -f main.mlir 32 | -------------------------------------------------------------------------------- /lua/binarytree.lua: -------------------------------------------------------------------------------- 1 | -- The Computer Language Benchmarks Game 2 | -- http://benchmarksgame.alioth.debian.org/ 3 | -- contributed by Mike Pall 4 | 5 | function BottomUpTree(item, depth) 6 | if depth > 0 then 7 | local i = item + item 8 | depth = depth - 1 9 | local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) 10 | return { item, left, right } 11 | else 12 | return { item } 13 | end 14 | end 15 | 16 | function ItemCheck(tree) 17 | if tree[2] then 18 | return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) 19 | else 20 | return tree[1] 21 | end 22 | end 23 | 24 | local mindepth = 4 25 | local maxdepth = 17 26 | 27 | local stretchdepth = maxdepth + 1 28 | local stretchtree = BottomUpTree(0, stretchdepth) 29 | print("stretch tree of depth", stretchdepth, "check:", ItemCheck(stretchtree)) 30 | 31 | local longlivedtree = BottomUpTree(0, maxdepth) 32 | 33 | for depth=mindepth,maxdepth,2 do 34 | local iterations = 2 ^ (maxdepth - depth + mindepth) 35 | local check = 0 36 | for i=1,iterations do 37 | check = check + ItemCheck(BottomUpTree(1, depth)) + 38 | ItemCheck(BottomUpTree(-1, depth)) 39 | end 40 | print(iterations*2, "trees of depth", depth, "check:", check) 41 | end 42 | 43 | print("long lived tree of depth", maxdepth, "check:", ItemCheck(longlivedtree)) 44 | -------------------------------------------------------------------------------- /lua/fannkuch.lua: -------------------------------------------------------------------------------- 1 | -- The Computer Language Benchmarks Game 2 | -- https://salsa.debian.org/benchmarksgame-team/benchmarksgame/ 3 | -- contributed by Mike Pall 4 | 5 | local function fannkuch(n) 6 | local p, q, s, sign, maxflips, sum = {}, {}, {}, 1, 0, 0 7 | for i=1,n do p[i] = i; q[i] = i; s[i] = i end 8 | repeat 9 | -- Copy and flip. 10 | local q1 = p[1] -- Cache 1st element. 11 | if q1 ~= 1 then 12 | for i=2,n do q[i] = p[i] end -- Work on a copy. 13 | local flips = 1 14 | local exit = false 15 | repeat 16 | local qq = q[q1] 17 | if qq == 1 then -- ... until 1st element is 1. 18 | sum = sum + sign*flips 19 | if flips > maxflips then maxflips = flips end -- New maximum? 20 | exit = true 21 | else 22 | q[q1] = q1 23 | if q1 >= 4 then 24 | local i, j = 2, q1 - 1 25 | repeat q[i], q[j] = q[j], q[i]; i = i + 1; j = j - 1; until i >= j 26 | end 27 | q1 = qq; flips = flips + 1 28 | end 29 | until exit 30 | end 31 | -- Permute. 32 | if sign == 1 then 33 | p[2], p[1] = p[1], p[2]; sign = -1 -- Rotate 1<-2. 34 | else 35 | p[2], p[3] = p[3], p[2]; sign = 1 -- Rotate 1<-2 and 1<-2<-3. 36 | local i = 3 37 | local loop = true 38 | while i <= n and loop do 39 | local sx = s[i] 40 | if sx ~= 1 then 41 | s[i] = sx-1 42 | loop = false 43 | else 44 | if i == n then 45 | return sum, maxflips 46 | end -- Out of permutations. 47 | s[i] = i 48 | -- Rotate 1<-...<-i+1. 49 | local t = p[1] 50 | for j=1,i do 51 | p[j] = p[j+1] 52 | end 53 | p[i+1] = t 54 | i = i + 1 55 | end 56 | end 57 | end 58 | until false 59 | end 60 | 61 | local n = 10 62 | local sum, flips = fannkuch(n) 63 | print(sum) 64 | print("Pfannkuchen(", n, ") = ", flips) 65 | -------------------------------------------------------------------------------- /lua/impl.h: -------------------------------------------------------------------------------- 1 | #include "lib.h" 2 | 3 | #include 4 | 5 | namespace lua { 6 | 7 | std::string &as_std_string(TObject *val); 8 | 9 | } // end namespace lua 10 | -------------------------------------------------------------------------------- /lua/lib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | /******************************************************************************* 11 | * Definitions 12 | ******************************************************************************/ 13 | 14 | enum { 15 | NIL, 16 | BOOL, 17 | NUM, 18 | STR, 19 | TBL, 20 | FCN, 21 | INT = 10 22 | }; 23 | 24 | struct Object; 25 | 26 | typedef struct Pack { 27 | int32_t size; 28 | struct Object *objs; 29 | } TPack; 30 | 31 | typedef struct Object *TCapture; 32 | 33 | typedef TPack (*lua_fcn_t)(TCapture, TPack); 34 | 35 | typedef struct Closure { 36 | lua_fcn_t addr; 37 | TCapture capture; 38 | } TClosure; 39 | 40 | typedef struct Object { 41 | int32_t type; 42 | union { 43 | int64_t u; 44 | 45 | bool b; 46 | double num; 47 | void *impl; 48 | }; 49 | } TObject; 50 | 51 | 52 | /******************************************************************************* 53 | * Simple Value Manipulation 54 | ******************************************************************************/ 55 | 56 | TObject lua_nil(void); 57 | 58 | TObject lua_alloc(void); 59 | void lua_copy(TObject *ptr, TObject val); 60 | 61 | int32_t lua_get_type(TObject val); 62 | void lua_set_type(TObject *ptr, int32_t ty); 63 | 64 | bool lua_get_bool_val(TObject val); 65 | void lua_set_bool_val(TObject *ptr, bool b); 66 | 67 | double lua_get_double_val(TObject val); 68 | void lua_set_double_val(TObject *ptr, double fp); 69 | 70 | lua_fcn_t lua_get_fcn_addr(TObject val); 71 | void lua_set_fcn_addr(TObject val, lua_fcn_t fcn_addr); 72 | 73 | TCapture lua_get_capture_pack(TObject val); 74 | void lua_set_capture_pack(TObject val, TCapture capture); 75 | 76 | uint64_t lua_get_value_union(TObject val); 77 | void lua_set_value_union(TObject *ptr, uint64_t u); 78 | 79 | /******************************************************************************* 80 | * Pack Manipulation 81 | ******************************************************************************/ 82 | 83 | TCapture lua_new_capture(int32_t size); 84 | void lua_add_capture(TCapture capture, TObject *ptr, int32_t idx); 85 | 86 | TPack lua_get_ret_pack(int32_t size); 87 | TPack lua_get_arg_pack(int32_t size); 88 | 89 | void lua_pack_insert(TPack pack, TObject val, int32_t idx); 90 | void lua_pack_insert_all(TPack pack, TPack tail, int32_t idx); 91 | TObject lua_pack_get(TPack pack, int32_t idx); 92 | int32_t lua_pack_get_size(TPack pack); 93 | 94 | /******************************************************************************* 95 | * Tables and Strings 96 | ******************************************************************************/ 97 | 98 | TObject lua_new_table(void); 99 | void lua_table_set(TObject tbl, TObject key, TObject val); 100 | TObject lua_table_get(TObject tbl, TObject key); 101 | void lua_table_set(TObject tbl, TObject key, TObject val); 102 | 103 | TObject lua_list_size(TObject tbl); 104 | TObject lua_load_string(const char *data, uint64_t len); 105 | 106 | #ifdef __cplusplus 107 | } 108 | #endif 109 | -------------------------------------------------------------------------------- /lua/loops.lua: -------------------------------------------------------------------------------- 1 | a = 0 2 | b = 0 3 | for i=5,1000000,2 do 4 | a = (a + 5) * 1 5 | 6 | for j=1,10000,3 do 7 | b = b + i + j 8 | end 9 | 10 | end 11 | 12 | print(a) 13 | print(b) 14 | -------------------------------------------------------------------------------- /lua/markov.in: -------------------------------------------------------------------------------- 1 | the more we try the more we do 2 | -------------------------------------------------------------------------------- /lua/markov.lua: -------------------------------------------------------------------------------- 1 | -- Markov Chain Program in Lua 2 | function allwords () 3 | local line = io.read() -- current line 4 | local pos = 1 -- current position in the line 5 | return function () -- iterator function 6 | while line do -- repeat while there are lines 7 | local s, e = string.find(line, "%w+", pos) 8 | if s then -- found a word? 9 | pos = e + 1 -- update next position 10 | return string.sub(line, s, e) -- return the word 11 | else 12 | line = io.read() -- word not found; try next line 13 | pos = 1 -- restart from first position 14 | end 15 | end 16 | return nil -- no more lines: end of traversal 17 | end 18 | end 19 | 20 | function prefix (w1, w2) 21 | return w1 .. ' ' .. w2 22 | end 23 | 24 | local statetab 25 | 26 | function insert (index, value) 27 | if not statetab[index] then 28 | statetab[index] = {n=0} 29 | end 30 | table.insert(statetab[index], value) 31 | end 32 | 33 | local N = 2 34 | local MAXGEN = 10000 35 | local NOWORD = "\n" 36 | 37 | -- build table 38 | statetab = {} 39 | local w1, w2 = NOWORD, NOWORD 40 | for w in allwords() do 41 | insert(prefix(w1, w2), w) 42 | w1 = w2; w2 = w; 43 | end 44 | insert(prefix(w1, w2), NOWORD) 45 | 46 | -- generate text 47 | w1 = NOWORD; w2 = NOWORD -- reinitialize 48 | for i=1,MAXGEN do 49 | local list = statetab[prefix(w1, w2)] 50 | -- choose a random item from list 51 | local r = math.random(#list) 52 | local nextword = list[r] 53 | if nextword == NOWORD then 54 | print() 55 | return 56 | end 57 | io.write(nextword, " ") 58 | w1 = w2; w2 = nextword 59 | end 60 | -------------------------------------------------------------------------------- /lua/preliminary-results.txt: -------------------------------------------------------------------------------- 1 | clang flags: -Ofast -flto 2 | 3 | binarytree N = 18 4 | luajit -jon -O3: 9.0 sec 5 | luajit -joff -O3: 12.8 sec 6 | luac -Oall: 7.7 sec 7 | luac -Onone: 9.3 sec 8 | lua: 45.1 sec 9 | 10 | fannkuch N = 12 11 | luajit -jon -O3: 42.4 sec 12 | luajit -joff -O3: 184.0 sec 13 | luac -Oall: 22.3 sec 14 | luac -Onone: 166.0 sec 15 | lua: 459.4 sec 16 | -------------------------------------------------------------------------------- /lua/test.c: -------------------------------------------------------------------------------- 1 | #include "lib.h" 2 | #include 3 | #include 4 | 5 | int main() { 6 | int a = 10; 7 | for (int i = 10; i < 100; ++i) { 8 | a = a * 2 + a; 9 | int g = a; 10 | int q; 11 | if (i == 55 && g > 22) { 12 | q = 44; 13 | } else { 14 | q = a + 44; 15 | } 16 | a += q; 17 | } 18 | printf("%d\n", a); 19 | } 20 | -------------------------------------------------------------------------------- /lua/test.lua: -------------------------------------------------------------------------------- 1 | i = {} 2 | i[4] = "hi" 3 | print(i[4]) 4 | print(i[5], "end") 5 | -------------------------------------------------------------------------------- /oec/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.s 3 | *.ll 4 | *.lowered 5 | main 6 | __pycache__ 7 | -------------------------------------------------------------------------------- /oec/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(open-earth-compiler) 2 | 3 | check_language(CUDA) 4 | if (CMAKE_CUDA_COMPILER) 5 | enable_language(CUDA) 6 | else () 7 | message(SEND_ERROR "OEC requires CUDA") 8 | endif () 9 | find_library(CUDA_RUNTIME_LIBRARY cuda) 10 | 11 | pybind11_add_module(dl_stencil dl_stencil.cpp) 12 | target_include_directories(dl_stencil PUBLIC 13 | ${Python3_INCLUDE_DIRS} 14 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 15 | ) 16 | target_link_libraries(dl_stencil PUBLIC 17 | ${Python3_LIBRARIES} 18 | ${CUDA_RUNTIME_LIBRARY} 19 | pybind11 20 | cuda-runtime-wrappers 21 | ) 22 | -------------------------------------------------------------------------------- /oec/Makefile: -------------------------------------------------------------------------------- 1 | main: main.o laplace.o 2 | clang++ main.o laplace.o -o main -fPIE -L$(LD_LIBRARY_PATH) -lcuda-runtime-wrappers -lcuda 3 | 4 | main.o: main.cpp 5 | clang++ -O3 -c main.cpp -o main.o -fPIE -I/opt/cuda/include -g 6 | 7 | laplace.o: laplace.ll 8 | clang -c laplace.ll -o laplace.o -fPIE 9 | 10 | laplace.ll: laplace.mlir.lowered 11 | mlir-translate --mlir-to-llvmir laplace.mlir.lowered -o laplace.ll 12 | 13 | laplace.mlir.lowered: laplace.mlir 14 | oec-opt --stencil-shape-inference --convert-stencil-to-std --cse --parallel-loop-tiling='parallel-loop-tile-sizes=128,1,1' --canonicalize --test-gpu-greedy-parallel-loop-mapping --convert-parallel-loops-to-gpu --canonicalize --lower-affine --convert-scf-to-std --stencil-kernel-to-cubin laplace.mlir > laplace.mlir.lowered 15 | 16 | clean: 17 | rm -f *.s 18 | rm -f *.o 19 | rm -f *.ll 20 | rm -f *.lowered 21 | -------------------------------------------------------------------------------- /oec/dl_stencil.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | static void cuda_init() { 11 | static CUdevice device; 12 | static CUcontext context; 13 | static bool inited{false}; 14 | 15 | if (!inited) { 16 | inited = true; 17 | cuInit(0); 18 | cuDeviceGet(&device, 0); 19 | cuCtxCreate(&context, 0, device); 20 | } 21 | } 22 | 23 | namespace py = pybind11; 24 | 25 | // TODO f32 26 | struct stencil_t { 27 | double *allocatedPtr; 28 | double *alignedPtr; 29 | int32_t offset; 30 | int32_t sizes[3]; 31 | int32_t strides[3]; 32 | }; 33 | 34 | using stencil_fcn_t = void (*)(stencil_t *, stencil_t *); 35 | 36 | static std::size_t compute_mem_size(py::buffer_info &info) { 37 | std::size_t size = info.itemsize; 38 | for (py::ssize_t i = 0; i < info.ndim; ++i) { 39 | size *= info.shape[i]; 40 | } 41 | return size; 42 | } 43 | 44 | extern "C" { 45 | void mgpuMemAlloc(CUdeviceptr *ptr, uint64_t size); 46 | void mgpuMemFree(CUdeviceptr ptr); 47 | } 48 | 49 | static std::function 50 | bind_stencil(std::string sym_name, std::string dl_name) { 51 | void *handle = dlopen(dl_name.c_str(), RTLD_LAZY | RTLD_NODELETE); 52 | if (char *err = dlerror()) { 53 | std::cerr << "dlopen(" << dl_name << ") error: " << err << std::endl; 54 | return nullptr; 55 | } 56 | std::string ciface_sym = "_mlir_ciface_" + sym_name; 57 | void *fcn_handle = dlsym(handle, ciface_sym.c_str()); 58 | if (char *err = dlerror(); fcn_handle == nullptr) { 59 | std::cerr << "dlsym(" << ciface_sym << ") error: " << err << std::endl; 60 | return nullptr; 61 | } 62 | stencil_fcn_t stencil_fcn = reinterpret_cast(fcn_handle); 63 | // TODO more than one input/output 64 | return [stencil_fcn](py::buffer input, py::buffer output) { 65 | py::buffer_info input_info = input.request(); 66 | py::buffer_info output_info = output.request(); 67 | 68 | if (input_info.ndim != 3) { 69 | throw std::runtime_error{"incompatible input shape: expected 3D array"}; 70 | } 71 | if (output_info.ndim != 3) { 72 | throw std::runtime_error{"incompatible output shape: expected 3D array"}; 73 | } 74 | 75 | if (input_info.format != py::format_descriptor::format()) { 76 | throw std::runtime_error{"incompatible input format: expected f64"}; 77 | } 78 | if (output_info.format != py::format_descriptor::format()) { 79 | throw std::runtime_error{"incompatible output format: expected f64"}; 80 | } 81 | 82 | std::size_t input_mem_size = compute_mem_size(input_info); 83 | std::size_t output_mem_size = compute_mem_size(output_info); 84 | 85 | CUdeviceptr input_mem_ptr{}, output_mem_ptr{}; 86 | mgpuMemAlloc(&input_mem_ptr, input_mem_size); 87 | mgpuMemAlloc(&output_mem_ptr, output_mem_size); 88 | 89 | cuMemcpyHtoD(input_mem_ptr, input_info.ptr, input_mem_size); 90 | 91 | stencil_t input_stencil{ 92 | (double *) input_mem_ptr, (double *) input_mem_ptr, 0, 93 | { input_info.shape[0], input_info.shape[1], input_info.shape[2] }, 94 | { input_info.strides[0], input_info.strides[1], input_info.strides[2] } 95 | }; 96 | stencil_t output_stencil{ 97 | (double *) output_mem_ptr, (double *) output_mem_ptr, 0, 98 | { output_info.shape[0], output_info.shape[1], output_info.shape[2] }, 99 | { output_info.strides[0], output_info.strides[1], output_info.strides[2] }, 100 | }; 101 | 102 | stencil_fcn(&input_stencil, &output_stencil); 103 | 104 | cuMemcpyDtoH(output_info.ptr, output_mem_ptr, output_mem_size); 105 | 106 | mgpuMemFree(input_mem_ptr); 107 | mgpuMemFree(input_mem_ptr); 108 | }; 109 | } 110 | 111 | PYBIND11_MODULE(dl_stencil, m) { 112 | m.doc() = "Stencil Dynamic Library Binding"; 113 | 114 | m.def("cuda_init", &cuda_init); 115 | m.def("bind_stencil", &bind_stencil); 116 | } 117 | -------------------------------------------------------------------------------- /oec/laplace.mlir: -------------------------------------------------------------------------------- 1 | 2 | 3 | module { 4 | func @laplace(%arg0: !stencil.field, %arg1: !stencil.field) attributes {stencil.program} { 5 | stencil.assert %arg0([-4, -4, -4] : [68, 68, 68]) : !stencil.field 6 | stencil.assert %arg1([-4, -4, -4] : [68, 68, 68]) : !stencil.field 7 | %0 = stencil.load %arg0 : (!stencil.field) -> !stencil.temp 8 | %1 = stencil.apply (%arg2 = %0 : !stencil.temp) -> !stencil.temp { 9 | %2 = stencil.access %arg2 [-1, 0, 0] : (!stencil.temp) -> f64 10 | %3 = stencil.access %arg2 [1, 0, 0] : (!stencil.temp) -> f64 11 | %4 = stencil.access %arg2 [0, 1, 0] : (!stencil.temp) -> f64 12 | %5 = stencil.access %arg2 [0, -1, 0] : (!stencil.temp) -> f64 13 | %6 = stencil.access %arg2 [0, 0, 0] : (!stencil.temp) -> f64 14 | %7 = addf %2, %3 : f64 15 | %8 = addf %4, %5 : f64 16 | %9 = addf %7, %8 : f64 17 | %cst = constant -4.000000e+00 : f64 18 | %10 = mulf %6, %cst : f64 19 | %11 = addf %10, %9 : f64 20 | stencil.return %11 : f64 21 | } 22 | stencil.store %1 to %arg1([0, 0, 0] : [64, 64, 64]) : !stencil.temp to !stencil.field 23 | return 24 | } 25 | 26 | func @fill(%arg0: memref<72x72x72xf64>) { 27 | %c0 = constant 0 : index 28 | %c1 = constant 1 : index 29 | %c72 = constant 72 : index 30 | scf.parallel (%i, %j, %k) = (%c0, %c0, %c0) to (%c72, %c72, %c72) step (%c1, %c1, %c1) { 31 | %0 = index_cast %i : index to i64 32 | %1 = index_cast %j : index to i64 33 | %2 = index_cast %k : index to i64 34 | %3 = sitofp %0 : i64 to f64 35 | %4 = sitofp %1 : i64 to f64 36 | %5 = sitofp %2 : i64 to f64 37 | %6 = addf %3, %4 : f64 38 | %7 = addf %6, %5 : f64 39 | store %7, %arg0[%i, %j, %k] : memref<72x72x72xf64> 40 | scf.yield 41 | } 42 | return 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /oec/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | extern "C" { 10 | typedef struct { 11 | double *allocatedPtr; 12 | double *alignedPtr; 13 | int32_t offset; 14 | int32_t sizes[3]; 15 | int32_t strides[3]; 16 | } MemRefType3D; 17 | 18 | void mgpuMemAlloc(CUdeviceptr *ptr, uint64_t size); 19 | void mgpuMemFree(CUdeviceptr ptr); 20 | 21 | void _mlir_ciface_laplace(MemRefType3D *input, MemRefType3D *output); 22 | void _mlir_ciface_fill(MemRefType3D *inout); 23 | double _mlir_ciface_get(MemRefType3D *input, int32_t i, int32_t j, int32_t k); 24 | } 25 | 26 | int main() { 27 | cuInit(0); 28 | CUdevice device; 29 | cuDeviceGet(&device, 0); 30 | CUcontext context; 31 | cuCtxCreate(&context, 0, device); 32 | 33 | constexpr int32_t dim = 72; 34 | constexpr auto mem_size = dim * dim * dim * sizeof(double); 35 | constexpr auto mem_align = sizeof(double); 36 | 37 | CUdeviceptr input_mem{}, output_mem{}; 38 | std::size_t input_space = mem_size + mem_align; 39 | std::size_t output_space = input_space; 40 | mgpuMemAlloc(&input_mem, input_space); 41 | mgpuMemAlloc(&output_mem, output_space); 42 | 43 | auto *input_ptr = (void *) input_mem; 44 | if (!std::align(mem_align, mem_size, input_ptr, input_space)) { 45 | std::cout << "Failed to align input memory" << std::endl; 46 | return -1; 47 | } 48 | 49 | auto *output_ptr = (void *) output_mem; 50 | if (!std::align(mem_align, mem_size, output_ptr, output_space)) { 51 | std::cout << "Failed to align output memory" << std::endl; 52 | return -1; 53 | } 54 | 55 | MemRefType3D input{ (double *) input_mem, (double *) input_ptr, 0, 56 | { dim, dim, dim }, 57 | { dim * dim * sizeof(double), dim * sizeof(double), 58 | sizeof(double) }}; 59 | 60 | MemRefType3D output{ (double *) output_mem, (double *) output_ptr, 0, 61 | { dim, dim, dim }, 62 | { dim * dim * sizeof(double), dim * sizeof(double), 63 | sizeof(double) }}; 64 | 65 | _mlir_ciface_fill(&input); 66 | _mlir_ciface_laplace(&input, &output); 67 | 68 | mgpuMemFree(input_mem); 69 | mgpuMemFree(output_mem); 70 | return 0; 71 | } 72 | -------------------------------------------------------------------------------- /oec/test.py: -------------------------------------------------------------------------------- 1 | import stencil 2 | import numpy as np 3 | 4 | @stencil.program 5 | def laplace(a, b): 6 | stencil.cast(a, [-4, -4, -4], [68, 68, 68]) 7 | stencil.cast(b, [-4, -4, -4], [68, 68, 68]) 8 | atmp = stencil.load(a) 9 | 10 | def applyFcn(c) -> float: 11 | return c[0, 0, 0] + c[-1, 0, 0] + c[1, 0, 0] + c[0, 1, 0] + c[0, -1, 0] 12 | 13 | btmp = stencil.apply(atmp, applyFcn) 14 | stencil.store(b, btmp, [0, 0, 0], [64, 64, 64]) 15 | return 16 | 17 | a = np.empty([72, 72, 72], dtype='d') 18 | b = np.empty([72, 72, 72], dtype='d') 19 | a.fill(3) 20 | laplace(a, b) 21 | laplace(b, a) 22 | laplace(a, b) 23 | laplace(b, a) 24 | laplace(a, b) 25 | laplace(b, a) 26 | laplace(a, b) 27 | print(b) 28 | print(b[32,32,32]) 29 | -------------------------------------------------------------------------------- /spec/dialect.mlir: -------------------------------------------------------------------------------- 1 | dmc.Dialect @test { 2 | dmc.Op @op_a(arg0 : !dmc.AnyInteger, arg1 : !dmc.AnyOf, !dmc.AnyFloat>) -> (ret0 : !dmc.UI<32>) 3 | { attr0 = #dmc.APInt } 4 | dmc.Op @op_b(arg0 : !dmc.AnyFloat, arg1 : !dmc.F<16>) -> (ret0 : !dmc.BF16, ret1 : !dmc.SI<32>) 5 | { attr1 = #dmc.Bool } 6 | dmc.Op @my_ret(arg0 : !dmc.AnyInteger, arg1 : !dmc.Variadic) -> () 7 | { attr2 = #dmc.Optional<#dmc.Bool> } 8 | traits [@SameVariadicOperandSizes, @AtLeastNOperands<1>, @IsTerminator] 9 | 10 | dmc.Type @CustomType 11 | dmc.Op @op_c(arg0 : !test.CustomType) -> (ret0 : !test.CustomType) 12 | traits [@HasParent<"func">] 13 | 14 | dmc.Type @Array2D 15 | dmc.Alias @IsArray2D -> !dmc.Isa<@test::@Array2D> 16 | dmc.Op @transpose(arg0 : !test.IsArray2D) -> (ret0 : !test.IsArray2D) 17 | 18 | dmc.Op @get_value() -> (value : !dmc.Any) 19 | 20 | dmc.Attr @CustomAttr 21 | dmc.Attr @Pair 22 | dmc.Alias @IsPair -> #dmc.Isa<@test::@Pair> 23 | dmc.Alias @IsCustomAttr -> #dmc.Isa<@test::@CustomAttr> 24 | dmc.Op @op_d() -> () { attr3 = #test.IsPair, 25 | attr4 = #test.IsCustomAttr } 26 | 27 | dmc.Attr @Box 28 | dmc.Attr @CustomPair, second: #dmc.Isa<@test::@Box>> 29 | 30 | dmc.Type @BoxType> 31 | dmc.Op @op_e(arg0 : !test.BoxType<#test.Box<6>>) -> () 32 | 33 | dmc.Op @op_regions() -> () {} (r0 : Any, r1 : Sized<2>, Rs : Variadic) 34 | dmc.Op @ret() -> () traits [@IsTerminator] 35 | 36 | dmc.Alias @IsInteger -> !dmc.Py<"isinstance({self}, IntegerType)"> 37 | dmc.Alias @ArraySize3 -> #dmc.Py<"isinstance({self}, ArrayAttr) and len({self}) == 3"> 38 | dmc.Op @op_py(arg0 : !test.IsInteger) -> () { index = #test.ArraySize3 } 39 | 40 | dmc.Op @op_succ() -> () [s0 : Any, s1 : Any, Ss : Variadic] traits [@IsTerminator] 41 | } 42 | -------------------------------------------------------------------------------- /spec/laplace.mlir: -------------------------------------------------------------------------------- 1 | func @laplace(%arg0: !stencil.field, %arg1: !stencil.field) attributes {stencil.program} { 2 | stencil.assert %arg0([-4, -4, -4] : [68, 68, 68]) : !stencil.field 3 | stencil.assert %arg1([-4, -4, -4] : [68, 68, 68]) : !stencil.field 4 | %0 = stencil.load %arg0 : (!stencil.field) -> !stencil.temp 5 | %1 = stencil.apply(%0) : (!stencil.temp) -> !stencil.temp (%arg2: !stencil.temp) { 6 | %2 = stencil.access %arg2 [-1, 0, 0] : (!stencil.temp) -> f64 7 | %3 = stencil.access %arg2 [1, 0, 0] : (!stencil.temp) -> f64 8 | %4 = stencil.access %arg2 [0, 1, 0] : (!stencil.temp) -> f64 9 | %5 = stencil.access %arg2 [0, -1, 0] : (!stencil.temp) -> f64 10 | %6 = stencil.access %arg2 [0, 0, 0] : (!stencil.temp) -> f64 11 | %7 = addf %2, %3 : f64 12 | %8 = addf %4, %5 : f64 13 | %9 = addf %7, %8 : f64 14 | %cst = constant -4.000000e+00 : f64 15 | %10 = mulf %6, %cst : f64 16 | %11 = addf %10, %9 : f64 17 | stencil.return %11 : f64 18 | } 19 | stencil.store %1 to %arg1([0, 0, 0] : [64, 64, 64]) : !stencil.temp to !stencil.field 20 | return 21 | } 22 | -------------------------------------------------------------------------------- /spec/stencil.mlir: -------------------------------------------------------------------------------- 1 | Dialect @stencil { 2 | /// Base constraints. 3 | Alias @Shape -> #dmc.AllOf<#dmc.Array, #dmc.ArrayOf<#dmc.APInt>> 4 | Alias @ArrayCount3 -> #dmc.Py<"isinstance({self}, ArrayAttr) and len({self}) == 3"> 5 | 6 | /// Stencil types: FieldType and TempType, both subclass GridType. 7 | Type @field 8 | { fmt = "`<` dims($shape) $type `>`" } 9 | Type @temp 10 | { fmt = "`<` dims($shape) $type `>`" } 11 | Alias @Field -> !dmc.Isa<@stencil::@field> 12 | Alias @Temp -> !dmc.Isa<@stencil::@temp> 13 | 14 | /// Element type and index attribute constraints. 15 | Alias @None -> !dmc.None { builder = "NoneType()" } 16 | Alias @Element -> !dmc.AnyOf 17 | Alias @Index -> #dmc.AllOf<#dmc.ArrayOf<#dmc.APInt>, #stencil.ArrayCount3> 18 | { type = !stencil.None } 19 | Alias @OptionalIndex -> #dmc.Optional<#stencil.Index> 20 | { type = !stencil.None } 21 | 22 | /// AssertOp 23 | Op @assert(field: !stencil.Field) -> () { lb = #stencil.Index, 24 | ub = #stencil.Index } 25 | config { fmt = "$field `(` $lb `:` $ub `)` attr-dict-with-keyword `:` type($field)" } 26 | 27 | /// AccessOp 28 | Op @access(temp: !stencil.Temp) -> (res: !stencil.Element) 29 | { offset = #stencil.Index } 30 | config { fmt = "$temp $offset attr-dict-with-keyword `:` functional-type($temp, $res)" } 31 | 32 | /// LoadOp 33 | Op @load(field: !stencil.Field) -> (res: !stencil.Temp) 34 | { lb = #stencil.OptionalIndex, ub = #stencil.OptionalIndex } 35 | config { fmt = "$field (`(` $lb^ `:` $ub `)`)? attr-dict-with-keyword `:` functional-type($field, $res)" } 36 | 37 | /// StoreOp 38 | Op @store(temp: !stencil.Temp, field: !stencil.Field) -> () 39 | { lb = #stencil.Index, ub = #stencil.Index } 40 | config { fmt = "$temp `to` $field `(` $lb `:` $ub `)` attr-dict-with-keyword `:` type($temp) `to` type($field)" } 41 | 42 | /// ApplyOp 43 | Op @apply(operands: !dmc.Variadic) -> (res: !dmc.Variadic) 44 | { lb = #stencil.OptionalIndex, ub = #stencil.OptionalIndex } 45 | (region: Sized<1>) 46 | traits [@SameVariadicOperandSizes, @SameVariadicResultSizes, 47 | @SingleBlockImplicitTerminator<"stencil.return">] 48 | config { is_isolated_from_above = true, 49 | fmt = "`(` $operands `)` `:` functional-type($operands, $res) attr-dict-with-keyword $region (`to` `(` $lb^ `:` $ub `)`)?" } 50 | 51 | /// ReturnOp 52 | Op @return(operands: !dmc.Variadic) -> () 53 | { unroll = #stencil.OptionalIndex } 54 | traits [@SameVariadicOperandSizes, @HasParent<"stencil.apply">, @IsTerminator] 55 | config { fmt = "(`unroll` $unroll^)? $operands attr-dict-with-keyword `:` type($operands)" } 56 | } 57 | -------------------------------------------------------------------------------- /spec/test.mlir: -------------------------------------------------------------------------------- 1 | module { 2 | func @test0(%arg0 : i32, %arg1 : i64, %arg2 : f16) -> (bf16) { 3 | %0 = "test.op_a"(%arg0, %arg2) { attr0 = 6 } : (i32, f16) -> ui32 4 | %1, %2 = "test.op_b"(%arg2, %arg2) { attr1 = true } : (f16, f16) -> (bf16, si32) 5 | "test.my_ret"(%2, %0, %1, %arg1) { attr2 = false } : (si32, ui32, bf16, i64) -> () 6 | } 7 | func @test1() -> !test.Array2D<3,2> { 8 | %0 = "test.get_value"() : () -> !test.Array2D<5,3> 9 | %1 = "test.get_value"() : () -> i32 10 | %2 = "test.transpose"(%0) : (!test.Array2D<5,3>) -> !test.Array2D<3,5> 11 | "test.my_ret"(%1, %2) : (i32, !test.Array2D<3,5>) -> () 12 | } 13 | 14 | func @test2() -> i32 { 15 | %0 = "test.get_value"() : () -> !test.CustomType 16 | %1 = "test.op_c"(%0) : (!test.CustomType) -> !test.CustomType 17 | %2 = "test.get_value"() : () -> i32 18 | "test.my_ret"(%2) : (i32) -> () 19 | } 20 | 21 | func @test3() -> i32 { 22 | %0 = "test.get_value"() : () -> i32 23 | "test.op_d"() { attr3 = #test.Pair<1, 2>, attr4 = #test.CustomAttr } : () -> () 24 | %1 = "test.get_value"() : () -> !test.BoxType<#test.Box<6>> 25 | "test.op_e"(%1) : (!test.BoxType<#test.Box<6>>) -> () 26 | "test.my_ret"(%0) 27 | { attrUnknown = #test.CustomPair<#test.Box<3>, #test.Box<400>>, 28 | attrTraits = #trait.OpTraits<[#trait.OpTrait<@IsTerminator<1,2>>]>} 29 | : (i32) -> () 30 | } 31 | 32 | func @test4() -> () { 33 | %0 = "test.get_value"() : () -> i32 34 | "test.op_regions"() ({ 35 | "test.ret"() : () -> () 36 | }, { 37 | "test.ret"() : () -> () 38 | ^bb1: 39 | "test.ret"() : () -> () 40 | }, { 41 | "test.ret"() : () -> () 42 | }, { 43 | "test.ret"() : () -> () 44 | ^bb1: 45 | "test.ret"() : () -> () 46 | }): () -> () 47 | "test.ret"() : () -> () 48 | } 49 | 50 | func @test5() -> () { 51 | %0 = "test.get_value"() : () -> i32 52 | "test.op_py"(%0) { index = [1, 2, 3] } : (i32) -> () 53 | "test.ret"() : () -> () 54 | } 55 | 56 | func @test6() -> () { 57 | ^bb0: 58 | "test.op_succ"() [^bb1, ^bb2] : () -> () 59 | ^bb1: 60 | "test.op_succ"() [^bb2, ^bb2, ^bb2] : () -> () 61 | ^bb2: 62 | "test.ret"() : () -> () 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /tools/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(spec spec.cpp) 2 | target_link_libraries(spec 3 | DMCDynamic 4 | DMCIO 5 | DMCSpec 6 | DMCTraits 7 | DMCEmbed 8 | LLVMSupport 9 | MLIRParser 10 | DMCEmbedInit 11 | ) 12 | 13 | add_executable(gen dialectgen.cpp) 14 | target_link_libraries(gen 15 | DMCSpec 16 | DMCDynamic 17 | DMCTraits 18 | DMCEmbed 19 | LLVMSupport 20 | MLIRStandardOps 21 | MLIRParser 22 | MLIRLLVMIR 23 | MLIRSCFToStandard 24 | MLIRStandardToLLVM 25 | DMCEmbedInit 26 | ) 27 | -------------------------------------------------------------------------------- /tools/dialectgen.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecDialect.h" 2 | #include "dmc/Spec/DialectGen.h" 3 | #include "dmc/Traits/Registry.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | using namespace mlir; 21 | using namespace llvm; 22 | using namespace dmc; 23 | 24 | static DialectRegistration specDialectRegistration; 25 | static DialectRegistration registerTraits; 26 | static DialectRegistration registerStdOps; 27 | static DialectRegistration registerScfOps; 28 | static DialectRegistration registerLlvmOps; 29 | 30 | int main(int argc, char *argv[]) { 31 | if (argc != 3) { 32 | llvm::errs() << "Usage: gen \n"; 33 | return -1; 34 | } 35 | 36 | MLIRContext ctx; 37 | auto *dynCtx = ctx.getOrCreateDialect(); 38 | 39 | SourceMgr dialectSrcMgr; 40 | SourceMgrDiagnosticHandler dialectDiag{dialectSrcMgr, &ctx}; 41 | auto dialectModule = mlir::parseSourceFile(argv[1], dialectSrcMgr, &ctx); 42 | if (!dialectModule) { 43 | llvm::errs() << "Failed to load dialect module: " << argv[1] << "\n"; 44 | return -1; 45 | } 46 | if (failed(verify(*dialectModule))) { 47 | llvm::errs() << "Failed to verify dialect module: " << argv[1] << "\n"; 48 | return -1; 49 | } 50 | 51 | if (failed(registerAllDialects(*dialectModule, dynCtx))) { 52 | llvm::errs() << "Failed to register dynamic dialects\n"; 53 | return -1; 54 | } 55 | 56 | SourceMgr mlirSrcMgr; 57 | SourceMgrDiagnosticHandler mlirDiag{mlirSrcMgr, &ctx}; 58 | auto mlirModule = mlir::parseSourceFile(argv[2], mlirSrcMgr, &ctx); 59 | if (!mlirModule) { 60 | llvm::errs() << "Failed to load MLIR module: " << argv[2] << "\n"; 61 | return -1; 62 | } 63 | if (failed(verify(*mlirModule))) { 64 | llvm::errs() << "Failed to verify MLIR module: " << argv[2] << "\n"; 65 | return -1; 66 | } 67 | 68 | mlirModule->print(llvm::outs()); 69 | llvm::outs() << "\n"; 70 | 71 | return 0; 72 | } 73 | -------------------------------------------------------------------------------- /tools/spec.cpp: -------------------------------------------------------------------------------- 1 | #include "dmc/Spec/SpecDialect.h" 2 | #include "dmc/Spec/DialectGen.h" 3 | #include "dmc/Traits/Registry.h" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace mlir; 16 | using namespace llvm; 17 | using namespace dmc; 18 | 19 | static DialectRegistration specDialectRegistration; 20 | static DialectRegistration registerTraits; 21 | 22 | int main(int argc, char *argv[]) { 23 | if (argc != 2) { 24 | llvm::errs() << "Usage: spec \n"; 25 | return -1; 26 | } 27 | 28 | MLIRContext ctx; 29 | auto *dynCtx = ctx.getOrCreateDialect(); 30 | SourceMgr srcMgr; 31 | SourceMgrDiagnosticHandler srcMgrDiagHandler{srcMgr, &ctx}; 32 | auto mlirModule = mlir::parseSourceFile(argv[1], srcMgr, &ctx); 33 | if (!mlirModule) { 34 | llvm::errs() << "Failed to load MLIR file: " << argv[1] << "\n"; 35 | return -1; 36 | } 37 | if (failed(verify(*mlirModule))) { 38 | llvm::errs() << "Failed to verify MLIR module: " << argv[1] << "\n"; 39 | return -1; 40 | } 41 | mlirModule->print(llvm::outs()); 42 | llvm::outs() << "\n"; 43 | 44 | if (failed(registerAllDialects(*mlirModule, dynCtx))) { 45 | llvm::errs() << "Failed to register dynamic dialects\n"; 46 | return -1; 47 | } 48 | } 49 | --------------------------------------------------------------------------------