├── demo ├── module │ ├── parent.rc │ ├── main │ │ ├── brother.rc │ │ └── normal.rc │ └── submodule │ │ └── otherlib.rc ├── base │ ├── simple.rc │ ├── print.rc │ ├── array.rc │ ├── binary.rc │ ├── call.rc │ └── fibonacci.rc ├── hello_world.rc ├── opt │ ├── constant_folding.rc │ └── cse.rc ├── template │ ├── fn.rc │ └── class.rc ├── lib │ └── file.rc ├── control │ ├── for.rc │ ├── if.rc │ ├── while.rc │ └── nested_while.rc └── oop │ ├── rc_object.rc │ ├── make_object.rc │ ├── oop.rc │ └── tree.rc ├── project ├── build.properties └── plugins.sbt ├── src ├── main │ └── scala │ │ ├── tools │ │ ├── Traits.scala │ │ ├── Debugger.scala │ │ ├── DumpManager.scala │ │ ├── State.scala │ │ ├── Ext.scala │ │ ├── Render.scala │ │ ├── RcLogger.scala │ │ ├── Mangling.scala │ │ ├── NestScope.scala │ │ └── SymTable.scala │ │ ├── Interface │ │ ├── RCI.scala │ │ └── RCC.scala │ │ ├── compiler │ │ ├── CompileOption.scala │ │ ├── DependencyGraph.scala │ │ └── Driver.scala │ │ ├── lib │ │ ├── Libc.scala │ │ ├── RcObject.rc │ │ └── StdLib.rc │ │ ├── ast │ │ ├── TyInfo.scala │ │ ├── Render.scala │ │ ├── Stmt.scala │ │ ├── AST.scala │ │ ├── Module.scala │ │ ├── ASTBuilder.scala │ │ ├── Expr.scala │ │ └── ASTVisitor.scala │ │ ├── analysis │ │ ├── DomTreeAnalysis.scala │ │ ├── Result.scala │ │ ├── DomFrontierAnalysis.scala │ │ ├── LoopAnalysis.scala │ │ ├── SymScanner.scala │ │ └── ModuleValidate.scala │ │ ├── pass │ │ ├── Pass.scala │ │ ├── PassManager.scala │ │ └── AnalysisManager.scala │ │ ├── parser │ │ ├── RcParser.scala │ │ ├── ModuleParser.scala │ │ ├── RcBaseParser.scala │ │ └── ExprParser.scala │ │ ├── CompileError.scala │ │ ├── codegen │ │ ├── Target.scala │ │ ├── MachineIRPrinter.scala │ │ ├── MachineIRBuilder.scala │ │ ├── PhiEliminate.scala │ │ ├── RegisterAllocation.scala │ │ ├── ASMTrait.scala │ │ ├── MachineFrameInfo.scala │ │ ├── IR.scala │ │ └── GNUASM.scala │ │ ├── Def.scala │ │ ├── transform │ │ ├── CFGSimplify.scala │ │ ├── CSE.scala │ │ ├── CallInliner.scala │ │ └── ConstantFolding.scala │ │ ├── mir │ │ ├── Intrinsic.scala │ │ ├── Render.scala │ │ ├── DFCalculator.scala │ │ ├── Value.scala │ │ ├── IRBuilder.scala │ │ ├── IR.scala │ │ ├── CFG.scala │ │ ├── InstVisitor.scala │ │ ├── DomTree.scala │ │ └── Instruction.scala │ │ ├── interpreter │ │ ├── interpreter.scala │ │ └── evaluator.scala │ │ ├── lexer │ │ ├── Token.scala │ │ └── Lexer.scala │ │ ├── ty │ │ ├── Check.scala │ │ ├── Type.scala │ │ ├── TyCtxt.scala │ │ ├── Translator.scala │ │ └── Infer.scala │ │ └── graphviz │ │ └── Backend.scala └── test │ └── scala │ ├── transform │ ├── CSETest.scala │ ├── ConstantFoldingTest.scala │ └── CallInlinerTest.scala │ ├── ast │ └── ExprTest.scala │ ├── tools │ └── StateTest.scala │ ├── mir │ ├── CFGTest.scala │ ├── MIRTestUtil.scala │ └── DFCalculatorTest.scala │ ├── RcTestBase.scala │ ├── integrated │ └── DemoCompileTest.scala │ ├── ty │ ├── TypedTranslatorTest.scala │ └── TyCtxtTest.scala │ ├── codegen │ ├── IRTranslatorTest.scala │ └── CodegenIRTest.scala │ ├── analysis │ ├── LoopAnalysisTest.scala │ ├── SymScannerTest.scala │ ├── ModuleValidateTest.scala │ └── DomTreeAnalysisTest.scala │ ├── parser │ ├── StmtParserTest.scala │ ├── BaseParserTest.scala │ ├── ModuleParserTest.scala │ └── ExprParserTest.scala │ └── lexer │ └── LexerTest.scala ├── .gitignore └── .github └── workflows └── test.yml /demo/module/parent.rc: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/module/main/brother.rc: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo/module/submodule/otherlib.rc: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.6.2 2 | -------------------------------------------------------------------------------- /demo/base/simple.rc: -------------------------------------------------------------------------------- 1 | def add(a: Int, b: Int) 2 | a + b 3 | end -------------------------------------------------------------------------------- /demo/hello_world.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | print("Hello world") 3 | end -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.jetbrains.scala" % "sbt-ide-settings" % "1.1.1") 2 | -------------------------------------------------------------------------------- /demo/opt/constant_folding.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | val a = 1 + 1 3 | val b = a + 1 4 | end 5 | -------------------------------------------------------------------------------- /demo/template/fn.rc: -------------------------------------------------------------------------------- 1 | def foo(): Int 2 | 3 | end 4 | 5 | def main() 6 | foo(1, 2) 7 | end -------------------------------------------------------------------------------- /demo/base/print.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | val a = 1 3 | val b = a + 2 4 | print("%d %d", a, b) 5 | end -------------------------------------------------------------------------------- /demo/opt/cse.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | val a = 2 3 | val b = a + 1 4 | val c = a + 1 5 | val d = a + 2 6 | end -------------------------------------------------------------------------------- /demo/base/array.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | var n = 5 3 | var a = Int[5] { 1, 2, n } 4 | print("a[2] = %d", a[2]) 5 | end -------------------------------------------------------------------------------- /demo/base/binary.rc: -------------------------------------------------------------------------------- 1 | def binary() 2 | val a = 1 3 | val b = a + 1 4 | val c = a + b 5 | val d = a + c 6 | end 7 | 8 | -------------------------------------------------------------------------------- /demo/lib/file.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | var f = File.new() 3 | f.open("test.txt", "w+") 4 | f.write("str") 5 | f.close() 6 | end -------------------------------------------------------------------------------- /src/main/scala/tools/Traits.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | trait In[T] { 5 | var parent: T = null.asInstanceOf[T] 6 | } -------------------------------------------------------------------------------- /demo/control/for.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | # a = a + 1 is error 3 | for(val a = 1; a < 5; a = a + 1) 4 | print("a:%d", a) 5 | end 6 | end -------------------------------------------------------------------------------- /src/main/scala/Interface/RCI.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package Interface 3 | 4 | import interpreter.Interpreter 5 | 6 | @main def rci = Interpreter().run -------------------------------------------------------------------------------- /demo/module/main/normal.rc: -------------------------------------------------------------------------------- 1 | import "brother" 2 | import "../parent" 3 | import "../submodule/otherlib" 4 | 5 | def main() 6 | print("hello, world.") 7 | end -------------------------------------------------------------------------------- /demo/template/class.rc: -------------------------------------------------------------------------------- 1 | class TreeNode 2 | var v: T 3 | end 4 | 5 | def main() 6 | val node = TreeNode.new(5) 7 | print("node: %d", node.v) 8 | end -------------------------------------------------------------------------------- /src/main/scala/compiler/CompileOption.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package compiler 3 | 4 | case class CompileOption(srcPath: List[String] = List(), outPath: String = "") 5 | -------------------------------------------------------------------------------- /demo/control/if.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | var a = 2 3 | print("a:%d", a) 4 | if a > 1 5 | print("a>1") 6 | else 7 | print("a<1") 8 | end 9 | print("end") 10 | end -------------------------------------------------------------------------------- /src/main/scala/lib/Libc.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package lib 3 | 4 | object Libc { 5 | val methods = List("printf", "open", "malloc", "fopen", "fgets", "fputs", "system") 6 | } 7 | -------------------------------------------------------------------------------- /demo/base/call.rc: -------------------------------------------------------------------------------- 1 | def add(a: Int, b: Int) 2 | a + b 3 | end 4 | 5 | def main() 6 | var a = 1 7 | var b = 2 8 | val c = add(a, b) 9 | print("c:%d", c) 10 | end -------------------------------------------------------------------------------- /demo/control/while.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | val a = 1 3 | while(a < 5) 4 | a = a + 1 5 | print("int while:%d\n", a) 6 | end 7 | val b = 2 8 | print("b=%d", b) 9 | end -------------------------------------------------------------------------------- /src/main/scala/lib/RcObject.rc: -------------------------------------------------------------------------------- 1 | class RcObject 2 | def new() 3 | var a = malloc() 4 | init() 5 | a 6 | end 7 | 8 | def init() 9 | 10 | end 11 | end -------------------------------------------------------------------------------- /demo/oop/rc_object.rc: -------------------------------------------------------------------------------- 1 | class RcObject 2 | def new() 3 | var a = malloc() 4 | init() 5 | a 6 | end 7 | 8 | def init() 9 | var a = 1 10 | end 11 | end -------------------------------------------------------------------------------- /src/main/scala/ast/TyInfo.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import scala.util.parsing.input.Positional 5 | 6 | enum TyInfo extends ASTNode: 7 | case Spec(ty: Ident) 8 | case Infer 9 | case Nil -------------------------------------------------------------------------------- /demo/control/nested_while.rc: -------------------------------------------------------------------------------- 1 | def main() 2 | val a = 1 3 | while(a < 5) 4 | val b = 2 5 | while(b < 5) 6 | b = b + 1 7 | end 8 | a = a + 1 9 | end 10 | val c = 3 11 | end -------------------------------------------------------------------------------- /demo/base/fibonacci.rc: -------------------------------------------------------------------------------- 1 | def f(x: Int): Int 2 | if x < 3 3 | x 4 | else 5 | f(x-1) + f(x-2) 6 | end 7 | end 8 | 9 | def main() 10 | print("fibonacci(2) result is:%d\n", f(2)) 11 | print("fibonacci(6) result is:%d\n", f(4)) 12 | end -------------------------------------------------------------------------------- /demo/oop/make_object.rc: -------------------------------------------------------------------------------- 1 | class RcObject 2 | def new() 3 | var a = malloc() 4 | print("new a%p", a) 5 | init() 6 | a 7 | end 8 | 9 | def init() 10 | print("init") 11 | end 12 | end 13 | 14 | def main() 15 | var obj = RcObject.new() 16 | end -------------------------------------------------------------------------------- /src/main/scala/tools/Debugger.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | object Debugger { 5 | def check(valid : => Boolean, msg: String) = { 6 | if(!valid) { 7 | throw new RuntimeException(msg) 8 | } 9 | } 10 | 11 | def unImpl[T](v: T) = { 12 | println(v.getClass) 13 | ??? 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/analysis/DomTreeAnalysis.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import mir.* 5 | import pass.{Analysis, AnalysisManager} 6 | 7 | case class DomTreeAnalysis() extends Analysis[Function] { 8 | type ResultT = DomTree 9 | def run(irUnit: Function, AM: AnalysisManager[Function]): ResultT = { 10 | DomTreeBuilder().compute(irUnit) 11 | } 12 | } -------------------------------------------------------------------------------- /src/main/scala/pass/Pass.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package pass 3 | 4 | trait Pass[IRUnitT] { 5 | } 6 | 7 | trait Transform[IRUnitT] extends Pass[IRUnitT] { 8 | def run(iRUnitT: IRUnitT, AM: AnalysisManager[IRUnitT]): Unit 9 | } 10 | 11 | trait Analysis[IRUnitT] extends Pass[IRUnitT] { 12 | type ResultT 13 | def run(irUnit: IRUnitT, AM: AnalysisManager[IRUnitT]): ResultT 14 | } -------------------------------------------------------------------------------- /src/main/scala/analysis/Result.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import mir.* 5 | 6 | object Analysis: 7 | given DomTreeAnalysis with { 8 | type ResultT = DomTree 9 | } 10 | 11 | given LoopAnalysis with { 12 | type ResultT = LoopInfo 13 | } 14 | 15 | given DomFrontierAnalysis with { 16 | type ResultT = Map[DomTreeNode, Set[DomTreeNode]] 17 | } 18 | 19 | -------------------------------------------------------------------------------- /src/main/scala/parser/RcParser.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import ast.{ASTNode, Modules, RcModule} 5 | import lexer.Token 6 | 7 | object RcParser extends ModuleParser { 8 | def apply(tokens: Seq[Token]): Either[RcParserError, RcModule] = { 9 | doParser(tokens, program) 10 | } 11 | 12 | def program: Parser[RcModule] = positioned { 13 | phrase(log(module)("module")) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/tools/DumpManager.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | import java.nio.file.{Files, Paths} 5 | 6 | object DumpManager { 7 | private var dumpRoot = "RcDump" 8 | def mkDumpRootDir = { 9 | Files.createDirectories(Paths.get(dumpRoot)) 10 | } 11 | 12 | def setDumpRoot(path: String) = { 13 | dumpRoot = path 14 | mkDumpRootDir 15 | } 16 | 17 | def getDumpRoot = dumpRoot 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/CompileError.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | 3 | sealed trait RcCompilationError 4 | 5 | case class RcLexerError(location: Location, msg: String) extends RcCompilationError 6 | case class RcParserError(location: Location, msg: String) extends RcCompilationError 7 | case class RcNotSupported(location: Location, msg: String) extends RcCompilationError 8 | 9 | case class Location(line: Int, column: Int) { 10 | override def toString = s"$line:$column" 11 | } -------------------------------------------------------------------------------- /src/main/scala/codegen/Target.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | abstract class TargetMachine() { 5 | val cpu: String 6 | val regInfos: List[RegInfo] 7 | val callingConvention: CallingConvention 8 | val wordSize: Int 9 | val gregCount: Int 10 | val asmEmiter: ASMEmiter 11 | } 12 | 13 | enum CallingConvention: 14 | case x86 15 | case x86_64 16 | 17 | case class RegInfo(name: String, asmName: String, alias: Set[String], id: Int, bit: Int = 4) -------------------------------------------------------------------------------- /src/main/scala/ast/Render.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import tools.Render 5 | 6 | class ClassesRender extends Render { 7 | def rendClasses(fileName: String, directory: String, methods: List[Class]): Unit = { 8 | rend(fileName, directory, methods) { (dot, klass) => 9 | dot.node(klass.name.str) 10 | klass.parent match 11 | case Some(parent) => dot.edge(parent.str, klass.name.str) 12 | case _ => 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /demo/oop/oop.rc: -------------------------------------------------------------------------------- 1 | class RcObject 2 | def init() 3 | 4 | end 5 | 6 | def new() 7 | var a = malloc() 8 | this.init() 9 | a 10 | end 11 | end 12 | 13 | class Point < RcObject 14 | var x:Int 15 | var y:Int 16 | 17 | def move_x(offset: Int): Int 18 | print("x: %d\n", x) 19 | print("offset: %d\n", offset) 20 | x = x + offset 21 | end 22 | end 23 | 24 | def main() 25 | var p = Point.new() 26 | p.move_x(5) 27 | print("a.x:%d\n", p.x) 28 | end -------------------------------------------------------------------------------- /src/main/scala/Def.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | 3 | import ast.* 4 | import ast.ImplicitConversions.* 5 | 6 | object Def { 7 | val Kernel = "Kernel" 8 | val DefaultModule = Kernel 9 | val self = "this" 10 | 11 | // return ptr type 12 | val NewMethod = Method(MethodDecl("new", Params(List()), TyInfo.Nil), Expr.Block(List())) 13 | val RcObject = Class("RcObject", None, List(), List()) 14 | // def selfObj(klass: String) = Param(self, TyInfo.Spec(klass)) 15 | 16 | val version = "RCC: 0.0.1" 17 | } 18 | -------------------------------------------------------------------------------- /src/main/scala/ast/Stmt.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import ast.Expr.Block 5 | import ast.Ident 6 | import ty.Typed 7 | import scala.util.parsing.input.Positional 8 | 9 | enum Stmt extends ASTNode with Typed: 10 | case Local(name: Ident, tyInfo: TyInfo, value: ast.Expr) 11 | case Expr(expr: ast.Expr) 12 | case While(cond: ast.Expr, body: Block) 13 | case For(init: Stmt, cond: ast.Expr, incr: Stmt, body: Block) 14 | case Assign(name: Ident, value: ast.Expr) 15 | case Break() 16 | case Continue() 17 | -------------------------------------------------------------------------------- /src/main/scala/tools/State.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | class State[T](var value: T) { 5 | 6 | /** switch to a tmp state and call f 7 | * @param newState 8 | * @param f 9 | * @return: value before restore state 10 | */ 11 | def by(newState: T)(f:() => Unit) = { 12 | val oldValue = value 13 | value = newState 14 | f() 15 | val save = value 16 | value = oldValue 17 | save 18 | } 19 | } 20 | 21 | implicit def toState[T](v: T): State[T] = { 22 | new State(v) 23 | } 24 | -------------------------------------------------------------------------------- /src/test/scala/transform/CSETest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import pass.{AnalysisManager, PassManager} 5 | import mir.* 6 | import tools.RcLogger.* 7 | 8 | class CSETest extends RcTestBase { 9 | describe("simple cse") { 10 | it("should run") { 11 | val fn = getOptDemoFirstFn("cse.rc") 12 | CSE().run(fn, AnalysisManager()) 13 | assert(fn.instructions.count(_.isInstanceOf[Binary]) == 2) 14 | assert(fn.instructions.count(_.isInstanceOf[Load]) == 1) 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/scala/transform/CFGSimplify.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import mir.* 5 | import pass.{AnalysisManager, Transform} 6 | 7 | def removeUnreachedBasicBlock(IRUnit: Function): Unit = { 8 | val bbs = dfsBasicBlocks(IRUnit.entry).toSet 9 | val newBBs = IRUnit.bbs.filter(bbs.contains) 10 | IRUnit.bbs = newBBs 11 | } 12 | 13 | class CFGSimplify() extends Transform[Function] { 14 | override def run(IRUnit: Function, AM: AnalysisManager[Function]): Unit = { 15 | removeUnreachedBasicBlock(IRUnit) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.class 3 | *.log 4 | *~ 5 | 6 | # sbt specific 7 | dist/* 8 | target/ 9 | lib_managed/ 10 | src_managed/ 11 | project/boot/ 12 | project/plugins/project/ 13 | project/local-plugins.sbt 14 | .history 15 | .bsp 16 | 17 | # Scala-IDE specific 18 | .scala_dependencies 19 | .cache 20 | .classpath 21 | .project 22 | .settings 23 | classes/ 24 | 25 | # idea 26 | .idea 27 | .idea_modules 28 | /.worksheet/ 29 | 30 | # Dotty-IDE 31 | .dotty-ide-artifact 32 | .dotty-ide.json 33 | 34 | # Visual Studio Code 35 | .vscode 36 | 37 | # dump files 38 | RcDump 39 | RcTestDump -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: compiler-test 2 | 3 | on: [push, pull_request] 4 | 5 | permissions: 6 | contents: read 7 | 8 | jobs: 9 | build: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-18.04,windows-2022,macos-10.15] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up JDK 11 18 | uses: actions/setup-java@v3 19 | with: 20 | java-version: '11' 21 | distribution: 'temurin' 22 | cache: 'sbt' 23 | - name: Run tests 24 | run: sbt test 25 | -------------------------------------------------------------------------------- /src/main/scala/mir/Intrinsic.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | import ty.{NilType, PointerType} 4 | import tools.Debugger.*; 5 | 6 | case class Print(arg_list: List[Value]) extends Intrinsic("print", arg_list) { 7 | ty = NilType 8 | } 9 | 10 | case class Open(arg_list: List[Value]) extends Intrinsic("open", arg_list) { 11 | ty = NilType 12 | } 13 | 14 | case class Malloc(arg_list: List[Value]) extends Intrinsic("malloc", arg_list) { 15 | check(arg_list.size == 1, s"malloc arg size should be 1, but get ${arg_list.size}") 16 | ty = PointerType(NilType) 17 | } 18 | 19 | val intrinsics = List("print", "open", "malloc") -------------------------------------------------------------------------------- /src/main/scala/ast/AST.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import scala.language.implicitConversions 5 | import scala.util.parsing.input.Positional 6 | 7 | case class Ident(str: String) extends ASTNode 8 | 9 | object ImplicitConversions { 10 | implicit def strToId(str: String): Ident = Ident(str) 11 | implicit def IdToStr(id: Ident): String = id.str 12 | implicit def boolToAST(b: Boolean): Expr.Bool = Expr.Bool(b) 13 | implicit def intToAST(i: Int): Expr.Number = Expr.Number(i) 14 | } 15 | 16 | trait ASTNode extends Positional 17 | 18 | case class Modules(modules: List[RcModule]) extends ASTNode 19 | object Empty extends ASTNode -------------------------------------------------------------------------------- /src/test/scala/ast/ExprTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import ast.ImplicitConversions.* 5 | import ast.Expr.* 6 | 7 | import rclang.mir.BasicBlock 8 | 9 | class ExprTest extends RcTestBase with ASTBuilder { 10 | describe("noCapturedLambdaToMethod") { 11 | it("ok") { 12 | val a = BasicBlock("a") 13 | val b = BasicBlock("a") 14 | println(a eq b) 15 | println(a == b) 16 | val lambda = Lambda( 17 | Params(List(Param("a", TyInfo.Spec("Int")))), 18 | Block(List(Stmt.Expr(Expr.Binary(BinaryOp.Add, Expr.Identifier("a"), Expr.Number(1)))))) 19 | val m = lambdaToMethod(lambda.asInstanceOf[Expr.Lambda]) 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/test/scala/tools/StateTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | 6 | class StateTest extends AnyFunSpec { 7 | 8 | describe("StateTest") { 9 | it("NormalBy") { 10 | val s = State("Str") 11 | val result = s.by("tmp") { () => 12 | assert(s.value == "tmp") 13 | } 14 | assert(result == "tmp") 15 | assert(s.value == "Str") 16 | } 17 | 18 | // modify 19 | it("ModifyInBody") { 20 | val s = State("Str") 21 | val result = s.by("tmp") { () => 22 | s.value = "what" 23 | } 24 | assert(result == "what") 25 | assert(s.value == "Str") 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/scala/transform/ConstantFoldingTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import pass.{AnalysisManager, PassManager} 5 | import mir.* 6 | import tools.RcLogger.* 7 | 8 | class ConstantFoldingTest extends RcTestBase { 9 | 10 | describe("ConstantFoldingTest") { 11 | it("should run") { 12 | val fn = getOptDemoFirstFn("constant_folding.rc") 13 | ConstantFolding().run(fn, AnalysisManager()) 14 | // alloc + store + return 15 | assert(!fn.instructions.exists(_.isInstanceOf[Binary])) 16 | assert(fn.instructions.size == 3) 17 | assert(fn.instructions.takeRight(2).head.getOperand(0).asInstanceOf[Integer].value == 3) 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/tools/Ext.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | import java.io.File 5 | 6 | extension [T](x: T) { 7 | def tap(f: T => Unit): T = { f(x); x } 8 | } 9 | 10 | def run[TL, TR](result: => Either[TL, TR]): TR = { 11 | result match { 12 | case Left(l) => throw new RuntimeException(l.toString) 13 | case Right(r) => r 14 | } 15 | } 16 | 17 | extension [TL, TR](result: => Either[TL, TR]) { 18 | def unwrap: TR = { 19 | result match { 20 | case Left(l) => throw new RuntimeException(l.toString) 21 | case Right(r) => r 22 | } 23 | } 24 | } 25 | 26 | extension (dir: String) { 27 | def /(file: String): String = { 28 | s"$dir${File.separator}$file" 29 | } 30 | } -------------------------------------------------------------------------------- /src/main/scala/pass/PassManager.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package pass 3 | 4 | case class PassManager[IRUnitT]() { 5 | type CallbackT = (IRUnitT, Transform[IRUnitT]) => Unit 6 | 7 | def run(IRUnit: IRUnitT, am: AnalysisManager[IRUnitT]) = { 8 | passes.foreach(pass => { 9 | pass.run(IRUnit, am) 10 | callbacksAfterPass.foreach(f => f(IRUnit, pass)) 11 | }) 12 | } 13 | 14 | def addPass(pass: Transform[IRUnitT]): Unit = { 15 | passes = passes :+ pass 16 | } 17 | 18 | def registerAfterPass(callback: CallbackT) = { 19 | callbacksAfterPass = callbacksAfterPass :+ callback 20 | } 21 | 22 | var passes = List[Transform[IRUnitT]]() 23 | var callbacksAfterPass = List[CallbackT]() 24 | } 25 | -------------------------------------------------------------------------------- /demo/oop/tree.rc: -------------------------------------------------------------------------------- 1 | class RcObject 2 | def init() 3 | 4 | end 5 | 6 | def new() 7 | var a = malloc() 8 | this.init() 9 | a 10 | end 11 | end 12 | 13 | class TreeNode 14 | var lhs: TreeNode 15 | var rhs: TreeNode 16 | var value: Int 17 | end 18 | 19 | def flip(node: TreeNode) 20 | if node == nil 21 | return nil 22 | end 23 | var newLhs = filp(node.lhs) 24 | var newRhs = filp(node.rhs) 25 | TreeNdoe.new(node.value, lhs, rhs) 26 | end 27 | 28 | def main 29 | var lhs = TreeNode.new(3, nil, nil) 30 | var rhs = TreeNode.new(4, nil, nil) 31 | var root = TreeNode.new(5, lhs, rhs) 32 | var new_root = file(root) 33 | printf("new root = %d", new_root.value) 34 | printf("new lhs = %d", new_root.lhs.value) 35 | printf("new rhs = %d", new_root.rhs.value) 36 | end -------------------------------------------------------------------------------- /src/main/scala/tools/Render.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | trait Render { 5 | def rendInit(dot: Digraph) = {} 6 | def rendFinal(dot: Digraph) = {} 7 | 8 | protected def rendDotImpl(fileName: String, directory: String)(f: Digraph => Unit): Unit = { 9 | val dot = new Digraph() 10 | rendInit(dot) 11 | f(dot) 12 | rendFinal(dot) 13 | dot.render(fileName = fileName, directory = directory, format = "svg") 14 | } 15 | 16 | protected def rend[T](fileName: String, directory: String, v: T)(f: (Digraph, T) => Unit): Unit = { 17 | rendDotImpl(fileName, directory) { dot => 18 | f(dot, v) 19 | } 20 | } 21 | 22 | protected def rend[T](fileName: String, directory: String, list: List[T])(f: (Digraph, T) => Unit): Unit = { 23 | rendDotImpl(fileName, directory) { dot => 24 | list.foreach(v => f(dot, v)) 25 | } 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/scala/pass/AnalysisManager.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package pass 3 | 4 | import analysis.* 5 | import mir.DomTree 6 | import scala.language.implicitConversions 7 | 8 | case class AnalysisManager[IRUnitT]() { 9 | def getResult[AnalysisT <: Analysis[IRUnitT]](IRUnit: IRUnitT)(using analysis: AnalysisT): analysis.ResultT = { 10 | analysis.run(IRUnit, this) 11 | } 12 | 13 | def addAnalysis(analysis: => Analysis[IRUnitT]): Unit = { 14 | val name = analysis.getClass.getTypeName 15 | if (!analyses.contains(name)) { 16 | analyses += (name -> analysis) 17 | } 18 | } 19 | 20 | var analyses = Map[String, Analysis[IRUnitT]]() 21 | } 22 | 23 | def getAnalysisResult[IRUnitT, AnalysisT <: Analysis[IRUnitT]](IRUnit: IRUnitT)(using analysis: AnalysisT): analysis.ResultT = { 24 | val am = AnalysisManager[IRUnitT]() 25 | am.addAnalysis(analysis) 26 | am.getResult(IRUnit) 27 | } -------------------------------------------------------------------------------- /src/main/scala/interpreter/interpreter.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package interpreter 3 | 4 | import scala.io.StdIn.readLine 5 | import ast.* 6 | import lexer.Lexer 7 | import tools.{RcLogger, unwrap} 8 | import parser.{RcExprParser, RcParser} 9 | 10 | case class Interpreter() { 11 | val prompt: String = "rc> " 12 | var evaluator = Evaluator() 13 | def run = { 14 | var isRunning = true 15 | while (true) { 16 | print(prompt) 17 | val line = readLine + "\n" 18 | if(line == null) { 19 | isRunning = false 20 | } 21 | else { 22 | if(line.nonEmpty) { 23 | val result = interpret(line) 24 | println(result) 25 | } 26 | } 27 | } 28 | } 29 | 30 | def interpret(str: String): Any = { 31 | val tokens = Lexer(str).unwrap 32 | println(tokens) 33 | val ast = RcExprParser(tokens).unwrap 34 | evaluator.run_stmt(ast) 35 | } 36 | } -------------------------------------------------------------------------------- /src/main/scala/mir/Render.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import tools.{DumpManager, Render} 5 | 6 | object CFGRender extends Render { 7 | // var blocksEntryName = "blocksEntry" 8 | def BBEdges(dot: Digraph, bb: BasicBlock) = { 9 | val name = bb.name 10 | dot.node(name) 11 | dot.edges(bb.terminator.successors.map(b => (name, b.name)).toArray) 12 | } 13 | 14 | def rendBBs(bbs: List[BasicBlock], fileName: String, directory: String = DumpManager.getDumpRoot): Unit = { 15 | rend(fileName, directory, bbs)(BBEdges) 16 | } 17 | 18 | def rendFn(fn: Function, fileName: String, directory: String = DumpManager.getDumpRoot): Unit = { 19 | // blocksEntryName = fn.entry.name 20 | rendBBs(fn.bbs, fileName, directory) 21 | } 22 | 23 | override def rendInit(dot: Digraph): Unit = { 24 | // dot.node("entry", null, collection.mutable.Map("URL" -> "\"https://www.google.com\"")) 25 | // dot.edge("entry", blocksEntryName) 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/scala/codegen/MachineIRPrinter.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import java.io.PrintWriter 5 | 6 | class MachineIRPrinter { 7 | def print(mfs: List[MachineFunction]): Unit = mfs.foreach(print) 8 | 9 | def printToWriter(mf: MachineFunction, writer: PrintWriter): Unit = { 10 | writer.write(toStr(mf)) 11 | } 12 | 13 | def print(mf: MachineFunction): Unit = { 14 | val content = toStr(mf) 15 | println(content) 16 | } 17 | 18 | private def toStr(mf: MachineFunction): String = { 19 | (List(mf.name):::mf.bbs.flatMap(toStr)).map(_ + "\n").mkString 20 | } 21 | 22 | private def toStr(mbb: MachineBasicBlock): List[String] = { 23 | List(mbb.name):::mbb.instList.map(toStr) 24 | } 25 | 26 | private def toStr(inst: MachineInstruction): String = { 27 | if(inst.origin == null) { 28 | printf("") 29 | } 30 | s"${inst.getClass.getName.split('.').last} ${inst.operands} #${inst.origin.pos}" 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/test/scala/transform/CallInlinerTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import mir.* 5 | 6 | import rclang.pass.AnalysisManager 7 | 8 | class CallInlinerTest extends RcTestBase { 9 | describe("simple") { 10 | it("succ") { 11 | val args = List(Argument("a", ty.Int32Type), Argument("b", ty.Int32Type)) 12 | val bn = new Binary("Add", args(0), args(1)) 13 | val bb = new BasicBlock("b1", List(bn, new Return(bn))) 14 | val f = Function("add", ty.Int32Type, args, bb, List(bb)) 15 | 16 | val a = Integer(1) 17 | val b = Integer(2) 18 | val c = Call(f, List(a, b)) 19 | val d = new Binary("Add", Integer(3), c) 20 | val mainBB = new BasicBlock("mainBB", List(c, d)) 21 | val main = Function("main", ty.NilType, List(), mainBB, List(mainBB)) 22 | CallInliner().run(main, new AnalysisManager[Function]()) 23 | assert(main.bbs.length == 2) 24 | assert(!main.instructions.exists(x => x.isInstanceOf[Call])) 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/analysis/DomFrontierAnalysis.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import mir.{DomTreeNode, Function} 5 | import pass.{Analysis, AnalysisManager} 6 | import rclang.analysis.Analysis.given_DomTreeAnalysis 7 | 8 | case class DomFrontierAnalysis() extends Analysis[Function] { 9 | override type ResultT = Map[DomTreeNode, Set[DomTreeNode]] 10 | override def run(irUnit: Function, AM: AnalysisManager[Function]): ResultT = { 11 | val domTree = AM.getResult[DomTreeAnalysis](irUnit) 12 | var df = Map[DomTreeNode, Set[DomTreeNode]]() 13 | for (n <- irUnit.bbs) { 14 | df = df.updated(domTree(n), Set[DomTreeNode]()) 15 | } 16 | 17 | for (n <- irUnit.bbs.map(domTree(_))) { 18 | if(n.preds.length > 1) { 19 | for (p <- n.preds.map(domTree(_))) { 20 | var runner = p 21 | while(runner != n.iDom) { 22 | df = df.updated(runner, df(runner) + n) 23 | runner = runner.iDom 24 | } 25 | } 26 | } 27 | } 28 | 29 | df 30 | } 31 | } -------------------------------------------------------------------------------- /src/test/scala/mir/CFGTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | import org.scalatest.BeforeAndAfter 6 | import tools.DumpManager 7 | import mir.* 8 | 9 | class CFGTest extends RcTestBase { 10 | var bbs: BBsType = null 11 | before { 12 | bbs = mkBBs( 13 | "1" -> "2", 14 | "1" -> "3", 15 | "2" -> "4", 16 | "3" -> "4", 17 | ) 18 | } 19 | 20 | describe("canReach") { 21 | it("success") { 22 | canReach(bbs("1"), bbs("2")) should be(true) 23 | canReach(bbs("1"), bbs("3")) should be(true) 24 | canReach(bbs("1"), bbs("4")) should be(true) 25 | canReach(bbs("4"), bbs("1")) should be(false) 26 | } 27 | } 28 | 29 | describe("predecessors") { 30 | it("succ") { 31 | predecessors(bbs("4"), bbs.values.toList) should be(Set(bbs("2"), bbs("3"))) 32 | } 33 | } 34 | 35 | describe("dfsBasicBlocks") { 36 | it("succ") { 37 | dfsBasicBlocks(bbs("1")) should be(List(bbs("1"), bbs("2"), bbs("4"), bbs("3"))) 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/test/scala/RcTestBase.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | 3 | import org.scalatest.funspec.AnyFunSpec 4 | import org.scalatest.BeforeAndAfter 5 | import org.scalatest.matchers.should.Matchers 6 | 7 | import java.io.File 8 | import tools.{DumpManager, RcLogger} 9 | import tools.RcLogger.{log, logf} 10 | import compiler.Driver.* 11 | import mir.MIRTranslator 12 | import tools./ 13 | 14 | class RcTestBase extends AnyFunSpec with BeforeAndAfter with Matchers { 15 | DumpManager.setDumpRoot("RcTestDump") 16 | 17 | def getModule(srcPath: String) = { 18 | val ast = parse(srcPath) 19 | val (typedModule, table) = typeProc(ast) 20 | val mirMod = log(MIRTranslator(table).proc(typedModule), "ToMIR") 21 | mirMod 22 | } 23 | 24 | def getFirstFn(srcPath: String) = { 25 | val mirMod = getModule(srcPath) 26 | mirMod.fnTable.values.head 27 | } 28 | 29 | def getDemoFirstFn(name: String) = { 30 | getFirstFn("demo" / name) 31 | } 32 | 33 | def getOptDemoFirstFn(name: String) = { 34 | getFirstFn("demo" / "opt" / name) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/scala/codegen/MachineIRBuilder.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | class MachineIRBuilder() { 5 | var mbb: MachineBasicBlock = null 6 | 7 | def insert(inst: MachineInstruction) = { 8 | mbb.insert(inst) 9 | } 10 | 11 | def insertLoadInst(dst: Dst, addr: Src) = insert(LoadInst(dst, addr)) 12 | 13 | def insertStoreInst(addr: Dst, src: Src) = insert(StoreInst(addr, src)) 14 | 15 | def insertCallInst(targetFn: String, dst: Dst, params: List[Src]) = insert(CallInst(targetFn, dst, params)) 16 | 17 | def insertReturnInst(src: Src) = insert(ReturnInst(src)) 18 | 19 | def insertBinaryInst(op: BinaryOperator, dst: Dst, lhs: Src, rhs: Src) = insert(BinaryInst(op, dst, lhs, rhs)) 20 | 21 | def insertBranchInst(addr: Src) = insert(BranchInst(addr)) 22 | 23 | def insrtCondBrInst(cond: Src, addr: Src, condType: CondType) = insert(CondBrInst(cond, addr, condType)) 24 | 25 | def insertPhiInst(dst: Dst, incoming: Map[Src, MachineBasicBlock]) = insert(PhiInst(dst, incoming)) 26 | 27 | def insrtInlineASM(str: String) = insert(InlineASM(str)) 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/lib/StdLib.rc: -------------------------------------------------------------------------------- 1 | class File 2 | def init 3 | 4 | end 5 | 6 | def open(val path: String, val mode: String) 7 | handle = fopen(path, mode) 8 | end 9 | 10 | def read() 11 | fgets(, n, handle) 12 | end 13 | 14 | def write(val content: String) 15 | fputs(content, handle) 16 | end 17 | 18 | def close() 19 | fclose(handle) 20 | end 21 | 22 | var handle: Handle 23 | end 24 | 25 | class String 26 | def split(val c: Char) 27 | 28 | end 29 | 30 | def concat(val str: String) 31 | 32 | end 33 | 34 | def length() 35 | 36 | end 37 | 38 | def indexAt(val i: Int) 39 | 40 | end 41 | end 42 | 43 | class List 44 | def add(val v: Int) 45 | 46 | end 47 | end 48 | 49 | class FileSystem 50 | def get_current_dir() 51 | 52 | end 53 | 54 | def create_dir(val path: String) 55 | 56 | end 57 | 58 | def exist(val path: String) 59 | 60 | end 61 | end 62 | 63 | class Time 64 | 65 | ene 66 | 67 | class System 68 | def cmd(val str: String) 69 | system(str) 70 | end 71 | end -------------------------------------------------------------------------------- /src/main/scala/mir/DFCalculator.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | case class DFCalculator(var domTree: DomTree) { 5 | def findParents(tree: DomTree, root: DomTreeNode): Set[DomTreeNode] = { 6 | val map = predecessorsMap(tree.nodes.keys.toList) 7 | val res = domTree.nodes.filter(node => { 8 | val pres = map(node._1) 9 | // println(s"current node: ${node._1.name}") 10 | // println(s"preds $pres") 11 | // join node -> pres.size > 1 12 | 13 | // n -> x -> y 14 | // x sdom y 15 | // n dom y 16 | pres.size > 1 && pres.exists(p => { 17 | println(s"p: ${p.name}") 18 | println(domTree(p).children.map(_.basicBlock.name)) 19 | (root sdom domTree(p)) && (p != node._1) 20 | }) 21 | }).values.toSet 22 | println(s"res:${res.map(_.name)}") 23 | res 24 | } 25 | 26 | def run(bb: BasicBlock): List[BasicBlock] = { 27 | val parents = findParents(domTree, domTree(bb)) 28 | parents.toList.map(_.basicBlock) 29 | // val result = parents.map(node => findParents(domTree, node)).reduce(_ | _) 30 | // result.map(_.basicBlock).toList 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/codegen/PhiEliminate.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import pass.* 5 | 6 | class PhiEliminate extends Transform[MachineFunction] { 7 | def run(fn: MachineFunction, am: AnalysisManager[MachineFunction]) = { 8 | fn.bbs.foreach(bb => { 9 | // 1. find all phi 10 | val phis = bb.instList 11 | .filter(inst => inst.isInstanceOf[PhiInst]) 12 | .map(inst => inst.asInstanceOf[PhiInst]) 13 | // 2. replace with x0 = value 14 | phis.foreach(eliminate(_, bb)) 15 | }) 16 | } 17 | 18 | // make copy for every income 19 | def eliminate(phiInst: PhiInst, basicBlock: MachineBasicBlock) = { 20 | val target = VReg(basicBlock.parent.instructions.length) 21 | phiInst.incomings.foreach((v, mbb) => { 22 | val parent = v.instParent 23 | val store = StoreInst(target.dup, v).setOrigin(parent.origin) 24 | mbb.insertAt(store, mbb.instList.length - 1) 25 | println(mbb.instList) 26 | }) 27 | val st = StoreInst(phiInst.dst, target.dup).setOrigin(phiInst.origin) 28 | basicBlock.insertAtFirst(st) 29 | // remove phi 30 | phiInst.removeFromParent() 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/test/scala/integrated/DemoCompileTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package integrated 3 | 4 | import org.scalatest.Outcome 5 | import rclang.compiler.Driver 6 | 7 | import java.io.File 8 | 9 | def listFiles(dir: File): List[File] = { 10 | val files = dir.listFiles.toList 11 | files.flatMap { 12 | case file if file.isDirectory => listFiles(file) 13 | case file => List(file) 14 | }.filter(f => f.getName.split('.')(1) == "rc") 15 | } 16 | 17 | class DemoCompileTest extends RcTestBase { 18 | val dir = new File(new File("").getAbsolutePath, "demo") 19 | val files = listFiles(dir).filter(f => !(f.getPath.contains("template") || f.getPath.contains("lib"))) 20 | describe("run all") { 21 | it("ok") { 22 | var error = List[String]() 23 | files.foreach(file => { 24 | try { 25 | val ast = Driver.parse(file.getPath) 26 | Driver.compileAST(ast) 27 | } catch { 28 | case _: Throwable => { 29 | error = error.appended(file.getPath) 30 | } 31 | } 32 | }) 33 | if(error.nonEmpty) { 34 | error.foreach(println) 35 | assert(false) 36 | } 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/tools/RcLogger.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | import cats.effect.unsafe.implicits.global 5 | import cats.effect.{IO, IOApp} 6 | import io.odin.* 7 | import cats.effect.IO 8 | import cats.effect.unsafe.IORuntime 9 | 10 | import java.io.{File, PrintWriter} 11 | 12 | 13 | object RcLogger { 14 | val logger: Logger[IO] = consoleLogger() 15 | 16 | def warning(str: String): Unit = { 17 | log(str) 18 | } 19 | 20 | def log(str: String): Unit = { 21 | logger.info(str).unsafeRunSync() 22 | } 23 | 24 | def log[T](result: T, stage: String): T = { 25 | val r = result 26 | log(stage + " Finish") 27 | r 28 | } 29 | 30 | def logf[T](path: String, v: T): Unit = logf(path, v.toString) 31 | 32 | def logf(path: String, str: String): Unit = { 33 | logf(path)(_.write(str)) 34 | } 35 | 36 | def logf(path: String)(f: PrintWriter => Unit) = { 37 | val printer = new PrintWriter(new File(DumpManager.getDumpRoot, path)); 38 | f(printer) 39 | printer.close() 40 | } 41 | 42 | def logSep(prefix: String, f: => Unit) = { 43 | log(s"------------$prefix begin------------") 44 | f 45 | log(s"------------$prefix end------------") 46 | } 47 | var level = 2 48 | } -------------------------------------------------------------------------------- /src/main/scala/lexer/Token.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package lexer 3 | 4 | import scala.util.parsing.input.Positional 5 | 6 | trait Token extends Positional 7 | 8 | enum Keyword extends Token: 9 | // local 10 | case VAR 11 | case VAL 12 | // method 13 | case DEF 14 | case RETURN 15 | case END 16 | // control flow 17 | case IF 18 | case THEN 19 | case ELSIF 20 | case ELSE 21 | case WHILE 22 | case FOR 23 | case BREAK 24 | case CONTINUE 25 | // class 26 | case CLASS 27 | case SUPER 28 | case SELF 29 | case METHODS 30 | case VARS 31 | // module 32 | case IMPORT 33 | 34 | enum Punctuation extends Token: 35 | case COMMENT 36 | case EOL 37 | case COMMA 38 | case EQL // = 39 | case SPACE 40 | case DOT 41 | case COLON 42 | case SEMICOLON 43 | case AT 44 | case OPERATOR(op: String) 45 | 46 | enum Literal extends Token: 47 | case NUMBER(int: Int) 48 | case STRING(str: String) 49 | case TRUE 50 | case FALSE 51 | 52 | enum Delimiter extends Token: 53 | case LEFT_PARENT_THESES 54 | case RIGHT_PARENT_THESES 55 | case LEFT_SQUARE 56 | case RIGHT_SQUARE 57 | case LEFT_BRACKET 58 | case RIGHT_BRACKET 59 | 60 | enum Ident extends Token: 61 | case IDENTIFIER(str: String) 62 | case UPPER_IDENTIFIER(str: String) -------------------------------------------------------------------------------- /src/main/scala/Interface/RCC.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | 3 | import Interface.* 4 | import compiler.CompileOption 5 | 6 | import scala.util.CommandLineParser 7 | import scopt.OParser 8 | 9 | given CommandLineParser.FromString[CompileOption] with { 10 | override def fromString(s: String): CompileOption = { 11 | val builder = OParser.builder[CompileOption] 12 | val parser = { 13 | import builder._ 14 | OParser.sequence( 15 | programName("rcc"), 16 | head("rcc", "0.0.0"), 17 | arg[String]("...") 18 | .unbounded() 19 | .action((x, c) => c.copy(srcPath = List(x):::c.srcPath)) 20 | .text("optional unbounded args"), 21 | opt[String]('o', "output") 22 | .text("set output file name"), 23 | opt[Unit]('t', "target") 24 | .text("show available target") 25 | .action((_, c) => { 26 | println("only suppport GNU ASM") 27 | c 28 | }), 29 | help('h', "help").text("display available options"), 30 | ) 31 | } 32 | OParser.parse(parser, s.split(' '), CompileOption()) match 33 | case Some(value) => value 34 | case None => CompileOption() 35 | } 36 | } 37 | 38 | @main def rcc(option: CompileOption) = compiler.Driver(option) -------------------------------------------------------------------------------- /src/main/scala/ty/Check.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | import ast.Expr 4 | import ast.Stmt 5 | import ast.RcModule 6 | import ast.Expr.* 7 | import ast.TyInfo 8 | import ty.* 9 | 10 | 11 | case object TypeCheck { 12 | def apply(module: RcModule): Unit = { 13 | // check(module) 14 | } 15 | 16 | def check(expr: Expr): Boolean = { 17 | expr match 18 | case Number(v) => true 19 | case Identifier(ident) => ??? 20 | case Bool(b) => true 21 | case Binary(op, lhs, rhs) => lhs == rhs 22 | case Str(str) => true 23 | case If(cond, true_branch, false_branch) => cond == Boolean 24 | case Lambda(args, block) => ??? 25 | case Call(target, args, _) => ??? 26 | case MethodCall(obj, target, args) => ??? 27 | case Block(stmts) => stmts.forall(stmt => check(stmt)) 28 | case Return(expr) => ??? 29 | case Field(expr, ident) => ??? 30 | case Self => true 31 | case Symbol(ident, _) => ??? 32 | case Index(expr, i) => ??? 33 | } 34 | 35 | def check(stmt: Stmt): Boolean = { 36 | stmt match 37 | case Stmt.Local(name, tyInfo, value) => tyInfo != TyInfo.Nil && check(value) 38 | case Stmt.Expr(expr) => check(expr) 39 | case Stmt.While(cond, body) => ??? 40 | case Stmt.Assign(name, value) => ??? 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/test/scala/mir/MIRTestUtil.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import ty.NilType 5 | 6 | import scala.collection.mutable 7 | import scala.collection.mutable.LinkedHashSet 8 | import analysis.Analysis.* 9 | import rclang.compiler.* 10 | import analysis.Analysis.given_LoopAnalysis 11 | import pass.{Analysis, AnalysisManager, getAnalysisResult} 12 | 13 | trait MIRTestUtil { 14 | def mkTree(using bbs: BBsType) = { 15 | val fn = Function("fn", NilType, List(), bbs.values.head, bbs.values.toList) 16 | val predMap = predecessorsMap(fn.bbs) 17 | val nodes = bbs("entry") :: bbs.values.toList ::: List(bbs("exit")) 18 | val builder = DomTreeBuilder() 19 | builder.compute(LinkedHashSet.from(nodes), predMap, bbs("entry")) 20 | } 21 | 22 | implicit def strToBB(str: String)(using bbs: BBsType): BasicBlock = bbs(str) 23 | } 24 | 25 | object MIRTestUtil { 26 | def getLoopInfo(fn: Function): LoopInfo = { 27 | Driver.simplify(fn) 28 | getAnalysisResult[Function, analysis.LoopAnalysis](fn) 29 | } 30 | 31 | def mkLoop(header: String)(bbs: String*) = { 32 | Loop((header :: bbs.toList).map(BasicBlock(_))) 33 | } 34 | 35 | def MakeBBsFunction(bbs: BBsType) = { 36 | val bbList = bbs.values.toList 37 | val fn = new Function("name", NilType, List(), bbs("entry"), bbList) 38 | fn 39 | } 40 | } -------------------------------------------------------------------------------- /src/main/scala/compiler/DependencyGraph.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package compiler 3 | 4 | import ast.RcModule 5 | import tools.Render 6 | import tools.DumpManager 7 | 8 | class DependencyGraph(val modules: Seq[RcModule]) extends Render { 9 | case class Node(module: String) 10 | 11 | var graph = Map[String, List[String]]() 12 | 13 | def addEdge(src: String, dst: String): Unit = { 14 | graph = graph.updatedWith(src) { 15 | case Some(value) => Some(value :+ dst) 16 | case None => Some(List(dst)) 17 | } 18 | } 19 | 20 | def checkCircle = { 21 | // 遍历每一个节点 22 | // 检查每一个节点开始,回到自己的路径 23 | } 24 | 25 | def rendGraph(fileName: String, directory: String) = { 26 | rend(fileName, directory, graph.toList) { (dot, relation) => 27 | val (mod, refs) = relation 28 | dot.node(s"\"$mod\"") 29 | refs.foreach(ref => dot.edge(s"\"$mod\"", s"\"$ref\"")) 30 | } 31 | } 32 | 33 | def moduleMap = modules.map(m => m.name -> m).toMap 34 | 35 | def moduleSet = modules.flatMap(_.refs).toSet 36 | } 37 | 38 | def dependencyResolve(modules: Seq[RcModule]) = { 39 | val graph = DependencyGraph(modules) 40 | modules.foreach(m => { 41 | m.refs.foreach(ref => { 42 | graph.addEdge(m.name, ref) 43 | }) 44 | }) 45 | // graph.rendGraph("dependency", DumpManager.getDumpRoot) 46 | graph 47 | } -------------------------------------------------------------------------------- /src/main/scala/transform/CSE.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import pass.{AnalysisManager, Transform} 5 | import mir.* 6 | 7 | class CSE extends Transform[Function] { 8 | var exprSet = Map[(String, Value, Value), Value]() 9 | // load ptr -> load inst 10 | var avaliableLoads = Map[Value, Value](); 11 | override def run(iRUnitT: Function, AM: AnalysisManager[Function]): Unit = { 12 | var workList = List[Instruction](); 13 | iRUnitT.instructions.foreach { inst => inst match 14 | case Load(ptr) => { 15 | avaliableLoads.get(ptr) match 16 | case Some(value) => { 17 | inst.replaceAllUseWith(value) 18 | workList = inst :: workList 19 | } 20 | case None => { 21 | avaliableLoads = avaliableLoads.updated(ptr, inst) 22 | } 23 | } 24 | case bn @ Binary(op, lhs, rhs) => { 25 | val key = (op, lhs, rhs) 26 | exprSet.get(key) match 27 | case Some(value) => { 28 | inst.replaceAllUseWith(exprSet(key)) 29 | workList = inst :: workList 30 | } 31 | case None => { 32 | exprSet = exprSet.updated(key, inst) 33 | } 34 | } 35 | case _ => 36 | } 37 | 38 | // print(exprSet) 39 | 40 | workList.foreach(_.eraseFromParent) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/ast/Module.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import scala.util.parsing.input.Positional 5 | import ast.Expr.Block 6 | import ast.Ident 7 | 8 | import ty.Typed 9 | 10 | case class RcModule(items: List[Item], name: String = "", refs: List[String] = List()) extends ASTNode { 11 | override def toString: String = s"RcModule:$name\n" + items.mkString("\n") 12 | 13 | def method(name: String): Option[Method] = items.collectFirst { case m: Method if m.decl.name.str == name => m } 14 | } 15 | 16 | sealed class Item extends ASTNode with Typed 17 | 18 | case class Method(decl: MethodDecl, body: Block) extends Item { 19 | def name = decl.name 20 | override def toString: String = s"Method:${name}${decl.generic.map(s => s"<${s.str}>").getOrElse("")}\n${body.toString}" 21 | } 22 | 23 | case class Class(name: Ident, parent: Option[Ident], vars: List[FieldDef], methods:List[Method], generic: Option[Ident] = None) extends Item { 24 | def fieldIndex(name: String): Int = { 25 | vars.indexWhere(v => v.name.str == name) 26 | } 27 | } 28 | 29 | case class FieldDef(name: Ident, ty: TyInfo, initValue: Option[Expr]) extends ASTNode 30 | case class Param(name: Ident, ty: TyInfo) extends ASTNode 31 | case class Params(params: List[Param]) extends ASTNode 32 | case class MethodDecl(name: Ident, inputs: Params, outType: TyInfo, generic: Option[Ident] = None) extends ASTNode -------------------------------------------------------------------------------- /src/main/scala/tools/Mangling.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | import rclang.ast.MethodDecl 5 | import rclang.ty.* 6 | 7 | // https://github.com/gchatelet/gcc_cpp_mangling_documentation 8 | 9 | def manglingTypeMap(ty: Type): String = { 10 | ty match 11 | case Int32Type => "i" 12 | case NilType => "v" 13 | case PointerType(ty) => s"P${manglingTypeMap(ty)}" 14 | case StructType(name, fields) => IdentName(name).toString 15 | case _ => Debugger.unImpl(ty) 16 | 17 | } 18 | 19 | trait Name 20 | 21 | class ManglingFn(name: Name, params: List[Name]) { 22 | override def toString: String = s"_Z$name${params.mkString}" 23 | } 24 | 25 | case class IdentName(name: String) extends Name { 26 | override def toString: String = s"${name.length}$name" 27 | } 28 | 29 | case class ScopeName(name: String, subStr: List[Name]) extends Name { 30 | override def toString: String = s"N$name${subStr.mkString}E" 31 | } 32 | 33 | def mangling(fn: MethodDecl, outer: List[String]): String = { 34 | val module = outer.head 35 | val klass = outer(1) 36 | val params = fn.inputs.params match 37 | case p if p.isEmpty => List(IdentName(manglingTypeMap(NilType))) 38 | case p => p.map(param => IdentName(manglingTypeMap(Infer.translate(param.ty).value))) 39 | ManglingFn( 40 | ScopeName(module, List(IdentName(klass), IdentName(fn.name.str))), 41 | params 42 | ).toString 43 | } -------------------------------------------------------------------------------- /src/main/scala/codegen/RegisterAllocation.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import pass.{AnalysisManager, Transform} 5 | 6 | class StackRegisterAllocation extends Transform[MachineFunction] { 7 | def debug(str: String): Unit = { 8 | if(false) { 9 | println(str) 10 | } 11 | } 12 | 13 | def run(mf: MachineFunction, am: AnalysisManager[MachineFunction]) = { 14 | debug(s"generate for ${mf.name}") 15 | val frameInfo = mf.frameInfo 16 | var regMap = Map[VReg, StackItem]() 17 | val allVReg = mf.instructions.flatMap(m => m.operands).map(_ match 18 | case v: VReg => Some(v) 19 | case _ => None).filter(v => v.isDefined && !v.get.force).map(_.get) 20 | allVReg.foreach(reg => { 21 | debug(reg.toString) 22 | debug(reg.instParent.operands.toString) 23 | val item = regMap.get(reg) match { 24 | case Some(value) => debug("yes"); value 25 | case None => { 26 | // update FrameInfo 27 | debug("no") 28 | val tmpItem = frameInfo.addItem(TmpItem(reg.size)) 29 | regMap = regMap.updated(reg, tmpItem) 30 | tmpItem 31 | } 32 | } 33 | debug(s"$reg -> $item") 34 | // replace operand 35 | val frameIndex = FrameIndex(item.offset, reg.size) 36 | reg.replaceFromParent(frameIndex) 37 | debug(frameIndex.instParent.operands.toString) 38 | }) 39 | } 40 | } -------------------------------------------------------------------------------- /src/test/scala/ty/TypedTranslatorTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | import org.scalatest.funspec.AnyFunSpec 4 | import ast.Expr.* 5 | import ast.Stmt.* 6 | import ast.* 7 | import ast.ImplicitConversions.* 8 | import ast.{ASTBuilder, RcModule} 9 | 10 | import ty.Int32Type 11 | 12 | class TypedTranslatorTest extends AnyFunSpec with ASTBuilder { 13 | 14 | def getFirstMethodFromModule(module: RcModule): Method = { 15 | module.items.head match 16 | case m: Method => m 17 | case _ => ??? 18 | } 19 | describe("AddLocal") { 20 | it("succeed") { 21 | val m = RcModule(List(makeASTMethod("f", block = List( 22 | Local("a", TyInfo.Infer, Number(1)), 23 | Stmt.Expr(Identifier("a")) 24 | )))) 25 | val result = TypedTranslator(TyCtxt())(m) 26 | val ty = getFirstMethodFromModule(m).body.stmts.last.asInstanceOf[Stmt.Expr].expr.asInstanceOf[Identifier].ty 27 | assert(ty == Int32Type) 28 | } 29 | } 30 | 31 | describe("if") { 32 | it("succeed") { 33 | val m = mkFnInMod("f", block = List( 34 | Stmt.Expr(makeIf( 35 | Bool(true), 36 | makeExprBlock(Number(1)), 37 | makeExprBlock(Number(2)), 38 | )) 39 | )) 40 | val result = TypedTranslator(TyCtxt())(m) 41 | val ty = getFirstMethodFromModule(m).body.stmts.head.asInstanceOf[Stmt.Expr].expr.asInstanceOf[If].ty 42 | assert(ty == Int32Type) 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/test/scala/codegen/IRTranslatorTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import mir.* 5 | 6 | import rclang.ty.NilType 7 | 8 | class IRTranslatorTest extends RcTestBase { 9 | describe("binary") { 10 | it("ok") { 11 | val tt = IRTranslator() 12 | val bn = Binary("Add", Integer(1), Integer(2)); 13 | val bb = BasicBlock("0", List(bn)) 14 | tt.visitBB(bb) 15 | } 16 | } 17 | 18 | def addFun = { 19 | val bn = Binary("Add", Integer(1), Integer(2)); 20 | val entry = BasicBlock("0", List(bn)) 21 | val f = Function("f", NilType, List(), entry, List(entry)) 22 | f 23 | } 24 | 25 | describe("call") { 26 | it("ok") { 27 | val f = addFun 28 | val tt = IRTranslator() 29 | val call = Call(f, List()); 30 | val bb = BasicBlock("0", List(call)) 31 | tt.visitBB(bb) 32 | } 33 | } 34 | 35 | describe("branch") { 36 | ignore("not implement") { 37 | val exitBB = BasicBlock("b3", List()) 38 | val elseBB = BasicBlock("b2", List( 39 | Branch(exitBB) 40 | )) 41 | val thenBB = BasicBlock("b1", List( 42 | Branch(exitBB) 43 | )) 44 | val entry = BasicBlock("b0", List( 45 | CondBranch(Bool(true), thenBB, elseBB) 46 | )) 47 | val tt = IRTranslator() 48 | tt.visitBB(entry) 49 | } 50 | } 51 | 52 | describe("function") { 53 | it("ok") { 54 | val f = addFun 55 | val t = IRTranslator() 56 | t.visit(f) 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/test/scala/analysis/LoopAnalysisTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import mir.* 5 | 6 | import org.scalatest.BeforeAndAfter 7 | import pass.{Analysis, AnalysisManager, getAnalysisResult} 8 | import compiler.Driver.{getSrc, parse, simplify, typeProc} 9 | import analysis.Analysis.given_LoopAnalysis 10 | import mir.MIRTestUtil.* 11 | import rclang.ty.NilType 12 | 13 | import java.io.File 14 | 15 | class LoopAnalysisTest extends RcTestBase { 16 | describe("normal loop") { 17 | it("simple") { 18 | val bbs = mkBBs( 19 | "entry" -> "header", 20 | "header" -> "body", 21 | "body" -> "exit", 22 | "body" -> "header") 23 | val fn = MakeBBsFunction(bbs) 24 | val loopInfo = getLoopInfo(fn) 25 | loopInfo.loops should be (Map(fn.getBB("header") -> Loop(List("header", "body").map(fn.getBB)))) 26 | } 27 | 28 | it("withLatch") { 29 | val bbs = mkBBs( 30 | "entry" -> "header", 31 | "header" -> "body", 32 | "body" -> "latch", 33 | "latch" -> "header", 34 | "latch" -> "exit") 35 | val bbList = bbs.values.toList 36 | val fn = new Function("name", NilType, List(), bbs("entry"), bbList) 37 | val loopInfo = getLoopInfo(fn) 38 | loopInfo.loops should be(Map(fn.getBB("header") -> Loop(List("header", "body", "latch").map(fn.getBB)))) 39 | } 40 | 41 | it("continue") { 42 | 43 | } 44 | } 45 | 46 | describe("nested loop") { 47 | it("ok") { 48 | 49 | } 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/test/scala/parser/StmtParserTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | import lexer.Keyword.* 6 | import lexer.Punctuation.* 7 | import lexer.Literal.* 8 | import lexer.Delimiter.* 9 | import lexer.Ident.* 10 | import ast.* 11 | import ast.Expr.Number 12 | import ast.ImplicitConversions.* 13 | import lexer.Token 14 | 15 | class StmtParserTest extends BaseParserTest with ExprParser { 16 | def apply(tokens: Seq[Token]): Either[RcParserError, (Stmt, Input)] = { 17 | doParserImpl(tokens, statement) 18 | } 19 | 20 | describe("expr") { 21 | it("succeed") { 22 | expectSuccess(List(NUMBER(1), EOL), Stmt.Expr(Expr.Number(1))) 23 | } 24 | } 25 | 26 | def expectSuccess(token: Seq[Token], expect: Stmt): Unit = { 27 | apply(token) match { 28 | case Left(value) => assert(false, value.msg) 29 | case Right((ast, reader)) => assert(ast == expect); assert(reader.atEnd, reader) 30 | } 31 | } 32 | 33 | describe("local") { 34 | it("succeed") { 35 | expectSuccess(mkLocalStmt("a", NUMBER(1)), Stmt.Local("a", TyInfo.Infer, Expr.Number(1))) 36 | } 37 | } 38 | 39 | describe("assign") { 40 | it("succeed") { 41 | expectSuccess(mkAssStmt("a", NUMBER(1)), Stmt.Assign("a", Expr.Number(1))) 42 | } 43 | } 44 | 45 | describe("while") { 46 | it("succeed") { 47 | expectSuccess( 48 | makeWhile(TRUE, mkAssStmt("a", NUMBER(1))), 49 | Stmt.While(trueExpr, makeStmtBlock(Stmt.Assign("a", Expr.Number(1))))) 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /src/test/scala/mir/DFCalculatorTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import ty.NilType 5 | import scala.collection.mutable.LinkedHashSet 6 | 7 | class DFCalculatorTest extends RcTestBase with MIRTestUtil { 8 | // override val bbs = mkBBs( 9 | // "entry" -> "0", 10 | // "0" -> "1", 11 | // "1" -> "2", 12 | // "2" -> "3", 13 | // "3" -> "4", 14 | // "3" -> "12", 15 | // "4" -> "5", 16 | // "4" -> "1", 17 | // "5" -> "8", 18 | // "5" -> "6", 19 | // "6" -> "7", 20 | // "6" -> "4", 21 | // "7" -> "exit", 22 | // "8" -> "9", 23 | // "8" -> "1", 24 | // "9" -> "10", 25 | // "10" -> "13", 26 | // "10" -> "11", 27 | // "11" -> "9", 28 | // "11" -> "8", 29 | // "12" -> "2", 30 | // "12" -> "1", 31 | // "13" -> "9", 32 | // "13" -> "8", 33 | // ) 34 | 35 | describe("build") { 36 | val order = List(0, 1, 2, 5, 6, 8, 7, 3, 4) 37 | val bbs = mkBBsByOrder( 38 | order, 39 | "entry" -> "0", 40 | "0" -> "1", 41 | "1" -> "2", 42 | "2" -> "3", 43 | "1" -> "5", 44 | "5" -> "6", 45 | "6" -> "7", 46 | "5" -> "8", 47 | "8" -> "7", 48 | "7" -> "3", 49 | "3" -> "1", 50 | "3" -> "4", 51 | "4" -> "exit", 52 | ) 53 | it("succeed") { 54 | given BBsType = bbs 55 | val tree = mkTree 56 | given DomTree = tree 57 | iDomCompute(tree, "entry") 58 | val result = DFCalculator(tree).run("1") 59 | assert(result.toSet == List("1", "3", "7").map(tree(_).basicBlock).toSet) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/scala/analysis/LoopAnalysis.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import pass.{Analysis, AnalysisManager} 5 | import mir.* 6 | import analysis.Analysis.given_DomTreeAnalysis 7 | 8 | class LoopAnalysis extends Analysis[Function] { 9 | override type ResultT = LoopInfo 10 | 11 | override def run(irUnit: Function, AM: AnalysisManager[Function]) = { 12 | val domTree = AM.getResult[DomTreeAnalysis](irUnit) 13 | println("-------begin------") 14 | var loopRanges = Map[BasicBlock, List[BasicBlock]]() 15 | domTree.visit(node => { 16 | println(s"node:${node.name} start") 17 | node.basicBlock.successors.find(succ => { 18 | // back edge: successor dom current node 19 | domTree(succ) dom node 20 | }) match 21 | case Some(succ) => { 22 | // todo: 1. nest loop 23 | // todo: 2. scc 24 | // todo: 3. multi back edge 25 | println(s"succ:${succ.name} node:${node.name}") 26 | // 从succ出发到node到所有bb都是这个循环体的部分 27 | // 因为存在支配关系,因此从succ,也就是header开始的所有通路都通向node 28 | val bbsInLoop = loopBasicBlocks(succ, node.basicBlock) 29 | loopRanges = loopRanges + (succ -> bbsInLoop) 30 | } 31 | case None => 32 | println("end") 33 | }) 34 | println("-------end------") 35 | // from key to value 36 | 37 | println(loopRanges) 38 | val result = loopRanges.map((header, list) => { 39 | println(s"header:${header.name}") 40 | (header -> Loop(list)) 41 | }).toMap 42 | println(result) 43 | LoopInfo(result) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/mir/Value.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import ty.* 5 | 6 | import scala.util.parsing.input.Positional 7 | 8 | class Value extends Typed with Positional { 9 | var users: List[Use] = List() 10 | var name: String = "" 11 | def addUser(v: User): Use = { 12 | val u = Use(this, v) 13 | users = u::users 14 | u 15 | } 16 | } 17 | 18 | class GlobalValue extends Value 19 | 20 | class GlobalVariable extends GlobalValue { 21 | } 22 | 23 | case object NilValue extends GlobalValue { 24 | ty = NilType 25 | } 26 | 27 | val varOps = -1 28 | enum NumOps: 29 | case Fixed(num: Int) 30 | case Dynamic 31 | 32 | class User(numOps: Int) extends Value { 33 | var operands: List[Use] = List.fill(numOps)(Use(null, null)) 34 | def setOperands(ops: List[Value]) = { 35 | operands = ops.map(_.addUser(this)) 36 | // ops.foreach(op => { 37 | // op.addUser(this) 38 | // }) 39 | } 40 | 41 | // increment users of value 42 | def setOperand(i: Int, v: Value) = { 43 | operands = operands.updated(i, v.addUser(this)) 44 | } 45 | 46 | def getOperand(i: Int) = operands(i).value 47 | def getOperands = operands.map(_.value) 48 | 49 | def replaceAllUseWith(v: Value) = { 50 | users.foreach(use => { 51 | assert(use.value == this) 52 | use.value = v 53 | }) 54 | users.foreach(use => { 55 | assert(use.value != this) 56 | }) 57 | } 58 | } 59 | 60 | // todo:implicit cast, use -> value, uses -> values 61 | 62 | 63 | case class Use(var value: Value, var parent: User) { 64 | // override def toString: String = s"Use(${toStr(value)} => ${toStr(parent)})" 65 | override def toString: String = s"Use(${toStr(value)})" 66 | 67 | private def toStr[T](v: T) = if v == null then "" else v.toString 68 | } 69 | -------------------------------------------------------------------------------- /src/test/scala/codegen/CodegenIRTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | class CodegenIRTest extends RcTestBase { 5 | describe("TestSetter") { 6 | it("succ") { 7 | val src = VReg(1) 8 | val dst = VReg(2) 9 | val newSrc = VReg(3) 10 | val load = LoadInst(dst, src) 11 | load.addr = newSrc 12 | assert(load.addr.asInstanceOf[VReg].num == 3) 13 | } 14 | } 15 | 16 | describe("TestPatternMatch") { 17 | it("succ") { 18 | val src = VReg(1) 19 | val dst = VReg(2) 20 | val load = LoadInst(dst, src) 21 | load match 22 | case LoadInst(d, s) => d == dst && s == src 23 | case _ => ??? 24 | val newDst = VReg(3) 25 | load.dst = newDst 26 | load match 27 | case LoadInst(d, s) => d == newDst && s == src 28 | case _ => ??? 29 | } 30 | } 31 | 32 | describe("ReplaceFromParent") { 33 | it("succ") { 34 | val src = VReg(1) 35 | val dst = VReg(2) 36 | val newSrc = VReg(3) 37 | // todo: change set parent 38 | val load = LoadInst(dst, src) 39 | src.replaceFromParent(newSrc) 40 | assert(newSrc == load.addr) 41 | assert(newSrc.instParent == load) 42 | } 43 | } 44 | 45 | describe("RemoveFromParent") { 46 | it("succ") { 47 | val load1 = LoadInst(VReg(0), VReg(1)) 48 | val load2 = LoadInst(VReg(1), VReg(2)) 49 | val store = StoreInst(VReg(1), VReg(3)) 50 | val mbb = MachineBasicBlock(List(load1, load2, store), null, null, "mbb") 51 | load2.removeFromParent() 52 | assert(mbb.instList == List(load1, store)) 53 | } 54 | } 55 | 56 | // todo: do this test 57 | // describe("getVReg") { 58 | // it("succ") { 59 | // val reg = VReg(0) 60 | // val load = LoadInst(reg, VReg(1)) 61 | // val store = StoreInst(reg, VReg(1)) 62 | // reg.instParent 63 | // } 64 | // } 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/mir/IRBuilder.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import ast.* 5 | import ty.* 6 | 7 | var idCounter = 0 8 | 9 | case class IRBuilder() { 10 | def makeId = { 11 | val s = idCounter.toString 12 | idCounter += 1 13 | s 14 | } 15 | var currentFn: Function = _ 16 | var basicBlocks: List[BasicBlock] = List(BasicBlock(makeId)) 17 | var currentBasicBlock: BasicBlock = basicBlocks.last 18 | 19 | private def insert[T <: Instruction](inst: T): T = { 20 | inst.parent = currentBasicBlock 21 | currentBasicBlock.insert(inst) 22 | } 23 | 24 | def insertBasicBlock(block: BasicBlock = createBB()): BasicBlock = { 25 | basicBlocks = basicBlocks :+ block 26 | block.parent = currentFn 27 | currentBasicBlock = basicBlocks.last 28 | basicBlocks.last 29 | } 30 | 31 | def createBB() = { 32 | val bb = BasicBlock(makeId) 33 | bb.parent = currentFn 34 | bb 35 | } 36 | def createPHINode() : PhiNode = insert(PhiNode()) 37 | def createCondBr(cond: Value, True: BasicBlock, False: BasicBlock) : CondBranch = insert(CondBranch(cond, True, False)) 38 | def createBr(dest: BasicBlock) : Branch = insert(Branch(dest)) 39 | def createCall(func: Function, args: List[Value]) : Call = insert(Call(func, args)) 40 | def createIntrinsic(intr: String, args: List[Value]) : Intrinsic = insert(Intrinsic(intr, args)) 41 | def createReturn(value: Value) : Return = { 42 | val r = insert(Return(value)) 43 | r.pos = value.pos 44 | r 45 | } 46 | 47 | def createAlloc(name: String, typ: Type) : Alloc = insert(Alloc(name, typ)) 48 | def createLoad(value: Value): Load = insert(Load(value)) 49 | def createStore(value: Value, ptr: Value) : Store = insert(Store(value, ptr)) 50 | def createBinary(op: String, lhs: Value, rhs: Value) : Binary = insert(Binary(op, lhs, rhs)) 51 | def createGetElementPtr(value: Value, index: Value, ty: Type) = insert(GetElementPtr(value, index, ty)) 52 | } 53 | -------------------------------------------------------------------------------- /src/test/scala/ty/TyCtxtTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | import org.scalatest.funspec.AnyFunSpec 4 | import org.scalatest.BeforeAndAfter 5 | import ast.ImplicitConversions.* 6 | import ty.* 7 | import ast.Ident 8 | 9 | class TyCtxtTest extends AnyFunSpec with BeforeAndAfter { 10 | var tyCtxt: TyCtxt = _ 11 | before { 12 | tyCtxt = TyCtxt() 13 | tyCtxt.global = Map(Ident("a") -> Int32Type, Ident("b") -> FloatType) 14 | } 15 | 16 | describe("enter") { 17 | it("returnCallResult") { 18 | val ty = Nil 19 | val t = tyCtxt.enter((() => { 20 | ty 21 | })()) 22 | assert(t == ty) 23 | } 24 | 25 | it("block") { 26 | val ty = StringType 27 | val t = tyCtxt.enter({ 28 | ty 29 | }) 30 | assert(t == ty) 31 | } 32 | } 33 | 34 | describe("addLocal") { 35 | it("succeed") { 36 | tyCtxt.enter(() => { 37 | val id = Ident("n") 38 | val ty = NilType 39 | tyCtxt.addLocal(id, ty) 40 | assert(tyCtxt.lookup(id).contains(ty)) 41 | }) 42 | assert(tyCtxt.lookup(Ident("n")).isEmpty) 43 | } 44 | 45 | it("sameWithGlobal") { 46 | tyCtxt.enter(() => { 47 | val id = Ident("a") 48 | val ty = NilType 49 | tyCtxt.addLocal(id, ty) 50 | assert(tyCtxt.lookup(id).contains(ty)) 51 | }) 52 | } 53 | 54 | it("nested") { 55 | tyCtxt.enter(() => { 56 | val id = Ident("a") 57 | val ty = NilType 58 | tyCtxt.addLocal(id, ty) 59 | tyCtxt.enter(testEnter(id)) 60 | assert(tyCtxt.enter(id) == StringType) 61 | assert(tyCtxt.lookup(id).contains(ty)) 62 | }) 63 | } 64 | 65 | def testEnter(id: Ident): Type = { 66 | assert(tyCtxt.local.isEmpty) 67 | val innerTy = StringType 68 | tyCtxt.addLocal(id, innerTy) 69 | assert(tyCtxt.lookup(id).contains(innerTy)) 70 | innerTy 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/analysis/SymScanner.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import ast.{ASTNode, ASTVisitor, Item, RcModule, Stmt, TyInfo} 5 | import ast.* 6 | import ast.Ident 7 | import ast.ImplicitConversions.* 8 | import tools.{ClassEntry, GlobalTable, LocalTable, State, toState} 9 | 10 | import scala.collection.mutable 11 | import scala.collection.mutable.Map 12 | 13 | /** 14 | * Recursively visit ast, generate info 15 | * 1. class list 16 | * 2. method local table 17 | * 3. global data 18 | */ 19 | object SymScanner extends ASTVisitor { 20 | var currentClass: State[ClassEntry] = null 21 | var currentMethod: State[LocalTable] = null 22 | var classTable: Map[String, ClassEntry] = null 23 | 24 | def init = { 25 | this.currentClass = new ClassEntry(Class(Def.Kernel, None, List(), List())) 26 | this.currentMethod = new LocalTable(null) 27 | this.classTable = Map(Def.Kernel -> currentClass.value) 28 | } 29 | 30 | def apply(ast: RcModule): GlobalTable = { 31 | init 32 | visit(ast) 33 | new GlobalTable(classTable, ast) 34 | } 35 | 36 | /** 37 | * @param klass 38 | * 1. mk new ClassTable 39 | * 2. visit subnode and update table 40 | * 3. add to ClassTable 41 | * @return 42 | */ 43 | override def visit(klass: Class): R = { 44 | val result = currentClass.by(new ClassEntry(klass)){ () => 45 | super.visit(klass) 46 | } 47 | classTable(klass.name) = result 48 | } 49 | 50 | /** 51 | * @param method 52 | * 1. mk new MethodTable 53 | * 2. visit subnode and update table 54 | * 3. add to Current ClassTable's MethodTable 55 | * @return 56 | */ 57 | override def visit(method: Method): R = { 58 | val result = currentMethod.by(new LocalTable(method)) { () => 59 | super.visit(method) 60 | } 61 | currentClass.value.addMethod(result) 62 | } 63 | 64 | override def visit(stmt: Stmt): R = { 65 | stmt match { 66 | case local: Stmt.Local => currentMethod.value += local 67 | case _ => 68 | } 69 | } 70 | } -------------------------------------------------------------------------------- /src/test/scala/analysis/SymScannerTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import ast.{ASTBuilder, Expr, RcModule} 5 | import ast.Expr.* 6 | import ast.ImplicitConversions.* 7 | import analysis.SymScanner 8 | 9 | import org.scalatest.funspec.AnyFunSpec 10 | import rclang.tools.LocalEntry 11 | 12 | class SymScannerTest extends AnyFunSpec with ASTBuilder { 13 | describe("Class") { 14 | it("default Kernel") { 15 | assert(SymScanner(RcModule(List())).classes == Set(Def.Kernel)) 16 | } 17 | it("ok") { 18 | val global = SymScanner(RcModule(List(mkASTClass("F1"), mkASTClass("F2")))) 19 | assert(global.classes == Set(Def.Kernel, "F1", "F2")) 20 | } 21 | } 22 | 23 | describe("KernelMethod") { 24 | it("ok") { 25 | val global = SymScanner(RcModule(List(makeASTMethod("f1"), makeASTMethod("f2")))) 26 | assert(global.classTable(Def.Kernel).methods.keys == Set("f1", "f2")) 27 | } 28 | } 29 | 30 | describe("LocalTable") { 31 | val localB = makeLocal("b", Number(3)) 32 | val localA = makeLocal("a", Number(1)) 33 | val f1 = makeASTMethod("f1", 34 | block = List(localB, localA)) 35 | val localE = makeLocal("e", Number(7)) 36 | val localF = makeLocal("f", Number(3)) 37 | val f2 = makeASTMethod("f2", 38 | block = List(localE, localF)) 39 | 40 | it("SingleMethod") { 41 | val module = RcModule(List(f1)) 42 | val t = SymScanner(module).kernel.methods("f1") 43 | assert(t.locals("b") == LocalEntry(0, localB)) 44 | assert(t.locals("a") == LocalEntry(1, localA)) 45 | } 46 | 47 | it("MultiMethod") { 48 | val module = RcModule(List(f2, f1)) 49 | val t = SymScanner(module).kernel.methods 50 | val f1T = t("f1") 51 | assert(f1T.locals("b") == LocalEntry(0, localB)) 52 | assert(f1T.locals("a") == LocalEntry(1, localA)) 53 | val f2T = t("f2") 54 | assert(f2T.locals("e") == LocalEntry(0, localE)) 55 | assert(f2T.locals("f") == LocalEntry(1, localF)) 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/test/scala/analysis/ModuleValidateTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | import ast.ImplicitConversions.* 6 | import ast.* 7 | import analysis.ModuleValidate 8 | 9 | class ModuleValidateTest extends AnyFunSpec with ASTBuilder with ModuleValidate { 10 | describe("ASTValidateTest") { 11 | describe("fieldDefValid") { 12 | describe("SingleDef") { 13 | it("true") { 14 | val f = FieldDef("a", TyInfo.Infer, None) 15 | assert(fieldDefValid(f) == List(ValidateError(f, "Field without initValue need spec Type"))) 16 | } 17 | 18 | it("false") { 19 | val f1 = FieldDef("a", TyInfo.Spec("Int"), None) 20 | assert(fieldDefValid(f1) == List()) 21 | val f2 = FieldDef("a", TyInfo.Infer, Some(Expr.Number(1))) 22 | assert(fieldDefValid(f2) == List()) 23 | } 24 | } 25 | } 26 | 27 | it("should methodDeclValid") { 28 | 29 | } 30 | 31 | describe("methodsDeclValid") { 32 | it("true") { 33 | val fs = List( 34 | makeASTMethod("f1"), 35 | makeASTMethod("f2"), 36 | makeASTMethod("f1") 37 | ).map(_.decl) 38 | assert(methodsDeclValid(fs) == List(ValidateError(fs(2).name, "Method Ident(f1) Dup"))) 39 | } 40 | 41 | it("false") { 42 | val fs = List( 43 | makeASTMethod("f1"), 44 | makeASTMethod("f2"), 45 | makeASTMethod("f3") 46 | ).map(_.decl) 47 | assert(methodsDeclValid(fs) == List()) 48 | } 49 | } 50 | 51 | it("should methodsValid") { 52 | 53 | } 54 | 55 | describe("dupCheck") { 56 | it("true") { 57 | val a = List("a", "b", "c", "c").map(Ident) 58 | assert(dupCheck(a, "Name") == List(ValidateError(a(3), s"Name Ident(c) Dup"))) 59 | } 60 | 61 | it("false") { 62 | val a = List("a", "b", "c", "d").map(Ident) 63 | assert(dupCheck(a, "Name") == List()) 64 | } 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/scala/ast/ASTBuilder.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import ast.* 5 | import ast.Expr.* 6 | import ast.ImplicitConversions.* 7 | 8 | trait ASTBuilder { 9 | def mkASTField(name: String, ty: String) = FieldDef(name, TyInfo.Spec(ty), None) 10 | def mkASTClass(name: String) = Class(name, None, List(), List()) 11 | def mkASTClass(name: String, parent: String) = Class(name, Some(parent), List(), List()) 12 | def mkASTClass(name: String, method: Method) = Class(name, None, List(), List(method)) 13 | def mkASTClass(name: String, field: FieldDef) = Class(name, None, List(field), List()) 14 | def mkASTClass(name: String, field: FieldDef, method: Method) = Class(name, None, List(field), List(method)) 15 | def makeExprBlock(cond: Expr): Block = Block(List(Stmt.Expr(cond))) 16 | def makeIf(cond: Expr, thenExpr: Expr, elseExpr: Expr) = If(cond, makeExprBlock(thenExpr), Some(elseExpr)) 17 | def makeLastIf(cond: Expr, thenExpr: Expr, elseExpr: Expr) = If(cond, makeExprBlock(thenExpr), Some(makeExprBlock(elseExpr))) 18 | def makeIf(cond: Expr, thenExpr: Expr, elseExpr: Option[Expr]) = If(cond, makeExprBlock(thenExpr), elseExpr) 19 | def mkASTMemField(name: String, field: String) = Expr.Field(Expr.Identifier(name), field) 20 | def mkASTMemCall(name: String, field: String, args: List[Expr] = List()) = 21 | Expr.MethodCall(Expr.Identifier(name), field, List()) 22 | def makeStmtBlock(cond: Stmt): Block = Block(List(cond)) 23 | def makeLocal(name: String, value: Expr): Stmt.Local = Stmt.Local(name, TyInfo.Infer, value) 24 | def makeASTMethod(name: String, 25 | params: List[Param] = List(), 26 | retType:TyInfo = TyInfo.Infer, 27 | block: List[Stmt] = List(), 28 | generic: Option[String] = None): Method = { 29 | Method(MethodDecl(name, Params(params), retType, generic.map(Ident)), Block(block)) 30 | } 31 | 32 | def mkFnInMod(name: String, 33 | params: List[Param] = List(), 34 | retType:TyInfo = TyInfo.Nil, 35 | block: List[Stmt] = List()) = RcModule(List(makeASTMethod(name, params, retType, block))) 36 | } 37 | -------------------------------------------------------------------------------- /src/main/scala/ast/Expr.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | 4 | import scala.util.parsing.input.Positional 5 | import ast.Stmt 6 | import ast.Ident 7 | 8 | import BinaryOp.* 9 | import ty.Typed 10 | import ty.Type 11 | import ty.Infer 12 | import ast.ImplicitConversions.strToId 13 | 14 | def uuid = java.util.UUID.randomUUID.toString 15 | 16 | def lambdaToMethod(lambda: Expr.Lambda): Method = lambdaToMethod(lambda.args, lambda.block) 17 | 18 | def lambdaToMethod(args: Params, body: Expr): Method = { 19 | val blockBody: Expr.Block = body match 20 | case b: Expr.Block => b 21 | case _ => Expr.Block(List(Stmt.Expr(body))) 22 | Method(MethodDecl(s"lambda_${uuid}", args, TyInfo.Infer), blockBody) 23 | } 24 | 25 | enum BinaryOp(op: String) extends Positional : 26 | case Add extends BinaryOp("+") 27 | case Sub extends BinaryOp("-") 28 | case Mul extends BinaryOp("*") 29 | case Div extends BinaryOp("/") 30 | case EQ extends BinaryOp("==") 31 | case LT extends BinaryOp("<") 32 | case GT extends BinaryOp(">") 33 | 34 | def strToOp(op: String): BinaryOp = { 35 | op match { 36 | case "+" => Add 37 | case "-" => Sub 38 | case "*" => Mul 39 | case "/" => Div 40 | case "==" => EQ 41 | case "<" => LT 42 | case ">" => GT 43 | } 44 | } 45 | 46 | enum Expr extends ASTNode with Typed : 47 | case Number(v: Int) 48 | case Identifier(ident: Ident) 49 | case Bool(b: Boolean) 50 | case Binary(op: BinaryOp, lhs: Expr, rhs: Expr) 51 | case Str(str: String) 52 | // false -> elsif | else 53 | case If(cond: Expr, true_branch: Block, false_branch: Option[Expr]) 54 | case Lambda(args: Params, block: Block) 55 | case Call(target: Ident, args: List[Expr], generic: Option[Ident] = None) 56 | case MethodCall(obj: Expr, target: Ident, args: List[Expr]) 57 | case Block(stmts: List[Stmt]) 58 | case Return(expr: ast.Expr) 59 | case Field(expr: Expr, ident: Ident) 60 | case Self 61 | // symbol 62 | case Symbol(ident: Ident, generic: Option[Ident] = None) 63 | case Index(expr: Expr, i: Expr) 64 | case Array(len: Int, initValues: List[Expr]) 65 | 66 | override def toString: String = this match 67 | case Expr.Block(stmts) => s"{\n${stmts.mkString("\n")}\n}" 68 | case _ => s"${super.toString}:${ty}" -------------------------------------------------------------------------------- /src/main/scala/parser/ModuleParser.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import lexer.Keyword.* 5 | import lexer.Punctuation.* 6 | import lexer.Literal.* 7 | import lexer.Delimiter.* 8 | import lexer.Ident.* 9 | import parser.RcParser 10 | import ast.* 11 | 12 | import scala.util.parsing.combinator.Parsers 13 | import scala.util.parsing.input.{NoPosition, Position, Reader} 14 | 15 | trait ModuleParser extends RcBaseParser with ExprParser { 16 | def define: Parser[Item] = method 17 | 18 | def params: Parser[Params] = positioned { 19 | parSround(repsep(idWithTy, COMMA)) ^^ (params => Params(params.map(Param(_,_)))) 20 | } 21 | 22 | def method: Parser[Item] = positioned { 23 | oneline(DEF ~> id ~ template.? ~ params ~ typeLimit.?) ~ block <~ END ^^ { 24 | case id ~ temp ~ params ~ ty ~ block => { 25 | val tyInfo = ty.getOrElse(TyInfo.Infer) 26 | Method(MethodDecl(id, params, tyInfo, temp), block) 27 | } 28 | } 29 | } 30 | 31 | // noneItem should be same level as oneline item 32 | def item: Parser[Item] = positioned { 33 | oneline(method | classDefine) 34 | } 35 | 36 | def module: Parser[RcModule] = positioned { 37 | (importModule).* ~ (item | noneItem).* ^^ { 38 | case refs ~ items => RcModule(items.filter(_ != Empty).map(_.asInstanceOf[Item]), "", refs.map(_.str)) 39 | } 40 | } 41 | 42 | def field: Parser[FieldDef] = positioned { 43 | oneline(VAR ~> (id <~ COLON) ~ sym ~ (EQL ~> expr).?) ^^ { 44 | case id ~ ty ~ value => FieldDef(id, TyInfo.Spec(ty), value) 45 | } 46 | } 47 | 48 | def noneItem: Parser[ASTNode] = positioned { 49 | EOL ^^^ Empty 50 | } 51 | 52 | def classDefine: Parser[Item] = positioned { 53 | oneline(CLASS ~> sym ~ template.? ~ (OPERATOR("<") ~> sym).?) ~ log(item | field | noneItem)("class member").* <~ log(END)("class end") ^^ { 54 | case klass ~ temp ~ parent ~ defines => 55 | Class(klass, parent, 56 | defines.filter(_.isInstanceOf[FieldDef]).map(_.asInstanceOf[FieldDef]), 57 | defines.filter(_.isInstanceOf[Method]).map(_.asInstanceOf[Method]), 58 | temp).asInstanceOf[Item] 59 | } 60 | } 61 | 62 | def importModule: Parser[STRING] = positioned { 63 | oneline(IMPORT ~> stringLiteral) ^^ { str => str } 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/scala/mir/IR.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | import ty.{FnType, NilType, Type} 4 | 5 | class BasicBlock(nameStr: String, var stmts: List[Instruction] = List()) extends Value with InFunction { 6 | name = nameStr 7 | stmts.foreach(inst => inst.parent = this) 8 | def terminator: Terminator = stmts.last.asInstanceOf[Terminator] 9 | 10 | // todo: this is append 11 | def insert[T <: Instruction](i: T): T = { 12 | stmts = stmts :+ i 13 | i 14 | } 15 | 16 | def successors = terminator.successors 17 | 18 | def preds: List[BasicBlock] = ??? 19 | 20 | override def toString: String = s"BasicBlock:$name" 21 | } 22 | 23 | def bbToStr(bb: BasicBlock): String = { 24 | s"--- BasicBlock:${bb.name} ---\n${traverseInst(bb.stmts).mkString("\n")}" 25 | } 26 | 27 | case object Function { 28 | def Empty(name: String) = Function(name, NilType, List(), null, List()) 29 | } 30 | 31 | case class Function(private val fnName: String, 32 | var retType: Type, 33 | var argument: List[Argument], 34 | var entry: BasicBlock, 35 | var bbs: List[BasicBlock] = List()) extends GlobalValue { 36 | name = fnName 37 | // todo: bad design 38 | if(entry != null) { 39 | entry.parent = this 40 | } 41 | 42 | var strTable = List[Str]() 43 | def instructions = bbs.flatMap(_.stmts) 44 | 45 | def fnType = FnType(retType, argument.map(_.ty)) 46 | 47 | def getBB(name: String): BasicBlock = bbs.find(_.name == name).get 48 | 49 | override def toString: String = { 50 | val sign = s"$fnName(${argument.mkString(",")})\n" 51 | val body = s"{\n${bbs.map(bbToStr).mkString("\n")}\n}" 52 | sign + body 53 | } 54 | } 55 | 56 | case class Module(var name: String = "MainModule", var fnTable: Map[String, Function] = Map()) { 57 | var globalVariables: List[GlobalVariable] = List() 58 | var types: Set[Type] = Set() 59 | var context: RcContext = null 60 | def fns = fnTable.values.toList 61 | 62 | override def toString: String = name + "\n" + fnTable.values.map(_.toString).mkString("\n\n") 63 | } 64 | 65 | case class RcContext() { 66 | var modules: List[Module] = List() 67 | } 68 | 69 | case class Loop(var bbs: List[BasicBlock], var parentLoop: Loop = null, var subLoop: List[Loop] = List()) { 70 | // header is compare 71 | def header = bbs.head 72 | // def body: BasicBlock 73 | // def latch: BasicBlock 74 | // def exit: BasicBlock 75 | } 76 | 77 | case class LoopInfo(var loops: Map[BasicBlock, Loop] = Map()) -------------------------------------------------------------------------------- /src/test/scala/analysis/DomTreeAnalysisTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import mir.{DomTreeBuilder, *} 5 | import ty.NilType 6 | import tools.DumpManager 7 | 8 | class DomTreeAnalysisTest extends RcTestBase with MIRTestUtil { 9 | describe("DomTreeBuild") { 10 | val bbs = mkBBs( 11 | "entry" -> "1", 12 | "1" -> "2", 13 | "2" -> "exit", 14 | "1" -> "3", 15 | "3" -> "4", 16 | "4" -> "5", 17 | "5" -> "exit", 18 | "4" -> "6", 19 | "6" -> "4" 20 | ) 21 | it("success") { 22 | given BBsType = bbs 23 | val tree = mkTree 24 | given DomTree = tree 25 | 26 | "1".noOtherDom 27 | "2" isDom ("1") 28 | "3" isDom ("1") 29 | "4" isDom("1", "3") 30 | "5" isDom("1", "3", "4") 31 | "6" isDom("1", "3", "4") 32 | "exit" isDom ("1") 33 | 34 | // todo: 正反关系 35 | val idoms = iDomCompute(tree, "entry") 36 | "1" isIDom "entry" 37 | "2" isIDom "1" 38 | "3" isIDom "1" 39 | "4" isIDom "3" 40 | "5" isIDom "4" 41 | "6" isIDom "4" 42 | "exit" isIDom "1" 43 | } 44 | } 45 | 46 | describe("build") { 47 | val order = List(0, 1, 2, 5, 6, 8, 7, 3, 4) 48 | val bbs = mkBBsByOrder( 49 | order, 50 | "entry" -> "0", 51 | "0" -> "1", 52 | "1" -> "2", 53 | "2" -> "3", 54 | "1" -> "5", 55 | "5" -> "6", 56 | "6" -> "7", 57 | "5" -> "8", 58 | "8" -> "7", 59 | "7" -> "3", 60 | "3" -> "1", 61 | "3" -> "4", 62 | "4" -> "exit", 63 | ) 64 | it("succeed") { 65 | given BBsType = bbs 66 | val tree = mkTree 67 | given DomTree = tree 68 | "0" isDom ("0") 69 | "1" isDom ("0", "1") 70 | "2" isDom ("0", "1", "2") 71 | "3" isDom ("0", "1", "3") 72 | "4" isDom ("0", "1", "3", "4") 73 | "5" isDom ("0", "1", "5") 74 | "6" isDom ("0", "1", "5", "6") 75 | "7" isDom ("0", "1", "5", "7") 76 | "8" isDom ("0", "1", "5", "8") 77 | } 78 | } 79 | 80 | extension (n: String) { 81 | def isDom(children: String*)(using tree: DomTree)(using bbs: BBsType): Unit = { 82 | val expect = (children.toList ::: List(n, "entry")).toSet 83 | tree(n).children.map(_.name).toSet should be(expect) 84 | } 85 | 86 | def isIDom(child: String)(using tree: DomTree)(using bbs: BBsType): Unit = { 87 | tree(child) idom tree(n) 88 | } 89 | 90 | def noOtherDom(using tree: DomTree)(using bbs: BBsType): Unit = { 91 | isDom() 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/main/scala/ty/Type.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | 4 | import ast.{ASTNode, Class} 5 | 6 | import rclang.tools.GlobalTable 7 | 8 | import collection.immutable.ListMap 9 | 10 | sealed class Type 11 | 12 | case object BooleanType extends Type 13 | 14 | case object StringType extends Type 15 | 16 | case object Int32Type extends Type 17 | 18 | case object Int64Type extends Type 19 | 20 | case object FloatType extends Type 21 | 22 | case object NilType extends Type 23 | 24 | case class FnType(ret: Type, params: List[Type]) extends Type 25 | 26 | case object InferType extends Type 27 | 28 | case class ArrayType(valueT: Type, size: Int) extends Type 29 | 30 | case class ErrType(msg: String) extends Type 31 | 32 | case object TypeBuilder { 33 | def fromClass(klass: Class, gt: GlobalTable): StructType = { 34 | val name = klass.name 35 | fromClass(name.str, gt) 36 | } 37 | 38 | def fromClass(name: String, gt: GlobalTable): StructType = { 39 | val vars = gt.classTable(name).allInstanceVars(gt) 40 | StructType(name, ListMap.from(vars.map(field => field.name.str -> Infer.translate(field.ty)))) 41 | } 42 | } 43 | 44 | case class StructType(name: String, private var fields: ListMap[String, Type]) extends Type { 45 | def align = fieldSizes.min 46 | 47 | def fieldOffset(field: String) = { 48 | // 1. find index 49 | // 2. reduce to index 50 | fields.zipWithIndex.find(_._1._1 == field) match 51 | case Some(value) => fields.slice(0, value._2).values.map(sizeof).sum 52 | case None => ??? 53 | } 54 | 55 | def fieldSizes = fields.values.map(sizeof) 56 | 57 | def sizeAfterAlign(align: Int) = { 58 | fieldSizes.map(size => (size / align + 1) * align) 59 | } 60 | 61 | override def toString: String = s"StructType($name)" 62 | } 63 | 64 | case class PointerType(ty: Type) extends Type 65 | 66 | def sizeof(ty: Type): Int = { 67 | // PtrLength == WordLength 68 | val ptrLength = 8 69 | ty match 70 | case BooleanType => 1 71 | case StringType => ptrLength 72 | case Int32Type => 4 73 | case FloatType => 4 74 | case NilType => 4 // todo: fix this 75 | case FnType(ret, params) => ptrLength 76 | case InferType => ??? // enum 77 | case ErrType(msg) => ??? 78 | case StructType(name, fields) => ??? 79 | case PointerType(ty) => ptrLength 80 | case ArrayType(ty, size) => sizeof(ty) * size 81 | case _ => ??? 82 | } 83 | 84 | trait Typed { 85 | var ty: Type = InferType 86 | 87 | def withTy(ty: Type): this.type = { 88 | this.ty = ty 89 | this 90 | } 91 | 92 | def withInfer: this.type = withTy(infer) 93 | 94 | def infer: Type = Infer(this) 95 | } -------------------------------------------------------------------------------- /src/main/scala/ty/TyCtxt.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | 4 | import ast.{Ident, MethodDecl, Params, TyInfo} 5 | import ty.Type 6 | import tools.{ClassEntry, FullName, GlobalTable, NestSpace} 7 | import ty.Infer 8 | 9 | import rclang.ty.Infer.tyCtxt 10 | 11 | import scala.collection.immutable.Map 12 | 13 | /** 14 | * 15 | * @param global GlobalTypeInfo 16 | */ 17 | case class TyCtxt() { 18 | var global: Map[Ident, Type] = Map() 19 | var globalTable: GlobalTable = GlobalTable(collection.mutable.Map(), null) 20 | var fullName: FullName = FullName() 21 | def setGlobalTable(gt:GlobalTable) = { 22 | globalTable = gt 23 | // todo: remove this global,这里是因为infer没有设置tyctxt导致的出错 24 | // global = globalTable.methodTypeTable.toMap.map((id, item) => id -> Infer(item)) 25 | } 26 | 27 | /** 28 | * OuterScopes's Type only SymbolTable 29 | */ 30 | var outer = List[Map[Ident, Type]]() 31 | /** 32 | * CurrentScope's Type only SymbolTable 33 | */ 34 | var local = Map[Ident, Type]() 35 | 36 | private def getClassTy(id: Ident) = globalTable.classTable.get(id.str).map(_.astNode.infer) 37 | 38 | def lookup(ident: Ident): Option[Type] = { 39 | if (ident.str == "malloc" || ident.str == "this") { 40 | val ty = getClassTy(Ident(fullName.klass)) 41 | ty match 42 | case Some(value) => value 43 | case None => println(s"${fullName.klass} not found") 44 | return Some(PointerType(ty.get)) 45 | } 46 | if (ident.str == "print") { 47 | return Some(NilType) 48 | } 49 | // 1. local 50 | val ty = local.get(ident) orElse outer.find(_.contains(ident)).map(_(ident)) orElse global.get(ident) 51 | ty orElse { 52 | // 1. var 53 | // 2. function 54 | if (!globalTable.classTable.contains(fullName.klass)) { 55 | return None 56 | } 57 | globalTable.classTable(fullName.klass).lookupFieldTy(ident) orElse { 58 | Some(NestSpace(globalTable, fullName).lookupFn(ident).infer) 59 | } 60 | } 61 | } 62 | 63 | /** 64 | * enter a block(method body or single block) 65 | * @param f lazy evaluated method call, before F is used, F will not be evaluated 66 | * @tparam T 67 | * @return 68 | */ 69 | def enter[T](newLocal: Map[Ident, Type], fnDecl: MethodDecl)(f: => T): T = { 70 | // (1, 2, 3) ::= 4 71 | // (4, 1, 2, 3) 72 | outer ::= local 73 | local = newLocal 74 | val oldFullName = fullName 75 | fullName = fullName.copy(fn = fnDecl) 76 | val result = f 77 | local = outer.head 78 | outer = outer.tail 79 | fullName = oldFullName 80 | result 81 | } 82 | 83 | def enter[T](f: => T): T = { 84 | enter(Map(), MethodDecl(Ident(""), Params(List()), TyInfo.Nil))(f) 85 | } 86 | 87 | def addLocal(k: Ident, v: Type): Unit = { 88 | local += (k -> v) 89 | } 90 | } -------------------------------------------------------------------------------- /src/main/scala/codegen/ASMTrait.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import tools.{Debugger, DumpManager} 5 | 6 | import java.io.{File, PrintWriter} 7 | 8 | val indent = " " 9 | 10 | case class ASMFile(sections: List[Section] = List()) { 11 | val ident: String = Def.version 12 | def write(path: String): Unit = { 13 | val printer = new PrintWriter(new File(path)); 14 | // println(sectionString) 15 | printer.write(sectionString) 16 | printer.write(jmpToMain) 17 | printer.close() 18 | } 19 | 20 | private def jmpToMain: String = { 21 | s".section .text\n .globl main\n .type main, @function\nmain:\n${indent}jmp _ZN6Kernel4mainE1v\n" 22 | } 23 | private def sectionString: String = sections.map(_.toASM).mkString("\n") + "\n" 24 | } 25 | 26 | case class MFText(asm: List[ASMText], mf: MachineFunction) { 27 | def toASM: String = { 28 | asm.map(_ match 29 | case ASMInstr(instr) => s"$indent$instr" 30 | case ASMLabel(label) => label 31 | case _ => Debugger.unImpl).mkString("\n") 32 | } 33 | } 34 | 35 | trait Section { 36 | def toASM: String = s".section $decl" + "\n" + getASMString 37 | 38 | protected def decl: String 39 | 40 | protected def getASMString: String 41 | } 42 | 43 | case class TextSection(mfs: List[MFText]) extends Section { 44 | private def fnDecls = mfs.map(mfText => s"$indent.globl ${mfText.mf.name}\n$indent.type ${mfText.mf.name}, @function\n").mkString("\n") 45 | override def getASMString: String = fnDecls + mfs.map(_.toASM).mkString("\n") 46 | 47 | override def decl: String = ".text" 48 | } 49 | 50 | case class StringSection(strTable: Map[String, Label]) extends Section { 51 | override def getASMString: String = strTable.map((str, label) => s"${label.name}:\n$indent.string \"$str\"").mkString("\n") 52 | 53 | override def decl: String = ".rodata" 54 | } 55 | 56 | trait ASMEmiter { 57 | def emitMF(fm: MachineFunction): MFText 58 | 59 | def emitMBB(mbb: MachineBasicBlock): List[ASMText] 60 | 61 | def emitInstr(instr: MachineInstruction): List[ASMText] 62 | } 63 | 64 | def buildASM(mfs: List[MachineFunction], strTable: Map[String, Label]) = { 65 | val stringSection = StringSection(strTable) 66 | val mfList = mfs.map(GNUASMEmiter().emitMF) 67 | val textSection = TextSection(mfList) 68 | ASMFile(List(stringSection, textSection)) 69 | } 70 | 71 | def generateASM(mfs: List[MachineFunction], strTable: Map[String, Label], asmPath: String): Unit = { 72 | val asmFile = buildASM(mfs, strTable) 73 | asmFile.write(asmPath) 74 | } 75 | 76 | trait ASMText { 77 | def str: String 78 | } 79 | 80 | case class ASMLabel(label: String) extends ASMText { 81 | override def str: String = label 82 | } 83 | 84 | case class ASMInstr(instr: String) extends ASMText { 85 | override def str: String = instr 86 | } 87 | 88 | given Conversion[String, ASMInstr] = ASMInstr(_) -------------------------------------------------------------------------------- /src/main/scala/interpreter/evaluator.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package interpreter 3 | 4 | import ast.* 5 | import ast.BinaryOp.* 6 | import ast.Expr.* 7 | 8 | import cats.implicits.* 9 | import analysis.SymScanner 10 | import tools.GlobalTable 11 | import ty.{TyCtxt, TypedTranslator} 12 | import ast.ImplicitConversions.strToId 13 | 14 | case class Structure(var name: String, var fields: List[String], var values: List[Expr]) { 15 | 16 | } 17 | 18 | case class Evaluator(var fenv:Map[Ident, Method] = Map()) { 19 | var env = Map[Ident, Expr]() 20 | var curObj: Structure = null 21 | 22 | 23 | def run_call(target: Ident, args: List[Expr]): Any = { 24 | run_call_impl(fenv(target), args) 25 | } 26 | 27 | def run_call_impl(method: Method, args: List[Expr]) : Any = { 28 | val new_env = method.decl.inputs.params.zip(args).map{case (p, a) => (p.name, a)}.toMap 29 | env = env ++ new_env 30 | run_expr(method.body) 31 | } 32 | 33 | def run_module(mod: RcModule): Any = { 34 | val table = SymScanner(mod) 35 | val tyCtxt = TyCtxt() 36 | tyCtxt.setGlobalTable(table) 37 | val typedModule = TypedTranslator(tyCtxt)(mod) 38 | run_call("main", List()) 39 | } 40 | 41 | def run_expr_t[T](expr: Expr): T = { 42 | ??? 43 | // run_expr(expr).asInstanceOf[T] 44 | } 45 | 46 | def run_expr(expr: Expr): Any = { 47 | expr match 48 | case Number(v) => v 49 | case Identifier(ident) => run_expr(env(ident)) 50 | case Bool(b) => b 51 | case Binary(op, lhs, rhs) => { 52 | ??? 53 | val l = run_expr_t[Int](lhs) 54 | val r = run_expr_t[Int](rhs) 55 | op match 56 | case Add => l + r 57 | case Sub => l - r 58 | case Mul => l * r 59 | case Div => l / r 60 | case EQ => l == r 61 | case LT => l < r 62 | case GT => l > r 63 | } 64 | case Str(str) => str 65 | case If(cond, true_branch, false_branch) => run_expr(cond) match { 66 | case Bool(true) => run_expr(true_branch) 67 | case Bool(false) => ??? 68 | } 69 | case Lambda(args, block) => ??? 70 | case Call(target, args, _) => run_call(target, args) 71 | case MethodCall(obj, target, args) => ??? 72 | case Block(stmts) => stmts.map(run_stmt).last 73 | case Return(expr) => run_expr(expr) 74 | case Field(expr, ident) => ??? 75 | case Self => ??? 76 | case Symbol(ident, _) => ??? 77 | case Index(expr, i) => ??? 78 | } 79 | 80 | def run_stmt(stmt: Stmt) = { 81 | stmt match 82 | case Stmt.Local(name, tyInfo, value) => env = env.updated(name, value) 83 | case Stmt.Expr(expr) => run_expr(expr) 84 | case Stmt.While(cond, body) => while(run_expr_t[Boolean](cond)) run_expr(body) 85 | case Stmt.Assign(name, value) => env = env.updated(name, value) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/scala/mir/CFG.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | import scala.collection.mutable.LinkedHashMap 4 | import scala.collection.mutable.LinkedHashSet 5 | 6 | def mkBB(name: String): BasicBlock = { 7 | val bb = BasicBlock(name, List(MultiSuccessorsInst())) 8 | bb 9 | } 10 | 11 | type BBsType = LinkedHashMap[String, BasicBlock] 12 | def mkBBs(connections: (String, String)*): BBsType = { 13 | // sorted by name 14 | val set = connections.foldLeft(Set[BasicBlock]())((s, e) => s + mkBB(e._1) + mkBB(e._2)).filter(b => b.name != "entry" && b.name != "exit") 15 | var bbMap = LinkedHashMap("entry" -> mkBB("entry")) 16 | val bbInFn = LinkedHashMap.from(set.toList.sortBy(_.name).map(s => s.name -> s)) 17 | bbMap = bbMap ++ bbInFn 18 | bbMap("exit") = mkBB("exit") 19 | connections.foreach((begin, end) => { 20 | bbMap(begin).terminator.asInstanceOf[MultiSuccessorsInst].add(bbMap(end)) 21 | }) 22 | bbMap 23 | } 24 | 25 | def mkBBsByOrder(order: List[Int], connections: (String, String)*): BBsType = { 26 | // sorted by name 27 | val set = LinkedHashSet.from(order.map(_.toString)).map(mkBB).filter(b => b.name != "entry" && b.name != "exit") 28 | println(set) 29 | var bbMap = LinkedHashMap("entry" -> mkBB("entry")) 30 | val bbInFn = LinkedHashMap.from(set.map(s => s.name -> s)) 31 | bbMap = bbMap ++ bbInFn 32 | bbMap("exit") = mkBB("exit") 33 | connections.foreach((begin, end) => { 34 | bbMap(begin).terminator.asInstanceOf[MultiSuccessorsInst].add(bbMap(end)) 35 | }) 36 | bbMap 37 | } 38 | 39 | def canReach(a: BasicBlock, b: BasicBlock): Boolean = { 40 | if (a == b) return true 41 | a.successors.exists(canReach(_, b)) 42 | } 43 | 44 | def predecessors(bb: BasicBlock, bbs: List[BasicBlock]): LinkedHashSet[BasicBlock] = { 45 | LinkedHashSet.from(bbs.filter(_.terminator.successors.contains(bb))) 46 | } 47 | 48 | def predecessorsMap(bbs: List[BasicBlock]): Map[BasicBlock, LinkedHashSet[BasicBlock]] = { 49 | bbs.map(bb => bb -> predecessors(bb, bbs)).toMap 50 | } 51 | 52 | def dfsBasicBlocks(b: BasicBlock): List[BasicBlock] = { 53 | var result = List[BasicBlock]() 54 | var visited = Set[BasicBlock]() 55 | dfsImpl(b) 56 | def dfsImpl(b: BasicBlock): Unit = { 57 | if (visited(b)) return visited 58 | result = result :+ b 59 | visited = visited + b 60 | b.successors.foreach(dfsImpl) 61 | } 62 | result 63 | } 64 | 65 | def loopBasicBlocks(begin: BasicBlock, end: BasicBlock): List[BasicBlock] = { 66 | var result = List[BasicBlock]() 67 | var visited = Set[BasicBlock]() 68 | dfsImpl(begin) 69 | def dfsImpl(b: BasicBlock): Unit = { 70 | if (visited(b)) { 71 | return 72 | } 73 | if(b.name == "header") { 74 | println(visited) 75 | } 76 | result = result :+ b 77 | visited = visited + b 78 | if(b == end) { 79 | return 80 | } 81 | b.successors.foreach(dfsImpl) 82 | } 83 | result 84 | } -------------------------------------------------------------------------------- /src/main/scala/transform/CallInliner.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import mir.* 5 | import pass.{AnalysisManager, Transform} 6 | 7 | // before 8 | // bb1. 9 | // f() = a + b * c 10 | // val n = 1 11 | // val d = f() 12 | // d + n 13 | 14 | // after 15 | // bb1. 16 | // val n = 1 17 | // val tmp = a + b * c 18 | // branch 19 | 20 | // bb2. 21 | // val d = tmp 22 | // d + n 23 | 24 | class CallInliner extends Transform[Function] { 25 | // 1. find call 26 | // 2. args copy 27 | // 3. return copy 28 | // 4. call bb split to two bb 29 | override def run(fn: Function, AM: AnalysisManager[Function]): Unit = { 30 | // todo: only process for call, invoke maybe cause exception and basicblock shoudld rebuild 31 | fn.instructions 32 | .collect { case a: Call => a } 33 | // todo: 如果foreach替换了后面的call,那么就不能继续从当前的开始了,这个只能支持一个bb里面一个call,因为实现会破坏bb 34 | .foreach(inst => { 35 | val f = inst.func 36 | if(canInline(f)) { 37 | // make new bb 38 | val newAfterBB: BasicBlock = new BasicBlock("afterBB") 39 | val (insts, alloc) = getInlineInsts(f, inst, newAfterBB) 40 | val instIndex = inst.parent.stmts.indexOf(inst) 41 | // inst后面的inst拷贝到newAfterBB, 42 | val after = inst.parent.stmts.slice(instIndex + 1, inst.parent.stmts.length) 43 | newAfterBB.stmts = after 44 | // 之后inst要被替换为insts 45 | inst.parent.stmts = inst.parent.stmts.slice(0, instIndex) ::: insts 46 | // insert bb 47 | fn.bbs = fn.bbs:+newAfterBB 48 | } 49 | }) 50 | } 51 | 52 | def replaceArgument(insts: List[Instruction], pair: Map[Argument, Value]): List[Instruction] = { 53 | insts.map(inst => { 54 | val newOperands = inst.getOperands.map { 55 | case param: Argument => pair(param) 56 | case i => i 57 | } 58 | inst.setOperands(newOperands) 59 | inst 60 | }) 61 | } 62 | 63 | def getInlineInsts(f: Function, inst: Call, newAfterBB: BasicBlock): (List[Instruction], Alloc) = { 64 | val originInstList = f.instructions 65 | 66 | // 如果最后一个不是return,是隐式返回的话,那么要在lower的过程添加return才行 67 | originInstList.last match 68 | case r: Return => 69 | case i: _ => throw new RuntimeException() 70 | 71 | val argMap = f.argument.zip(inst.args).toMap 72 | val instList = replaceArgument(originInstList, argMap) 73 | 74 | // 保存返回值 75 | val alloc = new Alloc("retTmp", f.retType) 76 | // 简单替换所有return 77 | val instWithoutReturn = instList.flatMap { 78 | case r: Return => List(new Store(r.value, alloc), Branch(newAfterBB)) 79 | case v => List(v) 80 | } 81 | 82 | (instWithoutReturn, alloc) 83 | } 84 | 85 | // inline for simple function, but this is a very simple strategy 86 | def canInline(fn: Function): Boolean = { 87 | fn.bbs.length <= 1 && fn.instructions.length < 10 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/main/scala/codegen/MachineFrameInfo.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import mir.{Alloc, Argument} 5 | 6 | class StackItem(var len: Int = 0, var offset: Int = 0) { 7 | def toFrameIndex = { 8 | FrameIndex(offset, len) 9 | } 10 | } 11 | 12 | case class LocalItem(private val _len: Int, alloc: Alloc) extends StackItem(_len) { 13 | override def toString: String = s"local:${alloc.id}" 14 | } 15 | 16 | case class ArgItem(private val _len: Int, arg: Argument) extends StackItem(_len) { 17 | override def toString: String = s"arg:${arg.name}" 18 | } 19 | 20 | case class TmpItem(private val _len: Int) extends StackItem(_len) { 21 | override def toString: String = "TmpItem" 22 | } 23 | 24 | case class MachineFrameInfo() { 25 | var mf: MachineFunction = null 26 | 27 | var items = List[StackItem]() 28 | 29 | def isEmpty = items.isEmpty 30 | 31 | def top = items.last 32 | 33 | def size = items.length 34 | 35 | // 最后一个的位置,在其之上是数据,因此不需要添加len 36 | def length = items.last.offset 37 | 38 | def align(v: Int, base: Int) = (v / base + (if v % base == 0 then 0 else 1)) * base 39 | 40 | def alignLength = align(length, 16) 41 | 42 | // invalid 43 | // -- 0 <- c <== rsp 44 | // -- 4 <- b 45 | // -- 8 <- a <== 低位 46 | // rsp - 8 = "abc" 47 | def addItem(item: StackItem): StackItem = { 48 | // todo: fix for 0 size 49 | item.offset = (if items.isEmpty then 0 else length) + (if item.len == 0 then 8 else item.len) 50 | items = items :+ item 51 | item 52 | } 53 | 54 | def checkValid = { 55 | items.sliding(3).foreach(window => { 56 | val valid = window.map(_.getClass).toSet.size != 3 57 | if(!valid) { 58 | throw new RuntimeException("invalid machine frame info") 59 | } 60 | }) 61 | } 62 | 63 | def locals = sliceItems[LocalItem](items.indexWhere(_.isInstanceOf[LocalItem]), items.lastIndexWhere(_.isInstanceOf[LocalItem])) 64 | 65 | def tmps = sliceItems[TmpItem](items.indexWhere(_.isInstanceOf[TmpItem]), items.lastIndexWhere(_.isInstanceOf[TmpItem])) 66 | 67 | def args = sliceItems[ArgItem](items.indexWhere(_.isInstanceOf[ArgItem]), items.lastIndexWhere(_.isInstanceOf[ArgItem])) 68 | 69 | private def sliceItems[T](range: (Int, Int)) = { 70 | if(range._1 == -1) { 71 | List() 72 | } else { 73 | items.slice(range._1, range._2 + 1).map(_.asInstanceOf[T]) 74 | } 75 | } 76 | 77 | override def toString: String = { 78 | val line = "| ---------- |" 79 | s"${mf.name}\n0 $line rsp\n" + 80 | items.map(item => { 81 | val offsetLine = s"${item.offset.toString.padTo(4, ' ')}$line\n" 82 | val itemLine = s" | ${item.toString.padTo(10, ' ')} |\n" 83 | val spaceLineCount: Int = item.len / 4 84 | 85 | val lenLine = (1 until spaceLineCount).map(n => s"${(item.offset + n * 4).toString.padTo(4, ' ')}| |\n").mkString 86 | itemLine + offsetLine 87 | }).mkString 88 | } 89 | } -------------------------------------------------------------------------------- /src/main/scala/ty/Translator.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | 4 | import ast.{ASTNode, ASTVisitor, Expr, FieldDef, Ident, Item, MethodDecl, Modules, Param, RcModule, Stmt, TyInfo} 5 | import ast.Expr.* 6 | import ast.Stmt.* 7 | import ast.* 8 | 9 | import scala.collection.immutable.Map 10 | 11 | case object TypedTranslator { 12 | var tyCtxt: TyCtxt = TyCtxt() 13 | def apply(tyCtxt: TyCtxt)(module: RcModule): RcModule = { 14 | this.tyCtxt = tyCtxt 15 | // update local table in TypedTranslator will cause Infer ctxt update 16 | // because of pass a typCtxt by Ref 17 | Infer.enter(tyCtxt, RcModuleTrans(module)) 18 | } 19 | 20 | // fn type 21 | def RcModuleTrans(module: RcModule): RcModule = { 22 | tyCtxt.fullName.module = module.name 23 | val items = module.items.map(itemTrans) 24 | RcModule(items) 25 | } 26 | 27 | def itemTrans(item: Item): Item = { 28 | item match 29 | case m: Method => methodTrans(m) 30 | case c: Class => tyCtxt.enter(classTrans(c)) 31 | } 32 | 33 | def classTrans(klass: Class): Class = { 34 | var oldName = tyCtxt.fullName 35 | tyCtxt.fullName.klass = klass.name.str 36 | val methods = klass.methods.map(methodTrans) 37 | tyCtxt.fullName = oldName 38 | klass.copy(methods = methods) 39 | } 40 | 41 | def exprTrans(expr: Expr): Expr = 42 | // println(expr) 43 | (expr match 44 | case Binary(op, lhs, rhs) => Binary(op, lhs.withInfer, rhs.withInfer) 45 | case If(cond, true_branch, false_branch) => { 46 | val false_br = false_branch match 47 | case Some(fBr) => Some(fBr.withInfer) 48 | case None => None 49 | If(cond.withInfer, 50 | true_branch.withInfer.asInstanceOf[Block], 51 | false_br) 52 | } 53 | case Call(target, args, generic) => Call(target, args.map(_.withInfer), generic) 54 | case Return(expr) => Return(expr.withInfer) 55 | case Lambda(args, block) => ??? 56 | case MethodCall(obj, target, args) => ??? 57 | case Block(stmts) => tyCtxt.enter(Block(stmts.map(stmtTrans))) 58 | case Field(expr, ident) => ??? 59 | case Self => ??? 60 | case Symbol(ident, _) => ??? 61 | case Index(expr, i) => ??? 62 | case _ => expr).withInfer.setPos(expr.pos) 63 | 64 | def stmtTrans(stmt: Stmt): Stmt = 65 | (stmt match 66 | case Local(name, ty, value) => { 67 | val localTy = ty match 68 | case TyInfo.Spec(_) => Infer.translate(ty) 69 | case _ => Infer(value) 70 | tyCtxt.addLocal(name, localTy) 71 | Local(name, ty, value.withInfer).withTy(localTy) 72 | } 73 | case Stmt.Expr(expr) => Stmt.Expr(expr.withInfer) 74 | case While(cond, body) => While(cond.withInfer, body.withInfer) 75 | case Assign(name, value) => Assign(name, value.withInfer) 76 | case For(init, cond, incr, body) => For(init.withInfer, cond.withInfer, incr.withInfer, body.withInfer)) 77 | .withInfer.setPos(stmt.pos) 78 | 79 | def methodTrans(method: Method): Method = { 80 | val inputs = method.decl.inputs.params.map(p => p.name.setPos(p.pos) -> Infer.translate(p.ty)).toMap 81 | tyCtxt.enter(inputs, method.decl)( 82 | method.copy(body = exprTrans(method.body).asInstanceOf[Block]).withInfer) 83 | } 84 | } -------------------------------------------------------------------------------- /src/main/scala/tools/NestScope.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | 4 | import ast.* 5 | 6 | import Expr.* 7 | import ast.ImplicitConversions.* 8 | 9 | import rclang.ty.Infer 10 | 11 | case class FullName(var fn: MethodDecl = MethodDecl("", Params(List()), TyInfo.Nil), var klass: String = Def.DefaultModule, var module: String = "") { 12 | def names = List(module, klass, fn.name.str).filter(_.nonEmpty) 13 | } 14 | 15 | case class NestSpace(val gt: GlobalTable, val fullName: FullName) { 16 | def withClass(klass: String) = { 17 | copy(fullName = fullName.copy(klass = klass)) 18 | } 19 | 20 | def localTable = { 21 | gt.classTable(fullName.klass).methods.get(fullName.fn.name) match 22 | case Some(value) => value 23 | case None => ??? 24 | } 25 | 26 | def klassTable = gt.classTable(fullName.klass) 27 | 28 | def fn = { 29 | localTable.astNode 30 | } 31 | 32 | def klass = { 33 | klassTable.astNode 34 | } 35 | 36 | def module: RcModule = { 37 | assert(gt.module.name == fullName.module) 38 | gt.module 39 | } 40 | 41 | // fn in SymbolTable is not be preprocessed 42 | def lookupFn(id: Ident, recursive: Boolean = false): Method = { 43 | // 1. fn, used for recursive 44 | if(recursive && fn.name == id) { 45 | fn 46 | } else { 47 | // 2. class 48 | val klassMethod = klassTable.allMethodsList(gt).find(_.name == id) 49 | val method = klassMethod.getOrElse({ 50 | // 3. module 51 | module.items.find(_ match 52 | case m: Method => m.name == id 53 | case _ => false) match 54 | case Some(value) => value.asInstanceOf[Method] 55 | case None => throw new RuntimeException(s"$fullName can't find $id")}) 56 | method.withInfer 57 | } 58 | } 59 | 60 | def lookupVar(id: Ident): Expr = { 61 | if(id.str == "this") { 62 | val self = Expr.Self.withTy(Infer.translate(klass.name)) 63 | return self 64 | } 65 | // 1. local 66 | localTable.locals.get(id) match 67 | case Some(value) => Identifier(value.name).withTy(value.ty) 68 | // 2. argument 69 | case None => { 70 | fn.decl.inputs.params.find(_.name == id) match 71 | case Some(value) => Identifier(value.name).withTy(Infer.translate(value.ty)) 72 | case None => { 73 | // 3. field 74 | klassTable.allInstanceVars(gt).find(_.name == id) match 75 | case Some(value) => Field(Identifier(Def.self), id).withTy(Infer.translate(value.ty)) 76 | case None => throw new RuntimeException() 77 | } 78 | } 79 | } 80 | def findMethodInWhichClass(id: Ident, gt: GlobalTable): Class = { 81 | findMethodInWhichClassImpl(klassTable, id, gt) getOrElse { 82 | // findMethodInWhichClassImpl(gt.classTable(Def.Kernel), id, gt) 83 | gt.module.items.find(_ match 84 | case m:Method => m.name == id 85 | case _ => false) match 86 | case Some(value) => gt.classTable(Def.Kernel).astNode 87 | case None => throw new RuntimeException("fn not in any class") 88 | } 89 | } 90 | 91 | private def findMethodInWhichClassImpl(klass: ClassEntry, id: Ident, gt: GlobalTable): Option[Class] = { 92 | klass.allMethods(gt).find(k => k._2.exists(_.name == id)).map(_._1) 93 | } 94 | } -------------------------------------------------------------------------------- /src/main/scala/ast/ASTVisitor.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ast 3 | import ast.{ASTNode, Modules} 4 | import ast.Expr.* 5 | import Stmt.{Local, While, Assign} 6 | 7 | trait ASTVisitor { 8 | type R = Unit 9 | 10 | def visit(modules: Modules): R = visitRecursive(modules) 11 | 12 | def visit(module: RcModule): R = visitRecursive(module) 13 | 14 | def visit(item: Item): R = visitRecursive(item) 15 | 16 | def visit(expr: Expr): R = visitRecursive(expr) 17 | 18 | def visit(stmt: Stmt): R = visitRecursive(stmt) 19 | 20 | def visit(ty: TyInfo): R = visitRecursive(ty) 21 | 22 | def visit(decl: MethodDecl): R = visitRecursive(decl) 23 | 24 | def visit(ident: Ident): R = {} 25 | 26 | def visit(param: Param): R = visitRecursive(param) 27 | 28 | def visit(field: FieldDef): R = visitRecursive(field) 29 | 30 | def visit(method: Method): R = visitRecursive(method) 31 | 32 | def visit(klass: Class): R = visitRecursive(klass) 33 | 34 | final def visitRecursive(modules: Modules): R = modules.modules.foreach(visit) 35 | 36 | final def visitRecursive(module: RcModule): R = { 37 | module.items.foreach(visit) 38 | } 39 | 40 | final def visitRecursive(item: Item): R = { 41 | item match { 42 | case method: Method => visit(method) 43 | case klass: Class => visit(klass) 44 | case _ => throw new RuntimeException("NoneItem") 45 | } 46 | } 47 | 48 | final def visitRecursive(expr: Expr): R = { 49 | expr match { 50 | case Number(n) => 51 | case Identifier(id) => 52 | case Bool(bool) => 53 | case Binary(op, lhs, rhs) => 54 | case Str(s) => 55 | case If(cond, true_branch, false_branch) => 56 | case Lambda(args, stmts) => 57 | case Call(target, args, _) => 58 | case MethodCall(obj, target, args) => 59 | case block: Block => visitRecursive(block) 60 | case Return(expr) => 61 | case Field(obj, id) => 62 | case Self => 63 | case Symbol(id, _) => 64 | case Index(expr, i) => 65 | } 66 | } 67 | 68 | final def visitRecursive(s: Stmt): R = { 69 | s match { 70 | case Local(id, ty, value) => visit(id); visit(ty); visit(value) 71 | case Stmt.Expr(expr) => visit(expr) 72 | case While(cond, stmts) => visit(cond); visit(stmts) 73 | case Assign(id, value) => visit(id); visit(value) 74 | case _ => throw new RuntimeException("NoneStmt") 75 | } 76 | } 77 | 78 | final def visitRecursive(value: TyInfo): R = { 79 | 80 | } 81 | 82 | final def visitRecursive(b: Block): R = { 83 | b.stmts.foreach(visit) 84 | } 85 | 86 | final def visitRecursive(decl: MethodDecl): R = { 87 | visit(decl.name) 88 | decl.inputs.params.foreach(visit) 89 | visit(decl.outType) 90 | } 91 | 92 | final def visitRecursive(param: Param): R = { 93 | visit(param.name) 94 | visit(param.ty) 95 | } 96 | 97 | final def visitRecursive(field: FieldDef): R = { 98 | visit(field.name) 99 | visit(field.ty) 100 | field.initValue match { 101 | case Some(expr) => visit(expr) 102 | case _ => 103 | } 104 | } 105 | 106 | final def visitRecursive(method: Method): R = { 107 | visit(method.decl) 108 | visit(method.body) 109 | } 110 | 111 | final def visitRecursive(klass: Class): R = { 112 | visit(klass.name) 113 | klass.parent match { 114 | case Some(parent) => visit(parent) 115 | case None => 116 | } 117 | klass.vars.foreach(visit) 118 | klass.methods.foreach(visit) 119 | } 120 | } -------------------------------------------------------------------------------- /src/main/scala/mir/InstVisitor.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | 5 | def traverse[T, U](list: List[T])(f: T => U): List[U] = 6 | list.map(f) 7 | 8 | def traverseInst[T <: Instruction](list: List[T]): List[String] = 9 | traverse(list)(inst => { 10 | Printer().visit(inst) 11 | // val user = inst.asInstanceOf[User] 12 | // s"${inst.getClass.getSimpleName}:${user.ty} ${user.operands.map(_.toString).mkString(" ")}" 13 | }) 14 | 15 | trait InstVisitor { 16 | type TRet = Unit 17 | def visit(inst: Instruction): TRet = { 18 | inst match 19 | case i: Call => visit(i) 20 | case i: Call => visit(i) 21 | case i: CondBranch => visit(i) 22 | case i: Branch => visit(i) 23 | case i: Return => visit(i) 24 | case i: Binary => visit(i) 25 | case i: Alloc => visit(i) 26 | case i: Load => visit(i) 27 | case i: Store => visit(i) 28 | case i: PhiNode => visit(i) 29 | } 30 | 31 | def visit(call: Call): TRet = { 32 | 33 | } 34 | 35 | def visit(condbranch: CondBranch): TRet = { 36 | 37 | } 38 | def visit(branch: Branch): TRet = { 39 | 40 | } 41 | def visit(ret: Return): TRet = { 42 | 43 | } 44 | def visit(binary: Binary): TRet = { 45 | 46 | } 47 | def visit(alloc: Alloc): TRet = { 48 | 49 | } 50 | def visit(load: Load): TRet = { 51 | 52 | } 53 | def visit(store: Store): TRet = { 54 | 55 | } 56 | def visit(phinode: PhiNode): TRet = { 57 | 58 | } 59 | } 60 | 61 | class Printer{ 62 | def opsToString(user: User) = { 63 | user.operands.map(_.toString).mkString(" ") 64 | } 65 | 66 | def instName(inst: Instruction) = { 67 | inst.getClass.getSimpleName 68 | } 69 | 70 | def visit(inst: Instruction): String = { 71 | val user = inst.asInstanceOf[User] 72 | val instStr = inst match { 73 | // // case BinaryInstBase(lhsValue, rhsValue) => ??? 74 | // // case UnaryInst(operandValue) => ??? 75 | // case intrinsic: Intrinsic => s"${instName(inst)} ${intrinsic.name}: ${user.ty}" 76 | // // case CondBranch(condValue, tBranch, fBranch) => ??? 77 | // // case Branch(destBasicBlock) => ??? 78 | // // case Return(value) => ??? 79 | // case bn @ Binary(op, lhs, rhs) => s"${instName(inst)}: ${user.ty} $op(${lhs}, ${rhs})" 80 | // // case Alloc(id, typ) => ??? 81 | // // case Load(ptr) => ??? 82 | //// case st @ Store(value, ptr) => s"${instName(inst)}: ${user.ty} ${st.value} -> ${st.ptr}" 83 | // case st: Store => s"${instName(inst)}: ${user.ty} ${st.value} -> ${st.ptr}" 84 | // // case GetElementPtr(value, offset) => ??? 85 | // case PhiNode(incomings) => s"PhiNode: ${incomings.map((v, bb) => (v, bb.toString))}${user.ty}" 86 | // case SwitchInst() => ??? 87 | // case MultiSuccessorsInst(bbs) => ??? 88 | case _ => s"${instName(inst)}:${user.ty} " 89 | // s"${opsToString(user)}" 90 | } 91 | instStr + inst.pos 92 | } 93 | } 94 | 95 | trait MIRVisitor extends InstVisitor { 96 | def visit(fn: Function): TRet = { 97 | fn.instructions.foreach(visit) 98 | } 99 | } 100 | 101 | class MIRPrinter extends MIRVisitor { 102 | override def visit(call: Call): TRet = { 103 | 104 | } 105 | 106 | override def visit(condbranch: CondBranch): TRet = { 107 | 108 | } 109 | 110 | override def visit(branch: Branch): TRet = { 111 | 112 | } 113 | 114 | override def visit(ret: Return): TRet = { 115 | 116 | } 117 | 118 | override def visit(binary: Binary): TRet = { 119 | 120 | } 121 | 122 | override def visit(alloc: Alloc): TRet = { 123 | 124 | } 125 | 126 | override def visit(load: Load): TRet = { 127 | println("load") 128 | } 129 | 130 | override def visit(store: Store): TRet = { 131 | 132 | } 133 | 134 | override def visit(phinode: PhiNode): TRet = { 135 | 136 | } 137 | } -------------------------------------------------------------------------------- /src/main/scala/graphviz/Backend.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | 3 | import org.apache.commons.lang3.SystemUtils 4 | import scala.language.postfixOps 5 | 6 | /** 7 | * @author Depeng Liang 8 | */ 9 | object Backend { 10 | 11 | // http://www.graphviz.org/cgi-bin/man?dot 12 | private val ENGINES = Set( 13 | "dot", "neato", "twopi", "circo", "fdp", "sfdp", "patchwork", "osage" 14 | ) 15 | 16 | val ENGINGSEXECPATH: Map[String, String] = ENGINES.foldLeft(Map[String, String]())({ 17 | (m, st) => { 18 | import sys.process._ 19 | Map[String, String](st -> s"which $st".!!) ++ m 20 | } 21 | }) 22 | 23 | // http://www.graphviz.org/doc/info/output.html 24 | private val FORMATS = Set( 25 | "bmp", 26 | "canon", "dot", "gv", "xdot", "xdot1.2", "xdot1.4", 27 | "cgimage", 28 | "cmap", 29 | "eps", 30 | "exr", 31 | "fig", 32 | "gd", "gd2", 33 | "gif", 34 | "gtk", 35 | "ico", 36 | "imap", "cmapx", 37 | "imap_np", "cmapx_np", 38 | "ismap", 39 | "jp2", 40 | "jpg", "jpeg", "jpe", 41 | "pct", "pict", 42 | "pdf", 43 | "pic", 44 | "plain", "plain-ext", 45 | "png", 46 | "pov", 47 | "ps", 48 | "ps2", 49 | "psd", 50 | "sgi", 51 | "svg", "svgz", 52 | "tga", 53 | "tif", "tiff", 54 | "tk", 55 | "vml", "vmlz", 56 | "vrml", 57 | "wbmp", 58 | "webp", 59 | "xlib", 60 | "x11" 61 | ) 62 | 63 | /** 64 | * Return command for open a file for default 65 | */ 66 | def viewFileCommand: String = { 67 | if (SystemUtils.IS_OS_LINUX) { 68 | "xdg-open" 69 | } 70 | else if (SystemUtils.IS_OS_MAC || SystemUtils.IS_OS_MAC_OSX) { 71 | "open" 72 | } 73 | else throw new RuntimeException("only support mac OSX or Linux ") 74 | } 75 | 76 | /** 77 | * Return command for execution and name of the rendered file. 78 | * 79 | * @param engine The layout commmand used for rendering ('dot', 'neato', ...). 80 | * @param format The output format used for rendering ('pdf', 'png', ...). 81 | * @param filePath The output path of the source file. 82 | * @return render command to execute. 83 | */ 84 | def command(engine: String, format: String, filePath: String = null): (String, String) = { 85 | require(ENGINES.contains(engine) == true, s"unknown engine: $engine") 86 | require(FORMATS.contains(format) == true, s"unknown format: $format") 87 | Option(filePath) match { 88 | case Some(path) => (s"${ENGINGSEXECPATH(engine)} -T$format -O $path", s"$path.$format") 89 | case None => (s"$engine -T$format", null) 90 | } 91 | } 92 | 93 | /** 94 | * Render file with Graphviz engine into format, return result filename. 95 | * 96 | * @param engine The layout commmand used for rendering ('dot', 'neato', ...). 97 | * @param format The output format used for rendering ('pdf', 'png', ...). 98 | * @param filePath Path to the DOT source file to render. 99 | */ 100 | @throws(classOf[RuntimeException]) 101 | def render(engine: String = "dot", format: String = "pdf", 102 | filePath: String): String = { 103 | val (args, rendered) = command(engine, format, filePath) 104 | import sys.process._ 105 | try { 106 | args ! 107 | } catch { 108 | case _: Throwable => 109 | val errorMsg = 110 | s"""failed to execute "$args", """ + 111 | """"make sure the Graphviz executables are on your systems' path""" 112 | throw new RuntimeException(errorMsg) 113 | } 114 | rendered 115 | } 116 | 117 | /** 118 | * Open filepath with its default viewing application (platform-specific). 119 | * For know only support linux. 120 | */ 121 | @throws(classOf[RuntimeException]) 122 | def view(filePath: String): Unit = { 123 | val command = s"$viewFileCommand $filePath" 124 | import sys.process._ 125 | try { 126 | command ! 127 | } catch { 128 | case _: Throwable => 129 | val errorMsg = s"failed to execute $command" 130 | throw new RuntimeException(errorMsg) 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /src/main/scala/tools/SymTable.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package tools 3 | import ast.{FieldDef, Ident, Item, MethodDecl} 4 | import ast.TyInfo 5 | import ast.Stmt 6 | import ast.ImplicitConversions.* 7 | import ast.* 8 | 9 | import rclang.ty.{Infer, Type} 10 | 11 | import scala.collection.mutable 12 | import scala.collection.mutable.Map 13 | 14 | object GlobalTablePrinter { 15 | def print(globalTable: GlobalTable): Unit = { 16 | println("GlobalTable") 17 | globalTable.classTable.foreach( (klassName, klass) => { 18 | println(klassName) 19 | print(klass) 20 | }) 21 | println("GlobalTable End") 22 | } 23 | 24 | def print(classEntry: ClassEntry): Unit = { 25 | classEntry.methods.foreach((_, method) => { 26 | println(method.fnName) 27 | println(method.locals) 28 | }) 29 | } 30 | 31 | def print(localTable: LocalTable): Unit = { 32 | println(localTable.locals) 33 | } 34 | } 35 | // because of id should carry position info, key' type should be String 36 | class GlobalTable(var classTable: Map[String, ClassEntry], var module: RcModule) { 37 | classTable.values.foreach(klass => { 38 | klass.gt = this 39 | }) 40 | def classes = classTable.keys 41 | 42 | def kernel = classTable(Def.Kernel) 43 | 44 | def methodTypeTable: Map[Ident, Item] = 45 | kernel.methods.map((name, local) => (local.astNode.decl.name -> local.astNode.asInstanceOf[Item])) 46 | 47 | def apply(id: String): Item = { 48 | kernel.methods(id).astNode 49 | } 50 | } 51 | 52 | case class LocalEntry(id: Int, astNode: Stmt.Local) { 53 | def pos = astNode.pos 54 | def name = astNode.name 55 | def ty = astNode.ty 56 | def initValue = astNode.value 57 | } 58 | 59 | class ClassEntry(val astNode: Class) { 60 | var gt: GlobalTable = null 61 | var methods = Map.empty[String, LocalTable] 62 | // var fields = Map.empty[String, FieldDef] 63 | 64 | def lookupMethods(name: String, gt: GlobalTable): Option[(Class, Method)] = { 65 | methods.get(name) match 66 | case Some(value) => Some((astNode, value.astNode)) 67 | case None => { 68 | val parentName = astNode.parent match 69 | case Some(value) => value 70 | case None => return None 71 | gt.classTable(parentName).lookupMethods(name, gt) 72 | } 73 | } 74 | 75 | def lookupFieldTy(field: Ident): Option[Type] = { 76 | allInstanceVars(gt).find(_.name == field) match 77 | case Some(value) => { 78 | value.ty match 79 | case TyInfo.Spec(ty) => Some(Infer.translate(ty)) 80 | case TyInfo.Infer => Some(value.initValue.get.infer) 81 | case TyInfo.Nil => None 82 | } 83 | case None => None 84 | } 85 | 86 | def fields = astNode.vars.map(v => (v.name -> v)).toMap 87 | // def addField(fieldDef: FieldDef): Unit = { 88 | // fields(fieldDef.name.str) = fieldDef 89 | // } 90 | 91 | def addMethod(localTable: LocalTable): Unit = { 92 | methods(localTable.fnName.str) = localTable 93 | } 94 | 95 | private def allMethodsImpl(gt: GlobalTable): Map[Class, List[Method]] = { 96 | val parentMethods = astNode.parent match 97 | case Some(parent) => gt.classTable(parent).allMethods(gt) 98 | case None => Map() 99 | Map(astNode -> astNode.methods) ++ parentMethods 100 | } 101 | 102 | def allMethods(gt: GlobalTable): Map[Class, List[Method]] = { 103 | allMethodsImpl(gt) ++ Map() 104 | } 105 | 106 | def allMethodsList(gt: GlobalTable): List[Method] = { 107 | val parentMethods = astNode.parent match 108 | case Some(parent) => gt.classTable(parent).allMethodsList(gt) 109 | case None => List() 110 | astNode.methods:::parentMethods 111 | } 112 | 113 | def allInstanceVars(gt: GlobalTable): List[FieldDef] = { 114 | val parentVars = astNode.parent match 115 | case Some(parent) => gt.classTable(parent).allInstanceVars(gt) 116 | case None => List() 117 | astNode.vars:::parentVars 118 | } 119 | } 120 | 121 | class LocalTable(val astNode: Method) { 122 | var locals = Map.empty[String, LocalEntry] 123 | def fnName = astNode.decl.name 124 | 125 | def +=(local: Stmt.Local): Unit = { 126 | locals(local.name.str) = LocalEntry(locals.size, local) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/main/scala/parser/RcBaseParser.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import lexer.* 5 | import lexer.Token 6 | import lexer.Keyword.* 7 | import lexer.Punctuation.* 8 | import lexer.Literal.* 9 | import lexer.Delimiter.* 10 | import lexer.Ident.* 11 | import ast.Ident 12 | import ast.TyInfo 13 | import tools.RcLogger 14 | 15 | import scala.util.parsing.combinator.PackratParsers 16 | import scala.util.parsing.input.{CharSequenceReader, NoPosition, Position, Positional, Reader} 17 | 18 | trait RcBaseParser extends PackratParsers { 19 | override type Elem = Token 20 | 21 | private def take[T](p: Reader[T], n: Int): List[T] = { 22 | if (n > 0 && !p.atEnd) then p.first :: take(p.rest, n - 1) else Nil 23 | } 24 | 25 | override def log[T](p: => Parser[T])(name: String): Parser[T] = Parser { in => 26 | val r = p(in) 27 | if (RcLogger.level > 1) { 28 | in match { 29 | case reader: PackratReader[Token] => 30 | RcLogger.log(s"trying ${name} at (${take(reader, 3).mkString(", ")})") 31 | case _ => 32 | RcLogger.log("trying " + name + " at " + in) 33 | } 34 | 35 | RcLogger.log(name + " --> " + r) 36 | } 37 | r 38 | } 39 | 40 | protected def template: Parser[Ident] = positioned { 41 | OPERATOR("<") ~> upperIdentifier <~ OPERATOR(">") ^^ { id => Ident(id.str) } 42 | } 43 | 44 | private def identifier: Parser[IDENTIFIER] = positioned { 45 | accept("identifier", { case id@IDENTIFIER(name) => id }) 46 | } 47 | 48 | private def upperIdentifier: Parser[UPPER_IDENTIFIER] = positioned { 49 | accept("upper_identifier", { case id@UPPER_IDENTIFIER(name) => id }) 50 | } 51 | 52 | protected def id: Parser[Ident] = positioned { 53 | identifier ^^ { case IDENTIFIER(id) => Ident(id) } 54 | } 55 | 56 | protected def sym: Parser[Ident] = positioned { 57 | upperIdentifier ^^ { case UPPER_IDENTIFIER(id) => Ident(id) } 58 | } 59 | 60 | protected def stringLiteral: Parser[STRING] = positioned { 61 | accept("string literal", { case lit@STRING(str) => lit }) 62 | } 63 | 64 | protected def number: Parser[NUMBER] = positioned { 65 | accept("number literal", { case num@NUMBER(n) => num }) 66 | } 67 | 68 | protected def operator: Parser[OPERATOR] = positioned { 69 | accept("operator", { case op@OPERATOR(_) => op }) 70 | } 71 | 72 | protected def idWithTy: Parser[(Ident, TyInfo)] = { 73 | id ~ (COLON ~> ty).? ^^ { 74 | case id ~ ty => (id, ty.getOrElse(TyInfo.Infer)) 75 | } 76 | } 77 | 78 | protected def ty: Parser[TyInfo] = positioned { 79 | sym ^^ TyInfo.Spec 80 | } 81 | 82 | protected def oneline[T](p: Parser[T]): Parser[T] = log(p <~ EOL)("oneline") 83 | 84 | protected def onelineOpt[T](p: Parser[T]): Parser[T] = log(p <~ EOL.?)("oneline") 85 | 86 | protected def nextline[T](p: Parser[T]): Parser[T] = log(EOL ~> p)("nextline") 87 | 88 | // parenthesesSround 89 | protected def parSround[T](p: Parser[T]) = LEFT_PARENT_THESES ~> p <~ RIGHT_PARENT_THESES 90 | 91 | protected def squareSround[T](p: Parser[T]) = LEFT_SQUARE ~> p <~ RIGHT_SQUARE 92 | 93 | protected def bracketSround[T](p: Parser[T]) = LEFT_BRACKET ~> p <~ RIGHT_BRACKET 94 | 95 | protected def noOrder[T](p1: Parser[T], p2: Parser[T]): Parser[T ~ T] = p1 ~ p2 | p2 ~ p1 96 | 97 | protected def makeParserError(next: Input, msg: String) = RcParserError(Location(next.pos.line, next.pos.column), msg) 98 | 99 | protected def doParser[T](tokens: Seq[Token], parser: Parser[T]): Either[RcParserError, T] = { 100 | doParserImpl(tokens, parser).map(_._1) 101 | } 102 | 103 | protected def doParserImpl[T](tokens: Seq[Token], parser: Parser[T]): Either[RcParserError, (T, Input)] = { 104 | val reader = new RcPackratReader(new RcTokenReader(tokens)) 105 | parser(reader) match { 106 | case NoSuccess(msg, next) => Left(makeParserError(next, msg)) 107 | case Success(result, next) => Right(result, next) 108 | } 109 | } 110 | 111 | def noEmptyEval[T](l: List[T], f: List[T] => List[T], els: List[T] = List()) = if l.isEmpty then els else f(l) 112 | 113 | class RcTokenReader(tokens: Seq[Token]) extends Reader[Token] { 114 | override def first: Token = tokens.head 115 | 116 | override def atEnd: Boolean = tokens.isEmpty 117 | 118 | override def pos: Position = tokens.headOption.map(_.pos).getOrElse(NoPosition) 119 | 120 | override def rest: Reader[Token] = new RcTokenReader(tokens.tail) 121 | 122 | override def toString: String = { 123 | val c = if (atEnd) "" else s"${tokens.slice(0, 3)} ..." 124 | s"RcTokenReader($c)" 125 | } 126 | } 127 | 128 | class RcPackratReader(reader: Reader[Token]) extends PackratReader[Token](reader) { 129 | override def toString: String = { 130 | reader.toString 131 | } 132 | } 133 | } -------------------------------------------------------------------------------- /src/main/scala/compiler/Driver.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package compiler 3 | 4 | import analysis.* 5 | import ast.ClassesRender 6 | import lexer.Lexer 7 | import mir.* 8 | import parser.RcParser 9 | import pass.{AnalysisManager, PassManager, Transform} 10 | import tools.* 11 | import tools.RcLogger.{log, logf, warning} 12 | import transform.CFGSimplify 13 | import ty.* 14 | 15 | import java.nio.file.{Files, Paths} 16 | //import analysis.Analysis.`given` 17 | import analysis.Analysis.{given_DomTreeAnalysis, given_LoopAnalysis} 18 | import ast.{Class, Ident, Item, RcModule} 19 | import codegen.* 20 | 21 | import java.io.File 22 | import java.nio.file.Path 23 | import scala.io.Source 24 | import scala.sys.process.* 25 | 26 | object Driver { 27 | def getSrc(path: String) = { 28 | val f = Source fromFile path 29 | // avoid last line is end and lost last empty line 30 | val src = f.getLines.mkString("\n") + "\n" 31 | f.close() 32 | src 33 | } 34 | 35 | def parse(path: String): RcModule = { 36 | val src = getSrc(path) 37 | val tokens = log(Lexer(src).unwrap, "Lexer") 38 | logf("token.txt", tokens.mkString(" ").replace("EOL", "\n")) 39 | val module = log(RcParser(tokens).unwrap, "Parser").tap { 40 | logf("ast.txt", _) 41 | } 42 | module.copy(name = Paths.get(path).getFileName.toString) 43 | } 44 | 45 | def typeProc(ast: RcModule): (RcModule, GlobalTable) = { 46 | val table = SymScanner(ast) 47 | val tyCtxt = TyCtxt() 48 | tyCtxt.setGlobalTable(table) 49 | val typedModule = TypedTranslator(tyCtxt)(ast) 50 | logf("typedModule.txt", typedModule) 51 | // ClassesRender().rendClasses("classes.dot", DumpManager.getDumpRoot, typedModule.items collect { case i: Class => i }) 52 | TypeCheck(typedModule) 53 | (typedModule, tyCtxt.globalTable) 54 | } 55 | 56 | def simplify(fn: Function) = { 57 | val pm = PassManager[Function]() 58 | val am = AnalysisManager[Function]() 59 | pm.addPass(CFGSimplify()) 60 | pm.run(fn, am) 61 | } 62 | 63 | def dumpDomTree(fn: Function) = { 64 | simplify(fn) 65 | CFGRender.rendFn(fn, "whileBBs") 66 | val am = AnalysisManager[Function]() 67 | am.addAnalysis(DomTreeAnalysis()) 68 | val domTree = am.getResult[DomTreeAnalysis](fn) 69 | println(domTree) 70 | val loop = am.getResult[LoopAnalysis](fn) 71 | } 72 | 73 | def apply(option: CompileOption): Unit = { 74 | DumpManager.mkDumpRootDir 75 | val modules = option.srcPath.map(parse) 76 | dependencyResolve(modules) 77 | modules.foreach(compileAST) 78 | } 79 | 80 | def compileAST(module: RcModule): Unit = { 81 | val (typedModule, table) = typeProc(module) 82 | val mirMod = log(MIRTranslator(table).proc(typedModule), "ToMIR") 83 | logf("mir.txt", mirMod) 84 | // dumpDomTree(mirMod.fnTable.values.head) 85 | codegen(mirMod) 86 | } 87 | 88 | def dumpPass(mf: MachineFunction, pass: Transform[MachineFunction]) = { 89 | val path = Paths.get(DumpManager.getDumpRoot / "Pass") 90 | if(!Files.exists(path)) { 91 | Files.createDirectories(path) 92 | } 93 | logf(f"Pass/${pass.getClass.getName.split('.').last}_${mf.name}.txt") { writer => 94 | MachineIRPrinter().printToWriter(mf, writer) 95 | } 96 | } 97 | 98 | def dumpFrameInfo(mf: MachineFunction): Unit = { 99 | val path = Paths.get(DumpManager.getDumpRoot / "FrameInfo") 100 | if (!Files.exists(path)) { 101 | Files.createDirectories(path) 102 | } 103 | logf(f"FrameInfo/${mf.name}.txt", mf.frameInfo.toString) 104 | } 105 | 106 | def codegen(mirMod: Module) = { 107 | val translator = IRTranslator() 108 | val fns = translator.visit(mirMod.fns) 109 | // MachineIRPrinter().print(fns) 110 | val pm = PassManager[MachineFunction]() 111 | pm.addPass(new PhiEliminate()) 112 | pm.addPass(new StackRegisterAllocation()) 113 | pm.registerAfterPass(dumpPass) 114 | val am = AnalysisManager[MachineFunction]() 115 | fns.foreach(pm.run(_, am)) 116 | fns.foreach(dumpFrameInfo) 117 | generateASM(fns, translator.strTable, DumpManager.getDumpRoot / "asm.s") 118 | // genELF(true) 119 | } 120 | 121 | def genELF(hasMain: Boolean) = { 122 | val o = log(as(DumpManager.getDumpRoot / "asm.s", DumpManager.getDumpRoot / "tmp.o"), "As") 123 | if (hasMain) { 124 | log(toExe(o.get), "ToELF") 125 | } else { 126 | warning("don't has main") 127 | } 128 | } 129 | 130 | def as(srcPath: String, destPath: String): Option[String] = { 131 | val args = List(srcPath, "-o", destPath) 132 | val out = s"as ${args.mkString(" ")}".!! 133 | Some(destPath) 134 | } 135 | 136 | def toExe(asmPath: String) = { 137 | val outPath = asmPath.replace("tmp.o", "a.out") 138 | val args = List(asmPath, "-o", outPath) 139 | val out = s"gcc ${args.mkString(" ")}".!! 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/test/scala/parser/BaseParserTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | import org.scalatest.matchers.should.* 6 | import lexer.Token 7 | import lexer.Keyword.* 8 | import lexer.Punctuation.* 9 | import lexer.Literal.* 10 | import lexer.Delimiter.* 11 | import lexer.Ident.* 12 | import ast.{ASTBuilder, Expr, FieldDef, Ident, Item, MethodDecl, Param, Params, Stmt, TyInfo} 13 | import ast.ImplicitConversions.* 14 | import ast.Expr.{Block, If} 15 | 16 | import scala.util.parsing.input.Positional 17 | 18 | trait BaseParserTest extends AnyFunSpec with RcBaseParser with Matchers with ASTBuilder { 19 | def withParentThese(tokens: List[Token]) = List(LEFT_PARENT_THESES) ::: tokens ::: List(RIGHT_PARENT_THESES) 20 | def mkASTCall(target: String, generic: String, args: List[Expr]) = Expr.Call(target, args, Some(Ident(generic))) 21 | def wrapWithAngleBrackets(s: Token) = List(OPERATOR("<"), s, OPERATOR(">")) 22 | def wrapWithAngleBrackets(s: List[Token]) = List(OPERATOR("<")):::s:::List(OPERATOR(">")) 23 | def parSround(tokens: List[Token]): List[Token] = LEFT_PARENT_THESES::tokens:::RIGHT_PARENT_THESES::List() 24 | def makeWhile(cond: Token, body: List[Token]): List[Token] = WHILE::parSround(List(cond)):::EOL::body:::EOL::END::EOL::List() 25 | def mkTKArgs(argsTokens: List[Token]): List[Token] = parSround(sepWithComma(argsTokens)) 26 | def mkTKArgsList(argsTokens: List[List[Token]]): List[Token] = parSround(sepListWithComma(argsTokens)) 27 | def mkTkMemField(name: String, field: String) = List(IDENTIFIER(name), DOT, IDENTIFIER(field)) 28 | 29 | def mkAssStmt(name: String, expr: Token): List[Token] = List(IDENTIFIER(name), EQL, expr, EOL) 30 | def mkLocalStmt(name: String, expr: Token): List[Token] = List(VAR, IDENTIFIER(name), EQL, expr, EOL) 31 | 32 | def makeCall(name: String, args: List[Token]): List[Token] = 33 | IDENTIFIER(name)::LEFT_PARENT_THESES:: 34 | noEmptyEval(args, _ => 35 | args.zip(List.fill(args.length - 1)(COMMA).appended(RIGHT_PARENT_THESES)) 36 | .flatten{ case (a, b) => List(a, b) }, 37 | List(RIGHT_PARENT_THESES)) 38 | 39 | def trueExpr = Expr.Bool(true) 40 | def falseExpr = Expr.Bool(false) 41 | def makeElsif(lists: List[(Token, Token)]): List[Token] = lists.map((x, y) => ELSIF::x::EOL::y::EOL::List()).reduce(_:::_) 42 | def makeIf(cond: List[Token], thenTokens: List[Token], elsifTokens: List[Token] = List(), elseTokens:List[Token] = List()): List[Token] = 43 | IF::cond 44 | .concat(EOL::thenTokens) 45 | .concat(noEmptyEval(elsifTokens, EOL::_)) 46 | .concat(noEmptyEval(elseTokens, EOL::ELSE::EOL::_)) 47 | .appended(EOL) 48 | .appended(END) 49 | def makeIf(cond: Token, thenToken: Token, elsifTokens: List[Token], elseToken: Token): List[Token] = makeIf(List(cond), List(thenToken), elsifTokens, List(elseToken)) 50 | def makeIf(cond: List[Token], thenToken: Token, elsifTokens: List[Token], elseToken: Token): List[Token] = makeIf(cond, List(thenToken), elsifTokens, List(elseToken)) 51 | def makeIf(cond: Token, thenToken: Token, elsifTokens: List[Token]): List[Token] = makeIf(List(cond), List(thenToken), elsifTokens, List()) 52 | def makeIf(cond: Token, thenToken: Token, elseToken: Token): List[Token] = makeIf(List(cond), List(thenToken), List(), List(elseToken)) 53 | 54 | def mkTkMemCall(name: String, field: String, args: List[Token] = List()): List[Token] = 55 | List(IDENTIFIER(name), DOT, IDENTIFIER(field), 56 | LEFT_PARENT_THESES):::args:::RIGHT_PARENT_THESES::Nil 57 | 58 | 59 | def makeTokenMethod(name: String, stmts: List[Token] = List()): List[Token] = { 60 | List(DEF, IDENTIFIER(name), LEFT_PARENT_THESES, RIGHT_PARENT_THESES, EOL):::(stmts).appended(END).appended(EOL) 61 | } 62 | 63 | def mkGenericToken(generic: Option[String]) = { 64 | generic.map(s => wrapWithAngleBrackets(UPPER_IDENTIFIER(s))).getOrElse(List()) 65 | } 66 | 67 | def mkEmptyTokenMethod(name: String, params: List[Token] = List(), generic: Option[String] = None): List[Token] = { 68 | List(DEF, IDENTIFIER(name)):::mkGenericToken(generic):::List(LEFT_PARENT_THESES):::params:::List(RIGHT_PARENT_THESES):::List(EOL, END, EOL) 69 | } 70 | 71 | def sepWithComma(tokens: List[Token]): List[Token] = { 72 | sepListWithComma(tokens.map(List(_))) 73 | } 74 | 75 | def sepListWithComma(tokens: List[List[Token]]): List[Token] = { 76 | tokens.flatMap(_.appended(COMMA)).init 77 | } 78 | 79 | def makeLocal(name: String, value: Token) = { 80 | List(VAR, IDENTIFIER(name), EQL, value, EOL) 81 | } 82 | 83 | def mkTokenField(name: String, ty: String) = List(VAR, IDENTIFIER(name), COLON, UPPER_IDENTIFIER(ty), EOL) 84 | def mkTokenClass(name: String, tokens: List[Token] = List(), generic: Option[String] = None) = List(CLASS, UPPER_IDENTIFIER(name)):::mkGenericToken(generic):::List(EOL):::tokens:::END::EOL::Nil 85 | def mkTokenClass(name: String, parent: String) = List(CLASS, UPPER_IDENTIFIER(name), OPERATOR("<"), UPPER_IDENTIFIER(parent), EOL):::END::EOL::Nil 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/test/scala/lexer/LexerTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package lexer 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | import lexer.Lexer 6 | import lexer.Keyword.* 7 | import lexer.Punctuation.* 8 | import lexer.Literal.* 9 | import lexer.Delimiter.* 10 | import lexer.Ident.* 11 | 12 | import org.scalatest.funspec.AnyFunSpec 13 | 14 | class LexerTest extends AnyFunSpec { 15 | def singleToken(tokens: List[Token]): Token = { 16 | assert(tokens.size == 1, tokens) 17 | tokens.last 18 | } 19 | 20 | def expectSuccess(str: String, token: Token) = { 21 | Lexer(str) match { 22 | case Left(value) => assert(false, value.msg) 23 | case Right(value) => assert(singleToken(value) == token) 24 | } 25 | } 26 | 27 | def expectSuccess(str: String, tokens: List[Token]) = { 28 | Lexer(str) match { 29 | case Left(value) => assert(false, value.msg) 30 | case Right(value) =>assert(value == tokens, if value.isEmpty then "" else value.last.pos) 31 | } 32 | } 33 | 34 | def expectFailed(str: String) = { 35 | Lexer(str) match { 36 | case Left(value) => 37 | case Right(value) => assert(false, singleToken(value)) 38 | } 39 | } 40 | 41 | def expectNotEql(str: String, token: Token) = { 42 | Lexer(str) match { 43 | case Left(value) => assert(false, value.msg) 44 | case Right(value) => assert(singleToken(value) != token) 45 | } 46 | } 47 | 48 | describe("number") { 49 | it("succeed") { 50 | expectSuccess("123", NUMBER(123)) 51 | } 52 | } 53 | 54 | describe("bool") { 55 | it("succeed") { 56 | expectSuccess("true", TRUE) 57 | expectSuccess("false", FALSE) 58 | } 59 | } 60 | 61 | describe("identifier") { 62 | it("succeed") { 63 | expectSuccess("foo", IDENTIFIER("foo")) 64 | expectSuccess("foo1", IDENTIFIER("foo1")) 65 | expectFailed("1foo") 66 | } 67 | } 68 | 69 | describe("upperIdentifier") { 70 | it("succeed") { 71 | expectSuccess("Foo", UPPER_IDENTIFIER("Foo")) 72 | expectSuccess("Foo1", UPPER_IDENTIFIER("Foo1")) 73 | } 74 | } 75 | 76 | def expectKeywordNotId(str: String): Unit = { 77 | expectNotEql(str, IDENTIFIER(str)) 78 | } 79 | 80 | describe("keyword is not a id") { 81 | it("succeed") { 82 | val keywords = List("true", "false", "def", "end", "if", "elsif", "else", "while", "class", "super", "self") 83 | keywords.map(expectKeywordNotId) 84 | } 85 | } 86 | 87 | describe("string") { 88 | it("succeed") { 89 | expectSuccess("\"describe str\"", STRING("describe str")) 90 | expectFailed("\"describe str") 91 | } 92 | } 93 | 94 | describe("operator") { 95 | it("succeed") { 96 | def expectOp(op: Char) = expectSuccess(op.toString, OPERATOR(op.toString)) 97 | 98 | "+-*/%^~!<>".foreach(expectOp) 99 | } 100 | } 101 | 102 | describe("spacer") { 103 | // a is notSpacer, b is spacer 104 | it("AB") { 105 | expectSuccess("id", List(IDENTIFIER("id"))) 106 | } 107 | 108 | it("ABA") { 109 | expectSuccess("id id", List(IDENTIFIER("id"), IDENTIFIER("id"))) 110 | } 111 | 112 | it("BAB") { 113 | expectSuccess(" id ", List(IDENTIFIER("id"))) 114 | } 115 | 116 | it("ABABB space and eol") { 117 | expectSuccess("def f \n", List(DEF, IDENTIFIER("f"), EOL)) 118 | } 119 | 120 | it("BABA") { 121 | expectSuccess(" def f", List(DEF, IDENTIFIER("f"))) 122 | } 123 | 124 | it("only space") { 125 | expectSuccess(" ", List()) 126 | } 127 | 128 | it("local") { 129 | val v = List(IDENTIFIER("a"), EQL, NUMBER(1)) 130 | expectSuccess("a = 1", v) 131 | expectSuccess("a = 1 ", v) 132 | expectSuccess("a =1", v) 133 | expectSuccess("a=1", v) 134 | } 135 | } 136 | 137 | describe("NoValueToken") { 138 | it("ValAndValue") { 139 | expectSuccess("value", IDENTIFIER("value")) 140 | expectSuccess("val ", VAL) 141 | expectSuccess("val", IDENTIFIER("val")) 142 | } 143 | } 144 | 145 | describe ("eol") { 146 | it("basic succeed") { 147 | expectSuccess("id \n id", List(IDENTIFIER("id"), EOL, IDENTIFIER("id"))) 148 | } 149 | 150 | it("succeed") { 151 | expectSuccess("def main \n end", List(DEF, IDENTIFIER("main"), EOL, END)) 152 | } 153 | } 154 | 155 | describe("fun") { 156 | it("empty") { 157 | val src = """def main() end""" 158 | expectSuccess(src, List(DEF, IDENTIFIER("main"), LEFT_PARENT_THESES, RIGHT_PARENT_THESES, END)) 159 | } 160 | 161 | it("with local") { 162 | val src = """def main 163 | var a = 1 164 | end""" 165 | expectSuccess(src, List(DEF, IDENTIFIER("main"), EOL, VAR, IDENTIFIER("a"), EQL, NUMBER(1), EOL, END)) 166 | } 167 | 168 | it("with if") { 169 | val src = """def main 170 | if a < 3 171 | a = 1 172 | else 173 | a = 2 174 | end""" 175 | expectSuccess(src, List(DEF, IDENTIFIER("main"), EOL, 176 | IF, IDENTIFIER("a"), OPERATOR("<"), NUMBER(3), EOL, 177 | IDENTIFIER("a"), EQL, NUMBER(1), EOL, 178 | ELSE, EOL, 179 | IDENTIFIER("a"), EQL, NUMBER(2), EOL, 180 | END)) 181 | } 182 | 183 | it("with call") { 184 | val src = """def main 185 | put_i(a) 186 | end""" 187 | expectSuccess(src, List(DEF, IDENTIFIER("main"), EOL, 188 | IDENTIFIER("put_i"), LEFT_PARENT_THESES, IDENTIFIER("a"), RIGHT_PARENT_THESES, EOL, 189 | END)) 190 | } 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /src/main/scala/analysis/ModuleValidate.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package analysis 3 | 4 | import ast.* 5 | import ast.Expr.{Block, Self} 6 | import ast.Method 7 | import tools.State 8 | 9 | import scala.collection.mutable 10 | import scala.collection.mutable.Set 11 | import scala.language.implicitConversions 12 | 13 | case class ValidateError(node: ASTNode, reason: String) 14 | 15 | case class Scope(var localTable: Set[Ident] = Set()) { 16 | 17 | def add(ident: Ident): Boolean = { 18 | localTable.add(ident) 19 | } 20 | 21 | def contains(ident: Ident): Boolean = { 22 | localTable.contains(ident) 23 | } 24 | } 25 | 26 | case class ScopeManager() { 27 | private var scopes = List[Scope]() 28 | def enter[T](f:() => T): T = { 29 | enter(Params(List()), f) 30 | } 31 | 32 | def enter[T](params: Params, f:() => T): T = { 33 | val oldScope = scopes 34 | scopes ::= Scope(mutable.Set.from(params.params.map(_.name))) 35 | val result = f() 36 | scopes = oldScope 37 | result 38 | } 39 | 40 | def curScope: Scope = scopes.last 41 | 42 | def add(ident: Ident): Boolean = curScope.add(ident) 43 | 44 | def contains(ident: Ident): Boolean = { 45 | !scopes.exists(_.contains(ident)) 46 | } 47 | 48 | def curContains(ident: Ident): Boolean = curScope.contains(ident) 49 | } 50 | 51 | type Result = List[ValidateError] 52 | 53 | trait Validate { 54 | def dupNameCheck(names: List[Ident]): Result = { 55 | dupCheck(names, "Name") 56 | } 57 | 58 | def dupCheck[T <: ASTNode](values: List[T], valueName: String): Result = { 59 | val s = Set[T]() 60 | values.filterNot(s.add).map(n => ValidateError(n, s"$valueName $n Dup")) 61 | } 62 | 63 | def checkCond(cond: Boolean, node: ASTNode, msg: String): Result = { 64 | if(cond) then List() else List(ValidateError(node, msg)) 65 | } 66 | 67 | def checkOption[T](v: Option[T], f: T => Result): Result = { 68 | v match 69 | case Some(v) => f(v) 70 | case None => List() 71 | } 72 | 73 | def valid: Result = List() 74 | } 75 | 76 | trait MethodValidate extends Validate { 77 | var scopes = ScopeManager() 78 | def analysis(method: Method): Result = { 79 | checkMethod(method) 80 | } 81 | 82 | def checkMethod(method: Method): Result = { 83 | checkMethodDecl(method.decl) 84 | checkBlock(method.body, method.decl.inputs) 85 | } 86 | 87 | def checkMethodDecl(decl: MethodDecl): Result = { 88 | dupCheck(decl.inputs.params.map(_.name), "MethodParam") 89 | } 90 | 91 | def checkBlock(block: Block, params: Params = Params(List())): Result = { 92 | scopes.enter(params, () => { 93 | block.stmts.flatMap(checkStmt) 94 | }) 95 | } 96 | 97 | def checkStmt(stmt: Stmt): Result = { 98 | stmt match 99 | case Stmt.Local(name, ty, value) => checkCond(!scopes.curContains(name), name, "$name redecl in current scope") 100 | case Stmt.Expr(expr) => checkExpr(expr) 101 | case Stmt.While(cond, body) => checkExpr(cond):::checkExpr(body) 102 | case Stmt.Assign(name, value) => checkCond(scopes.contains(name), name, "$name not decl") 103 | } 104 | 105 | def checkExpr(expr: Expr): Result = { 106 | expr match 107 | case Expr.Identifier(id) => checkCond(scopes.contains(id), expr, "$name not decl") 108 | case Expr.Binary(op, lhs, rhs) => checkExpr(lhs):::checkExpr(rhs) 109 | case Expr.If(cond, true_branch, false_branch) => checkExpr(cond):::checkExpr(true_branch)::: { 110 | false_branch match 111 | case Some(x) => checkExpr(x) 112 | case None => List() 113 | } 114 | case Expr.Lambda(args, block) => checkBlock(block, args) 115 | case Expr.Call(target, args, _) => args.flatMap(checkExpr) 116 | case Expr.MethodCall(obj, target, args) => (obj::args).flatMap(checkExpr) 117 | case block: Block => checkExpr(block) 118 | case Expr.Return(expr) => checkExpr(expr) 119 | // Field And Call can't resolve type in this phase 120 | case Expr.Field(expr, ident) => checkExpr(expr) 121 | case Expr.Index(expr, i) => checkExpr(expr) 122 | case _ => valid 123 | } 124 | } 125 | 126 | trait ModuleValidate extends Validate with MethodValidate { 127 | def analysis(module: RcModule): Result = { 128 | checkModule(module) 129 | } 130 | 131 | def checkModule(module: RcModule): Result = { 132 | dupNameCheck(module.items.map(item => item match 133 | case Class(name, _, _, _, _) => name 134 | case Method(decl, _) => decl.name 135 | )):::module.items.flatMap(checkItem) 136 | } 137 | 138 | def checkItem(item: Item): Result = { 139 | item match 140 | case m: Method => checkMethod(m) 141 | case klass: Class => checkClass(klass) 142 | } 143 | 144 | def checkClass(klass: Class): Result = { 145 | klass match 146 | case Class(name, parent, vars, methods, _) => { 147 | dupNameCheck(vars.map(_.name)):::dupNameCheck(methods.map(_.decl.name)) 148 | } 149 | } 150 | 151 | def methodsDeclValid(decls: List[MethodDecl]): Result = { 152 | val s = Set[(Ident, Params)]() 153 | decls.map(decl => (decl.name, decl.inputs)).filterNot(s.add(_)).map(t => ValidateError(t._1, s"Method ${t._1} Dup")) 154 | } 155 | 156 | def fieldDefsValid(fields: List[FieldDef]): Result = { 157 | fields.flatMap(fieldDefValid):::dupNameCheck(fields.map(_.name)) 158 | } 159 | 160 | def fieldDefValid(fieldDef: FieldDef): Result = { 161 | fieldDef.initValue match { 162 | case Some(expr) => checkExpr(expr) 163 | case None => checkCond(fieldDef.ty != TyInfo.Infer, fieldDef, "Field without initValue need spec Type") 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/test/scala/parser/ModuleParserTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | import lexer.Token 6 | import lexer.Keyword.* 7 | import lexer.Punctuation.* 8 | import lexer.Literal.* 9 | import lexer.Delimiter.* 10 | import lexer.Ident.* 11 | 12 | import ast.{Params, Expr, Item} 13 | import ast.Expr.* 14 | import ast.* 15 | import ast.ImplicitConversions.* 16 | 17 | class ModuleParserTest extends BaseParserTest with ModuleParser { 18 | def apply(tokens: Seq[Token]): Either[RcParserError, Item] = { 19 | doParser(tokens, item) 20 | } 21 | 22 | def expectSuccess(token: Seq[Token], expect: Item): Unit = { 23 | apply(token) match { 24 | case Left(value) => assert(false, value.msg) 25 | case Right(value) => { 26 | assert(value == expect) 27 | } 28 | } 29 | } 30 | 31 | def expectFailed(token: Seq[Token]): Unit = { 32 | apply(token) match { 33 | case Left(value) => assert(true) 34 | case Right(value) => assert(false, s"expect failed, value: $value") 35 | } 36 | } 37 | 38 | describe("fun") { 39 | describe("params") { 40 | it("not spec type") { 41 | expectSuccess( 42 | mkEmptyTokenMethod("f", 43 | sepWithComma(List(IDENTIFIER("a"), IDENTIFIER("b")))), 44 | makeASTMethod("f", 45 | List(Param("a", TyInfo.Infer), 46 | Param("b", TyInfo.Infer)))) 47 | } 48 | 49 | it("spec type") { 50 | expectSuccess( 51 | mkEmptyTokenMethod("f", 52 | sepListWithComma(List( 53 | List(IDENTIFIER("a"), COLON, UPPER_IDENTIFIER("Int")), 54 | List(IDENTIFIER("a"), COLON, UPPER_IDENTIFIER("Int"))))), 55 | makeASTMethod("f", 56 | List( 57 | Param("a", TyInfo.Spec("Int")), 58 | Param("a", TyInfo.Spec("Int")) 59 | ))) 60 | } 61 | } 62 | 63 | describe("body") { 64 | it("empty") { 65 | expectSuccess(makeTokenMethod("foo"), 66 | makeASTMethod("foo")) 67 | } 68 | 69 | it("with one line") { 70 | expectSuccess( 71 | makeTokenMethod("foo", makeLocal("a", NUMBER(1))), 72 | makeASTMethod("foo", block = List(Stmt.Local("a", TyInfo.Infer, Number(1))))) 73 | } 74 | 75 | it("with multi line") { 76 | expectSuccess( 77 | makeTokenMethod("foo", 78 | makeLocal("a", NUMBER(1)) 79 | .concat(makeLocal("a", NUMBER(1)))), 80 | makeASTMethod("foo", 81 | block = List( 82 | makeLocal("a", Number(1)), 83 | makeLocal("a", Number(1))))) 84 | } 85 | 86 | it("multi line with multi eol") { 87 | expectSuccess( 88 | makeTokenMethod("foo", 89 | makeLocal("a", NUMBER(1)) 90 | ::: (EOL :: EOL :: List()) 91 | ::: (makeLocal("a", NUMBER(1)))), 92 | makeASTMethod("foo", 93 | block = List( 94 | makeLocal("a", Number(1)), 95 | makeLocal("a", Number(1))))) 96 | } 97 | } 98 | } 99 | 100 | describe("class") { 101 | it("empty class") { 102 | expectSuccess(mkTokenClass("Foo"), mkASTClass("Foo")) 103 | } 104 | 105 | it("class with var") { 106 | expectSuccess(mkTokenClass("Foo", mkTokenField("a", "Int")), mkASTClass("Foo", mkASTField("a", "Int"))) 107 | } 108 | 109 | it("class with method") { 110 | expectSuccess(mkTokenClass("Foo", makeTokenMethod("a")), mkASTClass("Foo", makeASTMethod("a"))) 111 | } 112 | 113 | describe("class with method and var") { 114 | it("normal succeed") { 115 | expectSuccess( 116 | mkTokenClass("Foo", makeTokenMethod("f1") ::: mkTokenField("a", "Int")), 117 | mkASTClass("Foo", mkASTField("a", "Int"), makeASTMethod("f1"))) 118 | } 119 | 120 | it("include eol") { 121 | expectSuccess( 122 | mkTokenClass("Foo", makeTokenMethod("f1") ::: EOL :: mkTokenField("a", "Int")):::EOL::Nil, 123 | mkASTClass("Foo", mkASTField("a", "Int"), makeASTMethod("f1"))) 124 | } 125 | } 126 | 127 | it("must uppercase") { 128 | expectFailed(List(CLASS, IDENTIFIER("foo"), EOL, END, EOL)) 129 | } 130 | 131 | it("not supported oneline class") { 132 | expectFailed(List(CLASS, IDENTIFIER("Foo"), END)) 133 | } 134 | 135 | it("inherit") { 136 | expectSuccess(mkTokenClass("Foo", "Parent"), mkASTClass("Foo", "Parent")) 137 | } 138 | 139 | describe("multiModuleItem") { 140 | def apply(tokens: Seq[Token]): Either[RcParserError, RcModule] = { 141 | doParser(tokens, module) 142 | } 143 | def expectSuccess(token: Seq[Token], expect: RcModule): Unit = { 144 | apply(token) match { 145 | case Left(value) => assert(false, value.msg) 146 | case Right(value) => assert(value == expect) 147 | } 148 | } 149 | 150 | // 1. parse ok 151 | // 2. filter None 152 | it("splitWithEol") { 153 | expectSuccess( 154 | mkTokenClass("Foo"):::EOL::mkEmptyTokenMethod("f"), 155 | RcModule(List(mkASTClass("Foo"), makeASTMethod("f")))) 156 | } 157 | } 158 | } 159 | 160 | describe("templateFn") { 161 | it("define") { 162 | // todo: with method 163 | val foo = mkEmptyTokenMethod("foo", generic = Some("T")) 164 | val fooAST = makeASTMethod("foo", generic = Some("T")) 165 | expectSuccess(foo, fooAST) 166 | } 167 | } 168 | 169 | describe("templateClass") { 170 | it("define") { 171 | val treeNode = mkTokenClass("TreeNode", generic = Some("T")) 172 | val treeNodeAST = Class("TreeNode", None, List(), List(), Some("T")) 173 | expectSuccess(treeNode, treeNodeAST) 174 | } 175 | } 176 | } -------------------------------------------------------------------------------- /src/main/scala/lexer/Lexer.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package lexer 3 | 4 | import lexer.* 5 | import lexer.Keyword.* 6 | import lexer.Punctuation.* 7 | import lexer.Literal.* 8 | import lexer.Delimiter.* 9 | import lexer.Ident.* 10 | import rclang.RcLexerError 11 | import rclang.Location 12 | 13 | import scala.util.matching.Regex 14 | import scala.util.parsing.combinator.RegexParsers 15 | 16 | object Lexer extends RegexParsers { 17 | override def skipWhitespace = false 18 | override val whiteSpace: Regex = "[ \t\r\f]+".r 19 | 20 | def eliminateComment(src: String) = { 21 | val end = if(src.endsWith("\n")) then "\n" else "" 22 | (src.split("\n").map(c => { 23 | val begin = c.indexOf("#") 24 | if (begin != -1) { 25 | c.slice(0, begin) 26 | } else { 27 | c 28 | } 29 | }).mkString("\n")) + end 30 | } 31 | def apply(originSrc: String): Either[RcLexerError, List[Token]] = { 32 | val code = eliminateComment(originSrc) 33 | parse(tokens, code) match { 34 | case NoSuccess(msg, next) => Left(RcLexerError(Location(next.pos.line, next.pos.column), msg)) 35 | case Success(result, next) => Right(result) 36 | } 37 | } 38 | 39 | def keyword: Parser[Token] = stringLiteral | trueLiteral | falseLiteral | 40 | defStr | endStr | ifStr | thenStr | elsifStr | elseStr | whileStr | breakStr | 41 | continueStr | classStr | superStr | selfStr | varStr | valStr | importStr | forStr | returnStr 42 | def symbol: Parser[Token] = comment | comma | eol | dot | at | colon | semicolon | 43 | leftParentTheses | rightParentTheses | leftSquare | rightSquare | leftBracket | rightBracket 44 | 45 | def value: Parser[Token] = number | upperIdentifier | identifier 46 | 47 | def ops = "[+\\-*/%^~!><]|(==)".r 48 | def operator: Parser[Token] = positioned { 49 | ops ^^ OPERATOR 50 | } 51 | 52 | def tokens: Parser[List[Token]] = { 53 | phrase(allTokens) 54 | } 55 | 56 | def rep1sepNoDis[T](p : => Parser[T], q : => Parser[Any]): Parser[List[T]] = 57 | p ~ rep(q ~ p) ^^ {case x~y => x::y.map(x => List(x._1.asInstanceOf[T], x._2)).fold(List())(_:::_)} 58 | 59 | def allTokens: Parser[List[Token]] = { 60 | ((rep1sepNoDis(repN(1, notSpacer), spacer.+) ~ spacer.*) | 61 | // BAA is imposible 62 | (rep1sepNoDis(spacer.+, repN(1, notSpacer)) ~ notSpacer.?)) ^^ { 63 | case list ~ t => 64 | list 65 | .fold(List())(_:::_) 66 | .concat(t match { 67 | case Some(v) => List(v) 68 | case None => List() 69 | case _ => t 70 | }) 71 | .filter(_ != SPACE) 72 | } 73 | } 74 | 75 | def space: Parser[Token] = positioned { 76 | whiteSpace.+ ^^^ SPACE 77 | } 78 | 79 | def notSpacer: Parser[Token] = keyword | value | eol 80 | 81 | def spacer: Parser[Token] = symbol | eql | operator | space 82 | 83 | def upperIdentifier: Parser[UPPER_IDENTIFIER] = positioned { 84 | "[A-Z_][a-zA-Z0-9_]*".r ^^ { str => UPPER_IDENTIFIER(str) } 85 | } 86 | 87 | def identifier: Parser[IDENTIFIER] = positioned { 88 | "[a-zA-Z_][a-zA-Z0-9_]*".r ^^ { str => IDENTIFIER(str) } 89 | } 90 | 91 | def stringLiteral: Parser[STRING] = positioned { 92 | """"[^"]*"""".r ^^ { str => 93 | val content = str.substring(1, str.length - 1) 94 | STRING(content) 95 | } 96 | } 97 | 98 | def number = positioned { 99 | """(0|[1-9]\d*)""".r ^^ { i => NUMBER(i.toInt) } 100 | } 101 | 102 | def NoValueTokenKeyWord(str: String, token: Token): Parser[Token] = positioned { 103 | str ~ guard(spacer) ^^^ token 104 | } 105 | 106 | def NoValueTokenWithGuard(str: String, token: Token, guardStr: String): Parser[Token] = positioned { 107 | str ~ not(guard(guardStr)) ^^^ token 108 | } 109 | 110 | def NoValueTokenSymbol(str: String, token: Token): Parser[Token] = positioned { 111 | str ^^^ token 112 | } 113 | 114 | def comment = NoValueTokenSymbol("#", COMMENT) 115 | def eol = NoValueTokenSymbol("\n", EOL) 116 | def eql = NoValueTokenWithGuard("=", EQL, "=") 117 | def comma = NoValueTokenSymbol(",", COMMA) 118 | def dot = NoValueTokenSymbol(".", DOT) 119 | def colon = NoValueTokenSymbol(":", COLON) 120 | def semicolon = NoValueTokenSymbol(";", SEMICOLON) 121 | def at = NoValueTokenSymbol("@", AT) 122 | 123 | def trueLiteral = NoValueTokenKeyWord("true", TRUE) 124 | def falseLiteral = NoValueTokenKeyWord("false", FALSE) 125 | 126 | def varStr = NoValueTokenKeyWord("var", VAR) 127 | def valStr = NoValueTokenKeyWord("val", VAL) 128 | def defStr = NoValueTokenKeyWord("def", DEF) 129 | def returnStr = NoValueTokenKeyWord("return", RETURN) 130 | def endStr = NoValueTokenKeyWord("end", END) 131 | 132 | def ifStr = NoValueTokenKeyWord("if", IF) 133 | def thenStr = NoValueTokenKeyWord("then", THEN) 134 | def elsifStr = NoValueTokenKeyWord("elsif", ELSIF) 135 | def elseStr = NoValueTokenKeyWord("else", ELSE) 136 | def whileStr = NoValueTokenKeyWord("while", WHILE) 137 | def forStr = NoValueTokenKeyWord("for", FOR) 138 | def breakStr = NoValueTokenKeyWord("break", BREAK) 139 | def continueStr = NoValueTokenKeyWord("continue", CONTINUE) 140 | 141 | def classStr = NoValueTokenKeyWord("class", CLASS) 142 | def superStr = NoValueTokenKeyWord("super", SUPER) 143 | def selfStr = NoValueTokenKeyWord("self", SELF) 144 | def varsStr = NoValueTokenKeyWord("vars", VARS) 145 | def methods = NoValueTokenKeyWord("methods", METHODS) 146 | 147 | def importStr = NoValueTokenKeyWord("import", IMPORT) 148 | 149 | def leftParentTheses = NoValueTokenSymbol("(", LEFT_PARENT_THESES) 150 | def rightParentTheses = NoValueTokenSymbol(")", RIGHT_PARENT_THESES) 151 | def leftSquare = NoValueTokenSymbol("[", LEFT_SQUARE) 152 | def rightSquare = NoValueTokenSymbol("]", RIGHT_SQUARE) 153 | def leftBracket = NoValueTokenSymbol("{", LEFT_BRACKET) 154 | def rightBracket = NoValueTokenSymbol("}", RIGHT_BRACKET) 155 | } -------------------------------------------------------------------------------- /src/main/scala/transform/ConstantFolding.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package transform 3 | 4 | import mir.ImplicitConversions.* 5 | import mir.{InstVisitor, *} 6 | import pass.{AnalysisManager, Transform} 7 | 8 | import scala.math 9 | import scala.math.Fractional.Implicits.infixFractionalOps 10 | import scala.math.Integral.Implicits.infixIntegralOps 11 | import scala.math.Numeric.Implicits.infixNumericOps 12 | 13 | 14 | //def isConstant(v: Value): Boolean = { 15 | // v match 16 | // case block: BasicBlock => isConstant(block.stmts.last) 17 | // case value: GlobalValue => ??? 18 | // case user: User => user match 19 | // case constant: Constant => true 20 | // case instruction: Instruction => instruction.getOperands.forall(isConstant) 21 | // case _ => ??? 22 | // case _ => false 23 | //} 24 | // 25 | //def isConstant(inst: Instruction): Boolean = { 26 | // inst match 27 | // case UnaryInst(operandValue) => isConstant(operandValue) 28 | // case base: CallBase => ??? 29 | // case intrinsic: Intrinsic => ??? 30 | // // isConstant(condValue), cond -> isConstant(branch) 31 | // case CondBranch(condValue, tBranch, fBranch) => ??? 32 | // case Branch(destBasicBlock) => false 33 | // case Return(retValue) => isConstant(retValue) 34 | // case Binary(op, lhs_value, rhs_value) => isConstant(lhs_value) && isConstant(rhs_value) 35 | // case Alloc(id, typ) => false 36 | // case Load(valuePtr) => isConstant(valuePtr) 37 | // case Store(value, ptr) => isConstant(value) 38 | // case GetElementPtr(value, offset, targetTy) => false 39 | // case PhiNode(incomings) => false 40 | // case SwitchInst() => ??? 41 | // case MultiSuccessorsInst(bbs) => ??? 42 | // case _ => ??? 43 | //} 44 | 45 | def eval(value: Value): Option[Value] = { 46 | value match 47 | case Argument(nameStr, argTy) => None 48 | case block: BasicBlock => ??? 49 | case value: GlobalValue => ??? 50 | case user: User => user match 51 | case constant: Constant => None 52 | case instruction: Instruction => evalInst(instruction) 53 | case _ => ??? 54 | case _ => ??? 55 | } 56 | 57 | /** return a inst maybe change 58 | **/ 59 | def evalInst(inst: Instruction): Option[Value] = { 60 | inst match 61 | case UnaryInst(operandValue) => ??? 62 | case Return(retValue) => eval(retValue).map(Return(_)) 63 | case bn: Binary => foldBinaryInstruction(bn) 64 | case Load(valuePtr) => valuePtr match 65 | case c: Constant => Some(c) 66 | case _ => eval(valuePtr).map(Load(_)) 67 | case Store(value, ptr) => eval(value).map(Store(_, ptr)) 68 | case alloc: Alloc => { 69 | val stores = alloc.users.filter(use => use.parent.isInstanceOf[Store]) 70 | // store only once 71 | if (stores.size == 1) { 72 | val store = stores.head.parent.asInstanceOf[Store] 73 | // store is constant, then alloc replace with value 74 | evalInst(store) match 75 | case Some(value) => value.asInstanceOf[Store].value match 76 | case c: Constant => { 77 | store.eraseFromParent 78 | Some(c) 79 | } 80 | case _ => None 81 | case None => None 82 | } else { 83 | None 84 | } 85 | } 86 | case _ => None 87 | // case base: CallBase => ??? 88 | // case intrinsic: Intrinsic => ??? 89 | // case CondBranch(condValue, tBranch, fBranch) => ??? 90 | // case GetElementPtr(value, offset, targetTy) => ??? 91 | // case PhiNode(incomings) => ??? 92 | // case Branch(destBasicBlock) => ??? 93 | // case SwitchInst() => ??? 94 | // case MultiSuccessorsInst(bbs) => ??? 95 | } 96 | 97 | def foldBinaryOp(op: String, v: Value, c: Constant, bn: Binary): Option[Value] = { 98 | op match 99 | case _ @("add" | "sub") if c.isZero => Some(v) 100 | case _ @("mul" | "div") if c.isOne => Some(v) 101 | case _ => None 102 | } 103 | 104 | def getDouble(n: Number): Double = { 105 | n match 106 | case Integer(value) => value 107 | case FP(value) => value 108 | case _ => ??? 109 | } 110 | 111 | def compute(op: String, a: Number, b: Number): Number = { 112 | val lhs = getDouble(a) 113 | val rhs = getDouble(b) 114 | val result = op match 115 | case "Add" => lhs + rhs 116 | case "Sub" => lhs - rhs 117 | case _ => ??? 118 | a match 119 | case Integer(value) => Integer(result.toInt) 120 | case FP(value) => FP(result.toInt) 121 | case _ => ??? 122 | } 123 | 124 | def foldBinaryInstruction(bn: Binary): Option[Value] = { 125 | (bn.lhs, bn.rhs) match 126 | case (lhs: Constant, rhs: Constant) => (lhs, rhs) match 127 | case (a: Integer, b: Integer) => Some(compute(bn.op, a.value, b.value)) 128 | case (a: FP, b: FP) => Some(compute(bn.op, a.value, b.value)) 129 | case _ => ??? 130 | case (lhs: Constant, rhs: Value) => foldBinaryOp(bn.op, rhs, lhs, bn) 131 | case (lhs: Value, rhs: Constant) => foldBinaryOp(bn.op, lhs, rhs, bn) 132 | case (lhs: Value, rhs: Value) => None 133 | } 134 | 135 | // 1. operand is constant && operator can be eval 136 | // 2. eval 137 | // 3. replace -> find uses 138 | class ConstantFolding extends Transform[Function] { 139 | // binary -> x +- 0, x */ 1 140 | 141 | def fold(IRUnit: Function): Unit = { 142 | println(IRUnit) 143 | traverse(IRUnit.instructions)(inst => { 144 | evalInst(inst) match 145 | case Some(after) => { 146 | println("replace") 147 | println(inst) 148 | println(after) 149 | assert(inst != after) 150 | println("--") 151 | inst.replaceAllUseWith(after) 152 | inst.eraseFromParent 153 | } 154 | case None => 155 | }) 156 | } 157 | override def run(iRUnitT: Function, AM: AnalysisManager[Function]): Unit = { 158 | fold(iRUnitT) 159 | fold(iRUnitT) 160 | // workList.foreach(inst => inst.replaceAllUseWith(eval(inst))) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /src/main/scala/mir/DomTree.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import tools.tap 5 | 6 | import scala.collection.mutable.LinkedHashSet 7 | import util.chaining.scalaUtilChainingOps 8 | 9 | class DomTreeNode(var parentTree: DomTree, var basicBlock: BasicBlock, var children: List[DomTreeNode] = List()) { 10 | var iDom: DomTreeNode = null 11 | 12 | def name: String = basicBlock.name 13 | 14 | def addChild(child: DomTreeNode) = { 15 | children = child :: children 16 | } 17 | 18 | def addChilds(childs: List[DomTreeNode]) = { 19 | children = childs ::: children 20 | } 21 | 22 | def preds = basicBlock.preds 23 | 24 | override def toString: String = s"bb:${basicBlock.name} -> ${children.map(_.basicBlock.toString).mkString(" ")}" 25 | } 26 | 27 | object DomEntry extends DomTreeNode(null, null) { 28 | } 29 | 30 | case class DomTree(var parent: Function) { 31 | var nodes = Map[BasicBlock, DomTreeNode]() 32 | 33 | def entry = nodes(parent.entry) 34 | 35 | def addNode(bb: BasicBlock): DomTreeNode = { 36 | DomTreeNode(this, bb).tap(node => 37 | nodes += (bb -> node) 38 | ) 39 | } 40 | 41 | def apply(bb: BasicBlock) = node(bb) 42 | 43 | def node(bb: BasicBlock): DomTreeNode = nodes(bb) 44 | 45 | override def toString: String = { 46 | "DomTree:\n" + nodes.values.toList.sortBy(_.name).map(d => s"${d.name} -> ${d.children.map(_.name).sorted.mkString(",")}").mkString("\n") 47 | } 48 | 49 | def serialize: List[DomTreeNode] = dfsBasicBlocks(parent.entry).map(b => nodes(b)) 50 | 51 | def visit[T](f: DomTreeNode => T): List[T] = { 52 | serialize.map(f) 53 | } 54 | } 55 | 56 | extension (i: DomTreeNode) { 57 | // entry 58 | // \ 59 | // i 60 | // \ 61 | // a 62 | def dom(a: DomTreeNode): Boolean = { 63 | a.children.contains(i) 64 | } 65 | 66 | def sdom(a: DomTreeNode): Boolean = i != a && (i dom a) 67 | 68 | def idom(a: DomTreeNode): Boolean = a.iDom == i 69 | } 70 | 71 | def allReach(a: BasicBlock, b: BasicBlock): Boolean = { 72 | if (a == b) return true 73 | a.successors.forall(canReach(_, b)) 74 | } 75 | 76 | type Node = BasicBlock 77 | type DomInfoType = Map[Node, LinkedHashSet[Node]] 78 | 79 | case class DomTreeBuilder() { 80 | var visited = LinkedHashSet[BasicBlock]() 81 | 82 | def compute(fn: Function): DomTree = { 83 | val predMap = predecessorsMap(fn.bbs) 84 | val bbs = dfsBasicBlocks(fn.entry) 85 | compute(LinkedHashSet.from(bbs), predMap, fn.entry) 86 | } 87 | 88 | var dumpCompute = true 89 | 90 | def computeLog(str: String) = { 91 | if (dumpCompute) 92 | println(str) 93 | } 94 | 95 | def computeImpl(nodes: LinkedHashSet[Node], pred: DomInfoType, root: Node): DomInfoType = { 96 | var change = false; 97 | var Domin = Map(root -> LinkedHashSet[Node](root)) 98 | val N = nodes 99 | (N - root).foreach(n => { 100 | Domin = Domin.updated(n, N) 101 | }) 102 | assert(Domin(root).size == 1) 103 | assert(root != null) 104 | 105 | // remove entry 106 | val workList = nodes.tail 107 | while (!change) { 108 | workList.foreach(n => { 109 | computeLog(s"process: ${n.name}") 110 | val preds = pred(n) 111 | computeLog(s"preds: ${preds.map(_.name).mkString(", ")}") 112 | // union set of all predecessors dominator set 113 | // first node only be dominated by itself 114 | // first result of tmpDom is {root} 115 | val tmpDom = preds.foldLeft(N) { (acc, p) => acc & Domin(p) } 116 | // predecessors dom set + self (not strict dominate) 117 | val D = tmpDom + n 118 | if (D != Domin(n)) then { 119 | computeLog(s"Dom: ${D.map(_.name).mkString(", ")}") 120 | change = true 121 | Domin = Domin.updated(n, D) 122 | } 123 | println("") 124 | }) 125 | } 126 | Domin 127 | } 128 | 129 | def compute(nodes: LinkedHashSet[Node], pred: DomInfoType, root: Node): DomTree = { 130 | val Domin = computeImpl(nodes, pred, root) 131 | makeTree(Domin, root.parent) 132 | } 133 | 134 | def makeTree(Domin: DomInfoType, f: Function): DomTree = { 135 | val tree = DomTree(f) 136 | Domin.foreach(d => { 137 | tree.addNode(d._1) 138 | }) 139 | Domin.foreach(d => { 140 | tree.node(d._1).addChilds(d._2.map(tree.node).toList) 141 | }) 142 | tree 143 | } 144 | } 145 | 146 | def idomComputeLog(str: String) = { 147 | if (false) 148 | println(str) 149 | } 150 | 151 | def iDomCompute(tree: DomTree, root: Node): DomTree = { 152 | val domInfo: DomInfoType = tree.nodes.map(node => node._1 -> LinkedHashSet.from(node._2.children.map(_.basicBlock))).toMap 153 | val idoms = iDomComputeImpl(LinkedHashSet.from(tree.nodes.keys), domInfo, root) 154 | 155 | idoms.foreach(idom => { 156 | tree(idom._1).iDom = tree(idom._2) 157 | }) 158 | tree 159 | } 160 | 161 | def iDomComputeImpl(N: LinkedHashSet[Node], Domin: DomInfoType, root: Node): Map[Node, Node] = { 162 | var tmp = N.foldLeft(Map[Node, LinkedHashSet[Node]]())((acc, n) => 163 | acc.updated(n, Domin(n) - n) 164 | ) 165 | 166 | // a != b && a dom b && not exist c: a dom c && c dom b 167 | 168 | // all dominators 169 | (N - root).toList.sortBy(_.name).foreach(a => { 170 | 171 | idomComputeLog("a: " + a.name) 172 | idomComputeLog("tmp: " + tmp(a).map(_.name).mkString(", ")) 173 | // tmp(n) - n ==> a dom b && a != b 174 | (tmp(a) - a).foreach(b => { 175 | // node c ==> a dom c 176 | // c != b 177 | idomComputeLog("b: " + b.name) 178 | (tmp(a) - b).foreach(c => { 179 | // if c dom b, then is not idom, should remove from tmp 180 | if (tmp(c).contains(b)) { 181 | idomComputeLog("c: " + c.name) 182 | idomComputeLog("reduce: " + b.name) 183 | tmp = tmp.updated(a, tmp(a) - b) 184 | } 185 | }) 186 | }) 187 | idomComputeLog("") 188 | }) 189 | 190 | (N - root).map(n => { 191 | println(n.name) 192 | println(tmp(n)) 193 | (n -> tmp(n).head) 194 | }).toMap 195 | } -------------------------------------------------------------------------------- /src/main/scala/ty/Infer.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package ty 3 | import ast.Expr.* 4 | import ast.* 5 | import ty.* 6 | import ty.TyCtxt 7 | 8 | import scala.collection.immutable.ListMap 9 | import rclang.mir.intrinsics 10 | import rclang.tools.{FullName, GlobalTablePrinter, NestSpace} 11 | 12 | case object Infer { 13 | var tyCtxt: TyCtxt = TyCtxt() 14 | def enter[T](tyCtxt: TyCtxt, f: => T): T = { 15 | this.tyCtxt = tyCtxt 16 | tyCtxt.enter(f) 17 | } 18 | 19 | def enter[T](f: => T): T = { 20 | tyCtxt.enter(f) 21 | } 22 | 23 | def apply(typed: Typed, force: Boolean = false): Type = { 24 | infer(typed, force) 25 | } 26 | 27 | private def infer(typed: Typed, force: Boolean): Type = { 28 | if(!force && typed.ty != InferType) { 29 | typed.ty 30 | } else { 31 | infer(typed) 32 | } 33 | } 34 | 35 | private def infer(typed: Typed): Type = { 36 | typed match 37 | case expr: Expr => infer(expr) 38 | case item: Item => infer(item) 39 | case method: Method => infer(method) 40 | case stmt: Stmt => infer(stmt) 41 | case _ => ??? 42 | } 43 | 44 | 45 | private def infer(item: Item): Type = { 46 | item match 47 | case m: Method => infer(m) 48 | case k: Class => { 49 | StructType(k.name.str, ListMap.from(k.vars.map(v => v.name.str -> translate(v.ty)))) 50 | } 51 | } 52 | 53 | private def infer(stmt: Stmt): Type = { 54 | stmt match 55 | case Stmt.Local(name, ty, value) => { 56 | val localTy = ty match 57 | case TyInfo.Spec(ty) => translate(ty) 58 | case TyInfo.Infer => infer(value) 59 | case TyInfo.Nil => NilType 60 | tyCtxt.addLocal(name, localTy) 61 | localTy 62 | } 63 | case Stmt.Expr(expr) => infer(expr) 64 | case Stmt.While(cond, body) => infer(body) 65 | case Stmt.For(_, _, _, body) => infer(body) 66 | // todo: check value, check other member in for and while 67 | case Stmt.Assign(name, value) => lookup(name) 68 | case Stmt.Break() => NilType 69 | case Stmt.Continue() => NilType 70 | } 71 | 72 | private def infer(expr: Expr): Type = { 73 | expr match 74 | case Number(v) => Int32Type 75 | case Identifier(ident) => lookup(ident) 76 | case Bool(b) => BooleanType 77 | case Binary(op, lhs, rhs) => common(lhs, rhs) 78 | case Str(str) => StringType 79 | case If(cond, true_branch, false_branch) => false_branch match 80 | case Some(fBr) => common(true_branch, fBr) 81 | case None => infer(true_branch) 82 | case Return(expr) => infer(expr) 83 | case Block(stmts) => { 84 | if(stmts.isEmpty) { 85 | NilType 86 | } else { 87 | // must have prev info when infer last 88 | tyCtxt.enter(stmts.map(infer).last) 89 | } 90 | } 91 | case Call(target, args, _) => lookup(target) 92 | case Lambda(args, block) => ??? 93 | case MethodCall(obj, target, args) => { 94 | val makeCallTy = (klass: String) => { 95 | if(target.str == "new") { 96 | return PointerType(TypeBuilder.fromClass(klass, tyCtxt.globalTable)) 97 | } 98 | // GlobalTablePrinter.print(tyCtxt.globalTable) 99 | NestSpace(tyCtxt.globalTable, tyCtxt.fullName.copy(klass = klass)).lookupFn(target).ty match 100 | case FnType(ret, params) => ret 101 | case _ => ??? 102 | } 103 | // obj is a constant or obj is a expr 104 | val ty = obj match { 105 | case Symbol(sym, _) => { 106 | makeCallTy(sym.str) 107 | } 108 | // todo: new的返回类型 109 | case _ => { 110 | val id = obj match 111 | case Identifier(id) => id 112 | case _ => Ident("") 113 | structTyProc(obj.withInfer)(s => { 114 | makeCallTy(s.name) 115 | // NestSpace(tyCtxt.globalTable, tyCtxt.fullName.copy(klass = s.name)).lookupFn(target).ty 116 | }) 117 | } 118 | } 119 | ty 120 | } 121 | case Field(expr, ident) => { 122 | structTyProc(expr.withInfer)(s => { 123 | NestSpace(tyCtxt.globalTable, tyCtxt.fullName.copy(klass = s.name)).lookupVar(ident).ty 124 | }) 125 | } 126 | case Self => ??? 127 | case Symbol(ident, _) => ??? 128 | case Index(expr, _) => infer(expr) match 129 | case ArrayType(valueT, _) => valueT 130 | case _ => ErrType("failed") 131 | // todo: array type 132 | case Array(len, initValues) => ArrayType(common(initValues), len) 133 | } 134 | 135 | private def lookup(ident: Ident): Type = { 136 | // todo: lookup failed should to class 137 | tyCtxt.lookup(ident).getOrElse(ErrType(s"$ident not found")) 138 | } 139 | 140 | private def infer(f: Method): Type = { 141 | tyCtxt.global.getOrElse(f.name, { 142 | // todo:从这里进的function没有local 143 | val oldName = tyCtxt.fullName 144 | tyCtxt.fullName = tyCtxt.fullName.copy(fn = f.decl) 145 | val ret = if (f.decl.outType == TyInfo.Infer) { 146 | Infer(f.body) 147 | } else { 148 | translate(f.decl.outType) 149 | } 150 | val params = f.decl.inputs.params.map(_.ty).map(translate) 151 | tyCtxt.fullName = oldName 152 | FnType(ret, params) 153 | }) 154 | } 155 | 156 | def translate(info: TyInfo): Type = info match 157 | case TyInfo.Spec(ty) => translate(ty) 158 | case TyInfo.Infer => ErrType("can't translate TyInfo.Infer") 159 | case TyInfo.Nil => NilType 160 | 161 | def translate(ident: Ident): Type = { 162 | ident.str match 163 | case "Boolean" => BooleanType 164 | case "String" => StringType 165 | case "Int" => Int32Type 166 | case "Float" => FloatType 167 | case "Nil" => NilType 168 | case "Handle" => PointerType(Int64Type) 169 | case _ => { 170 | PointerType(TypeBuilder.fromClass(ident.str, tyCtxt.globalTable)) 171 | } 172 | } 173 | 174 | private def common(exprList: List[Expr]): Type = { 175 | val tyList = exprList.map(infer) 176 | val isSame = tyList.forall(_ != tyList.head) 177 | if isSame then tyList.head else ErrType("failed") 178 | } 179 | 180 | private def common(lhs: Expr, rhs: Expr): Type = { 181 | val lt = infer(lhs) 182 | val rt = infer(rhs) 183 | if lt == rt then lt else ErrType("failed") 184 | } 185 | } 186 | 187 | def structTyProc[T](obj: Expr)(f: StructType => T): T = { 188 | structTyProc(obj.ty)(f) 189 | } 190 | 191 | def structTyProc[T](ty: Type)(f: StructType => T): T = { 192 | ty match 193 | case s: StructType => f(s) 194 | case PointerType(ty) => structTyProc(ty)(f) 195 | case _ => ??? 196 | } 197 | 198 | def structTyProc[T](ty: Type, default: T)(f: StructType => T): T = { 199 | ty match 200 | case s: StructType => f(s) 201 | case PointerType(ty) => structTyProc(ty)(f) 202 | case _ => ??? 203 | } -------------------------------------------------------------------------------- /src/main/scala/mir/Instruction.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package mir 3 | 4 | import ty.* 5 | import tools.* 6 | 7 | trait Terminator { 8 | def successors: List[BasicBlock] 9 | } 10 | 11 | case class Argument(nameStr: String, private val argTy: Type) extends Value { 12 | name = nameStr 13 | ty = argTy 14 | var default: Option[Value] = None 15 | var passByRef: Boolean = true 16 | } 17 | 18 | 19 | type InBasicBlock = In[BasicBlock] 20 | type InFunction = In[Function] 21 | 22 | sealed class Instruction(numOps: Int) extends User(numOps) with InBasicBlock { 23 | def eraseFromParent = { 24 | parent.stmts = parent.stmts.filterNot(_ == this) 25 | } 26 | } 27 | 28 | case class BinaryInstBase(private val lhsValue: Value, private val rhsValue: Value) extends Instruction(2) { 29 | setOperand(0, lhsValue) 30 | setOperand(1, rhsValue) 31 | def lhs: Value = getOperand(0) 32 | def rhs: Value = getOperand(1) 33 | } 34 | 35 | case class UnaryInst(private val operandValue: Value) extends Instruction(1) { 36 | setOperand(0, operandValue) 37 | def operand: Value = getOperand(0) 38 | } 39 | 40 | class CallBase(func: Function, private val args_value: List[Value]) extends Instruction(varOps) { 41 | setOperands(args_value) 42 | ty = func.retType 43 | def args = getOperands 44 | def getArg(i: Int): Value = getOperand(i) 45 | } 46 | 47 | case class Call(func: Function, private val args_value: List[Value]) extends CallBase(func, args_value) 48 | 49 | class Intrinsic(private val intrName: String, private val args_value: List[Value]) extends Instruction(varOps) { 50 | name = intrName 51 | // todo: fix this 52 | if(intrName == "print") { 53 | ty = NilType 54 | } else if(intrName == "malloc") { 55 | ty = PointerType(NilType) 56 | } 57 | setOperands(args_value) 58 | def args = getOperands 59 | def getArg(i: Int): Value = getOperand(i) 60 | def symbol = s"$intrName@PLT" 61 | override def toString: String = s"$intrName: ${args_value.map(_.toString).mkString(" ")}" 62 | } 63 | 64 | def commonTy(lhs: Type, rhs: Type): Type = { 65 | lhs 66 | } 67 | 68 | case class CondBranch(private val condValue: Value, private val tBranch: BasicBlock, private val fBranch: BasicBlock) extends Instruction(3) with Terminator { 69 | setOperand(0, condValue) 70 | setOperand(1, tBranch) 71 | setOperand(2, fBranch) 72 | def cond: Value = getOperand(0) 73 | def trueBranch: BasicBlock = getOperand(1).asInstanceOf[BasicBlock] 74 | def falseBranch: BasicBlock = getOperand(2).asInstanceOf[BasicBlock] 75 | def successors: List[BasicBlock] = List(trueBranch, falseBranch) 76 | } 77 | 78 | case class Branch(destBasicBlock: BasicBlock) extends Instruction(1) with Terminator { 79 | setOperand(0, destBasicBlock) 80 | def dest: BasicBlock = getOperand(0).asInstanceOf[BasicBlock] 81 | def successors = List(dest) 82 | } 83 | 84 | class Return(retValue: Value) extends Instruction(1) with Terminator { 85 | setOperand(0, retValue) 86 | ty = retValue.ty 87 | def successors = List() 88 | 89 | def value = getOperand(0) 90 | } 91 | 92 | class Binary(var op: String, lhs_value: Value, rhs_value: Value) extends Instruction(2) { 93 | setOperand(0, lhs_value) 94 | setOperand(1, rhs_value) 95 | ty = commonTy(lhs_value.ty, rhs_value.ty) 96 | def lhs = getOperand(0) 97 | def rhs = getOperand(1) 98 | } 99 | 100 | class Alloc(var id: String, typ: Type) extends Instruction(0) { 101 | name = id 102 | ty = typ 103 | } 104 | 105 | class Load(valuePtr: Value) extends Instruction(1) { 106 | ty = valuePtr.ty 107 | setOperand(0, valuePtr) 108 | def ptr = getOperand(0) 109 | } 110 | 111 | class Store(valueV: Value, ptrV: Value) extends Instruction(2) { 112 | ty = valueV.ty 113 | setOperand(0, valueV) 114 | setOperand(1, ptrV) 115 | def value = getOperand(0) 116 | def ptr = getOperand(1) 117 | } 118 | 119 | // value: this object 120 | case class GetElementPtr(value: Value, offset: Value, targetTy: Type) extends Instruction(2) { 121 | setOperand(0, value) 122 | setOperand(1, offset) 123 | ty = targetTy 124 | def align = value.ty match 125 | case s: StructType => s.align 126 | case _ => throw RuntimeException("value should be structure type") 127 | } 128 | 129 | case class PhiNode(var incomings: Map[Value, BasicBlock] = Map()) extends Instruction(varOps) { 130 | // avoid recursive 131 | // private def incomingsStr = incomings.map(x => x._2.map(b => s"${x._1} => ${b.name}").mkString("\n")).mkString("\n") 132 | private def incomingStr = "incomings" 133 | 134 | override def toString: String = "Phi" 135 | def addIncoming(value: Value, block: BasicBlock): Unit = { 136 | ty = value.ty 137 | incomings = incomings.updated(value, block) 138 | } 139 | } 140 | 141 | case class SwitchInst() extends Instruction(varOps) with Terminator { 142 | def addCase(cond: Value, bb: BasicBlock) : Unit = { 143 | 144 | } 145 | 146 | override def successors: List[BasicBlock] = { 147 | operands.map(_.asInstanceOf[BasicBlock]) 148 | } 149 | } 150 | 151 | // used for test 152 | case class MultiSuccessorsInst(var bbs: List[BasicBlock] = List()) extends Instruction(varOps) with Terminator { 153 | def add(bb: BasicBlock) : BasicBlock = { 154 | bbs = bbs :+ bb 155 | bb 156 | } 157 | 158 | override def successors: List[BasicBlock] = bbs 159 | } 160 | 161 | sealed class Constant(typ: Type) extends User(0) { 162 | ty = typ 163 | def isZero: Boolean = false 164 | def isOne: Boolean = false 165 | } 166 | 167 | sealed class Number(typ: Type) extends Constant(typ) { 168 | } 169 | 170 | case class Integer(value: Int) extends Number(Int32Type) { 171 | override def isZero: Boolean = value == 0 172 | 173 | override def isOne: Boolean = value == 1 174 | } 175 | 176 | case class ConstantArray(len: Int, values: List[Value]) extends Constant(ArrayType(Int32Type, len)) 177 | 178 | object ImplicitConversions { 179 | implicit def toInteger(int: Int): Integer = Integer(int) 180 | implicit def toFP(fp: Float): FP = FP(fp) 181 | } 182 | 183 | case class FP(value: Float) extends Number(FloatType) { 184 | override def isZero: Boolean = value == 0 185 | override def isOne: Boolean = value == 1 186 | } 187 | 188 | case class Str(str: String) extends Constant(StringType) 189 | 190 | case class Bool(bool: Boolean) extends Constant(BooleanType) 191 | 192 | object Load { 193 | def unapply(inst: Value): Option[Value] = { 194 | inst match 195 | case ld:Load => Some(ld.ptr) 196 | case _ => None 197 | } 198 | } 199 | 200 | object Store { 201 | def unapply(inst: Value): Option[(Value, Value)] = { 202 | inst match 203 | case st:Store => Some(st.value, st.ptr) 204 | case _ => None 205 | } 206 | } 207 | 208 | object Binary { 209 | def unapply(inst: Value): Option[(String, Value, Value)] = { 210 | inst match 211 | case bn:Binary => Some(bn.op, bn.lhs, bn.rhs) 212 | case _ => None 213 | } 214 | } 215 | 216 | object Return { 217 | def unapply(inst: Value): Option[Value] = { 218 | inst match 219 | case rt:Return => Some(rt.value) 220 | case _ => None 221 | } 222 | } -------------------------------------------------------------------------------- /src/main/scala/parser/ExprParser.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import ast.* 5 | import lexer.Keyword.* 6 | import lexer.Punctuation.* 7 | import lexer.Literal.* 8 | import lexer.Delimiter.* 9 | import lexer.Ident.* 10 | import lexer.Token 11 | import ast.Expr.{Block, If, Return} 12 | 13 | import rclang.ty.{ArrayType, Infer} 14 | 15 | import scala.collection.immutable.HashMap 16 | import scala.collection.mutable 17 | import scala.language.postfixOps 18 | import scala.util.parsing.input.Positional 19 | 20 | trait BinaryTranslator { 21 | val opDefaultInfix = HashMap("+"->10, "-"->10, "*"->10, "/"->10, ">"->5, "<"->5, "==" ->4) 22 | 23 | def findMaxInfixIndex(terms: List[Positional]): Int = 24 | terms 25 | .zipWithIndex 26 | .filter((x, _) => x.isInstanceOf[OPERATOR]) 27 | .map((x, index) => (x.asInstanceOf[OPERATOR], index)) 28 | .minBy((op, index) => opDefaultInfix(op.op))._2 29 | 30 | def replaceBinaryOp(terms: List[Positional], index: Int): List[Positional] = { 31 | var t = terms(index) 32 | val left = terms.slice(0, index - 1) 33 | val bn = Expr.Binary( 34 | strToOp(terms(index).asInstanceOf[OPERATOR].op), 35 | terms(index - 1).asInstanceOf[Expr], 36 | terms(index + 1).asInstanceOf[Expr]) 37 | val rights = terms.slice(index + 2, terms.size) 38 | left.appended(bn):::(rights) 39 | } 40 | 41 | def termsToBinary(term: Expr, terms: List[List[Positional]]): Expr = { 42 | if terms.isEmpty then return term 43 | termsToBinary(term :: terms.reduce(_:::_)) 44 | } 45 | 46 | def termsToBinary(terms: List[Positional]): Expr = { 47 | var newTerms = terms 48 | while (newTerms.size > 1) { 49 | val max_index = findMaxInfixIndex(newTerms) 50 | newTerms = replaceBinaryOp(newTerms, max_index) 51 | } 52 | newTerms.head.asInstanceOf[Expr.Binary] 53 | } 54 | } 55 | 56 | trait ExprParser extends RcBaseParser with BinaryTranslator { 57 | def typeLimit = COLON ~> ty 58 | 59 | // todo: template的二义性如何解决 60 | def termExpr: Parser[Expr] = positioned { 61 | term ~ (operator ~ term).* ^^ { 62 | case term ~ terms => termsToBinary(term, terms.map(a => List(a._1, a._2))) 63 | } 64 | } 65 | 66 | def expr: Parser[Expr] = positioned { 67 | multiLineIf | log(ret)("ret") | termExpr 68 | } 69 | 70 | def array: Parser[Expr] = positioned { 71 | ty ~ squareSround(number) ~ bracketSround(repsep(termExpr, COMMA)) ^^ { 72 | case ty ~ NUMBER(size) ~ values => { 73 | val arr = Expr.Array(size, values) 74 | arr.ty = ArrayType(Infer.translate(ty), size) 75 | arr 76 | } 77 | } 78 | } 79 | 80 | def string = positioned { stringLiteral ^^ { case STRING(str) => Expr.Str(str)} } 81 | def num = positioned { number ^^ { case NUMBER(int) => Expr.Number(int) } } 82 | def idExpr = positioned { id ^^ Expr.Identifier } 83 | 84 | 85 | // term expr, ID也应该放在这里 86 | 87 | // memCall: term.x( 88 | // memCall: term.x( 89 | // memField: term.x 90 | // arrayIndex: term[ 91 | lazy val beginWithTerm: PackratParser[Expr] = positioned { 92 | log(memCall)("memCallStart") | log(memField)("memFieldStart") | array | arrayIndex 93 | } 94 | 95 | // call: term() 96 | // memCall: term.x() 97 | def term: Parser[Expr] = positioned { 98 | bool | num | string | selfField | call | log(beginWithTerm)("beginWithTerm") | symbol | log(idExpr)("idExpr") 99 | } 100 | 101 | def symbol = positioned { 102 | sym ~ template.? ^^ { 103 | case s ~ temp => Expr.Symbol(s, temp) 104 | } 105 | } 106 | 107 | def bool: Parser[Expr] = positioned { 108 | TRUE ^^ (_ => Expr.Bool(true)) | 109 | FALSE ^^ (_ => Expr.Bool(false)) 110 | } 111 | 112 | def call: Parser[Expr.Call] = positioned { 113 | id ~ template.? ~ parSround(repsep(termExpr, COMMA)) ^^ { 114 | case id ~ temp ~ args => Expr.Call(id, args, temp) 115 | } 116 | } 117 | 118 | def selfField: Parser[Expr.Field] = positioned { 119 | (AT ~> id) ^^ (id => Expr.Field(Expr.Self, id)) 120 | } 121 | 122 | // todo: 如果memField是callable怎么办 123 | def memField: Parser[Expr.Field] = positioned { 124 | log(log(termExpr)("fieldTerm") <~ DOT)("MemberLog") ~ log(id)("FieldLog") <~ not(guard(LEFT_PARENT_THESES)) ^^ { 125 | case obj ~ name => Expr.Field(obj, name) 126 | } 127 | } 128 | 129 | def memCall: Parser[Expr.MethodCall] = positioned { 130 | (termExpr <~ DOT) ~ id ~ parSround(repsep(log(termExpr)("memCallArgs"), COMMA)) ^^ { 131 | case obj ~ id ~ args => Expr.MethodCall(obj, id, args) 132 | } 133 | } 134 | 135 | def arrayIndex: Parser[Expr.Index] = positioned { 136 | termExpr ~ squareSround(termExpr) ^^ { 137 | case expr ~ index => Expr.Index(expr, index) 138 | } 139 | } 140 | 141 | def block: Parser[Block] = positioned { 142 | rep(log(statement | none)("stmt")) ^^ (stmts => { 143 | Block(stmts.filter(_ != Empty).map(_.asInstanceOf[Stmt])) 144 | }) 145 | } 146 | 147 | def multiLineIf: Parser[If] = positioned { 148 | log(oneline(IF ~> log(expr)("if_cond")))("if") ~ block ~ log(elsif.*)("elsif") ~ (oneline(ELSE) ~> log(block)("else block")).? <~ log(END)("end") ^^ { 149 | case cond ~ if_branch ~ elsif ~ else_branch 150 | => If(cond, if_branch, elsif.foldRight(else_branch.asInstanceOf[Option[Expr]])( 151 | (next, acc) => Some(If(next.cond, next.true_branch, acc)))) 152 | } 153 | } 154 | 155 | def elsif: Parser[If] = positioned { 156 | oneline(ELSIF ~> termExpr) ~ block ^^ { 157 | case cond ~ branch => If(cond, branch, None) 158 | } 159 | } 160 | 161 | def statement: Parser[Stmt] = positioned { 162 | oneline(assign 163 | | whileStmt 164 | | log(forStmt)("forStmt") 165 | | log(local)("local") 166 | | BREAK ^^^ Stmt.Break() 167 | | CONTINUE ^^^ Stmt.Continue() 168 | | log(expr)("exprStmt") ^^ Stmt.Expr) 169 | } 170 | 171 | def none: Parser[ASTNode] = positioned { 172 | EOL ^^^ Empty 173 | } 174 | 175 | def local: Parser[Stmt] = positioned { 176 | ((VAR | VAL) ~> id) ~ (EQL ~> termExpr) ^^ { 177 | case id ~ expr => Stmt.Local(id, TyInfo.Infer, expr) 178 | } 179 | } 180 | 181 | def ret: Parser[Return] = positioned { 182 | RETURN ~> termExpr ^^ Return 183 | } 184 | 185 | def assign: Parser[Stmt.Assign] = positioned { 186 | (id <~ EQL) ~ log(termExpr)("stmt assign") ^^ { 187 | case id ~ expr => Stmt.Assign(id, expr) 188 | } 189 | } 190 | 191 | def whileStmt: Parser[Stmt.While] = positioned { 192 | oneline(WHILE ~> parSround(termExpr)) ~ block <~ log(END)("end while") ^^ { 193 | case cond ~ body => Stmt.While(cond, body) 194 | } 195 | } 196 | 197 | def forStmt: Parser[Stmt.For] = positioned { 198 | oneline(FOR ~> parSround(log(local)("init") ~ SEMICOLON ~ log(expr)("cond") ~ SEMICOLON ~ log(assign)("update"))) ~ block <~ END ^^ { 199 | case init ~ _ ~ cond ~ _ ~ incr ~ body => Stmt.For(init, cond, incr, body) 200 | } 201 | } 202 | } 203 | 204 | object RcExprParser extends ExprParser { 205 | def apply(tokens: Seq[Token]) : Either[RcParserError, Stmt] = { 206 | doParser(tokens, statement) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /src/test/scala/parser/ExprParserTest.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package parser 3 | 4 | import org.scalatest.funspec.AnyFunSpec 5 | 6 | import ast.* 7 | import ast.Expr.* 8 | import ast.BinaryOp.* 9 | import ast.ImplicitConversions.* 10 | import lexer.Token 11 | import lexer.Keyword.* 12 | import lexer.Punctuation.* 13 | import lexer.Literal.* 14 | import lexer.Delimiter.* 15 | import lexer.Ident.* 16 | import scala.language.postfixOps 17 | import parser.RcBaseParser 18 | import org.scalatest._ 19 | import org.scalactic.TimesOnInt.convertIntToRepeater 20 | 21 | class ExprParserTest extends ExprParser with BaseParserTest { 22 | def apply(tokens: Seq[Token]): Either[RcParserError, (Expr, Input)] = { 23 | doParserImpl(tokens, expr) 24 | } 25 | 26 | def expectSuccess(token: Token, expect: Expr): Unit = { 27 | expectSuccess(List(token), expect) 28 | } 29 | 30 | def expectSuccess(tokens: Seq[Token], expect: Expr): Unit = { 31 | apply(tokens) match { 32 | case Left(value) => assert(false, value.msg) 33 | case Right((ast, reader)) => { 34 | println(ast) 35 | println(expect) 36 | assert(ast == expect) 37 | assert(reader.atEnd, reader) 38 | } 39 | } 40 | } 41 | 42 | describe("number") { 43 | it ("succeed") { 44 | expectSuccess(NUMBER(3), Number(3)) 45 | } 46 | } 47 | 48 | describe("identifier") { 49 | it("succeed") { 50 | expectSuccess(IDENTIFIER("foo"), Identifier("foo")) 51 | } 52 | } 53 | 54 | describe("const") { 55 | it("succeed") { 56 | expectSuccess(UPPER_IDENTIFIER("Foo"), Expr.Symbol("Foo")) 57 | } 58 | } 59 | 60 | describe("bool") { 61 | it("succeed") { 62 | expectSuccess(TRUE, Expr.Bool(true)) 63 | expectSuccess(FALSE, Expr.Bool(false)) 64 | } 65 | } 66 | 67 | describe("str") { 68 | it("succeed") { 69 | expectSuccess(STRING("str"), Str("str")) 70 | } 71 | } 72 | 73 | describe("call") { 74 | it("empty args") { 75 | expectSuccess(makeCall("foo", List()), Call("foo", List())) 76 | } 77 | 78 | it("multi args") { 79 | expectSuccess(makeCall("foo", List(NUMBER(1), NUMBER(2))), Call("foo", List(Number(1), Number(2)))) 80 | } 81 | 82 | it("withMemField") { 83 | expectSuccess( 84 | List(IDENTIFIER("foo"), LEFT_PARENT_THESES, NUMBER(1), COMMA, IDENTIFIER("node"), DOT, IDENTIFIER("value"), RIGHT_PARENT_THESES), 85 | Call("foo", List(Number(1), Field(Identifier("node"), Ident("value"))))) 86 | } 87 | } 88 | 89 | describe("memField") { 90 | it("normalField") { 91 | expectSuccess(mkTkMemField("homura", "shield"), mkASTMemField("homura", "shield")) 92 | } 93 | 94 | it("selfField") { 95 | expectSuccess(List(AT, IDENTIFIER("homura")), Expr.Field(Self, "homura")) 96 | } 97 | } 98 | 99 | describe("memCall") { 100 | it("succeed") { 101 | expectSuccess(mkTkMemCall("homura", "shot"), mkASTMemCall("homura", "shot")) 102 | } 103 | 104 | it("new") { 105 | expectSuccess( 106 | List(UPPER_IDENTIFIER("Foo"), DOT, IDENTIFIER("new"), LEFT_PARENT_THESES, RIGHT_PARENT_THESES), 107 | MethodCall(Symbol("Foo"), "new", List())) 108 | } 109 | 110 | it("multi param") { 111 | expectSuccess( 112 | List(UPPER_IDENTIFIER("TreeNode"), DOT, IDENTIFIER("new"), LEFT_PARENT_THESES, IDENTIFIER("node"), COMMA, IDENTIFIER("lhs"), COMMA, IDENTIFIER("rhs"), RIGHT_PARENT_THESES), 113 | MethodCall(Symbol("TreeNode"), "new", List(Identifier("node"), Identifier("lhs"), Identifier("rhs"))) 114 | ) 115 | } 116 | 117 | // todo: fix this error 118 | it("complex") { 119 | expectSuccess( 120 | List(UPPER_IDENTIFIER("TreeNode"), DOT, IDENTIFIER("new"), LEFT_PARENT_THESES, IDENTIFIER("node"), DOT, IDENTIFIER("value"), COMMA, IDENTIFIER("lhs"), COMMA, IDENTIFIER("rhs"), RIGHT_PARENT_THESES), 121 | MethodCall(Symbol("TreeNode"), "new", List(Field(Identifier("node"), Ident("value")), Identifier("lhs"), Identifier("rhs"))) 122 | ) 123 | } 124 | } 125 | 126 | describe("arrayIndex") { 127 | it("normal number") { 128 | expectSuccess( 129 | List(IDENTIFIER("a"), LEFT_SQUARE, NUMBER(1), RIGHT_SQUARE), 130 | Index(Identifier("a"), Number(1))) 131 | } 132 | 133 | it("index is termExpr") { 134 | expectSuccess( 135 | List(IDENTIFIER("a"), LEFT_SQUARE, NUMBER(1), OPERATOR("+"), NUMBER(2), RIGHT_SQUARE), 136 | Index(Identifier("a"), Binary(Add, Number(1), Number(2)))) 137 | } 138 | } 139 | 140 | describe("binary") { 141 | it("single add") { 142 | expectSuccess(List(NUMBER(1), OPERATOR("+"), NUMBER(2)), Expr.Binary(Add, Number(1), Number(2))) 143 | } 144 | } 145 | 146 | describe("return") { 147 | it("succeed") { 148 | expectSuccess(List(RETURN, NUMBER(1)), Expr.Return(Expr.Number(1))) 149 | } 150 | } 151 | 152 | describe("if") { 153 | it("full succeed") { 154 | expectSuccess( 155 | makeIf(TRUE, NUMBER(1), makeElsif(List((FALSE, NUMBER(2)))), NUMBER(3)), 156 | makeIf(trueExpr, Number(1), makeLastIf(falseExpr, Number(2), Number(3)))) 157 | } 158 | 159 | it("full with multi EOL") { 160 | expectSuccess( 161 | makeIf(List(TRUE, EOL, EOL), NUMBER(1), makeElsif(List((FALSE, NUMBER(2)))), NUMBER(3)), 162 | makeIf(trueExpr, Number(1), makeLastIf(falseExpr, Number(2), Number(3)))) 163 | } 164 | 165 | it("no elsif") { 166 | expectSuccess(makeIf(TRUE, NUMBER(1), NUMBER(3)), 167 | makeLastIf(trueExpr, Number(1), Number(3))) 168 | } 169 | 170 | it("no else") { 171 | expectSuccess(makeIf(TRUE, NUMBER(1), makeElsif(List((FALSE, NUMBER(2))))), 172 | makeIf(trueExpr, Number(1), makeIf(falseExpr, Number(2), None))) 173 | } 174 | } 175 | 176 | describe("template") { 177 | it("callFn") { 178 | val foo = List(IDENTIFIER("foo")) ::: wrapWithAngleBrackets(UPPER_IDENTIFIER("Int")) ::: mkTKArgs(List(NUMBER(1), NUMBER(2))) 179 | val fooAST = mkASTCall("foo", "Int", List(Number(1), Number(2))) 180 | expectSuccess(foo, fooAST) 181 | } 182 | 183 | it("callClass") { 184 | val treeNode = List(UPPER_IDENTIFIER("TreeNode")) ::: wrapWithAngleBrackets(UPPER_IDENTIFIER("Int")) ::: List(DOT, IDENTIFIER("new")) ::: mkTKArgs(List(NUMBER(5))) 185 | val treeNodeAST = MethodCall(Symbol("TreeNode", Some(Ident("Int"))), Ident("new"), List(Number(5))) 186 | expectSuccess(treeNode, treeNodeAST) 187 | } 188 | } 189 | } 190 | 191 | class BinaryTranslatorTest extends BaseParserTest with BinaryTranslator { 192 | def makeBinary(a: Int, op: String, b: Int) = List(Number(a), OPERATOR(op), Number(b)) 193 | def makeMultiBinary(a: Int, op1: String, b: Int, op2: String, c:Int) = 194 | List(Number(a), OPERATOR(op1), Number(b), OPERATOR(op2), Number(c)) 195 | 196 | def oneBn = makeBinary(1, "+", 2) 197 | def twoAdd = makeMultiBinary(1, "+", 2, "+", 3) 198 | def addAndLT = makeMultiBinary(1, "+", 2, "<", 3) 199 | 200 | describe("findMaxInfixIndex") { 201 | it("only one op") { 202 | assert(findMaxInfixIndex(oneBn) == 1) 203 | } 204 | it("multi same infix op") { 205 | assert(findMaxInfixIndex(twoAdd) == 1) 206 | } 207 | it("multi different op") { 208 | assert(findMaxInfixIndex(addAndLT) == 3) 209 | } 210 | } 211 | 212 | describe("replaceBinaryOp") { 213 | it("succeed") { 214 | assert(replaceBinaryOp(oneBn, 1) == List(Binary(Add, Number(1), Number(2)))) 215 | } 216 | } 217 | 218 | describe("compose") { 219 | it("succeed") { 220 | assert(termsToBinary(addAndLT) == Binary(Add, Number(1), Binary(LT, Number(2), Number(3)))) 221 | } 222 | } 223 | } -------------------------------------------------------------------------------- /src/main/scala/codegen/IR.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import mir.* 5 | import tools.In 6 | 7 | import cats.effect.kernel.Par.instance.T 8 | 9 | import scala.util.parsing.input.Positional 10 | 11 | trait MapOrigin[T] { 12 | var origin: T = null.asInstanceOf[T] 13 | 14 | def setOrigin(origin: T): this.type = { 15 | this.origin = origin 16 | this 17 | } 18 | } 19 | 20 | type InMF = In[MachineFunction] 21 | type InMBB = In[MachineBasicBlock] 22 | 23 | class MachineFunction(var bbs: List[MachineBasicBlock], var f: Function, val frameInfo: MachineFrameInfo) extends MapOrigin[Function] { 24 | origin = f 25 | frameInfo.mf = this 26 | 27 | def name = f.name 28 | 29 | def instructions = bbs.flatMap(_.instList) 30 | } 31 | 32 | class MachineBasicBlock(var instList: List[MachineInstruction], f: MachineFunction, bb: BasicBlock, val name: String) extends InMF with MapOrigin[BasicBlock] with Src { 33 | instList.foreach(inst => inst.parent = this) 34 | parent = f 35 | origin = bb 36 | 37 | def insert(inst: MachineInstruction) = { 38 | instList = instList.appended(inst) 39 | inst.parent = this 40 | inst 41 | } 42 | 43 | def insertAtFirst(inst: MachineInstruction) = { 44 | instList = inst :: instList 45 | inst.parent = this 46 | } 47 | 48 | def insertAt(inst: MachineInstruction, pos: Int) = { 49 | instList = instList.take(pos) ++ List(inst) ++ instList.drop(pos) 50 | } 51 | } 52 | 53 | trait MachineOperand { 54 | var instParent: MachineInstruction = null.asInstanceOf[MachineInstruction] 55 | 56 | def replaceFromParent(newOperand: MachineOperand) = { 57 | instParent.operands = instParent.operands.map(op => if op == this then newOperand else op) 58 | newOperand.instParent = instParent 59 | instParent = null 60 | } 61 | } 62 | 63 | trait Src extends MachineOperand 64 | 65 | trait Dst extends MachineOperand 66 | 67 | case class VReg(num: Int, size: Int = 4) extends Src with Dst { 68 | var force: Boolean = false 69 | def dup = VReg(num, size) 70 | } 71 | 72 | case class FrameIndex(offset: Int, size: Int = 4) extends Src with Dst 73 | 74 | case class Imm(value: Int) extends Src 75 | 76 | case class Label(name: String) extends Src 77 | 78 | case class MemoryOperand(base: VReg, displacement: Option[Imm] = None, index: Option[MachineOperand] = None, scale: Option[Imm] = None) extends Src with Dst 79 | 80 | sealed trait MachineInstruction extends InMBB with MapOrigin[Value] with Positional { 81 | var operands: List[MachineOperand] = List() 82 | 83 | def setOperand(op: MachineOperand, i: Int) = { 84 | operands = operands.updated(i, op) 85 | op.instParent = this 86 | } 87 | 88 | def getOperand[T](i: Int) = operands(i).asInstanceOf[T] 89 | 90 | def useIt(inst: MachineInstruction) = { 91 | // todo: fix this and add test 92 | operands.nonEmpty && inst.operands.nonEmpty && operands.contains(inst.operands.head) 93 | } 94 | 95 | def removeFromParent() = { 96 | parent.instList = parent.instList.filter(_ != this) 97 | } 98 | 99 | def initOperands = operands.foreach(op => op.instParent = this) 100 | } 101 | 102 | case class LoadInst(private var _dst: Dst, private var _addr: Src) extends MachineInstruction() { 103 | operands = List(_dst, _addr) 104 | initOperands 105 | 106 | def dst_=(newDst: Dst) = setOperand(newDst, 0) 107 | 108 | def dst: Dst = getOperand(0) 109 | 110 | def addr_=(newSrc: Src) = setOperand(newSrc, 1) 111 | 112 | def addr: Src = getOperand(1) 113 | } 114 | 115 | object LoadInst { 116 | def unapply(inst: MachineInstruction): Option[(Dst, Src)] = { 117 | inst match 118 | case l: LoadInst => Some(l.dst, l.addr) 119 | case _ => None 120 | } 121 | } 122 | 123 | case class StoreInst(private val _addr: Dst, private val _src: Src) extends MachineInstruction() { 124 | operands = List(_addr, _src) 125 | initOperands 126 | 127 | def addr_=(newDst: Dst) = setOperand(newDst, 0) 128 | 129 | def addr: Dst = getOperand(0) 130 | 131 | def src_=(newSrc: Src) = setOperand(newSrc, 1) 132 | 133 | def src: Src = getOperand(1) 134 | } 135 | 136 | object StoreInst { 137 | def unapply(inst: MachineInstruction): Option[(Dst, Src)] = { 138 | inst match 139 | case s: StoreInst => Some(s.addr, s.src) 140 | case _ => None 141 | } 142 | } 143 | 144 | case class CallInst(targetFn: String, private val _dst: Dst, private val _params: List[Src]) extends MachineInstruction() { 145 | operands = _params :+ _dst 146 | initOperands 147 | 148 | def paramSize = operands.size - 1 149 | 150 | def params: List[Src] = operands.take(paramSize).map(_.asInstanceOf[Src]) 151 | 152 | def params_=(newParams: List[Src]) = newParams.zipWithIndex.foreach((src, i) => setOperand(src, i)) 153 | 154 | def dst: Dst = getOperand(paramSize) 155 | 156 | def dst_=(newDst: Dst) = setOperand(newDst, paramSize) 157 | } 158 | 159 | object CallInst { 160 | def unapply(inst: MachineInstruction): Option[(String, Dst, List[Src])] = { 161 | inst match 162 | case c: CallInst => Some(c.targetFn, c.dst, c.params) 163 | case _ => None 164 | } 165 | } 166 | 167 | case class ReturnInst(private val _value: Src) extends MachineInstruction() { 168 | operands = List(_value) 169 | initOperands 170 | 171 | def value: Src = getOperand(0) 172 | 173 | def value_=(newV: Src) = setOperand(newV, 0) 174 | } 175 | 176 | object ReturnInst { 177 | def unapply(inst: MachineInstruction): Option[Src] = { 178 | inst match 179 | case r: ReturnInst => Some(r.value) 180 | case _ => None 181 | } 182 | } 183 | 184 | case class BinaryInst(op: BinaryOperator, private val _dst: Dst, private val _lhs: Src, private val _rhs: Src) extends MachineInstruction() { 185 | operands = List(_dst, _lhs, _rhs) 186 | initOperands 187 | 188 | def dst_=(newDst: Dst) = setOperand(newDst, 0) 189 | 190 | def dst: Dst = getOperand(0) 191 | 192 | def lhs_=(newLhs: Src) = setOperand(newLhs, 1) 193 | 194 | def lhs: Src = getOperand(1) 195 | 196 | def rhs_=(newRhs: Src) = setOperand(newRhs, 2) 197 | 198 | def rhs: Src = getOperand(2) 199 | } 200 | 201 | object BinaryInst { 202 | def unapply(inst: MachineInstruction): Option[(BinaryOperator, Dst, Src, Src)] = { 203 | inst match 204 | case b: BinaryInst => Some(b.op, b.dst, b.lhs, b.rhs) 205 | case _ => None 206 | } 207 | } 208 | 209 | enum CondType: 210 | case LT 211 | case GT 212 | case EQ 213 | 214 | case class CondBrInst(private val _cond: Src, private val _addr: Src, condType: CondType) extends MachineInstruction { 215 | operands = List(_cond, _addr) 216 | initOperands 217 | 218 | def cond: Src = getOperand(0) 219 | 220 | def cond_=(newCond: Src) = setOperand(newCond, 0) 221 | 222 | def addr: Src = getOperand(1) 223 | 224 | def addr_=(newAddr: Src) = setOperand(newAddr, 1) 225 | } 226 | 227 | object CondBrInst { 228 | def unapply(inst: MachineInstruction): Option[(Src, Src, CondType)] = { 229 | inst match 230 | case c: CondBrInst => Some(c.cond, c.addr, c.condType) 231 | case _ => None 232 | } 233 | } 234 | 235 | case class BranchInst(private val _addr: Src) extends MachineInstruction { 236 | operands = List(_addr) 237 | initOperands 238 | 239 | def addr: Src = getOperand(0) 240 | 241 | def addr_=(newAddr: Src) = setOperand(newAddr, 0) 242 | } 243 | 244 | object BranchInst { 245 | def unapply(inst: MachineInstruction): Option[(Src)] = { 246 | inst match 247 | case b: BranchInst => Some(b.addr) 248 | case _ => None 249 | } 250 | } 251 | 252 | case class PhiInst(private val _dst: Dst, var incomings: Map[Src, MachineBasicBlock]) extends MachineInstruction { 253 | operands = List(_dst) 254 | initOperands 255 | incomings = incomings.map((k, v) => { 256 | k.instParent = this 257 | (k, v) 258 | }) 259 | 260 | def dst: Dst = getOperand(0) 261 | 262 | def dst_=(newDst: Dst) = setOperand(newDst, 0) 263 | } 264 | 265 | object PhiInst { 266 | def unapply(inst: MachineInstruction): Option[(Dst, Map[Src, MachineBasicBlock])] = { 267 | inst match 268 | case c: PhiInst => Some(c.dst, c.incomings) 269 | case _ => None 270 | } 271 | } 272 | 273 | case class InlineASM(str: String) extends MachineInstruction() 274 | 275 | enum BinaryOperator: 276 | case Add 277 | case Sub 278 | case GT 279 | case LT 280 | -------------------------------------------------------------------------------- /src/main/scala/codegen/GNUASM.scala: -------------------------------------------------------------------------------- 1 | package rclang 2 | package codegen 3 | 4 | import codegen.CallingConvention.x86_64 5 | 6 | class GeneralX64Machine extends TargetMachine { 7 | val cpu = "general" 8 | val callingConvention = x86_64 9 | val wordSize = if callingConvention == x86_64 then 8 else 4 10 | val gregCount = 16 11 | val regInfos = List() 12 | val asmEmiter = GNUASMEmiter() 13 | } 14 | 15 | class GNUASMEmiter extends ASMEmiter { 16 | override def emitMF(mf: MachineFunction): MFText = { 17 | // impl convert 18 | val label = ASMLabel(s"${mf.name}:") 19 | val saveRBP = ASMInstr("pushq %rbp") 20 | val setRBP = ASMInstr("movq %rsp, %rbp") 21 | val allocStackFrame = if !mf.frameInfo.isEmpty then ASMInstr("subq $" + s"${mf.frameInfo.alignLength}, %rsp") else ASMInstr("") 22 | val initArg = mf.frameInfo.args.zipWithIndex.map((item, i) => { 23 | ASMInstr(s"${instStr("mov", item.len)} ${paramReg(i, item.len)}, ${-item.offset}(%rbp)") 24 | }) 25 | val asm = List(label, saveRBP, setRBP, allocStackFrame) ::: initArg ::: mf.bbs.flatMap(emitMBB) 26 | MFText(asm, mf) 27 | } 28 | 29 | override def emitMBB(mbb: MachineBasicBlock): List[ASMText] = { 30 | List(ASMLabel(s".${mbb.name}:")) ::: mbb.instList.flatMap(emitInstr) 31 | } 32 | 33 | override def emitInstr(instr: MachineInstruction): List[ASMText] = { 34 | val list = instr match 35 | case BinaryInst(op, dst, lhs, rhs) => binaryInstToASM(op.toString, dst, lhs, rhs) 36 | // load store should not same in mem 37 | case LoadInst(target, value) => value match 38 | case Label(label) => List(s"leaq $label(%rip), ${operandToASM(target)}") 39 | case _ => List(s"${instStr("mov", value)} ${operandToASM(value)}, ${operandToASM(target)}") 40 | case StoreInst(target, value) => { 41 | val valueLen = ValueLen(value) 42 | if ((target.isInstanceOf[FrameIndex] && value.isInstanceOf[FrameIndex]) || value.isInstanceOf[MemoryOperand] || target.isInstanceOf[MemoryOperand]) { 43 | val tmpReg = numToReg(0, valueLen) 44 | val movToTmp = s"${instStr("mov", valueLen)} ${operandToASM(value)}, $tmpReg" 45 | val store = s"${instStr("mov", valueLen)} $tmpReg, ${operandToASM(target)}" 46 | List(movToTmp, store) 47 | } else { 48 | List(s"${instStr("mov", valueLen)} ${operandToASM(value)}, ${operandToASM(target)}") 49 | } 50 | } 51 | // todo: value size 52 | // case ReturnInst(value) => List(s"${instStr("mov", value)} ${operandToASM(value)}, %eax", "popq %rbp", "ret") 53 | case ReturnInst(value) => { 54 | val originLen: Int = ValueLen(value) 55 | val valLen = if originLen == 0 then 8 else originLen 56 | List(s"${instStr("mov", valLen)} ${operandToASM(value)}, ${returnReg(valLen)}", "leave", "ret") 57 | } 58 | case CallInst(target, dst, args) => { 59 | val argList = args.zipWithIndex.map((value, i) => value match 60 | case Label(label) => s"leaq ${label}(%rip), ${paramReg(i, 8)}" 61 | case _ => { 62 | val len = ValueLen(value) 63 | s"${instStr("mov", len)} ${operandToASM(value)}, ${paramReg(i, len)}" 64 | } 65 | ).reverse 66 | val call = s"call $target" 67 | var dstLen = ValueLen(dst) 68 | if(dstLen == 0) { 69 | // todo: for no return value, maybe other way 70 | return (argList ::: List(call)).map(ASMInstr) 71 | } 72 | val saveResult = s"${instStr("mov", dstLen)} ${getRegASM(dstLen, 0)}, ${operandToASM(dst)}" 73 | (argList ::: List(call, saveResult)).map(ASMInstr) 74 | } 75 | case InlineASM(content) => List(content) 76 | case BranchInst(label) => List(s"jmp .${operandToASM(label)}") 77 | case CondBrInst(cond, addr, condType) => List(s"j${condType.toString.toLowerCase.head}e .${operandToASM(addr)}") 78 | case PhiInst(dst, _) => throw new Exception() 79 | case x => println(x.getClass.toString); ??? 80 | list.map { 81 | case s: String => s 82 | case x: ASMInstr => x.instr 83 | case n => n 84 | }.map(s => ASMInstr(s"${s} # ${instr.origin.pos}")) 85 | } 86 | 87 | private def ValueLen(value: MachineOperand): Int = { 88 | val valLen = value match 89 | // todo: len 90 | case VReg(n, size) => size 91 | case FrameIndex(offset, size) => size 92 | case _ => 4 93 | valLen 94 | } 95 | 96 | def operandToASM(operand: MachineOperand, immWithPrefix: Boolean = true): String = { 97 | operand match 98 | case Imm(value) => { 99 | val prefix = if immWithPrefix then "$" else "" 100 | prefix + value.toString 101 | } 102 | case r: VReg => regToASM(r) 103 | case Label(name) => name 104 | case FrameIndex(index, size) => s"${-index}(%rbp)" 105 | case MemoryOperand(base, dis, index, scale) => { 106 | if(index.isDefined || scale.isDefined || dis.isEmpty) { 107 | ??? 108 | } 109 | base match 110 | // case FrameIndex(frameIndex, size) => s"${-(frameIndex + dis.get.value)}(%rbp)" 111 | case VReg(num, size) => s"${dis.get.value}(${numToReg(num, size)})" 112 | } 113 | 114 | case _ => ??? 115 | } 116 | 117 | def binaryInstToASM(op: String, dst: MachineOperand, lhs: MachineOperand, rhs: MachineOperand): List[ASMText] = { 118 | def toStr(op: String): List[ASMText] = { 119 | val lhsSize = 4 120 | val eax = getRegASM(lhsSize, 0) // eax == 0 121 | val ebx = getRegASM(lhsSize, 1) // eax == 0 122 | val lhsMov = s"${instStr("mov", lhs)} ${operandToASM(lhs)}, $ebx" // eax = lhs 123 | val rhsMov = s"${instStr("mov", rhs)} ${operandToASM(rhs)}, $eax" // eax = lhs 124 | val bn = s"${instStr(op, lhs)} $eax, $ebx" // eax *= rhs 125 | val mv = s"${instStr("mov", dst)} $ebx, ${operandToASM(dst)}" // dst = eax 126 | List(lhsMov, rhsMov, bn, mv) 127 | } 128 | 129 | // cmp $3, %eax == 3 < eax 130 | op match 131 | case "Add" => toStr("add") 132 | case "Sub" => toStr("sub") 133 | case "LT" => toStr("cmp") 134 | case "GT" => toStr("cmp") 135 | } 136 | 137 | def instStr(inst: String, operand: MachineOperand): String = { 138 | inst + instTy(4) 139 | } 140 | 141 | def instStr(inst: String, len: Int): String = { 142 | inst + instTy(len) 143 | } 144 | 145 | 146 | def instTy(size: Int): String = { 147 | if size == 4 then "l" else "q" 148 | } 149 | 150 | def getRegASM(len: Int, num: Int) = { 151 | if (len == 4) { 152 | numToReg4(num) 153 | } else if (len == 8) { 154 | numToReg8(num) 155 | } else { 156 | ??? 157 | } 158 | } 159 | 160 | def regToASM(reg: VReg): String = { 161 | numToReg(reg.num, reg.size) 162 | } 163 | 164 | def numToReg(num: Int, size: Int) = { 165 | if(size == 4) { 166 | numToReg4(num) 167 | } else if (size == 8) { 168 | numToReg8(num) 169 | } else { 170 | "error num To Reg" 171 | } 172 | } 173 | def numToReg4(num: Int): String = { 174 | val name = num match 175 | case 0 => "eax" 176 | case 1 => "ebx" 177 | case 2 => "ecx" 178 | case 3 => "edx" 179 | case 4 => "esi" 180 | case 5 => "edi" 181 | // case 6 => "ebp" // ebp 182 | // case 7 => "esp" // esp 183 | case 6 => "r8d" 184 | case 7 => "r9d" 185 | case 8 => "r10d" 186 | case 9 => "r11d" 187 | case 10 => "r12d" 188 | case 11 => "r13d" 189 | case 12 => "r14d" 190 | case 13 => "r15d" 191 | case _ => "out" 192 | "%" + name 193 | } 194 | 195 | def numToReg8(num: Int): String = { 196 | val name = num match 197 | case 0 => "rax" 198 | case 1 => "rbx" 199 | case 2 => "rcx" 200 | case 3 => "rdx" 201 | case 4 => "rsi" 202 | case 5 => "rdi" 203 | case 6 => "rbp" 204 | case 7 => "rsp" 205 | case 8 => "r8" 206 | case 9 => "r9" 207 | case 10 => "r10" 208 | case 11 => "r11" 209 | case 12 => "r12" 210 | case 12 => "r12" 211 | case 13 => "r13" 212 | case 14 => "r14" 213 | case 15 => "r15" 214 | case 99 => "rip" 215 | case _ => "out" 216 | "%" + name 217 | } 218 | 219 | def returnReg(len: Int) = { 220 | if (len == 4) { 221 | "%eax" 222 | } else if(len == 8) { 223 | "%rax" 224 | } else { 225 | s"errorLen${len}" 226 | } 227 | } 228 | // rdi, rsi, rdx, rcx, r8/r8d, r9/r9d 229 | def paramReg(num: Int, len: Int): String = { 230 | var name = "" 231 | if (len == 4) { 232 | name = num match 233 | case 0 => "edi" 234 | case 1 => "esi" 235 | case 2 => "edx" 236 | case 3 => "ecx" 237 | case _ => ??? 238 | } else { 239 | name = num match 240 | case 0 => "rdi" 241 | case 1 => "rsi" 242 | case 2 => "rdx" 243 | case 3 => "rcx" 244 | case _ => ??? 245 | } 246 | "%" + name 247 | } 248 | } 249 | 250 | --------------------------------------------------------------------------------