├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE ├── Main.lean ├── README.md ├── SHerLOC.lean ├── SHerLOC ├── AST1.lean ├── Basic.lean └── Parsing │ ├── Basic.lean │ ├── Functions.lean │ ├── Identifiers.lean │ ├── Intermediate.lean │ ├── Modules.lean │ ├── Numbers.lean │ ├── Operations.lean │ ├── Parser.lean │ ├── Programs.lean │ └── Types.lean ├── lake-manifest.json ├── lakefile.lean ├── lean-toolchain └── prepare_tests.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: leanprover/lean-action@v1 17 | - name: Run tests 18 | run: lake exe sherloc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.lake/ 2 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * trjohnb@amazon.com 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Main.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC 7 | 8 | open System IO FilePath Process FS Std 9 | 10 | def main (args : List String) : IO UInt32 := do 11 | if args.length = 0 then 12 | let o ← output { cmd := "ls", args := #["Tests"] } 13 | let files := o.stdout.splitOn "\n" 14 | let files := files.filter fun s => s.takeRight 5 = ".mlir" 15 | let mut passed := [] 16 | let mut failed := [] 17 | for file in files do 18 | let fp : FilePath := System.mkFilePath ["Tests", file] 19 | let content ← readFile fp 20 | let content := StableHLO.Parsing.parse content 21 | IO.print s!"Parsing {file}... " 22 | match content with 23 | | .ok p => 24 | passed := file :: passed 25 | let fpReport : FilePath := System.mkFilePath ["Tests", file ++ ".report"] 26 | for msg in p.2.report do 27 | writeFile fpReport s!"File {file}, {msg}\n" 28 | IO.println "success" 29 | | .error _ => 30 | failed := file :: failed 31 | IO.println "failure" 32 | IO.println "\nFailed tests:\n" 33 | for file in failed do 34 | IO.println file 35 | IO.println "" 36 | IO.println s!"Passed: {passed.length}, Failed {failed.length}" 37 | if failed.length > 0 then 38 | return 1 39 | else 40 | return 0 41 | else if args.length = 1 then 42 | let file := args[0]! 43 | let fp : FilePath := System.mkFilePath [file] 44 | let content ← readFile fp 45 | let content := StableHLO.Parsing.parse content 46 | match content with 47 | | .ok p => 48 | let fpAST : FilePath := System.mkFilePath [file ++ ".ast"] 49 | let fpReport : FilePath := System.mkFilePath [file ++ ".report"] 50 | writeFile fpAST s!"{repr p.1}\n" 51 | writeFile fpReport s!"{p.2.report}\n" 52 | return 0 53 | | .error e => 54 | IO.println s!"{e.2.2}" 55 | IO.println s!"{e.2.1}" 56 | IO.println s!"{e.1}" 57 | return 1 58 | 59 | else panic! s!"Unexpected number of arguments: {args.length}" 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SHerLOC 2 | 3 | SHerLOC is a program analyzer for [StableHLO programs](https://openxla.org/stablehlo). It is written in [Lean](https://leanprover-community.github.io/index.html). 4 | 5 | SHerLOC aims to transform a StableHLO program written in concrete generic syntax into a well-formed, typed, abstract syntax tree. It also reports information such as use of undocumented/unspecified/underspecified/deprecated constructions. 6 | 7 | ## Installation 8 | 9 | To use SHerLOC, you must [install Lean](https://leanprover-community.github.io/get_started.html). If you want to use SHerLOC on StableHLO programs written in pretty syntax, you also need to [install StableHLO](https://github.com/openxla/stablehlo?tab=readme-ov-file#build-instructions) (note that you do not need to build the Python bindings). 10 | 11 | You should then clone this repository. 12 | 13 | ## Usage 14 | 15 | To run SHerLOC, go to the SHerLOC directory and run 16 | 17 | ``` 18 | lake exe sherloc myprogram.mlir 19 | ``` 20 | 21 | This will produce two files, `myprogram.mlir.ast` and `myprogram.mlir.report` that contain respectively a dump of the abstract syntax tree and the reported information about the program. 22 | 23 | If the StableHLO program is in pretty syntax, you can convert it to generic syntax using `stablehlo-opt` 24 | 25 | ``` 26 | stablehlo-opt -mlir-print-op-generic myprogrampretty.mlir > myprogramgeneric.mlir 27 | ``` 28 | 29 | To produce a StableHLO program in generic syntax from Jax, you can use the following Python example: 30 | 31 | ```python 32 | from jax._src.interpreters import mlir as jax_mlir 33 | from jax._src.lib.mlir import ir 34 | 35 | # Returns prettyprint of StableHLO module as generic print 36 | def get_stablehlo_asm(module_str): 37 | with jax_mlir.make_ir_context(): 38 | stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context()) 39 | return stablehlo_module.operation.get_asm(print_generic_op_form=True, enable_debug_info=False) 40 | 41 | ## ----- 42 | 43 | import jax 44 | from jax import export 45 | import jax.numpy as jnp 46 | import numpy as np 47 | 48 | def plus(x,y): 49 | return jnp.add(x,y) 50 | 51 | # Create abstract input shapes: 52 | inputs = (np.int32(1), np.int32(1),) 53 | input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs] 54 | stablehlo_add = export.export(jax.jit(plus))(*input_shapes).mlir_module() 55 | 56 | print(get_stablehlo_asm(stablehlo_add)) 57 | ``` 58 | -------------------------------------------------------------------------------- /SHerLOC.lean: -------------------------------------------------------------------------------- 1 | import SHerLOC.Basic 2 | -------------------------------------------------------------------------------- /SHerLOC/AST1.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | 7 | /-! 8 | # AST resulting from parsing 9 | 10 | -/ 11 | 12 | namespace StableHLO.Parsing 13 | 14 | abbrev FuncId := String 15 | 16 | abbrev ValueId := String 17 | 18 | abbrev UnusedId := String 19 | 20 | abbrev AttrId := String 21 | 22 | inductive Signedness where 23 | | signed 24 | | unsigned 25 | deriving Repr, Inhabited, Nonempty 26 | 27 | inductive IntegerSize where 28 | | b2 29 | | b4 30 | | b8 31 | | b16 32 | | b32 33 | | b64 34 | deriving Repr, Inhabited, Nonempty 35 | 36 | inductive Sign where 37 | | plus 38 | | minus 39 | deriving Repr, Inhabited, Nonempty 40 | 41 | structure IntegerLiteral where 42 | sign : Sign 43 | decimal : Nat 44 | deriving Repr, Inhabited, Nonempty 45 | 46 | structure FloatLiteralDecimal where 47 | integerPart : IntegerLiteral 48 | fractionalPart : IntegerLiteral 49 | scientificPart : IntegerLiteral 50 | deriving Repr, Inhabited, Nonempty 51 | 52 | inductive FloatLiteral where 53 | | decimal (literal : FloatLiteralDecimal) 54 | | hexaDecimal (literal : Nat) 55 | deriving Repr, Inhabited, Nonempty 56 | 57 | inductive BooleanLiteral where 58 | | true 59 | | false 60 | deriving Repr, Inhabited, Nonempty 61 | 62 | structure ComplexLiteral where 63 | real : FloatLiteral 64 | imaginary : FloatLiteral 65 | deriving Repr, Inhabited, Nonempty 66 | 67 | inductive ElementLiteral where 68 | | booleanLiteral (literal : BooleanLiteral) 69 | | floatLiteral (literal : FloatLiteral) 70 | | complexLiteral (literal : ComplexLiteral) 71 | | stringLiteral (literal : String) 72 | deriving Repr, Inhabited, Nonempty 73 | 74 | inductive DenseLiteral where 75 | | denseDimension (literal : List DenseLiteral) 76 | | denseElements (literal : List ElementLiteral) 77 | deriving Repr, Inhabited, Nonempty 78 | 79 | abbrev TensorLiteral := DenseLiteral 80 | 81 | inductive ComparisonDirection where 82 | | eq 83 | | ne 84 | | ge 85 | | gt 86 | | le 87 | | lt 88 | deriving Repr, Inhabited, Nonempty 89 | 90 | inductive CompareType where 91 | | float 92 | | totalOrder 93 | | signed 94 | | unsigned 95 | deriving Repr, Inhabited, Nonempty 96 | 97 | inductive PrecisionConfig where 98 | | default 99 | | high 100 | | highest 101 | deriving Repr, Inhabited, Nonempty 102 | 103 | inductive FftType where 104 | | fft 105 | | ifft 106 | | rfft 107 | | irfft 108 | deriving Repr, Inhabited, Nonempty 109 | 110 | inductive ChannelType where 111 | | deviceToDevice 112 | | hostToDevice 113 | deriving Repr, Inhabited, Nonempty 114 | 115 | inductive RngDistribution where 116 | | uniform 117 | | normal 118 | deriving Repr, Inhabited, Nonempty 119 | 120 | inductive RngAlgorithm where 121 | | default 122 | | threeFry 123 | | philox 124 | deriving Repr, Inhabited, Nonempty 125 | 126 | inductive TransposeA where 127 | | noTranspose 128 | | transpose 129 | | adjoint 130 | deriving Repr, Inhabited, Nonempty 131 | 132 | inductive EnumLiteral where 133 | | comparisonDirection (enum : ComparisonDirection) 134 | | compareType (enum : CompareType) 135 | | precisionConfig (enum : PrecisionConfig) 136 | | fftType (enum : FftType) 137 | | channelType (enum : ChannelType) 138 | | rngDistribution (enum : RngDistribution) 139 | | rngAlgorithm (enum : RngAlgorithm) 140 | | transposeA (enum : TransposeA) 141 | deriving Repr, Inhabited, Nonempty 142 | 143 | inductive ArrayLiteral where 144 | | array64 (literal : List IntegerLiteral) 145 | | array1 (literal : List BooleanLiteral) 146 | deriving Repr, Inhabited, Nonempty 147 | 148 | inductive ConvolutionMode where 149 | | i 150 | | o 151 | | f 152 | | one 153 | | b 154 | | zero 155 | | two 156 | deriving Repr, Inhabited, Nonempty 157 | 158 | structure Convolution where 159 | lhs : List ConvolutionMode 160 | rhs : List ConvolutionMode 161 | result : List ConvolutionMode 162 | deriving Repr, Inhabited, Nonempty 163 | 164 | structure IntegerType where 165 | sign : Signedness 166 | size : IntegerSize 167 | deriving Repr, Inhabited, Nonempty 168 | 169 | inductive FloatType where 170 | | f8E3M4 171 | | f8E4M3 172 | | f8E4M3FN 173 | | f8E5M2 174 | | f8E4M3FNUZ 175 | | f8E5M2FNUZ 176 | | f8E4M3B11FNUZ 177 | | bf16 178 | | f16 179 | | f32 180 | | f64 181 | | tf32 182 | deriving Repr, Inhabited, Nonempty 183 | 184 | inductive NumberType where 185 | | integerType (type : IntegerType) 186 | | floatType (type: FloatType) 187 | deriving Repr, Inhabited, Nonempty 188 | 189 | inductive ComplexType where 190 | | f32 191 | | f64 192 | deriving Repr, Inhabited, Nonempty 193 | 194 | inductive TensorElementType where 195 | | booleanType 196 | | integerType (t : IntegerType) 197 | | floatType (t: FloatType) 198 | | complexType (t: ComplexType) 199 | deriving Repr, Inhabited, Nonempty 200 | 201 | structure QuantizationParameter where 202 | quantizationScale : FloatLiteral 203 | quantizationZeroPoint: IntegerLiteral 204 | deriving Repr, Inhabited, Nonempty 205 | 206 | structure QuantizationBasics where 207 | quantizationStorageType : IntegerType 208 | quantizationStorageMinMax : Option (IntegerLiteral × IntegerLiteral) 209 | quantizationExpressedType : FloatType 210 | quantizationDimension : Option IntegerLiteral 211 | deriving Repr, Inhabited, Nonempty 212 | 213 | structure QuantizedTensorElementType where 214 | quantizationBasics : QuantizationBasics 215 | quantizationParameters : List QuantizationParameter 216 | deriving Repr, Inhabited, Nonempty 217 | 218 | inductive TensorElementTypeGen where 219 | | classic (t : TensorElementType) 220 | | quantized (t : QuantizedTensorElementType) 221 | deriving Repr, Inhabited, Nonempty 222 | 223 | inductive DimensionSize where 224 | | known (size : Nat) 225 | | unknown 226 | deriving Repr, Inhabited, Nonempty 227 | 228 | structure TensorType where 229 | shape : List DimensionSize 230 | tensorElementTypeGen : TensorElementTypeGen 231 | deriving Repr, Inhabited, Nonempty 232 | 233 | inductive ValueType where 234 | | tensorType (tensor : TensorType) 235 | | tokenType 236 | | tupleType (elements : List ValueType) 237 | deriving Repr, Inhabited, Nonempty 238 | 239 | structure FunctionType where 240 | domain : List ValueType 241 | range : List ValueType 242 | deriving Repr, Inhabited, Nonempty 243 | 244 | inductive NonValueType where 245 | | tensorElementType (t : TensorElementType) 246 | | quantizedTensorElementType (t: QuantizedTensorElementType) 247 | | functionType (t : FunctionType) 248 | | stringType 249 | deriving Repr, Inhabited, Nonempty 250 | 251 | inductive SType where 252 | | valueType (t : ValueType) 253 | | nonValueType (t : NonValueType) 254 | deriving Repr, Inhabited, Nonempty 255 | 256 | mutual 257 | 258 | inductive StableHLORecordFieldValue where 259 | | one (literal : Nat) 260 | | many (literal : List Nat) 261 | | type (literal : FloatType) 262 | | bool (literal : Bool) 263 | deriving Repr, Inhabited, Nonempty 264 | 265 | inductive StableHLORecordField where 266 | | mk (name : String) (value : StableHLORecordFieldValue) 267 | deriving Repr, Inhabited, Nonempty 268 | 269 | inductive Literal where 270 | | enum (literal : EnumLiteral) 271 | | element (literal : ElementLiteral) 272 | | tensor (literal : TensorLiteral) 273 | | string (literal : String) 274 | | stableHLORecord (literal : List StableHLORecordField) 275 | | convolution (literal : Convolution) 276 | | func (literal : FuncId) 277 | | list (literal : List Literal) 278 | | dictionary (literal : List Attribute) 279 | | array (literal : ArrayLiteral) 280 | deriving Repr, Inhabited, Nonempty 281 | 282 | inductive Constant where 283 | | mk (literal : Literal) (typ : Option SType) 284 | deriving Repr, Inhabited, Nonempty 285 | 286 | inductive Attribute where 287 | | mk (id : AttrId) (constant : Constant) 288 | deriving Repr, Inhabited, Nonempty 289 | 290 | end 291 | 292 | structure FuncInput where 293 | id : FuncId 294 | typ : ValueType 295 | deriving Repr, Inhabited, Nonempty 296 | 297 | inductive OpCode where 298 | | abs 299 | | add 300 | | afterAll 301 | | allGather 302 | | allReduce 303 | | allToAll 304 | | and 305 | | atan2 306 | | batchNormGrad 307 | | batchNormInference 308 | | batchNormTraining 309 | | bitcastConvert 310 | | broadcastInDim 311 | | case 312 | | cbrt 313 | | ceil 314 | | cholesky 315 | | clamp 316 | | collectiveBroadcast 317 | | collectivePermute 318 | | compare 319 | | complex 320 | | composite 321 | | concatenate 322 | | constant 323 | | convert 324 | | convolution 325 | | cosine 326 | | countLeadingZeros 327 | | customCall 328 | | divide 329 | | dot 330 | | dotGeneral 331 | | dynamicBroadcastInDim 332 | | dynamicConv 333 | | dynamicGather 334 | | dynamicIota 335 | | dynamicPad 336 | | dynamicReshape 337 | | dynamicSlice 338 | | dynamicUpdateSlice 339 | | exponential 340 | | exponentialMinusOne 341 | | fft 342 | | floor 343 | | gather 344 | | getDimensionSize 345 | | getTupleElement 346 | | if 347 | | imag 348 | | infeed 349 | | iota 350 | | isFinite 351 | | log 352 | | logPlusOne 353 | | logistic 354 | | map 355 | | maximum 356 | | minimum 357 | | multiply 358 | | negate 359 | | not 360 | | optimizationBarrier 361 | | or 362 | | outfeed 363 | | pad 364 | | partitionId 365 | | popcnt 366 | | power 367 | | real 368 | | realDynamicSlice 369 | | recv 370 | | reduce 371 | | reducePrecision 372 | | reduceScatter 373 | | reduceWindow 374 | | remainder 375 | | replicaId 376 | | reshape 377 | | reverse 378 | | rng 379 | | rngBitGenerator 380 | | roundNearestAfz 381 | | roundNearestEven 382 | | rsqrt 383 | | scatter 384 | | select 385 | | selectAndScatter 386 | | send 387 | | shiftLeft 388 | | shiftRightArithmetic 389 | | shiftRightLogical 390 | | sign 391 | | sine 392 | | slice 393 | | sort 394 | | sqrt 395 | | subtract 396 | | tan 397 | | tanh 398 | | transpose 399 | | triangularSolve 400 | | tuple 401 | | uniformDequantize 402 | | uniformQuantize 403 | | while 404 | | xor 405 | deriving Repr, Inhabited, Nonempty 406 | 407 | mutual 408 | 409 | inductive InputFunc where 410 | | mk 411 | (funcInputs : List FuncInput) 412 | (body : List Operation) 413 | deriving Repr, Inhabited, Nonempty 414 | 415 | inductive Operation where 416 | | stablehlo 417 | (opCode : OpCode) 418 | (inputValues : List ValueId) 419 | (inputFunctions : List InputFunc) 420 | (inputAttributes : List Attribute) 421 | (outputs : List ValueId) 422 | (signature : FunctionType) 423 | | tanh (operand : ValueId) (typ : FunctionType) 424 | | other 425 | (name : String) 426 | (inputValues : List ValueId) 427 | (inputFunctions : List InputFunc) 428 | (inputAttributes : List Attribute) 429 | (outputs : List ValueId) 430 | (signature : FunctionType) 431 | | return 432 | (operands : List ValueId) 433 | (signature : FunctionType) 434 | | call 435 | (callee : FuncId) 436 | (inputValues : List ValueId) 437 | (outputs : List ValueId) 438 | (signature : FunctionType) 439 | deriving Repr, Inhabited, Nonempty 440 | 441 | end 442 | 443 | structure Function where 444 | funcId : FuncId 445 | funcArgAttrs : List (List Attribute) 446 | funcResAttrs : List (List Attribute) 447 | funcType : FunctionType 448 | funcBody : InputFunc 449 | deriving Repr, Inhabited, Nonempty 450 | 451 | structure Module where 452 | modId : Option FuncId 453 | modAttrs : List Attribute 454 | modFuncs : List Function 455 | deriving Repr, Inhabited, Nonempty 456 | 457 | def Program := List Module 458 | deriving Repr, Inhabited, Nonempty 459 | 460 | end StableHLO.Parsing 461 | -------------------------------------------------------------------------------- /SHerLOC/Basic.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Basic 8 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Basic.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.Parsing.Programs 7 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Functions.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Operations 9 | import SHerLOC.Parsing.Intermediate 10 | 11 | namespace StableHLO.Parsing 12 | 13 | def tryParseDEntryFunctionType : PState (Option FunctionType) := do 14 | tryParseDictionaryEntry "function_type" parseFunctionType 15 | 16 | def parseDictionaryAttributesInner : PState (List Attribute) := do 17 | parseList "{" "}" "," parseAttribute 18 | 19 | def parseDictionaryAttributesOutter : PState (List (List Attribute)) := do 20 | parseList "[" "]" "," parseDictionaryAttributesInner 21 | 22 | def tryParseDEntryResultAttributes : PState (Option (List (List Attribute))) := do 23 | tryParseDictionaryEntry "res_attrs" parseDictionaryAttributesOutter 24 | 25 | def tryParseDEntryArgAttributes : PState (Option (List (List Attribute))) := do 26 | tryParseDictionaryEntry "arg_attrs" parseDictionaryAttributesOutter 27 | 28 | def tryParseDEntrySymName : PState (Option String) := do 29 | tryParseDictionaryEntry "sym_name" parseString 30 | 31 | def tryParseDEntrySymVisibility : PState (Option String) := do 32 | tryParseDictionaryEntry "sym_visibility" parseString 33 | 34 | def parseFunctionDictionaryAttributes : PState (String × FunctionType × (List (List Attribute)) × (List (List Attribute))) := do 35 | let mut functionName : Option String := none 36 | let mut functionType : Option FunctionType := none 37 | let mut functionVisibility : Option String := none 38 | let mut functionResultAttributes : List (List Attribute) := [] 39 | let mut functionArgAttributes : List (List Attribute) := [] 40 | let mut functionNameTodo : Bool := true 41 | let mut functionTypeTodo : Bool := true 42 | let mut functionVisibilityTodo : Bool := true 43 | let mut functionResultAttributesTodo : Bool := true 44 | let mut functionArgAttributesTodo : Bool := true 45 | let mut count := 0 46 | 47 | for _ in [:5] do 48 | if functionNameTodo then 49 | if let some name ← tryParseDEntrySymName then 50 | functionName := name 51 | functionNameTodo := false 52 | count := count + 1 53 | if ← is "," then parseItem "," else break 54 | 55 | if functionTypeTodo then 56 | if let some t ← tryParseDEntryFunctionType then 57 | functionType := t 58 | functionTypeTodo := false 59 | count := count + 1 60 | if ← is "," then parseItem "," else break 61 | 62 | if functionResultAttributesTodo then 63 | if let some res ← tryParseDEntryResultAttributes then 64 | functionResultAttributes := res 65 | functionResultAttributesTodo := false 66 | count := count + 1 67 | if ← is "," then parseItem "," else break 68 | 69 | if functionArgAttributesTodo then 70 | if let some res ← tryParseDEntryArgAttributes then 71 | functionArgAttributes := res 72 | functionArgAttributesTodo := false 73 | count := count + 1 74 | if ← is "," then parseItem "," else break 75 | 76 | if functionVisibilityTodo then 77 | if let some visibility ← tryParseDEntrySymVisibility then 78 | functionVisibility := visibility 79 | functionVisibilityTodo := false 80 | count := count + 1 81 | if ← is "," then parseItem "," else break 82 | 83 | if count = 5 then break 84 | 85 | if let some name := functionName then 86 | if let some typ := functionType then 87 | return (name, typ, functionArgAttributes, functionResultAttributes) 88 | else 89 | throw <| ← error "A5" 90 | else 91 | throw <| ← error "A6" 92 | 93 | def parseFunction : PState Function := do 94 | parseItems ["\"func.func\"", "(", ")"] 95 | parseItem "<{" 96 | let (name,typ,argAttrs,resAttrs) ← parseFunctionDictionaryAttributes 97 | parseItem "}>" 98 | let mut funcInputs : List FuncInput := [] 99 | parseItem "({" 100 | if ← is "^" then 101 | discard <| parseUnusedId 102 | funcInputs ← parseInputFuncInputs 103 | parseItem ":" 104 | let operations ← parseListAuxNoSep "}" parseOperation [] 105 | let body : InputFunc := InputFunc.mk funcInputs operations 106 | parseItem "})" 107 | parseItems [":","(",")","->","(",")"] 108 | let r : Function := { funcId := name , funcArgAttrs := argAttrs , funcResAttrs := resAttrs , funcType := typ, funcBody := body } 109 | return r 110 | 111 | def parseFunctions : PState (List Function) := do 112 | parseListAuxNoSep "}" parseFunction [] 113 | 114 | end StableHLO.Parsing 115 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Identifiers.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | 9 | namespace StableHLO.Parsing 10 | 11 | def parseValueId : PState String := do 12 | parseItem "%" 13 | parseId 14 | 15 | def parseValueIdRes : PState String := do 16 | let r ← parseValueId 17 | let mut r' := "" 18 | if ← isParse ":" then 19 | r' ← parseId 20 | r' := ":" ++ r' 21 | let r := r ++ r' 22 | return r 23 | 24 | def parseValueIdOpArg : PState String := do 25 | let r ← parseValueId 26 | let mut r' := "" 27 | if ← isParse "#" then 28 | r' ← parseId 29 | r' := "#" ++ r' 30 | let r := r ++ r' 31 | return r 32 | 33 | def parseFuncId : PState String := do 34 | parseItem "@" 35 | parseFId 36 | 37 | def parseUnusedId : PState String := do 38 | parseItem "^" 39 | parseId 40 | 41 | def parseAttrId : PState String := do 42 | parseId 43 | 44 | end StableHLO.Parsing 45 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Intermediate.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Identifiers 9 | import SHerLOC.Parsing.Types 10 | 11 | namespace StableHLO.Parsing 12 | 13 | def parseStableHLORecordFieldValue : PState (StableHLORecordFieldValue) := do 14 | if (← is "[") then 15 | let value ← parseDecimals 16 | return StableHLORecordFieldValue.many value 17 | else if (← isDigit) then 18 | let value ← parseDecimal 19 | return StableHLORecordFieldValue.one value 20 | else if (← isParse "true") then 21 | return StableHLORecordFieldValue.bool true 22 | else if (← isParse "false") then 23 | return StableHLORecordFieldValue.bool false 24 | else 25 | let type ← parseFloatType 26 | return StableHLORecordFieldValue.type type 27 | 28 | def parseStableHLORecordField : PState (StableHLORecordField) := do 29 | let name ← parseId 30 | parseItem "=" 31 | let value ← parseStableHLORecordFieldValue 32 | return StableHLORecordField.mk name value 33 | 34 | def parseRecord : PState (List StableHLORecordField) := do 35 | let r ← parseList "<" ">" "," parseStableHLORecordField 36 | return r 37 | 38 | mutual 39 | 40 | partial def parseLiteral : PState Literal := do 41 | skip 42 | if (← isDigit) || (← isChar '+') || (← isChar '-') then 43 | return Literal.element <| ElementLiteral.floatLiteral <| ← parseFloatLiteral 44 | if ← isChar 'd' then 45 | return Literal.tensor <| ← parseTensorLiteral 46 | if (← is "tr") || (← is "fa") then 47 | return Literal.element <| ElementLiteral.booleanLiteral <| ← parseBooleanLiteral 48 | if (← isChar '(') then 49 | return Literal.element <| ElementLiteral.complexLiteral <| ← parseComplexLiteral 50 | if ← isChar '"' then 51 | return Literal.string <| ← parseStringLiteral 52 | if ← isChar 'a' then 53 | report "literal array" 54 | return Literal.array <| ← parseArrayLiteral 55 | 56 | if ← isParse "#stablehlo" then { 57 | if (← isParse ".") then { 58 | report "literal record" 59 | if ← isParse "conv" then return Literal.convolution <| ← parseConvolution 60 | if ← isParse "dot_algorithm" then return Literal.stableHLORecord <| ← parseRecord 61 | if ← isParse "dot" then return Literal.stableHLORecord <| ← parseRecord 62 | if ← isParse "channel_handle" then return Literal.stableHLORecord <| ← parseRecord 63 | if ← isParse "scatter" then return Literal.stableHLORecord <| ← parseRecord 64 | if ← isParse "gather" then return Literal.stableHLORecord <| ← parseRecord 65 | } else return Literal.enum <| ← parseEnumLiteral 66 | } 67 | 68 | if ← isChar '[' then 69 | report "literal list" 70 | return Literal.list <| ← parseList "[" "]" "," parseLiteral 71 | 72 | if ← isChar '{' then 73 | report "literal attribute" 74 | return Literal.dictionary <| ← parseAttributes 75 | 76 | if ← isChar '@' then 77 | report "literal function" 78 | return Literal.func <| ← parseFuncId 79 | 80 | throw <| (← error "literal") 81 | 82 | partial def parseConstant : PState Constant := do 83 | let literal ← parseLiteral 84 | let mut typ : Option SType := none 85 | if ← isParse ":" then 86 | typ ← parseType 87 | let r : Constant := Constant.mk literal typ 88 | return r 89 | 90 | partial def parseAttribute : PState Attribute := do 91 | if ← isParse "use_global_device_ids" then 92 | report "literal use_global_device_ids" 93 | return Attribute.mk "use_global_device_ids" <| Constant.mk (Literal.element (ElementLiteral.booleanLiteral BooleanLiteral.true)) none 94 | else 95 | let id ← parseId 96 | parseItem "=" 97 | let constant ← parseConstant 98 | return Attribute.mk id constant 99 | 100 | partial def parseAttributes : PState (List Attribute) := do 101 | parseList "{" "}" "," parseAttribute 102 | 103 | end 104 | 105 | def parseValueUseList : PState (List ValueId) := do 106 | parseList "(" ")" "," parseValueIdOpArg 107 | 108 | def tryParseDictionaryEntry (name : String) (parser : PState T) : PState (Option T) := do 109 | if ← is name then 110 | parseItem name 111 | parseItem "=" 112 | let t ← parser 113 | return some t 114 | else return none 115 | 116 | def parseDictionaryProperties : PState (List Attribute) := do 117 | parseList "<{" "}>" "," parseAttribute 118 | 119 | end StableHLO.Parsing 120 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Modules.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Operations 9 | import SHerLOC.Parsing.Functions 10 | import SHerLOC.Parsing.Intermediate 11 | 12 | namespace StableHLO.Parsing 13 | 14 | def parseModule : PState Module := do 15 | parseItems ["\"builtin.module\"", "(", ")"] 16 | let mut name : Option FuncId := none 17 | if ← is "<{" then 18 | parseItem "<{" 19 | parseItem "sym_name" 20 | parseItem "=" 21 | name ← parseString 22 | parseItem "}>" 23 | parseItem "(" 24 | if (← isParse "{") then 25 | if (← isParse "^bb0:") then -- Empty module 26 | let r : Module := { modId := name, modAttrs := [], modFuncs := [] } 27 | parseItems ["}",")"] 28 | parseItems [":","(",")","->","(",")"] 29 | return r 30 | let region ← parseFunctions 31 | parseItems ["}",")"] 32 | let mut attributes : List Attribute := [] 33 | if ← is "{" then 34 | attributes ← parseAttributes 35 | parseItems [":","(",")","->","(",")"] 36 | let r : Module := { modId := name, modAttrs := attributes, modFuncs := region } 37 | return r 38 | else 39 | let r : Module := { modId := name, modAttrs := [], modFuncs := [] } 40 | return r 41 | 42 | partial def parseModules : PState (List Module) := do 43 | let done ← done? 44 | if done then 45 | return [] 46 | else 47 | let mod ← parseModule 48 | let mods ← parseModules 49 | return mod :: mods 50 | 51 | 52 | 53 | end StableHLO.Parsing 54 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Numbers.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Identifiers 9 | 10 | namespace StableHLO.Parsing 11 | 12 | def parseBooleanLiteral : PState BooleanLiteral := do 13 | if ← isParse "true" then return BooleanLiteral.true 14 | if ← isParse "false" then return BooleanLiteral.false 15 | throw <| ← error "Boolean literal" 16 | 17 | def parseIntegerLiteral : PState IntegerLiteral := do 18 | let mut sign := Sign.plus 19 | if ← isParse "+" then sign := Sign.plus 20 | else if ← isParse "-" then sign := Sign.minus 21 | let mut nat : Option Nat := none 22 | if ← is "0x" then 23 | nat ← parseHexaDecimal 24 | else 25 | nat ← parseDecimal 26 | if let some v := nat then 27 | let parseResult := { sign := sign , decimal := v } 28 | return parseResult 29 | else 30 | throw <| ← error "Integer literal" 31 | 32 | def parseFloatLiteral : PState FloatLiteral := do 33 | let mut sign := Sign.plus 34 | if ← isParse "+" then sign := Sign.plus 35 | else if ← isParse "-" then sign := Sign.minus 36 | if ← is "0x" then 37 | let nat ← parseHexaDecimal 38 | return FloatLiteral.hexaDecimal nat 39 | else 40 | let nat ← parseDecimal 41 | let integerPart : IntegerLiteral := { sign := sign , decimal := nat } 42 | let mut fractionalPart : IntegerLiteral := { sign := Sign.plus, decimal := 0 } 43 | if ← isParse "." then 44 | fractionalPart := {fractionalPart with decimal := ← parseDecimal} 45 | let mut scientificPart : IntegerLiteral:= { sign := Sign.plus, decimal := 0 } 46 | if (← isParse "e") || (← isParse "E") then 47 | let mut scientificSign := Sign.plus 48 | if ← isParse "+" then scientificSign := Sign.plus 49 | else if ← isParse "-" then scientificSign := Sign.minus 50 | let nat ← parseDecimal 51 | scientificPart := { sign := scientificSign, decimal := nat } 52 | let parseResult := FloatLiteral.decimal 53 | { integerPart := integerPart, 54 | fractionalPart := fractionalPart, 55 | scientificPart := scientificPart 56 | } 57 | return parseResult 58 | 59 | def parseComplexLiteral : PState ComplexLiteral := do 60 | parseItem "(" 61 | let realPart ← parseFloatLiteral 62 | parseItem "," 63 | let imaginaryPart ← parseFloatLiteral 64 | parseItem ")" 65 | let parseResult := { real := realPart, imaginary := imaginaryPart } 66 | return parseResult 67 | 68 | def parseElementLiteral : PState ElementLiteral := do 69 | skip 70 | if (← isDigit) || (← is "+") || (← is "-") then 71 | return ElementLiteral.floatLiteral <| ← parseFloatLiteral 72 | if (← is "t") || (← is "f") then 73 | return ElementLiteral.booleanLiteral <| ← parseBooleanLiteral 74 | if (← is "(") then 75 | return ElementLiteral.complexLiteral <| ← parseComplexLiteral 76 | if (← is "\"") then -- Not a good idea to try to parse these directly as numbers, they can be extremely large 77 | return ElementLiteral.stringLiteral <| ← parseString 78 | throw <| ← error "Element literal" 79 | 80 | def parseDenseElements (closingMark : String) : PState (List ElementLiteral) := do 81 | parseListAux closingMark "," parseElementLiteral 82 | 83 | partial def parseDenseLiteral : PState DenseLiteral := do 84 | if ← is "[" then 85 | let denseDimension ← parseList "[" "]" "," parseDenseLiteral 86 | return DenseLiteral.denseDimension denseDimension 87 | else 88 | let denseElements ← parseDenseElements "]" 89 | return DenseLiteral.denseElements denseElements 90 | 91 | def parseTensorLiteral : PState TensorLiteral := do 92 | parseItem "dense" 93 | parseItem "<" 94 | if ← is "[" then 95 | let denseLiteral ← parseDenseLiteral 96 | parseItem ">" 97 | return denseLiteral 98 | else 99 | let denseElements ← parseDenseElements ">" 100 | let denseLiteral := DenseLiteral.denseElements denseElements 101 | parseItem ">" 102 | return denseLiteral 103 | 104 | def parseStringLiteral : PState String := do 105 | parseString 106 | 107 | def parseComparisonDirection : PState ComparisonDirection := do 108 | let mut r := none 109 | if ← isParse "EQ" then r := ComparisonDirection.eq 110 | if ← isParse "NE" then r := ComparisonDirection.ne 111 | if ← isParse "GE" then r := ComparisonDirection.ge 112 | if ← isParse "GT" then r := ComparisonDirection.gt 113 | if ← isParse "LE" then r := ComparisonDirection.le 114 | if ← isParse "LT" then r := ComparisonDirection.lt 115 | if let some res := r then 116 | return res 117 | else throw <| ← error "comparison direction" 118 | 119 | def parseCompareType : PState CompareType := do 120 | let mut r := none 121 | if ← isParse "FLOAT" then r := CompareType.float 122 | if ← isParse "TOTALORDER" then r := CompareType.totalOrder 123 | if ← isParse "SIGNED" then r := CompareType.signed 124 | if ← isParse "UNSIGNED" then r := CompareType.unsigned 125 | if let some res := r then 126 | return res 127 | else throw <| ← error "compaare type" 128 | 129 | def parsePrecisionConfig : PState PrecisionConfig := do 130 | let mut r := none 131 | if ← isParse "DEFAULT" then r := PrecisionConfig.default 132 | if ← isParse "HIGHEST" then r := PrecisionConfig.highest 133 | if ← isParse "HIGH" then r := PrecisionConfig.high 134 | if let some res := r then 135 | return res 136 | else throw <| ← error "precision config" 137 | 138 | def parseFftType : PState FftType := do 139 | let mut r := none 140 | if ← isParse "FFT" then r := FftType.fft 141 | if ← isParse "IFFT" then r := FftType.ifft 142 | if ← isParse "RFFT" then r := FftType.rfft 143 | if ← isParse "IRFFT" then r := FftType.irfft 144 | if let some res := r then 145 | return res 146 | else throw <| ← error "FFT type" 147 | 148 | def parseChannelType : PState ChannelType := do 149 | let mut r := none 150 | if ← isParse "DEVICE_TO_DEVICE" then r := ChannelType.deviceToDevice 151 | if ← isParse "HOST_TO_DEVICE" then r := ChannelType.hostToDevice 152 | if let some res := r then 153 | return res 154 | else throw <| ← error "channel type" 155 | 156 | def parseRngDistribution : PState RngDistribution := do 157 | let mut r := none 158 | if ← isParse "UNIFORM" then r := RngDistribution.uniform 159 | if ← isParse "NORMAL" then r := RngDistribution.normal 160 | if let some res := r then 161 | return res 162 | else throw <| ← error "rng distribution" 163 | 164 | def parseRngAlgorithm : PState RngAlgorithm := do 165 | let mut r := none 166 | if ← isParse "DEFAULT" then r := RngAlgorithm.default 167 | if ← isParse "THREE_FRY" then r := RngAlgorithm.threeFry 168 | if ← isParse "PHILOX" then r := RngAlgorithm.philox 169 | if let some res := r then 170 | return res 171 | else throw <| ← error "rng algorithm" 172 | 173 | def parseTransposeA : PState TransposeA := do 174 | let mut r := none 175 | if ← isParse "NO_TRANSPOSE" then r := TransposeA.noTranspose 176 | if ← isParse "TRANSPOSE" then r := TransposeA.transpose 177 | if ← isParse "ADJOINT" then r := TransposeA.adjoint 178 | if let some res := r then 179 | return res 180 | else throw <| ← error "tranpose annotation" 181 | 182 | def parseEnumLiteral : PState EnumLiteral := do 183 | parseItem "<" 184 | let mut r := none 185 | if ← isParse "comparison_direction" then r := EnumLiteral.comparisonDirection <| ← parseComparisonDirection 186 | if ← isParse "comparison_type" then r := EnumLiteral.compareType <| ← parseCompareType 187 | if ← isParse "precision" then r := EnumLiteral.precisionConfig <| ← parsePrecisionConfig 188 | if ← isParse "fft_type" then r := EnumLiteral.fftType <| ← parseFftType 189 | if ← isParse "channel_type" then r := EnumLiteral.channelType <| ← parseChannelType 190 | if ← isParse "rng_distribution" then r := EnumLiteral.rngDistribution <| ← parseRngDistribution 191 | if ← isParse "rng_algorithm" then r := EnumLiteral.rngAlgorithm <| ← parseRngAlgorithm 192 | if ← isParse "transpose" then r := EnumLiteral.transposeA <| ← parseTransposeA 193 | if let some res := r then 194 | parseItem ">" 195 | return res 196 | else throw <| ← error "enumeration" 197 | 198 | def parseArrayLiteral : PState ArrayLiteral := do 199 | parseItems ["array", "<"] 200 | if ← isParse "i64" then 201 | let mut r := [] 202 | if ← isParse ":" then 203 | r ← parseListAux ">" "," parseIntegerLiteral 204 | parseItem ">" 205 | return ArrayLiteral.array64 r 206 | if ← isParse "i1" then 207 | let mut r := [] 208 | if ← isParse ":" then 209 | r ← parseListAux ">" "," parseBooleanLiteral 210 | parseItem ">" 211 | return ArrayLiteral.array1 r 212 | throw <| ← error "array literal" 213 | 214 | def parseConvolutionMode : PState ConvolutionMode := do 215 | let mut r := none 216 | if (← isParse "o") then r := ConvolutionMode.o 217 | else if (← isParse "f") then r := ConvolutionMode.f 218 | else if (← isParse "i") then r := ConvolutionMode.i 219 | else if (← isParse "0") then r := ConvolutionMode.zero 220 | else if (← isParse "1") then r := ConvolutionMode.one 221 | else if (← isParse "b") then r := ConvolutionMode.b 222 | else if (← isParse "2") then r := ConvolutionMode.two 223 | if let some res := r then return res 224 | else throw <| ← error "convolution mode" 225 | 226 | def parseConvolutionModes : PState (List ConvolutionMode) := do 227 | parseList "[" "]" "," parseConvolutionMode 228 | 229 | def parseConvolution : PState Convolution := do 230 | parseItem "<" 231 | let lhs ← parseConvolutionModes 232 | parseItem "x" 233 | let rhs ← parseConvolutionModes 234 | parseItem "->" 235 | let result ← parseConvolutionModes 236 | parseItem ">" 237 | return { lhs, rhs, result } 238 | 239 | end StableHLO.Parsing 240 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Operations.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Identifiers 9 | import SHerLOC.Parsing.Intermediate 10 | 11 | namespace StableHLO.Parsing 12 | 13 | def parseOpOutputs : PState (List ValueId) := do 14 | let r ← parseListAux "=" "," parseValueIdRes 15 | return r 16 | 17 | def parseInputFuncInput : PState FuncInput := do 18 | let id ← parseValueId 19 | parseItem ":" 20 | let typ ← parseValueType 21 | return { id := id , typ := typ } 22 | 23 | def parseInputFuncInputs : PState (List FuncInput) := do 24 | let r ← parseList "(" ")" "," parseInputFuncInput 25 | return r 26 | 27 | def parseReturn : PState Operation := do 28 | let arguments ← parseValueUseList 29 | parseItem ":" 30 | let functiontype ← parseFunctionType 31 | let parseResult := Operation.return arguments functiontype 32 | return parseResult 33 | 34 | def parseCall (outputs : List ValueId) : PState Operation := do 35 | parseItem "\"func.call\"" 36 | let arguments ← parseValueUseList 37 | parseItem "<{" 38 | parseItem "callee" 39 | parseItem "=" 40 | let callee ← parseFuncId 41 | parseItem "}>" 42 | parseItem ":" 43 | let typ ← parseFunctionType 44 | let r := Operation.call callee arguments outputs typ 45 | return r 46 | 47 | def parseOpCode : PState OpCode := do 48 | parseItems ["\"", "stablehlo."] 49 | let opCodeString ← parseId 50 | let mut opCode : Option OpCode := none 51 | match opCodeString with 52 | | "abs" => opCode := some OpCode.abs 53 | | "add" => opCode := some OpCode.add 54 | | "after_all" => opCode := some OpCode.afterAll 55 | | "all_gather" => opCode := some OpCode.allGather 56 | | "all_reduce" => opCode := some OpCode.allReduce 57 | | "all_to_all" => opCode := some OpCode.allToAll 58 | | "and" => opCode := some OpCode.and 59 | | "atan2" => opCode := some OpCode.atan2 60 | | "batch_norm_grad" => opCode := some OpCode.batchNormGrad 61 | | "batch_norm_inference" => opCode := some OpCode.batchNormInference 62 | | "batch_norm_training" => opCode := some OpCode.batchNormTraining 63 | | "bitcast_convert" => opCode := some OpCode.bitcastConvert 64 | | "broadcast_in_dim" => opCode := some OpCode.broadcastInDim 65 | | "case" => opCode := some OpCode.case 66 | | "cbrt" => opCode := some OpCode.cbrt 67 | | "ceil" => opCode := some OpCode.ceil 68 | | "cholesky" => opCode := some OpCode.cholesky 69 | | "clamp" => opCode := some OpCode.clamp 70 | | "collective_broadcast" => opCode := some OpCode.collectiveBroadcast 71 | | "collective_permute" => opCode := some OpCode.collectivePermute 72 | | "compare" => opCode := some OpCode.compare 73 | | "complex" => opCode := some OpCode.complex 74 | | "composite" => opCode := some OpCode.composite 75 | | "concatenate" => opCode := some OpCode.concatenate 76 | | "constant" => opCode := some OpCode.constant 77 | | "convert" => opCode := some OpCode.convert 78 | | "convolution" => opCode := some OpCode.convolution 79 | | "cosine" => opCode := some OpCode.cosine 80 | | "count_leading_zeros" => opCode := some OpCode.countLeadingZeros 81 | | "custom_call" => opCode := some OpCode.customCall 82 | | "divide" => opCode := some OpCode.divide 83 | | "dot" => opCode := some OpCode.dot 84 | | "dot_general" => opCode := some OpCode.dotGeneral 85 | | "dynamic_broadcast_in_dim" => opCode := some OpCode.dynamicBroadcastInDim 86 | | "dynamic_conv" => opCode := some OpCode.dynamicConv 87 | | "dynamic_gather" => opCode := some OpCode.dynamicGather 88 | | "dynamic_iota" => opCode := some OpCode.dynamicIota 89 | | "dynamic_pad" => opCode := some OpCode.dynamicPad 90 | | "dynamic_reshape" => opCode := some OpCode.dynamicReshape 91 | | "dynamic_slice" => opCode := some OpCode.dynamicSlice 92 | | "dynamic_update_slice" => opCode := some OpCode.dynamicUpdateSlice 93 | | "exponential" => opCode := some OpCode.exponential 94 | | "exponential_minus_one" => opCode := some OpCode.exponentialMinusOne 95 | | "fft" => opCode := some OpCode.fft 96 | | "floor" => opCode := some OpCode.floor 97 | | "gather" => opCode := some OpCode.gather 98 | | "get_dimension_size" => opCode := some OpCode.getDimensionSize 99 | | "get_tuple_element" => opCode := some OpCode.getTupleElement 100 | | "if" => opCode := some OpCode.if 101 | | "imag" => opCode := some OpCode.imag 102 | | "infeed" => opCode := some OpCode.infeed 103 | | "iota" => opCode := some OpCode.iota 104 | | "is_finite" => opCode := some OpCode.isFinite 105 | | "log" => opCode := some OpCode.log 106 | | "log_plus_one" => opCode := some OpCode.logPlusOne 107 | | "logistic" => opCode := some OpCode.logistic 108 | | "map" => opCode := some OpCode.map 109 | | "maximum" => opCode := some OpCode.maximum 110 | | "minimum" => opCode := some OpCode.minimum 111 | | "multiply" => opCode := some OpCode.multiply 112 | | "negate" => opCode := some OpCode.negate 113 | | "not" => opCode := some OpCode.not 114 | | "optimization_barrier" => opCode := some OpCode.optimizationBarrier 115 | | "or" => opCode := some OpCode.or 116 | | "outfeed" => opCode := some OpCode.outfeed 117 | | "pad" => opCode := some OpCode.pad 118 | | "partition_id" => opCode := some OpCode.partitionId 119 | | "popcnt" => opCode := some OpCode.popcnt 120 | | "power" => opCode := some OpCode.power 121 | | "real" => opCode := some OpCode.real 122 | | "real_dynamic_slice" => opCode := some OpCode.realDynamicSlice 123 | | "recv" => opCode := some OpCode.recv 124 | | "reduce" => opCode := some OpCode.reduce 125 | | "reduce_precision" => opCode := some OpCode.reducePrecision 126 | | "reduce_scatter" => opCode := some OpCode.reduceScatter 127 | | "reduce_window" => opCode := some OpCode.reduceWindow 128 | | "remainder" => opCode := some OpCode.remainder 129 | | "replica_id" => opCode := some OpCode.replicaId 130 | | "reshape" => opCode := some OpCode.reshape 131 | | "reverse" => opCode := some OpCode.reverse 132 | | "rng" => opCode := some OpCode.rng 133 | | "rng_bit_generator" => opCode := some OpCode.rngBitGenerator 134 | | "round_nearest_afz" => opCode := some OpCode.roundNearestAfz 135 | | "round_nearest_even" => opCode := some OpCode.roundNearestEven 136 | | "rsqrt" => opCode := some OpCode.rsqrt 137 | | "scatter" => opCode := some OpCode.scatter 138 | | "select" => opCode := some OpCode.select 139 | | "select_and_scatter" => opCode := some OpCode.selectAndScatter 140 | | "send" => opCode := some OpCode.send 141 | | "shift_left" => opCode := some OpCode.shiftLeft 142 | | "shift_right_arithmetic" => opCode := some OpCode.shiftRightArithmetic 143 | | "shift_right_logical" => opCode := some OpCode.shiftRightLogical 144 | | "sign" => opCode := some OpCode.sign 145 | | "sine" => opCode := some OpCode.sine 146 | | "slice" => opCode := some OpCode.slice 147 | | "sort" => opCode := some OpCode.sort 148 | | "sqrt" => opCode := some OpCode.sqrt 149 | | "subtract" => opCode := some OpCode.subtract 150 | | "tan" => opCode := some OpCode.tan 151 | | "tanh" => opCode := some OpCode.tanh 152 | | "transpose" => opCode := some OpCode.transpose 153 | | "triangular_solve" => opCode := some OpCode.triangularSolve 154 | | "tuple" => opCode := some OpCode.tuple 155 | | "uniform_dequantize" => opCode := some OpCode.uniformDequantize 156 | | "uniform_quantize" => opCode := some OpCode.uniformQuantize 157 | | "while" => opCode := some OpCode.while 158 | | "xor" => opCode := some OpCode.xor 159 | | _ => opCode := none 160 | if let some op := opCode then 161 | parseItem "\"" 162 | return op 163 | else throw (← errorSimple (String.join ["Unknown op code: '", opCodeString, "'"])) 164 | 165 | mutual 166 | 167 | partial def parseInputFunc : PState InputFunc := do 168 | parseItem "{" 169 | let mut funcInputs : List FuncInput := [] 170 | if ← is "^" then 171 | discard parseUnusedId 172 | funcInputs ← parseInputFuncInputs 173 | parseItem ":" 174 | let body ← parseInputFuncBody 175 | parseItem "}" 176 | return InputFunc.mk funcInputs body 177 | 178 | partial def parseOpInputFuncs : PState (List InputFunc) := do 179 | let r ← parseList "(" ")" "," parseInputFunc 180 | return r 181 | 182 | partial def parseOperationDictionaryAttributes : PState (List Attribute) := do 183 | let r ← parseList "<{" "}>" "," parseAttribute 184 | return r 185 | 186 | partial def parseOperationBasic (op : OpCode) (opOutputs : List ValueId) : PState Operation := do 187 | let opInputValues ← parseValueUseList 188 | let mut opInputAttrs := [] 189 | if ← is "<{" then 190 | opInputAttrs ← parseOperationDictionaryAttributes 191 | let mut opInputFuncs := [] 192 | if ← is "(" then 193 | opInputFuncs ← parseOpInputFuncs 194 | parseItem ":" 195 | let functiontype ← parseFunctionType 196 | let operation := Operation.stablehlo op opInputValues opInputFuncs opInputAttrs opOutputs functiontype 197 | return operation 198 | 199 | partial def parseOtherDialect (opOutputs : List ValueId) : PState Operation := do 200 | let name ← parseString 201 | report s!"undocumented operation: {name}" 202 | let opInputValues ← parseValueUseList 203 | let mut opInputAttrs := [] 204 | if ← is "<{" then 205 | opInputAttrs ← parseOperationDictionaryAttributes 206 | let mut opInputFuncs := [] 207 | if ← is "(" then 208 | opInputFuncs ← parseOpInputFuncs 209 | parseItem ":" 210 | let functiontype ← parseFunctionType 211 | let operation := Operation.other name opInputValues opInputFuncs opInputAttrs opOutputs functiontype 212 | return operation 213 | 214 | partial def parseStableHLO (opOutputs : List ValueId) : PState Operation := do 215 | let opCode ← parseOpCode 216 | match opCode with 217 | | OpCode.abs => parseOperationBasic OpCode.abs opOutputs 218 | | OpCode.add => parseOperationBasic OpCode.add opOutputs 219 | | OpCode.afterAll => parseOperationBasic OpCode.afterAll opOutputs 220 | | OpCode.allGather => parseOperationBasic OpCode.allGather opOutputs 221 | | OpCode.allReduce => parseOperationBasic OpCode.allReduce opOutputs 222 | | OpCode.allToAll => parseOperationBasic OpCode.allToAll opOutputs 223 | | OpCode.and => parseOperationBasic OpCode.and opOutputs 224 | | OpCode.atan2 => parseOperationBasic OpCode.atan2 opOutputs 225 | | OpCode.batchNormGrad => parseOperationBasic OpCode.batchNormGrad opOutputs 226 | | OpCode.batchNormInference => parseOperationBasic OpCode.batchNormInference opOutputs 227 | | OpCode.batchNormTraining => parseOperationBasic OpCode.batchNormTraining opOutputs 228 | | OpCode.bitcastConvert => parseOperationBasic OpCode.bitcastConvert opOutputs 229 | | OpCode.broadcastInDim => parseOperationBasic OpCode.broadcastInDim opOutputs 230 | | OpCode.case => parseOperationBasic OpCode.case opOutputs 231 | | OpCode.cbrt => parseOperationBasic OpCode.cbrt opOutputs 232 | | OpCode.ceil => parseOperationBasic OpCode.ceil opOutputs 233 | | OpCode.cholesky => parseOperationBasic OpCode.cholesky opOutputs 234 | | OpCode.clamp => parseOperationBasic OpCode.clamp opOutputs 235 | | OpCode.collectiveBroadcast => parseOperationBasic OpCode.collectiveBroadcast opOutputs 236 | | OpCode.collectivePermute => parseOperationBasic OpCode.collectivePermute opOutputs 237 | | OpCode.compare => parseOperationBasic OpCode.compare opOutputs 238 | | OpCode.complex => parseOperationBasic OpCode.complex opOutputs 239 | | OpCode.composite => parseOperationBasic OpCode.composite opOutputs 240 | | OpCode.concatenate => parseOperationBasic OpCode.concatenate opOutputs 241 | | OpCode.constant => parseOperationBasic OpCode.constant opOutputs 242 | | OpCode.convert => parseOperationBasic OpCode.convert opOutputs 243 | | OpCode.convolution => parseOperationBasic OpCode.convolution opOutputs 244 | | OpCode.cosine => parseOperationBasic OpCode.cosine opOutputs 245 | | OpCode.countLeadingZeros => parseOperationBasic OpCode.countLeadingZeros opOutputs 246 | | OpCode.customCall => parseOperationBasic OpCode.customCall opOutputs 247 | | OpCode.divide => parseOperationBasic OpCode.divide opOutputs 248 | | OpCode.dot => parseOperationBasic OpCode.dot opOutputs 249 | | OpCode.dotGeneral => parseOperationBasic OpCode.dotGeneral opOutputs 250 | | OpCode.dynamicBroadcastInDim => parseOperationBasic OpCode.dynamicBroadcastInDim opOutputs 251 | | OpCode.dynamicConv => parseOperationBasic OpCode.dynamicConv opOutputs 252 | | OpCode.dynamicGather => parseOperationBasic OpCode.dynamicGather opOutputs 253 | | OpCode.dynamicIota => parseOperationBasic OpCode.dynamicIota opOutputs 254 | | OpCode.dynamicPad => parseOperationBasic OpCode.dynamicPad opOutputs 255 | | OpCode.dynamicReshape => parseOperationBasic OpCode.dynamicReshape opOutputs 256 | | OpCode.dynamicSlice => parseOperationBasic OpCode.dynamicSlice opOutputs 257 | | OpCode.dynamicUpdateSlice => parseOperationBasic OpCode.dynamicUpdateSlice opOutputs 258 | | OpCode.exponential => parseOperationBasic OpCode.exponential opOutputs 259 | | OpCode.exponentialMinusOne => parseOperationBasic OpCode.exponentialMinusOne opOutputs 260 | | OpCode.fft => parseOperationBasic OpCode.fft opOutputs 261 | | OpCode.floor => parseOperationBasic OpCode.floor opOutputs 262 | | OpCode.gather => parseOperationBasic OpCode.gather opOutputs 263 | | OpCode.getDimensionSize => parseOperationBasic OpCode.getDimensionSize opOutputs 264 | | OpCode.getTupleElement => parseOperationBasic OpCode.getTupleElement opOutputs 265 | | OpCode.if => parseOperationBasic OpCode.if opOutputs 266 | | OpCode.imag => parseOperationBasic OpCode.imag opOutputs 267 | | OpCode.infeed => 268 | report "Semantics implementation defined infeed" 269 | parseOperationBasic OpCode.infeed opOutputs 270 | | OpCode.iota => parseOperationBasic OpCode.iota opOutputs 271 | | OpCode.isFinite => parseOperationBasic OpCode.isFinite opOutputs 272 | | OpCode.log => parseOperationBasic OpCode.log opOutputs 273 | | OpCode.logPlusOne => parseOperationBasic OpCode.logPlusOne opOutputs 274 | | OpCode.logistic => parseOperationBasic OpCode.logistic opOutputs 275 | | OpCode.map => parseOperationBasic OpCode.map opOutputs 276 | | OpCode.maximum => parseOperationBasic OpCode.maximum opOutputs 277 | | OpCode.minimum => parseOperationBasic OpCode.minimum opOutputs 278 | | OpCode.multiply => parseOperationBasic OpCode.multiply opOutputs 279 | | OpCode.negate => parseOperationBasic OpCode.negate opOutputs 280 | | OpCode.not => parseOperationBasic OpCode.not opOutputs 281 | | OpCode.optimizationBarrier => parseOperationBasic OpCode.optimizationBarrier opOutputs 282 | | OpCode.or => parseOperationBasic OpCode.or opOutputs 283 | | OpCode.outfeed => parseOperationBasic OpCode.outfeed opOutputs 284 | | OpCode.pad => parseOperationBasic OpCode.pad opOutputs 285 | | OpCode.partitionId => parseOperationBasic OpCode.partitionId opOutputs 286 | | OpCode.popcnt => parseOperationBasic OpCode.popcnt opOutputs 287 | | OpCode.power => parseOperationBasic OpCode.power opOutputs 288 | | OpCode.real => parseOperationBasic OpCode.real opOutputs 289 | | OpCode.realDynamicSlice => parseOperationBasic OpCode.real opOutputs -- Undocumented 290 | | OpCode.recv => parseOperationBasic OpCode.recv opOutputs 291 | | OpCode.reduce => parseOperationBasic OpCode.reduce opOutputs 292 | | OpCode.reducePrecision => parseOperationBasic OpCode.reducePrecision opOutputs 293 | | OpCode.reduceScatter => parseOperationBasic OpCode.reduceScatter opOutputs 294 | | OpCode.reduceWindow => parseOperationBasic OpCode.reduceWindow opOutputs 295 | | OpCode.remainder => parseOperationBasic OpCode.remainder opOutputs 296 | | OpCode.replicaId => parseOperationBasic OpCode.replicaId opOutputs 297 | | OpCode.reshape => parseOperationBasic OpCode.reshape opOutputs 298 | | OpCode.reverse => parseOperationBasic OpCode.reverse opOutputs 299 | | OpCode.rng => 300 | report "explore for deprecation rng" 301 | parseOperationBasic OpCode.rng opOutputs 302 | | OpCode.rngBitGenerator => parseOperationBasic OpCode.rngBitGenerator opOutputs 303 | | OpCode.roundNearestAfz => parseOperationBasic OpCode.roundNearestAfz opOutputs 304 | | OpCode.roundNearestEven => parseOperationBasic OpCode.roundNearestEven opOutputs 305 | | OpCode.rsqrt => parseOperationBasic OpCode.rsqrt opOutputs 306 | | OpCode.scatter => parseOperationBasic OpCode.scatter opOutputs 307 | | OpCode.select => parseOperationBasic OpCode.select opOutputs 308 | | OpCode.selectAndScatter => parseOperationBasic OpCode.selectAndScatter opOutputs 309 | | OpCode.send => parseOperationBasic OpCode.send opOutputs 310 | | OpCode.shiftLeft => parseOperationBasic OpCode.shiftLeft opOutputs 311 | | OpCode.shiftRightArithmetic => parseOperationBasic OpCode.shiftRightArithmetic opOutputs 312 | | OpCode.shiftRightLogical => parseOperationBasic OpCode.shiftRightLogical opOutputs 313 | | OpCode.sign => parseOperationBasic OpCode.sign opOutputs 314 | | OpCode.sine => parseOperationBasic OpCode.sine opOutputs 315 | | OpCode.slice => parseOperationBasic OpCode.slice opOutputs 316 | | OpCode.sort => parseOperationBasic OpCode.sort opOutputs 317 | | OpCode.sqrt => parseOperationBasic OpCode.sqrt opOutputs 318 | | OpCode.subtract => parseOperationBasic OpCode.subtract opOutputs 319 | | OpCode.tan => parseOperationBasic OpCode.tan opOutputs 320 | | OpCode.tanh => { 321 | let opInputValues ← parseValueUseList 322 | -- could provide better error messages by ensuring not dictionnary 323 | if opInputValues.length ≠ 1 then throw <| (← error "tanh operation: wrong number of arguments") 324 | else { 325 | parseItem ":" 326 | let functionType ← parseFunctionType 327 | return Operation.tanh (opInputValues.get! 0) functionType 328 | } 329 | } 330 | | OpCode.transpose => parseOperationBasic OpCode.transpose opOutputs 331 | | OpCode.triangularSolve => parseOperationBasic OpCode.triangularSolve opOutputs 332 | | OpCode.tuple => parseOperationBasic OpCode.tuple opOutputs 333 | | OpCode.uniformDequantize => parseOperationBasic OpCode.uniformDequantize opOutputs 334 | | OpCode.uniformQuantize => parseOperationBasic OpCode.uniformQuantize opOutputs 335 | | OpCode.while => 336 | report "semantics not fully decided " 337 | parseOperationBasic OpCode.while opOutputs 338 | | OpCode.xor => parseOperationBasic OpCode.xor opOutputs 339 | 340 | partial def parseOperation : PState Operation := do 341 | if ← isParse "\"func.return\"" then 342 | let r ← parseReturn 343 | return r 344 | if ← isParse "\"stablehlo.return\"" then 345 | let r ← parseReturn 346 | return r 347 | let mut opOutputs := [] 348 | if ← is "%" then 349 | opOutputs ← parseOpOutputs 350 | parseItem "=" 351 | if ← is "\"func.call\"" then 352 | let r ← parseCall opOutputs 353 | return r 354 | 355 | if ← is "\"check." then 356 | let r ← parseOtherDialect opOutputs 357 | return r 358 | 359 | if ← is "\"interpreter." then 360 | let r ← parseOtherDialect opOutputs 361 | return r 362 | 363 | if ← is "\"chlo." then 364 | let r ← parseOtherDialect opOutputs 365 | return r 366 | 367 | let operation ← parseStableHLO opOutputs 368 | 369 | return operation 370 | 371 | partial def parseInputFuncBody : PState (List Operation) := do 372 | let r ← parseListAuxNoSep "}" parseOperation [] 373 | return r 374 | 375 | end 376 | 377 | end StableHLO.Parsing 378 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Parser.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | 8 | namespace StableHLO.Parsing 9 | 10 | structure Trace where 11 | startLine : Nat 12 | startColumn : Nat 13 | parser : String 14 | deriving Repr, Inhabited, Nonempty 15 | 16 | instance : ToString Trace where 17 | toString := fun t : Trace => s!"({t.startLine},{t.startColumn}):{t.parser}" 18 | 19 | instance : ToString (List Trace) where 20 | toString := fun t : List Trace => t.foldl (fun s : String => fun t : Trace => s ++ s!"{t}\n") "\n" 21 | 22 | structure Derivation where 23 | startLine : Nat 24 | startColumn : Nat 25 | endLine : Nat 26 | endColumn : Nat 27 | parser : String 28 | deriving Repr, Inhabited, Nonempty 29 | 30 | instance : ToString Derivation where 31 | toString := fun t : Derivation => s!"{t.parser} ({t.startLine},{t.startColumn}):({t.endLine},{t.endColumn})" 32 | 33 | instance : ToString (List Derivation) where 34 | toString := fun t : List Derivation => t.foldl (fun s : String => fun t : Derivation => s ++ s!"{t}\n") "\n" 35 | 36 | structure ParsingState where 37 | source : String -- Source data being parsed 38 | index : Nat -- Index into source data 39 | stop : Nat 40 | lineNumber : Nat 41 | columnNumber : Nat 42 | trace : List Trace -- For debugging the parser 43 | derivations : List Derivation -- For debugging the parser 44 | report : List String 45 | deriving Repr, Inhabited, Nonempty 46 | 47 | abbrev PState (T : Type) := StateT ParsingState (Except (String × List Trace × List Derivation)) T 48 | 49 | def error (msg : String) : PState (String × (List Trace) × (List Derivation)) := do 50 | let st ← get 51 | let mut token := "" 52 | let mut started := false 53 | for i in [st.index:st.stop] do 54 | let c := if let some c := st.source.get? ⟨ i ⟩ then c else panic s!"Indexing error in ParsingState.error" 55 | if ! started then 56 | if c = ' ' || c = '\t' || c = '\n' then continue 57 | else 58 | started := true 59 | token := token.push c 60 | else if c = ' ' || c = '\t' || c = '\n' then break 61 | else token := token.push c 62 | let errorMsg := s!"Parsing error line {st.lineNumber}, column {st.columnNumber} : expected {msg} but found {token}" 63 | return (errorMsg, st.trace, st.derivations) 64 | 65 | def errorSimple (msg : String) : PState (String × (List Trace) × (List Derivation)) := do 66 | let st ← get 67 | let errorMsg := s!"Parsing error line {st.lineNumber}, column {st.columnNumber} : {msg}" 68 | return (errorMsg, st.trace, st.derivations) 69 | 70 | def report (msg : String) : PState Unit := do 71 | let st ← get 72 | let msg := s!"line {st.lineNumber}, column {st.columnNumber}: {msg}\n" 73 | set { st with report := msg :: st.report} 74 | 75 | def skipComment (index : Nat) (st : ParsingState) : Nat := Id.run do 76 | let mut count := 0 77 | for i in [index:st.stop] do 78 | let c := st.source.get! ⟨ i ⟩ 79 | count := count + 1 80 | if c = '\n' then 81 | break 82 | return count 83 | 84 | def skip : PState Unit := do 85 | let st ← get 86 | let mut count := 0 87 | let mut lines := 0 88 | let mut column := st.columnNumber 89 | for i in [st.index:st.stop] do 90 | let c := st.source.get! ⟨ i ⟩ 91 | if c = '\n' then 92 | count := count + 1 93 | lines := lines + 1 94 | column := 0 95 | else if c = ' ' then 96 | count := count + 1 97 | column := column + 1 98 | else if c = '\t' then 99 | count := count + 1 100 | column := column + 8 101 | else if c = '/' && st.source.get! ⟨ i + 1 ⟩ = '/' then 102 | count := count + skipComment (st.index + count) (← get) 103 | lines := lines + 1 104 | column := 0 105 | else break 106 | set { st with 107 | index := st.index + count, 108 | lineNumber := st.lineNumber + lines, 109 | columnNumber := column 110 | } 111 | 112 | def done? : PState Bool := do 113 | skip 114 | let st ← get 115 | if st.index >= st.stop then return true else return false 116 | 117 | def parseItem (keyword : String) : PState Unit := do 118 | skip 119 | let st ← get 120 | let sub : Substring := { str := st.source, startPos := ⟨ st.index ⟩ , stopPos := ⟨ st.index + keyword.length ⟩ } 121 | if sub.beq keyword.toSubstring then 122 | set { st with 123 | index := st.index + keyword.length, 124 | columnNumber := st.columnNumber + keyword.length 125 | } 126 | else 127 | throw <| ← error keyword 128 | 129 | def is (keyword : String) : PState Bool := do 130 | skip 131 | let st ← get 132 | let sub : Substring := { str := st.source, startPos := ⟨ st.index ⟩ , stopPos := ⟨ st.index + keyword.length ⟩ } 133 | return sub.beq keyword.toSubstring 134 | 135 | def isParse (keyword : String) : PState Bool := do 136 | skip 137 | let st ← get 138 | let sub : Substring := { str := st.source, startPos := ⟨ st.index ⟩ , stopPos := ⟨ st.index + keyword.length ⟩ } 139 | if sub.beq keyword.toSubstring then 140 | set { st with 141 | index := st.index + keyword.length, 142 | columnNumber := st.columnNumber + keyword.length 143 | } 144 | return true 145 | else 146 | return false 147 | 148 | def isDigit : PState Bool := do 149 | skip 150 | let st ← get 151 | let c := st.source.get! ⟨ st.index ⟩ 152 | return c.isDigit 153 | 154 | def isChar (c : Char) : PState Bool := do 155 | skip 156 | let st ← get 157 | let c' := st.source.get! ⟨ st.index ⟩ 158 | return c = c' 159 | 160 | def parseItems (keywords : List String) : PState Unit := do 161 | for i in [:keywords.length] do 162 | parseItem <| keywords.get! i 163 | 164 | def parseFId : PState String := do 165 | skip 166 | let st ← get 167 | let mut token := "" 168 | for i in [st.index:st.stop] do 169 | let c := st.source.get! ⟨ i ⟩ 170 | if c.isAlphanum || c = '_' || c = '.' || c = '"' || c = '<' || c = '>' then token := token.push c 171 | else break 172 | if token.length != 0 then 173 | set { st with 174 | index := st.index + token.length, 175 | columnNumber := st.columnNumber + token.length 176 | } 177 | return token 178 | else 179 | throw <| ← error s!"Id" 180 | 181 | def parseId : PState String := do 182 | skip 183 | let st ← get 184 | let mut token := "" 185 | for i in [st.index:st.stop] do 186 | let c := st.source.get! ⟨ i ⟩ 187 | if c.isAlphanum || c = '_' || c = '.' then token := token.push c 188 | else break 189 | if token.length != 0 then 190 | set { st with 191 | index := st.index + token.length, 192 | columnNumber := st.columnNumber + token.length 193 | } 194 | return token 195 | else 196 | throw <| ← error s!"Id" 197 | 198 | def parseDecimal : PState Nat := do 199 | skip 200 | let st ← get 201 | let mut token := "" 202 | for i in [st.index:st.stop] do 203 | let c := st.source.get! ⟨ i ⟩ 204 | if c.isDigit then token := token.push c 205 | else break 206 | if token.length != 0 then 207 | set { st with 208 | index := st.index + token.length, 209 | columnNumber := st.columnNumber + token.length 210 | } 211 | return token.toNat! 212 | else 213 | throw <| ← error s!"Decimal" 214 | 215 | def isHexDigit (c : Char) : Bool := 216 | c.val ≥ 48 && c.val ≤ 57 || c.val ≥ 65 && c.val ≤ 70 || c.val ≥ 97 && c.val ≤ 102 217 | 218 | def toNatHex (s : String) : Nat := 219 | let r := s.foldl (fun n c => n*16 + ( 220 | if c.isDigit then c.toNat - '0'.toNat 221 | else 222 | if c.val <= 70 then 10 + (c.toNat - 'A'.toNat) 223 | else 10 + (c.toNat - 'a'.toNat))) 0 224 | r 225 | 226 | def parseHexaDecimal : PState Nat := do 227 | skip 228 | parseItem "0x" 229 | let st ← get 230 | let mut token := "" 231 | for i in [st.index:st.stop] do 232 | let c := st.source.get! ⟨ i ⟩ 233 | if isHexDigit c then token := token.push c 234 | else break 235 | if token.length != 0 then 236 | set { st with 237 | index := st.index + token.length, 238 | columnNumber := st.columnNumber + token.length 239 | } 240 | return toNatHex token 241 | else 242 | throw <| ← error s!"HexaDecimal" 243 | 244 | def parseString : PState String := do 245 | skip 246 | parseItem "\"" 247 | let st ← get 248 | let mut token := "" 249 | let mut escaped := false 250 | for i in [st.index:st.stop] do 251 | let c := st.source.get! ⟨ i ⟩ 252 | if c = '"' then 253 | if escaped then 254 | token := token.push c 255 | escaped := false 256 | else 257 | break 258 | else if c = '\\' then 259 | if escaped then escaped := false 260 | else escaped := true 261 | token := token.push c 262 | else 263 | token := token.push c 264 | set { st with 265 | index := st.index + token.length, 266 | columnNumber := st.columnNumber + token.length 267 | } 268 | parseItem "\"" 269 | return token 270 | 271 | partial def parseListOneorMoreAux (separator : String) (parse : PState T) (acc : List T) : PState (List T) := do 272 | if ← isParse separator then 273 | parseListOneorMoreAux separator parse ((← parse) :: acc) 274 | else return acc.reverse 275 | 276 | partial def parseListOneorMore (separator : String) (parse : PState T) : PState (List T) := do 277 | let head ← parse 278 | let tail ← parseListOneorMoreAux separator parse [] 279 | return head :: tail 280 | 281 | partial def parseListAux' (closingMark : String) (separator : String) (parse : PState T) (acc : List T) : PState (List T) := do 282 | if ← is closingMark then return acc.reverse 283 | if ← isParse separator then 284 | parseListAux' closingMark separator parse ((← parse) :: acc) 285 | else 286 | parseListAux' closingMark separator parse ((← parse) :: acc) 287 | 288 | partial def parseListAux (closingMark : String) (separator : String) (parse : PState T) : PState (List T) := do 289 | parseListAux' closingMark separator parse [] 290 | 291 | def parseList (openingMark closingMark : String) (separator : String) (parse : PState T) : PState (List T) := do 292 | parseItem openingMark 293 | let attrs ← parseListAux closingMark separator parse 294 | parseItem closingMark 295 | return attrs 296 | 297 | partial def parseListAuxNoSep (closingMark : String) (parse : PState T) (acc : List T) : PState (List T) := do 298 | if ← is closingMark then return acc.reverse 299 | parseListAuxNoSep closingMark parse ((← parse) :: acc) 300 | 301 | def parseListNoSep (openingMark closingMark : String) (parse : PState T) : PState (List T) := do 302 | parseItem openingMark 303 | let attrs ← parseListAuxNoSep closingMark parse [] 304 | parseItem closingMark 305 | return attrs 306 | 307 | def parseDecimals : PState (List Nat) := do 308 | parseList "[" "]" "," parseDecimal 309 | 310 | end StableHLO.Parsing 311 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Programs.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Modules 9 | 10 | namespace StableHLO.Parsing 11 | 12 | def parse (src : String) : Except (String × List Trace × List Derivation) (List Module × ParsingState) := do 13 | parseModules.run <| ParsingState.mk src 0 src.length 1 0 [] [] [] 14 | 15 | end StableHLO.Parsing 16 | -------------------------------------------------------------------------------- /SHerLOC/Parsing/Types.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Jean-Baptiste Tristan 5 | -/ 6 | import SHerLOC.AST1 7 | import SHerLOC.Parsing.Parser 8 | import SHerLOC.Parsing.Numbers 9 | 10 | namespace StableHLO.Parsing 11 | 12 | def tryParseIntegerType : PState (Option IntegerType) := do 13 | let mut r : Option IntegerType := none 14 | if ← isChar 'i' then { 15 | if ← isParse "i32" then r := some { sign := Signedness.signed , size := IntegerSize.b32 } 16 | if ← isParse "i64" then r := some { sign := Signedness.signed , size := IntegerSize.b64 } 17 | if ← isParse "i2" then r := some { sign := Signedness.signed , size := IntegerSize.b2 } 18 | if ← isParse "i4" then r := some { sign := Signedness.signed , size := IntegerSize.b4 } 19 | if ← isParse "i8" then r := some { sign := Signedness.signed , size := IntegerSize.b8 } 20 | if ← isParse "i16" then r := some { sign := Signedness.signed , size := IntegerSize.b16 } 21 | } 22 | if ← isParse "ui32" then r := some { sign := Signedness.unsigned , size := IntegerSize.b32 } 23 | if ← isParse "ui64" then r := some { sign := Signedness.unsigned , size := IntegerSize.b64 } 24 | if ← isParse "ui2" then r := some { sign := Signedness.unsigned , size := IntegerSize.b2 } 25 | if ← isParse "ui4" then r := some { sign := Signedness.unsigned , size := IntegerSize.b4 } 26 | if ← isParse "ui8" then r := some { sign := Signedness.unsigned , size := IntegerSize.b8 } 27 | if ← isParse "ui16" then r := some { sign := Signedness.unsigned , size := IntegerSize.b16 } 28 | return r 29 | 30 | def parseIntegerType : PState IntegerType := do 31 | if let some r ← tryParseIntegerType then return r 32 | else throw <| ← error "Integer type" 33 | 34 | def tryParseFloatType : PState (Option FloatType) := do 35 | let mut r : Option FloatType := none 36 | if ← isChar 'f' then { 37 | if ← isParse "f16" then r := some FloatType.f16 38 | if ← isParse "f32" then r := some FloatType.f32 39 | if ← isParse "f64" then r := some FloatType.f64 40 | if ← isParse "f8E3M4" then r := some FloatType.f8E3M4 41 | if ← isParse "f8E4M3B11FNUZ" then r := some FloatType.f8E4M3B11FNUZ 42 | if ← isParse "f8E4M3FNUZ" then r := some FloatType.f8E4M3FNUZ 43 | if ← isParse "f8E4M3FN" then r := some FloatType.f8E4M3FN 44 | if ← isParse "f8E4M3" then r := some FloatType.f8E4M3 45 | if ← isParse "f8E5M2FNUZ" then r := some FloatType.f8E5M2FNUZ 46 | if ← isParse "f8E5M2" then r := some FloatType.f8E5M2 47 | } 48 | if ← isParse "bf16" then r := some FloatType.bf16 49 | if ← isParse "tf32" then r := some FloatType.tf32 50 | return r 51 | 52 | def parseFloatType : PState FloatType := do 53 | if let some r ← tryParseFloatType then return r 54 | else throw <| ← error "Float type" 55 | 56 | def parseNumberType : PState NumberType := do 57 | if let some r ← tryParseIntegerType then return NumberType.integerType r 58 | else if let some r ← tryParseFloatType then return NumberType.floatType r 59 | else throw <| ← error "Number type" 60 | 61 | def parseComplexElementType : PState ComplexType := do 62 | if ← isParse "f32" then return ComplexType.f32 63 | else if ← isParse "f64" then return ComplexType.f64 64 | else throw <| ← error "Complex element type" 65 | 66 | def parseComplexType : PState ComplexType := do 67 | parseItem "complex" 68 | parseItem "<" 69 | let t ← parseComplexElementType 70 | parseItem ">" 71 | return t 72 | 73 | def tryParseDimensionSize : PState (Option DimensionSize) := do 74 | let mut r := none 75 | if (← isDigit) then 76 | r := some <| DimensionSize.known <| ← parseDecimal 77 | if (← isParse "?") then 78 | r := some <| DimensionSize.unknown 79 | return r 80 | 81 | partial def parseShape : PState (List DimensionSize) := do 82 | if let some dim ← tryParseDimensionSize then 83 | parseItem "x" 84 | let dims ← parseShape 85 | return dim :: dims 86 | else 87 | return [] 88 | 89 | def parseTensorElementType : PState TensorElementType := do 90 | if let some r ← tryParseIntegerType then return TensorElementType.integerType r 91 | if ← isParse "i1" then return TensorElementType.booleanType 92 | if ← is "complex" then return TensorElementType.complexType <| ← parseComplexType 93 | if let some r ← tryParseFloatType then return TensorElementType.floatType r 94 | throw <| ← error "TensorElementType" 95 | 96 | def parseQuantizationParameter : PState QuantizationParameter := do 97 | let quantizationScale ← parseFloatLiteral 98 | let mut quantizationZeroPoint := { sign := Sign.plus , decimal := 0 } 99 | if (← isParse ":") then 100 | quantizationZeroPoint ← parseIntegerLiteral 101 | let parseResult := 102 | { quantizationScale := quantizationScale, 103 | quantizationZeroPoint := quantizationZeroPoint 104 | } 105 | return parseResult 106 | 107 | def parseQuantizationParameters : PState (List QuantizationParameter) := do 108 | if ← is "{" then 109 | let quantizationParameters ← parseList "{" "}" "," parseQuantizationParameter 110 | return quantizationParameters 111 | else 112 | let quantizationParameter ← parseQuantizationParameter 113 | return [quantizationParameter] 114 | 115 | def parseQuantizedTensorElementType : PState QuantizedTensorElementType := do 116 | parseItem "!quant.uniform" 117 | parseItem "<" 118 | let quantizationStorageType ← parseIntegerType 119 | let mut quantizationStorageMinMax := none 120 | if ← isParse "<" then 121 | let min ← parseIntegerLiteral 122 | parseItem ":" 123 | let max ← parseIntegerLiteral 124 | quantizationStorageMinMax := some (min,max) 125 | parseItem ">" 126 | parseItem ":" 127 | let quantizationExpressedType ← parseFloatType 128 | let mut quantizationDimension := none 129 | if ← isParse ":" then 130 | quantizationDimension := some (← parseIntegerLiteral) 131 | parseItem "," 132 | let quantizationParameters ← parseQuantizationParameters 133 | parseItem ">" 134 | let quantizationBasics : QuantizationBasics := 135 | { quantizationStorageType := quantizationStorageType, 136 | quantizationStorageMinMax := quantizationStorageMinMax, 137 | quantizationExpressedType := quantizationExpressedType, 138 | quantizationDimension := quantizationDimension 139 | } 140 | let parseResult : QuantizedTensorElementType := 141 | { quantizationBasics := quantizationBasics 142 | quantizationParameters := quantizationParameters 143 | } 144 | return parseResult 145 | 146 | def parseTensorElementTypeGen : PState TensorElementTypeGen := do 147 | if ← is "!quant.uniform" 148 | then 149 | let quantizedTensorElementType ← parseQuantizedTensorElementType 150 | return TensorElementTypeGen.quantized quantizedTensorElementType 151 | else 152 | let tensorElementType ← parseTensorElementType 153 | return TensorElementTypeGen.classic tensorElementType 154 | 155 | def parseTensorType : PState TensorType := do 156 | parseItem "tensor" 157 | parseItem "<" 158 | let shape ← parseShape 159 | let tensorElementTypeGen ← parseTensorElementTypeGen 160 | parseItem ">" 161 | return { shape := shape, tensorElementTypeGen := tensorElementTypeGen } 162 | 163 | def parseTokenType : PState ValueType := do 164 | parseItem "!stablehlo.token" 165 | return ValueType.tokenType 166 | 167 | mutual 168 | 169 | partial def parseTupleType : PState ValueType := do 170 | parseItem "tuple" 171 | let TupleElementTypes ← parseList "<" ">" "," parseValueType 172 | return ValueType.tupleType TupleElementTypes 173 | 174 | partial def parseValueType : PState ValueType := do 175 | if ← is "tensor" then return ValueType.tensorType <| ← parseTensorType 176 | else if ← is "tuple" then 177 | let r ← parseTupleType 178 | return r 179 | else if ← is "!stablehlo.token" then 180 | let r ← parseTokenType 181 | return r 182 | else throw <| ← error "Value Type" 183 | 184 | end 185 | 186 | def parseValueTypesOutput : PState (List ValueType) := do 187 | let mut valueTypes : List ValueType := [] 188 | if ← is "(" then 189 | valueTypes ← parseList "(" ")" "," parseValueType 190 | else 191 | let r ← parseValueType 192 | valueTypes := [r] 193 | return valueTypes 194 | 195 | def parseValueTypes : PState (List ValueType) := do 196 | parseList "(" ")" "," parseValueType 197 | 198 | 199 | def parseFunctionType : PState FunctionType := do 200 | let inputTypes ← parseValueTypes 201 | parseItem "-" 202 | parseItem ">" 203 | let outputType ← parseValueTypesOutput 204 | return { domain := inputTypes, range := outputType } 205 | 206 | def parseStringType : PState NonValueType := do 207 | parseItem "string" 208 | return NonValueType.stringType 209 | 210 | def parseType : PState SType := do 211 | if (← is "tensor") || (← is "tuple") || (← is "!stablehlo.token") then 212 | return SType.valueType <| ← parseValueType 213 | return SType.nonValueType <| NonValueType.tensorElementType <| ← parseTensorElementType 214 | 215 | end StableHLO.Parsing 216 | -------------------------------------------------------------------------------- /lake-manifest.json: -------------------------------------------------------------------------------- 1 | {"version": "1.1.0", 2 | "packagesDir": ".lake/packages", 3 | "packages": [], 4 | "name": "SHerLOC", 5 | "lakeDir": ".lake"} 6 | -------------------------------------------------------------------------------- /lakefile.lean: -------------------------------------------------------------------------------- 1 | import Lake 2 | open Lake DSL 3 | 4 | package "SHerLOC" where 5 | -- add package configuration options here 6 | 7 | lean_lib «SHerLOC» where 8 | -- add library configuration options here 9 | 10 | @[default_target] 11 | lean_exe "sherloc" where 12 | root := `Main 13 | -------------------------------------------------------------------------------- /lean-toolchain: -------------------------------------------------------------------------------- 1 | leanprover/lean4:v4.10.0 2 | -------------------------------------------------------------------------------- /prepare_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d stablehlo ] 4 | then 5 | echo "The stablehlo directory is missing" 6 | echo "To prepare tests, install stablehlo: https://github.com/openxla/stablehlo" 7 | exit 8 | fi 9 | 10 | if [ ! -f stablehlo/build/bin/stablehlo-opt ] 11 | then 12 | echo "Missing: stablehlo/build/bin/stablehlo-opt" 13 | exit 14 | fi 15 | 16 | shopt="stablehlo/build/bin/stablehlo-opt -mlir-print-op-generic -split-input-file" 17 | 18 | interpret_test=stablehlo/stablehlo/tests/interpret 19 | 20 | if [ ! -d $interpret_test ] 21 | then 22 | echo "Missing tests: $interpret_test" 23 | exit 24 | fi 25 | 26 | for test in `ls $interpret_test/*.mlir` 27 | do 28 | name=interpret_`basename $test` 29 | $shopt $test > Tests/$name 30 | done 31 | 32 | test_data=stablehlo/stablehlo/testdata 33 | 34 | if [ ! -d $test_data ] 35 | then 36 | echo "Missing tests: $test_data" 37 | exit 38 | fi 39 | 40 | for test in `ls $test_data/*.mlir` 41 | do 42 | name=testdata_`basename $test` 43 | $shopt $test > Tests/$name 44 | done 45 | 46 | test_data_dynamic=stablehlo/stablehlo/testdata/dynamic 47 | 48 | if [ ! -d $test_data_dynamic ] 49 | then 50 | echo "Missing tests: $test_data_dynamic" 51 | exit 52 | fi 53 | 54 | for test in `ls $test_data_dynamic/*.mlir` 55 | do 56 | name=testdata__dynamic_`basename $test` 57 | $shopt $test > Tests/$name 58 | done 59 | 60 | test_data_quantized=stablehlo/stablehlo/testdata/quantized 61 | 62 | if [ ! -d $test_data_quantized ] 63 | then 64 | echo "Missing tests: $test_data_quantized" 65 | exit 66 | fi 67 | 68 | for test in `ls $test_data_quantized/*.mlir` 69 | do 70 | name=testdata_quantized_`basename $test` 71 | $shopt $test > Tests/$name 72 | done 73 | --------------------------------------------------------------------------------