├── test └── custom │ ├── array.out │ ├── basic.out │ ├── break.out │ ├── fib.out │ ├── float.out │ ├── ltcmp.in │ ├── nested.out │ ├── sum.in │ ├── bigarray.out │ ├── pointer.out │ ├── reduction.in │ ├── tco.out │ ├── floatarr.out │ ├── ltcmp.out │ ├── return.out │ ├── short.out │ ├── sum.out │ ├── timer.out │ ├── manyargs.out │ ├── reduction.out │ ├── bigarray.sy │ ├── return.sy │ ├── fib.sy │ ├── reduction.sy │ ├── floatarr.sy │ ├── sum.sy │ ├── float.sy │ ├── tco.sy │ ├── nested.sy │ ├── basic.sy │ ├── ltcmp.sy │ ├── array.sy │ ├── short.sy │ ├── timer.sy │ ├── pointer.sy │ ├── break.sy │ └── manyargs.sy ├── src ├── rt │ ├── arm-join.s │ └── arm-clone.s ├── pre-opt │ ├── MoveAlloca.cpp │ ├── NoStore.cpp │ ├── Remerge.cpp │ ├── PreAttrs.cpp │ ├── Unroll.cpp │ ├── LoopDCE.cpp │ ├── PreAnalysis.h │ ├── TidyMemory.cpp │ ├── PreAttrs.h │ ├── Base.cpp │ ├── Lower.cpp │ ├── Localize.cpp │ ├── PrePasses.h │ ├── ArrayAccess.cpp │ ├── PreLoopPasses.h │ └── Parallelizable.cpp ├── utils │ ├── DynamicCast.h │ ├── presburger │ │ ├── BasicSet.h │ │ └── BasicSet.cpp │ ├── smt │ │ ├── BvMatcher.h │ │ ├── BvExpr.h │ │ ├── Simplify.cpp │ │ ├── SMT.h │ │ └── CDCL.h │ ├── Matcher.h │ └── Exec.h ├── arm │ ├── ArmLoopPasses.h │ ├── ArmMatcher.h │ ├── LateLegalize.cpp │ ├── ArmDCE.cpp │ ├── ArmPasses.h │ ├── Regs.h │ ├── PostIncr.cpp │ ├── ArmAttrs.h │ └── InstCombine.cpp ├── opt │ ├── LoopInfo.cpp │ ├── SMTPasses.h │ ├── PassManager.h │ ├── LowerPasses.h │ ├── CallGraph.cpp │ ├── Pureness.cpp │ ├── AtMostOnce.cpp │ ├── AggressiveDCE.cpp │ ├── RemoveEmptyLoop.cpp │ ├── Pass.h │ ├── Reassociate.cpp │ ├── Verify.cpp │ ├── RangeAwareFold.cpp │ ├── Analysis.h │ ├── SimplifyCFG.cpp │ ├── Pass.cpp │ ├── HoistConstArray.cpp │ ├── CleanupPasses.h │ ├── PassManager.cpp │ ├── Specialize.cpp │ ├── Alias.cpp │ ├── DAE.cpp │ ├── Cached.cpp │ ├── Splice.cpp │ ├── InlineStore.cpp │ └── Mem2Reg.cpp ├── main │ ├── Options.h │ └── Options.cpp ├── rv │ ├── RvDupPasses.h │ ├── RvDCE.cpp │ ├── RvPasses.h │ ├── Regs.h │ ├── RvAttrs.h │ └── RvOps.h ├── parse │ ├── Type.cpp │ ├── Sema.h │ ├── Lexer.h │ ├── Type.h │ ├── TypeContext.h │ └── Parser.h └── codegen │ └── Ops.h ├── .gitignore ├── mca.sh ├── package.json ├── fuzz.py ├── scoreboard.ts └── fuzzer.py /test/custom/array.out: -------------------------------------------------------------------------------- 1 | 14 2 | -------------------------------------------------------------------------------- /test/custom/basic.out: -------------------------------------------------------------------------------- 1 | 16 2 | -------------------------------------------------------------------------------- /test/custom/break.out: -------------------------------------------------------------------------------- 1 | 25 2 | -------------------------------------------------------------------------------- /test/custom/fib.out: -------------------------------------------------------------------------------- 1 | 34 2 | -------------------------------------------------------------------------------- /test/custom/float.out: -------------------------------------------------------------------------------- 1 | 33 2 | -------------------------------------------------------------------------------- /test/custom/ltcmp.in: -------------------------------------------------------------------------------- 1 | 45 2 | -------------------------------------------------------------------------------- /test/custom/nested.out: -------------------------------------------------------------------------------- 1 | 7 2 | -------------------------------------------------------------------------------- /test/custom/sum.in: -------------------------------------------------------------------------------- 1 | 100 2 | -------------------------------------------------------------------------------- /test/custom/bigarray.out: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /test/custom/pointer.out: -------------------------------------------------------------------------------- 1 | 91 2 | -------------------------------------------------------------------------------- /test/custom/reduction.in: -------------------------------------------------------------------------------- 1 | -105 2 | -------------------------------------------------------------------------------- /test/custom/tco.out: -------------------------------------------------------------------------------- 1 | 1 2 | 0 3 | -------------------------------------------------------------------------------- /test/custom/floatarr.out: -------------------------------------------------------------------------------- 1 | 12 2 | 4 3 | -------------------------------------------------------------------------------- /test/custom/ltcmp.out: -------------------------------------------------------------------------------- 1 | 2415 2 | 0 3 | -------------------------------------------------------------------------------- /test/custom/return.out: -------------------------------------------------------------------------------- 1 | -5 2 | 0 3 | -------------------------------------------------------------------------------- /test/custom/short.out: -------------------------------------------------------------------------------- 1 | 122 2 | 0 3 | -------------------------------------------------------------------------------- /test/custom/sum.out: -------------------------------------------------------------------------------- 1 | 4950 2 | 0 3 | -------------------------------------------------------------------------------- /test/custom/timer.out: -------------------------------------------------------------------------------- 1 | -5000 2 | 0 3 | -------------------------------------------------------------------------------- /test/custom/manyargs.out: -------------------------------------------------------------------------------- 1 | 66 2 | 75 3 | 0 4 | -------------------------------------------------------------------------------- /test/custom/reduction.out: -------------------------------------------------------------------------------- 1 | -35 2 | -21 3 | -15 4 | 0 5 | -------------------------------------------------------------------------------- /test/custom/bigarray.sy: -------------------------------------------------------------------------------- 1 | int main() { 2 | int big[10000]; 3 | big[9999] = 1; 4 | return big[9999]; 5 | } 6 | -------------------------------------------------------------------------------- /test/custom/return.sy: -------------------------------------------------------------------------------- 1 | void f(int x) { 2 | if (x > 0) 3 | return; 4 | putint(x); 5 | } 6 | 7 | int main() { 8 | f(3); 9 | f(-5); 10 | return 0; 11 | } 12 | -------------------------------------------------------------------------------- /test/custom/fib.sy: -------------------------------------------------------------------------------- 1 | int fib(int n) { 2 | if (n < 2) 3 | return 1; 4 | 5 | return fib(n - 2) + fib(n - 1); 6 | } 7 | 8 | int main() { 9 | return fib(8); 10 | } 11 | -------------------------------------------------------------------------------- /src/rt/arm-join.s: -------------------------------------------------------------------------------- 1 | ; R"( 2 | # Arg 0: Lock address 3 | spinlock_wait: 4 | 1: 5 | ldaxr w1, [x0] 6 | cbz w1, 2f 7 | clrex 8 | wfe 9 | b 1b 10 | 2: 11 | dmb ish 12 | ret 13 | )" 14 | -------------------------------------------------------------------------------- /test/custom/reduction.sy: -------------------------------------------------------------------------------- 1 | int main() { 2 | int x = getint(); 3 | putint(x / 3); 4 | putch(10); 5 | putint(x / 5); 6 | putch(10); 7 | putint(x / 7); 8 | putch(10); 9 | return 0; 10 | } 11 | -------------------------------------------------------------------------------- /test/custom/floatarr.sy: -------------------------------------------------------------------------------- 1 | float f(float x[]) { 2 | putint(x[0]); 3 | putint(x[1]); 4 | return x[0] + x[1]; 5 | } 6 | 7 | int main() { 8 | float a[2] = { 1.1, 2.2 }; 9 | return a[0] + f(a); // 4 10 | } 11 | -------------------------------------------------------------------------------- /test/custom/sum.sy: -------------------------------------------------------------------------------- 1 | int main() { 2 | int i = 0; 3 | int sum = 0; 4 | int n = getint(); 5 | while (i < n) { 6 | sum = sum + i; 7 | i = i + 1; 8 | } 9 | putint(sum); 10 | return 0; 11 | } 12 | -------------------------------------------------------------------------------- /test/custom/float.sy: -------------------------------------------------------------------------------- 1 | float square(float x) { 2 | return x * x; 3 | } 4 | 5 | int main() { 6 | float x = 1.23; 7 | int y = x + 2.46; // y == 3 8 | return square(y) * 4.3 - 5.5; // 9 * 4.3 - 5.5 == 33.2 9 | } 10 | -------------------------------------------------------------------------------- /test/custom/tco.sy: -------------------------------------------------------------------------------- 1 | int even_odd(int x, int result) { 2 | if (x == 0) 3 | return result; 4 | 5 | return even_odd(x - 1, !result); 6 | } 7 | 8 | int main() { 9 | putint(even_odd(5, 0)); 10 | return 0; 11 | } 12 | -------------------------------------------------------------------------------- /test/custom/nested.sy: -------------------------------------------------------------------------------- 1 | int main() { 2 | if (1) { 3 | 1; 4 | 2; 5 | } 6 | else if (1) { 7 | 3; 8 | if (1) 9 | 4; 10 | } 11 | else 12 | 5; 13 | if (1) 14 | 6; 15 | return 7; 16 | } 17 | -------------------------------------------------------------------------------- /test/custom/basic.sy: -------------------------------------------------------------------------------- 1 | int count; 2 | 3 | int main() { 4 | int a = 7; 5 | while (a != 1) { 6 | count = count + 1; 7 | if (a % 2 == 0) { 8 | a = a / 2; 9 | } else { 10 | a = a * 3 + 1; 11 | } 12 | } 13 | return count; 14 | } 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | bin 3 | test/official 4 | test/temp 5 | test/fuzz 6 | test/cnf 7 | test/cnf_performance 8 | test/client-linux 9 | temp 10 | final_performance 11 | functional 12 | h_functional 13 | performance 14 | node_modules 15 | rank 16 | arm_performance 17 | -------------------------------------------------------------------------------- /test/custom/ltcmp.sy: -------------------------------------------------------------------------------- 1 | int main() { 2 | int i = 0; 3 | int n = getint(); 4 | int sum = 0; 5 | while (i < n) { 6 | if (i < 30) 7 | sum = sum + i; 8 | 9 | sum = sum + i * 2; 10 | i = i + 1; 11 | } 12 | putint(sum); 13 | return 0; 14 | } 15 | -------------------------------------------------------------------------------- /test/custom/array.sy: -------------------------------------------------------------------------------- 1 | int a[4][2] = { 1, 2, { 3 }, 4, 5, { 6 } }; 2 | int largezero[10000]; 3 | int main() { 4 | int b[4][2][3] = { 1, 2, 3, { 4 }, 5 }; 5 | int c = a[1][0]; // 3 6 | a[1][0] = 4; 7 | b[0][0][0] = 5; 8 | return c + a[1][0] + b[0][0][0] + b[0][0][1]; // 14 9 | } 10 | -------------------------------------------------------------------------------- /test/custom/short.sy: -------------------------------------------------------------------------------- 1 | int f() { 2 | putint(1); 3 | return 1; 4 | } 5 | 6 | int g() { 7 | putint(2); 8 | return 0; 9 | } 10 | 11 | int main() { 12 | int total = 0; 13 | if (f() && g()) 14 | total = total + 1; 15 | while (g() && f()) 16 | total = total + 1; 17 | return total; 18 | } -------------------------------------------------------------------------------- /mca.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | # I have multiple LLVM's installed: 4 | # one for Clang IR, one for mainstream, one from `apt install`. 5 | # Only the mainstream one has `llvm-mca` with AArch64 and RISC-V enabled. 6 | # Hence this script, used as an alias of `llvm-mca` from mainstream though it's not on path. 7 | 8 | ~/llvm/llvm-project/build/bin/llvm-mca $@ -march=riscv64 -mcpu=xiangshan-nanhu 9 | -------------------------------------------------------------------------------- /test/custom/timer.sy: -------------------------------------------------------------------------------- 1 | int loop() { 2 | int i = 0; 3 | int total = 0; 4 | while (i < 10000) { 5 | if (i % 2 == 0) 6 | total = total + i; 7 | else 8 | total = total - i; 9 | i = i + 1; 10 | } 11 | return total; 12 | } 13 | 14 | int main() { 15 | int val; 16 | starttime(); 17 | val = loop(); 18 | stoptime(); 19 | putint(val); 20 | return 0; 21 | } 22 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "devDependencies": { 3 | "@types/node": "^22.15.21", 4 | "axios": "^1.9.0", 5 | "cheerio": "^1.0.0", 6 | "p-queue": "^8.1.0", 7 | "ts-node": "^10.9.2", 8 | "tsx": "^4.19.4", 9 | "yargs": "^17.7.2" 10 | }, 11 | "packageManager": "pnpm@10.11.0+sha512.6540583f41cc5f628eb3d9773ecee802f4f9ef9923cc45b69890fb47991d4b092964694ec3a4f738a420c918a333062c8b925d312f42e4f0c263eb603551f977" 12 | } 13 | -------------------------------------------------------------------------------- /src/pre-opt/MoveAlloca.cpp: -------------------------------------------------------------------------------- 1 | #include "PrePasses.h" 2 | 3 | using namespace sys; 4 | 5 | void MoveAlloca::run() { 6 | auto funcs = collectFuncs(); 7 | 8 | for (auto func : funcs) { 9 | auto allocas = func->findAll(); 10 | auto region = func->getRegion(); 11 | auto begin = region->insert(region->getFirstBlock()); 12 | for (auto alloca : allocas) 13 | alloca->moveToEnd(begin); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /test/custom/pointer.sy: -------------------------------------------------------------------------------- 1 | int f(int a[]) { 2 | int b[3] = { 11, 12 }; 3 | b[2] = 13; 4 | int total = 0; 5 | int i = 0; 6 | while (i < 10) { 7 | total = total + a[i]; 8 | i = i + 1; 9 | } 10 | while (i < 13) { 11 | total = total + b[i - 10]; 12 | i = i + 1; 13 | } 14 | return total; 15 | } 16 | 17 | int main() { 18 | int a[10] = { 19 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 20 | }; 21 | return f(a); 22 | } -------------------------------------------------------------------------------- /test/custom/break.sy: -------------------------------------------------------------------------------- 1 | int main() { 2 | int a = 1; 3 | int b = 10; 4 | int total = 0; 5 | while (a < 10) { 6 | if (a > 5) 7 | break; 8 | if (a == 7) 9 | continue; 10 | 11 | while (b > 0) { 12 | if (b < 5) 13 | break; 14 | b = b - 1; 15 | total = total + 1; 16 | continue; 17 | total = total + 1; 18 | } 19 | total = total + 1; 20 | b = a + 5; 21 | a = a + 1; 22 | } 23 | return total; 24 | } 25 | -------------------------------------------------------------------------------- /src/utils/DynamicCast.h: -------------------------------------------------------------------------------- 1 | #ifndef DYNAMIC_CAST_H 2 | #define DYNAMIC_CAST_H 3 | 4 | #include 5 | 6 | namespace sys { 7 | 8 | template 9 | bool isa(U *t) { 10 | return T::classof(t); 11 | } 12 | 13 | template 14 | T *cast(U *t) { 15 | assert(isa(t)); 16 | return (T*) t; 17 | } 18 | 19 | template 20 | T *dyn_cast(U *t) { 21 | if (!isa(t)) 22 | return nullptr; 23 | return cast(t); 24 | } 25 | 26 | } 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /src/utils/presburger/BasicSet.h: -------------------------------------------------------------------------------- 1 | #ifndef BASIC_SET_H 2 | #define BASIC_SET_H 3 | 4 | #include 5 | #include 6 | 7 | namespace pres { 8 | 9 | using AffineExpr = std::vector; 10 | 11 | class BasicSet { 12 | // [A I -b] [x 1]^T = 0 13 | std::vector tableau; 14 | std::vector denom; 15 | public: 16 | BasicSet(const std::vector &tableau): tableau(tableau) {} 17 | void dump(std::ostream &os); 18 | 19 | bool empty(); 20 | }; 21 | 22 | } 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /src/pre-opt/NoStore.cpp: -------------------------------------------------------------------------------- 1 | #include "PreAnalysis.h" 2 | 3 | using namespace sys; 4 | 5 | void NoStore::runImpl(Op *func) { 6 | auto stores = func->findAll(); 7 | for (auto store : stores) { 8 | auto addr = store->DEF(1); 9 | if (!addr->has()) 10 | return; 11 | auto base = BASE(addr); 12 | if (isa(base)) 13 | return; 14 | } 15 | if (!func->has()) 16 | func->add(); 17 | } 18 | 19 | void NoStore::run() { 20 | Base(module).run(); 21 | 22 | auto funcs = collectFuncs(); 23 | for (auto func : funcs) 24 | runImpl(func); 25 | } 26 | -------------------------------------------------------------------------------- /test/custom/manyargs.sy: -------------------------------------------------------------------------------- 1 | int f(int a1, int a2, int a3, int a4, int a5, int a6, int a7, int a8, int a9, int a10, int a11) { 2 | return a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11; 3 | } 4 | 5 | int fWithAlloca(int a1, int a2, int a3, int a4, int a5, int a6, int a7, int a8, int a9, int a10, int a11) { 6 | int arr[200]; 7 | arr[150] = 9; 8 | return arr[150] + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11; 9 | } 10 | 11 | 12 | int main() { 13 | putint(f(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)); 14 | putch(10); 15 | putint(fWithAlloca(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)); 16 | putch(10); 17 | return 0; 18 | } 19 | -------------------------------------------------------------------------------- /src/arm/ArmLoopPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef ARM_LOOP_PASSES_H 2 | #define ARM_LOOP_PASSES_H 3 | 4 | #include "../opt/Pass.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Ops.h" 7 | #include "../codegen/Attrs.h" 8 | #include "ArmOps.h" 9 | #include "ArmAttrs.h" 10 | #include "../opt/LoopPasses.h" 11 | 12 | namespace sys::arm { 13 | 14 | // Convert SCEV addresses to post-increment. 15 | class PostIncr : public Pass { 16 | void runImpl(LoopInfo *info); 17 | public: 18 | PostIncr(ModuleOp *module): Pass(module) {} 19 | 20 | std::string name() override { return "arm-post-incr"; }; 21 | std::map stats() override { return {}; } 22 | void run() override; 23 | }; 24 | 25 | } 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /src/opt/LoopInfo.cpp: -------------------------------------------------------------------------------- 1 | #include "LoopPasses.h" 2 | 3 | using namespace sys; 4 | 5 | void LoopInfo::dump(std::ostream &os) { 6 | os << "Blocks: "; 7 | for (auto bb : bbs) 8 | os << bbmap[bb] << " "; 9 | os << "\n"; 10 | 11 | os << "Preheader: " << (preheader ? std::to_string(bbmap[preheader]) : "none") << "\n"; 12 | os << "Header: " << bbmap[header] << "\n"; 13 | 14 | os << "Exits: "; 15 | for (auto bb : exits) 16 | os << bbmap[bb] << " "; 17 | os << "\n"; 18 | 19 | os << "Latches: "; 20 | for (auto bb : latches) 21 | os << bbmap[bb] << " "; 22 | os << "\n"; 23 | } 24 | 25 | void LoopForest::dump(std::ostream &os) { 26 | for (auto loop : loops) { 27 | loop->dump(os); 28 | os << "\n\n"; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/opt/SMTPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef SMT_PASSES_H 2 | #define SMT_PASSES_H 3 | 4 | #include "Pass.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Attrs.h" 7 | #include "../utils/smt/SMT.h" 8 | 9 | namespace sys { 10 | 11 | // Use SMT solver to guess a formula for constant arrys. 12 | class SynthConstArray : public Pass { 13 | smt::BvExprContext ctx; 14 | 15 | std::vector candidates; 16 | Builder builder; 17 | 18 | Op *reconstruct(smt::BvExpr *expr, Op *subscript, int c0, int c1); 19 | public: 20 | SynthConstArray(ModuleOp *module); 21 | 22 | std::string name() override { return "synth-const-array"; }; 23 | std::map stats() override { return {}; } 24 | void run() override; 25 | }; 26 | 27 | } 28 | 29 | #endif -------------------------------------------------------------------------------- /src/main/Options.h: -------------------------------------------------------------------------------- 1 | #ifndef OPTIONS_H 2 | #define OPTIONS_H 3 | 4 | #include 5 | 6 | namespace sys { 7 | 8 | struct Options { 9 | using option = unsigned char; 10 | 11 | struct { 12 | option dumpAST : 1; 13 | option noLink : 1; 14 | option dumpMidIR : 1; 15 | option o1 : 1; 16 | option arm : 1; 17 | option rv : 1; 18 | option verbose : 1; 19 | option stats : 1; 20 | option verify : 1; 21 | option bv : 1; 22 | option sat : 1; 23 | }; 24 | 25 | std::string inputFile; 26 | std::string outputFile; 27 | std::string printAfter; 28 | std::string printBefore; 29 | std::string compareWith; 30 | std::string simulateInput; 31 | 32 | Options(); 33 | }; 34 | 35 | Options parseArgs(int argc, char **argv); 36 | 37 | } 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /src/opt/PassManager.h: -------------------------------------------------------------------------------- 1 | #ifndef PASS_MANAGER_H 2 | #define PASS_MANAGER_H 3 | 4 | #include "Pass.h" 5 | #include "../main/Options.h" 6 | 7 | namespace sys { 8 | 9 | class PassManager { 10 | std::vector passes; 11 | ModuleOp *module; 12 | 13 | bool pastFlatten; 14 | bool pastMem2Reg; 15 | bool inBackend; 16 | int exitcode; 17 | 18 | std::string input; 19 | std::string truth; 20 | 21 | Options opts; 22 | public: 23 | PassManager(ModuleOp *module, const Options &opts); 24 | ~PassManager(); 25 | 26 | void run(); 27 | ModuleOp *getModule() { return module; } 28 | 29 | template 30 | void addPass(Args... args) { 31 | auto pass = new T(module, std::forward(args)...); 32 | passes.push_back(pass); 33 | } 34 | }; 35 | 36 | } 37 | 38 | #endif 39 | -------------------------------------------------------------------------------- /src/rv/RvDupPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef RV_DUP_PASSES_H 2 | #define RV_DUP_PASSES_H 3 | 4 | #include "../opt/Pass.h" 5 | #include "RvAttrs.h" 6 | #include "RvOps.h" 7 | #include "../codegen/Ops.h" 8 | #include "../codegen/Attrs.h" 9 | #include "../codegen/CodeGen.h" 10 | 11 | namespace sys::rv { 12 | 13 | // The only difference with opt/DCE is that `isImpure` behaves differently. 14 | class RvDCE : public Pass { 15 | std::vector removeable; 16 | int elim = 0; 17 | 18 | bool isImpure(Op *op); 19 | void markImpure(Region *region); 20 | void runOnRegion(Region *region); 21 | public: 22 | RvDCE(ModuleOp *module): Pass(module) {} 23 | 24 | std::string name() override { return "rv-dce"; }; 25 | std::map stats() override; 26 | void run() override; 27 | }; 28 | 29 | } 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /src/pre-opt/Remerge.cpp: -------------------------------------------------------------------------------- 1 | #include "PrePasses.h" 2 | 3 | using namespace sys; 4 | 5 | void Remerge::runImpl(Region *region) { 6 | auto entry = region->getFirstBlock(); 7 | const auto &bbs = region->getBlocks(); 8 | for (auto bb : bbs) { 9 | if (bb != entry) 10 | bb->inlineToEnd(entry); 11 | } 12 | for (auto it = --bbs.end(); it != bbs.begin();) { 13 | auto next = it; --next; 14 | (*it)->erase(); 15 | it = next; 16 | } 17 | 18 | // Recursively find operations with regions. 19 | for (auto op : entry->getOps()) { 20 | if (op->getRegionCount()) { 21 | for (auto x : op->getRegions()) 22 | runImpl(x); 23 | } 24 | } 25 | } 26 | 27 | void Remerge::run() { 28 | auto funcs = collectFuncs(); 29 | 30 | for (auto func : funcs) 31 | runImpl(func->getRegion()); 32 | 33 | MoveAlloca(module).run(); 34 | } 35 | -------------------------------------------------------------------------------- /src/arm/ArmMatcher.h: -------------------------------------------------------------------------------- 1 | #ifndef ARM_MATCHER_H 2 | #define ARM_MATCHER_H 3 | 4 | #include "../utils/Matcher.h" 5 | 6 | namespace sys { 7 | 8 | // The difference is that opcode is interpreted differently. 9 | class ArmRule { 10 | std::map binding; 11 | std::map blockBinding; 12 | std::map imms; 13 | 14 | std::string_view text; 15 | Expr *pattern; 16 | Builder builder; 17 | int loc = 0; 18 | bool failed = false; 19 | 20 | std::string_view nextToken(); 21 | Expr *parse(); 22 | 23 | bool matchExpr(Expr *expr, Op *op); 24 | int evalExpr(Expr *expr); 25 | Op *buildExpr(Expr *expr); 26 | 27 | void dump(Expr *expr, std::ostream &os); 28 | public: 29 | ArmRule(const char *text); 30 | bool rewrite(Op *op); 31 | 32 | void dump(std::ostream &os); 33 | }; 34 | 35 | } 36 | 37 | #endif 38 | -------------------------------------------------------------------------------- /src/parse/Type.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "Type.h" 4 | 5 | using namespace sys; 6 | 7 | std::string interleave(const std::vector &types) { 8 | std::stringstream ss; 9 | for (auto x : types) 10 | ss << x->toString() << ", "; 11 | auto str = ss.str(); 12 | // Remove the extra ", " at the end 13 | if (str.size() > 2) { 14 | str.pop_back(); 15 | str.pop_back(); 16 | } 17 | return str; 18 | } 19 | 20 | std::string FunctionType::toString() const { 21 | return "(" + interleave(params) + ") -> " + ret->toString(); 22 | } 23 | 24 | std::string ArrayType::toString() const { 25 | std::stringstream ss(base->toString()); 26 | for (auto x : dims) 27 | ss << "[" << x << "]"; 28 | return ss.str(); 29 | } 30 | 31 | int ArrayType::getSize() const { 32 | int size = 1; 33 | for (auto x : dims) 34 | size *= x; 35 | return size; 36 | } 37 | -------------------------------------------------------------------------------- /src/pre-opt/PreAttrs.cpp: -------------------------------------------------------------------------------- 1 | #include "PreAttrs.h" 2 | #include 3 | 4 | using namespace sys; 5 | 6 | std::string SubscriptAttr::toString() { 7 | std::stringstream ss; 8 | ss << " 0) 10 | ss << subscript[0]; 11 | for (int i = 1; i < subscript.size(); i++) 12 | ss << ", " << subscript[i]; 13 | ss << ">"; 14 | return ss.str(); 15 | } 16 | 17 | // Defined in OpBase.cpp. 18 | std::string getValueNumber(Value value); 19 | 20 | std::string BaseAttr::toString() { 21 | std::stringstream ss; 22 | ss << "getResult()) << ">"; 23 | return ss.str(); 24 | } 25 | 26 | std::string ParallelizableAttr::toString() { 27 | std::stringstream ss; 28 | if (!accum) 29 | return ""; 30 | ss << ""; 31 | return ss.str(); 32 | } 33 | -------------------------------------------------------------------------------- /src/rt/arm-clone.s: -------------------------------------------------------------------------------- 1 | ; R"( 2 | # x0: Function pointer 3 | # x1: Stack top 4 | instantiate_worker: 5 | sub sp, sp, #16 6 | stp x19, x20, [sp, #0] 7 | dmb ish 8 | # Syscall 220: 9 | # clone(flags, stack_top, parent_tid_ptr, child_tid_ptr, tls) 10 | mov x19, x0 11 | mov x20, x1 12 | 13 | # CLONE_VM | CLONE_FS | CLONE_FILES | CLONE_SIGHAND | CLONE_THREAD | CLONE_SYSVSEM 14 | mov x0, #3840 15 | movk x0, #5, lsl 16 16 | mov x2, #0 17 | mov x3, #0 18 | mrs x4, tpidr_el0 19 | mov x8, #220 20 | svc #0 21 | 22 | # For parent process, tid != 0, so diretly returns. 23 | cbnz x0, 1f 24 | 25 | # For child process, call the function. 26 | mov sp, x20 27 | blr x19 28 | 29 | # Exit child process when the function completes. 30 | # Syscall 93: 31 | # exit(value) 32 | mov x0, #0 33 | mov x8, #93 34 | svc #0 35 | 36 | 1: 37 | ldp x19, x20, [sp, #0] 38 | add sp, sp, #16 39 | ret 40 | )" 41 | -------------------------------------------------------------------------------- /src/opt/LowerPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef LOWER_PASSES_H 2 | #define LOWER_PASSES_H 3 | 4 | #include "Pass.h" 5 | #include "../codegen/Attrs.h" 6 | #include "../codegen/Ops.h" 7 | 8 | namespace sys { 9 | 10 | class FlattenCFG : public Pass { 11 | public: 12 | FlattenCFG(ModuleOp *module): Pass(module) {} 13 | 14 | std::string name() override { return "flatten-cfg"; }; 15 | std::map stats() override { return {}; }; 16 | void run() override; 17 | }; 18 | 19 | // A weak scheduler that only works on basic blocks. 20 | // This can't be in backend, because backends require that writereg-call-readreg must stay together. 21 | class InstSchedule : public Pass { 22 | void runImpl(BasicBlock *bb); 23 | public: 24 | InstSchedule(ModuleOp *module): Pass(module) {} 25 | 26 | std::string name() override { return "inst-schedule"; }; 27 | std::map stats() override { return {}; } 28 | void run() override; 29 | }; 30 | 31 | } 32 | 33 | #endif 34 | -------------------------------------------------------------------------------- /src/parse/Sema.h: -------------------------------------------------------------------------------- 1 | #ifndef SEMA_H 2 | #define SEMA_H 3 | 4 | #include "ASTNode.h" 5 | #include "Type.h" 6 | #include "TypeContext.h" 7 | #include 8 | #include 9 | 10 | namespace sys { 11 | 12 | // We don't need to do type inference, hence no memory management needed 13 | class Sema { 14 | TypeContext &ctx; 15 | // The current function we're in. Mainly used for deducing return type. 16 | Type *currentFunc; 17 | 18 | using SymbolTable = std::map; 19 | SymbolTable symbols; 20 | 21 | class SemanticScope { 22 | Sema &sema; 23 | SymbolTable symbols; 24 | public: 25 | SemanticScope(Sema &sema): sema(sema), symbols(sema.symbols) {} 26 | ~SemanticScope() { sema.symbols = symbols; } 27 | }; 28 | 29 | PointerType *decay(ArrayType *arrTy); 30 | ArrayType *raise(PointerType *ptr); 31 | 32 | Type *infer(ASTNode *node); 33 | public: 34 | // This modifies `node` inplace. 35 | Sema(ASTNode *node, TypeContext &ctx); 36 | }; 37 | 38 | } 39 | 40 | #endif 41 | -------------------------------------------------------------------------------- /src/opt/CallGraph.cpp: -------------------------------------------------------------------------------- 1 | #include "Analysis.h" 2 | 3 | using namespace sys; 4 | 5 | void CallGraph::run() { 6 | // Construct a call graph. 7 | // Actually Pureness can rely on this, but as it runs I wouldn't bother to change. 8 | std::map> calledBy; 9 | 10 | auto calls = module->findAll(); 11 | // We consider `clone()` syscall also as calling the worker function. 12 | auto workers = module->findAll(); 13 | calls.reserve(calls.size() + workers.size()); 14 | std::copy(workers.begin(), workers.end(), std::back_inserter(calls)); 15 | for (auto call : calls) { 16 | auto func = call->getParentOp(); 17 | auto calledName = NAME(call); 18 | if (!isExtern(calledName)) 19 | calledBy[calledName].insert(NAME(func)); 20 | } 21 | 22 | auto funcs = collectFuncs(); 23 | for (auto func : funcs) { 24 | // Remove the old version. 25 | func->remove(); 26 | 27 | const auto &name = NAME(func); 28 | const auto &callersSet = calledBy[name]; 29 | std::vector callers(callersSet.begin(), callersSet.end()); 30 | func->add(callers); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/utils/smt/BvMatcher.h: -------------------------------------------------------------------------------- 1 | #ifndef BV_MATCHER_H 2 | #define BV_MATCHER_H 3 | 4 | #include "BvExpr.h" 5 | #include 6 | #include "../Matcher.h" 7 | 8 | namespace smt { 9 | 10 | using sys::Expr; 11 | using sys::Atom; 12 | using sys::List; 13 | 14 | class BvRule { 15 | std::map binding; 16 | std::string_view text; 17 | std::vector externalStrs; 18 | Expr *pattern; 19 | int loc = 0; 20 | bool failed = false; 21 | 22 | std::string_view nextToken(); 23 | Expr *parse(); 24 | 25 | bool matchExpr(Expr *expr, BvExpr *bvexpr); 26 | int evalExpr(Expr *expr); 27 | float evalFExpr(Expr *expr); 28 | BvExpr *buildExpr(Expr *expr); 29 | 30 | void dump(Expr *expr, std::ostream &os); 31 | void release(Expr *expr); 32 | BvExpr *rewriteRoot(BvExpr *expr); 33 | public: 34 | using Binding = std::map; 35 | BvExprContext *ctx = nullptr; 36 | 37 | BvRule(const BvRule &other) = delete; 38 | 39 | BvRule(const char *text); 40 | ~BvRule(); 41 | BvExpr *rewrite(BvExpr *expr); 42 | BvExpr *extract(const std::string &name); 43 | 44 | void dump(std::ostream &os = std::cerr); 45 | }; 46 | 47 | } 48 | 49 | #endif 50 | -------------------------------------------------------------------------------- /src/arm/LateLegalize.cpp: -------------------------------------------------------------------------------- 1 | #include "ArmPasses.h" 2 | 3 | using namespace sys; 4 | using namespace sys::arm; 5 | 6 | void LateLegalize::run() { 7 | Builder builder; 8 | 9 | // ARM does not support `add x0, xzr, 1`. 10 | runRewriter([&](AddXIOp *op) { 11 | if (RS(op) == Reg::xzr) 12 | builder.replace(op, { RDC(RD(op)), new IntAttr(V(op)) }); 13 | 14 | return false; 15 | }); 16 | 17 | runRewriter([&](AddWIOp *op) { 18 | if (RS(op) == Reg::xzr) 19 | builder.replace(op, { RDC(RD(op)), new IntAttr(V(op)) }); 20 | 21 | return false; 22 | }); 23 | 24 | // Use `mov` and `movk` for an out-of-range `mov`. 25 | runRewriter([&](MovIOp *op) { 26 | int v = V(op); 27 | if (v >= 65536) { 28 | builder.setBeforeOp(op); 29 | builder.create({ RDC(RD(op)), new IntAttr(v & 0xffff) }); 30 | builder.replace(op, { RDC(RD(op)), new IntAttr(((unsigned) v) >> 16), new LslAttr(16) }); 31 | } 32 | if (v < -65536) { 33 | unsigned u = v; 34 | 35 | builder.setBeforeOp(op); 36 | builder.create({ RDC(RD(op)), new IntAttr((uint16_t)(~(uint16_t)(u & 0xffff))) }); 37 | builder.replace(op, { RDC(RD(op)), new IntAttr(u >> 16), new LslAttr(16) }); 38 | } 39 | return false; 40 | }); 41 | } 42 | -------------------------------------------------------------------------------- /src/pre-opt/Unroll.cpp: -------------------------------------------------------------------------------- 1 | #include "PreLoopPasses.h" 2 | 3 | using namespace sys; 4 | 5 | std::map Unroll::stats() { 6 | return { 7 | { "unrolled-loops", unrolled }, 8 | }; 9 | } 10 | 11 | namespace { 12 | 13 | bool innermost(Op *loop) { 14 | auto region = loop->getRegion(); 15 | auto entry = region->getFirstBlock(); 16 | for (auto op : entry->getOps()) { 17 | if (isa(op) || isa(op)) 18 | return false; 19 | } 20 | return true; 21 | } 22 | 23 | } 24 | 25 | // Defined in Unswitch.cpp. 26 | void unroll(Op *loop, int vi); 27 | 28 | // Defined in EarlyInline.cpp. 29 | int opcount(Region *region); 30 | 31 | void Unroll::run() { 32 | auto loops = module->findAll(); 33 | for (auto loop : loops) { 34 | // Only unroll innermost loops. 35 | if (!innermost(loop)) 36 | continue; 37 | 38 | // Don't unroll large loops. 39 | auto region = loop->getRegion(); 40 | if (opcount(region) >= 50) 41 | continue; 42 | 43 | // Don't unroll loops with calls in it. 44 | if (loop->findAll().size() > 0) 45 | continue; 46 | 47 | // unroll() requires that step is 1. 48 | auto step = loop->DEF(2); 49 | if (!isa(step) || V(step) != 1) 50 | continue; 51 | 52 | // Unroll it twice. 53 | unroll(loop, 2); 54 | unrolled++; 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/opt/Pureness.cpp: -------------------------------------------------------------------------------- 1 | #include "Analysis.h" 2 | 3 | using namespace sys; 4 | 5 | void Pureness::run() { 6 | auto funcs = collectFuncs(); 7 | 8 | // Construct a call graph. 9 | auto fnMap = getFunctionMap(); 10 | auto calls = module->findAll(); 11 | for (auto call : calls) { 12 | auto func = call->getParentOp(); 13 | auto calledName = NAME(call); 14 | if (!isExtern(calledName)) 15 | callGraph[func].insert(fnMap[calledName]); 16 | else if (!func->has()) 17 | // External functions are impure. 18 | func->add(); 19 | } 20 | 21 | // Every function that accesses globals is impure. 22 | for (auto func : funcs) { 23 | if (!func->has() && !func->findAll().empty()) 24 | func->add(); 25 | } 26 | 27 | // Propagate impureness across functions: 28 | // if a functions calls any impure function then it becomes impure. 29 | bool changed; 30 | do { 31 | changed = false; 32 | for (auto func : funcs) { 33 | bool impure = false; 34 | for (auto v : callGraph[func]) { 35 | if (v->has()) { 36 | impure = true; 37 | break; 38 | } 39 | } 40 | if (!func->has() && impure) { 41 | changed = true; 42 | func->add(); 43 | } 44 | } 45 | } while (changed); 46 | } 47 | -------------------------------------------------------------------------------- /src/parse/Lexer.h: -------------------------------------------------------------------------------- 1 | #ifndef LEXER_H 2 | #define LEXER_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace sys { 9 | 10 | struct Token { 11 | enum Type { 12 | // Literals 13 | LInt, LFloat, Ident, 14 | 15 | // Operators 16 | Plus, Minus, Mul, Div, Mod, 17 | PlusEq, MinusEq, MulEq, DivEq, ModEq, 18 | Le, Ge, Gt, Lt, Eq, Ne, 19 | And, Or, Semicolon, Assign, Not, 20 | LPar, RPar, LBrak, RBrak, LBrace, RBrace, 21 | Comma, 22 | 23 | // Keywords 24 | If, Else, While, For, Return, Int, Float, Void, 25 | Const, Break, Continue, 26 | 27 | // EOF 28 | End, 29 | } type; 30 | 31 | // We don't use std::string here to save space - 32 | // We have to manually free space anyway. (Destructors won't be called inside union.) 33 | union { 34 | int vi; 35 | float vf; 36 | char *vs; 37 | }; 38 | 39 | /* implicit */ Token(Type t): type(t) {} 40 | /* implicit */ Token(int vi): type(LInt), vi(vi) {} 41 | /* implicit */ Token(float vf): type(LFloat), vf(vf) {} 42 | /* implicit */ Token(const std::string &str): type(Ident), vs(new char[str.size() + 1]) { 43 | strcpy(vs, str.c_str()); 44 | } 45 | }; 46 | 47 | class Lexer { 48 | std::string input; 49 | 50 | // Index of `input` 51 | size_t loc = 0; 52 | size_t lineno = 1; 53 | public: 54 | Lexer(const std::string &input): input(input) {} 55 | 56 | Token nextToken(); 57 | bool hasMore() const; 58 | }; 59 | 60 | } 61 | 62 | #endif 63 | -------------------------------------------------------------------------------- /src/utils/Matcher.h: -------------------------------------------------------------------------------- 1 | #ifndef MATCHER_H 2 | #define MATCHER_H 3 | 4 | #include "../codegen/CodeGen.h" 5 | 6 | namespace sys { 7 | 8 | struct Expr { 9 | int id; 10 | Expr(int id): id(id) {} 11 | virtual ~Expr() {} 12 | }; 13 | 14 | struct Atom : Expr { 15 | template 16 | static bool classof(T *t) { return t->id == 1; } 17 | 18 | std::string_view value; 19 | Atom(std::string_view value): Expr(1), value(value) {} 20 | }; 21 | 22 | struct List : Expr { 23 | template 24 | static bool classof(T *t) { return t->id == 2; } 25 | 26 | std::vector elements; 27 | List(): Expr(2) {} 28 | }; 29 | 30 | 31 | class Rule { 32 | std::map binding; 33 | std::string_view text; 34 | std::vector externalStrs; 35 | Expr *pattern; 36 | Builder builder; 37 | int loc = 0; 38 | bool failed = false; 39 | 40 | std::string_view nextToken(); 41 | Expr *parse(); 42 | 43 | bool matchExpr(Expr *expr, Op *op); 44 | int evalExpr(Expr *expr); 45 | float evalFExpr(Expr *expr); 46 | Op *buildExpr(Expr *expr); 47 | 48 | void dump(Expr *expr, std::ostream &os); 49 | void release(Expr *expr); 50 | public: 51 | using Binding = std::map; 52 | 53 | Rule(const Rule &other) = delete; 54 | 55 | Rule(const char *text); 56 | ~Rule(); 57 | bool rewrite(Op *op); 58 | bool match(Op *op, const Binding &external = {}); 59 | Op *extract(const std::string &name); 60 | 61 | void dump(std::ostream &os = std::cerr); 62 | }; 63 | 64 | } 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /src/rv/RvDCE.cpp: -------------------------------------------------------------------------------- 1 | #include "RvDupPasses.h" 2 | 3 | using namespace sys::rv; 4 | using namespace sys; 5 | 6 | std::map RvDCE::stats() { 7 | return { 8 | { "eliminated-ops", elim } 9 | }; 10 | } 11 | 12 | bool RvDCE::isImpure(Op *op) { 13 | if (isa(op) || isa(op) || 14 | isa(op) || isa(op) || 15 | isa(op) || isa(op) || isa(op) || 16 | isa(op) || isa(op) || 17 | isa(op)) 18 | return true; 19 | 20 | return false; 21 | } 22 | 23 | // Here no nested regions are possible. 24 | void RvDCE::markImpure(Region *region) { 25 | for (auto bb : region->getBlocks()) { 26 | for (auto op : bb->getOps()) { 27 | if (isImpure(op) && !op->has()) 28 | op->add(); 29 | } 30 | } 31 | } 32 | 33 | void RvDCE::runOnRegion(Region *region) { 34 | for (auto bb : region->getBlocks()) { 35 | for (auto op : bb->getOps()) { 36 | if (!op->has() && op->getUses().size() == 0) 37 | removeable.push_back(op); 38 | else for (auto r : op->getRegions()) 39 | runOnRegion(r); 40 | } 41 | } 42 | } 43 | 44 | void RvDCE::run() { 45 | auto funcs = collectFuncs(); 46 | 47 | for (auto func : funcs) 48 | markImpure(func->getRegion()); 49 | 50 | do { 51 | removeable.clear(); 52 | for (auto func : funcs) 53 | runOnRegion(func->getRegion()); 54 | 55 | elim += removeable.size(); 56 | for (auto op : removeable) 57 | op->erase(); 58 | } while (removeable.size()); 59 | } 60 | -------------------------------------------------------------------------------- /src/opt/AtMostOnce.cpp: -------------------------------------------------------------------------------- 1 | #include "Analysis.h" 2 | 3 | using namespace sys; 4 | 5 | // This runs before Flatten CFG. 6 | void AtMostOnce::run() { 7 | CallGraph(module).run(); 8 | auto funcs = collectFuncs(); 9 | auto fnMap = getFunctionMap(); 10 | 11 | for (auto func : funcs) { 12 | if (func->has()) 13 | continue; 14 | 15 | const auto &callers = CALLER(func); 16 | 17 | if (callers.size() == 0) { 18 | func->add(); 19 | continue; 20 | } 21 | 22 | if (callers.size() >= 2) 23 | continue; 24 | 25 | FuncOp *caller = fnMap[callers[0]]; 26 | const auto &selfName = NAME(func); 27 | // Recursive functions aren't candidates. 28 | if (caller == func) 29 | continue; 30 | 31 | auto calls = caller->findAll(); 32 | bool good = true; 33 | Op *call = nullptr; 34 | 35 | // First, make sure there's only one call that calls the function. 36 | for (auto op : calls) { 37 | if (NAME(op) == selfName) { 38 | if (call) { 39 | good = false; 40 | break; 41 | } 42 | call = op; 43 | } 44 | } 45 | 46 | if (!good) 47 | continue; 48 | 49 | // Next, make sure the call isn't enclosed in a WhileOp. 50 | Op *father = call; 51 | while (!isa(father)) { 52 | father = father->getParentOp(); 53 | if (isa(father)) { 54 | good = false; 55 | break; 56 | } 57 | } 58 | 59 | if (!good) 60 | continue; 61 | 62 | // Now we know the function is called at most once. 63 | func->add(); 64 | } 65 | } -------------------------------------------------------------------------------- /src/pre-opt/LoopDCE.cpp: -------------------------------------------------------------------------------- 1 | #include "PreLoopPasses.h" 2 | 3 | using namespace sys; 4 | 5 | std::map LoopDCE::stats() { 6 | return { 7 | { "erased-loops", erased }, 8 | }; 9 | } 10 | 11 | namespace { 12 | 13 | bool pure(Region *region) { 14 | if (!region->getBlocks().size()) 15 | return true; 16 | 17 | auto entry = region->getFirstBlock(); 18 | for (auto op : entry->getOps()) { 19 | if (op->has()) 20 | return false; 21 | for (auto x : op->getRegions()) { 22 | if (!pure(x)) 23 | return false; 24 | } 25 | } 26 | return true; 27 | } 28 | 29 | } 30 | 31 | void LoopDCE::run() { 32 | Builder builder; 33 | bool changed; 34 | do { 35 | auto loops = module->findAll(); 36 | changed = false; 37 | for (auto loop : loops) { 38 | auto region = loop->getRegion(); 39 | if (pure(region)) { 40 | auto step = loop->DEF(2); 41 | if (!isa(step) || V(step) != 1) 42 | continue; 43 | 44 | // Replace with a store of `ivAddr`. 45 | builder.setAfterOp(loop); 46 | builder.create({ loop->getOperand(1), loop->getOperand(3) }, { new SizeAttr(4) }); 47 | loop->erase(), changed = true, erased++; 48 | continue; 49 | } 50 | 51 | // For a parallel loop (no loop-carried dependency) whose indvar is never used, 52 | // It is simply doing repeated work. 53 | if (loop->has() && !loop->getUses().size()) { 54 | if (region->getBlocks().size()) { 55 | auto entry = region->getFirstBlock(); 56 | entry->inlineBefore(loop); 57 | } 58 | loop->erase(), changed = true, erased++; 59 | break; 60 | } 61 | } 62 | } while (changed); 63 | } 64 | -------------------------------------------------------------------------------- /src/pre-opt/PreAnalysis.h: -------------------------------------------------------------------------------- 1 | #ifndef PRE_ANALYSIS 2 | #define PRE_ANALYSIS 3 | 4 | #include "../opt/Pass.h" 5 | #include "../codegen/Ops.h" 6 | #include "../codegen/Attrs.h" 7 | #include "PreAttrs.h" 8 | 9 | namespace sys { 10 | 11 | // Marks addresses, loads and stores with `SubscriptAttr`. 12 | class ArrayAccess : public Pass { 13 | // Takes all induction variables outside the current loop, 14 | // including that of the loop we're inspecting. 15 | // (In other words, outer.size() >= 1.) 16 | void runImpl(Op *loop, std::vector outer); 17 | public: 18 | ArrayAccess(ModuleOp *module): Pass(module) {} 19 | 20 | std::string name() override { return "array-access"; }; 21 | std::map stats() override { return {}; }; 22 | void run() override; 23 | }; 24 | 25 | // Marks base of an array. 26 | class Base : public Pass { 27 | void runImpl(Region *region); 28 | public: 29 | Base(ModuleOp *module): Pass(module) {} 30 | 31 | std::string name() override { return "base"; }; 32 | std::map stats() override { return {}; }; 33 | void run() override; 34 | }; 35 | 36 | // Checks whether a loop is parallelizable. 37 | class Parallelizable : public Pass { 38 | void runImpl(Op *loop, int depth); 39 | public: 40 | Parallelizable(ModuleOp *module): Pass(module) {} 41 | 42 | std::string name() override { return "parallelizable"; }; 43 | std::map stats() override { return {}; }; 44 | void run() override; 45 | }; 46 | 47 | // Checks whether a function does not store to global variable. 48 | class NoStore : public Pass { 49 | void runImpl(Op *func); 50 | public: 51 | NoStore(ModuleOp *module): Pass(module) {} 52 | 53 | std::string name() override { return "no-store"; }; 54 | std::map stats() override { return {}; }; 55 | void run() override; 56 | }; 57 | 58 | } 59 | 60 | #endif 61 | -------------------------------------------------------------------------------- /src/parse/Type.h: -------------------------------------------------------------------------------- 1 | #ifndef TYPE_H 2 | #define TYPE_H 3 | 4 | #include 5 | #include 6 | namespace sys { 7 | 8 | class Type { 9 | const int id; 10 | public: 11 | int getID() const { return id; } 12 | virtual std::string toString() const = 0; 13 | 14 | virtual ~Type() {} 15 | Type(int id): id(id) {} 16 | }; 17 | 18 | template 19 | class TypeImpl : public Type { 20 | public: 21 | static bool classof(Type *ty) { 22 | return ty->getID() == TypeID; 23 | } 24 | 25 | TypeImpl(): Type(TypeID) {} 26 | }; 27 | 28 | class IntType : public TypeImpl { 29 | public: 30 | std::string toString() const override { return "int"; } 31 | }; 32 | 33 | class FloatType : public TypeImpl { 34 | public: 35 | std::string toString() const override { return "float"; } 36 | }; 37 | 38 | class VoidType : public TypeImpl { 39 | public: 40 | std::string toString() const override { return "void"; } 41 | }; 42 | 43 | class PointerType : public TypeImpl { 44 | public: 45 | Type *pointee; 46 | 47 | PointerType(Type *pointee): pointee(pointee) {} 48 | 49 | std::string toString() const override { return pointee->toString() + "*"; } 50 | }; 51 | 52 | class FunctionType : public TypeImpl { 53 | public: 54 | Type *ret; 55 | std::vector params; 56 | 57 | FunctionType(Type *ret, std::vector params): 58 | ret(ret), params(params) {} 59 | 60 | std::string toString() const override; 61 | }; 62 | 63 | class ASTNode; 64 | class ArrayType : public TypeImpl { 65 | public: 66 | Type *base; 67 | std::vector dims; 68 | 69 | ArrayType(Type *base, std::vector dims): 70 | base(base), dims(dims) {} 71 | 72 | std::string toString() const override; 73 | int getSize() const; 74 | }; 75 | 76 | } 77 | 78 | #endif 79 | -------------------------------------------------------------------------------- /src/pre-opt/TidyMemory.cpp: -------------------------------------------------------------------------------- 1 | #include "PrePasses.h" 2 | #include "PreAnalysis.h" 3 | 4 | using namespace sys; 5 | 6 | std::map TidyMemory::stats() { 7 | return { 8 | { "tidied-ops", tidied } 9 | }; 10 | } 11 | 12 | void TidyMemory::runImpl(Region *region) { 13 | // Maps stored addresses into the value. 14 | std::unordered_map values; 15 | for (auto bb : region->getBlocks()) { 16 | auto ops = bb->getOps(); 17 | for (auto op : ops) { 18 | // Conservatively invalidates all stores. 19 | if (op->getRegionCount()) { 20 | values.clear(); 21 | for (auto r : op->getRegions()) 22 | runImpl(r); 23 | continue; 24 | } 25 | 26 | if (isa(op) && op->has()) { 27 | values.clear(); 28 | continue; 29 | } 30 | 31 | if (isa(op)) { 32 | auto addr = op->DEF(1); 33 | auto val = op->DEF(0); 34 | if (isa(val) || isa(val) || isa(val)) { 35 | values.erase(addr); 36 | continue; 37 | } 38 | 39 | // Have to make sure it's not an array; 40 | // otherwise there's no way to detect whether another thing aliases with it. 41 | if (!addr->has()) { 42 | values.clear(); 43 | continue; 44 | } 45 | auto base = BASE(addr); 46 | if (!isa(base) || SIZE(base) != 4) 47 | continue; 48 | 49 | values[addr] = val; 50 | continue; 51 | } 52 | 53 | if (isa(op) && values.count(op->DEF())) { 54 | op->replaceAllUsesWith(values[op->DEF()]); 55 | op->erase(); 56 | tidied++; 57 | continue; 58 | } 59 | } 60 | } 61 | } 62 | 63 | void TidyMemory::run() { 64 | Base(module).run(); 65 | auto funcs = collectFuncs(); 66 | 67 | for (auto func : funcs) 68 | runImpl(func->getRegion()); 69 | } -------------------------------------------------------------------------------- /src/opt/AggressiveDCE.cpp: -------------------------------------------------------------------------------- 1 | #include "CleanupPasses.h" 2 | #include 3 | 4 | using namespace sys; 5 | 6 | std::map AggressiveDCE::stats() { 7 | return { 8 | { "removed-ops", elim }, 9 | }; 10 | } 11 | 12 | #define PRESERVED(Ty) || isa(op) 13 | static bool preserved(Op *op) { 14 | return (isa(op) && op->has()) 15 | PRESERVED(BranchOp) 16 | PRESERVED(GotoOp) 17 | PRESERVED(StoreOp) 18 | PRESERVED(ReturnOp) 19 | PRESERVED(CloneOp) 20 | PRESERVED(JoinOp) 21 | PRESERVED(WakeOp); 22 | } 23 | 24 | void AggressiveDCE::runImpl(FuncOp *fn) { 25 | auto rets = fn->findAll(); 26 | auto calls = fn->findAll(); 27 | auto stores = fn->findAll(); 28 | auto branches = fn->findAll(); 29 | 30 | std::unordered_set live; 31 | std::vector queue(rets.begin(), rets.end()); 32 | for (auto call : calls) { 33 | if (call->has()) 34 | queue.push_back(call); 35 | } 36 | for (auto store : stores) 37 | queue.push_back(store); 38 | for (auto branch : branches) 39 | queue.push_back(branch); 40 | 41 | // Find all reachable operations from the outside-used ones. 42 | while (!queue.empty()) { 43 | auto op = queue.back(); 44 | queue.pop_back(); 45 | 46 | if (live.count(op)) 47 | continue; 48 | live.insert(op); 49 | 50 | for (auto operand : op->getOperands()) 51 | queue.push_back(operand.defining); 52 | } 53 | 54 | // Remove every operation that isn't live. 55 | auto region = fn->getRegion(); 56 | std::vector remove; 57 | for (auto bb : region->getBlocks()) { 58 | for (auto op : bb->getOps()) { 59 | if (!preserved(op) && !live.count(op)) { 60 | op->removeAllOperands(); 61 | remove.push_back(op); 62 | } 63 | } 64 | } 65 | 66 | elim += remove.size(); 67 | for (auto op : remove) 68 | op->erase(); 69 | } 70 | 71 | void AggressiveDCE::run() { 72 | auto funcs = collectFuncs(); 73 | for (auto func : funcs) 74 | runImpl(func); 75 | } 76 | -------------------------------------------------------------------------------- /src/arm/ArmDCE.cpp: -------------------------------------------------------------------------------- 1 | #include "ArmOps.h" 2 | #include "ArmPasses.h" 3 | 4 | using namespace sys::arm; 5 | using namespace sys; 6 | 7 | std::map ArmDCE::stats() { 8 | return { 9 | { "eliminated-ops", elim } 10 | }; 11 | } 12 | 13 | #define IMPURE(Ty) || isa(op) 14 | bool ArmDCE::isImpure(Op *op) { 15 | return isa(op) 16 | IMPURE(StrFOp) 17 | IMPURE(StrWOp) 18 | IMPURE(StrXOp) 19 | IMPURE(StrFROp) 20 | IMPURE(StrWROp) 21 | IMPURE(StrXROp) 22 | IMPURE(StrFPOp) 23 | IMPURE(StrWPOp) 24 | IMPURE(StrXPOp) 25 | IMPURE(St1Op) 26 | IMPURE(BlOp) 27 | IMPURE(BgtOp) 28 | IMPURE(BltOp) 29 | IMPURE(BleOp) 30 | IMPURE(BeqOp) 31 | IMPURE(BneOp) 32 | IMPURE(BgeOp) 33 | IMPURE(BplOp) 34 | IMPURE(BmiOp) 35 | IMPURE(RetOp) 36 | IMPURE(CbzOp) 37 | IMPURE(CbnzOp) 38 | IMPURE(BOp) 39 | IMPURE(WriteRegOp) 40 | IMPURE(SubSpOp) 41 | IMPURE(PlaceHolderOp) 42 | IMPURE(CloneOp) 43 | IMPURE(JoinOp) 44 | IMPURE(WakeOp) 45 | ; 46 | } 47 | 48 | // Here no nested regions are possible. 49 | void ArmDCE::markImpure(Region *region) { 50 | for (auto bb : region->getBlocks()) { 51 | for (auto op : bb->getOps()) { 52 | if (isImpure(op) && !op->has()) 53 | op->add(); 54 | } 55 | } 56 | } 57 | 58 | void ArmDCE::runOnRegion(Region *region) { 59 | for (auto bb : region->getBlocks()) { 60 | for (auto op : bb->getOps()) { 61 | if (!op->has() && op->getUses().size() == 0) 62 | removeable.push_back(op); 63 | else for (auto r : op->getRegions()) 64 | runOnRegion(r); 65 | } 66 | } 67 | } 68 | 69 | void ArmDCE::run() { 70 | auto funcs = collectFuncs(); 71 | for (auto func : funcs) 72 | markImpure(func->getRegion()); 73 | 74 | do { 75 | removeable.clear(); 76 | for (auto func : funcs) 77 | runOnRegion(func->getRegion()); 78 | 79 | elim += removeable.size(); 80 | for (auto op : removeable) 81 | op->erase(); 82 | } while (removeable.size()); 83 | } 84 | -------------------------------------------------------------------------------- /src/opt/RemoveEmptyLoop.cpp: -------------------------------------------------------------------------------- 1 | #include "LoopPasses.h" 2 | 3 | using namespace sys; 4 | 5 | std::map RemoveEmptyLoop::stats() { 6 | return { 7 | { "removed-loops", removed } 8 | }; 9 | } 10 | 11 | #define PINNED(Ty) || isa(op) 12 | static bool pinned(Op *op) { 13 | return (isa(op) && op->has()) 14 | PINNED(StoreOp); 15 | } 16 | 17 | bool RemoveEmptyLoop::runImpl(LoopInfo *info) { 18 | if (info->exits.size() != 1) 19 | return false; 20 | 21 | for (auto bb : info->getBlocks()) { 22 | for (auto op : bb->getOps()) { 23 | // Side-effect. 24 | if (pinned(op)) 25 | return false; 26 | 27 | // Something is used outside, cannot remove. 28 | for (auto use : op->getUses()) { 29 | if (!info->contains(use->getParent())) 30 | return false; 31 | } 32 | } 33 | } 34 | 35 | // Safe to remove. 36 | // All header's predecessors should now connect to the exit. 37 | auto header = info->header; 38 | auto exit = info->getExit(); 39 | for (auto pred : header->preds) { 40 | auto term = pred->getLastOp(); 41 | if (TARGET(term) == header) 42 | TARGET(term) = exit; 43 | if (term->has() && ELSE(term) == header) 44 | ELSE(term) = exit; 45 | } 46 | 47 | for (auto bb : info->getBlocks()) { 48 | for (auto op : bb->getOps()) 49 | op->removeAllOperands(); 50 | } 51 | auto bbs = info->getBlocks(); 52 | for (auto bb : bbs) 53 | bb->forceErase(); 54 | 55 | removed++; 56 | return true; 57 | } 58 | 59 | void RemoveEmptyLoop::run() { 60 | LoopAnalysis analysis(module); 61 | auto funcs = collectFuncs(); 62 | 63 | for (auto func : funcs) { 64 | auto region = func->getRegion(); 65 | auto forest = analysis.runImpl(region); 66 | 67 | bool changed; 68 | do { 69 | changed = false; 70 | for (auto loop : forest.getLoops()) { 71 | if (!runImpl(loop)) 72 | continue; 73 | 74 | forest = analysis.runImpl(region); 75 | changed = true; 76 | break; 77 | } 78 | } while (changed); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/opt/Pass.h: -------------------------------------------------------------------------------- 1 | #ifndef PASS_H 2 | #define PASS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "../codegen/Ops.h" 12 | 13 | namespace sys { 14 | 15 | using DomTree = std::unordered_map>; 16 | 17 | bool isExtern(const std::string &name); 18 | 19 | class Pass { 20 | template 21 | static A helper(Ret (F::*)(A) const); 22 | 23 | template 24 | using argument_t = decltype(helper(&F::operator())); 25 | protected: 26 | ModuleOp *module; 27 | 28 | template 29 | void runRewriter(Op *op, F rewriter) { 30 | using T = std::remove_pointer_t>; 31 | 32 | bool success; 33 | int total = 0; 34 | do { 35 | // Probably hit an infinite loop. 36 | if (++total > 10000) 37 | assert(false); 38 | 39 | auto ts = op->findAll(); 40 | success = false; 41 | for (auto t : ts) 42 | success |= rewriter(cast(t)); 43 | } while (success); 44 | } 45 | 46 | template 47 | void runRewriter(F rewriter) { 48 | runRewriter(module, rewriter); 49 | } 50 | 51 | // This will be faster than module->findAll, 52 | // as it doesn't need to iterate through the contents of functions. 53 | std::vector collectFuncs(); 54 | // Same as above, only that it's for global variables. 55 | std::vector collectGlobals(); 56 | std::map getFunctionMap(); 57 | std::map getGlobalMap(); 58 | DomTree getDomTree(Region *region); 59 | 60 | // Find the first op that isn't an AllocaOp. 61 | Op *nonalloca(Region *region); 62 | // Find the first op that isn't a PhiOp. 63 | Op *nonphi(BasicBlock *bb); 64 | public: 65 | Pass(ModuleOp *module): module(module) {} 66 | void cleanup(); 67 | virtual ~Pass() {} 68 | virtual std::string name() = 0; 69 | virtual std::map stats() = 0; 70 | virtual void run() = 0; 71 | }; 72 | 73 | } 74 | 75 | #endif 76 | -------------------------------------------------------------------------------- /src/pre-opt/PreAttrs.h: -------------------------------------------------------------------------------- 1 | #ifndef PRE_ATTRS_H 2 | #define PRE_ATTRS_H 3 | 4 | #include "../codegen/OpBase.h" 5 | #include 6 | 7 | #define PREOPTLINE __LINE__ + 8388608 8 | 9 | namespace sys { 10 | 11 | using AffineExpr = std::vector; 12 | 13 | // It only stores the coefficients. They are to be multiplied with loop induction variables. 14 | // subscript[0] is the coefficient for the outermost loop. 15 | // subscript.back() is a constant, hence `subscript.size()` is the loop nest depth plus 1. 16 | class SubscriptAttr : public AttrImpl { 17 | public: 18 | AffineExpr subscript; 19 | SubscriptAttr(const AffineExpr &subscript): 20 | subscript(subscript) {} 21 | 22 | std::string toString() override; 23 | SubscriptAttr *clone() override { return new SubscriptAttr(subscript); } 24 | }; 25 | 26 | class BaseAttr : public AttrImpl { 27 | public: 28 | Op *base; 29 | BaseAttr(Op *base): base(base) {} 30 | 31 | std::string toString() override; 32 | BaseAttr *clone() override { return new BaseAttr(base); } 33 | }; 34 | 35 | // Shows whether the loop is parallelizable; 36 | // For loops without loop-carried dependencies, it's trivial, 37 | // but if there is a single dependency of a scalar, then it's considered an accumulator. 38 | class ParallelizableAttr : public AttrImpl { 39 | public: 40 | Op *accum; 41 | ParallelizableAttr(): accum(nullptr) {} 42 | ParallelizableAttr(Op *accum): accum(accum) {} 43 | 44 | std::string toString() override; 45 | ParallelizableAttr *clone() override { return new ParallelizableAttr(accum); } 46 | }; 47 | 48 | class NoStoreAttr : public AttrImpl { 49 | public: 50 | std::string toString() override { return ""; } 51 | NoStoreAttr *clone() override { return new NoStoreAttr; } 52 | }; 53 | 54 | } 55 | 56 | #define SUBSCRIPT(op) (op)->get()->subscript 57 | #define BASE(op) (op)->get()->base 58 | #define PARALLEL(op) (op)->has() 59 | #define ACCUM(op) (op)->get()->accum 60 | 61 | #endif 62 | -------------------------------------------------------------------------------- /src/main/Options.cpp: -------------------------------------------------------------------------------- 1 | #include "Options.h" 2 | #include 3 | #include 4 | 5 | using namespace sys; 6 | 7 | #define PARSEOPT(str, field) \ 8 | if (strcmp(argv[i], str) == 0) { \ 9 | opts.field = true; \ 10 | continue; \ 11 | } 12 | 13 | Options::Options() { 14 | noLink = false; 15 | dumpAST = false; 16 | dumpMidIR = false; 17 | o1 = false; 18 | arm = false; 19 | rv = false; 20 | verbose = false; 21 | stats = false; 22 | verify = false; 23 | sat = false; 24 | bv = false; 25 | } 26 | 27 | Options sys::parseArgs(int argc, char **argv) { 28 | Options opts; 29 | 30 | for (int i = 1; i < argc; i++) { 31 | if (strcmp(argv[i], "-o") == 0) { 32 | opts.outputFile = argv[i + 1]; 33 | i++; 34 | continue; 35 | } 36 | 37 | if (strcmp(argv[i], "--print-after") == 0) { 38 | opts.printAfter = argv[i + 1]; 39 | i++; 40 | continue; 41 | } 42 | 43 | if (strcmp(argv[i], "--print-before") == 0) { 44 | opts.printBefore = argv[i + 1]; 45 | i++; 46 | continue; 47 | } 48 | 49 | if (strcmp(argv[i], "--compare") == 0) { 50 | opts.compareWith = argv[i + 1]; 51 | i++; 52 | continue; 53 | } 54 | 55 | if (strcmp(argv[i], "-i") == 0) { 56 | opts.simulateInput = argv[i + 1]; 57 | i++; 58 | continue; 59 | } 60 | 61 | PARSEOPT("--dump-ast", dumpAST); 62 | PARSEOPT("--dump-mid-ir", dumpMidIR); 63 | PARSEOPT("--rv", rv); 64 | PARSEOPT("--arm", arm); 65 | PARSEOPT("-O1", o1); 66 | PARSEOPT("-S", noLink); 67 | PARSEOPT("-v", verbose); 68 | PARSEOPT("--stats", stats); 69 | PARSEOPT("-s", stats); 70 | PARSEOPT("--verify", verify); 71 | PARSEOPT("--bv", bv); 72 | PARSEOPT("--sat", sat); 73 | 74 | if (opts.inputFile != "") { 75 | std::cerr << "error: multiple inputs\n"; 76 | exit(1); 77 | } 78 | 79 | opts.inputFile = argv[i]; 80 | } 81 | 82 | if (opts.rv && opts.arm) { 83 | std::cerr << "error: multiple target\n"; 84 | exit(1); 85 | } 86 | 87 | if (!opts.rv && !opts.arm) 88 | opts.rv = true; 89 | 90 | return opts; 91 | } -------------------------------------------------------------------------------- /src/utils/Exec.h: -------------------------------------------------------------------------------- 1 | #ifndef EXEC_H 2 | #define EXEC_H 3 | 4 | #include "../codegen/Ops.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace sys::exec { 11 | 12 | const int CACHE_3_N = 32; 13 | const int CACHE_2_N = 1024; 14 | 15 | using cache_3 = int[CACHE_3_N][CACHE_3_N][CACHE_3_N]; 16 | using cache_2 = int[CACHE_2_N][CACHE_2_N]; 17 | 18 | using cache_3_ptr = int(*)[CACHE_3_N][CACHE_3_N]; 19 | using cache_2_ptr = int(*)[CACHE_2_N]; 20 | 21 | const int CACHE_3_TOTAL = sizeof(cache_3) / sizeof(int); 22 | const int CACHE_2_TOTAL = sizeof(cache_2) / sizeof(int); 23 | 24 | class Interpreter { 25 | union Value { 26 | intptr_t vi; 27 | float vf; 28 | }; 29 | 30 | using SymbolTable = std::unordered_map; 31 | 32 | std::stringstream outbuf, inbuf; 33 | std::map fnMap; 34 | std::set fpGlobals; 35 | std::map globalMap; 36 | 37 | SymbolTable value; 38 | // Used for phi functions. 39 | BasicBlock *prev; 40 | // Instruction pointer. 41 | Op *ip; 42 | 43 | intptr_t eval(Op *op); 44 | float evalf(Op *op); 45 | 46 | void store(Op *op, float v); 47 | void store(Op *op, intptr_t v); 48 | 49 | void exec(Op *op); 50 | Value execf(Region *region, const std::vector &args); 51 | 52 | Value applyExtern(const std::string &name, const std::vector &args); 53 | 54 | unsigned retcode; 55 | int *cache = nullptr; 56 | int cache_type = 0; 57 | 58 | struct SemanticScope { 59 | Interpreter &parent; 60 | SymbolTable table; 61 | public: 62 | SemanticScope(Interpreter &itp): parent(itp), table(itp.value) {} 63 | ~SemanticScope() { parent.value = table; } 64 | }; 65 | public: 66 | Interpreter(ModuleOp *module); 67 | ~Interpreter(); 68 | 69 | void run(std::istream &input); 70 | void runFunction(const std::string &func, const std::vector &args); 71 | void useCache(cache_3 cache) { this->cache = (int*) cache; cache_type = 3; } 72 | void useCache(cache_2 cache) { this->cache = (int*) cache; cache_type = 2; } 73 | std::string out() { return outbuf.str(); } 74 | int exitcode() { return retcode & 0xff; } 75 | }; 76 | 77 | } 78 | 79 | #endif 80 | -------------------------------------------------------------------------------- /src/rv/RvPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef RV_PASSES_H 2 | #define RV_PASSES_H 3 | 4 | #include "../opt/Pass.h" 5 | #include "RvAttrs.h" 6 | #include "RvOps.h" 7 | #include "../codegen/Ops.h" 8 | #include "../codegen/Attrs.h" 9 | #include "../codegen/CodeGen.h" 10 | 11 | namespace sys { 12 | 13 | namespace rv { 14 | 15 | class Lower : public Pass { 16 | public: 17 | Lower(ModuleOp *module): Pass(module) {} 18 | 19 | std::string name() override { return "rv-lower"; }; 20 | std::map stats() override { return {}; }; 21 | void run() override; 22 | }; 23 | 24 | class StrengthReduct : public Pass { 25 | int convertedTotal = 0; 26 | 27 | int runImpl(); 28 | public: 29 | StrengthReduct(ModuleOp *module): Pass(module) {} 30 | 31 | std::string name() override { return "strength-reduction"; }; 32 | std::map stats() override; 33 | void run() override; 34 | }; 35 | 36 | class InstCombine : public Pass { 37 | int combined = 0; 38 | public: 39 | InstCombine(ModuleOp *module): Pass(module) {} 40 | 41 | std::string name() override { return "rv-inst-combine"; }; 42 | std::map stats() override; 43 | void run() override; 44 | }; 45 | 46 | class RegAlloc : public Pass { 47 | int spilled = 0; 48 | int convertedTotal = 0; 49 | 50 | std::map> usedRegisters; 51 | std::map fnMap; 52 | 53 | void runImpl(Region *region, bool isLeaf); 54 | // Create both prologue and epilogue of a function. 55 | void proEpilogue(FuncOp *funcOp, bool isLeaf); 56 | int latePeephole(Op *funcOp); 57 | void tidyup(Region *region); 58 | public: 59 | RegAlloc(ModuleOp *module): Pass(module) {} 60 | 61 | std::string name() override { return "rv-regalloc"; }; 62 | std::map stats() override; 63 | void run() override; 64 | }; 65 | 66 | // Dumps the output. 67 | class Dump : public Pass { 68 | std::string out; 69 | 70 | void dump(std::ostream &os); 71 | public: 72 | Dump(ModuleOp *module, const std::string &out): Pass(module), out(out) {} 73 | 74 | std::string name() override { return "rv-dump"; }; 75 | std::map stats() override { return {}; } 76 | void run() override; 77 | }; 78 | 79 | } 80 | 81 | } 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /src/pre-opt/Base.cpp: -------------------------------------------------------------------------------- 1 | #include "PreAnalysis.h" 2 | #include "../codegen/CodeGen.h" 3 | 4 | using namespace sys; 5 | 6 | namespace { 7 | 8 | void remove(Region *region) { 9 | for (auto bb : region->getBlocks()) { 10 | for (auto op : bb->getOps()) { 11 | op->remove(); 12 | for (auto r : op->getRegions()) 13 | remove(r); 14 | } 15 | } 16 | } 17 | 18 | } 19 | 20 | void Base::runImpl(Region *region) { 21 | for (auto bb : region->getBlocks()) { 22 | for (auto op : bb->getOps()) { 23 | // Recursively deal with inner regions. 24 | for (auto r : op->getRegions()) 25 | runImpl(r); 26 | 27 | // The base of an alloca/getglobal is itself. 28 | if (isa(op) || isa(op)) { 29 | op->add(op); 30 | continue; 31 | } 32 | 33 | // For addl, the base is that of the address. 34 | if (isa(op)) { 35 | auto x = op->DEF(0); 36 | auto y = op->DEF(1); 37 | if (!x->has()) { 38 | if (y->has()) 39 | std::swap(x, y); 40 | else continue; 41 | } 42 | 43 | op->add(BASE(x)); 44 | continue; 45 | } 46 | } 47 | } 48 | } 49 | 50 | void Base::run() { 51 | auto funcs = collectFuncs(); 52 | 53 | for (auto func : funcs) { 54 | Region *region = func->getRegion(); 55 | // First remove all existing BaseAttrs. 56 | remove(region); 57 | 58 | // Find the place to hoist get-globals. 59 | auto bb = region->getFirstBlock(); 60 | if (bb->getOpCount() && isa(bb->getFirstOp())) 61 | bb = bb->nextBlock(); 62 | 63 | // Hoist GetGlobalOp to the front. 64 | auto gets = func->findAll(); 65 | std::unordered_map hoisted; 66 | Builder builder; 67 | 68 | for (auto get : gets) { 69 | const auto &name = NAME(get); 70 | if (!hoisted.count(name)) { 71 | builder.setToBlockStart(bb); 72 | auto newget = builder.create({ new NameAttr(name) }); 73 | hoisted[name] = newget; 74 | } 75 | get->replaceAllUsesWith(hoisted[name]); 76 | get->erase(); 77 | } 78 | 79 | runImpl(region); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/opt/Reassociate.cpp: -------------------------------------------------------------------------------- 1 | #include "CleanupPasses.h" 2 | 3 | using namespace sys; 4 | 5 | struct Associated { 6 | bool ref; 7 | std::vector mem; 8 | }; 9 | 10 | void Reassociate::runImpl(Region *region) { 11 | std::map data; 12 | auto domtree = getDomTree(region); 13 | 14 | std::vector queue { region->getFirstBlock() }; 15 | while (!queue.empty()) { 16 | auto bb = queue.back(); 17 | queue.pop_back(); 18 | 19 | for (auto op : bb->getOps()) { 20 | if (!isa(op)) 21 | continue; 22 | 23 | auto x = op->DEF(0), y = op->DEF(1); 24 | std::vector mem; 25 | 26 | if (data.count(x)) 27 | std::copy(data[x].mem.begin(), data[x].mem.end(), std::back_inserter(mem)), 28 | data[x].ref = true; 29 | else 30 | mem.push_back(x); 31 | 32 | if (data.count(y)) 33 | std::copy(data[y].mem.begin(), data[y].mem.end(), std::back_inserter(mem)), 34 | data[y].ref = true; 35 | else 36 | mem.push_back(y); 37 | 38 | data[op] = { false, mem }; 39 | } 40 | 41 | for (auto child : domtree[bb]) 42 | queue.push_back(child); 43 | } 44 | 45 | for (auto [k, v] : data) { 46 | if (v.ref || v.mem.size() == 2) 47 | continue; 48 | 49 | // We require every addition is used only once. 50 | auto mem = v.mem; 51 | bool good = true; 52 | for (auto op : mem) { 53 | if (op->getUses().size() > 1) { 54 | good = false; 55 | break; 56 | } 57 | } 58 | if (!good) 59 | continue; 60 | 61 | // Reassociate. 62 | std::vector copy; 63 | Builder builder; 64 | while (mem.size() != 1) { 65 | builder.setBeforeOp(k); 66 | for (int i = 0; i + 1 < mem.size(); i += 2) { 67 | auto add = builder.create({ mem[i]->getResult(), mem[i + 1] }); 68 | copy.push_back(add); 69 | } 70 | if (mem.size() & 1) 71 | copy.push_back(mem.back()); 72 | mem = copy; 73 | copy.clear(); 74 | } 75 | auto fulladd = mem[0]; 76 | k->replaceAllUsesWith(fulladd); 77 | k->erase(); 78 | } 79 | } 80 | 81 | void Reassociate::run() { 82 | auto funcs = collectFuncs(); 83 | 84 | for (auto func : funcs) 85 | runImpl(func->getRegion()); 86 | } 87 | -------------------------------------------------------------------------------- /src/opt/Verify.cpp: -------------------------------------------------------------------------------- 1 | #include "Passes.h" 2 | 3 | using namespace sys; 4 | 5 | // Checks whether every operation dominates its uses. 6 | static void checkDom(Region *region, Op *module) { 7 | // Only find reachable blocks. 8 | std::set reachable; 9 | std::vector queue { region->getFirstBlock() }; 10 | while (!queue.empty()) { 11 | auto bb = queue.back(); 12 | queue.pop_back(); 13 | 14 | if (reachable.count(bb)) 15 | continue; 16 | reachable.insert(bb); 17 | for (auto succ : bb->succs) 18 | queue.push_back(succ); 19 | } 20 | 21 | for (auto bb : reachable) { 22 | for (auto op : bb->getOps()) { 23 | // Phi's are checked later on. 24 | if (isa(op)) 25 | continue; 26 | 27 | for (auto operand : op->getOperands()) { 28 | auto def = operand.defining; 29 | if (!def->getParent()->dominates(bb)) { 30 | std::cerr << module << "non-dominating: " << op << "operand: " << def; 31 | assert(false); 32 | } 33 | } 34 | } 35 | } 36 | }; 37 | 38 | void Verify::run() { 39 | auto funcs = collectFuncs(); 40 | for (auto func : funcs) { 41 | auto region = func->getRegion(); 42 | region->updateDoms(); 43 | checkDom(region, module); 44 | } 45 | 46 | auto phis = module->findAll(); 47 | bool nonfrom = false; 48 | for (auto phi : phis) { 49 | // Check the number of phi's must be equal to the number of predecessors. 50 | auto parent = phi->getParent(); 51 | if (parent->preds.size() != phi->getOperandCount()) { 52 | std::cerr << module << "phi with " << phi->getOperandCount() << 53 | " operand(s), but expected " << parent->preds.size() << ":\n " << phi; 54 | assert(false); 55 | } 56 | 57 | // Check that all operands from Phi must come from the immediate predecessor. 58 | auto bb = phi->getParent(); 59 | for (auto attr : phi->getAttrs()) { 60 | if (!isa(attr)) { 61 | nonfrom = true; 62 | continue; 63 | } 64 | if (!bb->preds.count(FROM(attr))) { 65 | std::cerr << module << "phi operands are not from predecessor:\n " << phi; 66 | assert(false); 67 | } 68 | } 69 | } 70 | if (nonfrom) 71 | std::cerr << "warning: phi has non-FromAttr\n"; 72 | } 73 | -------------------------------------------------------------------------------- /src/opt/RangeAwareFold.cpp: -------------------------------------------------------------------------------- 1 | #include "CleanupPasses.h" 2 | #include "../utils/Matcher.h" 3 | 4 | using namespace sys; 5 | 6 | std::map RangeAwareFold::stats() { 7 | return { 8 | { "folded-ops", folded } 9 | }; 10 | } 11 | 12 | // Defined in Specialize.cpp. 13 | void removeRange(Region *region); 14 | 15 | void RangeAwareFold::run() { 16 | Builder builder; 17 | 18 | // Fold left/right shifts early. 19 | runRewriter([&](DivIOp *op) { 20 | auto x = op->DEF(0); 21 | auto y = op->DEF(1); 22 | if (!isa(y) || V(y) < 0 || !x->has()) 23 | return false; 24 | 25 | auto [low, high] = RANGE(x); 26 | if (low < 0) 27 | return false; 28 | 29 | if (__builtin_popcount(V(y)) != 1) 30 | return false; 31 | 32 | // This can be replaced to an ordinary right-shift. 33 | folded++; 34 | builder.setBeforeOp(op); 35 | auto vi = builder.create({ new IntAttr(__builtin_ctz(V(y))) }); 36 | builder.replace(op, { x, vi }); 37 | return false; 38 | }); 39 | 40 | runRewriter([&](ModIOp *op) { 41 | auto x = op->DEF(0); 42 | auto y = op->DEF(1); 43 | if (!isa(y) || !x->has()) 44 | return false; 45 | 46 | if (V(y) < 0) 47 | V(y) = -V(y); 48 | 49 | auto [low, high] = RANGE(x); 50 | if (low < 0) 51 | return false; 52 | 53 | if (__builtin_popcount(V(y)) != 1) 54 | return false; 55 | 56 | // Replace with bit-and. 57 | folded++; 58 | builder.setBeforeOp(op); 59 | auto vi = builder.create({ new IntAttr(V(y) - 1) }); 60 | builder.replace(op, { x, vi }); 61 | return false; 62 | }); 63 | 64 | Rule eq_or("(or (eq x 1) (not x))"); 65 | runRewriter([&](OrIOp *op) { 66 | if (eq_or.match(op)) { 67 | auto x = eq_or.extract("x"); 68 | if (!x->has()) 69 | return false; 70 | auto [low, high] = RANGE(x); 71 | if (low == 0) { 72 | folded++; 73 | builder.setBeforeOp(op); 74 | auto two = builder.create({ new IntAttr(2) }); 75 | builder.replace(op, { x, two }); 76 | return false; 77 | } 78 | } 79 | return false; 80 | }); 81 | 82 | auto funcs = collectFuncs(); 83 | 84 | for (auto func : funcs) { 85 | auto region = func->getRegion(); 86 | removeRange(region); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/pre-opt/Lower.cpp: -------------------------------------------------------------------------------- 1 | #include "PreLoopPasses.h" 2 | 3 | using namespace sys; 4 | 5 | void Lower::run() { 6 | auto loops = module->findAll(); 7 | 8 | Builder builder; 9 | // Destruct fors and turn them into whiles. 10 | for (auto loop : loops) { 11 | builder.setBeforeOp(loop); 12 | auto ivAddr = loop->getOperand(3); 13 | 14 | // Put a load of the address at the entry. 15 | auto region = loop->getRegion(); 16 | builder.setToRegionStart(region); 17 | auto iv = builder.create(Value::i32, { ivAddr }, { new SizeAttr(4) }); 18 | 19 | // Replace the induction variable with the load. 20 | loop->replaceAllUsesWith(iv); 21 | 22 | // Before every break/continue, insert a store `iv = iv + 'a`. 23 | auto terms = loop->findAll(); 24 | auto conts = loop->findAll(); 25 | std::copy(conts.begin(), conts.end(), std::back_inserter(terms)); 26 | auto incr = loop->getOperand(2); 27 | 28 | for (auto op : terms) { 29 | builder.setBeforeOp(op); 30 | auto add = builder.create({ iv, incr }); 31 | builder.create({ add, ivAddr }, { new SizeAttr(4) }); 32 | } 33 | 34 | // Also do it at the end. 35 | auto last = region->getLastBlock(); 36 | builder.setToBlockEnd(last); 37 | auto add = builder.create({ iv, incr }); 38 | builder.create({ add, ivAddr }, { new SizeAttr(4) }); 39 | 40 | // Create a while loop. 41 | builder.setBeforeOp(loop); 42 | auto wloop = builder.create(); 43 | auto before = wloop->appendRegion(); 44 | auto after = wloop->appendRegion(); 45 | 46 | // Move all blocks in the for to the after region. 47 | for (auto it = region->begin(); it != region->end();) { 48 | auto next = it; next++; 49 | (*it)->moveToEnd(after); 50 | it = next; 51 | } 52 | 53 | // Create the condition at the before region. 54 | before->appendBlock(); 55 | builder.setToRegionStart(before); 56 | auto load = builder.create(Value::i32, { ivAddr }, { new SizeAttr(4) }); 57 | auto stop = loop->getOperand(1); 58 | auto lt = builder.create({ load, stop }); 59 | builder.create({ lt }); 60 | 61 | // Create the start value before the while. 62 | builder.setBeforeOp(wloop); 63 | auto start = loop->getOperand(0); 64 | builder.create({ start, ivAddr }, { new SizeAttr(4) }); 65 | 66 | // Erase the ForOp. 67 | loop->erase(); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/opt/Analysis.h: -------------------------------------------------------------------------------- 1 | #ifndef ANALYSIS_H 2 | #define ANALYSIS_H 3 | 4 | #include "Passes.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Attrs.h" 7 | 8 | namespace sys { 9 | 10 | // Analysis pass. 11 | // Detects whether a function is pure. If it isn't, give it an ImpureAttr. 12 | class Pureness : public Pass { 13 | // Maps a function to all functions that it might call. 14 | std::map> callGraph; 15 | 16 | public: 17 | Pureness(ModuleOp *module): Pass(module) {} 18 | 19 | std::string name() override { return "pureness"; }; 20 | std::map stats() override { return {}; } 21 | void run() override; 22 | }; 23 | 24 | // Puts CallerAttr to each function. 25 | class CallGraph : public Pass { 26 | public: 27 | CallGraph(ModuleOp *module): Pass(module) {} 28 | 29 | std::string name() override { return "call-graph"; }; 30 | std::map stats() override { return {}; } 31 | void run() override; 32 | }; 33 | 34 | // Gives an AliasAttr to values, if they are addresses. 35 | class Alias : public Pass { 36 | std::map gMap; 37 | void runImpl(Region *region); 38 | public: 39 | Alias(ModuleOp *module): Pass(module) {} 40 | 41 | std::string name() override { return "alias"; }; 42 | std::map stats() override { return {}; } 43 | void run() override; 44 | }; 45 | 46 | // Integer range analysis. 47 | class Range : public Pass { 48 | // The set of all loop headers in a function. 49 | // We should apply widening at these blocks, otherwise it would take forever to converge. 50 | std::set headers; 51 | 52 | // Reorder the blocks so that they have a single exit. 53 | void postdom(Region *region); 54 | // Split a single operation into two for comparison branches. 55 | void split(Region *region); 56 | // Give RangeAttr to operations. 57 | void analyze(Region *region); 58 | public: 59 | Range(ModuleOp *module): Pass(module) {} 60 | 61 | std::string name() override { return "range"; } 62 | std::map stats() override { return {}; } 63 | void run() override; 64 | }; 65 | 66 | // Mark functions that are called at most once. 67 | class AtMostOnce : public Pass { 68 | public: 69 | AtMostOnce(ModuleOp *module): Pass(module) {} 70 | 71 | std::string name() override { return "at-most-once"; }; 72 | std::map stats() override { return {}; } 73 | void run() override; 74 | }; 75 | 76 | } 77 | 78 | #endif 79 | -------------------------------------------------------------------------------- /src/parse/TypeContext.h: -------------------------------------------------------------------------------- 1 | #ifndef TYPE_CONTEXT_H 2 | #define TYPE_CONTEXT_H 3 | 4 | #include 5 | 6 | #include "Type.h" 7 | #include "../utils/DynamicCast.h" 8 | 9 | namespace sys { 10 | 11 | // Manages memory of types. 12 | class TypeContext { 13 | struct Hash { 14 | size_t operator()(Type *ty) const { 15 | size_t hash = ty->getID(); 16 | 17 | if (auto arr = dyn_cast(ty)) { 18 | hash = (hash << 4) + Hash()(arr->base); 19 | for (auto x : arr->dims) 20 | hash *= (x + 1); 21 | } 22 | 23 | if (auto ptr = dyn_cast(ty)) 24 | hash = (hash << 4) + Hash()(ptr->pointee); 25 | 26 | if (auto fn = dyn_cast(ty)) { 27 | hash = (hash << 4) + Hash()(fn->ret); 28 | for (auto x : fn->params) { 29 | hash <<= 1; 30 | hash += Hash()(x); 31 | } 32 | } 33 | 34 | return hash; 35 | } 36 | }; 37 | 38 | struct Eq { 39 | bool operator()(Type *a, Type *b) const { 40 | if (a->getID() != b->getID()) 41 | return false; 42 | 43 | if (auto arr = dyn_cast(a)) { 44 | auto arrb = cast(b); 45 | if (arr->dims.size() != arrb->dims.size()) 46 | return false; 47 | 48 | for (int i = 0; i < arr->dims.size(); i++) { 49 | if (arr->dims[i] != arrb->dims[i]) 50 | return false; 51 | } 52 | 53 | return Eq()(arr->base, arrb->base); 54 | } 55 | 56 | if (auto ptr = dyn_cast(a)) { 57 | auto ptrb = cast(b); 58 | return Eq()(ptr->pointee, ptrb->pointee); 59 | } 60 | 61 | if (auto fn = dyn_cast(a)) { 62 | auto fnb = cast(b); 63 | if (fn->params.size() != fnb->params.size()) 64 | return false; 65 | 66 | for (int i = 0; i < fn->params.size(); i++) { 67 | if (!Eq()(fn->params[i], fnb->params[i])) 68 | return false; 69 | } 70 | 71 | return Eq()(fn->ret, fnb->ret); 72 | } 73 | 74 | return true; 75 | } 76 | }; 77 | 78 | std::unordered_set content; 79 | public: 80 | template 81 | T *create(Args... args) { 82 | auto ptr = new T(std::forward(args)...); 83 | if (auto [it, absent] = content.insert(ptr); !absent) { 84 | delete ptr; 85 | return cast(*it); 86 | } 87 | return ptr; 88 | } 89 | 90 | ~TypeContext() { 91 | for (auto x : content) 92 | delete x; 93 | } 94 | }; 95 | 96 | } 97 | 98 | #endif 99 | -------------------------------------------------------------------------------- /src/pre-opt/Localize.cpp: -------------------------------------------------------------------------------- 1 | #include "PrePasses.h" 2 | 3 | using namespace sys; 4 | 5 | void Localize::run() { 6 | auto funcs = collectFuncs(); 7 | auto fnMap = getFunctionMap(); 8 | 9 | auto getglobs = module->findAll(); 10 | auto gMap = getGlobalMap(); 11 | std::map> accessed; 12 | 13 | Builder builder; 14 | 15 | for (auto get : getglobs) { 16 | const auto &name = NAME(get); 17 | accessed[gMap[name]].insert(get->getParentOp()); 18 | } 19 | 20 | for (auto [name, k] : gMap) { 21 | // We don't want to localize an array. In fact, we hope to globalize them. 22 | if (SIZE(k) != 4) 23 | continue; 24 | 25 | if (!accessed.count(k)) { 26 | // The global variable is never accessed. Remove it. 27 | k->erase(); 28 | continue; 29 | } 30 | 31 | auto v = accessed[k]; 32 | if (v.size() > 1) 33 | continue; 34 | 35 | if (!(*v.begin())->has()) 36 | continue; 37 | 38 | // Now we can replace the global with a local variable. 39 | auto user = *v.begin(); 40 | auto region = user->getRegion(); 41 | 42 | auto entry = region->getFirstBlock(); 43 | Op *addr; 44 | if (beforeFlatten) { 45 | builder.setToBlockEnd(entry); 46 | addr = builder.create({ new SizeAttr(4) }); 47 | } else { 48 | builder.setBeforeOp(entry->getLastOp()); 49 | addr = builder.create({ new SizeAttr(4) }); 50 | } 51 | 52 | auto bb = region->insertAfter(entry); 53 | // We must make sure the whole entry block contains only alloca. 54 | // That's why we inserted a new block here. 55 | // This is also for further transformations that append allocas to the first block. 56 | builder.setToBlockStart(bb); 57 | Value init; 58 | if (auto intArr = k->find()) { 59 | init = builder.create({ 60 | new IntAttr(intArr->vi[0]) 61 | }); 62 | } else { 63 | init = builder.create({ 64 | new FloatAttr(k->get()->vf[0]) 65 | }); 66 | } 67 | builder.create({ init, addr }, { new SizeAttr(4) }); 68 | 69 | if (!beforeFlatten) { 70 | // Remember to supply terminators for after FlattenCFG. 71 | entry->getLastOp()->moveToEnd(bb); 72 | 73 | builder.setToBlockEnd(entry); 74 | builder.create({ new TargetAttr(bb) }); 75 | } 76 | 77 | // Replace all "getglobal" to use the addr instead. 78 | auto gets = user->findAll(); 79 | for (auto get : gets) { 80 | if (NAME(get) == name) { 81 | get->replaceAllUsesWith(addr); 82 | get->erase(); 83 | } 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/pre-opt/PrePasses.h: -------------------------------------------------------------------------------- 1 | #ifndef PREPASSES_H 2 | #define PREPASSES_H 3 | 4 | #include "../opt/Pass.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Attrs.h" 7 | 8 | namespace sys { 9 | 10 | // Moves all alloca to the beginning. 11 | class MoveAlloca : public Pass { 12 | public: 13 | MoveAlloca(ModuleOp *module): Pass(module) {} 14 | 15 | std::string name() override { return "move-alloca"; }; 16 | std::map stats() override { return {}; }; 17 | void run() override; 18 | }; 19 | 20 | // Folds before flattening CFG. 21 | class EarlyConstFold : public Pass { 22 | int foldedTotal = 0; 23 | bool beforePureness; 24 | 25 | int foldImpl(); 26 | public: 27 | EarlyConstFold(ModuleOp *module, bool beforePureness): Pass(module), beforePureness(beforePureness) {} 28 | 29 | std::string name() override { return "early-const-fold"; }; 30 | std::map stats() override; 31 | void run() override; 32 | }; 33 | 34 | // Folds memory, similar to DLE in later passes. 35 | class TidyMemory : public Pass { 36 | int tidied = 0; 37 | 38 | void runImpl(Region *region); 39 | public: 40 | TidyMemory(ModuleOp *module): Pass(module) {} 41 | 42 | std::string name() override { return "tidy-memory"; }; 43 | std::map stats() override; 44 | void run() override; 45 | }; 46 | 47 | // Localizes global variables. 48 | class Localize : public Pass { 49 | bool beforeFlatten; 50 | public: 51 | Localize(ModuleOp *module, bool beforeFlatten): 52 | Pass(module), beforeFlatten(beforeFlatten) {} 53 | 54 | std::string name() override { return "localize"; }; 55 | std::map stats() override { return {}; } 56 | void run() override; 57 | }; 58 | 59 | class EarlyInline : public Pass { 60 | public: 61 | EarlyInline(ModuleOp *module): Pass(module) {} 62 | 63 | std::string name() override { return "early-inline"; }; 64 | std::map stats() override { return {}; } 65 | void run() override; 66 | }; 67 | 68 | // Tail call optimization. 69 | class TCO : public Pass { 70 | int uncalled = 0; 71 | 72 | bool runImpl(FuncOp *func); 73 | bool runAdd(FuncOp *func); 74 | public: 75 | TCO(ModuleOp *module): Pass(module) {} 76 | 77 | std::string name() override { return "tco"; }; 78 | std::map stats() override; 79 | void run() override; 80 | }; 81 | 82 | // Remerge basic blocks. 83 | // It is always possible to ensure each region has one block (except allocas), 84 | // since before FlattenCFG there's no jumps. 85 | class Remerge : public Pass { 86 | void runImpl(Region *region); 87 | public: 88 | Remerge(ModuleOp *module): Pass(module) {} 89 | 90 | std::string name() override { return "remerge"; }; 91 | std::map stats() override { return {}; } 92 | void run() override; 93 | }; 94 | 95 | } 96 | 97 | #endif 98 | -------------------------------------------------------------------------------- /fuzz.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | import tempfile; 3 | import os; 4 | import subprocess as proc; 5 | import random; 6 | 7 | operators = [ 8 | '+', '-', '*', '/', '%' 9 | ] 10 | comparisons = [ 11 | '<', '>', '==', '<=', '>=' 12 | ] 13 | error_cnt = 0 14 | 15 | def run(file: str, input: bytes): 16 | basename = os.path.splitext(os.path.basename(file))[0] 17 | try: 18 | proc.run(["build/sysc", file, "-o", f"temp/{basename}.s"], check=True, timeout=5) 19 | except: 20 | print("Compiler internal error.") 21 | with open(f"temp/bad_program.txt", "w") as f: 22 | f.write(open(file, "r").read()) 23 | exit(1) 24 | 25 | gcc = "riscv64-linux-gnu-gcc" 26 | proc.run([gcc, f"temp/{basename}.s", "test/official/sylib.c", "-static", "-o", f"temp/{basename}"]) 27 | 28 | # Run the file. 29 | qemu = "qemu-riscv64-static" 30 | try: 31 | return proc.run([qemu, f"temp/{basename}"], input=input, stdout=proc.PIPE, timeout=5) 32 | except: 33 | print("Program timeout.") 34 | with open(f"temp/bad_program.txt", "w") as f: 35 | f.write(open(file, "r").read()) 36 | exit(1) 37 | 38 | def run_actual(file: str, input: bytes): 39 | proc.run(["clang", file, "-o", f"{file}_clang"]) 40 | 41 | return proc.run([f"{file}_clang"], stdout=proc.PIPE, input=input) 42 | 43 | def fuzz_arithmetic_fold(dir: str): 44 | global error_cnt 45 | 46 | testcases = [] 47 | for i in range(0, 20): 48 | c1 = random.randint(-10, 50) 49 | c2 = random.randint(-10, 50) 50 | op = random.choice(operators) 51 | comp = random.choice(comparisons) 52 | # No division by zero. 53 | if (op == '/' or op == '%') and c1 == 0: 54 | c1 = 1 55 | 56 | testcases.append(f"x {op} {c1} {comp} {c2}") 57 | testcases.append(f"{c1} {op} x {comp} {c2}") 58 | 59 | sy = os.path.join(dir, "file.sy") 60 | c = os.path.join(dir, "file.c") 61 | with open(sy, "w") as f: 62 | f.write("int main() {\n int x = getint();\n ") 63 | f.write('\n '.join([f"putint({x}); putch(10);" for x in testcases])) 64 | f.write("\n}\n") 65 | 66 | with open(c, "w") as f: 67 | f.write("#include \n") 68 | f.write('int main() {\n int x; scanf("%d", &x);\n ') 69 | f.write('\n '.join([f'printf("%d\\n", {x});' for x in testcases])) 70 | f.write("\n}\n") 71 | 72 | data = (str(random.randint(-10000, 10000)) + "\n").encode('utf-8') 73 | expect_out = run_actual(c, input=data).stdout.decode('utf-8') 74 | actual_out = run(sy, input=data).stdout.decode('utf-8') 75 | if actual_out != expect_out: 76 | print(f"Error! Current error count: {error_cnt}") 77 | with open(f"temp/{error_cnt}_expected.txt", "w") as f: 78 | f.write(expect_out) 79 | with open(f"temp/{error_cnt}_actual.txt", "w") as f: 80 | f.write(actual_out) 81 | with open(f"temp/{error_cnt}_program.txt", "w") as f: 82 | f.write('\n'.join(testcases)) 83 | error_cnt += 1 84 | 85 | with tempfile.TemporaryDirectory() as dir: 86 | for i in range(0, 20): 87 | fuzz_arithmetic_fold(dir) 88 | -------------------------------------------------------------------------------- /src/utils/smt/BvExpr.h: -------------------------------------------------------------------------------- 1 | #ifndef BVEXPR_H 2 | #define BVEXPR_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace smt { 10 | 11 | #define TYPES \ 12 | X(Var) X(Const) X(And) X(Or) X(Xor) X(Not) X(Add) X(Eq) X(Ne) X(Mul) X(Sub) \ 13 | X(Lsh) X(Rsh) X(Div) X(Mod) X(Ite) X(Hole) X(Le) X(Lt) X(Abs) X(Minus) X(MulMod) \ 14 | X(Extr) 15 | 16 | class BvExpr { 17 | public: 18 | #define X(x) x, 19 | enum Type { 20 | TYPES 21 | } ty; 22 | #undef X 23 | 24 | #define X(x) #x, 25 | static constexpr const char *names[] = { 26 | TYPES 27 | }; 28 | #undef X 29 | 30 | BvExpr *l = nullptr, *r = nullptr, *cond = nullptr; 31 | int vi = 0; 32 | std::string name; 33 | 34 | BvExpr(Type ty): ty(ty) {} 35 | BvExpr(Type ty, int vi): ty(ty), vi(vi) {} 36 | BvExpr(Type ty, BvExpr *l): ty(ty), l(l) {} 37 | BvExpr(Type ty, BvExpr *l, int vi): ty(ty), l(l), vi(vi) {} 38 | BvExpr(Type ty, BvExpr *l, BvExpr *r): ty(ty), l(l), r(r) {} 39 | BvExpr(Type ty, BvExpr *cond, BvExpr *l, BvExpr *r): ty(ty), l(l), r(r), cond(cond) {} 40 | BvExpr(Type ty, const std::string &name): ty(ty), name(name) {} 41 | BvExpr(Type ty, int vi, const std::string &name, BvExpr *cond, BvExpr *l, BvExpr *r): 42 | ty(ty), l(l), r(r), cond(cond), vi(vi), name(name) {} 43 | 44 | void dump(std::ostream &os = std::cerr); 45 | }; 46 | 47 | inline std::ostream &operator<<(std::ostream &os, BvExpr *expr) { 48 | expr->dump(os); 49 | return os; 50 | } 51 | 52 | #undef TYPES 53 | 54 | class BvExprContext { 55 | struct Eq { 56 | bool operator()(BvExpr *a, BvExpr *b) const { 57 | return a->ty == b->ty && a->l == b->l && a->r == b->r && a->cond == b->cond && a->name == b->name && a->vi == b->vi; 58 | } 59 | }; 60 | 61 | struct Hash { 62 | // From boost::hash_combine. 63 | static void hash_combine(size_t &a, size_t b) { 64 | a ^= b + 0x9e3779b9 + (a << 6) + (a >> 2); 65 | } 66 | 67 | size_t operator()(BvExpr *a) const { 68 | size_t result = a->ty; 69 | hash_combine(result, a->vi); 70 | if (a->l) 71 | hash_combine(result, (uintptr_t) (a->l)); 72 | if (a->r) 73 | hash_combine(result, (uintptr_t) (a->r)); 74 | if (a->cond) 75 | hash_combine(result, (uintptr_t) (a->cond)); 76 | if (a->ty == BvExpr::Var) 77 | hash_combine(result, std::hash()(a->name)); 78 | return result; 79 | } 80 | }; 81 | 82 | std::unordered_set set; 83 | public: 84 | template 85 | BvExpr *create(BvExpr::Type ty, Args... args) { 86 | BvExpr *p = new BvExpr(ty, args...); 87 | if (auto it = set.find(p); it != set.end()) { 88 | delete p; 89 | return *it; 90 | } 91 | set.insert(p); 92 | return p; 93 | } 94 | 95 | ~BvExprContext() { 96 | for (auto x : set) 97 | delete x; 98 | } 99 | }; 100 | 101 | } 102 | 103 | #endif 104 | -------------------------------------------------------------------------------- /src/arm/ArmPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef ARM_PASSES_H 2 | #define ARM_PASSES_H 3 | 4 | #include "../opt/Pass.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Ops.h" 7 | #include "../codegen/Attrs.h" 8 | #include "ArmOps.h" 9 | #include "ArmAttrs.h" 10 | 11 | namespace sys::arm { 12 | 13 | class Lower : public Pass { 14 | public: 15 | Lower(ModuleOp *module): Pass(module) {} 16 | 17 | std::string name() override { return "arm-lower"; }; 18 | std::map stats() override { return {}; }; 19 | void run() override; 20 | }; 21 | 22 | class StrengthReduct : public Pass { 23 | int convertedTotal = 0; 24 | 25 | int runImpl(); 26 | public: 27 | StrengthReduct(ModuleOp *module): Pass(module) {} 28 | 29 | std::string name() override { return "arm-strength-reduct"; }; 30 | std::map stats() override; 31 | void run() override; 32 | }; 33 | 34 | class InstCombine : public Pass { 35 | int combined = 0; 36 | public: 37 | InstCombine(ModuleOp *module): Pass(module) {} 38 | 39 | std::string name() override { return "arm-inst-combine"; }; 40 | std::map stats() override; 41 | void run() override; 42 | }; 43 | 44 | // The only difference with opt/DCE is that `isImpure` behaves differently. 45 | class ArmDCE : public Pass { 46 | std::vector removeable; 47 | int elim = 0; 48 | 49 | bool isImpure(Op *op); 50 | void markImpure(Region *region); 51 | void runOnRegion(Region *region); 52 | public: 53 | ArmDCE(ModuleOp *module): Pass(module) {} 54 | 55 | std::string name() override { return "arm-dce"; }; 56 | std::map stats() override; 57 | void run() override; 58 | }; 59 | 60 | class RegAlloc : public Pass { 61 | int spilled = 0; 62 | int convertedTotal = 0; 63 | 64 | std::map> usedRegisters; 65 | std::map fnMap; 66 | 67 | void runImpl(Region *region, bool isLeaf); 68 | void proEpilogue(FuncOp *funcOp, bool isLeaf); 69 | int latePeephole(Op *funcOp); 70 | void tidyup(Region *region); 71 | public: 72 | 73 | RegAlloc(ModuleOp *module): Pass(module) {} 74 | 75 | std::string name() override { return "arm-regalloc"; }; 76 | std::map stats() override; 77 | void run() override; 78 | }; 79 | 80 | class LateLegalize : public Pass { 81 | public: 82 | LateLegalize(ModuleOp *module): Pass(module) {} 83 | 84 | std::string name() override { return "arm-late-legalize"; }; 85 | std::map stats() override { return {}; } 86 | void run() override; 87 | }; 88 | 89 | // Dumps the output. 90 | class Dump : public Pass { 91 | std::string out; 92 | 93 | void dump(std::ostream &os); 94 | void dumpBody(Region *region, std::ostream &os); 95 | void dumpOp(Op *op, std::ostream &os); 96 | public: 97 | Dump(ModuleOp *module, const std::string &out): Pass(module), out(out) {} 98 | 99 | std::string name() override { return "arm-dump"; }; 100 | std::map stats() override { return {}; } 101 | void run() override; 102 | }; 103 | 104 | }; 105 | 106 | #endif 107 | -------------------------------------------------------------------------------- /src/opt/SimplifyCFG.cpp: -------------------------------------------------------------------------------- 1 | #include "CleanupPasses.h" 2 | 3 | using namespace sys; 4 | 5 | std::map SimplifyCFG::stats() { 6 | return { 7 | { "inlined-blocks", inlined }, 8 | }; 9 | } 10 | 11 | void SimplifyCFG::runImpl(Region *region) { 12 | region->updatePreds(); 13 | bool changed; 14 | do { 15 | changed = false; 16 | const auto &bbs = region->getBlocks(); 17 | for (auto bb : bbs) { 18 | if (bb->succs.size() != 1) 19 | continue; 20 | 21 | auto succ = *bb->succs.begin(); 22 | if (succ->preds.size() != 1) 23 | continue; 24 | 25 | // Now we can safely inline `succ` into `bb`. 26 | // If `succ` have phi, then it must be single-operand. Inline them. 27 | auto phis = succ->getPhis(); 28 | for (auto phi : phis) { 29 | auto def = phi->getOperand().defining; 30 | phi->replaceAllUsesWith(def); 31 | phi->erase(); 32 | } 33 | 34 | // Remove the jump to `succ`. 35 | bb->getLastOp()->erase(); 36 | 37 | // Then move all instruction in `succ` to `bb`. 38 | auto ops = succ->getOps(); 39 | for (auto op : ops) 40 | op->moveToEnd(bb); 41 | 42 | // All successors of `succ` now have pred `bb`. 43 | // `bb` also regard them as successors. 44 | for (auto s : succ->succs) { 45 | s->preds.erase(succ); 46 | s->preds.insert(bb); 47 | bb->succs.insert(s); 48 | } 49 | // Don't forget to remove `succ` from the successors of `bb`. 50 | bb->succs.erase(succ); 51 | 52 | // The phis at the beginning of some successor need to refer to `bb`. 53 | for (auto s : succ->succs) { 54 | auto phis = s->getPhis(); 55 | for (auto phi : phis) { 56 | const auto &attrs = phi->getAttrs(); 57 | 58 | for (int i = 0; i < attrs.size(); i++) { 59 | auto &from = FROM(attrs[i]); 60 | if (from == succ) 61 | from = bb; 62 | } 63 | } 64 | } 65 | 66 | succ->forceErase(); 67 | inlined++; 68 | changed = true; 69 | break; 70 | } 71 | } while (changed); 72 | } 73 | 74 | void SimplifyCFG::run() { 75 | auto funcs = collectFuncs(); 76 | for (auto func : funcs) 77 | runImpl(func->getRegion()); 78 | 79 | // Simplify the following construction: 80 | // if (x) ...; if (!x) ... 81 | // becomes a single if-else. 82 | // 83 | // In IR, it should be like 84 | // 85 | // bb1: 86 | // br %1, , 87 | // bb2: 88 | // goto 89 | // exit: 90 | // br (not %1), , 91 | // bb3: 92 | // goto 93 | // exit2: 94 | // ... 95 | // 96 | // As multiple linear goto's have been combined, there's less risk of missed optimization. 97 | // To fold it: 98 | // 99 | // bb1: 100 | // br %1, , 101 | // bb2: 102 | // goto 103 | // exit (now dead): 104 | // br (not %1), , 105 | // bb3: 106 | // goto 107 | // exit2: 108 | // ... 109 | // Note that `exit` must only contain a NotOp and a branch. 110 | } 111 | -------------------------------------------------------------------------------- /src/codegen/Ops.h: -------------------------------------------------------------------------------- 1 | #ifndef OPS_H 2 | #define OPS_H 3 | 4 | #include "OpBase.h" 5 | 6 | #define OPBASE(ValueTy, Ty) \ 7 | class Ty : public OpImpl { \ 8 | public: \ 9 | explicit Ty(const std::vector &values): OpImpl(ValueTy, values) { \ 10 | setName(#Ty); \ 11 | } \ 12 | Ty(): OpImpl(ValueTy, {}) { \ 13 | setName(#Ty); \ 14 | } \ 15 | explicit Ty(const std::vector &attrs): OpImpl(ValueTy, {}, attrs) { \ 16 | setName(#Ty); \ 17 | } \ 18 | Ty(const std::vector &values, const std::vector &attrs): OpImpl(ValueTy, values, attrs) { \ 19 | setName(#Ty); \ 20 | } \ 21 | } 22 | 23 | // Ops that must be explicitly set a result type. 24 | #define OPE(Ty) \ 25 | class Ty : public OpImpl { \ 26 | public: \ 27 | Ty(Value::Type resultTy, const std::vector &values): OpImpl(resultTy, values) { \ 28 | setName(#Ty); \ 29 | } \ 30 | explicit Ty(Value::Type resultTy): OpImpl(resultTy, {}) { \ 31 | setName(#Ty); \ 32 | } \ 33 | Ty(Value::Type resultTy, const std::vector &attrs): OpImpl(resultTy, {}, attrs) { \ 34 | setName(#Ty); \ 35 | } \ 36 | Ty(Value::Type resultTy, const std::vector &values, const std::vector &attrs): OpImpl(resultTy, values, attrs) { \ 37 | setName(#Ty); \ 38 | } \ 39 | } 40 | 41 | #define OP(Ty) OPBASE(Value::i32, Ty) 42 | #define OPF(Ty) OPBASE(Value::f32, Ty) 43 | #define OPL(Ty) OPBASE(Value::i64, Ty) 44 | #define OPV(Ty) OPBASE(Value::i128, Ty) 45 | 46 | namespace sys { 47 | 48 | OP(ModuleOp); 49 | OP(AddIOp); 50 | OP(SubIOp); 51 | OP(MulIOp); 52 | OP(DivIOp); 53 | OP(ModIOp); 54 | OP(AndIOp); 55 | OP(OrIOp); 56 | OP(XorIOp); 57 | OPF(AddFOp); 58 | OPF(SubFOp); 59 | OPF(MulFOp); 60 | OPF(DivFOp); 61 | OPF(ModFOp); 62 | OPL(AddLOp); 63 | OPL(SubLOp); 64 | OPL(MulLOp); 65 | OPL(DivLOp); 66 | OPL(ModLOp); 67 | OP(EqOp); 68 | OP(NeOp); 69 | OP(LtOp); 70 | OP(LeOp); 71 | OP(EqFOp); 72 | OP(NeFOp); 73 | OP(LtFOp); 74 | OP(LeFOp); 75 | OP(FuncOp); 76 | OP(IntOp); 77 | OPF(FloatOp); 78 | OPL(AllocaOp); 79 | OPE(GetArgOp); 80 | OP(StoreOp); // Operand order: value, dst 81 | OPE(LoadOp); 82 | OP(ReturnOp); 83 | OP(IfOp); 84 | OP(WhileOp); 85 | OP(ForOp); 86 | OP(ProceedOp); 87 | OP(GotoOp); // Jumps unconditionally. 88 | OP(BranchOp); // Branches according to the only operand. 89 | OP(GlobalOp); 90 | OP(GetGlobalOp); 91 | OPE(CallOp); 92 | OP(PhiOp); 93 | OP(F2IOp); 94 | OPF(I2FOp); 95 | OP(MinusOp); // for input x, returns -x. Don't confuse with SubI/SubF. 96 | OPF(MinusFOp); 97 | OP(NotOp); 98 | OP(LShiftOp); 99 | OPL(LShiftLOp); 100 | OP(RShiftOp); 101 | OPL(RShiftLOp); 102 | OP(MulshOp); 103 | OP(MuluhOp); 104 | OP(SetNotZeroOp); 105 | OP(BreakOp); 106 | OP(ContinueOp); 107 | OP(SelectOp); 108 | 109 | // ====== Vectorized ====== 110 | 111 | OPV(AddVOp); 112 | OPV(SubVOp); 113 | OPV(MulVOp); 114 | 115 | OPV(BroadcastOp); 116 | 117 | // ====== Multi-threaded ====== 118 | 119 | OP(CloneOp); 120 | OP(JoinOp); 121 | OP(WakeOp); 122 | 123 | // vectorized load/store is detected by size. 124 | 125 | } 126 | 127 | #undef OP 128 | #define DEF(i) getOperand(i).defining 129 | 130 | #endif 131 | -------------------------------------------------------------------------------- /src/utils/smt/Simplify.cpp: -------------------------------------------------------------------------------- 1 | #include "BvMatcher.h" 2 | #include "SMT.h" 3 | 4 | using namespace smt; 5 | 6 | namespace { 7 | 8 | BvRule rules[] = { 9 | // Add 10 | "(change (add 'a 'b) (!add 'a 'b))", 11 | "(change (add x 0) x)", 12 | "(change (add 'a x) (add x 'a))", 13 | "(change (add x x) (lsh x 1))", 14 | 15 | // Sub 16 | "(change (sub 'a 'b) (!sub 'a 'b))", 17 | "(change (sub x 0) x)", 18 | "(change (sub x (minus y)) (add x y))", 19 | 20 | // Mul 21 | "(change (mul 'a 'b) (!mul 'a 'b))", 22 | 23 | // Div 24 | "(change (div 'a 'b) (!div 'a 'b))", 25 | 26 | // And 27 | "(change (and 'a 'b) (!and 'a 'b))", 28 | "(change (and x (eq x 'a)) (!only-if (!ne 'a 0) (eq x 'a)))", 29 | 30 | // Mod 31 | "(change (mod 'a 'b) (!mod 'a 'b))", 32 | 33 | // Lsh 34 | "(change (lsh 'a 'b) (!lsh 'a 'b))", 35 | 36 | // Eq 37 | "(change (eq 'a 'b) (!eq 'a 'b))", 38 | "(change (eq x 0) (not x))", 39 | 40 | // Not 41 | "(change (not 'a) (!not 'a))", 42 | "(change (not (eq x y)) (ne x y))", 43 | 44 | // Lt 45 | "(change (lt 'a 'b) (!lt 'a 'b))", 46 | 47 | // Ite 48 | "(change (ite (not x) y z) (ite x z y))", 49 | "(change (ite 0 y z) z)", 50 | "(change (ite 'a y z) (!only-if (!ne 'a 0) y))", 51 | 52 | // Mulmod 53 | "(change (mulmod 'a 'b 'c) (!mulmod 'a 'b 'c))", 54 | "(change (mulmod x y 1) 0)", 55 | "(change (mulmod x y -1) 0)", 56 | }; 57 | 58 | BvExpr *rewriteRoot(BvExpr *expr, BvExprContext &ctx) { 59 | // x % 2 == x - (x[31] + x) & (-2) 60 | if (expr->ty == BvExpr::Mod && expr->r->ty == BvExpr::Const && expr->r->vi == 2) { 61 | auto _1 = expr->l; 62 | auto _2 = ctx.create(BvExpr::Extr, _1, 31); 63 | auto _3 = ctx.create(BvExpr::Add, _1, _2); 64 | auto _4 = ctx.create(BvExpr::Const, -2); 65 | auto _5 = ctx.create(BvExpr::And, _3, _4); 66 | auto _6 = ctx.create(BvExpr::Sub, _1, _5); 67 | return _6; 68 | } 69 | 70 | // extr of constant 71 | if (expr->ty == BvExpr::Extr && expr->l->ty == BvExpr::Const) { 72 | return ctx.create(BvExpr::Const, (((unsigned) expr->l->vi) >> expr->vi) & 1); 73 | } 74 | 75 | return expr; 76 | } 77 | 78 | BvExpr *rewrite(BvExpr *expr, BvExprContext &ctx) { 79 | if (!expr) 80 | return nullptr; 81 | 82 | BvExpr* newcond = rewrite(expr->cond, ctx); 83 | BvExpr* newl = rewrite(expr->l, ctx); 84 | BvExpr* newr = rewrite(expr->r, ctx); 85 | 86 | BvExpr* updated = ctx.create(expr->ty, expr->vi, expr->name, newcond, newl, newr); 87 | return rewriteRoot(updated, ctx); 88 | } 89 | 90 | } 91 | 92 | [[nodiscard]] 93 | BvExpr *smt::simplify(BvExpr *expr, BvExprContext &ctx) { 94 | BvExpr *result = expr; 95 | bool changed; 96 | for (auto &rule : rules) 97 | rule.ctx = &ctx; 98 | do { 99 | changed = false; 100 | for (auto &rule : rules) { 101 | if (auto rewritten = rule.rewrite(expr); rewritten != expr) 102 | changed = true, expr = rewritten; 103 | if (auto rewritten = rewrite(expr, ctx); rewritten != expr) 104 | changed = true, expr = rewritten; 105 | } 106 | result = expr; 107 | } while (changed); 108 | 109 | std::cerr << "simplified: " << expr << "\n"; 110 | return result; 111 | } 112 | -------------------------------------------------------------------------------- /src/parse/Parser.h: -------------------------------------------------------------------------------- 1 | #ifndef PARSER_H 2 | #define PARSER_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "ASTNode.h" 10 | #include "Lexer.h" 11 | #include "TypeContext.h" 12 | 13 | namespace sys { 14 | 15 | // A compile-time integer constant, used when early-folding array dimensions. 16 | // Only integer is allowed, as the language specifies. 17 | class ConstValue { 18 | union { 19 | int *vi; 20 | float *vf; 21 | }; 22 | std::vector dims; 23 | public: 24 | bool isFloat; 25 | 26 | ConstValue() {} 27 | ConstValue(int *vi, const std::vector &dims): vi(vi), dims(dims), isFloat(false) {} 28 | ConstValue(float *vf, const std::vector &dims): vf(vf), dims(dims), isFloat(true) {} 29 | 30 | ConstValue operator[](int i); 31 | int getInt(); 32 | float getFloat(); 33 | const auto &getDims() { return dims; } 34 | 35 | int size(); 36 | int stride(); 37 | 38 | // Copies a new memory. Doesn't release the original one. 39 | int *getRaw(); 40 | float *getRawFloat(); 41 | void *getRawRef() { return vi; } 42 | 43 | // Note that we don't use a destructor, 44 | // because the Parser object will live till end of program. 45 | void release(); 46 | }; 47 | 48 | class Parser { 49 | using SymbolTable = std::map; 50 | SymbolTable symbols; 51 | 52 | class SemanticScope { 53 | Parser &parser; 54 | SymbolTable symbols; 55 | public: 56 | SemanticScope(Parser &parser): parser(parser), symbols(parser.symbols) {}; 57 | ~SemanticScope() { parser.symbols = symbols; } 58 | }; 59 | 60 | std::vector tokens; 61 | size_t loc; 62 | TypeContext &ctx; 63 | 64 | std::string currentFunc; 65 | 66 | Token last(); 67 | Token peek(); 68 | Token consume(); 69 | 70 | bool peek(Token::Type t); 71 | Token expect(Token::Type t); 72 | 73 | // Prints tokens in range [loc-5, loc+5]. For debugging purposes. 74 | void printSurrounding(); 75 | 76 | template 77 | bool peek(Token::Type t, Rest... ts) { 78 | return peek(t) || peek(ts...); 79 | } 80 | 81 | template 82 | bool test(T... ts) { 83 | if (peek(ts...)) { 84 | loc++; 85 | return true; 86 | } 87 | return false; 88 | } 89 | 90 | // Parses only void, int and float. 91 | Type *parseSimpleType(); 92 | 93 | // Const-fold the node. 94 | ConstValue earlyFold(ASTNode *node); 95 | 96 | ASTNode *primary(); 97 | ASTNode *unary(); 98 | ASTNode *mul(); 99 | ASTNode *add(); 100 | ASTNode *rel(); 101 | ASTNode *eq(); 102 | ASTNode *land(); 103 | ASTNode *lor(); 104 | ASTNode *expr(); 105 | ASTNode *stmt(); 106 | BlockNode *block(); 107 | TransparentBlockNode *varDecl(bool global); 108 | FnDeclNode *fnDecl(); 109 | BlockNode *compUnit(); 110 | 111 | // Global array is guaranteed to be constexpr, so we return a list of int/floats. 112 | // Local array isn't; so we return a list of ASTNodes. 113 | void *getArrayInit(const std::vector &dims, bool expectFloat, bool doFold); 114 | 115 | public: 116 | Parser(const std::string &input, TypeContext &ctx); 117 | ASTNode *parse(); 118 | }; 119 | 120 | } 121 | 122 | #endif 123 | -------------------------------------------------------------------------------- /src/rv/Regs.h: -------------------------------------------------------------------------------- 1 | #ifndef REGS_H 2 | #define REGS_H 3 | 4 | #include "RvAttrs.h" 5 | 6 | namespace sys::rv { 7 | 8 | // We use dedicated registers as the "spill" register, for simplicity. 9 | const Reg spillReg = Reg::s10; 10 | const Reg spillReg2 = Reg::s11; 11 | const Reg fspillReg = Reg::fs10; 12 | const Reg fspillReg2 = Reg::fs11; 13 | 14 | // Order for leaf functions. Prioritize temporaries. 15 | const Reg leafOrder[] = { 16 | Reg::a0, Reg::a1, Reg::a2, Reg::a3, 17 | Reg::a4, Reg::a5, Reg::a6, Reg::a7, 18 | 19 | Reg::t0, Reg::t1, Reg::t2, Reg::t3, 20 | Reg::t4, Reg::t5, Reg::t6, 21 | 22 | Reg::s0, Reg::s1, Reg::s2, Reg::s3, 23 | Reg::s4, Reg::s5, Reg::s6, Reg::s7, 24 | Reg::s8, Reg::s9, 25 | }; 26 | // Order for non-leaf functions. 27 | const Reg normalOrder[] = { 28 | Reg::a0, Reg::a1, Reg::a2, Reg::a3, 29 | Reg::a4, Reg::a5, Reg::a6, Reg::a7, 30 | Reg::ra, 31 | 32 | Reg::t0, Reg::t1, Reg::t2, Reg::t3, 33 | Reg::t4, Reg::t5, Reg::t6, 34 | 35 | Reg::s0, Reg::s1, Reg::s2, Reg::s3, 36 | Reg::s4, Reg::s5, Reg::s6, Reg::s7, 37 | Reg::s8, Reg::s9, 38 | }; 39 | const Reg argRegs[] = { 40 | Reg::a0, Reg::a1, Reg::a2, Reg::a3, 41 | Reg::a4, Reg::a5, Reg::a6, Reg::a7, 42 | }; 43 | const std::set callerSaved = { 44 | Reg::t0, Reg::t1, Reg::t2, Reg::t3, 45 | Reg::t4, Reg::t5, Reg::t6, 46 | 47 | Reg::a0, Reg::a1, Reg::a2, Reg::a3, 48 | Reg::a4, Reg::a5, Reg::a6, Reg::a7, 49 | Reg::ra, 50 | 51 | Reg::ft0, Reg::ft1, Reg::ft2, Reg::ft3, 52 | Reg::ft4, Reg::ft5, Reg::ft6, Reg::ft7, 53 | Reg::ft8, Reg::ft9, Reg::ft10, Reg::ft11, 54 | 55 | Reg::fa0, Reg::fa1, Reg::fa2, Reg::fa3, 56 | Reg::fa4, Reg::fa5, Reg::fa6, Reg::fa7, 57 | }; 58 | 59 | const std::set calleeSaved = { 60 | Reg::s0, Reg::s1, Reg::s2, Reg::s3, 61 | Reg::s4, Reg::s5, Reg::s6, Reg::s7, 62 | Reg::s8, Reg::s9, Reg::s10, Reg::s11, 63 | 64 | Reg::fs0, Reg::fs1, Reg::fs2, Reg::fs3, 65 | Reg::fs4, Reg::fs5, Reg::fs6, Reg::fs7, 66 | Reg::fs8, Reg::fs9, Reg::fs10, Reg::fs11, 67 | }; 68 | constexpr int leafRegCnt = sizeof(leafOrder) / sizeof(Reg); 69 | constexpr int normalRegCnt = sizeof(normalOrder) / sizeof(Reg); 70 | 71 | const Reg leafOrderf[] = { 72 | Reg::fa0, Reg::fa1, Reg::fa2, Reg::fa3, 73 | Reg::fa4, Reg::fa5, Reg::fa6, Reg::fa7, 74 | 75 | Reg::ft0, Reg::ft1, Reg::ft2, Reg::ft3, 76 | Reg::ft4, Reg::ft5, Reg::ft6, Reg::ft7, 77 | Reg::ft8, Reg::ft9, Reg::ft10, Reg::ft11, 78 | 79 | Reg::fs0, Reg::fs1, Reg::fs2, Reg::fs3, 80 | Reg::fs4, Reg::fs5, Reg::fs6, Reg::fs7, 81 | Reg::fs8, Reg::fs9, 82 | }; 83 | // Order for non-leaf functions. 84 | const Reg normalOrderf[] = { 85 | Reg::ft0, Reg::ft1, Reg::ft2, Reg::ft3, 86 | Reg::ft4, Reg::ft5, Reg::ft6, Reg::ft7, 87 | Reg::ft8, Reg::ft9, Reg::ft10, Reg::ft11, 88 | 89 | Reg::fa0, Reg::fa1, Reg::fa2, Reg::fa3, 90 | Reg::fa4, Reg::fa5, Reg::fa6, Reg::fa7, 91 | 92 | Reg::fs0, Reg::fs1, Reg::fs2, Reg::fs3, 93 | Reg::fs4, Reg::fs5, Reg::fs6, Reg::fs7, 94 | Reg::fs8, Reg::fs9, 95 | }; 96 | const Reg fargRegs[] = { 97 | Reg::fa0, Reg::fa1, Reg::fa2, Reg::fa3, 98 | Reg::fa4, Reg::fa5, Reg::fa6, Reg::fa7, 99 | }; 100 | constexpr int leafRegCntf = 30; 101 | constexpr int normalRegCntf = 30; 102 | 103 | inline bool fpreg(Value::Type ty) { 104 | return ty == Value::f32; 105 | } 106 | 107 | } 108 | 109 | #endif 110 | -------------------------------------------------------------------------------- /src/opt/Pass.cpp: -------------------------------------------------------------------------------- 1 | #include "Pass.h" 2 | #include "../codegen/Attrs.h" 3 | 4 | using namespace sys; 5 | 6 | bool sys::isExtern(const std::string &name) { 7 | static std::set externs = { 8 | "getint", 9 | "getch", 10 | "getfloat", 11 | "getarray", 12 | "getfarray", 13 | "putint", 14 | "putch", 15 | "putfloat", 16 | "putarray", 17 | "putfarray", 18 | "_sysy_starttime", 19 | "_sysy_stoptime", 20 | "starttime", 21 | "stoptime", 22 | }; 23 | return externs.count(name); 24 | } 25 | 26 | std::map Pass::getFunctionMap() { 27 | std::map funcs; 28 | 29 | auto region = module->getRegion(); 30 | auto block = region->getFirstBlock(); 31 | for (auto op : block->getOps()) { 32 | if (auto func = dyn_cast(op)) 33 | funcs[NAME(op)] = func; 34 | } 35 | 36 | return funcs; 37 | } 38 | 39 | std::map Pass::getGlobalMap() { 40 | std::map funcs; 41 | 42 | auto region = module->getRegion(); 43 | auto block = region->getFirstBlock(); 44 | for (auto op : block->getOps()) { 45 | if (auto glob = dyn_cast(op)) 46 | funcs[NAME(op)] = glob; 47 | } 48 | 49 | return funcs; 50 | } 51 | 52 | std::vector Pass::collectFuncs() { 53 | std::vector result; 54 | auto toplevel = module->getRegion()->getFirstBlock()->getOps(); 55 | for (auto op : toplevel) { 56 | if (auto fn = dyn_cast(op)) 57 | result.push_back(fn); 58 | } 59 | return result; 60 | } 61 | 62 | std::vector Pass::collectGlobals() { 63 | std::vector result; 64 | auto toplevel = module->getRegion()->getFirstBlock()->getOps(); 65 | for (auto op : toplevel) { 66 | if (auto glob = dyn_cast(op)) 67 | result.push_back(glob); 68 | } 69 | return result; 70 | } 71 | 72 | DomTree Pass::getDomTree(Region *region) { 73 | region->updateDoms(); 74 | 75 | DomTree tree; 76 | for (auto bb : region->getBlocks()) { 77 | if (auto idom = bb->getIdom()) 78 | tree[idom].push_back(bb); 79 | } 80 | return tree; 81 | } 82 | 83 | void Pass::cleanup() { 84 | Op::release(); 85 | 86 | // Put phi's types right. 87 | runRewriter([&](PhiOp *op) { 88 | if (op->getResultType() == Value::f32) 89 | return false; 90 | 91 | for (auto operand : op->getOperands()) { 92 | if (operand.defining->getResultType() == Value::f32) { 93 | op->setResultType(Value::f32); 94 | return true; 95 | } 96 | } 97 | 98 | return false; 99 | }); 100 | } 101 | 102 | Op *Pass::nonalloca(Region *region) { 103 | auto entry = region->getFirstBlock(); 104 | Op *nonalloca = entry->getFirstOp(); 105 | while (!nonalloca->atBack()) { 106 | if (isa(nonalloca)) 107 | nonalloca = nonalloca->nextOp(); 108 | else break; 109 | } 110 | if (nonalloca->atBack()) 111 | nonalloca = entry->nextBlock()->getFirstOp(); 112 | return nonalloca; 113 | } 114 | 115 | Op *Pass::nonphi(BasicBlock *bb) { 116 | Op *nonphi = bb->getFirstOp(); 117 | while (!nonphi->atBack()) { 118 | if (isa(nonphi)) 119 | nonphi = nonphi->nextOp(); 120 | else break; 121 | } 122 | // A basic block should have at least one op, so it's safe. 123 | return nonphi; 124 | } 125 | -------------------------------------------------------------------------------- /src/arm/Regs.h: -------------------------------------------------------------------------------- 1 | #ifndef ARM_REGS_H 2 | #define ARM_REGS_H 3 | 4 | #include "ArmAttrs.h" 5 | 6 | namespace sys::arm { 7 | 8 | 9 | // We use dedicated registers as the "spill" register, for simplicity. 10 | static const Reg fargRegs[] = { 11 | Reg::v0, Reg::v1, Reg::v2, Reg::v3, 12 | Reg::v4, Reg::v5, Reg::v6, Reg::v7, 13 | }; 14 | static const Reg argRegs[] = { 15 | Reg::x0, Reg::x1, Reg::x2, Reg::x3, 16 | Reg::x4, Reg::x5, Reg::x6, Reg::x7, 17 | }; 18 | 19 | static const Reg spillReg = Reg::x28; 20 | static const Reg spillReg2 = Reg::x15; 21 | static const Reg spillReg3 = Reg::x29; 22 | static const Reg fspillReg = Reg::v31; 23 | static const Reg fspillReg2 = Reg::v15; 24 | static const Reg fspillReg3 = Reg::v30; 25 | 26 | // Order for leaf functions. Prioritize temporaries. 27 | static const Reg leafOrder[] = { 28 | Reg::x0, Reg::x1, Reg::x2, Reg::x3, 29 | Reg::x4, Reg::x5, Reg::x6, Reg::x7, 30 | 31 | Reg::x8, Reg::x9, Reg::x10, Reg::x11, 32 | Reg::x12, Reg::x13, Reg::x14, 33 | Reg::x16, Reg::x17, 34 | 35 | Reg::x19, Reg::x20, Reg::x21, Reg::x22, 36 | Reg::x23, Reg::x24, Reg::x25, Reg::x26, 37 | Reg::x27, 38 | }; 39 | // Order for non-leaf functions. 40 | static const Reg normalOrder[] = { 41 | Reg::x0, Reg::x1, Reg::x2, Reg::x3, 42 | Reg::x4, Reg::x5, Reg::x6, Reg::x7, 43 | 44 | Reg::x8, Reg::x9, Reg::x10, Reg::x11, 45 | Reg::x12, Reg::x13, Reg::x14, 46 | Reg::x16, Reg::x17, 47 | 48 | Reg::x19, Reg::x20, Reg::x21, Reg::x22, 49 | Reg::x23, Reg::x24, Reg::x25, Reg::x26, 50 | Reg::x27, 51 | }; 52 | 53 | // The same, but for floating point registers. 54 | static const Reg leafOrderf[] = { 55 | Reg::v0, Reg::v1, Reg::v2, Reg::v3, 56 | Reg::v4, Reg::v5, Reg::v6, Reg::v7, 57 | 58 | Reg::v8, Reg::v9, Reg::v10, Reg::v11, 59 | Reg::v12, Reg::v13, Reg::v14, 60 | 61 | Reg::v16, Reg::v17, Reg::v18, 62 | Reg::v19, Reg::v20, Reg::v21, Reg::v22, 63 | Reg::v23, Reg::v24, Reg::v25, Reg::v26, 64 | Reg::v27, Reg::v28, Reg::v29, 65 | }; 66 | // Order for non-leaf functions. 67 | static const Reg normalOrderf[] = { 68 | Reg::v0, Reg::v1, Reg::v2, Reg::v3, 69 | Reg::v4, Reg::v5, Reg::v6, Reg::v7, 70 | 71 | Reg::v8, Reg::v9, Reg::v10, Reg::v11, 72 | Reg::v12, Reg::v13, Reg::v14, 73 | 74 | Reg::v16, Reg::v17, Reg::v18, 75 | Reg::v19, Reg::v20, Reg::v21, Reg::v22, 76 | Reg::v23, Reg::v24, Reg::v25, Reg::v26, 77 | Reg::v27, Reg::v28, Reg::v29, 78 | }; 79 | 80 | static const std::set callerSaved = { 81 | Reg::x0, Reg::x1, Reg::x2, Reg::x3, 82 | Reg::x4, Reg::x5, Reg::x6, Reg::x7, 83 | 84 | Reg::x8, Reg::x9, Reg::x10, Reg::x11, 85 | Reg::x12, Reg::x13, Reg::x14, Reg::x15, 86 | Reg::x16, Reg::x17, 87 | 88 | Reg::v0, Reg::v1, Reg::v2, Reg::v3, 89 | Reg::v4, Reg::v5, Reg::v6, Reg::v7, 90 | 91 | Reg::v8, Reg::v9, Reg::v10, Reg::v11, 92 | Reg::v12, Reg::v13, Reg::v14, Reg::v15, 93 | }; 94 | 95 | static const std::set calleeSaved = { 96 | Reg::x19, Reg::x20, Reg::x21, Reg::x22, 97 | Reg::x23, Reg::x24, Reg::x25, Reg::x26, 98 | Reg::x27, Reg::x28, 99 | 100 | Reg::v16, Reg::v17, Reg::v18, 101 | Reg::v19, Reg::v20, Reg::v21, Reg::v22, 102 | Reg::v23, Reg::v24, Reg::v25, Reg::v26, 103 | Reg::v27, Reg::v28, Reg::v29, Reg::v30, 104 | }; 105 | 106 | constexpr int leafRegCnt = 26; 107 | constexpr int leafRegCntf = 29; 108 | constexpr int normalRegCnt = 26; 109 | constexpr int normalRegCntf = 29; 110 | 111 | } 112 | 113 | #endif 114 | -------------------------------------------------------------------------------- /src/opt/HoistConstArray.cpp: -------------------------------------------------------------------------------- 1 | #include "Passes.h" 2 | #include 3 | 4 | using namespace sys; 5 | 6 | std::map HoistConstArray::stats() { 7 | return { 8 | { "hoisted-arrays", hoisted } 9 | }; 10 | } 11 | 12 | // Warning: buggy. "all values are constants" part hasn't been implemented. 13 | void HoistConstArray::attemptHoist(Op *op) { 14 | // A alloca is deemed constant if we statically know that all its elements are stored to, 15 | // and are stored to only once, and all values are constants. 16 | // (That's too restrictive; perhaps this won't be much gain?) 17 | auto func = op->getParentOp(); 18 | auto stores = func->findAll(); 19 | auto size = SIZE(op); 20 | int elem = size / 4; 21 | std::vector value(elem); 22 | std::vector fvalue(elem); 23 | std::vector visited(elem); 24 | std::vector toErase; 25 | 26 | for (auto store : stores) { 27 | auto addr = store->DEF(1); 28 | // An unknown store. 29 | if (!addr->has()) 30 | return; 31 | 32 | auto alias = ALIAS(addr); 33 | if (!alias->location.count(op)) 34 | continue; 35 | 36 | // An unsure store. 37 | if (alias->location.size() > 1) 38 | return; 39 | 40 | // Also an unsure store. 41 | const auto &offsets = alias->location[op]; 42 | if (offsets.size() > 1 || offsets[0] == -1) 43 | return; 44 | 45 | // Now we already know where has been stored. 46 | int offset = offsets[0]; 47 | visited[offset / 4] = 1; 48 | 49 | // We might also know the value. 50 | auto def = store->DEF(0); 51 | if (isa(def)) { 52 | value[offset / 4] = V(def); 53 | toErase.push_back(store); 54 | } 55 | if (isa(def)) { 56 | fvalue[offset / 4] = F(def); 57 | toErase.push_back(store); 58 | } 59 | } 60 | 61 | for (int i = 0; i < size / 4; i++) { 62 | // Not every place has been stored. 63 | if (!visited[i]) 64 | return; 65 | } 66 | 67 | bool fp = op->has(); 68 | 69 | // Now safe to transform. 70 | Builder builder; 71 | builder.setToRegionStart(module->getRegion()); 72 | auto name = "__const_" + NAME(func) + "_" + std::to_string(hoisted++); 73 | auto global = builder.create({ new NameAttr(name), new SizeAttr(elem * 4) }); 74 | 75 | // Add init value. 76 | if (fp) { 77 | float *vf = new float[elem]; 78 | memcpy(vf, fvalue.data(), elem * sizeof(float)); 79 | global->add(vf, elem); 80 | } else { 81 | int *vi = new int[elem]; 82 | memcpy(vi, value.data(), elem * sizeof(int)); 83 | global->add(vi, elem); 84 | } 85 | 86 | // Replace the alloca with getglobal. 87 | auto entry = func->getRegion()->getFirstBlock(); 88 | // Don't disrupt the consecutive allocas at the front. 89 | Op *firstNonAlloca = entry->getFirstOp(); 90 | while (isa(firstNonAlloca)) 91 | firstNonAlloca = firstNonAlloca->nextOp(); 92 | 93 | builder.setBeforeOp(firstNonAlloca); 94 | auto getglobal = builder.create({ new NameAttr(name) }); 95 | op->replaceAllUsesWith(getglobal); 96 | op->erase(); 97 | 98 | // Erase redundant stores. 99 | for (auto x : toErase) 100 | x->erase(); 101 | } 102 | 103 | void HoistConstArray::run() { 104 | auto allocas = module->findAll(); 105 | for (auto alloca : allocas) 106 | attemptHoist(alloca); 107 | } 108 | -------------------------------------------------------------------------------- /src/arm/PostIncr.cpp: -------------------------------------------------------------------------------- 1 | #include "ArmLoopPasses.h" 2 | 3 | using namespace sys::arm; 4 | using namespace sys; 5 | 6 | void PostIncr::runImpl(LoopInfo *info) { 7 | if (info->latches.size() > 1) 8 | return; 9 | 10 | auto header = info->header; 11 | auto latch = info->getLatch(); 12 | auto phis = header->getPhis(); 13 | 14 | // Find a series of straight lines from the latch. 15 | std::vector final; 16 | auto runner = latch; 17 | do { 18 | final.push_back(runner); 19 | if (runner->preds.size() > 1) 20 | break; 21 | runner = *runner->preds.begin(); 22 | } while (runner != header); 23 | 24 | Builder builder; 25 | 26 | for (auto phi : phis) { 27 | // First ensure this increments by an in-range constant. 28 | auto latchval = Op::getPhiFrom(phi, latch); 29 | if (!isa(latchval)) 30 | continue; 31 | auto vi = V(latchval); 32 | 33 | // Out of range. 34 | if (vi >= 256 || vi < -256) 35 | continue; 36 | 37 | // Find the last use of this phi, except `latchval`. 38 | Op *lastuse = nullptr; 39 | auto lastbb = final.end(); 40 | 41 | for (auto use : phi->getUses()) { 42 | if (use == latchval) 43 | continue; 44 | 45 | auto parent = use->getParent(); 46 | if (!info->contains(parent)) 47 | continue; 48 | auto it = std::find(final.begin(), final.end(), parent); 49 | if (it == final.end()) 50 | continue; 51 | 52 | if (!lastuse || it < lastbb) { 53 | // Either that's the first use, or found a later use. 54 | lastuse = use; 55 | lastbb = it; 56 | } else if (it == lastbb) { 57 | // The same block; find which's later. 58 | bool later = use->atBack(); 59 | if (!later) { 60 | for (auto p = lastuse; !p->atBack(); p = p->nextOp()) { 61 | if (p == use) { 62 | later = true; 63 | break; 64 | } 65 | } 66 | } 67 | 68 | // If the current `use` is later, then update it. 69 | if (later) { 70 | lastuse = use; 71 | lastbb = it; 72 | } 73 | } 74 | } 75 | 76 | if (!lastuse || !lastuse->has() || V(lastuse) != 0) 77 | continue; 78 | 79 | builder.setBeforeOp(lastuse); 80 | Op *replace = nullptr; 81 | switch (lastuse->opid) { 82 | case LdrWOp::id: 83 | replace = builder.create(lastuse->getOperands(), { new IntAttr(vi) }); 84 | break; 85 | case StrWOp::id: 86 | replace = builder.create(lastuse->getOperands(), { new IntAttr(vi) }); 87 | break; 88 | } 89 | if (!replace) 90 | continue; 91 | 92 | lastuse->replaceAllUsesWith(replace); 93 | lastuse->erase(); 94 | 95 | // Make sure the value of `phi` lives correctly. 96 | builder.setAfterOp(latchval); 97 | auto placeholder = builder.create({ phi }); 98 | phi->replaceOperand(latchval, placeholder); 99 | 100 | // Now `phi`'s value is already increased, so `latchval` is now equal to `phi`. 101 | latchval->replaceAllUsesWith(phi); 102 | } 103 | } 104 | 105 | void PostIncr::run() { 106 | LoopAnalysis analysis(module); 107 | analysis.run(); 108 | auto forests = analysis.getResult(); 109 | 110 | for (const auto &[_, forest] : forests) { 111 | for (auto loop : forest.getLoops()) { 112 | // Only consider innermost loops. 113 | if (loop->subloops.size() == 0) 114 | runImpl(loop); 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/opt/CleanupPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef CLEANUP_PASSES_H 2 | #define CLEANUP_PASSES_H 3 | 4 | #include "Pass.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Attrs.h" 7 | 8 | namespace sys { 9 | 10 | // Dead code elimination. Deals with functions, basic blocks and variables. 11 | class DCE : public Pass { 12 | std::vector removeable; 13 | int elimOp = 0; 14 | int elimFn = 0; 15 | int elimBB = 0; 16 | bool elimBlocks; 17 | 18 | bool isImpure(Op *op); 19 | bool markImpure(Region *region); 20 | void runOnRegion(Region *region); 21 | 22 | std::map fnMap; 23 | public: 24 | // If DCE is called before flatten cfg, then it shouldn't eliminate blocks, 25 | // since the blocks aren't actually well-formed. 26 | DCE(ModuleOp *module, bool elimBlocks = true): Pass(module), elimBlocks(elimBlocks) {} 27 | 28 | std::string name() override { return "dce"; }; 29 | std::map stats() override; 30 | void run() override; 31 | }; 32 | 33 | // Assume every operation is dead unless proved otherwise. 34 | class AggressiveDCE : public Pass { 35 | int elim = 0; 36 | 37 | void runImpl(FuncOp *fn); 38 | public: 39 | AggressiveDCE(ModuleOp *module): Pass(module) {} 40 | 41 | std::string name() override { return "aggressive-dce"; }; 42 | std::map stats() override; 43 | void run() override; 44 | }; 45 | 46 | // Dead (actually, redundant) load elimination. 47 | class DLE : public Pass { 48 | int elim = 0; 49 | 50 | void runImpl(Region *region); 51 | public: 52 | DLE(ModuleOp *module): Pass(module) {} 53 | 54 | std::string name() override { return "dle"; } 55 | std::map stats() override; 56 | void run() override; 57 | }; 58 | 59 | // Dead argument elimination. 60 | class DAE : public Pass { 61 | int elim = 0; 62 | int elimRet = 0; 63 | 64 | void runImpl(Region *region); 65 | public: 66 | DAE(ModuleOp *module): Pass(module) {} 67 | 68 | std::string name() override { return "dae"; } 69 | std::map stats() override; 70 | void run() override; 71 | }; 72 | 73 | // Dead store elimination. 74 | class DSE : public Pass { 75 | std::map used; 76 | 77 | int elim = 0; 78 | 79 | void dfs(BasicBlock *current, DomTree &dom, std::set live); 80 | void runImpl(Region *region); 81 | void removeUnread(Op *op, const std::vector &gets); 82 | public: 83 | DSE(ModuleOp *module): Pass(module) {} 84 | 85 | std::string name() override { return "dse"; }; 86 | std::map stats() override; 87 | void run() override; 88 | }; 89 | 90 | class SimplifyCFG : public Pass { 91 | int inlined = 0; 92 | 93 | void runImpl(Region *region); 94 | public: 95 | SimplifyCFG(ModuleOp *module): Pass(module) {} 96 | 97 | std::string name() override { return "simplify-cfg"; }; 98 | std::map stats() override; 99 | void run() override; 100 | }; 101 | 102 | class RangeAwareFold : public Pass { 103 | int folded = 0; 104 | public: 105 | RangeAwareFold(ModuleOp *module): Pass(module) {} 106 | 107 | std::string name() override { return "range-aware-fold"; }; 108 | std::map stats() override; 109 | void run() override; 110 | }; 111 | 112 | class Reassociate : public Pass { 113 | void runImpl(Region *region); 114 | public: 115 | Reassociate(ModuleOp *module): Pass(module) {} 116 | 117 | std::string name() override { return "reassociate"; }; 118 | std::map stats() override { return {}; } 119 | void run() override; 120 | }; 121 | 122 | } 123 | 124 | #endif 125 | -------------------------------------------------------------------------------- /src/rv/RvAttrs.h: -------------------------------------------------------------------------------- 1 | #ifndef RVATTRS_H 2 | #define RVATTRS_H 3 | 4 | #include "../codegen/OpBase.h" 5 | #include 6 | #define RVLINE __LINE__ + 524288 7 | 8 | namespace sys { 9 | 10 | namespace rv { 11 | 12 | #define REGS \ 13 | X(zero) \ 14 | X(ra) \ 15 | X(sp) \ 16 | X(gp) \ 17 | X(tp) \ 18 | X(t0) \ 19 | X(t1) \ 20 | X(t2) \ 21 | X(t3) \ 22 | X(t4) \ 23 | X(t5) \ 24 | X(t6) \ 25 | X(s0) \ 26 | X(s1) \ 27 | X(s2) \ 28 | X(s3) \ 29 | X(s4) \ 30 | X(s5) \ 31 | X(s6) \ 32 | X(s7) \ 33 | X(s8) \ 34 | X(s9) \ 35 | X(s10) \ 36 | X(s11) \ 37 | X(a0) \ 38 | X(a1) \ 39 | X(a2) \ 40 | X(a3) \ 41 | X(a4) \ 42 | X(a5) \ 43 | X(a6) \ 44 | X(a7) \ 45 | X(ft0) \ 46 | X(ft1) \ 47 | X(ft2) \ 48 | X(ft3) \ 49 | X(ft4) \ 50 | X(ft5) \ 51 | X(ft6) \ 52 | X(ft7) \ 53 | X(ft8) \ 54 | X(ft9) \ 55 | X(ft10) \ 56 | X(ft11) \ 57 | X(fs0) \ 58 | X(fs1) \ 59 | X(fs2) \ 60 | X(fs3) \ 61 | X(fs4) \ 62 | X(fs5) \ 63 | X(fs6) \ 64 | X(fs7) \ 65 | X(fs8) \ 66 | X(fs9) \ 67 | X(fs10) \ 68 | X(fs11) \ 69 | X(fa0) \ 70 | X(fa1) \ 71 | X(fa2) \ 72 | X(fa3) \ 73 | X(fa4) \ 74 | X(fa5) \ 75 | X(fa6) \ 76 | X(fa7) 77 | 78 | #define X(name) name, 79 | enum class Reg : signed { 80 | REGS 81 | }; 82 | 83 | #undef X 84 | 85 | inline std::string showReg(Reg reg) { 86 | switch (reg) { 87 | #define X(name) case Reg::name: return #name; 88 | REGS 89 | #undef X 90 | } 91 | return ""; 92 | } 93 | 94 | #undef REGS 95 | 96 | inline bool isFP(Reg reg) { 97 | return (int) Reg::ft0 <= (int) reg && (int) Reg::fa7 >= (int) reg; 98 | } 99 | 100 | class RegAttr : public AttrImpl { 101 | public: 102 | Reg reg; 103 | 104 | RegAttr(Reg reg): reg(reg) {} 105 | 106 | std::string toString() override { return ""; } 107 | RegAttr *clone() override { return new RegAttr(reg); } 108 | }; 109 | 110 | class RdAttr : public AttrImpl { 111 | public: 112 | Reg reg; 113 | 114 | RdAttr(Reg reg): reg(reg) {} 115 | 116 | std::string toString() override { return ""; } 117 | RdAttr *clone() override { return new RdAttr(reg); } 118 | }; 119 | 120 | class RsAttr : public AttrImpl { 121 | public: 122 | Reg reg; 123 | 124 | RsAttr(Reg reg): reg(reg) {} 125 | 126 | std::string toString() override { return ""; } 127 | RsAttr *clone() override { return new RsAttr(reg); } 128 | }; 129 | 130 | class Rs2Attr : public AttrImpl { 131 | public: 132 | Reg reg; 133 | 134 | Rs2Attr(Reg reg): reg(reg) {} 135 | 136 | std::string toString() override { return ""; } 137 | Rs2Attr *clone() override { return new Rs2Attr(reg); } 138 | }; 139 | 140 | // Stack offset from bp. 141 | class StackOffsetAttr : public AttrImpl { 142 | public: 143 | int offset; 144 | 145 | StackOffsetAttr(int offset): offset(offset) {} 146 | 147 | std::string toString() override { return ""; } 148 | StackOffsetAttr *clone() override { return new StackOffsetAttr(offset); } 149 | }; 150 | 151 | } 152 | 153 | #define STACKOFF(op) (op)->get()->offset 154 | #define RD(op) (op)->get()->reg 155 | #define RS(op) (op)->get()->reg 156 | #define RS2(op) (op)->get()->reg 157 | #define REG(op) (op)->get()->reg 158 | #define RDC(x) new RdAttr(x) 159 | #define RSC(x) new RsAttr(x) 160 | #define RS2C(x) new Rs2Attr(x) 161 | 162 | } 163 | 164 | #endif 165 | -------------------------------------------------------------------------------- /src/opt/PassManager.cpp: -------------------------------------------------------------------------------- 1 | #include "PassManager.h" 2 | #include "Passes.h" 3 | #include "../utils/Exec.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace sys; 10 | 11 | PassManager::PassManager(ModuleOp *module, const Options &opts): 12 | module(module), opts(opts) { 13 | if (opts.compareWith.size()) { 14 | std::ifstream ifs(opts.compareWith); 15 | std::stringstream ss; 16 | ss << ifs.rdbuf(); 17 | truth = ss.str(); 18 | 19 | // Strip the string. 20 | while (truth.size() && std::isspace(truth.back())) 21 | truth.pop_back(); 22 | 23 | // We need to separate the final line. 24 | auto pos = truth.rfind('\n'); 25 | if (pos == std::string::npos) { 26 | // This is the only line of file. 27 | exitcode = std::stoi(truth); 28 | truth.clear(); 29 | } else { 30 | exitcode = std::stoi(truth.substr(pos + 1)); 31 | truth.erase(pos); 32 | } 33 | 34 | // Strip the output again. 35 | while (truth.size() && std::isspace(truth.back())) 36 | truth.pop_back(); 37 | } 38 | 39 | if (opts.simulateInput.size()) { 40 | std::ifstream ifs(opts.simulateInput); 41 | std::stringstream ss; 42 | ss << ifs.rdbuf(); 43 | input = ss.str(); 44 | } 45 | } 46 | 47 | PassManager::~PassManager() { 48 | for (auto pass : passes) 49 | delete pass; 50 | } 51 | 52 | void PassManager::run() { 53 | pastFlatten = false; 54 | pastMem2Reg = false; 55 | inBackend = false; 56 | 57 | for (auto pass : passes) { 58 | if (pass->name() == "flatten-cfg") 59 | pastFlatten = true; 60 | if (pass->name() == "mem2reg") 61 | pastMem2Reg = true; 62 | if (pass->name() == "rv-lower" || pass->name() == "arm-lower") 63 | inBackend = true; 64 | 65 | if (pass->name() == opts.printBefore) { 66 | std::cerr << "===== Before " << pass->name() << " =====\n\n"; 67 | module->dump(std::cerr); 68 | std::cerr << "\n\n"; 69 | } 70 | 71 | pass->run(); 72 | pass->cleanup(); 73 | 74 | if (opts.verbose || pass->name() == opts.printAfter) { 75 | std::cerr << "===== After " << pass->name() << " =====\n\n"; 76 | module->dump(std::cerr); 77 | std::cerr << "\n\n"; 78 | } 79 | 80 | // Before mem2reg, we don't have phis. 81 | // Verify pass only checks phis; so no point running it before that. 82 | if (opts.verify && pastMem2Reg) { 83 | std::cerr << "checking " << pass->name() << "..."; 84 | Verify(module).run(); 85 | std::cerr << " passed\n"; 86 | } 87 | 88 | // We can't simulate for backend. 89 | // Technically we have the capacity, but it's too much work. 90 | if (opts.compareWith.size() && pastFlatten && !inBackend) { 91 | std::cerr << "checking " << pass->name() << "\n"; 92 | exec::Interpreter itp(module); 93 | std::stringstream buffer(input); 94 | itp.run(buffer); 95 | std::string str = itp.out(); 96 | // Strip output. 97 | while (str.size() && std::isspace(str.back())) 98 | str.pop_back(); 99 | 100 | if (str != truth) { 101 | std::cerr << "output mismatch:\n" << str << "\n"; 102 | std::cerr << "after pass: " << pass->name() << "\n"; 103 | assert(false); 104 | } 105 | if (exitcode != itp.exitcode()) { 106 | std::cerr << "exit code mismatch:" << itp.exitcode() << " (expected " << exitcode << ")\n"; 107 | std::cerr << "after pass: " << pass->name() << "\n"; 108 | assert(false); 109 | } 110 | } 111 | 112 | if (opts.stats) { 113 | std::cerr << pass->name() << ":\n"; 114 | 115 | auto stats = pass->stats(); 116 | if (!stats.size()) 117 | std::cerr << " \n"; 118 | 119 | for (auto [k, v] : stats) 120 | std::cerr << " " << k << " : " << v << "\n"; 121 | } 122 | } 123 | } 124 | 125 | -------------------------------------------------------------------------------- /src/arm/ArmAttrs.h: -------------------------------------------------------------------------------- 1 | #ifndef ARM_ATTRS_H 2 | #define ARM_ATTRS_H 3 | 4 | #include "../codegen/Attrs.h" 5 | #define ARMLINE __LINE__ + 1048576 6 | 7 | namespace sys { 8 | 9 | namespace arm { 10 | 11 | #define REGS \ 12 | /* x0 - x7: arguments */ \ 13 | X(x0) \ 14 | X(x1) \ 15 | X(x2) \ 16 | X(x3) \ 17 | X(x4) \ 18 | X(x5) \ 19 | X(x6) \ 20 | X(x7) \ 21 | /* x8: indirect result (we don't need it) */ \ 22 | X(x8) \ 23 | /* x9 - x15: caller saved (temps) */ \ 24 | X(x9) \ 25 | X(x10) \ 26 | X(x11) \ 27 | X(x12) \ 28 | X(x13) \ 29 | X(x14) \ 30 | X(x15) \ 31 | /* x16 - x18: reserved. Avoid them. */ \ 32 | X(x16) \ 33 | X(x17) \ 34 | X(x18) \ 35 | /* x19 - x29: callee saved. (x29 can be `fp`) */ \ 36 | X(x19) \ 37 | X(x20) \ 38 | X(x21) \ 39 | X(x22) \ 40 | X(x23) \ 41 | X(x24) \ 42 | X(x25) \ 43 | X(x26) \ 44 | X(x27) \ 45 | X(x28) \ 46 | X(x29) \ 47 | /* x30: ra */ \ 48 | X(x30) \ 49 | /* x31: either sp or zero, based on context; we consider it as two separate ones */ \ 50 | X(sp) \ 51 | X(xzr) \ 52 | /* v0 - v7: arguments */ \ 53 | X(v0) \ 54 | X(v1) \ 55 | X(v2) \ 56 | X(v3) \ 57 | X(v4) \ 58 | X(v5) \ 59 | X(v6) \ 60 | X(v7) \ 61 | /* v8 - v15: caller saved (temps) */ \ 62 | X(v8) \ 63 | X(v9) \ 64 | X(v10) \ 65 | X(v11) \ 66 | X(v12) \ 67 | X(v13) \ 68 | X(v14) \ 69 | X(v15) \ 70 | /* v16 - v31: callee saved */ \ 71 | X(v16) \ 72 | X(v17) \ 73 | X(v18) \ 74 | X(v19) \ 75 | X(v20) \ 76 | X(v21) \ 77 | X(v22) \ 78 | X(v23) \ 79 | X(v24) \ 80 | X(v25) \ 81 | X(v26) \ 82 | X(v27) \ 83 | X(v28) \ 84 | X(v29) \ 85 | X(v30) \ 86 | X(v31) 87 | 88 | #define X(name) name, 89 | enum class Reg : signed int { 90 | REGS 91 | }; 92 | 93 | #undef X 94 | 95 | inline std::string showReg(Reg reg) { 96 | switch (reg) { 97 | #define X(name) case Reg::name: return #name; 98 | REGS 99 | #undef X 100 | } 101 | return ""; 102 | } 103 | 104 | inline std::ostream &operator<<(std::ostream &os, Reg reg) { 105 | return os << showReg(reg); 106 | } 107 | 108 | #undef REGS 109 | 110 | inline bool isFP(Reg reg) { 111 | return (int) Reg::v0 <= (int) reg && (int) Reg::v31 >= (int) reg; 112 | } 113 | 114 | class StackOffsetAttr : public AttrImpl { 115 | public: 116 | int offset; 117 | 118 | StackOffsetAttr(int offset): offset(offset) {} 119 | 120 | std::string toString() { return ""; } 121 | StackOffsetAttr *clone() { return new StackOffsetAttr(offset); } 122 | }; 123 | 124 | class LslAttr : public AttrImpl { 125 | public: 126 | int vi; 127 | 128 | LslAttr(int vi): vi(vi) {} 129 | 130 | std::string toString() override { return ""; } 131 | LslAttr *clone() override { return new LslAttr(vi); } 132 | }; 133 | 134 | #define RATTR(Ty, name) \ 135 | class Ty : public AttrImpl { \ 136 | public: \ 137 | Reg reg; \ 138 | Ty(Reg reg): reg(reg) {} \ 139 | std::string toString() override { return "<" name + showReg(reg) + ">"; } \ 140 | Ty *clone() override { return new Ty(reg); } \ 141 | }; 142 | 143 | RATTR(RegAttr, ""); 144 | RATTR(RdAttr, "rd = "); 145 | RATTR(RsAttr, "rs = "); 146 | RATTR(Rs2Attr, "rs2 = "); 147 | RATTR(Rs3Attr, "rs3 = "); 148 | 149 | } 150 | 151 | } 152 | 153 | #define STACKOFF(op) (op)->get()->offset 154 | #define REG(op) (op)->get()->reg 155 | #define RD(op) (op)->get()->reg 156 | #define RS(op) (op)->get()->reg 157 | #define RS2(op) (op)->get()->reg 158 | #define RS3(op) (op)->get()->reg 159 | #define RDC(x) new RdAttr(x) 160 | #define RSC(x) new RsAttr(x) 161 | #define RS2C(x) new Rs2Attr(x) 162 | #define RS3C(x) new Rs3Attr(x) 163 | #define LSL(op) (op)->get()->vi 164 | 165 | #endif 166 | -------------------------------------------------------------------------------- /src/pre-opt/ArrayAccess.cpp: -------------------------------------------------------------------------------- 1 | #include "PreAnalysis.h" 2 | 3 | using namespace sys; 4 | 5 | namespace { 6 | 7 | AffineExpr make(const std::vector &outer) { 8 | return AffineExpr(outer.size() + 1); 9 | } 10 | 11 | // Lengthen the affine expression to match the current depth. 12 | void lengthen(AffineExpr &x, const std::vector &outer) { 13 | x.reserve(outer.size() + 1); 14 | 15 | // Remove the constant. 16 | auto back = x.back(); 17 | x.pop_back(); 18 | // Add zeroes at end. 19 | x.resize(outer.size()); 20 | // Put the constant back. 21 | x.push_back(back); 22 | } 23 | 24 | AffineExpr lengthened(Op *op, const std::vector &outer) { 25 | auto val = SUBSCRIPT(op); 26 | lengthen(val, outer); 27 | return val; 28 | } 29 | 30 | void remove(Region *region) { 31 | for (auto bb : region->getBlocks()) { 32 | for (auto op : bb->getOps()) { 33 | op->remove(); 34 | for (auto r : op->getRegions()) 35 | remove(r); 36 | } 37 | } 38 | } 39 | 40 | } 41 | 42 | void ArrayAccess::runImpl(Op *loop, std::vector outer) { 43 | auto region = loop->getRegion(); 44 | auto bb = region->getFirstBlock(); 45 | 46 | for (auto op : bb->getOps()) { 47 | if (isa(op)) { 48 | auto val = make(outer); 49 | val.back() = V(op); 50 | op->add(val); 51 | continue; 52 | } 53 | 54 | if (isa(op)) { 55 | auto k = outer; 56 | k.push_back(op); 57 | 58 | // Add subscript to this loop's induction variable. 59 | auto val = make(k); 60 | val[val.size() - 2] = 1; 61 | op->add(val); 62 | 63 | runImpl(op, k); 64 | continue; 65 | } 66 | 67 | if (isa(op)) { 68 | runImpl(op, outer); 69 | continue; 70 | } 71 | 72 | // Though WhileOp has regions, it is not considered here. 73 | // It's because induction variable isn't clear. 74 | 75 | if (isa(op)) { 76 | auto x = op->DEF(0); 77 | auto y = op->DEF(1); 78 | if (!x->has() || !y->has()) 79 | continue; 80 | 81 | auto vx = lengthened(x, outer); 82 | auto vy = lengthened(y, outer); 83 | for (int i = 0; i < vx.size(); i++) 84 | vx[i] += vy[i]; 85 | op->add(vx); 86 | continue; 87 | } 88 | 89 | if (isa(op)) { 90 | auto x = op->DEF(0); 91 | auto y = op->DEF(1); 92 | if (!isa(y) || !x->has()) 93 | continue; 94 | 95 | auto val = lengthened(x, outer); 96 | for (auto &coeff : val) 97 | coeff *= V(y); 98 | op->add(val); 99 | continue; 100 | } 101 | 102 | // Also tag the address, though it technically isn't a subscript. 103 | if (isa(op)) { 104 | auto x = op->DEF(0); 105 | auto y = op->DEF(1); 106 | if (!x->has()) { 107 | if (y->has()) 108 | std::swap(x, y); 109 | else continue; 110 | } 111 | 112 | auto vx = lengthened(x, outer); 113 | if (y->has()) { 114 | auto vy = lengthened(y, outer); 115 | for (int i = 0; i < vx.size(); i++) 116 | vx[i] += vy[i]; 117 | } 118 | op->add(vx); 119 | continue; 120 | } 121 | } 122 | } 123 | 124 | void ArrayAccess::run() { 125 | auto funcs = collectFuncs(); 126 | 127 | for (auto func : funcs) { 128 | auto region = func->getRegion(); 129 | 130 | // Remove all existing subscripts first. 131 | remove(region); 132 | 133 | for (auto bb : region->getBlocks()) { 134 | for (auto op : bb->getOps()) { 135 | if (!isa(op)) 136 | continue; 137 | 138 | // Start marking. 139 | op->add(std::vector { 1, 0 }); 140 | runImpl(op, { op }); 141 | } 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /scoreboard.ts: -------------------------------------------------------------------------------- 1 | #!pnpm tsx 2 | import axios from 'axios'; 3 | import * as cheerio from 'cheerio'; 4 | import * as fs from 'fs'; 5 | import * as readline from 'readline'; 6 | 7 | interface ScoreEntry { 8 | name: string; 9 | time: number; 10 | } 11 | 12 | let contest_id = "y9s9zPhwJPE"; 13 | let task_id = "7090546"; 14 | let page = "https://course.educg.net//pages/contest/contest_rank_more.jsp"; 15 | let url = `${page}?contestID=${contest_id}&taskID=${task_id}`; 16 | 17 | async function fetch(url: string) { 18 | // Fetch the content. 19 | const response = await axios.get(url); 20 | const html: string = response.data; 21 | const $: cheerio.CheerioAPI = cheerio.load(html); 22 | console.log("Fetched url:\n"); 23 | console.log(html); 24 | 25 | const rows = $("table tbody tr"); 26 | const result: [string, number, number][] = []; 27 | rows.each((index: number, row) => { 28 | const cells = $(row).find("td"); 29 | const name: string = $(cells.get(0)).text(); 30 | const self_time = parseFloat($(cells.get(2)).text()); 31 | const best_time = parseFloat($(cells.get(3)).text()); 32 | result.push([ name, self_time, best_time ]); 33 | }); 34 | 35 | for (let [name, self, best] of result) 36 | console.log(`${name},${self},${best}`); 37 | } 38 | 39 | function parseLine(line: string): ScoreEntry | null { 40 | const parts = line.trim().split(/\s+/); 41 | if (parts.length < 6) 42 | return null; 43 | 44 | return { 45 | name: parts[1], 46 | time: parseFloat(parts[3]), 47 | }; 48 | } 49 | 50 | function parseTest(line: string): ScoreEntry | null { 51 | const parts = line.trim().split(/\s+/); 52 | if (parts.length < 2) 53 | return null; 54 | 55 | return { 56 | name: parts[0], 57 | time: parseFloat(parts[1]), 58 | }; 59 | } 60 | 61 | async function read(filePath: string, parser: (string) => ScoreEntry | null): Promise { 62 | const fileStream = fs.createReadStream(filePath); 63 | const rl = readline.createInterface({ 64 | input: fileStream, 65 | crlfDelay: Infinity, 66 | }); 67 | 68 | const entries: ScoreEntry[] = []; 69 | for await (const line of rl) { 70 | const parsed = parser(line); 71 | if (parsed) { 72 | entries.push(parsed); 73 | } 74 | } 75 | 76 | return entries; 77 | } 78 | 79 | function compare(f1: ScoreEntry[], f2: ScoreEntry[], threshold = 1) { 80 | console.log("Significant changes:"); 81 | const namelen = f1.map((x) => x.name.length).reduce((x, cur) => Math.max(x, cur)); 82 | 83 | f1.forEach((a, i) => { 84 | const b = f2[i]; 85 | const delta = (b.time - a.time) / a.time * 100; 86 | if (Math.abs(delta) >= threshold) { 87 | const plus = delta > 0 ? "+" : ""; 88 | const change = `${a.time.toFixed(2)} -> ${b.time.toFixed(2)}`.padEnd(16); 89 | console.log( 90 | `${a.name.padEnd(namelen + 1)} ${change} (${plus}${delta.toFixed(2)}%)` 91 | ); 92 | } 93 | }); 94 | } 95 | 96 | async function main() { 97 | // The first two arguments are node-path and script path. 98 | let name1: string, name2: string; 99 | 100 | const len = process.argv.length; 101 | if (len < 3) { 102 | console.log("usage: scoreboard.ts "); 103 | return; 104 | } 105 | 106 | if (len == 3) { 107 | const count = parseInt(process.argv[2]); 108 | name1 = count.toString(); 109 | name2 = (count + 1).toString(); 110 | } else if (len == 4) { 111 | name1 = process.argv[2]; 112 | name2 = process.argv[3]; 113 | } else { 114 | console.log("usage: scoreboard.ts ") 115 | return; 116 | } 117 | 118 | let parser = len == 3 ? parseLine : parseTest; 119 | 120 | const file1 = `rank/${name1}.txt`; 121 | const file2 = `rank/${name2}.txt`; 122 | 123 | const data1 = await read(file1, parser); 124 | const data2 = await read(file2, parser); 125 | 126 | if (data1.length !== data2.length) { 127 | console.error(`different entry count: ${data1.length} != ${data2.length}`); 128 | return; 129 | } 130 | 131 | compare(data1, data2); 132 | } 133 | 134 | main().catch(err => console.error(err)); 135 | -------------------------------------------------------------------------------- /src/opt/Specialize.cpp: -------------------------------------------------------------------------------- 1 | #include "Passes.h" 2 | #include "Analysis.h" 3 | #include "CleanupPasses.h" 4 | 5 | using namespace sys; 6 | 7 | // NOTE: we consider 0 as both positive and negative here. 8 | 9 | namespace { 10 | 11 | // E.g. __pos_1_fib means argument 1 (second) is specialized to be positive. 12 | std::string posname(const std::string &name, int i) { 13 | return "__pos_" + std::to_string(i) + "_" + name;; 14 | } 15 | 16 | std::map> produced; 17 | std::map> processed; 18 | 19 | void copy(Region *tgt, Region *src) { 20 | std::unordered_map rewireMap; 21 | std::unordered_map cloneMap; 22 | Builder builder; 23 | 24 | for (auto x : src->getBlocks()) 25 | rewireMap[x] = tgt->appendBlock(); 26 | 27 | for (auto [k, v] : rewireMap) { 28 | builder.setToBlockStart(v); 29 | for (auto op : k->getOps()) { 30 | auto copied = builder.copy(op); 31 | cloneMap[op] = copied; 32 | if (processed.count(op)) 33 | processed[copied] = processed[op]; 34 | } 35 | } 36 | 37 | // Rewire operands. 38 | for (auto [_, v] : cloneMap) { 39 | for (int i = 0; i < v->getOperandCount(); i++) 40 | v->setOperand(i, cloneMap[v->DEF(i)]); 41 | } 42 | 43 | // Rewire basic blocks. 44 | for (auto [_, v] : rewireMap) { 45 | auto term = v->getLastOp(); 46 | if (auto target = term->find(); target && target->bb) 47 | target->bb = rewireMap[target->bb]; 48 | if (auto ifnot = term->find(); ifnot && ifnot->bb) 49 | ifnot->bb = rewireMap[ifnot->bb]; 50 | } 51 | 52 | // Rewire phis. 53 | for (auto [_, v] : cloneMap) { 54 | if (!isa(v)) 55 | continue; 56 | 57 | // RangeAttr isn't deleted yet. 58 | for (auto attr : v->getAttrs()) { 59 | if (!isa(attr)) 60 | continue; 61 | auto &from = FROM(attr); 62 | from = rewireMap[from]; 63 | } 64 | } 65 | } 66 | 67 | } 68 | 69 | void removeRange(Region *region) { 70 | for (auto bb : region->getBlocks()) { 71 | for (auto op : bb->getOps()) { 72 | if (!isa(op)) 73 | break; 74 | 75 | op->remove(); 76 | } 77 | } 78 | } 79 | 80 | bool Specialize::specialize() { 81 | auto calls = module->findAll(); 82 | std::unordered_map> posinfo; 83 | 84 | for (auto call : calls) { 85 | auto &name = NAME(call); 86 | if (isExtern(name)) 87 | continue; 88 | 89 | for (int i = 0; i < call->getOperandCount(); i++) { 90 | auto def = call->DEF(i); 91 | if (!def->has()) 92 | continue; 93 | auto [low, high] = RANGE(def); 94 | if (low >= 0) { 95 | posinfo[name].insert(i); 96 | if (!processed[call].count(i)) 97 | name = posname(name, i); 98 | processed[call].insert(i); 99 | break; 100 | } 101 | 102 | // TODO: negative 103 | // TODO: perhaps more possible ranges? 104 | } 105 | } 106 | 107 | auto fmap = getFunctionMap(); 108 | Builder builder; 109 | 110 | bool changed = false; 111 | for (auto [name, info] : posinfo) { 112 | for (auto i : info) { 113 | if (produced[name].count(i)) 114 | continue; 115 | 116 | auto pname = posname(name, i); 117 | produced[name].insert(i); 118 | produced[pname] = produced[name]; 119 | 120 | changed = true; 121 | auto fn = fmap[name]; 122 | builder.setToRegionStart(module->getRegion()); 123 | auto copied = builder.create({ 124 | new NameAttr(pname), 125 | fn->get() 126 | }); 127 | copied->appendRegion(); 128 | 129 | copy(copied->getRegion(), fn->getRegion()); 130 | } 131 | } 132 | 133 | Range(module).run(); 134 | RangeAwareFold(module).run(); 135 | // Perform split again. 136 | Range(module).run(); 137 | return changed; 138 | } 139 | 140 | void Specialize::run() { 141 | Range(module).run(); 142 | while (specialize()); 143 | CallGraph(module).run(); 144 | 145 | // Remove all range attributes for phi. 146 | auto funcs = collectFuncs(); 147 | 148 | for (auto func : funcs) { 149 | auto region = func->getRegion(); 150 | removeRange(region); 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /src/arm/InstCombine.cpp: -------------------------------------------------------------------------------- 1 | #include "ArmPasses.h" 2 | #include "ArmMatcher.h" 3 | 4 | using namespace sys; 5 | using namespace sys::arm; 6 | 7 | std::map InstCombine::stats() { 8 | return { 9 | { "combined-ops", combined } 10 | }; 11 | } 12 | 13 | static ArmRule rules[] = { 14 | // ADD 15 | "(change (addw x (mov #a)) (!only-if (!inbit 12 #a) (addwi x #a)))", 16 | "(change (addx x (mov #a)) (!only-if (!inbit 12 #a) (addxi x #a)))", 17 | "(change (addw x (lslwi y #a)) (addwl x y #a))", 18 | "(change (addw (lslwi y #a) x) (addwl x y #a))", 19 | "(change (addx x (lslwi y #a)) (addxl x y #a))", 20 | "(change (addx (lslwi y #a) x) (addxl x y #a))", 21 | "(change (addw x (lslxi y #a)) (addwl x y #a))", 22 | "(change (addw (lslxi y #a) x) (addwl x y #a))", 23 | "(change (addx x (lslxi y #a)) (addxl x y #a))", 24 | "(change (addx (lslxi y #a) x) (addxl x y #a))", 25 | "(change (addw x (asrwi y #a)) (addwar x y #a))", 26 | "(change (addw (asrwi y #a) x) (addwar x y #a))", 27 | "(change (addw (mulw x y) z) (maddw x y z))", 28 | "(change (addw z (mulw x y)) (maddw x y z))", 29 | 30 | // FADD, FSUB: precision changes unexpectedly 31 | // "(change (fadd (fmul x y) z) (fmadd x y z))", 32 | // "(change (fadd z (fmul x y)) (fmadd x y z))", 33 | 34 | // SUB 35 | "(change (subw x (mov #a)) (!only-if (!inbit 12 (!minus #a)) (addwi x (!minus #a))))", 36 | "(change (subx x (mov #a)) (!only-if (!inbit 12 (!minus #a)) (addxi x (!minus #a))))", 37 | 38 | // CBZ 39 | "(change (cbz (csetlt x y) >ifso >ifnot) (blt x y >ifnot >ifso))", 40 | "(change (cbz (csetle x y) >ifso >ifnot) (ble x y >ifnot >ifso))", 41 | "(change (cbz (csetne x y) >ifso >ifnot) (beq x y >ifso >ifnot))", 42 | "(change (cbz (cseteq x y) >ifso >ifnot) (bne x y >ifso >ifnot))", 43 | 44 | // CBNZ 45 | "(change (cbnz (csetlt x y) >ifso >ifnot) (blt x y >ifso >ifnot))", 46 | "(change (cbnz (csetle x y) >ifso >ifnot) (ble x y >ifso >ifnot))", 47 | "(change (cbnz (csetne x y) >ifso >ifnot) (bne x y >ifso >ifnot))", 48 | "(change (cbnz (cseteq x y) >ifso >ifnot) (beq x y >ifso >ifnot))", 49 | 50 | // LDR 51 | "(change (ldrw (addxi x #a) #b) (!only-if (!inbit 12 (!add #a #b)) (ldrw x (!add #a #b))))", 52 | "(change (ldrx (addxi x #a) #b) (!only-if (!inbit 12 (!add #a #b)) (ldrx x (!add #a #b))))", 53 | "(change (ldrf (addxi x #a) #b) (!only-if (!inbit 12 (!add #a #b)) (ldrf x (!add #a #b))))", 54 | "(change (ldrw (addx x y) #a) (!only-if (!eq #a 0) (ldrwr x y #a)))", 55 | "(change (ldrx (addx x y) #a) (!only-if (!eq #a 0) (ldrxr x y #a)))", 56 | "(change (ldrf (addx x y) #a) (!only-if (!eq #a 0) (ldrfr x y #a)))", 57 | "(change (ldrwr x (lslxi y #a) #b) (!only-if (!eq (!add #a #b) 2) (ldrwr x y 2)))", 58 | "(change (ldrxr x (lslxi y #a) #b) (!only-if (!eq (!add #a #b) 3) (ldrxr x y 3)))", 59 | "(change (ldrfr x (lslxi y #a) #b) (!only-if (!eq (!add #a #b) 2) (ldrfr x y 2)))", 60 | 61 | // STR 62 | "(change (strw y (addxi x #a) #b) (!only-if (!inbit 12 (!add #a #b)) (strw y x (!add #a #b))))", 63 | "(change (strx y (addxi x #a) #b) (!only-if (!inbit 12 (!add #a #b)) (strx y x (!add #a #b))))", 64 | "(change (strf y (addxi x #a) #b) (!only-if (!inbit 12 (!add #a #b)) (strf y x (!add #a #b))))", 65 | "(change (strw z (addx x y) #a) (!only-if (!eq #a 0) (strwr z x y #a)))", 66 | "(change (strx z (addx x y) #a) (!only-if (!eq #a 0) (strxr z x y #a)))", 67 | "(change (strf z (addx x y) #a) (!only-if (!eq #a 0) (strfr z x y #a)))", 68 | "(change (strwr z x (lslxi y #a) #b) (!only-if (!eq (!add #a #b) 2) (strwr z x y 2)))", 69 | "(change (strxr z x (lslxi y #a) #b) (!only-if (!eq (!add #a #b) 3) (strxr z x y 3)))", 70 | "(change (strfr z x (lslxi y #a) #b) (!only-if (!eq (!add #a #b) 2) (strfr z x y 2)))", 71 | }; 72 | 73 | void InstCombine::run() { 74 | auto funcs = collectFuncs(); 75 | int folded; 76 | do { 77 | folded = 0; 78 | for (auto func : funcs) { 79 | auto region = func->getRegion(); 80 | 81 | for (auto bb : region->getBlocks()) { 82 | auto ops = bb->getOps(); 83 | for (auto op : ops) { 84 | for (auto &rule : rules) { 85 | bool success = rule.rewrite(op); 86 | if (success) { 87 | folded++; 88 | break; 89 | } 90 | } 91 | } 92 | } 93 | } 94 | 95 | combined += folded; 96 | } while (folded); 97 | } 98 | -------------------------------------------------------------------------------- /src/opt/Alias.cpp: -------------------------------------------------------------------------------- 1 | #include "Analysis.h" 2 | 3 | using namespace sys; 4 | 5 | static void postorder(BasicBlock *current, DomTree &tree, std::vector &order) { 6 | for (auto bb : tree[current]) 7 | postorder(bb, tree, order); 8 | order.push_back(current); 9 | } 10 | 11 | void Alias::runImpl(Region *region) { 12 | // Run local analysis over RPO of the dominator tree. 13 | 14 | // First calculate RPO. 15 | DomTree tree = getDomTree(region); 16 | 17 | BasicBlock *entry = region->getFirstBlock(); 18 | std::vector rpo; 19 | postorder(entry, tree, rpo); 20 | std::reverse(rpo.begin(), rpo.end()); 21 | 22 | // Then traverse the CFG in that order. 23 | // This should guarantee definition comes before all uses. 24 | for (auto bb : rpo) { 25 | for (auto op : bb->getOps()) { 26 | if (isa(op)) { 27 | op->remove(); 28 | op->add(op, 0); 29 | continue; 30 | } 31 | 32 | if (isa(op)) { 33 | op->remove(); 34 | op->add(gMap[NAME(op)], 0); 35 | continue; 36 | } 37 | 38 | if (isa(op)) { 39 | op->remove(); 40 | auto x = op->getOperand(0).defining; 41 | auto y = op->getOperand(1).defining; 42 | if (!x->has() && !y->has()) { 43 | op->add(/*unknown*/); 44 | continue; 45 | } 46 | 47 | if (!x->has()) 48 | std::swap(x, y); 49 | 50 | // Now `x` is the address and `y` is the offset. 51 | // Note this swap won't affect the original op. 52 | auto alias = ALIAS(x)->clone(); 53 | if (isa(y)) { 54 | auto delta = V(y); 55 | for (auto &[_, offset] : alias->location) { 56 | for (auto &value : offset) { 57 | if (value != -1) 58 | value += delta; 59 | } 60 | } 61 | } else { 62 | // Unknown offset. Set all offsets to -1. 63 | for (auto &[_, offset] : alias->location) 64 | offset = { -1 }; 65 | } 66 | 67 | if (ALIAS(x)->unknown) 68 | op->add(/*unknown*/); 69 | else 70 | op->add(alias->location); 71 | delete alias; 72 | continue; 73 | } 74 | } 75 | } 76 | } 77 | 78 | // This has better precision after Mem2Reg, because less `int**` is possible. 79 | // Before Mem2Reg, we can store the address of an array in an alloca. 80 | // (Though it won't be fully eliminated after the pass; see 66_exgcd.sy) 81 | // Moreover, it's more useful when all unnecessary alloca's have been removed. 82 | // 83 | // In addition, remember to update the information after Globalize and Localize. 84 | void Alias::run() { 85 | auto funcs = collectFuncs(); 86 | gMap = getGlobalMap(); 87 | 88 | for (auto func : funcs) 89 | runImpl(func->getRegion()); 90 | 91 | // Now do a dataflow analysis on call graph. 92 | auto fnMap = getFunctionMap(); 93 | std::vector worklist; 94 | for (auto [_, v] : fnMap) 95 | worklist.push_back(v); 96 | 97 | while (!worklist.empty()) { 98 | auto func = worklist.back(); 99 | worklist.pop_back(); 100 | 101 | // Update local alias info. 102 | runImpl(func->getRegion()); 103 | 104 | // Find all CallOps from the function. 105 | auto calls = func->findAll(); 106 | 107 | for (auto call : calls) { 108 | const auto &name = NAME(call); 109 | bool changed = false; 110 | 111 | // Propagate alias info for each argument. 112 | if (isExtern(name)) 113 | continue; 114 | 115 | runRewriter(fnMap[name], [&](GetArgOp *op) { 116 | int index = V(op); 117 | auto def = call->getOperand(index).defining; 118 | if (!def->has()) 119 | return false; 120 | 121 | auto defLoc = ALIAS(def); 122 | if (op->has()) 123 | changed |= ALIAS(op)->addAll(defLoc); 124 | else { 125 | op->add(*defLoc); 126 | changed = true; 127 | } 128 | 129 | // Do it only once. 130 | return false; 131 | }); 132 | 133 | if (changed) 134 | worklist.push_back(fnMap[name]); 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/utils/smt/SMT.h: -------------------------------------------------------------------------------- 1 | #ifndef SMT_H 2 | #define SMT_H 3 | 4 | #include "BvExpr.h" 5 | #include "CDCL.h" 6 | #include "../../main/Options.h" 7 | #include 8 | #include 9 | 10 | namespace smt { 11 | 12 | // Bitvector[0] is the least significant bit. 13 | using Bitvector = std::vector; 14 | 15 | class BvSolver { 16 | using Clause = std::vector; 17 | 18 | SATContext ctx; 19 | Solver solver; 20 | std::unordered_map bindings; 21 | std::unordered_map cache; 22 | std::vector reserved; 23 | std::vector assignments; 24 | 25 | // These are literals false and true in SAT solver. 26 | Variable _false; 27 | Variable _true; 28 | 29 | void reserve(const Clause &clause) { reserved.push_back(clause); } 30 | 31 | // This means that `o` is `a op b`. 32 | void addAnd(Variable out, Variable a, Variable b); 33 | void addOr (Variable out, Variable a, Variable b); 34 | void addXor(Variable out, Variable a, Variable b); 35 | void addNot(Variable out, Variable a); 36 | 37 | // Combined operations. 38 | 39 | // a & !b 40 | void addAndNot(Variable out, Variable a, Variable b); 41 | // !(a ^ b) 42 | void addXnor (Variable out, Variable a, Variable b); 43 | 44 | // These blast functions will add clauses to solver. 45 | Bitvector blastConst(int vi); 46 | Bitvector blastVar(const std::string &name); 47 | 48 | // Add 32-bit numbers. 49 | Bitvector blastAdd(const Bitvector &a, const Bitvector &b, bool withCin = false); 50 | 51 | // Subtract with borrow bit `c[n]`. 52 | Bitvector blastSubBorrowed(const Bitvector &a, const Bitvector &b, Variable borrow); 53 | 54 | // Bitwise operations. 55 | Bitvector blastAnd(const Bitvector &a, const Bitvector &b); 56 | Bitvector blastOr(const Bitvector &a, const Bitvector &b); 57 | Bitvector blastXor(const Bitvector &a, const Bitvector &b); 58 | Bitvector blastNot(const Bitvector &a); 59 | 60 | // Left shift by constant. 61 | Bitvector blastLsh(const Bitvector &a, int x); 62 | 63 | // Absolute value. 64 | Bitvector blastAbs(const Bitvector &a); 65 | // Minus. 66 | Bitvector blastMinus(const Bitvector &a); 67 | 68 | // This gives a 64-bit long vector. 69 | Bitvector blastFullMul(const Bitvector &a, const Bitvector &b); 70 | // This gives a 64-bit long vector, and performs signed multiplication. 71 | Bitvector blastFullSMul(const Bitvector &a, const Bitvector &b); 72 | // Multiplies 64-bit vectors and get lower 64-bits. 73 | Bitvector blastFullLMul(const Bitvector &a, const Bitvector &b); 74 | // Multiplies 64-bit vectors and get lower 64-bits, and performs signed multiplication. 75 | Bitvector blastFullSLMul(const Bitvector &a, const Bitvector &b); 76 | 77 | // Unsigned division. 78 | Bitvector blastDiv(const Bitvector &a, const Bitvector &b); 79 | // Signed division. 80 | Bitvector blastSDiv(const Bitvector &a, const Bitvector &b); 81 | 82 | // This gives a full multiplication and then modulus constant x. 83 | // When `x` is zero, this modulus is 2^32, i.e. take the least significant 32 bits. 84 | Bitvector blastMulMod(const Bitvector &a, const Bitvector &b, int x); 85 | 86 | // If-then-else. 87 | Bitvector blastIte(Variable c, const Bitvector &a, const Bitvector &b); 88 | 89 | void blastEq(const Bitvector &a, const Bitvector &b); 90 | void blastNe(const Bitvector &a, const Bitvector &b); 91 | 92 | // Blast operators that have a value. 93 | Bitvector blastOp(BvExpr *expr); 94 | // Blast operators that don't have a value. This means it's top-level. 95 | void blast(BvExpr *expr); 96 | // Blast boolean-valued operators. 97 | Variable blastCond(BvExpr *expr); 98 | 99 | int eval(BvExpr *expr); 100 | public: 101 | using Model = std::unordered_map; 102 | 103 | BvSolver(); 104 | BvSolver(const sys::Options &opts); 105 | sys::Options opts; 106 | 107 | bool infer(BvExpr *expr); 108 | int extract(const std::string &name); 109 | bool has(const std::string &name) { return bindings.count(name); } 110 | int eval(BvExpr *expr, const std::unordered_map &external); 111 | void assign(const std::string &name, int value); 112 | void unassign() { bindings.clear(); } 113 | Model model(); 114 | }; 115 | 116 | BvExpr *simplify(BvExpr *expr, BvExprContext &ctx); 117 | 118 | } 119 | 120 | #endif 121 | -------------------------------------------------------------------------------- /fuzzer.py: -------------------------------------------------------------------------------- 1 | #!/bin/python3 2 | import io 3 | import random as r 4 | import collections as c 5 | 6 | Type = c.namedtuple("Type", ["base", "size"]) 7 | 8 | class Symbols: 9 | def __init__(self): 10 | self.arrs: dict[str, Type] = {} # Maps name to size. 11 | self.const_arrs: dict[str, Type] = {} 12 | self.vars: dict[str, Type] = {} 13 | self.const_vars: dict[str, Type] = {} 14 | 15 | def reset(self): 16 | self.arrs = {} 17 | self.const_arrs = {} 18 | self.vars = {} 19 | self.const_vars = {} 20 | 21 | 22 | table = Symbols() 23 | f: io.TextIOWrapper = None 24 | cnt = 0 25 | 26 | def rand() -> bool: 27 | return r.choice([True, False]) 28 | 29 | def randx(x: float): 30 | return r.random() < x 31 | 32 | def basety() -> str: 33 | return r.choice(["int", "float"]) 34 | 35 | def name() -> str: 36 | global cnt 37 | cnt += 1 38 | return f"v{cnt}" 39 | 40 | def vars(const: bool) -> set[str]: 41 | return table.vars.keys() if not const else table.const_vars.keys() 42 | 43 | def arrs(const: bool) -> set[str]: 44 | return table.arrs.keys() if not const else table.const_arrs.keys() 45 | 46 | def gen_var(const: bool): 47 | var = r.choice(list(vars(const))) 48 | f.write(var) 49 | 50 | def gen_arr_access(const: bool): 51 | arr = r.choice(list(arrs(const))) 52 | size = table.arrs[arr].size 53 | f.write(f"{arr}[{r.randint(0, size - 1)}]") 54 | 55 | def gen_expr_helper(const: bool, remain: int, ops: str): 56 | if remain == 0: 57 | if randx(0.4) and len(arrs(const)): 58 | gen_arr_access(const) 59 | elif randx(0.571428) and len(vars(const)): 60 | gen_var(const) 61 | else: 62 | f.write(str(r.randint(-17, 60))) 63 | return 64 | 65 | if randx(0.1): 66 | op = r.choice(["!", "-"]) # Unary 67 | par = randx(0.2) 68 | 69 | f.write(op) 70 | if par: 71 | f.write('(') 72 | gen_expr_helper(const, remain - 1, ops) 73 | if par: 74 | f.write(')') 75 | return; 76 | 77 | op = r.choice(ops) 78 | parl = randx(0.2) 79 | parr = randx(0.2) 80 | if parl: 81 | f.write('(') 82 | gen_expr_helper(const, remain - 1, ops) 83 | if parl: 84 | f.write(')') 85 | f.write(f" {op} ") 86 | if parr: 87 | f.write('(') 88 | gen_expr_helper(const, remain - 1, ops) 89 | if parr: 90 | f.write(')') 91 | 92 | 93 | def gen_expr(const: bool, **kwargs): 94 | depth = r.randint(1, 5) if kwargs.get("depth") is None else kwargs["depth"] 95 | ops = ["+", "-", "*", "/", "%"] if kwargs.get("compare") is None else \ 96 | ["+", "-", "*", "/", "%", "!=", "==", ">=", "<=", "<", ">", "&&", "||"] 97 | gen_expr_helper(const, depth, ops) 98 | 99 | def gen_arr(const: bool): 100 | ty = basety() 101 | n = name() 102 | size = r.randint(1, 256) 103 | f.write(f"{ty} {n}[{size}]") 104 | if rand(): 105 | f.write(" = {") 106 | gen_expr(const, depth=r.randint(1, 2)) 107 | for _ in range(1, r.randint(1, size)): 108 | f.write(", ") 109 | gen_expr(const, depth=r.randint(1, 2)) 110 | f.write("}") 111 | 112 | f.write(";\n") 113 | table.arrs[n] = Type(ty, size) 114 | 115 | def gen_v(const: bool): 116 | ty = basety() 117 | n = name() 118 | f.write(f"{ty} {n}") 119 | if rand(): 120 | f.write(" = ") 121 | gen_expr(const) 122 | 123 | f.write(";\n") 124 | table.vars[n] = Type(ty, None) 125 | 126 | def gen_var_assign(): 127 | gen_var(False) 128 | f.write(" = ") 129 | gen_expr(False) 130 | f.write(";\n") 131 | 132 | def gen_arr_assign(): 133 | gen_arr_access(False) 134 | f.write(" = ") 135 | gen_expr(False) 136 | f.write(";\n") 137 | 138 | def gen_body(): 139 | for _ in range(r.randint(12, 60)): 140 | f.write(" ") 141 | x = r.random() 142 | if x < 0.05: 143 | gen_arr(False) 144 | elif x < 0.2: 145 | gen_v(False) 146 | elif x < 0.7 and len(vars(False)): 147 | gen_var_assign() 148 | elif x < 0.8: 149 | f.write("putint(") 150 | gen_expr(False) 151 | f.write(");\n") 152 | elif len(arrs(False)): 153 | gen_arr_assign() 154 | else: 155 | f.write(";\n") # Empty statement 156 | 157 | def gen_program(): 158 | global table 159 | 160 | for _ in range(r.randint(1, 3)): 161 | r.choice([gen_arr, gen_v])(True) 162 | 163 | 164 | f.write("int main() {\n") 165 | gen_body() 166 | f.write(" return 0;\n}") 167 | 168 | if __name__ == "__main__": 169 | for i in range(0, 3): 170 | cnt = 0 171 | table.reset() 172 | 173 | f = open(f"test/fuzz/{i}.sy", "w") 174 | gen_program() 175 | f.close() 176 | -------------------------------------------------------------------------------- /src/rv/RvOps.h: -------------------------------------------------------------------------------- 1 | #ifndef RV_OPS_H 2 | #define RV_OPS_H 3 | 4 | #include "../codegen/Ops.h" 5 | 6 | // Don't forget that we actually rely on OpID, and __LINE__ can duplicate with codegen/Ops.h. 7 | #define RVOPBASE(ValueTy, Ty) \ 8 | class Ty : public OpImpl { \ 9 | public: \ 10 | explicit Ty(const std::vector &values): OpImpl(ValueTy, values) { \ 11 | setName("rv."#Ty); \ 12 | } \ 13 | Ty(): OpImpl(ValueTy, {}) { \ 14 | setName("rv."#Ty); \ 15 | } \ 16 | explicit Ty(const std::vector &attrs): OpImpl(ValueTy, {}, attrs) { \ 17 | setName("rv."#Ty); \ 18 | } \ 19 | Ty(const std::vector &values, const std::vector &attrs): OpImpl(ValueTy, values, attrs) { \ 20 | setName("rv."#Ty); \ 21 | } \ 22 | } 23 | 24 | // Ops that must be explicitly set a result type. 25 | #define RVOPE(Ty) \ 26 | class Ty : public OpImpl { \ 27 | public: \ 28 | Ty(Value::Type resultTy, const std::vector &values): OpImpl(resultTy, values) { \ 29 | setName("rv."#Ty); \ 30 | } \ 31 | explicit Ty(Value::Type resultTy): OpImpl(resultTy, {}) { \ 32 | setName("rv."#Ty); \ 33 | } \ 34 | Ty(Value::Type resultTy, const std::vector &attrs): OpImpl(resultTy, {}, attrs) { \ 35 | setName("rv."#Ty); \ 36 | } \ 37 | Ty(Value::Type resultTy, const std::vector &values, const std::vector &attrs): OpImpl(resultTy, values, attrs) { \ 38 | setName("rv."#Ty); \ 39 | } \ 40 | } 41 | 42 | #define RVOP(Ty) RVOPBASE(Value::i32, Ty) 43 | #define RVOPL(Ty) RVOPBASE(Value::i64, Ty) 44 | #define RVOPF(Ty) RVOPBASE(Value::f32, Ty) 45 | 46 | namespace sys { 47 | 48 | namespace rv { 49 | 50 | // To add an op: 51 | // 1) Check RegAlloc.cpp, the list of LOWER(...) 52 | // 2) Check the function hasRd(). 53 | 54 | RVOP(LiOp); 55 | RVOPL(LaOp); 56 | RVOPL(AddOp); 57 | RVOP(AddwOp); 58 | RVOP(AddiwOp); 59 | RVOPL(AddiOp); // Note that pointers can't be `addiw`'d. 60 | RVOPL(SubOp); 61 | RVOP(SubwOp); 62 | RVOP(MulwOp); 63 | RVOPL(MulOp); 64 | RVOP(DivwOp); // Signed; divu for unsigned. 65 | RVOPL(DivOp); 66 | RVOP(RemwOp); 67 | RVOPL(RemOp); 68 | RVOP(SlliwOp); // Shift left. 69 | RVOPL(SlliOp); // Shift left (64 bit). 70 | RVOP(SrliwOp); // Shift right, unsigned. 71 | RVOPL(SrliOp); // Shift right (64 bit), unsigned. 72 | RVOP(SraiwOp); // Shift right, signed. 73 | RVOPL(SraiOp); // Shift right (64 bit), signed. 74 | RVOP(SllwOp); // Shift left. 75 | RVOPL(SllOp); // Shift left (64 bit). 76 | RVOP(SrlwOp); // Shift right, unsigned. 77 | RVOP(SrlOp); // Shift right (64 bit), unsigned. 78 | RVOP(SrawOp); // Shift right, signed. 79 | RVOPL(SraOp); // Shift right (64 bit), signed. 80 | RVOP(MulhOp); // Higher bits of mul, signed. 81 | RVOP(MulhuOp); // Higher bits of mul, unsigned. 82 | RVOP(AndOp); 83 | RVOP(OrOp); 84 | RVOP(XorOp); 85 | RVOP(AndiOp); 86 | RVOP(OriOp); 87 | RVOP(XoriOp); 88 | RVOP(BneOp); 89 | RVOP(BeqOp); 90 | RVOP(BltOp); 91 | RVOP(BgeOp); 92 | RVOP(BleOp); 93 | RVOP(BgtOp); 94 | RVOP(SeqzOp); // Set equal to zero (pseudo, = sltiu) 95 | RVOP(SnezOp); // Set not equal to zero (pseudo, 2 ops) 96 | RVOP(SltOp); // Set less than 97 | RVOP(SltiOp); 98 | RVOP(JOp); 99 | RVOP(MvOp); 100 | RVOP(RetOp); 101 | RVOPE(LoadOp); 102 | RVOP(StoreOp); 103 | RVOP(SubSpOp); // Allocate stack space: sub sp, sp, 104 | RVOPE(ReadRegOp); // Read from real register 105 | RVOP(WriteRegOp); // Write to real register; the SSA value is used and pre-colored in RegAlloc. 106 | RVOP(CallOp); 107 | RVOP(PlaceHolderOp); // See regalloc; holds a place to denote a register isn't available. 108 | RVOPF(FcvtswOp); // i32 -> f32 109 | RVOP(FcvtwsRtzOp); // f32 -> i32, round to zero 110 | RVOPF(FmvwxOp); // copies bit pattern from i32 to f32 111 | RVOPF(FmvdxOp); 112 | RVOPL(FmvxdOp); 113 | RVOP(FldOp); // These are only used in stack save/restore. 114 | RVOP(FsdOp); 115 | RVOP(FeqOp); // Note these Ops must have been added with a `.s` in Dump. 116 | RVOP(FltOp); 117 | RVOP(FleOp); 118 | RVOPF(FaddOp); 119 | RVOPF(FsubOp); 120 | RVOPF(FmulOp); 121 | RVOPF(FdivOp); 122 | RVOPF(FmvOp); 123 | 124 | inline bool hasRd(Op *op) { 125 | return !( 126 | isa(op) || 127 | isa(op) || 128 | isa(op) || 129 | isa(op) || 130 | isa(op) || 131 | isa(op) || 132 | isa(op) || 133 | isa(op) || 134 | isa(op) || 135 | isa(op) || 136 | isa(op) 137 | ); 138 | } 139 | 140 | } 141 | 142 | } 143 | 144 | #undef RVOP 145 | 146 | #endif 147 | -------------------------------------------------------------------------------- /src/opt/DAE.cpp: -------------------------------------------------------------------------------- 1 | #include "CleanupPasses.h" 2 | 3 | using namespace sys; 4 | 5 | std::map DAE::stats() { 6 | return { 7 | { "removed-arguments", elim }, 8 | { "removed-return-values", elimRet }, 9 | }; 10 | } 11 | 12 | void DAE::run() { 13 | // Try to infer constants at each call site. 14 | auto calls = module->findAll(); 15 | auto fnMap = getFunctionMap(); 16 | 17 | // value[fn] gives a map `m`, which: 18 | // m[x] = y iff. all calls to `fn` gives `y` to the `x`th argument (counting from 0). 19 | std::map> value; 20 | std::map> forbidden; 21 | // All functions whose return value has been used. 22 | std::set resultUsed; 23 | 24 | for (auto call : calls) { 25 | FuncOp *fn = fnMap[NAME(call)]; 26 | const auto &operands = call->getOperands(); 27 | 28 | if (call->getUses().size() > 0) 29 | resultUsed.insert(fn); 30 | 31 | for (size_t i = 0; i < operands.size(); i++) { 32 | auto operand = operands[i]; 33 | auto def = operand.defining; 34 | 35 | if (isa(def)) { 36 | if (value[fn].count(i) && value[fn][i] == V(def)) 37 | continue; 38 | 39 | // This means the value isn't always the same across all call sites. 40 | if (value[fn].count(i)) { 41 | value[fn].erase(i); 42 | forbidden[fn].insert(i); 43 | continue; 44 | } 45 | 46 | if (forbidden[fn].count(i)) 47 | continue; 48 | 49 | value[fn][i] = V(def); 50 | continue; 51 | } 52 | 53 | // This is not an integer. 54 | value[fn].erase(i); 55 | forbidden[fn].insert(i); 56 | } 57 | } 58 | 59 | // Replace GetArgOp for each of the argument in `value`. 60 | Builder builder; 61 | 62 | auto funcs = collectFuncs(); 63 | for (auto func : funcs) { 64 | std::set> toRemove; 65 | std::set visited; 66 | int &argcnt = func->get()->count; 67 | 68 | auto getargs = func->findAll(); 69 | const auto &invariant = value[func]; 70 | 71 | for (auto getarg : getargs) { 72 | int index = V(getarg); 73 | visited.insert(index); 74 | 75 | if (invariant.count(index)) { 76 | builder.replace(getarg, { new IntAttr(invariant.at(index)) }); 77 | toRemove.insert(index); 78 | continue; 79 | } 80 | 81 | // Also check for AliasAttr. 82 | // If it always alias the same global with offset 0, then just replace it with GetGlobal. 83 | if (auto alias = getarg->find()) { 84 | if (alias->location.size() != 1) 85 | continue; 86 | 87 | const auto &[base, offsets] = *alias->location.begin(); 88 | if (offsets.size() == 1 && offsets[0] == 0 && isa(base)) { 89 | builder.replace(getarg, { new NameAttr(NAME(base)) }); 90 | toRemove.insert(index); 91 | continue; 92 | } 93 | } 94 | } 95 | 96 | // Remove those completely unused arguments. 97 | for (int i = 0; i < argcnt; i++) { 98 | if (!visited.count(i)) 99 | toRemove.insert(i); 100 | } 101 | 102 | if (!toRemove.size()) 103 | continue; 104 | 105 | // Decrease argument count of the function. 106 | argcnt -= toRemove.size(); 107 | 108 | // For all remaining getargs, decrease their count as well. 109 | getargs = func->findAll(); 110 | for (auto index : toRemove) { 111 | // As toRemove is in descending order, this should be correct. 112 | for (auto getarg : getargs) { 113 | if (V(getarg) > index) 114 | V(getarg)--; 115 | } 116 | } 117 | 118 | // For all calls, replace their operands. 119 | const auto &name = NAME(func); 120 | for (auto call : calls) { 121 | if (NAME(call) != name) 122 | continue; 123 | 124 | auto operands = call->getOperands(); 125 | call->removeAllOperands(); 126 | for (size_t i = 0; i < operands.size(); i++) { 127 | if (toRemove.count(i)) 128 | continue; 129 | 130 | call->pushOperand(operands[i]); 131 | } 132 | } 133 | 134 | elim += toRemove.size(); 135 | 136 | // Now let's move return values. 137 | // If the function's return value is never used, then no need to return. 138 | // (We can't remove `return xx` in main even though it doesn't seem used.) 139 | if (resultUsed.count(func) || NAME(func) == "main") 140 | continue; 141 | 142 | auto rets = func->findAll(); 143 | for (auto ret : rets) 144 | ret->removeAllOperands(); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/pre-opt/PreLoopPasses.h: -------------------------------------------------------------------------------- 1 | #ifndef PRE_LOOP_PASSES_H 2 | #define PRE_LOOP_PASSES_H 3 | 4 | #include "../opt/Pass.h" 5 | #include "../codegen/CodeGen.h" 6 | #include "../codegen/Ops.h" 7 | #include "../codegen/Attrs.h" 8 | #include "PreAttrs.h" 9 | #include 10 | 11 | namespace sys { 12 | 13 | // Raise whiles to fors whenever possible. 14 | class RaiseToFor : public Pass { 15 | int raised = 0; 16 | public: 17 | RaiseToFor(ModuleOp *module): Pass(module) {} 18 | 19 | std::string name() override { return "raise-to-for"; } 20 | std::map stats() override; 21 | void run() override; 22 | }; 23 | 24 | // Determine whether a const array is a view of another. 25 | // In that case, inline it. 26 | class View : public Pass { 27 | int inlined = 0; 28 | 29 | std::unordered_map> usedIn; 30 | void runImpl(Op *func); 31 | public: 32 | View(ModuleOp *module): Pass(module) {} 33 | 34 | std::string name() override { return "view"; } 35 | std::map stats() override; 36 | void run() override; 37 | }; 38 | 39 | // Erase useless loops. 40 | class LoopDCE : public Pass { 41 | int erased = 0; 42 | public: 43 | LoopDCE(ModuleOp *module): Pass(module) {} 44 | 45 | std::string name() override { return "loop-dce"; } 46 | std::map stats() override; 47 | void run() override; 48 | }; 49 | 50 | // Loop fusion. 51 | class Fusion : public Pass { 52 | int fused = 0; 53 | 54 | void runImpl(FuncOp *func); 55 | public: 56 | Fusion(ModuleOp *module): Pass(module) {} 57 | 58 | std::string name() override { return "fusion"; } 59 | std::map stats() override; 60 | void run() override; 61 | }; 62 | 63 | // Loop unswitch. 64 | // Unswitch branches related to induction variable. 65 | class Unswitch : public Pass { 66 | int unswitched = 0; 67 | 68 | bool runImpl(Op *loop); 69 | bool cmpmod(Op *loop, Op *cond); 70 | bool ltconst(Op *loop, Op *cond); 71 | bool gtconst(Op *loop, Op *cond); 72 | bool invariant(Op *loop, Op *cond); 73 | public: 74 | Unswitch(ModuleOp *module): Pass(module) {} 75 | 76 | std::string name() override { return "unswitch"; } 77 | std::map stats() override; 78 | void run() override; 79 | }; 80 | 81 | // For 2D arrays, if their access pattern suggests so, 82 | // then reorder them into column major format. 83 | class ColumnMajor : public Pass { 84 | struct AccessData { 85 | int depth = 0; 86 | // `true` when there exists a loop such that this array's last dimension is 87 | // contiguously accessed on the deepest dimension; 88 | // i.e. it's the smallest entry in SubscriptAttr. 89 | bool contiguous = false; 90 | // `true` when there exists a loop such that this array's last dimension is 91 | // NOT contiguously accessed on the deepest dimension. 92 | bool jumping = false; 93 | // `false` when some accesses to the array do not have SubscriptAttr. 94 | bool valid = true; 95 | }; 96 | std::unordered_map data; 97 | 98 | void collectDepth(Region *region, int depth); 99 | public: 100 | ColumnMajor(ModuleOp *module): Pass(module) {} 101 | 102 | std::string name() override { return "column-major"; } 103 | std::map stats() override { return {}; } 104 | void run() override; 105 | }; 106 | 107 | class Parallelize : public Pass { 108 | public: 109 | Parallelize(ModuleOp *module): Pass(module) {} 110 | 111 | std::string name() override { return "parallelize"; } 112 | std::map stats() override { return {}; } 113 | void run() override; 114 | }; 115 | 116 | class Unroll : public Pass { 117 | int unrolled = 0; 118 | public: 119 | Unroll(ModuleOp *module): Pass(module) {} 120 | 121 | std::string name() override { return "unroll"; } 122 | std::map stats() override; 123 | void run() override; 124 | }; 125 | 126 | // Try to hoist out access to adjacent elements of an array out of a loop. 127 | // This is done by GVN in LLVM, but I have no idea how it works there. 128 | class Adjacency : public Pass { 129 | void runImpl(Op *loop); 130 | public: 131 | Adjacency(ModuleOp *module): Pass(module) {} 132 | 133 | std::string name() override { return "adjacency"; } 134 | std::map stats() override { return {}; } 135 | void run() override; 136 | }; 137 | 138 | // Lower operations back to its original form. 139 | class Lower : public Pass { 140 | public: 141 | Lower(ModuleOp *module): Pass(module) {} 142 | 143 | std::string name() override { return "lower"; } 144 | std::map stats() override { return {}; } 145 | void run() override; 146 | }; 147 | 148 | } 149 | 150 | #endif 151 | -------------------------------------------------------------------------------- /src/utils/smt/CDCL.h: -------------------------------------------------------------------------------- 1 | #ifndef CLAUSE_H 2 | #define CLAUSE_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace smt { 11 | 12 | // Atomic propositions. 13 | // Encoding: 2x means the variable `x`, 2x+1 means the negation of it. 14 | // We need to index vectors with literals, so we can't use the obvious `-x`. 15 | using Variable = int; 16 | using Atomic = int; 17 | 18 | // A clause is a logical disjuntion of formulae. 19 | struct Clause { 20 | std::vector content; 21 | Atomic watch[2]; 22 | void dump(std::ostream &os = std::cerr) const; 23 | }; 24 | 25 | // https://www.cs.princeton.edu/~zkincaid/courses/fall18/readings/SATHandbook-CDCL.pdf 26 | // See this book for CDCL. 27 | class Solver { 28 | using Boolean = signed char; 29 | constexpr static Boolean False = -1; 30 | constexpr static Boolean True = 1; 31 | constexpr static Boolean Unassigned = 0; 32 | 33 | std::vector clauses; 34 | 35 | // We use -1 for false, 1 for true, 0 for unassigned. Note the difference from bools in C++. 36 | // Also note std::vector is not a vector of bools. 37 | std::vector assignment; 38 | std::vector antecedent; 39 | std::vector decisionLevel; 40 | 41 | // All assigned variable in order. It includes decision variables and variables updated by implication. 42 | std::vector trail; 43 | // trailLim[level] contains the index `i` of `trail` s.t. forall j >= i, decisionLevel[trail[j]] >= level. 44 | // In human-readable words, trail_lim[level] = index in trail where that level begins. 45 | std::vector trailLim; 46 | 47 | // Every clause watches 2 literals. See Chaff's paper: 48 | // https://www.princeton.edu/~chaff/publication/DAC2001v56.pdf 49 | // Here watched[i] contains all clauses that watch on literal `i`. 50 | // 51 | // The basic idea is, in unit propagation, we only care about the clauses (of size N) that have N-1 falses; 52 | // In other words, we would only want to be notified when an assignment might cause the false-count goes 53 | // from N-2 to N-1. 54 | // To approximately do this, we can pick 2 non-false literals at each clause. We don't need to care about any 55 | // updates to the clause if they do not touch those literals; only when one of them gets assigned to false, 56 | // it would become possible that the clause now have N-1 falses. 57 | std::vector> watched; 58 | 59 | // Current decision level. 60 | int dl = 0; 61 | // Total variable count. 62 | int varcnt = -1; 63 | // Conflict count since last decay. 64 | int conflict = 0; 65 | // The current place we should start dealing with unit propagation. 66 | size_t qhead = 0; 67 | 68 | // Set to true when `addClause` detects a conflict. 69 | bool addedConflict = false; 70 | // Set to true when unit propagation detects a conflict. 71 | bool unitConflict = false; 72 | 73 | Clause *conflictClause; 74 | 75 | // Activity scores for each variable. 76 | std::vector activity; 77 | // Heap of variables. 78 | std::priority_queue> vheap; 79 | // The phase (the preferred value of exploration) of variables. 80 | std::vector phase; 81 | 82 | // Increment of activity. 83 | double inc = 1.0; 84 | // Decay. 85 | static constexpr double decay = 0.95; 86 | 87 | Boolean value(Atomic atom); 88 | 89 | void unitPropagation(); 90 | void backtrack(int level); 91 | 92 | // Find the only unassigned variable within a clause. 93 | Atomic findUnit(const std::vector &unit); 94 | 95 | bool allAssigned(); 96 | 97 | // Returns the backtrack level and learnt clause. 98 | std::pair> analyzeConflict(); 99 | // Pick up a variable and a value to assign. 100 | std::pair pickPivot(); 101 | 102 | // Push `atom` to `trail` for further unit propagation, where `atom` must be true. 103 | void enqueue(Atomic atom, Clause *antecedent); 104 | 105 | Clause *addLearnt(std::vector clause); 106 | public: 107 | void dump(std::ostream &os = std::cerr) const; 108 | 109 | // Add external clause. Does extra tidying. 110 | Clause *addClause(const std::vector &clause); 111 | // Returns true if satisfiable. 112 | bool solve(std::vector &assignments); 113 | 114 | void init(int varcnt); 115 | }; 116 | 117 | class SATContext { 118 | int total = 0; 119 | public: 120 | int getTotal() { return total; } 121 | 122 | Variable create() { return total++; } 123 | Atomic neg(Variable x) { return (x << 1) + 1; } 124 | Atomic pos(Variable x) { return (x << 1); } 125 | void reset() { total = 0; } 126 | }; 127 | 128 | inline std::ostream &operator<<(std::ostream &os, const Clause *clause) { 129 | clause->dump(os); 130 | return os; 131 | } 132 | 133 | inline std::ostream &operator<<(std::ostream &os, const Solver &solver) { 134 | solver.dump(os); 135 | return os; 136 | } 137 | 138 | } 139 | 140 | #endif 141 | -------------------------------------------------------------------------------- /src/utils/presburger/BasicSet.cpp: -------------------------------------------------------------------------------- 1 | #include "BasicSet.h" 2 | #include 3 | #include 4 | #include 5 | 6 | using namespace pres; 7 | 8 | // See this tutorial of dual simplex; I wrote it. 9 | // https://gbc.xq.gl/polyhedral/maths/dual-simplex/ 10 | bool BasicSet::empty() { 11 | if (!denom.size()) 12 | denom = std::vector(tableau.size(), 1); 13 | 14 | for (;;) { 15 | // Choose a variable to evict from base. 16 | // We typically choose the x_r with minimal b_r. 17 | auto row = -1; 18 | auto min = 0; 19 | for (int i = 0; i < tableau.size(); ++i) { 20 | if (tableau[i].back() < min) { 21 | row = i; 22 | min = tableau[i].back(); 23 | } 24 | } 25 | 26 | // Now every element in `b` is non-negative. 27 | // We've found a feasible solution. 28 | if (row == -1) 29 | return false; 30 | 31 | // We need to find a variable to put autoo the base, whose value has to be positive. 32 | // This means we must make sure A_{rj} < 0. 33 | // We typically choose the A_j with minimal c_j / A_{rj}; but we don't have a target `c` here. 34 | // Instead, we'll choose the first one encountered. 35 | auto col = -1; 36 | // Don't take `b_r` autoo account. That column (1) can't be pivoted. 37 | for (int j = 0; j < tableau[row].size() - 1; ++j) { 38 | if (tableau[row][j] < 0) { 39 | col = j; 40 | break; 41 | } 42 | } 43 | 44 | // No valid pivot. Infeasible. 45 | if (col == -1) 46 | return true; 47 | 48 | // Pivot. 49 | auto pivot = tableau[row][col]; 50 | auto width = tableau[row].size(); 51 | 52 | // Normalize the pivot row. 53 | // We would write: 54 | // for (auto &x : tableau[row]) 55 | // x /= pivot; 56 | // 57 | // However, currently the pivot value is actually `pivot / denom[row]`, 58 | // so we'd multiply the denominator by `pivot` and everything else by denom[row]. 59 | // (We've had gcd(tableau[row], pivot) == 1 after each pivot step, so no simplification possible here.) 60 | for (auto &x : tableau[row]) 61 | x *= denom[row]; 62 | // Don't change it now. We'll need it further. 63 | denom[row] *= pivot; 64 | 65 | // Eliminate pivot column in all other rows. 66 | for (int i = 0; i < tableau.size(); ++i) { 67 | if (i == row) 68 | continue; 69 | 70 | // We would write: 71 | // auto factor = tableau[i][col]; 72 | // for (auto j = 0; j < width; ++j) 73 | // tableau[i][j] -= factor * tableau[row][j]; 74 | // 75 | // Then we're actually having 76 | // 77 | // tableau[i][j] = (tableau[i][j] / denom[i] - tableau[i][col] / denom[i] * tableau[row][j] / denom[row]) 78 | // 79 | // To clear it up: 80 | // 81 | // tableau[i][j] = (tableau[i][j] * denom[row] - tableau[i][col] * tableau[row][j]) / (denom[i] * denom[row]) 82 | // 83 | // So we just update according to that. 84 | 85 | denom[i] *= denom[row]; 86 | for (auto j = 0; j < tableau[0].size(); ++j) 87 | tableau[i][j] = tableau[i][j] * denom[row] - tableau[i][col] * tableau[row][j]; 88 | } 89 | 90 | // Calculate GCD to make numbers smaller. 91 | for (int i = 0; i < tableau.size(); i++) { 92 | auto &row = tableau[i]; 93 | auto &de = denom[i]; 94 | 95 | auto gcd = abs(de); 96 | for (auto x : row) { 97 | if ((gcd = std::gcd(gcd, abs(x))) == 1) 98 | break; 99 | } 100 | 101 | if (gcd != 1) { 102 | de /= gcd; 103 | for (auto &x : row) 104 | x /= gcd; 105 | } 106 | } 107 | } 108 | } 109 | 110 | void BasicSet::dump(std::ostream &os) { 111 | if (tableau.empty()) { 112 | os << "\n"; 113 | return; 114 | } 115 | 116 | const int rows = tableau.size(); 117 | const int cols = tableau[0].size(); 118 | 119 | // Compute max width per column for alignment. 120 | std::vector colWidths(cols, 0); 121 | for (int j = 0; j < cols; ++j) { 122 | size_t maxWidth = 0; 123 | for (int i = 0; i < rows; ++i) { 124 | auto num = tableau[i][j]; 125 | auto den = denom[i]; 126 | auto g = std::gcd(num, den); 127 | num /= g; 128 | den /= g; 129 | std::ostringstream ss; 130 | if (den == 1) 131 | ss << num; 132 | else 133 | ss << num << "/" << den; 134 | maxWidth = std::max(maxWidth, ss.str().size()); 135 | } 136 | colWidths[j] = maxWidth; 137 | } 138 | 139 | // Print. 140 | for (int i = 0; i < rows; ++i) { 141 | for (int j = 0; j < cols; ++j) { 142 | auto num = tableau[i][j]; 143 | auto den = denom[i]; 144 | auto g = std::gcd(num, den); 145 | num /= g; 146 | den /= g; 147 | 148 | std::ostringstream ss; 149 | if (den == 1) 150 | ss << num; 151 | else 152 | ss << num << "/" << den; 153 | 154 | os << std::setw(colWidths[j]) << ss.str() << (j + 1 < cols ? " " : ""); 155 | } 156 | os << "\n"; 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/opt/Cached.cpp: -------------------------------------------------------------------------------- 1 | #include "Passes.h" 2 | #include "../utils/Matcher.h" 3 | #include "../utils/Exec.h" 4 | 5 | using namespace sys; 6 | 7 | // Defined in Inline.cpp. 8 | bool isRecursive(Op *op); 9 | 10 | namespace { 11 | 12 | Rule minusone("(sub x 1)"); 13 | 14 | } 15 | 16 | void Cached::run() { 17 | // Identify candidate functions. 18 | auto funcs = collectFuncs(); 19 | for (auto func : funcs) { 20 | if (!isRecursive(func) || func->has()) 21 | continue; 22 | 23 | const auto &name = NAME(func); 24 | 25 | // Find the induction variables. 26 | auto calls = func->findAll(); 27 | for (auto call : calls) { 28 | // We're calling other functions, and it isn't generally possible to emulate. 29 | if (NAME(call) != name) 30 | return; 31 | } 32 | 33 | int argnum = func->get()->count; 34 | if (argnum > 3 || argnum <= 1) 35 | continue; 36 | 37 | // Only accept integers. 38 | auto getargs = func->findAll(); 39 | std::vector args(argnum); 40 | for (auto getarg : getargs) { 41 | if (getarg->getResultType() != Value::i32) 42 | return; 43 | args[V(getarg)] = getarg; 44 | } 45 | 46 | auto ret = func->findAll(); 47 | if (ret.empty() || ret[0]->DEF()->getResultType() != Value::i32) 48 | return; 49 | 50 | // Create a cache. 51 | using namespace exec; 52 | Interpreter interp(module); 53 | Builder builder; 54 | builder.setToRegionStart(module->getRegion()); 55 | const auto cachename = "__cache_" + name; 56 | if (argnum == 3) { 57 | auto vi = new cache_3; 58 | interp.useCache(vi); 59 | for (int i = 0; i < CACHE_3_N; i++) { 60 | for (int j = 0; j < CACHE_3_N; j++) { 61 | for (int k = 0; k < CACHE_3_N; k++) 62 | interp.runFunction(name, { i, j, k }); 63 | } 64 | } 65 | builder.create({ new NameAttr(cachename), 66 | new IntArrayAttr((int*) vi, CACHE_3_TOTAL), 67 | new SizeAttr(CACHE_3_TOTAL * 4) 68 | }); 69 | } 70 | 71 | if (argnum == 2) { 72 | auto vi = new cache_2; 73 | interp.useCache(vi); 74 | for (int i = 0; i < CACHE_2_N; i++) { 75 | for (int j = 0; j < CACHE_2_N; j++) 76 | interp.runFunction(name, { i, j }); 77 | } 78 | builder.create({ new NameAttr(cachename), 79 | new IntArrayAttr((int*) vi, CACHE_2_TOTAL), 80 | new SizeAttr(CACHE_2_TOTAL * 4) 81 | }); 82 | } 83 | 84 | // Utilize the cache. 85 | auto region = func->getRegion(); 86 | auto first = region->getFirstBlock(); 87 | builder.setToBlockStart(first); 88 | 89 | // Find the last `getarg`. 90 | auto insert = first->getFirstOp(); 91 | const auto &ops = first->getOps(); 92 | for (auto it = ops.rbegin(); it != ops.rend(); it++) { 93 | auto op = *it; 94 | if (isa(op)) { 95 | insert = op->nextOp(); 96 | break; 97 | } 98 | } 99 | 100 | // Insert an `if` check for it. 101 | std::vector inrange; 102 | builder.setBeforeOp(insert); 103 | int size = argnum == 2 ? CACHE_2_N : CACHE_3_N; 104 | auto max = builder.create({ new IntAttr(size) }); 105 | auto zero = builder.create({ new IntAttr(0) }); 106 | for (auto arg : getargs) { 107 | auto lt = builder.create({ arg, max }); 108 | auto ge = builder.create({ zero, arg }); 109 | inrange.push_back(lt); 110 | inrange.push_back(ge); 111 | } 112 | 113 | // Chain them with `and`. 114 | auto cond = inrange[0]; 115 | for (int i = 1; i < inrange.size(); i++) 116 | cond = builder.create({ cond, inrange[i]->getResult() }); 117 | 118 | // Create a branch to check cache hit. 119 | auto body = region->insertAfter(first); 120 | auto exit = region->insertAfter(body); 121 | first->splitOpsAfter(exit, cond); 122 | cond->moveToEnd(first); // `cond` itself will also be moved over. 123 | builder.setToBlockEnd(first); 124 | builder.create({ cond }, { new TargetAttr(body), new ElseAttr(exit) }); 125 | 126 | // On `body`, retrieve the cache content. 127 | builder.setToBlockStart(body); 128 | auto addr = builder.create({ new NameAttr(cachename) }); 129 | auto four = builder.create({ new IntAttr(4) }); 130 | if (argnum == 3) { 131 | auto stride1 = builder.create({ new IntAttr(size * size * 4) }); 132 | auto stride2 = builder.create({ new IntAttr(size * 4) }); 133 | 134 | auto mul1 = builder.create({ stride1, args[0] }); 135 | auto mul2 = builder.create({ stride2, args[1] }); 136 | auto mul3 = builder.create({ four, args[2] }); 137 | 138 | auto add1 = builder.create({ addr, mul1 }); 139 | auto add2 = builder.create({ mul2, mul3->getResult() }); 140 | auto add3 = builder.create({ add1, add2 }); 141 | 142 | auto value = builder.create(Value::i32, { add3 }, { new SizeAttr(4) }); 143 | builder.create({ value }); 144 | } 145 | 146 | if (argnum == 2) { 147 | // TODO 148 | } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/opt/Splice.cpp: -------------------------------------------------------------------------------- 1 | #include "LoopPasses.h" 2 | #include "Analysis.h" 3 | 4 | using namespace sys; 5 | 6 | // Defined in Specialize.cpp. 7 | void removeRange(Region *region); 8 | 9 | // Identify this pattern: 10 | // %1 = phi %2 %3 11 | // ... 12 | // %2 = ... 13 | // %3 = ... 14 | // 15 | // where `a` and `b` are constants. 16 | // This strongly suggests that the first iteration is different 17 | // and is worth splicing. 18 | void Splice::runImpl(LoopInfo *loop) { 19 | if (loop->latches.size() > 1 || !loop->preheader) 20 | return; 21 | 22 | auto latch = loop->getLatch(); 23 | // The loop should be rotated. 24 | if (!isa(latch->getLastOp())) 25 | return; 26 | 27 | auto induction = loop->getInduction(); 28 | if (!induction || !isa(loop->step)) 29 | return; 30 | 31 | auto header = loop->header; 32 | auto preheader = loop->preheader; 33 | auto phis = header->getPhis(); 34 | 35 | Op *splicer = nullptr; 36 | int vp, vl; 37 | for (auto phi : phis) { 38 | if (phi->getOperandCount() != 2) 39 | continue; 40 | 41 | auto l = Op::getPhiFrom(phi, latch), r = Op::getPhiFrom(phi, preheader); 42 | if (!l->has() || !r->has()) 43 | continue; 44 | 45 | auto [a1, a2] = RANGE(l); 46 | auto [b1, b2] = RANGE(r); 47 | if (a1 != a2 || b1 != b2) 48 | continue; 49 | 50 | splicer = phi; 51 | vl = a1, vp = b1; 52 | break; 53 | } 54 | if (!splicer) 55 | return; 56 | 57 | // Now copy the loop. 58 | std::unordered_map cloneMap; 59 | std::unordered_map rewireMap; 60 | 61 | auto region = header->getParent(); 62 | for (auto x : loop->getBlocks()) 63 | rewireMap[x] = region->insert(header); 64 | 65 | auto newpreheader = region->insert(header); 66 | 67 | // The new preheader should connect to header. 68 | Builder builder; 69 | builder.setToBlockEnd(newpreheader); 70 | builder.create({ new TargetAttr(header) }); 71 | 72 | // Shallow copy ops. 73 | for (auto [k, v] : rewireMap) { 74 | builder.setToBlockEnd(v); 75 | for (auto op : k->getOps()) { 76 | Op *cloned = builder.copy(op); 77 | cloneMap[op] = cloned; 78 | } 79 | } 80 | 81 | // Rewire operands. 82 | for (auto &[old, cloned] : cloneMap) { 83 | for (int i = 0; i < old->getOperandCount(); i++) { 84 | if (cloneMap.count(old->DEF(i))) 85 | cloned->setOperand(i, cloneMap[old->DEF(i)]); 86 | } 87 | } 88 | 89 | // Rewire blocks. 90 | for (auto [_, v] : rewireMap) { 91 | auto term = v->getLastOp(); 92 | if (auto attr = term->find(); attr && rewireMap.count(attr->bb)) 93 | attr->bb = rewireMap[attr->bb]; 94 | if (auto attr = term->find(); attr && rewireMap.count(attr->bb)) 95 | attr->bb = rewireMap[attr->bb]; 96 | } 97 | 98 | // In the inner loops, the phis must be fixed. 99 | for (auto [_, v] : rewireMap) { 100 | auto phis = v->getPhis(); 101 | for (auto phi : phis) { 102 | for (auto attr : phi->getAttrs()) { 103 | if (!isa(attr)) 104 | continue; 105 | auto &from = FROM(attr); 106 | if (rewireMap.count(from)) 107 | from = rewireMap[from]; 108 | } 109 | } 110 | } 111 | 112 | // The new latch should now get to the new preheader. 113 | auto term = rewireMap[latch]->getLastOp(); 114 | builder.replace(term, { new TargetAttr(newpreheader) }); 115 | 116 | // The preheader should go to the cloned header. 117 | builder.replace(preheader->getLastOp(), { new TargetAttr(rewireMap[header]) }); 118 | 119 | // The phi itself can be replaced with an IntOp. 120 | builder.setBeforeOp(preheader->getLastOp()); 121 | auto pi = builder.create({ new IntAttr(vp) }); 122 | cloneMap[splicer]->replaceAllUsesWith(pi); 123 | cloneMap[splicer]->erase(); 124 | 125 | auto li = builder.create({ new IntAttr(vl) }); 126 | splicer->replaceAllUsesWith(li); 127 | splicer->erase(); 128 | 129 | // Every phi of the new header should be replaced by the value from old preheader. 130 | phis = rewireMap[header]->getPhis(); 131 | for (auto phi : phis) { 132 | phi->replaceAllUsesWith(Op::getPhiFrom(phi, preheader)); 133 | phi->erase(); 134 | } 135 | 136 | // Every phi of the old header should be rewired to the new preheader. 137 | phis = header->getPhis(); 138 | for (auto phi : phis) { 139 | // The value of the old header's phi should be rewired from the new latch. 140 | phi->replaceOperand(Op::getPhiFrom(phi, preheader), cloneMap[Op::getPhiFrom(phi, latch)]); 141 | for (auto attr : phi->getAttrs()) { 142 | if (isa(attr) && FROM(attr) == preheader) 143 | FROM(attr) = newpreheader; 144 | } 145 | } 146 | } 147 | 148 | void Splice::run() { 149 | LoopAnalysis analysis(module); 150 | analysis.run(); 151 | auto forests = analysis.getResult(); 152 | 153 | Range(module).run(); 154 | for (const auto &[func, forest] : forests) { 155 | for (auto loop : forest.getLoops()) 156 | runImpl(loop); 157 | } 158 | 159 | auto funcs = collectFuncs(); 160 | for (auto func : funcs) { 161 | auto region = func->getRegion(); 162 | removeRange(region); 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/pre-opt/Parallelizable.cpp: -------------------------------------------------------------------------------- 1 | #include "PreAnalysis.h" 2 | 3 | using namespace sys; 4 | 5 | #define BAD(cond) if (cond) return 6 | 7 | void Parallelizable::runImpl(Op *loop, int depth) { 8 | NoStore(module).run(); 9 | auto fnMap = getFunctionMap(); 10 | // Find deeper loops inside the current one. 11 | auto region = loop->getRegion(); 12 | auto entry = region->getFirstBlock(); 13 | for (auto op : entry->getOps()) { 14 | if (isa(op)) 15 | runImpl(op, depth + 1); 16 | } 17 | // Cannot call impure functions, unless the function never stores. 18 | auto calls = loop->findAll(); 19 | for (auto call : calls) { 20 | const auto &name = NAME(call); 21 | BAD(call->has() && (isExtern(name) || !fnMap[name]->has())); 22 | } 23 | // Cannot have early return. 24 | if (loop->findAll().size()) 25 | return; 26 | 27 | // The subscript for this variable `loop` must be the same. 28 | std::unordered_map>> access, ops; 29 | auto stores = loop->findAll(); 30 | for (auto store : stores) { 31 | auto addr = store->DEF(1); 32 | BAD(!addr->has()); 33 | access[BASE(addr)].emplace_back(addr, true); 34 | ops[BASE(addr)].emplace_back(store, true); 35 | } 36 | 37 | auto loads = loop->findAll(); 38 | for (auto load : loads) { 39 | auto addr = load->DEF(); 40 | BAD(!addr->has()); 41 | access[BASE(addr)].emplace_back(addr, false); 42 | ops[BASE(addr)].emplace_back(load, false); 43 | } 44 | 45 | // Stores that aren't nested in inner loops. 46 | std::set directStores; 47 | for (auto op : region->getFirstBlock()->getOps()) { 48 | if (isa(op)) 49 | directStores.insert(BASE(op->DEF(1))); 50 | } 51 | 52 | for (const auto &[base, access] : access) { 53 | // Check subscript. 54 | assert(access.size()); 55 | auto [addr, isStore] = access[0]; 56 | if (!addr->has()) { 57 | // Acceptable if this scalar is only accessed in nested loops. 58 | if (!directStores.count(addr)) 59 | continue; 60 | // We can accept loads as long as there's no stores into it. 61 | for (auto [_, isStore] : access) 62 | BAD(isStore); 63 | continue; 64 | } 65 | 66 | // Similarly, we can accept loads to arrays when there are no stores. 67 | bool hasStore = false; 68 | for (auto [_, isStore] : access) { 69 | if (isStore) { 70 | hasStore = true; 71 | break; 72 | } 73 | } 74 | if (!hasStore) 75 | continue; 76 | 77 | // The stride and constant of current loop. 78 | const auto &subscript = SUBSCRIPT(addr); 79 | auto n = subscript[depth]; 80 | auto vi = n ? subscript.back() / (n / 4) : -1; 81 | 82 | for (auto [addr, _] : access) { 83 | BAD(!addr->has()); 84 | const auto &subscript = SUBSCRIPT(addr); 85 | auto n2 = subscript[depth]; 86 | auto vi2 = n2 ? subscript.back() / (n2 / 4) : -1; 87 | BAD(n2 != n || vi2 != vi); 88 | } 89 | } 90 | 91 | // The first load must have a preceding store, or no store at all. 92 | for (const auto &[base, access] : ops) { 93 | Op *load = nullptr; 94 | for (auto [op, isStore] : access) { 95 | if (!isStore) { 96 | load = op; 97 | break; 98 | } 99 | } 100 | // Alright. No stores, or no loads. 101 | if (!load || load == access[0].first) 102 | continue; 103 | 104 | auto [store, _] = access[0]; 105 | std::vector parents; 106 | for (auto runner = store; runner != loop; runner = runner->getParentOp()) 107 | parents.push_back(runner); 108 | for (auto runner = load; runner != loop; runner = runner->getParentOp()) { 109 | // check whether `runner` and anything in `parents` are on the same layer. 110 | bool decided = false, good = false; 111 | for (auto parent : parents) { 112 | if (runner->getParent() != parent->getParent()) 113 | continue; 114 | 115 | decided = true; 116 | // We expect `parent` to be front of `runner`. 117 | for (auto w = parent; !w->atBack(); w = w->nextOp()) { 118 | if (w == runner) { 119 | good = true; 120 | break; 121 | } 122 | } 123 | if (parent->getParent()->getLastOp() == runner) { 124 | good = true; 125 | break; 126 | } 127 | } 128 | if (decided && !good) 129 | return; 130 | if (decided) 131 | break; 132 | } 133 | } 134 | 135 | // Now it's parallelizable. 136 | loop->add(); 137 | } 138 | 139 | void Parallelizable::run() { 140 | ArrayAccess(module).run(); 141 | Base(module).run(); 142 | 143 | auto funcs = collectFuncs(); 144 | for (auto func : funcs) { 145 | auto region = func->getRegion(); 146 | // Clear stale data. 147 | for (auto bb : region->getBlocks()) { 148 | for (auto op : bb->getOps()) 149 | op->remove(); 150 | } 151 | 152 | for (auto bb : region->getBlocks()) { 153 | for (auto op : bb->getOps()) { 154 | if (isa(op)) 155 | runImpl(op, 0); 156 | } 157 | } 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /src/opt/InlineStore.cpp: -------------------------------------------------------------------------------- 1 | #include "Passes.h" 2 | #include 3 | 4 | using namespace sys; 5 | 6 | std::map InlineStore::stats() { 7 | return { 8 | { "inlined-stores", inlined } 9 | }; 10 | } 11 | 12 | namespace { 13 | 14 | bool hasStore(BasicBlock *bb) { 15 | for (auto x : bb->getOps()) { 16 | if (isa(x)) 17 | return true; 18 | } 19 | return false; 20 | } 21 | 22 | } 23 | 24 | #define BAD { bad = true; break; } 25 | 26 | void InlineStore::run() { 27 | auto gets = module->findAll(); 28 | auto gMap = getGlobalMap(); 29 | auto fMap = getFunctionMap(); 30 | 31 | // For each global, records in which functions they're used. 32 | std::unordered_map> used; 33 | for (auto get : gets) 34 | used[NAME(get)].insert(NAME(get->getParentOp())); 35 | 36 | // Remove unused globals, and find out ones only used in functions. 37 | std::vector queue; 38 | for (auto [k, v] : gMap) { 39 | if (used[k].empty() && !v->has()) 40 | v->erase(); 41 | if (used[k].size() == 1) { 42 | auto name = *used[k].begin(); 43 | if (fMap[name]->has()) 44 | queue.push_back(k); 45 | } 46 | } 47 | 48 | for (auto [_, v] : fMap) 49 | v->getRegion()->updateDoms(); 50 | 51 | for (auto gname : queue) { 52 | auto funcname = *used[gname].begin(); 53 | Op *func = fMap[funcname]; 54 | Builder builder; 55 | 56 | auto region = func->getRegion(); 57 | auto entry = region->getFirstBlock(); 58 | auto glob = gMap[gname]; 59 | bool fp = glob->has(); 60 | bool bad = false; 61 | 62 | for (auto runner = entry; runner->succs.size();) { 63 | auto ops = runner->getOps(); 64 | for (auto op : ops) { 65 | if (isa(op)) { 66 | if (!op->DEF()->has()) 67 | BAD 68 | 69 | auto alias = ALIAS(op->DEF()); 70 | if (alias->location.size() > 1) 71 | BAD 72 | 73 | auto [base, offsets] = *alias->location.begin(); 74 | if (offsets.size() > 1 || offsets[0] == -1) 75 | BAD 76 | if (base != glob) 77 | BAD 78 | 79 | auto offset = offsets[0]; 80 | builder.setBeforeOp(op); 81 | if (fp) { 82 | auto attr = glob->get(); 83 | auto f = builder.create({ new FloatAttr(attr->vf[offset / 4]) }); 84 | op->replaceAllUsesWith(f); 85 | } else { 86 | auto attr = glob->get(); 87 | auto i = builder.create({ new IntAttr(attr->vi[offset / 4]) }); 88 | op->replaceAllUsesWith(i); 89 | } 90 | op->erase(); 91 | } 92 | 93 | if (isa(op)) { 94 | if (!op->DEF(1)->has()) 95 | BAD 96 | 97 | auto alias = ALIAS(op->DEF(1)); 98 | if (alias->location.size() > 1) 99 | BAD 100 | 101 | auto [base, offsets] = *alias->location.begin(); 102 | if (offsets.size() > 1 || offsets[0] == -1) 103 | BAD 104 | if (base != glob) 105 | BAD 106 | 107 | auto offset = offsets[0]; 108 | if (fp) { 109 | if (!isa(op->DEF(0))) 110 | continue; 111 | 112 | auto attr = glob->get(); 113 | attr->vf[offset / 4] = F(op->DEF(0)); 114 | } else { 115 | if (!isa(op->DEF(0))) 116 | continue; 117 | 118 | auto attr = glob->get(); 119 | attr->vi[offset / 4] = V(op->DEF(0)); 120 | } 121 | inlined++; 122 | op->erase(); 123 | } 124 | } 125 | 126 | if (bad) 127 | break; 128 | 129 | auto term = runner->getLastOp(); 130 | if (isa(term) && TARGET(term)->preds.size() == 1) 131 | runner = TARGET(term); 132 | else if (isa(term)) { 133 | // These globals behave as local variables. 134 | // So if all successors of one branch doesn't have any stores, 135 | // and it isn't a loop-back edge, 136 | // then it's alright to inline stores in another branch. 137 | auto ifso = TARGET(term); 138 | auto ifnot = ELSE(term); 139 | // Don't check too far (haven't thought of how to handle loops yet.) 140 | // Just check the next block. 141 | if (isa(ifnot->getLastOp()) && !hasStore(ifnot) && !ifso->dominates(runner)) 142 | runner = ifso; 143 | else if (isa(ifso->getLastOp()) && !hasStore(ifso) && !ifnot->dominates(runner)) 144 | runner = ifnot; 145 | else break; 146 | } else break; 147 | } 148 | 149 | // Update allzero attribute. 150 | if (fp) { 151 | auto attr = glob->get(); 152 | for (int i = 0; i < attr->size; i++) { 153 | if (attr->vf[i] != 0) { 154 | attr->allZero = false; 155 | break; 156 | } 157 | } 158 | } else { 159 | auto attr = glob->get(); 160 | for (int i = 0; i < attr->size; i++) { 161 | if (attr->vi[i] != 0) { 162 | attr->allZero = false; 163 | break; 164 | } 165 | } 166 | } 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/opt/Mem2Reg.cpp: -------------------------------------------------------------------------------- 1 | #include "Passes.h" 2 | #include "../codegen/CodeGen.h" 3 | #include "../codegen/Attrs.h" 4 | #include 5 | 6 | using namespace sys; 7 | 8 | std::map Mem2Reg::stats() { 9 | return { 10 | { "lowered-alloca", count }, 11 | { "missed-alloca", missed }, 12 | }; 13 | } 14 | 15 | // See explanation at https://longfangsong.github.io/en/mem2reg-made-simple/ 16 | void Mem2Reg::runImpl(FuncOp *func) { 17 | converted.clear(); 18 | visited.clear(); 19 | phiFrom.clear(); 20 | domtree.clear(); 21 | 22 | auto region = func->getRegion(); 23 | region->updateDomFront(); 24 | domtree = getDomTree(region); 25 | 26 | Builder builder; 27 | 28 | // We need to put PhiOp at places where a StoreOp doesn't dominate, 29 | // because it means at least 2 possible values. 30 | auto allocas = func->findAll(); 31 | for (auto alloca : allocas) { 32 | bool good = true; 33 | 34 | // If the alloca is used for, as an example, AddOp, then 35 | // it's an array and can't be promoted to registers. 36 | for (auto use : alloca->getUses()) { 37 | if (!isa(use) && !isa(use)) { 38 | good = false; 39 | break; 40 | } 41 | // If the alloca is used as a value in a StoreOp, then it has to be an array. 42 | if (isa(use) && use->DEF(0) == alloca) { 43 | good = false; 44 | break; 45 | } 46 | } 47 | 48 | if (!good) { 49 | missed++; 50 | continue; 51 | } 52 | count++; 53 | converted.insert(alloca); 54 | 55 | // Now find all blocks where stores reside in. Use set to de-duplicate. 56 | std::set bbs; 57 | for (auto use : alloca->getUses()) { 58 | if (isa(use)) 59 | bbs.insert(use->getParent()); 60 | } 61 | 62 | std::vector worklist; 63 | std::copy(bbs.begin(), bbs.end(), std::back_inserter(worklist)); 64 | 65 | std::set visited; 66 | 67 | while (!worklist.empty()) { 68 | auto bb = worklist.back(); 69 | worklist.pop_back(); 70 | 71 | for (auto dom : bb->getDominanceFrontier()) { 72 | if (visited.count(dom)) 73 | continue; 74 | visited.insert(dom); 75 | 76 | // Insert a PhiOp at the dominance frontier of each StoreOp, as described above. 77 | // The PhiOp is broken; we only record which AllocaOp it's from. 78 | // We'll fill it in later. 79 | builder.setToBlockStart(dom); 80 | auto phi = builder.create(); 81 | phiFrom[phi] = alloca; 82 | worklist.push_back(dom); 83 | } 84 | } 85 | } 86 | 87 | fillPhi(func->getRegion()->getFirstBlock(), {}); 88 | 89 | for (auto alloca : converted) 90 | alloca->erase(); 91 | } 92 | 93 | void Mem2Reg::fillPhi(BasicBlock *bb, SymbolTable symbols) { 94 | if (visited.count(bb)) 95 | return; 96 | visited.insert(bb); 97 | 98 | Builder builder; 99 | 100 | std::vector removed; 101 | for (auto op : bb->getOps()) { 102 | // Loads are now ordinary reads. 103 | if (auto load = dyn_cast(op)) { 104 | auto alloca = load->getOperand().defining; 105 | if (!converted.count(alloca)) 106 | continue; 107 | 108 | if (!symbols.count(alloca)) { 109 | builder.setBeforeOp(load); 110 | bool fp = alloca->has(); 111 | symbols[alloca] = fp 112 | ? (Op*) builder.create({ new FloatAttr(0) }) 113 | : (Op*) builder.create({ new IntAttr(0) }); 114 | } 115 | 116 | load->replaceAllUsesWith(symbols[alloca].defining); 117 | removed.push_back(load); 118 | } 119 | 120 | // Stores are now mutating symbol table. 121 | if (auto store = dyn_cast(op)) { 122 | auto value = store->getOperand(0); 123 | auto alloca = store->getOperand(1).defining; 124 | if (!converted.count(alloca)) 125 | continue; 126 | symbols[alloca] = value; 127 | 128 | removed.push_back(store); 129 | } 130 | 131 | if (auto phi = dyn_cast(op)) { 132 | if (!phiFrom.count(phi)) 133 | continue; 134 | auto alloca = phiFrom[phi]; 135 | symbols[alloca] = phi; 136 | } 137 | } 138 | 139 | for (auto succ : bb->succs) { 140 | auto phis = succ->getPhis(); 141 | for (auto op : phis) { 142 | auto alloca = phiFrom[cast(op)]; 143 | 144 | // We meet a PhiOp. This means the promoted register might hold value `symbols[alloca]` when it reaches here. 145 | // So this PhiOp should have that value as operand as well. 146 | Value value; 147 | 148 | // It doesn't have an initial value from this path. 149 | // It's acceptable (for example a variable defined only in a loop) 150 | // Treat it as zero from this branch. 151 | if (!symbols.count(alloca)) { 152 | // Create a zero at the back of the incoming edge. 153 | auto term = bb->getLastOp(); 154 | builder.setBeforeOp(term); 155 | value = builder.create({ new IntAttr(0) }); 156 | } else 157 | value = symbols[alloca]; 158 | 159 | op->pushOperand(value); 160 | op->add(bb); 161 | } 162 | } 163 | 164 | for (auto x : removed) 165 | x->erase(); 166 | 167 | for (auto child : domtree[bb]) 168 | fillPhi(child, symbols); 169 | } 170 | 171 | void Mem2Reg::run() { 172 | auto funcs = collectFuncs(); 173 | for (auto func : funcs) 174 | runImpl(func); 175 | } 176 | --------------------------------------------------------------------------------