├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── .mailmap ├── README.md ├── build.zig ├── build.zig.zon ├── frontends ├── arocc │ ├── .gitignore │ ├── README.md │ ├── build.zig │ ├── build.zig.zon │ └── src │ │ ├── CodeGen.zig │ │ └── main.zig └── scc │ ├── .gitignore │ ├── build.zig │ └── src │ ├── Ast.zig │ ├── CodeGen.zig │ ├── Parser.zig │ ├── Tokenizer.zig │ └── main.zig ├── render.sh ├── src ├── Oir.zig ├── Oir │ ├── SimpleExtractor.zig │ ├── extraction.zig │ ├── print_oir.zig │ └── z3.zig ├── codegen │ └── p2.zig ├── cost.zig ├── lib.zig ├── passes │ ├── constant_fold.zig │ ├── rewrite.zig │ └── rewrite │ │ ├── SExpr.zig │ │ ├── machine.zig │ │ └── table.zon └── trace.zig └── test.c /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | pull_request: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | with: 13 | fetch-depth: 0 14 | 15 | - name: setup-zig 16 | uses: mlugg/setup-zig@v1 17 | with: 18 | version: 0.15.0-dev.460+f4e9846bc 19 | 20 | - name: build 21 | run: zig build 22 | 23 | - name: test 24 | run: zig build test --summary all -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .zig-cache/ 2 | zig-out/ 3 | .vscode/ 4 | out.dot 5 | graphs/* 6 | trace.json 7 | input.ir 8 | test.dot 9 | input.ir 10 | -------------------------------------------------------------------------------- /.mailmap: -------------------------------------------------------------------------------- 1 | David Rubin 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Zig Optimizing Backend 2 | 3 | This is sort of two projects in one. The first goal is to develop a prototype optimizing backend 4 | for the Zig compiler, testing out different IRs and representations, to find the one that will 5 | suite Zig the best. A second goal is to create a general optimizing backend, one that can be used 6 | as for simple compiled languages, test out new and interesting optimizations, and just see how Zig 7 | fairs as a language, for writing optimizers. 8 | 9 | The current implementation is a mix of E-Graphs and Sea of Nodes. At first, I was keen on using RVSDG 10 | since given Zig's very structural nature it seemed like it would be a good fit, however after a few 11 | months of fiddling around with it, RVSDG is just too restrictive. This is likely just my own experience, 12 | but I found it extremely difficult to develop algorithms to convert from Zig's SSA to RVSDG. 13 | Maybe I'll come back to it one day, but for now, I'd like to use something a bit less esoteric, such as 14 | SoN, and move further in the implementation. 15 | 16 | Implementation goals are: 17 | - Implement an abstract optimizing IR based on e-graphs, rvsdg, and SoN. 18 | - Implement a RISC-V machine code backend to provide a simpler method 19 | of register allocation and backend generalization. 20 | -------------------------------------------------------------------------------- /build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | const frontends = .{ 4 | .{ "scc", "C" }, 5 | .{ "arocc", "C-aro" }, 6 | }; 7 | 8 | pub fn build(b: *std.Build) !void { 9 | const target = b.standardTargetOptions(.{}); 10 | const optimize = b.standardOptimizeOption(.{}); 11 | 12 | const use_z3 = b.option(bool, "use_z3", "Use Z3 as the MILP solver for Oir extraction") orelse false; 13 | const filters = b.option([]const []const u8, "filter", "Filter test cases"); 14 | const trace = b.option(bool, "trace", "Enable tracing output to trace.json") orelse false; 15 | 16 | var options = b.addOptions(); 17 | options.addOption(bool, "has_z3", use_z3); 18 | options.addOption(bool, "enable_trace", trace); 19 | 20 | const zob_mod = b.addModule("zob", .{ 21 | .target = target, 22 | .optimize = optimize, 23 | .root_source_file = b.path("src/lib.zig"), 24 | }); 25 | zob_mod.addOptions("build_options", options); 26 | 27 | const test_lib = b.addTest(.{ 28 | .root_source_file = b.path("src/lib.zig"), 29 | .target = target, 30 | .optimize = optimize, 31 | .filters = filters orelse &.{}, 32 | }); 33 | test_lib.root_module.addOptions("build_options", options); 34 | 35 | const test_step = b.step("test", "Run the tests"); 36 | test_step.dependOn(&b.addRunArtifact(test_lib).step); 37 | 38 | if (use_z3) { 39 | const z3 = b.lazyDependency("z3", .{ .target = target, .optimize = optimize }) orelse return; 40 | const z3_mod = z3.module("z3_bindings"); 41 | zob_mod.addImport("z3", z3_mod); 42 | test_lib.root_module.addImport("z3", z3_mod); 43 | } 44 | 45 | const test_frontends = b.step("test-frontends", "Runs frontend tests"); 46 | if (filters == null) test_step.dependOn(test_frontends); 47 | 48 | inline for (frontends) |frontend| { 49 | const name, const lang = frontend; 50 | 51 | const step = b.step(name, lang ++ " language compiler"); 52 | const dep = b.dependency(name, .{ 53 | .target = target, 54 | .optimize = optimize, 55 | }); 56 | 57 | const artifact = dep.artifact(name); 58 | artifact.root_module.addImport("zob", zob_mod); 59 | b.installArtifact(artifact); 60 | 61 | const run = b.addRunArtifact(artifact); 62 | if (b.args) |args| run.addArgs(args); 63 | step.dependOn(&run.step); 64 | 65 | const test_exe = b.addTest(.{ 66 | .name = name, 67 | .root_module = dep.module(name), 68 | }); 69 | test_frontends.dependOn(&b.addRunArtifact(test_exe).step); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /build.zig.zon: -------------------------------------------------------------------------------- 1 | .{ 2 | .name = .zob, 3 | .version = "0.0.0", 4 | .paths = .{""}, 5 | .fingerprint = 0xe7fe43888bef583b, 6 | .minimum_zig_version = "0.15.0-dev.460+f4e9846bc", 7 | .dependencies = .{ 8 | .z3 = .{ 9 | .lazy = true, 10 | .url = "git+https://github.com/Rexicon226/zig-z3#e1e2037479923c2519adf48217356ec98819e656", 11 | .hash = "z3-0.1.0-bVhhr3MvCwDuxOJJNOkyUm9dU_SnEomjMFFcTlQjwJHy", 12 | }, 13 | .arocc = .{ .path = "frontends/arocc" }, 14 | .scc = .{ .path = "frontends/scc" }, 15 | }, 16 | } 17 | -------------------------------------------------------------------------------- /frontends/arocc/.gitignore: -------------------------------------------------------------------------------- 1 | .zig-cache/ 2 | zig-out/ -------------------------------------------------------------------------------- /frontends/arocc/README.md: -------------------------------------------------------------------------------- 1 | # Aro C Frontend 2 | 3 | C Frontend, similar to SCC, but uses Aro. 4 | https://github.com/Vexu/arocc -------------------------------------------------------------------------------- /frontends/arocc/build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn build(b: *std.Build) !void { 4 | const target = b.standardTargetOptions(.{}); 5 | const optimize = b.standardOptimizeOption(.{}); 6 | 7 | const aro = b.dependency("aro", .{ 8 | .target = target, 9 | .optimize = optimize, 10 | }); 11 | 12 | const main = b.addModule("arocc", .{ 13 | .target = target, 14 | .optimize = optimize, 15 | .root_source_file = b.path("src/main.zig"), 16 | }); 17 | main.addImport("aro", aro.module("aro")); 18 | 19 | const exe = b.addExecutable(.{ 20 | .name = "arocc", 21 | .root_module = main, 22 | }); 23 | b.installArtifact(exe); 24 | } 25 | -------------------------------------------------------------------------------- /frontends/arocc/build.zig.zon: -------------------------------------------------------------------------------- 1 | .{ 2 | .name = .arocc, 3 | .version = "0.0.0", 4 | .fingerprint = 0x7aa6456c37a4380a, 5 | .minimum_zig_version = "0.14.0", 6 | .dependencies = .{ 7 | .aro = .{ 8 | .url = "git+https://github.com/Vexu/arocc#348624998092bf4173f365c21fc8c0f43851ce97", 9 | .hash = "aro-0.0.0-JSD1QtGvJgCLKX4yRsgVVtmQu3tgoC6a-byQhxAA6011", 10 | }, 11 | }, 12 | .paths = .{ 13 | "build.zig", 14 | "build.zig.zon", 15 | "src", 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /frontends/arocc/src/CodeGen.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const aro = @import("aro"); 3 | const zob = @import("zob"); 4 | const CodeGen = @This(); 5 | 6 | const Tree = aro.Tree; 7 | const Oir = zob.Oir; 8 | const Recursive = zob.Recursive; 9 | 10 | gpa: std.mem.Allocator, 11 | oir: *Oir, 12 | tree: *const Tree, 13 | ctrl_class: ?Oir.Class.Index, 14 | exits: *std.ArrayListUnmanaged(Oir.Class.Index), 15 | node_to_class: std.AutoHashMapUnmanaged(Tree.Node.Index, Oir.Class.Index), 16 | symbol_table: SymbolTable, 17 | scratch: std.ArrayListUnmanaged(Oir.Class.Index), 18 | 19 | const Error = error{OutOfMemory}; 20 | const SymbolTable = std.ArrayListUnmanaged(std.StringHashMapUnmanaged(Oir.Class.Index)); 21 | 22 | pub fn init( 23 | oir: *Oir, 24 | gpa: std.mem.Allocator, 25 | tree: *const Tree, 26 | ) !CodeGen { 27 | const start_class = try oir.add(.{ 28 | .tag = .start, 29 | .data = .{ .list = .{ .start = 0, .end = 0 } }, 30 | }); 31 | const ctrl_class = try oir.add(.project(0, start_class, .ctrl)); 32 | 33 | var symbol_table: SymbolTable = .{}; 34 | try symbol_table.append(gpa, .{}); 35 | 36 | return .{ 37 | .gpa = gpa, 38 | .oir = oir, 39 | .tree = tree, 40 | .ctrl_class = ctrl_class, 41 | .node_to_class = .{}, 42 | .exits = &oir.exit_list, 43 | .scratch = .{}, 44 | .symbol_table = symbol_table, 45 | }; 46 | } 47 | 48 | pub fn build(cg: *CodeGen) !Recursive { 49 | const stdout = std.io.getStdOut().writer(); 50 | 51 | const tree = cg.tree; 52 | const node_tags = tree.nodes.items(.tag); 53 | 54 | for (cg.tree.root_decls.items) |node| { 55 | switch (cg.tree.nodes.items(.tag)[@intFromEnum(node)]) { 56 | .fn_def => try cg.buildFn(node), 57 | .typedef => {}, 58 | else => std.debug.panic("TODO: {s}", .{@tagName(node_tags[@intFromEnum(node)])}), 59 | } 60 | } 61 | 62 | try cg.oir.rebuild(); 63 | 64 | try stdout.writeAll("unoptimized OIR:\n"); 65 | try cg.oir.print(stdout); 66 | try stdout.writeAll("end OIR\n"); 67 | 68 | try cg.oir.optimize(.saturate, false); 69 | 70 | return cg.oir.extract(.auto); 71 | } 72 | 73 | fn buildFn(cg: *CodeGen, decl: Tree.Node.Index) !void { 74 | const tree = cg.tree; 75 | const node_tags = tree.nodes.items(.tag); 76 | 77 | switch (decl.get(tree)) { 78 | // TODO: Oir should only represent one function - currently all functions 79 | // are put into the same Oir, which would easily create an invalid graph. 80 | .fn_def => |def| { 81 | const func_ty = def.qt.base(tree.comp).type.func; 82 | for (func_ty.params, 0..) |param, i| { 83 | const name = cg.tree.tokSlice(param.name_tok); 84 | const node = try cg.oir.add(.project(@intCast(i + 1), .start, .data)); 85 | 86 | const latest = &cg.symbol_table.items[cg.symbol_table.items.len - 1]; 87 | try latest.put(cg.gpa, name, node); 88 | } 89 | try cg.buildStmt(def.body); 90 | }, 91 | .typedef => {}, 92 | else => std.debug.panic("TODO: {s}", .{@tagName(node_tags[@intFromEnum(decl)])}), 93 | } 94 | } 95 | 96 | fn buildStmt(cg: *CodeGen, stmt: Tree.Node.Index) !void { 97 | const scratch_top = cg.scratch.items.len; 98 | const tree = cg.tree; 99 | const node_tags = tree.nodes.items(.tag); 100 | 101 | switch (stmt.get(tree)) { 102 | .return_stmt => |ret| { 103 | const operand: Oir.Class.Index = switch (ret.operand) { 104 | .expr => |idx| try cg.buildExpr(idx), 105 | .implicit => |zeroes| { 106 | _ = zeroes; 107 | @panic("TODO"); 108 | }, 109 | .none => @panic("TODO"), 110 | }; 111 | 112 | const node = try cg.oir.add(.binOp( 113 | .ret, 114 | cg.ctrl_class.?, 115 | operand, 116 | )); 117 | 118 | try cg.exits.append(cg.gpa, node); 119 | try cg.node_to_class.put(cg.gpa, stmt, node); 120 | cg.ctrl_class = null; // nothing can exist after return 121 | }, 122 | .if_stmt => |cond_br| { 123 | const predicate = try cg.buildExpr(cond_br.cond); 124 | 125 | const branch = try cg.oir.add(.branch(cg.ctrl_class.?, predicate)); 126 | const then_project = try cg.oir.add(.project(0, branch, .ctrl)); 127 | const else_project = try cg.oir.add(.project(1, branch, .ctrl)); 128 | 129 | cg.ctrl_class = then_project; 130 | try cg.buildStmt(cond_br.then_body); 131 | const latest_then_ctrl = cg.ctrl_class; 132 | 133 | cg.ctrl_class = else_project; 134 | try cg.buildStmt(cond_br.else_body orelse @panic("TODO")); 135 | const latest_else_ctrl = cg.ctrl_class; 136 | 137 | if (latest_then_ctrl == null and latest_else_ctrl == null) { 138 | // this region is dead, we can ignore it 139 | return; 140 | } 141 | 142 | if (latest_then_ctrl) |ctrl| { 143 | try cg.scratch.append(cg.gpa, ctrl); 144 | } 145 | if (latest_else_ctrl) |ctrl| { 146 | try cg.scratch.append(cg.gpa, ctrl); 147 | } 148 | 149 | const items = cg.scratch.items[scratch_top..]; 150 | const list = try cg.oir.listToSpan(items); 151 | cg.ctrl_class = try cg.oir.add(.region(list)); 152 | }, 153 | .compound_stmt => |compound| { 154 | for (compound.body) |s| try cg.buildStmt(s); 155 | return; // nothing to add to the oir 156 | }, 157 | else => std.debug.panic("TODO: {s}", .{@tagName(node_tags[@intFromEnum(stmt)])}), 158 | } 159 | } 160 | 161 | fn buildExpr(cg: *CodeGen, expr: Tree.Node.Index) !Oir.Class.Index { 162 | const tree = cg.tree; 163 | const node_tags = tree.nodes.items(.tag); 164 | 165 | if (cg.node_to_class.get(expr)) |c| return c; 166 | if (tree.value_map.get(expr)) |val| { 167 | return cg.buildConstant(expr, val); 168 | } 169 | 170 | const class = switch (expr.get(tree)) { 171 | .add_expr => |bin| bin: { 172 | const lhs = try cg.buildExpr(bin.lhs); 173 | const rhs = try cg.buildExpr(bin.rhs); 174 | break :bin try cg.oir.add(.binOp( 175 | .add, 176 | lhs, 177 | rhs, 178 | )); 179 | }, 180 | .int_literal => unreachable, // handled in the value_map above 181 | .cast => |cast| switch (cast.kind) { 182 | .lval_to_rval => c: { 183 | const operand = try cg.buildLval(cast.operand); 184 | break :c try cg.oir.add(.init(.load, operand)); 185 | }, 186 | else => std.debug.panic("TODO: cast {s}", .{@tagName(cast.kind)}), 187 | }, 188 | else => std.debug.panic("TODO: {s}", .{@tagName(node_tags[@intFromEnum(expr)])}), 189 | }; 190 | 191 | try cg.node_to_class.put(cg.gpa, expr, class); 192 | return class; 193 | } 194 | 195 | fn buildLval(cg: *CodeGen, idx: Tree.Node.Index) !Oir.Class.Index { 196 | const tree = cg.tree; 197 | const node_tags = tree.nodes.items(.tag); 198 | 199 | if (cg.node_to_class.get(idx)) |c| return c; 200 | 201 | const class = switch (idx.get(tree)) { 202 | .decl_ref_expr => |decl_ref| ref: { 203 | const name = tree.tokSlice(decl_ref.name_tok); 204 | if (cg.findIdentifier(name)) |ref_idx| { 205 | break :ref ref_idx.*; 206 | } else { 207 | @panic("TODO"); 208 | } 209 | }, 210 | else => std.debug.panic("TODO: {s}", .{@tagName(node_tags[@intFromEnum(idx)])}), 211 | }; 212 | 213 | try cg.node_to_class.put(cg.gpa, idx, class); 214 | return class; 215 | } 216 | 217 | fn findIdentifier(cg: *CodeGen, ident: []const u8) ?*Oir.Class.Index { 218 | for (0..cg.symbol_table.items.len) |i| { 219 | const rev = cg.symbol_table.items.len - i - 1; 220 | if (cg.symbol_table.items[rev].getPtr(ident)) |class| return class; 221 | } 222 | return null; 223 | } 224 | 225 | fn buildConstant(cg: *CodeGen, idx: Tree.Node.Index, val: aro.Value) !Oir.Class.Index { 226 | const tree = cg.tree; 227 | const key = tree.comp.interner.get(val.ref()); 228 | 229 | const class = switch (key) { 230 | .int => if (val.toInt(i64, tree.comp)) |int| 231 | try cg.oir.add(.constant(int)) 232 | else { 233 | @panic("TODO"); 234 | }, 235 | else => @panic("TODO"), 236 | }; 237 | 238 | try cg.node_to_class.put(cg.gpa, idx, class); 239 | return class; 240 | } 241 | 242 | pub fn deinit(cg: *CodeGen, allocator: std.mem.Allocator) void { 243 | cg.node_to_class.deinit(allocator); 244 | cg.scratch.deinit(allocator); 245 | for (cg.symbol_table.items) |*table| { 246 | table.deinit(allocator); 247 | } 248 | cg.symbol_table.deinit(allocator); 249 | } 250 | -------------------------------------------------------------------------------- /frontends/arocc/src/main.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const aro = @import("aro"); 3 | const zob = @import("zob"); 4 | const builtin = @import("builtin"); 5 | 6 | const CodeGen = @import("CodeGen.zig"); 7 | 8 | pub const std_options: std.Options = .{ 9 | .log_level = .err, 10 | }; 11 | 12 | pub fn main() !void { 13 | var arena_instance = std.heap.ArenaAllocator.init(std.heap.page_allocator); 14 | defer arena_instance.deinit(); 15 | const arena = arena_instance.allocator(); 16 | 17 | var general_purpose_allocator: std.heap.GeneralPurposeAllocator(.{}) = .init; 18 | const gpa = general_purpose_allocator.allocator(); 19 | 20 | const args = try std.process.argsAlloc(arena); 21 | 22 | const stderr_file = std.io.getStdErr(); 23 | var diagnostics: aro.Diagnostics = .{ 24 | .output = .{ .to_file = .{ 25 | .config = std.io.tty.detectConfig(stderr_file), 26 | .file = stderr_file, 27 | } }, 28 | }; 29 | 30 | var comp = try aro.Compilation.initDefault(gpa, &diagnostics, std.fs.cwd()); 31 | defer comp.deinit(); 32 | 33 | var driver: aro.Driver = .{ 34 | .comp = &comp, 35 | .diagnostics = &diagnostics, 36 | }; 37 | defer driver.deinit(); 38 | 39 | var macro_buf = std.ArrayList(u8).init(gpa); 40 | defer macro_buf.deinit(); 41 | 42 | std.debug.assert(!try driver.parseArgs(std.io.null_writer, macro_buf.writer(), args)); 43 | std.debug.assert(driver.inputs.items.len == 1); 44 | const source = driver.inputs.items[0]; 45 | 46 | const builtin_macros = try comp.generateBuiltinMacros(.include_system_defines, null); 47 | const user_macros = try comp.addSourceFromBuffer("", macro_buf.items); 48 | 49 | var pp = try aro.Preprocessor.initDefault(&comp); 50 | defer pp.deinit(); 51 | 52 | try pp.preprocessSources(&.{ source, builtin_macros, user_macros }); 53 | 54 | var tree = try pp.parse(); 55 | defer tree.deinit(); 56 | 57 | var oir: zob.Oir = .init(gpa); 58 | defer oir.deinit(); 59 | 60 | var cg = try CodeGen.init(&oir, gpa, &tree); 61 | defer cg.deinit(gpa); 62 | 63 | var recv = try cg.build(); 64 | defer recv.deinit(gpa); 65 | 66 | try zob.p2.generate(&recv); 67 | } 68 | 69 | fn fail(comptime fmt: []const u8, args: anytype) noreturn { 70 | const stderr = std.io.getStdErr().writer(); 71 | stderr.print(fmt ++ "\n", args) catch @panic("failed to print the stderr"); 72 | std.posix.abort(); 73 | } 74 | -------------------------------------------------------------------------------- /frontends/scc/.gitignore: -------------------------------------------------------------------------------- 1 | .zig-cache 2 | zig-out 3 | 4 | traces/* 5 | !traces/.keep 6 | 7 | .vscode 8 | test.dot 9 | out.png 10 | -------------------------------------------------------------------------------- /frontends/scc/build.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | 3 | pub fn build(b: *std.Build) !void { 4 | const target = b.standardTargetOptions(.{}); 5 | const optimize = b.standardOptimizeOption(.{}); 6 | 7 | const main = b.addModule("scc", .{ 8 | .target = target, 9 | .optimize = optimize, 10 | .root_source_file = b.path("src/main.zig"), 11 | }); 12 | 13 | const exe = b.addExecutable(.{ 14 | .name = "scc", 15 | .root_module = main, 16 | }); 17 | b.installArtifact(exe); 18 | } 19 | -------------------------------------------------------------------------------- /frontends/scc/src/Ast.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Ast = @This(); 3 | 4 | const Parser = @import("Parser.zig"); 5 | const Tokenizer = @import("Tokenizer.zig"); 6 | 7 | source: [:0]const u8, 8 | source_path: []const u8, 9 | tokens: std.MultiArrayList(Token).Slice, 10 | nodes: std.MultiArrayList(Node).Slice, 11 | extra_data: []const Node.Index, 12 | errors: []const Error, 13 | 14 | /// Takes ownership of the `source`. 15 | pub fn parse( 16 | gpa: std.mem.Allocator, 17 | source: [:0]const u8, 18 | source_path: []const u8, 19 | ) !Ast { 20 | var tokenizer: Tokenizer = .{ .source = source }; 21 | 22 | var tokens: std.MultiArrayList(Token) = .{}; 23 | defer tokens.deinit(gpa); 24 | while (true) { 25 | const token = tokenizer.next(); 26 | try tokens.append(gpa, token); 27 | if (token.tag == .eof) break; 28 | } 29 | 30 | var parser: Parser = .{ 31 | .gpa = gpa, 32 | .source = source, 33 | .tokens = tokens, 34 | .token_index = 0, 35 | .nodes = .{}, 36 | .errors = .{}, 37 | .scratch = .{}, 38 | .extra_data = .{}, 39 | }; 40 | defer { 41 | parser.nodes.deinit(gpa); 42 | parser.errors.deinit(gpa); 43 | parser.extra_data.deinit(gpa); 44 | parser.scratch.deinit(gpa); 45 | } 46 | 47 | try parser.parse(); 48 | 49 | const errors = try parser.errors.toOwnedSlice(gpa); 50 | errdefer gpa.free(errors); 51 | const extra_data = try parser.extra_data.toOwnedSlice(gpa); 52 | 53 | return .{ 54 | .source = source, 55 | .source_path = source_path, 56 | .tokens = tokens.toOwnedSlice(), 57 | .nodes = parser.nodes.toOwnedSlice(), 58 | .errors = errors, 59 | .extra_data = extra_data, 60 | }; 61 | } 62 | 63 | pub fn deinit(ast: *Ast, allocator: std.mem.Allocator) void { 64 | ast.tokens.deinit(allocator); 65 | ast.nodes.deinit(allocator); 66 | allocator.free(ast.errors); 67 | allocator.free(ast.extra_data); 68 | } 69 | 70 | pub fn getNode(ast: Ast, idx: Node.Index) Node { 71 | return ast.nodes.get(@intFromEnum(idx)); 72 | } 73 | 74 | fn getToken(ast: Ast, idx: Token.Index) Token { 75 | return ast.tokens.get(@intFromEnum(idx)); 76 | } 77 | 78 | pub fn ident(ast: *const Ast, idx: Token.Index) []const u8 { 79 | const token = ast.getToken(idx); 80 | return ast.source[token.loc.start..token.loc.end]; 81 | } 82 | 83 | pub fn spanToList(ast: Ast, idx: Node.Index) []const Node.Index { 84 | const root = ast.nodes.items(.data)[@intFromEnum(idx)].span; 85 | return ast.extra_data[@intFromEnum(root.start)..@intFromEnum(root.end)]; 86 | } 87 | 88 | pub const Token = struct { 89 | tag: Tag, 90 | loc: Loc, 91 | 92 | pub const Index = enum(u32) { _ }; 93 | 94 | pub const OptionalIndex = enum(u32) { 95 | none = std.math.maxInt(u32), 96 | _, 97 | 98 | pub fn unwrap(o: OptionalIndex) ?Index { 99 | return switch (o) { 100 | .none => null, 101 | else => @enumFromInt(@intFromEnum(o)), 102 | }; 103 | } 104 | 105 | pub fn wrap(i: Index) OptionalIndex { 106 | const wrapped: OptionalIndex = @enumFromInt(@intFromEnum(i)); 107 | std.debug.assert(wrapped != .none); 108 | return wrapped; 109 | } 110 | }; 111 | 112 | const Loc = struct { 113 | start: usize, 114 | end: usize, 115 | }; 116 | 117 | pub const keywords = std.StaticStringMap(Tag).initComptime(&.{ 118 | .{ "return", .keyword_return }, 119 | .{ "if", .keyword_if }, 120 | .{ "else", .keyword_else }, 121 | .{ "int", .keyword_int }, 122 | }); 123 | 124 | pub const Tag = enum { 125 | string_literal, 126 | number_literal, 127 | identifier, 128 | semicolon, 129 | l_brace, 130 | r_brace, 131 | l_paren, 132 | r_paren, 133 | keyword_return, 134 | keyword_if, 135 | keyword_else, 136 | keyword_int, 137 | plus, 138 | minus, 139 | asterisk, 140 | slash, 141 | equal, 142 | equal_equal, 143 | angle_bracket_left_equal, 144 | angle_bracket_left, 145 | angle_bracket_right_equal, 146 | angle_bracket_right, 147 | eof, 148 | invalid, 149 | 150 | pub fn lexeme(tag: Tag) ?[]const u8 { 151 | return switch (tag) { 152 | .semicolon => ";", 153 | .l_brace => "{", 154 | .r_brace => "}", 155 | .l_paren => "(", 156 | .r_paren => ")", 157 | .plus => "+", 158 | .minus => "-", 159 | .asterisk => "*", 160 | .slash => "/", 161 | .equal => "=", 162 | .equal_equal => "==", 163 | .angle_bracket_right => ">", 164 | .angle_bracket_right_equal => ">=", 165 | .angle_bracket_left_equal => "<", 166 | .angle_bracket_left => "<=", 167 | else => null, 168 | }; 169 | } 170 | 171 | pub fn symbol(tag: Tag) []const u8 { 172 | return tag.lexeme() orelse switch (tag) { 173 | .string_literal => "a string literal", 174 | .number_literal => "a number literal", 175 | .keyword_int => "int", 176 | .keyword_return => "return", 177 | .keyword_if => "if", 178 | .keyword_else => "else", 179 | .eof => "EOF", 180 | .identifier => "an identifier", 181 | .invalid => "invalid", 182 | else => std.debug.panic("tag: {s}", .{@tagName(tag)}), 183 | }; 184 | } 185 | }; 186 | }; 187 | 188 | pub const Node = struct { 189 | tag: Tag, 190 | main_token: Token.OptionalIndex, 191 | data: Data, 192 | 193 | pub const Index = enum(u32) { 194 | root, 195 | _, 196 | }; 197 | 198 | pub const OptionalIndex = enum(u32) { 199 | none = std.math.maxInt(u32), 200 | _, 201 | 202 | pub fn unwrap(o: OptionalIndex) ?Index { 203 | return switch (o) { 204 | .none => null, 205 | else => @enumFromInt(@intFromEnum(o)), 206 | }; 207 | } 208 | 209 | pub fn wrap(i: Index) OptionalIndex { 210 | return @enumFromInt(@intFromEnum(i)); 211 | } 212 | }; 213 | 214 | pub const Tag = enum { 215 | root, 216 | @"return", 217 | @"if", 218 | block, 219 | number_literal, 220 | add, 221 | sub, 222 | mul, 223 | div, 224 | group, 225 | assign, 226 | identifier, 227 | equal, 228 | greater_than, 229 | greater_or_equal, 230 | less_or_equal, 231 | less_than, 232 | }; 233 | 234 | pub const Span = struct { 235 | start: Node.Index, 236 | end: Node.Index, 237 | }; 238 | 239 | pub const Data = union(enum) { 240 | un_op: Node.Index, 241 | bin_op: struct { 242 | lhs: Node.Index, 243 | rhs: Node.Index, 244 | }, 245 | cond_br: struct { 246 | pred: Node.Index, 247 | then: Node.Index, 248 | @"else": Node.Index, 249 | }, 250 | token_and_node: struct { 251 | Token.Index, 252 | Node.Index, 253 | }, 254 | span: Span, 255 | int: i64, 256 | }; 257 | }; 258 | 259 | pub const Error = struct { 260 | tag: Tag, 261 | token: Token.Index, 262 | extra: Extra = .{ .none = {} }, 263 | 264 | pub const Tag = enum { 265 | expected_expression, 266 | expected_statement, 267 | expected_token, 268 | chained_comparison_operators, 269 | }; 270 | 271 | const Extra = union { 272 | none: void, 273 | expected_tag: Token.Tag, 274 | }; 275 | 276 | pub fn render(err: Error, ast: Ast, stderr: anytype) !void { 277 | const ttyconf = std.zig.Color.get_tty_conf(.auto); 278 | try ttyconf.setColor(stderr, .bold); 279 | 280 | // Somehow an invalid token. 281 | if (@intFromEnum(err.token) >= ast.tokens.len) { 282 | try ttyconf.setColor(stderr, .red); 283 | try stderr.writeAll("error: "); 284 | try ttyconf.setColor(stderr, .reset); 285 | try ttyconf.setColor(stderr, .bold); 286 | try stderr.writeAll("unexpected EOF\n"); 287 | try ttyconf.setColor(stderr, .reset); 288 | return; 289 | } 290 | 291 | const token = ast.getToken(err.token); 292 | const byte_offset = token.loc.start; 293 | const err_loc = std.zig.findLineColumn(ast.source, byte_offset); 294 | 295 | try stderr.print("{s}:{d}:{d}: ", .{ 296 | ast.source_path, 297 | err_loc.line + 1, 298 | err_loc.column + 1, 299 | }); 300 | try ttyconf.setColor(stderr, .red); 301 | try stderr.writeAll("error: "); 302 | try ttyconf.setColor(stderr, .reset); 303 | 304 | try ttyconf.setColor(stderr, .bold); 305 | try err.write(ast, stderr); 306 | try stderr.writeByte('\n'); 307 | try ttyconf.setColor(stderr, .reset); 308 | } 309 | 310 | fn write(err: Error, ast: Ast, stderr: anytype) !void { 311 | const token_tags = ast.tokens.items(.tag); 312 | switch (err.tag) { 313 | .expected_expression => { 314 | const found_tag = token_tags[@intFromEnum(err.token)]; 315 | return stderr.print( 316 | "expected expression, found '{s}'", 317 | .{found_tag.symbol()}, 318 | ); 319 | }, 320 | .expected_statement => { 321 | const found_tag = token_tags[@intFromEnum(err.token)]; 322 | return stderr.print( 323 | "expected statement, found '{s}'", 324 | .{found_tag.symbol()}, 325 | ); 326 | }, 327 | .chained_comparison_operators => { 328 | return stderr.writeAll("comparison operators cannot be chained"); 329 | }, 330 | .expected_token => { 331 | const found_tag = token_tags[@intFromEnum(err.token)]; 332 | const expected_symbol = err.extra.expected_tag.symbol(); 333 | switch (found_tag) { 334 | .invalid => return stderr.print( 335 | "expected '{s}', found invalid bytes", 336 | .{expected_symbol}, 337 | ), 338 | else => return stderr.print( 339 | "expected '{s}', found '{s}'", 340 | .{ expected_symbol, found_tag.symbol() }, 341 | ), 342 | } 343 | }, 344 | } 345 | } 346 | }; 347 | -------------------------------------------------------------------------------- /frontends/scc/src/CodeGen.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const zob = @import("zob"); 3 | const Ast = @import("Ast.zig"); 4 | const CodeGen = @This(); 5 | 6 | const Oir = zob.Oir; 7 | const Recursive = zob.Recursive; 8 | 9 | gpa: std.mem.Allocator, 10 | oir: *Oir, 11 | ast: *const Ast, 12 | ctrl_class: ?Oir.Class.Index, 13 | exits: *std.ArrayListUnmanaged(Oir.Class.Index), 14 | node_to_class: std.AutoHashMapUnmanaged(Ast.Node.Index, Oir.Class.Index), 15 | symbol_table: SymbolTable, 16 | scratch: std.ArrayListUnmanaged(Oir.Class.Index), 17 | 18 | const Error = error{OutOfMemory}; 19 | const SymbolTable = std.ArrayListUnmanaged(std.StringHashMapUnmanaged(Oir.Class.Index)); 20 | 21 | pub fn init(oir: *Oir, gpa: std.mem.Allocator, ast: *const Ast) !CodeGen { 22 | const start_class = try oir.add(.{ 23 | .tag = .start, 24 | .data = .{ .list = .{ .start = 0, .end = 0 } }, 25 | }); 26 | const ctrl_class = try oir.add(.project(0, start_class, .ctrl)); 27 | 28 | var symbol_table: SymbolTable = .{}; 29 | try symbol_table.append(gpa, .{}); 30 | 31 | return .{ 32 | .gpa = gpa, 33 | .oir = oir, 34 | .ast = ast, 35 | .ctrl_class = ctrl_class, 36 | .node_to_class = .{}, 37 | .exits = &oir.exit_list, 38 | .scratch = .{}, 39 | .symbol_table = symbol_table, 40 | }; 41 | } 42 | 43 | pub fn build(cg: *CodeGen) !Recursive { 44 | const stdout = std.io.getStdOut().writer(); 45 | 46 | try cg.buildBlock(.root); 47 | try cg.oir.rebuild(); 48 | 49 | try stdout.writeAll("unoptimized OIR:\n"); 50 | try cg.oir.print(stdout); 51 | try stdout.writeAll("end OIR\n"); 52 | 53 | try cg.oir.optimize(.saturate, false); 54 | try cg.oir.dump("test.dot"); 55 | 56 | const extracted = try cg.oir.extract(.auto); 57 | 58 | try stdout.writeAll("optimized OIR:\n"); 59 | try extracted.print(stdout); 60 | try stdout.writeAll("end OIR\n"); 61 | 62 | return extracted; 63 | } 64 | 65 | fn buildStatement(cg: *CodeGen, idx: Ast.Node.Index) Error!void { 66 | const ast = cg.ast; 67 | const tag = ast.nodes.items(.tag)[@intFromEnum(idx)]; 68 | 69 | switch (tag) { 70 | .block => try cg.buildBlock(idx), 71 | .@"return" => try cg.buildReturn(idx), 72 | .@"if" => try cg.buildIf(idx), 73 | .assign => try cg.buildAssign(idx), 74 | else => std.debug.panic("TODO: buildStatement {s}", .{@tagName(tag)}), 75 | } 76 | } 77 | 78 | fn buildExpression(cg: *CodeGen, idx: Ast.Node.Index) Error!Oir.Class.Index { 79 | const ast = cg.ast; 80 | const tag = ast.nodes.items(.tag)[@intFromEnum(idx)]; 81 | const main_token = ast.nodes.items(.main_token)[@intFromEnum(idx)]; 82 | const data = ast.nodes.items(.data)[@intFromEnum(idx)]; 83 | 84 | if (cg.node_to_class.get(idx)) |c| return c; 85 | 86 | const class = switch (tag) { 87 | .add, 88 | .sub, 89 | .mul, 90 | .div, 91 | .equal, 92 | .greater_than, 93 | => c: { 94 | const bin_op = data.bin_op; 95 | const lhs = try cg.buildExpression(bin_op.lhs); 96 | const rhs = try cg.buildExpression(bin_op.rhs); 97 | const class = try cg.oir.add(.binOp( 98 | switch (tag) { 99 | .add => .add, 100 | .sub => .sub, 101 | .mul => .mul, 102 | .div => .div_trunc, 103 | .equal => .cmp_eq, 104 | .greater_than => .cmp_gt, 105 | else => unreachable, 106 | }, 107 | lhs, 108 | rhs, 109 | )); 110 | break :c class; 111 | }, 112 | .group => try cg.buildExpression(data.un_op), 113 | .number_literal => try cg.oir.add(.constant(data.int)), 114 | .identifier => c: { 115 | const ident = cg.ast.ident(main_token.unwrap().?); 116 | const class = cg.findIdentifier(ident) orelse 117 | std.debug.panic("couldn't find identifier '{s}'", .{ident}); 118 | break :c class.*; 119 | }, 120 | else => std.debug.panic("TODO: buildExpression {s}", .{@tagName(tag)}), 121 | }; 122 | 123 | try cg.node_to_class.put(cg.gpa, idx, class); 124 | return class; 125 | } 126 | 127 | fn findIdentifier(cg: *CodeGen, ident: []const u8) ?*Oir.Class.Index { 128 | for (0..cg.symbol_table.items.len) |i| { 129 | const rev = cg.symbol_table.items.len - i - 1; 130 | if (cg.symbol_table.items[rev].getPtr(ident)) |class| return class; 131 | } 132 | return null; 133 | } 134 | 135 | fn buildBlock(cg: *CodeGen, idx: Ast.Node.Index) Error!void { 136 | const ast = cg.ast; 137 | const items = ast.spanToList(idx); 138 | try cg.symbol_table.append(cg.gpa, .{}); 139 | for (items) |item| { 140 | try cg.buildStatement(item); 141 | } 142 | var old = cg.symbol_table.pop().?; 143 | old.deinit(cg.gpa); 144 | } 145 | 146 | fn buildReturn(cg: *CodeGen, idx: Ast.Node.Index) Error!void { 147 | const operand = cg.ast.getNode(idx).data.un_op; 148 | const operand_class = try cg.buildExpression(operand); 149 | 150 | const node = try cg.oir.add(.binOp( 151 | .ret, 152 | cg.ctrl_class.?, 153 | operand_class, 154 | )); 155 | 156 | try cg.exits.append(cg.gpa, node); 157 | try cg.node_to_class.put(cg.gpa, idx, node); 158 | cg.ctrl_class = null; 159 | } 160 | 161 | fn buildIf(cg: *CodeGen, idx: Ast.Node.Index) Error!void { 162 | const scratch_top = cg.scratch.items.len; 163 | defer cg.scratch.shrinkRetainingCapacity(scratch_top); 164 | 165 | const cond_br = cg.ast.getNode(idx).data.cond_br; 166 | 167 | const predicate = try cg.buildExpression(cond_br.pred); 168 | 169 | const branch = try cg.oir.add(.branch(cg.ctrl_class.?, predicate)); 170 | const then_project = try cg.oir.add(.project(0, branch, .ctrl)); 171 | const else_project = try cg.oir.add(.project(1, branch, .ctrl)); 172 | 173 | cg.ctrl_class = then_project; 174 | try cg.buildBlock(cond_br.then); 175 | const latest_then_ctrl = cg.ctrl_class; 176 | 177 | cg.ctrl_class = else_project; 178 | try cg.buildBlock(cond_br.@"else"); 179 | const latest_else_ctrl = cg.ctrl_class; 180 | 181 | if (latest_then_ctrl == null and latest_else_ctrl == null) { 182 | // this region is dead, we can ignore it 183 | return; 184 | } 185 | 186 | if (latest_then_ctrl) |ctrl| { 187 | try cg.scratch.append(cg.gpa, ctrl); 188 | } 189 | if (latest_else_ctrl) |ctrl| { 190 | try cg.scratch.append(cg.gpa, ctrl); 191 | } 192 | 193 | const items = cg.scratch.items[scratch_top..]; 194 | const list = try cg.oir.listToSpan(items); 195 | cg.ctrl_class = try cg.oir.add(.region(list)); 196 | } 197 | 198 | fn buildAssign(cg: *CodeGen, idx: Ast.Node.Index) Error!void { 199 | const ident_token, const rvalue_node = cg.ast.getNode(idx).data.token_and_node; 200 | 201 | const rvalue = try cg.buildExpression(rvalue_node); 202 | const ident = cg.ast.ident(ident_token); 203 | 204 | if (cg.findIdentifier(ident)) |existing| { 205 | existing.* = rvalue; 206 | } else { 207 | const latest = &cg.symbol_table.items[cg.symbol_table.items.len - 1]; 208 | try latest.put(cg.gpa, ident, rvalue); 209 | } 210 | } 211 | 212 | pub fn deinit(cg: *CodeGen, allocator: std.mem.Allocator) void { 213 | cg.node_to_class.deinit(allocator); 214 | cg.scratch.deinit(allocator); 215 | for (cg.symbol_table.items) |*table| { 216 | table.deinit(allocator); 217 | } 218 | cg.symbol_table.deinit(allocator); 219 | } 220 | -------------------------------------------------------------------------------- /frontends/scc/src/Parser.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Ast = @import("Ast.zig"); 3 | 4 | const Parser = @This(); 5 | 6 | const Token = Ast.Token; 7 | const Node = Ast.Node; 8 | 9 | gpa: std.mem.Allocator, 10 | source: [:0]const u8, 11 | tokens: std.MultiArrayList(Token), 12 | token_index: u32, 13 | nodes: std.MultiArrayList(Node), 14 | errors: std.ArrayListUnmanaged(Ast.Error), 15 | scratch: std.ArrayListUnmanaged(Node.Index), 16 | extra_data: std.ArrayListUnmanaged(Node.Index), 17 | 18 | const Error = error{ 19 | OutOfMemory, 20 | ParserError, 21 | // TODO: hook these up with the error messages 22 | Overflow, 23 | InvalidCharacter, 24 | }; 25 | 26 | pub fn parse(p: *Parser) !void { 27 | try p.nodes.append(p.gpa, .{ 28 | .tag = .root, 29 | .main_token = .none, 30 | .data = undefined, 31 | }); 32 | 33 | var statements: std.ArrayListUnmanaged(Node.Index) = .{}; 34 | defer statements.deinit(p.gpa); 35 | 36 | while (p.tokens.get(p.token_index).tag != .eof) { 37 | const stmt = p.parseStatement() catch |err| switch (err) { 38 | error.OutOfMemory => return error.OutOfMemory, 39 | error.ParserError => { 40 | std.debug.assert(p.errors.items.len > 0); 41 | return; 42 | }, 43 | else => |e| return e, 44 | }; 45 | try statements.append(p.gpa, stmt); 46 | } 47 | 48 | p.nodes.items(.data)[0] = .{ .span = try p.listToSpan(statements.items) }; 49 | } 50 | 51 | fn parseStatement(p: *Parser) Error!Node.Index { 52 | if (p.token_index >= p.tokens.len) { 53 | return p.failExpecting(.eof); 54 | } 55 | 56 | while (true) { 57 | switch (p.tokens.get(p.token_index).tag) { 58 | .l_brace => return p.parseCompoundStatement(), 59 | .keyword_return => { 60 | const main_token = try p.expectToken(.keyword_return); 61 | const payload = try p.parseExpression(0); 62 | _ = try p.expectToken(.semicolon); 63 | return p.addNode(.{ 64 | .tag = .@"return", 65 | .main_token = .wrap(main_token), 66 | .data = .{ .un_op = payload }, 67 | }); 68 | }, 69 | .keyword_if => { 70 | const main_token = try p.expectToken(.keyword_if); 71 | 72 | _ = try p.expectToken(.l_paren); 73 | const predicate = try p.parseExpression(0); 74 | _ = try p.expectToken(.r_paren); 75 | 76 | const then_body = try p.parseCompoundStatement(); 77 | _ = try p.expectToken(.keyword_else); 78 | const else_body = try p.parseCompoundStatement(); 79 | 80 | return p.addNode(.{ 81 | .tag = .@"if", 82 | .main_token = .wrap(main_token), 83 | .data = .{ .cond_br = .{ 84 | .pred = predicate, 85 | .then = then_body, 86 | .@"else" = else_body, 87 | } }, 88 | }); 89 | }, 90 | .keyword_int => { 91 | _ = try p.expectToken(.keyword_int); 92 | const ident_token = try p.expectToken(.identifier); 93 | const assign_token = try p.expectToken(.equal); 94 | 95 | const rvalue = try p.parseExpression(0); 96 | _ = try p.expectToken(.semicolon); 97 | 98 | return p.addNode(.{ 99 | .tag = .assign, 100 | .main_token = .wrap(assign_token), 101 | .data = .{ .token_and_node = .{ 102 | ident_token, 103 | rvalue, 104 | } }, 105 | }); 106 | }, 107 | .identifier => { 108 | const ident_token = try p.expectToken(.identifier); 109 | const assign_token = try p.expectToken(.equal); 110 | const rvalue = try p.parseExpression(0); 111 | _ = try p.expectToken(.semicolon); 112 | 113 | return p.addNode(.{ 114 | .tag = .assign, 115 | .main_token = .wrap(assign_token), 116 | .data = .{ .token_and_node = .{ 117 | ident_token, 118 | rvalue, 119 | } }, 120 | }); 121 | }, 122 | .eof => break, 123 | else => return p.failMsg(.{ 124 | .tag = .expected_statement, 125 | .token = @enumFromInt(p.token_index), 126 | }), 127 | } 128 | } 129 | 130 | @panic("TODO"); 131 | } 132 | 133 | const Assoc = enum { 134 | left, 135 | none, 136 | }; 137 | 138 | const OperatorInfo = struct { 139 | prec: i8, 140 | tag: Node.Tag, 141 | assoc: Assoc = Assoc.left, 142 | }; 143 | 144 | const operatorTable = std.enums.directEnumArrayDefault( 145 | Token.Tag, 146 | OperatorInfo, 147 | .{ .prec = -1, .tag = Node.Tag.root }, 148 | 0, 149 | .{ 150 | .equal_equal = .{ .prec = 30, .tag = .equal, .assoc = Assoc.none }, 151 | .angle_bracket_left = .{ .prec = 30, .tag = .less_than, .assoc = Assoc.none }, 152 | .angle_bracket_right = .{ .prec = 30, .tag = .greater_than, .assoc = Assoc.none }, 153 | .angle_bracket_left_equal = .{ .prec = 30, .tag = .less_or_equal, .assoc = Assoc.none }, 154 | .angle_bracket_right_equal = .{ .prec = 30, .tag = .greater_or_equal, .assoc = Assoc.none }, 155 | 156 | .plus = .{ .prec = 60, .tag = .add }, 157 | .minus = .{ .prec = 60, .tag = .sub }, 158 | 159 | .asterisk = .{ .prec = 70, .tag = .mul }, 160 | .slash = .{ .prec = 70, .tag = .div }, 161 | }, 162 | ); 163 | 164 | fn parseExpression(p: *Parser, min_prec: i32) Error!Node.Index { 165 | std.debug.assert(min_prec >= 0); 166 | var node = try p.parsePrimaryExpression(); 167 | var banned_prec: i8 = -1; 168 | 169 | while (true) { 170 | const tag = p.tokens.items(.tag)[p.token_index]; 171 | const info = operatorTable[@intFromEnum(tag)]; 172 | if (info.prec < min_prec) break; 173 | if (info.prec == banned_prec) { 174 | return p.fail(.chained_comparison_operators); 175 | } 176 | 177 | const operator_token = p.nextToken(); 178 | const rhs = try p.parseExpression(info.prec + 1); 179 | 180 | node = try p.addNode(.{ 181 | .tag = info.tag, 182 | .main_token = .wrap(operator_token), 183 | .data = .{ .bin_op = .{ 184 | .lhs = node, 185 | .rhs = rhs, 186 | } }, 187 | }); 188 | 189 | if (info.assoc == Assoc.none) { 190 | banned_prec = info.prec; 191 | } 192 | } 193 | 194 | return node; 195 | } 196 | 197 | fn parsePrimaryExpression(p: *Parser) Error!Node.Index { 198 | switch (p.tokens.get(p.token_index).tag) { 199 | .number_literal => { 200 | const number_literal = try p.expectToken(.number_literal); 201 | return p.addNode(.{ 202 | .tag = .number_literal, 203 | .main_token = .wrap(number_literal), 204 | .data = .{ .int = try p.parseNumber(number_literal) }, 205 | }); 206 | }, 207 | .l_paren => { 208 | const main_token = try p.expectToken(.l_paren); 209 | const inside = try p.parseExpression(0); 210 | _ = try p.expectToken(.r_paren); 211 | return p.addNode(.{ 212 | .tag = .group, 213 | .main_token = .wrap(main_token), 214 | .data = .{ .un_op = inside }, 215 | }); 216 | }, 217 | .identifier => return p.addNode(.{ 218 | .tag = .identifier, 219 | .main_token = .wrap(p.nextToken()), 220 | .data = undefined, 221 | }), 222 | else => return p.failMsg(.{ 223 | .tag = .expected_expression, 224 | .token = @enumFromInt(p.token_index), 225 | }), 226 | } 227 | } 228 | 229 | fn parseNumber(p: *Parser, node: Token.Index) Error!i64 { 230 | const string = p.ident(node); 231 | return std.fmt.parseInt(i64, string, 10); 232 | } 233 | 234 | fn parseCompoundStatement(p: *Parser) Error!Node.Index { 235 | const scratch_top = p.scratch.items.len; 236 | defer p.scratch.shrinkRetainingCapacity(scratch_top); 237 | 238 | _ = try p.expectToken(.l_brace); 239 | 240 | while (true) { 241 | if (p.tokens.get(p.token_index).tag == .r_brace) break; 242 | const stmt = try p.parseStatement(); 243 | try p.scratch.append(p.gpa, stmt); 244 | } 245 | 246 | _ = try p.expectToken(.r_brace); 247 | 248 | const items = p.scratch.items[scratch_top..]; 249 | return p.addNode(.{ 250 | .tag = .block, 251 | .main_token = .none, 252 | .data = .{ .span = try p.listToSpan(items) }, 253 | }); 254 | } 255 | 256 | fn eatToken(p: *Parser, tag: Token.Tag) ?Token.Index { 257 | return if (p.tokens.get(p.token_index).tag == tag) p.nextToken() else null; 258 | } 259 | 260 | fn expectToken(p: *Parser, tag: Token.Tag) !Token.Index { 261 | const token = p.tokens.get(p.token_index); 262 | if (token.tag != tag) { 263 | return p.failExpecting(tag); 264 | } 265 | return p.nextToken(); 266 | } 267 | 268 | fn nextToken(p: *Parser) Token.Index { 269 | const result = p.token_index; 270 | p.token_index += 1; 271 | return @enumFromInt(result); 272 | } 273 | 274 | fn getToken(p: *Parser, idx: Token.Index) Token { 275 | return p.tokens.get(@intFromEnum(idx)); 276 | } 277 | 278 | fn ident(p: *Parser, idx: Token.Index) []const u8 { 279 | const token = p.getToken(idx); 280 | return p.source[token.loc.start..token.loc.end]; 281 | } 282 | 283 | fn fail(p: *Parser, msg: Ast.Error.Tag) error{ ParserError, OutOfMemory } { 284 | return p.failMsg(.{ .tag = msg, .token = @enumFromInt(p.token_index) }); 285 | } 286 | 287 | fn failExpecting(p: *Parser, expected_tag: Token.Tag) error{ ParserError, OutOfMemory } { 288 | @branchHint(.cold); 289 | return p.failMsg(.{ 290 | .tag = .expected_token, 291 | .token = @enumFromInt(p.token_index), 292 | .extra = .{ .expected_tag = expected_tag }, 293 | }); 294 | } 295 | 296 | fn failMsg(p: *Parser, msg: Ast.Error) error{ ParserError, OutOfMemory } { 297 | @branchHint(.cold); 298 | try p.errors.append(p.gpa, msg); 299 | return error.ParserError; 300 | } 301 | 302 | fn addNode(p: *Parser, node: Node) !Node.Index { 303 | const result: Node.Index = @enumFromInt(p.nodes.len); 304 | try p.nodes.append(p.gpa, node); 305 | return result; 306 | } 307 | 308 | fn listToSpan(p: *Parser, list: []const Node.Index) !Node.Span { 309 | try p.extra_data.appendSlice(p.gpa, list); 310 | return .{ 311 | .start = @enumFromInt(p.extra_data.items.len - list.len), 312 | .end = @enumFromInt(p.extra_data.items.len), 313 | }; 314 | } 315 | -------------------------------------------------------------------------------- /frontends/scc/src/Tokenizer.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const Tokenizer = @This(); 3 | 4 | const Ast = @import("Ast.zig"); 5 | const Token = Ast.Token; 6 | 7 | source: [:0]const u8, 8 | index: usize = 0, 9 | 10 | const State = enum { 11 | start, 12 | invalid, 13 | identifier, 14 | plus, 15 | minus, 16 | int, 17 | asterisk, 18 | slash, 19 | equal, 20 | angle_bracket_right, 21 | angle_bracket_left, 22 | }; 23 | 24 | pub fn next(self: *Tokenizer) Token { 25 | var result: Token = .{ 26 | .tag = undefined, 27 | .loc = .{ 28 | .start = self.index, 29 | .end = undefined, 30 | }, 31 | }; 32 | 33 | state: switch (State.start) { 34 | .start => switch (self.source[self.index]) { 35 | 0 => { 36 | if (self.index == self.source.len) { 37 | return .{ 38 | .tag = .eof, 39 | .loc = .{ 40 | .start = self.index, 41 | .end = self.index, 42 | }, 43 | }; 44 | } else continue :state .invalid; 45 | }, 46 | 'a'...'z', 'A'...'Z', '_' => { 47 | result.tag = .identifier; 48 | continue :state .identifier; 49 | }, 50 | ';' => { 51 | result.tag = .semicolon; 52 | self.index += 1; 53 | }, 54 | '{' => { 55 | result.tag = .l_brace; 56 | self.index += 1; 57 | }, 58 | '}' => { 59 | result.tag = .r_brace; 60 | self.index += 1; 61 | }, 62 | '(' => { 63 | result.tag = .l_paren; 64 | self.index += 1; 65 | }, 66 | ')' => { 67 | result.tag = .r_paren; 68 | self.index += 1; 69 | }, 70 | '+' => continue :state .plus, 71 | '-' => continue :state .minus, 72 | '/' => continue :state .slash, 73 | '*' => continue :state .asterisk, 74 | '=' => continue :state .equal, 75 | '>' => continue :state .angle_bracket_right, 76 | '<' => continue :state .angle_bracket_left, 77 | ' ', '\n', '\t', '\r' => { 78 | self.index += 1; 79 | result.loc.start = self.index; 80 | continue :state .start; 81 | }, 82 | '0'...'9' => { 83 | result.tag = .number_literal; 84 | self.index += 1; 85 | continue :state .int; 86 | }, 87 | else => continue :state .invalid, 88 | }, 89 | .identifier => { 90 | self.index += 1; 91 | switch (self.source[self.index]) { 92 | 'a'...'z', 'A'...'Z', '_', '0'...'9' => continue :state .identifier, 93 | else => { 94 | if (Token.keywords.get(self.source[result.loc.start..self.index])) |tag| { 95 | result.tag = tag; 96 | } 97 | }, 98 | } 99 | }, 100 | .plus, 101 | .minus, 102 | .asterisk, 103 | .slash, 104 | => |t| { 105 | self.index += 1; 106 | switch (self.source[self.index]) { 107 | else => result.tag = switch (t) { 108 | .plus => .plus, 109 | .minus => .minus, 110 | .asterisk => .asterisk, 111 | .slash => .slash, 112 | else => unreachable, 113 | }, 114 | } 115 | }, 116 | .equal => { 117 | self.index += 1; 118 | switch (self.source[self.index]) { 119 | '=' => { 120 | result.tag = .equal_equal; 121 | self.index += 1; 122 | }, 123 | else => result.tag = .equal, 124 | } 125 | }, 126 | .int => switch (self.source[self.index]) { 127 | '0'...'9' => { 128 | self.index += 1; 129 | continue :state .int; 130 | }, 131 | else => {}, 132 | }, 133 | .angle_bracket_left => { 134 | self.index += 1; 135 | switch (self.source[self.index]) { 136 | '=' => { 137 | result.tag = .angle_bracket_left_equal; 138 | self.index += 1; 139 | }, 140 | else => result.tag = .angle_bracket_left, 141 | } 142 | }, 143 | .angle_bracket_right => { 144 | self.index += 1; 145 | switch (self.source[self.index]) { 146 | '=' => { 147 | result.tag = .angle_bracket_right_equal; 148 | self.index += 1; 149 | }, 150 | else => result.tag = .angle_bracket_right, 151 | } 152 | }, 153 | .invalid => { 154 | self.index += 1; 155 | switch (self.source[self.index]) { 156 | 0 => if (self.index == self.source.len) { 157 | result.tag = .invalid; 158 | } else continue :state .invalid, 159 | '\n' => result.tag = .invalid, 160 | else => continue :state .invalid, 161 | } 162 | }, 163 | } 164 | 165 | result.loc.end = self.index; 166 | return result; 167 | } 168 | 169 | fn testTokenize(source: [:0]const u8, expected_token_tags: []const Token.Tag) !void { 170 | var tokenizer: Tokenizer = .{ .source = source }; 171 | for (expected_token_tags) |expected_token_tag| { 172 | const token = tokenizer.next(); 173 | try std.testing.expectEqual(expected_token_tag, token.tag); 174 | } 175 | const last_token = tokenizer.next(); 176 | try std.testing.expectEqual(Token.Tag.eof, last_token.tag); 177 | try std.testing.expectEqual(source.len, last_token.loc.start); 178 | try std.testing.expectEqual(source.len, last_token.loc.end); 179 | } 180 | 181 | test "basic" { 182 | try testTokenize("10", &.{.number_literal}); 183 | } 184 | 185 | test "keywords" { 186 | try testTokenize("return", &.{.keyword_return}); 187 | } 188 | 189 | test "block" { 190 | try testTokenize("{ return 10; }", &.{ 191 | .l_brace, 192 | .keyword_return, 193 | .number_literal, 194 | .semicolon, 195 | .r_brace, 196 | }); 197 | } 198 | 199 | test "operation" { 200 | try testTokenize("{ return 10 + 20; }", &.{ 201 | .l_brace, 202 | .keyword_return, 203 | .number_literal, 204 | .plus, 205 | .number_literal, 206 | .semicolon, 207 | .r_brace, 208 | }); 209 | 210 | try testTokenize("{ return 10 == 20; }", &.{ 211 | .l_brace, 212 | .keyword_return, 213 | .number_literal, 214 | .equal_equal, 215 | .number_literal, 216 | .semicolon, 217 | .r_brace, 218 | }); 219 | } 220 | -------------------------------------------------------------------------------- /frontends/scc/src/main.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const zob = @import("zob"); 3 | const builtin = @import("builtin"); 4 | 5 | const Ast = @import("Ast.zig"); 6 | const CodeGen = @import("CodeGen.zig"); 7 | 8 | pub const std_options: std.Options = .{ 9 | .log_level = .err, 10 | }; 11 | 12 | pub fn main() !void { 13 | var gpa = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 16 }){}; 14 | defer _ = gpa.deinit(); 15 | const allocator = gpa.allocator(); 16 | 17 | var args = try std.process.argsWithAllocator(allocator); 18 | defer args.deinit(); 19 | _ = args.skip(); 20 | 21 | var input_path: ?[]const u8 = null; 22 | while (args.next()) |arg| { 23 | if (input_path != null) @panic("two file inputs"); 24 | input_path = arg; 25 | } 26 | 27 | if (input_path == null) @panic("no file provided"); 28 | const source = try std.fs.cwd().readFileAllocOptions( 29 | allocator, 30 | input_path.?, 31 | 10 * 1024 * 1024, 32 | null, 33 | .@"1", 34 | 0, 35 | ); 36 | defer allocator.free(source); 37 | 38 | var ast = try Ast.parse(allocator, source, input_path.?); 39 | defer ast.deinit(allocator); 40 | 41 | if (ast.errors.len != 0) { 42 | const stderr = std.io.getStdErr(); 43 | for (ast.errors) |err| { 44 | try err.render(ast, stderr.writer()); 45 | } 46 | fail("failed with {d} error(s)", .{ast.errors.len}); 47 | } 48 | 49 | var oir: zob.Oir = .init(allocator); 50 | defer oir.deinit(); 51 | 52 | var cg: CodeGen = try .init(&oir, allocator, &ast); 53 | defer cg.deinit(allocator); 54 | 55 | var recv = try cg.build(); 56 | defer recv.deinit(allocator); 57 | } 58 | 59 | fn fail(comptime fmt: []const u8, args: anytype) noreturn { 60 | const stderr = std.io.getStdErr().writer(); 61 | stderr.print(fmt ++ "\n", args) catch @panic("failed to print the stderr"); 62 | std.posix.abort(); 63 | } 64 | 65 | test { 66 | _ = std.testing.refAllDecls(Ast); 67 | _ = std.testing.refAllDecls(CodeGen); 68 | } 69 | -------------------------------------------------------------------------------- /render.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -r graphs/*.png 4 | 5 | for file in graphs/*.dot; do 6 | [ -e "$file" ] || continue 7 | basename="${file%.dot}" 8 | output="${basename}.png" 9 | dot "$file" -Tpng -o "$output" 10 | echo "Converted $file to $output" 11 | done 12 | -------------------------------------------------------------------------------- /src/Oir.zig: -------------------------------------------------------------------------------- 1 | //! Optimizable Intermediate Representation 2 | 3 | allocator: std.mem.Allocator, 4 | 5 | /// The list of all E-Nodes in the graph. Each E-Node represents a potential state of the E-Class 6 | /// they are in. After all optimizations we want have completed, the extractor will be used to 7 | /// iterate through all E-Classes and extract the best node within. 8 | nodes: std.AutoArrayHashMapUnmanaged(Node, void), 9 | 10 | /// Used for storing dynamic and temporary data. Things like the Node `list` payload are stored 11 | /// on this array. We can assume that it'll live as long as the OIR does. 12 | extra: std.ArrayListUnmanaged(u32), 13 | 14 | /// Represents the list of all E-Classes in the graph. 15 | /// 16 | /// Each E-Class contains a bundle of nodes which are equivalent to each other. 17 | classes: std.AutoHashMapUnmanaged(Class.Index, Class), 18 | 19 | /// A map relating nodes to the classes they are in. Used as a fast way to determine 20 | /// what "parent" class a node is in. 21 | node_to_class: std.HashMapUnmanaged( 22 | Node.Index, 23 | Class.Index, 24 | NodeContext, 25 | std.hash_map.default_max_load_percentage, 26 | ), 27 | 28 | union_find: UnionFind, 29 | 30 | /// A list of pending `Pair`s which have made the E-Graph unclean. This is a part of incremental 31 | /// rebuilding and lets the graph process faster. `add` and `union` dirty the graph, marking `clean` 32 | /// as false, and then `rebuild` will iterate through the pending items to analyze and mark `clean` 33 | /// as true. 34 | pending: std.ArrayListUnmanaged(Pair), 35 | 36 | /// Indicates whether or not reading type operations are allowed on the E-Graph. 37 | /// 38 | /// Mutating operations set this to `false`, and `rebuild` will set it back to `true`. 39 | clean: bool, 40 | trace: Trace, 41 | 42 | /// A list of classes/nodes which act as exits from the function. This will usually 43 | /// be `ret` nodes. We use it later in the extraction to understand where to start 44 | /// looking for the best path. 45 | exit_list: std.ArrayListUnmanaged(Class.Index), 46 | 47 | const UnionFind = struct { 48 | parents: std.ArrayListUnmanaged(Class.Index) = .{}, 49 | 50 | fn makeSet(f: *UnionFind, gpa: std.mem.Allocator) !Class.Index { 51 | const id: Class.Index = @enumFromInt(f.parents.items.len); 52 | try f.parents.append(gpa, id); 53 | return id; 54 | } 55 | 56 | pub fn find(f: *const UnionFind, idx: Class.Index) Class.Index { 57 | var current = idx; 58 | while (current != f.parent(current)) { 59 | current = f.parent(current); 60 | } 61 | return current; 62 | } 63 | 64 | /// Same thing as `find` but performs path-compression. 65 | fn findMutable(f: *UnionFind, idx: Class.Index) Class.Index { 66 | var current = idx; 67 | while (current != f.parent(current)) { 68 | const grandparent = f.parent(f.parent(current)); 69 | f.parents.items[@intFromEnum(idx)] = grandparent; 70 | current = grandparent; 71 | } 72 | return current; 73 | } 74 | 75 | fn @"union"(f: *UnionFind, a: Class.Index, b: Class.Index) Class.Index { 76 | f.parents.items[@intFromEnum(b)] = a; 77 | return a; 78 | } 79 | 80 | fn parent(f: *const UnionFind, idx: Class.Index) Class.Index { 81 | return f.parents.items[@intFromEnum(idx)]; 82 | } 83 | 84 | fn deinit(f: *UnionFind, gpa: std.mem.Allocator) void { 85 | f.parents.deinit(gpa); 86 | } 87 | }; 88 | 89 | pub const NodeContext = struct { 90 | oir: *const Oir, 91 | 92 | pub fn hash(ctx: NodeContext, node_idx: Node.Index) u64 { 93 | const node = ctx.oir.getNode(node_idx); 94 | var hasher = std.hash.XxHash3.init(0); 95 | std.hash.autoHash(&hasher, node); 96 | return hasher.final(); 97 | } 98 | 99 | pub fn eql(ctx: NodeContext, a_idx: Node.Index, b_idx: Node.Index) bool { 100 | const oir = ctx.oir; 101 | const a = oir.getNode(a_idx); 102 | const b = oir.getNode(b_idx); 103 | 104 | if (a.tag != b.tag) return false; 105 | if (std.meta.activeTag(a.data) != std.meta.activeTag(b.data)) return false; 106 | // b.data would also be `constant` because of the above check 107 | if (a.data == .constant) { 108 | return a.data.constant == b.data.constant; 109 | } 110 | 111 | for (a.operands(oir), b.operands(oir)) |a_class, b_class| { 112 | if (a_class != b_class) { 113 | return false; 114 | } 115 | } 116 | 117 | return true; 118 | } 119 | 120 | fn lessThanClass(_: void, a: Class.Index, b: Class.Index) bool { 121 | return @intFromEnum(a) < @intFromEnum(b); 122 | } 123 | }; 124 | 125 | pub const Node = struct { 126 | tag: Tag, 127 | data: Data = .none, 128 | 129 | pub const Index = enum(u32) { 130 | /// The singular `start` node should always be the first node. 131 | start, 132 | _, 133 | }; 134 | 135 | const Type = enum { 136 | ctrl, 137 | data, 138 | }; 139 | 140 | pub const Tag = enum(u8) { 141 | /// Constant integer. 142 | constant, 143 | /// Projection extracts a field from a tuple. 144 | project, 145 | 146 | // Control flow 147 | /// There can only ever be one `start` node in the function. 148 | /// The inputs to the start node is a list of the return values. 149 | /// The output of the start node is a list of the arguments to the function. 150 | start, 151 | /// The return nodes are input to the `start` node in the function 152 | /// 153 | /// This node uses the `bin_op` payload, where the first item is the preceding 154 | /// control node, and the second item is the data node which represents 155 | /// the return value. 156 | ret, 157 | branch, 158 | region, 159 | 160 | // Integer arthimatics. 161 | add, 162 | @"and", 163 | sub, 164 | mul, 165 | shl, 166 | shr, 167 | div_trunc, 168 | div_exact, 169 | 170 | // Compare 171 | cmp_eq, 172 | cmp_gt, 173 | 174 | load, 175 | store, 176 | 177 | dead, 178 | 179 | pub fn isCanonical(tag: Tag) bool { 180 | return switch (tag) { 181 | .constant, 182 | .start, 183 | => true, 184 | else => false, 185 | }; 186 | } 187 | 188 | pub fn dataType(tag: Tag) std.meta.FieldEnum(Data) { 189 | return switch (tag) { 190 | .constant, 191 | => .constant, 192 | .region, 193 | .start, 194 | => .list, 195 | .project, 196 | => .project, 197 | .cmp_gt, 198 | .cmp_eq, 199 | .@"and", 200 | .add, 201 | .sub, 202 | .mul, 203 | .shl, 204 | .shr, 205 | .div_trunc, 206 | .div_exact, 207 | .store, 208 | .ret, 209 | .branch, 210 | => .bin_op, 211 | .load, 212 | => .un_op, 213 | .dead, 214 | => .none, 215 | }; 216 | } 217 | }; 218 | 219 | const Data = union(enum) { 220 | none: void, 221 | constant: i64, 222 | /// NOTE: For future reference, we use an array here so that the operands() 223 | /// function can return a slice, otherwise padding between struct elements 224 | /// would be undefined and it wouldn't be safe. 225 | bin_op: [2]Class.Index, 226 | un_op: Class.Index, 227 | project: Project, 228 | list: Span, 229 | }; 230 | 231 | /// A span in the Oir "extra" array. 232 | pub const Span = struct { 233 | start: u32, 234 | end: u32, 235 | 236 | pub const empty: Span = .{ .start = 0, .end = 0 }; 237 | 238 | pub fn toSlice(span: Span, repr: anytype) []const u32 { 239 | return repr.extra.items[span.start..span.end]; 240 | } 241 | }; 242 | 243 | const Project = struct { 244 | tuple: Class.Index, 245 | index: u32, 246 | type: Type, 247 | }; 248 | 249 | pub fn init(comptime tag: Tag, payload: anytype) Node { 250 | const data = @unionInit(Data, @tagName(tag.dataType()), payload); 251 | return .{ 252 | .tag = tag, 253 | .data = data, 254 | }; 255 | } 256 | 257 | /// Same as `init`, but for nodes that need to allocate to be initialized. 258 | pub fn create(comptime tag: Tag, oir: *Oir, payload: []const Class.Index) !Node { 259 | switch (tag) { 260 | .start => { 261 | const data = try oir.listToSpan(payload); 262 | return .{ .tag = .start, .data = .{ .list = data } }; 263 | }, 264 | else => unreachable, 265 | } 266 | } 267 | 268 | pub fn operands(node: *const Node, repr: anytype) []const Class.Index { 269 | if (node.tag == .start) return &.{}; // no real operands 270 | return switch (node.data) { 271 | .none, .constant => &.{}, 272 | .bin_op => |*bin_op| bin_op, 273 | .un_op => |*un_op| un_op[0..1], 274 | .project => |*proj| (&proj.tuple)[0..1], 275 | .list => |span| @ptrCast(repr.extra.items[span.start..span.end]), 276 | }; 277 | } 278 | 279 | pub fn mutableOperands(node: *Node, repr: anytype) []Class.Index { 280 | if (node.tag == .start) return &.{}; // no real operands 281 | return switch (node.data) { 282 | .none, .constant => &.{}, 283 | .bin_op => |*bin_op| bin_op, 284 | .un_op => |*un_op| un_op[0..1], 285 | .project => |*proj| (&proj.tuple)[0..1], 286 | .list => |span| @ptrCast(repr.extra.items[span.start..span.end]), 287 | }; 288 | } 289 | 290 | pub fn nodeType(node: Node) Type { 291 | return switch (node.tag) { 292 | .constant, 293 | .load, 294 | .store, 295 | .cmp_gt, 296 | .cmp_eq, 297 | .@"and", 298 | .add, 299 | .sub, 300 | .mul, 301 | .shl, 302 | .shr, 303 | .div_trunc, 304 | .div_exact, 305 | => .data, 306 | .start, 307 | .ret, 308 | .branch, 309 | .region, 310 | => .ctrl, 311 | .project => node.data.project.type, 312 | .dead => unreachable, 313 | }; 314 | } 315 | 316 | pub fn isVolatile(node: Node) bool { 317 | // TODO: this isn't necessarily true, but just to be safe for now. 318 | return switch (node.tag) { 319 | .start, 320 | .ret, 321 | => true, 322 | else => false, 323 | }; 324 | } 325 | 326 | pub fn mapNode( 327 | old: Node, 328 | oir: *const Oir, 329 | map: *std.AutoHashMapUnmanaged(Class.Index, Class.Index), 330 | ) !Node { 331 | var copy = old; 332 | for (copy.mutableOperands(oir)) |*op| { 333 | op.* = map.get(oir.union_find.find(op.*)).?; 334 | } 335 | return copy; 336 | } 337 | 338 | // Helper functions 339 | pub fn branch(ctrl: Class.Index, pred: Class.Index) Node { 340 | return binOp(.branch, ctrl, pred); 341 | } 342 | 343 | pub fn project(index: u32, tuple: Class.Index, ty: Type) Node { 344 | return .{ 345 | .tag = .project, 346 | .data = .{ .project = .{ 347 | .index = index, 348 | .tuple = tuple, 349 | .type = ty, 350 | } }, 351 | }; 352 | } 353 | 354 | pub fn region(span: Span) Node { 355 | return .{ 356 | .tag = .region, 357 | .data = .{ .list = span }, 358 | }; 359 | } 360 | 361 | pub fn binOp(tag: Tag, lhs: Class.Index, rhs: Class.Index) Node { 362 | assert(tag.dataType() == .bin_op); 363 | return .{ 364 | .tag = tag, 365 | .data = .{ .bin_op = .{ lhs, rhs } }, 366 | }; 367 | } 368 | 369 | pub fn constant(value: i64) Node { 370 | return .{ .tag = .constant, .data = .{ .constant = value } }; 371 | } 372 | 373 | pub fn format( 374 | _: Node, 375 | comptime _: []const u8, 376 | _: std.fmt.FormatOptions, 377 | _: anytype, 378 | ) !void { 379 | @compileError("don't format nodes directly, use Node.fmt"); 380 | } 381 | 382 | pub fn fmt(node: Node, oir: *const Oir) std.fmt.Formatter(format2) { 383 | return .{ .data = .{ 384 | .node = node, 385 | .oir = oir, 386 | } }; 387 | } 388 | 389 | const FormatContext = struct { 390 | node: Node, 391 | oir: *const Oir, 392 | }; 393 | 394 | pub fn format2( 395 | ctx: FormatContext, 396 | comptime _: []const u8, 397 | _: std.fmt.FormatOptions, 398 | stream: anytype, 399 | ) !void { 400 | const node = ctx.node; 401 | var writer: print_oir.Writer = .{ .nodes = ctx.oir.nodes.keys() }; 402 | try writer.printNode(node, ctx.oir, stream); 403 | } 404 | }; 405 | 406 | const Pair = struct { Node.Index, Class.Index }; 407 | 408 | /// A Class contains an N amount of Nodes as children. 409 | pub const Class = struct { 410 | index: Index, 411 | bag: std.ArrayListUnmanaged(Node.Index) = .{}, 412 | parents: std.ArrayListUnmanaged(Pair) = .{}, 413 | 414 | pub const Index = enum(u32) { 415 | /// The start node is always in the first class, and alone as it's canonical. 416 | start, 417 | _, 418 | 419 | pub fn format( 420 | idx: Index, 421 | comptime fmt: []const u8, 422 | _: std.fmt.FormatOptions, 423 | writer: anytype, 424 | ) !void { 425 | assert(fmt.len == 0); 426 | try writer.print("%{d}", .{@intFromEnum(idx)}); 427 | } 428 | }; 429 | 430 | pub fn deinit(class: *Class, allocator: std.mem.Allocator) void { 431 | class.bag.deinit(allocator); 432 | class.parents.deinit(allocator); 433 | } 434 | }; 435 | 436 | const Pass = struct { 437 | const Error = error{ OutOfMemory, Overflow, InvalidCharacter }; 438 | 439 | name: []const u8, 440 | func: *const fn (oir: *Oir) Error!bool, 441 | }; 442 | 443 | const passes: []const Pass = &.{ 444 | .{ 445 | .name = "constant-fold", 446 | .func = @import("passes/constant_fold.zig").run, 447 | }, 448 | .{ 449 | .name = "common-rewrites", 450 | .func = @import("passes/rewrite.zig").run, 451 | }, 452 | }; 453 | 454 | pub fn optimize( 455 | oir: *Oir, 456 | mode: enum { 457 | /// Optimize until running all passes creates no new changes. 458 | /// NOTE: likely will be very slow for any large input 459 | saturate, 460 | }, 461 | /// Prints dumps a graphviz of the current OIR state after each pass iteration. 462 | output_graph: bool, 463 | ) !void { 464 | switch (mode) { 465 | .saturate => { 466 | try oir.rebuild(); 467 | assert(oir.clean); 468 | 469 | var i: u32 = 0; 470 | while (true) { 471 | var new_change: bool = false; 472 | inline for (passes) |pass| { 473 | if (output_graph) { 474 | const name = try std.fmt.allocPrint( 475 | oir.allocator, 476 | "graphs/pre_{s}_{}.dot", 477 | .{ pass.name, i }, 478 | ); 479 | defer oir.allocator.free(name); 480 | try oir.dump(name); 481 | } 482 | 483 | const trace = oir.trace.start(@src(), "{s}", .{pass.name}); 484 | defer trace.end(); 485 | 486 | if (try pass.func(oir)) new_change = true; 487 | // TODO: in theory we don't actually need to rebuild after every pass 488 | // maybe we should look into rebuilding on-demand? 489 | if (!oir.clean) try oir.rebuild(); 490 | } 491 | 492 | i += 1; 493 | if (!new_change) break; 494 | } 495 | }, 496 | } 497 | } 498 | 499 | pub fn init(allocator: std.mem.Allocator) Oir { 500 | return .{ 501 | .allocator = allocator, 502 | .nodes = .{}, 503 | .node_to_class = .{}, 504 | .classes = .{}, 505 | .extra = .{}, 506 | .union_find = .{}, 507 | .pending = .{}, 508 | .trace = .init(), 509 | .exit_list = .{}, 510 | .clean = true, 511 | }; 512 | } 513 | 514 | pub fn dump(oir: *Oir, name: []const u8) !void { 515 | const graphviz_file = try std.fs.cwd().createFile(name, .{}); 516 | defer graphviz_file.close(); 517 | try print_oir.dumpOirGraph(oir, graphviz_file.writer()); 518 | } 519 | 520 | pub fn print(oir: *Oir, stream: anytype) !void { 521 | try print_oir.print(oir, stream); 522 | } 523 | 524 | /// Reference becomes invalid when new classes are added to the graph. 525 | pub fn getClassPtr(oir: *Oir, idx: Class.Index) *Class { 526 | const found = oir.union_find.findMutable(idx); 527 | return oir.classes.getPtr(found).?; 528 | } 529 | 530 | pub fn getClass(oir: *const Oir, idx: Class.Index) Class { 531 | const found = oir.union_find.find(idx); 532 | return oir.classes.get(found).?; 533 | } 534 | 535 | pub fn findClass(oir: *const Oir, node_idx: Node.Index) Class.Index { 536 | const memo_idx = oir.node_to_class.getContext( 537 | node_idx, 538 | .{ .oir = oir }, 539 | ).?; 540 | return oir.union_find.find(memo_idx); 541 | } 542 | 543 | pub fn findNode(oir: *const Oir, node: Node) ?Node.Index { 544 | const idx = oir.nodes.getIndex(node) orelse return null; 545 | return @enumFromInt(idx); 546 | } 547 | 548 | pub fn getNode(oir: *const Oir, idx: Node.Index) Node { 549 | return oir.nodes.keys()[@intFromEnum(idx)]; 550 | } 551 | 552 | pub fn getNodes(oir: *const Oir) []const Node { 553 | return oir.nodes.keys(); 554 | } 555 | 556 | /// Reference becomes invalid when new nodes are added to the graph. 557 | fn getNodePtr(oir: *const Oir, idx: Node.Index) *Node { 558 | return &oir.nodes.keys()[@intFromEnum(idx)]; 559 | } 560 | 561 | /// Returns the type of the class. 562 | /// If the class contains a ctrl node, all other nodes must also be control. 563 | pub fn getClassType(oir: *const Oir, idx: Class.Index) Node.Type { 564 | const class = oir.classes.get(idx).?; 565 | const first = class.bag.items[0]; 566 | return oir.getNode(first).nodeType(); 567 | } 568 | 569 | /// Adds an ENode to the EGraph, giving the node its own class. 570 | /// Returns the EClass index the ENode was placed in. 571 | pub fn add(oir: *Oir, node: Node) !Class.Index { 572 | const gop = try oir.nodes.getOrPut(oir.allocator, node); 573 | if (gop.found_existing) { 574 | const class_idx = oir.findClass(@enumFromInt(gop.index)); 575 | return oir.union_find.find(class_idx); 576 | } else { 577 | const node_idx: Node.Index = @enumFromInt(gop.index); 578 | 579 | log.debug("adding node {} {}", .{ node.fmt(oir), node_idx }); 580 | 581 | const class_idx = try oir.addInternal(node_idx); 582 | return oir.union_find.find(class_idx); 583 | } 584 | } 585 | 586 | /// An internal function to simplify adding nodes to the Oir. 587 | /// 588 | /// It should be used carefully as it invalidates the equality invariance of the graph. 589 | fn addInternal(oir: *Oir, node: Node.Index) !Class.Index { 590 | if (oir.node_to_class.getContext( 591 | node, 592 | .{ .oir = oir }, 593 | )) |class_idx| { 594 | return class_idx; 595 | } else { 596 | const id = try oir.makeClass(node); 597 | oir.clean = false; 598 | return id; 599 | } 600 | } 601 | 602 | fn makeClass(oir: *Oir, node_idx: Node.Index) !Class.Index { 603 | const id = try oir.union_find.makeSet(oir.allocator); 604 | log.debug("adding {} to {}", .{ node_idx, id }); 605 | 606 | var class: Class = .{ 607 | .index = id, 608 | .bag = .{}, 609 | }; 610 | 611 | try class.bag.append(oir.allocator, node_idx); 612 | 613 | const node = oir.getNode(node_idx); 614 | for (node.operands(oir)) |child| { 615 | const class_ptr = oir.getClassPtr(child); 616 | try class_ptr.parents.append(oir.allocator, .{ node_idx, id }); 617 | } 618 | 619 | try oir.pending.append(oir.allocator, .{ node_idx, id }); 620 | try oir.classes.put(oir.allocator, id, class); 621 | try oir.node_to_class.putNoClobberContext(oir.allocator, node_idx, id, .{ .oir = oir }); 622 | 623 | return id; 624 | } 625 | 626 | /// Performs the "union" operation on the graph. 627 | /// 628 | /// Returns whether a union needs to happen. `true` is they are already equivalent 629 | /// 630 | /// This can be thought of as "merging" two classes when they were proven to be equivalent. 631 | pub fn @"union"(oir: *Oir, a_idx: Class.Index, b_idx: Class.Index) !bool { 632 | oir.clean = false; 633 | var a = oir.union_find.findMutable(a_idx); 634 | var b = oir.union_find.findMutable(b_idx); 635 | if (a == b) return false; 636 | 637 | log.debug("union on {} -> {}", .{ b, a }); 638 | 639 | assert(oir.getClassType(a) == oir.getClassType(b)); 640 | 641 | const a_parents = oir.classes.get(a).?.parents.items.len; 642 | const b_parents = oir.classes.get(b).?.parents.items.len; 643 | 644 | if (a_parents < b_parents) { 645 | std.mem.swap(Class.Index, &a, &b); 646 | } 647 | 648 | // make `a` the leader class 649 | _ = oir.union_find.@"union"(a, b); 650 | 651 | var b_class = oir.classes.fetchRemove(b).?.value; 652 | defer b_class.deinit(oir.allocator); 653 | 654 | const a_class = oir.classes.getPtr(a).?; 655 | assert(a == a_class.index); 656 | 657 | try oir.pending.appendSlice(oir.allocator, b_class.parents.items); 658 | try a_class.bag.appendSlice(oir.allocator, b_class.bag.items); 659 | try a_class.parents.appendSlice(oir.allocator, b_class.parents.items); 660 | 661 | return true; 662 | } 663 | 664 | /// Performs a rebuild of the E-Graph to ensure that invariances are met. 665 | /// 666 | /// This looks over hashes of the nodes and merges duplicate nodes. 667 | /// We can hash based on the class indices themselves, as they don't change during the 668 | /// rebuild. 669 | pub fn rebuild(oir: *Oir) !void { 670 | const trace = oir.trace.start(@src(), "rebuilding", .{}); 671 | defer trace.end(); 672 | log.debug("rebuilding", .{}); 673 | 674 | while (oir.pending.pop()) |pair| { 675 | const node_idx, const class_idx = pair; 676 | 677 | // before modifying the node in-place, we must remove it from the hashmap 678 | // in order to not get a stale hash. 679 | assert(oir.node_to_class.removeContext(node_idx, .{ .oir = oir })); 680 | 681 | const node = oir.getNodePtr(node_idx); 682 | for (node.mutableOperands(oir)) |*id| { 683 | id.* = oir.union_find.findMutable(id.*); 684 | } 685 | 686 | if (try oir.node_to_class.fetchPutContext( 687 | oir.allocator, 688 | node_idx, 689 | class_idx, 690 | .{ .oir = oir }, 691 | )) |prev| { 692 | _ = try oir.@"union"(prev.value, class_idx); 693 | } 694 | } 695 | 696 | var iter = oir.classes.iterator(); 697 | while (iter.next()) |entry| { 698 | for (entry.value_ptr.bag.items) |node_idx| { 699 | // NOTE: if this assert fails, you've modified the underlying data of a node 700 | assert(oir.node_to_class.removeContext(node_idx, .{ .oir = oir })); 701 | 702 | const node = oir.getNodePtr(node_idx); 703 | for (node.mutableOperands(oir)) |*child| { 704 | child.* = oir.union_find.findMutable(child.*); 705 | } 706 | 707 | // place the newly changed node back on the map 708 | try oir.node_to_class.putNoClobberContext( 709 | oir.allocator, 710 | node_idx, 711 | entry.key_ptr.*, 712 | .{ .oir = oir }, 713 | ); 714 | } 715 | } 716 | 717 | try oir.verifyNodes(); 718 | assert(oir.pending.items.len == 0); 719 | oir.clean = true; 720 | } 721 | 722 | pub fn findCycles(oir: *const Oir) !std.AutoHashMapUnmanaged(Node.Index, Class.Index) { 723 | const allocator = oir.allocator; 724 | 725 | const Color = enum { 726 | white, 727 | gray, 728 | black, 729 | }; 730 | 731 | var stack = try std.ArrayList(struct { 732 | bool, 733 | Class.Index, 734 | }).initCapacity(allocator, oir.classes.size); 735 | defer stack.deinit(); 736 | 737 | var color = std.AutoHashMap(Class.Index, Color).init(allocator); 738 | defer color.deinit(); 739 | 740 | var iter = oir.classes.valueIterator(); 741 | while (iter.next()) |class| { 742 | stack.appendAssumeCapacity(.{ true, class.index }); 743 | try color.put(class.index, .white); 744 | } 745 | 746 | var cycles: std.AutoHashMapUnmanaged(Node.Index, Class.Index) = .{}; 747 | while (stack.pop()) |entry| { 748 | const enter, const id = entry; 749 | if (enter) { 750 | color.getPtr(id).?.* = .gray; 751 | try stack.append(.{ false, id }); 752 | 753 | const class_ptr = oir.getClass(id); 754 | for (class_ptr.bag.items) |node_idx| { 755 | const node = oir.getNode(node_idx); 756 | for (node.operands(oir)) |child| { 757 | const child_color = color.get(child).?; 758 | switch (child_color) { 759 | .white => try stack.append(.{ true, child }), 760 | .gray => try cycles.put(allocator, node_idx, id), 761 | .black => {}, 762 | } 763 | } 764 | } 765 | } else color.getPtr(id).?.* = .black; 766 | } 767 | 768 | return cycles; 769 | } 770 | 771 | fn verifyNodes(oir: *Oir) !void { 772 | var found_start: bool = false; 773 | 774 | var temporary: std.HashMapUnmanaged( 775 | Node.Index, 776 | Class.Index, 777 | NodeContext, 778 | std.hash_map.default_max_load_percentage, 779 | ) = .{}; 780 | defer temporary.deinit(oir.allocator); 781 | 782 | var iter = oir.classes.iterator(); 783 | while (iter.next()) |entry| { 784 | const id = entry.key_ptr.*; 785 | const class = entry.value_ptr.*; 786 | for (class.bag.items) |node| { 787 | if (oir.getNode(node).tag == .start) { 788 | if (found_start == true) @panic("second start node found in OIR"); 789 | found_start = true; 790 | } 791 | 792 | const gop = try temporary.getOrPutContext( 793 | oir.allocator, 794 | node, 795 | .{ .oir = oir }, 796 | ); 797 | if (gop.found_existing) { 798 | const found_id = oir.union_find.find(id); 799 | const found_old = oir.union_find.find(gop.value_ptr.*); 800 | if (found_id != found_old) { 801 | std.debug.panic( 802 | "found unexpected equivalence for {}\n{any}\nvs\n{any}", 803 | .{ 804 | node, 805 | oir.getClassPtr(found_id).bag.items, 806 | oir.getClassPtr(found_old).bag.items, 807 | }, 808 | ); 809 | } 810 | } else gop.value_ptr.* = id; 811 | } 812 | } 813 | 814 | if (!found_start) @panic("no start node found in OIR"); 815 | 816 | var temp_iter = temporary.iterator(); 817 | while (temp_iter.next()) |entry| { 818 | const e = entry.value_ptr.*; 819 | assert(e == oir.union_find.find(e)); 820 | } 821 | } 822 | 823 | pub fn extract(oir: *Oir, strat: extraction.CostStrategy) !extraction.Recursive { 824 | return extraction.extract(oir, strat); 825 | } 826 | 827 | pub fn deinit(oir: *Oir) void { 828 | const allocator = oir.allocator; 829 | 830 | { 831 | var iter = oir.classes.valueIterator(); 832 | while (iter.next()) |class| class.deinit(allocator); 833 | oir.classes.deinit(allocator); 834 | } 835 | 836 | oir.trace.deinit(); 837 | oir.node_to_class.deinit(allocator); 838 | oir.nodes.deinit(allocator); 839 | 840 | oir.union_find.deinit(allocator); 841 | oir.pending.deinit(allocator); 842 | oir.extra.deinit(allocator); 843 | oir.exit_list.deinit(allocator); 844 | } 845 | 846 | /// Checks if a class contains a constant equivalence node, and returns it. 847 | /// Otherwise returns `null`. 848 | /// 849 | /// Can only return canonical element types such as `constant`. 850 | pub fn classContains(oir: *const Oir, idx: Class.Index, comptime tag: Node.Tag) ?Node.Index { 851 | comptime assert(tag.isCanonical()); 852 | assert(oir.clean); 853 | 854 | const class = oir.classes.get(idx) orelse return null; 855 | for (class.bag.items) |node_idx| { 856 | const node = oir.getNode(node_idx); 857 | // Since the node is aborbing, we can return early as no other 858 | // instances of it are allowed in the same class. 859 | if (node.tag == tag) return node_idx; 860 | } 861 | 862 | return null; 863 | } 864 | 865 | /// Similar to `classContains` but instead of returning a specific node that matches 866 | /// the tag, it just tells us whether the class in general contains a node of that tag. 867 | pub fn classContainsAny(oir: *const Oir, idx: Class.Index, tag: Node.Tag) bool { 868 | assert(oir.clean); 869 | const class = oir.classes.get(idx).?; 870 | for (class.bag.items) |node_idx| { 871 | const node = oir.getNode(node_idx); 872 | if (node.tag == tag) return true; 873 | } 874 | return false; 875 | } 876 | 877 | pub fn listToSpan(oir: *Oir, list: []const Class.Index) !Node.Span { 878 | try oir.extra.appendSlice(oir.allocator, @ptrCast(list)); 879 | return .{ 880 | .start = @intCast(oir.extra.items.len - list.len), 881 | .end = @intCast(oir.extra.items.len), 882 | }; 883 | } 884 | 885 | const Oir = @This(); 886 | const std = @import("std"); 887 | const print_oir = @import("Oir/print_oir.zig"); 888 | pub const extraction = @import("Oir/extraction.zig"); 889 | const Trace = @import("trace.zig").Trace; 890 | 891 | const log = std.log.scoped(.oir); 892 | const assert = std.debug.assert; 893 | -------------------------------------------------------------------------------- /src/Oir/SimpleExtractor.zig: -------------------------------------------------------------------------------- 1 | //! Super basic Oir extractor implementation. 2 | 3 | const std = @import("std"); 4 | const Oir = @import("../Oir.zig"); 5 | const cost = @import("../cost.zig"); 6 | const extraction = @import("extraction.zig"); 7 | 8 | const Class = Oir.Class; 9 | const Node = Oir.Node; 10 | const SimpleExtractor = @This(); 11 | const Recursive = extraction.Recursive; 12 | 13 | const log = std.log.scoped(.simple_extractor); 14 | const assert = std.debug.assert; 15 | 16 | oir: *const Oir, 17 | 18 | /// Describes cycles found in the OIR. EGraphs are allowed to have cycles, 19 | /// they are not DAGs. However, it's impossible to extract a "best node" 20 | /// from a cyclic class pattern so we must skip them. If after iterating through 21 | /// all of the nodes in a class we can't find one that doesn't cycle, this means 22 | /// the class itself cycles and the graph is unsolvable. 23 | /// 24 | /// The key is a cyclic node index and the value is the index of the class 25 | /// which references the class the node is in. 26 | cycles: std.AutoHashMapUnmanaged(Node.Index, Class.Index), 27 | 28 | /// Relates class indicies to the best node in them. Since the classes 29 | /// are immutable after the OIR optimization passes, we can confidently 30 | /// reuse the extraction. This amortization makes our extraction strategy 31 | /// just barely usable. 32 | cost_memo: std.AutoHashMapUnmanaged(Class.Index, NodeCost), 33 | 34 | start_class: ?Class.Index, 35 | 36 | best_node: std.AutoHashMapUnmanaged(Class.Index, Node.Index), 37 | map: std.AutoHashMapUnmanaged(Class.Index, Class.Index), 38 | 39 | exit_list: std.ArrayListUnmanaged(Class.Index), 40 | 41 | const NodeCost = struct { 42 | u32, 43 | Node.Index, 44 | }; 45 | 46 | /// Extracts the best pattern of Oir from the E-Graph given a cost model. 47 | pub fn extract(oir: *const Oir) !Recursive { 48 | var e: SimpleExtractor = .{ 49 | .oir = oir, 50 | .cycles = try oir.findCycles(), 51 | .cost_memo = .empty, 52 | .best_node = .{}, 53 | .map = .{}, 54 | .exit_list = .{}, 55 | .start_class = null, 56 | }; 57 | defer e.deinit(); 58 | 59 | log.debug("cycles found: {}", .{e.cycles.count()}); 60 | 61 | { 62 | var iter = oir.classes.valueIterator(); 63 | while (iter.next()) |class| { 64 | const best_node = try e.getBestNode(class.index); 65 | try e.best_node.put(oir.allocator, class.index, best_node); 66 | } 67 | } 68 | 69 | var recv: Recursive = .{}; 70 | for (oir.exit_list.items) |exit| { 71 | _ = try e.extractClass(exit, &recv); 72 | } 73 | recv.exit_list = try e.exit_list.clone(oir.allocator); 74 | 75 | return recv; 76 | } 77 | 78 | fn extractClass(e: *SimpleExtractor, class_idx: Class.Index, recv: *Recursive) !Class.Index { 79 | const oir = e.oir; 80 | const gpa = oir.allocator; 81 | 82 | if (e.map.get(class_idx)) |memo| return memo; 83 | 84 | const best_node_idx = e.best_node.get(e.oir.union_find.find(class_idx)).?; 85 | const best_node = oir.getNode(best_node_idx); 86 | 87 | switch (best_node.tag) { 88 | .start => { 89 | const new_node: Node = .{ .tag = .start, .data = .{ .list = .empty } }; 90 | const idx = try recv.addNode(gpa, new_node); 91 | if (e.start_class != null) @panic("found two start nodes?"); 92 | e.start_class = idx; 93 | try e.map.put(gpa, class_idx, idx); 94 | return idx; 95 | }, 96 | .project => { 97 | const project = best_node.data.project; 98 | 99 | const tuple = try e.extractClass(project.tuple, recv); 100 | 101 | const new_node: Node = .project(project.index, tuple, project.type); 102 | const new_node_idx = try recv.addNode(gpa, new_node); 103 | try e.map.put(gpa, class_idx, new_node_idx); 104 | return new_node_idx; 105 | }, 106 | .load => { 107 | const un_op = best_node.data.un_op; 108 | 109 | const operand = try e.extractClass(un_op, recv); 110 | 111 | const new_node: Node = .{ 112 | .tag = best_node.tag, 113 | .data = .{ .un_op = operand }, 114 | }; 115 | const new_node_idx = try recv.addNode(gpa, new_node); 116 | try e.map.put(gpa, class_idx, new_node_idx); 117 | return new_node_idx; 118 | }, 119 | .ret, 120 | .branch, 121 | .cmp_gt, 122 | .add, 123 | .sub, 124 | .shl, 125 | .shr, 126 | => { 127 | const bin_op = best_node.data.bin_op; 128 | 129 | const lhs = try e.extractClass(bin_op[0], recv); 130 | const rhs = try e.extractClass(bin_op[1], recv); 131 | 132 | const new_node: Node = .binOp(best_node.tag, lhs, rhs); 133 | const new_node_idx = try recv.addNode(gpa, new_node); 134 | try e.map.put(gpa, class_idx, new_node_idx); 135 | switch (best_node.tag) { 136 | .ret => try e.exit_list.append(gpa, new_node_idx), 137 | else => {}, 138 | } 139 | return new_node_idx; 140 | }, 141 | .constant => { 142 | const idx = try recv.addNode(gpa, best_node); 143 | try e.map.put(gpa, class_idx, idx); 144 | return idx; 145 | }, 146 | .dead => unreachable, 147 | else => std.debug.panic("TODO: extractClass {s}\n", .{@tagName(best_node.tag)}), 148 | } 149 | } 150 | 151 | fn getBestNode(e: *SimpleExtractor, class_idx: Class.Index) !Node.Index { 152 | _, const best_node = try e.extractBestNode(class_idx); 153 | 154 | log.debug("best node for class {} is {s}", .{ 155 | class_idx, 156 | @tagName(e.oir.getNode(best_node).tag), 157 | }); 158 | 159 | return best_node; 160 | } 161 | 162 | /// Given a class, extract the "best" node from it. 163 | fn extractBestNode(e: *SimpleExtractor, class_idx: Class.Index) !NodeCost { 164 | const oir = e.oir; 165 | const class = oir.classes.get(class_idx).?; 166 | assert(class.bag.items.len > 0); 167 | 168 | if (e.cost_memo.get(class_idx)) |entry| return entry; 169 | 170 | var best_cost: u32 = std.math.maxInt(u32); 171 | var best_node: Node.Index = class.bag.items[0]; 172 | 173 | for (class.bag.items) |node_idx| { 174 | // the node is known to cycle, we must skip it. 175 | if (e.cycles.get(node_idx) != null) continue; 176 | 177 | const node = oir.getNode(node_idx); 178 | 179 | const base_cost = cost.getCost(node.tag); 180 | var child_cost: u32 = 0; 181 | for (node.operands(oir)) |sub_class_idx| { 182 | assert(sub_class_idx != class_idx); // checked for cycles above 183 | 184 | const extracted_cost, _ = try e.extractBestNode(sub_class_idx); 185 | child_cost += extracted_cost; 186 | } 187 | 188 | const node_cost = base_cost + child_cost; 189 | 190 | if (node_cost < best_cost) { 191 | best_cost = node_cost; 192 | best_node = node_idx; 193 | } 194 | } 195 | if (best_cost == std.math.maxInt(u32)) { 196 | std.debug.panic("extracted cyclic terms, no best node could be found! {}", .{class_idx}); 197 | } 198 | 199 | const entry: NodeCost = .{ best_cost, best_node }; 200 | try e.cost_memo.putNoClobber(oir.allocator, class_idx, entry); 201 | return entry; 202 | } 203 | 204 | pub fn deinit(e: *SimpleExtractor) void { 205 | const allocator = e.oir.allocator; 206 | e.cost_memo.deinit(allocator); 207 | e.cycles.deinit(allocator); 208 | e.best_node.deinit(allocator); 209 | e.map.deinit(allocator); 210 | e.exit_list.deinit(allocator); 211 | } 212 | -------------------------------------------------------------------------------- /src/Oir/extraction.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const build_options = @import("build_options"); 3 | const Oir = @import("../Oir.zig"); 4 | 5 | const SimpleExtractor = @import("SimpleExtractor.zig"); 6 | const z3 = @import("z3.zig"); 7 | 8 | const log = std.log.scoped(.extraction); 9 | const Node = Oir.Node; 10 | const Class = Oir.Class; 11 | 12 | /// A form of OIR where nodes reference other nodes. 13 | pub const Recursive = struct { 14 | nodes: std.ArrayListUnmanaged(Node) = .{}, 15 | exit_list: std.ArrayListUnmanaged(Class.Index) = .{}, 16 | extra: std.ArrayListUnmanaged(u32) = .{}, 17 | 18 | // TODO: Explore making this its own unique type. Currently we can't do that because 19 | // the Node data payload types use Class.Index to reference other Classes, which isn't 20 | // compatible with this. Maybe we can bitcast safely between them? 21 | // pub const Index = enum(u32) { 22 | // start, 23 | // _, 24 | // }; 25 | 26 | pub fn getNode(r: *const Recursive, idx: Class.Index) Node { 27 | return r.nodes.items[@intFromEnum(idx)]; 28 | } 29 | 30 | pub fn getNodes(r: *const Recursive) []const Node { 31 | return r.nodes.items; 32 | } 33 | 34 | pub fn addNode(r: *Recursive, allocator: std.mem.Allocator, node: Node) !Class.Index { 35 | const idx: Class.Index = @enumFromInt(r.nodes.items.len); 36 | try r.nodes.append(allocator, node); 37 | return idx; 38 | } 39 | 40 | pub fn dump(recv: Recursive, name: []const u8) !void { 41 | const graphviz_file = try std.fs.cwd().createFile(name, .{}); 42 | defer graphviz_file.close(); 43 | try @import("print_oir.zig").dumpRecvGraph(recv, graphviz_file.writer()); 44 | } 45 | 46 | pub fn deinit(r: *Recursive, allocator: std.mem.Allocator) void { 47 | r.nodes.deinit(allocator); 48 | r.extra.deinit(allocator); 49 | r.exit_list.deinit(allocator); 50 | } 51 | 52 | pub fn listToSpan( 53 | r: *Recursive, 54 | list: []const Class.Index, 55 | gpa: std.mem.Allocator, 56 | ) !Oir.Node.Span { 57 | try r.extra.appendSlice(gpa, @ptrCast(list)); 58 | return .{ 59 | .start = @intCast(r.extra.items.len - list.len), 60 | .end = @intCast(r.extra.items.len), 61 | }; 62 | } 63 | 64 | pub fn print(recv: Recursive, writer: anytype) !void { 65 | try @import("print_oir.zig").print(recv, writer); 66 | } 67 | }; 68 | 69 | pub const CostStrategy = enum { 70 | /// A super basic cost strategy that simply looks at the number of child nodes 71 | /// a particular node has to determine its cost. 72 | simple_latency, 73 | /// Uses Z3 and a column approach to find the optimal solution. 74 | z3, 75 | 76 | pub const auto: CostStrategy = if (build_options.has_z3) .z3 else .simple_latency; 77 | }; 78 | 79 | pub fn extract(oir: *Oir, strat: CostStrategy) !Recursive { 80 | const trace = oir.trace.start(@src(), "extracting", .{}); 81 | defer trace.end(); 82 | 83 | switch (strat) { 84 | .simple_latency => return SimpleExtractor.extract(oir), 85 | .z3 => return z3.extract(oir), 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/Oir/print_oir.zig: -------------------------------------------------------------------------------- 1 | //! Prints OIR to graphviz 2 | 3 | const Oir = @import("../Oir.zig"); 4 | const Recursive = @import("extraction.zig").Recursive; 5 | 6 | pub fn dumpOirGraph( 7 | oir: *const Oir, 8 | stream: anytype, 9 | ) !void { 10 | try stream.writeAll( 11 | \\digraph G { 12 | \\ compound=true 13 | \\ clusterrank=local 14 | \\ graph [fontsize=14 compound=true] 15 | \\ node [shape=box, style=filled]; 16 | \\ rankdir=BT; 17 | \\ ordering="in"; 18 | // https://gitlab.com/graphviz/graphviz/-/issues/1949 19 | // \\ concentrate=true; 20 | \\ 21 | \\ 22 | ); 23 | 24 | { 25 | var class_iter = oir.classes.iterator(); 26 | while (class_iter.next()) |entry| { 27 | const class_idx: u32 = @intFromEnum(entry.key_ptr.*); 28 | const class = entry.value_ptr.*; 29 | try stream.print( 30 | \\ subgraph cluster_{d} {{ 31 | \\ style=dotted 32 | \\ 33 | , .{class_idx}); 34 | 35 | for (class.bag.items, 0..) |node_idx, i| { 36 | const node = oir.getNode(node_idx); 37 | try stream.print(" {}.{} [label=\"", .{ class_idx, i }); 38 | try printNodeLabel(stream, node); 39 | try stream.print(" {}", .{class_idx}); 40 | const color = switch (node.nodeType()) { 41 | .ctrl => "orange", 42 | .data => "grey", 43 | }; 44 | try stream.print("\", color=\"{s}\"];\n", .{color}); 45 | } 46 | try stream.writeAll(" }\n"); 47 | } 48 | } 49 | 50 | var class_iter = oir.classes.iterator(); 51 | while (class_iter.next()) |entry| { 52 | const class_idx = entry.key_ptr.*; 53 | const class = entry.value_ptr.*; 54 | for (class.bag.items, 0..) |node_idx, i| { 55 | const node = oir.getNode(node_idx); 56 | switch (node.tag) { 57 | .ret, 58 | .branch, 59 | => { 60 | const ctrl, const data = node.data.bin_op; 61 | try printClassEdge(stream, class_idx, i, ctrl, .red); 62 | try printClassEdge(stream, class_idx, i, data, .black); 63 | }, 64 | .project, 65 | => { 66 | const project = node.data.project; 67 | const target = project.tuple; 68 | try printClassEdge(stream, class_idx, i, target, switch (project.type) { 69 | .ctrl => .red, 70 | .data => .black, 71 | }); 72 | }, 73 | .region => { 74 | const list = node.data.list; 75 | for (list.toSlice(oir)) |item| { 76 | try printClassEdge( 77 | stream, 78 | class_idx, 79 | i, 80 | @enumFromInt(item), 81 | .red, 82 | ); 83 | } 84 | }, 85 | else => for (node.operands(oir)) |child_idx| { 86 | try printClassEdge(stream, class_idx, i, child_idx, .black); 87 | }, 88 | } 89 | } 90 | } 91 | 92 | try stream.writeAll("}\n"); 93 | } 94 | 95 | fn printClassEdge( 96 | stream: anytype, 97 | class_idx: Oir.Class.Index, 98 | i: usize, 99 | idx: Oir.Class.Index, 100 | color: enum { black, red }, 101 | ) !void { 102 | if (class_idx == idx) return; // We can't print arrows inside of a class. 103 | try stream.print( 104 | " {}.{} -> {}.0 [lhead = cluster_{} color=\"{s}\"]\n", 105 | .{ 106 | @intFromEnum(class_idx), 107 | i, 108 | @intFromEnum(idx), 109 | @intFromEnum(idx), 110 | @tagName(color), 111 | }, 112 | ); 113 | } 114 | 115 | fn printEdge( 116 | stream: anytype, 117 | i: usize, 118 | child: Oir.Class.Index, 119 | color: enum { black, red }, 120 | ) !void { 121 | try stream.print(" {d} -> {d} [color=\"{s}\"];\n", .{ 122 | i, 123 | @intFromEnum(child), 124 | @tagName(color), 125 | }); 126 | } 127 | 128 | pub fn dumpRecvGraph( 129 | recv: Recursive, 130 | stream: anytype, 131 | ) !void { 132 | try stream.writeAll( 133 | \\digraph G { 134 | \\ compound=true 135 | \\ clusterrank=local 136 | \\ graph [fontsize=14 compound=true] 137 | \\ node [shape=box, style=filled]; 138 | \\ rankdir=BT; 139 | \\ ordering="in"; 140 | \\ concentrate="true"; 141 | \\ 142 | \\ 143 | ); 144 | 145 | for (recv.nodes.items, 0..) |node, i| { 146 | try stream.print(" {} [label=\"", .{i}); 147 | try printNodeLabel(stream, node); 148 | const color = switch (node.nodeType()) { 149 | .ctrl => "orange", 150 | .data => "grey", 151 | }; 152 | try stream.print("\", color=\"{s}\"];\n", .{color}); 153 | } 154 | try stream.writeAll("\n"); 155 | 156 | for (recv.nodes.items, 0..) |node, i| { 157 | switch (node.tag) { 158 | .ret, .branch => { 159 | const ctrl, const data = node.data.bin_op; 160 | try printEdge(stream, i, ctrl, .red); 161 | try printEdge(stream, i, data, .black); 162 | }, 163 | .project => { 164 | const project = node.data.project; 165 | const target = project.tuple; 166 | try printEdge(stream, i, target, switch (project.type) { 167 | .ctrl => .red, 168 | .data => .black, 169 | }); 170 | }, 171 | .region => { 172 | const list = node.data.list; 173 | for (list.toSlice(recv)) |item| { 174 | try printEdge(stream, i, @enumFromInt(item), .red); 175 | } 176 | }, 177 | else => for (node.operands(recv)) |idx| { 178 | try printEdge(stream, i, idx, .black); 179 | }, 180 | } 181 | } 182 | 183 | try stream.writeAll("}\n"); 184 | } 185 | 186 | /// NOTE: Printing this with a "full" OIR graph is basically useless, since it just iterates 187 | /// through the node list. It only makes sense to use on recursive expressions and just created 188 | /// OIR graphs, for debugging. 189 | pub fn print( 190 | repr: anytype, 191 | stream: anytype, 192 | ) !void { 193 | var writer: Writer = .{ .nodes = repr.getNodes() }; 194 | try writer.printBody(repr, stream); 195 | } 196 | 197 | pub const Writer = struct { 198 | indent: u32 = 0, 199 | nodes: []const Oir.Node, 200 | 201 | fn printBody(w: *Writer, repr: anytype, stream: anytype) !void { 202 | for (0..w.nodes.len) |i| { 203 | try stream.print("%{d} = ", .{i}); 204 | try w.printNode(w.nodes[i], repr, stream); 205 | try stream.writeByte('\n'); 206 | } 207 | } 208 | 209 | pub fn printNode(w: *Writer, node: Oir.Node, repr: anytype, stream: anytype) !void { 210 | try stream.print("{s}(", .{@tagName(node.tag)}); 211 | switch (node.tag) { 212 | .ret, 213 | .@"and", 214 | .sub, 215 | .shl, 216 | .shr, 217 | .mul, 218 | .div_exact, 219 | .div_trunc, 220 | .add, 221 | .cmp_gt, 222 | .cmp_eq, 223 | => try w.printBinOp(node, stream), 224 | .load, 225 | => try w.printUnOp(node, stream), 226 | .project => try w.printProject(node, stream), 227 | .constant => try w.printConstant(node, stream), 228 | .branch => try w.printCtrlDataOp(node, stream), 229 | .start => try w.printStart(node, repr, stream), 230 | .region => try w.printCtrlList(node, repr, stream), 231 | else => try stream.print("TODO: {s}", .{@tagName(node.tag)}), 232 | } 233 | try stream.writeAll(")"); 234 | } 235 | 236 | fn printUnOp(_: *Writer, node: Oir.Node, stream: anytype) !void { 237 | const op = node.data.un_op; 238 | try stream.print("{}", .{op}); 239 | } 240 | 241 | fn printBinOp(_: *Writer, node: Oir.Node, stream: anytype) !void { 242 | const bin_op = node.data.bin_op; 243 | try stream.print("{}, {}", .{ bin_op[0], bin_op[1] }); 244 | } 245 | 246 | fn printProject(_: *Writer, node: Oir.Node, stream: anytype) !void { 247 | const project = node.data.project; 248 | try stream.print("{d} {}", .{ project.index, project.tuple }); 249 | } 250 | 251 | fn printConstant(_: *Writer, node: Oir.Node, stream: anytype) !void { 252 | const constant = node.data.constant; 253 | try stream.print("{d}", .{constant}); 254 | } 255 | 256 | fn printCtrlDataOp(_: *Writer, node: Oir.Node, stream: anytype) !void { 257 | const bin_op = node.data.bin_op; 258 | try stream.print("{}, {}", .{ bin_op[0], bin_op[1] }); 259 | } 260 | 261 | fn printStart(_: *Writer, _: Oir.Node, repr: anytype, stream: anytype) !void { 262 | for (repr.exit_list.items, 0..) |exit, i| { 263 | try stream.writeAll(if (i == 0) "" else ", "); 264 | try stream.print("{}", .{exit}); 265 | } 266 | } 267 | 268 | fn printCtrlList(_: *Writer, node: Oir.Node, repr: anytype, stream: anytype) !void { 269 | const span = node.data.list; 270 | for (repr.extra.items[span.start..span.end], 0..) |item, i| { 271 | try stream.writeAll(if (i == 0) "" else ", "); 272 | try stream.print("%{d}", .{item}); 273 | } 274 | } 275 | }; 276 | 277 | fn printNodeLabel( 278 | stream: anytype, 279 | node: Oir.Node, 280 | ) !void { 281 | switch (node.tag) { 282 | .constant => { 283 | const val = node.data.constant; 284 | try stream.print("constant:{d}", .{val}); 285 | }, 286 | .project => { 287 | const project = node.data.project; 288 | try stream.print("project({d})", .{project.index}); 289 | }, 290 | else => try stream.writeAll(@tagName(node.tag)), 291 | } 292 | } 293 | -------------------------------------------------------------------------------- /src/Oir/z3.zig: -------------------------------------------------------------------------------- 1 | //! Uses Z3 in order to extract an optimal pattern from the E-Graph. 2 | 3 | const std = @import("std"); 4 | const build_options = @import("build_options"); 5 | const z3 = if (build_options.has_z3) @import("z3") else {}; 6 | const Oir = @import("../Oir.zig"); 7 | const cost = @import("../cost.zig"); 8 | const Recursive = @import("extraction.zig").Recursive; 9 | 10 | const log = std.log.scoped(.z3_extractor); 11 | 12 | const Class = Oir.Class; 13 | 14 | const ClassVars = struct { 15 | active: z3.Bool, 16 | order: z3.Int, 17 | nodes: []const z3.Bool, 18 | }; 19 | 20 | pub fn extract(oir: *const Oir) !Recursive { 21 | if (!build_options.has_z3) @panic("need z3 enabled to use z3 extractor"); 22 | 23 | var arena: std.heap.ArenaAllocator = .init(oir.allocator); 24 | defer arena.deinit(); 25 | const gpa = arena.allocator(); 26 | 27 | var cycles = try oir.findCycles(); 28 | defer cycles.deinit(oir.allocator); 29 | 30 | var model = z3.Model.init(.optimize); 31 | defer model.deinit(); 32 | 33 | var vars: std.AutoHashMapUnmanaged(Class.Index, ClassVars) = .{}; 34 | defer vars.deinit(gpa); 35 | 36 | var iter = oir.classes.iterator(); 37 | while (iter.next()) |entry| { 38 | const class = entry.value_ptr; 39 | const active = model.constant(.bool, null); 40 | const order = model.constant(.int, null); 41 | 42 | const max_order = model.int(@intCast(oir.nodes.count() * 10)); 43 | model.assert(model.le(order, max_order)); 44 | 45 | const nodes = try gpa.alloc(z3.Bool, class.bag.items.len); 46 | for (nodes) |*node| node.* = model.constant(.bool, null); 47 | 48 | try vars.put(gpa, class.index, .{ 49 | .active = active, 50 | .order = order, 51 | .nodes = nodes, 52 | }); 53 | } 54 | 55 | var var_iter = vars.iterator(); 56 | while (var_iter.next()) |entry| { 57 | const id = entry.key_ptr.*; 58 | const class = entry.value_ptr; 59 | 60 | // Class is active if and only if at least one of the nodes is active. 61 | const class_or = model.@"or"(class.nodes); 62 | const equiv = model.iff(class.active, class_or); 63 | model.assert(equiv); 64 | 65 | for (oir.getClass(id).bag.items, class.nodes) |node, node_active| { 66 | // If there's a cycle through this node, it can never be chosen, so we just de-active it. 67 | if (cycles.contains(node)) { 68 | model.assert(model.not(node_active)); 69 | } 70 | 71 | // node_active == true implies that child_active == true 72 | for (oir.getNode(node).operands(oir)) |child| { 73 | const child_active = vars.get(child).?.active; 74 | const implication = model.implies(node_active, child_active); 75 | model.assert(implication); 76 | } 77 | } 78 | } 79 | 80 | // Each node in the graph is a term in the objective. Each term has a 81 | // weight, which is 0 if it isn't active, or 1 * cost(tag) if it is. 82 | // The goal of the optimizer is to reduce this number to the smallest possible 83 | // cost of the total graph, while keeping the root nodes alive. 84 | var terms: std.ArrayListUnmanaged(z3.Int) = .{}; 85 | defer terms.deinit(gpa); 86 | 87 | var class_iter = oir.classes.iterator(); 88 | while (class_iter.next()) |entry| { 89 | for ( 90 | entry.value_ptr.bag.items, 91 | vars.get(entry.value_ptr.index).?.nodes, 92 | ) |node, node_active| { 93 | const one = model.int(1); 94 | const zero = model.int(0); 95 | const int = model.ite(node_active, one, zero); 96 | const weight = model.int(@intCast(cost.getCost(oir.getNode(node).tag))); 97 | try terms.append(gpa, model.mul(&.{ weight, int })); 98 | } 99 | } 100 | 101 | const objective = model.add(terms.items); 102 | model.minimize(objective); 103 | 104 | // Force active == true for the roots. Otherwise, we'd just optimize into nothing! 105 | const exit_list = oir.getNode(.start).data.list.toSlice(oir); 106 | for (exit_list) |exit| { 107 | const root = vars.get(@enumFromInt(exit)).?.active; 108 | const eq = model.eq(root, model.true()); 109 | model.assert(eq); 110 | } 111 | 112 | log.debug("solver:\n{s}\n", .{model.toString()}); 113 | 114 | const result = model.check(); 115 | 116 | if (result == .true) { 117 | var partial_model = model.getLastModel(); 118 | defer partial_model.deinit(); 119 | log.debug("found solution model:\n{s}\n", .{partial_model.toString()}); 120 | 121 | var recv: Recursive = .{}; 122 | var start_class: ?Class.Index = null; 123 | 124 | var new_exit_list: std.ArrayListUnmanaged(Class.Index) = .{}; 125 | defer new_exit_list.deinit(gpa); 126 | 127 | var queue: std.ArrayListUnmanaged(Class.Index) = .{}; 128 | defer queue.deinit(gpa); 129 | 130 | var map: std.AutoHashMapUnmanaged(Class.Index, Class.Index) = .{}; 131 | defer map.deinit(gpa); 132 | 133 | for (exit_list) |exit| { 134 | try queue.append(gpa, oir.union_find.find(@enumFromInt(exit))); 135 | } 136 | 137 | while (queue.getLastOrNull()) |id| { 138 | if (map.contains(id)) { 139 | _ = queue.pop(); 140 | continue; 141 | } 142 | const v = vars.get(id).?; 143 | std.debug.assert(partial_model.isTrue(v.active)); 144 | 145 | const node_idx = for (v.nodes, 0..) |node, i| { 146 | if (partial_model.isTrue(node)) break i; 147 | } else @panic("TODO"); 148 | const node = oir.getNode(oir.getClass(id).bag.items[node_idx]); 149 | 150 | // Check whether all operands are in the memo map. 151 | var all: bool = true; 152 | for (node.operands(oir)) |child| { 153 | if (!map.contains(child)) all = false; 154 | } 155 | if (all) { 156 | const new_id = try recv.addNode(oir.allocator, try node.mapNode(oir, &map)); 157 | switch (node.tag) { 158 | .ret => try new_exit_list.append(gpa, new_id), 159 | .start => start_class = new_id, 160 | else => {}, 161 | } 162 | try map.put(gpa, id, new_id); 163 | _ = queue.pop(); 164 | } else { 165 | try queue.appendSlice(gpa, node.operands(oir)); 166 | } 167 | } 168 | 169 | if (start_class == null) @panic("no start class?"); 170 | recv.nodes.items[@intFromEnum(start_class.?)].data = .{ 171 | .list = try recv.listToSpan( 172 | new_exit_list.items, 173 | oir.allocator, 174 | ), 175 | }; 176 | 177 | return recv; 178 | } else { 179 | std.debug.panic("no solution found!!!! what??", .{}); 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/codegen/p2.zig: -------------------------------------------------------------------------------- 1 | //! Backend for emitting P2 (Paralax Propeller 2) assembly 2 | 3 | const Oir = @import("../Oir.zig"); 4 | const Recursive = Oir.extraction.Recursive; 5 | 6 | pub fn generate(recv: *const Recursive) !void { 7 | _ = recv; 8 | } 9 | -------------------------------------------------------------------------------- /src/cost.zig: -------------------------------------------------------------------------------- 1 | //! Defines simple cost information about MIR instructions 2 | 3 | pub fn getCost(tag: Oir.Node.Tag) u32 { 4 | return switch (tag) { 5 | // ALU operations 6 | .add, 7 | .sub, 8 | .mul, 9 | .div_trunc, 10 | .div_exact, 11 | => 2, 12 | 13 | .@"and", 14 | .shl, 15 | .shr, 16 | => 1, 17 | 18 | // Basic memory operations 19 | .load, 20 | .store, 21 | => 1, 22 | 23 | // Compare 24 | .cmp_eq, 25 | .cmp_gt, 26 | .branch, 27 | => 1, 28 | 29 | .start, 30 | .ret, 31 | .dead, 32 | .region, 33 | // constants have zero latency so that we bias towards 34 | // selecting the "free" canonical element. 35 | .constant, 36 | .project, 37 | => 0, 38 | }; 39 | } 40 | 41 | const Oir = @import("Oir.zig"); 42 | -------------------------------------------------------------------------------- /src/lib.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | pub const Oir = @import("Oir.zig"); 3 | pub const Trace = @import("trace.zig").Trace; 4 | 5 | pub const Recursive = Oir.extraction.Recursive; 6 | 7 | pub const p2 = @import("codegen/p2.zig"); 8 | 9 | pub const options: std.Options = .{ 10 | .log_level = .err, 11 | }; 12 | 13 | test { 14 | std.testing.refAllDeclsRecursive(@This()); 15 | } 16 | -------------------------------------------------------------------------------- /src/passes/constant_fold.zig: -------------------------------------------------------------------------------- 1 | //! The Constant Folding pass 2 | 3 | const std = @import("std"); 4 | const Oir = @import("../Oir.zig"); 5 | 6 | const Node = Oir.Node; 7 | const assert = std.debug.assert; 8 | 9 | /// Iterates through all nodes in the E-Graph, checking if it's possible to evaluate them now. 10 | /// 11 | /// If a node is found with "comptime-known" children, it's evaluated and the new 12 | /// "comptime-known" result is added to that node's class. 13 | pub fn run(oir: *Oir) !bool { 14 | // A buffer of constant nodes found in operand classes. 15 | // Not a BoundedArray, since there are certain nodes that can have a variable 16 | // amount of operands. 17 | var constants: std.ArrayListUnmanaged(Node.Index) = .{}; 18 | defer constants.deinit(oir.allocator); 19 | 20 | outer: for (oir.nodes.keys(), 0..) |node, i| { 21 | const node_idx: Node.Index = @enumFromInt(i); 22 | const class_idx = oir.findClass(node_idx); 23 | 24 | // If this node is volatile, we cannot fold it away. 25 | if (node.isVolatile()) continue; 26 | 27 | switch (node.tag) { 28 | .add, 29 | .sub, 30 | .mul, 31 | .div_exact, 32 | .div_trunc, 33 | .shl, 34 | .cmp_eq, 35 | .cmp_gt, 36 | => { 37 | // the class has already been solved for a constant, no need to do anything else! 38 | if (oir.classContains(class_idx, .constant) != null) continue; 39 | assert(node.tag != .constant); 40 | defer constants.clearRetainingCapacity(); 41 | 42 | for (node.operands(oir)) |child_idx| { 43 | if (oir.classContains(child_idx, .constant)) |constant| { 44 | try constants.append(oir.allocator, constant); 45 | } else continue :outer; 46 | } 47 | 48 | const lhs, const rhs = constants.items[0..2].*; 49 | const lhs_value = oir.getNode(lhs).data.constant; 50 | const rhs_value = oir.getNode(rhs).data.constant; 51 | 52 | const result = switch (node.tag) { 53 | .add => lhs_value + rhs_value, 54 | .sub => lhs_value - rhs_value, 55 | .mul => lhs_value * rhs_value, 56 | .div_exact => @divExact(lhs_value, rhs_value), 57 | .div_trunc => @divTrunc(lhs_value, rhs_value), 58 | .shl => lhs_value << @intCast(rhs_value), 59 | .cmp_eq => @intFromBool(lhs_value == rhs_value), 60 | .cmp_gt => @intFromBool(lhs_value > rhs_value), 61 | else => unreachable, 62 | }; 63 | 64 | const new_class = try oir.add(.{ 65 | .tag = .constant, 66 | .data = .{ .constant = result }, 67 | }); 68 | _ = try oir.@"union"(new_class, class_idx); 69 | try oir.rebuild(); 70 | 71 | // We can't continue this iteration since the rebuild could have modified 72 | // the `nodes` list. 73 | // TODO: figure out a better way to continue running, even after a rebuild 74 | // has affected the graph. 75 | return true; 76 | }, 77 | .constant => {}, // already folded! 78 | .project => {}, // TODO 79 | .load => {}, // TODO: GVN load elision 80 | .store => {}, // ^ 81 | .branch => { 82 | // We can fold away branches if we know what the predicate is. 83 | const predicate = node.data.bin_op[1]; 84 | if (oir.classContains(predicate, .constant)) |idx| { 85 | _ = idx; 86 | // TODO: no clear way to extract this in a valid way yet 87 | // const value = oir.getNode(idx).data.constant; 88 | // assert(value == 0 or value == 1); 89 | 90 | // var true_project: ?Oir.Class.Index = null; 91 | // for (oir.nodes.keys(), 0..) |sub_node, j| { 92 | // if (sub_node.tag == .project) { 93 | // const project = sub_node.data.project; 94 | 95 | // if (project.tuple == class_idx and 96 | // // NOTE: project(0, ...) is the then case, and 1 is "true", 97 | // // so we are comparing inverse. 98 | // project.index != value) 99 | // { 100 | // true_project = oir.findClass(@enumFromInt(j)); 101 | // } 102 | // } 103 | // } 104 | 105 | // if (true_project) |prt| { 106 | // for (oir.nodes.keys()) |sub_node| { 107 | // switch (sub_node.tag) { 108 | // .ret => { 109 | // if (sub_node.data.bin_op[0] == prt) { 110 | // const new_ret = try oir.add(.binOp( 111 | // .ret, 112 | // node.data.bin_op[0], 113 | // sub_node.data.bin_op[1], 114 | // )); 115 | // try oir.exit_list.insert(oir.allocator, 0, new_ret); 116 | // return false; 117 | // } 118 | // }, 119 | // else => {}, 120 | // } 121 | // } 122 | // return false; 123 | // } else return false; 124 | } 125 | }, 126 | else => std.debug.panic("TODO: constant fold {s}", .{@tagName(node.tag)}), 127 | } 128 | } 129 | 130 | return false; 131 | } 132 | -------------------------------------------------------------------------------- /src/passes/rewrite.zig: -------------------------------------------------------------------------------- 1 | //! Contains the common rewrites pass. This pass will find basic patterns in 2 | //! the graph, and convert them to another one. These rewrites won't always 3 | //! be strict improvements in the graph, but they expose future passes to 4 | //! find more advanced patterns. 5 | 6 | const std = @import("std"); 7 | const SExpr = @import("rewrite/SExpr.zig"); 8 | const Oir = @import("../Oir.zig"); 9 | const machine = @import("rewrite/machine.zig"); 10 | 11 | const log = std.log.scoped(.rewrite); 12 | 13 | const Node = Oir.Node; 14 | const Class = Oir.Class; 15 | 16 | pub const Rewrite = struct { 17 | name: []const u8, 18 | from: SExpr, 19 | to: SExpr, 20 | }; 21 | 22 | pub const MultiRewrite = struct { 23 | name: []const u8, 24 | from: []const MultiPattern, 25 | }; 26 | 27 | pub const MultiPattern = struct { 28 | atom: []const u8, 29 | pattern: SExpr, 30 | }; 31 | 32 | pub const Result = struct { 33 | bindings: Bindings, 34 | class: Class.Index, 35 | pattern: SExpr, 36 | 37 | pub const Bindings = std.StringHashMapUnmanaged(Class.Index); 38 | pub const Error = error{ OutOfMemory, InvalidCharacter, Overflow }; 39 | 40 | fn deinit(result: *const Result, gpa: std.mem.Allocator) void { 41 | var copy = result.*; 42 | copy.bindings.deinit(gpa); 43 | } 44 | }; 45 | 46 | const rewrites: []const Rewrite = blk: { 47 | const table: []const struct { 48 | name: []const u8, 49 | from: []const u8, 50 | to: []const u8, 51 | } = @import("rewrite/table.zon"); 52 | @setEvalBranchQuota(table.len * 20_000); 53 | var list: [table.len]Rewrite = undefined; 54 | for (&list, table) |*entry, op| { 55 | entry.* = Rewrite{ 56 | .name = op.name, 57 | .from = SExpr.parse(op.from), 58 | .to = SExpr.parse(op.to), 59 | }; 60 | } 61 | const copy = list; 62 | break :blk © 63 | }; 64 | 65 | pub fn run(oir: *Oir) !bool { 66 | var matches: std.ArrayListUnmanaged(Result) = .{}; 67 | defer { 68 | for (matches.items) |*m| m.deinit(oir.allocator); 69 | matches.deinit(oir.allocator); 70 | } 71 | 72 | { 73 | const trace = oir.trace.start(@src(), "searching", .{}); 74 | defer trace.end(); 75 | 76 | for (rewrites) |rewrite| { 77 | try machine.search(oir, .{ 78 | .from = rewrite.from, 79 | .to = rewrite.to, 80 | .name = rewrite.name, 81 | }, &matches); 82 | } 83 | } 84 | 85 | { 86 | const trace = oir.trace.start(@src(), "applying matches", .{}); 87 | defer trace.end(); 88 | 89 | return try applyMatches( 90 | oir, 91 | matches.items, 92 | ); 93 | } 94 | } 95 | 96 | fn applyMatches(oir: *Oir, matches: []const Result) !bool { 97 | var ids: std.ArrayListUnmanaged(Class.Index) = .{}; 98 | defer ids.deinit(oir.allocator); 99 | 100 | var changed: bool = false; 101 | for (matches) |m| { 102 | ids.clearRetainingCapacity(); 103 | 104 | for (m.pattern.nodes) |entry| { 105 | const id = switch (entry) { 106 | .atom => |v| m.bindings.get(v).?, 107 | .constant => |c| try oir.add(.constant(c)), 108 | .node => |n| b: { 109 | var new = switch (n.tag) { 110 | .region => unreachable, // TODO 111 | inline else => |t| Node.init(t, undefined), 112 | }; 113 | for (new.mutableOperands(oir), n.list) |*op, child| { 114 | op.* = ids.items[@intFromEnum(child)]; 115 | } 116 | break :b try oir.add(new); 117 | }, 118 | .builtin => |b| b: { 119 | if (b.tag.location() != .dst) 120 | @panic("have non-dst builtin in destination pattern"); 121 | 122 | switch (b.tag) { 123 | .log2 => { 124 | // TODO: I'd like to figure out a way to safely access `classContains` 125 | // for constants without having to rebuild the graph. In theory it should 126 | // be possible, but my concern right now is that if the class index gets 127 | // merged into a larger class, it will cause issues. Maybe union find 128 | // makes up for that? Need to do more testing. 129 | try oir.rebuild(); 130 | 131 | const idx = m.bindings.get(b.expr).?; 132 | 133 | const node_idx = oir.classContains(idx, .constant) orelse 134 | @panic("@log2 binding isn't a power of two?"); 135 | const node = oir.getNode(node_idx); 136 | const value = node.data.constant; 137 | if (value < 1) @panic("how do we handle @log2 of a negative?"); 138 | const log_value = std.math.log2_int(u64, @intCast(value)); 139 | break :b try oir.add(.constant(log_value)); 140 | }, 141 | else => unreachable, 142 | } 143 | }, 144 | }; 145 | 146 | try ids.append(oir.allocator, id); 147 | } 148 | 149 | const last = ids.getLast(); 150 | if (try oir.@"union"(m.class, last)) changed = true; 151 | } 152 | return changed; 153 | } 154 | 155 | const expectEqual = std.testing.expectEqual; 156 | 157 | fn testSearch(oir: *const Oir, comptime buffer: []const u8, num_matches: u64) !void { 158 | std.debug.assert(oir.clean); // must be clean before searching 159 | 160 | const apply = SExpr.parse("?x"); 161 | const pattern = SExpr.parse(buffer); 162 | 163 | var matches: std.ArrayListUnmanaged(Result) = .{}; 164 | defer { 165 | for (matches.items) |*m| m.deinit(oir.allocator); 166 | matches.deinit(oir.allocator); 167 | } 168 | try machine.search(oir, .{ 169 | .from = pattern, 170 | .to = apply, 171 | .name = "test", 172 | }, &matches); 173 | 174 | try expectEqual(num_matches, matches.items.len); 175 | } 176 | 177 | test "basic match" { 178 | const allocator = std.testing.allocator; 179 | var oir: Oir = .init(allocator); 180 | defer oir.deinit(); 181 | 182 | // (add (mul 10 20) 30) 183 | _ = try oir.add(try .create(.start, &oir, &.{})); 184 | const a = try oir.add(.init(.constant, 10)); 185 | const b = try oir.add(.init(.constant, 20)); 186 | const mul = try oir.add(.binOp(.mul, a, b)); 187 | const c = try oir.add(.init(.constant, 30)); 188 | _ = try oir.add(.binOp(.add, mul, c)); 189 | try oir.rebuild(); 190 | 191 | try testSearch(&oir, "(mul 10 20)", 1); 192 | try testSearch(&oir, "(mul ?x 20)", 1); 193 | try testSearch(&oir, "(add ?x ?x)", 0); 194 | try testSearch(&oir, "(add 10 20)", 0); 195 | try testSearch(&oir, "(add ?x ?y)", 1); 196 | try testSearch(&oir, "(add (mul 10 20) 30)", 1); 197 | } 198 | 199 | test "builtin function match" { 200 | const allocator = std.testing.allocator; 201 | var oir: Oir = .init(allocator); 202 | defer oir.deinit(); 203 | 204 | // (mul 37 16) 205 | _ = try oir.add(try .create(.start, &oir, &.{})); 206 | const a = try oir.add(.init(.constant, 37)); 207 | const b = try oir.add(.init(.constant, 16)); 208 | _ = try oir.add(.binOp(.mul, a, b)); 209 | try oir.rebuild(); 210 | 211 | try testSearch(&oir, "(mul ?x @known_pow2(?y))", 1); 212 | try testSearch(&oir, "(add ?x @known_pow2(?y))", 0); 213 | } 214 | 215 | test "negative known_pow2" { 216 | const allocator = std.testing.allocator; 217 | var oir: Oir = .init(allocator); 218 | defer oir.deinit(); 219 | 220 | // (mul 5 -2) 221 | _ = try oir.add(try .create(.start, &oir, &.{})); 222 | const a = try oir.add(.init(.constant, 5)); 223 | const b = try oir.add(.init(.constant, -2)); 224 | _ = try oir.add(.binOp(.mul, a, b)); 225 | try oir.rebuild(); 226 | 227 | try testSearch(&oir, "(mul ?x @known_pow2(?y))", 0); 228 | } 229 | 230 | // test "basic multi-pattern match" { 231 | // const allocator = std.testing.allocator; 232 | // var oir: Oir = .init(allocator, &trace); 233 | // defer oir.deinit(); 234 | 235 | // // (add (mul 10 20) 30) 236 | // _ = try oir.add(try .create(.start, &oir, &.{})); 237 | // const a = try oir.add(.init(.constant, 10)); 238 | // const b = try oir.add(.init(.constant, 20)); 239 | // const add = try oir.add(.binOp(.mul, a, b)); 240 | // const c = try oir.add(.init(.constant, 30)); 241 | // _ = try oir.add(.binOp(.add, add, c)); 242 | // try oir.rebuild(); 243 | 244 | // try machine.multiSearch(&oir, .{ 245 | // .from = &.{ 246 | // .{ 247 | // .atom = "?x", 248 | // .pattern = SExpr.parse("(mul 10 20)"), 249 | // }, 250 | // .{ 251 | // .atom = "?y", 252 | // .pattern = SExpr.parse("(add ?a 30)"), 253 | // }, 254 | // }, 255 | // .name = "test", 256 | // }); 257 | // // defer { 258 | // // for (matches) |*m| m.deinit(oir.allocator); 259 | // // oir.allocator.free(matches); 260 | // // } 261 | 262 | // // std.debug.print("n matches: {}\n", .{matches.len}); 263 | // } 264 | -------------------------------------------------------------------------------- /src/passes/rewrite/SExpr.zig: -------------------------------------------------------------------------------- 1 | //! Describes an S-Expression used for describing graph rewrites. 2 | //! 3 | //! S-Expressions can contain three things: 4 | //! 5 | //! - An identifier. An identifier signifies a unique value. 6 | //! To denote an identifier, you would write any single letter, [a-zA-Z] 7 | //! prepended with a question mark (`?`). Given the law of equivalence that E-Graphs 8 | //! must maintain before commiting rewrites, identifiers with the same letter will 9 | //! be required to match. 10 | //! An example usage of this could be removing exact divisions of the same values. 11 | //! The expression for such a rewrite would look like, `(mul ?x ?x)`, and for the e-match 12 | //! to succeed, the optimizer would need to prove that both the RHS and the LHS of 13 | //! the nodes here are equivalent. 14 | //! 15 | //! - A constant. This is a numerical number, which requires the node intending to match the constraint 16 | //! to have been proven to be equivalent to this constant. These can be written in the Zig fashion of, 17 | //! 0x10, 0b100, 0o10, or 10 18 | //! TODO: these values can only be `i64` currently. 19 | //! 20 | //! - Other S-Expressions. S-expressions are intended to be nested, and matching will consider 21 | //! the absolute structure of the expression when pairing. 22 | 23 | nodes: []const Entry, 24 | 25 | pub const Entry = union(enum) { 26 | atom: []const u8, 27 | constant: i64, 28 | node: Node, 29 | builtin: BuiltinFn, 30 | 31 | pub const Node = struct { 32 | tag: NodeTag, 33 | list: []const Index, 34 | }; 35 | 36 | const FormatCtx = struct { 37 | entry: Entry, 38 | expr: SExpr, 39 | }; 40 | 41 | pub fn operands(e: Entry) []const Index { 42 | return switch (e) { 43 | .builtin, .atom, .constant => &.{}, 44 | .node => |n| n.list, 45 | }; 46 | } 47 | 48 | pub fn tag(e: Entry) NodeTag { 49 | return switch (e) { 50 | .atom => unreachable, 51 | .builtin => unreachable, 52 | .constant => .constant, 53 | .node => |n| n.tag, 54 | }; 55 | } 56 | 57 | pub fn matches(e: Entry, n: Oir.Node, oir: *const Oir) bool { 58 | if (e == .builtin) { 59 | switch (e.builtin.tag) { 60 | .known_pow2 => { 61 | if (n.tag != .constant) return false; 62 | const value = n.data.constant; 63 | if (value > 0 and std.math.isPowerOfTwo(value)) return true; 64 | return false; 65 | }, 66 | else => @panic("TODO"), 67 | } 68 | } 69 | if (n.tag != e.tag()) return false; 70 | if (n.operands(oir).len != e.operands().len) return false; 71 | if (n.tag == .constant and n.data.constant != e.constant) return false; 72 | return true; 73 | } 74 | 75 | pub fn map( 76 | e: Entry, 77 | allocator: std.mem.Allocator, 78 | m: *const std.AutoHashMapUnmanaged(Index, Index), 79 | ) !Entry { 80 | return switch (e) { 81 | .atom, 82 | .constant, 83 | .builtin, 84 | => e, 85 | .node => |n| n: { 86 | const new_operands = try allocator.dupe(Index, n.list); 87 | for (new_operands) |*op| op.* = m.get(op.*).?; 88 | break :n .{ .node = .{ .tag = n.tag, .list = new_operands } }; 89 | }, 90 | }; 91 | } 92 | 93 | pub fn format2( 94 | ctx: FormatCtx, 95 | comptime _: []const u8, 96 | _: std.fmt.FormatOptions, 97 | writer: anytype, 98 | ) !void { 99 | switch (ctx.entry) { 100 | .atom => |atom| try writer.writeAll(atom), 101 | .constant => |constant| try writer.print("{}", .{constant}), 102 | .builtin => |b| try writer.print("@{s}({s})", .{ @tagName(b.tag), b.expr }), 103 | .node => |node| { 104 | try writer.print("({s}", .{@tagName(node.tag)}); 105 | for (node.list) |index| { 106 | try writer.print( 107 | " {}", 108 | .{ctx.expr.nodes[@intFromEnum(index)].fmt(ctx.expr)}, 109 | ); 110 | } 111 | try writer.writeAll(")"); 112 | }, 113 | } 114 | } 115 | 116 | pub fn fmt(entry: Entry, expr: SExpr) std.fmt.Formatter(format2) { 117 | return .{ .data = .{ 118 | .expr = expr, 119 | .entry = entry, 120 | } }; 121 | } 122 | 123 | pub fn format( 124 | entry: Entry, 125 | comptime _: []const u8, 126 | _: std.fmt.FormatOptions, 127 | writer: anytype, 128 | ) !void { 129 | switch (entry) { 130 | .atom => |v| try writer.writeAll(v), 131 | .constant => |c| try writer.print("{}", .{c}), 132 | .builtin => |b| try writer.print("@{s}({s})", .{ @tagName(b.tag), b.expr }), 133 | .node => |list| { 134 | try writer.print("({s} ", .{@tagName(list.tag)}); 135 | for (list.list, 0..) |child, i| { 136 | try writer.print("%{d}", .{@intFromEnum(child)}); 137 | if (i != list.list.len - 1) try writer.writeAll(", "); 138 | } 139 | try writer.writeByte(')'); 140 | }, 141 | } 142 | } 143 | }; 144 | 145 | pub const Index = enum(u32) { 146 | _, 147 | }; 148 | 149 | const BuiltinFn = struct { 150 | tag: Tag, 151 | expr: []const u8, 152 | 153 | const Tag = enum { 154 | known_pow2, 155 | log2, 156 | 157 | pub fn location(tag: Tag) Location { 158 | return switch (tag) { 159 | .known_pow2 => .src, 160 | .log2 => .dst, 161 | }; 162 | } 163 | }; 164 | 165 | /// Describes where this builtin can be used. 166 | const Location = enum { 167 | /// This builtin can be used in the source expression, during matching. 168 | /// 169 | /// Its parameter is the name of the identifier that will be set in the bindings 170 | /// when it finds that node/constant. 171 | /// 172 | /// `(mul ?x @known_pow2(y))` will create bindings where `y` is the constant node 173 | /// that was proven to be a known power of two. 174 | src, 175 | /// This builtin can be used in the destination expression, during applying. 176 | /// 177 | /// Its parameter is a link to the name of the identifier that was found during matching. 178 | /// 179 | /// `(shl ?x @log2(y))` will search up for `y` in the bindings and take the log2 of 180 | /// the constant node that was found. 181 | dst, 182 | }; 183 | }; 184 | 185 | /// TODO: better error reporting! 186 | pub const Parser = struct { 187 | nodes: []Entry = &.{}, 188 | buffer: []const u8, 189 | index: u32 = 0, 190 | 191 | pub fn parseInternal(comptime parser: *Parser) Index { 192 | @setEvalBranchQuota(parser.buffer.len * 1_000); 193 | while (parser.index < parser.buffer.len) { 194 | const c = parser.eat(); 195 | switch (c) { 196 | // the start of an expression. we're expecting to 197 | // have the expression tag next, i.e (mul ... 198 | '(' => { 199 | const tag_start = parser.index; 200 | // the space is what seperates the tag from the rest of the expression 201 | try parser.eatUntilDelimiter(' '); 202 | const tag_end = parser.index; 203 | const tag_string = parser.buffer[tag_start..tag_end]; 204 | const tag = std.meta.stringToEnum(NodeTag, tag_string) orelse 205 | @compileError("unknown tag"); 206 | // now there will be a list of parameters to this expression 207 | // i.e (mul ?x ?y), where ?x and ?y are the parameters. 208 | // these are delimited by the right paren. 209 | var list: []const Index = &.{}; 210 | while (parser.peak() != ')') { 211 | if (parser.index == parser.buffer.len) { 212 | @compileError("no closing paren"); 213 | } 214 | // this should only happen when an expression was parsed 215 | // but a space wasn't provided after it. so, (mul ?x?y). 216 | // the issue here is that identifiers are only allowed to 217 | // have single letter names, so `?y?` would be in a second 218 | // loop. 219 | if (parser.peak() != ' ') { 220 | @compileLog("no space after arg"); 221 | } 222 | // eat the space before parsing the next expression 223 | assert(parser.eat() == ' '); 224 | const expr = parser.parseInternal(); 225 | list = list ++ .{expr}; 226 | } 227 | // closing off the expression with a parenthesis 228 | assert(parser.eat() == ')'); 229 | if (list.len == 0) { 230 | @compileLog("no expression arguments"); 231 | } 232 | 233 | return parser.addEntry(.{ .node = .{ .tag = tag, .list = list } }); 234 | }, 235 | // the start of an identifier 236 | '?' => { 237 | if (!std.ascii.isAlphabetic(parser.peak())) { 238 | // the next character must be a letter, since only 239 | // identifiers start with question marks 240 | @compileLog("question mark without letter"); 241 | } 242 | 243 | // this - 1 is to include the `?`, which we will use later in the pipeline 244 | const ident_start = parser.index - 1; 245 | while (parser.index < parser.buffer.len and 246 | std.mem.indexOfScalar(u8, ident_delim, parser.peak()) == null) 247 | { 248 | parser.index += 1; 249 | } 250 | const ident_end = parser.index; 251 | 252 | const ident = parser.buffer[ident_start..ident_end]; 253 | if (ident.len != 2) { 254 | // identifiers should be a single character, including the ? that's 2 length 255 | @compileLog("ident too long"); 256 | } 257 | 258 | return parser.addEntry(.{ .atom = ident }); 259 | }, 260 | '0'...'9' => { 261 | // this -1 is to include the first number 262 | const constant_start = parser.index - 1; 263 | while (parser.index < parser.buffer.len and 264 | std.mem.indexOfScalar(u8, ident_delim, parser.peak()) == null) 265 | { 266 | parser.index += 1; 267 | } 268 | const constant_end = parser.index; 269 | const constant = parser.buffer[constant_start..constant_end]; 270 | 271 | const value: i64 = try std.fmt.parseInt(i64, constant, 0); 272 | return parser.addEntry(.{ .constant = value }); 273 | }, 274 | // the start of a builtin function 275 | '@' => { 276 | const builtin_start = parser.index; 277 | try parser.eatUntilDelimiter('('); 278 | const builtin_end = parser.index; 279 | _ = parser.eat(); 280 | 281 | const builtin_name = parser.buffer[builtin_start..builtin_end]; 282 | const builtin_tag = std.meta.stringToEnum(BuiltinFn.Tag, builtin_name) orelse 283 | @compileError("unknown builtin function"); 284 | 285 | const param_start = parser.index; 286 | try parser.eatUntilDelimiter(')'); 287 | const param_end = parser.index; 288 | 289 | const param = parser.buffer[param_start..param_end]; 290 | 291 | return parser.addEntry(.{ .builtin = .{ 292 | .tag = builtin_tag, 293 | .expr = param, 294 | } }); 295 | }, 296 | else => @compileError("unknown character: '" ++ .{c} ++ "'"), 297 | } 298 | } 299 | @compileError("unexpected end of expression"); 300 | } 301 | 302 | fn addEntry(parser: *Parser, entry: Entry) Index { 303 | const index: Index = @enumFromInt(parser.nodes.len); 304 | var copy = (parser.nodes ++ (&entry)[0..1])[0..].*; 305 | parser.nodes = © 306 | return index; 307 | } 308 | 309 | fn eat(parser: *Parser) u8 { 310 | const char = parser.peak(); 311 | parser.index += 1; 312 | return char; 313 | } 314 | 315 | fn peak(parser: *Parser) u8 { 316 | return parser.buffer[parser.index]; 317 | } 318 | 319 | fn eatUntilDelimiter(parser: *Parser, delem: u8) !void { 320 | while (parser.peak() != delem) : (parser.index += 1) { 321 | if (parser.index == parser.buffer.len) return error.OutOfBounds; 322 | } 323 | } 324 | 325 | /// The characters that can deliminate an identifier. 326 | const ident_delim: []const u8 = &.{ ' ', ')' }; 327 | }; 328 | 329 | pub inline fn parse(comptime buffer: []const u8) SExpr { 330 | comptime { 331 | var parser: Parser = .{ .buffer = buffer }; 332 | _ = parser.parseInternal(); 333 | const copy = parser.nodes[0..].*; 334 | return .{ .nodes = © }; 335 | } 336 | } 337 | 338 | pub fn root(expr: SExpr) Index { 339 | return @enumFromInt(expr.nodes.len - 1); 340 | } 341 | 342 | pub fn get(expr: SExpr, idx: Index) Entry { 343 | return expr.nodes[@intFromEnum(idx)]; 344 | } 345 | 346 | pub fn deinit(expr: SExpr, allocator: std.mem.Allocator) void { 347 | for (expr.nodes) |node| { 348 | switch (node) { 349 | .node => |n| allocator.free(n.list), 350 | else => {}, 351 | } 352 | } 353 | allocator.free(expr.nodes); 354 | } 355 | 356 | pub fn format( 357 | expr: SExpr, 358 | comptime fmt: []const u8, 359 | _: std.fmt.FormatOptions, 360 | writer: anytype, 361 | ) !void { 362 | comptime assert(fmt.len == 0); 363 | 364 | const r: Entry = expr.get(expr.root()); 365 | try writer.print("{}", .{r.fmt(expr)}); 366 | } 367 | 368 | pub fn isIdent(expr: *const SExpr) bool { 369 | return expr.data == .atom and expr.data.atom[0] == '?'; 370 | } 371 | 372 | test "single-layer, multi-variable" { 373 | const expr = comptime SExpr.parse("(mul ?x ?y)"); 374 | 375 | const root_node = expr.get(expr.root()); 376 | 377 | try expect(root_node == .node and root_node.node.tag == .mul); 378 | 379 | const lhs = expr.nodes[0]; 380 | const rhs = expr.nodes[1]; 381 | 382 | try expect(lhs == .atom); 383 | try expect(std.mem.eql(u8, lhs.atom, "?x")); 384 | 385 | try expect(rhs == .atom); 386 | try expect(std.mem.eql(u8, rhs.atom, "?y")); 387 | } 388 | 389 | test "single-layer, single variable single constant" { 390 | const expr = comptime SExpr.parse("(mul 10 ?y)"); 391 | 392 | const root_node = expr.get(expr.root()); 393 | 394 | try expect(root_node == .node and root_node.node.tag == .mul); 395 | 396 | const lhs = expr.nodes[0]; 397 | const rhs = expr.nodes[1]; 398 | 399 | try expect(lhs == .constant); 400 | try expect(lhs.constant == 10); 401 | 402 | try expect(rhs == .atom); 403 | try expect(std.mem.eql(u8, rhs.atom, "?y")); 404 | } 405 | 406 | test "multi-layer, multi-variable" { 407 | @setEvalBranchQuota(20_000); 408 | const expr = comptime SExpr.parse("(div_exact ?z (mul ?x ?y))"); 409 | 410 | const root_node = expr.get(expr.root()); 411 | 412 | try expect(root_node == .node and root_node.node.tag == .div_exact); 413 | 414 | const lhs = expr.get(root_node.node.list[0]); 415 | const rhs = expr.get(root_node.node.list[1]); 416 | 417 | try expect(lhs == .atom); 418 | try expect(std.mem.eql(u8, lhs.atom, "?z")); 419 | 420 | try expect(rhs == .node); 421 | try expect(rhs.node.tag == .mul); 422 | 423 | const mul_lhs = expr.get(rhs.node.list[0]); 424 | const mul_rhs = expr.get(rhs.node.list[1]); 425 | 426 | try expect(mul_lhs == .atom); 427 | try expect(std.mem.eql(u8, mul_lhs.atom, "?x")); 428 | 429 | try expect(mul_rhs == .atom); 430 | try expect(std.mem.eql(u8, mul_rhs.atom, "?y")); 431 | } 432 | 433 | test "builtin function" { 434 | const expr = comptime SExpr.parse("(mul ?x @known_pow2(y))"); 435 | 436 | const root_node = expr.get(expr.root()); 437 | 438 | try expect(root_node == .node and root_node.node.tag == .mul); 439 | 440 | const lhs = expr.get(root_node.node.list[0]); 441 | const rhs = expr.get(root_node.node.list[1]); 442 | 443 | try expect(lhs == .atom); 444 | try expect(std.mem.eql(u8, lhs.atom, "?x")); 445 | 446 | try expect(rhs == .builtin); 447 | try expect(rhs.builtin.tag == .known_pow2); 448 | try expect(std.mem.eql(u8, "y", rhs.builtin.expr)); 449 | } 450 | 451 | const SExpr = @This(); 452 | const Oir = @import("../../Oir.zig"); 453 | const NodeTag = Oir.Node.Tag; 454 | const std = @import("std"); 455 | const expect = std.testing.expect; 456 | const assert = std.debug.assert; 457 | -------------------------------------------------------------------------------- /src/passes/rewrite/machine.zig: -------------------------------------------------------------------------------- 1 | //! Implements E-matching through an Abstract Virtual Machine. 2 | //! 3 | //! Based on this paper: https://leodemoura.github.io/files/ematching.pdf 4 | 5 | const std = @import("std"); 6 | const Oir = @import("../../Oir.zig"); 7 | const rewrite = @import("../rewrite.zig"); 8 | const SExpr = @import("SExpr.zig"); 9 | 10 | const Node = Oir.Node; 11 | const Class = Oir.Class; 12 | const Result = rewrite.Result; 13 | const Rewrite = rewrite.Rewrite; 14 | const MultiRewrite = rewrite.MultiRewrite; 15 | 16 | const Compiler = struct { 17 | next_reg: Reg, 18 | instructions: std.ArrayListUnmanaged(Instruction), 19 | todo_nodes: std.AutoHashMapUnmanaged(struct { SExpr.Index, Reg }, SExpr.Entry), 20 | v2r: std.StringHashMapUnmanaged(Reg), 21 | free_vars: std.ArrayListUnmanaged(std.StringArrayHashMapUnmanaged(void)), 22 | subtree_size: std.ArrayListUnmanaged(u64), 23 | 24 | const Todo = struct { SExpr.Index, Reg }; 25 | 26 | fn compile( 27 | c: *Compiler, 28 | allocator: std.mem.Allocator, 29 | bind: ?[]const u8, 30 | pattern: SExpr, 31 | ) !void { 32 | try c.loadPattern(allocator, pattern); 33 | const root = pattern.root(); 34 | 35 | if (bind) |v| { 36 | if (c.v2r.get(v)) |i| { 37 | try c.addTodo(allocator, pattern, root, i); 38 | } else { 39 | try c.addPattern(allocator, pattern, root); 40 | try c.v2r.put(allocator, v, c.next_reg); 41 | c.next_reg.add(1); 42 | } 43 | } else { 44 | try c.addPattern(allocator, pattern, root); 45 | c.next_reg.add(1); 46 | } 47 | 48 | while (c.next()) |entry| { 49 | const todo, const node = entry; 50 | const id, const reg = todo; 51 | 52 | if (c.isGrounded(id) and node.operands().len != 0) { 53 | const new_node = try newRoot(pattern, allocator, id); 54 | try c.instructions.append(allocator, .{ .lookup = .{ 55 | .i = reg, 56 | .term = new_node, 57 | } }); 58 | } else { 59 | const out = c.next_reg; 60 | c.next_reg.add(@intCast(node.operands().len)); 61 | 62 | try c.instructions.append(allocator, .{ .bind = .{ 63 | .node = node, 64 | .i = reg, 65 | .out = out, 66 | } }); 67 | 68 | for (node.operands(), 0..) |child, i| { 69 | try c.addTodo( 70 | allocator, 71 | pattern, 72 | child, 73 | @enumFromInt(@intFromEnum(out) + i), 74 | ); 75 | } 76 | } 77 | } 78 | } 79 | 80 | fn addPattern( 81 | c: *Compiler, 82 | allocator: std.mem.Allocator, 83 | pattern: SExpr, 84 | root: SExpr.Index, 85 | ) !void { 86 | if (c.instructions.items.len != 0) { 87 | try c.instructions.append(allocator, .{ .scan = c.next_reg }); 88 | } 89 | try c.addTodo(allocator, pattern, root, c.next_reg); 90 | } 91 | 92 | fn isGrounded(c: *Compiler, id: SExpr.Index) bool { 93 | for (c.free_vars.items[@intFromEnum(id)].keys()) |v| { 94 | if (!c.v2r.contains(v)) return false; 95 | } 96 | return true; 97 | } 98 | 99 | /// Clones and owner must free Program. 100 | /// 101 | /// TODO: should just ref and deinit with Compiler, unsure if there's anything stopping that. 102 | fn extract(c: *Compiler, allocator: std.mem.Allocator) !Program { 103 | return .{ 104 | .instructions = try c.instructions.toOwnedSlice(allocator), 105 | .map = try c.v2r.clone(allocator), 106 | }; 107 | } 108 | 109 | fn loadPattern(c: *Compiler, allocator: std.mem.Allocator, pattern: SExpr) !void { 110 | const len = pattern.nodes.len; 111 | try c.free_vars.ensureTotalCapacityPrecise(allocator, len); 112 | try c.subtree_size.ensureTotalCapacityPrecise(allocator, len); 113 | 114 | for (pattern.nodes) |node| { 115 | var free: std.StringArrayHashMapUnmanaged(void) = .{}; 116 | var size: usize = 0; 117 | 118 | switch (node) { 119 | .node => |n| { 120 | size = 1; 121 | for (n.list) |child| { 122 | for (c.free_vars.items[@intFromEnum(child)].keys()) |fv| { 123 | try free.put(allocator, fv, {}); 124 | } 125 | size += c.subtree_size.items[@intFromEnum(child)]; 126 | } 127 | }, 128 | .constant => size = 1, 129 | .builtin => |b| try free.put(allocator, b.expr, {}), 130 | .atom => |v| try free.put(allocator, v, {}), 131 | } 132 | try c.free_vars.append(allocator, free); 133 | try c.subtree_size.append(allocator, size); 134 | } 135 | } 136 | 137 | fn addTodo( 138 | c: *Compiler, 139 | allocator: std.mem.Allocator, 140 | pattern: SExpr, 141 | id: SExpr.Index, 142 | reg: Reg, 143 | ) !void { 144 | const node = pattern.get(id); 145 | switch (node) { 146 | inline .builtin, .atom => |sub, t| { 147 | const v = switch (t) { 148 | .builtin => sub.expr, 149 | .atom => sub, 150 | else => unreachable, 151 | }; 152 | if (c.v2r.get(v)) |j| { 153 | try c.instructions.append(allocator, .{ .compare = .{ 154 | .i = reg, 155 | .j = j, 156 | } }); 157 | } else { 158 | try c.v2r.put(allocator, v, reg); 159 | switch (t) { 160 | .builtin => try c.todo_nodes.put(allocator, .{ id, reg }, node), 161 | else => {}, 162 | } 163 | } 164 | }, 165 | .node, 166 | .constant, 167 | => try c.todo_nodes.put(allocator, .{ id, reg }, node), 168 | } 169 | } 170 | 171 | fn next(c: *Compiler) ?struct { Todo, SExpr.Entry } { 172 | if (c.todo_nodes.count() == 0) return null; 173 | 174 | const Fill = struct { 175 | grounded: bool, 176 | free: u64, 177 | size: u64, 178 | node: ?Todo, 179 | 180 | fn better(new: @This(), old: @This()) bool { 181 | // Prefer grounded. 182 | if (old.grounded == true and new.grounded == false) return false; 183 | if (new.free < old.free) return false; 184 | if (new.size > old.size) return false; 185 | return true; 186 | } 187 | }; 188 | 189 | var best: Fill = .{ 190 | .grounded = false, 191 | // Prefer more free variables. 192 | .free = 0, 193 | // Prefer smaller terms. 194 | .size = std.math.maxInt(u64), 195 | .node = null, 196 | }; 197 | 198 | var iter = c.todo_nodes.keyIterator(); 199 | while (iter.next()) |node| { 200 | const id = node.@"0"; 201 | const vars = c.free_vars.items[@intFromEnum(id)]; 202 | var n_bound: usize = 0; 203 | for (vars.keys()) |v| { 204 | if (c.v2r.contains(v)) n_bound += 1; 205 | } 206 | const n_free = vars.count() - n_bound; 207 | const size = c.subtree_size.items[@intFromEnum(id)]; 208 | 209 | const new: Fill = .{ 210 | .grounded = n_free == 0, 211 | .free = n_free, 212 | .size = size, 213 | .node = node.*, 214 | }; 215 | 216 | if (best.node == null or new.better(best)) { 217 | best = new; 218 | } 219 | } 220 | 221 | const removed = c.todo_nodes.fetchRemove(best.node.?).?.value; 222 | return .{ best.node.?, removed }; 223 | } 224 | 225 | fn deinit(c: *Compiler, allocator: std.mem.Allocator) void { 226 | c.instructions.deinit(allocator); 227 | c.v2r.deinit(allocator); 228 | c.todo_nodes.deinit(allocator); 229 | for (c.free_vars.items) |*set| { 230 | set.deinit(allocator); 231 | } 232 | c.free_vars.deinit(allocator); 233 | c.subtree_size.deinit(allocator); 234 | } 235 | }; 236 | 237 | /// Extracts an S-expression that starts from new_root and contains its children. 238 | /// 239 | /// Caller must free the expr. 240 | pub fn newRoot( 241 | expr: SExpr, 242 | allocator: std.mem.Allocator, 243 | new_root_idx: SExpr.Index, 244 | ) !SExpr { 245 | var list: std.ArrayListUnmanaged(SExpr.Entry) = .{}; 246 | defer list.deinit(allocator); 247 | 248 | var queue: std.ArrayListUnmanaged(SExpr.Index) = .{}; 249 | defer queue.deinit(allocator); 250 | var map: std.AutoHashMapUnmanaged(SExpr.Index, SExpr.Index) = .{}; 251 | defer map.deinit(allocator); 252 | 253 | const new_root = expr.get(new_root_idx); 254 | 255 | try queue.appendSlice( 256 | allocator, 257 | new_root.operands(), 258 | ); 259 | 260 | while (queue.getLastOrNull()) |id| { 261 | if (map.contains(id)) { 262 | _ = queue.pop(); 263 | continue; 264 | } 265 | 266 | const node = expr.get(id); 267 | 268 | var resolved: bool = true; 269 | for (node.operands()) |child| { 270 | if (!map.contains(child)) { 271 | resolved = false; 272 | try queue.append(allocator, child); 273 | } 274 | } 275 | 276 | if (resolved) { 277 | const new_node = try node.map(allocator, &map); 278 | const new_id: SExpr.Index = @enumFromInt(list.items.len); 279 | try list.append(allocator, new_node); 280 | try map.put(allocator, id, new_id); 281 | _ = queue.pop(); 282 | } 283 | } 284 | 285 | const new_root_node = try new_root.map(allocator, &map); 286 | try list.append(allocator, new_root_node); 287 | return .{ .nodes = try list.toOwnedSlice(allocator) }; 288 | } 289 | 290 | const Program = struct { 291 | instructions: []const Instruction, 292 | map: std.StringHashMapUnmanaged(Reg), 293 | 294 | fn search( 295 | p: *Program, 296 | rw: Rewrite, 297 | oir: *const Oir, 298 | matches: *std.ArrayListUnmanaged(Result), 299 | ) !void { 300 | const pattern = rw.from; 301 | var iter = oir.classes.valueIterator(); 302 | const root = pattern.get(pattern.root()); 303 | while (iter.next()) |class| { 304 | switch (root) { 305 | .constant => |value| if (oir.classContains(class.index, .constant)) |idx| { 306 | const node_value = oir.getNode(idx).data.constant; 307 | if (value == node_value) { 308 | try matches.append(oir.allocator, .{ 309 | .bindings = .{}, 310 | .class = class.index, 311 | .pattern = rw.to, 312 | }); 313 | } 314 | }, 315 | .node => |n| if (oir.classContainsAny(class.index, n.tag)) { 316 | try p.searchClass(oir, class.index, rw.to, matches); 317 | }, 318 | .atom => @panic("TODO: non-node root"), 319 | .builtin => @panic("can't have root be a builtin function"), 320 | } 321 | } 322 | } 323 | 324 | // fn searchMulti( 325 | // p: *Program, 326 | // oir: *const Oir, 327 | // matches: *std.ArrayListUnmanaged(Result), 328 | // ) !void { 329 | // var iter = oir.classes.valueIterator(); 330 | // while (iter.next()) |class| { 331 | // try p.searchClass(oir, class.index, SExpr.parse("10"), matches); 332 | // } 333 | // } 334 | 335 | fn searchClass( 336 | p: *Program, 337 | oir: *const Oir, 338 | class: Class.Index, 339 | pattern: SExpr, 340 | matches: *std.ArrayListUnmanaged(Result), 341 | ) !void { 342 | std.debug.assert(oir.clean); // must be clean to search 343 | const allocator = oir.allocator; 344 | 345 | var machine: Machine = .{ 346 | .registers = .{}, 347 | .v2r = &p.map, 348 | .lookup = .{}, 349 | }; 350 | defer machine.deinit(allocator); 351 | try machine.registers.append(allocator, class); 352 | 353 | var results: std.ArrayListUnmanaged(Result.Bindings) = .{}; 354 | defer results.deinit(allocator); 355 | 356 | try machine.run(oir, p.instructions, p.map, &results); 357 | 358 | for (results.items) |result| { 359 | try matches.append(allocator, .{ 360 | .bindings = result, 361 | .class = class, 362 | .pattern = pattern, 363 | }); 364 | } 365 | } 366 | 367 | fn deinit(p: *Program, allocator: std.mem.Allocator) void { 368 | for (p.instructions) |inst| inst.deinit(allocator); 369 | allocator.free(p.instructions); 370 | p.map.deinit(allocator); 371 | } 372 | }; 373 | 374 | const Machine = struct { 375 | registers: std.ArrayListUnmanaged(Class.Index), 376 | /// Owned by the overhead Program 377 | v2r: *const std.StringHashMapUnmanaged(Reg), 378 | lookup: std.ArrayListUnmanaged(Class.Index), 379 | 380 | fn run( 381 | m: *Machine, 382 | oir: *const Oir, 383 | insts: []const Instruction, 384 | map: std.StringHashMapUnmanaged(Reg), 385 | matches: *std.ArrayListUnmanaged(Result.Bindings), 386 | ) !void { 387 | for (insts, 1..) |inst, i| { 388 | switch (inst) { 389 | .bind => |bind| { 390 | const class = oir.getClass(m.registers.items[@intFromEnum(bind.i)]); 391 | for (class.bag.items) |node_idx| { 392 | const node = oir.getNode(node_idx); 393 | if (bind.node.matches(node, oir)) { 394 | m.registers.shrinkRetainingCapacity(@intFromEnum(bind.out)); 395 | for (node.operands(oir)) |id| { 396 | try m.registers.append(oir.allocator, id); 397 | } 398 | // run for remaining instructions 399 | try m.run(oir, insts[i..], map, matches); 400 | } 401 | return; 402 | } 403 | }, 404 | .lookup => |lookup| { 405 | m.lookup.clearRetainingCapacity(); 406 | 407 | for (lookup.term.nodes) |node| { 408 | switch (node) { 409 | .atom => |v| { 410 | const reg = m.v2r.get(v).?; 411 | try m.lookup.append( 412 | oir.allocator, 413 | oir.union_find.find(m.registers.items[@intFromEnum(reg)]), 414 | ); 415 | }, 416 | .constant => |c| { 417 | const found_idx = oir.findNode(.constant(c)) orelse return; // can't match 418 | const class_id = oir.findClass(found_idx); 419 | try m.lookup.append(oir.allocator, class_id); 420 | }, 421 | .builtin => unreachable, // NOTE: not really unreachable i think, dunno 422 | .node => |n| { 423 | var new_node = switch (n.tag) { 424 | .region => unreachable, 425 | inline else => |t| Node.init(t, undefined), 426 | }; 427 | for (new_node.mutableOperands(oir), n.list) |*op, l| { 428 | op.* = m.lookup.items[@intFromEnum(l)]; 429 | } 430 | const found_idx = oir.findNode(new_node) orelse return; // can't match 431 | const class_id = oir.findClass(found_idx); 432 | try m.lookup.append(oir.allocator, class_id); 433 | }, 434 | } 435 | } 436 | 437 | const id = oir.union_find.find(m.registers.items[@intFromEnum(lookup.i)]); 438 | if (m.lookup.getLastOrNull() != id) { 439 | return; // no match 440 | } 441 | }, 442 | .compare => |compare| { 443 | const a = m.registers.items[@intFromEnum(compare.i)]; 444 | const b = m.registers.items[@intFromEnum(compare.j)]; 445 | if (oir.union_find.find(a) != oir.union_find.find(b)) { 446 | return; // no match 447 | } 448 | }, 449 | .scan => |scan| { 450 | var iter = oir.classes.valueIterator(); 451 | while (iter.next()) |class| { 452 | m.registers.shrinkRetainingCapacity(@intFromEnum(scan)); 453 | try m.registers.append(oir.allocator, class.index); 454 | try m.run(oir, insts[i..], map, matches); 455 | } 456 | return; 457 | }, 458 | } 459 | } 460 | 461 | // matched! 462 | 463 | var result: std.StringHashMapUnmanaged(Class.Index) = .{}; 464 | var iter = map.iterator(); 465 | while (iter.next()) |entry| { 466 | const class_id = m.registers.items[@intFromEnum(entry.value_ptr.*)]; 467 | try result.put(oir.allocator, entry.key_ptr.*, class_id); 468 | } 469 | try matches.append(oir.allocator, result); 470 | } 471 | 472 | fn deinit(m: *Machine, allocator: std.mem.Allocator) void { 473 | m.registers.deinit(allocator); 474 | m.lookup.deinit(allocator); 475 | } 476 | }; 477 | 478 | const Instruction = union(enum) { 479 | bind: struct { node: SExpr.Entry, i: Reg, out: Reg }, 480 | lookup: struct { term: SExpr, i: Reg }, 481 | compare: struct { i: Reg, j: Reg }, 482 | scan: Reg, 483 | 484 | pub fn format( 485 | inst: Instruction, 486 | comptime _: []const u8, 487 | _: std.fmt.FormatOptions, 488 | writer: anytype, 489 | ) !void { 490 | switch (inst) { 491 | .bind => |b| try writer.print("bind({}, ${}, ${})", .{ 492 | b.node, 493 | @intFromEnum(b.i), 494 | @intFromEnum(b.out), 495 | }), 496 | .lookup => |l| try writer.print("lookup({}, ${})", .{ 497 | l.term, 498 | @intFromEnum(l.i), 499 | }), 500 | .compare => |c| try writer.print("compare(${} vs ${})", .{ 501 | @intFromEnum(c.i), 502 | @intFromEnum(c.j), 503 | }), 504 | .scan => |s| try writer.print("scan(${})", .{@intFromEnum(s)}), 505 | } 506 | } 507 | 508 | fn deinit(inst: Instruction, allocator: std.mem.Allocator) void { 509 | switch (inst) { 510 | .lookup => |l| l.term.deinit(allocator), 511 | else => {}, 512 | } 513 | } 514 | }; 515 | 516 | const Reg = enum(u32) { 517 | _, 518 | 519 | pub fn add(r: *Reg, n: u32) void { 520 | r.* = @enumFromInt(@intFromEnum(r.*) + n); 521 | } 522 | }; 523 | 524 | pub fn search( 525 | oir: *const Oir, 526 | rw: Rewrite, 527 | matches: *std.ArrayListUnmanaged(Result), 528 | ) Result.Error!void { 529 | const allocator = oir.allocator; 530 | 531 | var compiler: Compiler = .{ 532 | .next_reg = @enumFromInt(0), 533 | .instructions = .{}, 534 | .todo_nodes = .{}, 535 | .v2r = .{}, 536 | .free_vars = .{}, 537 | .subtree_size = .{}, 538 | }; 539 | defer compiler.deinit(allocator); 540 | 541 | try compiler.compile(allocator, null, rw.from); 542 | 543 | var program = try compiler.extract(allocator); 544 | defer program.deinit(allocator); 545 | 546 | try program.search(rw, oir, matches); 547 | } 548 | 549 | // pub fn multiSearch(oir: *const Oir, mrw: MultiRewrite) Result.Error!void { 550 | // const allocator = oir.allocator; 551 | 552 | // var compiler: Compiler = .{ 553 | // .next_reg = @enumFromInt(0), 554 | // .instructions = .{}, 555 | // .todo_nodes = .{}, 556 | // .v2r = .{}, 557 | // .free_vars = .{}, 558 | // .subtree_size = .{}, 559 | // }; 560 | // defer compiler.deinit(allocator); 561 | 562 | // for (mrw.from) |rw| { 563 | // try compiler.compile(allocator, rw.atom, rw.pattern); 564 | // } 565 | 566 | // var program = try compiler.extract(allocator); 567 | // defer program.deinit(allocator); 568 | 569 | // std.debug.print("program: {any}\n", .{program.instructions}); 570 | 571 | // var matches: std.ArrayListUnmanaged(Result) = .{}; 572 | // try program.searchMulti(oir, &matches); 573 | 574 | // std.debug.print("num: {}\n", .{matches.items.len}); 575 | // } 576 | -------------------------------------------------------------------------------- /src/passes/rewrite/table.zon: -------------------------------------------------------------------------------- 1 | .{ 2 | .{ 3 | .name = "comm-mul", 4 | .from = "(mul ?x ?y)", 5 | .to = "(mul ?y ?x)", 6 | }, 7 | .{ 8 | .name = "comm-add", 9 | .from = "(add ?x ?y)", 10 | .to = "(add ?y ?x)", 11 | }, 12 | .{ 13 | .name = "mul-to-shl", 14 | .from = "(mul ?x @known_pow2(y))", 15 | .to = "(shl ?x @log2(y))", 16 | }, 17 | .{ 18 | .name = "zero-add", 19 | .from = "(add ?x 0)", 20 | .to = "?x", 21 | }, 22 | .{ 23 | .name = "double", 24 | .from = "(add ?x ?x)", 25 | .to = "(mul ?x 2)", 26 | }, 27 | .{ 28 | .name = "zero-mul", 29 | .from = "(mul ?x 0)", 30 | .to = "0", 31 | }, 32 | .{ 33 | .name = "one-mul", 34 | .from = "(mul ?x 1)", 35 | .to = "?x", 36 | }, 37 | .{ 38 | .name = "one-div", 39 | .from = "(div_exact ?x 1)", 40 | .to = "?x", 41 | }, 42 | .{ 43 | .name = "associate-div-mul", 44 | .from = "(div_exact (mul ?x ?y) ?z)", 45 | .to = "(mul ?x (div_exact ?y ?z))", 46 | }, 47 | .{ 48 | .name = "factor", 49 | .from = "(add (mul ?x ?y) (mul ?x ?z))", 50 | .to = "(mul ?x (add ?y ?z))", 51 | }, 52 | .{ 53 | .name = "factor-one", 54 | .from = "(add ?x (mul ?x ?y))", 55 | .to = "(mul ?x (add 1 ?y))", 56 | }, 57 | .{ 58 | .name = "cmp_eq_same", 59 | .from = "(cmp_eq ?x ?x)", 60 | .to = "1", 61 | }, 62 | .{ 63 | .name = "cmp_gt_same", 64 | .from = "(cmp_gt ?x ?x)", 65 | .to = "0", 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /src/trace.zig: -------------------------------------------------------------------------------- 1 | const std = @import("std"); 2 | const build_options = @import("build_options"); 3 | 4 | pub const Trace = if (build_options.enable_trace) struct { 5 | buffered: std.io.BufferedWriter(4096, std.fs.File.Writer), 6 | start_time: std.time.Instant, 7 | stream: std.fs.File, 8 | 9 | pub fn init() Trace { 10 | errdefer @panic("failed to init trace"); 11 | const file = try std.fs.cwd().createFile("trace.json", .{}); 12 | try file.writer().writeByte('['); 13 | return .{ 14 | .stream = file, 15 | .buffered = std.io.bufferedWriter(file.writer()), 16 | .start_time = try std.time.Instant.now(), 17 | }; 18 | } 19 | 20 | pub fn start( 21 | t: *Trace, 22 | src: std.builtin.SourceLocation, 23 | comptime fmt: []const u8, 24 | args: anytype, 25 | ) Event { 26 | const now = std.time.Instant.now() catch @panic("failed to now()"); 27 | const writer = t.buffered.writer(); 28 | 29 | writer.print( 30 | \\{{"cat":"function", "name":"{s}:{d}:{d} ( 31 | ++ fmt ++ 32 | \\)", "ph": "B", "pid": 0, "tid": 0, "ts": {d}}}, 33 | \\ 34 | , .{ 35 | std.fs.path.basename(src.file), 36 | src.line, 37 | src.column, 38 | } ++ 39 | args ++ .{ 40 | now.since(t.start_time) / 1_000, 41 | }) catch @panic("failed to write"); 42 | 43 | return .{ 44 | .src = src, 45 | .trace = t, 46 | }; 47 | } 48 | 49 | pub fn deinit(t: *Trace) void { 50 | t.buffered.writer().writeAll("]\n") catch @panic("failed to print"); 51 | t.buffered.flush() catch @panic("failed to flush"); 52 | t.stream.close(); 53 | } 54 | } else struct { 55 | pub fn init() Trace { 56 | return .{}; 57 | } 58 | 59 | pub fn start( 60 | _: *Trace, 61 | _: std.builtin.SourceLocation, 62 | comptime _: []const u8, 63 | _: anytype, 64 | ) Event { 65 | return undefined; 66 | } 67 | 68 | pub fn deinit(_: *Trace) void {} 69 | }; 70 | 71 | const Event = struct { 72 | src: std.builtin.SourceLocation, 73 | trace: *Trace, 74 | 75 | pub fn end(e: Event) void { 76 | if (!build_options.enable_trace) return; 77 | const writer = e.trace.buffered.writer(); 78 | const now = std.time.Instant.now() catch @panic("failed to now()"); 79 | writer.print( 80 | \\{{"cat":"function", "ph":"E", "ts":{d}, "pid":0, "tid":0}}, 81 | \\ 82 | , 83 | .{now.since(e.trace.start_time) / 1_000}, 84 | ) catch @panic("failed to write"); 85 | } 86 | }; 87 | -------------------------------------------------------------------------------- /test.c: -------------------------------------------------------------------------------- 1 | int foo(int x, int y) { 2 | return x + y; 3 | } --------------------------------------------------------------------------------