├── License.txt ├── README.md ├── compiler.lyx ├── design.txt └── src ├── examples ├── goertzel.lisp └── lerp.lisp ├── remora-front ├── sexp_files.ml └── sexp_files.mli └── remora-internal ├── annotation.ml ├── annotation.mli ├── basic_ast.ml ├── basic_ast.mli ├── closures.ml ├── closures.mli ├── erased_ast.ml ├── erased_ast.mli ├── frame_notes.ml ├── frame_notes.mli ├── globals.ml ├── globals.mli ├── map_replicate_ast.ml ├── map_replicate_ast.mli ├── run_tests.ml ├── substitution.ml ├── substitution.mli ├── test_basic_ast.ml ├── test_basic_ast.mli ├── test_closures.ml ├── test_closures.mli ├── test_erased_ast.ml ├── test_frame_notes.ml ├── test_map_replicate_ast.ml ├── test_typechecker.ml ├── typechecker.ml └── typechecker.mli /License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of NVIDIA CORPORATION nor the names of its 12 | contributors may be used to endorse or promote products derived 13 | from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the in-progress compiler for Remora. 2 | 3 | For more on the language itself, see [](http://www.ccs.neu.edu/home/jrslepak/esop14-full.pdf). Or try an untyped prototype version of the language implemented as a [Racket #lang](https://github.com/jrslepak/Remora/tree/master/remora). 4 | -------------------------------------------------------------------------------- /design.txt: -------------------------------------------------------------------------------- 1 | Syntax Tree 2 | In Remora, expressions and array elements are syntactically separated. 3 | Every expression must represent an array, whereas array elements may be non-array values (such as base data and functions). 4 | This is normally described with two mutually recursive tree types, one for expressions and one for elements. 5 | Here, we use delayed recursion to gain some flexibility about AST annotations (which will be added by later passes). 6 | The compiler's abstract syntax tree is defined as a variant, with one branch for each possible syntactic form. 7 | 8 | Two ways to complete the recursion are offered. 9 | A rem_expr or rem_elt simply uses rem_expr and rem_elt as the expression and element child node types. 10 | The term (foo bar baz) would be represented as RExpr (App (RExpr foo, [RExpr bar; RExpr baz])). 11 | This form is meant for constructing either by hand in test data or by a parser, which has no information to include as annotations. 12 | In an ann_expr or ann_elt, the constructor includes a polymorphic annotation field. 13 | 14 | There is a pair of procedures for mapping some function over the AST. 15 | The mapped function must specify the behavior at each node; map handles the recursion as appropriate. 16 | This allows AST traversals to explicitly handle only the non-trivial cases and leave simple pass-through cases to map_expr_form/map_elt_for. 17 | For example, annot_expr_app_frame in Frame_notes annotates App forms with their frame shapes and marks "not applicable" on other forms. 18 | 19 | A non-annotated AST can be converted to an annotated one where every node is marked with a designated "blank" annotation. 20 | Any function can also be mapped over the annotations in an AST, and two ASTs which differ only in their annotations can have their annotations merged using some user-provided function (e.g. constructing pairs). 21 | These procedures reuse as much of the input AST structure as possible (e.g., merging will not copy annotations or AST leaves, just create new references to them). 22 | By treating the AST as a persistent data structure, annotation passes can safely "forget" their input annotations. 23 | If the input annotations are needed later, they can be merged back in. 24 | 25 | 26 | No type inference 27 | 28 | 29 | Annotating with types 30 | 31 | 32 | Annotating with frames 33 | Once type annotations are present, the frame shape in an application form can be determined by comparing the function's result cell type with the actual type ascribed to the application form. 34 | The function for making this comparison is the same frame_contribution used in type checking. 35 | This is the information which will be needed later for determining what dimensions a Map form should look past. 36 | Map will treat all of its arguments the same, so the pass which emits them must also emit Replicate operations to ensure that all arguments in an application form have the same frame shape. 37 | A Replicate must add dimensions to its argument to expand the argument's own frame shape to match that of the application form surrounding it. 38 | To prepare for this, every argument within an App form is annotated with the necessary frame expansion. 39 | This pass requires extra arguments which give the expected cell type for the current expression and the overall frame shape of the enclosing application. 40 | When processing an App form, the recursive calls for the arguments are given the corresponding piece of the function's type annotation as well as the App form's frame shape annotation. 41 | Recursive calls for subterms of non-Apps are given None and NotApp to indicate that they should give NotArg as their expansion annotation. 42 | 43 | Erasing explicit types 44 | 45 | 46 | Map/Replicate IR 47 | -------------------------------------------------------------------------------- /src/examples/goertzel.lisp: -------------------------------------------------------------------------------- 1 | (RProg 2 | ;; Top-level definitions 3 | ((RDefn sample-rate (TArray (IShape ()) TFloat) 4 | (RExpr (Arr () ((RElt (Float 8000.)))))) 5 | (RDefn tau (TArray (IShape ()) TFloat) 6 | (RExpr 7 | (Arr () 8 | ((RElt 9 | (Float 10 | 6.283185307179586)))))) 11 | (RDefn 12 | sinusoid* 13 | (TDProd ((wl SNat)) 14 | (TAll (wt) 15 | (TArray 16 | (IShape ()) 17 | (TFun ((TArray (IShape ((IVar wl (Some SNat)))) (TVar wt)) 18 | (TArray (IShape ()) TFloat) 19 | (TArray (IShape ()) TFloat)) 20 | (TArray (IShape ((IVar wl (Some SNat)))) TFloat))))) 21 | (RExpr 22 | (ILam 23 | ((wl SNat)) 24 | (RExpr 25 | (TLam 26 | (wt) 27 | (RExpr 28 | (Arr 29 | () 30 | (#; 31 | (RElt 32 | (Lam 33 | ;; Maybe using '&' suffix to identify witness args is 34 | ;; a useful naming convention? 35 | ((length& (TArray (IShape ((IVar wl (Some SNat)))) (TVar wt))) 36 | (freq (TArray (IShape ()) TFloat)) 37 | (phase (TArray (IShape ()) TFloat))) 38 | (RExpr 39 | (App 40 | (RExpr (Var cos)) 41 | ((RExpr 42 | (App 43 | (RExpr (Var +.)) 44 | ((RExpr (Var phase)) 45 | (RExpr 46 | (App 47 | (RExpr (Var *.)) 48 | (;; [0 1 2 ...] 49 | (RExpr (App (RExpr (Var "float")) 50 | ((RExpr 51 | (App 52 | (RExpr 53 | (TApp 54 | (RExpr 55 | (IApp 56 | (RExpr (Var iota*)) 57 | ((IShape ((IVar wl (Some SNat))))))) 58 | ((TVar wt)))) 59 | ((RExpr (Var length&)))))))) 60 | (RExpr 61 | (App 62 | (RExpr (Var *.)) 63 | ((RExpr 64 | (Var freq)) 65 | (RExpr 66 | (Var tau)))))))))))))))))))))))) 67 | (RDefn 68 | goertzel-iir-step 69 | (TArray 70 | (IShape ()) 71 | (TFun ((TArray (IShape ()) TFloat)) 72 | (TArray (IShape ()) 73 | (TFun ((TArray (IShape ()) TFloat) 74 | (TArray (IShape ((INat 2))) TFloat)) 75 | (TArray (IShape ((INat 2))) TFloat))))) 76 | (RExpr 77 | (Arr 78 | () 79 | ((RElt 80 | (Lam 81 | ((freq (TArray (IShape ()) TFloat))) 82 | (RExpr 83 | (Arr 84 | () 85 | ((RElt 86 | (Lam 87 | ((next (TArray (IShape ()) TFloat)) 88 | (accum (TArray (IShape ((INat 2))) TFloat))) 89 | (RExpr 90 | (Arr 91 | (2) 92 | ((RElt 93 | (Expr 94 | (RExpr 95 | (App 96 | (RExpr (Var -.)) 97 | ((RExpr 98 | (App 99 | (RExpr (Var +.)) 100 | ((RExpr (Var next)) 101 | (RExpr 102 | (App 103 | (RExpr (Var *.)) 104 | ((RExpr (Arr () ((RElt (Float 2.))))) 105 | (RExpr 106 | (App 107 | (RExpr (Var *.)) 108 | ((RExpr 109 | (App 110 | (RExpr (Var cos)) 111 | ((RExpr 112 | (App 113 | (RExpr (Var *.)) 114 | ((RExpr (Var tau)) 115 | (RExpr (Var freq)))))))) 116 | (RExpr 117 | (App 118 | (RExpr 119 | (TApp 120 | (RExpr 121 | (IApp 122 | (RExpr (Var head)) 123 | ((INat 1) 124 | (IShape ())))) 125 | (TFloat))) 126 | ((RExpr (Var accum)))))))))))))) 127 | (RExpr 128 | (App 129 | (RExpr 130 | (TApp 131 | (RExpr 132 | (IApp 133 | (RExpr (Var tail)) 134 | ((INat 1) 135 | (IShape ())))) 136 | (TFloat))) 137 | ((RExpr (Var accum)))))))))) 138 | (RElt 139 | (Expr 140 | (RExpr 141 | (App 142 | (RExpr 143 | (TApp 144 | (RExpr 145 | (IApp 146 | (RExpr (Var head)) 147 | ((INat 1) 148 | (IShape ())))) 149 | (TFloat))) 150 | ((RExpr (Var accum))))))))))))))))))))) 151 | (RDefn 152 | goertzel-iir 153 | (TDProd 154 | ((len SNat)) 155 | (TArray (IShape ()) 156 | (TFun ((TArray (IShape ()) TFloat) 157 | (TArray (IShape ((IVar len (Some SNat)))) TFloat)) 158 | (TArray (IShape ((IVar len (Some SNat)))) TFloat)))) 159 | (RExpr 160 | (ILam 161 | ((len SNat)) 162 | (RExpr 163 | (Arr 164 | () 165 | ((RElt 166 | (Lam 167 | ((freq (TArray (IShape ()) TFloat)) 168 | (signal (TArray (IShape ((IVar len (Some SNat)))) TFloat))) 169 | (RExpr 170 | (App 171 | (RExpr 172 | (TApp 173 | (RExpr 174 | (IApp 175 | (RExpr (var head)) 176 | ((INat 1) 177 | (IShape ())))) 178 | (TFloat))) 179 | ((RExpr 180 | (App 181 | (RExpr 182 | (TApp 183 | (RExpr 184 | (IApp 185 | (RExpr (Var scanl)) 186 | ((IVar len (Some SNat)) 187 | (IShape ((INat 2))) 188 | (IShape ())))) 189 | (TFloat TFloat))) 190 | ((RExpr (App (RExpr (Var goertzel-iir-step)) 191 | ((RExpr (Var freq))))) 192 | (RExpr (Arr (2) ((RElt (Float 0.)) (RElt (Float 0.))))) 193 | (RExpr (Var signal))))))))))))))))) 194 | ;; Main expression: Read chosen frequencies and signal, apply IIR stage 195 | ;; of Goertzel filters at those frequencies. 196 | (RExpr 197 | (App 198 | (RExpr 199 | (Arr 200 | () 201 | ((RElt 202 | (Lam 203 | ((freq (TArray (IShape ()) TFloat))) 204 | (RExpr 205 | (Unpack 206 | (l) signal 207 | (RExpr (App (RExpr (Var readvec_f)) ())) 208 | (RExpr 209 | (Pack 210 | ((IVar l (Some SNat))) 211 | (RExpr 212 | (App 213 | (RExpr 214 | (IApp 215 | (RExpr (Var goertzel-iir)) 216 | ((IVar l (Some SNat))))) 217 | ((RExpr (Var freq)) 218 | (RExpr (Var signal))))) 219 | (TDSum ((sig-length SNat)) 220 | (TArray (IShape ((IVar sig-length (Some SNat)))) 221 | TFloat))))))))))) 222 | ((RExpr (App (RExpr (Var readscal_f)) ())))))) 223 | 224 | -------------------------------------------------------------------------------- /src/examples/lerp.lisp: -------------------------------------------------------------------------------- 1 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 2 | ;; Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. ;; 3 | ;; ;; 4 | ;; Redistribution and use in source and binary forms, with or without ;; 5 | ;; modification, are permitted provided that the following conditions ;; 6 | ;; are met: ;; 7 | ;; * Redistributions of source code must retain the above copyright ;; 8 | ;; notice, this list of conditions and the following disclaimer. ;; 9 | ;; * Redistributions in binary form must reproduce the above copyright ;; 10 | ;; notice, this list of conditions and the following disclaimer in the ;; 11 | ;; documentation and/or other materials provided with the distribution. ;; 12 | ;; * Neither the name of NVIDIA CORPORATION nor the names of its ;; 13 | ;; contributors may be used to endorse or promote products derived ;; 14 | ;; from this software without specific prior written permission. ;; 15 | ;; ;; 16 | ;; THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY ;; 17 | ;; EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ;; 18 | ;; IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR ;; 19 | ;; PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR ;; 20 | ;; CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, ;; 21 | ;; EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, ;; 22 | ;; PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR ;; 23 | ;; PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY ;; 24 | ;; OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT ;; 25 | ;; (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE ;; 26 | ;; OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ;; 27 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 28 | 29 | (RProg 30 | ((RDefn 31 | lerp 32 | (TArray (IShape ()) 33 | (TFun ((TArray (IShape ()) TFloat) 34 | (TArray (IShape ()) TFloat) 35 | (TArray (IShape ()) TFloat)) 36 | (TArray (IShape ()) TFloat))) 37 | (RExpr 38 | (Arr () 39 | ((RElt 40 | (Lam ((lo (TArray (IShape ()) TFloat)) 41 | (hi (TArray (IShape ()) TFloat)) 42 | (mid (TArray (IShape ()) TFloat))) 43 | (RExpr 44 | (App (RExpr (Var +.)) 45 | ((RExpr (App (RExpr (Var *.)) 46 | ((RExpr (Var mid)) 47 | (RExpr (Var hi))))) 48 | (RExpr (App (RExpr (Var *.)) 49 | ((RExpr (App 50 | (RExpr (Var -.)) 51 | ((RExpr 52 | (Arr () 53 | ((RElt (Float 1.0))))) 54 | (RExpr (Var mid))))) 55 | (RExpr (Var lo))))))))))))))) 56 | (RExpr (App (RExpr (Var lerp)) 57 | ((RExpr (Arr (2) ((RElt (Float 0.0)) (RElt (Float 1.0))))) 58 | (RExpr (Arr (2) ((RElt (Float 7.0)) (RElt (Float 4.0))))) 59 | (RExpr (Arr () ((RElt (Float 0.5))))))))) 60 | -------------------------------------------------------------------------------- /src/remora-front/sexp_files.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | 32 | (* Load (and s-expression deserialize) an AST from file *) 33 | let load_ast filename = Sexp.load_sexp filename |> rem_prog_of_sexp 34 | -------------------------------------------------------------------------------- /src/remora-front/sexp_files.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | val load_ast : string -> Basic_ast.rem_prog 30 | -------------------------------------------------------------------------------- /src/remora-internal/annotation.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | open Core.Option 32 | open Core.Option.Monad_infix 33 | 34 | (* Not sure why there's only map2_exn in Core and no total map2 *) 35 | let map2 ~f xs ys = 36 | (try Some (List.map2_exn ~f:f xs ys) with 37 | | _ -> None) 38 | 39 | (* Given two matching ASTs, merge their annotations *) 40 | let rec annot_elt_merge 41 | (f: 'a -> 'b -> 'c) 42 | (ast1: 'a ann_elt) 43 | (ast2: 'b ann_elt) : 'c ann_elt option = 44 | let (AnnRElt (annot1, elt1), AnnRElt (annot2, elt2)) = (ast1, ast2) in 45 | let new_annot = f annot1 annot2 46 | and (new_elt: ('c ann_elt, 'c ann_expr) elt_form option) = 47 | match (elt1, elt2) with 48 | | (Lam (bind1, body1), Lam (bind2, body2)) -> 49 | if (bind1 = bind2) 50 | then (annot_expr_merge f body1 body2 >>= fun (new_body: 'c ann_expr) -> 51 | Some (Lam (bind1, new_body))) 52 | else None 53 | | (Expr e1, Expr e2) -> 54 | annot_expr_merge f e1 e2 >>= fun (new_expr: 'c ann_expr) -> 55 | Some (Expr new_expr) 56 | (* In these cases, must reconstruct elt1 to use it at a different type. *) 57 | | ((Float c1) as v1, Float c2) -> Option.some_if (c1 = c2) v1 58 | | ((Int c1) as v1, Int c2) -> Option.some_if (c1 = c2) v1 59 | | ((Bool c1) as v1, Bool c2) -> Option.some_if (c1 = c2) v1 60 | (* Anything else with an already-handled form means mismatching ASTs. *) 61 | | ((Lam _ | Expr _ | Float _ | Int _ | Bool _), _) -> None 62 | in new_elt >>= fun valid_new_elt -> 63 | return (AnnRElt (new_annot, valid_new_elt)) 64 | and annot_expr_merge 65 | (f: 'a -> 'b -> 'c) 66 | (ast1: 'a ann_expr) 67 | (ast2: 'b ann_expr) : 'c ann_expr option = 68 | let (AnnRExpr (annot1, expr1), AnnRExpr (annot2, expr2)) = (ast1, ast2) in 69 | let new_annot = f annot1 annot2 70 | and (new_expr: ('c ann_expr, 'c ann_elt) expr_form option) = 71 | match (expr1, expr2) with 72 | | (App (fn1, args1), App (fn2, args2)) -> 73 | annot_expr_merge f fn1 fn2 >>= fun new_fn -> 74 | map2 ~f:(annot_expr_merge f) args1 args2 >>= fun merged -> 75 | Option.all merged >>= fun new_args -> 76 | return (App (new_fn, new_args)) 77 | | (TApp (fn1, t_args1), TApp (fn2, t_args2)) -> 78 | annot_expr_merge f fn1 fn2 >>= fun new_fn -> 79 | Option.some_if (t_args1 = t_args2) (TApp (new_fn, t_args1)) 80 | | (TLam (bind1, body1), TLam (bind2, body2)) -> 81 | if bind1 = bind2 82 | then annot_expr_merge f body1 body2 >>= fun new_body -> 83 | Some (TLam (bind1, new_body)) 84 | else None 85 | | (IApp (fn1, i_args1), IApp (fn2, i_args2)) -> 86 | annot_expr_merge f fn1 fn2 >>= fun new_fn -> 87 | Option.some_if (i_args1 = i_args2) (IApp (new_fn, i_args1)) 88 | | (ILam (bind1, body1), ILam (bind2, body2)) -> 89 | if bind1 = bind2 90 | then annot_expr_merge f body1 body2 >>= fun new_body -> 91 | Some (ILam (bind1, new_body)) 92 | else None 93 | | (Arr (dims1, elts1) , Arr (dims2, elts2)) -> 94 | if dims1 = dims2 95 | (* then (try Some (List.map2_exn ~f:annot_elt_merge elts1 elts2) with *) 96 | (* | Invalid_argument _ -> None) >>= fun merged -> *) 97 | then map2 ~f:(annot_elt_merge f) elts1 elts2 >>= fun merged -> 98 | Option.all merged >>= fun new_elts -> 99 | return (Arr (dims2, new_elts)) 100 | else None 101 | | (Var v1 as v, Var v2) -> Option.some_if (v1 = v2) v 102 | | (Pack (idxs1, value1, type1), Pack (idxs2, value2, type2)) -> 103 | if idxs1 = idxs2 && type1 = type2 104 | then (annot_expr_merge f value1 value2 >>= fun new_value -> 105 | Some (Pack (idxs1, new_value, type1))) 106 | else None 107 | | (Unpack (i_vars1, v1, dsum1, body1), 108 | Unpack (i_vars2, v2, dsum2, body2)) -> 109 | if i_vars1 = i_vars2 && v1 = v2 110 | then annot_expr_merge f dsum1 dsum2 >>= fun new_dsum -> 111 | annot_expr_merge f body1 body2 >>= fun new_body -> 112 | return (Unpack (i_vars1, v1, new_dsum, new_body)) 113 | else None 114 | | (Let (var1, bound1, body1), Let (var2, bound2, body2)) -> 115 | if var1 = var2 116 | then annot_expr_merge f bound1 bound2 >>= fun new_bound -> 117 | annot_expr_merge f body1 body2 >>= fun new_body -> 118 | return (Let (var1, new_bound, new_body)) 119 | else None 120 | | (Tuple elts1, Tuple elts2) -> 121 | map2 ~f:(annot_expr_merge f) elts1 elts2 |> 122 | Option.map ~f:Option.all |> Option.join >>= fun new_elts -> 123 | return (Tuple new_elts) 124 | | (Field (n1, tup1), Field (n2, tup2)) -> 125 | if n1 = n2 126 | then annot_expr_merge f tup1 tup2 >>= fun new_tup -> 127 | return (Field (n1, new_tup)) 128 | else None 129 | | (LetTup (vars1, bound1, body1), LetTup (vars2, bound2, body2)) -> 130 | if vars1 = vars2 131 | then annot_expr_merge f bound1 bound2 >>= fun new_bound -> 132 | annot_expr_merge f body1 body2 >>= fun new_body -> 133 | return (LetTup (vars1, new_bound, new_body)) 134 | else None 135 | (* Anything else with an already-handled form means mismatching ASTs. *) 136 | | ((App _ | TApp _ | TLam _ | IApp _ | ILam _ 137 | | Arr _ | Var _ | Pack _ | Unpack _ 138 | | Let _ | Tuple _ | Field _ | LetTup _), _) -> None 139 | in new_expr >>= fun valid_new_expr -> 140 | return (AnnRExpr (new_annot, valid_new_expr)) 141 | ;; 142 | let annot_defn_merge 143 | (f: 'a -> 'b -> 'c) 144 | (ast1: 'a ann_defn) 145 | (ast2: 'b ann_defn) : 'c ann_defn option = 146 | let (AnnRDefn (n1, t1, b1), AnnRDefn (n2, t2, b2)) = (ast1, ast2) in 147 | if (n1 = n2 && t1 = t2) 148 | then annot_expr_merge f b1 b2 >>= fun body -> AnnRDefn (n1, t1, body) 149 | |> return 150 | else None 151 | let annot_prog_merge 152 | (f: 'a -> 'b -> 'c) 153 | (ast1: 'a ann_prog) 154 | (ast2: 'b ann_prog) : 'c ann_prog option = 155 | let (AnnRProg (annot1, defs1, expr1), AnnRProg (annot2, defs2, expr2)) 156 | = (ast1, ast2) in 157 | map2 ~f:(annot_defn_merge f) defs1 defs2 158 | >>| Option.all |> Option.join 159 | >>= fun (ds: 'c ann_defn list) -> 160 | annot_expr_merge f expr1 expr2 >>= fun e -> 161 | AnnRProg (f annot1 annot2, ds, e) |> return 162 | 163 | 164 | (* Given an annotated AST, apply a function to its annotations *) 165 | let rec annot_elt_fmap 166 | ~(f: 'a -> 'b) 167 | (AnnRElt (annot, elt): 'a ann_elt) : 'b ann_elt = 168 | AnnRElt (f annot, (map_elt_form ~f_expr:(annot_expr_fmap ~f:f) elt)) 169 | and annot_expr_fmap 170 | ~(f: 'a -> 'b) 171 | (AnnRExpr (annot, expr): 'a ann_expr) : 'b ann_expr = 172 | AnnRExpr (f annot, (map_expr_form 173 | ~f_expr:(annot_expr_fmap ~f:f) 174 | ~f_elt:(annot_elt_fmap ~f:f) expr)) 175 | ;; 176 | 177 | let annot_defn_fmap 178 | ~(f: 'a -> 'b) 179 | (ast: 'a ann_defn) : 'b ann_defn = 180 | let AnnRDefn (n, t, v) = ast in AnnRDefn (n, t, annot_expr_fmap ~f:f v) 181 | 182 | let annot_prog_fmap 183 | ~(f: 'a -> 'b) 184 | (ast: 'a ann_prog) : 'b ann_prog = 185 | let AnnRProg (annot, defns, expr) = ast in 186 | AnnRProg (f annot, 187 | List.map ~f:(annot_defn_fmap ~f:f) defns, 188 | annot_expr_fmap ~f:f expr) 189 | -------------------------------------------------------------------------------- /src/remora-internal/annotation.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Basic_ast 30 | 31 | val annot_elt_merge : 32 | ('a -> 'b -> 'c) 33 | -> 'a ann_elt 34 | -> 'b ann_elt 35 | -> 'c ann_elt option 36 | val annot_expr_merge : 37 | ('a -> 'b -> 'c) 38 | -> 'a ann_expr 39 | -> 'b ann_expr 40 | -> 'c ann_expr option 41 | val annot_defn_merge : 42 | ('a -> 'b -> 'c) 43 | -> 'a ann_defn 44 | -> 'b ann_defn 45 | -> 'c ann_defn option 46 | val annot_prog_merge : 47 | ('a -> 'b -> 'c) 48 | -> 'a ann_prog 49 | -> 'b ann_prog 50 | -> 'c ann_prog option 51 | 52 | val annot_elt_fmap : f:('a -> 'b) -> 'a ann_elt -> 'b ann_elt 53 | val annot_expr_fmap : f:('a -> 'b) -> 'a ann_expr -> 'b ann_expr 54 | val annot_defn_fmap : f:('a -> 'b) -> 'a ann_defn -> 'b ann_defn 55 | val annot_prog_fmap : f:('a -> 'b) -> 'a ann_prog -> 'b ann_prog 56 | -------------------------------------------------------------------------------- /src/remora-internal/basic_ast.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | 31 | (* representation of variables, may later want to add more info *) 32 | type var = string with sexp 33 | 34 | let gensym_counter = ref 0 35 | let gensym_reset () = gensym_counter := 0;; 36 | let gensym (v: string) : var = 37 | let id_number = !gensym_counter in 38 | let new_name = String.concat [v; string_of_int id_number] 39 | and _ = gensym_counter := 1 + id_number 40 | in new_name;; 41 | 42 | (* Remora index sorts *) 43 | type srt = SNat | SShape with sexp 44 | 45 | (* Remora indices *) 46 | type idx = 47 | | INat of int 48 | | IShape of idx list 49 | | ISum of idx * idx 50 | | IVar of var * srt option 51 | with sexp 52 | 53 | (* Shorthand for constructing index var nodes *) 54 | let ivar n = IVar (n, None) 55 | let nvar n = IVar (n, Some SNat) 56 | let svar n = IVar (n, Some SShape) 57 | 58 | (* Remora types *) 59 | type typ = 60 | | TFloat 61 | | TInt 62 | | TBool 63 | | TDProd of (var * srt) list * typ 64 | | TDSum of (var * srt) list * typ 65 | | TFun of typ list * typ 66 | | TArray of idx * typ 67 | | TProd of typ list 68 | | TAll of var list * typ 69 | | TVar of var 70 | with sexp 71 | 72 | (* This is a two-way mutually recursive version of a trick described at 73 | http://lambda-the-ultimate.org/node/4170#comment-63836 74 | 75 | First, the syntax tree node type is parameterized over what a subnode 76 | (i.e., subexpression) looks like. Then, we create an explicitly recursive 77 | type for the syntax tree, where a subexpression is a tree node and an 78 | annotation (again, polymorphic in the annotation's type). This allows the 79 | structure to be built with `option rtype' annotations at first, and then 80 | annotated later with (non-option) `rtype', optional frame shape, etc. *) 81 | (* General shape of a Remora expression *) 82 | type ('self_t, 'elt_t) expr_form = 83 | | App of 'self_t * 'self_t list 84 | | TApp of 'self_t * typ list 85 | | TLam of var list * 'self_t 86 | | IApp of 'self_t * idx list 87 | | ILam of (var * srt) list * 'self_t 88 | | Let of var * 'self_t * 'self_t 89 | | Arr of int list * 'elt_t list 90 | | Tuple of 'self_t list 91 | | Field of int * 'self_t 92 | | LetTup of var list * 'self_t * 'self_t 93 | | Var of var 94 | (* Unfortunately, this typ becomes redundant once the AST is type-annotated. *) 95 | | Pack of idx list * 'self_t * typ 96 | | Unpack of var list * var * 'self_t * 'self_t 97 | (* General shape of a Remora array element *) 98 | and ('self_t, 'expr_t) elt_form = 99 | | Float of float 100 | | Int of int 101 | | Bool of bool 102 | | Lam of (var * typ) list * 'expr_t 103 | | Expr of 'expr_t 104 | with sexp 105 | 106 | let map_expr_form 107 | ~(f_expr: 'old_self_t -> 'new_self_t) 108 | ~(f_elt: 'old_elt_t -> 'new_elt_t) 109 | (e: ('old_self_t, 'old_elt_t) expr_form) 110 | : ('new_self_t, 'new_elt_t) expr_form = 111 | match e with 112 | | App (fn, args) -> App (f_expr fn, 113 | List.map ~f:f_expr args) 114 | | TApp (fn, t_args) -> TApp (f_expr fn, t_args) 115 | | TLam (t_vars, body) -> TLam (t_vars, f_expr body) 116 | | IApp (fn, i_args) -> IApp (f_expr fn, i_args) 117 | | ILam (i_vars, body) -> ILam (i_vars, f_expr body) 118 | | Let (var, bound, body) -> Let (var, f_expr bound, f_expr body) 119 | | Arr (dims, elts) -> Arr (dims, List.map ~f:f_elt elts) 120 | | Tuple elts -> Tuple (List.map ~f:f_expr elts) 121 | | Field (num, tup) -> Field (num, f_expr tup) 122 | | LetTup (vars, tup, body) -> LetTup (vars, f_expr tup, f_expr body) 123 | | Var _ as v -> v 124 | | Pack (idxs, v, t) -> Pack (idxs, f_expr v, t) 125 | | Unpack (ivars, v, dsum, body) -> Unpack (ivars, v, f_expr dsum, f_expr body) 126 | 127 | let map_elt_form 128 | ~(f_expr: 'old_expr_t -> 'new_expr_t) 129 | (* ~(f_elt: 'old_self_t -> 'new_self_t) *) 130 | (l: ('old_self_t, 'old_expr_t) elt_form) 131 | : ('new_self_t, 'new_expr_t) elt_form = 132 | match l with 133 | | Float _ as f -> f 134 | | Int _ as i -> i 135 | | Bool _ as b -> b 136 | | Lam (vars, body) -> Lam (vars, f_expr body) 137 | | Expr e -> Expr (f_expr e) 138 | 139 | (* Remora terms with no extra annotation field *) 140 | type rem_expr = 141 | | RExpr of (rem_expr, rem_elt) expr_form 142 | and rem_elt = 143 | | RElt of (rem_elt, rem_expr) elt_form 144 | with sexp 145 | type rem_defn = RDefn of var * typ * rem_expr with sexp 146 | type rem_prog = RProg of rem_defn list * rem_expr with sexp 147 | 148 | (* Annotated Remora expression (parameterized over annotation type) *) 149 | type 'annot ann_expr = 150 | | AnnRExpr of 'annot * (('annot ann_expr), ('annot ann_elt)) expr_form 151 | and 'annot ann_elt = 152 | | AnnRElt of 'annot * (('annot ann_elt), ('annot ann_expr)) elt_form 153 | with sexp 154 | type 'annot ann_defn = AnnRDefn of var * typ * 'annot ann_expr with sexp 155 | type 'annot ann_prog = 156 | | AnnRProg of 'annot * 'annot ann_defn list * 'annot ann_expr 157 | with sexp 158 | 159 | 160 | (* Fully type-annotated Remora terms *) 161 | type t_expr = typ ann_expr with sexp 162 | type t_elt = typ ann_elt with sexp 163 | type t_defn = typ ann_defn with sexp 164 | type t_prog = typ ann_prog with sexp 165 | (* Partially type-annotated Remora terms *) 166 | type pt_expr = typ option ann_expr with sexp 167 | type pt_elt = typ option ann_elt with sexp 168 | type pt_defn = typ option ann_defn with sexp 169 | type pt_prog = typ option ann_prog with sexp 170 | 171 | 172 | 173 | (* For example, 174 | AnnRExpr ((TArray (IShape [2], TInt)), 175 | (Arr ([2], [AnnRElt (TInt, Int 3); 176 | AnnRElt (TInt, Int 2)])));; 177 | is the type-annotated version of the 2-vector [2,3]. 178 | 179 | With blank annotations (i.e. all annotations are ()), 180 | AnnRExpr ((), (Arr ([2], [AnnRElt ((), Int 3); AnnRElt ((), Int 2)]))) 181 | 182 | With no annotations, 183 | RExpr (Arr ([2], [RElt (Int 3); RElt (Int 2)])) 184 | *) 185 | 186 | (* Set up a designated "blank" annotation at every AST node *) 187 | let rec annot_expr_init ~(init: 'a) (expr: rem_expr) : 'a ann_expr = 188 | match expr with RExpr node -> 189 | AnnRExpr (init, 190 | map_expr_form 191 | ~f_expr:(annot_expr_init ~init:init) 192 | ~f_elt:(annot_elt_init ~init:init) 193 | node) 194 | and annot_elt_init ~(init: 'a) (elt: rem_elt) : 'a ann_elt = 195 | match elt with RElt node -> 196 | AnnRElt (init, 197 | map_elt_form 198 | ~f_expr:(annot_expr_init ~init:init) 199 | node) 200 | ;; 201 | 202 | (* Set up "blank" annotations in a definition *) 203 | let annot_defn_init ~(init: 'a) (defn: rem_defn) : 'a ann_defn = 204 | let RDefn (n, t, v) = defn 205 | in AnnRDefn(n, t, annot_expr_init ~init:init v) 206 | 207 | (* Set up "blank" annotations in a program *) 208 | let annot_prog_init ~(init: 'a) (prog: rem_prog) : 'a ann_prog = 209 | let RProg (defns, expr) = prog in 210 | AnnRProg (init, 211 | List.map ~f:(annot_defn_init ~init:init) defns, 212 | annot_expr_init ~init:init expr) 213 | 214 | (* Extract the non-annotated version of an AST node *) 215 | let rec annot_expr_drop (expr: 'a ann_expr) : rem_expr = 216 | match expr with AnnRExpr (_, node) -> 217 | RExpr (map_expr_form ~f_expr:annot_expr_drop ~f_elt:annot_elt_drop node) 218 | and annot_elt_drop (elt: 'a ann_elt) : rem_elt = 219 | match elt with AnnRElt (_, node) -> 220 | RElt (map_elt_form ~f_expr:annot_expr_drop node) 221 | ;; 222 | 223 | (* Extract non-annotated version of a definition *) 224 | let annot_defn_drop (defn: 'a ann_defn) : rem_defn = 225 | let AnnRDefn (n, t, v) = defn in 226 | RDefn (n, t, annot_expr_drop v) 227 | 228 | (* Extract non-annotated version of a program *) 229 | let annot_prog_drop (prog: 'a ann_prog) : rem_prog = 230 | let AnnRProg (_, defns, expr) = prog in 231 | RProg (List.map ~f:annot_defn_drop defns, 232 | annot_expr_drop expr) 233 | 234 | let annot_of_expr ((AnnRExpr (annot, _)): 'a ann_expr) : 'a = annot 235 | let annot_of_elt ((AnnRElt (annot, _)): 'a ann_elt) : 'a = annot 236 | let annot_of_defn ((AnnRDefn (_, _, AnnRExpr (annot, _))): 'a ann_defn) : 'a 237 | = annot 238 | let annot_of_prog ((AnnRProg (annot, _, _)): 'a ann_prog) : 'a = annot 239 | 240 | (* Collect the passes which are essential to compilation. *) 241 | module Passes : sig 242 | val prog : rem_prog -> unit ann_prog 243 | val defn : rem_defn -> unit ann_defn 244 | val expr : rem_expr -> unit ann_expr 245 | val elt : rem_elt -> unit ann_elt 246 | 247 | val prog_all : rem_prog -> unit ann_prog 248 | val defn_all : rem_defn -> unit ann_defn 249 | val expr_all : rem_expr -> unit ann_expr 250 | val elt_all : rem_elt -> unit ann_elt 251 | end = struct 252 | let prog = annot_prog_init ~init:() 253 | let prog_all = prog 254 | 255 | let defn = annot_defn_init ~init:() 256 | let defn_all = defn 257 | 258 | let expr = annot_expr_init ~init:() 259 | let expr_all = expr 260 | 261 | let elt = annot_elt_init ~init:() 262 | let elt_all = elt 263 | end 264 | -------------------------------------------------------------------------------- /src/remora-internal/basic_ast.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | 31 | type var = bytes with sexp 32 | 33 | val gensym_reset : unit -> unit 34 | val gensym : string -> var 35 | 36 | type srt = SNat | SShape with sexp 37 | 38 | type idx = 39 | | INat of int 40 | | IShape of idx list 41 | | ISum of idx * idx 42 | | IVar of var * srt option 43 | with sexp 44 | 45 | val ivar : var -> idx 46 | val nvar : var -> idx 47 | val svar : var -> idx 48 | 49 | type typ = 50 | | TFloat 51 | | TInt 52 | | TBool 53 | | TDProd of (var * srt) list * typ 54 | | TDSum of (var * srt) list * typ 55 | | TFun of typ list * typ 56 | | TArray of idx * typ 57 | | TProd of typ list 58 | | TAll of var list * typ 59 | | TVar of var 60 | with sexp 61 | 62 | type ('self_t, 'elt_t) expr_form = 63 | | App of 'self_t * 'self_t list 64 | | TApp of 'self_t * typ list 65 | | TLam of var list * 'self_t 66 | | IApp of 'self_t * idx list 67 | | ILam of (var * srt) list * 'self_t 68 | | Let of var * 'self_t * 'self_t 69 | | Arr of int list * 'elt_t list 70 | | Tuple of 'self_t list 71 | | Field of int * 'self_t 72 | | LetTup of var list * 'self_t * 'self_t 73 | | Var of var 74 | | Pack of idx list * 'self_t * typ 75 | | Unpack of var list * var * 'self_t * 'self_t 76 | and ('self_t, 'expr_t) elt_form = 77 | | Float of float 78 | | Int of int 79 | | Bool of bool 80 | | Lam of (var * typ) list * 'expr_t 81 | | Expr of 'expr_t 82 | with sexp 83 | 84 | val map_expr_form : 85 | f_expr:('old_self_t -> 'new_self_t) 86 | -> f_elt:('old_elt_t -> 'new_elt_t) 87 | -> ('old_self_t, 'old_elt_t) expr_form 88 | -> ('new_self_t, 'new_elt_t) expr_form 89 | val map_elt_form : 90 | f_expr:('old_expr_t -> 'new_expr_t) 91 | -> ('old_self_t, 'old_expr_t) elt_form 92 | -> ('new_self_t, 'new_expr_t) elt_form 93 | 94 | type 'annot ann_expr = 95 | AnnRExpr of 'annot * ('annot ann_expr, 'annot ann_elt) expr_form 96 | and 'annot ann_elt = 97 | AnnRElt of 'annot * ('annot ann_elt, 'annot ann_expr) elt_form 98 | with sexp 99 | type 'annot ann_defn = AnnRDefn of var * typ * 'annot ann_expr with sexp 100 | type 'annot ann_prog = 101 | | AnnRProg of 'annot * 'annot ann_defn list * 'annot ann_expr 102 | with sexp 103 | 104 | type t_expr = typ ann_expr with sexp 105 | type pt_expr = typ option ann_expr with sexp 106 | 107 | type t_elt = typ ann_elt with sexp 108 | type pt_elt = typ option ann_elt with sexp 109 | 110 | type t_defn = typ ann_defn with sexp 111 | type pt_defn = typ option ann_defn with sexp 112 | 113 | type t_prog = typ ann_prog with sexp 114 | type pt_prog = typ option ann_prog with sexp 115 | 116 | type rem_expr = RExpr of (rem_expr, rem_elt) expr_form 117 | and rem_elt = RElt of (rem_elt, rem_expr) elt_form 118 | with sexp 119 | 120 | type rem_defn = RDefn of var * typ * rem_expr with sexp 121 | 122 | type rem_prog = RProg of rem_defn list * rem_expr with sexp 123 | 124 | val annot_expr_init : init:'a -> rem_expr -> 'a ann_expr 125 | val annot_elt_init : init:'a -> rem_elt -> 'a ann_elt 126 | val annot_defn_init : init:'a -> rem_defn -> 'a ann_defn 127 | val annot_prog_init : init:'a -> rem_prog -> 'a ann_prog 128 | val annot_expr_drop : 'a ann_expr -> rem_expr 129 | val annot_elt_drop : 'a ann_elt -> rem_elt 130 | val annot_defn_drop : 'a ann_defn -> rem_defn 131 | val annot_prog_drop : 'a ann_prog -> rem_prog 132 | 133 | val annot_of_expr : 'a ann_expr -> 'a 134 | val annot_of_elt : 'a ann_elt -> 'a 135 | val annot_of_defn : 'a ann_defn -> 'a 136 | val annot_of_prog : 'a ann_prog -> 'a 137 | 138 | module Passes : sig 139 | val prog : rem_prog -> unit ann_prog 140 | val defn : rem_defn -> unit ann_defn 141 | val expr : rem_expr -> unit ann_expr 142 | val elt : rem_elt -> unit ann_elt 143 | 144 | val prog_all : rem_prog -> unit ann_prog 145 | val defn_all : rem_defn -> unit ann_defn 146 | val expr_all : rem_expr -> unit ann_expr 147 | val elt_all : rem_elt -> unit ann_elt 148 | end 149 | -------------------------------------------------------------------------------- /src/remora-internal/closures.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | module MR = Map_replicate_ast;; 31 | module B = Basic_ast;; 32 | module E = Erased_ast;; 33 | open Frame_notes 34 | 35 | type var = Basic_ast.var with sexp 36 | 37 | type 'a cl_app_t = {closure: 'a; args: 'a list;} with sexp 38 | type 'a closure_t = {code: 'a; env: 'a;} with sexp 39 | 40 | type 'a expr_form = 41 | | App of 'a cl_app_t 42 | | Vec of 'a MR.vec_t 43 | | Map of 'a MR.map_t 44 | | Rep of 'a MR.rep_t 45 | | Tup of 'a MR.tup_t 46 | | LetTup of 'a MR.lettup_t 47 | | Fld of 'a MR.fld_t 48 | | Let of 'a MR.let_t 49 | | Cls of 'a closure_t 50 | | Lam of 'a MR.lam_t 51 | | Var of var 52 | | Int of int 53 | | Float of float 54 | | Bool of bool 55 | with sexp 56 | 57 | let map_expr_form 58 | ~(f: 'a -> 'b) 59 | (e: 'a expr_form) : 'b expr_form = 60 | match e with 61 | | App {closure = c; args = a} -> App {closure = f c; 62 | args = List.map ~f:f a} 63 | | Vec {MR.dims = d; MR.elts = e} -> Vec {MR.dims = d; 64 | MR.elts = List.map ~f:f e} 65 | | Map {MR.frame = fr; MR.fn = fn; MR.args = a; MR.shp = shp} -> 66 | Map {MR.frame = f fr; 67 | MR.fn = f fn; 68 | MR.args = List.map ~f:f a; 69 | MR.shp = f shp} 70 | | Rep {MR.arg = a; MR.new_frame = n; MR.old_frame = o} -> 71 | Rep {MR.arg = f a; MR.new_frame = f n; MR.old_frame = f o} 72 | | Tup e -> Tup (List.map ~f:f e) 73 | | LetTup {MR.vars = v; MR.bound = bn; MR.body = bd} -> 74 | LetTup {MR.vars = v; MR.bound = f bn; MR.body = f bd} 75 | | Fld {MR.field = n; MR.tuple = tup} -> 76 | Fld {MR.field = n; MR.tuple = f tup} 77 | | Let {MR.var = v; MR.bound = nd; MR.body = bd} -> 78 | Let {MR.var = v; MR.bound = f nd; MR.body = f bd} 79 | | Cls {code = c; env = a} -> Cls {code = f c; env = f a} 80 | | Lam {MR.bindings = v; MR.body = e} -> 81 | Lam {MR.bindings = v; MR.body = f e} 82 | | Var _ | Int _ | Float _ | Bool _ as v -> v 83 | 84 | type expr = Expr of expr expr_form with sexp 85 | type defn = Defn of var * expr with sexp 86 | type prog = Prog of defn list * expr with sexp 87 | 88 | type 'annot ann_expr = AExpr of 'annot * ('annot ann_expr) expr_form with sexp 89 | type 'annot ann_defn = ADefn of var * 'annot ann_expr with sexp 90 | type 'annot ann_prog = 91 | AProg of 'annot * 'annot ann_defn list * 'annot ann_expr with sexp 92 | 93 | (* Convert an erased-lang type which uses function types to a version that 94 | uses closure types. *) 95 | let rec convert_type = function 96 | | E.TFun (i, o) | E.TCls (i, o) -> 97 | E.TCls (List.map ~f:convert_type i, convert_type o) 98 | | E.TDProd (bind, body) -> E.TDProd (bind, convert_type body) 99 | | E.TDSum (bind, body) -> E.TDSum (bind, convert_type body) 100 | | E.TArray (shp, elt) -> E.TArray (shp, convert_type elt) 101 | | E.TTuple elts -> E.TTuple (List.map ~f:convert_type elts) 102 | | E.TUnknown | E.TShape | E.TInt | E.TFloat | E.TBool | E.TVar as t -> t 103 | 104 | 105 | 106 | (* Closure-convert a MapRep AST. *) 107 | let rec expr_of_maprep 108 | (bound_vars: var list) 109 | (MR.AExpr ((typ, arg, app), e): 110 | (E.typ * arg_frame * app_frame) MR.ann_expr) 111 | : (E.typ * arg_frame * app_frame) ann_expr = 112 | (* Any term that stood for a function in the Map/Replicate IR now stands for 113 | a closure. *) 114 | let new_note = (convert_type typ, arg, app) in 115 | match e with 116 | | MR.App {MR.fn = f; MR.args = a} -> 117 | AExpr (new_note, App {closure = expr_of_maprep bound_vars f; 118 | args = List.map ~f:(expr_of_maprep bound_vars) a}) 119 | | MR.Vec {MR.dims = d; MR.elts = e} -> 120 | AExpr (new_note, Vec {MR.dims = d; 121 | MR.elts = List.map ~f:(expr_of_maprep bound_vars) e}) 122 | | MR.Map {MR.frame = fr; MR.fn = fn; MR.args = a; MR.shp = s} -> 123 | AExpr (new_note, Map {MR.frame = expr_of_maprep bound_vars fr; 124 | MR.fn = expr_of_maprep bound_vars fn; 125 | MR.args = List.map ~f:(expr_of_maprep bound_vars) a; 126 | MR.shp = expr_of_maprep bound_vars s}) 127 | | MR.Rep {MR.arg = a; MR.new_frame = n; MR.old_frame = o} -> 128 | AExpr (new_note, Rep {MR.arg = expr_of_maprep bound_vars a; 129 | MR.new_frame = expr_of_maprep bound_vars n; 130 | MR.old_frame = expr_of_maprep bound_vars o}) 131 | | MR.Tup e -> AExpr (new_note, 132 | Tup (List.map ~f:(expr_of_maprep bound_vars) e)) 133 | | MR.LetTup {MR.vars = v; MR.bound = bn; MR.body = bd} -> 134 | AExpr (new_note, LetTup {MR.vars = v; 135 | MR.bound = expr_of_maprep bound_vars bn; 136 | MR.body = expr_of_maprep bound_vars bd}) 137 | | MR.Fld {MR.field = n; MR.tuple = tup} -> 138 | AExpr (new_note, 139 | Fld {MR.field = n; MR.tuple = expr_of_maprep bound_vars tup}) 140 | | MR.Let {MR.var = v; MR.bound = bn; MR.body = bd} -> 141 | AExpr (new_note, Let {MR.var = v; 142 | MR.bound = expr_of_maprep bound_vars bn; 143 | MR.body = expr_of_maprep bound_vars bd}) 144 | | MR.Lam {MR.bindings = v; MR.body = b} -> 145 | let env_name = Basic_ast.gensym "__ENV_" 146 | (* Exclude variables bound by this lambda from the resulting free 147 | variable list. *) 148 | and bound_for_body = List.append v bound_vars in 149 | (* Need the list of free vars and their types in order to construct 150 | the type of the environment. *) 151 | let typed_free_vars = 152 | List.map ~f:(fun (MR.AExpr ((t,_,_), e)) -> 153 | match e with | MR.Var n -> Some (t, n) | _ -> None) 154 | (MR.get_annotated_free_vars bound_for_body b) |> 155 | List.filter_opt in 156 | let free_vars = List.map ~f:snd typed_free_vars 157 | and fv_types = List.map ~f:fst typed_free_vars in 158 | (* Figure out the type of the env component. *) 159 | let env_typ = E.TTuple fv_types in 160 | (* Pick apart the function type so it can be built up with one 161 | extra arg type (for the env) and the output type can be used for 162 | the new Lam's body. *) 163 | let (out_typ, code_typ) = (match typ with 164 | | E.TFun (i, o) -> (o, E.TFun (env_typ :: i, o)) 165 | | _ -> 166 | print_string 167 | "CConv Warning: generated Lam with non-TFun type annotation\n"; 168 | (E.TUnknown, typ)) in 169 | AExpr (new_note, 170 | Cls {code = AExpr 171 | ((code_typ, arg, app), 172 | Lam {MR.bindings = env_name :: v; 173 | MR.body = AExpr 174 | ((out_typ, arg, app), 175 | LetTup {MR.vars = free_vars; 176 | MR.bound = AExpr ((env_typ, arg, app), 177 | Var env_name); 178 | MR.body = expr_of_maprep bound_vars b})}); 179 | env = AExpr ((env_typ, arg, app), 180 | Tup (List.map ~f:(fun (t,v) -> 181 | AExpr ((t, NotArg, NotApp), Var v)) 182 | typed_free_vars))}) 183 | | MR.Var v -> AExpr (new_note, Var v) 184 | | MR.Int i -> AExpr (new_note, Int i) 185 | | MR.Float f -> AExpr (new_note, Float f) 186 | | MR.Bool b -> AExpr (new_note, Bool b) 187 | let defn_of_maprep 188 | (bound_vars: var list) 189 | (MR.ADefn (name, body): 'a MR.ann_defn) 190 | : 'a ann_defn = 191 | (* We include the defn-bound name just in case it wasn't passed in. *) 192 | ADefn (name, expr_of_maprep (name :: bound_vars) body) 193 | (* Can optionally pass in a list of built-in names *) 194 | let prog_of_maprep 195 | ?(bound_vars = []) 196 | (MR.AProg (a, defns, expr): 'a MR.ann_prog) 197 | : 'a ann_prog = 198 | let top_level_names = 199 | List.append 200 | (List.map ~f:(fun (MR.ADefn (n, _)) -> n) defns) 201 | bound_vars in 202 | AProg (a, List.map ~f:(defn_of_maprep top_level_names) defns, 203 | expr_of_maprep top_level_names expr) 204 | 205 | let rec annot_expr_drop (AExpr (_, e): 'a ann_expr) : expr = 206 | Expr (map_expr_form ~f:annot_expr_drop e) 207 | let annot_defn_drop (ADefn (name, body): 'a ann_defn) : defn = 208 | Defn (name, annot_expr_drop body) 209 | let annot_prog_drop (AProg (_, defns, expr): 'a ann_prog) : prog = 210 | Prog (List.map ~f:annot_defn_drop defns, annot_expr_drop expr) 211 | let rec annot_expr_map 212 | ~(f: 'a -> 'b) (AExpr (a, e): 'a ann_expr) 213 | : 'b ann_expr = 214 | AExpr (f a, map_expr_form ~f:(annot_expr_map ~f:f) e) 215 | let annot_defn_map 216 | ~(f: 'a -> 'b) 217 | (ADefn (name, body): 'a ann_defn) 218 | : 'b ann_defn = 219 | ADefn (name, annot_expr_map ~f:f body) 220 | let annot_prog_map 221 | ~(f: 'a -> 'b) 222 | (AProg (a, defns, expr): 'a ann_prog) 223 | : 'b ann_prog = 224 | AProg (f a, 225 | List.map ~f:(annot_defn_map ~f:f) defns, 226 | annot_expr_map ~f:f expr) 227 | 228 | (* A computed value of type 'v, along with a list of definitions accumulated 229 | in the process of computing it, and a monadic interface for working with 230 | these structures. 231 | TODO: May want to use something other than a list for accumulating 232 | definitions if appending gets too slow. *) 233 | module Defn_writer : sig 234 | type ('v, 'a) t = 'v * 'a ann_defn list 235 | val (>>=) : ('v, 'a) t -> ('v -> ('w, 'a) t) -> ('w, 'a) t 236 | val (>>|) : ('v, 'a) t -> ('v -> 'w) -> ('w, 'a) t 237 | val (>>) : ('v, 'a) t -> ('w, 'a) t -> ('w, 'a) t 238 | val bind : ('v, 'a) t -> ('v -> ('w, 'a) t) -> ('w, 'a) t 239 | val return : 'v -> ('v, 'a) t 240 | val map : ('v, 'a) t -> f:('v -> 'w) -> ('w, 'a) t 241 | val join : (('v, 'a) t, 'a) t -> ('v, 'a) t 242 | val all : ('v, 'a) t list -> ('v list, 'a) t 243 | val tell : 'a ann_defn list -> (unit, 'a) t 244 | end = struct 245 | type ('v, 'a) t = 'v * 'a ann_defn list 246 | let (>>=) 247 | ((val_in, defns_in): ('v, 'a) t) 248 | (f: 'v -> ('w, 'a) t) : ('w, 'a) t = 249 | let (val_ret, defns_ret) = f val_in in 250 | (val_ret, List.append defns_in defns_ret) 251 | let (>>|) 252 | ((val_in, defns_in): ('v, 'a) t) 253 | (f: 'v -> 'w) : ('w, 'a) t = 254 | (f val_in, defns_in) 255 | let (>>) x y = x >>= (fun _ -> y) 256 | let bind v f = v >>= f 257 | let return v = (v, []) 258 | let map t ~f = t >>| f 259 | let join t = t >>= (fun t' -> t') 260 | let all ts = (List.map ~f:fst ts, List.join (List.map ~f:snd ts)) 261 | let tell new_defns = ((), new_defns) 262 | end 263 | 264 | (* Traverse an expression, replacing Lam forms with fresh Var forms and 265 | generating a list of (annotated) definitions for those variables. *) 266 | let rec expr_hoist_lambdas 267 | ((AExpr (a, e): 'a ann_expr) as expr) 268 | : ('a ann_expr, 'a) Defn_writer.t = 269 | let open Defn_writer in 270 | (* In almost all cases, we just recur on all subexpressions and merge 271 | their results together. *) 272 | match e with 273 | | App {closure = clos; args = args} -> 274 | expr_hoist_lambdas clos >>= fun new_clos -> 275 | List.map ~f:expr_hoist_lambdas args |> all >>= fun new_args -> 276 | AExpr (a, App {closure = new_clos; args = new_args}) |> return 277 | | Vec {MR.dims = dims; MR.elts = elts} -> 278 | List.map ~f:expr_hoist_lambdas elts |> all >>= fun new_elts -> 279 | AExpr (a, Vec {MR.dims = dims; MR.elts = new_elts}) |> return 280 | | Map {MR.frame = frame; MR.fn = fn; MR.args = args; MR.shp = shp} -> 281 | expr_hoist_lambdas frame >>= fun new_frame -> 282 | expr_hoist_lambdas fn >>= fun new_fn -> 283 | List.map ~f:expr_hoist_lambdas args |> all >>= fun new_args -> 284 | expr_hoist_lambdas shp >>= fun new_shp -> 285 | AExpr (a, Map {MR.frame = new_frame; 286 | MR.fn = new_fn; 287 | MR.args = new_args; 288 | MR.shp = new_shp}) |> return 289 | | Rep {MR.arg = arg; MR.old_frame = oldf; MR.new_frame = newf} -> 290 | expr_hoist_lambdas arg >>= fun new_arg -> 291 | expr_hoist_lambdas oldf >>= fun new_oldf -> 292 | expr_hoist_lambdas newf >>= fun new_newf -> 293 | AExpr (a, Rep {MR.arg = new_arg; 294 | MR.old_frame = new_oldf; 295 | MR.new_frame = new_newf}) |> return 296 | | Tup elts -> 297 | List.map ~f:expr_hoist_lambdas elts |> all >>= fun new_elts -> 298 | AExpr (a, Tup new_elts) |> return 299 | | LetTup {MR.vars = vars; MR.bound = bound; MR.body = body} -> 300 | expr_hoist_lambdas bound >>= fun new_bound -> 301 | expr_hoist_lambdas body >>= fun new_body -> 302 | AExpr (a, LetTup {MR.vars = vars; 303 | MR.bound = new_bound; 304 | MR.body = new_body}) |> return 305 | | Fld {MR.field = n; MR.tuple = tup} -> 306 | expr_hoist_lambdas tup >>= fun new_tup -> 307 | AExpr (a, Fld {MR.field = n; MR.tuple = new_tup}) |> return 308 | | Let {MR.var = v; MR.bound = bound; MR.body = body} -> 309 | expr_hoist_lambdas bound >>= fun new_bound -> 310 | expr_hoist_lambdas body >>= fun new_body -> 311 | AExpr (a, Let {MR.var = v; 312 | MR.bound = new_bound; 313 | MR.body = new_body}) |> return 314 | | Cls {code = code; env = env} -> 315 | expr_hoist_lambdas code >>= fun new_code -> 316 | expr_hoist_lambdas env >>= fun new_env -> 317 | AExpr (a, Cls {code = new_code; env = new_env}) |> return 318 | (* This is the only interesting case. *) 319 | | Lam {MR.bindings = vars; MR.body = body} -> 320 | (* Generate a new name to use as a global variable. *) 321 | let new_global = B.gensym "__HOIST_" in 322 | (* Hoist any lambdas that appear within the function body. *) 323 | expr_hoist_lambdas body >>= fun new_body -> 324 | (* Emit a new definition for this function, with the converted body. *) 325 | tell [ADefn (new_global, AExpr (a, Lam {MR.bindings = vars; 326 | MR.body = new_body}))] >> 327 | (* Replace this function with the global variable. *) 328 | (AExpr (a, Var new_global) |> return) 329 | | Var _ | Int _ | Float _ | Bool _ -> return expr 330 | 331 | let defn_hoist_lambdas 332 | (ADefn (n, e) as d) : 'a ann_defn list = 333 | match e with 334 | | AExpr (_, Lam _) -> [d] 335 | | _ -> let (new_body, new_defns) = expr_hoist_lambdas e in 336 | ADefn (n, new_body) :: new_defns 337 | 338 | let prog_hoist_lambdas 339 | (AProg (annot, defns, expr)) : 'a ann_prog = 340 | let (new_expr, expr_defns) = expr_hoist_lambdas expr 341 | and new_defns = List.join (List.map ~f:defn_hoist_lambdas defns) in 342 | AProg (annot, List.append new_defns expr_defns, new_expr) 343 | 344 | module Passes : sig 345 | val prog : (E.typ * arg_frame * app_frame) MR.ann_prog 346 | -> (E.typ * arg_frame * app_frame) ann_prog 347 | val defn : (E.typ * arg_frame * app_frame) MR.ann_defn 348 | -> (E.typ * arg_frame * app_frame) ann_defn 349 | val expr : (E.typ * arg_frame * app_frame) MR.ann_expr 350 | -> (E.typ * arg_frame * app_frame) ann_expr 351 | 352 | val prog_all : B.rem_prog -> (E.typ * arg_frame * app_frame) ann_prog option 353 | val defn_all : B.rem_defn -> (E.typ * arg_frame * app_frame) ann_defn option 354 | val expr_all : B.rem_expr -> (E.typ * arg_frame * app_frame) ann_expr option 355 | val elt_all : B.rem_elt -> (E.typ * arg_frame * app_frame) ann_expr option 356 | end = struct 357 | let lib_vars = [] 358 | let prog remora = remora 359 | |> prog_of_maprep ~bound_vars:lib_vars 360 | |> prog_hoist_lambdas 361 | let defn remora = remora |> defn_of_maprep lib_vars 362 | let expr remora = remora |> expr_of_maprep lib_vars 363 | open Option.Monad_infix 364 | let prog_all remora = remora |> MR.Passes.prog_all >>| prog 365 | let defn_all remora = remora |> MR.Passes.defn_all >>| defn 366 | let expr_all remora = remora |> MR.Passes.expr_all >>| expr 367 | let elt_all remora = remora |> MR.Passes.elt_all >>| expr 368 | end 369 | -------------------------------------------------------------------------------- /src/remora-internal/closures.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | module MR = Map_replicate_ast;; 30 | module B = Basic_ast;; 31 | module E = Erased_ast;; 32 | open Frame_notes 33 | 34 | type var = Basic_ast.var with sexp 35 | 36 | type 'a cl_app_t = {closure: 'a; args: 'a list;} with sexp 37 | type 'a closure_t = {code: 'a; env: 'a} with sexp 38 | 39 | type 'a expr_form = 40 | | App of 'a cl_app_t 41 | | Vec of 'a MR.vec_t 42 | | Map of 'a MR.map_t 43 | | Rep of 'a MR.rep_t 44 | | Tup of 'a MR.tup_t 45 | | LetTup of 'a MR.lettup_t 46 | | Fld of 'a MR.fld_t 47 | | Let of 'a MR.let_t 48 | | Cls of 'a closure_t 49 | | Lam of 'a MR.lam_t 50 | | Var of var 51 | | Int of int 52 | | Float of float 53 | | Bool of bool 54 | with sexp 55 | 56 | val map_expr_form : f:('a -> 'b) -> 'a expr_form -> 'b expr_form 57 | 58 | type expr = Expr of expr expr_form with sexp 59 | type defn = Defn of var * expr with sexp 60 | type prog = Prog of defn list * expr with sexp 61 | 62 | type 'annot ann_expr = AExpr of 'annot * ('annot ann_expr) expr_form with sexp 63 | type 'annot ann_defn = ADefn of var * 'annot ann_expr with sexp 64 | type 'annot ann_prog = 65 | AProg of 'annot * 'annot ann_defn list * 'annot ann_expr with sexp 66 | 67 | val expr_of_maprep : 68 | var list 69 | -> (E.typ * arg_frame * app_frame) MR.ann_expr 70 | -> (E.typ * arg_frame * app_frame) ann_expr 71 | 72 | val annot_expr_drop : 'a ann_expr -> expr 73 | val annot_defn_drop : 'a ann_defn -> defn 74 | val annot_prog_drop : 'a ann_prog -> prog 75 | 76 | val annot_expr_map : f:('a -> 'b) -> 'a ann_expr -> 'b ann_expr 77 | val annot_defn_map : f:('a -> 'b) -> 'a ann_defn -> 'b ann_defn 78 | val annot_prog_map : f:('a -> 'b) -> 'a ann_prog -> 'b ann_prog 79 | 80 | module Defn_writer : sig 81 | type ('v, 'a) t = 'v * 'a ann_defn list 82 | val (>>=) : ('v, 'a) t -> ('v -> ('w, 'a) t) -> ('w, 'a) t 83 | val (>>|) : ('v, 'a) t -> ('v -> 'w) -> ('w, 'a) t 84 | val (>>) : ('v, 'a) t -> ('w, 'a) t -> ('w, 'a) t 85 | val bind : ('v, 'a) t -> ('v -> ('w, 'a) t) -> ('w, 'a) t 86 | val return : 'v -> ('v, 'a) t 87 | val map : ('v, 'a) t -> f:('v -> 'w) -> ('w, 'a) t 88 | val join : (('v, 'a) t, 'a) t -> ('v, 'a) t 89 | val all : ('v, 'a) t list -> ('v list, 'a) t 90 | val tell : 'a ann_defn list -> (unit, 'a) t 91 | end 92 | 93 | val expr_hoist_lambdas : 'a ann_expr -> ('a ann_expr, 'a) Defn_writer.t 94 | 95 | module Passes : sig 96 | val prog : (E.typ * arg_frame * app_frame) MR.ann_prog 97 | -> (E.typ * arg_frame * app_frame) ann_prog 98 | val defn : (E.typ * arg_frame * app_frame) MR.ann_defn 99 | -> (E.typ * arg_frame * app_frame) ann_defn 100 | val expr : (E.typ * arg_frame * app_frame) MR.ann_expr 101 | -> (E.typ * arg_frame * app_frame) ann_expr 102 | 103 | val prog_all : B.rem_prog -> (E.typ * arg_frame * app_frame) ann_prog option 104 | val defn_all : B.rem_defn -> (E.typ * arg_frame * app_frame) ann_defn option 105 | val expr_all : B.rem_expr -> (E.typ * arg_frame * app_frame) ann_expr option 106 | val elt_all : B.rem_elt -> (E.typ * arg_frame * app_frame) ann_expr option 107 | end 108 | -------------------------------------------------------------------------------- /src/remora-internal/erased_ast.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | module B = Basic_ast 30 | open Frame_notes 31 | 32 | type var = B.var with sexp 33 | 34 | type idx = B.idx with sexp 35 | 36 | type srt = B.srt with sexp 37 | 38 | type typ = 39 | | TInt 40 | | TFloat 41 | | TBool 42 | | TVar 43 | | TUnknown 44 | | TFun of (typ list * typ) 45 | | TCls of (typ list * typ) 46 | | TDProd of ((var * srt) list * typ) 47 | | TDSum of ((var * srt) list * typ) 48 | | TArray of (idx * typ) 49 | | TTuple of typ list 50 | | TShape 51 | with sexp 52 | val of_typ : B.typ -> typ 53 | 54 | val shape_of_typ : typ -> idx list option 55 | val elt_of_typ : typ -> typ option 56 | val typ_of_shape : idx list -> typ -> typ 57 | 58 | type ('self_t, 'elt_t) expr_form = 59 | | App of ('self_t * 'self_t list * typ) 60 | | IApp of ('self_t * idx list) 61 | | ILam of ((var * srt) list * 'self_t) 62 | | Arr of (int list * 'elt_t list) 63 | | Var of var 64 | | Pack of (idx list * 'self_t) 65 | | Unpack of (var list * var * 'self_t * 'self_t) 66 | | Let of (var * 'self_t * 'self_t) 67 | | Tuple of 'self_t list 68 | | Field of int * 'self_t 69 | | LetTup of (var list * 'self_t * 'self_t) 70 | and ('self_t, 'expr_t) elt_form = 71 | | Float of float 72 | | Int of int 73 | | Bool of bool 74 | | Lam of (var list * 'expr_t) 75 | | Expr of 'expr_t 76 | with sexp 77 | 78 | val map_expr_form : 79 | f_expr: ('old_self_t -> 'new_self_t) 80 | -> f_elt: ('old_elt_t -> 'new_elt_t) 81 | -> ('old_self_t, 'old_elt_t) expr_form 82 | -> ('new_self_t, 'new_elt_t) expr_form 83 | 84 | val map_elt_form : 85 | f_expr: ('old_expr_t -> 'new_expr_t) 86 | -> ('old_self_t, 'old_expr_t) elt_form 87 | -> ('new_self_t, 'new_expr_t) elt_form 88 | 89 | type erased_expr = EExpr of (erased_expr, erased_elt) expr_form 90 | and erased_elt = EElt of (erased_elt, erased_expr) elt_form 91 | with sexp 92 | type erased_defn = EDefn of var * typ * erased_expr with sexp 93 | type erased_prog = EProg of erased_defn list * erased_expr with sexp 94 | 95 | type 'annot ann_expr = 96 | | AnnEExpr of 'annot * ('annot ann_expr, 'annot ann_elt) expr_form 97 | and 'annot ann_elt = 98 | | AnnEElt of 'annot * ('annot ann_elt, 'annot ann_expr) elt_form 99 | with sexp 100 | type 'annot ann_defn = AnnEDefn of var * typ * 'annot ann_expr with sexp 101 | type 'annot ann_prog = 102 | | AnnEProg of 'annot * 'annot ann_defn list * 'annot ann_expr 103 | with sexp 104 | 105 | val of_expr : B.rem_expr -> erased_expr 106 | val of_elt : B.rem_elt -> erased_elt 107 | val of_defn : B.rem_defn -> erased_defn 108 | val of_prog : B.rem_prog -> erased_prog 109 | 110 | val of_ann_expr : 111 | ?merge:('annot -> 'annot -> 'annot) 112 | -> 'annot B.ann_expr 113 | -> 'annot ann_expr 114 | val of_ann_elt : 115 | ?merge:('annot -> 'annot -> 'annot) 116 | -> 'annot B.ann_elt 117 | -> 'annot ann_elt 118 | val of_ann_defn : 119 | ?merge:('annot -> 'annot -> 'annot) 120 | -> 'annot B.ann_defn 121 | -> 'annot ann_defn 122 | val of_ann_prog : 123 | ?merge:('annot -> 'annot -> 'annot) 124 | -> 'annot B.ann_prog 125 | -> 'annot ann_prog 126 | 127 | val annot_expr_drop : 'annot ann_expr -> erased_expr 128 | val annot_elt_drop : 'annot ann_elt -> erased_elt 129 | val annot_defn_drop : 'annot ann_defn -> erased_defn 130 | val annot_prog_drop : 'annot ann_prog -> erased_prog 131 | 132 | val fix_expr_app_type : typ ann_expr -> typ ann_expr 133 | val fix_elt_app_type : typ ann_elt -> typ ann_elt 134 | val fix_defn_app_type : typ ann_defn -> typ ann_defn 135 | val fix_prog_app_type : typ ann_prog -> typ ann_prog 136 | 137 | val annot_of_expr : 'annot ann_expr -> 'annot 138 | val annot_of_elt : 'annot ann_elt -> 'annot 139 | val annot_of_defn : 'annot ann_defn -> 'annot 140 | val annot_of_prog : 'annot ann_prog -> 'annot 141 | 142 | val annot_expr_merge : 143 | ('a -> 'b -> 'c) -> 'a ann_expr -> 'b ann_expr -> 'c ann_expr option 144 | val annot_elt_merge : 145 | ('a -> 'b -> 'c) -> 'a ann_elt -> 'b ann_elt -> 'c ann_elt option 146 | val annot_defn_merge : 147 | ('a -> 'b -> 'c) -> 'a ann_defn -> 'b ann_defn -> 'c ann_defn option 148 | val annot_prog_merge : 149 | ('a -> 'b -> 'c) -> 'a ann_prog -> 'b ann_prog -> 'c ann_prog option 150 | 151 | val annot_expr_fmap : f:('a -> 'b) -> 'a ann_expr -> 'b ann_expr 152 | val annot_elt_fmap : f:('a -> 'b) -> 'a ann_elt -> 'b ann_elt 153 | val annot_defn_fmap : f:('a -> 'b) -> 'a ann_defn -> 'b ann_defn 154 | val annot_prog_fmap : f:('a -> 'b) -> 'a ann_prog -> 'b ann_prog 155 | 156 | module Passes : sig 157 | val prog : 158 | (B.typ * arg_frame * app_frame) B.ann_prog 159 | -> (typ * arg_frame * app_frame) ann_prog 160 | val defn : 161 | (B.typ * arg_frame * app_frame) B.ann_defn 162 | -> (typ * arg_frame * app_frame) ann_defn 163 | val expr : 164 | (B.typ * arg_frame * app_frame) B.ann_expr 165 | -> (typ * arg_frame * app_frame) ann_expr 166 | val elt : 167 | (B.typ * arg_frame * app_frame) B.ann_elt 168 | -> (typ * arg_frame * app_frame) ann_elt 169 | 170 | val prog_all : B.rem_prog -> (typ * arg_frame * app_frame) ann_prog option 171 | val defn_all : B.rem_defn -> (typ * arg_frame * app_frame) ann_defn option 172 | val expr_all : B.rem_expr -> (typ * arg_frame * app_frame) ann_expr option 173 | val elt_all : B.rem_elt -> (typ * arg_frame * app_frame) ann_elt option 174 | end 175 | -------------------------------------------------------------------------------- /src/remora-internal/frame_notes.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | open Annotation 32 | open Typechecker 33 | 34 | type app_frame = 35 | | AppFrame of idx list 36 | | NotApp 37 | with sexp 38 | 39 | let app_frame_of_option = function 40 | | Some dims -> AppFrame dims | None -> NotApp 41 | let option_of_app_frame = function 42 | | AppFrame dims -> Some dims | NotApp -> None 43 | let idxs_of_app_frame_exn e = match e with 44 | | AppFrame idxs -> idxs 45 | | _ -> raise (Failure ((e |> sexp_of_app_frame |> string_of_sexp) 46 | ^ " is not an App form")) 47 | 48 | (* Annotate an AST with the frame shape of each application form. *) 49 | let rec annot_expr_app_frame 50 | ((AnnRExpr (node_type, expr)): typ ann_expr) : app_frame ann_expr = 51 | match expr with 52 | (* TODO: broaden this case to match nested arrays of functions *) 53 | | App ((AnnRExpr (fn_position_typ, _) as fn_expr), 54 | args)-> 55 | let fn_typ = elt_of_typ fn_position_typ in 56 | let ret_type = 57 | (match fn_typ with 58 | | Some (TFun (_, ret)) -> ret 59 | (* The element type of the array in function position is not a function 60 | type (should not happen in well-typed AST). *) 61 | | _ -> assert false) in 62 | (match frame_contribution ret_type node_type with 63 | | Some idxs -> AnnRExpr (AppFrame idxs, 64 | App (annot_expr_app_frame fn_expr, 65 | List.map ~f:annot_expr_app_frame args)) 66 | (* We have an app form whose type is not a frame around its function's 67 | return type (should not happen in well-typed AST). *) 68 | | None -> assert false) 69 | | _ -> AnnRExpr (NotApp, (map_expr_form 70 | ~f_expr:annot_expr_app_frame 71 | ~f_elt:annot_elt_app_frame 72 | expr)) 73 | and annot_elt_app_frame 74 | ((AnnRElt (_, elt)): typ ann_elt) : app_frame ann_elt = 75 | AnnRElt (NotApp, map_elt_form ~f_expr:annot_expr_app_frame elt) 76 | let annot_defn_app_frame 77 | ((AnnRDefn (n, t, e)): typ ann_defn) : app_frame ann_defn = 78 | AnnRDefn (n, t, annot_expr_app_frame e) 79 | let annot_prog_app_frame 80 | ((AnnRProg (_, defns, expr)): typ ann_prog) : app_frame ann_prog = 81 | let new_expr = annot_expr_app_frame expr in 82 | let new_annot = annot_of_expr new_expr in 83 | AnnRProg (new_annot, 84 | List.map ~f:annot_defn_app_frame defns, 85 | new_expr) 86 | 87 | 88 | type arg_frame_rec = {frame: idx list; expansion: idx list} with sexp 89 | type arg_frame = 90 | | ArgFrame of arg_frame_rec 91 | | NotArg 92 | with sexp 93 | 94 | let arg_frame_of_option = function 95 | | Some dims -> ArgFrame dims | None -> NotArg 96 | let option_of_arg_frame = function 97 | | ArgFrame dims -> Some dims | NotArg -> None 98 | let frame_of_arg_exn = function 99 | | ArgFrame {frame = f; expansion = _} -> f 100 | | NotArg -> raise (Failure "Not an arg") 101 | let expansion_of_arg_exn = function 102 | | ArgFrame {frame = _; expansion = e} -> e 103 | | NotArg -> raise (Failure "Not an arg") 104 | 105 | 106 | 107 | (* Annotate subnodes of function application nodes with their argument frame 108 | shapes. We have to track whether we're being called on a node that is 109 | directly part of an application form (e.g., function position in a type/index 110 | applciation or the body of an abstraction form) and if so, what enclosing app 111 | form's principal frame is. If we are currently at an argument, we note how 112 | its frame must expand in order to match the application's frame. Otherwise, 113 | we mark it as NotArg. *) 114 | let rec annot_expr_arg_expansion 115 | ((AnnRExpr ((node_app_frame, node_type), expr)): (app_frame * typ) ann_expr) 116 | ~(outer_expectation: typ option) 117 | ~(outer_frame : app_frame) : arg_frame ann_expr = 118 | let open Option.Monad_infix in 119 | let my_expansion : arg_frame = 120 | (option_of_app_frame outer_frame >>= fun outer_frame_shape -> 121 | outer_expectation >>= fun outer_t -> 122 | frame_contribution outer_t node_type >>= fun my_frame -> 123 | shape_drop my_frame outer_frame_shape >>= fun missing_dims -> 124 | Option.return (ArgFrame {expansion = missing_dims; frame = my_frame})) 125 | |> (Option.value ~default:NotArg) in 126 | match expr with 127 | | App ((AnnRExpr ((_, fn_type), _)) as fn, args) -> 128 | let arg_expected_typs = match elt_of_typ fn_type with 129 | | (Some (TFun (typs, _))) -> typs 130 | (* In a well-typed AST, the array in function position should have 131 | functions as its elements. *) 132 | | _ -> assert false in 133 | AnnRExpr (my_expansion, App (annot_expr_arg_expansion 134 | ~outer_expectation:(elt_of_typ fn_type) 135 | ~outer_frame:node_app_frame fn, 136 | List.map2_exn 137 | ~f:(fun expect arg -> 138 | annot_expr_arg_expansion 139 | ~outer_expectation:(Some expect) 140 | ~outer_frame:node_app_frame arg) 141 | arg_expected_typs args)) 142 | | _ -> AnnRExpr (my_expansion, (map_expr_form 143 | ~f_expr:(annot_expr_arg_expansion 144 | ~outer_expectation:None 145 | ~outer_frame:NotApp) 146 | ~f_elt:annot_elt_arg_expansion expr)) 147 | 148 | and annot_elt_arg_expansion 149 | ((AnnRElt (_, elt)): (app_frame * typ) ann_elt) 150 | : arg_frame ann_elt = 151 | match elt with 152 | | Expr e -> AnnRElt (NotArg, 153 | Expr (annot_expr_arg_expansion 154 | ~outer_expectation:None 155 | ~outer_frame:NotApp e)) 156 | | Lam (bindings, body) -> 157 | AnnRElt (NotArg, Lam (bindings, 158 | annot_expr_arg_expansion 159 | ~outer_expectation:None 160 | ~outer_frame:NotApp body)) 161 | | Float _ | Int _ | Bool _ as l -> AnnRElt (NotArg, l) 162 | 163 | let annot_defn_arg_expansion 164 | ((AnnRDefn (n, t, e)): (app_frame * typ) ann_defn) : arg_frame ann_defn = 165 | AnnRDefn (n, t, annot_expr_arg_expansion 166 | ~outer_expectation:None 167 | ~outer_frame:NotApp e) 168 | 169 | let annot_prog_arg_expansion 170 | ((AnnRProg (_, defns, expr)): (app_frame * typ) ann_prog) 171 | : arg_frame ann_prog = 172 | let (AnnRExpr (new_annot, _)) as new_expr = 173 | annot_expr_arg_expansion 174 | ~outer_expectation:None 175 | ~outer_frame:NotApp 176 | expr 177 | and new_defns = List.map ~f:annot_defn_arg_expansion defns in 178 | AnnRProg (new_annot, new_defns, new_expr) 179 | 180 | module Passes : sig 181 | val prog : typ ann_prog -> (typ * arg_frame * app_frame) ann_prog 182 | val defn : typ ann_defn -> (typ * arg_frame * app_frame) ann_defn 183 | val expr : typ ann_expr -> (typ * arg_frame * app_frame) ann_expr 184 | val elt : typ ann_elt -> (typ * arg_frame * app_frame) ann_elt 185 | 186 | val prog_all : rem_prog -> (typ * arg_frame * app_frame) ann_prog option 187 | val defn_all : rem_defn -> (typ * arg_frame * app_frame) ann_defn option 188 | val expr_all : rem_expr -> (typ * arg_frame * app_frame) ann_expr option 189 | val elt_all : rem_elt -> (typ * arg_frame * app_frame) ann_elt option 190 | end = struct 191 | open Annotation 192 | open Option.Monad_infix 193 | let triple x (y, z) = (x, y, z) 194 | let prog typ_ast = 195 | let app_ast = annot_prog_app_frame typ_ast in 196 | let app_typ_ast = Option.value_exn 197 | ~message:"Failed to merge type and app-frame annotations" 198 | (annot_prog_merge Tuple2.create app_ast typ_ast) in 199 | let arg_ast = annot_prog_arg_expansion app_typ_ast in 200 | let arg_app_ast = Option.value_exn 201 | ~message:"Failed to merge app-frame and arg-expansion annotations" 202 | (annot_prog_merge Tuple2.create arg_ast app_ast) in 203 | Option.value_exn 204 | ~message:"Failed to merge typ and app/arg annotations" 205 | (annot_prog_merge triple typ_ast arg_app_ast) 206 | let prog_all ast = 207 | ast |> Typechecker.Passes.prog_all >>| prog 208 | 209 | let defn typ_ast = 210 | let app_ast = annot_defn_app_frame typ_ast in 211 | let app_typ_ast = Option.value_exn 212 | ~message:"Failed to merge type and app-frame annotations" 213 | (annot_defn_merge Tuple2.create app_ast typ_ast) in 214 | let arg_ast = annot_defn_arg_expansion app_typ_ast in 215 | let arg_app_ast = Option.value_exn 216 | ~message:"Failed to merge app-frame and arg-expansion annotations" 217 | (annot_defn_merge Tuple2.create arg_ast app_ast) in 218 | Option.value_exn 219 | ~message:"Failed to merge typ and app/arg annotations" 220 | (annot_defn_merge triple typ_ast arg_app_ast) 221 | let defn_all ast = 222 | ast |> Typechecker.Passes.defn_all >>| defn 223 | 224 | let expr typ_ast = 225 | let app_ast = annot_expr_app_frame typ_ast in 226 | let app_typ_ast = Option.value_exn 227 | ~message:"Failed to merge type and app-frame annotations" 228 | (annot_expr_merge Tuple2.create app_ast typ_ast) in 229 | let arg_ast = annot_expr_arg_expansion 230 | ~outer_expectation:None 231 | ~outer_frame:NotApp 232 | app_typ_ast in 233 | let arg_app_ast = Option.value_exn 234 | ~message:"Failed to merge app-frame and arg-expansion annotations" 235 | (annot_expr_merge Tuple2.create arg_ast app_ast) in 236 | Option.value_exn 237 | ~message:"Failed to merge typ and app/arg annotations" 238 | (annot_expr_merge triple typ_ast arg_app_ast) 239 | let expr_all ast = 240 | ast |> Typechecker.Passes.expr_all >>| expr 241 | 242 | let elt typ_ast = 243 | let app_ast = annot_elt_app_frame typ_ast in 244 | let app_typ_ast = Option.value_exn 245 | ~message:"Failed to merge type and app-frame annotations" 246 | (annot_elt_merge Tuple2.create app_ast typ_ast) in 247 | let arg_ast = annot_elt_arg_expansion app_typ_ast in 248 | let arg_app_ast = Option.value_exn 249 | ~message:"Failed to merge app-frame and arg-expansion annotations" 250 | (annot_elt_merge Tuple2.create arg_ast app_ast) in 251 | Option.value_exn 252 | ~message:"Failed to merge typ and app/arg annotations" 253 | (annot_elt_merge triple typ_ast arg_app_ast) 254 | let elt_all ast = 255 | ast |> Typechecker.Passes.elt_all >>| elt 256 | end 257 | -------------------------------------------------------------------------------- /src/remora-internal/frame_notes.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Basic_ast 30 | 31 | type app_frame = 32 | | AppFrame of idx list 33 | | NotApp 34 | with sexp 35 | 36 | val app_frame_of_option : idx list option -> app_frame 37 | val option_of_app_frame : app_frame -> idx list option 38 | val idxs_of_app_frame_exn : app_frame -> idx list 39 | 40 | val annot_expr_app_frame : typ ann_expr -> app_frame ann_expr 41 | val annot_elt_app_frame : typ ann_elt -> app_frame ann_elt 42 | val annot_defn_app_frame : typ ann_defn -> app_frame ann_defn 43 | val annot_prog_app_frame : typ ann_prog -> app_frame ann_prog 44 | 45 | 46 | type arg_frame_rec = {frame: idx list; expansion: idx list} with sexp 47 | type arg_frame = 48 | | ArgFrame of arg_frame_rec 49 | | NotArg 50 | with sexp 51 | 52 | val arg_frame_of_option : arg_frame_rec option -> arg_frame 53 | val option_of_arg_frame : arg_frame -> arg_frame_rec option 54 | val frame_of_arg_exn : arg_frame -> idx list 55 | val expansion_of_arg_exn : arg_frame -> idx list 56 | 57 | val annot_expr_arg_expansion : 58 | (app_frame * typ) ann_expr 59 | -> outer_expectation: typ option 60 | -> outer_frame: app_frame 61 | -> arg_frame ann_expr 62 | val annot_elt_arg_expansion : 63 | (app_frame * typ) ann_elt 64 | -> arg_frame ann_elt 65 | val annot_defn_arg_expansion : 66 | (app_frame * typ) ann_defn 67 | -> arg_frame ann_defn 68 | val annot_prog_arg_expansion : 69 | (app_frame * typ) ann_prog 70 | -> arg_frame ann_prog 71 | 72 | module Passes : sig 73 | val prog : typ ann_prog -> (typ * arg_frame * app_frame) ann_prog 74 | val defn : typ ann_defn -> (typ * arg_frame * app_frame) ann_defn 75 | val expr : typ ann_expr -> (typ * arg_frame * app_frame) ann_expr 76 | val elt : typ ann_elt -> (typ * arg_frame * app_frame) ann_elt 77 | 78 | val prog_all : rem_prog -> (typ * arg_frame * app_frame) ann_prog option 79 | val defn_all : rem_defn -> (typ * arg_frame * app_frame) ann_defn option 80 | val expr_all : rem_expr -> (typ * arg_frame * app_frame) ann_expr option 81 | val elt_all : rem_elt -> (typ * arg_frame * app_frame) ann_elt option 82 | end 83 | -------------------------------------------------------------------------------- /src/remora-internal/globals.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | 32 | let scalar t = TArray (IShape [], t) 33 | let vec l t = TArray (IShape [l], t) 34 | let vecv l t = vec (nvar l) t 35 | let vecn l t = vec (INat l) t 36 | let func ins out = scalar (TFun (ins, out)) 37 | 38 | let arith_unary t = 39 | func [scalar t] (scalar t) 40 | let arith_binary t = 41 | func [scalar t; scalar t] (scalar t) 42 | let compare t = 43 | func [scalar t; scalar t] (scalar TBool) 44 | 45 | let any i t = TArray (svar i, TVar t) 46 | 47 | let assoc_all names typ = List.map ~f:(fun n -> (n,typ)) names 48 | 49 | let int_arith_binary = 50 | assoc_all ["+"; "-"; "*"; "/"; "^"; "rand"] 51 | (arith_binary TInt) 52 | 53 | let float_arith_binary = 54 | assoc_all ["+."; "-."; "*."; "/."; "^.";] 55 | (arith_binary TFloat) 56 | let float_arith_unary = 57 | assoc_all ["sin"; "cos"; "tan"; "log"; "lg"; "ln"; "sqrt"; "e^"] 58 | (arith_unary TFloat) 59 | 60 | let logic_binary = 61 | assoc_all ["and"; "or"; "xor"] 62 | (arith_binary TBool) 63 | let logic_unary = 64 | assoc_all ["not"] 65 | (arith_unary TBool) 66 | 67 | let int_compare = 68 | assoc_all [">"; ">="; "<"; "<="; "="; "!="] 69 | (compare TInt) 70 | 71 | let float_compare = 72 | assoc_all [">."; ">=."; "<."; "<=."; "=."; "!=."] 73 | (compare TFloat) 74 | 75 | let num_coerce = 76 | ("float", func [TInt] TFloat) :: 77 | (assoc_all ["round"; "floor"; "ceil"] (func [TFloat] TInt)) 78 | 79 | let choice = 80 | let any = any "s" "t" in 81 | ["choice", 82 | TDProd (["s", SShape], TAll (["t"], func [scalar TBool; any; any] any))] 83 | 84 | let head_tail = 85 | let any = any "s" "t" in 86 | assoc_all ["head"; "tail"] 87 | (TDProd (["l", SNat; "s", SShape], 88 | TAll (["t"], 89 | func [TArray (IShape [ISum (nvar "l", 90 | INat 1)], any)] any))) 91 | 92 | let behead_curtail = 93 | let any = any "s" "t" in 94 | assoc_all ["head"; "tail"] 95 | (TDProd (["l", SNat; "s", SShape], 96 | TAll (["t"], 97 | func [TArray (IShape [ISum (nvar "l", 98 | INat 1)], any)] 99 | (TArray (IShape [nvar "l"], any))))) 100 | 101 | let take_drop = 102 | let any = any "s" "t" in 103 | assoc_all ["take"; "take-right"; "drop"; "drop-right"] 104 | (TDProd (["l", SNat; "s", SShape], 105 | TAll (["t"], func [scalar TInt; vecv "l" any] 106 | (TDSum (["n", SNat], vecv "n" any))))) 107 | 108 | let take_witness = 109 | let any = any "s" "t" in 110 | assoc_all ["take*"; "take-right*"] 111 | (TDProd (["l", SNat; "m", SNat; "s", SShape], 112 | TAll (["u"; "t"], func [vecv "l" (TVar "u"); 113 | vec (ISum (nvar "l", nvar "m")) any] 114 | (vecv "l" any)))) 115 | 116 | let drop_witness = 117 | let any = any "s" "t" in 118 | assoc_all ["drop*"; "drop-right*"] 119 | (TDProd (["l", SNat; "m", SNat; "s", SShape], 120 | TAll (["u"; "t"], func [vecv "l" (TVar "u"); 121 | vec (ISum (nvar "l", nvar "m")) any] 122 | (vecv "m" any)))) 123 | 124 | let reverse = 125 | let any = any "s" "t" in 126 | ["reverse", TDProd (["s", SShape], 127 | TAll (["t"], 128 | func [any] any))] 129 | 130 | let rotate = 131 | let any = any "s" "t" in 132 | ["rotate", TDProd (["l", SNat; "s", SShape], 133 | TAll (["t"], 134 | func [scalar TInt; vecv "l" any] (vecv "l" any)))] 135 | 136 | let append = 137 | let any = any "s" "t" in 138 | ["append", TDProd (["l", SNat; "m", SNat; "s", SShape], 139 | TAll (["t"], 140 | func [vecv "l" any; vecv "m" any] 141 | (vec (ISum (nvar "l", nvar "m")) any)))] 142 | 143 | let itemize = 144 | let any = any "s" "t" in 145 | ["itemize", TDProd (["s", SShape], 146 | TAll (["t"], 147 | func [any] 148 | (vecn 1 any)))] 149 | 150 | let ravel_shape = 151 | assoc_all ["ravel"; "shape"] 152 | (TDProd (["s", SShape], 153 | TAll (["t"], 154 | func [any "s" "t"] 155 | (TDSum ([], vecv "l" (TVar "t")))))) 156 | 157 | let length = 158 | ["length", TDProd (["s", SShape], 159 | TAll (["t"], 160 | func [any "s" "t"] (scalar TInt)))] 161 | 162 | let left_fold = 163 | let anyl = any "sl" "tl" 164 | and anyr = any "sr" "tr" in 165 | ["foldl", TDProd (["l", SNat; "sl", SShape; "sr", SShape], 166 | TAll (["tl"; "tr"], 167 | func [func [anyr; anyl] anyl; 168 | anyl; 169 | vecv "l" anyr] 170 | anyl))] 171 | 172 | let right_fold = 173 | let anyl = any "sl" "tl" 174 | and anyr = any "sr" "tr" in 175 | ["foldr", TDProd (["l", SNat; "sl", SShape; "sr", SShape], 176 | TAll (["tl"; "tr"], 177 | func [func [anyr; anyl] anyr; 178 | anyr; 179 | vecv "l" anyl] 180 | anyr))] 181 | 182 | let left_scan = 183 | let anyl = any "sl" "tl" 184 | and anyr = any "sr" "tr" in 185 | ["scanl", TDProd (["l", SNat; "sl", SShape; "sr", SShape], 186 | TAll (["tl"; "tr"], 187 | func [func [anyr; anyl] anyl; 188 | anyl; 189 | vecv "l" anyr] 190 | (vecv "l" anyl)))] 191 | 192 | let right_scan = 193 | let anyl = any "sl" "tl" 194 | and anyr = any "sr" "tr" in 195 | ["scanr", TDProd (["l", SNat; "sl", SShape; "sr", SShape], 196 | TAll (["tl"; "tr"], 197 | func [func [anyr; anyl] anyr; 198 | anyr; 199 | vecv "l" anyl] 200 | (vecv "l" anyr)))] 201 | 202 | let reduce = 203 | let any = any "s" "t" in 204 | ["reduce", TDProd (["l", SNat; "s", SShape], 205 | TAll (["t"], 206 | func [func [any; any] any; 207 | vecv "l" any] 208 | any))] 209 | 210 | let filter = 211 | let any = any "s" "t" in 212 | ["filter", TDProd (["l", SNat; "s", SShape], 213 | TAll (["t"], 214 | func [vecv "l" (scalar TBool); 215 | vecv "l" any] 216 | (TDSum (["m", SNat], vecv "m" any))))] 217 | 218 | let iota = 219 | ["iota", TDProd (["l", SNat], 220 | func [vecv "l" TInt] 221 | (TDSum (["s", SShape], TArray (svar "s", TInt))))] 222 | 223 | let iota_vector = 224 | ["iotavec", func [scalar TInt] (TDSum (["l", SNat], vecv "l" TInt))] 225 | 226 | let iota_witness = 227 | ["iota*", TDProd (["s", SShape], 228 | TAll (["t"], 229 | func [any "s" "t"] (TArray (svar "s", TInt))))] 230 | 231 | let read name t = 232 | [name, func [] (TDSum (["s", SShape], TArray (svar "s", t)))] 233 | let readvec name t = 234 | [name, func [] (TDSum (["l", SNat], vecv "l" t))] 235 | let readscal name t = 236 | [name, func [] t] 237 | let write name t = 238 | [name, TDProd (["s", SShape], 239 | func [TArray (svar "s", t)] TBool)] 240 | 241 | let read_basetype = 242 | List.join [read "read_i" TInt; 243 | read "read_f" TFloat; 244 | read "read_b" TBool] 245 | let readvec_basetype = 246 | List.join [readvec "readvec_i" TInt; 247 | readvec "readvec_f" TFloat; 248 | readvec "readvec_b" TBool] 249 | let readscal_basetype = 250 | List.join [readscal "readscal_i" TInt; 251 | readscal "readscal_f" TFloat; 252 | readscal "readscal_b" TBool] 253 | let write_basetype = 254 | List.join [write "write_i" TInt; 255 | write "write_f" TFloat; 256 | write "write_b" TBool] 257 | 258 | let builtins = 259 | List.join [int_arith_binary; 260 | float_arith_binary; 261 | float_arith_unary; 262 | logic_binary; 263 | logic_unary; 264 | int_compare; 265 | float_compare; 266 | num_coerce; 267 | choice; 268 | head_tail; 269 | behead_curtail; 270 | take_drop; 271 | take_witness; 272 | drop_witness; 273 | reverse; 274 | rotate; 275 | append; 276 | itemize; 277 | ravel_shape; 278 | length; 279 | left_fold; 280 | right_fold; 281 | left_scan; 282 | right_scan; 283 | reduce; 284 | filter; 285 | iota; 286 | iota_vector; 287 | iota_witness; 288 | read_basetype; 289 | readvec_basetype; 290 | readscal_basetype; 291 | write_basetype] 292 | 293 | let lam_version (n, t) : (var * typ) option = 294 | match t with 295 | | TArray (IShape [], tfun) -> Some ("__lam_" ^ n, tfun) 296 | (* This seems unlikely to work *) 297 | (* | TDProd (bindings, TArray (IShape [], tfun)) *) 298 | (* -> Some ("__lam_" ^ n, *) 299 | (* TDProd (bindings, tfun)) *) 300 | | _ -> None 301 | let builtin_lams = 302 | List.filter_opt (List.map ~f:lam_version builtins) 303 | 304 | let rec atomlevel_version (n, t) : (var * typ) option = 305 | let open Option.Monad_infix in 306 | match t with 307 | | TArray (IShape [], TFun (ins, out)) -> 308 | List.map ~f:(function 309 | | TArray (_, TArray _) -> None 310 | | TArray (IShape [], elt) -> Some elt 311 | | _ -> None) ins |> Option.all >>= fun in_elts -> 312 | Some ("__atm_" ^ n, (TFun (in_elts, out))) 313 | (* | TDProd (bindings, (TArray _ as arrtype)) -> *) 314 | (* atomlevel_version (n, arrtype) >>= fun (new_name, new_arrtype) -> *) 315 | (* Some (new_name, TDProd (bindings, new_arrtype)) *) 316 | | _ -> None 317 | let builtin_atomlevels = 318 | List.filter_opt (List.map ~f:atomlevel_version builtins) 319 | 320 | (* TODO: need some way to get type-specific versions of polymorphic primops *) 321 | -------------------------------------------------------------------------------- /src/remora-internal/globals.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | 32 | val builtins : (string, typ) List.Assoc.t 33 | val builtin_lams : (string, typ) List.Assoc.t 34 | val builtin_atomlevels : (string, typ) List.Assoc.t 35 | -------------------------------------------------------------------------------- /src/remora-internal/map_replicate_ast.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | module B = Basic_ast 30 | module E = Erased_ast 31 | open Frame_notes 32 | 33 | type var = B.var with sexp 34 | 35 | type 'a app_t = {fn : 'a; args: 'a list;} with sexp 36 | type 'a vec_t = {dims: int; elts: 'a list;} with sexp 37 | type 'a map_t = {frame: 'a; fn: 'a; args: 'a list; shp: 'a;} with sexp 38 | type 'a rep_t = {arg: 'a; new_frame: 'a; old_frame: 'a;} with sexp 39 | type 'a tup_t = 'a list with sexp 40 | type 'a fld_t = {field: int; tuple: 'a} with sexp 41 | type 'a let_t = {var: var; bound: 'a; body: 'a;} with sexp 42 | type 'a lettup_t = {vars: var list; bound: 'a; body: 'a} with sexp 43 | type 'a lam_t = {bindings: var list; body: 'a;} with sexp 44 | 45 | type 'a expr_form = 46 | App of 'a app_t 47 | | Vec of 'a vec_t 48 | | Map of 'a map_t 49 | | Rep of 'a rep_t 50 | | Tup of 'a tup_t 51 | | LetTup of 'a lettup_t 52 | | Fld of 'a fld_t 53 | | Let of 'a let_t 54 | | Lam of 'a lam_t 55 | | Var of var 56 | | Int of int 57 | | Float of float 58 | | Bool of bool 59 | with sexp 60 | 61 | val map_expr_form : f:('a -> 'b) -> 'a expr_form -> 'b expr_form 62 | 63 | type expr = Expr of expr expr_form with sexp 64 | type defn = Defn of var * expr with sexp 65 | type prog = Prog of defn tup_t * expr with sexp 66 | 67 | type 'annot ann_expr = AExpr of 'annot * 'annot ann_expr expr_form with sexp 68 | type 'annot ann_defn = ADefn of var * 'annot ann_expr with sexp 69 | type 'annot ann_prog = 70 | AProg of 'annot * 'annot ann_defn tup_t * 'annot ann_expr with sexp 71 | 72 | val op_name_plus : var 73 | val op_name_shape_append : var 74 | 75 | val idx_name_mangle : var -> B.srt option -> var 76 | 77 | val of_erased_idx : 78 | E.idx -> (E.typ * arg_frame * app_frame) ann_expr 79 | 80 | val of_nested_shape : 81 | E.idx tup_t -> (E.typ * arg_frame * app_frame) ann_expr 82 | 83 | val defunctionalized_map : 84 | fn:(E.typ * arg_frame * app_frame) ann_expr -> 85 | args:(E.typ * arg_frame * app_frame) ann_expr tup_t -> 86 | shp:(E.typ * arg_frame * app_frame) ann_expr -> 87 | frame:(E.typ * arg_frame * app_frame) ann_expr -> 88 | (E.typ * arg_frame * app_frame) ann_expr expr_form 89 | 90 | val of_erased_expr : 91 | (E.typ * arg_frame * app_frame) E.ann_expr 92 | -> (E.typ * arg_frame * app_frame) ann_expr 93 | 94 | val of_erased_elt : 95 | (E.typ * arg_frame * app_frame) E.ann_elt 96 | -> (E.typ * arg_frame * app_frame) ann_expr 97 | 98 | val of_erased_defn : 99 | (E.typ * arg_frame * app_frame) E.ann_defn 100 | -> (E.typ * arg_frame * app_frame) ann_defn 101 | 102 | val of_erased_prog : 103 | (E.typ * arg_frame * app_frame) E.ann_prog 104 | -> (E.typ * arg_frame * app_frame) ann_prog 105 | 106 | val annot_expr_drop : 'a ann_expr -> expr 107 | val annot_defn_drop : 'a ann_defn -> defn 108 | val annot_prog_drop : 'a ann_prog -> prog 109 | 110 | val get_free_vars : 111 | var list -> (var list -> 'a -> var list) -> 'a expr_form -> var list 112 | val aexpr_free_vars : var list -> 'a ann_expr -> var list 113 | val get_annotated_free_vars : 114 | var list -> 'a ann_expr -> 'a ann_expr list 115 | 116 | module Passes : sig 117 | val prog : 118 | (E.typ * arg_frame * app_frame) E.ann_prog 119 | -> (E.typ * arg_frame * app_frame) ann_prog 120 | val defn : 121 | (E.typ * arg_frame * app_frame) E.ann_defn 122 | -> (E.typ * arg_frame * app_frame) ann_defn 123 | val expr : 124 | (E.typ * arg_frame * app_frame) E.ann_expr 125 | -> (E.typ * arg_frame * app_frame) ann_expr 126 | val elt : 127 | (E.typ * arg_frame * app_frame) E.ann_elt 128 | -> (E.typ * arg_frame * app_frame) ann_expr 129 | 130 | val prog_all : B.rem_prog -> (E.typ * arg_frame * app_frame) ann_prog option 131 | val defn_all : B.rem_defn -> (E.typ * arg_frame * app_frame) ann_defn option 132 | val expr_all : B.rem_expr -> (E.typ * arg_frame * app_frame) ann_expr option 133 | val elt_all : B.rem_elt -> (E.typ * arg_frame * app_frame) ann_expr option 134 | end 135 | -------------------------------------------------------------------------------- /src/remora-internal/run_tests.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | module AST = Test_basic_ast.UnitTests;; 30 | module Typecheck = Test_typechecker.UnitTests;; 31 | module Frame_annotate = Test_frame_notes.UnitTests;; 32 | module Erase = Test_erased_ast.UnitTests;; 33 | module MapRep = Test_map_replicate_ast.UnitTests;; 34 | module Clos = Test_closures.UnitTests;; 35 | open OUnit2 36 | 37 | let () = 38 | run_test_tt_main AST.suite_init_drop; 39 | run_test_tt_main Typecheck.tests; 40 | run_test_tt_main Frame_annotate.tests; 41 | run_test_tt_main Erase.tests; 42 | run_test_tt_main MapRep.tests; 43 | run_test_tt_main Clos.tests 44 | ;; 45 | -------------------------------------------------------------------------------- /src/remora-internal/substitution.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Basic_ast 30 | open Core.Std 31 | 32 | (* Finite mapping from variable to something -- same as an environment *) 33 | type 'a subst = (var, 'a) List.Assoc.t with sexp 34 | 35 | (* May later need... 36 | let rec exp_into_expr (sub: rem_expr subst) (x: rem_expr) : rem_expr 37 | let rec exp_into_elt (sub: rem_expr subst) (l: rem_elt) : rem_elt 38 | let rec elt_into_expr (sub: rem_elt subst) (x: rem_expr) : rem_expr 39 | let rec elt_into_elt (sub: rem_elt subst) (l: rem_elt) : rem_elt 40 | let rec typ_into_expr (typs: typ subst) (x: rem_expr) : rem_expr 41 | let rec typ_into_elt (typs: typ subst) (l: rem_elt) : rem_elt 42 | let rec idx_into_expr (sub: idx subst) (x: rem_expr) : rem_expr 43 | let rec idx_into_elt (sub: idx subst) (l: rem_elt) : rem_elt 44 | *) 45 | 46 | let rec typ_into_typ (sub: typ subst) (t: typ) : typ = 47 | match t with 48 | | (TFloat | TInt | TBool) as t_ -> t_ 49 | | TDProd (ivars, body) -> TDProd (ivars, typ_into_typ sub body) 50 | | TDSum (ivars, body) -> TDSum (ivars, typ_into_typ sub body) 51 | | TFun (ins, out) -> TFun ((List.map ~f:(typ_into_typ sub) ins), 52 | typ_into_typ sub out) 53 | | TArray (shape, elts) -> TArray (shape, typ_into_typ sub elts) 54 | | TAll (tvars, body) 55 | -> TAll (tvars, 56 | typ_into_typ (List.fold ~init:sub 57 | ~f:(List.Assoc.remove ~equal:(=)) tvars) 58 | body) 59 | | TVar v -> Option.value ~default:(TVar v) (List.Assoc.find sub v) 60 | | TProd elts -> TProd (List.map ~f:(typ_into_typ sub) elts) 61 | 62 | let rec idx_into_idx (sub: idx subst) (i: idx) : idx = 63 | match i with 64 | | INat _ as i_ -> i_ 65 | | IShape idxs -> IShape (List.map ~f:(idx_into_idx sub) idxs) 66 | | ISum (idx1, idx2) -> ISum (idx_into_idx sub idx1, idx_into_idx sub idx2) 67 | | IVar (name, _) as v -> Option.value ~default:v (List.Assoc.find sub name) 68 | 69 | let rec idx_into_typ (sub: idx subst) (t: typ) : typ = 70 | match t with 71 | | TFloat | TInt | TBool as t_ -> t_ 72 | | TDProd (ivars, body) 73 | -> TDProd (ivars, (idx_into_typ 74 | (List.fold ~init:sub ~f:(List.Assoc.remove ~equal:(=)) 75 | (List.map ~f:fst ivars)) 76 | body)) 77 | | TDSum (ivars, body) 78 | -> TDSum (ivars, (idx_into_typ 79 | (List.fold ~init:sub ~f:(List.Assoc.remove ~equal:(=)) 80 | (List.map ~f:fst ivars)) 81 | body)) 82 | | TFun (ins, out) -> TFun ((List.map ~f:(idx_into_typ sub) ins), 83 | idx_into_typ sub out) 84 | | TArray (shape, elts) -> TArray (idx_into_idx sub shape, 85 | idx_into_typ sub elts) 86 | | TAll (tvars, body) -> TAll (tvars, idx_into_typ sub body) 87 | | TVar _ as tv -> tv 88 | | TProd elts -> TProd (List.map ~f:(idx_into_typ sub) elts) 89 | 90 | -------------------------------------------------------------------------------- /src/remora-internal/substitution.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | 31 | type 'a subst = (Basic_ast.var, 'a) List.Assoc.t with sexp 32 | 33 | val typ_into_typ : Basic_ast.typ subst -> Basic_ast.typ -> Basic_ast.typ 34 | val idx_into_typ : Basic_ast.idx subst -> Basic_ast.typ -> Basic_ast.typ 35 | val idx_into_idx : Basic_ast.idx subst -> Basic_ast.idx -> Basic_ast.idx 36 | -------------------------------------------------------------------------------- /src/remora-internal/test_basic_ast.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | module U = OUnit2;; 32 | 33 | let test_t_int = TInt 34 | let test_t_float = TFloat 35 | 36 | 37 | let ta = 38 | AnnRExpr ((TArray (IShape [INat 2], TInt)), 39 | (Arr ([2], [AnnRElt (TInt, Int 3); 40 | AnnRElt (TInt, Int 2)])));; 41 | 42 | let scalar_of_elt e = RExpr (Arr ([], [e])) 43 | let scalar_of_expr e = scalar_of_elt (RElt (Expr e)) 44 | let scalar_of_elt_form e = scalar_of_elt (RElt e) 45 | let scalar_of_expr_form e = scalar_of_expr (RExpr e) 46 | 47 | let flat_arr_2_3 = 48 | RExpr (Arr ([2;3], [RElt (Int 4); RElt (Int 1); RElt (Int 6); 49 | RElt (Int 2); RElt (Int 3); RElt (Int 5)])) 50 | 51 | let flat_arr_0_4 = RExpr (Arr ([0;4], [])) 52 | 53 | let arr_2 = RExpr (Arr ([2], [RElt (Bool false); RElt (Bool true)])) 54 | 55 | let arr_wrong = 56 | RExpr (Arr ([2;3], [RElt (Int 4); RElt (Int 1); RElt (Int 6); 57 | RElt (Int 2); RElt (Int 3)])) 58 | 59 | let nest_arr_2_3 = 60 | RExpr (Arr ([2], 61 | [RElt (Expr (RExpr (Arr ([3], [RElt (Int 4); 62 | RElt (Int 1); 63 | RElt (Int 6)])))); 64 | RElt (Expr (RExpr (Arr ([3], [RElt (Int 2); 65 | RElt (Int 3); 66 | RElt (Int 5)]))))])) 67 | 68 | let unary_lambda = 69 | RElt (Lam ([("x", TArray (IShape [], TInt))], scalar_of_elt_form (Int 3))) 70 | 71 | let binary_lambda = 72 | RElt (Lam ([("x", TArray (IShape [INat 3], TFloat)); 73 | ("y", TArray (IShape [INat 1], TBool))], 74 | scalar_of_elt_form (Int 3))) 75 | 76 | let unary_app = 77 | RExpr (App (RExpr (Arr ([], [unary_lambda])), 78 | [scalar_of_elt_form (Int 4)])) 79 | 80 | let binary_app = 81 | RExpr (App (RExpr (Arr ([], [binary_lambda])), 82 | [RExpr (Arr ([3], [RElt (Float 1.0); 83 | RElt (Float 2.0); 84 | RElt (Float 3.0)])); 85 | RExpr (Arr ([1], [RElt (Bool false)]))])) 86 | 87 | let elt_of_expr_form e = RElt (Expr (RExpr e)) 88 | 89 | let unary_to_nested_app = 90 | RExpr 91 | (App 92 | (RExpr 93 | (Arr ([], 94 | [(RElt (Lam ([("x", TArray (IShape [INat 2; INat 3], TInt))], 95 | RExpr (Var "x"))))])), 96 | [RExpr (Arr ([2], [elt_of_expr_form 97 | (Arr ([3], [RElt (Int 1); 98 | RElt (Int 2); 99 | RElt (Int 3)])); 100 | elt_of_expr_form 101 | (Arr ([3], [RElt (Int 4); 102 | RElt (Int 5); 103 | RElt (Int 6)]))]))])) 104 | 105 | let nested_to_unary_app = 106 | RExpr 107 | (App 108 | (RExpr 109 | (Arr ([3], 110 | [elt_of_expr_form 111 | (Arr ([2], 112 | [RElt (Lam ([("x", TArray (IShape [], TInt))], 113 | scalar_of_elt_form (Int 1))); 114 | RElt (Lam ([("x", TArray (IShape [], TInt))], 115 | scalar_of_elt_form (Int 2)))])); 116 | elt_of_expr_form 117 | (Arr ([2], 118 | [RElt (Lam ([("x", TArray (IShape [], TInt))], 119 | scalar_of_elt_form (Int 3))); 120 | RElt (Lam ([("x", TArray (IShape [], TInt))], 121 | scalar_of_elt_form (Int 4)))])); 122 | elt_of_expr_form 123 | (Arr ([2], 124 | [RElt (Lam ([("x", TArray (IShape [], TInt))], 125 | scalar_of_elt_form (Int 5))); 126 | RElt (Lam ([("x", TArray (IShape [], TInt))], 127 | scalar_of_elt_form (Int 6)))]))])), 128 | [RExpr (Arr ([3], [RElt (Int 7); RElt (Int 23523); RElt (Int 245)]))])) 129 | 130 | let type_abst = 131 | RExpr (TLam (["elt"], 132 | scalar_of_elt_form (Lam ([("x", TArray (IShape [], TVar "elt"))], 133 | RExpr (Var "x"))))) 134 | 135 | let type_abst_bad = 136 | RExpr (TLam (["elt"], 137 | scalar_of_elt_form (Lam ([("x", TArray (IShape [], TVar "foo"))], 138 | RExpr (Var "x"))))) 139 | 140 | let type_app = RExpr (TApp (type_abst, [TBool])) 141 | 142 | let index_abst = 143 | RExpr (ILam (["d", SNat], 144 | RExpr (Arr ([], 145 | [RElt (Lam ([("l", TArray (IShape [ivar "d"], 146 | TInt))], 147 | RExpr (Var "l")))])))) 148 | 149 | let index_app = RExpr (IApp (index_abst, [INat 6])) 150 | 151 | let dep_sum_create = 152 | RExpr (Pack ([INat 3], 153 | RExpr (Arr ([3], [RElt (Int 0); RElt (Int 1); RElt (Int 2)])), 154 | TDSum ([("d", SNat)], TArray (IShape [ivar "d"], TInt)))) 155 | 156 | 157 | let dep_sum_project = 158 | RExpr (Unpack (["l"], "c", dep_sum_create, scalar_of_elt_form (Int 0))) 159 | 160 | 161 | let remora_compose = 162 | let inner_lam = 163 | RElt (Lam (["x", TArray (ivar "s1", TVar "alpha")], 164 | RExpr (App (RExpr (Var "g"), 165 | [(RExpr (App (RExpr (Var "f"), 166 | [RExpr (Var "x")])))])))) in 167 | let outer_lam = 168 | RElt (Lam (["f", TArray (IShape [], 169 | TFun ([TArray (ivar "s1", TVar "alpha")], 170 | TArray (ivar "s2", TVar "beta"))); 171 | "g", TArray (IShape [], 172 | TFun ([TArray (ivar "s2", TVar "beta")], 173 | TArray (ivar "s3", TVar "gamma")))], 174 | scalar_of_elt inner_lam)) in 175 | let type_lam = 176 | RExpr (TLam (["alpha"; "beta"; "gamma"], 177 | scalar_of_elt outer_lam)) in 178 | RExpr (ILam (["s1", SShape; "s2", SShape; "s3", SShape], type_lam)) 179 | 180 | 181 | let fork_compose = 182 | let inner_lam = 183 | RElt (Lam (["x", TArray (ivar "s-li", TVar "t-li"); 184 | "y", TArray (ivar "s-ri", TVar "t-ri")], 185 | RExpr (App (RExpr (Var "f-j"), 186 | [RExpr (App (RExpr (Var "f-l"), 187 | [RExpr (Var "x")])); 188 | RExpr (App (RExpr (Var "f-r"), 189 | [RExpr (Var "y")]))])))) in 190 | let outer_lam = 191 | RElt (Lam (["f-l", TArray (IShape [], 192 | TFun ([TArray (ivar "s-li", TVar "t-li")], 193 | TArray (ivar "s-lo", TVar "t-lo"))); 194 | "f-r", TArray (IShape [], 195 | TFun ([TArray (ivar "s-ri", TVar "t-ri")], 196 | TArray (ivar "s-ro", TVar "t-ro"))); 197 | "f-j", TArray (IShape [], 198 | TFun ([TArray (ivar "s-lo", TVar "t-lo"); 199 | TArray (ivar "s-ro", TVar "t-ro")], 200 | TArray (ivar "s-jo", TVar "t-jo")))], 201 | scalar_of_elt inner_lam)) in 202 | let type_lam = 203 | RExpr (TLam (["t-li"; "t-lo"; "t-ri"; "t-ro"; "t-jo"], 204 | scalar_of_elt outer_lam)) in 205 | RExpr (ILam (["s-li", SShape; "s-lo", SShape; 206 | "s-ri", SShape; "s-ro", SShape; 207 | "s-jo", SShape], 208 | type_lam)) 209 | 210 | let define_compose = 211 | RDefn ("compose", 212 | (TDProd 213 | (["s1", SShape; "s2", SShape; "s3", SShape], 214 | TAll 215 | (["alpha"; "beta"; "gamma"], 216 | TArray (IShape [], 217 | TFun ([TArray 218 | (IShape [], 219 | TFun ([TArray (ivar "s1", TVar "alpha")], 220 | TArray (ivar "s2", TVar "beta"))); 221 | TArray 222 | (IShape [], 223 | TFun ([TArray (ivar "s2", TVar "beta")], 224 | TArray (ivar "s3", TVar "gamma")))], 225 | TArray 226 | (IShape [], 227 | TFun ([TArray (ivar "s1", TVar "alpha")], 228 | TArray (ivar "s3", TVar "gamma")))))))), 229 | remora_compose) 230 | 231 | let use_compose = 232 | RExpr (App (RExpr 233 | (App (RExpr 234 | (TApp (RExpr 235 | (IApp (RExpr (Var "compose"), 236 | [IShape []; 237 | IShape []; 238 | IShape []])), 239 | [TInt; TInt; TInt])), 240 | [scalar_of_elt unary_lambda; 241 | scalar_of_elt unary_lambda])), 242 | [scalar_of_elt_form (Int 0)])) 243 | 244 | let prog_compose = 245 | RProg ([define_compose], use_compose) 246 | 247 | let curried_add = 248 | let inner_app = RExpr (App (RExpr (Var "+"), 249 | [RExpr (Var "x"); RExpr (Var "y")])) in 250 | let inner_lambda = RElt (Lam (["y", TArray (IShape [], TInt)], 251 | inner_app)) in 252 | let outer_lambda = RElt (Lam (["x", TArray (IShape [], TInt)], 253 | scalar_of_elt inner_lambda)) in 254 | RExpr (Arr ([], [outer_lambda])) 255 | 256 | let define_curried_add = 257 | RDefn ("c+", 258 | (TArray (IShape [], 259 | TFun ([TArray (IShape [], TInt)], 260 | TArray (IShape [], 261 | TFun ([TArray (IShape [], TInt)], 262 | TArray (IShape [], TInt)))))), 263 | curried_add) 264 | 265 | let lift_curried_add = 266 | RExpr (App (RExpr (App (RExpr (Var "c+"), 267 | [RExpr (Arr ([2], [RElt (Int 10); RElt (Int 20)]))])), 268 | [RExpr (Arr ([2; 3], [RElt (Int 1); RElt (Int 2); 269 | RElt (Int 3); RElt (Int 4); 270 | RElt (Int 5); RElt (Int 6)]))])) 271 | 272 | let prog_curried_add = 273 | RProg ([define_curried_add], lift_curried_add) 274 | 275 | let expr_vec_add = 276 | RExpr 277 | (ILam (["l", SNat], 278 | RExpr 279 | (Arr ([], 280 | [RElt (Lam (["xs", TArray (IShape [IVar ("l", None)], TInt); 281 | "ys", TArray (IShape [IVar ("l", None)], TInt)], 282 | RExpr (App (RExpr (Var "+"), 283 | [RExpr (Var "xs"); 284 | RExpr (Var "ys")]))))])))) 285 | let defn_vec_add = 286 | RDefn ("v+", 287 | TDProd (["l", SNat], 288 | TArray (IShape [], 289 | TFun ([TArray (IShape [IVar ("l", None)], TInt); 290 | TArray (IShape [IVar ("l", None)], TInt)], 291 | TArray (IShape [IVar ("l", None)], TInt)))), 292 | expr_vec_add) 293 | let prog_vec_add = 294 | RProg ([defn_vec_add], 295 | (RExpr (App ((RExpr (IApp (RExpr (Var "v+"), 296 | [INat 3]))), 297 | [RExpr (Arr ([3], 298 | [RElt (Int 10); 299 | RElt (Int 20); 300 | RElt (Int 30)])); 301 | RExpr (Arr ([2; 3], 302 | [RElt (Int 1); RElt (Int 2); 303 | RElt (Int 3); RElt (Int 4); 304 | RElt (Int 5); RElt (Int 6)]))])))) 305 | 306 | let saxpy_let = 307 | RElt 308 | (Lam (["axy", TProd [TArray (IShape [], TFloat); 309 | TArray (IShape [], TFloat); 310 | TArray (IShape [], TFloat)]], 311 | (RExpr (LetTup (["a";"x";"y"], 312 | RExpr (Var "axy"), 313 | RExpr (App (RExpr (Var "+."), 314 | [RExpr (App (RExpr (Var "*."), 315 | [(RExpr (Var "a")); 316 | (RExpr (Var "x"))])); 317 | RExpr (Var "y")]))))))) 318 | let saxpy_field = 319 | RElt 320 | (Lam (["axy", TProd [TArray (IShape [], TFloat); 321 | TArray (IShape [], TFloat); 322 | TArray (IShape [], TFloat)]], 323 | (RExpr (App (RExpr (Var "+."), 324 | [RExpr (App (RExpr (Var "*."), 325 | [(RExpr (Field (0, RExpr (Var "axy")))); 326 | (RExpr (Field (1, RExpr (Var "axy"))))])); 327 | (RExpr (Field (2, RExpr (Var "axy"))))]))))) 328 | let saxpy_type = TFun ([TProd [TArray (IShape [], TFloat); 329 | TArray (IShape [], TFloat); 330 | TArray (IShape [], TFloat)]], 331 | TArray (IShape [], TFloat)) 332 | let saxpy_let_defn = RDefn ("saxpy_let", TArray (IShape [], saxpy_type), 333 | scalar_of_elt saxpy_let) 334 | let saxpy_field_defn = RDefn ("saxpy_field", TArray (IShape [], saxpy_type), 335 | scalar_of_elt saxpy_let) 336 | let saxpy_prog = 337 | RProg ([saxpy_let_defn; saxpy_field_defn], 338 | RExpr (App (RExpr (Arr ([2], [RElt (Expr (RExpr (Var "saxpy_field"))); 339 | RElt (Expr (RExpr (Var "saxpy_let")))])), 340 | [RExpr 341 | (Tuple [scalar_of_elt_form (Float 2.3); 342 | scalar_of_elt_form (Float (-1.1)); 343 | scalar_of_elt_form (Float 0.2)])]))) 344 | 345 | (* For any rem_expr, adding blank annotations and then dropping annotations 346 | should lead back to the same rem_expr *) 347 | module UnitTests : sig 348 | val suite_init_drop : U.test 349 | end = struct 350 | (* Initializing and removing annotations should give the same thing back *) 351 | let test_expr_init_drop (x: rem_expr) (_: U.test_ctxt) = 352 | U.assert_equal x (x |> annot_expr_init ~init:() |> annot_expr_drop) 353 | let test_elt_init_drop (l: rem_elt) (_: U.test_ctxt) = 354 | U.assert_equal l (l |> annot_elt_init ~init:() |> annot_elt_drop) 355 | let test_defn_init_drop (d: rem_defn) (_: U.test_ctxt) = 356 | U.assert_equal d (d |> annot_defn_init ~init:() |> annot_defn_drop) 357 | let test_prog_init_drop (d: rem_prog) (_: U.test_ctxt) = 358 | U.assert_equal d (d |> annot_prog_init ~init:() |> annot_prog_drop) 359 | (* let flat_arr_2_3 = test_expr_init_drop flat_arr_2_3 *) 360 | (* let flat_arr_0_4 = test_expr_init_drop flat_arr_0_4 *) 361 | (* let arr_2 = test_expr_init_drop arr_2 *) 362 | (* let arr_wrong = test_expr_init_drop arr_wrong *) 363 | (* let nest_arr_2_3 = test_expr_init_drop nest_arr_2_3 *) 364 | (* let unary_lambda = test_elt_init_drop unary_lambda *) 365 | (* let binary_lambda = test_elt_init_drop binary_lambda *) 366 | (* let unary_app = test_expr_init_drop unary_app *) 367 | (* let binary_app = test_expr_init_drop binary_app *) 368 | (* let unary_to_nested_app = test_expr_init_drop unary_to_nested_app *) 369 | (* let nested_to_unary_app = test_expr_init_drop nested_to_unary_app *) 370 | (* let type_abst = test_expr_init_drop type_abst *) 371 | (* let type_abst_bad = test_expr_init_drop type_abst_bad *) 372 | (* let type_app = test_expr_init_drop type_app *) 373 | (* let index_abst = test_expr_init_drop index_abst *) 374 | (* let index_app = test_expr_init_drop index_app *) 375 | (* let dep_sum_create = test_expr_init_drop dep_sum_create *) 376 | (* let dep_sum_project = test_expr_init_drop dep_sum_project *) 377 | (* let remora_compose = test_expr_init_drop remora_compose *) 378 | (* let fork_compose = test_expr_init_drop fork_compose *) 379 | open OUnit2 380 | let suite_init_drop = 381 | "annotation init-drop">::: 382 | ["Flat 2x3">:: test_expr_init_drop flat_arr_2_3; 383 | "Flat 0x4">:: test_expr_init_drop flat_arr_0_4; 384 | "Flat 2">:: test_expr_init_drop arr_2; 385 | "Ill-formed array">:: test_expr_init_drop arr_wrong; 386 | "Nested 2x3">:: test_expr_init_drop nest_arr_2_3; 387 | "Unary lambda">:: test_elt_init_drop unary_lambda; 388 | "Binary lambda">:: test_elt_init_drop binary_lambda; 389 | "Apply unary to nested">:: test_expr_init_drop unary_to_nested_app; 390 | "Apply nested to unary">:: test_expr_init_drop nested_to_unary_app; 391 | "Type abstraction">:: test_expr_init_drop type_abst; 392 | "Ill-formed type abstraction">:: test_expr_init_drop type_abst_bad; 393 | "Type application">:: test_expr_init_drop type_app; 394 | "Index abstraction">:: test_expr_init_drop index_abst; 395 | "Index application">:: test_expr_init_drop index_app; 396 | "Intro dependent sum">:: test_expr_init_drop dep_sum_create; 397 | "Elim dependent sum">:: test_expr_init_drop dep_sum_project; 398 | "Straight line composition">:: test_expr_init_drop remora_compose; 399 | "Fork composition">:: test_expr_init_drop fork_compose; 400 | "Define compose in program">:: test_defn_init_drop define_compose; 401 | "Use compose in program">:: test_prog_init_drop prog_compose; 402 | "Curried addition">:: test_expr_init_drop curried_add; 403 | "Define curried addition in program">:: test_defn_init_drop 404 | define_curried_add; 405 | "Apply curried addition to overranked arguments">:: test_expr_init_drop 406 | lift_curried_add; 407 | "Use curried addition in program">:: test_prog_init_drop 408 | prog_curried_add; 409 | "Destruct a tuple by let-binding">:: test_elt_init_drop saxpy_let; 410 | "Destruct a tuple by field access">:: test_elt_init_drop saxpy_field; 411 | "Applying two versions of saxpy">:: test_prog_init_drop saxpy_prog] 412 | end 413 | -------------------------------------------------------------------------------- /src/remora-internal/test_basic_ast.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | val test_t_int : Basic_ast.typ 30 | val test_t_float : Basic_ast.typ 31 | val ta : Basic_ast.typ Basic_ast.ann_expr 32 | val scalar_of_elt : Basic_ast.rem_elt -> Basic_ast.rem_expr 33 | val scalar_of_expr : Basic_ast.rem_expr -> Basic_ast.rem_expr 34 | val scalar_of_elt_form : 35 | (Basic_ast.rem_elt, Basic_ast.rem_expr) Basic_ast.elt_form -> 36 | Basic_ast.rem_expr 37 | val scalar_of_expr_form : 38 | (Basic_ast.rem_expr, Basic_ast.rem_elt) Basic_ast.expr_form -> 39 | Basic_ast.rem_expr 40 | val flat_arr_2_3 : Basic_ast.rem_expr 41 | val flat_arr_0_4 : Basic_ast.rem_expr 42 | val arr_2 : Basic_ast.rem_expr 43 | val arr_wrong : Basic_ast.rem_expr 44 | val nest_arr_2_3 : Basic_ast.rem_expr 45 | val unary_lambda : Basic_ast.rem_elt 46 | val binary_lambda : Basic_ast.rem_elt 47 | val unary_app : Basic_ast.rem_expr 48 | val binary_app : Basic_ast.rem_expr 49 | val elt_of_expr_form : 50 | (Basic_ast.rem_expr, Basic_ast.rem_elt) Basic_ast.expr_form -> 51 | Basic_ast.rem_elt 52 | val unary_to_nested_app : Basic_ast.rem_expr 53 | val nested_to_unary_app : Basic_ast.rem_expr 54 | val type_abst : Basic_ast.rem_expr 55 | val type_abst_bad : Basic_ast.rem_expr 56 | val type_app : Basic_ast.rem_expr 57 | val index_abst : Basic_ast.rem_expr 58 | val index_app : Basic_ast.rem_expr 59 | val dep_sum_create : Basic_ast.rem_expr 60 | val dep_sum_project : Basic_ast.rem_expr 61 | val remora_compose : Basic_ast.rem_expr 62 | val fork_compose : Basic_ast.rem_expr 63 | val define_compose : Basic_ast.rem_defn 64 | val use_compose : Basic_ast.rem_expr 65 | val prog_compose : Basic_ast.rem_prog 66 | val curried_add : Basic_ast.rem_expr 67 | val define_curried_add : Basic_ast.rem_defn 68 | val lift_curried_add : Basic_ast.rem_expr 69 | val prog_curried_add : Basic_ast.rem_prog 70 | val prog_vec_add : Basic_ast.rem_prog 71 | val saxpy_let : Basic_ast.rem_elt 72 | val saxpy_field : Basic_ast.rem_elt 73 | val saxpy_prog : Basic_ast.rem_prog 74 | 75 | module UnitTests : sig 76 | val suite_init_drop : OUnit2.test 77 | end 78 | -------------------------------------------------------------------------------- /src/remora-internal/test_closures.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Closures 31 | module MR = Map_replicate_ast;; 32 | module B = Basic_ast;; 33 | module U = OUnit2;; 34 | open Frame_notes 35 | module E = Erased_ast;; 36 | 37 | let unary_lam = 38 | Expr (Cls 39 | {code = 40 | Expr (Lam {MR.bindings = ["__ENV_1"; "x"]; 41 | MR.body = Expr (LetTup 42 | {MR.vars = []; 43 | MR.bound = Expr 44 | (Var "__ENV_1"); 45 | MR.body = Expr 46 | (Vec {MR.dims = 1; 47 | MR.elts = 48 | [Expr (Int 3)]})})}); 49 | env = Expr (Tup [])}) 50 | 51 | 52 | let mr_wrap e = MR.AExpr ((E.TUnknown, NotArg, NotApp), e) 53 | let fn_wrap e = MR.AExpr ((E.TFun ([], E.TUnknown), NotArg, NotApp), e) 54 | 55 | let escaping_function = 56 | mr_wrap (MR.LetTup {MR.vars = ["f"]; 57 | MR.bound = mr_wrap 58 | (MR.LetTup 59 | {MR.vars = ["x"]; 60 | MR.bound = mr_wrap (MR.Tup [mr_wrap (MR.Int 5)]); 61 | MR.body = (fn_wrap 62 | (MR.Lam {bindings = ["l"]; 63 | MR.body = mr_wrap 64 | (MR.App {MR.fn = fn_wrap (MR.Var "+"); 65 | MR.args = [mr_wrap (MR.Var "l"); 66 | mr_wrap (MR.Var "x")]})}))}); 67 | MR.body = (mr_wrap 68 | (MR.App {MR.fn = fn_wrap (MR.Var "f"); 69 | MR.args = [mr_wrap (MR.Int 6)]}))}) 70 | let converted = 71 | Expr 72 | (LetTup {MR.vars = ["f"]; 73 | MR.bound = 74 | Expr 75 | (LetTup {MR.vars = ["x"]; 76 | MR.bound = Expr (Tup [Expr (Int 5)]); 77 | MR.body = 78 | Expr (Cls {code = 79 | Expr (Lam {MR.bindings = ["env";"l"]; 80 | MR.body = Expr 81 | (LetTup 82 | {MR.vars = ["x"]; 83 | MR.bound = Expr (Var "env"); 84 | MR.body = Expr (App {closure = Expr (Var "+"); 85 | args = [Expr (Var "l"); 86 | Expr (Var "x")]})})}); 87 | env = Expr (Tup [Expr (Var "x")])})}); 88 | MR.body = Expr (App {closure = Expr (Var "f"); 89 | args = [Expr (Int 6)]})}) 90 | 91 | let destr_dsum = 92 | mr_wrap 93 | (MR.LetTup {MR.vars = ["l";"n"]; 94 | MR.bound = mr_wrap 95 | (MR.App {MR.fn = fn_wrap (MR.Var "iota"); 96 | MR.args = [mr_wrap 97 | (MR.Vec {MR.dims = 1; 98 | MR.elts = [mr_wrap (MR.Int 5)]})]}); 99 | MR.body = mr_wrap 100 | (MR.Tup [mr_wrap (MR.Var "l"); 101 | mr_wrap (MR.Map {MR.frame = mr_wrap (MR.Var "l"); 102 | MR.fn = fn_wrap 103 | (MR.Lam {MR.bindings = ["x"]; 104 | MR.body = mr_wrap 105 | (MR.App {MR.fn = fn_wrap (MR.Var "+"); 106 | MR.args = []})}); 107 | MR.args = []; 108 | MR.shp = mr_wrap (MR.Var "l")})])}) 109 | 110 | let vec_scal_add = 111 | fn_wrap 112 | (MR.Lam {MR.bindings = ["l"]; 113 | MR.body = fn_wrap 114 | (MR.Lam {MR.bindings = ["x";"y"]; 115 | MR.body = mr_wrap 116 | (MR.Map {MR.frame = mr_wrap (MR.Var "l"); 117 | MR.fn = fn_wrap (MR.Var "+"); 118 | MR.args = 119 | [mr_wrap (MR.Var "x"); 120 | mr_wrap (MR.Rep {MR.arg = mr_wrap (MR.Var "y"); 121 | MR.old_frame = mr_wrap 122 | (MR.Vec {MR.dims = 0; MR.elts = []}); 123 | MR.new_frame = mr_wrap 124 | (MR.Vec {MR.dims = 1; 125 | MR.elts = [mr_wrap (MR.Var "l")]})})]; 126 | MR.shp = mr_wrap (MR.Var "l")})})}) 127 | 128 | let rec subst (s : (var, expr) List.Assoc.t) ((Expr e) as exp: expr) : expr = 129 | match e with 130 | | LetTup {MR.vars = vars; MR.bound = bound; MR.body = body} -> 131 | (* Shadow every substitution var that appears in vars. *) 132 | let new_s = 133 | List.fold_right ~init:s ~f:(fun l r -> List.Assoc.remove r l) vars in 134 | Expr (LetTup {MR.vars = vars; 135 | MR.bound = subst s bound; 136 | MR.body = subst new_s body}) 137 | | Lam {MR.bindings = vars; MR.body = body} -> 138 | let new_s = 139 | List.fold_right ~init:s ~f:(fun l r -> List.Assoc.remove r l) vars in 140 | Expr (Lam {MR.bindings = vars; 141 | MR.body = subst new_s body}) 142 | | Var v -> List.Assoc.find s v |> (Option.value ~default:exp) 143 | | _ -> Expr (map_expr_form ~f:(subst s) e) 144 | 145 | let rec alpha_eqv 146 | (Expr e1: expr) (Expr e2: expr) : bool = 147 | match (e1, e2) with 148 | | (App {closure = c1; args = a1}, App {closure = c2; args = a2}) -> 149 | (alpha_eqv c1 c2) && (List.length a1 = List.length a2) && 150 | (List.for_all2_exn ~f:alpha_eqv a1 a2) 151 | | (Vec {MR.dims = d1; MR.elts = l1}, Vec {MR.dims = d2; MR.elts = l2}) -> 152 | (d1 = d2) && (List.length l1 = List.length l2) && 153 | (List.for_all2_exn ~f:alpha_eqv l1 l2) 154 | | (Map {MR.frame = fr1; MR.fn = fn1; MR.args = a1; MR.shp = s1}, 155 | Map {MR.frame = fr2; MR.fn = fn2; MR.args = a2; MR.shp = s2}) -> 156 | (alpha_eqv fr1 fr2) && (alpha_eqv fn1 fn2) && 157 | (List.length a1 = List.length a2) && 158 | (List.for_all2_exn ~f:alpha_eqv a1 a2) && 159 | (alpha_eqv s1 s2) 160 | | (Rep {MR.arg = a1; MR.old_frame = o1; MR.new_frame = n1}, 161 | Rep {MR.arg = a2; MR.old_frame = o2; MR.new_frame = n2}) -> 162 | (alpha_eqv a1 a2) && (alpha_eqv o1 o2) && (alpha_eqv n1 n2) 163 | | (Tup l1, Tup l2) -> 164 | (* Tuples must have the same arity. *) 165 | (List.length l1 = List.length l2) && 166 | (* Corresponding elements must be alpha-equivalent. *) 167 | (List.for_all2_exn ~f:alpha_eqv l1 l2) 168 | | (LetTup {MR.vars = v1; MR.bound = bn1; MR.body = bd1}, 169 | LetTup {MR.vars = v2; MR.bound = bn2; MR.body = bd2}) -> 170 | (* Both lets must bind the same number of variables. *) 171 | (List.length v1 = List.length v2) && 172 | (* Both bound terms must be alpha equivalent *) 173 | (alpha_eqv bn1 bn2) && 174 | (* Changing the let-bound vars to fresh ones must give alpha-equivalent 175 | bodies. *) 176 | (let fresh_vars = List.map ~f:(fun _ -> Expr (Var (B.gensym "__="))) v1 in 177 | (alpha_eqv 178 | (subst (List.map2_exn ~f:Tuple2.create v1 fresh_vars) bd1) 179 | (subst (List.map2_exn ~f:Tuple2.create v2 fresh_vars) bd2))) 180 | | (Fld {MR.field = n1; MR.tuple = t1}, Fld {MR.field = n2; MR.tuple = t2}) -> 181 | n1 = n2 && (alpha_eqv t1 t2) 182 | | (Let {MR.var = v1; MR.bound = bn1; MR.body = bd1}, 183 | Let {MR.var = v2; MR.bound = bn2; MR.body = bd2}) -> 184 | (alpha_eqv bn1 bn2) && 185 | let fresh_var = Expr (Var (B.gensym "__=")) in 186 | (alpha_eqv (subst [v1, fresh_var] bd1) (subst [v2, fresh_var] bd2)) 187 | | (Cls {code = c1; env = n1}, Cls {code = c2; env = n2}) -> 188 | (alpha_eqv c1 c2) && (alpha_eqv n1 n2) 189 | | (Lam {MR.bindings = vars1; MR.body = body1}, 190 | Lam {MR.bindings = vars2; MR.body = body2}) -> 191 | (* Both lambdas must bind the same number of variables. *) 192 | (List.length vars1 = List.length vars2) && 193 | (* Changing those vars to fresh ones must give alpha-equivalent bodies. *) 194 | (let fresh_vars = List.map 195 | ~f:(fun _ -> Expr (Var (B.gensym "__="))) 196 | vars1 in 197 | alpha_eqv 198 | (subst (List.map2_exn ~f:Tuple2.create vars1 fresh_vars) body1) 199 | (subst (List.map2_exn ~f:Tuple2.create vars2 fresh_vars) body2)) 200 | | (Var v1, Var v2) -> v1 = v2 201 | | (Int v1, Int v2) -> v1 = v2 202 | | (Float v1, Float v2) -> v1 = v2 203 | | (Bool v1, Bool v2) -> v1 = v2 204 | | ((App _ | Vec _ | Map _ | Rep _ | Tup _ | LetTup _ | Fld _ | Let _ | 205 | Cls _ | Lam _ | Var _ | Int _ | Float _ | Bool _), _) -> false 206 | 207 | module Test_closure_conversion : sig 208 | val tests : U.test 209 | end = struct 210 | let test_1 _ = 211 | U.assert_bool "Non-equivalent result from closure conversion!" 212 | (alpha_eqv converted 213 | (escaping_function |> expr_of_maprep ["+"] |> annot_expr_drop)) 214 | let tests = 215 | let open OUnit2 in 216 | "translate from Map/Rep AST into AST with explicit closures">::: 217 | ["lambda escaping with a let-bound var">:: test_1] 218 | end 219 | 220 | (* Make sure we can undo hoisting by substituting defn bodies back in. *) 221 | let expr_unhoist ((expr, defns): ('a ann_expr, 'a) Defn_writer.t) : expr = 222 | subst 223 | (List.map ~f:(fun (ADefn (name, value)) -> 224 | (name, annot_expr_drop value)) defns) 225 | (annot_expr_drop expr) 226 | let hoist_unhoist (e: 'a ann_expr) : bool = 227 | (e |> expr_hoist_lambdas |> expr_unhoist) = (e |> annot_expr_drop) 228 | 229 | (* Make sure we generate code with no inline lambdas. *) 230 | let rec lambda_free (AExpr (_, e): 'a ann_expr) : bool = 231 | match e with 232 | | App {closure = c; args = a} -> 233 | lambda_free c && List.for_all ~f:lambda_free a 234 | | Vec {MR.dims = _; MR.elts = elts} -> List.for_all ~f:lambda_free elts 235 | | Map {MR.frame = fr; MR.fn = fn; MR.args = a; MR.shp} -> 236 | List.for_all ~f:lambda_free (fr :: fn :: shp :: a) 237 | | Rep {MR.arg = a; MR.old_frame = o; MR.new_frame = n} -> 238 | lambda_free a && lambda_free o && lambda_free n 239 | | Tup elts -> List.for_all ~f:lambda_free elts 240 | | LetTup {MR.vars = _; MR.bound = bn; MR.body = bd} -> 241 | lambda_free bn && lambda_free bd 242 | | Fld {MR.field = _; MR.tuple = tup} -> lambda_free tup 243 | | Let {MR.var = _; MR.bound = bn; MR.body = bd} -> 244 | lambda_free bn && lambda_free bd 245 | | Lam _ -> false 246 | | Cls {code = c; env = n} -> lambda_free c && lambda_free n 247 | | Var _ | Int _ | Float _ | Bool _ -> true 248 | (* It's ok to define a function, but make sure no lambdas appear in the 249 | function body. *) 250 | let defn_lambda_free (ADefn (_, (AExpr (_, e) as expr)): 'a ann_defn) : bool = 251 | match e with 252 | | Lam {MR.bindings = _; MR.body = b} -> lambda_free b 253 | | _ -> lambda_free expr 254 | let writer_lambda_free ((expr, defns): ('a ann_expr, 'a) Defn_writer.t) 255 | : bool = 256 | List.for_all ~f:defn_lambda_free defns && lambda_free expr 257 | 258 | module Test_lambda_hoisting : sig 259 | val tests : U.test 260 | end = struct 261 | let test_hoist_unhoist e = 262 | e |> hoist_unhoist |> U.assert_bool "Hoist-unhoist match check" 263 | let test_lambda_free expr = 264 | expr |> expr_hoist_lambdas |> writer_lambda_free |> 265 | U.assert_bool "Generate lambda-free code" 266 | 267 | (* TODO: More test cases (code samples) *) 268 | let test_1 _ = 269 | escaping_function |> expr_of_maprep ["+"] |> test_hoist_unhoist 270 | let test_2 _ = 271 | escaping_function |> expr_of_maprep ["+"] |> test_lambda_free 272 | let test_3 _ = 273 | destr_dsum |> expr_of_maprep ["+"; "iota"] |> test_hoist_unhoist 274 | let test_4 _ = 275 | destr_dsum |> expr_of_maprep ["+"; "iota"] |> test_lambda_free 276 | (* Unhoisting does not work with lambda-producing lambdas (and making it 277 | work would risk infinitely unrolling a recursive call), so eliding this 278 | test. *) 279 | (* let test_5 _ = *) 280 | (* vec_scal_add |> expr_of_maprep ["+"] |> test_hoist_unhoist *) 281 | let test_6 _ = 282 | vec_scal_add |> expr_of_maprep ["+"] |> test_lambda_free 283 | let tests = 284 | let open OUnit2 in 285 | "hoist lambdas out to top-level definitions">::: 286 | ["lambda escaping with a let-bound var">:: test_1; 287 | "lambda escaping with a let-bound var">:: test_2; 288 | "destructing a dependent sum">:: test_3; 289 | "destructing a dependent sum">:: test_4; 290 | "adding vector and scalar">:: test_6] 291 | end 292 | 293 | module UnitTests : sig 294 | val tests : U.test 295 | end = struct 296 | let tests = 297 | let open OUnit2 in 298 | "Explicit-closure AST tests">::: 299 | [Test_closure_conversion.tests; 300 | Test_lambda_hoisting.tests] 301 | end 302 | -------------------------------------------------------------------------------- /src/remora-internal/test_closures.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Closures 30 | open Core.Std 31 | module MR = Map_replicate_ast;; 32 | module B = Basic_ast;; 33 | module U = OUnit2;; 34 | open Frame_notes;; 35 | 36 | val unary_lam : expr 37 | 38 | val mr_wrap : (E.typ * arg_frame * app_frame) MR.ann_expr MR.expr_form 39 | -> (E.typ * arg_frame * app_frame) MR.ann_expr 40 | 41 | val escaping_function : (E.typ * arg_frame * app_frame) MR.ann_expr 42 | 43 | val converted : expr 44 | 45 | val destr_dsum : (E.typ * arg_frame * app_frame) MR.ann_expr 46 | val vec_scal_add : (E.typ * arg_frame * app_frame) MR.ann_expr 47 | 48 | val subst : (var, expr) List.Assoc.t -> expr -> expr 49 | 50 | val alpha_eqv : expr -> expr -> bool 51 | 52 | val expr_unhoist : ('a ann_expr, 'a) Defn_writer.t -> expr 53 | val hoist_unhoist : 'a ann_expr -> bool 54 | 55 | val lambda_free : 'a ann_expr -> bool 56 | val defn_lambda_free : 'a ann_defn -> bool 57 | val writer_lambda_free : ('a ann_expr, 'a) Defn_writer.t -> bool 58 | 59 | module Test_closure_conversion : sig 60 | val tests : U.test 61 | end 62 | 63 | module Test_lambda_hoisting : sig 64 | val tests : U.test 65 | end 66 | 67 | module UnitTests : sig 68 | val tests : U.test 69 | end 70 | -------------------------------------------------------------------------------- /src/remora-internal/test_erased_ast.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | module E = Erased_ast;; 31 | open E 32 | module TB = Test_basic_ast;; 33 | module B = Basic_ast;; 34 | module U = OUnit2;; 35 | 36 | let scalar_of_elt e = EExpr (Arr ([], [e])) 37 | let scalar_of_elt_form e = scalar_of_elt (EElt e) 38 | 39 | let flat_arr_2_3 = 40 | EExpr (Arr ([2; 3], 41 | [EElt (Int 4); EElt (Int 1); EElt (Int 6); 42 | EElt (Int 2); EElt (Int 3); EElt (Int 5)])) 43 | 44 | let nest_arr_2_3 = 45 | EExpr 46 | (Arr ([2], 47 | [EElt (Expr 48 | (EExpr (Arr ([3], 49 | [EElt (Int 4); EElt (Int 1); EElt (Int 6)])))); 50 | EElt (Expr 51 | (EExpr (Arr ([3], 52 | [EElt (Int 2); EElt (Int 3); EElt (Int 5)]))))])) 53 | 54 | let unary_lambda = EElt (Lam (["x"], (EExpr (Arr ([], [EElt (Int 3)]))))) 55 | 56 | let binary_lambda = EElt (Lam (["x"; "y"], 57 | EExpr (Arr ([], [EElt (Int 3)])))) 58 | 59 | let unary_app = EExpr (App (EExpr (Arr ([], [unary_lambda])), 60 | [EExpr (Arr ([], [EElt (Int 4)]))], 61 | TUnknown)) 62 | 63 | let binary_app = EExpr (App (EExpr (Arr ([], [binary_lambda])), 64 | [EExpr (Arr ([3], [EElt (Float 1.0); 65 | EElt (Float 2.0); 66 | EElt (Float 3.0)])); 67 | EExpr (Arr ([1], [EElt (Bool false)]))], 68 | TUnknown)) 69 | 70 | let type_abst = EExpr (Arr ([], [EElt (Lam (["x"], EExpr (Var "x")))])) 71 | 72 | let index_abst = EExpr (ILam (["d", B.SNat], 73 | EExpr (Arr ([], 74 | [EElt (Lam (["l"], 75 | EExpr (Var "l")))])))) 76 | 77 | let index_app = EExpr (IApp (index_abst, 78 | [B.INat 6])) 79 | 80 | let dep_sum_create = 81 | EExpr (Pack ([B.INat 3], 82 | EExpr (Arr ([3], 83 | [EElt (Int 0); EElt (Int 1); EElt (Int 2)])))) 84 | 85 | let dep_sum_project = 86 | EExpr (Unpack (["l"], "c", 87 | dep_sum_create, 88 | EExpr (Arr ([], [EElt (Int 0)])))) 89 | 90 | let remora_compose = 91 | EExpr 92 | (ILam 93 | (["s1", B.SShape; "s2", B.SShape; "s3", B.SShape], 94 | scalar_of_elt_form 95 | (Lam (["f"; "g"], 96 | scalar_of_elt_form 97 | (Lam (["x"], 98 | EExpr 99 | (App (EExpr (Var "g"), 100 | [EExpr (App (EExpr (Var "f"), 101 | [EExpr (Var "x")], 102 | TUnknown))], 103 | TUnknown)))))))) 104 | 105 | let define_compose = 106 | EDefn ("compose", 107 | TDProd 108 | (["s1", B.SShape; "s2", B.SShape; "s3", B.SShape], 109 | TArray (B.IShape [], 110 | TFun 111 | ([(TArray (B.IShape [], 112 | TFun ([TArray (B.ivar "s1", TVar)], 113 | TArray (B.ivar "s2", TVar)))); 114 | (TArray (B.IShape [], 115 | TFun ([TArray (B.ivar "s2", TVar)], 116 | TArray (B.ivar "s3", TVar))))], 117 | TArray (B.IShape [], 118 | TFun ([TArray (B.ivar "s1", TVar)], 119 | TArray (B.ivar "s3", TVar)))))), 120 | remora_compose) 121 | 122 | let use_compose = 123 | EExpr 124 | (App (EExpr (App (EExpr (IApp (EExpr (Var "compose"), 125 | [B.IShape []; B.IShape []; B.IShape []])), 126 | [scalar_of_elt unary_lambda; 127 | scalar_of_elt unary_lambda], 128 | TUnknown)), 129 | [scalar_of_elt_form (Int 0)], 130 | TUnknown)) 131 | 132 | let prog_compose = EProg ([define_compose], use_compose) 133 | 134 | module Test_erasure : sig 135 | val tests : U.test 136 | end = struct 137 | let assert_expr_erasure pre post _ = 138 | U.assert_equal post 139 | (pre |> B.annot_expr_init ~init:() 140 | |> E.of_ann_expr ~merge:const 141 | |> E.annot_expr_drop) 142 | let assert_elt_erasure pre post _ = 143 | U.assert_equal post 144 | (pre |> B.annot_elt_init ~init:() 145 | |> E.of_ann_elt ~merge:const 146 | |> E.annot_elt_drop) 147 | let assert_defn_erasure pre post _ = 148 | U.assert_equal post 149 | (pre |> B.annot_defn_init ~init:() 150 | |> E.of_ann_defn ~merge:const 151 | |> E.annot_defn_drop) 152 | let assert_prog_erasure pre post _ = 153 | U.assert_equal post 154 | (pre |> B.annot_prog_init ~init:() 155 | |> E.of_ann_prog ~merge:const 156 | |> E.annot_prog_drop) 157 | let tests = 158 | let open OUnit2 in 159 | "Generate type annotations, type-erase, drop annotations">::: 160 | ["flat 2x3">:: assert_expr_erasure TB.flat_arr_2_3 flat_arr_2_3; 161 | "nested 2x3">:: assert_expr_erasure TB.nest_arr_2_3 nest_arr_2_3; 162 | "unary lambda">:: assert_elt_erasure TB.unary_lambda unary_lambda; 163 | "binary lambda">:: assert_elt_erasure TB.binary_lambda binary_lambda; 164 | "unary app">:: assert_expr_erasure TB.unary_app unary_app; 165 | "binary app">:: assert_expr_erasure TB.binary_app binary_app; 166 | (* The TLam and TApp examples should both erase to the same AST *) 167 | "type abstraction">:: assert_expr_erasure TB.type_abst type_abst; 168 | "type application">:: assert_expr_erasure TB.type_app type_abst; 169 | "index abstraction">:: assert_expr_erasure TB.index_abst index_abst; 170 | "index application">:: assert_expr_erasure TB.index_app index_app; 171 | "construct dependent sum">:: 172 | assert_expr_erasure TB.dep_sum_create dep_sum_create; 173 | "destruct dependent sum">:: 174 | assert_expr_erasure TB.dep_sum_project dep_sum_project; 175 | "composition">:: 176 | assert_expr_erasure TB.remora_compose remora_compose; 177 | "defining composition">:: 178 | assert_defn_erasure TB.define_compose define_compose; 179 | "program with composition">:: 180 | assert_prog_erasure TB.prog_compose prog_compose] 181 | end 182 | 183 | module UnitTests : sig 184 | val tests : U.test 185 | end = struct 186 | let tests = 187 | let open OUnit2 in 188 | "erasure tests">::: 189 | [Test_erasure.tests] 190 | end 191 | -------------------------------------------------------------------------------- /src/remora-internal/test_frame_notes.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright 2015 NVIDIA Corporation. All rights reserved. *) 3 | (* *) 4 | (* NOTICE TO USER: The source code, and related code and software *) 5 | (* ("Code"), is copyrighted under U.S. and international laws. *) 6 | (* *) 7 | (* NVIDIA Corporation owns the copyright and any patents issued or *) 8 | (* pending for the Code. *) 9 | (* *) 10 | (* NVIDIA CORPORATION MAKES NO REPRESENTATION ABOUT THE SUITABILITY *) 11 | (* OF THIS CODE FOR ANY PURPOSE. IT IS PROVIDED "AS-IS" WITHOUT EXPRESS *) 12 | (* OR IMPLIED WARRANTY OF ANY KIND. NVIDIA CORPORATION DISCLAIMS ALL *) 13 | (* WARRANTIES WITH REGARD TO THE CODE, INCLUDING NON-INFRINGEMENT, AND *) 14 | (* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 15 | (* PURPOSE. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY *) 16 | (* DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES *) 17 | (* WHATSOEVER ARISING OUT OF OR IN ANY WAY RELATED TO THE USE OR *) 18 | (* PERFORMANCE OF THE CODE, INCLUDING, BUT NOT LIMITED TO, INFRINGEMENT, *) 19 | (* LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, *) 20 | (* NEGLIGENCE OR OTHER TORTIOUS ACTION, AND WHETHER OR NOT THE *) 21 | (* POSSIBILITY OF SUCH DAMAGES WERE KNOWN OR MADE KNOWN TO NVIDIA *) 22 | (* CORPORATION. *) 23 | (******************************************************************************) 24 | 25 | open Basic_ast 26 | open Typechecker 27 | open Frame_notes 28 | open Core.Std 29 | open Core.Option 30 | open Core.Option.Monad_infix 31 | module U = OUnit2;; 32 | 33 | module Test_annot_expr_app_frame : sig 34 | val tests: U.test 35 | end = struct 36 | open Test_basic_ast 37 | let test_1 _ = 38 | U.assert_equal 39 | ((nested_to_unary_app 40 | |> Typechecker.Passes.expr_all) >>= fun typed_ast -> 41 | (typed_ast 42 | |> annot_expr_app_frame 43 | |> annot_of_expr 44 | |> return)) 45 | (Some (AppFrame [IShape [INat 3]; IShape [INat 2]])) 46 | let tests = 47 | let open OUnit2 in 48 | "add application frame shape annotation">::: 49 | ["apply array of functions to scalar">:: test_1] 50 | end 51 | 52 | (* Some pieces shared by multiple test modules later *) 53 | let typed_lifted_curried_addition = 54 | Option.value_exn 55 | (Test_basic_ast.lift_curried_add 56 | |> annot_expr_init ~init:() 57 | |> annot_expr_type [] [] 58 | ["c+", 59 | TArray (IShape [], 60 | TFun ([TArray (IShape [], TInt)], 61 | TArray (IShape [], 62 | TFun ([TArray (IShape [], TInt)], 63 | TArray (IShape [], TInt)))))] 64 | |> well_typed_of_expr) 65 | let vec_2 = AnnRExpr (ArgFrame {frame = [IShape [INat 2]]; 66 | expansion = []}, 67 | Arr ([2], [AnnRElt (NotArg, Int 10); 68 | AnnRElt (NotArg, Int 20)])) 69 | let mat_2_3 = 70 | AnnRExpr (ArgFrame {frame = [IShape [INat 2]; IShape [INat 3]]; 71 | expansion = []}, 72 | Arr ([2; 3], 73 | [AnnRElt (NotArg, Int 1); AnnRElt (NotArg, Int 2); 74 | AnnRElt (NotArg, Int 3); AnnRElt (NotArg, Int 4); 75 | AnnRElt (NotArg, Int 5); AnnRElt (NotArg, Int 6)])) 76 | let annotated_expr = 77 | AnnRExpr (NotArg, 78 | App (AnnRExpr (ArgFrame 79 | {frame = [IShape [INat 2]]; 80 | expansion = [IShape [INat 3]]}, 81 | App (AnnRExpr 82 | (ArgFrame {frame = [IShape []]; 83 | expansion = [IShape [INat 2]]}, 84 | Var "c+"), 85 | [vec_2])), 86 | [mat_2_3])) 87 | 88 | module Test_annot_expr_arg_expansion : sig 89 | val tests: U.test 90 | end = struct 91 | open Test_basic_ast 92 | let expr_arg_notes base target = 93 | let app = annot_expr_app_frame base in 94 | let app_typ_ = (Annotation.annot_expr_merge Tuple2.create app base) in 95 | let app_typ = match app_typ_ with 96 | | Some x -> x 97 | | None -> U.assert_failure 98 | "+----\n+ could not merge app/typ\n+-----\n" in 99 | let arg = annot_expr_arg_expansion 100 | ~outer_frame:NotApp 101 | ~outer_expectation:None 102 | app_typ in 103 | U.assert_equal arg target 104 | let test_1 _ = 105 | expr_arg_notes typed_lifted_curried_addition annotated_expr 106 | let tests = 107 | let open OUnit2 in 108 | "add argument frame shape annotations to expr">::: 109 | ["lifting curried addition">:: test_1] 110 | end 111 | 112 | module Test_annot_prog_arg_frame : sig 113 | val tests: U.test 114 | end = struct 115 | open Test_basic_ast 116 | open Typechecker 117 | let test_1 _ = 118 | let typed_curried_addition = 119 | Option.value_exn 120 | (prog_curried_add 121 | |> (annot_prog_init ~init:()) 122 | |> (annot_prog_type [] [] 123 | ["+", TArray (IShape [], 124 | TFun ([TArray (IShape [], TInt); 125 | TArray (IShape [], TInt)], 126 | TArray (IShape [], TInt)))]) 127 | |> well_typed_of_prog) in 128 | let app_expr = AnnRExpr (NotArg, 129 | App (AnnRExpr (ArgFrame 130 | {frame = [IShape []]; 131 | expansion = []}, 132 | Var "+"), 133 | [AnnRExpr (ArgFrame 134 | {frame = []; 135 | expansion = []}, 136 | Var "x"); 137 | AnnRExpr (ArgFrame 138 | {frame = []; 139 | expansion = []}, 140 | Var "y")])) in 141 | let inner_fun = 142 | AnnRExpr (NotArg, 143 | Arr ([], [AnnRElt (NotArg, 144 | Lam ([("y", TArray (IShape [], TInt))], 145 | app_expr))])) in 146 | let annotated_curried_add = 147 | AnnRExpr (NotArg, Arr ([], 148 | [AnnRElt 149 | (NotArg, 150 | Lam 151 | ([("x", TArray (IShape [], TInt))], 152 | inner_fun))])) in 153 | let curried_add_type = 154 | TArray (IShape [], 155 | TFun ([TArray (IShape [], TInt)], 156 | TArray (IShape [], 157 | TFun ([TArray (IShape [], TInt)], 158 | TArray (IShape [], TInt))))) in 159 | let annotated_defn = 160 | AnnRDefn ("c+", curried_add_type, annotated_curried_add) in 161 | let test_case = (annot_prog_arg_expansion 162 | (Option.value_exn 163 | (Annotation.annot_prog_merge 164 | Tuple2.create 165 | (annot_prog_app_frame typed_curried_addition) 166 | typed_curried_addition))) 167 | and target = (AnnRProg 168 | (NotArg, 169 | [annotated_defn], 170 | annotated_expr)) in 171 | U.assert_equal test_case target 172 | let tests = 173 | let open OUnit2 in 174 | "add argument frame shape annotations to program">::: 175 | ["lifting curried addition">:: test_1] 176 | end 177 | 178 | module UnitTests : sig 179 | val tests: U.test 180 | end = struct 181 | let tests = 182 | let open OUnit2 in 183 | "frame notes tests">::: 184 | [Test_annot_expr_app_frame.tests; 185 | Test_annot_expr_arg_expansion.tests; 186 | Test_annot_prog_arg_frame.tests] 187 | end 188 | -------------------------------------------------------------------------------- /src/remora-internal/test_map_replicate_ast.ml: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright 2015 NVIDIA Corporation. All rights reserved. *) 3 | (* *) 4 | (* NOTICE TO USER: The source code, and related code and software *) 5 | (* ("Code"), is copyrighted under U.S. and international laws. *) 6 | (* *) 7 | (* NVIDIA Corporation owns the copyright and any patents issued or *) 8 | (* pending for the Code. *) 9 | (* *) 10 | (* NVIDIA CORPORATION MAKES NO REPRESENTATION ABOUT THE SUITABILITY *) 11 | (* OF THIS CODE FOR ANY PURPOSE. IT IS PROVIDED "AS-IS" WITHOUT EXPRESS *) 12 | (* OR IMPLIED WARRANTY OF ANY KIND. NVIDIA CORPORATION DISCLAIMS ALL *) 13 | (* WARRANTIES WITH REGARD TO THE CODE, INCLUDING NON-INFRINGEMENT, AND *) 14 | (* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 15 | (* PURPOSE. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY *) 16 | (* DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES *) 17 | (* WHATSOEVER ARISING OUT OF OR IN ANY WAY RELATED TO THE USE OR *) 18 | (* PERFORMANCE OF THE CODE, INCLUDING, BUT NOT LIMITED TO, INFRINGEMENT, *) 19 | (* LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, *) 20 | (* NEGLIGENCE OR OTHER TORTIOUS ACTION, AND WHETHER OR NOT THE *) 21 | (* POSSIBILITY OF SUCH DAMAGES WERE KNOWN OR MADE KNOWN TO NVIDIA *) 22 | (* CORPORATION. *) 23 | (******************************************************************************) 24 | open Core.Std 25 | open Core.Option 26 | open Option.Monad_infix 27 | open Map_replicate_ast 28 | module T = Typechecker;; 29 | module U = OUnit2;; 30 | 31 | let subst_expr_form 32 | (recur: 'value T.env -> 'subexp -> 'subexp) 33 | (subst: 'value T.env) 34 | (ef: 'subexp expr_form) : 'subexp expr_form = 35 | match ef with 36 | | Var name -> List.Assoc.find subst name |> Option.value ~default:(Var name) 37 | | Lam {bindings = bindings; body = body} -> 38 | Lam {bindings = bindings; 39 | body = (recur 40 | (List.fold ~init:subst 41 | ~f:(List.Assoc.remove ~equal:(=)) bindings) body)} 42 | | LetTup {vars = vars; bound = bound; body = body} -> 43 | LetTup {vars = vars; 44 | bound = (recur subst bound); 45 | body = (recur (List.fold ~init:subst 46 | ~f:(List.Assoc.remove ~equal:(=)) vars) body)} 47 | | _ -> map_expr_form ~f:(recur subst) ef 48 | 49 | let rec split_list (increment: int) (xs: 'a list) : 'a list list = 50 | match List.take xs increment with 51 | | [] -> [] 52 | | first_block -> 53 | first_block :: split_list increment (List.drop xs increment) 54 | 55 | let rec repeat_list (count: int) (xs: 'a list) : 'a list = 56 | if count <= 0 57 | then [] 58 | else List.append xs (repeat_list (count - 1) xs) 59 | 60 | let known_shape (Expr e) : int = 61 | match e with 62 | | Vec {dims = dims; elts = _} -> dims 63 | (* Built-ins. If these are shadowed, evaluator may misbehave. *) 64 | | Var ("+" | "*" | "+."| "*.") -> 1 65 | | _ -> -1 66 | 67 | let is_Int (Expr e) = (match e with | Int _ -> true | _ -> false) 68 | 69 | let to_int (Expr e) = (match e with | Int i -> i | _ -> assert false) 70 | 71 | let value_as_vector (Expr e) = 72 | match e with 73 | | Tup _ | Lam _ | Int _ | Float _ | Bool _ -> 74 | Expr (Vec {dims = 1; elts = [Expr e]}) 75 | | _ -> Expr e 76 | 77 | (* TODO: for scalar frame, arg cells should not get array-wrapped *) 78 | let cell_split ~(frame: int) ((Expr array): expr) : expr list option = 79 | match array with 80 | | Vec {dims = dims; elts = elts} -> 81 | (* let cell_shape = List.drop dims (List.length frame) in *) 82 | (* let cell_size = List.fold_right ~f:( * ) ~init:1 cell_shape in *) 83 | let cell_size = dims / frame in 84 | let cell_elts = split_list cell_size elts in 85 | if cell_size = 1 86 | then List.map ~f:List.hd cell_elts |> Option.all 87 | else Some (List.map 88 | ~f:(fun c -> Expr (Vec {dims = cell_size; elts = c})) 89 | cell_elts) 90 | | _ -> None 91 | 92 | exception Stuck_eval;; 93 | let apply_primop (opname: var) (args: expr list) : expr = 94 | match opname with 95 | | "append" -> 96 | (* Make sure we have args in vector form. *) 97 | (try (let vec_args = 98 | List.map ~f:(fun (Expr v) -> match v with 99 | | Vec _ -> Expr v 100 | (* Coerce non-vector "value forms" to singleton vectors *) 101 | | Tup _ | Lam _ | Int _ | Float _ | Bool _ -> 102 | Expr (Vec {dims = 1; elts = [Expr v]}) 103 | | _ -> raise Stuck_eval) args in 104 | let arg_dims = 105 | List.map ~f:(fun (Expr v) -> match v with 106 | | (Vec {dims = d; elts = _}) -> d 107 | | _ -> raise Stuck_eval) vec_args in 108 | let arg_elts = 109 | List.map ~f:(fun (Expr v) -> match v with 110 | | (Vec {dims = _; elts = e}) -> e 111 | | _ -> raise Stuck_eval) vec_args in 112 | Expr (Vec {dims = List.fold ~init:1 ~f:( * ) arg_dims; 113 | elts = List.join arg_elts})) with 114 | (* Some arg wasn't reduced to a Vec form. *) 115 | | _ -> Expr (App {fn = Expr (Var opname); args = args})) 116 | | "+" -> (match args with 117 | | [Expr (Int n1); Expr (Int n2)] -> Expr (Int (n1 + n2)) 118 | | _ -> Expr (App {fn = Expr (Var opname); args = args})) 119 | | "*" -> (match args with 120 | | [Expr (Int n1); Expr (Int n2)] -> Expr (Int (n1 * n2)) 121 | | _ -> Expr (App {fn = Expr (Var opname); args = args})) 122 | | "+." -> (match args with 123 | | [Expr (Float n1); Expr (Float n2)] -> Expr (Float (n1 +. n2)) 124 | | _ -> Expr (App {fn = Expr (Var opname); args = args})) 125 | | "*." -> (match args with 126 | | [Expr (Float n1); Expr (Float n2)] -> Expr (Float (n1 *. n2)) 127 | | _ -> Expr (App {fn = Expr (Var opname); args = args})) 128 | | _ -> Expr (App {fn = Expr (Var opname); args = args}) 129 | 130 | (* Expression evaluator, to allow more flexible tests. *) 131 | let rec eval_expr ~(env: expr T.env) (Expr e) : expr = 132 | match e with 133 | | App {fn = fn; args = args} -> 134 | let fn_val = eval_expr ~env:env fn 135 | and args_val = List.map ~f:(eval_expr ~env:env) args in 136 | (match fn_val with 137 | | Expr (Lam {bindings = bindings; body = body}) -> 138 | eval_expr ~env:(List.append (List.zip_exn bindings args_val) env) body 139 | (* Special recognition for a few primitive ops *) 140 | | Expr (Var op) -> 141 | (* TODO: apply_primop may return the same thing we gave it, so don't 142 | blindly recur on what it returns. *) 143 | apply_primop op args_val 144 | (* Unrecognized/un-evaluated operator *) 145 | | _ -> Expr (App {fn = fn_val; args = args_val})) 146 | | Vec {dims = dims; elts = elts} -> 147 | let elts_val = List.map ~f:(eval_expr ~env:env) elts in 148 | (* If the evaluated elts all match in shape, collapse one nest level. *) 149 | (match elts_val with 150 | | (Expr (Vec el)) :: els -> 151 | if (List.for_all 152 | ~f:(fun (Expr v) -> match v with 153 | | (Vec i) -> i.dims = el.dims | _ -> false) 154 | els) 155 | then let (joined_dims, joined_elts) = 156 | (dims * el.dims, 157 | List.join (List.map 158 | ~f:(fun (Expr v) -> match v with 159 | (* They should all be Vec if this is reached *) 160 | | (Vec i) -> i.elts | _ -> assert false) 161 | elts_val)) 162 | in eval_expr ~env:env (Expr (Vec {dims = joined_dims; 163 | elts = joined_elts})) 164 | else Expr (Vec {dims = dims; elts = elts_val}) 165 | | _ -> Expr (Vec {dims = dims; elts = elts_val})) (* in *) 166 | (* Expr (Vec {dims = joined_dims; elts = joined_elts}) *) 167 | | Map {frame = frame; fn = fn; args = args; shp = shp} -> 168 | let frame_val = eval_expr ~env:env frame 169 | and fn_val = eval_expr ~env:env fn 170 | and args_val = List.map ~f:(eval_expr ~env:env) args 171 | and shp_val = eval_expr ~env:env shp |> value_as_vector in 172 | let eval_stuck = Expr (Map {frame = frame_val; 173 | fn = fn_val; 174 | args = args_val; 175 | shp = shp_val}) in 176 | let args_axes = List.map ~f:known_shape args_val in 177 | (match (fn_val, frame_val, shp_val) with 178 | (* We need fn_val to be a Lam and frame and shp to be valid shapes. *) 179 | | ((Expr (Lam {bindings = _; body = _}) as fn_lam | 180 | Expr (Vec {dims = 1; 181 | elts = [Expr (Lam {bindings = _; body = _}) as fn_lam]})), 182 | Expr (Int frame_size), 183 | (* Expr (Vec {dims = [_]; elts = frame_elts}), *) 184 | Expr (Vec {dims = _; elts = shp_elts})) -> 185 | (* let frame_axes = List.map ~f:to_int frame_elts in *) 186 | (* We need args to be arrays big enough to split into cells. *) 187 | if (List.for_all ~f:(fun s -> frame_size <= s) args_axes) 188 | then 189 | ((Option.all 190 | (List.map ~f:(cell_split ~frame:frame_size) args_val) 191 | >>= fun args_cells -> 192 | List.transpose args_cells >>= fun transp_cells -> 193 | let apps = List.map 194 | ~f:(fun cells -> Expr (App {fn = fn_lam; args = cells})) 195 | transp_cells in 196 | return 197 | (if List.length transp_cells = 0 198 | (* No result cells, so use the declared result shape. *) 199 | then eval_expr ~env:env 200 | (Expr (Vec {dims = List.fold ~init:1 ~f:( * ) 201 | (List.map ~f:to_int shp_elts); 202 | elts = []})) 203 | (* Evaluate the vector of result cells. *) 204 | else eval_expr ~env:env 205 | (Expr (Vec {dims = frame_size; elts = apps})))) 206 | |> Option.value ~default:eval_stuck) 207 | else eval_stuck 208 | | _ -> eval_stuck) 209 | | Rep {arg = arg; old_frame = old_frame; new_frame = new_frame} -> 210 | let arg_val = eval_expr ~env:env arg 211 | and old_frame_val = eval_expr ~env:env old_frame 212 | and new_frame_val = eval_expr ~env:env new_frame in 213 | let eval_stuck = Expr (Rep {arg = arg_val; 214 | old_frame = old_frame_val; 215 | new_frame = new_frame_val}) in 216 | (* Make sure we have evaluated everything far enough that we know the 217 | entire old_frame and new_frame. *) 218 | (match (arg_val, old_frame_val, new_frame_val) with 219 | | (Expr (Vec {dims = arg_val_dims; elts = arg_val_elts}), 220 | Expr (Int old_frame), 221 | Expr (Int new_frame)) -> 222 | (* Make sure that the old_frame is a prefix of the new_frame, and 223 | that the old_frame is a prefix of arg's shape. *) 224 | if (arg_val_dims >= old_frame) 225 | (* Copy each cell of the argument. *) 226 | (* 1. identify (visible portion of) cell shape *) 227 | (* 2. split arg_val_elts into cells *) 228 | then 229 | let expansion_size = (new_frame / old_frame) 230 | and cell_size = (arg_val_dims / old_frame) in 231 | (* let cell_size = List.fold_right ~init:1 ~f:( * ) cell_shape in *) 232 | let cells = split_list cell_size arg_val_elts in 233 | let more_cells: expr list = 234 | List.join (List.map ~f:(repeat_list expansion_size) cells) in 235 | let more_dims = new_frame * cell_size in 236 | Expr (Vec {dims = more_dims; elts = more_cells}) 237 | else eval_stuck 238 | (* Or if it's a non-array value (tuple) with scalar original frame and 239 | known target frame, copy it as needed. *) 240 | | (Expr (Tup elts), 241 | Expr (Int 1), Expr (Int n)) -> 242 | Expr (Vec {dims = n; elts = List.init n ~f:(fun _ -> arg_val)}) 243 | | _ -> eval_stuck 244 | ) 245 | | Tup elts -> Expr( Tup (List.map ~f:(eval_expr ~env:env) elts)) 246 | | Lam {bindings = _; body = _} as lam -> Expr lam 247 | | Fld {field = n; tuple = tup} -> 248 | let tup_val = eval_expr ~env:env tup in 249 | (match tup_val with 250 | | Expr Tup elts when List.length elts > n -> 251 | List.nth_exn elts n 252 | | _ -> tup_val) 253 | | LetTup {vars = vars; bound = bound; body = body} -> 254 | let bound_val = eval_expr ~env:env bound in 255 | (match bound_val with 256 | | Expr Tup elts -> 257 | eval_expr ~env:(List.append (List.zip_exn vars elts) env) body 258 | | _ -> Expr (LetTup {vars = vars; bound = bound_val; body = body})) 259 | | Let {var = var; bound = bound; body = body} -> 260 | let bound_val = eval_expr ~env:env bound in 261 | eval_expr ~env:((var, bound_val) :: env) body 262 | | Var name -> 263 | List.Assoc.find env name |> Option.value ~default:(Expr (Var name)) 264 | | Bool _ | Float _ | Int _ as c -> Expr c 265 | ;; 266 | 267 | let flat_arr_2_3 = 268 | Expr (Vec {dims = 6; elts = [Expr (Int 4); Expr (Int 1); Expr (Int 6); 269 | Expr (Int 2); Expr (Int 3); Expr (Int 5)]}) 270 | let arr_2 = Expr (Vec {dims = 2; 271 | elts = [Expr (Bool false); Expr (Bool true)]}) 272 | let unary_lambda = Expr (Lam {bindings = ["x"]; 273 | body = Expr (Vec {dims = 1; 274 | elts = [Expr (Int 3)]})}) 275 | let binary_lambda = Expr (Lam {bindings = ["x"; "y"]; 276 | body = Expr (Vec {dims = 1; 277 | elts = [Expr (Int 3)]})}) 278 | let unary_app = Expr (Vec {dims = 1; elts = [Expr (Int 3)]}) 279 | let unary_to_nested_app = 280 | Expr (Vec {dims = 6; elts = [Expr (Int 1); Expr (Int 2); Expr (Int 3); 281 | Expr (Int 4); Expr (Int 5); Expr (Int 6)]}) 282 | let nested_to_unary_app = 283 | Expr (Vec {dims = 6; elts = [Expr (Int 1); Expr (Int 2); Expr (Int 3); 284 | Expr (Int 4); Expr (Int 5); Expr (Int 6)]}) 285 | let type_abst = Expr (Vec {dims = 1; 286 | elts = [Expr (Lam {bindings = ["x"]; 287 | body = Expr (Var "x")})]}) 288 | let index_abst = 289 | Expr (Lam {bindings = [idx_name_mangle "d" (Some B.SNat)]; 290 | body = Expr (Vec {dims = 1; 291 | elts = [Expr (Lam {bindings = ["l"]; 292 | body = Expr (Var "l")})]})}) 293 | let index_app = 294 | Expr (Vec {dims = 1; 295 | elts = [Expr (Lam {bindings = ["l"]; 296 | body = Expr (Var "l")})]}) 297 | let dep_sum_create = 298 | Expr (Tup [Expr (Vec {dims = 3; 299 | elts = [Expr (Int 0); Expr (Int 1); Expr (Int 2)]}); 300 | Expr (Int 3)]) 301 | let dep_sum_project = 302 | Expr (Vec {dims = 1; elts = [Expr (Int 0)]}) 303 | 304 | module Test_translation : sig 305 | val tests : U.test 306 | end = struct 307 | let assert_translate_expr original final = 308 | U.assert_equal 309 | (original 310 | |> Passes.expr_all 311 | >>| annot_expr_drop 312 | >>| (eval_expr ~env:[])) 313 | (Some final) 314 | let assert_translate_elt original final = 315 | U.assert_equal 316 | (original 317 | |> Passes.elt_all 318 | >>| annot_expr_drop 319 | >>| (eval_expr ~env:[])) 320 | (Some final) 321 | module TB = Test_basic_ast;; 322 | let test_1 _ = 323 | assert_translate_expr TB.flat_arr_2_3 flat_arr_2_3 324 | let test_2 _ = 325 | assert_translate_expr TB.arr_2 arr_2 326 | let test_3 _ = 327 | assert_translate_expr TB.nest_arr_2_3 flat_arr_2_3 328 | let test_4 _ = 329 | assert_translate_elt TB.unary_lambda unary_lambda 330 | let test_5 _ = 331 | assert_translate_elt TB.binary_lambda binary_lambda 332 | let test_6 _ = 333 | assert_translate_expr TB.unary_app unary_app 334 | let test_7 _ = 335 | assert_translate_expr TB.binary_app unary_app 336 | let test_8 _ = 337 | assert_translate_expr TB.unary_to_nested_app unary_to_nested_app 338 | let test_9 _ = 339 | assert_translate_expr TB.nested_to_unary_app nested_to_unary_app 340 | let test_10 _ = 341 | assert_translate_expr TB.type_abst type_abst 342 | let test_11 _ = 343 | assert_translate_expr TB.type_app type_abst 344 | let test_12 _ = 345 | assert_translate_expr TB.index_abst index_abst 346 | let test_13 _ = 347 | assert_translate_expr TB.index_app index_app 348 | let test_14 _ = 349 | assert_translate_expr TB.dep_sum_create dep_sum_create 350 | let test_15 _ = 351 | assert_translate_expr TB.dep_sum_project dep_sum_project 352 | let tests = 353 | let open OUnit2 in 354 | "translate from basic AST to Map/Replicate & evaluate">::: 355 | ["flat 2x3">:: test_1; 356 | "2-vector">:: test_2; 357 | "collapse nested 2x3 to flat">:: test_3; 358 | "unary lambda">:: test_4; 359 | "binary lambda">:: test_5; 360 | "unary app producing 3-vector">:: test_6; 361 | "binary app producing 3-vector">:: test_7; 362 | "unary-to-nested produces flat 2x3">:: test_8; 363 | "nested-to-unary produces flat 3x2">:: test_9; 364 | "type abstraction erased">:: test_10; 365 | "type application erased">:: test_11; 366 | "index abstraction becomes term abstraction">:: test_12; 367 | "index application becomes term application">:: test_13; 368 | "dependent sum becomes tuple">:: test_14; 369 | "destruct tuple by let-binding">:: test_15] 370 | end 371 | 372 | module UnitTests : sig 373 | val tests : U.test 374 | end = struct 375 | let tests = 376 | let open OUnit2 in 377 | "Map-Replicate AST tests">::: 378 | [Test_translation.tests] 379 | end 380 | -------------------------------------------------------------------------------- /src/remora-internal/typechecker.mli: -------------------------------------------------------------------------------- 1 | (******************************************************************************) 2 | (* Copyright (c) 2015, NVIDIA CORPORATION. All rights reserved. *) 3 | (* *) 4 | (* Redistribution and use in source and binary forms, with or without *) 5 | (* modification, are permitted provided that the following conditions *) 6 | (* are met: *) 7 | (* * Redistributions of source code must retain the above copyright *) 8 | (* notice, this list of conditions and the following disclaimer. *) 9 | (* * Redistributions in binary form must reproduce the above copyright *) 10 | (* notice, this list of conditions and the following disclaimer in the *) 11 | (* documentation and/or other materials provided with the distribution. *) 12 | (* * Neither the name of NVIDIA CORPORATION nor the names of its *) 13 | (* contributors may be used to endorse or promote products derived *) 14 | (* from this software without specific prior written permission. *) 15 | (* *) 16 | (* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY *) 17 | (* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *) 18 | (* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR *) 19 | (* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR *) 20 | (* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, *) 21 | (* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, *) 22 | (* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR *) 23 | (* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *) 24 | (* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *) 25 | (* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE *) 26 | (* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) 27 | (******************************************************************************) 28 | 29 | open Core.Std 30 | open Basic_ast 31 | 32 | type 'a env = (var, 'a) Core.Std.List.Assoc.t with sexp 33 | type kind = unit with sexp 34 | 35 | val env_update : 'a env -> 'a env -> 'a env 36 | 37 | val srt_of_idx : srt env -> idx -> srt option 38 | 39 | val kind_of_typ : srt env -> kind env -> typ -> kind option 40 | 41 | val uniq_typ : typ list -> typ option 42 | 43 | val expand_shape : idx -> idx list option 44 | 45 | val shape_drop : idx list -> idx list -> idx list option 46 | 47 | val shape_of_typ : typ -> idx list option 48 | val elt_of_typ : typ -> typ option 49 | 50 | val idx_equal : idx -> idx -> bool 51 | 52 | val canonicalize_typ : typ -> typ option 53 | 54 | val prefix_of : 'a list -> 'a list -> bool option 55 | 56 | val typ_equal : typ -> typ -> bool 57 | 58 | val frame_contribution : typ -> typ -> idx list option 59 | 60 | val typ_of_shape : typ -> idx list -> typ 61 | val canonical_typ_of_shape : typ -> idx list -> typ option 62 | 63 | val annot_elt_type : srt env -> kind env -> typ env -> 'a ann_elt 64 | -> typ option ann_elt 65 | val annot_expr_type : srt env -> kind env -> typ env -> 'a ann_expr 66 | -> typ option ann_expr 67 | val annot_defn_type : srt env -> kind env -> typ env -> 'a ann_defn 68 | -> typ option ann_defn 69 | val annot_prog_type : srt env -> kind env -> typ env -> 'a ann_prog 70 | -> pt_prog 71 | 72 | val well_typed_of_expr : typ option ann_expr -> typ ann_expr option 73 | val well_typed_of_elt : typ option ann_elt -> typ ann_elt option 74 | val well_typed_of_defn: typ option ann_defn -> typ ann_defn option 75 | val well_typed_of_prog : typ option ann_prog -> typ ann_prog option 76 | 77 | module Passes : sig 78 | val prog : 'a ann_prog -> typ ann_prog option 79 | val defn : 'a ann_defn -> typ ann_defn option 80 | val expr : 'a ann_expr -> typ ann_expr option 81 | val elt : 'a ann_elt -> typ ann_elt option 82 | 83 | val prog_all : rem_prog -> typ ann_prog option 84 | val defn_all : rem_defn -> typ ann_defn option 85 | val expr_all : rem_expr -> typ ann_expr option 86 | val elt_all : rem_elt -> typ ann_elt option 87 | end 88 | --------------------------------------------------------------------------------