├── .gitignore ├── README.md ├── bench └── src │ └── main │ └── scala │ ├── ComparisonBenchmark.scala │ ├── FFMBenchmark.scala │ └── InterpreterBenchmark.scala ├── build.sbt ├── example.in ├── hs_err_pid31620.log ├── hs_err_pid32029.log ├── project ├── build.properties └── plugins.sbt ├── run.sh └── src ├── main ├── java │ ├── Foo.java │ └── de │ │ └── szeiger │ │ └── interact │ │ ├── codegen │ │ └── LocalClassLoader.java │ │ └── stc2 │ │ └── MetaClass.java └── scala │ ├── Analyzer.scala │ ├── BitOps.scala │ ├── CleanEmbedded.scala │ ├── Colors.scala │ ├── Compiler.scala │ ├── CreateWiring.scala │ ├── Curry.scala │ ├── Debug.scala │ ├── ExecutionMetrics.scala │ ├── ExpandRules.scala │ ├── Global.scala │ ├── Inline.scala │ ├── Main.scala │ ├── NormalizeCondition.scala │ ├── Parser.scala │ ├── PlanRules.scala │ ├── Prepare.scala │ ├── ResolveEmbedded.scala │ ├── Runtime.scala │ ├── SymCounts.scala │ ├── ast │ ├── AST.scala │ ├── Symbols.scala │ └── Transform.scala │ ├── codegen │ ├── AbstractCodeGen.scala │ ├── BoxOps.scala │ ├── ClassWriter.scala │ ├── ParSupport.scala │ └── dsl │ │ ├── Acc.scala │ │ ├── DSL.scala │ │ ├── Desc.scala │ │ └── TypedDSL.scala │ ├── mt │ ├── CodeGen.scala │ ├── Interpreter.scala │ └── workers │ │ └── Workers.scala │ ├── offheap │ ├── Allocator.scala │ └── MemoryDebugger.scala │ ├── stc1 │ ├── CodeGen.scala │ ├── GenStaticReduce.scala │ ├── Interpreter.scala │ └── PTOps.scala │ ├── stc2 │ ├── CodeGen.scala │ ├── GenStaticReduce.scala │ ├── Interpreter.scala │ └── PTOps.scala │ └── sti │ └── Interpreter.scala └── test ├── resources ├── ack.check ├── ack.in ├── diverging.check ├── diverging.in ├── embedded.check ├── embedded.in ├── fib.check ├── fib.in ├── inlining.check ├── inlining.in ├── lists.check ├── lists.in ├── par-mult.check ├── par-mult.in ├── seq-def.check └── seq-def.in └── scala ├── BitOpsTest.scala ├── LongBitOpsTest.scala ├── MainTest.scala ├── TestUtils.scala └── WorkersTest.scala /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | .bsp 3 | .idea 4 | bench/gen-classes 5 | bench/gen-src 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language, interpreter and compiler for interaction nets 2 | 3 | The language implements interaction nets and combinators as defined in https://core.ac.uk/download/pdf/81113716.pdf. 4 | 5 | ## Running 6 | 7 | Launch `sbt` and then `runMain Main example.in` in the sbt shell to run [example.in](./example.in) using a single-threaded interpreter. 8 | 9 | Launch `sbt` and then `runMain Debug example.in` in the sbt shell to run [example.in](./example.in) using the step-by-step debugger. 10 | 11 | ## Language 12 | 13 | There are 4 basic statements: `cons` defines a data constructor, `def` defines a function (with optional rules), `match` defines a detached rule, and `let` creates a part of the net that should be evaluated. Comments start with `#` and extend until the end of the line. Expression lists use Haskell-style significant indentation. 14 | 15 | ### Constructors 16 | 17 | Constructors are written as `Id(aux_1, aux_2, ..., aux_n) = principal`. The parentheses are optional for arity 0. The prinpical port is only used for documentation purposes and can always be omitted. For example: 18 | 19 | ``` 20 | # Natural numbers 21 | cons Z 22 | cons S(n) 23 | ``` 24 | 25 | ### Functions 26 | 27 | Semantically there is no difference between data and code, and it is possible to define and use functions with the normal constructor syntax and detached rules (or conversely use function syntax for data). A separate syntax exists because it makes sense for data to have the principal port be the return value of the expression, but it is more useful for functions to have the principal port as a parameter: 28 | 29 | ``` 30 | def add(x, y) = r 31 | ``` 32 | 33 | The first argument of the function is always assigned to the principal port while the rest of the arguments and return value(s) make up the auxiliary ports. Note that return values must always be specified for functions because their arity is not fixed (like it is for constructors). Tuple syntax is used for multiple return values, e.g.: 34 | 35 | ``` 36 | def dup(x) = (a, b) 37 | ``` 38 | 39 | In order to define rules together with the function, the parameters and return values must be named. The first argument (i.e. the principal port) can be always omitted by using a `_` wildcard: 40 | 41 | ``` 42 | def add(_, y) = r 43 | ``` 44 | 45 | Other parameters can be omitted if the rules don't need them (because they use curried patterns or are defined as detached rules). 46 | 47 | ### Data 48 | 49 | The interaction nets that should be reduced by the interpreter are defined using `let` statements: 50 | 51 | ``` 52 | let two = S(S(Z)) 53 | y = S(S(S(S(S(Z))))) 54 | x = S(S(S(Z))) 55 | example_3_plus_5 = add(x, y) 56 | ``` 57 | 58 | The body of a `let` statement contains a block of expressions consisting of assignments and nested function / constructor applications. Tuples can be used to group individual values, but they cannot be nested. Variables that are used exactly once are defined as free wires. Additional temporary variables can be introduced, but they must be used exactly twice. The order of expressions and the direction of assignments is irrelevant. 59 | 60 | The syntax for expression blocks uses Haskell-style layout: Additional expressions must start at the same indentation level as the first one. Lines with larger indentation continue the current expression. Multiple exressions on the same line must be separated by `;`. A trailing `;` at the end of the line is optional. 61 | 62 | ### Operators 63 | 64 | Both, constructors and functions, can be written as symbolic binary operators. An operator consists of an arbitrary combination of the symbols `*/%+-:<>&^|`. Precedence is based on the first character (the same as in Scala), operators ending with `:` are right-associative, all others left-associative (again, like in Scala). All operators in a chain of same-precedence operations must have the same associativity. Operator definitions are written in infix syntax: 65 | 66 | ``` 67 | # Linked lists 68 | cons Nil 69 | cons head :: tail 70 | 71 | def _ + y = r 72 | ``` 73 | 74 | The same infix notation is used for applying them in expressions: 75 | 76 | ``` 77 | let example_3_plus_2 = S(S(S(Z))) + S(S(Z)) 78 | ``` 79 | 80 | ### Rules 81 | 82 | Reduction rules for functions can be specified together with the definition using a pattern matching syntax which matches on the first argument: 83 | 84 | ``` 85 | # Addition on natural numbers 86 | def _ + y = r 87 | | Z => y 88 | | S(x) => r = x + S(y) 89 | ``` 90 | 91 | The right-hand side contains an expression list, similar to a `let` clause. All function/constructor arguments and return values (except the principal port on each side of the match) must be used exactly once in the reduction. 92 | 93 | ``` 94 | def _ * y = r 95 | | Z => erase(y) 96 | Z = r 97 | | S(x) => (y1, y2) = dup(y) 98 | x * y1 + y2 = r 99 | ``` 100 | 101 | If the last expression in the block is missing an assignment, it is implicitly assigned to the return value of the function: 102 | 103 | ``` 104 | def _ + y = r 105 | | Z => y # same as y = r 106 | | S(x) => x + S(y) # same as x + S(y) = r 107 | ``` 108 | 109 | The standard `dup` and `erase` functions are pre-defined, and combinators with all user-defined constructors and functions are derived automatically. The pre-defined functions are equivalent to the following syntax: 110 | 111 | ``` 112 | def erase(_) = () 113 | def dup[label l](x) = (x1, x2) 114 | ``` 115 | 116 | When matching on another function instead of a constructor, a `_` wildcard must be used to mark the first argument (i.e. the principal port) as the designated return value of an assignment expression. The wildcard always expands to the return value of the nearest enclosing assignment. For example: 117 | 118 | ``` 119 | def dup(_) = (a, b) 120 | | dup(_) = (c, d) => (c, d) 121 | ``` 122 | 123 | When matching on a function with no return value (like `erase`), an assignment to an empty tuple can be used to correctly expand the wildcard: 124 | 125 | ``` 126 | def erase(_) 127 | | erase(_) = () => () 128 | ``` 129 | 130 | ### Currying 131 | 132 | It is possible to use both, nested patterns and additional matches on auxiliary ports, to define curried functions, corresponding to currying on the left-hand side and right-hand side of a match. For example: 133 | 134 | ``` 135 | def fib(_) = r 136 | | Z => 1n 137 | | S(Z) => 1n 138 | | S(S(n)) => (n1, n2) = dup(n) 139 | fib(S(n1)) + fib(n2) 140 | ``` 141 | 142 | This expands to a definition similar to this one (modulo the generated name of the curried function): 143 | 144 | ``` 145 | def fib(_) = r 146 | | Z => 1n 147 | | S(n) => fib2(n) 148 | 149 | def fib2(_) = r 150 | | Z => 1n 151 | | S(n) => (n1, n2) = dup(n) 152 | fib(S(n1)) + fib(n2) 153 | ``` 154 | 155 | Matching on auxiliary ports is done by specifying a comma-separated list in a `def` rule. In the following example `b` is matched by `S(y)` in the second rule after successfully matching `a` with `S(x)`: 156 | 157 | ``` 158 | def foo(a, b) = r 159 | | Z => erase(b), Z 160 | | S(x), S(y) => x + y 161 | ``` 162 | 163 | Restrictions on curried rules: 164 | - All additional matches must be done on the same port of the original match. 165 | - Nested matches must not conflict with another match at the outer layer (e.g. you cannot match on both `f(S(x))` and `f(S(S(x)))`). 166 | 167 | ### Detached Rules 168 | 169 | A rule can be defined independently of a function definition using a `match` statement. These rules can also be defined for `cons`-style constructors (which do not have a special rule syntax like `def`). The expression on the left-hand side is interpreted as a pattern which must correspond to two cells connected via their principal ports. For example: 170 | 171 | ``` 172 | match add(Z, y) => y 173 | match add(S(x), y) => add(x, S(y)) 174 | ``` 175 | 176 | A combination of two constructors can be matched with an assignment, e.g.: 177 | 178 | ``` 179 | # Assimilation rules for S and dup 180 | match S(x) = S(y) => x = y 181 | match dup(dup(_) = (c, d)) = (a, b) => (c, d) 182 | ``` 183 | 184 | Currying works the same as in rules attached to a `def` statement. 185 | 186 | ### Natural Numbers 187 | 188 | There is syntactic support for parsing and printing natural numbers, e.g.: 189 | 190 | ``` 191 | let example_3_times_2 = mult(3n, 2n) 192 | ``` 193 | 194 | The snippet expands to: 195 | 196 | ``` 197 | let example_3_times_2 = mult(S(S(S(Z))), S(S(Z))) 198 | ``` 199 | 200 | This assumes that you have suitable definitions of `Z` and `S` like: 201 | 202 | ``` 203 | cons Z 204 | cons S(n) 205 | ``` 206 | 207 | ### Embedded Values 208 | 209 | A cell (`cons` or `def`) can optionally contain a primitive JVM value of type `int`, `ref` (`java.lang.Object`) or `label`. The type is placed in square brackets after the constructor name: 210 | 211 | ``` 212 | cons Int[int] 213 | cons String[ref] 214 | ``` 215 | 216 | Any match on such a constructor or its use in an expression must associate a variable name with the embedded value. These embedded variables share the same scope as regular variables but cannot be used interchangeably. If the same variable is used in a match and in the expansion, the value is automatically moved: 217 | 218 | ``` 219 | def _ + y = r 220 | | Int[i] => intAdd[i](y) 221 | ``` 222 | 223 | `int` and `label` values can also be copied or deleted implicitly by using the variable assigned to them in the match more or less than once in the expansion. `ref` values can only be moved implicitly. 224 | 225 | A static JVM method (or a method in a Scala object) can be invoked to perform a computation on embedded values by calling the method with its fully qualified name in an embedded expression in square brackets: 226 | ``` 227 | def intAdd[int a](_) = r 228 | | Int[b] => [de.szeiger.interact.Runtime.add(a, b, c)] 229 | Int[c] 230 | ``` 231 | 232 | Input parameters (corresponding to variables used in the match) must be of type `int` or a reference type, output parameters (corresponding to variables used in the expansion) must be of type `IntOutput` or `RefOutput`: 233 | 234 | ``` 235 | // Implementation in Scala: 236 | object Runtime { 237 | def add(a: Int, b: Int, res: IntOutput): Unit = 238 | res.setValue(a + b) 239 | } 240 | ``` 241 | 242 | It is up to the implementation of such a method to handle copying and deleting in an appropriate way. Since embedded `ref` values can also be copied and deleted by the `dup` and `erase` functions, the values may implement the `LifecycleManaged` interface to implement these methods. Otherwise, references will be shared or dropped from scope as usual on the JVM. 243 | 244 | Values of type `label` can be created implicitly by using the same variable in one or more cells without an expression to create it: 245 | 246 | ``` 247 | let (x1, x2) = dup[label1](x) # dup(x) and dup(y) get the same label 248 | (y1, y2) = dup[label1](y) 249 | (z1, z2) = dup[label2](z) # dup(z) gets a different label 250 | (a1, a1) = dup(a) # dup(a) gets the default label 251 | ``` 252 | 253 | All labels created in this way are unique per rule reduction or `let` statement. 254 | 255 | An implementation method may directly return a value instead of using an `IntOutput` / `RefOutput` output parameter. Such a method can be used directly as an embedded computation of a cell: 256 | 257 | ``` 258 | def strlen(_) = r 259 | | String[s] => Int[de.szeiger.interact.Runtime.strlen(s)] 260 | ``` 261 | 262 | Int literals and some basic operators are available in embedded computations, e.g.: 263 | 264 | ``` 265 | let x = Int[3 + 5] 266 | ``` 267 | 268 | Embedded values can be checked in a match to select different expansions for a rule depending on the embedded values: 269 | 270 | ``` 271 | def fib(_) = r 272 | | Int[i] if [i == 0] => Int[1] 273 | if [i == 1] => Int[1] 274 | else => fib(Int[i-1]) + fib(Int[i-2]) 275 | ``` 276 | 277 | The `else` clause is required, all branches must be defined together, and the order of definition is significant. 278 | 279 | Current implementation limitations: 280 | - Currying is not allowed when both sides of the match contain embedded values. 281 | -------------------------------------------------------------------------------- /bench/src/main/scala/ComparisonBenchmark.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import org.openjdk.jmh.annotations._ 4 | import org.openjdk.jmh.infra._ 5 | 6 | import java.util.concurrent.TimeUnit 7 | 8 | @BenchmarkMode(Array(Mode.Throughput)) 9 | @Fork(value = 1, jvmArgsAppend = Array("-Xmx12g", "-Xss32M", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseZGC")) 10 | @Threads(1) 11 | @Warmup(iterations = 15, time = 1) 12 | @Measurement(iterations = 15, time = 1) 13 | @OutputTimeUnit(TimeUnit.SECONDS) 14 | @State(Scope.Benchmark) 15 | class ComparisonBenchmark { 16 | 17 | @Param(Array("stc2")) 18 | var spec: String = _ 19 | 20 | private val prelude = 21 | """cons Z 22 | |cons S(n) 23 | |def add(_, y) = r 24 | | | Z => y 25 | | | S(x) => add(x, S(y)) 26 | |""".stripMargin 27 | 28 | private val mult1Src = prelude + 29 | """def mult(_, y) = r 30 | | | Z => erase(y); Z 31 | | | S(x) => (y1, y2) = dup(y); add(mult(x, y1), y2) 32 | |let res = mult(100n, 100n) 33 | |""".stripMargin 34 | 35 | private val fib22Src = prelude + 36 | """def add2(_, y) = r 37 | | | Z => y 38 | | | S(x) => S(add2(x, y)) 39 | |def fib(_) = r 40 | | | Z => 1n 41 | | | S(n) => fib2(n) 42 | |def fib2(_) = r 43 | | | Z => 1n 44 | | | S(n) => (n1, n2) = dup(n); add2(fib(S(n1)), fib(n2)) 45 | |let res = fib(22n) 46 | |""".stripMargin 47 | 48 | private val intAck38Src = 49 | """cons Int[int] 50 | | 51 | |def ackU(a, b) = r 52 | | | Int[x], Int[y] 53 | | if [x == 0] => Int[y + 1] 54 | | if [y == 0] => ackU(Int[x - 1], Int[1]) 55 | | else => ackU(Int[x - 1], ackU(Int[x], Int[y - 1])) 56 | | 57 | |let resU = ackU(Int[3], Int[9]) 58 | |""".stripMargin 59 | 60 | private val intMult3Src = 61 | """cons Int[int] 62 | |def add(_, _) = r 63 | | | Int[x], Int[y] if [x == 0] => Int[y] 64 | | else => add(Int[x-1], Int[y+1]) 65 | |def mult(_, _) = r 66 | | | Int[x], Int[y] if [x == 0] => Int[0] 67 | | else => add(Int[y], mult(Int[x-1], Int[y])) 68 | |let res = mult(Int[1000], Int[1000]) 69 | |""".stripMargin 70 | 71 | class PreparedInterpreter(source: String) { 72 | val model: Compiler = new Compiler(Parser.parse(source), Config(spec)) 73 | val inter = model.createInterpreter() 74 | def setup(): BaseInterpreter = { 75 | inter.initData() 76 | inter 77 | } 78 | } 79 | 80 | private lazy val mult1Inter: PreparedInterpreter = new PreparedInterpreter(mult1Src) 81 | private lazy val fib22Inter: PreparedInterpreter = new PreparedInterpreter(fib22Src) 82 | private lazy val intAck38Inter: PreparedInterpreter = new PreparedInterpreter(intAck38Src) 83 | private lazy val intMult3Inter: PreparedInterpreter = new PreparedInterpreter(intMult3Src) 84 | 85 | @Benchmark 86 | def mult1(bh: Blackhole): Unit = 87 | bh.consume(mult1Inter.setup().reduce()) 88 | 89 | @Benchmark 90 | def fib22(bh: Blackhole): Unit = 91 | bh.consume(fib22Inter.setup().reduce()) 92 | 93 | @Benchmark 94 | def intAck39(bh: Blackhole): Unit = 95 | bh.consume(intAck38Inter.setup().reduce()) 96 | 97 | @Benchmark 98 | def intMult3(bh: Blackhole): Unit = 99 | bh.consume(intAck38Inter.setup().reduce()) 100 | 101 | @Benchmark 102 | def mult1Scala(bh: Blackhole): Unit = { 103 | sealed abstract class Nat 104 | case object Z extends Nat 105 | case class S(pred: Nat) extends Nat 106 | def nat(i: Int): Nat = i match { 107 | case 0 => Z 108 | case n => S(nat(n-1)) 109 | } 110 | def add(x: Nat, y: Nat): Nat = x match { 111 | case Z => y 112 | case S(x) => add(x, S(y)) 113 | } 114 | def mult(x: Nat, y: Nat): Nat = x match { 115 | case Z => Z 116 | case S(x) => add(mult(x, y), y) 117 | } 118 | bh.consume(mult(nat(100), nat(100))) 119 | } 120 | 121 | @Benchmark 122 | def fib22Scala(bh: Blackhole): Unit = { 123 | sealed abstract class Nat 124 | case object Z extends Nat 125 | case class S(pred: Nat) extends Nat 126 | def nat(i: Int): Nat = i match { 127 | case 0 => Z 128 | case n => S(nat(n-1)) 129 | } 130 | def add2(x: Nat, y: Nat): Nat = x match { 131 | case Z => y 132 | case S(x) => S(add2(x, y)) 133 | } 134 | def fib(x: Nat): Nat = x match { 135 | case Z => nat(1) 136 | case S(n) => fib2(n) 137 | } 138 | def fib2(x: Nat): Nat = x match { 139 | case Z => nat(1) 140 | case S(n) => add2(fib(S(n)), fib(n)) 141 | } 142 | bh.consume(fib(nat(22))) 143 | } 144 | 145 | @Benchmark 146 | def intAck39Scala(bh: Blackhole): Unit = { 147 | def ack(x: Int, y: Int): Int = 148 | if(x == 0) y + 1 149 | else if(y == 0) ack(x-1, 1) 150 | else ack(x-1, ack(x, y-1)) 151 | bh.consume(ack(3, 9)) 152 | } 153 | 154 | @Benchmark 155 | def intMult3Scala(bh: Blackhole): Unit = { 156 | def add(x: Int, y: Int): Int = 157 | if(x == 0) y 158 | else add(x-1, y+1) 159 | def mult(x: Int, y: Int): Int = 160 | if(x == 0) 0 161 | else add(y, mult(x-1, y)) 162 | bh.consume(mult(1000, 1000)) 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /bench/src/main/scala/FFMBenchmark.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.offheap.{Allocator, ArenaAllocator} 4 | import org.openjdk.jmh.annotations._ 5 | import org.openjdk.jmh.infra._ 6 | 7 | //import java.lang.foreign.MemoryLayout.PathElement 8 | import java.util.concurrent.TimeUnit 9 | //import java.lang.foreign._ 10 | import java.nio.ByteBuffer 11 | 12 | @BenchmarkMode(Array(Mode.AverageTime)) 13 | @Fork(value = 1, jvmArgsAppend = Array("-Xmx12g", "-Xss32M", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseZGC", "--enable-native-access=ALL-UNNAMED")) 14 | @Threads(1) 15 | @Warmup(iterations = 10, time = 1) 16 | @Measurement(iterations = 10, time = 1) 17 | @OutputTimeUnit(TimeUnit.MILLISECONDS) 18 | @State(Scope.Benchmark) 19 | class FFMBenchmark { 20 | val u = { 21 | val f = classOf[sun.misc.Unsafe].getDeclaredField("theUnsafe") 22 | f.setAccessible(true) 23 | f.get(null).asInstanceOf[sun.misc.Unsafe] 24 | } 25 | 26 | println() 27 | println("heap: " + heap0()) 28 | // println("heapByteByffer: " + heapByteBuffer0()) 29 | //println("autoArena: " + autoArena0()) 30 | // println("customAllocator: " + customAllocator0()) 31 | // println("customAllocatorDirect: " + customAllocatorDirect0()) 32 | //println("ffmByteBuffer: " + ffmByteBuffer0()) 33 | //println("unsafe: " + unsafe0()) 34 | println("unsafeCustom: " + unsafeCustom0()) 35 | println("unsafeAllocator: " + unsafeAllocator0()) 36 | 37 | @Benchmark 38 | def heap(bh: Blackhole): Unit = bh.consume(heap0()) 39 | 40 | // @Benchmark 41 | // def heapByteBuffer(bh: Blackhole): Unit = bh.consume(heapByteBuffer0()) 42 | 43 | //@Benchmark 44 | //def autoArena(bh: Blackhole): Unit = bh.consume(autoArena0()) 45 | 46 | // @Benchmark 47 | // def customAllocator(bh: Blackhole): Unit = bh.consume(customAllocator0()) 48 | // 49 | // @Benchmark 50 | // def customAllocatorDirect(bh: Blackhole): Unit = bh.consume(customAllocatorDirect0()) 51 | 52 | //@Benchmark 53 | //def ffmByteBuffer(bh: Blackhole): Unit = bh.consume(ffmByteBuffer0()) 54 | 55 | //@Benchmark 56 | //def unsafe(bh: Blackhole): Unit = bh.consume(unsafe0()) 57 | 58 | @Benchmark 59 | def unsafeCustom(bh: Blackhole): Unit = bh.consume(unsafeCustom0()) 60 | 61 | @Benchmark 62 | def unsafeAllocator(bh: Blackhole): Unit = bh.consume(unsafeAllocator0()) 63 | 64 | def heap0(): Long = { 65 | final class C(var c0: C, var p0: Int, var c1: C, var p1: Int) 66 | val buf = new Array[C](10000000) 67 | for(i <- buf.indices) buf(i) = new C(null, 0, null, 0) 68 | for(i <- 1 until buf.length-1) { 69 | buf(i).c0 = buf(i - 1) 70 | buf(i).p0 = -1 71 | buf(i).c1 = buf(i + 1) 72 | buf(i).p1 = i 73 | } 74 | var sum = 0L 75 | for(i <- 1 until buf.length-1) { 76 | sum += buf(i).p1 77 | sum -= buf(i).c0.p1 78 | } 79 | sum 80 | } 81 | 82 | def heapByteBuffer0(): Long = { 83 | val buf = new Array[Int](10000000) 84 | val bb = ByteBuffer.allocateDirect(buf.length * 24) 85 | var next = 0 86 | def alloc(len: Int): Int = { 87 | val a = next 88 | next += len 89 | a 90 | } 91 | for(i <- buf.indices) buf(i) = alloc(24) 92 | for(i <- 1 until buf.length-1) { 93 | bb.putLong(buf(i), buf(i-1)) 94 | bb.putLong(buf(i)+8, buf(i+1)) 95 | bb.putInt(buf(i)+16, -1) 96 | bb.putInt(buf(i)+20, i) 97 | } 98 | var sum = 0L 99 | for(i <- 1 until buf.length-1) { 100 | sum += bb.getInt(buf(i)+20) 101 | sum -= bb.getInt(bb.getLong(buf(i)).toInt+20) 102 | } 103 | sum 104 | } 105 | 106 | // def autoArena0(): Long = { 107 | // import LayoutGlobals._ 108 | // val arena = Arena.ofAuto() 109 | // val buf = new Array[MemorySegment](10000000) 110 | // for(i <- buf.indices) buf(i) = arena.allocate(layout) 111 | // for(i <- 1 until buf.length-1) { 112 | // c0vh.set(buf(i), 0, buf(i-1)) 113 | // c1vh.set(buf(i), 0, buf(i+1)) 114 | // p0vh.set(buf(i), 0, -1) 115 | // p1vh.set(buf(i), 0, i) 116 | // } 117 | // var sum = 0L 118 | // for(i <- 1 until buf.length-1) { 119 | // sum += p1vh.get(buf(i), 0).asInstanceOf[Int] 120 | // sum -= p1vh.get(c0vh.get(buf(i), 0).asInstanceOf[MemorySegment].reinterpret(layout.byteSize()), 0).asInstanceOf[Int] 121 | // } 122 | // sum 123 | // } 124 | // 125 | // def customAllocator0(): Long = { 126 | // import LayoutGlobals2S._ 127 | // val arena = Arena.ofAuto() 128 | // val buf = new Array[Long](10000000) 129 | // val _block = arena.allocate(buf.length * 24) 130 | // val root = MemorySegment.NULL.reinterpret(Long.MaxValue) 131 | // var next = _block.address() 132 | // def alloc(len: Int): Long = { 133 | // val a = next 134 | // next += len 135 | // a 136 | // } 137 | // for(i <- buf.indices) buf(i) = alloc(24) 138 | // for(i <- 1 until buf.length-1) { 139 | // c0vh.set(root, buf(i), buf(i-1)) 140 | // c1vh.set(root, buf(i), buf(i+1)) 141 | // p0vh.set(root, buf(i), -1) 142 | // p1vh.set(root, buf(i), i) 143 | // } 144 | // var sum = 0L 145 | // for(i <- 1 until buf.length-1) { 146 | // sum += p1vh.get(root, buf(i)).asInstanceOf[Int] 147 | // sum -= p1vh.get(root, c0vh.get(root, buf(i)).asInstanceOf[Long]).asInstanceOf[Int] 148 | // } 149 | // sum 150 | // } 151 | // 152 | // def customAllocatorDirect0(): Long = { 153 | // val arena = Arena.ofAuto() 154 | // val buf = new Array[Long](10000000) 155 | // val _block = arena.allocate(buf.length * 24) 156 | // val off = _block.address() 157 | // val root = MemorySegment.NULL.reinterpret(Long.MaxValue) 158 | // var next = off 159 | // def alloc(len: Int): Long = { 160 | // val a = next 161 | // next += len 162 | // a 163 | // } 164 | // for(i <- buf.indices) buf(i) = alloc(24) 165 | // for(i <- 1 until buf.length-1) { 166 | // root.set(ValueLayout.JAVA_LONG, buf(i), buf(i-1)) 167 | // root.set(ValueLayout.JAVA_LONG, buf(i)+8, buf(i+1)) 168 | // root.set(ValueLayout.JAVA_INT, buf(i)+16, -1) 169 | // root.set(ValueLayout.JAVA_INT, buf(i)+20, i) 170 | // } 171 | // var sum = 0L 172 | // for(i <- 1 until buf.length-1) { 173 | // sum += root.get(ValueLayout.JAVA_INT, buf(i)+20) 174 | // sum -= root.get(ValueLayout.JAVA_INT, root.get(ValueLayout.JAVA_LONG, buf(i))+20) 175 | // } 176 | // sum 177 | // } 178 | // 179 | // def ffmByteBuffer0(): Long = { 180 | // val arena = Arena.ofAuto() 181 | // val buf = new Array[Int](10000000) 182 | // val bb = arena.allocate(buf.length * 24).asByteBuffer() 183 | // var next = 0 184 | // def alloc(len: Int): Int = { 185 | // val a = next 186 | // next += len 187 | // a 188 | // } 189 | // for(i <- buf.indices) buf(i) = alloc(24) 190 | // for(i <- 1 until buf.length-1) { 191 | // bb.putLong(buf(i), buf(i-1)) 192 | // bb.putLong(buf(i)+8, buf(i+1)) 193 | // bb.putInt(buf(i)+16, -1) 194 | // bb.putInt(buf(i)+20, i) 195 | // } 196 | // var sum = 0L 197 | // for(i <- 1 until buf.length-1) { 198 | // sum += bb.getInt(buf(i)+20) 199 | // sum -= bb.getInt(bb.getLong(buf(i)).toInt+20) 200 | // } 201 | // sum 202 | // } 203 | 204 | def unsafe0(): Long = { 205 | val buf = new Array[Long](10000000) 206 | for(i <- buf.indices) buf(i) = u.allocateMemory(24) 207 | for(i <- 1 until buf.length-1) { 208 | u.putLong(buf(i), buf(i-1)) 209 | u.putLong(buf(i)+8, buf(i+1)) 210 | u.putInt(buf(i)+16, -1) 211 | u.putInt(buf(i)+20, i) 212 | } 213 | var sum = 0L 214 | for(i <- 1 until buf.length-1) { 215 | sum += u.getInt(buf(i)+20) 216 | sum -= u.getInt(u.getAddress(buf(i))+20) 217 | } 218 | for(i <- buf.indices) u.freeMemory(buf(i)) 219 | sum 220 | } 221 | 222 | def unsafeCustom0(): Long = { 223 | val buf = new Array[Long](10000000) 224 | val block = u.allocateMemory(buf.length * 24) 225 | var next = block 226 | def alloc(len: Int): Long = { 227 | val a = next 228 | next += len 229 | a 230 | } 231 | for(i <- buf.indices) buf(i) = alloc(24) 232 | for(i <- 1 until buf.length-1) { 233 | u.putLong(buf(i), buf(i-1)) 234 | u.putLong(buf(i)+8, buf(i+1)) 235 | u.putInt(buf(i)+16, -1) 236 | u.putInt(buf(i)+20, i) 237 | } 238 | var sum = 0L 239 | for(i <- 1 until buf.length-1) { 240 | sum += u.getInt(buf(i)+20) 241 | sum -= u.getInt(u.getAddress(buf(i))+20) 242 | } 243 | u.freeMemory(block) 244 | sum 245 | } 246 | 247 | def unsafeAllocator0(): Long = { 248 | val buf = new Array[Long](10000000) 249 | val a = new ArenaAllocator() 250 | for(i <- buf.indices) buf(i) = a.alloc(24) 251 | for(i <- 1 until buf.length-1) { 252 | Allocator.putLong(buf(i), buf(i-1)) 253 | Allocator.putLong(buf(i)+8, buf(i+1)) 254 | Allocator.putInt(buf(i)+16, -1) 255 | Allocator.putInt(buf(i)+20, i) 256 | } 257 | var sum = 0L 258 | for(i <- 1 until buf.length-1) { 259 | sum += Allocator.getInt(buf(i)+20) 260 | sum -= Allocator.getInt(Allocator.getLong(buf(i))+20) 261 | } 262 | a.dispose() 263 | sum 264 | } 265 | } 266 | 267 | //object LayoutGlobals { 268 | // val layout = MemoryLayout.structLayout( 269 | // ValueLayout.ADDRESS.withName("c0"), 270 | // ValueLayout.ADDRESS.withName("c1"), 271 | // ValueLayout.JAVA_INT.withName("p0"), 272 | // ValueLayout.JAVA_INT.withName("p1"), 273 | // ) 274 | // val c0vh = layout.varHandle(PathElement.groupElement("c0")) 275 | // val p0vh = layout.varHandle(PathElement.groupElement("p0")) 276 | // val c1vh = layout.varHandle(PathElement.groupElement("c1")) 277 | // val p1vh = layout.varHandle(PathElement.groupElement("p1")) 278 | //} 279 | // 280 | //object LayoutGlobals2S { 281 | // val layout = MemoryLayout.structLayout( 282 | // ValueLayout.JAVA_LONG.withName("c0"), 283 | // ValueLayout.JAVA_LONG.withName("c1"), 284 | // ValueLayout.JAVA_INT.withName("p0"), 285 | // ValueLayout.JAVA_INT.withName("p1"), 286 | // ) 287 | // val c0vh = layout.varHandle(PathElement.groupElement("c0")) 288 | // val p0vh = layout.varHandle(PathElement.groupElement("p0")) 289 | // val c1vh = layout.varHandle(PathElement.groupElement("c1")) 290 | // val p1vh = layout.varHandle(PathElement.groupElement("p1")) 291 | //} 292 | -------------------------------------------------------------------------------- /bench/src/main/scala/InterpreterBenchmark.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import org.openjdk.jmh.annotations._ 4 | import org.openjdk.jmh.infra._ 5 | import org.openjdk.jmh.runner.{Defaults, Runner} 6 | import org.openjdk.jmh.runner.format.OutputFormatFactory 7 | import org.openjdk.jmh.runner.options.CommandLineOptions 8 | import org.openjdk.jmh.util.Optional 9 | 10 | import java.nio.file.Path 11 | import scala.jdk.CollectionConverters._ 12 | import java.util.concurrent.TimeUnit 13 | import scala.util.control.NonFatal 14 | 15 | // bench/jmh:runMain de.szeiger.interact.InterpreterBenchmark 16 | 17 | @BenchmarkMode(Array(Mode.Throughput)) 18 | @Fork(value = 1, jvmArgsAppend = Array("-Xmx12g", "-Xss32M", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseZGC", 19 | "-XX:+UnlockDiagnosticVMOptions")) 20 | @Threads(1) 21 | @Warmup(iterations = 11, time = 1) 22 | @Measurement(iterations = 11, time = 1) 23 | @OutputTimeUnit(TimeUnit.MICROSECONDS) 24 | @State(Scope.Benchmark) 25 | class InterpreterBenchmark { 26 | 27 | @Param(Array( 28 | //"sti", 29 | "stc1", 30 | "stc2", 31 | //"mt0.i", //"mt1.i", "mt8.i", 32 | //"mt1000.i", "mt1001.i", "mt1008.i", 33 | //"mt0.c", //"mt1.c", "mt8.c", 34 | //"mt1000.c", "mt1001.c", "mt1008.c", 35 | )) 36 | var spec: String = _ 37 | 38 | @Param(Array( 39 | "ack38", 40 | "ack38b", 41 | "intAck38", 42 | "boxedAck38", 43 | "fib22", 44 | "mult1", 45 | "mult2", 46 | "mult3", 47 | "intMult3", 48 | "intFib29", 49 | // "fib29", 50 | )) 51 | var benchmark: String = _ 52 | 53 | private[this] var inter: BaseInterpreter = _ 54 | 55 | @Setup(Level.Trial) 56 | def setup(): Unit = inter = InterpreterBenchmark.setup(spec, benchmark) 57 | 58 | @Setup(Level.Invocation) 59 | def prepare(): Unit = inter.initData() 60 | 61 | @Benchmark 62 | def run(bh: Blackhole): Unit = bh.consume(inter.reduce()) 63 | 64 | @TearDown(Level.Invocation) 65 | def cleanup(): Unit = inter.dispose() 66 | } 67 | 68 | object InterpreterBenchmark { 69 | private val prelude = 70 | """cons Z 71 | |cons S(n) 72 | |def add(_, y) = r 73 | | | Z => y 74 | | | S(x) => add(x, S(y)) 75 | |""".stripMargin 76 | 77 | private val mult1Src = prelude + 78 | """def mult(_, y) = r 79 | | | Z => erase(y); Z 80 | | | S(x) => (y1, y2) = dup(y); add(mult(x, y1), y2) 81 | |let res = mult(100n, 100n) 82 | |""".stripMargin 83 | 84 | private val mult2Src = prelude + 85 | """def mult(_, y) = r 86 | | | Z => erase(y); Z 87 | | | S(x) => (y1, y2) = dup(y); add(mult(x, y1), y2) 88 | |let res1 = mult(100n, 100n) 89 | | res2 = mult(100n, 100n) 90 | | res3 = mult(100n, 100n) 91 | | res4 = mult(100n, 100n) 92 | |""".stripMargin 93 | 94 | private val mult3Src = prelude + 95 | """def mult(_, y) = r 96 | | | Z => erase(y); Z 97 | | | S(x) => (a, b) = dup(y); add(b, mult(x, a)) 98 | |let res = mult(1000n, 1000n) 99 | |""".stripMargin 100 | 101 | private val intMult3Src = 102 | """cons Int[int] 103 | |def add(_, _) = r 104 | | | Int[x], Int[y] if [x == 0] => Int[y] 105 | | else => add(Int[x-1], Int[y+1]) 106 | |def mult(_, _) = r 107 | | | Int[x], Int[y] if [x == 0] => Int[0] 108 | | else => add(Int[y], mult(Int[x-1], Int[y])) 109 | |let res = mult(Int[1000], Int[1000]) 110 | |""".stripMargin 111 | 112 | private val fib22Src = prelude + 113 | """def add2(_, y) = r 114 | | | Z => y 115 | | | S(x) => S(add2(x, y)) 116 | |def fib(_) = r 117 | | | Z => 1n 118 | | | S(n) => fib2(n) 119 | |def fib2(_) = r 120 | | | Z => 1n 121 | | | S(n) => (n1, n2) = dup(n); add2(fib(S(n1)), fib(n2)) 122 | |let res = fib(22n) 123 | |""".stripMargin 124 | 125 | private val fib29Src = prelude + 126 | """def add2(_, y) = r 127 | | | Z => y 128 | | | S(x) => S(add2(x, y)) 129 | |def fib(_) = r 130 | | | Z => 1n 131 | | | S(n) => fib2(n) 132 | |def fib2(_) = r 133 | | | Z => 1n 134 | | | S(n) => (n1, n2) = dup(n); add2(fib(S(n1)), fib(n2)) 135 | |let res = fib(29n) 136 | |""".stripMargin 137 | 138 | private val intFib29Src = 139 | """cons Int[int] 140 | |def _ + _ = r 141 | | | Int[x], Int[y] => Int[x + y] 142 | |def fib(_) = r 143 | | | Int[x] if [x == 0] => Int[1] 144 | | if [x == 1] => Int[1] 145 | | else => fib(Int[x-1]) + fib(Int[x-2]) 146 | |let res = fib(Int[29]) 147 | |""".stripMargin 148 | 149 | private val ack38Src = prelude + 150 | """def ack(_, y) = r 151 | | | Z => S(y) 152 | | | S(x) => ack_Sx(y, x) 153 | |def ack_Sx(_, x) = r 154 | | | Z => ack(x, S(Z)) 155 | | | S(y) => (x1, x2) = dup(x); ack(x1, ack_Sx(y, x2)) 156 | |let res = ack(3n, 8n) 157 | |""".stripMargin 158 | 159 | private val ack38bSrc = prelude + 160 | """def pred(_) = r 161 | | | Z => Z 162 | | | S(x) => x 163 | |def ack2(_, a) = b 164 | | | Z => S(a) 165 | | | S(x) => ack2b(a, S(x)) 166 | |def ack2b(_, a) = b 167 | | | Z => ack2(pred(a), S(Z)) 168 | | | S(y) => (a1, a2) = dup(a); ack2(pred(a1), ack2(a2, y)) 169 | |let res2 = ack2(3n, 8n) 170 | |""".stripMargin 171 | 172 | private val intAck38Src = 173 | """cons Int[int] 174 | | 175 | |def ackU(a, b) = r 176 | | | Int[x], Int[y] 177 | | if [x == 0] => Int[y + 1] 178 | | if [y == 0] => ackU(Int[x - 1], Int[1]) 179 | | else => ackU(Int[x - 1], ackU(Int[x], Int[y - 1])) 180 | | 181 | |let resU = ackU(Int[3], Int[8]) 182 | |""".stripMargin 183 | 184 | private val boxedAck38Src = 185 | """cons BoxedInt[ref] 186 | | 187 | |def ackB(a, b) = r 188 | | | BoxedInt[x], BoxedInt[y] 189 | | if [de.szeiger.interact.InterpreterBenchmark.is0(x)] => 190 | | BoxedInt[de.szeiger.interact.InterpreterBenchmark.inc(y)] 191 | | [eraseRef(x)] 192 | | if [de.szeiger.interact.InterpreterBenchmark.is0(y)] => 193 | | ackB(BoxedInt[de.szeiger.interact.InterpreterBenchmark.dec(x)], BoxedInt[de.szeiger.interact.InterpreterBenchmark.box(1)]) 194 | | [eraseRef(y)] 195 | | else => 196 | | [de.szeiger.interact.InterpreterBenchmark.ackHelper(x, x1, x2)] 197 | | ackB(BoxedInt[x1], ackB(BoxedInt[x2], BoxedInt[de.szeiger.interact.InterpreterBenchmark.dec(y)])) 198 | | 199 | |let resB = ackB(BoxedInt[de.szeiger.interact.InterpreterBenchmark.box(3)], BoxedInt[de.szeiger.interact.InterpreterBenchmark.box(8)]) 200 | |""".stripMargin 201 | def is0(i: java.lang.Integer): Boolean = i.intValue() == 0 202 | def box(i: Int): java.lang.Integer = Integer.valueOf(i) 203 | def inc(i: java.lang.Integer): java.lang.Integer = box(i.intValue() + 1) 204 | def dec(i: java.lang.Integer): java.lang.Integer = box(i.intValue() - 1) 205 | def ackHelper(i: java.lang.Integer, o1: RefOutput, o2: RefOutput): Unit = { 206 | o1.setValue(dec(i)) 207 | o2.setValue(i) 208 | } 209 | 210 | val testCases = Map( 211 | "ack38" -> ack38Src, 212 | "ack38b" -> ack38bSrc, 213 | "boxedAck38" -> boxedAck38Src, 214 | "intAck38" -> intAck38Src, 215 | "fib22" -> fib22Src, 216 | "intFib29" -> intFib29Src, 217 | "fib29" -> fib29Src, 218 | "mult1" -> mult1Src, 219 | "mult2" -> mult2Src, 220 | "mult3" -> mult3Src, 221 | "intMult3" -> intMult3Src, 222 | ) 223 | 224 | val prepareConfig: Config => Config = 225 | _.copy(collectStats = true, logCodeGenSummary = true, showAfter = Set("PlanRules")) 226 | 227 | val benchConfig: Config => Config = 228 | //identity 229 | _.copy(writeOutput = Some(Path.of("gen-classes")), writeJava = Some(Path.of("gen-src")), logGeneratedClasses = None, showAfter = Set("")) 230 | //_.copy(skipCodeGen = Set("")) 231 | 232 | def setup(spec: String, benchmark: String): BaseInterpreter = 233 | new Compiler(Parser.parse(testCases(benchmark)), benchConfig(Config(spec))).createInterpreter() 234 | 235 | def main(args: Array[String]): Unit = { 236 | val cls = classOf[InterpreterBenchmark] 237 | 238 | def run1(testCase: String, spec: String) = { 239 | try { 240 | println(s"-------------------- Running $testCase $spec:") 241 | val i = new Compiler(Parser.parse(testCases(testCase)), prepareConfig(Config(spec))).createInterpreter() 242 | i.initData() 243 | println() 244 | i.reduce() 245 | val m = i.getMetrics 246 | m.log() 247 | val steps = m.getSteps 248 | i.dispose() 249 | println() 250 | val opts = new CommandLineOptions(cls.getName, s"-pbenchmark=$testCase", s"-pspec=$spec") { 251 | override def getOperationsPerInvocation = Optional.of(steps) 252 | } 253 | val runner = new Runner(opts) 254 | runner.run().asScala 255 | } catch { case NonFatal(ex) => 256 | ex.printStackTrace() 257 | Iterable.empty 258 | } 259 | } 260 | 261 | val res = for { 262 | testCase <- cls.getDeclaredField("benchmark").getAnnotation(classOf[Param]).value().toVector 263 | spec <- cls.getDeclaredField("spec").getAnnotation(classOf[Param]).value() 264 | res <- run1(testCase, spec) 265 | } yield res 266 | 267 | println("-------------------- Results") 268 | System.out.flush() 269 | val out = OutputFormatFactory.createFormatInstance(System.out, Defaults.VERBOSITY) 270 | out.endRun(res.asJava) 271 | } 272 | } 273 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | //cancelable in Global := false 2 | 3 | scalacOptions ++= Seq("-feature") 4 | 5 | Test / fork := true 6 | run / fork := true 7 | run / connectInput := true 8 | 9 | Global / scalaVersion := "2.13.12" 10 | 11 | lazy val main = (project in file(".")) 12 | .settings( 13 | libraryDependencies += "com.github.sbt" % "junit-interface" % "0.13.2" % "test", 14 | libraryDependencies += "com.lihaoyi" %% "fastparse" % "3.0.1", 15 | libraryDependencies ++= Seq("asm", "asm-tree", "asm-util", "asm-commons", "asm-analysis").map(a => "org.ow2.asm" % a % "9.5"), 16 | libraryDependencies += "com.jetbrains.intellij.java" % "java-decompiler-engine" % "233.14015.106", 17 | testOptions += Tests.Argument(TestFrameworks.JUnit, "-a", "-v"), 18 | resolvers += "IntelliJ Releases" at "https://www.jetbrains.com/intellij-repository/releases/", 19 | //scalacOptions ++= Seq("-feature", "-opt:l:inline", "-opt-inline-from:de.szeiger.interact.*", "-opt-inline-from:de.szeiger.interact.**"), 20 | ) 21 | 22 | lazy val bench = (project in file("bench")) 23 | .dependsOn(main) 24 | .enablePlugins(JmhPlugin) 25 | .settings( 26 | Jmh / javaOptions ++= Seq("-Xss32M"), 27 | ) 28 | -------------------------------------------------------------------------------- /example.in: -------------------------------------------------------------------------------- 1 | # Natural numbers 2 | cons Z 3 | cons S(n) 4 | 5 | # Erasure and Duplication 6 | # def erase(_) 7 | # def dup(_) = (a, b) 8 | # | dup(_) = (c, d) => (c, d) 9 | 10 | # Addition 11 | def _ + y = r 12 | | Z => y 13 | | S(x) => x + S(y) 14 | 15 | # Multiplication 16 | def _ * y = r 17 | | Z => erase(y); Z 18 | | S(x) => (y1, y2) = dup(y); x * y1 + y2 19 | 20 | # Example: Computations on natural numbers 21 | let y = 5n 22 | x = 3n 23 | example_3_plus_5 = x + y 24 | 25 | let example_3_times_2 = 3n * 2n 26 | 27 | # Lists 28 | cons Nil 29 | cons head :: tail = l 30 | 31 | def length(list) = r 32 | | Nil => Z 33 | | x :: xs => erase(x); S(length(xs)) 34 | 35 | def map(list, fi, fo) = r 36 | | Nil => erase(fi); erase(fo); Nil 37 | | x :: xs => (x, fi2) = dup(fi) 38 | (fo1, fo2) = dup(fo) 39 | fo1 :: map(xs, fi2, fo2) 40 | 41 | # Example: List operations 42 | let l0 = 1n :: 2n :: 3n :: Nil 43 | (l0a, l0b) = dup(l0) 44 | l0_length = length(l0a) 45 | l0_mapped = map(l0b, x, x + 2n) 46 | 47 | # Explicit lambdas 48 | cons in |> out 49 | def apply(l, in) = out 50 | | i |> o => in = i; o 51 | 52 | # Example: List mapping with lambdas 53 | def map2(l, f) = r 54 | | Nil => erase(f); Nil 55 | | x :: xs => (f1, f2) = dup(f) 56 | apply(f1, x) :: map2(xs, f2) 57 | 58 | let l0 = 1n :: 2n :: 3n :: Nil 59 | l0_mapped_lambda = map2(l0, x |> x + 2n) 60 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.9.8 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.3") 2 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sbt 'bench/jmh:run de.szeiger.interact.InterpreterBenchmark' 3 | -------------------------------------------------------------------------------- /src/main/java/Foo.java: -------------------------------------------------------------------------------- 1 | public class Foo { 2 | public Object f() { 3 | return de.szeiger.interact.ast.PayloadType.VOID(); 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/de/szeiger/interact/codegen/LocalClassLoader.java: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen; 2 | 3 | public class LocalClassLoader extends ClassLoader implements ClassWriter { 4 | static { registerAsParallelCapable(); } 5 | 6 | public final Class defineClass(String name, byte[] b) throws ClassFormatError { 7 | return defineClass(name, b, 0, b.length); 8 | } 9 | 10 | public void writeClass(String javaName, byte[] classFile) { 11 | defineClass(javaName, classFile); 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/de/szeiger/interact/stc2/MetaClass.java: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.stc2; 2 | 3 | import de.szeiger.interact.ast.Symbol; 4 | 5 | public abstract class MetaClass { 6 | public final Symbol cellSymbol; 7 | public final int symId; 8 | protected MetaClass(Symbol cellSymbol, int symId) { 9 | this.cellSymbol = cellSymbol; 10 | this.symId = symId; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/scala/Analyzer.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import java.io.PrintStream 4 | import scala.annotation.tailrec 5 | import scala.collection.mutable 6 | import de.szeiger.interact.ast._ 7 | 8 | trait Analyzer[Cell] { self => 9 | def rootCells: IterableOnce[Cell] 10 | def irreduciblePairs: IterableOnce[(Cell, Cell)] 11 | 12 | def getSymbol(c: Cell): Symbol 13 | def getConnected(c: Cell, port: Int): (Cell, Int) 14 | def isFreeWire(c: Cell): Boolean 15 | 16 | def symbolName(c: Cell): String = getSymbol(c).id 17 | def getArity(c: Cell): Int = getSymbol(c).arity 18 | 19 | def getPayload(c: Cell): Any = c match { 20 | case c: IntBox => c.getValue 21 | case c: RefBox => c.getValue 22 | case c => "???" 23 | } 24 | 25 | private[this] def getAllConnected(c: Cell): Iterator[(Cell, Int)] = 26 | if(isFreeWire(c)) Iterator(getConnected(c, 0)) 27 | else (-1 until getArity(c)).iterator.map(getConnected(c, _)) 28 | 29 | private[this] object Nat { 30 | def unapply(c: Cell): Option[Int] = unapply(c, 0) 31 | @tailrec private[this] def unapply(c: Cell, acc: Int): Option[Int] = (symbolName(c), getArity(c)) match { 32 | case ("Z", 0) => Some(acc) 33 | case ("S", 1) => 34 | val (c2, p2) = getConnected(c, 0) 35 | if(p2 != -1) None else unapply(c2, acc+1) 36 | case _ => None 37 | } 38 | } 39 | 40 | def reachableCells: Iterator[Cell] = { 41 | val freeWires = rootCells.iterator.filter(isFreeWire).toVector 42 | val s = mutable.HashSet.empty[Cell] 43 | val q = mutable.ArrayBuffer.from(freeWires.flatMap(getAllConnected(_).filter(_ != null).map(_._1))) 44 | while(q.nonEmpty) { 45 | val w = q.last 46 | q.dropRightInPlace(1) 47 | if(s.add(w)) q.addAll(getAllConnected(w).map(_._1)) 48 | } 49 | s.iterator 50 | } 51 | 52 | private[this] def allConnections(): (mutable.HashMap[(Cell, Int), (Cell, Int)], mutable.HashSet[Cell]) = { 53 | val m = mutable.HashMap.empty[(Cell, Int), (Cell, Int)] 54 | val s = mutable.HashSet.empty[Cell] 55 | val q = mutable.ArrayBuffer.from(rootCells) 56 | while(q.nonEmpty) { 57 | val c1 = q.last 58 | q.dropRightInPlace(1) 59 | if(s.add(c1)) { 60 | val isWire = isFreeWire(c1) 61 | val conn = getAllConnected(c1).toVector 62 | conn.zipWithIndex.foreach { 63 | case (null, _) => 64 | case ((c2, p2), _p1) => 65 | val p1 = if(isWire) 0 else _p1 - 1 66 | m.put((c1, p1), (c2, p2)) 67 | m.put((c2, p2), (c1, p1)) 68 | } 69 | q.addAll(conn.iterator.filter(_ != null).map(_._1)) 70 | } 71 | } 72 | (m, s) 73 | } 74 | 75 | def log(out: PrintStream, prefix: String = " ", markCut: (Cell, Cell) => Boolean = (_, _) => false, color: Boolean = true): mutable.ArrayBuffer[(Cell, Cell)] = { 76 | val colors = if(color) MaybeColors else NoColors 77 | import colors._ 78 | val cuts = mutable.ArrayBuffer.empty[(Cell, Cell)] 79 | def singleRet(s: Symbol): Int = if(!s.isDef) -1 else if(s.returnArity == 1) s.callArity-1 else -2 80 | val freeWires = rootCells.iterator.filter(isFreeWire).toVector 81 | val stack = mutable.Stack.from(freeWires.sortBy(c => symbolName(c))) 82 | val all = allConnections()._1 83 | val shown = mutable.HashSet.empty[Cell] 84 | var lastTmp = 0 85 | def tmp(): String = { lastTmp += 1; s"$$s$lastTmp" } 86 | val subst = mutable.HashMap.from(freeWires.iterator.map(c1 => ((c1, 0), symbolName(c1)))) 87 | //println(s"**** $subst") 88 | //def id(c: Cell): String = if(c == null) "null" else s"${getSymbol(c)}#${System.identityHashCode(c)}" 89 | //all.foreach { case ((c1, p1), (c2, p2)) => println(s" ${id(c1)}:$p1 . ${id(c2)}:$p2") } 90 | def nameOrSubst(c1: Cell, p1: Int, c2: Cell, p2: Int): String = subst.get(c2, p2) match { 91 | case Some(s) => s 92 | case None => 93 | val mark = if(p1 == -1 && p2 == -1 && markCut(c1, c2)) { 94 | cuts.addOne((c1, c2)) 95 | s"${cBlue}<${cuts.length-1}>${cNormal}" 96 | } else "" 97 | if(singleRet(getSymbol(c2)) == p2) mark + show(c2, false) 98 | else { 99 | if(!shown.contains(c2)) stack += c2 100 | val t = tmp() 101 | subst.put((c1, p1), t) 102 | mark + t 103 | } 104 | } 105 | def show(_c1: Cell, withRet: Boolean): String = { 106 | val (c1, freeP) = if(isFreeWire(_c1)) getConnected(_c1, 0) else (_c1, -2) 107 | shown += _c1 108 | val sym = getSymbol(c1) 109 | def list(poss: IndexedSeq[Int]) = poss.map { p1 => 110 | if(p1 == freeP && isFreeWire(_c1)) (getSymbol(_c1), nameOrSubst(c1, p1, _c1, 0)) 111 | else all.get(c1, p1) match { 112 | case Some((c2, p2)) => (getSymbol(c2), nameOrSubst(c1, p1, c2, p2)) 113 | case None => (Symbol.NoSymbol, "?") 114 | } 115 | } 116 | def needsParens(sym1: Symbol, pre1: Int, sym2: Symbol, sym2IsRight: Boolean): Boolean = { 117 | val pre2 = Lexical.precedenceOf(sym2.id) 118 | val r1 = Lexical.isRightAssoc(sym1.id) 119 | val r2 = Lexical.isRightAssoc(sym2.id) 120 | pre2 > pre1 || (pre2 >= 0 && (r1 != r2)) || (pre1 == pre2 && r1 != sym2IsRight && r2 != sym2IsRight) 121 | } 122 | val call = c1 match { 123 | case Nat(v) => s"${v}n" 124 | case _ => 125 | val aposs = if(sym.isDef) -1 +: (0 until sym.callArity-1) else 0 until sym.arity 126 | val as0 = list(aposs) 127 | val pr1 = Lexical.precedenceOf(sym.id) 128 | val nameAndValue = sym.payloadType match { 129 | case PayloadType.VOID => s"$cYellow${sym.id}$cNormal" 130 | case _ => 131 | val s = getPayload(c1) match { 132 | case s: String => s"\"$s\"" 133 | case o => String.valueOf(o) 134 | } 135 | s"$cYellow${sym.id}$cNormal[$s]" 136 | } 137 | if(pr1 >= 0 && sym.arity == 2) { 138 | val as1 = as0.zipWithIndex.map { case ((asym, s), idx) => if(needsParens(sym, pr1, asym, idx == 1)) s"($s)" else s } 139 | s"${as1(0)} $nameAndValue ${as1(1)}" 140 | } else { 141 | val as = if(as0.isEmpty) "" else as0.iterator.map(_._2).mkString("(", ", ", ")") 142 | s"$nameAndValue$as" 143 | } 144 | } 145 | if(withRet) { 146 | val rposs = if(sym.isDef) sym.callArity-1 until sym.callArity+sym.returnArity-1 else IndexedSeq(-1) 147 | val rs0 = list(rposs).map(_._2) 148 | rs0.size match { 149 | case 0 => call 150 | case 1 => s"${rs0.head} = $call" 151 | case _ => rs0.mkString("(", ", ", s") = $call") 152 | } 153 | } else call 154 | } 155 | while(stack.nonEmpty) { 156 | val c1 = stack.pop() 157 | if(!shown.contains(c1)) { 158 | val s = show(c1, true) 159 | out.println(s"$prefix$s") 160 | } 161 | } 162 | val irr = irreduciblePairs.iterator.filter { case (c1, c2) => c1 != null && c2 != null }.map { case (c1, c2) => Seq(symbolName(c1), symbolName(c2)).sorted.mkString(" <-> ") }.toVector.sorted 163 | if(irr.nonEmpty) { 164 | out.println() 165 | out.println("Irreducible pairs:") 166 | irr.foreach(s => out.println(s" $s")) 167 | } 168 | cuts 169 | } 170 | 171 | def toDot(out: PrintStream): Unit = { 172 | var lastIdx = 0 173 | def mk(): String = { lastIdx += 1; s"n$lastIdx" } 174 | val cells = allConnections()._2.map(c => (c, mk())).toMap 175 | out.println("graph G {") 176 | out.println(" node [shape=plain];") 177 | cells.foreachEntry { (c, l) => 178 | val sym = getSymbol(c) 179 | if(sym.arity == 0) 180 | out.println( 181 | s""" $l [shape=circle label=<${sym.id}>];""".stripMargin 182 | ) 183 | else { 184 | val ports = (sym.arity to 1 by -1).map(i => s"""""").mkString 185 | out.println( 186 | s""" $l [shape=plain label=< 187 | | 188 | | 189 | | 190 | |
${sym.id}
$ports
>];""".stripMargin 191 | ) 192 | } 193 | } 194 | val done = mutable.HashSet.empty[(Cell, Int)] 195 | cells.foreachEntry { (c1, l1) => 196 | getAllConnected(c1).zipWithIndex.foreach { case ((c2, _p2), p1) => 197 | if(!done.contains((c1, p1))) { 198 | val p2 = _p2 + 1 199 | val l2 = cells(c2) 200 | val st = 201 | if(p1 == 0 && p2 == 0) " [style=bold]" 202 | else if(p1 != 0 && p2 != 0) " [style=dashed]" 203 | else "" 204 | out.println(s""" $l1:$p1 -- $l2:$p2$st;""") 205 | done += ((c2, p2)) 206 | } 207 | } 208 | } 209 | out.println("}") 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /src/main/scala/BitOps.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | object BitOps { 4 | @inline def byte0(i: Int): Int = (i & 0xFF).toByte 5 | @inline def byte1(i: Int): Int = ((i >>> 8) & 0xFF).toByte 6 | @inline def byte2(i: Int): Int = ((i >>> 16) & 0xFF).toByte 7 | @inline def byte3(i: Int): Int = ((i >>> 24) & 0xFF).toByte 8 | @inline def intOfBytes(b0: Int, b1: Int, b2: Int, b3: Int): Int = b0.toByte&0xFF | ((b1.toByte&0xFF) << 8) | ((b2.toByte&0xFF) << 16) | ((b3.toByte&0xFF) << 24) 9 | def checkedIntOfBytes(b0: Int, b1: Int, b2: Int, b3: Int): Int = { 10 | assert(b0 >= -128 && b0 <= 127) 11 | assert(b1 >= -128 && b1 <= 127) 12 | assert(b2 >= -128 && b2 <= 127) 13 | assert(b3 >= -128 && b3 <= 127) 14 | intOfBytes(b0, b1, b2, b3) 15 | } 16 | object IntOfBytes { 17 | @inline def unapply(i: Int): Some[(Int, Int, Int, Int)] = Some((byte0(i), byte1(i), byte2(i), byte3(i))) 18 | } 19 | 20 | @inline def short0(i: Int): Int = (i & 0xFFFF).toShort 21 | @inline def short1(i: Int): Int = ((i >>> 16) & 0xFFFF).toShort 22 | @inline def intOfShorts(s0: Int, s1: Int): Int = s0.toShort&0xFFFF | ((s1.toShort&0xFFFF) << 16) 23 | def checkedIntOfShorts(s0: Int, s1: Int): Int = { 24 | assert(s0 >= Short.MinValue && s0 <= Short.MaxValue) 25 | assert(s1 >= Short.MinValue && s1 <= Short.MaxValue) 26 | intOfShorts(s0, s1) 27 | } 28 | object IntOfShorts { 29 | @inline def unapply(i: Int): Some[(Int, Int)] = Some((short0(i), short1(i))) 30 | } 31 | 32 | @inline def intOfShortByteByte(s0: Int, b2: Int, b3: Int): Int = s0.toShort&0xFFFF | ((b2.toByte&0xFF) << 16) | ((b3.toByte&0xFF) << 24) 33 | def checkedIntOfShortByteByte(s0: Int, b2: Int, b3: Int): Int = { 34 | assert(s0 >= Short.MinValue && s0 <= Short.MaxValue) 35 | assert(b2 >= -128 && b2 <= 127) 36 | assert(b3 >= -128 && b3 <= 127) 37 | intOfShortByteByte(s0, b2, b3) 38 | } 39 | object IntOfShortByteByte { 40 | @inline def unapply(i: Int): Some[(Int, Int)] = Some((short0(i), short1(i))) 41 | } 42 | } 43 | 44 | object LongBitOps { 45 | @inline def short0(l: Long): Int = (l & 0xFFFFL).toShort 46 | @inline def short1(l: Long): Int = ((l >>> 16L) & 0xFFFFL).toShort 47 | @inline def short2(l: Long): Int = ((l >>> 32L) & 0xFFFFL).toShort 48 | @inline def short3(l: Long): Int = ((l >>> 48L) & 0xFFFFL).toShort 49 | @inline def longOfShorts(s0: Int, s1: Int, s2: Int, s3: Int): Long = s0.toShort&0xFFFFL | ((s1.toShort&0xFFFFL) << 16L) | ((s2.toShort&0xFFFFL) << 32L) | ((s3.toShort&0xFFFFL) << 48L) 50 | def checkedLongOfShorts(s0: Int, s1: Int, s2: Int, s3: Int): Long = { 51 | assert(s0 >= Short.MinValue && s0 <= Short.MaxValue) 52 | assert(s1 >= Short.MinValue && s1 <= Short.MaxValue) 53 | assert(s2 >= Short.MinValue && s2 <= Short.MaxValue) 54 | assert(s3 >= Short.MinValue && s3 <= Short.MaxValue) 55 | longOfShorts(s0, s1, s2, s3) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/CleanEmbedded.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | import scala.collection.mutable 6 | import scala.collection.mutable.ArrayBuffer 7 | 8 | /** 9 | * Check usage of variables and assign payload types to embedded variables. After this phase: 10 | * - All Apply/ApplyCons constructor calls that need an embedded value contain an Ident 11 | * which either refers to an Ident in the pattern or is computed by an EmbeddedExpr. 12 | * - Implicit creation and copying of labels is made explicit. 13 | * - PayloadTypes are assigned to all embedded Symbols. 14 | * - Embedded variable usage is linear. 15 | * - Embedded expressions are valid assignments and method calls. 16 | * - Method/operator calls are typed but not yet resolved. 17 | */ 18 | class CleanEmbedded(val global: Global) extends Transform with Phase { 19 | import global._ 20 | 21 | override def apply(n: Statement): Vector[Statement] = n match { 22 | case n: MatchRule => apply(n) 23 | case n: Let => apply(n) 24 | case n => Vector(n) 25 | } 26 | 27 | override def apply(mr: MatchRule): Vector[Statement] = { 28 | val emb1IdOpt = checkPatternEmbSym(mr.id1, mr.emb1) 29 | val emb2IdOpt = checkPatternEmbSym(mr.id2, mr.emb2) 30 | val patternIds = Iterator(emb1IdOpt, emb2IdOpt).flatten.toSet 31 | patternIds.foreach { _.sym.setPattern() } 32 | val branches = mr.branches.map { b => 33 | val (reduced2, embRed2) = transformReduction(b.reduced, b.embRed, patternIds) 34 | val cond2 = b.cond.map(checkCond(_)) 35 | b.copy(cond = cond2, reduced = reduced2, embRed = embRed2) 36 | } 37 | Vector(mr.copy(branches = branches)) 38 | } 39 | 40 | override def apply(l: Let): Vector[Statement] = { 41 | val (es, ees) = transformReduction(l.defs, l.embDefs, Set.empty) 42 | Vector(l.copy(defs = es, embDefs = ees)) 43 | } 44 | 45 | // check embedded symbol in pattern and assign payload type to is 46 | private[this] def checkPatternEmbSym(consId: Ident, embId: Option[Ident]): Option[Ident] = { 47 | val consPT = consId.sym.payloadType 48 | embId match { 49 | case s @ Some(id) => 50 | assert(id.sym.isDefined) 51 | id.sym.payloadType = consPT 52 | if(consPT.isEmpty) { 53 | error(s"Constructor has no embedded value", consId) 54 | None 55 | } else s 56 | case None => 57 | if(consPT.isDefined && !consPT.canErase) 58 | error(s"Embedded value of type $consPT must be extracted in pattern match", consId) 59 | None 60 | } 61 | } 62 | 63 | // Create new unique embedded symbols in Apply / ApplyCons and assign their PayloadTypes. 64 | // Ensure that all targets have / do not have embedded symbols based on type. 65 | // Returns updated exprs, extracted embedded computations, old symbol to new idents mapping, all new symbols 66 | private[this] def extractConsEmbComps(exprs: Vector[Expr]): (Vector[Expr], Vector[EmbeddedExpr], mutable.HashMap[Symbol, ArrayBuffer[Ident]], Set[Symbol]) = { 67 | val local = new SymbolGen("$e$", isEmbedded = true) 68 | val symbolMap = mutable.HashMap.empty[Symbol, ArrayBuffer[Ident]] 69 | val newEmbComps = Vector.newBuilder[EmbeddedExpr] 70 | val defaultCreate = ArrayBuffer.empty[Symbol] 71 | val allEmbCompSyms = Set.newBuilder[Symbol] 72 | val proc: Transform = new Transform { 73 | def tr[T <: AnyApply](n: T): T = { 74 | val emb2 = n.embedded match { 75 | case s @ Some(e) => 76 | if(n.target.sym.payloadType.isEmpty) 77 | error("Constructor has no embedded value", e) 78 | val prefix = e match { 79 | case i: Ident => i.s 80 | case _ => n.target.sym.id 81 | } 82 | val id = local.id(isEmbedded = true, payloadType = n.target.sym.payloadType, prefix = prefix).setPos(e.pos) 83 | e match { 84 | case oldId: Ident if !oldId.sym.isPattern => 85 | symbolMap.getOrElseUpdate(oldId.sym, ArrayBuffer.empty) += id 86 | case _ => 87 | newEmbComps += EmbeddedAssignment(id, e) 88 | allEmbCompSyms += id.sym 89 | } 90 | Some(id) 91 | case None if n.target.sym.payloadType.canCreate => 92 | val id = local.id(isEmbedded = true, payloadType = n.target.sym.payloadType, prefix = "cr_"+n.target.sym.id).setPos(n.pos) 93 | defaultCreate += id.sym 94 | allEmbCompSyms += id.sym 95 | Some(id) 96 | case None => 97 | if(n.target.sym.payloadType.isDefined) 98 | error(s"Embedded value of type ${n.target.sym.payloadType} must be created", n) 99 | None 100 | } 101 | n.copy(embedded = emb2).asInstanceOf[T] 102 | } 103 | override def apply(n: Apply): Apply = tr(super.apply(n)) 104 | override def apply(n: ApplyCons): ApplyCons = tr(super.apply(n)) 105 | } 106 | val exprs2 = exprs.map(proc(_)) 107 | if(defaultCreate.nonEmpty) 108 | newEmbComps += CreateLabels(local(isEmbedded = true, payloadType = PayloadType.REF, prefix = "defCr"), defaultCreate.toVector) 109 | (exprs2, newEmbComps.result(), symbolMap, allEmbCompSyms.result()) 110 | } 111 | 112 | private[this] def checkCond(cond: EmbeddedExpr): EmbeddedExpr = { 113 | val proc: Transform = new Transform { 114 | override def apply(n: Ident): Ident = { 115 | if(n.sym.isDefined && !n.sym.isPattern) 116 | error("Unknown symbol (not in pattern)", n) 117 | n 118 | } 119 | } 120 | proc(cond) 121 | } 122 | 123 | private[this] def checkLinearity(patternIds: Iterable[Ident], consIds: Iterable[Ident], usageCount: mutable.HashMap[Symbol, Int]): Unit = { 124 | patternIds.foreach { id => 125 | val pt = id.sym.payloadType 126 | usageCount(id.sym) match { 127 | case 0 => if(!pt.canErase) error(s"Cannot implicitly erase value of type $pt", id) 128 | case 1 => 129 | case _ => if(!pt.canCopy) error(s"Cannot implicitly copy value of type $pt", id) 130 | } 131 | } 132 | consIds.foreach { id => 133 | val pt = id.sym.payloadType 134 | usageCount(id.sym) match { 135 | case 0 => if(!pt.canCreate) error(s"Cannot implicitly create value of type $pt", id) 136 | case 1 => 137 | case _ => error(s"Duplicate assignment", id) 138 | } 139 | } 140 | } 141 | 142 | private[this] def transformReduction(reduced: Vector[Expr], embComps: Vector[EmbeddedExpr], 143 | patternIds: Set[Ident]): (Vector[Expr], Vector[EmbeddedExpr]) = { 144 | val (reduced2, newEmbComps, newEmbIds, assignedEmbCompSyms) = extractConsEmbComps(reduced) 145 | val allEmbCompSyms = assignedEmbCompSyms ++ newEmbIds.values.flatten.map(_.sym) 146 | val usageCount = mutable.HashMap.from((patternIds.map(_.sym) ++ allEmbCompSyms).map((_ , 0))) 147 | def checkLHS(i: Ident): Unit = 148 | if(!allEmbCompSyms.contains(i.sym)) error("Assignment must be to an embedded variable of the reduction", i) 149 | val mapIdents: Transform = new Transform { 150 | override def apply(n: Ident): Ident = { 151 | if(n.sym.isCons || n.sym.isEmpty) n 152 | else { 153 | val n2 = newEmbIds.get(n.sym) match { 154 | case Some(ids) if ids.length > 1 => 155 | error(s"Multiple occurrences of symbol ${n.sym} in cell constructors", n) 156 | n 157 | case Some(ids) => ids.head 158 | case _ => n 159 | } 160 | if (!n2.sym.isPattern && !allEmbCompSyms.contains(n2.sym)) { 161 | error("Unknown symbol (not in pattern or cell constructor)", n) 162 | n 163 | } 164 | usageCount.update(n2.sym, usageCount(n2.sym) + 1) 165 | n2 166 | } 167 | } 168 | override def apply(n: EmbeddedAssignment): EmbeddedAssignment = { 169 | val n2 = super.apply(n) 170 | checkLHS(n2.lhs) 171 | n2 172 | } 173 | } 174 | val mappedEmbComps = ArrayBuffer.empty[EmbeddedExpr] 175 | embComps.foreach(mappedEmbComps += mapIdents(_)) 176 | newEmbComps.foreach(mappedEmbComps += mapIdents(_)) 177 | val unusedNewEmbIds = 178 | newEmbIds.iterator.map { case (oldSym, newIds) => (oldSym, newIds.filter(i => usageCount(i.sym) == 0)) }.filter(_._2.nonEmpty).toVector 179 | unusedNewEmbIds.foreach { case (oldSym, newIds) => 180 | mappedEmbComps += CreateLabels(oldSym, newIds.iterator.map(_.sym).toVector).setPos(newIds.head.pos) 181 | newIds.foreach { i => usageCount.update(i.sym, usageCount(i.sym) + 1) } 182 | } 183 | checkLinearity(patternIds, newEmbIds.iterator.flatMap(_._2).toVector, usageCount) 184 | 185 | // println("************* Reduction:") 186 | // mappedEmbComps.foreach(ShowableNode.print(_)) 187 | // println("newEmbIds: " + newEmbIds.iterator.map { case (k, v) => s"$k -> ${v.mkString("[", ",", "]")}"}.mkString(", ")) 188 | // println("unusedNewEmbIds: " + unusedNewEmbIds.iterator.map { case (k, v) => s"$k -> ${v.mkString("[", ",", "]")}"}.mkString(", ")) 189 | // println("usageCount: " + usageCount.iterator.map { case (k, v) => s"$k -> $v"}.mkString(", ")) 190 | 191 | (reduced2, mappedEmbComps.toVector) 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /src/main/scala/Colors.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | abstract class Colors { 4 | def cNormal: String 5 | def cBlack: String 6 | def cRed: String 7 | def cGreen: String 8 | def cYellow: String 9 | def cBlue: String 10 | def cMagenta: String 11 | def cCyan: String 12 | def bRed: String 13 | def bGreen: String 14 | def bYellow: String 15 | def bBlue: String 16 | def bMagenta: String 17 | def bCyan: String 18 | } 19 | 20 | object MaybeColors extends Colors { 21 | val useColors: Boolean = System.getProperty("interact.colors", "true").toBoolean 22 | 23 | val (cNormal, cBlack, cRed, cGreen, cYellow, cBlue, cMagenta, cCyan) = 24 | if(useColors) ("\u001B[0m", "\u001B[30m", "\u001B[31m", "\u001B[32m", "\u001B[33m", "\u001B[34m", "\u001B[35m", "\u001B[36m") 25 | else ("", "", "", "", "", "", "", "") 26 | val (bRed, bGreen, bYellow, bBlue, bMagenta, bCyan) = 27 | if(useColors) ("\u001B[41m", "\u001B[42m", "\u001B[43m", "\u001B[44m", "\u001B[45m", "\u001B[46m") 28 | else ("", "", "", "", "", "") 29 | } 30 | 31 | object NoColors extends Colors { 32 | val cNormal, cBlack, cRed, cGreen, cYellow, cBlue, cMagenta, cCyan, bRed, bGreen, bYellow, bBlue, bMagenta, bCyan = "" 33 | } 34 | 35 | object Colors { 36 | def stripColors(s: String): String = { 37 | val b = new StringBuilder(s.length) 38 | var i = 0 39 | while(i < s.length) { 40 | s.charAt(i) match { 41 | case '\u001B' => 42 | val end = s.indexOf('m', i+1) 43 | if(end == -1) i += 1 44 | else i = end+1 45 | case c => 46 | b.append(c) 47 | i += 1 48 | } 49 | } 50 | b.result() 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/scala/Compiler.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | import de.szeiger.interact.offheap.{Allocator, ProxyAllocator, SliceAllocator} 5 | 6 | import java.nio.file.Path 7 | import scala.collection.mutable 8 | 9 | class Compiler(unit0: CompilationUnit, _config: Config = Config.defaultConfig) { 10 | val global = new Global(_config) 11 | import global._ 12 | 13 | private[this] val phases: Vector[Phase] = Vector( 14 | new Prepare(global), 15 | new ExpandRules(global), 16 | new Curry(global), 17 | new CleanEmbedded(global), 18 | new ResolveEmbedded(global), 19 | new CreateWiring(global), 20 | new Inline(global), 21 | ) ++ (if(config.backend.planRules) Vector( 22 | new PlanRules(global), 23 | ) else Vector.empty) 24 | 25 | private[this] val unit1 = if(config.addEraseDup) { 26 | val erase = globalSymbols.define("erase", isCons = true, isDef = true, returnArity = 0) 27 | val dup = globalSymbols.define("dup", isCons = true, isDef = true, arity = 2, returnArity = 2, payloadType = PayloadType.LABEL) 28 | unit0.copy(statements = Vector(DerivedRule(erase, erase), DerivedRule(erase, dup), DerivedRule(dup, dup)) ++ unit0.statements).setPos(unit0.pos) 29 | } else unit0 30 | 31 | val unit = phases.foldLeft(unit1) { case (u, p) => 32 | val u2 = p(u) 33 | if(config.showAfter.contains(p.phaseName) || config.showAfter.contains("*")) 34 | ShowableNode.print(u2, name = s"After phase ${p.phaseName}") 35 | checkThrow() 36 | u2 37 | } 38 | 39 | def createInterpreter(): BaseInterpreter = config.backend.createInterpreter(this) 40 | } 41 | 42 | trait Phase extends (CompilationUnit => CompilationUnit) { 43 | val global: Global 44 | val phaseName: String = getClass.getName.replaceAll(".*\\.", "") 45 | val phaseLogEnabled: Boolean = global.config.phaseLog.contains(phaseName) || global.config.phaseLog.contains("*") 46 | override def toString: String = phaseName 47 | 48 | @inline final def phaseLog(@inline msg: => String): Unit = if(phaseLogEnabled) global.phaseLog(phaseName, msg) 49 | @inline final def phaseLog(n: ShowableNode, name: => String, prefix: String = ""): Unit = if(phaseLogEnabled) global.phaseLog(phaseName, n, name, prefix) 50 | } 51 | 52 | case class Config( 53 | // Frontend 54 | defaultDerive: Seq[String] = Seq("erase", "dup"), 55 | addEraseDup: Boolean = true, 56 | phaseLog: Set[String] = Set.empty, // show debug log of these phases 57 | showAfter: Set[String] = Set.empty, // log AST after these phases 58 | inlineFull: Boolean = true, // inline rules that can be merged into a single branch 59 | inlineFullAll: Boolean = true, // inline simple matches even when duplicating a parent rule 60 | inlineBranching: Boolean = true, // inline rules that cannot be merged into a single branch (st.c) 61 | inlineUniqueContinuations: Boolean = true, // st.c 62 | loop: Boolean = true, 63 | repeatedInliningLimit: Int = 0, 64 | 65 | // Backend 66 | backend: Backend = STC1Backend, 67 | numThreads: Int = 0, // mt 68 | collectStats: Boolean = false, 69 | useCellCache: Boolean = false, // stc* 70 | biasForCommonDispatch: Boolean = true, // optimize for invokevirtual dispatch of statically known cell types (stc1) 71 | logCodeGenSummary: Boolean = false, // stc*, mt.c 72 | logGeneratedClasses: Option[String] = None, // Log generated classes containing this string (stc*, mt.c) 73 | compilerParallelism: Int = 1, 74 | allCommon: Boolean = false, // compile all methods into CommonCell, not just shared ones (stc1) 75 | reuseCells: Boolean = true, // stc* 76 | writeOutput: Option[Path] = None, // write generated classes to dir or jar file (stc*) 77 | writeJava: Option[Path] = None, // write decompiled classes to dir (stc*) 78 | skipCodeGen: Set[String] = Set.empty, // do not generate classfiles for these Java class names (stc*) 79 | tailCallDepth: Int = 32, // stc2 80 | newAllocator: () => ProxyAllocator = () => new SliceAllocator(), // stc2 81 | debugMemory: Boolean = false, // stc2 82 | unboxedPrimitives: Boolean = true, // Use unboxed 32-bit primitives and singletons instead of pointers (stc2) 83 | ) { 84 | def withSpec(spec: String): Config = spec match { 85 | case s"sti" => copy(backend = STIBackend) 86 | case s"stc1" => copy(backend = STC1Backend) 87 | case s"stc2" => copy(backend = STC2Backend) 88 | case s"mt${mode}.i" => copy(backend = MTIBackend, numThreads = mode.toInt) 89 | case s"mt${mode}.c" => copy(backend = MTCBackend, numThreads = mode.toInt) 90 | } 91 | } 92 | 93 | abstract class Backend(val name: String) { 94 | def createInterpreter(comp: Compiler): BaseInterpreter 95 | def planRules: Boolean 96 | def inlineBranching: Boolean 97 | def inlineUniqueContinuations: Boolean 98 | def allowPayloadTemp: Boolean 99 | def canReuseLabels: Boolean 100 | 101 | def storageClass(sym: Symbol): Any = sym 102 | def canUnbox(sym: Symbol): Boolean = false 103 | } 104 | 105 | object STIBackend extends Backend("sti") { 106 | def createInterpreter(comp: Compiler): BaseInterpreter = 107 | new sti.Interpreter(comp.global.globalSymbols, comp.unit, comp.global.config) 108 | def planRules: Boolean = false 109 | def inlineBranching: Boolean = false 110 | def inlineUniqueContinuations: Boolean = false 111 | def allowPayloadTemp: Boolean = false 112 | def canReuseLabels: Boolean = true 113 | } 114 | 115 | object STC1Backend extends Backend("stc1") { 116 | def createInterpreter(comp: Compiler): BaseInterpreter = 117 | new stc1.Interpreter(comp.global.globalSymbols, comp.unit, comp.global.config) 118 | def planRules: Boolean = true 119 | def inlineBranching: Boolean = true 120 | def inlineUniqueContinuations: Boolean = true 121 | def allowPayloadTemp: Boolean = true 122 | def canReuseLabels: Boolean = true 123 | } 124 | 125 | object STC2Backend extends Backend("stc2") { 126 | def createInterpreter(comp: Compiler): BaseInterpreter = 127 | new stc2.Interpreter(comp.global.globalSymbols, comp.unit, comp.global.config) 128 | def planRules: Boolean = true 129 | def inlineBranching: Boolean = true 130 | def inlineUniqueContinuations: Boolean = true 131 | def allowPayloadTemp: Boolean = true 132 | def canReuseLabels: Boolean = false 133 | override def storageClass(sym: Symbol) = (stc2.Interpreter.cellSize(sym.arity, sym.payloadType), sym.payloadType == PayloadType.REF) 134 | override def canUnbox(sym: Symbol): Boolean = stc2.Interpreter.canUnbox(sym, sym.arity) 135 | } 136 | 137 | class MTBackend(name: String, compile: Boolean) extends Backend(name) { 138 | def createInterpreter(comp: Compiler): BaseInterpreter = { 139 | import comp.global._ 140 | val rulePlans = mutable.Map.empty[RuleKey, RuleWiring] 141 | val initialPlans = mutable.ArrayBuffer.empty[InitialRuleWiring] 142 | comp.unit.statements.foreach { 143 | case i: InitialRuleWiring => initialPlans += i 144 | case g: RuleWiring => rulePlans.put(g.key, g) 145 | } 146 | new mt.Interpreter(globalSymbols, rulePlans.values, config, mutable.ArrayBuffer.empty[Let], initialPlans, compile) 147 | } 148 | def planRules: Boolean = compile 149 | def inlineBranching: Boolean = compile 150 | def inlineUniqueContinuations: Boolean = compile 151 | def allowPayloadTemp: Boolean = compile 152 | def canReuseLabels: Boolean = compile 153 | } 154 | 155 | object MTIBackend extends MTBackend("mti", false) 156 | object MTCBackend extends MTBackend("mti", true) 157 | 158 | object Config { 159 | val defaultConfig: Config = Config() 160 | def apply(spec: String): Config = defaultConfig.withSpec(spec) 161 | } 162 | -------------------------------------------------------------------------------- /src/main/scala/Curry.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | import scala.collection.mutable 6 | 7 | /** 8 | * Create curried and derived rules, and remove all Cons and Def statements. 9 | */ 10 | class Curry(val global: Global) extends Transform with Phase { 11 | import global._ 12 | 13 | private[this] lazy val defaultDeriveSyms = 14 | config.defaultDerive.iterator.map(globalSymbols.get).filter(_.exists(_.isCons)).map(_.get).toVector 15 | 16 | private[this] def derivedRules(syms1: Vector[Symbol], sym2: Symbol, pos: Position): Vector[DerivedRule] = 17 | syms1.flatMap { sym => 18 | if(sym.id == "erase" || sym.id == "dup") Vector(DerivedRule(sym, sym2).setPos(pos)) 19 | else { error(s"Don't know how to derive '$sym'", pos); Vector.empty } 20 | } 21 | 22 | private[this] def singleNonIdentIdx(es: Seq[Expr]): Int = { 23 | val i1 = es.indexWhere(e => !e.isInstanceOf[Ident]) 24 | if(i1 == -1) -1 25 | else { 26 | val i2 = es.lastIndexWhere(e => !e.isInstanceOf[Ident]) 27 | if(i2 == i1) i1 else -2 28 | } 29 | } 30 | 31 | private[this] def createCurriedDef(lid: Ident, rid: Ident, idx: Int, rhs: Boolean, at: Position): (Ident, Vector[CheckedRule]) = { 32 | val ls = lid.sym 33 | val rs = rid.sym 34 | val curryId = Ident(s"${ls.id}$$${if(rhs) "ac" else "nc"}$$${rs.id}").setPos(lid.pos) 35 | val rules = globalSymbols.get(curryId) match { 36 | case Some(sym) => 37 | if(sym.matchContinuationPort != idx) error(s"Port mismatch in curried ${ls.id} -> ${rs.id} match", at) 38 | curryId.sym = sym 39 | Vector.empty 40 | case None if ls.hasPayload && rs.hasPayload => 41 | error("Implementation limitation: Curried definitions cannot have payload on both sides", at) 42 | Vector.empty 43 | case None => 44 | val curriedPtp = if(ls.hasPayload) ls.payloadType else rs.payloadType 45 | val emb1 = if(ls.hasPayload) Some(mkLocalId("$l", true).setPos(lid.pos)) else None 46 | val emb2 = if(rs.hasPayload) Some(mkLocalId("$r", true).setPos(rid.pos)) else None 47 | curryId.sym = globalSymbols.define(curryId.s, isCons = true, arity = ls.arity + rs.arity - 1, payloadType = curriedPtp, matchContinuationPort = idx) 48 | val largs = (0 until ls.callArity).map(i => mkLocalId(s"$$l$i").setPos(lid.pos)).toVector 49 | val rargs = (0 until rs.callArity).map(i => mkLocalId(s"$$r$i").setPos(rid.pos)).toVector 50 | val (keepArgs, splitArgs) = if(rhs) (rargs, largs) else (largs, rargs) 51 | val curryArgs = keepArgs ++ splitArgs.zipWithIndex.filter(_._2 != idx).map(_._1) 52 | val der = derivedRules(defaultDeriveSyms, curryId.sym, Position.unknown) 53 | val fwd = Assignment(Apply(curryId, emb1.orElse(emb2), curryArgs).setPos(at), splitArgs(idx)).setPos(at) 54 | der :+ MatchRule(lid, rid, largs, rargs, emb1, emb2, Vector(Branch(None, Vector.empty, Vector(fwd)).setPos(at))).setPos(at) 55 | } 56 | (curryId, rules) 57 | } 58 | 59 | private[this] def curry(mr: MatchRule): Vector[Statement] = mr.args1.indexWhere(e => !e.isInstanceOf[Ident]) match { 60 | case -1 => 61 | singleNonIdentIdx(mr.args2) match { 62 | case -2 => 63 | error(s"Only one nested match allowed", mr.pos) 64 | Vector.empty 65 | case -1 => 66 | mr.args1.toSet.intersect(mr.args2.toSet).foreach { case i: Ident => 67 | error(s"Duplicate variable '${i.s}' on both sides of a match", i) 68 | } 69 | Vector(mr) 70 | case idx => 71 | val (curryId, curryRules) = createCurriedDef(mr.id1, mr.id2, idx, false, mr.args2(idx).pos) 72 | val ApplyCons(cid, cemb, crargs) = mr.args2(idx) 73 | val clargs = mr.args1 ++ mr.args2.zipWithIndex.filter(_._2 != idx).map(_._1.asInstanceOf[Ident]) 74 | curryRules ++ curry(mr.copy(curryId, cid, clargs, crargs, mr.emb1.orElse(mr.emb2), checkCEmb(cemb), mr.branches)) 75 | } 76 | case idx => 77 | val (curryId, curryRules) = createCurriedDef(mr.id1, mr.id2, idx, true, mr.args1(idx).pos) 78 | val ApplyCons(cid, cemb, clargs) = mr.args1(idx) 79 | val crargs = mr.args2 ++ mr.args1.zipWithIndex.filter(_._2 != idx).map(_._1.asInstanceOf[Ident]) 80 | curryRules ++ curry(mr.copy(curryId, cid, crargs, clargs, mr.emb1.orElse(mr.emb2), checkCEmb(cemb), mr.branches)) 81 | } 82 | 83 | private[this] def checkCEmb(o: Option[EmbeddedExpr]): Option[Ident] = o.flatMap { 84 | case i: Ident => Some(i) 85 | case e => 86 | error("Embedded expression in pattern match must be a variable", e) 87 | None 88 | } 89 | 90 | override def apply(n: Statement): Vector[Statement] = n match { 91 | case n: MatchRule => curry(n) 92 | case c: Cons => derivedRules(c.der.map(_.map(_.sym)).getOrElse(defaultDeriveSyms), c.name.sym, c.name.pos) 93 | case d: Def => derivedRules(defaultDeriveSyms, d.name.sym, d.name.pos) 94 | case n => Vector(n) 95 | } 96 | 97 | override def apply(n: CompilationUnit): CompilationUnit = { 98 | val n2 = super.apply(n) 99 | val keys = mutable.HashSet.empty[RuleKey] 100 | n2.statements.foreach { 101 | case c: CheckedRule => 102 | if(!keys.add(c.key)) error(s"Duplicate rule ${c.sym1} <-> ${c.sym2}", c) 103 | case _ => 104 | } 105 | n2 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/scala/Debug.scala: -------------------------------------------------------------------------------- 1 | import de.szeiger.interact._ 2 | import de.szeiger.interact.sti.Cell 3 | 4 | import java.nio.file.Path 5 | import scala.annotation.tailrec 6 | import scala.collection.mutable 7 | 8 | object Debug extends App { 9 | val statements = Parser.parse(Path.of(args(0))) 10 | val model = new Compiler(statements, Config(backend = STIBackend)) 11 | val inter = model.createInterpreter().asInstanceOf[sti.Interpreter] 12 | inter.initData() 13 | 14 | var steps = 0 15 | var cuts: mutable.ArrayBuffer[(Cell, Cell)] = _ 16 | 17 | @tailrec 18 | def readLine(): Option[Int] = { 19 | print("> ") 20 | val in = Console.in.readLine() 21 | if(in == "q") None 22 | else in.toIntOption.filter(i => i >= 0 && i < cuts.length) match { 23 | case None => readLine() 24 | case o => o 25 | } 26 | } 27 | 28 | @tailrec def step(): Unit = { 29 | println(s"${MaybeColors.cGreen}At step $steps:${MaybeColors.cNormal}") 30 | cuts = inter.getAnalyzer.log(System.out, markCut = (c1, c2) => inter.getRuleImpl(c1, c2) != null) 31 | if(cuts.isEmpty) 32 | println(s"${MaybeColors.cGreen}Irreducible after $steps reductions.${MaybeColors.cNormal}") 33 | else { 34 | steps += 1 35 | readLine() match { 36 | case None => () 37 | case Some(idx) => 38 | inter.reduce1(cuts(idx)._1, cuts(idx)._2) 39 | step() 40 | } 41 | } 42 | } 43 | 44 | step() 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/ExecutionMetrics.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import scala.collection.mutable 4 | 5 | class ExecutionMetrics { 6 | private[this] final class MutInt(var i: Int) 7 | private[this] var steps, cellAlloc, proxyAlloc, cellReuse, singletonUse, unboxedCells, loopSave, directTail, singleDispatchTail, labelCreate = 0 8 | private[this] var metrics = mutable.Map.empty[String, MutInt] 9 | 10 | def getSteps: Int = steps 11 | 12 | def recordStats(steps: Int, cellAllocations: Int, proxyAllocations: Int, cachedCellReuse: Int, singletonUse: Int, 13 | unboxedCells: Int, loopSave: Int, directTail: Int, singleDispatchTail: Int, labelCreate: Int): Unit = { 14 | this.steps += steps 15 | this.cellAlloc += cellAllocations 16 | this.proxyAlloc += proxyAllocations 17 | this.cellReuse += cachedCellReuse 18 | this.singletonUse += singletonUse 19 | this.unboxedCells += unboxedCells 20 | this.loopSave += loopSave 21 | this.directTail += directTail 22 | this.singleDispatchTail += singleDispatchTail 23 | this.labelCreate += labelCreate 24 | } 25 | 26 | def recordStats(steps: Int, cellAllocations: Int): Unit = recordStats(steps, cellAllocations, 0, 0, 0, 0, 0, 0, 0, 0) 27 | 28 | def recordMetric(metric: String, inc: Int = 1): Unit = { 29 | val m = metrics.getOrElseUpdate(metric, new MutInt(0)) 30 | m.i += inc 31 | } 32 | 33 | def log(): Unit = { 34 | logStats() 35 | logMetrics() 36 | } 37 | 38 | def logStats(): Unit = { 39 | println(s"Steps: $steps ($loopSave loop, $directTail tail ($singleDispatchTail single-dispatch), ${steps-loopSave-directTail} other)") 40 | println(s" Cells created: $cellAlloc new ($proxyAlloc proxied), $cellReuse cached, $singletonUse singleton, $unboxedCells unboxed; Labels created: $labelCreate") 41 | } 42 | 43 | def logMetrics(): Unit = { 44 | val data = metrics.toVector.sortBy(_._1).map { case (k, v) => (k, v.i.toString) } 45 | val maxLen = data.iterator.map { case (k, v) => k.length + v.length }.max 46 | data.foreach { case (k, v) => 47 | val pad = " " * (maxLen-k.length-v.length) 48 | println(s" $k $pad$v") 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/scala/ExpandRules.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | /** 6 | * Complete rules by synthesizing omitted return wires in conditions and reductions. 7 | */ 8 | class ExpandRules(val global: Global) extends Transform with Phase { 9 | import global._ 10 | 11 | val normalizeCond = new NormalizeCondition(global) 12 | 13 | private[this] def wildcardCount(e: Expr): Int = e match { 14 | case _: Wildcard => 1 15 | case e: Apply => e.args.iterator.map(wildcardCount).sum 16 | case _ => 0 17 | } 18 | 19 | private[this] def returnArity(e: Expr): Int = e match { 20 | case e: Apply => e.target.sym.returnArity 21 | case Assignment(lhs, rhs) => wildcardCount(lhs) + wildcardCount(rhs) 22 | case _ => 0 23 | } 24 | 25 | override def apply(n: Statement): Vector[Statement] = n match { 26 | case n: Match => apply(n) 27 | case n: Def => apply(n) 28 | case n => Vector(n) 29 | } 30 | 31 | override def apply(m: Match): Vector[Statement] = returnArity(m.on.head) match { 32 | case 0 => checkedMatch(m.on, m.reduced, m.pos) 33 | case n => 34 | assert(m.on.length == 1) 35 | val p = m.on.head.pos 36 | checkedMatch(Vector(Assignment(m.on.head, Tuple((1 to n).map(i => mkLocalId(s"$$ret$i").setPos(p)).toVector).setPos(p)).setPos(p)), m.reduced, m.pos) 37 | } 38 | 39 | override def apply(d: Def): Vector[Statement] = { 40 | d.copy(rules = Vector.empty) +: d.rules.flatMap { r => 41 | val dret = Tuple(d.ret).setPos(d.pos) 42 | checkedMatch(Vector(Assignment(Apply(d.name, d.embeddedId, r.on ++ d.args.drop(r.on.length)).setPos(d.pos), dret).setPos(r.pos)), r.reduced, r.pos) 43 | } 44 | } 45 | 46 | private[this] def connectLastStatement(e: Expr, extraRhs: Vector[Ident]): Expr = e match { 47 | case e: Assignment => e 48 | case e: Tuple => 49 | if(e.exprs.length != extraRhs.length) 50 | error(s"Expected return arity ${extraRhs.length} for reduction but got ${e.exprs.length}", e) 51 | Assignment(Tuple(extraRhs).setPos(extraRhs.head.pos), e).setPos(extraRhs.head.pos) 52 | case e: Apply => 53 | val sym = globalSymbols(e.target.s) 54 | if(sym.returnArity == 0) e 55 | else { 56 | if(sym.returnArity != extraRhs.length) 57 | error(s"Expected return arity ${extraRhs.length} for reduction but got ${sym.returnArity}", e) 58 | Assignment(if(extraRhs.length == 1) extraRhs.head else Tuple(extraRhs).setPos(extraRhs.head.pos), e).setPos(extraRhs.head.pos) 59 | } 60 | case e: NatLit => 61 | if(extraRhs.length != 1) 62 | error(s"Expected return arity ${extraRhs.length} for reduction but got 1", e) 63 | Assignment(extraRhs.head, e).setPos(extraRhs.head.pos) 64 | case e: Ident => 65 | if(extraRhs.length != 1) 66 | error(s"Expected return arity ${extraRhs.length} for reduction but got 1", e) 67 | Assignment(extraRhs.head, e).setPos(e.pos) 68 | } 69 | 70 | private[this] def checkedMatch(on: Vector[Expr], red: Vector[Branch], pos: Position): Vector[Statement] = { 71 | val inlined = normalizeCond.toInline(normalizeCond.toANF(on).map(normalizeCond.toConsOrder)) 72 | inlined match { 73 | case Seq(Assignment(ApplyCons(lid, lemb, largs: Seq[Expr]), ApplyCons(rid, remb, rargs))) => 74 | val compl = if(lid.sym.isDef) largs.takeRight(lid.sym.returnArity) else Vector.empty 75 | val connected = red.map { r => 76 | r.copy(reduced = r.reduced.init :+ connectLastStatement(r.reduced.last, compl.asInstanceOf[Vector[Ident]])) 77 | } 78 | val mr = MatchRule(lid, rid, largs, rargs, lemb.map(_.asInstanceOf[Ident]), remb.map(_.asInstanceOf[Ident]), connected).setPos(pos) 79 | Vector(mr) 80 | case _ => 81 | error(s"Invalid match", pos) 82 | Vector.empty 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/main/scala/Global.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | import java.io.{PrintWriter, StringWriter} 6 | import scala.collection.mutable.ArrayBuffer 7 | import scala.util.control.NonFatal 8 | 9 | final class Global(val config: Config) { 10 | final val globalSymbols = new Symbols 11 | 12 | private[this] var hasErrors: Boolean = false 13 | private[this] val accumulated = ArrayBuffer.empty[Notice] 14 | 15 | def dependencyLoader: ClassLoader = getClass.getClassLoader 16 | 17 | def warning(msg: String, at: Node): Unit = error(msg, at.pos) 18 | def warning(msg: String, at: Position): Unit = 19 | accumulated += new Notice(msg, at, Severity.Warning) 20 | 21 | def error(msg: String, at: Node): Unit = error(msg, at.pos) 22 | def error(msg: String, at: Position): Unit = { 23 | accumulated += new Notice(msg, at, Severity.Error) 24 | hasErrors = true 25 | } 26 | 27 | def throwError(msg: String, at: Node): Nothing = throwError(msg, at.pos) 28 | def throwError(msg: String, at: Position): Nothing = { 29 | error(msg, at) 30 | throw new ReportedError 31 | } 32 | def tryError[T](f: => T): Option[T] = try Some(f) catch { case _: ReportedError => None } 33 | 34 | def fatal(msg: String, at: Node): Nothing = fatal(msg, at.pos) 35 | def fatal(msg: String, at: Position): Nothing = { 36 | accumulated += new Notice(msg, at, Severity.Fatal) 37 | throw getCompilerResult() 38 | } 39 | 40 | def internalError(msg: String, at: Position, parent: Throwable = null, atNode: Node = null): Unit = { 41 | accumulated += new Notice(msg, if(at == null && atNode != null) atNode.pos else at, Severity.Fatal, atNode, internal = true) 42 | throw getCompilerResult(parent) 43 | } 44 | 45 | def mkLocalId(name: String, isEmbedded: Boolean = false, payloadType: PayloadType = PayloadType.VOID): Ident = { 46 | val i = Ident(name) 47 | i.sym = Symbol(name, isEmbedded = isEmbedded, payloadType = payloadType) 48 | i 49 | } 50 | 51 | def getCompilerResult(parent: Throwable = null): CompilerResult = new CompilerResult(accumulated.toIndexedSeq, parent) 52 | 53 | def checkThrow(): Unit = 54 | if(hasErrors) throw getCompilerResult() 55 | 56 | def phaseLog(phase: String, msg: String): Unit = 57 | println(s"<$phase> $msg") 58 | 59 | def phaseLog(phase: String, n: ShowableNode, name: String, prefix: String): Unit = { 60 | val p = s"<$phase> " 61 | ShowableNode.print(n, name = name, prefix = p+prefix, prefix1 = p+prefix, highlightTopLevel = false) 62 | } 63 | } 64 | 65 | class Notice(msg: String, at: Position, severity: Severity, atNode: ShowableNode = null, internal: Boolean = false, mark: Node = null) { 66 | def formatted: String = { 67 | import MaybeColors._ 68 | import Notice._ 69 | val b = new StringBuilder 70 | val sev = if(internal) s"${cRed}Internal Error" else if(isError) s"${cRed}Error" else s"${cYellow}Warning" 71 | val msgLines = msg.split('\n') 72 | if(at.isDefined) { 73 | val (line, col) = at.input.find(at.offset) 74 | b.append(s"$sev: $cNormal${at.file}$cCyan:${line+1}:${col+1}$cNormal: ${msgLines.head}$eol") 75 | msgLines.tail.foreach { m => b.append(s"$cBlue| $cNormal$m$eol") } 76 | b.append(s"$cBlue| $cNormal${at.input.getLine(line).stripTrailing()}$eol") 77 | b.append(s"$cBlue| $cGreen${at.input.getCaret(col)}$cNormal") 78 | } else { 79 | b.append(s"$sev: $cNormal ${msgLines.head}$eol") 80 | msgLines.tail.foreach { m => b.append(s"$cBlue| $cNormal$m$eol") } 81 | } 82 | if(internal && atNode != null) { 83 | val out = new StringWriter() 84 | val outw = new PrintWriter(out) 85 | outw.print(s"\n$cBlue| ${cRed}AST Context:$cNormal\n") 86 | ShowableNode.print(atNode, outw, prefix = s"$cBlue| ", mark = mark) 87 | b.append(out.toString) 88 | } 89 | b.result() 90 | } 91 | override def toString: String = formatted 92 | def isError: Boolean = severity != Severity.Warning 93 | } 94 | 95 | object Notice { 96 | val eol: String = sys.props("line.separator") 97 | } 98 | 99 | sealed abstract class Severity 100 | object Severity { 101 | case object Warning extends Severity 102 | case object Error extends Severity 103 | case object Fatal extends Severity 104 | } 105 | 106 | class CompilerResult(val notices: IndexedSeq[Notice], parent: Throwable = null) extends Exception(parent) { 107 | lazy val hasErrors = notices.exists(_.isError) 108 | lazy val summary: String = { 109 | val errs = notices.count(_.isError) 110 | val warns = notices.length - errs 111 | def fmt(i: Int, s: String) = if(i == 1) s"1 $s" else s"$i ${s}s" 112 | if(warns > 0) s"${fmt(errs, "error")}, ${fmt(warns, "warnings")} found." 113 | else s"${fmt(errs, "error")} found." 114 | } 115 | override def getMessage: String = { 116 | import Notice._ 117 | val b = (new StringBuilder).append(eol) 118 | notices.foreach(n => b.append(n.formatted).append(eol).append(eol)) 119 | b.append(summary).result() 120 | } 121 | } 122 | object CompilerResult { 123 | def tryInternal[T](at: Position)(f: => T): T = try f catch { 124 | case e: CompilerResult => throw e 125 | case e: AssertionError => throw new CompilerResult(Vector(new Notice(e.toString, at, Severity.Fatal, internal = true)), e) 126 | case NonFatal(e) => throw new CompilerResult(Vector(new Notice(e.toString, at, Severity.Fatal, internal = true)), e) 127 | } 128 | def tryInternal[T](atNode: Node)(f: => T): T = try f catch { 129 | case e: CompilerResult => throw e 130 | case e: AssertionError => throw new CompilerResult(Vector(new Notice(e.toString, atNode.pos, Severity.Fatal, internal = true, atNode = atNode)), e) 131 | case NonFatal(e) => throw new CompilerResult(Vector(new Notice(e.toString, atNode.pos, Severity.Fatal, internal = true, atNode = atNode)), e) 132 | } 133 | 134 | def fail(msg: String, at: Position = null, parent: Throwable = null, atNode: Node = null, internal: Boolean = true, mark: Node = null): Nothing = 135 | throw new CompilerResult(Vector(new Notice(msg, if(at == null && atNode != null) atNode.pos else at, Severity.Fatal, atNode, internal, mark))) 136 | } 137 | 138 | class ReportedError extends Exception("Internal error: ReportedError thrown by throwError must be caught by a surrounding tryError") 139 | 140 | trait BaseInterpreter { 141 | def getAnalyzer: Analyzer[_] 142 | def initData(): Unit 143 | def reduce(): Unit 144 | def dispose(): Unit 145 | def getMetrics: ExecutionMetrics 146 | } 147 | -------------------------------------------------------------------------------- /src/main/scala/Main.scala: -------------------------------------------------------------------------------- 1 | import java.nio.file.Path 2 | import de.szeiger.interact._ 3 | import de.szeiger.interact.ast.ShowableNode 4 | 5 | object Main extends App { 6 | def handleRes(res: CompilerResult, full: Boolean): Unit = { 7 | if(full) res.printStackTrace(System.out) 8 | else { 9 | res.notices.foreach(println) 10 | println(res.summary) 11 | } 12 | if(res.hasErrors) sys.exit(1) 13 | } 14 | try { 15 | val unit = Parser.parse(Path.of(args(0))) 16 | ShowableNode.print(unit) 17 | //statements.foreach(println) 18 | val model = new Compiler(unit, Config(backend = STIBackend, collectStats = true)) 19 | 20 | //println("Constructors:") 21 | //model.constrs.foreach(c => println(s" ${c.show}")) 22 | //println("Defs:") 23 | //model.defs.foreach(d => println(s" ${d.show}")) 24 | //println("Rules:") 25 | //model.rules.foreach(r => if(!r.isInstanceOf[DerivedRule]) println(s" ${r.show}")) 26 | //println("Data:") 27 | //model.data.foreach(r => println(s" ${r.show}")) 28 | //ShowableNode.print(model.unit) 29 | 30 | ShowableNode.print(model.unit) 31 | val inter = model.createInterpreter() 32 | inter.initData() 33 | println("Initial state:") 34 | inter.getAnalyzer.log(System.out) 35 | inter.reduce() 36 | if(inter.getMetrics != null) inter.getMetrics.log() 37 | println(s"Irreducible after ${inter.getMetrics.getSteps} reductions.") 38 | inter.getAnalyzer.log(System.out) 39 | handleRes(model.global.getCompilerResult(), false) 40 | } catch { case ex: CompilerResult => handleRes(ex, true) } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/scala/NormalizeCondition.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | import scala.collection.mutable 6 | 7 | // Normalize expressions: 8 | // - all compound expressions are unnested (ANF) 9 | // - constructor Idents are converted to Ap 10 | // - only non-constructor Idents can be nested 11 | // - all Ap assignments have the Ap on the RHS 12 | // - all direct assignments are untupled 13 | // - wildcards in assignments are resolved 14 | // - only the last expr can be a non-assignment 15 | // - NatLits are expanded 16 | class NormalizeCondition(global: Global) { 17 | import global._ 18 | 19 | private var lastTmp = 0 20 | private def mk(): Ident = { lastTmp += 1; mkLocalId(s"$$u${lastTmp}") } 21 | 22 | def toANF(exprs: Seq[Expr]): Seq[Expr] = { 23 | val assigned = mutable.HashSet.empty[Ident] 24 | val buf = mutable.ArrayBuffer.empty[Expr] 25 | def expandWildcards(e: Expr, wild: Ident): Expr = e match { 26 | case Assignment(ls, rs) => 27 | val wild2 = mk() 28 | val ass2 = Assignment(expandWildcards(ls, wild2), expandWildcards(rs, wild2)).setPos(e.pos) 29 | if(assigned.contains(wild2)) { reorder(unnest(ass2, false)); wild2 } 30 | else ass2 31 | case Tuple(es) => Tuple(es.map(expandWildcards(_, wild))).setPos(e.pos) 32 | case Apply(t, emb, args) => Apply(t, emb, args.map(expandWildcards(_, wild))).setPos(e.pos) 33 | case Wildcard() => 34 | if(assigned.contains(wild)) error(s"Duplicate wildcard in assignment", e) 35 | assigned += wild 36 | wild.setPos(e.pos) 37 | case e: Ident => e 38 | case e: NatLit => e.expand 39 | } 40 | def unnest(e: Expr, nested: Boolean): Expr = e match { 41 | case Assignment(ls, rs) => 42 | if(nested) error("Unexpected nested assignment without wilcard", e) 43 | Assignment(unnest(ls, false), unnest(rs, false)).setPos(e.pos) 44 | case e: Tuple => e.copy(e.exprs.map(unnest(_, true))) 45 | case IdentOrAp(id, emb, args) => 46 | if(id.sym.isCons || args.nonEmpty) { 47 | val ap = Apply(id, emb, args.map(unnest(_, true))).setPos(e.pos) 48 | if(nested) { 49 | val v = mk().setPos(e.pos) 50 | reorder(Assignment(v, ap).setPos(e.pos)) 51 | v 52 | } else ap 53 | } else id 54 | } 55 | def reorder(e: Expr): Unit = e match { 56 | case Assignment(ls: Apply, rs: Apply) => 57 | val sym1 = ls.target.sym 58 | val sym2 = rs.target.sym 59 | if(sym1.returnArity != sym2.returnArity) 60 | error(s"Arity mismatch in assignment: ${sym1.returnArity} != ${sym2.returnArity}", e) 61 | if(sym1.returnArity == 1) { 62 | val v = mk().setPos(e.pos) 63 | buf += Assignment(v, ls).setPos(ls.pos) 64 | buf += Assignment(v, rs).setPos(rs.pos) 65 | } else { 66 | val vs = (0 until sym1.returnArity).map(_ => mk().setPos(e.pos)).toVector 67 | buf += Assignment(Tuple(vs).setPos(ls.pos), ls).setPos(ls.pos) 68 | buf += Assignment(Tuple(vs).setPos(rs.pos), rs).setPos(rs.pos) 69 | } 70 | case e @ Assignment(ls: Apply, rs) => reorder(e.copy(rs, ls)) 71 | case Assignment(Tuple(ls), Tuple(rs)) => 72 | ls.zip(rs).foreach { case (l, r) => reorder(Assignment(l, r).setPos(e.pos)) } 73 | case e => buf += e 74 | } 75 | exprs.foreach { e => 76 | val wild = mk() 77 | reorder(unnest(expandWildcards(e, wild), false)) 78 | if(assigned.contains(wild)) error("Unexpected wildcard outside of assignment", e) 79 | } 80 | buf.toSeq 81 | } 82 | 83 | // reorder assignments as if every def was a cons 84 | def toConsOrder(e: Expr): Expr = e match { 85 | case Assignment(id @ IdentOrTuple(es), a @ Apply(t, emb, args)) => 86 | if(!t.sym.isDef) Assignment(id, ApplyCons(t, emb, args).setPos(a.pos)).setPos(e.pos) else Assignment(args.head, ApplyCons(t, emb, args.tail ++ es).setPos(a.pos)).setPos(e.pos) 87 | case Apply(t, emb, args) => 88 | if(!t.sym.isDef) ApplyCons(t, emb, args).setPos(e.pos) else Assignment(args.head, ApplyCons(t, emb, args.tail).setPos(e.pos)) 89 | case e => e 90 | } 91 | 92 | // convert from cons-order ANF back to inlined expressions 93 | def toInline(es: Seq[Expr]): Seq[Expr] = { 94 | if(es.isEmpty) es 95 | else { 96 | val vars = mutable.HashMap.from(es.init.map { case a: Assignment => (a.lhs, a) }) 97 | def f(e: Expr): Expr = e match { 98 | case e: Ident => vars.remove(e).map { a => f(a.rhs) }.getOrElse(e) 99 | case e: Tuple => vars.remove(e).map { a => f(a.rhs) }.getOrElse(e) 100 | case e @ ApplyCons(target, emb, args) => e.copy(args = args.map(f)) 101 | case e @ Assignment(l, r) => e.copy(f(r), f(l)) 102 | } 103 | val e2 = f(es.last) 104 | (vars.valuesIterator ++ Iterator.single(e2)).toSeq 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/scala/Parser.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import fastparse._ 4 | 5 | import java.nio.charset.StandardCharsets 6 | import java.nio.file.{Files, Path} 7 | import scala.collection.mutable 8 | import de.szeiger.interact.ast._ 9 | 10 | 11 | import scala.collection.mutable.ArrayBuffer 12 | 13 | object Lexical { 14 | import NoWhitespace._ 15 | 16 | val reservedTokens = Set("cons", "let", "deriving", "def", "if", "else", "match", "_", ":", "=", "=>", "|") 17 | private val operatorStart = IndexedSeq(Set('*', '/', '%'), Set('+', '-'), Set(':'), Set('<', '>'), Set('=', '!'), Set('&'), Set('^'), Set('|')) 18 | private val operatorCont = operatorStart.iterator.flatten.toSet 19 | val MaxPrecedence = operatorStart.length-1 20 | 21 | def precedenceOf(s: String): Int = 22 | if(reservedTokens.contains(s) || !s.forall(operatorCont.contains)) -1 23 | else { val c = s.charAt(0); operatorStart.indexWhere(_.contains(c)) } 24 | def isRightAssoc(s: String): Boolean = s.endsWith(":") 25 | 26 | def ident[_: P]: P[String] = 27 | P( (letter|"_") ~ (letter | digit | "_").rep ).!.filter(!reservedTokens.contains(_)) 28 | def kw[_: P](s: String) = P( s ~ !(letter | digit | "_") ) 29 | def letter[_: P] = P( CharIn("a-z") | CharIn("A-Z") ) 30 | def digit[_: P] = P( CharIn("0-9") ) 31 | def natLit[_: P] = P( digit.rep(1).! ~ "n" ).map(_.toInt) 32 | def intLit[_: P] = P( (("-").? ~ digit.rep(1)).! ).map(_.toInt) 33 | 34 | def operator[_: P](precedence: Int): P[String] = 35 | P( CharPred(operatorStart(precedence).contains) ~ CharPred(operatorCont.contains).rep ).!.filter(s => !reservedTokens.contains(s)) 36 | 37 | def anyOperator[_: P]: P[Ident] = 38 | P( CharPred(operatorCont.contains).rep(1).!.filter(s => !reservedTokens.contains(s)).map(Ident(_)) ) 39 | 40 | def stringLit[_: P]: P[String] = P( "\"" ~ (stringChar | stringEscape).rep.! ~ "\"" ) 41 | def stringChar[_: P]: P[Unit] = P( CharsWhile(!s"\\\n\"}".contains(_)) ) 42 | def stringEscape[_: P]: P[Unit] = P( "\\" ~ AnyChar ) 43 | 44 | def comment[$: P] = P( "#" ~ CharsWhile(_ != '\n', 0) ) 45 | def whitespace[$: P](indent: Int) = P( 46 | CharsWhile(_ == ' ', 0) ~ (CharsWhile(_ == ' ', 0) ~ comment.? ~ "\r".? ~ "\n" ~ " ".rep(indent)).rep(0) 47 | ) 48 | } 49 | 50 | trait EmbeddedSyntax { this: Parser => 51 | import Lexical._ 52 | 53 | def identExpr[_: P]: P[Ident] = positioned(ident.map(Ident)) 54 | 55 | def embeddedExpr[_: P]: P[EmbeddedExpr] = 56 | P( positioned(embeddedOperatorExpr(MaxPrecedence)) ) 57 | 58 | def embeddedAssignment[_: P]: P[EmbeddedAssignment] = 59 | P( positioned((identExpr ~ "=" ~ embeddedExpr).map { case (id, ee) => EmbeddedAssignment(id, ee) }) ) 60 | 61 | def embeddedExprOrAssignment[_: P]: P[EmbeddedExpr] = 62 | P( embeddedAssignment | embeddedExpr ) 63 | 64 | def embeddedAp[_: P]: P[EmbeddedApply] = 65 | P( identExpr.rep(1, ".").map(_.toVector) ~ "(" ~ embeddedExpr.rep(0, ",").map(_.toVector) ~ ")" ).map { case (method, args) => 66 | EmbeddedApply(method, args, false, EmbeddedType.Unknown) 67 | } 68 | 69 | def bracketedEmbeddedExpr[_: P]: P[EmbeddedExpr] = 70 | P( "[" ~ embeddedExprOrAssignment ~ "]" ) 71 | 72 | def simpleEmbeddedExpr[_: P]: P[EmbeddedExpr] = 73 | P( embeddedAp | identExpr | intLit.map(IntLit) | stringLit.map(StringLit) | ("(" ~ embeddedExpr ~ ")") ) 74 | 75 | def operatorEmbeddedIdent[_: P](precedence: Int): P[Ident] = positioned(operator(precedence).map(Ident)) 76 | 77 | def embeddedOperatorExpr[_: P](precedence: Int): P[EmbeddedExpr] = { 78 | def next = if(precedence == 0) positioned(simpleEmbeddedExpr) else embeddedOperatorExpr(precedence - 1) 79 | P( next ~ (operatorEmbeddedIdent(precedence) ~ next).rep ).map { 80 | case (e, Seq()) => e 81 | case (e, ts) => 82 | val right = ts.count(_._1.s.endsWith(":")) 83 | if(right == 0) 84 | ts.foldLeft(e) { case (z, (o, a)) => EmbeddedApply(Vector(o), Vector(z, a), true, EmbeddedType.Unknown).setPos(o.pos) } 85 | else if(right == ts.length) { 86 | val e2 = ts.last._2 87 | val ts2 = ts.map(_._1).zip(e +: ts.map(_._2).init) 88 | ts2.foldRight(e2) { case ((o, a), z) => EmbeddedApply(Vector(o), Vector(a, z), true, EmbeddedType.Unknown).setPos(o.pos) } 89 | } else sys.error("Chained binary operators must have the same associativity") 90 | } 91 | } 92 | } 93 | 94 | trait Syntax { this: Parser => 95 | import Lexical._ 96 | 97 | def wildcard[_: P]: P[Wildcard] = P("_").map(_ => Wildcard()) 98 | 99 | def nat[_: P]: P[Expr] = P( positioned(natLit.map(NatLit)) ) 100 | 101 | def appOrIdent[_: P]: P[Expr] = 102 | P( 103 | positioned((identExpr ~ bracketedEmbeddedExpr.? ~ ("(" ~ expr.rep(sep = ",") ~ ")").?).map { case (id, embO, argsO) => 104 | if(embO.isDefined || argsO.isDefined) Apply(id, embO, argsO.getOrElse(Vector.empty).toVector) else id 105 | }) 106 | ) 107 | 108 | def tuple[_: P]: P[Tuple] = 109 | P( "(" ~ expr.rep(min = 0, sep = ",").map(_.toVector) ~ ")" ).map(Tuple) 110 | 111 | def simpleExpr[_: P]: P[Expr] = 112 | P( (appOrIdent | wildcard | nat | tuple) ) 113 | 114 | def operatorIdent[_: P](precedence: Int): P[Ident] = positioned(operator(precedence).map(Ident(_))) 115 | 116 | def operatorEx[_: P](precedence: Int): P[Expr] = { 117 | def next = if(precedence == 0) positioned(simpleExpr) else operatorEx(precedence - 1) 118 | P( next ~ (operatorIdent(precedence) ~ bracketedEmbeddedExpr.? ~ next).rep ).map { 119 | case (e, Seq()) => e 120 | case (e, ts) => 121 | val right = ts.count(_._1.s.endsWith(":")) 122 | if(right == 0) 123 | ts.foldLeft(e) { case (z, (o, oe, a)) => Apply(o, oe, Vector(z, a)).setPos(o.pos) } 124 | else if(right == ts.length) { 125 | val e2 = ts.last._3 126 | val ts2 = ts.map(t => (t._1, t._2)).zip(e +: ts.map(_._3).init) 127 | ts2.foldRight(e2) { case (((o, oe), a), z) => Apply(o, oe, Vector(a, z)).setPos(o.pos) } 128 | } else sys.error("Chained binary operators must have the same associativity") 129 | } 130 | } 131 | 132 | def expr[_: P]: P[Expr] = 133 | P( positioned(operatorEx(MaxPrecedence)) ~ ("=" ~ positioned(operatorEx(MaxPrecedence))).? ).map { 134 | case (e1, None) => e1 135 | case (e1, Some(e2)) => Assignment(e1, e2).setPos(e1.pos) 136 | } 137 | 138 | def params[_: P](min: Int): P[Vector[IdentOrWildcard]] = 139 | P( ("(" ~ param.rep(min = min, sep = ",").map(_.toVector) ~ ")") ) 140 | 141 | def param[_: P]: P[IdentOrWildcard] = 142 | P( positioned(identExpr | wildcard) ) 143 | 144 | def defReturn[_: P]: P[Vector[IdentOrWildcard]] = 145 | P( params(1) | identExpr.map(Vector(_)) ) 146 | 147 | def deriving[_ : P]: P[Vector[Ident]] = 148 | P( kw("deriving") ~/ "(" ~ identExpr.rep(0, sep=",").map(_.toVector) ~ ")" ) 149 | 150 | def payloadType[_: P]: P[PayloadType] = 151 | P(ident).map { 152 | case "int" => (PayloadType.INT) 153 | case "ref" => (PayloadType.REF) 154 | case "label" => (PayloadType.LABEL) 155 | case tpe => sys.error(s"Illegal payload type: $tpe") 156 | } 157 | 158 | def embeddedSpecOpt[_: P]: P[(PayloadType, Option[Ident])] = 159 | P( ("[" ~ payloadType ~ identExpr.? ~ "]" ).? ).map(_.getOrElse((PayloadType.VOID, None))) 160 | 161 | def cons[_: P]: P[Cons] = 162 | P( kw("cons") ~/ (operatorDef | namedCons) ~ ("=" ~ identExpr).? ~ deriving.? ).map(Cons.tupled) 163 | 164 | def definition[_: P]: P[Def] = 165 | P( kw("def") ~/ (operatorDef | namedDef) ~ ("=" ~ defReturn).?.map(_.getOrElse(Vector.empty)) ~ defRule.rep.map(_.toVector) ).map(Def.tupled) 166 | 167 | def namedCons[_: P]: P[(Ident, Vector[IdentOrWildcard], Boolean, PayloadType, Option[Ident])] = 168 | P( identExpr ~ embeddedSpecOpt ~ params(0).?.map(_.getOrElse(Vector.empty)) ).map { case (n, (pt, eid), as) => (n, as, false, pt, eid) } 169 | 170 | def operatorDef[_: P]: P[(Ident, Vector[IdentOrWildcard], Boolean, PayloadType, Option[Ident])] = 171 | P( param ~ positioned(anyOperator) ~ embeddedSpecOpt ~ param ).map { case (a1, o, (pt, eid), a2) => (o, Vector(a1, a2), true, pt, eid) } 172 | 173 | def namedDef[_: P]: P[(Ident, Vector[IdentOrWildcard], Boolean, PayloadType, Option[Ident])] = 174 | P( identExpr ~ embeddedSpecOpt ~ params(1) ).map { case (n, (pt, eid), as) => (n, as, false, pt, eid) } 175 | 176 | def anyExpr[_ : P]: P[Either[EmbeddedExpr, Expr]] = 177 | P( bracketedEmbeddedExpr.map(Left(_)) | expr.map(Right(_)) ) 178 | 179 | def anyExprs[_: P]: P[Vector[Either[EmbeddedExpr, Expr]]] = 180 | P( anyExpr.rep(1, sep = ";").map(_.toVector) ~ ";".? ) 181 | 182 | def anyExprBlock[_: P]: P[(Vector[Expr], Vector[EmbeddedExpr])] = 183 | P( pos.flatMapX(p => forIndent(p.column).anyExprBlock2) ) 184 | 185 | def anyExprBlock2[_: P]: P[(Vector[Expr], Vector[EmbeddedExpr])] = 186 | P( (forIndent(indent+1).anyExprs).rep(1) ).map { es => 187 | (es.iterator.flatten.collect { case Right(e) => e }.toVector, es.iterator.flatten.collect { case Left(e) => e }.toVector) 188 | } 189 | 190 | def simpleReduction[_: P]: P[Branch] = 191 | P( positioned("=>" ~ anyExprBlock.map { case (es, ees) => Branch(None, ees, es) }) ) 192 | 193 | def conditionalReductions[_: P]: P[Vector[Branch]] = 194 | P( ("if" ~ (bracketedEmbeddedExpr ~ simpleReduction).map { case (p, r) => r.copy(cond = Some(p))}).rep(1).map(_.toVector) ~ 195 | "else" ~ simpleReduction 196 | ).map { case (rs, r) => rs :+ r } 197 | 198 | def reductions[_: P]: P[Vector[Branch]] = 199 | P( conditionalReductions | simpleReduction.map(Vector(_)) ) 200 | 201 | def condition[_: P]: P[EmbeddedExpr] = 202 | P( "if" ~ bracketedEmbeddedExpr ) 203 | 204 | def defRule[_: P]: P[DefRule] = 205 | P( positioned(("|" ~ expr.rep(1, ",").map(_.toVector) ~ reductions).map(DefRule.tupled)) ) 206 | 207 | def matchStatement[_: P]: P[Match] = 208 | P( "match" ~ expr ~ reductions ).map { case (on, red) => Match(Vector(on), red) } 209 | 210 | def let[_: P]: P[Let] = 211 | P( kw("let") ~/ anyExprBlock ).map { case (defs, emb) => Let(defs, emb, Vector.empty) } 212 | 213 | def unit[_: P]: P[CompilationUnit] = 214 | P( Start ~ pos ~ positioned(cons | let | definition | matchStatement ).rep ~ End ).map { case (p, es) => CompilationUnit(es.toVector).setPos(p) } 215 | } 216 | 217 | class Parser(file: String, indexed: ConvenientParserInput, val indent: Int) extends EmbeddedSyntax with Syntax { 218 | implicit val whitespace: fastparse.Whitespace = 219 | (ctx: P[_]) => Lexical.whitespace(indent)(ctx) 220 | 221 | private[this] val positions = mutable.LongMap.empty[Position] 222 | 223 | def pos[_: P]: P[Position] = Index.map { offset => positions.getOrElseUpdate(offset, new Position(offset, file, indexed)) } 224 | 225 | def positioned[T <: Node, _: P](n: => P[T]): P[T] = 226 | P( (pos ~~ n) ).map { case (p, e) => if(!e.pos.isDefined) e.setPos(p) else e } 227 | 228 | def forIndent(i: Int): Parser = new Parser(file, indexed, i) 229 | } 230 | 231 | object Parser { 232 | def parse(input: String, file: String = ""): CompilationUnit = { 233 | val in = new ConvenientParserInput(input) 234 | val p = new Parser(file, in, 0) 235 | fastparse.parse(in, p.unit(_), verboseFailures = true).get.value 236 | } 237 | 238 | def parse(file: Path): CompilationUnit = 239 | parse(new String(Files.readAllBytes(file), StandardCharsets.UTF_8), file.toString) 240 | } 241 | 242 | class ConvenientParserInput(val data: String) extends ParserInput { 243 | def apply(index: Int): Char = data.charAt(index) 244 | def dropBuffer(index: Int): Unit = () 245 | def slice(from: Int, until: Int): String = data.slice(from, until) 246 | def length: Int = data.length 247 | def innerLength: Int = length 248 | def isReachable(index: Int): Boolean = index < length 249 | def checkTraceable(): Unit = () 250 | 251 | private[this] lazy val lineBreaks = { 252 | val b = ArrayBuffer.empty[Int] 253 | for(i <- data.indices) 254 | if(data.charAt(i) == '\n') b += i 255 | b 256 | } 257 | 258 | def find(idx: Int): (Int, Int) = { 259 | val line = lineBreaks.indexWhere(_ > idx) match { 260 | case -1 => lineBreaks.length 261 | case n => n 262 | } 263 | val col = if(line == 0) idx else idx - lineBreaks(line-1) - 1 264 | (line, col) 265 | } 266 | 267 | def prettyIndex(idx: Int): String = { 268 | val (line, pos) = find(idx) 269 | s"${line+1}:${pos+1}" 270 | } 271 | 272 | def getLine(line: Int): String = { 273 | val start = if(line == 0) 0 else lineBreaks(line-1) 274 | val end = if(line == lineBreaks.length) data.length else lineBreaks(line) 275 | data.substring(start, end).dropWhile(c => c == '\r' || c == '\n') 276 | } 277 | 278 | def getCaret(col: Int): String = (" " * col) + "^" 279 | } 280 | -------------------------------------------------------------------------------- /src/main/scala/Prepare.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | /** 6 | * Create all symbols and check function calls and linearity. 7 | */ 8 | class Prepare(val global: Global) extends Phase { 9 | import global._ 10 | 11 | def apply(unit: CompilationUnit): CompilationUnit = { 12 | unit.statements.foreach { 13 | case c: Cons => 14 | if(globalSymbols.contains(c.name)) error(s"Duplicate cons/def: ${c.name.s}", c.name) 15 | else c.name.sym = globalSymbols.defineCons(c.name.s, c.args.length, c.payloadType) 16 | case d: Def => 17 | if(globalSymbols.contains(d.name)) error(s"Duplicate cons/def: ${d.name.s}", d.name) 18 | else d.name.sym = globalSymbols.defineDef(d.name.s, d.args.length, d.ret.length, d.payloadType) 19 | case _ => 20 | } 21 | val st2 = Transform.mapC(unit.statements)(assign(_, globalSymbols)) 22 | if(st2 eq unit.statements) unit else unit.copy(st2) 23 | } 24 | 25 | private[this] def assign(n: Statement, scope: Symbols): Statement = n match { 26 | case Cons(_, args, _, _, embId, ret, der) => 27 | val sc = scope.sub() 28 | args.foreach(assign(_, sc)) 29 | args.foreach(a => if(a.isWildcard) error("No wildcard parameters allowed in cons", a)) 30 | ret.foreach(assign(_: Expr, sc)) 31 | embId.foreach(assign(_: EmbeddedExpr, sc)) 32 | der.foreach(_.foreach { i => 33 | val symO = scope.get(i) 34 | if(!symO.exists(_.isCons)) 35 | error(s"No constructor '${i.s}' defined", i) 36 | else i.sym = symO.get 37 | }) 38 | n 39 | case Def(_, args, _, _, embId, ret, rules) => 40 | val sc = scope.sub() 41 | args.foreach(assign(_, sc)) 42 | ret.foreach(assign(_, sc)) 43 | embId.foreach(assign(_: EmbeddedExpr, sc)) 44 | val rm = new RefsMap 45 | args.foreach(rm.collectAll) 46 | ret.foreach(rm.collectAll) 47 | embId.foreach(rm.collectAll) 48 | if(rm.hasNonFree) 49 | error(s"Duplicate variable(s) in def pattern ${rm.nonFree.map(s => s"'$s'").mkString(", ")}", n) 50 | rules.foreach(assign(_, sc, rm)) 51 | n 52 | case Match(on, reduced) => 53 | val sc = scope.sub() 54 | on.foreach(assign(_, sc)) 55 | val rm = new RefsMap 56 | on.foreach(rm.collectAll) 57 | if(rm.hasNonFree) 58 | error(s"Duplicate variable(s) in match pattern ${rm.nonFree.map(s => s"'$s'").mkString(", ")}", n) 59 | reduced.foreach(assign(_, sc)) 60 | n 61 | case n @ Let(defs, embDefs, _) => 62 | val sc = scope.sub() 63 | defs.foreach(assign(_, sc)) 64 | embDefs.foreach(assign(_, sc)) 65 | val refs = new RefsMap 66 | defs.foreach(refs.collectAll) 67 | if(refs.hasError) 68 | error(s"Non-linear use of variable(s) ${refs.err.map(s => s"'$s'").mkString(", ")}", n) 69 | n.copy(free = refs.free.filterNot(_.isEmbedded).map(sym => Ident(sym.id).setSym(sym)).toVector.sortBy(_.s)) 70 | case _: CheckedRule => n 71 | } 72 | 73 | private[this] def assign(n: DefRule, scope: Symbols, defRefs: RefsMap): Unit = { 74 | val sc = scope.sub() 75 | n.on.foreach(assign(_, sc)) 76 | val rm = defRefs.sub() 77 | n.on.foreach(rm.collectAll) 78 | if(rm.hasNonFree) 79 | error(s"Duplicate variable(s) in def rule pattern ${rm.nonFree.map(s => s"'$s'").mkString(", ")}", n) 80 | n.reduced.foreach(assign(_, sc)) 81 | } 82 | 83 | private[this] def assign(n: Branch, scope: Symbols): Unit = { 84 | val sc = scope.sub() 85 | n.cond.foreach(assign(_, sc)) 86 | n.reduced.foreach(assign(_, sc)) 87 | n.embRed.foreach(assign(_, sc)) 88 | } 89 | 90 | private[this] def define(n: Ident, scope: Symbols, embedded: Boolean): Symbol = { 91 | scope.get(n) match { 92 | case Some(s) => 93 | if(s.isEmbedded && !embedded) 94 | error(s"Embedded variable '$s' used in non-embedded context", n) 95 | else if(!s.isEmbedded && embedded) 96 | error(s"Non-embedded variable '$s' used in embedded context", n) 97 | s 98 | case None => 99 | scope.define(n.s, isEmbedded = embedded) 100 | } 101 | } 102 | 103 | private[this] def assign(n: Expr, scope: Symbols): Unit = n match { 104 | case n: Ident => 105 | assert(n.sym.isEmpty) 106 | n.sym = define(n, scope, false) 107 | case n: NatLit => 108 | n.sSym = define(Ident("S"), scope, false) 109 | n.zSym = define(Ident("Z"), scope, false) 110 | if(!n.sSym.isCons || n.sSym.arity != 1 || !n.zSym.isCons || n.zSym.arity != 0) 111 | error(s"Nat literal requires appropriate Z() and S(_) constructors", n) 112 | case Apply(id, emb, args) => 113 | assign(id: Expr, scope) 114 | emb.foreach(assign(_: EmbeddedExpr, scope)) 115 | val l = args.length 116 | if(!id.sym.isCons) 117 | error(s"Symbol '${id.sym}' in function call is not a constructor", id) 118 | else if(id.sym.callArity != l) 119 | error(s"Wrong number of arguments for '${id.sym}': got $l, expected ${id.sym.callArity}", n) 120 | if(l == 1) assign(args.head, scope) // tail-recursive call 121 | else if(l > 0) args.foreach(assign(_, scope)) 122 | case n => n.nodeChildren.foreach { 123 | case ch: Expr => assign(ch, scope) 124 | } 125 | } 126 | 127 | private[this] def assign(n: EmbeddedExpr, scope: Symbols): Unit = n match { 128 | case n: Ident => 129 | assert(n.sym.isEmpty) 130 | n.sym = define(n, scope, true) 131 | case EmbeddedApply(_, args, _, _) => 132 | args.foreach(assign(_, scope)) 133 | case n => n.nodeChildren.foreach { 134 | case ch: EmbeddedExpr => assign(ch, scope) 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/main/scala/ResolveEmbedded.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast._ 4 | 5 | import java.lang.reflect.Method 6 | 7 | /** Resolve embedded methods and assign embedded types. */ 8 | class ResolveEmbedded(val global: Global) extends Transform with Phase { 9 | import global._ 10 | 11 | override def apply(_n: Branch): Branch = { 12 | val n2 = super.apply(_n) 13 | n2.cond.foreach { ee => 14 | val (tp, _) = getReturnType(ee) 15 | if(tp != EmbeddedType.Bool) 16 | error(s"Expected boolean type for condition expression, got $tp", ee) 17 | } 18 | n2 19 | } 20 | 21 | override def apply(_n: EmbeddedApply): EmbeddedApply = { 22 | val n2 = super.apply(_n) 23 | val args = n2.args.map(getReturnType) 24 | val (clsName, mName, qn) = { 25 | if(n2.op && args.length == 2 && args(0)._1 == EmbeddedType.PayloadLabel && args(1)._1 == EmbeddedType.PayloadLabel) ResolveEmbedded.eqLabel 26 | else if(n2.op) ResolveEmbedded.operators(n2.methodQN.head) 27 | else (n2.className, n2.methodName, n2.methodQNIds) 28 | } 29 | val methTp = resolveMethod(n2, clsName, mName, args) 30 | val n3 = n2.copy(methodQNIds = qn, embTp = methTp, op = false) 31 | //ShowableNode.print(n3, name = "Resolved") 32 | n3 33 | } 34 | 35 | private[this] def getReturnType(ee: EmbeddedExpr): (EmbeddedType, Boolean /* out */) = ee match { 36 | case _: StringLit => (EmbeddedType.PayloadRef, false) 37 | case _: IntLit => (EmbeddedType.PayloadInt, false) 38 | case ee: Ident => (EmbeddedType.Payload(ee.sym.payloadType), !ee.sym.isPattern) 39 | case ee: EmbeddedApply => (ee.embTp match { 40 | case t: EmbeddedType.Method => t.ret 41 | case t @ EmbeddedType.Unknown => t 42 | }, false) 43 | case _ => (EmbeddedType.PayloadVoid, false) 44 | } 45 | 46 | private[this] def resolveMethod(e: EmbeddedApply, cln: String, mn: String, args: Vector[(EmbeddedType, Boolean)]): EmbeddedType = { 47 | val c = dependencyLoader.loadClass(cln) 48 | def toPT(cl: Class[_]): (EmbeddedType, Boolean) = cl.getName match { 49 | case "void" => (EmbeddedType.PayloadVoid, false) 50 | case "int" => (EmbeddedType.PayloadInt, false) 51 | case "boolean" => (EmbeddedType.Bool, false) 52 | case s if s == classOf[IntOutput].getName => (EmbeddedType.PayloadInt, true) 53 | case s if s == classOf[RefOutput].getName => (EmbeddedType.PayloadRef, true) 54 | case _ if !cl.isPrimitive => (EmbeddedType.PayloadRef, false) 55 | case s => (EmbeddedType.Unknown, false) 56 | } 57 | tryError { 58 | val nameCandidates = c.getDeclaredMethods.filter(_.getName == mn).map { m => 59 | EmbeddedType.Method(m, toPT(m.getReturnType)._1, m.getParameterTypes.iterator.map(toPT).toVector) 60 | } 61 | val typeCandidates = nameCandidates.filter { m => 62 | m.args == args.map { 63 | case (EmbeddedType.Payload(PayloadType.LABEL), o) => (EmbeddedType.PayloadRef, o) 64 | case (t, o) => (t, o) 65 | } 66 | } 67 | if(nameCandidates.isEmpty) throwError(s"Method $cln.$mn not found.", e) 68 | else if(typeCandidates.isEmpty) { 69 | val exp = showEmbeddedSignature(args, EmbeddedType.Unknown, mn) 70 | val found = nameCandidates.map(m => showJavaSignature(m.method)).mkString(" ", "\n ", "") 71 | throwError(s"No applicable overload of method $cln.$mn.\nExpected:\n $exp\nFound:\n$found", e) 72 | } 73 | else if(typeCandidates.length > 1) { 74 | val exp = showEmbeddedSignature(args, EmbeddedType.Unknown, mn) 75 | val found = nameCandidates.map(m => showJavaSignature(m.method)).mkString(" ", "\n ", "") 76 | throwError(s"${typeCandidates.length} ambiguous overloads of method $cln.$mn.\nExpected:\n $exp\nAmbiguous:\n$found", e) 77 | } 78 | else typeCandidates.head 79 | }.getOrElse(EmbeddedType.Unknown) 80 | } 81 | 82 | private[this] def showJavaSignature(m: Method): String = { 83 | s"${m.getReturnType.getName} ${m.getName}(${m.getParameterTypes.map(_.getName).mkString(", ")})" 84 | } 85 | 86 | private[this] def showEmbeddedSignature(args: Vector[(EmbeddedType, Boolean)], ret: EmbeddedType, name: String): String = { 87 | def f(t: EmbeddedType, out: Boolean): String = (if (out) "out " else "") + (t match { 88 | case EmbeddedType.Payload(pt) => pt.toString 89 | case EmbeddedType.Bool => "boolean" 90 | case EmbeddedType.Unknown => "?" 91 | case t => t.toString 92 | }) 93 | s"${f(ret, false)} ${name}(${args.map{ case (t, o) => f(t, o) }.mkString(", ")})" 94 | } 95 | } 96 | 97 | object ResolveEmbedded { 98 | private[this] val runtimeName = Runtime.getClass.getName 99 | private[this] val intrinsicsQN = runtimeName.split('.').toVector 100 | private[this] def mkOp(m: String) = (runtimeName, m, (intrinsicsQN :+ m).map(Ident(_))) 101 | private val eqLabel = mkOp("eqLabel") 102 | private val operators = Map( 103 | "==" -> mkOp("eq"), 104 | "+" -> mkOp("intAdd"), 105 | "-" -> mkOp("intSub"), 106 | "*" -> mkOp("intMult"), 107 | ) 108 | } 109 | -------------------------------------------------------------------------------- /src/main/scala/Runtime.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | trait IntOutput { def setValue(i: Int): Unit } 4 | trait LongOutput { def setValue(l: Long): Unit } 5 | trait IntBox extends IntOutput { def getValue: Int } 6 | trait LongBox extends LongOutput { def getValue: Long } 7 | trait RefOutput { def setValue(o: AnyRef): Unit } 8 | trait RefBox extends RefOutput { def getValue: AnyRef } // Also used for Label 9 | trait LifecycleManaged { def erase(): Unit; def copy(): LifecycleManaged } 10 | 11 | // Standalone boxes used for boxed temporary values in inlined payload computations 12 | final class IntBoxImpl extends IntBox { 13 | private[this] var value: Int = _ 14 | def getValue: Int = value 15 | def setValue(v: Int): Unit = value = v 16 | } 17 | final class LongBoxImpl extends LongBox { 18 | private[this] var value: Long = _ 19 | def getValue: Long = value 20 | def setValue(v: Long): Unit = value = v 21 | } 22 | final class RefBoxImpl extends RefBox { 23 | private[this] var value: AnyRef = _ 24 | def getValue: AnyRef = value 25 | def setValue(v: AnyRef): Unit = value = v 26 | } 27 | 28 | object Runtime { 29 | def add(a: Int, b: Int, res: IntOutput): Unit = res.setValue(a + b) 30 | def mult(a: Int, b: Int, res: IntOutput): Unit = res.setValue(a * b) 31 | def strlen(s: String): Int = s.length 32 | 33 | def eraseRef(o: AnyRef): Unit = o match { 34 | case o: LifecycleManaged => o.erase() 35 | case _ => 36 | } 37 | 38 | def dupRef(o: AnyRef, r1: RefOutput, r2: RefOutput): Unit = { 39 | r1.setValue(o) 40 | r2.setValue(o match { 41 | case o: LifecycleManaged => o.copy() 42 | case o => o 43 | }) 44 | } 45 | 46 | def eq(a: AnyRef, b: AnyRef): Boolean = a == b 47 | def eqLabel(a: AnyRef, b: AnyRef): Boolean = a eq b 48 | def eq(a: Int, b: Int): Boolean = a == b 49 | def intAdd(a: Int, b: Int): Int = a + b 50 | def intSub(a: Int, b: Int): Int = a - b 51 | def intMult(a: Int, b: Int): Int = a * b 52 | } 53 | -------------------------------------------------------------------------------- /src/main/scala/SymCounts.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import de.szeiger.interact.ast.Symbol 4 | 5 | import scala.collection.mutable 6 | 7 | final class SymCounts(private val m: mutable.Map[Symbol, Int]) extends AnyVal { 8 | override def toString: String = { 9 | m.iterator.filter(_._2 > 0).toVector.sortBy(_._1.id).map { 10 | case (s, 1) => s.id 11 | case (s, i) => s"$i ${s.id}" 12 | }.mkString("[", ", ", "]") 13 | } 14 | 15 | def size = m.size 16 | 17 | def + (s: Symbol): SymCounts = { 18 | val n = mutable.Map.from(m) 19 | SymCounts.add(n, s, 1) 20 | new SymCounts(n) 21 | } 22 | 23 | def - (s: Symbol): SymCounts = { 24 | val n = mutable.Map.from(m) 25 | SymCounts.add(n, s, -1) 26 | new SymCounts(n) 27 | } 28 | 29 | def ++ (o: SymCounts): SymCounts = { 30 | val n = mutable.Map.from(m) 31 | o.m.foreach { case (k, v) => SymCounts.add(n, k, v) } 32 | new SymCounts(n) 33 | } 34 | 35 | def ++ (it: IterableOnce[Symbol]): SymCounts = { 36 | val n = mutable.Map.from(m) 37 | it.iterator.foreach(SymCounts.add(n, _, 1)) 38 | new SymCounts(n) 39 | } 40 | 41 | def -- (o: SymCounts): SymCounts = { 42 | val n = mutable.Map.from(m) 43 | o.m.foreach { case (k, v) => SymCounts.add(n, k, -v) } 44 | new SymCounts(n) 45 | } 46 | 47 | def count(s: Symbol): Int = m.getOrElse(s, 0) 48 | 49 | def contains(s: Symbol): Boolean = count(s) > 0 50 | } 51 | 52 | object SymCounts { 53 | private def add(m: mutable.Map[Symbol, Int], s: Symbol, count: Int): Unit = 54 | m.updateWith(s) { 55 | case Some(i) => 56 | val j = i+count 57 | if(j > 0) Some(j) else None 58 | case None if count > 0 => Some(count) 59 | case None => None 60 | } 61 | 62 | private def add(m: mutable.Map[Symbol, Int], it: IterableOnce[Symbol]): Unit = it.iterator.foreach(add(m, _, 1)) 63 | 64 | def apply(ss: Symbol*): SymCounts = { 65 | val m = mutable.Map.empty[Symbol, Int] 66 | add(m, ss) 67 | new SymCounts(m) 68 | } 69 | 70 | def from(syms: IterableOnce[Symbol]): SymCounts = { 71 | val m = mutable.Map.empty[Symbol, Int] 72 | add(m, syms) 73 | new SymCounts(m) 74 | } 75 | 76 | val empty = from(Nil) 77 | } 78 | -------------------------------------------------------------------------------- /src/main/scala/ast/Symbols.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.ast 2 | 3 | import scala.collection.mutable 4 | 5 | final class Symbol(val id: String, val arity: Int = 0, val returnArity: Int = 1, var payloadType: PayloadType = PayloadType.VOID, 6 | val matchContinuationPort: Int = -2, private[this] var _flags: Int = 0) { 7 | import Symbol._ 8 | 9 | def flags: Int = _flags 10 | def isCons: Boolean = (_flags & FLAG_CONS) != 0 11 | def isDef: Boolean = (_flags & FLAG_DEF) != 0 12 | def isEmbedded: Boolean = (_flags & FLAG_EMBEDDED) != 0 13 | def isPattern: Boolean = (_flags & FLAG_PATTERN) != 0 14 | 15 | def setPattern(): Unit = _flags = (_flags | FLAG_PATTERN) 16 | 17 | def isContinuation: Boolean = matchContinuationPort != -2 18 | def callArity: Int = arity + 1 - returnArity 19 | def hasPayload: Boolean = payloadType != PayloadType.VOID 20 | override def toString: String = id 21 | def isDefined: Boolean = this != Symbol.NoSymbol 22 | def isEmpty: Boolean = !isDefined 23 | def isSingleton: Boolean = arity == 0 && payloadType.isEmpty 24 | def uniqueStr: String = if(isDefined) s"$id:${System.identityHashCode(this)}" else "" 25 | def show: String = s"$uniqueStr<$payloadType>" 26 | } 27 | 28 | object Symbol { 29 | def apply(id: String, arity: Int = 0, returnArity: Int = 1, 30 | isCons: Boolean = false, isDef: Boolean = false, 31 | payloadType: PayloadType = PayloadType.VOID, matchContinuationPort: Int = -2, 32 | isEmbedded: Boolean = false, isPattern: Boolean = false): Symbol = 33 | new Symbol(id, arity, returnArity, payloadType, matchContinuationPort, 34 | (if(isCons) FLAG_CONS else 0) | (if(isDef) FLAG_DEF else 0) | (if(isEmbedded) FLAG_EMBEDDED else 0) | (if(isPattern) FLAG_PATTERN else 0) 35 | ) 36 | 37 | val FLAG_CONS = 1 << 0 38 | val FLAG_DEF = 1 << 1 39 | val FLAG_EMBEDDED = 1 << 2 40 | val FLAG_PATTERN = 1 << 3 41 | 42 | val NoSymbol = new Symbol("") 43 | } 44 | 45 | class SymbolGen(prefix2: String, isEmbedded: Boolean = false, payloadType: PayloadType = PayloadType.VOID) { 46 | private[this] var last = 0 47 | def apply(isEmbedded: Boolean = isEmbedded, payloadType: PayloadType = payloadType, prefix: String = ""): Symbol = { 48 | last += 1 49 | Symbol(prefix+prefix2+last, isEmbedded = isEmbedded, payloadType = payloadType) 50 | } 51 | def id(isEmbedded: Boolean = isEmbedded, payloadType: PayloadType = payloadType, prefix: String = ""): Ident = { 52 | val s = apply(isEmbedded, payloadType, prefix) 53 | val i = Ident(s.id) 54 | i.sym = s 55 | i 56 | } 57 | } 58 | 59 | class Symbols(parent: Option[Symbols] = None) { 60 | private[this] val syms = mutable.HashMap.empty[String, Symbol] 61 | def define(id: String, isCons: Boolean = false, isDef: Boolean = false, arity: Int = 0, returnArity: Int = 1, 62 | payloadType: PayloadType = PayloadType.VOID, matchContinuationPort: Int = -2, 63 | isEmbedded: Boolean = false): Symbol = { 64 | assert(get(id).isEmpty) 65 | val sym = Symbol(id, arity, returnArity, isCons, isDef, payloadType, matchContinuationPort, isEmbedded) 66 | syms.put(id, sym) 67 | sym 68 | } 69 | def defineCons(id: String, arity: Int, payloadType: PayloadType): Symbol = 70 | define(id, isCons = true, arity = arity, payloadType = payloadType) 71 | def defineDef(id: String, argLen: Int, retLen: Int, payloadType: PayloadType): Symbol = 72 | define(id, isCons = true, isDef = true, arity = argLen + retLen - 1, returnArity = retLen, payloadType = payloadType) 73 | def getOrAdd(id: Ident): Symbol = get(id).getOrElse(define(id.s)) 74 | def contains(id: Ident): Boolean = get(id).isDefined 75 | def containsLocal(id: Ident): Boolean = syms.contains(id.s) 76 | def get(id: Ident): Option[Symbol] = get(id.s) 77 | def get(id: String): Option[Symbol] = syms.get(id).orElse(parent.flatMap(_.get(id))) 78 | def apply(id: Ident): Symbol = apply(id.s) 79 | def apply(id: String): Symbol = 80 | get(id).getOrElse(sys.error(s"No symbol found for $id")) 81 | def symbols: Iterator[Symbol] = syms.valuesIterator ++ parent.map(_.symbols).getOrElse(Iterator.empty) 82 | def sub(): Symbols = new Symbols(Some(this)) 83 | override def toString: String = s"Symbols(${syms.map { case (_, v) => s"$v"}.mkString(", ")})" 84 | } 85 | 86 | class RefsMap(parent: Option[RefsMap] = None) { 87 | private[this] val data = mutable.Map.empty[Symbol, Int] 88 | private[this] val hasErr = mutable.Set.empty[Symbol] 89 | def inc(s: Symbol): Unit = data.update(s, { 90 | val c = apply(s) + 1 91 | if(c == 3) { 92 | if(!s.isEmbedded || !s.payloadType.canCopy) hasErr += s 93 | } 94 | c 95 | }) 96 | def local: Iterator[(Symbol, Int)] = data.iterator.map {case (s, c) => 97 | (s, c - parent.map(_(s)).getOrElse(0)) 98 | }.filter(_._2 > 0) 99 | def iterator: Iterator[(Symbol, Int)] = parent match { 100 | case Some(r) => r.iterator.filter(t => !data.contains(t._1)) ++ data.iterator 101 | case None => data.iterator 102 | } 103 | def apply(s: Symbol): Int = data.getOrElse(s, parent match { 104 | case Some(r) => r(s) 105 | case None => 0 106 | }) 107 | def free: Iterator[Symbol] = iterator.filter(_._2 == 1).map(_._1) 108 | def linear: Iterator[Symbol] = iterator.filter(_._2 == 2).map(_._1) 109 | def err: Iterator[Symbol] = parent match { 110 | case Some(r) => r.err ++ hasErr.iterator 111 | case None => hasErr.iterator 112 | } 113 | def nonFree: Iterator[Symbol] = iterator.filter(_._2 > 1).map(_._1) 114 | def hasNonFree: Boolean = hasErr.nonEmpty || linear.hasNext || parent.exists(_.hasNonFree) 115 | def hasError: Boolean = hasErr.nonEmpty || parent.exists(_.hasError) 116 | private[this] def collect0(n: Node, embedded: Boolean, regular: Boolean): Unit = n match { 117 | case n: Ident => 118 | val use = (n.sym.isEmbedded && embedded) || (!n.sym.isEmbedded && regular) 119 | if(use && !n.sym.isEmpty && !n.sym.isCons) inc(n.sym) 120 | case n => n.nodeChildren.foreach(collect0(_, embedded, regular)) 121 | } 122 | def collectRegular(n: Node): Unit = collect0(n, false, true) 123 | def collectEmbedded(n: Node): Unit = collect0(n, true, false) 124 | def collectAll(n: Node): Unit = collect0(n, true, true) 125 | def sub(): RefsMap = parent match { 126 | case Some(r) if data.isEmpty => r.sub() 127 | case _ => new RefsMap(Some(this)) 128 | } 129 | def allSymbols: Iterator[Symbol] = iterator.map(_._1).filter(_.isDefined) 130 | } 131 | -------------------------------------------------------------------------------- /src/main/scala/ast/Transform.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.ast 2 | 3 | import de.szeiger.interact.{AllocateTemp, BranchWiring, CheckPrincipal, Connection, CreateLabelsComp, EmbArg, InitialRuleWiring, PayloadAssignment, PayloadComputation, PayloadComputationPlan, PayloadMethodApplication, PayloadMethodApplicationWithReturn, ReuseLabelsComp, RuleWiring} 4 | 5 | abstract class Transform { 6 | import Transform._ 7 | 8 | def apply(n: Node): Node = n match { 9 | case n: Branch => apply(n) 10 | case n: DefRule => apply(n) 11 | case n: Expr => apply(n) 12 | case n: EmbeddedExpr => apply(n) 13 | case n: Statement => 14 | val v = apply(n: Statement) 15 | assert(v.length == 1) 16 | v.head 17 | case n: CompilationUnit => apply(n) 18 | } 19 | 20 | def apply(n: DefRule): DefRule = { 21 | val on2 = mapC(n.on)(apply) 22 | val red2 = mapC(n.reduced)(apply) 23 | if((on2 eq n.on) && (red2 eq n.reduced)) n 24 | else n.copy(on2, red2) 25 | } 26 | 27 | def apply(n: IdentOrWildcard): IdentOrWildcard = n match { 28 | case n: Ident => apply(n) 29 | case n: Wildcard => apply(n) 30 | } 31 | 32 | def apply(n: Apply): Apply = { 33 | val t2 = apply(n.target) 34 | val emb2 = mapOC(n.embedded)(apply) 35 | val args2 = mapC(n.args)(apply) 36 | if((t2 eq n.target) && (emb2 eq n.embedded) && (args2 eq n.args)) n 37 | else n.copy(t2, emb2, args2) 38 | } 39 | 40 | def apply(n: ApplyCons): ApplyCons = { 41 | val t2 = apply(n.target) 42 | val emb2 = mapOC(n.embedded)(apply) 43 | val args2 = mapC(n.args)(apply) 44 | if((t2 eq n.target) && (emb2 eq n.embedded) && (args2 eq n.args)) n 45 | else n.copy(t2, emb2, args2) 46 | } 47 | 48 | def apply(n: Tuple): Tuple = { 49 | val e2 = mapC(n.exprs)(apply) 50 | if(e2 eq n.exprs) n 51 | else n.copy(e2) 52 | } 53 | 54 | def apply(n: Assignment): Assignment = { 55 | val l2 = apply(n.lhs) 56 | val r2 = apply(n.rhs) 57 | if((l2 eq n.lhs) && (r2 eq n.rhs)) n 58 | else n.copy(l2, r2) 59 | } 60 | 61 | def apply(n: Expr): Expr = n match { 62 | case n: IdentOrWildcard => apply(n) 63 | case n: Apply => apply(n) 64 | case n: ApplyCons => apply(n) 65 | case n: Tuple => apply(n) 66 | case n: Assignment => apply(n) 67 | case n: NatLit => apply(n) 68 | } 69 | 70 | def apply(n: NatLit): NatLit = n 71 | 72 | def apply(n: StringLit): StringLit = n 73 | 74 | def apply(n: IntLit): IntLit = n 75 | 76 | def apply(n: Ident): Ident = n 77 | 78 | def apply(n: Wildcard): Wildcard = n 79 | 80 | def apply(n: EmbeddedApply): EmbeddedApply = { 81 | val m2 = mapC(n.methodQNIds)(apply) 82 | val args2 = mapC(n.args)(apply) 83 | if((m2 eq n.methodQNIds) && (args2 eq n.args)) n 84 | else n.copy(m2, args2, n.op, n.embTp) 85 | } 86 | 87 | def apply(n: EmbeddedAssignment): EmbeddedAssignment = { 88 | val l2 = apply(n.lhs) 89 | val r2 = apply(n.rhs) 90 | if((l2 eq n.lhs) && (r2 eq n.rhs)) n 91 | else n.copy(l2, r2) 92 | } 93 | 94 | def apply(n: CreateLabels): CreateLabels = n 95 | 96 | def apply(n: EmbeddedExpr): EmbeddedExpr = n match { 97 | case n: StringLit => apply(n) 98 | case n: IntLit => apply(n) 99 | case n: Ident => apply(n) 100 | case n: EmbeddedApply => apply(n) 101 | case n: EmbeddedAssignment => apply(n) 102 | case n: CreateLabels => apply(n) 103 | } 104 | 105 | def apply(n: Match): Vector[Statement] = Vector({ 106 | val on2 = mapC(n.on)(apply) 107 | val red2 = mapC(n.reduced)(apply) 108 | if((on2 eq n.on) && (red2 eq n.reduced)) n 109 | else n.copy(on2, red2) 110 | }) 111 | 112 | def apply(n: Cons): Vector[Statement] = Vector({ 113 | val n2 = apply(n.name) 114 | val a2 = mapC(n.args)(apply) 115 | val e2 = mapOC(n.embeddedId)(apply) 116 | val r2 = mapOC(n.ret)(apply) 117 | val d2 = mapOC(n.der)(mapC(_)(apply)) 118 | if((n2 eq n.name) && (a2 eq n.args) && (e2 eq n.embeddedId) && (r2 eq n.ret) && (d2 eq n.der)) n 119 | else n.copy(n2, a2, n.operator, n.payloadType, e2, r2, d2) 120 | }) 121 | 122 | def apply(n: Let): Vector[Statement] = Vector({ 123 | val d2 = mapC(n.defs)(apply) 124 | val e2 = mapC(n.embDefs)(apply) 125 | val f2 = mapC(n.free)(apply) 126 | if((d2 eq n.defs) && (e2 eq n.embDefs) && (f2 eq n.free)) n 127 | else n.copy(d2, e2, f2) 128 | }) 129 | 130 | def apply(n: Def): Vector[Statement] = Vector({ 131 | val n2 = apply(n.name) 132 | val a2 = mapC(n.args)(apply) 133 | val e2 = mapOC(n.embeddedId)(apply) 134 | val r2 = mapC(n.ret)(apply) 135 | val u2 = mapC(n.rules)(apply) 136 | if((n2 eq n.name) && (a2 eq n.args) && (e2 eq n.embeddedId) && (r2 eq n.ret) && (u2 eq n.rules)) n 137 | else n.copy(n2, a2, n.operator, n.payloadType, e2, r2, u2) 138 | }) 139 | 140 | def apply(n: Statement): Vector[Statement] = n match { 141 | case n: Match => apply(n) 142 | case n: Cons => apply(n) 143 | case n: Let => apply(n) 144 | case n: Def => apply(n) 145 | case n: CheckedRule => apply(n) 146 | case n: RuleWiring => apply(n) 147 | case n: InitialRuleWiring => apply(n) 148 | } 149 | 150 | def apply(n: CheckedRule): Vector[Statement] = n match { 151 | case n: DerivedRule => apply(n) 152 | case n: MatchRule => apply(n) 153 | } 154 | 155 | def apply(n: DerivedRule): Vector[Statement] = Vector(n) 156 | 157 | def apply(n: MatchRule): Vector[Statement] = Vector({ 158 | val i1 = apply(n.id1) 159 | val i2 = apply(n.id2) 160 | val a12 = mapC(n.args1)(apply) 161 | val a22 = mapC(n.args2)(apply) 162 | val emb12 = mapOC(n.emb1)(apply) 163 | val emb22 = mapOC(n.emb2)(apply) 164 | val red2 = mapC(n.branches)(apply) 165 | if((i1 eq n.id1) && (i2 eq n.id2) && (a12 eq n.args1) && (a22 eq n.args2) && (emb12 eq n.emb1) && (emb22 eq n.emb2) && (red2 eq n.branches)) n 166 | else n.copy(i1, i2, a12, a22, emb12, emb22, red2) 167 | }) 168 | 169 | def apply(n: RuleWiring): Vector[Statement] = Vector({ 170 | val b2 = mapC(n.branches)(apply) 171 | if(b2 eq n.branches) n 172 | else n.copy(n.sym1, n.sym2, b2, n.derived) 173 | }) 174 | 175 | def apply(n: InitialRuleWiring): Vector[Statement] = Vector({ 176 | val b2 = apply(n.branch) 177 | if(b2 eq n.branch) n 178 | else n.copy(branch = b2) 179 | }) 180 | 181 | def apply(n: BranchWiring): BranchWiring = { 182 | val co2 = flatMapC(n.conns)(apply) 183 | val pc2 = flatMapOC(n.payloadComps)(apply) 184 | val cn2 = flatMapOC(n.cond)(apply) 185 | val bw2 = mapC(n.branches)(apply) 186 | if((co2 eq n.conns) && (pc2 eq n.payloadComps) && (cn2 eq n.cond) && (bw2 eq n.branches)) n 187 | else n.copy(n.cellOffset, n.cells, co2, pc2, cn2, bw2) 188 | } 189 | 190 | def apply(n: Connection): Set[Connection] = Set(n) 191 | 192 | def apply(n: CompilationUnit): CompilationUnit = { 193 | val st2 = flatMapC(n.statements)(apply) 194 | if(st2 eq n.statements) n 195 | else n.copy(st2) 196 | } 197 | 198 | def apply(n: Branch): Branch = { 199 | val cond2 = mapOC(n.cond)(apply) 200 | val embRed2 = mapC(n.embRed)(apply) 201 | val red2 = mapC(n.reduced)(apply) 202 | if((cond2 eq n.cond) && (embRed2 eq n.embRed) && (red2 eq n.reduced)) n 203 | else n.copy(cond2, embRed2, red2) 204 | } 205 | 206 | def apply(n: EmbArg): EmbArg = n 207 | 208 | def apply(n: PayloadComputationPlan): Option[PayloadComputationPlan] = n match { 209 | case n: PayloadComputation => apply(n) 210 | case n: ReuseLabelsComp => Some(apply(n)) 211 | case n: AllocateTemp => Some(apply(n)) 212 | } 213 | 214 | def apply(n: PayloadComputation): Option[PayloadComputation] = Some(n match { 215 | case n: PayloadMethodApplication => apply(n) 216 | case n: PayloadMethodApplicationWithReturn => apply(n) 217 | case n: PayloadAssignment => apply(n) 218 | case n: CreateLabelsComp => apply(n) 219 | case n: CheckPrincipal => apply(n) 220 | }) 221 | 222 | def apply(n: CheckPrincipal): CheckPrincipal = n 223 | 224 | def apply(n: PayloadMethodApplication): PayloadMethodApplication = { 225 | val ea2 = mapC(n.embArgs)(apply) 226 | if (ea2 eq n.embArgs) n 227 | else n.copy(embArgs = ea2) 228 | } 229 | 230 | def apply(n: PayloadMethodApplicationWithReturn): PayloadComputation = { 231 | val m2 = apply(n.method) 232 | val ri2 = apply(n.retIndex) 233 | if((m2 eq n.method) && (ri2 eq n.retIndex)) n 234 | else n.copy(method = m2, retIndex = ri2) 235 | } 236 | 237 | def apply(n: PayloadAssignment): PayloadComputation = { 238 | val si2 = apply(n.sourceIdx) 239 | val ti2 = apply(n.targetIdx) 240 | if((si2 eq n.sourceIdx) && (ti2 eq n.targetIdx)) n 241 | else n.copy(sourceIdx = si2, targetIdx = ti2) 242 | } 243 | 244 | def apply(n: CreateLabelsComp): PayloadComputation = { 245 | val ea2 = mapC(n.embArgs)(apply) 246 | if (ea2 eq n.embArgs) n 247 | else n.copy(embArgs = ea2) 248 | } 249 | 250 | def apply(n: ReuseLabelsComp): PayloadComputationPlan = { 251 | val ea2 = mapC(n.embArgs)(apply) 252 | if (ea2 eq n.embArgs) n 253 | else n.copy(embArgs = ea2) 254 | } 255 | 256 | def apply(n: AllocateTemp): PayloadComputationPlan = { 257 | val ea2 = apply(n.ea) 258 | if (ea2 eq n.ea) n 259 | else n.copy(ea = ea2.asInstanceOf[EmbArg.Temp]) 260 | } 261 | } 262 | 263 | object Transform { 264 | def mapC[T,R](xs: Vector[T])(f: T => R): Vector[R] = { 265 | var changed = false 266 | val xs2 = xs.map { x => 267 | val x2 = f(x) 268 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true 269 | x2 270 | } 271 | if(changed) xs2 else xs.asInstanceOf[Vector[R]] 272 | } 273 | 274 | def flatMapC[T,R](xs: Vector[T])(f: T => Vector[R]): Vector[R] = { 275 | var changed = false 276 | val xs2 = xs.flatMap { x => 277 | val x2 = f(x) 278 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true 279 | x2 280 | } 281 | if(changed) xs2 else xs.asInstanceOf[Vector[R]] 282 | } 283 | 284 | def flatMapC[T,R](xs: Set[T])(f: T => Set[R]): Set[R] = { 285 | var changed = false 286 | val xs2 = xs.flatMap { x => 287 | val x2 = f(x) 288 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true 289 | x2 290 | } 291 | if(changed) xs2 else xs.asInstanceOf[Set[R]] 292 | } 293 | 294 | def flatMapOC[T,R](xs: Vector[T])(f: T => Option[R]): Vector[R] = { 295 | var changed = false 296 | val xs2 = xs.flatMap { x => 297 | val x2 = f(x) 298 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true 299 | x2 300 | } 301 | if(changed) xs2 else xs.asInstanceOf[Vector[R]] 302 | } 303 | 304 | def mapOC[T,R](o: Option[T])(f: T => R): Option[R] = o match { 305 | case None => o.asInstanceOf[Option[R]] 306 | case Some(x) => 307 | val x2 = f(x) 308 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) Some(x2) 309 | else o.asInstanceOf[Option[R]] 310 | } 311 | 312 | def flatMapOC[T,R](o: Option[T])(f: T => Option[R]): Option[R] = o match { 313 | case None => o.asInstanceOf[Option[R]] 314 | case Some(x) => 315 | f(x) match { 316 | case s @ Some(x2) if x.asInstanceOf[AnyRef] ne x2.asInstanceOf[AnyRef] => s 317 | case _ => o.asInstanceOf[Option[R]] 318 | } 319 | } 320 | } 321 | -------------------------------------------------------------------------------- /src/main/scala/codegen/AbstractCodeGen.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen 2 | 3 | import de.szeiger.interact.Config 4 | import de.szeiger.interact.ast.Symbol 5 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _} 6 | import org.objectweb.asm.util.{CheckClassAdapter, Textifier, TraceClassVisitor} 7 | import org.objectweb.asm.{ClassReader, ClassWriter => AClassWriter} 8 | 9 | import java.io.{OutputStreamWriter, PrintWriter} 10 | import java.util.zip.CRC32 11 | 12 | abstract class AbstractCodeGen(config: Config) { 13 | import AbstractCodeGen._ 14 | 15 | private[this] def getCRC32(a: Array[Byte]): Long = { 16 | val crc = new CRC32 17 | crc.update(a) 18 | crc.getValue 19 | } 20 | 21 | protected def addClass(cl: ClassWriter, cls: ClassDSL): Unit = { 22 | val cw = new AClassWriter(AClassWriter.COMPUTE_FRAMES) 23 | val ca = new CheckClassAdapter(cw) 24 | cls.accept(ca) 25 | val raw = cw.toByteArray 26 | if(config.logCodeGenSummary) println(s"Generated class ${cls.name} (${raw.length} bytes, crc ${getCRC32(raw)})") 27 | if(config.logGeneratedClasses.exists(s => s == "*" || cls.name.contains(s))) { 28 | val cr = new ClassReader(raw) 29 | cr.accept(new TraceClassVisitor(cw, new Textifier(), new PrintWriter(new OutputStreamWriter(System.out))), 0) 30 | } 31 | cl.writeClass(cls.javaName, raw) 32 | } 33 | 34 | // Create a new Symbol instance that matches the given Symbol and place it on the stack 35 | protected def reifySymbol(m: MethodDSL, sym: Symbol): MethodDSL = { 36 | if(sym.isEmpty) m.invokestatic(symbol_NoSymbol) 37 | else m.newInitDup(new_Symbol) { 38 | m.ldc(sym.id).iconst(sym.arity).iconst(sym.returnArity) 39 | m.iconst(sym.payloadType.value).iconst(sym.matchContinuationPort) 40 | m.iconst(sym.flags) 41 | } 42 | } 43 | } 44 | 45 | object AbstractCodeGen { 46 | val symbolT = tp.c[Symbol] 47 | val symbol_NoSymbol = symbolT.method("NoSymbol", tp.m()(symbolT)) 48 | val new_Symbol = symbolT.constr(tp.m(tp.c[String], tp.I, tp.I, tp.I, tp.I, tp.I).V) 49 | 50 | private[this] def encodeName(s: String): String = { 51 | val b = new StringBuilder() 52 | s.foreach { 53 | case '|' => b.append("$bar") 54 | case '^' => b.append("$up") 55 | case '&' => b.append("$amp") 56 | case '<' => b.append("$less") 57 | case '>' => b.append("$greater") 58 | case ':' => b.append("$colon") 59 | case '+' => b.append("$plus") 60 | case '-' => b.append("$minus") 61 | case '*' => b.append("$times") 62 | case '/' => b.append("$div") 63 | case '%' => b.append("$percent") 64 | case c => b.append(c) 65 | } 66 | b.result() 67 | } 68 | 69 | def encodeName(s: Symbol): String = { 70 | assert(s.isDefined) 71 | encodeName(s.id) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/codegen/BoxOps.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen 2 | 3 | import de.szeiger.interact.{IntBox, IntBoxImpl, LongBox, LongBoxImpl, RefBox, RefBoxImpl} 4 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _} 5 | 6 | class BoxDesc private (val boxT: InterfaceOwner, val implT: ClassOwner, val unboxedT: ValDesc) { 7 | lazy val newBox = implT.constr(tp.m().V) 8 | lazy val getValue = boxT.method("getValue", tp.m()(unboxedT)) 9 | lazy val setValue = boxT.method("setValue", tp.m(unboxedT).V) 10 | 11 | def implementInterface(c: ClassDSL, owner: ClassOwner): Unit = { 12 | if(unboxedT != tp.V) { 13 | val field = owner.field("value", unboxedT) 14 | c.field(Acc.PUBLIC, field) 15 | c.setter(field) 16 | c.getter(field) 17 | } 18 | } 19 | } 20 | 21 | object BoxDesc { 22 | val intDesc = new BoxDesc(tp.i[IntBox], tp.c[IntBoxImpl], tp.I) 23 | val longDesc = new BoxDesc(tp.i[LongBox], tp.c[LongBoxImpl], tp.J) 24 | val refDesc = new BoxDesc(tp.i[RefBox], tp.c[RefBoxImpl], tp.Object) 25 | val voidDesc = new BoxDesc(null, null, tp.V) 26 | } 27 | 28 | class BoxOps(m: MethodDSL, val boxDesc: BoxDesc) { 29 | protected[this] lazy val t: TypedDSL = TypedDSL(boxDesc.unboxedT, m) 30 | 31 | def unboxedT: ValDesc = boxDesc.unboxedT 32 | 33 | def load(v: VarIdx) = t.xload(v) 34 | def store(v: VarIdx) = t.xstore(v) 35 | 36 | def unboxedClass: Class[_] = boxDesc.unboxedT.jvmClass.get 37 | def boxedClass: Class[_] = boxDesc.boxT.jvmClass.get 38 | 39 | def getBoxValue = m.invoke(boxDesc.getValue) 40 | def setBoxValue = m.invoke(boxDesc.setValue) 41 | def newBoxStore(name: String): VarIdx = m.newInitDup(boxDesc.newBox)().storeLocal(boxDesc.implT, name) 42 | def newBoxStoreDup: VarIdx = m.newInitDup(boxDesc.newBox)().dup.storeLocal(boxDesc.implT) 43 | } 44 | -------------------------------------------------------------------------------- /src/main/scala/codegen/ClassWriter.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen 2 | 3 | import de.szeiger.interact.Config 4 | import org.jetbrains.java.decompiler.main.decompiler.ConsoleDecompiler 5 | 6 | import java.io.{File, FileOutputStream} 7 | import java.nio.file.{Files, Path} 8 | import java.util.jar.{JarEntry, JarOutputStream} 9 | 10 | trait ClassWriter { self => 11 | def writeClass(javaName: String, classFile: Array[Byte]): Unit 12 | def close(): Unit = () 13 | 14 | def and(w: ClassWriter): ClassWriter = new ClassWriter { 15 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = { 16 | self.writeClass(javaName, classFile) 17 | w.writeClass(javaName, classFile) 18 | } 19 | override def close(): Unit = { 20 | self.close() 21 | w.close() 22 | } 23 | } 24 | } 25 | 26 | final class NullClassWriter extends ClassWriter { 27 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = () 28 | } 29 | 30 | final class JarClassWriter(file: Path) extends ClassWriter { 31 | if(Files.exists(file)) Files.delete(file) 32 | private[this] val out = new JarOutputStream(new FileOutputStream(file.toFile)) 33 | private[this] var closed = false 34 | 35 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = { 36 | assert(!closed) 37 | out.putNextEntry(new JarEntry(javaName.replace('.', '/')+".class")) 38 | out.write(classFile) 39 | out.closeEntry() 40 | } 41 | 42 | override def close(): Unit = { 43 | out.close() 44 | closed = true 45 | } 46 | } 47 | 48 | final class ClassDirWriter(dir: Path) extends ClassWriter { 49 | ClassWriter.delete(dir) 50 | 51 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = { 52 | val p = javaName.split('.').foldLeft(dir) { case (p, s) => p.resolve(s) } 53 | Files.createDirectories(p.getParent) 54 | val f = p.resolveSibling(p.getFileName + ".class") 55 | Files.write(f, classFile) 56 | } 57 | } 58 | 59 | private final class DecompileAdapter(parent: ClassWriter, classes: Path, out: Path) extends ClassWriter { 60 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = parent.writeClass(javaName, classFile) 61 | 62 | override def close(): Unit = { 63 | parent.close() 64 | ClassWriter.delete(out) 65 | Files.createDirectories(out) 66 | val args = Array(classes.toAbsolutePath.toFile.getPath, out.toAbsolutePath.toFile.getPath) 67 | ConsoleDecompiler.main(args) 68 | } 69 | } 70 | 71 | private final class SkipClassWriter(parent: ClassWriter, skip: Set[String]) extends ClassWriter { 72 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = 73 | if(!skip.contains(javaName) && !skip.contains("*")) parent.writeClass(javaName, classFile) 74 | override def close(): Unit = parent.close() 75 | } 76 | 77 | object ClassWriter { 78 | private[codegen] def delete(p: Path): Unit = if(Files.exists(p)) { 79 | if(Files.isDirectory(p)) Files.list(p).forEach(delete) 80 | Files.delete(p) 81 | } 82 | 83 | def apply(config: Config, parent: ClassWriter): ClassWriter = { 84 | val p = if(config.skipCodeGen.nonEmpty) new SkipClassWriter(parent, config.skipCodeGen) else parent 85 | val cw: ClassWriter = config.writeOutput match { 86 | case Some(f) if f.getFileName.toString.endsWith(".jar") => p.and(new JarClassWriter(f)) 87 | case Some(f) => p.and(new ClassDirWriter(f)) 88 | case None => p 89 | } 90 | (config.writeJava, config.writeOutput) match { 91 | case (Some(classes), None) => ??? 92 | case (Some(sources), Some(classes)) => new DecompileAdapter(cw, classes, sources) 93 | case _ => cw 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/main/scala/codegen/ParSupport.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen 2 | 3 | import java.lang.invoke.VarHandle 4 | import java.util.concurrent.atomic.AtomicInteger 5 | import java.util.concurrent.{CountedCompleter, ForkJoinPool} 6 | import scala.annotation.tailrec 7 | 8 | object ParSupport { 9 | val defaultParallelism: Int = math.min(ForkJoinPool.commonPool().getParallelism, 8) 10 | 11 | final class AtomicCounter { 12 | private[this] val ai = new AtomicInteger(0) 13 | 14 | @tailrec def max(i: Int): Unit = { 15 | val v = ai.get() 16 | if(i > v && !ai.compareAndSet(v, i)) max(i) 17 | } 18 | 19 | def get: Int = ai.get() 20 | } 21 | 22 | def foreach[T >: Null <: AnyRef](a: IterableOnce[T], parallelism: Int)(f: T => Unit): Unit = { 23 | if(parallelism <= 1) a.iterator.foreach(f) 24 | else { 25 | val it = a.iterator 26 | def getNext: T = it.synchronized { if(it.hasNext) it.next() else null } 27 | final class Task(parent: CountedCompleter[_]) extends CountedCompleter[Null](parent) { 28 | @tailrec def compute: Unit = getNext match { 29 | case null => 30 | VarHandle.releaseFence() 31 | propagateCompletion() 32 | case v => 33 | f(v) 34 | compute 35 | } 36 | } 37 | (new CountedCompleter[Null](null, parallelism) { 38 | override def compute(): Unit = { 39 | for(i <- 1 to parallelism) new Task(this).fork() 40 | propagateCompletion() 41 | } 42 | }).invoke() 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/codegen/dsl/Acc.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen.dsl 2 | 3 | import org.objectweb.asm.Opcodes._ 4 | 5 | trait AccessFlags extends Any { 6 | protected[this] def acc: Int 7 | @inline def PUBLIC = new Acc(acc | ACC_PUBLIC) 8 | @inline def PRIVATE = new Acc(acc | ACC_PRIVATE) 9 | @inline def PROTECTED = new Acc(acc | ACC_PROTECTED) 10 | @inline def STATIC = new Acc(acc | ACC_STATIC) 11 | @inline def FINAL = new Acc(acc | ACC_FINAL) 12 | @inline def SUPER = new Acc(acc | ACC_SUPER) 13 | @inline def SYNCHRONIZED = new Acc(acc | ACC_SYNCHRONIZED) 14 | @inline def OPEN = new Acc(acc | ACC_OPEN) 15 | @inline def TRANSITIVE = new Acc(acc | ACC_TRANSITIVE) 16 | @inline def VOLATILE = new Acc(acc | ACC_VOLATILE) 17 | @inline def BRIDGE = new Acc(acc | ACC_BRIDGE) 18 | @inline def STATIC_PHASE = new Acc(acc | ACC_STATIC_PHASE) 19 | @inline def VARARGS = new Acc(acc | ACC_VARARGS) 20 | @inline def TRANSIENT = new Acc(acc | ACC_TRANSIENT) 21 | @inline def NATIVE = new Acc(acc | ACC_NATIVE) 22 | @inline def INTERFACE = new Acc(acc | ACC_INTERFACE) 23 | @inline def ABSTRACT = new Acc(acc | ACC_ABSTRACT) 24 | @inline def STRICT = new Acc(acc | ACC_STRICT) 25 | @inline def SYNTHETIC = new Acc(acc | ACC_SYNTHETIC) 26 | @inline def ANNOTATION = new Acc(acc | ACC_ANNOTATION) 27 | @inline def ENUM = new Acc(acc | ACC_ENUM) 28 | @inline def MANDATED = new Acc(acc | ACC_MANDATED) 29 | @inline def MODULE = new Acc(acc | ACC_MODULE) 30 | @inline def RECORD = new Acc(acc | ACC_RECORD) 31 | @inline def DEPRECATED = new Acc(acc | ACC_DEPRECATED) 32 | } 33 | 34 | final class Acc(val acc: Int) extends AnyVal with AccessFlags { 35 | def | (other: Acc): Acc = new Acc(acc | other.acc) 36 | def has (other: Acc): Boolean = (acc & other.acc) == other.acc 37 | } 38 | 39 | object Acc extends AccessFlags { 40 | protected[this] final val acc = 0 41 | @inline def none: Acc = new Acc(0) 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/codegen/dsl/Desc.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen.dsl 2 | 3 | import org.objectweb.asm.Type 4 | 5 | import scala.reflect.ClassTag 6 | 7 | abstract class Desc { 8 | def desc: String 9 | def isArray: Boolean = desc.startsWith("[") 10 | def isMethod: Boolean = desc.startsWith("(") 11 | def isClass: Boolean = desc.startsWith("L") 12 | override def hashCode() = desc.hashCode 13 | override def equals(obj: Any) = obj match { 14 | case o: Desc => desc == o.desc 15 | case _ => false 16 | } 17 | } 18 | abstract class MethodDesc extends Desc 19 | abstract class ValDesc extends Desc { 20 | def a: ValDesc = new Desc.ValDescImpl("["+desc, jvmClass.map(c => c.arrayType())) 21 | def width: Int = { 22 | val d = desc 23 | if(d == "J" || d == "D") 2 24 | else 1 25 | } 26 | def jvmClass: Option[Class[_]] 27 | } 28 | final class PrimitiveValDesc(val desc: String, val jvmType: Class[_]) extends ValDesc { 29 | def jvmClass = Some(jvmType) 30 | } 31 | 32 | object Desc { 33 | import java.lang.{String => JString} 34 | private[dsl] class ValDescImpl(val desc: JString, val jvmClass: Option[Class[_]]) extends ValDesc 35 | private[this] class MethodDescImpl(val desc: JString) extends MethodDesc 36 | class MethodArgs private[Desc] (params: Seq[ValDesc]) { 37 | private[this] def d = params.iterator.map(_.desc).mkString("(", "", ")") 38 | def B: MethodDesc = new MethodDescImpl(s"${d}B") 39 | def Z: MethodDesc = new MethodDescImpl(s"${d}Z") 40 | def C: MethodDesc = new MethodDescImpl(s"${d}C") 41 | def I: MethodDesc = new MethodDescImpl(s"${d}I") 42 | def S: MethodDesc = new MethodDescImpl(s"${d}S") 43 | def D: MethodDesc = new MethodDescImpl(s"${d}D") 44 | def F: MethodDesc = new MethodDescImpl(s"${d}F") 45 | def J: MethodDesc = new MethodDescImpl(s"${d}J") 46 | def V: MethodDesc = new MethodDescImpl(s"${d}V") 47 | def apply(ret: ValDesc): MethodDesc = new MethodDescImpl(s"${d}${ret.desc}") 48 | } 49 | val B: PrimitiveValDesc = new PrimitiveValDesc("B", java.lang.Byte.TYPE) 50 | val Z: PrimitiveValDesc = new PrimitiveValDesc("Z", java.lang.Boolean.TYPE) 51 | val C: PrimitiveValDesc = new PrimitiveValDesc("C", java.lang.Character.TYPE) 52 | val I: PrimitiveValDesc = new PrimitiveValDesc("I", java.lang.Integer.TYPE) 53 | val S: PrimitiveValDesc = new PrimitiveValDesc("S", java.lang.Short.TYPE) 54 | val D: PrimitiveValDesc = new PrimitiveValDesc("D", java.lang.Double.TYPE) 55 | val F: PrimitiveValDesc = new PrimitiveValDesc("F", java.lang.Float.TYPE) 56 | val J: PrimitiveValDesc = new PrimitiveValDesc("J", java.lang.Long.TYPE) 57 | val V: PrimitiveValDesc = new PrimitiveValDesc("V", java.lang.Void.TYPE) 58 | val Object: ClassOwner = c[AnyRef] 59 | val String: ClassOwner = c[String] 60 | def m(desc: JString): MethodDesc = new MethodDescImpl(desc) 61 | def m(jMethod: java.lang.reflect.Method): MethodDesc = m(Type.getMethodDescriptor(jMethod)) 62 | def m(params: ValDesc*): MethodArgs = new MethodArgs(params) 63 | def c(className: JString): ClassOwner = new ClassOwner(className, None) 64 | def c(cls: Class[_]): ClassOwner = ClassOwner(cls) 65 | def c[T : ClassTag]: ClassOwner = ClassOwner.apply[T] 66 | def i(className: JString): InterfaceOwner = new InterfaceOwner(className, None) 67 | def i(cls: Class[_]): InterfaceOwner = InterfaceOwner(cls) 68 | def i[T : ClassTag]: InterfaceOwner = InterfaceOwner.apply[T] 69 | def o(cls: Class[_]): Owner = Owner(cls) 70 | def o[T : ClassTag]: Owner = Owner.apply[T] 71 | } 72 | 73 | sealed abstract class Owner extends ValDesc { 74 | def className: String 75 | def isInterface: Boolean 76 | final override def toString: String = className 77 | def desc: String = s"L$className;" 78 | 79 | def method(name: String, desc: MethodDesc): MethodRef = new MethodRef(this, name, desc) 80 | def field(name: String, desc: ValDesc): FieldRef = new FieldRef(this, name, desc) 81 | } 82 | object Owner { 83 | def apply[T](implicit ct: ClassTag[T]): Owner = apply(ct.runtimeClass) 84 | def apply(cls: Class[_]): Owner = 85 | if(cls.isInterface) InterfaceOwner(cls) else ClassOwner(cls) 86 | } 87 | class ClassOwner(val className: String, val jvmClass: Option[Class[_]]) extends Owner { 88 | def javaName = className.replace('/', '.') 89 | def isInterface = false 90 | def constr(desc: MethodDesc): ConstructorRef = new ConstructorRef(this, desc) 91 | } 92 | object ClassOwner { 93 | def apply[T](implicit ct: ClassTag[T]): ClassOwner = apply(ct.runtimeClass) 94 | def apply(cls: Class[_]): ClassOwner = { 95 | assert(!cls.isInterface) 96 | new ClassOwner(cls.getName.replace('.', '/'), Some(cls)) 97 | } 98 | } 99 | class InterfaceOwner(val className: String, val jvmClass: Option[Class[_]]) extends Owner { 100 | def isInterface = true 101 | } 102 | object InterfaceOwner { 103 | def apply[T](implicit ct: ClassTag[T]): InterfaceOwner = apply(ct.runtimeClass) 104 | def apply(cls: Class[_]): InterfaceOwner = { 105 | assert(cls.isInterface) 106 | new InterfaceOwner(cls.getName.replace('.', '/'), Some(cls)) 107 | } 108 | } 109 | 110 | case class MethodRef(owner: Owner, name: String, desc: MethodDesc) { 111 | def on(owner: Owner): MethodRef = new MethodRef(owner, name, desc) 112 | } 113 | object MethodRef { 114 | def apply(jMethod: java.lang.reflect.Method): MethodRef = 115 | Desc.o(jMethod.getDeclaringClass).method(jMethod.getName, Desc.m(jMethod)) 116 | } 117 | 118 | case class ConstructorRef(tpe: ClassOwner, desc: MethodDesc) 119 | 120 | case class FieldRef(owner: Owner, name: String, desc: ValDesc) 121 | -------------------------------------------------------------------------------- /src/main/scala/codegen/dsl/TypedDSL.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.codegen.dsl 2 | 3 | import org.objectweb.asm.Opcodes._ 4 | import org.objectweb.asm.Label 5 | 6 | object TypedDSL { 7 | def apply(d: ValDesc, m: MethodDSL): TypedDSL = d.desc match { 8 | case "I" => new IntTypedDSL(m) 9 | case "J" => new LongTypedDSL(m) 10 | case "F" => new FloatTypedDSL(m) 11 | case "D" => new DoubleTypedDSL(m) 12 | case _ if d.isClass || d.isArray => new RefTypedDSL(m, d) 13 | } 14 | } 15 | 16 | abstract class TypedDSL(m: MethodDSL) { 17 | def desc: ValDesc 18 | def wide: Boolean 19 | def xload(v: VarIdx): this.type 20 | def xstore(v: VarIdx): this.type 21 | def xaload: this.type 22 | def xastore: this.type 23 | def xreturn: this.type 24 | def zero: this.type 25 | def xdup: this.type = { if(wide) m.dup2 else m.dup; this } 26 | def xpop: this.type = { if(wide) m.pop2 else m.pop; this } 27 | def if_== : ThenDSL 28 | def if_!= : ThenDSL 29 | } 30 | 31 | class RefTypedDSL(m: MethodDSL, val desc: ValDesc) extends TypedDSL(m) { 32 | def wide: Boolean = false 33 | 34 | def xload(v: VarIdx) : this.type = { m.aload(v); this } 35 | def xstore(v: VarIdx): this.type = { m.astore(v); this } 36 | def xaload : this.type = { m.aaload; this } 37 | def xastore : this.type = { m.aastore; this } 38 | def xreturn : this.type = { m.areturn; this } 39 | def zero : this.type = { m.aconst_null; this } 40 | def if_== : ThenDSL = { m.aconst_null; new ThenDSL(IF_ACMPEQ, IF_ACMPNE, m, new Label) } 41 | def if_!= : ThenDSL = { m.aconst_null; new ThenDSL(IF_ACMPNE, IF_ACMPEQ, m, new Label) } 42 | } 43 | 44 | abstract class NumericTypedDSL(m: MethodDSL) extends TypedDSL(m) { 45 | def desc: PrimitiveValDesc 46 | 47 | def xadd: this.type 48 | def xmul: this.type 49 | def xsub: this.type 50 | def xdiv: this.type 51 | def xneg: this.type 52 | def xrem: this.type 53 | def x2f: this.type 54 | def x2i: this.type 55 | def x2l: this.type 56 | def x2d: this.type 57 | 58 | protected[this] def xcmp: Unit 59 | protected[this] def if0(pos: Int, neg: Int): ThenDSL = { 60 | zero 61 | ifX(pos, neg) 62 | } 63 | protected[this] def ifX(pos: Int, neg: Int): ThenDSL = { 64 | xcmp 65 | new ThenDSL(pos, neg, m, new Label) 66 | } 67 | 68 | def if_== : ThenDSL = if0(IF_ICMPEQ, IF_ICMPNE) 69 | def if_!= : ThenDSL = if0(IF_ICMPNE, IF_ICMPEQ) 70 | def if_< : ThenDSL = if0(IF_ICMPLT, IF_ICMPGE) 71 | def if_> : ThenDSL = if0(IF_ICMPGT, IF_ICMPLE) 72 | def if_<= : ThenDSL = if0(IF_ICMPLE, IF_ICMPGT) 73 | def if_>= : ThenDSL = if0(IF_ICMPGE, IF_ICMPLT) 74 | 75 | def ifX_== : ThenDSL = ifX(IF_ICMPEQ, IF_ICMPNE) 76 | def ifX_!= : ThenDSL = ifX(IF_ICMPNE, IF_ICMPEQ) 77 | def ifX_< : ThenDSL = ifX(IF_ICMPLT, IF_ICMPGE) 78 | def ifX_> : ThenDSL = ifX(IF_ICMPGT, IF_ICMPLE) 79 | def ifX_<= : ThenDSL = ifX(IF_ICMPLE, IF_ICMPGT) 80 | def ifX_>= : ThenDSL = ifX(IF_ICMPGE, IF_ICMPLT) 81 | } 82 | 83 | class IntTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) { 84 | protected[this] def xcmp: Unit = () 85 | 86 | def desc: PrimitiveValDesc = Desc.I 87 | def wide: Boolean = false 88 | 89 | def xload(v: VarIdx) : this.type = { m.iload(v); this } 90 | def xstore(v: VarIdx): this.type = { m.istore(v); this } 91 | def xaload : this.type = { m.iaload; this } 92 | def xastore : this.type = { m.iastore; this } 93 | def xreturn : this.type = { m.ireturn; this } 94 | def zero : this.type = { m.iconst(0); this } 95 | def xadd : this.type = { m.iadd; this } 96 | def xmul : this.type = { m.imul; this } 97 | def xsub : this.type = { m.isub; this } 98 | def xdiv : this.type = { m.idiv; this } 99 | def xneg : this.type = { m.ineg; this } 100 | def xrem : this.type = { m.irem; this } 101 | def x2f : this.type = { m.i2f; this } 102 | def x2i : this.type = this 103 | def x2l : this.type = { m.i2l; this } 104 | def x2d : this.type = { m.l2d; this } 105 | 106 | override def if_== : ThenDSL = new ThenDSL(IFEQ, IFNE, m, new Label) 107 | override def if_!= : ThenDSL = new ThenDSL(IFNE, IFEQ, m, new Label) 108 | override def if_< : ThenDSL = new ThenDSL(IFLT, IFGE, m, new Label) 109 | override def if_> : ThenDSL = new ThenDSL(IFGT, IFLE, m, new Label) 110 | override def if_<= : ThenDSL = new ThenDSL(IFLE, IFGT, m, new Label) 111 | override def if_>= : ThenDSL = new ThenDSL(IFGE, IFLT, m, new Label) 112 | } 113 | 114 | class LongTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) { 115 | protected[this] def xcmp: Unit = m.lcmp 116 | 117 | def desc: PrimitiveValDesc = Desc.J 118 | def wide: Boolean = true 119 | 120 | def xload(v: VarIdx) : this.type = { m.lload(v); this } 121 | def xstore(v: VarIdx): this.type = { m.lstore(v); this } 122 | def xaload : this.type = { m.laload; this } 123 | def xastore : this.type = { m.lastore; this } 124 | def xreturn : this.type = { m.lreturn; this } 125 | def zero : this.type = { m.lconst(0); this } 126 | def xadd : this.type = { m.ladd; this } 127 | def xmul : this.type = { m.lmul; this } 128 | def xsub : this.type = { m.lsub; this } 129 | def xdiv : this.type = { m.ldiv; this } 130 | def xneg : this.type = { m.lneg; this } 131 | def xrem : this.type = { m.lrem; this } 132 | def x2f : this.type = { m.l2f; this } 133 | def x2i : this.type = { m.l2i; this } 134 | def x2l : this.type = this 135 | def x2d : this.type = { m.l2d; this } 136 | } 137 | 138 | class FloatTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) { 139 | protected[this] def xcmp: Unit = m.fcmpl 140 | 141 | def desc: PrimitiveValDesc = Desc.F 142 | def wide: Boolean = false 143 | 144 | def xload(v: VarIdx) : this.type = { m.fload(v); this } 145 | def xstore(v: VarIdx): this.type = { m.fstore(v); this } 146 | def xaload : this.type = { m.faload; this } 147 | def xastore : this.type = { m.fastore; this } 148 | def xreturn : this.type = { m.freturn; this } 149 | def zero : this.type = { m.fconst(0); this } 150 | def xadd : this.type = { m.fadd; this } 151 | def xmul : this.type = { m.fmul; this } 152 | def xsub : this.type = { m.fsub; this } 153 | def xdiv : this.type = { m.fdiv; this } 154 | def xneg : this.type = { m.fneg; this } 155 | def xrem : this.type = { m.frem; this } 156 | def x2f : this.type = this 157 | def x2i : this.type = { m.f2i; this } 158 | def x2l : this.type = { m.f2l; this } 159 | def x2d : this.type = { m.f2d; this } 160 | } 161 | 162 | class DoubleTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) { 163 | protected[this] def xcmp: Unit = m.dcmpl 164 | 165 | def desc: PrimitiveValDesc = Desc.D 166 | def wide: Boolean = true 167 | 168 | def xload(v: VarIdx) : this.type = { m.dload(v); this } 169 | def xstore(v: VarIdx): this.type = { m.dstore(v); this } 170 | def xaload : this.type = { m.daload; this } 171 | def xastore : this.type = { m.dastore; this } 172 | def xreturn : this.type = { m.dreturn; this } 173 | def zero : this.type = { m.dconst(0); this } 174 | def xadd : this.type = { m.dadd; this } 175 | def xmul : this.type = { m.dmul; this } 176 | def xsub : this.type = { m.dsub; this } 177 | def xdiv : this.type = { m.ddiv; this } 178 | def xneg : this.type = { m.dneg; this } 179 | def xrem : this.type = { m.drem; this } 180 | def x2f : this.type = { m.d2f; this } 181 | def x2i : this.type = { m.d2i; this } 182 | def x2l : this.type = { m.d2l; this } 183 | def x2d : this.type = this 184 | } 185 | -------------------------------------------------------------------------------- /src/main/scala/mt/workers/Workers.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.mt.workers 2 | 3 | import java.lang.invoke.MethodHandles 4 | import scala.annotation.tailrec 5 | 6 | object Workers { 7 | private[workers] final val endMarker = new AnyRef 8 | } 9 | 10 | abstract class Worker[T >: Null] extends Thread { 11 | import Worker._ 12 | 13 | private[workers] final var _workers: Workers[T] = _ 14 | private[workers] final var _idx: Int = _ 15 | private[workers] final var page: Page = _ 16 | 17 | @inline final def workers = _workers 18 | @inline final def idx = _idx 19 | 20 | @volatile private[this] var _state: Int = 0 21 | 22 | def apply(v: T): Unit 23 | 24 | final override def run(): Unit = { 25 | while(true) { 26 | var item = tryTake() 27 | if(item == null) { 28 | maybeEmpty() 29 | item = take() 30 | } 31 | if(item.asInstanceOf[AnyRef] eq Workers.endMarker) return 32 | apply(item) 33 | } 34 | } 35 | 36 | def maybeEmpty(): Unit = () 37 | 38 | @inline private def tryLock(): Boolean = STATE.weakCompareAndSetAcquire(this, 0, 1) 39 | 40 | @tailrec @inline private final def lock(): Boolean = { 41 | if(tryLock()) true 42 | else { 43 | Thread.onSpinWait() 44 | lock() 45 | } 46 | } 47 | 48 | @inline private final def unlock(): Unit = STATE.setRelease(this, 0) 49 | 50 | final def add(v: T): Unit = { 51 | //println(s"add($v)") 52 | lock() 53 | val p = page 54 | if(p.size < p.data.length) { 55 | p.data(p.size) = v.asInstanceOf[AnyRef] 56 | p.size += 1 57 | } else { 58 | val p2 = new Page(workers.pageSize) 59 | p2.next = p 60 | p2.data(0) = v.asInstanceOf[AnyRef] 61 | p2.size = 1 62 | page = p2 63 | } 64 | unlock() 65 | //println(s"add($v) - done") 66 | } 67 | 68 | private final def tryTake(): T = { 69 | if(tryLock()) { 70 | val v = page.takeOrNull() 71 | val ret = if(v != null) { 72 | v 73 | } else if(page.next != null) { 74 | page = page.next 75 | page.takeOrNull() 76 | } else { 77 | val p = tryStealFromOther(page, (idx+1) % workers.numThreads, workers.numThreads-1) 78 | if(p != null) { 79 | page = p 80 | p.takeOrNull() 81 | } else null 82 | } 83 | unlock() 84 | ret.asInstanceOf[T] 85 | } else null 86 | } 87 | 88 | @tailrec private final def take(): T = { 89 | val ret = tryTake() 90 | if(ret != null) ret 91 | else { 92 | Thread.onSpinWait() 93 | take() 94 | } 95 | } 96 | 97 | private final def trySteal(into: Page): Page = { 98 | if(!tryLock()) null 99 | else { 100 | val ret = if(page.next != null) { 101 | val p = page.next 102 | page.next = null 103 | p 104 | } else if(page.size == 0) null 105 | else if(page.size == 1) { 106 | val p = page 107 | page = into 108 | p 109 | } else { 110 | val stealCount = page.size/2 111 | page.size -= stealCount 112 | into.size = stealCount 113 | System.arraycopy(page.data, page.size, into.data, 0, stealCount) 114 | into 115 | } 116 | unlock() 117 | ret 118 | } 119 | } 120 | 121 | @tailrec private final def tryStealFromOther(into: Page, cur: Int, left: Int): Page = { 122 | val p = workers.workers(cur).trySteal(into) 123 | if(p != null) p 124 | else if(left > 0) tryStealFromOther(into, (cur+1) % workers.numThreads, left-1) 125 | else null 126 | } 127 | } 128 | 129 | object Worker { 130 | final val STATE = 131 | MethodHandles.privateLookupIn(classOf[Worker[_]], MethodHandles.lookup).findVarHandle(classOf[Worker[_]], "_state", classOf[Int]) 132 | } 133 | 134 | class Workers[T >: Null](val numThreads: Int, val pageSize: Int, createWorker: Int => Worker[T]) { 135 | val workers = (0 until numThreads).iterator.map { i => 136 | val w = createWorker(i) 137 | w.setDaemon(true) 138 | w.page = new Page(pageSize) 139 | w._workers = this 140 | w._idx = i 141 | w 142 | }.toArray 143 | private[this] var started = false 144 | private[this] var addIdx = 0 145 | 146 | def add(item: T): Unit = { 147 | workers(addIdx).add(item) 148 | addIdx = (addIdx+1) % numThreads 149 | } 150 | 151 | def start(): Unit = { 152 | if(!started) { 153 | workers.foreach(_.start) 154 | started = true 155 | } 156 | } 157 | 158 | def shutdown(): Unit = { 159 | (1 to workers.length).foreach { _ => add(Workers.endMarker.asInstanceOf[T]) } 160 | workers.foreach(_.join()) 161 | } 162 | } 163 | 164 | final class Page(_pageSize: Int) { 165 | val data = new Array[AnyRef](_pageSize) 166 | var size: Int = 0 167 | var next: Page = _ 168 | 169 | def takeOrNull(): AnyRef = 170 | if(size > 0) { size -= 1; data(size) } else null 171 | } 172 | -------------------------------------------------------------------------------- /src/main/scala/offheap/Allocator.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.offheap 2 | 3 | import java.util.Arrays 4 | 5 | object Allocator { 6 | val UNSAFE = { 7 | val f = classOf[sun.misc.Unsafe].getDeclaredField("theUnsafe") 8 | f.setAccessible(true) 9 | f.get(null).asInstanceOf[sun.misc.Unsafe] 10 | } 11 | 12 | def proxyElemOffset = -4L 13 | 14 | val (objArrayOffset, objArrayScale) = { 15 | val cl = classOf[Array[AnyRef]] 16 | (UNSAFE.arrayBaseOffset(cl), UNSAFE.arrayIndexScale(cl)) 17 | } 18 | 19 | val (objArrayArrayOffset, objArrayArrayScale) = { 20 | val cl = classOf[Array[Array[AnyRef]]] 21 | (UNSAFE.arrayBaseOffset(cl), UNSAFE.arrayIndexScale(cl)) 22 | } 23 | 24 | // used by code generator: 25 | def putInt(address: Long, value: Int): Unit = UNSAFE.putInt(address, value) 26 | def getInt(address: Long): Int = UNSAFE.getInt(address) 27 | def putLong(address: Long, value: Long): Unit = UNSAFE.putLong(address, value) 28 | def getLong(address: Long): Long = UNSAFE.getLong(address) 29 | def getObject(o: AnyRef, offset: Long): AnyRef = UNSAFE.getObject(o, offset) 30 | def putObject(o: AnyRef, offset: Long, v: AnyRef): Unit = UNSAFE.putObject(o, offset, v) 31 | } 32 | 33 | abstract class Allocator { 34 | def dispose(): Unit 35 | 36 | def alloc(len: Long): Long 37 | def alloc8(): Long = alloc(8) 38 | def alloc16(): Long = alloc(16) 39 | def alloc24(): Long = alloc(24) 40 | def alloc32(): Long = alloc(32) 41 | 42 | def free(address: Long, len: Long): Unit 43 | def free8(o: Long): Unit = free(o, 8) 44 | def free16(o: Long): Unit = free(o, 16) 45 | def free24(o: Long): Unit = free(o, 24) 46 | def free32(o: Long): Unit = free(o, 32) 47 | } 48 | 49 | abstract class ProxyAllocator extends Allocator { 50 | def getProxyPage(o: Long): AnyRef 51 | def getProxy(o: Long): AnyRef 52 | def setProxy(o: Long, v: AnyRef): Unit 53 | 54 | def allocProxied(len: Long): Long 55 | def alloc8p(): Long = allocProxied(8) 56 | def alloc16p(): Long = allocProxied(16) 57 | def alloc24p(): Long = allocProxied(24) 58 | def alloc32p(): Long = allocProxied(32) 59 | 60 | def freeProxied(o: Long, len: Long): Unit 61 | def free8p(o: Long): Unit = freeProxied(o, 8) 62 | def free16p(o: Long): Unit = freeProxied(o, 16) 63 | def free24p(o: Long): Unit = freeProxied(o, 24) 64 | def free32p(o: Long): Unit = freeProxied(o, 32) 65 | } 66 | 67 | object SystemAllocator extends Allocator { 68 | import Allocator._ 69 | 70 | def dispose(): Unit = () 71 | def alloc(len: Long): Long = UNSAFE.allocateMemory(len) 72 | def free(address: Long, len: Long): Unit = UNSAFE.freeMemory(address) 73 | } 74 | 75 | final class ArenaAllocator(blockSize: Long = 1024L*1024L*8L) extends Allocator { 76 | import Allocator._ 77 | private[this] var block, end, next = 0L 78 | 79 | def dispose(): Unit = { 80 | while(block != 0L) { 81 | val n = UNSAFE.getLong(block) 82 | UNSAFE.freeMemory(block) 83 | block = n 84 | } 85 | } 86 | 87 | def alloc(len: Long): Long = { 88 | if(next + len >= end) { 89 | allocBlock() 90 | assert(next + len < end) 91 | } 92 | val o = next 93 | next += len 94 | o 95 | } 96 | 97 | def free(address: Long, len: Long): Unit = () 98 | 99 | private[this] def allocBlock(): Unit = { 100 | val b = UNSAFE.allocateMemory(blockSize) 101 | UNSAFE.putLong(b, block) 102 | block = b 103 | next = b + 8 104 | end = b + blockSize 105 | } 106 | } 107 | 108 | final class SliceAllocator(blockSize: Long = 1024L*64L, maxSliceSize: Int = 256, arenaSize: Long = 1024L*1024L*8L) extends ProxyAllocator { 109 | import Allocator._ 110 | assert(blockSize % 8 == 0) 111 | assert(maxSliceSize % 8 == 0) 112 | assert(blockSize >= maxSliceSize) 113 | 114 | private[this] val blockAllocator = new ArenaAllocator(arenaSize) 115 | private[this] val slices: Array[Slice] = Array.tabulate(maxSliceSize >> 3)(i => new Slice((i+1) << 3)) 116 | private[this] val proxySlices: Array[ProxySlice] = Array.tabulate(maxSliceSize >> 3)(i => new ProxySlice((i+1) << 3)) 117 | private[this] var proxyPages: Array[Array[AnyRef]] = new Array[Array[AnyRef]](64) 118 | private[this] var proxyPagesLen: Int= 0 119 | 120 | private[this] val slice8: Slice = slices(8 >> 3) 121 | private[this] val slice16: Slice = slices(16 >> 3) 122 | private[this] val slice24: Slice = slices(24 >> 3) 123 | private[this] val slice32: Slice = slices(32 >> 3) 124 | private[this] val slice8p: ProxySlice = proxySlices(8 >> 3) 125 | private[this] val slice16p: ProxySlice = proxySlices(16 >> 3) 126 | private[this] val slice24p: ProxySlice = proxySlices(24 >> 3) 127 | private[this] val slice32p: ProxySlice = proxySlices(32 >> 3) 128 | 129 | override def alloc8(): Long = slice8.alloc() 130 | override def alloc16(): Long = slice16.alloc() 131 | override def alloc24(): Long = slice24.alloc() 132 | override def alloc32(): Long = slice32.alloc() 133 | override def free8(o: Long): Unit = slice8.free(o) 134 | override def free16(o: Long): Unit = slice16.free(o) 135 | override def free24(o: Long): Unit = slice24.free(o) 136 | override def free32(o: Long): Unit = slice32.free(o) 137 | 138 | override def alloc8p(): Long = slice8p.alloc() 139 | override def alloc16p(): Long = slice16p.alloc() 140 | override def alloc24p(): Long = slice24p.alloc() 141 | override def alloc32p(): Long = slice32p.alloc() 142 | override def free8p(o: Long): Unit = slice8p.free(o) 143 | override def free16p(o: Long): Unit = slice16p.free(o) 144 | override def free24p(o: Long): Unit = slice24p.free(o) 145 | override def free32p(o: Long): Unit = slice32p.free(o) 146 | 147 | def dispose(): Unit = blockAllocator.dispose() 148 | def alloc(len: Long): Long = slices((len >> 3).toInt).alloc() 149 | def free(o: Long, len: Long): Unit = slices((len >> 3).toInt).free(o) 150 | def allocProxied(len: Long): Long = proxySlices((len >> 3).toInt).alloc() 151 | def freeProxied(o: Long, len: Long): Unit = proxySlices((len >> 3).toInt).free(o) 152 | 153 | def getProxyPage(o: Long): AnyRef = { 154 | val off = UNSAFE.getInt(o-8) 155 | UNSAFE.getObject(proxyPages, off) 156 | } 157 | 158 | def getProxy(o: Long): AnyRef = { 159 | val pp = getProxyPage(o) 160 | val off = UNSAFE.getInt(o-4) 161 | UNSAFE.getObject(pp, off) 162 | } 163 | 164 | def setProxy(o: Long, v: AnyRef): Unit = { 165 | val pp = getProxyPage(o) 166 | val off = UNSAFE.getInt(o-4) 167 | UNSAFE.putObject(pp, off, v) 168 | } 169 | 170 | private[this] final class Slice(sliceSize: Int) { 171 | private[this] val allocSize = ((blockSize / sliceSize) * sliceSize) + 8 172 | private[this] var block, last, next, freeSlice = 0L 173 | 174 | private[this] def allocBlock(): Unit = { 175 | val b = blockAllocator.alloc(allocSize) 176 | UNSAFE.putLong(b, block) 177 | block = b 178 | next = b + 8 179 | last = b + allocSize - sliceSize 180 | } 181 | 182 | def alloc(): Long = { 183 | if(freeSlice != 0L) { 184 | val o = freeSlice 185 | freeSlice = UNSAFE.getLong(o) 186 | o 187 | } else { 188 | if(next >= last) allocBlock() 189 | val o = next 190 | next += sliceSize 191 | o 192 | } 193 | } 194 | 195 | def free(o: Long): Unit = { 196 | UNSAFE.putLong(o, freeSlice) 197 | freeSlice = o 198 | } 199 | } 200 | 201 | private[this] final class ProxySlice(_sliceSize: Int) { 202 | private[this] val sliceAllocSize = _sliceSize + 8 203 | private[this] val numBlocks = (blockSize / sliceAllocSize).toInt 204 | private[this] val allocSize = (numBlocks * sliceAllocSize) + 8 205 | private[this] var block, last, next, freeSlice = 0L 206 | 207 | private[this] def allocBlock(): Unit = { 208 | val b = blockAllocator.alloc(allocSize) 209 | UNSAFE.putLong(b, block) 210 | block = b 211 | next = b + 8 212 | last = b + allocSize - sliceAllocSize 213 | proxyPagesLen += 1 214 | if(proxyPagesLen == proxyPages.length) 215 | proxyPages = Arrays.copyOf(proxyPages, proxyPages.length * 2) 216 | proxyPages(proxyPagesLen-1) = new Array[AnyRef](numBlocks) 217 | } 218 | 219 | def alloc(): Long = { 220 | if(freeSlice != 0L) { 221 | val o = freeSlice 222 | freeSlice = UNSAFE.getLong(o) 223 | o 224 | } else { 225 | if(next >= last) allocBlock() 226 | val p = proxyPagesLen-1 227 | UNSAFE.putInt(next, objArrayArrayOffset + p*objArrayArrayScale) 228 | val i = (next-block-8).toInt/sliceAllocSize 229 | UNSAFE.putInt(next + 4, objArrayOffset + i*objArrayScale) 230 | val o = next + 8 231 | next += sliceAllocSize 232 | o 233 | } 234 | } 235 | 236 | def free(o: Long): Unit = { 237 | setProxy(o, null) 238 | UNSAFE.putLong(o, freeSlice) 239 | freeSlice = o 240 | } 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /src/main/scala/offheap/MemoryDebugger.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.offheap 2 | 3 | import scala.collection.mutable 4 | 5 | object MemoryDebugger extends ProxyAllocator { 6 | private[this] var parent: ProxyAllocator = _ 7 | 8 | def setParent(p: ProxyAllocator): Unit = { 9 | objects.clear() 10 | parent = p 11 | } 12 | 13 | private[this] final class Obj(val address: Long, val length: Long, val proxied: Boolean) { 14 | override def toString = s"[$address..${address+length}[" 15 | } 16 | 17 | private[this] val objects = mutable.TreeMap.empty[Long, Obj](implicitly[Ordering[Long]].reverse) 18 | 19 | private[this] def find(o: Long, accessLen: Int): Obj = { 20 | val it = objects.valuesIteratorFrom(o) 21 | if(!it.hasNext) throw new AssertionError(s"Address $o is before all allocated objects") 22 | val v = it.next() 23 | if(o + accessLen > v.address + v.length) 24 | throw new AssertionError(s"Address $o with access length $accessLen is outside of object $v") 25 | v 26 | } 27 | 28 | def dispose(): Unit = { 29 | parent.dispose() 30 | objects.clear() 31 | } 32 | 33 | def alloc(len: Long): Long = { 34 | val o = parent.alloc(len) 35 | objects.put(o, new Obj(o, len, false)) 36 | if(o % 8 != 0) throw new AssertionError(s"alloc($len) returned non-aligned address $o") 37 | o 38 | } 39 | 40 | def free(address: Long, len: Long): Unit = { 41 | val obj = find(address, 0) 42 | if(obj.address != address) throw new AssertionError(s"free($address, $len) not called with base address of $obj") 43 | if(obj.proxied) throw new AssertionError(s"free($address, $len) called on proxied object") 44 | parent.free(address, len) 45 | objects.remove(address) 46 | } 47 | 48 | def allocProxied(len: Long): Long = { 49 | val o = parent.allocProxied(len) 50 | objects.put(o, new Obj(o, len, true)) 51 | if(o % 8 != 0) throw new AssertionError(s"allocProxied($len) returned non-aligned address $o") 52 | o 53 | } 54 | 55 | def freeProxied(address: Long, len: Long): Unit = { 56 | val obj = find(address, 0) 57 | if(obj.address != address) throw new AssertionError(s"freeProxied($address, $len) not called with base address of $obj") 58 | if(!obj.proxied) throw new AssertionError(s"freeProxied($address, $len) called on non-proxied object") 59 | parent.freeProxied(address, len) 60 | objects.remove(address) 61 | } 62 | 63 | def getProxyPage(address: Long): AnyRef = { 64 | val obj = find(address, 0) 65 | if(obj.address != address) throw new AssertionError(s"getProxyPage($address) not called with base address of $obj") 66 | if(!obj.proxied) throw new AssertionError(s"getProxyPage($address) called on non-proxied object") 67 | parent.getProxyPage(address) 68 | } 69 | 70 | def getProxy(address: Long): AnyRef = { 71 | val obj = find(address, 0) 72 | if(obj.address != address) throw new AssertionError(s"getProxy($address) not called with base address of $obj") 73 | if(!obj.proxied) throw new AssertionError(s"getProxy($address) called on non-proxied object") 74 | parent.getProxy(address) 75 | } 76 | 77 | def setProxy(address: Long, value: AnyRef): Unit = { 78 | val obj = find(address, 0) 79 | if(obj.address != address) throw new AssertionError(s"setProxy($address, $obj) not called with base address of $obj") 80 | if(!obj.proxied) throw new AssertionError(s"setProxy($address, $obj) called on non-proxied object") 81 | parent.setProxy(address, value) 82 | } 83 | 84 | def putInt(address: Long, value: Int): Unit = { 85 | val obj = find(address, 4) 86 | if(obj.proxied && address == obj.address + 4) throw new AssertionError(s"Overwriting 32-bit payload of proxied object") 87 | Allocator.putInt(address, value) 88 | } 89 | 90 | def getInt(address: Long): Int = { 91 | val obj = find(address, 4) 92 | Allocator.getInt(address) 93 | } 94 | 95 | def putLong(address: Long, value: Long): Unit = { 96 | val obj = find(address, 8) 97 | if(obj.proxied && address == obj.address) throw new AssertionError(s"Overwriting 32-bit payload of proxied object") 98 | Allocator.putLong(address, value) 99 | } 100 | 101 | def getLong(address: Long): Long = { 102 | val obj = find(address, 8) 103 | Allocator.getLong(address) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/main/scala/stc1/Interpreter.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.stc1 2 | 3 | import de.szeiger.interact.codegen.{ClassWriter, LocalClassLoader} 4 | import de.szeiger.interact._ 5 | import de.szeiger.interact.ast.{CompilationUnit, Symbol, Symbols} 6 | 7 | import java.util.Arrays 8 | import scala.collection.mutable 9 | 10 | abstract class Cell { 11 | def arity: Int 12 | def auxCell(p: Int): Cell 13 | def auxPort(p: Int): Int 14 | def setAux(p: Int, c2: Cell, p2: Int): Unit 15 | def reduce(c: Cell, ptw: Interpreter): Unit 16 | def cellSymbol: Symbol 17 | } 18 | 19 | final class DynamicCell(val cellSymbol: Symbol, val arity: Int) extends Cell { 20 | private[this] final val auxCells = new Array[Cell](arity) 21 | private[this] final val auxPorts = new Array[Int](arity) 22 | def auxCell(p: Int): Cell = auxCells(p) 23 | def auxPort(p: Int): Int = auxPorts(p) 24 | def setAux(p: Int, c2: Cell, p2: Int): Unit = { auxCells(p) = c2; auxPorts(p) = p2 } 25 | def reduce(c: Cell, ptw: Interpreter): Unit = () 26 | } 27 | 28 | abstract class InitialRuleImpl { 29 | def reduce(a0: Cell, a1: Cell, ptw: Interpreter): Unit 30 | def freeWires: Array[Symbol] 31 | } 32 | 33 | final class Interpreter(globals: Symbols, compilationUnit: CompilationUnit, config: Config) extends BaseInterpreter { self => 34 | private[this] val initialRuleImpls: Vector[InitialRuleImpl] = { 35 | val lcl = new LocalClassLoader 36 | val cw = ClassWriter(config, lcl) 37 | val initial = new CodeGen("generated", cw, config, compilationUnit, globals).compile() 38 | cw.close() 39 | initial.map { cln => lcl.loadClass(cln).getDeclaredConstructor().newInstance().asInstanceOf[InitialRuleImpl] } 40 | } 41 | private[this] val cutBuffer, irreducible = new CutBuffer(16) 42 | private[this] val freeWires = mutable.HashSet.empty[Cell] 43 | private[this] var metrics: ExecutionMetrics = _ 44 | private[this] var active0, active1: Cell = _ 45 | 46 | def getMetrics: ExecutionMetrics = metrics 47 | 48 | def getAnalyzer: Analyzer[Cell] = new Analyzer[Cell] { 49 | val principals = (cutBuffer.iterator ++ irreducible.iterator).flatMap { case (c1, c2) => Seq((c1, c2), (c2, c1)) }.toMap 50 | def irreduciblePairs: IterableOnce[(Cell, Cell)] = irreducible.iterator 51 | def rootCells = (self.freeWires.iterator ++ principals.keysIterator).toSet 52 | def getSymbol(c: Cell): Symbol = Option(c.cellSymbol).getOrElse(Symbol.NoSymbol) 53 | def getConnected(c: Cell, port: Int): (Cell, Int) = 54 | if(port == -1) principals.get(c).map((_, -1)).orNull else (c.auxCell(port), c.auxPort(port)) 55 | def isFreeWire(c: Cell): Boolean = c.isInstanceOf[DynamicCell] && c.cellSymbol.isDefined 56 | } 57 | 58 | def initData(): Unit = { 59 | cutBuffer.clear() 60 | irreducible.clear() 61 | freeWires.clear() 62 | if(config.collectStats) metrics = new ExecutionMetrics 63 | initialRuleImpls.foreach { rule => 64 | val free = rule.freeWires.map(new DynamicCell(_, 1)) 65 | freeWires.addAll(free) 66 | val lhs = new DynamicCell(Symbol.NoSymbol, freeWires.size) 67 | free.iterator.zipWithIndex.foreach { case (c, p) => lhs.setAux(p, c, 0) } 68 | rule.reduce(lhs, new DynamicCell(Symbol.NoSymbol, 0), this) 69 | } 70 | if(config.collectStats) metrics = new ExecutionMetrics 71 | } 72 | 73 | def dispose(): Unit = () 74 | 75 | def reduce(): Unit = 76 | while(true) { 77 | while(active0 != null) { 78 | val a0 = active0 79 | active0 = null 80 | a0.reduce(active1, this) 81 | } 82 | if(cutBuffer.isEmpty) return 83 | val (a0, a1) = cutBuffer.pop() 84 | a0.reduce(a1, this) 85 | } 86 | 87 | // ptw methods: 88 | 89 | def addActive(a0: Cell, a1: Cell): Unit = 90 | if(active0 == null) { active0 = a0; active1 = a1 } else cutBuffer.addOne(a0, a1) 91 | 92 | def addIrreducible(a0: Cell, a1: Cell): Unit = irreducible.addOne(a0, a1) 93 | 94 | def recordStats(steps: Int, cellAllocations: Int, cachedCellReuse: Int, singletonUse: Int, loopSave: Int, labelCreate: Int): Unit = 95 | metrics.recordStats(steps, cellAllocations, 0, cachedCellReuse, singletonUse, 0, loopSave, 0, 0, labelCreate) 96 | 97 | def recordMetric(metric: String, inc: Int): Unit = metrics.recordMetric(metric, inc) 98 | } 99 | 100 | 101 | final class CutBuffer(initialSize: Int) { 102 | private[this] var pairs = new Array[Cell](initialSize*2) 103 | private[this] var len = 0 104 | @inline def addOne(c1: Cell, c2: Cell): Unit = { 105 | if(len == pairs.length) 106 | pairs = Arrays.copyOf(pairs, pairs.length*2) 107 | pairs(len) = c1 108 | pairs(len+1) = c2 109 | len += 2 110 | } 111 | @inline def isEmpty: Boolean = len == 0 112 | @inline def pop(): (Cell, Cell) = { 113 | len -= 2 114 | val c1 = pairs(len) 115 | val c2 = pairs(len+1) 116 | pairs(len) = null 117 | pairs(len+1) = null 118 | (c1, c2) 119 | } 120 | def clear(): Unit = 121 | if(len > 0) { 122 | pairs = new Array[Cell](initialSize * 2) 123 | len = 0 124 | } 125 | def iterator: Iterator[(Cell, Cell)] = pairs.iterator.take(len).grouped(2).map { case Seq(c1, c2) => (c1, c2) } 126 | } 127 | -------------------------------------------------------------------------------- /src/main/scala/stc1/PTOps.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.stc1 2 | 3 | import de.szeiger.interact.ast.PayloadType 4 | import de.szeiger.interact.codegen.{BoxDesc, BoxOps} 5 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _} 6 | 7 | object PTOps { 8 | def boxDesc(pt: PayloadType): BoxDesc = pt match { 9 | case PayloadType.INT => BoxDesc.intDesc 10 | case PayloadType.REF => BoxDesc.refDesc 11 | case PayloadType.LABEL => BoxDesc.refDesc 12 | } 13 | def apply(m: MethodDSL, pt: PayloadType): BoxOps = new BoxOps(m, boxDesc(pt)) 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala/stc2/Interpreter.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.stc2 2 | 3 | import de.szeiger.interact.codegen.{ClassWriter, LocalClassLoader} 4 | import de.szeiger.interact._ 5 | import de.szeiger.interact.ast.{CompilationUnit, PayloadType, Symbol, Symbols} 6 | import de.szeiger.interact.offheap.{Allocator, MemoryDebugger, ProxyAllocator} 7 | 8 | import java.util.Arrays 9 | import scala.collection.mutable 10 | 11 | object Defs { 12 | type Cell = Long 13 | } 14 | import Defs._ 15 | 16 | abstract class InitialRuleImpl { 17 | def reduce(a0: Cell, a1: Cell, ptw: Interpreter): Unit 18 | def freeWires: Array[Symbol] 19 | } 20 | 21 | abstract class Dispatch { 22 | def reduce(c1: Cell, c2: Cell, level: Int, ptw: Interpreter): Unit 23 | } 24 | 25 | final class Interpreter(globals: Symbols, compilationUnit: CompilationUnit, config: Config) extends BaseInterpreter { self => 26 | import Interpreter._ 27 | 28 | private[this] val symIds = globals.symbols.filter(_.isCons).toVector.sortBy(_.id).iterator.zipWithIndex.toMap 29 | private[this] val reverseSymIds = symIds.map { case (sym, idx) => (idx, sym) } 30 | private[this] val (initialRuleImpls: Vector[InitialRuleImpl], dispatch: Dispatch) = { 31 | val lcl = new LocalClassLoader 32 | val cw = ClassWriter(config, lcl) 33 | val (initial, dispN) = new CodeGen("generated", cw, compilationUnit, globals, symIds, config).compile() 34 | cw.close() 35 | val irs = initial.map { cln => lcl.loadClass(cln).getDeclaredConstructor().newInstance().asInstanceOf[InitialRuleImpl] } 36 | val disp = lcl.loadClass(dispN).getDeclaredConstructor().newInstance().asInstanceOf[Dispatch] 37 | (irs, disp) 38 | } 39 | private[this] val cutBuffer, irreducible = new CutBuffer(16) 40 | private[this] val freeWires = mutable.HashSet.empty[Cell] 41 | private[this] val freeWireLookup = mutable.HashMap.empty[Int, Symbol] 42 | private[this] var metrics: ExecutionMetrics = _ 43 | private[this] var allocator: ProxyAllocator = _ 44 | private[this] val singletons: Array[Cell] = new Array(symIds.size) 45 | private[this] var nextLabel = Long.MinValue 46 | private[this] val tailCallDepth = config.tailCallDepth 47 | 48 | def getMetrics: ExecutionMetrics = metrics 49 | 50 | def getAnalyzer: Analyzer[Cell] = new Analyzer[Cell] { 51 | val principals = (cutBuffer.iterator ++ irreducible.iterator).flatMap { case (c1, c2) => Seq((c1, c2), (c2, c1)) }.toMap 52 | def irreduciblePairs: IterableOnce[(Cell, Cell)] = irreducible.iterator 53 | def rootCells = (self.freeWires.iterator ++ principals.keysIterator).toSet 54 | def getSymbol(c: Cell): Symbol = { 55 | val sid = getSymId(c) 56 | reverseSymIds.getOrElse(sid, freeWireLookup.getOrElse(getSymId(c), new Symbol(s""))) 57 | } 58 | def getConnected(c: Cell, port: Int): (Cell, Int) = 59 | if(port == -1) principals.get(c).map((_, -1)).orNull else findCellAndPort(c, port) 60 | def isFreeWire(c: Cell): Boolean = freeWireLookup.contains(getSymId(c)) 61 | override def getPayload(ptr: Long): Any = { 62 | val sym = getSymbol(ptr) 63 | if((ptr & TAGMASK) == TAG_UNBOXED) { 64 | sym.payloadType match { 65 | case PayloadType.INT => (ptr >>> 32).toInt 66 | } 67 | } else { 68 | val address = ptr + payloadOffset(sym.arity, sym.payloadType) 69 | sym.payloadType match { 70 | case PayloadType.INT => Allocator.getInt(address) 71 | case PayloadType.LABEL => "label@" + Allocator.getLong(address) 72 | case PayloadType.REF => getProxy(ptr) 73 | } 74 | } 75 | } 76 | } 77 | 78 | def dispose(): Unit = { 79 | if(allocator != null) { 80 | allocator.dispose() 81 | allocator = null 82 | if(config.debugMemory) MemoryDebugger.setParent(null) 83 | } 84 | } 85 | 86 | def initData(): Unit = { 87 | cutBuffer.clear() 88 | irreducible.clear() 89 | freeWires.clear() 90 | freeWireLookup.clear() 91 | //dispose() 92 | nextLabel = Long.MinValue 93 | if(allocator == null) { 94 | allocator = config.newAllocator() 95 | if(config.debugMemory) { 96 | MemoryDebugger.setParent(allocator) 97 | allocator = MemoryDebugger 98 | } 99 | singletons.indices.foreach { i => 100 | val s = reverseSymIds(i) 101 | if(s.isSingleton) singletons(i) = newCell(i, s.arity) 102 | } 103 | } 104 | if(config.collectStats) metrics = new ExecutionMetrics 105 | initialRuleImpls.foreach { rule => 106 | val fws = rule.freeWires 107 | val off = reverseSymIds.size + freeWireLookup.size 108 | val lhs = newCell(-1, fws.length) 109 | fws.iterator.zipWithIndex.foreach { case (s, i) => 110 | freeWireLookup += ((i+off, s)) 111 | val c = newCell(i+off, 1) 112 | freeWires += c 113 | setAux(lhs, i, c, 0) 114 | } 115 | val rhs = newCell(-1, 0) 116 | rule.reduce(lhs, rhs, this) 117 | } 118 | if(config.collectStats) metrics = new ExecutionMetrics 119 | } 120 | 121 | def reduce(): Unit = 122 | while(!cutBuffer.isEmpty) { 123 | val (a0, a1) = cutBuffer.pop() 124 | dispatch.reduce(a0, a1, tailCallDepth, this) 125 | } 126 | 127 | private final def newCell(symId: Int, arity: Int, pt: PayloadType = PayloadType.VOID): Long = { 128 | val o = allocator.alloc(cellSize(arity, pt)) 129 | Allocator.putInt(o, mkHeader(symId)) 130 | o 131 | } 132 | 133 | // ptw methods: 134 | 135 | def addActive(a0: Cell, a1: Cell): Unit = cutBuffer.addOne(a0, a1) 136 | 137 | def addIrreducible(a0: Cell, a1: Cell): Unit = irreducible.addOne(a0, a1) 138 | 139 | def recordStats(steps: Int, cellAllocations: Int, proxyAllocations: Int, cachedCellReuse: Int, singletonUse: Int, 140 | unboxedCells: Int, loopSave: Int, directTail: Int, singleDispatchTail: Int, labelCreate: Int): Unit = 141 | metrics.recordStats(steps, cellAllocations, proxyAllocations, cachedCellReuse, singletonUse, unboxedCells, 142 | loopSave, directTail, singleDispatchTail, labelCreate) 143 | 144 | def recordMetric(metric: String, inc: Int): Unit = metrics.recordMetric(metric, inc) 145 | 146 | def getSingleton(symId: Int): Cell = singletons(symId) 147 | 148 | def allocCell(length: Int): Cell = allocator.alloc(length) 149 | def freeCell(address: Cell, length: Int): Unit = allocator.free(address, length) 150 | def allocProxied(length: Int): Cell = allocator.allocProxied(length) 151 | def freeProxied(address: Cell, length: Int): Unit = allocator.freeProxied(address, length) 152 | 153 | def alloc8(): Long = allocator.alloc8() 154 | def alloc16(): Long = allocator.alloc16() 155 | def alloc24(): Long = allocator.alloc24() 156 | def alloc32(): Long = allocator.alloc32() 157 | def free8(o: Long): Unit = allocator.free8(o) 158 | def free16(o: Long): Unit = allocator.free16(o) 159 | def free24(o: Long): Unit = allocator.free24(o) 160 | def free32(o: Long): Unit = allocator.free32(o) 161 | 162 | def alloc8p(): Long = allocator.alloc8p() 163 | def alloc16p(): Long = allocator.alloc16p() 164 | def alloc24p(): Long = allocator.alloc24p() 165 | def alloc32p(): Long = allocator.alloc32p() 166 | def free8p(o: Long): Unit = allocator.free8p(o) 167 | def free16p(o: Long): Unit = allocator.free16p(o) 168 | def free24p(o: Long): Unit = allocator.free24p(o) 169 | def free32p(o: Long): Unit = allocator.free32p(o) 170 | 171 | def getProxyPage(o: Long): AnyRef = allocator.getProxyPage(o) 172 | def getProxy(o: Long): AnyRef = allocator.getProxy(o) 173 | def setProxy(o: Long, v: AnyRef): Unit = allocator.setProxy(o, v) 174 | 175 | def newLabel: Long = { 176 | val r = nextLabel 177 | nextLabel += 1 178 | r 179 | } 180 | } 181 | 182 | object Interpreter { 183 | // Cell layout: 184 | // 64-bit header 185 | // n * 64-bit pointers for arity n 186 | // optional 64-bit payload depending on PayloadType 187 | // 188 | // Header layout: 189 | // LSB ... MSB 190 | // 012 3456789abcdef 0123456789abcdef 0123456789abcdef 0123456789abcdef 191 | // -------------------------------------------------------------------- 192 | // 110 [29-bit symId ] [ padding or 32-bit payload ] 193 | // 194 | // Pointer layouts: 195 | // LSB ... MSB 196 | // 012 3456789abcdef 0123456789abcdef 0123456789abcdef 0123456789abcdef 197 | // -------------------------------------------------------------------- 198 | // 000 [ 64-bit aligned address >> 3: pointer to cell (principal) ] 199 | // 100 [ 64-bit aligned address >> 3: pointer to aux port ] 200 | // 010 [ 29-bit symId ] [ unboxed 32-bit payload ] 201 | 202 | final val TAGWIDTH = 3 203 | final val TAGMASK = 7L 204 | final val ADDRMASK = -8L 205 | final val TAG_HEADER = 3 206 | final val TAG_PRINC_PTR = 0 207 | final val TAG_AUX_PTR = 1 208 | final val TAG_UNBOXED = 2 209 | 210 | final val SYMIDMASK = ((1L << 29)-1) << TAGWIDTH 211 | 212 | def showPtr(l: Long, symIds: Map[Int, Symbol] = null): String = { 213 | var raw = l.toBinaryString 214 | raw = "0"*(64-raw.length) + raw 215 | raw = raw.substring(0, 32) + ":" + raw.substring(32, 61) + ":" + raw.substring(61) 216 | def symStr(sid: Int) = 217 | if(symIds == null) s"$sid" 218 | else s"$sid(${symIds.getOrElse(sid, Symbol.NoSymbol).id})" 219 | val decoded = (l & TAGMASK) match { 220 | case _ if l == 0L => "NULL" 221 | case TAG_HEADER => 222 | val sid = ((l & SYMIDMASK) >>> TAGWIDTH).toInt 223 | s"HEADER:${symStr(sid)}:${l >>> 32}" 224 | case TAG_AUX_PTR => s"AUX_PTR:${l & ADDRMASK}" 225 | case TAG_PRINC_PTR => s"PRINC_PTR:${l & ADDRMASK}" 226 | case TAG_UNBOXED => 227 | val sid = ((l & SYMIDMASK) >>> TAGWIDTH).toInt 228 | s"UNBOXED:${symStr(sid)}:${l >>> 32}" 229 | case tag => s"Invalid-Tag-$tag:$l" 230 | } 231 | s"$raw::$decoded" 232 | } 233 | 234 | def cellSize(arity: Int, pt: PayloadType) = arity*8 + 8 + (if(pt == PayloadType.LABEL) 8 else 0) 235 | def auxPtrOffset(p: Int): Int = 8 + (p * 8) 236 | def payloadOffset(arity: Int, pt: PayloadType): Int = if(pt == PayloadType.LABEL) 8 + (arity * 8) else 4 237 | def mkHeader(sid: Int): Int = (sid << TAGWIDTH) | TAG_HEADER 238 | def mkUnboxed(sid: Int): Int = (sid << TAGWIDTH) | TAG_UNBOXED 239 | 240 | private def setAux(c: Long, p: Int, c2: Long, p2: Int): Unit = { 241 | var l = c2 + auxPtrOffset(p2) 242 | if(p2 >= 0) l |= TAG_AUX_PTR 243 | Allocator.putLong(c + auxPtrOffset(p), l) 244 | } 245 | 246 | private def getSymId(ptr: Long) = 247 | if((ptr & TAGMASK) == TAG_UNBOXED) ((ptr & SYMIDMASK) >>> TAGWIDTH).toInt 248 | else Allocator.getInt(ptr) >> TAGWIDTH 249 | 250 | private def findCellAndPort(cellAddress: Long, cellPort: Int): (Long, Int) = { 251 | var ptr = Allocator.getLong(cellAddress + auxPtrOffset(cellPort)) 252 | if((ptr & TAGMASK) != TAG_AUX_PTR) { 253 | (ptr, -1) 254 | } else { 255 | ptr = ptr & ADDRMASK 256 | var p = -1 257 | while((Allocator.getInt(ptr - auxPtrOffset(p)) & TAGMASK) != TAG_HEADER) 258 | p += 1 259 | (ptr - auxPtrOffset(p), p) 260 | } 261 | } 262 | 263 | def canUnbox(sym: Symbol, arity: Int): Boolean = 264 | arity == 0 && (sym.payloadType == PayloadType.INT || sym.payloadType == PayloadType.VOID) 265 | } 266 | 267 | final class CutBuffer(initialSize: Int) { 268 | private[this] var pairs = new Array[Cell](initialSize*2) 269 | private[this] var len = 0 270 | @inline def addOne(c1: Cell, c2: Cell): Unit = { 271 | if(len == pairs.length) 272 | pairs = Arrays.copyOf(pairs, pairs.length*2) 273 | pairs(len) = c1 274 | pairs(len+1) = c2 275 | len += 2 276 | } 277 | @inline def isEmpty: Boolean = len == 0 278 | @inline def pop(): (Cell, Cell) = { 279 | len -= 2 280 | val c1 = pairs(len) 281 | val c2 = pairs(len+1) 282 | (c1, c2) 283 | } 284 | def clear(): Unit = 285 | if(len > 0) { 286 | pairs = new Array[Cell](initialSize * 2) 287 | len = 0 288 | } 289 | def iterator: Iterator[(Cell, Cell)] = pairs.iterator.take(len).grouped(2).map { case Seq(c1, c2) => (c1, c2) } 290 | } 291 | -------------------------------------------------------------------------------- /src/main/scala/stc2/PTOps.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.stc2 2 | 3 | import de.szeiger.interact.ast.PayloadType 4 | import de.szeiger.interact.codegen.{BoxDesc, BoxOps} 5 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _} 6 | import de.szeiger.interact.stc2.PTOps.boxDesc 7 | 8 | class PTOps(m: MethodDSL, pt: PayloadType, codeGen: CodeGen) extends BoxOps(m, PTOps.boxDesc(pt)) { 9 | import codeGen._ 10 | 11 | def isVoid: Boolean = boxDesc.boxT == null 12 | 13 | def loadConst0: this.type = { 14 | pt match { 15 | case PayloadType.INT => 16 | m.iconst(0) 17 | } 18 | this 19 | } 20 | 21 | def setCellPayload(ptw: VarIdx, arity: Int)(loadCell: => Unit)(loadUnboxedPayload: => Unit): this.type = { 22 | pt match { 23 | case PayloadType.LABEL => 24 | loadCell 25 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd 26 | loadUnboxedPayload 27 | m.invokestatic(allocator_putLong) 28 | case PayloadType.INT => 29 | loadCell 30 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd 31 | loadUnboxedPayload 32 | m.invokestatic(allocator_putInt) 33 | case PayloadType.REF => 34 | m.aload(ptw) 35 | loadCell 36 | loadUnboxedPayload 37 | m.invoke(ptw_setProxy) 38 | } 39 | this 40 | } 41 | 42 | def getCellPayload(ptw: VarIdx, arity: Int)(loadCell: => Unit): this.type = { 43 | pt match { 44 | case PayloadType.REF => 45 | m.aload(ptw) 46 | loadCell 47 | m.invoke(ptw_getProxy) 48 | case PayloadType.INT => 49 | loadCell 50 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd 51 | m.invokestatic(allocator_getInt) 52 | case PayloadType.LABEL => 53 | loadCell 54 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd 55 | m.invokestatic(allocator_getLong) 56 | } 57 | this 58 | } 59 | 60 | def untag(loadValue: => Unit): this.type = { 61 | pt match { 62 | case PayloadType.INT => 63 | loadValue 64 | m.iconst(32).lushr.l2i 65 | } 66 | this 67 | } 68 | 69 | def tag(symId: Int)(loadValue: => Unit): this.type = { 70 | loadValue 71 | m.i2l.iconst(32).lshl 72 | m.lconst((symId.toLong << Interpreter.TAGWIDTH) | Interpreter.TAG_UNBOXED).lor 73 | this 74 | } 75 | } 76 | 77 | object PTOps { 78 | def boxDesc(pt: PayloadType): BoxDesc = pt match { 79 | case PayloadType.INT => BoxDesc.intDesc 80 | case PayloadType.REF => BoxDesc.refDesc 81 | case PayloadType.LABEL => BoxDesc.longDesc 82 | case PayloadType.VOID => BoxDesc.voidDesc 83 | } 84 | def apply(m: MethodDSL, pt: PayloadType)(implicit codeGen: CodeGen): PTOps = new PTOps(m, pt, codeGen) 85 | } 86 | -------------------------------------------------------------------------------- /src/test/resources/ack.check: -------------------------------------------------------------------------------- 1 | res = 2045n 2 | res2 = 2045n 3 | resB = BoxedInt[2045] 4 | resI = Int[2045] 5 | -------------------------------------------------------------------------------- /src/test/resources/ack.in: -------------------------------------------------------------------------------- 1 | cons Z 2 | cons S(x) 3 | 4 | # Direct encoding as in https://github.com/inpla/inpla/blob/main/sample/AckSZ-3_5.in 5 | # 4182049 reductions (with pre-reduced rhs) 6 | 7 | def ack(a, b) = r 8 | | Z, y => S(y) 9 | | S(x), Z => ack(x, S(Z)) 10 | | S(x), S(y) => (x1, x2) = dup(x); ack(x1, ack(S(x2),y)) 11 | 12 | #def ack(_, y) = r 13 | # | Z => S(y) 14 | # | S(x) => ack_Sx(y, x) 15 | # 16 | #def ack_Sx(_, x) = r 17 | # | Z => ack(x, S(Z)) 18 | # | S(y) => (x1, x2) = dup(x); ack(x1, ack_Sx(y, x2)) 19 | 20 | let res = ack(3n, 8n) 21 | 22 | 23 | # Encoding with pred from https://www.user.tu-berlin.de/o.runge/tfs/workshops/gtvmt08/Program/paper_38.pdf 24 | # 8360028 reductions 25 | 26 | def pred(_) = r 27 | | Z => Z 28 | | S(x) => x 29 | def ack2(_, a) = b 30 | | Z => S(a) 31 | | S(x) => ack2b(a, S(x)) 32 | def ack2b(_, a) = b 33 | | Z => ack2(pred(a), S(Z)) 34 | | S(y) => (a1, a2) = dup(a); ack2(pred(a1), ack2(a2, y)) 35 | 36 | let res2 = ack2(3n, 8n) 37 | 38 | 39 | 40 | # Int-based encoding 41 | # 2786001 reductions 42 | 43 | cons Int[int] 44 | 45 | def ackI(a, b) = r 46 | | Int[x], Int[y] if [x == 0] => Int[y + 1] 47 | if [y == 0] => ackI(Int[x-1], Int[1]) 48 | else => ackI(Int[x-1], ackI(Int[x], Int[y-1])) 49 | 50 | let resI = ackI(Int[3], Int[8]) 51 | 52 | 53 | # Boxed Integer version using ref payloads 54 | # 2786001 reductions 55 | 56 | cons BoxedInt[ref] 57 | 58 | def ackB(a, b) = r 59 | | BoxedInt[x], BoxedInt[y] 60 | if [de.szeiger.interact.MainTest.is0(x)] => 61 | BoxedInt[de.szeiger.interact.MainTest.inc(y)] 62 | [eraseRef(x)] 63 | if [de.szeiger.interact.MainTest.is0(y)] => 64 | ackB(BoxedInt[de.szeiger.interact.MainTest.dec(x)], BoxedInt[de.szeiger.interact.MainTest.box(1)]) 65 | [eraseRef(y)] 66 | else => 67 | [de.szeiger.interact.MainTest.ackHelper(x, x1, x2)] 68 | ackB(BoxedInt[x1], ackB(BoxedInt[x2], BoxedInt[de.szeiger.interact.MainTest.dec(y)])) 69 | 70 | let resB = ackB(BoxedInt[de.szeiger.interact.MainTest.box(3)], BoxedInt[de.szeiger.interact.MainTest.box(8)]) 71 | -------------------------------------------------------------------------------- /src/test/resources/diverging.check: -------------------------------------------------------------------------------- 1 | Error: src/test/resources/diverging.in:7:7: Circular expansion (f <-> A) => (f <-> B) => (f <-> C) => (f <-> D) 2 | | | A => f(B) 3 | | ^ 4 | 5 | Error: src/test/resources/diverging.in:8:7: Circular expansion (f <-> B) => (f <-> C) => (f <-> D) => (f <-> A) 6 | | | B => f(C) 7 | | ^ 8 | 9 | Error: src/test/resources/diverging.in:9:7: Circular expansion (f <-> C) => (f <-> D) => (f <-> A) => (f <-> B) 10 | | | C => f(D) 11 | | ^ 12 | 13 | Error: src/test/resources/diverging.in:10:7: Circular expansion (f <-> D) => (f <-> A) => (f <-> B) => (f <-> C) 14 | | | D => f(A) 15 | | ^ 16 | 17 | 4 errors found. 18 | -------------------------------------------------------------------------------- /src/test/resources/diverging.in: -------------------------------------------------------------------------------- 1 | cons A 2 | cons B 3 | cons C 4 | cons D 5 | 6 | def f(_) = r 7 | | A => f(B) 8 | | B => f(C) 9 | | C => f(D) 10 | | D => f(A) 11 | -------------------------------------------------------------------------------- /src/test/resources/embedded.check: -------------------------------------------------------------------------------- 1 | dummy = Dummy[d:0:0] 2 | dummy2b = Dummy[d2:1:1] 3 | fib10 = Int[89] 4 | i1 = Int[42] 5 | i2 = Int[42] 6 | len = Int[5] 7 | mult = Int[6] 8 | r1 = Pair(1n, 1n) 9 | r2 = Pair(2n, 2n) 10 | simple = Int[42] 11 | str = String["foo"] 12 | sum = Int[8] 13 | -------------------------------------------------------------------------------- /src/test/resources/embedded.in: -------------------------------------------------------------------------------- 1 | cons Int[int] 2 | cons String[ref] 3 | cons Dummy[ref] 4 | 5 | # Automatic move of payload 6 | def _ + y = r 7 | | Int[i] => intAdd[i](y) 8 | 9 | #match Int[i] + y => intAdd[i](y) 10 | 11 | # Explicit computation on payloads 12 | def intAdd[int a](_) = r 13 | | Int[b] => [add(a, b, c)]; Int[c] 14 | 15 | # Value-returning method call instead of out parameter 16 | def strlen(_) = r 17 | | String[s] => Int[strlen(s)] 18 | 19 | let simple = Int[42] 20 | x = Int[5]; y = Int[3]; sum = x + y 21 | str = String["foo"] 22 | len = strlen(String["12345"]) 23 | dummy = Dummy[d]; [de.szeiger.interact.ManagedDummy.create("d", d)] 24 | dummy2 = Dummy[d2]; [de.szeiger.interact.ManagedDummy.create("d2", d2)] 25 | (dummy2a, dummy2b) = dup(dummy2) 26 | erase(dummy2a) 27 | 28 | # Currying 29 | def _ * _ = r 30 | | Int[a], Int[b] => [mult(a, b, c)]; Int[c] 31 | 32 | let mult = Int[3] * Int[2] 33 | 34 | # Conditional matching 35 | def fib(_) = r 36 | | Int[i] if [i == 0] => Int[1] 37 | if [i == 1] => Int[1] 38 | else => fib(Int[i-1]) + fib(Int[i-2]) 39 | 40 | let fib10 = fib(Int[10]) 41 | 42 | 43 | 44 | 45 | cons Z 46 | cons S(x) 47 | 48 | cons Func(x, fx) 49 | 50 | def apply(l, in) = out 51 | | Func(i, o) => in = i; o 52 | 53 | cons Pair(a, b) 54 | 55 | let f = Func(x, Pair(x1, x2)) 56 | (x1, x2) = dup[l1](x) 57 | (f1, f2) = dup[l2](f) 58 | r1 = apply(f1, 1n) 59 | r2 = apply(f2, 2n) 60 | 61 | let i0 = Int[42] 62 | (i1, i2) = dup(i0) 63 | -------------------------------------------------------------------------------- /src/test/resources/fib.check: -------------------------------------------------------------------------------- 1 | res = 89n 2 | -------------------------------------------------------------------------------- /src/test/resources/fib.in: -------------------------------------------------------------------------------- 1 | cons Z 2 | cons S(n) 3 | 4 | def _ + y = r 5 | | Z => y 6 | | S(x) => S(x + y) 7 | 8 | def fib(_) = r 9 | | Z => 1n 10 | | S(Z) => 1n 11 | | S(S(n)) => (n1, n2) = dup(n) 12 | fib(S(n1)) + fib(n2) 13 | 14 | let res = fib(10n) 15 | -------------------------------------------------------------------------------- /src/test/resources/inlining.check: -------------------------------------------------------------------------------- 1 | res = D 2 | res1 = Int[1] 3 | res2 = Int[2] 4 | res22 = Int[22] 5 | res3 = 9n 6 | 7 | Irreducible pairs: 8 | A <-> C 9 | B <-> B 10 | B <-> B 11 | B <-> B 12 | B <-> B 13 | B <-> B 14 | -------------------------------------------------------------------------------- /src/test/resources/inlining.in: -------------------------------------------------------------------------------- 1 | cons A(x) 2 | cons B 3 | cons C(x) 4 | cons D 5 | 6 | match A(x) = B => C(x) = D; A(D) = C(D) 7 | match C(x) = D => x = D 8 | 9 | let A(res) = B 10 | 11 | 12 | cons Int[int] 13 | 14 | def f(_) = r 15 | | Int[i] if [i == 0] => B = B; g(Int[i + 10]) 16 | else => B = B; B = B; g(Int[i + 20]) 17 | 18 | def g(_) = r 19 | | Int[i] if [i == 10] => Int[1] 20 | if [i == 21] => Int[2] 21 | else => Int[i] 22 | 23 | let 24 | res1 = f(Int[0]) 25 | res2 = f(Int[1]) 26 | res22 = f(Int[2]) 27 | 28 | 29 | cons Z 30 | cons S(n) 31 | def f1(_, a) = b 32 | | S(x) => g1(a, x) 33 | def g1(_, a) = b 34 | | S(y) => H1(y, b) = a 35 | cons H1(y, b) = a 36 | match H1(y, b) = S(t) => b = f1(S(t), y) 37 | 38 | 39 | def pred(_) = r 40 | | Z => Z 41 | | S(x) => x 42 | def ack2(_, a) = b 43 | | Z => S(a) 44 | | S(x) => ack2b(a, S(x)) 45 | def ack2b(_, a) = b 46 | | Z => ack2(pred(a), S(Z)) 47 | | S(y), S(t) => (t1, t2) = dup(t); ack2(pred(S(t1)), ack2(S(t2), y)) 48 | let res3 = ack2(2n, 3n) 49 | -------------------------------------------------------------------------------- /src/test/resources/lists.check: -------------------------------------------------------------------------------- 1 | flatMapped = 11n :: 12n :: 21n :: 22n :: 31n :: 32n :: Nil 2 | idMapped = 1n :: 2n :: 3n :: Nil 3 | l0_length = 3n 4 | l0_mapped = 3n :: 4n :: 5n :: Nil 5 | l0_mapped_lambda = 3n :: 4n :: 5n :: Nil 6 | listCons = 1n :: 2n :: 3n :: 4n :: 5n ::: Nil 7 | 8 | Irreducible pairs: 9 | ::: <-> S 10 | 11 | -------------------------------------------------------------------------------- /src/test/resources/lists.in: -------------------------------------------------------------------------------- 1 | # Natural numbers 2 | cons Z 3 | cons S(n) 4 | 5 | # Addition 6 | def _ + y = r 7 | | Z => y 8 | | S(x) => x + S(y) 9 | 10 | # Lists 11 | cons Nil 12 | cons head :: tail = l 13 | 14 | def length(list) = r 15 | | Nil => Z 16 | | x :: xs => erase(x); S(length(xs)) 17 | 18 | def map(list, fi, fo) = r 19 | | Nil => erase(fi); erase(fo); Nil 20 | | x :: xs => (x, fi2) = dup(fi) 21 | (fo1, fo2) = dup(fo) 22 | fo1 :: map(xs, fi2, fo2) 23 | 24 | def _ ::: ys = r 25 | | Nil => ys 26 | | x :: xs => x :: (xs ::: ys) 27 | 28 | def flatMap(list, fi, fo) = r 29 | | Nil => erase(fi); erase(fo); Nil 30 | | x :: xs => (x, fi2) = dup(fi) 31 | (fo1, fo2) = dup(fo) 32 | fo1 ::: flatMap(xs, fi2, fo2) 33 | 34 | # Example: List operations 35 | let l0 = 1n :: 2n :: 3n :: Nil 36 | (l0a, l0b) = dup(l0) 37 | l0_length = length(l0a) 38 | l0_mapped = map(l0b, x, x + 2n) 39 | 40 | let listCons = (1n ::2n :: 3n :: Nil) ::: (4n :: 5n ::: Nil) 41 | 42 | let idMapped = map(1n :: 2n :: 3n :: Nil, y, y) 43 | 44 | def mkList(_) = r 45 | | Z => Z :: Z :: Nil 46 | | S(x) => (s1, s2) = dup(S(x)) 47 | s1 + 1n :: s2 + 2n :: Nil 48 | 49 | let flatMapped = 50 | flatMap(10n :: 20n :: 30n :: Nil, x, mkList(x)) 51 | 52 | # Explicit lambdas 53 | cons in |> out 54 | def apply(l, in) = out 55 | | i |> o => in = i; o 56 | 57 | # Example: List mapping with lambdas 58 | def map2(l, f) = r 59 | | Nil => erase(f); Nil 60 | | x :: xs => (f1, f2) = dup(f) 61 | apply(f1, x) :: map2(xs, f2) 62 | 63 | let l0 = 1n :: 2n :: 3n :: Nil 64 | l0_mapped_lambda = map2(l0, x |> x + 2n) 65 | -------------------------------------------------------------------------------- /src/test/resources/par-mult.check: -------------------------------------------------------------------------------- 1 | res = 10000n 2 | resI = Int[10000] 3 | -------------------------------------------------------------------------------- /src/test/resources/par-mult.in: -------------------------------------------------------------------------------- 1 | cons Z 2 | cons S(n) 3 | 4 | def _ + y = r 5 | | Z => y 6 | | S(x) => x + S(y) 7 | 8 | def _ * y = r 9 | | Z => erase(y); Z 10 | | S(x) => (a, b) = dup(y) 11 | r = b + x * a 12 | 13 | let res = 100n * 100n 14 | 15 | 16 | cons Int[int] 17 | def add(_, _) = r 18 | | Int[x], Int[y] if [x == 0] => Int[y] 19 | else => add(Int[x-1], Int[y+1]) 20 | def mult(_, _) = r 21 | | Int[x], Int[y] if [x == 0] => Int[0] 22 | else => add(Int[y], mult(Int[x-1], Int[y])) 23 | let resI = mult(Int[100], Int[100]) 24 | -------------------------------------------------------------------------------- /src/test/resources/seq-def.check: -------------------------------------------------------------------------------- 1 | example_3_plus_5 = 8n 2 | example_3_times_2 = 6n 3 | foo_curry = 1n 4 | -------------------------------------------------------------------------------- /src/test/resources/seq-def.in: -------------------------------------------------------------------------------- 1 | # Natural numbers 2 | cons Z 3 | cons S(n) 4 | 5 | # Erasure and Duplication 6 | # def erase(_) 7 | # def dup(_) = (a, b) 8 | # | dup(_) = (c, d) => (c, d) 9 | 10 | # Addition 11 | def _ + y = r 12 | | Z => y 13 | | S(x) => x + S(y) 14 | 15 | # Separate reductions 16 | #match Z + y => y 17 | #match S(x) + y => x + S(y) 18 | 19 | #match erase(Z) => () 20 | #match erase(S(x)) => erase(x) 21 | #match dup(Z) = (a, b) => a = Z; b = Z 22 | #match dup(S(x)) = (a, b) => (sa, sb) = dup(x); s = S(sa); b = S(sb) 23 | #match dup(dup(_) = (c, d)) = (a, b) => (c, d) 24 | #match S(x) = S(y) => x = y 25 | 26 | # Multiplication 27 | def _ * y = r 28 | # | Z => erase(y), Z 29 | # | S(x) => (y1, y2) = dup(y); x * y1 + y2 30 | 31 | match Z * y => erase(y); Z 32 | match S(x) * y => (y1, y2) = dup(y); x * y1 + y2 33 | 34 | # Example: Computations on church numerals 35 | let y = 5n 36 | x = 3n 37 | example_3_plus_5 = x + y 38 | 39 | let example_3_times_2 = 3n * 2n 40 | 41 | # Currying on additional args: 42 | def foo(a, b) = r 43 | | Z => erase(b); Z 44 | | S(x), S(y) => x + y 45 | let foo_curry = foo(1n, 2n) 46 | -------------------------------------------------------------------------------- /src/test/scala/BitOpsTest.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import org.junit.Assert._ 4 | import org.junit.Test 5 | 6 | import BitOps._ 7 | 8 | class BitOpsTest { 9 | 10 | val vals = Seq(-128, -127, -1, 0, 1, 126, 127) 11 | 12 | @Test 13 | def testByte0(): Unit = { 14 | for { 15 | v0 <- vals 16 | } { 17 | val i = checkedIntOfBytes(v0, 0, 0, 0) 18 | assertEquals(v0, byte0(i)) 19 | } 20 | } 21 | 22 | @Test 23 | def testByte1(): Unit = { 24 | for { 25 | v1 <- vals 26 | } { 27 | val i = checkedIntOfBytes(0, v1, 0, 0) 28 | assertEquals(v1, byte1(i)) 29 | } 30 | } 31 | 32 | @Test 33 | def testByte2(): Unit = { 34 | for { 35 | v2 <- vals 36 | } { 37 | val i = checkedIntOfBytes(0, 0, v2, 0) 38 | assertEquals(v2, byte2(i)) 39 | } 40 | } 41 | 42 | @Test 43 | def testByte3(): Unit = { 44 | for { 45 | v3 <- vals 46 | } { 47 | val i = checkedIntOfBytes(0, 0, 0, v3) 48 | assertEquals(v3, byte3(i)) 49 | } 50 | } 51 | 52 | @Test 53 | def testMixed(): Unit = { 54 | val vals = Seq(-128, -127, -1, 0, 1, 126, 127) 55 | for { 56 | v0 <- vals 57 | v1 <- vals 58 | v2 <- vals 59 | v3 <- vals 60 | } { 61 | val i = checkedIntOfBytes(v0, v1, v2, v3) 62 | assertEquals(v0, byte0(i)) 63 | assertEquals(v1, byte1(i)) 64 | assertEquals(v2, byte2(i)) 65 | assertEquals(v3, byte3(i)) 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/test/scala/LongBitOpsTest.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import org.junit.Assert._ 4 | import org.junit.Test 5 | 6 | import LongBitOps._ 7 | 8 | class LongBitOpsTest { 9 | 10 | val vals = Seq(Short.MinValue, Short.MinValue+1, -128, -127, -1, 0, 1, 126, 127, Short.MaxValue-1, Short.MaxValue) 11 | 12 | @Test 13 | def testShort0(): Unit = { 14 | for { 15 | v0 <- vals 16 | } { 17 | val i = checkedLongOfShorts(v0, 0, 0, 0) 18 | assertEquals(v0, short0(i)) 19 | } 20 | } 21 | 22 | @Test 23 | def testShort1(): Unit = { 24 | for { 25 | v1 <- vals 26 | } { 27 | val i = checkedLongOfShorts(0, v1, 0, 0) 28 | assertEquals(v1, short1(i)) 29 | } 30 | } 31 | 32 | @Test 33 | def testShort2(): Unit = { 34 | for { 35 | v2 <- vals 36 | } { 37 | val i = checkedLongOfShorts(0, 0, v2, 0) 38 | assertEquals(v2, short2(i)) 39 | } 40 | } 41 | 42 | @Test 43 | def testShort3(): Unit = { 44 | for { 45 | v3 <- vals 46 | } { 47 | val i = checkedLongOfShorts(0, 0, 0, v3) 48 | assertEquals(v3, short3(i)) 49 | } 50 | } 51 | 52 | @Test 53 | def testMixed(): Unit = { 54 | for { 55 | v0 <- vals 56 | v1 <- vals 57 | v2 <- vals 58 | v3 <- vals 59 | } { 60 | val i = checkedLongOfShorts(v0, v1, v2, v3) 61 | assertEquals(v0, short0(i)) 62 | assertEquals(v1, short1(i)) 63 | assertEquals(v2, short2(i)) 64 | assertEquals(v3, short3(i)) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/test/scala/MainTest.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import org.junit.Test 4 | import org.junit.runner.RunWith 5 | import org.junit.runners.Parameterized 6 | import org.junit.runners.Parameterized.Parameters 7 | 8 | import java.nio.file.Path 9 | import scala.jdk.CollectionConverters._ 10 | 11 | @RunWith(classOf[Parameterized]) 12 | class MainTest(spec: String) { 13 | val SCALE = 0 14 | val conf = Config.defaultConfig.withSpec(spec).copy(showAfter = Set(""), phaseLog = Set(""), 15 | //writeOutput = Some(Path.of("bench/gen-classes")), writeJava = Some(Path.of("bench/gen-src")) 16 | ) 17 | 18 | def check(testName: String, scaleFactor: Int = 1, expectedSteps: Int = -1, fail: Boolean = false, config: Config = conf): Unit = 19 | for(i <- 1 to (if(SCALE == 0) 1 else SCALE * scaleFactor)) TestUtils.check(testName, expectedSteps, fail, config) 20 | 21 | @Test def testSeqDef = check("seq-def", scaleFactor = 50, expectedSteps = 32) 22 | @Test def testLists = check("lists") 23 | @Test def testParMult = check("par-mult") 24 | @Test def testInlining = check("inlining", expectedSteps = if(conf.backend.allowPayloadTemp) 99 else 102) 25 | @Test def testFib = check("fib") 26 | @Test def testEmbedded = check("embedded") 27 | @Test def testAck = check("ack", expectedSteps = if(conf.backend.allowPayloadTemp) 18114077 else 23686073) 28 | @Test def testDiverging = check("diverging", fail = true) 29 | } 30 | 31 | object MainTest { 32 | @Parameters(name = "{0}") 33 | def interpreters = Seq( 34 | "sti", "stc1", "stc2", 35 | //"mt0.i", "mt1.i", "mt8.i", 36 | //"mt1000.i", "mt1001.i", "mt1008.i", 37 | //"mt0.c", "mt1.c", "mt8.c", 38 | //"mt1000.c", "mt1001.c", "mt1008.c", 39 | ).map(s => Array[AnyRef](s)).asJava 40 | 41 | // used by ack.in: 42 | def is0(i: java.lang.Integer): Boolean = i.intValue() == 0 43 | def box(i: Int): java.lang.Integer = Integer.valueOf(i) 44 | def inc(i: java.lang.Integer): java.lang.Integer = box(i.intValue() + 1) 45 | def dec(i: java.lang.Integer): java.lang.Integer = box(i.intValue() - 1) 46 | def ackHelper(i: java.lang.Integer, o1: RefOutput, o2: RefOutput): Unit = { 47 | o1.setValue(dec(i)) 48 | o2.setValue(i) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/test/scala/TestUtils.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact 2 | 3 | import org.junit.Assert 4 | import org.junit.Assert._ 5 | 6 | import java.io.{ByteArrayOutputStream, PrintStream} 7 | import java.nio.charset.StandardCharsets 8 | import java.nio.file.{Files, Path} 9 | import java.util.concurrent.atomic.AtomicInteger 10 | 11 | object TestUtils { 12 | def check(testName: String, expectedSteps: Int = -1, fail: Boolean = false, config: Config = Config.defaultConfig): Unit = { 13 | val basePath = s"src/test/resources/$testName" 14 | val statements = Parser.parse(Path.of(basePath+".in")) 15 | val (result, success, steps) = try { 16 | val model = new Compiler(statements, config.copy(collectStats = true)) 17 | val inter = model.createInterpreter() 18 | inter.initData() 19 | inter.reduce() 20 | if(inter.getMetrics != null) inter.getMetrics.logStats() 21 | val out = new ByteArrayOutputStream() 22 | inter.getAnalyzer.log(new PrintStream(out, true, StandardCharsets.UTF_8), color = false) 23 | (out.toString(StandardCharsets.UTF_8), true, inter.getMetrics.getSteps) 24 | } catch { 25 | case ex: CompilerResult => (Colors.stripColors(ex.getMessage).replace("src\\test\\resources\\", "src/test/resources/"), false, -1) 26 | } 27 | val checkFile = Path.of(basePath+".check") 28 | if(Files.exists(checkFile)) { 29 | val check = Files.readString(checkFile, StandardCharsets.UTF_8) 30 | if(check.trim.replaceAll("\r", "") != result.trim.replaceAll("\r", "")) { 31 | println("---- Expected ----") 32 | println(check) 33 | println("---- Actual ----") 34 | println(result) 35 | println("---- End ----") 36 | assertEquals(check.trim.replaceAll("\r", ""), result.trim.replaceAll("\r", "")) 37 | } 38 | } 39 | if(fail && success) Assert.fail("Failure expected") 40 | else if(!fail && !success) Assert.fail("Unexpected failure") 41 | if(expectedSteps >= 0 && success) assertEquals(s"Expected $expectedSteps steps, but fully reduced after $steps", expectedSteps, steps) 42 | } 43 | } 44 | 45 | class ManagedDummy(name: String) extends LifecycleManaged { 46 | private[this] val copied, erased = new AtomicInteger() 47 | override def toString: String = s"$name:$copied:$erased" 48 | override def erase(): Unit = erased.incrementAndGet() 49 | override def copy(): LifecycleManaged = { copied.incrementAndGet(); this } 50 | } 51 | object ManagedDummy { 52 | def create(name: String, res: RefOutput): Unit = res.setValue(new ManagedDummy(name)) 53 | } 54 | -------------------------------------------------------------------------------- /src/test/scala/WorkersTest.scala: -------------------------------------------------------------------------------- 1 | package de.szeiger.interact.mt 2 | 3 | import de.szeiger.interact.mt.workers.{Worker, Workers} 4 | import org.junit.Assert._ 5 | import org.junit.Test 6 | 7 | import java.util.concurrent.CountDownLatch 8 | import java.util.concurrent.atomic.AtomicInteger 9 | 10 | class WorkersTest { 11 | 12 | @Test 13 | def test1(): Unit = { 14 | val unfinished = new AtomicInteger(0) 15 | var latch: CountDownLatch = new CountDownLatch(1) 16 | val count = new AtomicInteger(0) 17 | case class Task(sub: Task*) 18 | def add(t: Task): Unit = { 19 | unfinished.incrementAndGet() 20 | ws.add(t) 21 | } 22 | class Processor extends Worker[Task] { 23 | def apply(t: Task): Unit = { 24 | count.incrementAndGet() 25 | t.sub.foreach { t2 => 26 | Thread.sleep(50) 27 | unfinished.incrementAndGet() 28 | this.add(t2) 29 | } 30 | if(unfinished.decrementAndGet() == 0) latch.countDown() 31 | } 32 | } 33 | lazy val ws = new Workers[Task](8, 1024, _ => new Processor) 34 | add(Task()) 35 | ws.start() 36 | add(Task(Task(Task(), Task(Task(), Task())), Task(Task(), Task(), Task()))) 37 | while(unfinished.get() != 0) latch.await() 38 | assertEquals(11, count.get()) 39 | latch = new CountDownLatch(1) 40 | add(Task(Task(Task(), Task(Task(), Task())), Task(Task(), Task(), Task()))) 41 | while(unfinished.get() != 0) latch.await() 42 | assertEquals(21, count.get()) 43 | ws.shutdown() 44 | } 45 | } 46 | --------------------------------------------------------------------------------