├── src └── kirin │ ├── py.typed │ ├── decl │ ├── emit │ │ ├── __init__.py │ │ ├── dialect.py │ │ ├── name.py │ │ ├── traits.py │ │ ├── _set_new_attribute.py │ │ ├── repr.py │ │ └── _create_fn.py │ └── camel2snake.py │ ├── lowering │ ├── python │ │ └── __init__.py │ ├── exception.py │ ├── __init__.py │ └── stream.py │ ├── dialects │ ├── vmath │ │ ├── rewrites │ │ │ └── __init__.py │ │ └── _dialect.py │ ├── cf │ │ ├── dialect.py │ │ ├── __init__.py │ │ ├── interp.py │ │ └── abstract.py │ ├── math │ │ └── dialect.py │ ├── scf │ │ ├── _dialect.py │ │ ├── __init__.py │ │ └── interp.py │ ├── func │ │ ├── _dialect.py │ │ ├── rewrite │ │ │ └── __init__.py │ │ └── __init__.py │ ├── ilist │ │ ├── _dialect.py │ │ ├── rewrite │ │ │ ├── __init__.py │ │ │ ├── hint_len.py │ │ │ └── inline_getitem.py │ │ ├── _julia.py │ │ ├── __init__.py │ │ └── passes.py │ ├── py │ │ ├── cmp │ │ │ ├── _dialect.py │ │ │ ├── _julia.py │ │ │ ├── __init__.py │ │ │ ├── stmts.py │ │ │ ├── lowering.py │ │ │ └── interp.py │ │ ├── binop │ │ │ ├── _dialect.py │ │ │ ├── _julia.py │ │ │ ├── lowering.py │ │ │ ├── __init__.py │ │ │ └── stmts.py │ │ ├── list │ │ │ ├── _dialect.py │ │ │ ├── __init__.py │ │ │ ├── lowering.py │ │ │ ├── stmts.py │ │ │ └── interp.py │ │ ├── unary │ │ │ ├── _dialect.py │ │ │ ├── lowering.py │ │ │ ├── constprop.py │ │ │ ├── typeinfer.py │ │ │ ├── stmts.py │ │ │ ├── interp.py │ │ │ └── __init__.py │ │ ├── range.py │ │ ├── __init__.py │ │ ├── base.py │ │ ├── len.py │ │ └── attr.py │ ├── random │ │ ├── _dialect.py │ │ ├── __init__.py │ │ ├── interp.py │ │ └── stmts.py │ ├── lowering │ │ ├── __init__.py │ │ └── range.py │ ├── __init__.py │ └── eltype.py │ ├── analysis │ ├── const │ │ ├── _visitor.py │ │ ├── _visitor.pyi │ │ └── __init__.py │ ├── typeinfer │ │ └── __init__.py │ └── __init__.py │ ├── ir │ ├── traits │ │ ├── region │ │ │ ├── __init__.py │ │ │ └── ssacfg.py │ │ └── __init__.py │ ├── attrs │ │ ├── _types.py │ │ ├── __init__.py │ │ ├── data.py │ │ └── _types.pyi │ ├── nodes │ │ └── __init__.py │ └── use.py │ ├── rewrite │ ├── aggressive │ │ ├── __init__.py │ │ └── fold.py │ ├── alias.py │ ├── dce.py │ ├── getfield.py │ ├── type_assert.py │ ├── __init__.py │ ├── fixpoint.py │ ├── chain.py │ ├── getitem.py │ └── call2invoke.py │ ├── serialization │ ├── __init__.py │ ├── base │ │ └── __init__.py │ └── core │ │ ├── __init__.py │ │ ├── serializationunit.py │ │ ├── supportedtypes.py │ │ ├── serializable.py │ │ ├── serializationmodule.py │ │ └── deserializable.py │ ├── passes │ ├── aggressive │ │ ├── __init__.py │ │ ├── unroll.py │ │ └── fold.py │ ├── __init__.py │ ├── canonicalize.py │ ├── post_inference.py │ ├── hint_const.py │ ├── inline.py │ ├── abc.py │ ├── fold.py │ └── typeinfer.py │ ├── emit │ └── __init__.py │ ├── validation │ └── __init__.py │ ├── testing │ └── __init__.py │ ├── __init__.py │ ├── lattice │ ├── __init__.py │ └── empty.py │ ├── interp │ ├── exceptions.py │ ├── concrete.py │ └── undefined.py │ ├── print │ └── __init__.py │ ├── symbol_table.py │ ├── types.py │ ├── graph.py │ ├── worklist.py │ └── source.py ├── test ├── passes │ ├── __init__.py │ └── test_unroll_scf.py ├── print │ └── __init__.py ├── stmt │ ├── __init__.py │ └── test_statement.py ├── verify │ ├── __init__.py │ ├── test_typecheck.py │ └── test_method_verify.py ├── dialects │ ├── __init__.py │ ├── func │ │ ├── __init__.py │ │ ├── test_closurefield.py │ │ └── test_lambdalifting.py │ ├── math │ │ ├── __init__.py │ │ └── test_const.py │ ├── pyrules │ │ ├── __init__.py │ │ └── test_getitem.py │ ├── pystmts │ │ ├── __init__.py │ │ ├── test_getattr.py │ │ ├── test_coll_add.py │ │ ├── test_range.py │ │ └── test_slice.py │ ├── scf │ │ ├── __init__.py │ │ ├── test_fold.py │ │ ├── test_typeinfer.py │ │ └── test_unroll.py │ ├── vmath │ │ └── __init__.py │ ├── kirin_random │ │ ├── __init__.py │ │ └── test_random.py │ ├── py_dialect │ │ ├── __init__.py │ │ ├── test_iter.py │ │ ├── test_tuple_infer.py │ │ └── test_assign.py │ ├── test_numpy.py │ ├── test_func.py │ ├── test_debug.py │ ├── test_ilist2list.py │ ├── test_dummy.py │ ├── test_infer_len.py │ └── test_module.py ├── lowering │ ├── __init__.py │ ├── test_353.py │ ├── test_method_hint.py │ ├── test_binding.py │ ├── test_list.py │ ├── test_hint_union_binop.py │ ├── test_337.py │ ├── test_source_info.py │ ├── test_with.py │ └── test_with_binding.py ├── program │ └── py │ │ ├── aha │ │ ├── __init__.py │ │ ├── hoho.py │ │ └── gaga.py │ │ ├── test_tuple_hint.py │ │ ├── test_doc.py │ │ ├── test_index.py │ │ ├── test_87.py │ │ ├── test_const.py │ │ ├── test_glob_pi.py │ │ ├── test_closure.py │ │ ├── test_global_hint.py │ │ ├── test_abs.py │ │ ├── test_noreturn.py │ │ ├── test_signature.py │ │ ├── test_class.py │ │ ├── test_list_append.py │ │ ├── test_loop.py │ │ ├── test_cmp.py │ │ └── test_aggressive.py ├── analysis │ ├── dataflow │ │ ├── constprop │ │ │ ├── __init__.py │ │ │ └── test_worklist.py │ │ ├── typeinfer │ │ │ ├── test_selfref_closure.py │ │ │ ├── test_infer_lambda.py │ │ │ ├── test_unstable.py │ │ │ └── test_inter_method.py │ │ ├── test_cfg.py │ │ └── test_non_pure_const.py │ └── test_callgraph.py ├── ir │ ├── test_region.py │ ├── test_isequal.py │ ├── test_dialect.py │ ├── test_verify.py │ └── test_stmt.py ├── rules │ ├── test_dce.py │ ├── test_fold_br.py │ ├── test_fold.py │ ├── test_apply_type.py │ ├── test_alias.py │ └── test_cse.py ├── emit │ ├── test_julia.py │ └── julia_like.jl ├── test_worklist.py ├── interp │ └── test_select.py ├── testing │ └── test_assert_statements_same.py └── rewrite │ └── test_cse_rewrite.py ├── docs ├── blog │ ├── index.md │ └── posts │ │ ├── qft-code.png │ │ ├── puzzle-pieces.png │ │ └── typeinfer-basic.png ├── assets │ ├── favicon.ico │ ├── food-printing.png │ ├── logo-small-black.png │ └── logo-small-white.png ├── codegen.md ├── comparison.md ├── scripts │ ├── katex.js │ └── gen_ref_nav.py ├── stylesheets │ └── extra.css ├── dialects │ └── python │ │ ├── index.md │ │ ├── sugar.md │ │ └── data.md └── cookbook │ └── index.md ├── example ├── food │ ├── dialect.py │ ├── attrs.py │ ├── group.py │ ├── rewrite.py │ ├── stmts.py │ └── interp.py ├── pauli │ ├── dialect.py │ ├── script.py │ ├── group.py │ ├── stmts.py │ └── interp.py ├── simple.py └── README.md ├── .github ├── dependabot.yml └── workflows │ ├── isort.yml │ ├── lint.yml │ ├── doc.yml │ ├── devdoc.yml │ └── pub_doc.yml ├── justfile └── .pre-commit-config.yaml /src/kirin/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/passes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/print/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/stmt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/verify/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/lowering/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/blog/index.md: -------------------------------------------------------------------------------- 1 | # Blog 2 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/func/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/math/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/pyrules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/pystmts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/scf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/vmath/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/program/py/aha/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/kirin/lowering/python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/kirin_random/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/dialects/py_dialect/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/kirin/dialects/vmath/rewrites/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/analysis/dataflow/constprop/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/kirin/analysis/const/_visitor.py: -------------------------------------------------------------------------------- 1 | class _ElemVisitor: 2 | pass 3 | -------------------------------------------------------------------------------- /src/kirin/ir/traits/region/__init__.py: -------------------------------------------------------------------------------- 1 | """Builtin region traits.""" 2 | -------------------------------------------------------------------------------- /src/kirin/rewrite/aggressive/__init__.py: -------------------------------------------------------------------------------- 1 | from .fold import Fold as Fold 2 | -------------------------------------------------------------------------------- /example/food/dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("food") 4 | -------------------------------------------------------------------------------- /example/pauli/dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | _dialect = ir.Dialect("pauli") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/cf/dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("cf") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/math/dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("math") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/scf/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("scf") 4 | -------------------------------------------------------------------------------- /docs/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/assets/favicon.ico -------------------------------------------------------------------------------- /src/kirin/dialects/func/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect(name="func") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("py.ilist") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/cmp/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("py.cmp") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/random/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("random") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/vmath/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("vmath") 4 | -------------------------------------------------------------------------------- /src/kirin/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .jsonserializer import JSONSerializer as JSONSerializer 2 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/binop/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("py.binop") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/list/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("py.list") 4 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/_dialect.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | 3 | dialect = ir.Dialect("py.unary") 4 | -------------------------------------------------------------------------------- /docs/blog/posts/qft-code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/blog/posts/qft-code.png -------------------------------------------------------------------------------- /docs/assets/food-printing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/assets/food-printing.png -------------------------------------------------------------------------------- /docs/assets/logo-small-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/assets/logo-small-black.png -------------------------------------------------------------------------------- /docs/assets/logo-small-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/assets/logo-small-white.png -------------------------------------------------------------------------------- /docs/blog/posts/puzzle-pieces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/blog/posts/puzzle-pieces.png -------------------------------------------------------------------------------- /src/kirin/passes/aggressive/__init__.py: -------------------------------------------------------------------------------- 1 | from .fold import Fold as Fold 2 | from .unroll import UnrollScf as UnrollScf 3 | -------------------------------------------------------------------------------- /docs/blog/posts/typeinfer-basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuEraComputing/kirin/HEAD/docs/blog/posts/typeinfer-basic.png -------------------------------------------------------------------------------- /src/kirin/decl/camel2snake.py: -------------------------------------------------------------------------------- 1 | def camel2snake(name: str) -> str: 2 | return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") 3 | -------------------------------------------------------------------------------- /src/kirin/emit/__init__.py: -------------------------------------------------------------------------------- 1 | from .abc import EmitABC as EmitABC, EmitFrame as EmitFrame 2 | from .julia import Julia as Julia, JuliaFrame as JuliaFrame 3 | -------------------------------------------------------------------------------- /src/kirin/validation/__init__.py: -------------------------------------------------------------------------------- 1 | from .validationpass import ( 2 | ValidationPass as ValidationPass, 3 | ValidationSuite as ValidationSuite, 4 | ) 5 | -------------------------------------------------------------------------------- /src/kirin/dialects/func/rewrite/__init__.py: -------------------------------------------------------------------------------- 1 | from .closurefield import ClosureField as ClosureField 2 | from .lambdalifting import LambdaLifting as LambdaLifting 3 | -------------------------------------------------------------------------------- /src/kirin/ir/attrs/_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .abc import Attribute 4 | 5 | 6 | @dataclass 7 | class _TypeAttribute(Attribute): 8 | pass 9 | -------------------------------------------------------------------------------- /src/kirin/analysis/typeinfer/__init__.py: -------------------------------------------------------------------------------- 1 | """Type inference analysis for kirin.""" 2 | 3 | from .solve import TypeResolution as TypeResolution 4 | from .analysis import TypeInference as TypeInference 5 | -------------------------------------------------------------------------------- /src/kirin/lowering/exception.py: -------------------------------------------------------------------------------- 1 | from kirin.exception import StaticCheckError 2 | 3 | 4 | class BuildError(StaticCheckError): 5 | """Base class for all dialect lowering errors.""" 6 | 7 | pass 8 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | -------------------------------------------------------------------------------- /example/pauli/script.py: -------------------------------------------------------------------------------- 1 | from .group import pauli 2 | from .stmts import X, Y, Z 3 | 4 | 5 | @pauli 6 | def main(): 7 | ex = (X() + 2 * Y()) * Z() 8 | return ex 9 | 10 | 11 | main.print() 12 | -------------------------------------------------------------------------------- /src/kirin/serialization/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .context import SerializationContext as SerializationContext 2 | from .serializer import Serializer as Serializer 3 | from .deserializer import Deserializer as Deserializer 4 | -------------------------------------------------------------------------------- /src/kirin/testing/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful tools for testing.""" 2 | 3 | from .statements import ( 4 | assert_statements_same as assert_statements_same, 5 | assert_structurally_same as assert_structurally_same, 6 | ) 7 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/dialect.py: -------------------------------------------------------------------------------- 1 | from kirin.decl.base import BaseModifier 2 | 3 | 4 | class EmitDialect(BaseModifier): 5 | 6 | def emit_dialect(self): 7 | setattr(self.cls, "dialect", self.dialect) 8 | return 9 | -------------------------------------------------------------------------------- /test/program/py/aha/hoho.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | 3 | 4 | @basic_no_opt 5 | def foo(x: int) -> int: 6 | return x + 1 7 | 8 | 9 | @basic_no_opt 10 | def goo(x: int) -> int: 11 | return x + 1 12 | -------------------------------------------------------------------------------- /example/food/attrs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | # runtime value of a food 5 | @dataclass 6 | class Food: 7 | type: str 8 | 9 | 10 | @dataclass 11 | class Serving: 12 | kind: Food 13 | amount: int 14 | -------------------------------------------------------------------------------- /docs/codegen.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | This page is under construction. The content may be incomplete or incorrect. Submit an issue 3 | on [GitHub](https://github.com/QuEraComputing/kirin/issues/new) if you need help or want to 4 | contribute. 5 | -------------------------------------------------------------------------------- /test/dialects/math/test_const.py: -------------------------------------------------------------------------------- 1 | import math as pymath 2 | 3 | from kirin.dialects import math 4 | 5 | 6 | def test_const(): 7 | assert math.pi == pymath.pi 8 | assert math.e == pymath.e 9 | assert math.tau == pymath.tau 10 | -------------------------------------------------------------------------------- /docs/comparison.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | This page is under construction. The content may be incomplete or incorrect. Submit an issue 3 | on [GitHub](https://github.com/QuEraComputing/kirin/issues/new) if you need help or want to 4 | contribute. 5 | -------------------------------------------------------------------------------- /test/program/py/aha/gaga.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | 3 | 4 | @basic_no_opt 5 | def foo(x: int) -> int: 6 | assert x == 1 7 | return x + 1 8 | 9 | 10 | @basic_no_opt 11 | def goo(x: int) -> int: 12 | return x + 1 13 | -------------------------------------------------------------------------------- /src/kirin/passes/__init__.py: -------------------------------------------------------------------------------- 1 | from kirin.passes.abc import Pass as Pass 2 | from kirin.passes.fold import Fold as Fold 3 | from kirin.passes.typeinfer import TypeInfer as TypeInfer 4 | 5 | from .default import Default as Default 6 | from .hint_const import HintConst as HintConst 7 | -------------------------------------------------------------------------------- /src/kirin/__init__.py: -------------------------------------------------------------------------------- 1 | # re-exports the public API of the kirin package 2 | from . import ir as ir, types as types, lowering as lowering 3 | from .exception import enable_stracetrace, disable_stracetrace 4 | 5 | __all__ = ["ir", "types", "lowering", "enable_stracetrace", "disable_stracetrace"] 6 | -------------------------------------------------------------------------------- /src/kirin/serialization/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .serializable import Serializable as Serializable 2 | from .deserializable import Deserializable as Deserializable 3 | from .serializationunit import SerializationUnit as SerializationUnit 4 | from .serializationmodule import SerializationModule as SerializationModule 5 | -------------------------------------------------------------------------------- /example/pauli/group.py: -------------------------------------------------------------------------------- 1 | from dialect import _dialect 2 | 3 | from kirin import ir 4 | from kirin.prelude import basic_no_opt 5 | 6 | 7 | @ir.dialect_group(basic_no_opt.add(dialect=_dialect)) 8 | def pauli(self): 9 | def run_pass(mt): 10 | # TODO 11 | pass 12 | 13 | return run_pass 14 | -------------------------------------------------------------------------------- /src/kirin/ir/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | """Definition of Kirin's Intermediate Representation (IR) nodes.""" 2 | 3 | from kirin.ir.nodes.base import IRNode as IRNode 4 | from kirin.ir.nodes.stmt import Statement as Statement 5 | from kirin.ir.nodes.block import Block as Block 6 | from kirin.ir.nodes.region import Region as Region 7 | -------------------------------------------------------------------------------- /test/program/py/test_tuple_hint.py: -------------------------------------------------------------------------------- 1 | from kirin import types 2 | from kirin.prelude import basic 3 | 4 | 5 | @basic 6 | def tuple_hint(xs: tuple[int, ...]): 7 | types.Tuple[types.Int] 8 | 9 | 10 | def test_tuple_hint(): 11 | assert tuple_hint.arg_types[0].is_subseteq(types.Tuple[types.Vararg(types.Int)]) 12 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/name.py: -------------------------------------------------------------------------------- 1 | from kirin.decl.base import BaseModifier 2 | from kirin.decl.camel2snake import camel2snake 3 | 4 | from ._set_new_attribute import set_new_attribute 5 | 6 | 7 | class EmitName(BaseModifier): 8 | 9 | def emit_name(self): 10 | set_new_attribute(self.cls, "name", camel2snake(self.cls.__name__)) 11 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/rewrite/__init__.py: -------------------------------------------------------------------------------- 1 | from .list import List2IList as List2IList 2 | from .const import ConstList2IList as ConstList2IList 3 | from .unroll import Unroll as Unroll 4 | from .hint_len import HintLen as HintLen 5 | from .flatten_add import FlattenAdd as FlattenAdd 6 | from .inline_getitem import InlineGetItem as InlineGetItem 7 | -------------------------------------------------------------------------------- /test/program/py/test_doc.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from kirin.prelude import basic_no_opt 4 | 5 | 6 | def test_docstring(): 7 | @basic_no_opt 8 | def some_func(x): 9 | "Some kernel function" 10 | return x + 1 11 | 12 | assert inspect.getdoc(some_func) == "Some kernel function" 13 | assert some_func(1) == 2 14 | -------------------------------------------------------------------------------- /src/kirin/dialects/lowering/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains the dialects for choosing different lowering strategies. 2 | 3 | The dialects defined inside this module do not provide any new statements, it only 4 | provide different lowering strategies for existing statements. 5 | """ 6 | 7 | from . import cf as cf, call as call, func as func, range as range 8 | -------------------------------------------------------------------------------- /test/program/py/test_index.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | 3 | 4 | @basic_no_opt 5 | def setindex(a): 6 | a[1] = 2 7 | return a 8 | 9 | 10 | @basic_no_opt 11 | def index(a): 12 | return a[1] 13 | 14 | 15 | # TODO: actually test the lowered code 16 | def test_index(): 17 | index.code.print() 18 | setindex.code.print() 19 | -------------------------------------------------------------------------------- /test/program/py/test_87.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic 2 | 3 | 4 | @basic 5 | def foo(x: float): 6 | return x + 0.22 7 | 8 | 9 | @basic 10 | def issue_87(x: float): 11 | 12 | def inner(y: float, z: float): 13 | return foo(x) + y + z 14 | 15 | return inner 16 | 17 | 18 | def test_issue_87(): 19 | assert issue_87(1.0)(1, 2) == 4.22 20 | -------------------------------------------------------------------------------- /test/program/py/test_const.py: -------------------------------------------------------------------------------- 1 | from kirin import types 2 | from kirin.prelude import basic 3 | 4 | x = [1, 2, 3] 5 | 6 | 7 | @basic(typeinfer=True) 8 | def main(): 9 | return x[1] 10 | 11 | 12 | main.print(hint="const") 13 | 14 | 15 | def test_const_infer(): 16 | assert main.return_type is not None 17 | assert main.return_type.is_subseteq(types.Int) 18 | -------------------------------------------------------------------------------- /test/ir/test_region.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | 3 | 4 | @basic_no_opt 5 | def factorial(n): 6 | if n == 0: 7 | return 1 8 | else: 9 | return n * factorial(n - 1) 10 | 11 | 12 | def test_region_clone(): 13 | assert factorial.callable_region.clone().is_structurally_equal( 14 | factorial.callable_region 15 | ) 16 | -------------------------------------------------------------------------------- /test/dialects/test_numpy.py: -------------------------------------------------------------------------------- 1 | # This file is generated by gen.py 2 | import numpy as np 3 | 4 | from kirin.prelude import basic_no_opt 5 | 6 | 7 | @basic_no_opt 8 | def numpy_passing(x: np.ndarray): 9 | return x 10 | 11 | 12 | numpy_passing.code.print() 13 | 14 | 15 | def test_passing(): 16 | truth = np.arange(3) 17 | 18 | assert numpy_passing(truth) is truth 19 | -------------------------------------------------------------------------------- /docs/scripts/katex.js: -------------------------------------------------------------------------------- 1 | document$.subscribe(({ body }) => { 2 | renderMathInElement(body, { 3 | delimiters: [ 4 | { left: "$$", right: "$$", display: true }, 5 | { left: "$", right: "$", display: false }, 6 | { left: "\\(", right: "\\)", display: false }, 7 | { left: "\\[", right: "\\]", display: true } 8 | ], 9 | }) 10 | }) 11 | -------------------------------------------------------------------------------- /test/program/py/test_glob_pi.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from kirin.prelude import basic_no_opt 4 | from kirin.dialects import py 5 | 6 | 7 | def test_math_pi(): 8 | @basic_no_opt 9 | def main(): 10 | return math.pi 11 | 12 | stmt = main.callable_region.blocks[0].stmts.at(0) 13 | assert isinstance(stmt, py.Constant) 14 | assert stmt.value.unwrap() == math.pi 15 | -------------------------------------------------------------------------------- /test/program/py/test_closure.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic 2 | 3 | 4 | @basic 5 | def foo(x: int): # type: ignore 6 | def goo(y: int): 7 | return x + y 8 | 9 | return goo 10 | 11 | 12 | @basic 13 | def main(y: int): 14 | x = 1 15 | f = foo(x) 16 | return f(y) 17 | 18 | 19 | def test_main(): 20 | assert main(1) == 2 21 | assert main(2) == 3 22 | -------------------------------------------------------------------------------- /test/program/py/test_global_hint.py: -------------------------------------------------------------------------------- 1 | from kirin import types 2 | from kirin.prelude import basic 3 | 4 | 5 | def test_global_hint(): 6 | @basic 7 | def main(xs: types.Float) -> None: # type: ignore 8 | return None 9 | 10 | assert main.code.signature.inputs[0] == types.Float # type: ignore 11 | assert main.code.signature.output == types.NoneType # type: ignore 12 | -------------------------------------------------------------------------------- /src/kirin/serialization/core/serializationunit.py: -------------------------------------------------------------------------------- 1 | class SerializationUnit: 2 | kind: str 3 | module_name: str 4 | class_name: str 5 | data: dict 6 | 7 | def __init__(self, kind: str, module_name: str, class_name: str, data: dict): 8 | self.kind = kind 9 | self.module_name = module_name 10 | self.class_name = class_name 11 | self.data = data 12 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | coverage-run: 2 | coverage run -m pytest test 3 | 4 | coverage-xml: 5 | coverage xml 6 | 7 | coverage-html: 8 | coverage html 9 | 10 | coverage-report: 11 | coverage report 12 | 13 | coverage-open: 14 | open htmlcov/index.html 15 | 16 | coverage: coverage-run coverage-xml coverage-report 17 | 18 | doc: 19 | mkdocs serve 20 | 21 | doc-build: 22 | mkdocs build 23 | -------------------------------------------------------------------------------- /test/lowering/test_353.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from kirin import types 4 | from kirin.prelude import basic 5 | 6 | N = typing.TypeVar("N") 7 | 8 | 9 | class MyTest(typing.Generic[N]): 10 | pass 11 | 12 | 13 | def test_generic_type_hint(): 14 | @basic 15 | def test(obj: MyTest[N]): 16 | return None 17 | 18 | assert test.arg_types[0].is_subseteq(types.PyClass(MyTest)) 19 | -------------------------------------------------------------------------------- /test/verify/test_typecheck.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin import ir 4 | from kirin.prelude import basic 5 | from kirin.dialects import math 6 | 7 | 8 | @basic(verify=False, typeinfer=False) 9 | def check_type_err(a, b): 10 | math.sin(a) 11 | return math.sin(b) 12 | 13 | 14 | def test_check_type(): 15 | with pytest.raises(ir.TypeCheckError): 16 | check_type_err.code.verify_type() 17 | -------------------------------------------------------------------------------- /test/lowering/test_method_hint.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.prelude import basic 3 | 4 | 5 | def test_method_type_hint(): 6 | @basic 7 | def main() -> ir.Method[[int, int], float]: 8 | 9 | def test(x: int, y: int) -> float: 10 | return x * y * 3.0 11 | 12 | return test 13 | 14 | assert main.return_type == types.MethodType[[types.Int, types.Int], types.Float] 15 | -------------------------------------------------------------------------------- /test/program/py/test_abs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.prelude import basic_no_opt 4 | from kirin.dialects.py import builtin 5 | 6 | 7 | @basic_no_opt 8 | def abs_kernel(x): 9 | return builtin.Abs(value=x) 10 | 11 | 12 | @pytest.mark.parametrize("x", [-1, 2, 3.0, -3.2]) 13 | def test_abs(x): 14 | 15 | abs_x = abs_kernel(x) 16 | 17 | assert isinstance(abs_x, type(abs(x))) 18 | assert abs_x == abs(x) 19 | -------------------------------------------------------------------------------- /test/lowering/test_binding.py: -------------------------------------------------------------------------------- 1 | from kirin import lowering 2 | from kirin.prelude import basic_no_opt 3 | from kirin.dialects import math 4 | 5 | 6 | @lowering.wraps(math.stmts.sin) 7 | def sin(value: float) -> float: ... 8 | 9 | 10 | @basic_no_opt 11 | def main(x: float): 12 | return sin(x) 13 | 14 | 15 | def test_binding(): 16 | stmt = main.callable_region.blocks[0].stmts.at(0) 17 | assert isinstance(stmt, math.stmts.sin) 18 | -------------------------------------------------------------------------------- /src/kirin/serialization/core/supportedtypes.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from collections.abc import Sequence 3 | 4 | SUPPORTED_PYTHON_TYPES = Union[ 5 | # Python built-in types 6 | bool, 7 | bytes, 8 | bytearray, 9 | dict, 10 | float, 11 | frozenset, 12 | int, 13 | list, 14 | range, 15 | set, 16 | slice, 17 | str, 18 | tuple, 19 | type, 20 | type(None), 21 | Sequence, 22 | ] 23 | -------------------------------------------------------------------------------- /test/dialects/pystmts/test_getattr.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin.prelude import basic 4 | 5 | 6 | def test_getattr(): 7 | @dataclass 8 | class MyFoo: 9 | x: float 10 | y: float 11 | z: float 12 | 13 | foo = MyFoo(1.0, 2.0, 3.0) 14 | 15 | @basic 16 | def main(): 17 | return foo.x + 1.0 18 | 19 | main.print() 20 | out = main() 21 | 22 | assert out == 2.0 23 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/traits.py: -------------------------------------------------------------------------------- 1 | from kirin.decl.base import BaseModifier 2 | 3 | from ._set_new_attribute import set_new_attribute 4 | 5 | 6 | class EmitTraits(BaseModifier): 7 | 8 | def emit_traits(self): 9 | # if no parent defines traits, set it to empty set 10 | for base in self.cls.__mro__[-1:0:-1]: 11 | if hasattr(base, "traits"): 12 | return 13 | set_new_attribute(self.cls, "traits", frozenset({})) 14 | -------------------------------------------------------------------------------- /src/kirin/ir/use.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | from dataclasses import dataclass 5 | 6 | if TYPE_CHECKING: 7 | from kirin.ir.nodes.stmt import Statement 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Use: 12 | """A use of an SSA value in a statement.""" 13 | 14 | stmt: Statement 15 | """The statement that uses the SSA value.""" 16 | index: int 17 | """The index of the use in the statement.""" 18 | -------------------------------------------------------------------------------- /src/kirin/analysis/const/_visitor.pyi: -------------------------------------------------------------------------------- 1 | from .lattice import Value, Bottom, Unknown, PartialTuple, PartialLambda 2 | 3 | class _ElemVisitor: 4 | def is_subseteq_Value(self, other: Value) -> bool: ... 5 | def is_subseteq_NotConst(self, other: Unknown) -> bool: ... 6 | def is_subseteq_Unknown(self, other: Bottom) -> bool: ... 7 | def is_subseteq_PartialTuple(self, other: PartialTuple) -> bool: ... 8 | def is_subseteq_PartialLambda(self, other: PartialLambda) -> bool: ... 9 | -------------------------------------------------------------------------------- /test/program/py/test_noreturn.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.prelude import basic_no_opt 4 | 5 | 6 | @basic_no_opt 7 | def no_return(x): 8 | return 9 | 10 | 11 | def test_noreturn(): 12 | assert no_return(1) is None 13 | 14 | with pytest.raises(ValueError): 15 | no_return() 16 | 17 | 18 | def test_noreturn_with_body(): 19 | @basic_no_opt 20 | def no_return_with_body(x): 21 | x + 1 22 | 23 | assert no_return_with_body(1) is None 24 | -------------------------------------------------------------------------------- /src/kirin/lattice/__init__.py: -------------------------------------------------------------------------------- 1 | from kirin.lattice.abc import ( 2 | Lattice as Lattice, 3 | UnionMeta as UnionMeta, 4 | LatticeMeta as LatticeMeta, 5 | SingletonMeta as SingletonMeta, 6 | BoundedLattice as BoundedLattice, 7 | ) 8 | from kirin.lattice.empty import EmptyLattice as EmptyLattice 9 | from kirin.lattice.mixin import ( 10 | IsSubsetEqMixin as IsSubsetEqMixin, 11 | SimpleJoinMixin as SimpleJoinMixin, 12 | SimpleMeetMixin as SimpleMeetMixin, 13 | ) 14 | -------------------------------------------------------------------------------- /src/kirin/dialects/__init__.py: -------------------------------------------------------------------------------- 1 | """Built-in dialects for Kirin. 2 | 3 | This module contains the built-in dialects for Kirin. Each dialect is an 4 | instance of the `Dialect` class. Each submodule contains a `dialect` variable 5 | that is an instance of the corresponding `Dialect` class. 6 | 7 | The modules can be directly used as dialects. For example, you can write 8 | 9 | ```python 10 | from kirin.dialects import py, func 11 | ``` 12 | 13 | to import the Python and function dialects. 14 | """ 15 | -------------------------------------------------------------------------------- /.github/workflows/isort.yml: -------------------------------------------------------------------------------- 1 | name: Run isort 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v6 16 | - uses: isort/isort-action@v1 17 | with: 18 | sortPaths: "src" # only sort files in the src directory 19 | -------------------------------------------------------------------------------- /test/dialects/pystmts/test_coll_add.py: -------------------------------------------------------------------------------- 1 | from kirin import types 2 | from kirin.prelude import basic 3 | from kirin.dialects.ilist import IList, IListType 4 | 5 | 6 | @basic(typeinfer=True) 7 | def tuple_new(x: int, xs: tuple): 8 | return xs + (1, x) 9 | 10 | 11 | @basic(typeinfer=True) 12 | def list_new(x: int, xs: IList): 13 | return xs + [1, x] 14 | 15 | 16 | def test_tuple_add(): 17 | assert tuple_new.return_type.is_subseteq(types.Tuple) 18 | assert list_new.return_type.is_subseteq(IListType) 19 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/cmp/_julia.py: -------------------------------------------------------------------------------- 1 | from kirin import emit, interp 2 | 3 | from .stmts import Eq 4 | from ._dialect import dialect 5 | 6 | 7 | @dialect.register(key="emit.julia") 8 | class JuliaEmit(interp.MethodTable): 9 | @interp.impl(Eq) 10 | def add(self, emit_: emit.Julia, frame: emit.JuliaFrame, node: Eq): 11 | lhs = frame.get(node.lhs) 12 | rhs = frame.get(node.rhs) 13 | frame.write_line(f"{frame.ssa[node.result]} = ({lhs} == {rhs})") 14 | return (frame.ssa[node.result],) 15 | -------------------------------------------------------------------------------- /src/kirin/serialization/core/serializable.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import TYPE_CHECKING, Protocol, runtime_checkable 3 | 4 | if TYPE_CHECKING: 5 | from kirin.serialization.base.serializer import Serializer 6 | from kirin.serialization.core.serializationunit import SerializationUnit 7 | 8 | 9 | @runtime_checkable 10 | class Serializable(Protocol): 11 | @abstractmethod 12 | def serialize(self, serializer: "Serializer") -> "SerializationUnit": 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/binop/_julia.py: -------------------------------------------------------------------------------- 1 | from kirin import emit, interp 2 | 3 | from .stmts import Add 4 | from ._dialect import dialect 5 | 6 | 7 | @dialect.register(key="emit.julia") 8 | class JuliaEmit(interp.MethodTable): 9 | @interp.impl(Add) 10 | def add(self, emit_: emit.Julia, frame: emit.JuliaFrame, node: Add): 11 | lhs = frame.get(node.lhs) 12 | rhs = frame.get(node.rhs) 13 | frame.write_line(f"{frame.ssa[node.result]} = ({lhs} + {rhs})") 14 | return (frame.ssa[node.result],) 15 | -------------------------------------------------------------------------------- /src/kirin/rewrite/alias.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin import ir 4 | from kirin.rewrite.abc import RewriteRule, RewriteResult 5 | from kirin.dialects.py.assign import Alias 6 | 7 | 8 | @dataclass 9 | class InlineAlias(RewriteRule): 10 | 11 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 12 | if not isinstance(node, Alias): 13 | return RewriteResult() 14 | 15 | node.result.replace_by(node.value) 16 | return RewriteResult(has_done_something=True) 17 | -------------------------------------------------------------------------------- /src/kirin/passes/canonicalize.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin.ir.method import Method 4 | from kirin.rewrite.abc import RewriteResult 5 | 6 | from .abc import Pass 7 | 8 | 9 | @dataclass 10 | class Canonicalize(Pass): 11 | 12 | def unsafe_run(self, mt: Method) -> RewriteResult: 13 | result = RewriteResult() 14 | for dialect in self.dialects: 15 | for rule in dialect.rules.canonicalize: 16 | result = rule.rewrite(mt.code).join(result) 17 | return result 18 | -------------------------------------------------------------------------------- /src/kirin/passes/post_inference.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin.ir.method import Method 4 | from kirin.rewrite.abc import RewriteResult 5 | 6 | from .abc import Pass 7 | 8 | 9 | @dataclass 10 | class PostInference(Pass): 11 | 12 | def unsafe_run(self, mt: Method) -> RewriteResult: 13 | result = RewriteResult() 14 | for dialect in self.dialects: 15 | for rule in dialect.rules.inference: 16 | result = rule.rewrite(mt.code).join(result) 17 | return result 18 | -------------------------------------------------------------------------------- /test/dialects/py_dialect/test_iter.py: -------------------------------------------------------------------------------- 1 | import kirin.prelude 2 | 3 | 4 | def iter_non_pure(): 5 | @kirin.prelude.basic 6 | def loop(a: str): 7 | out = [] 8 | for i in range(4): 9 | # for i in [1,1,2,3]: # Same result with this line instead 10 | out = out + [a] 11 | return out 12 | 13 | x = loop("a") 14 | assert x == ["a", "a", "a", "a"] 15 | 16 | x = loop("b") 17 | assert x == ["b", "b", "b", "b"] 18 | 19 | x = loop("c") 20 | assert x == ["c", "c", "c", "c"] 21 | -------------------------------------------------------------------------------- /test/dialects/scf/test_fold.py: -------------------------------------------------------------------------------- 1 | from kirin.passes import Fold 2 | from kirin.prelude import basic, structural_no_opt 3 | 4 | 5 | def test_simple_loop(): 6 | @structural_no_opt 7 | def simple_loop(): 8 | x = 0 9 | for i in range(2): 10 | x = x + 1 11 | return x 12 | 13 | @basic(fold=True) 14 | def target(): 15 | return 2 16 | 17 | fold = Fold(structural_no_opt) 18 | fold(simple_loop) 19 | assert target.callable_region.is_structurally_equal(simple_loop.callable_region) 20 | -------------------------------------------------------------------------------- /src/kirin/interp/exceptions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | # errors 5 | class InterpreterError(Exception): 6 | """Generic interpreter error. 7 | 8 | This is the base class for all interpreter errors. 9 | """ 10 | 11 | pass 12 | 13 | 14 | class StackOverflowError(InterpreterError): 15 | """An error raised when the interpreter stack overflows.""" 16 | 17 | pass 18 | 19 | 20 | class FuelExhaustedError(InterpreterError): 21 | """An error raised when the interpreter runs out of fuel.""" 22 | 23 | pass 24 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --md-primary-fg-color: #6437FF; 3 | --md-accent-fg-color: #6437FF; 4 | } 5 | 6 | #logo_light_mode { 7 | display: var(--md-footer-logo-light-mode); 8 | } 9 | 10 | #logo_dark_mode { 11 | display: var(--md-footer-logo-dark-mode); 12 | } 13 | 14 | [data-md-color-scheme="slate"] { 15 | --md-footer-logo-dark-mode: block; 16 | --md-footer-logo-light-mode: none; 17 | } 18 | 19 | [data-md-color-scheme="default"] { 20 | --md-footer-logo-dark-mode: none; 21 | --md-footer-logo-light-mode: block; 22 | } 23 | -------------------------------------------------------------------------------- /docs/dialects/python/index.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | This page is under construction. The content may be incomplete or incorrect. Submit an issue 3 | on [GitHub](https://github.com/QuEraComputing/kirin/issues/new) if you need help or want to 4 | contribute. 5 | 6 | # Python Dialects 7 | 8 | Kirin provides a set of dialects that represents fractions of Python semantics. We will describe 9 | each dialect in this page. The general design principle of these dialects is to provide a composable 10 | set of Python semantics that can be used to build different embedded DSLs inside Python. 11 | -------------------------------------------------------------------------------- /src/kirin/serialization/core/serializationmodule.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | if TYPE_CHECKING: 4 | from kirin.serialization.base.context import MethodSymbolMeta 5 | from kirin.serialization.core.serializationunit import SerializationUnit 6 | 7 | 8 | class SerializationModule: 9 | symbol_table: dict[str, "MethodSymbolMeta"] 10 | body: "SerializationUnit" 11 | 12 | def __init__( 13 | self, symbol_table: dict[str, "MethodSymbolMeta"], body: "SerializationUnit" 14 | ): 15 | self.symbol_table = symbol_table 16 | self.body = body 17 | -------------------------------------------------------------------------------- /src/kirin/rewrite/dce.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin import ir 4 | from kirin.rewrite.abc import RewriteRule, RewriteResult 5 | 6 | 7 | @dataclass 8 | class DeadCodeElimination(RewriteRule): 9 | 10 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 11 | if self.is_pure(node): 12 | for result in node._results: 13 | if result.uses: 14 | return RewriteResult() 15 | 16 | node.delete() 17 | return RewriteResult(has_done_something=True) 18 | 19 | return RewriteResult() 20 | -------------------------------------------------------------------------------- /test/program/py/test_signature.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from kirin import types 4 | from kirin.prelude import basic 5 | from kirin.dialects.ilist import IList, IListType 6 | 7 | 8 | @basic 9 | def complicated_type(x: IList[tuple[float, float, IList[float, Any]], Any]): 10 | return x 11 | 12 | 13 | def test_complicated_type(): 14 | typ = complicated_type.arg_types[0] 15 | assert isinstance(typ, types.Generic) 16 | assert typ.is_subseteq( 17 | IListType[ 18 | types.Tuple[types.Float, types.Float, IListType[types.Float]], types.Any 19 | ] 20 | ) 21 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/lowering.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from kirin import lowering 4 | 5 | from . import stmts 6 | from ._dialect import dialect 7 | 8 | 9 | @dialect.register 10 | class Lowering(lowering.FromPythonAST): 11 | 12 | def lower_UnaryOp( 13 | self, state: lowering.State, node: ast.UnaryOp 14 | ) -> lowering.Result: 15 | if op := getattr(stmts, node.op.__class__.__name__, None): 16 | return state.current_frame.push(op(state.lower(node.operand).expect_one())) 17 | else: 18 | raise lowering.BuildError(f"unsupported unary operator {node.op}") 19 | -------------------------------------------------------------------------------- /test/dialects/test_func.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin import ir 4 | from kirin.prelude import structural_no_opt 5 | 6 | 7 | def test_python_func(): 8 | def some_func(x): 9 | return x + 1 10 | 11 | @structural_no_opt 12 | def dumm(x): 13 | return some_func(x) 14 | 15 | with pytest.raises(ir.TypeCheckError): 16 | dumm.code.verify_type() 17 | 18 | some_staff = "" 19 | 20 | @structural_no_opt 21 | def dumm2(x): 22 | return some_staff(x) # type: ignore 23 | 24 | with pytest.raises(ir.TypeCheckError): 25 | dumm.code.verify_type() 26 | -------------------------------------------------------------------------------- /src/kirin/lowering/__init__.py: -------------------------------------------------------------------------------- 1 | from .abc import Result as Result, LoweringABC as LoweringABC 2 | from .frame import Frame as Frame 3 | from .state import State as State 4 | from .exception import BuildError as BuildError 5 | from .python.traits import ( 6 | FromPythonCall as FromPythonCall, 7 | FromPythonWith as FromPythonWith, 8 | FromPythonRangeLike as FromPythonRangeLike, 9 | FromPythonWithSingleItem as FromPythonWithSingleItem, 10 | ) 11 | from .python.binding import wraps as wraps 12 | from .python.dialect import FromPythonAST as FromPythonAST, akin as akin 13 | from .python.lowering import Python as Python 14 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/_julia.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from kirin import emit, interp 4 | 5 | from .stmts import Range 6 | from ._dialect import dialect 7 | 8 | 9 | @dialect.register(key="emit.julia") 10 | class JuliaMethodTable(interp.MethodTable): 11 | 12 | @interp.impl(Range) 13 | def range(self, emit_: emit.Julia, frame: emit.JuliaFrame, node: Range): 14 | start = frame.get(node.start) 15 | stop = frame.get(node.stop) 16 | step = frame.get(node.step) 17 | frame.write_line(f"{frame.ssa[node.result]} = {start}:{step}:{stop}") 18 | return (frame.ssa[node.result],) 19 | -------------------------------------------------------------------------------- /test/program/py/test_class.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.prelude import basic 4 | 5 | 6 | class Foo: 7 | 8 | def some_kernel(self): 9 | @basic(verify=False) 10 | def goo(x: int): 11 | return x 12 | 13 | return goo 14 | 15 | def another_kernel(self): 16 | @basic(verify=False) 17 | def goo(x: int): 18 | kernel = self.some_kernel() 19 | kernel(x) 20 | 21 | return goo 22 | 23 | 24 | def test_call_method_error(): 25 | foo = Foo() 26 | goo = foo.another_kernel() 27 | 28 | with pytest.raises(Exception): 29 | goo.verify_type() 30 | -------------------------------------------------------------------------------- /src/kirin/passes/hint_const.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin.ir import Method 4 | from kirin.rewrite import Walk, WrapConst 5 | from kirin.analysis import const 6 | from kirin.passes.abc import Pass 7 | from kirin.rewrite.abc import RewriteResult 8 | 9 | 10 | @dataclass 11 | class HintConst(Pass): 12 | 13 | def unsafe_run(self, mt: Method) -> RewriteResult: 14 | constprop = const.Propagate(self.dialects) 15 | if self.no_raise: 16 | frame, _ = constprop.run_no_raise(mt) 17 | else: 18 | frame, _ = constprop.run(mt) 19 | return Walk(WrapConst(frame)).rewrite(mt.code) 20 | -------------------------------------------------------------------------------- /test/verify/test_method_verify.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.prelude import basic 4 | 5 | 6 | def test_verification(capsys): 7 | @basic 8 | def my_kernel(arg1, arg2): 9 | return arg1 + arg2 10 | 11 | with pytest.raises(Exception): 12 | 13 | @basic 14 | def main(): 15 | my_kernel(5) # type: ignore 16 | 17 | captured = capsys.readouterr() 18 | assert "test/verify/test_method_verify.py" in captured.err 19 | assert "line 11" in captured.err 20 | assert ( 21 | "Verification failed for main: expected 3 arguments, got 1" in captured.err 22 | ) 23 | -------------------------------------------------------------------------------- /docs/cookbook/index.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | This page is under construction. The content may be incomplete or incorrect. Submit an issue 3 | on [GitHub](https://github.com/QuEraComputing/kirin/issues/new) if you need help or want to 4 | contribute. 5 | 6 | 7 | # Cookbook 8 | 9 | This cookbook provides a collection of quick examples to help you get started with Kirin. Each 10 | example is a self-contained page that demonstrates a specific feature or use case. You can copy and 11 | paste the code snippets into your own project and modify them as needed. 12 | 13 | ## Table of Contents 14 | 15 | - [Continuation of FoodLang](foodlang/cf_rewrite.md) 16 | -------------------------------------------------------------------------------- /src/kirin/dialects/cf/__init__.py: -------------------------------------------------------------------------------- 1 | """Control flow dialect. 2 | 3 | This dialect provides a low-level control flow representation. 4 | 5 | This dialect does not provide any lowering strategies, to lowering 6 | a Python AST to this dialect, use the `kirin.dialects.lowering.cf` dialect 7 | with this dialect. 8 | """ 9 | 10 | from kirin.dialects.cf import abstract as abstract, constprop as constprop 11 | from kirin.dialects.cf.stmts import ( 12 | Branch as Branch, 13 | ConditionalBranch as ConditionalBranch, 14 | ) 15 | from kirin.dialects.cf.interp import CfInterpreter as CfInterpreter 16 | from kirin.dialects.cf.dialect import dialect as dialect 17 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/list/__init__.py: -------------------------------------------------------------------------------- 1 | """The list dialect for Python. 2 | 3 | This module contains the dialect for list semantics in Python, including: 4 | 5 | - The `New` and `Append` statement classes. 6 | - The lowering pass for list operations. 7 | - The concrete implementation of list operations. 8 | - The type inference implementation of list operations. 9 | 10 | This dialect maps `list()`, `ast.List` and `append()` calls to the `New` and `Append` statements. 11 | """ 12 | 13 | from . import interp as interp, lowering as lowering, typeinfer as typeinfer 14 | from .stmts import New as New, Append as Append 15 | from ._dialect import dialect as dialect 16 | -------------------------------------------------------------------------------- /src/kirin/serialization/core/deserializable.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import TYPE_CHECKING, Protocol, runtime_checkable 3 | 4 | from typing_extensions import Self 5 | 6 | if TYPE_CHECKING: 7 | from kirin.serialization.base.deserializer import Deserializer 8 | from kirin.serialization.core.serializationunit import SerializationUnit 9 | 10 | 11 | @runtime_checkable 12 | class Deserializable(Protocol): 13 | 14 | @classmethod 15 | @abstractmethod 16 | def deserialize( 17 | cls: type[Self], serUnit: "SerializationUnit", deserializer: "Deserializer" 18 | ) -> Self: 19 | raise NotImplementedError 20 | -------------------------------------------------------------------------------- /example/food/group.py: -------------------------------------------------------------------------------- 1 | from dialect import dialect 2 | 3 | from rewrite import RandomWalkBranch 4 | from kirin.ir import dialect_group 5 | from kirin.prelude import basic_no_opt 6 | from kirin.rewrite import Walk, Fixpoint 7 | from kirin.passes.fold import Fold 8 | 9 | 10 | # create our own food dialect, it runs a random walk on the branches 11 | @dialect_group(basic_no_opt.add(dialect)) 12 | def food(self): 13 | 14 | fold_pass = Fold(self) 15 | 16 | def run_pass(mt, *, fold=True): 17 | Fixpoint(Walk(RandomWalkBranch())).rewrite(mt.code) 18 | 19 | # add const fold 20 | if fold: 21 | fold_pass(mt) 22 | 23 | return run_pass 24 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/constprop.py: -------------------------------------------------------------------------------- 1 | from kirin import interp 2 | from kirin.analysis import const 3 | 4 | from . import stmts 5 | from ._dialect import dialect 6 | 7 | 8 | @dialect.register(key="constprop") 9 | class ConstProp(interp.MethodTable): 10 | 11 | @interp.impl(stmts.Not) 12 | def not_( 13 | self, _: const.Propagate, frame: const.Frame, stmt: stmts.Not 14 | ) -> interp.StatementResult[const.Result]: 15 | hint = frame.get(stmt.value) 16 | if isinstance(hint, (const.PartialTuple, const.Value)): 17 | ret = const.Value(not hint.data) 18 | else: 19 | ret = const.Unknown() 20 | return (ret,) 21 | -------------------------------------------------------------------------------- /test/dialects/test_debug.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | from kirin.dialects import debug 3 | 4 | 5 | def test_debug_printing(): 6 | @basic_no_opt.add(debug) 7 | def test_if_inside_for() -> int: 8 | count = 0 9 | for i in range(5): 10 | count = count + 1 11 | something_else = count + 2 12 | debug.info("current count before", count, something_else) 13 | if True: 14 | count = count + 100 15 | debug.info("inside the ifelse", count, something_else) 16 | else: 17 | count = count + 300 18 | return count 19 | 20 | test_if_inside_for() 21 | -------------------------------------------------------------------------------- /test/lowering/test_list.py: -------------------------------------------------------------------------------- 1 | from kirin import types, lowering 2 | from kirin.dialects import cf, py, func 3 | from kirin.dialects.lowering import func as func_lowering 4 | 5 | lower = lowering.Python([cf, func, py.base, py.list, py.assign, func_lowering]) 6 | 7 | 8 | def test_empty_list(): 9 | 10 | def empty_list(): 11 | x = [] 12 | return x 13 | 14 | code = lower.python_function(empty_list) 15 | 16 | list_stmt = code.body.blocks[0].stmts.at(0) # type: ignore 17 | 18 | assert isinstance(list_stmt, py.list.New) 19 | assert len(list_stmt._results) == 1 20 | 21 | res = list_stmt._results[0] 22 | assert res.type.is_subseteq(types.List) 23 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/list/lowering.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from kirin import types, lowering 4 | 5 | from .stmts import New 6 | from ._dialect import dialect 7 | 8 | 9 | @dialect.register 10 | class PythonLowering(lowering.FromPythonAST): 11 | 12 | def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result: 13 | elts = tuple(state.lower(each).expect_one() for each in node.elts) 14 | 15 | if len(elts): 16 | typ = elts[0].type 17 | for each in elts: 18 | typ = typ.join(each.type) 19 | else: 20 | typ = types.Any 21 | 22 | return state.current_frame.push(New(values=tuple(elts))) 23 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/binop/lowering.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from kirin import lowering 4 | 5 | from . import stmts 6 | from ._dialect import dialect 7 | 8 | 9 | @dialect.register 10 | class Lowering(lowering.FromPythonAST): 11 | 12 | def lower_BinOp(self, state: lowering.State, node: ast.BinOp) -> lowering.Result: 13 | lhs = state.lower(node.left).expect_one() 14 | rhs = state.lower(node.right).expect_one() 15 | 16 | if op := getattr(stmts, node.op.__class__.__name__, None): 17 | stmt = op(lhs=lhs, rhs=rhs) 18 | else: 19 | raise lowering.BuildError(f"unsupported binop {node.op}") 20 | return state.current_frame.push(stmt) 21 | -------------------------------------------------------------------------------- /src/kirin/lattice/empty.py: -------------------------------------------------------------------------------- 1 | from kirin.lattice.abc import SingletonMeta, BoundedLattice 2 | 3 | 4 | class EmptyLattice(BoundedLattice["EmptyLattice"], metaclass=SingletonMeta): 5 | """Empty lattice.""" 6 | 7 | def join(self, other: "EmptyLattice") -> "EmptyLattice": 8 | return self 9 | 10 | def meet(self, other: "EmptyLattice") -> "EmptyLattice": 11 | return self 12 | 13 | @classmethod 14 | def bottom(cls): 15 | return cls() 16 | 17 | @classmethod 18 | def top(cls): 19 | return cls() 20 | 21 | def __hash__(self) -> int: 22 | return id(self) 23 | 24 | def is_subseteq(self, other: "EmptyLattice") -> bool: 25 | return True 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-yaml 8 | args: ['--unsafe'] 9 | - id: end-of-file-fixer 10 | - id: trailing-whitespace 11 | - repo: https://github.com/pycqa/isort 12 | rev: 6.0.1 13 | hooks: 14 | - id: isort 15 | name: isort (python) 16 | - repo: https://github.com/psf/black 17 | rev: 25.1.0 18 | hooks: 19 | - id: black 20 | - repo: https://github.com/charliermarsh/ruff-pre-commit 21 | # Ruff version. 22 | rev: "v0.12.1" 23 | hooks: 24 | - id: ruff 25 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/cmp/__init__.py: -------------------------------------------------------------------------------- 1 | """The cmp dialect for Python. 2 | 3 | This module contains the dialect for comparison semantics in Python, including: 4 | 5 | - The `Eq`, `NotEq`, `Lt`, `LtE`, `Gt`, `GtE`, `Is`, and `IsNot` statement classes. 6 | - The lowering pass for comparison operations. 7 | - The concrete implementation of comparison operations. 8 | - The Julia emitter for comparison operations. 9 | 10 | This dialect maps `ast.Compare` nodes to the `Eq`, `NotEq`, `Lt`, `LtE`, 11 | `Gt`, `GtE`, `Is`, and `IsNot` statements. 12 | """ 13 | 14 | from . import _julia as _julia, interp as interp, lowering as lowering 15 | from .stmts import * # noqa: F403 16 | from ._dialect import dialect as dialect 17 | -------------------------------------------------------------------------------- /src/kirin/analysis/const/__init__.py: -------------------------------------------------------------------------------- 1 | """Const analysis module. 2 | 3 | This module contains the constant analysis framework for kirin. The constant 4 | analysis framework is built on top of the interpreter framework. 5 | 6 | This module provides a lattice for constant propagation analysis and a 7 | propagation algorithm for computing the constant values for each SSA value in 8 | the IR. 9 | """ 10 | 11 | from .prop import Frame as Frame, Propagate as Propagate 12 | from .lattice import ( 13 | Value as Value, 14 | Bottom as Bottom, 15 | Result as Result, 16 | Unknown as Unknown, 17 | PartialConst as PartialConst, 18 | PartialTuple as PartialTuple, 19 | PartialLambda as PartialLambda, 20 | ) 21 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/list/stmts.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | 4 | from ._dialect import dialect 5 | 6 | T = types.TypeVar("T") 7 | 8 | 9 | @statement(dialect=dialect) 10 | class New(ir.Statement): 11 | name = "list" 12 | traits = frozenset({lowering.FromPythonCall()}) 13 | values: tuple[ir.SSAValue, ...] = info.argument(T) 14 | result: ir.ResultValue = info.result(types.List[T]) 15 | 16 | 17 | @statement(dialect=dialect) 18 | class Append(ir.Statement): 19 | name = "append" 20 | traits = frozenset({lowering.FromPythonCall()}) 21 | list_: ir.SSAValue = info.argument(types.List[T]) 22 | value: ir.SSAValue = info.argument(T) 23 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/typeinfer.py: -------------------------------------------------------------------------------- 1 | from kirin import types, interp 2 | 3 | from . import stmts 4 | from ._dialect import dialect 5 | 6 | 7 | @dialect.register(key="typeinfer") 8 | class TypeInfer(interp.MethodTable): 9 | 10 | @interp.impl(stmts.UAdd) 11 | @interp.impl(stmts.USub) 12 | def uadd( 13 | self, interp, frame: interp.Frame[types.TypeAttribute], stmt: stmts.UnaryOp 14 | ): 15 | return (frame.get(stmt.value),) 16 | 17 | @interp.impl(stmts.Not) 18 | def not_(self, interp, frame, stmt: stmts.Not): 19 | return (types.Bool,) 20 | 21 | @interp.impl(stmts.Invert, types.Int) 22 | def invert(self, interp, frame, stmt): 23 | return (types.Int,) 24 | -------------------------------------------------------------------------------- /src/kirin/print/__init__.py: -------------------------------------------------------------------------------- 1 | """Pretty printing utilities. 2 | 3 | This module provides a pretty printing utility for the IR nodes and other 4 | objects in the compiler. 5 | 6 | The pretty printing utility is implemented using the visitor pattern. The 7 | [`Printable`][kirin.print.Printable] class is the base class for all objects that can be pretty printed. 8 | 9 | The [`Printer`][kirin.print.Printer] class is the visitor that traverses the object and prints the 10 | object to a string. The [`Printer`][kirin.print.Printer] class provides methods for printing different 11 | types of objects. 12 | """ 13 | 14 | from kirin.print.printer import Printer as Printer 15 | from kirin.print.printable import Printable as Printable 16 | -------------------------------------------------------------------------------- /src/kirin/rewrite/getfield.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin import ir 4 | from kirin.dialects import func 5 | from kirin.rewrite.abc import RewriteRule, RewriteResult 6 | 7 | 8 | @dataclass 9 | class InlineGetField(RewriteRule): 10 | 11 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 12 | if not isinstance(node, func.GetField): 13 | return RewriteResult() 14 | 15 | if not isinstance(node.obj.owner, func.Lambda): 16 | return RewriteResult() 17 | 18 | original = node.obj.owner.captured[node.field] 19 | node.result.replace_by(original) 20 | node.delete() 21 | return RewriteResult(has_done_something=True) 22 | -------------------------------------------------------------------------------- /src/kirin/dialects/func/__init__.py: -------------------------------------------------------------------------------- 1 | """A function dialect that is compatible with python semantics.""" 2 | 3 | from kirin.dialects.func import ( 4 | interp as interp, 5 | constprop as constprop, 6 | typeinfer as typeinfer, 7 | ) 8 | from kirin.dialects.func.attrs import Signature as Signature 9 | from kirin.dialects.func.stmts import ( 10 | Call as Call, 11 | Invoke as Invoke, 12 | Lambda as Lambda, 13 | Return as Return, 14 | Function as Function, 15 | GetField as GetField, 16 | ConstantNone as ConstantNone, 17 | FuncOpCallableInterface as FuncOpCallableInterface, 18 | ) 19 | from kirin.dialects.func._dialect import dialect as dialect 20 | 21 | from . import ( 22 | _julia as _julia, 23 | ) 24 | -------------------------------------------------------------------------------- /src/kirin/rewrite/type_assert.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin import ir 4 | from kirin.rewrite.abc import RewriteRule, RewriteResult 5 | from kirin.dialects.py.assign import TypeAssert 6 | 7 | 8 | @dataclass 9 | class InlineTypeAssert(RewriteRule): 10 | 11 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 12 | if not isinstance(node, TypeAssert): 13 | return RewriteResult() 14 | 15 | if node.got.type.is_subseteq(node.expected): 16 | node.got.type = node.got.type.meet(node.expected) 17 | node.result.replace_by(node.got) 18 | node.delete() 19 | return RewriteResult(has_done_something=True) 20 | return RewriteResult() 21 | -------------------------------------------------------------------------------- /src/kirin/interp/concrete.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | from dataclasses import dataclass 5 | 6 | from kirin import ir 7 | 8 | from .abc import InterpreterABC 9 | from .frame import Frame 10 | 11 | 12 | @dataclass 13 | class Interpreter(InterpreterABC[Frame[Any], Any]): 14 | keys = ("main",) 15 | void = None 16 | 17 | def initialize_frame( 18 | self, node: ir.Statement, *, has_parent_access: bool = False 19 | ) -> Frame[Any]: 20 | """Initialize the frame for the given node.""" 21 | return Frame(node, has_parent_access=has_parent_access) 22 | 23 | def run(self, method: ir.Method, *args, **kwargs): 24 | return self.call(method, method, *args, **kwargs) 25 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/_set_new_attribute.py: -------------------------------------------------------------------------------- 1 | """Copied from dataclasses in Python 3.10.13.""" 2 | 3 | from types import FunctionType 4 | 5 | 6 | def set_qualname(cls: type, value): 7 | # Ensure that the functions returned from _create_fn uses the proper 8 | # __qualname__ (the class they belong to). 9 | if isinstance(value, FunctionType): 10 | value.__qualname__ = f"{cls.__qualname__}.{value.__name__}" 11 | return value 12 | 13 | 14 | def set_new_attribute(cls: type, name: str, value): 15 | # Never overwrites an existing attribute. Returns True if the 16 | # attribute already exists. 17 | if name in cls.__dict__: 18 | return True 19 | set_qualname(cls, value) 20 | setattr(cls, name, value) 21 | return False 22 | -------------------------------------------------------------------------------- /src/kirin/ir/attrs/__init__.py: -------------------------------------------------------------------------------- 1 | """Compile-time values in Kirin IR. 2 | 3 | This module contains the following: 4 | 5 | ## `abc` module 6 | `abc.AttributeMeta`: The metaclass for all attributes. 7 | `abc.Attribute`: The base class for all attributes. 8 | 9 | ## `types` module 10 | `types.TypeAttribute`: The base class for all type attributes. 11 | `types.PyClass`: A type attribute representing a Python class. 12 | `types.TypeVar`: A type attribute representing a type variable. 13 | `types.Literal`: A type attribute representing a literal type. 14 | `types.Generic`: A type attribute representing a generic type. 15 | `types.Union`: A type attribute representing a union type. 16 | 17 | ## `py` module 18 | 19 | `py.PyAttr`: An attribute representing a Python value. 20 | """ 21 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/list/interp.py: -------------------------------------------------------------------------------- 1 | from kirin import types, interp 2 | from kirin.dialects.py.binop import Add 3 | 4 | from .stmts import New, Append 5 | from ._dialect import dialect 6 | 7 | 8 | @dialect.register 9 | class ListMethods(interp.MethodTable): 10 | 11 | @interp.impl(New) 12 | def new(self, interp, frame: interp.Frame, stmt: New): 13 | return (list(frame.get_values(stmt.values)),) 14 | 15 | @interp.impl(Add, types.PyClass(list), types.PyClass(list)) 16 | def add(self, interp, frame: interp.Frame, stmt: Add): 17 | return (frame.get(stmt.lhs) + frame.get(stmt.rhs),) 18 | 19 | @interp.impl(Append) 20 | def append(self, interp, frame: interp.Frame, stmt: Append): 21 | frame.get(stmt.list_).append(frame.get(stmt.value)) 22 | -------------------------------------------------------------------------------- /test/program/py/test_list_append.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | from kirin.prelude import python_no_opt 3 | from kirin.dialects import py 4 | 5 | 6 | def test_list_append(): 7 | 8 | @python_no_opt 9 | def test_append(): 10 | x = [] 11 | py.Append(x, 1) 12 | py.Append(x, 2) 13 | return x 14 | 15 | y = test_append() 16 | 17 | assert len(y) == 2 18 | assert y[0] == 1 19 | assert y[1] == 2 20 | 21 | 22 | def test_recursive_append(): 23 | @python_no_opt 24 | def for_loop_append(cntr: int, x: list, n_range: int): 25 | if cntr < n_range: 26 | py.Append(x, cntr) 27 | for_loop_append(cntr + 1, x, n_range) 28 | 29 | return x 30 | 31 | assert for_loop_append(0, [], 5) == [0, 1, 2, 3, 4] 32 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/stmts.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | 4 | from ._dialect import dialect 5 | 6 | T = types.TypeVar("T") 7 | 8 | 9 | @statement 10 | class UnaryOp(ir.Statement): 11 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 12 | value: ir.SSAValue = info.argument(T, print=False) 13 | result: ir.ResultValue = info.result(T) 14 | 15 | 16 | @statement(dialect=dialect) 17 | class UAdd(UnaryOp): 18 | name = "uadd" 19 | 20 | 21 | @statement(dialect=dialect) 22 | class USub(UnaryOp): 23 | name = "usub" 24 | 25 | 26 | @statement(dialect=dialect) 27 | class Not(UnaryOp): 28 | name = "not" 29 | 30 | 31 | @statement(dialect=dialect) 32 | class Invert(UnaryOp): 33 | name = "invert" 34 | -------------------------------------------------------------------------------- /test/analysis/test_callgraph.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | from kirin.analysis.callgraph import CallGraph 3 | 4 | 5 | @basic_no_opt 6 | def abc(a, b): 7 | return a + b 8 | 9 | 10 | @basic_no_opt 11 | def bcd(a, b): 12 | return a - b 13 | 14 | 15 | @basic_no_opt 16 | def cde(a, b): 17 | return abc(a, b) + bcd(a, b) 18 | 19 | 20 | @basic_no_opt 21 | def defg(a, b): 22 | return cde(a, b) + abc(a, b) 23 | 24 | 25 | @basic_no_opt 26 | def efg(a, b): 27 | return defg(a, b) + bcd(a, b) 28 | 29 | 30 | def test_callgraph(): 31 | graph = CallGraph(efg) 32 | graph.print() 33 | assert cde in graph.get_neighbors(abc) 34 | assert defg in graph.get_neighbors(abc) 35 | assert cde in graph.get_neighbors(abc) 36 | assert defg in graph.get_neighbors(abc) 37 | -------------------------------------------------------------------------------- /example/pauli/stmts.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | 3 | import numpy as np 4 | 5 | from kirin import ir, types, lowering 6 | from kirin.decl import info, statement 7 | 8 | from .dialect import _dialect 9 | 10 | 11 | @statement 12 | class PauliOperator(ir.Statement): 13 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 14 | pre_factor: Number = info.attribute(default=1) 15 | result: ir.ResultValue = info.result(types.PyClass(np.matrix)) 16 | 17 | 18 | @statement(dialect=_dialect) 19 | class X(PauliOperator): 20 | pass 21 | 22 | 23 | @statement(dialect=_dialect) 24 | class Y(PauliOperator): 25 | pass 26 | 27 | 28 | @statement(dialect=_dialect) 29 | class Z(PauliOperator): 30 | pass 31 | 32 | 33 | @statement(dialect=_dialect) 34 | class Id(PauliOperator): 35 | pass 36 | -------------------------------------------------------------------------------- /test/analysis/dataflow/constprop/test_worklist.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | from kirin.analysis import const 3 | 4 | 5 | def test_worklist_bfs(): 6 | @basic_no_opt 7 | def make_ker(val: float): 8 | 9 | def ker(i: float): 10 | return i + val 11 | 12 | return ker 13 | 14 | @basic_no_opt 15 | def test(x: str, y: float, flag: bool): 16 | 17 | if x == "x": 18 | val = 1.0 19 | else: 20 | val = 2.0 21 | 22 | if flag: 23 | ker = make_ker(val=val) 24 | 25 | else: 26 | ker = make_ker(val=val) 27 | 28 | return ker 29 | 30 | # test.print() 31 | prop = const.Propagate(basic_no_opt) 32 | frame, ret = prop.run(test) 33 | assert isinstance(ret, const.PartialLambda) 34 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/interp.py: -------------------------------------------------------------------------------- 1 | from kirin import interp 2 | 3 | from . import stmts 4 | from ._dialect import dialect 5 | 6 | 7 | @dialect.register 8 | class Concrete(interp.MethodTable): 9 | 10 | @interp.impl(stmts.UAdd) 11 | def uadd(self, interp, frame: interp.Frame, stmt: stmts.UAdd): 12 | return (+frame.get(stmt.value),) 13 | 14 | @interp.impl(stmts.USub) 15 | def usub(self, interp, frame: interp.Frame, stmt: stmts.USub): 16 | return (-frame.get(stmt.value),) 17 | 18 | @interp.impl(stmts.Not) 19 | def not_(self, interp, frame: interp.Frame, stmt: stmts.Not): 20 | return (not frame.get(stmt.value),) 21 | 22 | @interp.impl(stmts.Invert) 23 | def invert(self, interp, frame: interp.Frame, stmt: stmts.Invert): 24 | return (~frame.get(stmt.value),) 25 | -------------------------------------------------------------------------------- /src/kirin/rewrite/__init__.py: -------------------------------------------------------------------------------- 1 | from .cse import CommonSubexpressionElimination as CommonSubexpressionElimination 2 | from .dce import DeadCodeElimination as DeadCodeElimination 3 | from .fold import ConstantFold as ConstantFold 4 | from .walk import Walk as Walk 5 | from .alias import InlineAlias as InlineAlias 6 | from .chain import Chain as Chain 7 | from .inline import Inline as Inline 8 | from .getitem import InlineGetItem as InlineGetItem 9 | from .fixpoint import Fixpoint as Fixpoint 10 | from .getfield import InlineGetField as InlineGetField 11 | from .apply_type import ApplyType as ApplyType 12 | from .compactify import CFGCompactify as CFGCompactify 13 | from .wrap_const import WrapConst as WrapConst 14 | from .call2invoke import Call2Invoke as Call2Invoke 15 | from .type_assert import InlineTypeAssert as InlineTypeAssert 16 | -------------------------------------------------------------------------------- /example/simple.py: -------------------------------------------------------------------------------- 1 | """A minimal language example with a single pass that does nothing.""" 2 | 3 | from kirin import ir 4 | from kirin.dialects import cf, py, func, lowering 5 | 6 | 7 | @ir.dialect_group( 8 | [ 9 | func, 10 | lowering.func, 11 | lowering.call, 12 | lowering.cf, 13 | py.base, 14 | py.constant, 15 | py.assign, 16 | py.binop, 17 | py.unary, 18 | ] 19 | ) 20 | def simple(self): 21 | def run_pass(mt): 22 | return mt 23 | 24 | return run_pass 25 | 26 | 27 | @simple 28 | def main(x): 29 | y = x + 1 30 | return y 31 | 32 | 33 | main.print() 34 | 35 | 36 | @simple.add(cf).add(py.cmp) 37 | def main2(x): 38 | y = x + 1 39 | if y > 0: # errors 40 | return y 41 | else: 42 | return -y 43 | 44 | 45 | main2.print() 46 | -------------------------------------------------------------------------------- /src/kirin/passes/inline.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from dataclasses import field, dataclass 3 | 4 | from kirin import ir 5 | from kirin.passes import Pass 6 | from kirin.rewrite import Walk, Inline, Fixpoint, CFGCompactify, DeadCodeElimination 7 | from kirin.rewrite.abc import RewriteResult 8 | 9 | 10 | def aggresive(x: ir.IRNode) -> bool: 11 | return True 12 | 13 | 14 | @dataclass 15 | class InlinePass(Pass): 16 | heuristic: Callable[[ir.IRNode], bool] = field(default=aggresive) 17 | 18 | def unsafe_run(self, mt: ir.Method) -> RewriteResult: 19 | 20 | result = Walk(Inline(heuristic=self.heuristic)).rewrite(mt.code) 21 | result = Walk(CFGCompactify()).rewrite(mt.code).join(result) 22 | 23 | # dce 24 | dce = DeadCodeElimination() 25 | return Fixpoint(Walk(dce)).rewrite(mt.code).join(result) 26 | -------------------------------------------------------------------------------- /test/ir/test_isequal.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.dialects import func 3 | 4 | 5 | def test_is_structurally_equal_ignoring_hint(): 6 | block = ir.Block() 7 | block.args.append_from(types.MethodType, "self") 8 | source_func = func.Function( 9 | sym_name="main", 10 | signature=func.Signature( 11 | inputs=(), 12 | output=types.NoneType, 13 | ), 14 | body=ir.Region(block), 15 | ) 16 | 17 | block = ir.Block() 18 | block.args.append_from(types.MethodType, "self") 19 | expected_func = func.Function( 20 | sym_name="main", 21 | signature=func.Signature( 22 | inputs=(), 23 | output=types.NoneType, 24 | ), 25 | body=ir.Region(block), 26 | ) 27 | 28 | assert expected_func.is_structurally_equal(source_func) 29 | -------------------------------------------------------------------------------- /test/lowering/test_hint_union_binop.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from kirin import types 4 | from kirin.prelude import basic 5 | from kirin.dialects import ilist 6 | 7 | 8 | def test_method_union_binop_hint(): 9 | 10 | @basic 11 | def main(x: ilist.IList[float, Any] | list[float]) -> float: 12 | return x[0] 13 | 14 | main.print() 15 | 16 | tps = main.arg_types 17 | 18 | assert len(tps) == 1 19 | assert tps[0] == types.Union( 20 | [ilist.IListType[types.Float, types.Any], types.List[types.Float]] 21 | ) 22 | 23 | 24 | def test_method_union_multi_hint(): 25 | 26 | @basic 27 | def main(x: str | float | int): 28 | return x 29 | 30 | main.print() 31 | 32 | tps = main.arg_types 33 | 34 | assert len(tps) == 1 35 | assert tps[0] == types.Union([types.String, types.Float, types.Int]) 36 | -------------------------------------------------------------------------------- /src/kirin/ir/attrs/data.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Generic, TypeVar 3 | from dataclasses import field, dataclass 4 | 5 | from .abc import Attribute 6 | from .types import TypeAttribute 7 | 8 | T = TypeVar("T", covariant=True) 9 | 10 | 11 | @dataclass(eq=False) 12 | class Data(Attribute, Generic[T]): 13 | """Base class for data attributes. 14 | 15 | Data attributes are compile-time constants that can be used to 16 | represent runtime data inside the IR. 17 | 18 | This class is meant to be subclassed by specific data attributes. 19 | It provides a `type` attribute that should be set to the type of 20 | the data. 21 | """ 22 | 23 | type: TypeAttribute = field(init=False, repr=False) 24 | 25 | @abstractmethod 26 | def unwrap(self) -> T: 27 | """Returns the underlying data value.""" 28 | ... 29 | -------------------------------------------------------------------------------- /src/kirin/dialects/eltype.py: -------------------------------------------------------------------------------- 1 | """This dialect offers a statement `eltype` for other dialects' 2 | type inference to query/implement the element type of a value. 3 | For example, the `ilist` dialect implements the `eltype` statement 4 | on the `ilist.IList` type to return the element type. 5 | """ 6 | 7 | from kirin import ir, types 8 | from kirin.decl import info, statement 9 | 10 | dialect = ir.Dialect("eltype") 11 | 12 | 13 | @statement(dialect=dialect) 14 | class ElType(ir.Statement): 15 | """Returns the element type of a value. 16 | 17 | This statement is used by other dialects to query the element type of a value. 18 | """ 19 | 20 | container: ir.SSAValue = info.argument(types.Any) 21 | """The value to query the element type of.""" 22 | elem: ir.ResultValue = info.result(types.PyClass(types.TypeAttribute)) 23 | """The element type of the value.""" 24 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/unary/__init__.py: -------------------------------------------------------------------------------- 1 | """The unary dialect for Python. 2 | 3 | This module contains the dialect for unary semantics in Python, including: 4 | 5 | - The `UnaryOp` base class for unary operations. 6 | - The `UAdd`, `USub`, `Not`, and `Invert` statement classes. 7 | - The lowering pass for unary operations. 8 | - The concrete implementation of unary operations. 9 | - The type inference implementation of unary operations. 10 | - The constant propagation implementation of unary operations. 11 | - The Julia emitter for unary operations. 12 | 13 | This dialect maps `ast.UnaryOp` nodes to the `UAdd`, `USub`, `Not`, and `Invert` statements. 14 | """ 15 | 16 | from . import ( 17 | interp as interp, 18 | lowering as lowering, 19 | constprop as constprop, 20 | typeinfer as typeinfer, 21 | ) 22 | from .stmts import * # noqa: F403 23 | from ._dialect import dialect as dialect 24 | -------------------------------------------------------------------------------- /src/kirin/dialects/random/__init__.py: -------------------------------------------------------------------------------- 1 | from kirin import lowering 2 | 3 | from . import stmts as stmts, interp as interp 4 | from ._dialect import dialect as dialect 5 | 6 | 7 | @lowering.wraps(stmts.Random) 8 | def random() -> float: 9 | """ 10 | Generate a random floating number between 0 and 1. 11 | """ 12 | ... 13 | 14 | 15 | @lowering.wraps(stmts.RandInt) 16 | def randint(start: int, stop: int) -> int: 17 | """ 18 | Generate a random integer between the given range. 19 | """ 20 | ... 21 | 22 | 23 | @lowering.wraps(stmts.Uniform) 24 | def uniform(start: float, stop: float) -> float: 25 | """ 26 | Generate a random floating number between the given range. 27 | """ 28 | ... 29 | 30 | 31 | @lowering.wraps(stmts.Seed) 32 | def seed(value: int) -> None: 33 | """ 34 | Set the seed for the random number generator. 35 | """ 36 | ... 37 | -------------------------------------------------------------------------------- /example/pauli/interp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from kirin.interp import MethodTable, impl 4 | 5 | from .stmts import X, Y, Z, Id 6 | from .dialect import _dialect 7 | 8 | 9 | @_dialect.register 10 | class PauliMethods(MethodTable): 11 | X_mat = np.array([[0, 1], [1, 0]]) 12 | Y_mat = np.array([[0, -1j], [1j, 0]]) 13 | Z_mat = np.array([[1, 0], [0, -1]]) 14 | Id_mat = np.array([[1, 0], [0, 1]]) 15 | 16 | @impl(X) # (1)! 17 | def x(self, interp, frame, stmt: X): 18 | return (stmt.pre_factor * self.X_mat,) 19 | 20 | @impl(Y) 21 | def y(self, interp, frame, stmt: Y): 22 | return (self.Y_mat * stmt.pre_factor,) 23 | 24 | @impl(Z) 25 | def z(self, interp, frame, stmt: Z): 26 | return (self.Z_mat * stmt.pre_factor,) 27 | 28 | @impl(Id) 29 | def id(self, interp, frame, stmt: Id): 30 | return (self.Id_mat * stmt.pre_factor,) 31 | -------------------------------------------------------------------------------- /src/kirin/ir/attrs/_types.pyi: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .abc import Attribute 4 | from .types import ( 5 | Union, 6 | Generic, 7 | Literal, 8 | PyClass, 9 | TypeVar, 10 | FunctionType, 11 | TypeAttribute, 12 | TypeofMethodType, 13 | ) 14 | 15 | @dataclass 16 | class _TypeAttribute(Attribute): 17 | def is_subseteq_Union(self, other: Union) -> bool: ... 18 | def is_subseteq_Literal(self, other: Literal) -> bool: ... 19 | def is_subseteq_TypeVar(self, other: TypeVar) -> bool: ... 20 | def is_subseteq_PyClass(self, other: PyClass) -> bool: ... 21 | def is_subseteq_Generic(self, other: Generic) -> bool: ... 22 | def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ... 23 | def is_subseteq_TypeofMethodType(self, other: TypeofMethodType) -> bool: ... 24 | def is_subseteq_FunctionType(self, other: FunctionType) -> bool: ... 25 | -------------------------------------------------------------------------------- /test/lowering/test_337.py: -------------------------------------------------------------------------------- 1 | from kirin import lowering 2 | from kirin.prelude import basic_no_opt 3 | from kirin.dialects import cf, func 4 | 5 | 6 | def test_issue_337(): 7 | def test_if_inside_for() -> int: 8 | count = 0 9 | for i in range(5): 10 | count = count + 1 11 | if True: 12 | count = count + 100 13 | else: 14 | count = count + 300 15 | return count 16 | 17 | lower = lowering.Python(basic_no_opt) 18 | code = lower.python_function(test_if_inside_for, compactify=True) 19 | assert isinstance(code, func.Function) 20 | loop_last_block = code.body.blocks[4] 21 | count_5 = loop_last_block.args[0] 22 | stmt = loop_last_block.stmts.at(-1) 23 | assert isinstance(stmt, cf.ConditionalBranch) 24 | assert stmt.then_arguments[0] is count_5 25 | assert stmt.else_arguments[1] is count_5 26 | -------------------------------------------------------------------------------- /test/analysis/dataflow/typeinfer/test_selfref_closure.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | 3 | from kirin import types 4 | from kirin.prelude import structural_no_opt 5 | from kirin.analysis import TypeInference 6 | 7 | 8 | @mark.xfail(reason="if with early return not supported in scf lowering") 9 | def test_self_ref_closure(): 10 | 11 | @structural_no_opt 12 | def should_work(n_qubits: int): 13 | def self_ref_source(i_layer): 14 | stride = n_qubits // (2**i_layer) 15 | if stride == 0: 16 | return 17 | 18 | self_ref_source(i_layer + 1) 19 | 20 | return self_ref_source 21 | 22 | infer = TypeInference(structural_no_opt) 23 | frame, ret = infer.run(should_work, types.Int) 24 | should_work.print(analysis=frame.entries) 25 | assert ret.is_structurally_equal( 26 | types.MethodType[types.Tuple[types.Any], types.NoneType] 27 | ) 28 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/binop/__init__.py: -------------------------------------------------------------------------------- 1 | """The binop dialect for Python. 2 | 3 | This module contains the dialect for binary operation semantics in Python, including: 4 | 5 | - The `Add`, `Sub`, `Mult`, `Div`, `FloorDiv`, `Mod`, `Pow`, 6 | `LShift`, `RShift`, `BitOr`, `BitXor`, and `BitAnd` statement classes. 7 | - The lowering pass for binary operations. 8 | - The concrete implementation of binary operations. 9 | - The type inference implementation of binary operations. 10 | - The Julia emitter for binary operations. 11 | 12 | This dialect maps `ast.BinOp` nodes to the `Add`, `Sub`, `Mult`, `Div`, `FloorDiv`, 13 | `Mod`, `Pow`, `LShift`, `RShift`, `BitOr`, `BitXor`, and `BitAnd` statements. 14 | """ 15 | 16 | from . import ( 17 | _julia as _julia, 18 | interp as interp, 19 | lowering as lowering, 20 | typeinfer as typeinfer, 21 | ) 22 | from .stmts import * # noqa: F403 23 | from ._dialect import dialect as dialect 24 | -------------------------------------------------------------------------------- /src/kirin/dialects/scf/__init__.py: -------------------------------------------------------------------------------- 1 | """A Python-like structural Control Flow dialect. 2 | 3 | This dialect provides constructs for expressing control flow in a structured 4 | manner. The dialect provides constructs for expressing loops and conditionals. 5 | Unlike MLIR SCF dialect, this dialect does not restrict the control flow to 6 | statically analyzable forms. This dialect is designed to be compatible with 7 | Python native control flow constructs. 8 | 9 | This dialect depends on the following dialects: 10 | - `eltype`: for obtaining the element type of a value. 11 | """ 12 | 13 | from . import ( 14 | trim as trim, 15 | _julia as _julia, 16 | absint as absint, 17 | interp as interp, 18 | unroll as unroll, 19 | lowering as lowering, 20 | constprop as constprop, 21 | typeinfer as typeinfer, 22 | ) 23 | from .stmts import For as For, Yield as Yield, IfElse as IfElse 24 | from ._dialect import dialect as dialect 25 | -------------------------------------------------------------------------------- /src/kirin/dialects/cf/interp.py: -------------------------------------------------------------------------------- 1 | from kirin.interp import Frame, Successor, Interpreter, MethodTable, impl 2 | from kirin.dialects.cf.stmts import Branch, ConditionalBranch 3 | from kirin.dialects.cf.dialect import dialect 4 | 5 | 6 | @dialect.register 7 | class CfInterpreter(MethodTable): 8 | 9 | @impl(Branch) 10 | def branch(self, interp: Interpreter, frame: Frame, stmt: Branch): 11 | return Successor(stmt.successor, *frame.get_values(stmt.arguments)) 12 | 13 | @impl(ConditionalBranch) 14 | def conditional_branch( 15 | self, interp: Interpreter, frame: Frame, stmt: ConditionalBranch 16 | ): 17 | if frame.get(stmt.cond): 18 | return Successor( 19 | stmt.then_successor, *frame.get_values(stmt.then_arguments) 20 | ) 21 | else: 22 | return Successor( 23 | stmt.else_successor, *frame.get_values(stmt.else_arguments) 24 | ) 25 | -------------------------------------------------------------------------------- /test/dialects/test_ilist2list.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | from kirin.prelude import python_basic 3 | from kirin.dialects import func, ilist, lowering 4 | 5 | 6 | @ir.dialect_group(python_basic.union([func, ilist, lowering.func])) 7 | def basic_desugar(self): 8 | ilist_desugar = ilist.IListDesugar(self) 9 | 10 | def run_pass( 11 | mt: ir.Method, 12 | ) -> None: 13 | ilist_desugar(mt) 14 | 15 | return run_pass 16 | 17 | 18 | def test_ilist2list_rewrite(): 19 | 20 | x = [1, 2, 3, 4] 21 | 22 | @basic_desugar 23 | def ilist2_list(): 24 | return x 25 | 26 | ilist2_list.print() 27 | 28 | x = ilist2_list() 29 | 30 | assert isinstance(x, ilist.IList) 31 | 32 | 33 | def test_range_rewrite(): 34 | 35 | r = range(10) 36 | 37 | @basic_desugar 38 | def ilist_range(): 39 | return r 40 | 41 | ilist_range.print() 42 | 43 | x = ilist_range() 44 | 45 | assert isinstance(x, ilist.IList) 46 | -------------------------------------------------------------------------------- /test/program/py/test_loop.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | 3 | 4 | def test_simple(): 5 | @basic_no_opt 6 | def main(x: int): 7 | for i in range(5): 8 | x = x + 1 9 | return x 10 | 11 | assert main.py_func is not None 12 | assert main.py_func(1) == main(1) 13 | 14 | 15 | # generate some more complicated loop 16 | def test_nested(): 17 | @basic_no_opt 18 | def main(x: int): 19 | for i in range(5): 20 | for j in range(5): 21 | x = x + 1 22 | return x 23 | 24 | assert main.py_func is not None 25 | assert main.py_func(1) == main(1) 26 | 27 | 28 | def test_nested2(): 29 | @basic_no_opt 30 | def main(x: int): 31 | for i in range(5): 32 | for j in range(5): 33 | for k in range(5): 34 | x = x + 1 35 | return x 36 | 37 | assert main.py_func is not None 38 | assert main.py_func(1) == main(1) 39 | -------------------------------------------------------------------------------- /test/rules/test_dce.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | from kirin.rewrite import Walk, Fixpoint 3 | from kirin.analysis import const 4 | from kirin.rewrite.dce import DeadCodeElimination 5 | from kirin.rewrite.fold import ConstantFold 6 | from kirin.rewrite.wrap_const import WrapConst 7 | 8 | 9 | @basic_no_opt 10 | def foldable(x: int) -> int: 11 | y = 1 12 | b = y + 2 13 | c = y + b 14 | d = c + 4 15 | return d + x 16 | 17 | 18 | def test_dce(): 19 | before = foldable(1) 20 | const_prop = const.Propagate(foldable.dialects) 21 | frame, _ = const_prop.run(foldable) 22 | Fixpoint(Walk(WrapConst(frame))).rewrite(foldable.code) 23 | fold = ConstantFold() 24 | Fixpoint(Walk(fold)).rewrite(foldable.code) 25 | 26 | foldable.code.print() 27 | dce = DeadCodeElimination() 28 | Fixpoint(Walk(dce)).rewrite(foldable.code) 29 | foldable.code.print() 30 | 31 | after = foldable(1) 32 | 33 | assert before == after 34 | -------------------------------------------------------------------------------- /src/kirin/symbol_table.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar 2 | from dataclasses import field, dataclass 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | @dataclass 8 | class SymbolTable(Generic[T]): 9 | names: dict[str, T] = field(default_factory=dict) 10 | """The table that maps names to values.""" 11 | prefix: str = field(default="", kw_only=True) 12 | name_count: dict[str, int] = field(default_factory=dict, kw_only=True) 13 | """The count of names that have been requested.""" 14 | 15 | def __getitem__(self, name: str) -> T: 16 | return self.names[name] 17 | 18 | def __contains__(self, name: str) -> bool: 19 | return name in self.names 20 | 21 | def __setitem__(self, name: str, value: T) -> None: 22 | count = self.name_count.setdefault(name, 0) 23 | self.name_count[name] = count + 1 24 | self.names[f"{self.prefix}_{name}_{count}"] = value 25 | 26 | def __delitem__(self, name: str) -> None: 27 | del self.names[name] 28 | -------------------------------------------------------------------------------- /test/dialects/py_dialect/test_tuple_infer.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types as ktypes 2 | from kirin.prelude import structural 3 | from kirin.analysis import TypeInference 4 | 5 | 6 | # stmt_at and results_at taken from kirin type inference tests with 7 | # minimal modification 8 | def stmt_at(kernel: ir.Method, block_id: int, stmt_id: int) -> ir.Statement: 9 | return kernel.code.body.blocks[block_id].stmts.at(stmt_id) # type: ignore 10 | 11 | 12 | def results_at(kernel: ir.Method, block_id: int, stmt_id: int): 13 | return stmt_at(kernel, block_id, stmt_id).results 14 | 15 | 16 | def test_tuple_type_infer(): 17 | 18 | @structural(typeinfer=True) 19 | def test(x: bool): 20 | a = [True, False, True] 21 | return (a[0], x) 22 | 23 | typeinfer = TypeInference(structural) 24 | frame, _ = typeinfer.run(test) 25 | 26 | assert [frame.entries[result] for result in results_at(test, 0, 1)] == [ 27 | ktypes.Generic(tuple, ktypes.Bool, ktypes.Bool) 28 | ] 29 | -------------------------------------------------------------------------------- /test/analysis/dataflow/test_cfg.py: -------------------------------------------------------------------------------- 1 | from kirin import lowering 2 | from kirin.prelude import basic_no_opt 3 | from kirin.dialects import func 4 | from kirin.analysis.cfg import CFG 5 | 6 | lower = lowering.Python(basic_no_opt) 7 | 8 | 9 | def deadblock(x): 10 | if x: 11 | return x + 1 12 | else: 13 | return x + 2 14 | return x + 3 15 | 16 | 17 | def test_reachable(): 18 | code = lower.python_function(deadblock, compactify=False) 19 | assert isinstance(code, func.Function) 20 | cfg = CFG(code.body) 21 | assert code.body.blocks[-1] not in cfg.successors 22 | 23 | 24 | def foo(x: int): # type: ignore 25 | def goo(y: int): 26 | return x + y 27 | 28 | return goo 29 | 30 | 31 | def test_foo_cfg(): 32 | code = lower.python_function(foo, compactify=False) 33 | assert isinstance(code, func.Function) 34 | cfg = CFG(code.body) 35 | assert code.body.blocks[0] in cfg.successors 36 | assert code.body.blocks[1] not in cfg.successors 37 | -------------------------------------------------------------------------------- /test/ir/test_dialect.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.ir import Dialect 4 | from kirin.types import PyClass 5 | 6 | 7 | def test_py_type_register(): 8 | 9 | class TestClass: 10 | pass 11 | 12 | class OtherTestClass: 13 | pass 14 | 15 | dialect = Dialect("test") 16 | 17 | TestType = dialect.register_py_type( 18 | TestClass, display_name="TestType", prefix="test" 19 | ) 20 | assert TestType == PyClass(TestClass, prefix="test", display_name="TestType") 21 | 22 | assert dialect.python_types == {("test", "TestType"): TestType} 23 | 24 | with pytest.raises(ValueError): 25 | dialect.register_py_type(OtherTestClass, display_name="TestType", prefix="test") 26 | 27 | with pytest.raises(ValueError): 28 | dialect.register_py_type(TestClass, display_name="TestClass", prefix="test") 29 | 30 | with pytest.raises(ValueError): 31 | dialect.register_py_type( 32 | TestClass, display_name="TestType", prefix="other_prefix" 33 | ) 34 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/repr.py: -------------------------------------------------------------------------------- 1 | from kirin.decl.base import BaseModifier 2 | 3 | from ._create_fn import create_fn 4 | from ._set_new_attribute import set_new_attribute 5 | 6 | 7 | class EmitRepr(BaseModifier): 8 | 9 | def emit_repr(self): 10 | if "repr" not in self.params or not self.params["repr"]: 11 | return 12 | 13 | body = [f'ret = "{self.cls.__name__}("'] 14 | for idx, field in enumerate(self.fields): 15 | if idx > 0: 16 | body.append('ret += ", "') 17 | body.append(f'ret += f"{field.name}={{{self._self_name}.{field.name}}}"') 18 | body.append('ret += ")"') 19 | body.append("return ret") 20 | 21 | set_new_attribute( 22 | self.cls, 23 | "__repr__", 24 | create_fn( 25 | "__repr__", 26 | args=[self._self_name], 27 | body=body, 28 | globals=self.globals, 29 | return_type=str, 30 | ), 31 | ) 32 | -------------------------------------------------------------------------------- /test/emit/test_julia.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from kirin import emit 4 | from kirin.prelude import structural 5 | from kirin.dialects import debug 6 | 7 | 8 | @structural.add(debug) 9 | def some_arith(x: int, y: float): 10 | return x + y 11 | 12 | 13 | @structural.add(debug) 14 | def julia_like(x: int, y: int): 15 | for i in range(x): 16 | for j in range(y): 17 | if i == 0: 18 | debug.info("Hello") 19 | else: 20 | debug.info("World") 21 | return some_arith(x + y, 4.0) 22 | 23 | 24 | def test_julia_like(tmp_path): 25 | file = tmp_path / "julia_like.jl" 26 | with open(file, "w") as io: 27 | julia_emit = emit.Julia(structural.add(debug), io=io) 28 | julia_emit.run(julia_like) 29 | 30 | with open(file, "r") as io: 31 | generated = io.read() 32 | 33 | with open(Path(__file__).parent / "julia_like.jl", "r") as io: 34 | target = io.read() 35 | 36 | assert generated.strip() == target.strip() 37 | -------------------------------------------------------------------------------- /test/dialects/scf/test_typeinfer.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | 3 | from kirin import types 4 | from kirin.prelude import structural_no_opt 5 | from kirin.analysis import TypeInference 6 | 7 | type_infer = TypeInference(structural_no_opt) 8 | 9 | 10 | @mark.xfail(reason="for with early return not supported in scf lowering") 11 | def test_inside_return_loop(): 12 | @structural_no_opt 13 | def simple_loop(x: float): 14 | for i in range(0, 3): 15 | return i 16 | return x 17 | 18 | frame, ret = type_infer.run(simple_loop) 19 | assert ret.is_subseteq(types.Int | types.Float) 20 | 21 | 22 | @mark.xfail(reason="if with early return not supported in scf lowering") 23 | def test_simple_ifelse(): 24 | @structural_no_opt 25 | def simple_ifelse(x: int): 26 | cond = x > 0 27 | if cond: 28 | return cond 29 | else: 30 | return 0 31 | 32 | frame, ret = type_infer.run(simple_ifelse) 33 | assert ret.is_subseteq(types.Bool | types.Int | types.NoneType) 34 | -------------------------------------------------------------------------------- /test/dialects/test_dummy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin import ir, types, lowering 4 | from kirin.decl import info, statement 5 | from kirin.prelude import basic 6 | 7 | dialect = ir.Dialect("dummy") 8 | 9 | 10 | @statement(dialect=dialect) 11 | class DummyStmt(ir.Statement): 12 | name = "dummy" 13 | traits = frozenset({lowering.FromPythonCall()}) 14 | value: ir.SSAValue = info.argument(types.Int) 15 | option: ir.PyAttr[str] = info.attribute() 16 | result: ir.ResultValue = info.result(types.Int) 17 | 18 | 19 | def test_attribute_lowering(): 20 | @basic.add(dialect) 21 | def test(x: int): 22 | return DummyStmt(x, option=ir.PyAttr("attr")) # type: ignore 23 | 24 | option = test.code.body.blocks[0].stmts.at(0).option # type: ignore 25 | assert isinstance(option, ir.PyAttr) and option.data == "attr" 26 | 27 | with pytest.raises(lowering.BuildError): 28 | 29 | @basic.add(dialect) 30 | def not_working(x: int): 31 | return DummyStmt(x, option=x) # type: ignore 32 | -------------------------------------------------------------------------------- /test/emit/julia_like.jl: -------------------------------------------------------------------------------- 1 | function julia_like(ssa_x, ssa_y) 2 | @label block_0 3 | ssa_0 = 0:1:ssa_x 4 | local ssa_y_1 5 | local ssa_y_2 6 | for ssa_i in ssa_0 7 | ssa_y_3 = ssa_y 8 | ssa_y_4 = ssa_y 9 | ssa_1 = 0:1:ssa_y_3 10 | local ssa_i_1 11 | local ssa_i_2 12 | for ssa_j in ssa_1 13 | ssa_i_3 = ssa_i 14 | ssa_i_4 = ssa_i 15 | ssa_2 = (ssa_i_3 == 0) 16 | if ssa_2 17 | ssa_3 = ssa_2 18 | @info "Hello" 19 | else 20 | ssa_4 = ssa_2 21 | @info "World" 22 | end 23 | ssa_i_1 = ssa_i_3 24 | ssa_i_2 = ssa_i_3 25 | end 26 | ssa_y_1 = ssa_y_3 27 | ssa_y_2 = ssa_y_3 28 | end 29 | ssa_5 = (ssa_x + ssa_y_1) 30 | ssa_6 = some_arith(ssa_5, 4.0) 31 | return ssa_6 32 | end 33 | 34 | function some_arith(ssa_x, ssa_y) 35 | @label block_0 36 | ssa_0 = (ssa_x + ssa_y) 37 | return ssa_0 38 | end 39 | -------------------------------------------------------------------------------- /test/dialects/scf/test_unroll.py: -------------------------------------------------------------------------------- 1 | from kirin.passes import Fold 2 | from kirin.prelude import structural_no_opt 3 | from kirin.rewrite import Walk 4 | from kirin.dialects import py, scf, func 5 | 6 | 7 | def test_simple_loop_unroll(): 8 | @structural_no_opt 9 | def simple_loop(x): 10 | for i in range(3): 11 | x = x + i 12 | return x 13 | 14 | fold = Fold(structural_no_opt) 15 | fold(simple_loop) 16 | Walk(scf.unroll.ForLoop()).rewrite(simple_loop.code) 17 | assert len(simple_loop.callable_region.blocks) == 1 18 | stmts = simple_loop.callable_region.blocks[0].stmts 19 | assert isinstance(stmts.at(0), py.Constant) 20 | assert isinstance(stmts.at(1), py.Constant) 21 | assert isinstance(stmts.at(2), py.Add) 22 | assert isinstance(stmts.at(3), py.Constant) 23 | assert isinstance(stmts.at(4), py.Add) 24 | assert isinstance(stmts.at(5), py.Constant) 25 | assert isinstance(stmts.at(6), py.Add) 26 | assert isinstance(stmts.at(7), func.Return) 27 | assert simple_loop(1) == 4 28 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Immutable list dialect for Python. 3 | 4 | This dialect provides a simple, immutable list dialect similar 5 | to Python's built-in list type. 6 | """ 7 | 8 | from . import ( 9 | _julia as _julia, 10 | interp as interp, 11 | rewrite as rewrite, 12 | lowering as lowering, 13 | constprop as constprop, 14 | typeinfer as typeinfer, 15 | ) 16 | from .stmts import ( 17 | Map as Map, 18 | New as New, 19 | Push as Push, 20 | Scan as Scan, 21 | Foldl as Foldl, 22 | Foldr as Foldr, 23 | ForEach as ForEach, 24 | IListType as IListType, 25 | ) 26 | from .passes import IListDesugar as IListDesugar 27 | from .runtime import IList as IList 28 | from ._dialect import dialect as dialect 29 | from ._wrapper import ( # careful this is not the builtin range 30 | all as all, 31 | any as any, 32 | map as map, 33 | scan as scan, 34 | foldl as foldl, 35 | foldr as foldr, 36 | range as range, 37 | sorted as sorted, 38 | for_each as for_each, 39 | ) 40 | -------------------------------------------------------------------------------- /src/kirin/dialects/random/interp.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from kirin.interp import Frame, MethodTable, impl 4 | 5 | from . import stmts 6 | from ._dialect import dialect 7 | 8 | 9 | @dialect.register 10 | class RandomMethodTable(MethodTable): 11 | 12 | @impl(stmts.Random) 13 | def random(self, interp, frame: Frame, stmt: stmts.Random): 14 | return (random.random(),) 15 | 16 | @impl(stmts.RandInt) 17 | def randint(self, interp, frame: Frame, stmt: stmts.RandInt): 18 | start = frame.get(stmt.start) 19 | stop = frame.get(stmt.stop) 20 | return (random.randint(start, stop),) 21 | 22 | @impl(stmts.Uniform) 23 | def uniform(self, interp, frame: Frame, stmt: stmts.Uniform): 24 | start = frame.get(stmt.start) 25 | stop = frame.get(stmt.stop) 26 | return (random.uniform(start, stop),) 27 | 28 | @impl(stmts.Seed) 29 | def seed(self, interp, frame: Frame, stmt: stmts.Seed): 30 | seed_value = frame.get(stmt.value) 31 | random.seed(seed_value) 32 | return tuple() 33 | -------------------------------------------------------------------------------- /test/test_worklist.py: -------------------------------------------------------------------------------- 1 | from kirin.worklist import WorkList 2 | 3 | 4 | def test_worklist(): 5 | wl = WorkList() 6 | 7 | assert wl.is_empty() 8 | assert not wl 9 | assert len(wl) == 0 10 | 11 | assert wl.pop() is None 12 | 13 | assert wl.is_empty() 14 | assert not wl 15 | assert len(wl) == 0 16 | 17 | wl.append("A") 18 | 19 | assert not wl.is_empty() 20 | assert wl 21 | assert len(wl) == 1 22 | 23 | assert wl.pop() == "A" 24 | 25 | assert wl.is_empty() 26 | assert not wl 27 | assert len(wl) == 0 28 | 29 | wl.append("Z") 30 | wl.extend("BCGFEDCB") 31 | 32 | assert not wl.is_empty() 33 | assert wl 34 | assert len(wl) == 9 35 | 36 | assert wl.pop() == "Z" 37 | assert wl.pop() == "B" 38 | 39 | assert not wl.is_empty() 40 | assert wl 41 | assert len(wl) == 7 42 | 43 | rest = [] 44 | while wl: 45 | rest.append(wl.pop()) 46 | assert rest == list("CGFEDCB") 47 | 48 | assert wl.is_empty() 49 | assert not wl 50 | assert len(wl) == 0 51 | -------------------------------------------------------------------------------- /docs/dialects/python/sugar.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | This page is under construction. The content may be incomplete or incorrect. Submit an issue 3 | on [GitHub](https://github.com/QuEraComputing/kirin/issues/new) if you need help or want to 4 | contribute. 5 | 6 | # Dialects for Python Syntax Sugar 7 | 8 | This page contains the dialects designed to represent Python syntax sugar. They provide an implementation 9 | of lowering transform from the corresponding Python AST to the dialects' statements. All the statements are 10 | typed `Any` thus one can always use a custom rewrite pass after type inference to support the desired syntax sugar. 11 | 12 | ## Reference 13 | 14 | ### Indexing 15 | 16 | ::: kirin.dialects.py.indexing 17 | options: 18 | filters: 19 | - "!statement" 20 | show_root_heading: true 21 | show_if_no_docstring: true 22 | 23 | ### Attribute 24 | 25 | ::: kirin.dialects.py.attr 26 | options: 27 | filters: 28 | - "!statement" 29 | show_root_heading: true 30 | show_if_no_docstring: true 31 | -------------------------------------------------------------------------------- /src/kirin/dialects/cf/abstract.py: -------------------------------------------------------------------------------- 1 | from kirin.interp import Successor, MethodTable, AbstractFrame, impl 2 | from kirin.dialects.cf.stmts import Branch, ConditionalBranch 3 | from kirin.analysis.typeinfer import TypeInference 4 | from kirin.dialects.cf.dialect import dialect 5 | 6 | 7 | @dialect.register(key="abstract") 8 | class AbstractMethodTable(MethodTable): 9 | 10 | @impl(Branch) 11 | def branch(self, interp: TypeInference, frame: AbstractFrame, stmt: Branch): 12 | frame.worklist.append( 13 | Successor(stmt.successor, *frame.get_values(stmt.arguments)) 14 | ) 15 | return () 16 | 17 | @impl(ConditionalBranch) 18 | def conditional_branch( 19 | self, interp: TypeInference, frame: AbstractFrame, stmt: ConditionalBranch 20 | ): 21 | frame.worklist.append( 22 | Successor(stmt.else_successor, *frame.get_values(stmt.else_arguments)) 23 | ) 24 | frame.worklist.append( 25 | Successor(stmt.then_successor, *frame.get_values(stmt.then_arguments)) 26 | ) 27 | return () 28 | -------------------------------------------------------------------------------- /src/kirin/dialects/lowering/range.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from kirin import ir, lowering 4 | from kirin.dialects.py.range import Range as PyRange 5 | from kirin.dialects.ilist.stmts import Range as IListRange 6 | 7 | ilist = ir.Dialect("lowering.range.ilist") 8 | """provides the syntax sugar from built-in range() function to ilist.range() 9 | """ 10 | py = ir.Dialect("lowering.range.py") 11 | """provides the syntax sugar from built-in range() function to py.range() 12 | """ 13 | 14 | 15 | @py.register 16 | class PyLowering(lowering.FromPythonAST): 17 | 18 | @lowering.akin(range) 19 | def lower_Call_range( 20 | self, state: lowering.State, node: ast.Call 21 | ) -> lowering.Result: 22 | return lowering.FromPythonRangeLike().lower(PyRange, state, node) 23 | 24 | 25 | @ilist.register 26 | class IListLowering(lowering.FromPythonAST): 27 | 28 | @lowering.akin(range) 29 | def lower_Call_range( 30 | self, state: lowering.State, node: ast.Call 31 | ) -> lowering.Result: 32 | return lowering.FromPythonRangeLike().lower(IListRange, state, node) 33 | -------------------------------------------------------------------------------- /src/kirin/types.py: -------------------------------------------------------------------------------- 1 | """Bindings for built-in types.""" 2 | 3 | import numbers 4 | 5 | from kirin.ir.attrs.types import ( 6 | Union as Union, 7 | Vararg as Vararg, 8 | AnyType as AnyType, 9 | Generic as Generic, 10 | Literal as Literal, 11 | PyClass as PyClass, 12 | TypeVar as TypeVar, 13 | BottomType as BottomType, 14 | FunctionType as FunctionType, 15 | TypeAttribute as TypeAttribute, 16 | TypeofMethodType as TypeofMethodType, 17 | hint2type as hint2type, 18 | is_tuple_of as is_tuple_of, 19 | ) 20 | 21 | Any = AnyType() 22 | Bottom = BottomType() 23 | Int = PyClass(int) 24 | Float = PyClass(float) 25 | Complex = PyClass(complex) 26 | Number = PyClass(numbers.Number) 27 | String = PyClass(str) 28 | Bool = PyClass(bool) 29 | NoneType = PyClass(type(None)) 30 | List = Generic(list, TypeVar("T")) 31 | Slice = Generic(slice, TypeVar("T")) 32 | Tuple = Generic(tuple, Vararg(TypeVar("T"))) 33 | Dict = Generic(dict, TypeVar("K"), TypeVar("V")) 34 | Set = Generic(set, TypeVar("T")) 35 | FrozenSet = Generic(frozenset, TypeVar("T")) 36 | MethodType = TypeofMethodType() 37 | -------------------------------------------------------------------------------- /src/kirin/graph.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterable, Optional, Protocol 2 | 3 | if TYPE_CHECKING: 4 | from kirin import ir 5 | from kirin.print import Printer 6 | 7 | Node = TypeVar("Node") 8 | 9 | 10 | class Graph(Protocol, Generic[Node]): 11 | """The graph interface. 12 | 13 | This interface defines the methods that a graph object must implement. 14 | The graph interface is mainly for compatibility reasons so that one can 15 | use multiple graph implementations interchangeably. 16 | """ 17 | 18 | def get_neighbors(self, node: Node) -> Iterable[Node]: 19 | """Get the neighbors of a node.""" 20 | ... 21 | 22 | def get_nodes(self) -> Iterable[Node]: 23 | """Get all the nodes in the graph.""" 24 | ... 25 | 26 | def get_edges(self) -> Iterable[tuple[Node, Node]]: 27 | """Get all the edges in the graph.""" 28 | ... 29 | 30 | def print( 31 | self, 32 | printer: Optional["Printer"] = None, 33 | analysis: dict["ir.SSAValue", Any] | None = None, 34 | ) -> None: ... 35 | -------------------------------------------------------------------------------- /test/rules/test_fold_br.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | from kirin.rewrite import Walk, Fixpoint, WrapConst 3 | from kirin.analysis import const 4 | from kirin.rewrite.dce import DeadCodeElimination 5 | from kirin.rewrite.fold import ConstantFold 6 | from kirin.rewrite.compactify import CFGCompactify 7 | 8 | 9 | @basic_no_opt 10 | def branch(x): 11 | if x > 1: 12 | y = x + 1 13 | else: 14 | y = x + 2 15 | 16 | if True: 17 | return y + 1 18 | else: 19 | y + 2 20 | 21 | 22 | def test_branch_elim(): 23 | assert branch(1) == 4 24 | const_prop = const.Propagate(branch.dialects) 25 | frame, ret = const_prop.run(branch) 26 | Walk(Fixpoint(WrapConst(frame))).rewrite(branch.code) 27 | fold = ConstantFold() 28 | Fixpoint(Walk(fold)).rewrite(branch.code) 29 | # TODO: also check the generated CFG 30 | # interp.worklist.visited 31 | Fixpoint(CFGCompactify()).rewrite(branch.code) 32 | Walk(DeadCodeElimination()).rewrite(branch.code) 33 | branch.code.print() 34 | assert len(branch.code.body.blocks) == 4 # type: ignore 35 | -------------------------------------------------------------------------------- /src/kirin/decl/emit/_create_fn.py: -------------------------------------------------------------------------------- 1 | """This module provides a function to create a function dynamically. 2 | 3 | Copied from `dataclasses._create_fn` in Python 3.10.13. 4 | """ 5 | 6 | from dataclasses import MISSING 7 | 8 | 9 | def create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): 10 | # Note that we may mutate locals. Callers beware! 11 | # The only callers are internal to this module, so no 12 | # worries about external callers. 13 | if locals is None: 14 | locals = {} 15 | return_annotation = "" 16 | if return_type is not MISSING: 17 | locals["_return_type"] = return_type 18 | return_annotation = "->_return_type" 19 | args = ",".join(args) 20 | body = "\n".join(f" {b}" for b in body) 21 | 22 | # Compute the text of the entire function. 23 | txt = f" def {name}({args}){return_annotation}:\n{body}" 24 | 25 | local_vars = ", ".join(locals.keys()) 26 | txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" 27 | ns = {} 28 | exec(txt, globals, ns) 29 | return ns["__create_fn__"](**locals) 30 | -------------------------------------------------------------------------------- /test/ir/test_verify.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin import ir, lowering 4 | from kirin.decl import statement 5 | from kirin.prelude import basic_no_opt 6 | 7 | dialect = ir.Dialect("foo") 8 | 9 | 10 | @statement(dialect=dialect) 11 | class InvalidStmt(ir.Statement): 12 | traits = frozenset({lowering.FromPythonCall()}) 13 | 14 | def check(self): 15 | raise ValueError("Never triggers") 16 | 17 | 18 | @statement(dialect=dialect) 19 | class InvalidType(ir.Statement): 20 | traits = frozenset({lowering.FromPythonCall()}) 21 | 22 | def check_type(self): 23 | raise ValueError("Never triggers") 24 | 25 | 26 | @ir.dialect_group(basic_no_opt.add(dialect)) 27 | def foo(self): 28 | def run_pass(mt): 29 | pass 30 | 31 | return run_pass 32 | 33 | 34 | def test_invalid_stmt(): 35 | @foo 36 | def test(): 37 | InvalidStmt() 38 | 39 | with pytest.raises(Exception): 40 | test.verify() 41 | 42 | 43 | def test_invalid_type(): 44 | @foo 45 | def test(): 46 | InvalidType() 47 | 48 | with pytest.raises(Exception): 49 | test.verify_type() 50 | -------------------------------------------------------------------------------- /src/kirin/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | """Analysis module for kirin. 2 | 3 | This module contains the analysis framework for kirin. The analysis framework is 4 | built on top of the interpreter framework. This module provides a set of base classes 5 | and frameworks for implementing compiler analysis passes on the IR. 6 | 7 | The analysis framework contains the following modules: 8 | 9 | - [`cfg`][kirin.analysis.cfg]: Control flow graph for a given IR. 10 | - [`forward`][kirin.analysis.forward]: Forward dataflow analysis. 11 | - [`callgraph`][kirin.analysis.callgraph]: Call graph for a given IR. 12 | - [`typeinfer`][kirin.analysis.typeinfer]: Type inference analysis. 13 | - [`const`][kirin.analysis.const]: Constants used in the analysis framework. 14 | """ 15 | 16 | from kirin.analysis import const as const 17 | from kirin.analysis.cfg import CFG as CFG 18 | from kirin.analysis.forward import ( 19 | Forward as Forward, 20 | ForwardExtra as ForwardExtra, 21 | ForwardFrame as ForwardFrame, 22 | ) 23 | from kirin.analysis.callgraph import CallGraph as CallGraph 24 | from kirin.analysis.typeinfer import TypeInference as TypeInference 25 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | ruff: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v6 16 | - uses: chartboost/ruff-action@v1 17 | black: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v6 21 | - uses: psf/black@stable 22 | pyright: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v6 26 | - name: Install uv 27 | uses: astral-sh/setup-uv@v7 28 | with: 29 | # Install a specific version of uv. 30 | version: "0.6.14" 31 | enable-cache: true 32 | cache-dependency-glob: "uv.lock" 33 | - name: Install the project 34 | run: uv sync --all-extras --dev 35 | - run: echo "$PWD/.venv/bin" >> $GITHUB_PATH 36 | - uses: jakebailey/pyright-action@v2 37 | with: 38 | pylance-version: latest-release 39 | -------------------------------------------------------------------------------- /src/kirin/worklist.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from queue import SimpleQueue 4 | from typing import Generic, TypeVar, Iterable 5 | 6 | ElemType = TypeVar("ElemType") 7 | 8 | 9 | class WorkList(SimpleQueue, Generic[ElemType]): 10 | """The worklist data structure. 11 | 12 | The worklist is a stack that allows for O(1) removal of elements from the stack. 13 | """ 14 | 15 | def __len__(self) -> int: 16 | return self.qsize() 17 | 18 | def __bool__(self) -> bool: 19 | return not self.empty() 20 | 21 | def is_empty(self) -> bool: 22 | return self.empty() 23 | 24 | def append(self, item: ElemType) -> None: 25 | self.put_nowait(item) 26 | 27 | def extend(self, items: Iterable[ElemType]) -> None: 28 | for item in items: 29 | self.put_nowait(item) 30 | 31 | def pop(self) -> ElemType | None: 32 | if self.empty(): 33 | return None 34 | return self.get_nowait() 35 | 36 | 37 | # Remove one function call from critical speed bottleneck 38 | WorkList.is_empty = WorkList.empty 39 | WorkList.append = WorkList.put_nowait 40 | -------------------------------------------------------------------------------- /src/kirin/rewrite/fixpoint.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin.ir import IRNode 4 | from kirin.rewrite.abc import RewriteRule, RewriteResult 5 | 6 | 7 | @dataclass 8 | class Fixpoint(RewriteRule): 9 | """Apply a rewrite rule until a fixpoint is reached. 10 | 11 | The rewrite rule is applied to the node until the rewrite rule does not do anything. 12 | 13 | ### Parameters 14 | - `map`: The rewrite rule to apply. 15 | - `max_iter`: The maximum number of iterations to apply the rewrite rule. Default is 32. 16 | """ 17 | 18 | rule: RewriteRule 19 | max_iter: int = 32 20 | 21 | def rewrite(self, node: IRNode) -> RewriteResult: 22 | has_done_something = False 23 | for _ in range(self.max_iter): 24 | result = self.rule.rewrite(node) 25 | if result.terminated: 26 | return result 27 | 28 | if result.has_done_something: 29 | has_done_something = True 30 | else: 31 | return RewriteResult(has_done_something=has_done_something) 32 | 33 | return RewriteResult(exceeded_max_iter=True) 34 | -------------------------------------------------------------------------------- /test/lowering/test_source_info.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.source import SourceInfo 4 | from kirin.prelude import basic_no_opt 5 | 6 | 7 | def get_line_of(target: str) -> int: 8 | for i, line in enumerate(open(__file__), 1): 9 | if target in line: 10 | return i 11 | 12 | 13 | @pytest.mark.parametrize("similar", [True, False]) 14 | def test_stmt_source_info(similar: bool): 15 | @basic_no_opt 16 | def test(x: int): 17 | y = 2 18 | a = 4**2 19 | return y + 2 + a 20 | 21 | if similar: 22 | test = test.similar() 23 | 24 | stmts = test.callable_region.blocks[0].stmts 25 | 26 | def get_line_from_source_info(source: SourceInfo) -> int: 27 | return source.lineno + source.lineno_begin 28 | 29 | for stmt in stmts: 30 | assert stmt.source.file == __file__ 31 | 32 | assert get_line_from_source_info(stmts.at(0).source) == get_line_of("y = 2") 33 | assert get_line_from_source_info(stmts.at(2).source) == get_line_of("a = 4**2") 34 | assert get_line_from_source_info(stmts.at(4).source) == get_line_of( 35 | "return y + 2 + a" 36 | ) 37 | -------------------------------------------------------------------------------- /test/dialects/test_infer_len.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | from kirin import rewrite 4 | from kirin.prelude import basic 5 | from kirin.dialects import py, ilist 6 | 7 | 8 | def test(): 9 | rule = rewrite.Fixpoint( 10 | rewrite.Walk( 11 | rewrite.Chain( 12 | ilist.rewrite.HintLen(), 13 | rewrite.ConstantFold(), 14 | rewrite.DeadCodeElimination(), 15 | ) 16 | ) 17 | ) 18 | 19 | @basic 20 | def len_func(xs: ilist.IList[int, Literal[3]]): 21 | return len(xs) 22 | 23 | @basic 24 | def len_func3(xs: ilist.IList[int, Any]): 25 | return len(xs) 26 | 27 | rule.rewrite(len_func.code) 28 | rule.rewrite(len_func3.code) 29 | 30 | stmt = len_func.callable_region.blocks[0].stmts.at(0) 31 | assert isinstance(stmt, py.Constant) 32 | assert stmt.value.unwrap() == 3 33 | assert len(len_func.callable_region.blocks[0].stmts) == 2 34 | 35 | stmt = len_func3.callable_region.blocks[0].stmts.at(0) 36 | assert isinstance(stmt, py.Len) 37 | assert len(len_func3.callable_region.blocks[0].stmts) == 2 38 | -------------------------------------------------------------------------------- /src/kirin/ir/traits/region/ssacfg.py: -------------------------------------------------------------------------------- 1 | """SSACFG region trait. 2 | 3 | This module defines the SSACFGRegion trait, which is used to indicate that a 4 | region has an SSACFG graph. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from typing import TYPE_CHECKING, TypeVar 10 | from dataclasses import dataclass 11 | 12 | from kirin.ir.ssa import SSAValue as SSAValue 13 | from kirin.ir.traits.abc import RegionGraph, RegionInterpretationTrait 14 | from kirin.ir.nodes.region import Region 15 | 16 | if TYPE_CHECKING: 17 | from kirin import ir 18 | from kirin.interp.frame import FrameABC 19 | 20 | 21 | @dataclass(frozen=True) 22 | class HasCFG(RegionGraph): 23 | 24 | def get_graph(self, region: ir.Region): 25 | from kirin.analysis.cfg import CFG 26 | 27 | return CFG(region) 28 | 29 | 30 | @dataclass(frozen=True) 31 | class SSACFG(RegionInterpretationTrait): 32 | 33 | ValueType = TypeVar("ValueType") 34 | 35 | @classmethod 36 | def set_region_input( 37 | cls, frame: FrameABC[SSAValue, ValueType], region: Region, *inputs: ValueType 38 | ) -> None: 39 | frame.set_values(region.blocks[0].args, inputs) 40 | -------------------------------------------------------------------------------- /test/analysis/dataflow/typeinfer/test_infer_lambda.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.prelude import structural 3 | from kirin.dialects import ilist 4 | 5 | 6 | def test_infer_lambda(): 7 | @structural(typeinfer=True, fold=False, no_raise=False) 8 | def main(n): 9 | def map_func(i): 10 | return n + 1 11 | 12 | return ilist.map(map_func, ilist.range(4)) 13 | 14 | map_stmt = main.callable_region.blocks[0].stmts.at(-2) 15 | assert isinstance(map_stmt, ilist.Map) 16 | assert map_stmt.result.type == ilist.IListType[types.Int, types.Literal(4)] 17 | 18 | 19 | def test_infer_method_type_hint_call(): 20 | 21 | @structural(typeinfer=True, fold=False, no_raise=False) 22 | def main(n, fx: ir.Method[[int], int]): 23 | return fx(n) 24 | 25 | assert main.return_type == types.Int 26 | 27 | 28 | def test_infer_method_type_hint(): 29 | 30 | @structural(typeinfer=True, fold=False, no_raise=False) 31 | def main(n, fx: ir.Method[[int], int]): 32 | def map_func(i): 33 | return n + 1 + fx(i) 34 | 35 | return ilist.map(map_func, ilist.range(4)) 36 | 37 | assert main.return_type == ilist.IListType[types.Int, types.Literal(4)] 38 | -------------------------------------------------------------------------------- /test/program/py/test_cmp.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | 3 | 4 | @basic_no_opt 5 | def not_in(x, y): 6 | return x not in y 7 | 8 | 9 | @basic_no_opt 10 | def in_(x, y): 11 | return x in y 12 | 13 | 14 | @basic_no_opt 15 | def is_(x, y): 16 | return x is y 17 | 18 | 19 | @basic_no_opt 20 | def is_not(x, y): 21 | return x is not y 22 | 23 | 24 | def test_is(): 25 | class Foo: 26 | pass 27 | 28 | a, b = Foo(), Foo() 29 | assert is_(a, b) == (a is b) 30 | assert is_(a, a) == (a is a) 31 | 32 | 33 | def test_is_not(): 34 | class Foo: 35 | pass 36 | 37 | a, b = Foo(), Foo() 38 | assert is_not(a, b) == (a is not b) 39 | assert is_not(a, a) == (a is not a) 40 | 41 | 42 | def test_in(): 43 | assert in_(1, [1, 2, 3]) == (1 in [1, 2, 3]) 44 | assert in_(4, [1, 2, 3]) == (4 in [1, 2, 3]) 45 | assert in_("a", "abc") == ("a" in "abc") 46 | assert in_("d", "abc") == ("d" in "abc") 47 | 48 | 49 | def test_not_in(): 50 | assert not_in(1, [1, 2, 3]) == (1 not in [1, 2, 3]) 51 | assert not_in(4, [1, 2, 3]) == (4 not in [1, 2, 3]) 52 | assert not_in("a", "abc") == ("a" not in "abc") 53 | assert not_in("d", "abc") == ("d" not in "abc") 54 | -------------------------------------------------------------------------------- /test/analysis/dataflow/test_non_pure_const.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | from kirin.prelude import basic_no_opt 4 | from kirin.analysis import const 5 | 6 | dialect = ir.Dialect("mwe") 7 | 8 | 9 | @statement(dialect=dialect) 10 | class SideEffect(ir.Statement): 11 | name = "side_effect" 12 | traits = frozenset({lowering.FromPythonCall()}) 13 | value: ir.SSAValue = info.argument(types.Int) 14 | 15 | 16 | @basic_no_opt.add(dialect) 17 | def recursion(kernel, n: int, pos: int): 18 | if pos == n: 19 | return 20 | 21 | kernel(pos) 22 | recursion(kernel, n, pos + 1) 23 | 24 | 25 | @basic_no_opt.add(dialect) 26 | def side_effect(pos: int): 27 | SideEffect(pos) # type: ignore 28 | 29 | 30 | def test_non_pure_const(): 31 | constprop = const.Propagate(basic_no_opt) 32 | frame, ret = constprop.run( 33 | recursion, 34 | const.Value(side_effect), 35 | const.Result.top(), 36 | const.Result.top(), 37 | ) 38 | # recursion.print(analysis=frame.entries) 39 | ret = frame.entries[recursion.callable_region.blocks[2].stmts.at(3).results[0]] 40 | assert isinstance(ret, const.Value) 41 | assert frame.frame_is_not_pure 42 | -------------------------------------------------------------------------------- /test/rules/test_fold.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt, python_no_opt 2 | from kirin.rewrite import Walk, Fixpoint, WrapConst 3 | from kirin.analysis import const 4 | from kirin.rewrite.fold import ConstantFold 5 | 6 | 7 | @basic_no_opt 8 | def foldable(x: int) -> int: 9 | y = 1 10 | b = y + 2 11 | c = y + b 12 | d = c + 4 13 | return d + x 14 | 15 | 16 | def test_const_fold(): 17 | before = foldable(1) 18 | const_prop = const.Propagate(foldable.dialects) 19 | frame, _ = const_prop.run(foldable) 20 | Fixpoint(Walk(WrapConst(frame))).rewrite(foldable.code) 21 | fold = ConstantFold() 22 | Fixpoint(Walk(fold)).rewrite(foldable.code) 23 | after = foldable(1) 24 | 25 | assert before == after 26 | 27 | 28 | def test_const_fold_subroutine(): 29 | 30 | @python_no_opt 31 | def non_pure_subroutine(x: list[int]) -> None: 32 | x.append(1) 33 | 34 | @python_no_opt 35 | def main(): 36 | x = [] 37 | non_pure_subroutine(x) 38 | x.append(2) 39 | 40 | old_main_region = main.callable_region.clone() 41 | 42 | fold = ConstantFold() 43 | Fixpoint(Walk(fold)).rewrite(main.code) 44 | 45 | assert old_main_region.is_structurally_equal(main.callable_region) 46 | -------------------------------------------------------------------------------- /test/lowering/test_with.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, lowering 2 | from kirin.decl import info, statement 3 | from kirin.prelude import python_no_opt 4 | from kirin.dialects import cf, py, func 5 | 6 | dialect = ir.Dialect("test") 7 | 8 | 9 | @statement(dialect=dialect) 10 | class Adjoint(ir.Statement): 11 | traits = frozenset({lowering.FromPythonWithSingleItem()}) 12 | body: ir.Region = info.region() 13 | result: ir.ResultValue = info.result() 14 | 15 | 16 | def with_example(x): 17 | y = 1 18 | with Adjoint() as f: # type: ignore 19 | y = x + 1 20 | return y, f 21 | 22 | 23 | def test_with_lowering(): 24 | lower = lowering.Python(python_no_opt.union([cf, func, dialect])) 25 | code = lower.python_function(with_example) 26 | code.print() 27 | assert isinstance(code, func.Function) 28 | stmts = code.body.blocks[0].stmts 29 | assert isinstance(stmts.at(0), py.Constant) 30 | adjoint = stmts.at(1) 31 | assert isinstance(adjoint, Adjoint) 32 | assert len(adjoint.body.blocks) == 1 33 | add = adjoint.body.blocks[0].stmts.at(1) 34 | assert isinstance(add, py.Add) 35 | assert isinstance(add.lhs, ir.BlockArgument) 36 | assert isinstance(add.rhs, ir.SSAValue) 37 | assert adjoint.result.name == "f" 38 | -------------------------------------------------------------------------------- /test/rules/test_apply_type.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.prelude import basic 3 | from kirin.analysis import const 4 | 5 | 6 | @basic(typeinfer=True, fold=True) 7 | def unstable(x: int): # type: ignore 8 | y = x + 1 9 | if y > 10: 10 | z = y 11 | else: 12 | z = y + 1.2 13 | return z 14 | 15 | 16 | def test_apply_type(): 17 | def stmt_at(block_id, stmt_id: int) -> ir.Statement: 18 | return unstable.callable_region.blocks[block_id].stmts.at(stmt_id) # type: ignore 19 | 20 | assert stmt_at(0, 0).results.types == [types.Int] 21 | assert stmt_at(0, 0).results[0].hints.get("const") == const.Value(1) 22 | assert stmt_at(0, 1).results.types == [types.Int] 23 | assert stmt_at(0, 2).results.types == [types.Int] 24 | assert stmt_at(0, 2).results[0].hints.get("const") == const.Value(10) 25 | assert stmt_at(0, 3).results.types == [types.Bool] 26 | 27 | assert stmt_at(1, 0).results.types == [types.Int] 28 | assert stmt_at(2, 0).results.types == [types.Float] 29 | assert stmt_at(2, 0).results[0].hints.get("const") == const.Value(1.2) 30 | assert stmt_at(2, 1).results.types == [types.Float] 31 | 32 | stmt = stmt_at(3, 0) 33 | assert stmt.args[0].type == (types.Int | types.Float) 34 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/cmp/stmts.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | 4 | from ._dialect import dialect 5 | 6 | 7 | @statement 8 | class Cmp(ir.Statement): 9 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 10 | lhs: ir.SSAValue = info.argument() 11 | rhs: ir.SSAValue = info.argument() 12 | result: ir.ResultValue = info.result(types.Bool) 13 | 14 | 15 | @statement(dialect=dialect) 16 | class Eq(Cmp): 17 | name = "eq" 18 | 19 | 20 | @statement(dialect=dialect) 21 | class NotEq(Cmp): 22 | name = "ne" 23 | 24 | 25 | @statement(dialect=dialect) 26 | class Lt(Cmp): 27 | name = "lt" 28 | 29 | 30 | @statement(dialect=dialect) 31 | class Gt(Cmp): 32 | name = "gt" 33 | 34 | 35 | @statement(dialect=dialect) 36 | class LtE(Cmp): 37 | name = "lte" 38 | 39 | 40 | @statement(dialect=dialect) 41 | class GtE(Cmp): 42 | name = "gte" 43 | 44 | 45 | @statement(dialect=dialect) 46 | class Is(Cmp): 47 | name = "is" 48 | 49 | 50 | @statement(dialect=dialect) 51 | class IsNot(Cmp): 52 | name = "is_not" 53 | 54 | 55 | @statement(dialect=dialect) 56 | class In(Cmp): 57 | name = "in" 58 | 59 | 60 | @statement(dialect=dialect) 61 | class NotIn(Cmp): 62 | name = "not_in" 63 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/passes.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.rewrite import Walk, Chain, Fixpoint 3 | from kirin.passes.abc import Pass 4 | from kirin.rewrite.abc import RewriteResult 5 | from kirin.dialects.ilist.rewrite import List2IList, ConstList2IList 6 | 7 | 8 | class IListDesugar(Pass): 9 | """This pass desugars the Python list dialect 10 | to the immutable list dialect by rewriting all 11 | constant `list` type into `IList` type. 12 | """ 13 | 14 | def unsafe_run(self, mt: ir.Method) -> RewriteResult: 15 | for arg in mt.args: 16 | _check_list(arg.type, arg.type) 17 | return Fixpoint(Walk(Chain(ConstList2IList(), List2IList()))).rewrite(mt.code) 18 | 19 | 20 | def _check_list(total: types.TypeAttribute, type_: types.TypeAttribute): 21 | if isinstance(type_, types.Generic): 22 | _check_list(total, type_.body) 23 | for var in type_.vars: 24 | _check_list(total, var) 25 | if type_.vararg: 26 | _check_list(total, type_.vararg.typ) 27 | elif isinstance(type_, types.PyClass): 28 | if issubclass(type_.typ, list): 29 | raise TypeError( 30 | f"Invalid type {total} for this kernel, use IList instead of {type_}." 31 | ) 32 | return 33 | -------------------------------------------------------------------------------- /src/kirin/passes/aggressive/unroll.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, dataclass 2 | 3 | from kirin.ir import Method 4 | from kirin.passes import Fold, Pass, TypeInfer 5 | from kirin.rewrite import Walk 6 | from kirin.rewrite.abc import RewriteResult 7 | from kirin.dialects.scf.unroll import ForLoop, PickIfElse 8 | 9 | 10 | @dataclass 11 | class UnrollScf(Pass): 12 | """This pass can be used to unroll scf.For loops and inline/expand scf.IfElse when 13 | the input are known at compile time. 14 | 15 | usage: 16 | UnrollScf(dialects).fixpoint(method) 17 | 18 | Note: This pass should be used in a fixpoint manner, to unroll nested scf nodes. 19 | 20 | """ 21 | 22 | typeinfer: TypeInfer = field(init=False) 23 | fold: Fold = field(init=False) 24 | 25 | def __post_init__(self): 26 | self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) 27 | self.fold = Fold(self.dialects, no_raise=self.no_raise) 28 | 29 | def unsafe_run(self, mt: Method): 30 | result = RewriteResult() 31 | result = Walk(PickIfElse()).rewrite(mt.code).join(result) 32 | result = Walk(ForLoop()).rewrite(mt.code).join(result) 33 | result = self.fold.unsafe_run(mt).join(result) 34 | self.typeinfer.unsafe_run(mt) 35 | return result 36 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/range.py: -------------------------------------------------------------------------------- 1 | """The range dialect for Python. 2 | 3 | This dialect models the builtin `range()` function in Python. 4 | 5 | The dialect includes: 6 | - The `Range` statement class. 7 | - The lowering pass for the `range()` function. 8 | 9 | This dialect does not include a concrete implementation or type inference 10 | for the `range()` function. One needs to use other dialect for the concrete 11 | implementation and type inference, e.g., `ilist` dialect. 12 | """ 13 | 14 | from kirin import ir, types, interp, lowering 15 | from kirin.decl import info, statement 16 | from kirin.dialects import eltype 17 | 18 | dialect = ir.Dialect("py.range") 19 | 20 | 21 | @statement(dialect=dialect) 22 | class Range(ir.Statement): 23 | name = "range" 24 | traits = frozenset({ir.Pure(), lowering.FromPythonRangeLike()}) 25 | start: ir.SSAValue = info.argument(types.Int) 26 | stop: ir.SSAValue = info.argument(types.Int) 27 | step: ir.SSAValue = info.argument(types.Int) 28 | result: ir.ResultValue = info.result(types.PyClass(range)) 29 | 30 | 31 | @dialect.register(key="typeinfer") 32 | class TypeInfer(interp.MethodTable): 33 | 34 | @interp.impl(eltype.ElType, types.PyClass(range)) 35 | def eltype_range(self, interp_, frame: interp.Frame, stmt: eltype.ElType): 36 | return (types.Int,) 37 | -------------------------------------------------------------------------------- /src/kirin/passes/aggressive/fold.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, dataclass 2 | 3 | from kirin.passes import Pass 4 | from kirin.rewrite import ( 5 | Walk, 6 | Chain, 7 | Inline, 8 | Fixpoint, 9 | Call2Invoke, 10 | ConstantFold, 11 | CFGCompactify, 12 | InlineGetItem, 13 | InlineGetField, 14 | DeadCodeElimination, 15 | ) 16 | from kirin.ir.method import Method 17 | from kirin.rewrite.abc import RewriteResult 18 | from kirin.passes.hint_const import HintConst 19 | 20 | 21 | @dataclass 22 | class Fold(Pass): 23 | hint_const: HintConst = field(init=False) 24 | 25 | def __post_init__(self): 26 | self.hint_const = HintConst(self.dialects) 27 | self.hint_const.no_raise = self.no_raise 28 | 29 | def unsafe_run(self, mt: Method) -> RewriteResult: 30 | result = self.hint_const.unsafe_run(mt) 31 | rule = Chain( 32 | ConstantFold(), 33 | Call2Invoke(), 34 | InlineGetField(), 35 | InlineGetItem(), 36 | DeadCodeElimination(), 37 | ) 38 | result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) 39 | result = Walk(Inline(lambda _: True)).rewrite(mt.code).join(result) 40 | result = Fixpoint(CFGCompactify()).rewrite(mt.code).join(result) 41 | return result 42 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/__init__.py: -------------------------------------------------------------------------------- 1 | """Python dialects module. 2 | 3 | This module contains a set of dialects that represent 4 | different fractions of the Python language. The dialects 5 | are designed to be used in a union to represent the 6 | entire Python language. 7 | """ 8 | 9 | from . import ( 10 | cmp as cmp, 11 | len as len, 12 | attr as attr, 13 | base as base, 14 | list as list, 15 | binop as binop, 16 | range as range, 17 | slice as slice, 18 | tuple as tuple, 19 | unary as unary, 20 | assign as assign, 21 | boolop as boolop, 22 | unpack as unpack, 23 | builtin as builtin, 24 | constant as constant, 25 | indexing as indexing, 26 | iterable as iterable, 27 | ) 28 | from .len import Len as Len 29 | from .attr import GetAttr as GetAttr 30 | from .range import Range as Range 31 | from .slice import Slice as Slice 32 | from .assign import Alias as Alias, SetItem as SetItem 33 | from .boolop import Or as Or, And as And 34 | from .builtin import Abs as Abs, Sum as Sum 35 | from .constant import Constant as Constant 36 | from .indexing import GetItem as GetItem, PyGetItemLike as PyGetItemLike 37 | from .cmp.stmts import * # noqa: F403 38 | from .list.stmts import Append as Append 39 | from .binop.stmts import * # noqa: F403 40 | from .unary.stmts import * # noqa: F403 41 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/base.py: -------------------------------------------------------------------------------- 1 | """Base dialect for Python. 2 | 3 | This dialect does not contain statements. It only contains 4 | lowering rules for `ast.Name` and `ast.Expr`. 5 | """ 6 | 7 | import ast 8 | 9 | from kirin import ir, interp, lowering 10 | 11 | dialect = ir.Dialect("py.base") 12 | 13 | 14 | @dialect.register 15 | class PythonLowering(lowering.FromPythonAST): 16 | 17 | def lower_Name(self, state: lowering.State, node: ast.Name) -> lowering.Result: 18 | name = node.id 19 | if isinstance(node.ctx, ast.Load): 20 | value = state.current_frame.get(name) 21 | if value is None: 22 | raise lowering.BuildError(f"{name} is not defined") 23 | return value 24 | elif isinstance(node.ctx, ast.Store): 25 | raise lowering.BuildError("unhandled store operation") 26 | else: # Del 27 | raise lowering.BuildError("unhandled del operation") 28 | 29 | def lower_Expr(self, state: lowering.State, node: ast.Expr) -> lowering.Result: 30 | return state.parent.visit(state, node.value) 31 | 32 | 33 | @dialect.register(key="emit.julia") 34 | class PyAttrMethod(interp.MethodTable): 35 | 36 | @interp.impl(ir.PyAttr) 37 | def py_attr(self, interp, frame: interp.Frame, node: ir.PyAttr): 38 | return repr(node.data) 39 | -------------------------------------------------------------------------------- /src/kirin/rewrite/chain.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from dataclasses import dataclass 3 | 4 | from kirin.ir import IRNode 5 | from kirin.rewrite.abc import RewriteRule, RewriteResult 6 | 7 | 8 | @dataclass 9 | class Chain(RewriteRule): 10 | """Chain multiple rewrites together. 11 | 12 | The chain will apply each rewrite in order until one of the rewrites terminates. 13 | """ 14 | 15 | rules: list[RewriteRule] 16 | 17 | def __init__(self, rule: RewriteRule | Iterable[RewriteRule], *others: RewriteRule): 18 | if isinstance(rule, RewriteRule): 19 | self.rules = [rule, *others] 20 | else: 21 | assert ( 22 | others == () 23 | ), "Cannot pass multiple positional arguments if the first argument is an iterable" 24 | self.rules = list(rule) 25 | 26 | def rewrite(self, node: IRNode) -> RewriteResult: 27 | has_done_something = False 28 | for rule in self.rules: 29 | result = rule.rewrite(node) 30 | if result.terminated: 31 | return result 32 | 33 | if result.has_done_something: 34 | has_done_something = True 35 | return RewriteResult(has_done_something=has_done_something) 36 | 37 | def __repr__(self): 38 | return " -> ".join(map(str, self.rules)) 39 | -------------------------------------------------------------------------------- /src/kirin/dialects/scf/interp.py: -------------------------------------------------------------------------------- 1 | from kirin import interp 2 | 3 | from .stmts import For, Yield, IfElse 4 | from ._dialect import dialect 5 | 6 | 7 | @dialect.register 8 | class Concrete(interp.MethodTable): 9 | 10 | @interp.impl(Yield) 11 | def yield_stmt(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: Yield): 12 | return interp.YieldValue(frame.get_values(stmt.values)) 13 | 14 | @interp.impl(IfElse) 15 | def if_else(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: IfElse): 16 | cond = frame.get(stmt.cond) 17 | if cond: 18 | body = stmt.then_body 19 | else: 20 | body = stmt.else_body 21 | return interp_.frame_call_region(frame, stmt, body, cond) 22 | 23 | @interp.impl(For) 24 | def for_loop(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: For): 25 | iterable = frame.get(stmt.iterable) 26 | loop_vars = frame.get_values(stmt.initializers) 27 | for value in iterable: 28 | loop_vars = interp_.frame_call_region( 29 | frame, stmt, stmt.body, value, *loop_vars 30 | ) 31 | if isinstance(loop_vars, interp.ReturnValue): 32 | return loop_vars 33 | elif loop_vars is None: 34 | loop_vars = () 35 | return loop_vars 36 | -------------------------------------------------------------------------------- /src/kirin/ir/traits/__init__.py: -------------------------------------------------------------------------------- 1 | """Kirin IR Traits. 2 | 3 | This module defines the traits that can be used to define the behavior of 4 | Kirin IR nodes. The base trait is `StmtTrait`, which is a `dataclass` that 5 | implements the `__hash__` and `__eq__` methods. 6 | 7 | There are also some basic traits that are provided for convenience, such as 8 | `Pure`, `HasParent`, `ConstantLike`, `IsTerminator`, `NoTerminator`, and 9 | `IsolatedFromAbove`. 10 | """ 11 | 12 | from .abc import ( 13 | Trait as Trait, 14 | AttrTrait as AttrTrait, 15 | StmtTrait as StmtTrait, 16 | RegionGraph as RegionGraph, 17 | RegionInterpretationTrait as RegionInterpretationTrait, 18 | ) 19 | from .basic import ( 20 | Pure as Pure, 21 | HasParent as HasParent, 22 | MaybePure as MaybePure, 23 | ConstantLike as ConstantLike, 24 | IsTerminator as IsTerminator, 25 | NoTerminator as NoTerminator, 26 | IsolatedFromAbove as IsolatedFromAbove, 27 | ) 28 | from .symbol import ( 29 | SymbolTable as SymbolTable, 30 | SymbolOpInterface as SymbolOpInterface, 31 | EntryPointInterface as EntryPointInterface, 32 | ) 33 | from .callable import ( 34 | StaticCall as StaticCall, 35 | HasSignature as HasSignature, 36 | CallableStmtInterface as CallableStmtInterface, 37 | ) 38 | from .region.ssacfg import SSACFG as SSACFG, HasCFG as HasCFG 39 | -------------------------------------------------------------------------------- /.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | name: Documentation (preview) 2 | on: 3 | pull_request: 4 | types: 5 | - opened 6 | - reopened 7 | - synchronize 8 | - closed 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | documentation: 15 | name: Deploy preview documentation 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v6 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v7 21 | with: 22 | # Install a specific version of uv. 23 | version: "0.6.14" 24 | enable-cache: true 25 | cache-dependency-glob: "uv.lock" 26 | - name: Install Documentation dependencies 27 | run: uv sync --group doc 28 | - name: Set up build cache 29 | uses: actions/cache@v5 30 | id: cache 31 | with: 32 | key: mkdocs-material-${{ github.ref }} 33 | path: .cache 34 | restore-keys: | 35 | mkdocs-material- 36 | - name: Depoly documentation 37 | env: 38 | GH_TOKEN: ${{ secrets.GH_TOKEN }} 39 | run: | 40 | uv run mkdocs build 41 | - name: Deploy preview 42 | uses: rossjrw/pr-preview-action@v1 43 | with: 44 | source-dir: ./site 45 | -------------------------------------------------------------------------------- /test/dialects/pystmts/test_range.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from kirin import types 4 | from kirin.prelude import basic 5 | from kirin.analysis import const 6 | from kirin.dialects.py.range import Range 7 | 8 | 9 | @basic 10 | def new_range(a: int, b: int, c: int): 11 | x = range(a) 12 | y = range(a, b) 13 | z = range(a, b, c) 14 | return x, y, z 15 | 16 | 17 | new_range.print() 18 | 19 | 20 | def test_new_range(): 21 | stmt = cast(Range, new_range.callable_region.blocks[0].stmts.at(2)) 22 | assert isinstance(hint := stmt.start.hints.get("const"), const.Value) 23 | assert hint.data == 0 24 | assert stmt.stop.type.is_subseteq(types.Int) 25 | assert isinstance(hint := stmt.step.hints.get("const"), const.Value) 26 | assert hint.data == 1 27 | 28 | stmt = cast(Range, new_range.callable_region.blocks[0].stmts.at(4)) 29 | assert stmt.start.type.is_subseteq(types.Int) 30 | assert stmt.stop.type.is_subseteq(types.Int) 31 | assert stmt.step.type.is_subseteq(types.Int) 32 | assert isinstance(hint := stmt.step.hints.get("const"), const.Value) 33 | assert hint.data == 1 34 | 35 | stmt = cast(Range, new_range.callable_region.blocks[0].stmts.at(5)) 36 | assert stmt.start.type.is_subseteq(types.Int) 37 | assert stmt.stop.type.is_subseteq(types.Int) 38 | assert stmt.step.type.is_subseteq(types.Int) 39 | -------------------------------------------------------------------------------- /src/kirin/rewrite/getitem.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin import ir 4 | from kirin.analysis import const 5 | from kirin.dialects import py 6 | from kirin.rewrite.abc import RewriteRule, RewriteResult 7 | 8 | 9 | @dataclass 10 | class InlineGetItem(RewriteRule): 11 | 12 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 13 | if not isinstance(node, py.indexing.GetItem): 14 | return RewriteResult() 15 | 16 | if not isinstance(node.obj.owner, py.tuple.New): 17 | return RewriteResult() 18 | 19 | if not isinstance(index_value := node.index.hints.get("const"), const.Value): 20 | return RewriteResult() 21 | 22 | if not node.result.uses: 23 | return RewriteResult() 24 | 25 | stmt = node.obj.owner 26 | index = index_value.data 27 | if isinstance(index, int) and ( 28 | 0 <= index < len(stmt.args) or -len(stmt.args) <= index < 0 29 | ): 30 | node.result.replace_by(stmt.args[index]) 31 | return RewriteResult(has_done_something=True) 32 | elif isinstance(index, slice): 33 | new_tuple = py.tuple.New(tuple(stmt.args[index])) 34 | node.replace_by(new_tuple) 35 | return RewriteResult(has_done_something=True) 36 | else: 37 | return RewriteResult() 38 | -------------------------------------------------------------------------------- /src/kirin/interp/undefined.py: -------------------------------------------------------------------------------- 1 | """This module provides a singleton class `Undefined` and `UndefinedType` 2 | that represents an undefined value in the Kirin interpreter. 3 | 4 | The `Undefined` class is a singleton that can be used to represent an 5 | undefined value in the interpreter. It is used to indicate that a value 6 | has not been set or is not available. This is used to distinguish between 7 | an undefined value and a Python `None` value. 8 | """ 9 | 10 | from typing_extensions import TypeIs 11 | 12 | 13 | class UndefinedMeta(type): 14 | 15 | def __init__(cls, name, bases, attrs): 16 | super().__init__(name, bases, attrs) 17 | cls._instance = None 18 | 19 | def __call__(cls): 20 | if cls._instance is None: 21 | cls._instance = super().__call__() 22 | return cls._instance 23 | 24 | 25 | class UndefinedType(metaclass=UndefinedMeta): 26 | pass 27 | 28 | 29 | Undefined = UndefinedType() 30 | """Singleton instance of `UndefinedType` that represents an undefined value.""" 31 | 32 | 33 | def is_undefined(value: object) -> TypeIs[UndefinedType]: 34 | """Check if the given value is an instance of `UndefinedType`. 35 | 36 | Args: 37 | value (object): The value to check. 38 | 39 | Returns: 40 | bool: True if the value is an instance of `UndefinedType`, False otherwise. 41 | """ 42 | return value is Undefined 43 | -------------------------------------------------------------------------------- /src/kirin/lowering/stream.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar, Sequence 2 | from dataclasses import field, dataclass 3 | 4 | Stmt = TypeVar("Stmt") 5 | 6 | 7 | @dataclass 8 | class StmtStream(Generic[Stmt]): 9 | stmts: list[Stmt] = field(default_factory=list) 10 | cursor: int = 0 11 | 12 | def __init__(self, stmts: Sequence[Stmt], cursor: int = 0): 13 | self.stmts = list(stmts) 14 | self.cursor = cursor 15 | 16 | def __iter__(self): 17 | return self 18 | 19 | def __next__(self): 20 | if self.cursor < len(self.stmts): 21 | stmt = self.stmts[self.cursor] 22 | self.cursor += 1 23 | return stmt 24 | else: 25 | raise StopIteration 26 | 27 | def peek(self): 28 | return self.stmts[self.cursor] 29 | 30 | def split(self) -> "StmtStream": 31 | cursor = self.cursor 32 | self.cursor = len(self.stmts) 33 | return StmtStream(self.stmts, cursor) 34 | 35 | def __len__(self): 36 | return len(self.stmts) 37 | 38 | def __getitem__(self, key): 39 | return self.stmts[key] 40 | 41 | def __setitem__(self, key, value): 42 | self.stmts[key] = value 43 | 44 | def pop(self): 45 | stmt = self.stmts[self.cursor] 46 | self.cursor += 1 47 | return stmt 48 | 49 | def __bool__(self): 50 | return self.cursor < len(self.stmts) 51 | -------------------------------------------------------------------------------- /docs/scripts/gen_ref_nav.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | SRC_PATH = "src" 8 | 9 | skip_keywords = [] 10 | 11 | nav = mkdocs_gen_files.Nav() 12 | for path in sorted(Path(SRC_PATH).rglob("*.py")): 13 | module_path = path.relative_to(SRC_PATH).with_suffix("") 14 | doc_path = path.relative_to(SRC_PATH).with_suffix(".md") 15 | full_doc_path = Path("reference", doc_path) 16 | 17 | iskip = False 18 | 19 | for kwrd in skip_keywords: 20 | if kwrd in str(doc_path): 21 | iskip = True 22 | break 23 | if iskip: 24 | print("[Ignore]", str(doc_path)) 25 | continue 26 | 27 | print("[>]", str(doc_path)) 28 | 29 | parts = tuple(module_path.parts) 30 | 31 | if parts[-1] == "__init__": 32 | parts = parts[:-1] 33 | doc_path = doc_path.with_name("index.md") 34 | full_doc_path = full_doc_path.with_name("index.md") 35 | elif parts[-1].startswith("_"): 36 | continue 37 | 38 | nav[parts] = doc_path.as_posix() 39 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 40 | ident = ".".join(parts) 41 | fd.write(f"::: {ident}") 42 | 43 | mkdocs_gen_files.set_edit_path(full_doc_path, ".." / path) 44 | 45 | with mkdocs_gen_files.open("reference/SUMMARY.txt", "w") as nav_file: 46 | nav_file.writelines(nav.build_literate_nav()) 47 | -------------------------------------------------------------------------------- /test/rules/test_alias.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic_no_opt 2 | from kirin.rewrite import Walk, Chain, Fixpoint, WrapConst 3 | from kirin.analysis import const 4 | from kirin.rewrite.dce import DeadCodeElimination 5 | from kirin.rewrite.alias import InlineAlias 6 | 7 | 8 | @basic_no_opt 9 | def main_simplify_alias(x: int): 10 | y = x + 1 11 | z = y 12 | z2 = z 13 | return z2 14 | 15 | 16 | def test_alias_inline(): 17 | constprop = const.Propagate(main_simplify_alias.dialects) 18 | frame, ret = constprop.run(main_simplify_alias) 19 | Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_alias.code) 20 | Fixpoint(Walk(Chain([InlineAlias(), DeadCodeElimination()]))).rewrite( 21 | main_simplify_alias.code 22 | ) 23 | assert len(main_simplify_alias.callable_region.blocks[0].stmts) == 3 24 | 25 | 26 | @basic_no_opt 27 | def simplify_alias_ref_const(): 28 | y = 3 29 | z = y 30 | return z 31 | 32 | 33 | def test_alias_inline2(): 34 | constprop = const.Propagate(simplify_alias_ref_const.dialects) 35 | frame, _ = constprop.run(simplify_alias_ref_const) 36 | Fixpoint(Walk(WrapConst(frame))).rewrite(main_simplify_alias.code) 37 | Fixpoint(Walk(Chain([InlineAlias(), DeadCodeElimination()]))).rewrite( 38 | simplify_alias_ref_const.code 39 | ) 40 | simplify_alias_ref_const.code.print() 41 | assert len(simplify_alias_ref_const.callable_region.blocks[0].stmts) == 2 42 | -------------------------------------------------------------------------------- /test/lowering/test_with_binding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generator 2 | from contextlib import contextmanager 3 | 4 | from kirin import ir, lowering 5 | from kirin.decl import info, statement 6 | from kirin.prelude import structural_no_opt 7 | from kirin.dialects import ilist 8 | 9 | dialect = ir.Dialect("with_binding") 10 | 11 | 12 | @statement(dialect=dialect) 13 | class ContextStatatement(ir.Statement): 14 | traits = frozenset({lowering.FromPythonWithSingleItem()}) 15 | body: ir.Region = info.region(multi=False) 16 | 17 | 18 | @ir.dialect_group(structural_no_opt.add(dialect)) 19 | def dummy(self): 20 | 21 | def run_pass(mt): 22 | 23 | return mt 24 | 25 | return run_pass 26 | 27 | 28 | @lowering.wraps(ContextStatatement) 29 | @contextmanager 30 | def context_statement() -> Generator[Any, None, None]: ... 31 | 32 | 33 | @dummy 34 | def with_binding(): 35 | x = 1 36 | 37 | def fn(x): 38 | return x**2 39 | 40 | with context_statement(): 41 | with context_statement(): 42 | x = ilist.map(fn, ilist.range(10)) 43 | 44 | return x 45 | 46 | 47 | def test_with_binding(): 48 | stmt = with_binding.callable_region.blocks[0].stmts.at(-2) 49 | assert isinstance(stmt, ContextStatatement) 50 | assert len(stmt.body.blocks) == 1 51 | stmt = stmt.body.blocks[0].stmts.at(0) 52 | assert isinstance(stmt, ContextStatatement) 53 | assert len(stmt.body.blocks[0].stmts) == 5 54 | -------------------------------------------------------------------------------- /test/dialects/test_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin import ir, types 4 | from kirin.prelude import basic 5 | from kirin.dialects import py, func, module 6 | 7 | 8 | def test_module(): 9 | fn1 = func.Function( 10 | sym_name="foo", 11 | slots=(), 12 | body=ir.Region( 13 | ir.Block( 14 | [ 15 | x := py.Constant(1), 16 | func.Return(x.result), 17 | ], 18 | argtypes=(types.Any,), 19 | ) 20 | ), 21 | signature=func.Signature(inputs=(), output=types.Int), 22 | ) 23 | 24 | fn2 = func.Function( 25 | sym_name="main", 26 | slots=(), 27 | body=ir.Region( 28 | ir.Block( 29 | [ 30 | x := module.Invoke((), (), callee="foo"), 31 | func.Return(x.result), 32 | ], 33 | argtypes=(types.Any,), 34 | ) 35 | ), 36 | signature=func.Signature(inputs=(), output=types.Int), 37 | ) 38 | 39 | mod = module.Module( 40 | sym_name="test_module", entry="main", body=ir.Region(ir.Block([fn1, fn2])) 41 | ) 42 | 43 | dialects = basic.add(module) 44 | method = ir.Method(dialects=dialects, code=mod) 45 | method.print() 46 | 47 | with pytest.raises(KeyError): 48 | method() 49 | 50 | dialects.update_symbol_table(method) 51 | assert method() == 1 52 | -------------------------------------------------------------------------------- /test/interp/test_select.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import pytest 4 | 5 | from kirin import interp 6 | from kirin.lattice import EmptyLattice 7 | from kirin.prelude import basic 8 | from kirin.dialects import py 9 | from kirin.ir.method import Method 10 | from kirin.ir.nodes.stmt import Statement 11 | from kirin.analysis.forward import Forward, ForwardFrame 12 | 13 | 14 | @dataclass 15 | class DummyInterpreter(Forward[EmptyLattice]): 16 | keys = ("test_interp",) 17 | lattice = EmptyLattice 18 | 19 | def method_self(self, method: Method) -> EmptyLattice: 20 | return EmptyLattice() 21 | 22 | def eval_fallback( 23 | self, frame: ForwardFrame[EmptyLattice], node: Statement 24 | ) -> interp.StatementResult[EmptyLattice]: 25 | ret = super().eval_fallback(frame, node) 26 | print("fallback: ", ret) 27 | return ret 28 | 29 | 30 | @py.tuple.dialect.register(key="test_interp") 31 | class DialectMethodTable(interp.MethodTable): 32 | 33 | @interp.impl(py.tuple.New) 34 | def new_tuple(self, interp: DummyInterpreter, frame, stmt: py.tuple.New): 35 | return (EmptyLattice(),) 36 | 37 | 38 | @basic 39 | def main(x): 40 | return 1 41 | 42 | 43 | def test_interp(): 44 | interp_ = DummyInterpreter(basic) 45 | with pytest.raises(NotImplementedError): 46 | interp_.run(main, EmptyLattice()) 47 | 48 | interp_ = DummyInterpreter(basic) 49 | interp_.run_no_raise(main, EmptyLattice()) 50 | -------------------------------------------------------------------------------- /example/food/rewrite.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from stmts import Eat, Nap, NewFood, RandomBranch 4 | 5 | from kirin import ir 6 | from kirin.dialects import cf 7 | from kirin.rewrite.abc import RewriteRule, RewriteResult 8 | 9 | 10 | @dataclass 11 | class RandomWalkBranch(RewriteRule): 12 | 13 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 14 | if not isinstance(node, cf.ConditionalBranch): 15 | return RewriteResult() 16 | node.replace_by( 17 | RandomBranch( 18 | cond=node.cond, 19 | then_arguments=node.then_arguments, 20 | then_successor=node.then_successor, 21 | else_arguments=node.else_arguments, 22 | else_successor=node.else_successor, 23 | ) 24 | ) 25 | return RewriteResult(has_done_something=True) 26 | 27 | 28 | @dataclass 29 | class NewFoodAndNap(RewriteRule): 30 | # sometimes someone is hungry and needs a nap 31 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 32 | if not isinstance(node, Eat): 33 | return RewriteResult() 34 | 35 | # 1. create new stmts: 36 | new_food_stmt = NewFood(type="burger") 37 | nap_stmt = Nap() 38 | 39 | # 2. put them in the ir 40 | new_food_stmt.insert_after(node) 41 | nap_stmt.insert_after(new_food_stmt) 42 | 43 | return RewriteResult(has_done_something=True) 44 | -------------------------------------------------------------------------------- /docs/dialects/python/data.md: -------------------------------------------------------------------------------- 1 | !!! warning 2 | This page is under construction. The content may be incomplete or incorrect. Submit an issue 3 | on [GitHub](https://github.com/QuEraComputing/kirin/issues/new) if you need help or want to 4 | contribute. 5 | 6 | # Dialects that brings in common Python data types 7 | 8 | This page provides a reference for dialects that bring in semantics for common Python data types. 9 | 10 | !!! note 11 | While it is worth noting that using Python semantics can be very convenient, it is also important to remember that the Python semantics are not designed for compilation. Therefore, it is important to be aware of the limitations of using Python semantics in a compilation context especially when it comes to data types. 12 | An example of this is the `list` data type in Python which is a dynamic mutable array. When the low-level code is not expecting a dynamic mutable array, it can lead to extra complexity for compilation. An immutable array or a fixed-size array can be a better choice in such cases (see `ilist` dialect). 13 | 14 | ## References 15 | 16 | ### Tuple 17 | 18 | ::: kirin.dialects.py.tuple 19 | options: 20 | filters: 21 | - "!statement" 22 | show_root_heading: true 23 | show_if_no_docstring: true 24 | 25 | ### List 26 | 27 | ::: kirin.dialects.py.list 28 | options: 29 | filters: 30 | - "!statement" 31 | show_root_heading: true 32 | show_if_no_docstring: true 33 | -------------------------------------------------------------------------------- /src/kirin/source.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class SourceInfo: 7 | lineno: int 8 | col_offset: int 9 | end_lineno: int | None 10 | end_col_offset: int | None 11 | file: str | None = None 12 | lineno_begin: int = 0 13 | col_indent: int = 0 14 | 15 | @classmethod 16 | def from_ast( 17 | cls, 18 | node: ast.AST, 19 | file: str | None = None, 20 | ): 21 | end_lineno = getattr(node, "end_lineno", None) 22 | end_col_offset = getattr(node, "end_col_offset", None) 23 | return cls( 24 | getattr(node, "lineno", 0), 25 | getattr(node, "col_offset", 0), 26 | end_lineno if end_lineno is not None else None, 27 | end_col_offset if end_col_offset is not None else None, 28 | file, 29 | ) 30 | 31 | def offset(self, lineno_begin: int = 0, col_indent: int = 0): 32 | """Offset the source info by the given offsets. 33 | 34 | Args: 35 | lineno_offset (int): The line number offset. 36 | col_offset (int): The column offset. 37 | """ 38 | self.lineno_begin = lineno_begin 39 | self.col_indent = col_indent 40 | 41 | def __repr__(self) -> str: 42 | return ( 43 | f'File "{self.file or "stdin"}", ' 44 | f"line {self.lineno + self.lineno_begin}," 45 | f" col {self.col_offset + self.col_indent}" 46 | ) 47 | -------------------------------------------------------------------------------- /example/food/stmts.py: -------------------------------------------------------------------------------- 1 | from attrs import Food, Serving 2 | from dialect import dialect 3 | 4 | from kirin import ir, types 5 | from kirin.decl import info, statement 6 | 7 | 8 | @statement(dialect=dialect) 9 | class NewFood(ir.Statement): 10 | name = "new_food" 11 | traits = frozenset({ir.Pure(), ir.FromPythonCall()}) 12 | type: str = info.attribute(types.String) 13 | result: ir.ResultValue = info.result(types.PyClass(Food)) 14 | 15 | 16 | @statement(dialect=dialect) 17 | class Cook(ir.Statement): 18 | traits = frozenset({ir.FromPythonCall()}) 19 | target: ir.SSAValue = info.argument(types.PyClass(Food)) 20 | amount: ir.SSAValue = info.argument(types.Int) 21 | result: ir.ResultValue = info.result(types.PyClass(Serving)) 22 | 23 | 24 | @statement(dialect=dialect) 25 | class Eat(ir.Statement): 26 | traits = frozenset({ir.FromPythonCall()}) 27 | target: ir.SSAValue = info.argument(types.PyClass(Serving)) 28 | 29 | 30 | @statement(dialect=dialect) 31 | class Nap(ir.Statement): 32 | traits = frozenset({ir.FromPythonCall()}) 33 | 34 | 35 | @statement(dialect=dialect) 36 | class RandomBranch(ir.Statement): 37 | name = "random_br" 38 | traits = frozenset({ir.IsTerminator()}) 39 | cond: ir.SSAValue = info.argument(types.Bool) 40 | then_arguments: tuple[ir.SSAValue, ...] = info.argument() 41 | else_arguments: tuple[ir.SSAValue, ...] = info.argument() 42 | then_successor: ir.Block = info.block() 43 | else_successor: ir.Block = info.block() 44 | -------------------------------------------------------------------------------- /src/kirin/rewrite/call2invoke.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin import ir 4 | from kirin.analysis import const 5 | from kirin.rewrite.abc import RewriteRule, RewriteResult 6 | from kirin.dialects.func import Call, Invoke 7 | 8 | 9 | @dataclass 10 | class Call2Invoke(RewriteRule): 11 | """Rewrite a `Call` statement to an `Invoke` statement.""" 12 | 13 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 14 | if not isinstance(node, Call): 15 | return RewriteResult() 16 | 17 | if (mt := node.callee.hints.get("const")) is None: 18 | return RewriteResult() 19 | 20 | if not isinstance(mt, const.Value): 21 | return RewriteResult() 22 | 23 | if not isinstance(mt.data, ir.Method): 24 | return RewriteResult() 25 | 26 | method = mt.data 27 | trait = method.code.get_present_trait(ir.CallableStmtInterface) 28 | inputs = trait.align_input_args( 29 | method.code, *node.inputs, **dict(zip(node.keys, node.kwargs)) 30 | ) 31 | stmt = Invoke(inputs=inputs, callee=mt.data) 32 | for result, new_result in zip(node.results, stmt.results): 33 | new_result.name = result.name 34 | new_result.type = result.type 35 | if result_hint := result.hints.get("const"): 36 | new_result.hints["const"] = result_hint 37 | 38 | node.replace_by(stmt) 39 | return RewriteResult(has_done_something=True) 40 | -------------------------------------------------------------------------------- /test/testing/test_assert_statements_same.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.testing import assert_statements_same 4 | from kirin.dialects import py 5 | 6 | 7 | def test_same_statements_pass(): 8 | statement_1 = py.Constant(1) 9 | statement_2 = py.Constant(1) 10 | assert_statements_same(statement_1, statement_2) 11 | 12 | 13 | def test_different_statements_fail(): 14 | statement_1 = py.Constant(1) 15 | statement_2 = py.Mult(statement_1.result, statement_1.result) 16 | with pytest.raises(AssertionError): 17 | assert_statements_same(statement_1, statement_2) 18 | 19 | 20 | def test_same_statement_different_args_fails_1(): 21 | statement_1 = py.Constant(1) 22 | statement_2 = py.Constant(2) 23 | with pytest.raises(AssertionError): 24 | assert_statements_same(statement_1, statement_2) 25 | 26 | 27 | def test_same_statement_different_args_fails_2(): 28 | arg_1 = py.Constant(1) 29 | arg_2 = py.Constant(2) 30 | statement_1 = py.Mult(arg_1.result, arg_1.result) 31 | statement_2 = py.Mult(arg_2.result, arg_2.result) 32 | with pytest.raises(AssertionError): 33 | assert_statements_same(statement_1, statement_2) 34 | 35 | 36 | def test_same_statement_different_args_check_args_false_passes(): 37 | arg_1 = py.Constant(1) 38 | arg_2 = py.Constant(2) 39 | statement_1 = py.Mult(arg_1.result, arg_1.result) 40 | statement_2 = py.Mult(arg_2.result, arg_2.result) 41 | assert_statements_same(statement_1, statement_2, check_args=False) 42 | -------------------------------------------------------------------------------- /test/rules/test_cse.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import basic, basic_no_opt 2 | from kirin.rewrite import Walk, Fixpoint 3 | from kirin.rewrite.cse import CommonSubexpressionElimination 4 | 5 | 6 | @basic_no_opt 7 | def badprogram(x: int, y: int) -> int: 8 | a = x + y 9 | b = x + y 10 | x = a + b 11 | y = a + b 12 | return x + y 13 | 14 | 15 | def test_cse(): 16 | before = badprogram(1, 2) 17 | cse = CommonSubexpressionElimination() 18 | Fixpoint(Walk(cse)).rewrite(badprogram.code) 19 | after = badprogram(1, 2) 20 | 21 | assert before == after 22 | 23 | 24 | @basic_no_opt 25 | def cse_constant(): 26 | x = 1 27 | y = 2 28 | z = 1 29 | return x + y + z 30 | 31 | 32 | def test_cse_constant(): 33 | # NOTE: issue #61 34 | before = cse_constant() 35 | cse_constant.print() 36 | cse = CommonSubexpressionElimination() 37 | Fixpoint(Walk(cse)).rewrite(cse_constant.code) 38 | after = cse_constant() 39 | cse_constant.print() 40 | assert before == after 41 | assert len(cse_constant.callable_region.blocks[0].stmts) == 5 42 | 43 | 44 | def test_cse_constant_int_float(): 45 | 46 | @basic(fold=False, typeinfer=True) 47 | def gv2(x: int): 48 | y = 1 49 | z = 1.0 50 | return y + z + x 51 | 52 | out = gv2(2) 53 | 54 | Walk(CommonSubexpressionElimination()).rewrite(gv2.code) 55 | gv2.print() 56 | 57 | out2 = gv2(2) 58 | 59 | assert out == out2 60 | assert type(out) is type(out2) 61 | assert type(out) is float 62 | -------------------------------------------------------------------------------- /test/ir/test_stmt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin.ir import Block 4 | from kirin.source import SourceInfo 5 | from kirin.dialects import py 6 | 7 | 8 | def test_stmt(): 9 | a = py.Constant(0) 10 | x = py.Constant(1) 11 | y = py.Constant(2) 12 | z = py.Add(lhs=x.result, rhs=y.result) 13 | 14 | bb1 = Block([a, x, y, z]) 15 | assert bb1.first_stmt == a 16 | bb1.print() 17 | 18 | a.delete() 19 | assert bb1.first_stmt == x 20 | bb1.print() 21 | 22 | a.insert_before(x) 23 | bb1.print() 24 | assert bb1.first_stmt == a 25 | 26 | a.delete() 27 | a.insert_after(x) 28 | bb1.stmts.at(1) == a # type: ignore 29 | 30 | with pytest.raises(ValueError): 31 | a.insert_after(x) 32 | 33 | with pytest.raises(ValueError): 34 | a.insert_before(x) 35 | 36 | 37 | def test_block_append(): 38 | block = Block() 39 | block.stmts.append(py.Constant(1)) 40 | block.stmts.append(py.Constant(1)) 41 | block.print() 42 | assert len(block.stmts) == 2 43 | 44 | 45 | def test_stmt_from_stmt(): 46 | 47 | x = py.Constant(1) 48 | 49 | x.result.hints["const"] = py.constant.types.Int 50 | 51 | y = x.from_stmt(x) 52 | 53 | assert y.result.hints["const"] == py.constant.types.Int 54 | 55 | 56 | def test_stmt_from_stmt_preserves_source_info(): 57 | x = py.Constant(1) 58 | x.source = SourceInfo(lineno=1, col_offset=0, end_lineno=None, end_col_offset=None) 59 | 60 | y = x.from_stmt(x) 61 | assert y.source == x.source 62 | assert y.source is x.source 63 | -------------------------------------------------------------------------------- /test/stmt/test_statement.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from kirin import ir 4 | from kirin.decl import info, statement 5 | 6 | dialect = ir.Dialect("my_dialect") 7 | 8 | 9 | def test_reserved_verify(): 10 | with pytest.raises(ValueError): 11 | 12 | @statement(dialect=dialect) 13 | class ReserveKeyword(ir.Statement): 14 | name = "my_statement" 15 | traits = frozenset({}) 16 | args: ir.SSAValue = info.argument() 17 | 18 | with pytest.raises(ValueError): 19 | 20 | @statement(dialect=dialect) 21 | class NoAnnotation(ir.Statement): 22 | name = "my_statement" 23 | traits = frozenset({}) 24 | no_annotation = info.argument() # type: ignore 25 | 26 | with pytest.raises(ValueError): 27 | 28 | @statement(dialect=dialect) 29 | class WrongAnnotation(ir.Statement): 30 | name = "my_statement" 31 | traits = frozenset({}) 32 | field: str = info.argument() 33 | 34 | with pytest.raises(ValueError): 35 | 36 | @statement(dialect=dialect) 37 | class WrongFiledSpecifier(ir.Statement): 38 | name = "my_statement" 39 | traits = frozenset({}) 40 | result: ir.ResultValue = info.argument() 41 | 42 | with pytest.raises(ValueError): 43 | 44 | @statement(dialect=dialect) 45 | class WrongResultAnnotation(ir.Statement): 46 | name = "my_statement" 47 | traits = frozenset({}) 48 | result: ir.SSAValue = info.result() 49 | -------------------------------------------------------------------------------- /test/analysis/dataflow/typeinfer/test_unstable.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.prelude import basic_no_opt 3 | from kirin.analysis.typeinfer import TypeInference 4 | 5 | 6 | def test_untable_branch(): 7 | @basic_no_opt 8 | def unstable(x: int): # type: ignore 9 | y = x + 1 10 | if y > 10: 11 | z = y 12 | else: 13 | z = y + 1.2 14 | return z 15 | 16 | infer = TypeInference(dialects=unstable.dialects) 17 | frame, ret = infer.run_no_raise(unstable, types.Int) 18 | assert ret == types.Union(types.Int, types.Float) 19 | 20 | def stmt_at(block_id, stmt_id) -> ir.Statement: 21 | return unstable.code.body.blocks[block_id].stmts.at(stmt_id) # type: ignore 22 | 23 | def results_at(block_id, stmt_id): 24 | return stmt_at(block_id, stmt_id).results 25 | 26 | assert [frame.entries[result] for result in results_at(0, 0)] == [types.Int] 27 | assert [frame.entries[result] for result in results_at(0, 1)] == [types.Int] 28 | assert [frame.entries[result] for result in results_at(0, 2)] == [types.Int] 29 | assert [frame.entries[result] for result in results_at(0, 3)] == [types.Bool] 30 | 31 | assert [frame.entries[result] for result in results_at(1, 0)] == [types.Int] 32 | assert [frame.entries[result] for result in results_at(2, 0)] == [types.Float] 33 | assert [frame.entries[result] for result in results_at(2, 1)] == [types.Float] 34 | 35 | stmt = stmt_at(3, 0) 36 | assert frame.entries[stmt.args[0]] == (types.Int | types.Float) 37 | -------------------------------------------------------------------------------- /test/dialects/kirin_random/test_random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import kirin.prelude 4 | from kirin.dialects import ilist, random as kirin_random 5 | 6 | 7 | def test_random(): 8 | 9 | random.seed(12) 10 | expected = [random.random() for _ in range(4)] 11 | 12 | @kirin.prelude.basic 13 | def rnd_main() -> ilist.IList: 14 | kirin_random.seed(12) 15 | out = [] 16 | for i in range(4): 17 | # for i in [1,1,2,3]: # Same result with this line instead 18 | out = out + [kirin_random.random()] 19 | return out 20 | 21 | out: ilist.IList = rnd_main() 22 | 23 | assert out.data == expected 24 | 25 | 26 | def test_randint(): 27 | 28 | random.seed(12) 29 | expected = [random.randint(i, 10) for i in range(4)] 30 | 31 | @kirin.prelude.basic 32 | def rndint_main() -> ilist.IList: 33 | kirin_random.seed(12) 34 | out = [] 35 | for i in range(4): 36 | out = out + [kirin_random.randint(i, 10)] 37 | return out 38 | 39 | out: ilist.IList = rndint_main() 40 | 41 | assert out.data == expected 42 | 43 | 44 | def test_uniform(): 45 | 46 | random.seed(12) 47 | expected = [random.uniform(i, 10) for i in range(4)] 48 | 49 | @kirin.prelude.basic 50 | def rnduniform_main() -> ilist.IList: 51 | kirin_random.seed(12) 52 | out = [] 53 | for i in range(4): 54 | out = out + [kirin_random.uniform(i, 10)] 55 | return out 56 | 57 | out: ilist.IList = rnduniform_main() 58 | 59 | assert out.data == expected 60 | -------------------------------------------------------------------------------- /test/program/py/test_aggressive.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | from kirin import ir, types, lowering 4 | from kirin.decl import info, statement 5 | from kirin.prelude import basic, basic_no_opt 6 | from kirin.dialects.py import cmp 7 | 8 | dialect = ir.Dialect("dummy2") 9 | 10 | 11 | @statement(dialect=dialect) 12 | class DummyStmt2(ir.Statement): 13 | name = "dummy2" 14 | traits = frozenset({lowering.FromPythonCall()}) 15 | value: ir.SSAValue = info.argument(types.Int) 16 | option: ir.PyAttr[str] = info.attribute() 17 | result: ir.ResultValue = info.result(types.Int) 18 | 19 | 20 | @basic_no_opt.add(dialect) 21 | def unfolable(x: int, y: int): 22 | def inner(): 23 | DummyStmt2(x, option=ir.PyAttr("hello")) 24 | DummyStmt2(y, option=ir.PyAttr("hello")) 25 | 26 | return inner 27 | 28 | 29 | @basic.add(dialect)(fold=True, aggressive=True) 30 | def main(): 31 | x = DummyStmt2(1, option=ir.PyAttr("hello")) 32 | x = unfolable(x, x) 33 | return x() 34 | 35 | 36 | def test_aggressive_pass(): 37 | const_count = 0 38 | dummy_count = 0 39 | for stmt in main.callable_region.walk(): 40 | if isinstance(stmt, DummyStmt2): 41 | dummy_count += 1 42 | elif stmt.has_trait(ir.ConstantLike): 43 | const_count += 1 44 | assert dummy_count == 3 45 | assert const_count == 2 46 | 47 | 48 | @basic(fold=True, aggressive=True) 49 | def should_fold(): 50 | return 1 < 2 51 | 52 | 53 | def test_should_fold(): 54 | for stmt in should_fold.callable_region.walk(): 55 | assert not isinstance(stmt, cmp.Lt) 56 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/rewrite/hint_len.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types 2 | from kirin.analysis import const 3 | from kirin.dialects import py, scf 4 | from kirin.rewrite.abc import RewriteRule, RewriteResult 5 | from kirin.dialects.ilist.stmts import IListType 6 | 7 | from .._dialect import dialect 8 | 9 | 10 | @dialect.post_inference 11 | class HintLen(RewriteRule): 12 | 13 | def _get_collection_len(self, collection: ir.SSAValue): 14 | coll_type = collection.type 15 | 16 | if not isinstance(coll_type, types.Generic): 17 | return None 18 | 19 | if ( 20 | coll_type.is_subseteq(IListType) 21 | and isinstance(coll_type.vars[1], types.Literal) 22 | and isinstance(coll_type.vars[1].data, int) 23 | ): 24 | return coll_type.vars[1].data 25 | else: 26 | return None 27 | 28 | def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: 29 | 30 | if not ( 31 | isinstance(node, py.Len) 32 | and not isinstance(node.parent_stmt, (scf.For, scf.IfElse)) 33 | ): 34 | return RewriteResult() 35 | 36 | if (coll_len := self._get_collection_len(node.value)) is None: 37 | return RewriteResult() 38 | 39 | existing_hint = node.result.hints.get("const") 40 | new_hint = const.Value(coll_len) 41 | 42 | if existing_hint is not None and new_hint.is_structurally_equal(existing_hint): 43 | return RewriteResult() 44 | 45 | node.result.hints["const"] = new_hint 46 | return RewriteResult(has_done_something=True) 47 | -------------------------------------------------------------------------------- /example/food/interp.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | 3 | from attrs import Food, Serving 4 | from stmts import Eat, Nap, Cook, NewFood, RandomBranch 5 | from dialect import dialect 6 | 7 | from kirin.interp import Frame, Successor, Interpreter, MethodTable, impl 8 | 9 | 10 | @dialect.register 11 | class FoodMethods(MethodTable): 12 | 13 | @impl(NewFood) 14 | def new_food(self, interp: Interpreter, frame: Frame, stmt: NewFood): 15 | return (Food(stmt.type),) 16 | 17 | @impl(Eat) 18 | def eat(self, interp: Interpreter, frame: Frame, stmt: Eat): 19 | serving: Serving = frame.get(stmt.target) 20 | print(f"Eating {serving.amount} servings of {serving.kind.type}") 21 | return () 22 | 23 | @impl(Cook) 24 | def cook(self, interp: Interpreter, frame: Frame, stmt: Cook): 25 | food: Food = frame.get(stmt.target) 26 | amount: int = frame.get(stmt.amount) 27 | print(f"Cooking {food.type} {amount}") 28 | 29 | return (Serving(food, amount),) 30 | 31 | @impl(Nap) 32 | def nap(self, interp: Interpreter, frame: Frame, stmt: Nap): 33 | print("Napping!!!") 34 | return () 35 | 36 | @impl(RandomBranch) 37 | def random_branch(self, interp: Interpreter, frame: Frame, stmt: RandomBranch): 38 | frame = interp.state.current_frame() 39 | if randint(0, 1): 40 | return Successor( 41 | stmt.then_successor, *frame.get_values(stmt.then_arguments) 42 | ) 43 | else: 44 | return Successor( 45 | stmt.else_successor, *frame.get_values(stmt.then_arguments) 46 | ) 47 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/cmp/lowering.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from kirin import ir, lowering 4 | from kirin.dialects.py import boolop 5 | 6 | from . import stmts 7 | from ._dialect import dialect 8 | 9 | 10 | @dialect.register 11 | class PythonLowering(lowering.FromPythonAST): 12 | 13 | def lower_Compare( 14 | self, state: lowering.State, node: ast.Compare 15 | ) -> lowering.Result: 16 | # NOTE: a key difference here is we need to lower 17 | # the multi-argument comparison operators into binary operators 18 | # since low-level comparision operators are binary + we need a static 19 | # number of arguments in each instruction 20 | lhs = state.lower(node.left).expect_one() 21 | 22 | comparators = [ 23 | state.lower(comparator).expect_one() for comparator in node.comparators 24 | ] 25 | 26 | cmp_results: list[ir.SSAValue] = [] 27 | for op, rhs in zip(node.ops, comparators): 28 | if cls := getattr(stmts, op.__class__.__name__, None): 29 | stmt: stmts.Cmp = cls(lhs=lhs, rhs=rhs) 30 | else: 31 | raise lowering.BuildError(f"unsupported compare operator {op}") 32 | state.current_frame.push(stmt) 33 | cmp_results.append(stmt.result) 34 | lhs = rhs 35 | 36 | if len(cmp_results) == 1: 37 | return cmp_results[0] 38 | 39 | lhs = cmp_results[0] 40 | for rhs in cmp_results[1:]: 41 | and_stmt = boolop.And(lhs=lhs, rhs=rhs) 42 | state.current_frame.push(and_stmt) 43 | lhs = and_stmt.result 44 | 45 | return lhs 46 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/binop/stmts.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | 4 | from ._dialect import dialect 5 | 6 | T = types.TypeVar("T") 7 | 8 | 9 | @statement 10 | class BinOp(ir.Statement): 11 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 12 | lhs: ir.SSAValue = info.argument(T, print=False) 13 | rhs: ir.SSAValue = info.argument(T, print=False) 14 | result: ir.ResultValue = info.result(T) 15 | 16 | 17 | @statement(dialect=dialect) 18 | class Add(BinOp): 19 | name = "add" 20 | 21 | 22 | @statement(dialect=dialect) 23 | class Sub(BinOp): 24 | name = "sub" 25 | 26 | 27 | @statement(dialect=dialect) 28 | class Mult(BinOp): 29 | name = "mult" 30 | 31 | 32 | @statement(dialect=dialect) 33 | class Div(BinOp): 34 | name = "div" 35 | 36 | 37 | @statement(dialect=dialect) 38 | class Mod(BinOp): 39 | name = "mod" 40 | 41 | 42 | @statement(dialect=dialect) 43 | class Pow(BinOp): 44 | name = "pow" 45 | 46 | 47 | @statement(dialect=dialect) 48 | class LShift(BinOp): 49 | name = "lshift" 50 | 51 | 52 | @statement(dialect=dialect) 53 | class RShift(BinOp): 54 | name = "rshift" 55 | 56 | 57 | @statement(dialect=dialect) 58 | class BitAnd(BinOp): 59 | name = "bitand" 60 | 61 | 62 | @statement(dialect=dialect) 63 | class BitOr(BinOp): 64 | name = "bitor" 65 | 66 | 67 | @statement(dialect=dialect) 68 | class BitXor(BinOp): 69 | name = "bitxor" 70 | 71 | 72 | @statement(dialect=dialect) 73 | class FloorDiv(BinOp): 74 | name = "floordiv" 75 | 76 | 77 | @statement(dialect=dialect) 78 | class MatMult(BinOp): 79 | name = "matmult" 80 | -------------------------------------------------------------------------------- /test/dialects/func/test_closurefield.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from kirin import rewrite 4 | from kirin.prelude import basic 5 | from kirin.dialects import py, func 6 | from kirin.dialects.func.rewrite import closurefield 7 | 8 | 9 | def test_rewrite_closure_inner_lambda(): 10 | @basic 11 | def outer(y: int): 12 | def inner(x: int): 13 | return x * y + 1 14 | 15 | return inner 16 | 17 | inner_ker = outer(y=10) 18 | 19 | @basic 20 | def main_lambda(z: int): 21 | return inner_ker(z) 22 | 23 | main_invoke = main_lambda.code.regions[0].blocks[0].stmts.at(0) 24 | inner_lambda = cast(func.Invoke, main_invoke).callee.code 25 | inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0) 26 | assert isinstance( 27 | inner_getfield_stmt, func.GetField 28 | ), "expected GetField before rewrite" 29 | 30 | rewrite.Walk(closurefield.ClosureField()).rewrite(main_lambda.code) 31 | 32 | inner_getfield_stmt = inner_lambda.regions[0].blocks[0].stmts.at(0) 33 | assert isinstance( 34 | inner_getfield_stmt, py.Constant 35 | ), "GetField should be lowered to Constant" 36 | 37 | 38 | def test_rewrite_closure_no_fields(): 39 | @basic 40 | def bar(): 41 | def goo(x: int): 42 | a = (3, 4) 43 | return a[0] 44 | 45 | def boo(y): 46 | return goo(y) + 1 47 | 48 | return boo(4) 49 | 50 | before = bar.code.regions[0].blocks[0].stmts.at(0) 51 | rewrite.Walk(closurefield.ClosureField()).rewrite(bar.code) 52 | after = bar.code.regions[0].blocks[0].stmts.at(0) 53 | assert before is after 54 | -------------------------------------------------------------------------------- /src/kirin/passes/abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import ClassVar 3 | from dataclasses import field, dataclass 4 | 5 | from kirin.ir import Method, DialectGroup 6 | from kirin.rewrite.abc import RewriteResult 7 | 8 | 9 | @dataclass 10 | class Pass(ABC): 11 | """A pass is a transformation that is applied to a method. It wraps 12 | the analysis and rewrites needed to transform the method as an independent 13 | unit. 14 | 15 | Unlike LLVM/MLIR passes, a pass in Kirin does not apply to a module, 16 | this is because we focus on individual methods defined within 17 | python modules. This is a design choice to allow seamless integration 18 | within the Python interpreter. 19 | 20 | A Kirin compile unit is a `ir.Method` object, which is always equivalent 21 | to a LLVM/MLIR module if it were lowered to LLVM/MLIR just like other JIT 22 | compilers. 23 | """ 24 | 25 | name: ClassVar[str] 26 | dialects: DialectGroup 27 | no_raise: bool = field(default=True, kw_only=True) 28 | 29 | def __call__(self, mt: Method) -> RewriteResult: 30 | result = self.unsafe_run(mt) 31 | mt.code.verify() 32 | return result 33 | 34 | def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult: 35 | result = RewriteResult() 36 | for _ in range(max_iter): 37 | result_ = self.unsafe_run(mt) 38 | result = result_.join(result) 39 | if not result_.has_done_something: 40 | break 41 | mt.verify() 42 | return result 43 | 44 | @abstractmethod 45 | def unsafe_run(self, mt: Method) -> RewriteResult: ... 46 | -------------------------------------------------------------------------------- /src/kirin/dialects/random/stmts.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | 4 | from ._dialect import dialect 5 | 6 | 7 | @statement(dialect=dialect) 8 | class Random(ir.Statement): 9 | """random statement, wrapping the random.random function 10 | returns a random floating number between 0 and 1 11 | """ 12 | 13 | traits = frozenset({lowering.FromPythonCall()}) 14 | result: ir.ResultValue = info.result(types.Float) 15 | 16 | 17 | @statement(dialect=dialect) 18 | class RandInt(ir.Statement): 19 | """randint statement, wrapping the random.randint function 20 | returns a random integer between the given range 21 | """ 22 | 23 | traits = frozenset({lowering.FromPythonCall()}) 24 | start: ir.SSAValue = info.argument(types.Int) 25 | stop: ir.SSAValue = info.argument(types.Int) 26 | result: ir.ResultValue = info.result(types.Int) 27 | 28 | 29 | @statement(dialect=dialect) 30 | class Uniform(ir.Statement): 31 | """uniform statement, wrapping the random.uniform function 32 | returns a random floating number between the given range 33 | """ 34 | 35 | traits = frozenset({lowering.FromPythonCall()}) 36 | start: ir.SSAValue = info.argument(types.Float) 37 | stop: ir.SSAValue = info.argument(types.Float) 38 | result: ir.ResultValue = info.result(types.Float) 39 | 40 | 41 | @statement(dialect=dialect) 42 | class Seed(ir.Statement): 43 | """seed statement, wrapping the random.seed function 44 | sets the seed for the random number generator 45 | """ 46 | 47 | traits = frozenset({lowering.FromPythonCall()}) 48 | value: ir.SSAValue = info.argument(types.Int) 49 | -------------------------------------------------------------------------------- /.github/workflows/devdoc.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Devopment Branch Docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | documentation: 13 | name: Deploy dev documentation 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v6 17 | with: 18 | fetch-depth: 0 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v7 21 | with: 22 | # Install a specific version of uv. 23 | version: "0.6.14" 24 | enable-cache: true 25 | cache-dependency-glob: "uv.lock" 26 | - name: Install Documentation dependencies 27 | run: uv sync --group doc 28 | - name: Set up build cache 29 | uses: actions/cache@v5 30 | id: cache 31 | with: 32 | key: mkdocs-material-${{ github.ref }} 33 | path: .cache 34 | restore-keys: | 35 | mkdocs-material- 36 | # derived from: 37 | # https://github.com/RemoteCloud/public-documentation/blob/dev/.github/workflows/build_docs.yml 38 | - name: Configure Git user 39 | run: | 40 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 41 | git config --local user.name "github-actions[bot]" 42 | - name: Deploy documentation 43 | env: 44 | GH_TOKEN: ${{ secrets.GH_TOKEN }} 45 | GOOGLE_ANALYTICS_KEY: ${{ secrets.GOOGLE_ANALYTICS_KEY }} 46 | run: | 47 | git fetch origin gh-pages --depth=1 48 | uv run mike deploy -p dev 49 | -------------------------------------------------------------------------------- /src/kirin/rewrite/aggressive/fold.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from kirin.rewrite import Walk, Chain, Fixpoint 4 | from kirin.analysis import const 5 | from kirin.rewrite.abc import RewriteRule, RewriteResult 6 | from kirin.rewrite.dce import DeadCodeElimination 7 | from kirin.rewrite.fold import ConstantFold 8 | from kirin.ir.nodes.base import IRNode 9 | from kirin.rewrite.inline import Inline 10 | from kirin.rewrite.getitem import InlineGetItem 11 | from kirin.rewrite.getfield import InlineGetField 12 | from kirin.rewrite.compactify import CFGCompactify 13 | from kirin.rewrite.wrap_const import WrapConst 14 | from kirin.rewrite.call2invoke import Call2Invoke 15 | from kirin.rewrite.type_assert import InlineTypeAssert 16 | 17 | 18 | @dataclass 19 | class Fold(RewriteRule): 20 | rule: RewriteRule 21 | 22 | def __init__(self, frame: const.Frame): 23 | rule = Fixpoint( 24 | Chain( 25 | Walk(WrapConst(frame)), 26 | Walk(Inline(lambda _: True)), 27 | Walk(ConstantFold()), 28 | Walk(Call2Invoke()), 29 | Fixpoint( 30 | Walk( 31 | Chain( 32 | InlineTypeAssert(), 33 | InlineGetItem(), 34 | InlineGetField(), 35 | DeadCodeElimination(), 36 | ) 37 | ) 38 | ), 39 | Walk(CFGCompactify()), 40 | ) 41 | ) 42 | self.rule = rule 43 | 44 | def rewrite(self, node: IRNode) -> RewriteResult: 45 | return self.rule.rewrite(node) 46 | -------------------------------------------------------------------------------- /test/dialects/pystmts/test_slice.py: -------------------------------------------------------------------------------- 1 | from kirin import types 2 | from kirin.prelude import basic_no_opt 3 | from kirin.dialects import py 4 | 5 | 6 | @basic_no_opt 7 | def explicit_slice(): 8 | x = slice(1, 2, 3) 9 | y = slice(1, 2) 10 | z = slice(1) 11 | return x, y, z 12 | 13 | 14 | @basic_no_opt 15 | def wrong_slice(): 16 | x = slice(None, None, None) 17 | y = slice(None, None, 1) 18 | return x, y 19 | 20 | 21 | def test_explicit_slice(): 22 | stmt: py.slice.Slice = explicit_slice.code.body.blocks[0].stmts.at(3) 23 | assert stmt.start.type.is_subseteq(types.Int) 24 | assert stmt.stop.type.is_subseteq(types.Int) 25 | assert stmt.step.type.is_subseteq(types.Int) 26 | assert stmt.result.type.is_subseteq(types.Slice[types.Int]) 27 | 28 | stmt: py.slice.Slice = explicit_slice.code.body.blocks[0].stmts.at(7) 29 | assert stmt.start.type.is_subseteq(types.Int) 30 | assert stmt.stop.type.is_subseteq(types.Int) 31 | assert stmt.step.type.is_subseteq(types.NoneType) 32 | assert stmt.result.type.is_subseteq(types.Slice[types.Int]) 33 | 34 | stmt: py.slice.Slice = explicit_slice.code.body.blocks[0].stmts.at(11) 35 | assert stmt.start.type.is_subseteq(types.NoneType) 36 | assert stmt.stop.type.is_subseteq(types.Int) 37 | assert stmt.step.type.is_subseteq(types.NoneType) 38 | assert stmt.result.type.is_subseteq(types.Slice[types.Int]) 39 | 40 | 41 | def test_wrong_slice(): 42 | stmt: py.slice.Slice = wrong_slice.code.body.blocks[0].stmts.at(3) 43 | assert stmt.result.type.is_subseteq(types.Bottom) 44 | 45 | stmt: py.slice.Slice = wrong_slice.code.body.blocks[0].stmts.at(7) 46 | assert stmt.result.type.is_subseteq(types.Bottom) 47 | -------------------------------------------------------------------------------- /src/kirin/passes/fold.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, dataclass 2 | 3 | from kirin.ir import HasCFG, Method 4 | from kirin.rewrite import ( 5 | Walk, 6 | Chain, 7 | Fixpoint, 8 | Call2Invoke, 9 | ConstantFold, 10 | CFGCompactify, 11 | InlineGetItem, 12 | DeadCodeElimination, 13 | ) 14 | from kirin.passes.abc import Pass 15 | from kirin.rewrite.abc import RewriteResult 16 | 17 | from .hint_const import HintConst 18 | 19 | 20 | @dataclass 21 | class Fold(Pass): 22 | """ 23 | Pass that runs a number of small optimization rewrites. 24 | 25 | Specifically, the following rewrites are chained: 26 | 27 | - `ConstantFold` 28 | - `InlineGetItem` 29 | - `Call2Invoke` 30 | - `DeadCodeElimination` 31 | """ 32 | 33 | hint_const: HintConst = field(init=False) 34 | 35 | def __post_init__(self): 36 | self.hint_const = HintConst(self.dialects) 37 | self.hint_const.no_raise = self.no_raise 38 | 39 | def unsafe_run(self, mt: Method) -> RewriteResult: 40 | result = self.hint_const.unsafe_run(mt) 41 | result = ( 42 | Fixpoint( 43 | Walk( 44 | Chain( 45 | ConstantFold(), 46 | InlineGetItem(), 47 | Call2Invoke(), 48 | DeadCodeElimination(), 49 | ) 50 | ) 51 | ) 52 | .rewrite(mt.code) 53 | .join(result) 54 | ) 55 | 56 | if mt.code.has_trait(HasCFG): 57 | result = Walk(CFGCompactify()).rewrite(mt.code).join(result) 58 | 59 | return Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result) 60 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/len.py: -------------------------------------------------------------------------------- 1 | """The `Len` dialect. 2 | 3 | This dialect maps the `len()` call to the `Len` statement: 4 | 5 | - The `Len` statement class. 6 | - The lowering pass for the `len()` call. 7 | - The concrete implementation of the `len()` call. 8 | """ 9 | 10 | import ast 11 | 12 | from kirin import ir, types, interp, lowering 13 | from kirin.decl import info, statement 14 | from kirin.analysis import const 15 | 16 | dialect = ir.Dialect("py.len") 17 | 18 | 19 | @statement(dialect=dialect) 20 | class Len(ir.Statement): 21 | name = "len" 22 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 23 | value: ir.SSAValue = info.argument(types.Any) 24 | result: ir.ResultValue = info.result(types.Int) 25 | 26 | 27 | @dialect.register 28 | class Concrete(interp.MethodTable): 29 | 30 | @interp.impl(Len) 31 | def len(self, interp, frame: interp.Frame, stmt: Len): 32 | return (len(frame.get(stmt.value)),) 33 | 34 | 35 | @dialect.register(key="constprop") 36 | class ConstProp(interp.MethodTable): 37 | 38 | @interp.impl(Len) 39 | def len(self, interp, frame: interp.Frame, stmt: Len): 40 | value = frame.get(stmt.value) 41 | if isinstance(value, const.Value): 42 | return (const.Value(len(value.data)),) 43 | elif isinstance(value, const.PartialTuple): 44 | return (const.Value(len(value.data)),) 45 | else: 46 | return (const.Result.top(),) 47 | 48 | 49 | @dialect.register 50 | class Lowering(lowering.FromPythonAST): 51 | 52 | @lowering.akin(len) 53 | def lower_Call_len(self, state: lowering.State, node: ast.Call) -> lowering.Result: 54 | return state.current_frame.push(Len(state.lower(node.args[0]).expect_one())) 55 | -------------------------------------------------------------------------------- /src/kirin/dialects/ilist/rewrite/inline_getitem.py: -------------------------------------------------------------------------------- 1 | from kirin import ir 2 | from kirin.rewrite import abc 3 | from kirin.analysis import const 4 | from kirin.dialects import py 5 | 6 | from ..stmts import New 7 | 8 | 9 | class InlineGetItem(abc.RewriteRule): 10 | """Rewrite rule to inline GetItem statements for IList. 11 | 12 | For example if we have an `ilist.New` statement with a list of items, 13 | and we can infer that the index used in `py.GetItem` is constant and within bounds, 14 | we replace the `py.GetItem` with the ssa value in the list when the index is an integer 15 | or with a new `ilist.New` statement containing the sliced items when the index is a slice. 16 | 17 | """ 18 | 19 | def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: 20 | if not isinstance(node, py.GetItem) or not isinstance( 21 | stmt := node.obj.owner, New 22 | ): 23 | return abc.RewriteResult() 24 | 25 | if not isinstance(index_const := node.index.hints.get("const"), const.Value): 26 | return abc.RewriteResult() 27 | 28 | if not node.result.uses: 29 | return abc.RewriteResult() 30 | 31 | index = index_const.data 32 | if isinstance(index, int) and ( 33 | 0 <= index < len(stmt.args) or -len(stmt.args) <= index < 0 34 | ): 35 | node.result.replace_by(stmt.args[index]) 36 | return abc.RewriteResult(has_done_something=True) 37 | elif isinstance(index, slice): 38 | new_tuple = New(tuple(stmt.args[index])) 39 | node.replace_by(new_tuple) 40 | return abc.RewriteResult(has_done_something=True) 41 | else: 42 | return abc.RewriteResult() 43 | -------------------------------------------------------------------------------- /test/dialects/py_dialect/test_assign.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from kirin import types 4 | from kirin.prelude import basic, basic_no_opt 5 | from kirin.analysis import TypeInference 6 | from kirin.dialects import py, func, ilist 7 | 8 | 9 | @basic_no_opt 10 | def main(x): 11 | y: int = x 12 | return y 13 | 14 | 15 | def test_ann_assign(): 16 | stmt = main.callable_region.blocks[0].stmts.at(0) 17 | assert isinstance(stmt, py.assign.TypeAssert) 18 | 19 | typeinfer = TypeInference(basic_no_opt) 20 | _, ret = typeinfer.run(main, types.Int) 21 | assert ret.is_structurally_equal(types.Int) 22 | _, ret = typeinfer.run(main, types.Float) 23 | assert ret is ret.bottom() 24 | 25 | 26 | def test_typeinfer_simplify_assert(): 27 | @basic(typeinfer=True, fold=False) 28 | def simplify(x: int): 29 | y: int = x 30 | return y 31 | 32 | stmt = simplify.callable_region.blocks[0].stmts.at(0) 33 | assert isinstance(stmt, func.Return) 34 | 35 | 36 | def test_list_assign(): 37 | @basic_no_opt.add(ilist) 38 | def list_assign(): 39 | xs: ilist.IList[float, Literal[3]] = ilist.IList([1, 2, 3], elem=types.Float) 40 | return xs 41 | 42 | stmt = list_assign.callable_region.blocks[0].stmts.at(3) 43 | assert isinstance(stmt, ilist.New) 44 | assert stmt.elem_type.is_structurally_equal(types.Float) 45 | assert stmt.result.type.is_structurally_equal( 46 | ilist.IListType[types.Float, types.Literal(3)] 47 | ) 48 | 49 | stmt = list_assign.callable_region.blocks[0].stmts.at(4) 50 | assert isinstance(stmt, py.assign.TypeAssert) 51 | assert stmt.expected.is_structurally_equal( 52 | ilist.IListType[types.Float, types.Literal(3)] 53 | ) 54 | -------------------------------------------------------------------------------- /test/rewrite/test_cse_rewrite.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, lowering 2 | from kirin.decl import info, statement 3 | from kirin.prelude import basic_no_opt 4 | from kirin.rewrite.cse import _HASHABLE_SLICE, Info, CommonSubexpressionElimination 5 | from kirin.rewrite.walk import Walk 6 | 7 | dialect = ir.Dialect("test") 8 | 9 | 10 | @statement(dialect=dialect) 11 | class MultiResult(ir.Statement): 12 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 13 | result_a: ir.ResultValue = info.result() 14 | result_b: ir.ResultValue = info.result() 15 | 16 | 17 | dummy_dialect = basic_no_opt.add(dialect) 18 | 19 | 20 | def test_multi_result(): 21 | @dummy_dialect 22 | def duplicated(): 23 | x, y = MultiResult() # type: ignore 24 | a, b = MultiResult() # type: ignore 25 | return x + a, y + b 26 | 27 | stmt_0 = duplicated.callable_region.blocks[0].stmts.at(0) 28 | stmt_1 = duplicated.callable_region.blocks[0].stmts.at(1) 29 | assert isinstance(stmt_0, MultiResult) 30 | assert isinstance(stmt_1, MultiResult) 31 | 32 | Walk(CommonSubexpressionElimination()).rewrite(duplicated.code) 33 | 34 | stmt_0 = duplicated.callable_region.blocks[0].stmts.at(0) 35 | stmt_1 = duplicated.callable_region.blocks[0].stmts.at(1) 36 | assert isinstance(stmt_0, MultiResult) 37 | assert not isinstance(stmt_1, MultiResult) 38 | 39 | 40 | def test_info(): 41 | info_value = Info(ir.Statement, (), (ir.PyAttr(slice(None)),), (), ()) 42 | 43 | if not _HASHABLE_SLICE: 44 | assert info_value._hashable is False 45 | assert info_value._hash == id(info_value) 46 | else: 47 | assert info_value._hashable is True 48 | info_value._hash = hash((ir.Statement,) + (ir.PyAttr(slice(None)),)) 49 | -------------------------------------------------------------------------------- /test/passes/test_unroll_scf.py: -------------------------------------------------------------------------------- 1 | from kirin.prelude import structural, structural_no_opt 2 | from kirin.dialects import py, func, ilist 3 | from kirin.passes.aggressive import UnrollScf 4 | 5 | 6 | def test_unroll_scf(): 7 | @structural 8 | def main(r: list[int], cond: bool): 9 | if cond: 10 | for i in range(4): 11 | tmp = r[-1] 12 | if i < 2: 13 | tmp += i * 2 14 | else: 15 | for j in range(4): 16 | if i > j: 17 | tmp += i + j 18 | else: 19 | tmp += i - j 20 | 21 | r.append(tmp) 22 | else: 23 | for i in range(4): 24 | r.append(i) 25 | return r 26 | 27 | UnrollScf(structural).fixpoint(main) 28 | 29 | num_adds = 0 30 | num_calls = 0 31 | 32 | for op in main.callable_region.walk(): 33 | if isinstance(op, py.Add): 34 | num_adds += 1 35 | elif isinstance(op, func.Call): 36 | num_calls += 1 37 | 38 | assert num_adds == 10 39 | assert num_calls == 8 40 | 41 | 42 | def test_dce_unroll_typeinfer(): 43 | # NOTE: tests bug in typeinfer preventing DCE from issue#564 44 | 45 | @structural_no_opt 46 | def main(): 47 | ls = [1, 2, 3] 48 | for i in range(1): 49 | ls[i] = 10 50 | return ls 51 | 52 | UnrollScf(main.dialects).fixpoint(main) 53 | 54 | for stmt in main.callable_region.stmts(): 55 | if isinstance(stmt, py.Constant) and isinstance( 56 | value := stmt.value, ilist.IList 57 | ): 58 | assert not isinstance(value.data, range), "Unused range not eliminated!" 59 | -------------------------------------------------------------------------------- /src/kirin/passes/typeinfer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, dataclass 2 | 3 | from kirin.ir import Method, HasSignature 4 | from kirin.rewrite import Walk, Chain 5 | from kirin.passes.abc import Pass 6 | from kirin.rewrite.abc import RewriteResult 7 | from kirin.dialects.func import Signature 8 | from kirin.analysis.typeinfer import TypeInference 9 | from kirin.rewrite.apply_type import ApplyType 10 | from kirin.rewrite.type_assert import InlineTypeAssert 11 | 12 | from .hint_const import HintConst 13 | from .post_inference import PostInference 14 | 15 | 16 | @dataclass 17 | class TypeInfer(Pass): 18 | hint_const: HintConst = field(init=False) 19 | inference: PostInference = field(init=False) 20 | 21 | def __post_init__(self): 22 | self.infer = TypeInference(self.dialects) 23 | self.hint_const = HintConst(self.dialects, no_raise=self.no_raise) 24 | self.post_inference = PostInference(self.dialects, no_raise=self.no_raise) 25 | 26 | def unsafe_run(self, mt: Method) -> RewriteResult: 27 | result = self.hint_const.unsafe_run(mt) 28 | if self.no_raise: 29 | frame, return_type = self.infer.run_no_raise(mt, *mt.arg_types) 30 | else: 31 | frame, return_type = self.infer.run(mt, *mt.arg_types) 32 | 33 | if trait := mt.code.get_trait(HasSignature): 34 | trait.set_signature(mt.code, Signature(mt.arg_types, return_type)) 35 | 36 | result = ( 37 | Chain( 38 | Walk(ApplyType(frame.entries)), 39 | Walk(InlineTypeAssert()), 40 | ) 41 | .rewrite(mt.code) 42 | .join(result) 43 | ) 44 | result = self.post_inference.fixpoint(mt).join(result) 45 | mt.inferred = True 46 | return result 47 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/attr.py: -------------------------------------------------------------------------------- 1 | """Attribute access dialect for Python. 2 | 3 | This module contains the dialect for the Python attribute access statement, including: 4 | 5 | - The `GetAttr` statement class. 6 | - The lowering pass for the attribute access statement. 7 | - The concrete implementation of the attribute access statement. 8 | 9 | This dialect maps `ast.Attribute` nodes to the `GetAttr` statement. 10 | """ 11 | 12 | import ast 13 | 14 | from kirin import ir, interp, lowering 15 | from kirin.decl import info, statement 16 | 17 | dialect = ir.Dialect("py.attr") 18 | 19 | 20 | @statement(dialect=dialect) 21 | class GetAttr(ir.Statement): 22 | name = "getattr" 23 | traits = frozenset({lowering.FromPythonCall()}) 24 | obj: ir.SSAValue = info.argument(print=False) 25 | attrname: str = info.attribute() 26 | result: ir.ResultValue = info.result() 27 | 28 | 29 | @dialect.register 30 | class Concrete(interp.MethodTable): 31 | 32 | @interp.impl(GetAttr) 33 | def getattr(self, interp: interp.Interpreter, frame: interp.Frame, stmt: GetAttr): 34 | return (getattr(frame.get(stmt.obj), stmt.attrname),) 35 | 36 | 37 | @dialect.register 38 | class Lowering(lowering.FromPythonAST): 39 | 40 | def lower_Attribute( 41 | self, state: lowering.State, node: ast.Attribute 42 | ) -> lowering.Result: 43 | 44 | if not isinstance(node.ctx, ast.Load): 45 | raise lowering.BuildError(f"unsupported attribute context {node.ctx}") 46 | 47 | # NOTE: eagerly load global variables 48 | value = state.get_global(node, no_raise=True) 49 | if value is not None: 50 | return state.lower(ast.Constant(value.data)).expect_one() 51 | 52 | value = state.lower(node.value).expect_one() 53 | return state.current_frame.push(GetAttr(obj=value, attrname=node.attr)) 54 | -------------------------------------------------------------------------------- /test/analysis/dataflow/typeinfer/test_inter_method.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | 3 | from kirin import types 4 | from kirin.prelude import basic, structural 5 | from kirin.dialects import ilist 6 | 7 | 8 | @mark.xfail(reason="if with early return not supported in scf lowering") 9 | def test_inter_method_infer(): 10 | @basic 11 | def foo(x: int): 12 | if x > 1: 13 | return x + 1 14 | else: 15 | return x - 1.0 16 | 17 | @basic(typeinfer=True, no_raise=False) 18 | def main(x: int): 19 | return foo(x) 20 | 21 | @basic(typeinfer=True, no_raise=False) 22 | def moo(x): 23 | return foo(x) 24 | 25 | assert main.return_type == (types.Int | types.Float) 26 | # assert moo.arg_types[0] == types.Int # type gets narrowed based on callee 27 | assert moo.return_type == (types.Int | types.Float) 28 | # NOTE: inference of moo should not update foo 29 | assert foo.arg_types[0] == types.Int 30 | assert foo.inferred is False 31 | assert foo.return_type is types.Any 32 | 33 | 34 | @mark.xfail(reason="if with early return not supported in scf lowering") 35 | def test_infer_if_return(): 36 | from kirin.prelude import structural 37 | 38 | @structural(typeinfer=True, fold=True, no_raise=False) 39 | def test(b: bool): 40 | if b: 41 | return False 42 | else: 43 | b = not b 44 | 45 | return b 46 | 47 | test.print() 48 | 49 | 50 | def test_method_constant_type_infer(): 51 | 52 | @structural(typeinfer=True, fold=False) 53 | def _new(qid: int): 54 | return 1 55 | 56 | @structural(fold=False, typeinfer=True) 57 | def alloc(n_iter: int): 58 | return ilist.map(_new, ilist.range(n_iter)) 59 | 60 | assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any]) 61 | -------------------------------------------------------------------------------- /test/dialects/pyrules/test_getitem.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, types, lowering 2 | from kirin.decl import info, statement 3 | from kirin.prelude import basic_no_opt 4 | from kirin.rewrite import Walk 5 | from kirin.dialects.py import indexing 6 | 7 | dummy = ir.Dialect("dummy") 8 | 9 | 10 | class RegGetItemInterface(indexing.GetItemLike["RegGetItem"]): 11 | 12 | def get_object(self, stmt: "RegGetItem") -> ir.SSAValue: 13 | return stmt.reg 14 | 15 | def get_index(self, stmt: "RegGetItem") -> ir.SSAValue: 16 | return stmt.index 17 | 18 | def new( 19 | self, stmt_type: type["RegGetItem"], obj: ir.SSAValue, index: ir.SSAValue 20 | ) -> "RegGetItem": 21 | return RegGetItem(obj, index) 22 | 23 | 24 | class Register: 25 | pass 26 | 27 | 28 | @statement(dialect=dummy) 29 | class New(ir.Statement): 30 | name = "new" 31 | traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) 32 | result: ir.ResultValue = info.result(types.PyClass(Register)) 33 | 34 | 35 | @statement(dialect=dummy) 36 | class RegGetItem(ir.Statement): 37 | name = "reg.get" 38 | traits = frozenset({ir.Pure(), RegGetItemInterface()}) 39 | reg: ir.SSAValue = info.argument(types.PyClass(Register)) 40 | index: ir.SSAValue = info.argument(types.Int) 41 | result: ir.ResultValue = info.result(types.Int) 42 | 43 | 44 | @basic_no_opt.add(dummy) 45 | def main(): 46 | reg = New() 47 | return reg[0] # type: ignore 48 | 49 | 50 | def test_rewrite_getitem(): 51 | rule = Walk(indexing.RewriteGetItem(RegGetItem, types.PyClass(Register))) 52 | 53 | stmt: ir.Statement = main.code.body.blocks[0].stmts.at(-2) # type: ignore 54 | assert isinstance(stmt, indexing.GetItem) 55 | rule.rewrite(main.code) 56 | stmt: ir.Statement = main.code.body.blocks[0].stmts.at(-2) # type: ignore 57 | assert isinstance(stmt, RegGetItem) 58 | -------------------------------------------------------------------------------- /test/dialects/func/test_lambdalifting.py: -------------------------------------------------------------------------------- 1 | from kirin import ir, rewrite 2 | from kirin.prelude import basic 3 | from kirin.dialects import py, func 4 | from kirin.dialects.func.rewrite import lambdalifting 5 | 6 | 7 | def test_rewrite_inner_lambda(): 8 | @basic 9 | def outer(): 10 | def inner(x: int): 11 | return x + 1 12 | 13 | return inner 14 | 15 | pyconstant_stmt = outer.code.regions[0].blocks[0].stmts.at(0) 16 | assert isinstance(pyconstant_stmt, py.Constant), "expected a Constant in outer body" 17 | assert isinstance( 18 | pyconstant_stmt.value, ir.PyAttr 19 | ), "expected a PyAttr in outer body" 20 | assert isinstance( 21 | pyconstant_stmt.value.data.code, func.Lambda 22 | ), "expected a lambda Method in outer body" 23 | 24 | rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer.code) 25 | assert isinstance( 26 | pyconstant_stmt.value.data.code, func.Function 27 | ), "expected a Function in outer body" 28 | 29 | 30 | def test_rewrite_inner_lambda_with_captured_vars(): 31 | @basic 32 | def outer2(): 33 | z = 10 34 | y = 3 + z 35 | 36 | def inner2(x: int): 37 | return x + y + 5 38 | 39 | return inner2 40 | 41 | pyconstant_stmt = outer2.code.regions[0].blocks[0].stmts.at(0) 42 | assert isinstance(pyconstant_stmt, py.Constant), "expected a Constant in outer body" 43 | assert isinstance( 44 | pyconstant_stmt.value, ir.PyAttr 45 | ), "expected a PyAttr in outer body" 46 | assert isinstance( 47 | pyconstant_stmt.value.data.code, func.Lambda 48 | ), "expected a lambda Method in outer body" 49 | rewrite.Walk(lambdalifting.LambdaLifting()).rewrite(outer2.code) 50 | assert isinstance( 51 | pyconstant_stmt.value.data.code, func.Function 52 | ), "expected a Function in outer body" 53 | -------------------------------------------------------------------------------- /.github/workflows/pub_doc.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Release Docs 2 | on: 3 | push: 4 | tags: 5 | - "v*" 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | documentation: 13 | name: Deploy release documentation 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v6 17 | with: 18 | fetch-depth: 0 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v7 21 | with: 22 | # Install a specific version of uv. 23 | version: "0.6.14" 24 | enable-cache: true 25 | cache-dependency-glob: "uv.lock" 26 | - name: Install Documentation dependencies 27 | run: uv sync --group doc 28 | - name: Set up build cache 29 | uses: actions/cache@v5 30 | id: cache 31 | with: 32 | key: mkdocs-material-${{ github.ref }} 33 | path: .cache 34 | restore-keys: | 35 | mkdocs-material- 36 | # derived from: 37 | # https://github.com/RemoteCloud/public-documentation/blob/dev/.github/workflows/build_docs.yml 38 | - name: Configure Git user 39 | run: | 40 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 41 | git config --local user.name "github-actions[bot]" 42 | - name: Set release notes tag 43 | run: | 44 | export TAG_PATH=${{ github.ref }} 45 | echo ${TAG_PATH} 46 | echo "TAG_VERSION=${TAG_PATH##*/}" >> $GITHUB_ENV 47 | echo ${TAG_VERSION} 48 | - name: Deploy documentation 49 | env: 50 | GH_TOKEN: ${{ secrets.GH_TOKEN }} 51 | run: | 52 | git fetch origin gh-pages --depth=1 53 | uv run mike deploy --update-alias --push ${TAG_VERSION} latest 54 | -------------------------------------------------------------------------------- /src/kirin/dialects/py/cmp/interp.py: -------------------------------------------------------------------------------- 1 | from kirin import interp 2 | 3 | from . import stmts as cmp 4 | from ._dialect import dialect 5 | 6 | 7 | @dialect.register 8 | class CmpMethod(interp.MethodTable): 9 | 10 | @interp.impl(cmp.Eq) 11 | def eq(self, interp, frame: interp.Frame, stmt: cmp.Eq): 12 | return (frame.get(stmt.lhs) == frame.get(stmt.rhs),) 13 | 14 | @interp.impl(cmp.NotEq) 15 | def not_eq(self, interp, frame: interp.Frame, stmt: cmp.NotEq): 16 | return (frame.get(stmt.lhs) != frame.get(stmt.rhs),) 17 | 18 | @interp.impl(cmp.Lt) 19 | def lt(self, interp, frame: interp.Frame, stmt: cmp.Lt): 20 | return (frame.get(stmt.lhs) < frame.get(stmt.rhs),) 21 | 22 | @interp.impl(cmp.LtE) 23 | def lt_eq(self, interp, frame: interp.Frame, stmt: cmp.LtE): 24 | return (frame.get(stmt.lhs) <= frame.get(stmt.rhs),) 25 | 26 | @interp.impl(cmp.Gt) 27 | def gt(self, interp, frame: interp.Frame, stmt: cmp.Gt): 28 | return (frame.get(stmt.lhs) > frame.get(stmt.rhs),) 29 | 30 | @interp.impl(cmp.GtE) 31 | def gt_eq(self, interp, frame: interp.Frame, stmt: cmp.GtE): 32 | return (frame.get(stmt.lhs) >= frame.get(stmt.rhs),) 33 | 34 | @interp.impl(cmp.In) 35 | def in_(self, interp, frame: interp.Frame, stmt: cmp.In): 36 | return (frame.get(stmt.lhs) in frame.get(stmt.rhs),) 37 | 38 | @interp.impl(cmp.NotIn) 39 | def not_in(self, interp, frame: interp.Frame, stmt: cmp.NotIn): 40 | return (frame.get(stmt.lhs) not in frame.get(stmt.rhs),) 41 | 42 | @interp.impl(cmp.Is) 43 | def is_(self, interp, frame: interp.Frame, stmt: cmp.Is): 44 | return (frame.get(stmt.lhs) is frame.get(stmt.rhs),) 45 | 46 | @interp.impl(cmp.IsNot) 47 | def is_not(self, interp, frame: interp.Frame, stmt: cmp.IsNot): 48 | return (frame.get(stmt.lhs) is not frame.get(stmt.rhs),) 49 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This folder contains examples of how to use the Kirin library. Each example is a standalone project that can be run independently. 4 | 5 | ## List of Examples 6 | 7 | - `simple.py`: A simple example that demonstrates how to create a simple Kirin dialect group and its kernel. 8 | - `food`: A more sophisticated example but without any domain specifics. It demonstrates how to create a new Kirin dialect and combine it with existing dialects with custom analysis and rewrites. 9 | - `pauli`: An example that implements a dialect with rewrites that simplifies products of Pauli matrices. 10 | 11 | ## Examples outside this folder with more domain-specific contents 12 | 13 | ### Quantum Computing 14 | 15 | - [bloqade.qasm2](https://github.com/QuEraComputing/bloqade/tree/main/src/bloqade/qasm2): This is an eDSL for quantum computing that uses Kirin to define an eDSL for the Quantum Assembly Language (QASM) 2.0. It demonstrates how to create multiple dialects using Kirin, run custom analysis and rewrites, and generate code from the dialects (back to QASM 2.0 in this case). 16 | - [bloqade.stim](https://github.com/QuEraComputing/bloqade/tree/main/src/bloqade/stim): This is an eDSL for quantum computing that uses Kirin to define an eDSL for the [STIM](https://github.com/quantumlib/Stim/) language. It demonstrates how to create multiple dialects using Kirin, run custom analysis and rewrites, and generate code from the dialects (back to Stim in this case). 17 | - [bloqade.qBraid](https://github.com/QuEraComputing/bloqade/blob/main/src/bloqade/qbraid/lowering.py): This example demonstrates how to lower from an existing representation into the Kirin IR by using the visitor pattern. 18 | - [bloqade.analysis](https://github.com/QuEraComputing/bloqade/tree/main/src/bloqade/analysis/): This directory contains examples of how to write custom analysis passes using Kirin for quantum computing. 19 | --------------------------------------------------------------------------------