├── .gitignore
├── README.md
├── bench
└── src
│ └── main
│ └── scala
│ ├── ComparisonBenchmark.scala
│ ├── FFMBenchmark.scala
│ └── InterpreterBenchmark.scala
├── build.sbt
├── example.in
├── hs_err_pid31620.log
├── hs_err_pid32029.log
├── project
├── build.properties
└── plugins.sbt
├── run.sh
└── src
├── main
├── java
│ ├── Foo.java
│ └── de
│ │ └── szeiger
│ │ └── interact
│ │ ├── codegen
│ │ └── LocalClassLoader.java
│ │ └── stc2
│ │ └── MetaClass.java
└── scala
│ ├── Analyzer.scala
│ ├── BitOps.scala
│ ├── CleanEmbedded.scala
│ ├── Colors.scala
│ ├── Compiler.scala
│ ├── CreateWiring.scala
│ ├── Curry.scala
│ ├── Debug.scala
│ ├── ExecutionMetrics.scala
│ ├── ExpandRules.scala
│ ├── Global.scala
│ ├── Inline.scala
│ ├── Main.scala
│ ├── NormalizeCondition.scala
│ ├── Parser.scala
│ ├── PlanRules.scala
│ ├── Prepare.scala
│ ├── ResolveEmbedded.scala
│ ├── Runtime.scala
│ ├── SymCounts.scala
│ ├── ast
│ ├── AST.scala
│ ├── Symbols.scala
│ └── Transform.scala
│ ├── codegen
│ ├── AbstractCodeGen.scala
│ ├── BoxOps.scala
│ ├── ClassWriter.scala
│ ├── ParSupport.scala
│ └── dsl
│ │ ├── Acc.scala
│ │ ├── DSL.scala
│ │ ├── Desc.scala
│ │ └── TypedDSL.scala
│ ├── mt
│ ├── CodeGen.scala
│ ├── Interpreter.scala
│ └── workers
│ │ └── Workers.scala
│ ├── offheap
│ ├── Allocator.scala
│ └── MemoryDebugger.scala
│ ├── stc1
│ ├── CodeGen.scala
│ ├── GenStaticReduce.scala
│ ├── Interpreter.scala
│ └── PTOps.scala
│ ├── stc2
│ ├── CodeGen.scala
│ ├── GenStaticReduce.scala
│ ├── Interpreter.scala
│ └── PTOps.scala
│ └── sti
│ └── Interpreter.scala
└── test
├── resources
├── ack.check
├── ack.in
├── diverging.check
├── diverging.in
├── embedded.check
├── embedded.in
├── fib.check
├── fib.in
├── inlining.check
├── inlining.in
├── lists.check
├── lists.in
├── par-mult.check
├── par-mult.in
├── seq-def.check
└── seq-def.in
└── scala
├── BitOpsTest.scala
├── LongBitOpsTest.scala
├── MainTest.scala
├── TestUtils.scala
└── WorkersTest.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | target
2 | .bsp
3 | .idea
4 | bench/gen-classes
5 | bench/gen-src
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Language, interpreter and compiler for interaction nets
2 |
3 | The language implements interaction nets and combinators as defined in https://core.ac.uk/download/pdf/81113716.pdf.
4 |
5 | ## Running
6 |
7 | Launch `sbt` and then `runMain Main example.in` in the sbt shell to run [example.in](./example.in) using a single-threaded interpreter.
8 |
9 | Launch `sbt` and then `runMain Debug example.in` in the sbt shell to run [example.in](./example.in) using the step-by-step debugger.
10 |
11 | ## Language
12 |
13 | There are 4 basic statements: `cons` defines a data constructor, `def` defines a function (with optional rules), `match` defines a detached rule, and `let` creates a part of the net that should be evaluated. Comments start with `#` and extend until the end of the line. Expression lists use Haskell-style significant indentation.
14 |
15 | ### Constructors
16 |
17 | Constructors are written as `Id(aux_1, aux_2, ..., aux_n) = principal`. The parentheses are optional for arity 0. The prinpical port is only used for documentation purposes and can always be omitted. For example:
18 |
19 | ```
20 | # Natural numbers
21 | cons Z
22 | cons S(n)
23 | ```
24 |
25 | ### Functions
26 |
27 | Semantically there is no difference between data and code, and it is possible to define and use functions with the normal constructor syntax and detached rules (or conversely use function syntax for data). A separate syntax exists because it makes sense for data to have the principal port be the return value of the expression, but it is more useful for functions to have the principal port as a parameter:
28 |
29 | ```
30 | def add(x, y) = r
31 | ```
32 |
33 | The first argument of the function is always assigned to the principal port while the rest of the arguments and return value(s) make up the auxiliary ports. Note that return values must always be specified for functions because their arity is not fixed (like it is for constructors). Tuple syntax is used for multiple return values, e.g.:
34 |
35 | ```
36 | def dup(x) = (a, b)
37 | ```
38 |
39 | In order to define rules together with the function, the parameters and return values must be named. The first argument (i.e. the principal port) can be always omitted by using a `_` wildcard:
40 |
41 | ```
42 | def add(_, y) = r
43 | ```
44 |
45 | Other parameters can be omitted if the rules don't need them (because they use curried patterns or are defined as detached rules).
46 |
47 | ### Data
48 |
49 | The interaction nets that should be reduced by the interpreter are defined using `let` statements:
50 |
51 | ```
52 | let two = S(S(Z))
53 | y = S(S(S(S(S(Z)))))
54 | x = S(S(S(Z)))
55 | example_3_plus_5 = add(x, y)
56 | ```
57 |
58 | The body of a `let` statement contains a block of expressions consisting of assignments and nested function / constructor applications. Tuples can be used to group individual values, but they cannot be nested. Variables that are used exactly once are defined as free wires. Additional temporary variables can be introduced, but they must be used exactly twice. The order of expressions and the direction of assignments is irrelevant.
59 |
60 | The syntax for expression blocks uses Haskell-style layout: Additional expressions must start at the same indentation level as the first one. Lines with larger indentation continue the current expression. Multiple exressions on the same line must be separated by `;`. A trailing `;` at the end of the line is optional.
61 |
62 | ### Operators
63 |
64 | Both, constructors and functions, can be written as symbolic binary operators. An operator consists of an arbitrary combination of the symbols `*/%+-:<>&^|`. Precedence is based on the first character (the same as in Scala), operators ending with `:` are right-associative, all others left-associative (again, like in Scala). All operators in a chain of same-precedence operations must have the same associativity. Operator definitions are written in infix syntax:
65 |
66 | ```
67 | # Linked lists
68 | cons Nil
69 | cons head :: tail
70 |
71 | def _ + y = r
72 | ```
73 |
74 | The same infix notation is used for applying them in expressions:
75 |
76 | ```
77 | let example_3_plus_2 = S(S(S(Z))) + S(S(Z))
78 | ```
79 |
80 | ### Rules
81 |
82 | Reduction rules for functions can be specified together with the definition using a pattern matching syntax which matches on the first argument:
83 |
84 | ```
85 | # Addition on natural numbers
86 | def _ + y = r
87 | | Z => y
88 | | S(x) => r = x + S(y)
89 | ```
90 |
91 | The right-hand side contains an expression list, similar to a `let` clause. All function/constructor arguments and return values (except the principal port on each side of the match) must be used exactly once in the reduction.
92 |
93 | ```
94 | def _ * y = r
95 | | Z => erase(y)
96 | Z = r
97 | | S(x) => (y1, y2) = dup(y)
98 | x * y1 + y2 = r
99 | ```
100 |
101 | If the last expression in the block is missing an assignment, it is implicitly assigned to the return value of the function:
102 |
103 | ```
104 | def _ + y = r
105 | | Z => y # same as y = r
106 | | S(x) => x + S(y) # same as x + S(y) = r
107 | ```
108 |
109 | The standard `dup` and `erase` functions are pre-defined, and combinators with all user-defined constructors and functions are derived automatically. The pre-defined functions are equivalent to the following syntax:
110 |
111 | ```
112 | def erase(_) = ()
113 | def dup[label l](x) = (x1, x2)
114 | ```
115 |
116 | When matching on another function instead of a constructor, a `_` wildcard must be used to mark the first argument (i.e. the principal port) as the designated return value of an assignment expression. The wildcard always expands to the return value of the nearest enclosing assignment. For example:
117 |
118 | ```
119 | def dup(_) = (a, b)
120 | | dup(_) = (c, d) => (c, d)
121 | ```
122 |
123 | When matching on a function with no return value (like `erase`), an assignment to an empty tuple can be used to correctly expand the wildcard:
124 |
125 | ```
126 | def erase(_)
127 | | erase(_) = () => ()
128 | ```
129 |
130 | ### Currying
131 |
132 | It is possible to use both, nested patterns and additional matches on auxiliary ports, to define curried functions, corresponding to currying on the left-hand side and right-hand side of a match. For example:
133 |
134 | ```
135 | def fib(_) = r
136 | | Z => 1n
137 | | S(Z) => 1n
138 | | S(S(n)) => (n1, n2) = dup(n)
139 | fib(S(n1)) + fib(n2)
140 | ```
141 |
142 | This expands to a definition similar to this one (modulo the generated name of the curried function):
143 |
144 | ```
145 | def fib(_) = r
146 | | Z => 1n
147 | | S(n) => fib2(n)
148 |
149 | def fib2(_) = r
150 | | Z => 1n
151 | | S(n) => (n1, n2) = dup(n)
152 | fib(S(n1)) + fib(n2)
153 | ```
154 |
155 | Matching on auxiliary ports is done by specifying a comma-separated list in a `def` rule. In the following example `b` is matched by `S(y)` in the second rule after successfully matching `a` with `S(x)`:
156 |
157 | ```
158 | def foo(a, b) = r
159 | | Z => erase(b), Z
160 | | S(x), S(y) => x + y
161 | ```
162 |
163 | Restrictions on curried rules:
164 | - All additional matches must be done on the same port of the original match.
165 | - Nested matches must not conflict with another match at the outer layer (e.g. you cannot match on both `f(S(x))` and `f(S(S(x)))`).
166 |
167 | ### Detached Rules
168 |
169 | A rule can be defined independently of a function definition using a `match` statement. These rules can also be defined for `cons`-style constructors (which do not have a special rule syntax like `def`). The expression on the left-hand side is interpreted as a pattern which must correspond to two cells connected via their principal ports. For example:
170 |
171 | ```
172 | match add(Z, y) => y
173 | match add(S(x), y) => add(x, S(y))
174 | ```
175 |
176 | A combination of two constructors can be matched with an assignment, e.g.:
177 |
178 | ```
179 | # Assimilation rules for S and dup
180 | match S(x) = S(y) => x = y
181 | match dup(dup(_) = (c, d)) = (a, b) => (c, d)
182 | ```
183 |
184 | Currying works the same as in rules attached to a `def` statement.
185 |
186 | ### Natural Numbers
187 |
188 | There is syntactic support for parsing and printing natural numbers, e.g.:
189 |
190 | ```
191 | let example_3_times_2 = mult(3n, 2n)
192 | ```
193 |
194 | The snippet expands to:
195 |
196 | ```
197 | let example_3_times_2 = mult(S(S(S(Z))), S(S(Z)))
198 | ```
199 |
200 | This assumes that you have suitable definitions of `Z` and `S` like:
201 |
202 | ```
203 | cons Z
204 | cons S(n)
205 | ```
206 |
207 | ### Embedded Values
208 |
209 | A cell (`cons` or `def`) can optionally contain a primitive JVM value of type `int`, `ref` (`java.lang.Object`) or `label`. The type is placed in square brackets after the constructor name:
210 |
211 | ```
212 | cons Int[int]
213 | cons String[ref]
214 | ```
215 |
216 | Any match on such a constructor or its use in an expression must associate a variable name with the embedded value. These embedded variables share the same scope as regular variables but cannot be used interchangeably. If the same variable is used in a match and in the expansion, the value is automatically moved:
217 |
218 | ```
219 | def _ + y = r
220 | | Int[i] => intAdd[i](y)
221 | ```
222 |
223 | `int` and `label` values can also be copied or deleted implicitly by using the variable assigned to them in the match more or less than once in the expansion. `ref` values can only be moved implicitly.
224 |
225 | A static JVM method (or a method in a Scala object) can be invoked to perform a computation on embedded values by calling the method with its fully qualified name in an embedded expression in square brackets:
226 | ```
227 | def intAdd[int a](_) = r
228 | | Int[b] => [de.szeiger.interact.Runtime.add(a, b, c)]
229 | Int[c]
230 | ```
231 |
232 | Input parameters (corresponding to variables used in the match) must be of type `int` or a reference type, output parameters (corresponding to variables used in the expansion) must be of type `IntOutput` or `RefOutput`:
233 |
234 | ```
235 | // Implementation in Scala:
236 | object Runtime {
237 | def add(a: Int, b: Int, res: IntOutput): Unit =
238 | res.setValue(a + b)
239 | }
240 | ```
241 |
242 | It is up to the implementation of such a method to handle copying and deleting in an appropriate way. Since embedded `ref` values can also be copied and deleted by the `dup` and `erase` functions, the values may implement the `LifecycleManaged` interface to implement these methods. Otherwise, references will be shared or dropped from scope as usual on the JVM.
243 |
244 | Values of type `label` can be created implicitly by using the same variable in one or more cells without an expression to create it:
245 |
246 | ```
247 | let (x1, x2) = dup[label1](x) # dup(x) and dup(y) get the same label
248 | (y1, y2) = dup[label1](y)
249 | (z1, z2) = dup[label2](z) # dup(z) gets a different label
250 | (a1, a1) = dup(a) # dup(a) gets the default label
251 | ```
252 |
253 | All labels created in this way are unique per rule reduction or `let` statement.
254 |
255 | An implementation method may directly return a value instead of using an `IntOutput` / `RefOutput` output parameter. Such a method can be used directly as an embedded computation of a cell:
256 |
257 | ```
258 | def strlen(_) = r
259 | | String[s] => Int[de.szeiger.interact.Runtime.strlen(s)]
260 | ```
261 |
262 | Int literals and some basic operators are available in embedded computations, e.g.:
263 |
264 | ```
265 | let x = Int[3 + 5]
266 | ```
267 |
268 | Embedded values can be checked in a match to select different expansions for a rule depending on the embedded values:
269 |
270 | ```
271 | def fib(_) = r
272 | | Int[i] if [i == 0] => Int[1]
273 | if [i == 1] => Int[1]
274 | else => fib(Int[i-1]) + fib(Int[i-2])
275 | ```
276 |
277 | The `else` clause is required, all branches must be defined together, and the order of definition is significant.
278 |
279 | Current implementation limitations:
280 | - Currying is not allowed when both sides of the match contain embedded values.
281 |
--------------------------------------------------------------------------------
/bench/src/main/scala/ComparisonBenchmark.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import org.openjdk.jmh.annotations._
4 | import org.openjdk.jmh.infra._
5 |
6 | import java.util.concurrent.TimeUnit
7 |
8 | @BenchmarkMode(Array(Mode.Throughput))
9 | @Fork(value = 1, jvmArgsAppend = Array("-Xmx12g", "-Xss32M", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseZGC"))
10 | @Threads(1)
11 | @Warmup(iterations = 15, time = 1)
12 | @Measurement(iterations = 15, time = 1)
13 | @OutputTimeUnit(TimeUnit.SECONDS)
14 | @State(Scope.Benchmark)
15 | class ComparisonBenchmark {
16 |
17 | @Param(Array("stc2"))
18 | var spec: String = _
19 |
20 | private val prelude =
21 | """cons Z
22 | |cons S(n)
23 | |def add(_, y) = r
24 | | | Z => y
25 | | | S(x) => add(x, S(y))
26 | |""".stripMargin
27 |
28 | private val mult1Src = prelude +
29 | """def mult(_, y) = r
30 | | | Z => erase(y); Z
31 | | | S(x) => (y1, y2) = dup(y); add(mult(x, y1), y2)
32 | |let res = mult(100n, 100n)
33 | |""".stripMargin
34 |
35 | private val fib22Src = prelude +
36 | """def add2(_, y) = r
37 | | | Z => y
38 | | | S(x) => S(add2(x, y))
39 | |def fib(_) = r
40 | | | Z => 1n
41 | | | S(n) => fib2(n)
42 | |def fib2(_) = r
43 | | | Z => 1n
44 | | | S(n) => (n1, n2) = dup(n); add2(fib(S(n1)), fib(n2))
45 | |let res = fib(22n)
46 | |""".stripMargin
47 |
48 | private val intAck38Src =
49 | """cons Int[int]
50 | |
51 | |def ackU(a, b) = r
52 | | | Int[x], Int[y]
53 | | if [x == 0] => Int[y + 1]
54 | | if [y == 0] => ackU(Int[x - 1], Int[1])
55 | | else => ackU(Int[x - 1], ackU(Int[x], Int[y - 1]))
56 | |
57 | |let resU = ackU(Int[3], Int[9])
58 | |""".stripMargin
59 |
60 | private val intMult3Src =
61 | """cons Int[int]
62 | |def add(_, _) = r
63 | | | Int[x], Int[y] if [x == 0] => Int[y]
64 | | else => add(Int[x-1], Int[y+1])
65 | |def mult(_, _) = r
66 | | | Int[x], Int[y] if [x == 0] => Int[0]
67 | | else => add(Int[y], mult(Int[x-1], Int[y]))
68 | |let res = mult(Int[1000], Int[1000])
69 | |""".stripMargin
70 |
71 | class PreparedInterpreter(source: String) {
72 | val model: Compiler = new Compiler(Parser.parse(source), Config(spec))
73 | val inter = model.createInterpreter()
74 | def setup(): BaseInterpreter = {
75 | inter.initData()
76 | inter
77 | }
78 | }
79 |
80 | private lazy val mult1Inter: PreparedInterpreter = new PreparedInterpreter(mult1Src)
81 | private lazy val fib22Inter: PreparedInterpreter = new PreparedInterpreter(fib22Src)
82 | private lazy val intAck38Inter: PreparedInterpreter = new PreparedInterpreter(intAck38Src)
83 | private lazy val intMult3Inter: PreparedInterpreter = new PreparedInterpreter(intMult3Src)
84 |
85 | @Benchmark
86 | def mult1(bh: Blackhole): Unit =
87 | bh.consume(mult1Inter.setup().reduce())
88 |
89 | @Benchmark
90 | def fib22(bh: Blackhole): Unit =
91 | bh.consume(fib22Inter.setup().reduce())
92 |
93 | @Benchmark
94 | def intAck39(bh: Blackhole): Unit =
95 | bh.consume(intAck38Inter.setup().reduce())
96 |
97 | @Benchmark
98 | def intMult3(bh: Blackhole): Unit =
99 | bh.consume(intAck38Inter.setup().reduce())
100 |
101 | @Benchmark
102 | def mult1Scala(bh: Blackhole): Unit = {
103 | sealed abstract class Nat
104 | case object Z extends Nat
105 | case class S(pred: Nat) extends Nat
106 | def nat(i: Int): Nat = i match {
107 | case 0 => Z
108 | case n => S(nat(n-1))
109 | }
110 | def add(x: Nat, y: Nat): Nat = x match {
111 | case Z => y
112 | case S(x) => add(x, S(y))
113 | }
114 | def mult(x: Nat, y: Nat): Nat = x match {
115 | case Z => Z
116 | case S(x) => add(mult(x, y), y)
117 | }
118 | bh.consume(mult(nat(100), nat(100)))
119 | }
120 |
121 | @Benchmark
122 | def fib22Scala(bh: Blackhole): Unit = {
123 | sealed abstract class Nat
124 | case object Z extends Nat
125 | case class S(pred: Nat) extends Nat
126 | def nat(i: Int): Nat = i match {
127 | case 0 => Z
128 | case n => S(nat(n-1))
129 | }
130 | def add2(x: Nat, y: Nat): Nat = x match {
131 | case Z => y
132 | case S(x) => S(add2(x, y))
133 | }
134 | def fib(x: Nat): Nat = x match {
135 | case Z => nat(1)
136 | case S(n) => fib2(n)
137 | }
138 | def fib2(x: Nat): Nat = x match {
139 | case Z => nat(1)
140 | case S(n) => add2(fib(S(n)), fib(n))
141 | }
142 | bh.consume(fib(nat(22)))
143 | }
144 |
145 | @Benchmark
146 | def intAck39Scala(bh: Blackhole): Unit = {
147 | def ack(x: Int, y: Int): Int =
148 | if(x == 0) y + 1
149 | else if(y == 0) ack(x-1, 1)
150 | else ack(x-1, ack(x, y-1))
151 | bh.consume(ack(3, 9))
152 | }
153 |
154 | @Benchmark
155 | def intMult3Scala(bh: Blackhole): Unit = {
156 | def add(x: Int, y: Int): Int =
157 | if(x == 0) y
158 | else add(x-1, y+1)
159 | def mult(x: Int, y: Int): Int =
160 | if(x == 0) 0
161 | else add(y, mult(x-1, y))
162 | bh.consume(mult(1000, 1000))
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/bench/src/main/scala/FFMBenchmark.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.offheap.{Allocator, ArenaAllocator}
4 | import org.openjdk.jmh.annotations._
5 | import org.openjdk.jmh.infra._
6 |
7 | //import java.lang.foreign.MemoryLayout.PathElement
8 | import java.util.concurrent.TimeUnit
9 | //import java.lang.foreign._
10 | import java.nio.ByteBuffer
11 |
12 | @BenchmarkMode(Array(Mode.AverageTime))
13 | @Fork(value = 1, jvmArgsAppend = Array("-Xmx12g", "-Xss32M", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseZGC", "--enable-native-access=ALL-UNNAMED"))
14 | @Threads(1)
15 | @Warmup(iterations = 10, time = 1)
16 | @Measurement(iterations = 10, time = 1)
17 | @OutputTimeUnit(TimeUnit.MILLISECONDS)
18 | @State(Scope.Benchmark)
19 | class FFMBenchmark {
20 | val u = {
21 | val f = classOf[sun.misc.Unsafe].getDeclaredField("theUnsafe")
22 | f.setAccessible(true)
23 | f.get(null).asInstanceOf[sun.misc.Unsafe]
24 | }
25 |
26 | println()
27 | println("heap: " + heap0())
28 | // println("heapByteByffer: " + heapByteBuffer0())
29 | //println("autoArena: " + autoArena0())
30 | // println("customAllocator: " + customAllocator0())
31 | // println("customAllocatorDirect: " + customAllocatorDirect0())
32 | //println("ffmByteBuffer: " + ffmByteBuffer0())
33 | //println("unsafe: " + unsafe0())
34 | println("unsafeCustom: " + unsafeCustom0())
35 | println("unsafeAllocator: " + unsafeAllocator0())
36 |
37 | @Benchmark
38 | def heap(bh: Blackhole): Unit = bh.consume(heap0())
39 |
40 | // @Benchmark
41 | // def heapByteBuffer(bh: Blackhole): Unit = bh.consume(heapByteBuffer0())
42 |
43 | //@Benchmark
44 | //def autoArena(bh: Blackhole): Unit = bh.consume(autoArena0())
45 |
46 | // @Benchmark
47 | // def customAllocator(bh: Blackhole): Unit = bh.consume(customAllocator0())
48 | //
49 | // @Benchmark
50 | // def customAllocatorDirect(bh: Blackhole): Unit = bh.consume(customAllocatorDirect0())
51 |
52 | //@Benchmark
53 | //def ffmByteBuffer(bh: Blackhole): Unit = bh.consume(ffmByteBuffer0())
54 |
55 | //@Benchmark
56 | //def unsafe(bh: Blackhole): Unit = bh.consume(unsafe0())
57 |
58 | @Benchmark
59 | def unsafeCustom(bh: Blackhole): Unit = bh.consume(unsafeCustom0())
60 |
61 | @Benchmark
62 | def unsafeAllocator(bh: Blackhole): Unit = bh.consume(unsafeAllocator0())
63 |
64 | def heap0(): Long = {
65 | final class C(var c0: C, var p0: Int, var c1: C, var p1: Int)
66 | val buf = new Array[C](10000000)
67 | for(i <- buf.indices) buf(i) = new C(null, 0, null, 0)
68 | for(i <- 1 until buf.length-1) {
69 | buf(i).c0 = buf(i - 1)
70 | buf(i).p0 = -1
71 | buf(i).c1 = buf(i + 1)
72 | buf(i).p1 = i
73 | }
74 | var sum = 0L
75 | for(i <- 1 until buf.length-1) {
76 | sum += buf(i).p1
77 | sum -= buf(i).c0.p1
78 | }
79 | sum
80 | }
81 |
82 | def heapByteBuffer0(): Long = {
83 | val buf = new Array[Int](10000000)
84 | val bb = ByteBuffer.allocateDirect(buf.length * 24)
85 | var next = 0
86 | def alloc(len: Int): Int = {
87 | val a = next
88 | next += len
89 | a
90 | }
91 | for(i <- buf.indices) buf(i) = alloc(24)
92 | for(i <- 1 until buf.length-1) {
93 | bb.putLong(buf(i), buf(i-1))
94 | bb.putLong(buf(i)+8, buf(i+1))
95 | bb.putInt(buf(i)+16, -1)
96 | bb.putInt(buf(i)+20, i)
97 | }
98 | var sum = 0L
99 | for(i <- 1 until buf.length-1) {
100 | sum += bb.getInt(buf(i)+20)
101 | sum -= bb.getInt(bb.getLong(buf(i)).toInt+20)
102 | }
103 | sum
104 | }
105 |
106 | // def autoArena0(): Long = {
107 | // import LayoutGlobals._
108 | // val arena = Arena.ofAuto()
109 | // val buf = new Array[MemorySegment](10000000)
110 | // for(i <- buf.indices) buf(i) = arena.allocate(layout)
111 | // for(i <- 1 until buf.length-1) {
112 | // c0vh.set(buf(i), 0, buf(i-1))
113 | // c1vh.set(buf(i), 0, buf(i+1))
114 | // p0vh.set(buf(i), 0, -1)
115 | // p1vh.set(buf(i), 0, i)
116 | // }
117 | // var sum = 0L
118 | // for(i <- 1 until buf.length-1) {
119 | // sum += p1vh.get(buf(i), 0).asInstanceOf[Int]
120 | // sum -= p1vh.get(c0vh.get(buf(i), 0).asInstanceOf[MemorySegment].reinterpret(layout.byteSize()), 0).asInstanceOf[Int]
121 | // }
122 | // sum
123 | // }
124 | //
125 | // def customAllocator0(): Long = {
126 | // import LayoutGlobals2S._
127 | // val arena = Arena.ofAuto()
128 | // val buf = new Array[Long](10000000)
129 | // val _block = arena.allocate(buf.length * 24)
130 | // val root = MemorySegment.NULL.reinterpret(Long.MaxValue)
131 | // var next = _block.address()
132 | // def alloc(len: Int): Long = {
133 | // val a = next
134 | // next += len
135 | // a
136 | // }
137 | // for(i <- buf.indices) buf(i) = alloc(24)
138 | // for(i <- 1 until buf.length-1) {
139 | // c0vh.set(root, buf(i), buf(i-1))
140 | // c1vh.set(root, buf(i), buf(i+1))
141 | // p0vh.set(root, buf(i), -1)
142 | // p1vh.set(root, buf(i), i)
143 | // }
144 | // var sum = 0L
145 | // for(i <- 1 until buf.length-1) {
146 | // sum += p1vh.get(root, buf(i)).asInstanceOf[Int]
147 | // sum -= p1vh.get(root, c0vh.get(root, buf(i)).asInstanceOf[Long]).asInstanceOf[Int]
148 | // }
149 | // sum
150 | // }
151 | //
152 | // def customAllocatorDirect0(): Long = {
153 | // val arena = Arena.ofAuto()
154 | // val buf = new Array[Long](10000000)
155 | // val _block = arena.allocate(buf.length * 24)
156 | // val off = _block.address()
157 | // val root = MemorySegment.NULL.reinterpret(Long.MaxValue)
158 | // var next = off
159 | // def alloc(len: Int): Long = {
160 | // val a = next
161 | // next += len
162 | // a
163 | // }
164 | // for(i <- buf.indices) buf(i) = alloc(24)
165 | // for(i <- 1 until buf.length-1) {
166 | // root.set(ValueLayout.JAVA_LONG, buf(i), buf(i-1))
167 | // root.set(ValueLayout.JAVA_LONG, buf(i)+8, buf(i+1))
168 | // root.set(ValueLayout.JAVA_INT, buf(i)+16, -1)
169 | // root.set(ValueLayout.JAVA_INT, buf(i)+20, i)
170 | // }
171 | // var sum = 0L
172 | // for(i <- 1 until buf.length-1) {
173 | // sum += root.get(ValueLayout.JAVA_INT, buf(i)+20)
174 | // sum -= root.get(ValueLayout.JAVA_INT, root.get(ValueLayout.JAVA_LONG, buf(i))+20)
175 | // }
176 | // sum
177 | // }
178 | //
179 | // def ffmByteBuffer0(): Long = {
180 | // val arena = Arena.ofAuto()
181 | // val buf = new Array[Int](10000000)
182 | // val bb = arena.allocate(buf.length * 24).asByteBuffer()
183 | // var next = 0
184 | // def alloc(len: Int): Int = {
185 | // val a = next
186 | // next += len
187 | // a
188 | // }
189 | // for(i <- buf.indices) buf(i) = alloc(24)
190 | // for(i <- 1 until buf.length-1) {
191 | // bb.putLong(buf(i), buf(i-1))
192 | // bb.putLong(buf(i)+8, buf(i+1))
193 | // bb.putInt(buf(i)+16, -1)
194 | // bb.putInt(buf(i)+20, i)
195 | // }
196 | // var sum = 0L
197 | // for(i <- 1 until buf.length-1) {
198 | // sum += bb.getInt(buf(i)+20)
199 | // sum -= bb.getInt(bb.getLong(buf(i)).toInt+20)
200 | // }
201 | // sum
202 | // }
203 |
204 | def unsafe0(): Long = {
205 | val buf = new Array[Long](10000000)
206 | for(i <- buf.indices) buf(i) = u.allocateMemory(24)
207 | for(i <- 1 until buf.length-1) {
208 | u.putLong(buf(i), buf(i-1))
209 | u.putLong(buf(i)+8, buf(i+1))
210 | u.putInt(buf(i)+16, -1)
211 | u.putInt(buf(i)+20, i)
212 | }
213 | var sum = 0L
214 | for(i <- 1 until buf.length-1) {
215 | sum += u.getInt(buf(i)+20)
216 | sum -= u.getInt(u.getAddress(buf(i))+20)
217 | }
218 | for(i <- buf.indices) u.freeMemory(buf(i))
219 | sum
220 | }
221 |
222 | def unsafeCustom0(): Long = {
223 | val buf = new Array[Long](10000000)
224 | val block = u.allocateMemory(buf.length * 24)
225 | var next = block
226 | def alloc(len: Int): Long = {
227 | val a = next
228 | next += len
229 | a
230 | }
231 | for(i <- buf.indices) buf(i) = alloc(24)
232 | for(i <- 1 until buf.length-1) {
233 | u.putLong(buf(i), buf(i-1))
234 | u.putLong(buf(i)+8, buf(i+1))
235 | u.putInt(buf(i)+16, -1)
236 | u.putInt(buf(i)+20, i)
237 | }
238 | var sum = 0L
239 | for(i <- 1 until buf.length-1) {
240 | sum += u.getInt(buf(i)+20)
241 | sum -= u.getInt(u.getAddress(buf(i))+20)
242 | }
243 | u.freeMemory(block)
244 | sum
245 | }
246 |
247 | def unsafeAllocator0(): Long = {
248 | val buf = new Array[Long](10000000)
249 | val a = new ArenaAllocator()
250 | for(i <- buf.indices) buf(i) = a.alloc(24)
251 | for(i <- 1 until buf.length-1) {
252 | Allocator.putLong(buf(i), buf(i-1))
253 | Allocator.putLong(buf(i)+8, buf(i+1))
254 | Allocator.putInt(buf(i)+16, -1)
255 | Allocator.putInt(buf(i)+20, i)
256 | }
257 | var sum = 0L
258 | for(i <- 1 until buf.length-1) {
259 | sum += Allocator.getInt(buf(i)+20)
260 | sum -= Allocator.getInt(Allocator.getLong(buf(i))+20)
261 | }
262 | a.dispose()
263 | sum
264 | }
265 | }
266 |
267 | //object LayoutGlobals {
268 | // val layout = MemoryLayout.structLayout(
269 | // ValueLayout.ADDRESS.withName("c0"),
270 | // ValueLayout.ADDRESS.withName("c1"),
271 | // ValueLayout.JAVA_INT.withName("p0"),
272 | // ValueLayout.JAVA_INT.withName("p1"),
273 | // )
274 | // val c0vh = layout.varHandle(PathElement.groupElement("c0"))
275 | // val p0vh = layout.varHandle(PathElement.groupElement("p0"))
276 | // val c1vh = layout.varHandle(PathElement.groupElement("c1"))
277 | // val p1vh = layout.varHandle(PathElement.groupElement("p1"))
278 | //}
279 | //
280 | //object LayoutGlobals2S {
281 | // val layout = MemoryLayout.structLayout(
282 | // ValueLayout.JAVA_LONG.withName("c0"),
283 | // ValueLayout.JAVA_LONG.withName("c1"),
284 | // ValueLayout.JAVA_INT.withName("p0"),
285 | // ValueLayout.JAVA_INT.withName("p1"),
286 | // )
287 | // val c0vh = layout.varHandle(PathElement.groupElement("c0"))
288 | // val p0vh = layout.varHandle(PathElement.groupElement("p0"))
289 | // val c1vh = layout.varHandle(PathElement.groupElement("c1"))
290 | // val p1vh = layout.varHandle(PathElement.groupElement("p1"))
291 | //}
292 |
--------------------------------------------------------------------------------
/bench/src/main/scala/InterpreterBenchmark.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import org.openjdk.jmh.annotations._
4 | import org.openjdk.jmh.infra._
5 | import org.openjdk.jmh.runner.{Defaults, Runner}
6 | import org.openjdk.jmh.runner.format.OutputFormatFactory
7 | import org.openjdk.jmh.runner.options.CommandLineOptions
8 | import org.openjdk.jmh.util.Optional
9 |
10 | import java.nio.file.Path
11 | import scala.jdk.CollectionConverters._
12 | import java.util.concurrent.TimeUnit
13 | import scala.util.control.NonFatal
14 |
15 | // bench/jmh:runMain de.szeiger.interact.InterpreterBenchmark
16 |
17 | @BenchmarkMode(Array(Mode.Throughput))
18 | @Fork(value = 1, jvmArgsAppend = Array("-Xmx12g", "-Xss32M", "-XX:+UnlockExperimentalVMOptions", "-XX:+UseZGC",
19 | "-XX:+UnlockDiagnosticVMOptions"))
20 | @Threads(1)
21 | @Warmup(iterations = 11, time = 1)
22 | @Measurement(iterations = 11, time = 1)
23 | @OutputTimeUnit(TimeUnit.MICROSECONDS)
24 | @State(Scope.Benchmark)
25 | class InterpreterBenchmark {
26 |
27 | @Param(Array(
28 | //"sti",
29 | "stc1",
30 | "stc2",
31 | //"mt0.i", //"mt1.i", "mt8.i",
32 | //"mt1000.i", "mt1001.i", "mt1008.i",
33 | //"mt0.c", //"mt1.c", "mt8.c",
34 | //"mt1000.c", "mt1001.c", "mt1008.c",
35 | ))
36 | var spec: String = _
37 |
38 | @Param(Array(
39 | "ack38",
40 | "ack38b",
41 | "intAck38",
42 | "boxedAck38",
43 | "fib22",
44 | "mult1",
45 | "mult2",
46 | "mult3",
47 | "intMult3",
48 | "intFib29",
49 | // "fib29",
50 | ))
51 | var benchmark: String = _
52 |
53 | private[this] var inter: BaseInterpreter = _
54 |
55 | @Setup(Level.Trial)
56 | def setup(): Unit = inter = InterpreterBenchmark.setup(spec, benchmark)
57 |
58 | @Setup(Level.Invocation)
59 | def prepare(): Unit = inter.initData()
60 |
61 | @Benchmark
62 | def run(bh: Blackhole): Unit = bh.consume(inter.reduce())
63 |
64 | @TearDown(Level.Invocation)
65 | def cleanup(): Unit = inter.dispose()
66 | }
67 |
68 | object InterpreterBenchmark {
69 | private val prelude =
70 | """cons Z
71 | |cons S(n)
72 | |def add(_, y) = r
73 | | | Z => y
74 | | | S(x) => add(x, S(y))
75 | |""".stripMargin
76 |
77 | private val mult1Src = prelude +
78 | """def mult(_, y) = r
79 | | | Z => erase(y); Z
80 | | | S(x) => (y1, y2) = dup(y); add(mult(x, y1), y2)
81 | |let res = mult(100n, 100n)
82 | |""".stripMargin
83 |
84 | private val mult2Src = prelude +
85 | """def mult(_, y) = r
86 | | | Z => erase(y); Z
87 | | | S(x) => (y1, y2) = dup(y); add(mult(x, y1), y2)
88 | |let res1 = mult(100n, 100n)
89 | | res2 = mult(100n, 100n)
90 | | res3 = mult(100n, 100n)
91 | | res4 = mult(100n, 100n)
92 | |""".stripMargin
93 |
94 | private val mult3Src = prelude +
95 | """def mult(_, y) = r
96 | | | Z => erase(y); Z
97 | | | S(x) => (a, b) = dup(y); add(b, mult(x, a))
98 | |let res = mult(1000n, 1000n)
99 | |""".stripMargin
100 |
101 | private val intMult3Src =
102 | """cons Int[int]
103 | |def add(_, _) = r
104 | | | Int[x], Int[y] if [x == 0] => Int[y]
105 | | else => add(Int[x-1], Int[y+1])
106 | |def mult(_, _) = r
107 | | | Int[x], Int[y] if [x == 0] => Int[0]
108 | | else => add(Int[y], mult(Int[x-1], Int[y]))
109 | |let res = mult(Int[1000], Int[1000])
110 | |""".stripMargin
111 |
112 | private val fib22Src = prelude +
113 | """def add2(_, y) = r
114 | | | Z => y
115 | | | S(x) => S(add2(x, y))
116 | |def fib(_) = r
117 | | | Z => 1n
118 | | | S(n) => fib2(n)
119 | |def fib2(_) = r
120 | | | Z => 1n
121 | | | S(n) => (n1, n2) = dup(n); add2(fib(S(n1)), fib(n2))
122 | |let res = fib(22n)
123 | |""".stripMargin
124 |
125 | private val fib29Src = prelude +
126 | """def add2(_, y) = r
127 | | | Z => y
128 | | | S(x) => S(add2(x, y))
129 | |def fib(_) = r
130 | | | Z => 1n
131 | | | S(n) => fib2(n)
132 | |def fib2(_) = r
133 | | | Z => 1n
134 | | | S(n) => (n1, n2) = dup(n); add2(fib(S(n1)), fib(n2))
135 | |let res = fib(29n)
136 | |""".stripMargin
137 |
138 | private val intFib29Src =
139 | """cons Int[int]
140 | |def _ + _ = r
141 | | | Int[x], Int[y] => Int[x + y]
142 | |def fib(_) = r
143 | | | Int[x] if [x == 0] => Int[1]
144 | | if [x == 1] => Int[1]
145 | | else => fib(Int[x-1]) + fib(Int[x-2])
146 | |let res = fib(Int[29])
147 | |""".stripMargin
148 |
149 | private val ack38Src = prelude +
150 | """def ack(_, y) = r
151 | | | Z => S(y)
152 | | | S(x) => ack_Sx(y, x)
153 | |def ack_Sx(_, x) = r
154 | | | Z => ack(x, S(Z))
155 | | | S(y) => (x1, x2) = dup(x); ack(x1, ack_Sx(y, x2))
156 | |let res = ack(3n, 8n)
157 | |""".stripMargin
158 |
159 | private val ack38bSrc = prelude +
160 | """def pred(_) = r
161 | | | Z => Z
162 | | | S(x) => x
163 | |def ack2(_, a) = b
164 | | | Z => S(a)
165 | | | S(x) => ack2b(a, S(x))
166 | |def ack2b(_, a) = b
167 | | | Z => ack2(pred(a), S(Z))
168 | | | S(y) => (a1, a2) = dup(a); ack2(pred(a1), ack2(a2, y))
169 | |let res2 = ack2(3n, 8n)
170 | |""".stripMargin
171 |
172 | private val intAck38Src =
173 | """cons Int[int]
174 | |
175 | |def ackU(a, b) = r
176 | | | Int[x], Int[y]
177 | | if [x == 0] => Int[y + 1]
178 | | if [y == 0] => ackU(Int[x - 1], Int[1])
179 | | else => ackU(Int[x - 1], ackU(Int[x], Int[y - 1]))
180 | |
181 | |let resU = ackU(Int[3], Int[8])
182 | |""".stripMargin
183 |
184 | private val boxedAck38Src =
185 | """cons BoxedInt[ref]
186 | |
187 | |def ackB(a, b) = r
188 | | | BoxedInt[x], BoxedInt[y]
189 | | if [de.szeiger.interact.InterpreterBenchmark.is0(x)] =>
190 | | BoxedInt[de.szeiger.interact.InterpreterBenchmark.inc(y)]
191 | | [eraseRef(x)]
192 | | if [de.szeiger.interact.InterpreterBenchmark.is0(y)] =>
193 | | ackB(BoxedInt[de.szeiger.interact.InterpreterBenchmark.dec(x)], BoxedInt[de.szeiger.interact.InterpreterBenchmark.box(1)])
194 | | [eraseRef(y)]
195 | | else =>
196 | | [de.szeiger.interact.InterpreterBenchmark.ackHelper(x, x1, x2)]
197 | | ackB(BoxedInt[x1], ackB(BoxedInt[x2], BoxedInt[de.szeiger.interact.InterpreterBenchmark.dec(y)]))
198 | |
199 | |let resB = ackB(BoxedInt[de.szeiger.interact.InterpreterBenchmark.box(3)], BoxedInt[de.szeiger.interact.InterpreterBenchmark.box(8)])
200 | |""".stripMargin
201 | def is0(i: java.lang.Integer): Boolean = i.intValue() == 0
202 | def box(i: Int): java.lang.Integer = Integer.valueOf(i)
203 | def inc(i: java.lang.Integer): java.lang.Integer = box(i.intValue() + 1)
204 | def dec(i: java.lang.Integer): java.lang.Integer = box(i.intValue() - 1)
205 | def ackHelper(i: java.lang.Integer, o1: RefOutput, o2: RefOutput): Unit = {
206 | o1.setValue(dec(i))
207 | o2.setValue(i)
208 | }
209 |
210 | val testCases = Map(
211 | "ack38" -> ack38Src,
212 | "ack38b" -> ack38bSrc,
213 | "boxedAck38" -> boxedAck38Src,
214 | "intAck38" -> intAck38Src,
215 | "fib22" -> fib22Src,
216 | "intFib29" -> intFib29Src,
217 | "fib29" -> fib29Src,
218 | "mult1" -> mult1Src,
219 | "mult2" -> mult2Src,
220 | "mult3" -> mult3Src,
221 | "intMult3" -> intMult3Src,
222 | )
223 |
224 | val prepareConfig: Config => Config =
225 | _.copy(collectStats = true, logCodeGenSummary = true, showAfter = Set("PlanRules"))
226 |
227 | val benchConfig: Config => Config =
228 | //identity
229 | _.copy(writeOutput = Some(Path.of("gen-classes")), writeJava = Some(Path.of("gen-src")), logGeneratedClasses = None, showAfter = Set(""))
230 | //_.copy(skipCodeGen = Set(""))
231 |
232 | def setup(spec: String, benchmark: String): BaseInterpreter =
233 | new Compiler(Parser.parse(testCases(benchmark)), benchConfig(Config(spec))).createInterpreter()
234 |
235 | def main(args: Array[String]): Unit = {
236 | val cls = classOf[InterpreterBenchmark]
237 |
238 | def run1(testCase: String, spec: String) = {
239 | try {
240 | println(s"-------------------- Running $testCase $spec:")
241 | val i = new Compiler(Parser.parse(testCases(testCase)), prepareConfig(Config(spec))).createInterpreter()
242 | i.initData()
243 | println()
244 | i.reduce()
245 | val m = i.getMetrics
246 | m.log()
247 | val steps = m.getSteps
248 | i.dispose()
249 | println()
250 | val opts = new CommandLineOptions(cls.getName, s"-pbenchmark=$testCase", s"-pspec=$spec") {
251 | override def getOperationsPerInvocation = Optional.of(steps)
252 | }
253 | val runner = new Runner(opts)
254 | runner.run().asScala
255 | } catch { case NonFatal(ex) =>
256 | ex.printStackTrace()
257 | Iterable.empty
258 | }
259 | }
260 |
261 | val res = for {
262 | testCase <- cls.getDeclaredField("benchmark").getAnnotation(classOf[Param]).value().toVector
263 | spec <- cls.getDeclaredField("spec").getAnnotation(classOf[Param]).value()
264 | res <- run1(testCase, spec)
265 | } yield res
266 |
267 | println("-------------------- Results")
268 | System.out.flush()
269 | val out = OutputFormatFactory.createFormatInstance(System.out, Defaults.VERBOSITY)
270 | out.endRun(res.asJava)
271 | }
272 | }
273 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | //cancelable in Global := false
2 |
3 | scalacOptions ++= Seq("-feature")
4 |
5 | Test / fork := true
6 | run / fork := true
7 | run / connectInput := true
8 |
9 | Global / scalaVersion := "2.13.12"
10 |
11 | lazy val main = (project in file("."))
12 | .settings(
13 | libraryDependencies += "com.github.sbt" % "junit-interface" % "0.13.2" % "test",
14 | libraryDependencies += "com.lihaoyi" %% "fastparse" % "3.0.1",
15 | libraryDependencies ++= Seq("asm", "asm-tree", "asm-util", "asm-commons", "asm-analysis").map(a => "org.ow2.asm" % a % "9.5"),
16 | libraryDependencies += "com.jetbrains.intellij.java" % "java-decompiler-engine" % "233.14015.106",
17 | testOptions += Tests.Argument(TestFrameworks.JUnit, "-a", "-v"),
18 | resolvers += "IntelliJ Releases" at "https://www.jetbrains.com/intellij-repository/releases/",
19 | //scalacOptions ++= Seq("-feature", "-opt:l:inline", "-opt-inline-from:de.szeiger.interact.*", "-opt-inline-from:de.szeiger.interact.**"),
20 | )
21 |
22 | lazy val bench = (project in file("bench"))
23 | .dependsOn(main)
24 | .enablePlugins(JmhPlugin)
25 | .settings(
26 | Jmh / javaOptions ++= Seq("-Xss32M"),
27 | )
28 |
--------------------------------------------------------------------------------
/example.in:
--------------------------------------------------------------------------------
1 | # Natural numbers
2 | cons Z
3 | cons S(n)
4 |
5 | # Erasure and Duplication
6 | # def erase(_)
7 | # def dup(_) = (a, b)
8 | # | dup(_) = (c, d) => (c, d)
9 |
10 | # Addition
11 | def _ + y = r
12 | | Z => y
13 | | S(x) => x + S(y)
14 |
15 | # Multiplication
16 | def _ * y = r
17 | | Z => erase(y); Z
18 | | S(x) => (y1, y2) = dup(y); x * y1 + y2
19 |
20 | # Example: Computations on natural numbers
21 | let y = 5n
22 | x = 3n
23 | example_3_plus_5 = x + y
24 |
25 | let example_3_times_2 = 3n * 2n
26 |
27 | # Lists
28 | cons Nil
29 | cons head :: tail = l
30 |
31 | def length(list) = r
32 | | Nil => Z
33 | | x :: xs => erase(x); S(length(xs))
34 |
35 | def map(list, fi, fo) = r
36 | | Nil => erase(fi); erase(fo); Nil
37 | | x :: xs => (x, fi2) = dup(fi)
38 | (fo1, fo2) = dup(fo)
39 | fo1 :: map(xs, fi2, fo2)
40 |
41 | # Example: List operations
42 | let l0 = 1n :: 2n :: 3n :: Nil
43 | (l0a, l0b) = dup(l0)
44 | l0_length = length(l0a)
45 | l0_mapped = map(l0b, x, x + 2n)
46 |
47 | # Explicit lambdas
48 | cons in |> out
49 | def apply(l, in) = out
50 | | i |> o => in = i; o
51 |
52 | # Example: List mapping with lambdas
53 | def map2(l, f) = r
54 | | Nil => erase(f); Nil
55 | | x :: xs => (f1, f2) = dup(f)
56 | apply(f1, x) :: map2(xs, f2)
57 |
58 | let l0 = 1n :: 2n :: 3n :: Nil
59 | l0_mapped_lambda = map2(l0, x |> x + 2n)
60 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.9.8
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.3")
2 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | sbt 'bench/jmh:run de.szeiger.interact.InterpreterBenchmark'
3 |
--------------------------------------------------------------------------------
/src/main/java/Foo.java:
--------------------------------------------------------------------------------
1 | public class Foo {
2 | public Object f() {
3 | return de.szeiger.interact.ast.PayloadType.VOID();
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/src/main/java/de/szeiger/interact/codegen/LocalClassLoader.java:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen;
2 |
3 | public class LocalClassLoader extends ClassLoader implements ClassWriter {
4 | static { registerAsParallelCapable(); }
5 |
6 | public final Class> defineClass(String name, byte[] b) throws ClassFormatError {
7 | return defineClass(name, b, 0, b.length);
8 | }
9 |
10 | public void writeClass(String javaName, byte[] classFile) {
11 | defineClass(javaName, classFile);
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/java/de/szeiger/interact/stc2/MetaClass.java:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.stc2;
2 |
3 | import de.szeiger.interact.ast.Symbol;
4 |
5 | public abstract class MetaClass {
6 | public final Symbol cellSymbol;
7 | public final int symId;
8 | protected MetaClass(Symbol cellSymbol, int symId) {
9 | this.cellSymbol = cellSymbol;
10 | this.symId = symId;
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/src/main/scala/Analyzer.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import java.io.PrintStream
4 | import scala.annotation.tailrec
5 | import scala.collection.mutable
6 | import de.szeiger.interact.ast._
7 |
8 | trait Analyzer[Cell] { self =>
9 | def rootCells: IterableOnce[Cell]
10 | def irreduciblePairs: IterableOnce[(Cell, Cell)]
11 |
12 | def getSymbol(c: Cell): Symbol
13 | def getConnected(c: Cell, port: Int): (Cell, Int)
14 | def isFreeWire(c: Cell): Boolean
15 |
16 | def symbolName(c: Cell): String = getSymbol(c).id
17 | def getArity(c: Cell): Int = getSymbol(c).arity
18 |
19 | def getPayload(c: Cell): Any = c match {
20 | case c: IntBox => c.getValue
21 | case c: RefBox => c.getValue
22 | case c => "???"
23 | }
24 |
25 | private[this] def getAllConnected(c: Cell): Iterator[(Cell, Int)] =
26 | if(isFreeWire(c)) Iterator(getConnected(c, 0))
27 | else (-1 until getArity(c)).iterator.map(getConnected(c, _))
28 |
29 | private[this] object Nat {
30 | def unapply(c: Cell): Option[Int] = unapply(c, 0)
31 | @tailrec private[this] def unapply(c: Cell, acc: Int): Option[Int] = (symbolName(c), getArity(c)) match {
32 | case ("Z", 0) => Some(acc)
33 | case ("S", 1) =>
34 | val (c2, p2) = getConnected(c, 0)
35 | if(p2 != -1) None else unapply(c2, acc+1)
36 | case _ => None
37 | }
38 | }
39 |
40 | def reachableCells: Iterator[Cell] = {
41 | val freeWires = rootCells.iterator.filter(isFreeWire).toVector
42 | val s = mutable.HashSet.empty[Cell]
43 | val q = mutable.ArrayBuffer.from(freeWires.flatMap(getAllConnected(_).filter(_ != null).map(_._1)))
44 | while(q.nonEmpty) {
45 | val w = q.last
46 | q.dropRightInPlace(1)
47 | if(s.add(w)) q.addAll(getAllConnected(w).map(_._1))
48 | }
49 | s.iterator
50 | }
51 |
52 | private[this] def allConnections(): (mutable.HashMap[(Cell, Int), (Cell, Int)], mutable.HashSet[Cell]) = {
53 | val m = mutable.HashMap.empty[(Cell, Int), (Cell, Int)]
54 | val s = mutable.HashSet.empty[Cell]
55 | val q = mutable.ArrayBuffer.from(rootCells)
56 | while(q.nonEmpty) {
57 | val c1 = q.last
58 | q.dropRightInPlace(1)
59 | if(s.add(c1)) {
60 | val isWire = isFreeWire(c1)
61 | val conn = getAllConnected(c1).toVector
62 | conn.zipWithIndex.foreach {
63 | case (null, _) =>
64 | case ((c2, p2), _p1) =>
65 | val p1 = if(isWire) 0 else _p1 - 1
66 | m.put((c1, p1), (c2, p2))
67 | m.put((c2, p2), (c1, p1))
68 | }
69 | q.addAll(conn.iterator.filter(_ != null).map(_._1))
70 | }
71 | }
72 | (m, s)
73 | }
74 |
75 | def log(out: PrintStream, prefix: String = " ", markCut: (Cell, Cell) => Boolean = (_, _) => false, color: Boolean = true): mutable.ArrayBuffer[(Cell, Cell)] = {
76 | val colors = if(color) MaybeColors else NoColors
77 | import colors._
78 | val cuts = mutable.ArrayBuffer.empty[(Cell, Cell)]
79 | def singleRet(s: Symbol): Int = if(!s.isDef) -1 else if(s.returnArity == 1) s.callArity-1 else -2
80 | val freeWires = rootCells.iterator.filter(isFreeWire).toVector
81 | val stack = mutable.Stack.from(freeWires.sortBy(c => symbolName(c)))
82 | val all = allConnections()._1
83 | val shown = mutable.HashSet.empty[Cell]
84 | var lastTmp = 0
85 | def tmp(): String = { lastTmp += 1; s"$$s$lastTmp" }
86 | val subst = mutable.HashMap.from(freeWires.iterator.map(c1 => ((c1, 0), symbolName(c1))))
87 | //println(s"**** $subst")
88 | //def id(c: Cell): String = if(c == null) "null" else s"${getSymbol(c)}#${System.identityHashCode(c)}"
89 | //all.foreach { case ((c1, p1), (c2, p2)) => println(s" ${id(c1)}:$p1 . ${id(c2)}:$p2") }
90 | def nameOrSubst(c1: Cell, p1: Int, c2: Cell, p2: Int): String = subst.get(c2, p2) match {
91 | case Some(s) => s
92 | case None =>
93 | val mark = if(p1 == -1 && p2 == -1 && markCut(c1, c2)) {
94 | cuts.addOne((c1, c2))
95 | s"${cBlue}<${cuts.length-1}>${cNormal}"
96 | } else ""
97 | if(singleRet(getSymbol(c2)) == p2) mark + show(c2, false)
98 | else {
99 | if(!shown.contains(c2)) stack += c2
100 | val t = tmp()
101 | subst.put((c1, p1), t)
102 | mark + t
103 | }
104 | }
105 | def show(_c1: Cell, withRet: Boolean): String = {
106 | val (c1, freeP) = if(isFreeWire(_c1)) getConnected(_c1, 0) else (_c1, -2)
107 | shown += _c1
108 | val sym = getSymbol(c1)
109 | def list(poss: IndexedSeq[Int]) = poss.map { p1 =>
110 | if(p1 == freeP && isFreeWire(_c1)) (getSymbol(_c1), nameOrSubst(c1, p1, _c1, 0))
111 | else all.get(c1, p1) match {
112 | case Some((c2, p2)) => (getSymbol(c2), nameOrSubst(c1, p1, c2, p2))
113 | case None => (Symbol.NoSymbol, "?")
114 | }
115 | }
116 | def needsParens(sym1: Symbol, pre1: Int, sym2: Symbol, sym2IsRight: Boolean): Boolean = {
117 | val pre2 = Lexical.precedenceOf(sym2.id)
118 | val r1 = Lexical.isRightAssoc(sym1.id)
119 | val r2 = Lexical.isRightAssoc(sym2.id)
120 | pre2 > pre1 || (pre2 >= 0 && (r1 != r2)) || (pre1 == pre2 && r1 != sym2IsRight && r2 != sym2IsRight)
121 | }
122 | val call = c1 match {
123 | case Nat(v) => s"${v}n"
124 | case _ =>
125 | val aposs = if(sym.isDef) -1 +: (0 until sym.callArity-1) else 0 until sym.arity
126 | val as0 = list(aposs)
127 | val pr1 = Lexical.precedenceOf(sym.id)
128 | val nameAndValue = sym.payloadType match {
129 | case PayloadType.VOID => s"$cYellow${sym.id}$cNormal"
130 | case _ =>
131 | val s = getPayload(c1) match {
132 | case s: String => s"\"$s\""
133 | case o => String.valueOf(o)
134 | }
135 | s"$cYellow${sym.id}$cNormal[$s]"
136 | }
137 | if(pr1 >= 0 && sym.arity == 2) {
138 | val as1 = as0.zipWithIndex.map { case ((asym, s), idx) => if(needsParens(sym, pr1, asym, idx == 1)) s"($s)" else s }
139 | s"${as1(0)} $nameAndValue ${as1(1)}"
140 | } else {
141 | val as = if(as0.isEmpty) "" else as0.iterator.map(_._2).mkString("(", ", ", ")")
142 | s"$nameAndValue$as"
143 | }
144 | }
145 | if(withRet) {
146 | val rposs = if(sym.isDef) sym.callArity-1 until sym.callArity+sym.returnArity-1 else IndexedSeq(-1)
147 | val rs0 = list(rposs).map(_._2)
148 | rs0.size match {
149 | case 0 => call
150 | case 1 => s"${rs0.head} = $call"
151 | case _ => rs0.mkString("(", ", ", s") = $call")
152 | }
153 | } else call
154 | }
155 | while(stack.nonEmpty) {
156 | val c1 = stack.pop()
157 | if(!shown.contains(c1)) {
158 | val s = show(c1, true)
159 | out.println(s"$prefix$s")
160 | }
161 | }
162 | val irr = irreduciblePairs.iterator.filter { case (c1, c2) => c1 != null && c2 != null }.map { case (c1, c2) => Seq(symbolName(c1), symbolName(c2)).sorted.mkString(" <-> ") }.toVector.sorted
163 | if(irr.nonEmpty) {
164 | out.println()
165 | out.println("Irreducible pairs:")
166 | irr.foreach(s => out.println(s" $s"))
167 | }
168 | cuts
169 | }
170 |
171 | def toDot(out: PrintStream): Unit = {
172 | var lastIdx = 0
173 | def mk(): String = { lastIdx += 1; s"n$lastIdx" }
174 | val cells = allConnections()._2.map(c => (c, mk())).toMap
175 | out.println("graph G {")
176 | out.println(" node [shape=plain];")
177 | cells.foreachEntry { (c, l) =>
178 | val sym = getSymbol(c)
179 | if(sym.arity == 0)
180 | out.println(
181 | s""" $l [shape=circle label=<${sym.id}>];""".stripMargin
182 | )
183 | else {
184 | val ports = (sym.arity to 1 by -1).map(i => s"""
| """).mkString
185 | out.println(
186 | s""" $l [shape=plain label=<
187 | | | | |
188 | | ${sym.id} |
189 | | |
190 | |
>];""".stripMargin
191 | )
192 | }
193 | }
194 | val done = mutable.HashSet.empty[(Cell, Int)]
195 | cells.foreachEntry { (c1, l1) =>
196 | getAllConnected(c1).zipWithIndex.foreach { case ((c2, _p2), p1) =>
197 | if(!done.contains((c1, p1))) {
198 | val p2 = _p2 + 1
199 | val l2 = cells(c2)
200 | val st =
201 | if(p1 == 0 && p2 == 0) " [style=bold]"
202 | else if(p1 != 0 && p2 != 0) " [style=dashed]"
203 | else ""
204 | out.println(s""" $l1:$p1 -- $l2:$p2$st;""")
205 | done += ((c2, p2))
206 | }
207 | }
208 | }
209 | out.println("}")
210 | }
211 | }
212 |
--------------------------------------------------------------------------------
/src/main/scala/BitOps.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | object BitOps {
4 | @inline def byte0(i: Int): Int = (i & 0xFF).toByte
5 | @inline def byte1(i: Int): Int = ((i >>> 8) & 0xFF).toByte
6 | @inline def byte2(i: Int): Int = ((i >>> 16) & 0xFF).toByte
7 | @inline def byte3(i: Int): Int = ((i >>> 24) & 0xFF).toByte
8 | @inline def intOfBytes(b0: Int, b1: Int, b2: Int, b3: Int): Int = b0.toByte&0xFF | ((b1.toByte&0xFF) << 8) | ((b2.toByte&0xFF) << 16) | ((b3.toByte&0xFF) << 24)
9 | def checkedIntOfBytes(b0: Int, b1: Int, b2: Int, b3: Int): Int = {
10 | assert(b0 >= -128 && b0 <= 127)
11 | assert(b1 >= -128 && b1 <= 127)
12 | assert(b2 >= -128 && b2 <= 127)
13 | assert(b3 >= -128 && b3 <= 127)
14 | intOfBytes(b0, b1, b2, b3)
15 | }
16 | object IntOfBytes {
17 | @inline def unapply(i: Int): Some[(Int, Int, Int, Int)] = Some((byte0(i), byte1(i), byte2(i), byte3(i)))
18 | }
19 |
20 | @inline def short0(i: Int): Int = (i & 0xFFFF).toShort
21 | @inline def short1(i: Int): Int = ((i >>> 16) & 0xFFFF).toShort
22 | @inline def intOfShorts(s0: Int, s1: Int): Int = s0.toShort&0xFFFF | ((s1.toShort&0xFFFF) << 16)
23 | def checkedIntOfShorts(s0: Int, s1: Int): Int = {
24 | assert(s0 >= Short.MinValue && s0 <= Short.MaxValue)
25 | assert(s1 >= Short.MinValue && s1 <= Short.MaxValue)
26 | intOfShorts(s0, s1)
27 | }
28 | object IntOfShorts {
29 | @inline def unapply(i: Int): Some[(Int, Int)] = Some((short0(i), short1(i)))
30 | }
31 |
32 | @inline def intOfShortByteByte(s0: Int, b2: Int, b3: Int): Int = s0.toShort&0xFFFF | ((b2.toByte&0xFF) << 16) | ((b3.toByte&0xFF) << 24)
33 | def checkedIntOfShortByteByte(s0: Int, b2: Int, b3: Int): Int = {
34 | assert(s0 >= Short.MinValue && s0 <= Short.MaxValue)
35 | assert(b2 >= -128 && b2 <= 127)
36 | assert(b3 >= -128 && b3 <= 127)
37 | intOfShortByteByte(s0, b2, b3)
38 | }
39 | object IntOfShortByteByte {
40 | @inline def unapply(i: Int): Some[(Int, Int)] = Some((short0(i), short1(i)))
41 | }
42 | }
43 |
44 | object LongBitOps {
45 | @inline def short0(l: Long): Int = (l & 0xFFFFL).toShort
46 | @inline def short1(l: Long): Int = ((l >>> 16L) & 0xFFFFL).toShort
47 | @inline def short2(l: Long): Int = ((l >>> 32L) & 0xFFFFL).toShort
48 | @inline def short3(l: Long): Int = ((l >>> 48L) & 0xFFFFL).toShort
49 | @inline def longOfShorts(s0: Int, s1: Int, s2: Int, s3: Int): Long = s0.toShort&0xFFFFL | ((s1.toShort&0xFFFFL) << 16L) | ((s2.toShort&0xFFFFL) << 32L) | ((s3.toShort&0xFFFFL) << 48L)
50 | def checkedLongOfShorts(s0: Int, s1: Int, s2: Int, s3: Int): Long = {
51 | assert(s0 >= Short.MinValue && s0 <= Short.MaxValue)
52 | assert(s1 >= Short.MinValue && s1 <= Short.MaxValue)
53 | assert(s2 >= Short.MinValue && s2 <= Short.MaxValue)
54 | assert(s3 >= Short.MinValue && s3 <= Short.MaxValue)
55 | longOfShorts(s0, s1, s2, s3)
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/scala/CleanEmbedded.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | import scala.collection.mutable
6 | import scala.collection.mutable.ArrayBuffer
7 |
8 | /**
9 | * Check usage of variables and assign payload types to embedded variables. After this phase:
10 | * - All Apply/ApplyCons constructor calls that need an embedded value contain an Ident
11 | * which either refers to an Ident in the pattern or is computed by an EmbeddedExpr.
12 | * - Implicit creation and copying of labels is made explicit.
13 | * - PayloadTypes are assigned to all embedded Symbols.
14 | * - Embedded variable usage is linear.
15 | * - Embedded expressions are valid assignments and method calls.
16 | * - Method/operator calls are typed but not yet resolved.
17 | */
18 | class CleanEmbedded(val global: Global) extends Transform with Phase {
19 | import global._
20 |
21 | override def apply(n: Statement): Vector[Statement] = n match {
22 | case n: MatchRule => apply(n)
23 | case n: Let => apply(n)
24 | case n => Vector(n)
25 | }
26 |
27 | override def apply(mr: MatchRule): Vector[Statement] = {
28 | val emb1IdOpt = checkPatternEmbSym(mr.id1, mr.emb1)
29 | val emb2IdOpt = checkPatternEmbSym(mr.id2, mr.emb2)
30 | val patternIds = Iterator(emb1IdOpt, emb2IdOpt).flatten.toSet
31 | patternIds.foreach { _.sym.setPattern() }
32 | val branches = mr.branches.map { b =>
33 | val (reduced2, embRed2) = transformReduction(b.reduced, b.embRed, patternIds)
34 | val cond2 = b.cond.map(checkCond(_))
35 | b.copy(cond = cond2, reduced = reduced2, embRed = embRed2)
36 | }
37 | Vector(mr.copy(branches = branches))
38 | }
39 |
40 | override def apply(l: Let): Vector[Statement] = {
41 | val (es, ees) = transformReduction(l.defs, l.embDefs, Set.empty)
42 | Vector(l.copy(defs = es, embDefs = ees))
43 | }
44 |
45 | // check embedded symbol in pattern and assign payload type to is
46 | private[this] def checkPatternEmbSym(consId: Ident, embId: Option[Ident]): Option[Ident] = {
47 | val consPT = consId.sym.payloadType
48 | embId match {
49 | case s @ Some(id) =>
50 | assert(id.sym.isDefined)
51 | id.sym.payloadType = consPT
52 | if(consPT.isEmpty) {
53 | error(s"Constructor has no embedded value", consId)
54 | None
55 | } else s
56 | case None =>
57 | if(consPT.isDefined && !consPT.canErase)
58 | error(s"Embedded value of type $consPT must be extracted in pattern match", consId)
59 | None
60 | }
61 | }
62 |
63 | // Create new unique embedded symbols in Apply / ApplyCons and assign their PayloadTypes.
64 | // Ensure that all targets have / do not have embedded symbols based on type.
65 | // Returns updated exprs, extracted embedded computations, old symbol to new idents mapping, all new symbols
66 | private[this] def extractConsEmbComps(exprs: Vector[Expr]): (Vector[Expr], Vector[EmbeddedExpr], mutable.HashMap[Symbol, ArrayBuffer[Ident]], Set[Symbol]) = {
67 | val local = new SymbolGen("$e$", isEmbedded = true)
68 | val symbolMap = mutable.HashMap.empty[Symbol, ArrayBuffer[Ident]]
69 | val newEmbComps = Vector.newBuilder[EmbeddedExpr]
70 | val defaultCreate = ArrayBuffer.empty[Symbol]
71 | val allEmbCompSyms = Set.newBuilder[Symbol]
72 | val proc: Transform = new Transform {
73 | def tr[T <: AnyApply](n: T): T = {
74 | val emb2 = n.embedded match {
75 | case s @ Some(e) =>
76 | if(n.target.sym.payloadType.isEmpty)
77 | error("Constructor has no embedded value", e)
78 | val prefix = e match {
79 | case i: Ident => i.s
80 | case _ => n.target.sym.id
81 | }
82 | val id = local.id(isEmbedded = true, payloadType = n.target.sym.payloadType, prefix = prefix).setPos(e.pos)
83 | e match {
84 | case oldId: Ident if !oldId.sym.isPattern =>
85 | symbolMap.getOrElseUpdate(oldId.sym, ArrayBuffer.empty) += id
86 | case _ =>
87 | newEmbComps += EmbeddedAssignment(id, e)
88 | allEmbCompSyms += id.sym
89 | }
90 | Some(id)
91 | case None if n.target.sym.payloadType.canCreate =>
92 | val id = local.id(isEmbedded = true, payloadType = n.target.sym.payloadType, prefix = "cr_"+n.target.sym.id).setPos(n.pos)
93 | defaultCreate += id.sym
94 | allEmbCompSyms += id.sym
95 | Some(id)
96 | case None =>
97 | if(n.target.sym.payloadType.isDefined)
98 | error(s"Embedded value of type ${n.target.sym.payloadType} must be created", n)
99 | None
100 | }
101 | n.copy(embedded = emb2).asInstanceOf[T]
102 | }
103 | override def apply(n: Apply): Apply = tr(super.apply(n))
104 | override def apply(n: ApplyCons): ApplyCons = tr(super.apply(n))
105 | }
106 | val exprs2 = exprs.map(proc(_))
107 | if(defaultCreate.nonEmpty)
108 | newEmbComps += CreateLabels(local(isEmbedded = true, payloadType = PayloadType.REF, prefix = "defCr"), defaultCreate.toVector)
109 | (exprs2, newEmbComps.result(), symbolMap, allEmbCompSyms.result())
110 | }
111 |
112 | private[this] def checkCond(cond: EmbeddedExpr): EmbeddedExpr = {
113 | val proc: Transform = new Transform {
114 | override def apply(n: Ident): Ident = {
115 | if(n.sym.isDefined && !n.sym.isPattern)
116 | error("Unknown symbol (not in pattern)", n)
117 | n
118 | }
119 | }
120 | proc(cond)
121 | }
122 |
123 | private[this] def checkLinearity(patternIds: Iterable[Ident], consIds: Iterable[Ident], usageCount: mutable.HashMap[Symbol, Int]): Unit = {
124 | patternIds.foreach { id =>
125 | val pt = id.sym.payloadType
126 | usageCount(id.sym) match {
127 | case 0 => if(!pt.canErase) error(s"Cannot implicitly erase value of type $pt", id)
128 | case 1 =>
129 | case _ => if(!pt.canCopy) error(s"Cannot implicitly copy value of type $pt", id)
130 | }
131 | }
132 | consIds.foreach { id =>
133 | val pt = id.sym.payloadType
134 | usageCount(id.sym) match {
135 | case 0 => if(!pt.canCreate) error(s"Cannot implicitly create value of type $pt", id)
136 | case 1 =>
137 | case _ => error(s"Duplicate assignment", id)
138 | }
139 | }
140 | }
141 |
142 | private[this] def transformReduction(reduced: Vector[Expr], embComps: Vector[EmbeddedExpr],
143 | patternIds: Set[Ident]): (Vector[Expr], Vector[EmbeddedExpr]) = {
144 | val (reduced2, newEmbComps, newEmbIds, assignedEmbCompSyms) = extractConsEmbComps(reduced)
145 | val allEmbCompSyms = assignedEmbCompSyms ++ newEmbIds.values.flatten.map(_.sym)
146 | val usageCount = mutable.HashMap.from((patternIds.map(_.sym) ++ allEmbCompSyms).map((_ , 0)))
147 | def checkLHS(i: Ident): Unit =
148 | if(!allEmbCompSyms.contains(i.sym)) error("Assignment must be to an embedded variable of the reduction", i)
149 | val mapIdents: Transform = new Transform {
150 | override def apply(n: Ident): Ident = {
151 | if(n.sym.isCons || n.sym.isEmpty) n
152 | else {
153 | val n2 = newEmbIds.get(n.sym) match {
154 | case Some(ids) if ids.length > 1 =>
155 | error(s"Multiple occurrences of symbol ${n.sym} in cell constructors", n)
156 | n
157 | case Some(ids) => ids.head
158 | case _ => n
159 | }
160 | if (!n2.sym.isPattern && !allEmbCompSyms.contains(n2.sym)) {
161 | error("Unknown symbol (not in pattern or cell constructor)", n)
162 | n
163 | }
164 | usageCount.update(n2.sym, usageCount(n2.sym) + 1)
165 | n2
166 | }
167 | }
168 | override def apply(n: EmbeddedAssignment): EmbeddedAssignment = {
169 | val n2 = super.apply(n)
170 | checkLHS(n2.lhs)
171 | n2
172 | }
173 | }
174 | val mappedEmbComps = ArrayBuffer.empty[EmbeddedExpr]
175 | embComps.foreach(mappedEmbComps += mapIdents(_))
176 | newEmbComps.foreach(mappedEmbComps += mapIdents(_))
177 | val unusedNewEmbIds =
178 | newEmbIds.iterator.map { case (oldSym, newIds) => (oldSym, newIds.filter(i => usageCount(i.sym) == 0)) }.filter(_._2.nonEmpty).toVector
179 | unusedNewEmbIds.foreach { case (oldSym, newIds) =>
180 | mappedEmbComps += CreateLabels(oldSym, newIds.iterator.map(_.sym).toVector).setPos(newIds.head.pos)
181 | newIds.foreach { i => usageCount.update(i.sym, usageCount(i.sym) + 1) }
182 | }
183 | checkLinearity(patternIds, newEmbIds.iterator.flatMap(_._2).toVector, usageCount)
184 |
185 | // println("************* Reduction:")
186 | // mappedEmbComps.foreach(ShowableNode.print(_))
187 | // println("newEmbIds: " + newEmbIds.iterator.map { case (k, v) => s"$k -> ${v.mkString("[", ",", "]")}"}.mkString(", "))
188 | // println("unusedNewEmbIds: " + unusedNewEmbIds.iterator.map { case (k, v) => s"$k -> ${v.mkString("[", ",", "]")}"}.mkString(", "))
189 | // println("usageCount: " + usageCount.iterator.map { case (k, v) => s"$k -> $v"}.mkString(", "))
190 |
191 | (reduced2, mappedEmbComps.toVector)
192 | }
193 | }
194 |
--------------------------------------------------------------------------------
/src/main/scala/Colors.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | abstract class Colors {
4 | def cNormal: String
5 | def cBlack: String
6 | def cRed: String
7 | def cGreen: String
8 | def cYellow: String
9 | def cBlue: String
10 | def cMagenta: String
11 | def cCyan: String
12 | def bRed: String
13 | def bGreen: String
14 | def bYellow: String
15 | def bBlue: String
16 | def bMagenta: String
17 | def bCyan: String
18 | }
19 |
20 | object MaybeColors extends Colors {
21 | val useColors: Boolean = System.getProperty("interact.colors", "true").toBoolean
22 |
23 | val (cNormal, cBlack, cRed, cGreen, cYellow, cBlue, cMagenta, cCyan) =
24 | if(useColors) ("\u001B[0m", "\u001B[30m", "\u001B[31m", "\u001B[32m", "\u001B[33m", "\u001B[34m", "\u001B[35m", "\u001B[36m")
25 | else ("", "", "", "", "", "", "", "")
26 | val (bRed, bGreen, bYellow, bBlue, bMagenta, bCyan) =
27 | if(useColors) ("\u001B[41m", "\u001B[42m", "\u001B[43m", "\u001B[44m", "\u001B[45m", "\u001B[46m")
28 | else ("", "", "", "", "", "")
29 | }
30 |
31 | object NoColors extends Colors {
32 | val cNormal, cBlack, cRed, cGreen, cYellow, cBlue, cMagenta, cCyan, bRed, bGreen, bYellow, bBlue, bMagenta, bCyan = ""
33 | }
34 |
35 | object Colors {
36 | def stripColors(s: String): String = {
37 | val b = new StringBuilder(s.length)
38 | var i = 0
39 | while(i < s.length) {
40 | s.charAt(i) match {
41 | case '\u001B' =>
42 | val end = s.indexOf('m', i+1)
43 | if(end == -1) i += 1
44 | else i = end+1
45 | case c =>
46 | b.append(c)
47 | i += 1
48 | }
49 | }
50 | b.result()
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/scala/Compiler.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 | import de.szeiger.interact.offheap.{Allocator, ProxyAllocator, SliceAllocator}
5 |
6 | import java.nio.file.Path
7 | import scala.collection.mutable
8 |
9 | class Compiler(unit0: CompilationUnit, _config: Config = Config.defaultConfig) {
10 | val global = new Global(_config)
11 | import global._
12 |
13 | private[this] val phases: Vector[Phase] = Vector(
14 | new Prepare(global),
15 | new ExpandRules(global),
16 | new Curry(global),
17 | new CleanEmbedded(global),
18 | new ResolveEmbedded(global),
19 | new CreateWiring(global),
20 | new Inline(global),
21 | ) ++ (if(config.backend.planRules) Vector(
22 | new PlanRules(global),
23 | ) else Vector.empty)
24 |
25 | private[this] val unit1 = if(config.addEraseDup) {
26 | val erase = globalSymbols.define("erase", isCons = true, isDef = true, returnArity = 0)
27 | val dup = globalSymbols.define("dup", isCons = true, isDef = true, arity = 2, returnArity = 2, payloadType = PayloadType.LABEL)
28 | unit0.copy(statements = Vector(DerivedRule(erase, erase), DerivedRule(erase, dup), DerivedRule(dup, dup)) ++ unit0.statements).setPos(unit0.pos)
29 | } else unit0
30 |
31 | val unit = phases.foldLeft(unit1) { case (u, p) =>
32 | val u2 = p(u)
33 | if(config.showAfter.contains(p.phaseName) || config.showAfter.contains("*"))
34 | ShowableNode.print(u2, name = s"After phase ${p.phaseName}")
35 | checkThrow()
36 | u2
37 | }
38 |
39 | def createInterpreter(): BaseInterpreter = config.backend.createInterpreter(this)
40 | }
41 |
42 | trait Phase extends (CompilationUnit => CompilationUnit) {
43 | val global: Global
44 | val phaseName: String = getClass.getName.replaceAll(".*\\.", "")
45 | val phaseLogEnabled: Boolean = global.config.phaseLog.contains(phaseName) || global.config.phaseLog.contains("*")
46 | override def toString: String = phaseName
47 |
48 | @inline final def phaseLog(@inline msg: => String): Unit = if(phaseLogEnabled) global.phaseLog(phaseName, msg)
49 | @inline final def phaseLog(n: ShowableNode, name: => String, prefix: String = ""): Unit = if(phaseLogEnabled) global.phaseLog(phaseName, n, name, prefix)
50 | }
51 |
52 | case class Config(
53 | // Frontend
54 | defaultDerive: Seq[String] = Seq("erase", "dup"),
55 | addEraseDup: Boolean = true,
56 | phaseLog: Set[String] = Set.empty, // show debug log of these phases
57 | showAfter: Set[String] = Set.empty, // log AST after these phases
58 | inlineFull: Boolean = true, // inline rules that can be merged into a single branch
59 | inlineFullAll: Boolean = true, // inline simple matches even when duplicating a parent rule
60 | inlineBranching: Boolean = true, // inline rules that cannot be merged into a single branch (st.c)
61 | inlineUniqueContinuations: Boolean = true, // st.c
62 | loop: Boolean = true,
63 | repeatedInliningLimit: Int = 0,
64 |
65 | // Backend
66 | backend: Backend = STC1Backend,
67 | numThreads: Int = 0, // mt
68 | collectStats: Boolean = false,
69 | useCellCache: Boolean = false, // stc*
70 | biasForCommonDispatch: Boolean = true, // optimize for invokevirtual dispatch of statically known cell types (stc1)
71 | logCodeGenSummary: Boolean = false, // stc*, mt.c
72 | logGeneratedClasses: Option[String] = None, // Log generated classes containing this string (stc*, mt.c)
73 | compilerParallelism: Int = 1,
74 | allCommon: Boolean = false, // compile all methods into CommonCell, not just shared ones (stc1)
75 | reuseCells: Boolean = true, // stc*
76 | writeOutput: Option[Path] = None, // write generated classes to dir or jar file (stc*)
77 | writeJava: Option[Path] = None, // write decompiled classes to dir (stc*)
78 | skipCodeGen: Set[String] = Set.empty, // do not generate classfiles for these Java class names (stc*)
79 | tailCallDepth: Int = 32, // stc2
80 | newAllocator: () => ProxyAllocator = () => new SliceAllocator(), // stc2
81 | debugMemory: Boolean = false, // stc2
82 | unboxedPrimitives: Boolean = true, // Use unboxed 32-bit primitives and singletons instead of pointers (stc2)
83 | ) {
84 | def withSpec(spec: String): Config = spec match {
85 | case s"sti" => copy(backend = STIBackend)
86 | case s"stc1" => copy(backend = STC1Backend)
87 | case s"stc2" => copy(backend = STC2Backend)
88 | case s"mt${mode}.i" => copy(backend = MTIBackend, numThreads = mode.toInt)
89 | case s"mt${mode}.c" => copy(backend = MTCBackend, numThreads = mode.toInt)
90 | }
91 | }
92 |
93 | abstract class Backend(val name: String) {
94 | def createInterpreter(comp: Compiler): BaseInterpreter
95 | def planRules: Boolean
96 | def inlineBranching: Boolean
97 | def inlineUniqueContinuations: Boolean
98 | def allowPayloadTemp: Boolean
99 | def canReuseLabels: Boolean
100 |
101 | def storageClass(sym: Symbol): Any = sym
102 | def canUnbox(sym: Symbol): Boolean = false
103 | }
104 |
105 | object STIBackend extends Backend("sti") {
106 | def createInterpreter(comp: Compiler): BaseInterpreter =
107 | new sti.Interpreter(comp.global.globalSymbols, comp.unit, comp.global.config)
108 | def planRules: Boolean = false
109 | def inlineBranching: Boolean = false
110 | def inlineUniqueContinuations: Boolean = false
111 | def allowPayloadTemp: Boolean = false
112 | def canReuseLabels: Boolean = true
113 | }
114 |
115 | object STC1Backend extends Backend("stc1") {
116 | def createInterpreter(comp: Compiler): BaseInterpreter =
117 | new stc1.Interpreter(comp.global.globalSymbols, comp.unit, comp.global.config)
118 | def planRules: Boolean = true
119 | def inlineBranching: Boolean = true
120 | def inlineUniqueContinuations: Boolean = true
121 | def allowPayloadTemp: Boolean = true
122 | def canReuseLabels: Boolean = true
123 | }
124 |
125 | object STC2Backend extends Backend("stc2") {
126 | def createInterpreter(comp: Compiler): BaseInterpreter =
127 | new stc2.Interpreter(comp.global.globalSymbols, comp.unit, comp.global.config)
128 | def planRules: Boolean = true
129 | def inlineBranching: Boolean = true
130 | def inlineUniqueContinuations: Boolean = true
131 | def allowPayloadTemp: Boolean = true
132 | def canReuseLabels: Boolean = false
133 | override def storageClass(sym: Symbol) = (stc2.Interpreter.cellSize(sym.arity, sym.payloadType), sym.payloadType == PayloadType.REF)
134 | override def canUnbox(sym: Symbol): Boolean = stc2.Interpreter.canUnbox(sym, sym.arity)
135 | }
136 |
137 | class MTBackend(name: String, compile: Boolean) extends Backend(name) {
138 | def createInterpreter(comp: Compiler): BaseInterpreter = {
139 | import comp.global._
140 | val rulePlans = mutable.Map.empty[RuleKey, RuleWiring]
141 | val initialPlans = mutable.ArrayBuffer.empty[InitialRuleWiring]
142 | comp.unit.statements.foreach {
143 | case i: InitialRuleWiring => initialPlans += i
144 | case g: RuleWiring => rulePlans.put(g.key, g)
145 | }
146 | new mt.Interpreter(globalSymbols, rulePlans.values, config, mutable.ArrayBuffer.empty[Let], initialPlans, compile)
147 | }
148 | def planRules: Boolean = compile
149 | def inlineBranching: Boolean = compile
150 | def inlineUniqueContinuations: Boolean = compile
151 | def allowPayloadTemp: Boolean = compile
152 | def canReuseLabels: Boolean = compile
153 | }
154 |
155 | object MTIBackend extends MTBackend("mti", false)
156 | object MTCBackend extends MTBackend("mti", true)
157 |
158 | object Config {
159 | val defaultConfig: Config = Config()
160 | def apply(spec: String): Config = defaultConfig.withSpec(spec)
161 | }
162 |
--------------------------------------------------------------------------------
/src/main/scala/Curry.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | import scala.collection.mutable
6 |
7 | /**
8 | * Create curried and derived rules, and remove all Cons and Def statements.
9 | */
10 | class Curry(val global: Global) extends Transform with Phase {
11 | import global._
12 |
13 | private[this] lazy val defaultDeriveSyms =
14 | config.defaultDerive.iterator.map(globalSymbols.get).filter(_.exists(_.isCons)).map(_.get).toVector
15 |
16 | private[this] def derivedRules(syms1: Vector[Symbol], sym2: Symbol, pos: Position): Vector[DerivedRule] =
17 | syms1.flatMap { sym =>
18 | if(sym.id == "erase" || sym.id == "dup") Vector(DerivedRule(sym, sym2).setPos(pos))
19 | else { error(s"Don't know how to derive '$sym'", pos); Vector.empty }
20 | }
21 |
22 | private[this] def singleNonIdentIdx(es: Seq[Expr]): Int = {
23 | val i1 = es.indexWhere(e => !e.isInstanceOf[Ident])
24 | if(i1 == -1) -1
25 | else {
26 | val i2 = es.lastIndexWhere(e => !e.isInstanceOf[Ident])
27 | if(i2 == i1) i1 else -2
28 | }
29 | }
30 |
31 | private[this] def createCurriedDef(lid: Ident, rid: Ident, idx: Int, rhs: Boolean, at: Position): (Ident, Vector[CheckedRule]) = {
32 | val ls = lid.sym
33 | val rs = rid.sym
34 | val curryId = Ident(s"${ls.id}$$${if(rhs) "ac" else "nc"}$$${rs.id}").setPos(lid.pos)
35 | val rules = globalSymbols.get(curryId) match {
36 | case Some(sym) =>
37 | if(sym.matchContinuationPort != idx) error(s"Port mismatch in curried ${ls.id} -> ${rs.id} match", at)
38 | curryId.sym = sym
39 | Vector.empty
40 | case None if ls.hasPayload && rs.hasPayload =>
41 | error("Implementation limitation: Curried definitions cannot have payload on both sides", at)
42 | Vector.empty
43 | case None =>
44 | val curriedPtp = if(ls.hasPayload) ls.payloadType else rs.payloadType
45 | val emb1 = if(ls.hasPayload) Some(mkLocalId("$l", true).setPos(lid.pos)) else None
46 | val emb2 = if(rs.hasPayload) Some(mkLocalId("$r", true).setPos(rid.pos)) else None
47 | curryId.sym = globalSymbols.define(curryId.s, isCons = true, arity = ls.arity + rs.arity - 1, payloadType = curriedPtp, matchContinuationPort = idx)
48 | val largs = (0 until ls.callArity).map(i => mkLocalId(s"$$l$i").setPos(lid.pos)).toVector
49 | val rargs = (0 until rs.callArity).map(i => mkLocalId(s"$$r$i").setPos(rid.pos)).toVector
50 | val (keepArgs, splitArgs) = if(rhs) (rargs, largs) else (largs, rargs)
51 | val curryArgs = keepArgs ++ splitArgs.zipWithIndex.filter(_._2 != idx).map(_._1)
52 | val der = derivedRules(defaultDeriveSyms, curryId.sym, Position.unknown)
53 | val fwd = Assignment(Apply(curryId, emb1.orElse(emb2), curryArgs).setPos(at), splitArgs(idx)).setPos(at)
54 | der :+ MatchRule(lid, rid, largs, rargs, emb1, emb2, Vector(Branch(None, Vector.empty, Vector(fwd)).setPos(at))).setPos(at)
55 | }
56 | (curryId, rules)
57 | }
58 |
59 | private[this] def curry(mr: MatchRule): Vector[Statement] = mr.args1.indexWhere(e => !e.isInstanceOf[Ident]) match {
60 | case -1 =>
61 | singleNonIdentIdx(mr.args2) match {
62 | case -2 =>
63 | error(s"Only one nested match allowed", mr.pos)
64 | Vector.empty
65 | case -1 =>
66 | mr.args1.toSet.intersect(mr.args2.toSet).foreach { case i: Ident =>
67 | error(s"Duplicate variable '${i.s}' on both sides of a match", i)
68 | }
69 | Vector(mr)
70 | case idx =>
71 | val (curryId, curryRules) = createCurriedDef(mr.id1, mr.id2, idx, false, mr.args2(idx).pos)
72 | val ApplyCons(cid, cemb, crargs) = mr.args2(idx)
73 | val clargs = mr.args1 ++ mr.args2.zipWithIndex.filter(_._2 != idx).map(_._1.asInstanceOf[Ident])
74 | curryRules ++ curry(mr.copy(curryId, cid, clargs, crargs, mr.emb1.orElse(mr.emb2), checkCEmb(cemb), mr.branches))
75 | }
76 | case idx =>
77 | val (curryId, curryRules) = createCurriedDef(mr.id1, mr.id2, idx, true, mr.args1(idx).pos)
78 | val ApplyCons(cid, cemb, clargs) = mr.args1(idx)
79 | val crargs = mr.args2 ++ mr.args1.zipWithIndex.filter(_._2 != idx).map(_._1.asInstanceOf[Ident])
80 | curryRules ++ curry(mr.copy(curryId, cid, crargs, clargs, mr.emb1.orElse(mr.emb2), checkCEmb(cemb), mr.branches))
81 | }
82 |
83 | private[this] def checkCEmb(o: Option[EmbeddedExpr]): Option[Ident] = o.flatMap {
84 | case i: Ident => Some(i)
85 | case e =>
86 | error("Embedded expression in pattern match must be a variable", e)
87 | None
88 | }
89 |
90 | override def apply(n: Statement): Vector[Statement] = n match {
91 | case n: MatchRule => curry(n)
92 | case c: Cons => derivedRules(c.der.map(_.map(_.sym)).getOrElse(defaultDeriveSyms), c.name.sym, c.name.pos)
93 | case d: Def => derivedRules(defaultDeriveSyms, d.name.sym, d.name.pos)
94 | case n => Vector(n)
95 | }
96 |
97 | override def apply(n: CompilationUnit): CompilationUnit = {
98 | val n2 = super.apply(n)
99 | val keys = mutable.HashSet.empty[RuleKey]
100 | n2.statements.foreach {
101 | case c: CheckedRule =>
102 | if(!keys.add(c.key)) error(s"Duplicate rule ${c.sym1} <-> ${c.sym2}", c)
103 | case _ =>
104 | }
105 | n2
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/scala/Debug.scala:
--------------------------------------------------------------------------------
1 | import de.szeiger.interact._
2 | import de.szeiger.interact.sti.Cell
3 |
4 | import java.nio.file.Path
5 | import scala.annotation.tailrec
6 | import scala.collection.mutable
7 |
8 | object Debug extends App {
9 | val statements = Parser.parse(Path.of(args(0)))
10 | val model = new Compiler(statements, Config(backend = STIBackend))
11 | val inter = model.createInterpreter().asInstanceOf[sti.Interpreter]
12 | inter.initData()
13 |
14 | var steps = 0
15 | var cuts: mutable.ArrayBuffer[(Cell, Cell)] = _
16 |
17 | @tailrec
18 | def readLine(): Option[Int] = {
19 | print("> ")
20 | val in = Console.in.readLine()
21 | if(in == "q") None
22 | else in.toIntOption.filter(i => i >= 0 && i < cuts.length) match {
23 | case None => readLine()
24 | case o => o
25 | }
26 | }
27 |
28 | @tailrec def step(): Unit = {
29 | println(s"${MaybeColors.cGreen}At step $steps:${MaybeColors.cNormal}")
30 | cuts = inter.getAnalyzer.log(System.out, markCut = (c1, c2) => inter.getRuleImpl(c1, c2) != null)
31 | if(cuts.isEmpty)
32 | println(s"${MaybeColors.cGreen}Irreducible after $steps reductions.${MaybeColors.cNormal}")
33 | else {
34 | steps += 1
35 | readLine() match {
36 | case None => ()
37 | case Some(idx) =>
38 | inter.reduce1(cuts(idx)._1, cuts(idx)._2)
39 | step()
40 | }
41 | }
42 | }
43 |
44 | step()
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/scala/ExecutionMetrics.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import scala.collection.mutable
4 |
5 | class ExecutionMetrics {
6 | private[this] final class MutInt(var i: Int)
7 | private[this] var steps, cellAlloc, proxyAlloc, cellReuse, singletonUse, unboxedCells, loopSave, directTail, singleDispatchTail, labelCreate = 0
8 | private[this] var metrics = mutable.Map.empty[String, MutInt]
9 |
10 | def getSteps: Int = steps
11 |
12 | def recordStats(steps: Int, cellAllocations: Int, proxyAllocations: Int, cachedCellReuse: Int, singletonUse: Int,
13 | unboxedCells: Int, loopSave: Int, directTail: Int, singleDispatchTail: Int, labelCreate: Int): Unit = {
14 | this.steps += steps
15 | this.cellAlloc += cellAllocations
16 | this.proxyAlloc += proxyAllocations
17 | this.cellReuse += cachedCellReuse
18 | this.singletonUse += singletonUse
19 | this.unboxedCells += unboxedCells
20 | this.loopSave += loopSave
21 | this.directTail += directTail
22 | this.singleDispatchTail += singleDispatchTail
23 | this.labelCreate += labelCreate
24 | }
25 |
26 | def recordStats(steps: Int, cellAllocations: Int): Unit = recordStats(steps, cellAllocations, 0, 0, 0, 0, 0, 0, 0, 0)
27 |
28 | def recordMetric(metric: String, inc: Int = 1): Unit = {
29 | val m = metrics.getOrElseUpdate(metric, new MutInt(0))
30 | m.i += inc
31 | }
32 |
33 | def log(): Unit = {
34 | logStats()
35 | logMetrics()
36 | }
37 |
38 | def logStats(): Unit = {
39 | println(s"Steps: $steps ($loopSave loop, $directTail tail ($singleDispatchTail single-dispatch), ${steps-loopSave-directTail} other)")
40 | println(s" Cells created: $cellAlloc new ($proxyAlloc proxied), $cellReuse cached, $singletonUse singleton, $unboxedCells unboxed; Labels created: $labelCreate")
41 | }
42 |
43 | def logMetrics(): Unit = {
44 | val data = metrics.toVector.sortBy(_._1).map { case (k, v) => (k, v.i.toString) }
45 | val maxLen = data.iterator.map { case (k, v) => k.length + v.length }.max
46 | data.foreach { case (k, v) =>
47 | val pad = " " * (maxLen-k.length-v.length)
48 | println(s" $k $pad$v")
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/scala/ExpandRules.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | /**
6 | * Complete rules by synthesizing omitted return wires in conditions and reductions.
7 | */
8 | class ExpandRules(val global: Global) extends Transform with Phase {
9 | import global._
10 |
11 | val normalizeCond = new NormalizeCondition(global)
12 |
13 | private[this] def wildcardCount(e: Expr): Int = e match {
14 | case _: Wildcard => 1
15 | case e: Apply => e.args.iterator.map(wildcardCount).sum
16 | case _ => 0
17 | }
18 |
19 | private[this] def returnArity(e: Expr): Int = e match {
20 | case e: Apply => e.target.sym.returnArity
21 | case Assignment(lhs, rhs) => wildcardCount(lhs) + wildcardCount(rhs)
22 | case _ => 0
23 | }
24 |
25 | override def apply(n: Statement): Vector[Statement] = n match {
26 | case n: Match => apply(n)
27 | case n: Def => apply(n)
28 | case n => Vector(n)
29 | }
30 |
31 | override def apply(m: Match): Vector[Statement] = returnArity(m.on.head) match {
32 | case 0 => checkedMatch(m.on, m.reduced, m.pos)
33 | case n =>
34 | assert(m.on.length == 1)
35 | val p = m.on.head.pos
36 | checkedMatch(Vector(Assignment(m.on.head, Tuple((1 to n).map(i => mkLocalId(s"$$ret$i").setPos(p)).toVector).setPos(p)).setPos(p)), m.reduced, m.pos)
37 | }
38 |
39 | override def apply(d: Def): Vector[Statement] = {
40 | d.copy(rules = Vector.empty) +: d.rules.flatMap { r =>
41 | val dret = Tuple(d.ret).setPos(d.pos)
42 | checkedMatch(Vector(Assignment(Apply(d.name, d.embeddedId, r.on ++ d.args.drop(r.on.length)).setPos(d.pos), dret).setPos(r.pos)), r.reduced, r.pos)
43 | }
44 | }
45 |
46 | private[this] def connectLastStatement(e: Expr, extraRhs: Vector[Ident]): Expr = e match {
47 | case e: Assignment => e
48 | case e: Tuple =>
49 | if(e.exprs.length != extraRhs.length)
50 | error(s"Expected return arity ${extraRhs.length} for reduction but got ${e.exprs.length}", e)
51 | Assignment(Tuple(extraRhs).setPos(extraRhs.head.pos), e).setPos(extraRhs.head.pos)
52 | case e: Apply =>
53 | val sym = globalSymbols(e.target.s)
54 | if(sym.returnArity == 0) e
55 | else {
56 | if(sym.returnArity != extraRhs.length)
57 | error(s"Expected return arity ${extraRhs.length} for reduction but got ${sym.returnArity}", e)
58 | Assignment(if(extraRhs.length == 1) extraRhs.head else Tuple(extraRhs).setPos(extraRhs.head.pos), e).setPos(extraRhs.head.pos)
59 | }
60 | case e: NatLit =>
61 | if(extraRhs.length != 1)
62 | error(s"Expected return arity ${extraRhs.length} for reduction but got 1", e)
63 | Assignment(extraRhs.head, e).setPos(extraRhs.head.pos)
64 | case e: Ident =>
65 | if(extraRhs.length != 1)
66 | error(s"Expected return arity ${extraRhs.length} for reduction but got 1", e)
67 | Assignment(extraRhs.head, e).setPos(e.pos)
68 | }
69 |
70 | private[this] def checkedMatch(on: Vector[Expr], red: Vector[Branch], pos: Position): Vector[Statement] = {
71 | val inlined = normalizeCond.toInline(normalizeCond.toANF(on).map(normalizeCond.toConsOrder))
72 | inlined match {
73 | case Seq(Assignment(ApplyCons(lid, lemb, largs: Seq[Expr]), ApplyCons(rid, remb, rargs))) =>
74 | val compl = if(lid.sym.isDef) largs.takeRight(lid.sym.returnArity) else Vector.empty
75 | val connected = red.map { r =>
76 | r.copy(reduced = r.reduced.init :+ connectLastStatement(r.reduced.last, compl.asInstanceOf[Vector[Ident]]))
77 | }
78 | val mr = MatchRule(lid, rid, largs, rargs, lemb.map(_.asInstanceOf[Ident]), remb.map(_.asInstanceOf[Ident]), connected).setPos(pos)
79 | Vector(mr)
80 | case _ =>
81 | error(s"Invalid match", pos)
82 | Vector.empty
83 | }
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/main/scala/Global.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | import java.io.{PrintWriter, StringWriter}
6 | import scala.collection.mutable.ArrayBuffer
7 | import scala.util.control.NonFatal
8 |
9 | final class Global(val config: Config) {
10 | final val globalSymbols = new Symbols
11 |
12 | private[this] var hasErrors: Boolean = false
13 | private[this] val accumulated = ArrayBuffer.empty[Notice]
14 |
15 | def dependencyLoader: ClassLoader = getClass.getClassLoader
16 |
17 | def warning(msg: String, at: Node): Unit = error(msg, at.pos)
18 | def warning(msg: String, at: Position): Unit =
19 | accumulated += new Notice(msg, at, Severity.Warning)
20 |
21 | def error(msg: String, at: Node): Unit = error(msg, at.pos)
22 | def error(msg: String, at: Position): Unit = {
23 | accumulated += new Notice(msg, at, Severity.Error)
24 | hasErrors = true
25 | }
26 |
27 | def throwError(msg: String, at: Node): Nothing = throwError(msg, at.pos)
28 | def throwError(msg: String, at: Position): Nothing = {
29 | error(msg, at)
30 | throw new ReportedError
31 | }
32 | def tryError[T](f: => T): Option[T] = try Some(f) catch { case _: ReportedError => None }
33 |
34 | def fatal(msg: String, at: Node): Nothing = fatal(msg, at.pos)
35 | def fatal(msg: String, at: Position): Nothing = {
36 | accumulated += new Notice(msg, at, Severity.Fatal)
37 | throw getCompilerResult()
38 | }
39 |
40 | def internalError(msg: String, at: Position, parent: Throwable = null, atNode: Node = null): Unit = {
41 | accumulated += new Notice(msg, if(at == null && atNode != null) atNode.pos else at, Severity.Fatal, atNode, internal = true)
42 | throw getCompilerResult(parent)
43 | }
44 |
45 | def mkLocalId(name: String, isEmbedded: Boolean = false, payloadType: PayloadType = PayloadType.VOID): Ident = {
46 | val i = Ident(name)
47 | i.sym = Symbol(name, isEmbedded = isEmbedded, payloadType = payloadType)
48 | i
49 | }
50 |
51 | def getCompilerResult(parent: Throwable = null): CompilerResult = new CompilerResult(accumulated.toIndexedSeq, parent)
52 |
53 | def checkThrow(): Unit =
54 | if(hasErrors) throw getCompilerResult()
55 |
56 | def phaseLog(phase: String, msg: String): Unit =
57 | println(s"<$phase> $msg")
58 |
59 | def phaseLog(phase: String, n: ShowableNode, name: String, prefix: String): Unit = {
60 | val p = s"<$phase> "
61 | ShowableNode.print(n, name = name, prefix = p+prefix, prefix1 = p+prefix, highlightTopLevel = false)
62 | }
63 | }
64 |
65 | class Notice(msg: String, at: Position, severity: Severity, atNode: ShowableNode = null, internal: Boolean = false, mark: Node = null) {
66 | def formatted: String = {
67 | import MaybeColors._
68 | import Notice._
69 | val b = new StringBuilder
70 | val sev = if(internal) s"${cRed}Internal Error" else if(isError) s"${cRed}Error" else s"${cYellow}Warning"
71 | val msgLines = msg.split('\n')
72 | if(at.isDefined) {
73 | val (line, col) = at.input.find(at.offset)
74 | b.append(s"$sev: $cNormal${at.file}$cCyan:${line+1}:${col+1}$cNormal: ${msgLines.head}$eol")
75 | msgLines.tail.foreach { m => b.append(s"$cBlue| $cNormal$m$eol") }
76 | b.append(s"$cBlue| $cNormal${at.input.getLine(line).stripTrailing()}$eol")
77 | b.append(s"$cBlue| $cGreen${at.input.getCaret(col)}$cNormal")
78 | } else {
79 | b.append(s"$sev: $cNormal ${msgLines.head}$eol")
80 | msgLines.tail.foreach { m => b.append(s"$cBlue| $cNormal$m$eol") }
81 | }
82 | if(internal && atNode != null) {
83 | val out = new StringWriter()
84 | val outw = new PrintWriter(out)
85 | outw.print(s"\n$cBlue| ${cRed}AST Context:$cNormal\n")
86 | ShowableNode.print(atNode, outw, prefix = s"$cBlue| ", mark = mark)
87 | b.append(out.toString)
88 | }
89 | b.result()
90 | }
91 | override def toString: String = formatted
92 | def isError: Boolean = severity != Severity.Warning
93 | }
94 |
95 | object Notice {
96 | val eol: String = sys.props("line.separator")
97 | }
98 |
99 | sealed abstract class Severity
100 | object Severity {
101 | case object Warning extends Severity
102 | case object Error extends Severity
103 | case object Fatal extends Severity
104 | }
105 |
106 | class CompilerResult(val notices: IndexedSeq[Notice], parent: Throwable = null) extends Exception(parent) {
107 | lazy val hasErrors = notices.exists(_.isError)
108 | lazy val summary: String = {
109 | val errs = notices.count(_.isError)
110 | val warns = notices.length - errs
111 | def fmt(i: Int, s: String) = if(i == 1) s"1 $s" else s"$i ${s}s"
112 | if(warns > 0) s"${fmt(errs, "error")}, ${fmt(warns, "warnings")} found."
113 | else s"${fmt(errs, "error")} found."
114 | }
115 | override def getMessage: String = {
116 | import Notice._
117 | val b = (new StringBuilder).append(eol)
118 | notices.foreach(n => b.append(n.formatted).append(eol).append(eol))
119 | b.append(summary).result()
120 | }
121 | }
122 | object CompilerResult {
123 | def tryInternal[T](at: Position)(f: => T): T = try f catch {
124 | case e: CompilerResult => throw e
125 | case e: AssertionError => throw new CompilerResult(Vector(new Notice(e.toString, at, Severity.Fatal, internal = true)), e)
126 | case NonFatal(e) => throw new CompilerResult(Vector(new Notice(e.toString, at, Severity.Fatal, internal = true)), e)
127 | }
128 | def tryInternal[T](atNode: Node)(f: => T): T = try f catch {
129 | case e: CompilerResult => throw e
130 | case e: AssertionError => throw new CompilerResult(Vector(new Notice(e.toString, atNode.pos, Severity.Fatal, internal = true, atNode = atNode)), e)
131 | case NonFatal(e) => throw new CompilerResult(Vector(new Notice(e.toString, atNode.pos, Severity.Fatal, internal = true, atNode = atNode)), e)
132 | }
133 |
134 | def fail(msg: String, at: Position = null, parent: Throwable = null, atNode: Node = null, internal: Boolean = true, mark: Node = null): Nothing =
135 | throw new CompilerResult(Vector(new Notice(msg, if(at == null && atNode != null) atNode.pos else at, Severity.Fatal, atNode, internal, mark)))
136 | }
137 |
138 | class ReportedError extends Exception("Internal error: ReportedError thrown by throwError must be caught by a surrounding tryError")
139 |
140 | trait BaseInterpreter {
141 | def getAnalyzer: Analyzer[_]
142 | def initData(): Unit
143 | def reduce(): Unit
144 | def dispose(): Unit
145 | def getMetrics: ExecutionMetrics
146 | }
147 |
--------------------------------------------------------------------------------
/src/main/scala/Main.scala:
--------------------------------------------------------------------------------
1 | import java.nio.file.Path
2 | import de.szeiger.interact._
3 | import de.szeiger.interact.ast.ShowableNode
4 |
5 | object Main extends App {
6 | def handleRes(res: CompilerResult, full: Boolean): Unit = {
7 | if(full) res.printStackTrace(System.out)
8 | else {
9 | res.notices.foreach(println)
10 | println(res.summary)
11 | }
12 | if(res.hasErrors) sys.exit(1)
13 | }
14 | try {
15 | val unit = Parser.parse(Path.of(args(0)))
16 | ShowableNode.print(unit)
17 | //statements.foreach(println)
18 | val model = new Compiler(unit, Config(backend = STIBackend, collectStats = true))
19 |
20 | //println("Constructors:")
21 | //model.constrs.foreach(c => println(s" ${c.show}"))
22 | //println("Defs:")
23 | //model.defs.foreach(d => println(s" ${d.show}"))
24 | //println("Rules:")
25 | //model.rules.foreach(r => if(!r.isInstanceOf[DerivedRule]) println(s" ${r.show}"))
26 | //println("Data:")
27 | //model.data.foreach(r => println(s" ${r.show}"))
28 | //ShowableNode.print(model.unit)
29 |
30 | ShowableNode.print(model.unit)
31 | val inter = model.createInterpreter()
32 | inter.initData()
33 | println("Initial state:")
34 | inter.getAnalyzer.log(System.out)
35 | inter.reduce()
36 | if(inter.getMetrics != null) inter.getMetrics.log()
37 | println(s"Irreducible after ${inter.getMetrics.getSteps} reductions.")
38 | inter.getAnalyzer.log(System.out)
39 | handleRes(model.global.getCompilerResult(), false)
40 | } catch { case ex: CompilerResult => handleRes(ex, true) }
41 | }
42 |
--------------------------------------------------------------------------------
/src/main/scala/NormalizeCondition.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | import scala.collection.mutable
6 |
7 | // Normalize expressions:
8 | // - all compound expressions are unnested (ANF)
9 | // - constructor Idents are converted to Ap
10 | // - only non-constructor Idents can be nested
11 | // - all Ap assignments have the Ap on the RHS
12 | // - all direct assignments are untupled
13 | // - wildcards in assignments are resolved
14 | // - only the last expr can be a non-assignment
15 | // - NatLits are expanded
16 | class NormalizeCondition(global: Global) {
17 | import global._
18 |
19 | private var lastTmp = 0
20 | private def mk(): Ident = { lastTmp += 1; mkLocalId(s"$$u${lastTmp}") }
21 |
22 | def toANF(exprs: Seq[Expr]): Seq[Expr] = {
23 | val assigned = mutable.HashSet.empty[Ident]
24 | val buf = mutable.ArrayBuffer.empty[Expr]
25 | def expandWildcards(e: Expr, wild: Ident): Expr = e match {
26 | case Assignment(ls, rs) =>
27 | val wild2 = mk()
28 | val ass2 = Assignment(expandWildcards(ls, wild2), expandWildcards(rs, wild2)).setPos(e.pos)
29 | if(assigned.contains(wild2)) { reorder(unnest(ass2, false)); wild2 }
30 | else ass2
31 | case Tuple(es) => Tuple(es.map(expandWildcards(_, wild))).setPos(e.pos)
32 | case Apply(t, emb, args) => Apply(t, emb, args.map(expandWildcards(_, wild))).setPos(e.pos)
33 | case Wildcard() =>
34 | if(assigned.contains(wild)) error(s"Duplicate wildcard in assignment", e)
35 | assigned += wild
36 | wild.setPos(e.pos)
37 | case e: Ident => e
38 | case e: NatLit => e.expand
39 | }
40 | def unnest(e: Expr, nested: Boolean): Expr = e match {
41 | case Assignment(ls, rs) =>
42 | if(nested) error("Unexpected nested assignment without wilcard", e)
43 | Assignment(unnest(ls, false), unnest(rs, false)).setPos(e.pos)
44 | case e: Tuple => e.copy(e.exprs.map(unnest(_, true)))
45 | case IdentOrAp(id, emb, args) =>
46 | if(id.sym.isCons || args.nonEmpty) {
47 | val ap = Apply(id, emb, args.map(unnest(_, true))).setPos(e.pos)
48 | if(nested) {
49 | val v = mk().setPos(e.pos)
50 | reorder(Assignment(v, ap).setPos(e.pos))
51 | v
52 | } else ap
53 | } else id
54 | }
55 | def reorder(e: Expr): Unit = e match {
56 | case Assignment(ls: Apply, rs: Apply) =>
57 | val sym1 = ls.target.sym
58 | val sym2 = rs.target.sym
59 | if(sym1.returnArity != sym2.returnArity)
60 | error(s"Arity mismatch in assignment: ${sym1.returnArity} != ${sym2.returnArity}", e)
61 | if(sym1.returnArity == 1) {
62 | val v = mk().setPos(e.pos)
63 | buf += Assignment(v, ls).setPos(ls.pos)
64 | buf += Assignment(v, rs).setPos(rs.pos)
65 | } else {
66 | val vs = (0 until sym1.returnArity).map(_ => mk().setPos(e.pos)).toVector
67 | buf += Assignment(Tuple(vs).setPos(ls.pos), ls).setPos(ls.pos)
68 | buf += Assignment(Tuple(vs).setPos(rs.pos), rs).setPos(rs.pos)
69 | }
70 | case e @ Assignment(ls: Apply, rs) => reorder(e.copy(rs, ls))
71 | case Assignment(Tuple(ls), Tuple(rs)) =>
72 | ls.zip(rs).foreach { case (l, r) => reorder(Assignment(l, r).setPos(e.pos)) }
73 | case e => buf += e
74 | }
75 | exprs.foreach { e =>
76 | val wild = mk()
77 | reorder(unnest(expandWildcards(e, wild), false))
78 | if(assigned.contains(wild)) error("Unexpected wildcard outside of assignment", e)
79 | }
80 | buf.toSeq
81 | }
82 |
83 | // reorder assignments as if every def was a cons
84 | def toConsOrder(e: Expr): Expr = e match {
85 | case Assignment(id @ IdentOrTuple(es), a @ Apply(t, emb, args)) =>
86 | if(!t.sym.isDef) Assignment(id, ApplyCons(t, emb, args).setPos(a.pos)).setPos(e.pos) else Assignment(args.head, ApplyCons(t, emb, args.tail ++ es).setPos(a.pos)).setPos(e.pos)
87 | case Apply(t, emb, args) =>
88 | if(!t.sym.isDef) ApplyCons(t, emb, args).setPos(e.pos) else Assignment(args.head, ApplyCons(t, emb, args.tail).setPos(e.pos))
89 | case e => e
90 | }
91 |
92 | // convert from cons-order ANF back to inlined expressions
93 | def toInline(es: Seq[Expr]): Seq[Expr] = {
94 | if(es.isEmpty) es
95 | else {
96 | val vars = mutable.HashMap.from(es.init.map { case a: Assignment => (a.lhs, a) })
97 | def f(e: Expr): Expr = e match {
98 | case e: Ident => vars.remove(e).map { a => f(a.rhs) }.getOrElse(e)
99 | case e: Tuple => vars.remove(e).map { a => f(a.rhs) }.getOrElse(e)
100 | case e @ ApplyCons(target, emb, args) => e.copy(args = args.map(f))
101 | case e @ Assignment(l, r) => e.copy(f(r), f(l))
102 | }
103 | val e2 = f(es.last)
104 | (vars.valuesIterator ++ Iterator.single(e2)).toSeq
105 | }
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/scala/Parser.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import fastparse._
4 |
5 | import java.nio.charset.StandardCharsets
6 | import java.nio.file.{Files, Path}
7 | import scala.collection.mutable
8 | import de.szeiger.interact.ast._
9 |
10 |
11 | import scala.collection.mutable.ArrayBuffer
12 |
13 | object Lexical {
14 | import NoWhitespace._
15 |
16 | val reservedTokens = Set("cons", "let", "deriving", "def", "if", "else", "match", "_", ":", "=", "=>", "|")
17 | private val operatorStart = IndexedSeq(Set('*', '/', '%'), Set('+', '-'), Set(':'), Set('<', '>'), Set('=', '!'), Set('&'), Set('^'), Set('|'))
18 | private val operatorCont = operatorStart.iterator.flatten.toSet
19 | val MaxPrecedence = operatorStart.length-1
20 |
21 | def precedenceOf(s: String): Int =
22 | if(reservedTokens.contains(s) || !s.forall(operatorCont.contains)) -1
23 | else { val c = s.charAt(0); operatorStart.indexWhere(_.contains(c)) }
24 | def isRightAssoc(s: String): Boolean = s.endsWith(":")
25 |
26 | def ident[_: P]: P[String] =
27 | P( (letter|"_") ~ (letter | digit | "_").rep ).!.filter(!reservedTokens.contains(_))
28 | def kw[_: P](s: String) = P( s ~ !(letter | digit | "_") )
29 | def letter[_: P] = P( CharIn("a-z") | CharIn("A-Z") )
30 | def digit[_: P] = P( CharIn("0-9") )
31 | def natLit[_: P] = P( digit.rep(1).! ~ "n" ).map(_.toInt)
32 | def intLit[_: P] = P( (("-").? ~ digit.rep(1)).! ).map(_.toInt)
33 |
34 | def operator[_: P](precedence: Int): P[String] =
35 | P( CharPred(operatorStart(precedence).contains) ~ CharPred(operatorCont.contains).rep ).!.filter(s => !reservedTokens.contains(s))
36 |
37 | def anyOperator[_: P]: P[Ident] =
38 | P( CharPred(operatorCont.contains).rep(1).!.filter(s => !reservedTokens.contains(s)).map(Ident(_)) )
39 |
40 | def stringLit[_: P]: P[String] = P( "\"" ~ (stringChar | stringEscape).rep.! ~ "\"" )
41 | def stringChar[_: P]: P[Unit] = P( CharsWhile(!s"\\\n\"}".contains(_)) )
42 | def stringEscape[_: P]: P[Unit] = P( "\\" ~ AnyChar )
43 |
44 | def comment[$: P] = P( "#" ~ CharsWhile(_ != '\n', 0) )
45 | def whitespace[$: P](indent: Int) = P(
46 | CharsWhile(_ == ' ', 0) ~ (CharsWhile(_ == ' ', 0) ~ comment.? ~ "\r".? ~ "\n" ~ " ".rep(indent)).rep(0)
47 | )
48 | }
49 |
50 | trait EmbeddedSyntax { this: Parser =>
51 | import Lexical._
52 |
53 | def identExpr[_: P]: P[Ident] = positioned(ident.map(Ident))
54 |
55 | def embeddedExpr[_: P]: P[EmbeddedExpr] =
56 | P( positioned(embeddedOperatorExpr(MaxPrecedence)) )
57 |
58 | def embeddedAssignment[_: P]: P[EmbeddedAssignment] =
59 | P( positioned((identExpr ~ "=" ~ embeddedExpr).map { case (id, ee) => EmbeddedAssignment(id, ee) }) )
60 |
61 | def embeddedExprOrAssignment[_: P]: P[EmbeddedExpr] =
62 | P( embeddedAssignment | embeddedExpr )
63 |
64 | def embeddedAp[_: P]: P[EmbeddedApply] =
65 | P( identExpr.rep(1, ".").map(_.toVector) ~ "(" ~ embeddedExpr.rep(0, ",").map(_.toVector) ~ ")" ).map { case (method, args) =>
66 | EmbeddedApply(method, args, false, EmbeddedType.Unknown)
67 | }
68 |
69 | def bracketedEmbeddedExpr[_: P]: P[EmbeddedExpr] =
70 | P( "[" ~ embeddedExprOrAssignment ~ "]" )
71 |
72 | def simpleEmbeddedExpr[_: P]: P[EmbeddedExpr] =
73 | P( embeddedAp | identExpr | intLit.map(IntLit) | stringLit.map(StringLit) | ("(" ~ embeddedExpr ~ ")") )
74 |
75 | def operatorEmbeddedIdent[_: P](precedence: Int): P[Ident] = positioned(operator(precedence).map(Ident))
76 |
77 | def embeddedOperatorExpr[_: P](precedence: Int): P[EmbeddedExpr] = {
78 | def next = if(precedence == 0) positioned(simpleEmbeddedExpr) else embeddedOperatorExpr(precedence - 1)
79 | P( next ~ (operatorEmbeddedIdent(precedence) ~ next).rep ).map {
80 | case (e, Seq()) => e
81 | case (e, ts) =>
82 | val right = ts.count(_._1.s.endsWith(":"))
83 | if(right == 0)
84 | ts.foldLeft(e) { case (z, (o, a)) => EmbeddedApply(Vector(o), Vector(z, a), true, EmbeddedType.Unknown).setPos(o.pos) }
85 | else if(right == ts.length) {
86 | val e2 = ts.last._2
87 | val ts2 = ts.map(_._1).zip(e +: ts.map(_._2).init)
88 | ts2.foldRight(e2) { case ((o, a), z) => EmbeddedApply(Vector(o), Vector(a, z), true, EmbeddedType.Unknown).setPos(o.pos) }
89 | } else sys.error("Chained binary operators must have the same associativity")
90 | }
91 | }
92 | }
93 |
94 | trait Syntax { this: Parser =>
95 | import Lexical._
96 |
97 | def wildcard[_: P]: P[Wildcard] = P("_").map(_ => Wildcard())
98 |
99 | def nat[_: P]: P[Expr] = P( positioned(natLit.map(NatLit)) )
100 |
101 | def appOrIdent[_: P]: P[Expr] =
102 | P(
103 | positioned((identExpr ~ bracketedEmbeddedExpr.? ~ ("(" ~ expr.rep(sep = ",") ~ ")").?).map { case (id, embO, argsO) =>
104 | if(embO.isDefined || argsO.isDefined) Apply(id, embO, argsO.getOrElse(Vector.empty).toVector) else id
105 | })
106 | )
107 |
108 | def tuple[_: P]: P[Tuple] =
109 | P( "(" ~ expr.rep(min = 0, sep = ",").map(_.toVector) ~ ")" ).map(Tuple)
110 |
111 | def simpleExpr[_: P]: P[Expr] =
112 | P( (appOrIdent | wildcard | nat | tuple) )
113 |
114 | def operatorIdent[_: P](precedence: Int): P[Ident] = positioned(operator(precedence).map(Ident(_)))
115 |
116 | def operatorEx[_: P](precedence: Int): P[Expr] = {
117 | def next = if(precedence == 0) positioned(simpleExpr) else operatorEx(precedence - 1)
118 | P( next ~ (operatorIdent(precedence) ~ bracketedEmbeddedExpr.? ~ next).rep ).map {
119 | case (e, Seq()) => e
120 | case (e, ts) =>
121 | val right = ts.count(_._1.s.endsWith(":"))
122 | if(right == 0)
123 | ts.foldLeft(e) { case (z, (o, oe, a)) => Apply(o, oe, Vector(z, a)).setPos(o.pos) }
124 | else if(right == ts.length) {
125 | val e2 = ts.last._3
126 | val ts2 = ts.map(t => (t._1, t._2)).zip(e +: ts.map(_._3).init)
127 | ts2.foldRight(e2) { case (((o, oe), a), z) => Apply(o, oe, Vector(a, z)).setPos(o.pos) }
128 | } else sys.error("Chained binary operators must have the same associativity")
129 | }
130 | }
131 |
132 | def expr[_: P]: P[Expr] =
133 | P( positioned(operatorEx(MaxPrecedence)) ~ ("=" ~ positioned(operatorEx(MaxPrecedence))).? ).map {
134 | case (e1, None) => e1
135 | case (e1, Some(e2)) => Assignment(e1, e2).setPos(e1.pos)
136 | }
137 |
138 | def params[_: P](min: Int): P[Vector[IdentOrWildcard]] =
139 | P( ("(" ~ param.rep(min = min, sep = ",").map(_.toVector) ~ ")") )
140 |
141 | def param[_: P]: P[IdentOrWildcard] =
142 | P( positioned(identExpr | wildcard) )
143 |
144 | def defReturn[_: P]: P[Vector[IdentOrWildcard]] =
145 | P( params(1) | identExpr.map(Vector(_)) )
146 |
147 | def deriving[_ : P]: P[Vector[Ident]] =
148 | P( kw("deriving") ~/ "(" ~ identExpr.rep(0, sep=",").map(_.toVector) ~ ")" )
149 |
150 | def payloadType[_: P]: P[PayloadType] =
151 | P(ident).map {
152 | case "int" => (PayloadType.INT)
153 | case "ref" => (PayloadType.REF)
154 | case "label" => (PayloadType.LABEL)
155 | case tpe => sys.error(s"Illegal payload type: $tpe")
156 | }
157 |
158 | def embeddedSpecOpt[_: P]: P[(PayloadType, Option[Ident])] =
159 | P( ("[" ~ payloadType ~ identExpr.? ~ "]" ).? ).map(_.getOrElse((PayloadType.VOID, None)))
160 |
161 | def cons[_: P]: P[Cons] =
162 | P( kw("cons") ~/ (operatorDef | namedCons) ~ ("=" ~ identExpr).? ~ deriving.? ).map(Cons.tupled)
163 |
164 | def definition[_: P]: P[Def] =
165 | P( kw("def") ~/ (operatorDef | namedDef) ~ ("=" ~ defReturn).?.map(_.getOrElse(Vector.empty)) ~ defRule.rep.map(_.toVector) ).map(Def.tupled)
166 |
167 | def namedCons[_: P]: P[(Ident, Vector[IdentOrWildcard], Boolean, PayloadType, Option[Ident])] =
168 | P( identExpr ~ embeddedSpecOpt ~ params(0).?.map(_.getOrElse(Vector.empty)) ).map { case (n, (pt, eid), as) => (n, as, false, pt, eid) }
169 |
170 | def operatorDef[_: P]: P[(Ident, Vector[IdentOrWildcard], Boolean, PayloadType, Option[Ident])] =
171 | P( param ~ positioned(anyOperator) ~ embeddedSpecOpt ~ param ).map { case (a1, o, (pt, eid), a2) => (o, Vector(a1, a2), true, pt, eid) }
172 |
173 | def namedDef[_: P]: P[(Ident, Vector[IdentOrWildcard], Boolean, PayloadType, Option[Ident])] =
174 | P( identExpr ~ embeddedSpecOpt ~ params(1) ).map { case (n, (pt, eid), as) => (n, as, false, pt, eid) }
175 |
176 | def anyExpr[_ : P]: P[Either[EmbeddedExpr, Expr]] =
177 | P( bracketedEmbeddedExpr.map(Left(_)) | expr.map(Right(_)) )
178 |
179 | def anyExprs[_: P]: P[Vector[Either[EmbeddedExpr, Expr]]] =
180 | P( anyExpr.rep(1, sep = ";").map(_.toVector) ~ ";".? )
181 |
182 | def anyExprBlock[_: P]: P[(Vector[Expr], Vector[EmbeddedExpr])] =
183 | P( pos.flatMapX(p => forIndent(p.column).anyExprBlock2) )
184 |
185 | def anyExprBlock2[_: P]: P[(Vector[Expr], Vector[EmbeddedExpr])] =
186 | P( (forIndent(indent+1).anyExprs).rep(1) ).map { es =>
187 | (es.iterator.flatten.collect { case Right(e) => e }.toVector, es.iterator.flatten.collect { case Left(e) => e }.toVector)
188 | }
189 |
190 | def simpleReduction[_: P]: P[Branch] =
191 | P( positioned("=>" ~ anyExprBlock.map { case (es, ees) => Branch(None, ees, es) }) )
192 |
193 | def conditionalReductions[_: P]: P[Vector[Branch]] =
194 | P( ("if" ~ (bracketedEmbeddedExpr ~ simpleReduction).map { case (p, r) => r.copy(cond = Some(p))}).rep(1).map(_.toVector) ~
195 | "else" ~ simpleReduction
196 | ).map { case (rs, r) => rs :+ r }
197 |
198 | def reductions[_: P]: P[Vector[Branch]] =
199 | P( conditionalReductions | simpleReduction.map(Vector(_)) )
200 |
201 | def condition[_: P]: P[EmbeddedExpr] =
202 | P( "if" ~ bracketedEmbeddedExpr )
203 |
204 | def defRule[_: P]: P[DefRule] =
205 | P( positioned(("|" ~ expr.rep(1, ",").map(_.toVector) ~ reductions).map(DefRule.tupled)) )
206 |
207 | def matchStatement[_: P]: P[Match] =
208 | P( "match" ~ expr ~ reductions ).map { case (on, red) => Match(Vector(on), red) }
209 |
210 | def let[_: P]: P[Let] =
211 | P( kw("let") ~/ anyExprBlock ).map { case (defs, emb) => Let(defs, emb, Vector.empty) }
212 |
213 | def unit[_: P]: P[CompilationUnit] =
214 | P( Start ~ pos ~ positioned(cons | let | definition | matchStatement ).rep ~ End ).map { case (p, es) => CompilationUnit(es.toVector).setPos(p) }
215 | }
216 |
217 | class Parser(file: String, indexed: ConvenientParserInput, val indent: Int) extends EmbeddedSyntax with Syntax {
218 | implicit val whitespace: fastparse.Whitespace =
219 | (ctx: P[_]) => Lexical.whitespace(indent)(ctx)
220 |
221 | private[this] val positions = mutable.LongMap.empty[Position]
222 |
223 | def pos[_: P]: P[Position] = Index.map { offset => positions.getOrElseUpdate(offset, new Position(offset, file, indexed)) }
224 |
225 | def positioned[T <: Node, _: P](n: => P[T]): P[T] =
226 | P( (pos ~~ n) ).map { case (p, e) => if(!e.pos.isDefined) e.setPos(p) else e }
227 |
228 | def forIndent(i: Int): Parser = new Parser(file, indexed, i)
229 | }
230 |
231 | object Parser {
232 | def parse(input: String, file: String = ""): CompilationUnit = {
233 | val in = new ConvenientParserInput(input)
234 | val p = new Parser(file, in, 0)
235 | fastparse.parse(in, p.unit(_), verboseFailures = true).get.value
236 | }
237 |
238 | def parse(file: Path): CompilationUnit =
239 | parse(new String(Files.readAllBytes(file), StandardCharsets.UTF_8), file.toString)
240 | }
241 |
242 | class ConvenientParserInput(val data: String) extends ParserInput {
243 | def apply(index: Int): Char = data.charAt(index)
244 | def dropBuffer(index: Int): Unit = ()
245 | def slice(from: Int, until: Int): String = data.slice(from, until)
246 | def length: Int = data.length
247 | def innerLength: Int = length
248 | def isReachable(index: Int): Boolean = index < length
249 | def checkTraceable(): Unit = ()
250 |
251 | private[this] lazy val lineBreaks = {
252 | val b = ArrayBuffer.empty[Int]
253 | for(i <- data.indices)
254 | if(data.charAt(i) == '\n') b += i
255 | b
256 | }
257 |
258 | def find(idx: Int): (Int, Int) = {
259 | val line = lineBreaks.indexWhere(_ > idx) match {
260 | case -1 => lineBreaks.length
261 | case n => n
262 | }
263 | val col = if(line == 0) idx else idx - lineBreaks(line-1) - 1
264 | (line, col)
265 | }
266 |
267 | def prettyIndex(idx: Int): String = {
268 | val (line, pos) = find(idx)
269 | s"${line+1}:${pos+1}"
270 | }
271 |
272 | def getLine(line: Int): String = {
273 | val start = if(line == 0) 0 else lineBreaks(line-1)
274 | val end = if(line == lineBreaks.length) data.length else lineBreaks(line)
275 | data.substring(start, end).dropWhile(c => c == '\r' || c == '\n')
276 | }
277 |
278 | def getCaret(col: Int): String = (" " * col) + "^"
279 | }
280 |
--------------------------------------------------------------------------------
/src/main/scala/Prepare.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | /**
6 | * Create all symbols and check function calls and linearity.
7 | */
8 | class Prepare(val global: Global) extends Phase {
9 | import global._
10 |
11 | def apply(unit: CompilationUnit): CompilationUnit = {
12 | unit.statements.foreach {
13 | case c: Cons =>
14 | if(globalSymbols.contains(c.name)) error(s"Duplicate cons/def: ${c.name.s}", c.name)
15 | else c.name.sym = globalSymbols.defineCons(c.name.s, c.args.length, c.payloadType)
16 | case d: Def =>
17 | if(globalSymbols.contains(d.name)) error(s"Duplicate cons/def: ${d.name.s}", d.name)
18 | else d.name.sym = globalSymbols.defineDef(d.name.s, d.args.length, d.ret.length, d.payloadType)
19 | case _ =>
20 | }
21 | val st2 = Transform.mapC(unit.statements)(assign(_, globalSymbols))
22 | if(st2 eq unit.statements) unit else unit.copy(st2)
23 | }
24 |
25 | private[this] def assign(n: Statement, scope: Symbols): Statement = n match {
26 | case Cons(_, args, _, _, embId, ret, der) =>
27 | val sc = scope.sub()
28 | args.foreach(assign(_, sc))
29 | args.foreach(a => if(a.isWildcard) error("No wildcard parameters allowed in cons", a))
30 | ret.foreach(assign(_: Expr, sc))
31 | embId.foreach(assign(_: EmbeddedExpr, sc))
32 | der.foreach(_.foreach { i =>
33 | val symO = scope.get(i)
34 | if(!symO.exists(_.isCons))
35 | error(s"No constructor '${i.s}' defined", i)
36 | else i.sym = symO.get
37 | })
38 | n
39 | case Def(_, args, _, _, embId, ret, rules) =>
40 | val sc = scope.sub()
41 | args.foreach(assign(_, sc))
42 | ret.foreach(assign(_, sc))
43 | embId.foreach(assign(_: EmbeddedExpr, sc))
44 | val rm = new RefsMap
45 | args.foreach(rm.collectAll)
46 | ret.foreach(rm.collectAll)
47 | embId.foreach(rm.collectAll)
48 | if(rm.hasNonFree)
49 | error(s"Duplicate variable(s) in def pattern ${rm.nonFree.map(s => s"'$s'").mkString(", ")}", n)
50 | rules.foreach(assign(_, sc, rm))
51 | n
52 | case Match(on, reduced) =>
53 | val sc = scope.sub()
54 | on.foreach(assign(_, sc))
55 | val rm = new RefsMap
56 | on.foreach(rm.collectAll)
57 | if(rm.hasNonFree)
58 | error(s"Duplicate variable(s) in match pattern ${rm.nonFree.map(s => s"'$s'").mkString(", ")}", n)
59 | reduced.foreach(assign(_, sc))
60 | n
61 | case n @ Let(defs, embDefs, _) =>
62 | val sc = scope.sub()
63 | defs.foreach(assign(_, sc))
64 | embDefs.foreach(assign(_, sc))
65 | val refs = new RefsMap
66 | defs.foreach(refs.collectAll)
67 | if(refs.hasError)
68 | error(s"Non-linear use of variable(s) ${refs.err.map(s => s"'$s'").mkString(", ")}", n)
69 | n.copy(free = refs.free.filterNot(_.isEmbedded).map(sym => Ident(sym.id).setSym(sym)).toVector.sortBy(_.s))
70 | case _: CheckedRule => n
71 | }
72 |
73 | private[this] def assign(n: DefRule, scope: Symbols, defRefs: RefsMap): Unit = {
74 | val sc = scope.sub()
75 | n.on.foreach(assign(_, sc))
76 | val rm = defRefs.sub()
77 | n.on.foreach(rm.collectAll)
78 | if(rm.hasNonFree)
79 | error(s"Duplicate variable(s) in def rule pattern ${rm.nonFree.map(s => s"'$s'").mkString(", ")}", n)
80 | n.reduced.foreach(assign(_, sc))
81 | }
82 |
83 | private[this] def assign(n: Branch, scope: Symbols): Unit = {
84 | val sc = scope.sub()
85 | n.cond.foreach(assign(_, sc))
86 | n.reduced.foreach(assign(_, sc))
87 | n.embRed.foreach(assign(_, sc))
88 | }
89 |
90 | private[this] def define(n: Ident, scope: Symbols, embedded: Boolean): Symbol = {
91 | scope.get(n) match {
92 | case Some(s) =>
93 | if(s.isEmbedded && !embedded)
94 | error(s"Embedded variable '$s' used in non-embedded context", n)
95 | else if(!s.isEmbedded && embedded)
96 | error(s"Non-embedded variable '$s' used in embedded context", n)
97 | s
98 | case None =>
99 | scope.define(n.s, isEmbedded = embedded)
100 | }
101 | }
102 |
103 | private[this] def assign(n: Expr, scope: Symbols): Unit = n match {
104 | case n: Ident =>
105 | assert(n.sym.isEmpty)
106 | n.sym = define(n, scope, false)
107 | case n: NatLit =>
108 | n.sSym = define(Ident("S"), scope, false)
109 | n.zSym = define(Ident("Z"), scope, false)
110 | if(!n.sSym.isCons || n.sSym.arity != 1 || !n.zSym.isCons || n.zSym.arity != 0)
111 | error(s"Nat literal requires appropriate Z() and S(_) constructors", n)
112 | case Apply(id, emb, args) =>
113 | assign(id: Expr, scope)
114 | emb.foreach(assign(_: EmbeddedExpr, scope))
115 | val l = args.length
116 | if(!id.sym.isCons)
117 | error(s"Symbol '${id.sym}' in function call is not a constructor", id)
118 | else if(id.sym.callArity != l)
119 | error(s"Wrong number of arguments for '${id.sym}': got $l, expected ${id.sym.callArity}", n)
120 | if(l == 1) assign(args.head, scope) // tail-recursive call
121 | else if(l > 0) args.foreach(assign(_, scope))
122 | case n => n.nodeChildren.foreach {
123 | case ch: Expr => assign(ch, scope)
124 | }
125 | }
126 |
127 | private[this] def assign(n: EmbeddedExpr, scope: Symbols): Unit = n match {
128 | case n: Ident =>
129 | assert(n.sym.isEmpty)
130 | n.sym = define(n, scope, true)
131 | case EmbeddedApply(_, args, _, _) =>
132 | args.foreach(assign(_, scope))
133 | case n => n.nodeChildren.foreach {
134 | case ch: EmbeddedExpr => assign(ch, scope)
135 | }
136 | }
137 | }
138 |
--------------------------------------------------------------------------------
/src/main/scala/ResolveEmbedded.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast._
4 |
5 | import java.lang.reflect.Method
6 |
7 | /** Resolve embedded methods and assign embedded types. */
8 | class ResolveEmbedded(val global: Global) extends Transform with Phase {
9 | import global._
10 |
11 | override def apply(_n: Branch): Branch = {
12 | val n2 = super.apply(_n)
13 | n2.cond.foreach { ee =>
14 | val (tp, _) = getReturnType(ee)
15 | if(tp != EmbeddedType.Bool)
16 | error(s"Expected boolean type for condition expression, got $tp", ee)
17 | }
18 | n2
19 | }
20 |
21 | override def apply(_n: EmbeddedApply): EmbeddedApply = {
22 | val n2 = super.apply(_n)
23 | val args = n2.args.map(getReturnType)
24 | val (clsName, mName, qn) = {
25 | if(n2.op && args.length == 2 && args(0)._1 == EmbeddedType.PayloadLabel && args(1)._1 == EmbeddedType.PayloadLabel) ResolveEmbedded.eqLabel
26 | else if(n2.op) ResolveEmbedded.operators(n2.methodQN.head)
27 | else (n2.className, n2.methodName, n2.methodQNIds)
28 | }
29 | val methTp = resolveMethod(n2, clsName, mName, args)
30 | val n3 = n2.copy(methodQNIds = qn, embTp = methTp, op = false)
31 | //ShowableNode.print(n3, name = "Resolved")
32 | n3
33 | }
34 |
35 | private[this] def getReturnType(ee: EmbeddedExpr): (EmbeddedType, Boolean /* out */) = ee match {
36 | case _: StringLit => (EmbeddedType.PayloadRef, false)
37 | case _: IntLit => (EmbeddedType.PayloadInt, false)
38 | case ee: Ident => (EmbeddedType.Payload(ee.sym.payloadType), !ee.sym.isPattern)
39 | case ee: EmbeddedApply => (ee.embTp match {
40 | case t: EmbeddedType.Method => t.ret
41 | case t @ EmbeddedType.Unknown => t
42 | }, false)
43 | case _ => (EmbeddedType.PayloadVoid, false)
44 | }
45 |
46 | private[this] def resolveMethod(e: EmbeddedApply, cln: String, mn: String, args: Vector[(EmbeddedType, Boolean)]): EmbeddedType = {
47 | val c = dependencyLoader.loadClass(cln)
48 | def toPT(cl: Class[_]): (EmbeddedType, Boolean) = cl.getName match {
49 | case "void" => (EmbeddedType.PayloadVoid, false)
50 | case "int" => (EmbeddedType.PayloadInt, false)
51 | case "boolean" => (EmbeddedType.Bool, false)
52 | case s if s == classOf[IntOutput].getName => (EmbeddedType.PayloadInt, true)
53 | case s if s == classOf[RefOutput].getName => (EmbeddedType.PayloadRef, true)
54 | case _ if !cl.isPrimitive => (EmbeddedType.PayloadRef, false)
55 | case s => (EmbeddedType.Unknown, false)
56 | }
57 | tryError {
58 | val nameCandidates = c.getDeclaredMethods.filter(_.getName == mn).map { m =>
59 | EmbeddedType.Method(m, toPT(m.getReturnType)._1, m.getParameterTypes.iterator.map(toPT).toVector)
60 | }
61 | val typeCandidates = nameCandidates.filter { m =>
62 | m.args == args.map {
63 | case (EmbeddedType.Payload(PayloadType.LABEL), o) => (EmbeddedType.PayloadRef, o)
64 | case (t, o) => (t, o)
65 | }
66 | }
67 | if(nameCandidates.isEmpty) throwError(s"Method $cln.$mn not found.", e)
68 | else if(typeCandidates.isEmpty) {
69 | val exp = showEmbeddedSignature(args, EmbeddedType.Unknown, mn)
70 | val found = nameCandidates.map(m => showJavaSignature(m.method)).mkString(" ", "\n ", "")
71 | throwError(s"No applicable overload of method $cln.$mn.\nExpected:\n $exp\nFound:\n$found", e)
72 | }
73 | else if(typeCandidates.length > 1) {
74 | val exp = showEmbeddedSignature(args, EmbeddedType.Unknown, mn)
75 | val found = nameCandidates.map(m => showJavaSignature(m.method)).mkString(" ", "\n ", "")
76 | throwError(s"${typeCandidates.length} ambiguous overloads of method $cln.$mn.\nExpected:\n $exp\nAmbiguous:\n$found", e)
77 | }
78 | else typeCandidates.head
79 | }.getOrElse(EmbeddedType.Unknown)
80 | }
81 |
82 | private[this] def showJavaSignature(m: Method): String = {
83 | s"${m.getReturnType.getName} ${m.getName}(${m.getParameterTypes.map(_.getName).mkString(", ")})"
84 | }
85 |
86 | private[this] def showEmbeddedSignature(args: Vector[(EmbeddedType, Boolean)], ret: EmbeddedType, name: String): String = {
87 | def f(t: EmbeddedType, out: Boolean): String = (if (out) "out " else "") + (t match {
88 | case EmbeddedType.Payload(pt) => pt.toString
89 | case EmbeddedType.Bool => "boolean"
90 | case EmbeddedType.Unknown => "?"
91 | case t => t.toString
92 | })
93 | s"${f(ret, false)} ${name}(${args.map{ case (t, o) => f(t, o) }.mkString(", ")})"
94 | }
95 | }
96 |
97 | object ResolveEmbedded {
98 | private[this] val runtimeName = Runtime.getClass.getName
99 | private[this] val intrinsicsQN = runtimeName.split('.').toVector
100 | private[this] def mkOp(m: String) = (runtimeName, m, (intrinsicsQN :+ m).map(Ident(_)))
101 | private val eqLabel = mkOp("eqLabel")
102 | private val operators = Map(
103 | "==" -> mkOp("eq"),
104 | "+" -> mkOp("intAdd"),
105 | "-" -> mkOp("intSub"),
106 | "*" -> mkOp("intMult"),
107 | )
108 | }
109 |
--------------------------------------------------------------------------------
/src/main/scala/Runtime.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | trait IntOutput { def setValue(i: Int): Unit }
4 | trait LongOutput { def setValue(l: Long): Unit }
5 | trait IntBox extends IntOutput { def getValue: Int }
6 | trait LongBox extends LongOutput { def getValue: Long }
7 | trait RefOutput { def setValue(o: AnyRef): Unit }
8 | trait RefBox extends RefOutput { def getValue: AnyRef } // Also used for Label
9 | trait LifecycleManaged { def erase(): Unit; def copy(): LifecycleManaged }
10 |
11 | // Standalone boxes used for boxed temporary values in inlined payload computations
12 | final class IntBoxImpl extends IntBox {
13 | private[this] var value: Int = _
14 | def getValue: Int = value
15 | def setValue(v: Int): Unit = value = v
16 | }
17 | final class LongBoxImpl extends LongBox {
18 | private[this] var value: Long = _
19 | def getValue: Long = value
20 | def setValue(v: Long): Unit = value = v
21 | }
22 | final class RefBoxImpl extends RefBox {
23 | private[this] var value: AnyRef = _
24 | def getValue: AnyRef = value
25 | def setValue(v: AnyRef): Unit = value = v
26 | }
27 |
28 | object Runtime {
29 | def add(a: Int, b: Int, res: IntOutput): Unit = res.setValue(a + b)
30 | def mult(a: Int, b: Int, res: IntOutput): Unit = res.setValue(a * b)
31 | def strlen(s: String): Int = s.length
32 |
33 | def eraseRef(o: AnyRef): Unit = o match {
34 | case o: LifecycleManaged => o.erase()
35 | case _ =>
36 | }
37 |
38 | def dupRef(o: AnyRef, r1: RefOutput, r2: RefOutput): Unit = {
39 | r1.setValue(o)
40 | r2.setValue(o match {
41 | case o: LifecycleManaged => o.copy()
42 | case o => o
43 | })
44 | }
45 |
46 | def eq(a: AnyRef, b: AnyRef): Boolean = a == b
47 | def eqLabel(a: AnyRef, b: AnyRef): Boolean = a eq b
48 | def eq(a: Int, b: Int): Boolean = a == b
49 | def intAdd(a: Int, b: Int): Int = a + b
50 | def intSub(a: Int, b: Int): Int = a - b
51 | def intMult(a: Int, b: Int): Int = a * b
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/scala/SymCounts.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import de.szeiger.interact.ast.Symbol
4 |
5 | import scala.collection.mutable
6 |
7 | final class SymCounts(private val m: mutable.Map[Symbol, Int]) extends AnyVal {
8 | override def toString: String = {
9 | m.iterator.filter(_._2 > 0).toVector.sortBy(_._1.id).map {
10 | case (s, 1) => s.id
11 | case (s, i) => s"$i ${s.id}"
12 | }.mkString("[", ", ", "]")
13 | }
14 |
15 | def size = m.size
16 |
17 | def + (s: Symbol): SymCounts = {
18 | val n = mutable.Map.from(m)
19 | SymCounts.add(n, s, 1)
20 | new SymCounts(n)
21 | }
22 |
23 | def - (s: Symbol): SymCounts = {
24 | val n = mutable.Map.from(m)
25 | SymCounts.add(n, s, -1)
26 | new SymCounts(n)
27 | }
28 |
29 | def ++ (o: SymCounts): SymCounts = {
30 | val n = mutable.Map.from(m)
31 | o.m.foreach { case (k, v) => SymCounts.add(n, k, v) }
32 | new SymCounts(n)
33 | }
34 |
35 | def ++ (it: IterableOnce[Symbol]): SymCounts = {
36 | val n = mutable.Map.from(m)
37 | it.iterator.foreach(SymCounts.add(n, _, 1))
38 | new SymCounts(n)
39 | }
40 |
41 | def -- (o: SymCounts): SymCounts = {
42 | val n = mutable.Map.from(m)
43 | o.m.foreach { case (k, v) => SymCounts.add(n, k, -v) }
44 | new SymCounts(n)
45 | }
46 |
47 | def count(s: Symbol): Int = m.getOrElse(s, 0)
48 |
49 | def contains(s: Symbol): Boolean = count(s) > 0
50 | }
51 |
52 | object SymCounts {
53 | private def add(m: mutable.Map[Symbol, Int], s: Symbol, count: Int): Unit =
54 | m.updateWith(s) {
55 | case Some(i) =>
56 | val j = i+count
57 | if(j > 0) Some(j) else None
58 | case None if count > 0 => Some(count)
59 | case None => None
60 | }
61 |
62 | private def add(m: mutable.Map[Symbol, Int], it: IterableOnce[Symbol]): Unit = it.iterator.foreach(add(m, _, 1))
63 |
64 | def apply(ss: Symbol*): SymCounts = {
65 | val m = mutable.Map.empty[Symbol, Int]
66 | add(m, ss)
67 | new SymCounts(m)
68 | }
69 |
70 | def from(syms: IterableOnce[Symbol]): SymCounts = {
71 | val m = mutable.Map.empty[Symbol, Int]
72 | add(m, syms)
73 | new SymCounts(m)
74 | }
75 |
76 | val empty = from(Nil)
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/scala/ast/Symbols.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.ast
2 |
3 | import scala.collection.mutable
4 |
5 | final class Symbol(val id: String, val arity: Int = 0, val returnArity: Int = 1, var payloadType: PayloadType = PayloadType.VOID,
6 | val matchContinuationPort: Int = -2, private[this] var _flags: Int = 0) {
7 | import Symbol._
8 |
9 | def flags: Int = _flags
10 | def isCons: Boolean = (_flags & FLAG_CONS) != 0
11 | def isDef: Boolean = (_flags & FLAG_DEF) != 0
12 | def isEmbedded: Boolean = (_flags & FLAG_EMBEDDED) != 0
13 | def isPattern: Boolean = (_flags & FLAG_PATTERN) != 0
14 |
15 | def setPattern(): Unit = _flags = (_flags | FLAG_PATTERN)
16 |
17 | def isContinuation: Boolean = matchContinuationPort != -2
18 | def callArity: Int = arity + 1 - returnArity
19 | def hasPayload: Boolean = payloadType != PayloadType.VOID
20 | override def toString: String = id
21 | def isDefined: Boolean = this != Symbol.NoSymbol
22 | def isEmpty: Boolean = !isDefined
23 | def isSingleton: Boolean = arity == 0 && payloadType.isEmpty
24 | def uniqueStr: String = if(isDefined) s"$id:${System.identityHashCode(this)}" else ""
25 | def show: String = s"$uniqueStr<$payloadType>"
26 | }
27 |
28 | object Symbol {
29 | def apply(id: String, arity: Int = 0, returnArity: Int = 1,
30 | isCons: Boolean = false, isDef: Boolean = false,
31 | payloadType: PayloadType = PayloadType.VOID, matchContinuationPort: Int = -2,
32 | isEmbedded: Boolean = false, isPattern: Boolean = false): Symbol =
33 | new Symbol(id, arity, returnArity, payloadType, matchContinuationPort,
34 | (if(isCons) FLAG_CONS else 0) | (if(isDef) FLAG_DEF else 0) | (if(isEmbedded) FLAG_EMBEDDED else 0) | (if(isPattern) FLAG_PATTERN else 0)
35 | )
36 |
37 | val FLAG_CONS = 1 << 0
38 | val FLAG_DEF = 1 << 1
39 | val FLAG_EMBEDDED = 1 << 2
40 | val FLAG_PATTERN = 1 << 3
41 |
42 | val NoSymbol = new Symbol("")
43 | }
44 |
45 | class SymbolGen(prefix2: String, isEmbedded: Boolean = false, payloadType: PayloadType = PayloadType.VOID) {
46 | private[this] var last = 0
47 | def apply(isEmbedded: Boolean = isEmbedded, payloadType: PayloadType = payloadType, prefix: String = ""): Symbol = {
48 | last += 1
49 | Symbol(prefix+prefix2+last, isEmbedded = isEmbedded, payloadType = payloadType)
50 | }
51 | def id(isEmbedded: Boolean = isEmbedded, payloadType: PayloadType = payloadType, prefix: String = ""): Ident = {
52 | val s = apply(isEmbedded, payloadType, prefix)
53 | val i = Ident(s.id)
54 | i.sym = s
55 | i
56 | }
57 | }
58 |
59 | class Symbols(parent: Option[Symbols] = None) {
60 | private[this] val syms = mutable.HashMap.empty[String, Symbol]
61 | def define(id: String, isCons: Boolean = false, isDef: Boolean = false, arity: Int = 0, returnArity: Int = 1,
62 | payloadType: PayloadType = PayloadType.VOID, matchContinuationPort: Int = -2,
63 | isEmbedded: Boolean = false): Symbol = {
64 | assert(get(id).isEmpty)
65 | val sym = Symbol(id, arity, returnArity, isCons, isDef, payloadType, matchContinuationPort, isEmbedded)
66 | syms.put(id, sym)
67 | sym
68 | }
69 | def defineCons(id: String, arity: Int, payloadType: PayloadType): Symbol =
70 | define(id, isCons = true, arity = arity, payloadType = payloadType)
71 | def defineDef(id: String, argLen: Int, retLen: Int, payloadType: PayloadType): Symbol =
72 | define(id, isCons = true, isDef = true, arity = argLen + retLen - 1, returnArity = retLen, payloadType = payloadType)
73 | def getOrAdd(id: Ident): Symbol = get(id).getOrElse(define(id.s))
74 | def contains(id: Ident): Boolean = get(id).isDefined
75 | def containsLocal(id: Ident): Boolean = syms.contains(id.s)
76 | def get(id: Ident): Option[Symbol] = get(id.s)
77 | def get(id: String): Option[Symbol] = syms.get(id).orElse(parent.flatMap(_.get(id)))
78 | def apply(id: Ident): Symbol = apply(id.s)
79 | def apply(id: String): Symbol =
80 | get(id).getOrElse(sys.error(s"No symbol found for $id"))
81 | def symbols: Iterator[Symbol] = syms.valuesIterator ++ parent.map(_.symbols).getOrElse(Iterator.empty)
82 | def sub(): Symbols = new Symbols(Some(this))
83 | override def toString: String = s"Symbols(${syms.map { case (_, v) => s"$v"}.mkString(", ")})"
84 | }
85 |
86 | class RefsMap(parent: Option[RefsMap] = None) {
87 | private[this] val data = mutable.Map.empty[Symbol, Int]
88 | private[this] val hasErr = mutable.Set.empty[Symbol]
89 | def inc(s: Symbol): Unit = data.update(s, {
90 | val c = apply(s) + 1
91 | if(c == 3) {
92 | if(!s.isEmbedded || !s.payloadType.canCopy) hasErr += s
93 | }
94 | c
95 | })
96 | def local: Iterator[(Symbol, Int)] = data.iterator.map {case (s, c) =>
97 | (s, c - parent.map(_(s)).getOrElse(0))
98 | }.filter(_._2 > 0)
99 | def iterator: Iterator[(Symbol, Int)] = parent match {
100 | case Some(r) => r.iterator.filter(t => !data.contains(t._1)) ++ data.iterator
101 | case None => data.iterator
102 | }
103 | def apply(s: Symbol): Int = data.getOrElse(s, parent match {
104 | case Some(r) => r(s)
105 | case None => 0
106 | })
107 | def free: Iterator[Symbol] = iterator.filter(_._2 == 1).map(_._1)
108 | def linear: Iterator[Symbol] = iterator.filter(_._2 == 2).map(_._1)
109 | def err: Iterator[Symbol] = parent match {
110 | case Some(r) => r.err ++ hasErr.iterator
111 | case None => hasErr.iterator
112 | }
113 | def nonFree: Iterator[Symbol] = iterator.filter(_._2 > 1).map(_._1)
114 | def hasNonFree: Boolean = hasErr.nonEmpty || linear.hasNext || parent.exists(_.hasNonFree)
115 | def hasError: Boolean = hasErr.nonEmpty || parent.exists(_.hasError)
116 | private[this] def collect0(n: Node, embedded: Boolean, regular: Boolean): Unit = n match {
117 | case n: Ident =>
118 | val use = (n.sym.isEmbedded && embedded) || (!n.sym.isEmbedded && regular)
119 | if(use && !n.sym.isEmpty && !n.sym.isCons) inc(n.sym)
120 | case n => n.nodeChildren.foreach(collect0(_, embedded, regular))
121 | }
122 | def collectRegular(n: Node): Unit = collect0(n, false, true)
123 | def collectEmbedded(n: Node): Unit = collect0(n, true, false)
124 | def collectAll(n: Node): Unit = collect0(n, true, true)
125 | def sub(): RefsMap = parent match {
126 | case Some(r) if data.isEmpty => r.sub()
127 | case _ => new RefsMap(Some(this))
128 | }
129 | def allSymbols: Iterator[Symbol] = iterator.map(_._1).filter(_.isDefined)
130 | }
131 |
--------------------------------------------------------------------------------
/src/main/scala/ast/Transform.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.ast
2 |
3 | import de.szeiger.interact.{AllocateTemp, BranchWiring, CheckPrincipal, Connection, CreateLabelsComp, EmbArg, InitialRuleWiring, PayloadAssignment, PayloadComputation, PayloadComputationPlan, PayloadMethodApplication, PayloadMethodApplicationWithReturn, ReuseLabelsComp, RuleWiring}
4 |
5 | abstract class Transform {
6 | import Transform._
7 |
8 | def apply(n: Node): Node = n match {
9 | case n: Branch => apply(n)
10 | case n: DefRule => apply(n)
11 | case n: Expr => apply(n)
12 | case n: EmbeddedExpr => apply(n)
13 | case n: Statement =>
14 | val v = apply(n: Statement)
15 | assert(v.length == 1)
16 | v.head
17 | case n: CompilationUnit => apply(n)
18 | }
19 |
20 | def apply(n: DefRule): DefRule = {
21 | val on2 = mapC(n.on)(apply)
22 | val red2 = mapC(n.reduced)(apply)
23 | if((on2 eq n.on) && (red2 eq n.reduced)) n
24 | else n.copy(on2, red2)
25 | }
26 |
27 | def apply(n: IdentOrWildcard): IdentOrWildcard = n match {
28 | case n: Ident => apply(n)
29 | case n: Wildcard => apply(n)
30 | }
31 |
32 | def apply(n: Apply): Apply = {
33 | val t2 = apply(n.target)
34 | val emb2 = mapOC(n.embedded)(apply)
35 | val args2 = mapC(n.args)(apply)
36 | if((t2 eq n.target) && (emb2 eq n.embedded) && (args2 eq n.args)) n
37 | else n.copy(t2, emb2, args2)
38 | }
39 |
40 | def apply(n: ApplyCons): ApplyCons = {
41 | val t2 = apply(n.target)
42 | val emb2 = mapOC(n.embedded)(apply)
43 | val args2 = mapC(n.args)(apply)
44 | if((t2 eq n.target) && (emb2 eq n.embedded) && (args2 eq n.args)) n
45 | else n.copy(t2, emb2, args2)
46 | }
47 |
48 | def apply(n: Tuple): Tuple = {
49 | val e2 = mapC(n.exprs)(apply)
50 | if(e2 eq n.exprs) n
51 | else n.copy(e2)
52 | }
53 |
54 | def apply(n: Assignment): Assignment = {
55 | val l2 = apply(n.lhs)
56 | val r2 = apply(n.rhs)
57 | if((l2 eq n.lhs) && (r2 eq n.rhs)) n
58 | else n.copy(l2, r2)
59 | }
60 |
61 | def apply(n: Expr): Expr = n match {
62 | case n: IdentOrWildcard => apply(n)
63 | case n: Apply => apply(n)
64 | case n: ApplyCons => apply(n)
65 | case n: Tuple => apply(n)
66 | case n: Assignment => apply(n)
67 | case n: NatLit => apply(n)
68 | }
69 |
70 | def apply(n: NatLit): NatLit = n
71 |
72 | def apply(n: StringLit): StringLit = n
73 |
74 | def apply(n: IntLit): IntLit = n
75 |
76 | def apply(n: Ident): Ident = n
77 |
78 | def apply(n: Wildcard): Wildcard = n
79 |
80 | def apply(n: EmbeddedApply): EmbeddedApply = {
81 | val m2 = mapC(n.methodQNIds)(apply)
82 | val args2 = mapC(n.args)(apply)
83 | if((m2 eq n.methodQNIds) && (args2 eq n.args)) n
84 | else n.copy(m2, args2, n.op, n.embTp)
85 | }
86 |
87 | def apply(n: EmbeddedAssignment): EmbeddedAssignment = {
88 | val l2 = apply(n.lhs)
89 | val r2 = apply(n.rhs)
90 | if((l2 eq n.lhs) && (r2 eq n.rhs)) n
91 | else n.copy(l2, r2)
92 | }
93 |
94 | def apply(n: CreateLabels): CreateLabels = n
95 |
96 | def apply(n: EmbeddedExpr): EmbeddedExpr = n match {
97 | case n: StringLit => apply(n)
98 | case n: IntLit => apply(n)
99 | case n: Ident => apply(n)
100 | case n: EmbeddedApply => apply(n)
101 | case n: EmbeddedAssignment => apply(n)
102 | case n: CreateLabels => apply(n)
103 | }
104 |
105 | def apply(n: Match): Vector[Statement] = Vector({
106 | val on2 = mapC(n.on)(apply)
107 | val red2 = mapC(n.reduced)(apply)
108 | if((on2 eq n.on) && (red2 eq n.reduced)) n
109 | else n.copy(on2, red2)
110 | })
111 |
112 | def apply(n: Cons): Vector[Statement] = Vector({
113 | val n2 = apply(n.name)
114 | val a2 = mapC(n.args)(apply)
115 | val e2 = mapOC(n.embeddedId)(apply)
116 | val r2 = mapOC(n.ret)(apply)
117 | val d2 = mapOC(n.der)(mapC(_)(apply))
118 | if((n2 eq n.name) && (a2 eq n.args) && (e2 eq n.embeddedId) && (r2 eq n.ret) && (d2 eq n.der)) n
119 | else n.copy(n2, a2, n.operator, n.payloadType, e2, r2, d2)
120 | })
121 |
122 | def apply(n: Let): Vector[Statement] = Vector({
123 | val d2 = mapC(n.defs)(apply)
124 | val e2 = mapC(n.embDefs)(apply)
125 | val f2 = mapC(n.free)(apply)
126 | if((d2 eq n.defs) && (e2 eq n.embDefs) && (f2 eq n.free)) n
127 | else n.copy(d2, e2, f2)
128 | })
129 |
130 | def apply(n: Def): Vector[Statement] = Vector({
131 | val n2 = apply(n.name)
132 | val a2 = mapC(n.args)(apply)
133 | val e2 = mapOC(n.embeddedId)(apply)
134 | val r2 = mapC(n.ret)(apply)
135 | val u2 = mapC(n.rules)(apply)
136 | if((n2 eq n.name) && (a2 eq n.args) && (e2 eq n.embeddedId) && (r2 eq n.ret) && (u2 eq n.rules)) n
137 | else n.copy(n2, a2, n.operator, n.payloadType, e2, r2, u2)
138 | })
139 |
140 | def apply(n: Statement): Vector[Statement] = n match {
141 | case n: Match => apply(n)
142 | case n: Cons => apply(n)
143 | case n: Let => apply(n)
144 | case n: Def => apply(n)
145 | case n: CheckedRule => apply(n)
146 | case n: RuleWiring => apply(n)
147 | case n: InitialRuleWiring => apply(n)
148 | }
149 |
150 | def apply(n: CheckedRule): Vector[Statement] = n match {
151 | case n: DerivedRule => apply(n)
152 | case n: MatchRule => apply(n)
153 | }
154 |
155 | def apply(n: DerivedRule): Vector[Statement] = Vector(n)
156 |
157 | def apply(n: MatchRule): Vector[Statement] = Vector({
158 | val i1 = apply(n.id1)
159 | val i2 = apply(n.id2)
160 | val a12 = mapC(n.args1)(apply)
161 | val a22 = mapC(n.args2)(apply)
162 | val emb12 = mapOC(n.emb1)(apply)
163 | val emb22 = mapOC(n.emb2)(apply)
164 | val red2 = mapC(n.branches)(apply)
165 | if((i1 eq n.id1) && (i2 eq n.id2) && (a12 eq n.args1) && (a22 eq n.args2) && (emb12 eq n.emb1) && (emb22 eq n.emb2) && (red2 eq n.branches)) n
166 | else n.copy(i1, i2, a12, a22, emb12, emb22, red2)
167 | })
168 |
169 | def apply(n: RuleWiring): Vector[Statement] = Vector({
170 | val b2 = mapC(n.branches)(apply)
171 | if(b2 eq n.branches) n
172 | else n.copy(n.sym1, n.sym2, b2, n.derived)
173 | })
174 |
175 | def apply(n: InitialRuleWiring): Vector[Statement] = Vector({
176 | val b2 = apply(n.branch)
177 | if(b2 eq n.branch) n
178 | else n.copy(branch = b2)
179 | })
180 |
181 | def apply(n: BranchWiring): BranchWiring = {
182 | val co2 = flatMapC(n.conns)(apply)
183 | val pc2 = flatMapOC(n.payloadComps)(apply)
184 | val cn2 = flatMapOC(n.cond)(apply)
185 | val bw2 = mapC(n.branches)(apply)
186 | if((co2 eq n.conns) && (pc2 eq n.payloadComps) && (cn2 eq n.cond) && (bw2 eq n.branches)) n
187 | else n.copy(n.cellOffset, n.cells, co2, pc2, cn2, bw2)
188 | }
189 |
190 | def apply(n: Connection): Set[Connection] = Set(n)
191 |
192 | def apply(n: CompilationUnit): CompilationUnit = {
193 | val st2 = flatMapC(n.statements)(apply)
194 | if(st2 eq n.statements) n
195 | else n.copy(st2)
196 | }
197 |
198 | def apply(n: Branch): Branch = {
199 | val cond2 = mapOC(n.cond)(apply)
200 | val embRed2 = mapC(n.embRed)(apply)
201 | val red2 = mapC(n.reduced)(apply)
202 | if((cond2 eq n.cond) && (embRed2 eq n.embRed) && (red2 eq n.reduced)) n
203 | else n.copy(cond2, embRed2, red2)
204 | }
205 |
206 | def apply(n: EmbArg): EmbArg = n
207 |
208 | def apply(n: PayloadComputationPlan): Option[PayloadComputationPlan] = n match {
209 | case n: PayloadComputation => apply(n)
210 | case n: ReuseLabelsComp => Some(apply(n))
211 | case n: AllocateTemp => Some(apply(n))
212 | }
213 |
214 | def apply(n: PayloadComputation): Option[PayloadComputation] = Some(n match {
215 | case n: PayloadMethodApplication => apply(n)
216 | case n: PayloadMethodApplicationWithReturn => apply(n)
217 | case n: PayloadAssignment => apply(n)
218 | case n: CreateLabelsComp => apply(n)
219 | case n: CheckPrincipal => apply(n)
220 | })
221 |
222 | def apply(n: CheckPrincipal): CheckPrincipal = n
223 |
224 | def apply(n: PayloadMethodApplication): PayloadMethodApplication = {
225 | val ea2 = mapC(n.embArgs)(apply)
226 | if (ea2 eq n.embArgs) n
227 | else n.copy(embArgs = ea2)
228 | }
229 |
230 | def apply(n: PayloadMethodApplicationWithReturn): PayloadComputation = {
231 | val m2 = apply(n.method)
232 | val ri2 = apply(n.retIndex)
233 | if((m2 eq n.method) && (ri2 eq n.retIndex)) n
234 | else n.copy(method = m2, retIndex = ri2)
235 | }
236 |
237 | def apply(n: PayloadAssignment): PayloadComputation = {
238 | val si2 = apply(n.sourceIdx)
239 | val ti2 = apply(n.targetIdx)
240 | if((si2 eq n.sourceIdx) && (ti2 eq n.targetIdx)) n
241 | else n.copy(sourceIdx = si2, targetIdx = ti2)
242 | }
243 |
244 | def apply(n: CreateLabelsComp): PayloadComputation = {
245 | val ea2 = mapC(n.embArgs)(apply)
246 | if (ea2 eq n.embArgs) n
247 | else n.copy(embArgs = ea2)
248 | }
249 |
250 | def apply(n: ReuseLabelsComp): PayloadComputationPlan = {
251 | val ea2 = mapC(n.embArgs)(apply)
252 | if (ea2 eq n.embArgs) n
253 | else n.copy(embArgs = ea2)
254 | }
255 |
256 | def apply(n: AllocateTemp): PayloadComputationPlan = {
257 | val ea2 = apply(n.ea)
258 | if (ea2 eq n.ea) n
259 | else n.copy(ea = ea2.asInstanceOf[EmbArg.Temp])
260 | }
261 | }
262 |
263 | object Transform {
264 | def mapC[T,R](xs: Vector[T])(f: T => R): Vector[R] = {
265 | var changed = false
266 | val xs2 = xs.map { x =>
267 | val x2 = f(x)
268 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true
269 | x2
270 | }
271 | if(changed) xs2 else xs.asInstanceOf[Vector[R]]
272 | }
273 |
274 | def flatMapC[T,R](xs: Vector[T])(f: T => Vector[R]): Vector[R] = {
275 | var changed = false
276 | val xs2 = xs.flatMap { x =>
277 | val x2 = f(x)
278 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true
279 | x2
280 | }
281 | if(changed) xs2 else xs.asInstanceOf[Vector[R]]
282 | }
283 |
284 | def flatMapC[T,R](xs: Set[T])(f: T => Set[R]): Set[R] = {
285 | var changed = false
286 | val xs2 = xs.flatMap { x =>
287 | val x2 = f(x)
288 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true
289 | x2
290 | }
291 | if(changed) xs2 else xs.asInstanceOf[Set[R]]
292 | }
293 |
294 | def flatMapOC[T,R](xs: Vector[T])(f: T => Option[R]): Vector[R] = {
295 | var changed = false
296 | val xs2 = xs.flatMap { x =>
297 | val x2 = f(x)
298 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) changed = true
299 | x2
300 | }
301 | if(changed) xs2 else xs.asInstanceOf[Vector[R]]
302 | }
303 |
304 | def mapOC[T,R](o: Option[T])(f: T => R): Option[R] = o match {
305 | case None => o.asInstanceOf[Option[R]]
306 | case Some(x) =>
307 | val x2 = f(x)
308 | if(x2.asInstanceOf[AnyRef] ne x.asInstanceOf[AnyRef]) Some(x2)
309 | else o.asInstanceOf[Option[R]]
310 | }
311 |
312 | def flatMapOC[T,R](o: Option[T])(f: T => Option[R]): Option[R] = o match {
313 | case None => o.asInstanceOf[Option[R]]
314 | case Some(x) =>
315 | f(x) match {
316 | case s @ Some(x2) if x.asInstanceOf[AnyRef] ne x2.asInstanceOf[AnyRef] => s
317 | case _ => o.asInstanceOf[Option[R]]
318 | }
319 | }
320 | }
321 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/AbstractCodeGen.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen
2 |
3 | import de.szeiger.interact.Config
4 | import de.szeiger.interact.ast.Symbol
5 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _}
6 | import org.objectweb.asm.util.{CheckClassAdapter, Textifier, TraceClassVisitor}
7 | import org.objectweb.asm.{ClassReader, ClassWriter => AClassWriter}
8 |
9 | import java.io.{OutputStreamWriter, PrintWriter}
10 | import java.util.zip.CRC32
11 |
12 | abstract class AbstractCodeGen(config: Config) {
13 | import AbstractCodeGen._
14 |
15 | private[this] def getCRC32(a: Array[Byte]): Long = {
16 | val crc = new CRC32
17 | crc.update(a)
18 | crc.getValue
19 | }
20 |
21 | protected def addClass(cl: ClassWriter, cls: ClassDSL): Unit = {
22 | val cw = new AClassWriter(AClassWriter.COMPUTE_FRAMES)
23 | val ca = new CheckClassAdapter(cw)
24 | cls.accept(ca)
25 | val raw = cw.toByteArray
26 | if(config.logCodeGenSummary) println(s"Generated class ${cls.name} (${raw.length} bytes, crc ${getCRC32(raw)})")
27 | if(config.logGeneratedClasses.exists(s => s == "*" || cls.name.contains(s))) {
28 | val cr = new ClassReader(raw)
29 | cr.accept(new TraceClassVisitor(cw, new Textifier(), new PrintWriter(new OutputStreamWriter(System.out))), 0)
30 | }
31 | cl.writeClass(cls.javaName, raw)
32 | }
33 |
34 | // Create a new Symbol instance that matches the given Symbol and place it on the stack
35 | protected def reifySymbol(m: MethodDSL, sym: Symbol): MethodDSL = {
36 | if(sym.isEmpty) m.invokestatic(symbol_NoSymbol)
37 | else m.newInitDup(new_Symbol) {
38 | m.ldc(sym.id).iconst(sym.arity).iconst(sym.returnArity)
39 | m.iconst(sym.payloadType.value).iconst(sym.matchContinuationPort)
40 | m.iconst(sym.flags)
41 | }
42 | }
43 | }
44 |
45 | object AbstractCodeGen {
46 | val symbolT = tp.c[Symbol]
47 | val symbol_NoSymbol = symbolT.method("NoSymbol", tp.m()(symbolT))
48 | val new_Symbol = symbolT.constr(tp.m(tp.c[String], tp.I, tp.I, tp.I, tp.I, tp.I).V)
49 |
50 | private[this] def encodeName(s: String): String = {
51 | val b = new StringBuilder()
52 | s.foreach {
53 | case '|' => b.append("$bar")
54 | case '^' => b.append("$up")
55 | case '&' => b.append("$amp")
56 | case '<' => b.append("$less")
57 | case '>' => b.append("$greater")
58 | case ':' => b.append("$colon")
59 | case '+' => b.append("$plus")
60 | case '-' => b.append("$minus")
61 | case '*' => b.append("$times")
62 | case '/' => b.append("$div")
63 | case '%' => b.append("$percent")
64 | case c => b.append(c)
65 | }
66 | b.result()
67 | }
68 |
69 | def encodeName(s: Symbol): String = {
70 | assert(s.isDefined)
71 | encodeName(s.id)
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/BoxOps.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen
2 |
3 | import de.szeiger.interact.{IntBox, IntBoxImpl, LongBox, LongBoxImpl, RefBox, RefBoxImpl}
4 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _}
5 |
6 | class BoxDesc private (val boxT: InterfaceOwner, val implT: ClassOwner, val unboxedT: ValDesc) {
7 | lazy val newBox = implT.constr(tp.m().V)
8 | lazy val getValue = boxT.method("getValue", tp.m()(unboxedT))
9 | lazy val setValue = boxT.method("setValue", tp.m(unboxedT).V)
10 |
11 | def implementInterface(c: ClassDSL, owner: ClassOwner): Unit = {
12 | if(unboxedT != tp.V) {
13 | val field = owner.field("value", unboxedT)
14 | c.field(Acc.PUBLIC, field)
15 | c.setter(field)
16 | c.getter(field)
17 | }
18 | }
19 | }
20 |
21 | object BoxDesc {
22 | val intDesc = new BoxDesc(tp.i[IntBox], tp.c[IntBoxImpl], tp.I)
23 | val longDesc = new BoxDesc(tp.i[LongBox], tp.c[LongBoxImpl], tp.J)
24 | val refDesc = new BoxDesc(tp.i[RefBox], tp.c[RefBoxImpl], tp.Object)
25 | val voidDesc = new BoxDesc(null, null, tp.V)
26 | }
27 |
28 | class BoxOps(m: MethodDSL, val boxDesc: BoxDesc) {
29 | protected[this] lazy val t: TypedDSL = TypedDSL(boxDesc.unboxedT, m)
30 |
31 | def unboxedT: ValDesc = boxDesc.unboxedT
32 |
33 | def load(v: VarIdx) = t.xload(v)
34 | def store(v: VarIdx) = t.xstore(v)
35 |
36 | def unboxedClass: Class[_] = boxDesc.unboxedT.jvmClass.get
37 | def boxedClass: Class[_] = boxDesc.boxT.jvmClass.get
38 |
39 | def getBoxValue = m.invoke(boxDesc.getValue)
40 | def setBoxValue = m.invoke(boxDesc.setValue)
41 | def newBoxStore(name: String): VarIdx = m.newInitDup(boxDesc.newBox)().storeLocal(boxDesc.implT, name)
42 | def newBoxStoreDup: VarIdx = m.newInitDup(boxDesc.newBox)().dup.storeLocal(boxDesc.implT)
43 | }
44 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/ClassWriter.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen
2 |
3 | import de.szeiger.interact.Config
4 | import org.jetbrains.java.decompiler.main.decompiler.ConsoleDecompiler
5 |
6 | import java.io.{File, FileOutputStream}
7 | import java.nio.file.{Files, Path}
8 | import java.util.jar.{JarEntry, JarOutputStream}
9 |
10 | trait ClassWriter { self =>
11 | def writeClass(javaName: String, classFile: Array[Byte]): Unit
12 | def close(): Unit = ()
13 |
14 | def and(w: ClassWriter): ClassWriter = new ClassWriter {
15 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = {
16 | self.writeClass(javaName, classFile)
17 | w.writeClass(javaName, classFile)
18 | }
19 | override def close(): Unit = {
20 | self.close()
21 | w.close()
22 | }
23 | }
24 | }
25 |
26 | final class NullClassWriter extends ClassWriter {
27 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = ()
28 | }
29 |
30 | final class JarClassWriter(file: Path) extends ClassWriter {
31 | if(Files.exists(file)) Files.delete(file)
32 | private[this] val out = new JarOutputStream(new FileOutputStream(file.toFile))
33 | private[this] var closed = false
34 |
35 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = {
36 | assert(!closed)
37 | out.putNextEntry(new JarEntry(javaName.replace('.', '/')+".class"))
38 | out.write(classFile)
39 | out.closeEntry()
40 | }
41 |
42 | override def close(): Unit = {
43 | out.close()
44 | closed = true
45 | }
46 | }
47 |
48 | final class ClassDirWriter(dir: Path) extends ClassWriter {
49 | ClassWriter.delete(dir)
50 |
51 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = {
52 | val p = javaName.split('.').foldLeft(dir) { case (p, s) => p.resolve(s) }
53 | Files.createDirectories(p.getParent)
54 | val f = p.resolveSibling(p.getFileName + ".class")
55 | Files.write(f, classFile)
56 | }
57 | }
58 |
59 | private final class DecompileAdapter(parent: ClassWriter, classes: Path, out: Path) extends ClassWriter {
60 | def writeClass(javaName: String, classFile: Array[Byte]): Unit = parent.writeClass(javaName, classFile)
61 |
62 | override def close(): Unit = {
63 | parent.close()
64 | ClassWriter.delete(out)
65 | Files.createDirectories(out)
66 | val args = Array(classes.toAbsolutePath.toFile.getPath, out.toAbsolutePath.toFile.getPath)
67 | ConsoleDecompiler.main(args)
68 | }
69 | }
70 |
71 | private final class SkipClassWriter(parent: ClassWriter, skip: Set[String]) extends ClassWriter {
72 | def writeClass(javaName: String, classFile: Array[Byte]): Unit =
73 | if(!skip.contains(javaName) && !skip.contains("*")) parent.writeClass(javaName, classFile)
74 | override def close(): Unit = parent.close()
75 | }
76 |
77 | object ClassWriter {
78 | private[codegen] def delete(p: Path): Unit = if(Files.exists(p)) {
79 | if(Files.isDirectory(p)) Files.list(p).forEach(delete)
80 | Files.delete(p)
81 | }
82 |
83 | def apply(config: Config, parent: ClassWriter): ClassWriter = {
84 | val p = if(config.skipCodeGen.nonEmpty) new SkipClassWriter(parent, config.skipCodeGen) else parent
85 | val cw: ClassWriter = config.writeOutput match {
86 | case Some(f) if f.getFileName.toString.endsWith(".jar") => p.and(new JarClassWriter(f))
87 | case Some(f) => p.and(new ClassDirWriter(f))
88 | case None => p
89 | }
90 | (config.writeJava, config.writeOutput) match {
91 | case (Some(classes), None) => ???
92 | case (Some(sources), Some(classes)) => new DecompileAdapter(cw, classes, sources)
93 | case _ => cw
94 | }
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/ParSupport.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen
2 |
3 | import java.lang.invoke.VarHandle
4 | import java.util.concurrent.atomic.AtomicInteger
5 | import java.util.concurrent.{CountedCompleter, ForkJoinPool}
6 | import scala.annotation.tailrec
7 |
8 | object ParSupport {
9 | val defaultParallelism: Int = math.min(ForkJoinPool.commonPool().getParallelism, 8)
10 |
11 | final class AtomicCounter {
12 | private[this] val ai = new AtomicInteger(0)
13 |
14 | @tailrec def max(i: Int): Unit = {
15 | val v = ai.get()
16 | if(i > v && !ai.compareAndSet(v, i)) max(i)
17 | }
18 |
19 | def get: Int = ai.get()
20 | }
21 |
22 | def foreach[T >: Null <: AnyRef](a: IterableOnce[T], parallelism: Int)(f: T => Unit): Unit = {
23 | if(parallelism <= 1) a.iterator.foreach(f)
24 | else {
25 | val it = a.iterator
26 | def getNext: T = it.synchronized { if(it.hasNext) it.next() else null }
27 | final class Task(parent: CountedCompleter[_]) extends CountedCompleter[Null](parent) {
28 | @tailrec def compute: Unit = getNext match {
29 | case null =>
30 | VarHandle.releaseFence()
31 | propagateCompletion()
32 | case v =>
33 | f(v)
34 | compute
35 | }
36 | }
37 | (new CountedCompleter[Null](null, parallelism) {
38 | override def compute(): Unit = {
39 | for(i <- 1 to parallelism) new Task(this).fork()
40 | propagateCompletion()
41 | }
42 | }).invoke()
43 | }
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/dsl/Acc.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen.dsl
2 |
3 | import org.objectweb.asm.Opcodes._
4 |
5 | trait AccessFlags extends Any {
6 | protected[this] def acc: Int
7 | @inline def PUBLIC = new Acc(acc | ACC_PUBLIC)
8 | @inline def PRIVATE = new Acc(acc | ACC_PRIVATE)
9 | @inline def PROTECTED = new Acc(acc | ACC_PROTECTED)
10 | @inline def STATIC = new Acc(acc | ACC_STATIC)
11 | @inline def FINAL = new Acc(acc | ACC_FINAL)
12 | @inline def SUPER = new Acc(acc | ACC_SUPER)
13 | @inline def SYNCHRONIZED = new Acc(acc | ACC_SYNCHRONIZED)
14 | @inline def OPEN = new Acc(acc | ACC_OPEN)
15 | @inline def TRANSITIVE = new Acc(acc | ACC_TRANSITIVE)
16 | @inline def VOLATILE = new Acc(acc | ACC_VOLATILE)
17 | @inline def BRIDGE = new Acc(acc | ACC_BRIDGE)
18 | @inline def STATIC_PHASE = new Acc(acc | ACC_STATIC_PHASE)
19 | @inline def VARARGS = new Acc(acc | ACC_VARARGS)
20 | @inline def TRANSIENT = new Acc(acc | ACC_TRANSIENT)
21 | @inline def NATIVE = new Acc(acc | ACC_NATIVE)
22 | @inline def INTERFACE = new Acc(acc | ACC_INTERFACE)
23 | @inline def ABSTRACT = new Acc(acc | ACC_ABSTRACT)
24 | @inline def STRICT = new Acc(acc | ACC_STRICT)
25 | @inline def SYNTHETIC = new Acc(acc | ACC_SYNTHETIC)
26 | @inline def ANNOTATION = new Acc(acc | ACC_ANNOTATION)
27 | @inline def ENUM = new Acc(acc | ACC_ENUM)
28 | @inline def MANDATED = new Acc(acc | ACC_MANDATED)
29 | @inline def MODULE = new Acc(acc | ACC_MODULE)
30 | @inline def RECORD = new Acc(acc | ACC_RECORD)
31 | @inline def DEPRECATED = new Acc(acc | ACC_DEPRECATED)
32 | }
33 |
34 | final class Acc(val acc: Int) extends AnyVal with AccessFlags {
35 | def | (other: Acc): Acc = new Acc(acc | other.acc)
36 | def has (other: Acc): Boolean = (acc & other.acc) == other.acc
37 | }
38 |
39 | object Acc extends AccessFlags {
40 | protected[this] final val acc = 0
41 | @inline def none: Acc = new Acc(0)
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/dsl/Desc.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen.dsl
2 |
3 | import org.objectweb.asm.Type
4 |
5 | import scala.reflect.ClassTag
6 |
7 | abstract class Desc {
8 | def desc: String
9 | def isArray: Boolean = desc.startsWith("[")
10 | def isMethod: Boolean = desc.startsWith("(")
11 | def isClass: Boolean = desc.startsWith("L")
12 | override def hashCode() = desc.hashCode
13 | override def equals(obj: Any) = obj match {
14 | case o: Desc => desc == o.desc
15 | case _ => false
16 | }
17 | }
18 | abstract class MethodDesc extends Desc
19 | abstract class ValDesc extends Desc {
20 | def a: ValDesc = new Desc.ValDescImpl("["+desc, jvmClass.map(c => c.arrayType()))
21 | def width: Int = {
22 | val d = desc
23 | if(d == "J" || d == "D") 2
24 | else 1
25 | }
26 | def jvmClass: Option[Class[_]]
27 | }
28 | final class PrimitiveValDesc(val desc: String, val jvmType: Class[_]) extends ValDesc {
29 | def jvmClass = Some(jvmType)
30 | }
31 |
32 | object Desc {
33 | import java.lang.{String => JString}
34 | private[dsl] class ValDescImpl(val desc: JString, val jvmClass: Option[Class[_]]) extends ValDesc
35 | private[this] class MethodDescImpl(val desc: JString) extends MethodDesc
36 | class MethodArgs private[Desc] (params: Seq[ValDesc]) {
37 | private[this] def d = params.iterator.map(_.desc).mkString("(", "", ")")
38 | def B: MethodDesc = new MethodDescImpl(s"${d}B")
39 | def Z: MethodDesc = new MethodDescImpl(s"${d}Z")
40 | def C: MethodDesc = new MethodDescImpl(s"${d}C")
41 | def I: MethodDesc = new MethodDescImpl(s"${d}I")
42 | def S: MethodDesc = new MethodDescImpl(s"${d}S")
43 | def D: MethodDesc = new MethodDescImpl(s"${d}D")
44 | def F: MethodDesc = new MethodDescImpl(s"${d}F")
45 | def J: MethodDesc = new MethodDescImpl(s"${d}J")
46 | def V: MethodDesc = new MethodDescImpl(s"${d}V")
47 | def apply(ret: ValDesc): MethodDesc = new MethodDescImpl(s"${d}${ret.desc}")
48 | }
49 | val B: PrimitiveValDesc = new PrimitiveValDesc("B", java.lang.Byte.TYPE)
50 | val Z: PrimitiveValDesc = new PrimitiveValDesc("Z", java.lang.Boolean.TYPE)
51 | val C: PrimitiveValDesc = new PrimitiveValDesc("C", java.lang.Character.TYPE)
52 | val I: PrimitiveValDesc = new PrimitiveValDesc("I", java.lang.Integer.TYPE)
53 | val S: PrimitiveValDesc = new PrimitiveValDesc("S", java.lang.Short.TYPE)
54 | val D: PrimitiveValDesc = new PrimitiveValDesc("D", java.lang.Double.TYPE)
55 | val F: PrimitiveValDesc = new PrimitiveValDesc("F", java.lang.Float.TYPE)
56 | val J: PrimitiveValDesc = new PrimitiveValDesc("J", java.lang.Long.TYPE)
57 | val V: PrimitiveValDesc = new PrimitiveValDesc("V", java.lang.Void.TYPE)
58 | val Object: ClassOwner = c[AnyRef]
59 | val String: ClassOwner = c[String]
60 | def m(desc: JString): MethodDesc = new MethodDescImpl(desc)
61 | def m(jMethod: java.lang.reflect.Method): MethodDesc = m(Type.getMethodDescriptor(jMethod))
62 | def m(params: ValDesc*): MethodArgs = new MethodArgs(params)
63 | def c(className: JString): ClassOwner = new ClassOwner(className, None)
64 | def c(cls: Class[_]): ClassOwner = ClassOwner(cls)
65 | def c[T : ClassTag]: ClassOwner = ClassOwner.apply[T]
66 | def i(className: JString): InterfaceOwner = new InterfaceOwner(className, None)
67 | def i(cls: Class[_]): InterfaceOwner = InterfaceOwner(cls)
68 | def i[T : ClassTag]: InterfaceOwner = InterfaceOwner.apply[T]
69 | def o(cls: Class[_]): Owner = Owner(cls)
70 | def o[T : ClassTag]: Owner = Owner.apply[T]
71 | }
72 |
73 | sealed abstract class Owner extends ValDesc {
74 | def className: String
75 | def isInterface: Boolean
76 | final override def toString: String = className
77 | def desc: String = s"L$className;"
78 |
79 | def method(name: String, desc: MethodDesc): MethodRef = new MethodRef(this, name, desc)
80 | def field(name: String, desc: ValDesc): FieldRef = new FieldRef(this, name, desc)
81 | }
82 | object Owner {
83 | def apply[T](implicit ct: ClassTag[T]): Owner = apply(ct.runtimeClass)
84 | def apply(cls: Class[_]): Owner =
85 | if(cls.isInterface) InterfaceOwner(cls) else ClassOwner(cls)
86 | }
87 | class ClassOwner(val className: String, val jvmClass: Option[Class[_]]) extends Owner {
88 | def javaName = className.replace('/', '.')
89 | def isInterface = false
90 | def constr(desc: MethodDesc): ConstructorRef = new ConstructorRef(this, desc)
91 | }
92 | object ClassOwner {
93 | def apply[T](implicit ct: ClassTag[T]): ClassOwner = apply(ct.runtimeClass)
94 | def apply(cls: Class[_]): ClassOwner = {
95 | assert(!cls.isInterface)
96 | new ClassOwner(cls.getName.replace('.', '/'), Some(cls))
97 | }
98 | }
99 | class InterfaceOwner(val className: String, val jvmClass: Option[Class[_]]) extends Owner {
100 | def isInterface = true
101 | }
102 | object InterfaceOwner {
103 | def apply[T](implicit ct: ClassTag[T]): InterfaceOwner = apply(ct.runtimeClass)
104 | def apply(cls: Class[_]): InterfaceOwner = {
105 | assert(cls.isInterface)
106 | new InterfaceOwner(cls.getName.replace('.', '/'), Some(cls))
107 | }
108 | }
109 |
110 | case class MethodRef(owner: Owner, name: String, desc: MethodDesc) {
111 | def on(owner: Owner): MethodRef = new MethodRef(owner, name, desc)
112 | }
113 | object MethodRef {
114 | def apply(jMethod: java.lang.reflect.Method): MethodRef =
115 | Desc.o(jMethod.getDeclaringClass).method(jMethod.getName, Desc.m(jMethod))
116 | }
117 |
118 | case class ConstructorRef(tpe: ClassOwner, desc: MethodDesc)
119 |
120 | case class FieldRef(owner: Owner, name: String, desc: ValDesc)
121 |
--------------------------------------------------------------------------------
/src/main/scala/codegen/dsl/TypedDSL.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.codegen.dsl
2 |
3 | import org.objectweb.asm.Opcodes._
4 | import org.objectweb.asm.Label
5 |
6 | object TypedDSL {
7 | def apply(d: ValDesc, m: MethodDSL): TypedDSL = d.desc match {
8 | case "I" => new IntTypedDSL(m)
9 | case "J" => new LongTypedDSL(m)
10 | case "F" => new FloatTypedDSL(m)
11 | case "D" => new DoubleTypedDSL(m)
12 | case _ if d.isClass || d.isArray => new RefTypedDSL(m, d)
13 | }
14 | }
15 |
16 | abstract class TypedDSL(m: MethodDSL) {
17 | def desc: ValDesc
18 | def wide: Boolean
19 | def xload(v: VarIdx): this.type
20 | def xstore(v: VarIdx): this.type
21 | def xaload: this.type
22 | def xastore: this.type
23 | def xreturn: this.type
24 | def zero: this.type
25 | def xdup: this.type = { if(wide) m.dup2 else m.dup; this }
26 | def xpop: this.type = { if(wide) m.pop2 else m.pop; this }
27 | def if_== : ThenDSL
28 | def if_!= : ThenDSL
29 | }
30 |
31 | class RefTypedDSL(m: MethodDSL, val desc: ValDesc) extends TypedDSL(m) {
32 | def wide: Boolean = false
33 |
34 | def xload(v: VarIdx) : this.type = { m.aload(v); this }
35 | def xstore(v: VarIdx): this.type = { m.astore(v); this }
36 | def xaload : this.type = { m.aaload; this }
37 | def xastore : this.type = { m.aastore; this }
38 | def xreturn : this.type = { m.areturn; this }
39 | def zero : this.type = { m.aconst_null; this }
40 | def if_== : ThenDSL = { m.aconst_null; new ThenDSL(IF_ACMPEQ, IF_ACMPNE, m, new Label) }
41 | def if_!= : ThenDSL = { m.aconst_null; new ThenDSL(IF_ACMPNE, IF_ACMPEQ, m, new Label) }
42 | }
43 |
44 | abstract class NumericTypedDSL(m: MethodDSL) extends TypedDSL(m) {
45 | def desc: PrimitiveValDesc
46 |
47 | def xadd: this.type
48 | def xmul: this.type
49 | def xsub: this.type
50 | def xdiv: this.type
51 | def xneg: this.type
52 | def xrem: this.type
53 | def x2f: this.type
54 | def x2i: this.type
55 | def x2l: this.type
56 | def x2d: this.type
57 |
58 | protected[this] def xcmp: Unit
59 | protected[this] def if0(pos: Int, neg: Int): ThenDSL = {
60 | zero
61 | ifX(pos, neg)
62 | }
63 | protected[this] def ifX(pos: Int, neg: Int): ThenDSL = {
64 | xcmp
65 | new ThenDSL(pos, neg, m, new Label)
66 | }
67 |
68 | def if_== : ThenDSL = if0(IF_ICMPEQ, IF_ICMPNE)
69 | def if_!= : ThenDSL = if0(IF_ICMPNE, IF_ICMPEQ)
70 | def if_< : ThenDSL = if0(IF_ICMPLT, IF_ICMPGE)
71 | def if_> : ThenDSL = if0(IF_ICMPGT, IF_ICMPLE)
72 | def if_<= : ThenDSL = if0(IF_ICMPLE, IF_ICMPGT)
73 | def if_>= : ThenDSL = if0(IF_ICMPGE, IF_ICMPLT)
74 |
75 | def ifX_== : ThenDSL = ifX(IF_ICMPEQ, IF_ICMPNE)
76 | def ifX_!= : ThenDSL = ifX(IF_ICMPNE, IF_ICMPEQ)
77 | def ifX_< : ThenDSL = ifX(IF_ICMPLT, IF_ICMPGE)
78 | def ifX_> : ThenDSL = ifX(IF_ICMPGT, IF_ICMPLE)
79 | def ifX_<= : ThenDSL = ifX(IF_ICMPLE, IF_ICMPGT)
80 | def ifX_>= : ThenDSL = ifX(IF_ICMPGE, IF_ICMPLT)
81 | }
82 |
83 | class IntTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) {
84 | protected[this] def xcmp: Unit = ()
85 |
86 | def desc: PrimitiveValDesc = Desc.I
87 | def wide: Boolean = false
88 |
89 | def xload(v: VarIdx) : this.type = { m.iload(v); this }
90 | def xstore(v: VarIdx): this.type = { m.istore(v); this }
91 | def xaload : this.type = { m.iaload; this }
92 | def xastore : this.type = { m.iastore; this }
93 | def xreturn : this.type = { m.ireturn; this }
94 | def zero : this.type = { m.iconst(0); this }
95 | def xadd : this.type = { m.iadd; this }
96 | def xmul : this.type = { m.imul; this }
97 | def xsub : this.type = { m.isub; this }
98 | def xdiv : this.type = { m.idiv; this }
99 | def xneg : this.type = { m.ineg; this }
100 | def xrem : this.type = { m.irem; this }
101 | def x2f : this.type = { m.i2f; this }
102 | def x2i : this.type = this
103 | def x2l : this.type = { m.i2l; this }
104 | def x2d : this.type = { m.l2d; this }
105 |
106 | override def if_== : ThenDSL = new ThenDSL(IFEQ, IFNE, m, new Label)
107 | override def if_!= : ThenDSL = new ThenDSL(IFNE, IFEQ, m, new Label)
108 | override def if_< : ThenDSL = new ThenDSL(IFLT, IFGE, m, new Label)
109 | override def if_> : ThenDSL = new ThenDSL(IFGT, IFLE, m, new Label)
110 | override def if_<= : ThenDSL = new ThenDSL(IFLE, IFGT, m, new Label)
111 | override def if_>= : ThenDSL = new ThenDSL(IFGE, IFLT, m, new Label)
112 | }
113 |
114 | class LongTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) {
115 | protected[this] def xcmp: Unit = m.lcmp
116 |
117 | def desc: PrimitiveValDesc = Desc.J
118 | def wide: Boolean = true
119 |
120 | def xload(v: VarIdx) : this.type = { m.lload(v); this }
121 | def xstore(v: VarIdx): this.type = { m.lstore(v); this }
122 | def xaload : this.type = { m.laload; this }
123 | def xastore : this.type = { m.lastore; this }
124 | def xreturn : this.type = { m.lreturn; this }
125 | def zero : this.type = { m.lconst(0); this }
126 | def xadd : this.type = { m.ladd; this }
127 | def xmul : this.type = { m.lmul; this }
128 | def xsub : this.type = { m.lsub; this }
129 | def xdiv : this.type = { m.ldiv; this }
130 | def xneg : this.type = { m.lneg; this }
131 | def xrem : this.type = { m.lrem; this }
132 | def x2f : this.type = { m.l2f; this }
133 | def x2i : this.type = { m.l2i; this }
134 | def x2l : this.type = this
135 | def x2d : this.type = { m.l2d; this }
136 | }
137 |
138 | class FloatTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) {
139 | protected[this] def xcmp: Unit = m.fcmpl
140 |
141 | def desc: PrimitiveValDesc = Desc.F
142 | def wide: Boolean = false
143 |
144 | def xload(v: VarIdx) : this.type = { m.fload(v); this }
145 | def xstore(v: VarIdx): this.type = { m.fstore(v); this }
146 | def xaload : this.type = { m.faload; this }
147 | def xastore : this.type = { m.fastore; this }
148 | def xreturn : this.type = { m.freturn; this }
149 | def zero : this.type = { m.fconst(0); this }
150 | def xadd : this.type = { m.fadd; this }
151 | def xmul : this.type = { m.fmul; this }
152 | def xsub : this.type = { m.fsub; this }
153 | def xdiv : this.type = { m.fdiv; this }
154 | def xneg : this.type = { m.fneg; this }
155 | def xrem : this.type = { m.frem; this }
156 | def x2f : this.type = this
157 | def x2i : this.type = { m.f2i; this }
158 | def x2l : this.type = { m.f2l; this }
159 | def x2d : this.type = { m.f2d; this }
160 | }
161 |
162 | class DoubleTypedDSL(m: MethodDSL) extends NumericTypedDSL(m) {
163 | protected[this] def xcmp: Unit = m.dcmpl
164 |
165 | def desc: PrimitiveValDesc = Desc.D
166 | def wide: Boolean = true
167 |
168 | def xload(v: VarIdx) : this.type = { m.dload(v); this }
169 | def xstore(v: VarIdx): this.type = { m.dstore(v); this }
170 | def xaload : this.type = { m.daload; this }
171 | def xastore : this.type = { m.dastore; this }
172 | def xreturn : this.type = { m.dreturn; this }
173 | def zero : this.type = { m.dconst(0); this }
174 | def xadd : this.type = { m.dadd; this }
175 | def xmul : this.type = { m.dmul; this }
176 | def xsub : this.type = { m.dsub; this }
177 | def xdiv : this.type = { m.ddiv; this }
178 | def xneg : this.type = { m.dneg; this }
179 | def xrem : this.type = { m.drem; this }
180 | def x2f : this.type = { m.d2f; this }
181 | def x2i : this.type = { m.d2i; this }
182 | def x2l : this.type = { m.d2l; this }
183 | def x2d : this.type = this
184 | }
185 |
--------------------------------------------------------------------------------
/src/main/scala/mt/workers/Workers.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.mt.workers
2 |
3 | import java.lang.invoke.MethodHandles
4 | import scala.annotation.tailrec
5 |
6 | object Workers {
7 | private[workers] final val endMarker = new AnyRef
8 | }
9 |
10 | abstract class Worker[T >: Null] extends Thread {
11 | import Worker._
12 |
13 | private[workers] final var _workers: Workers[T] = _
14 | private[workers] final var _idx: Int = _
15 | private[workers] final var page: Page = _
16 |
17 | @inline final def workers = _workers
18 | @inline final def idx = _idx
19 |
20 | @volatile private[this] var _state: Int = 0
21 |
22 | def apply(v: T): Unit
23 |
24 | final override def run(): Unit = {
25 | while(true) {
26 | var item = tryTake()
27 | if(item == null) {
28 | maybeEmpty()
29 | item = take()
30 | }
31 | if(item.asInstanceOf[AnyRef] eq Workers.endMarker) return
32 | apply(item)
33 | }
34 | }
35 |
36 | def maybeEmpty(): Unit = ()
37 |
38 | @inline private def tryLock(): Boolean = STATE.weakCompareAndSetAcquire(this, 0, 1)
39 |
40 | @tailrec @inline private final def lock(): Boolean = {
41 | if(tryLock()) true
42 | else {
43 | Thread.onSpinWait()
44 | lock()
45 | }
46 | }
47 |
48 | @inline private final def unlock(): Unit = STATE.setRelease(this, 0)
49 |
50 | final def add(v: T): Unit = {
51 | //println(s"add($v)")
52 | lock()
53 | val p = page
54 | if(p.size < p.data.length) {
55 | p.data(p.size) = v.asInstanceOf[AnyRef]
56 | p.size += 1
57 | } else {
58 | val p2 = new Page(workers.pageSize)
59 | p2.next = p
60 | p2.data(0) = v.asInstanceOf[AnyRef]
61 | p2.size = 1
62 | page = p2
63 | }
64 | unlock()
65 | //println(s"add($v) - done")
66 | }
67 |
68 | private final def tryTake(): T = {
69 | if(tryLock()) {
70 | val v = page.takeOrNull()
71 | val ret = if(v != null) {
72 | v
73 | } else if(page.next != null) {
74 | page = page.next
75 | page.takeOrNull()
76 | } else {
77 | val p = tryStealFromOther(page, (idx+1) % workers.numThreads, workers.numThreads-1)
78 | if(p != null) {
79 | page = p
80 | p.takeOrNull()
81 | } else null
82 | }
83 | unlock()
84 | ret.asInstanceOf[T]
85 | } else null
86 | }
87 |
88 | @tailrec private final def take(): T = {
89 | val ret = tryTake()
90 | if(ret != null) ret
91 | else {
92 | Thread.onSpinWait()
93 | take()
94 | }
95 | }
96 |
97 | private final def trySteal(into: Page): Page = {
98 | if(!tryLock()) null
99 | else {
100 | val ret = if(page.next != null) {
101 | val p = page.next
102 | page.next = null
103 | p
104 | } else if(page.size == 0) null
105 | else if(page.size == 1) {
106 | val p = page
107 | page = into
108 | p
109 | } else {
110 | val stealCount = page.size/2
111 | page.size -= stealCount
112 | into.size = stealCount
113 | System.arraycopy(page.data, page.size, into.data, 0, stealCount)
114 | into
115 | }
116 | unlock()
117 | ret
118 | }
119 | }
120 |
121 | @tailrec private final def tryStealFromOther(into: Page, cur: Int, left: Int): Page = {
122 | val p = workers.workers(cur).trySteal(into)
123 | if(p != null) p
124 | else if(left > 0) tryStealFromOther(into, (cur+1) % workers.numThreads, left-1)
125 | else null
126 | }
127 | }
128 |
129 | object Worker {
130 | final val STATE =
131 | MethodHandles.privateLookupIn(classOf[Worker[_]], MethodHandles.lookup).findVarHandle(classOf[Worker[_]], "_state", classOf[Int])
132 | }
133 |
134 | class Workers[T >: Null](val numThreads: Int, val pageSize: Int, createWorker: Int => Worker[T]) {
135 | val workers = (0 until numThreads).iterator.map { i =>
136 | val w = createWorker(i)
137 | w.setDaemon(true)
138 | w.page = new Page(pageSize)
139 | w._workers = this
140 | w._idx = i
141 | w
142 | }.toArray
143 | private[this] var started = false
144 | private[this] var addIdx = 0
145 |
146 | def add(item: T): Unit = {
147 | workers(addIdx).add(item)
148 | addIdx = (addIdx+1) % numThreads
149 | }
150 |
151 | def start(): Unit = {
152 | if(!started) {
153 | workers.foreach(_.start)
154 | started = true
155 | }
156 | }
157 |
158 | def shutdown(): Unit = {
159 | (1 to workers.length).foreach { _ => add(Workers.endMarker.asInstanceOf[T]) }
160 | workers.foreach(_.join())
161 | }
162 | }
163 |
164 | final class Page(_pageSize: Int) {
165 | val data = new Array[AnyRef](_pageSize)
166 | var size: Int = 0
167 | var next: Page = _
168 |
169 | def takeOrNull(): AnyRef =
170 | if(size > 0) { size -= 1; data(size) } else null
171 | }
172 |
--------------------------------------------------------------------------------
/src/main/scala/offheap/Allocator.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.offheap
2 |
3 | import java.util.Arrays
4 |
5 | object Allocator {
6 | val UNSAFE = {
7 | val f = classOf[sun.misc.Unsafe].getDeclaredField("theUnsafe")
8 | f.setAccessible(true)
9 | f.get(null).asInstanceOf[sun.misc.Unsafe]
10 | }
11 |
12 | def proxyElemOffset = -4L
13 |
14 | val (objArrayOffset, objArrayScale) = {
15 | val cl = classOf[Array[AnyRef]]
16 | (UNSAFE.arrayBaseOffset(cl), UNSAFE.arrayIndexScale(cl))
17 | }
18 |
19 | val (objArrayArrayOffset, objArrayArrayScale) = {
20 | val cl = classOf[Array[Array[AnyRef]]]
21 | (UNSAFE.arrayBaseOffset(cl), UNSAFE.arrayIndexScale(cl))
22 | }
23 |
24 | // used by code generator:
25 | def putInt(address: Long, value: Int): Unit = UNSAFE.putInt(address, value)
26 | def getInt(address: Long): Int = UNSAFE.getInt(address)
27 | def putLong(address: Long, value: Long): Unit = UNSAFE.putLong(address, value)
28 | def getLong(address: Long): Long = UNSAFE.getLong(address)
29 | def getObject(o: AnyRef, offset: Long): AnyRef = UNSAFE.getObject(o, offset)
30 | def putObject(o: AnyRef, offset: Long, v: AnyRef): Unit = UNSAFE.putObject(o, offset, v)
31 | }
32 |
33 | abstract class Allocator {
34 | def dispose(): Unit
35 |
36 | def alloc(len: Long): Long
37 | def alloc8(): Long = alloc(8)
38 | def alloc16(): Long = alloc(16)
39 | def alloc24(): Long = alloc(24)
40 | def alloc32(): Long = alloc(32)
41 |
42 | def free(address: Long, len: Long): Unit
43 | def free8(o: Long): Unit = free(o, 8)
44 | def free16(o: Long): Unit = free(o, 16)
45 | def free24(o: Long): Unit = free(o, 24)
46 | def free32(o: Long): Unit = free(o, 32)
47 | }
48 |
49 | abstract class ProxyAllocator extends Allocator {
50 | def getProxyPage(o: Long): AnyRef
51 | def getProxy(o: Long): AnyRef
52 | def setProxy(o: Long, v: AnyRef): Unit
53 |
54 | def allocProxied(len: Long): Long
55 | def alloc8p(): Long = allocProxied(8)
56 | def alloc16p(): Long = allocProxied(16)
57 | def alloc24p(): Long = allocProxied(24)
58 | def alloc32p(): Long = allocProxied(32)
59 |
60 | def freeProxied(o: Long, len: Long): Unit
61 | def free8p(o: Long): Unit = freeProxied(o, 8)
62 | def free16p(o: Long): Unit = freeProxied(o, 16)
63 | def free24p(o: Long): Unit = freeProxied(o, 24)
64 | def free32p(o: Long): Unit = freeProxied(o, 32)
65 | }
66 |
67 | object SystemAllocator extends Allocator {
68 | import Allocator._
69 |
70 | def dispose(): Unit = ()
71 | def alloc(len: Long): Long = UNSAFE.allocateMemory(len)
72 | def free(address: Long, len: Long): Unit = UNSAFE.freeMemory(address)
73 | }
74 |
75 | final class ArenaAllocator(blockSize: Long = 1024L*1024L*8L) extends Allocator {
76 | import Allocator._
77 | private[this] var block, end, next = 0L
78 |
79 | def dispose(): Unit = {
80 | while(block != 0L) {
81 | val n = UNSAFE.getLong(block)
82 | UNSAFE.freeMemory(block)
83 | block = n
84 | }
85 | }
86 |
87 | def alloc(len: Long): Long = {
88 | if(next + len >= end) {
89 | allocBlock()
90 | assert(next + len < end)
91 | }
92 | val o = next
93 | next += len
94 | o
95 | }
96 |
97 | def free(address: Long, len: Long): Unit = ()
98 |
99 | private[this] def allocBlock(): Unit = {
100 | val b = UNSAFE.allocateMemory(blockSize)
101 | UNSAFE.putLong(b, block)
102 | block = b
103 | next = b + 8
104 | end = b + blockSize
105 | }
106 | }
107 |
108 | final class SliceAllocator(blockSize: Long = 1024L*64L, maxSliceSize: Int = 256, arenaSize: Long = 1024L*1024L*8L) extends ProxyAllocator {
109 | import Allocator._
110 | assert(blockSize % 8 == 0)
111 | assert(maxSliceSize % 8 == 0)
112 | assert(blockSize >= maxSliceSize)
113 |
114 | private[this] val blockAllocator = new ArenaAllocator(arenaSize)
115 | private[this] val slices: Array[Slice] = Array.tabulate(maxSliceSize >> 3)(i => new Slice((i+1) << 3))
116 | private[this] val proxySlices: Array[ProxySlice] = Array.tabulate(maxSliceSize >> 3)(i => new ProxySlice((i+1) << 3))
117 | private[this] var proxyPages: Array[Array[AnyRef]] = new Array[Array[AnyRef]](64)
118 | private[this] var proxyPagesLen: Int= 0
119 |
120 | private[this] val slice8: Slice = slices(8 >> 3)
121 | private[this] val slice16: Slice = slices(16 >> 3)
122 | private[this] val slice24: Slice = slices(24 >> 3)
123 | private[this] val slice32: Slice = slices(32 >> 3)
124 | private[this] val slice8p: ProxySlice = proxySlices(8 >> 3)
125 | private[this] val slice16p: ProxySlice = proxySlices(16 >> 3)
126 | private[this] val slice24p: ProxySlice = proxySlices(24 >> 3)
127 | private[this] val slice32p: ProxySlice = proxySlices(32 >> 3)
128 |
129 | override def alloc8(): Long = slice8.alloc()
130 | override def alloc16(): Long = slice16.alloc()
131 | override def alloc24(): Long = slice24.alloc()
132 | override def alloc32(): Long = slice32.alloc()
133 | override def free8(o: Long): Unit = slice8.free(o)
134 | override def free16(o: Long): Unit = slice16.free(o)
135 | override def free24(o: Long): Unit = slice24.free(o)
136 | override def free32(o: Long): Unit = slice32.free(o)
137 |
138 | override def alloc8p(): Long = slice8p.alloc()
139 | override def alloc16p(): Long = slice16p.alloc()
140 | override def alloc24p(): Long = slice24p.alloc()
141 | override def alloc32p(): Long = slice32p.alloc()
142 | override def free8p(o: Long): Unit = slice8p.free(o)
143 | override def free16p(o: Long): Unit = slice16p.free(o)
144 | override def free24p(o: Long): Unit = slice24p.free(o)
145 | override def free32p(o: Long): Unit = slice32p.free(o)
146 |
147 | def dispose(): Unit = blockAllocator.dispose()
148 | def alloc(len: Long): Long = slices((len >> 3).toInt).alloc()
149 | def free(o: Long, len: Long): Unit = slices((len >> 3).toInt).free(o)
150 | def allocProxied(len: Long): Long = proxySlices((len >> 3).toInt).alloc()
151 | def freeProxied(o: Long, len: Long): Unit = proxySlices((len >> 3).toInt).free(o)
152 |
153 | def getProxyPage(o: Long): AnyRef = {
154 | val off = UNSAFE.getInt(o-8)
155 | UNSAFE.getObject(proxyPages, off)
156 | }
157 |
158 | def getProxy(o: Long): AnyRef = {
159 | val pp = getProxyPage(o)
160 | val off = UNSAFE.getInt(o-4)
161 | UNSAFE.getObject(pp, off)
162 | }
163 |
164 | def setProxy(o: Long, v: AnyRef): Unit = {
165 | val pp = getProxyPage(o)
166 | val off = UNSAFE.getInt(o-4)
167 | UNSAFE.putObject(pp, off, v)
168 | }
169 |
170 | private[this] final class Slice(sliceSize: Int) {
171 | private[this] val allocSize = ((blockSize / sliceSize) * sliceSize) + 8
172 | private[this] var block, last, next, freeSlice = 0L
173 |
174 | private[this] def allocBlock(): Unit = {
175 | val b = blockAllocator.alloc(allocSize)
176 | UNSAFE.putLong(b, block)
177 | block = b
178 | next = b + 8
179 | last = b + allocSize - sliceSize
180 | }
181 |
182 | def alloc(): Long = {
183 | if(freeSlice != 0L) {
184 | val o = freeSlice
185 | freeSlice = UNSAFE.getLong(o)
186 | o
187 | } else {
188 | if(next >= last) allocBlock()
189 | val o = next
190 | next += sliceSize
191 | o
192 | }
193 | }
194 |
195 | def free(o: Long): Unit = {
196 | UNSAFE.putLong(o, freeSlice)
197 | freeSlice = o
198 | }
199 | }
200 |
201 | private[this] final class ProxySlice(_sliceSize: Int) {
202 | private[this] val sliceAllocSize = _sliceSize + 8
203 | private[this] val numBlocks = (blockSize / sliceAllocSize).toInt
204 | private[this] val allocSize = (numBlocks * sliceAllocSize) + 8
205 | private[this] var block, last, next, freeSlice = 0L
206 |
207 | private[this] def allocBlock(): Unit = {
208 | val b = blockAllocator.alloc(allocSize)
209 | UNSAFE.putLong(b, block)
210 | block = b
211 | next = b + 8
212 | last = b + allocSize - sliceAllocSize
213 | proxyPagesLen += 1
214 | if(proxyPagesLen == proxyPages.length)
215 | proxyPages = Arrays.copyOf(proxyPages, proxyPages.length * 2)
216 | proxyPages(proxyPagesLen-1) = new Array[AnyRef](numBlocks)
217 | }
218 |
219 | def alloc(): Long = {
220 | if(freeSlice != 0L) {
221 | val o = freeSlice
222 | freeSlice = UNSAFE.getLong(o)
223 | o
224 | } else {
225 | if(next >= last) allocBlock()
226 | val p = proxyPagesLen-1
227 | UNSAFE.putInt(next, objArrayArrayOffset + p*objArrayArrayScale)
228 | val i = (next-block-8).toInt/sliceAllocSize
229 | UNSAFE.putInt(next + 4, objArrayOffset + i*objArrayScale)
230 | val o = next + 8
231 | next += sliceAllocSize
232 | o
233 | }
234 | }
235 |
236 | def free(o: Long): Unit = {
237 | setProxy(o, null)
238 | UNSAFE.putLong(o, freeSlice)
239 | freeSlice = o
240 | }
241 | }
242 | }
243 |
--------------------------------------------------------------------------------
/src/main/scala/offheap/MemoryDebugger.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.offheap
2 |
3 | import scala.collection.mutable
4 |
5 | object MemoryDebugger extends ProxyAllocator {
6 | private[this] var parent: ProxyAllocator = _
7 |
8 | def setParent(p: ProxyAllocator): Unit = {
9 | objects.clear()
10 | parent = p
11 | }
12 |
13 | private[this] final class Obj(val address: Long, val length: Long, val proxied: Boolean) {
14 | override def toString = s"[$address..${address+length}["
15 | }
16 |
17 | private[this] val objects = mutable.TreeMap.empty[Long, Obj](implicitly[Ordering[Long]].reverse)
18 |
19 | private[this] def find(o: Long, accessLen: Int): Obj = {
20 | val it = objects.valuesIteratorFrom(o)
21 | if(!it.hasNext) throw new AssertionError(s"Address $o is before all allocated objects")
22 | val v = it.next()
23 | if(o + accessLen > v.address + v.length)
24 | throw new AssertionError(s"Address $o with access length $accessLen is outside of object $v")
25 | v
26 | }
27 |
28 | def dispose(): Unit = {
29 | parent.dispose()
30 | objects.clear()
31 | }
32 |
33 | def alloc(len: Long): Long = {
34 | val o = parent.alloc(len)
35 | objects.put(o, new Obj(o, len, false))
36 | if(o % 8 != 0) throw new AssertionError(s"alloc($len) returned non-aligned address $o")
37 | o
38 | }
39 |
40 | def free(address: Long, len: Long): Unit = {
41 | val obj = find(address, 0)
42 | if(obj.address != address) throw new AssertionError(s"free($address, $len) not called with base address of $obj")
43 | if(obj.proxied) throw new AssertionError(s"free($address, $len) called on proxied object")
44 | parent.free(address, len)
45 | objects.remove(address)
46 | }
47 |
48 | def allocProxied(len: Long): Long = {
49 | val o = parent.allocProxied(len)
50 | objects.put(o, new Obj(o, len, true))
51 | if(o % 8 != 0) throw new AssertionError(s"allocProxied($len) returned non-aligned address $o")
52 | o
53 | }
54 |
55 | def freeProxied(address: Long, len: Long): Unit = {
56 | val obj = find(address, 0)
57 | if(obj.address != address) throw new AssertionError(s"freeProxied($address, $len) not called with base address of $obj")
58 | if(!obj.proxied) throw new AssertionError(s"freeProxied($address, $len) called on non-proxied object")
59 | parent.freeProxied(address, len)
60 | objects.remove(address)
61 | }
62 |
63 | def getProxyPage(address: Long): AnyRef = {
64 | val obj = find(address, 0)
65 | if(obj.address != address) throw new AssertionError(s"getProxyPage($address) not called with base address of $obj")
66 | if(!obj.proxied) throw new AssertionError(s"getProxyPage($address) called on non-proxied object")
67 | parent.getProxyPage(address)
68 | }
69 |
70 | def getProxy(address: Long): AnyRef = {
71 | val obj = find(address, 0)
72 | if(obj.address != address) throw new AssertionError(s"getProxy($address) not called with base address of $obj")
73 | if(!obj.proxied) throw new AssertionError(s"getProxy($address) called on non-proxied object")
74 | parent.getProxy(address)
75 | }
76 |
77 | def setProxy(address: Long, value: AnyRef): Unit = {
78 | val obj = find(address, 0)
79 | if(obj.address != address) throw new AssertionError(s"setProxy($address, $obj) not called with base address of $obj")
80 | if(!obj.proxied) throw new AssertionError(s"setProxy($address, $obj) called on non-proxied object")
81 | parent.setProxy(address, value)
82 | }
83 |
84 | def putInt(address: Long, value: Int): Unit = {
85 | val obj = find(address, 4)
86 | if(obj.proxied && address == obj.address + 4) throw new AssertionError(s"Overwriting 32-bit payload of proxied object")
87 | Allocator.putInt(address, value)
88 | }
89 |
90 | def getInt(address: Long): Int = {
91 | val obj = find(address, 4)
92 | Allocator.getInt(address)
93 | }
94 |
95 | def putLong(address: Long, value: Long): Unit = {
96 | val obj = find(address, 8)
97 | if(obj.proxied && address == obj.address) throw new AssertionError(s"Overwriting 32-bit payload of proxied object")
98 | Allocator.putLong(address, value)
99 | }
100 |
101 | def getLong(address: Long): Long = {
102 | val obj = find(address, 8)
103 | Allocator.getLong(address)
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/src/main/scala/stc1/Interpreter.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.stc1
2 |
3 | import de.szeiger.interact.codegen.{ClassWriter, LocalClassLoader}
4 | import de.szeiger.interact._
5 | import de.szeiger.interact.ast.{CompilationUnit, Symbol, Symbols}
6 |
7 | import java.util.Arrays
8 | import scala.collection.mutable
9 |
10 | abstract class Cell {
11 | def arity: Int
12 | def auxCell(p: Int): Cell
13 | def auxPort(p: Int): Int
14 | def setAux(p: Int, c2: Cell, p2: Int): Unit
15 | def reduce(c: Cell, ptw: Interpreter): Unit
16 | def cellSymbol: Symbol
17 | }
18 |
19 | final class DynamicCell(val cellSymbol: Symbol, val arity: Int) extends Cell {
20 | private[this] final val auxCells = new Array[Cell](arity)
21 | private[this] final val auxPorts = new Array[Int](arity)
22 | def auxCell(p: Int): Cell = auxCells(p)
23 | def auxPort(p: Int): Int = auxPorts(p)
24 | def setAux(p: Int, c2: Cell, p2: Int): Unit = { auxCells(p) = c2; auxPorts(p) = p2 }
25 | def reduce(c: Cell, ptw: Interpreter): Unit = ()
26 | }
27 |
28 | abstract class InitialRuleImpl {
29 | def reduce(a0: Cell, a1: Cell, ptw: Interpreter): Unit
30 | def freeWires: Array[Symbol]
31 | }
32 |
33 | final class Interpreter(globals: Symbols, compilationUnit: CompilationUnit, config: Config) extends BaseInterpreter { self =>
34 | private[this] val initialRuleImpls: Vector[InitialRuleImpl] = {
35 | val lcl = new LocalClassLoader
36 | val cw = ClassWriter(config, lcl)
37 | val initial = new CodeGen("generated", cw, config, compilationUnit, globals).compile()
38 | cw.close()
39 | initial.map { cln => lcl.loadClass(cln).getDeclaredConstructor().newInstance().asInstanceOf[InitialRuleImpl] }
40 | }
41 | private[this] val cutBuffer, irreducible = new CutBuffer(16)
42 | private[this] val freeWires = mutable.HashSet.empty[Cell]
43 | private[this] var metrics: ExecutionMetrics = _
44 | private[this] var active0, active1: Cell = _
45 |
46 | def getMetrics: ExecutionMetrics = metrics
47 |
48 | def getAnalyzer: Analyzer[Cell] = new Analyzer[Cell] {
49 | val principals = (cutBuffer.iterator ++ irreducible.iterator).flatMap { case (c1, c2) => Seq((c1, c2), (c2, c1)) }.toMap
50 | def irreduciblePairs: IterableOnce[(Cell, Cell)] = irreducible.iterator
51 | def rootCells = (self.freeWires.iterator ++ principals.keysIterator).toSet
52 | def getSymbol(c: Cell): Symbol = Option(c.cellSymbol).getOrElse(Symbol.NoSymbol)
53 | def getConnected(c: Cell, port: Int): (Cell, Int) =
54 | if(port == -1) principals.get(c).map((_, -1)).orNull else (c.auxCell(port), c.auxPort(port))
55 | def isFreeWire(c: Cell): Boolean = c.isInstanceOf[DynamicCell] && c.cellSymbol.isDefined
56 | }
57 |
58 | def initData(): Unit = {
59 | cutBuffer.clear()
60 | irreducible.clear()
61 | freeWires.clear()
62 | if(config.collectStats) metrics = new ExecutionMetrics
63 | initialRuleImpls.foreach { rule =>
64 | val free = rule.freeWires.map(new DynamicCell(_, 1))
65 | freeWires.addAll(free)
66 | val lhs = new DynamicCell(Symbol.NoSymbol, freeWires.size)
67 | free.iterator.zipWithIndex.foreach { case (c, p) => lhs.setAux(p, c, 0) }
68 | rule.reduce(lhs, new DynamicCell(Symbol.NoSymbol, 0), this)
69 | }
70 | if(config.collectStats) metrics = new ExecutionMetrics
71 | }
72 |
73 | def dispose(): Unit = ()
74 |
75 | def reduce(): Unit =
76 | while(true) {
77 | while(active0 != null) {
78 | val a0 = active0
79 | active0 = null
80 | a0.reduce(active1, this)
81 | }
82 | if(cutBuffer.isEmpty) return
83 | val (a0, a1) = cutBuffer.pop()
84 | a0.reduce(a1, this)
85 | }
86 |
87 | // ptw methods:
88 |
89 | def addActive(a0: Cell, a1: Cell): Unit =
90 | if(active0 == null) { active0 = a0; active1 = a1 } else cutBuffer.addOne(a0, a1)
91 |
92 | def addIrreducible(a0: Cell, a1: Cell): Unit = irreducible.addOne(a0, a1)
93 |
94 | def recordStats(steps: Int, cellAllocations: Int, cachedCellReuse: Int, singletonUse: Int, loopSave: Int, labelCreate: Int): Unit =
95 | metrics.recordStats(steps, cellAllocations, 0, cachedCellReuse, singletonUse, 0, loopSave, 0, 0, labelCreate)
96 |
97 | def recordMetric(metric: String, inc: Int): Unit = metrics.recordMetric(metric, inc)
98 | }
99 |
100 |
101 | final class CutBuffer(initialSize: Int) {
102 | private[this] var pairs = new Array[Cell](initialSize*2)
103 | private[this] var len = 0
104 | @inline def addOne(c1: Cell, c2: Cell): Unit = {
105 | if(len == pairs.length)
106 | pairs = Arrays.copyOf(pairs, pairs.length*2)
107 | pairs(len) = c1
108 | pairs(len+1) = c2
109 | len += 2
110 | }
111 | @inline def isEmpty: Boolean = len == 0
112 | @inline def pop(): (Cell, Cell) = {
113 | len -= 2
114 | val c1 = pairs(len)
115 | val c2 = pairs(len+1)
116 | pairs(len) = null
117 | pairs(len+1) = null
118 | (c1, c2)
119 | }
120 | def clear(): Unit =
121 | if(len > 0) {
122 | pairs = new Array[Cell](initialSize * 2)
123 | len = 0
124 | }
125 | def iterator: Iterator[(Cell, Cell)] = pairs.iterator.take(len).grouped(2).map { case Seq(c1, c2) => (c1, c2) }
126 | }
127 |
--------------------------------------------------------------------------------
/src/main/scala/stc1/PTOps.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.stc1
2 |
3 | import de.szeiger.interact.ast.PayloadType
4 | import de.szeiger.interact.codegen.{BoxDesc, BoxOps}
5 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _}
6 |
7 | object PTOps {
8 | def boxDesc(pt: PayloadType): BoxDesc = pt match {
9 | case PayloadType.INT => BoxDesc.intDesc
10 | case PayloadType.REF => BoxDesc.refDesc
11 | case PayloadType.LABEL => BoxDesc.refDesc
12 | }
13 | def apply(m: MethodDSL, pt: PayloadType): BoxOps = new BoxOps(m, boxDesc(pt))
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala/stc2/Interpreter.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.stc2
2 |
3 | import de.szeiger.interact.codegen.{ClassWriter, LocalClassLoader}
4 | import de.szeiger.interact._
5 | import de.szeiger.interact.ast.{CompilationUnit, PayloadType, Symbol, Symbols}
6 | import de.szeiger.interact.offheap.{Allocator, MemoryDebugger, ProxyAllocator}
7 |
8 | import java.util.Arrays
9 | import scala.collection.mutable
10 |
11 | object Defs {
12 | type Cell = Long
13 | }
14 | import Defs._
15 |
16 | abstract class InitialRuleImpl {
17 | def reduce(a0: Cell, a1: Cell, ptw: Interpreter): Unit
18 | def freeWires: Array[Symbol]
19 | }
20 |
21 | abstract class Dispatch {
22 | def reduce(c1: Cell, c2: Cell, level: Int, ptw: Interpreter): Unit
23 | }
24 |
25 | final class Interpreter(globals: Symbols, compilationUnit: CompilationUnit, config: Config) extends BaseInterpreter { self =>
26 | import Interpreter._
27 |
28 | private[this] val symIds = globals.symbols.filter(_.isCons).toVector.sortBy(_.id).iterator.zipWithIndex.toMap
29 | private[this] val reverseSymIds = symIds.map { case (sym, idx) => (idx, sym) }
30 | private[this] val (initialRuleImpls: Vector[InitialRuleImpl], dispatch: Dispatch) = {
31 | val lcl = new LocalClassLoader
32 | val cw = ClassWriter(config, lcl)
33 | val (initial, dispN) = new CodeGen("generated", cw, compilationUnit, globals, symIds, config).compile()
34 | cw.close()
35 | val irs = initial.map { cln => lcl.loadClass(cln).getDeclaredConstructor().newInstance().asInstanceOf[InitialRuleImpl] }
36 | val disp = lcl.loadClass(dispN).getDeclaredConstructor().newInstance().asInstanceOf[Dispatch]
37 | (irs, disp)
38 | }
39 | private[this] val cutBuffer, irreducible = new CutBuffer(16)
40 | private[this] val freeWires = mutable.HashSet.empty[Cell]
41 | private[this] val freeWireLookup = mutable.HashMap.empty[Int, Symbol]
42 | private[this] var metrics: ExecutionMetrics = _
43 | private[this] var allocator: ProxyAllocator = _
44 | private[this] val singletons: Array[Cell] = new Array(symIds.size)
45 | private[this] var nextLabel = Long.MinValue
46 | private[this] val tailCallDepth = config.tailCallDepth
47 |
48 | def getMetrics: ExecutionMetrics = metrics
49 |
50 | def getAnalyzer: Analyzer[Cell] = new Analyzer[Cell] {
51 | val principals = (cutBuffer.iterator ++ irreducible.iterator).flatMap { case (c1, c2) => Seq((c1, c2), (c2, c1)) }.toMap
52 | def irreduciblePairs: IterableOnce[(Cell, Cell)] = irreducible.iterator
53 | def rootCells = (self.freeWires.iterator ++ principals.keysIterator).toSet
54 | def getSymbol(c: Cell): Symbol = {
55 | val sid = getSymId(c)
56 | reverseSymIds.getOrElse(sid, freeWireLookup.getOrElse(getSymId(c), new Symbol(s"")))
57 | }
58 | def getConnected(c: Cell, port: Int): (Cell, Int) =
59 | if(port == -1) principals.get(c).map((_, -1)).orNull else findCellAndPort(c, port)
60 | def isFreeWire(c: Cell): Boolean = freeWireLookup.contains(getSymId(c))
61 | override def getPayload(ptr: Long): Any = {
62 | val sym = getSymbol(ptr)
63 | if((ptr & TAGMASK) == TAG_UNBOXED) {
64 | sym.payloadType match {
65 | case PayloadType.INT => (ptr >>> 32).toInt
66 | }
67 | } else {
68 | val address = ptr + payloadOffset(sym.arity, sym.payloadType)
69 | sym.payloadType match {
70 | case PayloadType.INT => Allocator.getInt(address)
71 | case PayloadType.LABEL => "label@" + Allocator.getLong(address)
72 | case PayloadType.REF => getProxy(ptr)
73 | }
74 | }
75 | }
76 | }
77 |
78 | def dispose(): Unit = {
79 | if(allocator != null) {
80 | allocator.dispose()
81 | allocator = null
82 | if(config.debugMemory) MemoryDebugger.setParent(null)
83 | }
84 | }
85 |
86 | def initData(): Unit = {
87 | cutBuffer.clear()
88 | irreducible.clear()
89 | freeWires.clear()
90 | freeWireLookup.clear()
91 | //dispose()
92 | nextLabel = Long.MinValue
93 | if(allocator == null) {
94 | allocator = config.newAllocator()
95 | if(config.debugMemory) {
96 | MemoryDebugger.setParent(allocator)
97 | allocator = MemoryDebugger
98 | }
99 | singletons.indices.foreach { i =>
100 | val s = reverseSymIds(i)
101 | if(s.isSingleton) singletons(i) = newCell(i, s.arity)
102 | }
103 | }
104 | if(config.collectStats) metrics = new ExecutionMetrics
105 | initialRuleImpls.foreach { rule =>
106 | val fws = rule.freeWires
107 | val off = reverseSymIds.size + freeWireLookup.size
108 | val lhs = newCell(-1, fws.length)
109 | fws.iterator.zipWithIndex.foreach { case (s, i) =>
110 | freeWireLookup += ((i+off, s))
111 | val c = newCell(i+off, 1)
112 | freeWires += c
113 | setAux(lhs, i, c, 0)
114 | }
115 | val rhs = newCell(-1, 0)
116 | rule.reduce(lhs, rhs, this)
117 | }
118 | if(config.collectStats) metrics = new ExecutionMetrics
119 | }
120 |
121 | def reduce(): Unit =
122 | while(!cutBuffer.isEmpty) {
123 | val (a0, a1) = cutBuffer.pop()
124 | dispatch.reduce(a0, a1, tailCallDepth, this)
125 | }
126 |
127 | private final def newCell(symId: Int, arity: Int, pt: PayloadType = PayloadType.VOID): Long = {
128 | val o = allocator.alloc(cellSize(arity, pt))
129 | Allocator.putInt(o, mkHeader(symId))
130 | o
131 | }
132 |
133 | // ptw methods:
134 |
135 | def addActive(a0: Cell, a1: Cell): Unit = cutBuffer.addOne(a0, a1)
136 |
137 | def addIrreducible(a0: Cell, a1: Cell): Unit = irreducible.addOne(a0, a1)
138 |
139 | def recordStats(steps: Int, cellAllocations: Int, proxyAllocations: Int, cachedCellReuse: Int, singletonUse: Int,
140 | unboxedCells: Int, loopSave: Int, directTail: Int, singleDispatchTail: Int, labelCreate: Int): Unit =
141 | metrics.recordStats(steps, cellAllocations, proxyAllocations, cachedCellReuse, singletonUse, unboxedCells,
142 | loopSave, directTail, singleDispatchTail, labelCreate)
143 |
144 | def recordMetric(metric: String, inc: Int): Unit = metrics.recordMetric(metric, inc)
145 |
146 | def getSingleton(symId: Int): Cell = singletons(symId)
147 |
148 | def allocCell(length: Int): Cell = allocator.alloc(length)
149 | def freeCell(address: Cell, length: Int): Unit = allocator.free(address, length)
150 | def allocProxied(length: Int): Cell = allocator.allocProxied(length)
151 | def freeProxied(address: Cell, length: Int): Unit = allocator.freeProxied(address, length)
152 |
153 | def alloc8(): Long = allocator.alloc8()
154 | def alloc16(): Long = allocator.alloc16()
155 | def alloc24(): Long = allocator.alloc24()
156 | def alloc32(): Long = allocator.alloc32()
157 | def free8(o: Long): Unit = allocator.free8(o)
158 | def free16(o: Long): Unit = allocator.free16(o)
159 | def free24(o: Long): Unit = allocator.free24(o)
160 | def free32(o: Long): Unit = allocator.free32(o)
161 |
162 | def alloc8p(): Long = allocator.alloc8p()
163 | def alloc16p(): Long = allocator.alloc16p()
164 | def alloc24p(): Long = allocator.alloc24p()
165 | def alloc32p(): Long = allocator.alloc32p()
166 | def free8p(o: Long): Unit = allocator.free8p(o)
167 | def free16p(o: Long): Unit = allocator.free16p(o)
168 | def free24p(o: Long): Unit = allocator.free24p(o)
169 | def free32p(o: Long): Unit = allocator.free32p(o)
170 |
171 | def getProxyPage(o: Long): AnyRef = allocator.getProxyPage(o)
172 | def getProxy(o: Long): AnyRef = allocator.getProxy(o)
173 | def setProxy(o: Long, v: AnyRef): Unit = allocator.setProxy(o, v)
174 |
175 | def newLabel: Long = {
176 | val r = nextLabel
177 | nextLabel += 1
178 | r
179 | }
180 | }
181 |
182 | object Interpreter {
183 | // Cell layout:
184 | // 64-bit header
185 | // n * 64-bit pointers for arity n
186 | // optional 64-bit payload depending on PayloadType
187 | //
188 | // Header layout:
189 | // LSB ... MSB
190 | // 012 3456789abcdef 0123456789abcdef 0123456789abcdef 0123456789abcdef
191 | // --------------------------------------------------------------------
192 | // 110 [29-bit symId ] [ padding or 32-bit payload ]
193 | //
194 | // Pointer layouts:
195 | // LSB ... MSB
196 | // 012 3456789abcdef 0123456789abcdef 0123456789abcdef 0123456789abcdef
197 | // --------------------------------------------------------------------
198 | // 000 [ 64-bit aligned address >> 3: pointer to cell (principal) ]
199 | // 100 [ 64-bit aligned address >> 3: pointer to aux port ]
200 | // 010 [ 29-bit symId ] [ unboxed 32-bit payload ]
201 |
202 | final val TAGWIDTH = 3
203 | final val TAGMASK = 7L
204 | final val ADDRMASK = -8L
205 | final val TAG_HEADER = 3
206 | final val TAG_PRINC_PTR = 0
207 | final val TAG_AUX_PTR = 1
208 | final val TAG_UNBOXED = 2
209 |
210 | final val SYMIDMASK = ((1L << 29)-1) << TAGWIDTH
211 |
212 | def showPtr(l: Long, symIds: Map[Int, Symbol] = null): String = {
213 | var raw = l.toBinaryString
214 | raw = "0"*(64-raw.length) + raw
215 | raw = raw.substring(0, 32) + ":" + raw.substring(32, 61) + ":" + raw.substring(61)
216 | def symStr(sid: Int) =
217 | if(symIds == null) s"$sid"
218 | else s"$sid(${symIds.getOrElse(sid, Symbol.NoSymbol).id})"
219 | val decoded = (l & TAGMASK) match {
220 | case _ if l == 0L => "NULL"
221 | case TAG_HEADER =>
222 | val sid = ((l & SYMIDMASK) >>> TAGWIDTH).toInt
223 | s"HEADER:${symStr(sid)}:${l >>> 32}"
224 | case TAG_AUX_PTR => s"AUX_PTR:${l & ADDRMASK}"
225 | case TAG_PRINC_PTR => s"PRINC_PTR:${l & ADDRMASK}"
226 | case TAG_UNBOXED =>
227 | val sid = ((l & SYMIDMASK) >>> TAGWIDTH).toInt
228 | s"UNBOXED:${symStr(sid)}:${l >>> 32}"
229 | case tag => s"Invalid-Tag-$tag:$l"
230 | }
231 | s"$raw::$decoded"
232 | }
233 |
234 | def cellSize(arity: Int, pt: PayloadType) = arity*8 + 8 + (if(pt == PayloadType.LABEL) 8 else 0)
235 | def auxPtrOffset(p: Int): Int = 8 + (p * 8)
236 | def payloadOffset(arity: Int, pt: PayloadType): Int = if(pt == PayloadType.LABEL) 8 + (arity * 8) else 4
237 | def mkHeader(sid: Int): Int = (sid << TAGWIDTH) | TAG_HEADER
238 | def mkUnboxed(sid: Int): Int = (sid << TAGWIDTH) | TAG_UNBOXED
239 |
240 | private def setAux(c: Long, p: Int, c2: Long, p2: Int): Unit = {
241 | var l = c2 + auxPtrOffset(p2)
242 | if(p2 >= 0) l |= TAG_AUX_PTR
243 | Allocator.putLong(c + auxPtrOffset(p), l)
244 | }
245 |
246 | private def getSymId(ptr: Long) =
247 | if((ptr & TAGMASK) == TAG_UNBOXED) ((ptr & SYMIDMASK) >>> TAGWIDTH).toInt
248 | else Allocator.getInt(ptr) >> TAGWIDTH
249 |
250 | private def findCellAndPort(cellAddress: Long, cellPort: Int): (Long, Int) = {
251 | var ptr = Allocator.getLong(cellAddress + auxPtrOffset(cellPort))
252 | if((ptr & TAGMASK) != TAG_AUX_PTR) {
253 | (ptr, -1)
254 | } else {
255 | ptr = ptr & ADDRMASK
256 | var p = -1
257 | while((Allocator.getInt(ptr - auxPtrOffset(p)) & TAGMASK) != TAG_HEADER)
258 | p += 1
259 | (ptr - auxPtrOffset(p), p)
260 | }
261 | }
262 |
263 | def canUnbox(sym: Symbol, arity: Int): Boolean =
264 | arity == 0 && (sym.payloadType == PayloadType.INT || sym.payloadType == PayloadType.VOID)
265 | }
266 |
267 | final class CutBuffer(initialSize: Int) {
268 | private[this] var pairs = new Array[Cell](initialSize*2)
269 | private[this] var len = 0
270 | @inline def addOne(c1: Cell, c2: Cell): Unit = {
271 | if(len == pairs.length)
272 | pairs = Arrays.copyOf(pairs, pairs.length*2)
273 | pairs(len) = c1
274 | pairs(len+1) = c2
275 | len += 2
276 | }
277 | @inline def isEmpty: Boolean = len == 0
278 | @inline def pop(): (Cell, Cell) = {
279 | len -= 2
280 | val c1 = pairs(len)
281 | val c2 = pairs(len+1)
282 | (c1, c2)
283 | }
284 | def clear(): Unit =
285 | if(len > 0) {
286 | pairs = new Array[Cell](initialSize * 2)
287 | len = 0
288 | }
289 | def iterator: Iterator[(Cell, Cell)] = pairs.iterator.take(len).grouped(2).map { case Seq(c1, c2) => (c1, c2) }
290 | }
291 |
--------------------------------------------------------------------------------
/src/main/scala/stc2/PTOps.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.stc2
2 |
3 | import de.szeiger.interact.ast.PayloadType
4 | import de.szeiger.interact.codegen.{BoxDesc, BoxOps}
5 | import de.szeiger.interact.codegen.dsl.{Desc => tp, _}
6 | import de.szeiger.interact.stc2.PTOps.boxDesc
7 |
8 | class PTOps(m: MethodDSL, pt: PayloadType, codeGen: CodeGen) extends BoxOps(m, PTOps.boxDesc(pt)) {
9 | import codeGen._
10 |
11 | def isVoid: Boolean = boxDesc.boxT == null
12 |
13 | def loadConst0: this.type = {
14 | pt match {
15 | case PayloadType.INT =>
16 | m.iconst(0)
17 | }
18 | this
19 | }
20 |
21 | def setCellPayload(ptw: VarIdx, arity: Int)(loadCell: => Unit)(loadUnboxedPayload: => Unit): this.type = {
22 | pt match {
23 | case PayloadType.LABEL =>
24 | loadCell
25 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd
26 | loadUnboxedPayload
27 | m.invokestatic(allocator_putLong)
28 | case PayloadType.INT =>
29 | loadCell
30 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd
31 | loadUnboxedPayload
32 | m.invokestatic(allocator_putInt)
33 | case PayloadType.REF =>
34 | m.aload(ptw)
35 | loadCell
36 | loadUnboxedPayload
37 | m.invoke(ptw_setProxy)
38 | }
39 | this
40 | }
41 |
42 | def getCellPayload(ptw: VarIdx, arity: Int)(loadCell: => Unit): this.type = {
43 | pt match {
44 | case PayloadType.REF =>
45 | m.aload(ptw)
46 | loadCell
47 | m.invoke(ptw_getProxy)
48 | case PayloadType.INT =>
49 | loadCell
50 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd
51 | m.invokestatic(allocator_getInt)
52 | case PayloadType.LABEL =>
53 | loadCell
54 | m.lconst(Interpreter.payloadOffset(arity, pt)).ladd
55 | m.invokestatic(allocator_getLong)
56 | }
57 | this
58 | }
59 |
60 | def untag(loadValue: => Unit): this.type = {
61 | pt match {
62 | case PayloadType.INT =>
63 | loadValue
64 | m.iconst(32).lushr.l2i
65 | }
66 | this
67 | }
68 |
69 | def tag(symId: Int)(loadValue: => Unit): this.type = {
70 | loadValue
71 | m.i2l.iconst(32).lshl
72 | m.lconst((symId.toLong << Interpreter.TAGWIDTH) | Interpreter.TAG_UNBOXED).lor
73 | this
74 | }
75 | }
76 |
77 | object PTOps {
78 | def boxDesc(pt: PayloadType): BoxDesc = pt match {
79 | case PayloadType.INT => BoxDesc.intDesc
80 | case PayloadType.REF => BoxDesc.refDesc
81 | case PayloadType.LABEL => BoxDesc.longDesc
82 | case PayloadType.VOID => BoxDesc.voidDesc
83 | }
84 | def apply(m: MethodDSL, pt: PayloadType)(implicit codeGen: CodeGen): PTOps = new PTOps(m, pt, codeGen)
85 | }
86 |
--------------------------------------------------------------------------------
/src/test/resources/ack.check:
--------------------------------------------------------------------------------
1 | res = 2045n
2 | res2 = 2045n
3 | resB = BoxedInt[2045]
4 | resI = Int[2045]
5 |
--------------------------------------------------------------------------------
/src/test/resources/ack.in:
--------------------------------------------------------------------------------
1 | cons Z
2 | cons S(x)
3 |
4 | # Direct encoding as in https://github.com/inpla/inpla/blob/main/sample/AckSZ-3_5.in
5 | # 4182049 reductions (with pre-reduced rhs)
6 |
7 | def ack(a, b) = r
8 | | Z, y => S(y)
9 | | S(x), Z => ack(x, S(Z))
10 | | S(x), S(y) => (x1, x2) = dup(x); ack(x1, ack(S(x2),y))
11 |
12 | #def ack(_, y) = r
13 | # | Z => S(y)
14 | # | S(x) => ack_Sx(y, x)
15 | #
16 | #def ack_Sx(_, x) = r
17 | # | Z => ack(x, S(Z))
18 | # | S(y) => (x1, x2) = dup(x); ack(x1, ack_Sx(y, x2))
19 |
20 | let res = ack(3n, 8n)
21 |
22 |
23 | # Encoding with pred from https://www.user.tu-berlin.de/o.runge/tfs/workshops/gtvmt08/Program/paper_38.pdf
24 | # 8360028 reductions
25 |
26 | def pred(_) = r
27 | | Z => Z
28 | | S(x) => x
29 | def ack2(_, a) = b
30 | | Z => S(a)
31 | | S(x) => ack2b(a, S(x))
32 | def ack2b(_, a) = b
33 | | Z => ack2(pred(a), S(Z))
34 | | S(y) => (a1, a2) = dup(a); ack2(pred(a1), ack2(a2, y))
35 |
36 | let res2 = ack2(3n, 8n)
37 |
38 |
39 |
40 | # Int-based encoding
41 | # 2786001 reductions
42 |
43 | cons Int[int]
44 |
45 | def ackI(a, b) = r
46 | | Int[x], Int[y] if [x == 0] => Int[y + 1]
47 | if [y == 0] => ackI(Int[x-1], Int[1])
48 | else => ackI(Int[x-1], ackI(Int[x], Int[y-1]))
49 |
50 | let resI = ackI(Int[3], Int[8])
51 |
52 |
53 | # Boxed Integer version using ref payloads
54 | # 2786001 reductions
55 |
56 | cons BoxedInt[ref]
57 |
58 | def ackB(a, b) = r
59 | | BoxedInt[x], BoxedInt[y]
60 | if [de.szeiger.interact.MainTest.is0(x)] =>
61 | BoxedInt[de.szeiger.interact.MainTest.inc(y)]
62 | [eraseRef(x)]
63 | if [de.szeiger.interact.MainTest.is0(y)] =>
64 | ackB(BoxedInt[de.szeiger.interact.MainTest.dec(x)], BoxedInt[de.szeiger.interact.MainTest.box(1)])
65 | [eraseRef(y)]
66 | else =>
67 | [de.szeiger.interact.MainTest.ackHelper(x, x1, x2)]
68 | ackB(BoxedInt[x1], ackB(BoxedInt[x2], BoxedInt[de.szeiger.interact.MainTest.dec(y)]))
69 |
70 | let resB = ackB(BoxedInt[de.szeiger.interact.MainTest.box(3)], BoxedInt[de.szeiger.interact.MainTest.box(8)])
71 |
--------------------------------------------------------------------------------
/src/test/resources/diverging.check:
--------------------------------------------------------------------------------
1 | Error: src/test/resources/diverging.in:7:7: Circular expansion (f <-> A) => (f <-> B) => (f <-> C) => (f <-> D)
2 | | | A => f(B)
3 | | ^
4 |
5 | Error: src/test/resources/diverging.in:8:7: Circular expansion (f <-> B) => (f <-> C) => (f <-> D) => (f <-> A)
6 | | | B => f(C)
7 | | ^
8 |
9 | Error: src/test/resources/diverging.in:9:7: Circular expansion (f <-> C) => (f <-> D) => (f <-> A) => (f <-> B)
10 | | | C => f(D)
11 | | ^
12 |
13 | Error: src/test/resources/diverging.in:10:7: Circular expansion (f <-> D) => (f <-> A) => (f <-> B) => (f <-> C)
14 | | | D => f(A)
15 | | ^
16 |
17 | 4 errors found.
18 |
--------------------------------------------------------------------------------
/src/test/resources/diverging.in:
--------------------------------------------------------------------------------
1 | cons A
2 | cons B
3 | cons C
4 | cons D
5 |
6 | def f(_) = r
7 | | A => f(B)
8 | | B => f(C)
9 | | C => f(D)
10 | | D => f(A)
11 |
--------------------------------------------------------------------------------
/src/test/resources/embedded.check:
--------------------------------------------------------------------------------
1 | dummy = Dummy[d:0:0]
2 | dummy2b = Dummy[d2:1:1]
3 | fib10 = Int[89]
4 | i1 = Int[42]
5 | i2 = Int[42]
6 | len = Int[5]
7 | mult = Int[6]
8 | r1 = Pair(1n, 1n)
9 | r2 = Pair(2n, 2n)
10 | simple = Int[42]
11 | str = String["foo"]
12 | sum = Int[8]
13 |
--------------------------------------------------------------------------------
/src/test/resources/embedded.in:
--------------------------------------------------------------------------------
1 | cons Int[int]
2 | cons String[ref]
3 | cons Dummy[ref]
4 |
5 | # Automatic move of payload
6 | def _ + y = r
7 | | Int[i] => intAdd[i](y)
8 |
9 | #match Int[i] + y => intAdd[i](y)
10 |
11 | # Explicit computation on payloads
12 | def intAdd[int a](_) = r
13 | | Int[b] => [add(a, b, c)]; Int[c]
14 |
15 | # Value-returning method call instead of out parameter
16 | def strlen(_) = r
17 | | String[s] => Int[strlen(s)]
18 |
19 | let simple = Int[42]
20 | x = Int[5]; y = Int[3]; sum = x + y
21 | str = String["foo"]
22 | len = strlen(String["12345"])
23 | dummy = Dummy[d]; [de.szeiger.interact.ManagedDummy.create("d", d)]
24 | dummy2 = Dummy[d2]; [de.szeiger.interact.ManagedDummy.create("d2", d2)]
25 | (dummy2a, dummy2b) = dup(dummy2)
26 | erase(dummy2a)
27 |
28 | # Currying
29 | def _ * _ = r
30 | | Int[a], Int[b] => [mult(a, b, c)]; Int[c]
31 |
32 | let mult = Int[3] * Int[2]
33 |
34 | # Conditional matching
35 | def fib(_) = r
36 | | Int[i] if [i == 0] => Int[1]
37 | if [i == 1] => Int[1]
38 | else => fib(Int[i-1]) + fib(Int[i-2])
39 |
40 | let fib10 = fib(Int[10])
41 |
42 |
43 |
44 |
45 | cons Z
46 | cons S(x)
47 |
48 | cons Func(x, fx)
49 |
50 | def apply(l, in) = out
51 | | Func(i, o) => in = i; o
52 |
53 | cons Pair(a, b)
54 |
55 | let f = Func(x, Pair(x1, x2))
56 | (x1, x2) = dup[l1](x)
57 | (f1, f2) = dup[l2](f)
58 | r1 = apply(f1, 1n)
59 | r2 = apply(f2, 2n)
60 |
61 | let i0 = Int[42]
62 | (i1, i2) = dup(i0)
63 |
--------------------------------------------------------------------------------
/src/test/resources/fib.check:
--------------------------------------------------------------------------------
1 | res = 89n
2 |
--------------------------------------------------------------------------------
/src/test/resources/fib.in:
--------------------------------------------------------------------------------
1 | cons Z
2 | cons S(n)
3 |
4 | def _ + y = r
5 | | Z => y
6 | | S(x) => S(x + y)
7 |
8 | def fib(_) = r
9 | | Z => 1n
10 | | S(Z) => 1n
11 | | S(S(n)) => (n1, n2) = dup(n)
12 | fib(S(n1)) + fib(n2)
13 |
14 | let res = fib(10n)
15 |
--------------------------------------------------------------------------------
/src/test/resources/inlining.check:
--------------------------------------------------------------------------------
1 | res = D
2 | res1 = Int[1]
3 | res2 = Int[2]
4 | res22 = Int[22]
5 | res3 = 9n
6 |
7 | Irreducible pairs:
8 | A <-> C
9 | B <-> B
10 | B <-> B
11 | B <-> B
12 | B <-> B
13 | B <-> B
14 |
--------------------------------------------------------------------------------
/src/test/resources/inlining.in:
--------------------------------------------------------------------------------
1 | cons A(x)
2 | cons B
3 | cons C(x)
4 | cons D
5 |
6 | match A(x) = B => C(x) = D; A(D) = C(D)
7 | match C(x) = D => x = D
8 |
9 | let A(res) = B
10 |
11 |
12 | cons Int[int]
13 |
14 | def f(_) = r
15 | | Int[i] if [i == 0] => B = B; g(Int[i + 10])
16 | else => B = B; B = B; g(Int[i + 20])
17 |
18 | def g(_) = r
19 | | Int[i] if [i == 10] => Int[1]
20 | if [i == 21] => Int[2]
21 | else => Int[i]
22 |
23 | let
24 | res1 = f(Int[0])
25 | res2 = f(Int[1])
26 | res22 = f(Int[2])
27 |
28 |
29 | cons Z
30 | cons S(n)
31 | def f1(_, a) = b
32 | | S(x) => g1(a, x)
33 | def g1(_, a) = b
34 | | S(y) => H1(y, b) = a
35 | cons H1(y, b) = a
36 | match H1(y, b) = S(t) => b = f1(S(t), y)
37 |
38 |
39 | def pred(_) = r
40 | | Z => Z
41 | | S(x) => x
42 | def ack2(_, a) = b
43 | | Z => S(a)
44 | | S(x) => ack2b(a, S(x))
45 | def ack2b(_, a) = b
46 | | Z => ack2(pred(a), S(Z))
47 | | S(y), S(t) => (t1, t2) = dup(t); ack2(pred(S(t1)), ack2(S(t2), y))
48 | let res3 = ack2(2n, 3n)
49 |
--------------------------------------------------------------------------------
/src/test/resources/lists.check:
--------------------------------------------------------------------------------
1 | flatMapped = 11n :: 12n :: 21n :: 22n :: 31n :: 32n :: Nil
2 | idMapped = 1n :: 2n :: 3n :: Nil
3 | l0_length = 3n
4 | l0_mapped = 3n :: 4n :: 5n :: Nil
5 | l0_mapped_lambda = 3n :: 4n :: 5n :: Nil
6 | listCons = 1n :: 2n :: 3n :: 4n :: 5n ::: Nil
7 |
8 | Irreducible pairs:
9 | ::: <-> S
10 |
11 |
--------------------------------------------------------------------------------
/src/test/resources/lists.in:
--------------------------------------------------------------------------------
1 | # Natural numbers
2 | cons Z
3 | cons S(n)
4 |
5 | # Addition
6 | def _ + y = r
7 | | Z => y
8 | | S(x) => x + S(y)
9 |
10 | # Lists
11 | cons Nil
12 | cons head :: tail = l
13 |
14 | def length(list) = r
15 | | Nil => Z
16 | | x :: xs => erase(x); S(length(xs))
17 |
18 | def map(list, fi, fo) = r
19 | | Nil => erase(fi); erase(fo); Nil
20 | | x :: xs => (x, fi2) = dup(fi)
21 | (fo1, fo2) = dup(fo)
22 | fo1 :: map(xs, fi2, fo2)
23 |
24 | def _ ::: ys = r
25 | | Nil => ys
26 | | x :: xs => x :: (xs ::: ys)
27 |
28 | def flatMap(list, fi, fo) = r
29 | | Nil => erase(fi); erase(fo); Nil
30 | | x :: xs => (x, fi2) = dup(fi)
31 | (fo1, fo2) = dup(fo)
32 | fo1 ::: flatMap(xs, fi2, fo2)
33 |
34 | # Example: List operations
35 | let l0 = 1n :: 2n :: 3n :: Nil
36 | (l0a, l0b) = dup(l0)
37 | l0_length = length(l0a)
38 | l0_mapped = map(l0b, x, x + 2n)
39 |
40 | let listCons = (1n ::2n :: 3n :: Nil) ::: (4n :: 5n ::: Nil)
41 |
42 | let idMapped = map(1n :: 2n :: 3n :: Nil, y, y)
43 |
44 | def mkList(_) = r
45 | | Z => Z :: Z :: Nil
46 | | S(x) => (s1, s2) = dup(S(x))
47 | s1 + 1n :: s2 + 2n :: Nil
48 |
49 | let flatMapped =
50 | flatMap(10n :: 20n :: 30n :: Nil, x, mkList(x))
51 |
52 | # Explicit lambdas
53 | cons in |> out
54 | def apply(l, in) = out
55 | | i |> o => in = i; o
56 |
57 | # Example: List mapping with lambdas
58 | def map2(l, f) = r
59 | | Nil => erase(f); Nil
60 | | x :: xs => (f1, f2) = dup(f)
61 | apply(f1, x) :: map2(xs, f2)
62 |
63 | let l0 = 1n :: 2n :: 3n :: Nil
64 | l0_mapped_lambda = map2(l0, x |> x + 2n)
65 |
--------------------------------------------------------------------------------
/src/test/resources/par-mult.check:
--------------------------------------------------------------------------------
1 | res = 10000n
2 | resI = Int[10000]
3 |
--------------------------------------------------------------------------------
/src/test/resources/par-mult.in:
--------------------------------------------------------------------------------
1 | cons Z
2 | cons S(n)
3 |
4 | def _ + y = r
5 | | Z => y
6 | | S(x) => x + S(y)
7 |
8 | def _ * y = r
9 | | Z => erase(y); Z
10 | | S(x) => (a, b) = dup(y)
11 | r = b + x * a
12 |
13 | let res = 100n * 100n
14 |
15 |
16 | cons Int[int]
17 | def add(_, _) = r
18 | | Int[x], Int[y] if [x == 0] => Int[y]
19 | else => add(Int[x-1], Int[y+1])
20 | def mult(_, _) = r
21 | | Int[x], Int[y] if [x == 0] => Int[0]
22 | else => add(Int[y], mult(Int[x-1], Int[y]))
23 | let resI = mult(Int[100], Int[100])
24 |
--------------------------------------------------------------------------------
/src/test/resources/seq-def.check:
--------------------------------------------------------------------------------
1 | example_3_plus_5 = 8n
2 | example_3_times_2 = 6n
3 | foo_curry = 1n
4 |
--------------------------------------------------------------------------------
/src/test/resources/seq-def.in:
--------------------------------------------------------------------------------
1 | # Natural numbers
2 | cons Z
3 | cons S(n)
4 |
5 | # Erasure and Duplication
6 | # def erase(_)
7 | # def dup(_) = (a, b)
8 | # | dup(_) = (c, d) => (c, d)
9 |
10 | # Addition
11 | def _ + y = r
12 | | Z => y
13 | | S(x) => x + S(y)
14 |
15 | # Separate reductions
16 | #match Z + y => y
17 | #match S(x) + y => x + S(y)
18 |
19 | #match erase(Z) => ()
20 | #match erase(S(x)) => erase(x)
21 | #match dup(Z) = (a, b) => a = Z; b = Z
22 | #match dup(S(x)) = (a, b) => (sa, sb) = dup(x); s = S(sa); b = S(sb)
23 | #match dup(dup(_) = (c, d)) = (a, b) => (c, d)
24 | #match S(x) = S(y) => x = y
25 |
26 | # Multiplication
27 | def _ * y = r
28 | # | Z => erase(y), Z
29 | # | S(x) => (y1, y2) = dup(y); x * y1 + y2
30 |
31 | match Z * y => erase(y); Z
32 | match S(x) * y => (y1, y2) = dup(y); x * y1 + y2
33 |
34 | # Example: Computations on church numerals
35 | let y = 5n
36 | x = 3n
37 | example_3_plus_5 = x + y
38 |
39 | let example_3_times_2 = 3n * 2n
40 |
41 | # Currying on additional args:
42 | def foo(a, b) = r
43 | | Z => erase(b); Z
44 | | S(x), S(y) => x + y
45 | let foo_curry = foo(1n, 2n)
46 |
--------------------------------------------------------------------------------
/src/test/scala/BitOpsTest.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import org.junit.Assert._
4 | import org.junit.Test
5 |
6 | import BitOps._
7 |
8 | class BitOpsTest {
9 |
10 | val vals = Seq(-128, -127, -1, 0, 1, 126, 127)
11 |
12 | @Test
13 | def testByte0(): Unit = {
14 | for {
15 | v0 <- vals
16 | } {
17 | val i = checkedIntOfBytes(v0, 0, 0, 0)
18 | assertEquals(v0, byte0(i))
19 | }
20 | }
21 |
22 | @Test
23 | def testByte1(): Unit = {
24 | for {
25 | v1 <- vals
26 | } {
27 | val i = checkedIntOfBytes(0, v1, 0, 0)
28 | assertEquals(v1, byte1(i))
29 | }
30 | }
31 |
32 | @Test
33 | def testByte2(): Unit = {
34 | for {
35 | v2 <- vals
36 | } {
37 | val i = checkedIntOfBytes(0, 0, v2, 0)
38 | assertEquals(v2, byte2(i))
39 | }
40 | }
41 |
42 | @Test
43 | def testByte3(): Unit = {
44 | for {
45 | v3 <- vals
46 | } {
47 | val i = checkedIntOfBytes(0, 0, 0, v3)
48 | assertEquals(v3, byte3(i))
49 | }
50 | }
51 |
52 | @Test
53 | def testMixed(): Unit = {
54 | val vals = Seq(-128, -127, -1, 0, 1, 126, 127)
55 | for {
56 | v0 <- vals
57 | v1 <- vals
58 | v2 <- vals
59 | v3 <- vals
60 | } {
61 | val i = checkedIntOfBytes(v0, v1, v2, v3)
62 | assertEquals(v0, byte0(i))
63 | assertEquals(v1, byte1(i))
64 | assertEquals(v2, byte2(i))
65 | assertEquals(v3, byte3(i))
66 | }
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/src/test/scala/LongBitOpsTest.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import org.junit.Assert._
4 | import org.junit.Test
5 |
6 | import LongBitOps._
7 |
8 | class LongBitOpsTest {
9 |
10 | val vals = Seq(Short.MinValue, Short.MinValue+1, -128, -127, -1, 0, 1, 126, 127, Short.MaxValue-1, Short.MaxValue)
11 |
12 | @Test
13 | def testShort0(): Unit = {
14 | for {
15 | v0 <- vals
16 | } {
17 | val i = checkedLongOfShorts(v0, 0, 0, 0)
18 | assertEquals(v0, short0(i))
19 | }
20 | }
21 |
22 | @Test
23 | def testShort1(): Unit = {
24 | for {
25 | v1 <- vals
26 | } {
27 | val i = checkedLongOfShorts(0, v1, 0, 0)
28 | assertEquals(v1, short1(i))
29 | }
30 | }
31 |
32 | @Test
33 | def testShort2(): Unit = {
34 | for {
35 | v2 <- vals
36 | } {
37 | val i = checkedLongOfShorts(0, 0, v2, 0)
38 | assertEquals(v2, short2(i))
39 | }
40 | }
41 |
42 | @Test
43 | def testShort3(): Unit = {
44 | for {
45 | v3 <- vals
46 | } {
47 | val i = checkedLongOfShorts(0, 0, 0, v3)
48 | assertEquals(v3, short3(i))
49 | }
50 | }
51 |
52 | @Test
53 | def testMixed(): Unit = {
54 | for {
55 | v0 <- vals
56 | v1 <- vals
57 | v2 <- vals
58 | v3 <- vals
59 | } {
60 | val i = checkedLongOfShorts(v0, v1, v2, v3)
61 | assertEquals(v0, short0(i))
62 | assertEquals(v1, short1(i))
63 | assertEquals(v2, short2(i))
64 | assertEquals(v3, short3(i))
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/test/scala/MainTest.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import org.junit.Test
4 | import org.junit.runner.RunWith
5 | import org.junit.runners.Parameterized
6 | import org.junit.runners.Parameterized.Parameters
7 |
8 | import java.nio.file.Path
9 | import scala.jdk.CollectionConverters._
10 |
11 | @RunWith(classOf[Parameterized])
12 | class MainTest(spec: String) {
13 | val SCALE = 0
14 | val conf = Config.defaultConfig.withSpec(spec).copy(showAfter = Set(""), phaseLog = Set(""),
15 | //writeOutput = Some(Path.of("bench/gen-classes")), writeJava = Some(Path.of("bench/gen-src"))
16 | )
17 |
18 | def check(testName: String, scaleFactor: Int = 1, expectedSteps: Int = -1, fail: Boolean = false, config: Config = conf): Unit =
19 | for(i <- 1 to (if(SCALE == 0) 1 else SCALE * scaleFactor)) TestUtils.check(testName, expectedSteps, fail, config)
20 |
21 | @Test def testSeqDef = check("seq-def", scaleFactor = 50, expectedSteps = 32)
22 | @Test def testLists = check("lists")
23 | @Test def testParMult = check("par-mult")
24 | @Test def testInlining = check("inlining", expectedSteps = if(conf.backend.allowPayloadTemp) 99 else 102)
25 | @Test def testFib = check("fib")
26 | @Test def testEmbedded = check("embedded")
27 | @Test def testAck = check("ack", expectedSteps = if(conf.backend.allowPayloadTemp) 18114077 else 23686073)
28 | @Test def testDiverging = check("diverging", fail = true)
29 | }
30 |
31 | object MainTest {
32 | @Parameters(name = "{0}")
33 | def interpreters = Seq(
34 | "sti", "stc1", "stc2",
35 | //"mt0.i", "mt1.i", "mt8.i",
36 | //"mt1000.i", "mt1001.i", "mt1008.i",
37 | //"mt0.c", "mt1.c", "mt8.c",
38 | //"mt1000.c", "mt1001.c", "mt1008.c",
39 | ).map(s => Array[AnyRef](s)).asJava
40 |
41 | // used by ack.in:
42 | def is0(i: java.lang.Integer): Boolean = i.intValue() == 0
43 | def box(i: Int): java.lang.Integer = Integer.valueOf(i)
44 | def inc(i: java.lang.Integer): java.lang.Integer = box(i.intValue() + 1)
45 | def dec(i: java.lang.Integer): java.lang.Integer = box(i.intValue() - 1)
46 | def ackHelper(i: java.lang.Integer, o1: RefOutput, o2: RefOutput): Unit = {
47 | o1.setValue(dec(i))
48 | o2.setValue(i)
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/src/test/scala/TestUtils.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact
2 |
3 | import org.junit.Assert
4 | import org.junit.Assert._
5 |
6 | import java.io.{ByteArrayOutputStream, PrintStream}
7 | import java.nio.charset.StandardCharsets
8 | import java.nio.file.{Files, Path}
9 | import java.util.concurrent.atomic.AtomicInteger
10 |
11 | object TestUtils {
12 | def check(testName: String, expectedSteps: Int = -1, fail: Boolean = false, config: Config = Config.defaultConfig): Unit = {
13 | val basePath = s"src/test/resources/$testName"
14 | val statements = Parser.parse(Path.of(basePath+".in"))
15 | val (result, success, steps) = try {
16 | val model = new Compiler(statements, config.copy(collectStats = true))
17 | val inter = model.createInterpreter()
18 | inter.initData()
19 | inter.reduce()
20 | if(inter.getMetrics != null) inter.getMetrics.logStats()
21 | val out = new ByteArrayOutputStream()
22 | inter.getAnalyzer.log(new PrintStream(out, true, StandardCharsets.UTF_8), color = false)
23 | (out.toString(StandardCharsets.UTF_8), true, inter.getMetrics.getSteps)
24 | } catch {
25 | case ex: CompilerResult => (Colors.stripColors(ex.getMessage).replace("src\\test\\resources\\", "src/test/resources/"), false, -1)
26 | }
27 | val checkFile = Path.of(basePath+".check")
28 | if(Files.exists(checkFile)) {
29 | val check = Files.readString(checkFile, StandardCharsets.UTF_8)
30 | if(check.trim.replaceAll("\r", "") != result.trim.replaceAll("\r", "")) {
31 | println("---- Expected ----")
32 | println(check)
33 | println("---- Actual ----")
34 | println(result)
35 | println("---- End ----")
36 | assertEquals(check.trim.replaceAll("\r", ""), result.trim.replaceAll("\r", ""))
37 | }
38 | }
39 | if(fail && success) Assert.fail("Failure expected")
40 | else if(!fail && !success) Assert.fail("Unexpected failure")
41 | if(expectedSteps >= 0 && success) assertEquals(s"Expected $expectedSteps steps, but fully reduced after $steps", expectedSteps, steps)
42 | }
43 | }
44 |
45 | class ManagedDummy(name: String) extends LifecycleManaged {
46 | private[this] val copied, erased = new AtomicInteger()
47 | override def toString: String = s"$name:$copied:$erased"
48 | override def erase(): Unit = erased.incrementAndGet()
49 | override def copy(): LifecycleManaged = { copied.incrementAndGet(); this }
50 | }
51 | object ManagedDummy {
52 | def create(name: String, res: RefOutput): Unit = res.setValue(new ManagedDummy(name))
53 | }
54 |
--------------------------------------------------------------------------------
/src/test/scala/WorkersTest.scala:
--------------------------------------------------------------------------------
1 | package de.szeiger.interact.mt
2 |
3 | import de.szeiger.interact.mt.workers.{Worker, Workers}
4 | import org.junit.Assert._
5 | import org.junit.Test
6 |
7 | import java.util.concurrent.CountDownLatch
8 | import java.util.concurrent.atomic.AtomicInteger
9 |
10 | class WorkersTest {
11 |
12 | @Test
13 | def test1(): Unit = {
14 | val unfinished = new AtomicInteger(0)
15 | var latch: CountDownLatch = new CountDownLatch(1)
16 | val count = new AtomicInteger(0)
17 | case class Task(sub: Task*)
18 | def add(t: Task): Unit = {
19 | unfinished.incrementAndGet()
20 | ws.add(t)
21 | }
22 | class Processor extends Worker[Task] {
23 | def apply(t: Task): Unit = {
24 | count.incrementAndGet()
25 | t.sub.foreach { t2 =>
26 | Thread.sleep(50)
27 | unfinished.incrementAndGet()
28 | this.add(t2)
29 | }
30 | if(unfinished.decrementAndGet() == 0) latch.countDown()
31 | }
32 | }
33 | lazy val ws = new Workers[Task](8, 1024, _ => new Processor)
34 | add(Task())
35 | ws.start()
36 | add(Task(Task(Task(), Task(Task(), Task())), Task(Task(), Task(), Task())))
37 | while(unfinished.get() != 0) latch.await()
38 | assertEquals(11, count.get())
39 | latch = new CountDownLatch(1)
40 | add(Task(Task(Task(), Task(Task(), Task())), Task(Task(), Task(), Task())))
41 | while(unfinished.get() != 0) latch.await()
42 | assertEquals(21, count.get())
43 | ws.shutdown()
44 | }
45 | }
46 |
--------------------------------------------------------------------------------