├── .cargo └── config ├── .github └── workflows │ └── check.yml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── container ├── Dockerfile └── Makefile ├── ks ├── forloop.ks └── ifelse.ks └── src ├── codegen.rs ├── lexer.rs ├── lib.rs ├── llvm ├── basic_block.rs ├── builder.rs ├── lljit.rs ├── mod.rs ├── module.rs ├── pass_manager.rs ├── type_.rs └── value.rs ├── main.rs └── parser.rs /.cargo/config: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = ["-C", "link-args=-rdynamic"] 3 | -------------------------------------------------------------------------------- /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | name: Build, Test and generate Doc 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | PODMAN_RUN: podman run --rm -t -v $PWD:/work -w /work ks-rs 12 | 13 | jobs: 14 | ci: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | - run: make -C container 19 | 20 | - run: cargo fmt -- --check 21 | - run: eval $PODMAN_RUN cargo build --verbose 22 | - run: eval $PODMAN_RUN cargo test --verbose 23 | 24 | - name: Generate doc 25 | run: | 26 | eval $PODMAN_RUN cargo doc --no-deps 27 | echo "" > target/doc/index.html 28 | - name: Upload doc to gh pages 29 | uses: peaceiris/actions-gh-pages@v3 30 | with: 31 | github_token: ${{ secrets.GITHUB_TOKEN }} 32 | publish_dir: ./target/doc 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "aho-corasick" 7 | version = "0.7.19" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "b4f55bd91a0978cbfd91c457a164bab8b4001c833b7f323132c0a4e1922dd44e" 10 | dependencies = [ 11 | "memchr", 12 | ] 13 | 14 | [[package]] 15 | name = "cc" 16 | version = "1.0.73" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" 19 | 20 | [[package]] 21 | name = "lazy_static" 22 | version = "1.4.0" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 25 | 26 | [[package]] 27 | name = "libc" 28 | version = "0.2.133" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "c0f80d65747a3e43d1596c7c5492d95d5edddaabd45a7fcdb02b95f644164966" 31 | 32 | [[package]] 33 | name = "llvm-kaleidoscope-rs" 34 | version = "0.1.0" 35 | dependencies = [ 36 | "libc", 37 | "llvm-sys", 38 | ] 39 | 40 | [[package]] 41 | name = "llvm-sys" 42 | version = "160.2.0" 43 | source = "registry+https://github.com/rust-lang/crates.io-index" 44 | checksum = "0438b23666723e3851fe336c1757acd3c915a0a14b41cc41284143609ca6cdce" 45 | dependencies = [ 46 | "cc", 47 | "lazy_static", 48 | "libc", 49 | "regex", 50 | "semver", 51 | ] 52 | 53 | [[package]] 54 | name = "memchr" 55 | version = "2.5.0" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" 58 | 59 | [[package]] 60 | name = "regex" 61 | version = "1.6.0" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b" 64 | dependencies = [ 65 | "aho-corasick", 66 | "memchr", 67 | "regex-syntax", 68 | ] 69 | 70 | [[package]] 71 | name = "regex-syntax" 72 | version = "0.6.27" 73 | source = "registry+https://github.com/rust-lang/crates.io-index" 74 | checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" 75 | 76 | [[package]] 77 | name = "semver" 78 | version = "1.0.16" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" 81 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llvm-kaleidoscope-rs" 3 | version = "0.1.0" 4 | edition = "2018" 5 | 6 | [dependencies] 7 | libc = "0.2" 8 | llvm-sys = {version = "160.0", features = ["strict-versioning"]} 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Johannes Stoelp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llvm-kaleidoscope-rs 2 | 3 | [![Rust][wf-badge]][wf-output] [![Rustdoc][doc-badge]][doc-html] 4 | 5 | [wf-output]: https://github.com/johannst/llvm-kaleidoscope-rs/actions/workflows/check.yml 6 | [wf-badge]: https://github.com/johannst/llvm-kaleidoscope-rs/actions/workflows/check.yml/badge.svg 7 | [doc-html]: https://johannst.github.io/llvm-kaleidoscope-rs/llvm_kaleidoscope_rs/index.html 8 | [doc-badge]: https://img.shields.io/badge/llvm__kaleidoscope__rs-rustdoc-blue.svg?style=flat&logo=rust 9 | 10 | The purpose of this repository is to learn about the [`llvm`][llvm] compiler 11 | infrastructure and practice some [`rust-lang`][rust]. 12 | 13 | To reach the goals set, we follow the official llvm tutorial [`Kaleidoscope: 14 | Implementing a Language with LLVM`][llvm-tutorial]. This tutorial is written in 15 | `C++` and structured in multiple chapters, we will try to follow along and 16 | implement every chapter in rust. 17 | 18 | The topics of the chapters are as follows: 19 | 20 | - Chapter 1: [Kaleidoscope Introduction and the Lexer][llvm-ch1] 21 | - Chapter 2: [Implementing a Parser and AST][llvm-ch2] 22 | - Chapter 3: [Code generation to LLVM IR][llvm-ch3] 23 | - Chapter 4: [Adding JIT and Optimizer Support][llvm-ch4] 24 | - Chapter 5: [Extending the Language: Control Flow][llvm-ch5] 25 | 26 | The implementation after each chapter can be compiled and executed by checking 27 | out the corresponding tag for the chapter. 28 | ```bash 29 | > git tag -l 30 | chapter1 31 | chapter2 32 | chapter3 33 | chapter4 34 | chapter5 35 | ``` 36 | 37 | Names of variables and functions as well as the structure of the functions are 38 | mainly kept aligned with the official tutorial. This aims to make it easy to 39 | map the `rust` implementation onto the `C++` implementation when following the 40 | tutorial. 41 | 42 | One further note on the llvm API, instead of using the llvm `C++` API we are 43 | going to use the llvm `C` API and build our own safe wrapper specialized for 44 | this tutorial. The wrapper offers a similar interface as the `C++` API and is 45 | implemented in [`src/llvm/`](src/llvm/) 46 | 47 | ## Demo 48 | 49 | ```bash 50 | # Run kaleidoscope program from file. 51 | cargo run ks/ 52 | 53 | # Run REPL loop, parsing from stdin. 54 | cargo run 55 | ``` 56 | 57 | ## Documentation 58 | 59 | Rustdoc for this crate is available at 60 | [johannst.github.io/llvm-kaleidoscope-rs][gh-pages]. 61 | 62 | ## Build with provided container file 63 | 64 | The provided [Dockerfile](container/Dockerfile) documents the required 65 | dependencies for an ubuntu based system and serves as a build environment with 66 | the correct llvm version as specified in the [Cargo.toml](Cargo.toml) file. 67 | 68 | ```bash 69 | ## Either user podman .. 70 | 71 | # Build the image *ks-rs*. Depending on the downlink this may take some minutes. 72 | make -C container 73 | 74 | podman run --rm -it -v $PWD:/work -w /work ks-rs 75 | # Drops into a shell in the container, just use cargo build / run ... 76 | 77 | ## .. or docker. 78 | 79 | # Build the image *ks-rs*. Depending on the downlink this may take some minutes. 80 | make -C container docker 81 | 82 | docker run --rm -it -v $PWD:/work -w /work ks-rs 83 | # Drops into a shell in the container, just use cargo build / run ... 84 | ``` 85 | 86 | ## License 87 | 88 | This project is licensed under the [MIT](LICENSE) license. 89 | 90 | [llvm]: https://llvm.org 91 | [llvm-tutorial]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html 92 | [llvm-ch1]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl01.html 93 | [llvm-ch2]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl02.html 94 | [llvm-ch3]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl03.html 95 | [llvm-ch4]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl04.html 96 | [llvm-ch5]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl05.html 97 | [rust]: https://www.rust-lang.org 98 | [gh-pages]: https://johannst.github.io/llvm-kaleidoscope-rs/llvm_kaleidoscope_rs/index.html 99 | -------------------------------------------------------------------------------- /container/Dockerfile: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: MIT 2 | # 3 | # Copyright (c) 2021, Johannes Stoelp 4 | 5 | FROM ubuntu 6 | 7 | RUN apt update && \ 8 | DEBIAN_FRONTEND=noninteractive \ 9 | apt install \ 10 | --yes \ 11 | --no-install-recommends \ 12 | ca-certificates \ 13 | build-essential \ 14 | cargo \ 15 | llvm-16-dev \ 16 | # For polly dependency. 17 | # https://gitlab.com/taricorp/llvm-sys.rs/-/issues/13 18 | libpolly-16-dev \ 19 | libz-dev \ 20 | libzstd-dev \ 21 | && \ 22 | rm -rf /var/lib/apt/lists/* && \ 23 | apt-get clean 24 | -------------------------------------------------------------------------------- /container/Makefile: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: MIT 2 | # 3 | # Copyright (c) 2021, Johannes Stoelp 4 | 5 | podman: build-podman 6 | docker: build-docker 7 | 8 | build-%: 9 | $* build -t ks-rs . 10 | -------------------------------------------------------------------------------- /ks/forloop.ks: -------------------------------------------------------------------------------- 1 | extern putchard(char); 2 | def printstar(n) 3 | for i = 1, i < n, 1.0 in 4 | putchard(42); # ascii 42 = '*' 5 | 6 | # print 100 '*' characters 7 | printstar(100); 8 | -------------------------------------------------------------------------------- /ks/ifelse.ks: -------------------------------------------------------------------------------- 1 | def fib(x) 2 | if x < 3 then 3 | 1 4 | else 5 | fib(x-1)+fib(x-2); 6 | 7 | fib(10); 8 | -------------------------------------------------------------------------------- /src/codegen.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use std::collections::HashMap; 6 | 7 | use crate::llvm::{FnValue, FunctionPassManager, IRBuilder, Module, Value}; 8 | use crate::parser::{ExprAST, FunctionAST, PrototypeAST}; 9 | use crate::Either; 10 | 11 | type CodegenResult = Result; 12 | 13 | /// Code generator from kaleidoscope AST to LLVM IR. 14 | pub struct Codegen<'llvm, 'a> { 15 | module: &'llvm Module, 16 | builder: &'a IRBuilder<'llvm>, 17 | fpm: &'a FunctionPassManager<'llvm>, 18 | fn_protos: &'a mut HashMap, 19 | } 20 | 21 | impl<'llvm, 'a> Codegen<'llvm, 'a> { 22 | /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`. 23 | pub fn compile( 24 | module: &'llvm Module, 25 | fn_protos: &mut HashMap, 26 | compilee: Either<&PrototypeAST, &FunctionAST>, 27 | ) -> CodegenResult> { 28 | let mut cg = Codegen { 29 | module, 30 | builder: &IRBuilder::with_ctx(module), 31 | fpm: &FunctionPassManager::with_ctx(module), 32 | fn_protos, 33 | }; 34 | let mut variables = HashMap::new(); 35 | 36 | match compilee { 37 | Either::A(proto) => Ok(cg.codegen_prototype(proto)), 38 | Either::B(func) => cg.codegen_function(func, &mut variables), 39 | } 40 | } 41 | 42 | fn codegen_expr( 43 | &self, 44 | expr: &ExprAST, 45 | named_values: &mut HashMap>, 46 | ) -> CodegenResult> { 47 | match expr { 48 | ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)), 49 | ExprAST::Variable(name) => match named_values.get(name.as_str()) { 50 | Some(value) => Ok(*value), 51 | None => Err("Unknown variable name".into()), 52 | }, 53 | ExprAST::Binary(binop, lhs, rhs) => { 54 | let l = self.codegen_expr(lhs, named_values)?; 55 | let r = self.codegen_expr(rhs, named_values)?; 56 | 57 | match binop { 58 | '+' => Ok(self.builder.fadd(l, r)), 59 | '-' => Ok(self.builder.fsub(l, r)), 60 | '*' => Ok(self.builder.fmul(l, r)), 61 | '<' => { 62 | let res = self.builder.fcmpult(l, r); 63 | // Turn bool into f64. 64 | Ok(self.builder.uitofp(res, self.module.type_f64())) 65 | } 66 | _ => Err("invalid binary operator".into()), 67 | } 68 | } 69 | ExprAST::Call(callee, args) => match self.get_function(callee) { 70 | Some(callee) => { 71 | if callee.args() != args.len() { 72 | return Err("Incorrect # arguments passed".into()); 73 | } 74 | 75 | // Generate code for function argument expressions. 76 | let mut args: Vec> = args 77 | .iter() 78 | .map(|arg| self.codegen_expr(arg, named_values)) 79 | .collect::>()?; 80 | 81 | Ok(self.builder.call(callee, &mut args)) 82 | } 83 | None => Err("Unknown function referenced".into()), 84 | }, 85 | ExprAST::If { cond, then, else_ } => { 86 | // For 'if' expressions we are building the following CFG. 87 | // 88 | // ; cond 89 | // br 90 | // | 91 | // +-----+------+ 92 | // v v 93 | // ; then ; else 94 | // | | 95 | // +-----+------+ 96 | // v 97 | // ; merge 98 | // phi then, else 99 | // ret phi 100 | 101 | let cond_v = { 102 | // Codgen 'cond' expression. 103 | let v = self.codegen_expr(cond, named_values)?; 104 | // Compare 'v' against '0' as 'one = ordered not equal'. 105 | self.builder 106 | .fcmpone(v, self.module.type_f64().const_f64(0f64)) 107 | }; 108 | 109 | // Get the function we are currently inserting into. 110 | let the_function = self.builder.get_insert_block().get_parent(); 111 | 112 | // Create basic blocks for the 'then' / 'else' expressions as well as the return 113 | // instruction ('merge'). 114 | // 115 | // Append the 'then' basic block to the function, don't insert the 'else' and 116 | // 'merge' basic blocks yet. 117 | let then_bb = self.module.append_basic_block(the_function); 118 | let else_bb = self.module.create_basic_block(); 119 | let merge_bb = self.module.create_basic_block(); 120 | 121 | // Create a conditional branch based on the result of the 'cond' expression. 122 | self.builder.cond_br(cond_v, then_bb, else_bb); 123 | 124 | // Move to 'then' basic block and codgen the 'then' expression. 125 | self.builder.pos_at_end(then_bb); 126 | let then_v = self.codegen_expr(then, named_values)?; 127 | // Create unconditional branch to 'merge' block. 128 | self.builder.br(merge_bb); 129 | // Update reference to current basic block (in case the 'then' expression added new 130 | // basic blocks). 131 | let then_bb = self.builder.get_insert_block(); 132 | 133 | // Now append the 'else' basic block to the function. 134 | the_function.append_basic_block(else_bb); 135 | // Move to 'else' basic block and codgen the 'else' expression. 136 | self.builder.pos_at_end(else_bb); 137 | let else_v = self.codegen_expr(else_, named_values)?; 138 | // Create unconditional branch to 'merge' block. 139 | self.builder.br(merge_bb); 140 | // Update reference to current basic block (in case the 'else' expression added new 141 | // basic blocks). 142 | let else_bb = self.builder.get_insert_block(); 143 | 144 | // Now append the 'merge' basic block to the function. 145 | the_function.append_basic_block(merge_bb); 146 | // Move to 'merge' basic block. 147 | self.builder.pos_at_end(merge_bb); 148 | // Codegen the phi node returning the appropriate value depending on the branch 149 | // condition. 150 | let phi = self.builder.phi( 151 | self.module.type_f64(), 152 | &[(then_v, then_bb), (else_v, else_bb)], 153 | ); 154 | 155 | Ok(*phi) 156 | } 157 | ExprAST::For { 158 | var, 159 | start, 160 | end, 161 | step, 162 | body, 163 | } => { 164 | // For 'for' expression we build the following structure. 165 | // 166 | // entry: 167 | // init = start expression 168 | // br loop 169 | // loop: 170 | // i = phi [%init, %entry], [%new_i, %loop] 171 | // ; loop body ... 172 | // new_i = increment %i by step expression 173 | // ; check end condition and branch 174 | // end: 175 | 176 | // Compute initial value for the loop variable. 177 | let start_val = self.codegen_expr(start, named_values)?; 178 | 179 | let the_function = self.builder.get_insert_block().get_parent(); 180 | // Get current basic block (used in the loop variable phi node). 181 | let entry_bb = self.builder.get_insert_block(); 182 | // Add new basic block to emit loop body. 183 | let loop_bb = self.module.append_basic_block(the_function); 184 | 185 | self.builder.br(loop_bb); 186 | self.builder.pos_at_end(loop_bb); 187 | 188 | // Build phi not to pick loop variable in case we come from the 'entry' block. 189 | // Which is the case when we enter the loop for the first time. 190 | // We will add another incoming value once we computed the updated loop variable 191 | // below. 192 | let variable = self 193 | .builder 194 | .phi(self.module.type_f64(), &[(start_val, entry_bb)]); 195 | 196 | // Insert the loop variable into the named values map that it can be referenced 197 | // from the body as well as the end condition. 198 | // In case the loop variable shadows an existing variable remember the shared one. 199 | let old_val = named_values.insert(var.into(), *variable); 200 | 201 | // Generate the loop body. 202 | self.codegen_expr(body, named_values)?; 203 | 204 | // Generate step value expression if available else use '1'. 205 | let step_val = if let Some(step) = step { 206 | self.codegen_expr(step, named_values)? 207 | } else { 208 | self.module.type_f64().const_f64(1f64) 209 | }; 210 | 211 | // Increment loop variable. 212 | let next_var = self.builder.fadd(*variable, step_val); 213 | 214 | // Generate the loop end condition. 215 | let end_cond = self.codegen_expr(end, named_values)?; 216 | let end_cond = self 217 | .builder 218 | .fcmpone(end_cond, self.module.type_f64().const_f64(0f64)); 219 | 220 | // Get current basic block. 221 | let loop_end_bb = self.builder.get_insert_block(); 222 | // Add new basic block following the loop. 223 | let after_bb = self.module.append_basic_block(the_function); 224 | 225 | // Register additional incoming value for the loop variable. This will choose the 226 | // updated loop variable if we are iterating in the loop. 227 | variable.add_incoming(next_var, loop_end_bb); 228 | 229 | // Branch depending on the loop end condition. 230 | self.builder.cond_br(end_cond, loop_bb, after_bb); 231 | 232 | self.builder.pos_at_end(after_bb); 233 | 234 | // Restore the shadowed variable if there was one. 235 | if let Some(old_val) = old_val { 236 | // We inserted 'var' above so it must exist. 237 | *named_values.get_mut(var).unwrap() = old_val; 238 | } else { 239 | named_values.remove(var); 240 | } 241 | 242 | // Loops just always return 0. 243 | Ok(self.module.type_f64().const_f64(0f64)) 244 | } 245 | } 246 | } 247 | 248 | fn codegen_prototype(&self, PrototypeAST(name, args): &PrototypeAST) -> FnValue<'llvm> { 249 | let type_f64 = self.module.type_f64(); 250 | 251 | let mut doubles = Vec::new(); 252 | doubles.resize(args.len(), type_f64); 253 | 254 | // Build the function type: fn(f64, f64, ..) -> f64 255 | let ft = self.module.type_fn(&mut doubles, type_f64); 256 | 257 | // Create the function declaration. 258 | let f = self.module.add_fn(name, ft); 259 | 260 | // Set the names of the function arguments. 261 | for idx in 0..f.args() { 262 | f.arg(idx).set_name(&args[idx]); 263 | } 264 | 265 | f 266 | } 267 | 268 | fn codegen_function( 269 | &mut self, 270 | FunctionAST(proto, body): &FunctionAST, 271 | named_values: &mut HashMap>, 272 | ) -> CodegenResult> { 273 | // Insert the function prototype into the `fn_protos` map to keep track for re-generating 274 | // declarations in other modules. 275 | self.fn_protos.insert(proto.0.clone(), proto.clone()); 276 | 277 | let the_function = self.get_function(&proto.0) 278 | .expect("If proto not already generated, get_function will do for us since we updated fn_protos before-hand!"); 279 | 280 | if the_function.basic_blocks() > 0 { 281 | return Err("Function cannot be redefined.".into()); 282 | } 283 | 284 | // Create entry basic block to insert code. 285 | let bb = self.module.append_basic_block(the_function); 286 | self.builder.pos_at_end(bb); 287 | 288 | // New scope, clear the map with the function args. 289 | named_values.clear(); 290 | 291 | // Update the map with the current functions args. 292 | for idx in 0..the_function.args() { 293 | let arg = the_function.arg(idx); 294 | named_values.insert(arg.get_name().into(), arg); 295 | } 296 | 297 | // Codegen function body. 298 | if let Ok(ret) = self.codegen_expr(body, named_values) { 299 | self.builder.ret(ret); 300 | assert!(the_function.verify()); 301 | 302 | // Run the optimization passes on the function. 303 | self.fpm.run(the_function); 304 | 305 | Ok(the_function) 306 | } else { 307 | todo!("Failed to codegen function body, erase from module!"); 308 | } 309 | } 310 | 311 | /// Lookup function with `name` in the LLVM module and return the corresponding value reference. 312 | /// If the function is not available in the module, check if the prototype is known and codegen 313 | /// it. 314 | /// Return [`None`] if the prototype is not known. 315 | fn get_function(&self, name: &str) -> Option> { 316 | let callee = match self.module.get_fn(name) { 317 | Some(callee) => callee, 318 | None => { 319 | let proto = self.fn_protos.get(name)?; 320 | self.codegen_prototype(proto) 321 | } 322 | }; 323 | 324 | Some(callee) 325 | } 326 | } 327 | -------------------------------------------------------------------------------- /src/lexer.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | #[derive(Debug, PartialEq)] 6 | pub enum Token { 7 | Eof, 8 | Def, 9 | Extern, 10 | Identifier(String), 11 | Number(f64), 12 | Char(char), 13 | If, 14 | Then, 15 | Else, 16 | For, 17 | In, 18 | } 19 | 20 | pub struct Lexer 21 | where 22 | I: Iterator, 23 | { 24 | input: I, 25 | last_char: Option, 26 | } 27 | 28 | impl Lexer 29 | where 30 | I: Iterator, 31 | { 32 | pub fn new(mut input: I) -> Lexer { 33 | let last_char = input.next(); 34 | Lexer { input, last_char } 35 | } 36 | 37 | fn step(&mut self) -> Option { 38 | self.last_char = self.input.next(); 39 | self.last_char 40 | } 41 | 42 | /// Lex and return the next token. 43 | /// 44 | /// Implement `int gettok();` from the tutorial. 45 | pub fn gettok(&mut self) -> Token { 46 | // Eat up whitespaces. 47 | while matches!(self.last_char, Some(c) if c.is_ascii_whitespace()) { 48 | self.step(); 49 | } 50 | 51 | // Unpack last char or return EOF. 52 | let last_char = if let Some(c) = self.last_char { 53 | c 54 | } else { 55 | return Token::Eof; 56 | }; 57 | 58 | // Identifier: [a-zA-Z][a-zA-Z0-9]* 59 | if last_char.is_ascii_alphabetic() { 60 | let mut ident = String::new(); 61 | ident.push(last_char); 62 | 63 | while let Some(c) = self.step() { 64 | if c.is_ascii_alphanumeric() { 65 | ident.push(c) 66 | } else { 67 | break; 68 | } 69 | } 70 | 71 | match ident.as_ref() { 72 | "def" => return Token::Def, 73 | "extern" => return Token::Extern, 74 | "if" => return Token::If, 75 | "then" => return Token::Then, 76 | "else" => return Token::Else, 77 | "for" => return Token::For, 78 | "in" => return Token::In, 79 | _ => {} 80 | } 81 | 82 | return Token::Identifier(ident); 83 | } 84 | 85 | // Number: [0-9.]+ 86 | if last_char.is_ascii_digit() || last_char == '.' { 87 | let mut num = String::new(); 88 | num.push(last_char); 89 | 90 | while let Some(c) = self.step() { 91 | if c.is_ascii_digit() || c == '.' { 92 | num.push(c) 93 | } else { 94 | break; 95 | } 96 | } 97 | 98 | let num: f64 = num.parse().unwrap_or_default(); 99 | return Token::Number(num); 100 | } 101 | 102 | // Eat up comment. 103 | if last_char == '#' { 104 | loop { 105 | match self.step() { 106 | Some(c) if c == '\r' || c == '\n' => return self.gettok(), 107 | None => return Token::Eof, 108 | _ => { /* consume comment */ } 109 | } 110 | } 111 | } 112 | 113 | // Advance last char and return currently last char. 114 | self.step(); 115 | Token::Char(last_char) 116 | } 117 | } 118 | 119 | #[cfg(test)] 120 | mod test { 121 | use super::{Lexer, Token}; 122 | 123 | #[test] 124 | fn test_identifier() { 125 | let mut lex = Lexer::new("a b c".chars()); 126 | assert_eq!(Token::Identifier("a".into()), lex.gettok()); 127 | assert_eq!(Token::Identifier("b".into()), lex.gettok()); 128 | assert_eq!(Token::Identifier("c".into()), lex.gettok()); 129 | assert_eq!(Token::Eof, lex.gettok()); 130 | } 131 | 132 | #[test] 133 | fn test_keyword() { 134 | let mut lex = Lexer::new("def extern".chars()); 135 | assert_eq!(Token::Def, lex.gettok()); 136 | assert_eq!(Token::Extern, lex.gettok()); 137 | assert_eq!(Token::Eof, lex.gettok()); 138 | } 139 | 140 | #[test] 141 | fn test_number() { 142 | let mut lex = Lexer::new("12.34".chars()); 143 | assert_eq!(Token::Number(12.34f64), lex.gettok()); 144 | assert_eq!(Token::Eof, lex.gettok()); 145 | 146 | let mut lex = Lexer::new(" 1.0 2.0 3.0".chars()); 147 | assert_eq!(Token::Number(1.0f64), lex.gettok()); 148 | assert_eq!(Token::Number(2.0f64), lex.gettok()); 149 | assert_eq!(Token::Number(3.0f64), lex.gettok()); 150 | assert_eq!(Token::Eof, lex.gettok()); 151 | 152 | let mut lex = Lexer::new("12.34.56".chars()); 153 | assert_eq!(Token::Number(0f64), lex.gettok()); 154 | assert_eq!(Token::Eof, lex.gettok()); 155 | } 156 | 157 | #[test] 158 | fn test_comment() { 159 | let mut lex = Lexer::new("# some comment".chars()); 160 | assert_eq!(Token::Eof, lex.gettok()); 161 | 162 | let mut lex = Lexer::new("abc # some comment \n xyz".chars()); 163 | assert_eq!(Token::Identifier("abc".into()), lex.gettok()); 164 | assert_eq!(Token::Identifier("xyz".into()), lex.gettok()); 165 | assert_eq!(Token::Eof, lex.gettok()); 166 | } 167 | 168 | #[test] 169 | fn test_chars() { 170 | let mut lex = Lexer::new("a+b-c".chars()); 171 | assert_eq!(Token::Identifier("a".into()), lex.gettok()); 172 | assert_eq!(Token::Char('+'), lex.gettok()); 173 | assert_eq!(Token::Identifier("b".into()), lex.gettok()); 174 | assert_eq!(Token::Char('-'), lex.gettok()); 175 | assert_eq!(Token::Identifier("c".into()), lex.gettok()); 176 | assert_eq!(Token::Eof, lex.gettok()); 177 | } 178 | 179 | #[test] 180 | fn test_whitespaces() { 181 | let mut lex = Lexer::new(" +a b c! ".chars()); 182 | assert_eq!(Token::Char('+'), lex.gettok()); 183 | assert_eq!(Token::Identifier("a".into()), lex.gettok()); 184 | assert_eq!(Token::Identifier("b".into()), lex.gettok()); 185 | assert_eq!(Token::Identifier("c".into()), lex.gettok()); 186 | assert_eq!(Token::Char('!'), lex.gettok()); 187 | assert_eq!(Token::Eof, lex.gettok()); 188 | 189 | let mut lex = Lexer::new("\n a \n\r b \r \n c \r\r \n ".chars()); 190 | assert_eq!(Token::Identifier("a".into()), lex.gettok()); 191 | assert_eq!(Token::Identifier("b".into()), lex.gettok()); 192 | assert_eq!(Token::Identifier("c".into()), lex.gettok()); 193 | assert_eq!(Token::Eof, lex.gettok()); 194 | } 195 | 196 | #[test] 197 | fn test_ite() { 198 | let mut lex = Lexer::new("if then else".chars()); 199 | assert_eq!(Token::If, lex.gettok()); 200 | assert_eq!(Token::Then, lex.gettok()); 201 | assert_eq!(Token::Else, lex.gettok()); 202 | } 203 | 204 | #[test] 205 | fn test_for() { 206 | let mut lex = Lexer::new("for in".chars()); 207 | assert_eq!(Token::For, lex.gettok()); 208 | assert_eq!(Token::In, lex.gettok()); 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use std::convert::TryFrom; 6 | 7 | pub mod codegen; 8 | pub mod lexer; 9 | pub mod llvm; 10 | pub mod parser; 11 | 12 | /// Fixed size of [`SmallCStr`] including the trailing `\0` byte. 13 | pub const SMALL_STR_SIZE: usize = 16; 14 | 15 | /// Small C string on the stack with fixed size [`SMALL_STR_SIZE`]. 16 | /// 17 | /// This is specially crafted to interact with the LLVM C API and get rid of some heap allocations. 18 | #[derive(Debug, PartialEq)] 19 | pub struct SmallCStr([u8; SMALL_STR_SIZE]); 20 | 21 | impl SmallCStr { 22 | /// Create a new C string from `src`. 23 | /// Returns [`None`] if `src` exceeds the fixed size or contains any `\0` bytes. 24 | pub fn new>(src: &T) -> Option { 25 | let src = src.as_ref(); 26 | let len = src.len(); 27 | 28 | // Check for \0 bytes. 29 | let contains_null = unsafe { !libc::memchr(src.as_ptr().cast(), 0, len).is_null() }; 30 | 31 | if contains_null || len > SMALL_STR_SIZE - 1 { 32 | None 33 | } else { 34 | let mut dest = [0; SMALL_STR_SIZE]; 35 | dest[..len].copy_from_slice(src); 36 | Some(SmallCStr(dest)) 37 | } 38 | } 39 | 40 | /// Return pointer to C string. 41 | pub const fn as_ptr(&self) -> *const libc::c_char { 42 | self.0.as_ptr().cast() 43 | } 44 | } 45 | 46 | impl TryFrom<&str> for SmallCStr { 47 | type Error = (); 48 | 49 | fn try_from(value: &str) -> Result { 50 | SmallCStr::new(&value).ok_or(()) 51 | } 52 | } 53 | 54 | /// Either type, for APIs accepting two types. 55 | pub enum Either { 56 | A(A), 57 | B(B), 58 | } 59 | 60 | #[cfg(test)] 61 | mod test { 62 | use super::{SmallCStr, SMALL_STR_SIZE}; 63 | use std::convert::TryInto; 64 | 65 | #[test] 66 | fn test_create() { 67 | let src = "\x30\x31\x32\x33"; 68 | let scs = SmallCStr::new(&src).unwrap(); 69 | assert_eq!(&scs.0[..5], &[0x30, 0x31, 0x32, 0x33, 0x00]); 70 | 71 | let src = b"abcd1234"; 72 | let scs = SmallCStr::new(&src).unwrap(); 73 | assert_eq!( 74 | &scs.0[..9], 75 | &[0x61, 0x62, 0x63, 0x64, 0x31, 0x32, 0x33, 0x34, 0x00] 76 | ); 77 | } 78 | 79 | #[test] 80 | fn test_contain_null() { 81 | let src = "\x30\x00\x32\x33"; 82 | let scs = SmallCStr::new(&src); 83 | assert_eq!(scs, None); 84 | 85 | let src = "\x30\x31\x32\x33\x00"; 86 | let scs = SmallCStr::new(&src); 87 | assert_eq!(scs, None); 88 | } 89 | 90 | #[test] 91 | fn test_too_large() { 92 | let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::(); 93 | let scs = SmallCStr::new(&src); 94 | assert_eq!(scs, None); 95 | 96 | let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::(); 97 | let scs = SmallCStr::new(&src); 98 | assert_eq!(scs, None); 99 | } 100 | 101 | #[test] 102 | fn test_try_into() { 103 | let src = "\x30\x31\x32\x33"; 104 | let scs: Result = src.try_into(); 105 | assert!(scs.is_ok()); 106 | 107 | let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::(); 108 | let scs: Result = src.as_str().try_into(); 109 | assert!(scs.is_err()); 110 | 111 | let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::(); 112 | let scs: Result = src.as_str().try_into(); 113 | assert!(scs.is_err()); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/llvm/basic_block.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_sys::{core::LLVMGetBasicBlockParent, prelude::LLVMBasicBlockRef}; 6 | 7 | use std::marker::PhantomData; 8 | 9 | use super::FnValue; 10 | 11 | /// Wrapper for a LLVM Basic Block. 12 | #[derive(Copy, Clone)] 13 | pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>); 14 | 15 | impl<'llvm> BasicBlock<'llvm> { 16 | /// Create a new BasicBlock instance. 17 | /// 18 | /// # Panics 19 | /// 20 | /// Panics if `bb_ref` is a null pointer. 21 | pub(super) fn new(bb_ref: LLVMBasicBlockRef) -> BasicBlock<'llvm> { 22 | assert!(!bb_ref.is_null()); 23 | BasicBlock(bb_ref, PhantomData) 24 | } 25 | 26 | /// Get the raw LLVM value reference. 27 | #[inline] 28 | pub(super) fn bb_ref(&self) -> LLVMBasicBlockRef { 29 | self.0 30 | } 31 | 32 | /// Get the function to which the basic block belongs. 33 | /// 34 | /// # Panics 35 | /// 36 | /// Panics if LLVM API returns a `null` pointer. 37 | pub fn get_parent(&self) -> FnValue<'llvm> { 38 | let value_ref = unsafe { LLVMGetBasicBlockParent(self.bb_ref()) }; 39 | assert!(!value_ref.is_null()); 40 | 41 | FnValue::new(value_ref) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/llvm/builder.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_sys::{ 6 | core::{ 7 | LLVMAddIncoming, LLVMBuildBr, LLVMBuildCondBr, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul, 8 | LLVMBuildFSub, LLVMBuildPhi, LLVMBuildRet, LLVMBuildUIToFP, LLVMCreateBuilderInContext, 9 | LLVMDisposeBuilder, LLVMGetInsertBlock, LLVMPositionBuilderAtEnd, 10 | }, 11 | prelude::{LLVMBuilderRef, LLVMValueRef}, 12 | LLVMRealPredicate, 13 | }; 14 | 15 | use std::marker::PhantomData; 16 | 17 | use super::{BasicBlock, FnValue, Module, PhiValue, Type, Value}; 18 | 19 | // Definition of LLVM C API functions using our `repr(transparent)` types. 20 | extern "C" { 21 | fn LLVMBuildCall2( 22 | arg1: LLVMBuilderRef, 23 | arg2: Type<'_>, 24 | Fn: FnValue<'_>, 25 | Args: *mut Value<'_>, 26 | NumArgs: ::libc::c_uint, 27 | Name: *const ::libc::c_char, 28 | ) -> LLVMValueRef; 29 | } 30 | 31 | /// Wrapper for a LLVM IR Builder. 32 | pub struct IRBuilder<'llvm> { 33 | builder: LLVMBuilderRef, 34 | _ctx: PhantomData<&'llvm ()>, 35 | } 36 | 37 | impl<'llvm> IRBuilder<'llvm> { 38 | /// Create a new LLVM IR Builder with the `module`s context. 39 | /// 40 | /// # Panics 41 | /// 42 | /// Panics if creating the IR Builder fails. 43 | pub fn with_ctx(module: &'llvm Module) -> IRBuilder<'llvm> { 44 | let builder = unsafe { LLVMCreateBuilderInContext(module.ctx()) }; 45 | assert!(!builder.is_null()); 46 | 47 | IRBuilder { 48 | builder, 49 | _ctx: PhantomData, 50 | } 51 | } 52 | 53 | /// Position the IR Builder at the end of the given Basic Block. 54 | pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) { 55 | unsafe { 56 | LLVMPositionBuilderAtEnd(self.builder, bb.bb_ref()); 57 | } 58 | } 59 | 60 | /// Get the BasicBlock the IRBuilder currently inputs into. 61 | /// 62 | /// # Panics 63 | /// 64 | /// Panics if LLVM API returns a `null` pointer. 65 | pub fn get_insert_block(&self) -> BasicBlock<'llvm> { 66 | let bb_ref = unsafe { LLVMGetInsertBlock(self.builder) }; 67 | assert!(!bb_ref.is_null()); 68 | 69 | BasicBlock::new(bb_ref) 70 | } 71 | 72 | /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction. 73 | /// 74 | /// # Panics 75 | /// 76 | /// Panics if LLVM API returns a `null` pointer. 77 | pub fn fadd(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { 78 | debug_assert!(lhs.is_f64(), "fadd: Expected f64 as lhs operand!"); 79 | debug_assert!(rhs.is_f64(), "fadd: Expected f64 as rhs operand!"); 80 | 81 | let value_ref = unsafe { 82 | LLVMBuildFAdd( 83 | self.builder, 84 | lhs.value_ref(), 85 | rhs.value_ref(), 86 | b"fadd\0".as_ptr().cast(), 87 | ) 88 | }; 89 | Value::new(value_ref) 90 | } 91 | 92 | /// Emit a [fsub](https://llvm.org/docs/LangRef.html#fsub-instruction) instruction. 93 | /// 94 | /// # Panics 95 | /// 96 | /// Panics if LLVM API returns a `null` pointer. 97 | pub fn fsub(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { 98 | debug_assert!(lhs.is_f64(), "fsub: Expected f64 as lhs operand!"); 99 | debug_assert!(rhs.is_f64(), "fsub: Expected f64 as rhs operand!"); 100 | 101 | let value_ref = unsafe { 102 | LLVMBuildFSub( 103 | self.builder, 104 | lhs.value_ref(), 105 | rhs.value_ref(), 106 | b"fsub\0".as_ptr().cast(), 107 | ) 108 | }; 109 | Value::new(value_ref) 110 | } 111 | 112 | /// Emit a [fmul](https://llvm.org/docs/LangRef.html#fmul-instruction) instruction. 113 | /// 114 | /// # Panics 115 | /// 116 | /// Panics if LLVM API returns a `null` pointer. 117 | pub fn fmul(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { 118 | debug_assert!(lhs.is_f64(), "fmul: Expected f64 as lhs operand!"); 119 | debug_assert!(rhs.is_f64(), "fmul: Expected f64 as rhs operand!"); 120 | 121 | let value_ref = unsafe { 122 | LLVMBuildFMul( 123 | self.builder, 124 | lhs.value_ref(), 125 | rhs.value_ref(), 126 | b"fmul\0".as_ptr().cast(), 127 | ) 128 | }; 129 | Value::new(value_ref) 130 | } 131 | 132 | /// Emit a [fcmpult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. 133 | /// 134 | /// # Panics 135 | /// 136 | /// Panics if LLVM API returns a `null` pointer. 137 | pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { 138 | debug_assert!(lhs.is_f64(), "fcmpult: Expected f64 as lhs operand!"); 139 | debug_assert!(rhs.is_f64(), "fcmpult: Expected f64 as rhs operand!"); 140 | 141 | let value_ref = unsafe { 142 | LLVMBuildFCmp( 143 | self.builder, 144 | LLVMRealPredicate::LLVMRealULT, 145 | lhs.value_ref(), 146 | rhs.value_ref(), 147 | b"fcmpult\0".as_ptr().cast(), 148 | ) 149 | }; 150 | Value::new(value_ref) 151 | } 152 | 153 | /// Emit a [fcmpone](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction. 154 | /// 155 | /// # Panics 156 | /// 157 | /// Panics if LLVM API returns a `null` pointer. 158 | pub fn fcmpone(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> { 159 | debug_assert!(lhs.is_f64(), "fcmone: Expected f64 as lhs operand!"); 160 | debug_assert!(rhs.is_f64(), "fcmone: Expected f64 as rhs operand!"); 161 | 162 | let value_ref = unsafe { 163 | LLVMBuildFCmp( 164 | self.builder, 165 | LLVMRealPredicate::LLVMRealONE, 166 | lhs.value_ref(), 167 | rhs.value_ref(), 168 | b"fcmpone\0".as_ptr().cast(), 169 | ) 170 | }; 171 | Value::new(value_ref) 172 | } 173 | 174 | /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction. 175 | /// 176 | /// # Panics 177 | /// 178 | /// Panics if LLVM API returns a `null` pointer. 179 | pub fn uitofp(&self, val: Value<'llvm>, dest_type: Type<'llvm>) -> Value<'llvm> { 180 | debug_assert!(val.is_int(), "uitofp: Expected integer operand!"); 181 | 182 | let value_ref = unsafe { 183 | LLVMBuildUIToFP( 184 | self.builder, 185 | val.value_ref(), 186 | dest_type.type_ref(), 187 | b"uitofp\0".as_ptr().cast(), 188 | ) 189 | }; 190 | Value::new(value_ref) 191 | } 192 | 193 | /// Emit a [call](https://llvm.org/docs/LangRef.html#call-instruction) instruction. 194 | /// 195 | /// # Panics 196 | /// 197 | /// Panics if LLVM API returns a `null` pointer. 198 | pub fn call(&self, fn_value: FnValue<'llvm>, args: &mut [Value<'llvm>]) -> Value<'llvm> { 199 | let value_ref = unsafe { 200 | LLVMBuildCall2( 201 | self.builder, 202 | fn_value.fn_type(), 203 | fn_value, 204 | args.as_mut_ptr(), 205 | args.len() as libc::c_uint, 206 | b"call\0".as_ptr().cast(), 207 | ) 208 | }; 209 | Value::new(value_ref) 210 | } 211 | 212 | /// Emit a [ret](https://llvm.org/docs/LangRef.html#ret-instruction) instruction. 213 | /// 214 | /// # Panics 215 | /// 216 | /// Panics if LLVM API returns a `null` pointer. 217 | pub fn ret(&self, ret: Value<'llvm>) { 218 | let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) }; 219 | assert!(!ret.is_null()); 220 | } 221 | 222 | /// Emit an unconditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction. 223 | /// 224 | /// # Panics 225 | /// 226 | /// Panics if LLVM API returns a `null` pointer. 227 | pub fn br(&self, dest: BasicBlock<'llvm>) { 228 | let br_ref = unsafe { LLVMBuildBr(self.builder, dest.bb_ref()) }; 229 | assert!(!br_ref.is_null()); 230 | } 231 | 232 | /// Emit a conditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction. 233 | /// 234 | /// # Panics 235 | /// 236 | /// Panics if LLVM API returns a `null` pointer. 237 | pub fn cond_br(&self, cond: Value<'llvm>, then: BasicBlock<'llvm>, else_: BasicBlock<'llvm>) { 238 | let br_ref = unsafe { 239 | LLVMBuildCondBr( 240 | self.builder, 241 | cond.value_ref(), 242 | then.bb_ref(), 243 | else_.bb_ref(), 244 | ) 245 | }; 246 | assert!(!br_ref.is_null()); 247 | } 248 | 249 | /// Emit a [phi](https://llvm.org/docs/LangRef.html#phi-instruction) instruction. 250 | /// 251 | /// # Panics 252 | /// 253 | /// Panics if LLVM API returns a `null` pointer. 254 | pub fn phi( 255 | &self, 256 | phi_type: Type<'llvm>, 257 | incoming: &[(Value<'llvm>, BasicBlock<'llvm>)], 258 | ) -> PhiValue<'llvm> { 259 | let phi_ref = 260 | unsafe { LLVMBuildPhi(self.builder, phi_type.type_ref(), b"phi\0".as_ptr().cast()) }; 261 | assert!(!phi_ref.is_null()); 262 | 263 | for (val, bb) in incoming { 264 | debug_assert_eq!( 265 | val.type_of().kind(), 266 | phi_type.kind(), 267 | "Type of incoming phi value must be the same as the type used to build the phi node." 268 | ); 269 | 270 | unsafe { 271 | LLVMAddIncoming(phi_ref, &mut val.value_ref() as _, &mut bb.bb_ref() as _, 1); 272 | } 273 | } 274 | 275 | PhiValue::new(phi_ref) 276 | } 277 | } 278 | 279 | impl Drop for IRBuilder<'_> { 280 | fn drop(&mut self) { 281 | unsafe { LLVMDisposeBuilder(self.builder) } 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /src/llvm/lljit.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_sys::orc2::{ 6 | lljit::{ 7 | LLVMOrcCreateLLJIT, LLVMOrcLLJITAddLLVMIRModuleWithRT, LLVMOrcLLJITGetGlobalPrefix, 8 | LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup, LLVMOrcLLJITRef, 9 | }, 10 | LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess, LLVMOrcDefinitionGeneratorRef, 11 | LLVMOrcJITDylibAddGenerator, LLVMOrcJITDylibCreateResourceTracker, LLVMOrcJITDylibRef, 12 | LLVMOrcReleaseResourceTracker, LLVMOrcResourceTrackerRef, LLVMOrcResourceTrackerRemove, 13 | }; 14 | 15 | use std::convert::TryFrom; 16 | use std::marker::PhantomData; 17 | 18 | use super::{Error, Module}; 19 | use crate::SmallCStr; 20 | 21 | /// Marker trait to constrain function signatures that can be looked up in the JIT. 22 | pub trait JitFn {} 23 | 24 | impl JitFn for unsafe extern "C" fn() -> f64 {} 25 | 26 | /// Wrapper for a LLVM [LLJIT](https://www.llvm.org/docs/ORCv2.html#lljit-and-lllazyjit). 27 | pub struct LLJit { 28 | jit: LLVMOrcLLJITRef, 29 | dylib: LLVMOrcJITDylibRef, 30 | } 31 | 32 | impl LLJit { 33 | /// Create a new LLJit instance. 34 | /// 35 | /// # Panics 36 | /// 37 | /// Panics if LLVM API returns a `null` pointer or an error. 38 | pub fn new() -> LLJit { 39 | let (jit, dylib) = unsafe { 40 | let mut jit = std::ptr::null_mut(); 41 | let err = LLVMOrcCreateLLJIT( 42 | &mut jit as _, 43 | std::ptr::null_mut(), /* builder: nullptr -> default */ 44 | ); 45 | 46 | if let Some(err) = Error::from(err) { 47 | panic!("Error: {}", err.as_str()); 48 | } 49 | 50 | let dylib = LLVMOrcLLJITGetMainJITDylib(jit); 51 | assert!(!dylib.is_null()); 52 | 53 | (jit, dylib) 54 | }; 55 | 56 | LLJit { jit, dylib } 57 | } 58 | 59 | /// Add an LLVM IR module to the JIT. Return a [`ResourceTracker`], which when dropped, will 60 | /// remove the code of the LLVM IR module from the JIT. 61 | /// 62 | /// # Panics 63 | /// 64 | /// Panics if LLVM API returns a `null` pointer or an error. 65 | pub fn add_module(&self, module: Module) -> ResourceTracker<'_> { 66 | let tsmod = module.into_raw_thread_safe_module(); 67 | 68 | let rt = unsafe { 69 | let rt = LLVMOrcJITDylibCreateResourceTracker(self.dylib); 70 | let err = LLVMOrcLLJITAddLLVMIRModuleWithRT(self.jit, rt, tsmod); 71 | 72 | if let Some(err) = Error::from(err) { 73 | panic!("Error: {}", err.as_str()); 74 | } 75 | 76 | rt 77 | }; 78 | 79 | ResourceTracker::new(rt) 80 | } 81 | 82 | /// Find the symbol with the name `sym` in the JIT. 83 | /// 84 | /// # Panics 85 | /// 86 | /// Panics if the symbol is not found in the JIT. 87 | pub fn find_symbol(&self, sym: &str) -> F { 88 | let sym = 89 | SmallCStr::try_from(sym).expect("Failed to convert 'sym' argument to small C string!"); 90 | 91 | unsafe { 92 | let mut addr = 0u64; 93 | let err = LLVMOrcLLJITLookup(self.jit, &mut addr as _, sym.as_ptr()); 94 | 95 | if let Some(err) = Error::from(err) { 96 | panic!("Error: {}", err.as_str()); 97 | } 98 | 99 | debug_assert_eq!(core::mem::size_of_val(&addr), core::mem::size_of::()); 100 | std::mem::transmute_copy(&addr) 101 | } 102 | } 103 | 104 | /// Enable lookup of dynamic symbols available in the current process from the JIT. 105 | /// 106 | /// # Panics 107 | /// 108 | /// Panics if LLVM API returns an error. 109 | pub fn enable_process_symbols(&self) { 110 | unsafe { 111 | let mut proc_syms_gen: LLVMOrcDefinitionGeneratorRef = std::ptr::null_mut(); 112 | let err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess( 113 | &mut proc_syms_gen as _, 114 | self.global_prefix(), 115 | None, /* filter */ 116 | std::ptr::null_mut(), /* filter ctx */ 117 | ); 118 | 119 | if let Some(err) = Error::from(err) { 120 | panic!("Error: {}", err.as_str()); 121 | } 122 | 123 | LLVMOrcJITDylibAddGenerator(self.dylib, proc_syms_gen); 124 | } 125 | } 126 | 127 | /// Return the global prefix character according to the LLJITs data layout. 128 | fn global_prefix(&self) -> libc::c_char { 129 | unsafe { LLVMOrcLLJITGetGlobalPrefix(self.jit) } 130 | } 131 | } 132 | 133 | /// A resource handle for code added to an [`LLJit`] instance. 134 | /// 135 | /// When a `ResourceTracker` handle is dropped, the code corresponding to the handle will be 136 | /// removed from the JIT. 137 | pub struct ResourceTracker<'jit>(LLVMOrcResourceTrackerRef, PhantomData<&'jit ()>); 138 | 139 | impl<'jit> ResourceTracker<'jit> { 140 | fn new(rt: LLVMOrcResourceTrackerRef) -> ResourceTracker<'jit> { 141 | assert!(!rt.is_null()); 142 | ResourceTracker(rt, PhantomData) 143 | } 144 | } 145 | 146 | impl Drop for ResourceTracker<'_> { 147 | fn drop(&mut self) { 148 | unsafe { 149 | let err = LLVMOrcResourceTrackerRemove(self.0); 150 | 151 | if let Some(err) = Error::from(err) { 152 | panic!("Error: {}", err.as_str()); 153 | } 154 | 155 | LLVMOrcReleaseResourceTracker(self.0); 156 | }; 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/llvm/mod.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | //! Safe wrapper around the LLVM C API. 6 | //! 7 | //! References returned from the LLVM API are tied to the `'llvm` lifetime which is bound to the 8 | //! context where the objects are created in. 9 | //! We do not offer wrappers to remove or delete any objects in the context and therefore all the 10 | //! references will be valid for the liftime of the context. 11 | //! 12 | //! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM 13 | //! API calls. 14 | 15 | use llvm_sys::{ 16 | core::LLVMShutdown, 17 | error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage}, 18 | target::{ 19 | LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter, 20 | LLVM_InitializeNativeTarget, 21 | }, 22 | }; 23 | 24 | use std::ffi::CStr; 25 | 26 | mod basic_block; 27 | mod builder; 28 | mod lljit; 29 | mod module; 30 | mod pass_manager; 31 | mod type_; 32 | mod value; 33 | 34 | pub use basic_block::BasicBlock; 35 | pub use builder::IRBuilder; 36 | pub use lljit::{LLJit, ResourceTracker}; 37 | pub use module::Module; 38 | pub use pass_manager::FunctionPassManager; 39 | pub use type_::Type; 40 | pub use value::{FnValue, PhiValue, Value}; 41 | 42 | struct Error<'llvm>(&'llvm mut libc::c_char); 43 | 44 | impl<'llvm> Error<'llvm> { 45 | fn from(err: LLVMErrorRef) -> Option> { 46 | (!err.is_null()).then(|| Error(unsafe { &mut *LLVMGetErrorMessage(err) })) 47 | } 48 | 49 | fn as_str(&self) -> &str { 50 | unsafe { CStr::from_ptr(self.0) } 51 | .to_str() 52 | .expect("Expected valid UTF8 string from LLVM API") 53 | } 54 | } 55 | 56 | impl Drop for Error<'_> { 57 | fn drop(&mut self) { 58 | unsafe { 59 | LLVMDisposeErrorMessage(self.0 as *mut libc::c_char); 60 | } 61 | } 62 | } 63 | 64 | /// Initialize native target for corresponding to host (useful for jitting). 65 | pub fn initialize_native_taget() { 66 | unsafe { 67 | assert_eq!(LLVM_InitializeNativeTarget(), 0); 68 | assert_eq!(LLVM_InitializeNativeAsmParser(), 0); 69 | assert_eq!(LLVM_InitializeNativeAsmPrinter(), 0); 70 | } 71 | } 72 | 73 | /// Deallocate and destroy all "ManagedStatic" variables. 74 | pub fn shutdown() { 75 | unsafe { 76 | LLVMShutdown(); 77 | }; 78 | } 79 | -------------------------------------------------------------------------------- /src/llvm/module.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_sys::{ 6 | core::{ 7 | LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMCreateBasicBlockInContext, 8 | LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMGetNamedFunction, 9 | LLVMModuleCreateWithNameInContext, 10 | }, 11 | orc2::{ 12 | LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule, 13 | LLVMOrcDisposeThreadSafeContext, LLVMOrcThreadSafeContextGetContext, 14 | LLVMOrcThreadSafeContextRef, LLVMOrcThreadSafeModuleRef, 15 | }, 16 | prelude::{LLVMBool, LLVMContextRef, LLVMModuleRef, LLVMTypeRef}, 17 | LLVMTypeKind, 18 | }; 19 | 20 | use std::convert::TryFrom; 21 | 22 | use super::{BasicBlock, FnValue, Type}; 23 | use crate::SmallCStr; 24 | 25 | // Definition of LLVM C API functions using our `repr(transparent)` types. 26 | extern "C" { 27 | fn LLVMFunctionType( 28 | ReturnType: Type<'_>, 29 | ParamTypes: *mut Type<'_>, 30 | ParamCount: ::libc::c_uint, 31 | IsVarArg: LLVMBool, 32 | ) -> LLVMTypeRef; 33 | } 34 | 35 | /// Wrapper for a LLVM Module with its own LLVM Context. 36 | pub struct Module { 37 | tsctx: LLVMOrcThreadSafeContextRef, 38 | ctx: LLVMContextRef, 39 | module: LLVMModuleRef, 40 | } 41 | 42 | impl<'llvm> Module { 43 | /// Create a new Module instance. 44 | /// 45 | /// # Panics 46 | /// 47 | /// Panics if creating the context or the module fails. 48 | pub fn new() -> Self { 49 | let (tsctx, ctx, module) = unsafe { 50 | // We generate a thread safe context because we are going to jit this IR module and 51 | // there is no method to create a thread safe context wrapper from an existing context 52 | // reference (at the time of writing this). 53 | // 54 | // ThreadSafeContext has shared ownership (start with ref count 1). 55 | // We must explicitly dispose our reference (dec ref count). 56 | let tc = LLVMOrcCreateNewThreadSafeContext(); 57 | assert!(!tc.is_null()); 58 | 59 | let c = LLVMOrcThreadSafeContextGetContext(tc); 60 | let m = LLVMModuleCreateWithNameInContext(b"module\0".as_ptr().cast(), c); 61 | assert!(!c.is_null() && !m.is_null()); 62 | (tc, c, m) 63 | }; 64 | 65 | Module { tsctx, ctx, module } 66 | } 67 | 68 | /// Get the raw LLVM context reference. 69 | #[inline] 70 | pub(super) fn ctx(&self) -> LLVMContextRef { 71 | self.ctx 72 | } 73 | 74 | /// Get the raw LLVM module reference. 75 | #[inline] 76 | pub(super) fn module(&self) -> LLVMModuleRef { 77 | self.module 78 | } 79 | 80 | /// Consume the module and turn in into a raw LLVM ThreadSafeModule reference. 81 | /// 82 | /// If ownership of the raw reference is not transferred (eg to the JIT), memory will be leaked 83 | /// in case the reference is disposed explicitly with LLVMOrcDisposeThreadSafeModule. 84 | #[inline] 85 | pub(super) fn into_raw_thread_safe_module(mut self) -> LLVMOrcThreadSafeModuleRef { 86 | let m = std::mem::replace(&mut self.module, std::ptr::null_mut()); 87 | 88 | // ThreadSafeModule has unique ownership. 89 | // Takes ownership of module and increments ThreadSafeContext ref count. 90 | // 91 | // We must not reference/dispose `m` after this call, but we need to dispose our `tsctx` 92 | // reference. 93 | let tm = unsafe { LLVMOrcCreateNewThreadSafeModule(m, self.tsctx) }; 94 | assert!(!tm.is_null()); 95 | 96 | tm 97 | } 98 | 99 | /// Dump LLVM IR emitted into the Module to stdout. 100 | pub fn dump(&self) { 101 | unsafe { LLVMDumpModule(self.module) }; 102 | } 103 | 104 | /// Get a type reference representing a `f64` float. 105 | /// 106 | /// # Panics 107 | /// 108 | /// Panics if LLVM API returns a `null` pointer. 109 | pub fn type_f64(&self) -> Type<'llvm> { 110 | let type_ref = unsafe { LLVMDoubleTypeInContext(self.ctx) }; 111 | Type::new(type_ref) 112 | } 113 | 114 | /// Get a type reference representing a `fn(args) -> ret` function. 115 | /// 116 | /// # Panics 117 | /// 118 | /// Panics if LLVM API returns a `null` pointer. 119 | pub fn type_fn(&'llvm self, args: &mut [Type<'llvm>], ret: Type<'llvm>) -> Type<'llvm> { 120 | let type_ref = unsafe { 121 | LLVMFunctionType( 122 | ret, 123 | args.as_mut_ptr(), 124 | args.len() as libc::c_uint, 125 | 0, /* IsVarArg */ 126 | ) 127 | }; 128 | Type::new(type_ref) 129 | } 130 | 131 | /// Add a function with the given `name` and `fn_type` to the module and return a value 132 | /// reference representing the function. 133 | /// 134 | /// # Panics 135 | /// 136 | /// Panics if LLVM API returns a `null` pointer or `name` could not be converted to a 137 | /// [`SmallCStr`]. 138 | pub fn add_fn(&'llvm self, name: &str, fn_type: Type<'llvm>) -> FnValue<'llvm> { 139 | debug_assert_eq!( 140 | fn_type.kind(), 141 | LLVMTypeKind::LLVMFunctionTypeKind, 142 | "Expected a function type when adding a function!" 143 | ); 144 | 145 | let name = SmallCStr::try_from(name) 146 | .expect("Failed to convert 'name' argument to small C string!"); 147 | 148 | let value_ref = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type.type_ref()) }; 149 | FnValue::new(value_ref) 150 | } 151 | 152 | /// Get a function value reference to the function with the given `name` if it was previously 153 | /// added to the module with [`add_fn`][Module::add_fn]. 154 | /// 155 | /// # Panics 156 | /// 157 | /// Panics if `name` could not be converted to a [`SmallCStr`]. 158 | pub fn get_fn(&'llvm self, name: &str) -> Option> { 159 | let name = SmallCStr::try_from(name) 160 | .expect("Failed to convert 'name' argument to small C string!"); 161 | 162 | let value_ref = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; 163 | 164 | (!value_ref.is_null()).then(|| FnValue::new(value_ref)) 165 | } 166 | 167 | /// Append a Basic Block to the end of the function referenced by the value reference 168 | /// `fn_value`. 169 | /// 170 | /// # Panics 171 | /// 172 | /// Panics if LLVM API returns a `null` pointer. 173 | pub fn append_basic_block(&'llvm self, fn_value: FnValue<'llvm>) -> BasicBlock<'llvm> { 174 | let block = unsafe { 175 | LLVMAppendBasicBlockInContext( 176 | self.ctx, 177 | fn_value.value_ref(), 178 | b"block\0".as_ptr().cast(), 179 | ) 180 | }; 181 | assert!(!block.is_null()); 182 | 183 | BasicBlock::new(block) 184 | } 185 | 186 | /// Create a free-standing Basic Block without adding it to a function. 187 | /// This can be added to a function at a later point in time with 188 | /// [`FnValue::append_basic_block`]. 189 | /// 190 | /// # Panics 191 | /// 192 | /// Panics if LLVM API returns a `null` pointer. 193 | pub fn create_basic_block(&self) -> BasicBlock<'llvm> { 194 | let block = unsafe { LLVMCreateBasicBlockInContext(self.ctx, b"block\0".as_ptr().cast()) }; 195 | assert!(!block.is_null()); 196 | 197 | BasicBlock::new(block) 198 | } 199 | } 200 | 201 | impl Drop for Module { 202 | fn drop(&mut self) { 203 | unsafe { 204 | // In case we turned the module into a ThreadSafeModule, we must not dispose the module 205 | // reference because ThreadSafeModule took ownership! 206 | if !self.module.is_null() { 207 | LLVMDisposeModule(self.module); 208 | } 209 | 210 | // Dispose ThreadSafeContext reference (dec ref count) in any case. 211 | LLVMOrcDisposeThreadSafeContext(self.tsctx); 212 | } 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /src/llvm/pass_manager.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_sys::{ 6 | core::{ 7 | LLVMCreateFunctionPassManagerForModule, LLVMDisposePassManager, 8 | LLVMInitializeFunctionPassManager, LLVMRunFunctionPassManager, 9 | }, 10 | prelude::LLVMPassManagerRef, 11 | transforms::{ 12 | instcombine::LLVMAddInstructionCombiningPass, 13 | scalar::{LLVMAddCFGSimplificationPass, LLVMAddNewGVNPass, LLVMAddReassociatePass}, 14 | }, 15 | }; 16 | 17 | use std::marker::PhantomData; 18 | 19 | use super::{FnValue, Module}; 20 | 21 | /// Wrapper for a LLVM Function PassManager (legacy). 22 | pub struct FunctionPassManager<'llvm> { 23 | fpm: LLVMPassManagerRef, 24 | _ctx: PhantomData<&'llvm ()>, 25 | } 26 | 27 | impl<'llvm> FunctionPassManager<'llvm> { 28 | /// Create a new Function PassManager with the following optimization passes 29 | /// - InstructionCombiningPass 30 | /// - ReassociatePass 31 | /// - NewGVNPass 32 | /// - CFGSimplificationPass 33 | /// 34 | /// The list of selected optimization passes is taken from the tutorial chapter [LLVM 35 | /// Optimization Passes](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl04.html#id3). 36 | pub fn with_ctx(module: &'llvm Module) -> FunctionPassManager<'llvm> { 37 | let fpm = unsafe { 38 | // Borrows module reference. 39 | LLVMCreateFunctionPassManagerForModule(module.module()) 40 | }; 41 | assert!(!fpm.is_null()); 42 | 43 | unsafe { 44 | // Do simple "peephole" optimizations and bit-twiddling optzns. 45 | LLVMAddInstructionCombiningPass(fpm); 46 | // Reassociate expressions. 47 | LLVMAddReassociatePass(fpm); 48 | // Eliminate Common SubExpressions. 49 | LLVMAddNewGVNPass(fpm); 50 | // Simplify the control flow graph (deleting unreachable blocks, etc). 51 | LLVMAddCFGSimplificationPass(fpm); 52 | 53 | let fail = LLVMInitializeFunctionPassManager(fpm); 54 | assert_eq!(fail, 0); 55 | } 56 | 57 | FunctionPassManager { 58 | fpm, 59 | _ctx: PhantomData, 60 | } 61 | } 62 | 63 | /// Run the optimization passes registered with the Function PassManager on the function 64 | /// referenced by `fn_value`. 65 | pub fn run(&'llvm self, fn_value: FnValue<'llvm>) { 66 | unsafe { 67 | // Returns 1 if any of the passes modified the function, false otherwise. 68 | LLVMRunFunctionPassManager(self.fpm, fn_value.value_ref()); 69 | } 70 | } 71 | } 72 | 73 | impl Drop for FunctionPassManager<'_> { 74 | fn drop(&mut self) { 75 | unsafe { 76 | LLVMDisposePassManager(self.fpm); 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/llvm/type_.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_sys::{ 6 | core::{LLVMConstReal, LLVMDumpType, LLVMGetTypeKind}, 7 | prelude::LLVMTypeRef, 8 | LLVMTypeKind, 9 | }; 10 | 11 | use std::marker::PhantomData; 12 | 13 | use super::Value; 14 | 15 | /// Wrapper for a LLVM Type Reference. 16 | #[derive(Copy, Clone)] 17 | #[repr(transparent)] 18 | pub struct Type<'llvm>(LLVMTypeRef, PhantomData<&'llvm ()>); 19 | 20 | impl<'llvm> Type<'llvm> { 21 | /// Create a new Type instance. 22 | /// 23 | /// # Panics 24 | /// 25 | /// Panics if `type_ref` is a null pointer. 26 | pub(super) fn new(type_ref: LLVMTypeRef) -> Self { 27 | assert!(!type_ref.is_null()); 28 | Type(type_ref, PhantomData) 29 | } 30 | 31 | /// Get the raw LLVM type reference. 32 | #[inline] 33 | pub(super) fn type_ref(&self) -> LLVMTypeRef { 34 | self.0 35 | } 36 | 37 | /// Get the LLVM type kind for the given type reference. 38 | pub(super) fn kind(&self) -> LLVMTypeKind { 39 | unsafe { LLVMGetTypeKind(self.type_ref()) } 40 | } 41 | 42 | /// Dump the LLVM Type to stdout. 43 | pub fn dump(&self) { 44 | unsafe { LLVMDumpType(self.type_ref()) }; 45 | } 46 | 47 | /// Get a value reference representing the const `f64` value. 48 | /// 49 | /// # Panics 50 | /// 51 | /// Panics if LLVM API returns a `null` pointer. 52 | pub fn const_f64(self, n: f64) -> Value<'llvm> { 53 | debug_assert_eq!( 54 | self.kind(), 55 | LLVMTypeKind::LLVMDoubleTypeKind, 56 | "Expected a double type when creating const f64 value!" 57 | ); 58 | 59 | let value_ref = unsafe { LLVMConstReal(self.type_ref(), n) }; 60 | Value::new(value_ref) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/llvm/value.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | #![allow(unused)] 6 | 7 | use llvm_sys::{ 8 | analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction}, 9 | core::{ 10 | LLVMAddIncoming, LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams, 11 | LLVMDumpValue, LLVMGetParam, LLVMGetValueKind, LLVMGetValueName2, LLVMGlobalGetValueType, 12 | LLVMIsAFunction, LLVMIsAPHINode, LLVMSetValueName2, LLVMTypeOf, 13 | }, 14 | prelude::LLVMValueRef, 15 | LLVMTypeKind, LLVMValueKind, 16 | }; 17 | use std::ffi::CStr; 18 | use std::marker::PhantomData; 19 | use std::ops::Deref; 20 | 21 | use super::BasicBlock; 22 | use super::Type; 23 | 24 | /// Wrapper for a LLVM Value Reference. 25 | #[derive(Copy, Clone)] 26 | #[repr(transparent)] 27 | pub struct Value<'llvm>(LLVMValueRef, PhantomData<&'llvm ()>); 28 | 29 | impl<'llvm> Value<'llvm> { 30 | /// Create a new Value instance. 31 | /// 32 | /// # Panics 33 | /// 34 | /// Panics if `value_ref` is a null pointer. 35 | pub(super) fn new(value_ref: LLVMValueRef) -> Self { 36 | assert!(!value_ref.is_null()); 37 | Value(value_ref, PhantomData) 38 | } 39 | 40 | /// Get the raw LLVM value reference. 41 | #[inline] 42 | pub(super) fn value_ref(&self) -> LLVMValueRef { 43 | self.0 44 | } 45 | 46 | /// Get the LLVM value kind for the given value reference. 47 | pub(super) fn kind(&self) -> LLVMValueKind { 48 | unsafe { LLVMGetValueKind(self.value_ref()) } 49 | } 50 | 51 | /// Check if value is `function` type. 52 | pub(super) fn is_function(&self) -> bool { 53 | let cast = unsafe { LLVMIsAFunction(self.value_ref()) }; 54 | !cast.is_null() 55 | } 56 | 57 | /// Check if value is `phinode` type. 58 | pub(super) fn is_phinode(&self) -> bool { 59 | let cast = unsafe { LLVMIsAPHINode(self.value_ref()) }; 60 | !cast.is_null() 61 | } 62 | 63 | /// Dump the LLVM Value to stdout. 64 | pub fn dump(&self) { 65 | unsafe { LLVMDumpValue(self.value_ref()) }; 66 | } 67 | 68 | /// Get a type reference representing for the given value reference. 69 | /// 70 | /// # Panics 71 | /// 72 | /// Panics if LLVM API returns a `null` pointer. 73 | pub fn type_of(&self) -> Type<'llvm> { 74 | let type_ref = unsafe { LLVMTypeOf(self.value_ref()) }; 75 | Type::new(type_ref) 76 | } 77 | 78 | /// Set the name for the given value reference. 79 | /// 80 | /// # Panics 81 | /// 82 | /// Panics if LLVM API returns a `null` pointer. 83 | pub fn set_name(&self, name: &str) { 84 | unsafe { LLVMSetValueName2(self.value_ref(), name.as_ptr().cast(), name.len()) }; 85 | } 86 | 87 | /// Get the name for the given value reference. 88 | /// 89 | /// # Panics 90 | /// 91 | /// Panics if LLVM API returns a `null` pointer. 92 | pub fn get_name(&self) -> &'llvm str { 93 | let name = unsafe { 94 | let mut len: libc::size_t = 0; 95 | let name = LLVMGetValueName2(self.0, &mut len as _); 96 | assert!(!name.is_null()); 97 | 98 | CStr::from_ptr(name) 99 | }; 100 | 101 | // TODO: Does this string live for the time of the LLVM context?! 102 | name.to_str() 103 | .expect("Expected valid UTF8 string from LLVM API") 104 | } 105 | 106 | /// Check if value is of `f64` type. 107 | pub fn is_f64(&self) -> bool { 108 | self.type_of().kind() == LLVMTypeKind::LLVMDoubleTypeKind 109 | } 110 | 111 | /// Check if value is of integer type. 112 | pub fn is_int(&self) -> bool { 113 | self.type_of().kind() == LLVMTypeKind::LLVMIntegerTypeKind 114 | } 115 | } 116 | 117 | /// Wrapper for a LLVM Value Reference specialized for contexts where function values are needed. 118 | #[derive(Copy, Clone)] 119 | #[repr(transparent)] 120 | pub struct FnValue<'llvm>(Value<'llvm>); 121 | 122 | impl<'llvm> Deref for FnValue<'llvm> { 123 | type Target = Value<'llvm>; 124 | fn deref(&self) -> &Self::Target { 125 | &self.0 126 | } 127 | } 128 | 129 | impl<'llvm> FnValue<'llvm> { 130 | /// Create a new FnValue instance. 131 | /// 132 | /// # Panics 133 | /// 134 | /// Panics if `value_ref` is a null pointer. 135 | pub(super) fn new(value_ref: LLVMValueRef) -> Self { 136 | let value = Value::new(value_ref); 137 | debug_assert!( 138 | value.is_function(), 139 | "Expected a fn value when constructing FnValue!" 140 | ); 141 | 142 | FnValue(value) 143 | } 144 | 145 | /// Get a type reference representing the function type (return + args) of the given function 146 | /// value. 147 | /// 148 | /// # Panics 149 | /// 150 | /// Panics if LLVM API returns a `null` pointer. 151 | pub fn fn_type(&self) -> Type<'llvm> { 152 | // https://github.com/llvm/llvm-project/issues/72798 153 | let type_ref = unsafe { LLVMGlobalGetValueType(self.value_ref()) }; 154 | Type::new(type_ref) 155 | } 156 | 157 | /// Get the number of function arguments for the given function value. 158 | pub fn args(&self) -> usize { 159 | unsafe { LLVMCountParams(self.value_ref()) as usize } 160 | } 161 | 162 | /// Get a value reference for the function argument at index `idx`. 163 | /// 164 | /// # Panics 165 | /// 166 | /// Panics if LLVM API returns a `null` pointer or indexed out of bounds. 167 | pub fn arg(&self, idx: usize) -> Value<'llvm> { 168 | assert!(idx < self.args()); 169 | 170 | let value_ref = unsafe { LLVMGetParam(self.value_ref(), idx as libc::c_uint) }; 171 | Value::new(value_ref) 172 | } 173 | 174 | /// Get the number of Basic Blocks for the given function value. 175 | pub fn basic_blocks(&self) -> usize { 176 | unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize } 177 | } 178 | 179 | /// Append a Basic Block to the end of the function value. 180 | pub fn append_basic_block(&self, bb: BasicBlock<'llvm>) { 181 | unsafe { 182 | LLVMAppendExistingBasicBlock(self.value_ref(), bb.bb_ref()); 183 | } 184 | } 185 | 186 | /// Verify that the given function is valid. 187 | pub fn verify(&self) -> bool { 188 | unsafe { 189 | LLVMVerifyFunction( 190 | self.value_ref(), 191 | LLVMVerifierFailureAction::LLVMPrintMessageAction, 192 | ) == 0 193 | } 194 | } 195 | } 196 | 197 | /// Wrapper for a LLVM Value Reference specialized for contexts where phi values are needed. 198 | #[derive(Copy, Clone)] 199 | #[repr(transparent)] 200 | pub struct PhiValue<'llvm>(Value<'llvm>); 201 | 202 | impl<'llvm> Deref for PhiValue<'llvm> { 203 | type Target = Value<'llvm>; 204 | fn deref(&self) -> &Self::Target { 205 | &self.0 206 | } 207 | } 208 | 209 | impl<'llvm> PhiValue<'llvm> { 210 | /// Create a new PhiValue instance. 211 | /// 212 | /// # Panics 213 | /// 214 | /// Panics if `value_ref` is a null pointer. 215 | pub(super) fn new(value_ref: LLVMValueRef) -> Self { 216 | let value = Value::new(value_ref); 217 | debug_assert!( 218 | value.is_phinode(), 219 | "Expected a phinode value when constructing PhiValue!" 220 | ); 221 | 222 | PhiValue(value) 223 | } 224 | 225 | /// Add an incoming value to the end of a PHI list. 226 | pub fn add_incoming(&self, ival: Value<'llvm>, ibb: BasicBlock<'llvm>) { 227 | debug_assert_eq!( 228 | ival.type_of().kind(), 229 | self.type_of().kind(), 230 | "Type of incoming phi value must be the same as the type used to build the phi node." 231 | ); 232 | 233 | unsafe { 234 | LLVMAddIncoming( 235 | self.value_ref(), 236 | &mut ival.value_ref() as _, 237 | &mut ibb.bb_ref() as _, 238 | 1, 239 | ); 240 | } 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use llvm_kaleidoscope_rs::{ 6 | codegen::Codegen, 7 | lexer::{Lexer, Token}, 8 | llvm, 9 | parser::{Parser, PrototypeAST}, 10 | Either, 11 | }; 12 | 13 | use std::collections::HashMap; 14 | use std::io::{Read, Write}; 15 | 16 | #[no_mangle] 17 | #[inline(never)] 18 | pub extern "C" fn putchard(c: libc::c_double) -> f64 { 19 | std::io::stdout() 20 | .write(&[c as u8]) 21 | .expect("Failed to write to stdout!"); 22 | 0f64 23 | } 24 | 25 | fn main_loop(mut parser: Parser) 26 | where 27 | I: Iterator, 28 | { 29 | // Initialize LLVM module with its own context. 30 | // We will emit LLVM IR into this module. 31 | let mut module = llvm::Module::new(); 32 | 33 | // Create a new JIT, based on the LLVM LLJIT. 34 | let jit = llvm::LLJit::new(); 35 | 36 | // Enable lookup of dynamic symbols in the current process from the JIT. 37 | jit.enable_process_symbols(); 38 | 39 | // Keep track of prototype names to their respective ASTs. 40 | // 41 | // This is useful since we jit every function definition into its own LLVM module. 42 | // To allow calling functions defined in previous LLVM modules we keep track of their 43 | // prototypes and generate IR for their declarations when they are called from another module. 44 | let mut fn_protos: HashMap = HashMap::new(); 45 | 46 | // When adding an IR module to the JIT, it will hand out a ResourceTracker. When the 47 | // ResourceTracker is dropped, the code generated from the corresponding module will be removed 48 | // from the JIT. 49 | // 50 | // For each function we want to keep the code generated for the last definition, hence we need 51 | // to keep their ResourceTracker alive. 52 | let mut fn_jit_rt: HashMap = HashMap::new(); 53 | 54 | loop { 55 | match parser.cur_tok() { 56 | Token::Eof => break, 57 | Token::Char(';') => { 58 | // Ignore top-level semicolon. 59 | parser.get_next_token(); 60 | } 61 | Token::Def => match parser.parse_definition() { 62 | Ok(func) => { 63 | println!("Parse 'def'"); 64 | let func_name = &func.0 .0; 65 | 66 | // If we already jitted that function, remove the last definition from the JIT 67 | // by dropping the corresponding ResourceTracker. 68 | fn_jit_rt.remove(func_name); 69 | 70 | if let Ok(func_ir) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) 71 | { 72 | func_ir.dump(); 73 | 74 | // Add module to the JIT. 75 | let rt = jit.add_module(module); 76 | 77 | // Keep track of the ResourceTracker to keep the module code in the JIT. 78 | fn_jit_rt.insert(func_name.to_string(), rt); 79 | 80 | // Initialize a new module. 81 | module = llvm::Module::new(); 82 | } 83 | } 84 | Err(err) => { 85 | eprintln!("Error: {:?}", err); 86 | parser.get_next_token(); 87 | } 88 | }, 89 | Token::Extern => match parser.parse_extern() { 90 | Ok(proto) => { 91 | println!("Parse 'extern'"); 92 | if let Ok(proto_ir) = 93 | Codegen::compile(&module, &mut fn_protos, Either::A(&proto)) 94 | { 95 | proto_ir.dump(); 96 | 97 | // Keep track of external function declaration. 98 | fn_protos.insert(proto.0.clone(), proto); 99 | } 100 | } 101 | Err(err) => { 102 | eprintln!("Error: {:?}", err); 103 | parser.get_next_token(); 104 | } 105 | }, 106 | _ => match parser.parse_top_level_expr() { 107 | Ok(func) => { 108 | println!("Parse top-level expression"); 109 | if let Ok(func) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) { 110 | func.dump(); 111 | 112 | // Add module to the JIT. Code will be removed when `_rt` is dropped. 113 | let _rt = jit.add_module(module); 114 | 115 | // Initialize a new module. 116 | module = llvm::Module::new(); 117 | 118 | // Call the top level expression. 119 | let fp = jit.find_symbol:: f64>("__anon_expr"); 120 | unsafe { 121 | println!("Evaluated to {}", fp()); 122 | } 123 | } 124 | } 125 | Err(err) => { 126 | eprintln!("Error: {:?}", err); 127 | parser.get_next_token(); 128 | } 129 | }, 130 | }; 131 | } 132 | 133 | // Dump all the emitted LLVM IR to stdout. 134 | module.dump(); 135 | } 136 | 137 | fn run_kaleidoscope(lexer: Lexer) 138 | where 139 | I: Iterator, 140 | { 141 | // Create parser for kaleidoscope. 142 | let mut parser = Parser::new(lexer); 143 | 144 | // Throw first coin and initialize cur_tok. 145 | parser.get_next_token(); 146 | 147 | // Initialize native target for jitting. 148 | llvm::initialize_native_taget(); 149 | 150 | main_loop(parser); 151 | 152 | // De-allocate managed static LLVM data. 153 | llvm::shutdown(); 154 | } 155 | 156 | fn main() { 157 | match std::env::args().nth(1) { 158 | Some(file) => { 159 | println!("Parse {}.", file); 160 | 161 | // Create lexer over file. 162 | let lexer = Lexer::new( 163 | std::fs::File::open(&file) 164 | .expect(&format!("Failed to open file {}!", file)) 165 | .bytes() 166 | .filter_map(|v| { 167 | let v = v.ok()?; 168 | Some(v.into()) 169 | }), 170 | ); 171 | run_kaleidoscope(lexer); 172 | } 173 | None => { 174 | println!("Parse stdin."); 175 | println!("ENTER to parse current input."); 176 | println!("C-d to exit."); 177 | 178 | // Create lexer over stdin. 179 | let lexer = Lexer::new(std::io::stdin().bytes().filter_map(|v| { 180 | let v = v.ok()?; 181 | Some(v.into()) 182 | })); 183 | run_kaleidoscope(lexer); 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/parser.rs: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // 3 | // Copyright (c) 2021, Johannes Stoelp 4 | 5 | use crate::lexer::{Lexer, Token}; 6 | 7 | #[derive(Debug, PartialEq)] 8 | pub enum ExprAST { 9 | /// Number - Expression class for numeric literals like "1.0". 10 | Number(f64), 11 | 12 | /// Variable - Expression class for referencing a variable, like "a". 13 | Variable(String), 14 | 15 | /// Binary - Expression class for a binary operator. 16 | Binary(char, Box, Box), 17 | 18 | /// Call - Expression class for function calls. 19 | Call(String, Vec), 20 | 21 | /// If - Expression class for if/then/else. 22 | If { 23 | cond: Box, 24 | then: Box, 25 | else_: Box, 26 | }, 27 | 28 | /// ForExprAST - Expression class for for/in. 29 | For { 30 | var: String, 31 | start: Box, 32 | end: Box, 33 | step: Option>, 34 | body: Box, 35 | }, 36 | } 37 | 38 | /// PrototypeAST - This class represents the "prototype" for a function, 39 | /// which captures its name, and its argument names (thus implicitly the number 40 | /// of arguments the function takes). 41 | #[derive(Debug, PartialEq, Clone)] 42 | pub struct PrototypeAST(pub String, pub Vec); 43 | 44 | /// FunctionAST - This class represents a function definition itself. 45 | #[derive(Debug, PartialEq)] 46 | pub struct FunctionAST(pub PrototypeAST, pub ExprAST); 47 | 48 | /// Parse result with String as Error type (to be compliant with tutorial). 49 | type ParseResult = Result; 50 | 51 | /// Parser for the `kaleidoscope` language. 52 | pub struct Parser 53 | where 54 | I: Iterator, 55 | { 56 | lexer: Lexer, 57 | cur_tok: Option, 58 | } 59 | 60 | impl Parser 61 | where 62 | I: Iterator, 63 | { 64 | pub fn new(lexer: Lexer) -> Self { 65 | Parser { 66 | lexer, 67 | cur_tok: None, 68 | } 69 | } 70 | 71 | // ----------------------- 72 | // Simple Token Buffer 73 | // ----------------------- 74 | 75 | /// Implement the global variable `int CurTok;` from the tutorial. 76 | /// 77 | /// # Panics 78 | /// Panics if the parser doesn't have a current token. 79 | pub fn cur_tok(&self) -> &Token { 80 | self.cur_tok.as_ref().expect("Parser: Expected cur_token!") 81 | } 82 | 83 | /// Advance the `cur_tok` by getting the next token from the lexer. 84 | /// 85 | /// Implement the fucntion `int getNextToken();` from the tutorial. 86 | pub fn get_next_token(&mut self) { 87 | self.cur_tok = Some(self.lexer.gettok()); 88 | } 89 | 90 | // ---------------------------- 91 | // Basic Expression Parsing 92 | // ---------------------------- 93 | 94 | /// numberexpr ::= number 95 | /// 96 | /// Implement `std::unique_ptr ParseNumberExpr();` from the tutorial. 97 | fn parse_num_expr(&mut self) -> ParseResult { 98 | match *self.cur_tok() { 99 | Token::Number(num) => { 100 | // Consume the number token. 101 | self.get_next_token(); 102 | Ok(ExprAST::Number(num)) 103 | } 104 | _ => unreachable!(), 105 | } 106 | } 107 | 108 | /// parenexpr ::= '(' expression ')' 109 | /// 110 | /// Implement `std::unique_ptr ParseParenExpr();` from the tutorial. 111 | fn parse_paren_expr(&mut self) -> ParseResult { 112 | // Eat '(' token. 113 | assert_eq!(*self.cur_tok(), Token::Char('(')); 114 | self.get_next_token(); 115 | 116 | let v = self.parse_expression()?; 117 | 118 | if *self.cur_tok() == Token::Char(')') { 119 | // Eat ')' token. 120 | self.get_next_token(); 121 | Ok(v) 122 | } else { 123 | Err("expected ')'".into()) 124 | } 125 | } 126 | 127 | /// identifierexpr 128 | /// ::= identifier 129 | /// ::= identifier '(' expression* ')' 130 | /// 131 | /// Implement `std::unique_ptr ParseIdentifierExpr();` from the tutorial. 132 | fn parse_identifier_expr(&mut self) -> ParseResult { 133 | let id_name = match self.cur_tok.take() { 134 | Some(Token::Identifier(id)) => { 135 | // Consume identifier. 136 | self.get_next_token(); 137 | id 138 | } 139 | _ => unreachable!(), 140 | }; 141 | 142 | if *self.cur_tok() != Token::Char('(') { 143 | // Simple variable reference. 144 | Ok(ExprAST::Variable(id_name)) 145 | } else { 146 | // Call. 147 | 148 | // Eat '(' token. 149 | self.get_next_token(); 150 | 151 | let mut args: Vec = Vec::new(); 152 | 153 | // If there are arguments collect them. 154 | if *self.cur_tok() != Token::Char(')') { 155 | loop { 156 | let arg = self.parse_expression()?; 157 | args.push(arg); 158 | 159 | if *self.cur_tok() == Token::Char(')') { 160 | break; 161 | } 162 | 163 | if *self.cur_tok() != Token::Char(',') { 164 | return Err("Expected ')' or ',' in argument list".into()); 165 | } 166 | 167 | self.get_next_token(); 168 | } 169 | } 170 | 171 | assert_eq!(*self.cur_tok(), Token::Char(')')); 172 | // Eat ')' token. 173 | self.get_next_token(); 174 | 175 | Ok(ExprAST::Call(id_name, args)) 176 | } 177 | } 178 | 179 | /// ifexpr ::= 'if' expression 'then' expression 'else' expression 180 | /// 181 | /// Implement `std::unique_ptr ParseIfExpr();` from the tutorial. 182 | fn parse_if_expr(&mut self) -> ParseResult { 183 | // Consume 'if' token. 184 | assert_eq!(*self.cur_tok(), Token::If); 185 | self.get_next_token(); 186 | 187 | let cond = self.parse_expression()?; 188 | 189 | if *dbg!(self.cur_tok()) != Token::Then { 190 | return Err("Expected 'then'".into()); 191 | } 192 | // Consume 'then' token. 193 | self.get_next_token(); 194 | 195 | let then = self.parse_expression()?; 196 | 197 | if *self.cur_tok() != Token::Else { 198 | return Err("Expected 'else'".into()); 199 | } 200 | // Consume 'else' token. 201 | self.get_next_token(); 202 | 203 | let else_ = self.parse_expression()?; 204 | 205 | Ok(ExprAST::If { 206 | cond: Box::new(cond), 207 | then: Box::new(then), 208 | else_: Box::new(else_), 209 | }) 210 | } 211 | 212 | /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression 213 | /// 214 | /// Implement `std::unique_ptr ParseForExpr();` from the tutorial. 215 | fn parse_for_expr(&mut self) -> ParseResult { 216 | // Consume the 'for' token. 217 | assert_eq!(*self.cur_tok(), Token::For); 218 | self.get_next_token(); 219 | 220 | let var = match self 221 | .parse_identifier_expr() 222 | .map_err(|_| String::from("expected identifier after 'for'"))? 223 | { 224 | ExprAST::Variable(var) => var, 225 | _ => unreachable!(), 226 | }; 227 | 228 | // Consume the '=' token. 229 | if *self.cur_tok() != Token::Char('=') { 230 | return Err("expected '=' after for".into()); 231 | } 232 | self.get_next_token(); 233 | 234 | let start = self.parse_expression()?; 235 | 236 | // Consume the ',' token. 237 | if *self.cur_tok() != Token::Char(',') { 238 | return Err("expected ',' after for start value".into()); 239 | } 240 | self.get_next_token(); 241 | 242 | let end = self.parse_expression()?; 243 | 244 | let step = if *self.cur_tok() == Token::Char(',') { 245 | // Consume the ',' token. 246 | self.get_next_token(); 247 | 248 | Some(self.parse_expression()?) 249 | } else { 250 | None 251 | }; 252 | 253 | // Consume the 'in' token. 254 | if *self.cur_tok() != Token::In { 255 | return Err("expected 'in' after for".into()); 256 | } 257 | self.get_next_token(); 258 | 259 | let body = self.parse_expression()?; 260 | 261 | Ok(ExprAST::For { 262 | var, 263 | start: Box::new(start), 264 | end: Box::new(end), 265 | step: step.map(|s| Box::new(s)), 266 | body: Box::new(body), 267 | }) 268 | } 269 | 270 | /// primary 271 | /// ::= identifierexpr 272 | /// ::= numberexpr 273 | /// ::= parenexpr 274 | /// 275 | /// Implement `std::unique_ptr ParsePrimary();` from the tutorial. 276 | fn parse_primary(&mut self) -> ParseResult { 277 | match *self.cur_tok() { 278 | Token::Identifier(_) => self.parse_identifier_expr(), 279 | Token::Number(_) => self.parse_num_expr(), 280 | Token::Char('(') => self.parse_paren_expr(), 281 | Token::If => self.parse_if_expr(), 282 | Token::For => self.parse_for_expr(), 283 | _ => Err("unknown token when expecting an expression".into()), 284 | } 285 | } 286 | 287 | // ----------------------------- 288 | // Binary Expression Parsing 289 | // ----------------------------- 290 | 291 | /// /// expression 292 | /// ::= primary binoprhs 293 | /// 294 | /// Implement `std::unique_ptr ParseExpression();` from the tutorial. 295 | fn parse_expression(&mut self) -> ParseResult { 296 | let lhs = self.parse_primary()?; 297 | self.parse_bin_op_rhs(0, lhs) 298 | } 299 | 300 | /// binoprhs 301 | /// ::= ('+' primary)* 302 | /// 303 | /// Implement `std::unique_ptr ParseBinOpRHS(int ExprPrec, std::unique_ptr LHS);` from the tutorial. 304 | fn parse_bin_op_rhs(&mut self, expr_prec: isize, mut lhs: ExprAST) -> ParseResult { 305 | loop { 306 | let tok_prec = get_tok_precedence(self.cur_tok()); 307 | 308 | // Not a binary operator or precedence is too small. 309 | if tok_prec < expr_prec { 310 | return Ok(lhs); 311 | } 312 | 313 | let binop = match self.cur_tok.take() { 314 | Some(Token::Char(c)) => { 315 | // Eat binary operator. 316 | self.get_next_token(); 317 | c 318 | } 319 | _ => unreachable!(), 320 | }; 321 | 322 | // lhs BINOP1 rhs BINOP2 remrhs 323 | // ^^^^^^ ^^^^^^ 324 | // tok_prec next_prec 325 | // 326 | // In case BINOP1 has higher precedence, we are done here and can build a 'Binary' AST 327 | // node between 'lhs' and 'rhs'. 328 | // 329 | // In case BINOP2 has higher precedence, we take 'rhs' as 'lhs' and recurse into the 330 | // 'remrhs' expression first. 331 | 332 | // Parse primary expression after binary operator. 333 | let mut rhs = self.parse_primary()?; 334 | 335 | let next_prec = get_tok_precedence(self.cur_tok()); 336 | if tok_prec < next_prec { 337 | // BINOP2 has higher precedence thatn BINOP1, recurse into 'remhs'. 338 | rhs = self.parse_bin_op_rhs(tok_prec + 1, rhs)? 339 | } 340 | 341 | lhs = ExprAST::Binary(binop, Box::new(lhs), Box::new(rhs)); 342 | } 343 | } 344 | 345 | // -------------------- 346 | // Parsing the Rest 347 | // -------------------- 348 | 349 | /// prototype 350 | /// ::= id '(' id* ')' 351 | /// 352 | /// Implement `std::unique_ptr ParsePrototype();` from the tutorial. 353 | fn parse_prototype(&mut self) -> ParseResult { 354 | let id_name = match self.cur_tok.take() { 355 | Some(Token::Identifier(id)) => { 356 | // Consume the identifier. 357 | self.get_next_token(); 358 | id 359 | } 360 | other => { 361 | // Plug back current token. 362 | self.cur_tok = other; 363 | return Err("Expected function name in prototype".into()); 364 | } 365 | }; 366 | 367 | if *self.cur_tok() != Token::Char('(') { 368 | return Err("Expected '(' in prototype".into()); 369 | } 370 | 371 | let mut args: Vec = Vec::new(); 372 | loop { 373 | self.get_next_token(); 374 | 375 | match self.cur_tok.take() { 376 | Some(Token::Identifier(arg)) => args.push(arg), 377 | Some(Token::Char(',')) => {} 378 | other => { 379 | self.cur_tok = other; 380 | break; 381 | } 382 | } 383 | } 384 | 385 | if *self.cur_tok() != Token::Char(')') { 386 | return Err("Expected ')' in prototype".into()); 387 | } 388 | 389 | // Consume ')'. 390 | self.get_next_token(); 391 | 392 | Ok(PrototypeAST(id_name, args)) 393 | } 394 | 395 | /// definition ::= 'def' prototype expression 396 | /// 397 | /// Implement `std::unique_ptr ParseDefinition();` from the tutorial. 398 | pub fn parse_definition(&mut self) -> ParseResult { 399 | // Consume 'def' token. 400 | assert_eq!(*self.cur_tok(), Token::Def); 401 | self.get_next_token(); 402 | 403 | let proto = self.parse_prototype()?; 404 | let expr = self.parse_expression()?; 405 | 406 | Ok(FunctionAST(proto, expr)) 407 | } 408 | 409 | /// external ::= 'extern' prototype 410 | /// 411 | /// Implement `std::unique_ptr ParseExtern();` from the tutorial. 412 | pub fn parse_extern(&mut self) -> ParseResult { 413 | // Consume 'extern' token. 414 | assert_eq!(*self.cur_tok(), Token::Extern); 415 | self.get_next_token(); 416 | 417 | self.parse_prototype() 418 | } 419 | 420 | /// toplevelexpr ::= expression 421 | /// 422 | /// Implement `std::unique_ptr ParseTopLevelExpr();` from the tutorial. 423 | pub fn parse_top_level_expr(&mut self) -> ParseResult { 424 | let e = self.parse_expression()?; 425 | let proto = PrototypeAST("__anon_expr".into(), Vec::new()); 426 | Ok(FunctionAST(proto, e)) 427 | } 428 | } 429 | 430 | /// Get the binary operator precedence. 431 | /// 432 | /// Implement `int GetTokPrecedence();` from the tutorial. 433 | fn get_tok_precedence(tok: &Token) -> isize { 434 | match tok { 435 | Token::Char('<') => 10, 436 | Token::Char('+') => 20, 437 | Token::Char('-') => 20, 438 | Token::Char('*') => 40, 439 | _ => -1, 440 | } 441 | } 442 | 443 | #[cfg(test)] 444 | mod test { 445 | use super::{ExprAST, FunctionAST, Parser, PrototypeAST}; 446 | use crate::lexer::Lexer; 447 | 448 | fn parser(input: &str) -> Parser { 449 | let l = Lexer::new(input.chars()); 450 | let mut p = Parser::new(l); 451 | 452 | // Drop initial coin, initialize cur_tok. 453 | p.get_next_token(); 454 | 455 | p 456 | } 457 | 458 | #[test] 459 | fn parse_number() { 460 | let mut p = parser("13.37"); 461 | 462 | assert_eq!(p.parse_num_expr(), Ok(ExprAST::Number(13.37f64))); 463 | } 464 | 465 | #[test] 466 | fn parse_variable() { 467 | let mut p = parser("foop"); 468 | 469 | assert_eq!( 470 | p.parse_identifier_expr(), 471 | Ok(ExprAST::Variable("foop".into())) 472 | ); 473 | } 474 | 475 | #[test] 476 | fn parse_if() { 477 | let mut p = parser("if 1 then 2 else 3"); 478 | 479 | let cond = Box::new(ExprAST::Number(1f64)); 480 | let then = Box::new(ExprAST::Number(2f64)); 481 | let else_ = Box::new(ExprAST::Number(3f64)); 482 | 483 | assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ })); 484 | 485 | let mut p = parser("if foo() then bar(2) else baz(3)"); 486 | 487 | let cond = Box::new(ExprAST::Call("foo".into(), vec![])); 488 | let then = Box::new(ExprAST::Call("bar".into(), vec![ExprAST::Number(2f64)])); 489 | let else_ = Box::new(ExprAST::Call("baz".into(), vec![ExprAST::Number(3f64)])); 490 | 491 | assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ })); 492 | } 493 | 494 | #[test] 495 | fn parse_for() { 496 | let mut p = parser("for i = 1, 2, 3 in 4"); 497 | 498 | let var = String::from("i"); 499 | let start = Box::new(ExprAST::Number(1f64)); 500 | let end = Box::new(ExprAST::Number(2f64)); 501 | let step = Some(Box::new(ExprAST::Number(3f64))); 502 | let body = Box::new(ExprAST::Number(4f64)); 503 | 504 | assert_eq!( 505 | p.parse_for_expr(), 506 | Ok(ExprAST::For { 507 | var, 508 | start, 509 | end, 510 | step, 511 | body 512 | }) 513 | ); 514 | } 515 | 516 | #[test] 517 | fn parse_for_no_step() { 518 | let mut p = parser("for i = 1, 2 in 4"); 519 | 520 | let var = String::from("i"); 521 | let start = Box::new(ExprAST::Number(1f64)); 522 | let end = Box::new(ExprAST::Number(2f64)); 523 | let step = None; 524 | let body = Box::new(ExprAST::Number(4f64)); 525 | 526 | assert_eq!( 527 | p.parse_for_expr(), 528 | Ok(ExprAST::For { 529 | var, 530 | start, 531 | end, 532 | step, 533 | body 534 | }) 535 | ); 536 | } 537 | 538 | #[test] 539 | fn parse_primary() { 540 | let mut p = parser("1337 foop \n bla(123) \n if a then b else c \n for x=1,2 in 3"); 541 | 542 | assert_eq!(p.parse_primary(), Ok(ExprAST::Number(1337f64))); 543 | 544 | assert_eq!(p.parse_primary(), Ok(ExprAST::Variable("foop".into()))); 545 | 546 | assert_eq!( 547 | p.parse_primary(), 548 | Ok(ExprAST::Call("bla".into(), vec![ExprAST::Number(123f64)])) 549 | ); 550 | 551 | assert_eq!( 552 | p.parse_primary(), 553 | Ok(ExprAST::If { 554 | cond: Box::new(ExprAST::Variable("a".into())), 555 | then: Box::new(ExprAST::Variable("b".into())), 556 | else_: Box::new(ExprAST::Variable("c".into())), 557 | }) 558 | ); 559 | 560 | assert_eq!( 561 | p.parse_primary(), 562 | Ok(ExprAST::For { 563 | var: String::from("x"), 564 | start: Box::new(ExprAST::Number(1f64)), 565 | end: Box::new(ExprAST::Number(2f64)), 566 | step: None, 567 | body: Box::new(ExprAST::Number(3f64)), 568 | }) 569 | ); 570 | } 571 | 572 | #[test] 573 | fn parse_binary_op() { 574 | // Operator before RHS has higher precedence, expected AST 575 | // 576 | // - 577 | // / \ 578 | // + c 579 | // / \ 580 | // a b 581 | let mut p = parser("a + b - c"); 582 | 583 | let binexpr_ab = ExprAST::Binary( 584 | '+', 585 | Box::new(ExprAST::Variable("a".into())), 586 | Box::new(ExprAST::Variable("b".into())), 587 | ); 588 | 589 | let binexpr_abc = ExprAST::Binary( 590 | '-', 591 | Box::new(binexpr_ab), 592 | Box::new(ExprAST::Variable("c".into())), 593 | ); 594 | 595 | assert_eq!(p.parse_expression(), Ok(binexpr_abc)); 596 | } 597 | 598 | #[test] 599 | fn parse_binary_op2() { 600 | // Operator after RHS has higher precedence, expected AST 601 | // 602 | // + 603 | // / \ 604 | // a * 605 | // / \ 606 | // b c 607 | let mut p = parser("a + b * c"); 608 | 609 | let binexpr_bc = ExprAST::Binary( 610 | '*', 611 | Box::new(ExprAST::Variable("b".into())), 612 | Box::new(ExprAST::Variable("c".into())), 613 | ); 614 | 615 | let binexpr_abc = ExprAST::Binary( 616 | '+', 617 | Box::new(ExprAST::Variable("a".into())), 618 | Box::new(binexpr_bc), 619 | ); 620 | 621 | assert_eq!(p.parse_expression(), Ok(binexpr_abc)); 622 | } 623 | 624 | #[test] 625 | fn parse_prototype() { 626 | let mut p = parser("foo(a,b)"); 627 | 628 | let proto = PrototypeAST("foo".into(), vec!["a".into(), "b".into()]); 629 | 630 | assert_eq!(p.parse_prototype(), Ok(proto)); 631 | } 632 | 633 | #[test] 634 | fn parse_definition() { 635 | let mut p = parser("def bar( arg0 , arg1 ) arg0 + arg1"); 636 | 637 | let proto = PrototypeAST("bar".into(), vec!["arg0".into(), "arg1".into()]); 638 | 639 | let body = ExprAST::Binary( 640 | '+', 641 | Box::new(ExprAST::Variable("arg0".into())), 642 | Box::new(ExprAST::Variable("arg1".into())), 643 | ); 644 | 645 | let func = FunctionAST(proto, body); 646 | 647 | assert_eq!(p.parse_definition(), Ok(func)); 648 | } 649 | 650 | #[test] 651 | fn parse_extern() { 652 | let mut p = parser("extern baz()"); 653 | 654 | let proto = PrototypeAST("baz".into(), vec![]); 655 | 656 | assert_eq!(p.parse_extern(), Ok(proto)); 657 | } 658 | } 659 | --------------------------------------------------------------------------------