├── example.c ├── .gitignore ├── COPYING ├── README.md └── cc.py /example.c: -------------------------------------------------------------------------------- 1 | int example(long a1) { 2 | long i = 42; 3 | long j = 10; 4 | long y = i * j; 5 | return y; 6 | } 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | *.o 3 | *.db 4 | *.txt 5 | a.out 6 | mcc 7 | watch 8 | *.dSYM 9 | *.gch 10 | debug-assembly 11 | unit_tests 12 | .DS_Store 13 | .cache 14 | compile_commands.json 15 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, James W M Barford-Evans 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # c compiler in python 2 | 3 | 4 | This is a toy compiler attempting to compile a subset of c in less than 1000 5 | lines of code. Not much is supported as a result. I wrote this over a *weekend 6 | as a challenge, or more realistically procrastinating from what I should be doing. 7 | 8 | As a result this can only handle basic integer operations! Although the 9 | parser has support for a `while` with break and continue, an `if` and 10 | function calls. I've not built these out as I ran out of time. (when I next 11 | procrastinate I might add them :D). 12 | 13 | I use `sys` to be able to access the command line arguments, `cast` to 14 | hack around pythons type system and `NoReturn` for my implementation of `panic`. 15 | Other than that there are no python libraries used. 16 | 17 | # Usage 18 | ```sh 19 | ./cc.py 20 | ``` 21 | 22 | ## Output: 23 | This shows all of the stages used to generate the x86 code. 24 | ``` 25 | C ====== 26 | int main() { 27 | int x = 4 + 5; 28 | int y = x + 9; 29 | return y; 30 | } 31 | 32 | LEX ===== 33 | 'int' 'main' '(' ')' '{' 'int' 'x' '=' '4' '+' '5' ';' 'int' 'y' '=' 'x' '+' '9' ';' 'return' 'y' ';' '}' 34 | 35 | AST ===== 36 | here 37 | int main 38 | () 39 | 40 | int 41 | x = 42 | (4 + 5) 43 | int 44 | y = 45 | ( x + 9) 46 | y 47 | TAC ===== 48 | main:: 49 | R0 = ADD @i64::4, @i64::5 50 | R1 = R0 51 | R2 = ADD R1, @i64::9 52 | R3 = R2 53 | RET R3 54 | 55 | x86 ===== 56 | _main:: 57 | PUSH RBP 58 | MOVQ RBP, RSP 59 | SUB RSP,16 60 | MOVQ RAX, 4 61 | MOVQ RCX, 5 62 | ADD RAX, RCX 63 | MOVQ -8[RBP], RAX 64 | MOVQ RCX, 9 65 | MOVQ RAX, -8[RBP] 66 | ADD RAX, RCX 67 | MOVQ -8[RBP], RAX 68 | MOVQ -16[RBP], RAX 69 | ADD RSP, 16 70 | MOVQ RAX, -16[RBP] 71 | LEAVE 72 | RET 73 | ``` 74 | 75 | # Components 76 | 77 | ## Lexer 78 | - This is classical lexer splitting the code into a list of tokens without 79 | using pythons `re` library. 80 | 81 | ## Parser 82 | - Essentially an LL recursive decent parser, precidence is essentially operator 83 | climbing. 84 | 85 | ## IR - TAC 86 | - A three Address Code intermediate representation is used to flatten out the 87 | ast into something that is easier to convert to assembly 88 | 89 | ## x86_64 90 | - A semi-realistic intel style assembly, I'm not too worried that the x86_64 91 | may not run, it's out side of the scope of this project. 92 | 93 | # Inspirations & Resources 94 | - [Compiler design in c](https://holub.com/compiler/) 95 | - [8cc](https://github.com/rui314/8cc) 96 | - [parsing expressions by precidence climbing](https://eli.thegreenplace.net/2012/08/02/parsing-expressions-by-precedence-climbing) 97 | - [Crafting interpreters](https://craftinginterpreters.com/) 98 | - [Engineering a compiler - 3rd edition](https://www.amazon.com/Engineering-Compiler-Keith-D-Cooper/dp/0128154128) 99 | - [tcc](http://bellard.org/tcc/) 100 | - [cc65](https://cc65.github.io/) 101 | - [TempleOS](https://templeos.org/) 102 | - [JS c compiler](https://github.com/Captainarash/CaptCC) 103 | 104 | ## Twitch 105 | I semi-frequently stream on twitch at: https://twitch.tv/Jamesbarford mostly 106 | c and mostly compiler related. 107 | 108 | *Taking a weekend - 109 | I'm building a more _real_ compiler in `c` for my masters. There's months of 110 | learning crammed into these rather shoddy ~900 of code. 111 | -------------------------------------------------------------------------------- /cc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import cast, NoReturn 3 | import sys 4 | 5 | def panic(argv) -> NoReturn: 6 | print(f"ERROR: {argv}") 7 | exit(1) 8 | 9 | OP_PLUS = ord("+") 10 | OP_SUB = ord("-") 11 | OP_MUL = ord("*") 12 | OP_DIV = ord("/") 13 | OP_SHL = 0x20 14 | OP_SHR = 0x21 15 | 16 | op_to_alu = { OP_SHR: "SHR",OP_SHL: "SHL",OP_MUL: "MUL", OP_PLUS: "ADD", OP_SUB: "SUB", OP_DIV: "DIV" } 17 | op_to_str = { OP_SHR: ">>",OP_SHL: "<<",OP_MUL: "*", OP_PLUS: "+", OP_SUB: "-", OP_DIV: "/" } 18 | str_to_op = {">>": OP_SHR, "<<": OP_SHL, "+": OP_PLUS, "-": OP_SUB, "*": OP_MUL, "/": OP_DIV} 19 | 20 | TK_IDENT = 0 21 | TK_PUNCT = 1 22 | TK_KW = 2 23 | TK_IDENT = 3 24 | TK_I64 = 4 25 | TK_F64 = 5 26 | keywords = {"void","int","long","while","break","return","continue","if","else"} 27 | mult_tk = {"<<",">>","++","--","->","=="} 28 | 29 | class Token: 30 | def __init__(self, kind: int, lineno: int) -> None: 31 | self.kind = kind 32 | self.lineno = lineno 33 | 34 | class TokenIdent(Token): 35 | def __init__(self, ident: str, lineno: int) -> None: 36 | super().__init__(TK_IDENT,lineno) 37 | self.ident = ident 38 | def __str__(self) -> str: return self.ident 39 | 40 | class TokenPunct(Token): 41 | def __init__(self, punct: str, lineno: int) -> None: 42 | super().__init__(TK_PUNCT,lineno) 43 | self.punct = punct 44 | def __str__(self) -> str: return self.punct 45 | 46 | class TokenI64(Token): 47 | def __init__(self, i64: int, lineno: int) -> None: 48 | super().__init__(TK_I64, lineno) 49 | self.i64 = i64 50 | def __str__(self) -> str: return str(self.i64) 51 | 52 | class TokenF64(Token): 53 | def __init__(self, f64: float, lineno: int) -> None: 54 | super().__init__(TK_F64, lineno) 55 | self.f64 = f64 56 | def __str__(self) -> str: return str(self.f64) 57 | 58 | class TokenKeyWord(Token): 59 | def __init__(self, ident: str, lineno: int) -> None: 60 | super().__init__(TK_KW, lineno) 61 | self.ident = ident 62 | def __str__(self) -> str: return self.ident 63 | 64 | class Lexer: 65 | def __init__(self, code: str): 66 | self.code = code 67 | self.idx = 0 68 | self.code_len = len(code) 69 | self.lineno = 1 70 | 71 | def get_next(self) -> str: 72 | if self.idx == self.code_len: 73 | return '\0' 74 | ch = self.code[self.idx] 75 | if ch == '\n': self.lineno += 1 76 | self.idx+=1 77 | return ch 78 | def peek(self) -> str: return self.code[self.idx] 79 | 80 | def rewind(self): self.idx -= 1 81 | 82 | def lexident(lexer: Lexer, ch) -> str: 83 | ident = ch 84 | while ch := lexer.get_next(): 85 | if not str.isalpha(ch) and not str.isdigit(ch) and ch != '\0': 86 | break 87 | ident += ch 88 | lexer.rewind() 89 | return ident 90 | 91 | def lexnum(lexer: Lexer, ch) -> tuple[bool, int | float]: 92 | strnum = ch 93 | while str.isdigit((ch := lexer.get_next())): 94 | strnum += ch 95 | if ch == '.': 96 | strnum += '.' 97 | while str.isdigit((ch := lexer.get_next())): 98 | strnum += ch 99 | lexer.rewind() 100 | return True, float(strnum) 101 | lexer.rewind() 102 | return False, int(strnum) 103 | 104 | def lexc(code: str) -> list[Token]: 105 | tokens = [] 106 | lexer = Lexer(code) 107 | while ch := lexer.get_next(): 108 | if ch == '\0': break 109 | elif str.isalpha(ch): 110 | ident = lexident(lexer,ch) 111 | if ident in keywords: tokens.append(TokenKeyWord(ident,lexer.lineno)) 112 | else: tokens.append(TokenIdent(ident,lexer.lineno)) 113 | elif str.isdigit(ch): 114 | is_float,num = lexnum(lexer,ch) 115 | if is_float: tokens.append(TokenF64(float(num),lexer.lineno)) 116 | else: tokens.append(TokenI64(int(num),lexer.lineno)) 117 | elif ch in {"+","-","*","/",";","{","}","(",")"}: tokens.append(TokenPunct(ch,lexer.lineno)) 118 | elif ch == '>': 119 | if lexer.peek() == '>': 120 | lexer.get_next() 121 | tokens.append(TokenPunct(">>",lexer.lineno)) 122 | else: 123 | tokens.append(TokenPunct(ch,lexer.lineno)) 124 | elif ch == '<': 125 | if lexer.peek() == '<': 126 | lexer.get_next() 127 | tokens.append(TokenPunct("<<",lexer.lineno)) 128 | else: 129 | tokens.append(TokenPunct(ch,lexer.lineno)) 130 | elif ch == '=': 131 | if lexer.peek() == '=': 132 | lexer.get_next() 133 | tokens.append(TokenPunct("==",lexer.lineno)) 134 | else: 135 | tokens.append(TokenPunct(ch,lexer.lineno)) 136 | return tokens 137 | 138 | # AST =========== 139 | # This limited implementation can do arithmetic operations given an Ast 140 | AST_INT = 0 141 | AST_FLOAT = 1 142 | AST_COMPOUND = 2 143 | AST_LITERAL = 3 144 | AST_LVAR = 4 145 | AST_DECL = 5 146 | AST_FUN = 5 147 | AST_FUN_CALL = 6 148 | AST_PTR = 7 149 | AST_RETURN = 8 150 | AST_BREAK = 9 151 | AST_CONTINUE = 10 152 | AST_WHILE = 11 153 | AST_IF = 12 154 | 155 | class AstType: 156 | def __init__(self, size: int, kind: int = 0, ptr = None) -> None: 157 | self.kind: int = kind 158 | self.issigned: bool = False 159 | self.size: int = size 160 | self.ptr: AstType | None = ptr 161 | def __str__(self) -> str: 162 | if self.kind == AST_INT: return "int" 163 | elif self.kind == AST_FLOAT: return "float" 164 | elif self.kind == AST_COMPOUND: return "compound" 165 | elif self.kind == AST_LITERAL: return "literal" 166 | elif self.kind == AST_LVAR: return "lvar" 167 | elif self.kind == AST_DECL: return "decl" 168 | elif self.kind == AST_FUN: return "function" 169 | elif self.kind == AST_FUN_CALL: return "function_call" 170 | elif self.kind == AST_PTR: return "pointer" 171 | elif self.kind == AST_RETURN: return "return" 172 | elif self.kind == AST_BREAK: return "break" 173 | elif self.kind == AST_CONTINUE: return "continue" 174 | elif self.kind == AST_WHILE: return "while" 175 | elif self.kind == AST_IF: return "if" 176 | else: return "unknown" 177 | 178 | ast_type_i32 = AstType(size=4, kind=AST_INT) 179 | ast_type_i64 = AstType(size=8, kind=AST_INT) 180 | ast_type_f64 = AstType(size=8, kind=AST_FLOAT) 181 | 182 | class AstTypePtr(AstType): 183 | def __init__(self, base: AstType) -> None: 184 | super().__init__(8,AST_PTR,base) 185 | 186 | class Ast: 187 | kind: int 188 | def __init__(self, ast_type: AstType | None = None) -> None: 189 | self.type = ast_type 190 | self.offset = 0 191 | 192 | class AstI32(Ast): 193 | def __init__(self, i32: int) -> None: 194 | super().__init__(ast_type_i32) 195 | self.kind = AST_LITERAL 196 | self.i64 = i32 197 | def __str__(self) -> str: return str(self.i64) 198 | 199 | class AstI64(Ast): 200 | def __init__(self, i64: int) -> None: 201 | super().__init__(ast_type_i64) 202 | self.kind = AST_LITERAL 203 | self.i64 = i64 204 | def __str__(self) -> str: return str(self.i64) 205 | 206 | class AstF64(Ast): 207 | def __init__(self, f64: float) -> None: 208 | super().__init__(ast_type_f64) 209 | self.kind = AST_LITERAL 210 | self.f64 = f64 211 | def __str__(self) -> str: return str(self.f64) 212 | 213 | class AstBinaryOp(Ast): 214 | def __init__(self, ast_type: AstType, left: Ast, op: int, right: Ast) -> None: 215 | super().__init__(ast_type) 216 | self.left = left 217 | self.right = right 218 | self.kind = op 219 | def __str__(self) -> str: return f"\n\t\t({self.left} {op_to_str[self.kind]} {self.right})" 220 | 221 | class AstCompound(Ast): 222 | # list of Ast's 223 | def __init__(self, argv: list[Ast]) -> None: 224 | super().__init__(None) 225 | self.kind = AST_COMPOUND 226 | self.stmts = argv 227 | def __str__(self) -> str: return "\n" + "\n".join(f"\t{node}" for node in cast(list[Ast], self.stmts)) 228 | 229 | class AstFunction(Ast): 230 | def __init__(self, ret_ast_type: AstType, fname: str, params: list[Ast], body: AstCompound, local_defs: list[Ast]) -> None: 231 | super().__init__(ret_ast_type) 232 | self.kind = AST_FUN 233 | self.fname = fname 234 | self.params = params 235 | self.body = body 236 | self.locals = local_defs 237 | 238 | def __str__(self) -> str: 239 | params = ", ".join(f"{node.type} {node}" for node in cast(list[Ast], self.params)) 240 | body = str(self.body) # "\n".join(f"\t{node}" for node in cast(list[Ast], self.body.stmts)) 241 | return f" {self.type} {self.fname}\n({params})\n{body}" 242 | 243 | class AstFunctionCall(Ast): 244 | def __init__(self, ast_type: AstType, argv: list[Ast], fname: str)-> None: 245 | super().__init__(ast_type) 246 | self.kind = AST_FUN_CALL 247 | self.argv = argv 248 | self.fname = fname 249 | 250 | class AstDecl(Ast): 251 | def __init__(self, var: Ast, init: Ast | None)-> None: 252 | super().__init__(None) 253 | self.kind = AST_DECL 254 | self.var = var 255 | self.init = init 256 | def __str__(self) -> str: return f" {self.var.type} \n\t\t{self.var} = {self.init}" 257 | 258 | class AstLVar(Ast): 259 | def __init__(self, ast_type: AstType, name: str) -> None: 260 | super().__init__(ast_type) 261 | self.kind = AST_LVAR 262 | self.name = name 263 | def __str__(self) -> str: return f" {self.name}" 264 | 265 | class AstReturn(Ast): 266 | def __init__(self, ast_type: AstType, retval: Ast | None) -> None: 267 | super().__init__(ast_type) 268 | self.kind = AST_RETURN 269 | self.retval = retval 270 | def __str__(self) -> str: return f" {self.retval}" 271 | 272 | class AstWhile(Ast): 273 | def __init__(self, cond: Ast|None, body: Ast, begin_label: str, end_label: str) -> None: 274 | super().__init__(None) 275 | self.kind = AST_WHILE 276 | self.cond = cond 277 | self.body = body 278 | self.begin_label = begin_label 279 | self.end_label = end_label 280 | 281 | class AstBreak(Ast): 282 | def __init__(self,label: str) -> None: 283 | super().__init__(None) 284 | self.label = label 285 | self.kind = AST_BREAK 286 | 287 | class AstContinue(Ast): 288 | def __init__(self,label: str) -> None: 289 | super().__init__(None) 290 | self.label = label 291 | self.kind = AST_CONTINUE 292 | 293 | class AstIf(Ast): 294 | def __init__(self, cond: Ast,then: Ast,els: Ast | None) -> None: 295 | super().__init__(None) 296 | self.kind = AST_IF 297 | self.cond = cond 298 | self.then = then 299 | self.els = els 300 | 301 | # I've just made this up 302 | def get_priority(tok: TokenPunct) -> int: 303 | if tok.punct in {'[','.','->'}: return 1 304 | elif tok.punct == '/': return 2 305 | elif tok.punct == '*': return 3 306 | elif tok.punct == '+': return 4 307 | elif tok.punct == '-': return 5 308 | elif tok.punct in {'&','|','>>','<<'}: return 6 309 | elif tok.punct == '==': return 7 310 | else: return -1 311 | 312 | label_count = 1 313 | def create_label(): 314 | global label_count 315 | label_count += 1 316 | return f".L{label_count}" 317 | 318 | class Parser: 319 | def __init__(self, tokens: list[Token]) -> None: 320 | self.tokens = tokens 321 | self.tokens_len = len(tokens) 322 | self.ptr = 0 323 | self.env: dict = {} 324 | self.types = {"int": ast_type_i64, "long": ast_type_i32, "float": ast_type_f64} 325 | self.tmp_env: dict = {} 326 | self.tmp_func: AstFunction 327 | self.tmp_locals: list[Ast] 328 | self.tmp_ret_type: AstType 329 | self.tmp_loop_end: str | None 330 | self.tmp_loop_begin: str | None 331 | 332 | def get_type(self,name:str) -> AstType: 333 | if (val := self.types.get(name)) is not None: return val 334 | panic(f"Invalid type name: {name}") 335 | 336 | def is_type(self,tok:Token) -> bool: return (isinstance(tok,TokenIdent) or isinstance(tok,TokenKeyWord)) and tok.ident in self.types 337 | 338 | def is_punct_match(self,tok: Token | None, punct: str) -> bool: return isinstance(tok,TokenPunct) and tok.punct == punct 339 | 340 | def rewind(self) -> None: self.ptr -= 1 341 | 342 | def peek(self) -> Token | None: return None if self.ptr == self.tokens_len else self.tokens[self.ptr] 343 | 344 | def get_next(self) -> Token | None: 345 | if self.ptr == self.tokens_len: 346 | return None 347 | tok = self.tokens[self.ptr] 348 | self.ptr += 1 349 | return tok 350 | 351 | def expect_tok_next(self, expected: str) -> bool: 352 | tok = self.get_next() 353 | if isinstance(tok, TokenPunct) and tok.punct == expected: return True 354 | if tok: panic(f"Parser error: Missmatched characters: {tok} != {expected} at line: {tok.lineno}") 355 | else: panic(f"Parser error: expected {expected} ran out of input") 356 | 357 | def env_get(self, name: str) -> Ast | None: 358 | cur_env = self.tmp_env 359 | while cur_env: 360 | if isinstance((ast := cur_env.get(name)), Ast): return ast 361 | cur_env = cur_env.get("parent") 362 | return None 363 | 364 | def func_get(self,name:str) -> AstFunction | None: 365 | func = self.env_get(name) 366 | return func if isinstance(func,AstFunction) else None 367 | 368 | def parse_function_arguments(self, name: str) -> Ast: 369 | func: AstFunction | None = self.func_get(name) 370 | argv = [] 371 | tok = self.peek() 372 | while tok and not self.is_punct_match(tok,')'): 373 | ast = self.parse_expr() 374 | tok = self.get_next() 375 | argv.append(ast) 376 | if self.is_punct_match(tok,')'): break 377 | elif not self.is_punct_match(tok,','): panic(f"Expected ',' got: {tok}") 378 | tok = self.peek() 379 | if len(argv) == 0: self.get_next() # move passed '(' as we've not parsed anything :( 380 | if func: return AstFunctionCall(cast(AstType,func.type),argv,name) 381 | return AstFunctionCall(ast_type_i64,argv,name) 382 | 383 | def parse_function_call_or_identifier(self, name: TokenIdent) -> None | Ast: 384 | tok = self.get_next() 385 | if self.is_punct_match(tok,'('): 386 | return self.parse_function_arguments(name.ident) 387 | self.rewind() 388 | if (ast := self.env_get(name.ident)) is None: panic(f"Identifier: {name.ident} is undefined at line: {name.lineno}") 389 | return ast 390 | 391 | def parse_primary(self) -> Ast | None: 392 | tok = self.get_next() 393 | if tok is None: panic("Ran out of input while parsing primary expression") 394 | if isinstance(tok,TokenIdent): return self.parse_function_call_or_identifier(tok) 395 | elif isinstance(tok,TokenI64): return AstI64(tok.i64) 396 | elif isinstance(tok,TokenF64): return AstF64(tok.f64) 397 | elif isinstance(tok,TokenPunct): 398 | self.rewind() 399 | return None 400 | 401 | def parse_expr(self, prec: int = 16) -> Ast | None: 402 | lhs = self.parse_primary() 403 | if lhs is None: return None 404 | while 1: 405 | if (tok := self.get_next()) is None: return lhs 406 | if not isinstance(tok,TokenPunct): 407 | self.rewind() 408 | return lhs 409 | prec2 = get_priority(tok) 410 | 411 | if prec2 < 0 or prec <= prec2: 412 | self.rewind() 413 | return lhs 414 | 415 | if self.is_punct_match(tok,'='): 416 | if not lhs.kind in {AST_LVAR}: panic(f"{lhs} is not an lvalue") 417 | 418 | next_prec = prec2 419 | if self.is_punct_match(tok,'='): 420 | next_prec += 1 421 | rhs = self.parse_expr(next_prec) 422 | if rhs is None: panic(f"lefthand lvar missing right hand value at line: {tok.lineno}") 423 | 424 | lhs = AstBinaryOp(ast_type_i64,lhs,str_to_op[tok.punct],rhs) 425 | 426 | def parse_declaration_initialiser(self, var: Ast, terminators: set[str]) -> None | Ast: 427 | tok = self.get_next() 428 | if self.is_punct_match(tok,'='): 429 | init = self.parse_expr() 430 | tok = self.get_next() 431 | assert isinstance(tok,TokenPunct) and tok.punct in terminators 432 | return AstDecl(var,init) 433 | self.rewind() 434 | tok = self.get_next() 435 | if isinstance(tok,TokenPunct) and tok.punct in terminators: 436 | return AstDecl(var,None) 437 | panic(f"Invalid variable initaliser: {tok}") 438 | 439 | def parse_statement(self): 440 | tok = self.get_next() 441 | if isinstance(tok,TokenKeyWord): 442 | if tok.ident == "if": 443 | self.expect_tok_next('(') 444 | cond = self.parse_expr() 445 | if cond is None: panic("if cannot be None") 446 | self.expect_tok_next(')') 447 | then = self.parse_statement() 448 | tok = self.get_next() 449 | if isinstance(tok,TokenKeyWord) and tok.ident == "else": 450 | els = self.parse_statement() 451 | return AstIf(cond,then,els) 452 | self.rewind() 453 | return AstIf(cond,then,None) 454 | elif tok.ident == "return": 455 | print("here") 456 | retval = self.parse_expr() 457 | self.expect_tok_next(';') 458 | return AstReturn(self.tmp_ret_type,retval) 459 | elif tok.ident == "while": 460 | while_begin = create_label() 461 | while_end = create_label() 462 | self.tmp_loop_end = while_end 463 | self.tmp_loop_begin = while_begin 464 | self.tmp_env = {"parent": self.tmp_env} 465 | while_cond = self.parse_expr(16) 466 | self.expect_tok_next(')') 467 | while_body = self.parse_statement() 468 | self.tmp_env = self.tmp_env["parent"] 469 | self.tmp_loop_begin = None 470 | self.tmp_loop_end = None 471 | return AstWhile(while_cond,while_body,while_begin,while_end) 472 | elif tok.ident == "break": 473 | if self.tmp_loop_end is None: panic(f"Floating 'break' statement at line: {tok.lineno}") 474 | return AstBreak(self.tmp_loop_end) 475 | elif tok.ident == "continue": 476 | if self.tmp_loop_begin is None: panic(f"Floating 'continue' statement at line: {tok.lineno}") 477 | return AstContinue(self.tmp_loop_begin) 478 | 479 | def parse_compound(self) -> Ast: 480 | statements = [] 481 | self.tmp_env = {"parent": self.env} 482 | tok = self.peek() 483 | while tok and not self.is_punct_match(tok,'}'): 484 | if self.is_type(tok): 485 | base_type = self.parse_base_type() 486 | while True: 487 | next_type = self.parse_ptr(base_type) 488 | varname = self.get_next() 489 | if varname is None or not isinstance(varname,TokenIdent): break 490 | var = AstLVar(next_type,varname.ident) 491 | self.tmp_env[var.name] = var 492 | self.tmp_locals.append(var) 493 | statement = self.parse_declaration_initialiser(var,{',',';'}) 494 | if statement is not None: statements.append(statement) 495 | self.rewind() 496 | tok = self.get_next() 497 | if self.is_punct_match(tok,';'): break 498 | elif self.is_punct_match(tok,','): continue 499 | else: panic(f"Unexpected token: {tok}") 500 | else: 501 | stmt = self.parse_statement() 502 | if stmt: statements.append(stmt) 503 | else: break 504 | tok = self.peek() 505 | self.tmp_env = self.tmp_env["parent"] 506 | self.expect_tok_next('}') 507 | return AstCompound(statements) 508 | 509 | def parse_base_type(self) -> AstType: 510 | tok = self.get_next() 511 | if tok is None: panic("Ran out of tokens while parsing base_type") 512 | if isinstance(tok,TokenIdent) or isinstance(tok,TokenKeyWord): 513 | return self.get_type(tok.ident) 514 | panic(f"undefined type {tok}") 515 | 516 | def parse_ptr(self, base_type: AstType) -> AstType: 517 | ptr_type = base_type 518 | while True: 519 | tok = self.get_next() 520 | if not self.is_punct_match(tok,'*'): 521 | self.rewind() 522 | return ptr_type 523 | ptr_type = AstTypePtr(ptr_type) 524 | 525 | def parse_type(self) -> AstType: 526 | base_type = self.parse_base_type() 527 | return self.parse_ptr(base_type) 528 | 529 | def parse_params(self) -> list[Ast]: 530 | params = [] 531 | self.expect_tok_next('(') 532 | while tok := self.peek(): 533 | if isinstance(tok, TokenPunct) and tok.punct == ')': 534 | self.get_next() 535 | break 536 | param_type = self.parse_type() 537 | name = self.get_next() 538 | if name is None: panic(f"Expected a named variable while parsing function parameters of {self.tmp_func.fname}") 539 | elif not isinstance(name,TokenIdent): panic("Expected Identifier got {name} at line {name.lineno}") 540 | params.append(AstLVar(param_type,name.ident)) 541 | return params 542 | 543 | def parse_function(self, ret_type: AstType, tok_ident: TokenIdent) -> Ast: 544 | self.tmp_env = { 545 | "parent": self.env 546 | } 547 | self.tmp_locals = [] 548 | self.tmp_ret_type = ret_type 549 | tmp_compound = AstCompound([]) 550 | params = self.parse_params() 551 | self.expect_tok_next('{') 552 | func = AstFunction(ret_type, tok_ident.ident, params, tmp_compound, []) 553 | body = self.parse_compound() 554 | func.body = cast(AstCompound, body) 555 | func.locals = self.tmp_locals 556 | self.tmp_locals = [] 557 | self.tmp_func = func 558 | return func 559 | 560 | def parse_top(self) -> Ast | None: 561 | tok = self.peek() 562 | if tok is None: return None 563 | if isinstance(tok,TokenKeyWord) or isinstance(tok,TokenIdent): 564 | ret_type = self.parse_type() 565 | name = self.peek() 566 | if isinstance(name,TokenIdent): 567 | self.get_next() 568 | ast = self.parse_function(ret_type,name) 569 | return ast 570 | else: 571 | panic(f"Error expected function definition at top level def got {tok} at line {tok.lineno}") 572 | 573 | def parse(self) -> list[Ast]: return list(iter(self.parse_top, None)) 574 | 575 | ## TAC ======================= 576 | # Three Adress Code IR START 577 | TAC_NULL = -1 578 | TAC_REG = 0 579 | TAC_ALU = 1 580 | TAC_INT = 2 581 | TAC_FLOAT = 3 582 | TAC_BINOP = 4 583 | TAC_LIST = 5 584 | TAC_FUNC = 6 585 | TAC_SAVE = 7 586 | TAC_LOAD = 8 587 | TAC_RETURN = 9 588 | 589 | class TACNode: 590 | def __init__(self, kind: int) -> None: 591 | self.kind = kind 592 | 593 | class TAClist(TACNode): 594 | def __init__(self, tac_list: list[TACNode]) -> None: 595 | super().__init__(TAC_LIST) 596 | self.tac_list = tac_list 597 | def __str__(self) -> str: 598 | buf = "" 599 | for tac in self.tac_list: 600 | buf += f"\t{tac}\n" 601 | return buf 602 | 603 | class TACNull(TACNode): 604 | def __init__(self) -> None: 605 | super().__init__(TAC_NULL) 606 | def __str__(self) -> str: return "NULL" 607 | 608 | class TACInt(TACNode): 609 | def __init__(self, num: int, size: int) -> None: 610 | super().__init__(TAC_INT) 611 | self.num = num 612 | self.size = size 613 | def __str__(self) -> str: return str(f"@i{self.size*8}::{self.num}") 614 | 615 | class TACFloat(TACNode): 616 | def __init__(self, num: float, size: int) -> None: 617 | super().__init__(TAC_FLOAT) 618 | self.num = num 619 | self.size = size 620 | def __str__(self) -> str: return str(f"@f{self.size*8}::{self.num}") 621 | 622 | class TACReg(TACNode): 623 | def __init__(self, reg: int) -> None: 624 | super().__init__(TAC_REG) 625 | self.reg = reg 626 | def __str__(self) -> str: return str(f"R{self.reg}") 627 | 628 | class TACAlu(TACNode): 629 | def __init__(self, alu: int) -> None: 630 | super().__init__(TAC_ALU) 631 | self.alu = alu 632 | def __str__(self) -> str: return f"{op_to_alu[self.alu]}" 633 | 634 | class TACBinOp(TACNode): 635 | def __init__(self, alu: TACNode, op1: TACNode, op2: TACNode, result: TACNode) -> None: 636 | super().__init__(TAC_BINOP) 637 | self.alu = alu 638 | self.op1 = op1 639 | self.op2 = op2 640 | self.result = result 641 | def __str__(self) -> str: 642 | # weird; kinda TAC, kinda ASM o7 643 | return f"{self.result} = {self.alu} {self.op1}, {self.op2}" 644 | 645 | class TACFunc(TACNode): 646 | def __init__(self, name: str, body: TACNode, local_vars: list[Ast], params: list[Ast]): 647 | super().__init__(TAC_FUNC) 648 | self.name = name 649 | self.locals = local_vars 650 | self.body = body 651 | self.params = params 652 | def __str__(self) -> str: return f"{self.name}::\n{self.body}" 653 | 654 | class TACSave(TACNode): 655 | def __init__(self, variable: Ast, init: TACNode | None, reg: TACNode) -> None: 656 | super().__init__(TAC_SAVE) 657 | self.var = variable 658 | self.init = init 659 | self.reg = reg 660 | self.offset = variable.offset 661 | def __str__(self) -> str: return f"{self.reg} = {self.init}" 662 | 663 | class TACLoad(TACNode): 664 | def __init__(self, variable: Ast, reg: TACNode) -> None: 665 | super().__init__(TAC_LOAD) 666 | self.var = variable 667 | self.reg = reg 668 | def __str__(self) -> str: return f"{self.reg}" 669 | 670 | class TACReturn(TACNode): 671 | def __init__(self, reg: TACReg, expr: Ast|None) -> None: 672 | super().__init__(TAC_RETURN) 673 | self.reg = reg 674 | self.expr = expr 675 | def __str__(self) -> str: return f"RET {self.reg}" 676 | 677 | class Register: 678 | __reg: int = 0 679 | 680 | @staticmethod 681 | def get_next() -> TACNode: 682 | reg = Register.__reg 683 | Register.__reg += 1 684 | return TACReg(reg) 685 | 686 | @staticmethod 687 | def reset() -> None: Register.__reg = 0 688 | 689 | class IR: 690 | def __init__(self) -> None: 691 | self.var_to_reg = {} 692 | self.ops: list[TACNode] = [] 693 | 694 | def ir_literal(ast: AstF64|AstI32|AstI64) -> TACNode: 695 | if ast.type: 696 | if isinstance(ast,AstI32) or isinstance(ast,AstI64): return TACInt(ast.i64,ast.type.size) 697 | elif isinstance(ast,AstF64): return TACFloat(ast.f64,ast.type.size) 698 | else: panic(f"unknown kind: {ast.type.kind}") 699 | else: panic(f"kind: {ast.type} is NULL") 700 | 701 | def ir_compound(ir: IR, stmts: list[Ast] = []) -> TACNode: 702 | return TAClist([ir_expr(ir,stmt) for stmt in stmts]) 703 | 704 | def ir_save(ir: IR, ast: AstDecl) -> TACNode: 705 | if ast.init: 706 | init = ir_expr(ir,ast.init) 707 | save = TACSave(ast.var,init,Register.get_next()) 708 | else: 709 | save = TACSave(ast.var,None,Register.get_next()) 710 | lvar = cast(AstLVar,ast.var) 711 | ir.var_to_reg[lvar.name] = save.reg 712 | ir.ops.append(save) 713 | return save 714 | 715 | def ir_load(ir:IR, ast: AstLVar) -> TACNode: return TACLoad(ast,ir.var_to_reg[ast.name]) 716 | 717 | def ir_return(ir: IR, ast: AstReturn) -> TACNode: 718 | expr = ir_expr(ir,ast.retval) 719 | reg = cast(TACReg, expr) 720 | if ast.retval: ret = TACReturn(reg,ast.retval) 721 | else: ret = TACReturn(reg,None) 722 | ir.ops.append(ret) 723 | return ret 724 | 725 | def ir_expr(ir: IR, ast: Ast | None) -> TACNode: 726 | if ast is None: return TACNull() 727 | if isinstance(ast,AstI32) or \ 728 | isinstance (ast,AstI64) or \ 729 | isinstance(ast,AstF64): return ir_literal(ast) 730 | elif isinstance(ast,AstCompound): return ir_compound(ir,ast.stmts) 731 | elif isinstance(ast,AstDecl): return ir_save(ir,ast) # This will be a local 732 | elif isinstance(ast,AstLVar): 733 | load = ir_load(ir,ast) 734 | return load 735 | elif isinstance(ast,AstReturn): return ir_return(ir,ast) 736 | elif isinstance(ast,AstFunction): 737 | body = ir_compound(ir,ast.body.stmts) 738 | return TACFunc(ast.fname,body,ast.locals,ast.params) 739 | elif ast.kind in op_to_alu: 740 | quad = cast(TACBinOp, ir_binop(ir,ast)) 741 | ir.ops.append(quad) 742 | return quad.result 743 | else: panic(f"kind: {ast} not handled") 744 | 745 | def ir_binop(ir: IR, ast: Ast) -> TACNode: 746 | binop = cast(AstBinaryOp,ast) 747 | return TACBinOp( 748 | TACAlu(ast.kind), 749 | ir_expr(ir,binop.left), 750 | ir_expr(ir,binop.right), 751 | Register.get_next()) 752 | 753 | def ir_gen(funcs: list[Ast]) -> list[TACNode]: 754 | ir_funcs = [] 755 | for func in funcs: 756 | if not isinstance(func,AstFunction): panic("Can only generate ir from function definitions") 757 | ir = IR() 758 | ir_fn = ir_expr(ir,func) 759 | cast(TACFunc,ir_fn).body = TAClist(ir.ops) 760 | #print(ir.ops) 761 | ir_funcs.append(ir_fn) 762 | return ir_funcs 763 | 764 | ## CODE GEN 765 | x86_registers = ["RDI","RSI","RDX","RCX","R8","R9","R10","R11","R12","R13","R14","R15"] 766 | def x86(ops: list[TACNode], stack_space: int) -> str: 767 | x86_code = [] 768 | for op in ops: 769 | if isinstance(op,TACReturn): 770 | if stack_space > 0: 771 | x86_code.append(f"ADD\tRSP, {stack_space}\n\t") 772 | if op.expr: 773 | x86_code.append(f"MOVQ\tRAX, {op.expr.offset}[RBP]\n\t") 774 | x86_code.append(f"LEAVE\n\tRET\n\n") 775 | if isinstance(op,TACSave): 776 | if isinstance(op.init,TACInt): 777 | x86_code.append(f"MOVQ\t{op.var.offset}[RBP], {op.init.num}\n\t") 778 | else: 779 | x86_code.append(f"MOVQ\t{op.var.offset}[RBP], RAX\n\t") 780 | if isinstance(op,TACBinOp): 781 | binop = cast(TACBinOp,op) 782 | alu = cast(TACAlu, binop.alu) 783 | mnemonic = op_to_alu[alu.alu] 784 | # XXX: optimiser would nuke this as it is a constant expression 785 | if isinstance(binop.op1,TACInt) and isinstance(binop.op2,TACInt): 786 | x86_code.append(f"MOVQ\tRAX, {binop.op1.num}\n\t") 787 | x86_code.append(f"MOVQ\tRCX, {binop.op2.num}\n\t") 788 | x86_code.append(f"{mnemonic}\tRAX, RCX\n\t") 789 | elif isinstance(binop.op1,TACLoad) and isinstance(binop.op2,TACInt): 790 | if alu.alu == OP_SHL or alu.alu == OP_SHR: 791 | x86_code.append(f"{mnemonic}\tAL, {binop.op2.num}\n\t") 792 | else: 793 | x86_code.append(f"MOVQ\tRCX, {binop.op2.num}\n\t") 794 | x86_code.append(f"MOVQ\tRAX, {binop.op1.var.offset}[RBP]\n\t") 795 | x86_code.append(f"{mnemonic}\tRAX, RCX\n\t") 796 | x86_code.append(f"MOVQ\t{binop.op1.var.offset}[RBP], RAX\n\t") 797 | elif isinstance(binop.op1,TACLoad) and isinstance(binop.op2, TACLoad): 798 | x86_code.append(f"MOVQ\tRAX, {binop.op1.var.offset}[RBP]\n\t") 799 | x86_code.append(f"MOVQ\tRCX, {binop.op2.var.offset}[RBP]\n\t") 800 | x86_code.append(f"{mnemonic}\tRAX, RCX\n\t") 801 | x86_code.append(f"MOVQ\t{binop.op1.var.offset}[RBP], RAX\n\t") 802 | return ''.join(x86_code) 803 | 804 | def align(n: int, m: int) -> int: 805 | r = n % m 806 | return n if r == 0 else n-r+m 807 | 808 | def x86_func(func: TACFunc) -> str: 809 | asm_func = f"_{func.name}::\n\tPUSH\tRBP\n\tMOVQ\tRBP, RSP\n\t" 810 | body = cast(TAClist,func.body) 811 | local_size = 0 812 | stack_space = 0 813 | new_offset = 0 814 | offset = 0 815 | for locl in func.locals: 816 | if locl.type: 817 | local_size += align(locl.type.size,8) 818 | new_offset -= align(locl.type.size,8) 819 | locl.offset = new_offset 820 | for param in func.params: 821 | if param.type: 822 | local_size += align(param.type.size,8) 823 | if local_size > 0: 824 | stack_space = align(local_size,16) 825 | asm_func += f"SUB\tRSP,{stack_space}\n\t" 826 | offset = stack_space 827 | arg = 2 828 | ireg = 0 829 | for _ in func.params: 830 | if ireg == 6: 831 | off_ = arg * 8 832 | asm_func += f"MOVQ\tRAX, {off_}[RBP]\n\tMOVQ\t{-offset}[RBP], RAX\n\t" 833 | arg += 1 834 | else: 835 | asm_func += f"MOVQ\t{-offset}[RBP], {x86_registers[ireg]}\n\t" 836 | ireg += 1 837 | return f"{asm_func}{x86(body.tac_list,stack_space)}" 838 | 839 | def make_add_i64(a: Ast, b: Ast) -> Ast: return AstBinaryOp(ast_type_i64,a,OP_PLUS,b) 840 | def make_shl_i64(a: Ast, b: Ast) -> Ast: return AstBinaryOp(ast_type_i64,a,OP_SHL,b) 841 | 842 | def main(): 843 | if len(sys.argv) < 2: 844 | panic(f"Usage: {sys.argv[0]} .c") 845 | with open(sys.argv[1]) as f: 846 | code = f.read() 847 | print("C ======") 848 | print(code) 849 | 850 | print("LEX =====") 851 | tokens = lexc(code) 852 | for tok in tokens: print(f"'{tok}' ", end='') 853 | print('\n') 854 | 855 | print("AST =====") 856 | parser = Parser(tokens) 857 | ast_list = parser.parse() 858 | for ast in ast_list: 859 | print(ast) 860 | 861 | print("TAC =====") 862 | ir_list = ir_gen(ast_list) 863 | for it in ir_list: 864 | print(it) 865 | 866 | print("x86 =====") 867 | asm_funcs = [] 868 | for it in ir_list: 869 | if not isinstance(it,TACFunc): 870 | panic("Only ir functions are supported") 871 | asm_funcs.append(x86_func(it)) 872 | for asm in asm_funcs: 873 | print(asm) 874 | 875 | if __name__ == "__main__": 876 | main() 877 | --------------------------------------------------------------------------------