├── README.md ├── build.bash ├── codegen.bash ├── lib ├── Makefile ├── antlr-4.7-complete.jar ├── antlr4-runtime.tar.gz ├── antlr_generated.tar.gz ├── boost_mini_1_64_0.tar.gz ├── boost_sha1.tar.gz ├── libboost_program_options.a └── option_parser.tar.gz ├── optim.bash ├── precompile_antlr_generated.bash ├── precompile_option_parser.bash ├── rules ├── MxLexer.g4 └── MxParser.g4 ├── semantic.bash ├── slide.pdf └── src ├── ASM.h ├── ASM_x64.cpp ├── AST.cpp ├── AST.h ├── ASTConstructor.cpp ├── ASTConstructor.h ├── ASTVisualizer.cpp ├── ASTVisualizer.h ├── CodeGenerator.cpp ├── CodeGenerator.h ├── CodeGeneratorBasic.cpp ├── CodeGeneratorBasic.h ├── ConstantFold.cpp ├── ConstantFold.h ├── DeadCodeElimination.cpp ├── DeadCodeElimination.h ├── GVN.cpp ├── GVN.h ├── GlobalSymbol.cpp ├── GlobalSymbol.h ├── IR.cpp ├── IR.h ├── IRGenerator.cpp ├── IRGenerator.h ├── IRVisualizer.cpp ├── IRVisualizer.h ├── InlineOptimizer.cpp ├── InlineOptimizer.h ├── InstructionSelect.cpp ├── InstructionSelect.h ├── IssueCollector.cpp ├── IssueCollector.h ├── LoadCombine.cpp ├── LoadCombine.h ├── LoopDetector.cpp ├── LoopDetector.h ├── LoopInvariantOptimizer.cpp ├── LoopInvariantOptimizer.h ├── MxBuiltin.cpp ├── MxBuiltin.h ├── MxProgram.cpp ├── MxProgram.h ├── RegisterAllocatorSSA.cpp ├── RegisterAllocatorSSA.h ├── SSAConstructor.cpp ├── SSAConstructor.h ├── SSAReconstructor.cpp ├── SSAReconstructor.h ├── StaticTypeChecker.cpp ├── StaticTypeChecker.h ├── common.h ├── common_headers.h ├── main.cpp ├── option_parser.cpp ├── option_parser.h └── utils ├── CycleEquiv.cpp ├── CycleEquiv.h ├── DepGraph.cpp ├── DepGraph.h ├── DispatchLength.h ├── DomTree.cpp ├── DomTree.h ├── ElementAdapter.h ├── JoinIterator.h ├── MaxClique.cpp ├── MaxClique.h └── UnionFindSet.h /README.md: -------------------------------------------------------------------------------- 1 | # MxCompiler 2 | A Mx-language compiler focused on backend optimization and code generation for x86-64. 3 | 4 | This is the codebase for compiler project of ACM class @ SJTU 5 | 6 | Useful links: 7 | - [wiki](https://acm.sjtu.edu.cn/wiki/Compiler_2017) 8 | - [language reference](https://acm.sjtu.edu.cn/w/images/3/30/M_language_manual.pdf?20170401) 9 | 10 | ## Main Optimizations 11 | - Function Inlining 12 | - Register Allocation based on SSA 13 | - Loop-invariant Code Motion 14 | - Dead Code Elimination 15 | - Global Value Numbering 16 | -------------------------------------------------------------------------------- /build.bash: -------------------------------------------------------------------------------- 1 | set -e 2 | cat /proc/cpuinfo | grep "model name" 3 | cat /proc/meminfo | grep "MemTotal" 4 | rm -rf build 5 | mkdir -p build 6 | cd build 7 | echo extracting antlr generated files... 8 | time tar -zxf ../lib/antlr_generated.tar.gz 9 | echo extracting option parser... 10 | time tar -zxf ../lib/option_parser.tar.gz 11 | echo extracting antlr4... 12 | time tar -zxf ../lib/antlr4-runtime.tar.gz 13 | cp ../lib/libboost_program_options.a . 14 | echo extracting boost... 15 | time tar -zxf ../lib/boost_sha1.tar.gz 16 | cp ../lib/Makefile . 17 | echo compiling... 18 | time make $@ 19 | -------------------------------------------------------------------------------- /codegen.bash: -------------------------------------------------------------------------------- 1 | cd build 2 | cat > program.mx 3 | ./mxcompiler program.mx -o program.asm --fdisable-access-protect --optim-reg-alloc --optim-inline --optim-loop-invariant --optim-dead-code --optim-gvn 4 | cat program.asm 5 | -------------------------------------------------------------------------------- /lib/Makefile: -------------------------------------------------------------------------------- 1 | CPP := g++ 2 | 3 | CFLAGS_O2 := -Isrc/ -Igenerated/ -Iinclude/ -std=c++14 -O2 -DNDEBUG 4 | CFLAGS := -Isrc/ -Igenerated/ -Iinclude/ -std=c++14 -O1 -DNDEBUG 5 | CFLAGS_O0 := -Isrc/ -Igenerated/ -Iinclude/ -std=c++14 -DNDEBUG 6 | LDFLAGS := 7 | 8 | mxcompiler: libantlr4-runtime.a antlr_generated.a libboost_program_options.a common_headers.h.gch option_parser.o ASM_x64.o AST.o ASTConstructor.o CodeGenerator.o CodeGeneratorBasic.o ConstantFold.o DeadCodeElimination.o GlobalSymbol.o GVN.o InlineOptimizer.o InstructionSelect.o IR.o IRGenerator.o IssueCollector.o LoadCombine.o LoopDetector.o LoopInvariantOptimizer.o main.o MxBuiltin.o MxProgram.o RegisterAllocatorSSA.o SSAConstructor.o SSAReconstructor.o StaticTypeChecker.o CycleEquiv.o DomTree.o MaxClique.o 9 | $(CPP) $(LDFLAGS) option_parser.o ASM_x64.o AST.o ASTConstructor.o CodeGenerator.o CodeGeneratorBasic.o ConstantFold.o DeadCodeElimination.o GlobalSymbol.o GVN.o InlineOptimizer.o InstructionSelect.o IR.o IRGenerator.o IssueCollector.o LoadCombine.o LoopDetector.o LoopInvariantOptimizer.o main.o MxBuiltin.o MxProgram.o RegisterAllocatorSSA.o SSAConstructor.o SSAReconstructor.o StaticTypeChecker.o CycleEquiv.o DomTree.o MaxClique.o antlr_generated.a libantlr4-runtime.a libboost_program_options.a -o mxcompiler 10 | 11 | common_headers.h.gch: ../src/common_headers.h 12 | $(CPP) ../src/common_headers.h -o common_headers.h.gch $(CFLAGS) 13 | 14 | ASM_x64.o: ../src/ASM_x64.cpp 15 | $(CPP) -c ../src/ASM_x64.cpp -o ASM_x64.o $(CFLAGS_O0) 16 | AST.o: ../src/AST.cpp 17 | $(CPP) -c ../src/AST.cpp -o AST.o $(CFLAGS) 18 | ASTConstructor.o: ../src/ASTConstructor.cpp 19 | $(CPP) -c ../src/ASTConstructor.cpp -o ASTConstructor.o $(CFLAGS) 20 | CodeGenerator.o: ../src/CodeGenerator.cpp 21 | $(CPP) -c ../src/CodeGenerator.cpp -o CodeGenerator.o $(CFLAGS) 22 | CodeGeneratorBasic.o: ../src/CodeGeneratorBasic.cpp 23 | $(CPP) -c ../src/CodeGeneratorBasic.cpp -o CodeGeneratorBasic.o $(CFLAGS) 24 | ConstantFold.o: ../src/ConstantFold.cpp 25 | $(CPP) -c ../src/ConstantFold.cpp -o ConstantFold.o $(CFLAGS_O0) 26 | DeadCodeElimination.o: ../src/DeadCodeElimination.cpp 27 | $(CPP) -c ../src/DeadCodeElimination.cpp -o DeadCodeElimination.o $(CFLAGS_O2) 28 | GlobalSymbol.o: ../src/GlobalSymbol.cpp 29 | $(CPP) -c ../src/GlobalSymbol.cpp -o GlobalSymbol.o $(CFLAGS_O0) 30 | GVN.o: ../src/GVN.cpp 31 | $(CPP) -c ../src/GVN.cpp -o GVN.o $(CFLAGS_O2) 32 | InlineOptimizer.o: ../src/InlineOptimizer.cpp 33 | $(CPP) -c ../src/InlineOptimizer.cpp -o InlineOptimizer.o $(CFLAGS) 34 | InstructionSelect.o: ../src/InstructionSelect.cpp 35 | $(CPP) -c ../src/InstructionSelect.cpp -o InstructionSelect.o $(CFLAGS) 36 | IR.o: ../src/IR.cpp 37 | $(CPP) -c ../src/IR.cpp -o IR.o $(CFLAGS) 38 | IRGenerator.o: ../src/IRGenerator.cpp 39 | $(CPP) -c ../src/IRGenerator.cpp -o IRGenerator.o $(CFLAGS) 40 | IssueCollector.o: ../src/IssueCollector.cpp 41 | $(CPP) -c ../src/IssueCollector.cpp -o IssueCollector.o $(CFLAGS_O0) 42 | LoadCombine.o: ../src/LoadCombine.cpp 43 | $(CPP) -c ../src/LoadCombine.cpp -o LoadCombine.o $(CFLAGS) 44 | LoopDetector.o: ../src/LoopDetector.cpp 45 | $(CPP) -c ../src/LoopDetector.cpp -o LoopDetector.o $(CFLAGS) 46 | LoopInvariantOptimizer.o: ../src/LoopInvariantOptimizer.cpp 47 | $(CPP) -c ../src/LoopInvariantOptimizer.cpp -o LoopInvariantOptimizer.o $(CFLAGS_O2) 48 | main.o: ../src/main.cpp 49 | $(CPP) -c ../src/main.cpp -o main.o $(CFLAGS_O0) 50 | MxBuiltin.o: ../src/MxBuiltin.cpp 51 | $(CPP) -c ../src/MxBuiltin.cpp -o MxBuiltin.o $(CFLAGS) 52 | MxProgram.o: ../src/MxProgram.cpp 53 | $(CPP) -c ../src/MxProgram.cpp -o MxProgram.o $(CFLAGS_O0) 54 | RegisterAllocatorSSA.o: ../src/RegisterAllocatorSSA.cpp 55 | $(CPP) -c ../src/RegisterAllocatorSSA.cpp -o RegisterAllocatorSSA.o $(CFLAGS_O2) 56 | SSAConstructor.o: ../src/SSAConstructor.cpp 57 | $(CPP) -c ../src/SSAConstructor.cpp -o SSAConstructor.o $(CFLAGS) 58 | SSAReconstructor.o: ../src/SSAReconstructor.cpp 59 | $(CPP) -c ../src/SSAReconstructor.cpp -o SSAReconstructor.o $(CFLAGS) 60 | StaticTypeChecker.o: ../src/StaticTypeChecker.cpp 61 | $(CPP) -c ../src/StaticTypeChecker.cpp -o StaticTypeChecker.o $(CFLAGS) 62 | CycleEquiv.o: ../src/utils/CycleEquiv.cpp 63 | $(CPP) -c ../src/utils/CycleEquiv.cpp -o CycleEquiv.o $(CFLAGS_O2) 64 | DomTree.o: ../src/utils/DomTree.cpp 65 | $(CPP) -c ../src/utils/DomTree.cpp -o DomTree.o $(CFLAGS_O2) 66 | MaxClique.o: ../src/utils/MaxClique.cpp 67 | $(CPP) -c ../src/utils/MaxClique.cpp -o MaxClique.o $(CFLAGS_O2) 68 | -------------------------------------------------------------------------------- /lib/antlr-4.7-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/antlr-4.7-complete.jar -------------------------------------------------------------------------------- /lib/antlr4-runtime.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/antlr4-runtime.tar.gz -------------------------------------------------------------------------------- /lib/antlr_generated.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/antlr_generated.tar.gz -------------------------------------------------------------------------------- /lib/boost_mini_1_64_0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/boost_mini_1_64_0.tar.gz -------------------------------------------------------------------------------- /lib/boost_sha1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/boost_sha1.tar.gz -------------------------------------------------------------------------------- /lib/libboost_program_options.a: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/libboost_program_options.a -------------------------------------------------------------------------------- /lib/option_parser.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/lib/option_parser.tar.gz -------------------------------------------------------------------------------- /optim.bash: -------------------------------------------------------------------------------- 1 | cd build 2 | cat > program.mx 3 | ./mxcompiler program.mx -o program.asm --fdisable-access-protect --optim-reg-alloc --optim-inline --optim-loop-invariant --optim-dead-code --optim-gvn 4 | cat program.asm 5 | -------------------------------------------------------------------------------- /precompile_antlr_generated.bash: -------------------------------------------------------------------------------- 1 | set -e 2 | rm -rf build 3 | mkdir -p build/generated 4 | cd rules 5 | java -jar ../lib/antlr-4.7-complete.jar -Dlanguage=Cpp MxLexer.g4 MxParser.g4 -o ../build/generated/ -no-listener -visitor 6 | cd ../build 7 | tar -zxf ../lib/antlr4-runtime.tar.gz 8 | g++ -c generated/*.cpp -Igenerated/ -Isrc/ -std=c++14 -O2 9 | ar -r antlr_generated.a *.o 10 | rm generated/*.cpp 11 | tar -zcf antlr_generated.tar.gz generated/ antlr_generated.a -------------------------------------------------------------------------------- /precompile_option_parser.bash: -------------------------------------------------------------------------------- 1 | set -e 2 | rm -rf build 3 | mkdir -p build 4 | cd build 5 | tar -zxf ../lib/antlr4-runtime.tar.gz 6 | tar -zxf ../lib/boost_mini_1_64_0.tar.gz 7 | g++ -c ../src/option_parser.cpp -o option_parser.o -Isrc/ -Iinclude/ -std=c++14 -O2 8 | tar -zcf option_parser.tar.gz option_parser.o -------------------------------------------------------------------------------- /rules/MxLexer.g4: -------------------------------------------------------------------------------- 1 | lexer grammar MxLexer; 2 | 3 | BoolType: 'bool'; 4 | IntType: 'int'; 5 | StringType: 'string'; 6 | Null: 'null'; 7 | Void: 'void'; 8 | True: 'true'; 9 | False: 'false'; 10 | If: 'if'; 11 | Else: 'else'; 12 | For: 'for'; 13 | While: 'while'; 14 | Break: 'break'; 15 | Continue: 'continue'; 16 | Return: 'return'; 17 | New: 'new'; 18 | Class: 'class'; 19 | This: 'this'; 20 | 21 | Plus: '+'; 22 | Minus: '-'; 23 | Multi: '*'; 24 | Div: '/'; 25 | Mod: '%'; 26 | GreaterThan: '>'; 27 | LessThan: '<'; 28 | Equal: '=='; 29 | NotEqual: '!='; 30 | GreaterEqual: '>='; 31 | LessEqual: '<='; 32 | 33 | And: '&&'; 34 | Or: '||'; 35 | Not: '!'; 36 | ShiftLeft: '<<'; 37 | ShiftRight: '>>'; 38 | BitNot: '~'; 39 | BitOr: '|'; 40 | BitAnd: '&'; 41 | BitXor: '^'; 42 | 43 | Assign: '='; 44 | 45 | Increment: '++'; 46 | Decrement: '--'; 47 | 48 | Dot: '.'; 49 | 50 | OpenPar: '('; 51 | ClosePar: ')'; 52 | OpenSqu: '['; 53 | CloseSqu: ']'; 54 | OpenCurly: '{'; 55 | CloseCurly: '}'; 56 | 57 | Semicolon: ';'; 58 | 59 | Comma: ','; 60 | 61 | Id: [A-Za-z_][A-Za-z0-9_]*; 62 | 63 | IntegerDec: [1-9][0-9]*|'0'; 64 | 65 | String: '"' (~'\\'|'\\'.)*? '"'; 66 | 67 | Comment: '//' ~[\r\n]* '\r'? '\n' -> skip; 68 | CommentBlock: '/*' .*? '*/' ->skip; 69 | Whitespace: [ \t\r\n] -> skip; -------------------------------------------------------------------------------- /rules/MxParser.g4: -------------------------------------------------------------------------------- 1 | parser grammar MxParser; 2 | 3 | options 4 | { 5 | tokenVocab = MxLexer; 6 | } 7 | 8 | exprList: (expr (Comma expr)*)?; 9 | exprNewDim: OpenSqu expr? CloseSqu; 10 | 11 | exprPrimary: Id 12 | | String 13 | | IntegerDec 14 | | True 15 | | False 16 | | This 17 | | Null 18 | | exprPar 19 | ; 20 | 21 | exprPar: OpenPar expr ClosePar; 22 | subexprPostfix: exprPrimary #subexpr0 23 | | subexprPostfix Increment #exprIncrementPostfix 24 | | subexprPostfix Decrement #exprDecrementPostfix 25 | | subexprPostfix Dot Id #exprMember 26 | | subexprPostfix OpenPar exprList ClosePar #exprFuncCall 27 | | subexprPostfix OpenSqu expr CloseSqu #exprSubscript 28 | ; 29 | 30 | subexprPrefix: subexprPostfix #subexpr1 31 | | Increment subexprPrefix #exprIncrementPrefix 32 | | Decrement subexprPrefix #exprDecrementPrefix 33 | | Plus subexprPrefix #exprPositive 34 | | Minus subexprPrefix #exprNegative 35 | | Not subexprPrefix #exprNot 36 | | BitNot subexprPrefix #exprBitNot 37 | | New typeNotArray ((OpenPar exprList ClosePar) | (exprNewDim)*) #exprNew 38 | ; 39 | subexprMultiDiv: subexprPrefix #subexpr2 40 | | subexprMultiDiv Multi subexprPrefix #exprMulti 41 | | subexprMultiDiv Div subexprPrefix #exprDiv 42 | | subexprMultiDiv Mod subexprPrefix #exprMod 43 | ; 44 | subexprPlusMinus: subexprMultiDiv #subexpr3 45 | | subexprPlusMinus Plus subexprMultiDiv #exprPlus 46 | | subexprPlusMinus Minus subexprMultiDiv #exprMinus 47 | ; 48 | 49 | subexprShift: subexprPlusMinus #subexpr4 50 | | subexprShift ShiftLeft subexprPlusMinus #exprShiftLeft 51 | | subexprShift ShiftRight subexprPlusMinus #exprShiftRight 52 | ; 53 | subexprCompareRel: subexprShift #subexpr5 54 | | subexprCompareRel LessThan subexprShift #exprLessThan 55 | | subexprCompareRel LessEqual subexprShift #exprLessEqual 56 | | subexprCompareRel GreaterThan subexprShift #exprGreaterThan 57 | | subexprCompareRel GreaterEqual subexprShift #exprGreaterEqual 58 | ; 59 | subexprCompareEqu: subexprCompareRel #subexpr6 60 | | subexprCompareEqu Equal subexprCompareRel #exprEqual 61 | | subexprCompareEqu NotEqual subexprCompareRel #exprNotEqual 62 | ; 63 | 64 | subexprBitand: subexprCompareEqu #subexpr7 65 | | subexprBitand BitAnd subexprCompareEqu #exprBitand 66 | ; 67 | subexprXor: subexprBitand #subexpr8 68 | | subexprXor BitXor subexprBitand #exprXor 69 | ; 70 | subexprBitor: subexprXor #subexpr9 71 | | subexprBitor BitOr subexprXor #exprBitor 72 | ; 73 | subexprAnd: subexprBitor #subexpr10 74 | | subexprAnd And subexprBitor #exprAnd 75 | ; 76 | subexprOr: subexprAnd #subexpr11 77 | | subexprOr Or subexprAnd #exprOr 78 | ; 79 | 80 | subexprAssignment: subexprOr #subexpr12 81 | | subexprOr Assign subexprAssignment #exprAssignment 82 | ; 83 | 84 | expr: subexprAssignment; 85 | 86 | typeInternal: IntType 87 | | StringType 88 | | BoolType 89 | ; 90 | 91 | typeNotArray: (typeInternal | Id); 92 | type: typeNotArray (OpenSqu CloseSqu)*; 93 | 94 | varDecl: type Id (Assign expr)? (Comma Id (Assign expr)?)* Semicolon; 95 | paramList: (type Id (Comma type Id)*)?; 96 | funcDecl: (type | Void)? Id OpenPar paramList ClosePar block; 97 | memberList: (varDecl | funcDecl)*; 98 | classDecl: Class Id OpenCurly memberList CloseCurly; 99 | if_statement: If OpenPar expr ClosePar statement (Else statement)?; 100 | 101 | for_exprIn: (varDecl | expr? Semicolon); 102 | for_exprCond: expr? Semicolon; 103 | for_exprStep: expr?; 104 | for_statement: For OpenPar for_exprIn for_exprCond for_exprStep ClosePar statement; 105 | 106 | while_statement: While OpenPar expr ClosePar statement; 107 | statement: block | if_statement | for_statement | while_statement | varDecl | ((expr | Continue | Break | Return expr?)? Semicolon); 108 | block: OpenCurly statement* CloseCurly; 109 | prog: (classDecl | funcDecl | varDecl)*; -------------------------------------------------------------------------------- /semantic.bash: -------------------------------------------------------------------------------- 1 | cd build 2 | cat > program.mx 3 | ./mxcompiler program.mx -o program.asm --fdisable-access-protect 4 | -------------------------------------------------------------------------------- /slide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/slide.pdf -------------------------------------------------------------------------------- /src/ASM.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_ASM_H 2 | #define MX_COMPILER_ASM_H 3 | 4 | #include "common.h" 5 | 6 | std::string regName(int id, size_t size); 7 | std::string sizeName(size_t size); 8 | 9 | #endif -------------------------------------------------------------------------------- /src/ASM_x64.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "ASM.h" 3 | 4 | std::string regName(int id, size_t size) 5 | { 6 | static const std::string regTable[8][4] = { 7 | { "al", "ax", "eax", "rax" }, 8 | { "cl", "cx", "ecx", "rcx" }, 9 | { "dl", "dx", "edx", "rdx" }, 10 | { "bl", "bx", "ebx", "rbx" }, 11 | { "spl", "sp", "esp", "rsp" }, 12 | { "bpl", "bp", "ebp", "rbp" }, 13 | { "sil", "si", "esi", "rsi" }, 14 | { "dil", "di", "edi", "rdi" }, 15 | }; 16 | if (id < 8) 17 | { 18 | if (size == 1) 19 | return regTable[id][0]; 20 | if (size == 2) 21 | return regTable[id][1]; 22 | if (size == 4) 23 | return regTable[id][2]; 24 | assert(size == 8); 25 | return regTable[id][3]; 26 | } 27 | std::stringstream ss; 28 | ss << "r" << id; 29 | if (size == 1) 30 | ss << "b"; 31 | else if (size == 2) 32 | ss << "w"; 33 | else if (size == 4) 34 | ss << "d"; 35 | else 36 | { 37 | assert(size == 8); 38 | } 39 | return ss.str(); 40 | } 41 | 42 | 43 | std::string sizeName(size_t size) 44 | { 45 | if (size == 1) 46 | return "byte"; 47 | if (size == 2) 48 | return "word"; 49 | if (size == 4) 50 | return "dword"; 51 | assert(size == 8); 52 | return "qword"; 53 | } -------------------------------------------------------------------------------- /src/AST.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "AST.h" 3 | 4 | namespace MxAST 5 | { 6 | void ASTVisitor::visit(ASTNode *node) {} 7 | void ASTVisitor::visit(ASTBlock *block) { visit(static_cast(block)); } 8 | void ASTVisitor::visit(ASTRoot *root) { visit(static_cast(root)); } 9 | void ASTVisitor::visit(ASTDeclClass *declClass) { visit(static_cast(declClass)); } 10 | void ASTVisitor::visit(ASTDeclVar *declVar) { visit(static_cast(declVar)); } 11 | void ASTVisitor::visit(ASTDeclFunc *declFunc) { visit(static_cast(declFunc)); } 12 | void ASTVisitor::visit(ASTExpr *expr) { visit(static_cast(expr)); } 13 | void ASTVisitor::visit(ASTExprImm *imm) { visit(static_cast(imm)); } 14 | void ASTVisitor::visit(ASTExprVar *var) { visit(static_cast(var)); } 15 | void ASTVisitor::visit(ASTExprUnary *unary) { visit(static_cast(unary)); } 16 | void ASTVisitor::visit(ASTExprBinary *binary) { visit(static_cast(binary)); } 17 | void ASTVisitor::visit(ASTExprAssignment *assign) { visit(static_cast(assign)); } 18 | void ASTVisitor::visit(ASTExprNew *exprNew) { visit(static_cast(exprNew)); } 19 | void ASTVisitor::visit(ASTExprSubscriptAccess *exprSub) { visit(static_cast(exprSub)); } 20 | void ASTVisitor::visit(ASTExprMemberAccess *expr) { visit(static_cast(expr)); } 21 | void ASTVisitor::visit(ASTExprFuncCall *expr) { visit(static_cast(expr)); } 22 | void ASTVisitor::visit(ASTStatement *stat) { visit(static_cast(stat)); } 23 | void ASTVisitor::visit(ASTStatementReturn *stat) { visit(static_cast(stat)); } 24 | void ASTVisitor::visit(ASTStatementBreak *stat) { visit(static_cast(stat)); } 25 | void ASTVisitor::visit(ASTStatementContinue *stat) { visit(static_cast(stat)); } 26 | void ASTVisitor::visit(ASTStatementIf *stat) { defaultVisit(stat); } 27 | void ASTVisitor::visit(ASTStatementWhile *stat) { defaultVisit(stat); } 28 | void ASTVisitor::visit(ASTStatementFor *stat) { defaultVisit(stat); } 29 | void ASTVisitor::visit(ASTStatementBlock *block) { defaultVisit(block); } 30 | void ASTVisitor::visit(ASTStatementExpr *stat) { visit(static_cast(stat)); } 31 | void ASTVisitor::visit(ASTDeclVarLocal *declVar) { defaultVisit(declVar); } 32 | void ASTVisitor::visit(ASTDeclVarGlobal *declVar) { visit(static_cast(declVar)); } 33 | 34 | ASTNode * ASTListener::enter(ASTBlock *block) { return block; } 35 | ASTNode * ASTListener::enter(ASTDeclClass *declClass) { return declClass; } 36 | ASTNode * ASTListener::enter(ASTDeclVar *declVar) { return declVar; } 37 | ASTNode * ASTListener::enter(ASTDeclFunc *declFunc) { return enter(static_cast(declFunc)); } 38 | ASTNode * ASTListener::enter(ASTExpr *expr) { return expr; } 39 | ASTNode * ASTListener::enter(ASTExprImm *expr) { return enter(static_cast(expr)); } 40 | ASTNode * ASTListener::enter(ASTExprVar *expr) { return enter(static_cast(expr)); } 41 | ASTNode * ASTListener::enter(ASTExprUnary *expr) { return enter(static_cast(expr)); } 42 | ASTNode * ASTListener::enter(ASTExprBinary *expr) { return enter(static_cast(expr)); } 43 | ASTNode * ASTListener::enter(ASTExprAssignment *expr) { return enter(static_cast(expr)); } 44 | ASTNode * ASTListener::enter(ASTExprNew *expr) { return enter(static_cast(expr)); } 45 | ASTNode * ASTListener::enter(ASTExprSubscriptAccess *expr) { return enter(static_cast(expr)); } 46 | ASTNode * ASTListener::enter(ASTExprMemberAccess *expr) { return enter(static_cast(expr)); } 47 | ASTNode * ASTListener::enter(ASTExprFuncCall *expr) { return enter(static_cast(expr)); } 48 | ASTNode * ASTListener::enter(ASTStatement *stat) { return stat; } 49 | ASTNode * ASTListener::enter(ASTStatementReturn *stat) { return enter(static_cast(stat)); } 50 | ASTNode * ASTListener::enter(ASTStatementBreak *stat) { return enter(static_cast(stat)); } 51 | ASTNode * ASTListener::enter(ASTStatementContinue *stat) { return enter(static_cast(stat)); } 52 | ASTNode * ASTListener::enter(ASTStatementIf *stat) { return defaultEnter(stat); } 53 | ASTNode * ASTListener::enter(ASTStatementWhile *stat) { return defaultEnter(stat); } 54 | ASTNode * ASTListener::enter(ASTStatementFor *stat) { return defaultEnter(stat); } 55 | ASTNode * ASTListener::enter(ASTStatementBlock *stat) { return defaultEnter(stat); } 56 | ASTNode * ASTListener::enter(ASTStatementExpr *stat) { return enter(static_cast(stat)); } 57 | ASTNode * ASTListener::enter(ASTDeclVarLocal *declVar) { return defaultEnter(declVar); } 58 | ASTNode * ASTListener::enter(ASTDeclVarGlobal *declVar) { return enter(static_cast(declVar)); } 59 | 60 | ASTNode * ASTListener::leave(ASTBlock *block) { return block; } 61 | ASTNode * ASTListener::leave(ASTDeclClass *declClass) { return declClass; } 62 | ASTNode * ASTListener::leave(ASTDeclVar *declVar) { return declVar;} 63 | ASTNode * ASTListener::leave(ASTDeclFunc *declFunc) { return leave(static_cast(declFunc)); } 64 | ASTNode * ASTListener::leave(ASTExpr *expr) { return expr; } 65 | ASTNode * ASTListener::leave(ASTExprImm *expr) { return leave(static_cast(expr)); } 66 | ASTNode * ASTListener::leave(ASTExprVar *expr) { return leave(static_cast(expr)); } 67 | ASTNode * ASTListener::leave(ASTExprUnary *expr) { return leave(static_cast(expr)); } 68 | ASTNode * ASTListener::leave(ASTExprBinary *expr) { return leave(static_cast(expr)); } 69 | ASTNode * ASTListener::leave(ASTExprAssignment *expr) { return leave(static_cast(expr)); } 70 | ASTNode * ASTListener::leave(ASTExprNew *expr) { return leave(static_cast(expr)); } 71 | ASTNode * ASTListener::leave(ASTExprSubscriptAccess *expr) { return leave(static_cast(expr)); } 72 | ASTNode * ASTListener::leave(ASTExprMemberAccess *expr) { return leave(static_cast(expr)); } 73 | ASTNode * ASTListener::leave(ASTExprFuncCall *expr) { return leave(static_cast(expr)); } 74 | ASTNode * ASTListener::leave(ASTStatement *stat) { return stat; } 75 | ASTNode * ASTListener::leave(ASTStatementReturn *stat) { return leave(static_cast(stat)); } 76 | ASTNode * ASTListener::leave(ASTStatementBreak *stat) { return leave(static_cast(stat)); } 77 | ASTNode * ASTListener::leave(ASTStatementContinue *stat) { return leave(static_cast(stat)); } 78 | ASTNode * ASTListener::leave(ASTStatementIf *stat) { return defaultLeave(stat); } 79 | ASTNode * ASTListener::leave(ASTStatementWhile *stat) { return defaultLeave(stat); } 80 | ASTNode * ASTListener::leave(ASTStatementFor *stat) { return defaultLeave(stat); } 81 | ASTNode * ASTListener::leave(ASTStatementBlock *stat) { return defaultLeave(stat); } 82 | ASTNode * ASTListener::leave(ASTStatementExpr *stat) { return leave(static_cast(stat)); } 83 | ASTNode * ASTListener::leave(ASTDeclVarLocal *declVar) { return defaultLeave(declVar); } 84 | ASTNode * ASTListener::leave(ASTDeclVarGlobal *declVar) { return leave(static_cast(declVar)); } 85 | 86 | int ASTExprUnary::evaluate(Operator oper, int val) 87 | { 88 | switch (oper) 89 | { 90 | case IncPostfix: case Increment: 91 | return val + 1; 92 | case DecPostfix: case Decrement: 93 | return val - 1; 94 | case Positive: 95 | return val; 96 | case Negative: 97 | return -val; 98 | case BitNot: 99 | return ~val; 100 | default: 101 | assert(false); 102 | } 103 | return val; 104 | } 105 | bool ASTExprUnary::evaluate(Operator oper, bool val) 106 | { 107 | if (oper == Not) 108 | return !val; 109 | assert(false); 110 | return val; 111 | } 112 | 113 | int ASTExprBinary::evaluate(int valL, Operator oper, int valR) 114 | { 115 | switch (oper) 116 | { 117 | case Plus: 118 | return valL + valR; 119 | case Minus: 120 | return valL - valR; 121 | case Multiple: 122 | return valL * valR; 123 | case Divide: 124 | return valL / valR; 125 | case Mod: 126 | return valL % valR; 127 | case ShiftLeft: 128 | return valL << valR; 129 | case ShiftRight: 130 | return valL >> valR; 131 | case BitAnd: 132 | return valL & valR; 133 | case BitXor: 134 | return valL ^ valR; 135 | case BitOr: 136 | return valL | valR; 137 | case Equal: 138 | return valL == valR; 139 | case NotEqual: 140 | return valL != valR; 141 | case GreaterThan: 142 | return valL > valR; 143 | case GreaterEqual: 144 | return valL >= valR; 145 | case LessThan: 146 | return valL < valR; 147 | case LessEqual: 148 | return valL <= valR; 149 | default: 150 | assert(false); 151 | } 152 | return 0; 153 | } 154 | bool ASTExprBinary::evaluate(bool valL, Operator oper, bool valR) 155 | { 156 | switch (oper) 157 | { 158 | case BitAnd: case And: 159 | return valL && valR; 160 | case BitOr: case Or: 161 | return valL || valR; 162 | case Equal: 163 | return valL == valR; 164 | case BitXor: case NotEqual: 165 | return valL != valR; 166 | case GreaterThan: 167 | return valL > valR; 168 | case GreaterEqual: 169 | return valL >= valR; 170 | case LessThan: 171 | return valL < valR; 172 | case LessEqual: 173 | return valL <= valR; 174 | default: 175 | assert(false); 176 | } 177 | return false; 178 | } 179 | 180 | bool ASTExprBinary::stringCompare(const std::string &valL, Operator oper, const std::string &valR) 181 | { 182 | switch (oper) 183 | { 184 | case Equal: 185 | return valL == valR; 186 | case NotEqual: 187 | return valL != valR; 188 | case GreaterThan: 189 | return valL > valR; 190 | case GreaterEqual: 191 | return valL >= valR; 192 | case LessThan: 193 | return valL < valR; 194 | case LessEqual: 195 | return valL <= valR; 196 | default: 197 | assert(false); 198 | } 199 | return false; 200 | } 201 | 202 | } -------------------------------------------------------------------------------- /src/ASTConstructor.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_AST_CONSTRUCTOR_H 2 | #define MX_COMPILER_AST_CONSTRUCTOR_H 3 | 4 | #include "AST.h" 5 | #include "GlobalSymbol.h" 6 | #include "IssueCollector.h" 7 | #include 8 | 9 | class ASTConstructor : protected MxParserBaseVisitor 10 | { 11 | public: 12 | ASTConstructor(IssueCollector *issues) : issues(issues) {} 13 | MxAST::ASTRoot * constructAST(MxParser::ProgContext *prog, GlobalSymbol *symbolTable); 14 | 15 | protected: 16 | template 17 | static T * newNode(antlr4::tree::ParseTree *tree, Tparam&&... param) 18 | { 19 | T *ret = new T(std::forward(param)...); 20 | ret->tokenL = tree->getSourceInterval().a; 21 | ret->tokenR = tree->getSourceInterval().b; 22 | return ret; 23 | } 24 | antlrcpp::Any visitPrefixUnaryExpr(antlr4::ParserRuleContext *ctx, MxAST::ASTExprUnary::Operator oper); 25 | antlrcpp::Any visitBinaryExpr(antlr4::ParserRuleContext *ctx, MxAST::ASTExprBinary::Operator oper); 26 | 27 | MxType getType(MxParser::TypeNotArrayContext *ctx); 28 | MxType getType(MxParser::TypeContext *ctx); 29 | std::vector> getParamList(MxParser::ParamListContext *ctx); 30 | std::vector> getVarDecl(MxParser::VarDeclContext *ctx, bool isGlobal); 31 | std::vector> getStatment(MxParser::StatementContext *ctx); 32 | std::vector> getExprList(MxParser::ExprListContext *ctx); 33 | 34 | virtual antlrcpp::Any visitBlock(MxParser::BlockContext *ctx) override; 35 | virtual antlrcpp::Any visitProg(MxParser::ProgContext *ctx) override; 36 | virtual antlrcpp::Any visitClassDecl(MxParser::ClassDeclContext *ctx) override; 37 | virtual antlrcpp::Any visitFuncDecl(MxParser::FuncDeclContext *ctx) override; 38 | virtual antlrcpp::Any visitVarDecl(MxParser::VarDeclContext *ctx) override { assert(false); return nullptr; } 39 | virtual antlrcpp::Any visitParamList(MxParser::ParamListContext *ctx) override { assert(false); return nullptr; } 40 | virtual antlrcpp::Any visitStatement(MxParser::StatementContext *ctx) override { assert(false); return nullptr; } 41 | virtual antlrcpp::Any visitIf_statement(MxParser::If_statementContext *ctx) override; 42 | virtual antlrcpp::Any visitWhile_statement(MxParser::While_statementContext *ctx) override; 43 | virtual antlrcpp::Any visitFor_statement(MxParser::For_statementContext *ctx) override; 44 | virtual antlrcpp::Any visitExprPrimary(MxParser::ExprPrimaryContext *ctx) override; 45 | virtual antlrcpp::Any visitExprIncrementPostfix(MxParser::ExprIncrementPostfixContext *ctx) override; 46 | virtual antlrcpp::Any visitExprDecrementPostfix(MxParser::ExprDecrementPostfixContext *ctx) override; 47 | virtual antlrcpp::Any visitExprMember(MxParser::ExprMemberContext *ctx) override; 48 | virtual antlrcpp::Any visitExprFuncCall(MxParser::ExprFuncCallContext *ctx) override; 49 | virtual antlrcpp::Any visitExprSubscript(MxParser::ExprSubscriptContext *ctx) override; 50 | virtual antlrcpp::Any visitExprIncrementPrefix(MxParser::ExprIncrementPrefixContext *ctx) override { return visitPrefixUnaryExpr(ctx, MxAST::ASTExprUnary::Increment); } 51 | virtual antlrcpp::Any visitExprDecrementPrefix(MxParser::ExprDecrementPrefixContext *ctx) override { return visitPrefixUnaryExpr(ctx, MxAST::ASTExprUnary::Decrement); } 52 | virtual antlrcpp::Any visitExprPositive(MxParser::ExprPositiveContext *ctx) override { return visitPrefixUnaryExpr(ctx, MxAST::ASTExprUnary::Positive); } 53 | virtual antlrcpp::Any visitExprNegative(MxParser::ExprNegativeContext *ctx) override { return visitPrefixUnaryExpr(ctx, MxAST::ASTExprUnary::Negative); } 54 | virtual antlrcpp::Any visitExprNot(MxParser::ExprNotContext *ctx) override { return visitPrefixUnaryExpr(ctx, MxAST::ASTExprUnary::Not); } 55 | virtual antlrcpp::Any visitExprBitNot(MxParser::ExprBitNotContext *ctx) override { return visitPrefixUnaryExpr(ctx, MxAST::ASTExprUnary::BitNot); } 56 | virtual antlrcpp::Any visitExprNew(MxParser::ExprNewContext *ctx) override; 57 | virtual antlrcpp::Any visitExprMulti(MxParser::ExprMultiContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Multiple); } 58 | virtual antlrcpp::Any visitExprDiv(MxParser::ExprDivContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Divide); } 59 | virtual antlrcpp::Any visitExprMod(MxParser::ExprModContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Mod); } 60 | virtual antlrcpp::Any visitExprPlus(MxParser::ExprPlusContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Plus); } 61 | virtual antlrcpp::Any visitExprMinus(MxParser::ExprMinusContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Minus); } 62 | virtual antlrcpp::Any visitExprShiftLeft(MxParser::ExprShiftLeftContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::ShiftLeft); } 63 | virtual antlrcpp::Any visitExprShiftRight(MxParser::ExprShiftRightContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::ShiftRight); } 64 | virtual antlrcpp::Any visitExprLessThan(MxParser::ExprLessThanContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::LessThan); } 65 | virtual antlrcpp::Any visitExprLessEqual(MxParser::ExprLessEqualContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::LessEqual); } 66 | virtual antlrcpp::Any visitExprGreaterThan(MxParser::ExprGreaterThanContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::GreaterThan); } 67 | virtual antlrcpp::Any visitExprGreaterEqual(MxParser::ExprGreaterEqualContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::GreaterEqual); } 68 | virtual antlrcpp::Any visitExprEqual(MxParser::ExprEqualContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Equal); } 69 | virtual antlrcpp::Any visitExprNotEqual(MxParser::ExprNotEqualContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::NotEqual); } 70 | virtual antlrcpp::Any visitExprBitand(MxParser::ExprBitandContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::BitAnd); } 71 | virtual antlrcpp::Any visitExprXor(MxParser::ExprXorContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::BitXor); } 72 | virtual antlrcpp::Any visitExprBitor(MxParser::ExprBitorContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::BitOr); } 73 | virtual antlrcpp::Any visitExprAnd(MxParser::ExprAndContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::And); } 74 | virtual antlrcpp::Any visitExprOr(MxParser::ExprOrContext *ctx) override { return visitBinaryExpr(ctx, MxAST::ASTExprBinary::Or); } 75 | virtual antlrcpp::Any visitExprAssignment(MxParser::ExprAssignmentContext *ctx) override; 76 | 77 | std::string transferString(const std::string &in, ssize_t tokenL, ssize_t tokenR); 78 | 79 | protected: 80 | IssueCollector *issues; 81 | GlobalSymbol *symbols; 82 | std::unique_ptr node; 83 | size_t curClass; 84 | IF_DEBUG(std::string strCurClass); 85 | }; 86 | 87 | #endif -------------------------------------------------------------------------------- /src/ASTVisualizer.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_AST_VISUALIZER_H 2 | #define MX_COMPILER_AST_VISUALIZER_H 3 | 4 | #include "AST.h" 5 | #include "GlobalSymbol.h" 6 | 7 | class ASTVisualizer : public MxAST::ASTVisitor 8 | { 9 | public: 10 | ASTVisualizer(std::ostream &out, GlobalSymbol &symbol) : out(out), cntNode(0), symbol(symbol), inClass(false) {} 11 | void printHead(); 12 | void printFoot(); 13 | 14 | //virtual void visit(MxAST::ASTNode *node) override; 15 | virtual void visit(MxAST::ASTRoot *root) override; 16 | virtual void visit(MxAST::ASTDeclClass *declClass) override; 17 | virtual void visit(MxAST::ASTDeclFunc *declFunc) override; 18 | virtual void visit(MxAST::ASTDeclVar *declVar) override; 19 | virtual void visit(MxAST::ASTExprImm *expr) override; 20 | virtual void visit(MxAST::ASTExprVar *expr) override; 21 | virtual void visit(MxAST::ASTExprUnary *expr) override; 22 | virtual void visit(MxAST::ASTExprBinary *expr) override; 23 | virtual void visit(MxAST::ASTExprAssignment *expr) override; 24 | virtual void visit(MxAST::ASTExprNew *expr) override; 25 | virtual void visit(MxAST::ASTExprSubscriptAccess *expr) override; 26 | virtual void visit(MxAST::ASTExprMemberAccess *expr) override; 27 | virtual void visit(MxAST::ASTExprFuncCall *expr) override; 28 | virtual void visit(MxAST::ASTStatementReturn *stat) override; 29 | virtual void visit(MxAST::ASTStatementBreak *stat) override; 30 | virtual void visit(MxAST::ASTStatementContinue *stat) override; 31 | virtual void visit(MxAST::ASTStatementIf *stat) override; 32 | virtual void visit(MxAST::ASTStatementWhile *stat) override; 33 | virtual void visit(MxAST::ASTStatementFor *stat) override; 34 | virtual void visit(MxAST::ASTStatementBlock *stat) override; 35 | virtual void visit(MxAST::ASTStatementExpr *stat) override; 36 | 37 | 38 | protected: 39 | std::string type2html(MxType type); 40 | //static std::string transferHTML(const std::string &in); 41 | static std::string exprColor(MxAST::ASTExpr *expr); 42 | 43 | protected: 44 | bool inClass; 45 | std::ostream &out; 46 | GlobalSymbol &symbol; 47 | size_t cntNode, lastNode; 48 | }; 49 | 50 | #endif -------------------------------------------------------------------------------- /src/CodeGenerator.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_CODE_GENERATOR_H 2 | #define MX_COMPILER_CODE_GENERATOR_H 3 | 4 | #include "common.h" 5 | #include "CodeGeneratorBasic.h" 6 | 7 | class CodeGenerator : public CodeGeneratorBasic 8 | { 9 | public: 10 | CodeGenerator(std::ostream &out) : CodeGeneratorBasic(out) {} 11 | 12 | protected: 13 | virtual void generateFunc(MxProgram::funcInfo &finfo, const std::string &label) override; 14 | virtual void translateIns(MxIR::Instruction ins) override; 15 | 16 | void initFuncEntryExit(); 17 | void regularizeInsnPre(); 18 | void setRegisterConstrains(); 19 | void setRegisterPrefer(); 20 | void regularizeInsnPost(); 21 | void allocateStackFrame(); 22 | std::string getOperand(MxIR::Operand operand); 23 | 24 | protected: 25 | bool popRBP; 26 | size_t varID; 27 | MxIR::Function *func; 28 | std::map stackFrame; 29 | size_t stackSize; 30 | static const std::vector regCallerSave, regCalleeSave, regParam; 31 | }; 32 | 33 | #endif -------------------------------------------------------------------------------- /src/CodeGeneratorBasic.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_CODE_GENERATOR_BASIC_H 2 | #define MX_COMPILER_CODE_GENERATOR_BASIC_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "MxProgram.h" 7 | #include "GlobalSymbol.h" 8 | 9 | class CodeGeneratorBasic 10 | { 11 | public: 12 | CodeGeneratorBasic(std::ostream &out) : 13 | program(MxProgram::getDefault()), symbol(GlobalSymbol::getDefault()), out(out), cntLocalLabel(0) {} 14 | 15 | void generateProgram(); 16 | 17 | protected: 18 | void createLabel(); 19 | virtual std::string decorateFuncName(const MxProgram::funcInfo &finfo); 20 | virtual void generateFunc(MxProgram::funcInfo &finfo, const std::string &label); 21 | virtual void generateConst(const MxProgram::constInfo &cinfo, const std::string &label); 22 | virtual void generateVar(const MxProgram::varInfo &vinfo, const std::string &label); 23 | 24 | virtual std::vector sortBlocks(MxIR::Block *inBlock); 25 | virtual void translateBlocks(const std::vector &vBlocks); 26 | virtual void translateIns(MxIR::Instruction ins); 27 | 28 | std::string loadOperand(int id, MxIR::Operand src); 29 | void loadOperand(std::string reg, MxIR::Operand src); 30 | std::string getConst(MxIR::Operand src, bool immSigned = true, int tempreg = 11); 31 | std::string getVReg(MxIR::Operand src); 32 | std::string getVRegAddr(MxIR::Operand src); 33 | std::string storeOperand(MxIR::Operand dst, int id); 34 | 35 | template 36 | void writeCode(T&&... val) 37 | { 38 | out << "\t"; 39 | prints(out, std::forward(val)...); 40 | out << std::endl; 41 | } 42 | template 43 | void writeLabel(T&&... val) 44 | { 45 | prints(out, std::forward(val)...); 46 | out << ":" << std::endl; 47 | } 48 | 49 | protected: 50 | MxProgram *program; 51 | GlobalSymbol *symbol; 52 | std::ostream &out; 53 | std::vector labelFunc, labelVar; 54 | size_t cntLocalLabel; 55 | std::vector regAddr; 56 | std::list> allocAddr; 57 | 58 | static const std::string paramReg[]; 59 | static const int paramRegID[]; 60 | }; 61 | 62 | #endif -------------------------------------------------------------------------------- /src/ConstantFold.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "ConstantFold.h" 3 | using namespace MxAST; 4 | 5 | namespace ASTOptimizer 6 | { 7 | ASTNode * ConstantFold::leave(ASTExprUnary *expr) 8 | { 9 | ASTExprImm *imm = dynamic_cast(expr->operand.get()); 10 | if (!imm) 11 | return expr; 12 | 13 | std::unique_ptr newImm; 14 | if (imm->exprType.mainType == MxType::Bool) 15 | newImm.reset(new ASTExprImm(expr->evaluate(imm->exprVal.bvalue))); 16 | else if (imm->exprType.mainType == MxType::Integer) 17 | newImm.reset(new ASTExprImm(expr->evaluate(imm->exprVal.ivalue))); 18 | else 19 | { 20 | assert(imm->exprType.mainType == MxType::Object && imm->exprType.className == size_t(-1)); // null 21 | assert(expr->oper == ASTExprUnary::Not); 22 | newImm.reset(new ASTExprImm(true)); 23 | } 24 | newImm->tokenL = expr->tokenL; 25 | newImm->tokenR = expr->tokenR; 26 | return newImm.release(); 27 | } 28 | 29 | ASTNode * ConstantFold::leave(ASTExprBinary *expr) 30 | { 31 | ASTExprImm *operandL = dynamic_cast(expr->operandL.get()); 32 | ASTExprImm *operandR = dynamic_cast(expr->operandR.get()); 33 | if (!operandL || !operandR) 34 | return expr; 35 | 36 | if (expr->oper == ASTExprBinary::Divide || expr->oper == ASTExprBinary::Mod) 37 | { 38 | assert(operandL->exprType.mainType == MxType::Integer); 39 | assert(operandR->exprType.mainType == MxType::Integer); 40 | if (operandR->exprVal.ivalue == 0) 41 | { 42 | issues->error(expr->tokenL, expr->tokenR, 43 | "Divided by zero"); 44 | return expr; 45 | } 46 | } 47 | 48 | std::unique_ptr newImm; 49 | if (operandL->exprType.mainType == MxType::Integer) 50 | { 51 | assert(operandR->exprType.mainType == MxType::Integer); 52 | newImm.reset(new ASTExprImm(expr->evaluate(operandL->exprVal.ivalue, operandR->exprVal.ivalue))); 53 | } 54 | else if (operandL->exprType.mainType == MxType::Bool) 55 | { 56 | assert(operandR->exprType.mainType == MxType::Bool); 57 | newImm.reset(new ASTExprImm(expr->evaluate(operandL->exprVal.bvalue, operandR->exprVal.bvalue))); 58 | } 59 | else if (operandL->exprType.isNull()) 60 | { 61 | assert(operandR->exprType.isNull()); 62 | newImm.reset(new ASTExprImm(true)); 63 | } 64 | else 65 | { 66 | assert(operandL->exprType.mainType == MxType::String); 67 | assert(operandR->exprType.mainType == MxType::String); 68 | std::string strL = symbols->vString[operandL->exprVal.strId]; 69 | std::string strR = symbols->vString[operandR->exprVal.strId]; 70 | if (expr->oper == ASTExprBinary::Plus) 71 | { 72 | if (symbols->sumStringSize + strL.size() + strR.size() > MAX_STRINGSIZE) 73 | return expr; 74 | if (symbols->memoryUsage + strL.size() + strR.size() > MAX_STRINGMEMUSAGE) 75 | return expr; 76 | size_t sid = symbols->addString(strL + strR); 77 | symbols->decStringRef(operandL->exprVal.strId); 78 | symbols->decStringRef(operandR->exprVal.strId); 79 | newImm.reset(new ASTExprImm); 80 | newImm->exprType = MxType{ MxType::String }; 81 | newImm->exprVal.strId = sid; 82 | IF_DEBUG(newImm->strContent = strL + strR); 83 | } 84 | else 85 | newImm.reset(new ASTExprImm(expr->stringCompare(strL, strR))); 86 | } 87 | newImm->tokenL = expr->tokenL; 88 | newImm->tokenR = expr->tokenR; 89 | return newImm.release(); 90 | } 91 | 92 | } -------------------------------------------------------------------------------- /src/ConstantFold.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_CONSTANT_FOLDER_H 2 | #define MX_COMPILER_CONSTANT_FOLDER_H 3 | 4 | #include "AST.h" 5 | #include "IssueCollector.h" 6 | #include "GlobalSymbol.h" 7 | 8 | namespace ASTOptimizer 9 | { 10 | class ConstantFold : public MxAST::ASTListener 11 | { 12 | public: 13 | ConstantFold() : issues(IssueCollector::getDefault()), symbols(GlobalSymbol::getDefault()) {} 14 | ConstantFold(IssueCollector *issues, GlobalSymbol *symbols) : issues(issues), symbols(symbols) {} 15 | 16 | virtual MxAST::ASTNode * leave(MxAST::ASTExprUnary *expr) override; 17 | virtual MxAST::ASTNode * leave(MxAST::ASTExprBinary *expr) override; 18 | 19 | protected: 20 | IssueCollector *issues; 21 | GlobalSymbol *symbols; 22 | }; 23 | } 24 | 25 | #endif -------------------------------------------------------------------------------- /src/DeadCodeElimination.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "DeadCodeElimination.h" 3 | 4 | namespace MxIR 5 | { 6 | void DeadCodeElimination::work() 7 | { 8 | static const size_t maxIter = 50; 9 | 10 | func.splitProgramRegion(); 11 | func.constructPST(); 12 | 13 | collectVars(); 14 | 15 | size_t i = 0; 16 | do 17 | { 18 | eliminateVars(); 19 | countGlobalUse(); 20 | 21 | regionUpdated = false; 22 | eliminateRegions(func.pstRoot.get()); 23 | } while (regionUpdated && ++i < maxIter); 24 | } 25 | 26 | bool DeadCodeElimination::hasSideEffect(const Instruction &ins) 27 | { 28 | return ins.oper == Call && (ins.src1.type != Operand::funcID || !(program->vFuncs[ins.src1.val].attribute & NoSideEffect)) 29 | || ins.oper == Store || ins.oper == StoreA 30 | || ins.oper == Return; 31 | } 32 | 33 | void DeadCodeElimination::collectVars() 34 | { 35 | func.inBlock->traverse([this](Block *block) -> bool 36 | { 37 | for (auto iter = block->phi.begin(); iter != block->phi.end(); ++iter) 38 | { 39 | vars[iter->second.dst].block = block; 40 | vars[iter->second.dst].isPhi = true; 41 | vars[iter->second.dst].iterPhi = iter; 42 | } 43 | for (auto iter = block->ins.begin(); iter != block->ins.end(); ++iter) 44 | { 45 | for (Operand *operand : iter->getOutputReg()) 46 | { 47 | vars[*operand].block = block; 48 | vars[*operand].isPhi = false; 49 | vars[*operand].iterInsn = iter; 50 | } 51 | } 52 | return true; 53 | }); 54 | } 55 | 56 | void DeadCodeElimination::eliminateVars() 57 | { 58 | std::set marked; 59 | func.inBlock->traverse([&marked, this](Block *block) -> bool 60 | { 61 | bool flag = false; 62 | for (auto &ins : block->ins) 63 | { 64 | if (hasSideEffect(ins) || ins.oper == Br) 65 | { 66 | for (Operand *operand : join(ins.getInputReg(), ins.getOutputReg())) 67 | marked.insert(*operand); 68 | if(ins.oper != Br) 69 | flag = true; 70 | } 71 | } 72 | if (flag) 73 | blockBlacklist.insert(block); 74 | return true; 75 | }); 76 | 77 | std::queue worklist; 78 | for (auto &op : marked) 79 | worklist.push(op); 80 | while (!worklist.empty()) 81 | { 82 | Operand cur = worklist.front(); 83 | worklist.pop(); 84 | 85 | if (!vars.count(cur)) 86 | continue; 87 | 88 | if (vars[cur].isPhi) 89 | { 90 | for (auto &src : vars[cur].iterPhi->second.srcs) 91 | { 92 | if (src.first.isReg()) 93 | if (!marked.count(src.first)) 94 | { 95 | marked.insert(src.first); 96 | worklist.push(src.first); 97 | } 98 | } 99 | } 100 | else 101 | { 102 | for (Operand *operand : vars[cur].iterInsn->getInputReg()) 103 | { 104 | if (!marked.count(*operand)) 105 | { 106 | marked.insert(*operand); 107 | worklist.push(*operand); 108 | } 109 | } 110 | } 111 | } 112 | 113 | func.inBlock->traverse([&marked, this](Block *block) -> bool 114 | { 115 | for (auto iter = block->phi.begin(); iter != block->phi.end();) 116 | { 117 | if (!marked.count(iter->second.dst)) 118 | iter = block->phi.erase(iter); 119 | else 120 | ++iter; 121 | } 122 | for (auto iter = block->ins.begin(); iter != block->ins.end();) 123 | { 124 | if (hasSideEffect(*iter) || iter->oper == Jump || iter->oper == Br) 125 | { 126 | ++iter; 127 | continue; 128 | } 129 | 130 | bool flag = false; 131 | for (Operand *operand : iter->getOutputReg()) 132 | if (marked.count(*operand)) 133 | { 134 | flag = true; 135 | break; 136 | } 137 | if (!flag) 138 | iter = block->ins.erase(iter); 139 | else 140 | ++iter; 141 | } 142 | return true; 143 | }); 144 | } 145 | 146 | void DeadCodeElimination::countGlobalUse() 147 | { 148 | for (auto &kv : vars) 149 | kv.second.useCount = 0; 150 | func.inBlock->traverse([this](Block *block) -> bool 151 | { 152 | for (auto &ins : block->instructions()) 153 | for (Operand *operand : ins.getInputReg()) 154 | { 155 | if (!vars.count(*operand)) 156 | continue; 157 | vars[*operand].useCount++; 158 | } 159 | return true; 160 | }); 161 | } 162 | 163 | std::map DeadCodeElimination::countUse(const std::set &blocks) 164 | { 165 | std::map ret; 166 | for (Block *block : blocks) 167 | { 168 | for (auto &ins : block->instructions()) 169 | for (Operand *operand : ins.getInputReg()) 170 | { 171 | if (!vars.count(*operand)) 172 | continue; 173 | ret[*operand]++; 174 | } 175 | } 176 | return ret; 177 | } 178 | 179 | std::set DeadCodeElimination::findDef(const std::set &blocks) 180 | { 181 | std::set ret; 182 | for (Block *block : blocks) 183 | { 184 | for (auto &ins : block->instructions()) 185 | for (Operand *operand : ins.getOutputReg()) 186 | ret.insert(*operand); 187 | } 188 | return ret; 189 | } 190 | 191 | int DeadCodeElimination::eliminateRegions(PSTNode *node) 192 | { 193 | bool blacklisted = false; 194 | for (auto iter = node->children.begin(); iter != node->children.end();) 195 | { 196 | int ret = eliminateRegions(iter->get()); 197 | if (ret == 1) 198 | iter = node->children.erase(iter); 199 | else 200 | ++iter; 201 | 202 | if (ret == -1) 203 | blacklisted = true; 204 | } 205 | 206 | if (blacklisted) 207 | return -1; 208 | 209 | for (Block *block : node->blocks) 210 | if (blockBlacklist.count(block)) 211 | return -1; 212 | 213 | if (node->inBlock == func.inBlock.get() || node->outBlock == func.outBlock.get()) 214 | return -1; 215 | 216 | std::set blocks = node->getBlocks(); 217 | 218 | auto def = findDef(blocks); 219 | auto use = countUse(blocks); 220 | for (auto &op : def) 221 | { 222 | assert(vars.count(op)); 223 | if (vars[op].useCount != use[op]) 224 | return 0; 225 | } 226 | 227 | 228 | 229 | /*std::cerr << "eliminate region with " << blocks.size() << " blocks" << std::endl; 230 | IRVisualizer visual(std::cerr); 231 | std::cerr << visual.toString(**blocks.begin(), false) << std::endl;*/ 232 | 233 | Block *pred = nullptr, *next = nullptr; 234 | for (Block *block : node->inBlock->preds) 235 | if (!blocks.count(block)) 236 | { 237 | assert(!pred); 238 | pred = block; 239 | } 240 | for (Block *block : { node->outBlock->brTrue.get(), node->outBlock->brFalse.get() }) 241 | if (block && !blocks.count(block)) 242 | { 243 | assert(!next); 244 | next = block; 245 | } 246 | 247 | std::shared_ptr inBlock = node->inBlock->self.lock(); 248 | 249 | bool flag = false; 250 | std::shared_ptr tmp; 251 | if (pred->brTrue.get() == next || pred->brFalse.get() == next) 252 | { 253 | if (blocks.size() == 1 && (*blocks.begin())->ins.size() == 1 && (*blocks.begin())->ins.front().oper == Jump) 254 | flag = true; 255 | 256 | tmp = Block::construct(); 257 | tmp->ins = { IRJump() }; 258 | if (inBlock.get() == pred->brTrue.get()) 259 | pred->brTrue = tmp; 260 | if (inBlock.get() == pred->brFalse.get()) 261 | pred->brFalse = tmp; 262 | tmp->brTrue = inBlock; 263 | pred = tmp.get(); 264 | } 265 | 266 | if (inBlock.get() == pred->brTrue.get()) 267 | pred->brTrue = next->self.lock(); 268 | if (inBlock.get() == pred->brFalse.get()) 269 | pred->brFalse = next->self.lock(); 270 | next->redirectPhiSrc(node->outBlock, pred); 271 | 272 | if (node->outBlock->brTrue.get() == next) 273 | node->outBlock->brTrue.reset(); 274 | else 275 | node->outBlock->brFalse.reset(); 276 | 277 | if(!flag) 278 | regionUpdated = true; 279 | 280 | return 1; 281 | } 282 | } -------------------------------------------------------------------------------- /src/DeadCodeElimination.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_DEAD_CODE_ELIMINATION_H 2 | #define MX_COMPILER_DEAD_CODE_ELIMINATION_H 3 | 4 | #include "common.h" 5 | #include "MxProgram.h" 6 | #include "IR.h" 7 | 8 | namespace MxIR 9 | { 10 | class DeadCodeElimination 11 | { 12 | public: 13 | DeadCodeElimination(Function &func) : func(func), program(MxProgram::getDefault()) {} 14 | void work(); 15 | 16 | protected: 17 | bool hasSideEffect(const Instruction &ins); 18 | void collectVars(); 19 | void eliminateVars(); 20 | 21 | void countGlobalUse(); 22 | std::map countUse(const std::set &blocks); 23 | static std::set findDef(const std::set &blocks); 24 | int eliminateRegions(PSTNode *node); 25 | 26 | protected: 27 | struct VarProperty 28 | { 29 | bool isPhi; 30 | Block *block; 31 | std::list::iterator iterInsn; 32 | std::map::iterator iterPhi; 33 | 34 | size_t useCount = 0; 35 | }; 36 | 37 | MxProgram *program; 38 | Function &func; 39 | std::map vars; 40 | std::set blockBlacklist; 41 | bool regionUpdated; 42 | }; 43 | } 44 | 45 | #endif -------------------------------------------------------------------------------- /src/GVN.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_GVN_H 2 | #define MX_COMPILER_GVN_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "MxProgram.h" 7 | #include "utils/DomTree.h" 8 | 9 | namespace MxIR 10 | { 11 | class GVN 12 | { 13 | public: 14 | GVN(Function &func) : func(func), program(MxProgram::getDefault()) {} 15 | void work(); 16 | 17 | protected: 18 | enum NodeType : std::uint8_t 19 | { 20 | Imm, Var, Const, OperCommAssoc, OperBinary, OperUnary, OperFuncCall, OperPhiIf, OperPhi 21 | }; 22 | class ValueHash 23 | { 24 | public: 25 | template 26 | void process(const T &value) 27 | { 28 | process_bytes(&value, sizeof(value)); 29 | } 30 | void process_bytes(const void *buffer, size_t length); 31 | void cacluateHash(); 32 | 33 | bool operator<(const ValueHash &rhs) const; 34 | bool operator==(const ValueHash &rhs) const; 35 | bool operator!=(const ValueHash &rhs) const { return !(*this == rhs); } 36 | 37 | void printHash(std::ostream &out); 38 | const std::uint32_t (&getHash() const)[5] { return hash; } 39 | 40 | 41 | protected: 42 | std::uint32_t hash[5]; 43 | std::vector buffer; 44 | }; 45 | struct ValueNode 46 | { 47 | const NodeType nodeType; 48 | ValueHash hash; 49 | size_t length; 50 | 51 | ValueNode(NodeType nodeType) : nodeType(nodeType) 52 | { 53 | hash.process(char(nodeType)); 54 | } 55 | ValueNode(NodeType nodeType, size_t length) : nodeType(nodeType), length(length) 56 | { 57 | hash.process(char(nodeType)); 58 | hash.process(char(length)); 59 | } 60 | virtual ~ValueNode() {} 61 | 62 | bool operator<(const ValueNode &rhs) const { return hash < rhs.hash; } 63 | virtual bool equal(ValueNode &rhs) = 0; 64 | 65 | protected: 66 | template 67 | static bool equal_impl(T &lhs, ValueNode &rhs, Tparam ...param) 68 | { 69 | if (lhs.nodeType != rhs.nodeType || lhs.length != rhs.length || lhs.hash != rhs.hash) 70 | return false; 71 | T &node = dynamic_cast(rhs); 72 | return equal_recursive(lhs, node, param...); 73 | } 74 | 75 | private: 76 | template 77 | static bool equal_recursive(T &lhs, T &rhs) { return true; } 78 | template 79 | static bool equal_recursive(T &lhs, T &rhs, Tparam param, Tother ...other) 80 | { 81 | if (lhs.*param != rhs.*param) 82 | return false; 83 | return equal_recursive(lhs, rhs, other...); 84 | } 85 | template 86 | static bool equal_recursive(T &lhs, T &rhs, std::shared_ptr T::*param, Tother ...other) 87 | { 88 | if (!((lhs.*param).get()->equal(*(rhs.*param)))) 89 | return false; 90 | return equal_recursive(lhs, rhs, other...); 91 | } 92 | template 93 | static bool equal_recursive(T &lhs, T &rhs, std::vector> T::*param, Tother ...other) 94 | { 95 | if ((lhs.*param).size() != (rhs.*param).size()) 96 | return false; 97 | for (size_t i = 0; i < (lhs.*param).size(); i++) 98 | if (!(lhs.*param)[i]->equal(*(rhs.*param)[i])) 99 | return false; 100 | return equal_recursive(lhs, rhs, other...); 101 | } 102 | }; 103 | struct ValueImm : public ValueNode 104 | { 105 | std::uint64_t val; 106 | 107 | ValueImm(Operand operand); 108 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValueImm::val); } 109 | }; 110 | struct ValueConst : public ValueNode 111 | { 112 | Operand::OperandType type; 113 | std::uint64_t val; 114 | 115 | ValueConst(Operand operand); 116 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValueConst::type, &ValueConst::val); } 117 | }; 118 | struct ValueVar : public ValueNode 119 | { 120 | size_t varID, ver; 121 | 122 | ValueVar(Operand operand); 123 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValueVar::varID, &ValueVar::ver); } 124 | }; 125 | struct ValueCommAssoc : public ValueNode 126 | { 127 | enum Operator { Add, Mult, And, Or, Xor }; 128 | Operator oper; 129 | std::vector> varValue; 130 | ValueImm immValue; 131 | 132 | ValueCommAssoc(Operator oper, size_t length, const std::vector> &varValue, ValueImm immValue); 133 | virtual bool equal(ValueNode &rhs) override; 134 | static std::uint64_t calculate(std::uint64_t val1, Operator oper, std::uint64_t val2); 135 | }; 136 | struct ValueBinary : public ValueNode 137 | { 138 | enum Operator { Div, Mod, Shl, Shr, Shlu, Shru, Seq, Slt, Sltu }; 139 | Operator oper; 140 | std::shared_ptr valueL, valueR; 141 | 142 | ValueBinary(Operator oper, size_t length, std::shared_ptr valueL, std::shared_ptr valueR); 143 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValueBinary::oper, &ValueBinary::valueL, &ValueBinary::valueR); } 144 | 145 | static std::uint64_t calculate(std::uint64_t val1, size_t length1, Operator oper, std::uint64_t val2, size_t length2, size_t lengthResult); 146 | }; 147 | struct ValueUnary : public ValueNode 148 | { 149 | enum Operator { Not, Neg, Sext, Zext, NotBool }; 150 | Operator oper; 151 | std::shared_ptr operand; 152 | 153 | ValueUnary(Operator oper, size_t length, std::shared_ptr operand); 154 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValueUnary::oper, &ValueUnary::operand); } 155 | 156 | static std::uint64_t calculate(std::uint64_t val, size_t length, Operator oper, size_t lengthResult); 157 | }; 158 | struct ValueFuncCall : public ValueNode 159 | { 160 | size_t funcID; 161 | std::vector> params; 162 | 163 | ValueFuncCall(size_t length, size_t funcID, const std::vector> ¶ms); 164 | virtual bool equal(ValueNode &rhs) override; 165 | }; 166 | struct ValuePhiIf : public ValueNode 167 | { 168 | std::shared_ptr cond; 169 | std::shared_ptr valueTrue, valueFalse; 170 | ValuePhiIf(size_t length, std::shared_ptr cond, std::shared_ptr valueTrue, std::shared_ptr valueFalse); 171 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValuePhiIf::cond, &ValuePhiIf::valueTrue, &ValuePhiIf::valueFalse); } 172 | }; 173 | struct ValuePhi : public ValueNode 174 | { 175 | size_t blockID; 176 | std::vector> srcs; 177 | ValuePhi(size_t length, size_t blockID, const std::vector> &srcs); 178 | virtual bool equal(ValueNode &rhs) override { return equal_impl(*this, rhs, &ValuePhi::blockID, &ValuePhi::srcs); } 179 | }; 180 | 181 | protected: 182 | bool isConstExpr(const Instruction &insn); 183 | std::shared_ptr getOperand(Operand operand); 184 | 185 | static std::shared_ptr reduceValue(std::shared_ptr value); 186 | static std::shared_ptr reduceValue(ValueNode *value) { return reduceValue(std::shared_ptr(value)); } 187 | 188 | void computeDomTree(); 189 | void computeDomTreeDepth(size_t idx, size_t curdepth); 190 | bool isPostDom(size_t parent, size_t child); 191 | std::shared_ptr numberInstruction(Instruction insn); 192 | std::shared_ptr numberPhiInstruction(Block::PhiIns phiins, Block *block); 193 | void computeVarGroup(); 194 | 195 | void renameVar(size_t idx, std::map &avaliableGroup); 196 | void renameVar(); 197 | 198 | protected: 199 | Function &func; 200 | MxProgram *program; 201 | 202 | DomTree dtree, postdtree; 203 | std::vector depth; 204 | 205 | std::set blacklist; 206 | std::map blockIndex; 207 | std::vector vBlocks; 208 | 209 | std::map> opNumber; 210 | 211 | std::vector> varGroups; 212 | std::map groupID; 213 | }; 214 | } 215 | 216 | 217 | #endif -------------------------------------------------------------------------------- /src/GlobalSymbol.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "GlobalSymbol.h" 3 | 4 | GlobalSymbol * GlobalSymbol::defGS = nullptr; -------------------------------------------------------------------------------- /src/GlobalSymbol.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_GLOBAL_SYMBOL_H 2 | #define MX_COMPILER_GLOBAL_SYMBOL_H 3 | 4 | #include "common.h" 5 | 6 | struct GlobalSymbol 7 | { 8 | std::unordered_map mapSymbol, mapString; 9 | std::vector vSymbol, vString; 10 | std::vector vStringRefCount; 11 | size_t sumStringSize, memoryUsage; 12 | 13 | static GlobalSymbol *defGS; 14 | 15 | GlobalSymbol() : sumStringSize(0), memoryUsage(0) {} 16 | 17 | void setDefault() { defGS = this; } 18 | static GlobalSymbol * getDefault() { return defGS; } 19 | 20 | size_t addSymbol(const std::string &name) 21 | { 22 | auto iter = mapSymbol.find(name); 23 | if (iter == mapSymbol.end()) 24 | { 25 | vSymbol.push_back(name); 26 | mapSymbol.insert({ name, vSymbol.size() - 1 }); 27 | return vSymbol.size() - 1; 28 | } 29 | return iter->second; 30 | } 31 | size_t addString(const std::string &name) 32 | { 33 | auto iter = mapString.find(name); 34 | if (iter == mapString.end()) 35 | { 36 | vString.push_back(name); 37 | vStringRefCount.push_back(1); 38 | sumStringSize += name.size(); 39 | memoryUsage += name.size(); 40 | mapString.insert({ name, vString.size() - 1 }); 41 | return vString.size() - 1; 42 | } 43 | if (incStringRef(iter->second) == 1) 44 | sumStringSize += name.size(); 45 | return iter->second; 46 | } 47 | size_t incStringRef(size_t strId) 48 | { 49 | return ++vStringRefCount.at(strId); 50 | } 51 | size_t decStringRef(size_t strId) 52 | { 53 | size_t ref = --vStringRefCount.at(strId); 54 | if (ref == 0) 55 | sumStringSize -= vString.at(strId).size(); 56 | return ref; 57 | } 58 | }; 59 | 60 | #endif 61 | -------------------------------------------------------------------------------- /src/IR.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "IR.h" 3 | #include "utils/CycleEquiv.h" 4 | #include "LoopDetector.h" 5 | 6 | namespace MxIR 7 | { 8 | Block::block_ptr::~block_ptr() 9 | { 10 | if (ptr) 11 | { 12 | //std::cerr << "Destructing " << curBlock << ":" << this << "->" << ptr.get() << std::endl; 13 | ptr->removePred(iterPred); 14 | } 15 | } 16 | void Block::block_ptr::reset() 17 | { 18 | if (ptr) 19 | ptr->removePred(iterPred); 20 | ptr.reset(); 21 | } 22 | Block::block_ptr & Block::block_ptr::operator=(const block_ptr &other) 23 | { 24 | (*this) = other.ptr; 25 | return *this; 26 | } 27 | Block::block_ptr & Block::block_ptr::operator=(const std::shared_ptr &block) 28 | { 29 | if (ptr) 30 | ptr->removePred(iterPred); 31 | ptr = block; 32 | if (ptr) 33 | iterPred = ptr->newPred(curBlock); 34 | return *this; 35 | } 36 | Block::block_ptr & Block::block_ptr::operator=(std::shared_ptr &&block) 37 | { 38 | if (ptr) 39 | ptr->removePred(iterPred); 40 | ptr = std::move(block); 41 | if (ptr) 42 | iterPred = ptr->newPred(curBlock); 43 | return *this; 44 | } 45 | 46 | std::shared_ptr Block::construct() 47 | { 48 | std::shared_ptr ptr(new Block); 49 | ptr->self = ptr; 50 | return ptr; 51 | } 52 | void Block::traverse(std::function func) 53 | { 54 | std::set visited; 55 | std::queue q; 56 | q.push(this); 57 | visited.insert(this); 58 | while (!q.empty()) 59 | { 60 | Block *cur = q.front(); 61 | q.pop(); 62 | if (!func(cur)) 63 | return; 64 | if (cur->brTrue && visited.find(cur->brTrue.get()) == visited.end()) 65 | { 66 | visited.insert(cur->brTrue.get()); 67 | q.push(cur->brTrue.get()); 68 | } 69 | if (cur->brFalse && visited.find(cur->brFalse.get()) == visited.end()) 70 | { 71 | visited.insert(cur->brFalse.get()); 72 | q.push(cur->brFalse.get()); 73 | } 74 | } 75 | } 76 | void Block::traverse_preorder(std::function func) 77 | { 78 | std::set visited; 79 | std::function dfs; 80 | dfs = [&func, &dfs, &visited](Block *block) -> bool 81 | { 82 | visited.insert(block); 83 | if (!func(block)) 84 | return false; 85 | if (block->brTrue && !visited.count(block->brTrue.get())) 86 | if (!dfs(block->brTrue.get())) 87 | return false; 88 | if (block->brFalse && !visited.count(block->brFalse.get())) 89 | if (!dfs(block->brFalse.get())) 90 | return false; 91 | return true; 92 | }; 93 | dfs(this); 94 | } 95 | void Block::traverse_postorder(std::function func) 96 | { 97 | std::set visited; 98 | std::function dfs; 99 | dfs = [&func, &dfs, &visited](Block *block) -> bool 100 | { 101 | visited.insert(block); 102 | if (block->brTrue && !visited.count(block->brTrue.get())) 103 | if (!dfs(block->brTrue.get())) 104 | return false; 105 | if (block->brFalse && !visited.count(block->brFalse.get())) 106 | if (!dfs(block->brFalse.get())) 107 | return false; 108 | if (!func(block)) 109 | return false; 110 | return true; 111 | }; 112 | dfs(this); 113 | } 114 | void Block::traverse_rev_postorder(std::function func) 115 | { 116 | std::vector postOrder; 117 | traverse_postorder([&postOrder](Block *block) { postOrder.push_back(block); return true; }); 118 | for (auto iter = postOrder.rbegin(); iter != postOrder.rend(); ++iter) 119 | if (!func(*iter)) 120 | return; 121 | } 122 | std::list::iterator Block::newPred(Block *pred) 123 | { 124 | //std::cerr << "link: " << pred << " -> " << this << std::endl; 125 | preds.push_back(pred); 126 | return std::prev(preds.end()); 127 | } 128 | void Block::removePred(std::list::iterator iterPred) 129 | { 130 | //std::cerr << "cut " << this << std::endl; 131 | //std::cerr << "cut: " << *iterPred << " -> " << this << std::endl; 132 | preds.erase(iterPred); 133 | } 134 | 135 | bool Block::checkPhiSrc() 136 | { 137 | if (phi.empty()) 138 | return true; 139 | 140 | std::set phisrc; 141 | for(auto &kv : phi) 142 | for (auto &src : kv.second.srcs) 143 | phisrc.insert(src.second.lock().get()); 144 | for (Block *pred : preds) 145 | { 146 | if (!phisrc.count(pred)) 147 | return false; 148 | phisrc.erase(pred); 149 | } 150 | return phisrc.empty(); 151 | } 152 | 153 | Function Function::clone() 154 | { 155 | std::map> mapNewBlock; // old block -> new block 156 | inBlock->traverse([&mapNewBlock](Block *block) -> bool 157 | { 158 | std::shared_ptr newBlock = Block::construct(); 159 | newBlock->phi = block->phi; 160 | newBlock->ins = block->ins; 161 | newBlock->sigma = block->sigma; 162 | mapNewBlock[block] = std::move(newBlock); 163 | return true; 164 | }); 165 | inBlock->traverse([&mapNewBlock](Block *block) -> bool 166 | { 167 | if (block->brTrue) 168 | mapNewBlock[block]->brTrue = mapNewBlock[block->brTrue.get()]; 169 | if (block->brFalse) 170 | mapNewBlock[block]->brFalse = mapNewBlock[block->brFalse.get()]; 171 | return true; 172 | }); 173 | Function ret; 174 | ret.params = params; 175 | ret.inBlock = mapNewBlock[inBlock.get()]; 176 | ret.outBlock = mapNewBlock[outBlock.get()]; 177 | return ret; 178 | } 179 | 180 | void PSTNode::traverse(std::function func) 181 | { 182 | std::function dfs; 183 | dfs = [&dfs, &func](PSTNode *node) 184 | { 185 | for (auto &child : node->children) 186 | dfs(child.get()); 187 | func(node); 188 | }; 189 | dfs(this); 190 | } 191 | 192 | std::set PSTNode::getBlocks() 193 | { 194 | std::set ret; 195 | traverse([&ret](PSTNode *node) 196 | { 197 | for (Block *block : node->blocks) 198 | ret.insert(block); 199 | }); 200 | return ret; 201 | } 202 | 203 | void Function::constructPST() 204 | { 205 | std::map blockID; 206 | std::vector vBlocks; 207 | std::vector> vEdges; 208 | inBlock->traverse_preorder([&](Block *block) -> bool 209 | { 210 | block->pstNode.reset(); 211 | blockID[block] = vBlocks.size(); 212 | vBlocks.push_back(block); 213 | if (block->brTrue) 214 | vEdges.push_back(std::make_pair(block, block->brTrue.get())); 215 | if (block->brFalse) 216 | vEdges.push_back(std::make_pair(block, block->brFalse.get())); //note that the ids of edges are in dfs order 217 | return true; 218 | }); 219 | 220 | vEdges.push_back(std::make_pair(outBlock.get(), inBlock.get())); 221 | CycleEquiv solver(vBlocks.size()); 222 | for (auto &e : vEdges) 223 | solver.addEdge(blockID[e.first], blockID[e.second]); 224 | 225 | std::vector> result = solver.work(); 226 | std::vector> nodes; 227 | std::map outBlockNext; 228 | for (auto &equClass : result) 229 | { 230 | std::vector tmp; 231 | if (*std::prev(equClass.end()) == vEdges.size() - 1) 232 | tmp.push_back(vEdges.size() - 1); 233 | for (size_t e : equClass) 234 | tmp.push_back(e); 235 | 236 | for (auto iter = tmp.begin(); iter != std::prev(tmp.end()); ++iter) 237 | { 238 | std::shared_ptr cur(new PSTNode); 239 | cur->inBlock = vEdges[*iter].second; 240 | cur->outBlock = vEdges[*std::next(iter)].first; 241 | outBlockNext[cur.get()] = vEdges[*std::next(iter)].second; 242 | cur->inBlock->pstNode = cur; 243 | cur->outBlock->pstNode = cur; 244 | cur->blocks.insert(cur->inBlock); 245 | cur->blocks.insert(cur->outBlock); 246 | cur->self = cur; 247 | nodes.emplace_back(std::move(cur)); 248 | } 249 | } 250 | 251 | std::shared_ptr root(new PSTNode); 252 | 253 | std::stack> stkPST; 254 | std::set visited; 255 | std::function dfs; 256 | stkPST.push(root); 257 | dfs = [&outBlockNext, &stkPST, &visited, &dfs](Block *block) 258 | { 259 | visited.insert(block); 260 | int flag = block->pstNode.expired() ? 0 : 261 | stkPST.top() == block->pstNode.lock() ? -1 : 1; 262 | 263 | if (flag == 1) 264 | { 265 | std::shared_ptr curNode = block->pstNode.lock(); 266 | stkPST.top()->children.push_back(curNode); 267 | curNode->iterParent = std::prev(stkPST.top()->children.end()); 268 | curNode->parent = stkPST.top(); 269 | stkPST.push(block->pstNode.lock()); 270 | } 271 | else if(flag == 0) 272 | { 273 | block->pstNode = stkPST.top(); 274 | stkPST.top()->blocks.insert(block); 275 | } 276 | 277 | std::shared_ptr oldTop; 278 | if (block->brTrue && !visited.count(block->brTrue.get())) 279 | { 280 | if (block->brTrue.get() == outBlockNext[stkPST.top().get()]) 281 | { 282 | oldTop = stkPST.top(); 283 | stkPST.pop(); 284 | } 285 | dfs(block->brTrue.get()); 286 | if (oldTop) 287 | stkPST.emplace(std::move(oldTop)); 288 | } 289 | if (block->brFalse && !visited.count(block->brFalse.get())) 290 | { 291 | if (block->brFalse.get() == outBlockNext[stkPST.top().get()]) 292 | { 293 | oldTop = stkPST.top(); 294 | stkPST.pop(); 295 | } 296 | dfs(block->brFalse.get()); 297 | if (oldTop) 298 | stkPST.emplace(std::move(oldTop)); 299 | } 300 | 301 | if (flag == 1) 302 | stkPST.pop(); 303 | }; 304 | 305 | dfs(inBlock.get()); 306 | pstRoot = root; 307 | } 308 | 309 | void Function::splitProgramRegion() 310 | { 311 | std::map maxVer; 312 | std::vector vBlocks; 313 | inBlock->traverse([&maxVer, &vBlocks, this](Block *block) -> bool 314 | { 315 | if(block != outBlock.get()) 316 | vBlocks.push_back(block); 317 | for (auto &ins : block->instructions()) 318 | for (Operand *operand : join(ins.getInputReg(), ins.getOutputReg())) 319 | maxVer[operand->val] = std::max(maxVer[operand->val], operand->ver); 320 | return true; 321 | }); 322 | 323 | LoopDetector detector(*this); 324 | detector.findLoops(); 325 | auto &loops = detector.getLoops(); 326 | 327 | for (Block *block : vBlocks) 328 | { 329 | if (loops.count(block)) 330 | { 331 | auto &loopBody = loops.find(block)->second; 332 | std::vector entryPred; 333 | for (Block *pred : block->preds) 334 | if (!loopBody.count(pred)) 335 | entryPred.push_back(pred); 336 | 337 | if (entryPred.size() > 1) 338 | { 339 | std::shared_ptr blockPreheader = Block::construct(); 340 | blockPreheader->ins = { IRJump() }; 341 | blockPreheader->brTrue = block->self.lock(); 342 | for (Block *pred : entryPred) 343 | { 344 | if (pred->brTrue.get() == block) 345 | pred->brTrue = blockPreheader; 346 | if (pred->brFalse.get() == block) 347 | pred->brFalse = blockPreheader; 348 | } 349 | 350 | for (auto iter = block->phi.begin(); iter != block->phi.end();) 351 | { 352 | Block::PhiIns upperPhi, remainPhi; 353 | bool bRedundantRemainPhi = true, bRedundantUpperPhi = true; 354 | for (auto &src : iter->second.srcs) 355 | { 356 | if (!loopBody.count(src.second.lock().get())) 357 | { 358 | upperPhi.srcs.push_back(src); 359 | if (src.first.isReg()) 360 | bRedundantUpperPhi = false; 361 | } 362 | else 363 | { 364 | remainPhi.srcs.push_back(src); 365 | assert(src.first.isReg()); 366 | if (src.first.val != iter->second.dst.val || src.first.ver != iter->second.dst.ver) 367 | bRedundantRemainPhi = false; 368 | } 369 | } 370 | 371 | if (bRedundantRemainPhi) 372 | { 373 | upperPhi.dst = iter->second.dst; 374 | blockPreheader->phi[upperPhi.dst.val] = upperPhi; 375 | iter = block->phi.erase(iter); 376 | } 377 | else if (bRedundantUpperPhi) 378 | { 379 | remainPhi.dst = iter->second.dst; 380 | remainPhi.srcs.push_back({ EmptyOperand(), blockPreheader }); 381 | iter->second = remainPhi; 382 | ++iter; 383 | } 384 | else 385 | { 386 | upperPhi.dst = iter->second.dst; 387 | upperPhi.dst.ver = ++maxVer[upperPhi.dst.val]; 388 | blockPreheader->phi[upperPhi.dst.val] = upperPhi; 389 | remainPhi.dst = iter->second.dst; 390 | remainPhi.srcs.push_back({ upperPhi.dst, blockPreheader }); 391 | iter->second = remainPhi; 392 | ++iter; 393 | } 394 | } 395 | } 396 | else if (entryPred.front()->brTrue && entryPred.front()->brFalse) 397 | { 398 | Block *pred = entryPred.front(); 399 | 400 | std::shared_ptr blockPreheader = Block::construct(); 401 | blockPreheader->ins = { IRJump() }; 402 | blockPreheader->brTrue = block->self.lock(); 403 | 404 | if (pred->brTrue.get() == block) 405 | pred->brTrue = blockPreheader->self.lock(); 406 | if (pred->brFalse.get() == block) 407 | pred->brFalse = blockPreheader->self.lock(); 408 | 409 | block->redirectPhiSrc(pred, blockPreheader.get()); 410 | } 411 | continue; 412 | } 413 | if (block->preds.size() > 1) 414 | { 415 | auto preds = block->preds; 416 | std::shared_ptr tmp = Block::construct(); 417 | tmp->phi = std::move(block->phi); 418 | tmp->brTrue = block->self.lock(); 419 | tmp->ins = { IRJump() }; 420 | for (Block *pred : preds) 421 | { 422 | if (pred->brTrue.get() == block) 423 | pred->brTrue = tmp; 424 | if (pred->brFalse.get() == block) 425 | pred->brFalse = tmp; 426 | } 427 | } 428 | if (block->brTrue && block->brFalse) 429 | { 430 | assert(block->phi.empty()); 431 | auto spliceEnd = std::prev(block->ins.end()); 432 | if (block->ins.size() >= 2 433 | && block->ins.back().oper == Br 434 | && std::prev(block->ins.end(), 2)->dst.val == block->ins.back().src1.val 435 | && std::prev(block->ins.end(), 2)->dst.ver == block->ins.back().src1.ver) 436 | { 437 | --spliceEnd; 438 | } 439 | if (spliceEnd != block->ins.begin()) 440 | { 441 | auto preds = block->preds; 442 | std::shared_ptr tmp = Block::construct(); 443 | tmp->ins.splice(tmp->ins.end(), block->ins, block->ins.begin(), spliceEnd); 444 | tmp->ins.push_back(IRJump()); 445 | tmp->brTrue = block->self.lock(); 446 | for (Block *pred : preds) 447 | { 448 | if (pred->brTrue.get() == block) 449 | pred->brTrue = tmp; 450 | if (pred->brFalse.get() == block) 451 | pred->brFalse = tmp; 452 | } 453 | if (block == inBlock.get()) 454 | inBlock = tmp; 455 | } 456 | } 457 | } 458 | } 459 | 460 | void Block::redirectPhiSrc(Block *from, Block *to) 461 | { 462 | for (auto &kv : phi) 463 | for (auto &src : kv.second.srcs) 464 | if (src.second.lock().get() == from) 465 | src.second = to->self; 466 | } 467 | 468 | void Function::mergeBlocks() 469 | { 470 | std::set blocks; 471 | inBlock->traverse([&blocks](Block *block) -> bool 472 | { 473 | blocks.insert(block); 474 | return true; 475 | }); 476 | 477 | for (auto iter = blocks.begin(); iter != blocks.end(); ) 478 | { 479 | std::shared_ptr block = (*iter)->self.lock(); 480 | if (block == outBlock) 481 | { 482 | ++iter; 483 | continue; 484 | } 485 | if (block->preds.size() == 1) 486 | { 487 | Block *pred = block->preds.front(); 488 | if (pred->brTrue.get() == block.get() && !pred->brFalse) 489 | { 490 | pred->ins.pop_back(); 491 | pred->ins.splice(pred->ins.end(), block->ins); 492 | 493 | if(block->brTrue) 494 | block->brTrue->redirectPhiSrc(block.get(), pred); 495 | if (block->brFalse) 496 | block->brFalse->redirectPhiSrc(block.get(), pred); 497 | 498 | pred->brTrue = block->brTrue; 499 | pred->brFalse = block->brFalse; 500 | block->brTrue.reset(); 501 | block->brFalse.reset(); 502 | iter = blocks.erase(iter); 503 | continue; 504 | } 505 | } 506 | if (block->ins.size() == 1 && block->ins.back().oper == Jump && block != inBlock) 507 | { 508 | std::map phi; 509 | for (auto &kv : block->phi) 510 | phi[kv.second.dst] = kv.second; 511 | 512 | Block *next = block->brTrue.get(); 513 | for (auto &kv : next->phi) //TODO: check if it is correct 514 | { 515 | decltype(kv.second.srcs) newSrc; 516 | for (auto iterSrc = kv.second.srcs.begin(); iterSrc != kv.second.srcs.end(); ++iterSrc) 517 | { 518 | if (iterSrc->second.lock() == block) 519 | { 520 | if (!phi.count(iterSrc->first)) 521 | { 522 | for (Block *pred : block->preds) 523 | newSrc.push_back(std::make_pair(EmptyOperand(), pred->self)); 524 | } 525 | else 526 | { 527 | for (auto &src : phi[iterSrc->first].srcs) 528 | newSrc.push_back(src); 529 | } 530 | } 531 | else 532 | newSrc.push_back(*iterSrc); 533 | } 534 | kv.second.srcs = newSrc; 535 | } 536 | 537 | for (Block *pred : std::list(block->preds)) 538 | { 539 | if (pred->brTrue.get() == block.get()) 540 | pred->brTrue = next->self.lock(); 541 | if (pred->brFalse.get() == block.get()) 542 | pred->brFalse = next->self.lock(); 543 | } 544 | block->brTrue.reset(); 545 | iter = blocks.erase(iter); 546 | continue; 547 | } 548 | ++iter; 549 | } 550 | } 551 | } -------------------------------------------------------------------------------- /src/IR.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_IR_H 2 | #define MX_COMPILER_IR_H 3 | 4 | #include "common.h" 5 | #include "utils/JoinIterator.h" 6 | #include "utils/ElementAdapter.h" 7 | 8 | namespace MxIR 9 | { 10 | enum Operation 11 | { 12 | Nop, 13 | Add, Sub, //dst = src1 op src2, dst & src1 & src2 must have the same size 14 | Mult, Div, Mod, 15 | Shlu, Shru, 16 | Shl, Shr, //FIXME: signed shift 17 | And, Or, Xor, 18 | Neg, Not, 19 | Load, Store, //dst = load [src1]; store dst [src1] (src1 can be reg / imm / funcID / stringID / globalVarID / externalSymbolName) 20 | LoadA, StoreA, //dst = load [src1 + src2 (signed)]; store dst [src1 + src2] 21 | Move, //dst = src, data moving from a small register to a large register is NOT allowed, use Sext/Zext instead 22 | Sext, Zext, //sign extend & zero extend 23 | Slt, Sle, Seq, Sge, Sgt, Sne, //dst = src1 op src2 ? true : false 24 | Sltu, Sleu, Sgeu, Sgtu, 25 | //Blt, Ble, Beq, Bge, Bgt, Bne, //branch if src1 op src2 26 | Br, //branch to brTrue if src1 == true (likely when src2 == 1, unlikely when src2 == 2) 27 | Jump, //jump to brTrue 28 | Call, //src1 should be func id / address (funcID / reg / imm) 29 | //CallExternal, //src1 should be func name (in symbol table) (DEPRECATED) 30 | Return, //return with value src1 31 | Allocate, //alloc a var in stack with size of src1 bytes and align to src2 bytes, alloc hint store in paramExt[0] 32 | 33 | TestZero, //dst = test_zero src1, src2: if src1 == 0, dst = src2; else dst = src1 34 | 35 | //Operations below are only avaliable when allocating register & generating code 36 | LockReg, 37 | UnlockReg, 38 | ParallelMove, //paramExt[0..n/2-1] : dst, paramExt[n/2, n] : src 39 | MoveToRegister, 40 | ExternalVar, //dst = external(hint), define a var without loading it to register; alloc hint store in src1; must be in inBlock 41 | PushParam, 42 | Placeholder, 43 | Xchg, //xchg(src1, src2) 44 | LoadAddr, //lea(src1 + src2) 45 | }; 46 | 47 | struct Operand 48 | { 49 | enum OperandType : std::uint8_t 50 | { 51 | empty, 52 | reg64, reg32, reg16, reg8, 53 | imm64, imm32, imm16, imm8, 54 | funcID, constID, globalVarID, externalSymbolName 55 | }; 56 | OperandType type; 57 | std::uint64_t val; 58 | std::size_t ver; //in ssa form, ver >= 1 59 | int pregid; //physical register id 60 | bool noreg; 61 | 62 | static const std::uint64_t InvalidID = std::uint64_t(-1); 63 | 64 | Operand() : type(empty), pregid(-1), ver(0), noreg(false) {} 65 | Operand(OperandType type, std::uint64_t val) : type(type), val(val), pregid(-1), ver(0), noreg(false) {} 66 | 67 | Operand clone() const { return Operand(*this); } 68 | Operand & setVer(size_t newVer) { ver = newVer; return *this; } 69 | Operand & setVal(std::uint64_t newVal) { val = newVal; return *this; } 70 | Operand & setPRegID(int newRegID) { pregid = newRegID; return *this; } 71 | Operand & setNOReg(bool newFlag) { noreg = newFlag; return *this; } 72 | Operand & setSize(size_t size) 73 | { 74 | assert(isImm() || isReg()); 75 | switch (size) 76 | { 77 | case 8: 78 | type = isImm() ? imm64 : reg64; 79 | break; 80 | case 4: 81 | type = isImm() ? imm32 : reg32; 82 | break; 83 | case 2: 84 | type = isImm() ? imm16 : reg16; 85 | break; 86 | case 1: 87 | type = isImm() ? imm8 : reg8; 88 | break; 89 | default: 90 | assert(false); 91 | } 92 | return *this; 93 | } 94 | 95 | bool isImm() const { return type == imm64 || type == imm32 || type == imm16 || type == imm8; } 96 | bool isReg() const { return type == reg64 || type == reg32 || type == reg16 || type == reg8; } 97 | bool isConst() const { return type != empty && !isReg(); } 98 | size_t size() const 99 | { 100 | switch (type) 101 | { 102 | case reg64: case imm64: 103 | return 8; 104 | case reg32: case imm32: 105 | return 4; 106 | case reg16: case imm16: 107 | return 2; 108 | case reg8: case imm8: 109 | return 1; 110 | default: 111 | return POINTER_SIZE; 112 | } 113 | } 114 | 115 | //used in set/map of Operand 116 | bool operator<(const Operand &rhs) const 117 | { 118 | //assert(isReg() && rhs.isReg()); 119 | if (type == rhs.type || isReg() && rhs.isReg()) 120 | { 121 | if (val == rhs.val) 122 | return ver < rhs.ver; 123 | return val < rhs.val; 124 | } 125 | return type < rhs.type; 126 | } 127 | }; 128 | //NOTE: when writing one part of the 64-bit register, the value in the other part of the register is UNDEFINED 129 | inline Operand Reg64(std::uint64_t regid) { return Operand{ Operand::reg64, regid }; } 130 | inline Operand Reg32(std::uint64_t regid) { return Operand{ Operand::reg32, regid }; } 131 | inline Operand Reg16(std::uint64_t regid) { return Operand{ Operand::reg16, regid }; } 132 | inline Operand Reg8(std::uint64_t regid) { return Operand{ Operand::reg8, regid }; } 133 | inline Operand RegPtr(std::uint64_t regid) { return Reg64(regid); } 134 | inline Operand RegSize(std::uint64_t regid, size_t size) 135 | { 136 | switch (size) 137 | { 138 | case 1: 139 | return Reg8(regid); 140 | case 2: 141 | return Reg16(regid); 142 | case 4: 143 | return Reg32(regid); 144 | case 8: 145 | return Reg64(regid); 146 | default: 147 | assert(false); 148 | return Operand(); 149 | } 150 | } 151 | 152 | inline Operand Imm64(std::int64_t num) { return Operand{ Operand::imm64, std::uint64_t(num) }; } 153 | inline Operand Imm32(std::int32_t num) { return Operand{ Operand::imm32, std::uint64_t(num) }; } 154 | inline Operand Imm16(std::int16_t num) { return Operand{ Operand::imm16, std::uint64_t(num) }; } 155 | inline Operand Imm8(std::int8_t num) { return Operand{ Operand::imm8, std::uint64_t(num) }; } 156 | inline Operand ImmPtr(std::int64_t num) { return Imm64(num); } 157 | inline Operand ImmSize(std::uint64_t val, size_t size) 158 | { 159 | Operand op; 160 | op.val = val; 161 | if (size == 1) 162 | op.type = Operand::imm8; 163 | else if (size == 2) 164 | op.type = Operand::imm16; 165 | else if (size == 4) 166 | op.type = Operand::imm32; 167 | else if (size == 8) 168 | op.type = Operand::imm64; 169 | else 170 | { 171 | assert(false); 172 | } 173 | return op; 174 | } 175 | 176 | inline Operand IDFunc(size_t funcID) { return Operand{ Operand::funcID, funcID }; } 177 | inline Operand IDConst(size_t constID) { return Operand{ Operand::constID, constID }; } 178 | inline Operand IDGlobalVar(size_t varID) { return Operand{ Operand::globalVarID, varID }; } 179 | inline Operand IDExtSymbol(size_t symbolID) { return Operand{ Operand::externalSymbolName, symbolID }; } 180 | inline Operand EmptyOperand() { return Operand(); } 181 | 182 | struct InstructionBase 183 | { 184 | enum registerHint { NoPrefer, PreferAnyOfOperands, PreferOperands, PreferCorrespondingOperand }; 185 | registerHint hint; 186 | 187 | InstructionBase() : hint(NoPrefer) {} 188 | virtual std::vector getInputReg() = 0; 189 | virtual std::vector getOutputReg() = 0; 190 | }; 191 | 192 | struct Instruction : public InstructionBase 193 | { 194 | Operation oper; 195 | Operand dst, src1, src2; 196 | std::vector paramExt; 197 | 198 | Instruction() { autoPrefer(); } 199 | Instruction(Operation oper) : oper(oper) { autoPrefer(); } 200 | Instruction(Operation oper, Operand dst, Operand src1) : oper(oper), dst(dst), src1(src1) { autoPrefer(); } 201 | Instruction(Operation oper, Operand dst, Operand src1, Operand src2) : oper(oper), dst(dst), src1(src1), src2(src2) { autoPrefer(); } 202 | Instruction(Operation oper, Operand dst, Operand src1, Operand src2, const std::vector ¶mExt) : oper(oper), dst(dst), src1(src1), src2(src2), paramExt(paramExt) { autoPrefer(); } 203 | 204 | void autoPrefer() 205 | { 206 | if (oper == Move) 207 | hint = PreferAnyOfOperands; 208 | else if (oper == ParallelMove) 209 | hint = PreferCorrespondingOperand; 210 | } 211 | 212 | virtual std::vector getInputReg() override final 213 | { 214 | std::vector ret; 215 | if (oper == ParallelMove) 216 | { 217 | assert(paramExt.size() % 2 == 0); 218 | for (size_t i = paramExt.size() / 2; i < paramExt.size(); i++) 219 | { 220 | assert(paramExt[i].isReg()); 221 | ret.push_back(¶mExt[i]); 222 | } 223 | return ret; 224 | } 225 | if ((oper == Store || oper == StoreA) && dst.isReg()) 226 | ret.push_back(&dst); 227 | if (src1.isReg()) 228 | ret.push_back(&src1); 229 | if (src2.isReg()) 230 | ret.push_back(&src2); 231 | for (auto ¶m : paramExt) 232 | if (param.isReg()) 233 | ret.push_back(¶m); 234 | return ret; 235 | } 236 | virtual std::vector getOutputReg() override final 237 | { 238 | std::vector ret; 239 | if (oper == ParallelMove) 240 | { 241 | assert(paramExt.size() % 2 == 0); 242 | for (size_t i = 0; i < paramExt.size() / 2; i++) 243 | { 244 | assert(paramExt[i].isReg()); 245 | ret.push_back(¶mExt[i]); 246 | } 247 | return ret; 248 | } 249 | if (oper != Store && oper != StoreA && dst.isReg()) 250 | ret.push_back(&dst); 251 | return ret; 252 | } 253 | }; 254 | 255 | enum BranchFreq : std::uint64_t 256 | { 257 | normal = 0, likely = 1, unlikely = 2 258 | }; 259 | 260 | inline Instruction IRNop() { return Instruction{ Nop }; } 261 | inline Instruction IR(Operand dst, Operation oper, Operand src) { return Instruction{ oper, dst, src }; } 262 | inline Instruction IR(Operand dst, Operation oper, Operand src1, Operand src2) { return Instruction{ oper, dst, src1, src2 }; } 263 | //[[deprecated]] inline Instruction IR(Operand dst, Operation oper, Operand src, std::vector paramExt) { return Instruction{ oper, dst, src, Operand(), paramExt }; } 264 | inline Instruction IRCall(Operand dst, Operand func, std::vector paramExt) { return Instruction{ Call, dst, func, Operand(), paramExt }; } 265 | inline Instruction IRJump() { return Instruction{ Jump }; } 266 | inline Instruction IRReturn() { return Instruction{ Return , EmptyOperand(), EmptyOperand()}; } 267 | inline Instruction IRReturn(Operand ret) { return Instruction{ Return, EmptyOperand(), ret }; } 268 | inline Instruction IRBranch(Operand src, BranchFreq freq = normal) { return Instruction{ Br, Operand{}, src, Imm64(freq)}; } 269 | inline Instruction IRStore(Operand data, Operand addr) { return Instruction{ Store, data, addr }; } 270 | inline Instruction IRStoreA(Operand data, Operand base, Operand offset) { return Instruction{ StoreA, data, base, offset }; } 271 | 272 | inline Instruction IRParallelMove(const std::vector &dst, const std::vector &src) 273 | { 274 | assert(dst.size() == src.size()); 275 | Instruction insn(ParallelMove); 276 | for (Operand operand : dst) 277 | insn.paramExt.push_back(operand); 278 | for (Operand operand : src) 279 | insn.paramExt.push_back(operand); 280 | return insn; 281 | } 282 | inline Instruction IRMoveToRegister(const std::vector ¶m) { return Instruction{ MoveToRegister, EmptyOperand(), EmptyOperand(), EmptyOperand(), param }; } 283 | inline Instruction IRLockRegister(const std::vector ®s) 284 | { 285 | Instruction insn(LockReg); 286 | for (int reg : regs) 287 | insn.paramExt.push_back(RegPtr(Operand::InvalidID).setPRegID(reg)); 288 | return insn; 289 | } 290 | inline Instruction IRUnlockRegister() { return Instruction(UnlockReg); } 291 | 292 | class Block; 293 | struct PSTNode 294 | { 295 | Block *inBlock, *outBlock; 296 | std::set blocks; 297 | 298 | std::list> children; 299 | std::list>::iterator iterParent; 300 | std::weak_ptr parent, self; 301 | 302 | void traverse(std::function func); 303 | std::set getBlocks(); 304 | 305 | bool isSibling(PSTNode *other) const { return parent.lock() == other->parent.lock(); } 306 | bool isSingleBlock() const 307 | { 308 | return inBlock == outBlock && blocks.size() == 1 && children.empty(); 309 | } 310 | 311 | }; 312 | //FIXME: Possible Memory Leak 313 | class Block 314 | { 315 | public: 316 | struct PhiIns : public InstructionBase 317 | { 318 | Operand dst; 319 | std::vector>> srcs; 320 | 321 | PhiIns() { hint = PreferOperands; } 322 | PhiIns(Operand dst) : dst(dst) { hint = PreferOperands; } 323 | 324 | virtual std::vector getInputReg() override final 325 | { 326 | std::vector ret; 327 | for (auto &src : srcs) 328 | { 329 | if(src.first.isReg()) 330 | ret.push_back(&src.first); 331 | } 332 | return ret; 333 | } 334 | virtual std::vector getOutputReg() override final 335 | { 336 | assert(dst.isReg()); 337 | return { &dst }; 338 | } 339 | }; 340 | struct SigmaIns : public InstructionBase 341 | { 342 | Operand dstTrue, dstFalse; 343 | Operand src; 344 | 345 | SigmaIns() {} 346 | SigmaIns(Operand src) : src(src) {} 347 | 348 | virtual std::vector getInputReg() override final 349 | { 350 | assert(src.isReg()); 351 | return { &src }; 352 | } 353 | virtual std::vector getOutputReg() override final 354 | { 355 | assert(dstTrue.isReg()); 356 | assert(dstFalse.isReg()); 357 | return { &dstTrue, &dstFalse }; 358 | } 359 | }; 360 | class block_ptr 361 | { 362 | public: 363 | explicit block_ptr(Block *curBlock) : curBlock(curBlock) 364 | { 365 | assert(this == &curBlock->brTrue || this == &curBlock->brFalse); 366 | } 367 | block_ptr(const block_ptr &other) = delete; 368 | block_ptr(block_ptr &&other) = delete; 369 | 370 | ~block_ptr(); 371 | void reset(); 372 | block_ptr & operator=(const block_ptr &other); 373 | block_ptr & operator=(const std::shared_ptr &block); 374 | block_ptr & operator=(std::shared_ptr &&block); 375 | bool operator==(const block_ptr &rhs) const { return ptr == rhs.ptr; } 376 | bool operator!=(const block_ptr &rhs) const { return ptr != rhs.ptr; } 377 | Block * operator->() const { return get(); } 378 | Block * get() const { return ptr.get(); } 379 | operator bool() const { return bool(ptr); } 380 | 381 | protected: 382 | std::shared_ptr ptr; 383 | std::list::iterator iterPred; 384 | Block * const curBlock; 385 | }; 386 | 387 | public: 388 | std::list ins; 389 | std::map phi; //map register id to phi instruction 390 | std::map sigma; 391 | 392 | block_ptr brTrue, brFalse; 393 | std::weak_ptr self; 394 | std::list preds; 395 | 396 | std::weak_ptr pstNode; 397 | 398 | IF_DEBUG(std::string dbgInfo); 399 | 400 | public: 401 | static std::shared_ptr construct(); 402 | void traverse(std::function func); 403 | void traverse_preorder(std::function func); 404 | void traverse_postorder(std::function func); 405 | void traverse_rev_postorder(std::function func); 406 | auto instructions() 407 | { 408 | auto pair_second = [](auto &p) -> auto& { return p.second; }; 409 | return join(element_adapter(phi, pair_second), ins, element_adapter(sigma, pair_second)); 410 | } 411 | void redirectPhiSrc(Block *from, Block *to); 412 | bool checkPhiSrc(); 413 | 414 | protected: 415 | std::list::iterator newPred(Block *pred); 416 | void removePred(std::list::iterator iterPred); 417 | 418 | private: 419 | Block() : brTrue(this), brFalse(this) {} 420 | }; 421 | 422 | class Function 423 | { 424 | public: 425 | std::vector params; 426 | std::shared_ptr inBlock, outBlock; 427 | 428 | std::shared_ptr pstRoot; 429 | 430 | void constructPST(); 431 | void splitProgramRegion(); 432 | void mergeBlocks(); 433 | Function clone(); 434 | }; 435 | } 436 | 437 | #endif -------------------------------------------------------------------------------- /src/IRGenerator.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_IR_GENERATOR_H 2 | #define MX_COMPILER_IR_GENERATOR_H 3 | 4 | #include "AST.h" 5 | #include "IR.h" 6 | #include "MxProgram.h" 7 | #include "IssueCollector.h" 8 | #include "MxBuiltin.h" 9 | 10 | class IRGenerator : protected MxAST::ASTVisitor 11 | { 12 | public: 13 | IRGenerator() : program(MxProgram::getDefault()), symbol(GlobalSymbol::getDefault()), issues(IssueCollector::getDefault()) {} 14 | IRGenerator(MxProgram *program, GlobalSymbol *symbol, IssueCollector *issues) : program(program), symbol(symbol), issues(issues) {} 15 | 16 | MxIR::Function generate(MxAST::ASTDeclFunc *declFunc); 17 | 18 | //generate the ir code of the full program and store it to MxProgram 19 | //note that it will 20 | // 1. fillin the size and offset field of MxProgram::vClass 21 | // 2. add all string in GlobalSymbol to const list 22 | // 3. fillin the __initialize function 23 | // 4. add a function call of __initialize to the very beginning of main function 24 | // 5. add 'Export' attribute to main function 25 | void generateProgram(MxAST::ASTRoot *root); 26 | 27 | protected: 28 | static void redirectReturn(std::shared_ptr inBlock, std::shared_ptr outBlock); //link all blocks that are ended with return to outBlock 29 | static MxIR::Operand RegByType(size_t regid, MxType type); 30 | static MxIR::Operand RegByType(size_t regid, MxIR::Operand other); 31 | static MxIR::Operand ImmByType(std::int64_t imm, MxType type); 32 | static MxIR::Operand ImmByType(std::int64_t imm, MxIR::Operand other); 33 | static void merge(std::shared_ptr ¤tBlock, std::shared_ptr &blkIn, std::shared_ptr &blkOut); 34 | void merge(std::shared_ptr ¤tBlock); //merge last block / lastIns to current block 35 | static MxIR::Instruction releaseXValue(MxIR::Operand addr, MxType type); 36 | 37 | virtual void visit(MxAST::ASTDeclVar *declVar) override; 38 | virtual void visit(MxAST::ASTExprImm *imm) override; 39 | virtual void visit(MxAST::ASTExprVar *var) override; 40 | virtual void visit(MxAST::ASTExprUnary *unary) override; 41 | virtual void visit(MxAST::ASTExprBinary *binary) override; 42 | virtual void visit(MxAST::ASTExprAssignment *assign) override; 43 | virtual void visit(MxAST::ASTExprNew *exprNew) override; 44 | virtual void visit(MxAST::ASTExprSubscriptAccess *exprSub) override; 45 | virtual void visit(MxAST::ASTExprMemberAccess *expr) override; 46 | virtual void visit(MxAST::ASTExprFuncCall *expr) override; 47 | virtual void visit(MxAST::ASTStatementReturn *stat) override; 48 | virtual void visit(MxAST::ASTStatementBreak *stat) override; 49 | virtual void visit(MxAST::ASTStatementContinue *stat) override; 50 | virtual void visit(MxAST::ASTStatementIf *stat) override; 51 | virtual void visit(MxAST::ASTStatementWhile *stat) override; 52 | virtual void visit(MxAST::ASTStatementFor *stat) override; 53 | virtual void visit(MxAST::ASTStatementExpr *stat) override; 54 | virtual void visit(MxAST::ASTBlock *block) override; 55 | 56 | void generateFuncCall(size_t funcID, const std::vector ¶m); 57 | void visitExprRec(MxAST::ASTNode *node); //must be called by visit(ASTExpr* *) 58 | void visitExpr(MxAST::ASTNode *node); 59 | void clearXValueStack(); 60 | 61 | void releaseLocalVar(); 62 | 63 | size_t findMain(); 64 | 65 | protected: 66 | MxProgram *program; 67 | GlobalSymbol *symbol; 68 | IssueCollector *issues; 69 | std::list lastIns; 70 | std::shared_ptr lastBlockIn, lastBlockOut; 71 | std::shared_ptr loopContinue, loopBreak; 72 | std::shared_ptr returnBlock; 73 | std::stack> stkXValues; 74 | 75 | MxIR::Operand lastOperand, lastWriteAddr; //when Write flag is set and lastWriteAddr.type == empty, we can write to lastOperand directly 76 | size_t regNum, regThis; 77 | 78 | std::map mapStringConstID; 79 | 80 | size_t funcID; 81 | std::vector declaredVar; 82 | 83 | enum vflag : std::uint32_t 84 | { 85 | Read = 1, Write = 2 86 | }; 87 | std::uint32_t visFlag; 88 | std::stack stkFlags; 89 | void setFlag(std::uint32_t newFlag) { stkFlags.push(visFlag); visFlag = newFlag; } 90 | void resumeFlag() { visFlag = stkFlags.top(); stkFlags.pop(); } 91 | }; 92 | 93 | #endif -------------------------------------------------------------------------------- /src/IRVisualizer.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sxtyzhangzk/MxCompiler/38225cf5ce6a7efebfaea6f8834e76496880333c/src/IRVisualizer.cpp -------------------------------------------------------------------------------- /src/IRVisualizer.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_IR_VISUALIZER_H 2 | #define MX_COMPILER_IR_VISUALIZER_H 3 | 4 | #include "IR.h" 5 | #include "GlobalSymbol.h" 6 | #include "MxProgram.h" 7 | 8 | class IRVisualizer 9 | { 10 | public: 11 | IRVisualizer(std::ostream &out) : out(out), symbol(GlobalSymbol::getDefault()), program(MxProgram::getDefault()), cntBlock(0), cntCluster(0) {} 12 | std::string toString(const MxIR::Operand &operand, bool isHTML); 13 | std::string toString(const MxIR::Instruction &ins, bool isHTML); 14 | std::string toString(const MxIR::Block &block, bool isHTML); 15 | std::string toHTML(const MxIR::Block &block, int flag, const std::string &funcName); //flag: 1 for in block, 2 for out block 16 | void print(const MxIR::Function &func, const std::string &funcName, bool noPST = false); 17 | void printHead() { out << "digraph mxprog {" << std::endl; } 18 | void printFoot() { out << "}" << std::endl; } 19 | void reset() { cntBlock = 0; } 20 | 21 | void printAll(); 22 | 23 | protected: 24 | std::ostream &out; 25 | GlobalSymbol *symbol; 26 | MxProgram *program; 27 | bool enableColor; 28 | size_t cntBlock, cntCluster; 29 | }; 30 | 31 | #endif -------------------------------------------------------------------------------- /src/InlineOptimizer.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "InlineOptimizer.h" 3 | #include "utils/JoinIterator.h" 4 | 5 | namespace MxIR 6 | { 7 | size_t InlineOptimizer::statFunc::penalty() const 8 | { 9 | if (forceInline) 10 | return 0; 11 | bool hasFuncCall = externalCall ? true : !callTo.empty(); 12 | return nInsn + nBlock * nBlock + (hasFuncCall ? 100 : 0); 13 | } 14 | 15 | void InlineOptimizer::work() 16 | { 17 | const size_t maxPenalty = CompileFlags::getInstance()->inline_param; 18 | const int threshold = CompileFlags::getInstance()->inline_param2; 19 | struct Tqueue 20 | { 21 | size_t idx, nUpdate, penalty; 22 | bool operator<(const Tqueue &rhs) const { return penalty > rhs.penalty; } 23 | }; 24 | std::priority_queue Q; 25 | analyzeProgram(); 26 | for (size_t i = 0; i < stats.size(); i++) 27 | { 28 | if (program->vFuncs[i].attribute & NoInline) 29 | continue; 30 | Q.push(Tqueue{ i, stats[i].nUpdate, stats[i].penalty() }); 31 | } 32 | while (!Q.empty()) 33 | { 34 | Tqueue cur = Q.top(); 35 | Q.pop(); 36 | if (cur.nUpdate != stats[cur.idx].nUpdate) 37 | continue; 38 | if (cur.penalty > maxPenalty) 39 | break; 40 | Function content = program->vFuncs[cur.idx].content.clone(); 41 | if (stats[cur.idx].callTo.count(cur.idx)) 42 | { 43 | size_t incBlock = (stats[cur.idx].nBlock - 1) * stats[cur.idx].callTo[cur.idx]; 44 | if (incBlock > threshold + stats[cur.idx].nBlock * sqrt(threshold)) 45 | continue; 46 | 47 | applyInline(cur.idx, cur.idx, content); 48 | stats[cur.idx].nUpdate++; 49 | Q.push(Tqueue{ cur.idx, stats[cur.idx].nUpdate, stats[cur.idx].penalty() }); 50 | } 51 | else 52 | { 53 | bool flag = true; 54 | for (size_t i = 0; i < stats.size(); i++) 55 | if (stats[i].callTo.count(cur.idx)) 56 | { 57 | size_t incBlock = (stats[cur.idx].nBlock - 1) * stats[i].callTo[cur.idx]; 58 | if (incBlock > threshold + stats[i].nBlock * sqrt(threshold)) 59 | { 60 | flag = false; 61 | continue; 62 | } 63 | 64 | applyInline(cur.idx, i, content); 65 | assert(!stats[i].callTo.count(cur.idx)); 66 | stats[i].nUpdate++; 67 | 68 | if(!(program->vFuncs[i].attribute & NoInline)) 69 | Q.push(Tqueue{ i, stats[i].nUpdate, stats[i].penalty() }); 70 | } 71 | if(flag && !(program->vFuncs[cur.idx].attribute & Export)) 72 | program->vFuncs[cur.idx].disabled = true; 73 | } 74 | } 75 | } 76 | 77 | void InlineOptimizer::analyzeProgram() 78 | { 79 | stats.resize(program->vFuncs.size()); 80 | for (size_t i = 0; i < program->vFuncs.size(); i++) 81 | { 82 | stats[i] = analyzeFunc(program->vFuncs[i].content); 83 | if (program->vFuncs[i].attribute & ForceInline) 84 | stats[i].forceInline = true; 85 | else 86 | stats[i].forceInline = false; 87 | } 88 | } 89 | 90 | InlineOptimizer::statFunc InlineOptimizer::analyzeFunc(Function &func) 91 | { 92 | statFunc stat; 93 | stat.nVar = 0; 94 | for (auto ¶m : func.params) 95 | { 96 | assert(param.isReg()); 97 | stat.nVar = std::max(stat.nVar, param.val + 1); 98 | } 99 | func.inBlock->traverse([&stat](Block *block) -> bool 100 | { 101 | stat.nBlock++; 102 | stat.nInsn += block->ins.size(); 103 | for (auto &ins : block->ins) 104 | { 105 | if (ins.oper == Call) 106 | { 107 | if (ins.src1.type == Operand::funcID) 108 | stat.callTo[ins.src1.val]++; 109 | else if (ins.src1.type == Operand::externalSymbolName) 110 | stat.externalCall = true; 111 | } 112 | for (Operand *operand : join(ins.getInputReg(), ins.getOutputReg())) 113 | stat.nVar = std::max(stat.nVar, operand->val + 1); 114 | } 115 | return true; 116 | }); 117 | return stat; 118 | } 119 | 120 | void InlineOptimizer::applyInline(size_t callee, size_t caller, Function &content) 121 | { 122 | std::vector vBlocks; 123 | program->vFuncs[caller].content.inBlock->traverse([&vBlocks](Block *block) -> bool 124 | { 125 | vBlocks.push_back(block); 126 | return true; 127 | }); 128 | auto callTo = stats[callee].callTo; 129 | 130 | for(Block *block : vBlocks) 131 | { 132 | //auto endIter = block->ins.end(); 133 | for (auto iter = block->ins.begin(); iter != block->ins.end(); ) 134 | { 135 | if (iter->oper == Call && iter->src1.type == Operand::funcID && iter->src1.val == callee) 136 | { 137 | Operand retVar = iter->dst; 138 | Function child = content.clone(); 139 | assert(child.outBlock->ins.empty()); 140 | size_t offsetVarID = stats[caller].nVar; 141 | 142 | stats[caller].nVar += stats[callee].nVar; 143 | stats[caller].nInsn += stats[callee].nInsn + child.params.size(); 144 | if (stats[callee].nBlock > 2) 145 | stats[caller].nBlock += stats[callee].nBlock - 1; 146 | for (auto &kv : callTo) 147 | stats[caller].callTo[kv.first] += kv.second; 148 | stats[caller].callTo[callee]--; 149 | 150 | 151 | child.inBlock->traverse([offsetVarID, &retVar, &child](Block *block) -> bool 152 | { 153 | for (auto iter = block->ins.begin(); iter != block->ins.end(); ++iter) 154 | { 155 | for (Operand *operand : join(iter->getInputReg(), iter->getOutputReg())) 156 | operand->val += offsetVarID; 157 | if (iter->oper == Return) 158 | { 159 | assert(std::next(iter) == block->ins.end() && block->brTrue.get() == child.outBlock.get()); 160 | if (retVar.type == Operand::empty) 161 | *iter = IRJump(); 162 | else if (iter->src1.type == Operand::empty) 163 | { 164 | block->ins.insert(iter, IR(retVar, Move, ImmSize(0, retVar.size()))); 165 | *iter = IRJump(); 166 | } 167 | else 168 | { 169 | block->ins.insert(iter, IR(retVar, Move, iter->src1)); 170 | *iter = IRJump(); 171 | } 172 | } 173 | } 174 | return true; 175 | }); 176 | assert(child.params.size() == iter->paramExt.size()); 177 | for (auto iterJ = child.inBlock->ins.cbegin(); iterJ != child.inBlock->ins.cend();) 178 | { 179 | if (iterJ->oper == Allocate) 180 | { 181 | program->vFuncs[caller].content.inBlock->ins.push_front(*iterJ); 182 | iterJ = child.inBlock->ins.erase(iterJ); 183 | } 184 | else 185 | ++iterJ; 186 | } 187 | for (auto ¶m : child.params) 188 | param.val += offsetVarID; 189 | for (size_t i = 0; i < child.params.size(); i++) 190 | child.inBlock->ins.push_front(IR(child.params[i], Move, iter->paramExt[i])); 191 | 192 | if (child.inBlock->brTrue.get() == child.outBlock.get() && !child.inBlock->brFalse) 193 | { 194 | assert(child.inBlock->ins.back().oper == Jump); 195 | child.inBlock->ins.pop_back(); 196 | block->ins.splice(iter, child.inBlock->ins); 197 | iter = block->ins.erase(iter); 198 | } 199 | else 200 | { 201 | iter = block->ins.erase(iter); 202 | child.outBlock->ins.splice(child.outBlock->ins.end(), block->ins, iter, block->ins.end()); 203 | block->ins.splice(block->ins.end(), child.inBlock->ins); 204 | child.outBlock->brTrue = block->brTrue; 205 | child.outBlock->brFalse = block->brFalse; 206 | block->brTrue = child.inBlock->brTrue; 207 | block->brFalse = child.inBlock->brFalse; 208 | 209 | block = child.outBlock.get(); 210 | iter = block->ins.begin(); 211 | } 212 | } 213 | else 214 | ++iter; 215 | } 216 | } 217 | 218 | //assert(caller == callee || stats[caller].callTo[callee] == 0); 219 | if(stats[caller].callTo[callee] == 0) 220 | stats[caller].callTo.erase(callee); 221 | } 222 | 223 | } -------------------------------------------------------------------------------- /src/InlineOptimizer.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_INLINE_OPTIMIZER_H 2 | #define MX_COMPILER_INLINE_OPTIMIZER_H 3 | 4 | #include "common.h" 5 | #include "MxProgram.h" 6 | 7 | namespace MxIR 8 | { 9 | class InlineOptimizer 10 | { 11 | public: 12 | InlineOptimizer() : program(MxProgram::getDefault()) {} 13 | InlineOptimizer(MxProgram *program) : program(program) {} 14 | void work(); 15 | 16 | protected: 17 | struct statFunc 18 | { 19 | size_t nInsn, nBlock; 20 | size_t nVar; 21 | bool externalCall, forceInline; 22 | std::map callTo; //func ID -> call times 23 | 24 | size_t nUpdate; 25 | 26 | statFunc() : nInsn(0), nBlock(0), nVar(0), externalCall(false), nUpdate(0) {} 27 | size_t penalty() const; 28 | }; 29 | 30 | protected: 31 | void analyzeProgram(); 32 | statFunc analyzeFunc(Function &func); 33 | void applyInline(size_t callee, size_t caller, Function &content); 34 | 35 | protected: 36 | std::vector stats; 37 | MxProgram *program; 38 | }; 39 | } 40 | 41 | #endif -------------------------------------------------------------------------------- /src/InstructionSelect.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "InstructionSelect.h" 3 | 4 | namespace MxIR 5 | { 6 | void InstructionSelect::computeUseCount() 7 | { 8 | func.inBlock->traverse([this](Block *block) -> bool 9 | { 10 | for (auto &ins : block->instructions()) 11 | for (Operand *operand : ins.getInputReg()) 12 | useCount[*operand]++; 13 | return true; 14 | }); 15 | } 16 | 17 | void InstructionSelect::selectInsn(Block *block) 18 | { 19 | static const std::set alterInsn = { 20 | Slt, Sle, Seq, Sgt, Sge, Sne, 21 | Sltu, Sleu, Sgtu, Sgeu 22 | }; 23 | if (block->ins.back().oper == Br && block->ins.size() >= 2) 24 | { 25 | auto iter = std::prev(block->ins.end(), 2); 26 | if (alterInsn.count(iter->oper)) 27 | { 28 | if (iter->dst.val == block->ins.back().src1.val && iter->dst.ver == block->ins.back().src1.ver 29 | && useCount[iter->dst] == 1) 30 | { 31 | iter->dst = EmptyOperand(); 32 | block->ins.back().src1 = EmptyOperand(); 33 | } 34 | } 35 | } 36 | } 37 | 38 | void InstructionSelect::work() 39 | { 40 | computeUseCount(); 41 | func.inBlock->traverse([this](Block *block) -> bool 42 | { 43 | if(block != func.outBlock.get()) 44 | selectInsn(block); 45 | return true; 46 | }); 47 | } 48 | } -------------------------------------------------------------------------------- /src/InstructionSelect.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_INSTRUCTION_SELECT_H 2 | #define MX_COMPILER_INSTRUCTION_SELECT_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | 7 | namespace MxIR 8 | { 9 | class InstructionSelect 10 | { 11 | public: 12 | InstructionSelect(Function &func) : func(func) {} 13 | void work(); 14 | 15 | protected: 16 | void computeUseCount(); 17 | void selectInsn(Block *block); 18 | 19 | protected: 20 | Function &func; 21 | std::map useCount; 22 | }; 23 | } 24 | 25 | #endif -------------------------------------------------------------------------------- /src/IssueCollector.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "IssueCollector.h" 3 | #include 4 | 5 | IssueCollector *IssueCollector::defIC = nullptr; 6 | 7 | IssueCollector::IssueCollector(issueLevel printLevel, std::ostream *printTarget, const antlr4::TokenStream *tokenStream, const std::string &fileName) : 8 | printLevel(printLevel), printTarget(printTarget), tokenStream(tokenStream), fileName(fileName), cntError(0) 9 | { 10 | std::string fileContent = tokenStream->getTokenSource()->getInputStream()->toString(); 11 | std::string temp; 12 | for (char c : fileContent) 13 | { 14 | if (c == '\r') 15 | continue; 16 | else if (c == '\n') 17 | { 18 | lineContent.push_back(temp); 19 | temp = ""; 20 | } 21 | else 22 | temp += c; 23 | } 24 | if (!temp.empty()) 25 | lineContent.push_back(temp); 26 | } 27 | 28 | void IssueCollector::notice(ssize_t tokenL, ssize_t tokenR, const std::string &description) 29 | { 30 | issue e{ NOTICE, tokenL, tokenR, description }; 31 | vIssues.push_back(e); 32 | printIssue(e); 33 | } 34 | void IssueCollector::warning(ssize_t tokenL, ssize_t tokenR, const std::string &description) 35 | { 36 | issue e{ WARNING, tokenL, tokenR, description }; 37 | vIssues.push_back(e); 38 | printIssue(e); 39 | } 40 | void IssueCollector::error(ssize_t tokenL, ssize_t tokenR, const std::string &description) 41 | { 42 | issue e{ ERROR, tokenL, tokenR, description }; 43 | vIssues.push_back(e); 44 | printIssue(e); 45 | cntError++; 46 | if (cntError >= MAX_ERROR) 47 | fatal(0, -1, "Maximum error limit exceeded. Stop compiling."); 48 | } 49 | void IssueCollector::fatal(ssize_t tokenL, ssize_t tokenR, const std::string &description) 50 | { 51 | issue e{ FATAL, tokenL, tokenR, description }; 52 | vIssues.push_back(e); 53 | throw FatalErrorException{ e }; 54 | } 55 | 56 | void IssueCollector::printIssue(const issue &e) 57 | { 58 | static const size_t maxLine = 10; 59 | static const std::vector issueLevelName = { "Notice", "Warning", "Error", "Fatal Error" }; 60 | if (e.level >= printLevel && printTarget && tokenStream) 61 | { 62 | if (!fileName.empty()) 63 | *printTarget << fileName << ":"; 64 | if (e.tokenL > e.tokenR) 65 | *printTarget << " " << issueLevelName[e.level] << ": " << e.description << std::endl; 66 | else 67 | { 68 | //TODO: Tabstop 69 | size_t startLine = tokenStream->get(e.tokenL)->getLine(); 70 | size_t startPos = tokenStream->get(e.tokenL)->getCharPositionInLine(); 71 | *printTarget << startLine << ":" << startPos << ": " << issueLevelName[e.level] << ": "; 72 | *printTarget << e.description << std::endl; 73 | 74 | auto *endToken = tokenStream->get(e.tokenR); 75 | size_t endLine = endToken->getLine(); 76 | size_t endPos = endToken->getCharPositionInLine() + endToken->getStopIndex() - endToken->getStartIndex(); 77 | 78 | if (startLine == endLine) 79 | printLine(lineContent.at(startLine - 1), startPos, endPos); 80 | else 81 | { 82 | printLine(lineContent.at(startLine - 1), startPos, startPos); 83 | for (size_t i = startLine + 1; i <= endLine && i <= startLine + maxLine; i++) 84 | printLine(lineContent.at(i - 1), 1, 0); 85 | } 86 | } 87 | } 88 | } 89 | 90 | void IssueCollector::printLine(const std::string &line, size_t l, size_t r) 91 | { 92 | static const size_t tabStop = 4; 93 | auto getHighlightChar = [l, r](size_t pos) 94 | { 95 | if (pos < l) 96 | return ' '; 97 | if (pos == l) 98 | return '^'; 99 | if (pos > l && pos <= r) 100 | return '~'; 101 | return ' '; 102 | }; 103 | std::string highlight; 104 | *printTarget << ' '; 105 | for (size_t i = 0; i < line.size(); i++) 106 | { 107 | if (line[i] == '\t') 108 | { 109 | for (size_t j = i; j > i && j % tabStop == 0; j++) 110 | { 111 | *printTarget << ' '; 112 | highlight += getHighlightChar(i); 113 | } 114 | } 115 | else 116 | { 117 | *printTarget << line[i]; 118 | highlight += getHighlightChar(i); 119 | } 120 | } 121 | *printTarget << std::endl; 122 | if (l <= r) 123 | *printTarget << ' ' << highlight << std::endl; 124 | } -------------------------------------------------------------------------------- /src/IssueCollector.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_ERROR_COLLECTOR_H 2 | #define MX_COMPILER_ERROR_COLLECTOR_H 3 | 4 | #include "common.h" 5 | 6 | namespace antlr4 7 | { 8 | class TokenStream; 9 | } 10 | 11 | class IssueCollector 12 | { 13 | public: 14 | enum issueLevel : int 15 | { 16 | NOTICE = 0, WARNING = 1, ERROR = 2, FATAL = 3 17 | }; 18 | struct issue 19 | { 20 | issueLevel level; 21 | ssize_t tokenL, tokenR; 22 | std::string description; 23 | }; 24 | struct FatalErrorException { issue e; }; 25 | 26 | std::vector vIssues; 27 | size_t cntError; 28 | 29 | public: 30 | IssueCollector() : printLevel(FATAL), printTarget(nullptr), tokenStream(nullptr), cntError(0) {} 31 | IssueCollector(issueLevel printLevel, std::ostream *printTarget, const antlr4::TokenStream *tokenStream, const std::string &fileName); 32 | 33 | void notice(ssize_t tokenL, ssize_t tokenR, const std::string &description); 34 | void warning(ssize_t tokenL, ssize_t tokenR, const std::string &description); 35 | void error(ssize_t tokenL, ssize_t tokenR, const std::string &description); 36 | void fatal(ssize_t tokenL, ssize_t tokenR, const std::string &description); 37 | 38 | void setDefault() { defIC = this; } 39 | static IssueCollector * getDefault() { return defIC; } 40 | 41 | protected: 42 | void printIssue(const issue &e); 43 | void printLine(const std::string &line, size_t l, size_t r); 44 | 45 | protected: 46 | issueLevel printLevel; 47 | std::ostream *printTarget; 48 | const antlr4::TokenStream *tokenStream; 49 | std::string fileName; 50 | std::vector lineContent; 51 | 52 | static IssueCollector *defIC; 53 | }; 54 | 55 | #endif -------------------------------------------------------------------------------- /src/LoadCombine.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "LoadCombine.h" 3 | 4 | namespace MxIR 5 | { 6 | void LoadCombine::combine(Block *block) 7 | { 8 | std::map loadOp; //addr -> var 9 | std::map, Operand> loadAOp; 10 | 11 | for (auto &ins : block->ins) 12 | { 13 | if (ins.oper == Call || ins.oper == Store || ins.oper == StoreA) 14 | { 15 | loadOp.clear(); 16 | loadAOp.clear(); 17 | continue; 18 | } 19 | if (ins.oper == Load) 20 | { 21 | if (loadOp.count(ins.src1) && loadOp[ins.src1].size() >= ins.dst.size()) 22 | ins = IR(ins.dst, Move, loadOp[ins.src1].clone().setSize(ins.dst.size())); 23 | else 24 | loadOp[ins.src1] = ins.dst; 25 | } 26 | else if (ins.oper == LoadA) 27 | { 28 | auto src = std::make_pair(ins.src1, ins.src2); 29 | if (loadAOp.count(src) && loadAOp[src].size() >= ins.dst.size()) 30 | ins = IR(ins.dst, Move, loadAOp[src].clone().setSize(ins.dst.size())); 31 | else 32 | loadAOp[src] = ins.dst; 33 | } 34 | } 35 | } 36 | 37 | void LoadCombine::work() 38 | { 39 | func.inBlock->traverse([](Block *block) -> bool 40 | { 41 | combine(block); 42 | return true; 43 | }); 44 | } 45 | } -------------------------------------------------------------------------------- /src/LoadCombine.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_LOAD_COMBINE_H 2 | #define MX_COMPILER_LOAD_COMBINE_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | 7 | namespace MxIR 8 | { 9 | class LoadCombine 10 | { 11 | public: 12 | LoadCombine(Function &func) : func(func) {} 13 | void work(); 14 | 15 | protected: 16 | static void combine(Block *block); 17 | 18 | protected: 19 | Function &func; 20 | }; 21 | } 22 | 23 | #endif -------------------------------------------------------------------------------- /src/LoopDetector.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "LoopDetector.h" 3 | 4 | namespace MxIR 5 | { 6 | void LoopDetector::findLoops() 7 | { 8 | func.inBlock->traverse([this](Block *block) -> bool 9 | { 10 | mapBlock[block] = vBlock.size(); 11 | vBlock.push_back(block); 12 | return true; 13 | }); 14 | dtree = DomTree(vBlock.size()); 15 | for (Block *blk : vBlock) 16 | { 17 | if (blk->brTrue) 18 | dtree.link(mapBlock[blk], mapBlock[blk->brTrue.get()]); 19 | if (blk->brFalse) 20 | dtree.link(mapBlock[blk], mapBlock[blk->brFalse.get()]); 21 | } 22 | assert(vBlock[0] == func.inBlock.get()); 23 | dtree.buildTree(0); 24 | 25 | std::set predecessors; 26 | dfs_dtree(0, predecessors); 27 | } 28 | 29 | void LoopDetector::dfs_dtree(size_t idx, std::set &predecessors) 30 | { 31 | Block *blk = vBlock[idx]; 32 | for (auto *child : { &blk->brTrue, &blk->brFalse }) 33 | if (*child && predecessors.count(child->get())) 34 | dfs_backward(blk, child->get()); 35 | predecessors.insert(blk); 36 | for (size_t child_dtree : dtree.getDomChildren(idx)) 37 | dfs_dtree(child_dtree, predecessors); 38 | predecessors.erase(blk); 39 | } 40 | 41 | void LoopDetector::dfs_backward(Block *blk, Block *target) 42 | { 43 | loops[target].insert(blk); 44 | if (blk == target) 45 | return; 46 | for (Block *pred : blk->preds) 47 | if (!loops[target].count(pred)) 48 | dfs_backward(pred, target); 49 | } 50 | } -------------------------------------------------------------------------------- /src/LoopDetector.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_LOOP_DETECTOR 2 | #define MX_COMPILER_LOOP_DETECTOR 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "utils/DomTree.h" 7 | 8 | namespace MxIR 9 | { 10 | class LoopDetector 11 | { 12 | public: 13 | LoopDetector(Function &func) : func(func) {} 14 | void findLoops(); 15 | const std::map> &getLoops() const { return loops; } 16 | 17 | protected: 18 | void dfs_dtree(size_t idx, std::set &predecessors); 19 | void dfs_backward(Block *blk, Block *target); 20 | 21 | protected: 22 | Function &func; 23 | 24 | std::map mapBlock; 25 | std::vector vBlock; 26 | DomTree dtree; 27 | 28 | std::map> loops; //loop header -> set of loop body 29 | }; 30 | } 31 | 32 | #endif -------------------------------------------------------------------------------- /src/LoopInvariantOptimizer.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "LoopInvariantOptimizer.h" 3 | #include "LoopDetector.h" 4 | #include "IRVisualizer.h" 5 | 6 | namespace MxIR 7 | { 8 | void LoopInvariantOptimizer::work() 9 | { 10 | func.splitProgramRegion(); 11 | func.constructPST(); 12 | 13 | computeMaxVer(); 14 | 15 | LoopDetector detector(func); 16 | detector.findLoops(); 17 | for (auto &kv : detector.getLoops()) 18 | { 19 | loop cur; 20 | cur.header = kv.first; 21 | cur.body = kv.second; 22 | 23 | loops.push_back(cur); 24 | } 25 | std::sort(loops.begin(), loops.end(), [](const loop &a, const loop &b) { return a.body.size() > b.body.size(); }); 26 | 27 | for (size_t i = 0; i < loops.size(); i++) 28 | { 29 | for (Block *block : loops[i].body) 30 | loopID.insert(std::make_pair(block, i)); 31 | } 32 | 33 | //static size_t count = 0; 34 | 35 | for (loop &lp : loops) 36 | { 37 | /*count++; 38 | std::ofstream loptim_tmp("loop_" + std::to_string(count) + ".graph"); 39 | IRVisualizer irv(loptim_tmp); 40 | irv.printHead(); 41 | irv.print(func, "function", true); 42 | irv.printFoot(); 43 | loptim_tmp.close();*/ 44 | 45 | mapVar.clear(); 46 | mapRegion.clear(); 47 | failedVar.clear(); 48 | vInvar.clear(); 49 | 50 | insertPoint = nullptr; 51 | for (Block *pred : lp.header->preds) 52 | if (!lp.body.count(pred)) 53 | { 54 | assert(!insertPoint); 55 | insertPoint = pred; 56 | } 57 | 58 | if (insertPoint->brTrue.get() != lp.header) 59 | throw 233; 60 | assert(insertPoint->brTrue.get() == lp.header); 61 | 62 | assert(!insertPoint->brFalse); 63 | 64 | createInvars(lp); 65 | 66 | for (size_t i = 0; i < vInvar.size(); i++) 67 | { 68 | if (!vInvar[i]->failed) 69 | vInvar[i]->init(); 70 | } 71 | 72 | for (size_t i = 0; i < vInvar.size(); i++) 73 | { 74 | InvariantRegion *region = dynamic_cast(vInvar[i].get()); 75 | if (!region) 76 | break; 77 | if (region->failed) 78 | continue; 79 | 80 | InvariantRegion *upmostParent = region; 81 | while (mapRegion.count(upmostParent->node->parent.lock().get()) && !vInvar[mapRegion[upmostParent->node->parent.lock().get()]]->failed) 82 | upmostParent = dynamic_cast(vInvar[mapRegion[upmostParent->node->parent.lock().get()]].get()); 83 | 84 | if (upmostParent != region) 85 | { 86 | region->worked = true; 87 | for (size_t dependency : region->dependOn) 88 | { 89 | if (dynamic_cast(vInvar[dependency].get())) 90 | continue; 91 | 92 | if (!upmostParent->dependBy.count(dependency)) 93 | { 94 | upmostParent->dependOn.insert(dependency); 95 | vInvar[dependency]->dependBy.insert(upmostParent->index); 96 | } 97 | vInvar[dependency]->dependBy.erase(region->index); 98 | } 99 | for (size_t i : region->dependBy) 100 | { 101 | if (auto ptr = dynamic_cast(vInvar[i].get())) 102 | { 103 | ptr->dependOn.erase(region->index); 104 | continue; 105 | } 106 | if (upmostParent->dependOn.count(i)) 107 | { 108 | upmostParent->dependOn.erase(i); 109 | vInvar[i]->dependBy.erase(upmostParent->index); 110 | } 111 | upmostParent->dependBy.insert(i); 112 | vInvar[i]->dependOn.insert(upmostParent->index); 113 | vInvar[i]->dependOn.erase(region->index); 114 | } 115 | } 116 | } 117 | 118 | std::queue workList; 119 | for (size_t i = 0; i < vInvar.size(); i++) 120 | { 121 | if (!vInvar[i]->failed && vInvar[i]->dependOn.empty()) 122 | workList.push(i); 123 | } 124 | //std::cerr << "IN" << std::endl; 125 | while (!workList.empty()) 126 | { 127 | size_t current = workList.front(); 128 | workList.pop(); 129 | if (vInvar[current]->failed) 130 | continue; 131 | 132 | vInvar[current]->work(); 133 | for (size_t dependency : vInvar[current]->dependBy) 134 | { 135 | assert(!vInvar[dependency]->failed); 136 | vInvar[dependency]->dependOn.erase(current); 137 | if (!vInvar[dependency]->worked && vInvar[dependency]->dependOn.empty()) 138 | workList.push(dependency); 139 | } 140 | } 141 | //std::cerr << "OUT" << std::endl; 142 | } 143 | } 144 | 145 | void LoopInvariantOptimizer::computeMaxVer() 146 | { 147 | func.inBlock->traverse([this](Block *block) -> bool 148 | { 149 | for (auto &ins : block->instructions()) 150 | { 151 | for (Operand *operand : join(ins.getInputReg(), ins.getOutputReg())) 152 | maxVer[operand->val] = std::max(maxVer[operand->val], operand->ver); 153 | } 154 | return true; 155 | }); 156 | } 157 | 158 | void LoopInvariantOptimizer::createInvars(const loop &lp) 159 | { 160 | std::shared_ptr node(lp.header->pstNode.lock()); 161 | std::function traverse; 162 | traverse = [&traverse, this](PSTNode *node) 163 | { 164 | for (auto &child : node->children) 165 | traverse(child.get()); 166 | InvariantRegion::construct(*this, node); 167 | }; 168 | for (auto &child : node->children) 169 | { 170 | if (lp.body.count(child->inBlock)) 171 | traverse(child.get()); 172 | } 173 | 174 | failedVar.clear(); 175 | for (Block *block : lp.body) 176 | { 177 | bool blockFailed = false; 178 | for (auto iter = block->ins.begin(); iter != block->ins.end(); ++iter) 179 | { 180 | auto &ins = *iter; 181 | auto output = ins.getOutputReg(); 182 | if (ins.oper == Call && (ins.src1.type != Operand::funcID || !(program->vFuncs[ins.src1.val].attribute & ConstExpr)) 183 | || ins.oper == Load || ins.oper == LoadA 184 | || ins.oper == Store || ins.oper == StoreA) 185 | { 186 | blockFailed = true; 187 | for (Operand *operand : output) 188 | failedVar.insert(*operand); 189 | } 190 | else if (!output.empty()) 191 | { 192 | InvariantVar::construct(*this, block, iter); 193 | } 194 | } 195 | for (auto &kv : block->phi) 196 | InvariantVar::construct(*this, block, kv.second); 197 | if (blockFailed) 198 | failedBlock.insert(block); 199 | } 200 | } 201 | 202 | void LoopInvariantOptimizer::InvariantVar::init() 203 | { 204 | if (inited) 205 | return; 206 | if (isPhi) 207 | return fail(); 208 | std::set dependency; 209 | for (Operand *operand : insn->getInputReg()) 210 | { 211 | if (parent.failedVar.count(*operand)) 212 | return fail(); 213 | if (parent.mapVar.count(*operand)) 214 | { 215 | assert(dynamic_cast(parent.vInvar[parent.mapVar[*operand]].get())); 216 | dependency.insert(parent.mapVar[*operand]); 217 | } 218 | } 219 | if (!setDependOn(dependency)) 220 | return fail(); 221 | inited = true; 222 | } 223 | 224 | void LoopInvariantOptimizer::InvariantVar::fail() 225 | { 226 | Invariant::fail(); 227 | if (isPhi) 228 | parent.failedVar.insert(phiDst); 229 | else 230 | { 231 | for (Operand *operand : insn->getOutputReg()) 232 | parent.failedVar.insert(*operand); 233 | } 234 | } 235 | 236 | void LoopInvariantOptimizer::InvariantRegion::init() 237 | { 238 | if (!parent.mapRegion.count(node->parent.lock().get()) && node->isSingleBlock()) 239 | return fail(); 240 | for (Block *block : node->blocks) 241 | { 242 | if (parent.failedBlock.count(block)) 243 | return fail(); 244 | } 245 | std::set dependency; 246 | for (auto &child : node->children) 247 | { 248 | if (parent.vInvar[parent.mapRegion[child.get()]]->failed) 249 | return fail(); 250 | dependency.insert(parent.mapRegion[child.get()]); 251 | } 252 | 253 | std::vector children; 254 | std::function dfs; 255 | dfs = [&children, &dfs, this](PSTNode *node) 256 | { 257 | assert(parent.vInvar[parent.mapRegion[node]]->inited); 258 | children.push_back(parent.mapRegion[node]); 259 | for(auto &child : node->children) 260 | dfs(child.get()); 261 | }; 262 | for (auto &child : node->children) 263 | dfs(child.get()); 264 | 265 | auto isDefined = [&dependency, &children, this](Operand operand) 266 | { 267 | assert(!parent.failedVar.count(operand)); 268 | if (!parent.mapVar.count(operand)) 269 | return true; 270 | if (dependBy.count(parent.mapVar[operand])) 271 | return true; 272 | for (size_t child : children) 273 | { 274 | if (parent.vInvar[child]->dependBy.count(parent.mapVar[operand])) 275 | return true; 276 | } 277 | return false; 278 | }; 279 | 280 | for (Block *block : node->blocks) 281 | for (auto &ins : block->instructions()) 282 | { 283 | for (Operand *operand : ins.getOutputReg()) 284 | { 285 | parent.vInvar[parent.mapVar[*operand]]->setDependOn({ index }); 286 | parent.vInvar[parent.mapVar[*operand]]->inited = true; 287 | } 288 | } 289 | for(Block *block : node->blocks) 290 | for (auto &ins : block->instructions()) 291 | { 292 | for (Operand *operand : ins.getInputReg()) 293 | { 294 | if (parent.failedVar.count(*operand)) 295 | return fail(); 296 | if (!isDefined(*operand)) 297 | dependency.insert(parent.mapVar[*operand]); 298 | } 299 | } 300 | if (!setDependOn(dependency)) 301 | return fail(); 302 | inited = true; 303 | } 304 | 305 | void LoopInvariantOptimizer::InvariantRegion::releaseVars() 306 | { 307 | std::vector cand; 308 | for (size_t dependency : dependBy) 309 | { 310 | if (InvariantVar *var = dynamic_cast(parent.vInvar[dependency].get())) 311 | { 312 | var->inited = false; 313 | cand.push_back(var); 314 | } 315 | } 316 | for (InvariantVar *var : cand) 317 | { 318 | var->init(); 319 | } 320 | } 321 | 322 | void LoopInvariantOptimizer::InvariantRegion::fail() 323 | { 324 | releaseVars(); 325 | for (auto &child : node->children) 326 | { 327 | if (child->isSingleBlock()) 328 | { 329 | if (!parent.vInvar[parent.mapRegion[child.get()]]->failed) 330 | parent.vInvar[parent.mapRegion[child.get()]]->fail(); 331 | } 332 | } 333 | Invariant::fail(); 334 | } 335 | 336 | void LoopInvariantOptimizer::protectDivisor(Block *block, std::list::iterator insn) 337 | { 338 | if (insn->oper == Div || insn->oper == Mod) 339 | { 340 | if (insn->src2.isImm()) 341 | return; 342 | assert(insn->src2.isReg()); 343 | Operand tmp = RegSize(maxVer.empty() ? 0 : maxVer.rbegin()->first + 1, insn->src2.size()); 344 | maxVer[tmp.val] = 0; 345 | block->ins.insert(insn, IR(tmp, TestZero, insn->src2, ImmSize(1, insn->src1.size()))); 346 | insn->src2 = tmp; 347 | } 348 | } 349 | 350 | void LoopInvariantOptimizer::InvariantVar::work() 351 | { 352 | if (worked) 353 | return; 354 | assert(!isPhi); 355 | static const std::set blacklist = { Seq, Sne, Sgt, Sge, Slt, Sle, Sgtu, Sgeu, Sltu, Sleu, Move }; 356 | if (dependBy.empty() && blacklist.count(insn->oper)) 357 | return; 358 | auto iter = parent.insertPoint->ins.insert(std::prev(parent.insertPoint->ins.end()), *insn); 359 | parent.protectDivisor(parent.insertPoint, iter); 360 | block->ins.erase(insn); 361 | worked = true; 362 | } 363 | 364 | void LoopInvariantOptimizer::InvariantRegion::work() 365 | { 366 | if (worked) 367 | return; 368 | std::shared_ptr pstParent = node->parent.lock(); 369 | 370 | std::set innerBlocks; 371 | node->traverse([&innerBlocks](PSTNode *node) 372 | { 373 | for (Block *block : node->blocks) 374 | innerBlocks.insert(block); 375 | }); 376 | 377 | Block *pred = nullptr; 378 | for (Block *block : node->inBlock->preds) 379 | { 380 | if (!innerBlocks.count(block)) 381 | { 382 | assert(!pred); 383 | pred = block; 384 | } 385 | } 386 | 387 | Block::block_ptr &next = 388 | innerBlocks.count(node->outBlock->brTrue.get()) ? node->outBlock->brFalse : node->outBlock->brTrue; 389 | assert(next); 390 | 391 | std::shared_ptr tmpBlock(Block::construct()); 392 | 393 | auto range = parent.loopID.equal_range(node->inBlock); 394 | for (auto iter = range.first; iter != range.second; ++iter) 395 | { 396 | if (parent.loops[iter->second].header != node->inBlock) 397 | { 398 | for (Block *block : innerBlocks) 399 | parent.loops[iter->second].body.erase(block); 400 | parent.loops[iter->second].body.insert(tmpBlock.get()); 401 | } 402 | } 403 | 404 | std::shared_ptr inBlock = node->inBlock->self.lock(); 405 | 406 | tmpBlock->pstNode = pstParent; 407 | pstParent->blocks.insert(tmpBlock.get()); 408 | tmpBlock->brTrue = next->self.lock(); 409 | tmpBlock->ins = { IRJump() }; 410 | if (pred->brTrue.get() == inBlock.get()) 411 | pred->brTrue = tmpBlock; 412 | if (pred->brFalse.get() == inBlock.get()) 413 | pred->brFalse = tmpBlock; 414 | 415 | std::shared_ptr dummyBlockAfterOut(Block::construct()); 416 | pstParent->blocks.insert(dummyBlockAfterOut.get()); 417 | dummyBlockAfterOut->pstNode = pstParent; 418 | dummyBlockAfterOut->ins = { IRJump() }; 419 | dummyBlockAfterOut->brTrue = parent.insertPoint->brTrue; 420 | 421 | for (auto &kv : next->phi) 422 | for (auto &src : kv.second.srcs) 423 | if (src.second.lock().get() == node->outBlock) 424 | src.second = tmpBlock; 425 | 426 | for (auto &kv : parent.insertPoint->brTrue->phi) 427 | for (auto &src : kv.second.srcs) 428 | if (src.second.lock().get() == parent.insertPoint) 429 | src.second = dummyBlockAfterOut; 430 | //src.second = node->outBlock->self; 431 | 432 | for (auto &kv : node->inBlock->phi) 433 | for (auto &src : kv.second.srcs) 434 | if (src.second.lock().get() == pred) 435 | src.second = parent.insertPoint->self; 436 | 437 | // next = parent.insertPoint->brTrue; 438 | next = dummyBlockAfterOut; 439 | parent.insertPoint->brTrue = inBlock; 440 | // parent.insertPoint = node->outBlock; 441 | parent.insertPoint = dummyBlockAfterOut.get(); 442 | 443 | std::function dfs; 444 | dfs = [&dfs, this](PSTNode *node) 445 | { 446 | for (Block *block : node->blocks) 447 | for (auto iter = block->ins.begin(); iter != block->ins.end(); ++iter) 448 | parent.protectDivisor(block, iter); 449 | 450 | for (auto &child : node->children) 451 | dfs(child.get()); 452 | }; 453 | dfs(node); 454 | 455 | std::shared_ptr ptr = node->self.lock(); 456 | pstParent->children.erase(ptr->iterParent); 457 | parent.func.pstRoot->children.push_back(ptr); 458 | ptr->iterParent = std::prev(parent.func.pstRoot->children.end()); 459 | 460 | for (size_t dependency : dependBy) 461 | parent.vInvar[dependency]->worked = true; 462 | 463 | worked = true; 464 | } 465 | } -------------------------------------------------------------------------------- /src/LoopInvariantOptimizer.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_LOOP_INVARIANT_OPTIMIZER_H 2 | #define MX_COMPILER_LOOP_INVARIANT_OPTIMIZER_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "MxProgram.h" 7 | 8 | namespace MxIR 9 | { 10 | class LoopInvariantOptimizer 11 | { 12 | public: 13 | LoopInvariantOptimizer(Function &func) : func(func), program(MxProgram::getDefault()) {} 14 | void work(); 15 | 16 | protected: 17 | struct loop 18 | { 19 | Block *header; 20 | std::set body; 21 | }; 22 | struct Invariant 23 | { 24 | LoopInvariantOptimizer &parent; 25 | std::set dependBy; 26 | std::set dependOn; 27 | bool failed, inited, worked; 28 | const size_t index; 29 | 30 | Invariant(LoopInvariantOptimizer &parent, size_t index) : parent(parent), failed(false), inited(false), worked(false), index(index) {} 31 | virtual ~Invariant() {} 32 | virtual void init() = 0; 33 | virtual void work() = 0; 34 | virtual void fail() 35 | { 36 | if (failed) 37 | return; 38 | failed = true; 39 | for (size_t invar : std::set(dependBy)) 40 | parent.vInvar[invar]->fail(); 41 | for (size_t invar : dependOn) 42 | parent.vInvar[invar]->dependBy.erase(index); 43 | } 44 | bool setDependOn(const std::set &newDepend) 45 | { 46 | for (size_t i : dependOn) 47 | parent.vInvar[i]->dependBy.erase(index); 48 | dependOn = newDepend; 49 | for (size_t i : dependOn) 50 | { 51 | if (parent.vInvar[i]->failed) 52 | return false; 53 | } 54 | for (size_t i : dependOn) 55 | parent.vInvar[i]->dependBy.insert(index); 56 | return true; 57 | } 58 | }; 59 | struct InvariantVar : public Invariant 60 | { 61 | Block *block; 62 | std::list::iterator insn; 63 | Operand phiDst; 64 | bool isPhi; 65 | 66 | InvariantVar(LoopInvariantOptimizer &parent, size_t index, Block *block, const Block::PhiIns &phi) : Invariant(parent, index), block(block), isPhi(true), phiDst(phi.dst) {} 67 | InvariantVar(LoopInvariantOptimizer &parent, size_t index, Block *block, std::list::iterator insn) : 68 | Invariant(parent, index), block(block), insn(insn), isPhi(false) {} 69 | 70 | static InvariantVar *construct(LoopInvariantOptimizer &parent, Block *block, const Block::PhiIns &phi) 71 | { 72 | parent.vInvar.emplace_back(new InvariantVar(parent, parent.vInvar.size(), block, phi)); 73 | InvariantVar *ptr = dynamic_cast(parent.vInvar.back().get()); 74 | parent.mapVar[ptr->phiDst] = ptr->index; 75 | return ptr; 76 | } 77 | static InvariantVar *construct(LoopInvariantOptimizer &parent, Block *block, std::list::iterator insn) 78 | { 79 | parent.vInvar.emplace_back(new InvariantVar(parent, parent.vInvar.size(), block, insn)); 80 | InvariantVar *ptr = dynamic_cast(parent.vInvar.back().get()); 81 | parent.mapVar[ptr->insn->dst] = ptr->index; 82 | return ptr; 83 | } 84 | 85 | virtual void init() override; 86 | virtual void fail() override; 87 | virtual void work() override; 88 | }; 89 | struct InvariantRegion : public Invariant 90 | { 91 | PSTNode *node; 92 | 93 | InvariantRegion(LoopInvariantOptimizer &parent, size_t index, PSTNode *node) : Invariant(parent, index), node(node) {} 94 | 95 | static InvariantRegion *construct(LoopInvariantOptimizer &parent, PSTNode *node) 96 | { 97 | parent.vInvar.emplace_back(new InvariantRegion(parent, parent.vInvar.size(), node)); 98 | InvariantRegion *ptr = dynamic_cast(parent.vInvar.back().get()); 99 | parent.mapRegion[ptr->node] = ptr->index; 100 | return ptr; 101 | } 102 | 103 | virtual void init() override; 104 | virtual void fail() override; 105 | virtual void work() override; 106 | 107 | void releaseVars(); 108 | }; 109 | 110 | protected: 111 | void computeMaxVer(); 112 | void createInvars(const loop &lp); 113 | void protectDivisor(Block *block, std::list::iterator insn); 114 | 115 | protected: 116 | Function &func; 117 | MxProgram *program; 118 | std::vector> vInvar; 119 | std::map mapVar; 120 | std::map mapRegion; 121 | 122 | Block *insertPoint; 123 | 124 | std::set failedVar; 125 | std::set failedBlock; 126 | 127 | std::map maxVer; 128 | 129 | std::vector loops; 130 | std::multimap loopID; 131 | }; 132 | } 133 | 134 | #endif -------------------------------------------------------------------------------- /src/MxBuiltin.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_MX_BUILTIN_H 2 | #define MX_COMPILER_MX_BUILTIN_H 3 | 4 | #include "common.h" 5 | #include "GlobalSymbol.h" 6 | #include "MxProgram.h" 7 | #include "IR.h" 8 | 9 | class MxBuiltin 10 | { 11 | public: 12 | enum class BuiltinFunc : size_t 13 | { 14 | print = 0, 15 | println = 1, 16 | getString = 2, 17 | getInt = 3, 18 | toString = 4, 19 | length = 5, 20 | substring = 6, 21 | parseInt = 7, 22 | ord = 8, 23 | size = 9, 24 | runtime_error = 10, //void __runtime_error(const char *) 25 | strcat = 11, //string __strcat(string, string) 26 | strcmp = 12, //int __strcmp(string, string) 27 | subscript_bool = 13,//ptr (bool[])::operator[](int) 28 | subscript_int = 14, //ptr (int[])::operator[](int) 29 | subscript_object = 15, //ptr (object[])::operator[](int) 30 | newobject = 16, //object __newobject(size_t size, size_t typeid) 31 | release_string = 17,//void __release_string(string) 32 | release_array_internal = 18,//void __release_array_internal(array, int dim) //release reference of array that has internal type (int or bool) 33 | release_array_string = 19, //void __release_array_string(array, int dim) //release reference of string array 34 | release_array_object = 20, //void __release_array_object(array, int dim) //release reference of array that has any object type 35 | release_object = 21, //void __release_object(object) //release reference of any object 36 | addref_object = 22, //void __addref_object(object) 37 | newobject_zero = 23, //object __newobject_zero(size_t size, size_t typeid) //new object and memset to zero 38 | initialize = 24, //void __initialize() 39 | }; 40 | 41 | public: 42 | MxBuiltin() : symbol(GlobalSymbol::getDefault()), program(MxProgram::getDefault()) {} 43 | MxBuiltin(GlobalSymbol *symbol, MxProgram *program) : symbol(symbol), program(program) {} 44 | 45 | static size_t getBuiltinClassByType(MxType type); 46 | void init(); 47 | 48 | static MxProgram::constInfo string2Const(const std::string &str); 49 | 50 | void setDefault() { defBI = this; } 51 | static MxBuiltin * getDefault() { return defBI; } 52 | 53 | protected: 54 | enum class BuiltinSymbol : size_t 55 | { 56 | print = 0, 57 | println, 58 | getString, 59 | getInt, 60 | toString, 61 | length, 62 | substring, 63 | parseInt, 64 | ord, 65 | size, 66 | Cstring, 67 | Carray, 68 | strcmp, 69 | malloc, 70 | free, 71 | realloc, 72 | scanf, 73 | puts, 74 | printf, 75 | putchar, 76 | getchar, 77 | snprintf, 78 | strlen, 79 | memcpy, 80 | sscanf, 81 | fputs, 82 | Stderr, //stderr 83 | exit, 84 | memset, 85 | 86 | Hruntime_error, 87 | Hstrcat, 88 | Hstrcmp, 89 | Hsubscript_bool, 90 | Hsubscript_int, 91 | Hsubscript_object, 92 | Hnewobject, 93 | Hrelease_string, 94 | Hrelease_array_internal, 95 | Hrelease_array_string, 96 | Hrelease_array_object, 97 | Hrelease_object, 98 | Haddref_object, 99 | Hnewobject_zero, 100 | 101 | Hinitialize, 102 | }; 103 | enum class BuiltinConst : size_t 104 | { 105 | Percent_d = 0, //"%d" 106 | runtime_error = 1, //"Runtime Error: " 107 | subscript_out_of_range = 2, //"Subscript out of range" 108 | null_ptr = 3, //"Null pointer" 109 | bad_allocation = 4, //"Bad allocation" 110 | line_break = 5, //"\n" 111 | Percend_s = 6, 112 | }; 113 | enum class BuiltinClass : size_t 114 | { 115 | string = 10, 116 | array = 11 117 | }; 118 | 119 | protected: 120 | void fillBuiltinSymbol(); 121 | void fillBuiltinMemberTable(); 122 | MxIR::Function builtin_print(); 123 | MxIR::Function builtin_println(); 124 | MxIR::Function builtin_getString(); 125 | MxIR::Function builtin_getInt(); 126 | MxIR::Function builtin_toString(); 127 | MxIR::Function builtin_length(); 128 | MxIR::Function builtin_substring(); 129 | MxIR::Function builtin_parseInt(); 130 | MxIR::Function builtin_ord_safe(); 131 | MxIR::Function builtin_ord_unsafe(); 132 | MxIR::Function builtin_size(); 133 | MxIR::Function builtin_size_unsafe(); 134 | MxIR::Function builtin_runtime_error(); 135 | MxIR::Function builtin_strcat(); 136 | MxIR::Function builtin_strcmp(); 137 | MxIR::Function builtin_subscript_safe(size_t size); 138 | MxIR::Function builtin_subscript_unsafe(size_t size); 139 | MxIR::Function builtin_newobject(); 140 | MxIR::Function builtin_newobject_zero(); 141 | 142 | MxIR::Function builtin_addref_object(); 143 | MxIR::Function builtin_release_string(); 144 | MxIR::Function builtin_release_array(bool internal); 145 | 146 | MxIR::Function builtin_stub(const std::vector ¶m); 147 | 148 | protected: 149 | GlobalSymbol *symbol; 150 | MxProgram *program; 151 | static MxBuiltin *defBI; 152 | static const size_t objectHeader = 2 * POINTER_SIZE; 153 | static const size_t stringHeader = POINTER_SIZE, arrayHeader = POINTER_SIZE; 154 | }; 155 | 156 | #endif -------------------------------------------------------------------------------- /src/MxProgram.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "MxProgram.h" 3 | 4 | MxProgram * MxProgram::defProg = nullptr; -------------------------------------------------------------------------------- /src/MxProgram.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_MEMBER_TABLE_H 2 | #define MX_COMPILER_MEMBER_TABLE_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | 7 | enum FuncAttribute : std::uint32_t 8 | { 9 | NoSideEffect = 1, 10 | ConstExpr = 2, //No global var read/write 11 | Linear = 4, //No cycle 12 | Builtin = 8, 13 | ForceInline = 16, 14 | NoInline = 32, 15 | Export = 64, //use 'global' to export in NASM; export name is the same as funcName; multiple export function with the same name is not allowed 16 | }; 17 | 18 | class MxProgram 19 | { 20 | public: 21 | struct funcInfo 22 | { 23 | size_t funcName; 24 | MxType retType; 25 | std::vector paramType; 26 | std::uint32_t attribute; 27 | bool isThiscall; 28 | std::set dependency; //builtin function not included 29 | 30 | MxIR::Function content; 31 | bool disabled = false; 32 | }; 33 | struct varInfo 34 | { 35 | size_t varName; 36 | MxType varType; 37 | std::int64_t offset; //offset in class 38 | }; 39 | struct classInfo 40 | { 41 | size_t classSize; 42 | std::vector members; 43 | }; 44 | struct constInfo 45 | { 46 | size_t labelOffset; //must be in range [0, data.size() ) 47 | size_t align; 48 | std::vector data; 49 | }; 50 | 51 | std::vector vFuncs; 52 | std::vector> vOverloadedFuncs; 53 | std::vector vGlobalVars; //Note: A function is also a variable 54 | std::map vClass; //TODO: Compare performance with unordered_map 55 | std::vector> vLocalVars; 56 | std::vector vConst; 57 | 58 | void setDefault() { defProg = this; } 59 | static MxProgram * getDefault() { return defProg; } 60 | 61 | protected: 62 | static MxProgram *defProg; 63 | }; 64 | 65 | #endif -------------------------------------------------------------------------------- /src/RegisterAllocatorSSA.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_REGISTER_ALLOCATOR_SSA_H 2 | #define MX_COMPILER_REGISTER_ALLOCATOR_SSA_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "utils/DomTree.h" 7 | 8 | namespace MxIR 9 | { 10 | class RegisterAllocatorSSA 11 | { 12 | public: 13 | explicit RegisterAllocatorSSA(Function &func, const std::vector &phyReg) : func(func), phyReg(phyReg) {} 14 | void work(); 15 | 16 | protected: 17 | void splitCriticalEdge(); 18 | void computeDomTree(); 19 | void computeLoop(); 20 | void relabelVReg(); 21 | void computeDefUses(); 22 | void computeNextUse(); 23 | void computeMaxPressure(); 24 | void computeVarOp(); 25 | void computeVarGroup(); 26 | void computeExternalVar(); 27 | 28 | void spillRegister(); 29 | void spillRegisterBlock(Block *block); 30 | Instruction getLoadInsn(size_t regid); 31 | Instruction getSpillInsn(size_t regid); 32 | void initLoopHeader(Block *block); //init the W set (set of the v-registers not spilled) of loop header 33 | void initUsualBlock(Block *block); 34 | 35 | bool isSpill(const Instruction &insn); 36 | bool isReload(const Instruction &insn); 37 | void eliminateSpillCode(); 38 | void eliminateSpillCode(size_t idx, std::set &spilled); 39 | void insertAllocateCode(); 40 | 41 | void reconstructSSA(); 42 | // --- vars will be relabeled here --- 43 | // --- varOp and varGroup is invalid here --- 44 | void analysisLiveness(); 45 | void buildInterferenceGraph(); 46 | void allocateRegister(); 47 | void allocateRegisterDFS(size_t idx); 48 | int chooseRegister(size_t vreg, const std::vector &prefer); 49 | 50 | int getPReg(Operand operand); 51 | void coalesce(); 52 | 53 | void destructSSA(); 54 | void writeRegInfo(); 55 | 56 | size_t findVertexRoot(size_t vtx); 57 | void mergeVertices(const std::set &vtxList); 58 | 59 | protected: 60 | struct BlockProperty 61 | { 62 | size_t idx; 63 | std::map nextUseBegin, nextUseEnd; //global next use distance at the begin/end of the block 64 | std::set loopBorder; //edges leading out of loop 65 | std::set loopBody; 66 | 67 | std::set definedVar; //virtual register that is defined in this block 68 | std::set usedVar; 69 | // std::set updateDistance; //virtual register of which the next use distance at end is updated. used in computeNextUse 70 | 71 | size_t maxPressure; 72 | 73 | std::set Wentry, Wexit; 74 | 75 | bool visited; //visited in spill stage 76 | 77 | std::set liveIn, liveOut; //used when building inference graph 78 | 79 | /* 80 | ------- liveIn --------- 81 | a = phi(...) 82 | b = phi(...) 83 | ... 84 | ----- nextUseBegin ----- 85 | c = d op e 86 | ... 87 | - nextUseEnd & liveOut - 88 | */ 89 | 90 | BlockProperty() : idx(size_t(-1)), maxPressure(0), visited(false) {} 91 | }; 92 | struct GraphVertex 93 | { 94 | std::set neighbor; 95 | std::shared_ptr> forbiddenReg; 96 | int preg; 97 | bool pinned; 98 | size_t root; 99 | 100 | GraphVertex() : preg(-1), pinned(false) {} 101 | }; 102 | struct OptimUnit 103 | { 104 | RegisterAllocatorSSA &allocator; 105 | size_t keyVertex; 106 | int targetRegister; 107 | std::set S; 108 | std::map changedReg; 109 | 110 | std::set cand; 111 | const size_t minCand; //if cand < minCand, this OptimUnit will be discarded 112 | const ssize_t bias; 113 | 114 | OptimUnit(RegisterAllocatorSSA &allocator, size_t minCand, ssize_t bias) : allocator(allocator), minCand(minCand), bias(bias) {} 115 | 116 | virtual void fail(size_t u) = 0; //u != keyVertex 117 | virtual void conflict(size_t u, size_t v) = 0; //u, v != keyVertex 118 | virtual void apply(); 119 | 120 | bool operator<(const OptimUnit &rhs) const { return S.size() + bias < rhs.S.size() + rhs.bias; } 121 | int work(); //>0 for success; =0 for retry; <0 for abandoned 122 | ssize_t adjust(size_t src); //return -1 if succeed. return the id of conflicted var otherwise 123 | ssize_t adjust(size_t src, int target); 124 | int getCurrentReg(size_t idx) { return changedReg.count(idx) ? changedReg[idx] : allocator.ifGraph[idx].preg; } 125 | }; 126 | struct OptimUnitAll : public OptimUnit 127 | { 128 | std::set vertices; //including keyVertex 129 | std::set> edges; 130 | 131 | OptimUnitAll(RegisterAllocatorSSA &allocator, size_t key, const std::vector &another); 132 | 133 | virtual void fail(size_t u) override; 134 | virtual void conflict(size_t u, size_t v) override; 135 | 136 | void findMaxStableSet(); 137 | }; 138 | struct OptimUnitOne : public OptimUnit 139 | { 140 | size_t another; 141 | 142 | OptimUnitOne(RegisterAllocatorSSA &allocator, size_t u, size_t v); 143 | 144 | virtual void fail(size_t u) override { S.clear(); } 145 | virtual void conflict(size_t u, size_t v) override { S.clear(); } 146 | }; 147 | struct OptimUnitSingle : public OptimUnit //set key var to a specific register 148 | { 149 | OptimUnitSingle(RegisterAllocatorSSA &allocator, size_t key, int target) : OptimUnit(allocator, 1, 0) 150 | { 151 | targetRegister = target; 152 | keyVertex = allocator.findVertexRoot(key); 153 | S.insert(keyVertex); 154 | } 155 | virtual void apply() override 156 | { 157 | assert(S.size() == 1); 158 | allocator.ifGraph[*S.begin()].pinned = true; 159 | } 160 | virtual void fail(size_t u) override { S.clear(); } 161 | virtual void conflict(size_t u, size_t v) override { S.clear(); } 162 | }; 163 | std::vector ifGraph; 164 | std::map property; 165 | std::vector vBlock; 166 | DomTree dtree; 167 | 168 | std::vector varOp; 169 | std::vector varGroup; //vregid -> groupid. note that groupid is also the store address of the register 170 | std::map externalVarHint; //varid -> alloc hint 171 | size_t nVar; 172 | 173 | Function &func; 174 | std::vector phyReg; 175 | 176 | static const size_t outLoopPenalty = 1000; 177 | }; 178 | } 179 | 180 | #endif -------------------------------------------------------------------------------- /src/SSAConstructor.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "SSAConstructor.h" 3 | #include "utils/ElementAdapter.h" 4 | 5 | namespace MxIR 6 | { 7 | void SSAConstructor::constructSSA() 8 | { 9 | preprocess(); 10 | 11 | DomTree domTree = getDomTree(false); 12 | 13 | for (auto &kv : vars) 14 | { 15 | size_t varID = kv.first; 16 | const globalVar &var = kv.second; 17 | 18 | std::queue workList; //pair of { block id, var def } 19 | for (auto &i : var.def) 20 | workList.push(i); 21 | 22 | while (!workList.empty()) 23 | { 24 | auto cur = workList.front(); 25 | workList.pop(); 26 | for (size_t frontier : domTree.getDomFrontier(cur)) 27 | { 28 | if (!blocks[frontier]->phi.count(varID)) 29 | { 30 | blocks[frontier]->phi.insert({ varID, Block::PhiIns{ var.operand } }); 31 | if (!var.def.count(frontier)) 32 | workList.push(frontier); 33 | } 34 | } 35 | } 36 | } 37 | 38 | renameVar(); 39 | } 40 | 41 | void SSAConstructor::constructSSIFull() 42 | { 43 | preprocess(); 44 | 45 | DomTree domTree = getDomTree(false); 46 | DomTree postDomTree = getDomTree(true); 47 | 48 | for (auto &kv : vars) 49 | { 50 | size_t varID = kv.first; 51 | const globalVar &var = kv.second; 52 | 53 | std::queue workListPhi; //blockid, operand 54 | std::queue workListSigma; //blockid 55 | for (size_t i : var.def) 56 | workListPhi.push(i); 57 | for (size_t i : var.uses) 58 | workListSigma.push(i); 59 | 60 | while (!workListPhi.empty() && !workListSigma.empty()) 61 | { 62 | while (!workListPhi.empty()) 63 | { 64 | size_t curBlock = workListPhi.front(); 65 | workListPhi.pop(); 66 | for (size_t frontier : domTree.getDomFrontier(curBlock)) 67 | { 68 | if (!blocks[frontier]->phi.count(varID)) 69 | { 70 | blocks[frontier]->phi.insert({ varID, Block::PhiIns(var.operand) }); 71 | if (!var.def.count(frontier) && !blocks[frontier]->sigma.count(varID)) 72 | workListPhi.push(frontier); 73 | if (!var.uses.count(frontier) && !blocks[frontier]->sigma.count(varID)) 74 | workListSigma.push(frontier); 75 | } 76 | } 77 | } 78 | while (!workListSigma.empty()) 79 | { 80 | size_t curBlock = workListSigma.front(); 81 | workListSigma.pop(); 82 | for (size_t frontier : postDomTree.getDomFrontier(curBlock)) 83 | { 84 | if (!blocks[frontier]->sigma.count(varID)) 85 | { 86 | blocks[frontier]->sigma.insert({ varID, Block::SigmaIns(var.operand) }); 87 | if (!var.def.count(frontier) && !blocks[frontier]->phi.count(varID)) 88 | workListPhi.push(frontier); 89 | if (!var.uses.count(frontier) && !blocks[frontier]->phi.count(varID)) 90 | workListSigma.push(frontier); 91 | } 92 | } 93 | } 94 | } 95 | } 96 | } 97 | 98 | void SSAConstructor::preprocess() 99 | { 100 | maxVarID = 0; 101 | for (auto ¶m : func.params) 102 | maxVarID = std::max(maxVarID, param.val); 103 | func.inBlock->traverse([this](Block *block) -> bool 104 | { 105 | size_t myID = blocks.size(); 106 | mapBlock.insert({ block, myID }); 107 | blocks.push_back(block); 108 | 109 | std::set killed; 110 | for (auto &ins : block->ins) 111 | { 112 | auto inputReg = ins.getInputReg(); 113 | for (Operand *in : inputReg) 114 | { 115 | maxVarID = std::max(maxVarID, in->val); 116 | if (!killed.count(in->val)) 117 | vars.insert({ in->val, globalVar{} }); 118 | } 119 | 120 | auto outputReg = ins.getOutputReg(); 121 | for (Operand *out : outputReg) 122 | { 123 | maxVarID = std::max(maxVarID, out->val); 124 | killed.insert(out->val); 125 | } 126 | } 127 | return true; 128 | }); 129 | func.inBlock->traverse([this](Block *block) -> bool 130 | { 131 | size_t myID = mapBlock[block]; 132 | 133 | std::set killed; 134 | for (auto &ins : block->ins) 135 | { 136 | auto inputReg = ins.getInputReg(); 137 | for (Operand *in : inputReg) 138 | { 139 | if (!vars.count(in->val) || killed.count(in->val)) 140 | continue; 141 | vars[in->val].uses.insert(myID); 142 | } 143 | 144 | auto outputReg = ins.getOutputReg(); 145 | for (Operand *out : outputReg) 146 | { 147 | if (!vars.count(out->val)) 148 | continue; 149 | if (vars[out->val].operand.type == Operand::empty) 150 | vars[out->val].operand = *out; 151 | else 152 | { 153 | assert(vars[out->val].operand.type == out->type); 154 | } 155 | vars[out->val].def.insert(myID); 156 | } 157 | } 158 | return true; 159 | }); 160 | } 161 | 162 | DomTree SSAConstructor::getDomTree(bool postDom) 163 | { 164 | DomTree domTree(blocks.size()); 165 | func.inBlock->traverse([&domTree, postDom, this](Block *block) -> bool 166 | { 167 | size_t myID = mapBlock.find(block)->second; 168 | if (block->brTrue) 169 | postDom ? domTree.link(mapBlock.find(block->brTrue.get())->second, myID) 170 | : domTree.link(myID, mapBlock.find(block->brTrue.get())->second); 171 | if (block->brFalse) 172 | postDom ? domTree.link(mapBlock.find(block->brFalse.get())->second, myID) 173 | : domTree.link(myID, mapBlock.find(block->brFalse.get())->second); 174 | return true; 175 | }); 176 | domTree.buildTree(mapBlock.find(func.inBlock.get())->second); 177 | return domTree; 178 | } 179 | 180 | void SSAConstructor::renameVar() 181 | { 182 | std::set visited; 183 | std::vector varCount(maxVarID + 1); 184 | std::vector> varCurVersion(maxVarID + 1); 185 | 186 | for (auto ¶m : func.params) 187 | { 188 | param.ver = ++varCount[param.val]; 189 | varCurVersion[param.val].push(param); 190 | } 191 | for (auto &var : varCurVersion) 192 | { 193 | if (var.empty()) 194 | var.push(EmptyOperand()); 195 | } 196 | 197 | renameVar(func.inBlock.get(), visited, varCount, varCurVersion); 198 | } 199 | 200 | void SSAConstructor::renameVar(Block *block, std::set &visited, std::vector &varCount, std::vector> &varCurVersion) 201 | { 202 | visited.insert(block); 203 | std::set killed; 204 | for (auto &phi : block->phi) 205 | { 206 | assert(phi.first == phi.second.dst.val); 207 | phi.second.dst.ver = ++varCount[phi.second.dst.val]; 208 | killed.insert(phi.second.dst.val); 209 | varCurVersion[phi.second.dst.val].push(phi.second.dst); 210 | } 211 | 212 | for (auto &ins : block->ins) 213 | { 214 | auto inputReg = ins.getInputReg(); 215 | for (auto &in : inputReg) 216 | { 217 | //assert(!varCurVersion[in->val].empty()); 218 | assert(varCurVersion[in->val].top().val == in->val); 219 | *in = varCurVersion[in->val].top().clone().setSize(in->size()); 220 | } 221 | 222 | for (Operand *out : ins.getOutputReg()) 223 | { 224 | out->ver = ++varCount[out->val]; 225 | 226 | if (!killed.count(out->val)) 227 | killed.insert(out->val); 228 | else 229 | varCurVersion[out->val].pop(); 230 | varCurVersion[out->val].push(*out); 231 | } 232 | } 233 | 234 | if (block->brTrue && block->brTrue == block->brFalse) 235 | { 236 | assert(!block->ins.empty()); 237 | assert(block->ins.back().oper == Br); 238 | block->ins.back() = IRJump(); 239 | block->brFalse.reset(); 240 | } 241 | 242 | auto visitChild = [&block, &varCount, &varCurVersion, &visited](Block::block_ptr &child, Operand Block::SigmaIns::*sigmaDst) 243 | { 244 | for (auto &sigma : block->sigma) 245 | { 246 | Operand var = sigma.second.src; 247 | var.ver = ++varCount[var.val]; 248 | varCurVersion[var.val].push(var); 249 | sigma.second.*sigmaDst = var; 250 | } 251 | for (auto &phi : child->phi) 252 | { 253 | assert(phi.first == phi.second.dst.val); 254 | //assert(!varCurVersion[phi.second.dst.val].empty()); 255 | phi.second.srcs.push_back({ varCurVersion[phi.second.dst.val].top(), block->self }); 256 | } 257 | if (!visited.count(child.get())) 258 | renameVar(child.get(), visited, varCount, varCurVersion); 259 | for (auto &sigma : block->sigma) 260 | varCurVersion[sigma.second.src.val].pop(); 261 | }; 262 | if (block->brTrue) 263 | visitChild(block->brTrue, &Block::SigmaIns::dstTrue); 264 | if (block->brFalse) 265 | visitChild(block->brFalse, &Block::SigmaIns::dstFalse); 266 | 267 | for (size_t varid : killed) 268 | varCurVersion[varid].pop(); 269 | } 270 | 271 | void SSAConstructor::constructSSA(MxProgram *program) 272 | { 273 | for (auto &func : program->vFuncs) 274 | { 275 | SSAConstructor ssa(func.content); 276 | ssa.constructSSA(); 277 | } 278 | } 279 | 280 | std::map SSAConstructor::calculateDefUse() 281 | { 282 | std::map ret; 283 | func.inBlock->traverse([&ret](Block *block) -> bool 284 | { 285 | for (InstructionBase &ins : block->instructions()) 286 | { 287 | for (Operand *operand : ins.getInputReg()) 288 | ret[*operand].uses.insert(std::make_pair(&ins, operand)); 289 | for (Operand *operand : ins.getOutputReg()) 290 | { 291 | assert(!ret[*operand].def.first); 292 | ret[*operand].def = { &ins, operand }; 293 | } 294 | } 295 | return true; 296 | }); 297 | return ret; 298 | } 299 | } -------------------------------------------------------------------------------- /src/SSAConstructor.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_SSA_H 2 | #define MX_COMPILER_SSA_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "MxProgram.h" 7 | #include "utils/DomTree.h" 8 | 9 | namespace MxIR 10 | { 11 | class SSAConstructor 12 | { 13 | public: 14 | SSAConstructor(Function &func) : func(func) {} 15 | 16 | public: 17 | struct varDefUse 18 | { 19 | std::pair def; 20 | std::multimap uses; 21 | varDefUse() : def({ nullptr, nullptr }) {} 22 | }; 23 | void constructSSA(); 24 | void constructSSIFull(); 25 | std::map calculateDefUse(); 26 | 27 | static void constructSSA(MxProgram *program); 28 | 29 | protected: 30 | void preprocess(); 31 | DomTree getDomTree(bool postDom); 32 | void renameVar(); 33 | static void renameVar(Block *block, std::set &visited, std::vector &varCount, std::vector> &varCurVersion); 34 | 35 | protected: 36 | struct globalVar //global: in function 37 | { 38 | Operand operand; //type must be reg 39 | std::set def; //block id of definition 40 | std::set uses; 41 | }; 42 | 43 | size_t maxVarID; 44 | std::map vars; //varid (regid) -> globalVar 45 | std::vector blocks; //id -> block 46 | std::map mapBlock; //block -> id 47 | Function &func; 48 | }; 49 | } 50 | 51 | #endif -------------------------------------------------------------------------------- /src/SSAReconstructor.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include "SSAReconstructor.h" 3 | 4 | namespace MxIR 5 | { 6 | void SSAReconstructor::preprocess() 7 | { 8 | func.inBlock->traverse([this](Block *block) -> bool 9 | { 10 | property[block].idx = vBlock.size(); 11 | vBlock.push_back(block); 12 | return true; 13 | }); 14 | dtree = DomTree(vBlock.size()); 15 | for (Block *block : vBlock) 16 | for (Block *next : { block->brTrue.get(), block->brFalse.get() }) 17 | { 18 | if(next) 19 | dtree.link(property[block].idx, property[next].idx); 20 | } 21 | assert(property[func.inBlock.get()].idx == 0); 22 | dtree.buildTree(0); 23 | } 24 | 25 | void SSAReconstructor::calcIDF(Block *block) 26 | { 27 | std::queue Q; 28 | for (size_t frontier : dtree.getDomFrontier(property[block].idx)) 29 | { 30 | Q.push(vBlock[frontier]); 31 | property[block].idf.insert(vBlock[frontier]); 32 | } 33 | 34 | while (!Q.empty()) 35 | { 36 | Block *cur = Q.front(); 37 | Q.pop(); 38 | for (size_t frontier : dtree.getDomFrontier(property[cur].idx)) 39 | if (!property[block].idf.count(vBlock[frontier])) 40 | { 41 | Q.push(vBlock[frontier]); 42 | property[block].idf.insert(vBlock[frontier]); 43 | } 44 | } 45 | 46 | property[block].visited = true; 47 | } 48 | 49 | void SSAReconstructor::reconstruct(const std::vector &vars) 50 | { 51 | curVer.clear(); 52 | defIDF.clear(); 53 | func.inBlock->traverse([&vars, this](Block *block) -> bool 54 | { 55 | std::set defs; 56 | property[block].latestVer.clear(); 57 | for (auto &ins : block->instructions()) 58 | { 59 | for (Operand *operand : ins.getOutputReg()) 60 | { 61 | for (size_t var : vars) 62 | if (operand->val == var) 63 | { 64 | if(!varOp.count(var)) 65 | varOp[var] = operand->clone().setPRegID(-1); 66 | else 67 | { 68 | assert(varOp[var].type == operand->type); 69 | } 70 | property[block].latestVer[var] = operand->ver = ++curVer[var]; 71 | defs.insert(var); 72 | } 73 | } 74 | } 75 | if (!defs.empty() && !property[block].visited) 76 | calcIDF(block); 77 | for (size_t def : defs) 78 | { 79 | for (Block *idf : property[block].idf) 80 | defIDF[def].insert(idf); 81 | } 82 | return true; 83 | }); 84 | 85 | func.inBlock->traverse([&vars, this](Block *block) -> bool 86 | { 87 | std::map latestVer; 88 | for (auto &kv : block->phi) 89 | { 90 | if (kv.second.dst.ver != 0) 91 | latestVer[kv.second.dst.val] = kv.second.dst.ver; 92 | for (auto &src : kv.second.srcs) 93 | { 94 | for(size_t var : vars) 95 | if (src.first.val == var) 96 | { 97 | size_t version = findDefFromBottom(var, src.second.lock().get()); 98 | src.first.ver = version; 99 | } 100 | } 101 | } 102 | 103 | for (auto &ins : block->ins) 104 | { 105 | for (Operand *operand : ins.getInputReg()) 106 | { 107 | for(size_t var : vars) 108 | if (operand->val == var) 109 | { 110 | if (latestVer.count(var)) 111 | operand->ver = latestVer[var]; 112 | else 113 | operand->ver = findDefFromTop(var, block); 114 | } 115 | } 116 | for (Operand *operand : ins.getOutputReg()) 117 | { 118 | if (operand->ver != 0) 119 | latestVer[operand->val] = operand->ver; 120 | } 121 | } 122 | return true; 123 | }); 124 | } 125 | 126 | size_t SSAReconstructor::findDefFromBottom(size_t var, Block *block) 127 | { 128 | if (property[block].latestVer.count(var)) 129 | return property[block].latestVer[var]; 130 | else 131 | return findDefFromTop(var, block); 132 | } 133 | 134 | size_t SSAReconstructor::findDefFromTop(size_t var, Block *block) 135 | { 136 | if (defIDF[var].count(block)) 137 | { 138 | if (block->phi.count(var)) 139 | return block->phi[var].dst.ver; 140 | Block::PhiIns phi; 141 | phi.dst = varOp[var]; 142 | phi.dst.ver = ++curVer[var]; 143 | if(!property[block].latestVer.count(var)) 144 | property[block].latestVer[var] = phi.dst.ver; 145 | for (Block *pred : block->preds) 146 | { 147 | Operand src = varOp[var]; 148 | src.ver = findDefFromBottom(var, pred); 149 | phi.srcs.push_back(std::make_pair(src, pred->self)); 150 | } 151 | block->phi.insert(std::make_pair(var, phi)); 152 | return phi.dst.ver; 153 | } 154 | else 155 | return findDefFromBottom(var, vBlock[dtree.getIdom(property[block].idx)]); 156 | } 157 | 158 | void SSAReconstructor::reconstructAuto() 159 | { 160 | std::vector vars; 161 | std::map freq; 162 | func.inBlock->traverse([&freq, &vars](Block *block) -> bool 163 | { 164 | for (auto &ins : block->instructions()) 165 | { 166 | for (Operand *operand : ins.getOutputReg()) 167 | { 168 | if (operand->val == Operand::InvalidID) 169 | continue; 170 | freq[operand->val]++; 171 | if (freq[operand->val] == 2) 172 | vars.push_back(operand->val); 173 | } 174 | } 175 | return true; 176 | }); 177 | 178 | reconstruct(vars); 179 | } 180 | } -------------------------------------------------------------------------------- /src/SSAReconstructor.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_SSA_RECONSTRUCTOR_H 2 | #define MX_COMPILER_SSA_RECONSTRUCTOR_H 3 | 4 | #include "common.h" 5 | #include "IR.h" 6 | #include "utils/DomTree.h" 7 | 8 | namespace MxIR 9 | { 10 | class SSAReconstructor 11 | { 12 | public: 13 | explicit SSAReconstructor(Function &func) : func(func) {} 14 | void preprocess(); 15 | void reconstruct(const std::vector &vars); //assume all vreg in IR has no version information 16 | void reconstructAuto(); //reconstruct ssa by all vars that have duplicated definitions 17 | 18 | protected: 19 | void calcIDF(Block *block); 20 | size_t findDefFromBottom(size_t var, Block *block); //return the latest version 21 | size_t findDefFromTop(size_t var, Block *block); 22 | 23 | protected: 24 | struct BlockProperty 25 | { 26 | size_t idx; 27 | std::set idf; //iterated dominance frontier 28 | bool visited; //visited when calculating idf 29 | std::map latestVer; 30 | 31 | BlockProperty() : idx(size_t(-1)), visited(false) {} 32 | }; 33 | std::map property; 34 | std::vector vBlock; 35 | 36 | DomTree dtree; 37 | Function &func; 38 | 39 | std::map varOp; 40 | std::map curVer; 41 | std::map> defIDF; //the union of idf of each block that defines the target var 42 | }; 43 | } 44 | 45 | #endif -------------------------------------------------------------------------------- /src/StaticTypeChecker.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_VARIABLE_CHECKER_H 2 | #define MX_COMPILER_VARIABLE_CHECKER_H 3 | 4 | #include "MxProgram.h" 5 | #include "IssueCollector.h" 6 | #include "GlobalSymbol.h" 7 | #include "AST.h" 8 | 9 | class StaticTypeChecker : public MxAST::ASTListener 10 | { 11 | public: 12 | StaticTypeChecker(MxProgram *memTable, GlobalSymbol *symbols, IssueCollector *issues); 13 | bool preCheck(MxAST::ASTRoot *root); 14 | 15 | protected: 16 | bool checkFunc(MxAST::ASTDeclFunc *declFunc, 17 | std::map &mapVarId, std::vector &varTable, 18 | size_t className); 19 | bool checkVar(MxAST::ASTDeclVar *declVar, std::map &mapVarId, std::vector &varTable); 20 | bool checkType(MxType type, ssize_t tokenL, ssize_t tokenR); 21 | 22 | size_t findConstructor(size_t className, const std::vector &vTypes); 23 | size_t findOverloadedFunc(size_t olid, const std::vector &vTypes); //return -1 for not found, -2 for ambiguous call 24 | 25 | virtual MxAST::ASTNode * enter(MxAST::ASTDeclClass *declClass) override; 26 | virtual MxAST::ASTNode * leave(MxAST::ASTDeclClass *declClass) override; 27 | virtual MxAST::ASTNode * enter(MxAST::ASTDeclFunc *declFunc) override; 28 | virtual MxAST::ASTNode * leave(MxAST::ASTDeclFunc *declFunc) override; 29 | virtual MxAST::ASTNode * enter(MxAST::ASTDeclVarGlobal *declVar) override; 30 | virtual MxAST::ASTNode * enter(MxAST::ASTDeclVarLocal *declVar) override; 31 | virtual MxAST::ASTNode * leave(MxAST::ASTDeclVar *declVar) override; 32 | virtual MxAST::ASTNode * enter(MxAST::ASTBlock *block) override; 33 | virtual MxAST::ASTNode * leave(MxAST::ASTBlock *block) override; 34 | virtual MxAST::ASTNode * enter(MxAST::ASTExprVar *var) override; 35 | virtual MxAST::ASTNode * leave(MxAST::ASTExprImm *imm) override; 36 | virtual MxAST::ASTNode * leave(MxAST::ASTExprUnary *expr) override; 37 | virtual MxAST::ASTNode * leave(MxAST::ASTExprBinary *expr) override; 38 | virtual MxAST::ASTNode * leave(MxAST::ASTExprAssignment *expr) override; 39 | virtual MxAST::ASTNode * leave(MxAST::ASTExprNew *expr) override; 40 | virtual MxAST::ASTNode * leave(MxAST::ASTExprSubscriptAccess *expr) override; 41 | virtual MxAST::ASTNode * leave(MxAST::ASTExprMemberAccess *expr) override; 42 | virtual MxAST::ASTNode * leave(MxAST::ASTExprFuncCall *expr) override; 43 | virtual MxAST::ASTNode * leave(MxAST::ASTStatementReturn *stat) override; 44 | virtual MxAST::ASTNode * leave(MxAST::ASTStatementBreak *stat) override; 45 | virtual MxAST::ASTNode * leave(MxAST::ASTStatementContinue *stat) override; 46 | virtual MxAST::ASTNode * leave(MxAST::ASTStatementIf *stat) override; 47 | virtual MxAST::ASTNode * enter(MxAST::ASTStatementWhile *stat) override; 48 | virtual MxAST::ASTNode * leave(MxAST::ASTStatementWhile *stat) override; 49 | virtual MxAST::ASTNode * enter(MxAST::ASTStatementFor *stat) override; 50 | virtual MxAST::ASTNode * leave(MxAST::ASTStatementFor *stat) override; 51 | protected: 52 | MxProgram *program; 53 | IssueCollector *issues; 54 | GlobalSymbol *symbols; 55 | 56 | std::map> mapClassMemberId; //class name -> { member name -> var id } 57 | std::map mapGlobalVar; 58 | 59 | std::vector vLocalVar; 60 | std::map> mapLocalVar; //name -> local var id 61 | std::stack> stkCurrentBlockVar; 62 | 63 | size_t curClass, curFunc; 64 | size_t depthLoop; 65 | }; 66 | 67 | #endif -------------------------------------------------------------------------------- /src/common.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_COMMON_H 2 | #define MX_COMPILER_COMMON_H 3 | 4 | #if defined(_DEBUG) && !defined(NDEBUG) 5 | #define IS_DEBUG 6 | #define IF_DEBUG(x) x 7 | #define IFNOT_DEBUG(x) 8 | #else 9 | #define IF_DEBUG(x) 10 | #define IFNOT_DEBUG(x) x 11 | #endif 12 | 13 | #include "common_headers.h" 14 | 15 | constexpr size_t MAX_ERROR = 100; 16 | constexpr size_t MAX_STRINGSIZE = 100000; 17 | constexpr size_t MAX_STRINGMEMUSAGE = 10000000; 18 | constexpr size_t POINTER_SIZE = 8; 19 | 20 | enum ValueType 21 | { 22 | lvalue, 23 | xvalue, //expiring value 24 | rvalue //rvalue or constant 25 | }; 26 | 27 | struct MxType 28 | { 29 | enum type 30 | { 31 | Void, Bool, Integer, String, Object, Function 32 | }; 33 | type mainType; 34 | size_t arrayDim; //0 for non-array 35 | size_t className; //-1 for undetermined object (null) 36 | size_t funcOLID; 37 | IF_DEBUG(std::string strClassName); 38 | 39 | bool isNull() const 40 | { 41 | return mainType == Object && className == size_t(-1); 42 | } 43 | bool isObject() const 44 | { 45 | return mainType == Object || mainType == String || arrayDim > 0; 46 | } 47 | size_t getSize() const 48 | { 49 | if (arrayDim > 0) 50 | return POINTER_SIZE; 51 | if (mainType == Void) 52 | return 0; 53 | if (mainType == Bool) 54 | return 1; 55 | if (mainType == Integer) 56 | return 4; 57 | return POINTER_SIZE; 58 | } 59 | bool operator==(const MxType &rhs) const 60 | { 61 | if (rhs.isNull()) 62 | return rhs == *this; 63 | if (isNull() && (rhs.mainType == Object || rhs.arrayDim > 0)) 64 | return true; 65 | if (mainType != rhs.mainType || arrayDim != rhs.arrayDim) 66 | return false; 67 | if (mainType == Object) 68 | return className == rhs.className; 69 | if (mainType == Function) 70 | return funcOLID == rhs.funcOLID; 71 | return true; 72 | } 73 | bool operator!=(const MxType &rhs) const 74 | { 75 | return !(*this == rhs); 76 | } 77 | 78 | static MxType Null() 79 | { 80 | return MxType{ Object, 0, size_t(-1) }; 81 | } 82 | }; 83 | 84 | class CompileFlags 85 | { 86 | public: 87 | bool disable_access_protect = false; 88 | bool optim_register_allocation = false; 89 | bool optim_inline = false; 90 | bool optim_loop_invariant = false; 91 | bool optim_dead_code = false; 92 | bool optim_gvn = false; 93 | bool gvn_strict_equal = false; 94 | int inline_param = 1000, inline_param2 = 25; 95 | 96 | static CompileFlags * getInstance() 97 | { 98 | static CompileFlags instance; 99 | return &instance; 100 | } 101 | private: 102 | CompileFlags() : disable_access_protect(false) {} 103 | CompileFlags(const CompileFlags &other) = delete; 104 | }; 105 | 106 | inline std::string transferHTML(const std::string &in) 107 | { 108 | std::string ret; 109 | for (char c : in) 110 | { 111 | if (c == '<') 112 | ret += "<"; 113 | else if (c == '>') 114 | ret += ">"; 115 | else if (c == ' ') 116 | ret += " "; 117 | else if (c == '"') 118 | ret += """; 119 | else if (c == '&') 120 | ret += "&"; 121 | else if (c == '\n') 122 | ret += "
"; 123 | else 124 | ret += c; 125 | } 126 | return ret; 127 | } 128 | 129 | //align addr to align bytes 130 | constexpr std::uint64_t alignAddr(std::uint64_t addr, std::uint64_t align) 131 | { 132 | return (addr + align - 1) / align * align; 133 | } 134 | 135 | inline void prints(std::ostream &out) {} 136 | 137 | template 138 | void prints(std::ostream &out, Tnow &&now, T&&... val) 139 | { 140 | out << std::forward(now); 141 | prints(out, std::forward(val)...); 142 | } 143 | 144 | 145 | 146 | #endif -------------------------------------------------------------------------------- /src/common_headers.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "common_headers.h" 2 | #include 3 | #include 4 | #include "AST.h" 5 | #include "ASTConstructor.h" 6 | #include "IssueCollector.h" 7 | #include "StaticTypeChecker.h" 8 | #include "MxBuiltin.h" 9 | #include "ConstantFold.h" 10 | #include "IRGenerator.h" 11 | #include "CodeGeneratorBasic.h" 12 | #include "option_parser.h" 13 | #include "SSAConstructor.h" 14 | #include "CodeGenerator.h" 15 | #include "InlineOptimizer.h" 16 | #include "LoopInvariantOptimizer.h" 17 | #include "DeadCodeElimination.h" 18 | #include "GVN.h" 19 | #include "LoadCombine.h" 20 | using namespace std; 21 | 22 | int compile(const std::string &fileName, const std::string &output) 23 | { 24 | antlr4::ANTLRFileStream fin(fileName); 25 | MxLexer lexer(&fin); 26 | antlr4::CommonTokenStream tokens(&lexer); 27 | MxParser parser(&tokens); 28 | auto prog = parser.prog(); 29 | 30 | if (lexer.getNumberOfSyntaxErrors() > 0 || parser.getNumberOfSyntaxErrors() > 0) 31 | return 1; 32 | IssueCollector ic(IssueCollector::NOTICE, &cerr, &tokens, fileName); 33 | ic.setDefault(); 34 | 35 | ASTConstructor constructor(&ic); 36 | GlobalSymbol symbol; 37 | symbol.setDefault(); 38 | 39 | try 40 | { 41 | MxProgram program; 42 | program.setDefault(); 43 | 44 | MxBuiltin builtin; 45 | builtin.setDefault(); 46 | builtin.init(); 47 | 48 | std::unique_ptr root(constructor.constructAST(prog, &symbol)); 49 | 50 | StaticTypeChecker checker(&program, &symbol, &ic); 51 | if (!checker.preCheck(root.get())) 52 | return 2; 53 | root->recursiveAccess(&checker); 54 | 55 | if (ic.cntError > 0) 56 | return 2; 57 | 58 | ASTOptimizer::ConstantFold cfold; 59 | root->recursiveAccess(&cfold); 60 | 61 | IRGenerator irgen; 62 | irgen.generateProgram(root.get()); 63 | 64 | if (ic.cntError > 0) 65 | return 2; 66 | 67 | if (CompileFlags::getInstance()->optim_inline) 68 | { 69 | MxIR::InlineOptimizer optim; 70 | optim.work(); 71 | } 72 | 73 | if (CompileFlags::getInstance()->optim_register_allocation) 74 | { 75 | MxIR::SSAConstructor::constructSSA(&program); 76 | if (CompileFlags::getInstance()->optim_gvn) 77 | { 78 | for (auto &func : program.vFuncs) 79 | { 80 | MxIR::GVN optim(func.content); 81 | optim.work(); 82 | 83 | MxIR::LoadCombine loadcombine(func.content); 84 | loadcombine.work(); 85 | 86 | MxIR::GVN optim2(func.content); 87 | optim2.work(); 88 | } 89 | } 90 | if (CompileFlags::getInstance()->optim_dead_code) 91 | { 92 | for (auto &func : program.vFuncs) 93 | { 94 | MxIR::DeadCodeElimination optim(func.content); 95 | optim.work(); 96 | } 97 | } 98 | if (CompileFlags::getInstance()->optim_loop_invariant) 99 | { 100 | for (auto &func : program.vFuncs) 101 | { 102 | MxIR::LoopInvariantOptimizer optim(func.content); 103 | optim.work(); 104 | } 105 | } 106 | std::ofstream fout(output); 107 | CodeGenerator codegen(fout); 108 | codegen.generateProgram(); 109 | } 110 | else 111 | { 112 | std::ofstream fout(output); 113 | CodeGeneratorBasic codegen(fout); 114 | codegen.generateProgram(); 115 | } 116 | } 117 | catch (IssueCollector::FatalErrorException &) 118 | { 119 | return 2; 120 | } 121 | if (ic.cntError > 0) 122 | return 2; 123 | return 0; 124 | } 125 | 126 | int main(int argc, char *argv[]) 127 | { 128 | int ret; 129 | std::string input, output; 130 | std::tie(ret, input, output) = ParseOptions(argc, argv); 131 | if (input.empty() || output.empty()) 132 | return ret; 133 | ret = compile(input, output); 134 | return ret; 135 | } -------------------------------------------------------------------------------- /src/option_parser.cpp: -------------------------------------------------------------------------------- 1 | #include "option_parser.h" 2 | #include 3 | 4 | std::tuple ParseOptions(int argc, char *argv[]) 5 | { 6 | using namespace boost::program_options; 7 | options_description options("Options:"); 8 | options.add_options() 9 | ("help,h", "Display this information") 10 | ("input", value>()->value_name("file"), "Input file") 11 | ("output,o", value()->value_name("file"), "Place the output into ") 12 | ("fdisable-access-protect", "Set the flag of disable access protect") 13 | ("optim-reg-alloc", "Optimize the register allocation") 14 | ("optim-inline", "enable inline expansion") 15 | ("optim-loop-invariant", "enable loop invariant optimization") 16 | ("optim-dead-code", "enable dead code elimination") 17 | ("optim-gvn", "enable global value numbering") 18 | ("inline-param", value()->value_name("param"), "the parameter for inline optimizer") 19 | ("inline-param2", value()->value_name("param"), "the parameter 2 for inline optimizer");; 20 | 21 | positional_options_description po; 22 | po.add("input", 1); 23 | 24 | variables_map vm; 25 | 26 | try 27 | { 28 | store(command_line_parser(argc, argv).options(options).positional(po).run(), vm); 29 | notify(vm); 30 | } 31 | catch (std::exception &e) 32 | { 33 | std::cerr << e.what() << std::endl; 34 | return std::make_tuple(1, "", ""); 35 | } 36 | 37 | if (vm.count("help")) 38 | { 39 | std::cout << "Usage: " << argv[0] << " [options] file" << std::endl; 40 | std::cout << options; 41 | return std::make_tuple(0, "", ""); 42 | } 43 | if (vm.count("fdisable-access-protect")) 44 | CompileFlags::getInstance()->disable_access_protect = true; 45 | if (vm.count("optim-reg-alloc")) 46 | CompileFlags::getInstance()->optim_register_allocation = true; 47 | if (vm.count("optim-inline")) 48 | CompileFlags::getInstance()->optim_inline = true; 49 | if (vm.count("inline-param")) 50 | CompileFlags::getInstance()->inline_param = vm["inline-param"].as(); 51 | if (vm.count("inline-param2")) 52 | CompileFlags::getInstance()->inline_param2 = vm["inline-param2"].as(); 53 | if (vm.count("optim-loop-invariant")) 54 | CompileFlags::getInstance()->optim_loop_invariant = true; 55 | if (vm.count("optim-dead-code")) 56 | CompileFlags::getInstance()->optim_dead_code = true; 57 | if (vm.count("optim-gvn")) 58 | CompileFlags::getInstance()->optim_gvn = true; 59 | if (!vm.count("input")) 60 | { 61 | std::cerr << argv[0] << ": no input file" << std::endl; 62 | return std::make_tuple(1, "", ""); 63 | } 64 | std::string input = vm["input"].as>()[0]; 65 | std::string output = vm.count("output") ? vm["output"].as() : "a.out"; 66 | return std::make_tuple(0, input, output); 67 | } -------------------------------------------------------------------------------- /src/option_parser.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_OPTION_PARSER_H 2 | #define MX_COMPILER_OPTION_PARSER_H 3 | 4 | #include "common.h" 5 | 6 | std::tuple ParseOptions(int argc, char *argv[]); 7 | 8 | #endif -------------------------------------------------------------------------------- /src/utils/CycleEquiv.cpp: -------------------------------------------------------------------------------- 1 | #include "../common_headers.h" 2 | #include "CycleEquiv.h" 3 | 4 | void CycleEquiv::dfs(size_t idx, size_t parent) 5 | { 6 | V[idx].visited = true; 7 | V[idx].min_dfn = V[idx].dfn = dfv.size(); 8 | dfv.push_back(idx); 9 | for (size_t i = 0; i < V[idx].to.size(); i++) 10 | { 11 | size_t child = V[idx].to[i]; 12 | if (child == parent) 13 | continue; 14 | if (V[child].visited) 15 | { 16 | E[V[idx].edges[i]].backward = true; 17 | E[V[idx].edges[i]].min_dfn = std::min(V[idx].dfn, V[child].dfn); 18 | V[idx].min_dfn = std::min(V[idx].min_dfn, V[child].dfn); 19 | continue; 20 | } 21 | dfs(child, idx); 22 | if (V[child].min_dfn < V[idx].min_dfn) 23 | V[idx].min_dfn = V[child].min_dfn; 24 | } 25 | } 26 | 27 | void CycleEquiv::dfs2(size_t idx, size_t upperEdge, std::list &bracketList) 28 | { 29 | V[idx].visited = true; 30 | size_t nChild = 0; 31 | for (size_t i = 0; i < V[idx].to.size(); i++) 32 | { 33 | size_t child = V[idx].to[i]; 34 | if (V[child].visited) 35 | continue; 36 | nChild++; 37 | std::list tmp; 38 | dfs2(child, V[idx].edges[i], tmp); 39 | bracketList.splice(bracketList.end(), tmp); 40 | } 41 | for (size_t e : V[idx].edges) 42 | { 43 | if (!E[e].backward) 44 | continue; 45 | if (E[e].min_dfn == V[idx].dfn) 46 | bracketList.erase(E[e].iter); 47 | else 48 | { 49 | assert(E[e].min_dfn < V[idx].dfn); 50 | bracketList.push_front(e); 51 | E[e].iter = bracketList.begin(); 52 | } 53 | } 54 | if (nChild > 1) 55 | { 56 | size_t min_dfn = SIZE_MAX, min_dfn2 = SIZE_MAX; 57 | for (size_t i = 0; i < V[idx].to.size(); i++) 58 | { 59 | size_t child = V[idx].to[i]; 60 | if (E[V[idx].edges[i]].backward || V[child].dfn < V[idx].dfn) 61 | continue; 62 | if (V[child].min_dfn < min_dfn) 63 | min_dfn2 = min_dfn, min_dfn = V[child].min_dfn; 64 | else if (V[child].min_dfn < min_dfn2) 65 | min_dfn2 = V[child].min_dfn; 66 | } 67 | if (min_dfn2 < V[idx].dfn) 68 | { 69 | size_t e = E.size(); 70 | E.push_back(edge{ true, min_dfn2 }); 71 | V[dfv[min_dfn2]].edges.push_back(e); 72 | bracketList.push_front(e); 73 | E[e].iter = bracketList.begin(); 74 | } 75 | } 76 | E[upperEdge].name = { bracketList.front(), bracketList.size() }; 77 | } 78 | 79 | std::vector> CycleEquiv::work() 80 | { 81 | nEdge = E.size(); 82 | dfv.clear(); 83 | for (vertex &vtx : V) 84 | vtx.visited = false; 85 | dfs(0, SIZE_MAX); 86 | for (vertex &vtx : V) 87 | vtx.visited = false; 88 | V[0].visited = true; 89 | 90 | for (size_t i = 0; i < V[0].to.size(); i++) 91 | { 92 | if (V[V[0].to[i]].visited) 93 | continue; 94 | std::list bracketList; 95 | dfs2(V[0].to[i], V[0].edges[i], bracketList); 96 | } 97 | 98 | std::map> equClass; 99 | for (size_t i = 0; i < nEdge; i++) 100 | { 101 | if (E[i].backward) 102 | equClass[edge::compactName{ i, 1 }].insert(i); 103 | else 104 | equClass[E[i].name].insert(i); 105 | } 106 | 107 | std::vector> ret; 108 | for (auto &kv : equClass) 109 | ret.emplace_back(std::move(kv.second)); 110 | return ret; 111 | } -------------------------------------------------------------------------------- /src/utils/CycleEquiv.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_CYCLE_EQUIV_H 2 | #define MX_COMPILER_UTILS_CYCLE_EQUIV_H 3 | 4 | #include "../common.h" 5 | 6 | class CycleEquiv 7 | { 8 | public: 9 | CycleEquiv(size_t nVertex) : V(nVertex), dfv(nVertex) {} 10 | size_t addEdge(size_t u, size_t v) 11 | { 12 | if (u != v) 13 | { 14 | V[u].to.push_back(v); 15 | V[u].edges.push_back(E.size()); 16 | V[v].to.push_back(u); 17 | V[v].edges.push_back(E.size()); 18 | } 19 | E.emplace_back(); 20 | return E.size() - 1; 21 | } 22 | std::vector> work(); 23 | 24 | protected: 25 | void dfs(size_t idx, size_t parent); 26 | void dfs2(size_t idx, size_t upperEdge, std::list &bracketList); 27 | 28 | protected: 29 | struct edge 30 | { 31 | bool backward = false; 32 | size_t min_dfn; 33 | std::list::iterator iter; 34 | 35 | struct compactName 36 | { 37 | size_t listFront = size_t(-1); size_t listSize = size_t(-1); 38 | bool operator<(const compactName &rhs) const 39 | { 40 | if (listFront == rhs.listFront) 41 | return listSize < rhs.listSize; 42 | return listFront < rhs.listFront; 43 | } 44 | }; 45 | compactName name; 46 | }; 47 | struct vertex 48 | { 49 | std::vector edges; 50 | std::vector to; 51 | size_t dfn; 52 | size_t min_dfn; 53 | bool visited; 54 | }; 55 | std::vector V; 56 | std::vector dfv; //dfv[i]: vertex with dfn i 57 | std::vector E; 58 | std::list bracketList; 59 | size_t nEdge; 60 | }; 61 | 62 | #endif -------------------------------------------------------------------------------- /src/utils/DepGraph.cpp: -------------------------------------------------------------------------------- 1 | #include "../common_headers.h" 2 | #include "DepGraph.h" 3 | 4 | void DepGraph::work() 5 | { 6 | std::stack S; 7 | trajan(0, S); 8 | sort(); 9 | } 10 | 11 | void DepGraph::trajan(size_t idx, std::stack &S) 12 | { 13 | V[idx].visited = true; 14 | V[idx].lowlink = V[idx].dfn = dfsclock++; 15 | S.push(idx); 16 | 17 | for (size_t next : V[idx].to) 18 | { 19 | if (!V[next].visited) 20 | { 21 | trajan(next, S); 22 | V[idx].lowlink = std::min(V[idx].lowlink, V[next].lowlink); 23 | } 24 | else if (V[next].groupID == -1) 25 | { 26 | V[idx].lowlink = std::min(V[idx].lowlink, V[next].dfn); 27 | } 28 | } 29 | 30 | if (V[idx].lowlink == V[idx].dfn) 31 | { 32 | VGroup.emplace_back(); 33 | while (true) 34 | { 35 | size_t v = S.top(); 36 | S.pop(); 37 | VGroup.back().vtxs.push_back(v); 38 | V[v].groupID = VGroup.size() - 1; 39 | 40 | if (v == idx) 41 | break; 42 | } 43 | } 44 | } 45 | 46 | void DepGraph::sort() 47 | { 48 | for (vertex &v : V) 49 | for (size_t next : v.to) 50 | VGroup[V[next].groupID].indegree++; 51 | 52 | std::queue worklist; 53 | for (size_t i = 0; i < VGroup.size(); i++) 54 | if (VGroup[i].indegree == 0) 55 | worklist.push(i); 56 | 57 | size_t sorted = 0; 58 | while (!worklist.empty()) 59 | { 60 | size_t cur = worklist.front(); 61 | worklist.pop(); 62 | 63 | for (size_t vtx : VGroup[cur].vtxs) 64 | for (size_t next : V[vtx].to) 65 | { 66 | VGroup[V[next].groupID].indegree--; 67 | if (VGroup[V[next].groupID].indegree == 0) 68 | worklist.push(V[next].groupID); 69 | } 70 | 71 | std::swap(VGroup[sorted++], VGroup[cur]); 72 | } 73 | } -------------------------------------------------------------------------------- /src/utils/DepGraph.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_DEP_GRAPH_H 2 | #define MX_COMPILER_UTILS_DEP_GRAPH_H 3 | 4 | #include "common.h" 5 | #include 6 | 7 | class DepGraph 8 | { 9 | public: 10 | DepGraph(size_t cntVertex) : V(cntVertex) {} 11 | void link(size_t u, size_t v) 12 | { 13 | V.at(u).to.push_back(v); 14 | V.at(v).to.push_back(u); 15 | } 16 | void work(); 17 | const std::vector & getVertex(size_t i) 18 | { 19 | return VGroup.at(i).vtxs; 20 | } 21 | 22 | protected: 23 | struct vertex 24 | { 25 | std::vector to; 26 | 27 | bool visited = false; 28 | size_t dfn, lowlink; 29 | ssize_t groupID = -1; 30 | }; 31 | struct vtxGroup 32 | { 33 | std::vector vtxs; 34 | size_t indegree = 0; 35 | }; 36 | 37 | protected: 38 | void trajan(size_t idx, std::stack &S); 39 | void sort(); 40 | 41 | protected: 42 | std::vector V; 43 | 44 | size_t dfsclock; 45 | std::vector VGroup; 46 | }; 47 | 48 | #endif -------------------------------------------------------------------------------- /src/utils/DispatchLength.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_DISPATCH_LENGTH_H 2 | #define MX_COMPILER_UTILS_DISPATCH_LENGTH_H 3 | 4 | #include "../common.h" 5 | 6 | template 7 | auto dispatch_length(F &&f, uint64_t value, size_t length, bool sign) 8 | { 9 | if (!sign) 10 | { 11 | if (length == 1) 12 | return f((uint8_t)value); 13 | else if (length == 2) 14 | return f((uint16_t)value); 15 | else if (length == 4) 16 | return f((uint32_t)value); 17 | else 18 | { 19 | assert(length == 8); 20 | return f((uint64_t)value); 21 | } 22 | } 23 | else 24 | { 25 | if (length == 1) 26 | return f((int8_t)value); 27 | else if (length == 2) 28 | return f((int16_t)value); 29 | else if (length == 4) 30 | return f((int32_t)value); 31 | else 32 | { 33 | assert(length == 8); 34 | return f((int64_t)value); 35 | } 36 | } 37 | } 38 | 39 | template 40 | auto dispatch_length(F &&f, uint64_t value, size_t length, bool sign, T &&...other) 41 | { 42 | return dispatch_length([&](auto ...valN) 43 | { 44 | return dispatch_length([&](auto val1) 45 | { 46 | return f(val1, valN...); 47 | }, value, length, sign); 48 | }, other...); 49 | } 50 | 51 | #endif -------------------------------------------------------------------------------- /src/utils/DomTree.cpp: -------------------------------------------------------------------------------- 1 | #include "../common_headers.h" 2 | #include "DomTree.h" 3 | 4 | void DomTree::buildTree(size_t root) 5 | { 6 | for (size_t idx = 0; idx < V.size(); idx++) 7 | V[idx].visited = false, V[idx].df.clear(); 8 | dfs(root); 9 | for (size_t idx = 0; idx < V.size(); idx++) 10 | idv[V[idx].dfn] = idx; 11 | calcIdom(); 12 | for (size_t idx = 0; idx < V.size(); idx++) 13 | if(V[idx].idom != idx) 14 | V[V[idx].idom].children.push_back(idx); 15 | IF_DEBUG(verifyIdom()); 16 | calcDomFrontier(); 17 | } 18 | 19 | void DomTree::dfs(size_t idx) 20 | { 21 | V[idx].dfn = idv.size(); 22 | V[idx].visited = true; 23 | V[idx].semi = size_t(-1); 24 | V[idx].idom = size_t(-1); 25 | idv.push_back(idx); 26 | for (size_t next : V[idx].to) 27 | { 28 | if (V[next].visited) 29 | continue; 30 | dfs(next); 31 | } 32 | } 33 | 34 | void DomTree::calcIdom() 35 | { 36 | struct ufsnode //NOTE: non-standard UFS 37 | { 38 | size_t root; 39 | size_t min_semi_point; 40 | }; 41 | std::vector ufs(V.size()); 42 | for (size_t i = 0; i < V.size(); i++) 43 | ufs[i] = { i, i }; 44 | 45 | std::function ufs_find_root; 46 | ufs_find_root = [&ufs, &ufs_find_root, this](size_t idx) 47 | { 48 | if (ufs[idx].root == idx) 49 | return idx; 50 | size_t root = ufs_find_root(ufs[idx].root); 51 | size_t curSemi = V[ufs[idx].min_semi_point].semi; 52 | size_t parSemi = V[ufs[ufs[idx].root].min_semi_point].semi; 53 | if (V[parSemi].dfn < V[curSemi].dfn) 54 | ufs[idx].min_semi_point = ufs[ufs[idx].root].min_semi_point; 55 | return ufs[idx].root = root; 56 | }; 57 | auto ufs_merge = [&ufs, &ufs_find_root](size_t fa, size_t child) 58 | { 59 | assert(ufs[child].root == child); 60 | ufs[child].root = fa; 61 | }; 62 | 63 | std::vector> semiDomList(V.size()); 64 | 65 | for (ssize_t dfn = idv.size() - 1; dfn >= 0; dfn--) 66 | { 67 | size_t idx = idv[dfn]; 68 | if (dfn != 0) 69 | { 70 | size_t semi_dfn = SIZE_MAX; 71 | for (size_t prev : V[idx].from) 72 | { 73 | if (V[prev].dfn <= V[idx].dfn) 74 | semi_dfn = std::min(semi_dfn, V[prev].dfn); 75 | else 76 | { 77 | ufs_find_root(prev); 78 | size_t minSemi = V[ufs[prev].min_semi_point].semi; 79 | if (V[minSemi].dfn < semi_dfn) 80 | semi_dfn = V[minSemi].dfn; 81 | } 82 | } 83 | assert(semi_dfn != SIZE_MAX); 84 | V[idx].semi = idv[semi_dfn]; 85 | } 86 | else 87 | V[idx].semi = idx; 88 | 89 | semiDomList[V[idx].semi].push_back(idx); 90 | 91 | for (size_t dom_child : semiDomList[idx]) 92 | { 93 | ufs_find_root(dom_child); 94 | size_t u = ufs[dom_child].min_semi_point; 95 | if (V[u].semi == idx) 96 | { 97 | V[dom_child].idom = idx; 98 | assert(V[dom_child].semi == V[dom_child].idom); 99 | } 100 | else 101 | { 102 | V[dom_child].idom = u; 103 | assert(V[dom_child].semi != V[dom_child].idom); 104 | } 105 | } 106 | 107 | for (size_t next : V[idx].to) 108 | { 109 | if (V[next].dfn < V[idx].dfn) //backward edge & cross edge 110 | continue; 111 | if (ufs[next].root != next) //forward edge 112 | continue; 113 | ufs_merge(idx, next); //tree edge 114 | } 115 | } 116 | 117 | for (size_t dfn = 0; dfn < V.size(); dfn++) 118 | { 119 | size_t idx = idv[dfn]; 120 | if (V[idx].semi != V[idx].idom) 121 | V[idx].idom = V[V[idx].idom].idom; 122 | } 123 | } 124 | 125 | void DomTree::calcDomFrontier() 126 | { 127 | std::vector> frontier(V.size()); 128 | for (size_t idx = 0; idx < V.size(); idx++) 129 | { 130 | for (size_t prev : V[idx].from) 131 | { 132 | size_t cur = prev; 133 | while (cur != V[idx].idom) 134 | { 135 | frontier[cur].insert(idx); 136 | cur = V[cur].idom; 137 | } 138 | } 139 | } 140 | for (size_t idx = 0; idx < V.size(); idx++) 141 | std::copy(frontier[idx].begin(), frontier[idx].end(), std::back_inserter(V[idx].df)); 142 | } 143 | 144 | void DomTree::verifyIdom() 145 | { 146 | std::vector idomTrue(V.size(), idv[0]); 147 | for (size_t dfn = 1; dfn < V.size(); dfn++) 148 | { 149 | for (size_t idx = 0; idx < V.size(); idx++) 150 | V[idx].visited = false; 151 | std::function traverse; 152 | traverse = [dfn, &traverse, this](size_t idx) 153 | { 154 | V[idx].visited = true; 155 | for (size_t next : V[idx].to) 156 | if (!V[next].visited && V[next].dfn != dfn) 157 | traverse(next); 158 | }; 159 | traverse(idv[0]); 160 | for (size_t idx = 0; idx < V.size(); idx++) 161 | if (!V[idx].visited && idx != idv[dfn]) 162 | idomTrue[idx] = idv[dfn]; 163 | } 164 | for (size_t idx = 0; idx < V.size(); idx++) 165 | { 166 | assert(idomTrue[idx] == V[idx].idom); 167 | } 168 | } -------------------------------------------------------------------------------- /src/utils/DomTree.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_DOM_TREE_H 2 | #define MX_COMPILER_UTILS_DOM_TREE_H 3 | 4 | #include "../common.h" 5 | 6 | class DomTree 7 | { 8 | public: 9 | DomTree() {} 10 | DomTree(size_t nVertex) : V(nVertex) {} 11 | 12 | size_t addVertex() 13 | { 14 | V.emplace_back(); 15 | return V.size() - 1; 16 | } 17 | DomTree & link(size_t u, size_t v) 18 | { 19 | V.at(u).to.push_back(v); 20 | V.at(v).from.push_back(u); 21 | return *this; 22 | } 23 | size_t getIdom(size_t idx) 24 | { 25 | return V.at(idx).idom; 26 | } 27 | const std::vector & getDomFrontier(size_t idx) 28 | { 29 | return V.at(idx).df; 30 | } 31 | const std::vector & getDomChildren(size_t idx) 32 | { 33 | return V.at(idx).children; 34 | } 35 | void buildTree(size_t root); 36 | 37 | 38 | protected: 39 | void dfs(size_t idx); 40 | void calcIdom(); 41 | void calcDomFrontier(); 42 | void verifyIdom(); 43 | 44 | protected: 45 | struct vertex 46 | { 47 | size_t dfn; 48 | size_t semi; 49 | size_t idom; 50 | std::vector to, from; 51 | std::vector df; //Dominance Frontier 52 | std::vector children; //Children on dominator tree 53 | bool visited; 54 | }; 55 | std::vector V; 56 | std::vector idv; //dfn -> idx 57 | }; 58 | 59 | #endif -------------------------------------------------------------------------------- /src/utils/ElementAdapter.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_ELEMENT_ADAPTER_H 2 | #define MX_COMPILER_UTILS_ELEMENT_ADAPTER_H 3 | 4 | #include 5 | 6 | template 7 | class ElementAdapter; 8 | template 9 | ElementAdapter element_adapter(Container &&container, AdapterFunc adapter); 10 | 11 | template 12 | class ElementAdapter 13 | { 14 | friend ElementAdapter element_adapter(Container &&container, AdapterFunc adapter); 15 | //typedef decltype(std::declval()(*(((typename std::remove_reference::type *)nullptr)->begin()))) DstType; 16 | public: 17 | class iterator 18 | { 19 | friend class ElementAdapter; 20 | public: 21 | iterator & operator++() 22 | { 23 | ++iter; 24 | return *this; 25 | } 26 | iterator & operator--() 27 | { 28 | --iter; 29 | return *this; 30 | } 31 | iterator operator++(int) 32 | { 33 | iterator ret = *this; 34 | ++(*this); 35 | return ret; 36 | } 37 | iterator operator--(int) 38 | { 39 | iterator ret = *this; 40 | --(*this); 41 | return ret; 42 | } 43 | bool operator==(const iterator &rhs) const 44 | { 45 | return iter == rhs.iter; 46 | } 47 | bool operator!=(const iterator &rhs) const 48 | { 49 | return iter != rhs.iter; 50 | } 51 | auto & operator*() 52 | { 53 | return adapter->adapter(*iter); 54 | } 55 | auto * operator->() 56 | { 57 | return &adapter->adapter(*iter); 58 | } 59 | private: 60 | decltype(((typename std::remove_reference::type *)nullptr)->begin()) iter; 61 | ElementAdapter *adapter; 62 | }; 63 | public: 64 | 65 | iterator begin() 66 | { 67 | iterator iter; 68 | iter.iter = container.begin(); 69 | iter.adapter = this; 70 | return iter; 71 | } 72 | iterator end() 73 | { 74 | iterator iter; 75 | iter.iter = container.end(); 76 | iter.adapter = this; 77 | return iter; 78 | } 79 | 80 | private: 81 | ElementAdapter(Container &&container, AdapterFunc adapter) : container(std::forward(container)), adapter(adapter) {} 82 | 83 | private: 84 | Container container; 85 | AdapterFunc adapter; 86 | }; 87 | 88 | template 89 | ElementAdapter element_adapter(Container &&container, AdapterFunc adapter) 90 | { 91 | return ElementAdapter(std::forward(container), adapter); 92 | } 93 | #endif -------------------------------------------------------------------------------- /src/utils/JoinIterator.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_JOIN_ITERATOR 2 | #define MX_COMPILER_UTILS_JOIN_ITERATOR 3 | 4 | #include 5 | 6 | template 7 | class JoinWrapper; 8 | template 9 | class JoinWrapperPair; 10 | template 11 | class JoinIterator; 12 | template 13 | JoinWrapper join(Container&& ...containers); 14 | 15 | template 16 | class JoinWrapper 17 | { 18 | }; 19 | 20 | template 21 | class JoinWrapper 22 | { 23 | template 24 | friend JoinWrapper join(C&& ...); 25 | template 26 | friend class JoinWrapper; 27 | 28 | Container container; 29 | public: 30 | typedef decltype(container.begin()) iterator; 31 | iterator begin() { return container.begin(); } 32 | iterator end() { return container.end(); } 33 | 34 | private: 35 | JoinWrapper(Container &&container) : container(container) {} 36 | }; 37 | 38 | template 39 | class JoinWrapper 40 | { 41 | template 42 | friend JoinWrapper join(C&& ...); 43 | template 44 | friend class JoinWrapper; 45 | public: 46 | typedef JoinIterator> iterator; 47 | iterator begin() { return wrapper.begin(); } 48 | iterator end() { return wrapper.end(); } 49 | 50 | private: 51 | JoinWrapper(Container1 &&container1, ContainerOther &&...other) : wrapper(std::forward(container1), join(std::forward(other)...)) {} 52 | private: 53 | JoinWrapperPair> wrapper; 54 | }; 55 | 56 | 57 | template 58 | class JoinWrapperPair 59 | { 60 | friend class JoinIterator; 61 | template 62 | friend class JoinWrapper; 63 | public: 64 | JoinIterator begin(); 65 | JoinIterator end(); 66 | 67 | private: 68 | JoinWrapperPair(Container1 &&container1, Container2 &&container2) : container1(std::forward(container1)), container2(std::forward(container2)) {} 69 | private: 70 | Container1 container1; 71 | Container2 container2; 72 | }; 73 | 74 | template 75 | class JoinIterator 76 | { 77 | friend class JoinWrapperPair; 78 | typedef JoinIterator ThisType; 79 | public: 80 | ThisType & operator++() 81 | { 82 | if (!wrapper) 83 | return *this; 84 | if (inContainer1) 85 | { 86 | ++iter1; 87 | if (iter1 == wrapper->container1.end()) 88 | { 89 | inContainer1 = false; 90 | iter2 = wrapper->container2.begin(); 91 | } 92 | } 93 | else 94 | ++iter2; 95 | return *this; 96 | } 97 | ThisType & operator--() 98 | { 99 | if (!wrapper) 100 | return *this; 101 | if (!inContainer1) 102 | { 103 | if (iter2 == wrapper->container2.begin()) 104 | { 105 | inContainer1 = true; 106 | iter1 = --wrapper->container1.end(); 107 | } 108 | } 109 | else 110 | --iter1; 111 | return *this; 112 | } 113 | ThisType operator++(int) 114 | { 115 | ThisType ret = *this; 116 | ++(*this); 117 | return ret; 118 | } 119 | ThisType operator--(int) 120 | { 121 | ThisType ret = *this; 122 | --(*this); 123 | return ret; 124 | } 125 | bool operator==(const ThisType &rhs) const 126 | { 127 | if (inContainer1 != rhs.inContainer1) 128 | return false; 129 | if (inContainer1) 130 | return iter1 == rhs.iter1; 131 | return iter2 == rhs.iter2; 132 | } 133 | bool operator!=(const ThisType &rhs) const 134 | { 135 | return !((*this) == rhs); 136 | } 137 | 138 | RetType & operator*() 139 | { 140 | if (inContainer1) 141 | return static_cast(*iter1); 142 | return static_cast(*iter2); 143 | } 144 | RetType * operator->() 145 | { 146 | if (inContainer1) 147 | return static_cast(&(*iter1)); 148 | return static_cast(&(*iter2)); 149 | } 150 | 151 | private: 152 | JoinWrapperPair *wrapper; 153 | decltype(((typename std::remove_reference::type *)nullptr)->begin()) iter1; 154 | decltype(((typename std::remove_reference::type *)nullptr)->begin()) iter2; 155 | bool inContainer1; 156 | }; 157 | 158 | template 159 | JoinIterator JoinWrapperPair::begin() 160 | { 161 | JoinIterator iter; 162 | iter.wrapper = this; 163 | if (container1.begin() == container1.end()) 164 | { 165 | iter.iter2 = container2.begin(); 166 | iter.inContainer1 = false; 167 | } 168 | else 169 | { 170 | iter.iter1 = container1.begin(); 171 | iter.inContainer1 = true; 172 | } 173 | return iter; 174 | } 175 | 176 | template 177 | JoinIterator JoinWrapperPair::end() 178 | { 179 | JoinIterator iter; 180 | iter.wrapper = this; 181 | iter.iter2 = container2.end(); 182 | iter.inContainer1 = false; 183 | return iter; 184 | } 185 | 186 | template 187 | JoinWrapper join(Container&& ...containers) 188 | { 189 | return JoinWrapper(std::forward(containers)...); 190 | } 191 | 192 | #endif -------------------------------------------------------------------------------- /src/utils/MaxClique.cpp: -------------------------------------------------------------------------------- 1 | #include "../common_headers.h" 2 | #include "MaxClique.h" 3 | 4 | bool MaxClique::BronKerbosch() 5 | { 6 | if (P.empty() && X.empty()) 7 | return true; 8 | size_t u = P.empty() ? *X.begin() : *P.begin(); 9 | 10 | size_t j = 0; 11 | for (size_t i = 0; i < P.size(); i++) 12 | if (V[u].neighbor.count(P[i])) 13 | std::swap(P[j++], P[i]); 14 | for (size_t k = P.size() - 1; k >= j; k--) 15 | { 16 | size_t v = P[k]; 17 | std::vector oldP = std::move(P), oldX = std::move(X); 18 | for (size_t t : oldP) 19 | if (V[v].neighbor.count(t)) 20 | P.push_back(t); 21 | for (size_t t : oldX) 22 | if (V[v].neighbor.count(t)) 23 | X.push_back(t); 24 | R.push_back(v); 25 | if (BronKerbosch()) 26 | return true; 27 | R.pop_back(); 28 | P = std::move(oldP), X = std::move(oldX); 29 | assert(k == P.size() - 1 && P[k] == v); 30 | P.pop_back(); 31 | X.push_back(v); 32 | } 33 | return false; 34 | } -------------------------------------------------------------------------------- /src/utils/MaxClique.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_MAX_CLIQUE_H 2 | #define MX_COMPILER_UTILS_MAX_CLIQUE_H 3 | 4 | #include "../common.h" 5 | 6 | class MaxClique 7 | { 8 | public: 9 | MaxClique(size_t nVertex) : V(nVertex) {} 10 | void link(size_t u, size_t v) 11 | { 12 | V.at(u).neighbor.insert(v); 13 | V.at(v).neighbor.insert(u); 14 | } 15 | const std::vector & findMaxClique() 16 | { 17 | R.clear(); 18 | X.clear(); 19 | P.resize(V.size()); 20 | std::iota(P.begin(), P.end(), 0); 21 | BronKerbosch(); 22 | return R; 23 | } 24 | bool BronKerbosch(); 25 | 26 | protected: 27 | struct vertex 28 | { 29 | std::set neighbor; 30 | }; 31 | std::vector R, P, X; 32 | std::vector V; 33 | }; 34 | 35 | #endif -------------------------------------------------------------------------------- /src/utils/UnionFindSet.h: -------------------------------------------------------------------------------- 1 | #ifndef MX_COMPILER_UTILS_UNION_FIND_SET_H 2 | #define MX_COMPILER_UTILS_UNION_FIND_SET_H 3 | 4 | #include "../common.h" 5 | 6 | class UnionFindSet 7 | { 8 | public: 9 | UnionFindSet() {} 10 | UnionFindSet(size_t size) : vNodes(size) 11 | { 12 | for (size_t i = 0; i < size; i++) 13 | vNodes[i].father = i, vNodes[i].size = 1; 14 | } 15 | 16 | size_t addNode() 17 | { 18 | size_t idx = vNodes.size(); 19 | vNodes.push_back({ idx, 1 }); 20 | return idx; 21 | } 22 | size_t merge(size_t u, size_t v) 23 | { 24 | u = findRoot(u), v = findRoot(v); 25 | if (u == v) 26 | return u; 27 | if (vNodes.at(u).size < vNodes.at(v).size) 28 | return merge(v, u); 29 | 30 | vNodes[v].father = vNodes[u].father; 31 | size_t root = findRoot(v); 32 | vNodes[root].size += vNodes[v].size; 33 | return root; 34 | } 35 | size_t findRoot(size_t idx) 36 | { 37 | size_t &father = vNodes.at(idx).father; 38 | if (vNodes[father].father == father) 39 | return father; 40 | return father = findRoot(father); 41 | } 42 | 43 | protected: 44 | struct node 45 | { 46 | size_t father; 47 | size_t size; 48 | }; 49 | std::vector vNodes; 50 | }; 51 | 52 | #endif --------------------------------------------------------------------------------