├── .cargo
└── config
├── .github
└── workflows
│ └── check.yml
├── .gitignore
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── README.md
├── container
├── Dockerfile
└── Makefile
├── ks
├── forloop.ks
└── ifelse.ks
└── src
├── codegen.rs
├── lexer.rs
├── lib.rs
├── llvm
├── basic_block.rs
├── builder.rs
├── lljit.rs
├── mod.rs
├── module.rs
├── pass_manager.rs
├── type_.rs
└── value.rs
├── main.rs
└── parser.rs
/.cargo/config:
--------------------------------------------------------------------------------
1 | [build]
2 | rustflags = ["-C", "link-args=-rdynamic"]
3 |
--------------------------------------------------------------------------------
/.github/workflows/check.yml:
--------------------------------------------------------------------------------
1 | name: Build, Test and generate Doc
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | env:
10 | CARGO_TERM_COLOR: always
11 | PODMAN_RUN: podman run --rm -t -v $PWD:/work -w /work ks-rs
12 |
13 | jobs:
14 | ci:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v2
18 | - run: make -C container
19 |
20 | - run: cargo fmt -- --check
21 | - run: eval $PODMAN_RUN cargo build --verbose
22 | - run: eval $PODMAN_RUN cargo test --verbose
23 |
24 | - name: Generate doc
25 | run: |
26 | eval $PODMAN_RUN cargo doc --no-deps
27 | echo "" > target/doc/index.html
28 | - name: Upload doc to gh pages
29 | uses: peaceiris/actions-gh-pages@v3
30 | with:
31 | github_token: ${{ secrets.GITHUB_TOKEN }}
32 | publish_dir: ./target/doc
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
--------------------------------------------------------------------------------
/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 3
4 |
5 | [[package]]
6 | name = "aho-corasick"
7 | version = "0.7.19"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "b4f55bd91a0978cbfd91c457a164bab8b4001c833b7f323132c0a4e1922dd44e"
10 | dependencies = [
11 | "memchr",
12 | ]
13 |
14 | [[package]]
15 | name = "cc"
16 | version = "1.0.73"
17 | source = "registry+https://github.com/rust-lang/crates.io-index"
18 | checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
19 |
20 | [[package]]
21 | name = "lazy_static"
22 | version = "1.4.0"
23 | source = "registry+https://github.com/rust-lang/crates.io-index"
24 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
25 |
26 | [[package]]
27 | name = "libc"
28 | version = "0.2.133"
29 | source = "registry+https://github.com/rust-lang/crates.io-index"
30 | checksum = "c0f80d65747a3e43d1596c7c5492d95d5edddaabd45a7fcdb02b95f644164966"
31 |
32 | [[package]]
33 | name = "llvm-kaleidoscope-rs"
34 | version = "0.1.0"
35 | dependencies = [
36 | "libc",
37 | "llvm-sys",
38 | ]
39 |
40 | [[package]]
41 | name = "llvm-sys"
42 | version = "160.2.0"
43 | source = "registry+https://github.com/rust-lang/crates.io-index"
44 | checksum = "0438b23666723e3851fe336c1757acd3c915a0a14b41cc41284143609ca6cdce"
45 | dependencies = [
46 | "cc",
47 | "lazy_static",
48 | "libc",
49 | "regex",
50 | "semver",
51 | ]
52 |
53 | [[package]]
54 | name = "memchr"
55 | version = "2.5.0"
56 | source = "registry+https://github.com/rust-lang/crates.io-index"
57 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
58 |
59 | [[package]]
60 | name = "regex"
61 | version = "1.6.0"
62 | source = "registry+https://github.com/rust-lang/crates.io-index"
63 | checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b"
64 | dependencies = [
65 | "aho-corasick",
66 | "memchr",
67 | "regex-syntax",
68 | ]
69 |
70 | [[package]]
71 | name = "regex-syntax"
72 | version = "0.6.27"
73 | source = "registry+https://github.com/rust-lang/crates.io-index"
74 | checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244"
75 |
76 | [[package]]
77 | name = "semver"
78 | version = "1.0.16"
79 | source = "registry+https://github.com/rust-lang/crates.io-index"
80 | checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a"
81 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "llvm-kaleidoscope-rs"
3 | version = "0.1.0"
4 | edition = "2018"
5 |
6 | [dependencies]
7 | libc = "0.2"
8 | llvm-sys = {version = "160.0", features = ["strict-versioning"]}
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Johannes Stoelp
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # llvm-kaleidoscope-rs
2 |
3 | [![Rust][wf-badge]][wf-output] [![Rustdoc][doc-badge]][doc-html]
4 |
5 | [wf-output]: https://github.com/johannst/llvm-kaleidoscope-rs/actions/workflows/check.yml
6 | [wf-badge]: https://github.com/johannst/llvm-kaleidoscope-rs/actions/workflows/check.yml/badge.svg
7 | [doc-html]: https://johannst.github.io/llvm-kaleidoscope-rs/llvm_kaleidoscope_rs/index.html
8 | [doc-badge]: https://img.shields.io/badge/llvm__kaleidoscope__rs-rustdoc-blue.svg?style=flat&logo=rust
9 |
10 | The purpose of this repository is to learn about the [`llvm`][llvm] compiler
11 | infrastructure and practice some [`rust-lang`][rust].
12 |
13 | To reach the goals set, we follow the official llvm tutorial [`Kaleidoscope:
14 | Implementing a Language with LLVM`][llvm-tutorial]. This tutorial is written in
15 | `C++` and structured in multiple chapters, we will try to follow along and
16 | implement every chapter in rust.
17 |
18 | The topics of the chapters are as follows:
19 |
20 | - Chapter 1: [Kaleidoscope Introduction and the Lexer][llvm-ch1]
21 | - Chapter 2: [Implementing a Parser and AST][llvm-ch2]
22 | - Chapter 3: [Code generation to LLVM IR][llvm-ch3]
23 | - Chapter 4: [Adding JIT and Optimizer Support][llvm-ch4]
24 | - Chapter 5: [Extending the Language: Control Flow][llvm-ch5]
25 |
26 | The implementation after each chapter can be compiled and executed by checking
27 | out the corresponding tag for the chapter.
28 | ```bash
29 | > git tag -l
30 | chapter1
31 | chapter2
32 | chapter3
33 | chapter4
34 | chapter5
35 | ```
36 |
37 | Names of variables and functions as well as the structure of the functions are
38 | mainly kept aligned with the official tutorial. This aims to make it easy to
39 | map the `rust` implementation onto the `C++` implementation when following the
40 | tutorial.
41 |
42 | One further note on the llvm API, instead of using the llvm `C++` API we are
43 | going to use the llvm `C` API and build our own safe wrapper specialized for
44 | this tutorial. The wrapper offers a similar interface as the `C++` API and is
45 | implemented in [`src/llvm/`](src/llvm/)
46 |
47 | ## Demo
48 |
49 | ```bash
50 | # Run kaleidoscope program from file.
51 | cargo run ks/
52 |
53 | # Run REPL loop, parsing from stdin.
54 | cargo run
55 | ```
56 |
57 | ## Documentation
58 |
59 | Rustdoc for this crate is available at
60 | [johannst.github.io/llvm-kaleidoscope-rs][gh-pages].
61 |
62 | ## Build with provided container file
63 |
64 | The provided [Dockerfile](container/Dockerfile) documents the required
65 | dependencies for an ubuntu based system and serves as a build environment with
66 | the correct llvm version as specified in the [Cargo.toml](Cargo.toml) file.
67 |
68 | ```bash
69 | ## Either user podman ..
70 |
71 | # Build the image *ks-rs*. Depending on the downlink this may take some minutes.
72 | make -C container
73 |
74 | podman run --rm -it -v $PWD:/work -w /work ks-rs
75 | # Drops into a shell in the container, just use cargo build / run ...
76 |
77 | ## .. or docker.
78 |
79 | # Build the image *ks-rs*. Depending on the downlink this may take some minutes.
80 | make -C container docker
81 |
82 | docker run --rm -it -v $PWD:/work -w /work ks-rs
83 | # Drops into a shell in the container, just use cargo build / run ...
84 | ```
85 |
86 | ## License
87 |
88 | This project is licensed under the [MIT](LICENSE) license.
89 |
90 | [llvm]: https://llvm.org
91 | [llvm-tutorial]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/index.html
92 | [llvm-ch1]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl01.html
93 | [llvm-ch2]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl02.html
94 | [llvm-ch3]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl03.html
95 | [llvm-ch4]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl04.html
96 | [llvm-ch5]: https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl05.html
97 | [rust]: https://www.rust-lang.org
98 | [gh-pages]: https://johannst.github.io/llvm-kaleidoscope-rs/llvm_kaleidoscope_rs/index.html
99 |
--------------------------------------------------------------------------------
/container/Dockerfile:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: MIT
2 | #
3 | # Copyright (c) 2021, Johannes Stoelp
4 |
5 | FROM ubuntu
6 |
7 | RUN apt update && \
8 | DEBIAN_FRONTEND=noninteractive \
9 | apt install \
10 | --yes \
11 | --no-install-recommends \
12 | ca-certificates \
13 | build-essential \
14 | cargo \
15 | llvm-16-dev \
16 | # For polly dependency.
17 | # https://gitlab.com/taricorp/llvm-sys.rs/-/issues/13
18 | libpolly-16-dev \
19 | libz-dev \
20 | libzstd-dev \
21 | && \
22 | rm -rf /var/lib/apt/lists/* && \
23 | apt-get clean
24 |
--------------------------------------------------------------------------------
/container/Makefile:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: MIT
2 | #
3 | # Copyright (c) 2021, Johannes Stoelp
4 |
5 | podman: build-podman
6 | docker: build-docker
7 |
8 | build-%:
9 | $* build -t ks-rs .
10 |
--------------------------------------------------------------------------------
/ks/forloop.ks:
--------------------------------------------------------------------------------
1 | extern putchard(char);
2 | def printstar(n)
3 | for i = 1, i < n, 1.0 in
4 | putchard(42); # ascii 42 = '*'
5 |
6 | # print 100 '*' characters
7 | printstar(100);
8 |
--------------------------------------------------------------------------------
/ks/ifelse.ks:
--------------------------------------------------------------------------------
1 | def fib(x)
2 | if x < 3 then
3 | 1
4 | else
5 | fib(x-1)+fib(x-2);
6 |
7 | fib(10);
8 |
--------------------------------------------------------------------------------
/src/codegen.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use std::collections::HashMap;
6 |
7 | use crate::llvm::{FnValue, FunctionPassManager, IRBuilder, Module, Value};
8 | use crate::parser::{ExprAST, FunctionAST, PrototypeAST};
9 | use crate::Either;
10 |
11 | type CodegenResult = Result;
12 |
13 | /// Code generator from kaleidoscope AST to LLVM IR.
14 | pub struct Codegen<'llvm, 'a> {
15 | module: &'llvm Module,
16 | builder: &'a IRBuilder<'llvm>,
17 | fpm: &'a FunctionPassManager<'llvm>,
18 | fn_protos: &'a mut HashMap,
19 | }
20 |
21 | impl<'llvm, 'a> Codegen<'llvm, 'a> {
22 | /// Compile either a [`PrototypeAST`] or a [`FunctionAST`] into the LLVM `module`.
23 | pub fn compile(
24 | module: &'llvm Module,
25 | fn_protos: &mut HashMap,
26 | compilee: Either<&PrototypeAST, &FunctionAST>,
27 | ) -> CodegenResult> {
28 | let mut cg = Codegen {
29 | module,
30 | builder: &IRBuilder::with_ctx(module),
31 | fpm: &FunctionPassManager::with_ctx(module),
32 | fn_protos,
33 | };
34 | let mut variables = HashMap::new();
35 |
36 | match compilee {
37 | Either::A(proto) => Ok(cg.codegen_prototype(proto)),
38 | Either::B(func) => cg.codegen_function(func, &mut variables),
39 | }
40 | }
41 |
42 | fn codegen_expr(
43 | &self,
44 | expr: &ExprAST,
45 | named_values: &mut HashMap>,
46 | ) -> CodegenResult> {
47 | match expr {
48 | ExprAST::Number(num) => Ok(self.module.type_f64().const_f64(*num)),
49 | ExprAST::Variable(name) => match named_values.get(name.as_str()) {
50 | Some(value) => Ok(*value),
51 | None => Err("Unknown variable name".into()),
52 | },
53 | ExprAST::Binary(binop, lhs, rhs) => {
54 | let l = self.codegen_expr(lhs, named_values)?;
55 | let r = self.codegen_expr(rhs, named_values)?;
56 |
57 | match binop {
58 | '+' => Ok(self.builder.fadd(l, r)),
59 | '-' => Ok(self.builder.fsub(l, r)),
60 | '*' => Ok(self.builder.fmul(l, r)),
61 | '<' => {
62 | let res = self.builder.fcmpult(l, r);
63 | // Turn bool into f64.
64 | Ok(self.builder.uitofp(res, self.module.type_f64()))
65 | }
66 | _ => Err("invalid binary operator".into()),
67 | }
68 | }
69 | ExprAST::Call(callee, args) => match self.get_function(callee) {
70 | Some(callee) => {
71 | if callee.args() != args.len() {
72 | return Err("Incorrect # arguments passed".into());
73 | }
74 |
75 | // Generate code for function argument expressions.
76 | let mut args: Vec> = args
77 | .iter()
78 | .map(|arg| self.codegen_expr(arg, named_values))
79 | .collect::>()?;
80 |
81 | Ok(self.builder.call(callee, &mut args))
82 | }
83 | None => Err("Unknown function referenced".into()),
84 | },
85 | ExprAST::If { cond, then, else_ } => {
86 | // For 'if' expressions we are building the following CFG.
87 | //
88 | // ; cond
89 | // br
90 | // |
91 | // +-----+------+
92 | // v v
93 | // ; then ; else
94 | // | |
95 | // +-----+------+
96 | // v
97 | // ; merge
98 | // phi then, else
99 | // ret phi
100 |
101 | let cond_v = {
102 | // Codgen 'cond' expression.
103 | let v = self.codegen_expr(cond, named_values)?;
104 | // Compare 'v' against '0' as 'one = ordered not equal'.
105 | self.builder
106 | .fcmpone(v, self.module.type_f64().const_f64(0f64))
107 | };
108 |
109 | // Get the function we are currently inserting into.
110 | let the_function = self.builder.get_insert_block().get_parent();
111 |
112 | // Create basic blocks for the 'then' / 'else' expressions as well as the return
113 | // instruction ('merge').
114 | //
115 | // Append the 'then' basic block to the function, don't insert the 'else' and
116 | // 'merge' basic blocks yet.
117 | let then_bb = self.module.append_basic_block(the_function);
118 | let else_bb = self.module.create_basic_block();
119 | let merge_bb = self.module.create_basic_block();
120 |
121 | // Create a conditional branch based on the result of the 'cond' expression.
122 | self.builder.cond_br(cond_v, then_bb, else_bb);
123 |
124 | // Move to 'then' basic block and codgen the 'then' expression.
125 | self.builder.pos_at_end(then_bb);
126 | let then_v = self.codegen_expr(then, named_values)?;
127 | // Create unconditional branch to 'merge' block.
128 | self.builder.br(merge_bb);
129 | // Update reference to current basic block (in case the 'then' expression added new
130 | // basic blocks).
131 | let then_bb = self.builder.get_insert_block();
132 |
133 | // Now append the 'else' basic block to the function.
134 | the_function.append_basic_block(else_bb);
135 | // Move to 'else' basic block and codgen the 'else' expression.
136 | self.builder.pos_at_end(else_bb);
137 | let else_v = self.codegen_expr(else_, named_values)?;
138 | // Create unconditional branch to 'merge' block.
139 | self.builder.br(merge_bb);
140 | // Update reference to current basic block (in case the 'else' expression added new
141 | // basic blocks).
142 | let else_bb = self.builder.get_insert_block();
143 |
144 | // Now append the 'merge' basic block to the function.
145 | the_function.append_basic_block(merge_bb);
146 | // Move to 'merge' basic block.
147 | self.builder.pos_at_end(merge_bb);
148 | // Codegen the phi node returning the appropriate value depending on the branch
149 | // condition.
150 | let phi = self.builder.phi(
151 | self.module.type_f64(),
152 | &[(then_v, then_bb), (else_v, else_bb)],
153 | );
154 |
155 | Ok(*phi)
156 | }
157 | ExprAST::For {
158 | var,
159 | start,
160 | end,
161 | step,
162 | body,
163 | } => {
164 | // For 'for' expression we build the following structure.
165 | //
166 | // entry:
167 | // init = start expression
168 | // br loop
169 | // loop:
170 | // i = phi [%init, %entry], [%new_i, %loop]
171 | // ; loop body ...
172 | // new_i = increment %i by step expression
173 | // ; check end condition and branch
174 | // end:
175 |
176 | // Compute initial value for the loop variable.
177 | let start_val = self.codegen_expr(start, named_values)?;
178 |
179 | let the_function = self.builder.get_insert_block().get_parent();
180 | // Get current basic block (used in the loop variable phi node).
181 | let entry_bb = self.builder.get_insert_block();
182 | // Add new basic block to emit loop body.
183 | let loop_bb = self.module.append_basic_block(the_function);
184 |
185 | self.builder.br(loop_bb);
186 | self.builder.pos_at_end(loop_bb);
187 |
188 | // Build phi not to pick loop variable in case we come from the 'entry' block.
189 | // Which is the case when we enter the loop for the first time.
190 | // We will add another incoming value once we computed the updated loop variable
191 | // below.
192 | let variable = self
193 | .builder
194 | .phi(self.module.type_f64(), &[(start_val, entry_bb)]);
195 |
196 | // Insert the loop variable into the named values map that it can be referenced
197 | // from the body as well as the end condition.
198 | // In case the loop variable shadows an existing variable remember the shared one.
199 | let old_val = named_values.insert(var.into(), *variable);
200 |
201 | // Generate the loop body.
202 | self.codegen_expr(body, named_values)?;
203 |
204 | // Generate step value expression if available else use '1'.
205 | let step_val = if let Some(step) = step {
206 | self.codegen_expr(step, named_values)?
207 | } else {
208 | self.module.type_f64().const_f64(1f64)
209 | };
210 |
211 | // Increment loop variable.
212 | let next_var = self.builder.fadd(*variable, step_val);
213 |
214 | // Generate the loop end condition.
215 | let end_cond = self.codegen_expr(end, named_values)?;
216 | let end_cond = self
217 | .builder
218 | .fcmpone(end_cond, self.module.type_f64().const_f64(0f64));
219 |
220 | // Get current basic block.
221 | let loop_end_bb = self.builder.get_insert_block();
222 | // Add new basic block following the loop.
223 | let after_bb = self.module.append_basic_block(the_function);
224 |
225 | // Register additional incoming value for the loop variable. This will choose the
226 | // updated loop variable if we are iterating in the loop.
227 | variable.add_incoming(next_var, loop_end_bb);
228 |
229 | // Branch depending on the loop end condition.
230 | self.builder.cond_br(end_cond, loop_bb, after_bb);
231 |
232 | self.builder.pos_at_end(after_bb);
233 |
234 | // Restore the shadowed variable if there was one.
235 | if let Some(old_val) = old_val {
236 | // We inserted 'var' above so it must exist.
237 | *named_values.get_mut(var).unwrap() = old_val;
238 | } else {
239 | named_values.remove(var);
240 | }
241 |
242 | // Loops just always return 0.
243 | Ok(self.module.type_f64().const_f64(0f64))
244 | }
245 | }
246 | }
247 |
248 | fn codegen_prototype(&self, PrototypeAST(name, args): &PrototypeAST) -> FnValue<'llvm> {
249 | let type_f64 = self.module.type_f64();
250 |
251 | let mut doubles = Vec::new();
252 | doubles.resize(args.len(), type_f64);
253 |
254 | // Build the function type: fn(f64, f64, ..) -> f64
255 | let ft = self.module.type_fn(&mut doubles, type_f64);
256 |
257 | // Create the function declaration.
258 | let f = self.module.add_fn(name, ft);
259 |
260 | // Set the names of the function arguments.
261 | for idx in 0..f.args() {
262 | f.arg(idx).set_name(&args[idx]);
263 | }
264 |
265 | f
266 | }
267 |
268 | fn codegen_function(
269 | &mut self,
270 | FunctionAST(proto, body): &FunctionAST,
271 | named_values: &mut HashMap>,
272 | ) -> CodegenResult> {
273 | // Insert the function prototype into the `fn_protos` map to keep track for re-generating
274 | // declarations in other modules.
275 | self.fn_protos.insert(proto.0.clone(), proto.clone());
276 |
277 | let the_function = self.get_function(&proto.0)
278 | .expect("If proto not already generated, get_function will do for us since we updated fn_protos before-hand!");
279 |
280 | if the_function.basic_blocks() > 0 {
281 | return Err("Function cannot be redefined.".into());
282 | }
283 |
284 | // Create entry basic block to insert code.
285 | let bb = self.module.append_basic_block(the_function);
286 | self.builder.pos_at_end(bb);
287 |
288 | // New scope, clear the map with the function args.
289 | named_values.clear();
290 |
291 | // Update the map with the current functions args.
292 | for idx in 0..the_function.args() {
293 | let arg = the_function.arg(idx);
294 | named_values.insert(arg.get_name().into(), arg);
295 | }
296 |
297 | // Codegen function body.
298 | if let Ok(ret) = self.codegen_expr(body, named_values) {
299 | self.builder.ret(ret);
300 | assert!(the_function.verify());
301 |
302 | // Run the optimization passes on the function.
303 | self.fpm.run(the_function);
304 |
305 | Ok(the_function)
306 | } else {
307 | todo!("Failed to codegen function body, erase from module!");
308 | }
309 | }
310 |
311 | /// Lookup function with `name` in the LLVM module and return the corresponding value reference.
312 | /// If the function is not available in the module, check if the prototype is known and codegen
313 | /// it.
314 | /// Return [`None`] if the prototype is not known.
315 | fn get_function(&self, name: &str) -> Option> {
316 | let callee = match self.module.get_fn(name) {
317 | Some(callee) => callee,
318 | None => {
319 | let proto = self.fn_protos.get(name)?;
320 | self.codegen_prototype(proto)
321 | }
322 | };
323 |
324 | Some(callee)
325 | }
326 | }
327 |
--------------------------------------------------------------------------------
/src/lexer.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | #[derive(Debug, PartialEq)]
6 | pub enum Token {
7 | Eof,
8 | Def,
9 | Extern,
10 | Identifier(String),
11 | Number(f64),
12 | Char(char),
13 | If,
14 | Then,
15 | Else,
16 | For,
17 | In,
18 | }
19 |
20 | pub struct Lexer
21 | where
22 | I: Iterator- ,
23 | {
24 | input: I,
25 | last_char: Option,
26 | }
27 |
28 | impl Lexer
29 | where
30 | I: Iterator
- ,
31 | {
32 | pub fn new(mut input: I) -> Lexer {
33 | let last_char = input.next();
34 | Lexer { input, last_char }
35 | }
36 |
37 | fn step(&mut self) -> Option {
38 | self.last_char = self.input.next();
39 | self.last_char
40 | }
41 |
42 | /// Lex and return the next token.
43 | ///
44 | /// Implement `int gettok();` from the tutorial.
45 | pub fn gettok(&mut self) -> Token {
46 | // Eat up whitespaces.
47 | while matches!(self.last_char, Some(c) if c.is_ascii_whitespace()) {
48 | self.step();
49 | }
50 |
51 | // Unpack last char or return EOF.
52 | let last_char = if let Some(c) = self.last_char {
53 | c
54 | } else {
55 | return Token::Eof;
56 | };
57 |
58 | // Identifier: [a-zA-Z][a-zA-Z0-9]*
59 | if last_char.is_ascii_alphabetic() {
60 | let mut ident = String::new();
61 | ident.push(last_char);
62 |
63 | while let Some(c) = self.step() {
64 | if c.is_ascii_alphanumeric() {
65 | ident.push(c)
66 | } else {
67 | break;
68 | }
69 | }
70 |
71 | match ident.as_ref() {
72 | "def" => return Token::Def,
73 | "extern" => return Token::Extern,
74 | "if" => return Token::If,
75 | "then" => return Token::Then,
76 | "else" => return Token::Else,
77 | "for" => return Token::For,
78 | "in" => return Token::In,
79 | _ => {}
80 | }
81 |
82 | return Token::Identifier(ident);
83 | }
84 |
85 | // Number: [0-9.]+
86 | if last_char.is_ascii_digit() || last_char == '.' {
87 | let mut num = String::new();
88 | num.push(last_char);
89 |
90 | while let Some(c) = self.step() {
91 | if c.is_ascii_digit() || c == '.' {
92 | num.push(c)
93 | } else {
94 | break;
95 | }
96 | }
97 |
98 | let num: f64 = num.parse().unwrap_or_default();
99 | return Token::Number(num);
100 | }
101 |
102 | // Eat up comment.
103 | if last_char == '#' {
104 | loop {
105 | match self.step() {
106 | Some(c) if c == '\r' || c == '\n' => return self.gettok(),
107 | None => return Token::Eof,
108 | _ => { /* consume comment */ }
109 | }
110 | }
111 | }
112 |
113 | // Advance last char and return currently last char.
114 | self.step();
115 | Token::Char(last_char)
116 | }
117 | }
118 |
119 | #[cfg(test)]
120 | mod test {
121 | use super::{Lexer, Token};
122 |
123 | #[test]
124 | fn test_identifier() {
125 | let mut lex = Lexer::new("a b c".chars());
126 | assert_eq!(Token::Identifier("a".into()), lex.gettok());
127 | assert_eq!(Token::Identifier("b".into()), lex.gettok());
128 | assert_eq!(Token::Identifier("c".into()), lex.gettok());
129 | assert_eq!(Token::Eof, lex.gettok());
130 | }
131 |
132 | #[test]
133 | fn test_keyword() {
134 | let mut lex = Lexer::new("def extern".chars());
135 | assert_eq!(Token::Def, lex.gettok());
136 | assert_eq!(Token::Extern, lex.gettok());
137 | assert_eq!(Token::Eof, lex.gettok());
138 | }
139 |
140 | #[test]
141 | fn test_number() {
142 | let mut lex = Lexer::new("12.34".chars());
143 | assert_eq!(Token::Number(12.34f64), lex.gettok());
144 | assert_eq!(Token::Eof, lex.gettok());
145 |
146 | let mut lex = Lexer::new(" 1.0 2.0 3.0".chars());
147 | assert_eq!(Token::Number(1.0f64), lex.gettok());
148 | assert_eq!(Token::Number(2.0f64), lex.gettok());
149 | assert_eq!(Token::Number(3.0f64), lex.gettok());
150 | assert_eq!(Token::Eof, lex.gettok());
151 |
152 | let mut lex = Lexer::new("12.34.56".chars());
153 | assert_eq!(Token::Number(0f64), lex.gettok());
154 | assert_eq!(Token::Eof, lex.gettok());
155 | }
156 |
157 | #[test]
158 | fn test_comment() {
159 | let mut lex = Lexer::new("# some comment".chars());
160 | assert_eq!(Token::Eof, lex.gettok());
161 |
162 | let mut lex = Lexer::new("abc # some comment \n xyz".chars());
163 | assert_eq!(Token::Identifier("abc".into()), lex.gettok());
164 | assert_eq!(Token::Identifier("xyz".into()), lex.gettok());
165 | assert_eq!(Token::Eof, lex.gettok());
166 | }
167 |
168 | #[test]
169 | fn test_chars() {
170 | let mut lex = Lexer::new("a+b-c".chars());
171 | assert_eq!(Token::Identifier("a".into()), lex.gettok());
172 | assert_eq!(Token::Char('+'), lex.gettok());
173 | assert_eq!(Token::Identifier("b".into()), lex.gettok());
174 | assert_eq!(Token::Char('-'), lex.gettok());
175 | assert_eq!(Token::Identifier("c".into()), lex.gettok());
176 | assert_eq!(Token::Eof, lex.gettok());
177 | }
178 |
179 | #[test]
180 | fn test_whitespaces() {
181 | let mut lex = Lexer::new(" +a b c! ".chars());
182 | assert_eq!(Token::Char('+'), lex.gettok());
183 | assert_eq!(Token::Identifier("a".into()), lex.gettok());
184 | assert_eq!(Token::Identifier("b".into()), lex.gettok());
185 | assert_eq!(Token::Identifier("c".into()), lex.gettok());
186 | assert_eq!(Token::Char('!'), lex.gettok());
187 | assert_eq!(Token::Eof, lex.gettok());
188 |
189 | let mut lex = Lexer::new("\n a \n\r b \r \n c \r\r \n ".chars());
190 | assert_eq!(Token::Identifier("a".into()), lex.gettok());
191 | assert_eq!(Token::Identifier("b".into()), lex.gettok());
192 | assert_eq!(Token::Identifier("c".into()), lex.gettok());
193 | assert_eq!(Token::Eof, lex.gettok());
194 | }
195 |
196 | #[test]
197 | fn test_ite() {
198 | let mut lex = Lexer::new("if then else".chars());
199 | assert_eq!(Token::If, lex.gettok());
200 | assert_eq!(Token::Then, lex.gettok());
201 | assert_eq!(Token::Else, lex.gettok());
202 | }
203 |
204 | #[test]
205 | fn test_for() {
206 | let mut lex = Lexer::new("for in".chars());
207 | assert_eq!(Token::For, lex.gettok());
208 | assert_eq!(Token::In, lex.gettok());
209 | }
210 | }
211 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use std::convert::TryFrom;
6 |
7 | pub mod codegen;
8 | pub mod lexer;
9 | pub mod llvm;
10 | pub mod parser;
11 |
12 | /// Fixed size of [`SmallCStr`] including the trailing `\0` byte.
13 | pub const SMALL_STR_SIZE: usize = 16;
14 |
15 | /// Small C string on the stack with fixed size [`SMALL_STR_SIZE`].
16 | ///
17 | /// This is specially crafted to interact with the LLVM C API and get rid of some heap allocations.
18 | #[derive(Debug, PartialEq)]
19 | pub struct SmallCStr([u8; SMALL_STR_SIZE]);
20 |
21 | impl SmallCStr {
22 | /// Create a new C string from `src`.
23 | /// Returns [`None`] if `src` exceeds the fixed size or contains any `\0` bytes.
24 | pub fn new>(src: &T) -> Option {
25 | let src = src.as_ref();
26 | let len = src.len();
27 |
28 | // Check for \0 bytes.
29 | let contains_null = unsafe { !libc::memchr(src.as_ptr().cast(), 0, len).is_null() };
30 |
31 | if contains_null || len > SMALL_STR_SIZE - 1 {
32 | None
33 | } else {
34 | let mut dest = [0; SMALL_STR_SIZE];
35 | dest[..len].copy_from_slice(src);
36 | Some(SmallCStr(dest))
37 | }
38 | }
39 |
40 | /// Return pointer to C string.
41 | pub const fn as_ptr(&self) -> *const libc::c_char {
42 | self.0.as_ptr().cast()
43 | }
44 | }
45 |
46 | impl TryFrom<&str> for SmallCStr {
47 | type Error = ();
48 |
49 | fn try_from(value: &str) -> Result {
50 | SmallCStr::new(&value).ok_or(())
51 | }
52 | }
53 |
54 | /// Either type, for APIs accepting two types.
55 | pub enum Either {
56 | A(A),
57 | B(B),
58 | }
59 |
60 | #[cfg(test)]
61 | mod test {
62 | use super::{SmallCStr, SMALL_STR_SIZE};
63 | use std::convert::TryInto;
64 |
65 | #[test]
66 | fn test_create() {
67 | let src = "\x30\x31\x32\x33";
68 | let scs = SmallCStr::new(&src).unwrap();
69 | assert_eq!(&scs.0[..5], &[0x30, 0x31, 0x32, 0x33, 0x00]);
70 |
71 | let src = b"abcd1234";
72 | let scs = SmallCStr::new(&src).unwrap();
73 | assert_eq!(
74 | &scs.0[..9],
75 | &[0x61, 0x62, 0x63, 0x64, 0x31, 0x32, 0x33, 0x34, 0x00]
76 | );
77 | }
78 |
79 | #[test]
80 | fn test_contain_null() {
81 | let src = "\x30\x00\x32\x33";
82 | let scs = SmallCStr::new(&src);
83 | assert_eq!(scs, None);
84 |
85 | let src = "\x30\x31\x32\x33\x00";
86 | let scs = SmallCStr::new(&src);
87 | assert_eq!(scs, None);
88 | }
89 |
90 | #[test]
91 | fn test_too_large() {
92 | let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::();
93 | let scs = SmallCStr::new(&src);
94 | assert_eq!(scs, None);
95 |
96 | let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::();
97 | let scs = SmallCStr::new(&src);
98 | assert_eq!(scs, None);
99 | }
100 |
101 | #[test]
102 | fn test_try_into() {
103 | let src = "\x30\x31\x32\x33";
104 | let scs: Result = src.try_into();
105 | assert!(scs.is_ok());
106 |
107 | let src = (0..SMALL_STR_SIZE).map(|_| 'a').collect::();
108 | let scs: Result = src.as_str().try_into();
109 | assert!(scs.is_err());
110 |
111 | let src = (0..SMALL_STR_SIZE + 10).map(|_| 'a').collect::();
112 | let scs: Result = src.as_str().try_into();
113 | assert!(scs.is_err());
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/src/llvm/basic_block.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_sys::{core::LLVMGetBasicBlockParent, prelude::LLVMBasicBlockRef};
6 |
7 | use std::marker::PhantomData;
8 |
9 | use super::FnValue;
10 |
11 | /// Wrapper for a LLVM Basic Block.
12 | #[derive(Copy, Clone)]
13 | pub struct BasicBlock<'llvm>(LLVMBasicBlockRef, PhantomData<&'llvm ()>);
14 |
15 | impl<'llvm> BasicBlock<'llvm> {
16 | /// Create a new BasicBlock instance.
17 | ///
18 | /// # Panics
19 | ///
20 | /// Panics if `bb_ref` is a null pointer.
21 | pub(super) fn new(bb_ref: LLVMBasicBlockRef) -> BasicBlock<'llvm> {
22 | assert!(!bb_ref.is_null());
23 | BasicBlock(bb_ref, PhantomData)
24 | }
25 |
26 | /// Get the raw LLVM value reference.
27 | #[inline]
28 | pub(super) fn bb_ref(&self) -> LLVMBasicBlockRef {
29 | self.0
30 | }
31 |
32 | /// Get the function to which the basic block belongs.
33 | ///
34 | /// # Panics
35 | ///
36 | /// Panics if LLVM API returns a `null` pointer.
37 | pub fn get_parent(&self) -> FnValue<'llvm> {
38 | let value_ref = unsafe { LLVMGetBasicBlockParent(self.bb_ref()) };
39 | assert!(!value_ref.is_null());
40 |
41 | FnValue::new(value_ref)
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/llvm/builder.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_sys::{
6 | core::{
7 | LLVMAddIncoming, LLVMBuildBr, LLVMBuildCondBr, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFMul,
8 | LLVMBuildFSub, LLVMBuildPhi, LLVMBuildRet, LLVMBuildUIToFP, LLVMCreateBuilderInContext,
9 | LLVMDisposeBuilder, LLVMGetInsertBlock, LLVMPositionBuilderAtEnd,
10 | },
11 | prelude::{LLVMBuilderRef, LLVMValueRef},
12 | LLVMRealPredicate,
13 | };
14 |
15 | use std::marker::PhantomData;
16 |
17 | use super::{BasicBlock, FnValue, Module, PhiValue, Type, Value};
18 |
19 | // Definition of LLVM C API functions using our `repr(transparent)` types.
20 | extern "C" {
21 | fn LLVMBuildCall2(
22 | arg1: LLVMBuilderRef,
23 | arg2: Type<'_>,
24 | Fn: FnValue<'_>,
25 | Args: *mut Value<'_>,
26 | NumArgs: ::libc::c_uint,
27 | Name: *const ::libc::c_char,
28 | ) -> LLVMValueRef;
29 | }
30 |
31 | /// Wrapper for a LLVM IR Builder.
32 | pub struct IRBuilder<'llvm> {
33 | builder: LLVMBuilderRef,
34 | _ctx: PhantomData<&'llvm ()>,
35 | }
36 |
37 | impl<'llvm> IRBuilder<'llvm> {
38 | /// Create a new LLVM IR Builder with the `module`s context.
39 | ///
40 | /// # Panics
41 | ///
42 | /// Panics if creating the IR Builder fails.
43 | pub fn with_ctx(module: &'llvm Module) -> IRBuilder<'llvm> {
44 | let builder = unsafe { LLVMCreateBuilderInContext(module.ctx()) };
45 | assert!(!builder.is_null());
46 |
47 | IRBuilder {
48 | builder,
49 | _ctx: PhantomData,
50 | }
51 | }
52 |
53 | /// Position the IR Builder at the end of the given Basic Block.
54 | pub fn pos_at_end(&self, bb: BasicBlock<'llvm>) {
55 | unsafe {
56 | LLVMPositionBuilderAtEnd(self.builder, bb.bb_ref());
57 | }
58 | }
59 |
60 | /// Get the BasicBlock the IRBuilder currently inputs into.
61 | ///
62 | /// # Panics
63 | ///
64 | /// Panics if LLVM API returns a `null` pointer.
65 | pub fn get_insert_block(&self) -> BasicBlock<'llvm> {
66 | let bb_ref = unsafe { LLVMGetInsertBlock(self.builder) };
67 | assert!(!bb_ref.is_null());
68 |
69 | BasicBlock::new(bb_ref)
70 | }
71 |
72 | /// Emit a [fadd](https://llvm.org/docs/LangRef.html#fadd-instruction) instruction.
73 | ///
74 | /// # Panics
75 | ///
76 | /// Panics if LLVM API returns a `null` pointer.
77 | pub fn fadd(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
78 | debug_assert!(lhs.is_f64(), "fadd: Expected f64 as lhs operand!");
79 | debug_assert!(rhs.is_f64(), "fadd: Expected f64 as rhs operand!");
80 |
81 | let value_ref = unsafe {
82 | LLVMBuildFAdd(
83 | self.builder,
84 | lhs.value_ref(),
85 | rhs.value_ref(),
86 | b"fadd\0".as_ptr().cast(),
87 | )
88 | };
89 | Value::new(value_ref)
90 | }
91 |
92 | /// Emit a [fsub](https://llvm.org/docs/LangRef.html#fsub-instruction) instruction.
93 | ///
94 | /// # Panics
95 | ///
96 | /// Panics if LLVM API returns a `null` pointer.
97 | pub fn fsub(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
98 | debug_assert!(lhs.is_f64(), "fsub: Expected f64 as lhs operand!");
99 | debug_assert!(rhs.is_f64(), "fsub: Expected f64 as rhs operand!");
100 |
101 | let value_ref = unsafe {
102 | LLVMBuildFSub(
103 | self.builder,
104 | lhs.value_ref(),
105 | rhs.value_ref(),
106 | b"fsub\0".as_ptr().cast(),
107 | )
108 | };
109 | Value::new(value_ref)
110 | }
111 |
112 | /// Emit a [fmul](https://llvm.org/docs/LangRef.html#fmul-instruction) instruction.
113 | ///
114 | /// # Panics
115 | ///
116 | /// Panics if LLVM API returns a `null` pointer.
117 | pub fn fmul(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
118 | debug_assert!(lhs.is_f64(), "fmul: Expected f64 as lhs operand!");
119 | debug_assert!(rhs.is_f64(), "fmul: Expected f64 as rhs operand!");
120 |
121 | let value_ref = unsafe {
122 | LLVMBuildFMul(
123 | self.builder,
124 | lhs.value_ref(),
125 | rhs.value_ref(),
126 | b"fmul\0".as_ptr().cast(),
127 | )
128 | };
129 | Value::new(value_ref)
130 | }
131 |
132 | /// Emit a [fcmpult](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction.
133 | ///
134 | /// # Panics
135 | ///
136 | /// Panics if LLVM API returns a `null` pointer.
137 | pub fn fcmpult(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
138 | debug_assert!(lhs.is_f64(), "fcmpult: Expected f64 as lhs operand!");
139 | debug_assert!(rhs.is_f64(), "fcmpult: Expected f64 as rhs operand!");
140 |
141 | let value_ref = unsafe {
142 | LLVMBuildFCmp(
143 | self.builder,
144 | LLVMRealPredicate::LLVMRealULT,
145 | lhs.value_ref(),
146 | rhs.value_ref(),
147 | b"fcmpult\0".as_ptr().cast(),
148 | )
149 | };
150 | Value::new(value_ref)
151 | }
152 |
153 | /// Emit a [fcmpone](https://llvm.org/docs/LangRef.html#fcmp-instruction) instruction.
154 | ///
155 | /// # Panics
156 | ///
157 | /// Panics if LLVM API returns a `null` pointer.
158 | pub fn fcmpone(&self, lhs: Value<'llvm>, rhs: Value<'llvm>) -> Value<'llvm> {
159 | debug_assert!(lhs.is_f64(), "fcmone: Expected f64 as lhs operand!");
160 | debug_assert!(rhs.is_f64(), "fcmone: Expected f64 as rhs operand!");
161 |
162 | let value_ref = unsafe {
163 | LLVMBuildFCmp(
164 | self.builder,
165 | LLVMRealPredicate::LLVMRealONE,
166 | lhs.value_ref(),
167 | rhs.value_ref(),
168 | b"fcmpone\0".as_ptr().cast(),
169 | )
170 | };
171 | Value::new(value_ref)
172 | }
173 |
174 | /// Emit a [uitofp](https://llvm.org/docs/LangRef.html#uitofp-to-instruction) instruction.
175 | ///
176 | /// # Panics
177 | ///
178 | /// Panics if LLVM API returns a `null` pointer.
179 | pub fn uitofp(&self, val: Value<'llvm>, dest_type: Type<'llvm>) -> Value<'llvm> {
180 | debug_assert!(val.is_int(), "uitofp: Expected integer operand!");
181 |
182 | let value_ref = unsafe {
183 | LLVMBuildUIToFP(
184 | self.builder,
185 | val.value_ref(),
186 | dest_type.type_ref(),
187 | b"uitofp\0".as_ptr().cast(),
188 | )
189 | };
190 | Value::new(value_ref)
191 | }
192 |
193 | /// Emit a [call](https://llvm.org/docs/LangRef.html#call-instruction) instruction.
194 | ///
195 | /// # Panics
196 | ///
197 | /// Panics if LLVM API returns a `null` pointer.
198 | pub fn call(&self, fn_value: FnValue<'llvm>, args: &mut [Value<'llvm>]) -> Value<'llvm> {
199 | let value_ref = unsafe {
200 | LLVMBuildCall2(
201 | self.builder,
202 | fn_value.fn_type(),
203 | fn_value,
204 | args.as_mut_ptr(),
205 | args.len() as libc::c_uint,
206 | b"call\0".as_ptr().cast(),
207 | )
208 | };
209 | Value::new(value_ref)
210 | }
211 |
212 | /// Emit a [ret](https://llvm.org/docs/LangRef.html#ret-instruction) instruction.
213 | ///
214 | /// # Panics
215 | ///
216 | /// Panics if LLVM API returns a `null` pointer.
217 | pub fn ret(&self, ret: Value<'llvm>) {
218 | let ret = unsafe { LLVMBuildRet(self.builder, ret.value_ref()) };
219 | assert!(!ret.is_null());
220 | }
221 |
222 | /// Emit an unconditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction.
223 | ///
224 | /// # Panics
225 | ///
226 | /// Panics if LLVM API returns a `null` pointer.
227 | pub fn br(&self, dest: BasicBlock<'llvm>) {
228 | let br_ref = unsafe { LLVMBuildBr(self.builder, dest.bb_ref()) };
229 | assert!(!br_ref.is_null());
230 | }
231 |
232 | /// Emit a conditional [br](https://llvm.org/docs/LangRef.html#br-instruction) instruction.
233 | ///
234 | /// # Panics
235 | ///
236 | /// Panics if LLVM API returns a `null` pointer.
237 | pub fn cond_br(&self, cond: Value<'llvm>, then: BasicBlock<'llvm>, else_: BasicBlock<'llvm>) {
238 | let br_ref = unsafe {
239 | LLVMBuildCondBr(
240 | self.builder,
241 | cond.value_ref(),
242 | then.bb_ref(),
243 | else_.bb_ref(),
244 | )
245 | };
246 | assert!(!br_ref.is_null());
247 | }
248 |
249 | /// Emit a [phi](https://llvm.org/docs/LangRef.html#phi-instruction) instruction.
250 | ///
251 | /// # Panics
252 | ///
253 | /// Panics if LLVM API returns a `null` pointer.
254 | pub fn phi(
255 | &self,
256 | phi_type: Type<'llvm>,
257 | incoming: &[(Value<'llvm>, BasicBlock<'llvm>)],
258 | ) -> PhiValue<'llvm> {
259 | let phi_ref =
260 | unsafe { LLVMBuildPhi(self.builder, phi_type.type_ref(), b"phi\0".as_ptr().cast()) };
261 | assert!(!phi_ref.is_null());
262 |
263 | for (val, bb) in incoming {
264 | debug_assert_eq!(
265 | val.type_of().kind(),
266 | phi_type.kind(),
267 | "Type of incoming phi value must be the same as the type used to build the phi node."
268 | );
269 |
270 | unsafe {
271 | LLVMAddIncoming(phi_ref, &mut val.value_ref() as _, &mut bb.bb_ref() as _, 1);
272 | }
273 | }
274 |
275 | PhiValue::new(phi_ref)
276 | }
277 | }
278 |
279 | impl Drop for IRBuilder<'_> {
280 | fn drop(&mut self) {
281 | unsafe { LLVMDisposeBuilder(self.builder) }
282 | }
283 | }
284 |
--------------------------------------------------------------------------------
/src/llvm/lljit.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_sys::orc2::{
6 | lljit::{
7 | LLVMOrcCreateLLJIT, LLVMOrcLLJITAddLLVMIRModuleWithRT, LLVMOrcLLJITGetGlobalPrefix,
8 | LLVMOrcLLJITGetMainJITDylib, LLVMOrcLLJITLookup, LLVMOrcLLJITRef,
9 | },
10 | LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess, LLVMOrcDefinitionGeneratorRef,
11 | LLVMOrcJITDylibAddGenerator, LLVMOrcJITDylibCreateResourceTracker, LLVMOrcJITDylibRef,
12 | LLVMOrcReleaseResourceTracker, LLVMOrcResourceTrackerRef, LLVMOrcResourceTrackerRemove,
13 | };
14 |
15 | use std::convert::TryFrom;
16 | use std::marker::PhantomData;
17 |
18 | use super::{Error, Module};
19 | use crate::SmallCStr;
20 |
21 | /// Marker trait to constrain function signatures that can be looked up in the JIT.
22 | pub trait JitFn {}
23 |
24 | impl JitFn for unsafe extern "C" fn() -> f64 {}
25 |
26 | /// Wrapper for a LLVM [LLJIT](https://www.llvm.org/docs/ORCv2.html#lljit-and-lllazyjit).
27 | pub struct LLJit {
28 | jit: LLVMOrcLLJITRef,
29 | dylib: LLVMOrcJITDylibRef,
30 | }
31 |
32 | impl LLJit {
33 | /// Create a new LLJit instance.
34 | ///
35 | /// # Panics
36 | ///
37 | /// Panics if LLVM API returns a `null` pointer or an error.
38 | pub fn new() -> LLJit {
39 | let (jit, dylib) = unsafe {
40 | let mut jit = std::ptr::null_mut();
41 | let err = LLVMOrcCreateLLJIT(
42 | &mut jit as _,
43 | std::ptr::null_mut(), /* builder: nullptr -> default */
44 | );
45 |
46 | if let Some(err) = Error::from(err) {
47 | panic!("Error: {}", err.as_str());
48 | }
49 |
50 | let dylib = LLVMOrcLLJITGetMainJITDylib(jit);
51 | assert!(!dylib.is_null());
52 |
53 | (jit, dylib)
54 | };
55 |
56 | LLJit { jit, dylib }
57 | }
58 |
59 | /// Add an LLVM IR module to the JIT. Return a [`ResourceTracker`], which when dropped, will
60 | /// remove the code of the LLVM IR module from the JIT.
61 | ///
62 | /// # Panics
63 | ///
64 | /// Panics if LLVM API returns a `null` pointer or an error.
65 | pub fn add_module(&self, module: Module) -> ResourceTracker<'_> {
66 | let tsmod = module.into_raw_thread_safe_module();
67 |
68 | let rt = unsafe {
69 | let rt = LLVMOrcJITDylibCreateResourceTracker(self.dylib);
70 | let err = LLVMOrcLLJITAddLLVMIRModuleWithRT(self.jit, rt, tsmod);
71 |
72 | if let Some(err) = Error::from(err) {
73 | panic!("Error: {}", err.as_str());
74 | }
75 |
76 | rt
77 | };
78 |
79 | ResourceTracker::new(rt)
80 | }
81 |
82 | /// Find the symbol with the name `sym` in the JIT.
83 | ///
84 | /// # Panics
85 | ///
86 | /// Panics if the symbol is not found in the JIT.
87 | pub fn find_symbol(&self, sym: &str) -> F {
88 | let sym =
89 | SmallCStr::try_from(sym).expect("Failed to convert 'sym' argument to small C string!");
90 |
91 | unsafe {
92 | let mut addr = 0u64;
93 | let err = LLVMOrcLLJITLookup(self.jit, &mut addr as _, sym.as_ptr());
94 |
95 | if let Some(err) = Error::from(err) {
96 | panic!("Error: {}", err.as_str());
97 | }
98 |
99 | debug_assert_eq!(core::mem::size_of_val(&addr), core::mem::size_of::());
100 | std::mem::transmute_copy(&addr)
101 | }
102 | }
103 |
104 | /// Enable lookup of dynamic symbols available in the current process from the JIT.
105 | ///
106 | /// # Panics
107 | ///
108 | /// Panics if LLVM API returns an error.
109 | pub fn enable_process_symbols(&self) {
110 | unsafe {
111 | let mut proc_syms_gen: LLVMOrcDefinitionGeneratorRef = std::ptr::null_mut();
112 | let err = LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
113 | &mut proc_syms_gen as _,
114 | self.global_prefix(),
115 | None, /* filter */
116 | std::ptr::null_mut(), /* filter ctx */
117 | );
118 |
119 | if let Some(err) = Error::from(err) {
120 | panic!("Error: {}", err.as_str());
121 | }
122 |
123 | LLVMOrcJITDylibAddGenerator(self.dylib, proc_syms_gen);
124 | }
125 | }
126 |
127 | /// Return the global prefix character according to the LLJITs data layout.
128 | fn global_prefix(&self) -> libc::c_char {
129 | unsafe { LLVMOrcLLJITGetGlobalPrefix(self.jit) }
130 | }
131 | }
132 |
133 | /// A resource handle for code added to an [`LLJit`] instance.
134 | ///
135 | /// When a `ResourceTracker` handle is dropped, the code corresponding to the handle will be
136 | /// removed from the JIT.
137 | pub struct ResourceTracker<'jit>(LLVMOrcResourceTrackerRef, PhantomData<&'jit ()>);
138 |
139 | impl<'jit> ResourceTracker<'jit> {
140 | fn new(rt: LLVMOrcResourceTrackerRef) -> ResourceTracker<'jit> {
141 | assert!(!rt.is_null());
142 | ResourceTracker(rt, PhantomData)
143 | }
144 | }
145 |
146 | impl Drop for ResourceTracker<'_> {
147 | fn drop(&mut self) {
148 | unsafe {
149 | let err = LLVMOrcResourceTrackerRemove(self.0);
150 |
151 | if let Some(err) = Error::from(err) {
152 | panic!("Error: {}", err.as_str());
153 | }
154 |
155 | LLVMOrcReleaseResourceTracker(self.0);
156 | };
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/src/llvm/mod.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | //! Safe wrapper around the LLVM C API.
6 | //!
7 | //! References returned from the LLVM API are tied to the `'llvm` lifetime which is bound to the
8 | //! context where the objects are created in.
9 | //! We do not offer wrappers to remove or delete any objects in the context and therefore all the
10 | //! references will be valid for the liftime of the context.
11 | //!
12 | //! For the scope of this tutorial we mainly use assertions to validate the results from the LLVM
13 | //! API calls.
14 |
15 | use llvm_sys::{
16 | core::LLVMShutdown,
17 | error::{LLVMDisposeErrorMessage, LLVMErrorRef, LLVMGetErrorMessage},
18 | target::{
19 | LLVM_InitializeNativeAsmParser, LLVM_InitializeNativeAsmPrinter,
20 | LLVM_InitializeNativeTarget,
21 | },
22 | };
23 |
24 | use std::ffi::CStr;
25 |
26 | mod basic_block;
27 | mod builder;
28 | mod lljit;
29 | mod module;
30 | mod pass_manager;
31 | mod type_;
32 | mod value;
33 |
34 | pub use basic_block::BasicBlock;
35 | pub use builder::IRBuilder;
36 | pub use lljit::{LLJit, ResourceTracker};
37 | pub use module::Module;
38 | pub use pass_manager::FunctionPassManager;
39 | pub use type_::Type;
40 | pub use value::{FnValue, PhiValue, Value};
41 |
42 | struct Error<'llvm>(&'llvm mut libc::c_char);
43 |
44 | impl<'llvm> Error<'llvm> {
45 | fn from(err: LLVMErrorRef) -> Option> {
46 | (!err.is_null()).then(|| Error(unsafe { &mut *LLVMGetErrorMessage(err) }))
47 | }
48 |
49 | fn as_str(&self) -> &str {
50 | unsafe { CStr::from_ptr(self.0) }
51 | .to_str()
52 | .expect("Expected valid UTF8 string from LLVM API")
53 | }
54 | }
55 |
56 | impl Drop for Error<'_> {
57 | fn drop(&mut self) {
58 | unsafe {
59 | LLVMDisposeErrorMessage(self.0 as *mut libc::c_char);
60 | }
61 | }
62 | }
63 |
64 | /// Initialize native target for corresponding to host (useful for jitting).
65 | pub fn initialize_native_taget() {
66 | unsafe {
67 | assert_eq!(LLVM_InitializeNativeTarget(), 0);
68 | assert_eq!(LLVM_InitializeNativeAsmParser(), 0);
69 | assert_eq!(LLVM_InitializeNativeAsmPrinter(), 0);
70 | }
71 | }
72 |
73 | /// Deallocate and destroy all "ManagedStatic" variables.
74 | pub fn shutdown() {
75 | unsafe {
76 | LLVMShutdown();
77 | };
78 | }
79 |
--------------------------------------------------------------------------------
/src/llvm/module.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_sys::{
6 | core::{
7 | LLVMAddFunction, LLVMAppendBasicBlockInContext, LLVMCreateBasicBlockInContext,
8 | LLVMDisposeModule, LLVMDoubleTypeInContext, LLVMDumpModule, LLVMGetNamedFunction,
9 | LLVMModuleCreateWithNameInContext,
10 | },
11 | orc2::{
12 | LLVMOrcCreateNewThreadSafeContext, LLVMOrcCreateNewThreadSafeModule,
13 | LLVMOrcDisposeThreadSafeContext, LLVMOrcThreadSafeContextGetContext,
14 | LLVMOrcThreadSafeContextRef, LLVMOrcThreadSafeModuleRef,
15 | },
16 | prelude::{LLVMBool, LLVMContextRef, LLVMModuleRef, LLVMTypeRef},
17 | LLVMTypeKind,
18 | };
19 |
20 | use std::convert::TryFrom;
21 |
22 | use super::{BasicBlock, FnValue, Type};
23 | use crate::SmallCStr;
24 |
25 | // Definition of LLVM C API functions using our `repr(transparent)` types.
26 | extern "C" {
27 | fn LLVMFunctionType(
28 | ReturnType: Type<'_>,
29 | ParamTypes: *mut Type<'_>,
30 | ParamCount: ::libc::c_uint,
31 | IsVarArg: LLVMBool,
32 | ) -> LLVMTypeRef;
33 | }
34 |
35 | /// Wrapper for a LLVM Module with its own LLVM Context.
36 | pub struct Module {
37 | tsctx: LLVMOrcThreadSafeContextRef,
38 | ctx: LLVMContextRef,
39 | module: LLVMModuleRef,
40 | }
41 |
42 | impl<'llvm> Module {
43 | /// Create a new Module instance.
44 | ///
45 | /// # Panics
46 | ///
47 | /// Panics if creating the context or the module fails.
48 | pub fn new() -> Self {
49 | let (tsctx, ctx, module) = unsafe {
50 | // We generate a thread safe context because we are going to jit this IR module and
51 | // there is no method to create a thread safe context wrapper from an existing context
52 | // reference (at the time of writing this).
53 | //
54 | // ThreadSafeContext has shared ownership (start with ref count 1).
55 | // We must explicitly dispose our reference (dec ref count).
56 | let tc = LLVMOrcCreateNewThreadSafeContext();
57 | assert!(!tc.is_null());
58 |
59 | let c = LLVMOrcThreadSafeContextGetContext(tc);
60 | let m = LLVMModuleCreateWithNameInContext(b"module\0".as_ptr().cast(), c);
61 | assert!(!c.is_null() && !m.is_null());
62 | (tc, c, m)
63 | };
64 |
65 | Module { tsctx, ctx, module }
66 | }
67 |
68 | /// Get the raw LLVM context reference.
69 | #[inline]
70 | pub(super) fn ctx(&self) -> LLVMContextRef {
71 | self.ctx
72 | }
73 |
74 | /// Get the raw LLVM module reference.
75 | #[inline]
76 | pub(super) fn module(&self) -> LLVMModuleRef {
77 | self.module
78 | }
79 |
80 | /// Consume the module and turn in into a raw LLVM ThreadSafeModule reference.
81 | ///
82 | /// If ownership of the raw reference is not transferred (eg to the JIT), memory will be leaked
83 | /// in case the reference is disposed explicitly with LLVMOrcDisposeThreadSafeModule.
84 | #[inline]
85 | pub(super) fn into_raw_thread_safe_module(mut self) -> LLVMOrcThreadSafeModuleRef {
86 | let m = std::mem::replace(&mut self.module, std::ptr::null_mut());
87 |
88 | // ThreadSafeModule has unique ownership.
89 | // Takes ownership of module and increments ThreadSafeContext ref count.
90 | //
91 | // We must not reference/dispose `m` after this call, but we need to dispose our `tsctx`
92 | // reference.
93 | let tm = unsafe { LLVMOrcCreateNewThreadSafeModule(m, self.tsctx) };
94 | assert!(!tm.is_null());
95 |
96 | tm
97 | }
98 |
99 | /// Dump LLVM IR emitted into the Module to stdout.
100 | pub fn dump(&self) {
101 | unsafe { LLVMDumpModule(self.module) };
102 | }
103 |
104 | /// Get a type reference representing a `f64` float.
105 | ///
106 | /// # Panics
107 | ///
108 | /// Panics if LLVM API returns a `null` pointer.
109 | pub fn type_f64(&self) -> Type<'llvm> {
110 | let type_ref = unsafe { LLVMDoubleTypeInContext(self.ctx) };
111 | Type::new(type_ref)
112 | }
113 |
114 | /// Get a type reference representing a `fn(args) -> ret` function.
115 | ///
116 | /// # Panics
117 | ///
118 | /// Panics if LLVM API returns a `null` pointer.
119 | pub fn type_fn(&'llvm self, args: &mut [Type<'llvm>], ret: Type<'llvm>) -> Type<'llvm> {
120 | let type_ref = unsafe {
121 | LLVMFunctionType(
122 | ret,
123 | args.as_mut_ptr(),
124 | args.len() as libc::c_uint,
125 | 0, /* IsVarArg */
126 | )
127 | };
128 | Type::new(type_ref)
129 | }
130 |
131 | /// Add a function with the given `name` and `fn_type` to the module and return a value
132 | /// reference representing the function.
133 | ///
134 | /// # Panics
135 | ///
136 | /// Panics if LLVM API returns a `null` pointer or `name` could not be converted to a
137 | /// [`SmallCStr`].
138 | pub fn add_fn(&'llvm self, name: &str, fn_type: Type<'llvm>) -> FnValue<'llvm> {
139 | debug_assert_eq!(
140 | fn_type.kind(),
141 | LLVMTypeKind::LLVMFunctionTypeKind,
142 | "Expected a function type when adding a function!"
143 | );
144 |
145 | let name = SmallCStr::try_from(name)
146 | .expect("Failed to convert 'name' argument to small C string!");
147 |
148 | let value_ref = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type.type_ref()) };
149 | FnValue::new(value_ref)
150 | }
151 |
152 | /// Get a function value reference to the function with the given `name` if it was previously
153 | /// added to the module with [`add_fn`][Module::add_fn].
154 | ///
155 | /// # Panics
156 | ///
157 | /// Panics if `name` could not be converted to a [`SmallCStr`].
158 | pub fn get_fn(&'llvm self, name: &str) -> Option> {
159 | let name = SmallCStr::try_from(name)
160 | .expect("Failed to convert 'name' argument to small C string!");
161 |
162 | let value_ref = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
163 |
164 | (!value_ref.is_null()).then(|| FnValue::new(value_ref))
165 | }
166 |
167 | /// Append a Basic Block to the end of the function referenced by the value reference
168 | /// `fn_value`.
169 | ///
170 | /// # Panics
171 | ///
172 | /// Panics if LLVM API returns a `null` pointer.
173 | pub fn append_basic_block(&'llvm self, fn_value: FnValue<'llvm>) -> BasicBlock<'llvm> {
174 | let block = unsafe {
175 | LLVMAppendBasicBlockInContext(
176 | self.ctx,
177 | fn_value.value_ref(),
178 | b"block\0".as_ptr().cast(),
179 | )
180 | };
181 | assert!(!block.is_null());
182 |
183 | BasicBlock::new(block)
184 | }
185 |
186 | /// Create a free-standing Basic Block without adding it to a function.
187 | /// This can be added to a function at a later point in time with
188 | /// [`FnValue::append_basic_block`].
189 | ///
190 | /// # Panics
191 | ///
192 | /// Panics if LLVM API returns a `null` pointer.
193 | pub fn create_basic_block(&self) -> BasicBlock<'llvm> {
194 | let block = unsafe { LLVMCreateBasicBlockInContext(self.ctx, b"block\0".as_ptr().cast()) };
195 | assert!(!block.is_null());
196 |
197 | BasicBlock::new(block)
198 | }
199 | }
200 |
201 | impl Drop for Module {
202 | fn drop(&mut self) {
203 | unsafe {
204 | // In case we turned the module into a ThreadSafeModule, we must not dispose the module
205 | // reference because ThreadSafeModule took ownership!
206 | if !self.module.is_null() {
207 | LLVMDisposeModule(self.module);
208 | }
209 |
210 | // Dispose ThreadSafeContext reference (dec ref count) in any case.
211 | LLVMOrcDisposeThreadSafeContext(self.tsctx);
212 | }
213 | }
214 | }
215 |
--------------------------------------------------------------------------------
/src/llvm/pass_manager.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_sys::{
6 | core::{
7 | LLVMCreateFunctionPassManagerForModule, LLVMDisposePassManager,
8 | LLVMInitializeFunctionPassManager, LLVMRunFunctionPassManager,
9 | },
10 | prelude::LLVMPassManagerRef,
11 | transforms::{
12 | instcombine::LLVMAddInstructionCombiningPass,
13 | scalar::{LLVMAddCFGSimplificationPass, LLVMAddNewGVNPass, LLVMAddReassociatePass},
14 | },
15 | };
16 |
17 | use std::marker::PhantomData;
18 |
19 | use super::{FnValue, Module};
20 |
21 | /// Wrapper for a LLVM Function PassManager (legacy).
22 | pub struct FunctionPassManager<'llvm> {
23 | fpm: LLVMPassManagerRef,
24 | _ctx: PhantomData<&'llvm ()>,
25 | }
26 |
27 | impl<'llvm> FunctionPassManager<'llvm> {
28 | /// Create a new Function PassManager with the following optimization passes
29 | /// - InstructionCombiningPass
30 | /// - ReassociatePass
31 | /// - NewGVNPass
32 | /// - CFGSimplificationPass
33 | ///
34 | /// The list of selected optimization passes is taken from the tutorial chapter [LLVM
35 | /// Optimization Passes](https://llvm.org/docs/tutorial/MyFirstLanguageFrontend/LangImpl04.html#id3).
36 | pub fn with_ctx(module: &'llvm Module) -> FunctionPassManager<'llvm> {
37 | let fpm = unsafe {
38 | // Borrows module reference.
39 | LLVMCreateFunctionPassManagerForModule(module.module())
40 | };
41 | assert!(!fpm.is_null());
42 |
43 | unsafe {
44 | // Do simple "peephole" optimizations and bit-twiddling optzns.
45 | LLVMAddInstructionCombiningPass(fpm);
46 | // Reassociate expressions.
47 | LLVMAddReassociatePass(fpm);
48 | // Eliminate Common SubExpressions.
49 | LLVMAddNewGVNPass(fpm);
50 | // Simplify the control flow graph (deleting unreachable blocks, etc).
51 | LLVMAddCFGSimplificationPass(fpm);
52 |
53 | let fail = LLVMInitializeFunctionPassManager(fpm);
54 | assert_eq!(fail, 0);
55 | }
56 |
57 | FunctionPassManager {
58 | fpm,
59 | _ctx: PhantomData,
60 | }
61 | }
62 |
63 | /// Run the optimization passes registered with the Function PassManager on the function
64 | /// referenced by `fn_value`.
65 | pub fn run(&'llvm self, fn_value: FnValue<'llvm>) {
66 | unsafe {
67 | // Returns 1 if any of the passes modified the function, false otherwise.
68 | LLVMRunFunctionPassManager(self.fpm, fn_value.value_ref());
69 | }
70 | }
71 | }
72 |
73 | impl Drop for FunctionPassManager<'_> {
74 | fn drop(&mut self) {
75 | unsafe {
76 | LLVMDisposePassManager(self.fpm);
77 | }
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/src/llvm/type_.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_sys::{
6 | core::{LLVMConstReal, LLVMDumpType, LLVMGetTypeKind},
7 | prelude::LLVMTypeRef,
8 | LLVMTypeKind,
9 | };
10 |
11 | use std::marker::PhantomData;
12 |
13 | use super::Value;
14 |
15 | /// Wrapper for a LLVM Type Reference.
16 | #[derive(Copy, Clone)]
17 | #[repr(transparent)]
18 | pub struct Type<'llvm>(LLVMTypeRef, PhantomData<&'llvm ()>);
19 |
20 | impl<'llvm> Type<'llvm> {
21 | /// Create a new Type instance.
22 | ///
23 | /// # Panics
24 | ///
25 | /// Panics if `type_ref` is a null pointer.
26 | pub(super) fn new(type_ref: LLVMTypeRef) -> Self {
27 | assert!(!type_ref.is_null());
28 | Type(type_ref, PhantomData)
29 | }
30 |
31 | /// Get the raw LLVM type reference.
32 | #[inline]
33 | pub(super) fn type_ref(&self) -> LLVMTypeRef {
34 | self.0
35 | }
36 |
37 | /// Get the LLVM type kind for the given type reference.
38 | pub(super) fn kind(&self) -> LLVMTypeKind {
39 | unsafe { LLVMGetTypeKind(self.type_ref()) }
40 | }
41 |
42 | /// Dump the LLVM Type to stdout.
43 | pub fn dump(&self) {
44 | unsafe { LLVMDumpType(self.type_ref()) };
45 | }
46 |
47 | /// Get a value reference representing the const `f64` value.
48 | ///
49 | /// # Panics
50 | ///
51 | /// Panics if LLVM API returns a `null` pointer.
52 | pub fn const_f64(self, n: f64) -> Value<'llvm> {
53 | debug_assert_eq!(
54 | self.kind(),
55 | LLVMTypeKind::LLVMDoubleTypeKind,
56 | "Expected a double type when creating const f64 value!"
57 | );
58 |
59 | let value_ref = unsafe { LLVMConstReal(self.type_ref(), n) };
60 | Value::new(value_ref)
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/src/llvm/value.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | #![allow(unused)]
6 |
7 | use llvm_sys::{
8 | analysis::{LLVMVerifierFailureAction, LLVMVerifyFunction},
9 | core::{
10 | LLVMAddIncoming, LLVMAppendExistingBasicBlock, LLVMCountBasicBlocks, LLVMCountParams,
11 | LLVMDumpValue, LLVMGetParam, LLVMGetValueKind, LLVMGetValueName2, LLVMGlobalGetValueType,
12 | LLVMIsAFunction, LLVMIsAPHINode, LLVMSetValueName2, LLVMTypeOf,
13 | },
14 | prelude::LLVMValueRef,
15 | LLVMTypeKind, LLVMValueKind,
16 | };
17 | use std::ffi::CStr;
18 | use std::marker::PhantomData;
19 | use std::ops::Deref;
20 |
21 | use super::BasicBlock;
22 | use super::Type;
23 |
24 | /// Wrapper for a LLVM Value Reference.
25 | #[derive(Copy, Clone)]
26 | #[repr(transparent)]
27 | pub struct Value<'llvm>(LLVMValueRef, PhantomData<&'llvm ()>);
28 |
29 | impl<'llvm> Value<'llvm> {
30 | /// Create a new Value instance.
31 | ///
32 | /// # Panics
33 | ///
34 | /// Panics if `value_ref` is a null pointer.
35 | pub(super) fn new(value_ref: LLVMValueRef) -> Self {
36 | assert!(!value_ref.is_null());
37 | Value(value_ref, PhantomData)
38 | }
39 |
40 | /// Get the raw LLVM value reference.
41 | #[inline]
42 | pub(super) fn value_ref(&self) -> LLVMValueRef {
43 | self.0
44 | }
45 |
46 | /// Get the LLVM value kind for the given value reference.
47 | pub(super) fn kind(&self) -> LLVMValueKind {
48 | unsafe { LLVMGetValueKind(self.value_ref()) }
49 | }
50 |
51 | /// Check if value is `function` type.
52 | pub(super) fn is_function(&self) -> bool {
53 | let cast = unsafe { LLVMIsAFunction(self.value_ref()) };
54 | !cast.is_null()
55 | }
56 |
57 | /// Check if value is `phinode` type.
58 | pub(super) fn is_phinode(&self) -> bool {
59 | let cast = unsafe { LLVMIsAPHINode(self.value_ref()) };
60 | !cast.is_null()
61 | }
62 |
63 | /// Dump the LLVM Value to stdout.
64 | pub fn dump(&self) {
65 | unsafe { LLVMDumpValue(self.value_ref()) };
66 | }
67 |
68 | /// Get a type reference representing for the given value reference.
69 | ///
70 | /// # Panics
71 | ///
72 | /// Panics if LLVM API returns a `null` pointer.
73 | pub fn type_of(&self) -> Type<'llvm> {
74 | let type_ref = unsafe { LLVMTypeOf(self.value_ref()) };
75 | Type::new(type_ref)
76 | }
77 |
78 | /// Set the name for the given value reference.
79 | ///
80 | /// # Panics
81 | ///
82 | /// Panics if LLVM API returns a `null` pointer.
83 | pub fn set_name(&self, name: &str) {
84 | unsafe { LLVMSetValueName2(self.value_ref(), name.as_ptr().cast(), name.len()) };
85 | }
86 |
87 | /// Get the name for the given value reference.
88 | ///
89 | /// # Panics
90 | ///
91 | /// Panics if LLVM API returns a `null` pointer.
92 | pub fn get_name(&self) -> &'llvm str {
93 | let name = unsafe {
94 | let mut len: libc::size_t = 0;
95 | let name = LLVMGetValueName2(self.0, &mut len as _);
96 | assert!(!name.is_null());
97 |
98 | CStr::from_ptr(name)
99 | };
100 |
101 | // TODO: Does this string live for the time of the LLVM context?!
102 | name.to_str()
103 | .expect("Expected valid UTF8 string from LLVM API")
104 | }
105 |
106 | /// Check if value is of `f64` type.
107 | pub fn is_f64(&self) -> bool {
108 | self.type_of().kind() == LLVMTypeKind::LLVMDoubleTypeKind
109 | }
110 |
111 | /// Check if value is of integer type.
112 | pub fn is_int(&self) -> bool {
113 | self.type_of().kind() == LLVMTypeKind::LLVMIntegerTypeKind
114 | }
115 | }
116 |
117 | /// Wrapper for a LLVM Value Reference specialized for contexts where function values are needed.
118 | #[derive(Copy, Clone)]
119 | #[repr(transparent)]
120 | pub struct FnValue<'llvm>(Value<'llvm>);
121 |
122 | impl<'llvm> Deref for FnValue<'llvm> {
123 | type Target = Value<'llvm>;
124 | fn deref(&self) -> &Self::Target {
125 | &self.0
126 | }
127 | }
128 |
129 | impl<'llvm> FnValue<'llvm> {
130 | /// Create a new FnValue instance.
131 | ///
132 | /// # Panics
133 | ///
134 | /// Panics if `value_ref` is a null pointer.
135 | pub(super) fn new(value_ref: LLVMValueRef) -> Self {
136 | let value = Value::new(value_ref);
137 | debug_assert!(
138 | value.is_function(),
139 | "Expected a fn value when constructing FnValue!"
140 | );
141 |
142 | FnValue(value)
143 | }
144 |
145 | /// Get a type reference representing the function type (return + args) of the given function
146 | /// value.
147 | ///
148 | /// # Panics
149 | ///
150 | /// Panics if LLVM API returns a `null` pointer.
151 | pub fn fn_type(&self) -> Type<'llvm> {
152 | // https://github.com/llvm/llvm-project/issues/72798
153 | let type_ref = unsafe { LLVMGlobalGetValueType(self.value_ref()) };
154 | Type::new(type_ref)
155 | }
156 |
157 | /// Get the number of function arguments for the given function value.
158 | pub fn args(&self) -> usize {
159 | unsafe { LLVMCountParams(self.value_ref()) as usize }
160 | }
161 |
162 | /// Get a value reference for the function argument at index `idx`.
163 | ///
164 | /// # Panics
165 | ///
166 | /// Panics if LLVM API returns a `null` pointer or indexed out of bounds.
167 | pub fn arg(&self, idx: usize) -> Value<'llvm> {
168 | assert!(idx < self.args());
169 |
170 | let value_ref = unsafe { LLVMGetParam(self.value_ref(), idx as libc::c_uint) };
171 | Value::new(value_ref)
172 | }
173 |
174 | /// Get the number of Basic Blocks for the given function value.
175 | pub fn basic_blocks(&self) -> usize {
176 | unsafe { LLVMCountBasicBlocks(self.value_ref()) as usize }
177 | }
178 |
179 | /// Append a Basic Block to the end of the function value.
180 | pub fn append_basic_block(&self, bb: BasicBlock<'llvm>) {
181 | unsafe {
182 | LLVMAppendExistingBasicBlock(self.value_ref(), bb.bb_ref());
183 | }
184 | }
185 |
186 | /// Verify that the given function is valid.
187 | pub fn verify(&self) -> bool {
188 | unsafe {
189 | LLVMVerifyFunction(
190 | self.value_ref(),
191 | LLVMVerifierFailureAction::LLVMPrintMessageAction,
192 | ) == 0
193 | }
194 | }
195 | }
196 |
197 | /// Wrapper for a LLVM Value Reference specialized for contexts where phi values are needed.
198 | #[derive(Copy, Clone)]
199 | #[repr(transparent)]
200 | pub struct PhiValue<'llvm>(Value<'llvm>);
201 |
202 | impl<'llvm> Deref for PhiValue<'llvm> {
203 | type Target = Value<'llvm>;
204 | fn deref(&self) -> &Self::Target {
205 | &self.0
206 | }
207 | }
208 |
209 | impl<'llvm> PhiValue<'llvm> {
210 | /// Create a new PhiValue instance.
211 | ///
212 | /// # Panics
213 | ///
214 | /// Panics if `value_ref` is a null pointer.
215 | pub(super) fn new(value_ref: LLVMValueRef) -> Self {
216 | let value = Value::new(value_ref);
217 | debug_assert!(
218 | value.is_phinode(),
219 | "Expected a phinode value when constructing PhiValue!"
220 | );
221 |
222 | PhiValue(value)
223 | }
224 |
225 | /// Add an incoming value to the end of a PHI list.
226 | pub fn add_incoming(&self, ival: Value<'llvm>, ibb: BasicBlock<'llvm>) {
227 | debug_assert_eq!(
228 | ival.type_of().kind(),
229 | self.type_of().kind(),
230 | "Type of incoming phi value must be the same as the type used to build the phi node."
231 | );
232 |
233 | unsafe {
234 | LLVMAddIncoming(
235 | self.value_ref(),
236 | &mut ival.value_ref() as _,
237 | &mut ibb.bb_ref() as _,
238 | 1,
239 | );
240 | }
241 | }
242 | }
243 |
--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use llvm_kaleidoscope_rs::{
6 | codegen::Codegen,
7 | lexer::{Lexer, Token},
8 | llvm,
9 | parser::{Parser, PrototypeAST},
10 | Either,
11 | };
12 |
13 | use std::collections::HashMap;
14 | use std::io::{Read, Write};
15 |
16 | #[no_mangle]
17 | #[inline(never)]
18 | pub extern "C" fn putchard(c: libc::c_double) -> f64 {
19 | std::io::stdout()
20 | .write(&[c as u8])
21 | .expect("Failed to write to stdout!");
22 | 0f64
23 | }
24 |
25 | fn main_loop(mut parser: Parser)
26 | where
27 | I: Iterator
- ,
28 | {
29 | // Initialize LLVM module with its own context.
30 | // We will emit LLVM IR into this module.
31 | let mut module = llvm::Module::new();
32 |
33 | // Create a new JIT, based on the LLVM LLJIT.
34 | let jit = llvm::LLJit::new();
35 |
36 | // Enable lookup of dynamic symbols in the current process from the JIT.
37 | jit.enable_process_symbols();
38 |
39 | // Keep track of prototype names to their respective ASTs.
40 | //
41 | // This is useful since we jit every function definition into its own LLVM module.
42 | // To allow calling functions defined in previous LLVM modules we keep track of their
43 | // prototypes and generate IR for their declarations when they are called from another module.
44 | let mut fn_protos: HashMap = HashMap::new();
45 |
46 | // When adding an IR module to the JIT, it will hand out a ResourceTracker. When the
47 | // ResourceTracker is dropped, the code generated from the corresponding module will be removed
48 | // from the JIT.
49 | //
50 | // For each function we want to keep the code generated for the last definition, hence we need
51 | // to keep their ResourceTracker alive.
52 | let mut fn_jit_rt: HashMap = HashMap::new();
53 |
54 | loop {
55 | match parser.cur_tok() {
56 | Token::Eof => break,
57 | Token::Char(';') => {
58 | // Ignore top-level semicolon.
59 | parser.get_next_token();
60 | }
61 | Token::Def => match parser.parse_definition() {
62 | Ok(func) => {
63 | println!("Parse 'def'");
64 | let func_name = &func.0 .0;
65 |
66 | // If we already jitted that function, remove the last definition from the JIT
67 | // by dropping the corresponding ResourceTracker.
68 | fn_jit_rt.remove(func_name);
69 |
70 | if let Ok(func_ir) = Codegen::compile(&module, &mut fn_protos, Either::B(&func))
71 | {
72 | func_ir.dump();
73 |
74 | // Add module to the JIT.
75 | let rt = jit.add_module(module);
76 |
77 | // Keep track of the ResourceTracker to keep the module code in the JIT.
78 | fn_jit_rt.insert(func_name.to_string(), rt);
79 |
80 | // Initialize a new module.
81 | module = llvm::Module::new();
82 | }
83 | }
84 | Err(err) => {
85 | eprintln!("Error: {:?}", err);
86 | parser.get_next_token();
87 | }
88 | },
89 | Token::Extern => match parser.parse_extern() {
90 | Ok(proto) => {
91 | println!("Parse 'extern'");
92 | if let Ok(proto_ir) =
93 | Codegen::compile(&module, &mut fn_protos, Either::A(&proto))
94 | {
95 | proto_ir.dump();
96 |
97 | // Keep track of external function declaration.
98 | fn_protos.insert(proto.0.clone(), proto);
99 | }
100 | }
101 | Err(err) => {
102 | eprintln!("Error: {:?}", err);
103 | parser.get_next_token();
104 | }
105 | },
106 | _ => match parser.parse_top_level_expr() {
107 | Ok(func) => {
108 | println!("Parse top-level expression");
109 | if let Ok(func) = Codegen::compile(&module, &mut fn_protos, Either::B(&func)) {
110 | func.dump();
111 |
112 | // Add module to the JIT. Code will be removed when `_rt` is dropped.
113 | let _rt = jit.add_module(module);
114 |
115 | // Initialize a new module.
116 | module = llvm::Module::new();
117 |
118 | // Call the top level expression.
119 | let fp = jit.find_symbol:: f64>("__anon_expr");
120 | unsafe {
121 | println!("Evaluated to {}", fp());
122 | }
123 | }
124 | }
125 | Err(err) => {
126 | eprintln!("Error: {:?}", err);
127 | parser.get_next_token();
128 | }
129 | },
130 | };
131 | }
132 |
133 | // Dump all the emitted LLVM IR to stdout.
134 | module.dump();
135 | }
136 |
137 | fn run_kaleidoscope(lexer: Lexer)
138 | where
139 | I: Iterator
- ,
140 | {
141 | // Create parser for kaleidoscope.
142 | let mut parser = Parser::new(lexer);
143 |
144 | // Throw first coin and initialize cur_tok.
145 | parser.get_next_token();
146 |
147 | // Initialize native target for jitting.
148 | llvm::initialize_native_taget();
149 |
150 | main_loop(parser);
151 |
152 | // De-allocate managed static LLVM data.
153 | llvm::shutdown();
154 | }
155 |
156 | fn main() {
157 | match std::env::args().nth(1) {
158 | Some(file) => {
159 | println!("Parse {}.", file);
160 |
161 | // Create lexer over file.
162 | let lexer = Lexer::new(
163 | std::fs::File::open(&file)
164 | .expect(&format!("Failed to open file {}!", file))
165 | .bytes()
166 | .filter_map(|v| {
167 | let v = v.ok()?;
168 | Some(v.into())
169 | }),
170 | );
171 | run_kaleidoscope(lexer);
172 | }
173 | None => {
174 | println!("Parse stdin.");
175 | println!("ENTER to parse current input.");
176 | println!("C-d to exit.");
177 |
178 | // Create lexer over stdin.
179 | let lexer = Lexer::new(std::io::stdin().bytes().filter_map(|v| {
180 | let v = v.ok()?;
181 | Some(v.into())
182 | }));
183 | run_kaleidoscope(lexer);
184 | }
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/src/parser.rs:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: MIT
2 | //
3 | // Copyright (c) 2021, Johannes Stoelp
4 |
5 | use crate::lexer::{Lexer, Token};
6 |
7 | #[derive(Debug, PartialEq)]
8 | pub enum ExprAST {
9 | /// Number - Expression class for numeric literals like "1.0".
10 | Number(f64),
11 |
12 | /// Variable - Expression class for referencing a variable, like "a".
13 | Variable(String),
14 |
15 | /// Binary - Expression class for a binary operator.
16 | Binary(char, Box, Box),
17 |
18 | /// Call - Expression class for function calls.
19 | Call(String, Vec),
20 |
21 | /// If - Expression class for if/then/else.
22 | If {
23 | cond: Box,
24 | then: Box,
25 | else_: Box,
26 | },
27 |
28 | /// ForExprAST - Expression class for for/in.
29 | For {
30 | var: String,
31 | start: Box,
32 | end: Box,
33 | step: Option>,
34 | body: Box,
35 | },
36 | }
37 |
38 | /// PrototypeAST - This class represents the "prototype" for a function,
39 | /// which captures its name, and its argument names (thus implicitly the number
40 | /// of arguments the function takes).
41 | #[derive(Debug, PartialEq, Clone)]
42 | pub struct PrototypeAST(pub String, pub Vec);
43 |
44 | /// FunctionAST - This class represents a function definition itself.
45 | #[derive(Debug, PartialEq)]
46 | pub struct FunctionAST(pub PrototypeAST, pub ExprAST);
47 |
48 | /// Parse result with String as Error type (to be compliant with tutorial).
49 | type ParseResult = Result;
50 |
51 | /// Parser for the `kaleidoscope` language.
52 | pub struct Parser
53 | where
54 | I: Iterator
- ,
55 | {
56 | lexer: Lexer,
57 | cur_tok: Option,
58 | }
59 |
60 | impl Parser
61 | where
62 | I: Iterator
- ,
63 | {
64 | pub fn new(lexer: Lexer) -> Self {
65 | Parser {
66 | lexer,
67 | cur_tok: None,
68 | }
69 | }
70 |
71 | // -----------------------
72 | // Simple Token Buffer
73 | // -----------------------
74 |
75 | /// Implement the global variable `int CurTok;` from the tutorial.
76 | ///
77 | /// # Panics
78 | /// Panics if the parser doesn't have a current token.
79 | pub fn cur_tok(&self) -> &Token {
80 | self.cur_tok.as_ref().expect("Parser: Expected cur_token!")
81 | }
82 |
83 | /// Advance the `cur_tok` by getting the next token from the lexer.
84 | ///
85 | /// Implement the fucntion `int getNextToken();` from the tutorial.
86 | pub fn get_next_token(&mut self) {
87 | self.cur_tok = Some(self.lexer.gettok());
88 | }
89 |
90 | // ----------------------------
91 | // Basic Expression Parsing
92 | // ----------------------------
93 |
94 | /// numberexpr ::= number
95 | ///
96 | /// Implement `std::unique_ptr ParseNumberExpr();` from the tutorial.
97 | fn parse_num_expr(&mut self) -> ParseResult {
98 | match *self.cur_tok() {
99 | Token::Number(num) => {
100 | // Consume the number token.
101 | self.get_next_token();
102 | Ok(ExprAST::Number(num))
103 | }
104 | _ => unreachable!(),
105 | }
106 | }
107 |
108 | /// parenexpr ::= '(' expression ')'
109 | ///
110 | /// Implement `std::unique_ptr ParseParenExpr();` from the tutorial.
111 | fn parse_paren_expr(&mut self) -> ParseResult {
112 | // Eat '(' token.
113 | assert_eq!(*self.cur_tok(), Token::Char('('));
114 | self.get_next_token();
115 |
116 | let v = self.parse_expression()?;
117 |
118 | if *self.cur_tok() == Token::Char(')') {
119 | // Eat ')' token.
120 | self.get_next_token();
121 | Ok(v)
122 | } else {
123 | Err("expected ')'".into())
124 | }
125 | }
126 |
127 | /// identifierexpr
128 | /// ::= identifier
129 | /// ::= identifier '(' expression* ')'
130 | ///
131 | /// Implement `std::unique_ptr ParseIdentifierExpr();` from the tutorial.
132 | fn parse_identifier_expr(&mut self) -> ParseResult {
133 | let id_name = match self.cur_tok.take() {
134 | Some(Token::Identifier(id)) => {
135 | // Consume identifier.
136 | self.get_next_token();
137 | id
138 | }
139 | _ => unreachable!(),
140 | };
141 |
142 | if *self.cur_tok() != Token::Char('(') {
143 | // Simple variable reference.
144 | Ok(ExprAST::Variable(id_name))
145 | } else {
146 | // Call.
147 |
148 | // Eat '(' token.
149 | self.get_next_token();
150 |
151 | let mut args: Vec = Vec::new();
152 |
153 | // If there are arguments collect them.
154 | if *self.cur_tok() != Token::Char(')') {
155 | loop {
156 | let arg = self.parse_expression()?;
157 | args.push(arg);
158 |
159 | if *self.cur_tok() == Token::Char(')') {
160 | break;
161 | }
162 |
163 | if *self.cur_tok() != Token::Char(',') {
164 | return Err("Expected ')' or ',' in argument list".into());
165 | }
166 |
167 | self.get_next_token();
168 | }
169 | }
170 |
171 | assert_eq!(*self.cur_tok(), Token::Char(')'));
172 | // Eat ')' token.
173 | self.get_next_token();
174 |
175 | Ok(ExprAST::Call(id_name, args))
176 | }
177 | }
178 |
179 | /// ifexpr ::= 'if' expression 'then' expression 'else' expression
180 | ///
181 | /// Implement `std::unique_ptr ParseIfExpr();` from the tutorial.
182 | fn parse_if_expr(&mut self) -> ParseResult {
183 | // Consume 'if' token.
184 | assert_eq!(*self.cur_tok(), Token::If);
185 | self.get_next_token();
186 |
187 | let cond = self.parse_expression()?;
188 |
189 | if *dbg!(self.cur_tok()) != Token::Then {
190 | return Err("Expected 'then'".into());
191 | }
192 | // Consume 'then' token.
193 | self.get_next_token();
194 |
195 | let then = self.parse_expression()?;
196 |
197 | if *self.cur_tok() != Token::Else {
198 | return Err("Expected 'else'".into());
199 | }
200 | // Consume 'else' token.
201 | self.get_next_token();
202 |
203 | let else_ = self.parse_expression()?;
204 |
205 | Ok(ExprAST::If {
206 | cond: Box::new(cond),
207 | then: Box::new(then),
208 | else_: Box::new(else_),
209 | })
210 | }
211 |
212 | /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
213 | ///
214 | /// Implement `std::unique_ptr ParseForExpr();` from the tutorial.
215 | fn parse_for_expr(&mut self) -> ParseResult {
216 | // Consume the 'for' token.
217 | assert_eq!(*self.cur_tok(), Token::For);
218 | self.get_next_token();
219 |
220 | let var = match self
221 | .parse_identifier_expr()
222 | .map_err(|_| String::from("expected identifier after 'for'"))?
223 | {
224 | ExprAST::Variable(var) => var,
225 | _ => unreachable!(),
226 | };
227 |
228 | // Consume the '=' token.
229 | if *self.cur_tok() != Token::Char('=') {
230 | return Err("expected '=' after for".into());
231 | }
232 | self.get_next_token();
233 |
234 | let start = self.parse_expression()?;
235 |
236 | // Consume the ',' token.
237 | if *self.cur_tok() != Token::Char(',') {
238 | return Err("expected ',' after for start value".into());
239 | }
240 | self.get_next_token();
241 |
242 | let end = self.parse_expression()?;
243 |
244 | let step = if *self.cur_tok() == Token::Char(',') {
245 | // Consume the ',' token.
246 | self.get_next_token();
247 |
248 | Some(self.parse_expression()?)
249 | } else {
250 | None
251 | };
252 |
253 | // Consume the 'in' token.
254 | if *self.cur_tok() != Token::In {
255 | return Err("expected 'in' after for".into());
256 | }
257 | self.get_next_token();
258 |
259 | let body = self.parse_expression()?;
260 |
261 | Ok(ExprAST::For {
262 | var,
263 | start: Box::new(start),
264 | end: Box::new(end),
265 | step: step.map(|s| Box::new(s)),
266 | body: Box::new(body),
267 | })
268 | }
269 |
270 | /// primary
271 | /// ::= identifierexpr
272 | /// ::= numberexpr
273 | /// ::= parenexpr
274 | ///
275 | /// Implement `std::unique_ptr ParsePrimary();` from the tutorial.
276 | fn parse_primary(&mut self) -> ParseResult {
277 | match *self.cur_tok() {
278 | Token::Identifier(_) => self.parse_identifier_expr(),
279 | Token::Number(_) => self.parse_num_expr(),
280 | Token::Char('(') => self.parse_paren_expr(),
281 | Token::If => self.parse_if_expr(),
282 | Token::For => self.parse_for_expr(),
283 | _ => Err("unknown token when expecting an expression".into()),
284 | }
285 | }
286 |
287 | // -----------------------------
288 | // Binary Expression Parsing
289 | // -----------------------------
290 |
291 | /// /// expression
292 | /// ::= primary binoprhs
293 | ///
294 | /// Implement `std::unique_ptr ParseExpression();` from the tutorial.
295 | fn parse_expression(&mut self) -> ParseResult {
296 | let lhs = self.parse_primary()?;
297 | self.parse_bin_op_rhs(0, lhs)
298 | }
299 |
300 | /// binoprhs
301 | /// ::= ('+' primary)*
302 | ///
303 | /// Implement `std::unique_ptr ParseBinOpRHS(int ExprPrec, std::unique_ptr LHS);` from the tutorial.
304 | fn parse_bin_op_rhs(&mut self, expr_prec: isize, mut lhs: ExprAST) -> ParseResult {
305 | loop {
306 | let tok_prec = get_tok_precedence(self.cur_tok());
307 |
308 | // Not a binary operator or precedence is too small.
309 | if tok_prec < expr_prec {
310 | return Ok(lhs);
311 | }
312 |
313 | let binop = match self.cur_tok.take() {
314 | Some(Token::Char(c)) => {
315 | // Eat binary operator.
316 | self.get_next_token();
317 | c
318 | }
319 | _ => unreachable!(),
320 | };
321 |
322 | // lhs BINOP1 rhs BINOP2 remrhs
323 | // ^^^^^^ ^^^^^^
324 | // tok_prec next_prec
325 | //
326 | // In case BINOP1 has higher precedence, we are done here and can build a 'Binary' AST
327 | // node between 'lhs' and 'rhs'.
328 | //
329 | // In case BINOP2 has higher precedence, we take 'rhs' as 'lhs' and recurse into the
330 | // 'remrhs' expression first.
331 |
332 | // Parse primary expression after binary operator.
333 | let mut rhs = self.parse_primary()?;
334 |
335 | let next_prec = get_tok_precedence(self.cur_tok());
336 | if tok_prec < next_prec {
337 | // BINOP2 has higher precedence thatn BINOP1, recurse into 'remhs'.
338 | rhs = self.parse_bin_op_rhs(tok_prec + 1, rhs)?
339 | }
340 |
341 | lhs = ExprAST::Binary(binop, Box::new(lhs), Box::new(rhs));
342 | }
343 | }
344 |
345 | // --------------------
346 | // Parsing the Rest
347 | // --------------------
348 |
349 | /// prototype
350 | /// ::= id '(' id* ')'
351 | ///
352 | /// Implement `std::unique_ptr ParsePrototype();` from the tutorial.
353 | fn parse_prototype(&mut self) -> ParseResult {
354 | let id_name = match self.cur_tok.take() {
355 | Some(Token::Identifier(id)) => {
356 | // Consume the identifier.
357 | self.get_next_token();
358 | id
359 | }
360 | other => {
361 | // Plug back current token.
362 | self.cur_tok = other;
363 | return Err("Expected function name in prototype".into());
364 | }
365 | };
366 |
367 | if *self.cur_tok() != Token::Char('(') {
368 | return Err("Expected '(' in prototype".into());
369 | }
370 |
371 | let mut args: Vec = Vec::new();
372 | loop {
373 | self.get_next_token();
374 |
375 | match self.cur_tok.take() {
376 | Some(Token::Identifier(arg)) => args.push(arg),
377 | Some(Token::Char(',')) => {}
378 | other => {
379 | self.cur_tok = other;
380 | break;
381 | }
382 | }
383 | }
384 |
385 | if *self.cur_tok() != Token::Char(')') {
386 | return Err("Expected ')' in prototype".into());
387 | }
388 |
389 | // Consume ')'.
390 | self.get_next_token();
391 |
392 | Ok(PrototypeAST(id_name, args))
393 | }
394 |
395 | /// definition ::= 'def' prototype expression
396 | ///
397 | /// Implement `std::unique_ptr ParseDefinition();` from the tutorial.
398 | pub fn parse_definition(&mut self) -> ParseResult {
399 | // Consume 'def' token.
400 | assert_eq!(*self.cur_tok(), Token::Def);
401 | self.get_next_token();
402 |
403 | let proto = self.parse_prototype()?;
404 | let expr = self.parse_expression()?;
405 |
406 | Ok(FunctionAST(proto, expr))
407 | }
408 |
409 | /// external ::= 'extern' prototype
410 | ///
411 | /// Implement `std::unique_ptr ParseExtern();` from the tutorial.
412 | pub fn parse_extern(&mut self) -> ParseResult {
413 | // Consume 'extern' token.
414 | assert_eq!(*self.cur_tok(), Token::Extern);
415 | self.get_next_token();
416 |
417 | self.parse_prototype()
418 | }
419 |
420 | /// toplevelexpr ::= expression
421 | ///
422 | /// Implement `std::unique_ptr ParseTopLevelExpr();` from the tutorial.
423 | pub fn parse_top_level_expr(&mut self) -> ParseResult {
424 | let e = self.parse_expression()?;
425 | let proto = PrototypeAST("__anon_expr".into(), Vec::new());
426 | Ok(FunctionAST(proto, e))
427 | }
428 | }
429 |
430 | /// Get the binary operator precedence.
431 | ///
432 | /// Implement `int GetTokPrecedence();` from the tutorial.
433 | fn get_tok_precedence(tok: &Token) -> isize {
434 | match tok {
435 | Token::Char('<') => 10,
436 | Token::Char('+') => 20,
437 | Token::Char('-') => 20,
438 | Token::Char('*') => 40,
439 | _ => -1,
440 | }
441 | }
442 |
443 | #[cfg(test)]
444 | mod test {
445 | use super::{ExprAST, FunctionAST, Parser, PrototypeAST};
446 | use crate::lexer::Lexer;
447 |
448 | fn parser(input: &str) -> Parser {
449 | let l = Lexer::new(input.chars());
450 | let mut p = Parser::new(l);
451 |
452 | // Drop initial coin, initialize cur_tok.
453 | p.get_next_token();
454 |
455 | p
456 | }
457 |
458 | #[test]
459 | fn parse_number() {
460 | let mut p = parser("13.37");
461 |
462 | assert_eq!(p.parse_num_expr(), Ok(ExprAST::Number(13.37f64)));
463 | }
464 |
465 | #[test]
466 | fn parse_variable() {
467 | let mut p = parser("foop");
468 |
469 | assert_eq!(
470 | p.parse_identifier_expr(),
471 | Ok(ExprAST::Variable("foop".into()))
472 | );
473 | }
474 |
475 | #[test]
476 | fn parse_if() {
477 | let mut p = parser("if 1 then 2 else 3");
478 |
479 | let cond = Box::new(ExprAST::Number(1f64));
480 | let then = Box::new(ExprAST::Number(2f64));
481 | let else_ = Box::new(ExprAST::Number(3f64));
482 |
483 | assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ }));
484 |
485 | let mut p = parser("if foo() then bar(2) else baz(3)");
486 |
487 | let cond = Box::new(ExprAST::Call("foo".into(), vec![]));
488 | let then = Box::new(ExprAST::Call("bar".into(), vec![ExprAST::Number(2f64)]));
489 | let else_ = Box::new(ExprAST::Call("baz".into(), vec![ExprAST::Number(3f64)]));
490 |
491 | assert_eq!(p.parse_if_expr(), Ok(ExprAST::If { cond, then, else_ }));
492 | }
493 |
494 | #[test]
495 | fn parse_for() {
496 | let mut p = parser("for i = 1, 2, 3 in 4");
497 |
498 | let var = String::from("i");
499 | let start = Box::new(ExprAST::Number(1f64));
500 | let end = Box::new(ExprAST::Number(2f64));
501 | let step = Some(Box::new(ExprAST::Number(3f64)));
502 | let body = Box::new(ExprAST::Number(4f64));
503 |
504 | assert_eq!(
505 | p.parse_for_expr(),
506 | Ok(ExprAST::For {
507 | var,
508 | start,
509 | end,
510 | step,
511 | body
512 | })
513 | );
514 | }
515 |
516 | #[test]
517 | fn parse_for_no_step() {
518 | let mut p = parser("for i = 1, 2 in 4");
519 |
520 | let var = String::from("i");
521 | let start = Box::new(ExprAST::Number(1f64));
522 | let end = Box::new(ExprAST::Number(2f64));
523 | let step = None;
524 | let body = Box::new(ExprAST::Number(4f64));
525 |
526 | assert_eq!(
527 | p.parse_for_expr(),
528 | Ok(ExprAST::For {
529 | var,
530 | start,
531 | end,
532 | step,
533 | body
534 | })
535 | );
536 | }
537 |
538 | #[test]
539 | fn parse_primary() {
540 | let mut p = parser("1337 foop \n bla(123) \n if a then b else c \n for x=1,2 in 3");
541 |
542 | assert_eq!(p.parse_primary(), Ok(ExprAST::Number(1337f64)));
543 |
544 | assert_eq!(p.parse_primary(), Ok(ExprAST::Variable("foop".into())));
545 |
546 | assert_eq!(
547 | p.parse_primary(),
548 | Ok(ExprAST::Call("bla".into(), vec![ExprAST::Number(123f64)]))
549 | );
550 |
551 | assert_eq!(
552 | p.parse_primary(),
553 | Ok(ExprAST::If {
554 | cond: Box::new(ExprAST::Variable("a".into())),
555 | then: Box::new(ExprAST::Variable("b".into())),
556 | else_: Box::new(ExprAST::Variable("c".into())),
557 | })
558 | );
559 |
560 | assert_eq!(
561 | p.parse_primary(),
562 | Ok(ExprAST::For {
563 | var: String::from("x"),
564 | start: Box::new(ExprAST::Number(1f64)),
565 | end: Box::new(ExprAST::Number(2f64)),
566 | step: None,
567 | body: Box::new(ExprAST::Number(3f64)),
568 | })
569 | );
570 | }
571 |
572 | #[test]
573 | fn parse_binary_op() {
574 | // Operator before RHS has higher precedence, expected AST
575 | //
576 | // -
577 | // / \
578 | // + c
579 | // / \
580 | // a b
581 | let mut p = parser("a + b - c");
582 |
583 | let binexpr_ab = ExprAST::Binary(
584 | '+',
585 | Box::new(ExprAST::Variable("a".into())),
586 | Box::new(ExprAST::Variable("b".into())),
587 | );
588 |
589 | let binexpr_abc = ExprAST::Binary(
590 | '-',
591 | Box::new(binexpr_ab),
592 | Box::new(ExprAST::Variable("c".into())),
593 | );
594 |
595 | assert_eq!(p.parse_expression(), Ok(binexpr_abc));
596 | }
597 |
598 | #[test]
599 | fn parse_binary_op2() {
600 | // Operator after RHS has higher precedence, expected AST
601 | //
602 | // +
603 | // / \
604 | // a *
605 | // / \
606 | // b c
607 | let mut p = parser("a + b * c");
608 |
609 | let binexpr_bc = ExprAST::Binary(
610 | '*',
611 | Box::new(ExprAST::Variable("b".into())),
612 | Box::new(ExprAST::Variable("c".into())),
613 | );
614 |
615 | let binexpr_abc = ExprAST::Binary(
616 | '+',
617 | Box::new(ExprAST::Variable("a".into())),
618 | Box::new(binexpr_bc),
619 | );
620 |
621 | assert_eq!(p.parse_expression(), Ok(binexpr_abc));
622 | }
623 |
624 | #[test]
625 | fn parse_prototype() {
626 | let mut p = parser("foo(a,b)");
627 |
628 | let proto = PrototypeAST("foo".into(), vec!["a".into(), "b".into()]);
629 |
630 | assert_eq!(p.parse_prototype(), Ok(proto));
631 | }
632 |
633 | #[test]
634 | fn parse_definition() {
635 | let mut p = parser("def bar( arg0 , arg1 ) arg0 + arg1");
636 |
637 | let proto = PrototypeAST("bar".into(), vec!["arg0".into(), "arg1".into()]);
638 |
639 | let body = ExprAST::Binary(
640 | '+',
641 | Box::new(ExprAST::Variable("arg0".into())),
642 | Box::new(ExprAST::Variable("arg1".into())),
643 | );
644 |
645 | let func = FunctionAST(proto, body);
646 |
647 | assert_eq!(p.parse_definition(), Ok(func));
648 | }
649 |
650 | #[test]
651 | fn parse_extern() {
652 | let mut p = parser("extern baz()");
653 |
654 | let proto = PrototypeAST("baz".into(), vec![]);
655 |
656 | assert_eq!(p.parse_extern(), Ok(proto));
657 | }
658 | }
659 |
--------------------------------------------------------------------------------