├── .clang-format ├── .gitignore ├── Builtins.h ├── CMakeLists.txt ├── Codegen.cpp ├── Codegen.h ├── Exp.h ├── FindLLVM.cmake ├── FuncDef.h ├── LICENSE ├── Lexer.h ├── Lexer.re ├── Parser.cpp ├── Parser.h ├── Printer.cpp ├── Printer.h ├── Program.h ├── README.md ├── Scope.h ├── SimpleJIT.h ├── Stmt.h ├── Syntax.h ├── Token.cpp ├── Token.h ├── TokenStream.h ├── Type.h ├── Typechecker.cpp ├── Typechecker.h ├── VarDecl.h ├── Visitor.h ├── examples ├── factorial.in └── sum.in ├── grammar.txt └── main.cpp /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: LLVM 2 | AccessModifierOffset: '-2' 3 | AlignAfterOpenBracket: Align 4 | AlignConsecutiveAssignments: 'true' 5 | AlignConsecutiveDeclarations: 'true' 6 | AlignOperands: 'true' 7 | AlignTrailingComments: 'true' 8 | AllowAllParametersOfDeclarationOnNextLine: 'false' 9 | AllowShortBlocksOnASingleLine: 'false' 10 | AllowShortCaseLabelsOnASingleLine: 'false' 11 | AllowShortFunctionsOnASingleLine: Inline 12 | AllowShortIfStatementsOnASingleLine: 'false' 13 | AllowShortLoopsOnASingleLine: 'false' 14 | AlwaysBreakAfterReturnType: None 15 | AlwaysBreakBeforeMultilineStrings: 'true' 16 | AlwaysBreakTemplateDeclarations: 'true' 17 | BinPackArguments: 'true' 18 | BinPackParameters: 'false' 19 | ExperimentalAutoDetectBinPacking: 'false' 20 | BreakBeforeBinaryOperators: NonAssignment 21 | BreakBeforeBraces: Custom 22 | BreakBeforeTernaryOperators: 'false' 23 | BreakConstructorInitializersBeforeComma: 'true' 24 | ColumnLimit: '120' 25 | ConstructorInitializerAllOnOneLineOrOnePerLine: 'false' 26 | Cpp11BracedListStyle: 'true' 27 | IndentCaseLabels: 'true' 28 | IndentWidth: '4' 29 | KeepEmptyLinesAtTheStartOfBlocks: 'true' 30 | Language: Cpp 31 | MaxEmptyLinesToKeep: '2' 32 | NamespaceIndentation: None 33 | ObjCSpaceBeforeProtocolList: 'true' 34 | PointerAlignment: Left 35 | SpaceAfterCStyleCast: 'false' 36 | SpaceBeforeAssignmentOperators: 'true' 37 | SpaceBeforeParens: Never 38 | SpaceInEmptyParentheses: 'false' 39 | SpacesBeforeTrailingComments: '2' 40 | SpacesInAngles: 'false' 41 | SpacesInCStyleCastParentheses: 'false' 42 | SpacesInParentheses: 'true' 43 | SpacesInSquareBrackets: 'false' 44 | Standard: Cpp11 45 | TabWidth: '4' 46 | UseTab: Never 47 | SortIncludes: 'true' 48 | ReflowComments: 'false' 49 | BraceWrapping: { 50 | AfterClass: 'true' 51 | AfterControlStatement: 'true' 52 | AfterEnum: 'true' 53 | AfterFunction: 'true' 54 | AfterNamespace: 'false' 55 | AfterStruct: 'true' 56 | AfterUnion: 'true' 57 | BeforeCatch: 'true' 58 | BeforeElse: 'true' 59 | IndentBraces: 'false' 60 | } 61 | PenaltyExcessCharacter: 1 62 | PenaltyBreakBeforeFirstCallParameter: 40 63 | PenaltyBreakFirstLessLess: 1 64 | PenaltyBreakComment: 30 65 | PenaltyBreakString: 30 66 | PenaltyReturnTypeOnItsOwnLine: 9999 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.d 2 | *.o 3 | *~ 4 | Lexer.cpp 5 | -------------------------------------------------------------------------------- /Builtins.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Get builtin function declarations (for typechecking purposes). 4 | inline const char* GetBuiltins() 5 | { 6 | return 7 | // Arithmetic 8 | "int operator+ ( int x, int y ); " 9 | "int operator- ( int x, int y ); " 10 | "int operator* ( int x, int y ); " 11 | "int operator/ ( int x, int y ); " 12 | "int operator% ( int x, int y ); " 13 | // Equality 14 | "bool operator== ( int x, int y ); " 15 | "bool operator!= ( int x, int y ); " 16 | "bool operator== ( bool x, bool y ); " 17 | "bool operator!= ( bool x, bool y ); " 18 | // Comparisons 19 | "bool operator< ( int x, int y ); " 20 | "bool operator<= ( int x, int y ); " 21 | "bool operator> ( int x, int y ); " 22 | "bool operator>= ( int x, int y ); " 23 | // Unary operations. 24 | "bool operator! ( bool x ); " 25 | "int operator- ( int x ); " 26 | // Logical operations 27 | "bool operator&& ( bool x, bool y ); " 28 | "bool operator|| ( bool x, bool y ); " 29 | // Type conversions 30 | "bool operator bool ( int x ); " 31 | "int operator int ( bool x ); " 32 | ; 33 | } 34 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1) 2 | 3 | project(WeekendCompiler VERSION 1.0 4 | LANGUAGES CXX) 5 | 6 | find_program(RE2C_EXE re2c) 7 | if(NOT RE2C_EXE) 8 | message(WARNING "re2c executable not found. Please specify RE2C_EXE.") 9 | endif() 10 | 11 | add_custom_command( 12 | OUTPUT Lexer.cpp 13 | COMMAND "${RE2C_EXE}" -o Lexer.cpp "${CMAKE_SOURCE_DIR}/Lexer.re" 14 | DEPENDS Lexer.re 15 | ) 16 | 17 | list(INSERT CMAKE_MODULE_PATH 0 "${CMAKE_SOURCE_DIR}") 18 | set(LLVM_STATIC TRUE) 19 | find_package(LLVM REQUIRED) 20 | 21 | add_executable(weekend 22 | main.cpp 23 | Codegen.cpp 24 | Parser.cpp 25 | Printer.cpp 26 | Token.cpp 27 | Typechecker.cpp 28 | ${CMAKE_CURRENT_BINARY_DIR}/Lexer.cpp 29 | ) 30 | 31 | target_include_directories(weekend PUBLIC "${CMAKE_SOURCE_DIR}" "${LLVM_INCLUDE_DIRS}") 32 | target_link_libraries(weekend PUBLIC "${LLVM_LDFLAGS}" ${LLVM_LIBRARY}) 33 | -------------------------------------------------------------------------------- /Codegen.cpp: -------------------------------------------------------------------------------- 1 | #include "Codegen.h" 2 | #include "Exp.h" 3 | #include "FuncDef.h" 4 | #include "Program.h" 5 | #include "Stmt.h" 6 | #include "Visitor.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace llvm; 15 | 16 | // The symbol table maps variable declarations to LLVM values. Local variables are mapped to alloca 17 | // pointers, while function parameters are mapped to their LLVM equivalents. 18 | using SymbolTable = std::map; 19 | 20 | // The function table maps function definitions to their LLVM equivalents. 21 | using FunctionTable = std::map; 22 | 23 | namespace { 24 | 25 | // Base class for expression and statement code generators, which holds the LLVM context, module, 26 | // and IR builder, providing various helper routines. 27 | class CodegenBase 28 | { 29 | public: 30 | // In addition to the LLVM context, module, and builder, the base class holds boolean and integer 31 | // types. 32 | CodegenBase( LLVMContext* context, Module* module, IRBuilder<>* builder ) 33 | : m_context( context ) 34 | , m_module( module ) 35 | , m_builder( builder ) 36 | , m_boolType( IntegerType::get( *m_context, 1 ) ) 37 | , m_intType( IntegerType::get( *m_context, 32 ) ) 38 | { 39 | } 40 | 41 | LLVMContext* GetContext() { return m_context; } 42 | 43 | Module* GetModule() { return m_module; } 44 | 45 | IRBuilder<>* GetBuilder() { return m_builder; } 46 | 47 | // Convert a type to its LLVM equivalent. 48 | llvm::Type* ConvertType(::Type type) 49 | { 50 | switch (type) 51 | { 52 | case kTypeBool: 53 | return m_boolType; 54 | case kTypeInt: 55 | return m_intType; 56 | case kTypeUnknown: 57 | assert( false && "Invalid type" ); 58 | return m_intType; 59 | } 60 | assert( false && "Invalid type" ); 61 | return m_intType; 62 | } 63 | 64 | llvm::Type* GetBoolType() const { return m_boolType; } 65 | 66 | llvm::Type* GetIntType() const { return m_intType; } 67 | 68 | // Generate LLVM IR for a constant boolean. 69 | Constant* GetBool( bool b ) const { return ConstantInt::get( GetBoolType(), int(b), false /*isSigned*/ ); } 70 | 71 | // Generate LLVM IR for a constant integer. 72 | Constant* GetInt( int i ) const { return ConstantInt::get( GetIntType(), i, true /*isSigned*/ ); } 73 | 74 | protected: 75 | LLVMContext* m_context; 76 | Module* m_module; 77 | IRBuilder<>* m_builder; 78 | llvm::Type* m_boolType; 79 | llvm::Type* m_intType; 80 | }; 81 | 82 | // Expression code generator. 83 | class CodegenExp : public ExpVisitor, CodegenBase 84 | { 85 | public: 86 | CodegenExp( LLVMContext* context, Module* module, IRBuilder<>* builder, 87 | SymbolTable* symbols, FunctionTable* functions ) 88 | : CodegenBase( context, module, builder ) 89 | , m_symbols( symbols ) 90 | , m_functions( functions ) 91 | { 92 | } 93 | 94 | // Helper routine to codegen a subexpression. The visitor operates on 95 | // non-const expressions, so we must const_cast when dispatching. 96 | Value* Codegen( const Exp& exp ) { return reinterpret_cast( const_cast( exp ).Dispatch( *this ) ); } 97 | 98 | void* Visit( BoolExp& exp ) override { return GetBool( exp.GetValue() ); } 99 | 100 | void* Visit( IntExp& exp ) override { return GetInt( exp.GetValue() ); } 101 | 102 | // Generate code for a variable reference. 103 | void* Visit( VarExp& exp ) override 104 | { 105 | // The typechecker linked variable references to their declarations. 106 | const VarDecl* varDecl = exp.GetVarDecl(); 107 | assert( varDecl ); 108 | 109 | // An llvm::Value was associated with the variable when its declaration was processed. 110 | SymbolTable::const_iterator it = m_symbols->find( varDecl ); 111 | assert( it != m_symbols->end() ); 112 | Value* value = it->second; 113 | 114 | // The value is either a function parameter or a pointer to storage for a local variable. 115 | switch( varDecl->GetKind() ) 116 | { 117 | case VarDecl::kParam: 118 | return value; 119 | case VarDecl::kLocal: 120 | return GetBuilder()->CreateLoad( value, varDecl->GetName() ); 121 | } 122 | assert(false && "unreachable"); 123 | return nullptr; 124 | } 125 | 126 | // Generate code for a function call. 127 | void* Visit( CallExp& exp ) override 128 | { 129 | // Convert the arguments to LLVM values. 130 | std::vector args; 131 | args.reserve( exp.GetArgs().size() ); 132 | for( const ExpPtr& arg : exp.GetArgs() ) 133 | { 134 | args.push_back( Codegen( *arg ) ); 135 | } 136 | 137 | // Builtin definition? TODO: use an enum for builtin functions, rather than matching the name. 138 | const std::string& funcName = exp.GetFuncName(); 139 | if( funcName == "+" ) 140 | return GetBuilder()->CreateAdd( args.at( 0 ), args.at( 1 ) ); 141 | else if( funcName == "-" ) 142 | { 143 | if( args.size() == 1 ) 144 | return GetBuilder()->CreateNeg( args.at( 0 ) ); 145 | else 146 | return GetBuilder()->CreateSub( args.at( 0 ), args.at( 1 ) ); 147 | } 148 | else if( funcName == "*" ) 149 | return GetBuilder()->CreateMul( args.at( 0 ), args.at( 1 ) ); 150 | else if( funcName == "/" ) 151 | return GetBuilder()->CreateSDiv( args.at( 0 ), args.at( 1 ) ); 152 | else if( funcName == "%" ) 153 | return GetBuilder()->CreateSRem( args.at( 0 ), args.at( 1 ) ); 154 | else if( funcName == "==" ) 155 | return GetBuilder()->CreateICmpEQ( args.at( 0 ), args.at( 1 ) ); 156 | else if( funcName == "!=" ) 157 | return GetBuilder()->CreateICmpNE( args.at( 0 ), args.at( 1 ) ); 158 | else if( funcName == "<" ) 159 | return GetBuilder()->CreateICmpSLT( args.at( 0 ), args.at( 1 ) ); 160 | else if( funcName == "<=" ) 161 | return GetBuilder()->CreateICmpSLE( args.at( 0 ), args.at( 1 ) ); 162 | else if( funcName == ">" ) 163 | return GetBuilder()->CreateICmpSGT( args.at( 0 ), args.at( 1 ) ); 164 | else if( funcName == ">=" ) 165 | return GetBuilder()->CreateICmpSGE( args.at( 0 ), args.at( 1 ) ); 166 | else if( funcName == "!" ) 167 | return GetBuilder()->CreateICmpEQ( args.at( 0 ), GetBool( false ) ); 168 | else if( funcName == "bool" ) 169 | return GetBuilder()->CreateICmpNE( args.at( 0 ), GetInt( 0 ) ); 170 | else if( funcName == "int" ) 171 | return GetBuilder()->CreateZExt( args.at( 0 ), GetIntType() ); 172 | // TODO: proper short-circuiting for && and ||. 173 | else if (funcName == "&&") 174 | return GetBuilder()->CreateSelect( args.at( 0 ), args.at( 1 ), GetBool( false ) ); 175 | else if (funcName == "||") 176 | return GetBuilder()->CreateSelect( args.at( 0 ), GetBool( true ), args.at( 1 ) ); 177 | 178 | // The typechecker linked function call sites to their definitions. 179 | const FuncDef* funcDef = exp.GetFuncDef(); 180 | assert( funcDef ); 181 | 182 | // An llvm::Function was associated with the function when its definition was processed. 183 | FunctionTable::const_iterator it = m_functions->find( funcDef ); 184 | assert( it != m_functions->end() ); 185 | Function* function = it->second; 186 | 187 | // Generate LLVM function call. 188 | return GetBuilder()->CreateCall( function, args, funcDef->GetName() ); 189 | } 190 | 191 | private: 192 | SymbolTable* m_symbols; 193 | FunctionTable* m_functions; 194 | }; 195 | 196 | 197 | // Statement code generator. 198 | class CodegenStmt : public StmtVisitor, CodegenBase 199 | { 200 | public: 201 | CodegenStmt( LLVMContext* context, Module* module, IRBuilder<>* builder, 202 | SymbolTable* symbols, FunctionTable* functions, Function* currentFunction ) 203 | : CodegenBase( context, module, builder ) 204 | , m_symbols( symbols ) 205 | , m_functions( functions ) 206 | , m_currentFunction( currentFunction ) 207 | , m_codegenExp( context, module, builder, symbols, functions ) 208 | { 209 | } 210 | 211 | // Helper routine to codegen a subexpression. The visitor operates on 212 | // non-const statements, so we must const_cast when dispatching. 213 | void Codegen( const Stmt& stmt ) 214 | { 215 | const_cast( stmt ).Dispatch( *this ); 216 | } 217 | 218 | 219 | // Generate code for a function call statement. 220 | void Visit( CallStmt& stmt ) override 221 | { 222 | m_codegenExp.Codegen( stmt.GetCallExp() ); 223 | } 224 | 225 | 226 | // Generate code for an assignment statement. 227 | void Visit( AssignStmt& stmt ) override 228 | { 229 | // The typechecker links assignments to variable declarations. Assignments to function 230 | // parameters are prohibited by the typechecker. 231 | const VarDecl* varDecl = stmt.GetVarDecl(); 232 | assert( varDecl && varDecl->GetKind() == VarDecl::kLocal ); 233 | 234 | // The symbol table maps local variables to stack-allocated storage. 235 | SymbolTable::const_iterator it = m_symbols->find( varDecl ); 236 | assert( it != m_symbols->end() ); 237 | Value* location = it->second; 238 | 239 | // Generate code for the rvalue and store it. 240 | Value* rvalue = m_codegenExp.Codegen( stmt.GetRvalue() ); 241 | GetBuilder()->CreateStore(rvalue, location); 242 | } 243 | 244 | 245 | // Generate code for a local variable declaration. 246 | void Visit( DeclStmt& stmt ) override 247 | { 248 | const VarDecl* varDecl = stmt.GetVarDecl(); 249 | llvm::Type* type = ConvertType(varDecl->GetType()); 250 | 251 | // Generate an "alloca" instruction, which goes in entry block of the current function. 252 | IRBuilder<> allocaBuilder( &m_currentFunction->getEntryBlock(), 253 | m_currentFunction->getEntryBlock().getFirstInsertionPt() ); 254 | Value* location = allocaBuilder.CreateAlloca( type, nullptr /*arraySize*/, varDecl->GetName() ); 255 | 256 | // Store the variable location in the symbol table. 257 | m_symbols->insert( SymbolTable::value_type( varDecl, location ) ); 258 | 259 | // Generate code for the initializer (if any) and store it. 260 | if (stmt.HasInitExp()) 261 | { 262 | Value* rvalue = m_codegenExp.Codegen( stmt.GetInitExp() ); 263 | GetBuilder()->CreateStore(rvalue, location); 264 | } 265 | } 266 | 267 | // Generate code for a return statement. 268 | void Visit( ReturnStmt& stmt ) override 269 | { 270 | Value* result = m_codegenExp.Codegen( stmt.GetExp() ); 271 | GetBuilder()->CreateRet( result ); 272 | } 273 | 274 | // Generate code for a sequence of statements. 275 | void Visit( SeqStmt& seq ) override 276 | { 277 | for( const StmtPtr& stmt : seq.Get() ) 278 | { 279 | Codegen( *stmt ); 280 | } 281 | } 282 | 283 | // Generate code for an "if" statement. 284 | void Visit( IfStmt& stmt ) override 285 | { 286 | // Generate code for the conditional expression. 287 | Value* condition = codegenCondExp( stmt.GetCondExp() ); 288 | 289 | // Create basic blocks for "then" branch, "else" branch (if any), and the join point. 290 | BasicBlock* thenBlock = BasicBlock::Create( *GetContext(), "then", m_currentFunction ); 291 | BasicBlock* elseBlock = stmt.HasElseStmt() ? BasicBlock::Create( *GetContext(), "else", m_currentFunction ) : nullptr; 292 | BasicBlock* joinBlock = BasicBlock::Create( *GetContext(), "join", m_currentFunction ); 293 | 294 | // Create a conditional branch. 295 | GetBuilder()->CreateCondBr( condition, thenBlock, elseBlock ? elseBlock : joinBlock ); 296 | 297 | // Generate code for "then" branch 298 | GetBuilder()->SetInsertPoint( thenBlock ); 299 | Codegen( stmt.GetThenStmt() ); 300 | 301 | // Create an unconditional branch to the "join" block, unless the block already ends 302 | // in a return instruction. 303 | if( !GetBuilder()->GetInsertBlock()->getTerminator() ) 304 | GetBuilder()->CreateBr( joinBlock ); 305 | 306 | // If present, generate code for "else" branch. 307 | if( stmt.HasElseStmt() ) 308 | { 309 | GetBuilder()->SetInsertPoint( elseBlock ); 310 | Codegen( stmt.GetElseStmt() ); 311 | 312 | // Create an unconditional branch to the "join" block, unless the block already ends 313 | // in a return instruction. 314 | if( !GetBuilder()->GetInsertBlock()->getTerminator() ) 315 | GetBuilder()->CreateBr( joinBlock ); 316 | } 317 | 318 | // Set the builder insertion point in the join block. 319 | GetBuilder()->SetInsertPoint( joinBlock ); 320 | } 321 | 322 | // Generate code for a while loop. 323 | void Visit( WhileStmt& stmt ) override 324 | { 325 | // Create a basic block for the start of the loop. 326 | BasicBlock* loopBlock = BasicBlock::Create( *GetContext(), "loop", m_currentFunction ); 327 | GetBuilder()->CreateBr( loopBlock ); 328 | GetBuilder()->SetInsertPoint( loopBlock ); 329 | 330 | // Generate code for the loop condition. 331 | Value* condition = codegenCondExp( stmt.GetCondExp() ); 332 | 333 | // Create basic blocks for the loop body and the join point. 334 | BasicBlock* bodyBlock = BasicBlock::Create( *GetContext(), "body", m_currentFunction ); 335 | BasicBlock* joinBlock = BasicBlock::Create( *GetContext(), "join", m_currentFunction ); 336 | 337 | // Create a conditional branch. 338 | GetBuilder()->CreateCondBr( condition, bodyBlock, joinBlock ); 339 | 340 | // Generate code for the loop body, followed by an unconditional branch to the loop head. 341 | GetBuilder()->SetInsertPoint( bodyBlock ); 342 | Codegen( stmt.GetBodyStmt() ); 343 | GetBuilder()->CreateBr( loopBlock ); 344 | 345 | // Set the builder insertion point in the join block. 346 | GetBuilder()->SetInsertPoint( joinBlock ); 347 | } 348 | 349 | private: 350 | SymbolTable* m_symbols; 351 | FunctionTable* m_functions; 352 | Function* m_currentFunction; 353 | CodegenExp m_codegenExp; 354 | 355 | // Generate code for the condition expression in an "if" statement or a while loop. 356 | Value* codegenCondExp( const Exp& exp ) 357 | { 358 | Value* condition = m_codegenExp.Codegen( exp ); 359 | if ( exp.GetType() == kTypeBool ) 360 | return condition; 361 | 362 | // Convert the integer conditional expresison to a boolean (i1) using a comparison. 363 | assert( exp.GetType() == kTypeInt ); 364 | return GetBuilder()->CreateICmpNE( condition, GetInt( 0 ) ); 365 | } 366 | }; 367 | 368 | 369 | // Function definition code generator. 370 | class CodegenFunc : public CodegenBase 371 | { 372 | public: 373 | CodegenFunc( LLVMContext* context, Module* module, FunctionTable* functions ) 374 | : CodegenBase( context, module, &m_builder ) 375 | , m_builder( *context ) 376 | , m_functions( functions ) 377 | { 378 | } 379 | 380 | // Generate code for a function definition. 381 | void Codegen( const FuncDef* funcDef ) 382 | { 383 | // Don't generate code for builtin function declarations. 384 | if( !funcDef->HasBody() ) 385 | return; 386 | 387 | // Convert parameter types to LLVM types. 388 | const std::vector& params = funcDef->GetParams(); 389 | std::vector paramTypes; 390 | paramTypes.reserve( params.size() ); 391 | for( const VarDeclPtr& param : params ) 392 | { 393 | paramTypes.push_back( ConvertType( param->GetType() ) ); 394 | } 395 | 396 | // Construct LLVM function type and function definition. 397 | llvm::Type* returnType = ConvertType( funcDef->GetReturnType() ); 398 | FunctionType* funcType = FunctionType::get( returnType, paramTypes, false /*isVarArg*/ ); 399 | Function* function = Function::Create( funcType, Function::ExternalLinkage, funcDef->GetName(), GetModule() ); 400 | 401 | // The main function has external linkage. Other functions are 402 | // "internal", which encourages inlining. 403 | function->setLinkage( funcDef->GetName() == "main" ? Function::ExternalLinkage : Function::InternalLinkage ); 404 | 405 | // Update the function table. 406 | m_functions->insert( FunctionTable::value_type( funcDef, function ) ); 407 | 408 | // Construct a symbol table that maps the parameter declarations to the LLVM function parameters. 409 | SymbolTable symbols; 410 | size_t i = 0; 411 | for( Argument& arg : function->args() ) 412 | { 413 | symbols.insert( SymbolTable::value_type( params[i].get(), &arg ) ); 414 | ++i; 415 | } 416 | 417 | // Create entry block and use it as the builder's insertion point. 418 | BasicBlock* block = BasicBlock::Create(*GetContext(), "entry", function); 419 | GetBuilder()->SetInsertPoint(block); 420 | 421 | // Generate code for the body of the function. 422 | CodegenStmt codegen( GetContext(), GetModule(), GetBuilder(), &symbols, m_functions, function ); 423 | codegen.Codegen( funcDef->GetBody() ); 424 | 425 | // Add a return instruction if the user neglected to do so. 426 | if( !GetBuilder()->GetInsertBlock()->getTerminator() ) 427 | GetBuilder()->CreateRet( GetInt( 0 ) ); 428 | } 429 | 430 | private: 431 | IRBuilder<> m_builder; 432 | FunctionTable* m_functions; 433 | }; 434 | 435 | } // anonymous namespace 436 | 437 | 438 | // Generate code for a program. 439 | std::unique_ptr Codegen(LLVMContext* context, const Program& program) 440 | { 441 | // Construct LLVM module. 442 | std::unique_ptr module( new Module( "module", *context ) ); 443 | 444 | // The function table maps function definitions to their LLVM equivalents. 445 | FunctionTable functions; 446 | 447 | // Generate code for each function, adding LLVM functions to the odule. 448 | for( const FuncDefPtr& funcDef : program.GetFunctions() ) 449 | { 450 | CodegenFunc( context, module.get(), &functions ).Codegen( funcDef.get() ); 451 | } 452 | return std::move( module ); 453 | } 454 | -------------------------------------------------------------------------------- /Codegen.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | class Program; 6 | namespace llvm { class LLVMContext; class Module; } 7 | 8 | // Generate LLVM IR for the given program. 9 | std::unique_ptr Codegen( llvm::LLVMContext* context, const Program& program ); 10 | -------------------------------------------------------------------------------- /Exp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Type.h" 4 | #include "Visitor.h" 5 | #include 6 | #include 7 | #include 8 | 9 | /// Base class for an expression, which holds its type. 10 | class Exp 11 | { 12 | public: 13 | /// Construct expression. Most expression types are unknown until typechecking, 14 | /// except for constants. 15 | explicit Exp( Type type = kTypeUnknown ) 16 | : m_type( type ) 17 | { 18 | } 19 | 20 | /// The destructor is virtual, ensuring that the destructor for a derived 21 | /// class will be properly invoked. 22 | virtual ~Exp() {} 23 | 24 | /// Get the expression type (usually kTypeUnknown if not yet typechecked). 25 | Type GetType() const { return m_type; } 26 | 27 | /// Set the expression type. 28 | void SetType( Type type ) { m_type = type; } 29 | 30 | /// Dispatch to a visitor. \see ExpVisitor 31 | virtual void* Dispatch( ExpVisitor& visitor ) = 0; 32 | 33 | private: 34 | Type m_type; 35 | }; 36 | 37 | /// Unique pointer to expression. 38 | using ExpPtr = std::unique_ptr; 39 | 40 | /// Boolean constant expression. 41 | class BoolExp : public Exp 42 | { 43 | public: 44 | /// Construct boolean constant expression. 45 | BoolExp( bool value ) 46 | : Exp( kTypeBool ) 47 | , m_value( value ) 48 | { 49 | } 50 | 51 | /// Get the value of this constant. 52 | bool GetValue() const { return m_value; } 53 | 54 | /// Dispatch to visitor. 55 | void* Dispatch( ExpVisitor& visitor ) override { return visitor.Visit( *this ); } 56 | 57 | private: 58 | bool m_value; 59 | }; 60 | 61 | 62 | /// Integer constant expression 63 | class IntExp : public Exp 64 | { 65 | public: 66 | /// Construct integer constant expression. 67 | IntExp( int value ) 68 | : Exp( kTypeInt ) 69 | , m_value( value ) 70 | { 71 | } 72 | 73 | /// Get the value of this constant. 74 | int GetValue() const { return m_value; } 75 | 76 | /// Dispatch to visitor. 77 | void* Dispatch( ExpVisitor& visitor ) override { return visitor.Visit( *this ); } 78 | 79 | private: 80 | int m_value; 81 | }; 82 | 83 | 84 | /// Variable expression. 85 | class VarExp : public Exp 86 | { 87 | public: 88 | /// Construct variable expression. 89 | VarExp( const std::string& name ) 90 | : m_name( name ) 91 | { 92 | } 93 | 94 | /// Get the variable name. 95 | const std::string& GetName() const { return m_name; } 96 | 97 | /// Get the variable's declaration (null if not yet typechecked). 98 | const VarDecl* GetVarDecl() const { return m_varDecl; } 99 | 100 | /// Link this variable expression to the variable's declaration. Called from the typechecker. 101 | void SetVarDecl( const VarDecl* varDecl ) { m_varDecl = varDecl; } 102 | 103 | /// Dispatch to a visitor. 104 | void* Dispatch( ExpVisitor& visitor ) override { return visitor.Visit( *this ); } 105 | 106 | private: 107 | std::string m_name; 108 | const VarDecl* m_varDecl; // assigned by the typechecker. 109 | }; 110 | 111 | 112 | /// Function call expression. 113 | class CallExp : public Exp 114 | { 115 | public: 116 | /// Construct function call expression with arbitrary arguments. 117 | CallExp( const std::string& funcName, std::vector&& args ) 118 | : m_funcName( funcName ) 119 | , m_args( std::move( args ) ) 120 | , m_funcDef( nullptr ) 121 | { 122 | } 123 | 124 | /// Construct a unary function call (for convenience). 125 | CallExp( const std::string& funcName, ExpPtr exp ) 126 | : m_funcName( funcName ) 127 | , m_args( 1 ) 128 | , m_funcDef( nullptr ) 129 | { 130 | m_args[0] = std::move( exp ); 131 | } 132 | 133 | /// Construct a binary function call (for convenience). 134 | CallExp( const std::string& funcName, ExpPtr leftExp, ExpPtr rightExp ) 135 | : m_funcName( funcName ) 136 | , m_args( 2 ) 137 | , m_funcDef( nullptr ) 138 | { 139 | m_args[0] = std::move( leftExp ); 140 | m_args[1] = std::move( rightExp ); 141 | } 142 | 143 | /// Get the function name. 144 | const std::string& GetFuncName() const { return m_funcName; } 145 | 146 | /// Get the argument expressions. 147 | const std::vector& GetArgs() const { return m_args; } 148 | 149 | /// Get the function definition (null until typechecked). 150 | const FuncDef* GetFuncDef() const { return m_funcDef; } 151 | 152 | /// Link this function call to the function definition. Called from typechecker. 153 | void SetFuncDef( const FuncDef* funcDef ) { m_funcDef = funcDef; } 154 | 155 | /// Dispatch to visitor. 156 | void* Dispatch( ExpVisitor& visitor ) override { return visitor.Visit( *this ); } 157 | 158 | private: 159 | std::string m_funcName; 160 | std::vector m_args; 161 | const FuncDef* m_funcDef; // set by typechecker. 162 | }; 163 | 164 | /// Unique pointer to function call expression 165 | using CallExpPtr = std::unique_ptr; 166 | 167 | -------------------------------------------------------------------------------- /FindLLVM.cmake: -------------------------------------------------------------------------------- 1 | # - Find LLVM library 2 | # Find the native LLVM includes and library 3 | # This module defines 4 | # LLVM_INCLUDE_DIRS, where to find LLVM.h, Set when LLVM_INCLUDE_DIR is found. 5 | # LLVM_LIBRARIES, libraries to link against to use LLVM. 6 | # LLVM_ROOT_DIR, The base directory to search for LLVM. 7 | # This can also be an environment variable. 8 | # LLVM_FOUND, If false, do not try to use LLVM. 9 | # 10 | # also defined, but not for general use are 11 | # LLVM_LIBRARY, where to find the LLVM library. 12 | 13 | #============================================================================= 14 | # Copyright 2015 Blender Foundation. 15 | # 16 | # Distributed under the OSI-approved BSD License (the "License"); 17 | # see accompanying file Copyright.txt for details. 18 | # 19 | # This software is distributed WITHOUT ANY WARRANTY; without even the 20 | # implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 21 | # See the License for more information. 22 | #============================================================================= 23 | 24 | if(LLVM_ROOT_DIR) 25 | if(DEFINED LLVM_VERSION) 26 | find_program(LLVM_CONFIG llvm-config-${LLVM_VERSION} HINTS ${LLVM_ROOT_DIR}/bin) 27 | endif() 28 | if(NOT LLVM_CONFIG) 29 | find_program(LLVM_CONFIG llvm-config HINTS ${LLVM_ROOT_DIR}/bin) 30 | endif() 31 | else() 32 | if(DEFINED LLVM_VERSION) 33 | message(running llvm-config-${LLVM_VERSION}) 34 | find_program(LLVM_CONFIG llvm-config-${LLVM_VERSION}) 35 | endif() 36 | if(NOT LLVM_CONFIG) 37 | find_program(LLVM_CONFIG llvm-config) 38 | endif() 39 | endif() 40 | 41 | if(NOT DEFINED LLVM_VERSION) 42 | execute_process(COMMAND ${LLVM_CONFIG} --version 43 | OUTPUT_VARIABLE LLVM_VERSION 44 | OUTPUT_STRIP_TRAILING_WHITESPACE) 45 | set(LLVM_VERSION ${LLVM_VERSION} CACHE STRING "Version of LLVM to use") 46 | endif() 47 | if(NOT LLVM_ROOT_DIR) 48 | execute_process(COMMAND ${LLVM_CONFIG} --prefix 49 | OUTPUT_VARIABLE LLVM_ROOT_DIR 50 | OUTPUT_STRIP_TRAILING_WHITESPACE) 51 | set(LLVM_ROOT_DIR ${LLVM_ROOT_DIR} CACHE PATH "Path to the LLVM installation") 52 | endif() 53 | if(NOT LLVM_LIBPATH) 54 | execute_process(COMMAND ${LLVM_CONFIG} --libdir 55 | OUTPUT_VARIABLE LLVM_LIBPATH 56 | OUTPUT_STRIP_TRAILING_WHITESPACE) 57 | set(LLVM_LIBPATH ${LLVM_LIBPATH} CACHE PATH "Path to the LLVM library path") 58 | mark_as_advanced(LLVM_LIBPATH) 59 | endif() 60 | 61 | if(LLVM_STATIC) 62 | find_library(LLVM_LIBRARY 63 | NAMES LLVMAnalysis # first of a whole bunch of libs to get 64 | PATHS ${LLVM_LIBPATH}) 65 | else() 66 | find_library(LLVM_LIBRARY 67 | NAMES 68 | LLVM-${LLVM_VERSION} 69 | LLVMAnalysis # check for the static library as a fall-back 70 | PATHS ${LLVM_LIBPATH}) 71 | endif() 72 | 73 | 74 | if(LLVM_LIBRARY AND LLVM_ROOT_DIR AND LLVM_LIBPATH) 75 | execute_process(COMMAND ${LLVM_CONFIG} --includedir 76 | OUTPUT_VARIABLE LLVM_INCLUDE_DIRS 77 | OUTPUT_STRIP_TRAILING_WHITESPACE) 78 | if(LLVM_STATIC) 79 | # if static LLVM libraries were requested, use llvm-config to generate 80 | # the list of what libraries we need, and substitute that in the right 81 | # way for LLVM_LIBRARY. 82 | execute_process(COMMAND ${LLVM_CONFIG} --libfiles 83 | OUTPUT_VARIABLE LLVM_LIBRARY 84 | OUTPUT_STRIP_TRAILING_WHITESPACE) 85 | string(REPLACE " " ";" LLVM_LIBRARY "${LLVM_LIBRARY}") 86 | endif() 87 | endif() 88 | 89 | 90 | # handle the QUIETLY and REQUIRED arguments and set SDL2_FOUND to TRUE if 91 | # all listed variables are TRUE 92 | INCLUDE(FindPackageHandleStandardArgs) 93 | FIND_PACKAGE_HANDLE_STANDARD_ARGS(LLVM DEFAULT_MSG 94 | LLVM_LIBRARY) 95 | 96 | MARK_AS_ADVANCED( 97 | LLVM_LIBRARY 98 | ) 99 | -------------------------------------------------------------------------------- /FuncDef.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Stmt.h" 4 | #include "Syntax.h" 5 | #include "Type.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | /// Syntax for function definition. 12 | class FuncDef 13 | { 14 | public: 15 | /// Construct function definition syntax. 16 | FuncDef( const Type& returnType, const std::string& name, std::vector&& params, SeqStmtPtr body ) 17 | : m_returnType( returnType ) 18 | , m_name( name ) 19 | , m_params( std::move( params ) ) 20 | , m_body( std::move( body ) ) 21 | { 22 | } 23 | 24 | /// Get the function return type. 25 | const Type& GetReturnType() const { return m_returnType; } 26 | 27 | /// Get the function name. 28 | const std::string& GetName() const { return m_name; } 29 | 30 | /// Get the parameter declarations. 31 | const std::vector& GetParams() const { return m_params; } 32 | 33 | /// Check whether the function definition has a body. (Builtin function declarations do not.) 34 | bool HasBody() const { return bool( m_body ); } 35 | 36 | /// Get the function body, which is a sequence of statements. 37 | const SeqStmt& GetBody() const 38 | { 39 | assert( HasBody() && "Expected function body" ); 40 | return *m_body; 41 | } 42 | 43 | private: 44 | Type m_returnType; 45 | std::string m_name; 46 | std::vector m_params; 47 | SeqStmtPtr m_body; 48 | }; 49 | 50 | /// Unique pointer to a function definition. 51 | using FuncDefPtr = std::unique_ptr; 52 | 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mark Leone 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Lexer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Token.h" 4 | 5 | /// Scan the given string for the next token (discarding whitespace). 6 | /// The string pointer is passed by reference; it is advanced to the character 7 | /// following the token. Discards invalid characters (with a warning). 8 | /// Returns kTokenEOF if the string contains no token. 9 | Token Lexer( const char*& source ); 10 | 11 | 12 | -------------------------------------------------------------------------------- /Lexer.re: -------------------------------------------------------------------------------- 1 | #include "Lexer.h" 2 | #include 3 | 4 | // This file is processed by re2c (http://re2c.org) to generate a finite state 5 | // machine that matches various regular expressions. 6 | 7 | // Scan the given string for the next token (discarding whitespace). 8 | // The string pointer is passed by reference; it is advanced to the character 9 | // following the token. Discards invalid characters (with a warning). 10 | // Returns kTokenEOF if the string contains no token. 11 | Token Lexer(const char*& source) 12 | { 13 | start: 14 | const char* begin = source; 15 | /*!re2c 16 | re2c:define:YYCTYPE = char; 17 | re2c:define:YYCURSOR = source; 18 | re2c:yyfill:enable = 0; 19 | 20 | integer = "-"?[0-9]+; 21 | id = [a-zA-Z_][a-zA-Z_0-9]*; 22 | space = [ \t\r\n]+; 23 | eof = "\x00"; 24 | 25 | integer { return Token( atoi( begin ) ); } 26 | "bool" { return kTokenBool; } 27 | "true" { return kTokenTrue; } 28 | "false" { return kTokenFalse; } 29 | "int" { return kTokenInt; } 30 | "if" { return kTokenIf; } 31 | "else" { return kTokenElse; } 32 | "operator" { return kTokenOperator; } 33 | "return" { return kTokenReturn; } 34 | "while" { return kTokenWhile; } 35 | id { return Token( std::string( begin, source ) ); } 36 | "+" { return kTokenPlus; } 37 | "-" { return kTokenMinus; } 38 | "*" { return kTokenTimes; } 39 | "/" { return kTokenDiv; } 40 | "%" { return kTokenMod; } 41 | "==" { return kTokenEQ; } 42 | "!=" { return kTokenNE; } 43 | "<" { return kTokenLT; } 44 | "<=" { return kTokenLE; } 45 | ">" { return kTokenGT; } 46 | ">=" { return kTokenGE; } 47 | "&&" { return kTokenAnd; } 48 | "||" { return kTokenOr; } 49 | "!" { return kTokenNot; } 50 | "(" { return kTokenLparen; } 51 | ")" { return kTokenRparen; } 52 | "{" { return kTokenLbrace; } 53 | "}" { return kTokenRbrace; } 54 | "," { return kTokenComma; } 55 | "=" { return kTokenAssign; } 56 | ";" { return kTokenSemicolon; } 57 | space { goto start; } 58 | eof { return Token( kTokenEOF ); } 59 | . { std::cerr << "Discarding unexpected character '" 60 | << *begin << "'" << std::endl; } 61 | */ 62 | } 63 | -------------------------------------------------------------------------------- /Parser.cpp: -------------------------------------------------------------------------------- 1 | #include "Parser.h" 2 | #include "Exp.h" 3 | #include "FuncDef.h" 4 | #include "Program.h" 5 | #include "Stmt.h" 6 | #include "TokenStream.h" 7 | 8 | namespace { 9 | 10 | // Exceptions are used internally by the parser to simplify error checking. 11 | // Any parse error is caught by the top-level parsing routine, which reports 12 | // an error and returns an error status. 13 | class ParseError : public std::runtime_error 14 | { 15 | public: 16 | explicit ParseError( const std::string& msg ) 17 | : std::runtime_error( msg ) 18 | { 19 | } 20 | }; 21 | 22 | // Forward declarations 23 | ExpPtr parseExp( TokenStream& tokens ); 24 | std::vector parseArgs( TokenStream& tokens ); 25 | SeqStmtPtr parseSeq( TokenStream& tokens ); 26 | ExpPtr parseRemainingExp( ExpPtr leftExp, int leftPrecedence, TokenStream& tokens ); 27 | int getPrecedence( const Token& token ); 28 | 29 | 30 | // Skip the specified token, throwing ParseError if it's not present. 31 | void skipToken( const Token& expected, TokenStream& tokens ) 32 | { 33 | Token token( *tokens++ ); 34 | if( token != expected ) 35 | throw ParseError( std::string( "Expected '" ) + expected.ToString() + "'" ); 36 | } 37 | 38 | 39 | // PrimaryExp -> true | false 40 | // | Num 41 | // | Id 42 | // | Id ( Args ) 43 | // | ( Exp ) 44 | // | UnaryOp PrimaryExp 45 | ExpPtr parsePrimaryExp( TokenStream& tokens ) 46 | { 47 | // Fetch the next token, advancing the token stream. (Note that this 48 | // dereferences, then increments the TokenStream.) 49 | Token token( *tokens++ ); 50 | switch( token.GetTag() ) 51 | { 52 | // Boolean constant? 53 | case kTokenTrue: 54 | return std::make_unique( true ); 55 | case kTokenFalse: 56 | return std::make_unique( false ); 57 | // Integer constant? 58 | case kTokenNum: 59 | return std::make_unique( token.GetNum() ); 60 | // An identifier might be a variable or the start of a function call. 61 | case kTokenId: 62 | { 63 | // If the next token is a left paren, it's a function call. 64 | if( *tokens == kTokenLparen ) 65 | // Parse argument expressions and construct CallExp. 66 | return std::make_unique( token.GetId(), parseArgs( tokens ) ); 67 | else 68 | // Construct VarExp 69 | return std::make_unique( token.GetId() ); 70 | } 71 | // Type conversion? 72 | case kTokenBool: 73 | case kTokenInt: 74 | { 75 | return std::make_unique( token.ToString(), parseArgs( tokens ) ); 76 | } 77 | // Parenthesized expression? 78 | case kTokenLparen: 79 | { 80 | skipToken( kTokenLparen, tokens ); 81 | ExpPtr exp( parseExp( tokens ) ); 82 | skipToken( kTokenRparen, tokens ); 83 | return exp; 84 | } 85 | // Prefix minus? 86 | case kTokenMinus: 87 | { 88 | Token unaryOp( *tokens++ ); 89 | ExpPtr exp( parsePrimaryExp( tokens ) ); 90 | return std::make_unique( unaryOp.ToString(), std::move( exp ) ); 91 | } 92 | default: 93 | throw ParseError( std::string( "Unexpected token: " ) + token.ToString() ); 94 | } 95 | } 96 | 97 | 98 | // Args -> ( ArgList ) 99 | // ArgList -> Exp 100 | // | Exp , ArgList 101 | std::vector parseArgs( TokenStream& tokens ) 102 | { 103 | skipToken( kTokenLparen, tokens ); 104 | std::vector exps; 105 | if( *tokens != kTokenRparen ) 106 | { 107 | exps.push_back( parseExp( tokens ) ); 108 | while( *tokens == kTokenComma ) 109 | { 110 | exps.push_back( parseExp( ++tokens ) ); 111 | } 112 | } 113 | skipToken( kTokenRparen, tokens ); 114 | 115 | return std::move( exps ); 116 | } 117 | 118 | 119 | // Parse an expression with infix operators. 120 | ExpPtr parseExp( TokenStream& tokens ) 121 | { 122 | // First, parse a primary expression, which contains no infix operators. 123 | ExpPtr leftExp( parsePrimaryExp( tokens ) ); 124 | 125 | // The next token might be an operator. Call a helper routine 126 | // to parse the remainder of the expression. 127 | return parseRemainingExp( std::move( leftExp ), 0 /*initial precedence*/, tokens ); 128 | } 129 | 130 | // This routine implements an operator precedence expression parser. 131 | // It assembles primary expressions into call expressions based 132 | // on the precedence of the operators it encounters. For example, 133 | // "1 + 2 * 3" is parsed as "1 + (2 * 3)" because multiplication has 134 | // higher precedence than addition. 135 | // 136 | // After parsing an expression (leftExp) whose operator has the given 137 | // precedence (or zero if it has no operator), parse the remainder of 138 | // the expression from the given token stream. 139 | ExpPtr parseRemainingExp( ExpPtr leftExp, int leftPrecedence, TokenStream& tokens ) 140 | { 141 | while( true ) 142 | { 143 | // If the previous operator has higher precedence than the current one, 144 | // it claims the prevously parsed expression. 145 | int precedence = getPrecedence( *tokens ); 146 | if( leftPrecedence > precedence ) 147 | return leftExp; 148 | 149 | // Parse the current operator and the current primary expression. 150 | Token opToken( *tokens++ ); 151 | ExpPtr rightExp = parsePrimaryExp( tokens ); 152 | 153 | // If the next operator has higher precedence, it claims the current expression. 154 | int rightPrecedence = getPrecedence( *tokens ); 155 | if( rightPrecedence > precedence ) 156 | { 157 | rightExp = parseRemainingExp( std::move( rightExp ), precedence + 1, tokens ); 158 | } 159 | 160 | // Construct a call expression with the left and right expressions. 161 | leftExp = std::make_unique( opToken.ToString(), 162 | std::move( leftExp ), std::move( rightExp ) ); 163 | } 164 | } 165 | 166 | // If the given token is an operator, return its precedence (from 0 to 5). 167 | // Otherwise return -1. 168 | int getPrecedence( const Token& token ) 169 | { 170 | switch( token.GetTag() ) 171 | { 172 | case kTokenTimes: 173 | case kTokenDiv: 174 | return 5; 175 | case kTokenMod: 176 | case kTokenPlus: 177 | case kTokenMinus: 178 | return 4; 179 | case kTokenLT: 180 | case kTokenLE: 181 | case kTokenGT: 182 | case kTokenGE: 183 | return 3; 184 | case kTokenEQ: 185 | case kTokenNE: 186 | return 2; 187 | case kTokenAnd: 188 | return 1; 189 | case kTokenOr: 190 | return 0; 191 | default: 192 | return -1; 193 | } 194 | } 195 | 196 | 197 | // Type -> bool | int 198 | Type parseType( TokenStream& tokens ) 199 | { 200 | Token typeName( *tokens++ ); 201 | switch( typeName.GetTag() ) 202 | { 203 | case kTokenBool: 204 | return kTypeBool; 205 | case kTokenInt: 206 | return kTypeInt; 207 | default: 208 | throw ParseError( "Expected type name" ); 209 | } 210 | } 211 | 212 | // Parse an identifier. 213 | std::string parseId( TokenStream& tokens ) 214 | { 215 | Token id( *tokens++ ); 216 | if( id.GetTag() != kTokenId ) 217 | throw ParseError( "Invalid declaration (expected identifier)" ); 218 | return id.GetId(); 219 | } 220 | 221 | 222 | // VarDecl -> Type Id 223 | VarDeclPtr parseVarDecl( VarDecl::Kind kind, TokenStream& tokens ) 224 | { 225 | Type type( parseType( tokens ) ); 226 | std::string id( parseId( tokens ) ); 227 | return std::make_unique( kind, type, id ); 228 | } 229 | 230 | 231 | // Stmt -> Id = Exp ; 232 | // | Id ( Args ) ; 233 | // | VarDecl ; 234 | // | Seq 235 | // | return Exp ; 236 | // | if ( Exp ) Stmt 237 | // | if ( Exp ) Stmt else Stmt 238 | // | while ( Exp ) Stmt 239 | StmtPtr parseStmt( TokenStream& tokens ) 240 | { 241 | Token token( *tokens ); 242 | switch( token.GetTag() ) 243 | { 244 | case kTokenId: 245 | { 246 | Token id( *tokens++ ); 247 | if( *tokens == kTokenAssign ) 248 | { 249 | // Assignment 250 | ExpPtr rvalue( parseExp( ++tokens ) ); 251 | skipToken( kTokenSemicolon, tokens ); 252 | return std::make_unique( id.GetId(), std::move( rvalue ) ); 253 | } 254 | else 255 | { 256 | // Call 257 | std::vector args( parseArgs( tokens ) ); 258 | CallExpPtr callExp( std::make_unique( id.GetId(), std::move( args ) ) ); 259 | skipToken( kTokenSemicolon, tokens ); 260 | return std::make_unique( std::move( callExp ) ); 261 | } 262 | } 263 | case kTokenInt: 264 | case kTokenBool: 265 | { 266 | // Declaration 267 | VarDeclPtr varDecl( parseVarDecl( VarDecl::kLocal, tokens ) ); 268 | ExpPtr initExp; 269 | if( *tokens == kTokenAssign ) 270 | { 271 | initExp = parseExp( ++tokens ); 272 | } 273 | skipToken( kTokenSemicolon, tokens ); 274 | return std::make_unique( std::move( varDecl ), std::move( initExp ) ); 275 | } 276 | case kTokenLbrace: 277 | { 278 | // Sequence 279 | return parseSeq( tokens ); 280 | } 281 | case kTokenReturn: 282 | { 283 | ++tokens; // skip "return" 284 | ExpPtr returnExp( parseExp( tokens ) ); 285 | skipToken( kTokenSemicolon, tokens ); 286 | return std::make_unique( std::move( returnExp ) ); 287 | } 288 | case kTokenIf: 289 | { 290 | ++tokens; // skip "if" 291 | skipToken( kTokenLparen, tokens ); 292 | ExpPtr condExp( parseExp( tokens ) ); 293 | skipToken( kTokenRparen, tokens ); 294 | 295 | StmtPtr thenStmt( parseStmt( tokens ) ); 296 | StmtPtr elseStmt; 297 | if( *tokens == kTokenElse ) 298 | { 299 | ++tokens; // skip "else" 300 | elseStmt = parseStmt( tokens ); 301 | } 302 | return std::make_unique( std::move( condExp ), std::move( thenStmt ), std::move( elseStmt ) ); 303 | } 304 | case kTokenWhile: 305 | { 306 | ++tokens; // skip "while" 307 | skipToken( kTokenLparen, tokens ); 308 | ExpPtr condExp( parseExp( tokens ) ); 309 | skipToken( kTokenRparen, tokens ); 310 | 311 | StmtPtr bodyStmt( parseStmt( tokens ) ); 312 | return std::make_unique( std::move( condExp ), std::move( bodyStmt ) ); 313 | } 314 | default: 315 | throw ParseError( std::string( "Unexpected token: " ) + token.ToString() ); 316 | } 317 | } 318 | 319 | // Seq -> { Stmt* } 320 | SeqStmtPtr parseSeq( TokenStream& tokens ) 321 | { 322 | skipToken( kTokenLbrace, tokens ); 323 | std::vector stmts; 324 | while( *tokens != kTokenRbrace ) 325 | { 326 | stmts.push_back( parseStmt( tokens ) ); 327 | } 328 | skipToken( kTokenRbrace, tokens ); 329 | return std::make_unique( std::move( stmts ) ); 330 | } 331 | 332 | // FuncId -> Id | operator BinaryOp 333 | std::string parseFuncId( TokenStream& tokens ) 334 | { 335 | if( *tokens == kTokenOperator ) 336 | { 337 | ++tokens; 338 | Token op( *tokens++ ); 339 | if( !op.IsOperator() ) 340 | throw ParseError( "Invalid operator" ); 341 | return op.ToString(); 342 | } 343 | else 344 | return parseId( tokens ); 345 | } 346 | 347 | // FuncDef -> Type Id ( VarDecl* ) Seq 348 | FuncDefPtr parseFuncDef( TokenStream& tokens ) 349 | { 350 | // Parse return type and function id. 351 | Type returnType( parseType( tokens ) ); 352 | std::string id( parseFuncId( tokens ) ); 353 | 354 | // Parse parameter declarations 355 | skipToken( kTokenLparen, tokens ); 356 | std::vector params; 357 | if( *tokens != kTokenRparen ) 358 | { 359 | params.push_back( parseVarDecl( VarDecl::kParam, tokens ) ); 360 | while( *tokens != kTokenRparen ) 361 | { 362 | skipToken( kTokenComma, tokens ); 363 | params.push_back( parseVarDecl( VarDecl::kParam, tokens ) ); 364 | } 365 | } 366 | skipToken( kTokenRparen, tokens ); 367 | 368 | // Parse function body (if any); 369 | SeqStmtPtr body; 370 | if( *tokens == kTokenLbrace ) 371 | body = parseSeq( tokens ); 372 | else 373 | skipToken( kTokenSemicolon, tokens ); 374 | 375 | return std::make_unique( returnType, id, std::move( params ), std::move( body ) ); 376 | } 377 | 378 | } // anonymouse namespace 379 | 380 | // Parse the given tokens, adding function definitions to the given program. 381 | // Returns zero for success (otherwise an error message is reported). 382 | int ParseProgram( TokenStream& tokens, Program* program ) 383 | { 384 | try 385 | { 386 | do { 387 | FuncDefPtr function( parseFuncDef( tokens ) ); 388 | program->GetFunctions().push_back( std::move( function ) ); 389 | } while( *tokens != kTokenEOF ); 390 | return 0; 391 | } 392 | catch( const ParseError& error ) 393 | { 394 | std::cerr << "Error: " << error.what() << std::endl; 395 | return -1; 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /Parser.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | class Program; 4 | class TokenStream; 5 | 6 | /// Parse the given tokens, adding function definitions to the given program. 7 | /// Returns zero for success (otherwise an error message is reported). 8 | int ParseProgram( TokenStream& tokens, Program* program ); 9 | 10 | 11 | -------------------------------------------------------------------------------- /Printer.cpp: -------------------------------------------------------------------------------- 1 | #include "Printer.h" 2 | #include "Exp.h" 3 | #include "FuncDef.h" 4 | #include "Program.h" 5 | #include "Stmt.h" 6 | #include "Visitor.h" 7 | 8 | class ExpPrinter : public ExpVisitor 9 | { 10 | public: 11 | ExpPrinter( std::ostream& out ) 12 | : m_out( out ) 13 | { 14 | } 15 | 16 | void Print( const Exp& exp ) { const_cast( exp ).Dispatch( *this ); } 17 | 18 | void* Visit( BoolExp& exp ) override 19 | { 20 | m_out << (exp.GetValue() ? "true" : "false"); 21 | return nullptr; 22 | } 23 | 24 | void* Visit( IntExp& exp ) override 25 | { 26 | m_out << exp.GetValue(); 27 | return nullptr; 28 | } 29 | 30 | void* Visit( VarExp& exp ) override 31 | { 32 | m_out << exp.GetName(); 33 | return nullptr; 34 | } 35 | 36 | void* Visit( CallExp& exp ) override 37 | { 38 | m_out << exp.GetFuncName() << '('; 39 | for( size_t i = 0; i < exp.GetArgs().size(); ++i ) 40 | { 41 | if( i > 0 ) 42 | m_out << ", "; 43 | exp.GetArgs()[i]->Dispatch( *this ); 44 | } 45 | m_out << ')'; 46 | return nullptr; 47 | } 48 | 49 | private: 50 | std::ostream& m_out; 51 | }; 52 | 53 | 54 | class StmtPrinter : public StmtVisitor 55 | { 56 | public: 57 | StmtPrinter( std::ostream& out ) 58 | : m_out( out ) 59 | { 60 | } 61 | 62 | void Print( const Stmt& stmt ) { const_cast( stmt ).Dispatch( *this ); } 63 | 64 | void Visit( CallStmt& stmt ) override { m_out << stmt.GetCallExp() << ';'; } 65 | 66 | void Visit( AssignStmt& stmt ) override { m_out << stmt.GetVarName() << " = " << stmt.GetRvalue() << ';'; } 67 | 68 | void Visit( DeclStmt& stmt ) override 69 | { 70 | m_out << *stmt.GetVarDecl(); 71 | if( stmt.HasInitExp() ) 72 | m_out << " = " << stmt.GetInitExp(); 73 | m_out << ';'; 74 | } 75 | 76 | void Visit( ReturnStmt& stmt ) override { m_out << "return " << stmt.GetExp() << ';'; } 77 | 78 | void Visit( SeqStmt& seq ) override 79 | { 80 | m_out << "{" << std::endl; 81 | for( const StmtPtr& stmt : seq.Get() ) 82 | { 83 | Print( *stmt ); 84 | m_out << std::endl; 85 | } 86 | m_out << "}"; 87 | } 88 | 89 | void Visit( IfStmt& stmt ) override 90 | { 91 | m_out << "if (" << stmt.GetCondExp() << ")" << std::endl; 92 | Print( stmt.GetThenStmt() ); 93 | if( stmt.HasElseStmt() ) 94 | { 95 | m_out << std::endl << "else" << std::endl; 96 | Print( stmt.GetElseStmt() ); 97 | } 98 | } 99 | 100 | void Visit( WhileStmt& stmt ) override 101 | { 102 | m_out << "while (" << stmt.GetCondExp() << ")" << std::endl; 103 | Print( stmt.GetBodyStmt() ); 104 | } 105 | 106 | private: 107 | std::ostream& m_out; 108 | }; 109 | 110 | std::ostream& operator<<( std::ostream& out, const Exp& exp ) 111 | { 112 | ExpPrinter( out ).Print( exp ); 113 | return out; 114 | } 115 | 116 | std::ostream& operator<<( std::ostream& out, const Stmt& stmt ) 117 | { 118 | StmtPrinter( out ).Print( stmt ); 119 | return out; 120 | } 121 | 122 | std::ostream& operator<<( std::ostream& out, const FuncDef& def ) 123 | { 124 | out << ToString( def.GetReturnType() ) << ' ' << def.GetName() << '('; 125 | for( size_t i = 0; i < def.GetParams().size(); ++i ) 126 | { 127 | if( i > 0 ) 128 | out << ", "; 129 | out << *def.GetParams()[i]; 130 | } 131 | out << ')' << std::endl; 132 | if( def.HasBody() ) 133 | out << def.GetBody(); 134 | return out; 135 | } 136 | 137 | std::ostream& operator<<( std::ostream& out, const Program& program ) 138 | { 139 | for( const FuncDefPtr& funcDef : program.GetFunctions() ) 140 | { 141 | if( funcDef->HasBody() ) 142 | out << *funcDef << std::endl; 143 | } 144 | return out; 145 | } 146 | -------------------------------------------------------------------------------- /Printer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | class Exp; 6 | class Stmt; 7 | class FuncDef; 8 | class Program; 9 | 10 | /// Output an expression. 11 | std::ostream& operator<<( std::ostream& out, const Exp& exp ); 12 | 13 | /// Output a statement. 14 | std::ostream& operator<<( std::ostream& out, const Stmt& stmt ); 15 | 16 | /// Output a function definition. 17 | std::ostream& operator<<( std::ostream& out, const FuncDef& def ); 18 | 19 | /// Output a program. 20 | std::ostream& operator<<( std::ostream& out, const Program& program ); 21 | 22 | 23 | -------------------------------------------------------------------------------- /Program.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | class FuncDef; 6 | using FuncDefPtr = std::unique_ptr; 7 | 8 | /// Syntax for a program, which is simply a vector of function definitions. 9 | class Program 10 | { 11 | public: 12 | const std::vector& GetFunctions() const { return m_functions; } 13 | 14 | std::vector& GetFunctions() { return m_functions; } 15 | 16 | private: 17 | std::vector m_functions; 18 | }; 19 | 20 | using ProgramPtr = std::unique_ptr; 21 | 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # One Weekend Compiler 3 | 4 | The LLVM compiler toolkit makes it easy to implement a compiler, and the [LLVM 5 | Tutorial](https://llvm.org/docs/tutorial/) (called "Kaleidoscope") is a good 6 | place to start. 7 | 8 | This example compiler provides an alternative introduction, with a bit more 9 | focus on design details like the use of a visitor pattern for syntax traversal. 10 | It consists of the following components: 11 | 12 | - A regexp-based lexer, employing [re2c](http://re2c.org/) to generate an efficient state machine. 13 | 14 | - A recursive-descent parser with an operator-precedence strategy for parsing 15 | expressions with infix operators. 16 | 17 | - Well-engineered abstract syntax classes that provide a good foundation for extending the source language. 18 | 19 | - A typechecker that supports simple overloading (without implicit type 20 | conversions). A key aspect of the typechecker is that it resolves lexical 21 | scoping, linking variable references and function calls to the corresponding 22 | definitions. This allows subsequent passes to operate without any knowledge 23 | of scoping rules. 24 | 25 | - A simple code generator that contructs LLVM IR. 26 | 27 | - Optimization and JIT code generation using off-the-shelf LLVM components. 28 | 29 | # Grammar 30 | 31 | The source language is a subset of C with the following grammar. 32 | 33 | ``` 34 | Prog -> FuncDef+ 35 | 36 | FuncDef -> Type FuncId ( VarDecl* ) Seq 37 | 38 | Type -> bool | int 39 | 40 | FuncId -> Id | operator BinaryOp 41 | 42 | VarDecl -> Type Id 43 | 44 | Seq -> { Stmt* } 45 | 46 | Stmt -> Id = Exp ; 47 | | Id ( Args ) ; 48 | | VarDecl ; 49 | | Seq 50 | | return Exp ; 51 | | if ( Exp ) Stmt 52 | | if ( Exp ) Stmt else Stmt 53 | | while ( Exp ) Stmt 54 | 55 | Args -> Exp 56 | | Exp , Args 57 | 58 | Exp -> true | false 59 | | Num 60 | | Id 61 | | Id ( Args ) 62 | | ( Exp ) 63 | | UnaryOp Exp 64 | | Exp BinaryOp Exp 65 | 66 | UnaryOp -> - | ! 67 | BinaryOp -> * | / | % 68 | | + | - 69 | | < | <= | > | >= 70 | | == | != 71 | ``` 72 | 73 | Notation: 74 | - `Prog -> FuncDef+` indicates that a program consists of one or more function definitions. 75 | - `Exp -> true | false | ...` indicates that an expression can be a `true` or `false` keyword, etc. 76 | - `VarDecl*` indicates zero or more variable declarations 77 | - Other punctuation characters are program literals (e.g. parentheses, braces, semicolon, comma). 78 | 79 | Example: 80 | ``` 81 | int main(int x) 82 | { 83 | int sum = 0; 84 | int i = 1; 85 | while (i <= x) 86 | { 87 | sum = sum + i; 88 | i = i + 1; 89 | } 90 | } 91 | 92 | ``` 93 | 94 | # Source files 95 | 96 | Here is an overview of the source files: 97 | 98 | - `main.cpp`: calls the lexer, parser, typechecker, and code generator 99 | - `Token.h`: lexical tokens, e.g. constants, identifiers, and keywords. 100 | - `Lexer.re`: regular expressions for lexical tokens (compiled by re2c) 101 | - `TokenStream.h`: adapter that calls Lexer to produce a stream of tokens. 102 | - `Parser.cpp`: recursive descent parser, which reads token stream and produces a syntax tree. 103 | - `Exp.h Stmt.h VarDecl FuncDef.h Program.h`: syntax trees for expressions, statements, functions, etc. 104 | - `Visitor.h`: visitor pattern for syntax traversal 105 | - `Printer.h`: print syntax tree using Visitor 106 | - `Typechecker.h`: a typechecker that supports overloading. 107 | - `Scope.h`: scoped symbol table used by the typechecker. 108 | - `Builtins.h`: declarations of built-in operators 109 | - `Codegen.cpp`: generates LLVM IR from syntax tree 110 | - `SimpleJit.h`: encapsulates LLVM ORC JIT engine 111 | 112 | # Building 113 | 114 | Prerequisites: 115 | 116 | - CMake 3.1+: https://cmake.org/download/ 117 | - LLVM 7.0: http://releases.llvm.org/download.html 118 | - re2c 1.2: https://github.com/skvadrik/re2c/releases 119 | 120 | - (Note that building re2c under Windows is easiest via cygwin or mingw.) 121 | 122 | The Weekend Compiler project is built with CMake. From the command line, the locations of LLVM and re2c can be specified as follows: 123 | 124 | cmake -D LLVM_ROOT_DIR=C:/LLVM-7.1 -D RE2C_EXE=C:/re2c-1.2.1/bin/re2c.exe 125 | 126 | Alternatively these locations can be specified via cmake-gui. 127 | -------------------------------------------------------------------------------- /Scope.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "VarDecl.h" 4 | #include 5 | #include 6 | 7 | /// The typechecker uses a Scope to resolve lexical scoping. A Scope maps 8 | /// variable names to variable declarations. A Scope has a parent pointer 9 | /// that links to the enclosing scope. This allows local variables to shadow 10 | /// function parameters, etc. 11 | class Scope 12 | { 13 | public: 14 | /// Construct a scope with an optional parent scope. 15 | explicit Scope( const Scope* parent = nullptr ) 16 | : m_map() 17 | , m_parent( parent ) 18 | { 19 | } 20 | 21 | /// Look up the variable with the specified name, delegating to the parent 22 | /// scope if not found. Returns null if the variable is not defined. 23 | const VarDecl* Find( const std::string& name ) const 24 | { 25 | MapType::const_iterator it = m_map.find( name ); 26 | if( it != m_map.end() ) 27 | return it->second; 28 | else 29 | return m_parent ? m_parent->Find( name ) : nullptr; 30 | } 31 | 32 | /// Add the given variable declaration to this scope. The variable 33 | /// declaration might be a function parameter or a local variable. 34 | /// Returns true for success, or false if the variable is already defined 35 | /// in this scope. (Note that a variable can be shadowed in an enclosing 36 | /// scope, but it cannot be declared twice in the same scope.) 37 | bool Insert( const VarDecl* varDecl ) 38 | { 39 | return m_map.insert( MapType::value_type( varDecl->GetName(), varDecl ) ).second; 40 | } 41 | 42 | private: 43 | using MapType = std::unordered_map; 44 | MapType m_map; 45 | const Scope* m_parent; 46 | }; 47 | 48 | 49 | -------------------------------------------------------------------------------- /SimpleJIT.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | using namespace llvm; 19 | using namespace llvm::orc; 20 | 21 | /// A simple JIT engine that encapsulates the LLVM ORC JIT API. 22 | /// (JIT = Just In Time, ORC = On Request Compilation) 23 | /// Adapted from the LLVM KaleidoscopeJIT example. 24 | class SimpleJIT { 25 | public: 26 | /// Construct JIT engine, initializing the resolver, object layer, and compile layer. 27 | SimpleJIT() : 28 | m_initialized( init() ), 29 | m_resolver 30 | (createLegacyLookupResolver 31 | ( m_session, 32 | [this](const std::string& name) { 33 | return m_objectLayer.findSymbol(name, true); 34 | }, 35 | [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })), 36 | m_target(EngineBuilder().selectTarget()), m_dataLayout(m_target->createDataLayout()), 37 | m_objectLayer(m_session, 38 | [this](VModuleKey) 39 | { return ObjLayerT::Resources 40 | { std::make_shared(), m_resolver}; 41 | }), 42 | m_compileLayer(m_objectLayer, SimpleCompiler(*m_target)) 43 | { 44 | llvm::sys::DynamicLibrary::LoadLibraryPermanently( nullptr ); 45 | } 46 | 47 | /// Get the TargetMachine, which can be used for target-specific optimizations. 48 | TargetMachine& getTargetMachine() { return *m_target; } 49 | 50 | /// Add the given module to the JIT engine, yielding a key that can be 51 | /// used for subsequent symbol lookups. 52 | VModuleKey addModule( std::unique_ptr module ) 53 | { 54 | VModuleKey key = m_session.allocateVModule(); 55 | cantFail( m_compileLayer.addModule( key, std::move( module ) ) ); 56 | return key; 57 | } 58 | 59 | /// Remove the module with the specified key from the JIT engine. 60 | void removeModule( VModuleKey key ) 61 | { 62 | cantFail( m_compileLayer.removeModule( key ) ); 63 | } 64 | 65 | /// Find the specified symbol in the module with the given key. 66 | JITSymbol findSymbol( VModuleKey key, const std::string name ) 67 | { 68 | return m_compileLayer.findSymbolIn( key, name, false /*ExportedSymbolsOnly*/ ); 69 | } 70 | 71 | private: 72 | using ObjLayerT = RTDyldObjectLinkingLayer; 73 | using CompileLayerT = IRCompileLayer; 74 | 75 | bool m_initialized; 76 | ExecutionSession m_session; 77 | std::shared_ptr m_resolver; 78 | std::unique_ptr m_target; 79 | const DataLayout m_dataLayout; 80 | ObjLayerT m_objectLayer; 81 | CompileLayerT m_compileLayer; 82 | 83 | // Perform prerequisite initialization. 84 | bool init() 85 | { 86 | InitializeNativeTarget(); 87 | InitializeNativeTargetAsmPrinter(); 88 | InitializeNativeTargetAsmParser(); 89 | return true; 90 | } 91 | 92 | }; 93 | -------------------------------------------------------------------------------- /Stmt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Exp.h" 4 | #include "VarDecl.h" 5 | 6 | /// Base class for statement syntax. 7 | class Stmt 8 | { 9 | public: 10 | /// The destructor is virtual, ensuring that the destructor for a derived 11 | /// class will be properly invoked. 12 | virtual ~Stmt() {} 13 | 14 | /// Dispatch to a visitor. \see StmtVisitor 15 | virtual void Dispatch( StmtVisitor& visitor ) = 0; 16 | }; 17 | 18 | /// Unique pointer to statement. 19 | using StmtPtr = std::unique_ptr; 20 | 21 | 22 | /// Function call statement, which simply holds a CallExp. 23 | class CallStmt : public Stmt 24 | { 25 | public: 26 | /// Construct from the given function call expression. 27 | CallStmt( CallExpPtr&& callExp ) 28 | : m_callExp( std::move( callExp ) ) 29 | { 30 | } 31 | 32 | /// Get the function call expression. 33 | const CallExp& GetCallExp() const { return *m_callExp; } 34 | 35 | /// Dispatch to a visitor. 36 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 37 | 38 | private: 39 | CallExpPtr m_callExp; 40 | }; 41 | 42 | 43 | // Assignment statement. 44 | class AssignStmt : public Stmt 45 | { 46 | public: 47 | /// Construct assignment statement. The lvalue is a variable, and the 48 | /// rvalue is an arbitrary expression. 49 | AssignStmt( const std::string& varName, ExpPtr&& rvalue ) 50 | : m_varName( varName ) 51 | , m_rvalue( std::move( rvalue ) ) 52 | { 53 | } 54 | 55 | /// Get the variable name (lvalue). 56 | const std::string& GetVarName() const { return m_varName; } 57 | 58 | /// Get the rvalue (the right-hand side of the assignment). 59 | const Exp& GetRvalue() const { return *m_rvalue; } 60 | 61 | /// Get the declaration of the assigned variable (null until typechecked). 62 | const VarDecl* GetVarDecl() const { return m_varDecl; } 63 | 64 | /// Link the assignment to the declaration of the assigned variable 65 | /// (called by the typechecker). 66 | void SetVarDecl( const VarDecl* varDecl ) { m_varDecl = varDecl; } 67 | 68 | /// Dispatch to a visitor. 69 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 70 | 71 | private: 72 | std::string m_varName; 73 | ExpPtr m_rvalue; 74 | const VarDecl* m_varDecl; 75 | }; 76 | 77 | 78 | /// A declaration statement (e.g. "int x = 0;") declares a local variable with 79 | /// an optional initializer. 80 | class DeclStmt : public Stmt 81 | { 82 | public: 83 | /// Construct a declaration statement from the specified variable declaration 84 | /// and optional initializer expression. 85 | DeclStmt( VarDeclPtr&& varDecl, ExpPtr&& initExp = ExpPtr() ) 86 | : m_varDecl( std::move( varDecl ) ) 87 | , m_initExp( std::move( initExp ) ) 88 | { 89 | } 90 | 91 | /// Get pointer to variable declaration, which is stored at use sites by the typechecker. 92 | const VarDecl* GetVarDecl() const { return m_varDecl.get(); } 93 | 94 | /// Check whether this declaration has an initializer expression. 95 | bool HasInitExp() const { return bool( m_initExp ); } 96 | 97 | /// Get the initializer expression. Check HasInitExp() before calling. 98 | const Exp& GetInitExp() const 99 | { 100 | assert( HasInitExp() && "Expected initializer expression in variable declaration" ); 101 | return *m_initExp; 102 | } 103 | 104 | /// Dispatch to a visitor. 105 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 106 | 107 | private: 108 | VarDeclPtr m_varDecl; 109 | ExpPtr m_initExp; 110 | }; 111 | 112 | 113 | /// Return statement. 114 | class ReturnStmt : public Stmt 115 | { 116 | public: 117 | /// Construct return statement with the given return value expression. 118 | /// (Note that void functions are not permitted, so the return value is required.) 119 | ReturnStmt( ExpPtr&& exp ) 120 | : m_exp( std::move( exp ) ) 121 | { 122 | } 123 | 124 | /// Get the return value expression. 125 | const Exp& GetExp() const { return *m_exp; } 126 | 127 | /// Dispatch to a visitor. 128 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 129 | 130 | private: 131 | ExpPtr m_exp; 132 | }; 133 | 134 | 135 | /// A sequence of statements. 136 | class SeqStmt : public Stmt 137 | { 138 | public: 139 | /// Construct sequence of statements from a vector of unique pointers. 140 | SeqStmt( std::vector&& stmts ) 141 | : m_stmts( std::move( stmts ) ) 142 | { 143 | } 144 | 145 | /// Get the sequence of statements. 146 | const std::vector& Get() const { return m_stmts; } 147 | 148 | /// Dispatch to a visitor. 149 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 150 | 151 | private: 152 | std::vector m_stmts; 153 | }; 154 | 155 | /// Unique pointer to a sequence of statements. 156 | using SeqStmtPtr = std::unique_ptr; 157 | 158 | 159 | /// If statement syntax. 160 | class IfStmt : public Stmt 161 | { 162 | public: 163 | /// Construct "if" statement with conditional expression, "then" statement 164 | /// (which might be a sequence), and an optional "else" statement. 165 | IfStmt( ExpPtr condExp, StmtPtr thenStmt, StmtPtr elseStmt = StmtPtr() ) 166 | : m_condExp( std::move( condExp ) ) 167 | , m_thenStmt( std::move( thenStmt ) ) 168 | , m_elseStmt( std::move( elseStmt ) ) 169 | { 170 | } 171 | 172 | /// Get the conditional expression. 173 | const Exp& GetCondExp() const { return *m_condExp; } 174 | 175 | /// Get the "then" statement, which might be a sequence. 176 | const Stmt& GetThenStmt() const { return *m_thenStmt; } 177 | 178 | /// Check whether this "if" statement has an "else" statement. 179 | bool HasElseStmt() const { return bool( m_elseStmt ); } 180 | 181 | /// Get the "else" statement. 182 | const Stmt& GetElseStmt() const 183 | { 184 | assert( HasElseStmt() && "Expected else statement" ); 185 | return *m_elseStmt; 186 | } 187 | 188 | /// Dispatch to a visitor. 189 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 190 | 191 | private: 192 | ExpPtr m_condExp; 193 | StmtPtr m_thenStmt; 194 | StmtPtr m_elseStmt; 195 | }; 196 | 197 | 198 | /// While statement. 199 | class WhileStmt : public Stmt 200 | { 201 | public: 202 | /// Construct while statement from a conditional expression and the loop body 203 | /// statement (which might be a sequence). 204 | WhileStmt( ExpPtr condExp, StmtPtr bodyStmt ) 205 | : m_condExp( std::move( condExp ) ) 206 | , m_bodyStmt( std::move( bodyStmt ) ) 207 | { 208 | } 209 | 210 | /// Get the conditional expression. 211 | const Exp& GetCondExp() const { return *m_condExp; } 212 | 213 | /// Get the loop body statement (which might be a sequence). 214 | const Stmt& GetBodyStmt() const { return *m_bodyStmt; } 215 | 216 | /// Dispatch to a visitor. 217 | void Dispatch( StmtVisitor& visitor ) override { visitor.Visit( *this ); } 218 | 219 | private: 220 | ExpPtr m_condExp; 221 | StmtPtr m_bodyStmt; 222 | }; 223 | 224 | 225 | -------------------------------------------------------------------------------- /Syntax.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Forward declarations for syntax classes 6 | 7 | class Exp; 8 | class BoolExp; 9 | class IntExp; 10 | class CallExp; 11 | class VarExp; 12 | 13 | class Stmt; 14 | class AssignStmt; 15 | class CallStmt; 16 | class DeclStmt; 17 | class IfStmt; 18 | class ReturnStmt; 19 | class SeqStmt; 20 | class WhileStmt; 21 | 22 | class FuncDef; 23 | class Program; 24 | class VarDecl; 25 | 26 | using ExpPtr = std::unique_ptr; 27 | using CallExpPtr = std::unique_ptr; 28 | using FuncDefPtr = std::unique_ptr; 29 | using ProgramPtr = std::unique_ptr; 30 | using SeqStmtPtr = std::unique_ptr; 31 | using StmtPtr = std::unique_ptr; 32 | using VarDeclPtr = std::unique_ptr; 33 | 34 | 35 | -------------------------------------------------------------------------------- /Token.cpp: -------------------------------------------------------------------------------- 1 | #include "Token.h" 2 | #include 3 | 4 | std::string Token::ToString() const 5 | { 6 | switch( GetTag() ) 7 | { 8 | case kTokenNum: 9 | { 10 | std::stringstream stream; 11 | stream << GetNum(); 12 | return stream.str(); 13 | } 14 | case kTokenId: return GetId(); 15 | case kTokenBool: return "bool"; 16 | case kTokenTrue: return "true"; 17 | case kTokenFalse: return "false"; 18 | case kTokenInt: return "int"; 19 | case kTokenIf: return "if"; 20 | case kTokenElse: return "else"; 21 | case kTokenReturn: return "return"; 22 | case kTokenWhile: return "while"; 23 | case kTokenOperator: return "operator"; 24 | case kTokenPlus: return "+"; 25 | case kTokenMinus: return "-"; 26 | case kTokenTimes: return "*"; 27 | case kTokenDiv: return "/"; 28 | case kTokenMod: return "%"; 29 | case kTokenEQ: return "=="; 30 | case kTokenNE: return "!="; 31 | case kTokenLT: return "<"; 32 | case kTokenLE: return "<="; 33 | case kTokenGT: return ">"; 34 | case kTokenGE: return ">="; 35 | case kTokenAnd: return "&&"; 36 | case kTokenOr: return "||"; 37 | case kTokenNot: return "!"; 38 | case kTokenLbrace: return "{"; 39 | case kTokenRbrace: return "}"; 40 | case kTokenLparen: return "("; 41 | case kTokenRparen: return ")"; 42 | case kTokenComma: return ","; 43 | case kTokenAssign: return "="; 44 | case kTokenSemicolon: return ";"; 45 | case kTokenEOF: return ""; 46 | } 47 | assert(false && "Unhandled token kind"); 48 | return ""; 49 | } 50 | -------------------------------------------------------------------------------- /Token.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | /// Tag of lexical token, e.g. integer, id, keyword, or punctuation. 9 | enum TokenTag 10 | { 11 | // Value-carrying tokens: 12 | kTokenNum, 13 | kTokenId, 14 | 15 | // Keywords: 16 | kTokenBool, 17 | kTokenTrue, 18 | kTokenFalse, 19 | kTokenInt, 20 | kTokenIf, 21 | kTokenElse, 22 | kTokenReturn, 23 | kTokenWhile, 24 | kTokenOperator, 25 | 26 | // Operators: 27 | kTokenPlus, 28 | kTokenMinus, 29 | kTokenTimes, 30 | kTokenDiv, 31 | kTokenMod, 32 | kTokenEQ, 33 | kTokenNE, 34 | kTokenLT, 35 | kTokenLE, 36 | kTokenGT, 37 | kTokenGE, 38 | kTokenAnd, 39 | kTokenOr, 40 | kTokenNot, 41 | 42 | // Punctuation: 43 | kTokenLbrace, 44 | kTokenRbrace, 45 | kTokenLparen, 46 | kTokenRparen, 47 | kTokenComma, 48 | kTokenAssign, 49 | kTokenSemicolon, 50 | kTokenEOF 51 | }; 52 | 53 | /// The lexer converts sequences of characters into tokens. A token has a tag 54 | /// (e.g. integer vs. id) and a value (e.g integer value or identifier name). 55 | class Token 56 | { 57 | public: 58 | /// Construct an integer token. 59 | explicit Token( int value ) 60 | : m_tag( kTokenNum ) 61 | , m_int( value ) 62 | { 63 | } 64 | 65 | /// Construct an identifier token. 66 | explicit Token( const std::string& id ) 67 | : m_tag( kTokenId ) 68 | , m_id( id ) 69 | { 70 | } 71 | 72 | /// Construct a non-value-carrying token. Implicit conversion is allowed. 73 | Token( TokenTag tag ) 74 | : m_tag( tag ) 75 | { 76 | assert( tag != kTokenNum && tag != kTokenId && "Value required for integer and id tokens" ); 77 | } 78 | 79 | /// Get the token's tag. 80 | TokenTag GetTag() const { return m_tag; } 81 | 82 | /// Get the value of a numeric token. 83 | int GetNum() const 84 | { 85 | assert( GetTag() == kTokenNum && "Expected numeric token" ); 86 | return m_int; 87 | } 88 | 89 | /// Get identifier. 90 | const std::string& GetId() const 91 | { 92 | assert( GetTag() == kTokenId && "Expected identifier token" ); 93 | return m_id; 94 | } 95 | 96 | /// Get token text, e.g. operator name. 97 | std::string ToString() const; 98 | 99 | /// Equality considers token value for numeric and identifier tokens. 100 | bool operator==( const Token& other ) 101 | { 102 | if( GetTag() != other.GetTag() ) 103 | return false; 104 | if( GetTag() == kTokenNum ) 105 | return GetNum() == other.GetNum(); 106 | if( GetTag() == kTokenId ) 107 | return GetId() == other.GetId(); 108 | return true; 109 | } 110 | 111 | /// Inequality operator. 112 | bool operator!=( const Token& other ) { return !( *this == other ); } 113 | 114 | /// Check whether this token can be preceded by an "operator" keyword 115 | /// (e.g. operator+). This allows operators to be defined by ordinary 116 | /// function definitions (see Builtins.h) 117 | bool IsOperator() const 118 | { 119 | switch( GetTag() ) 120 | { 121 | case kTokenPlus: 122 | case kTokenMinus: 123 | case kTokenTimes: 124 | case kTokenDiv: 125 | case kTokenMod: 126 | case kTokenEQ: 127 | case kTokenNE: 128 | case kTokenLT: 129 | case kTokenLE: 130 | case kTokenGT: 131 | case kTokenGE: 132 | case kTokenAnd: 133 | case kTokenOr: 134 | case kTokenNot: 135 | case kTokenBool: 136 | case kTokenInt: 137 | return true; 138 | default: 139 | return false; 140 | } 141 | } 142 | 143 | private: 144 | TokenTag m_tag; // Tag of token, e.g. int, id, keyword. 145 | int m_int; // Integer value, if tag is kTokenNum. 146 | std::string m_id; // Identifier value, if tag is kTokenId. 147 | }; 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /TokenStream.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Lexer.h" 4 | 5 | /// The Lexer returns a single token. This class wraps the Lexer to provide a 6 | /// stream-like interface to the Parser. A single token of lookahead is 7 | /// provided (via operator*), and the token stream can be advanced using the 8 | /// increment operator. For example, a token is typically consumed via 9 | /// "Token token(*tokens++);" (Note that the dereference operator has higher 10 | /// precedence than the increment operator.) 11 | class TokenStream 12 | { 13 | public: 14 | /// Construct token stream for the given source code, which must be null 15 | /// terminated. 16 | TokenStream( const char* source ) 17 | : m_source( source ) 18 | , m_token( kTokenEOF ) 19 | { 20 | ++*this; // Lex the first token 21 | } 22 | 23 | /// Inspect the next token, without advancing the token stream. 24 | Token operator*() { return m_token; } 25 | 26 | /// Advance the token stream, calling the Lexer to obtain the next token. 27 | TokenStream& operator++() 28 | { 29 | m_token = Lexer( m_source ); 30 | return *this; 31 | } 32 | 33 | /// Postfix increment operator. 34 | TokenStream operator++( int ) 35 | { 36 | TokenStream before( *this ); 37 | this->operator++(); 38 | return before; 39 | } 40 | 41 | private: 42 | const char* m_source; 43 | Token m_token; 44 | }; 45 | 46 | 47 | -------------------------------------------------------------------------------- /Type.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /// Only bool and int types are supported, so a Type is simply an enum value. 6 | /// Supporting arrays and structs would require a structured representation. 7 | enum Type 8 | { 9 | kTypeUnknown, 10 | kTypeBool, 11 | kTypeInt 12 | }; 13 | 14 | /// Convert the given type to a string. 15 | inline const char* ToString( Type type ) 16 | { 17 | switch( type ) 18 | { 19 | case kTypeUnknown: 20 | return ""; 21 | case kTypeBool: 22 | return "bool"; 23 | case kTypeInt: 24 | return "int"; 25 | } 26 | assert( false && "Unhandled type" ); 27 | return ""; 28 | } 29 | 30 | 31 | -------------------------------------------------------------------------------- /Typechecker.cpp: -------------------------------------------------------------------------------- 1 | #include "Typechecker.h" 2 | #include "Exp.h" 3 | #include "FuncDef.h" 4 | #include "Program.h" 5 | #include "Scope.h" 6 | #include "Stmt.h" 7 | #include "VarDecl.h" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace { 14 | 15 | // The function table is a multimap, mapping function names to overloaded definitions. 16 | using FuncTable = std::multimap; 17 | 18 | // Exceptions are used internally by the typechecker, but they do not 19 | // propagate beyond the top-level typechecking routine. 20 | class TypeError : public std::runtime_error 21 | { 22 | public: 23 | explicit TypeError( const std::string& msg ) 24 | : std::runtime_error( msg ) 25 | { 26 | } 27 | }; 28 | 29 | // The expression typechecker is a visitor. It holds a Scope, which maps 30 | // variable names to their declarations, and a function table, which maps 31 | // function names to definitions. The typechecker decorates each expression 32 | // with its type, and it resolves lexical scoping, linking variable references 33 | // and function calls to the corresponding definitions. This allows 34 | // subsequent passes (e.g. Codegen) to operate without any knowledge of 35 | // scoping rules. 36 | class ExpTypechecker : public ExpVisitor 37 | { 38 | public: 39 | // Construct typecheck from scope and function table. 40 | ExpTypechecker( const Scope& scope, const FuncTable& funcTable ) 41 | : m_scope( scope ) 42 | , m_funcTable( funcTable ) 43 | { 44 | } 45 | 46 | // Helper routine to typecheck a subexpression. The visitor operates on 47 | // non-const expressions, so we must const_cast when dispatching. 48 | void Check( const Exp& exp ) { const_cast( exp ).Dispatch( *this ); } 49 | 50 | // Typecheck a boolean constant. 51 | void* Visit( BoolExp& exp ) override 52 | { 53 | assert( exp.GetType() == kTypeBool ); 54 | return nullptr; 55 | } 56 | 57 | // Typecheck an integer constant. 58 | void* Visit( IntExp& exp ) override 59 | { 60 | assert( exp.GetType() == kTypeInt ); 61 | return nullptr; 62 | } 63 | 64 | // Typecheck a variable reference. 65 | void* Visit( VarExp& exp ) override 66 | { 67 | // Look up the variable name in the current scope. 68 | const VarDecl* decl = m_scope.Find( exp.GetName() ); 69 | if( decl ) 70 | { 71 | // Set the expression type to the type specified in the declaration. 72 | exp.SetType( decl->GetType() ); 73 | 74 | // Link the variable use to its declaration. This allows subsequent passes 75 | // to operate without knowledge of scoping. 76 | exp.SetVarDecl( decl ); 77 | } 78 | else 79 | throw TypeError( std::string( "Undefined variable: " ) + exp.GetName() ); 80 | return nullptr; 81 | } 82 | 83 | // Typecheck a function call. 84 | void* Visit( CallExp& exp ) override 85 | { 86 | // Typecheck the arguments. 87 | const std::vector& args = exp.GetArgs(); 88 | for( const ExpPtr& arg : args ) 89 | Check( *arg ); 90 | 91 | // Look up the function definition, which might be overloaded. 92 | const std::string& funcName = exp.GetFuncName(); 93 | const FuncDef* funcDef = findFunc( funcName, args ); 94 | if( !funcDef ) 95 | // TODO: better error message, including candidates. 96 | throw TypeError( std::string( "No match for function: " ) + funcName ); 97 | 98 | // Set expression type and link it to the function definition. 99 | exp.SetType( funcDef->GetReturnType() ); 100 | exp.SetFuncDef( funcDef ); 101 | return nullptr; 102 | } 103 | 104 | private: 105 | const Scope& m_scope; 106 | const FuncTable& m_funcTable; 107 | 108 | // Find a (possibly overloaded) function definition with the specified 109 | // name whose parameters match the types of the given arguments. 110 | // TODO: generalize this and use it to check for duplicate definitions. 111 | const FuncDef* findFunc( std::string name, const std::vector& args ) const 112 | { 113 | auto range = m_funcTable.equal_range( name ); 114 | for( auto it = range.first; it != range.second; ++it ) 115 | { 116 | const FuncDef* funcDef = it->second; 117 | if (argsMatch(funcDef->GetParams(), args)) 118 | return funcDef; 119 | } 120 | return nullptr; 121 | } 122 | 123 | // Check whether the given function parameters match the given argument types. 124 | static bool argsMatch( const std::vector& params, const std::vector& args ) 125 | { 126 | if( params.size() != args.size() ) 127 | return false; 128 | for( size_t i = 0; i < params.size(); ++i ) 129 | { 130 | if( params[i]->GetType() != args[i]->GetType() ) 131 | return false; 132 | } 133 | return true; 134 | } 135 | }; 136 | 137 | 138 | // The statement typechecker holds a scope and a function table, along with a pointer to the current function 139 | // (for typechecking return statements). The scope is extended as nested lexical scopes are encountered. 140 | class StmtTypechecker : public StmtVisitor 141 | { 142 | public: 143 | StmtTypechecker( Scope* scope, const FuncTable& funcTable, const FuncDef& enclosingFunction ) 144 | : m_scope( scope ) 145 | , m_funcTable( funcTable ) 146 | , m_enclosingFunction( enclosingFunction ) 147 | { 148 | } 149 | 150 | 151 | // Helper routine to typecheck a sub-statement. The visitor operates on 152 | // non-const expressions, so we must const_cast when dispatching. 153 | void CheckStmt( const Stmt& stmt ) { const_cast( stmt ).Dispatch( *this ); } 154 | 155 | // Helper routine to typecheck an expression. We construct an expression 156 | // typechecker on the fly (which is cheap) that contains the current scope 157 | // and function table. 158 | void CheckExp( const Exp& exp ) const { ExpTypechecker( *m_scope, m_funcTable ).Check( exp ); } 159 | 160 | // Typecheck a function call statement. 161 | void Visit( CallStmt& stmt ) override { CheckExp( stmt.GetCallExp() ); } 162 | 163 | // Typecheck an assignment statement. 164 | void Visit( AssignStmt& stmt ) override 165 | { 166 | // Check the rvalue (the right hand side). 167 | CheckExp( stmt.GetRvalue() ); 168 | 169 | // Look up the declaration of the variable on the left hand side of the assignment. 170 | const std::string& varName = stmt.GetVarName(); 171 | const VarDecl* varDecl = m_scope->Find( varName ); 172 | if( !varDecl ) 173 | throw TypeError( std::string( "Undefined variable in assignment: " ) + varName ); 174 | 175 | // Check that the type of the rvalue matches the lvalue. 176 | if( varDecl->GetType() != stmt.GetRvalue().GetType() ) 177 | throw TypeError( std::string( "Type mismatch in assignment to " ) + varName ); 178 | 179 | // Prohibit assignment to function parameters. 180 | if( varDecl->GetKind() != VarDecl::kLocal ) 181 | throw TypeError( std::string( "Expected local variable in assignment to " ) + varName ); 182 | 183 | // Link the assignment to the variable declaration. 184 | stmt.SetVarDecl( varDecl ); 185 | } 186 | 187 | // Typecheck a declaration statement (e.g. "int x = 1;") 188 | void Visit( DeclStmt& stmt ) override 189 | { 190 | // Add the variable declaration to the current scope. Declaring the same variable twice in 191 | // a given scope is prohibited. 192 | const VarDecl* varDecl = stmt.GetVarDecl(); 193 | const std::string& varName = varDecl->GetName(); 194 | if( !m_scope->Insert( varDecl ) ) 195 | throw TypeError( std::string( "Variable already defined in this scope: " ) + varName ); 196 | 197 | // Typecheck the initializer expression (if any) and verify that its type matches the declaration. 198 | if( stmt.HasInitExp() ) 199 | { 200 | CheckExp( stmt.GetInitExp() ); 201 | if( stmt.GetInitExp().GetType() != varDecl->GetType() ) 202 | throw TypeError( std::string( "Type mismatch in initialization of " ) + varName ); 203 | } 204 | 205 | } 206 | 207 | // Typecheck a return statement, ensuring that the type of the return value matches the current 208 | // function definition. 209 | void Visit( ReturnStmt& stmt ) override 210 | { 211 | CheckExp( stmt.GetExp() ); 212 | if( stmt.GetExp().GetType() != m_enclosingFunction.GetReturnType() ) 213 | throw TypeError( "Type mismatch in return statement" ); 214 | } 215 | 216 | // Typecheck a sequence of statements in a nested lexical scope. 217 | void Visit( SeqStmt& seq ) override 218 | { 219 | // Create a nested scope for any local variable declarations, saving the parent scope. 220 | Scope* parentScope = m_scope; 221 | Scope localScope( parentScope ); 222 | m_scope = &localScope; 223 | 224 | // Typecheck each statement in the sequence 225 | for( const StmtPtr& stmt : seq.Get() ) 226 | { 227 | CheckStmt( *stmt ); 228 | } 229 | 230 | // Restore the parent scope. 231 | m_scope = parentScope; 232 | } 233 | 234 | // Typecheck an "if" statement. 235 | void Visit( IfStmt& stmt ) override 236 | { 237 | CheckCondExp( stmt.GetCondExp() ); 238 | CheckStmt( stmt.GetThenStmt() ); 239 | if( stmt.HasElseStmt() ) 240 | CheckStmt( stmt.GetElseStmt() ); 241 | } 242 | 243 | // Typechek a while loop. 244 | void Visit( WhileStmt& stmt ) override 245 | { 246 | CheckCondExp( stmt.GetCondExp() ); 247 | CheckStmt( stmt.GetBodyStmt() ); 248 | } 249 | 250 | // Typecheck the conditional expression from an "if" statement or while 251 | // loop, ensuring that it has type bool or int. 252 | void CheckCondExp( const Exp& exp) 253 | { 254 | CheckExp( exp ); 255 | switch (exp.GetType()) 256 | { 257 | case kTypeBool: 258 | case kTypeInt: 259 | return; 260 | default: 261 | throw TypeError( "Expected integer condition expression" ); 262 | } 263 | } 264 | 265 | private: 266 | Scope* m_scope; 267 | const FuncTable& m_funcTable; 268 | const FuncDef& m_enclosingFunction; 269 | }; 270 | 271 | 272 | // Typecheck a function definition, adding it to the given function table. 273 | void checkFunction( FuncDef* funcDef, FuncTable* funcTable ) 274 | { 275 | // To permit recursion, we add the definition to the function table 276 | // before typechecking the body. TODO: check for duplicate definitions. 277 | funcTable->insert( FuncTable::value_type( funcDef->GetName(), funcDef ) ); 278 | 279 | // Construct a scope and add the function parameters. 280 | Scope scope; 281 | for( const VarDeclPtr& param : funcDef->GetParams() ) 282 | { 283 | if( !scope.Insert( param.get() ) ) 284 | throw TypeError( "Parameter already defined: " + param->GetName() ); 285 | } 286 | 287 | // Typecheck the function body. 288 | if( funcDef->HasBody() ) 289 | StmtTypechecker( &scope, *funcTable, *funcDef ).CheckStmt( funcDef->GetBody() ); 290 | } 291 | 292 | } // anonymous namespace 293 | 294 | 295 | // Typecheck a program, returning zero for success. If a TypeError exception 296 | // is caught, an error message is reported and a non-zero value is returned. 297 | int Typecheck( Program& program ) 298 | { 299 | FuncTable funcTable; 300 | 301 | for( const FuncDefPtr& funcDef : program.GetFunctions() ) 302 | { 303 | try 304 | { 305 | checkFunction( funcDef.get(), &funcTable ); 306 | } 307 | catch( const TypeError& e ) 308 | { 309 | std::cerr << "Error: " << e.what() << std::endl; 310 | return -1; 311 | } 312 | } 313 | return 0; 314 | } 315 | -------------------------------------------------------------------------------- /Typechecker.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | class Program; 4 | 5 | /// Typecheck the given program. Reports an error and returns a non-zero 6 | /// value if a type error is encountered. The typechecker decorates each 7 | /// expression with its type, and it resolves lexical scoping, linking 8 | /// variable references and function calls to the corresponding definitions. 9 | /// This allows subsequent passes (e.g. Codegen) to operate without any 10 | /// knowledge of scoping rules. 11 | int Typecheck( Program& program ); 12 | 13 | 14 | -------------------------------------------------------------------------------- /VarDecl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Type.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | /// Variable declaration syntax. VarDecl is used to represent function 10 | /// parameters (\see FuncDef) and local variable declarations (\see DeclStmt). 11 | class VarDecl 12 | { 13 | public: 14 | /// The declaration kind is needed by the code generator, since it 15 | /// determines how the variable value is stored. 16 | enum Kind 17 | { 18 | kLocal, 19 | kParam 20 | }; 21 | 22 | /// Construct a variable declaration of the specified kind with the given 23 | /// type and name. Note that the initializer for a local variable is not 24 | /// part of the declaration; it is stored in the DeclStmt. 25 | VarDecl( Kind kind, Type type, const std::string& name ) 26 | : m_kind( kind ) 27 | , m_type( type ) 28 | , m_name( name ) 29 | { 30 | } 31 | 32 | /// Get the kind of this variable declaration (kLocal vs. kParam). 33 | Kind GetKind() const { return m_kind; } 34 | 35 | /// Get the variable's type. 36 | const Type& GetType() const { return m_type; } 37 | 38 | /// Get the variable name. 39 | const std::string& GetName() const { return m_name; } 40 | 41 | private: 42 | Kind m_kind; 43 | Type m_type; 44 | std::string m_name; 45 | }; 46 | 47 | /// Unique pointer to variable declaration. 48 | using VarDeclPtr = std::unique_ptr; 49 | 50 | /// Output a variable declaration. 51 | inline std::ostream& operator<<( std::ostream& out, const VarDecl& varDecl ) 52 | { 53 | return out << ToString( varDecl.GetType() ) << ' ' << varDecl.GetName(); 54 | } 55 | 56 | 57 | -------------------------------------------------------------------------------- /Visitor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Syntax.h" 4 | 5 | /// Expression visitor base class. The typechecker and code generator are 6 | /// implemented as visitors, which allows all of the logic for a pass to be 7 | /// defined in a single class, rather than being scattered throughout separate 8 | /// methods in the the various syntax classes. A visitor can also retain 9 | /// state in member variables. For example, the typechecker visitor contains 10 | /// a symbol table (\see Scope) that is extended as variable declarations are 11 | /// processed. 12 | class ExpVisitor 13 | { 14 | public: 15 | virtual void* Visit( BoolExp& exp ) = 0; 16 | virtual void* Visit( IntExp& exp ) = 0; 17 | virtual void* Visit( VarExp& exp ) = 0; 18 | virtual void* Visit( CallExp& exp ) = 0; 19 | }; 20 | 21 | 22 | /// Statement visitor base class. 23 | class StmtVisitor 24 | { 25 | public: 26 | virtual void Visit( CallStmt& exp ) = 0; 27 | virtual void Visit( AssignStmt& exp ) = 0; 28 | virtual void Visit( DeclStmt& exp ) = 0; 29 | virtual void Visit( ReturnStmt& exp ) = 0; 30 | virtual void Visit( SeqStmt& exp ) = 0; 31 | virtual void Visit( IfStmt& exp ) = 0; 32 | virtual void Visit( WhileStmt& exp ) = 0; 33 | }; 34 | 35 | 36 | -------------------------------------------------------------------------------- /examples/factorial.in: -------------------------------------------------------------------------------- 1 | 2 | int fact(int x) 3 | { 4 | if (x == 1) 5 | return x; 6 | else 7 | return x * fact(x - 1); 8 | } 9 | 10 | int main(int x) 11 | { 12 | return fact(x); 13 | } 14 | -------------------------------------------------------------------------------- /examples/sum.in: -------------------------------------------------------------------------------- 1 | 2 | int main(int x) 3 | { 4 | int sum = 0; 5 | int i = 1; 6 | while (i <= x) 7 | { 8 | sum = sum + i; 9 | i = i + 1; 10 | } 11 | return sum; 12 | } 13 | -------------------------------------------------------------------------------- /grammar.txt: -------------------------------------------------------------------------------- 1 | Prog -> FuncDef+ 2 | 3 | FuncDef -> Type FuncId ( VarDecl* ) Seq 4 | 5 | Type -> bool | int 6 | 7 | FuncId -> Id | operator BinaryOp 8 | 9 | VarDecl -> Type Id 10 | 11 | Seq -> { Stmt* } 12 | 13 | Stmt -> Id = Exp ; 14 | | Id ( Args ) ; 15 | | VarDecl ; 16 | | Seq 17 | | return Exp ; 18 | | if ( Exp ) Stmt 19 | | if ( Exp ) Stmt else Stmt 20 | | while ( Exp ) Stmt 21 | 22 | Args -> Exp 23 | | Exp , Args 24 | 25 | ---------------------------------------------------------------------- 26 | 27 | Exp -> true | false 28 | | Num 29 | | Id 30 | | Id ( Args ) 31 | | ( Exp ) 32 | | UnaryOp Exp 33 | | Exp BinaryOp Exp 34 | 35 | UnaryOp -> - | ! 36 | BinaryOp -> * | / | % 37 | | + | - 38 | | < | <= | > | >= 39 | | == | != 40 | 41 | ---------------------------------------------------------------------- 42 | 43 | // Alternative: precedence climbing 44 | 45 | PrimaryExp -> true | false 46 | | Num 47 | | Id 48 | | Id ( Args ) 49 | | ( Exp ) 50 | 51 | UnaryOp -> - | + | ! 52 | MulOp -> * | / | % 53 | AddOp -> + | - 54 | CmpOp -> < | <= | > | >= 55 | EqOp -> == | != 56 | 57 | UnaryExp -> UnaryOp UnaryExp 58 | | PrimaryExp 59 | 60 | MulExp -> MulExp MulOp UnaryExp 61 | | UnaryExp 62 | 63 | AddExp -> AddExp AddOp MulExp 64 | | MulExp 65 | 66 | CmpExp -> CmpExp CmpOp AddExp 67 | | AddExp 68 | 69 | AndExp -> AndExp && CmpExp 70 | | CmpExp 71 | 72 | OrExp -> OrExp || AndExp 73 | | AndExp 74 | 75 | Exp -> OrExp 76 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "Builtins.h" 2 | #include "Codegen.h" 3 | #include "FuncDef.h" 4 | #include "Parser.h" 5 | #include "Printer.h" 6 | #include "Program.h" 7 | #include "SimpleJIT.h" 8 | #include "TokenStream.h" 9 | #include "Typechecker.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | 23 | #ifndef OPT_LEVEL 24 | /// Optimization level, which defaults to -O2. 25 | #define OPT_LEVEL 2 26 | #endif 27 | 28 | namespace { 29 | 30 | // Forward declarations. 31 | void optimize( Module* module, int optLevel ); 32 | int readFile( const char* filename, std::vector* buffer ); 33 | void dumpSyntax( const Program& program, const char* srcFilename ); 34 | void dumpIR( llvm::Module& module, const char* srcFilename, const char* what ); 35 | 36 | // Parse and typecheck the given source code, adding definitions to the given Program. 37 | // This is used to process both builtin definitions and user code. 38 | int parseAndTypecheck( const char* source, Program* program ) 39 | { 40 | // Construct token stream, which encapsulates the lexer. \see TokenStream. 41 | TokenStream tokens( source ); 42 | 43 | // Parse the token stream into a program. 44 | int status = ParseProgram( tokens, program ); 45 | 46 | // If the parser succeeded, typecheck the program. 47 | if( status == 0 ) 48 | status = Typecheck( *program ); 49 | return status; 50 | } 51 | 52 | } // anonymous namespace 53 | 54 | 55 | int main( int argc, const char* const* argv ) 56 | { 57 | // Get command-line arguments. 58 | if( argc != 3 ) 59 | { 60 | std::cerr << "Usage: " << argv[0] << " " << std::endl; 61 | return -1; 62 | } 63 | const char* filename = argv[1]; 64 | int inputValue = atoi( argv[2] ); 65 | 66 | // Read source file. TODO: use an input stream, rather than reading the entire file. 67 | std::vector source; 68 | int status = readFile( argv[1], &source ); 69 | if( status != 0 ) 70 | { 71 | std::cerr << "Unable to open input file: " << argv[1] << std::endl; 72 | return status; 73 | } 74 | 75 | // Parse and typecheck builtin functions. 76 | ProgramPtr program( new Program ); 77 | status = parseAndTypecheck( GetBuiltins(), program.get() ); 78 | assert(status == 0); 79 | 80 | // Parse and typecheck user source code. 81 | status = parseAndTypecheck( source.data(), program.get()); 82 | if( status ) 83 | return status; 84 | dumpSyntax( *program, filename ); 85 | 86 | // Generate LLVM IR. 87 | llvm::LLVMContext context; 88 | std::unique_ptr module( Codegen( &context, *program ) ); 89 | dumpIR( *module, filename, "initial" ); 90 | 91 | // Verify the module, which catches malformed instructions and type errors. 92 | assert(!verifyModule(*module, &llvm::errs())); 93 | 94 | // Construct JIT engine and use data layout for target-specific optimizations. 95 | SimpleJIT jit; 96 | module->setDataLayout( jit.getTargetMachine().createDataLayout() ); 97 | 98 | // Optimize the module. 99 | optimize( module.get(), OPT_LEVEL ); 100 | dumpIR( *module, filename, "optimized" ); 101 | 102 | // Use the JIT engine to generate native code. 103 | VModuleKey key = jit.addModule( std::move(module) ); 104 | 105 | // Get the main function pointer. 106 | JITSymbol mainSymbol = jit.findSymbol( key, "main" ); 107 | typedef int ( *MainFunc )( int ); 108 | MainFunc mainFunc = reinterpret_cast( cantFail( mainSymbol.getAddress() ) ); 109 | 110 | // Call the main function using the input value from the command line. 111 | int result = mainFunc(inputValue); 112 | std::cout << result << std::endl; 113 | 114 | return 0; 115 | } 116 | 117 | namespace { 118 | 119 | // Optimize the module using the given optimization level (0 - 3). 120 | void optimize( Module* module, int optLevel ) 121 | { 122 | // Construct the function and module pass managers, which are populated 123 | // with standard optimizations (e.g. constant propagation, inlining, etc.) 124 | legacy::FunctionPassManager functionPasses( module ); 125 | legacy::PassManager modulePasses; 126 | 127 | // Populate the pass managers based on the optimization level. 128 | PassManagerBuilder builder; 129 | builder.OptLevel = optLevel; 130 | builder.populateFunctionPassManager( functionPasses ); 131 | builder.populateModulePassManager( modulePasses ); 132 | 133 | // Run the function passes, then the module passes. 134 | for( Function& function : *module ) 135 | functionPasses.run( function ); 136 | modulePasses.run( *module ); 137 | } 138 | 139 | // Read file into the given buffer. Returns zero for success. 140 | int readFile( const char* filename, std::vector* buffer ) 141 | { 142 | // Open the stream at the end, get file size, and allocate data. 143 | std::ifstream in( filename, std::ifstream::ate | std::ifstream::binary ); 144 | if( in.fail() ) 145 | return -1; 146 | size_t length = static_cast( in.tellg() ); 147 | 148 | buffer->resize( length + 1 ); 149 | 150 | // Rewind and read entire file 151 | in.clear(); // clear EOF 152 | in.seekg( 0, std::ios::beg ); 153 | in.read( buffer->data(), length ); 154 | 155 | // The buffer is null-terminated (for the benefit of the Lexer). 156 | (*buffer)[length] = '\0'; 157 | return 0; 158 | } 159 | 160 | // Dump syntax for debugging if the "ENABLE_DUMP" environment variable is set. 161 | void dumpSyntax( const Program& program, const char* srcFilename ) 162 | { 163 | if ( !getenv("ENABLE_DUMP") ) 164 | return; 165 | std::string filename( std::string( srcFilename ) + ".syn" ); 166 | std::ofstream out( filename ); 167 | out << program << std::endl; 168 | } 169 | 170 | // Dump LLVM IR for debugging if the "ENABLE_DUMP" environment variable is set. 171 | void dumpIR( llvm::Module& module, const char* srcFilename, const char* what ) 172 | { 173 | if ( !getenv("ENABLE_DUMP") ) 174 | return; 175 | std::string filename( std::string( srcFilename ) + "." + what + ".ll" ); 176 | std::ofstream stream( filename ); 177 | llvm::raw_os_ostream out( stream ); 178 | out << module; 179 | } 180 | 181 | } // anonymous namespace 182 | --------------------------------------------------------------------------------