├── .clang-format ├── .github ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── ci.yml │ └── docs.yml ├── .gitignore ├── LICENSE ├── README.md ├── Setup.hs ├── docs ├── code-of-conduct.md └── contributing.md ├── mlir-hs.cabal ├── src └── MLIR │ ├── AST.hs │ ├── AST │ ├── Builder.hs │ ├── Dialect │ │ ├── Affine.hs │ │ ├── Arith.hs │ │ ├── ControlFlow.hs │ │ ├── Func.hs │ │ ├── LLVM.hs │ │ ├── Linalg.hs │ │ ├── MemRef.hs │ │ ├── Shape.hs │ │ ├── Tensor.hs │ │ ├── Vector.hs │ │ └── X86Vector.hs │ ├── IStorableArray.hs │ ├── PatternUtil.hs │ ├── Rewrite.hs │ └── Serialize.hs │ ├── Native.hs │ └── Native │ ├── ExecutionEngine.hs │ ├── FFI.hs │ └── Pass.hs ├── stack.yaml ├── tblgen ├── hs-generators.cc └── mlir-hs-tblgen.cc └── test ├── MLIR ├── ASTSpec.hs ├── BuilderSpec.hs ├── NativeSpec.hs ├── RewriteSpec.hs └── Test │ └── Generators.hs └── Spec.hs /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | AlwaysBreakTemplateDeclarations: Yes 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Expected Behavior 2 | 3 | 4 | ## Actual Behavior 5 | 6 | 7 | ## Steps to Reproduce the Problem 8 | 9 | 1. 10 | 1. 11 | 1. 12 | 13 | ## Specifications 14 | 15 | - Version: 16 | - Platform: -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Fixes # 2 | 3 | > It's a good idea to open an issue first for discussion. 4 | 5 | - [ ] Tests pass 6 | - [ ] Appropriate changes to README are included in PR -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Haskell CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | types: [ opened, synchronize ] 9 | schedule: 10 | # Always regenerate once every 4 hour 11 | - cron: '15 */4 * * *' 12 | workflow_dispatch: 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - name: Checkout the repository 20 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 21 | 22 | - name: Setup Haskell Stack 23 | id: setup-haskell 24 | uses: haskell-actions/setup@f7d8a55550ba6c8e4fdba2f1e56e14f595218dd9 # v2.5.1 25 | with: 26 | enable-stack: true 27 | stack-no-global: true 28 | stack-version: 'latest' 29 | 30 | - name: Cache .stack-work 31 | uses: actions/cache@704facf57e6136b1bc63b828d79edcd491f0ee84 # v3.3.2 32 | with: 33 | path: .stack-work 34 | key: stack-work-${{ runner.os }}-${{ hashFiles('stack.yaml', '**/*.cabal') }}-${{ hashFiles('src/*', 'tblgen/*', 'test/*') }} 35 | restore-keys: | 36 | stack-work-${{ runner.os }}-${{ hashFiles('stack.yaml', '**/*.cabal') }}- 37 | stack-work-${{ runner.os }}- 38 | 39 | - name: Cache ~/.stack 40 | uses: actions/cache@704facf57e6136b1bc63b828d79edcd491f0ee84 # v3.3.2 41 | with: 42 | path: ${{ steps.setup-haskell.outputs.stack-root }} 43 | key: stack-root-${{ runner.os }}-${{ hashFiles('stack.yaml', '**/*.cabal') }} 44 | restore-keys: stack-root-${{ runner.os }}- 45 | 46 | - name: Install zstd 47 | run: sudo apt install libzstd-dev 48 | 49 | - name: Install Ninja 50 | uses: llvm/actions/install-ninja@55d844821959226fab4911f96f37071c1d4c3268 51 | 52 | - name: Clone LLVM repo 53 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 54 | with: 55 | repository: llvm/llvm-project 56 | ref: 'main' 57 | path: 'llvm-project' 58 | 59 | - name: Ccache for C++ compilation 60 | uses: hendrikmuhs/ccache-action@6d1841ec156c39a52b1b23a810da917ab98da1f4 # v1.2.10 61 | 62 | - name: Install dependencies (LLVM & MLIR) 63 | run: | 64 | export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" 65 | cmake -B llvm-project/build -DLLVM_CCACHE_BUILD=ON \ 66 | -DLLVM_BUILD_LLVM_DYLIB=ON -DMLIR_BUILD_MLIR_C_DYLIB=ON -DCMAKE_BUILD_TYPE=Release \ 67 | -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_TARGETS_TO_BUILD="host" \ 68 | -DCMAKE_INSTALL_PREFIX=$HOME/mlir_shared llvm-project/llvm 69 | cmake --build llvm-project/build -t install 70 | echo "$HOME/mlir_shared/bin" >> $GITHUB_PATH 71 | env: 72 | CC: clang 73 | CXX: clang++ 74 | CMAKE_GENERATOR: Ninja 75 | 76 | - name: Install dependencies (Haskell) 77 | run: | 78 | stack build --only-dependencies --test --no-run-tests 79 | 80 | - name: Build mlir-hs 81 | run: | 82 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/mlir_shared/lib 83 | stack build --ghc-options "-Wall -Werror -fforce-recomp" --test --no-run-tests 84 | 85 | - name: Run mlir-hs tests 86 | run: | 87 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/mlir_shared/lib 88 | stack test --ghc-options "-Wall -Werror -fforce-recomp" 89 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish Haddock docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout the repository 15 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 16 | 17 | - name: Setup Haskell Stack 18 | uses: actions/setup-haskell@v1 19 | with: 20 | enable-stack: true 21 | stack-no-global: true 22 | stack-version: 'latest' 23 | 24 | - name: Cache 25 | uses: actions/cache@v2 26 | with: 27 | path: ~/.stack 28 | key: ${{ runner.os }}-docs-${{ hashFiles('**/*.cabal', 'stack*.yaml') }} 29 | restore-keys: ${{ runner.os }}-docs- 30 | 31 | - name: Install zstd 32 | run: sudo apt install libzstd-dev 33 | 34 | - name: Install Ninja 35 | uses: llvm/actions/install-ninja@55d844821959226fab4911f96f37071c1d4c3268 36 | 37 | - name: Clone LLVM repo 38 | uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 39 | with: 40 | repository: llvm/llvm-project 41 | ref: 'main' 42 | path: 'llvm-project' 43 | 44 | - name: Ccache for C++ compilation 45 | uses: hendrikmuhs/ccache-action@4687d037e4d7cf725512d9b819137a3af34d39b3 46 | 47 | - name: Install dependencies (LLVM & MLIR) 48 | run: | 49 | export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" 50 | cmake -B llvm-project/build -DLLVM_CCACHE_BUILD=ON \ 51 | -DLLVM_BUILD_LLVM_DYLIB=ON -DMLIR_BUILD_MLIR_C_DYLIB=ON -DCMAKE_BUILD_TYPE=Release \ 52 | -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_TARGETS_TO_BUILD="host" \ 53 | -DCMAKE_INSTALL_PREFIX=$HOME/mlir_shared llvm-project/llvm 54 | cmake --build llvm-project/build -t install 55 | echo "$HOME/mlir_shared/bin" >> $GITHUB_PATH 56 | env: 57 | CC: clang 58 | CXX: clang++ 59 | CMAKE_GENERATOR: Ninja 60 | 61 | - name: Build mlir-hs documentation 62 | run: | 63 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/mlir_shared/lib 64 | stack haddock --no-haddock-deps --force-dirty 65 | cp -r `stack path --local-install-root`/doc haddock 66 | 67 | - name: Deploy to GitHub Pages 68 | uses: "JamesIves/github-pages-deploy-action@3dbacc7e69578703f91f077118b3475862cb09b8" # 4.1.0 69 | with: 70 | token: ${{ secrets.GITHUB_TOKEN }} 71 | branch: gh-pages # The branch the action should deploy to. 72 | folder: haddock # The folder the action should deploy. 73 | clean: false # If true, automatically remove deleted files from the deploy branch. 74 | commit-message: Updating gh-pages from ${{ github.sha }} 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Stack specific files 2 | .stack-work 3 | stack.yaml.lock 4 | # A directory for the tblgen executables 5 | .bin 6 | # Generated files 7 | src/MLIR/AST/Dialect/Generated 8 | test/MLIR/AST/Dialect/Generated 9 | # LLVM src checkout dir 10 | llvm-project 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mlir-hs - Haskell bindings for MLIR 2 | 3 | **🚨 This is an early-stage project. All details are subject to arbitrary changes. 🚨** 4 | 5 | Note that the `main` branch tracks the current HEAD of [LLVM](https://github.com/llvm/llvm-project) 6 | and so it is likely to be incompatible with any past releases. We are planning to 7 | provide release-specifi branches in the future, but only once the API stabilizes. 8 | For now your best bet is to develop against MLIR built from source. See the 9 | [Building MLIR from source](#building-mlir-from-source) section for guidance. 10 | 11 | ## Building 12 | 13 | The only prerequisite for building mlir-hs is that you have MLIR installed 14 | somewhere, and the `llvm-config` binary from that installation is available 15 | in your `PATH` (a good way to verify this is to run `which llvm-config`). 16 | 17 | If that's looking reasonable, we recommend using [Stack](https://haskellstack.org) 18 | for development. To build the project simply run `stack build`, while the test 19 | suite can be executed using `stack test`. 20 | 21 | ### Building MLIR from source 22 | 23 | The instructions below assume that you have `cmake` and `ninja` installed. 24 | You should be able to get them from your favorite package manager. 25 | 26 | 1. Clone the latest LLVM code (or use `git pull` if you cloned it before) into the root of this repository 27 | ```bash 28 | git clone https://github.com/llvm/llvm-project 29 | ``` 30 | 31 | 2. Create a temporary build directory 32 | ```bash 33 | mkdir llvm-project/build 34 | ``` 35 | 36 | 3. Configure the build using CMake. Remember to replace `$PREFIX` with the directory 37 | where you want MLIR to be installed. See [LLVM documentation](https://llvm.org/docs/CMake.html) 38 | for extended explanation and other potentially interesting build flags. 39 | ```bash 40 | cmake -B llvm-project/build \ 41 | -G Ninja \ # Use the Ninja build system 42 | -DLLVM_ENABLE_PROJECTS=mlir \ # Enable build MLIR 43 | -DCMAKE_INSTALL_PREFIX=$PREFIX \ # Install prefix 44 | -DMLIR_BUILD_MLIR_C_DYLIB=ON \ # Build shared libraries 45 | -DLLVM_BUILD_LLVM_DYLIB=ON \ 46 | llvm-project/llvm 47 | ``` 48 | For development purposes we additionally recommend using 49 | `-DCMAKE_BUILD_TYPE=RelWithDebInfo -DLLVM_ENABLE_ASSERTIONS=ON` 50 | to retain debug information and enable internal LLVM assertions. If one changes 51 | the install directory (CMAKE_INSTALL_PREFIX) then one needs to add this directory 52 | to PATH and LD_LIBRARY_PATH for the subsequent builds (e.g., `stack`) to find it. 53 | 54 | 4. [Build and install MLIR]. Note that it uses the installation prefix specified 55 | in the previous step. 56 | ```bash 57 | cmake --build llvm-project/build -t install 58 | ``` 59 | 60 | ## Contributing 61 | 62 | Contributions of all kinds are welcome! If you're planning to implement a larger feature, 63 | consider posting an issue so that we can discuss it before you put in the work. 64 | 65 | ## License 66 | 67 | See the LICENSE file. 68 | 69 | mlir-hs is an early-stage project, not an official Google product. 70 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | import Data.Char 16 | import Data.List 17 | import Data.Maybe 18 | 19 | import System.Info 20 | import System.Directory 21 | import System.FilePath 22 | 23 | import Distribution.ModuleName hiding (main) 24 | import Distribution.Simple 25 | import Distribution.Simple.Setup 26 | import Distribution.Simple.Program 27 | import Distribution.Types.BuildInfo.Lens 28 | import Distribution.Types.Library.Lens 29 | import Distribution.Types.TestSuite.Lens 30 | import Distribution.Types.GenericPackageDescription 31 | import Distribution.Types.CondTree 32 | 33 | import Control.Monad 34 | import Control.Lens.Setter 35 | import Control.Lens.Operators ((&)) 36 | 37 | llvmVersion :: Version 38 | llvmVersion = mkVersion [21] 39 | 40 | llvmConfigProgram :: Program 41 | llvmConfigProgram = (simpleProgram "llvm-config") 42 | { programFindVersion = 43 | findProgramVersion "--version" (takeWhile (\c -> isDigit c || c == '.')) 44 | } 45 | 46 | getLLVMConfig :: ConfigFlags -> IO ([String] -> IO String) 47 | getLLVMConfig confFlags = do 48 | (program, _, _) <- requireProgramVersion 49 | verbosity 50 | llvmConfigProgram 51 | (withinVersion llvmVersion) 52 | (configPrograms confFlags) 53 | return $ getProgramOutput verbosity program 54 | where verbosity = fromFlag $ configVerbosity confFlags 55 | 56 | ccProgram = simpleProgram "c++" 57 | 58 | getCC :: ConfigFlags -> IO ([String] -> IO ()) 59 | getCC confFlags = do 60 | (program, _) <- requireProgram verbosity ccProgram (configPrograms confFlags) 61 | return $ runProgram verbosity program 62 | where verbosity = fromFlag $ configVerbosity confFlags 63 | 64 | isIncludeDir :: String -> Bool 65 | isIncludeDir = ("-I" `isPrefixOf`) 66 | 67 | isLibDir :: String -> Bool 68 | isLibDir = ("-L" `isPrefixOf`) 69 | 70 | data TblGenerator = OpGenerator | TestGenerator 71 | instance Show TblGenerator where 72 | show OpGenerator = "hs-op-defs" 73 | show TestGenerator = "hs-tests" 74 | 75 | trim :: String -> String 76 | trim = dropWhileEnd isSpace . dropWhile isSpace 77 | 78 | buildTblgen :: ConfigFlags -> IO (TblGenerator -> FilePath -> FilePath -> [ProgArg] -> IO ()) 79 | buildTblgen confFlags = do 80 | -- TODO(apaszke): Cache compilation. 81 | cwd <- getCurrentDirectory 82 | llvmConfig <- getLLVMConfig confFlags 83 | cxxFlags <- words <$> llvmConfig ["--cxxflags"] 84 | ldFlags <- words <$> llvmConfig ["--ldflags"] 85 | cppFlags <- words <$> llvmConfig ["--cppflags"] 86 | includeDir <- trim <$> llvmConfig ["--includedir"] 87 | cc <- getCC confFlags 88 | let windowsLLVMVersion = "-" ++ (show $ head $ versionNumbers llvmVersion) 89 | ensureDirectory $ cwd ".bin" 90 | cc $ sources ++ cxxFlags ++ ldFlags ++ 91 | [ "-lMLIR", if os == "mingw32" then "-lLLVM" ++ windowsLLVMVersion else "-lMLIRTableGen", "-lLLVMTableGen" 92 | , "-o", cwd ".bin/mlir-hs-tblgen"] 93 | let tblgenProgram = ConfiguredProgram 94 | { programId = "mlir-hs-tblgen" 95 | , programVersion = Nothing 96 | , programDefaultArgs = ("-I" <> includeDir) : cppFlags 97 | , programOverrideArgs = [] 98 | , programOverrideEnv = [] 99 | , programProperties = mempty 100 | , programLocation = FoundOnSystem $ cwd ".bin/mlir-hs-tblgen" 101 | , programMonitorFiles = [] 102 | } 103 | return $ \generator tdPath outputPath opts -> do 104 | putStrLn $ "Generating " <> (cwd outputPath) 105 | runProgram verbosity tblgenProgram $ 106 | [ "--write-if-changed" 107 | , "--generator", show generator 108 | , includeDir tdPath 109 | , "-o", cwd outputPath 110 | ] ++ opts 111 | where 112 | verbosity = fromFlag $ configVerbosity confFlags 113 | sources = 114 | [ "tblgen/mlir-hs-tblgen.cc" 115 | , "tblgen/hs-generators.cc" 116 | ] 117 | 118 | ensureDirectory :: FilePath -> IO () 119 | ensureDirectory path = 120 | mapM_ ensureDirectoryNonrec $ tail $ scanl' (++) "" $ splitPath path 121 | where 122 | ensureDirectoryNonrec dir = do 123 | exists <- doesDirectoryExist dir 124 | if exists then return () else createDirectory dir 125 | 126 | main :: IO () 127 | main = defaultMainWithHooks simpleUserHooks 128 | { hookedPrograms = [ llvmConfigProgram, ccProgram ] 129 | , confHook = \(genericPackageDesc, hookedBuildInfo) confFlags -> do 130 | tblgen <- buildTblgen confFlags 131 | let dialects = 132 | [ ("Func" , "mlir/Dialect/Func/IR/FuncOps.td", []) 133 | , ("Arith" , "mlir/Dialect/Arith/IR/ArithOps.td", ["-strip-prefix", "Arith_"]) 134 | , ("ControlFlow" , "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td", ["-dialect-name", "ControlFlow"]) 135 | , ("Vector" , "mlir/Dialect/Vector/IR/VectorOps.td", ["-strip-prefix", "Vector_"]) 136 | , ("Shape" , "mlir/Dialect/Shape/IR/ShapeOps.td", ["-strip-prefix", "Shape_"]) 137 | , ("LLVM" , "mlir/Dialect/LLVMIR/LLVMOps.td", ["-strip-prefix", "LLVM_", "-dialect-name", "LLVM"]) 138 | , ("Linalg" , "mlir/Dialect/Linalg/IR/LinalgOps.td", []) 139 | , ("LinalgStructured", "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td", ["-dialect-name", "LinalgStructured"]) 140 | , ("Tensor" , "mlir/Dialect/Tensor/IR/TensorOps.td", ["-strip-prefix", "Tensor_"]) 141 | , ("X86Vector" , "mlir/Dialect/X86Vector/X86Vector.td", ["-dialect-name", "X86Vector"]) 142 | ] 143 | ensureDirectory "src/MLIR/AST/Dialect/Generated" 144 | generatedModules <- forM dialects $ \(dialect, tdPath, opts) -> do 145 | tblgen OpGenerator tdPath ("src/MLIR/AST/Dialect/Generated/" <> dialect <> ".hs") opts 146 | return $ fromString $ "MLIR.AST.Dialect.Generated." <> dialect 147 | 148 | -- TODO: Do I need to do anything about the rpath? 149 | llvmConfig <- getLLVMConfig confFlags 150 | (llvmLibDirFlags , llvmLdFlags) <- partition isLibDir . words <$> llvmConfig ["--ldflags"] 151 | (llvmIncludeFlags, llvmCcFlags) <- partition isIncludeDir . words <$> llvmConfig ["--cflags"] 152 | let llvmIncludeDirs = (fromJust . (stripPrefix "-I")) <$> llvmIncludeFlags 153 | let llvmLibDirs = (fromJust . (stripPrefix "-L")) <$> llvmLibDirFlags 154 | let Just condLib = condLibrary genericPackageDesc 155 | let newLibrary = condTreeData condLib 156 | & over (libBuildInfo . buildInfo . ccOptions ) (<> llvmCcFlags ) 157 | & over (libBuildInfo . buildInfo . includeDirs ) (<> llvmIncludeDirs ) 158 | & over (libBuildInfo . buildInfo . ldOptions ) (<> llvmLdFlags ) 159 | & over (libBuildInfo . buildInfo . extraLibDirs ) (<> llvmLibDirs ) 160 | & over (libBuildInfo . buildInfo . otherModules ) (<> generatedModules) 161 | & over (libBuildInfo . buildInfo . autogenModules) (<> generatedModules) 162 | let newCondLibrary = condLib { condTreeData = newLibrary } 163 | 164 | ensureDirectory "test/MLIR/AST/Dialect/Generated" 165 | generatedSpecModules <- liftM catMaybes $ forM dialects $ \(dialect, tdPath, opts) -> do 166 | case dialect of 167 | "LinalgStructured" -> return Nothing 168 | _ -> do 169 | tblgen TestGenerator tdPath ("test/MLIR/AST/Dialect/Generated/" <> dialect <> "Spec.hs") opts 170 | return $ Just $ fromString $ "MLIR.AST.Dialect.Generated." <> dialect <> "Spec" 171 | let [(testSuiteName, condTestSuite)] = condTestSuites genericPackageDesc 172 | let newTestSuite = condTreeData condTestSuite 173 | & over (testBuildInfo . otherModules) (<> generatedSpecModules) 174 | 175 | 176 | let newGenericPackageDesc = genericPackageDesc 177 | { condLibrary = Just newCondLibrary 178 | , condTestSuites = [(testSuiteName, condTestSuite { condTreeData = newTestSuite })] 179 | } 180 | confHook simpleUserHooks (newGenericPackageDesc, hookedBuildInfo) confFlags 181 | } 182 | 183 | 184 | -------------------------------------------------------------------------------- /docs/code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of 9 | experience, education, socio-economic status, nationality, personal appearance, 10 | race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or reject 41 | comments, commits, code, wiki edits, issues, and other contributions that are 42 | not aligned to this Code of Conduct, or to ban temporarily or permanently any 43 | contributor for other behaviors that they deem inappropriate, threatening, 44 | offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when the Project 56 | Steward has a reasonable belief that an individual's behavior may have a 57 | negative impact on the project or its community. 58 | 59 | ## Conflict Resolution 60 | 61 | We do not believe that all conflict is bad; healthy debate and disagreement 62 | often yield positive results. However, it is never okay to be disrespectful or 63 | to engage in behavior that violates the project’s code of conduct. 64 | 65 | If you see someone violating the code of conduct, you are encouraged to address 66 | the behavior directly with those involved. Many issues can be resolved quickly 67 | and easily, and this gives people more control over the outcome of their 68 | dispute. If you are unable to resolve the matter for any reason, or if the 69 | behavior is threatening or harassing, report it. We are dedicated to providing 70 | an environment where participants feel welcome and safe. 71 | 72 | Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the 73 | Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to 74 | receive and address reported violations of the code of conduct. They will then 75 | work with a committee consisting of representatives from the Open Source 76 | Programs Office and the Google Open Source Strategy team. If for any reason you 77 | are uncomfortable reaching out to the Project Steward, please email 78 | opensource@google.com. 79 | 80 | We will investigate every complaint, but you may not receive a direct response. 81 | We will use our discretion in determining when and how to follow up on reported 82 | incidents, which may range from not taking action to permanent expulsion from 83 | the project and project-sponsored spaces. We will notify the accused of the 84 | report and provide them an opportunity to discuss it before any action is taken. 85 | The identity of the reporter will be omitted from the details of the report 86 | supplied to the accused. In potentially harmful situations, such as ongoing 87 | harassment or threats to anyone's safety, we may take action without notice. 88 | 89 | ## Attribution 90 | 91 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4, 92 | available at 93 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 94 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /mlir-hs.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 3.0 2 | 3 | name: mlir-hs 4 | version: 0.1.0.0 5 | description: Haskell bindings to MLIR 6 | homepage: https://github.com/google/mlir-hs#readme 7 | bug-reports: https://github.com/google/mlir-hs/issues 8 | author: Adam Paszke 9 | maintainer: apaszke@google.com 10 | copyright: 2021 Google 11 | license: Apache-2.0 12 | license-file: LICENSE 13 | build-type: Custom 14 | extra-source-files: 15 | README.md 16 | 17 | source-repository head 18 | type: git 19 | location: https://github.com/google/mlir-hs 20 | 21 | common defaults 22 | default-language: Haskell2010 23 | default-extensions: 24 | BlockArguments 25 | , DeriveGeneric 26 | , DerivingVia 27 | , FlexibleContexts 28 | , FlexibleInstances 29 | , FunctionalDependencies 30 | , GADTs 31 | , GeneralizedNewtypeDeriving 32 | , LambdaCase 33 | , OverloadedStrings 34 | , PatternSynonyms 35 | , QuasiQuotes 36 | , RecordWildCards 37 | , RecursiveDo 38 | , ScopedTypeVariables 39 | , StandaloneDeriving 40 | , TemplateHaskell 41 | , TupleSections 42 | , TypeApplications 43 | , TypeSynonymInstances 44 | , ViewPatterns 45 | 46 | custom-setup 47 | setup-depends: 48 | base 49 | , Cabal 50 | , lens 51 | , directory 52 | , filepath 53 | 54 | library 55 | import: defaults 56 | hs-source-dirs: src 57 | exposed-modules: 58 | MLIR.AST 59 | , MLIR.AST.Builder 60 | , MLIR.AST.Dialect.Affine 61 | , MLIR.AST.Dialect.Arith 62 | , MLIR.AST.Dialect.ControlFlow 63 | , MLIR.AST.Dialect.Func 64 | , MLIR.AST.Dialect.LLVM 65 | , MLIR.AST.Dialect.Linalg 66 | , MLIR.AST.Dialect.MemRef 67 | , MLIR.AST.Dialect.Shape 68 | , MLIR.AST.Dialect.Tensor 69 | , MLIR.AST.Dialect.Vector 70 | , MLIR.AST.Dialect.X86Vector 71 | , MLIR.AST.IStorableArray 72 | , MLIR.AST.PatternUtil 73 | , MLIR.AST.Rewrite 74 | , MLIR.AST.Serialize 75 | , MLIR.Native 76 | , MLIR.Native.ExecutionEngine 77 | , MLIR.Native.FFI 78 | , MLIR.Native.Pass 79 | build-depends: 80 | base >=4.7 && <5 81 | , inline-c 82 | , mtl 83 | , raw-strings-qq 84 | , array 85 | , containers 86 | , bytestring 87 | , transformers 88 | extra-libraries: 89 | MLIR-C 90 | 91 | test-suite spec 92 | import: defaults 93 | type: exitcode-stdio-1.0 94 | main-is: Spec.hs 95 | ghc-options: -Wall 96 | hs-source-dirs: test 97 | build-depends: 98 | base >=4.7 && <5 99 | , mlir-hs 100 | , array 101 | , hspec 102 | , transformers 103 | , bytestring 104 | , raw-strings-qq 105 | , vector 106 | , mtl 107 | , QuickCheck 108 | , generic-random 109 | , containers 110 | other-modules: 111 | MLIR.ASTSpec 112 | , MLIR.BuilderSpec 113 | , MLIR.NativeSpec 114 | , MLIR.RewriteSpec 115 | , MLIR.Test.Generators 116 | build-tool-depends: hspec-discover:hspec-discover 117 | -------------------------------------------------------------------------------- /src/MLIR/AST.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST where 16 | 17 | import qualified Data.ByteString as BS 18 | 19 | import Data.Typeable 20 | import Data.Int 21 | import Data.Word 22 | import Data.Coerce 23 | import Data.Ix 24 | import Data.Array.IArray 25 | import Foreign.Ptr 26 | import Foreign.Marshal.Alloc 27 | import Foreign.Marshal.Array 28 | import qualified Language.C.Inline as C 29 | import qualified Data.ByteString.Unsafe as BS 30 | import Control.Monad 31 | import Control.Monad.IO.Class 32 | import Control.Monad.Trans.Cont 33 | import qualified Data.Map.Strict as M 34 | 35 | import qualified MLIR.Native as Native 36 | import qualified MLIR.Native.FFI as Native 37 | import qualified MLIR.AST.Dialect.Affine as Affine 38 | import MLIR.AST.Serialize 39 | import MLIR.AST.IStorableArray 40 | 41 | type Name = BS.ByteString 42 | type UInt = Word 43 | 44 | data Signedness = Signed | Unsigned | Signless 45 | deriving Eq 46 | data Type = 47 | -- Builtin types 48 | -- See 49 | BFloat16Type 50 | | Float16Type 51 | | Float32Type 52 | | Float64Type 53 | | Float80Type 54 | | Float128Type 55 | | ComplexType Type 56 | | IndexType 57 | | IntegerType Signedness UInt 58 | | TupleType [Type] 59 | | NoneType 60 | | FunctionType [Type] [Type] 61 | | MemRefType { memrefTypeShape :: [Maybe Int] 62 | , memrefTypeElement :: Type 63 | , memrefTypeLayout :: Maybe Attribute 64 | , memrefTypeMemorySpace :: Maybe Attribute } 65 | | RankedTensorType { rankedTensorTypeShape :: [Maybe Int] 66 | , rankedTensorTypeElement :: Type 67 | , rankedTensorTypeEncoding :: Maybe Attribute } 68 | | VectorType { vectorTypeShape :: [Int] 69 | , vectorTypeElement :: Type } 70 | | UnrankedMemRefType { unrankedMemrefTypeElement :: Type 71 | , unrankedMemrefTypeMemorySpace :: Attribute } 72 | | UnrankedTensorType { unrankedTensorTypeElement :: Type } 73 | | OpaqueType { opaqueTypeNamespace :: Name 74 | , opaqueTypeData :: BS.ByteString } 75 | | forall t. (Typeable t, Eq t, FromAST t Native.Type) => DialectType t 76 | -- GHC cannot derive Eq due to the existential case, so we implement Eq below 77 | -- deriving Eq 78 | 79 | instance Eq Type where 80 | a == b = case (a, b) of 81 | (BFloat16Type , BFloat16Type ) -> True 82 | (Float16Type , Float16Type ) -> True 83 | (Float32Type , Float32Type ) -> True 84 | (Float64Type , Float64Type ) -> True 85 | (Float80Type , Float80Type ) -> True 86 | (Float128Type , Float128Type ) -> True 87 | (ComplexType a1 , ComplexType b1 ) -> a1 == b1 88 | (IndexType , IndexType ) -> True 89 | (IntegerType a1 a2 , IntegerType b1 b2 ) -> (a1, a2) == (b1, b2) 90 | (TupleType a1 , TupleType b1 ) -> a1 == b1 91 | (NoneType , NoneType ) -> True 92 | (FunctionType a1 a2, FunctionType b1 b2) -> (a1, a2) == (b1, b2) 93 | (MemRefType a1 a2 a3 a4 , MemRefType b1 b2 b3 b4 ) -> (a1, a2, a3, a4) == (b1, b2, b3, b4) 94 | (RankedTensorType a1 a2 a3, RankedTensorType b1 b2 b3) -> (a1, a2, a3 ) == (b1, b2, b3 ) 95 | (VectorType a1 a2 , VectorType b1 b2 ) -> (a1, a2) == (b1, b2) 96 | (UnrankedMemRefType a1 a2, UnrankedMemRefType b1 b2 ) -> (a1, a2) == (b1, b2) 97 | (UnrankedTensorType a1 , UnrankedTensorType b1 ) -> (a1 ) == (b1 ) 98 | (OpaqueType a1 a2 , OpaqueType b1 b2 ) -> (a1, a2) == (b1, b2) 99 | (DialectType a1 , DialectType b1 ) -> case cast a1 of 100 | Just a1' -> a1' == b1 101 | Nothing -> False 102 | _ -> False 103 | 104 | data Location = 105 | UnknownLocation 106 | | FileLocation { locPath :: BS.ByteString, locLine :: UInt, locColumn :: UInt } 107 | | NameLocation { locName :: BS.ByteString, locChild :: Location } 108 | | FusedLocation { locLocations :: [Location], locMetadata :: Maybe Attribute } 109 | -- TODO(jpienaar): Add support C API side and implement these 110 | | CallSiteLocation 111 | | OpaqueLocation 112 | 113 | data Binding = Bind [Name] Operation 114 | 115 | pattern Do :: Operation -> Binding 116 | pattern Do op = Bind [] op 117 | 118 | pattern (:=) :: Name -> Operation -> Binding 119 | pattern (:=) name op = Bind [name] op 120 | 121 | pattern (::=) :: [Name] -> Operation -> Binding 122 | pattern (::=) names op = Bind names op 123 | 124 | data Block = Block { 125 | blockName :: Name 126 | , blockArgs :: [(Name, Type)] 127 | , blockBody :: [Binding] 128 | } 129 | 130 | data Region = Region [Block] 131 | 132 | data Attribute = 133 | ArrayAttr [Attribute] 134 | | DictionaryAttr (M.Map Name Attribute) 135 | | FloatAttr Type Double 136 | | IntegerAttr Type Int 137 | | BoolAttr Bool 138 | | StringAttr BS.ByteString 139 | | TypeAttr Type 140 | | AffineMapAttr Affine.Map 141 | | UnitAttr 142 | | DenseArrayAttr DenseElements 143 | | DenseElementsAttr Type DenseElements 144 | -- Represents Attribute textually represented. 145 | | AsmTextAttr BS.ByteString 146 | | forall t. (Typeable t, Eq t, Show t, FromAST t Native.Attribute) => DialectAttr t 147 | -- GHC cannot derive Eq due to the existential case, so we implement Eq below 148 | -- deriving Eq 149 | -- TODO(apaszke): (Flat) SymbolRef, IntegerSet, Opaque 150 | 151 | instance Eq Attribute where 152 | a == b = case (a, b) of 153 | (ArrayAttr a1, ArrayAttr b1) -> a1 == b1 154 | (DictionaryAttr a1, DictionaryAttr b1) -> a1 == b1 155 | (FloatAttr a1 a2, FloatAttr b1 b2) -> (a1, a2) == (b1, b2) 156 | (IntegerAttr a1 a2, IntegerAttr b1 b2) -> (a1, a2) == (b1, b2) 157 | (BoolAttr a1, BoolAttr b1) -> a1 == b1 158 | (StringAttr a1, StringAttr b1) -> a1 == b1 159 | (TypeAttr a1, TypeAttr b1) -> a1 == b1 160 | (AffineMapAttr a1, AffineMapAttr b1) -> a1 == b1 161 | (UnitAttr, UnitAttr) -> True 162 | (DenseArrayAttr a1, DenseArrayAttr b1) -> a1 == b1 163 | (DenseElementsAttr a1 a2, DenseElementsAttr b1 b2) -> (a1, a2) == (b1, b2) 164 | (AsmTextAttr a1, AsmTextAttr b1) -> a1 == b1 165 | (DialectAttr a1, DialectAttr b1) -> case cast a1 of 166 | Just a1' -> a1' == b1 167 | Nothing -> False 168 | _ -> False 169 | 170 | data DenseElements 171 | = forall i. (Show i, Ix i) => DenseUInt8 (IStorableArray i Word8 ) 172 | | forall i. (Show i, Ix i) => DenseInt8 (IStorableArray i Int8 ) 173 | | forall i. (Show i, Ix i) => DenseUInt32 (IStorableArray i Word32) 174 | | forall i. (Show i, Ix i) => DenseInt32 (IStorableArray i Int32 ) 175 | | forall i. (Show i, Ix i) => DenseUInt64 (IStorableArray i Word64) 176 | | forall i. (Show i, Ix i) => DenseInt64 (IStorableArray i Int64 ) 177 | | forall i. (Show i, Ix i) => DenseFloat (IStorableArray i Float ) 178 | | forall i. (Show i, Ix i) => DenseDouble (IStorableArray i Double) 179 | 180 | -- Note that we use a relaxed notion of equality, where the indices don't matter! 181 | -- TODO: Use a faster comparison? We could really just use memcmp here. 182 | instance Eq DenseElements where 183 | a == b = case (a, b) of 184 | (DenseUInt8 da, DenseUInt8 db) -> elems da == elems db 185 | (DenseInt8 da, DenseInt8 db) -> elems da == elems db 186 | (DenseUInt32 da, DenseUInt32 db) -> elems da == elems db 187 | (DenseInt32 da, DenseInt32 db) -> elems da == elems db 188 | (DenseUInt64 da, DenseUInt64 db) -> elems da == elems db 189 | (DenseInt64 da, DenseInt64 db) -> elems da == elems db 190 | (DenseFloat da, DenseFloat db) -> elems da == elems db 191 | (DenseDouble da, DenseDouble db) -> elems da == elems db 192 | _ -> False 193 | 194 | data ResultTypes = Explicit [Type] | Inferred 195 | 196 | type NamedAttributes = M.Map Name Attribute 197 | 198 | data AbstractOperation operand = Operation { 199 | opName :: Name, 200 | opLocation :: Location, 201 | opResultTypes :: ResultTypes, 202 | opOperands :: [operand], 203 | opRegions :: [Region], 204 | opSuccessors :: [Name], 205 | opAttributes :: M.Map Name Attribute 206 | } 207 | type Operation = AbstractOperation Name 208 | 209 | -------------------------------------------------------------------------------- 210 | -- Builtin operations 211 | 212 | pattern NoAttrs :: M.Map Name Attribute 213 | pattern NoAttrs <- _ -- Accept any attributes 214 | where NoAttrs = M.empty 215 | 216 | namedAttribute :: Name -> Attribute -> NamedAttributes 217 | namedAttribute name value = M.singleton name value 218 | 219 | pattern ModuleOp :: Block -> Operation 220 | pattern ModuleOp body = Operation 221 | { opName = "builtin.module" 222 | , opLocation = UnknownLocation 223 | , opResultTypes = Explicit [] 224 | , opOperands = [] 225 | , opRegions = [Region [body]] 226 | , opSuccessors = [] 227 | , opAttributes = NoAttrs 228 | } 229 | 230 | pattern FuncAttrs :: Name -> Type -> M.Map Name Attribute 231 | pattern FuncAttrs name ty <- 232 | ((\d -> (M.lookup "sym_name" d, M.lookup "type" d)) -> 233 | (Just (StringAttr name), Just (TypeAttr ty))) 234 | where FuncAttrs name ty = M.fromList [("sym_name", StringAttr name), 235 | ("function_type", TypeAttr ty)] 236 | 237 | pattern FuncOp :: Location -> Name -> Type -> Region -> Operation 238 | pattern FuncOp loc name ty body = Operation 239 | { opName = "func.func" 240 | , opLocation = loc 241 | , opResultTypes = Explicit [] 242 | , opOperands = [] 243 | , opRegions = [body] 244 | , opSuccessors = [] 245 | , opAttributes = FuncAttrs name ty 246 | } 247 | 248 | -------------------------------------------------------------------------------- 249 | -- AST -> Native translation 250 | 251 | C.context $ C.baseCtx <> Native.mlirCtx 252 | 253 | C.include "" 254 | C.include "mlir-c/IR.h" 255 | C.include "mlir-c/BuiltinTypes.h" 256 | C.include "mlir-c/BuiltinAttributes.h" 257 | 258 | instance FromAST Location Native.Location where 259 | fromAST ctx env loc = case loc of 260 | UnknownLocation -> Native.getUnknownLocation ctx 261 | FileLocation file line col -> do 262 | Native.withStringRef file \fileStrRef -> 263 | Native.getFileLineColLocation ctx fileStrRef cline ccol 264 | where cline = fromIntegral line 265 | ccol = fromIntegral col 266 | FusedLocation locLocations locMetadata -> do 267 | metadata <- case locMetadata of 268 | -- TODO: Consider factoring out to convenience function. 269 | Nothing -> [C.exp| MlirAttribute { mlirAttributeGetNull() } |] 270 | Just l -> fromAST ctx env l 271 | evalContT $ do 272 | (numLocs, locs) <- packFromAST ctx env locLocations 273 | liftIO $ [C.exp| MlirLocation { 274 | mlirLocationFusedGet($(MlirContext ctx), 275 | $(intptr_t numLocs), $(MlirLocation* locs), 276 | $(MlirAttribute metadata)) 277 | } |] 278 | NameLocation name childLoc -> do 279 | Native.withStringRef name \nameStrRef -> do 280 | nativeChildLoc <- fromAST ctx env childLoc 281 | Native.getNameLocation ctx nameStrRef nativeChildLoc 282 | -- TODO(jpienaar): Fix 283 | _ -> error "Unimplemented Location case" 284 | 285 | instance FromAST Type Native.Type where 286 | fromAST ctx env ty = case ty of 287 | BFloat16Type -> [C.exp| MlirType { mlirBF16TypeGet($(MlirContext ctx)) } |] 288 | Float16Type -> [C.exp| MlirType { mlirF16TypeGet($(MlirContext ctx)) } |] 289 | Float32Type -> [C.exp| MlirType { mlirF32TypeGet($(MlirContext ctx)) } |] 290 | Float64Type -> [C.exp| MlirType { mlirF64TypeGet($(MlirContext ctx)) } |] 291 | Float80Type -> error "Float80Type missing in the MLIR C API!" 292 | Float128Type -> error "Float128Type missing in the MLIR C API!" 293 | ComplexType e -> do 294 | ne <- fromAST ctx env e 295 | [C.exp| MlirType { mlirComplexTypeGet($(MlirType ne)) } |] 296 | FunctionType args rets -> evalContT $ do 297 | (numArgs, nativeArgs) <- packFromAST ctx env args 298 | (numRets, nativeRets) <- packFromAST ctx env rets 299 | liftIO $ [C.exp| MlirType { 300 | mlirFunctionTypeGet($(MlirContext ctx), 301 | $(intptr_t numArgs), $(MlirType* nativeArgs), 302 | $(intptr_t numRets), $(MlirType* nativeRets)) 303 | } |] 304 | IndexType -> [C.exp| MlirType { mlirIndexTypeGet($(MlirContext ctx)) } |] 305 | IntegerType signedness width -> case signedness of 306 | Signless -> [C.exp| MlirType { 307 | mlirIntegerTypeGet($(MlirContext ctx), $(unsigned int cwidth)) 308 | } |] 309 | Signed -> [C.exp| MlirType { 310 | mlirIntegerTypeSignedGet($(MlirContext ctx), $(unsigned int cwidth)) 311 | } |] 312 | Unsigned -> [C.exp| MlirType { 313 | mlirIntegerTypeUnsignedGet($(MlirContext ctx), $(unsigned int cwidth)) 314 | } |] 315 | where cwidth = fromIntegral width 316 | MemRefType shape elTy layout memSpace -> evalContT $ do 317 | (rank, nativeShape) <- packArray shapeI64 318 | liftIO $ do 319 | nativeElTy <- fromAST ctx env elTy 320 | nativeSpace <- case memSpace of 321 | Just space -> fromAST ctx env space 322 | Nothing -> return $ coerce nullPtr 323 | nativeLayout <- case layout of 324 | Just alayout -> fromAST ctx env alayout 325 | Nothing -> return $ coerce nullPtr 326 | [C.exp| MlirType { 327 | mlirMemRefTypeGet($(MlirType nativeElTy), 328 | $(intptr_t rank), $(int64_t* nativeShape), 329 | $(MlirAttribute nativeLayout), $(MlirAttribute nativeSpace)) 330 | } |] 331 | where shapeI64 = fmap (maybe (-1) fromIntegral) shape :: [Int64] 332 | NoneType -> [C.exp| MlirType { mlirNoneTypeGet($(MlirContext ctx)) } |] 333 | OpaqueType _ _ -> notImplemented 334 | RankedTensorType shape elTy encoding -> evalContT $ do 335 | (rank, nativeShape) <- packArray shapeI64 336 | liftIO $ do 337 | nativeElTy <- fromAST ctx env elTy 338 | nativeEncoding <- case encoding of 339 | Just enc -> fromAST ctx env enc 340 | Nothing -> return $ coerce nullPtr 341 | [C.exp| MlirType { 342 | mlirRankedTensorTypeGet($(intptr_t rank), $(int64_t* nativeShape), 343 | $(MlirType nativeElTy), $(MlirAttribute nativeEncoding)) 344 | } |] 345 | where shapeI64 = fmap (maybe (-1) fromIntegral) shape :: [Int64] 346 | TupleType tys -> evalContT $ do 347 | (numTypes, nativeTypes) <- packFromAST ctx env tys 348 | liftIO $ [C.exp| MlirType { 349 | mlirTupleTypeGet($(MlirContext ctx), $(intptr_t numTypes), $(MlirType* nativeTypes)) 350 | } |] 351 | UnrankedMemRefType elTy attr -> do 352 | nativeElTy <- fromAST ctx env elTy 353 | nativeAttr <- fromAST ctx env attr 354 | [C.exp| MlirType { 355 | mlirUnrankedMemRefTypeGet($(MlirType nativeElTy), $(MlirAttribute nativeAttr)) 356 | } |] 357 | UnrankedTensorType elTy -> do 358 | nativeElTy <- fromAST ctx env elTy 359 | [C.exp| MlirType { 360 | mlirUnrankedTensorTypeGet($(MlirType nativeElTy)) 361 | } |] 362 | VectorType shape elTy -> evalContT $ do 363 | (rank, nativeShape) <- packArray shapeI64 364 | liftIO $ do 365 | nativeElTy <- fromAST ctx env elTy 366 | [C.exp| MlirType { 367 | mlirVectorTypeGet($(intptr_t rank), $(int64_t* nativeShape), $(MlirType nativeElTy)) 368 | } |] 369 | where shapeI64 = fmap fromIntegral shape :: [Int64] 370 | DialectType t -> fromAST ctx env t 371 | 372 | 373 | instance FromAST Region Native.Region where 374 | fromAST ctx env@(valueEnv, _) (Region blocks) = do 375 | region <- [C.exp| MlirRegion { mlirRegionCreate() } |] 376 | blockEnv <- foldM (initAppendBlock region) mempty blocks 377 | mapM_ (fromAST ctx (valueEnv, blockEnv)) blocks 378 | return region 379 | where 380 | initAppendBlock :: Native.Region -> BlockMapping -> Block -> IO BlockMapping 381 | initAppendBlock region blockEnv block = do 382 | nativeBlock <- initBlock block 383 | [C.exp| void { 384 | mlirRegionAppendOwnedBlock($(MlirRegion region), $(MlirBlock nativeBlock)) 385 | } |] 386 | return $ blockEnv <> (M.singleton (blockName block) nativeBlock) 387 | 388 | initBlock :: Block -> IO Native.Block 389 | initBlock Block{..} = do 390 | -- TODO: Use proper locations 391 | let locations = take (length blockArgs) (repeat UnknownLocation) 392 | evalContT $ do 393 | let blockArgTypes = snd <$> blockArgs 394 | (numBlockArgs, nativeArgTypes) <- packFromAST ctx env blockArgTypes 395 | (_, locs) <- packFromAST ctx env locations 396 | liftIO $ [C.exp| MlirBlock { 397 | mlirBlockCreate($(intptr_t numBlockArgs), $(MlirType* nativeArgTypes), $(MlirLocation* locs)) 398 | } |] 399 | 400 | 401 | instance FromAST Block Native.Block where 402 | fromAST ctx (outerValueEnv, blockEnv) Block{..} = do 403 | let block = blockEnv M.! blockName 404 | nativeBlockArgs <- getBlockArgs block 405 | let blockArgNames = fst <$> blockArgs 406 | let argValueEnv = M.fromList $ zip blockArgNames nativeBlockArgs 407 | foldM_ (appendInstr block) (outerValueEnv <> argValueEnv) blockBody 408 | return block 409 | where 410 | appendInstr :: Native.Block -> ValueMapping -> Binding -> IO ValueMapping 411 | appendInstr block valueEnv (Bind names operation) = do 412 | nativeOperation <- fromAST ctx (valueEnv, blockEnv) operation 413 | [C.exp| void { 414 | mlirBlockAppendOwnedOperation($(MlirBlock block), 415 | $(MlirOperation nativeOperation)) 416 | } |] 417 | nativeResults <- getOperationResults nativeOperation 418 | return $ valueEnv <> (M.fromList $ zip names nativeResults) 419 | 420 | getBlockArgs :: Native.Block -> IO [Native.Value] 421 | getBlockArgs block = do 422 | numArgs <- [C.exp| intptr_t { mlirBlockGetNumArguments($(MlirBlock block)) } |] 423 | allocaArray (fromIntegral numArgs) \nativeArgs -> do 424 | [C.block| void { 425 | for (intptr_t i = 0; i < $(intptr_t numArgs); ++i) { 426 | $(MlirValue* nativeArgs)[i] = mlirBlockGetArgument($(MlirBlock block), i); 427 | } 428 | } |] 429 | unpackArray numArgs nativeArgs 430 | 431 | getOperationResults :: Native.Operation -> IO [Native.Value] 432 | getOperationResults op = do 433 | numResults <- [C.exp| intptr_t { mlirOperationGetNumResults($(MlirOperation op)) } |] 434 | allocaArray (fromIntegral numResults) \nativeResults -> do 435 | [C.block| void { 436 | for (intptr_t i = 0; i < $(intptr_t numResults); ++i) { 437 | $(MlirValue* nativeResults)[i] = mlirOperationGetResult($(MlirOperation op), i); 438 | } 439 | } |] 440 | unpackArray numResults nativeResults 441 | 442 | 443 | instance FromAST Attribute Native.Attribute where 444 | fromAST ctx env attr = case attr of 445 | ArrayAttr attrs -> evalContT $ do 446 | (numAttrs, nativeAttrs) <- packFromAST ctx env attrs 447 | liftIO $ [C.exp| MlirAttribute { 448 | mlirArrayAttrGet($(MlirContext ctx), $(intptr_t numAttrs), $(MlirAttribute* nativeAttrs)) 449 | } |] 450 | DictionaryAttr dict -> evalContT $ do 451 | (numAttrs, nativeAttrs) <- packNamedAttrs ctx env dict 452 | liftIO $ [C.exp| MlirAttribute { 453 | mlirDictionaryAttrGet($(MlirContext ctx), $(intptr_t numAttrs), $(MlirNamedAttribute* nativeAttrs)) 454 | } |] 455 | DialectAttr at -> fromAST ctx env at 456 | FloatAttr ty value -> do 457 | nativeType <- fromAST ctx env ty 458 | let nativeValue = coerce value 459 | [C.exp| MlirAttribute { 460 | mlirFloatAttrDoubleGet($(MlirContext ctx), $(MlirType nativeType), $(double nativeValue)) 461 | } |] 462 | IntegerAttr ty value -> do 463 | nativeType <- fromAST ctx env ty 464 | let nativeValue = fromIntegral value 465 | [C.exp| MlirAttribute { 466 | mlirIntegerAttrGet($(MlirType nativeType), $(int64_t nativeValue)) 467 | } |] 468 | BoolAttr value -> do 469 | let nativeValue = if value then 1 else 0 470 | [C.exp| MlirAttribute { 471 | mlirBoolAttrGet($(MlirContext ctx), $(int nativeValue)) 472 | } |] 473 | StringAttr value -> do 474 | Native.withStringRef value \(Native.StringRef ptr len) -> 475 | [C.exp| MlirAttribute { 476 | mlirStringAttrGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)}) 477 | } |] 478 | AsmTextAttr value -> 479 | Native.withStringRef value \(Native.StringRef ptr len) -> 480 | [C.exp| MlirAttribute { 481 | mlirAttributeParseGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)}) 482 | } |] 483 | TypeAttr ty -> do 484 | nativeType <- fromAST ctx env ty 485 | [C.exp| MlirAttribute { mlirTypeAttrGet($(MlirType nativeType)) } |] 486 | AffineMapAttr afMap -> do 487 | nativeMap <- fromAST ctx env afMap 488 | [C.exp| MlirAttribute { mlirAffineMapAttrGet($(MlirAffineMap nativeMap)) } |] 489 | UnitAttr -> [C.exp| MlirAttribute { mlirUnitAttrGet($(MlirContext ctx)) } |] 490 | DenseArrayAttr storage -> do 491 | case storage of 492 | DenseInt8 arr -> do 493 | let size = fromIntegral $ rangeSize $ bounds arr 494 | unsafeWithIStorableArray arr \valuesPtr -> 495 | [C.exp| MlirAttribute { 496 | mlirDenseI8ArrayGet($(MlirContext ctx), $(intptr_t size), 497 | $(const int8_t* valuesPtr)) 498 | } |] 499 | DenseInt32 arr -> do 500 | let size = fromIntegral $ rangeSize $ bounds arr 501 | unsafeWithIStorableArray arr \valuesPtr -> 502 | [C.exp| MlirAttribute { 503 | mlirDenseI32ArrayGet($(MlirContext ctx), $(intptr_t size), 504 | $(const int32_t* valuesPtr)) 505 | } |] 506 | DenseInt64 arr -> do 507 | let size = fromIntegral $ rangeSize $ bounds arr 508 | unsafeWithIStorableArray arr \valuesPtr -> 509 | [C.exp| MlirAttribute { 510 | mlirDenseI64ArrayGet($(MlirContext ctx), $(intptr_t size), 511 | $(const int64_t* valuesPtr)) 512 | } |] 513 | DenseFloat arr -> do 514 | let size = fromIntegral $ rangeSize $ bounds arr 515 | unsafeWithIStorableArray arr \valuesPtrHs -> do 516 | let valuesPtr = castPtr valuesPtrHs 517 | [C.exp| MlirAttribute { 518 | mlirDenseF32ArrayGet($(MlirContext ctx), $(intptr_t size), 519 | $(const float* valuesPtr)) 520 | } |] 521 | DenseDouble arr -> do 522 | let size = fromIntegral $ rangeSize $ bounds arr 523 | unsafeWithIStorableArray arr \valuesPtrHs -> do 524 | let valuesPtr = castPtr valuesPtrHs 525 | [C.exp| MlirAttribute { 526 | mlirDenseF64ArrayGet($(MlirContext ctx), $(intptr_t size), 527 | $(const double* valuesPtr)) 528 | } |] 529 | _ -> error "Found aDenseArray datatype unsupported in the MLIR API" 530 | DenseElementsAttr ty storage -> do 531 | nativeType <- fromAST ctx env ty 532 | case storage of 533 | DenseUInt8 arr -> do 534 | let size = fromIntegral $ rangeSize $ bounds arr 535 | unsafeWithIStorableArray arr \valuesPtr -> 536 | [C.exp| MlirAttribute { 537 | mlirDenseElementsAttrUInt8Get($(MlirType nativeType), $(intptr_t size), 538 | $(const uint8_t* valuesPtr)) 539 | } |] 540 | DenseInt8 arr -> do 541 | let size = fromIntegral $ rangeSize $ bounds arr 542 | unsafeWithIStorableArray arr \valuesPtr -> 543 | [C.exp| MlirAttribute { 544 | mlirDenseElementsAttrInt8Get($(MlirType nativeType), $(intptr_t size), 545 | $(const int8_t* valuesPtr)) 546 | } |] 547 | DenseUInt32 arr -> do 548 | let size = fromIntegral $ rangeSize $ bounds arr 549 | unsafeWithIStorableArray arr \valuesPtr -> 550 | [C.exp| MlirAttribute { 551 | mlirDenseElementsAttrUInt32Get($(MlirType nativeType), $(intptr_t size), 552 | $(const uint32_t* valuesPtr)) 553 | } |] 554 | DenseInt32 arr -> do 555 | let size = fromIntegral $ rangeSize $ bounds arr 556 | unsafeWithIStorableArray arr \valuesPtr -> 557 | [C.exp| MlirAttribute { 558 | mlirDenseElementsAttrInt32Get($(MlirType nativeType), $(intptr_t size), 559 | $(const int32_t* valuesPtr)) 560 | } |] 561 | DenseUInt64 arr -> do 562 | let size = fromIntegral $ rangeSize $ bounds arr 563 | unsafeWithIStorableArray arr \valuesPtr -> 564 | [C.exp| MlirAttribute { 565 | mlirDenseElementsAttrUInt64Get($(MlirType nativeType), $(intptr_t size), 566 | $(const uint64_t* valuesPtr)) 567 | } |] 568 | DenseInt64 arr -> do 569 | let size = fromIntegral $ rangeSize $ bounds arr 570 | unsafeWithIStorableArray arr \valuesPtr -> 571 | [C.exp| MlirAttribute { 572 | mlirDenseElementsAttrInt64Get($(MlirType nativeType), $(intptr_t size), 573 | $(const int64_t* valuesPtr)) 574 | } |] 575 | DenseFloat arr -> do 576 | let size = fromIntegral $ rangeSize $ bounds arr 577 | unsafeWithIStorableArray arr \valuesPtrHs -> do 578 | let valuesPtr = castPtr valuesPtrHs 579 | [C.exp| MlirAttribute { 580 | mlirDenseElementsAttrFloatGet($(MlirType nativeType), $(intptr_t size), 581 | $(const float* valuesPtr)) 582 | } |] 583 | DenseDouble arr -> do 584 | let size = fromIntegral $ rangeSize $ bounds arr 585 | unsafeWithIStorableArray arr \valuesPtrHs -> do 586 | let valuesPtr = castPtr valuesPtrHs 587 | [C.exp| MlirAttribute { 588 | mlirDenseElementsAttrDoubleGet($(MlirType nativeType), $(intptr_t size), 589 | $(const double* valuesPtr)) 590 | } |] 591 | 592 | 593 | instance FromAST Operation Native.Operation where 594 | fromAST ctx env@(valueEnv, blockEnv) Operation{..} = evalContT $ do 595 | (namePtr, nameLen) <- ContT $ BS.unsafeUseAsCStringLen opName 596 | let nameLenSizeT = fromIntegral nameLen 597 | (infersResults, (numResultTypes, nativeResultTypes)) <- case opResultTypes of 598 | Inferred -> return (CTrue, (0, nullPtr)) 599 | Explicit types -> (CFalse,) <$> packFromAST ctx env types 600 | nativeLocation <- liftIO $ fromAST ctx env opLocation 601 | (numOperands, nativeOperands) <- packArray $ fmap (valueEnv M.!) opOperands 602 | (numRegions, nativeRegions) <- packFromAST ctx env opRegions 603 | (numSuccessors, nativeSuccessors) <- packArray $ fmap (blockEnv M.!) opSuccessors 604 | (numAttributes, nativeAttributes) <- packNamedAttrs ctx env opAttributes 605 | -- NB: This is nullable when result type inference is enabled 606 | maybeOperation <- liftIO $ Native.nullable <$> [C.block| MlirOperation { 607 | MlirOperationState state = mlirOperationStateGet( 608 | (MlirStringRef){$(char* namePtr), $(size_t nameLenSizeT)}, 609 | $(MlirLocation nativeLocation)); 610 | if ($(bool infersResults)) { 611 | mlirOperationStateEnableResultTypeInference(&state); 612 | } else { 613 | mlirOperationStateAddResults( 614 | &state, $(intptr_t numResultTypes), $(MlirType* nativeResultTypes)); 615 | } 616 | mlirOperationStateAddOperands( 617 | &state, $(intptr_t numOperands), $(MlirValue* nativeOperands)); 618 | mlirOperationStateAddOwnedRegions( 619 | &state, $(intptr_t numRegions), $(MlirRegion* nativeRegions)); 620 | mlirOperationStateAddSuccessors( 621 | &state, $(intptr_t numSuccessors), $(MlirBlock* nativeSuccessors)); 622 | mlirOperationStateAddAttributes( 623 | &state, $(intptr_t numAttributes), $(MlirNamedAttribute* nativeAttributes)); 624 | return mlirOperationCreate(&state); 625 | } |] 626 | case maybeOperation of 627 | Just operation -> return operation 628 | Nothing -> error $ "Type inference failed for operation " ++ show opName 629 | 630 | -------------------------------------------------------------------------------- 631 | -- Utilities for AST -> Native translation 632 | 633 | packNamedAttrs :: Native.Context -> ValueAndBlockMapping 634 | -> M.Map Name Attribute -> ContT r IO (C.CIntPtr, Ptr Native.NamedAttribute) 635 | packNamedAttrs ctx env attrDict = do 636 | let arrSize = M.size attrDict 637 | elemSize <- liftIO $ fromIntegral <$> [C.exp| size_t { sizeof(MlirNamedAttribute) } |] 638 | elemAlign <- liftIO $ fromIntegral <$> [C.exp| size_t { alignof(MlirNamedAttribute) } |] 639 | ptr <- ContT $ allocaBytesAligned (arrSize * elemSize) (elemAlign) 640 | flip mapM_ (zip [0..] $ M.toList attrDict) \(i, (name, attr)) -> do 641 | nameRef <- ContT $ Native.withStringRef name 642 | liftIO $ do 643 | nativeAttr <- fromAST ctx env attr 644 | ident <- Native.createIdentifier ctx nameRef 645 | [C.exp| void { 646 | $(MlirNamedAttribute* ptr)[$(int i)] = 647 | mlirNamedAttributeGet($(MlirIdentifier ident), $(MlirAttribute nativeAttr)); 648 | } |] 649 | return (fromIntegral arrSize, ptr) 650 | 651 | pattern CTrue :: C.CBool 652 | pattern CTrue = C.CBool 1 653 | 654 | pattern CFalse :: C.CBool 655 | pattern CFalse = C.CBool 0 656 | 657 | notImplemented :: forall a. a 658 | notImplemented = error "Not implemented" 659 | -------------------------------------------------------------------------------- /src/MLIR/AST/Builder.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | {-# LANGUAGE UndecidableInstances #-} 16 | module MLIR.AST.Builder where 17 | 18 | import MLIR.AST 19 | import Data.String 20 | import Data.Functor 21 | import Control.Monad 22 | import Control.Monad.State.Strict 23 | import Control.Monad.Writer 24 | import Control.Monad.Reader 25 | 26 | -------------------------------------------------------------------------------- 27 | -- Value 28 | 29 | data Value = Name :> Type 30 | 31 | typeOf :: Value -> Type 32 | typeOf (_ :> ty) = ty 33 | 34 | operand :: Value -> Name 35 | operand (n :> _) = n 36 | 37 | operands :: [Value] -> [Name] 38 | operands = fmap operand 39 | 40 | -------------------------------------------------------------------------------- 41 | -- Name supply monad 42 | 43 | newtype NameSupply = NameSupply { nextName :: Int } 44 | newtype NameSupplyT m a = NameSupplyT (StateT NameSupply m a) 45 | deriving (Functor, Applicative, Monad, 46 | MonadTrans, MonadFix, 47 | MonadReader r, MonadWriter w) 48 | 49 | instance MonadState s m => MonadState s (NameSupplyT m) where 50 | get = lift get 51 | put = lift . put 52 | 53 | class Monad m => MonadNameSupply m where 54 | freshName :: m Name 55 | 56 | instance MonadNameSupply m => MonadNameSupply (ReaderT r m) where 57 | freshName = lift freshName 58 | 59 | evalNameSupplyT :: Monad m => NameSupplyT m a -> m a 60 | evalNameSupplyT (NameSupplyT a) = evalStateT a $ NameSupply 0 61 | 62 | instance Monad m => MonadNameSupply (NameSupplyT m) where 63 | freshName = NameSupplyT $ do 64 | curId <- gets nextName 65 | modify \s -> s { nextName = nextName s + 1 } 66 | return $ fromString $ show curId 67 | 68 | freshValue :: MonadNameSupply m => Type -> m Value 69 | freshValue ty = freshName <&> (:> ty) 70 | 71 | freshBlockArg :: MonadNameSupply m => Type -> m Value 72 | freshBlockArg ty = (("arg" <>) <$> freshName) <&> (:> ty) 73 | 74 | -------------------------------------------------------------------------------- 75 | -- Block builder monad 76 | 77 | -- TODO(apaszke): Thread locations through 78 | -- TODO(apaszke): Use a writer monad 79 | data BlockBindings = BlockBindings 80 | { blockBindings :: SnocList Binding 81 | , blockArguments :: SnocList Value 82 | , blockDefaultLocation :: Location 83 | } 84 | 85 | instance Semigroup BlockBindings where 86 | BlockBindings bs args _ <> BlockBindings bs' args' loc' = 87 | BlockBindings (bs <> bs') (args <> args') loc' 88 | 89 | instance Monoid BlockBindings where 90 | mempty = BlockBindings mempty mempty UnknownLocation 91 | 92 | newtype BlockBuilderT m a = BlockBuilderT (StateT BlockBindings m a) 93 | deriving (Functor, Applicative, Monad, 94 | MonadTrans, MonadFix, 95 | MonadReader r, MonadWriter w) 96 | 97 | instance MonadState s m => MonadState s (BlockBuilderT m) where 98 | get = lift get 99 | put = lift . put 100 | 101 | class Monad m => MonadBlockDecl m where 102 | emitOp_ :: Operation -> m () 103 | class MonadBlockDecl m => MonadBlockBuilder m where 104 | emitOp :: Operation -> m [Value] 105 | blockArgument :: Type -> m Value 106 | setDefaultLocation :: Location -> m () 107 | 108 | data EndOfBlock = EndOfBlock 109 | 110 | terminateBlock :: Monad m => m EndOfBlock 111 | terminateBlock = return EndOfBlock 112 | 113 | noTerminator :: Monad m => m EndOfBlock 114 | noTerminator = return EndOfBlock 115 | 116 | runBlockBuilder :: Monad m => BlockBuilderT m a -> m (a, ([Value], [Binding])) 117 | runBlockBuilder (BlockBuilderT act) = do 118 | (result, BlockBindings binds args _) <- runStateT act mempty 119 | return (result, (unsnocList args, unsnocList binds)) 120 | 121 | instance Monad m => MonadBlockDecl (BlockBuilderT m) where 122 | emitOp_ op = BlockBuilderT $ do 123 | case opResultTypes op of 124 | Inferred -> error "Builder doesn't support inferred result types!" 125 | Explicit [] -> modify \s -> s { blockBindings = blockBindings s .:. (Do op) } 126 | Explicit _ -> error "emitOp_ can only be used on ops that have no results" 127 | 128 | instance MonadNameSupply m => MonadBlockBuilder (BlockBuilderT m) where 129 | emitOp opNoLoc = BlockBuilderT $ do 130 | loc <- gets blockDefaultLocation 131 | let op = case opLocation opNoLoc of 132 | UnknownLocation -> opNoLoc { opLocation = loc } 133 | _ -> opNoLoc 134 | results <- case opResultTypes op of 135 | Inferred -> error "Builder doesn't support inferred result types!" 136 | Explicit tys -> lift $ mapM freshValue tys 137 | modify \s -> s { blockBindings = blockBindings s .:. (operands results ::= op) } 138 | return results 139 | blockArgument ty = BlockBuilderT $ do 140 | value <- lift $ freshValue ty 141 | modify \s -> s { blockArguments = blockArguments s .:. value } 142 | return value 143 | setDefaultLocation loc = BlockBuilderT $ modify \s -> s { blockDefaultLocation = loc } 144 | 145 | -------------------------------------------------------------------------------- 146 | -- Region builder monad 147 | 148 | -- TODO(apaszke): Make this a writer, assign block names only at the very end 149 | data RegionBuilderState = RegionBuilderState 150 | { blocks :: SnocList Block 151 | , nextBlockId :: Int 152 | } 153 | newtype RegionBuilderT m a = RegionBuilderT (StateT RegionBuilderState m a) 154 | deriving (Functor, Applicative, Monad, 155 | MonadTrans, MonadFix, 156 | MonadReader r, MonadWriter w) 157 | 158 | instance MonadState s m => MonadState s (RegionBuilderT m) where 159 | get = lift get 160 | put = lift . put 161 | 162 | type BlockName = Name 163 | 164 | class Monad m => MonadRegionBuilder m where 165 | appendBlock :: BlockBuilderT m EndOfBlock -> m BlockName 166 | 167 | endOfRegion :: Monad m => m () 168 | endOfRegion = return () 169 | 170 | buildRegion :: Monad m => RegionBuilderT m () -> m Region 171 | buildRegion (RegionBuilderT regionBuilder) = 172 | Region . unsnocList . blocks <$> execStateT regionBuilder (RegionBuilderState mempty 0) 173 | 174 | buildBlock :: Monad m => BlockBuilderT m EndOfBlock -> RegionBuilderT m BlockName 175 | buildBlock builder = RegionBuilderT $ do 176 | (EndOfBlock, (args, body)) <- lift $ runBlockBuilder builder 177 | makeBlock args body 178 | where 179 | makeBlock args body = do 180 | curBlockId <- gets nextBlockId 181 | modify (\s -> s { nextBlockId = nextBlockId s + 1 }) 182 | let blockName = "bb" <> (fromString $ show curBlockId) 183 | let block = Block blockName (fmap (\(n :> t) -> (n, t)) args) body 184 | modify (\s -> s { blocks = blocks s .:. block }) 185 | return blockName 186 | 187 | -------------------------------------------------------------------------------- 188 | -- Builtin dialect 189 | 190 | soleBlock :: Monad m => BlockBuilderT m EndOfBlock -> m Block 191 | soleBlock builder = do 192 | (EndOfBlock, (args, body)) <- runBlockBuilder builder 193 | return $ Block "0" (fmap (\(n :> t) -> (n, t)) args) body 194 | 195 | buildModule :: Monad m => BlockBuilderT m () -> m Operation 196 | buildModule build = liftM ModuleOp $ soleBlock $ build >> noTerminator 197 | 198 | declareFunction :: MonadBlockDecl m => Name -> Type -> m () 199 | declareFunction name funcTy = 200 | emitOp_ $ FuncOp UnknownLocation name funcTy $ Region [] 201 | 202 | buildFunction :: MonadBlockDecl m 203 | => Name -> [Type] -> NamedAttributes 204 | -> RegionBuilderT (NameSupplyT m) () -> m () 205 | buildFunction name retTypes attrs bodyBuilder = do 206 | body@(Region blocks) <- evalNameSupplyT $ buildRegion bodyBuilder 207 | let argTypes = case blocks of 208 | [] -> error $ "buildFunction cannot be used for function declarations! " ++ 209 | "Build at least one block!" 210 | (Block _ args _) : _ -> fmap snd args 211 | let op = FuncOp UnknownLocation name (FunctionType argTypes retTypes) body 212 | emitOp_ $ op { opAttributes = opAttributes op <> attrs } 213 | 214 | buildSimpleFunction :: MonadBlockDecl m 215 | => Name -> [Type] -> NamedAttributes 216 | -> BlockBuilderT (NameSupplyT m) EndOfBlock -> m () 217 | buildSimpleFunction name retTypes attrs bodyBuilder = do 218 | block <- evalNameSupplyT $ soleBlock bodyBuilder 219 | let argTypes = fmap snd $ blockArgs block 220 | let fTy = FunctionType argTypes retTypes 221 | let op = FuncOp UnknownLocation name fTy $ Region [block] 222 | emitOp_ $ op { opAttributes = opAttributes op <> attrs } 223 | 224 | -------------------------------------------------------------------------------- 225 | -- Utilities 226 | 227 | newtype SnocList a = SnocList [a] 228 | 229 | (.:.) :: SnocList a -> a -> SnocList a 230 | (SnocList t) .:. h = SnocList (h : t) 231 | 232 | unsnocList :: SnocList a -> [a] 233 | unsnocList (SnocList l) = reverse l 234 | 235 | instance Semigroup (SnocList a) where 236 | SnocList l <> SnocList r = SnocList (r <> l) 237 | 238 | instance Monoid (SnocList a) where 239 | mempty = SnocList [] 240 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Affine.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | 16 | module MLIR.AST.Dialect.Affine where 17 | 18 | import Control.Monad.IO.Class 19 | import Control.Monad.Trans.Cont 20 | import qualified Language.C.Inline as C 21 | 22 | import qualified MLIR.Native.FFI as Native 23 | import MLIR.AST.Serialize 24 | 25 | C.context $ C.baseCtx <> Native.mlirCtx 26 | 27 | C.include "mlir-c/AffineExpr.h" 28 | C.include "mlir-c/AffineMap.h" 29 | 30 | data Expr = 31 | Dimension Int 32 | | Symbol Int 33 | | Constant Int 34 | | Add Expr Expr 35 | | Mul Expr Expr 36 | | Mod Expr Expr 37 | | FloorDiv Expr Expr 38 | | CeilDiv Expr Expr 39 | deriving Eq 40 | 41 | data Map = Map { mapDimensionCount :: Int 42 | , mapSymbolCount :: Int 43 | , mapExprs :: [Expr] 44 | } 45 | deriving Eq 46 | 47 | 48 | instance FromAST Expr Native.AffineExpr where 49 | fromAST ctx env expr = case expr of 50 | Dimension idx -> do 51 | let natIdx = fromIntegral idx 52 | [C.exp| MlirAffineExpr { mlirAffineDimExprGet($(MlirContext ctx), $(intptr_t natIdx)) } |] 53 | Symbol idx -> do 54 | let natIdx = fromIntegral idx 55 | [C.exp| MlirAffineExpr { mlirAffineSymbolExprGet($(MlirContext ctx), $(intptr_t natIdx)) } |] 56 | Constant val -> do 57 | let natVal = fromIntegral val 58 | [C.exp| MlirAffineExpr { mlirAffineConstantExprGet($(MlirContext ctx), $(int64_t natVal)) } |] 59 | Add l r -> do 60 | natL <- fromAST ctx env l 61 | natR <- fromAST ctx env r 62 | [C.exp| MlirAffineExpr { mlirAffineAddExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |] 63 | Mul l r -> do 64 | natL <- fromAST ctx env l 65 | natR <- fromAST ctx env r 66 | [C.exp| MlirAffineExpr { mlirAffineMulExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |] 67 | Mod l r -> do 68 | natL <- fromAST ctx env l 69 | natR <- fromAST ctx env r 70 | [C.exp| MlirAffineExpr { mlirAffineModExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |] 71 | FloorDiv l r -> do 72 | natL <- fromAST ctx env l 73 | natR <- fromAST ctx env r 74 | [C.exp| MlirAffineExpr { mlirAffineFloorDivExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |] 75 | CeilDiv l r -> do 76 | natL <- fromAST ctx env l 77 | natR <- fromAST ctx env r 78 | [C.exp| MlirAffineExpr { mlirAffineCeilDivExprGet($(MlirAffineExpr natL), $(MlirAffineExpr natR)) } |] 79 | 80 | 81 | instance FromAST Map Native.AffineMap where 82 | fromAST ctx env Map{..} = evalContT $ do 83 | (numExprs, nativeExprs) <- packFromAST ctx env mapExprs 84 | let nativeDimCount = fromIntegral mapDimensionCount 85 | let nativeSymbolCount = fromIntegral mapSymbolCount 86 | liftIO $ [C.exp| MlirAffineMap { 87 | mlirAffineMapGet($(MlirContext ctx), 88 | $(intptr_t nativeDimCount), 89 | $(intptr_t nativeSymbolCount), 90 | $(intptr_t numExprs), $(MlirAffineExpr* nativeExprs)) 91 | } |] 92 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Arith.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.Arith 16 | ( module MLIR.AST.Dialect.Generated.Arith 17 | ) where 18 | 19 | import MLIR.AST.Dialect.Generated.Arith 20 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/ControlFlow.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2022 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.ControlFlow 16 | ( module MLIR.AST.Dialect.ControlFlow, 17 | module MLIR.AST.Dialect.Generated.ControlFlow 18 | ) where 19 | 20 | import Prelude hiding (return) 21 | import Data.Array.IArray 22 | 23 | import MLIR.AST 24 | import MLIR.AST.Builder 25 | 26 | import MLIR.AST.Dialect.Generated.ControlFlow 27 | 28 | pattern Branch :: Location -> BlockName -> [Name] -> Operation 29 | pattern Branch loc block args = Operation 30 | { opName = "cf.br" 31 | , opLocation = loc 32 | , opResultTypes = Explicit [] 33 | , opOperands = args 34 | , opRegions = [] 35 | , opSuccessors = [block] 36 | , opAttributes = NoAttrs 37 | } 38 | 39 | br :: MonadBlockBuilder m => BlockName -> [Value] -> m EndOfBlock 40 | br block args = emitOp (Branch UnknownLocation block $ operands args) >> terminateBlock 41 | 42 | cond_br :: MonadBlockBuilder m => Value -> BlockName -> [Value] -> BlockName -> [Value] -> m EndOfBlock 43 | cond_br cond trueBlock trueArgs falseBlock falseArgs = do 44 | emitOp_ $ Operation 45 | { opName = "cf.cond_br" 46 | , opLocation = UnknownLocation 47 | , opResultTypes = Explicit [] 48 | , opOperands = operands $ [cond] <> trueArgs <> falseArgs 49 | , opRegions = [] 50 | , opSuccessors = [trueBlock, falseBlock] 51 | , opAttributes = namedAttribute "operand_segment_sizes" $ 52 | DenseArrayAttr $ 53 | DenseInt32 $ listArray (0 :: Int, 2) $ fromIntegral <$> [1, length trueArgs, length falseArgs] 54 | } 55 | terminateBlock 56 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Func.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2022 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.Func 16 | ( module MLIR.AST.Dialect.Generated.Func 17 | ) where 18 | import MLIR.AST.Dialect.Generated.Func 19 | 20 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/LLVM.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.LLVM ( 16 | -- * Types 17 | Type(..) 18 | , pattern Array 19 | , pattern Void 20 | , pattern LiteralStruct 21 | -- * Operations 22 | , module MLIR.AST.Dialect.Generated.LLVM 23 | ) where 24 | 25 | import MLIR.AST.Dialect.Generated.LLVM 26 | 27 | import Data.Typeable 28 | import Control.Monad.IO.Class 29 | import Control.Monad.Trans.Cont 30 | import qualified Language.C.Inline as C 31 | 32 | import qualified MLIR.AST as AST 33 | import qualified MLIR.AST.Serialize as AST 34 | import qualified MLIR.Native as Native 35 | import qualified MLIR.Native.FFI as Native 36 | 37 | C.context $ C.baseCtx <> Native.mlirCtx 38 | C.include "mlir-c/Dialect/LLVM.h" 39 | 40 | data Type = ArrayType Int AST.Type 41 | | VoidType 42 | | LiteralStructType [AST.Type] 43 | -- TODO(apaszke): Structures, functions, vectors, etc. 44 | deriving Eq 45 | 46 | instance AST.FromAST Type Native.Type where 47 | fromAST ctx env ty = case ty of 48 | ArrayType size t -> do 49 | nt <- AST.fromAST ctx env t 50 | let nsize = fromIntegral size 51 | [C.exp| MlirType { mlirLLVMArrayTypeGet($(MlirType nt), $(unsigned int nsize)) } |] 52 | VoidType -> [C.exp| MlirType { mlirLLVMVoidTypeGet($(MlirContext ctx)) } |] 53 | LiteralStructType fields -> evalContT $ do 54 | (numFields, nativeFields) <- AST.packFromAST ctx env fields 55 | liftIO $ [C.exp| MlirType { 56 | mlirLLVMStructTypeLiteralGet($(MlirContext ctx), $(intptr_t numFields), 57 | $(MlirType* nativeFields), false) 58 | } |] 59 | 60 | 61 | castLLVMType :: AST.Type -> Maybe Type 62 | castLLVMType ty = case ty of 63 | AST.DialectType dty -> cast dty 64 | _ -> Nothing 65 | 66 | pattern Array :: Int -> AST.Type -> AST.Type 67 | pattern Array n t <- (castLLVMType -> Just (ArrayType n t)) 68 | where Array n t = AST.DialectType (ArrayType n t) 69 | 70 | pattern Void :: AST.Type 71 | pattern Void <- (castLLVMType -> Just VoidType) 72 | where Void = AST.DialectType VoidType 73 | 74 | pattern LiteralStruct :: [AST.Type] -> AST.Type 75 | pattern LiteralStruct fields <- (castLLVMType -> Just (LiteralStructType fields)) 76 | where LiteralStruct fields = AST.DialectType (LiteralStructType fields) 77 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Linalg.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.Linalg 16 | ( module MLIR.AST.Dialect.Generated.Linalg 17 | , module MLIR.AST.Dialect.Generated.LinalgStructured 18 | ) where 19 | 20 | import MLIR.AST.Dialect.Generated.Linalg 21 | import MLIR.AST.Dialect.Generated.LinalgStructured 22 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/MemRef.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.MemRef where 16 | 17 | import MLIR.AST 18 | 19 | pattern Load :: Type -> Name -> [Name] -> Operation 20 | pattern Load ty src idx = Operation 21 | { opName = "memref.load" 22 | , opLocation = UnknownLocation 23 | , opResultTypes = Explicit [ty] 24 | , opOperands = src : idx 25 | , opRegions = [] 26 | , opSuccessors = [] 27 | , opAttributes = NoAttrs 28 | } 29 | 30 | pattern Store :: Name -> Name -> [Name] -> Operation 31 | pattern Store src dst idx = Operation 32 | { opName = "memref.store" 33 | , opLocation = UnknownLocation 34 | , opResultTypes = Explicit [] 35 | , opOperands = src : dst : idx 36 | , opRegions = [] 37 | , opSuccessors = [] 38 | , opAttributes = NoAttrs 39 | } 40 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Shape.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.Shape 16 | ( module MLIR.AST.Dialect.Generated.Shape 17 | ) where 18 | import MLIR.AST.Dialect.Generated.Shape 19 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Tensor.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.Tensor 16 | ( module MLIR.AST.Dialect.Generated.Tensor 17 | ) where 18 | 19 | import MLIR.AST.Dialect.Generated.Tensor 20 | 21 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/Vector.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | {-# OPTIONS_GHC -Wno-name-shadowing #-} 16 | 17 | module MLIR.AST.Dialect.Vector 18 | ( module MLIR.AST.Dialect.Vector 19 | , module MLIR.AST.Dialect.Generated.Vector 20 | ) where 21 | 22 | import Data.Typeable 23 | import qualified Data.Map.Strict as M 24 | import qualified Data.ByteString as BS 25 | import qualified Language.C.Inline as C 26 | 27 | import MLIR.AST.Dialect.Generated.Vector 28 | import qualified MLIR.AST as AST 29 | import qualified MLIR.AST.Serialize as AST 30 | import qualified MLIR.AST.Dialect.Affine as Affine 31 | import qualified MLIR.Native as Native 32 | import qualified MLIR.Native.FFI as Native 33 | 34 | data IteratorKind = Parallel | Reduction 35 | deriving (Eq, Show) 36 | 37 | data Attribute = IteratorAttr IteratorKind 38 | deriving (Eq, Show) 39 | 40 | castVectorAttr :: AST.Attribute -> Maybe Attribute 41 | castVectorAttr ty = case ty of 42 | AST.DialectAttr dty -> cast dty 43 | _ -> Nothing 44 | 45 | showIterator :: IteratorKind -> BS.ByteString 46 | showIterator Parallel = "#vector.iterator_type" 47 | showIterator Reduction = "#vector.iterator_type" 48 | 49 | C.context $ C.baseCtx <> Native.mlirCtx 50 | 51 | C.include "mlir-c/IR.h" 52 | 53 | instance AST.FromAST Attribute Native.Attribute where 54 | fromAST ctx _ ty = case ty of 55 | IteratorAttr t -> do 56 | let value = showIterator t 57 | Native.withStringRef value \(Native.StringRef ptr len) -> 58 | [C.exp| MlirAttribute { 59 | mlirAttributeParseGet($(MlirContext ctx), (MlirStringRef){$(char* ptr), $(size_t len)}) 60 | } |] 61 | 62 | iterFromAttribute :: AST.Attribute -> Maybe IteratorKind 63 | iterFromAttribute attr = case attr of 64 | AST.DialectAttr subAttr -> case cast subAttr of 65 | Just (IteratorAttr kind) -> Just kind 66 | _ -> Nothing 67 | _ -> Nothing 68 | 69 | itersFromAttribute :: AST.Attribute -> Maybe [IteratorKind] 70 | itersFromAttribute attr = case attr of 71 | AST.ArrayAttr subAttrs -> traverse iterFromAttribute subAttrs 72 | _ -> Nothing 73 | 74 | pattern IteratorAttrs :: [IteratorKind] -> AST.Attribute 75 | pattern IteratorAttrs iterTypes <- (itersFromAttribute -> Just iterTypes) 76 | where IteratorAttrs iterTypes = AST.ArrayAttr $ fmap (AST.DialectAttr . IteratorAttr) iterTypes 77 | 78 | pattern ContractAttrs :: Affine.Map -> Affine.Map -> Affine.Map -> [IteratorKind] -> AST.NamedAttributes 79 | pattern ContractAttrs lhsMap rhsMap accMap iterKinds <- 80 | ((\m -> (M.lookup "indexing_maps" m, M.lookup "iterator_types" m)) -> 81 | (Just (AST.ArrayAttr [AST.AffineMapAttr lhsMap, AST.AffineMapAttr rhsMap, AST.AffineMapAttr accMap]), 82 | Just (IteratorAttrs iterKinds))) 83 | where ContractAttrs lhsMap rhsMap accMap iterKinds = M.fromList 84 | [ ("indexing_maps", AST.ArrayAttr [ AST.AffineMapAttr lhsMap 85 | , AST.AffineMapAttr rhsMap 86 | , AST.AffineMapAttr accMap]) 87 | , ("iterator_types", IteratorAttrs iterKinds) 88 | ] 89 | 90 | pattern Contract :: AST.Location -> AST.Type -> AST.Name -> AST.Name -> AST.Name 91 | -> Affine.Map -> Affine.Map -> Affine.Map -> [IteratorKind] 92 | -> AST.Operation 93 | pattern Contract location resultType lhs rhs acc lhsMap rhsMap accMap iterKinds = AST.Operation 94 | { opName = "vector.contract" 95 | , opLocation = location 96 | , opResultTypes = AST.Explicit [resultType] 97 | , opOperands = [lhs, rhs, acc] 98 | , opRegions = [] 99 | , opSuccessors = [] 100 | , opAttributes = ContractAttrs lhsMap rhsMap accMap iterKinds 101 | } 102 | 103 | -------------------------------------------------------------------------------- /src/MLIR/AST/Dialect/X86Vector.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Dialect.X86Vector 16 | ( module MLIR.AST.Dialect.Generated.X86Vector 17 | ) where 18 | 19 | import MLIR.AST.Dialect.Generated.X86Vector 20 | -------------------------------------------------------------------------------- /src/MLIR/AST/IStorableArray.hs: -------------------------------------------------------------------------------- 1 | module MLIR.AST.IStorableArray (IStorableArray, unsafeWithIStorableArray) where 2 | 3 | import Data.Ix 4 | import Data.Array.Storable 5 | import Data.Array.Base 6 | import Foreign.Ptr 7 | import Foreign.Storable 8 | import System.IO.Unsafe 9 | 10 | newtype IStorableArray i e = UnsafeIStorableArray (StorableArray i e) 11 | 12 | unsafeWithIStorableArray :: IStorableArray i e -> (Ptr e -> IO c) -> IO c 13 | unsafeWithIStorableArray (UnsafeIStorableArray arr) = withStorableArray arr 14 | 15 | instance Storable e => IArray IStorableArray e where 16 | bounds (UnsafeIStorableArray arr) = unsafeDupablePerformIO $ getBounds arr 17 | numElements = rangeSize . bounds 18 | unsafeArray bs inits = unsafeDupablePerformIO $ do 19 | arr <- newArray_ bs 20 | mapM_ (uncurry $ unsafeWrite arr) inits 21 | return $ UnsafeIStorableArray arr 22 | unsafeAt (UnsafeIStorableArray arr) i = unsafeDupablePerformIO $ unsafeRead arr i 23 | 24 | instance (Ix i, Show i, Show e, Storable e) => Show (IStorableArray i e) where 25 | showsPrec = showsIArray 26 | 27 | instance (Ix i, Eq e, Storable e) => Eq (IStorableArray i e) where 28 | a == b = (bounds a == bounds b) && 29 | (all id [unsafeAt a i == unsafeAt b i | i <- [0 .. numElements a - 1]]) 30 | -------------------------------------------------------------------------------- /src/MLIR/AST/PatternUtil.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.PatternUtil 16 | ( pattern I32ArrayAttr 17 | , pattern I64ArrayAttr 18 | , pattern AffineMapArrayAttr 19 | , DummyIx 20 | ) where 21 | 22 | import Data.Traversable 23 | import Data.Array 24 | 25 | import MLIR.AST 26 | import qualified MLIR.AST.Dialect.Affine as Affine 27 | 28 | unwrapI32ArrayAttr :: Attribute -> Maybe [Int] 29 | unwrapI32ArrayAttr (ArrayAttr vals) = for vals \case 30 | IntegerAttr (IntegerType Signed 32) v -> Just v 31 | _ -> Nothing 32 | unwrapI32ArrayAttr _ = Nothing 33 | 34 | pattern I32ArrayAttr :: [Int] -> Attribute 35 | pattern I32ArrayAttr vals <- (unwrapI32ArrayAttr -> Just vals) 36 | where I32ArrayAttr vals = ArrayAttr $ fmap (IntegerAttr (IntegerType Signed 32)) vals 37 | 38 | unwrapI64ArrayAttr :: Attribute -> Maybe [Int] 39 | unwrapI64ArrayAttr (ArrayAttr vals) = for vals \case 40 | IntegerAttr (IntegerType Signed 64) v -> Just v 41 | _ -> Nothing 42 | unwrapI64ArrayAttr _ = Nothing 43 | 44 | pattern I64ArrayAttr :: [Int] -> Attribute 45 | pattern I64ArrayAttr vals <- (unwrapI64ArrayAttr -> Just vals) 46 | where I64ArrayAttr vals = ArrayAttr $ fmap (IntegerAttr (IntegerType Signed 64)) vals 47 | 48 | unwrapAffineMapArrayAttr :: Attribute -> Maybe [Affine.Map] 49 | unwrapAffineMapArrayAttr (ArrayAttr vals) = for vals \case 50 | AffineMapAttr m -> Just m 51 | _ -> Nothing 52 | unwrapAffineMapArrayAttr _ = Nothing 53 | 54 | pattern AffineMapArrayAttr :: [Affine.Map] -> Attribute 55 | pattern AffineMapArrayAttr vals <- (unwrapAffineMapArrayAttr -> Just vals) 56 | where AffineMapArrayAttr vals = ArrayAttr $ fmap AffineMapAttr vals 57 | 58 | data DummyIx 59 | deriving instance Eq DummyIx 60 | deriving instance Ord DummyIx 61 | deriving instance Show DummyIx 62 | instance Ix DummyIx where 63 | range _ = error "Invalid index" 64 | index _ _ = error "Invalid index" 65 | inRange _ _ = error "Invalid index" 66 | -------------------------------------------------------------------------------- /src/MLIR/AST/Rewrite.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Rewrite 16 | ( RewriteBuilderT 17 | , OpRewriteM 18 | , OpRewrite 19 | , RewriteResult(..) 20 | , pattern ReplaceOne 21 | , applyClosedOpRewrite 22 | , applyClosedOpRewriteT 23 | ) where 24 | 25 | import qualified Data.Map.Strict as M 26 | import Control.Monad.Reader 27 | import Control.Monad.Identity 28 | 29 | import qualified MLIR.AST as AST 30 | import MLIR.AST hiding (Operation) 31 | import MLIR.AST.Builder 32 | 33 | -- For convenience we pass in operations with operands already substituted 34 | -- for builder-compilant values that can be used to replace the operation 35 | -- with an arbitrary program constructed through a builder expression. 36 | type Operation = AST.AbstractOperation Value 37 | 38 | type ValueMapping = M.Map Name Value 39 | type BlockMapping = M.Map BlockName BlockName 40 | type BlockAndValueMapping = (ValueMapping, BlockMapping) 41 | 42 | type SubstT = ReaderT BlockAndValueMapping 43 | type RewriteT m = SubstT (NameSupplyT m) 44 | type RewriteBuilderT m = BlockBuilderT (RewriteT m) 45 | 46 | data RewriteResult = Replace [Value] | Skip | Traverse 47 | 48 | pattern ReplaceOne :: Value -> RewriteResult 49 | pattern ReplaceOne val = Replace [val] 50 | 51 | type OpRewriteM m = Operation -> RewriteBuilderT m RewriteResult 52 | type OpRewrite = OpRewriteM Identity 53 | 54 | extendValueMap :: MonadReader BlockAndValueMapping m => ValueMapping -> m a -> m a 55 | extendValueMap upd = local \(vm, bm) -> (vm <> upd, bm) 56 | 57 | extendBlockMap :: MonadReader BlockAndValueMapping m => BlockMapping -> m a -> m a 58 | extendBlockMap upd = local \(vm, bm) -> (vm, bm <> upd) 59 | 60 | applyClosedOpRewrite :: OpRewrite -> AST.Operation -> AST.Operation 61 | applyClosedOpRewrite rule op = runIdentity $ applyClosedOpRewriteT rule op 62 | 63 | applyClosedOpRewriteT :: MonadFix m => OpRewriteM m -> AST.Operation -> m AST.Operation 64 | applyClosedOpRewriteT rule op = evalNameSupplyT $ applyOpRewrite rule op 65 | 66 | applyOpRewrite :: MonadFix m => OpRewriteM m -> AST.Operation -> NameSupplyT m AST.Operation 67 | applyOpRewrite rule op = flip runReaderT (mempty, mempty) $ do 68 | newRegions <- mapM (applyOpRewriteRegion rule) $ opRegions op 69 | return $ op { opRegions = newRegions } 70 | 71 | applyOpRewriteRegion :: MonadFix m => OpRewriteM m -> Region -> RewriteT m Region 72 | applyOpRewriteRegion rule (Region blocks) = do 73 | buildRegion $ void $ mfix \blockSubst -> extendBlockMap blockSubst $ go mempty blocks 74 | where 75 | go blockSubst bs = case bs of 76 | [] -> return blockSubst 77 | (block@(Block oldName _ _) : rest) -> do 78 | newName <- applyOpRewriteBlock rule block 79 | go (blockSubst <> M.singleton oldName newName) rest 80 | 81 | applyOpRewriteBlock :: MonadFix m => OpRewriteM m -> Block -> RegionBuilderT (RewriteT m) BlockName 82 | applyOpRewriteBlock rule Block{..} = do 83 | buildBlock $ do 84 | let (blockArgNames, blockArgTypes) = unzip blockArgs 85 | newBlockArgs <- mapM blockArgument blockArgTypes 86 | extendValueMap (M.fromList $ zip blockArgNames newBlockArgs) $ go blockBody 87 | where 88 | go bs = case bs of 89 | [] -> terminateBlock 90 | ((Bind names astOp) : rest) -> do 91 | op <- substOp astOp 92 | answer <- rule op 93 | newValues <- case answer of 94 | Replace newValues -> do 95 | unless (length names == length newValues) $ 96 | error "Rewrite rule returned an incorrect number of values" 97 | return newValues 98 | Traverse -> opRewriteTraverse op 99 | Skip -> opRewriteSkip op 100 | extendValueMap (M.fromList $ zip names newValues) $ go rest 101 | 102 | opRewriteTraverse op = do 103 | newRegions <- lift $ mapM (applyOpRewriteRegion rule) $ opRegions op 104 | emitOp $ op { opRegions = newRegions, opOperands = operands (opOperands op) } 105 | 106 | opRewriteSkip op = do 107 | -- Note that we still have to traverse the subregions in case they close-over 108 | -- any values that have had their names updated. 109 | newRegions <- lift $ mapM (applyOpRewriteRegion (const $ return Skip)) $ opRegions op 110 | emitOp $ op { opRegions = newRegions, opOperands = operands (opOperands op) } 111 | 112 | substOp :: MonadReader BlockAndValueMapping m => AST.Operation -> m Operation 113 | substOp op = do 114 | (valueMap, blockMap) <- ask 115 | let newOperands = fmap (valueMap M.!) $ opOperands op 116 | let newSuccessors = fmap (blockMap M.!) $ opSuccessors op 117 | return $ op { opOperands = newOperands, opSuccessors = newSuccessors } 118 | 119 | -- TODO(apaszke): Multi-op-patterns. Sketch: 120 | -- 121 | -- type OperationExpr = AbstractOperation Value 122 | -- data Value = BlockArgument Builder.Value | OperationResult Builder.Value Int OperationExpr 123 | -- 124 | -- instance Builder.IsValue Value where 125 | -- getValue (BlockArgument val) = val 126 | -- getValue (Result val _ _) = val 127 | -- 128 | -- pattern Result :: Int -> OperationExpr -> Value 129 | -- pattern Result idx opExpr <- OperationResult _ idx opExpr 130 | -- 131 | -- pattern Result0 :: OperationExpr -> Value 132 | -- pattern Result0 opExpr <- OperationResult _ 0 opExpr 133 | -- 134 | -- asOperationExpr :: MonadRewrite m => Operation -> m OperationExpr 135 | -- asOperationExpr = undefined 136 | -- 137 | -- TODO(apaszke): Multi-op-patterns removals with multi-op-patterns. Sketch: 138 | -- 139 | -- erase :: MonadRewrite m => Operation -> m () 140 | -- erase = undefined 141 | -------------------------------------------------------------------------------- /src/MLIR/AST/Serialize.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.AST.Serialize ( 16 | ValueMapping, 17 | BlockMapping, 18 | ValueAndBlockMapping, 19 | FromAST(..), 20 | packFromAST, packArray, unpackArray) where 21 | 22 | import Foreign.Ptr 23 | import Foreign.Storable 24 | import Foreign.Marshal.Array 25 | import Control.Monad.IO.Class 26 | import Control.Monad.Trans.Cont 27 | import qualified Language.C.Inline as C 28 | import qualified Data.ByteString as BS 29 | import qualified Data.Map.Strict as M 30 | 31 | import qualified MLIR.Native as Native 32 | import qualified MLIR.Native.FFI as Native 33 | 34 | type Name = BS.ByteString 35 | 36 | type ValueMapping = M.Map Name Native.Value 37 | type BlockMapping = M.Map Name Native.Block 38 | type ValueAndBlockMapping = (ValueMapping, BlockMapping) 39 | 40 | class FromAST ast native | ast -> native where 41 | fromAST :: Native.Context -> ValueAndBlockMapping -> ast -> IO native 42 | 43 | packFromAST :: (FromAST ast native, Storable native) 44 | => Native.Context -> ValueAndBlockMapping 45 | -> [ast] -> ContT r IO (C.CIntPtr, Ptr native) 46 | packFromAST ctx env asts = packArray =<< liftIO (mapM (fromAST ctx env) asts) 47 | 48 | -- TODO(apaszke): Unify this with packing utilities from ExecutionEngine? 49 | packArray :: Storable a => [a] -> ContT r IO (C.CIntPtr, Ptr a) 50 | packArray xs = do 51 | let arrSize = (length xs) 52 | ptr <- ContT $ allocaArray arrSize 53 | liftIO $ mapM_ (uncurry $ pokeElemOff ptr) $ zip [0..] xs 54 | return (fromIntegral arrSize, ptr) 55 | 56 | unpackArray :: Storable a => C.CIntPtr -> Ptr a -> IO [a] 57 | unpackArray size arrPtr = mapM (peekElemOff arrPtr) [0..fromIntegral size - 1] 58 | -------------------------------------------------------------------------------- /src/MLIR/Native.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | {-| 16 | This module defines a set of Haskell types wrapping references to native C++ 17 | MLIR objects along with some basic operations on them. See the submodules for 18 | more specialized components such as an 'MLIR.Native.ExecutionEngine.ExecutionEngine' 19 | or 'MLIR.Native.Pass.PassManager'. 20 | -} 21 | module MLIR.Native ( 22 | -- * Contexts 23 | Context, 24 | createContext, 25 | destroyContext, 26 | withContext, 27 | HasContext(..), 28 | -- ** Dialect registration 29 | registerAllDialects, 30 | getNumLoadedDialects, 31 | -- * Type 32 | Type, 33 | -- * Location 34 | Location, 35 | getFileLineColLocation, 36 | getNameLocation, 37 | getUnknownLocation, 38 | -- * Operation 39 | Operation, 40 | getOperationName, 41 | showOperation, 42 | showOperationWithLocation, 43 | verifyOperation, 44 | -- * Region 45 | Region, 46 | getOperationRegions, 47 | getRegionBlocks, 48 | -- * Block 49 | Block, 50 | showBlock, 51 | getBlockOperations, 52 | -- * Module 53 | Module, 54 | createEmptyModule, 55 | parseModule, 56 | destroyModule, 57 | getModuleBody, 58 | moduleAsOperation, 59 | moduleFromOperation, 60 | showModule, 61 | -- * StringRef 62 | StringRef(..), 63 | withStringRef, 64 | -- * Identifier 65 | Identifier, 66 | createIdentifier, 67 | identifierString, 68 | -- * LogicalResult 69 | LogicalResult, 70 | pattern Failure, 71 | pattern Success, 72 | -- * Debugging utilities 73 | setDebugMode, 74 | HasDump(..), 75 | ) where 76 | 77 | import qualified Data.ByteString as BS 78 | 79 | import Foreign.Ptr 80 | import Foreign.Storable 81 | import Foreign.Marshal.Alloc 82 | import Foreign.Marshal.Array 83 | import qualified Language.C.Inline as C 84 | 85 | import Control.Monad 86 | import Control.Monad.IO.Class 87 | import Control.Monad.Trans.Cont 88 | import Control.Exception (bracket) 89 | 90 | import MLIR.Native.FFI 91 | 92 | C.context $ C.baseCtx <> mlirCtx 93 | 94 | C.include "mlir-c/Support.h" 95 | C.include "mlir-c/Debug.h" 96 | C.include "mlir-c/IR.h" 97 | C.include "mlir-c/Pass.h" 98 | C.include "mlir-c/Conversion.h" 99 | C.include "mlir-c/RegisterEverything.h" 100 | 101 | C.verbatim stringCallbackDecl 102 | 103 | -- TODO(apaszke): Flesh this out based on the header 104 | 105 | -------------------------------------------------------------------------------- 106 | -- Context management 107 | 108 | -- | Creates a native MLIR context. 109 | createContext :: IO Context 110 | createContext = [C.exp| MlirContext { mlirContextCreate() } |] 111 | 112 | -- | Destroys a native MLIR context. 113 | destroyContext :: Context -> IO () 114 | destroyContext ctx = [C.exp| void { mlirContextDestroy($(MlirContext ctx)) } |] 115 | 116 | -- | Wraps an IO action that gets access to a fresh MLIR context. 117 | withContext :: (Context -> IO a) -> IO a 118 | withContext = bracket createContext destroyContext 119 | 120 | -- TODO(apaszke): Can this be pure? 121 | -- | A typeclass for retrieving MLIR contexts managing other native types. 122 | class HasContext a where 123 | -- | Retrieve the MLIR context that manages the storage of the native value. 124 | getContext :: a -> IO Context 125 | 126 | -------------------------------------------------------------------------------- 127 | -- Dialect registration 128 | 129 | -- | Register all builtin MLIR dialects in the specified 'Context'. 130 | registerAllDialects :: Context -> IO () 131 | registerAllDialects ctx = [C.block| void { 132 | MlirDialectRegistry registry = mlirDialectRegistryCreate(); 133 | mlirRegisterAllDialects(registry); 134 | mlirContextAppendDialectRegistry($(MlirContext ctx), registry); 135 | mlirDialectRegistryDestroy(registry); 136 | mlirContextLoadAllAvailableDialects($(MlirContext ctx)); 137 | } |] 138 | 139 | -- | Retrieve the count of dialects currently registered in the 'Context'. 140 | getNumLoadedDialects :: Context -> IO Int 141 | getNumLoadedDialects ctx = fromIntegral <$> 142 | [C.exp| intptr_t { mlirContextGetNumLoadedDialects($(MlirContext ctx)) } |] 143 | 144 | -------------------------------------------------------------------------------- 145 | -- Locations 146 | 147 | -- | Create an unknown source location. 148 | getUnknownLocation :: Context -> IO Location 149 | getUnknownLocation ctx = 150 | [C.exp| MlirLocation { mlirLocationUnknownGet($(MlirContext ctx)) } |] 151 | 152 | getFileLineColLocation :: Context -> StringRef -> C.CUInt -> C.CUInt -> IO Location 153 | getFileLineColLocation ctx (StringRef sPtr len) line col = 154 | [C.exp| MlirLocation { 155 | mlirLocationFileLineColGet( 156 | $(MlirContext ctx), 157 | (MlirStringRef){$(char* sPtr), $(size_t len)}, 158 | $(unsigned int line), 159 | $(unsigned int col)) } |] 160 | 161 | getNameLocation :: Context -> StringRef -> Location -> IO Location 162 | getNameLocation ctx (StringRef sPtr len) childLoc = 163 | [C.exp| MlirLocation { 164 | mlirLocationNameGet( 165 | $(MlirContext ctx), 166 | (MlirStringRef){$(char* sPtr), $(size_t len)}, 167 | $(MlirLocation childLoc)) } |] 168 | 169 | -- TODO(apaszke): No destructor for locations? 170 | 171 | -------------------------------------------------------------------------------- 172 | -- Operation 173 | 174 | -- | Retrieve the name of the given operation. 175 | getOperationName :: Operation -> IO Identifier 176 | getOperationName op = 177 | [C.exp| MlirIdentifier { mlirOperationGetName($(MlirOperation op)) } |] 178 | 179 | -- | Show the operation using the MLIR printer. 180 | showOperation :: Operation -> IO BS.ByteString 181 | showOperation op = showSomething \ctx -> 182 | [C.block| void { 183 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 184 | mlirOperationPrintWithFlags($(MlirOperation op), flags, 185 | HaskellMlirStringCallback, $(void* ctx)); 186 | mlirOpPrintingFlagsDestroy(flags); 187 | } |] 188 | 189 | -- TODO(jpienaar): This should probably be more general options supported. 190 | -- | Show the operation with location using the MLIR printer. 191 | showOperationWithLocation :: Operation -> IO BS.ByteString 192 | showOperationWithLocation op = showSomething \ctx -> 193 | [C.block| void { 194 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 195 | mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/false); 196 | mlirOperationPrintWithFlags($(MlirOperation op), flags, 197 | HaskellMlirStringCallback, $(void* ctx)); 198 | mlirOpPrintingFlagsDestroy(flags); 199 | } |] 200 | 201 | -- | Check validity of the operation. 202 | verifyOperation :: Operation -> IO Bool 203 | verifyOperation op = 204 | (1==) <$> [C.exp| bool { mlirOperationVerify($(MlirOperation op)) } |] 205 | 206 | -------------------------------------------------------------------------------- 207 | -- Region 208 | 209 | -- | Returns the first Region in a Operation. 210 | getOperationFirstRegion :: Operation -> IO (Maybe Region) 211 | getOperationFirstRegion op = nullable <$> [C.exp| MlirRegion { 212 | mlirOperationGetFirstRegion($(MlirOperation op)) 213 | } |] 214 | 215 | -- | Returns the next Block in a Region. 216 | getOperationNextRegion :: Region -> IO (Maybe Region) 217 | getOperationNextRegion region = nullable <$> [C.exp| MlirRegion { 218 | mlirRegionGetNextInOperation($(MlirRegion region)) 219 | } |] 220 | 221 | -- | Returns the regions of an Operation. 222 | getOperationRegions :: Operation -> IO [Region] 223 | getOperationRegions op = unrollIOMaybe getOperationNextRegion (getOperationFirstRegion op) 224 | 225 | -- | Returns the first Block in a Region. 226 | getRegionFirstBlock :: Region -> IO (Maybe Block) 227 | getRegionFirstBlock region = nullable <$> [C.exp| MlirBlock { 228 | mlirRegionGetFirstBlock($(MlirRegion region)) 229 | } |] 230 | 231 | -- | Returns the next Block in a Region. 232 | getRegionNextBlock :: Block -> IO (Maybe Block) 233 | getRegionNextBlock block = nullable <$> [C.exp| MlirBlock { 234 | mlirBlockGetNextInRegion($(MlirBlock block)) 235 | } |] 236 | 237 | -- | Returns the Blocks in a Region. 238 | getRegionBlocks :: Region -> IO [Block] 239 | getRegionBlocks region = unrollIOMaybe getRegionNextBlock (getRegionFirstBlock region) 240 | 241 | -------------------------------------------------------------------------------- 242 | -- Block 243 | 244 | -- | Show the block using the MLIR printer. 245 | showBlock :: Block -> IO BS.ByteString 246 | showBlock block = showSomething \ctx -> [C.exp| void { 247 | mlirBlockPrint($(MlirBlock block), HaskellMlirStringCallback, $(void* ctx)) 248 | } |] 249 | 250 | -- | Returns the first operation in a block. 251 | getFirstOperationBlock :: Block -> IO (Maybe Operation) 252 | getFirstOperationBlock block = nullable <$> 253 | [C.exp| MlirOperation { mlirBlockGetFirstOperation($(MlirBlock block)) } |] 254 | 255 | -- | Returns the next operation in the block. Returns 'Nothing' if last 256 | -- operation in block. 257 | getNextOperationBlock :: Operation -> IO (Maybe Operation) 258 | getNextOperationBlock childOp = nullable <$> [C.exp| MlirOperation { 259 | mlirOperationGetNextInBlock($(MlirOperation childOp)) } |] 260 | 261 | -- | Returns the Operations in a Block. 262 | getBlockOperations :: Block -> IO [Operation] 263 | getBlockOperations block = unrollIOMaybe getNextOperationBlock (getFirstOperationBlock block) 264 | 265 | -------------------------------------------------------------------------------- 266 | -- Module 267 | 268 | instance HasContext Module where 269 | getContext m = [C.exp| MlirContext { mlirModuleGetContext($(MlirModule m)) } |] 270 | 271 | -- | Create an empty module. 272 | createEmptyModule :: Location -> IO Module 273 | createEmptyModule loc = 274 | [C.exp| MlirModule { mlirModuleCreateEmpty($(MlirLocation loc)) } |] 275 | 276 | -- | Parse a module from a string. Returns 'Nothing' in case of parse failure. 277 | parseModule :: Context -> StringRef -> IO (Maybe Module) 278 | parseModule ctx (StringRef sPtr len) = nullable <$> 279 | [C.exp| MlirModule { 280 | mlirModuleCreateParse($(MlirContext ctx), 281 | (MlirStringRef){$(char* sPtr), $(size_t len)}) 282 | } |] 283 | 284 | -- | Destroy all resources associated with a 'Module'. 285 | destroyModule :: Module -> IO () 286 | destroyModule m = 287 | [C.exp| void { mlirModuleDestroy($(MlirModule m)) } |] 288 | 289 | -- | Retrieve the block containg all module definitions. 290 | getModuleBody :: Module -> IO Block 291 | getModuleBody m = [C.exp| MlirBlock { mlirModuleGetBody($(MlirModule m)) } |] 292 | 293 | -- TODO(apaszke): Can this be pure? 294 | -- | Convert a module to an 'Operation'. 295 | moduleAsOperation :: Module -> IO Operation 296 | moduleAsOperation m = 297 | [C.exp| MlirOperation { mlirModuleGetOperation($(MlirModule m)) } |] 298 | 299 | -- | Inverse of 'moduleAsOperation'. Returns 'Nothing' if the operation is not a 300 | -- builtin MLIR module operation. 301 | moduleFromOperation :: Operation -> IO (Maybe Module) 302 | moduleFromOperation op = 303 | nullable <$> [C.exp| MlirModule { mlirModuleFromOperation($(MlirOperation op)) } |] 304 | 305 | -- | Show the module using the MLIR printer. 306 | showModule :: Module -> IO BS.ByteString 307 | showModule = moduleAsOperation >=> showOperation 308 | 309 | -------------------------------------------------------------------------------- 310 | -- StringRef 311 | 312 | data StringRef = StringRef (Ptr C.CChar) C.CSize 313 | 314 | -- MLIR sometimes expects null-terminated StringRefs, so we can't use 315 | -- unsafeUseAsCStringLen, because ByteStrings are not guaranteed to have a terminator 316 | -- | Use a 'BS.ByteString' as a 'StringRef'. This is O(n) due to MLIR sometimes 317 | -- requiring the 'StringRef's to be null-terminated. 318 | withStringRef :: BS.ByteString -> (StringRef -> IO a) -> IO a 319 | withStringRef s f = BS.useAsCString s \ptr -> f $ StringRef ptr $ fromIntegral $ BS.length s 320 | 321 | -- | Copy a 'StringRef' as a 'BS.ByteString'. This is an O(n) operation. 322 | peekStringRef :: StringRef -> IO BS.ByteString 323 | peekStringRef (StringRef ref size) = BS.packCStringLen (ref, fromIntegral size) 324 | 325 | -------------------------------------------------------------------------------- 326 | -- Identifier 327 | 328 | -- | View an identifier as a 'StringRef'. The result is valid for as long as the 329 | -- 'Context' managing the identifier. 330 | identifierString :: Identifier -> IO StringRef 331 | identifierString ident = evalContT $ do 332 | namePtrPtr <- ContT alloca 333 | sizePtr <- ContT alloca 334 | liftIO $ do 335 | [C.block| void { 336 | MlirStringRef identStr = mlirIdentifierStr($(MlirIdentifier ident)); 337 | *$(const char** namePtrPtr) = identStr.data; 338 | *$(size_t* sizePtr) = identStr.length; 339 | } |] 340 | StringRef <$> peek namePtrPtr <*> peek sizePtr 341 | 342 | -- | Create an identifier from a 'StringRef'. 343 | createIdentifier :: Context -> StringRef -> IO Identifier 344 | createIdentifier ctx (StringRef ref size) = 345 | [C.exp| MlirIdentifier { 346 | mlirIdentifierGet($(MlirContext ctx), (MlirStringRef){$(char* ref), $(size_t size)}) 347 | } |] 348 | 349 | -------------------------------------------------------------------------------- 350 | -- Utilities 351 | 352 | showSomething :: (Ptr () -> IO ()) -> IO BS.ByteString 353 | showSomething action = do 354 | allocaArray @(Ptr ()) 2 \ctx -> 355 | alloca @C.CSize \sizePtr -> do 356 | poke sizePtr 0 357 | pokeElemOff ctx 0 nullPtr 358 | pokeElemOff ctx 1 $ castPtr sizePtr 359 | let ctxFlat = (castPtr ctx) :: Ptr () 360 | action ctxFlat 361 | dataPtr <- castPtr <$> peek ctx 362 | size <- peek sizePtr 363 | bs <- peekStringRef $ StringRef dataPtr size 364 | free dataPtr 365 | return bs 366 | 367 | -- | Unroll using a function that is equivalent to "get next" inside IO. 368 | unrollIOMaybe :: (a -> IO (Maybe a)) -> IO (Maybe a) -> IO [a] 369 | unrollIOMaybe fn z = do 370 | x <- z 371 | case x of 372 | Nothing -> return [] 373 | Just x' -> (x':) <$> unrollIOMaybe fn (fn x') 374 | 375 | -------------------------------------------------------------------------------- 376 | -- Debugging utilities 377 | 378 | -- | Enable or disable debug logging in MLIR. 379 | setDebugMode :: Bool -> IO () 380 | setDebugMode enable = do 381 | let nativeEnable = if enable then 1 else 0 382 | [C.exp| void { mlirEnableGlobalDebug($(bool nativeEnable)) } |] 383 | 384 | 385 | -- | A class for native objects that can be dumped to standard error output. 386 | class HasDump a where 387 | -- | Display the value in the standard error output. 388 | dump :: a -> IO () 389 | 390 | instance HasDump Operation where 391 | dump op = [C.exp| void { mlirOperationDump($(MlirOperation op)) } |] 392 | 393 | instance HasDump Module where 394 | dump = moduleAsOperation >=> dump 395 | -------------------------------------------------------------------------------- /src/MLIR/Native/ExecutionEngine.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.Native.ExecutionEngine where 16 | 17 | import Foreign.Ptr 18 | import Foreign.Storable 19 | import Foreign.Marshal.Alloc 20 | import Foreign.Marshal.Array 21 | import Data.Int 22 | import qualified Language.C.Inline as C 23 | 24 | import Control.Exception (bracket) 25 | import Control.Monad 26 | 27 | import MLIR.Native 28 | import MLIR.Native.FFI 29 | 30 | C.context $ C.baseCtx <> mlirCtx 31 | 32 | C.include "mlir-c/ExecutionEngine.h" 33 | 34 | -- TODO(apaszke): Flesh this out based on the header 35 | 36 | -------------------------------------------------------------------------------- 37 | -- Execution engine 38 | 39 | -- TODO(apaszke): Make the opt level configurable 40 | -- TODO(apaszke): Allow loading shared libraries 41 | createExecutionEngine :: Module -> IO (Maybe ExecutionEngine) 42 | createExecutionEngine m = nullable <$> 43 | [C.exp| MlirExecutionEngine { mlirExecutionEngineCreate($(MlirModule m), 3, 0, NULL, false) } |] 44 | 45 | destroyExecutionEngine :: ExecutionEngine -> IO () 46 | destroyExecutionEngine eng = 47 | [C.exp| void { mlirExecutionEngineDestroy($(MlirExecutionEngine eng)) } |] 48 | 49 | withExecutionEngine :: Module -> (Maybe ExecutionEngine -> IO a) -> IO a 50 | withExecutionEngine m = bracket (createExecutionEngine m) 51 | (\case Just e -> destroyExecutionEngine e 52 | Nothing -> return ()) 53 | 54 | 55 | data SomeStorable = forall a. Storable a => SomeStorable a 56 | 57 | executionEngineInvoke :: forall result. Storable result 58 | => ExecutionEngine -> StringRef -> [SomeStorable] -> IO (Maybe result) 59 | executionEngineInvoke eng (StringRef namePtr nameLen) args = 60 | withPackedPtr \packPtr resultPtr -> do 61 | result <- [C.exp| MlirLogicalResult { 62 | mlirExecutionEngineInvokePacked($(MlirExecutionEngine eng), 63 | (MlirStringRef){$(char* namePtr), $(size_t nameLen)}, 64 | $(void** packPtr)) 65 | } |] 66 | case result of 67 | Success -> Just <$> peek resultPtr 68 | Failure -> return Nothing 69 | where 70 | numArgs = length args 71 | 72 | -- TODO(apaszke): Are tuples exploded, or stored as pointers? 73 | withPackedPtr :: (Ptr (Ptr ()) -> Ptr result -> IO a) -> IO a 74 | withPackedPtr f = 75 | allocaArray (numArgs + 1) \packedPtr -> 76 | alloca @result \resultPtr -> do 77 | pokeElemOff packedPtr numArgs (castPtr resultPtr) 78 | withStoredArgs args packedPtr $ f packedPtr resultPtr 79 | 80 | withStoredArgs :: [SomeStorable] -> Ptr (Ptr ()) -> IO a -> IO a 81 | withStoredArgs [] _ m = m 82 | withStoredArgs (SomeStorable h:t) nextArgPtr m = 83 | alloca \argPtr -> do 84 | poke argPtr h 85 | poke nextArgPtr (castPtr argPtr) 86 | withStoredArgs t (advancePtr nextArgPtr 1) m 87 | 88 | packStruct64 :: [SomeStorable] -> (Ptr () -> IO a) -> IO a 89 | packStruct64 fields f = do 90 | allocaArray (length fields) \(structPtr :: Ptr Int64) -> do 91 | forM_ (zip [0..] fields) \(i, SomeStorable field) -> do 92 | unless (sizeOf field == 8) $ 93 | error "packStruct64 expects all fields to be exactly 8 bytes in size" 94 | unless (alignment field <= 8) $ 95 | error "packStruct64 expects all fields to have an alignment of at most 8 bytes" 96 | pokeElemOff (castPtr structPtr) i field 97 | f $ castPtr structPtr 98 | -------------------------------------------------------------------------------- /src/MLIR/Native/FFI.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | {-# OPTIONS_HADDOCK hide #-} 16 | module MLIR.Native.FFI where 17 | 18 | import Foreign.Ptr 19 | import Foreign.Storable 20 | import qualified Language.C.Inline as C 21 | import qualified Language.C.Types as C 22 | import qualified Language.C.Inline.Context as C.Context 23 | 24 | import Text.RawString.QQ 25 | 26 | import Data.Int 27 | import Data.Coerce 28 | import qualified Data.Map as Map 29 | 30 | C.include "" 31 | C.include "" 32 | C.include "mlir-c/Support.h" 33 | 34 | -- TODO(apaszke): Better buffering? 35 | C.verbatim [r| 36 | void HaskellMlirStringCallback(MlirStringRef ref, void* ctxRaw) { 37 | void** ctx = ctxRaw; 38 | char** data_ptr = ctxRaw; 39 | size_t* size_ptr = ctx[1]; 40 | size_t old_size = *size_ptr; 41 | size_t new_size = old_size + ref.length; 42 | if (new_size == 0) return; 43 | *data_ptr = realloc(*data_ptr, new_size); 44 | *size_ptr = new_size; 45 | memcpy((*data_ptr) + old_size, ref.data, ref.length); 46 | } 47 | |] 48 | 49 | stringCallbackDecl :: String 50 | stringCallbackDecl = [r| 51 | void HaskellMlirStringCallback(MlirStringRef ref, void* ctxRaw); 52 | |] 53 | 54 | data MlirContextObject 55 | data MlirLocationObject 56 | data MlirModuleObject 57 | data MlirOperationObject 58 | data MlirPassManagerObject 59 | data MlirPassObject 60 | data MlirExecutionEngineObject 61 | data MlirTypeObject 62 | data MlirBlockObject 63 | data MlirRegionObject 64 | data MlirAttributeObject 65 | data MlirValueObject 66 | data MlirIdentifierObject 67 | data MlirAffineExprObject 68 | data MlirAffineMapObject 69 | 70 | -- | A native MLIR context. 71 | newtype Context = ContextPtr (Ptr MlirContextObject) 72 | deriving Storable via (Ptr ()) 73 | -- | A native MLIR pass instance. 74 | newtype Pass = PassPtr (Ptr MlirPassObject) 75 | deriving Storable via (Ptr ()) 76 | -- | A native MLIR pass manager instance. 77 | newtype PassManager = PassManagerPtr (Ptr MlirPassManagerObject) 78 | deriving Storable via (Ptr ()) 79 | -- | A native MLIR location object. 80 | newtype Location = LocationPtr (Ptr MlirLocationObject) 81 | deriving Storable via (Ptr ()) 82 | -- | A native MLIR operation instance. 83 | newtype Operation = OperationPtr (Ptr MlirOperationObject) 84 | deriving Storable via (Ptr ()) 85 | -- | A native MLIR module operation. 86 | -- Since every module is an operation, it can be converted to 87 | -- an 'Operation' using 'MLIR.Native.moduleAsOperation'. 88 | newtype Module = ModulePtr (Ptr MlirModuleObject) 89 | deriving Storable via (Ptr ()) 90 | -- | A native MLIR execution engine. 91 | newtype ExecutionEngine = ExecutionEnginePtr (Ptr MlirExecutionEngineObject) 92 | deriving Storable via (Ptr ()) 93 | -- | A native MLIR type object. 94 | newtype Type = TypePtr (Ptr MlirTypeObject) 95 | deriving Storable via (Ptr ()) 96 | -- | A native MLIR block object. 97 | -- Every block is a list of 'Operation's. 98 | newtype Block = BlockPtr (Ptr MlirBlockObject) 99 | deriving Storable via (Ptr ()) 100 | -- | A native MLIR region. 101 | newtype Region = RegionPtr (Ptr MlirRegionObject) 102 | deriving Storable via (Ptr ()) 103 | -- | A native MLIR attribute. 104 | newtype Attribute = AttributePtr (Ptr MlirAttributeObject) 105 | deriving Storable via (Ptr ()) 106 | -- | A native MLIR value object. 107 | -- Every 'Value' is either a 'Block' argument or an output from an 'Operation'. 108 | newtype Value = ValuePtr (Ptr MlirValueObject) 109 | deriving Storable via (Ptr ()) 110 | -- | A native MLIR identifier. 111 | -- Identifiers are strings interned in the MLIR context. 112 | newtype Identifier = IdentifierPtr (Ptr MlirIdentifierObject) 113 | deriving Storable via (Ptr ()) 114 | -- | A native MLIR affine expression object. 115 | newtype AffineExpr = AffineExprPtr (Ptr MlirAffineExprObject) 116 | deriving Storable via (Ptr ()) 117 | -- | A native MLIR affine map object. 118 | newtype AffineMap = AffineMapPtr (Ptr MlirAffineMapObject) 119 | deriving Storable via (Ptr ()) 120 | data NamedAttribute -- C structs cannot be represented in Haskell 121 | 122 | -- | A result code for many failable MLIR operations. 123 | -- The only valid cases are 'Success' and 'Failure'. 124 | newtype LogicalResult = UnsafeMkLogicalResult Int8 125 | deriving Storable via Int8 126 | deriving Eq 127 | 128 | instance Show LogicalResult where 129 | show Success = "Success" 130 | show Failure = "Failure" 131 | 132 | -- | Indicates a successful completion of an MLIR operation. 133 | pattern Success :: LogicalResult 134 | pattern Success = UnsafeMkLogicalResult 1 135 | -- | Indicates a filure of an MLIR operation. Inspect the diagnostics output 136 | -- to find the cause of the issue. 137 | pattern Failure :: LogicalResult 138 | pattern Failure = UnsafeMkLogicalResult 0 139 | 140 | {-# COMPLETE Success, Failure #-} 141 | 142 | mlirCtx :: C.Context 143 | mlirCtx = mempty { 144 | -- This is a lie... 145 | -- All of those types are really C structs that hold a single pointer, but 146 | -- dealing with structs is just way too complicated. For simplicity, we 147 | -- assume that the layout of the struct is equal to the layout of a single 148 | -- pointer here, but I'm not 100% sure if that's a good assumption. 149 | C.Context.ctxTypesTable = Map.fromList [ 150 | (C.TypeName "MlirContext", [t|Context|]) 151 | , (C.TypeName "MlirLocation", [t|Location|]) 152 | , (C.TypeName "MlirModule", [t|Module|]) 153 | , (C.TypeName "MlirOperation", [t|Operation|]) 154 | , (C.TypeName "MlirPassManager", [t|PassManager|]) 155 | , (C.TypeName "MlirPass", [t|Pass|]) 156 | , (C.TypeName "MlirExecutionEngine", [t|ExecutionEngine|]) 157 | , (C.TypeName "MlirLogicalResult", [t|LogicalResult|]) 158 | , (C.TypeName "MlirType", [t|Type|]) 159 | , (C.TypeName "MlirBlock", [t|Block|]) 160 | , (C.TypeName "MlirRegion", [t|Region|]) 161 | , (C.TypeName "MlirAttribute", [t|Attribute|]) 162 | , (C.TypeName "MlirNamedAttribute", [t|NamedAttribute|]) 163 | , (C.TypeName "MlirValue", [t|Value|]) 164 | , (C.TypeName "MlirIdentifier", [t|Identifier|]) 165 | , (C.TypeName "MlirAffineExpr", [t|AffineExpr|]) 166 | , (C.TypeName "MlirAffineMap", [t|AffineMap|]) 167 | ] 168 | } 169 | 170 | nullable :: Coercible a (Ptr ()) => a -> Maybe a 171 | nullable x = if coerce x == nullPtr then Nothing else Just x 172 | -------------------------------------------------------------------------------- /src/MLIR/Native/Pass.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.Native.Pass where 16 | 17 | import qualified Language.C.Inline as C 18 | 19 | import Control.Exception (bracket) 20 | 21 | import MLIR.Native.FFI 22 | 23 | C.context $ C.baseCtx <> mlirCtx 24 | 25 | C.include "mlir-c/IR.h" 26 | C.include "mlir-c/Pass.h" 27 | C.include "mlir-c/Conversion.h" 28 | 29 | -- TODO(apaszke): Flesh this out based on the header 30 | 31 | -------------------------------------------------------------------------------- 32 | -- Pass manager 33 | 34 | createPassManager :: Context -> IO PassManager 35 | createPassManager ctx = 36 | [C.exp| MlirPassManager { mlirPassManagerCreate($(MlirContext ctx)) } |] 37 | 38 | destroyPassManager :: PassManager -> IO () 39 | destroyPassManager pm = 40 | [C.exp| void { mlirPassManagerDestroy($(MlirPassManager pm)) } |] 41 | 42 | withPassManager :: Context -> (PassManager -> IO a) -> IO a 43 | withPassManager ctx = bracket (createPassManager ctx) destroyPassManager 44 | 45 | runPasses :: PassManager -> Operation -> IO LogicalResult 46 | runPasses pm op = 47 | [C.exp| MlirLogicalResult { mlirPassManagerRunOnOp($(MlirPassManager pm), $(MlirOperation op)) } |] 48 | 49 | -------------------------------------------------------------------------------- 50 | -- Transform passes 51 | 52 | -------------------------------------------------------------------------------- 53 | -- Conversion passes 54 | 55 | addConvertMemRefToLLVMPass :: PassManager -> IO () 56 | addConvertMemRefToLLVMPass pm = 57 | [C.exp| void { 58 | mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionFinalizeMemRefToLLVMConversionPass()) 59 | } |] 60 | 61 | addConvertArithToLLVMPass :: PassManager -> IO () 62 | addConvertArithToLLVMPass pm = 63 | [C.exp| void { 64 | mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionArithToLLVMConversionPass()) 65 | } |] 66 | 67 | addConvertControlFlowToLLVMPass :: PassManager -> IO () 68 | addConvertControlFlowToLLVMPass pm = 69 | [C.exp| void { 70 | mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionConvertControlFlowToLLVMPass()) 71 | } |] 72 | 73 | addConvertFuncToLLVMPass :: PassManager -> IO () 74 | addConvertFuncToLLVMPass pm = 75 | [C.exp| void { 76 | mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionConvertFuncToLLVMPass()) 77 | } |] 78 | 79 | addConvertVectorToLLVMPass :: PassManager -> IO () 80 | addConvertVectorToLLVMPass pm = 81 | [C.exp| void { 82 | mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionConvertVectorToLLVMPass()) 83 | } |] 84 | 85 | addConvertReconcileUnrealizedCastsPass :: PassManager -> IO () 86 | addConvertReconcileUnrealizedCastsPass pm = 87 | [C.exp| void { 88 | mlirPassManagerAddOwnedPass($(MlirPassManager pm), mlirCreateConversionReconcileUnrealizedCasts()) 89 | } |] 90 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-17.15 2 | 3 | packages: 4 | - . 5 | -------------------------------------------------------------------------------- /tblgen/hs-generators.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "llvm/ADT/STLExtras.h" 20 | #include "llvm/ADT/Sequence.h" 21 | #include "llvm/ADT/StringExtras.h" 22 | #include "llvm/ADT/StringMap.h" 23 | #include "llvm/ADT/StringRef.h" 24 | #include "llvm/ADT/StringSet.h" 25 | #include "llvm/ADT/iterator_range.h" 26 | #include "llvm/Support/CommandLine.h" 27 | #include "llvm/Support/FormatAdapters.h" 28 | #include "llvm/Support/FormatCommon.h" 29 | #include "llvm/Support/FormatVariadic.h" 30 | #include "llvm/Support/Path.h" 31 | #include "llvm/Support/Signals.h" 32 | #include "llvm/Support/raw_ostream.h" 33 | #include "llvm/TableGen/Error.h" 34 | #include "llvm/TableGen/Record.h" 35 | #include "llvm/TableGen/TableGenBackend.h" 36 | #include "mlir/TableGen/Argument.h" 37 | #include "mlir/TableGen/Class.h" 38 | #include "mlir/TableGen/CodeGenHelpers.h" 39 | #include "mlir/TableGen/Format.h" 40 | #include "mlir/TableGen/Interfaces.h" 41 | #include "mlir/TableGen/Operator.h" 42 | #include "mlir/TableGen/Region.h" 43 | #include "mlir/TableGen/SideEffects.h" 44 | #include "mlir/TableGen/Trait.h" 45 | 46 | namespace { 47 | 48 | llvm::cl::opt ExplainMissing( 49 | "explain-missing", 50 | llvm::cl::desc("Print the reason for skipping operations from output")); 51 | 52 | llvm::cl::opt StripOpPrefix( 53 | "strip-prefix", llvm::cl::desc("Prefix to strip from def names"), 54 | llvm::cl::value_desc("prefix")); 55 | 56 | llvm::cl::opt DialectName( 57 | "dialect-name", llvm::cl::desc("Override the inferred dialect name"), 58 | llvm::cl::value_desc("dialect")); 59 | 60 | template 61 | llvm::iterator_range make_range(const C& x) { 62 | return llvm::make_range(x.begin(), x.end()); 63 | } 64 | 65 | template ()( 67 | std::declval()))> 68 | std::vector map_vector(const C& container, FunTy f) { 69 | std::vector results; 70 | for (const auto& v : container) { 71 | results.push_back(f(v)); 72 | } 73 | return results; 74 | } 75 | 76 | void warn(llvm::StringRef op_name, const std::string& reason) { 77 | if (!ExplainMissing) return; 78 | llvm::errs() << llvm::formatv( 79 | "{0} {1}\n", llvm::fmt_align(op_name, llvm::AlignStyle::Left, 40), 80 | reason); 81 | } 82 | 83 | void warn(const mlir::tblgen::Operator& op, const std::string& reason) { 84 | warn(op.getOperationName(), reason); 85 | } 86 | 87 | struct AttrPatternTemplate { 88 | const char* _pattern; 89 | const char* _type; 90 | std::vector provided_constraints; 91 | std::vector type_var_defaults; 92 | }; 93 | 94 | using attr_print_state = llvm::StringSet<>; 95 | class AttrPattern { 96 | public: 97 | virtual ~AttrPattern() = default; 98 | virtual std::string type() const = 0; 99 | virtual std::string match(std::string name) const = 0; 100 | virtual const std::vector& provided_constraints() const = 0; 101 | virtual void print(llvm::raw_ostream& os, 102 | attr_print_state& optional_attr_defs) const = 0; 103 | }; 104 | 105 | struct NameSource { 106 | NameSource(const char* prefix) : prefix(prefix) {} 107 | NameSource(const NameSource&) = delete; 108 | std::string fresh() { return std::string(prefix) + std::to_string(suffix++); } 109 | private: 110 | const char* prefix; 111 | int suffix = 0; 112 | }; 113 | 114 | class SimpleAttrPattern : public AttrPattern { 115 | public: 116 | SimpleAttrPattern(const AttrPatternTemplate& tmpl, NameSource& gen) 117 | : _type_var_defaults(tmpl.type_var_defaults) { 118 | _pattern = tmpl._pattern; 119 | if (tmpl.type_var_defaults.empty()) { 120 | _type = tmpl._type; 121 | _provided_constraints = 122 | map_vector(tmpl.provided_constraints, 123 | [](const char* c) { return std::string(c); }); 124 | } else if (tmpl.type_var_defaults.size() == 1) { 125 | std::string var = gen.fresh(); 126 | _type_vars.push_back(var); 127 | _type = llvm::formatv(tmpl._type, var); 128 | _provided_constraints = map_vector( 129 | tmpl.provided_constraints, 130 | [&var](const char* c) { return llvm::formatv(c, var).str(); }); 131 | } else { 132 | std::abort(); // Not sure how to splat arbitrary many vars into formatv. 133 | } 134 | } 135 | 136 | std::string match(std::string name) const override { return llvm::formatv(_pattern, name); } 137 | std::string type() const override { return _type; } 138 | const std::vector& provided_constraints() const override { return _provided_constraints; } 139 | const std::vector& type_vars() const { return _type_vars; } 140 | const std::vector& type_var_defaults() const { return _type_var_defaults; } 141 | 142 | void print(llvm::raw_ostream& os, 143 | attr_print_state& optional_attr_defs) const override {} 144 | private: 145 | const char* _pattern; 146 | std::string _type; 147 | std::vector _provided_constraints; 148 | std::vector _type_vars; 149 | const std::vector _type_var_defaults; 150 | }; 151 | 152 | class OptionalAttrPattern : public AttrPattern { 153 | public: 154 | OptionalAttrPattern(llvm::StringRef attr_kind, SimpleAttrPattern base) 155 | : base(std::move(base)), attr_kind(attr_kind) {} 156 | 157 | std::string type() const override { 158 | return "Maybe " + base.type(); 159 | } 160 | std::string match(std::string name) const override { 161 | return llvm::formatv("Optional{0} {1}", attr_kind, name); 162 | } 163 | const std::vector& provided_constraints() const override { return base.provided_constraints(); } 164 | 165 | void print(llvm::raw_ostream& os, 166 | attr_print_state& optional_attr_defs) const override { 167 | if (!optional_attr_defs.contains(attr_kind)) { 168 | if (base.provided_constraints().empty()) { 169 | const char* kOptionalHandler = R"( 170 | pattern Optional{0} :: Maybe {1} -> Maybe Attribute 171 | pattern Optional{0} x <- ((\case Just ({2}) -> Just y; Nothing -> Nothing) -> x) 172 | where Optional{0} x = case x of Just y -> Just ({2}); Nothing -> Nothing 173 | )"; 174 | os << llvm::formatv(kOptionalHandler, attr_kind, base.type(), 175 | base.match("y")); 176 | } else { 177 | const char *kOptionalHandlerConstr = R"( 178 | data Maybe{0}Adapter = forall {4:$[ ]}. ({3:$[, ]}) => AdaptMaybe{0} (Maybe ({1})) 179 | 180 | unwrapMaybe{0} :: Maybe Attribute -> Maybe{0}Adapter 181 | unwrapMaybe{0} = \case 182 | Just ({2}) -> AdaptMaybe{0} (Just y) 183 | _ -> AdaptMaybe{0} {5:[]}Nothing 184 | 185 | pattern Optional{0} :: () => ({3:$[, ]}) => Maybe {1} -> Maybe Attribute 186 | pattern Optional{0} x <- (unwrapMaybe{0} -> AdaptMaybe{0} x) 187 | where Optional{0} x = case x of Just y -> Just ({2}); Nothing -> Nothing 188 | )"; 189 | std::vector default_apps; 190 | for (const char* d : base.type_var_defaults()) { 191 | default_apps.push_back("@" + std::string(d) + " "); 192 | } 193 | os << llvm::formatv(kOptionalHandlerConstr, 194 | attr_kind, // 0 195 | base.type(), // 1 196 | base.match("y"), // 2 197 | make_range(base.provided_constraints()), // 3 198 | make_range(base.type_vars()), // 4 199 | make_range(default_apps)); // 5 200 | } 201 | optional_attr_defs.insert(attr_kind); 202 | } 203 | } 204 | 205 | private: 206 | SimpleAttrPattern base; 207 | llvm::StringRef attr_kind; 208 | }; 209 | 210 | using attr_pattern_map = llvm::StringMap; 211 | 212 | const attr_pattern_map& getAttrPatternTemplates() { 213 | static const attr_pattern_map* kAttrHandlers = new attr_pattern_map{ 214 | {"AnyAttr", {"{0}", "Attribute", {}, {}}}, 215 | {"AffineMapArrayAttr", {"PatternUtil.AffineMapArrayAttr {0}", "[Affine.Map]", {}, {}}}, 216 | {"AffineMapAttr", {"AffineMapAttr {0}", "Affine.Map", {}, {}}}, 217 | {"ArrayAttr", {"ArrayAttr {0}", "[Attribute]", {}, {}}}, 218 | {"BoolAttr", {"BoolAttr {0}", "Bool", {}, {}}}, 219 | {"DenseI32ArrayAttr", {"PatternUtil.I32ArrayAttr {0}", "[Int]", {}, {}}}, 220 | {"DictionaryAttr", {"DictionaryAttr {0}", "(M.Map Name Attribute)", {}, {}}}, 221 | {"F32Attr", {"FloatAttr Float32Type {0}", "Double", {}, {}}}, 222 | {"F64Attr", {"FloatAttr Float64Type {0}", "Double", {}, {}}}, 223 | {"I32Attr", {"IntegerAttr (IntegerType Signless 32) {0}", "Int", {}, {}}}, 224 | {"I64Attr", {"IntegerAttr (IntegerType Signless 64) {0}", "Int", {}, {}}}, 225 | {"I64ArrayAttr", {"PatternUtil.I64ArrayAttr {0}", "[Int]", {}, {}}}, 226 | {"I64ElementsAttr", {"DenseElementsAttr (IntegerType Signless 64) (DenseInt64 {0})", 227 | "(AST.IStorableArray {0} Int64)", {"Ix {0}", "Show {0}"}, {"PatternUtil.DummyIx"}}}, 228 | {"IndexAttr", {"IntegerAttr IndexType {0}", "Int", {}, {}}}, 229 | {"StrAttr", {"StringAttr {0}", "BS.ByteString", {}, {}}}, 230 | // TODO(jpienaar): We could specialize this one more to query Type. 231 | {"TypedAttrInterface", {"{0}", "Attribute", {}, {}}}, 232 | }; 233 | return *kAttrHandlers; 234 | } 235 | 236 | // Returns nullptr when the attribute pattern couldn't be constructed. 237 | std::unique_ptr tryGetAttrPattern( 238 | const mlir::tblgen::NamedAttribute& nattr, NameSource& gen) { 239 | llvm::StringRef attr_kind = nattr.attr.getAttrDefName(); 240 | if (getAttrPatternTemplates().count(attr_kind) != 1) return nullptr; 241 | const AttrPatternTemplate& tmpl = getAttrPatternTemplates().lookup(attr_kind); 242 | if (!nattr.attr.isOptional()) { 243 | return std::make_unique(tmpl, gen); 244 | } else { 245 | auto pat = std::make_unique( 246 | attr_kind, SimpleAttrPattern(tmpl, gen)); 247 | return pat; 248 | } 249 | } 250 | 251 | const std::string sanitizeName(llvm::StringRef name, std::optional idx = std::nullopt) { 252 | static const llvm::StringSet<>* kReservedNames = new llvm::StringSet<>{ 253 | // TODO(apaszke): Add more keywords 254 | // Haskell keywords 255 | "in", "data", "if" 256 | }; 257 | if (name.empty()) { 258 | assert(idx); 259 | return llvm::formatv("_unnamed{0}", *idx); 260 | } else if (kReservedNames->contains(name)) { 261 | auto new_name = name.str(); 262 | new_name.push_back('_'); 263 | return new_name; 264 | } else { 265 | return name.str(); 266 | } 267 | } 268 | 269 | std::string getDialectName(llvm::ArrayRef op_defs) { 270 | mlir::tblgen::Operator any_op(op_defs.front()); 271 | assert(std::all_of( 272 | op_defs.begin(), op_defs.end(), [&any_op](const llvm::Record* op) { 273 | return mlir::tblgen::Operator(op).getDialectName() == 274 | any_op.getDialectName(); 275 | })); 276 | std::string dialect_name; 277 | if (DialectName.empty()) { 278 | dialect_name = any_op.getDialectName().str(); 279 | dialect_name[0] = llvm::toUpper(dialect_name[0]); 280 | } else { 281 | dialect_name = DialectName; 282 | } 283 | return dialect_name; 284 | } 285 | 286 | class OpAttrPattern { 287 | OpAttrPattern(std::string name, std::vector binders, 288 | std::vector attrs, 289 | std::vector> patterns) 290 | : name(std::move(name)), 291 | binders(std::move(binders)), 292 | attrs(std::move(attrs)), 293 | patterns(std::move(patterns)) {} 294 | 295 | public: 296 | static std::optional buildFor(mlir::tblgen::Operator& op) { 297 | if (op.getNumAttributes() == 0) return OpAttrPattern("NoAttrs", {}, {}, {}); 298 | 299 | NameSource gen("a"); 300 | std::vector binders; 301 | std::vector attrs; 302 | std::vector> patterns; 303 | for (const auto& named_attr : op.getAttributes()) { 304 | // Derived attributes are never materialized and don't have to be 305 | // specified. 306 | if (named_attr.attr.isDerivedAttr()) continue; 307 | 308 | auto pattern = tryGetAttrPattern(named_attr, gen); 309 | if (!pattern) { 310 | if (named_attr.attr.hasDefaultValue()) { 311 | warn(op, llvm::formatv("unsupported attr {0} (but has default value)", 312 | named_attr.attr.getAttrDefName())); 313 | continue; 314 | } 315 | if (named_attr.attr.isOptional()) { 316 | warn(op, llvm::formatv("unsupported attr {0} (but is optional)", 317 | named_attr.attr.getAttrDefName())); 318 | continue; 319 | } 320 | warn(op, llvm::formatv("unsupported attr ({0})", 321 | named_attr.attr.getAttrDefName())); 322 | return std::nullopt; 323 | } 324 | binders.push_back(sanitizeName(named_attr.name) + "_"); 325 | attrs.push_back(named_attr); 326 | patterns.push_back(std::move(pattern)); 327 | } 328 | if (binders.empty()) return OpAttrPattern("NoAttrs", {}, {}, {}); 329 | std::string name = "Internal" + op.getCppClassName().str() + "Attributes"; 330 | return OpAttrPattern(std::move(name), std::move(binders), std::move(attrs), 331 | std::move(patterns)); 332 | } 333 | 334 | void print(llvm::raw_ostream& os, attr_print_state& optional_attr_defs) { 335 | if (name == "NoAttrs") return; 336 | // `M.lookup "attr_name" m` for every attribute 337 | std::vector lookups; 338 | // Patterns from handlers, but wrapped in "Just (...)" when non-optional 339 | std::vector lookup_patterns; 340 | // `[("attr_name", attr_pattern)]` for every non-optional attribute 341 | std::vector singleton_pairs; 342 | for (size_t i = 0; i < attrs.size(); ++i) { 343 | const mlir::tblgen::NamedAttribute& nattr = attrs[i]; 344 | const AttrPattern& pattern = *patterns[i]; 345 | pattern.print(os, optional_attr_defs); 346 | lookups.push_back(llvm::formatv("M.lookup \"{0}\" m", nattr.name)); 347 | std::string inst_pattern = pattern.match(binders[i]); 348 | if (nattr.attr.isOptional()) { 349 | lookup_patterns.push_back(inst_pattern); 350 | singleton_pairs.push_back(llvm::formatv( 351 | "(Data.Maybe.maybeToList $ (\"{0}\",) <$> {1})", nattr.name, inst_pattern)); 352 | } else { 353 | lookup_patterns.push_back(llvm::formatv("Just ({0})", inst_pattern)); 354 | singleton_pairs.push_back( 355 | llvm::formatv("[(\"{0}\", {1})]", nattr.name, inst_pattern)); 356 | } 357 | } 358 | const char* kAttributePattern = R"( 359 | pattern {0} :: () => ({6:$[, ]}) => {1:$[ -> ]} -> NamedAttributes 360 | pattern {0} {2:$[ ]} <- ((\m -> ({3:$[, ]})) -> ({4:$[, ]})) 361 | where {0} {2:$[ ]} = M.fromList $ {5:$[ ++ ]} 362 | )"; 363 | os << llvm::formatv(kAttributePattern, 364 | name, // 0 365 | make_range(types()), // 1 366 | make_range(binders), // 2 367 | make_range(lookups), // 3 368 | make_range(lookup_patterns), // 4 369 | make_range(singleton_pairs), // 5 370 | make_range(provided_constraints())); // 6 371 | } 372 | 373 | std::vector types() const { 374 | return map_vector(patterns, [](const std::unique_ptr& p) { 375 | return p->type(); 376 | }); 377 | } 378 | std::vector provided_constraints() const { 379 | std::vector result; 380 | for (auto& p : patterns) { 381 | for (auto& c : p->provided_constraints()) { 382 | result.push_back(c); 383 | } 384 | } 385 | return result; 386 | } 387 | 388 | std::string name; 389 | std::vector binders; 390 | 391 | private: 392 | std::vector attrs; 393 | std::vector> patterns; 394 | }; 395 | 396 | std::optional buildOperation( 397 | const llvm::Record* def, bool is_pattern, const std::string& what_for, 398 | const std::string& location_expr, 399 | const std::vector& type_exprs, 400 | const std::vector& operand_exprs, 401 | const std::vector& region_exprs, 402 | const OpAttrPattern& attr_pattern) { 403 | mlir::tblgen::Operator op(def); 404 | auto fail = [&op, &what_for](std::string reason) { 405 | warn(op, llvm::formatv("couldn't construct {0}: {1}", what_for, reason)); 406 | return std::optional(); 407 | }; 408 | 409 | // Skip currently unsupported cases 410 | if (op.getNumVariadicRegions() != 0) return fail("variadic regions"); 411 | if (op.getNumSuccessors() != 0) return fail("successors"); 412 | 413 | // Prepare results 414 | std::string type_expr; 415 | if (op.getNumResults() == 0) { 416 | assert(type_exprs.size() == op.getNumResults()); 417 | type_expr = "[]"; 418 | } else if (op.getNumVariableLengthResults() == 0 && 419 | op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 420 | assert(type_exprs.size() == 1); 421 | type_expr = llvm::formatv("[{0:$[, ]}]", 422 | make_range(std::vector( 423 | op.getNumResults(), type_exprs.front()))); 424 | } else if (op.getNumVariableLengthResults() == 0) { 425 | assert(type_exprs.size() == op.getNumResults()); 426 | type_expr = llvm::formatv("[{0:$[, ]}]", make_range(type_exprs)); 427 | } else if (!is_pattern) { 428 | assert(type_exprs.size() == op.getNumResults()); 429 | std::vector list_type_exprs; 430 | for (int i = 0; i < op.getNumResults(); ++i) { 431 | auto& result = op.getResult(i); 432 | if (result.isOptional()) { 433 | list_type_exprs.push_back("(Data.Maybe.maybeToList " + type_exprs[i] + ")"); 434 | } else if (result.isVariadic()) { 435 | list_type_exprs.push_back(type_exprs[i]); 436 | } else { 437 | assert(!result.isVariableLength()); 438 | list_type_exprs.push_back("[" + type_exprs[i] + "]"); 439 | } 440 | } 441 | type_expr = llvm::formatv("({0:$[ ++ ]})", make_range(list_type_exprs)); 442 | } else { 443 | return fail("unsupported variable length results"); 444 | } 445 | 446 | // Prepare operands 447 | std::string operand_expr; 448 | assert(operand_exprs.size() == op.getNumOperands()); 449 | if (op.getNumOperands() == 1 && op.getOperand(0).isVariadic()) { 450 | // Note that this expr already should represent a list 451 | operand_expr = operand_exprs.front(); 452 | } else if (op.getNumVariableLengthOperands() == 0) { 453 | operand_expr = llvm::formatv("[{0:$[, ]}]", make_range(operand_exprs)); 454 | } else if (!is_pattern) { 455 | std::vector operand_list_exprs; 456 | for (int i = 0; i < op.getNumOperands(); ++i) { 457 | auto& operand = op.getOperand(i); 458 | if (operand.isOptional()) { 459 | operand_list_exprs.push_back("(Data.Maybe.maybeToList " + operand_exprs[i] + ")"); 460 | } else if (operand.isVariadic()) { 461 | operand_list_exprs.push_back(operand_exprs[i]); 462 | } else { 463 | assert(!operand.isVariableLength()); 464 | operand_list_exprs.push_back("[" + operand_exprs[i] + "]"); 465 | } 466 | } 467 | operand_expr = 468 | llvm::formatv("({0:$[ ++ ]})", make_range(operand_list_exprs)); 469 | } else { 470 | return fail("unsupported variable length operands"); 471 | } 472 | 473 | std::string extra_attrs; 474 | if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 475 | std::vector segment_sizes; 476 | for (int i = 0; i < op.getNumOperands(); ++i) { 477 | auto& operand = op.getOperand(i); 478 | if (operand.isOptional()) { 479 | segment_sizes.push_back(llvm::formatv( 480 | "case {0} of Just _ -> 1; Nothing -> 0", operand_exprs[i])); 481 | } else if (operand.isVariadic()) { 482 | segment_sizes.push_back("Prelude.length " + operand_exprs[i]); 483 | } else { 484 | assert(!operand.isVariableLength()); 485 | segment_sizes.push_back("1"); 486 | } 487 | } 488 | const char* kOperandSegmentsAttr = R"( 489 | <> AST.namedAttribute "operand_segment_sizes" 490 | (DenseElementsAttr (VectorType [{0}] $ IntegerType Unsigned 32) $ 491 | DenseUInt32 $ IArray.listArray (1 :: Int, {0}) $ Prelude.fromIntegral <$> [{1:$[, ]}]) 492 | )"; 493 | extra_attrs = llvm::formatv(kOperandSegmentsAttr, 494 | segment_sizes.size(), 495 | make_range(segment_sizes)); 496 | } 497 | 498 | const char* kPatternExplicitType = R"(Operation 499 | { opName = "{0}" 500 | , opLocation = {1} 501 | , opResultTypes = Explicit {2} 502 | , opOperands = {3} 503 | , opRegions = [{4:$[ , ]}] 504 | , opSuccessors = [] 505 | , opAttributes = ({5}{6}{7:$[ ]}){8} 506 | })"; 507 | return llvm::formatv(kPatternExplicitType, 508 | op.getOperationName(), // 0 509 | location_expr, // 1 510 | type_expr, // 2 511 | operand_expr, // 3 512 | make_range(region_exprs), // 4 513 | attr_pattern.name, // 5 514 | attr_pattern.binders.empty() ? "" : " ", // 6 515 | make_range(attr_pattern.binders), // 7 516 | extra_attrs) // 8 517 | .str(); 518 | } 519 | 520 | // TODO(apaszke): Make this more reliable 521 | std::string legalizeBuilderName(std::string name) { 522 | for (size_t i = 0; i < name.size(); ++i) { 523 | if (name[i] == '.') name[i] = '_'; 524 | } 525 | return name; 526 | } 527 | 528 | std::string stripDialect(std::string name) { 529 | size_t dialect_sep_loc = name.find('.'); 530 | assert(dialect_sep_loc != std::string::npos); 531 | return name.substr(dialect_sep_loc + 1); 532 | } 533 | 534 | void emitBuilderMethod(mlir::tblgen::Operator& op, 535 | const OpAttrPattern& attr_pattern, llvm::raw_ostream& os) { 536 | auto fail = [&op](std::string reason) { 537 | warn(op, "couldn't construct builder: " + reason); 538 | }; 539 | 540 | if (op.getNumVariadicRegions() != 0) return fail("variadic regions"); 541 | if (op.getNumSuccessors() != 0) return fail("successors"); 542 | 543 | const char* result_type; 544 | std::string prologue; 545 | const char* continuation = ""; 546 | if (op.getNumResults() == 0) { 547 | prologue = "Control.Monad.void "; 548 | if (op.getTrait("::mlir::OpTrait::IsTerminator")) { 549 | result_type = "EndOfBlock"; 550 | continuation = "\n AST.terminateBlock"; 551 | } else { 552 | result_type = "()"; 553 | } 554 | } else if (op.getNumResults() == 1) { 555 | result_type = "Value"; 556 | prologue = "Control.Monad.liftM Prelude.head "; 557 | } else { 558 | result_type = "[Value]"; 559 | prologue = ""; 560 | } 561 | 562 | std::string builder_name = sanitizeName(legalizeBuilderName(stripDialect(op.getOperationName()))); 563 | 564 | std::vector builder_arg_types; 565 | 566 | // TODO(apaszke): Use inference (op.getSameTypeAsResult) 567 | std::vector type_exprs; 568 | std::vector type_binders; 569 | if (op.getNumResults() == 0) { 570 | // Nothing to do. 571 | } else if (op.getNumVariableLengthResults() == 0 && 572 | op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 573 | for (const mlir::tblgen::NamedTypeConstraint& operand : op.getOperands()) { 574 | if (operand.isVariableLength()) continue; 575 | type_exprs.push_back("(AST.typeOf " + sanitizeName(operand.name) + "_)"); 576 | break; 577 | } 578 | if (type_exprs.empty()) return fail("type inference failed"); 579 | } else { 580 | int result_nr = 0; 581 | for (const mlir::tblgen::NamedTypeConstraint& result : op.getResults()) { 582 | type_binders.push_back(llvm::formatv("ty{0}", result_nr++)); 583 | type_exprs.push_back(type_binders.back()); 584 | if (result.isOptional()) { 585 | builder_arg_types.push_back("Maybe Type"); 586 | } else if (result.isVariadic()) { 587 | builder_arg_types.push_back("[Type]"); 588 | } else { 589 | assert(!result.isVariableLength()); 590 | builder_arg_types.push_back("Type"); 591 | } 592 | } 593 | } 594 | 595 | std::vector operand_binders; 596 | std::vector operand_name_exprs; 597 | operand_name_exprs.reserve(op.getNumOperands()); 598 | for (int i = 0; i < op.getNumOperands(); ++i) { 599 | const auto& operand = op.getOperand(i); 600 | std::string operand_name = sanitizeName(operand.name, i) + "_"; 601 | operand_binders.push_back(operand_name); 602 | if (operand.isOptional()) { 603 | builder_arg_types.push_back("Maybe Value"); 604 | operand_name_exprs.push_back("(AST.operand <$> " + operand_name + ")"); 605 | } else if (operand.isVariadic()) { 606 | builder_arg_types.push_back("[Value]"); 607 | operand_name_exprs.push_back("(AST.operands " + operand_name + ")"); 608 | } else { 609 | assert(!operand.isVariableLength()); 610 | builder_arg_types.push_back("Value"); 611 | operand_name_exprs.push_back("(AST.operand " + operand_name + ")"); 612 | } 613 | } 614 | 615 | auto attr_types = attr_pattern.types(); 616 | builder_arg_types.insert(builder_arg_types.end(), attr_types.begin(), 617 | attr_types.end()); 618 | 619 | std::vector region_builder_binders; 620 | std::vector region_binders; 621 | if (op.getNumRegions() > 0) { 622 | std::string region_prologue; 623 | NameSource gen("_unnamed_region"); 624 | for (const mlir::tblgen::NamedRegion& region : op.getRegions()) { 625 | std::string name = region.name.empty() ? gen.fresh() : sanitizeName(region.name) + "_"; 626 | region_builder_binders.push_back(name + "Builder"); 627 | region_binders.push_back(name); 628 | builder_arg_types.push_back("RegionBuilderT m ()"); 629 | region_prologue += llvm::formatv( 630 | "{0} <- AST.buildRegion {1}\n ", 631 | region_binders.back(), region_builder_binders.back()); 632 | } 633 | prologue = region_prologue + prologue; 634 | } 635 | 636 | builder_arg_types.push_back(""); // To add the arrow before m 637 | 638 | std::optional operation = 639 | buildOperation(&op.getDef(), false, "builder", "UnknownLocation", 640 | type_exprs, operand_name_exprs, region_binders, 641 | attr_pattern); 642 | if (!operation) return; 643 | 644 | const char* kBuilder = R"( 645 | -- | A builder for @{10}@. 646 | {0} :: ({11:$[, ]}) => MonadBlockBuilder m => {1:$[ -> ]}m {2} 647 | {0} {3:$[ ]} {4:$[ ]} {5:$[ ]} {6:$[ ]} = do 648 | {7}(AST.emitOp ({8})){9} 649 | )"; 650 | os << llvm::formatv(kBuilder, 651 | builder_name, // 0 652 | make_range(builder_arg_types), // 1 653 | result_type, // 2 654 | make_range(type_binders), // 3 655 | make_range(operand_binders), // 4 656 | make_range(attr_pattern.binders), // 5 657 | make_range(region_builder_binders), // 6 658 | prologue, // 7 659 | *operation, // 8 660 | continuation, // 9 661 | op.getOperationName(), // 10 662 | make_range(attr_pattern.provided_constraints())); // 11 663 | } 664 | 665 | void emitPattern(const llvm::Record* def, const OpAttrPattern& attr_pattern, 666 | llvm::raw_ostream& os) { 667 | mlir::tblgen::Operator op(def); 668 | auto fail = [&op](std::string reason) { 669 | return warn(op, llvm::formatv("couldn't construct pattern: {0}", reason)); 670 | }; 671 | 672 | // Skip currently unsupported cases 673 | if (op.getNumVariableLengthResults() != 0) return fail("variadic results"); 674 | if (op.getNumRegions() != 0) return fail("regions"); 675 | if (op.getNumSuccessors() != 0) return fail("successors"); 676 | if (!def->getName().ends_with("Op")) return fail("unsupported name format"); 677 | if (!def->getName().starts_with(StripOpPrefix)) return fail("missing prefix"); 678 | 679 | // Drop the stripped prefix and "Op" from the end. 680 | llvm::StringRef pattern_name = 681 | def->getName().drop_back(2).drop_front(StripOpPrefix.length()); 682 | 683 | std::vector pattern_arg_types{"Location"}; 684 | 685 | // Prepare results 686 | std::vector type_binders; 687 | if (op.getNumResults() > 0 && 688 | op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 689 | assert(op.getNumVariableLengthResults() == 0); 690 | pattern_arg_types.push_back("Type"); 691 | type_binders.push_back("ty"); 692 | } else { 693 | size_t result_count = 0; 694 | for (int i = 0; i < op.getNumResults(); ++i) { 695 | pattern_arg_types.push_back("Type"); 696 | type_binders.push_back(llvm::formatv("ty{0}", result_count++)); 697 | } 698 | } 699 | 700 | // Prepare operands 701 | std::vector operand_binders; 702 | if (op.getNumOperands() == 1 && op.getOperand(0).isVariadic()) { 703 | // Single variadic arg is easy to handle 704 | pattern_arg_types.push_back("[operand]"); 705 | operand_binders.push_back(sanitizeName(op.getOperand(0).name, 0) + "_"); 706 | } else { 707 | // Non-variadic case 708 | for (int i = 0; i < op.getNumOperands(); ++i) { 709 | const auto& operand = op.getOperand(i); 710 | if (operand.isVariableLength()) 711 | return fail("unsupported variable length operand"); 712 | pattern_arg_types.push_back("operand"); 713 | operand_binders.push_back(sanitizeName(operand.name, i) + "_"); 714 | } 715 | } 716 | 717 | // Prepare attribute pattern 718 | auto attr_types = attr_pattern.types(); 719 | pattern_arg_types.insert(pattern_arg_types.end(), attr_types.begin(), 720 | attr_types.end()); 721 | 722 | std::optional operation = buildOperation( 723 | def, true, "pattern", "loc", 724 | type_binders, operand_binders, {}, attr_pattern); 725 | if (!operation) return; 726 | 727 | const char* kPatternExplicitType = R"( 728 | -- | A pattern for @{6}@. 729 | pattern {0} :: () => ({7:$[, ]}) => {1:$[ -> ]} -> AbstractOperation operand 730 | pattern {0} loc {2:$[ ]} {3:$[ ]} {4:$[ ]} = {5} 731 | )"; 732 | os << llvm::formatv(kPatternExplicitType, 733 | pattern_name, // 0 734 | make_range(pattern_arg_types), // 1 735 | make_range(type_binders), // 2 736 | make_range(operand_binders), // 3 737 | make_range(attr_pattern.binders), // 4 738 | *operation, // 5 739 | op.getOperationName(), // 6 740 | make_range(attr_pattern.provided_constraints())); // 7 741 | 742 | } 743 | 744 | std::string formatDescription(mlir::tblgen::Operator op) { 745 | std::string description; 746 | description = "\n" + op.getDescription().str(); 747 | size_t pos = 0; 748 | while (description[pos] == '\n') ++pos; 749 | size_t leading_spaces = 0; 750 | while (description[pos++] == ' ') ++leading_spaces; 751 | if (leading_spaces) { 752 | std::string leading_spaces_str; 753 | for (size_t i = 0; i < leading_spaces; ++i) leading_spaces_str += "[ ]"; 754 | description = std::regex_replace(description, std::regex("\n" + leading_spaces_str), "\n"); 755 | } 756 | description = std::regex_replace(description, std::regex("\\[(.*)\\]\\(.*\\)"), "$1"); 757 | description = std::regex_replace(description, std::regex("(['\"@<$#])"), "\\$1"); 758 | description = std::regex_replace(description, std::regex("```mlir"), "@"); 759 | description = std::regex_replace(description, std::regex("```"), "@"); 760 | description = std::regex_replace(description, std::regex("`"), "@"); 761 | description = std::regex_replace(description, std::regex("\n"), "\n-- "); 762 | return description; 763 | } 764 | 765 | } // namespace 766 | 767 | bool emitOpTableDefs(const llvm::RecordKeeper& recordKeeper, 768 | llvm::raw_ostream& os) { 769 | auto defs = recordKeeper.getAllDerivedDefinitions("Op"); 770 | 771 | if (defs.empty()) return true; 772 | // TODO(apaszke): Emit a module header to avoid leaking internal definitions. 773 | auto dialect_name = getDialectName(defs); 774 | os << "{-# OPTIONS_GHC -Wno-unused-imports #-}\n"; 775 | os << "{-# OPTIONS_HADDOCK hide, prune, not-home #-}\n\n"; 776 | os << "module MLIR.AST.Dialect.Generated." << dialect_name << " where\n"; 777 | os << R"( 778 | import Prelude (Int, Double, Maybe(..), Bool(..), (++), (<$>), ($), (<>), Show) 779 | import qualified Prelude 780 | import Data.Int (Int64) 781 | import qualified Data.Maybe 782 | import Data.Array (Ix) 783 | import qualified Data.Array.IArray as IArray 784 | import qualified Data.ByteString as BS 785 | import qualified Data.Map.Strict as M 786 | import qualified Control.Monad 787 | 788 | import MLIR.AST ( Attribute(..), Type(..), AbstractOperation(..), ResultTypes(..) 789 | , Location(..), Signedness(..), DenseElements(..) 790 | , NamedAttributes, Name 791 | , pattern NoAttrs ) 792 | import qualified MLIR.AST as AST 793 | import MLIR.AST.Builder (Value, EndOfBlock, MonadBlockBuilder, RegionBuilderT) 794 | import qualified MLIR.AST.Builder as AST 795 | import qualified MLIR.AST.IStorableArray as AST 796 | import qualified MLIR.AST.PatternUtil as PatternUtil 797 | import qualified MLIR.AST.Dialect.Affine as Affine 798 | )"; 799 | 800 | attr_print_state attr_pattern_state; 801 | for (const auto* def : defs) { 802 | mlir::tblgen::Operator op(*def); 803 | if (op.hasDescription()) { 804 | os << llvm::formatv("\n-- * {0}\n-- ${0}", stripDialect(op.getOperationName())); 805 | os << formatDescription(op); 806 | os << "\n"; 807 | } 808 | std::optional attr_pattern = OpAttrPattern::buildFor(op); 809 | if (!attr_pattern) continue; 810 | attr_pattern->print(os, attr_pattern_state); 811 | emitPattern(def, *attr_pattern, os); 812 | emitBuilderMethod(op, *attr_pattern, os); 813 | } 814 | 815 | return false; 816 | } 817 | 818 | bool emitTestTableDefs(const llvm::RecordKeeper& recordKeeper, 819 | llvm::raw_ostream& os) { 820 | auto defs = recordKeeper.getAllDerivedDefinitions("Op"); 821 | if (defs.empty()) return true; 822 | 823 | auto dialect_name = getDialectName(defs); 824 | os << "{-# OPTIONS_GHC -Wno-unused-imports #-}\n\n"; 825 | const char* module_header = R"( 826 | module MLIR.AST.Dialect.Generated.{0}Spec where 827 | 828 | import Prelude (IO, Maybe(..), ($), (<>)) 829 | import qualified Prelude 830 | import Test.Hspec (Spec) 831 | import qualified Test.Hspec as Hspec 832 | import Test.QuickCheck ((===)) 833 | import qualified Test.QuickCheck as QC 834 | 835 | import MLIR.AST (pattern NoAttrs) 836 | import MLIR.AST.Dialect.{0} 837 | 838 | import MLIR.Test.Generators () 839 | 840 | main :: IO () 841 | main = Hspec.hspec spec 842 | 843 | spec :: Spec 844 | spec = do 845 | )"; 846 | os << llvm::formatv(module_header, dialect_name); 847 | for (const auto* def : defs) { 848 | mlir::tblgen::Operator op(*def); 849 | std::optional attr_pattern = OpAttrPattern::buildFor(op); 850 | if (!attr_pattern) continue; 851 | os << "\n Hspec.describe \"" << op.getOperationName() << "\" $ do"; 852 | const char* bidirectional_test_template = R"( 853 | Hspec.it "has a bidirectional attr pattern" $ do 854 | let wrapUnwrap ({1:$[, ]}) = case ({0} {1:$[ ]}) <> Prelude.mempty of 855 | {0} {2:$[ ]} -> Just ({2:$[, ]}) 856 | _ -> Nothing 857 | QC.property $ \args -> wrapUnwrap args === Just args 858 | )"; 859 | os << llvm::formatv( 860 | bidirectional_test_template, attr_pattern->name, 861 | make_range(attr_pattern->binders), 862 | make_range(map_vector(attr_pattern->binders, [](const std::string& b) { 863 | return b + "_match"; 864 | }))); 865 | const char* pattern_extensibility_test_template = R"( 866 | Hspec.it "accepts additional attributes" $ do 867 | QC.property $ do 868 | ({1:$[, ]}) <- QC.arbitrary 869 | extraAttrs <- QC.arbitrary 870 | let match = case ({0} {1:$[ ]}) <> extraAttrs of 871 | {0} {2:$[ ]} -> Just ({2:$[, ]}) 872 | _ -> Nothing 873 | Prelude.return $ match === Just ({1:$[, ]}) 874 | )"; 875 | os << llvm::formatv( 876 | pattern_extensibility_test_template, attr_pattern->name, 877 | make_range(attr_pattern->binders), 878 | make_range(map_vector(attr_pattern->binders, [](const std::string& b) { 879 | return b + "_match"; 880 | }))); 881 | // TODO(apaszke): Test attr pattern matches with more attributes. 882 | // TODO(apaszke): Test bidirectionality of op pattern. 883 | // TODO(apaszke): Test op pattern matches with more attributes. 884 | // TODO(apaszke): Test builder output matches op pattern. 885 | // TODO(apaszke): Figure out how to do tests with translation. 886 | } 887 | return false; 888 | } 889 | -------------------------------------------------------------------------------- /tblgen/mlir-hs-tblgen.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "llvm/ADT/StringExtras.h" 16 | #include "llvm/Support/CommandLine.h" 17 | #include "llvm/Support/FormatVariadic.h" 18 | #include "llvm/Support/InitLLVM.h" 19 | #include "llvm/Support/ManagedStatic.h" 20 | #include "llvm/Support/Signals.h" 21 | #include "llvm/TableGen/Error.h" 22 | #include "llvm/TableGen/Main.h" 23 | #include "llvm/TableGen/Record.h" 24 | #include "llvm/TableGen/TableGenBackend.h" 25 | 26 | using namespace llvm; 27 | 28 | using generator_function = bool(const llvm::RecordKeeper& recordKeeper, 29 | llvm::raw_ostream& os); 30 | 31 | struct GeneratorInfo { 32 | const char* name; 33 | generator_function* generator; 34 | }; 35 | 36 | extern generator_function emitOpTableDefs; 37 | extern generator_function emitTestTableDefs; 38 | 39 | static std::array generators {{ 40 | {"hs-op-defs", emitOpTableDefs}, 41 | {"hs-tests", emitTestTableDefs}, 42 | }}; 43 | 44 | generator_function* generator; 45 | 46 | int main(int argc, char **argv) { 47 | llvm::InitLLVM y(argc, argv); 48 | llvm::cl::opt generatorOpt("generator", llvm::cl::desc("Generator to run"), cl::Required); 49 | cl::ParseCommandLineOptions(argc, argv); 50 | for (const auto& spec : generators) { 51 | if (generatorOpt == spec.name) { 52 | generator = spec.generator; 53 | break; 54 | } 55 | } 56 | if (!generator) { 57 | llvm::errs() << "Invalid generator type\n"; 58 | abort(); 59 | } 60 | 61 | return TableGenMain(argv[0], [](raw_ostream& os, const RecordKeeper &records) { 62 | return generator(records, os); 63 | }); 64 | } 65 | -------------------------------------------------------------------------------- /test/MLIR/ASTSpec.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.ASTSpec where 16 | 17 | import Test.Hspec 18 | 19 | import Text.RawString.QQ 20 | import Data.Int 21 | import Data.Char 22 | import Data.Maybe 23 | import Data.Foldable 24 | import Foreign.Ptr 25 | import Foreign.Storable 26 | import Foreign.ForeignPtr 27 | import qualified Data.ByteString as BS 28 | import qualified Data.Vector.Storable as V 29 | import Control.Monad.Trans.Cont 30 | import Control.Monad.IO.Class 31 | 32 | import MLIR.AST 33 | import MLIR.AST.Serialize 34 | import qualified MLIR.AST.Dialect.Arith as Arith 35 | import qualified MLIR.AST.Dialect.Func as Func 36 | import qualified MLIR.AST.Dialect.MemRef as MemRef 37 | import qualified MLIR.AST.Dialect.Affine as Affine 38 | import qualified MLIR.AST.Dialect.Vector as Vector 39 | import qualified MLIR.Native as MLIR 40 | import qualified MLIR.Native.Pass as MLIR 41 | import qualified MLIR.Native.ExecutionEngine as MLIR 42 | 43 | 44 | newtype AlignedStorable a = Aligned a 45 | deriving instance Num a => Num (AlignedStorable a) 46 | deriving instance Fractional a => Fractional (AlignedStorable a) 47 | deriving instance Show a => Show (AlignedStorable a) 48 | deriving instance Eq a => Eq (AlignedStorable a) 49 | deriving instance Ord a => Ord (AlignedStorable a) 50 | instance Storable a => Storable (AlignedStorable a) where 51 | sizeOf (Aligned x) = sizeOf x 52 | alignment (Aligned _) = 64 53 | peek ptr = Aligned <$> peek (castPtr ptr) 54 | poke ptr (Aligned x) = poke (castPtr ptr) x 55 | 56 | 57 | trimLeadingSpaces :: BS.ByteString -> BS.ByteString 58 | trimLeadingSpaces str = BS.intercalate "\n" strippedLines 59 | where 60 | space = fromIntegral $ ord ' ' 61 | ls = BS.split (fromIntegral $ ord '\n') str 62 | indentDepth = fromJust $ asum $ (BS.findIndex (/= space)) <$> filter (/= "") ls 63 | indent = BS.replicate indentDepth space 64 | strippedLines = flip fmap ls \case "" -> "" 65 | l -> fromJust $ BS.stripPrefix indent l 66 | 67 | 68 | shouldShowAs :: Operation -> BS.ByteString -> Expectation 69 | shouldShowAs op expectedWithLeadingNewline = do 70 | MLIR.withContext \ctx -> do 71 | MLIR.registerAllDialects ctx 72 | nativeOp <- fromAST ctx (mempty, mempty) op 73 | MLIR.verifyOperation nativeOp >>= (`shouldBe` True) 74 | let expected = trimLeadingSpaces $ BS.append (BS.tail expectedWithLeadingNewline) "\n" 75 | MLIR.showOperation nativeOp >>= (`shouldBe` expected) 76 | 77 | shouldShowWithLocationAs :: Operation -> BS.ByteString -> Expectation 78 | shouldShowWithLocationAs op expectedWithLeadingNewline = do 79 | MLIR.withContext \ctx -> do 80 | MLIR.registerAllDialects ctx 81 | nativeOp <- fromAST ctx (mempty, mempty) op 82 | MLIR.verifyOperation nativeOp >>= (`shouldBe` True) 83 | let expected = trimLeadingSpaces $ BS.append (BS.tail expectedWithLeadingNewline) "\n" 84 | MLIR.showOperationWithLocation nativeOp >>= (`shouldBe` expected) 85 | 86 | shouldImplementMatmul :: Operation -> Expectation 87 | shouldImplementMatmul op = evalContT $ do 88 | ctx <- ContT $ MLIR.withContext 89 | m <- liftIO $ do 90 | MLIR.registerAllDialects ctx 91 | o <- fromAST ctx (mempty, mempty) op 92 | Just m <- MLIR.moduleFromOperation o -- fromAST ctx (mempty, mempty) op 93 | MLIR.withPassManager ctx \pm -> do 94 | MLIR.addConvertMemRefToLLVMPass pm 95 | MLIR.addConvertVectorToLLVMPass pm 96 | MLIR.addConvertFuncToLLVMPass pm 97 | MLIR.addConvertArithToLLVMPass pm 98 | MLIR.addConvertControlFlowToLLVMPass pm 99 | MLIR.addConvertReconcileUnrealizedCastsPass pm 100 | result <- MLIR.runPasses pm o 101 | result `shouldBe` MLIR.Success 102 | return m 103 | (a, _ ) <- withMemrefArg0 $ V.unsafeThaw $ V.iterateN 64 (+1.0) (1.0 :: AlignedStorable Float) 104 | (b, _ ) <- withMemrefArg0 $ V.unsafeThaw $ V.iterateN 64 (+2.0) (1.0 :: AlignedStorable Float) 105 | (c, cVec) <- withMemrefArg0 $ V.unsafeThaw $ V.replicate 64 (0.0 :: AlignedStorable Float) 106 | Just eng <- ContT $ MLIR.withExecutionEngine m 107 | name <- ContT $ MLIR.withStringRef "matmul8x8x8" 108 | liftIO $ do 109 | Just () <- MLIR.executionEngineInvoke @() eng name [a, b, c] 110 | cVecFinal <- V.unsafeFreeze cVec 111 | cVecFinal `shouldBe` expectedOutput 112 | where 113 | -- Packs a vector into a struct representing a rank-0 memref 114 | withMemrefArg0 :: ContT r IO (V.MVector s a) -> ContT r IO (MLIR.SomeStorable, V.MVector s a) 115 | withMemrefArg0 mkVec = do 116 | vec@(V.MVector _ fptr) <- mkVec 117 | ptr <- ContT $ withForeignPtr fptr 118 | structPtr <- ContT $ MLIR.packStruct64 119 | [MLIR.SomeStorable ptr, MLIR.SomeStorable ptr, MLIR.SomeStorable (0 :: Int64)] 120 | return (MLIR.SomeStorable structPtr, vec) 121 | 122 | expectedOutput :: V.Vector (AlignedStorable Float) 123 | expectedOutput = V.fromList $ fmap Aligned 124 | [ 2724, 2796, 2868, 2940, 3012, 3084, 3156, 3228 125 | , 6372, 6572, 6772, 6972, 7172, 7372, 7572, 7772 126 | , 10020, 10348, 10676, 11004, 11332, 11660, 11988, 12316 127 | , 13668, 14124, 14580, 15036, 15492, 15948, 16404, 16860 128 | , 17316, 17900, 18484, 19068, 19652, 20236, 20820, 21404 129 | , 20964, 21676, 22388, 23100, 23812, 24524, 25236, 25948 130 | , 24612, 25452, 26292, 27132, 27972, 28812, 29652, 30492 131 | , 28260, 29228, 30196, 31164, 32132, 33100, 34068, 35036 132 | ] 133 | 134 | 135 | emitted :: Operation -> Operation 136 | emitted op = op { opAttributes = opAttributes op <> namedAttribute "llvm.emit_c_interface" UnitAttr } 137 | 138 | 139 | spec :: Spec 140 | spec = do 141 | describe "AST translation" $ do 142 | it "Can translate an empty module" $ do 143 | let m = ModuleOp $ Block "0" [] [] 144 | m `shouldShowAs` [r| 145 | module { 146 | }|] 147 | 148 | it "Can translate a module with location" $ do 149 | -- TODO(jpienaar): This builds the module explicitly using the Operation 150 | -- interface to set the opLocation on an Operation corresponding to a 151 | -- Module. 152 | let m = Operation { 153 | opName = "builtin.module" 154 | , opLocation = FusedLocation [ 155 | NameLocation "first" UnknownLocation 156 | , NameLocation "last" UnknownLocation 157 | ] Nothing 158 | , opResultTypes = Explicit [] 159 | , opOperands = [] 160 | , opRegions = [Region [Block "0" [] []]] 161 | , opSuccessors = [] 162 | , opAttributes = NoAttrs 163 | } 164 | m `shouldShowWithLocationAs` [r| 165 | module { 166 | } loc(#loc2) 167 | #loc = loc("first") 168 | #loc1 = loc("last") 169 | #loc2 = loc(fused[#loc, #loc1])|] 170 | 171 | it "Can construct a matmul via vector.matrix_multiply" $ do 172 | let v64Ty = VectorType [64] Float32Type 173 | let v64refTy = MemRefType { memrefTypeShape = [] 174 | , memrefTypeElement = v64Ty 175 | , memrefTypeLayout = Nothing 176 | , memrefTypeMemorySpace = Nothing } 177 | let m = ModuleOp $ Block "0" [] [ 178 | Do $ emitted $ FuncOp UnknownLocation "matmul8x8x8" (FunctionType [v64refTy, v64refTy, v64refTy] []) $ Region [ 179 | Block "0" [("arg0", v64refTy), ("arg1", v64refTy), ("arg2", v64refTy)] 180 | [ "0" := MemRef.Load v64Ty "arg0" [] 181 | , "1" := MemRef.Load v64Ty "arg1" [] 182 | , "2" := Vector.Matmul UnknownLocation v64Ty "0" "1" 8 8 8 183 | , "3" := MemRef.Load v64Ty "arg2" [] 184 | , "4" := Arith.AddF UnknownLocation v64Ty "3" "2" 185 | , Do $ MemRef.Store "4" "arg2" [] 186 | , Do $ Func.Return UnknownLocation [] 187 | ] 188 | ] 189 | ] 190 | m `shouldShowAs` [r| 191 | module { 192 | func.func @matmul8x8x8(%arg0: memref>, %arg1: memref>, %arg2: memref>) attributes {llvm.emit_c_interface} { 193 | %0 = memref.load %arg0[] : memref> 194 | %1 = memref.load %arg1[] : memref> 195 | %2 = vector.matrix_multiply %0, %1 {lhs_columns = 8 : i32, lhs_rows = 8 : i32, rhs_columns = 8 : i32} : (vector<64xf32>, vector<64xf32>) -> vector<64xf32> 196 | %3 = memref.load %arg2[] : memref> 197 | %4 = arith.addf %3, %2 : vector<64xf32> 198 | memref.store %4, %arg2[] : memref> 199 | return 200 | } 201 | }|] 202 | 203 | it "Can translate matmul via vector.contract" $ do 204 | let v8x8Ty = VectorType [8, 8] Float32Type 205 | let v8x8RefTy = MemRefType [] v8x8Ty Nothing Nothing 206 | let m = ModuleOp $ Block "0" [] [ 207 | Do $ emitted $ FuncOp UnknownLocation "matmul8x8x8" (FunctionType [v8x8RefTy, v8x8RefTy, v8x8RefTy] []) $ Region [ 208 | Block "0" [("arg0", v8x8RefTy), ("arg1", v8x8RefTy), ("arg2", v8x8RefTy)] 209 | [ "0" := MemRef.Load v8x8Ty "arg0" [] 210 | , "1" := MemRef.Load v8x8Ty "arg1" [] 211 | , "2" := MemRef.Load v8x8Ty "arg2" [] 212 | , "3" := Vector.Contract UnknownLocation v8x8Ty "0" "1" "2" 213 | (Affine.Map 3 0 [Affine.Dimension 0, Affine.Dimension 2]) 214 | (Affine.Map 3 0 [Affine.Dimension 2, Affine.Dimension 1]) 215 | (Affine.Map 3 0 [Affine.Dimension 0, Affine.Dimension 1]) 216 | [Vector.Parallel, Vector.Parallel, Vector.Reduction] 217 | , Do $ MemRef.Store "3" "arg2" [] 218 | , Do $ Func.Return UnknownLocation [] 219 | ] 220 | ] 221 | ] 222 | shouldImplementMatmul m 223 | m `shouldShowAs` [r| 224 | #map = affine_map<(d0, d1, d2) -> (d0, d2)> 225 | #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> 226 | #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> 227 | module { 228 | func.func @matmul8x8x8(%arg0: memref>, %arg1: memref>, %arg2: memref>) attributes {llvm.emit_c_interface} { 229 | %0 = memref.load %arg0[] : memref> 230 | %1 = memref.load %arg1[] : memref> 231 | %2 = memref.load %arg2[] : memref> 232 | %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %0, %1, %2 : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> 233 | memref.store %3, %arg2[] : memref> 234 | return 235 | } 236 | }|] 237 | 238 | 239 | main :: IO () 240 | main = hspec spec 241 | -------------------------------------------------------------------------------- /test/MLIR/BuilderSpec.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.BuilderSpec where 16 | 17 | import Test.Hspec 18 | 19 | import Control.Monad.Identity 20 | 21 | import MLIR.AST 22 | import MLIR.AST.Builder 23 | import MLIR.AST.Serialize 24 | import qualified Data.ByteString as BS 25 | import qualified MLIR.AST.Dialect.Arith as Arith 26 | import qualified MLIR.AST.Dialect.ControlFlow as Cf 27 | import qualified MLIR.AST.Dialect.Func as Func 28 | import qualified MLIR.Native as MLIR 29 | 30 | 31 | verifyAndDump :: Operation -> Expectation 32 | verifyAndDump op = 33 | MLIR.withContext \ctx -> do 34 | MLIR.registerAllDialects ctx 35 | nativeOp <- fromAST ctx (mempty, mempty) op 36 | MLIR.dump nativeOp 37 | MLIR.verifyOperation nativeOp >>= (`shouldBe` True) 38 | 39 | 40 | spec :: Spec 41 | spec = do 42 | describe "Builder API" $ do 43 | let combineFunc name ty combine = 44 | buildSimpleFunction name [ty] NoAttrs do 45 | x <- blockArgument ty 46 | y <- blockArgument ty 47 | z <- combine x y 48 | Func.return [z] 49 | 50 | it "Can construct a simple add function" $ do 51 | let m = runIdentity $ buildModule $ combineFunc "add" Float32Type Arith.addf 52 | verifyAndDump m 53 | 54 | it "Can construct a module with two simple functions" $ do 55 | let m = runIdentity $ buildModule $ do 56 | combineFunc "add_fp32" Float32Type Arith.addf 57 | combineFunc "add_fp64" Float64Type Arith.addf 58 | verifyAndDump m 59 | 60 | it "Can loop blocks with MonadFix" $ do 61 | let f32 = Float32Type 62 | let i1 = IntegerType Signless 1 63 | let m = runIdentity $ buildModule $ do 64 | buildFunction "one_shot_loop" [f32] NoAttrs mdo 65 | _entry <- buildBlock do 66 | false <- Arith.constant i1 $ IntegerAttr i1 0 67 | Cf.br header [false] 68 | header <- buildBlock do 69 | isDone <- blockArgument i1 70 | result <- Arith.constant f32 $ FloatAttr f32 1234.0 71 | Cf.cond_br isDone exit [result] body [result] 72 | body <- buildBlock do 73 | _ <- blockArgument f32 74 | true <- Arith.constant i1 $ IntegerAttr i1 1 75 | Cf.br header [true] 76 | exit <- buildBlock do 77 | result <- blockArgument f32 78 | Func.return [result] 79 | endOfRegion 80 | verifyAndDump m 81 | 82 | it "location propagation in builder" $ do 83 | let f32 = Float32Type 84 | let m = runIdentity $ buildModule $ do 85 | buildSimpleFunction "f_loc" [f32] NoAttrs do 86 | x <- blockArgument f32 87 | y <- Arith.addf x x 88 | setDefaultLocation (FileLocation "file.mlir" 4 10) 89 | z <- Arith.addf y y 90 | Func.return [z] 91 | MLIR.withContext \ctx -> do 92 | MLIR.registerAllDialects ctx 93 | nativeOp <- fromAST ctx (mempty, mempty) m 94 | MLIR.verifyOperation nativeOp >>= (`shouldBe` True) 95 | MLIR.showOperationWithLocation nativeOp >>= (`shouldBe` BS.intercalate "\n" [ 96 | "#loc = loc(unknown)" 97 | , "module {" 98 | , " func.func @f_loc(%arg0: f32 loc(unknown)) -> f32 {" 99 | , " %0 = arith.addf %arg0, %arg0 : f32 loc(#loc)" 100 | , " %1 = arith.addf %0, %0 : f32 loc(#loc1)" 101 | , " return %1 : f32 loc(#loc1)" 102 | , " } loc(#loc)" 103 | , "} loc(#loc)" 104 | , "#loc1 = loc(\"file.mlir\":4:10)" 105 | , ""]) 106 | 107 | 108 | main :: IO () 109 | main = hspec spec 110 | -------------------------------------------------------------------------------- /test/MLIR/NativeSpec.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.NativeSpec where 16 | 17 | import Test.Hspec hiding (shouldContain, shouldStartWith) 18 | 19 | import Text.RawString.QQ 20 | 21 | import Data.Int 22 | import Data.Maybe 23 | import Data.Char (ord) 24 | import qualified Data.ByteString as BS 25 | import Control.Monad 26 | import Foreign.Storable 27 | 28 | import qualified MLIR.Native as MLIR 29 | import qualified MLIR.Native.Pass as MLIR 30 | import qualified MLIR.Native.ExecutionEngine as MLIR 31 | 32 | exampleModuleStr :: BS.ByteString 33 | exampleModuleStr = pack $ [r|module { 34 | func.func @add(%arg0: i32) -> i32 attributes {llvm.emit_c_interface} { 35 | %0 = arith.addi %arg0, %arg0 : i32 36 | return %0 : i32 37 | } 38 | } 39 | |] 40 | 41 | -- XXX: Only valid for ASCII strings 42 | pack :: String -> BS.ByteString 43 | pack = BS.pack . fmap (fromIntegral . ord) 44 | 45 | -- TODO(apaszke): Clean up 46 | prepareContext :: IO MLIR.Context 47 | prepareContext = do 48 | ctx <- MLIR.createContext 49 | MLIR.registerAllDialects ctx 50 | return ctx 51 | 52 | -- Helper matcher as shouldContain requires the same type both sides and here 53 | -- we are predominantly checking if a BS contains some String. 54 | shouldContain :: BS.ByteString -> BS.ByteString -> Expectation 55 | shouldContain str sub = str `shouldSatisfy` BS.isInfixOf sub 56 | 57 | shouldStartWith :: BS.ByteString -> BS.ByteString -> Expectation 58 | shouldStartWith str sub = str `shouldSatisfy` BS.isPrefixOf sub 59 | 60 | spec :: Spec 61 | spec = do 62 | describe "Basics" $ do 63 | it "Can create a context" $ MLIR.withContext $ const $ return () 64 | 65 | it "Can load dialects" $ do 66 | MLIR.withContext \ctx -> do 67 | MLIR.registerAllDialects ctx 68 | numDialects <- MLIR.getNumLoadedDialects ctx 69 | numDialects `shouldSatisfy` (> 1) 70 | 71 | describe "Modules" $ beforeAll prepareContext $ do 72 | it "Can create an empty module" $ \ctx -> do 73 | loc <- MLIR.getUnknownLocation ctx 74 | m <- MLIR.createEmptyModule loc 75 | str <- MLIR.showModule m 76 | MLIR.destroyModule m 77 | str `shouldBe` "module {\n}\n" 78 | 79 | it "Can parse an example module" $ \ctx -> do 80 | exampleModule <- liftM fromJust $ 81 | MLIR.withStringRef exampleModuleStr $ MLIR.parseModule ctx 82 | exampleModuleStr' <- MLIR.showModule exampleModule 83 | exampleModuleStr' `shouldBe` exampleModuleStr 84 | MLIR.destroyModule exampleModule 85 | 86 | it "Fails to parse garbage" $ \ctx -> do 87 | maybeModule <- MLIR.withStringRef "asdf" $ MLIR.parseModule ctx 88 | (isNothing maybeModule) `shouldBe` True 89 | 90 | it "Can create an empty module with location" $ \ctx -> do 91 | MLIR.withStringRef "test.cc" $ \nameRef -> do 92 | loc <- MLIR.getFileLineColLocation ctx nameRef 21 45 93 | m <- MLIR.createEmptyModule loc 94 | str <- (MLIR.moduleAsOperation >=> MLIR.showOperationWithLocation) m 95 | MLIR.destroyModule m 96 | str `shouldContain` "loc(\"test.cc\":21:45)" 97 | 98 | it "Can create an empty module with name location" $ \ctx -> do 99 | MLIR.withStringRef "WhatIamCalled" $ \nameRef -> do 100 | loc <- MLIR.getNameLocation ctx nameRef =<< MLIR.getUnknownLocation ctx 101 | m <- MLIR.createEmptyModule loc 102 | str <- (MLIR.moduleAsOperation >=> MLIR.showOperationWithLocation) m 103 | MLIR.destroyModule m 104 | str `shouldContain` "loc(\"WhatIamCalled\")" 105 | 106 | it "Can extract first operation (Function) of module" $ \ctx -> do 107 | exampleModule <- liftM fromJust $ 108 | MLIR.withStringRef exampleModuleStr $ MLIR.parseModule ctx 109 | operations <- (MLIR.getModuleBody >=> MLIR.getBlockOperations) exampleModule 110 | functionStr' <- MLIR.showOperation $ head operations 111 | functionStr' `shouldStartWith` "func.func @add(%arg0: i32) -> i32" 112 | MLIR.destroyModule exampleModule 113 | 114 | it "Can show operations inside region of function" $ \ctx -> do 115 | exampleModule <- liftM fromJust $ 116 | MLIR.withStringRef exampleModuleStr $ MLIR.parseModule ctx 117 | operations <- (MLIR.getModuleBody >=> MLIR.getBlockOperations) exampleModule 118 | regions <- MLIR.getOperationRegions (head operations) 119 | blocks <- MLIR.getRegionBlocks (head regions) 120 | ops <- MLIR.getBlockOperations $ head blocks 121 | opStrs <- sequence $ map MLIR.showOperation ops 122 | (BS.intercalate " ; " opStrs) `shouldBe` "%0 = arith.addi %arg0, %arg0 : i32 ; func.return %0 : i32" 123 | MLIR.destroyModule exampleModule 124 | 125 | describe "Evaluation engine" $ beforeAll prepareContext $ do 126 | it "Can evaluate the example module" $ \ctx -> do 127 | m <- liftM fromJust $ 128 | MLIR.withStringRef exampleModuleStr $ MLIR.parseModule ctx 129 | lowerToLLVM m 130 | result <- run @Int32 m "add" [MLIR.SomeStorable (123 :: Int32)] 131 | result `shouldBe` 246 132 | MLIR.destroyModule m 133 | where 134 | lowerToLLVM :: MLIR.Module -> IO () 135 | lowerToLLVM m = do 136 | ctx <- MLIR.getContext m 137 | o <- MLIR.moduleAsOperation m 138 | MLIR.withPassManager ctx \pm -> do 139 | MLIR.addConvertFuncToLLVMPass pm 140 | MLIR.addConvertArithToLLVMPass pm 141 | MLIR.addConvertReconcileUnrealizedCastsPass pm 142 | result <- MLIR.runPasses pm o 143 | when (result == MLIR.Failure) $ error "Failed to lower to LLVM!" 144 | 145 | run :: forall result. Storable result 146 | => MLIR.Module -> BS.ByteString -> [MLIR.SomeStorable] -> IO result 147 | run m name args = do 148 | MLIR.withExecutionEngine m \maybeEng -> do 149 | let eng = fromMaybe (error "Failed to compile the module") maybeEng 150 | MLIR.withStringRef name $ \nameRef -> do 151 | maybeValue <- MLIR.executionEngineInvoke eng nameRef args 152 | case maybeValue of 153 | Just value -> return value 154 | Nothing -> error "Failed to run the example program!" 155 | 156 | main :: IO () 157 | main = hspec spec 158 | -------------------------------------------------------------------------------- /test/MLIR/RewriteSpec.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | module MLIR.RewriteSpec where 16 | 17 | import Test.Hspec 18 | 19 | import Control.Monad.Identity 20 | 21 | import MLIR.AST 22 | import MLIR.AST.Builder 23 | import MLIR.AST.Serialize 24 | import MLIR.AST.Rewrite 25 | import qualified MLIR.AST.Dialect.Arith as Arith 26 | import qualified MLIR.AST.Dialect.Func as Func 27 | import qualified MLIR.Native as MLIR 28 | 29 | 30 | verifyAndDump :: Operation -> Expectation 31 | verifyAndDump op = 32 | MLIR.withContext \ctx -> do 33 | MLIR.registerAllDialects ctx 34 | nativeOp <- fromAST ctx (mempty, mempty) op 35 | MLIR.dump nativeOp 36 | MLIR.verifyOperation nativeOp >>= (`shouldBe` True) 37 | 38 | 39 | spec :: Spec 40 | spec = do 41 | describe "Rewrite API" $ do 42 | it "Can replace adds with multiplies" $ do 43 | let m = runIdentity $ buildModule $ 44 | buildSimpleFunction "f" [Float32Type] NoAttrs do 45 | x <- blockArgument Float32Type 46 | y <- blockArgument Float32Type 47 | z <- Arith.addf x y 48 | w <- Arith.addf z z 49 | Func.return [w] 50 | let m' = applyClosedOpRewrite replaceAddWithMul m 51 | verifyAndDump m' 52 | where 53 | replaceAddWithMul op = case op of 54 | Arith.AddF _ _ x y -> ReplaceOne <$> Arith.mulf x y 55 | _ -> return Traverse 56 | 57 | 58 | main :: IO () 59 | main = hspec spec 60 | -------------------------------------------------------------------------------- /test/MLIR/Test/Generators.hs: -------------------------------------------------------------------------------- 1 | -- Copyright 2021 Google LLC 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | {-# OPTIONS_GHC -Wno-orphans #-} 16 | module MLIR.Test.Generators where 17 | 18 | import Control.Monad 19 | import Test.QuickCheck 20 | import GHC.Generics 21 | import Generic.Random 22 | import Data.Array.IArray 23 | import qualified Data.Map.Strict as M 24 | import qualified Data.ByteString.Char8 as BS8 25 | 26 | import MLIR.AST 27 | import qualified MLIR.AST.Dialect.Affine as Affine 28 | 29 | instance Arbitrary Name where 30 | arbitrary = BS8.pack <$> arbitrary 31 | 32 | instance Arbitrary Attribute where 33 | arbitrary = recursiveArbitrary leafGenerators recGenerators 34 | where 35 | leafGenerators = 36 | [ FloatAttr <$> arbitraryFloatType <*> arbitrary 37 | , IntegerAttr <$> arbitraryIntegerType <*> arbitrary 38 | , BoolAttr <$> arbitrary 39 | , StringAttr <$> arbitrary 40 | , TypeAttr <$> arbitrary 41 | , AffineMapAttr <$> arbitrary 42 | , pure UnitAttr 43 | , do 44 | values <- arbitrary 45 | return $ DenseElementsAttr 46 | (VectorType [length values] (IntegerType Signless 32)) 47 | (DenseUInt32 $ listArray (1, length values) $ values) 48 | ] 49 | recGenerators = 50 | [ ArrayAttr <$> arbitrarySubtrees 51 | , DictionaryAttr . M.fromList <$> (traverse arbitraryName =<< arbitrarySubtrees) 52 | ] 53 | arbitraryName :: Attribute -> Gen (Name, Attribute) 54 | arbitraryName attr = (,attr) <$> arbitrary 55 | 56 | arbitrarySubtrees :: Arbitrary a => Gen [a] 57 | arbitrarySubtrees = sized $ \size -> do 58 | numSubtrees <- chooseInt (0, size) 59 | replicateM numSubtrees do 60 | subsize <- chooseInt (0, size - 1) 61 | resize subsize arbitrary 62 | 63 | recursiveArbitrary :: [Gen a] -> [Gen a] -> Gen a 64 | recursiveArbitrary leafGenerators recGenerators = sized $ \size -> do 65 | case size > 0 of 66 | False -> oneof leafGenerators 67 | True -> frequency $ ((9,) <$> leafGenerators) <> ((1,) <$> recGenerators) 68 | 69 | arbitraryFloatType :: Gen Type 70 | arbitraryFloatType = oneof $ fmap pure 71 | [ BFloat16Type 72 | , Float128Type 73 | , Float16Type 74 | , Float32Type 75 | , Float64Type 76 | , Float80Type 77 | ] 78 | 79 | arbitraryIntegerType :: Gen Type 80 | arbitraryIntegerType = IntegerType <$> arbitrary <*> elements [1, 8, 16, 32, 64] 81 | 82 | scalarTypeGenerators :: [Gen Type] 83 | scalarTypeGenerators = 84 | [ pure BFloat16Type 85 | , ComplexType <$> arbitraryFloatType 86 | , pure Float128Type 87 | , pure Float16Type 88 | , pure Float32Type 89 | , pure Float64Type 90 | , pure Float80Type 91 | , pure IndexType 92 | , arbitraryIntegerType 93 | ] 94 | 95 | arbitraryScalarType :: Gen Type 96 | arbitraryScalarType = oneof scalarTypeGenerators 97 | 98 | instance Arbitrary Type where 99 | arbitrary = recursiveArbitrary leafGenerators recGenerators 100 | where 101 | leafGenerators = scalarTypeGenerators ++ 102 | [ MemRefType <$> arbitrary 103 | <*> arbitraryScalarType 104 | <*> frequency [(9, pure Nothing), (1, arbitrary)] 105 | <*> frequency [(9, pure Nothing), (1, arbitrary)] 106 | , pure NoneType 107 | , OpaqueType <$> arbitrary <*> arbitrary 108 | , RankedTensorType <$> frequency [(1, listOf $ Just <$> arbitrary), (1, arbitrary)] 109 | <*> arbitraryScalarType 110 | <*> arbitrary 111 | , UnrankedMemRefType <$> arbitraryScalarType <*> arbitrary 112 | , UnrankedTensorType <$> arbitraryScalarType 113 | , VectorType <$> frequency [(1, (:[]) <$> arbitrary), (1, arbitrary)] 114 | <*> arbitraryScalarType 115 | ] 116 | recGenerators = 117 | [ FunctionType <$> arbitrarySubtrees <*> arbitrarySubtrees 118 | , TupleType <$> arbitrarySubtrees 119 | ] 120 | 121 | 122 | instance Arbitrary Signedness where 123 | arbitrary = genericArbitrary uniform 124 | 125 | instance Arbitrary Affine.Map where 126 | arbitrary = genericArbitrary uniform 127 | 128 | instance Arbitrary Affine.Expr where 129 | arbitrary = sized $ \size -> do 130 | case size > 0 of 131 | False -> oneof leafGenerators 132 | True -> do 133 | l <- choose (0, size) 134 | frequency $ ((9,) <$> leafGenerators) <> ((1,) <$> recGenerators l (size - l)) 135 | where 136 | leafGenerators = fmap (<$> arbitrary) 137 | [Affine.Dimension, Affine.Symbol, Affine.Constant] 138 | recGenerators l r = smallerArbitrary2 l r <$> 139 | [Affine.Add, Affine.Mul, Affine.Mod, Affine.FloorDiv, Affine.CeilDiv] 140 | smallerArbitrary2 l r f = f <$> (resize l arbitrary) <*> (resize r arbitrary) 141 | 142 | deriving instance Show Attribute 143 | deriving instance Show Signedness 144 | deriving instance Show Affine.Map 145 | deriving instance Show Affine.Expr 146 | deriving instance Show DenseElements 147 | 148 | instance Show Type where 149 | show _ = "" 150 | 151 | deriving instance Generic Signedness 152 | deriving instance Generic Affine.Map 153 | deriving instance Generic Affine.Expr 154 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | --------------------------------------------------------------------------------