├── CMakeLists.txt ├── LICENSE ├── README.md ├── apex.cxx ├── examples ├── autodiff.md ├── autodiff_errors.png ├── formula.json ├── formula2.json ├── grad1.cxx ├── grad2.cxx ├── grad3.cxx ├── grad4.cxx └── test_errors.png ├── include └── apex │ ├── autodiff.hxx │ ├── autodiff_codegen.hxx │ ├── parse.hxx │ ├── tokenizer.hxx │ ├── tokens.hxx │ ├── util.hxx │ └── value.hxx └── src ├── autodiff ├── autodiff.cxx └── untitled ├── core └── value.cxx ├── parse └── grammar.cxx ├── tokenizer ├── lexer.cxx ├── number.cxx ├── operators.cxx ├── tokenizer.cxx └── tokens.cxx └── util ├── format.cxx └── utf.cxx /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5.1) 2 | project(apex) 3 | 4 | set(CMAKE_CXX_STANDARD 17) 5 | include_directories(include) 6 | 7 | set(SOURCE_FILES 8 | # apex.cxx 9 | src/util/utf.cxx 10 | src/util/format.cxx 11 | 12 | src/core/value.cxx 13 | 14 | src/parse/grammar.cxx 15 | 16 | src/tokenizer/tokens.cxx 17 | src/tokenizer/lexer.cxx 18 | src/tokenizer/operators.cxx 19 | src/tokenizer/tokenizer.cxx 20 | src/tokenizer/number.cxx 21 | 22 | src/autodiff/autodiff.cxx 23 | ) 24 | 25 | add_library(apex SHARED 26 | ${SOURCE_FILES} 27 | ) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Sean Baxter 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Apex DSL Library 2 | 3 | Write embedded domain-specific languages for [Circle](https://www.circle-lang.org). 4 | 5 | Follow the [reverse-mode autodiff example](examples/autodiff.md). 6 | -------------------------------------------------------------------------------- /apex.cxx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanbaxter/apex/f07a92857efd0d7a23d174227b8154f4cbaf01b6/apex.cxx -------------------------------------------------------------------------------- /examples/autodiff.md: -------------------------------------------------------------------------------- 1 | # Reverse-mode automatic differentiation with Circle and Apex 2 | 3 | ```cpp 4 | struct terms_t { 5 | double x; 6 | double y; 7 | }; 8 | 9 | terms_t my_grad(terms_t input) { 10 | return apex::autodiff_grad("sq(x / y) * sin(x * y)", input); 11 | } 12 | ``` 13 | 14 | This example shows how to leverage the Apex DSL library to implement reverse-mode automatic differentation. The expression to differentiate is passed in as a compile-time string, and the primary inputs are provided as a vector of strings. 15 | 16 | This illustrates shared object library development. There are three things you can count on: 17 | 1. No template metaprogramming. 18 | 1. No operator overloading. 19 | 1. No performance compromise. 20 | 21 | Do all your development in an ordinary C++/Circle shared object project. Call into this shared object library during source translation and capture the returned IR. Lower the IR to code using Circle macros. This is a new way forward for DSLs in C++. 22 | 23 | ## Expression templates 24 | 25 | In Standard C++, code must either be included in textual form or compiled to binary and linked with the program. Only the former form is generic--template libraries may be specialized to solve an application-specific problem. 26 | 27 | EDSLs have been attempted by the C++ community for almost twenty years. In Standard C++, the [expression template](https://en.wikipedia.org/wiki/Expression_templates) idiom is used to repurpose the C++ syntax as a domain-specific language. Operator overloading and template metaprogramming are combined to capture the result of a subexpression as a _type_. For example, if either terminal in the subexpression `a + b` is an EDSL type, the result object of the addition expression is a type that includes details of the operator and of each operand. For example, `op_add_t`, where the template arguments are EDSL operator types that recursively specify their operand types. The type of the result object for the full expression contains the information of a parse tree over that same expression input. The expression template may be traversed (as if one were traversing a parse tree) and some calculation performed at each node. 28 | 29 | Expression templates are extremely difficult to write, error messages are opaque (mostly due to the hierarchical nature of the involved types) and build times are long. Most critically, expression-template EDSLs don't allow very complex compile-time transformations on the parse tree content. Once the expression template is built, the user remains limited by C++'s lack of compile-time imperative programming support. The user cannot lower the expression template to a rational IR, or build tree data structures, or run the content through optimizers or analyzers. 30 | 31 | ## The Apex vision for libraries 32 | 33 | Circle's integrated interpreter and code reflection mechanisms establish a larger design space for libraries. _What is a library with Circle?_ **Any code that provides a service**. 34 | 35 | As demonstrated with the [Tensor Compiler example](https://github.com/seanbaxter/circle/blob/master/gems/taco.md), a Circle program can dynamically link to a shared object library _at compile time_, use that library to solve an intricate problem (tensor contraction scheduling), then lower the resultin solution (the intermediate representation) to code using Circle's macro system. 36 | 37 | Apex is a collection of services to help programmers develop this new form of library. Currently it includes a tokenizer and parser for a C++ subset (called the Apex grammar), as well as a reverse-mode automatic differentation package that serves as both an example for building new libraries and an ingredient for additional numerical computing development. 38 | 39 | Functionality built using Apex presents itself to the application developer as an _embedded domain-specific language_. But the design of Apex EDSLs is very different from the design of expression templates: there is no operator overloading; there is no template metaprogramming; we don't try to co-opt C++ tokens into carrying DSL information. 40 | 41 | The client communicates with the library by way of compile-time strings. The contents may be provided as literals or assembled from code and data using Circle's usual compile-time mechanisms. The library transforms the input text into code: 42 | 43 | 1. The Apex tokenizer lexes the text into tokens. The tokenizer code is in the shared object library `libapex.so`. 44 | 1. The Apex parser processes the tokens into a parse tree. The parser code is also in `libapex.so`. The parse tree is an light-weight class hierarchy. There are node types for unary and binary operators, function calls, terminals, and so on. It is not a template library. 45 | Parse errors are formed by the parser--you don't get C++ frontend errors when the input is malformed, but Apex backend errors, which are cleaner and more relevant to the problem of parsing. 46 | 1. The EDSL library traverses the parse tree and performs semantic analysis. This is where the library intelligence resides. All usual programming tools are available to the library. It can interact with the file system, host interpreters to execute scripts, and so on. The library intelligence should be compiled into a shared object; Apex's autodiff package is compiled into `libapex.so`. 47 | The output of the intelligence layer is the problem solution in an intermediate representation. This IR may be expressed using any C++ data structure. Because the library's shared object is loaded into the compiler's process during source translation, objects created by the intelligence layer occupy the same address space as the Circle interpreter, allowing those objects to be consumed and modified by meta code. 48 | 1. A codegen header supplies the interface between the client, the Apex tokenizer and parser, and the library intelligence. This header provides Circle macros to lower the EDSL operation from IR to expression code. 49 | 50 | Although this seems like an involved four-step pipeline, the first two components are reusable and provided by libraries. Even if you choose a different tokenizer or parser, you can use them from libraries. The intelligence layer establishes a nice separation of concerns, as you can develop it independently of the code generator. Finally, the codegen layer is very small, as all difficult work was pushed down and handled by shared object libraries. 51 | 52 | A strength of this approach is that it requires very little Circle code, only a thin layer of macros for code generation. All the intelligence can be written with Standard C++ for portability and to ease migration into this new language. 53 | 54 | ## Autodiff for Circle 55 | 56 | [**grad1.cxx**](grad1.cxx) 57 | ```cpp 58 | #include 59 | 60 | struct terms_t { 61 | double x; 62 | double y; 63 | }; 64 | 65 | terms_t my_grad(terms_t input) { 66 | return apex::autodiff_grad("sq(x / y) * sin(x * y)", input); 67 | } 68 | 69 | int main() { 70 | terms_t grad = my_grad( { .3, .5 } ); 71 | printf("%f %f\n", grad.x, grad.y); 72 | return 0; 73 | } 74 | ``` 75 | ``` 76 | $ circle grad1.cxx -I ../include -M ../Debug/libapex.so 77 | $ ./grad1 78 | 1.053170 0.425549 79 | ``` 80 | 81 | To use Apex's autodiff, pass the formula to differentiate as a string, followed by a class object with the values of each primary input. The names referenced in the formula must correspond to the member names in the class object. We'll use Circle's type introspection to help relate the type information to the variable names. 82 | 83 | The result object is another instance of the class type, this time holding the partial derivatives rather than the values of the independent variables. 84 | 85 | After just two days of programming, this package supports these expressions and elementary functions: 86 | * Binary + - * and /. 87 | * Unary -. 88 | * sq, sqrt, exp, log, sin, cos, tan, sinh, cosh, tanh, pow and norm functions. 89 | 90 | The call to `autodiff_grad` has distinct compile-time and runtime phases. At compile time, the formula is tokenized and parsed; the parse tree is lowered by `make_autodiff` to an IR called a "tape," and that tape is lowered by `autodiff_codegen.hxx` to code using Circle macros. At runtime, the independent variables are evaluated and the tape-generated code is executed, yielding the gradient. All scheduling is performed at compile time, and there is no runtime dependency on any part of the `libapex.so` library. 91 | 92 | Reverse-mode differentation is essentially a sparse matrix problem. Each dependent variable/subexpression is a row in the sparse matrix (an item in the tape) with a non-zero column for each partial derivative we'll compute to complete the chain rule. When considered as a DAG traversal, the chain rule calculation involves propagating partials from the root down each edge, and incrementing a component of the gradient vector by the concatenated chain rule coefficient. When viewed as linear algebra, the entire gradient pass is a sparse back-propagation operation. 93 | 94 | The Apex autodiff example adopts the DAG view of the problem. The implementation performs best when the size of the DAG is small enough so that the gains of explicit scheduling of each individual back-propagation term more than offset the parallelism left on the table by not using an optimized sparse matrix back-propagation code. 95 | 96 | However, the separation of autodiff intelligence and code generation permits selection of a back-propagation treatment most suitable for the particular primary inputs and expression graph. Calls into the autodiff library with different expressions may generate implementations utilizing different strategies, without the vexations of template metaprogramming. 97 | 98 | ## Writing an embedded DSL for Circle programs 99 | 100 | How do we implement the autodiff DSL? We basically write our own small compiler frontend--it takes text input, performs syntax and semantic analysis, and emits IR, just like a general-purpose compiler. 101 | 102 | This would be too much work if written from scratch for each DSL library. We'll use the language components available in Apex to tokenize and parse the input text into a parse tree for the _Apex grammar_ (a C++-inspired expression grammar), and consume the parse tree as configuration for the library. 103 | 104 | ### The tokenizer 105 | 106 | Apex includes a tokenizer that breaks an input text into operators (all the ones recognized by C++) and identifiers. 107 | 108 | [**tokens.hxx**](../include/apex/tokens.hxx) 109 | ```cpp 110 | struct token_t { 111 | tk_kind_t kind : 8; 112 | int store : 24; 113 | const char* begin, *end; 114 | 115 | operator tk_kind_t() const { return kind; } 116 | }; 117 | typedef const token_t* token_it; 118 | ``` 119 | 120 | The token structure holds an enumeration defining the kind of token (eg '+' token, integer token or identifier token) and an index into a store to retrieve a resource, like a string, integer or floating-point value. 121 | 122 | [**tokenizer.hxx**](../include/apex/tokenizer.hxx) 123 | ```cpp 124 | struct tokenizer_t { 125 | std::vector strings; 126 | std::vector ints; 127 | std::vector floats; 128 | 129 | // Byte offset for each line start. 130 | std::vector line_offsets; 131 | 132 | // Original text we tokenized. 133 | std::string text; 134 | 135 | // The text divided into tokens. 136 | std::vector tokens; 137 | 138 | parse::range_t token_range() const; 139 | 140 | int reg_string(range_t range); 141 | int find_string(range_t range) const; 142 | 143 | // Return 0-indexed line and column offsets for the token at 144 | // the specified byte offset. This performs UCS decoding to support 145 | // multibyte characters. 146 | int token_offset(source_loc_t loc) const; 147 | int token_line(int offset) const; 148 | int token_col(int offset, int line) const; 149 | std::pair token_linecol(int offset) const; 150 | std::pair token_linecol(source_loc_t loc) const; 151 | 152 | void tokenize(); 153 | }; 154 | ``` 155 | 156 | The `tokenizer_t` class holds the input text, the array of tokens, the resources, and an array of line offsets to ease mapping between tokens and line/column positions within the input text. The tokenizer expects UTF-8 input, so characters may consume between one and four bytes; the `token_linecol` functions map token indices and byte offsets within the text to the correct line/column positions, accounting for these multi-byte characters. 157 | 158 | To use the tokenizer, set the `text` data member and call the `tokenize` member function. 159 | 160 | ### The parser 161 | 162 | The parser consumes tokens from left-to-right and constructs a parse tree from bottom-to-top. Apex includes a hand-written recursive descent (RD) parser, which is the most practical and flexible approach to parsing. 163 | 164 | [**parse.hxx**](../include/apex/parse.hxx) 165 | ```cpp 166 | struct node_t { 167 | enum kind_t { 168 | kind_ident, 169 | kind_unary, 170 | kind_binary, 171 | kind_assign, 172 | kind_ternary, 173 | kind_call, 174 | kind_char, 175 | kind_string, 176 | kind_number, 177 | kind_bool, 178 | kind_subscript, 179 | kind_member, 180 | kind_braced, 181 | }; 182 | 183 | kind_t kind; 184 | source_loc_t loc; 185 | 186 | node_t(kind_t kind, source_loc_t loc) : kind(kind), loc(loc) { } 187 | virtual ~node_t() { } 188 | 189 | template 190 | derived_t* as() { 191 | return derived_t::classof(this) ? 192 | static_cast(this) : 193 | nullptr; 194 | } 195 | 196 | template 197 | const derived_t* as() const { 198 | return derived_t::classof(this) ? 199 | static_cast(this) : 200 | nullptr; 201 | } 202 | }; 203 | typedef std::unique_ptr node_ptr_t; 204 | typedef std::vector node_list_t; 205 | ``` 206 | 207 | Each parse tree node derives `apex::parse::node_t`. `source_loc_t` is the integer index of the token from which the parse node was constructed, and we include one in each parse node. The tokenizer object can map `source_loc_t` objects back to line/column numbers for error reporting. 208 | 209 | The full implementation of the parser is in [grammar.cxx](../src/parse/grammar.cxx). We'll run the parser at compile time from a meta context in the Circle program. But unlike a template, which is C++ generic programming offering, _we don't need to see the source of the parser_ from the source code of the client. The parser is compiled into `libapex.so`, and the Circle interpreter will make a foreign function call to run the parser and retrieve the parse tree. We don't even need access to `libapex.so` at runtime--the IR from the DSL library is lowered to Circle code during compile time, and the resulting binary retains no evidence of `libapex.so`'s role in its generation. 210 | 211 | [**parse.hxx**](../include/apex/parse.hxx) 212 | ```cpp 213 | struct parse_t { 214 | tok::tokenizer_t tokenizer; 215 | node_ptr_t root; 216 | }; 217 | 218 | parse_t parse_expression(const char* str); 219 | ``` 220 | 221 | Calling `apex::parse::parse_expression` tokenizes and parses an input text and returns both the tokenizer (which has line-mapping content) and the root node of the parse tree. `node_ptr_t` is an `std::unique_ptr`; when the user destroys the root object from meta code, the entire tree is recursively destroyed from the smart pointers' destructors. 222 | 223 | ### The autodiff IR 224 | 225 | The autodiff library traverses the parse tree and builds a data structure called a _tape_ or _Wengert list_, which includes instructions for evaluating the value and partial derivatives for each subexpression. 226 | 227 | [**autodiff.hxx**](../include/apex/autodiff.hxx) 228 | ```cpp 229 | struct autodiff_t { 230 | struct item_t { 231 | // The dimension of the tape item. 232 | // 0 == dim for scalar. dim > 0 for vector. 233 | int dim; 234 | 235 | // The expression to execute to compute this dependent variable's value. 236 | // This is evaluated during the upsweep when creating the tape from the 237 | // independent variables and moving through all subexpressions. 238 | ad_ptr_t val; 239 | 240 | // When updating the gradient of the parent, this tape item loops over each 241 | // of its dependent variables and performs a chain rule increment. 242 | // It calls grad(index, coef) on each index. This recurses, down to the 243 | // independent vars, multiplying in the coef at each recurse. 244 | 245 | // When we hit an independent var, the grads array is empty (although it 246 | // may be empty otherwise) and we simply perform += coef into the slot 247 | // corresponding to the independent variable in the gradient array. 248 | struct grad_t { 249 | int index; 250 | ad_ptr_t coef; 251 | }; 252 | std::vector grads; 253 | }; 254 | 255 | // The first var_names.size() items encode independent variables. 256 | std::vector vars; 257 | std::vector tape; 258 | }; 259 | 260 | autodiff_t make_autodiff(const std::string& formula, 261 | const std::vector& vars); 262 | ``` 263 | 264 | The result object of `make_autodiff` is an object of type `autodiff_t`. This holds the _tape_, and each tape item holds expressions to evaluating the tape's subexpression and that subexpression's gradient. The index in each gradient component refers to a position within the tape corresponding to the variable (dependent or independent) that the partial derivative is computed with respect to. When traversing the tape DAG, we concatenate partial derivatives; when we hit a terminal node (an independent variable), we increment the output gradient by the total derivative--this is the chain rule in action. 265 | 266 | Although the DSL doesn't yet support it, the tape is designed to accomodate vector types in addition to scalar types. 267 | 268 | The autodiff IR needs to be comprehensive enough to encode any operations found in the expression to differentiate. We chose the design for easy lowering using intrinsics like `@op` and `@expression` to generate code from strings. 269 | 270 | [**autodiff.hxx**](../include/apex/autodiff.hxx) 271 | ```cpp 272 | struct ad_t { 273 | enum kind_t { 274 | kind_tape, 275 | kind_literal, 276 | kind_unary, 277 | kind_binary, 278 | kind_func 279 | }; 280 | kind_t kind; 281 | 282 | ad_t(kind_t kind) : kind(kind) { } 283 | 284 | template 285 | derived_t* as() { 286 | return derived_t::classof(this) ? 287 | static_cast(this) : 288 | nullptr; 289 | } 290 | 291 | template 292 | const derived_t* as() const { 293 | return derived_t::classof(this) ? 294 | static_cast(this) : 295 | nullptr; 296 | } 297 | }; 298 | typedef std::unique_ptr ad_ptr_t; 299 | 300 | struct ad_tape_t : ad_t { 301 | ad_tape_t(int index) : ad_t(kind_tape), index(index) { } 302 | static bool classof(const ad_t* ad) { return kind_tape == ad->kind; } 303 | 304 | int index; 305 | }; 306 | 307 | struct ad_literal_t : ad_t { 308 | ad_literal_t(double x) : ad_t(kind_literal), x(x) { } 309 | static bool classof(const ad_t* ad) { return kind_literal == ad->kind; } 310 | 311 | double x; 312 | }; 313 | 314 | struct ad_unary_t : ad_t { 315 | ad_unary_t(const char* op, ad_ptr_t a) : 316 | ad_t(kind_unary), op(op), a(std::move(a)) { } 317 | static bool classof(const ad_t* ad) { return kind_unary == ad->kind; } 318 | 319 | const char* op; 320 | ad_ptr_t a; 321 | }; 322 | 323 | struct ad_binary_t : ad_t { 324 | ad_binary_t(const char* op, ad_ptr_t a, ad_ptr_t b) : 325 | ad_t(kind_binary), op(op), a(std::move(a)), b(std::move(b)) { } 326 | static bool classof(const ad_t* ad) { return kind_binary == ad->kind; } 327 | 328 | const char* op; 329 | ad_ptr_t a, b; 330 | }; 331 | 332 | struct ad_func_t : ad_t { 333 | ad_func_t(std::string f) : ad_t(kind_func), f(std::move(f)) { } 334 | static bool classof(const ad_t* ad) { return kind_func == ad->kind; } 335 | 336 | std::string f; 337 | std::vector args; 338 | }; 339 | ``` 340 | 341 | The autodiff code in `libapex.so` generates `ad_t` trees into the tape data structure. Each tree node is allocated on the heap and stored in an `std::unique_ptr`. Because the shared object is loaded into the address space of the compiler, the result object of the foreign-function library call is fully accessible to meta code in the translation unit by way of the Circle interpreter. 342 | 343 | As with the tokenizer and parser, the implementation of the autodiff library is totally abstracted from the library's caller. 344 | 345 | The tape-building class `ad_builder_t` in [autodiff.cxx](../src/autodiff/autodiff.cxx) has member functions for each operation and elementary function supported by the DSL. For example, to support multiplication we implement the product rule of calculus: 346 | 347 | ```cpp 348 | int ad_builder_t::mul(int a, int b) { 349 | // The sq operator is memoized, so prefer that. 350 | if(a == b) 351 | return sq(a); 352 | 353 | // grad (a * b) = a grad b + b grad a. 354 | item_t item { }; 355 | item.val = mul(val(a), val(b)); 356 | item.grads.push_back({ 357 | b, // a * grad b 358 | val(a) 359 | }); 360 | item.grads.push_back({ 361 | a, // b * grad a 362 | val(b) 363 | }); 364 | return push_item(std::move(item)); 365 | } 366 | ``` 367 | 368 | The operands are indices to lower nodes in the tape. We use function overloads like `mul` and `val` to create `ad_t` nodes, which are assembled recursively into expression trees. 369 | 370 | ```cpp 371 | int ad_builder_t::sin(int a) { 372 | item_t item { }; 373 | item.val = func("std::sin", val(a)); 374 | item.grads.push_back({ 375 | a, 376 | func("std::cos", val(a)) 377 | }); 378 | return push_item(std::move(item)); 379 | } 380 | ``` 381 | 382 | The sine function is supported with a similar member function. We generate `ad_func_t` nodes which identify the functions to call by string name. When the IR is lowered to code in [autodiff_codegen.hxx](../include/apex/autodiff_codegen.hxx), we'll use `@expression` to perform name lookup and convert these qualified names to function lvalues. 383 | 384 | Note that we can deliver a rich calculus package without having to define a type system to interact with the rest of the C++ application. We don't have to require that `sin` and `cos` implement any particular concept or interface to participate in differentiation, because these are first-class functions supported by the DSL. 385 | 386 | To allow user-extension to the autodiff library, such as user-defined functions, any convention may be used to communicate between the library and the client. A participating function and its derivative could adopt a particular naming convention (e.g., the function ends with `_f` and the derivative ends with `_grad`); the function and derivative could be member functions of a class that is named in the input string (e.g., naming "sinc" in the formula string performs name lookup for class `sinc_t` and calls member functions `f` and `grad`) 387 | 388 | The strength of this design is that you aren't relying on C++'s overload resolution and type systems to coordinate between the library's implementation and its users; the library can introduce its own conventions for interoperability. 389 | 390 | ## The autodiff code generator 391 | 392 | [**autodiff_codegen.hxx**](../include/apex/autodiff_codegen.hxx) 393 | ```cpp 394 | template 395 | @meta type_t autodiff_grad(@meta const char* formula, type_t input) { 396 | 397 | // Parse out the names from the inputs. 398 | static_assert(std::is_class::value, 399 | "argument to autodiff_eval must be a class object"); 400 | 401 | // Collect the name of each primary input. 402 | @meta std::vector vars; 403 | @meta size_t num_vars = @member_count(type_t); 404 | 405 | // Copy the values of the independent variables into the tape. 406 | double tape_values[num_vars]; 407 | @meta for(int i = 0; i < num_vars; ++i) { 408 | // Confirm that we have a scalar double-precision term. 409 | static_assert(std::is_same<@member_type(type_t, i), double>::value, 410 | std::string("member ") + @member_name(type_t, i) + " must be type double"); 411 | 412 | // Push the primary input name. 413 | @meta vars.push_back({ 414 | @member_name(type_t, i), 415 | 0 416 | }); 417 | 418 | // Set the primary input's value. 419 | tape_values[i] = @member_ref(input, i); 420 | } 421 | 422 | // Construct the tape. This makes a foreign function call into libapex.so. 423 | @meta apex::autodiff_t autodiff = apex::make_autodiff(formula, vars); 424 | @meta size_t count = autodiff.tape.size(); 425 | 426 | // Compute the values for the whole tape. This is the forward-mode pass. 427 | // It propagates values from the terminals (independent variables) through 428 | // the subexpressions and up to the root of the function. 429 | 430 | // Evaluate the subexpressions. 431 | @meta for(size_t i = num_vars; i < count; ++i) 432 | tape_values[i] = autodiff_expr(autodiff.tape[i].val.get()); 433 | 434 | // Evaluate the gradients. This is a top-down reverse-mode traversal of 435 | // the autodiff DAG. The partial derivatives are parsed along edges, starting 436 | // from the root and towards each terminal. When a terminal is visited, the 437 | // corresponding component of the gradient is incremented by the product of 438 | // all the partial derivatives from the root of the DAG down to that 439 | // terminal. 440 | double coef[num_vars]; 441 | type_t grad { }; 442 | 443 | // Visit each child of the root node. 444 | @meta int root = count - 1; 445 | @meta for(const auto& g : autodiff.tape[root].grads) { 446 | // Evaluate the coefficient into the stack. 447 | coef[root] = autodiff_expr(g.coef.get()); 448 | 449 | // Recurse on the child. 450 | @macro autodiff_tape(g.index, root); 451 | } 452 | 453 | return std::move(grad); 454 | } 455 | ``` 456 | 457 | `autodiff_grad` is implemented as a metafunction. Recall that these are like super function templates: some parameters are marked `@meta`, requiring compile-time arguments be provided. By making the input formula string `@meta`, we can pass the string to `apex::make_autodiff` at compile time to generate an IR. This function is implemented in `libapex.so`, so `-M libapex.so` must be specified as a `circle` argument to load this shared object library as a dependency. 458 | 459 | After the tape has been computed and returned, we first initialize the tape values in forward order (that is, from the leaves of the expression tree up to the root). Then, we may the reverse pass, propagating partial derivatives from the root of the expression tree down to the terminals, where the gradient is finally updated. 460 | 461 | Although the values in the tape will be used again during the top-down gradient pass, their storage may be a performance limiter in problems with a very large number of temporary nodes. Because the library defines its own IR and scheduling intelligence, it's feasible to extend the IR and emit instructions to rematerialize temporary values to alleviate storage pressure. 462 | 463 | [**autodiff_codegen.hxx**](../include/apex/autodiff_codegen.hxx) 464 | ```cpp 465 | @macro auto autodiff_expr(const ad_t* ad) { 466 | @meta+ if(const auto* tape = ad->as()) { 467 | @emit return tape_values[tape->index]; 468 | 469 | } else if(const auto* literal = ad->as()) { 470 | @emit return literal->x; 471 | 472 | } else if(const auto* unary = ad->as()) { 473 | @emit return @op( 474 | unary->op, 475 | autodiff_expr(unary->a.get()) 476 | ); 477 | 478 | } else if(const auto* binary = ad->as()) { 479 | @emit return @op( 480 | binary->op, 481 | autodiff_expr(binary->a.get()), 482 | autodiff_expr(binary->b.get()) 483 | ); 484 | 485 | } else if(const auto* func = ad->as()) { 486 | // Support 1- and 2-parameter function calls. 487 | if(1 == func->args.size()) { 488 | @emit return @expression(func->f)(autodiff_expr(func->args[0].get())); 489 | 490 | } else if(2 == func->args.size()) { 491 | @emit return @expression(func->f)(autodiff_expr(func->args[0].get()), 492 | autodiff_expr(func->args[1].get())); 493 | } 494 | } 495 | } 496 | ``` 497 | 498 | The expression macro `autodiff_expr` recurses an `ad_t` tree and switches on each node kind. 499 | 500 | * The macro is expanded in the scope of the caller, so the `tape_values` object is visible; this provides access to the value of each subexpression. 501 | * The unary and binary nodes hold strings with the operator names, such as "+" or "/". We can pass these strings to `@op` along with the expression arguments to synthesize the corresponding kind of expression. 502 | * Function call nodes have the _name_ of the function stored as a string. When evaluated with `@expression`, name lookup is performed on the qualified name (eg, "std::cos") and returns a function lvalue or overload set. 503 | 504 | Each tape item (corresponding to sparse matrix row) includes one `ad_t` tree that renders the value of the subexpression, and one `ad_t` per child node in the DAG to compute partial derivatives. The values are computed in bottom-up order (forward through the tape), and the partial derivatives are computed in top-down order (reverse mode through the tape). An optimization potential may be exposed by evaluating all partial derivatives in parallel (there are no data dependencies between them), and using a parallelized sparse back-propagation code to concatenate the partial derivatives. Again, these choices should be made by the intelligence of the library, which is well-separated from the metaprogramming concerns of the code generator. 505 | 506 | ## DSL error reporting 507 | 508 | Circle has been adding capability for better integration of compile-time exceptions with compiler errors. If an exception is thrown either from the Circle interpreter or from a foreign function call to a dynamically-loaded shared object, a backtrace for the unhandled exception is printed along with the exception's error. This helps us understand exactly where and why the error was generated. 509 | 510 | Consider breaking the syntax of our Apex autodiff function: 511 | ```cpp 512 | terms_t my_grad(terms_t input) { 513 | return apex::autodiff_grad("sq(x / y) * 2 sin(x)", input); 514 | } 515 | ``` 516 | 517 | The Apex grammar doesn't support this kind of implicit multiplication. The `sin` call is simply an unexpected token, so our parser throws an exception indicating such: 518 | 519 | [**grammar.cxx**](../src/parse/grammar.cxx) 520 | ```cpp 521 | void grammar_t::unexpected_token(token_it pos, const char* rule) { 522 | const char* begin = pos->begin; 523 | const char* end = pos->end; 524 | int len = end - begin; 525 | 526 | std::string msg = format("unexpected token '%.*s' in %s", len, begin, rule); 527 | 528 | throw parse_exception_t(msg); 529 | } 530 | ``` 531 | 532 | This exception originates in compiled code, and unrolls the stack through the foreign function call to `apex::make_autodiff`, through the entire Circle compiler, until it is finally caught from Circle's `main` function and printed as an uncaught exception. Wonderfully, Circle constructs backtrace information as this exception unrolls through function calls and macro expansions, and presents the backtrace as compiler errors: 533 | 534 | ![grad1.cxx errors](autodiff_errors.png) 535 | 536 | The text saved in the `parse_exception_t` class (which derives `std::runtime_error`, which is how we access its error message) is printed in the most-indented part of the error. Around that, we're shown that it's thrown from `apex::make_autodiff` from `libapex.so`. In turn, that function call is made from inside the instantiation of the metafunction `apex::autodiff_grad`. The offending string is shown here, and we can see the call to `sin` which corresponds to the parse error thrown from inside Apex. 537 | 538 | The backtrace will print the full path of functions called, regardless of if they are called from interpreted or compiled code. Additionally, the _throw-expression_ generating the exception is flagged: 539 | 540 | **test.cxx** 541 | ```cpp 542 | #include 543 | 544 | void func1() { 545 | throw std::runtime_error("A bad bad thing"); 546 | } 547 | 548 | void func2() { 549 | func1(); 550 | } 551 | 552 | void func3() { 553 | func2(); 554 | } 555 | 556 | @meta func3(); 557 | ``` 558 | 559 | ![test.cxx errors](test_errors.png) 560 | 561 | ## Circle as a build system 562 | 563 | Circle integrates with the host environment and provides functionality of scripting languages and build systems. We can use this compile-time capability and drive program generation from resources. To extend the gradient example let's specify functions to differentiate in JSON files in the working directory: 564 | 565 | [**forumla.json**](formula.json) 566 | ``` 567 | { 568 | "F1" : "sin(x / y + z) / sq(x + y + z)", 569 | "F2" : "tanh(sin(x) * exp(y / z))" 570 | } 571 | ``` 572 | 573 | The translation unit will open this file using the header-only [json.hpp](https://github.com/nlohmann/json) library, read the contents, inject functions for evaluating both functions, and call into the Apex library for code to compute the derivatives for each library. Finally, we'll generate a driver program that takes command-line arguments for evaluating one of the functions and its gradient. 574 | 575 | [**grad2.cxx**](grad2.cxx) 576 | ```cpp 577 | #include 578 | #include 579 | #include 580 | #include 581 | #include 582 | 583 | // Parse the JSON file and keep it open in j. 584 | using nlohmann::json; 585 | using apex::sq; 586 | 587 | struct vec3_t { 588 | double x, y, z; 589 | }; 590 | 591 | // Record the function names encountered in here! 592 | @meta std::vector func_names; 593 | 594 | @macro void gen_functions(const char* filename) { 595 | // Open this file at compile time and parse as JSON. 596 | @meta std::ifstream json_file(filename); 597 | @meta json j; 598 | @meta json_file>> j; 599 | 600 | @meta for(auto& item : j.items()) { 601 | // For each item in the json... 602 | @meta std::string name = item.key(); 603 | @meta std::string f = item.value(); 604 | @meta std::cout<< "Injecting '"<< name<< "' : '"<< f<< "' from "<< 605 | filename<< "\n"; 606 | 607 | // Generate a function from the expression. 608 | extern "C" double @("f_" + name)(vec3_t v) { 609 | double x = v.x, y = v.y, z = v.z; 610 | return @expression(f); 611 | } 612 | 613 | // Generate a function to return the gradient. 614 | extern "C" vec3_t @("grad_" + name)(vec3_t v) { 615 | return apex::autodiff_grad(f.c_str(), v); 616 | } 617 | 618 | @meta func_names.push_back(name); 619 | } 620 | } 621 | ``` 622 | 623 | First, we'll write a statement macro `gen_functions` which takes a filename (since this is a macro, the value of the argument must be known at compile time). Expanding the macro creates a new meta scope but reuses the real scope of the call site. This way, we can create new meta objects like the file handle and JSON object without polluting the declarative region we're expanding the macro into. 624 | 625 | We open the JSON file at compile time and range-for through its contents. We'll echo back to the user the functions being generated, then define functions `"f_" + name` and `"grad_" + name`. The Circle operator `@()` transforms strings to tokens. If we expand this statement macro into a namespace scope, the real declarations are interpreted as function definitions. 626 | 627 | Circle macros allow you to programmatically inject code into any statement- or expression-accepting scope. Unlike preprocessor macros, these macros follow argument deduction and overload resolution rules. They establish a new meta scope to allow you to use your own tooling. Only the _real statements_ fall through the macro and are inserted into the target scope. 628 | 629 | After the value and gradient functions are defined, we simply push the name of the function to the `func_names` array which sits in the global namespace. 630 | 631 | ```cpp 632 | @macro gen_functions("formula.json"); 633 | 634 | std::pair eval(const char* name, vec3_t v) { 635 | @meta for(const std::string& f : func_names) { 636 | if(!strcmp(name, @string(f))) { 637 | return { 638 | @("f_" + f)(v), 639 | @("grad_" + f)(v) 640 | }; 641 | } 642 | } 643 | 644 | printf("Unknown function %s\n", name); 645 | exit(1); 646 | } 647 | 648 | void print_usage() { 649 | printf(" Usage: grad2 name x y z\n"); 650 | exit(1); 651 | } 652 | 653 | int main(int argc, char** argv) { 654 | if(5 != argc) 655 | print_usage(); 656 | 657 | const char* f = argv[1]; 658 | double x = atof(argv[2]); 659 | double y = atof(argv[3]); 660 | double z = atof(argv[4]); 661 | vec3_t v { x, y, z }; 662 | 663 | auto result = eval(f, v); 664 | double val = result.first; 665 | vec3_t grad = result.second; 666 | 667 | printf(" f: %f\n", val); 668 | printf(" grad: { %f, %f, %f }\n", grad.x, grad.y, grad.z); 669 | 670 | return 0; 671 | } 672 | ``` 673 | 674 | In the second half of the program we generate a driver capability. `eval` is a normal runtime function that takes the name of the function as a string and the primary inputs as a `vec3_t`. We then loop over all the function names pushed in `gen_functions`. The Circle extension `@string` converts a compile-time `std::string` back into a string literal, so we can `strcmp` it against our runtime function name. If we have a match, we'll evaluate the function and its gradient and return the results in an `std::pair`. 675 | 676 | Note that we have to expand `gen_functions` over the input file `formula.json` _prior_ to entering the definition for `eval`, because `gen_functions` populates the `func_names` array. If we were to change the order of operations here, `func_names` would be empty when `eval` is translated, resulting in a broken driver program. 677 | 678 | ``` 679 | $ circle grad2.cxx -I ../include/ -M ../Debug/libapex.so 680 | Injecting 'F1' : 'sin(x / y + z) / sq(x + y + z)' from formula.json 681 | Injecting 'F2' : 'tanh(sin(x) * exp(y / z))' from formula.json 682 | 683 | $ ./grad2 F1 .5 .6. .7 684 | f: 0.308425 685 | grad: { 0.361961, 0.358750, 0.354255 } 686 | ``` 687 | 688 | ## More build options 689 | 690 | We've built both diagnostics and a client program into our translation unit, driven by a simple portable JSON resource file sitting in the source directory. 691 | 692 | With a one-line change we can turn `grad2` into `grad3`, a command-line tool that gets pointed at a JSON configuration file and generates code from that. How do we pass arguments through the Circle compiler to the translation unit? Preprocessor macros! We can still get some use from them: 693 | 694 | [**grad3.cxx**](grad3.cxx) 695 | ```cpp 696 | @macro gen_functions(INPUT_FILENAME); 697 | ``` 698 | 699 | Instead of expanding the `gen_functions` statement macro on "formula.json", let's expand it on a macro defined to a string literal. How do we build? Just use the -D command-line argument to bind macros. Keep in mind the escaped quotes \" to satisfy the requirements of the Linux terminal: 700 | 701 | ``` 702 | $ circle grad3.cxx -I ../include/ -M ../Debug/libapex.so -DINPUT_FILENAME=\"formula2.json\" 703 | Injecting 'F3' : 'exp(x * x) + y / z' from formula2.json 704 | Injecting 'F4' : 'sqrt(sq(x) + sq(y) + sq(z))' from formula2.json 705 | 706 | $ ./grad3 F4 .4 .5 .6 707 | f: 0.877496 708 | grad: { 0.455842, 0.569803, 0.683763 } 709 | ``` 710 | 711 | This is a cute change. We've just separated the resource from the source code of the program. 712 | 713 | But can we really exploit the build-system qualities of the compiler? What about having the translation unit scrape all the JSON files in the working directory, and generating code from all the functions found in those? 714 | 715 | [**grad4.cxx**](grad4.cxx) 716 | ```cpp 717 | #include 718 | #include 719 | #include 720 | #include 721 | #include 722 | #include 723 | 724 | inline std::string get_extension(const std::string& filename) { 725 | return filename.substr(filename.find_last_of(".") + 1); 726 | } 727 | 728 | inline bool match_extension(const char* filename, const char* ext) { 729 | return ext == get_extension(filename); 730 | } 731 | 732 | // ... Snipped out the same gen_functions macro as grad2.cxx. 733 | 734 | // Use Circle like a build system: 735 | 736 | // Open the current directory. 737 | @meta DIR* dir = opendir("."); 738 | 739 | // Loop over all files in the current directory. 740 | @meta while(dirent* ent = readdir(dir)) { 741 | 742 | // Match .json files. 743 | @meta if(match_extension(ent->d_name, "json")) { 744 | 745 | // Generate functions for all entries in this json file. 746 | @macro gen_functions(ent->d_name); 747 | } 748 | } 749 | 750 | @meta closedir(dir); 751 | ``` 752 | 753 | ``` 754 | $ circle grad4.cxx -I ../include -M ../Debug/libapex.so 755 | Injecting 'F1' : 'sin(x / y + z) / sq(x + y + z)' from formula.json 756 | Injecting 'F2' : 'tanh(sin(x) * exp(y / z))' from formula.json 757 | Injecting 'F3' : 'exp(x * x) + y / z' from formula2.json 758 | Injecting 'F4' : 'sqrt(sq(x) + sq(y) + sq(z))' from formula2.json 759 | 760 | $ ./grad4 F3 .2 .3 .4 761 | f: 1.790811 762 | grad: { 0.416324, 2.500000, 1.875000 } 763 | ``` 764 | 765 | POSIX systems put their filesystem APIs in `dirent.h`. Let's include this and write our own function to test the extension of a filename, a simple operation curiously missing from the POSIX API. 766 | 767 | Now, _at compile time_, use `opendir` to open the current directory. Keep hitting `readdir` to return a descriptor for the next file in the directory. If the filename ends with ".json", expand the `gen_functions` macro on the JSON's filename. When we've exhausted the directory, be a good citizen and `closedir` the directory handle. 768 | 769 | Even though we're nested in a while and an if statement, the _real scope_ of the program at the site of the macro expansion is the global namespace, so the macro will inject its generated functions in the global namespace. 770 | 771 | One of the biggest strengths of using Circle as a build system over trying to express equivalent operations in CMake or Make or any other tool is _familiarity_. Even if you didn't know the _dirent.h_ API, you can look it up and in a minute or two know exactly how to use it. Perusing the CMake reference for a directory enumeration command isn't the same help, because you're still stuck trying to express yourself in a language you probably don't know very well (and a language that is much less expressive than C++). 772 | 773 | The design goal of Circle is to extend the _compiler_ to let you do much more with the same _language_. -------------------------------------------------------------------------------- /examples/autodiff_errors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanbaxter/apex/f07a92857efd0d7a23d174227b8154f4cbaf01b6/examples/autodiff_errors.png -------------------------------------------------------------------------------- /examples/formula.json: -------------------------------------------------------------------------------- 1 | { 2 | "F1" : "sin(x / y + z) / sq(x + y + z)", 3 | "F2" : "tanh(sin(x) * exp(y / z))" 4 | } -------------------------------------------------------------------------------- /examples/formula2.json: -------------------------------------------------------------------------------- 1 | { 2 | "F3" : "exp(x * x) + y / z", 3 | "F4" : "sqrt(sq(x) + sq(y) + sq(z))" 4 | } -------------------------------------------------------------------------------- /examples/grad1.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | struct terms_t { 4 | double x; 5 | double y; 6 | }; 7 | 8 | terms_t my_grad(terms_t input) { 9 | return apex::autodiff_grad("sq(x / y) * sin(x * y)", input); 10 | } 11 | 12 | int main() { 13 | terms_t grad = my_grad( { .3, .5 } ); 14 | printf("%f %f\n", grad.x, grad.y); 15 | return 0; 16 | } -------------------------------------------------------------------------------- /examples/grad2.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // Parse the JSON file and keep it open in j. 8 | using nlohmann::json; 9 | using apex::sq; 10 | 11 | struct vec3_t { 12 | double x, y, z; 13 | }; 14 | 15 | // Record the function names encountered in here! 16 | @meta std::vector func_names; 17 | 18 | @macro void gen_functions(const char* filename) { 19 | // Open this file at compile time and parse as JSON. 20 | @meta std::ifstream json_file(filename); 21 | @meta json j; 22 | @meta json_file>> j; 23 | 24 | @meta for(auto& item : j.items()) { 25 | // For each item in the json... 26 | @meta std::string name = item.key(); 27 | @meta std::string f = item.value(); 28 | @meta std::cout<< "Injecting '"<< name<< "' : '"<< f<< "' from "<< 29 | filename<< "\n"; 30 | 31 | // Generate a function from the expression. 32 | extern "C" double @("f_" + name)(vec3_t v) { 33 | double x = v.x, y = v.y, z = v.z; 34 | return @expression(f); 35 | } 36 | 37 | // Generate a function to return the gradient. 38 | extern "C" vec3_t @("grad_" + name)(vec3_t v) { 39 | return apex::autodiff_grad(f.c_str(), v); 40 | } 41 | 42 | @meta func_names.push_back(name); 43 | } 44 | } 45 | 46 | @macro gen_functions("formula.json"); 47 | 48 | std::pair eval(const char* name, vec3_t v) { 49 | @meta for(const std::string& f : func_names) { 50 | if(!strcmp(name, @string(f))) { 51 | return { 52 | @("f_" + f)(v), 53 | @("grad_" + f)(v) 54 | }; 55 | } 56 | } 57 | 58 | printf("Unknown function %s\n", name); 59 | exit(1); 60 | } 61 | 62 | void print_usage() { 63 | printf(" Usage: grad2 name x y z\n"); 64 | exit(1); 65 | } 66 | 67 | int main(int argc, char** argv) { 68 | if(5 != argc) 69 | print_usage(); 70 | 71 | const char* f = argv[1]; 72 | double x = atof(argv[2]); 73 | double y = atof(argv[3]); 74 | double z = atof(argv[4]); 75 | vec3_t v { x, y, z }; 76 | 77 | auto result = eval(f, v); 78 | double val = result.first; 79 | vec3_t grad = result.second; 80 | 81 | printf(" f: %f\n", val); 82 | printf(" grad: { %f, %f, %f }\n", grad.x, grad.y, grad.z); 83 | 84 | return 0; 85 | } -------------------------------------------------------------------------------- /examples/grad3.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // Parse the JSON file and keep it open in j. 8 | using nlohmann::json; 9 | using apex::sq; 10 | 11 | struct vec3_t { 12 | double x, y, z; 13 | }; 14 | 15 | // Record the function names encountered in here! 16 | @meta std::vector func_names; 17 | 18 | @macro void gen_functions(const char* filename) { 19 | // Open this file at compile time and parse as JSON. 20 | @meta std::ifstream json_file(filename); 21 | @meta json j; 22 | @meta json_file>> j; 23 | 24 | @meta for(auto& item : j.items()) { 25 | // For each item in the json... 26 | @meta std::string name = item.key(); 27 | @meta std::string f = item.value(); 28 | @meta std::cout<< "Injecting '"<< name<< "' : '"<< f<< "' from "<< 29 | filename<< "\n"; 30 | 31 | // Generate a function from the expression. 32 | extern "C" double @("f_" + name)(vec3_t v) { 33 | double x = v.x, y = v.y, z = v.z; 34 | return @expression(f); 35 | } 36 | 37 | // Generate a function to return the gradient. 38 | extern "C" vec3_t @("grad_" + name)(vec3_t v) { 39 | return apex::autodiff_grad(f.c_str(), v); 40 | } 41 | 42 | @meta func_names.push_back(name); 43 | } 44 | } 45 | 46 | @macro gen_functions(INPUT_FILENAME); 47 | 48 | std::pair eval(const char* name, vec3_t v) { 49 | @meta for(const std::string& f : func_names) { 50 | if(!strcmp(name, @string(f))) { 51 | return { 52 | @("f_" + f)(v), 53 | @("grad_" + f)(v) 54 | }; 55 | } 56 | } 57 | 58 | printf("Unknown function %s\n", name); 59 | exit(1); 60 | } 61 | 62 | void print_usage() { 63 | printf(" Usage: grad3 name x y z\n"); 64 | exit(1); 65 | } 66 | 67 | int main(int argc, char** argv) { 68 | if(5 != argc) 69 | print_usage(); 70 | 71 | const char* f = argv[1]; 72 | double x = atof(argv[2]); 73 | double y = atof(argv[3]); 74 | double z = atof(argv[4]); 75 | vec3_t v { x, y, z }; 76 | 77 | auto result = eval(f, v); 78 | double val = result.first; 79 | vec3_t grad = result.second; 80 | 81 | printf(" f: %f\n", val); 82 | printf(" grad: { %f, %f, %f }\n", grad.x, grad.y, grad.z); 83 | return 0; 84 | } 85 | -------------------------------------------------------------------------------- /examples/grad4.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | inline std::string get_extension(const std::string& filename) { 9 | return filename.substr(filename.find_last_of(".") + 1); 10 | } 11 | 12 | inline bool match_extension(const char* filename, const char* ext) { 13 | return ext == get_extension(filename); 14 | } 15 | 16 | // Parse the JSON file and keep it open in j. 17 | using nlohmann::json; 18 | using apex::sq; 19 | 20 | struct vec3_t { 21 | double x, y, z; 22 | }; 23 | 24 | // Record the function names encountered in here! 25 | @meta std::vector func_names; 26 | 27 | @macro void gen_functions(const char* filename) { 28 | // Open this file at compile time and parse as JSON. 29 | @meta std::ifstream json_file(filename); 30 | @meta json j; 31 | @meta json_file>> j; 32 | 33 | @meta for(auto& item : j.items()) { 34 | // For each item in the json... 35 | @meta std::string name = item.key(); 36 | @meta std::string f = item.value(); 37 | @meta std::cout<< "Injecting '"<< name<< "' : '"<< f<< "' from "<< 38 | filename<< "\n"; 39 | 40 | // Generate a function from the expression. 41 | extern "C" double @("f_" + name)(vec3_t v) { 42 | double x = v.x, y = v.y, z = v.z; 43 | return @expression(f); 44 | } 45 | 46 | // Generate a function to return the gradient. 47 | extern "C" vec3_t @("grad_" + name)(vec3_t v) { 48 | return apex::autodiff_grad(f.c_str(), v); 49 | } 50 | 51 | @meta func_names.push_back(name); 52 | } 53 | } 54 | 55 | // Use Circle like a build system: 56 | 57 | // Open the current directory. 58 | @meta DIR* dir = opendir("."); 59 | 60 | // Loop over all files in the current directory. 61 | @meta while(dirent* ent = readdir(dir)) { 62 | 63 | // Match .json files. 64 | @meta if(match_extension(ent->d_name, "json")) { 65 | 66 | // Generate functions for all entries in this json file. 67 | @macro gen_functions(ent->d_name); 68 | } 69 | } 70 | 71 | @meta closedir(dir); 72 | 73 | std::pair eval(const char* name, vec3_t v) { 74 | @meta for(const std::string& f : func_names) { 75 | if(!strcmp(name, @string(f))) { 76 | return { 77 | @("f_" + f)(v), 78 | @("grad_" + f)(v) 79 | }; 80 | } 81 | } 82 | 83 | printf("Unknown function %s\n", name); 84 | exit(1); 85 | } 86 | 87 | void print_usage() { 88 | printf(" Usage: grad3 name x y z\n"); 89 | exit(1); 90 | } 91 | 92 | int main(int argc, char** argv) { 93 | if(5 != argc) 94 | print_usage(); 95 | 96 | const char* f = argv[1]; 97 | double x = atof(argv[2]); 98 | double y = atof(argv[3]); 99 | double z = atof(argv[4]); 100 | vec3_t v { x, y, z }; 101 | 102 | auto result = eval(f, v); 103 | double val = result.first; 104 | vec3_t grad = result.second; 105 | 106 | printf(" f: %f\n", val); 107 | printf(" grad: { %f, %f, %f }\n", grad.x, grad.y, grad.z); 108 | return 0; 109 | } 110 | -------------------------------------------------------------------------------- /examples/test_errors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanbaxter/apex/f07a92857efd0d7a23d174227b8154f4cbaf01b6/examples/test_errors.png -------------------------------------------------------------------------------- /include/apex/autodiff.hxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | struct ad_exeption_t : std::runtime_error { 7 | ad_exeption_t(const std::string& err) : std::runtime_error(err) { } 8 | }; 9 | 10 | struct ad_t { 11 | enum kind_t { 12 | kind_tape, 13 | kind_component, 14 | kind_literal, 15 | kind_unary, 16 | kind_binary, 17 | kind_func 18 | }; 19 | kind_t kind; 20 | 21 | ad_t(kind_t kind) : kind(kind) { } 22 | 23 | template 24 | derived_t* as() { 25 | return derived_t::classof(this) ? 26 | static_cast(this) : 27 | nullptr; 28 | } 29 | 30 | template 31 | const derived_t* as() const { 32 | return derived_t::classof(this) ? 33 | static_cast(this) : 34 | nullptr; 35 | } 36 | }; 37 | typedef std::unique_ptr ad_ptr_t; 38 | 39 | struct ad_tape_t : ad_t { 40 | ad_tape_t(int index) : ad_t(kind_tape), index(index) { } 41 | static bool classof(const ad_t* ad) { return kind_tape == ad->kind; } 42 | 43 | int index; 44 | }; 45 | 46 | struct ad_component_t : ad_t { 47 | ad_component_t(int index) : ad_t(kind_component), index(index) { } 48 | static bool classof(const ad_t* ad) { return kind_component == ad->kind; } 49 | 50 | int index; 51 | }; 52 | 53 | struct ad_literal_t : ad_t { 54 | ad_literal_t(double x) : ad_t(kind_literal), x(x) { } 55 | static bool classof(const ad_t* ad) { return kind_literal == ad->kind; } 56 | 57 | double x; 58 | }; 59 | 60 | struct ad_unary_t : ad_t { 61 | ad_unary_t(const char* op, ad_ptr_t a) : 62 | ad_t(kind_unary), op(op), a(std::move(a)) { } 63 | static bool classof(const ad_t* ad) { return kind_unary == ad->kind; } 64 | 65 | const char* op; 66 | ad_ptr_t a; 67 | }; 68 | 69 | struct ad_binary_t : ad_t { 70 | ad_binary_t(const char* op, ad_ptr_t a, ad_ptr_t b) : 71 | ad_t(kind_binary), op(op), a(std::move(a)), b(std::move(b)) { } 72 | static bool classof(const ad_t* ad) { return kind_binary == ad->kind; } 73 | 74 | const char* op; 75 | ad_ptr_t a, b; 76 | }; 77 | 78 | struct ad_func_t : ad_t { 79 | ad_func_t(std::string f) : ad_t(kind_func), f(std::move(f)) { } 80 | static bool classof(const ad_t* ad) { return kind_func == ad->kind; } 81 | 82 | std::string f; 83 | std::vector args; 84 | }; 85 | 86 | // Each primary input may be a scalar (dim 0) or a vector (dim > 0). 87 | // autodiff_codegen.hxx uses introspection to parse these out of 88 | // the argument type. 89 | struct autodiff_var_t { 90 | std::string name; 91 | int dim; 92 | }; 93 | 94 | struct autodiff_t { 95 | struct item_t { 96 | // The dimension of the tape item. 97 | // 0 == dim for scalar. dim > 0 for vector. 98 | int dim; 99 | 100 | // The expression to execute to compute this dependent variable's value. 101 | // This is evaluated during the upsweep when creating the tape from the 102 | // independent variables and moving through all subexpressions. 103 | ad_ptr_t val; 104 | 105 | // When updating the gradient of the parent, this tape item loops over each 106 | // of its dependent variables and performs a chain rule increment. 107 | // It calls grad(index, coef) on each index. This recurses, down to the 108 | // independent vars, multiplying in the coef at each recurse. 109 | 110 | // When we hit an independent var, the grads array is empty (although it 111 | // may be empty otherwise) and we simply perform += coef into the slot 112 | // corresponding to the independent variable in the gradient array. 113 | struct grad_t { 114 | int index; 115 | ad_ptr_t coef; 116 | }; 117 | std::vector grads; 118 | }; 119 | 120 | // The first var_names.size() items encode independent variables. 121 | std::vector vars; 122 | std::vector tape; 123 | }; 124 | 125 | autodiff_t make_autodiff(const std::string& formula, 126 | const std::vector& vars); 127 | 128 | std::string print_ad(const ad_t* ad, int indent = 0); 129 | std::string print_autodiff(const autodiff_t& autodiff); 130 | 131 | END_APEX_NAMESPACE 132 | -------------------------------------------------------------------------------- /include/apex/autodiff_codegen.hxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | BEGIN_APEX_NAMESPACE 6 | 7 | // STL doesn't include a sq function, but it's really useful, because it 8 | // memoizes the argument, so we don't have to evaluate its subexpression twice 9 | // to take its square. 10 | inline double sq(double x) { 11 | return x * x; 12 | } 13 | 14 | @macro auto autodiff_expr(const ad_t* ad) { 15 | @meta+ if(const auto* tape = ad->as()) { 16 | @emit return tape_values[tape->index]; 17 | 18 | } else if(const auto* literal = ad->as()) { 19 | @emit return literal->x; 20 | 21 | } else if(const auto* unary = ad->as()) { 22 | @emit return @op( 23 | unary->op, 24 | autodiff_expr(unary->a.get()) 25 | ); 26 | 27 | } else if(const auto* binary = ad->as()) { 28 | @emit return @op( 29 | binary->op, 30 | autodiff_expr(binary->a.get()), 31 | autodiff_expr(binary->b.get()) 32 | ); 33 | 34 | } else if(const auto* func = ad->as()) { 35 | // Evaluate and expand the arguments parameter pack. 36 | 37 | // TODO: Can't currently expand a parameter pack through a macro 38 | // invocation. Why not? Because we expand the macro immediately so we 39 | // can learn its return type, which we need to continue parsing the 40 | // expression. But if the pack expansion is outside of the macro expansion, 41 | // we'll need to expand the macro prior to even parsing the expansion. 42 | 43 | // Two paths forward: 44 | // 1) Speculatively expand the macro expansion. 45 | // 2) Make the macro expansion a dependent expression when its passed an 46 | // unexpanded pack argument. Parse through to the end of the statement, 47 | // then perform substitution and expansion. This is probably what should 48 | // happen. 49 | 50 | // That feature will eliminate the need to switch over the 51 | // argument counts. 52 | // @emit return @expression(func->f)( 53 | // autodiff_expr(func->args[__integer_pack(func->args.size())].get())... 54 | // ); 55 | 56 | if(1 == func->args.size()) { 57 | @emit return @expression(func->f)(autodiff_expr(func->args[0].get())); 58 | 59 | } else if(2 == func->args.size()) { 60 | @emit return @expression(func->f)(autodiff_expr(func->args[0].get()), 61 | autodiff_expr(func->args[1].get())); 62 | } 63 | } 64 | } 65 | 66 | @macro void autodiff_tape(int index, int parent) { 67 | @meta if(index < num_vars) { 68 | // We've hit a terminal, which corresponds to an independent variable. 69 | // Increment the gradient array by the coefficient at parent. 70 | @member_ref(grad, index) += coef[parent]; 71 | 72 | } else { 73 | // We're in a subexpression. Evaluate each of the child nodes. 74 | @meta for(const auto& g : autodiff.tape[index].grads) { 75 | // Evaluate the coefficient into the stack. 76 | coef[index] = coef[parent] * autodiff_expr(g.coef.get()); 77 | @macro autodiff_tape(g.index, index); 78 | } 79 | } 80 | } 81 | 82 | template 83 | @meta type_t autodiff_grad(@meta const char* formula, type_t input) { 84 | 85 | // Parse out the names from the inputs. 86 | static_assert(std::is_class::value, 87 | "argument to autodiff_eval must be a class object"); 88 | 89 | // Collect the name of each primary input. 90 | @meta std::vector vars; 91 | @meta size_t num_vars = @member_count(type_t); 92 | 93 | // Copy the values of the independent variables into the tape. 94 | double tape_values[num_vars]; 95 | @meta for(int i = 0; i < num_vars; ++i) { 96 | // Confirm that we have a scalar double-precision term. 97 | static_assert(std::is_same<@member_type(type_t, i), double>::value, 98 | std::string("member ") + @member_name(type_t, i) + " must be type double"); 99 | 100 | // Push the primary input name. 101 | @meta vars.push_back({ 102 | @member_name(type_t, i), 103 | 0 104 | }); 105 | 106 | // Set the primary input's value. 107 | tape_values[i] = @member_ref(input, i); 108 | } 109 | 110 | // Construct the tape. This makes a foreign function call into libapex.so. 111 | @meta apex::autodiff_t autodiff = apex::make_autodiff(formula, vars); 112 | @meta size_t count = autodiff.tape.size(); 113 | 114 | // Compute the values for the whole tape. This is the forward-mode pass. 115 | // It propagates values from the terminals (independent variables) through 116 | // the subexpressions and up to the root of the function. 117 | 118 | // Evaluate the subexpressions. 119 | @meta for(size_t i = num_vars; i < count; ++i) 120 | tape_values[i] = autodiff_expr(autodiff.tape[i].val.get()); 121 | 122 | // Evaluate the gradients. This is a top-down reverse-mode traversal of 123 | // the autodiff DAG. The partial derivatives are parsed along edges, starting 124 | // from the root and towards each terminal. When a terminal is visited, the 125 | // corresponding component of the gradient is incremented by the product of 126 | // all the partial derivatives from the root of the DAG down to that 127 | // terminal. 128 | double coef[num_vars]; 129 | type_t grad { }; 130 | 131 | // Visit each child of the root node. 132 | @meta int root = count - 1; 133 | @meta for(const auto& g : autodiff.tape[root].grads) { 134 | // Evaluate the coefficient into the stack. 135 | coef[root] = autodiff_expr(g.coef.get()); 136 | 137 | // Recurse on the child. 138 | @macro autodiff_tape(g.index, root); 139 | } 140 | 141 | return std::move(grad); 142 | } 143 | 144 | END_APEX_NAMESPACE -------------------------------------------------------------------------------- /include/apex/parse.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | BEGIN_APEX_NAMESPACE 6 | 7 | namespace parse { 8 | 9 | struct range_t { 10 | token_it begin, end; 11 | explicit operator bool() const { return begin < end; } 12 | 13 | token_t peek() const { 14 | return (begin < end) ? *begin : token_t { }; 15 | } 16 | token_t next() { 17 | return (begin < end) ? *begin++ : token_t { }; 18 | } 19 | token_t advance_if(tk_kind_t kind) { 20 | return (begin < end && kind == begin->kind) ? *begin++ : token_t { }; 21 | } 22 | 23 | void advance(token_it it) { 24 | begin = it; 25 | } 26 | void advance(range_t range) { 27 | begin = range.end; 28 | } 29 | template 30 | void advance(const type_t& result) { 31 | if(result) 32 | advance(result->range.end); 33 | } 34 | }; 35 | 36 | template 37 | using result_t = result_template_t; 38 | 39 | template 40 | result_t make_result(range_t range, attr_t attr = { }) { 41 | return { range, std::move(attr) }; 42 | } 43 | 44 | template 45 | result_t make_result(token_it begin, token_it end, attr_t attr = { }) { 46 | return make_result(range_t { begin, end }, std::move(attr)); 47 | } 48 | 49 | struct parse_exception_t : std::runtime_error { 50 | parse_exception_t(const std::string& err) : std::runtime_error(err) { } 51 | }; 52 | 53 | struct node_t { 54 | enum kind_t { 55 | kind_ident, 56 | kind_unary, 57 | kind_binary, 58 | kind_assign, 59 | kind_ternary, 60 | kind_call, 61 | kind_char, 62 | kind_string, 63 | kind_number, 64 | kind_bool, 65 | kind_subscript, 66 | kind_member, 67 | kind_braced, 68 | }; 69 | 70 | kind_t kind; 71 | source_loc_t loc; 72 | 73 | node_t(kind_t kind, source_loc_t loc) : kind(kind), loc(loc) { } 74 | virtual ~node_t() { } 75 | 76 | template 77 | derived_t* as() { 78 | return derived_t::classof(this) ? 79 | static_cast(this) : 80 | nullptr; 81 | } 82 | 83 | template 84 | const derived_t* as() const { 85 | return derived_t::classof(this) ? 86 | static_cast(this) : 87 | nullptr; 88 | } 89 | }; 90 | typedef std::unique_ptr node_ptr_t; 91 | typedef std::vector node_list_t; 92 | 93 | struct parse_t { 94 | tok::tokenizer_t tokenizer; 95 | node_ptr_t root; 96 | }; 97 | 98 | parse_t parse_expression(const char* str); 99 | 100 | //////////////////////////////////////////////////////////////////////////////// 101 | 102 | struct node_ident_t : node_t { 103 | node_ident_t(source_loc_t loc) : node_t(kind_ident, loc) { } 104 | static bool classof(const node_t* p) { return kind_ident == p->kind; } 105 | 106 | std::string s; 107 | }; 108 | 109 | struct node_unary_t : node_t { 110 | node_unary_t(source_loc_t loc) : node_t(kind_unary, loc) { } 111 | static bool classof(const node_t* p) { return kind_unary == p->kind; } 112 | 113 | expr_op_t op; 114 | node_ptr_t a; 115 | }; 116 | 117 | struct node_binary_t : node_t { 118 | node_binary_t(source_loc_t loc) : node_t(kind_binary, loc) { } 119 | static bool classof(const node_t* p) { return kind_binary == p->kind; } 120 | 121 | expr_op_t op; 122 | node_ptr_t a, b; 123 | }; 124 | 125 | struct node_assign_t : node_t { 126 | node_assign_t(source_loc_t loc) : node_t(kind_assign, loc) { } 127 | static bool classof(const node_t* p) { return kind_assign == p->kind; } 128 | 129 | expr_op_t op; 130 | node_ptr_t a, b; 131 | }; 132 | 133 | struct node_ternary_t : node_t { 134 | node_ternary_t(source_loc_t loc) : node_t(kind_ternary, loc) { } 135 | static bool classof(const node_t* p) { return kind_ternary == p->kind; } 136 | 137 | node_ptr_t a, b, c; 138 | }; 139 | 140 | struct node_char_t : node_t { 141 | node_char_t(char32_t c, source_loc_t loc) : node_t(kind_char, loc), c(c) { } 142 | static bool classof(const node_t* p) { return kind_char == p->kind; } 143 | 144 | // UCS code for character. Caller should use UTF to_utf8 to convert back 145 | // to a UTF-8 string. 146 | char32_t c; 147 | }; 148 | 149 | struct node_string_t : node_t { 150 | node_string_t(std::string s, source_loc_t loc) : 151 | node_t(kind_string, loc), s(std::move(s)) { } 152 | static bool classof(const node_t* p) { return kind_string == p->kind; } 153 | 154 | std::string s; 155 | }; 156 | 157 | struct node_bool_t : node_t { 158 | node_bool_t(bool b, source_loc_t loc) : node_t(kind_bool, loc), b(b) { } 159 | static bool classof(const node_t* p) { return kind_bool == p->kind; } 160 | 161 | bool b; 162 | }; 163 | 164 | struct node_number_t : node_t { 165 | node_number_t(number_t number, source_loc_t loc) : 166 | node_t(kind_number, loc), x(number) { } 167 | static bool classof(const node_t* p) { return kind_number == p->kind; } 168 | 169 | number_t x; 170 | }; 171 | 172 | struct node_call_t : node_t { 173 | node_call_t(source_loc_t loc) : node_t(kind_call, loc) { } 174 | static bool classof(const node_t* p) { return kind_call == p->kind; } 175 | 176 | node_ptr_t f; 177 | std::vector args; 178 | }; 179 | 180 | struct node_subscript_t : node_t { 181 | node_subscript_t(source_loc_t loc) : node_t(kind_subscript, loc) { } 182 | static bool classof(const node_t* p) { return kind_subscript == p->kind; } 183 | 184 | node_ptr_t lhs; 185 | std::vector args; 186 | }; 187 | 188 | struct node_member_t : node_t { 189 | node_member_t(source_loc_t loc) : node_t(kind_member, loc) { } 190 | static bool classof(const node_t* p) { return kind_member == p->kind; } 191 | 192 | tk_kind_t tk; // dot or arrow 193 | node_ptr_t lhs; 194 | std::string member; 195 | }; 196 | 197 | struct node_braced_t : node_t { 198 | node_braced_t(source_loc_t loc) : node_t(kind_braced, loc) { } 199 | static bool classof(const node_t* p) { return kind_braced == p->kind; } 200 | 201 | std::vector args; 202 | }; 203 | 204 | } // namespace parse 205 | 206 | 207 | END_APEX_NAMESPACE 208 | -------------------------------------------------------------------------------- /include/apex/tokenizer.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | namespace parse { 7 | 8 | struct range_t; 9 | 10 | } 11 | 12 | namespace tok { 13 | 14 | struct range_t { 15 | const char* begin, *end; 16 | explicit operator bool() const { 17 | return begin < end; 18 | } 19 | void advance(const char* p) { 20 | begin = p; 21 | } 22 | char advance_if(char c) { 23 | return (begin < end && *begin == c) ? *begin++ : 0; 24 | } 25 | template 26 | char advance_if(func_t f) { 27 | return (begin < end && f(*begin)) ? *begin++ : 0; 28 | } 29 | template 30 | void advance(const type_t& result) { 31 | if(result) 32 | advance(result->range.end); 33 | } 34 | 35 | char operator[](ptrdiff_t index) { 36 | return (begin + index < end) ? begin[index] : 0; 37 | } 38 | 39 | char peek() const { 40 | return (begin < end) ? *begin : 0; 41 | } 42 | char next() { 43 | return (begin < end) ? *begin++ : 0; 44 | } 45 | 46 | bool match(const char* s) const { 47 | const char* p = begin; 48 | while(*s && p < end && *p++ == *s) ++s; 49 | return !*s; 50 | } 51 | 52 | bool match_advance(const char* s) { 53 | const char* p = begin; 54 | while(*s && p < end && *p++ == *s) ++s; 55 | bool success = !*s; 56 | if(success) begin = p; 57 | return success; 58 | } 59 | }; 60 | 61 | template 62 | using result_t = result_template_t; 63 | 64 | template 65 | result_t make_result(range_t range, attr_t attr = { }) { 66 | return { range, std::move(attr) }; 67 | } 68 | 69 | template 70 | result_t make_result(const char* begin, const char* end, 71 | attr_t attr = { }) { 72 | return make_result(range_t { begin, end }, std::move(attr)); 73 | } 74 | 75 | // operators.cxx. Match the longest operator. 76 | result_t match_operator(range_t range); 77 | 78 | struct tokenizer_t; 79 | 80 | struct lexer_t { 81 | lexer_t(tokenizer_t& tokenizer) : tokenizer(tokenizer) { } 82 | 83 | result_t char_literal(range_t range); 84 | result_t c_char(range_t range); 85 | 86 | result_t string_literal(range_t range); 87 | result_t s_char(range_t range); 88 | 89 | // Match a-zA-Z or a UCS. If digit is true, also match a digit. 90 | result_t identifier_char(range_t range, bool digit); 91 | 92 | // Read an extended character. 93 | result_t ucs(range_t range); 94 | 95 | // Read a character sequence matching any number. 96 | // This conforms to the C++17 definition pp-number. 97 | result_t pp_number(range_t range); 98 | result_t decimal_sequence(range_t range); 99 | result_t decimal_number(range_t range); 100 | result_t exponent_part(range_t range); 101 | 102 | result_t integer_literal(range_t range); 103 | result_t floating_point_literal(range_t range); 104 | result_t number(range_t range); 105 | 106 | result_t literal(range_t range); 107 | result_t identifier(range_t range); 108 | result_t operator_(range_t range); 109 | result_t token(range_t range); 110 | 111 | const char* skip_comment(range_t range); 112 | bool advance_skip(range_t& range); 113 | 114 | 115 | void throw_error(const char* pos, const char* msg); 116 | 117 | tokenizer_t& tokenizer; 118 | }; 119 | 120 | struct tokenizer_t { 121 | std::vector strings; 122 | std::vector ints; 123 | std::vector floats; 124 | 125 | // Byte offset for each line start. 126 | std::vector line_offsets; 127 | 128 | // Original text we tokenized. 129 | std::string text; 130 | 131 | // The text divided into tokens. 132 | std::vector tokens; 133 | 134 | parse::range_t token_range() const; 135 | 136 | int reg_string(range_t range); 137 | int find_string(range_t range) const; 138 | 139 | // Return 0-indexed line and column offsets for the token at 140 | // the specified byte offset. This performs UCS decoding to support 141 | // multibyte characters. 142 | int token_offset(source_loc_t loc) const; 143 | int token_line(int offset) const; 144 | int token_col(int offset, int line) const; 145 | std::pair token_linecol(int offset) const; 146 | std::pair token_linecol(source_loc_t loc) const; 147 | 148 | void tokenize(); 149 | }; 150 | 151 | } // namespace tok 152 | 153 | END_APEX_NAMESPACE 154 | -------------------------------------------------------------------------------- /include/apex/tokens.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | enum tk_kind_t : uint8_t { 7 | tk_none = 0, 8 | tk_ident, 9 | tk_int, 10 | tk_float, 11 | tk_char, 12 | tk_string, 13 | tk_kw_false, 14 | tk_kw_true, 15 | tk_sym_amp, 16 | tk_sym_ampamp, 17 | tk_sym_ampeq, 18 | tk_sym_arrow, 19 | tk_sym_arrowstar, 20 | tk_sym_at, 21 | tk_sym_attrib_l, 22 | tk_sym_bang, 23 | tk_sym_bangeq, 24 | tk_sym_brace_l, 25 | tk_sym_brace_r, 26 | tk_sym_bracket_l, 27 | tk_sym_bracket_r, 28 | tk_sym_caret, 29 | tk_sym_careteq, 30 | tk_sym_chevron_l, 31 | tk_sym_chevron_r, 32 | tk_sym_col, 33 | tk_sym_colcol, 34 | tk_sym_comma, 35 | tk_sym_dot, 36 | tk_sym_dotstar, 37 | tk_sym_ellipsis, 38 | tk_sym_eq, 39 | tk_sym_eqeq, 40 | tk_sym_gt, 41 | tk_sym_gteq, 42 | tk_sym_gtgt, 43 | tk_sym_gtgteq, 44 | tk_sym_hash, 45 | tk_sym_hashhash, 46 | tk_sym_lt, 47 | tk_sym_lteq, 48 | tk_sym_ltlt, 49 | tk_sym_ltlteq, 50 | tk_sym_minus, 51 | tk_sym_minuseq, 52 | tk_sym_minusminus, 53 | tk_sym_paren_l, 54 | tk_sym_paren_r, 55 | tk_sym_percent, 56 | tk_sym_percenteq, 57 | tk_sym_pipe, 58 | tk_sym_pipeeq, 59 | tk_sym_pipepipe, 60 | tk_sym_plus, 61 | tk_sym_pluseq, 62 | tk_sym_plusplus, 63 | tk_sym_question, 64 | tk_sym_semi, 65 | tk_sym_slash, 66 | tk_sym_slasheq, 67 | tk_sym_star, 68 | tk_sym_stareq, 69 | tk_sym_tilde, 70 | }; 71 | 72 | struct token_t { 73 | tk_kind_t kind : 8; 74 | int store : 24; 75 | const char* begin, *end; 76 | 77 | operator tk_kind_t() const { return kind; } 78 | }; 79 | typedef const token_t* token_it; 80 | 81 | // Index of the token within the token stream. 82 | struct source_loc_t { 83 | int index; 84 | }; 85 | 86 | END_APEX_NAMESPACE 87 | -------------------------------------------------------------------------------- /include/apex/util.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define BEGIN_APEX_NAMESPACE namespace apex { 9 | #define END_APEX_NAMESPACE } 10 | 11 | BEGIN_APEX_NAMESPACE 12 | 13 | // Encodes ucs into the UTF-8 buffer at s. Returns the number of characters 14 | // encoded. 0 indicates error. 15 | int to_utf8(char* s, int ucs); 16 | 17 | // Returns the number of code-units consumed and the value of the character. 18 | // 0 indicates error. 19 | std::pair from_utf8(const char* s); 20 | 21 | // sprintf into a std::string 22 | std::string format(const char* pattern, ...); 23 | std::string vformat(const char* pattern, va_list args); 24 | 25 | //////////////////////////////////////////////////////////////////////////////// 26 | // Reusable types for result_t<> returns. Combines range and attribute. 27 | 28 | template 29 | struct result_base_t { 30 | result_base_t() { } 31 | result_base_t(range_t range, attr_t attr) : 32 | range(range), attr(std::move(attr)), success(true) { } 33 | 34 | template 35 | result_base_t(result_base_t&& rhs) : 36 | range(rhs.range), attr(std::move(rhs.attr)), success(rhs.success) { } 37 | 38 | range_t range; 39 | attr_t attr; 40 | bool success = false; 41 | }; 42 | 43 | template 44 | class result_template_t : protected result_base_t { 45 | public: 46 | typedef result_base_t base_t; 47 | 48 | result_template_t() { } 49 | result_template_t(range_t range, attr_t attr) : 50 | base_t(range, std::move(attr)) { } 51 | 52 | template 53 | result_template_t(result_template_t&& rhs) : 54 | base_t(std::move(rhs.get_base())) { } 55 | 56 | explicit operator bool() const { return this->success; } 57 | 58 | base_t* operator->() { 59 | assert(this->success); 60 | return this; 61 | } 62 | 63 | const base_t* operator->() const { 64 | assert(this->success); 65 | return this; 66 | } 67 | 68 | // Ugly hack to let the move constructor upcast through a protected base 69 | // class. 70 | base_t& get_base() { return *this; } 71 | }; 72 | 73 | struct unused_t { }; 74 | 75 | 76 | END_APEX_NAMESPACE 77 | -------------------------------------------------------------------------------- /include/apex/value.hxx: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | enum number_kind_t { 7 | number_kind_none, 8 | number_kind_bool, 9 | number_kind_int, 10 | number_kind_float 11 | }; 12 | 13 | struct number_t { 14 | number_kind_t kind; 15 | union { 16 | bool b; 17 | int64_t i; 18 | double d; 19 | }; 20 | 21 | number_t() : kind(number_kind_none) { } 22 | number_t(bool b) : kind(number_kind_bool), b(b) { } 23 | number_t(int64_t i) : kind(number_kind_int), i(i) { } 24 | number_t(double d) : kind(number_kind_float), d(d) { } 25 | 26 | bool is_boolean() const { return number_kind_bool == kind; } 27 | bool is_integral() const { return number_kind_int == kind; } 28 | bool is_floating() const { return number_kind_float == kind; } 29 | bool is_arithmetic() const { return is_integral() || is_floating(); } 30 | 31 | explicit operator bool() const { 32 | return number_kind_none != kind; 33 | } 34 | 35 | number_t to_boolean() const; 36 | number_t to_integral() const; 37 | number_t to_floating() const; 38 | 39 | number_t to_kind(number_kind_t kind2) const; 40 | 41 | std::string to_string() const; 42 | 43 | 44 | template 45 | auto switch_numeric(func_t f) const { 46 | switch(kind) { 47 | case number_kind_int: 48 | return f(i); 49 | 50 | case number_kind_float: 51 | return f(d); 52 | } 53 | } 54 | 55 | template 56 | auto switch_all(func_t f) const { 57 | switch(kind) { 58 | case number_kind_bool: 59 | return f(b); 60 | 61 | case number_kind_int: 62 | return f(i); 63 | 64 | case number_kind_float: 65 | return f(d); 66 | } 67 | } 68 | 69 | template 70 | type_t convert() const { 71 | return switch_all([](auto x) { return (type_t)x; }); 72 | } 73 | }; 74 | 75 | number_kind_t common_arithmetic_kind(number_kind_t left, number_kind_t right); 76 | 77 | //////////////////////////////////////////////////////////////////////////////// 78 | 79 | enum expr_op_t : uint8_t { 80 | expr_op_none = 0, 81 | 82 | // postfix. 83 | expr_op_inc_post, 84 | expr_op_dec_post, 85 | 86 | // prefix. 87 | expr_op_inc_pre, // ++x 88 | expr_op_dec_pre, // --x 89 | expr_op_complement, // ~x 90 | expr_op_negate, // !x 91 | expr_op_plus, // +x 92 | expr_op_minus, // -x 93 | expr_op_addressof, // &x 94 | expr_op_indirection, // *x 95 | 96 | // Right-associative binary operators. 97 | expr_op_ptrmem_dot, 98 | expr_op_ptrmem_arrow, 99 | 100 | // Left-associative operations. 101 | expr_op_mul, 102 | expr_op_div, 103 | expr_op_mod, 104 | expr_op_add, 105 | expr_op_sub, 106 | expr_op_shl, 107 | expr_op_shr, 108 | expr_op_lt, 109 | expr_op_gt, 110 | expr_op_lte, 111 | expr_op_gte, 112 | expr_op_eq, 113 | expr_op_ne, 114 | expr_op_bit_and, 115 | expr_op_bit_xor, 116 | expr_op_bit_or, 117 | expr_op_log_and, 118 | expr_op_log_or, 119 | 120 | // Right-associative operations. 121 | expr_op_assign, 122 | expr_op_assign_mul, 123 | expr_op_assign_div, 124 | expr_op_assign_mod, 125 | expr_op_assign_add, 126 | expr_op_assign_sub, 127 | expr_op_assign_shl, 128 | expr_op_assign_shr, 129 | expr_op_assign_and, 130 | expr_op_assign_or, 131 | expr_op_assign_xor, 132 | 133 | expr_op_ternary, 134 | expr_op_sequence, 135 | }; 136 | 137 | extern const char* expr_op_names[]; 138 | 139 | number_t value_unary(expr_op_t op, number_t value); 140 | number_t value_binary(expr_op_t op, number_t left, number_t right); 141 | 142 | END_APEX_NAMESPACE 143 | -------------------------------------------------------------------------------- /src/autodiff/autodiff.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | BEGIN_APEX_NAMESPACE 8 | 9 | using namespace parse; 10 | 11 | 12 | struct ad_builder_t : autodiff_t { 13 | typedef autodiff_t::item_t item_t; 14 | typedef autodiff_t::item_t::grad_t grad_t; 15 | 16 | int literal_node(double x); 17 | 18 | // Operators 19 | int add(int a, int b); 20 | int sub(int a, int b); 21 | int mul(int a, int b); 22 | int div(int a, int b); 23 | int negate(int a); 24 | 25 | // Calls to elementary functions. 26 | int sq(int a); 27 | int sqrt(int a); 28 | int exp(int a); 29 | int log(int a); 30 | int sin(int a); 31 | int cos(int a); 32 | int tan(int a); 33 | int sinh(int a); 34 | int cosh(int a); 35 | int tanh(int a); 36 | int abs(int a); 37 | int pow(int a, int b); 38 | int norm(const int* p, int count); 39 | 40 | ad_ptr_t val(int index); 41 | ad_ptr_t literal(double x); 42 | ad_ptr_t add(ad_ptr_t a, ad_ptr_t b); 43 | ad_ptr_t sub(ad_ptr_t a, ad_ptr_t b); 44 | ad_ptr_t mul(ad_ptr_t a, ad_ptr_t b); 45 | ad_ptr_t div(ad_ptr_t a, ad_ptr_t b); 46 | ad_ptr_t rcp(ad_ptr_t a); 47 | ad_ptr_t sq(ad_ptr_t a); 48 | ad_ptr_t func(const char* name, ad_ptr_t a, ad_ptr_t b = nullptr); 49 | 50 | std::string str(const parse::node_t* node); 51 | 52 | int recurse(const parse::node_ident_t* node); 53 | int recurse(const parse::node_member_t* node); 54 | int recurse(const parse::node_subscript_t* node); 55 | int recurse(const parse::node_unary_t* node); 56 | int recurse(const parse::node_binary_t* node); 57 | int recurse(const parse::node_call_t* node); 58 | int recurse(const parse::node_t* node); 59 | 60 | void throw_error(const parse::node_t* node, const char* fmt, ...); 61 | 62 | int push_item(item_t item) { 63 | int count = tape.size(); 64 | tape.push_back(std::move(item)); 65 | return count; 66 | } 67 | 68 | int find_var(const parse::node_t* node, std::string name); 69 | 70 | // If the tokenizer is provided we can print error messages that are 71 | // line/col specific. 72 | const tok::tokenizer_t* tokenizer = nullptr; 73 | 74 | // Store each literal value once. This doesn't effect the computation 75 | // directly, but is helpful for subexpression elimination. 76 | std::map literal_map; 77 | 78 | enum op_name_t { 79 | op_name_tape, 80 | op_name_literal, 81 | op_name_add, 82 | op_name_sub, 83 | op_name_mul, 84 | op_name_div, 85 | op_name_negate, 86 | op_name_sq, 87 | op_name_sqrt, 88 | op_name_exp, 89 | op_name_log, 90 | op_name_sin, 91 | op_name_cos, 92 | op_name_tan, 93 | op_name_sinh, 94 | op_name_cosh, 95 | op_name_tanh, 96 | op_name_abs, 97 | op_name_pow, 98 | }; 99 | 100 | union op_key_t { 101 | struct { 102 | op_name_t name : 8; 103 | uint a : 28; 104 | uint b : 28; 105 | }; 106 | uint64_t bits; 107 | }; 108 | 109 | std::optional find_cse(op_name_t op_name, int a, int b = -1); 110 | std::optional find_literal(double x); 111 | 112 | // Map each operation to the location in the tape where its value is stored. 113 | // We only build this structure during the upsweep when computing the tape 114 | // values. We won't necessarily match common subexpressions in partial 115 | // derivatives, because we don't want to memoize all those fragments as it 116 | // will consume more storage than we're prepared to give. 117 | std::map cse_map; 118 | }; 119 | 120 | 121 | //////////////////////////////////////////////////////////////////////////////// 122 | // TODO: Register each tape insertion with the common subexpression elimination 123 | // (CSE) map, so that find_cse will work. 124 | 125 | int ad_builder_t::literal_node(double x) { 126 | item_t item { }; 127 | item.val = literal(x); 128 | return push_item(std::move(item)); 129 | } 130 | 131 | int ad_builder_t::add(int a, int b) { 132 | if(auto cse = find_cse(op_name_add, a, b)) 133 | return *cse; 134 | 135 | item_t item { }; 136 | item.val = add(val(a), val(b)); 137 | item.grads.push_back({ 138 | a, 139 | literal(1) 140 | }); 141 | item.grads.push_back({ 142 | b, 143 | literal(1) 144 | }); 145 | return push_item(std::move(item)); 146 | } 147 | 148 | int ad_builder_t::sub(int a, int b) { 149 | // Nip this in the bud. 150 | if(a == b) 151 | return literal_node(0); 152 | 153 | if(auto cse = find_cse(op_name_sub, a, b)) 154 | return *cse; 155 | 156 | item_t item { }; 157 | item.val = sub(val(a), val(b)); 158 | item.grads.push_back({ 159 | a, 160 | literal(1) 161 | }); 162 | item.grads.push_back({ 163 | b, 164 | literal(-1) 165 | }); 166 | return push_item(std::move(item)); 167 | } 168 | 169 | int ad_builder_t::mul(int a, int b) { 170 | if(auto cse = find_cse(op_name_mul, a, b)) 171 | return *cse; 172 | 173 | // The sq operator is memoized, so prefer that. 174 | if(a == b) 175 | return sq(a); 176 | 177 | // grad (a * b) = a grad b + b grad a. 178 | item_t item { }; 179 | item.val = mul(val(a), val(b)); 180 | item.grads.push_back({ 181 | b, // a * grad b 182 | val(a) 183 | }); 184 | item.grads.push_back({ 185 | a, // b * grad a 186 | val(b) 187 | }); 188 | return push_item(std::move(item)); 189 | } 190 | 191 | int ad_builder_t::div(int a, int b) { 192 | if(auto cse = find_cse(op_name_div, a, b)) 193 | return *cse; 194 | 195 | // grad (a / b) = 1 / b * grad a - a / b^2 * grad b. 196 | item_t item { }; 197 | item.val = div(val(a), val(b)); 198 | item.grads.push_back({ 199 | // 1 / b * grad a. 200 | a, 201 | rcp(val(b)) 202 | }); 203 | item.grads.push_back({ 204 | // a / b^2 * grad b. 205 | b, 206 | div(val(a), sq(val(b))) 207 | }); 208 | return push_item(std::move(item)); 209 | } 210 | 211 | int ad_builder_t::negate(int a) { 212 | 213 | item_t item { }; 214 | item.val = mul(literal(-1), val(a)); 215 | item.grads.push_back({ 216 | a, 217 | literal(-1) 218 | }); 219 | return push_item(std::move(item)); 220 | } 221 | 222 | //////////////////////////////////////////////////////////////////////////////// 223 | // Elementary functions 224 | 225 | int ad_builder_t::sq(int a) { 226 | item_t item { }; 227 | item.val = sq(val(a)); 228 | item.grads.push_back({ 229 | // grad (a^2) = 2 * a grad a 230 | a, 231 | mul(literal(2), val(a)) 232 | }); 233 | return push_item(std::move(item)); 234 | } 235 | 236 | int ad_builder_t::sqrt(int a) { 237 | item_t item { }; 238 | item.val = func("std::sqrt", val(a)); 239 | item.grads.push_back({ 240 | // .5 / sqrt(a) * grad a 241 | a, 242 | div(literal(.5), func("std::sqrt", val(a))) 243 | }); 244 | return push_item(std::move(item)); 245 | } 246 | 247 | int ad_builder_t::exp(int a) { 248 | item_t item { }; 249 | item.val = func("std::exp", val(a)); 250 | item.grads.push_back({ 251 | // exp(a) * grad a 252 | a, 253 | func("std::exp", val(a)) 254 | }); 255 | return push_item(std::move(item)); 256 | } 257 | 258 | int ad_builder_t::log(int a) { 259 | // grad (ln a) = grad a / a 260 | item_t item { }; 261 | item.val = func("std::log", val(a)); 262 | item.grads.push_back({ 263 | a, 264 | rcp(val(a)) 265 | }); 266 | return push_item(std::move(item)); 267 | } 268 | 269 | int ad_builder_t::sin(int a) { 270 | item_t item { }; 271 | item.val = func("std::sin", val(a)); 272 | item.grads.push_back({ 273 | a, 274 | func("std::cos", val(a)) 275 | }); 276 | return push_item(std::move(item)); 277 | } 278 | 279 | int ad_builder_t::cos(int a) { 280 | item_t item { }; 281 | item.val = func("std::cos", val(a)); 282 | item.grads.push_back({ 283 | a, 284 | mul(literal(-1), func("std::sin", val(a))) 285 | }); 286 | return push_item(std::move(item)); 287 | } 288 | 289 | int ad_builder_t::tan(int a) { 290 | item_t item { }; 291 | item.val = func("std::tan", val(a)); 292 | item.grads.push_back({ 293 | a, 294 | sq(rcp(func("std::cos", val(a)))) 295 | }); 296 | return push_item(std::move(item)); 297 | } 298 | 299 | int ad_builder_t::sinh(int a) { 300 | item_t item { }; 301 | item.val = func("std::sinh", val(a)); 302 | item.grads.push_back({ 303 | a, 304 | func("std::cosh", val(a)) 305 | }); 306 | return push_item(std::move(item)); 307 | } 308 | 309 | int ad_builder_t::cosh(int a) { 310 | item_t item { }; 311 | item.val = func("std::cosh", val(a)); 312 | item.grads.push_back({ 313 | a, 314 | func("std::sinh", val(a)) 315 | }); 316 | return push_item(std::move(item)); 317 | } 318 | 319 | int ad_builder_t::tanh(int a) { 320 | item_t item { }; 321 | item.val = func("std::tanh", val(a)); 322 | item.grads.push_back({ 323 | a, 324 | sub(literal(1), sq(func("std::tanh", val(a)))) 325 | }); 326 | return push_item(std::move(item)); 327 | } 328 | 329 | int ad_builder_t::abs(int a) { 330 | item_t item { }; 331 | item.val = func("std::abs", val(a)); 332 | item.grads.push_back({ 333 | a, // d/dx abs(x) = x / abs(x) 334 | div(val(a), func("std::abs", val(a))) 335 | }); 336 | } 337 | 338 | int ad_builder_t::pow(int a, int b) { 339 | item_t item { }; 340 | item.val = func("std::pow", val(a), val(b)); 341 | item.grads.push_back({ 342 | // d/dx (a**b) = b a**(b - 1) da/dx 343 | a, 344 | mul(val(b), func("std::pow", val(a), sub(val(b), literal(1)))) 345 | }); 346 | item.grads.push_back({ 347 | // d/dx (a**b) = a**b ln a db/dx 348 | b, 349 | mul(func("std::pow", val(a), val(b)), func("std::log", val(a))) 350 | }); 351 | return push_item(std::move(item)); 352 | } 353 | 354 | int ad_builder_t::norm(const int* p, int count) { 355 | item_t item { }; 356 | 357 | // Square and accumulate each argument. 358 | ad_ptr_t x = sq(val(p[0])); 359 | for(int i = 1; i < count; ++i) 360 | x = add(std::move(x), sq(val(p[i]))); 361 | 362 | // Take its sqrt. 363 | item.val = func("std::sqrt", std::move(x)); 364 | 365 | // Differentiate with respect to each argument. 366 | // The derivative is f_i * grad f_i / norm(f). 367 | // We compute the norm in this tape item during the upsweep, so load it. 368 | // We have a 1 / norm common subexpression--this can be eliminated by the 369 | // optimizer, but may be added to the tape as its own value. 370 | int index = tape.size(); 371 | for(int i = 0; i < count; ++i) { 372 | item.grads.push_back({ 373 | p[i], 374 | div(val(p[i]), val(index)) 375 | }); 376 | } 377 | return push_item(std::move(item)); 378 | } 379 | 380 | std::string ad_builder_t::str(const node_t* node) { 381 | switch(node->kind) { 382 | case node_t::kind_ident: 383 | return static_cast(node)->s; 384 | 385 | case node_t::kind_member: { 386 | const auto* member = static_cast(node); 387 | return str(member->lhs.get()) + "." + member->member; 388 | } 389 | 390 | case node_t::kind_subscript: { 391 | const auto* subscript = static_cast(node); 392 | if(1 != subscript->args.size()) 393 | throw_error(node, "subscript must have 1 index"); 394 | return str(subscript->lhs.get()) + 395 | "[" + str(subscript->args[0].get()) + "]"; 396 | } 397 | 398 | case node_t::kind_number: { 399 | const auto* number = static_cast(node); 400 | return number->x.to_string(); 401 | } 402 | 403 | default: 404 | throw_error(node, "unsupported identifier kind"); 405 | } 406 | } 407 | 408 | int ad_builder_t::recurse(const node_unary_t* node) { 409 | int a = recurse(node->a.get()); 410 | int c = -1; 411 | switch(node->op) { 412 | case expr_op_negate: 413 | c = negate(a); 414 | break; 415 | 416 | default: 417 | throw_error(node, "unsupported unary %s", expr_op_names[node->op]); 418 | } 419 | return c; 420 | } 421 | 422 | int ad_builder_t::recurse(const node_binary_t* node) { 423 | int a = recurse(node->a.get()); 424 | int b = recurse(node->b.get()); 425 | int c = -1; 426 | 427 | switch(node->op) { 428 | case expr_op_add: 429 | c = add(a, b); 430 | break; 431 | 432 | case expr_op_sub: 433 | c = sub(a, b); 434 | break; 435 | 436 | case expr_op_mul: 437 | c = mul(a, b); 438 | break; 439 | 440 | case expr_op_div: 441 | c = div(a, b); 442 | break; 443 | 444 | default: 445 | throw_error(node, "unsupported binary %s", expr_op_names[node->op]); 446 | } 447 | return c; 448 | } 449 | 450 | int ad_builder_t::recurse(const node_call_t* node) { 451 | std::string func_name = str(node->f.get()); 452 | std::vector args(node->args.size()); 453 | for(int i = 0; i < node->args.size(); ++i) 454 | args[i] = recurse(node->args[i].get()); 455 | 456 | #define GEN_CALL_1(s) \ 457 | if(#s == func_name) { \ 458 | if(1 != args.size()) \ 459 | throw_error(node, #s "() requires 1 argument"); \ 460 | return s(args[0]); \ 461 | } 462 | 463 | GEN_CALL_1(sq) 464 | GEN_CALL_1(sqrt) 465 | GEN_CALL_1(exp) 466 | GEN_CALL_1(log) 467 | GEN_CALL_1(sin) 468 | GEN_CALL_1(cos) 469 | GEN_CALL_1(tan) 470 | GEN_CALL_1(sinh) 471 | GEN_CALL_1(cosh) 472 | GEN_CALL_1(tanh) 473 | GEN_CALL_1(abs) 474 | 475 | #undef GEN_CALL_1 476 | 477 | if("pow" == func_name) { 478 | if(2 != node->args.size()) 479 | throw_error(node, "pow() requires 2 arguments"); 480 | return pow(args[0], args[1]); 481 | 482 | } else if("norm" == func_name) { 483 | // Allow 1 or more arguments. 484 | if(!node->args.size()) 485 | throw_error(node, "norm() requires 1 or more arguments"); 486 | return norm(args.data(), args.size()); 487 | 488 | } else { 489 | throw_error(node, "unknown function '%s'", func_name.c_str()); 490 | } 491 | } 492 | 493 | int ad_builder_t::recurse(const node_t* node) { 494 | int result = -1; 495 | switch(node->kind) { 496 | case node_t::kind_number: { 497 | auto* number = node->as(); 498 | result = literal_node(number->x.convert()); 499 | break; 500 | } 501 | 502 | case node_t::kind_ident: 503 | case node_t::kind_member: 504 | case node_t::kind_subscript: 505 | // Don't add a new tape item for independent variables--these get 506 | // provisioned in order at the start. 507 | result = find_var(node, str(node)); 508 | break; 509 | 510 | case node_t::kind_unary: 511 | result = recurse(static_cast(node)); 512 | break; 513 | 514 | case node_t::kind_binary: 515 | result = recurse(static_cast(node)); 516 | break; 517 | 518 | case node_t::kind_call: 519 | result = recurse(static_cast(node)); 520 | break; 521 | 522 | default: 523 | break; 524 | } 525 | return result; 526 | } 527 | 528 | autodiff_t make_autodiff(const parse_t& parse, 529 | const std::vector& vars) { 530 | 531 | ad_builder_t ad_builder; 532 | ad_builder.tokenizer = &parse.tokenizer; 533 | ad_builder.vars = vars; 534 | ad_builder.tape.resize(ad_builder.vars.size()); 535 | ad_builder.recurse(parse.root.get()); 536 | 537 | return std::move(ad_builder); 538 | } 539 | 540 | autodiff_t make_autodiff(const std::string& formula, 541 | const std::vector& vars) { 542 | 543 | auto p = parse::parse_expression(formula.c_str()); 544 | return make_autodiff(p, std::move(vars)); 545 | } 546 | 547 | 548 | //////////////////////////////////////////////////////////////////////////////// 549 | 550 | ad_ptr_t ad_builder_t::val(int index) { 551 | // Return a value from the tape. 552 | return std::make_unique(index); 553 | } 554 | 555 | ad_ptr_t ad_builder_t::literal(double x) { 556 | return std::make_unique(x); 557 | } 558 | 559 | ad_ptr_t ad_builder_t::add(ad_ptr_t a, ad_ptr_t b) { 560 | auto* a2 = a->as(); 561 | auto* b2 = b->as(); 562 | if(a2 && b2) 563 | return literal(a2->x + b2->x); 564 | else 565 | return std::make_unique("+", std::move(a), std::move(b)); 566 | } 567 | 568 | ad_ptr_t ad_builder_t::sub(ad_ptr_t a, ad_ptr_t b) { 569 | auto* a2 = a->as(); 570 | auto* b2 = b->as(); 571 | if(a2 && b2) 572 | return literal(a2->x - b2->x); 573 | return std::make_unique("-", std::move(a), std::move(b)); 574 | } 575 | 576 | ad_ptr_t ad_builder_t::mul(ad_ptr_t a, ad_ptr_t b) { 577 | auto* a2 = a->as(); 578 | auto* b2 = b->as(); 579 | if(a2 && b2) 580 | return literal(a2->x * b2->x); 581 | return std::make_unique("*", std::move(a), std::move(b)); 582 | } 583 | 584 | ad_ptr_t ad_builder_t::div(ad_ptr_t a, ad_ptr_t b) { 585 | auto* a2 = a->as(); 586 | auto* b2 = b->as(); 587 | if(a2 && b2) 588 | return literal(a2->x / b2->x); 589 | return std::make_unique("/", std::move(a), std::move(b)); 590 | } 591 | 592 | ad_ptr_t ad_builder_t::rcp(ad_ptr_t a) { 593 | if(auto* a2 = a->as()) 594 | return literal(1 / a2->x); 595 | else 596 | return div(literal(1), std::move(a)); 597 | } 598 | 599 | ad_ptr_t ad_builder_t::sq(ad_ptr_t a) { 600 | if(auto* a2 = a->as()) 601 | return literal(a2->x * a2->x); 602 | else 603 | return func("apex::sq", std::move(a)); 604 | } 605 | 606 | ad_ptr_t ad_builder_t::func(const char* f, ad_ptr_t a, ad_ptr_t b) { 607 | // TODO: Perform constant folding? 608 | 609 | auto node = std::make_unique(f); 610 | node->args.push_back(std::move(a)); 611 | if(b) node->args.push_back(std::move(b)); 612 | return node; 613 | } 614 | 615 | //////////////////////////////////////////////////////////////////////////////// 616 | 617 | void ad_builder_t::throw_error(const node_t* node, const char* fmt, ...) { 618 | // Get the user's error message. 619 | va_list args; 620 | va_start(args, fmt); 621 | std::string msg = vformat(fmt, args); 622 | va_end(args); 623 | 624 | // If the tokenizer is available, print a location message. 625 | if(tokenizer) { 626 | std::pair linecol = tokenizer->token_linecol(node->loc); 627 | msg = format( 628 | "autodiff formula \"%s\"\n" 629 | "line %d col %d\n" 630 | "%s", 631 | tokenizer->text.c_str(), 632 | linecol.first + 1, 633 | linecol.second + 1, 634 | msg.c_str() 635 | ); 636 | } 637 | 638 | throw ad_exeption_t(msg); 639 | } 640 | 641 | int ad_builder_t::find_var(const node_t* node, std::string name) { 642 | auto p = [&](const auto& var) { return var.name == name; }; 643 | auto it = std::find_if(vars.begin(), vars.end(), p); 644 | if(vars.end() == it) 645 | throw_error(node, "unknown variable '%s'", name.c_str()); 646 | return it - vars.begin(); 647 | } 648 | 649 | std::optional ad_builder_t::find_cse(op_name_t op_name, int a, int b) { 650 | switch(op_name) { 651 | case op_name_add: 652 | case op_name_mul: 653 | // For these commutative operators, put the lower index on the left. 654 | // This improves CSE performance. 655 | if(a > b) 656 | std::swap(a, b); 657 | break; 658 | 659 | default: 660 | break; 661 | } 662 | 663 | op_key_t op_key { op_name, (uint)a, (uint)b }; 664 | auto it = cse_map.find(op_key.bits); 665 | std::optional index; 666 | if(cse_map.end() != it) { 667 | index = it->second; 668 | } 669 | return index; 670 | } 671 | 672 | std::optional ad_builder_t::find_literal(double x) { 673 | auto it = literal_map.find(x); 674 | std::optional index; 675 | if(literal_map.end() != it) { 676 | index = it->second; 677 | } 678 | return index; 679 | } 680 | 681 | //////////////////////////////////////////////////////////////////////////////// 682 | 683 | void print_ad(const ad_t* ad, std::ostringstream& oss, int indent) { 684 | for(int i = 0; i < indent; ++i) 685 | oss.write(" ", 2); 686 | 687 | if(auto* tape = ad->as()) { 688 | oss<< "tape "<< tape->index<< "\n"; 689 | 690 | } else if(auto* literal = ad->as()) { 691 | oss<< "literal "<< literal->x<< "\n"; 692 | 693 | } else if(auto* unary = ad->as()) { 694 | oss<< "unary "<< unary->op<< "\n"; 695 | print_ad(unary->a.get(), oss, indent + 1); 696 | 697 | } else if(auto* binary = ad->as()) { 698 | oss<< "binary "<< binary->op<< "\n"; 699 | print_ad(binary->a.get(), oss, indent + 1); 700 | print_ad(binary->b.get(), oss, indent + 1); 701 | 702 | } else if(auto* func = ad->as()) { 703 | oss<< func->f<< "()\n"; 704 | for(const auto& arg : func->args) 705 | print_ad(arg.get(), oss, indent + 1); 706 | } 707 | } 708 | 709 | std::string print_ad(const ad_t* ad, int indent) { 710 | std::ostringstream oss; 711 | print_ad(ad, oss, indent); 712 | return oss.str(); 713 | } 714 | 715 | std::string print_autodiff(const autodiff_t& autodiff) { 716 | // Print all non-terminal tape items. 717 | std::ostringstream oss; 718 | 719 | for(int i = autodiff.vars.size(); i < autodiff.tape.size(); ++i) { 720 | const auto& item = autodiff.tape[i]; 721 | 722 | oss<< "tape "<< i<< ":\n"; 723 | 724 | // Print the value. 725 | oss<< " value =\n"; 726 | oss<< print_ad(item.val.get(), 2); 727 | 728 | // Print each gradient. 729 | for(const auto& grad : item.grads) { 730 | oss<< " grad "<< grad.index<< " = \n"; 731 | oss<< print_ad(grad.coef.get(), 2); 732 | } 733 | } 734 | 735 | return oss.str(); 736 | } 737 | 738 | END_APEX_NAMESPACE 739 | -------------------------------------------------------------------------------- /src/autodiff/untitled: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanbaxter/apex/f07a92857efd0d7a23d174227b8154f4cbaf01b6/src/autodiff/untitled -------------------------------------------------------------------------------- /src/core/value.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | BEGIN_APEX_NAMESPACE 4 | 5 | number_t number_t::to_boolean() const { 6 | return switch_all([](auto x) { 7 | return (bool)x; 8 | }); 9 | } 10 | 11 | number_t number_t::to_integral() const { 12 | return switch_all([](auto x) { 13 | return (int64_t)x; 14 | }); 15 | } 16 | 17 | number_t number_t::to_floating() const { 18 | return switch_all([](auto x) { 19 | return (double)x; 20 | }); 21 | } 22 | 23 | number_t number_t::to_kind(number_kind_t kind2) const { 24 | number_t result { }; 25 | switch(kind2) { 26 | case number_kind_bool: 27 | result = to_boolean(); 28 | break; 29 | 30 | case number_kind_int: 31 | result = to_integral(); 32 | break; 33 | 34 | case number_kind_float: 35 | result = to_floating(); 36 | break; 37 | } 38 | return result; 39 | } 40 | 41 | std::string number_t::to_string() const { 42 | if(number_kind_bool == kind) { 43 | return b ? "true" : "false"; 44 | 45 | } else { 46 | return switch_numeric([](auto x) { 47 | return std::to_string(x); 48 | }); 49 | } 50 | } 51 | 52 | number_kind_t common_arithmetic_kind(number_kind_t left, number_kind_t right) { 53 | if(number_kind_float == left || number_kind_float == right) 54 | return number_kind_float; 55 | else 56 | return number_kind_int; 57 | } 58 | 59 | //////////////////////////////////////////////////////////////////////////////// 60 | 61 | number_t value_unary(expr_op_t op, number_t value) { 62 | number_t result { }; 63 | switch(op) { 64 | case expr_op_inc_post: 65 | case expr_op_dec_post: 66 | case expr_op_inc_pre: 67 | case expr_op_dec_pre: 68 | break; 69 | 70 | case expr_op_complement: 71 | if(!value.is_floating()) 72 | result = ~value.convert(); 73 | break; 74 | 75 | case expr_op_negate: 76 | result = !value.convert(); 77 | break; 78 | 79 | case expr_op_plus: 80 | result = value; 81 | break; 82 | 83 | case expr_op_minus: 84 | switch(value.kind) { 85 | case number_kind_bool: 86 | break; 87 | 88 | case number_kind_int: 89 | result = number_t(-value.i); 90 | break; 91 | 92 | case number_kind_float: 93 | result = number_t(-value.d); 94 | break; 95 | } 96 | break; 97 | } 98 | 99 | return result; 100 | } 101 | 102 | number_t value_binary(expr_op_t op, number_t left, number_t right) { 103 | number_t result { }; 104 | 105 | switch(op) { 106 | case expr_op_add: 107 | case expr_op_sub: 108 | case expr_op_mul: 109 | case expr_op_div: 110 | // Promote to a common type. 111 | 112 | case expr_op_shl: 113 | case expr_op_shr: 114 | case expr_op_bit_and: 115 | case expr_op_bit_xor: 116 | case expr_op_bit_or: 117 | // Integer only. 118 | if(left.is_integral() && right.is_integral()) { 119 | int64_t x = 0; 120 | switch(op) { 121 | case expr_op_shl: x = left.i<< right.i; break; 122 | case expr_op_shr: x = left.i>> right.i; break; 123 | case expr_op_bit_and: x = left.i & right.i; break; 124 | case expr_op_bit_xor: x = left.i ^ right.i; break; 125 | case expr_op_bit_or: x = left.i | right.i; break; 126 | } 127 | result = number_t(x); 128 | } 129 | break; 130 | 131 | case expr_op_lt: 132 | case expr_op_gt: 133 | case expr_op_lte: 134 | case expr_op_gte: 135 | // Integer or float. 136 | if(left.is_arithmetic() && right.is_arithmetic()) { 137 | bool x = false; 138 | 139 | if(left.is_floating() || right.is_floating()) { 140 | double a = left.convert(); 141 | double b = right.convert(); 142 | switch(op) { 143 | case expr_op_lt: x = a < b; break; 144 | case expr_op_gt: x = a > b; break; 145 | case expr_op_lte: x = a <= b; break; 146 | case expr_op_gte: x = a >= b; break; 147 | } 148 | 149 | } else { 150 | assert(left.is_integral() && right.is_integral()); 151 | switch(op) { 152 | case expr_op_lt: x = left.i < right.i; break; 153 | case expr_op_gt: x = left.i > right.i; break; 154 | case expr_op_lte: x = left.i <= right.i; break; 155 | case expr_op_gte: x = left.i >= right.i; break; 156 | } 157 | } 158 | result = number_t(x); 159 | } 160 | break; 161 | 162 | case expr_op_eq: 163 | case expr_op_ne: { 164 | number_kind_t kind = common_arithmetic_kind(left.kind, right.kind); 165 | if(kind != left.kind) left = left.to_kind(kind); 166 | if(kind != right.kind) right = right.to_kind(kind); 167 | bool x = false; 168 | if(number_kind_bool == kind) { 169 | switch(op) { 170 | case expr_op_eq: x = left.b == right.b; break; 171 | case expr_op_ne: x = left.b != right.b; break; 172 | } 173 | 174 | } else if(number_kind_int == kind) { 175 | switch(op) { 176 | case expr_op_eq: x = left.i == right.i; break; 177 | case expr_op_ne: x = left.i != right.i; break; 178 | } 179 | 180 | } else { 181 | switch(op) { 182 | case expr_op_eq: x = left.d == right.d; break; 183 | case expr_op_ne: x = left.d != right.d; break; 184 | } 185 | } 186 | result = number_t(x); 187 | } 188 | break; 189 | 190 | case expr_op_log_and: 191 | case expr_op_log_or: { 192 | left = left.to_boolean(); 193 | right = right.to_boolean(); 194 | bool x = false; 195 | switch(op) { 196 | case expr_op_log_and: x = left.b && right.b; break; 197 | case expr_op_log_or: x = left.b && right.b; break; 198 | } 199 | result = number_t(x); 200 | break; 201 | } 202 | 203 | case expr_op_sequence: 204 | result = right; 205 | break; 206 | } 207 | 208 | return result; 209 | } 210 | 211 | 212 | END_APEX_NAMESPACE 213 | -------------------------------------------------------------------------------- /src/parse/grammar.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | BEGIN_APEX_NAMESPACE 7 | 8 | const char* expr_op_names[] { 9 | "none", 10 | 11 | // Postfix 12 | "++", 13 | "--", 14 | 15 | // Prefix operators 16 | "++", 17 | "--", 18 | "~", 19 | "!", 20 | "+", 21 | "-", 22 | "&", 23 | "*", 24 | 25 | // Pointer-to-member 26 | ".*", 27 | "->*", 28 | 29 | // Binary operators 30 | "*", 31 | "/", 32 | "%", 33 | "+", 34 | "-", 35 | "<<", 36 | ">>", 37 | "<", 38 | ">", 39 | "<=", 40 | ">=", 41 | "==", 42 | "!=", 43 | "&=", 44 | "^=", 45 | "|=", 46 | "&&", 47 | "||", 48 | 49 | // Assignment operators. 50 | "=", 51 | "*=", 52 | "/=", 53 | "%=", 54 | "+=", 55 | "-=", 56 | "<<=", 57 | ">>=", 58 | "&=" 59 | "|=", 60 | "^=", 61 | 62 | "?:", 63 | "," 64 | }; 65 | 66 | 67 | namespace parse { 68 | 69 | struct grammar_t { 70 | token_it advance_brace(range_t range); 71 | token_it advance_paren(range_t range); 72 | token_it advance_bracket(range_t range); 73 | 74 | result_t parse_paren(range_t range); 75 | result_t parse_brace(range_t range); 76 | result_t parse_bracket(range_t range); 77 | 78 | template 79 | auto parse_switch(range_t range, F f) { 80 | result_t result; 81 | if(auto x = f(range.peek())) 82 | result = make_result(range.begin, range.begin + 1, x); 83 | return result; 84 | } 85 | 86 | result_t entity(range_t range, bool expect); 87 | result_t literal(range_t range); 88 | result_t primary_expression(range_t range, bool expect); 89 | result_t expression(range_t range, bool expect); 90 | result_t postfix_expression(range_t range, bool expect); 91 | result_t postfix_operator(range_t range, node_ptr_t& node); 92 | result_t paren_initializer(range_t range); 93 | result_t unary_expression(range_t range, bool expect); 94 | result_t binary_expression(range_t range, bool expect); 95 | result_t logical_and_expression(range_t range, bool expect); 96 | result_t logical_or_expression(range_t range, bool expect); 97 | result_t assignment_expression(range_t range, bool expect); 98 | 99 | result_t paren_expression(range_t range); 100 | result_t braced_init_list(range_t range); 101 | result_t initializer_clause(range_t range, bool expect); 102 | result_t init_list(range_t range); 103 | 104 | node_ptr_t make_number(number_t number); 105 | node_ptr_t make_unary(expr_op_t op, node_ptr_t a, source_loc_t loc); 106 | node_ptr_t make_binary(expr_op_t op, node_ptr_t a, node_ptr_t b, 107 | source_loc_t loc); 108 | 109 | void throw_error(token_it pos, const char* fmt, ...); 110 | void throw_error(source_loc_t loc, const char* fmt, ...); 111 | void unexpected_token(token_it pos, const char* rule); 112 | 113 | source_loc_t loc(token_it it) const; 114 | 115 | const tok::tokenizer_t& tokenizer; 116 | }; 117 | 118 | //////////////////////////////////////////////////////////////////////////////// 119 | 120 | token_it grammar_t::advance_brace(range_t range) { 121 | int count = 1; 122 | while(token_t token = range.next()) { 123 | if(tk_sym_paren_l == token) 124 | range.begin = advance_paren(range); 125 | else if(tk_sym_paren_r == token) 126 | throw_error(range.begin - 1, "unbalanced ')' in brace set { }"); 127 | 128 | else if(tk_sym_bracket_l == token) 129 | range.begin = advance_bracket(range); 130 | else if(tk_sym_bracket_r == token) 131 | throw_error(range.begin - 1, "unbalanced ']' in brace set { }"); 132 | 133 | else if(tk_sym_brace_l == token) 134 | ++count; 135 | else if(tk_sym_brace_r == token) 136 | --count; 137 | 138 | if(!count) break; 139 | } 140 | 141 | if(count) 142 | throw_error(range.begin, "no closing '}' in brace set { }"); 143 | 144 | return range.begin; 145 | } 146 | 147 | token_it grammar_t::advance_paren(range_t range) { 148 | int count = 1; 149 | while(token_t token = range.next()) { 150 | if(tk_sym_bracket_l == token) 151 | range.begin = advance_bracket(range); 152 | else if(tk_sym_bracket_r == token) 153 | throw_error(range.begin - 1, "unbalanced ']' in paren set ( )"); 154 | 155 | else if(tk_sym_brace_l == token) 156 | range.begin = advance_brace(range); 157 | else if(tk_sym_brace_r == token) 158 | throw_error(range.begin - 1, "unbalanced '}' in paren set ( )"); 159 | 160 | else if(tk_sym_paren_l == token) 161 | ++count; 162 | else if(tk_sym_paren_r == token) 163 | --count; 164 | 165 | if(!count) break; 166 | } 167 | 168 | if(count) 169 | throw_error(range.begin, "no closing ')' in paren set ( )"); 170 | 171 | return range.begin; 172 | } 173 | 174 | token_it grammar_t::advance_bracket(range_t range) { 175 | int count = 1; 176 | while(token_t token = range.next()) { 177 | if(tk_sym_brace_l == token) 178 | range.begin = advance_brace(range); 179 | else if(tk_sym_brace_r == token) 180 | throw_error(range.begin - 1, "unbalanced '}' in bracket set [ ]"); 181 | 182 | else if(tk_sym_paren_l == token) 183 | range.begin = advance_paren(range); 184 | else if(tk_sym_paren_r == token) 185 | throw_error(range.begin - 1, "unbalanced ')' in bracket set [ ]"); 186 | 187 | else if(tk_sym_bracket_l == token ) 188 | ++count; 189 | else if(tk_sym_bracket_r == token) 190 | --count; 191 | 192 | if(!count) break; 193 | } 194 | 195 | if(count) 196 | throw_error(range.begin, "no closing ']' in bracket set [ ]"); 197 | 198 | return range.begin; 199 | } 200 | 201 | result_t grammar_t::parse_brace(range_t range) { 202 | result_t result; 203 | token_it begin = range.begin; 204 | if(range.advance_if(tk_sym_brace_l)) { 205 | token_it end = advance_brace(range); 206 | result = make_result(begin, end, range_t { range.begin, end - 1 }); 207 | } 208 | return result; 209 | } 210 | 211 | result_t grammar_t::parse_paren(range_t range) { 212 | result_t result; 213 | token_it begin = range.begin; 214 | if(range.advance_if(tk_sym_paren_l)) { 215 | token_it end = advance_paren(range); 216 | result = make_result(begin, end, range_t { range.begin, end - 1 }); 217 | } 218 | return result; 219 | } 220 | 221 | result_t grammar_t::parse_bracket(range_t range) { 222 | result_t result; 223 | token_it begin = range.begin; 224 | if(range.advance_if(tk_sym_bracket_l)) { 225 | token_it end = advance_bracket(range); 226 | result = make_result(begin, end, range_t { range.begin, end - 1 }); 227 | } 228 | return result; 229 | } 230 | 231 | //////////////////////////////////////////////////////////////////////////////// 232 | 233 | result_t grammar_t::entity(range_t range, bool expect) { 234 | result_t result; 235 | token_it begin = range.begin; 236 | if(token_t token = range.advance_if(tk_ident)) { 237 | auto ident = std::make_unique(loc(begin)); 238 | ident->s = tokenizer.strings[token.store]; 239 | result = make_result(begin, range.begin, std::move(ident)); 240 | 241 | } else if(expect) 242 | throw_error(range.begin, "expected entity in expression"); 243 | 244 | return result; 245 | } 246 | 247 | result_t grammar_t::literal(range_t range) { 248 | token_it begin = range.begin; 249 | node_ptr_t node; 250 | switch(token_t token = range.next()) { 251 | case tk_int: { 252 | int64_t i = tokenizer.ints[token.store]; 253 | node = std::make_unique(i, loc(begin)); 254 | break; 255 | } 256 | 257 | case tk_float: { 258 | double d = tokenizer.floats[token.store]; 259 | node = std::make_unique(d, loc(begin)); 260 | break; 261 | } 262 | 263 | case tk_char: 264 | node = std::make_unique((char32_t)token.store, loc(begin)); 265 | break; 266 | 267 | case tk_string: { 268 | const std::string& s = tokenizer.strings[token.store]; 269 | node = std::make_unique(s, loc(begin)); 270 | break; 271 | } 272 | 273 | case tk_kw_false: 274 | node = std::make_unique(false, loc(begin)); 275 | break; 276 | 277 | case tk_kw_true: 278 | node = std::make_unique(true, loc(begin)); 279 | break; 280 | 281 | default: 282 | break; 283 | } 284 | 285 | return make_result(begin, range.begin, std::move(node)); 286 | } 287 | 288 | result_t grammar_t::primary_expression(range_t range, bool expect) { 289 | result_t result; 290 | 291 | switch(range.peek()) { 292 | case tk_kw_false: 293 | case tk_kw_true: 294 | case tk_int: 295 | case tk_float: 296 | case tk_char: 297 | case tk_string: 298 | result = literal(range); 299 | break; 300 | 301 | case tk_sym_paren_l: 302 | result = paren_expression(range); 303 | break; 304 | 305 | default: 306 | result = entity(range, expect); 307 | break; 308 | } 309 | 310 | return result; 311 | } 312 | 313 | result_t grammar_t::postfix_expression(range_t range, bool expect) { 314 | token_it begin = range.begin; 315 | result_t result; 316 | if(auto primary = primary_expression(range, expect)) { 317 | range.advance(primary); 318 | node_ptr_t node = std::move(primary->attr); 319 | 320 | // Consume postfix operators until there are no more. 321 | while(auto op = postfix_operator(range, node)) { 322 | range.advance(op); 323 | assert(op->attr); 324 | node = std::move(op->attr); 325 | } 326 | 327 | assert(node); 328 | result = make_result(begin, range.begin, std::move(node)); 329 | } 330 | 331 | return result; 332 | } 333 | 334 | result_t grammar_t::postfix_operator(range_t range, 335 | node_ptr_t& node) { 336 | 337 | token_it begin = range.begin; 338 | switch(token_t token = range.next()) { 339 | case tk_sym_minusminus: { 340 | case tk_sym_plusplus: 341 | expr_op_t op = tk_sym_plusplus == token ? 342 | expr_op_inc_post : expr_op_dec_post; 343 | node = make_unary(op, std::move(node), loc(begin)); 344 | break; 345 | } 346 | 347 | case tk_sym_bracket_l: { 348 | // Subscript operation. 349 | break; 350 | } 351 | 352 | case tk_sym_paren_l: { 353 | --range.begin; 354 | auto paren = paren_initializer(range); 355 | range.advance(paren); 356 | 357 | auto call = std::make_unique(loc(begin)); 358 | call->f = std::move(node); 359 | call->args = std::move(paren->attr); 360 | node = std::move(call); 361 | break; 362 | } 363 | 364 | case tk_sym_arrow: 365 | case tk_sym_dot: { 366 | 367 | } 368 | 369 | default: 370 | // We don't match any of the postfix expressions, so break the loop and 371 | // return to the caller. 372 | return { }; 373 | } 374 | 375 | return make_result(begin, range.begin, std::move(node)); 376 | } 377 | 378 | //////////////////////////////////////////////////////////////////////////////// 379 | 380 | expr_op_t switch_unary(tk_kind_t kind) { 381 | expr_op_t op = expr_op_none; 382 | switch(kind) { 383 | case tk_sym_plusplus: op = expr_op_inc_pre; break; 384 | case tk_sym_minusminus: op = expr_op_dec_pre; break; 385 | case tk_sym_tilde: op = expr_op_complement; break; 386 | case tk_sym_bang: op = expr_op_negate; break; 387 | case tk_sym_plus: op = expr_op_plus; break; 388 | case tk_sym_minus: op = expr_op_minus; break; 389 | case tk_sym_amp: op = expr_op_addressof; break; 390 | case tk_sym_star: op = expr_op_indirection; break; 391 | default: break; 392 | } 393 | return op; 394 | } 395 | 396 | result_t grammar_t::unary_expression(range_t range, bool expect) { 397 | token_it begin = range.begin; 398 | result_t result; 399 | 400 | if(auto op = parse_switch(range, switch_unary)) { 401 | range.advance(op); 402 | 403 | auto rhs = unary_expression(range, true); 404 | range.advance(rhs); 405 | 406 | node_ptr_t unary = make_unary(op->attr, std::move(rhs->attr), loc(begin)); 407 | result = make_result(begin, range.begin, std::move(unary)); 408 | 409 | } else 410 | result = postfix_expression(range, expect); 411 | 412 | return result; 413 | } 414 | 415 | //////////////////////////////////////////////////////////////////////////////// 416 | 417 | enum ast_prec_t : uint8_t { 418 | // lowest precedence. 419 | ast_prec_any = 0, 420 | ast_prec_comma, 421 | ast_prec_assign, 422 | ast_prec_log_or, 423 | ast_prec_log_and, 424 | ast_prec_bit_or, 425 | ast_prec_bit_xor, 426 | ast_prec_bit_and, 427 | ast_prec_eq, 428 | ast_prec_cmp, 429 | ast_prec_shift, 430 | ast_prec_add, 431 | ast_prec_mul, 432 | ast_prec_ptr_to_mem, 433 | // highest precedence. 434 | }; 435 | 436 | struct binary_desc_t { 437 | expr_op_t op; 438 | ast_prec_t prec; 439 | 440 | explicit operator bool() const { return op; } 441 | }; 442 | 443 | binary_desc_t switch_binary(tk_kind_t kind) { 444 | binary_desc_t desc { }; 445 | switch(kind) { 446 | // binary ->* and .* 447 | case tk_sym_arrowstar: desc = { expr_op_ptrmem_arrow, ast_prec_ptr_to_mem }; break; 448 | case tk_sym_dotstar: desc = { expr_op_ptrmem_dot, ast_prec_ptr_to_mem }; break; 449 | 450 | // binary *, /, % with the same precedence. 451 | case tk_sym_star: desc = { expr_op_mul, ast_prec_mul }; break; 452 | case tk_sym_slash: desc = { expr_op_div, ast_prec_mul }; break; 453 | case tk_sym_percent: desc = { expr_op_mod, ast_prec_mul }; break; 454 | 455 | // binary + and - with the same precedence. 456 | case tk_sym_plus: desc = { expr_op_add, ast_prec_add }; break; 457 | case tk_sym_minus: desc = { expr_op_sub, ast_prec_add }; break; 458 | 459 | // <<, >> with the same precedence. 460 | case tk_sym_ltlt: desc = { expr_op_shl, ast_prec_shift }; break; 461 | case tk_sym_gtgt: desc = { expr_op_shr, ast_prec_shift }; break; 462 | 463 | // <, >, <=, >= with the same precedence. 464 | case tk_sym_lt: desc = { expr_op_lt, ast_prec_cmp }; break; 465 | case tk_sym_gt: desc = { expr_op_gt, ast_prec_cmp }; break; 466 | case tk_sym_lteq: desc = { expr_op_lte, ast_prec_cmp }; break; 467 | case tk_sym_gteq: desc = { expr_op_gte, ast_prec_cmp }; break; 468 | 469 | // == and != with the same precedence. 470 | case tk_sym_eqeq: desc = { expr_op_eq, ast_prec_eq }; break; 471 | case tk_sym_bangeq: desc = { expr_op_ne, ast_prec_eq }; break; 472 | 473 | // bitwise AND & 474 | case tk_sym_amp: desc = { expr_op_bit_and, ast_prec_bit_and }; break; 475 | 476 | // bitwise XOR ^ 477 | case tk_sym_caret: desc = { expr_op_bit_xor, ast_prec_bit_xor }; break; 478 | 479 | // bitwise OR | 480 | case tk_sym_pipe: desc = { expr_op_bit_or, ast_prec_bit_or }; break; 481 | 482 | default: break; 483 | } 484 | return desc; 485 | } 486 | 487 | struct item_t { 488 | node_ptr_t node; 489 | source_loc_t loc; 490 | binary_desc_t desc; 491 | }; 492 | 493 | result_t grammar_t::binary_expression(range_t range, bool expect) { 494 | 495 | std::vector stack; 496 | auto fold = [&]() { 497 | while(stack.size() >= 2) { 498 | size_t size = stack.size(); 499 | auto& lhs = stack[size - 2]; 500 | auto& rhs = stack[size - 1]; 501 | 502 | if(lhs.desc.prec >= rhs.desc.prec) { 503 | // Fold the two right-most expressions together. 504 | lhs.node = make_binary(lhs.desc.op, std::move(lhs.node), 505 | std::move(rhs.node), lhs.loc); 506 | 507 | // Use the descriptor for the rhs for this subexpression. 508 | lhs.loc = rhs.loc; 509 | lhs.desc = rhs.desc; 510 | 511 | // Pop the rhs. 512 | stack.pop_back(); 513 | } else 514 | break; 515 | } 516 | }; 517 | 518 | token_it begin = range.begin; 519 | result_t result; 520 | if(auto lhs = unary_expression(range, false)) { 521 | range.advance(lhs); 522 | stack.push_back({ std::move(lhs->attr), loc(lhs->range.begin) }); 523 | 524 | while(true) { 525 | item_t& item = stack.back(); 526 | if(auto op = parse_switch(range, switch_binary)) { 527 | range.advance(op); 528 | item.desc = op->attr; 529 | } else 530 | // No operator found. This is the end of the binary-expression. 531 | break; 532 | 533 | // Fold the expressions to the left with equal or lesser precedence. 534 | fold(); 535 | 536 | // Read the next expression. 537 | auto rhs = unary_expression(range, true); 538 | assert(rhs); 539 | range.advance(rhs); 540 | stack.push_back({ std::move(rhs->attr) }); 541 | assert(stack.back().node); 542 | } 543 | 544 | // Fold all the remaining expressions. 545 | stack.back().desc.prec = ast_prec_any; 546 | fold(); 547 | 548 | assert(1 == stack.size()); 549 | result = make_result(begin, range.begin, std::move(stack[0].node)); 550 | } 551 | 552 | return result; 553 | } 554 | 555 | //////////////////////////////////////////////////////////////////////////////// 556 | 557 | result_t grammar_t::logical_and_expression(range_t range, 558 | bool expect) { 559 | 560 | token_it begin = range.begin; 561 | result_t result = binary_expression(range, expect); 562 | 563 | while(result) { 564 | range.advance(result); 565 | 566 | if(range.advance_if(tk_sym_ampamp)) { 567 | auto rhs = binary_expression(range, true); 568 | range.advance(rhs); 569 | 570 | auto binary = make_binary(expr_op_log_and, std::move(result->attr), 571 | std::move(rhs->attr), loc(begin)); 572 | result = make_result(begin, range.begin, std::move(binary)); 573 | 574 | } else 575 | break; 576 | } 577 | 578 | return result; 579 | } 580 | 581 | result_t grammar_t::logical_or_expression(range_t range, 582 | bool expect) { 583 | 584 | token_it begin = range.begin; 585 | result_t result = logical_and_expression(range, expect); 586 | while(result) { 587 | range.advance(result); 588 | 589 | if(range.advance_if(tk_sym_ampamp)) { 590 | auto rhs = logical_and_expression(range, true); 591 | range.advance(rhs); 592 | 593 | auto binary = make_binary(expr_op_log_or, std::move(result->attr), 594 | std::move(rhs->attr), loc(begin)); 595 | result = make_result(begin, range.begin, std::move(binary)); 596 | 597 | } else 598 | break; 599 | } 600 | 601 | return result; 602 | } 603 | 604 | 605 | //////////////////////////////////////////////////////////////////////////////// 606 | 607 | expr_op_t switch_assign(tk_kind_t kind) { 608 | expr_op_t op = expr_op_none; 609 | switch(kind) { 610 | case tk_sym_eq: op = expr_op_assign; break; 611 | case tk_sym_stareq: op = expr_op_assign_mul; break; 612 | case tk_sym_slasheq: op = expr_op_assign_div; break; 613 | case tk_sym_percenteq: op = expr_op_assign_mod; break; 614 | case tk_sym_pluseq: op = expr_op_assign_add; break; 615 | case tk_sym_minuseq: op = expr_op_assign_sub; break; 616 | case tk_sym_ltlteq: op = expr_op_assign_shl; break; 617 | case tk_sym_gtgteq: op = expr_op_assign_shr; break; 618 | case tk_sym_ampeq: op = expr_op_assign_and; break; 619 | case tk_sym_pipeeq: op = expr_op_assign_or; break; 620 | case tk_sym_careteq: op = expr_op_assign_xor; break; 621 | default: break; 622 | } 623 | return op; 624 | } 625 | 626 | result_t grammar_t::assignment_expression(range_t range, 627 | bool expect) { 628 | 629 | token_it begin = range.begin; 630 | result_t result; 631 | if(auto a = logical_or_expression(range, expect)) { 632 | range.advance(a); 633 | 634 | if(auto op = parse_switch(range, switch_assign)) { 635 | range.advance(op); 636 | 637 | // Match an initializer clause. 638 | auto b = initializer_clause(range, true); 639 | range.advance(b); 640 | 641 | auto assign = std::make_unique(loc(begin)); 642 | assign->a = std::move(a->attr); 643 | assign->b = std::move(b->attr); 644 | 645 | a->attr = std::move(assign); 646 | 647 | } else if(range.advance_if(tk_sym_question)) { 648 | // Start of a ternary expression ? : 649 | auto b = assignment_expression(range, true); 650 | range.advance(b); 651 | 652 | if(!range.advance_if(tk_sym_col)) 653 | throw_error(range.begin, "expected ':' in conditional-expression"); 654 | 655 | auto c = assignment_expression(range, true); 656 | range.advance(c); 657 | 658 | auto ternary = std::make_unique(loc(begin)); 659 | ternary->a = std::move(a->attr); 660 | ternary->b = std::move(b->attr); 661 | ternary->c = std::move(c->attr); 662 | 663 | a->attr = std::move(ternary); 664 | } 665 | 666 | result = make_result(begin, range.begin, std::move(a->attr)); 667 | } 668 | 669 | return result; 670 | } 671 | 672 | result_t grammar_t::expression(range_t range, bool expect) { 673 | token_it begin = range.begin; 674 | result_t result; 675 | 676 | if(auto expr = assignment_expression(range, expect)) { 677 | range.advance(expr); 678 | 679 | while(range.advance_if(tk_sym_comma)) { 680 | auto expr2 = assignment_expression(range, true); 681 | range.advance(expr2); 682 | 683 | auto binary = make_binary(expr_op_sequence, std::move(expr->attr), 684 | std::move(expr2->attr), loc(begin)); 685 | expr->attr = std::move(binary); 686 | } 687 | 688 | result = make_result(begin, range.begin, std::move(expr->attr)); 689 | } 690 | return result; 691 | } 692 | 693 | //////////////////////////////////////////////////////////////////////////////// 694 | 695 | result_t grammar_t::paren_initializer(range_t range) { 696 | token_it begin = range.begin; 697 | result_t result; 698 | if(auto paren = parse_paren(range)) { 699 | range.advance(paren); 700 | 701 | auto list = init_list(paren->attr); 702 | result = make_result(begin, range.begin, std::move(list->attr)); 703 | } 704 | return result; 705 | } 706 | 707 | result_t grammar_t::paren_expression(range_t range) { 708 | token_it begin = range.begin; 709 | result_t result; 710 | 711 | if(auto paren = parse_paren(range)) { 712 | range.advance(paren); 713 | range_t range2 = paren->attr; 714 | 715 | if(auto expr = expression(range2, true)) { 716 | range2.advance(expr); 717 | 718 | if(range2) 719 | unexpected_token(range2.begin, "expression"); 720 | 721 | result = make_result(begin, range.begin, std::move(expr->attr)); 722 | 723 | } else 724 | throw_error(range.begin, "expected expression"); 725 | } 726 | 727 | return result; 728 | } 729 | 730 | result_t grammar_t::braced_init_list(range_t range) { 731 | token_it begin = range.begin; 732 | result_t result; 733 | 734 | if(auto brace = parse_brace(range)) { 735 | range.advance(brace); 736 | range_t range2 = brace->attr; 737 | 738 | // Support a braced initializer with a trailing , as long as there are 739 | // other tokens. 740 | if(range2.end - 1 > range2.begin && tk_sym_comma == range2.end[-1]) 741 | --range2.end; 742 | 743 | node_list_t init_list; 744 | } 745 | return result; 746 | } 747 | 748 | result_t grammar_t::initializer_clause(range_t range, bool expect) { 749 | result_t result = braced_init_list(range); 750 | if(!result) result = assignment_expression(range, expect); 751 | return result; 752 | } 753 | 754 | result_t grammar_t::init_list(range_t range) { 755 | // Must consume all elements. 756 | token_it begin = range.begin; 757 | node_list_t list; 758 | if(auto expr = initializer_clause(range, false)) { 759 | range.advance(expr); 760 | list.push_back(std::move(expr->attr)); 761 | 762 | while(range.advance_if(tk_sym_comma)) { 763 | auto expr2 = initializer_clause(range, true); 764 | range.advance(expr2); 765 | list.push_back(std::move(expr2->attr)); 766 | } 767 | } 768 | 769 | if(range) 770 | unexpected_token(range.begin, "initializer-list"); 771 | 772 | return make_result(begin, range.begin, std::move(list)); 773 | } 774 | 775 | //////////////////////////////////////////////////////////////////////////////// 776 | 777 | node_ptr_t grammar_t::make_unary(expr_op_t op, node_ptr_t a, source_loc_t loc) { 778 | if(auto* a2 = a->as()) { 779 | if(number_t n = value_unary(op, a2->x)) { 780 | a2->x = n; 781 | a2->loc = loc; 782 | return std::move(a); 783 | 784 | } else { 785 | throw_error(loc, "illegal constant folding operation"); 786 | } 787 | 788 | } else { 789 | auto result = std::make_unique(loc); 790 | result->op = op; 791 | result->a = std::move(a); 792 | return result; 793 | } 794 | } 795 | 796 | node_ptr_t grammar_t::make_binary(expr_op_t op, node_ptr_t a, node_ptr_t b, 797 | source_loc_t loc) { 798 | 799 | auto* a2 = a->as(); 800 | auto* b2 = b->as(); 801 | if(a2 && b2) { 802 | if(number_t n = value_binary(op, a2->x, b2->x)) { 803 | a2->x = n; 804 | a2->loc = loc; 805 | return std::move(a); 806 | 807 | } else { 808 | throw_error(loc, "illegal constant folding operation"); 809 | } 810 | 811 | } else { 812 | auto result = std::make_unique(loc); 813 | result->op = op; 814 | result->a = std::move(a); 815 | result->b = std::move(b); 816 | return result; 817 | } 818 | } 819 | 820 | //////////////////////////////////////////////////////////////////////////////// 821 | 822 | void grammar_t::throw_error(source_loc_t loc, const char* fmt, ...) { 823 | va_list args; 824 | va_start(args, fmt); 825 | std::string msg = vformat(fmt, args); 826 | va_end(args); 827 | 828 | throw parse_exception_t(msg); 829 | 830 | } 831 | 832 | void grammar_t::throw_error(token_it pos, const char* fmt, ...) { 833 | 834 | } 835 | 836 | void grammar_t::unexpected_token(token_it pos, const char* rule) { 837 | const char* begin = pos->begin; 838 | const char* end = pos->end; 839 | int len = end - begin; 840 | 841 | std::string msg = format("unexpected token '%.*s' in %s", len, begin, rule); 842 | 843 | throw parse_exception_t(msg); 844 | } 845 | 846 | source_loc_t grammar_t::loc(token_it it) const { 847 | return { (int)(it - tokenizer.tokens.data()) }; 848 | } 849 | 850 | //////////////////////////////////////////////////////////////////////////////// 851 | 852 | //////////////////////////////////////////////////////////////////////////////// 853 | 854 | parse_t parse_expression(const char* begin, const char* end) { 855 | // Tokenize the input. 856 | parse_t parse; 857 | parse.tokenizer.text = std::string(begin, end); 858 | parse.tokenizer.tokenize(); 859 | 860 | // Parse the tokens. 861 | grammar_t g { parse.tokenizer }; 862 | range_t range = parse.tokenizer.token_range(); 863 | 864 | auto expr = g.expression(range, true); 865 | range.advance(expr); 866 | if(range) 867 | g.unexpected_token(range.begin, "expression"); 868 | parse.root = std::move(expr->attr); 869 | 870 | return std::move(parse); 871 | } 872 | 873 | 874 | parse_t parse_expression(const char* str) { 875 | return parse_expression(str, str + strlen(str)); 876 | } 877 | 878 | } // namespace parse 879 | 880 | END_APEX_NAMESPACE 881 | 882 | -------------------------------------------------------------------------------- /src/tokenizer/lexer.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | namespace tok { 7 | 8 | result_t lexer_t::char_literal(range_t range) { 9 | const char* begin = range.begin; 10 | result_t result; 11 | 12 | if(range.advance_if('\'')) { 13 | char32_t char_; 14 | if(auto c = c_char(range)) { 15 | char_ = c->attr; 16 | 17 | } else 18 | throw_error(range.begin, "expected character in literal"); 19 | 20 | if(!range.advance_if('\'')) 21 | throw_error(range.begin, "expected \"'\" to end character literal"); 22 | 23 | result = make_result(begin, range.begin, token_t { 24 | tk_char, (int)char_, begin, range.begin 25 | }); 26 | } 27 | return result; 28 | } 29 | 30 | result_t lexer_t::c_char(range_t range) { 31 | // Any character except '. 32 | // if(range[0] == '') 33 | return { }; 34 | } 35 | 36 | result_t lexer_t::string_literal(range_t range) { 37 | return { }; 38 | } 39 | 40 | result_t lexer_t::s_char(range_t range) { 41 | // Any character except ". 42 | return { }; 43 | } 44 | 45 | result_t lexer_t::identifier_char(range_t range, bool digit) { 46 | const char* begin = range.begin; 47 | result_t result; 48 | 49 | if(char c = range.next()) { 50 | bool c1 = digit && isdigit(c); 51 | if(c1 || isalpha(c) || '_' == c) 52 | result = make_result(begin, range.begin, (char32_t)c); 53 | else 54 | result = ucs(range); 55 | } 56 | return result; 57 | } 58 | 59 | result_t lexer_t::ucs(range_t range) { 60 | const char* begin = range.begin; 61 | result_t result; 62 | 63 | if(range && (0x80 & range.begin[0])) { 64 | std::pair p = from_utf8(range.begin); 65 | range.begin += p.first; 66 | result = make_result(begin, range.begin, (char32_t)p.second); 67 | } 68 | return result; 69 | } 70 | 71 | const char* lexer_t::skip_comment(range_t range) { 72 | while(true) { 73 | // Eat the blank characters. 74 | while(range.advance_if(isblank)); 75 | 76 | const char* begin = range.begin; 77 | if(range.match_advance("//")) { 78 | // Match a C++-style comment. 79 | auto f = [](char c) { 80 | return '\n' != c; 81 | }; 82 | 83 | while(range.advance_if(f)); 84 | 85 | } else if(range.match_advance("/*")) { 86 | // Match a C-style comment. 87 | while(!range.match("*/") && range.next()); 88 | 89 | if(!range.match_advance("*/")) 90 | throw_error(begin, "unterminated C-style comment: expected */"); 91 | 92 | } else 93 | break; 94 | } 95 | 96 | return range.begin; 97 | } 98 | 99 | result_t lexer_t::literal(range_t range) { 100 | result_t result = number(range); 101 | if(!result) result = char_literal(range); 102 | if(!result) result = string_literal(range); 103 | return result; 104 | } 105 | 106 | result_t lexer_t::operator_(range_t range) { 107 | result_t result; 108 | if(auto match = match_operator(range)) { 109 | token_t token { match->attr, 0, match->range.begin, match->range.end }; 110 | return make_result(match->range, token); 111 | } 112 | return result; 113 | } 114 | 115 | result_t lexer_t::identifier(range_t range) { 116 | const char* begin = range.begin; 117 | result_t result; 118 | 119 | if(auto c = identifier_char(range, false)) { 120 | range.advance(c); 121 | 122 | while(auto c = identifier_char(range, true)) 123 | range.advance(c); 124 | 125 | int ident = tokenizer.reg_string(range_t { begin, range.begin }); 126 | token_t token { tk_ident, ident, begin, range.begin }; 127 | result = make_result(begin, range.begin, token); 128 | } 129 | return result; 130 | } 131 | 132 | result_t lexer_t::token(range_t range) { 133 | result_t result = literal(range); 134 | if(!result) result = identifier(range); 135 | if(!result) result = operator_(range); 136 | return result; 137 | } 138 | 139 | 140 | bool lexer_t::advance_skip(range_t& range) { 141 | const char* next = skip_comment(range); 142 | bool advance = next != range.begin; 143 | range.begin = next; 144 | return advance; 145 | } 146 | 147 | void lexer_t::throw_error(const char* p, const char* msg) { 148 | printf("Error thrown %s\n", msg); 149 | exit(0); 150 | } 151 | 152 | } // namespace tok 153 | 154 | END_APEX_NAMESPACE 155 | -------------------------------------------------------------------------------- /src/tokenizer/number.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | namespace tok { 7 | 8 | result_t lexer_t::pp_number(range_t range) { 9 | result_t result; 10 | const char* begin = range.begin; 11 | 12 | // [lex.ppnumber] 13 | // pp-number: 14 | // digit 15 | // . digit 16 | range.advance_if('.'); 17 | if(range.advance_if(isdigit)) { 18 | while(char c0 = toupper(range[0])) { 19 | char c1 = range[1]; 20 | if(('E' == c0 || 'P' == c0) && ('+' == c1 || '-' == c1)) { 21 | // pp-number e sign 22 | // pp-number E sign 23 | // pp-number p sign 24 | // pp-number P sign 25 | range.begin += 2; 26 | continue; 27 | } 28 | 29 | if('\'' == c0 && (isalnum(c1) || '_' == c1)) { 30 | // pp-number ' digit 31 | // pp-number ' non-digit 32 | range.begin += 2; 33 | continue; 34 | } 35 | 36 | if('.' == c0) { 37 | // pp-number . 38 | ++range.begin; 39 | continue; 40 | } 41 | 42 | if(auto c = identifier_char(range, true)) { 43 | // pp-number digit 44 | // pp-number identifier-nondigit 45 | range.advance(c); 46 | continue; 47 | } 48 | 49 | break; 50 | } 51 | 52 | result = make_result(begin, range.begin, { }); 53 | } 54 | 55 | return result; 56 | } 57 | 58 | struct floating_parts_t { 59 | range_t integer; // digits before the . 60 | range_t fractional; // digits after the . 61 | int exponent; // after the exponent. 62 | }; 63 | 64 | result_t lexer_t::decimal_sequence(range_t range) { 65 | const char* begin = range.begin; 66 | while(isdigit(range.peek())) 67 | ++range.begin; 68 | return make_result(begin, range.begin, { }); 69 | } 70 | 71 | result_t lexer_t::decimal_number(range_t range) { 72 | result_t result; 73 | if(auto digits = decimal_sequence(range)) { 74 | range.advance(digits); 75 | 76 | uint64_t x = 0; 77 | for(const char* p = digits->range.begin; p != digits->range.end; ++p) { 78 | int y = *p - '0'; 79 | uint64_t x2 = 10 * x + y; 80 | if(x2 < x) 81 | throw_error(p, "integer overflow in decimal literal"); 82 | x = x2; 83 | } 84 | result = make_result(digits->range, x); 85 | } 86 | return result; 87 | } 88 | 89 | result_t lexer_t::exponent_part(range_t range) { 90 | result_t result; 91 | auto begin = range.begin; 92 | if(range.advance_if('e') || range.advance_if('E')) { 93 | bool sign = false; 94 | if(range.advance_if('-')) sign = true; 95 | else range.advance_if('+'); 96 | 97 | // Expect a digit-sequence here. 98 | if(auto exp = decimal_number(range)) { 99 | range.advance(exp); 100 | if(exp->attr > INT_MAX) 101 | throw_error(exp->range.begin, "exponent is too large"); 102 | 103 | int exponent = exp->attr; 104 | if(sign) exponent = -exponent; 105 | 106 | result = make_result(begin, range.begin, exponent); 107 | 108 | } else 109 | throw_error(range.begin, "expected digit-sequence in exponent-part"); 110 | } 111 | return result; 112 | } 113 | 114 | result_t lexer_t::floating_point_literal(range_t range) { 115 | const char* begin = range.begin; 116 | floating_parts_t parts { }; 117 | if(auto leading = decimal_sequence(range)) { 118 | range.advance(leading); 119 | parts.integer = leading->range; 120 | 121 | if(range.advance_if('.')) { 122 | // We've matched fractional-constant, so both the trailing digit-sequence 123 | // and exponent are optional. 124 | if(auto fractional = decimal_sequence(range)) { 125 | range.advance(fractional); 126 | parts.fractional = fractional->range; 127 | } 128 | 129 | if(auto exp = exponent_part(range)) { 130 | range.advance(exp); 131 | parts.exponent = exp->attr; 132 | } 133 | 134 | } else if(auto exp = exponent_part(range)) { 135 | range.advance(exp); 136 | parts.exponent = exp->attr; 137 | 138 | } else 139 | // A leading decimal sequence with no fraction or exp is an integer 140 | return { }; 141 | 142 | } else if(range.advance_if('.')) { 143 | if(auto fractional = decimal_sequence(range)) { 144 | range.advance(fractional); 145 | parts.fractional = fractional->range; 146 | 147 | if(auto exp = exponent_part(range)) { 148 | range.advance(exp); 149 | parts.exponent = exp->attr; 150 | } 151 | } 152 | 153 | } else 154 | return { }; 155 | 156 | // TODO: Assemble the floating-point literal by hand. 157 | // sscanf the floating point literal into double. 158 | std::string s(begin, range.begin); 159 | double x; 160 | sscanf(s.c_str(), "%lf", &x); 161 | 162 | return make_result(begin, range.begin, x); 163 | } 164 | 165 | result_t lexer_t::integer_literal(range_t range) { 166 | result_t result; 167 | 168 | // For now parse all numbers as base 10. 169 | if(auto number = decimal_number(range)) 170 | result = number; 171 | 172 | return result; 173 | }; 174 | 175 | result_t lexer_t::number(range_t range) { 176 | result_t result; 177 | if(auto num = pp_number(range)) { 178 | // The pp-number must be a floating-point-literal or integer-literal. 179 | range_t range = num->range; 180 | token_t token { }; 181 | if(auto floating = floating_point_literal(range)) { 182 | range.advance(floating); 183 | 184 | result = make_result(range, token_t { 185 | tk_float, 186 | (int)tokenizer.floats.size(), 187 | floating->range.begin, 188 | floating->range.end 189 | }); 190 | tokenizer.floats.push_back(floating->attr); 191 | 192 | } else if(auto integer = integer_literal(range)) { 193 | range.advance(integer); 194 | 195 | result = make_result(range, token_t { 196 | tk_int, 197 | (int)tokenizer.ints.size(), 198 | integer->range.begin, 199 | integer->range.end 200 | }); 201 | tokenizer.ints.push_back(integer->attr); 202 | } 203 | 204 | if(range) 205 | throw_error(range.begin, "unexpected character in numeric literal"); 206 | } 207 | return result; 208 | } 209 | 210 | } // namespace tok 211 | 212 | END_APEX_NAMESPACE 213 | -------------------------------------------------------------------------------- /src/tokenizer/operators.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | BEGIN_APEX_NAMESPACE 6 | 7 | namespace tok { 8 | 9 | struct tk_symbol_t { 10 | const char* symbol; 11 | tk_kind_t kind; 12 | }; 13 | 14 | const tk_symbol_t tk_op_symbols[] = { 15 | // Standard tokens. 16 | tk_symbol_t { "&" , tk_sym_amp }, 17 | tk_symbol_t { "&&" , tk_sym_ampamp }, 18 | tk_symbol_t { "&=" , tk_sym_ampeq }, 19 | tk_symbol_t { "->" , tk_sym_arrow }, 20 | tk_symbol_t { "->*" , tk_sym_arrowstar }, 21 | tk_symbol_t { "[[" , tk_sym_attrib_l }, 22 | tk_symbol_t { "!" , tk_sym_bang }, 23 | tk_symbol_t { "!=" , tk_sym_bangeq }, 24 | tk_symbol_t { "{" , tk_sym_brace_l }, 25 | tk_symbol_t { "}" , tk_sym_brace_r }, 26 | tk_symbol_t { "[" , tk_sym_bracket_l }, 27 | tk_symbol_t { "]" , tk_sym_bracket_r }, 28 | tk_symbol_t { "^" , tk_sym_caret }, 29 | tk_symbol_t { "^=" , tk_sym_careteq }, 30 | // tk_symbol_t { "<<<" , tk_sym_chevron_l }, 31 | // tk_symbol_t { ">>>" , tk_sym_chevron_r }, 32 | tk_symbol_t { ":" , tk_sym_col }, 33 | tk_symbol_t { "::" , tk_sym_colcol }, 34 | tk_symbol_t { "," , tk_sym_comma }, 35 | tk_symbol_t { "." , tk_sym_dot }, 36 | tk_symbol_t { ".*" , tk_sym_dotstar }, 37 | tk_symbol_t { "..." , tk_sym_ellipsis }, 38 | tk_symbol_t { "=" , tk_sym_eq }, 39 | tk_symbol_t { "==" , tk_sym_eqeq }, 40 | tk_symbol_t { ">" , tk_sym_gt }, 41 | tk_symbol_t { ">=" , tk_sym_gteq }, 42 | tk_symbol_t { ">>" , tk_sym_gtgt }, 43 | tk_symbol_t { ">>=" , tk_sym_gtgteq }, 44 | tk_symbol_t { "<" , tk_sym_lt }, 45 | tk_symbol_t { "<=" , tk_sym_lteq }, 46 | tk_symbol_t { "<<" , tk_sym_ltlt }, 47 | tk_symbol_t { "<<=" , tk_sym_ltlteq }, 48 | tk_symbol_t { "-" , tk_sym_minus }, 49 | tk_symbol_t { "-=" , tk_sym_minuseq }, 50 | tk_symbol_t { "--" , tk_sym_minusminus }, 51 | tk_symbol_t { "(" , tk_sym_paren_l }, 52 | tk_symbol_t { ")" , tk_sym_paren_r }, 53 | tk_symbol_t { "%" , tk_sym_percent }, 54 | tk_symbol_t { "%=" , tk_sym_percenteq }, 55 | tk_symbol_t { "|" , tk_sym_pipe }, 56 | tk_symbol_t { "|=" , tk_sym_pipeeq }, 57 | tk_symbol_t { "||" , tk_sym_pipepipe }, 58 | tk_symbol_t { "+" , tk_sym_plus }, 59 | tk_symbol_t { "+=" , tk_sym_pluseq }, 60 | tk_symbol_t { "++" , tk_sym_plusplus }, 61 | tk_symbol_t { "?" , tk_sym_question }, 62 | tk_symbol_t { ";" , tk_sym_semi }, 63 | tk_symbol_t { "/" , tk_sym_slash }, 64 | tk_symbol_t { "/=" , tk_sym_slasheq }, 65 | tk_symbol_t { "*" , tk_sym_star }, 66 | tk_symbol_t { "*=" , tk_sym_stareq }, 67 | tk_symbol_t { "~" , tk_sym_tilde }, 68 | }; 69 | const size_t num_op_symbols = sizeof(tk_op_symbols) / sizeof(tk_symbol_t); 70 | 71 | //////////////////////////////////////////////////////////////////////////////// 72 | 73 | typedef std::pair pair_t; 74 | 75 | class match_operator_t { 76 | public: 77 | match_operator_t(); 78 | result_t substring(range_t range) const; 79 | 80 | private: 81 | // Return the range of operators matching the first character. 82 | pair_t first_char(size_t c) const; 83 | 84 | // Return the range of operators matching a subsequent character. 85 | pair_t next_char(pair_t pair, int pos, char c) const; 86 | 87 | std::vector tokens; 88 | std::vector kinds; 89 | std::vector first_char_map; 90 | }; 91 | 92 | match_operator_t::match_operator_t() { 93 | std::vector symbols(tk_op_symbols, 94 | tk_op_symbols + num_op_symbols); 95 | auto cmp = [](tk_symbol_t a, tk_symbol_t b) { 96 | return strcmp(a.symbol, b.symbol) < 0; 97 | }; 98 | std::sort(symbols.begin(), symbols.end(), cmp); 99 | 100 | tokens.resize(num_op_symbols); 101 | kinds.resize(num_op_symbols); 102 | for(size_t i = 0; i < num_op_symbols; ++i) { 103 | tokens[i] = symbols[i].symbol; 104 | kinds[i] = symbols[i].kind; 105 | } 106 | 107 | first_char_map.resize(257); 108 | for(size_t i = 0; i < 256; ++i) { 109 | auto cmp = [](const char* p, char c) { 110 | return (uint8_t)p[0] < (uint8_t)c; 111 | }; 112 | 113 | auto it = std::lower_bound(tokens.begin(), tokens.end(), (char)i, cmp); 114 | first_char_map[i] = it - tokens.begin(); 115 | } 116 | first_char_map[256] = tokens.size(); 117 | } 118 | 119 | inline pair_t match_operator_t::first_char(size_t c) const { 120 | // Build a map so the range of the first character is a direct lookup. 121 | size_t begin = first_char_map[c]; 122 | size_t end = first_char_map[c + 1]; 123 | return std::make_pair(begin, end); 124 | } 125 | 126 | inline pair_t match_operator_t::next_char(pair_t pair, int pos, char c) const { 127 | assert('\0' != c); 128 | 129 | // Scan from left-to-right until we hit a match. 130 | size_t begin2 = pair.first; 131 | while(begin2 != pair.second && c != tokens[begin2][pos]) 132 | ++begin2; 133 | 134 | // Scan from left-to-right until we hit a miss. 135 | size_t end2 = begin2; 136 | while(end2 != pair.second && c == tokens[end2][pos]) 137 | ++end2; 138 | 139 | return std::make_pair(begin2, end2); 140 | } 141 | 142 | result_t match_operator_t::substring(range_t range) const { 143 | const char* begin = range.begin; 144 | result_t result; 145 | 146 | pair_t match = first_char((uint8_t)range.peek()); 147 | if(match.first < match.second) { 148 | int pos = 0; 149 | pair_t match2 = match; 150 | while(match2.first < match2.second && range) { 151 | ++pos; 152 | ++range.begin; 153 | 154 | match = match2; 155 | if(char c = range.peek()) 156 | match2 = next_char(match, pos, c); 157 | else 158 | match2.first = match2.second; 159 | } 160 | 161 | assert(match.first < match.second); 162 | if('\0' == tokens[match.first][pos]) { 163 | // We've run out of matches. Test the first operator from the input 164 | // range. 165 | result = make_result(begin, range.begin, kinds[match.first]); 166 | } 167 | } 168 | 169 | return result; 170 | } 171 | 172 | result_t match_operator(range_t range) { 173 | static match_operator_t match; 174 | return match.substring(range); 175 | } 176 | 177 | } // namespace tok 178 | 179 | END_APEX_NAMESPACE 180 | -------------------------------------------------------------------------------- /src/tokenizer/tokenizer.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | BEGIN_APEX_NAMESPACE 6 | 7 | namespace tok { 8 | 9 | apex::parse::range_t tokenizer_t::token_range() const { 10 | return { tokens.data(), tokens.data() + tokens.size() }; 11 | } 12 | 13 | int tokenizer_t::reg_string(range_t range) { 14 | int id = find_string(range); 15 | if(-1 == id) { 16 | id = (int)strings.size(); 17 | strings.push_back(std::string(range.begin, range.end)); 18 | } 19 | return id; 20 | } 21 | 22 | int tokenizer_t::find_string(range_t range) const { 23 | for(size_t i = 0; i < strings.size(); ++i) { 24 | if(0 == strings[i].compare(range.begin)) 25 | return i; 26 | } 27 | return -1; 28 | } 29 | 30 | void tokenizer_t::tokenize() { 31 | // Mark the byte of each line offset. 32 | size_t len = text.size(); 33 | line_offsets.push_back(0); 34 | for(size_t i = 0; i < len; ++i) { 35 | if('\n' == text[i]) 36 | line_offsets.push_back(i); 37 | } 38 | line_offsets.push_back(len); 39 | 40 | lexer_t lexer(*this); 41 | range_t range { text.data(), text.data() + text.size() }; 42 | 43 | while(true) { 44 | // Skip past whitespace and comments. 45 | lexer.advance_skip(range); 46 | 47 | if(auto token = lexer.token(range)) { 48 | range.advance(token); 49 | tokens.push_back(token->attr); 50 | } else 51 | break; 52 | } 53 | } 54 | 55 | int tokenizer_t::token_offset(source_loc_t loc) const { 56 | return tokens[loc.index].begin - text.c_str(); 57 | } 58 | 59 | int tokenizer_t::token_line(int offset) const { 60 | // Binary search to find the line for this byte offset. 61 | auto it = std::upper_bound(line_offsets.begin(), line_offsets.end(), offset); 62 | int line = it - line_offsets.begin() - 1; 63 | return line; 64 | } 65 | 66 | int tokenizer_t::token_col(int offset, int line) const { 67 | // Walk forward, decoding UTF-8 and count the column adjustment to reach 68 | // the offset from the line offset. 69 | int col = 0; 70 | int pos = line_offsets[line]; 71 | while(pos < offset) { 72 | std::pair ucs = from_utf8(text.data() + pos); 73 | 74 | // Advance by the number of bytes in the character. 75 | pos += ucs.first; 76 | 77 | // Advance by one column. 78 | ++col; 79 | } 80 | return col; 81 | } 82 | 83 | std::pair tokenizer_t::token_linecol(int offset) const { 84 | int line = token_line(offset); 85 | int col = token_col(offset, line); 86 | return { line, col }; 87 | } 88 | 89 | std::pair tokenizer_t::token_linecol(source_loc_t loc) const { 90 | return token_linecol(token_offset(loc)); 91 | } 92 | 93 | } // namespace tok 94 | 95 | END_APEX_NAMESPACE 96 | 97 | -------------------------------------------------------------------------------- /src/tokenizer/tokens.cxx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seanbaxter/apex/f07a92857efd0d7a23d174227b8154f4cbaf01b6/src/tokenizer/tokens.cxx -------------------------------------------------------------------------------- /src/util/format.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | BEGIN_APEX_NAMESPACE 5 | 6 | std::string format(const char* pattern, ...) { 7 | va_list args; 8 | va_start(args, pattern); 9 | std::string s = vformat(pattern, args); 10 | va_end(args); 11 | return s; 12 | } 13 | 14 | std::string vformat(const char* pattern, va_list args) { 15 | va_list args_copy; 16 | va_copy(args_copy, args); 17 | 18 | int len = std::vsnprintf(nullptr, 0, pattern, args); 19 | std::string result(len, ' '); 20 | std::vsnprintf(result.data(), len + 1, pattern, args_copy); 21 | 22 | va_end(args_copy); 23 | return result; 24 | } 25 | 26 | END_APEX_NAMESPACE 27 | -------------------------------------------------------------------------------- /src/util/utf.cxx: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | BEGIN_APEX_NAMESPACE 4 | 5 | // Encodes ucs into the UTF-8 buffer at s. Returs the number of characters 6 | // encoded. 0 indicates error. 7 | int to_utf8(char* s, int ucs) { 8 | if(ucs <= 0x007f) { 9 | s[0] = (char)ucs; 10 | return 1; 11 | 12 | } else if(ucs <= 0x07ff) { 13 | s[0] = 0xc0 | (ucs>> 6); 14 | s[1] = 0x80 | (0x3f & ucs); 15 | return 2; 16 | 17 | } else if(ucs <= 0xffff) { 18 | s[0] = 0xe0 | (ucs>> 12); 19 | s[1] = 0x80 | (0x3f & (ucs>> 6)); 20 | s[2] = 0x80 | (0x3f & ucs); 21 | return 3; 22 | 23 | } else if(ucs <= 0x10ffff) { 24 | s[0] = 0xf0 | (ucs>> 18); 25 | s[1] = 0x80 | (0x3f & (ucs>> 12)); 26 | s[2] = 0x80 | (0x3f & (ucs>> 6)); 27 | s[3] = 0x80 | (0x3f & ucs); 28 | return 4; 29 | } 30 | return 0; 31 | } 32 | 33 | // Returns the number of code-units consumed and the value of the character. 34 | // 0 indicates error. 35 | std::pair from_utf8(const char* s) { 36 | std::pair result { }; 37 | 38 | if(0 == (0x80 & s[0])) { 39 | result = std::make_pair(1, s[0]); 40 | 41 | } else if(0xc0 == (0xe0 & s[0])) { 42 | if(0x80 == (0xc0 & s[1])) { 43 | int ucs = (0x3f & s[1]) + ((0x1f & s[0])<< 6); 44 | result = std::make_pair(2, ucs); 45 | } 46 | 47 | } else if(0xe0 == (0xf0 & s[0])) { 48 | if(0x80 == (0xc0 & s[1]) && 49 | 0x80 == (0xc0 & s[2])) { 50 | int ucs = (0x3f & s[2]) + ((0x3f & s[1])<< 6) + ((0x0f & s[0])<< 12); 51 | result = std::make_pair(3, ucs); 52 | } 53 | 54 | } else if(0xf0 == (0xf8 & s[0])) { 55 | if(0x80 == (0xc0 & s[1]) && 56 | 0x80 == (0xc0 & s[2]) && 57 | 0x80 == (0xc0 & s[3])) { 58 | int ucs = (0x3f & s[3]) + ((0x3f & s[2])<< 6) + 59 | ((0x3f & s[1])<< 12) + ((0x07 & s[0])<< 18); 60 | result = std::make_pair(4, ucs); 61 | } 62 | } 63 | return result; 64 | } 65 | 66 | END_APEX_NAMESPACE 67 | --------------------------------------------------------------------------------