├── interop ├── nki │ ├── __init__.py │ ├── isa │ │ ├── constants.py │ │ └── __init__.py │ ├── typing.py │ └── language │ │ └── __init__.py ├── klr │ ├── .clang-format │ ├── .gitignore │ ├── reformat.sh │ ├── peg_parser │ │ ├── README.md │ │ ├── Tokens │ │ └── compat.c │ ├── region.h │ ├── klir_common.hpp │ ├── _cli.py │ ├── peg_parser.c │ ├── __init__.py │ ├── load_test.cpp │ ├── stdc.h │ ├── frontend.h │ ├── cbor.h │ ├── klir_common.cpp │ ├── README.md │ ├── Makefile │ ├── region.c │ ├── klr_ffi.c │ └── lean_ast.h ├── README.md ├── requirements.txt ├── MANIFEST.in ├── test │ ├── apis.py │ ├── examples │ │ ├── __init__.py │ │ ├── getting_started.py │ │ ├── prof.py │ │ ├── layout.py │ │ ├── tensor_addition.py │ │ ├── rmsnorm.py │ │ ├── transpose2d.py │ │ ├── average_pool.py │ │ ├── mm.py │ │ └── index.py │ ├── runner.py │ ├── test_nki_allocation.py │ ├── test_fstr.py │ ├── test_examples.py │ ├── eval │ │ └── kernels.py │ ├── test_enum.py │ ├── test_dyn_ap.py │ ├── test_memory.py │ ├── test_nki_isa_tensor_scalar.py │ ├── test_list.py │ ├── test_dict.py │ └── test_basic.py ├── notes.md ├── setup.py └── pyproject.toml ├── lean-toolchain ├── KLR ├── TGR │ ├── Basic.lean │ └── Dot.lean ├── Semantics.lean ├── K │ ├── Common.lean │ └── K3 │ │ ├── AST.lean │ │ └── DotK3.lean ├── Serde.lean ├── Core.lean ├── NKI │ ├── Typed.lean │ ├── Typed │ │ └── Types.lean │ ├── Patterns.lean │ └── SimplifyOperators.lean ├── Py.lean ├── NKI.lean ├── Util.lean ├── Util │ ├── Gensym.lean │ ├── NumBytesTest.lean │ ├── Meta.lean │ ├── BigArray.lean │ ├── ToBytesTest.lean │ ├── FromBytesTest.lean │ ├── Padding.lean │ ├── Common.lean │ ├── Plausible.lean │ ├── BitVec.lean │ ├── Hex.lean │ ├── Float.lean │ └── NumBytes.lean ├── Trace │ ├── Extension.lean │ └── Lang.lean ├── Semantics │ ├── Notation.lean │ ├── Float.lean │ └── Tactics.lean ├── Extract │ ├── Extract │ │ ├── ASDL.lean │ │ └── Python.lean │ └── Extract.lean ├── Py │ ├── Util.lean │ └── PosLemmas.lean ├── CompileHLO.lean ├── Serde │ ├── File.lean │ └── Attr.lean ├── Trace.lean ├── Model.lean ├── Export.lean └── Core │ └── Pretty.lean ├── bin ├── klr ├── update-manifest ├── check-libs ├── license-header.txt ├── klr-gather ├── make-wheel └── rename-wheels ├── .gitignore ├── .github └── CODEOWNERS ├── docs └── disk_format.md ├── KLR.lean ├── lakefile.lean ├── lake-manifest.json └── README.md /interop/nki/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lean-toolchain: -------------------------------------------------------------------------------- 1 | leanprover/lean4:v4.23.0 2 | -------------------------------------------------------------------------------- /interop/klr/.clang-format: -------------------------------------------------------------------------------- 1 | SortIncludes: false 2 | -------------------------------------------------------------------------------- /interop/klr/.gitignore: -------------------------------------------------------------------------------- 1 | cbor_test 2 | simplify_test 3 | -------------------------------------------------------------------------------- /interop/README.md: -------------------------------------------------------------------------------- 1 | Python bindings for [KLR](https://github.com/leanprover/KLR) 2 | -------------------------------------------------------------------------------- /KLR/TGR/Basic.lean: -------------------------------------------------------------------------------- 1 | import KLR.TGR.AST 2 | import KLR.TGR.Compile 3 | import KLR.TGR.Dot 4 | import KLR.TGR.Py 5 | -------------------------------------------------------------------------------- /bin/klr: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT=$(dirname $(dirname $(readlink -f $0))) 4 | 5 | $ROOT/.lake/build/bin/klr $@ 6 | -------------------------------------------------------------------------------- /interop/requirements.txt: -------------------------------------------------------------------------------- 1 | # This is the place to add pinned versions if we have any build troubles 2 | ml_dtypes 3 | numpy 4 | pytest 5 | -------------------------------------------------------------------------------- /interop/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # All files must be below the python project root 2 | include klr/README.md 3 | include klr/*.[ch] 4 | include klr/klr.bin 5 | graft klr/peg_parser 6 | -------------------------------------------------------------------------------- /interop/klr/reformat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -u -o pipefail 3 | trap "kill 0" SIGINT SIGTERM 4 | 5 | for f in klir_*.[hc]pp 6 | do 7 | clang-format -i $f 8 | done 9 | -------------------------------------------------------------------------------- /interop/klr/peg_parser/README.md: -------------------------------------------------------------------------------- 1 | This directory contains files from the CPython distribution. 2 | These files are not compiled directly. 3 | See comments in peg_parser.c for more details. 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | \.lake/ 2 | __pycache__/ 3 | *.o 4 | *.so 5 | *.dSYM 6 | .vscode/ 7 | *.egg-info/ 8 | build/ 9 | dist/ 10 | klr.bin 11 | /.wheel/ 12 | env 13 | *.klr 14 | *.klir 15 | interop/klr/load_test 16 | -------------------------------------------------------------------------------- /interop/nki/isa/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | class reduce_cmd(Enum): 3 | """Engine Register Reduce commands""" 4 | idle = 0 5 | reset = 1 6 | reset_reduce = 3 7 | reduce = 2 8 | 9 | 10 | -------------------------------------------------------------------------------- /interop/nki/typing.py: -------------------------------------------------------------------------------- 1 | 2 | def scalar(x): pass 3 | 4 | class tensor: 5 | def __init__(self, dtype, shape): 6 | self.dtype = dtype 7 | self.shape = shape 8 | 9 | class mutable(tensor): 10 | pass 11 | -------------------------------------------------------------------------------- /interop/test/apis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025, Amazon.com. All Rights Reserved 3 | 4 | """ 5 | 6 | # import APIs needed by test cases. 7 | 8 | import numpy as np 9 | import nki 10 | import nki.isa as nisa 11 | import nki.language as nl 12 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners 2 | * @govereau @seanmcl @jtristan @joonwonc @ppotapov-aws @waahm7 @jiancheng-aws @yongweiy @kerrijoe-aws 3 | -------------------------------------------------------------------------------- /bin/update-manifest: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -u -o pipefail 3 | trap "kill 0" SIGINT SIGTERM 4 | 5 | # Update all of the manifest files in the repo 6 | 7 | for f in `fd lake-manifest.json` 8 | do 9 | d=`dirname $f` 10 | echo $d 11 | (cd $d; lake update) 12 | done 13 | -------------------------------------------------------------------------------- /interop/test/examples/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'getting_started', 3 | 'layout', 4 | 'index', 5 | 'mm', 6 | 'prof', 7 | 'average_pool', 8 | 'fused_mamba', 9 | 'layernorm', 10 | 'matmul', 11 | 'rmsnorm', 12 | 'sd_attention', 13 | 'tensor_addition', 14 | 'transpose2d', 15 | ] 16 | -------------------------------------------------------------------------------- /docs/disk_format.md: -------------------------------------------------------------------------------- 1 | # KLR On-Disk Format 2 | 3 | KLR uses [CBOR](https://cbor.io) as its on-disk format. 4 | 5 | The exact details are stull in flux. 6 | 7 | ## Assignment of tags 8 | 9 | | Lean Type | Type Tag | 10 | |-|-| 11 | | Internal ASTs | 1 - 99 | 12 | | Core.* | 100 - 149 | 13 | | Core.Operators.* | 150 - 234 | 14 | | KLRMetaData | 235 (0xeb) | 15 | | Contents | 236 (0xec) | 16 | | Option | 255 (0xff) | 17 | -------------------------------------------------------------------------------- /KLR/Semantics.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Markus de Medeiros 5 | -/ 6 | 7 | import KLR.Semantics.Lib 8 | import KLR.Semantics.Memory 9 | import KLR.Semantics.NML 10 | import KLR.Semantics.Float 11 | import KLR.Semantics.KLR 12 | import KLR.Semantics.NKI 13 | import KLR.Semantics.Logic 14 | import KLR.Semantics.ProofRules 15 | import KLR.Semantics.Examples 16 | -------------------------------------------------------------------------------- /bin/check-libs: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -u -o pipefail 3 | trap "kill 0" SIGINT SIGTERM 4 | 5 | if [[ $# -ne 1 ]]; then 6 | echo "Usage: check-libs file" 7 | exit 1 8 | fi 9 | 10 | file=$1; shift 11 | os=$(uname) 12 | 13 | echo "Listing library dependencies of ${file}..." 14 | 15 | echo "Architecture: $(uname -m)" 16 | 17 | if [[ $os == "Darwin" ]]; then 18 | otool -L $file 19 | elif [[ $os == "Linux" ]]; then 20 | ldd $file 21 | else 22 | echo "Unsupported OS" 23 | exit 1 24 | fi 25 | -------------------------------------------------------------------------------- /KLR/K/Common.lean: -------------------------------------------------------------------------------- 1 | import KLR.Core 2 | 3 | namespace KLR.K 4 | 5 | /- A region of HBM, described as an offset from some named tensor as well as 6 | an access pattern. The access pattern is in source-order (slowest-first), not 7 | ISA-order (fastest-first) -/ 8 | structure HbmLocation (Scalar : Type) where 9 | name : String 10 | offset : Scalar 11 | pattern : List Core.APPair 12 | deriving Inhabited, Repr, BEq 13 | 14 | instance [ToString T] : ToString (HbmLocation T) where 15 | toString 16 | | .mk name offset pattern => s!"HbmLoc(at: {name}[{offset}], pattern: {repr pattern})" 17 | -------------------------------------------------------------------------------- /interop/nki/language/__init__.py: -------------------------------------------------------------------------------- 1 | # KLR implemetations of NKI langauge APIs 2 | 3 | class NKIObject: 4 | pass 5 | 6 | def copy(): pass 7 | def exp(): pass 8 | def add(): pass 9 | def multiply(): pass 10 | 11 | def par_dim(x): return x 12 | 13 | int32 = "int32" 14 | float16 = "float16" 15 | uint8 = "uint8" 16 | 17 | class int8(NKIObject): 18 | itemsize = 1 19 | 20 | def __init__(self, x): 21 | self.value = x 22 | 23 | def __str__(self): 24 | return "uint8" 25 | 26 | class float32(NKIObject): 27 | itemsize = 4 28 | 29 | class tile_size: 30 | pmax = 128 31 | psum_fmax = 128 32 | 33 | -------------------------------------------------------------------------------- /bin/license-header.txt: -------------------------------------------------------------------------------- 1 | Copyright KLR Contributors 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /interop/test/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from klr.frontend import Kernel 4 | 5 | def run_success(f, args): 6 | F = Kernel(f) # parse python 7 | j = json.loads(F.specialize(args)) 8 | if len(j['errors']) > 0: 9 | assert False, j['errors'][0] 10 | j = json.loads(F.trace("tmp.klr")) 11 | if len(j['errors']) > 0: 12 | assert False, j['errors'][0] 13 | os.remove("tmp.klr") 14 | 15 | def run_fail(f, args): 16 | F = Kernel(f) # parse python 17 | j = json.loads(F.specialize(args)) 18 | if len(j['errors']) > 0: 19 | return 20 | j = json.loads(F.trace("tmp.klr")) 21 | if len(j['errors']) > 0: 22 | return 23 | os.remove("tmp.klr") 24 | assert False, "expecting failure" 25 | -------------------------------------------------------------------------------- /KLR/Serde.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Serde.Attr 18 | import KLR.Serde.Basic 19 | import KLR.Serde.Elab 20 | import KLR.Serde.File 21 | import KLR.Serde.Test 22 | -------------------------------------------------------------------------------- /KLR/Core.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Core.Basic 18 | -- TODO fix pretty printer 19 | --import KLR.Core.Pretty 20 | import KLR.Core.Indexing 21 | import KLR.Core.LowerAP 22 | -------------------------------------------------------------------------------- /KLR/NKI/Typed.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.NKI.Typed.Basic 18 | import KLR.NKI.Typed.Context 19 | import KLR.NKI.Typed.Elab 20 | import KLR.NKI.Typed.Test 21 | import KLR.NKI.Typed.Types 22 | -------------------------------------------------------------------------------- /KLR/Py.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Py.Basic 18 | import KLR.Py.Parser 19 | import KLR.Py.PosLemmas 20 | import KLR.Py.Pretty 21 | import KLR.Py.Test 22 | import KLR.Py.Tokenizer 23 | import KLR.Py.Util 24 | -------------------------------------------------------------------------------- /bin/klr-gather: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run the NKI frontend's gather step from a local build 4 | # 5 | # In order for this to work, you must set PY_VER to a supported 6 | # python version, and you must have a executable named python${PY_VER} 7 | # pointing to that Python interpreter. 8 | # If you installed python through brew, then you should already have 9 | # this. The default is version 3.10, this choice is arbitrary. 10 | # 11 | # You must also have run make in the interop/klr directory for the 12 | # python version that you are using. For example: 13 | # 14 | # cd interop/klr 15 | # make PY_VER=3.10 16 | # 17 | # Will build the Python 3.10 version of the NKI frontend. 18 | 19 | set -e 20 | 21 | V=${PY_VER:-3.10} 22 | 23 | PYTHONPATH=interop:${PYTHONPATH} \ 24 | python${V} -c 'import klr._cli; klr._cli.gather()' $@ 25 | -------------------------------------------------------------------------------- /KLR.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Compile 18 | import KLR.Core 19 | import KLR.Export 20 | import KLR.File 21 | import KLR.NKI 22 | --import KLR.Py 23 | import KLR.Python 24 | import KLR.Serde 25 | import KLR.Trace 26 | import KLR.Util 27 | --import KLR.Semantics 28 | -------------------------------------------------------------------------------- /interop/test/examples/getting_started.py: -------------------------------------------------------------------------------- 1 | from apis import * 2 | 3 | def nki_tensor_add_kernel(a_input, b_input, c_output): 4 | """NKI kernel to compute element-wise addition of two input tensors 5 | """ 6 | 7 | # Check all input/output tensor shapes are the same for element-wise operation 8 | assert a_input.shape == b_input.shape == c_output.shape 9 | 10 | # Check size of the first dimension does not exceed on-chip memory tile size limit, 11 | # so that we don't need to tile the input to keep this example simple 12 | assert a_input.shape[0] <= nl.tile_size.pmax 13 | 14 | # Load the inputs from device memory to on-chip memory 15 | a_tile = nl.load(a_input) 16 | b_tile = nl.load(b_input) 17 | 18 | # Specify the computation (in our case: a + b) 19 | c_tile = nl.add(a_tile, b_tile) 20 | 21 | # Store the result to c_output from on-chip memory to device memory 22 | nl.store(c_output, value=c_tile) 23 | -------------------------------------------------------------------------------- /interop/test/test_nki_allocation.py: -------------------------------------------------------------------------------- 1 | # tests of NKI allocation API 2 | 3 | import os 4 | import pytest 5 | 6 | from apis import * 7 | 8 | from klr.frontend import Kernel 9 | 10 | # Success cases 11 | # (these functions should load and trace to KLR) 12 | 13 | def simple(): 14 | tensor1 = nl.ndarray((32,32), np.uint8, buffer=nl.sbuf, name="t") 15 | tensor2 = nl.ndarray((32,8), np.float32, buffer=nl.sbuf) 16 | 17 | # test each function in turn 18 | @pytest.mark.parametrize("f", [ 19 | simple, 20 | ]) 21 | def test_succeed(f): 22 | F = Kernel(f) # parse python 23 | F.specialize() 24 | F.trace("tmp.klr") 25 | if os.path.exists("tmp.klr"): 26 | os.remove("tmp.klr") 27 | 28 | # Failing cases 29 | # (These functions are expected to fail elaboration to KLR) 30 | 31 | def too_large(): 32 | nl.ndarray((512,16), np.float32, buffer=nl.sbuf) 33 | 34 | @pytest.mark.parametrize("f", [ 35 | too_large, 36 | ]) 37 | def test_fails(f): 38 | F = Kernel(f) 39 | with pytest.raises(Exception): 40 | F() 41 | -------------------------------------------------------------------------------- /interop/test/examples/prof.py: -------------------------------------------------------------------------------- 1 | from apis import * 2 | import math 3 | 4 | def tensor_exp_kernel_(in_tensor, out_tensor): 5 | """NKI kernel to compute elementwise exponential of an input tensor 6 | Args: 7 | in_tensor: an input tensor of ANY 2D shape (up to SBUF size) 8 | out_tensor: an output tensor of ANY 2D shape (up to SBUF size) 9 | """ 10 | sz_p, sz_f = in_tensor.shape 11 | i_f = nl.arange(sz_f)[None, :] 12 | for p in nl.affine_range(math.ceil(sz_p / nl.tile_size.pmax)): 13 | # Generate tensor indices for the input/output tensors 14 | # pad index to pmax, for simplicity 15 | i_p = p * nl.tile_size.pmax + nl.arange(nl.tile_size.pmax)[:, None] 16 | # Load input data from external memory to on-chip memory 17 | # only read up to sz_p 18 | in_tile = nl.load(in_tensor[i_p, i_f], mask=(i_p" where is incremented 16 | until a unique name is found. -/ 17 | partial def GensymEnv.gensym (env : GensymEnv) (suggestion : String) : (String × GensymEnv) := Id.run do 18 | let mut env := env 19 | if !env.gensymUsed.contains suggestion then 20 | env := {env with gensymUsed := env.gensymUsed.insert suggestion } 21 | return (suggestion, env) 22 | 23 | let mut name := s!"gs_{env.gensymCounter}" 24 | repeat do 25 | if !env.gensymUsed.contains name then 26 | env := { env with gensymUsed := env.gensymUsed.insert name, gensymCounter := env.gensymCounter + 1 } 27 | break 28 | else 29 | env := { env with gensymCounter := env.gensymCounter + 1 } 30 | name := s!"gs_{env.gensymCounter}" 31 | (name, env) 32 | -------------------------------------------------------------------------------- /interop/nki/isa/__init__.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | # KLR implemetations of NKI ISA APIs 3 | 4 | class reduce_cmd(Enum): 5 | """Engine Register Reduce commands""" 6 | idle = 0 7 | reset = 1 8 | reset_reduce = 3 9 | reduce = 2 10 | 11 | class dge_mode(Enum): 12 | none = 0 13 | swde = 1 14 | hwge = 2 15 | unknown = 3 16 | 17 | def psum_raw_ptr(address, size): 18 | p_start, f_start = address 19 | p_size, f_size = size 20 | p_end = p_start + p_size 21 | f_end = f_start + f_size 22 | return psum[p_start:p_end, f_start:f_end] 23 | 24 | def sbuf_raw_ptr(address, size): 25 | p_start, f_start = address 26 | p_size, f_size = size 27 | p_end = p_start + p_size 28 | f_end = f_start + f_size 29 | return sbuf[p_start:p_end, f_start:f_end] 30 | 31 | class reduce_cmd(Enum): 32 | """Engine Register Reduce commands """ 33 | idle = 0 34 | reset = 1 35 | reset_reduce = 2 36 | reduce = 3 37 | load_reduce = 4 38 | 39 | def psum_raw_ptr(address, size): 40 | p_start, f_start = address 41 | p_size, f_size = size 42 | p_end = p_start + p_size 43 | f_end = f_start + f_size 44 | return psum[p_start:p_end, f_start:f_end] 45 | 46 | def sbuf_raw_ptr(address, size): 47 | p_start, f_start = address 48 | p_size, f_size = size 49 | p_end = p_start + p_size 50 | f_end = f_start + f_size 51 | return sbuf[p_start:p_end, f_start:f_end] 52 | -------------------------------------------------------------------------------- /interop/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import Extension, setup 2 | import os 3 | import sys 4 | 5 | # MACOSX_DEPLOYMENT_TARGET tells compiler the oldest macOS version to support. 6 | # It also determines the "macosx_13_0" part of the wheel name, 7 | # and pip won't install that wheel on macOS older than that. 8 | # If we don't set MACOSX_DEPLOYMENT_TARGET explicitly, setuptools will set 9 | # it to whatever your python installation was built with, 10 | # and if you used the python.org installer it's something ancient like 10.9, 11 | # which causes compiler warnings about modern-ish things like aligned_alloc(). 12 | if sys.platform == 'darwin' and not os.getenv('MACOSX_DEPLOYMENT_TARGET'): 13 | os.environ['MACOSX_DEPLOYMENT_TARGET'] = '13.0' 14 | 15 | # Note, because we are building an extension module, we will get arch specific wheels. 16 | # This is important because we are putting an arch-specific binary into the wheel 17 | # as an "extra" file. 18 | # 19 | # TODO: move this to pyproject.toml once Brazil supports newer versions of setuptools 20 | setup( 21 | ext_modules=[ 22 | Extension( 23 | name="klr.frontend", 24 | sources=[ 25 | "klr/cbor.c", 26 | "klr/frontend.c", 27 | "klr/gather.c", 28 | "klr/peg_parser.c", 29 | "klr/region.c", 30 | "klr/serde.c", 31 | "klr/serde_common.c", 32 | "klr/serde_file.c", 33 | "klr/serde_python_core.c", 34 | ], 35 | ), 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /interop/test/examples/layout.py: -------------------------------------------------------------------------------- 1 | from apis import * 2 | 3 | def tensor_exp_kernel_1(in_tensor, out_tensor): 4 | """NKI kernel to compute elementwise exponential of an input tensor 5 | 6 | Args: 7 | in_tensor: an input tensor of shape [256,512] 8 | out_tensor: an output tensor of shape [256,512] 9 | """ 10 | i_f = nl.arange(512)[None, :] 11 | 12 | for k in nl.affine_range(2): 13 | # Generate tensor indices for the input/output tensors 14 | i_p = k * nl.tile_size.pmax + nl.arange(nl.tile_size.pmax)[:, None] 15 | 16 | # Load input data from HBM to on-chip memory 17 | in_tile = nl.load(in_tensor[i_p, i_f]) 18 | 19 | # perform the computation 20 | out_tile = nl.exp(in_tile) 21 | 22 | # store the results back to HBM 23 | nl.store(out_tensor[i_p, i_f], value=out_tile) 24 | 25 | 26 | def tensor_exp_kernel_2(in_tensor, out_tensor): 27 | """NKI kernel to compute elementwise exponential of an input tensor 28 | 29 | Args: 30 | in_tensor: an input tensor of shape [128,512] 31 | out_tensor: an output tensor of shape [128,512] 32 | """ 33 | # Generate indices for the input/output tensors 34 | i_p = nl.arange(128)[:, None] 35 | i_f = nl.arange(512)[None, :] 36 | 37 | # Load input data from HBM to on-chip memory 38 | in_tile = nl.load(in_tensor[i_p, i_f]) 39 | 40 | # perform the computation: 41 | out_tile = nl.exp(in_tile) 42 | 43 | # store the results back to HBM 44 | nl.store(out_tensor[i_p, i_f], value=out_tile) 45 | -------------------------------------------------------------------------------- /interop/klr/klir_common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin 5 | */ 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | 11 | namespace klr { 12 | 13 | #define check_size(s, n) \ 14 | static_assert(sizeof(s) == n, "sizeof " #s " unexpected") 15 | 16 | typedef uint8_t u8; 17 | typedef uint64_t u64; 18 | typedef int64_t i64; 19 | 20 | typedef bool Bool; 21 | typedef int32_t Int; 22 | typedef uint32_t Nat; 23 | typedef float Float; 24 | 25 | check_size(Float, 4); 26 | 27 | using String = std::string; 28 | 29 | struct Prop {}; 30 | 31 | template using Ptr = std::shared_ptr; 32 | 33 | template Ptr ptr() { return std::make_shared(); } 34 | 35 | template using List = std::list; 36 | 37 | // template List list() { return ptr>(); } 38 | 39 | template using Option = std::optional; 40 | 41 | bool deserialize_tag(FILE *in, u8 *type, u8 *constructor, u8 *len); 42 | bool deserialize_array_start(FILE *in, u64 *size); 43 | bool deserialize_option(FILE *in, bool *isSome); 44 | 45 | Prop Prop_des(FILE *out); 46 | Bool Bool_des(FILE *out); 47 | Nat Nat_des(FILE *out); 48 | Int Int_des(FILE *out); 49 | Float Float_des(FILE *out); 50 | String String_des(FILE *out); 51 | 52 | } // namespace klr 53 | -------------------------------------------------------------------------------- /interop/klr/_cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # Released under Apache 2.0 license as described in the file LICENSE. 3 | # Authors: Paul Govereau, Sean McLaughlin 4 | 5 | import importlib 6 | import importlib.resources 7 | import klr.frontend as fe 8 | import os 9 | import subprocess 10 | import sys 11 | 12 | # This function is only used from within a pip-installed environment 13 | # Local developers can use ./bin/klr from the github root 14 | def run_klr(): 15 | # FIXME: Perhaps should use the scripts directory? How do we do that? 16 | # see https://packaging.python.org/en/latest/specifications/binary-distribution-format/ 17 | bin = importlib.resources.files('klr').joinpath('klr.bin') 18 | args = [bin] + sys.argv[1:] 19 | cp = subprocess.run(args) 20 | sys.exit(cp.returncode) 21 | 22 | # This function is used internally by the klr binary. 23 | # For a pip-installed environment, this is available as script "klr-gather" 24 | # For local developers, this is called by ./bin/klr-gather 25 | def gather(): 26 | if len(sys.argv) != 4: 27 | print(f"Usage: {sys.argv[0]} module function outfile", file=sys.stderr) 28 | sys.exit(1) 29 | 30 | _, module, fn, outfile = sys.argv 31 | try: 32 | m = importlib.import_module(module) 33 | f = getattr(m, fn) 34 | K = fe.Kernel(f) 35 | with open(outfile, "w") as f: 36 | f.write(K.serialize_python()) 37 | except Exception as e: 38 | print(str(e), file=sys.stderr) 39 | sys.exit(1) 40 | -------------------------------------------------------------------------------- /interop/klr/peg_parser.c: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin 5 | */ 6 | #include "frontend.h" 7 | #include "ast_python.h" 8 | 9 | /* 10 | Many of the things in peg_parser are also in libpython, some private, some 11 | public. To avoid name conflicts, we create a single compilation unit with 12 | everything declared static, except for the two functions at the bottom of 13 | this file. This ensures our version of the PEG parser will not get confused 14 | with the version in the user's libpython. 15 | */ 16 | 17 | #include "peg_parser/compat.c" 18 | #include "peg_parser/token.c" 19 | #include "peg_parser/ast_python.c" 20 | #include "peg_parser/tokenizer.c" 21 | #include "peg_parser/pegen.c" 22 | #include "peg_parser/string_parser.c" 23 | #include "peg_parser/action_helpers.c" 24 | #include "peg_parser/parser.c" 25 | 26 | // -- Public interface to our version of the PEG parser 27 | 28 | struct _mod* parse_string(const char *str, PyObject* filename) { 29 | PyArena *arena = _PyArena_New(); 30 | if (!arena) 31 | return NULL; 32 | 33 | struct _mod *result = _PyPegen_run_parser_from_string(str, filename, arena); 34 | if (result) 35 | result->arena = arena; 36 | else 37 | _PyArena_Free(arena); 38 | return result; 39 | } 40 | 41 | // Everything (including m) is in the region 42 | void free_python_ast(struct _mod *m) { 43 | if (m) 44 | _PyArena_Free(m->arena); 45 | } 46 | -------------------------------------------------------------------------------- /interop/klr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # Released under Apache 2.0 license as described in the file LICENSE. 3 | # Authors: Paul Govereau, Sean McLaughlin 4 | 5 | import json 6 | from tempfile import NamedTemporaryFile 7 | from typing import Optional, Sequence, Union 8 | 9 | from . import frontend 10 | 11 | class NKIObject: 12 | pass 13 | 14 | # wrapper around kernel.specialize() 15 | def _specialize_kernel( 16 | kernel: frontend.Kernel, 17 | args: Optional[tuple] = None, 18 | kwargs: Optional[dict] = None, 19 | *, 20 | arch: int, 21 | grid: Optional[int] = None, 22 | schedule: Optional[Sequence[tuple[str, Union[str, Sequence[str]]]]] = None, 23 | address_rotation: bool = False, 24 | ): 25 | flags = [("address_rotation", address_rotation)] 26 | metadata_json_str = kernel.specialize(args, kwargs, arch, grid, schedule, flags) 27 | metadata = json.loads(metadata_json_str) 28 | return metadata 29 | 30 | # wrapper around tracing step via KLR's Lean FFI. 31 | def _trace_kernel( 32 | kernel: frontend.Kernel, 33 | *, 34 | dst_filepath: str, 35 | ) -> dict: 36 | """ 37 | Trace Python to KLIR 38 | 39 | Returns: dict of metadata 40 | """ 41 | metadata_json_str = kernel.trace(dst_filepath) 42 | metadata = json.loads(metadata_json_str) 43 | with open(metadata.get("errors")) as f: 44 | rv = f.read() 45 | if rv: 46 | raise Exception("Error(s) during tracing" + rv) 47 | return metadata 48 | -------------------------------------------------------------------------------- /KLR/Trace/Extension.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Lean 18 | 19 | /- 20 | Definition of an enviromnet extension for tracking compiler builtins. 21 | Environment extensions have to be declared in separate modules from where they are used. 22 | 23 | See also: 24 | Builtin.lean - population of the environment in the `nki` macro. 25 | Term.lean - use of the environment to construct builtin list. 26 | -/ 27 | 28 | namespace KLR.Trace 29 | open Lean 30 | 31 | structure Builtin where 32 | nkiName : Name 33 | leanName : Name 34 | deriving Repr 35 | 36 | structure Builtins where 37 | builtins : Array Builtin := #[] 38 | deriving Inhabited, Repr 39 | 40 | def addEntry (s : Builtins) (e : Builtin) := 41 | { s with builtins := s.builtins.push e } 42 | 43 | initialize extension : SimplePersistentEnvExtension Builtin Builtins <- 44 | registerSimplePersistentEnvExtension { 45 | addEntryFn := addEntry 46 | addImportedFn := fun es => (mkStateFromImportedEntries addEntry {} es) 47 | } 48 | -------------------------------------------------------------------------------- /interop/klr/load_test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | A simple example of using the C++ deserialization API. 3 | This file can be compiled and run with: 4 | 5 | # gcc -std=c17 -Wall -c region.c cbor.c 6 | # g++ --std=c++17 -Wall load_test.cpp klir_serde.cpp klir_common.cpp cbor.o region.o 7 | # ./a.out test.klr 8 | */ 9 | 10 | #include 11 | #include "klir_ast.hpp" 12 | #include "klir_serde.hpp" 13 | 14 | using namespace std; 15 | using namespace klr; 16 | 17 | int main(int argc, char **argv) { 18 | if (argc != 2) { 19 | cout << "invalid args, expecting one KLR filename" << endl; 20 | return 1; 21 | } 22 | 23 | FILE *in = fopen(argv[1], "r"); 24 | if (!in) { 25 | perror("fopen"); 26 | return 1; 27 | } 28 | 29 | Ptr file = KLRFile_des(in); 30 | if (file->major != 0 || file->minor != 0 || file->patch != 12) 31 | throw runtime_error("Wrong KLR version"); 32 | 33 | cout << "KLR file header, version: " << 34 | file->major << "." << 35 | file->minor << "." << 36 | file->patch << endl; 37 | 38 | Ptr data = KLRMetaData_des(in); 39 | cout << "KLR meta data : " << 40 | data->format << endl; 41 | 42 | Ptr contents = Contents_des(in); 43 | cout << "KLR content type : " << 44 | static_cast(contents->tag) << endl; 45 | 46 | if (contents->tag != Contents::Tag::lnc) 47 | throw runtime_error("Wrong KLIR content type"); 48 | 49 | ContentsLncWrapper *w = static_cast(contents.get()); 50 | Ptr k = w->kernel; 51 | cout << "KLR Kernel : " << k << endl; 52 | 53 | fclose(in); 54 | 55 | return 0; 56 | } 57 | -------------------------------------------------------------------------------- /lakefile.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Lake 18 | open Lake DSL 19 | 20 | package "KLR" where 21 | 22 | @[default_target] 23 | lean_lib "KLR" where 24 | defaultFacets := #[LeanLib.staticFacet] 25 | 26 | lean_lib "Extract" where 27 | srcDir := "KLR/Extract" 28 | 29 | @[default_target] 30 | lean_exe "klr" where 31 | nativeFacets := fun _ => #[Module.oFacet] 32 | root := `Main 33 | supportInterpreter := true 34 | 35 | require Cli from git 36 | "https://github.com/leanprover/lean4-cli.git" @ "v4.23.0" 37 | 38 | require plausible from git 39 | "https://github.com/leanprover-community/plausible" @ "v4.23.0" 40 | 41 | require TensorLib from git 42 | "https://github.com/leanprover/TensorLib.git" @ "v0.0.16" 43 | 44 | require SHerLOC from git 45 | "https://github.com/leanprover/SHerLOC.git" @ "c74ae090d4326cca9ff636184c330a67ca039ef6" 46 | 47 | -- Comment the above and uncomment this for local development 48 | -- require TensorLib from "../TensorLib" 49 | 50 | --require iris from git 51 | -- "https://github.com/markusdemedeiros/iris-lean.git" 52 | -------------------------------------------------------------------------------- /interop/klr/stdc.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin 5 | */ 6 | #pragma once 7 | 8 | // A simple header file to check the C version and bring C standard 9 | // definitions into scope. 10 | 11 | #if defined(__STDC_NO_ATOMICS__) 12 | #error Compiler does not support atomic types 13 | #endif 14 | 15 | // Standard definitions (free standing) 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | // Standard integer types 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | // Standard C utilites (free standing) 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | #ifndef __has_builtin 38 | #define __has_builtin(x) 0 39 | #endif 40 | 41 | #if __has_builtin(__builtin_expect) 42 | #define likely(x) (__builtin_expect((x), 1)) 43 | #define unlikely(x) (__builtin_expect((x), 0)) 44 | #else 45 | #define likely(x) (x) 46 | #define unlikely(x) (x) 47 | #endif 48 | 49 | #define check_size(s,n) \ 50 | static_assert(sizeof(s) == n, "sizeof "#s" unexpected") 51 | 52 | typedef int8_t i8; 53 | typedef int16_t i16; 54 | typedef int32_t i32; 55 | typedef int64_t i64; 56 | 57 | typedef uint8_t u8; 58 | typedef uint16_t u16; 59 | typedef uint32_t u32; 60 | typedef uint64_t u64; 61 | 62 | typedef float f32; 63 | typedef double f64; 64 | 65 | check_size(f32, 4); 66 | check_size(f64, 8); 67 | -------------------------------------------------------------------------------- /interop/test/test_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from examples import * 5 | from klr.frontend import Kernel 6 | 7 | 8 | @pytest.mark.parametrize("module,name", 9 | [ (getting_started, "nki_tensor_add_kernel"), 10 | (layout, "tensor_exp_kernel_1"), 11 | (layout, "tensor_exp_kernel_2"), 12 | (index, "tensor_split_kernel_"), 13 | (index, "tensor_maxpool_kernel_"), 14 | (mm, "matmul_128x128x512_spmd_nisa"), 15 | (mm, "matmul_128x128x512_nl"), 16 | (prof, "tensor_exp_kernel_"), 17 | (average_pool, "tensor_avgpool_kernel_"), 18 | (fused_mamba, "mamba_v1"), 19 | (fused_mamba, "mamba_v2"), 20 | (fused_mamba, "mamba_v3"), 21 | (layernorm, "nki_layernorm_kernel_v1"), 22 | (layernorm, "nki_layernorm_kernel_v2"), 23 | (matmul, "nki_matmul_basic_"), 24 | (matmul, "nki_matmul_tiled_"), 25 | (matmul, "nki_matmul_hoist_load_"), 26 | (matmul, "nki_matmul_block_free_dimension_"), 27 | (matmul, "nki_matmul_fully_optimized_"), 28 | (rmsnorm, "nki_rmsnorm_kernel"), 29 | (sd_attention, "fused_self_attn_for_SD_small_head_size"), 30 | (tensor_addition, "nki_tensor_add_kernel_"), 31 | (tensor_addition, "nki_tensor_add"), 32 | (transpose2d, "tensor_transpose2D_kernel_") 33 | ]) 34 | 35 | 36 | def test_parse(module, name): 37 | f = getattr(module, name) 38 | K = Kernel(f) 39 | print(K) 40 | 41 | 42 | if __name__ == '__main__': 43 | module, name = getting_started, "nki_tensor_add_kernel" 44 | a = np.ndarray((128,512)) 45 | f = getattr(module, name) 46 | K = Kernel(f) 47 | K.specialize((a,a,a)) 48 | print(K) 49 | -------------------------------------------------------------------------------- /KLR/Util/NumBytesTest.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.NumBytes 18 | 19 | open KLR.Util(NumBytes) 20 | 21 | section Test 22 | 23 | private structure Foo where 24 | x : Int8 25 | y : Int32 26 | z : Int8 × Int16 27 | deriving Inhabited, NumBytes 28 | 29 | #guard NumBytes.numBytes (default:Foo) == 8 30 | 31 | /-- 32 | error: deriving NumBytes only works on single structures 33 | -/ 34 | #guard_msgs in 35 | mutual 36 | private structure Foo1 where 37 | x : Int8 38 | deriving NumBytes 39 | 40 | private structure Foo2 where 41 | x : Int8 42 | deriving NumBytes 43 | end 44 | 45 | /-- 46 | error: deriving NumBytes only works on single structures 47 | -/ 48 | #guard_msgs in 49 | mutual 50 | private structure Bar1 where 51 | x : Int8 52 | deriving NumBytes 53 | 54 | private structure Bar2 where 55 | x : Int8 56 | -- No deriving clause here 57 | end 58 | 59 | /-- 60 | error: deriving NumBytes only works on single structures 61 | -/ 62 | #guard_msgs in 63 | private inductive Baz where 64 | | x : Int -> Baz 65 | | y : Nat -> Baz 66 | deriving NumBytes 67 | 68 | end Test 69 | -------------------------------------------------------------------------------- /KLR/Util/Meta.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.Common 18 | import Lean.Elab 19 | 20 | open KLR(Err) 21 | open Lean(MonadError Name Syntax Term) 22 | open Lean.Elab.Term(TermElabM elabBinders elabTerm) 23 | 24 | /- 25 | This namespace is intended to be a library of utilities for 26 | metaprogramming in KLR. The functions we've been using are spread 27 | over many files and namespaces, and are hard to remember. Let's use 28 | this to organize as many of our standard patterns as possible. Some 29 | of them, like stringToStrLit are just abbreviations for other library 30 | functions. Even so, I still think there's value to having them in one spot. 31 | -/ 32 | namespace KLR.Util.Meta 33 | 34 | def stringToStrLit (s : String) : Term := Syntax.mkStrLit s 35 | 36 | def nameToIdent (n : Name) : Term := Lean.mkIdent n 37 | 38 | def nameToString [Monad m] [MonadError m] (n : Name) : m String := do 39 | let .str _ s := n | throwError "Not a string name" 40 | return s 41 | 42 | def nameToStrLit [Monad m] [MonadError m] (n : Name) : m Term := do 43 | return stringToStrLit (<- nameToString n) 44 | 45 | end KLR.Util.Meta 46 | -------------------------------------------------------------------------------- /interop/test/eval/kernels.py: -------------------------------------------------------------------------------- 1 | import nki.isa as nisa 2 | import nki.language as nl 3 | import numpy as np 4 | 5 | # t + 1.0 with no access pattern 6 | def kernel1(a): 7 | a_tile = nl.load(a) 8 | b_tile = nisa.tensor_scalar(a_tile, np.add, 1.0) 9 | b = nl.ndarray(b_tile.shape, dtype=b_tile.dtype, buffer=nl.sbuf) 10 | nl.store(b, b_tile) 11 | return b 12 | 13 | # t - 5.0 with no access pattern 14 | def kernel2(a): 15 | a_tile = nl.load(a) 16 | b_tile = nisa.tensor_scalar(a_tile, np.subtract, 5.0) 17 | b = nl.ndarray(b_tile.shape, dtype=b_tile.dtype, buffer=nl.sbuf) 18 | nl.store(b, b_tile) 19 | return b 20 | 21 | # t * 5.0 with no access pattern 22 | def kernel3(a): 23 | a_tile = nl.load(a) 24 | b_tile = nisa.tensor_scalar(a_tile, np.multiply, 5.0) 25 | b = nl.ndarray(b_tile.shape, dtype=b_tile.dtype, buffer=nl.sbuf) 26 | nl.store(b, b_tile) 27 | return b 28 | 29 | # t * 5.0 - 1 with no access pattern 30 | def kernel4(a): 31 | a_tile = nl.load(a) 32 | b_tile = nisa.tensor_scalar(a_tile, np.multiply, 5.0, False, np.subtract, 1.0, False) 33 | b = nl.ndarray(b_tile.shape, dtype=b_tile.dtype, buffer=nl.sbuf) 34 | nl.store(b, b_tile) 35 | return b 36 | 37 | # 1 - t * 5.0 with no access pattern 38 | def kernel5(a): 39 | a_tile = nl.load(a) 40 | b_tile = nisa.tensor_scalar(a_tile, np.multiply, 5.0, False, np.subtract, 1.0, True) 41 | b = nl.ndarray(b_tile.shape, dtype=b_tile.dtype, buffer=nl.sbuf) 42 | nl.store(b, b_tile) 43 | return b 44 | 45 | # 1 - (5.0 - t) with no access pattern 46 | def kernel6(a): 47 | a_tile = nl.load(a) 48 | b_tile = nisa.tensor_scalar(a_tile, np.subtract, 5.0, True, np.subtract, 1.0, True) 49 | b = nl.ndarray(b_tile.shape, dtype=b_tile.dtype, buffer=nl.sbuf) 50 | nl.store(b, b_tile) 51 | return b 52 | -------------------------------------------------------------------------------- /KLR/K/K3/AST.lean: -------------------------------------------------------------------------------- 1 | import KLR.TGR.AST 2 | import KLR.K.Operators 3 | import Lean 4 | import TensorLib.Tensor 5 | 6 | namespace KLR.K.K3 7 | 8 | open KLR.TGR(TensorTy) 9 | 10 | abbrev Var := String 11 | 12 | /- A tensor in K3 has a name, shape, and datatype -/ 13 | structure TensorK3 where 14 | name : Var 15 | type : TensorTy 16 | deriving Inhabited, Repr, BEq 17 | 18 | /- A scalar in K3 can be a float, int, or vector, where a vector is a named 19 | variable with a size and datatype. -/ 20 | inductive ScalarK3 21 | | float (f : Float32) 22 | | int (f : Nat) 23 | | vector (name : Var) (size : Nat) (dtype : TensorLib.Dtype) 24 | deriving Inhabited, Repr, BEq 25 | 26 | abbrev OperatorK3 := KLR.K.Operator TensorK3 ScalarK3 27 | 28 | /- K3 functions take a list of arguments as input and have a list of outputs. 29 | The input arguments can be referred to by name, and it is assumed that by the 30 | end of the instruction stream the named output tensors will have been written to. -/ 31 | structure FunctionK3 where 32 | name : String 33 | inputs : List TensorK3 34 | outputs : List TensorK3 35 | statements : List OperatorK3 36 | deriving Inhabited, Repr, BEq 37 | 38 | instance : ToString TensorK3 where 39 | toString t := 40 | s!"%{t.name}<{t.type.shape.val.toString}>" 41 | 42 | instance : ToString ScalarK3 where 43 | toString 44 | | .float f => s!"{f}" 45 | | .int i => s!"{i}" 46 | | .vector name size dtype=> s!"{name}<{size}x{dtype}>" 47 | 48 | instance : ToString FunctionK3 where 49 | toString f := 50 | let inputs := f.inputs.map ToString.toString |> ",".intercalate 51 | let outputs := f.outputs.map ToString.toString |> ",".intercalate 52 | let body := f.statements.map ToString.toString |> "\n\t".intercalate 53 | s!"def {f.name}({inputs}) -> {outputs} :\n\t{body}" 54 | -------------------------------------------------------------------------------- /interop/klr/frontend.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin 5 | */ 6 | #pragma once 7 | #include "stdc.h" 8 | #include "region.h" 9 | 10 | #define PY_SSIZE_T_CLEAN 11 | #include 12 | 13 | #if PY_MINOR_VERSION == 9 14 | #define Py_IsNone(x) ((x) == Py_None) 15 | #define Py_IsTrue(x) ((x) == Py_True) 16 | #endif 17 | 18 | // Front-end version (place holder) 19 | #define KLR_VERSION 1 20 | 21 | // The place where we live 22 | #ifdef IS_NKI_REPO 23 | #define MODULE_ROOT "nki._klr" 24 | #else 25 | #define MODULE_ROOT "klr" 26 | #endif 27 | 28 | // The front-end is accessed through the class Kernel; one instance 29 | // per kernel. Each instance has a `struct kernel` on the C side. 30 | 31 | struct kernel { 32 | PyObject_HEAD 33 | PyObject *f; // Kernel function 34 | struct region *region; 35 | struct lean_kernel *lean_kernel; 36 | }; 37 | 38 | // peg_parser.c 39 | struct _mod* parse_string(const char *str, PyObject* filename); 40 | void free_python_ast(struct _mod *m); 41 | 42 | // gather.c 43 | PyObject* specialize(struct kernel *k, PyObject *args, PyObject *kws, PyObject *arch, PyObject *grid, PyObject *schedule, PyObject *flags); 44 | const char* serialize_python(struct kernel *k); 45 | const char* trace(struct kernel *k, const char *dst_file, const char *dst_format, const char *dbg_file); 46 | 47 | // klr_ffi.c 48 | 49 | // Initialize Lean and the KLR module. 50 | // On failure, returns false with a Python exception set. 51 | bool initialize_KLR_lean_ffi(void); 52 | 53 | // Sanity tests 54 | PyObject* lean_ffi_hello(PyObject *self, PyObject *args); 55 | PyObject* lean_ffi_throw(PyObject *self, PyObject *args); 56 | PyObject* lean_ffi_panic(PyObject *self, PyObject *args); 57 | -------------------------------------------------------------------------------- /interop/test/test_enum.py: -------------------------------------------------------------------------------- 1 | # This file exercises the Lean partial evaluator with 2 | # a set of basic unit tests. Each function is parsed, 3 | # handed to Lean, where it is checked and reduced to KLR. 4 | 5 | import os 6 | import pytest 7 | 8 | from enum import Enum 9 | from klr.frontend import Kernel 10 | 11 | # Success cases 12 | # (these functions should load and trace to KLR) 13 | 14 | class E(Enum): 15 | x = 1 16 | y = 2 17 | z = 3 18 | 19 | def f(self): 20 | print(self.name + " " + str(self.value)) 21 | 22 | def enum1(): 23 | e = E(name="x", value=1) 24 | assert e.name == "x" 25 | assert e.value == 1 26 | 27 | def enum2(): 28 | e = E.x 29 | assert e.name == "x" 30 | assert e.value == 1 31 | 32 | def enum3(): 33 | assert E.y.name == "y" 34 | assert E.y.value == 2 35 | 36 | def enum4(): 37 | assert E.z.name == "z" 38 | assert E.z.value == 3 39 | 40 | def enumEq1(): 41 | assert E.x == E.x 42 | 43 | def enumEq2(): 44 | assert E.x != E.y 45 | 46 | def enumEq3(): 47 | assert E.y != E.z 48 | 49 | # test each function in turn 50 | @pytest.mark.parametrize("f", [ 51 | enum1, 52 | enum2, 53 | enum3, 54 | enum4, 55 | enumEq1, 56 | enumEq2, 57 | enumEq3, 58 | ]) 59 | def test_succeed(f): 60 | F = Kernel(f) # parse python 61 | F.specialize() 62 | F.trace("tmp.klr") 63 | os.remove("tmp.klr") 64 | 65 | 66 | # Boundary crossing 67 | # Check objects that originate from Python 68 | 69 | def cross1(e): 70 | assert e.name == "x" 71 | assert e.value == 1 72 | 73 | def cross2(e): 74 | # technically this reference is crossing 75 | assert E.x.x.name == "x" 76 | 77 | # test each function in turn 78 | @pytest.mark.parametrize("f", [ 79 | cross1, 80 | cross2, 81 | ]) 82 | def test_crossing(f): 83 | F = Kernel(f) # parse python 84 | F.specialize((E.x,)) 85 | F.trace("tmp.klr") 86 | os.remove("tmp.klr") 87 | -------------------------------------------------------------------------------- /interop/test/test_dyn_ap.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import nki.isa as nisa 4 | import nki.language as nl 5 | 6 | from runner import * 7 | 8 | def test_dynamic_ap(): 9 | t = nl.ndarray((128, 128, 128), nl.float32, name="t") 10 | ap1 = t.ap([[128*128, 128], [128, 1], [1, 1]]) 11 | ap2 = t.ap([[128*128, 128], [128, 1], [1, 1]], offset=32) 12 | i = 1 # no way to produce a dynamic scalar yet, but it int is also immediate 13 | ap2 = t.ap([[128*128, 128], [128, 1], [1, 1]], offset=32, scalar_offset=i) 14 | v = nl.ndarray((128,1), nl.float32, name="v") 15 | ap3 = t.ap([[128*128, 128], [128, 1], [1, 1]], offset=32, vector_offset=v, indirect_dim=0) 16 | 17 | def test_dyn_ap_fail(): 18 | t = nl.ndarray((128, 128, 128), nl.float32, name="t") 19 | i = 1 # no way to produce a dynamic scalar yet, but it int is also immediate 20 | ap2 = t.ap( 21 | [[128*128, 128], [128, 1], [1, 1]], offset=32, scalar_offset=i 22 | ).ap( 23 | [[128*128, 128], [128, 1], [1, 1]], offset=32, scalar_offset=i 24 | ) 25 | 26 | @pytest.mark.parametrize("f", [ 27 | test_dynamic_ap 28 | ]) 29 | def test_succeed(f): 30 | F = Kernel(f) 31 | F.specialize() 32 | rv = F.trace("out.klr") 33 | rv = json.loads(rv) 34 | errs = rv["errors"] 35 | assert errs is None or errs == [], f"errors are set. errors {errs}" 36 | 37 | if os.path.exists("out.klr"): 38 | os.remove("out.klr") 39 | 40 | @pytest.mark.parametrize("f", [ 41 | test_dyn_ap_fail 42 | ]) 43 | def test_fail(f): 44 | F = Kernel(f) 45 | F.specialize() 46 | rv = F.trace("out.klr") 47 | rv = json.loads(rv) 48 | errs = rv["errors"] 49 | assert errs and len(errs) > 0 50 | 51 | if os.path.exists("out.klr"): 52 | os.remove("out.klr") 53 | 54 | if __name__ == "__main__": 55 | test_succeed(test_dynamic_ap) 56 | test_fail(test_dyn_ap_fail) 57 | -------------------------------------------------------------------------------- /KLR/Semantics/Notation.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Markus de Medeiros 5 | -/ 6 | 7 | import Lean 8 | import KLR.Semantics.NML 9 | 10 | open Lean Elab Meta 11 | 12 | elab "#debug_expr " termStx:term : command => 13 | open Lean Lean.Elab Command Term in 14 | liftTermElabM 15 | try 16 | let s ← elabTerm termStx (expectedType? := none) 17 | logInfo s!"elaboration: {s}" 18 | catch | _ => throwError s!"failure" 19 | 20 | #debug_expr @NML.Locals.bind 21 | 22 | 23 | /-## NML Values -/ 24 | declare_syntax_cat nml_val 25 | syntax "%u " : nml_val 26 | syntax "%b " term : nml_val 27 | syntax "%d " term : nml_val 28 | syntax "%i " term : nml_val 29 | syntax "%a " term : nml_val 30 | syntax "%p " term : nml_val 31 | syntax "%s " term : nml_val 32 | 33 | syntax "[nml_val|" nml_val "]" : term 34 | 35 | macro_rules 36 | | `([nml_val| %u]) => `(@NML.Value.unit _) 37 | | `([nml_val| %b $b]) => `(@NML.Value.bool _ $b) 38 | 39 | 40 | -- #check [nml_val| %u ] 41 | -- #check [nml_val| %b true ] 42 | 43 | declare_syntax_cat nml_binding 44 | syntax term " ↣ " nml_val : nml_binding 45 | syntax "[nml_binding|" nml_binding "]" : term 46 | macro_rules 47 | | `([nml_binding| $x:term ↣ $y:nml_val]) => `(fun l => @NML.Locals.bind _ l $x [nml_val| $y]) 48 | -- #check [nml_binding| "x" ↣ %b true] 49 | 50 | declare_syntax_cat nml_locals 51 | 52 | elab "[nml_locs|" tuples:nml_binding,* "]" : term => do 53 | let mut _ : Expr := Expr.const ``NML.nolocals [] 54 | for tup in tuples.getElems do 55 | match tup with 56 | | `(nml_binding| $p:term ↣ $v) => logInfo s!"OK {p} and {v}" 57 | | stx => throwError s!"{stx}" 58 | return Expr.const ``Unit.unit [] 59 | -- return b 60 | 61 | -- #check ([nml_locs| "x" ↣ %b true, "y" ↣ %u]) 62 | -- #eval Lean.Meta.reduce `(NML.nolocals _) 63 | -------------------------------------------------------------------------------- /KLR/Util/BigArray.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | /- 18 | A simple big array type 19 | 20 | This structure is designed to avoid memory copies when adding elements to an 21 | array. It does this by keeping an array of arrays, each of which will only be 22 | filled to its initial capacity. 23 | -/ 24 | 25 | namespace KLR.Util 26 | 27 | structure BigArray (a : Type) where 28 | capacity : Nat := 1024 29 | children : Nat := 10 30 | arr : Array (Array a) := (Array.emptyWithCapacity children).push 31 | (Array.emptyWithCapacity capacity) 32 | 33 | namespace BigArray 34 | 35 | def push (ba : BigArray a) (x : a) : BigArray a := 36 | let index := ba.arr.size - 1 37 | let child := ba.arr[index]! 38 | if child.size >= ba.capacity then 39 | let child := Array.emptyWithCapacity ba.capacity 40 | let child := child.push x 41 | let arr := ba.arr.push child 42 | { ba with arr } 43 | else 44 | let child := child.push x 45 | let arr := ba.arr.set! index child 46 | { ba with arr } 47 | 48 | #guard 49 | let ba : BigArray Nat := { capacity := 2 } 50 | let ba := ba.push 1 51 | let ba := ba.push 2 52 | let ba := ba.push 3 53 | ba.arr.size == 2 && 54 | ba.arr[0]!.size == 2 && 55 | ba.arr[1]!.size == 1 && 56 | ba.arr[0]![0]! == 1 && 57 | ba.arr[0]![1]! == 2 && 58 | ba.arr[1]![0]! == 3 59 | -------------------------------------------------------------------------------- /interop/klr/peg_parser/Tokens: -------------------------------------------------------------------------------- 1 | ENDMARKER 2 | NAME 3 | NUMBER 4 | STRING 5 | NEWLINE 6 | INDENT 7 | DEDENT 8 | 9 | LPAR '(' 10 | RPAR ')' 11 | LSQB '[' 12 | RSQB ']' 13 | COLON ':' 14 | COMMA ',' 15 | SEMI ';' 16 | PLUS '+' 17 | MINUS '-' 18 | STAR '*' 19 | SLASH '/' 20 | VBAR '|' 21 | AMPER '&' 22 | LESS '<' 23 | GREATER '>' 24 | EQUAL '=' 25 | DOT '.' 26 | PERCENT '%' 27 | LBRACE '{' 28 | RBRACE '}' 29 | EQEQUAL '==' 30 | NOTEQUAL '!=' 31 | LESSEQUAL '<=' 32 | GREATEREQUAL '>=' 33 | TILDE '~' 34 | CIRCUMFLEX '^' 35 | LEFTSHIFT '<<' 36 | RIGHTSHIFT '>>' 37 | DOUBLESTAR '**' 38 | PLUSEQUAL '+=' 39 | MINEQUAL '-=' 40 | STAREQUAL '*=' 41 | SLASHEQUAL '/=' 42 | PERCENTEQUAL '%=' 43 | AMPEREQUAL '&=' 44 | VBAREQUAL '|=' 45 | CIRCUMFLEXEQUAL '^=' 46 | LEFTSHIFTEQUAL '<<=' 47 | RIGHTSHIFTEQUAL '>>=' 48 | DOUBLESTAREQUAL '**=' 49 | DOUBLESLASH '//' 50 | DOUBLESLASHEQUAL '//=' 51 | AT '@' 52 | ATEQUAL '@=' 53 | RARROW '->' 54 | ELLIPSIS '...' 55 | COLONEQUAL ':=' 56 | EXCLAMATION '!' 57 | 58 | OP 59 | AWAIT 60 | ASYNC 61 | TYPE_IGNORE 62 | TYPE_COMMENT 63 | SOFT_KEYWORD 64 | FSTRING_START 65 | FSTRING_MIDDLE 66 | FSTRING_END 67 | COMMENT 68 | NL 69 | ERRORTOKEN 70 | 71 | # These aren't used by the C tokenizer but are needed for tokenize.py 72 | ENCODING 73 | -------------------------------------------------------------------------------- /interop/test/test_memory.py: -------------------------------------------------------------------------------- 1 | # tests of pointers and memory allocation 2 | 3 | import pytest 4 | import os 5 | # import nki.typing as nt 6 | 7 | from klr.frontend import Kernel 8 | 9 | # Success cases 10 | # (these functions should load and trace to KLR) 11 | 12 | def pointers(): 13 | x = sbuf[0:128, 0:512] 14 | y = x[32:,:] 15 | a = psum[:64,:512] 16 | b = a[32:,:] 17 | sb = sbuf[:,1024:2048] 18 | left = sb[:,0:512] 19 | right = sb[:,512:] 20 | big = sbuf[None,:] 21 | 22 | def views(): 23 | ptr = sbuf[0:32, 0:128] 24 | assert ptr.start == (0,0) 25 | assert ptr.size == (32,128) 26 | tensor1 = ptr.view("float32", (32, 32)) 27 | assert tensor1.shape == (32,32) 28 | tensor2 = ptr.view("int8", (32, 128)) 29 | assert tensor2.shape == (32,128) 30 | tensor3 = ptr.view("int8", (16, 64)) 31 | assert tensor3.shape == (16,64) 32 | 33 | # test each function in turn 34 | @pytest.mark.parametrize("f", [ 35 | pointers, 36 | views 37 | ]) 38 | def test_succeed(f): 39 | F = Kernel(f) # parse python 40 | F.specialize() 41 | 42 | # Failing cases 43 | # (These functions are expected to fail elaboration to KLR) 44 | 45 | def bad_pointer1(): sbuf[1:,0:512] # sbuf must start at 0,32,64, or 96 46 | def bad_pointer2(): sbuf[0:32,5:32] # starts must be even 47 | def bad_pointer3(): sbuf[None,0:511] # ends must be even 48 | def bad_pointer4(): sbuf[None,0:0x50000] # too big free dim 49 | def bad_pointer5(): sbuf[0:130,0:512] # too big pdim 50 | 51 | def too_large1(): sbuf[0:32, 0:16].view("float32", (64, 4)) 52 | def too_large2(): sbuf[0:32, 0:16].view("float32", (32, 5)) 53 | 54 | @pytest.mark.parametrize("f", [ 55 | # bad_pointer1, 56 | # bad_pointer2, 57 | # bad_pointer3, 58 | # bad_pointer4, 59 | # bad_pointer5, 60 | # too_large1, 61 | # too_large2, 62 | ]) 63 | def test_fails(f): 64 | F = Kernel(f) 65 | with pytest.raises(Exception): 66 | F.specialize() 67 | F.trace("out.klr") 68 | 69 | if os.exists("out.klr"): 70 | os.remove("out.klr") 71 | -------------------------------------------------------------------------------- /KLR/Util/ToBytesTest.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.ToBytes 18 | 19 | namespace KLR.Util 20 | 21 | private structure Foo where 22 | x : Int8 23 | y : Int16 24 | z : Int32 25 | deriving ToBytes 26 | 27 | #guard toBytes (Foo.mk 0x0 0x1 0x2) == ⟨ #[0, 1, 0, 2, 0, 0, 0] ⟩ 28 | 29 | private structure Bar where 30 | y : Foo 31 | z : Int32 32 | deriving ToBytes 33 | 34 | #guard toBytes (Bar.mk (Foo.mk 0x0 0x1 0x2) 3) == ⟨ #[0, 1, 0, 2, 0, 0, 0, 3, 0, 0, 0] ⟩ 35 | 36 | /-- 37 | error: deriving ToBytes only works on single structures or inductives all of whose branches have a single ToBytes argument 38 | -/ 39 | #guard_msgs in 40 | mutual 41 | private structure Foo1 where 42 | x : Int8 43 | deriving ToBytes 44 | 45 | private structure Foo2 where 46 | x : Int8 47 | deriving ToBytes 48 | end 49 | 50 | /-- 51 | error: deriving ToBytes only works on single structures or inductives all of whose branches have a single ToBytes argument 52 | -/ 53 | #guard_msgs in 54 | mutual 55 | private structure Bar1 where 56 | x : Int8 57 | deriving ToBytes 58 | 59 | private structure Bar2 where 60 | x : Int8 61 | -- No deriving clause here 62 | end 63 | 64 | private inductive FooI where 65 | | X (x : Int8) 66 | | Y (y : Int16) 67 | | Z (z : Int32) 68 | deriving ToBytes 69 | 70 | #guard toBytes (FooI.X 0x2) == ⟨ #[2] ⟩ 71 | #guard toBytes (FooI.Y 0x2) == ⟨ #[2, 0] ⟩ 72 | #guard toBytes (FooI.Z 0x2) == ⟨ #[2, 0, 0, 0] ⟩ 73 | -------------------------------------------------------------------------------- /KLR/Util/FromBytesTest.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.FromBytes 18 | 19 | namespace KLR.Util 20 | 21 | #guard fromBytes (Vector UInt8 4) ⟨ #[0, 1, 0, 2] ⟩ == .ok (Vector.mk #[(0 : UInt8), 1, 0, 2] (by simp), ⟨ #[] ⟩) 22 | 23 | private structure Foo where 24 | x : Int8 25 | y : Int16 26 | z : Int32 27 | deriving BEq, FromBytes, Inhabited, NumBytes 28 | 29 | #guard fromBytes Foo ⟨ #[0, 1, 0, 2, 0, 0, 0, 77] ⟩ == .ok (Foo.mk 0x0 0x1 0x2, ⟨ #[77] ⟩) 30 | 31 | private structure Bar where 32 | y : Foo 33 | z : Int32 34 | deriving BEq, FromBytes, Inhabited, NumBytes 35 | 36 | #guard fromBytes Bar ⟨ #[0, 1, 0, 2, 0, 0, 0, 3, 0, 0, 0, 77] ⟩ == .ok (Bar.mk (Foo.mk 0x0 0x1 0x2) 3, ⟨ #[77] ⟩) 37 | 38 | /-- 39 | error: deriving FromBytes only works on single structures 40 | -/ 41 | #guard_msgs in 42 | mutual 43 | private structure Foo1 where 44 | x : Int8 45 | deriving FromBytes 46 | 47 | private structure Foo2 where 48 | x : Int8 49 | deriving NumBytes 50 | end 51 | 52 | /-- 53 | error: deriving FromBytes only works on single structures 54 | -/ 55 | #guard_msgs in 56 | mutual 57 | private structure Bar1 where 58 | x : Int8 59 | deriving FromBytes 60 | 61 | private structure Bar2 where 62 | x : Int8 63 | -- No deriving clause here 64 | end 65 | 66 | /-- 67 | error: deriving FromBytes only works on single structures 68 | -/ 69 | #guard_msgs in 70 | private inductive Baz where 71 | | x : Int -> Baz 72 | | y : Nat -> Baz 73 | deriving FromBytes 74 | -------------------------------------------------------------------------------- /KLR/Util/Padding.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.FromBytes 18 | import KLR.Util.NumBytes 19 | import KLR.Util.ToBytes 20 | 21 | namespace KLR.Util 22 | 23 | open Lean(Json ToJson toJson) 24 | 25 | structure Padding (n : Nat) where 26 | deriving Inhabited 27 | 28 | namespace Padding 29 | 30 | def get (_ : Padding n) (_ : Nat) : Nat := 0 31 | 32 | instance : GetElem (Padding n) Nat Nat (fun _ i => i < n) where 33 | getElem _ _ _ := 0 34 | 35 | #guard 36 | let v : Padding 5 := Padding.mk 37 | v[0] == 0 && v[4]! == 0 && v[5]?.isNone 38 | 39 | instance : Repr (Padding n) where 40 | reprPrec _ _ := s!"Padding of size {n}" 41 | 42 | instance : BEq (Padding n) where 43 | beq _ _ := true 44 | 45 | instance : NumBytes (Padding n) where 46 | numBytes _ := n 47 | 48 | instance : ToBytes (Padding n) where 49 | toBytes _ := ByteArray.zeros n 50 | 51 | instance : FromBytes (Padding n) where 52 | fromBytesUnchecked arr := do 53 | let zeros := arr.take n 54 | let mut i := 0 55 | for byte in zeros.data do 56 | if byte != 0 then throw s!"Nonzero padding at index {i}" 57 | i := i + 1 58 | return (⟨⟩, arr.drop n) 59 | 60 | instance : ToJson (Padding n) where 61 | toJson _ := Json.str s!"Padding of size {n}" 62 | 63 | instance : ToSexp (Padding n) where 64 | toSexp _ := Sexp.atom s!"Padding of size {n}" 65 | 66 | instance : FromSexp (Padding n) where 67 | fromSexp? _ := default 68 | 69 | end Padding 70 | 71 | end KLR.Util 72 | -------------------------------------------------------------------------------- /KLR/Util/Common.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | namespace KLR 18 | 19 | /- 20 | The default choice for an error monad is `Except String`, used for simple 21 | computations that can fail. 22 | 23 | Provide automatic lifting of Err to any monad that supports throwing strings 24 | as errors. 25 | -/ 26 | abbrev Err := Except String 27 | 28 | instance : MonadLift Err IO where 29 | monadLift 30 | | .ok x => return x 31 | | .error s => throw $ .userError s 32 | 33 | instance [Monad m] [MonadExcept String m] : MonadLift Err m where 34 | monadLift 35 | | .ok x => return x 36 | | .error s => throw s 37 | 38 | instance [BEq a] : BEq (Err a) where 39 | beq x y := match x, y with 40 | | .ok a, .ok b => a == b 41 | | .error msg, .error msg' => msg == msg' 42 | | _, _ => false 43 | 44 | /- 45 | The default choice for a state monad is `EStateM String`. 46 | -/ 47 | abbrev StM := EStateM String 48 | 49 | def impossible {a : Type} [h : Inhabited a] (msg : String := "") := 50 | @panic a h s!"Invariant violation: {msg}" 51 | 52 | def _root_.Except.get! [Inhabited α] (v : Except ε α) : α := 53 | match v with 54 | | .error _ => impossible 55 | | .ok x => x 56 | 57 | def _root_.Except.getD (v : Except ε α) (default : α) : α := 58 | match v with 59 | | .ok v => v 60 | | .error _ => default 61 | 62 | -- TODO: Deprecate this 63 | def get! [Inhabited a] (x : Err a) : a := 64 | x.get! 65 | 66 | def natDivCeil (num denom : Nat) : Nat := (num + denom - 1) / denom 67 | 68 | end KLR 69 | -------------------------------------------------------------------------------- /KLR/Extract/Extract/ASDL.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Extract.Basic 18 | import KLR.Core 19 | import Lean 20 | 21 | /- 22 | Output functions for ASDL 23 | -/ 24 | 25 | namespace Extract.ASDL 26 | open Lean Meta 27 | 28 | local instance : ToString Name where 29 | toString 30 | | .str _ s => s 31 | | n => n.toString 32 | 33 | private def typeName (ty : SimpleType) : String := 34 | match ty with 35 | | .option t => typeName t ++ "?" 36 | | .list t => typeName t ++ "*" 37 | | _ => s!"{ty.name}" 38 | 39 | 40 | private def genType (ty : LeanType) (topLevel : Bool := false) : MetaM Unit := 41 | match ty with 42 | | .simple _ => pure () 43 | | .prod name fields => do 44 | if topLevel then 45 | IO.print s!"\n{name} = (" 46 | else 47 | IO.print s!"{name}(" 48 | let fieldStrs := fields.map (fun f => s!"{typeName f.type} {f.name}") 49 | IO.print (String.intercalate ", " fieldStrs) 50 | IO.println ")" 51 | | .sum name variants => do 52 | if ty.isEnum then do 53 | IO.println s!"{name} =" 54 | for v in variants do 55 | IO.println s!" | {v.name}" 56 | else do 57 | IO.println s!"{name} =" 58 | for v in variants do 59 | IO.print s!" | " 60 | genType v 61 | IO.println "" 62 | return () 63 | 64 | def generateNkiAST : MetaM Unit := do 65 | let tys <- klrAST 66 | for t in tys do 67 | genType t (topLevel := true) 68 | return () 69 | 70 | -- NOTE: Uncomment for debugging 71 | -- run_meta generateNkiAST 72 | -------------------------------------------------------------------------------- /interop/test/test_nki_isa_tensor_scalar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2025, Amazon.com. All Rights Reserved 3 | 4 | """ 5 | import unittest 6 | from apis import * 7 | 8 | """ 9 | Unit tests for tensor_scalar. 10 | 11 | If using these tests for development, you can generate the NKI.json from 12 | the top-level like so: 13 | 14 | PYTHONPATH=interop:interop/test ./bin/gather test_nki_isa_tensor_scalar.kernel1 > kernel1.json 15 | 16 | and then, e.g. 17 | 18 | lake exe klr trace kernel1.json 19 | lake exe klr compile kernel1.json 20 | """ 21 | 22 | # utility function - allocate memory in DRAM 23 | def alloc_like(t): 24 | return nl.ndarray(t.shape, dtype=t.dtype, buffer=nl.shared_hbm) 25 | 26 | # utility function - allocate memory in DRAM and copy SBUF tile to it 27 | def dram_tile(a): 28 | b = alloc_like(a) 29 | nl.store(b, a) 30 | return b 31 | 32 | # test kernel 1 : t - 1.0 with no access pattern 33 | def kernel1(a): 34 | a_tile = nl.load(a) 35 | b_tile = nisa.tensor_scalar(a_tile, np.subtract, 1.0) 36 | return dram_tile(b_tile) 37 | 38 | # test kernel 2 : t - 1.0 with ellipsis access pattern 39 | def kernel2(a): 40 | a_tile = nl.load(a[...]) 41 | b_tile = nisa.tensor_scalar(a_tile, np.subtract, 1.0) 42 | return dram_tile(b_tile) 43 | 44 | # test kernel 2 : t - 1.0 with simple tile access pattern 45 | def kernel3(a): 46 | a_tile = nl.load(a[0:128,0:512]) 47 | b_tile = nisa.tensor_scalar(a_tile, np.subtract, 1.0) 48 | return dram_tile(b_tile) 49 | 50 | # The above example will fail tracing with: 51 | # nl.store(b, b_tile) 52 | # ^-- incompatible shapes [10, 10] [128, 512] 53 | # This is because inferArguments is very dumb. 54 | # You can use the kernel below for testing to get proper arguments. 55 | def kernel3b(): 56 | a = nl.ndarray((128,512), dtype="float32", buffer=nl.shared_hbm) 57 | return kernel3(a) 58 | 59 | def tensor_scalar(dst, src, op, scalar): 60 | src_tile = nl.load(src) 61 | dst_tile = nisa.tensor_scalar(src_tile, op, 1.0) 62 | nl.store(dst, dst_tile) 63 | 64 | def kernel4(a): 65 | b = alloc_like(a) 66 | for x in range(4): 67 | tensor_scalar(b[x,0:10], a[x,0:10], np.subtract, 1.0) 68 | return b 69 | -------------------------------------------------------------------------------- /interop/test/examples/tensor_addition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024, Amazon.com. All Rights Reserved 3 | 4 | NKI implementation for tensor addition NKI tutorial. 5 | 6 | """ 7 | from apis import * 8 | 9 | def nki_tensor_add_kernel_(a_input, b_input, c_output): 10 | """NKI kernel to compute element-wise addition of two input tensors 11 | 12 | This kernel assumes strict input/output tile-sizes, of up-to [128,512] 13 | 14 | Args: 15 | a_input: a first input tensor, of shape [128,512] 16 | b_input: a second input tensor, of shape [128,512] 17 | c_output: an output tensor, of shape [128,512] 18 | """ 19 | 20 | # Calculate tile offsets based on current 'program' 21 | offset_i_x = nl.program_id(0) * 128 22 | offset_i_y = nl.program_id(1) * 512 23 | 24 | # Generate tensor indices to index tensors a and b 25 | ix = offset_i_x + nl.arange(128)[:, None] 26 | iy = offset_i_y + nl.arange(512)[None, :] 27 | 28 | # Load input data from device memory (HBM) to on-chip memory (SBUF) 29 | # We refer to an indexed portion of a tensor as an intermediate tensor 30 | a_tile = nl.load(a_input[ix, iy]) 31 | b_tile = nl.load(b_input[ix, iy]) 32 | 33 | # compute a + b 34 | c_tile = a_tile + b_tile 35 | 36 | # store the addition results back to device memory (c_output) 37 | nl.store(c_output[ix, iy], value=c_tile) 38 | 39 | 40 | def nki_tensor_add(a_input, b_input, c_output): 41 | """NKI kernel caller to compute element-wise addition of two input tensors 42 | 43 | This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs 44 | 45 | Args: 46 | a_input: a first input tensor, of shape [N*128, M*512] 47 | b_input: a second input tensor, of shape [N*128, M*512] 48 | 49 | Returns: 50 | a tensor of shape [N*128, M*512], the result of a_input + b_input 51 | """ 52 | 53 | # The SPMD launch grid denotes the number of kernel instances. 54 | # In this case, we use a 2D grid where the size of each invocation is 128x512 55 | grid_x = a_input.shape[0] // 128 56 | grid_y = a_input.shape[1] // 512 57 | #c_output = np.zeros(a_input.shape, dtype=a_input.dtype) 58 | 59 | nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input, c_output) 60 | 61 | return c_output 62 | -------------------------------------------------------------------------------- /interop/klr/cbor.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin, Claude 5 | */ 6 | #pragma once 7 | #include "stdc.h" 8 | #include 9 | 10 | #include "region.h" 11 | 12 | // Note: this code was written by Q, with minor edits by Q's human assistant PG 13 | 14 | // Encoding functions 15 | bool cbor_encode_uint(FILE *out, u64 value); 16 | bool cbor_encode_int(FILE *out, i64 value); 17 | bool cbor_encode_bool(FILE *out, bool value); 18 | bool cbor_encode_float(FILE *out, float value); 19 | bool cbor_encode_double(FILE *out, double value); 20 | bool cbor_encode_string(FILE *out, const char *s, u64 len); 21 | bool cbor_encode_array_start(FILE *out, u64 size); 22 | bool cbor_encode_tag(FILE *out, u8 type, u8 constructor, u8 len); 23 | bool cbor_encode_option(FILE *out, bool isSome); 24 | 25 | // Decoding functions 26 | bool cbor_decode_uint(FILE *in, u64 *value); 27 | bool cbor_decode_int(FILE *in, i64 *value); 28 | bool cbor_decode_bool(FILE *in, bool *value); 29 | bool cbor_decode_float(FILE *in, float *value); 30 | bool cbor_decode_double(FILE *in, double *value); 31 | bool cbor_decode_string(FILE *in, char **s, void*(alloc)(void*,size_t), void *arg); 32 | bool cbor_decode_array_start(FILE *in, u64 *size); 33 | bool cbor_decode_tag(FILE *in, u8 *type, u8 *constructor, u8 *len); 34 | bool cbor_decode_option(FILE *in, bool *isSome); 35 | 36 | // Functions Lean for generated code 37 | static inline bool String_ser(FILE *out, const char *s) { 38 | return cbor_encode_string(out, s, 0); 39 | } 40 | 41 | struct Prop {}; 42 | 43 | static inline bool Prop_ser(FILE *out, struct Prop p) { 44 | (void)out; 45 | (void)p; 46 | return true; 47 | } 48 | 49 | static inline bool Prop_des(FILE *out, struct region *region, struct Prop *p) { 50 | (void)out; 51 | (void)region; 52 | (void)p; 53 | return true; 54 | } 55 | 56 | bool Bool_des(FILE *out, struct region *region, bool *x); 57 | bool Nat_des(FILE *out, struct region *region, u32 *x); 58 | bool Int_des(FILE *out, struct region *region, i32 *x); 59 | bool Float_des(FILE *out, struct region *region, float *x); 60 | bool String_des(FILE *out, struct region *region, char **s); 61 | 62 | -------------------------------------------------------------------------------- /interop/klr/klir_common.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin 5 | */ 6 | #include "klir_common.hpp" 7 | namespace klr { 8 | 9 | extern "C" { 10 | // Decoding functions 11 | bool cbor_decode_uint(FILE *in, u64 *value); 12 | bool cbor_decode_int(FILE *in, i64 *value); 13 | bool cbor_decode_bool(FILE *in, bool *value); 14 | bool cbor_decode_float(FILE *in, float *value); 15 | bool cbor_decode_double(FILE *in, double *value); 16 | bool cbor_decode_string(FILE *in, char **s, void *(alloc)(void *, size_t), 17 | void *arg); 18 | bool cbor_decode_array_start(FILE *in, u64 *size); 19 | bool cbor_decode_tag(FILE *in, u8 *type, u8 *constructor, u8 *len); 20 | bool cbor_decode_option(FILE *in, bool *isSome); 21 | } 22 | 23 | bool deserialize_tag(FILE *in, u8 *type, u8 *constructor, u8 *len) { 24 | return cbor_decode_tag(in, type, constructor, len); 25 | } 26 | 27 | bool deserialize_array_start(FILE *in, u64 *size) { 28 | return cbor_decode_array_start(in, size); 29 | } 30 | 31 | bool deserialize_option(FILE *in, bool *isSome) { 32 | return cbor_decode_option(in, isSome); 33 | } 34 | 35 | Prop Prop_des(FILE *in) { return Prop(); } 36 | 37 | Bool Bool_des(FILE *in) { 38 | Bool res; 39 | if (!cbor_decode_bool(in, &res)) 40 | throw std::runtime_error("expecting Bool"); 41 | return res; 42 | } 43 | 44 | Nat Nat_des(FILE *in) { 45 | u64 res; 46 | if (!cbor_decode_uint(in, &res)) 47 | throw std::runtime_error("expecting Nat"); 48 | return res; 49 | } 50 | 51 | Int Int_des(FILE *in) { 52 | i64 res; 53 | if (!cbor_decode_int(in, &res)) 54 | throw std::runtime_error("expecting Int"); 55 | return res; 56 | } 57 | 58 | Float Float_des(FILE *in) { 59 | float res; 60 | if (!cbor_decode_float(in, &res)) 61 | throw std::runtime_error("expecting Float"); 62 | return res; 63 | } 64 | 65 | String String_des(FILE *in) { 66 | char *res = NULL; 67 | if (!cbor_decode_string(in, &res, NULL, NULL)) 68 | throw std::runtime_error("expecting String"); 69 | 70 | String result = res; 71 | free(res); // std::string make a copy 72 | return result; 73 | } 74 | 75 | } // namespace klr 76 | -------------------------------------------------------------------------------- /lake-manifest.json: -------------------------------------------------------------------------------- 1 | {"version": "1.1.0", 2 | "packagesDir": ".lake/packages", 3 | "packages": 4 | [{"url": "https://github.com/leanprover/SHerLOC.git", 5 | "type": "git", 6 | "subDir": null, 7 | "scope": "", 8 | "rev": "c74ae090d4326cca9ff636184c330a67ca039ef6", 9 | "name": "SHerLOC", 10 | "manifestFile": "lake-manifest.json", 11 | "inputRev": "c74ae090d4326cca9ff636184c330a67ca039ef6", 12 | "inherited": false, 13 | "configFile": "lakefile.lean"}, 14 | {"url": "https://github.com/leanprover/TensorLib.git", 15 | "type": "git", 16 | "subDir": null, 17 | "scope": "", 18 | "rev": "d8137a41053423de98c528d75a5fa1a01fa58567", 19 | "name": "TensorLib", 20 | "manifestFile": "lake-manifest.json", 21 | "inputRev": "v0.0.16", 22 | "inherited": false, 23 | "configFile": "lakefile.lean"}, 24 | {"url": "https://github.com/leanprover-community/plausible", 25 | "type": "git", 26 | "subDir": null, 27 | "scope": "", 28 | "rev": "a22e7c1fa7707fb7ea75f2f9fd6b14de2b7b87a9", 29 | "name": "plausible", 30 | "manifestFile": "lake-manifest.json", 31 | "inputRev": "v4.23.0", 32 | "inherited": false, 33 | "configFile": "lakefile.toml"}, 34 | {"url": "https://github.com/leanprover/lean4-cli.git", 35 | "type": "git", 36 | "subDir": null, 37 | "scope": "", 38 | "rev": "41c5d0b8814dec559e2e1441171db434fe2281cc", 39 | "name": "Cli", 40 | "manifestFile": "lake-manifest.json", 41 | "inputRev": "v4.23.0", 42 | "inherited": false, 43 | "configFile": "lakefile.toml"}, 44 | {"url": "https://github.com/leanprover-community/aesop", 45 | "type": "git", 46 | "subDir": null, 47 | "scope": "", 48 | "rev": "247ff80701c76760523b5d7c180b27b7708faf38", 49 | "name": "aesop", 50 | "manifestFile": "lake-manifest.json", 51 | "inputRev": "v4.23.0", 52 | "inherited": true, 53 | "configFile": "lakefile.toml"}, 54 | {"url": "https://github.com/leanprover-community/batteries", 55 | "type": "git", 56 | "subDir": null, 57 | "scope": "", 58 | "rev": "d117e2c28cba42e974bc22568ac999492a34e812", 59 | "name": "batteries", 60 | "manifestFile": "lake-manifest.json", 61 | "inputRev": "v4.23.0", 62 | "inherited": true, 63 | "configFile": "lakefile.toml"}], 64 | "name": "KLR", 65 | "lakeDir": ".lake"} 66 | -------------------------------------------------------------------------------- /interop/pyproject.toml: -------------------------------------------------------------------------------- 1 | # NB: All this setuptools stuff took about 1000 hours to figure out so never delete this. 2 | [build-system] 3 | requires = ["setuptools==68.0.0"] # version used by Brazil's Python3PBuildTool as of 2025.07.02 4 | build-backend = "setuptools.build_meta" 5 | 6 | [project] 7 | name = "klr-lang" 8 | version = "0.0.12" 9 | authors = [ 10 | {name = "Paul Govereau", email = "govereau@amazon.com"}, 11 | {name = "Sean McLaughlin", email = "seanmcl@amazon.com"}, 12 | ] 13 | description = "Intermediate langauge for tensor compilers" 14 | readme = "README.md" 15 | license = {text = "Apache-2.0"} 16 | keywords = ["trainium", "tpu", "pallas", "triton", "gpu"] 17 | 18 | # Note, dependencies here are "abstract" while the same 19 | # lines in requirements.txt are "pinned". 20 | # https://stackoverflow.com/questions/74508024/is-requirements-txt-still-needed-when-using-pyproject-toml 21 | # For now we'll just keep both, but maybe we can drop one or the 22 | # other. requirements.txt is nice to have for installing 23 | # deps from GitHub Actions. 24 | dependencies = [ 25 | "numpy", 26 | ] 27 | requires-python = ">= 3.9" 28 | 29 | [project.urls] 30 | Repository = "https://github.com/leanprover/KLR" 31 | 32 | [tool.pytest.ini_options] 33 | pythonpath = "." # Needed for tests to pass 34 | testpaths = [ 35 | "test", 36 | ] 37 | 38 | [tool.setuptools] 39 | packages = ["klr"] 40 | 41 | [project.scripts] 42 | klr = "klr._cli:run_klr" 43 | klr-gather = "klr._cli:gather" 44 | 45 | [tool.cibuildwheel] 46 | # Skip unsupported python versions as well as 32-bit platforms, which are not supported anymore. 47 | skip = "pp* *-win32 *-manylinux_i686 *-musllinux_*" 48 | # Let's use a more recent version of the manylinux image for more modern compilers 49 | 50 | # Build fails without this when we updated to Lean 4.16.0 51 | # auditwheel: error: cannot repair "/tmp/cibuildwheel/built_wheel/klr_lang-0.0.7-cp38-cp38-linux_x86_64.whl" to "manylinux2014_x86_64" ABI because of the presence of too-recent versioned symbols. You'll need to compile the wheel on an older toolchain. 52 | # https://github.com/pypa/cibuildwheel/issues/1982 53 | # https://github.com/Blosc/python-blosc2/blob/99525d3141ac802e60b3d7bea4dabd1f2ae92b8f/pyproject.toml#L54-L55 54 | manylinux-x86_64-image = "manylinux_2_28" 55 | manylinux-aarch64-image = "manylinux_2_28" 56 | -------------------------------------------------------------------------------- /KLR/K/K3/DotK3.lean: -------------------------------------------------------------------------------- 1 | import KLR.K.K3.AST 2 | import SHerLOC.Analysis.Graph 3 | 4 | open StableHLO.Analysis (Graph Edge Vertex) 5 | 6 | namespace KLR.K.K3 7 | 8 | /- This module outputs a K3 program as a graph in DOT format -/ 9 | 10 | /- DOT identifiers can't start with numbers, so we need to sanitize them -/ 11 | def sanitize (var : String) : String := 12 | s!"node_{var}" 13 | 14 | /- Makes a graph node for an argument to the kernel -/ 15 | def makeArgNode (argName : String) : Vertex := 16 | .mk 17 | (sanitize argName) 18 | (.mk [ 19 | ("label", s!"arg\\n{argName}"), 20 | ("shape", "box"), 21 | ("style", "filled"), 22 | ("fillcolor", "lightgray"), 23 | ("color", "gray") 24 | ]) 25 | 26 | /- Makes a graph node for an operator, with the output tensor as the label -/ 27 | def makeOpNode (op : OperatorK3) (output : TensorK3) : Vertex := 28 | let attrs := match op with 29 | | .matmulP .. => [ 30 | ("style", "filled"), 31 | ("fillcolor", "lightpink"), 32 | ("color", "red") 33 | ] 34 | | _ => [] 35 | .mk 36 | (sanitize output.name) 37 | (.mk ([ 38 | ("label", s!"{name op}\\n{output.name}\n{output.type.shape}"), 39 | ] ++ attrs)) 40 | 41 | /- Makes a graph edge from one tensor to another. The names should be sanitized -/ 42 | def makeEdge (source : String) (dest : String) : Edge := 43 | .mk 44 | source 45 | dest 46 | (.mk []) 47 | 48 | /- Produces a graph from a K3 function. -/ 49 | def graph (f : FunctionK3) : Graph := Id.run do 50 | let mut vertices := [] 51 | let mut edges := [] 52 | for op in f.statements do 53 | let targets := targets op 54 | let tensorDeps := dependencies op 55 | let scalarDeps := scalarDependencies op |>.filterMap fun dep => 56 | match dep with 57 | | .vector name size dtype => 58 | .some { name, type := ⟨⟨[size]⟩, dtype⟩ } 59 | | _ => none 60 | let deps := tensorDeps ++ scalarDeps 61 | for target in targets do 62 | vertices := (makeOpNode op target) :: vertices 63 | for dep in deps do 64 | for target in targets do 65 | edges := (makeEdge (sanitize dep.name) (sanitize target.name)) :: edges 66 | 67 | for arg in f.inputs do 68 | vertices := (makeArgNode arg.name) :: vertices 69 | 70 | ⟨f.name, vertices, edges⟩ 71 | 72 | end KLR.K.K3 73 | -------------------------------------------------------------------------------- /interop/test/test_list.py: -------------------------------------------------------------------------------- 1 | # This file exercises the Lean partial evaluator with 2 | # a set of basic unit tests. Each function is parsed, 3 | # handed to Lean, where it is checked and reduced to KLR. 4 | 5 | import os 6 | import pytest 7 | from runner import * 8 | import numpy as np 9 | 10 | from klr.frontend import Kernel 11 | 12 | # Success cases 13 | # (these functions should load and trace to KLR) 14 | 15 | def expr_list(t): 16 | assert [1,2,False] 17 | assert not [] 18 | 19 | def modify(): 20 | l = [1] 21 | x = l 22 | assert x == [1] 23 | l[0] = 5 24 | assert l[0] == 5 25 | assert l == [5] 26 | assert x == [5] 27 | 28 | def append(): 29 | l = [] 30 | l.append(2) 31 | assert l[0] == 2 32 | l.append(5) 33 | assert l == [2,5] 34 | 35 | def clear(): 36 | l = [1,2,3] 37 | l.clear() 38 | assert l == [] 39 | 40 | def copy(): 41 | l = [1,2,3] 42 | x = l 43 | y = l.copy() 44 | x[0] = 4 45 | assert l == [4,2,3] 46 | assert x == l 47 | assert y == [1,2,3] 48 | 49 | def count(): 50 | l = [1,(2,3),4] 51 | assert l.count() == 3 52 | l.clear() 53 | assert l.count() == 0 54 | assert [[2]].count() == 1 55 | 56 | def extend(): 57 | l = [] 58 | l.extend((1,2,3)) 59 | assert l == [1,2,3] 60 | l.extend([4]) 61 | assert l == [1,2,3,4] 62 | 63 | def index(): 64 | l = [1,2,3] 65 | i = l.index(3) 66 | assert i == 2 67 | assert l.index(7) == None 68 | 69 | def pop(): 70 | l = [1,2,3] 71 | x = l.pop() 72 | assert x == 3 73 | assert l == [1,2] 74 | 75 | def remove(): 76 | l = [3,1,3,2,3,3] 77 | l.remove(3) 78 | assert l == [1,2] 79 | 80 | def reverse(): 81 | l = [1,2,4,3] 82 | l.reverse() 83 | assert l == [3,4,2,1] 84 | 85 | # test each function in turn 86 | @pytest.mark.parametrize("f", [ 87 | expr_list, 88 | modify, 89 | append, 90 | clear, 91 | copy, 92 | count, 93 | extend, 94 | index, 95 | pop, 96 | remove, 97 | reverse, 98 | ]) 99 | def test_succeed(f): 100 | run_success(f, ()) 101 | 102 | # Failing cases 103 | # (These functions are expected to fail elaboration to KLR) 104 | 105 | def out_of_bounds(): 106 | l = [1,2] 107 | l[3] = 0 108 | 109 | @pytest.mark.parametrize("f", [ 110 | out_of_bounds, 111 | ]) 112 | def test_fails(f): 113 | run_fail(f, ()) 114 | -------------------------------------------------------------------------------- /KLR/Util/Plausible.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Plausible 18 | 19 | /- 20 | Common instances for Plausible 21 | -/ 22 | namespace KLR.Util.Plausible 23 | 24 | instance : Plausible.Shrinkable Int8 := ⟨ fun _ => [] ⟩ 25 | instance : Plausible.SampleableExt Int8 := 26 | Plausible.SampleableExt.mkSelfContained do 27 | let x <- Plausible.Gen.chooseNatLt 0 UInt8.size (Nat.zero_lt_succ _) 28 | return x.val.toInt8 29 | 30 | instance : Plausible.Shrinkable Int16 := ⟨ fun _ => [] ⟩ 31 | instance : Plausible.SampleableExt Int16 := 32 | Plausible.SampleableExt.mkSelfContained do 33 | let x <- Plausible.Gen.chooseNatLt 0 UInt16.size (Nat.zero_lt_succ _) 34 | return x.val.toInt16 35 | 36 | instance : Plausible.Shrinkable Int32 := ⟨ fun _ => [] ⟩ 37 | instance : Plausible.SampleableExt Int32 := 38 | Plausible.SampleableExt.mkSelfContained do 39 | let x <- Plausible.Gen.chooseNatLt 0 UInt32.size (Nat.zero_lt_succ _) 40 | return x.val.toInt32 41 | 42 | instance : Plausible.Shrinkable Int64 := ⟨ fun _ => [] ⟩ 43 | instance : Plausible.SampleableExt Int64 := 44 | Plausible.SampleableExt.mkSelfContained do 45 | let x <- Plausible.Gen.chooseNatLt 0 UInt64.size (Nat.zero_lt_succ _) 46 | return x.val.toInt64 47 | 48 | 49 | instance : Plausible.Shrinkable Float32 := ⟨ fun _ => [] ⟩ 50 | instance : Plausible.SampleableExt Float32 := 51 | Plausible.SampleableExt.mkSelfContained do 52 | let x <- Plausible.Gen.chooseNatLt 0 UInt32.size (Nat.zero_lt_succ _) 53 | return Float32.ofBits x.val.toUInt32 54 | 55 | instance : Plausible.Shrinkable Float := ⟨ fun _ => [] ⟩ 56 | instance : Plausible.SampleableExt Float := 57 | Plausible.SampleableExt.mkSelfContained do 58 | let x <- Plausible.Gen.chooseNatLt 0 UInt64.size (Nat.zero_lt_succ _) 59 | return Float.ofBits x.val.toUInt64 60 | -------------------------------------------------------------------------------- /interop/test/examples/rmsnorm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024, Amazon.com. All Rights Reserved 3 | 4 | RMSNorm NKI kernel implementation. 5 | 6 | """ 7 | 8 | import math 9 | 10 | from apis import * 11 | 12 | def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor): 13 | # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor 14 | # Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor)) 15 | # and N = a_tensor.shape[1] 16 | # Reduction (mean) is performed in the free (2nd) dimension 17 | 18 | # Make sure shapes match 19 | assert a_tensor.shape[1] == g_tensor.shape[0] 20 | assert a_tensor.shape == out_tensor.shape 21 | 22 | # Generate tensor indices to index input tensor 23 | ix = nl.arange(128)[:, None] 24 | iw = nl.arange(1)[:, None] 25 | iy = nl.arange(a_tensor.shape[1])[None, :] 26 | 27 | num_rows = a_tensor.shape[0] 28 | 29 | # Load RMSNorm weight once, reused by rows/tiles of a_tensor 30 | g_tile = nl.load(g_tensor.reshape((1, g_tensor.shape[0]))[iw, iy]) 31 | 32 | # Process 128 rows at a time due to 128-partition tile size limitation 33 | # Since we're not reducing across the first dimension 34 | # Tiles can be processed independently 35 | for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)): 36 | 37 | # Load input data from external memory to on-chip memory 38 | a_tile = nl.load(a_tensor[i * 128 + ix, iy], 39 | mask=(i * 128 + ix < num_rows)) 40 | 41 | # Compute element-wise square of a_tensor 42 | in_square = nl.square(a_tile) 43 | 44 | # Calculate sum of squared elements, along last dimension 45 | square_sum = nl.sum(in_square, axis=[1]) 46 | 47 | # Scale and get a reciprocal 48 | mean = square_sum / a_tensor.shape[1] 49 | 50 | # Take square root of mean and then reciprocal with 51 | # rsqrt API (one ISA instruction) 52 | rms_reciprocal = nl.rsqrt(mean) 53 | 54 | # Scale the input tensor 55 | out_tile = nl.multiply(a_tile, rms_reciprocal) 56 | 57 | # Broadcast weight along first axis to match tensor shape 58 | # num_rows_active = min(num_rows - i * 128, 128) 59 | g_bcast = g_tile.broadcast_to((128, g_tensor.shape[0])) 60 | 61 | # Multiply with the RMSNorm weight 62 | out_tile[...] = nl.multiply(out_tile, g_bcast, 63 | mask=(i * 128 + ix < num_rows)) 64 | 65 | # store the addition results back to external memory (out_tensor) 66 | nl.store(out_tensor[i * 128 + ix, iy], value=out_tile, 67 | mask=(i * 128 + ix < num_rows)) 68 | -------------------------------------------------------------------------------- /interop/test/examples/transpose2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024, Amazon.com. All Rights Reserved 3 | 4 | NKI baremetal implementation for transpose2d NKI tutorial. 5 | """ 6 | from apis import * 7 | 8 | def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D): 9 | """ 10 | NKI kernel to reorder the elements on axis[1] of the input tensor. 11 | 12 | Every row of the input tensor is a flattened row-major 2D matrix. 13 | The shape2D argument defines the dimensions of the flattened matrices (#rows,#cols). 14 | Our goal in this kernel is to transpose these flattened 2D matrices, i.e. make them (#cols,#rows). 15 | 16 | Example: 17 | in_tensor = [a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3] 18 | shape2D = (3,4) 19 | this means that in_tensor has 3 rows and 4 columns, i.e. can be represented as: 20 | [a0,a1,a2,a3] 21 | [b0,b1,b2,b3] 22 | [c0,c1,c2,c3] 23 | after transpose, we expect to get: 24 | [a0,b0,c0] 25 | [a1,b1,c1] 26 | [a2,b2,c2] 27 | [a3,b3,c3] 28 | Thus, out_tensor is expected to be [a0,b0,c0,a1,b1,c1,a2,b2,c2,a3,b3,c3] 29 | 30 | Args: 31 | in_tensor: an input tensor 32 | shape2D: tuple representing the dimensions to be transposed: (#rows, #cols) 33 | out_tensor: an output (transposed) tensor 34 | """ 35 | # Gather input shapes 36 | sz_p, _ = in_tensor.shape 37 | 38 | # Load input data from external memory to on-chip memory 39 | in_tile = nl.load(in_tensor) 40 | 41 | # Performing f1/f2 transpose 42 | # ========================== 43 | # The desired transpose pattern is provided as an input: 44 | sz_f1, sz_f2 = shape2D 45 | 46 | # We're going to need 3 indices to perform f1:f2 transpose. 47 | # - i_p0 is the parallel index 48 | # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes 49 | i_p0 = nl.arange(sz_p)[:, None, None] 50 | i_f1 = nl.arange(sz_f1)[None, :, None] 51 | i_f2 = nl.arange(sz_f2)[None, None, :] 52 | 53 | # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation 54 | # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern *per partition* 55 | # RHS traverses an F1 x F2 matrix in a row major manner 56 | # LHS traverses an F2 x F1 (new) matrix in a row major manner 57 | out_tile = nl.ndarray(shape=(sz_p, sz_f2*sz_f1), dtype=out_tensor.dtype) 58 | out_tile[i_p0, i_f2*sz_f1+i_f1] = nl.copy(in_tile[i_p0, i_f1*sz_f2+i_f2]) 59 | 60 | # Finally, we store out_tile to external memory 61 | nl.store(out_tensor, value=out_tile) 62 | -------------------------------------------------------------------------------- /bin/rename-wheels: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e -u -o pipefail 3 | trap "kill 0" SIGINT SIGTERM 4 | 5 | # Used to rename wheels before uploading wheels to pypi 6 | 7 | # pypi can't handle the 99.0 version tags. We need to move those back to versions it is OK with 8 | # While sketchy in the extreme, this should hopefully work. Why? 9 | # - The step that fails without the 99.0 bump is 'delocate' which looks for libs 10 | # linked in at different versions. 11 | # - Lean/Lake don't have external dependencies besides system libs 12 | # https://leanprover.zulipchat.com/#narrow/channel/424609-lean-at-aws/topic/.E2.9C.94.20Build.20static.20binary 13 | # I think this is likely a Lean/Lake bug: https://github.com/leanprover/lean4/pull/6631/files 14 | # - We are setting the version back to the OS that it was compiled on. 15 | # - Counterpoint to things being OK: if it's true that there are only system libraries, then 16 | # `delocate-wheel` shouldn't do anything, but then it's not clear why it's complaining 17 | # about the 99.0 min version of Lean. 18 | 19 | 20 | # Example output of ls -1: 21 | # klr-0.0.3-cp310-cp310-macosx_99_0_arm64.whl 22 | # klr-0.0.3-cp310-cp310-macosx_99_0_x86_64.whl 23 | # klr-0.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 24 | # klr-0.0.3-cp311-cp311-macosx_99_0_arm64.whl 25 | # klr-0.0.3-cp311-cp311-macosx_99_0_x86_64.whl 26 | # klr-0.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 27 | # klr-0.0.3-cp312-cp312-macosx_99_0_arm64.whl 28 | # klr-0.0.3-cp312-cp312-macosx_99_0_x86_64.whl 29 | # klr-0.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 30 | # klr-0.0.3-cp313-cp313-macosx_99_0_arm64.whl 31 | # klr-0.0.3-cp313-cp313-macosx_99_0_x86_64.whl 32 | # klr-0.0.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 33 | # klr-0.0.3-cp38-cp38-macosx_99_0_arm64.whl 34 | # klr-0.0.3-cp38-cp38-macosx_99_0_x86_64.whl 35 | # klr-0.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 36 | # klr-0.0.3-cp39-cp39-macosx_99_0_arm64.whl 37 | # klr-0.0.3-cp39-cp39-macosx_99_0_x86_64.whl 38 | # klr-0.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 39 | 40 | x86_suffix=macosx_99_0_x86_64 41 | arm_suffix=macosx_99_0_arm64 42 | for file in *; do 43 | if [[ "$file" == *"$x86_suffix"* ]]; then 44 | newname=$(echo "$file" | sed "s/$x86_suffix/macosx_13_0_x86_64/g") 45 | mv $file $newname 46 | echo "Renamed: $file -> $newname" 47 | fi 48 | if [[ "$file" == *"$arm_suffix"* ]]; then 49 | newname=$(echo "$file" | sed "s/$arm_suffix/macosx_14_0_arm64/g") 50 | mv $file $newname 51 | echo "Renamed: $file -> $newname" 52 | fi 53 | done 54 | -------------------------------------------------------------------------------- /KLR/Semantics/Float.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Markus de Medeiros 5 | -/ 6 | import Init.Data.Int.Basic 7 | import KLR.Core.Basic 8 | import KLR.Util 9 | 10 | /- 11 | ## Floating point typeclasses 12 | 13 | The operational semantics, as well as the logic, are parameterized by a representation 14 | of floating point numbers. 15 | -/ 16 | 17 | /-- Abstract representation of a floating point number. This class includes only 18 | those syntax classes necessary to express floating point calculations, extensions 19 | of `FloatRep` are used to state that the floats have different properties. -/ 20 | class FloatRep (f : Type _) extends Add f, Inhabited f where 21 | 22 | /- TODO: A floating point representation contains a well-behaved family of integers -/ 23 | /- TODO: Assoc, Comm, etc -/ 24 | /- TODO: Fields etc... can we axiomatize ℝ in a typeclass rather than import mathlib? -/ 25 | 26 | /- 27 | ## Floating point instances 28 | -/ 29 | 30 | /-- Symbolic representation of floats 31 | These floating point numbers are uninterpreted, so two SymbolicFloat calculations are 32 | equal iff they are do the same operations in the same order. -/ 33 | inductive SymbolicFloat where 34 | | value : SymbolicFloat 35 | | add : SymbolicFloat → SymbolicFloat → SymbolicFloat 36 | deriving BEq 37 | 38 | instance : FloatRep SymbolicFloat where 39 | default := .value 40 | add := .add 41 | 42 | 43 | /- Indexed families of floating point types, used to represent multiple floating point types within 44 | the same program. 45 | TODO: Could also add coercions if that's what the hw does 46 | -/ 47 | inductive FloatFamily {I : Type _} (ty : I → Type _) where 48 | | err : FloatFamily ty 49 | | ftype : (i : I) → (v : ty i) → FloatFamily ty 50 | 51 | class abbrev FloatFamilyIndex (I : Type _) := BEq I, LawfulBEq I 52 | 53 | /-- Pointwise lifting of a binop to a function of FloatFamilies, error when applied to differently 54 | tagged floating point types. -/ 55 | def lift_binop {I : Type _} {ty : I → Type _} [FloatFamilyIndex I] 56 | (op : {i : I} → ty i → ty i → ty i) : 57 | FloatFamily ty → FloatFamily ty → FloatFamily ty 58 | | .ftype i1 v1, .ftype i2 v2 => 59 | dite (i2 == i1) 60 | (fun H => .ftype i1 (op v1 ((congrArg ty <| LawfulBEq.eq_of_beq H) ▸ v2))) 61 | (fun _ => .err) 62 | | _, _ => .err 63 | 64 | instance (I : Type _) (ty : I → Type _) [FloatFamilyIndex I] [∀ {i : I}, FloatRep (ty i)] : 65 | FloatRep (FloatFamily ty) where 66 | default := .err 67 | add := lift_binop Add.add 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kernel Language Representation (KLR) 2 | 3 | This repository contains an implementation of KLR, a core language and 4 | elaborators for machine learning kernels. The goal of KLR is to define a common 5 | representation for kernel functions with a precise formal semantics along with 6 | translations from common kernel languages to the KLR core language. The initial 7 | focus of KLR is the 8 | [Neuron Kernel Interface](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/index.html), 9 | and the [Trainium](https://aws.amazon.com/ai/machine-learning/trainium/) hardware. 10 | 11 | # Building on Amazon Linux 2023 12 | 13 | To build the FFI utilities like Archive, please use 14 | 15 | LIBRARY_PATH=/usr/lib64 lake build 16 | 17 | (or export the path) to get linking to work correctly. 18 | 19 | # Quick Start 20 | 21 | The easiest way to get started using KLR is to install the python package 22 | using `pip`: 23 | 24 | ``` 25 | # pip install klr-lang 26 | # klr gather test.py test_kernel -o test_kernel.klr 27 | # klr trace test_kernel.klr 28 | ``` 29 | 30 | For more information see the [Getting Started Guide](docs/getting_started.md) 31 | 32 | # Interop with Python 33 | 34 | The KLR compiler starts with Python code (e.g. NKI kernels), and converts this 35 | into an instance of the abstract syntax tree found in `KLR/Python.lean`. The 36 | current version of KLR uses the CPython parser to do this conversion. The 37 | parsing processes is called "gather" and involves the following steps: 38 | 39 | 1. Load the Python interpreter with our custom CPython extension module 40 | 2. Find the kernel function and extract its source code 41 | 3. Run the CPython parser 42 | 4. Transform the CPython AST to our Python AST 43 | 5. Repeat steps 2-5 for all found references 44 | 6. Serialize the AST to the KLR on-disk format 45 | 7. From Lean, deserialize the Python AST from the on-disk format 46 | 47 | This process is complex and brittle, and will be replaced by a proper (pure 48 | Lean) parser in future versions of KLR. 49 | 50 | # Steps to make a new version/wheel 51 | 52 | 1. Bump the build or minor version in 53 | - interop/pyproject.toml (Deployment to PyPI will fail if you forget this.) 54 | - Main.klrCmd (Nothing will break if you don't, but we'd like to be consistent) 55 | 2. Create a git tag of the form v1.2.3 and push it to KLR repo 56 | 57 | This should trigger a build that uploads the artifacts to pypi. 58 | 59 | # Adding a new Lake package 60 | 61 | If you want to add a new directory with its own lakefile, 62 | please ensure you (relative) symlink lean-toolchain to that directory. 63 | Otherwise the VSCode plugin will use the top level directory 64 | lakefile. 65 | -------------------------------------------------------------------------------- /interop/klr/README.md: -------------------------------------------------------------------------------- 1 | # C/C++ Front-end 2 | 3 | This directory contains a C/C++ version of the KLR front-end for NKI. The main 4 | entry point is `frontend.c` which defines a python extension module and a type 5 | called `Kernel` which provides the API for the front-end. The typical use of 6 | this API will create a new `Kernel` type from a user's python function, 7 | serialize the result, and later deserialize to Python types: 8 | 9 | ```python 10 | import frontend 11 | K = frontend(kernel) 12 | K.specialize(arguments) 13 | bytes = K.serialize() 14 | ... 15 | nki_ast = frontend.deserialize(bytes) 16 | ``` 17 | 18 | ## Lean Generated Files 19 | 20 | The following files have been generated from Lean sources. 21 | 22 | | C File | Lean Source | Method | 23 | |-|-|-| 24 | | ast_common.h | KLR/Serde/File.lean | KLR/Extract/C | 25 | | ast_file.h | KLR/File.lean | KLR/Extract/C | 26 | | ast_python_core.h | KLR/Python.lean | KLR/Extract/C | 27 | | ast_nki.h | KLR/NKI/Basic.lean | KLR/Extract/C | 28 | | ast_nki.py | KLR/NKI/Basic.lean | KLR/Extract/Python | 29 | | serde_common.[hc] | KLR/Serde/File.lean | KLR/Extract/Serde | 30 | | serde_file.[hc] | KLR/File.lean | KLR/Extract/Serde | 31 | | serde_python_core.[hc] | KLR/Python.lean | KLR/Extract/Serde | 32 | | serde_nki.[hc] | KLR/NKI/Basic.lean | KLR/Extract/Serde | 33 | 34 | ## CPython Sources 35 | 36 | The files in the `peg_parser` directory are from the CPython sources. These 37 | files have been lightly modified for use in NKI. The modifications are mostly 38 | marking functions as `static` and removing unused code. See comments in 39 | `peg_parser.c` for more details. 40 | 41 | ### Generating PEG Parser files 42 | 43 | Several of the source files are generated from the parser generator in the 44 | CPython source tree. The source for generating these files is the `Token` and 45 | `python.gram` files (contained here) and the `Python.asdl` files in the CPython 46 | sources. 47 | 48 | The following commands can be used to generate derived sources. 49 | 50 | Files: `peg_parser/ast_python.c` and `ast_python.h` 51 | ```sh 52 | $PYSRC/Parser/asdl_c.py -d $PYSRC/Parser/Python.asdl \ 53 | -C peg_parser/ast_python.c -H ast_python.h -i ignored 54 | rm ignored 55 | ``` 56 | 57 | Files: `peg_parser/token.c` 58 | ```sh 59 | python $PYSRC/Tools/build/generate_token.py h Tokens tmp1 60 | python $PYSRC/Tools/build/generate_token.py c Tokens tmp2 61 | cat tmp1 tmp2 > peg_parser/token.c 62 | rm tmp1 tmp2 63 | ``` 64 | 65 | Files: `peg_parser/parser.c` 66 | ```sh 67 | PYTHONPATH=$PYSRC/Tools/peg_generator \ 68 | python3 -m pegen -q c python.gram Tokens -o peg_parser/parser.c 69 | ``` 70 | 71 | -------------------------------------------------------------------------------- /KLR/Extract/Extract.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Extract.Basic 18 | import Extract.C 19 | import Extract.Cpp 20 | import Extract.Python 21 | import Extract.Serde 22 | import Extract.SerdeCpp 23 | import Extract.ToPython 24 | import Extract.ASDL 25 | import Lean 26 | 27 | namespace Extract 28 | open Lean Meta 29 | 30 | private def withFile (file : String) (m : MetaM Unit) : MetaM Unit := do 31 | let h <- IO.FS.Handle.mk file IO.FS.Mode.write 32 | IO.withStdout (.ofHandle h) m 33 | 34 | private def dir := "interop/klr" 35 | 36 | -- Note: please leave commmented out items as we may need to bring these 37 | -- back in follow-up commits. 38 | run_meta do 39 | --withFile s!"{dir}/ast_common.h" C.generateCommonAST 40 | --withFile s!"{dir}/ast_file.h" C.generateFileAST 41 | --withFile s!"{dir}/ast_python_core.h" C.generatePythonAST 42 | --withFile s!"{dir}/ast_nki.h" C.generateNkiAST 43 | --withFile s!"{dir}/ast_nki.py" Python.generateNkiAST 44 | --withFile s!"{dir}/ast_klir.h" C.generateKlrAST 45 | --withFile s!"{dir}/serde_common.h" Serde.generateCommonH 46 | --withFile s!"{dir}/serde_common.c" Serde.generateCommonC 47 | --withFile s!"{dir}/serde_file.h" Serde.generateFileH 48 | --withFile s!"{dir}/serde_file.c" Serde.generateFileC 49 | --withFile s!"{dir}/serde_python_core.h" Serde.generatePythonH 50 | --withFile s!"{dir}/serde_python_core.c" Serde.generatePythonC 51 | ---withFile s!"{dir}/serde_nki.h" Serde.generateNkiH 52 | ---withFile s!"{dir}/serde_nki.c" Serde.generateNkiC 53 | ---withFile s!"{dir}/serde_klir.h" Serde.generateKlrH 54 | ---withFile s!"{dir}/serde_klir.c" Serde.generateKlrC 55 | --withFile s!"{dir}/topy_nki.h" ToPython.generateNkiH 56 | --withFile s!"{dir}/topy_nki.c" ToPython.generateNkiC 57 | -- C++ 58 | withFile s!"{dir}/klir_ast.hpp" Cpp.generateKlrAST 59 | withFile s!"{dir}/klir_pretty_print.hpp" Cpp.generateKlrPrettyPrintHeader 60 | withFile s!"{dir}/klir_pretty_print.cpp" Cpp.generateKlrPrettyPrint 61 | withFile s!"{dir}/klir_serde.hpp" SerdeCpp.generateKlrH 62 | withFile s!"{dir}/klir_serde.cpp" SerdeCpp.generateKlrC 63 | withFile s!"{dir}/NKI.asdl" ASDL.generateNkiAST 64 | -------------------------------------------------------------------------------- /KLR/Py/Util.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Lean 18 | 19 | open Lean (ToJson FileMap) 20 | 21 | deriving instance ToJson for String.Pos 22 | 23 | namespace KLR.Py 24 | 25 | structure Span where 26 | pos : String.Pos := {} 27 | stopPos : String.Pos := {} 28 | deriving ToJson, Repr, Inhabited, BEq 29 | 30 | structure FileInfo where 31 | content : String 32 | fileMap : FileMap 33 | fileName : String 34 | 35 | def FileInfo.formatError (f : FileInfo) (pre : String) (msg : String) (span : Span) : String := 36 | let fileMap := f.fileMap 37 | let input := f.content 38 | 39 | let { line, column } := fileMap.toPosition span.pos 40 | let { line := lineEnd, column := columnEnd } := f.fileMap.toPosition (span.stopPos) 41 | 42 | let startPos := fileMap.ofPosition { line := line, column := 0 } 43 | let endPos := fileMap.ofPosition { line := lineEnd + 1, column := 0 } 44 | let code := input.extract startPos endPos 45 | let lines := (code.split (· == '\n')).filter (not ∘ String.isEmpty) 46 | let markedLines : List String := (lines.zip (List.range lines.length)).map ( 47 | fun (line, i) => 48 | let markStart := if i == 0 then column else 0 49 | let markEnd := if i == lines.length - 1 then columnEnd else line.length 50 | let pre := "".pushn ' ' markStart 51 | let mark := "".pushn '^' (markEnd - markStart) 52 | line ++ "\n" ++ pre ++ mark 53 | ) 54 | let markedCode := Std.Format.nest 2 ("\n" ++ Std.Format.joinSep markedLines "\n") 55 | 56 | let posMsg := s!"{pre}: {f.fileName}:{line}:{column + 1}:" 57 | let errMsg := s!"{msg}" 58 | s!"{posMsg}{markedCode}\n{errMsg}" 59 | 60 | namespace _Debug 61 | 62 | def input := "1 63 | 23456 64 | 79 65 | a 66 | 80000 67 | bcd" 68 | def pos : Span := ⟨input.find (· == '3'), input.next <| input.find (· == '8')⟩ 69 | def info : FileInfo := ⟨input, input.toFileMap, "/usr/code/my.py"⟩ 70 | def err := info.formatError "SyntaxError" "invalid syntax" pos 71 | 72 | /-- 73 | info: SyntaxError: /usr/code/my.py:2:2: 74 | 23456 75 | ^^^^ 76 | 79 77 | ^^^^ 78 | a 79 | ^^^ 80 | 80000 81 | ^^^ 82 | invalid syntax 83 | -/ 84 | #guard_msgs in #eval IO.println err 85 | 86 | end _Debug 87 | -------------------------------------------------------------------------------- /interop/klr/peg_parser/compat.c: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | */ 5 | 6 | // Forward declare private APIs that used to be in public headers, but later moved to internal headers 7 | Py_ssize_t _PyUnicode_ScanIdentifier(PyObject *); // hidden in https://github.com/python/cpython/commit/8a73b57b 8 | PyObject* _PyUnicode_DecodeUnicodeEscapeInternal(const char *, Py_ssize_t, const char *, Py_ssize_t *, const char **); // hidden in https://github.com/python/cpython/commit/d8c5d76d 9 | PyObject* _PyBytes_DecodeEscape(const char *, Py_ssize_t, const char *, const char **); // hidden in https://github.com/python/cpython/commit/7d41ead9 10 | 11 | // These functions were introduced in Python 3.10 and are used by the parser. 12 | #if PY_MINOR_VERSION < 10 13 | static inline PyObject *_Py_NewRef(PyObject *obj) { 14 | Py_INCREF(obj); 15 | return obj; 16 | } 17 | #define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) 18 | 19 | static PyObject *_PyImport_GetModuleAttrString(const char *modname, 20 | const char *attrname) { 21 | PyObject *mod = PyImport_ImportModule(modname); 22 | if (mod == NULL) 23 | return NULL; 24 | 25 | PyObject *result = PyObject_GetAttrString(mod, attrname); 26 | Py_DECREF(mod); 27 | return result; 28 | } 29 | #else 30 | PyObject* _PyImport_GetModuleAttrString(const char *, const char *); // hidden in https://github.com/python/cpython/commit/2e92edbf 31 | #endif 32 | 33 | // An alternate implementation of PyArena which uses our region allocator. 34 | 35 | typedef struct _arena { 36 | struct region *region; 37 | PyObject *objects; 38 | } PyArena; 39 | 40 | static void _PyArena_Free(PyArena *arena) { 41 | if (arena) { 42 | if (arena->region) 43 | region_destroy(arena->region); 44 | if (arena->objects) 45 | Py_DECREF(arena->objects); 46 | PyMem_Free(arena); 47 | } 48 | } 49 | 50 | static PyArena *_PyArena_New(void) { 51 | PyArena *arena = (PyArena *)PyMem_Malloc(sizeof(PyArena)); 52 | if (!arena) 53 | return (PyArena *)PyErr_NoMemory(); 54 | 55 | arena->region = region_create(); 56 | arena->objects = PyList_New(0); 57 | if (!arena->region || !arena->objects) { 58 | _PyArena_Free(arena); 59 | return (PyArena *)PyErr_NoMemory(); 60 | } 61 | return arena; 62 | } 63 | 64 | static void *_PyArena_Malloc(PyArena *arena, size_t size) { 65 | void *p = region_alloc(arena->region, size); 66 | if (!p) 67 | return PyErr_NoMemory(); 68 | return p; 69 | } 70 | 71 | static int _PyArena_AddPyObject(PyArena *arena, PyObject *obj) { 72 | int r = PyList_Append(arena->objects, obj); 73 | if (r >= 0) 74 | Py_DECREF(obj); 75 | return r; 76 | } 77 | -------------------------------------------------------------------------------- /interop/test/test_dict.py: -------------------------------------------------------------------------------- 1 | # This file exercises the Lean partial evaluator with 2 | # a set of basic unit tests. Each function is parsed, 3 | # handed to Lean, where it is checked and reduced to KLR. 4 | 5 | import os 6 | import pytest 7 | from runner import * 8 | 9 | from klr.frontend import Kernel 10 | 11 | # Success cases 12 | # (these functions should load and trace to KLR) 13 | 14 | def expr_dict(): 15 | assert {'a':1} 16 | assert dict() 17 | assert dict([('a', 1), ('b', 2)]) 18 | 19 | def modify(): 20 | d = {'a':1} 21 | assert d['a'] == 1 22 | d['a'] = 2 23 | assert d['a'] == 2 24 | d['b'] = 7 25 | assert d['b'] == 7 26 | assert d == {'a':2, 'b':7} 27 | 28 | def clear(): 29 | d = {'a':1} 30 | d.clear() 31 | assert d == dict() 32 | 33 | def copy(): 34 | d = {'a':1} 35 | d2 = d.copy() 36 | d['b'] = 4 37 | assert d == {'a':1, 'b':4} 38 | assert d2 == {'a':1} 39 | 40 | def get(): 41 | d = dict() 42 | d['a'] = 1 43 | assert d.get('a') == 1 44 | assert d.get('a', 5) == 1 45 | assert d.get('b') == None 46 | assert d.get('b', 5) == 5 47 | 48 | def items(): 49 | d = {'a':1, 'b':2} 50 | assert d.items() == [ ('a', 1), ('b', 2) ] 51 | 52 | def keys(): 53 | d = {'a':1, 'b':2} 54 | assert d.keys() == ['a','b'] 55 | 56 | def pop(): 57 | d = {'a':1, 'b':2} 58 | x = d.pop('a') 59 | assert x == 1 60 | assert d == {'b':2} 61 | x = d.pop('c', 5) 62 | assert x == 5 63 | assert d == {'b':2} 64 | x = d.pop('c') 65 | assert x == None 66 | assert d == {'b':2} 67 | 68 | def setdefault(): 69 | d = dict([('a', 1)]) 70 | x = d.setdefault('b', 2) 71 | assert x == 2 72 | assert d == {'a':1, 'b':2} 73 | x = d.setdefault('a', 2) 74 | assert x == 1 75 | assert d == {'a':1, 'b':2} 76 | 77 | def values(): 78 | d = {'a':1, 'b':2} 79 | assert d.values() == [1, 2] 80 | 81 | def dict_len(): 82 | assert len({}) == 0 83 | assert len({'a':1, 'b':2}) == 2 84 | d1 = {} 85 | assert len(d1) == 0 86 | d2 = {'a':1, 'b':2} 87 | assert len(d2) == 2 88 | d2.pop('c') 89 | assert len(d2) == 2 90 | d2.pop('a') 91 | assert len(d2) == 1 92 | d2.pop('b') 93 | assert len(d2) == 0 94 | 95 | # test each function in turn 96 | @pytest.mark.parametrize("f", [ 97 | expr_dict, 98 | modify, 99 | clear, 100 | copy, 101 | get, 102 | items, 103 | keys, 104 | pop, 105 | setdefault, 106 | values, 107 | dict_len, 108 | ]) 109 | def test_succeed(f): 110 | run_success(f, ()) 111 | 112 | # Failing cases 113 | # (These functions are expected to fail elaboration to KLR) 114 | 115 | def out_of_bounds(): 116 | l = [1,2] 117 | l[3] = 0 118 | 119 | @pytest.mark.parametrize("f", [ 120 | out_of_bounds, 121 | ]) 122 | def test_fails(f): 123 | run_fail(f, ()) 124 | -------------------------------------------------------------------------------- /interop/klr/Makefile: -------------------------------------------------------------------------------- 1 | # A simple Makefile for local development 2 | 3 | PY_VER ?= 3.10 4 | PY_CFG := $(shell which python${PY_VER}-config) 5 | PY_EXT := $(shell ${PY_CFG} --extension-suffix) 6 | PY_DIR := $(shell ${PY_CFG} --prefix) 7 | #PY_CFLAGS := $(shell ${PY_CFG} --cflags) 8 | PY_CFLAGS := -I${PY_DIR}/include/python${PY_VER} 9 | PY_LDFLAGS := $(shell ${PY_CFG} --ldflags) 10 | PY_LIBS := "-lpython${PY_VER}" 11 | 12 | LEAN_CFLAGS := $(shell leanc --print-cflags) 13 | LEAN_LDFLAGS := -L$(shell lean --print-prefix)/lib $(shell leanc --print-ldflags) 14 | 15 | CFLAGS := ${LEAN_CFLAGS} ${PY_CFLAGS} 16 | LDFLAGS := ${KLR_LIBS} ${LEAN_LIBS} ${LEAN_LDFLAGS} ${PY_LDFLAGS} ${PY_LIBS} 17 | CC := clang -std=c17 -g -Wall -Wextra -Wno-unused-command-line-argument 18 | 19 | C := region.c peg_parser.c gather.c frontend.c klr_ffi.c 20 | O := $(patsubst %.c,%.${PY_VER}.o,$C) 21 | 22 | TESTS := cbor_test 23 | 24 | .SUFFIXES: 25 | .PHONY: all versions test clean 26 | 27 | all: frontend${PY_EXT} load_test 28 | 29 | versions: 30 | $(MAKE) PY_VER=3.9 31 | $(MAKE) PY_VER=3.10 32 | $(MAKE) PY_VER=3.11 33 | $(MAKE) PY_VER=3.12 34 | $(MAKE) PY_VER=3.13 35 | 36 | test: ${TESTS} 37 | 38 | clean: 39 | rm -f *.o *.so 40 | rm -f ${TESTS} 41 | rm -f load_test 42 | 43 | peg_parser.${PY_VER}.o: $(wildcard *.c) 44 | 45 | %.${PY_VER}.o: %.c $(wildcard *.h) 46 | ${CC} ${CFLAGS} -c $< -o $@ 47 | 48 | # Lean dependencies 49 | # build up KLR_A list with targets back to lake 50 | 51 | D := ../../.lake 52 | LAKE := lake -d ../.. 53 | P := ${D}/packages 54 | 55 | # Main KLR library 56 | A := ${D}/build/lib/libKLR.a 57 | KLR_A := $A 58 | $A: 59 | ${LAKE} build 60 | 61 | # TensorLib 62 | A := ${P}/TensorLib/.lake/build/lib/libTensorLib.a 63 | KLR_A += $A 64 | $A: 65 | ${LAKE} build TensorLib/TensorLib:static 66 | 67 | # Plausible 68 | A := ${P}/plausible/.lake/build/lib/libPlausible.a 69 | KLR_A += $A 70 | $A: 71 | ${LAKE} build plausible/Plausible:static 72 | 73 | # Aesop 74 | A := ${P}/aesop/.lake/build/lib/libAesop.a 75 | KLR_A += $A 76 | $A: 77 | ${LAKE} build aesop/Aesop:static 78 | 79 | # Batteries 80 | A := ${P}/batteries/.lake/build/lib/libBatteries.a 81 | KLR_A += $A 82 | $A: 83 | ${LAKE} build batteries/Batteries:static 84 | 85 | # for debugging 86 | deps: ${KLR_A} 87 | @echo deps ${KLR_A} 88 | 89 | # Python shared library 90 | frontend${PY_EXT}: $(O) $(KLR_A) 91 | ${CC} $^ -dynamiclib -o $@ ${LDFLAGS} 92 | 93 | # Basic test oc CBOR utilities 94 | cbor_test: cbor_test.c 95 | clang -std=c17 ${PY_CFLAGS} $< -o $@ 96 | ./$@ 97 | 98 | # Simulate production build 99 | load_test: *.cpp *.hpp region.c cbor.c 100 | gcc -O0 -g -std=c17 -Wall -c region.c cbor.c 101 | g++ -O0 -g --std=c++17 -Wall -o load_test load_test.cpp klir_serde.cpp klir_common.cpp cbor.o region.o 102 | -------------------------------------------------------------------------------- /interop/test/examples/average_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024, Amazon.com. All Rights Reserved 3 | 4 | NKI implementation for average pool 2D NKI tutorial. 5 | 6 | """ 7 | from apis import * 8 | 9 | def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size): 10 | """NKI kernel to compute a 2D avg-pool operation 11 | 12 | Args: 13 | in_tensor: an input tensor, of shape C x H x W 14 | pool_size: an integer representing a (square) pool-window size 15 | out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size) 16 | """ 17 | 18 | # Get input/output dimensions 19 | sz_cin, sz_hin, sz_win = in_tensor.shape 20 | sz_cout, sz_hout, sz_wout = out_tensor.shape 21 | assert sz_cin == sz_cout 22 | 23 | # Set relevant sizes 24 | sz_p = sz_cin 25 | sz_pool = pool_size 26 | 27 | # Generate tensor h/w index patterns 28 | # 3D indexing according to [C, H, W] 29 | i_p = nl.arange(sz_p)[:, None, None] # 3D for 30 | i_win = nl.arange(sz_win)[None, None, :] 31 | i_hin = nl.arange(sz_hin)[None, :, None] 32 | 33 | i_wout = nl.arange(sz_wout)[None, None, :] 34 | i_hout = nl.arange(sz_hout)[None, :, None] 35 | 36 | # Generate pool index patterns (requires two extra dimensions, for the pool window) 37 | i_0 = nl.arange(sz_p)[:, None, None, None, None] # 38 | i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer 39 | i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner 40 | i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer 41 | i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner 42 | 43 | # Load input data from external memory to on-chip memory 44 | # Declare ndarray to force a 3D tensor (temporary requirement) 45 | in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype) 46 | in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win]) 47 | 48 | # Perform the pooling operation: 49 | # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension. 50 | # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation. 51 | # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides 52 | # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2]. 53 | # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4]. 54 | out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size) 55 | 56 | # Store the results back to external memory 57 | nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile) 58 | 59 | 60 | # Reference NumPy implementation 61 | def np_average_pool_2D(in_tensor, pool_size): 62 | c, h_in, w_in = in_tensor.shape 63 | reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size) 64 | return np.nanmean(reshaped, axis=(2, 4)) 65 | -------------------------------------------------------------------------------- /KLR/Semantics/Tactics.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Markus de Medeiros 5 | -/ 6 | 7 | import KLR.Semantics.Lib 8 | import KLR.Semantics.NML 9 | import KLR.Semantics.Logic 10 | import KLR.Semantics.SmallStep 11 | import KLR.Semantics.ProofRules 12 | 13 | 14 | open Iris.BI.BIBase KLR.Core Iris NML Iris.BI 15 | 16 | macro "wp_sync_pure " t1:term ", " t2:term : tactic => 17 | `(tactic| refine Entails.trans ?_ <| wpPureSync $t1 $t2 (by simp)) 18 | 19 | macro "wp_sync_val" : tactic => 20 | `(tactic| refine Entails.trans ?_ <| wpValVal (by rfl) (by rfl)) 21 | 22 | macro "wp_desync" : tactic => 23 | `(tactic| refine Entails.trans ?_ <| wand_entails <| wpDesync) 24 | 25 | macro "wp_resync" : tactic => 26 | `(tactic| refine Entails.trans ?_ <| wand_entails <| wpResync) 27 | 28 | -- macro "dwp_left_pure " t:term : tactic => 29 | -- `(tactic| apply Entails.trans ?_ <| wand_entails <| dwpPureL $t (Hx := by simp)) 30 | -- 31 | -- macro "dwp_right_pure " t:term : tactic => 32 | -- `(tactic| apply Entails.trans ?_ <| wand_entails <| dwpPureR $t (Hx := by simp)) 33 | 34 | theorem tac_uwp_elim_triv_both {P Q : @PROP DataT} (H : P ⊢ Q) : P ⊢ emp ∗ (emp -∗ Q) := by 35 | apply H.trans 36 | iintro H 37 | isplit r 38 | · exact fun n x a a => a 39 | iintro - 40 | iexact H 41 | 42 | theorem tac_uwp_elim_triv_pre {P Q R : @PROP DataT} (H : P ⊢ R ∗ Q) : P ⊢ R ∗ (emp -∗ Q) := by 43 | apply H.trans 44 | iintro ⟨HR, HQ⟩ 45 | isplit l [HR] 46 | · iexact HR 47 | iintro - 48 | iexact HQ 49 | 50 | theorem tac_uwp_elim_triv_post {P Q : @PROP DataT} (H : P ⊢ R -∗ Q) : P ⊢ emp ∗ (R -∗ Q) := by 51 | apply H.trans 52 | iintro H 53 | isplit r [H] 54 | · exact fun n x a a => a 55 | iexact H 56 | 57 | macro "uwp_left " u:term : tactic => `(tactic| 58 | apply Entails.trans ?_ (dwpL $u ?_) <;> 59 | try simp <;> 60 | (first 61 | | apply tac_uwp_elim_triv_both 62 | | apply tac_uwp_elim_triv_pre 63 | | apply tac_uwp_elim_triv_post 64 | | skip) <;> 65 | istart) 66 | 67 | macro "uwp_right " u:term : tactic => `(tactic| 68 | apply Entails.trans ?_ (dwpR $u (by simp)) <;> 69 | try simp <;> 70 | (first 71 | | apply tac_uwp_elim_triv_both 72 | | apply tac_uwp_elim_triv_pre 73 | | apply tac_uwp_elim_triv_post 74 | | skip) <;> 75 | istart) 76 | 77 | macro "dwp_left " u:term : tactic => `(tactic| 78 | apply Entails.trans ?_ $u <;> 79 | simp <;> 80 | (first 81 | | apply tac_uwp_elim_triv_both 82 | | apply tac_uwp_elim_triv_pre 83 | | apply tac_uwp_elim_triv_post 84 | | skip) <;> 85 | istart) 86 | 87 | macro "dwp_right " u:term : tactic => `(tactic| 88 | apply Entails.trans ?_ $u <;> 89 | simp <;> 90 | (first 91 | | apply tac_uwp_elim_triv_both 92 | | apply tac_uwp_elim_triv_pre 93 | | apply tac_uwp_elim_triv_post 94 | | skip) <;> 95 | istart) 96 | -------------------------------------------------------------------------------- /KLR/Py/PosLemmas.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | /-! # Lemmas used to prove progress in the tokenizer -/ 18 | 19 | set_option grind.warning false 20 | 21 | abbrev PosGe (pos : String.Pos) := 22 | {endPos : String.Pos // endPos.1 ≥ pos.1} 23 | 24 | abbrev PosGt (pos : String.Pos) := 25 | {endPos : String.Pos // endPos.1 > pos.1} 26 | 27 | def PosGt.next' {s : String} (pos : String.Pos) (h : ¬s.atEnd pos = true) 28 | : PosGt pos := 29 | ⟨s.next' pos h, by grind [String.pos_lt_eq, String.lt_next', String.next'_eq]⟩ 30 | 31 | def PosGt.fromNext' {s : String} {pos : String.Pos} 32 | (h : ¬s.atEnd pos = true) (next : PosGt (s.next' pos h)) : PosGt pos := 33 | ⟨next.val, by grind [String.pos_lt_eq, String.lt_next', String.next'_eq]⟩ 34 | 35 | def PosGt.fromLe {pos1 pos2 : String.Pos} 36 | (h : pos1.1 ≤ pos2.1) (posGt : PosGt pos2) : PosGt pos1 := 37 | ⟨posGt.val, by grind⟩ 38 | 39 | theorem String.lt_end {s : String} {pos : String.Pos} (h : ¬s.atEnd pos = true) 40 | : pos.1 < s.endPos.1 := by 41 | simp only [String.atEnd, ge_iff_le, decide_eq_true_eq, Nat.not_le] at h 42 | exact h 43 | 44 | theorem String.lt_sub_next {s : String} {pos : String.Pos} (h : ¬s.atEnd pos = true) 45 | : s.endPos.1 - (s.next' pos h).1 < s.endPos.1 - pos.1 := by 46 | have hlt := String.lt_next' s pos 47 | have heq := String.next'_eq _ _ h 48 | have := String.lt_end h 49 | rw [heq] 50 | simp_all only [pos_lt_eq, next'_eq, gt_iff_lt] 51 | omega 52 | 53 | theorem String.findAux_le_start {s : String} {p : Char → Bool} {stopPos pos : String.Pos} 54 | : pos.1 ≤ (s.findAux p stopPos pos).1 := by 55 | rw [String.findAux] 56 | split 57 | next h => 58 | split 59 | · grind 60 | · simp only 61 | have := Nat.sub_lt_sub_left h (lt_next s pos) 62 | have := @String.findAux_le_start s p stopPos (s.next pos) 63 | grind 64 | · grind 65 | termination_by stopPos.1 - pos.1 66 | 67 | theorem String.add_gt 68 | : ∀ {pos : String.Pos} {s : String}, 69 | s.utf8ByteSize > 0 → (pos + s).byteIdx > pos.byteIdx := by 70 | intro pos s h 71 | simp_all only [Bool.not_eq_true, gt_iff_lt, String.Pos.byteIdx_addString, Nat.lt_add_right_iff_pos] 72 | 73 | /-- 74 | Tries to solve theorems of the form `s.utf8ByteSize > n`. 75 | -/ 76 | macro "simp_str_size" : tactic => `(tactic|( 77 | rw [←String.size_toUTF8] 78 | simp [String.toUTF8, String.utf8EncodeChar, ByteArray.size] 79 | )) 80 | -------------------------------------------------------------------------------- /KLR/Trace/Lang.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Core 18 | import KLR.Trace.ISA 19 | import KLR.Trace.Types 20 | 21 | /- 22 | NKI Language builtins 23 | -/ 24 | 25 | namespace KLR.Trace 26 | open Core 27 | 28 | nki builtin.lang.ndarray 29 | (shape : Shape) 30 | (dtype : Dtype) 31 | (buffer : Option Memory := none) 32 | (name : Option String := none) 33 | (address : Option (Nat × Nat) := none) 34 | (address_rotation : Option Bool := none) := do 35 | let memory := buffer.getD .sbuf 36 | let (parSize, freeSize) := Address.defaultSize shape dtype 37 | let (parOffset, freeOffset) := match address with 38 | | some (par, free) => (some par, some free) 39 | | none => (none, none) 40 | let name <- tensorName name 41 | let address_rotation <- match address_rotation with 42 | | some v => pure v 43 | | none => flags.address_rotation 44 | let address := { name, memory, parSize, freeSize, parOffset, freeOffset : Address } 45 | let tensor <- TensorName.make name dtype shape address address_rotation 46 | return .access (.simple tensor) 47 | 48 | nki builtin.lang.par_dim (t : Term) := do 49 | warn "par_dim is deprecated" 50 | return t 51 | 52 | nki builtin.lang.program_id (axis : Int) := do 53 | if axis != 0 then 54 | throw s!"invalid program axis {axis} (must be zero)" 55 | lookup (nl "_program_id") 56 | 57 | nki builtin.lang.num_programs (axes : Option Int := none) := do 58 | if axes.getD 0 != 0 then 59 | throw s!"invalid program axis {axes} (must be zero)" 60 | lookup (nl "_num_programs") 61 | 62 | nki builtin.lang.program_ndim := do 63 | lookup (nl "_program_ndim") 64 | 65 | nki builtin.lang.ds (start : Int) (size : Int) := do 66 | return .slice start (start + size) (some 1) 67 | 68 | nki builtin.lang.unique_name (name : String) := do 69 | let uniqueName := <- genName name.toName 70 | return .string uniqueName.toString 71 | 72 | nki builtin.lang.device_print 73 | (print_prefix : String) 74 | (tensor : Access) 75 | (output_buffer : Option PrintOutputBuffer := none) 76 | (mask: Option Immediate := none) := do 77 | if mask.isSome then throw "mask parameter is not supported" 78 | let buffer := output_buffer.getD .stdout 79 | Trace.add_stmt $ .oper (.devicePrint { 80 | src := .abstract tensor 81 | printPrefix := print_prefix 82 | buffer 83 | }) print_prefix 84 | return .none 85 | -------------------------------------------------------------------------------- /interop/klr/region.c: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Govereau, Sean McLaughlin, Michael Graeb 5 | */ 6 | #include "region.h" 7 | #include "stdc.h" 8 | 9 | #include 10 | 11 | struct block { 12 | size_t size; 13 | size_t offset; 14 | struct block *next; 15 | u8 buf[0]; 16 | }; 17 | 18 | #define BLOCK_SIZE (8192 - sizeof(struct block)) 19 | #define LARGE_SIZE 7168 20 | 21 | static_assert(BLOCK_SIZE >= LARGE_SIZE, 22 | "BLOCK_SIZE must hold anything LARGE_SIZE or less"); 23 | 24 | struct region { 25 | struct block *blocks; 26 | struct block *large; 27 | }; 28 | 29 | static struct block *alloc_block(size_t size) { 30 | struct block *b = calloc(1, size + sizeof(*b)); 31 | if (b) { 32 | b->size = size; 33 | b->offset = 0; 34 | b->next = NULL; 35 | } 36 | return b; 37 | } 38 | 39 | static void free_blocks(struct block *b) { 40 | while (b) { 41 | struct block *tmp = b->next; 42 | free(b); 43 | b = tmp; 44 | } 45 | } 46 | 47 | struct region *region_create() { return calloc(1, sizeof(struct region)); } 48 | 49 | void region_destroy(struct region *region) { 50 | if (region) { 51 | free_blocks(region->blocks); 52 | free_blocks(region->large); 53 | free(region); 54 | } 55 | } 56 | 57 | void *region_try_alloc(struct region *region, size_t size) { 58 | if (unlikely(!region)) 59 | return NULL; 60 | 61 | // check for large block 62 | if (unlikely(size > LARGE_SIZE)) { 63 | struct block *b = alloc_block(size); 64 | if (unlikely(!b)) 65 | return NULL; 66 | b->next = region->large; 67 | region->large = b; 68 | return b->buf; 69 | } 70 | 71 | struct block *b = region->blocks; 72 | if (unlikely(!b || size > b->size - b->offset)) { 73 | b = alloc_block(BLOCK_SIZE); 74 | if (unlikely(!b)) 75 | return NULL; 76 | b->next = region->blocks; 77 | region->blocks = b; 78 | } 79 | 80 | void *p = b->buf + b->offset; 81 | b->offset += size; 82 | return p; 83 | } 84 | 85 | void *region_alloc(struct region *region, size_t size) { 86 | void *p = region_try_alloc(region, size); 87 | if (unlikely(!p)) { 88 | fprintf(stderr, "Out Of Memory. NKI compiler will abort the program.\n"); 89 | abort(); 90 | } 91 | 92 | return p; 93 | } 94 | 95 | char *region_strdup(struct region *region, const char *src) { 96 | assert(src); 97 | size_t size = strlen(src) + 1; 98 | char *dst = region_alloc(region, size); 99 | // copy everything, including null-terminator 100 | memcpy(dst, src, size); 101 | return dst; 102 | } 103 | 104 | char *region_strndup(struct region *region, const char *src, size_t len) { 105 | assert(src); 106 | 107 | // check for null-terminator earlier than `len` 108 | len = strnlen(src, len); 109 | 110 | char *dst = region_alloc(region, len + 1); 111 | memcpy(dst, src, len); 112 | dst[len] = 0; // add null-terminator 113 | return dst; 114 | } 115 | -------------------------------------------------------------------------------- /interop/test/examples/mm.py: -------------------------------------------------------------------------------- 1 | from apis import * 2 | 3 | def matmul_128x128x512_spmd_nisa(A_T, B, result): 4 | """NKI kernel to compute a 128x128x512 matrix multiplication operation 5 | Use SPMD program IDs to index into the full A and B input tensor to get tiles 6 | for 128x128x512 matrix multiplication. 7 | 8 | Args: 9 | A_T: an input tensor of shape [K=128,M=512], 10 | a left hand side argument of the matrix multiplication, 11 | B: an input tensor of shape [K=128,N=1024], 12 | a right hand side argument of the matrix multiplication 13 | result: the resulting output tensor of shape [M=128,N=512] 14 | """ 15 | # Defining starting indexes for input A.T and B 16 | i_A_T_col = nl.program_id(0) * 128 17 | i_B_col = nl.program_id(1) * 512 18 | 19 | # Loading the inputs (HBM->SBUF) 20 | A_T_tile = nl.load(A_T[0:128, i_A_T_col:i_A_T_col+128]) 21 | B_tile = nl.load(B[0:128, i_B_col:i_B_col+512]) 22 | 23 | # Perform the matrix-multplication 24 | # Note1: p-dim of both input tiles is mapped to the contraction dimension, aligned 25 | # with TensorE layout requirements (LayoutConstraint #1: For MatMult, contraction 26 | # axis must be mapped to P-dim) 27 | # Note2: A NKI matmul instruction always writes to PSUM in float32 data-type 28 | result_psum = nisa.nc_matmul(A_T_tile, B_tile) 29 | 30 | # Copy the result from PSUM back to SBUF, and cast to expected output data-type 31 | result_sbuf = nl.copy(result_psum, dtype=result.dtype) 32 | 33 | # Store back into result tile with the correct SPMD offsets. 34 | nl.store(result[i_A_T_col:i_A_T_col+128, i_B_col:i_B_col+512], value=result_sbuf) 35 | 36 | 37 | def matmul_128x128x512_nl(A, B, result): 38 | """NKI kernel to compute a 128x128x512 matrix multiplication operation. 39 | Use SPMD program IDs to index into the full A and B input tensor to get tiles 40 | for 128x128x512 matrix multiplication. 41 | 42 | Args: 43 | A: an input tensor of shape [M=512,K=128], 44 | a left hand side argument of the matrix multiplication, 45 | B: an input tensor of shape [K=128,N=1024], 46 | a right hand side argument of the matrix multiplication 47 | result: the resulting output tensor of shape [M=128,N=512] 48 | """ 49 | # Defining starting indexes for input A and B 50 | i_A_row = nl.program_id(0) * 128 51 | i_B_col = nl.program_id(1) * 512 52 | 53 | # Loading the inputs (HBM->SBUF) 54 | A_tile = nl.load(A[i_A_row:i_A_row+128, 0:128]) 55 | B_tile = nl.load(B[0:128, i_B_col:i_B_col+512]) 56 | 57 | # Perform the matrix-multiplication 58 | # Note1: nl.matmul will invoke a transpose on A_tile before performing the actual matmul operation 59 | # Note2: A NKI matmul instruction always writes to PSUM in float32 data-type 60 | result_psum = nl.matmul(A_tile, B_tile) 61 | 62 | # Copy the result from PSUM back to SBUF, and cast to expected output data-type 63 | result_sbuf = nl.copy(result_psum, dtype=result.dtype) 64 | 65 | # The result of a [128,128] x [128,512] matrix multiplication has a shape of [128, 512]. 66 | # This dictates which indices to use to address the result tile. 67 | nl.store(result[i_A_row:i_A_row+128, i_B_col:i_B_col+512], value=result_sbuf) 68 | -------------------------------------------------------------------------------- /KLR/CompileHLO.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Cli 18 | import KLR 19 | import SHerLOC 20 | import KLR.TGR.Basic 21 | import SHerLOC.Analysis.Graph 22 | import KLR.K.K3.CompileK3 23 | import KLR.K.K2.CompileK2 24 | import KLR.K.K2.AST 25 | import KLR.K.K3.DotK3 26 | import KLR.K.K1.CompileK1 27 | import KLR.K.K1.AST 28 | import KLR.K.K1.InterpK1 29 | 30 | private def unwrap {T Q : Type} [Inhabited T] [Inhabited Q] 31 | (x : Except String T × Q) : T × Q := 32 | match x with 33 | | (.ok a, s) => (a, s) 34 | | (.error e, _) => panic! s!"Error: {e}" 35 | 36 | def compileHlo (p : Parsed) : IO UInt32 := do 37 | let file := p.positionalArg! "file" |>.as! String 38 | let s <- IO.FS.readFile file 39 | match StableHLO.Parsing.parse s with 40 | | .ok (hlo, _) => 41 | -- compile HLO to TGR 42 | let (_, s) := KLR.TGR.Compile.compile hlo |> unwrap 43 | let func := s.program.functions.head! 44 | if p.hasFlag "intermediates" then 45 | writeContent "tgr.txt" p (toString s.program) 46 | writeContent "tgr.dot" p s!"{KLR.TGR.Graph.graph func}" 47 | writeContent "py" p (KLR.TGR.Py.compile s.program) 48 | -- compile TGR to K3 49 | IO.println s!"Compiling TGR to K3" 50 | let (func, _) := KLR.K.K3.compile func |> unwrap 51 | if p.hasFlag "intermediates" then 52 | writeContent "k3.txt" p s!"{func}" 53 | writeContent "k3.dot" p (KLR.K.K3.graph func |> toString) 54 | -- compile K3 to K2 55 | IO.println s!"Compiling K3 to K2" 56 | let (func, _) := KLR.K.K2.compile func |> unwrap 57 | if p.hasFlag "intermediates" then 58 | writeContent "k2.txt" p s!"{KLR.K.K2.formatProgramK2 func}" 59 | -- compile K2 to K1 60 | IO.println s!"Compiling K2 to K1" 61 | let (func, _) := KLR.K.K1.compile func |> unwrap 62 | writeContent "k1.txt" p s!"{KLR.K.K1.formatProgramK1 func}" 63 | if p.hasFlag "evaluate" then 64 | IO.println s!"Evaluating K1" 65 | match KLR.K.K1.Interp.interp func with 66 | | (.ok (), ctx) => 67 | writeContent "k1.result" p s!"{ctx}" 68 | | (.error e, ctx) => 69 | IO.eprintln s!"Error interpreting K1: {e}" 70 | writeContent "k1.err" p s!"{ctx}" 71 | return 0 72 | | .error e => 73 | IO.eprintln s!"Error parsing HLO: {e}" 74 | return 1 75 | 76 | def compileHloCmd := `[Cli| 77 | "compile-hlo" VIA compileHlo; 78 | "Compile HLO graph to KLR graph" 79 | 80 | FLAGS: 81 | o, outfile : String; "Name of output file" 82 | i, intermediates : Unit; "Write intermediate files" 83 | e, evaluate : Unit; "Evaluate the K1 program on random inputs as a smoke test" 84 | ARGS: 85 | file : String; "File of HLO graph in .mlir format" 86 | ] 87 | -------------------------------------------------------------------------------- /KLR/Extract/Extract/Python.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Extract.Basic 18 | import KLR.Python 19 | import Lean 20 | 21 | /- 22 | Output functions for Python 23 | -/ 24 | 25 | namespace Extract.Python 26 | open Lean Meta 27 | 28 | private def PyName (name : Name) (f : String -> String := id) : String := 29 | match name with 30 | | .str _ s => f (s.replace "'" "_") 31 | | _ => panic! "found invalid name" 32 | 33 | instance : ToString Name where toString n := PyName n 34 | 35 | private def genType : SimpleType -> String 36 | | .bool => "bool" 37 | | .nat | .int => "int" 38 | | .float => "float" 39 | | .string => "str" 40 | | .prop => panic! "TODO" 41 | | .const name 42 | | .enum name => s!"\"{PyName name}\"" 43 | | .option .string => "str" 44 | | .option t => s!"Optional[{genType t}]" 45 | | .list .string => "list[str]" 46 | | .list t => s!"list[{genType t}]" 47 | | .pair .. => panic! "TODO" 48 | 49 | private def under (s : String) : String := 50 | if s == "" || s.endsWith "_" then s 51 | else s ++ "_" 52 | 53 | private def genPyType (ty : LeanType) (pre : String := "") : MetaM Unit := 54 | match ty with 55 | | .simple _ => pure () 56 | | .prod name fields => do 57 | IO.println "" 58 | IO.println s!"class {under pre}{name}(NamedTuple):" 59 | if fields.length == 0 then 60 | IO.println " pass" 61 | for f in fields do 62 | IO.println s!" {f.name} : {genType f.type}" 63 | | .sum name variants => do 64 | if ty.isEnum then do 65 | IO.println "" 66 | IO.println s!"class {name}(Enum):" 67 | for v in variants do 68 | IO.println s!" {(toString v.name).capitalize} = auto()" 69 | IO.println "" 70 | for v in variants do 71 | IO.println s!"def {name}_{v.name}(): return {name}.{(toString v.name).capitalize}" 72 | else do 73 | let mut tys := [] 74 | for v in variants do 75 | genPyType v (PyName name) 76 | tys := (under (PyName name) ++ PyName v.name) :: tys 77 | let rhs := String.intercalate " | " tys 78 | IO.println "" 79 | IO.println s!"{name} = {rhs}" 80 | return () 81 | 82 | private def header := 83 | "# This file is automatically generated from KLR 84 | # Manual edits will be overwritten 85 | 86 | from typing import NamedTuple, Optional 87 | from enum import Enum, auto" 88 | 89 | def generatePythonAST : MetaM Unit := do 90 | IO.println header 91 | let tys <- pythonAST 92 | let pty <- collectLeanType `KLR.Core.Pos 93 | for t in pty :: tys do 94 | genPyType t 95 | return () 96 | 97 | def generateNkiAST : MetaM Unit := do 98 | IO.println header 99 | let tys <- nkiAST 100 | let pty <- collectLeanType `KLR.Core.Pos 101 | for t in pty :: tys do 102 | genPyType t 103 | return () 104 | -------------------------------------------------------------------------------- /KLR/Serde/File.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util 18 | import KLR.Serde.Attr 19 | import KLR.Serde.Basic 20 | import KLR.Serde.Elab 21 | 22 | namespace KLR.Serde 23 | 24 | /- 25 | We would like all of our CBOR encoded files to be easily identifiable as KLR 26 | files, both for our tools and for third-party tools that understand CBOR. The 27 | CBOR specification defines a special tag value (55799), which indicates that 28 | what follows is a CBOR encoded value. This special value is meant to help 29 | decoders identify CBOR data streams. When this value is encoded as a CBOR 30 | tagged value, it will be three bytes, the 16-bit tag indicator (0xd9), followed 31 | by the big-endian representation of 55799 (0xd9 0xf7): 32 | 33 | 0xd9 0xd9 0xf7 , 34 | 35 | and this three byte sequence can be used to identify a CBOR file. Quoting the 36 | specification: 37 | 38 | The serialization of this tag's head is 0xd9d9f7, which does not appear to be 39 | in use as a distinguishing mark for any frequently used file types. In 40 | particular, 0xd9d9f7 is not a valid start of a Unicode text in any Unicode 41 | encoding if it is followed by a valid CBOR data item. 42 | 43 | We leverage this by specifically marking our `KLRFile` structure with a tag of 44 | 0xd9, and changing the `mk` constructor to have tag 0xf7. When this is 45 | converted to a 16-bit tag by our `cborTag` function (see Serde.Basic), we end 46 | up with 0xd9d9f7, the "self-described CBOR" marker. 47 | 48 | The `KLRFile` structure contains a semantic version number, and nothing else so 49 | that tools unaware of our format can read the header. Hence, every KLR file 50 | starts with: 51 | 52 | 0xd9 0xd9 0xf7 0x83 major minor patch 53 | | | | | 54 | | | | +-- list of length 3 55 | | +----+-- 55799 tag value (self-described CBOR) 56 | +-- CBOR code for 16-bit tagged value 57 | 58 | This is all handled automatically by our derived instances. 59 | -/ 60 | 61 | @[serde tag = 0xd9] 62 | structure KLRFile where 63 | major : Nat := 0 -- TODO come up with a way to manage versions 64 | minor : Nat := 0 65 | patch : Nat := 12 66 | deriving BEq, Repr 67 | 68 | attribute [serde tag = 0xf7] KLRFile.mk 69 | 70 | deriving instance ToCBOR for KLRFile 71 | deriving instance FromCBOR for KLRFile 72 | 73 | #guard (toCBOR { : KLRFile}).take 4 == .mk #[0xd9, 0xd9, 0xf7, 0x83] 74 | #guard (fromCBOR (toCBOR { : KLRFile }) : Err KLRFile) == .ok { : KLRFile } 75 | 76 | /- 77 | Immediately following the KLRFile structure we place a KLRMetaData structure. 78 | We are careful to choose a tag value (0xeb00) that is unassigned according to: 79 | 80 | https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml 81 | 82 | Because everything is contained within this tagged value, we do not have to 83 | worry about avoiding assigned tags on our other (internal) types. 84 | -/ 85 | 86 | -- TODO what do we need here? 87 | @[serde tag = 0xeb] 88 | structure KLRMetaData where 89 | format : String := "NKI" 90 | deriving BEq, Repr, ToCBOR, FromCBOR 91 | -------------------------------------------------------------------------------- /KLR/NKI/Typed/Types.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.NKI.Typed.Basic 18 | import KLR.NKI.Typed.Context 19 | import KLR.Py 20 | 21 | namespace KLR.NKI.Typed 22 | 23 | open KLR.Py (Span Ident FileInfo) 24 | 25 | structure State where 26 | span : Option Span := none 27 | nameCount : Nat := 0 28 | warnings : List String := [] 29 | retTyp : Option Typ := none 30 | inLoop : Bool := false 31 | 32 | abbrev ElabM := ReaderT FileInfo <| EStateM String State 33 | 34 | /-! # -----------Start: Error Reporting-------------- -/ 35 | 36 | def throwAt {α} (span : Span) (pre : String) (msg : String) : ElabM α := do 37 | let msg := (← read).formatError pre msg span 38 | MonadExcept.throw msg 39 | 40 | def throw {α} (pre : String) (msg : String) : ElabM α := do 41 | let pos := (← get).span.getD {} 42 | throwAt pos pre msg 43 | 44 | def throwInternal {α} (msg : String) : ElabM α := 45 | throw "InternalError" msg 46 | 47 | def throwSyntax {α} (msg : String) : ElabM α := 48 | throw "SyntaxError" msg 49 | 50 | def throwType {α} (msg : String) : ElabM α := 51 | throw "TypeError" msg 52 | 53 | def throwEmptyStmts {α} : ElabM α := 54 | throwType "cannot have empty statements here" 55 | 56 | def throwUnsupported {α} (msg : String) : ElabM α := 57 | throw "Unsupported" msg 58 | 59 | /-! # -----------End: Error Reporting-------------- -/ 60 | 61 | /-! # -----------Start: Getters/Setters-------------- -/ 62 | 63 | def getSpan! : ElabM Span := do 64 | let some span := (← get).span 65 | | throwInternal "missing position information" 66 | return span 67 | 68 | def getRetTyp : ElabM (Option Typ) := 69 | return (← get).retTyp 70 | 71 | def inLoop : ElabM Bool := 72 | return (← get).inLoop 73 | 74 | def enterLoop : ElabM Unit := 75 | modifyGet fun s => 76 | ((), {s with inLoop := true}) 77 | 78 | def withLoop {α} (x : ElabM α) : ElabM α := do 79 | let saved ← inLoop 80 | enterLoop 81 | let a ← x 82 | modifyGet fun s => 83 | (a, {s with inLoop := saved}) 84 | 85 | def withRetTyp {α} (retTyp : Typ) (x : ElabM α) : ElabM α := do 86 | let s ← get 87 | let saved := s.retTyp 88 | set {s with retTyp := retTyp} 89 | let a ← x 90 | modifyGet fun s => 91 | (a, {s with retTyp := saved}) 92 | 93 | def withSpan {α} (span : Span) (x : ElabM α) : ElabM α := do 94 | let s ← get 95 | let savedSpan := s.span 96 | set {s with span := span} 97 | let a ← x 98 | modifyGet fun s => 99 | (a, {s with span := savedSpan}) 100 | 101 | def applyCtxToRet (Γ : Context) : ElabM Unit := 102 | modifyGet fun s => 103 | ((), {s with retTyp := s.retTyp.map Γ.apply}) 104 | 105 | /-! # -----------End: Getters/Setters-------------- -/ 106 | 107 | /-! # -----------Start: Name Generation-------------- -/ 108 | 109 | def freshName : ElabM Ident := 110 | modifyGet fun s => 111 | (s!"\'T{s.nameCount}", {s with nameCount := s.nameCount + 1}) 112 | 113 | /-! # -----------End: Name Generation-------------- -/ 114 | -------------------------------------------------------------------------------- /KLR/Util/BitVec.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.Common 18 | import Plausible 19 | 20 | namespace KLR.Util.BitVec 21 | 22 | open Plausible(Gen) 23 | 24 | /-- Convert a character to its ASCII/UTF-8 value as a BitVec 8 -/ 25 | private def charToBitVec8 (c : Char) : BitVec 8 := 26 | BitVec.ofNat 8 c.toNat 27 | 28 | private def nul : Char := Char.ofNat 0 29 | 30 | /-- Convert a BitVec 8 to a character -/ 31 | private def bitVec8ToChar (bv : BitVec 8) : Char := 32 | Char.ofNat bv.toNat 33 | 34 | /-- Convert a list of BitVec 8 values to a string, stopping at the first null char -/ 35 | private def bitVecListToString (bvList : List (BitVec 8)) : String := 36 | String.mk $ (bvList.map bitVec8ToChar).takeWhile fun c => c != nul 37 | 38 | private def isAsciiChar (c : Char) : Bool := c.toNat < 128 39 | 40 | private def isAscii (s : String) : Bool := s.data.all isAsciiChar 41 | 42 | /-- Convert a string to a single BitVec of appropriate size in little-endian order -/ 43 | def asciiStringToBitVec (n : Nat) (s : String) : Err (BitVec n) := do 44 | if n % 8 != 0 then throw "size should be divisible by 8" else 45 | if !isAscii s then throw "not an ascii string" else 46 | if n < s.length * 8 then throw "string is too long for storage" else 47 | let mut bytes : BitVec n := 0 48 | let mut i := 0 49 | for byte in s.toList.map charToBitVec8 do 50 | let shiftAmount := i * 8 51 | let shiftedB := BitVec.zeroExtend n byte 52 | let shiftedVal := shiftedB <<< shiftAmount 53 | i := i + 1 54 | bytes := bytes ||| shiftedVal 55 | return bytes 56 | 57 | /-- Convert a BitVec back to a string, assuming little-endian byte order and known string length -/ 58 | def bitVecToAsciiString {n : Nat} (bv : BitVec n) : Err String := do 59 | if n % 8 != 0 then throw "size should be divisible by 8" else 60 | let mut chars := [] 61 | for i in [0:n/8] do 62 | let shiftAmount := i * 8 63 | let mask := BitVec.zeroExtend n (BitVec.ofNat 8 0xFF) 64 | let shiftedMask := mask <<< shiftAmount 65 | let extractedByte := (bv &&& shiftedMask) >>> shiftAmount 66 | chars := BitVec.truncate 8 extractedByte :: chars 67 | return bitVecListToString chars.reverse 68 | 69 | private def smallNatGen : Gen Nat := Plausible.Gen.choose Nat 0 128 (by omega) 70 | private def asciiCharGen : Gen Char := smallNatGen.bind fun n => return Char.ofNat n 71 | private def asciiStringGen : Gen String := do 72 | let l <- Plausible.Gen.listOf asciiCharGen 73 | return String.mk l 74 | 75 | private def roundTrip (n : Nat) (s : String) : Err String := do 76 | bitVecToAsciiString (<- asciiStringToBitVec n s) 77 | 78 | #guard roundTrip 32 "sean" == .ok "sean" 79 | #guard roundTrip 48 "sean" == .ok "sean" -- we truncate 0s 80 | #guard roundTrip 24 "sean" == .error "string is too long for storage" 81 | #guard roundTrip 25 "sean" == .error "size should be divisible by 8" 82 | #guard roundTrip 24 "🏎️" == .error "not an ascii string" 83 | -- We 0-pad to fill the gaps 84 | #guard get! $ do 85 | let lhs <- asciiStringToBitVec 48 "sean" 86 | let rhs <- asciiStringToBitVec 32 "sean" 87 | let rhs := BitVec.zeroExtend 48 rhs 88 | return lhs == rhs 89 | 90 | end KLR.Util.BitVec 91 | -------------------------------------------------------------------------------- /KLR/Util/Hex.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Plausible 18 | 19 | namespace KLR.Util.Hex 20 | 21 | def encode (input : ByteArray) : String := Id.run do 22 | let hexChars := "0123456789abcdef".toList 23 | let mut result := "" 24 | for b in input do 25 | let hi := (b >>> 4).toNat 26 | let lo := (b &&& 0xF).toNat 27 | result := result.push (hexChars[hi]!) 28 | result := result.push (hexChars[lo]!) 29 | return result 30 | 31 | private def hexCharToNibble (c : Char) : Option UInt8 := 32 | let u10 : UInt8 := 10 33 | if '0' ≤ c && c ≤ '9' then c.toUInt8 - '0'.toUInt8 34 | else if 'a' ≤ c && c ≤ 'f' then c.toUInt8 - 'a'.toUInt8 + u10 35 | else if 'A' ≤ c && c ≤ 'F' then c.toUInt8 - 'A'.toUInt8 + u10 36 | else none 37 | 38 | private def hexCharToUInt8 (high : Char) (low : Char) : Option UInt8 := do 39 | let highNibble ← hexCharToNibble high 40 | let lowNibble ← hexCharToNibble low 41 | return (highNibble <<< 4) + lowNibble 42 | 43 | def hexCharsToBitVecBE (c0 c1 c2 c3 c4 c5 c6 c7: Char) : Option (BitVec 32) := do 44 | let b0 := (← hexCharToUInt8 c0 c1).toBitVec 45 | let b1 := (← hexCharToUInt8 c2 c3).toBitVec 46 | let b2 := (← hexCharToUInt8 c4 c5).toBitVec 47 | let b3 := (← hexCharToUInt8 c6 c7).toBitVec 48 | pure (b0 ++ b1 ++ b2 ++ b3) 49 | 50 | def decode (s : String) : Option ByteArray := Id.run do 51 | let rec split : List Char -> List (Char × Char) 52 | | [] | [_] => [] 53 | | x0 :: x1 :: xs => (x0, x1) :: split xs 54 | let s := if s.length % 2 == 0 then s else "0" ++ s 55 | let mut buf := ByteArray.emptyWithCapacity (s.length / 2) 56 | for (hi, lo) in split s.data do 57 | match hexCharToUInt8 hi lo with 58 | | none => return none 59 | | some n => buf := buf.push n 60 | return buf 61 | 62 | def encodeString (input : String) : String := encode input.toUTF8 63 | 64 | def decodeString (input : String) : Option String := 65 | match decode input with 66 | | none => none 67 | | some bytes => String.fromUTF8? bytes 68 | 69 | section Test 70 | 71 | open Plausible 72 | 73 | private local instance : Repr ByteArray where 74 | reprPrec arr n := reprPrec arr.data n 75 | 76 | private local instance : BEq ByteArray where 77 | beq (x y : ByteArray) := x.data == y.data 78 | 79 | private local instance : Shrinkable ByteArray where 80 | 81 | private local instance : SampleableExt ByteArray := 82 | SampleableExt.mkSelfContained do 83 | let data ← SampleableExt.interpSample (Array UInt8) 84 | return ByteArray.mk data 85 | 86 | #guard encodeString "plausible" == "706c61757369626c65" 87 | 88 | #guard 89 | let b := ByteArray.mk #[1] 90 | decode (encode b) == b 91 | 92 | #guard 93 | let s := "klr-is-the-best" 94 | let e := encodeString s 95 | let d := decodeString e 96 | s == some d 97 | 98 | /-- 99 | info: Unable to find a counter-example 100 | --- 101 | warning: declaration uses 'sorry' 102 | -/ 103 | #guard_msgs in 104 | example (arr : ByteArray) : 105 | let s := encode arr 106 | let v := decode s 107 | some arr == v := by plausible 108 | 109 | end Test 110 | 111 | end KLR.Util.Hex 112 | -------------------------------------------------------------------------------- /KLR/Util/Float.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.Common 18 | import KLR.Util.Plausible 19 | 20 | /- 21 | Utilities for the builtin Float types 22 | -/ 23 | 24 | namespace KLR.Util 25 | 26 | /- 27 | A simple implementation of String -> Float 28 | 29 | Note: Could easily add support for e.g. 3e5 since ofScientific supports this. 30 | -/ 31 | 32 | private def sign : List Char -> (Bool × List Char) 33 | | '+' :: s => (true, s) 34 | | '-' :: s => (false, s) 35 | | s => (true, s) 36 | 37 | private def dot : List Char -> (Bool × List Char) 38 | | '.' :: s => (true, s) 39 | | s => (false, s) 40 | 41 | -- Note: String.toNat? would require we find the dot and split the string, 42 | -- so just do the whole thing here in one spot 43 | private def nat (acc : Nat) : List Char -> (Nat × List Char) 44 | | [] => (acc, []) 45 | | ch :: s => Id.run do 46 | if ch < '0' || ch > '9' then 47 | return (acc, ch :: s) 48 | let n := ch.val - '0'.val 49 | let acc := acc * 10 + n.toNat 50 | nat acc s 51 | 52 | private def floatParts (cs : List Char) : (Bool × Nat × Nat) := 53 | let (sgn, cs) := sign cs 54 | let (a, cs) := nat 0 cs 55 | let (haveDot, cs) := dot cs 56 | if haveDot then 57 | let (b, cs') := nat 0 cs 58 | let d := cs.length - cs'.length 59 | let m := a * 10^d + b 60 | (sgn, m, d) 61 | else 62 | (sgn, a, 0) 63 | 64 | def parseFloat (s : String) : Float := 65 | match s with 66 | | "NaN" => 0.0 / 0.0 67 | | "inf" => 1.0 / 0.0 68 | | "-inf" => -1.0 / 0.0 69 | | _ => 70 | let (sgn, m, d) := floatParts s.toList 71 | let f := Float.ofScientific m true d 72 | if sgn then f else -f 73 | 74 | -- I tried to extract the logic between parseFloat and parseFloat32 but 75 | -- the general form is more painful than the copying of 5 lines of code. 76 | def parseFloat32 (s : String) : Float32 := 77 | match s with 78 | | "NaN" => 0.0 / 0.0 79 | | "inf" => 1.0 / 0.0 80 | | "-inf" => -1.0 / 0.0 81 | | _ => 82 | let (sgn, m, d) := floatParts s.toList 83 | let f := Float32.ofScientific m true d 84 | if sgn then f else -f 85 | 86 | private def roundTrip (f : Float) : Bool := 87 | let f' := parseFloat f.toString 88 | if f.isNaN then f'.isNaN 89 | else if f.isInf then f == f' 90 | else (f - f').abs < 0.000001 91 | 92 | #guard roundTrip (0.0 / 0.0) 93 | #guard roundTrip (1.0 / 0.0) 94 | #guard roundTrip (-1.0 / 0.0) 95 | #guard roundTrip 0.0 96 | #guard roundTrip (-0.0) 97 | #guard roundTrip (-34.55) 98 | #guard roundTrip 3.1415926 99 | 100 | private def roundTrip32 (f : Float32) : Bool := 101 | let f' := parseFloat32 f.toString 102 | if f.isNaN then f'.isNaN 103 | else if f.isInf then f == f' 104 | else (f - f').abs < 0.000001 105 | 106 | #guard roundTrip32 (0.0 / 0.0) 107 | #guard roundTrip32 (1.0 / 0.0) 108 | #guard roundTrip32 (-1.0 / 0.0) 109 | #guard roundTrip32 0.0 110 | #guard roundTrip32 (-0.0) 111 | #guard roundTrip32 (-34.55) 112 | #guard roundTrip32 3.1415926 113 | 114 | /-- 115 | info: Unable to find a counter-example 116 | --- 117 | warning: declaration uses 'sorry' 118 | -/ 119 | #guard_msgs in 120 | example (f : Float) : roundTrip f := by plausible 121 | -------------------------------------------------------------------------------- /KLR/Trace.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Core 18 | import KLR.Trace.Builtin 19 | import KLR.Trace.NKI 20 | import KLR.Trace.Python 21 | import KLR.Trace.Term 22 | import KLR.Trace.Types 23 | 24 | namespace KLR.Trace 25 | open Compile.Pass (PassM) 26 | 27 | -- Keywords recognized by the tracer (KLR keywords) 28 | -- Limits come from: 29 | -- https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/nki_arch_guides.html 30 | def keywords : List (Name × Term) := 31 | let ptr name memory parSize freeSize := 32 | (name, Term.pointer { name := name.toString, memory, parSize, freeSize }) 33 | [ ptr `hbm .hbm 0xffffffff 0xffffffff -- TODO: size of HBM? 34 | , ptr `sbuf .sbuf 128 0x30000 35 | , ptr `psum .psum 128 0x4000 36 | ] 37 | 38 | def globalEnv := keywords ++ builtinEnv ++ pythonEnv ++ NKIEnv 39 | 40 | def kernelEnv (arch : Nat) : List (Name × Term) := 41 | let base := [ 42 | const_int (`arch) arch, 43 | const_int (.str (nl "tile_size") "pmax") 128, 44 | const_int (.str (nl "tile_size") "psum_fmax") 512, 45 | const_int (.str (nl "tile_size") "gemm_stationary_fmax") 128, 46 | const_int (.str (nl "tile_size") "gemm_moving_fmax") 512, 47 | const_int (.str (nl "tile_size") "bn_stats_fmax") 512, 48 | const_int (.str (nl "tile_size") "psum_min_align") 4, 49 | const_int (.str (nl "tile_size") "sbuf_min_align") 1, 50 | ] 51 | match arch with 52 | | 2 => base ++ [const_int (.str (nl "tile_size") "total_available_sbuf_size") 180224] 53 | | 3 => base ++ [const_int (.str (nl "tile_size") "total_available_sbuf_size") 212984] 54 | | 4 => base ++ [const_int (.str (nl "tile_size") "total_available_sbuf_size") 245752] 55 | | _ => [] 56 | 57 | def runNkiKernel 58 | (k : KLR.NKI.Kernel) 59 | (genDebug : Bool) 60 | (pid : Option (Nat × Nat)) 61 | : PassM (TraceResult Core.Kernel) := do 62 | let int i := Term.int i 63 | let env := match pid with 64 | | none => (nl "_program_id", int 0) :: 65 | (nl "_num_programs", int 1) :: 66 | (nl "_program_ndim", int 0) :: kernelEnv k.arch ++ globalEnv 67 | | some (p,n) => (nl "_program_id", int p) :: 68 | (nl "_num_programs", int n) :: 69 | (nl "_program_ndim", int 1) :: kernelEnv k.arch ++ globalEnv 70 | tracer genDebug env (traceKernel k) 71 | 72 | -- TODO: check that inputs and outputs are the same 73 | -- TODO: check that shared constants are the same 74 | -- TODO: check that schedule edges make sense 75 | def runLncKernels (k : NKI.Kernel) (genDebug : Bool := false) 76 | : PassM (List (TraceResult Unit) × Core.LncKernel) := do 77 | let num := k.grid.max 1 78 | let res <- runNkiKernel k genDebug (0, num) 79 | let k0 := res.result 80 | 81 | let mut result := [{ res with result := () }] 82 | let mut bodies := [res.result.body] 83 | for i in [1:num] do 84 | let res <- runNkiKernel k genDebug (i,num) 85 | result := { res with result := () } :: result 86 | bodies := res.result.body :: bodies 87 | 88 | let kernel : Core.LncKernel := { 89 | name := k0.name 90 | inputs := k0.inputs 91 | outputs := k0.outputs 92 | bodies := bodies.reverse 93 | sharedConstants := [] 94 | edges := k.edges 95 | } 96 | return (result.reverse, kernel) 97 | -------------------------------------------------------------------------------- /KLR/Model.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Cli 18 | import KLR.Python 19 | import KLR.NKI 20 | open Cli 21 | 22 | def modelKLR (p : Parsed) : IO UInt32 := do 23 | /- Load the Python file and perform simplifications -/ 24 | let kernel : KLR.Python.Kernel <- gatherTmp p 25 | -- IO.println s!"[Kernel] \n{Lean.toJson kernel}" 26 | let (kernel, warnings) := kernel.inferArguments 27 | warnings.forM IO.eprintln 28 | let kernel : KLR.NKI.Kernel <- KLR.NKI.simplify kernel 29 | let (kernel, w) <- KLR.NKI.simplifyOperators kernel 30 | w.forM IO.println 31 | let kernel <- KLR.NKI.annotate kernel 32 | let kernel <- KLR.NKI.simplifyPatterns kernel 33 | -- IO.println s!"{Lean.toJson kernel}" 34 | match (NKI.model kernel : Err NMLModel) with 35 | | .error s => throw <| (IO.userError s) 36 | | .ok m => 37 | writeContent "lean" p (NKI.pprint_standalone_model m) 38 | return 0 39 | 40 | def equivKLR (p : Parsed) : IO UInt32 := do 41 | let debug := p.hasFlag "debug" 42 | let file := p.positionalArg! "moduleFileName" |>.as! String 43 | let kernelL := p.positionalArg! "kernelFunctionNameL" |>.as! String 44 | let kernelR := p.positionalArg! "kernelFunctionNameR" |>.as! String 45 | let dir := (p.flag? "klr-module-dir").map fun x => x.as! String 46 | let kernelL : KLR.Python.Kernel ← IO.FS.withTempFile fun _ tmpName => do 47 | gatherRun file kernelL tmpName.toString dir debug 48 | KLR.File.readKLRFile tmpName .cbor 49 | let kernelR : KLR.Python.Kernel ← IO.FS.withTempFile fun _ tmpName => do 50 | gatherRun file kernelR tmpName.toString dir debug 51 | KLR.File.readKLRFile tmpName .cbor 52 | let (kernelL, warnings) := kernelL.inferArguments 53 | warnings.forM IO.eprintln 54 | let (kernelR, warnings) := kernelR.inferArguments 55 | let kernelL : KLR.NKI.Kernel <- KLR.NKI.simplify kernelL 56 | let kernelR : KLR.NKI.Kernel <- KLR.NKI.simplify kernelR 57 | let (kernelL, w) <- KLR.NKI.simplifyOperators kernelL 58 | w.forM IO.println 59 | let (kernelR, w) <- KLR.NKI.simplifyOperators kernelR 60 | w.forM IO.println 61 | let kernelL <- KLR.NKI.annotate kernelL 62 | let kernelR <- KLR.NKI.annotate kernelR 63 | let kernelL <- KLR.NKI.simplifyPatterns kernelL 64 | let kernelR <- KLR.NKI.simplifyPatterns kernelR 65 | match (NKI.model kernelL : Err NMLModel) with 66 | | .error s => throw <| (IO.userError s) 67 | | .ok mL => 68 | match (NKI.model kernelR : Err NMLModel) with 69 | | .error s => throw <| (IO.userError s) 70 | | .ok mR => 71 | writeContent "lean" p (NKI.pprint_relational_goal mL mR) 72 | return 0 73 | 74 | def modelKLRCmd:= `[Cli| 75 | "model" VIA modelKLR; 76 | "Emit a Lean model of a KLR kernel which describes its semantics exactly." 77 | 78 | FLAGS: 79 | o, outfile : String; "Name of output file" 80 | 81 | ARGS: 82 | moduleFileName : String; "File of the Python module with the kernel function" 83 | kernelFunctionName : String; "Name of the kernel function" 84 | ] 85 | 86 | def equivKLRCmd:= `[Cli| 87 | "equiv" VIA equivKLR; 88 | "Emit a Lean file containing an open theorem that two KLR kernels are equivalent." 89 | 90 | FLAGS: 91 | o, outfile : String; "Name of output file" 92 | 93 | ARGS: 94 | moduleFileName : String; "File of the Python module with the kernel function" 95 | kernelFunctionNameL : String; "Name of the left kernel function" 96 | kernelFunctionNameR : String; "Name of the right kernel function" 97 | ] 98 | -------------------------------------------------------------------------------- /KLR/Serde/Attr.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import Lean 18 | 19 | /- 20 | # Attribute for serialization and de-serialization 21 | 22 | This modules defines a new attribute and some meta-programming utilities for 23 | accessing the attribute data. Note: new attributes must be defined in a 24 | separate module from where they are used. See `Test.lean` for more details. 25 | -/ 26 | namespace KLR.Serde 27 | open Lean Meta 28 | 29 | private structure SerdeTag where 30 | tag : Nat 31 | deriving Inhabited, BEq, Repr 32 | 33 | syntax (name := serde) "serde" "tag" "=" num : attr 34 | 35 | private initialize tags : ParametricAttribute SerdeTag <- 36 | registerParametricAttribute { 37 | name := `serde 38 | descr := "Assign Serde tag" 39 | getParam name stx := do 40 | let `(attr| serde tag = $t:num ) := stx 41 | | throwError "invalid [serde] attribute" 42 | return ⟨ TSyntax.getNat t ⟩ 43 | } 44 | 45 | -- Return the serde tag for a name (if any) 46 | def serdeTag [Monad m] [MonadEnv m] (name : Name) : m (Option Nat) := do 47 | return (tags.getParam? (<- getEnv) name).map (·.tag) 48 | 49 | -- Check for duplicates and reverse list (private function used below) 50 | private def checkDups (l : List (a × Nat)) : MetaM (List (a × Nat)) := do 51 | let rec loop l1 l2 := 52 | match l1, l2 with 53 | | l1, [] => return l1 54 | | l1, x :: xs => 55 | if l1.any fun p => p.2 == x.2 56 | then throwError s!"duplicate Serde tag found {x.2}" 57 | else loop (x :: l1) xs 58 | loop [] l 59 | 60 | /- 61 | Return a mapping of constructor names to serde tags for a given type. Any 62 | constructors without assigned tags will be automatically assigned a tag 63 | starting equal to the previous tag plus one, while avoiding any user defined 64 | values. 65 | 66 | One important use case is adding a new constructor. If we add a new constructor 67 | to a type, but we don't want to add it at the end, then we can manually assign 68 | it a tag and the other constructors will end up with their previous assignments. 69 | 70 | inductive A | a | c | d 71 | -- mapping is a=0, c=1, d=2 72 | 73 | -- changed to: 74 | inductive A | a | b | c | d 75 | -- mapping is: a=0, b=1, c=2, d=3 (INCORRECT) 76 | 77 | @[serde tag = 3 ] A.b 78 | -- mapping is: a=0, b=3, c=1, d=2 (correct) 79 | 80 | See Test.lean for more details. 81 | -/ 82 | partial def nextTag (user : List (Option Nat)) (n : Nat) : Nat := 83 | let n := n + 1 84 | if user.contains (some n) 85 | then nextTag user n 86 | else n 87 | 88 | def serdeMap (name : Name) : MetaM (List (Name × Nat)) := do 89 | let tcs <- getConstInfoInduct name 90 | let user <- tcs.ctors.mapM serdeTag 91 | let nextTag := nextTag user 92 | let mut current := 0 93 | let mut res := [] 94 | for ctor in tcs.ctors do 95 | match <- serdeTag ctor with 96 | | none => 97 | res := (ctor, current) :: res 98 | current := nextTag current 99 | | some n => 100 | res := (ctor, n) :: res 101 | current := nextTag n 102 | checkDups res 103 | 104 | abbrev Tags := Nat × List (Name × Nat) 105 | 106 | -- Convenience function: serdeTag and serdeMap 107 | def serdeTags (name : Name) : MetaM Tags := do 108 | match <- serdeTag name with 109 | | none => throwError s!"No serde tags for {name}" 110 | | some t => return (t, <- serdeMap name) 111 | -------------------------------------------------------------------------------- /KLR/NKI/Patterns.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Compile.Pass 18 | import KLR.NKI.Basic 19 | import KLR.Util 20 | 21 | /- 22 | Simplify Patterns found in let statements. 23 | 24 | This module will rewrite complex patterns into simple patterns. 25 | For example, the following: 26 | 27 | let x, y, z = e 28 | 29 | will be transformed to: 30 | 31 | let tmp = e 32 | let x = tmp[0] 33 | let y = tmp[1] 34 | let x = tmp[2] 35 | -/ 36 | 37 | namespace KLR.NKI 38 | open Compile.Pass 39 | 40 | -- Simplify Pattern = Simpat 41 | abbrev Simpat := Pass Unit 42 | 43 | -- Build a simple subscript expression, e.g. v[2,3,1] 44 | private def subscript (v : Name) (ix : List Int) : Expr := 45 | let pos : Pos := { line := 0 } 46 | let ix := ix.map fun n => .coord ⟨.value (.int n), pos ⟩ 47 | let e' := .access ⟨ .var v, pos ⟩ ix 48 | ⟨ e', pos ⟩ 49 | 50 | /- 51 | Expand a complex pattern into a set of subscript expressions. 52 | 53 | ((x,y),z) = v 54 | 55 | becomes 56 | 57 | x = v[0][0] 58 | y = v[0][1] 59 | x = v[1] 60 | -/ 61 | private def expand (v : Name) (ix : List Int) (i : Int) (ps : List Pattern) : List Stmt' := 62 | match ps with 63 | | [] => [] 64 | | p :: ps => 65 | let l := match p with 66 | | .var n => [.letM (.var n) none (subscript v (ix ++ [i]))] 67 | | .tuple ps => expand v (ix ++ [i]) 0 ps 68 | l ++ expand v ix (i+1) ps 69 | 70 | section Testing 71 | 72 | private def extract : Stmt' -> List Int 73 | | .letM _ _ ⟨.access _ ix, _⟩ => ix.map fun 74 | | .coord ⟨.value (.int i), _⟩ => i 75 | | _ => panic! "invalid" 76 | | _ => panic! "invalid" 77 | 78 | private def test (ps : List Pattern) : List (List Int) := 79 | let l := expand `x [] 0 ps 80 | l.map extract 81 | 82 | private def p1 : List Pattern := 83 | [.tuple [.var `a, .var `b], .tuple [.var `c], .var `d] 84 | 85 | private def p2 : List Pattern := 86 | [.var `a, .var `b, .var `c, .var `d] 87 | 88 | private def p3 : List Pattern := 89 | [.tuple [.tuple [.var `a]]] 90 | 91 | #guard test p1 == [[0,0], [0,1], [1,0], [2]] 92 | #guard test p2 == [[0], [1], [2], [3]] 93 | #guard test p3 == [[0,0,0]] 94 | 95 | end Testing 96 | 97 | mutual 98 | private def stmt (s : Stmt) : Simpat (List Stmt) := 99 | withPos s.pos do 100 | let l <- stmt' s.stmt 101 | let l := l.map fun s' => ⟨ s', s.pos ⟩ 102 | return l 103 | termination_by sizeOf s 104 | decreasing_by cases s; simp; omega 105 | 106 | private def stmts (l : List Stmt) : Simpat (List Stmt) := do 107 | l.flatMapM stmt 108 | termination_by sizeOf l 109 | 110 | private def stmt' (s : Stmt') : Simpat (List Stmt') := do 111 | match s with 112 | | .expr .. 113 | | .assert .. 114 | | .ret .. 115 | | .declare .. 116 | | .setM .. 117 | | .breakLoop 118 | | .continueLoop 119 | | .letM (.var ..) .. => return [s] 120 | | .letM (.tuple []) .. => throw "internal errro: empty tuple pattern not allowed in let binding" 121 | | .letM (.tuple ps) ty e => do 122 | let x <- freshName 123 | let st : Stmt' := .letM (.var x) ty e 124 | return st :: expand x [] 0 ps 125 | | .ifStm c t e => return [.ifStm c (<- stmts t) (<- stmts e)] 126 | | .forLoop x iter body => return [.forLoop x iter (<- stmts body)] 127 | | .whileLoop test body => return [.whileLoop test (<- stmts body)] 128 | termination_by sizeOf s 129 | end 130 | 131 | private def func (f : Fun) : Simpat Fun := 132 | return { f with body := <- stmts f.body } 133 | 134 | def simplifyPatterns (k : Kernel) : Simpat Kernel := do 135 | return { k with funs := <- k.funs.mapM func } 136 | -------------------------------------------------------------------------------- /KLR/Export.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Python 18 | import Lean 19 | 20 | /- 21 | This modules defines C export versions of inductive constructors. Both the Lean 22 | definitions and the C header files are generated by the meta-programs in this 23 | module. 24 | -/ 25 | 26 | namespace KLR.Export 27 | open Lean Meta 28 | 29 | /- 30 | Create a C-compatible name from a Lean Name 31 | 32 | Note: We remove "'", replacing this with an empty string. This is a bit nicer 33 | on the C side, and doesn't create and name conflicts for our types. 34 | -/ 35 | private def Cname : Name -> String 36 | | .anonymous => "" 37 | | .str .anonymous s 38 | | .str `KLR s => s.replace "'" "" 39 | | .str n s => Cname n ++ "_" ++ s.replace "'" "" 40 | | .num n k => Cname n ++ "_" ++ toString k 41 | 42 | /- 43 | Create the C type which corresponds to a given Lean type. 44 | 45 | Note: I have determined these by inspecting the output of the Lean compiler for 46 | the types we have; this may not be generally correct. 47 | -/ 48 | private def Ctype : Expr -> String 49 | | .const `Bool [] => "uint8_t" 50 | | .const `Float [] => "double" 51 | | .const `KLR.Python.Ctx [] 52 | | .const `KLR.Python.BoolOp [] 53 | | .const `KLR.Python.CmpOp [] 54 | | .const `KLR.Python.UnaryOp [] 55 | | .const `KLR.Python.BinOp [] => "uint8_t" 56 | | _ => "lean_object*" 57 | 58 | /- 59 | Generates a set of definitions of the form 60 | 61 | @[export X] def X := C 62 | 63 | For each constructor, C, of an inductive type. As a side effect this function 64 | will emit a C declaration for the function that will be generated by the Lean 65 | compiler. For example, if you have: 66 | 67 | inductive I where | C : Bool -> Nat -> I 68 | 69 | then we get a Lean definition: 70 | 71 | @[export I_C] def I_C := I.C 72 | 73 | and a C declaration 74 | 75 | lean_object* I_C(uint8_t, lean_object*); 76 | 77 | This code does not handle types with uninstantiated mvars. 78 | -/ 79 | private def generate (name : Name) : MetaM Unit := do 80 | let mut isEnum := true 81 | let mut atoms : Array String := #[] 82 | let tci <- getConstInfoInduct name 83 | for c in tci.ctors do 84 | let ci <- getConstInfoCtor c 85 | let id := mkIdent (Cname c).toName 86 | let ts <- forallTelescopeReducing ci.type fun xs _ => do 87 | xs.mapM fun x => return Ctype (<- x.fvarId!.getDecl).type 88 | if ts.size > 0 then 89 | isEnum := false 90 | IO.println s!"lean_object* {Cname c}({",".intercalate ts.toList});" 91 | else 92 | atoms := atoms.push (Cname c) 93 | let cmd <- `(@[export $id] def $id := $(mkIdent c)) 94 | liftCommandElabM (Elab.Command.elabCommand cmd) 95 | for a in atoms do 96 | if isEnum then 97 | IO.println s!"extern uint8_t {a};" 98 | else 99 | IO.println s!"extern lean_object* {a};" 100 | 101 | /- 102 | Search a namespace for all inductive definitions and generate externs for 103 | each one. 104 | -/ 105 | private def generateNamespace (name : Name) : MetaM Unit := do 106 | unless (<- getEnv).isNamespace name do 107 | throwError "{name} is not a namespace" 108 | for n in (<- getEnv).getNamespaceSet.toList do 109 | if name.isPrefixOf n && (<- isInductive n) then 110 | generate n 111 | 112 | /- 113 | Build extern definitions and C header file for types we need. 114 | -/ 115 | run_meta 116 | let h <- IO.FS.Handle.mk "interop/klr/lean_ast.h" IO.FS.Mode.write 117 | IO.withStdout (.ofHandle h) do 118 | IO.println "#include " 119 | generate `KLR.Core.Pos 120 | generateNamespace `KLR.Python 121 | -------------------------------------------------------------------------------- /interop/klr/klr_ffi.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "frontend.h" 4 | 5 | // forward declarations 6 | lean_object* initialize_KLR(uint8_t builtin, lean_object* w); 7 | void lean_initialize_runtime_module(); 8 | lean_object* lean_io_error_to_string(lean_object*); 9 | lean_object* klr_frontend_hello(lean_object*); 10 | lean_object* klr_frontend_panic(lean_object*); 11 | lean_object* klr_frontend_throw(lean_object*); 12 | lean_object* klr_frontend_trace(lean_object*, lean_object*, lean_object*); 13 | 14 | // Given a lean_io_result, sets a Python exception. 15 | // Steals the reference to lean_io_result. 16 | static void set_pyerr_from_lean_io_result_wprefix(lean_obj_arg l_io_result, const char *err_msg_prefix) { 17 | b_lean_obj_res l_io_error = lean_io_result_get_error(l_io_result); // borrows reference to arg, returns borrowed reference 18 | lean_inc_ref(l_io_error); 19 | lean_obj_res l_string = lean_io_error_to_string(l_io_error); // steals reference to arg, returns new reference 20 | const char *c_str = lean_string_cstr(l_string); // borrows reference to arg, returns borrowed c-str 21 | 22 | PyObject *py_exc_type = PyExc_RuntimeError; 23 | if (err_msg_prefix && err_msg_prefix[0] != 0) { 24 | PyErr_Format(py_exc_type, "%s: %s", err_msg_prefix, c_str); 25 | } else { 26 | PyErr_SetString(py_exc_type, c_str); 27 | } 28 | 29 | lean_dec(l_string); 30 | lean_dec(l_io_result); 31 | } 32 | 33 | // Given a lean_io_result, sets a Python exception. 34 | // Steals the reference to lean_io_result. 35 | static void set_pyerr_from_lean_io_result(lean_obj_arg l_io_result) { 36 | set_pyerr_from_lean_io_result_wprefix(l_io_result, NULL/*err_msg_prefix*/); 37 | } 38 | 39 | // Given an "IO Uint32" (lean_io_result with uint32 value), return a new PyLongObject. 40 | // If the lean_io_result is error, sets a Python exception and returns NULL. 41 | // Steals the reference to the lean_io_result. 42 | static PyObject *pylong_from_uint32_lean_io_result(lean_obj_arg l_io_result) { 43 | if (!lean_io_result_is_ok(l_io_result)) { 44 | set_pyerr_from_lean_io_result(l_io_result); // steals reference to arg 45 | return NULL; 46 | } 47 | 48 | uint32_t u32 = lean_unbox_uint32(lean_io_result_get_value(l_io_result)); 49 | lean_dec_ref(l_io_result); 50 | return PyLong_FromUnsignedLong(u32); 51 | } 52 | 53 | 54 | // Initialize Lean and the KLR module. 55 | // Returns true if successful. 56 | // Otherwise returns false and sets a Python exception 57 | bool initialize_KLR_lean_ffi() { 58 | // This code initially copied from: 59 | // https://lean-lang.org/doc/reference/4.22.0-rc2//Run-Time-Code/Foreign-Function-Interface/ 60 | // https://github.com/leanprover/lean4/blob/master/src/lake/examples/reverse-ffi/main.c 61 | lean_initialize_runtime_module(); 62 | 63 | // Disable panic messages during initialization. 64 | // The Lean compiler also does this when generating Main functions. 65 | // See: https://github.com/leanprover/lean4/commit/2018dc0 66 | lean_set_panic_messages(false); 67 | lean_obj_res l_io_result = initialize_KLR(1 /*builtin*/, lean_io_mk_world()); 68 | lean_set_panic_messages(true); 69 | if (lean_io_result_is_ok(l_io_result)) { 70 | lean_dec_ref(l_io_result); 71 | } else { 72 | set_pyerr_from_lean_io_result_wprefix(l_io_result, "Lean initialization failed"); 73 | return false; 74 | } 75 | lean_io_mark_end_initialization(); 76 | 77 | return true; 78 | } 79 | 80 | PyObject* lean_ffi_hello(PyObject *self, PyObject *args) { 81 | (void)self; 82 | (void)args; 83 | 84 | lean_obj_res l_io_result = klr_frontend_hello(lean_io_mk_world()); 85 | return pylong_from_uint32_lean_io_result(l_io_result); 86 | } 87 | 88 | PyObject* lean_ffi_throw(PyObject *self, PyObject *args) { 89 | (void)self; 90 | (void)args; 91 | 92 | lean_obj_res l_io_result = klr_frontend_throw(lean_io_mk_world()); 93 | return pylong_from_uint32_lean_io_result(l_io_result); 94 | } 95 | 96 | PyObject* lean_ffi_panic(PyObject *self, PyObject *args) { 97 | (void)self; 98 | (void)args; 99 | 100 | lean_obj_res l_io_result = klr_frontend_panic(lean_io_mk_world()); 101 | return pylong_from_uint32_lean_io_result(l_io_result); 102 | } 103 | -------------------------------------------------------------------------------- /interop/klr/lean_ast.h: -------------------------------------------------------------------------------- 1 | #include 2 | lean_object* Core_Pos_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); 3 | lean_object* Python_Const_bool(uint8_t); 4 | lean_object* Python_Const_int(lean_object*); 5 | lean_object* Python_Const_float(double); 6 | lean_object* Python_Const_string(lean_object*); 7 | lean_object* Python_Const_tensor(lean_object*,lean_object*); 8 | extern lean_object* Python_Const_none; 9 | extern lean_object* Python_Const_ellipsis; 10 | lean_object* Python_Keyword_mk(lean_object*,lean_object*,lean_object*); 11 | extern uint8_t Python_Ctx_load; 12 | extern uint8_t Python_Ctx_store; 13 | extern uint8_t Python_Ctx_del; 14 | lean_object* Python_Expr_mk(lean_object*,lean_object*); 15 | lean_object* Python_Kernel_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); 16 | lean_object* Python_Class_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); 17 | lean_object* Python_Stmt_expr(lean_object*); 18 | lean_object* Python_Stmt_assert(lean_object*,lean_object*); 19 | lean_object* Python_Stmt_ret(lean_object*); 20 | lean_object* Python_Stmt_assign(lean_object*,lean_object*); 21 | lean_object* Python_Stmt_augAssign(lean_object*,uint8_t,lean_object*); 22 | lean_object* Python_Stmt_annAssign(lean_object*,lean_object*,lean_object*); 23 | lean_object* Python_Stmt_ifStm(lean_object*,lean_object*,lean_object*); 24 | lean_object* Python_Stmt_forLoop(lean_object*,lean_object*,lean_object*,lean_object*); 25 | lean_object* Python_Stmt_whileLoop(lean_object*,lean_object*,lean_object*); 26 | extern lean_object* Python_Stmt_pass; 27 | extern lean_object* Python_Stmt_breakLoop; 28 | extern lean_object* Python_Stmt_continueLoop; 29 | lean_object* Python_Expr_const(lean_object*); 30 | lean_object* Python_Expr_name(lean_object*,uint8_t); 31 | lean_object* Python_Expr_attr(lean_object*,lean_object*,uint8_t); 32 | lean_object* Python_Expr_tuple(lean_object*,uint8_t); 33 | lean_object* Python_Expr_list(lean_object*,uint8_t); 34 | lean_object* Python_Expr_dict(lean_object*,lean_object*); 35 | lean_object* Python_Expr_subscript(lean_object*,lean_object*,uint8_t); 36 | lean_object* Python_Expr_slice(lean_object*,lean_object*,lean_object*); 37 | lean_object* Python_Expr_boolOp(uint8_t,lean_object*); 38 | lean_object* Python_Expr_binOp(uint8_t,lean_object*,lean_object*); 39 | lean_object* Python_Expr_unaryOp(uint8_t,lean_object*); 40 | lean_object* Python_Expr_compare(lean_object*,lean_object*,lean_object*); 41 | lean_object* Python_Expr_ifExp(lean_object*,lean_object*,lean_object*); 42 | lean_object* Python_Expr_call(lean_object*,lean_object*,lean_object*); 43 | lean_object* Python_Expr_starred(lean_object*,uint8_t); 44 | lean_object* Python_Expr_object(lean_object*,lean_object*); 45 | lean_object* Python_Expr_format(lean_object*,lean_object*); 46 | lean_object* Python_Expr_joined(lean_object*); 47 | lean_object* Python_Stmt_mk(lean_object*,lean_object*); 48 | extern uint8_t Python_CmpOp_eq; 49 | extern uint8_t Python_CmpOp_ne; 50 | extern uint8_t Python_CmpOp_lt; 51 | extern uint8_t Python_CmpOp_le; 52 | extern uint8_t Python_CmpOp_gt; 53 | extern uint8_t Python_CmpOp_ge; 54 | extern uint8_t Python_CmpOp_is; 55 | extern uint8_t Python_CmpOp_isNot; 56 | extern uint8_t Python_CmpOp_isIn; 57 | extern uint8_t Python_CmpOp_notIn; 58 | extern uint8_t Python_UnaryOp_invert; 59 | extern uint8_t Python_UnaryOp_not; 60 | extern uint8_t Python_UnaryOp_uadd; 61 | extern uint8_t Python_UnaryOp_usub; 62 | lean_object* Python_Fun_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); 63 | extern uint8_t Python_BinOp_add; 64 | extern uint8_t Python_BinOp_sub; 65 | extern uint8_t Python_BinOp_mul; 66 | extern uint8_t Python_BinOp_matmul; 67 | extern uint8_t Python_BinOp_div; 68 | extern uint8_t Python_BinOp_mod; 69 | extern uint8_t Python_BinOp_pow; 70 | extern uint8_t Python_BinOp_lshift; 71 | extern uint8_t Python_BinOp_rshift; 72 | extern uint8_t Python_BinOp_or; 73 | extern uint8_t Python_BinOp_xor; 74 | extern uint8_t Python_BinOp_and; 75 | extern uint8_t Python_BinOp_floor; 76 | extern uint8_t Python_BoolOp_land; 77 | extern uint8_t Python_BoolOp_lor; 78 | lean_object* Python_Args_mk(lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*,lean_object*); 79 | -------------------------------------------------------------------------------- /KLR/Util/NumBytes.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Util.Enum 18 | import Lean 19 | import TensorLib.Common 20 | 21 | namespace KLR.Util 22 | 23 | open Lean(Command Expr InductiveVal Name Term getEnv isInductive mkIdent) 24 | open Lean.Elab(registerDerivingHandler) 25 | open Lean.Elab.Command(CommandElabM liftTermElabM elabCommand) 26 | open Lean.Elab.Deriving(Context Header mkContext mkHeader mkInstanceCmds) 27 | open Lean.Elab.Term(TermElabM) 28 | open TensorLib(impossible) 29 | 30 | class NumBytes (a : Type) where 31 | numBytes : a -> Nat 32 | 33 | export NumBytes(numBytes) 34 | 35 | namespace NumBytes 36 | 37 | instance : NumBytes Int8 where 38 | numBytes _ := 1 39 | 40 | instance : NumBytes UInt8 where 41 | numBytes _ := 1 42 | 43 | instance : NumBytes Int16 where 44 | numBytes _ := 2 45 | 46 | instance : NumBytes UInt16 where 47 | numBytes _ := 2 48 | 49 | instance : NumBytes Int32 where 50 | numBytes _ := 4 51 | 52 | instance : NumBytes UInt32 where 53 | numBytes _ := 4 54 | 55 | instance : NumBytes UInt64 where 56 | numBytes _ := 8 57 | 58 | instance : NumBytes Int64 where 59 | numBytes _ := 8 60 | 61 | instance : NumBytes Float32 where 62 | numBytes _ := 4 63 | 64 | instance : NumBytes Float where 65 | numBytes _ := 8 66 | 67 | instance : NumBytes (BitVec n) where 68 | numBytes _ := (n + 7) / 8 69 | 70 | instance [NumBytes a][NumBytes b] : NumBytes (a × b) where 71 | numBytes := fun (x, y) => numBytes x + numBytes y 72 | 73 | instance [Enum a] : NumBytes a where 74 | numBytes _ := 1 75 | 76 | instance [Inhabited a][NumBytes a] : NumBytes (Vector a n) where 77 | numBytes _ := if n == 0 then 0 else n * numBytes (default : a) 78 | 79 | end NumBytes 80 | 81 | def mkNumBytesHeader (indVal : InductiveVal) : TermElabM Header := do 82 | mkHeader ``NumBytes 1 indVal 83 | 84 | def mkNumBytesBody (header : Header) (e : Expr): TermElabM Term := do 85 | let indName := e.getAppFn.constName! 86 | let env <- getEnv 87 | let fields := Lean.getStructureFieldsFlattened env indName (includeSubobjectFields := false) 88 | let target := mkIdent header.targetNames[0]! 89 | let apps : Array Term <- fields.mapM fun f => ``(NumBytes.numBytes ($target).$(mkIdent f)) 90 | let body : Term <- `(Array.sum #[ $apps,* ]) 91 | return body 92 | 93 | def mkNumBytesFunction (ctx : Context) : TermElabM Command := do 94 | let auxFunName := ctx.auxFunNames[0]! 95 | let header ← mkNumBytesHeader ctx.typeInfos[0]! 96 | let binders := header.binders 97 | let type ← Lean.Elab.Term.elabTerm header.targetType none 98 | let body ← mkNumBytesBody header type 99 | `(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Nat := $body:term) 100 | 101 | private def mkNumBytesInstance (declName : Name) : TermElabM (Array Command) := do 102 | let ctx ← mkContext "NumBytes" declName 103 | let cmds := #[← mkNumBytesFunction ctx] ++ (← mkInstanceCmds ctx ``NumBytes #[declName]) 104 | return cmds 105 | 106 | private def errMsg := "deriving NumBytes only works on single structures" 107 | 108 | def mkNumBytesInstanceHandler (declNames : Array Name) : CommandElabM Bool := match declNames with 109 | | #[] => impossible "Expected a type" 110 | | #[t] => do 111 | if (Lean.isStructure (<- getEnv) t) && (<- Lean.getConstInfoInduct t).all.length == 1 then 112 | let cmds ← liftTermElabM <| mkNumBytesInstance t 113 | cmds.forM elabCommand 114 | return true 115 | else 116 | throwError errMsg 117 | return false 118 | | _ => throwError errMsg 119 | 120 | initialize 121 | registerDerivingHandler ``NumBytes mkNumBytesInstanceHandler 122 | 123 | end KLR.Util 124 | -------------------------------------------------------------------------------- /KLR/Core/Pretty.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Core.Basic 18 | import KLR.Util 19 | 20 | namespace KLR.Core 21 | open Std 22 | 23 | /- 24 | This is a simple pretty printer for KLR terms. At some point, we may want to 25 | make this output valid python syntax that would parse and elaborate to the same 26 | KLR kernel. At the moment, there are too many unknowns to spend time on this. 27 | The format here is just for ease of debugging, feel free to modify as you wish. 28 | -/ 29 | 30 | private def abracket (f : Format) : Format := 31 | Format.bracket "<" f ">" 32 | 33 | private def args [ToFormat a] (l : List a) : Format := 34 | .paren (.joinSep l ",") 35 | 36 | private def sqArgs [ToFormat a] (l : List a) : Format := 37 | .sbracket (.joinSep l ",") 38 | 39 | instance [ToFormat a] [ToFormat b] : ToFormat (a × b) where 40 | format | (a,b) => .paren (format a ++ "," ++ format b) 41 | 42 | instance : ToFormat Memory where 43 | format 44 | | .hbm => "hbm" 45 | | .sbuf => "sbuf" 46 | | .psum => "psum" 47 | | .reg => "reg" 48 | 49 | instance : ToFormat Dtype where 50 | format dty := 51 | match (reprStr dty).toName with 52 | | .str _ name => name 53 | | _ => impossible "dtype repr must be a name" 54 | 55 | instance : ToFormat Address where 56 | format a := format a.memory ++ sqArgs [format a.partitionOffset, format a.freeOffset, format a.size] 57 | 58 | instance : ToFormat Shape where 59 | format s := sqArgs (s.parDim :: s.freeDims) 60 | 61 | instance : ToFormat TensorName where 62 | format t := t.name 63 | 64 | instance : ToFormat Slice where 65 | format s := .joinSep [format s.l, format s.u, format s.step] ":" 66 | 67 | instance : ToFormat Index where 68 | format 69 | | .coord i => format i 70 | | .slice s => format s 71 | 72 | instance : ToFormat AccessBasic where 73 | format acc := format acc.tensor ++ sqArgs acc.indexes 74 | 75 | instance : ToFormat APPair where 76 | format ap := args [ap.step, Int.ofNat ap.num] 77 | 78 | instance : ToFormat AccessPattern where 79 | format ap := format ap.tensor ++ (.sbracket <| sqArgs <| 80 | format ap.offset :: format ap.parNum :: ap.freePattern.map format) 81 | 82 | instance : ToFormat Access where 83 | format 84 | | .simple t => format t 85 | | .basic acc => format acc 86 | | .pattern ap => format ap 87 | 88 | instance : ToFormat Operator where 89 | format 90 | | .load => "load" 91 | | .save => "save" 92 | | .memset x => f!"memset {x}" 93 | | .tensorScalar .. => "tensor_scalar{..}" 94 | | .tensorScalarAddr .. => "tensor_scalar_addr{..}" 95 | 96 | instance : ToFormat Value where 97 | format 98 | | .var x => x 99 | | .bool b => format b 100 | | .int i => format i 101 | | .float f => format f 102 | | .access a => format a 103 | 104 | instance : ToFormat Expr where 105 | format 106 | | .value v => format v 107 | | .call f a k => format f ++ args (a ++ k.map Prod.snd) 108 | 109 | instance : ToFormat Stmt where 110 | format 111 | | .ret e => "ret" ++ format e 112 | | .assign x e => x ++ " = " ++ format e 113 | | .store d op as => format d ++ " := " ++ format op ++ args as 114 | 115 | def ppFullTensor (t : TensorName) : Format := 116 | t.name ++ sqArgs [ format t.dtype, format t.shape, format t.address ] 117 | 118 | instance : ToFormat Kernel where 119 | format k := 120 | let lines l := Format.joinSep l .line 121 | let nest_lines l := Format.nest 2 (.align true ++ lines l) 122 | lines [ 123 | Format.text k.name, 124 | "inputs:", nest_lines (k.inputs.map ppFullTensor), 125 | "outputs:", nest_lines (k.outputs.map ppFullTensor), 126 | "internal:", nest_lines (k.internal.map ppFullTensor), 127 | "body:", nest_lines (k.body.map format) 128 | ] 129 | -------------------------------------------------------------------------------- /interop/test/test_basic.py: -------------------------------------------------------------------------------- 1 | # This file exercises the Lean partial evaluator with 2 | # a set of basic unit tests. Each function is parsed, 3 | # handed to Lean, where it is checked and reduced to KLR. 4 | 5 | import pytest 6 | import numpy as np 7 | 8 | import klr.frontend 9 | from klr.frontend import Kernel 10 | 11 | # Success cases 12 | # (these functions should load and trace to KLR) 13 | 14 | def const_stmt(t): 15 | "this will be ignored because it has no effect" 16 | 1 # so will this, it is a simple constant 17 | 1.0 # so will this 18 | False # and this 19 | None # and this 20 | (1,2) # and this 21 | [1,2] # and this 22 | 23 | string = "a string" 24 | integer = -3 25 | floating = 1.23 26 | boolean = True 27 | nothing = None 28 | triple = (1, floating, False) 29 | list3 = [string, triple, klr.frontend] 30 | 31 | def expr_name(t): 32 | # these names will end up in the global environment after parsing 33 | # they will be eliminated after substitution during tracing 34 | string, integer, floating, boolean, nothing 35 | # constant tuples are also OK 36 | triple 37 | # as are constant lists 38 | list3 39 | # as are module references 40 | np 41 | 42 | def expr_tuple(t): 43 | assert (1,False,"hello") 44 | 45 | def expr_list(t): 46 | assert [1,2,False] 47 | assert not [] 48 | 49 | def expr_subscript(t): 50 | t[1] 51 | t[1,2] 52 | t[1:2:3] 53 | t[1:2] 54 | t[1:] 55 | t[1::] 56 | t[1::2] 57 | t[1:2:None] 58 | t[1:None:2] 59 | t[:] 60 | t[:,:] 61 | t[:,:,:] 62 | t[...] 63 | t[1,...] 64 | t[...,1] 65 | t[:,None] 66 | t[1] 67 | 68 | def expr_bool_op(t): 69 | True and 1 and [1] and [] and True # evals to [] 70 | False or None or [] or 1 # evals to 1 71 | 1 or None # evals to 1 72 | (False,) or 1 # evals to (False,) 73 | 74 | def expr_cmp_op(t): 75 | assert 1 == 1 76 | assert [] == [] 77 | assert not ([1,2] == [1]) 78 | assert not ([] < []) 79 | assert [] < [1] 80 | assert not ([1,2] < [1,2]) 81 | assert [1,1] < [1,2] 82 | assert [1,2] < [1,2,3] 83 | assert 1.2 < 2 84 | assert 1 < 1.2 85 | assert 1.2 < 1.3 86 | assert 0.5 < True 87 | assert not (0.5 < False) 88 | assert "a" < "ab" 89 | assert (1,2) is (1,2) 90 | assert not ([1,2] is [1,2]) 91 | assert 1 in (1,2) 92 | assert 1 in [3,2,1] 93 | assert 1 not in (2,3,4) 94 | assert 1 not in [] 95 | 96 | def assign(t): 97 | x = y = 1 98 | assert x == y 99 | x, y = [1,2] 100 | assert x == 1 101 | assert y == 2 102 | (x,y), z = a, [b,c] = ((1,2),(3,4)) 103 | assert x == 1 104 | assert y == 2 105 | assert z == (3,4) 106 | assert a == (1,2) 107 | assert b == 3 108 | assert c == 4 109 | 110 | def ifs(t): 111 | x = 0 112 | if x: x = 1 113 | else: x = 2 114 | assert x == 2 115 | if x: x = 1 116 | else: x = 2 117 | assert x == 1 118 | 119 | def loops(t): 120 | for x in [1,2,3,4]: 121 | if x == 1: continue 122 | assert x != 1 123 | if x == 3: break 124 | assert x == 3 125 | 126 | # some undefined names are OK 127 | def undefined_ok(t): 128 | nl.foo(t) 129 | 130 | def min_max_test(t): 131 | assert 1 == min(1, 2) 132 | assert 1 == min(2, 1) 133 | assert 1.0 == min(2.0, 1) 134 | assert 1.0 == min(1, 2.0) 135 | assert 1 == min([1, 2]) 136 | assert 1 == min([2, 1]) 137 | assert 1.0 == min([2.0, 1]) 138 | assert 1.0 == min([1, 2.0]) 139 | assert 2 == max(1, 2) 140 | assert 2 == max(2, 1) 141 | assert 2.0 == max(2.0, 1) 142 | assert 2.0 == max(1, 2.0) 143 | assert 2 == max([1, 2]) 144 | assert 2 == max([2, 1]) 145 | assert 2.0 == max([2.0, 1]) 146 | assert 2.0 == max([1, 2.0]) 147 | 148 | # test each function in turn 149 | @pytest.mark.parametrize("f", [ 150 | const_stmt, 151 | expr_name, 152 | expr_tuple, 153 | expr_list, 154 | expr_subscript, 155 | expr_bool_op, 156 | expr_cmp_op, 157 | assign, 158 | ifs, 159 | loops, 160 | undefined_ok, 161 | min_max_test 162 | ]) 163 | def test_succeed(f): 164 | t = np.zeros((10,10,10), dtype=np.float32) 165 | F = Kernel(f) # parse python 166 | file = F.specialize((t,)) 167 | 168 | # Failing cases 169 | # (These functions are expected to fail elaboration to KLR) 170 | 171 | def name_not_found(): 172 | x 173 | 174 | @pytest.mark.parametrize("f", [ 175 | name_not_found, 176 | ]) 177 | def test_fails(f): 178 | F = Kernel(f) 179 | with pytest.raises(Exception): 180 | F() 181 | 182 | if __name__ == '__main__': 183 | F = Parser(name_not_found) 184 | print(F()) 185 | -------------------------------------------------------------------------------- /KLR/TGR/Dot.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | Released under Apache 2.0 license as described in the file LICENSE. 4 | Authors: Paul Biberstein 5 | -/ 6 | 7 | import KLR.TGR.AST 8 | import SHerLOC.Analysis.Graph 9 | 10 | open StableHLO.Analysis (Graph Edge Vertex) 11 | 12 | /- This module provides a way to convert an TGR function into a DOT graph representation. -/ 13 | namespace KLR.TGR.Graph 14 | 15 | /- 16 | Process the name `var` so that it can used as a node ID in DOT format. 17 | Notably, IDs can't start with a digit, so we prefix it with "node_". 18 | -/ 19 | def sanitize (var : String) : String := 20 | s!"node_{var}" 21 | 22 | def makeReturnNode (funcName : String) : Vertex := 23 | .mk 24 | s!"return_{funcName}" 25 | (.mk [ 26 | ("label", s!"return\\n{funcName}"), 27 | ("shape", "box"), 28 | ("style", "filled"), 29 | ("fillcolor", "lightgray"), 30 | ("color", "gray") 31 | ]) 32 | def makeOpNode (op : Operator) (output : String) (ty : KLR.TGR.TensorTy): Vertex := 33 | let attrs := match op with 34 | | .arg .. => [ 35 | ("shape", "diamond"), 36 | ("style", "filled"), 37 | ("fillcolor", "lightgreen"), 38 | ("color", "green") 39 | ] 40 | | .batchMatmul .. => [ 41 | ("style", "filled"), 42 | ("fillcolor", "lightpink"), 43 | ("color", "red") 44 | ] 45 | | .slice .. => [ 46 | ("style", "filled"), 47 | ("fillcolor", "lightblue"), 48 | ("color", "blue") 49 | ] 50 | | _ => [] 51 | .mk 52 | (sanitize output) 53 | (.mk ([ 54 | ("label", s!"{opName op}\\n{output}\\n{ty.shape}"), 55 | ] ++ attrs)) 56 | 57 | def makeConstNode (op : String) (name : String) (shape : TensorTy) (usedBy : String): Vertex := 58 | .mk 59 | s!"node_const_{name}_{usedBy}" 60 | (.mk [ 61 | ("label", s!"{op}\\n{name}\\n{shape.shape}"), 62 | ("shape", "diamond"), 63 | ("style", "filled"), 64 | ("fillcolor", "lightyellow"), 65 | ("color", "yellow") 66 | ]) 67 | 68 | def makeEdge (source : String) (dest : String) : Edge := 69 | .mk 70 | source 71 | dest 72 | (.mk []) 73 | 74 | /- 75 | Convert an TGR function to a DOT graph, where each variable is a vertex 76 | and an edge exists from A to B if A is used in the computation of B. 77 | 78 | Note: since constants are reused in many parts of the function, they can 79 | cause the graph to have long edges that cross over other nodes. To avoid this, 80 | we create a separate vertex for each use of a constant. 81 | -/ 82 | def graph (f : TGR.Function) : Graph := Id.run do 83 | let mut vertices := [] 84 | let mut edges := [] 85 | /- Every variables in the function that is the result of a `constant` operatior -/ 86 | let mut consts := f.statements.filterMap (fun 87 | | .assign v (.const ..) shape => .some ("const", v, shape) 88 | | .assign v (.full ..) shape => .some ("full", v, shape) 89 | | _ => .none) 90 | /- A closure that creates edges from a list of inputs to an output variable. 91 | If the input is a constant, it creates a vertex for that constant. -/ 92 | let (makeEdges : List String → String → (List Vertex) × (List Edge)) := fun inputs output => Id.run do 93 | let mut vertices := [] 94 | let mut edges := [] 95 | for input in inputs do 96 | if let .some (op, v, shape) := consts.find? fun (_, v, _) => v == input then 97 | let node := makeConstNode op v shape output 98 | vertices := node :: vertices 99 | edges := (makeEdge node.id output) :: edges 100 | else 101 | edges := (makeEdge (sanitize input) output) :: edges 102 | return (vertices, edges) 103 | 104 | /- Walk the program statements and create vertices and edges. -/ 105 | for s in f.statements do 106 | match s with 107 | | .assign _ (.const ..) _ | .assign _ (.full ..) _ => pure () 108 | | .assign v op ty => 109 | let deps := dependencies op 110 | let (newVertices, newEdges) ← makeEdges deps (sanitize v) 111 | vertices := [makeOpNode op v ty] ++ newVertices ++ vertices 112 | edges := newEdges ++ edges 113 | | .ret vars => 114 | let node := makeReturnNode f.name 115 | let deps := vars 116 | let (newVertices, newEdges) ← makeEdges deps node.id 117 | vertices := [node] ++ newVertices ++ vertices 118 | edges := newEdges ++ edges 119 | | .comment _ => pure () 120 | 121 | pure $ .mk f.name vertices edges 122 | 123 | end KLR.TGR.Graph 124 | -------------------------------------------------------------------------------- /KLR/NKI/SimplifyOperators.lean: -------------------------------------------------------------------------------- 1 | /- 2 | Copyright KLR Contributors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -/ 16 | 17 | import KLR.Compile.Pass 18 | import KLR.NKI.Basic 19 | import KLR.Python 20 | import KLR.Util 21 | 22 | /- 23 | Simplification pass: convert operators to the ones that don't use mutating assignment 24 | -/ 25 | 26 | namespace KLR.NKI 27 | open Compile.Pass 28 | 29 | abbrev SimplifyOp := Pass Unit 30 | 31 | private def isISA : Expr -> Bool 32 | | ⟨ .var (.str `nki.isa "register_alloc"), _ ⟩ => false 33 | | ⟨ .var (.str `nki.isa _), _ ⟩ => true 34 | | _ => false 35 | 36 | private def rewriteOp (rhs: Expr) (dst: Expr) (accum : Bool) : SimplifyOp (Option Stmt') := do 37 | match rhs.expr with 38 | | .call f args kws => 39 | if isISA f then 40 | -- The way we resolve args would treat first positional arg as dst even if 41 | -- keyword arg is present 42 | let args := dst :: args 43 | let kws := if accum then ⟨ "psumAccumulateFlag", ⟨.value (.int (1 <<< 7)), rhs.pos⟩⟩ :: kws else kws 44 | let call : Expr' := .call f args kws 45 | return some (.expr ⟨call, rhs.pos⟩) 46 | else 47 | return none 48 | | _ => return none 49 | 50 | private def rewriteNdarray (stmt : Stmt') : Stmt' := 51 | -- hacky for now. Fixme to actually lookup names from environment 52 | -- this most likely belongs somewhere else 53 | match stmt with 54 | | .letM (.var x) ty ⟨.call ⟨.var fname, p0 ⟩ args kws, p1 ⟩ => 55 | if fname.toString.endsWith "ndarray" || fname.toString.endsWith "view" then 56 | if kws.any fun x => x.name == "name" then 57 | stmt 58 | else 59 | let uniqueNameCall := ⟨.call 60 | ⟨.var `builtin.lang.unique_name, p1⟩ 61 | [⟨.value (.string x.toString), p1⟩] 62 | [], 63 | p1⟩ 64 | let kws := ⟨"name", uniqueNameCall⟩ :: kws 65 | .letM (.var x) ty ⟨.call ⟨.var fname, p0 ⟩ args kws, p1 ⟩ 66 | else 67 | stmt 68 | | _ => stmt 69 | 70 | mutual 71 | private def stmt (s : Stmt) : SimplifyOp (Stmt) := do 72 | return ⟨ <- stmt' s.stmt, s.pos ⟩ 73 | termination_by sizeOf s 74 | decreasing_by cases s; simp; omega 75 | 76 | private def stmts (s : List Stmt) : SimplifyOp (List Stmt) := do 77 | return <- s.mapM stmt 78 | termination_by sizeOf s 79 | 80 | private def stmt' (s : Stmt') : SimplifyOp Stmt' := do 81 | let mutAssignWarning := "Mutating assignment (a[...] = foo(...)) form is deprecated. Use foo(..., dst=a[...]) instead" 82 | match s with 83 | | .letM x ty e => 84 | if let .binOp _ _ ⟨.call f args kws, _⟩ := e.expr then 85 | if isISA f then 86 | if let .var xExpr := x then 87 | match <- rewriteOp ⟨.call f args kws, e.pos⟩ ⟨.var xExpr, e.pos⟩ true with 88 | | some op => 89 | warn mutAssignWarning 90 | return op 91 | | none => return .letM x ty e 92 | if let .call f _ _ := e.expr then 93 | if isISA f then 94 | if let .var xExpr := x then 95 | match <- rewriteOp e ⟨.var xExpr, e.pos⟩ false with 96 | | some op => 97 | warn mutAssignWarning 98 | return op 99 | | none => return .letM x ty e 100 | return rewriteNdarray s 101 | | .setM x e accum => 102 | match <- rewriteOp e x accum with 103 | | some op => 104 | warn mutAssignWarning 105 | return op 106 | | none => return .setM x e accum 107 | -- reccur on statemtns 108 | | .ifStm e thn els => return .ifStm e (<- stmts thn) (<- stmts els) 109 | | .forLoop x iter body => return .forLoop x iter (<- stmts body) 110 | | .whileLoop test body => return .whileLoop test (<- stmts body) 111 | -- statments that only contain expressions don't need to be considered and can be simply passed back 112 | | .expr _ | .assert _ _ | .ret _ | .declare _ _ | .breakLoop | .continueLoop => return s 113 | termination_by sizeOf s 114 | end 115 | 116 | private def func (f : Fun) : SimplifyOp Fun := 117 | return { f with body := <- stmts f.body } 118 | 119 | def simplifyOperators (k : Kernel) : SimplifyOp Kernel := do 120 | return { k with funs := <- k.funs.mapM func } 121 | -------------------------------------------------------------------------------- /interop/test/examples/index.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from apis import * 4 | 5 | def tensor_split_kernel_(in_tensor, out_tensor_even, out_tensor_odd): 6 | """NKI kernel to split an input tensor into two output tensors, along the column axis. 7 | 8 | The even columns of the input tensor will be gathered into the first output tensor, 9 | and the odd columns of the input tensor will be gathered into the second output tensor. 10 | 11 | Args: 12 | in_tensor: an input tensor 13 | out_tensor_even: a first output tensor (will hold the even columns of the input tensor) 14 | out_tensor_odd: a second output tensor (will hold the odd columns of the input tensor) 15 | """ 16 | 17 | # Extract tile sizes. 18 | sz_p, sz_f = in_tensor.shape 19 | sz_fout_even, sz_fout_odd = out_tensor_even.shape[1], out_tensor_odd.shape[1] 20 | 21 | # We assume that all three tensors have the same partition dimension size 22 | # and it does not exceed pmax 23 | assert in_tensor.shape[0] == out_tensor_even.shape[0] == out_tensor_odd.shape[0] 24 | assert in_tensor.shape[0] <= nl.tile_size.pmax 25 | 26 | # Make sure even/odd output tensors have correct free dimension size 27 | assert sz_fout_even == math.ceil(sz_f / 2) 28 | assert sz_fout_odd == math.floor(sz_f / 2) 29 | 30 | # Generate tensor indices for the input/output tensors 31 | i_p = nl.arange(sz_p)[:, None] 32 | i_f = nl.arange(sz_f)[None, :] 33 | i_fout_even = nl.arange(sz_fout_even)[None, :] 34 | i_fout_odd = nl.arange(sz_fout_odd)[None, :] 35 | 36 | # Split pattern: 37 | i_f_even = (2 * i_fout_even) 38 | i_f_odd = (2 * i_fout_odd + 1) 39 | 40 | # Load input data from external memory to on-chip memory 41 | in_tile = nl.load(in_tensor[i_p, i_f]) 42 | 43 | # Perform the split 44 | # these assignments invoke copy instructions under the hood 45 | # which can execute on either Scalar or Vector Engine 46 | # (decided by compiler instruction scheduler) 47 | out_tile_even = in_tile[i_p, i_f_even] 48 | out_tile_odd = in_tile[i_p, i_f_odd] 49 | 50 | # Store the results back to external memory 51 | nl.store(out_tensor_even[i_p, i_fout_even], value=out_tile_even) 52 | nl.store(out_tensor_odd[i_p, i_fout_odd], value=out_tile_odd) 53 | 54 | 55 | def tensor_maxpool_kernel_(in_tensor, out_tensor, pool_size): 56 | """NKI kernel to compute a 2D max-pool operation 57 | 58 | Args: 59 | in_tensor: an input tensor, of dimensions C x H x W 60 | pool_size: integer P representing a (square) pool-window size 61 | out_tensor: the resulting output tensor, of dimensions C x (H/P) x (W/P) 62 | """ 63 | 64 | # Get input/output dimensions 65 | sz_cin, sz_hin, sz_win = in_tensor.shape 66 | sz_cout, sz_hout, sz_wout = out_tensor.shape 67 | assert sz_cin == sz_cout 68 | 69 | # Set relevant sizes 70 | sz_p = sz_cin 71 | sz_pool = pool_size 72 | 73 | # Generate tensor h/w index patterns 74 | # 3D indexing according to [C, H, W] 75 | i_p = nl.arange(sz_p)[:, None, None] # 3D for 76 | i_win = nl.arange(sz_win)[None, None, :] 77 | i_hin = nl.arange(sz_hin)[None, :, None] 78 | 79 | i_wout = nl.arange(sz_wout)[None, None, :] 80 | i_hout = nl.arange(sz_hout)[None, :, None] 81 | 82 | # Generate pool index patterns (requires two extra dimensions, for the pool window) 83 | i_0 = nl.arange(sz_p)[:, None, None, None, None] # 84 | i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer 85 | i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner 86 | i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer 87 | i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner 88 | 89 | # Load input data from external memory to on-chip memory 90 | # Declare ndarray to force a 3D tensor (temporary requirement) 91 | in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype) 92 | in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win]) 93 | 94 | # Perform the pooling operation: 95 | # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-max two dimension. 96 | # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation. 97 | # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides 98 | # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2]. 99 | # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4]. 100 | out_tile = nl.max(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) 101 | 102 | # Store the results back to external memory 103 | nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile) 104 | --------------------------------------------------------------------------------