├── .gitignore ├── CHANGES.md ├── LICENSE.md ├── README.md ├── dune-project ├── examples ├── dune ├── linalg.ml ├── ma.ml └── readme_example.ml ├── lib ├── dune ├── mask.ml ├── mask.mli ├── misc.ml ├── misc.mli ├── multidim_array.ml ├── multidim_array.mli ├── nat.ml ├── nat.mli ├── nat_defs.ml ├── nat_defs.mli ├── range.ml ├── range.mli ├── shape.ml ├── shape.mli ├── signatures.ml ├── signatures.mli ├── small_matrix.ml ├── small_matrix.mli ├── small_tensor.ml ├── small_tensor.mli ├── small_unified.ml ├── small_unified.mli ├── small_vec.ml ├── small_vec.mli ├── stencil.ml ├── stencil.mli ├── stride.ml ├── stride.mli ├── tensor.ml └── tensor.mli ├── ppx ├── dune └── ppx_tensority.ml ├── tensority.opam └── tensority_ppx.opam /.gitignore: -------------------------------------------------------------------------------- 1 | _build/* 2 | *~ 3 | *.merlin 4 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Octachron/tensority/6d1399957f5377813e08584c8b93b35e61b8eb46/CHANGES.md -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The Tensority library is distributed under the terms of the 2 | GNU Lesser General Public License (LGPL) version 2.1 (included below). 3 | 4 | As a special exception to the GNU Lesser General Public License, you 5 | may link, statically or dynamically, a "work that uses the OCaml Core 6 | System" with a publicly distributed version of the OCaml Core System 7 | to produce an executable file containing portions of the OCaml Core 8 | System, and distribute that executable file under terms of your 9 | choice, without any of the additional requirements listed in clause 6 10 | of the GNU Lesser General Public License. By "a publicly distributed 11 | version of the OCaml Core System", we mean either the unmodified OCaml 12 | Core System as distributed by INRIA, or a modified version of the 13 | OCaml Core System that is distributed under the conditions defined in 14 | clause 2 of the GNU Lesser General Public License. This exception 15 | does not however invalidate any other reasons why the executable file 16 | might be covered by the GNU Lesser General Public License. 17 | 18 | ---------------------------------------------------------------------- 19 | 20 | GNU LESSER GENERAL PUBLIC LICENSE 21 | 22 | Version 2.1, February 1999 23 | 24 | Copyright (C) 1991, 1999 Free Software Foundation, Inc. 25 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 26 | Everyone is permitted to copy and distribute verbatim copies 27 | of this license document, but changing it is not allowed. 28 | 29 | [This is the first released version of the Lesser GPL. It also counts 30 | as the successor of the GNU Library Public License, version 2, hence 31 | the version number 2.1.] 32 | 33 | Preamble 34 | 35 | The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public Licenses are intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. 36 | 37 | This license, the Lesser General Public License, applies to some specially designated software packages--typically libraries--of the Free Software Foundation and other authors who decide to use it. You can use it too, but we suggest you first think carefully about whether this license or the ordinary General Public License is the better strategy to use in any particular case, based on the explanations below. 38 | 39 | When we speak of free software, we are referring to freedom of use, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish); that you receive source code or can get it if you want it; that you can change the software and use pieces of it in new free programs; and that you are informed that you can do these things. 40 | 41 | To protect your rights, we need to make restrictions that forbid distributors to deny you these rights or to ask you to surrender these rights. These restrictions translate to certain responsibilities for you if you distribute copies of the library or if you modify it. 42 | 43 | For example, if you distribute copies of the library, whether gratis or for a fee, you must give the recipients all the rights that we gave you. You must make sure that they, too, receive or can get the source code. If you link other code with the library, you must provide complete object files to the recipients, so that they can relink them with the library after making changes to the library and recompiling it. And you must show them these terms so they know their rights. 44 | 45 | We protect your rights with a two-step method: (1) we copyright the library, and (2) we offer you this license, which gives you legal permission to copy, distribute and/or modify the library. 46 | 47 | To protect each distributor, we want to make it very clear that there is no warranty for the free library. Also, if the library is modified by someone else and passed on, the recipients should know that what they have is not the original version, so that the original author's reputation will not be affected by problems that might be introduced by others. 48 | 49 | Finally, software patents pose a constant threat to the existence of any free program. We wish to make sure that a company cannot effectively restrict the users of a free program by obtaining a restrictive license from a patent holder. Therefore, we insist that any patent license obtained for a version of the library must be consistent with the full freedom of use specified in this license. 50 | 51 | Most GNU software, including some libraries, is covered by the ordinary GNU General Public License. This license, the GNU Lesser General Public License, applies to certain designated libraries, and is quite different from the ordinary General Public License. We use this license for certain libraries in order to permit linking those libraries into non-free programs. 52 | 53 | When a program is linked with a library, whether statically or using a shared library, the combination of the two is legally speaking a combined work, a derivative of the original library. The ordinary General Public License therefore permits such linking only if the entire combination fits its criteria of freedom. The Lesser General Public License permits more lax criteria for linking other code with the library. 54 | 55 | We call this license the "Lesser" General Public License because it does Less to protect the user's freedom than the ordinary General Public License. It also provides other free software developers Less of an advantage over competing non-free programs. These disadvantages are the reason we use the ordinary General Public License for many libraries. However, the Lesser license provides advantages in certain special circumstances. 56 | 57 | For example, on rare occasions, there may be a special need to encourage the widest possible use of a certain library, so that it becomes a de-facto standard. To achieve this, non-free programs must be allowed to use the library. A more frequent case is that a free library does the same job as widely used non-free libraries. In this case, there is little to gain by limiting the free library to free software only, so we use the Lesser General Public License. 58 | 59 | In other cases, permission to use a particular library in non-free programs enables a greater number of people to use a large body of free software. For example, permission to use the GNU C Library in non-free programs enables many more people to use the whole GNU operating system, as well as its variant, the GNU/Linux operating system. 60 | 61 | Although the Lesser General Public License is Less protective of the users' freedom, it does ensure that the user of a program that is linked with the Library has the freedom and the wherewithal to run that program using a modified version of the Library. 62 | 63 | The precise terms and conditions for copying, distribution and modification follow. Pay close attention to the difference between a "work based on the library" and a "work that uses the library". The former contains code derived from the library, whereas the latter must be combined with the library in order to run. 64 | 65 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 66 | 67 | 0. This License Agreement applies to any software library or other program which contains a notice placed by the copyright holder or other authorized party saying it may be distributed under the terms of this Lesser General Public License (also called "this License"). Each licensee is addressed as "you". 68 | 69 | A "library" means a collection of software functions and/or data prepared so as to be conveniently linked with application programs (which use some of those functions and data) to form executables. 70 | 71 | The "Library", below, refers to any such software library or work which has been distributed under these terms. A "work based on the Library" means either the Library or any derivative work under copyright law: that is to say, a work containing the Library or a portion of it, either verbatim or with modifications and/or translated straightforwardly into another language. (Hereinafter, translation is included without limitation in the term "modification".) 72 | 73 | "Source code" for a work means the preferred form of the work for making modifications to it. For a library, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the library. 74 | 75 | Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running a program using the Library is not restricted, and output from such a program is covered only if its contents constitute a work based on the Library (independent of the use of the Library in a tool for writing it). Whether that is true depends on what the Library does and what the program that uses the Library does. 76 | 77 | 1. You may copy and distribute verbatim copies of the Library's complete source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and distribute a copy of this License along with the Library. 78 | 79 | You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. 80 | 81 | 2. You may modify your copy or copies of the Library or any portion of it, thus forming a work based on the Library, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: 82 | 83 | a) The modified work must itself be a software library. 84 | b) You must cause the files modified to carry prominent notices stating that you changed the files and the date of any change. 85 | c) You must cause the whole of the work to be licensed at no charge to all third parties under the terms of this License. 86 | d) If a facility in the modified Library refers to a function or a table of data to be supplied by an application program that uses the facility, other than as an argument passed when the facility is invoked, then you must make a good faith effort to ensure that, in the event an application does not supply such function or table, the facility still operates, and performs whatever part of its purpose remains meaningful. 87 | 88 | (For example, a function in a library to compute square roots has a purpose that is entirely well-defined independent of the application. Therefore, Subsection 2d requires that any application-supplied function or table used by this function must be optional: if the application does not supply it, the square root function must still compute square roots.) 89 | 90 | These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Library, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Library, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. 91 | 92 | Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Library. 93 | 94 | In addition, mere aggregation of another work not based on the Library with the Library (or with a work based on the Library) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. 95 | 96 | 3. You may opt to apply the terms of the ordinary GNU General Public License instead of this License to a given copy of the Library. To do this, you must alter all the notices that refer to this License, so that they refer to the ordinary GNU General Public License, version 2, instead of to this License. (If a newer version than version 2 of the ordinary GNU General Public License has appeared, then you can specify that version instead if you wish.) Do not make any other change in these notices. 97 | 98 | Once this change is made in a given copy, it is irreversible for that copy, so the ordinary GNU General Public License applies to all subsequent copies and derivative works made from that copy. 99 | 100 | This option is useful when you wish to copy part of the code of the Library into a program that is not a library. 101 | 102 | 4. You may copy and distribute the Library (or a portion or derivative of it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange. 103 | 104 | If distribution of object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place satisfies the requirement to distribute the source code, even though third parties are not compelled to copy the source along with the object code. 105 | 106 | 5. A program that contains no derivative of any portion of the Library, but is designed to work with the Library by being compiled or linked with it, is called a "work that uses the Library". Such a work, in isolation, is not a derivative work of the Library, and therefore falls outside the scope of this License. 107 | 108 | However, linking a "work that uses the Library" with the Library creates an executable that is a derivative of the Library (because it contains portions of the Library), rather than a "work that uses the library". The executable is therefore covered by this License. Section 6 states terms for distribution of such executables. 109 | 110 | When a "work that uses the Library" uses material from a header file that is part of the Library, the object code for the work may be a derivative work of the Library even though the source code is not. Whether this is true is especially significant if the work can be linked without the Library, or if the work is itself a library. The threshold for this to be true is not precisely defined by law. 111 | 112 | If such an object file uses only numerical parameters, data structure layouts and accessors, and small macros and small inline functions (ten lines or less in length), then the use of the object file is unrestricted, regardless of whether it is legally a derivative work. (Executables containing this object code plus portions of the Library will still fall under Section 6.) 113 | 114 | Otherwise, if the work is a derivative of the Library, you may distribute the object code for the work under the terms of Section 6. Any executables containing that work also fall under Section 6, whether or not they are linked directly with the Library itself. 115 | 116 | 6. As an exception to the Sections above, you may also combine or link a "work that uses the Library" with the Library to produce a work containing portions of the Library, and distribute that work under terms of your choice, provided that the terms permit modification of the work for the customer's own use and reverse engineering for debugging such modifications. 117 | 118 | You must give prominent notice with each copy of the work that the Library is used in it and that the Library and its use are covered by this License. You must supply a copy of this License. If the work during execution displays copyright notices, you must include the copyright notice for the Library among them, as well as a reference directing the user to the copy of this License. Also, you must do one of these things: 119 | 120 | a) Accompany the work with the complete corresponding machine-readable source code for the Library including whatever changes were used in the work (which must be distributed under Sections 1 and 2 above); and, if the work is an executable linked with the Library, with the complete machine-readable "work that uses the Library", as object code and/or source code, so that the user can modify the Library and then relink to produce a modified executable containing the modified Library. (It is understood that the user who changes the contents of definitions files in the Library will not necessarily be able to recompile the application to use the modified definitions.) 121 | b) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (1) uses at run time a copy of the library already present on the user's computer system, rather than copying library functions into the executable, and (2) will operate properly with a modified version of the library, if the user installs one, as long as the modified version is interface-compatible with the version that the work was made with. 122 | c) Accompany the work with a written offer, valid for at least three years, to give the same user the materials specified in Subsection 6a, above, for a charge no more than the cost of performing this distribution. 123 | d) If distribution of the work is made by offering access to copy from a designated place, offer equivalent access to copy the above specified materials from the same place. 124 | e) Verify that the user has already received a copy of these materials or that you have already sent this user a copy. 125 | 126 | For an executable, the required form of the "work that uses the Library" must include any data and utility programs needed for reproducing the executable from it. However, as a special exception, the materials to be distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. 127 | 128 | It may happen that this requirement contradicts the license restrictions of other proprietary libraries that do not normally accompany the operating system. Such a contradiction means you cannot use both them and the Library together in an executable that you distribute. 129 | 130 | 7. You may place library facilities that are a work based on the Library side-by-side in a single library together with other library facilities not covered by this License, and distribute such a combined library, provided that the separate distribution of the work based on the Library and of the other library facilities is otherwise permitted, and provided that you do these two things: 131 | 132 | a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities. This must be distributed under the terms of the Sections above. 133 | b) Give prominent notice with the combined library of the fact that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 134 | 135 | 8. You may not copy, modify, sublicense, link with, or distribute the Library except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense, link with, or distribute the Library is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. 136 | 137 | 9. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Library or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Library (or any work based on the Library), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Library or works based on it. 138 | 139 | 10. Each time you redistribute the Library (or any work based on the Library), the recipient automatically receives a license from the original licensor to copy, distribute, link with or modify the Library subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties with this License. 140 | 141 | 11. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Library at all. For example, if a patent license would not permit royalty-free redistribution of the Library by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Library. 142 | 143 | If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply, and the section as a whole is intended to apply in other circumstances. 144 | 145 | It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. 146 | 147 | This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. 148 | 149 | 12. If the distribution and/or use of the Library is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Library under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. 150 | 151 | 13. The Free Software Foundation may publish revised and/or new versions of the Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. 152 | 153 | Each version is given a distinguishing version number. If the Library specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Library does not specify a license version number, you may choose any version ever published by the Free Software Foundation. 154 | 155 | 14. If you wish to incorporate parts of the Library into other free programs whose distribution conditions are incompatible with these, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. 156 | 157 | NO WARRANTY 158 | 159 | 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 160 | 161 | 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 162 | END OF TERMS AND CONDITIONS 163 | 164 | How to Apply These Terms to Your New Libraries 165 | 166 | If you develop a new library, and you want it to be of the greatest possible use to the public, we recommend making it free software that everyone can redistribute and change. You can do so by permitting redistribution under these terms (or, alternatively, under the terms of the ordinary General Public License). 167 | 168 | To apply these terms, attach the following notices to the library. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. 169 | 170 | one line to give the library's name and an idea of what it does. 171 | Copyright (C) year name of author 172 | 173 | This library is free software; you can redistribute it and/or 174 | modify it under the terms of the GNU Lesser General Public 175 | License as published by the Free Software Foundation; either 176 | version 2.1 of the License, or (at your option) any later version. 177 | 178 | This library is distributed in the hope that it will be useful, 179 | but WITHOUT ANY WARRANTY; without even the implied warranty of 180 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 181 | Lesser General Public License for more details. 182 | 183 | You should have received a copy of the GNU Lesser General Public 184 | License along with this library; if not, write to the Free Software 185 | Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 186 | 187 | Also add information on how to contact you by electronic and paper mail. 188 | 189 | You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the library, if necessary. Here is a sample; alter the names: 190 | 191 | Yoyodyne, Inc., hereby disclaims all copyright interest in 192 | the library `Frob' (a library for tweaking knobs) written 193 | by James Random Hacker. 194 | 195 | signature of Ty Coon, 1 April 1990 196 | Ty Coon, President of Vice 197 | 198 | That's all there is to it! 199 | 200 | -------------------------------------------------- 201 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Beware: This is a work in progress on a highly experimental prototype. 2 | Do not expect backward compatibility, performance nor correctness. 3 | 4 | 5 | Tensority is an experience in designing a library for strongly-typed 6 | multidimensional arrays and tensors manipulation. Tensority aims to 7 | cover three levels of compile-time safety: 8 | 9 | * tensor order level: 10 | * the type of multidimensional arrays should distinguish between 11 | arrays of different dimensions 12 | * the type of tensor should moreover distinguish between an vector and 1-form, 13 | i.e between `(1 + 0)` tensor and `(0 + 1)` tensors 14 | 15 | * dimension level: 16 | * adding two vectors of different dimensions should be a type error 17 | 18 | * index level 19 | * trying to access the `k+1`th elements of an array of size `k` should be 20 | an error 21 | 22 | Each level of safety restricts the number of functions implementable using 23 | tensority safe interface. However, non-trivial functions can still be 24 | implemented using this safe interface. 25 | 26 | 27 | ## Examples 28 | 29 | With the included ppx extension, tensority looks like: 30 | 31 | ```OCaml 32 | open Tensority 33 | open Multidim_array 34 | open Shape 35 | 36 | (* Multidimensional array literals *) 37 | let array4 = [%array 4 38 | [ 39 | [1, 2; 3, 4], [5, 6; 7, 8], [9, 10; 11, 12]; 40 | [1, 2; 8, 9], [9, 1; 2, 3], [19, 12; 14, 17] 41 | ] 42 | ];; 43 | 44 | (* Definition using function *) 45 | let array = init_sh [502103s] (function [k] -> Nat.to_int k) 46 | 47 | (* accessing an element *) 48 | let one = array4.( 1i, 0i, 0i, 0i ) ;; 49 | 50 | (* accessing an element with an out-of-bound type error *) 51 | let one = array4.( 1i, 5i, 0i, 1i ) ;; 52 | 53 | (* element assignment *) 54 | array4.(0i,0i,0i,0i) <- 0;; 55 | 56 | (* slicing *) 57 | let matrix = array4.!( All, All, 1j, 1j );; 58 | (* matrix is [9, 1; 2, 3] *) 59 | 60 | (* slice assignment *) 61 | let row = [%array (2,3) ];; 62 | matrix.!(All, 0j) <- row;; 63 | (* matrix is now [2, 3; 2, 3] *) 64 | 65 | (* range slice *) 66 | let array' = array.!([%range 25 50 ~by:5]) 67 | (* or *) 68 | let array' = array.!( 25 #-># 50 ## 5 ) 69 | ``` 70 | 71 | ### Examples without ppx 72 | 73 | Note that most of the ppx transformations are local, without 74 | rewriter the previous example would become 75 | 76 | ```OCaml 77 | open Tensority 78 | open Multidim_array 79 | open Shape 80 | open Nat_defs 81 | 82 | (* Multidimensional array literals *) 83 | let array4 = Unsafe.create [ _2; _2; _3; _2 ] 84 | [| 1; 2; 3; 4; 5; 6; 7; 8; 9; 10; 11; 12; 85 | 1; 2; 8; 9; 9; 1; 2; 3; 19; 12; 14; 17 86 | |] 87 | 88 | (* Definition using function *) 89 | let array = init_sh Size.[nat _5 @ _0 @ _2 @ _1 @ _0 @ _3n] 90 | (function [k] -> Nat.to_int k) 91 | 92 | (* accessing an element *) 93 | let one = array4.%( [ _1i; _0i; _0i; _0i] ) ;; 94 | 95 | (* accessing an element with an out-of-bound type error *) 96 | let one = array4.%([ _1i; _5i; _0i; _1i ] ) ;; 97 | 98 | (* element assignment *) 99 | array4.%([ _0i; _0i; _0i; _0i ]) <- 0;; 100 | 101 | (* slicing *) 102 | let matrix = array4.%[[ All; All; Elt _1i; Elt _1i ]];; 103 | (* matrix is [9, 1; 2, 3] *) 104 | 105 | (* slice assignment *) 106 | let row = Unsafe.create [_2] [| 2; 3 |];; 107 | matrix.%[[ All; Elt _0i]] <- row 108 | (* matrix is now [2, 3; 2, 3] *) 109 | 110 | (* range slice *) 111 | let r= Range.create ~by: 5 ~start:Indices.(nat _2 @ 5n) 112 | ~stop:Indices.(nat _5 @ 0n) ~len:_10 113 | 114 | let array' = array.%[[Range r]] 115 | 116 | ``` 117 | -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 2.0) 2 | -------------------------------------------------------------------------------- /examples/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (name linalg) 3 | (modules Linalg) 4 | (libraries tensority) 5 | (preprocess (pps tensority_ppx)) 6 | ) 7 | 8 | (executable 9 | (name ma) 10 | (modules Ma) 11 | (libraries tensority) 12 | (preprocess (pps tensority_ppx)) 13 | ) -------------------------------------------------------------------------------- /examples/linalg.ml: -------------------------------------------------------------------------------- 1 | open Tensority 2 | open Tensor 3 | 4 | let rotation theta = 5 | [%matrix 6 | [ cos theta, -.sin theta ; 7 | sin theta, cos theta ] 8 | ] 9 | 10 | let m = rotation 1. 11 | let c1 = m.!( All; 1j ) 12 | let one = [%vec ( 1., 1. ) ] 13 | 14 | let d = det @@ rotation 0.95 15 | 16 | let v = rotation 2. * one 17 | 18 | let c = (rotation 0.45).%( [1i], [0i]) 19 | let x = v.( 1i; __ ) 20 | 21 | let v1 = (transpose v).( __ ; 1i ) 22 | -------------------------------------------------------------------------------- /examples/ma.ml: -------------------------------------------------------------------------------- 1 | open Tensority 2 | 3 | open Shape 4 | open Multidim_array 5 | 6 | let m = init_sh [253s; 253s] Shape.( function [k; l] -> 7 | Nat.to_int k + Nat.to_int l) 8 | 9 | let t = [%array 4 10 | [ 11 | [1, 2; 3, 4], [5, 6; 7, 8], [9, 10; 11, 12]; 12 | [1, 2; 8, 9], [9, 1; 2, 3], [19, 12; 14, 17] 13 | ] 14 | ] 15 | 16 | 17 | let mat = t.!( 1j, 1j, All, All ) 18 | let k = mat.!(0j, 1j) 19 | 20 | let w = [%array (0, 2) ] 21 | 22 | let z = init_sh [502103s] Shape.( function [k] -> Nat.to_int k) 23 | 24 | let v = init_sh [253s] Shape.( fun [k] -> Nat.to_int k) 25 | 26 | ;; m.!( 252j, All ) <- v 27 | ;; let x = v.(0i) 28 | 29 | ;; w.!( 1j ) 30 | 31 | (* let r = [%range 0 -- 50 ] *) 32 | 33 | let zr = z.!( 0 #-># 50 ) 34 | let zr' = z.!([%range 0 50]) 35 | 36 | let xr = zr.!(10j) 37 | 38 | let tx = t.!(1j,0 #-># 1, All, 0j) 39 | -------------------------------------------------------------------------------- /examples/readme_example.ml: -------------------------------------------------------------------------------- 1 | 2 | #require "tensority.ppx";; 3 | 4 | open Tensority 5 | open Multidim_array 6 | open Shape 7 | 8 | (* Multidimensional array literals *) 9 | let array4 = [%array 4 10 | [ 11 | [1, 2; 3, 4], [5, 6; 7, 8], [9, 10; 11, 12]; 12 | [1, 2; 8, 9], [9, 1; 2, 3], [19, 12; 14, 17] 13 | ] 14 | ];; 15 | 16 | (* Definition using function *) 17 | let array = init_sh [502103s] (function [k] -> Nat.to_int k);; 18 | 19 | (* accessing an element *) 20 | let one = array4.( 1i, 0i, 0i, 0i ) ;; 21 | 22 | (* accessing an element with an out-of-bound type error *) 23 | (* let one = array4.( 1i, 5i, 0i, 1i ) ;; *) 24 | 25 | (* element assignment *) 26 | array4.(0i,0i,0i,0i) <- 0;; 27 | 28 | (* slicing *) 29 | let matrix = array4.[ All, All, 1j, 1j ];; 30 | (* matrix is [9, 1; 2, 3] *) 31 | 32 | (* slice assignment *) 33 | let row = [%array (2,3) ];; 34 | matrix.[All, 0j] <- row;; 35 | (* matrix is now [2, 3; 2, 3] *) 36 | 37 | (* range slice *) 38 | let array' = array.[[%range 25 50 ~by:5]];; 39 | (* or *) 40 | let array' = array.[ 25 #-># 50 ## 5];; 41 | -------------------------------------------------------------------------------- /lib/dune: -------------------------------------------------------------------------------- 1 | (library (public_name tensority)) -------------------------------------------------------------------------------- /lib/mask.ml: -------------------------------------------------------------------------------- 1 | 2 | 3 | type ( 'kind, 'nat, 'l, 'out ) abs = 4 | < 5 | k_in:'kind; 6 | x: < l_in:'l; out: 'out >; 7 | fx: ; 8 | > 9 | 10 | 11 | type _ elt = 12 | | Elt: ('nat,'kind) Nat.t -> 13 | ('kind, 'nat, 'l, 'out ) abs elt 14 | | All : 15 | < 16 | k_in: 'k; 17 | x: < l_in: 'l; out: 'n2 * 'l2 >; 18 | fx: < l_in: 'any * 'l; out: 'n2 Nat.succ * ('any * 'l2) > 19 | > elt 20 | | Range : 21 | ('in_, 'out) Range.t -> 22 | < 23 | k_in:'k; 24 | x: < l_in: 'l; out: 'n2 * 'l2 >; 25 | fx: < l_in: 'in_ * 'l; out:'n2 Nat.succ * ( 'out * 'l2 ) > 26 | > elt 27 | 28 | let pp_elt: type a. Format.formatter -> a elt -> unit = fun ppf -> function 29 | | Elt nat -> Format.fprintf ppf "%d" @@ Nat.to_int nat 30 | | All -> Format.fprintf ppf "All" 31 | | Range r -> Format.fprintf ppf "%a" Range.pp r 32 | 33 | type ('k1, 'k2) empty_2 = 34 | < kind :'k1 * 'k2; in_ : Shape.empty; out : Shape.empty > 35 | 36 | 37 | type _ list = 38 | | [] : ('a, 'b) empty_2 list 39 | | (::) : 40 | < k_in:'k; x: < l_in:'l; out:'out >; fx : > elt 41 | * list -> 42 | < in_:'n Nat.succ * 'fl; out:'f_out; kind: 'k * 'ko > list 43 | 44 | type ('a, 'b, 'k ) gen_s = 45 | < kind : 'k ; in_ : 'a; out : 'b > list 46 | 47 | type ('a, 'b) eq_s = ('a,'b, [`Eq] * [`Eq] ) gen_s 48 | type ('a, 'b) lt_s = ('a,'b, [`Lt] * [`Lt] ) gen_s 49 | type ('a, 'b) s_to_lt = ('a,'b, [`Eq] * [`Lt] ) gen_s 50 | type ('a, 'b) s_to_eq = ('a,'b, [`Lt] * [`Eq] ) gen_s 51 | type ('a, 'b) s = ('a, 'b) s_to_eq 52 | type ('a, 'b) t = ('a, 'b) s 53 | 54 | 55 | let rec order_in: type sh. sh list -> int = 56 | function 57 | | _ :: q -> 1 + order_in q 58 | | [] -> 0 59 | 60 | let rec order_out: type sh. sh list -> int = 61 | function 62 | | Elt _ :: q -> order_out q 63 | | Range _ :: q -> 1 + order_out q 64 | | All :: q -> 1 + order_out q 65 | | [] -> 0 66 | 67 | let rec filter: type sh sh2. sh Shape.eq -> (sh, sh2) s -> sh2 Shape.eq = 68 | let open Shape in 69 | fun sh sl -> match sh,sl with 70 | | [], [] -> [] 71 | | _ :: q, Elt _ :: sq -> filter q sq 72 | | _ :: q, Range r :: sq -> (Range.len r) :: filter q sq 73 | | e :: q, All :: sq -> e::filter q sq 74 | 75 | let rec iter_extended_dual: type sh sh'. 76 | (sh Shape.lt -> sh' Shape.lt -> unit ) -> sh Shape.eq -> (sh',sh) s -> unit= 77 | fun f sh mask -> 78 | let open Shape in 79 | match mask, sh with 80 | | [], [] -> f [] [] 81 | | Elt a :: mask, _ -> 82 | iter_extended_dual (fun sh sh' -> f sh (a :: sh') ) sh mask 83 | | All :: mask, a :: sh -> 84 | Nat.iter_on a (fun nat -> 85 | let f sh sh' = f (nat::sh) (nat::sh') in 86 | iter_extended_dual f sh mask 87 | ) 88 | | Range r :: mask, a :: sh -> 89 | Nat.iter_on a (fun nat -> 90 | let f sh sh' = 91 | f (nat::sh) (Range.transpose r nat::sh') in 92 | iter_extended_dual f sh mask 93 | ) 94 | 95 | let rec iter_masked_dual: type sh sh'. 96 | (sh Shape.lt -> sh' Shape.lt -> unit ) -> sh Shape.l -> (sh,sh') s_to_eq -> unit= 97 | fun f sh mask -> 98 | let open Shape in 99 | match mask, sh with 100 | | [], [] -> f [] [] 101 | | Elt a :: mask, _ :: sh -> 102 | iter_masked_dual (fun sh sh' -> f (a :: sh) sh' ) sh mask 103 | | All :: mask, a :: sh -> 104 | Nat.iter_on a (fun nat -> 105 | let f sh sh' = f (nat::sh) (nat::sh') in 106 | iter_masked_dual f sh mask 107 | ) 108 | | Range r :: mask, _ :: sh -> 109 | Nat.iter_on (Range.len r) (fun nat -> 110 | let f sh sh' = 111 | f (Range.transpose r nat::sh) (nat ::sh') in 112 | iter_masked_dual f sh mask 113 | ) 114 | 115 | let rec join: type li lm lo ni nm no. 116 | ( ni * li as 'i, nm * lm as 'm) s -> 117 | ('m, no * lo as 'o) s -> 118 | ('i,'o) s 119 | = fun slice1 slice2 -> 120 | match slice1, slice2 with 121 | | [], [] -> [] 122 | | Elt k :: slice1, _ -> Elt k :: (join slice1 slice2) 123 | | All :: slice1, All::slice2 -> All :: (join slice1 slice2) 124 | | All :: slice1, Elt k :: slice2 -> (Elt k) :: (join slice1 slice2) 125 | | All :: slice1, Range r :: slice2 -> Range r :: (join slice1 slice2) 126 | | (Range _ as r) :: slice1, All::slice2 -> r :: (join slice1 slice2) 127 | | Range r :: slice1, Elt k :: slice2 -> 128 | Elt (Range.transpose r k) :: (join slice1 slice2) 129 | | Range r :: slice1, Range r2 :: slice2 -> 130 | Range (Range.compose r r2) :: (join slice1 slice2) 131 | 132 | let (>>) = join 133 | 134 | [@@@warning "-32"] 135 | let rec position_gen: 136 | type sh filt. 137 | mult:int -> sum:int 138 | -> (sh, filt) s 139 | -> sh Shape.eq 140 | -> filt Shape.lt -> int * int = 141 | let open Shape in 142 | fun ~mult ~sum filter shape indices -> 143 | match[@warning "-4"] filter, shape, indices with 144 | | All :: filter , dim :: shape, nat :: indices -> 145 | position_gen ~mult:(mult * Nat.to_int dim) ~sum:(sum + mult * Nat.to_int nat) 146 | filter shape indices 147 | | Elt nat :: filter, dim :: shape, _ -> 148 | position_gen ~sum:(sum + mult * Nat.to_int nat) 149 | ~mult:(Nat.to_int dim * mult) filter shape indices 150 | | Range r :: filter, dim :: shape, nat :: indices -> 151 | let nat = Range.transpose r nat in 152 | position_gen ~sum:(sum + mult * Nat.to_int nat) 153 | ~mult:(Nat.to_int dim * mult) filter shape indices 154 | | [], [], _ -> mult, sum 155 | 156 | [@@@warning "+32"] 157 | 158 | let pp ppf shape = 159 | let rec inner: type sh. Format.formatter -> sh list -> unit = 160 | fun ppf -> 161 | function 162 | | [] -> () 163 | | [a] -> Format.fprintf ppf "%a" pp_elt a 164 | | a :: q -> 165 | Format.fprintf ppf "%a;@ %a " pp_elt a inner q 166 | in 167 | Format.fprintf ppf "@[(%a)@]" inner shape 168 | -------------------------------------------------------------------------------- /lib/mask.mli: -------------------------------------------------------------------------------- 1 | (** The [Mask] module can be used to transform shapes defined 2 | in the {!Shape} module and strides defined over the 3 | {!Stride} module. 4 | 5 | At the type level, a mask is a combination of possible input shape 6 | and output shapes. At the value level, a mask is a list of mask element 7 | that can be either 8 | 9 | 10 | - [All] : Take all elements 11 | - [Elt nat]: Take only elements with indices equal to [nat] 12 | - [Range r]: Take only elements belonging to the range r 13 | 14 | {2 Types} 15 | {3 Helper type constructors} *) 16 | 17 | type ('k1, 'k2) empty_2 = 18 | < kind :'k1 * 'k2; in_ : Shape.empty; out : Shape.empty > 19 | (** [empty_2] is the type argument of an empty shape *) 20 | 21 | type ( 'kind, 'nat, 'l, 'out ) abs = 22 | < 23 | k_in:'kind; 24 | x: < l_in:'l; out: 'out >; 25 | fx: ; 26 | > 27 | (** [('kind,'nat,'l,'out) abs] is the type of an absolute element 28 | of logical size ['nat] with kind [nat] that is applied to an inner 29 | list [l] and an filtered shape ['out] 30 | A value of type [('kind,'nat,'l,'out) abs elt] represents the description 31 | of a slice of a multidimensional array. 32 | Such slice description can be either concrete like [Elt] and [P_elt]. 33 | Concrete elements represents the logical and physical size of a 34 | multidimensional array in a given dimension. Relative elements describes 35 | the size of a slice relatively to another source slice. 36 | *) 37 | 38 | (** {3 Main types} *) 39 | type _ elt = 40 | | Elt: ('nat,'kind) Nat.t -> 41 | ('kind, 'nat, 'l, 'out ) abs elt 42 | (** [Elt k]: Take only the [k]-th element from the original source *) 43 | | All : 44 | < 45 | k_in: 'k; 46 | x: < l_in: 'l; out: 'n2 * 'l2 >; 47 | fx: < l_in: 'any * 'l; out: 'n2 Nat.succ * ('any * 'l2) > 48 | > elt 49 | (** [All]: Take all the elements from the source *) 50 | | Range : 51 | ('in_, 'out) Range.t -> 52 | < 53 | k_in:'k; 54 | x: < l_in: 'l; out: 'n2 * 'l2 >; 55 | fx: < l_in: 'in_ * 'l; out:'n2 Nat.succ * ( 'out * 'l2 ) > 56 | > elt 57 | (** [Range r]: Take an affine subset from the source, 58 | see the [Range] module for more information. *) 59 | 60 | (** The generic type of mask *) 61 | type _ list = 62 | | [] : ('a, 'b) empty_2 list 63 | (** A quite complicated empty type *) 64 | | (::) : 65 | < k_in:'k; x: < l_in:'l; out:'out >; fx : > elt 66 | * list -> 67 | < in_:'n Nat.succ * 'fl; out:'f_out; kind: 'k * 'ko > list 68 | (** The cons operator, note that most of type-level operations 69 | are delegated to the elt constructor, using the method [x] has 70 | argument and the method [fx] as output of the type-level function 71 | *) 72 | 73 | 74 | (** Generic subtype generator for slice shape *) 75 | type ('a, 'b, 'k ) gen_s = 76 | < kind : 'k ; in_ : 'a; out : 'b > list 77 | 78 | (** Slice mapping size to size *) 79 | type ('a, 'b) eq_s = ('a,'b, [`Eq] * [`Eq] ) gen_s 80 | 81 | (** Slice mapping multi-index to multi-index *) 82 | type ('a, 'b) lt_s = ('a,'b, [`Lt] * [`Lt] ) gen_s 83 | 84 | (** Slice mapping size to multi-index *) 85 | type ('a, 'b) s_to_lt = ('a,'b, [`Eq] * [`Lt] ) gen_s 86 | 87 | (** Slice mapping index to size *) 88 | type ('a, 'b) s_to_eq = ('a,'b, [`Lt] * [`Eq] ) gen_s 89 | 90 | (** Standard slice *) 91 | type ('a, 'b) s = ('a, 'b) s_to_eq 92 | type ('a, 'b) t = ('a, 'b) s 93 | 94 | (** {2 Functions } *) 95 | 96 | 97 | val order_in: 'sh list -> int 98 | (** Dimension of the input shape *) 99 | 100 | val order_out: 'sh list -> int 101 | (** Dimension of the output shape *) 102 | 103 | val filter : 'sh Shape.eq -> ('sh, 'sh2) s -> 'sh2 Shape.eq 104 | (** Apply a filter to a shape *) 105 | 106 | val join : ('ni * 'li as 'i, 'nm * 'lm as 'm) s -> 107 | ('m, 'no * 'lo as 'o) s -> ('i, 'o) s 108 | (** Compose two filter together *) 109 | 110 | val ( >> ) : 111 | ('ni * 'li as 'i, 'nm * 'lm as 'm) s -> 112 | ('m, 'no * 'lo as 'o) s -> ('i, 'o) s 113 | (** Operator for mask composition *) 114 | 115 | (** {2 Iteration} *) 116 | 117 | (** Iter over both the original shape and the masked shape *) 118 | val iter_extended_dual : 119 | ('sh Shape.lt -> 'sh2 Shape.lt -> unit) -> 'sh Shape.l -> ('sh2, 'sh) s -> unit 120 | 121 | val iter_masked_dual : 122 | ('sh Shape.lt -> 'sh2 Shape.lt -> unit) -> 'sh Shape.l -> ('sh, 'sh2) s -> unit 123 | 124 | (** {2 Pretty printing} *) 125 | 126 | (** Pretty printer for elements *) 127 | val pp_elt : Format.formatter -> 'a elt -> unit 128 | 129 | (** Pretty printer for shapes *) 130 | val pp : Format.formatter -> ('a,'b) t -> unit 131 | -------------------------------------------------------------------------------- /lib/misc.ml: -------------------------------------------------------------------------------- 1 | 2 | let delta i j = if i = j then 1. else 0. 3 | -------------------------------------------------------------------------------- /lib/misc.mli: -------------------------------------------------------------------------------- 1 | val delta : 'a -> 'a -> float 2 | -------------------------------------------------------------------------------- /lib/multidim_array.ml: -------------------------------------------------------------------------------- 1 | module A = Array 2 | exception Dimension_error = Signatures.Dimension_error 3 | 4 | let (.!()) = Array.unsafe_get 5 | 6 | type 'x t = { 7 | shape: ('n * 'sh) Shape.l; 8 | offset:int; 9 | strides: 'n Stride.t; 10 | array: 'elt array 11 | } 12 | constraint 'x = 13 | 14 | let size m = Shape.size m.shape 15 | let is_sparse m = m.offset <> 0 || (Shape.size m.shape < Stride.size m.strides) 16 | 17 | module Unsafe_0 = struct 18 | 19 | let create shape array = 20 | let strides = Stride.create shape in 21 | let len = A.length array and size = Stride.size strides in 22 | if len <> size then 23 | raise @@ Dimension_error("Multidim_array.create_unsafe", size, len) 24 | else 25 | {shape; array; strides; offset = 0 } 26 | end 27 | 28 | let position (m: t) (indices:'sh Shape.lt) = 29 | m.offset + Stride.position ~strides:m.strides ~indices 30 | 31 | 32 | let ( .%() ): 33 | t -> 'sh Shape.lt -> 'elt = fun t indices -> 34 | t.array.!(position t indices) 35 | 36 | let ( .%() <- ): t -> 'sh Shape.lt -> 'elt -> unit = 37 | fun t indices value -> 38 | let p = position t indices in 39 | A.unsafe_set t.array p value 40 | 41 | 42 | let physical_size t = A.length t.array 43 | let shape t = t.shape 44 | 45 | 46 | let init_sh shape f = 47 | let strides = Stride.create shape in 48 | let size = Stride.size strides in 49 | let z = Shape.zero shape in 50 | let array = A.make size @@ f z in 51 | let m = {shape; array; strides; offset = 0 } in 52 | Shape.iter_on shape (fun sh -> m.%(sh) <- f sh); 53 | m 54 | 55 | 56 | let ordinal (nat: 'a Nat.eq) : t = 57 | Unsafe_0.create Shape.[nat] @@ A.init (Nat.to_int nat) Nat.Unsafe.create 58 | 59 | let slice_first (nat:'a Nat.lt) (m: t) = 60 | let strides, shape = Stride.slice_1 m.strides, Shape.tail_1 m.shape in 61 | let offset = m.offset + (Stride.first m.strides) * (Nat.to_int nat) in 62 | { m with shape; strides; offset } 63 | 64 | let slice s m = 65 | let shape = Mask.filter m.shape s in 66 | let offset, strides = Stride.filter m.strides s in 67 | { m with shape; strides; offset= offset + m.offset } 68 | 69 | module Dense = struct 70 | let copy ?(deep_copy=fun x -> x) : 'sh t -> 'sh t = fun ma -> 71 | { ma with array = A.map deep_copy ma.array } 72 | 73 | 74 | let blit: from:'sh t -> to_:'sh t -> unit = 75 | fun ~from ~to_ -> 76 | A.iteri (A.unsafe_set from.array) to_.array 77 | 78 | let map f m = 79 | let array = A.map f m.array in 80 | { m with array } 81 | 82 | let map2 f (m: t) (m2: t) = 83 | let array = A.init (physical_size m) 84 | (fun i -> f m.array.!(i) m2.array.!(i)) in 85 | { m with array } 86 | 87 | let iter f m = 88 | A.iter f m.array 89 | 90 | let iter2 f m n = 91 | A.iter2 f m.array n.array 92 | 93 | let fold_all_left f acc m = 94 | A.fold_left f acc m.array 95 | 96 | 97 | (** Unsafe *) 98 | let reshape_inplace: 99 | 'sh2 Shape.l -> t -> t = 100 | fun sh2 m -> 101 | let s = size m and s2 = Shape.size sh2 in 102 | if size m <> Shape.size sh2 then 103 | raise @@ Dimension_error ("Multidim_array.reshape", s, s2) 104 | else 105 | { m with shape = sh2 } 106 | 107 | end 108 | 109 | module Sparse = struct 110 | 111 | let copy ?(deep_copy=(fun x->x)) m = 112 | let size = Shape.size m.shape and shape = m.shape in 113 | if size = 0 then 114 | Unsafe_0.create shape [| |] 115 | else 116 | let m' = Unsafe_0.create shape @@ A.make size m.array.!(0) in 117 | Shape.iter_on m.shape (fun sh -> m'.%(sh) <- deep_copy m.%(sh)) 118 | ; m' 119 | 120 | let partial_blit: from: t -> to_: t 121 | -> ('sh,'sh2) Mask.t -> unit = 122 | fun ~from ~to_ filter -> 123 | Mask.iter_extended_dual 124 | (fun sh sh' -> 125 | to_.%(sh') <- from.%(sh) ) 126 | from.shape filter 127 | 128 | let iter_sh f m = 129 | Shape.iter (fun sh -> f sh m.%(sh)) m.shape 130 | 131 | let map_sh f m = 132 | init_sh m.shape (fun sh -> f sh m.%(sh) ) 133 | 134 | let blit: from:'sh t -> to_:'sh t -> unit = 135 | fun ~from ~to_ -> 136 | Shape.iter_on from.shape (fun sh -> 137 | to_.%(sh) <- from.%(sh) 138 | ) 139 | 140 | let map f m = 141 | init_sh m.shape (fun sh -> f m.%(sh) ) 142 | 143 | let map2 f (m: t) (m2: t) = 144 | init_sh m.shape (fun sh -> f m.%(sh) m2.%(sh) ) 145 | 146 | let iter f m = 147 | Shape.iter_on m.shape (fun sh -> f m.%(sh) ) 148 | 149 | let iter2 f m n = 150 | Shape.iter_on m.shape (fun sh -> f m.%(sh) n.%(sh) ) 151 | 152 | let fold_all_left f acc m = 153 | let acc =ref acc in 154 | Shape.iter_on m.shape (fun sh -> acc := f !acc m.%(sh)) 155 | ; !acc 156 | 157 | end 158 | 159 | let copy ?(deep_copy= fun x -> x ) m = 160 | if is_sparse m then 161 | Sparse.copy ~deep_copy m 162 | else Dense.copy ~deep_copy m 163 | 164 | let blit ~from ~to_ = 165 | if is_sparse from || is_sparse to_ then 166 | Sparse.blit ~from ~to_ 167 | else 168 | Dense.blit ~from ~to_ 169 | 170 | let map f m = 171 | (if is_sparse m then 172 | Sparse.map 173 | else 174 | Dense.map 175 | ) f m 176 | 177 | let map_first f m = 178 | let nat, _ = Shape.split_1 m.shape in 179 | let open Shape in 180 | init_sh [nat] (fun [n] -> f @@ slice_first n m) 181 | 182 | let map2 f m m2 = 183 | ( if is_sparse m || is_sparse m2 then Sparse.map2 else Dense.map2) f m m2 184 | 185 | let iter f m = 186 | if is_sparse m then Sparse.iter f m else Dense.iter f m 187 | 188 | let iter2 f m m2 = 189 | ( if is_sparse m || is_sparse m2 then Sparse.iter2 else Dense.iter2) f m m2 190 | 191 | 192 | let iter_sh f = Sparse.iter_sh f 193 | 194 | let map_sh f = Sparse.map_sh f 195 | 196 | let fold_all_left f acc m = 197 | (if is_sparse m then 198 | Sparse.fold_all_left 199 | else 200 | Dense.fold_all_left) 201 | f acc m 202 | 203 | let fold_top_left f acc m = 204 | let k, _ = Shape.split_1 m.shape in 205 | Nat.fold_on k acc (fun acc nat -> 206 | f acc (slice_first nat m) 207 | ) 208 | 209 | let partial_copy ?(deep_copy=fun x -> x) s m = 210 | Sparse.copy ~deep_copy @@ slice s m 211 | 212 | let partial_blit = Sparse.partial_blit 213 | 214 | let ( .%[] ) m f = slice f m 215 | and ( .%[]<- ) to_ filter from = partial_blit ~from ~to_ filter 216 | 217 | (** Full unsafe module *) 218 | module Unsafe = struct 219 | include Unsafe_0 220 | let reshape_inplace dims t = 221 | if is_sparse t then 222 | None 223 | else 224 | Some (Dense.reshape_inplace dims t) 225 | 226 | let reshape dims t = 227 | Dense.reshape_inplace dims @@ copy t 228 | 229 | end 230 | 231 | 232 | (** Scanning functions *) 233 | let for_all p x = 234 | fold_all_left (fun b x -> b && p x ) true x 235 | 236 | let exists p x = 237 | fold_all_left (fun b x -> b || p x ) false x 238 | 239 | let mem x m = 240 | fold_all_left (fun b y -> b || x = y ) false m 241 | 242 | let memq x m = 243 | fold_all_left (fun b y -> b || x == y ) false m 244 | 245 | 246 | let find predicate ma = 247 | Shape.fold_left (fun l sh -> 248 | let x = ma.%(sh) in 249 | if predicate x then sh :: l else l 250 | ) [] ma.shape 251 | 252 | 253 | let rec _repeat k ppf s = if k=0 then () else 254 | (Format.fprintf ppf "%s" s; _repeat (k-1) ppf s) 255 | 256 | let pp elt_pp ppf ma = 257 | let up k = Format.fprintf ppf "@[%s" (if k mod 2 = 0 then "[" else "" ) 258 | and down k = Format.fprintf ppf "%s@]" (if k mod 2 = 0 then "]" else "") 259 | and sep k = 260 | if k mod 2 = 0 then 261 | Format.fprintf ppf ", " 262 | else 263 | Format.fprintf ppf "; " in 264 | let f sh = elt_pp ppf (ma.%(sh)) in 265 | Format.fprintf ppf "@[(" 266 | ; Shape.iter_sep ~up ~down ~sep ~f ma.shape 267 | ; Format.fprintf ppf ")@]" 268 | -------------------------------------------------------------------------------- /lib/multidim_array.mli: -------------------------------------------------------------------------------- 1 | type 'a t constraint 'a = < elt : 'elt; shape : 'n * 'sh > 2 | (** Values of type [ t] 3 | are multidimensional arrays with dimension a_1, ..., a_n where a_1 to 4 | a_n are type-level representation of integers *) 5 | 6 | (** {2 Size and shape} *) 7 | 8 | (** [size m] is the number of elements in the array *) 9 | val size : < elt : 'a; shape : 'b > t -> int 10 | 11 | (** [physical_size m] is the length of the underlying physical storage *) 12 | val physical_size : < elt : 'a; shape : 'b > t -> int 13 | 14 | (** [shape m] is the shape of the array *) 15 | val shape : < elt : 'a; shape : 'b > t -> 'b Shape.l 16 | 17 | (** [ is_sparse m ] is true when the logical size of the array is less 18 | than the physical size of this array *) 19 | val is_sparse : < elt : 'a; shape : 'b > t -> bool 20 | 21 | (** {2 Generic indexing operators} *) 22 | val (.%()): < elt : 'elt; shape : 'sh > t -> 'sh Shape.lt -> 'elt 23 | 24 | val (.%()<-): < elt : 'elt; shape : 'sh > t -> 'sh Shape.lt -> 'elt -> unit 25 | 26 | 27 | (** {2 Unsafe functions} *) 28 | module Unsafe : sig 29 | 30 | (** [create shape array] creates an array of shape [shape] and physical contents [array]. 31 | The function raises an [Dimension_error] exception if the physical size of [shape] is 32 | not the same as the length of [array]. 33 | *) 34 | val create : 'a Shape.eq -> 'b array -> < elt : 'b; shape : 'a > t 35 | 36 | (** Reshape function *) 37 | 38 | (** [reshape sh m] copy and reshape the function if the shape [sh] 39 | is compatible with the shape [m.shape]. 40 | Raise a [Dimension_error] otherwise. 41 | @todo Multiplication proof. 42 | *) 43 | val reshape: 'sh Shape.l -> < elt : 'e; shape : 'sh > t 44 | -> < elt : 'e; shape : 'sh > t 45 | 46 | (** [reshape_inplace sh m] reinterpret the dense multidimensional array 47 | [m] as an array of shape [sh]. The function returns [None] if the 48 | [m] is not dense 49 | Raise a [Dimension_error] otherwise. 50 | @todo Multiplication proof. 51 | *) 52 | val reshape_inplace: 'sh Shape.l 53 | -> < elt : 'e; shape : 'sh > t 54 | -> < elt : 'e; shape : 'sh > t option 55 | 56 | end 57 | 58 | (** {2 Creation functions} *) 59 | val init_sh : 60 | 'a Shape.eq -> ('a Shape.lt -> 'b) -> < elt : 'b; shape : 'a > t 61 | 62 | (** [ordinal n] is the sorted unidimensional array of all integers less than [n] 63 | equiped with a type bound [ [%nat n] Nat.lt] *) 64 | val ordinal : 'a Nat.eq -> < elt : 'a Nat.lt; shape : 'a Shape.single > t 65 | 66 | (** {2 Copy functions} *) 67 | 68 | val copy: 69 | ?deep_copy:('a -> 'a) -> 70 | < elt : 'a; shape : 'b > t -> < elt : 'a; shape : 'b > t 71 | 72 | val blit : 73 | from:< elt : 'a; shape : 'b > t -> to_:< elt : 'a; shape : 'b > t -> unit 74 | 75 | (** {2 Slicing function} *) 76 | 77 | (** [slice_first nat m] creates a slice [s] such that [s.(a_1,...,a_n)] corresponds to [ m.(nat,a_1,...,a_n) ] *) 78 | val slice_first: 79 | 'a Nat.lt -> t -> 80 | t 81 | 82 | (** [slice filter m] takes a sparse shape [ filter = [s_1; ...; s_n] ] with [ s_k = Elt k | Range r | All ] 83 | and creates an array whose shape is made of the remaining free indices in [shape/filter] *) 84 | val slice : 85 | ('a, 'b) Mask.t -> 86 | < elt : 'c; shape : 'a > t -> < elt : 'c; shape : 'b > t 87 | 88 | (** {2 Slicing indexing operators} *) 89 | val (.%[]<-): 90 | < elt : 'a; shape : 'b > t -> 91 | ('b, 'c) Mask.t -> < elt : 'a; shape : 'c > t -> unit 92 | 93 | val (.%[]) : 94 | < elt : 'a; shape : 'b > t -> 95 | ('b, 'c) Mask.t -> < elt : 'a; shape : 'c > t 96 | 97 | (** [partial_copy filter m] creates a fresh copy of the slice [slice filter m] *) 98 | val partial_copy : 99 | ?deep_copy:('a -> 'a) -> 100 | ('b, 'c) Mask.t -> 101 | < elt : 'a; shape : 'b > t -> < elt : 'a; shape : 'c > t 102 | 103 | (** [partial_blit ~from ~to_ filter] copy the values of [from] to the slice [slice filter to] *) 104 | val partial_blit : 105 | from:< elt : 'a; shape : 'b > t -> 106 | to_:< elt : 'a; shape : 'c > t -> ('c, 'b) Mask.t -> unit 107 | 108 | 109 | (** {2 Map, iter and fold functions} *) 110 | (** {3 Map functions} *) 111 | 112 | (** [map f m] applies [f] to every elements of [m] while preserving the shape 113 | of the array *) 114 | val map : 115 | ('a -> 'b) -> < elt : 'a; shape : 'c > t -> < elt : 'b; shape : 'c > t 116 | 117 | (** [map_sh f m] applies [f index] to every elements of [m], where [index] is 118 | multi-index of the element. The shape of the array is preserved *) 119 | val map_sh : 120 | ('sh Shape.lt -> 'a -> 'b) -> < elt : 'a; shape : 'sh > t -> 121 | < elt : 'b; shape : 'sh > t 122 | 123 | (** [map_first f m] computes the multidimensional array 124 | [ [| f s_1; ...; f s_n|] ] where [s_k] is the k-th first-index slice 125 | of m 126 | *) 127 | val map_first: (< elt : 'a; shape : 'rank * 'l > t -> 'd) -> 128 | < elt : 'a; shape : 'rank Nat.succ * ( 'n * 'l ) > t -> 129 | < elt : 'd; shape : Nat.z Nat.succ * ('n * Shape.nil) > t 130 | 131 | (** [map f m_1 m_2] computes [f e_1 e_2] for every pair of elements of the 132 | arrays [m_1] and [m_2] and returns the result as an array of the same shape 133 | as its input *) 134 | val map2 : 135 | ('a -> 'b -> 'c) -> 136 | < elt : 'a; shape : 'd > t -> 137 | < elt : 'b; shape : 'd > t -> < elt : 'c; shape : 'd > t 138 | 139 | 140 | (** {2 Iter functions} *) 141 | 142 | (** [iter f m] computes [f e] for every elements of [m] *) 143 | val iter : ('a -> unit) -> < elt : 'a; shape : 'b > t -> unit 144 | 145 | (** [iter2 f m n] computes [f e_1 e_2] for every elements of [m] and [n] *) 146 | val iter2 : ('a -> 'b -> unit) -> < elt : 'a; shape : 'sh > t -> 147 | t -> unit 148 | 149 | (** [iter_sh f m] computes [f index e] for every elements of [m] *) 150 | val iter_sh: ( 'b Shape.lt -> 'a -> unit) -> < elt : 'a; shape : 'b > t -> unit 151 | 152 | (** {2 Fold functions} *) 153 | 154 | (** [fold_all_left f acc m] for an array [ m = (a_0, ..., a_n)] computes 155 | [f (...f ( f acc a_0 ) a_1 )...) a_n ] *) 156 | val fold_all_left : 157 | ('a -> 'b -> 'a) -> 'a -> < elt : 'b; shape : 'c > t -> 'a 158 | 159 | (** 160 | [fold_top_left f acc m] computes a fold_left on the array of slices 161 | [(s_0, ..., s_n)] with [ s_n = slice_first n m ] 162 | *) 163 | val fold_top_left: 164 | ('acc -> < elt:'elt; shape: 'n * 'l > t -> 'acc) -> 'acc -> 165 | t -> 166 | 'acc 167 | 168 | 169 | (** {2 Scanning functions} *) 170 | 171 | (** [for_all predicate m] is true if and only if [predicate] is true for 172 | every element of the array.*) 173 | val for_all: ('a -> bool) -> t -> bool 174 | 175 | (** [exists predicate m] is true if and only if there is an element within the 176 | array for whic [predicate] is true.*) 177 | val exists: ('a -> bool) -> t -> bool 178 | 179 | (** [mem x m] is true if and only if there is an element of the array 180 | equal to x (in the sense of the structural equality (=) ).*) 181 | val mem: 'a -> t -> bool 182 | 183 | (** [memq x m] is true if and only if there is an element of the array 184 | physically equal to x (i.e [ ∃s, x == m.(s) ] ).*) 185 | val memq: 'a -> t -> bool 186 | 187 | (** {2 Predicates functions}*) 188 | 189 | (** [find predicate m] returns all the multi-indices of the elements [e] 190 | of the array such that [predicate e] is true *) 191 | val find : ('a -> bool) -> < elt : 'a; shape : 'b > t -> 'b Shape.lt list 192 | 193 | 194 | (** {2 Printing functions} *) 195 | 196 | val pp : 197 | (Format.formatter -> 'a -> unit) -> 198 | Format.formatter -> < elt : 'a; shape : 'b > t -> unit 199 | -------------------------------------------------------------------------------- /lib/nat.ml: -------------------------------------------------------------------------------- 1 | (* Core types *) 2 | type (+'a,-'b) t = int 3 | type (+'a,-'b) nat = ('a,'b) t 4 | 5 | (* Helper types *) 6 | type lem = [ `Eq | `Lt ] 7 | type eqm = [ `Eq ] 8 | type ltm = [ `Lt ] 9 | 10 | type empty = private Empty_set 11 | 12 | (* Specialized types *) 13 | type +'a lt = ('a,ltm) t 14 | type +'a eq = ('a,eqm) t 15 | type +'a le = ('a,lem) t 16 | 17 | (* Unsafe functions, to be used rarely and cautiously *) 18 | module Unsafe = struct 19 | type unsafe 20 | let create n = n 21 | let magic n = n 22 | let eq n = n 23 | let lt n = n 24 | end 25 | 26 | (* Safe conversion functions *) 27 | let to_int n = n 28 | let pp ppf n = Format.fprintf ppf "%d" n 29 | let show = string_of_int 30 | 31 | (* Utility functions *) 32 | let zero = Unsafe.create 0 33 | let succ nat = succ @@ to_int nat 34 | 35 | (* Functor for dynamic natural *) 36 | module type dynamic = sig type t val dim: t eq end 37 | module Dynamic(D: sig val dim: int end)= struct 38 | [@@@warning "-37"] 39 | type t = private T 40 | let dim: t eq = Unsafe.create D.dim 41 | end 42 | let dynamic dim = (module Dynamic(struct let dim=dim end): dynamic ) 43 | 44 | (* Peano types ?? *) 45 | type z = private Z 46 | type nz = private NZ 47 | type +'a succ = private Succ 48 | 49 | (* Iters, folds and map *) 50 | 51 | (* Iter functions *) 52 | let iter (f:'a lt -> unit) (n:'a eq) : unit = 53 | for i = 0 to (to_int n - 1) do 54 | f @@ Unsafe.create i 55 | done 56 | let iter_on n f = iter f n 57 | 58 | let partial_iter ~start ~(stop: 'a eq) (f:'a lt -> unit): unit = 59 | for i=start to (to_int stop - 1) do 60 | f @@ Unsafe.create i 61 | done 62 | 63 | let typed_partial_iter ~(start: 'a lt) ~(stop: 'a eq) (f:'a lt -> unit): unit = 64 | for i= to_int start to (to_int stop - 1) do 65 | f @@ Unsafe.create i 66 | done 67 | 68 | (* Map function *) 69 | let map (f:'a lt -> 'b ) (dim:'a eq) = 70 | let n = to_int dim in 71 | Array.init n (fun i -> f @@ Unsafe.create i) 72 | 73 | 74 | (* Fold functions *) 75 | let fold f acc n = 76 | let acc = ref acc in 77 | iter (fun n -> acc := f !acc n) n; 78 | !acc 79 | 80 | let fold_on n acc f = fold f acc n 81 | 82 | let partial_fold 83 | ~start 84 | ~(stop:'a eq) 85 | ~acc 86 | (f:'acc -> 'a lt -> 'acc) = 87 | let acc = ref acc in 88 | partial_iter ~start ~stop (fun i -> acc := f !acc i); 89 | !acc 90 | 91 | (* Random generator *) 92 | let rand state (n: 'a eq) :'a lt = 93 | Unsafe.create @@ Random.State.int state @@ to_int n 94 | 95 | (* Predicate generators *) 96 | type truth = Truth 97 | let (%<%): 'a lt -> 'a eq -> truth = fun _ _ -> Truth 98 | 99 | module Sum = struct 100 | exception Erroneous_arithmetic of 101 | {fn:string; summand:int list; erroneous_sum:int } 102 | 103 | type 'a summand = int list 104 | type ('a, 'c) t = Witness of {summand:'a summand; result: 'c eq} 105 | 106 | let create: 'a summand -> 'c eq -> ('a,'c) t option = 107 | fun s result -> 108 | let sum = List.fold_left (fun s x -> s + (x:>int) ) 0 s in 109 | if sum = (result:>int) then 110 | Some (Witness {summand=s; result }) 111 | else None 112 | 113 | let create_exn: 'a summand -> 'c eq -> ('a,'c) t = 114 | fun s result -> 115 | let sum = List.fold_left (fun s x -> s + (x:>int) ) 0 s in 116 | if sum = (result:>int) then 117 | Witness {summand=s; result } 118 | else 119 | raise @@ Erroneous_arithmetic 120 | { fn="create_exn"; summand=s; erroneous_sum = result } 121 | 122 | let adder: ('a * 'b, 'c) t -> 'a lt -> 'b le -> 'c lt = 123 | fun _p x y -> Unsafe.magic @@ (x:>int) + (y:>int) 124 | 125 | 126 | let ( + ) (x:'a eq) (y:'b eq) : ('a * 'b) summand = [x; y] 127 | let ( =? ) = create 128 | let ( =! ) = create_exn 129 | end 130 | 131 | let (% 'b eq -> 'b lt option = 132 | fun k l -> if to_int k < to_int l then Some(Unsafe.magic k) else None 133 | 134 | let if_ opt f g = match opt with 135 | | Some x -> f x 136 | | None -> g () 137 | 138 | let (%?): 'a lt -> 'a eq -> 'a lt = fun x _y -> x 139 | 140 | let if_inferior (n:int) (nat: 'a eq) (f:'a lt -> 'b) (default:'b) = 141 | if n < to_int nat then 142 | f @@ Unsafe.create n 143 | else 144 | default 145 | -------------------------------------------------------------------------------- /lib/nat.mli: -------------------------------------------------------------------------------- 1 | (** 2 | This modules provides natural numbers (non-negative integers) extended with 3 | a type-level predicate on the associated value. 4 | 5 | More precisely, this module assumes that there is a type-level 6 | embedding of integer intervals [%nat S] and that any values of type 7 | nat satisfies the following invariants: 8 | 9 | * For any value [n: ([%nat S ], mark ) nat], 10 | ** if [ mark = [`Lt] ], n < min S 11 | ** if [ mark = [`Lt|`Eq] ], n ≤ min S 12 | ** if [ mark = [`Eq] ], S = \{n\} 13 | 14 | The module {!Nat_defs} provides functions that can create natural number 15 | that automatically respect these invariants. Integers up to [10] are also 16 | predefined within this modules using a "_" prefix, and a suffix "i" and "p" 17 | for strictly bounded and bounded naturals numbers; i.e: 18 | [ _1 : (_, [`Eq] ) nat], [ _1i : (_, [`Lt] ) nat], [ _1p : (_, [`Eq|`Lt] ) nat] 19 | 20 | Another option is to use {!Tensority} ppx extension, that comes with literal 21 | support for Nat: 22 | 23 | * k-suffixed literals (e.g. [2k] ) are translated to [( 'n , [`Eq] ) nat] 24 | with the right type representation (e.g. [ `_2 of [`T] ] ) 25 | * j-suffixed literals (e.g. [1j] ) are translated to [('n, [`Lt]) nat ] 26 | * p-suffixed literals (e.g. [1p] ) are translated to [('n, [`Lt | `Eq ]) nat ] 27 | 28 | Note that tensority ppx extension provides two other literals types that are 29 | detailed in the {!Shape} modules. 30 | 31 | Unsafe functions can be used to directly create a natural with [Unsafe.create] 32 | or convert a natural from one type to the other [Unsafe.magic]. If the 33 | previous invariant are not respected, the behavior of any function using 34 | these broken naturals is unspecified. 35 | 36 | If possible, it is therefore recommanded to create type and value checked 37 | dynamical value using the [Dynamic] functor. 38 | *) 39 | 40 | 41 | (** {2 Natural numbers with type-level reflection} *) 42 | 43 | type (+'a, -'b) t = private int 44 | type (+'a,-'b) nat = ('a,'b) t 45 | (** Underneath, a [nat] or [Nat.t] is just an integer *) 46 | 47 | 48 | (** {3 Helper types} *) 49 | 50 | type empty = private Empty_set 51 | type z = private Z 52 | type nz = private NZ 53 | type +'a succ = private Succ 54 | 55 | (** {2 Specialized types} *) 56 | 57 | type 'a lt = ('a, [`Lt]) t 58 | (** A natural [k : [%nat n] lt] is a couple of value [k] and integer interval 59 | type [[%nat n]] such that [ k < min n ] *) 60 | 61 | type 'a eq = ('a, [`Eq]) t 62 | (** A natural [k : [%nat n] eq] is a couple of value [k] and type [[%nat n]] such 63 | that [ k = n ] *) 64 | 65 | type 'a le = ('a, [`Eq|`Lt]) t 66 | (** A natural [k : [%nat n] lt] is a couple of value [k] and integer interval 67 | type [[%nat n]] such that [ k < min n ] *) 68 | 69 | 70 | (** {3 Unsafe functions} *) 71 | module Unsafe: sig 72 | type unsafe 73 | val create : int -> ('a, 'b) t 74 | val magic : ('a, 'b) t -> ('c, 'd) t 75 | val eq: int -> unsafe eq 76 | val lt: int -> unsafe lt 77 | end 78 | 79 | (** {3 Conversion and printing } *) 80 | val to_int : ('a, 'b) t -> int 81 | val pp : Format.formatter -> ('a, 'b) t -> unit 82 | val show : ('a, 'b) t -> string 83 | 84 | (** {3 Utility function} *) 85 | val zero : ('a, 'b) t 86 | val succ : ('a, 'b) t -> int 87 | 88 | (** {3 Dynamic natural} 89 | The dynamic functor allow to safely use naturals not known at 90 | compile time. The safety of dynamic natural is guaranteed by 91 | disabling the possibility to construct ['a lt] natural. 92 | *) 93 | module type dynamic = sig type t val dim: t eq end 94 | module Dynamic : sig val dim : int end -> dynamic 95 | val dynamic: int -> (module dynamic) 96 | 97 | 98 | (** {2 Iter, map, fold functions} *) 99 | (** {3 Iter functions }*) 100 | 101 | (** [iter f nat] computes [f (0: '(< nat)); ...; f (nat - 1: '(< nat) )] *) 102 | val iter : ('a lt -> unit) -> 'a eq -> unit 103 | 104 | (** [iter_on nat f] is [iter f nat] *) 105 | val iter_on : 'a eq -> ('a lt -> unit) -> unit 106 | 107 | (** [map f nat] computes [| f (0: '(< nat)); ...; f (nat - 1: '(< nat) )|] *) 108 | val map : ('a lt -> 'b) -> 'a eq -> 'b array 109 | 110 | 111 | (** [partial_iter ~start ~stop f] computes 112 | [f (start: '(< nat)); ...; f (stop: '(< nat) )] *) 113 | val partial_iter : start:int -> stop:'a eq -> ('a lt -> unit) -> unit 114 | 115 | (** [typed_partial_iter] requires a proof that [start stop:'a eq -> ('a lt -> unit) -> unit 117 | 118 | (** {3 Iter functions }*) 119 | val fold : ('a -> 'b lt -> 'a) -> 'a -> 'b eq -> 'a 120 | val fold_on: 'b eq -> 'a -> ('a -> 'b lt -> 'a) -> 'a 121 | val partial_fold : 122 | start:int -> stop:'a eq -> acc:'acc -> ('acc -> 'a lt -> 'acc) -> 'acc 123 | 124 | (** {3 Random generator }*) 125 | val rand: Random.State.t -> 'n eq -> 'n lt 126 | 127 | (** {2 Extended proofs } *) 128 | type truth = Truth 129 | 130 | (** Type constraints *) 131 | 132 | val ( %<% ) : 'a lt -> 'a eq -> truth 133 | val ( %? ) : 'a lt -> 'a eq -> 'a lt 134 | 135 | (** Proofs for arithmetic proposition of the form a + b = c *) 136 | module Sum : sig 137 | (** Exception for wrong arithmetic exception *) 138 | exception Erroneous_arithmetic of 139 | {fn:string; summand:int list; erroneous_sum:int } 140 | 141 | (** type for summand *) 142 | type 'a summand 143 | 144 | (** type for proof witness: the successful construction of a value 145 | of type [ ([s_1,..s_n],r) t ] implies that ∑ s_i = r *) 146 | type ('a, 'c) t 147 | 148 | (** Safe construction function *) 149 | val create : 'a summand -> 'c eq -> ('a, 'c) t option 150 | 151 | (** Operator version *) 152 | val ( =? ) : 'a summand -> 'b eq -> ('a, 'b) t option 153 | 154 | (** Exception-raising construction function *) 155 | val create_exn : 'a summand -> 'c eq -> ('a, 'c) t 156 | 157 | (** Operator version *) 158 | val ( =! ) : 'a summand -> 'b eq -> ('a, 'b) t 159 | 160 | (** Pair summation *) 161 | val ( + ) : 'a eq -> 'b eq -> ('a * 'b) summand 162 | 163 | (** Let be k , l and n such that [k + l = n], then [k' < k] and 164 | [l'≤l] ⇒ [k' + l' < n]. 165 | With a proof that [k + l = n], therefore we can safely add 166 | [k:k lt] and [l: l le] to obtain a natural number [n : n lt]. 167 | *) 168 | val adder : ('a * 'b ,'c) t -> 'a lt -> 'b le -> 'c lt 169 | end 170 | 171 | val ( % 'b eq -> 'b lt option 172 | val if_inferior : int -> 'a eq -> ('a lt -> 'b) -> 'b -> 'b 173 | 174 | val if_ : 'a option -> ('a -> 'b) -> (unit -> 'b) -> 'b 175 | -------------------------------------------------------------------------------- /lib/nat_defs.ml: -------------------------------------------------------------------------------- 1 | open Nat 2 | 3 | 4 | module Gt = struct 5 | type _9 = [ `_9 ] 6 | type _8 = [ `_8 | _9 ] 7 | type _7 = [ `_7 | _8 ] 8 | type _6 = [ `_6 | _7 ] 9 | type _5 = [ `_5 | _6 ] 10 | type _4 = [ `_4 | _5 ] 11 | type _3 = [ `_3 | _4 ] 12 | type _2 = [ `_2 | _3 ] 13 | type _1 = [ `_1 | _2 ] 14 | type _0 = [ `_0 | _1 ] 15 | end 16 | 17 | module Lep = struct 18 | type +'a _0 = [ `_0 of 'a ] 19 | type +'a _1 = [ `_1 of 'a | 'a _0 ] 20 | type +'a _2 = [ `_2 of 'a | 'a _1 ] 21 | type +'a _3 = [ `_3 of 'a | 'a _2 ] 22 | type +'a _4 = [ `_4 of 'a | 'a _3 ] 23 | type +'a _5 = [ `_5 of 'a | 'a _4 ] 24 | type +'a _6 = [ `_6 of 'a | 'a _5 ] 25 | type +'a _7 = [ `_7 of 'a | 'a _6 ] 26 | type +'a _8 = [ `_8 of 'a | 'a _7 ] 27 | type +'a _9 = [ `_9 of 'a | 'a _8 ] 28 | end 29 | 30 | module Sp_lep = struct 31 | type +'a _1 = [ `_1 of 'a ] 32 | type +'a _2 = [ `_2 of 'a | 'a _1 ] 33 | type +'a _3 = [ `_3 of 'a | 'a _2 ] 34 | type +'a _4 = [ `_4 of 'a | 'a _3 ] 35 | type +'a _5 = [ `_5 of 'a | 'a _4 ] 36 | type +'a _6 = [ `_6 of 'a | 'a _5 ] 37 | type +'a _7 = [ `_7 of 'a | 'a _6 ] 38 | type +'a _8 = [ `_8 of 'a | 'a _7 ] 39 | type +'a _9 = [ `_9 of 'a | 'a _8 ] 40 | end 41 | 42 | module Gtp = struct 43 | type +'a _9 = [ `_9 of 'a ] 44 | type +'a _8 = [ `_8 of 'a | 'a _9 ] 45 | type +'a _7 = [ `_7 of 'a | 'a _8 ] 46 | type +'a _6 = [ `_6 of 'a | 'a _7 ] 47 | type +'a _5 = [ `_5 of 'a | 'a _6 ] 48 | type +'a _4 = [ `_4 of 'a | 'a _5 ] 49 | type +'a _3 = [ `_3 of 'a | 'a _4 ] 50 | type +'a _2 = [ `_2 of 'a | 'a _3 ] 51 | type +'a _1 = [ `_1 of 'a | 'a _2 ] 52 | type +'a _0 = [ `_0 of 'a | 'a _1 ] 53 | end 54 | 55 | type (+'a, +'b) all = [< 'a Gtp._0 ] as 'b 56 | type (+'a, +'b) end_ = [< 'a Gtp._0 | `T ] as 'b 57 | type +'args at_least_1 = (('a,'x) end_, 'y) all 58 | constraint 'args = 'a * 'x * 'y 59 | type (+'a,+'res) filter_zero = 60 | [< `_1 of 'b | `_2 of 'c | `_3 of 'd | `_4 of 'e | `_5 of 'f | `_6 of 'g 61 | | `_7 of 'h | `_8 of 'i | `_9 of 'j ] as 'res 62 | constraint 63 | 'a = 'b * 'c *'d *'e *'f *'g *'h * 'i * 'j 64 | 65 | 66 | module Shifter(K:sig type m end) = struct 67 | type +'a t = ('a, K.m) Nat.t 68 | type (+'d,+'x) s = ('d,'x) all 69 | 70 | let shift k (d,x) = d * 10, Unsafe.create (k*d + Nat.to_int x) 71 | 72 | type ('args,'fx,'aux,'lead) f_gen = 73 | int * ('x * 'd * 'l ) t -> int * ('fx * ('d, 'any) s * 'lead ) t 74 | constraint 75 | 'args = 'x * 'd 76 | constraint 77 | 'aux = 'l * 'any 78 | 79 | type ('x,'fx,'aux) f = ('x,'fx,'aux,nz) f_gen 80 | type ('x,'fx,'aux) f0 = ('x,'fx,'aux,z) f_gen 81 | 82 | (**) 83 | let _9 : ('a * 'd, [< `_9 of 'a | ('d,_) s Lep._8 ], _ ) f = 84 | fun x -> shift 9 x 85 | let _8 : ('a * 'd, [< `_8 of 'a | 'd Gtp._9 | ('d,_) s Lep._7 ], _ ) f = 86 | fun x -> shift 8 x 87 | let _7 : ('a * 'd, [< `_7 of 'a | 'd Gtp._8 | ('d,_) s Lep._6 ], _ ) f = 88 | fun x -> shift 7 x 89 | let _6 : ('a * 'd, [< `_6 of 'a | 'd Gtp._7 | ('d,_) s Lep._5 ], _ ) f = 90 | fun x -> shift 6 x 91 | let _5 : ('a * 'd, [< `_5 of 'a | 'd Gtp._6 | ('d,_) s Lep._4 ], _ ) f = 92 | fun x -> shift 5 x 93 | let _4 : ('a * 'd, [< `_4 of 'a | 'd Gtp._5 | ('d,_) s Lep._3 ], _ ) f = 94 | fun x -> shift 4 x 95 | let _3 : ('a * 'd, [< `_3 of 'a | 'd Gtp._4 | ('d,_) s Lep._2 ], _ ) f = 96 | fun x -> shift 3 x 97 | let _2 : ('a * 'd, [< `_2 of 'a | 'd Gtp._3 | ('d,_) s Lep._1 ], _ ) f = 98 | fun x -> shift 2 x 99 | let _1 : ('a * 'd, [< `_1 of 'a | 'd Gtp._2 | ('d,_) s Lep._0 ], _ ) f = 100 | fun x -> shift 1 x 101 | let _0 : ('a * 'd, [< `_0 of 'a | 'd Gtp._1 ], _ ) f0 = fun x -> shift 0 x 102 | 103 | let nat: int * ((_,'x) filter_zero * 'd * nz) t -> 'x t = 104 | fun (_m,n) -> Unsafe.magic n 105 | let nat_z : int * ('x * 'd * nz) t -> 'x t = 106 | fun (_,n) -> Unsafe.magic n 107 | 108 | let (@) f x = f x 109 | end 110 | 111 | module Indices = struct 112 | module K = struct type m = [`Lt] end 113 | let make n = 10, Unsafe.create n 114 | type (+'a,+'any) b = int * ('a * 'any at_least_1 * nz) lt 115 | 116 | let _9n : ( _ at_least_1,'a) b = make 9 117 | let _8n : ([< _ end_ Gtp._9 | _ all Lep._8 ],_) b = make 8 118 | let _7n : ([< _ end_ Gtp._8 | _ all Lep._7],_) b = make 7 119 | let _6n : ([< _ end_ Gtp._7 | _ all Lep._6],_) b = make 6 120 | let _5n : ([< _ end_ Gtp._6 | _ all Lep._5],_) b = make 5 121 | let _4n : ([< _ end_ Gtp._5 | _ all Lep._4],_) b = make 4 122 | let _3n : ([< _ end_ Gtp._4 | _ all Lep._3],_) b = make 3 123 | let _2n : ([< _ end_ Gtp._3 | _ all Lep._2],_) b = make 2 124 | let _1n : ([< _ end_ Gtp._2 | _ all Lep._1],_) b = make 1 125 | let _0n : ([< _ end_ Gtp._1],_) b = make 0 126 | 127 | include Shifter(K) 128 | end 129 | 130 | module Adder = struct 131 | 132 | module K = struct type m = [ `Eq | `Lt ] end 133 | 134 | let make n = 10, Unsafe.create n 135 | type (+'a,+'any) b = int * ('a * 'any at_least_1 * nz) le 136 | 137 | let _9n : ([< _ end_ Gtp._9 | _ all Lep._8 ],_) b = make 9 138 | let _8n : ([< _ end_ Gtp._8 | _ all Lep._7],_) b = make 8 139 | let _7n : ([< _ end_ Gtp._7 | _ all Lep._6],_) b = make 7 140 | let _6n : ([< _ end_ Gtp._6 | _ all Lep._5],_) b = make 6 141 | let _5n : ([< _ end_ Gtp._5 | _ all Lep._4],_) b = make 5 142 | let _4n : ([< _ end_ Gtp._4 | _ all Lep._3],_) b = make 4 143 | let _3n : ([< _ end_ Gtp._3 | _ all Lep._2],_) b = make 3 144 | let _2n : ([< _ end_ Gtp._2 | _ all Lep._1],_) b = make 2 145 | let _1n : ([< _ end_ Gtp._1],_) b = make 1 146 | let _0n : ('any,_) b = make 1 147 | 148 | 149 | include Shifter(K) 150 | 151 | end 152 | 153 | module Size = struct 154 | let make n = 10, Unsafe.create n 155 | type 'a s = int * ('a * nz) eq 156 | let _9n : [ `_9 of [`T] ] s = make 9 157 | let _8n : [ `_8 of [`T] ] s = make 8 158 | let _7n : [ `_7 of [`T] ] s = make 7 159 | let _6n : [ `_6 of [`T] ] s = make 6 160 | let _5n : [ `_5 of [`T] ] s = make 5 161 | let _4n : [ `_4 of [`T] ] s = make 4 162 | let _3n : [ `_3 of [`T] ] s = make 3 163 | let _2n : [ `_2 of [`T] ] s = make 2 164 | let _1n : [ `_1 of [`T] ] s = make 1 165 | let _0n : [ `_0 of [`T] ] s = make 0 166 | 167 | 168 | let shift k (d,x) = 10 * d, Unsafe.create (k*d + Nat.to_int x) 169 | 170 | type ('x,'fx,'any) d = 171 | int * ('x * 'l) eq -> int * ('fx * nz) eq 172 | constraint 'any = 'l 173 | 174 | type ('x,'fx,'any) d0 = 175 | int * ('x * 'l) eq -> int * ('fx * z) eq 176 | constraint 'any = 'l 177 | 178 | let _9 : ('a,[ `_9 of 'a ],_) d = fun x -> shift 9 x 179 | let _8 : ('a,[ `_8 of 'a ],_) d = fun x -> shift 8 x 180 | let _7 : ('a,[ `_7 of 'a ],_) d = fun x -> shift 7 x 181 | let _6 : ('a,[ `_6 of 'a ],_) d = fun x -> shift 6 x 182 | let _5 : ('a,[ `_5 of 'a ],_) d = fun x -> shift 5 x 183 | let _4 : ('a,[ `_4 of 'a ],_) d = fun x -> shift 4 x 184 | let _3 : ('a,[ `_3 of 'a ],_) d = fun x -> shift 3 x 185 | let _2 : ('a,[ `_2 of 'a ],_) d = fun x -> shift 2 x 186 | let _1 : ('a,[ `_1 of 'a ],_) d = fun x -> shift 1 x 187 | let _0 : ('a,[ `_0 of 'a ],_) d0 = fun x -> shift 0 x 188 | 189 | 190 | let nat: int * ('digits * nz) eq 191 | -> 'digits eq 192 | = 193 | fun (_m,n) -> Unsafe.magic n 194 | 195 | let (@) f x = f x 196 | end 197 | 198 | let _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10 = 199 | let open Size in 200 | let (!) x = nat x in 201 | !_0n, !_1n, ! _2n, !_3n, !_4n, !_5n, !_6n, !_7n, !_8n, !_9n, !(_1 @ _0n) 202 | 203 | 204 | let x, y, z, t = 205 | let open Indices in 206 | let (!) x = (nat x) and (!!) x = (nat_z x) in 207 | Indices.( !!_0n, !_1n, !_2n, !_3n ) 208 | 209 | 210 | let _0i, _1i, _2i, _3i, _4i, _5i, _6i, _7i, _8i, _9i, _10i = 211 | let open Indices in 212 | let (!) x= nat x in 213 | x, y, z, t, !_4n, !_5n, !_6n, !_7n, !_8n, ! _9n, !(_1 @ _0n) 214 | 215 | let _0p, _1p, _2p, _3p, _4p, _5p, _6p, _7p, _8p, _9p, _10p = 216 | let open Adder in 217 | let (!) = Adder.nat 218 | and (!!) = Adder.nat_z 219 | in 220 | !!_0n, !_1n, !_2n, !_3n, !_4n, !_5n, !_6n, !_7n, !_8n, !_9n, !(_1 @ _0n) 221 | -------------------------------------------------------------------------------- /lib/nat_defs.mli: -------------------------------------------------------------------------------- 1 | open Nat 2 | module Gt : 3 | sig 4 | type _9 = [ `_9 ] 5 | type _8 = [ `_8 | _9 ] 6 | type _7 = [ `_7 | _8 ] 7 | type _6 = [ `_6 | _7 ] 8 | type _5 = [ `_5 | _6 ] 9 | type _4 = [ `_4 | _5 ] 10 | type _3 = [ `_3 | _4 ] 11 | type _2 = [ `_2 | _3 ] 12 | type _1 = [ `_1 | _2 ] 13 | type _0 = [ `_0 | _1 ] 14 | end 15 | 16 | module Lep : 17 | sig 18 | type +'a _0 = [ `_0 of 'a ] 19 | type +'a _1 = [ `_1 of 'a | 'a _0 ] 20 | type +'a _2 = [ `_2 of 'a | 'a _1 ] 21 | type +'a _3 = [ `_3 of 'a | 'a _2 ] 22 | type +'a _4 = [ `_4 of 'a | 'a _3 ] 23 | type +'a _5 = [ `_5 of 'a | 'a _4 ] 24 | type +'a _6 = [ `_6 of 'a | 'a _5 ] 25 | type +'a _7 = [ `_7 of 'a | 'a _6 ] 26 | type +'a _8 = [ `_8 of 'a | 'a _7 ] 27 | type +'a _9 = [ `_9 of 'a | 'a _8 ] 28 | end 29 | 30 | module Sp_lep : 31 | sig 32 | type +'a _1 = [ `_1 of 'a ] 33 | type +'a _2 = [ `_2 of 'a | 'a _1 ] 34 | type +'a _3 = [ `_3 of 'a | 'a _2 ] 35 | type +'a _4 = [ `_4 of 'a | 'a _3 ] 36 | type +'a _5 = [ `_5 of 'a | 'a _4 ] 37 | type +'a _6 = [ `_6 of 'a | 'a _5 ] 38 | type +'a _7 = [ `_7 of 'a | 'a _6 ] 39 | type +'a _8 = [ `_8 of 'a | 'a _7 ] 40 | type +'a _9 = [ `_9 of 'a | 'a _8 ] 41 | end 42 | 43 | module Gtp : 44 | sig 45 | type +'a _9 = [ `_9 of 'a ] 46 | type +'a _8 = [ `_8 of 'a | 'a _9 ] 47 | type +'a _7 = [ `_7 of 'a | 'a _8 ] 48 | type +'a _6 = [ `_6 of 'a | 'a _7 ] 49 | type +'a _5 = [ `_5 of 'a | 'a _6 ] 50 | type +'a _4 = [ `_4 of 'a | 'a _5 ] 51 | type +'a _3 = [ `_3 of 'a | 'a _4 ] 52 | type +'a _2 = [ `_2 of 'a | 'a _3 ] 53 | type +'a _1 = [ `_1 of 'a | 'a _2 ] 54 | type +'a _0 = [ `_0 of 'a | 'a _1 ] 55 | end 56 | 57 | type ('a, +'b) all = [< 'a Gtp._0 ] as 'b 58 | type ('a, +'b) end_ = [< 'a Gtp._0 | `T ] as 'b 59 | 60 | type +'args at_least_1 = (('a,'x) end_, 'y) all 61 | constraint 'args = 'a * 'x * 'y 62 | 63 | type (+'a,+'res) filter_zero = 64 | [< `_1 of 'b | `_2 of 'c | `_3 of 'd | `_4 of 'e | `_5 of 'f | `_6 of 'g 65 | | `_7 of 'h | `_8 of 'i | `_9 of 'j ] as 'res 66 | constraint 67 | 'a = 'b * 'c *'d *'e *'f *'g *'h * 'i * 'j 68 | 69 | module Indices : 70 | sig 71 | 72 | type ('a, +'b) b = int * ('a * 'b at_least_1 * nz) lt 73 | 74 | val _9n : ( _ at_least_1,'a) b 75 | val _8n : ([< _ end_ Gtp._9 | _ all Lep._8 ],_) b 76 | val _7n : ([< _ end_ Gtp._8 | _ all Lep._7],_) b 77 | val _6n : ([< _ end_ Gtp._7 | _ all Lep._6],_) b 78 | val _5n : ([< _ end_ Gtp._6 | _ all Lep._5],_) b 79 | val _4n : ([< _ end_ Gtp._5 | _ all Lep._4],_) b 80 | val _3n : ([< _ end_ Gtp._4 | _ all Lep._3],_) b 81 | val _2n : ([< _ end_ Gtp._3 | _ all Lep._2],_) b 82 | val _1n : ([< _ end_ Gtp._2 | _ all Lep._1],_) b 83 | val _0n : ([< _ end_ Gtp._1],_) b 84 | 85 | type ('d, +'a) s = ('d, 'a) all constraint 'a = [< 'd Gtp._0 ] 86 | 87 | type ('a, 'fx, 'b, 'lead) f_gen = 88 | int * ('x * 'd * 'l) lt -> int * ('fx * ('d, 'c) s * 'lead) lt 89 | constraint 'a = 'x * 'd 90 | constraint 'b = 'l * ([< 'd Gtp._0 ] as 'c) 91 | type ('a, 'fx, 'b) f = ('a, 'fx, 'b, nz) f_gen constraint 'a = 'g * 'f 92 | constraint 'b = 'c * ([< 'f Gtp._0 ] as 'e) 93 | type ('a, 'fx, 'b) f0 = ('a, 'fx, 'b, z) f_gen constraint 'a = 'g * 'f 94 | constraint 'b = 'c * ([< 'f Gtp._0 ] as 'e) 95 | 96 | 97 | val _9 : ('a * 'd, [< `_9 of 'a | ('d,_) s Lep._8 ], _ ) f 98 | val _8 : ('a * 'd, [< `_8 of 'a | 'd Gtp._9 | ('d,_) s Lep._7 ], _ ) f 99 | val _7 : ('a * 'd, [< `_7 of 'a | 'd Gtp._8 | ('d,_) s Lep._6 ], _ ) f 100 | val _6 : ('a * 'd, [< `_6 of 'a | 'd Gtp._7 | ('d,_) s Lep._5 ], _ ) f 101 | val _5 : ('a * 'd, [< `_5 of 'a | 'd Gtp._6 | ('d,_) s Lep._4 ], _ ) f 102 | val _4 : ('a * 'd, [< `_4 of 'a | 'd Gtp._5 | ('d,_) s Lep._3 ], _ ) f 103 | val _3 : ('a * 'd, [< `_3 of 'a | 'd Gtp._4 | ('d,_) s Lep._2 ], _ ) f 104 | val _2 : ('a * 'd, [< `_2 of 'a | 'd Gtp._3 | ('d,_) s Lep._1 ], _ ) f 105 | val _1 : ('a * 'd, [< `_1 of 'a | 'd Gtp._2 | ('d,_) s Lep._0 ], _ ) f 106 | val _0 : ('a * 'd, [< `_0 of 'a | 'd Gtp._1 ], _ ) f0 107 | 108 | val nat: int * ((_,'x) filter_zero * 'd * nz) lt -> 'x lt 109 | val nat_z : int * ('x * 'd * nz) lt -> 'x lt 110 | 111 | val ( @ ) : ('a -> 'b) -> 'a -> 'b 112 | end 113 | 114 | module Adder : 115 | sig 116 | 117 | type (+'a,+'any) b = int * ('a * 'any at_least_1 * nz) le 118 | 119 | 120 | val _9n : ([< _ end_ Gtp._9 | _ all Lep._8 ],_) b 121 | val _8n : ([< _ end_ Gtp._8 | _ all Lep._7],_) b 122 | val _7n : ([< _ end_ Gtp._7 | _ all Lep._6],_) b 123 | val _6n : ([< _ end_ Gtp._6 | _ all Lep._5],_) b 124 | val _5n : ([< _ end_ Gtp._5 | _ all Lep._4],_) b 125 | val _4n : ([< _ end_ Gtp._4 | _ all Lep._3],_) b 126 | val _3n : ([< _ end_ Gtp._3 | _ all Lep._2],_) b 127 | val _2n : ([< _ end_ Gtp._2 | _ all Lep._1],_) b 128 | val _1n : ([< _ end_ Gtp._1],_) b 129 | val _0n : ('any,_) b 130 | 131 | type ('d, +'a) s = ('d, 'a) all constraint 'a = [< 'd Gtp._0 ] 132 | 133 | type ('a, 'fx, 'b, 'lead) f_gen = 134 | int * ('x * 'd * 'l) le -> int * ('fx * ('d, 'c) s * 'lead) le 135 | constraint 'a = 'x * 'd 136 | constraint 'b = 'l * ([< 'd Gtp._0 ] as 'c) 137 | type ('a, 'fx, 'b) f = ('a, 'fx, 'b, nz) f_gen constraint 'a = 'g * 'f 138 | constraint 'b = 'c * ([< 'f Gtp._0 ] as 'e) 139 | type ('a, 'fx, 'b) f0 = ('a, 'fx, 'b, z) f_gen constraint 'a = 'g * 'f 140 | constraint 'b = 'c * ([< 'f Gtp._0 ] as 'e) 141 | 142 | 143 | val _9 : ('a * 'd, [< `_9 of 'a | ('d,_) s Lep._8 ], _ ) f 144 | val _8 : ('a * 'd, [< `_8 of 'a | 'd Gtp._9 | ('d,_) s Lep._7 ], _ ) f 145 | val _7 : ('a * 'd, [< `_7 of 'a | 'd Gtp._8 | ('d,_) s Lep._6 ], _ ) f 146 | val _6 : ('a * 'd, [< `_6 of 'a | 'd Gtp._7 | ('d,_) s Lep._5 ], _ ) f 147 | val _5 : ('a * 'd, [< `_5 of 'a | 'd Gtp._6 | ('d,_) s Lep._4 ], _ ) f 148 | val _4 : ('a * 'd, [< `_4 of 'a | 'd Gtp._5 | ('d,_) s Lep._3 ], _ ) f 149 | val _3 : ('a * 'd, [< `_3 of 'a | 'd Gtp._4 | ('d,_) s Lep._2 ], _ ) f 150 | val _2 : ('a * 'd, [< `_2 of 'a | 'd Gtp._3 | ('d,_) s Lep._1 ], _ ) f 151 | val _1 : ('a * 'd, [< `_1 of 'a | 'd Gtp._2 | ('d,_) s Lep._0 ], _ ) f 152 | val _0 : ('a * 'd, [< `_0 of 'a | 'd Gtp._1 ], _ ) f0 153 | 154 | val nat: int * ((_,'x) filter_zero * 'd * nz) le -> 'x le 155 | val nat_z : int * ('x * 'd * nz) le -> 'x le 156 | 157 | val ( @ ) : ('a -> 'b) -> 'a -> 'b 158 | end 159 | 160 | module Size : 161 | sig 162 | type 'a s = int * ('a * nz) eq 163 | val _9n : [ `_9 of [ `T ] ] s 164 | val _8n : [ `_8 of [ `T ] ] s 165 | val _7n : [ `_7 of [ `T ] ] s 166 | val _6n : [ `_6 of [ `T ] ] s 167 | val _5n : [ `_5 of [ `T ] ] s 168 | val _4n : [ `_4 of [ `T ] ] s 169 | val _3n : [ `_3 of [ `T ] ] s 170 | val _2n : [ `_2 of [ `T ] ] s 171 | val _1n : [ `_1 of [ `T ] ] s 172 | val _0n : [ `_0 of [ `T ] ] s 173 | 174 | type ('x, 'fx, 'l) d = int * ('x * 'l) eq -> int * ('fx * nz) eq 175 | type ('x, 'fx, 'l) d0 = int * ('x * 'l) eq -> int * ('fx * z) eq 176 | val _9 : ('a, [ `_9 of 'a ], 'b) d 177 | val _8 : ('a, [ `_8 of 'a ], 'b) d 178 | val _7 : ('a, [ `_7 of 'a ], 'b) d 179 | val _6 : ('a, [ `_6 of 'a ], 'b) d 180 | val _5 : ('a, [ `_5 of 'a ], 'b) d 181 | val _4 : ('a, [ `_4 of 'a ], 'b) d 182 | val _3 : ('a, [ `_3 of 'a ], 'b) d 183 | val _2 : ('a, [ `_2 of 'a ], 'b) d 184 | val _1 : ('a, [ `_1 of 'a ], 'b) d 185 | val _0 : ('a, [ `_0 of 'a ], 'b) d0 186 | val nat : int * ('digits * nz) eq -> 'digits eq 187 | val ( @ ) : ('a -> 'b) -> 'a -> 'b 188 | end 189 | 190 | 191 | val _0 : [ `_0 of [ `T ] ] eq 192 | val _1 : [ `_1 of [ `T ] ] eq 193 | val _2 : [ `_2 of [ `T ] ] eq 194 | val _3 : [ `_3 of [ `T ] ] eq 195 | val _4 : [ `_4 of [ `T ] ] eq 196 | val _5 : [ `_5 of [ `T ] ] eq 197 | val _6 : [ `_6 of [ `T ] ] eq 198 | val _7 : [ `_7 of [ `T ] ] eq 199 | val _8 : [ `_8 of [ `T ] ] eq 200 | val _9 : [ `_9 of [ `T ] ] eq 201 | val _10 : [ `_1 of [ `_0 of [ `T ] ] ] eq 202 | 203 | val x : 204 | [< (_,_ ) end_ Gtp._1 ] lt 205 | 206 | val y : 207 | [< `_1 of (_, _) all | (_,_) end_ Gtp._2] lt 208 | 209 | val z : 210 | [< (_, _) all Sp_lep._2 | (_,_) end_ Gtp._3] lt 211 | 212 | val t : 213 | [< (_, _) all Sp_lep._3 | (_,_) end_ Gtp._4] lt 214 | 215 | val _0i : 216 | [< (_,_ ) end_ Gtp._1 ] lt 217 | 218 | val _1i : 219 | [< `_1 of (_, _) all | (_,_) end_ Gtp._2] lt 220 | 221 | val _2i : 222 | [< (_, _) all Sp_lep._2 | (_,_) end_ Gtp._3] lt 223 | 224 | val _3i : 225 | [< (_, _) all Sp_lep._3 | (_,_) end_ Gtp._4] lt 226 | 227 | val _4i : 228 | [< (_, _) all Sp_lep._4 | (_,_) end_ Gtp._5] lt 229 | 230 | val _5i : 231 | [< (_, _) all Sp_lep._5 | (_,_) end_ Gtp._6] lt 232 | 233 | val _6i : 234 | [< (_, _) all Sp_lep._6 | (_,_) end_ Gtp._7] lt 235 | 236 | val _7i : 237 | [< (_, _) all Sp_lep._7 | (_,_) end_ Gtp._8] lt 238 | 239 | val _8i : 240 | [< (_, _) all Sp_lep._8 | (_,_) end_ Gtp._9] lt 241 | 242 | val _9i : 243 | [< (_, _) all Sp_lep._9 ] lt 244 | 245 | val _10i : 246 | [< `_1 of [< (_,_) end_ Gtp._1] | _ at_least_1 Gtp._2 ] lt 247 | 248 | val _0p : 'a le 249 | val _1p : 250 | [< (_,_) end_ Gtp._1 ] le 251 | val _2p : 252 | [< `_1 of (_,_) all | (_,_) end_ Gtp._2 ] le 253 | val _3p : 254 | [< (_,_) all Sp_lep._2 | (_,_) end_ Gtp._3 ] le 255 | val _4p : 256 | [< (_,_) all Sp_lep._3 | (_,_) end_ Gtp._4 ] le 257 | val _5p : 258 | [< (_,_) all Sp_lep._4 | (_,_) end_ Gtp._5 ] le 259 | val _6p : 260 | [< (_,_) all Sp_lep._5 | (_,_) end_ Gtp._6 ] le 261 | val _7p : 262 | [< (_,_) all Sp_lep._6 | (_,_) end_ Gtp._7 ] le 263 | val _8p : 264 | [< (_,_) all Sp_lep._7 | (_,_) end_ Gtp._8 ] le 265 | val _9p : 266 | [< (_,_) all Sp_lep._8 | (_,_) end_ Gtp._9 ] le 267 | val _10p : 268 | [< `_1 of _ | _ at_least_1 Gtp._1 ] le 269 | -------------------------------------------------------------------------------- /lib/range.ml: -------------------------------------------------------------------------------- 1 | type ('n_in, 'n_out) t = { start:int; stop:int; step:int } 2 | 3 | let create ~start ~stop ~step ~len = 4 | let diff = Nat.to_int stop - Nat.to_int start in 5 | let dyn_len = 1 + diff / step and len = Nat.to_int len in 6 | if len <> dyn_len then 7 | raise @@ Signatures.Dimension_error 8 | ("Slices.range.create", dyn_len , len ) 9 | else 10 | {start= Nat.to_int start ; stop= Nat.to_int stop; step } 11 | 12 | let start r = Nat.Unsafe.create r.start 13 | let stop r = Nat.Unsafe.create r.stop 14 | let step r = r.step 15 | let len r = Nat.Unsafe.create @@ 1 + (r.stop - r.start) / r.step 16 | let compose r1 r2 = 17 | { start = r1.start + r2.start 18 | ; stop = r1.start + r2.stop 19 | ; step = r1.step * r2.step 20 | } 21 | let transpose r p = Nat.Unsafe.create @@ r.start + r.step * Nat.to_int p 22 | 23 | let (--) start stop len = create ~start ~stop ~len ~step:1 24 | let (-->) start stop (step,len) = create ~start ~stop ~len ~step 25 | 26 | let pp ppf r = 27 | Format.fprintf ppf "@[[%a->%a by %d (%a)]@]" 28 | Nat.pp (start r) Nat.pp (stop r) 29 | (step r) 30 | Nat.pp (len r) 31 | 32 | let show r= 33 | Format.asprintf "%a" pp r 34 | -------------------------------------------------------------------------------- /lib/range.mli: -------------------------------------------------------------------------------- 1 | 2 | type (+'in_,+'out) t 3 | 4 | val create: start:'a Nat.lt -> stop: 'a Nat.lt -> step:int -> len:'b Nat.eq 5 | -> ('a,'b) t 6 | 7 | val start: ('a,_) t -> 'a Nat.lt 8 | val stop: ('a,_) t -> 'a Nat.lt 9 | val step: _ t -> int 10 | val len: (_,'b) t -> 'b Nat.eq 11 | 12 | val compose: ('a,'b) t -> ('b,'c) t -> ('a,'c) t 13 | val transpose: ('a,'b) t -> 'b Nat.lt -> 'a Nat.lt 14 | 15 | val (--): 'a Nat.lt -> 'a Nat.lt -> 'b Nat.eq -> ('a,'b) t 16 | val (-->): 'a Nat.lt -> 'a Nat.lt -> (int * 'b Nat.eq) -> ('a,'b) t 17 | 18 | val pp: Format.formatter -> ('a,'b) t -> unit 19 | val show: ('a,'b) t -> string 20 | -------------------------------------------------------------------------------- /lib/shape.ml: -------------------------------------------------------------------------------- 1 | type 'a succ = 'a Nat.succ 2 | type z = Nat.z 3 | type nil = private Nil 4 | 5 | type empty = z * nil 6 | 7 | type (_,_) t = 8 | (::): ('nat,'kind) Nat.t * ('n * 'l, 'kind ) t 9 | ->( 'n succ * ('nat * 'l), 'kind) t 10 | | []: (empty, 'any) t 11 | 12 | type 'a eq = ('a, [`Eq]) t 13 | type 'a lt = ('a, [`Lt]) t 14 | type 'a l = 'a eq 15 | 16 | type 'a single = z succ * ('a * nil) 17 | type ('a, 'b) pair = z succ succ * ( 'a * ('b * nil) ) 18 | type ('a, 'b, 'c) triple = z succ succ succ * ( 'a * ('b * ('c * nil))) 19 | 20 | let rec order:type sh. (sh,'k) t -> int = function 21 | | [] -> 0 22 | | _::q -> 1 + order q 23 | 24 | let rec size: type sh. sh eq -> int = function 25 | | [] -> 1 26 | | nat::sh -> Nat.to_int nat * size sh 27 | 28 | let rec zero: type sh. sh eq -> sh lt = function 29 | | _ :: q -> Nat.zero :: zero q 30 | | [] -> [] 31 | 32 | (** {2 Splitting } *) 33 | 34 | 35 | let split_1 = 36 | function 37 | | nat :: q -> nat, q 38 | 39 | let tail_1 = function 40 | | _ :: q -> q 41 | 42 | 43 | (** {2 Iter, map and fold } *) 44 | 45 | let rec iter: type sh. (sh lt -> unit) -> sh eq -> unit = fun f sh -> 46 | match sh with 47 | | [] -> f [] 48 | | a :: sh -> 49 | Nat.iter_on a ( fun nat -> 50 | iter (fun sh -> f (nat :: sh) ) sh 51 | ) 52 | 53 | let iter_on shape f = iter f shape 54 | 55 | let rec fold: type l. ('a -> int -> 'a) -> 'a -> l eq -> 'a = 56 | fun f acc -> function 57 | | [] -> acc 58 | | n::q -> fold f (f acc @@ Nat.to_int n) q 59 | 60 | let rec fold_left: type sh. ('a -> sh lt -> 'a ) -> 'a -> sh eq -> 'a = 61 | fun f acc -> function 62 | | [] -> acc 63 | | n::q -> 64 | let inner acc n = fold_left (fun acc sh -> f acc (n::sh)) acc q in 65 | Nat.fold inner acc n 66 | 67 | 68 | let iter_jmp ~up ~down ~f shape = 69 | let rec iter: type sh. up:(int -> unit) -> down:(int->unit) ->f:(sh lt -> unit) 70 | -> level:int -> sh eq -> unit = 71 | fun ~up ~down ~f ~level -> 72 | function 73 | | [] -> f [] 74 | | a :: sh -> 75 | down level 76 | ; Nat.iter_on a 77 | (fun nat -> iter ~up ~down ~level:(level + 1) 78 | ~f:(fun sh -> f (nat::sh) ) sh ) 79 | ; up level 80 | 81 | in 82 | iter ~f ~up ~down ~level:0 shape 83 | 84 | 85 | let iter_sep ~up ~down ~sep ~f shape = 86 | let rec iter: type sh. 87 | sep:(int -> unit) -> f:(sh lt -> unit) -> level:int -> sh eq -> unit = 88 | fun ~sep ~f ~level -> 89 | let one = Nat_defs._1 in 90 | function 91 | | [] -> f [] 92 | | n :: sh -> 93 | let sub_iter f nat = 94 | iter ~level:(level+1) ~sep ~f:(fun sh -> f @@ (nat) :: sh) sh 95 | in 96 | down level 97 | ; Nat.(if_ (one % 98 | sub_iter f Nat.zero 99 | ; Nat.typed_partial_iter ~start:one ~stop:n 100 | (sub_iter (fun sh -> sep level; f sh)) 101 | ) 102 | ( fun () -> sub_iter f Nat.zero) 103 | ; up level in 104 | iter ~sep ~f ~level:0 shape 105 | 106 | 107 | let pp ppf shape = 108 | let rec inner: type sh. Format.formatter -> (sh,'k) t -> unit = 109 | fun ppf -> 110 | function 111 | | [] -> () 112 | | [a] -> Format.fprintf ppf "%a" Nat.pp a 113 | | a :: q -> 114 | Format.fprintf ppf "%a;@ %a " Nat.pp a inner q 115 | in 116 | Format.fprintf ppf "@[(%a)@]" inner shape 117 | -------------------------------------------------------------------------------- /lib/shape.mli: -------------------------------------------------------------------------------- 1 | (** 2 | 3 | Building upon the {!Tensority.Nat} module, this module extends the 4 | mapping between integer intervals and types to a mapping [[%shape]] between 5 | fixed-size list of integer intervals and type and provides a generic 6 | type [('a,'b) t] for fixed-size integer list with type-level 7 | predicates. 8 | 9 | More precisely a value [ [a_1:(α_1,'k) Nat.t; ...; a_n:(α_n, 'k) Nat.t] ] 10 | is mapped to the type 11 | [ (Nat.z Nat.succ^{(n)} * ( α_1 * ( α_2 ... * α_n)...)), k) t]). 12 | 13 | 14 | At the predicate level, it is useful to define the partial order [(≺)] 15 | over fixed-size list defined by 16 | {[ 17 | [ let rec (≺) l l' = match l, l' with 18 | | a :: q, b :: q' -> a < b && q ≺ q' 19 | | [], [] -> true 20 | ]} 21 | 22 | This module then implements two types of predicates: 23 | 24 | - The equality predicate is associated with the subtype [ 'a eq = ('a, [`Eq]) t ]. 25 | For this subtype, for any value [( l : 'a eq) ], we have [ [%shape 'a] = l ] 26 | Equality list are useful for representing the shape of a multidimensional 27 | array at both the the type and value level. 28 | 29 | - The inferiority predicate is associated with the subtype 30 | [ 'a lt = ('a, [`Lt]) t]. For this subtype, for any value [(l : 'a lt)], we 31 | have the property that [ l ≺ [shape 'a] ]. Such subtype is useful to represent 32 | multi-indices for multidimensional arrays: given two values [(s:'a eq)] 33 | and [(i:'a lt)], the previous property ensures that [( i ≺ s )]. 34 | 35 | *) 36 | 37 | (** {2 Main type definitions} 38 | {3 Auxiliary types} *) 39 | 40 | type nil = private Nil 41 | (** Inhabited utility type *) 42 | 43 | type empty = Nat.z * nil 44 | (** [empty] is the type-level equivalent to 0, [] *) 45 | 46 | 47 | (** {3 Main types} *) 48 | 49 | 50 | type (_,_) t = 51 | (::) : 52 | ('nat,'kind) Nat.t 53 | * ('n * 'l, 'kind ) t 54 | -> ( 'n Nat.succ * ('nat * 'l), 'kind) t 55 | | []: (empty, 'any) t 56 | 57 | (** Size subtype *) 58 | type 'a eq = ('a, [`Eq] ) t 59 | type 'a l = 'a eq 60 | 61 | (** Index subtype *) 62 | type 'a lt = ('a, [`Lt] ) t 63 | 64 | (**{3 Helper types}*) 65 | 66 | type 'a single = Nat.z Nat.succ * ('a * nil) 67 | (** ['a single] is the inner type of a size or index shape with only 68 | one element *) 69 | 70 | type ('a, 'b) pair = Nat.z Nat.succ Nat.succ * ( 'a * ('b * nil) ) 71 | (** [('a,'b) pair] is the inner type of a size or index shape with 72 | two elements *) 73 | 74 | 75 | type ('a, 'b, 'c) triple = Nat.z Nat.succ Nat.succ Nat.succ 76 | * ( 'a * ('b * ('c * nil))) 77 | (** [('a,'b,'c) triple] is the inner type of a size or index shape with 78 | three elements *) 79 | 80 | (** Compute the order, i.e. the number of elements of a shape *) 81 | val order : ('sh,'any) t -> int 82 | 83 | (** Compute the total physical size associated with a size shape *) 84 | val size : 'sh eq -> int 85 | 86 | (** Computes the first multi-index associated to a shape *) 87 | val zero : 'sh eq -> 'sh lt 88 | 89 | 90 | (** {2 Splitting } *) 91 | 92 | 93 | (** Split a size shape its first element and the remaining elements *) 94 | val split_1: ('n Nat.succ * ('a * 'b)) eq -> 'a Nat.eq * ('n * 'b) eq 95 | 96 | 97 | (** Take the tail of a shape after discarding the first element *) 98 | val tail_1: ('n Nat.succ * ('a * 'b)) eq -> ('n * 'b) eq 99 | 100 | (** {2 Iter, map and fold } *) 101 | 102 | val iter : ('sh lt -> unit) -> 'sh eq -> unit 103 | val iter_on : 'a eq -> ('a lt -> unit) -> unit 104 | 105 | val iter_jmp : 106 | up:(int -> unit) -> 107 | down:(int -> unit) -> f:('a lt -> unit) -> 'a eq -> unit 108 | 109 | val iter_sep : 110 | up:(int -> unit) -> down:(int -> unit) -> sep:(int -> unit) 111 | -> f:('a lt -> unit) -> 'a eq -> unit 112 | 113 | 114 | val fold : ('a -> int -> 'a) -> 'a -> 'l eq -> 'a 115 | val fold_left : ('a -> 'sh lt -> 'a) -> 'a -> 'sh eq -> 'a 116 | 117 | (** {2 Pretty printing } *) 118 | 119 | val pp: Format.formatter -> ('a, 'k) t -> unit 120 | -------------------------------------------------------------------------------- /lib/signatures.ml: -------------------------------------------------------------------------------- 1 | module type base_operators = 2 | sig 3 | type 'a t 4 | val ( + ) : 'a t -> 'a t -> 'a t 5 | val ( - ) : 'a t -> 'a t -> 'a t 6 | val ( |*| ) : 'a t -> 'a t -> float 7 | val ( *. ) : float -> 'a t -> 'a t 8 | val ( /. ) : 'a t -> float -> 'a t 9 | val ( ~- ) : 'a t -> 'a t 10 | end 11 | 12 | module type vec_operators= 13 | sig 14 | include base_operators 15 | val (.%()): 'a t -> 'a Nat.lt -> float 16 | val (.%()<-): 'a t -> 'a Nat.lt -> float -> unit 17 | end 18 | 19 | module type matrix_specific_operators = sig 20 | type 'a t constraint 'a = 'b * 'c 21 | type 'a vec 22 | val ( @ ) : ('a * 'b) t -> 'b vec -> 'a vec 23 | val ( * ) : ('a * 'b) t -> ('b * 'c) t -> ('a * 'c) t 24 | val ( **^ ): ('a * 'a ) t -> int -> ('a * 'a ) t 25 | end 26 | 27 | module type matrix_operators = 28 | sig 29 | include matrix_specific_operators 30 | module Matrix_specific: matrix_specific_operators with 31 | type 'a vec := 'a vec and type 'a t := 'a t 32 | val ( + ) : 'a t -> 'a t -> 'a t 33 | val ( - ) : 'a t -> 'a t -> 'a t 34 | val ( |*| ) : 'a t -> 'a t -> float 35 | val ( *. ) : float -> 'a t -> 'a t 36 | val ( /. ) : 'a t -> float -> 'a t 37 | val ( ~- ) : 'a t -> 'a t 38 | val (.%()): ('a*'b) t -> ('a Nat.lt * 'b Nat.lt) -> 39 | float 40 | val (.%()<-): ('a*'b) t -> ('a Nat.lt * 'b Nat.lt) -> 41 | float -> unit 42 | end 43 | 44 | module type tensor_operators = 45 | sig 46 | type 'a t constraint 'a = 47 | type ('a,'b) matrix 48 | val ( * ) : 49 | < contr : 'a; cov : 'b > t -> 50 | < contr : 'b; cov : 'c > t -> < contr : 'a; cov : 'c > t 51 | val ( |*| ) : 52 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t -> float 53 | val ( + ) : 54 | < contr : 'a; cov : 'b > t -> 55 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 56 | val ( - ) : 57 | < contr : 'a; cov : 'b > t -> 58 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 59 | val ( *. ) : 60 | float -> < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 61 | val ( /. ) : 62 | float -> < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 63 | val ( ** ) : ('a, 'a) matrix -> int -> ('a, 'a) matrix 64 | end 65 | 66 | exception Dimension_error of string * int * int 67 | -------------------------------------------------------------------------------- /lib/signatures.mli: -------------------------------------------------------------------------------- 1 | module type base_operators = 2 | sig 3 | type 'a t 4 | val ( + ) : 'a t -> 'a t -> 'a t 5 | val ( - ) : 'a t -> 'a t -> 'a t 6 | val ( |*| ) : 'a t -> 'a t -> float 7 | val ( *. ) : float -> 'a t -> 'a t 8 | val ( /. ) : 'a t -> float -> 'a t 9 | val ( ~- ) : 'a t -> 'a t 10 | end 11 | 12 | module type vec_operators= 13 | sig 14 | include base_operators 15 | val (.%()): 'a t -> 'a Nat.lt -> float 16 | val (.%()<-): 'a t -> 'a Nat.lt -> float -> unit 17 | end 18 | 19 | module type matrix_specific_operators = sig 20 | type 'a t constraint 'a = 'b * 'c 21 | type 'a vec 22 | val ( @ ) : ('a * 'b) t -> 'b vec -> 'a vec 23 | val ( * ) : ('a * 'b) t -> ('b * 'c) t -> ('a * 'c) t 24 | val ( **^ ): ('a * 'a ) t -> int -> ('a * 'a ) t 25 | end 26 | 27 | module type matrix_operators = 28 | sig 29 | include matrix_specific_operators 30 | module Matrix_specific: matrix_specific_operators with 31 | type 'a vec := 'a vec and type 'a t := 'a t 32 | val ( + ) : 'a t -> 'a t -> 'a t 33 | val ( - ) : 'a t -> 'a t -> 'a t 34 | val ( |*| ) : 'a t -> 'a t -> float 35 | val ( *. ) : float -> 'a t -> 'a t 36 | val ( /. ) : 'a t -> float -> 'a t 37 | val ( ~- ) : 'a t -> 'a t 38 | val ( .%() ): ('a*'b) t -> ('a Nat.lt * 'b Nat.lt) -> 39 | float 40 | val ( .%()<- ): ('a*'b) t -> ('a Nat.lt * 'b Nat.lt) -> 41 | float -> unit 42 | end 43 | 44 | module type tensor_operators = 45 | sig 46 | type 'a t constraint 'a = 47 | type ('a,'b) matrix 48 | val ( * ) : 49 | < contr : 'a; cov : 'b > t -> 50 | < contr : 'b; cov : 'c > t -> < contr : 'a; cov : 'c > t 51 | val ( |*| ) : 52 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t -> float 53 | val ( + ) : 54 | < contr : 'a; cov : 'b > t -> 55 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 56 | val ( - ) : 57 | < contr : 'a; cov : 'b > t -> 58 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 59 | val ( *. ) : 60 | float -> < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 61 | val ( /. ) : 62 | float -> < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 63 | val ( ** ) : ('a, 'a) matrix -> int -> ('a, 'a) matrix 64 | end 65 | 66 | exception Dimension_error of string * int * int 67 | -------------------------------------------------------------------------------- /lib/small_matrix.ml: -------------------------------------------------------------------------------- 1 | module A = Array 2 | open Signatures 3 | let delta = Misc.delta 4 | 5 | let (@?) a n = A.unsafe_get a n 6 | let ( % ) a n x = A.unsafe_set a n x 7 | let ( =: ) = (@@) 8 | 9 | type 'a t = {lines: 'b Nat.eq; array:float array} 10 | constraint 'a = 'b * 'c 11 | 12 | let unsafe_create (lines:'a Nat.eq) ( _ : 'b Nat.eq) array: ('a * 'b) t = 13 | { lines; array} 14 | 15 | let create l r a = 16 | if Nat.to_int l * Nat.to_int r = A.length a then 17 | unsafe_create l r a 18 | else 19 | raise @@ Dimension_error("Matric.create",Nat.to_int l * Nat.to_int r, A.length a) 20 | 21 | let init lines (rows:' b Nat.eq) f : ('a * 'b) t = 22 | let nl = Nat.to_int lines and nr = Nat.to_int rows in 23 | let array = A.create_float (nl * nr ) in 24 | let pos = ref 0 in 25 | for j = 0 to nr - 1 do 26 | for i = 0 to nl - 1 do 27 | array.(!pos) <- f i j; incr pos 28 | done; 29 | done; 30 | { lines; array } 31 | 32 | let square dim f = f dim dim 33 | 34 | let get (mat:('a * 'b) t) (i:'a Nat.lt) (j:'b Nat.lt)= 35 | Array.unsafe_get mat.array (Nat.to_int j * Nat.to_int mat.lines + Nat.to_int i) 36 | 37 | let set (mat:('a * 'b) t) (i:'a Nat.lt) (j:'b Nat.lt) x = 38 | Array.unsafe_set mat.array (Nat.to_int j * Nat.to_int mat.lines + Nat.to_int i) x 39 | 40 | let dims mat = let l = Nat.to_int mat.lines in 41 | l, Array.length mat.array / l 42 | 43 | let typed_dims (mat:('a * 'b) t) : 'a Nat.eq * 'b Nat.eq = 44 | let l, r = dims mat in 45 | Nat.Unsafe.create l, Nat.Unsafe.create r 46 | 47 | let size mat = Array.length mat.array 48 | 49 | let map f m = { m with array = A.map f m.array } 50 | 51 | let map2 ( <@> ) (m:'a t) (n:'a t): 'a t = 52 | let array = Array.mapi (fun i x -> x <@> n.array @? i) m.array in 53 | { m with array } 54 | 55 | let fold_2 f acc (m:' a t) (n:'a t) = 56 | let acc = ref acc in 57 | for i = 0 to size m - 1 do 58 | acc := f !acc (m.array @? i) (n.array @? i) 59 | done; 60 | !acc 61 | 62 | let base ~(dim_l:'a Nat.eq) ~i ~(dim_r:'b Nat.eq) ~j : ('a * 'b) t = 63 | let open Nat in 64 | let Truth = i %<% dim_l 65 | and Truth = j %<% dim_r in 66 | let array = Array.make (to_int dim_l * to_int dim_r) 0. in 67 | array.( to_int i * to_int dim_l + to_int j) <- 0.; 68 | {lines = dim_l; array } 69 | 70 | let zero l r = create l r @@ Array.make (Nat.to_int l* Nat.to_int r) 0. 71 | let diag v = 72 | let dim = Small_vec.typed_dim v in 73 | let n = Nat.to_int dim in 74 | let a = Array.make (n * n) 0. in 75 | Nat.iter_on dim ( fun k -> a % (Nat.to_int k * n) =: Small_vec.( v.%(k) ) ) ; 76 | create dim dim a 77 | 78 | let id dim = square dim init delta 79 | 80 | let transpose (mat:('a *' b) t) : ('b * 'a ) t = 81 | let array = Array.create_float (size mat) in 82 | let l, r = dims mat in 83 | let lines = Nat.Unsafe.create @@ r in 84 | let dir = ref 0 and tr = ref 0 in 85 | for j = 0 to r - 1 do 86 | tr := j; 87 | for _i = 0 to l - 1 do 88 | array % !tr @@ array @? !dir; 89 | incr dir; 90 | tr := !tr + l 91 | done; 92 | done; 93 | { lines ; array } 94 | 95 | 96 | module Operators = struct 97 | module Matrix_specific = struct 98 | 99 | (** matrix application: fortran layout *) 100 | let (@) (m: ('a * 'b) t) (v:'b Small_vec.t) : 'a Small_vec.t = 101 | let l = Nat.to_int m.lines in 102 | let array = Array.make l 0. in 103 | let pos = ref 0 in 104 | Nat.iter_on (Small_vec.typed_dim v) (fun k -> 105 | for i = 0 to l - 1 do 106 | array % i =: (array @? i) +. m.array.(!pos) *. Small_vec.(v.%(k)); 107 | incr pos 108 | done 109 | ) 110 | ; Small_vec.unsafe_create m.lines array 111 | 112 | (** matrix multiplication: fortran layout *) 113 | let ( * ) (m: ('a * 'b) t) (n: ('b * 'c) t): ('a * 'c) t = 114 | let l, c = typed_dims n in 115 | init m.lines c (fun i j -> 116 | let sum = ref 0. in 117 | let off = j * Nat.to_int m.lines in 118 | let off_n = ref i in 119 | let n_k = Nat.to_int l in 120 | for k = 0 to n_k - 1 do 121 | sum := !sum +. 122 | (m.array @? !off_n ) 123 | *. ( n.array @? off + k); 124 | off_n := !off_n + n_k 125 | done; 126 | !sum 127 | ) 128 | 129 | 130 | let ( **^ ) x k = 131 | let rec aux x m k = match k with 132 | | 0 -> m 133 | | 1 -> m * x 134 | | k when k land 1 = 1 -> aux (x*x) (x*m) (k lsr 1) 135 | | k -> aux (x*x) m (k lsr 1) in 136 | aux x (id @@ fst @@ typed_dims x) k 137 | 138 | end 139 | include Matrix_specific 140 | 141 | let (+) m n = map2 (+.) m n 142 | let (-) m n = map2 (-.) m n 143 | let (~-) (m:'sh t) : 'sh t = { m with array= A.map (~-.) m.array } 144 | 145 | let ( |*| ) m n = fold_2 (fun sum x y -> sum +. x *. y ) 0. m n 146 | 147 | let ( *. ) scalar m = { m with array = A.map ( ( *. ) scalar ) m.array } 148 | let ( /. ) m scalar = 149 | { m with array = A.map (fun x -> x /. scalar ) m.array } 150 | 151 | let (.%()) m (x,y) = get m x y and ( .% () <- ) m (x,y) = set m x y 152 | end 153 | 154 | include Operators 155 | -------------------------------------------------------------------------------- /lib/small_matrix.mli: -------------------------------------------------------------------------------- 1 | 2 | type 'a t constraint 'a = 'b * 'c 3 | 4 | val unsafe_create : 'a Nat.eq -> 'b Nat.eq -> float array -> ('a * 'b) t 5 | val create : 'a Nat.eq -> 'b Nat.eq -> float array -> ('a * 'b) t 6 | val init : 'a Nat.eq -> 'b Nat.eq -> (int -> int -> float) -> ('a * 'b) t 7 | val square : 'a -> ('a -> 'a -> 'b) -> 'b 8 | 9 | val dims : ('a * 'b) t -> int * int 10 | val typed_dims : ('a * 'b) t -> 'a Nat.eq * 'b Nat.eq 11 | val size : ('a * 'b) t -> int 12 | 13 | val map: (float -> float) -> 'a t -> 'a t 14 | val map2 : 15 | (float -> float -> float) -> ('a * 'b) t -> ('a * 'b) t -> ('a * 'b) t 16 | val fold_2 : 17 | ('a -> float -> float -> 'a) -> 18 | 'a -> ('b * 'c) t -> ('b * 'c) t -> 'a 19 | 20 | val base : 21 | dim_l:'a Nat.eq -> i:'a Nat.lt -> dim_r:'b Nat.eq -> j:'b Nat.lt -> ('a * 'b) t 22 | val zero: 'a Nat.eq -> 'b Nat.eq -> ('a*'b) t 23 | val id : 'a Nat.eq -> ('a * 'a) t 24 | val diag: 'a Small_vec.t -> ('a * 'a ) t 25 | val transpose : ('a * 'b) t -> ('b * 'a) t 26 | 27 | module Operators: Signatures.matrix_operators with 28 | type 'a t := 'a t and type 'a vec := 'a Small_vec.t 29 | 30 | include Signatures.matrix_operators with 31 | type 'a t := 'a t and type 'a vec := 'a Small_vec.t 32 | -------------------------------------------------------------------------------- /lib/small_tensor.ml: -------------------------------------------------------------------------------- 1 | (** 2 | This module gathers tensor definitions for which storing the chain 3 | of indices takes too much spaces 4 | 5 | *) 6 | 7 | 8 | module Vec= Small_vec 9 | module Matrix = Small_matrix 10 | module Unified = Small_unified 11 | 12 | include Vec.Operators 13 | include Matrix.Operators.Matrix_specific 14 | 15 | let vector = Vec.create 16 | let matrix = Matrix.create 17 | -------------------------------------------------------------------------------- /lib/small_tensor.mli: -------------------------------------------------------------------------------- 1 | 2 | 3 | module Vec = Small_vec 4 | module Matrix = Small_matrix 5 | 6 | include Signatures.vec_operators with type 'a t := 'a Vec.t 7 | include Signatures.matrix_specific_operators 8 | with type 'a t := 'a Matrix.t and type 'a vec := 'a Vec.t 9 | 10 | val vector : 'a Nat.eq -> float array -> 'a Vec.t 11 | val matrix : 'a Nat.eq -> 'b Nat.eq -> float array 12 | -> ('a * 'b) Matrix.t 13 | -------------------------------------------------------------------------------- /lib/small_unified.ml: -------------------------------------------------------------------------------- 1 | module V = Small_vec 2 | module M = Small_matrix 3 | module T = Tensor 4 | 5 | type _ t = 6 | | Scalar: float ref -> < contr:Shape.empty; cov: Shape.empty > t 7 | | Vec: 'a V.t -> < contr: 'a Shape.single; cov: Shape.empty > t 8 | | Matrix: ('a * 'b) M.t -> < contr:'a Shape.single; cov:'b Shape.single > t 9 | 10 | let scalar f = Scalar (ref f) 11 | let vector n array = Vec(V.create n array) 12 | let matrix n m array = Matrix(M.create n m array) 13 | 14 | module Operators = struct 15 | 16 | let (+) (type a) (x:a t) (y:a t): a t = match x, y with 17 | | Scalar x, Scalar y -> Scalar ( ref @@ !x +. !y ) 18 | | Vec x, Vec y -> Vec V.(x + y) 19 | | Matrix x, Matrix y -> Matrix M.( x + y ) 20 | 21 | let (-) (type a) (x: a t)(y: a t): a t = match x, y with 22 | | Scalar x, Scalar y -> Scalar ( ref @@ !x -. !y ) 23 | | Vec x, Vec y -> Vec V.(x - y) 24 | | Matrix x, Matrix y -> Matrix M.( x - y ) 25 | 26 | let (~-) (type a) (t:a t) : a t = match t with 27 | | Scalar f -> Scalar ( ref @@ -. !f) 28 | | Vec v -> Vec V.( - v) 29 | | Matrix m -> Matrix M.( - m ) 30 | 31 | let ( |*| ) (type a) (t: a t) (u: a t) = match t, u with 32 | | Scalar x, Scalar y -> !x *. !y 33 | | Vec u, Vec v -> V.( u |*| v ) 34 | | Matrix m, Matrix n -> M.( m |*| n ) 35 | 36 | 37 | let ( * ) (type a) (type b) (type c) 38 | (x: t)(y: t): t = 39 | match x, y with 40 | | Scalar x, Scalar y -> Scalar ( ref @@ !x *. !y ) 41 | | Matrix m, Matrix n -> Matrix M.( m * n) 42 | | Matrix m, Vec v -> Vec M.( m @ v ) 43 | | Vec v, Scalar f -> Vec V.( !f *. v ) 44 | 45 | 46 | let one (type a): t -> t = function 47 | | Scalar _ -> Scalar(ref 1.) 48 | | Matrix m -> Matrix M.(id @@ fst @@ typed_dims m) 49 | 50 | let ( **^ ) (type a) (t: t) k = 51 | let rec aux: type a. 52 | acc:( t as 'te) -> t:'te -> int -> 'te = 53 | fun ~acc ~t k -> 54 | match k with 55 | | 0 -> acc 56 | | 1 -> acc * t 57 | | k when k land 1 = 1 -> aux ~acc:( acc * t ) ~t:(t * t) (k lsr 1) 58 | | k -> aux ~acc ~t:(t*t) (k lsr 1) in 59 | aux ~acc:(one t) ~t k 60 | 61 | let ( *. ) (type a) s (t : a t) : a t = match t with 62 | | Scalar x -> Scalar ( ref @@ s *. !x ) 63 | | Vec v -> Vec V.( s *. v ) 64 | | Matrix m -> Matrix M.( s *. m ) 65 | 66 | let ( /. ) (type a) (t : a t) s : a t = match t with 67 | | Scalar x -> Scalar ( ref @@ !x /. s ) 68 | | Vec v -> Vec V.( v /. s ) 69 | | Matrix m -> Matrix M.( m /. s) 70 | 71 | end 72 | 73 | let (.%()): type a b. t -> (a Shape.lt * b Shape.lt) 74 | -> float = fun t (contr,cov) -> 75 | let open Shape in 76 | match[@warning "-4"] t, contr, cov with 77 | | Scalar f, [] , [] -> !f 78 | | Vec v, [a], [] -> v.V.%(a) 79 | | Matrix m, [i], [j] -> m.M.%(i,j) 80 | | _ -> . 81 | 82 | and (.%()<-): type a b. 83 | t -> (a Shape.lt * b Shape.lt) -> float -> unit 84 | = fun t (contr,cov) x -> 85 | let open Shape in 86 | match[@warning "-4"] t, contr, cov with 87 | | Scalar f, [] , [] -> f := x 88 | | Vec v, [a] , [] -> v.V.%(a) <- x 89 | | Matrix m, [i], [j] -> m.M.%(i,j) <- x 90 | | _ -> . 91 | -------------------------------------------------------------------------------- /lib/small_unified.mli: -------------------------------------------------------------------------------- 1 | 2 | type _ t = 3 | | Scalar: float ref -> < contr:Shape.empty; cov: Shape.empty > t 4 | | Vec: 'a Small_vec.t -> < contr: 'a Shape.single; cov: Shape.empty > t 5 | | Matrix: ('a * 'b) Small_matrix.t -> < contr:'a Shape.single; cov:'b Shape.single > t 6 | 7 | val scalar: float -> t 8 | val vector : 9 | 'a Nat.eq -> 10 | float array -> < contr : 'a Shape.single; cov : Shape.empty > t 11 | val matrix : 12 | 'a Nat.eq -> 13 | 'b Nat.eq -> 14 | float array -> < contr : 'a Shape.single; cov : 'b Shape.single > t 15 | 16 | module Operators: sig 17 | include Signatures.base_operators with type 'a t := 'a t 18 | val ( * ): t -> t -> 19 | t 20 | val ( **^ ): t -> int -> t 21 | end 22 | 23 | val (.%()): t -> ('a Shape.lt * 'b Shape.lt) -> float 24 | val (.%()<-): 25 | t -> ( 'a Shape.lt * 'b Shape.lt) -> float -> unit 26 | -------------------------------------------------------------------------------- /lib/small_vec.ml: -------------------------------------------------------------------------------- 1 | 2 | open Signatures 3 | open Misc 4 | module A = Array 5 | 6 | type +'a t = float array 7 | 8 | let unsafe_create (_nat:'a Nat.eq) (array: float array) : 'a t = 9 | array 10 | 11 | let create nat array = 12 | if Nat.to_int nat <> A.length array then 13 | raise @@ Dimension_error( "Vec.create", Nat.to_int nat, A.length array) 14 | else 15 | unsafe_create nat array 16 | 17 | let init (nat:'a Nat.eq) f : 'a t = 18 | unsafe_create nat @@ A.init (Nat.to_int nat) f 19 | 20 | let const nat c = 21 | Array.make (Nat.to_int nat) c 22 | 23 | let zero nat = const nat 0. 24 | 25 | let pad_right nat array = 26 | let n = Array.length array in 27 | if n <> Nat.to_int nat then 28 | raise @@ Dimension_error("Vec.pad_left",n, Nat.to_int nat) 29 | else 30 | let v = zero nat in 31 | A.blit array 0 v 0 n; 32 | v 33 | 34 | let get: 'a t -> 'a Nat.lt -> float = fun vec nat -> 35 | A.unsafe_get vec (Nat.to_int nat) 36 | 37 | let set: 'a t -> 'a Nat.lt -> float -> unit = fun vec nat x -> 38 | A.unsafe_set vec (Nat.to_int nat) x 39 | 40 | let dim = Array.length 41 | let typed_dim (v: 'a t) : 'a Nat.eq = Nat.Unsafe.create @@ dim v 42 | 43 | let map f v = Array.map f v 44 | 45 | let map_nat f v = 46 | let a = Array.create_float (dim v) in 47 | Nat.iter_on (typed_dim v) (fun k -> A.unsafe_set a (Nat.to_int k) @@ 48 | f k @@ A.unsafe_get a @@ Nat.to_int k ); 49 | a 50 | 51 | let map2 ( <@> ) (v:'a t) (w:'a t) : 'a t = 52 | A.mapi ( fun i x -> x <@> A.unsafe_get w (i) ) v 53 | 54 | 55 | let fold_2 f acc (v:' a t) (w:'a t) = 56 | let acc = ref acc in 57 | for i = 0 to dim v - 1 do 58 | acc := f !acc (A.unsafe_get v i) (A.unsafe_get w i) 59 | done; 60 | !acc 61 | 62 | let scalar_prod v w = fold_2 (fun sum x y -> sum +. x *. y ) 0. v w 63 | 64 | module Operators = struct 65 | let (+) v w = map2 (+.) v w 66 | let (-) v w = map2 (-.) v w 67 | 68 | let ( |*| ) = scalar_prod 69 | let ( *. ) scalar vec = A.map ( ( *. ) scalar ) vec 70 | let ( /. ) vec scalar = A.map (fun x -> x /. scalar ) vec 71 | let (~-) v = A.map ( ~-.) v 72 | let ( .%() ) = get and ( .%() <- ) = set 73 | end 74 | 75 | let prenorm v = Operators.( v |*| v ) 76 | let norm v = sqrt @@ prenorm v 77 | 78 | let proj v w = Operators.( *. ) ((scalar_prod v w)/.prenorm v) v 79 | 80 | let base dim p = 81 | let open Nat in 82 | let Truth = p %<% dim in 83 | init dim @@ delta @@ to_int p 84 | 85 | 86 | include Operators 87 | -------------------------------------------------------------------------------- /lib/small_vec.mli: -------------------------------------------------------------------------------- 1 | 2 | type 'a t 3 | 4 | val unsafe_create : 'a Nat.eq -> float array -> 'a t 5 | val create : 'a Nat.eq -> float array -> 'a t 6 | val init : 'a Nat.eq -> (int -> float) -> 'a t 7 | 8 | val const: 'a Nat.eq -> float -> 'a t 9 | val zero: 'a Nat.eq -> 'a t 10 | val pad_right: 'a Nat.eq -> float array -> 'a t 11 | 12 | val dim : 'a t -> int 13 | val typed_dim : 'a t -> 'a Nat.eq 14 | 15 | val map_nat: ('a Nat.lt -> float -> float) -> 'a t -> 'a t 16 | val map: (float -> float) -> 'a t -> 'a t 17 | val map2 : (float -> float -> float) -> 'a t -> 'a t -> 'a t 18 | val fold_2 : ('b -> float -> float -> 'b) -> 'b -> 'a t -> 'a t -> 'b 19 | val scalar_prod : 'a t -> 'a t -> float 20 | 21 | 22 | module Operators: Signatures.vec_operators with type 'a t := 'a t 23 | 24 | val prenorm : 'a t -> float 25 | val norm : 'a t -> float 26 | val proj : 'a t -> 'a t -> float array 27 | val base : 'a Nat.eq -> 'a Nat.lt -> 'a t 28 | 29 | include Signatures.vec_operators with type 'a t := 'a t 30 | -------------------------------------------------------------------------------- /lib/stencil.ml: -------------------------------------------------------------------------------- 1 | 2 | type t = {translation:int; linear:int} 3 | (** [{translation=l; linear=k}] corresponds to the affine subspace l + k ℕ *) 4 | 5 | type stencil = t 6 | 7 | type ideal = int 8 | type integer_set = N 9 | 10 | 11 | let affine ~translation ~linear = {translation; linear} 12 | let ( +: ) translation (linear:ideal) = affine ~translation ~linear 13 | (* let f = k +: a *: N *) 14 | 15 | let ( *: ) x N : ideal = x 16 | let ( ~* ) x = { translation = 0; linear = x } 17 | let ( ~+ ) x = { translation = x; linear = 1 } 18 | 19 | 20 | let ( % ) f g = 21 | (f.translation + f.linear * g.translation) 22 | +: (f.linear * g.linear) *: N 23 | (* 24 | { 25 | linear = f.linear * g.linear; 26 | translation = f.translation + f.linear * g.translation 27 | } 28 | *) 29 | 30 | let first s = s.translation 31 | let translation s = { s with linear = 1 } 32 | 33 | let (.%[] ) stencil n = 34 | stencil.translation + n * stencil.linear 35 | 36 | let id = { linear = 1; translation = 0 } 37 | let all = id 38 | let is_all = (=) id 39 | -------------------------------------------------------------------------------- /lib/stencil.mli: -------------------------------------------------------------------------------- 1 | (** Stencil defines types and functions for handling enumerable 2 | subset of integer. Such subset are used to represent sparse 3 | sub-array within array in Multidim_array and Tensor 4 | *) 5 | 6 | (** The type for affine stencil [s]= [translation] + [linear] ℕ *) 7 | type t = {translation:int; linear:int} 8 | type stencil = t 9 | 10 | (** Specialized type used to represent the ideal k ℕ, i.e. all 11 | multiple of k within natural numbers. *) 12 | type ideal = private int 13 | type integer_set = N 14 | 15 | (** Smart constructor for affine stencil *) 16 | val affine : translation:int -> linear:int -> t 17 | 18 | (** Fancy constructor for affine stencil *) 19 | val ( +: ) : int -> ideal -> t 20 | 21 | (** [ k *: N ] converts [k] to the ideal [k ℕ] *) 22 | val ( *: ): int -> integer_set -> ideal 23 | val ( ~+ ): int -> t 24 | val ( ~* ): int -> t 25 | 26 | (** Stencil enumeration [ s.[n] ] is the n-th integer within s *) 27 | val ( .%[] ): t -> int -> int 28 | 29 | (** [first s] ≡ s.[0] *) 30 | val first: t -> int 31 | 32 | (** Stencil composition: [ s_1 % s_2 ] is the stencil such that 33 | [ (s_1 % s_2).[n] = s_1.[ s_2.[n] ] ] *) 34 | val (%): t -> t -> t 35 | 36 | (** The stencil [s] = ℕ *) 37 | val id: t 38 | val all: t 39 | 40 | (** equality test with [all] *) 41 | val is_all: t -> bool 42 | 43 | (** [translation s] is the stencil s' such that 44 | [ s'.[k] ≡ s.[0] + k ] *) 45 | val translation: t -> t 46 | -------------------------------------------------------------------------------- /lib/stride.ml: -------------------------------------------------------------------------------- 1 | type 'a t = int array 2 | 3 | let size s = s.(Array.length s - 1) 4 | let first s = s.(0) 5 | 6 | 7 | let create: ('n * 'sh) Shape.eq -> 'n t = fun sh -> 8 | let s = Array.make (1 + Shape.order sh) 0 in 9 | let rec fill: type sh. pos:int -> m:int -> sh Shape.eq -> unit = 10 | let open Shape in 11 | fun ~pos ~m -> function 12 | | a :: q -> Array.unsafe_set s pos m; 13 | fill ~pos:(pos+1) ~m:(m * Nat.to_int a) q 14 | | [] -> Array.unsafe_set s pos m in 15 | fill ~pos:0 ~m:1 sh; s 16 | 17 | let create_2: ('n * 'sh) Shape.eq -> ('n2 * 'sh2) Shape.eq -> 'n t * 'n2 t = 18 | fun sh sh2 -> 19 | let s = Array.make (1+Shape.order sh) 0 in 20 | let s2 = Array.make (1+Shape.order sh2) 0 in 21 | let rec fill: type sh. 'a t -> pos:int -> m:int -> sh Shape.eq -> int = 22 | let open Shape in 23 | fun s ~pos ~m -> function 24 | | a :: q -> Array.unsafe_set s pos m; 25 | fill s ~pos:(pos+1) ~m:(m * Nat.to_int a) q 26 | | [] -> Array.unsafe_set s pos m; m in 27 | let m = fill s ~pos:0 ~m:1 sh in 28 | let _ = fill s2 ~pos:0 ~m sh2 in 29 | s, s2 30 | 31 | 32 | let rec filter_scan: 33 | type sh sh2. _ t -> _ t -> pos_in:int -> pos_out:int -> 34 | (sh, sh2) Mask.t -> int = 35 | let open Mask in 36 | fun s s' ~pos_in ~pos_out m -> 37 | match m with 38 | | [] -> s'.(pos_out) <- s.(pos_in); 0 39 | | All :: q -> 40 | s'.(pos_out) <- s.(pos_in); 41 | filter_scan s s' ~pos_in:(succ pos_in) ~pos_out:(succ pos_out) q 42 | | Elt k :: q -> 43 | let t = Nat.to_int k * s.(pos_in) in 44 | t + filter_scan s s' ~pos_in:(succ pos_in) ~pos_out q 45 | | Range r :: q -> 46 | let t = Nat.to_int (Range.start r) * s.(pos_in) in 47 | s'.(pos_out) <- Range.step r * s.(pos_in); 48 | t + filter_scan s s' ~pos_in:(succ pos_in) ~pos_out:(succ pos_out) q 49 | 50 | let filter: 'n t -> ('n * _, 'n2 * _ ) Mask.t -> int * 'n2 t = 51 | fun s m -> 52 | let s' = Array.make (1 + Mask.order_out m) 1 in 53 | let offset = filter_scan s s' ~pos_in:0 ~pos_out:0 m in 54 | offset, s' 55 | 56 | let slice_1 (s: 'n Nat.succ t) : 'n t = 57 | Array.init (Array.length s - 1 ) (fun i -> s.(i+1) ) 58 | 59 | (** Note: fortran layout *) 60 | let position ~(strides:'n t) ~(indices: ('n*'sh) Shape.lt) = 61 | let rec descent: type sh. int -> 'n t -> sh Shape.lt -> int= 62 | fun n strides shape -> let open Shape in 63 | match shape with 64 | | [] -> 0 65 | | a :: q -> (a:>int) * strides.(n) + descent (n+1) strides q in 66 | descent 0 strides indices 67 | 68 | (** Note: fortran layout *) 69 | let position_2 ( st1, st2 : 'n t * 'm t) 70 | (i1:('n*'sh) Shape.lt) (i2: ('m*_) Shape.lt ) = 71 | let rec descent: type sh. int -> 'n t -> sh Shape.lt -> int= 72 | fun n strides shape -> let open Shape in 73 | match shape with 74 | | [] -> 0 75 | | a :: q -> (a:>int) * strides.(n) + descent (n+1) strides q in 76 | descent 0 st1 i1 + descent 0 st2 i2 77 | -------------------------------------------------------------------------------- /lib/stride.mli: -------------------------------------------------------------------------------- 1 | (** Strides represents injective mapping between cylinders of ℕ^d and 2 | integers subset \{0,...,n\} 3 | 4 | *) 5 | type 'a t 6 | 7 | (** The size of the stride image *) 8 | val size: 'a t -> int 9 | 10 | (** first dimension of a stride *) 11 | val first: 'a t -> int 12 | 13 | (** Create a stride array from a shape *) 14 | val create: ('n * 'sh) Shape.eq -> 'n t 15 | 16 | (** Create two stride arrays from two consecutive shapes *) 17 | val create_2: ('n * 'sh) Shape.eq -> ('n2 * 'sh2) Shape.eq -> 'n t * 'n2 t 18 | 19 | (** Apply a mask to a stride by computing the resulting substride 20 | and offset *) 21 | val filter: 'n t -> ('n * 'sh, 'n2 * 'sh2 ) Mask.t -> int * 'n2 t 22 | 23 | (** Remove the first dimension of the stride *) 24 | val slice_1 : 'n Nat.succ t -> 'n t 25 | 26 | (** Apply the strides to compute the resulting integer position 27 | from the shape *) 28 | val position: strides:'n t -> indices:('n*'sh) Shape.lt -> int 29 | 30 | (** Apply the strides to compute the resulting integer position 31 | from the shape *) 32 | val position_2: ('n t * 'm t) -> ('n*'sh) Shape.lt -> ('m * 'sh2) Shape.lt 33 | -> int 34 | -------------------------------------------------------------------------------- /lib/tensor.ml: -------------------------------------------------------------------------------- 1 | module MA = Multidim_array 2 | module A=Array 3 | let (.!()) = A.unsafe_get 4 | let (.!()<-) = A.unsafe_set 5 | 6 | 7 | type 'x t = { contr:('n * 'a) Shape.eq 8 | ; cov:('n2 * 'b) Shape.eq 9 | ; strides: 'n Stride.t * 'n2 Stride.t 10 | ; offset: int 11 | ; array : float array 12 | } 13 | constraint 'x = < contr:'n * 'a; cov:'n2 * 'b > 14 | 15 | type 'dim vec = t 16 | type ('l,'c) matrix = t 17 | type ('d1,'d2,'d3) t3 = t 18 | 19 | module Unsafe = struct 20 | let create ~contr ~cov array = 21 | let len = (Shape.size cov) * (Shape.size contr) in 22 | let len' = A.length array in 23 | if len <> len' then 24 | raise @@ Signatures.Dimension_error( "Tensor.unsafe_create", len, len' ) 25 | else 26 | {cov;contr; array; strides= Stride.create_2 contr cov; offset = 0 } 27 | 28 | end 29 | 30 | let (.%()): t -> ('a Shape.lt * 'b Shape.lt ) -> float = 31 | fun t (contr,cov) -> 32 | let p = 33 | t.offset + Stride.position_2 t.strides contr cov in 34 | t.array.!(p) 35 | 36 | 37 | let (.%()<-): 38 | < contr:'a; cov:'b > t -> ('a Shape.lt * 'b Shape.lt ) -> float -> unit 39 | = fun t (contr,cov) value -> 40 | let p = 41 | t.offset + Stride.position_2 t.strides contr cov in 42 | t.array.!(p) <- value 43 | 44 | 45 | let cov_size t = Shape.size t.cov 46 | let contr_size t = Shape.size t.contr 47 | let size t = cov_size t * contr_size t 48 | 49 | let len t = A.length t.array 50 | let contr_dims t = t.contr 51 | let cov_dims t = t.cov 52 | let is_sparse t = len t <> size t 53 | 54 | 55 | let const ~contr ~cov x= 56 | let len = Shape.size cov * Shape.size contr in 57 | let array = A.make len x in 58 | Unsafe.create ~contr ~cov array 59 | 60 | let zero ~contr ~cov = const ~contr ~cov 0. 61 | 62 | let init_sh f ~contr ~cov = 63 | let r = zero ~contr ~cov in 64 | Shape.iter_on contr ( fun contr -> 65 | Shape.iter_on cov ( fun cov -> 66 | r.%(contr,cov) <- f contr cov 67 | ) 68 | ) 69 | ; r 70 | 71 | let pp ppf t = 72 | let order = Shape.order t.cov in 73 | let sep ?(start=0) ppf n = 74 | match n + start with 75 | | 0 -> Format.fprintf ppf ",@ " 76 | | 1 -> Format.fprintf ppf ";@ " 77 | | 2 -> Format.fprintf ppf "@," 78 | | _n -> Format.fprintf ppf "@," in 79 | let up _ = Format.fprintf ppf "@[" 80 | and down _ = Format.fprintf ppf "@]" in 81 | let pp_scalar ppf x= Format.fprintf ppf "%f" x in 82 | let pp_cov ppf t contr = 83 | Shape.iter_sep ~up ~down ~sep:(sep ppf) t.cov ~f:(fun cov -> 84 | pp_scalar ppf t.%(contr,cov) 85 | ) in 86 | let pp_array ppf t = 87 | Shape.iter_sep ~up ~down 88 | ~sep:(sep ~start:order ppf) ~f:(pp_cov ppf t) t.contr in 89 | Format.fprintf ppf "@[{contr=%a;@ cov=%a;@ array=%a}@]" 90 | Shape.pp t.contr Shape.pp t.cov 91 | pp_array t 92 | 93 | let show t = Format.asprintf "%a" pp t 94 | 95 | let reshape t (contr,cov) = 96 | if is_sparse t then 97 | raise @@ Invalid_argument "Tensor.reshape: sparse tensor cannot be reshaped" 98 | else 99 | let l = len t and dim = Shape.(size contr * size cov) in 100 | if l <> dim then 101 | raise @@ Signatures.Dimension_error("Tensor.reshape", l, dim) 102 | else 103 | { t with contr; cov; 104 | strides = Stride.create_2 contr cov } 105 | 106 | let matrix dim_row dim_col f: ('a,'b) matrix = 107 | let size = Nat.(to_int dim_row * to_int dim_col) in 108 | let array = A.create_float size in 109 | let pos = ref 0 in 110 | let () = (*init*) 111 | Nat.iter_on dim_row (fun i -> 112 | Nat.iter_on dim_col (fun j -> 113 | array.!(!pos) <- f i j 114 | ; incr pos 115 | ) 116 | ) in 117 | Unsafe.create ~contr:[dim_row] ~cov:[dim_col] array 118 | 119 | let sq_matrix dim f = matrix dim dim f 120 | 121 | let vector (dim:'a Nat.eq) f :' a vec= 122 | Unsafe.create ~contr:[dim] ~cov:[] @@ Nat.map f dim 123 | 124 | 125 | let delta i j = if Nat.to_int i = Nat.to_int j then 1. else 0. 126 | let id dim = sq_matrix dim delta 127 | let base dim p = 128 | let open Nat in 129 | let Truth = p %<% dim in 130 | vector dim @@ delta p 131 | 132 | let endo_dim (mat: ('a,'a) matrix) = 133 | let open Shape in 134 | match mat.contr with 135 | | [dim] -> dim 136 | 137 | module Sparse = struct 138 | 139 | let transpose t = 140 | let tt = zero ~cov:t.contr ~contr:t.cov in 141 | Shape.iter_on t.contr ( fun contr -> 142 | Shape.iter_on t.cov ( fun cov -> 143 | tt.%(cov,contr) <- t.%(contr,cov) 144 | ) 145 | ) 146 | ; tt 147 | 148 | let mult t1 t2 = 149 | let r = zero ~contr:t1.contr ~cov:t2.cov in 150 | Shape.iter_on t1.contr (fun i -> 151 | Shape.iter_on t1.cov (fun k -> 152 | Shape.iter_on t2.cov ( fun j -> 153 | r.%(i,j)<- r.%(i,j) +. t1.%(i,k) *. t2.%(k,j) 154 | ) 155 | ) 156 | ) 157 | ; r 158 | 159 | 160 | let trace t = 161 | let s = ref 0. in 162 | Shape.iter_on t.contr ( fun sh -> 163 | s := !s +. t.%(sh,sh) 164 | ) 165 | ; !s 166 | 167 | let full_contraction t1 t2 = 168 | let s = ref 0. in 169 | Shape.iter_on t1.contr ( fun contr -> 170 | Shape.iter_on t1.cov ( fun cov -> 171 | s := !s +. t1.%(contr,cov) *. t2.%(cov,contr) 172 | ) 173 | ) 174 | ; !s 175 | 176 | let scalar_product t1 t2 = 177 | let s = ref 0. in 178 | Shape.iter_on t1.contr ( fun contr -> 179 | Shape.iter_on t1.cov ( fun cov -> 180 | s := !s +. t1.%(contr,cov) *. t2.%(contr,cov) 181 | ) 182 | ) 183 | ; !s 184 | 185 | let map2 ( <+> ) t1 t2 = 186 | init_sh ~contr:t1.contr ~cov:t1.cov (fun contr cov -> 187 | t1.%(contr,cov) <+> t2.%(contr,cov) 188 | ) 189 | 190 | let scalar_map f t1 = 191 | init_sh ~contr:t1.contr ~cov:t1.cov 192 | (fun contr cov -> f t1.%(contr,cov)) 193 | 194 | 195 | end 196 | 197 | 198 | (** Optimized version when tensor are not sparse *) 199 | module Full = struct 200 | 201 | let map2 ( <@> ) (t1:'sh t) (t2:'sh t) : 'sh t = 202 | let array = A.mapi ( fun i x -> x <@> t2.array.!(i) ) t1.array in 203 | { t1 with array } 204 | 205 | 206 | let iter_int n kont f = 207 | for i = 0 to n - 1 do 208 | kont (f i) 209 | done 210 | let (^) = iter_int 211 | let stop () = () 212 | let iter_on f = f 213 | 214 | let transpose: < contr:'left; cov:'right > t -> < contr:'right; cov:'left > t = 215 | fun t1 -> 216 | let left = contr_size t1 217 | and right = cov_size t1 in 218 | let array = A.make (len t1) 0. in 219 | let () = 220 | iter_on (left ^ right ^ stop) (fun i j -> 221 | t1.array.!(i * right + j ) <- t1.array.!(j * right + i) 222 | ) in 223 | Unsafe.create ~contr:t1.cov ~cov:t1.contr array 224 | 225 | let mult (t1: t) (t2: t) : 226 | t = 227 | let left_dim = contr_size t1 228 | and middle_dim = cov_size t1 229 | and right_dim = cov_size t2 in 230 | let l = t1.array and r = t2.array in 231 | let len = left_dim * right_dim in 232 | let array = A.make len 0. in 233 | iter_on (left_dim ^ middle_dim ^ right_dim ^ stop) 234 | ( fun i k j -> 235 | let pos = i * right_dim + j in 236 | array.!(pos) <- 237 | array.!(pos) +. 238 | l.!(i * middle_dim + k ) *. r.!(k * right_dim + j) 239 | ); 240 | Unsafe.create ~contr:t1.contr ~cov:t2.cov array 241 | 242 | 243 | let trace (t1: t ) = 244 | let size = contr_size t1 in 245 | let s = ref 0. in 246 | iter_on (size ^ stop) (fun i -> 247 | s := !s +. t1.array.!(i + size * i) 248 | ) 249 | ; !s 250 | 251 | let full_contraction (t1: t ) (t2: < contr:'b; cov:'a > t) = 252 | let left = contr_size t1 and right = cov_size t1 in 253 | let s = ref 0. in 254 | iter_on (left ^ right ^ stop) (fun i j -> 255 | s := !s +. t1.array.!(i + left * j) *. t2.array.!(j + right * i) 256 | ) 257 | ; !s 258 | 259 | let scalar_product (t1: 'sh t) (t2: 'sh t) = 260 | let l = len t1 in 261 | let s =ref 0. in 262 | for i = 0 to l - 1 do 263 | s:= !s +. t1.array.! (i) *. t2.array.!(i) 264 | done 265 | ; !s 266 | 267 | let scalar_map f t = 268 | { t with array = A.map f t.array } 269 | 270 | end 271 | 272 | let transpose t = 273 | if is_sparse t then 274 | Sparse.transpose t 275 | else 276 | Full.transpose t 277 | 278 | let mult t1 t2 = 279 | if is_sparse t1 || is_sparse t2 then 280 | Sparse.mult t1 t2 281 | else 282 | Full.mult t1 t2 283 | 284 | let trace t = 285 | if is_sparse t then 286 | Sparse.trace t 287 | else 288 | Full.trace t 289 | 290 | let full_contraction t1 t2 = 291 | if is_sparse t1 || is_sparse t2 then 292 | Sparse.full_contraction t1 t2 293 | else 294 | Full.full_contraction t1 t2 295 | 296 | 297 | let scalar_product t1 t2 = 298 | if is_sparse t1 || is_sparse t2 then 299 | Sparse.scalar_product t1 t2 300 | else 301 | Full.scalar_product t1 t2 302 | 303 | 304 | let map2 ( <+> ) t1 t2 = 305 | if is_sparse t1 || is_sparse t2 then 306 | Sparse.map2 (<+>) t1 t2 307 | else 308 | Full.map2 (<+>) t1 t2 309 | 310 | 311 | let scalar_map f t = 312 | if is_sparse t then 313 | Sparse.scalar_map f t 314 | else 315 | Full.scalar_map f t 316 | 317 | let pow_int x k = 318 | let ( * ) = mult in 319 | let rec aux x m k = match k with 320 | | 0 -> m 321 | | 1 -> m * x 322 | | k when k land 1 = 1 -> aux (x*x) (x*m) (k lsr 1) 323 | | k -> aux (x*x) m (k lsr 1) in 324 | let id = id @@ endo_dim x in 325 | aux x id k 326 | 327 | module Operators = struct 328 | let ( * ) x y = mult x y 329 | let ( |*| ) x y = scalar_product x y 330 | let ( + ) t1 t2 = map2 ( +. ) t1 t2 331 | let ( - ) t1 t2 = map2 ( -. ) t1 t2 332 | 333 | let ( *. ) l t = scalar_map ( ( *. ) l ) t 334 | 335 | let ( /. ) l t = scalar_map ( fun x -> x /. l ) t 336 | 337 | let ( ** )= pow_int 338 | 339 | end 340 | 341 | (* to do: 342 | * moving indices up/down 343 | *) 344 | 345 | (* 346 | let full_up: type left right tl. 347 | (, ) t -> (, 'a Shape.empty ) t = 348 | fun t1 -> 349 | Shape.{ t1 with contr=t1.contr @ t1.cov; cov = [] } 350 | 351 | let up1: type left right dim tl tl2. 352 | (tl >, right; tl:tl2> ) t -> 353 | (, ) t = fun t -> 354 | let open Shape in 355 | match t.cov with 356 | | dim::right -> { t with contr = t.contr @ [dim] ; cov = right } 357 | *) 358 | 359 | let copy t = 360 | if is_sparse t then 361 | init_sh ~contr:t.contr ~cov:t.cov 362 | ( fun i j -> t.%(i,j) ) 363 | else 364 | { t with array = A.copy t.array } 365 | 366 | let partial_copy (type na a nb b nc c nd d) 367 | (t: t) 368 | (f1,f2: (na * a, nc * c) Mask.s * ( nb * b, nd * d) Mask.s) 369 | : t 370 | = 371 | let contr = Mask.filter t.contr f1 in 372 | let cov = Mask.filter t.cov f2 in 373 | let tnew = zero ~contr ~cov in 374 | Mask.iter_extended_dual 375 | (fun sh2 sh2' -> 376 | Mask.iter_extended_dual ( 377 | fun sh1 sh1' -> 378 | tnew.%(sh1,sh2) <- t.%(sh1',sh2') 379 | ) contr f1 380 | ) cov f2; 381 | tnew 382 | 383 | let slice t (f1,f2) = 384 | let s1, s2 = t.strides in 385 | let offset, s1 = Stride.filter s1 f1 in 386 | let offset_2, s2 = Stride.filter s2 f2 in 387 | let contr, cov = Mask.filter t.contr f1, Mask.filter t.cov f2 in 388 | { t with contr; cov; offset = offset + offset_2; strides = (s1, s2) } 389 | 390 | let blit t t2 = 391 | Shape.iter ( fun sh' -> 392 | Shape.iter ( fun sh -> 393 | t.%(sh,sh')<- t2.%(sh,sh') 394 | ) t.contr 395 | ) t.cov 396 | 397 | let partial_blit t (f1,f2) t2 = 398 | Mask.iter_masked_dual ( fun sh2 sh2' -> 399 | Mask.iter_masked_dual ( fun sh sh' -> 400 | t.%(sh,sh2) <- t2.%(sh',sh2') 401 | ) t.contr f1 402 | ) t.cov f2 403 | 404 | 405 | let (.%[]) = slice and (.%[]<-) = partial_blit 406 | 407 | exception Break 408 | 409 | let det ( mat : t): float= 410 | let abs = abs_float in 411 | let dim = endo_dim mat in 412 | let mat = copy mat in 413 | let sign = ref 1. in 414 | let perm = MA.ordinal dim in 415 | let ( ! ) k = perm.MA.%([k]) in 416 | let swap i i' = 417 | if i <> i' then 418 | let tmp = !i in 419 | let open MA in 420 | perm.%([i]) <- !i' 421 | ; perm.%([i']) <- tmp 422 | ; sign.contents<- -.sign.contents 423 | in 424 | let pivot j = 425 | let find_max (i,max) k = 426 | let abs_k = abs_float mat.%([!k], [j] ) in 427 | if abs_k > max then (k,abs_k) else (i,max) in 428 | let start = Nat.succ j and acc = j, abs mat.%([!j], [j]) in 429 | let i, max = 430 | Nat.partial_fold ~stop:dim ~start ~acc find_max in 431 | if max > 0. then swap j i else raise Break in 432 | let transl ?(start=0) ~from ~to_ coeff = 433 | Nat.partial_iter ~start ~stop:dim (fun j -> 434 | mat.%([!to_],[j]) <- 435 | mat.%([!to_],[j]) +. coeff *. mat.%([!from],[j]) 436 | ) 437 | in 438 | try 439 | Nat.iter_on dim (fun i -> 440 | pivot i; 441 | let c = mat.%([!i],[i]) in 442 | Nat.partial_iter ~start:(Nat.succ i) ~stop:dim 443 | (fun to_ -> transl ~start:(Nat.to_int i) ~from:i 444 | ~to_ (-. mat.%([!to_],[i])/. c) ) 445 | ) 446 | ; Nat.fold (fun p k -> p *. mat.%([!k],[k])) sign.contents dim 447 | with Break -> 0. 448 | 449 | (** Given (n-1) vectors of dimension n, compute the normal to the hyperplane 450 | defined by these vectors with norm equal to their (n-1)-volume; 451 | sometimes erroneously called vector or cross product in dimension 3. 452 | * raise Invalid_argument if array = [||] 453 | * raise dimension_error if array lenght and vector dimension disagrees 454 | *) 455 | let normal (array: 'dim vec array): 'dim vec = 456 | let nvec = A.length array in 457 | if nvec = 0 then raise @@ 458 | Invalid_argument "Tensor.normal expects array of size >0"; 459 | let open Shape in 460 | let [dim] = array.!(0).contr in 461 | let module Dyn = Nat.Dynamic(struct let dim = nvec end) in 462 | let open Nat_defs in 463 | match Nat.Sum.( Dyn.dim + _1 =? dim ) with 464 | | None -> raise @@ 465 | Signatures.Dimension_error( "Tensor.normal", nvec + 1 , Nat.to_int dim ) 466 | | Some proof -> 467 | let (%+%) = Nat.Sum.adder proof in 468 | let minor k = det @@ sq_matrix Dyn.dim (fun i j -> 469 | let offset = 470 | if Nat.( to_int i < to_int k) then _0p else _1p in 471 | array.!(Nat.to_int j).%([i %+% offset], []) 472 | ) 473 | in 474 | vector dim minor 475 | 476 | include Operators 477 | -------------------------------------------------------------------------------- /lib/tensor.mli: -------------------------------------------------------------------------------- 1 | (** Compared to multidimensional arrays, tensors carry supplementary 2 | information to facilitate common linear algebra information. 3 | 4 | More precisely, the dimensions of a tensor are divided between covariant 5 | and contravariant dimensions. A vector is for instance a (1,0) tensor, 6 | and a matrix is an (1,1) tensor. Matrix multiplication and matrix - vector 7 | multiplication is then a special case of tensor contraction: 8 | A (full) tensor contraction of a (i,k) tensor with a (k,j) tensor yields 9 | a (i,j) tensors. Matrix multiplication corresponds then to the case 10 | (1,1) tensor * (1,1) tensor → (1,1) tensor and matrix-vector multiplication 11 | is the case (1,1) tensor * (1,0) tensor → (1,0) tensor. 12 | 13 | Note that this module implements for now an euclidian geometry, so the 14 | distinction between contravariant and covariant indices is mainly formal. 15 | 16 | *) 17 | 18 | (** {2 Type definitions }*) 19 | 20 | (** [ t] encodes a tensor 21 | with contravariant dimensions [dims] and covariant dimension [dims'] *) 22 | type 'c t constraint 'c = < contr:'n * 'a; cov:'n2 * 'b > 23 | 24 | 25 | (** Shortcut type for vectors = (1,0) tensors *) 26 | type 'dim vec = < contr : 'dim Shape.single; cov : Shape.empty > t 27 | 28 | (** Shortcut type for matrices = (1,1) tensors *) 29 | type ('l, 'c) matrix = < contr : 'l Shape.single; cov : 'c Shape.single > t 30 | 31 | (** Shortcut type for the (2,1)-tensors *) 32 | type ('d1, 'd2, 'd3) t3 = 33 | < contr : ('d1, 'd2) Shape.pair; cov : 'd3 Shape.single > t 34 | 35 | (** {2 Printing} *) 36 | 37 | (** Pretty-printer for tensor *) 38 | val pp : Format.formatter -> < contr : 'a; cov : 'b > t -> unit 39 | 40 | (** Conversion to string *) 41 | val show : < contr : 'a; cov : 'b > t -> string 42 | 43 | (** {2 Access operators} *) 44 | 45 | (** tensor access: 46 | with ppx_tensority: [ t.(i_1, ..., i_n; j_1, ..., j_n) ] 47 | without ppx [ t.([i_1;...;i_n], [j_1;...,j_n]) ] 48 | *) 49 | val (.%()): < contr : 'a; cov : 'b > t -> 'a Shape.lt * 'b Shape.lt -> float 50 | 51 | (** tensor access: 52 | with ppx_tensority: [ t.%(i_1, ..., i_n; j_1, ..., j_n) <- x ] 53 | without ppx [ t.%([i_1;...;i_n], [j_1;...,j_n]) <- x ] 54 | *) 55 | val (.%()<-): 56 | < contr : 'a; cov : 'b > t -> 'a Shape.lt * 'b Shape.lt -> float -> unit 57 | 58 | (** {2 Shape functions} *) 59 | 60 | (** total size of the covariant dimensions *) 61 | val cov_size : < contr : 'a; cov : 'b > t -> int 62 | 63 | (** total size of the contravariant dimensions *) 64 | val contr_size : < contr : 'a; cov : 'b > t -> int 65 | 66 | (** total size of the tensor *) 67 | val len : < contr : 'a; cov : 'b > t -> int 68 | 69 | (** Contravariant shape *) 70 | val contr_dims : < contr : 'a; cov : 'b > t -> 'a Shape.eq 71 | 72 | (** Covariant shape *) 73 | val cov_dims : < contr : 'a; cov : 'b > t -> 'b Shape.eq 74 | 75 | (** Dimension of an endomorphism *) 76 | val endo_dim : ('a, 'a) matrix -> 'a Nat.eq 77 | 78 | (** A tensor is sparse if the number of elements within 79 | addressed by the tensor is strictly less than 80 | the lenght of the underlying array *) 81 | val is_sparse : < contr : 'a; cov : 'b > t -> bool 82 | 83 | (** {2 Construction function } *) 84 | 85 | (** Unsafe module for function with a runtime check *) 86 | module Unsafe : sig 87 | 88 | (** Create a tensor from an array. 89 | @raise Dimension_error if the array lenght is not compatible 90 | with the given dimension*) 91 | val create : 92 | contr:'a Shape.eq -> 93 | cov:'b Shape.eq -> float array -> < contr : 'a; cov : 'b > t 94 | end 95 | 96 | (** Creates a tensor with constant coefficients *) 97 | val const : 98 | contr:'a Shape.eq -> cov:'b Shape.eq -> float -> < contr : 'a; cov : 'b > t 99 | 100 | (** zero tensor *) 101 | val zero : contr:'a Shape.eq -> cov:'b Shape.eq -> < contr : 'a; cov : 'b > t 102 | 103 | (** [init_sh f sh1 sh2] creates a tensor [ t ] such 104 | that for all s≺sh1 and s2≺sh2 [ t.(s,s2) = f s s2 ] 105 | *) 106 | val init_sh : 107 | ('a Shape.lt -> 'b Shape.lt -> float) -> 108 | contr:'a Shape.eq -> cov:'b Shape.eq -> < contr : 'a; cov : 'b > t 109 | 110 | (** [vector n f] computes the n vector [ v_{i} = f i ] *) 111 | val vector : 'a Nat.eq -> ('a Nat.lt -> float) -> 'a vec 112 | 113 | (** [matrix k l f] computes the k×l matrix [ m_{i,j} = f i j ] *) 114 | val matrix : 115 | 'a Nat.eq -> 'b Nat.eq -> ('a Nat.lt -> 'b Nat.lt -> float) -> ('a, 'b) matrix 116 | 117 | (** [sq_matrix n f] computes the square matrix [matrix n n f] *) 118 | val sq_matrix : 119 | 'a Nat.eq -> ('a Nat.lt -> 'a Nat.lt -> float) -> ('a, 'a) matrix 120 | 121 | (** copy a tensor *) 122 | val copy : < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 123 | 124 | (** {2 Reshaping and indice manipulation} *) 125 | val reshape : 126 | < contr : 'a; cov : 'b > t -> 127 | 'c Shape.eq * 'd Shape.eq -> < contr : 'c; cov : 'd > t 128 | 129 | (** [transpose t] is the totally transposed tensor such that 130 | (transpose t).(s;s') = t.(s';s). 131 | Note that this transposition supposes that space is flat to be geometrically 132 | valid. 133 | *) 134 | val transpose : < contr : 'a; cov : 'b > t -> < contr : 'b; cov : 'a > t 135 | 136 | 137 | (** {2 Linear basis function} *) 138 | 139 | (** [delta i j = 1. ] if [i = j], [0] otherwise *) 140 | val delta : ('a, 'b) Nat.t -> ('c, 'd) Nat.t -> float 141 | 142 | (** [id n] is the identity matrix in dimension n *) 143 | val id : 'a Nat.eq -> ('a, 'a) matrix 144 | 145 | (** [base n i] is the n vector such that [ v.{j} = delta i j ] *) 146 | val base : 'a Nat.eq -> 'a Nat.lt -> 'a vec 147 | 148 | (** {2 Tensor level operation} *) 149 | 150 | (** Tensor contraction (or multiplication) of a (d,d') tensor and 151 | a (d',d'') tensor: 152 | [ (mult t1 t2).[s;s''] = ∑_{s'≺d'} t1.(s;s') * t2.(s';s'') ] 153 | *) 154 | val mult : 155 | < contr : 'a; cov : 'b > t -> 156 | < contr : 'b; cov : 'c > t -> < contr : 'a; cov : 'c > t 157 | 158 | 159 | (** Contract a (d,d) tensor with itself 160 | [trace t = ∑_{s≺d} t.(s,s)] 161 | *) 162 | val trace : 163 | < contr : 'a; cov : 'a > t -> float 164 | 165 | (** Fully contracts two tensors to obtains a scalar 166 | [full contraction a b = trace (mult a b)] 167 | *) 168 | val full_contraction : 169 | < contr : 'a; cov : 'b > t -> < contr : 'b; cov : 'a > t -> float 170 | 171 | (** Canonical scalar product of two tensors of dimensions (d, d'): 172 | [scalar_product t1 t2 = ∑_{s≺d,s'≺d'} t1.[s;s'] * t2.[s;s'] ] *) 173 | val scalar_product : 174 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t -> float 175 | 176 | (** [map2 f t1 t2] is the tensor [ t.(s) = f t1.(s) t2.(s) ] *) 177 | val map2 : 178 | (float -> float -> float) -> 179 | < contr : 'a; cov : 'b > t -> 180 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 181 | 182 | (** [scalar_map f t1] is the tensor [ t.(s) = f t.(s) ] *) 183 | val scalar_map : 184 | (float -> float) -> 185 | < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t 186 | 187 | (** matrix power: 188 | [pow_int m 0 = id ] 189 | [pow_int m n = m * pow_int m (n-1) ] 190 | *) 191 | val pow_int : ('a,'a) matrix -> int -> ('a,'a) matrix 192 | 193 | (** Linear algebra basic operations *) 194 | module Operators: Signatures.tensor_operators with 195 | type 'a t := 'a t and 196 | type ('a,'b) matrix := ('a,'b) matrix 197 | include Signatures.tensor_operators with 198 | type 'a t := 'a t and 199 | type ('a,'b) matrix := ('a,'b) matrix 200 | 201 | (** {2 Slice related operations} *) 202 | 203 | 204 | val partial_copy : 205 | < contr : 'a; cov : 'b > t -> ('a, 'c) Mask.t * ('b, 'd) Mask.t -> 206 | < contr : 'c; cov : 'd > t 207 | 208 | val slice : 209 | < contr : 'a; cov : 'b > t -> 210 | ('a, 'c) Mask.t * ('b, 'd) Mask.t -> < contr : 'c; cov : 'd > t 211 | 212 | val blit : < contr : 'a; cov : 'b > t -> < contr : 'a; cov : 'b > t -> unit 213 | val partial_blit : 214 | < contr : 'a; cov : 'b > t -> 215 | ('a, 'c) Mask.s_to_eq * ('b, 'd) Mask.s_to_eq -> 216 | < contr : 'c; cov : 'd > t -> unit 217 | 218 | val (.%[]) : 219 | < contr : 'a; cov : 'b > t -> 220 | ('a, 'c) Mask.t * ('b, 'd) Mask.t -> 221 | < contr : 'c; cov : 'd > t 222 | 223 | val (.%[]<-) : 224 | < contr : 'a; cov : 'b > t -> 225 | ('a, 'c) Mask.s_to_eq * ('b, 'd) Mask.s_to_eq -> 226 | < contr : 'c; cov : 'd > t -> unit 227 | 228 | val det : < contr : 'a Shape.single; cov : 'a Shape.single > t -> float 229 | val normal : 'dim vec array -> 'dim vec 230 | -------------------------------------------------------------------------------- /ppx/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (public_name tensority_ppx) 3 | (kind ppx_rewriter) 4 | (preprocess (pps ppxlib.metaquot)) 5 | (libraries ppxlib) 6 | ) -------------------------------------------------------------------------------- /ppx/ppx_tensority.ml: -------------------------------------------------------------------------------- 1 | open Ppxlib 2 | 3 | module H = Ast_helper 4 | module T = H.Typ 5 | 6 | type tensor_kind = { fn: Location.t -> Parsetree.expression; name:string } 7 | let mk_array = 8 | { fn = (fun loc -> [%expr Multidim_array.Unsafe.create]); name = "array" } 9 | 10 | let mk_tensor = 11 | { fn = (fun loc -> [%expr Tensor.Unsafe.create]); name = "tensor" } 12 | 13 | 14 | type 'a loc = 'a Location.loc 15 | type label = Asttypes.label 16 | type closed_flag = Asttypes.closed_flag = Closed | Open 17 | 18 | let mkloc ~loc txt = Location.{txt; loc } 19 | 20 | let error loc format = 21 | Location.raise_errorf ~loc ("ppx_tensority:" ^^ format) 22 | 23 | 24 | module Lid = struct 25 | open Longident 26 | let (!) s = Lident s 27 | let (<*>) m s = Ldot ( m, s) 28 | let ( $ ) f x = Lapply(f,x) 29 | end 30 | 31 | module Polyvar = struct 32 | type row_tag = { 33 | label: label 34 | ; attributes: attributes 35 | ; empty_type: bool 36 | ; conjunction: core_type list 37 | } 38 | 39 | let tag ~loc ?(empty_type=false) ?(conjunction=[]) label = 40 | { prf_desc = Rtag (label, empty_type, conjunction ); 41 | prf_attributes = []; 42 | prf_loc = loc; 43 | } 44 | 45 | let set ~loc typ = { prf_desc = Rinherit typ; prf_attributes = []; prf_loc = loc } 46 | 47 | type simple_var = core_type list 48 | 49 | let var_simple loc ?(closed=Closed) ?lower_bound_opt upper_bound = 50 | { 51 | ptyp_desc = Ptyp_variant (upper_bound, closed, lower_bound_opt) 52 | ; ptyp_loc = loc 53 | ; ptyp_loc_stack = [] 54 | ; ptyp_attributes = [] 55 | } 56 | 57 | 58 | let var loc types = 59 | if List.length types > 1 then 60 | var_simple loc ~closed:Closed ~lower_bound_opt:[] types 61 | else 62 | var_simple loc ~closed:Closed types 63 | 64 | let var_low types = 65 | var_simple ~closed:Closed ~lower_bound_opt:[] types 66 | (* 67 | let var ?(conjunction=[]) label = variant @@ tag ~conjunction label 68 | *) 69 | end 70 | 71 | module Expr = struct 72 | let rec sequence loc = function 73 | | [x] -> x 74 | | a :: q -> H.Exp.sequence ~loc a (sequence loc q) 75 | | _ -> assert false 76 | 77 | let rec extract_sequence =function 78 | | [%expr [%e? h]; [%e? r] ] -> h :: (extract_sequence r) 79 | | e -> [e] 80 | 81 | let rec extract_list =function 82 | | [%expr [%e? h] :: [%e? r] ] -> h :: (extract_list r) 83 | | [%expr [] ] -> [] 84 | | e -> error e.pexp_loc "wrong kind of expression, a list was expected" 85 | 86 | 87 | let rec to_list loc = function 88 | | [] -> [%expr [] ][@metaloc loc] 89 | | [e] -> [%expr [[%e e]] ][@metaloc loc] 90 | | a::q -> [%expr [%e a]::[%e to_list loc q] ][@metaloc loc] 91 | end 92 | 93 | module Ints = struct 94 | 95 | 96 | let to_label n = "_" ^ string_of_int n 97 | 98 | let t loc = Polyvar.tag ~loc ~empty_type:true (Loc.make ~loc "T") 99 | let eq loc = [%type: [`Eq] ] 100 | let lt loc = [%type: [`Lt] ] 101 | let le loc = [%type: [`Lt|`Eq] ] 102 | 103 | module Expr = struct 104 | let nat loc kind typ value = 105 | [%expr let open Nat in 106 | (Unsafe.create [%e value]: ([%t typ],[%t kind]) Nat.t) ] 107 | [@metaloc loc] 108 | 109 | let shape loc kind typ value = 110 | [%expr Mask.Elt [%e nat loc kind typ value]][@@metaloc loc] 111 | 112 | let int ~loc n = H.Exp.constant ~loc (H.Const.int n) 113 | end 114 | 115 | (* a digit [k] followed by [t] *) 116 | let digit ~loc k t = 117 | Polyvar.tag ~loc ~conjunction:[t] (Loc.make ~loc @@ to_label k) 118 | 119 | (* number of digits *) 120 | let size n = int_of_float @@ log10 @@ float n 121 | 122 | 123 | let rec digit_split n = 124 | if n < 10 then n, 0 125 | else 126 | let d, k = digit_split (n/10) in 127 | d, k * 10 + n mod 10 128 | 129 | module Eq = struct 130 | (** '( =n ) nat *) 131 | module Type = struct 132 | let rec int_rec loc n inner = 133 | let open Polyvar in 134 | if n < 10 then 135 | var loc [ digit ~loc n inner ] 136 | else 137 | let l, k = n mod 10, n / 10 in 138 | int_rec loc k @@ var loc [ digit ~loc l inner ] 139 | 140 | let int loc n = int_rec loc n @@ Polyvar.var loc [t loc] 141 | end 142 | 143 | let int loc n = 144 | Expr.shape loc (eq loc) (Type.int loc n) (Expr.int ~loc n) 145 | 146 | let nat loc n = 147 | Expr.nat loc (eq loc) (Type.int loc n) (Expr.int ~loc n) 148 | end 149 | 150 | module Types = struct 151 | 152 | let ($) f t = T.constr f t 153 | 154 | let gtp loc k t = 155 | let lid = Lid.( !"Nat_defs" <*> "Gtp" <*> to_label k ) in 156 | let lid = mkloc ~loc lid in 157 | (lid $ [t]) 158 | 159 | let lep loc k t = 160 | let lid = Lid.( !"Nat_defs" <*> "Lep" <*> to_label k ) in 161 | let lid = mkloc ~loc lid in 162 | (lid $ [t]) 163 | 164 | let all loc t = gtp loc 0 t 165 | let ending loc = 166 | let open Polyvar in 167 | var loc [t loc; set ~loc @@ all loc @@ T.any () ] 168 | 169 | let rec digits loc k = if k = 0 then 170 | ending loc 171 | else 172 | let open Polyvar in 173 | var_low loc [ set ~loc @@ all loc (digits loc @@ k - 1) ] 174 | end 175 | 176 | module L = struct 177 | 178 | let rec int_aux loc k len inner = 179 | let open Types in 180 | let inner d = 181 | let open Polyvar in 182 | let l = [ digit ~loc d inner ] in 183 | let l = if d<9 then 184 | (set ~loc @@ gtp loc (d+1) @@ digits loc len ) :: l 185 | else 186 | l 187 | in 188 | if d > 0 then 189 | var loc @@ set ~loc (lep loc (d-1) @@ digits loc @@ 1 + len) :: l 190 | else 191 | var_low loc l 192 | in 193 | if k < 10 then 194 | inner k 195 | else 196 | let d, k = k mod 10, k / 10 in 197 | int_aux loc k (len + 1) (inner d) 198 | end 199 | 200 | 201 | module Lt = struct 202 | 203 | module Type = struct 204 | open Types 205 | let int loc k = 206 | if k = 0 then 207 | ending loc 208 | else 209 | let open Polyvar in 210 | L.int_aux loc k 0 (var_low loc [ set ~loc @@ all loc @@ ending loc ] ) 211 | end 212 | 213 | let int loc k = 214 | Expr.shape loc (lt loc) (Type.int loc k) (Expr.int ~loc k) 215 | 216 | let nat loc k = 217 | Expr.nat loc (lt loc) (Type.int loc k) (Expr.int ~loc k) 218 | end 219 | 220 | module Le = struct 221 | 222 | module Type = struct 223 | open Types 224 | let int loc k = 225 | if k = 0 then 226 | ending loc 227 | else 228 | let open Polyvar in 229 | L.int_aux loc k 0 (var_low loc [ set ~loc @@ ending loc ] ) 230 | end 231 | 232 | let int loc k = 233 | Expr.shape loc (le loc) (Type.int loc k) (Expr.int ~loc k) 234 | 235 | let nat loc k = 236 | Expr.nat loc (le loc) (Type.int loc k) (Expr.int ~loc k) 237 | end 238 | 239 | 240 | end 241 | 242 | let expect_int name = function 243 | | { pexp_desc = Pexp_constant Pconst_integer(n, None); _ } -> 244 | int_of_string n 245 | | e -> error e.pexp_loc "[%%%s] expected an integer as first argument" 246 | name 247 | 248 | let nat_constant (k,f) = 249 | Context_free.Rule.( 250 | constant Integer k (fun loc s -> f loc (int_of_string s)) 251 | ) 252 | 253 | 254 | let constants = List.map nat_constant 255 | Ints.[ 'k', Eq.nat; 256 | 's', Eq.nat ; 257 | 'i', Lt.nat; 258 | 'j', Lt.int; 259 | 'p', Le.nat 260 | ] 261 | 262 | module Index_rewriter = struct 263 | 264 | let tuples = function 265 | | {pexp_desc = Pexp_tuple l; _ } as e -> 266 | Expr.to_list e.pexp_loc @@ l 267 | | e -> Expr.to_list e.pexp_loc [ e ] 268 | 269 | let seq inner seq = 270 | let l = Expr.extract_sequence seq in 271 | match l with 272 | | [a] -> inner a 273 | | l -> H.Exp.tuple ~loc:seq.pexp_loc (List.map inner l) 274 | 275 | let array inner = function 276 | | { pexp_desc = Pexp_array l; _ } as a -> 277 | H.Exp.tuple ~loc:a.pexp_loc (List.map inner l) 278 | | a -> 279 | error a.pexp_loc "[.!(;..)] expected an array literal as index" 280 | 281 | 282 | let simple x = seq tuples x 283 | let all = 284 | function 285 | | [%expr [%e? i ]; __ ] -> 286 | let loc = i.pexp_loc in 287 | [%expr [%e tuples i], []] 288 | | [%expr __ ; [%e? i ] ] -> 289 | let loc = i.pexp_loc in 290 | [%expr [], [%e tuples i]] 291 | | e -> simple e 292 | 293 | let all_array = 294 | function 295 | | [%expr [| [%e? i ]; __ |] ] -> 296 | let loc = i.pexp_loc in 297 | [%expr [ [%e tuples i], [] ] ] 298 | | [%expr [| __ ; [%e? i ] |] ] -> 299 | let loc = i.pexp_loc in 300 | [%expr [ [], [%e tuples i] ] ] 301 | | e -> array tuples e 302 | 303 | 304 | end 305 | 306 | let single_str_expr = Ast_pattern.(pstr (pstr_eval __ __ ^:: nil) ) 307 | 308 | module Array_lit = struct 309 | 310 | let _vec loc = function 311 | | { pexp_desc = Pexp_tuple s; _ } as e -> 312 | let a = {e with pexp_desc = Pexp_array s} in 313 | let nat = Ints.Eq.nat loc @@ List.length s in 314 | [%expr Unsafe.create Shape.[[%e nat]] [%e a]][@metaloc loc] 315 | | _ -> 316 | error loc "expected tuple in [%%vec ...]" 317 | 318 | type 'a nested_list = 319 | | Elt of 'a 320 | | Nested of 'a nested_list loc * 'a nested_list loc list 321 | 322 | let rec extract_nested ({name; _ } as k) loc level e = 323 | if level = 0 then 324 | mkloc ~loc @@ Elt e 325 | else 326 | let loc, a, q = 327 | if level mod 2 = 0 then 328 | match e with 329 | | [%expr [%e? a] :: [%e? b] ] -> 330 | e.pexp_loc, a, Expr.extract_list b 331 | | [%expr [] ] -> 332 | error e.pexp_loc 333 | "[%%%s] invalid input: a non-empty list was expected" 334 | name level 335 | | e -> error e.pexp_loc 336 | "[%%%s] invalid input: a list of %d-tensors was expected" 337 | name level 338 | else 339 | match e with 340 | | {pexp_desc = Pexp_tuple (a::q) ; _ } as e -> e.pexp_loc, a, q 341 | | e -> error e.pexp_loc 342 | "[%%%s] invalid input: a list of comma separated %d-tensors\ 343 | was expected" name level 344 | in 345 | let extr e = extract_nested k loc (level - 1) e in 346 | mkloc ~loc @@ Nested( extr a, List.map extr q) 347 | [@@warning "-4"] 348 | 349 | let rec compute_and_check_shape kind loc level = 350 | let error_ppx = error in 351 | let open Location in 352 | function 353 | | { txt = Elt _ ; _ } -> [] 354 | | { txt = Nested (a,q); _ } -> 355 | let n = 1 + List.length q in 356 | let shape0 = 357 | compute_and_check_shape kind a.loc (level - 1) a in 358 | let test (e:_ loc) = 359 | shape0 = compute_and_check_shape kind e.loc (level - 1) e in 360 | if List.for_all test q 361 | then n :: shape0 362 | else error_ppx loc "[%%%s]: non-valid sub-tensor shape at level %d" 363 | kind.name level 364 | 365 | let rec flatten_nested n l = 366 | let open Location in 367 | match n.txt with 368 | | Elt e -> e :: l 369 | | Nested(a, q) -> 370 | flatten_nested a @@ 371 | List.fold_right flatten_nested q l 372 | 373 | let transpose sh l = 374 | let a = Array.of_list l in 375 | let a' = Array.make (Array.length a) a.(0) in 376 | let sh' = List.rev sh in 377 | let rec pos l k = match l, k with 378 | | [], [] -> 0 379 | | n :: q, p :: qp -> p + n * pos q qp 380 | | _ -> raise (Invalid_argument "transpose") in 381 | let rec shape_iter f = function 382 | | [] -> f [] 383 | | a :: q -> 384 | for i = 0 to a - 1 do 385 | shape_iter (fun l -> f (i::l) ) q done 386 | in 387 | let () = (* do the transposition *) 388 | shape_iter (fun k -> 389 | a'.(pos sh' (List.rev k) ) <- a.(pos sh k) 390 | ) 391 | sh in 392 | Array.to_list a' 393 | 394 | let array loc level e = 395 | let kind = mk_array in 396 | let nested = extract_nested kind loc level e in 397 | let shape_int = compute_and_check_shape kind loc level nested in 398 | let l = flatten_nested nested [] in 399 | let shape = Expr.to_list loc @@ List.rev_map (Ints.Eq.nat loc) shape_int in 400 | let array = H.Exp.array ~loc l in 401 | [%expr [%e kind.fn loc] [%e shape] [%e array] ] 402 | 403 | let rec split n l = 404 | if n = 0 then [], l else 405 | match l with 406 | | a :: q -> 407 | let left, right = split (n-1) q in 408 | a :: left, right 409 | | [] -> raise @@ Invalid_argument ( 410 | Printf.sprintf "split %d [] is not valid" n 411 | ) 412 | 413 | let tensor loc ~contr ~cov e = 414 | let kind = mk_tensor in 415 | let level = contr + cov in 416 | let nested = extract_nested kind loc level e in 417 | let shape_int = compute_and_check_shape kind loc level nested in 418 | let l = flatten_nested nested [] in 419 | let contr, cov = split contr shape_int in 420 | let shape l = Expr.to_list loc @@ List.rev_map (Ints.Eq.nat loc) l in 421 | let array = H.Exp.array ~loc l in 422 | [%expr [%e kind.fn loc] 423 | ~contr:[%e shape contr] 424 | ~cov:[%e shape cov] 425 | [%e array] 426 | ] 427 | 428 | let array_mapper loc = function 429 | | [%expr [%e? n] [%e? arr] ] -> 430 | array loc (expect_int mk_array.name n) arr 431 | | arr -> array loc 1 arr 432 | 433 | let array = 434 | let open Extension in 435 | declare "array" Context.expression single_str_expr 436 | (fun ~loc ~path:_ e _ -> array_mapper loc e) 437 | 438 | 439 | let tensor_mapper loc = 440 | function 441 | | [%expr [%e? contr] [%e? cov] [%e? array] ] -> 442 | tensor loc 443 | ~contr:(expect_int mk_tensor.name contr) 444 | ~cov:(expect_int mk_tensor.name cov) 445 | array 446 | | e -> e 447 | 448 | let vec_mapper loc array = 449 | tensor loc ~contr:1 ~cov:0 array 450 | 451 | let matrix_mapper loc array = 452 | tensor loc ~contr:1 ~cov:1 array 453 | 454 | 455 | let ext name f = 456 | let open Extension in 457 | declare name Context.expression 458 | Ast_pattern.(pstr (pstr_eval __ __ ^:: nil) ) 459 | (fun ~loc ~path:_ e _ -> f loc e) 460 | 461 | let tensor = ext "tensor" tensor_mapper 462 | let vec = ext "vec" vec_mapper 463 | let matrix = ext "matrix" matrix_mapper 464 | 465 | 466 | let rules = 467 | List.map Context_free.Rule.extension [array; tensor; vec ; matrix ] 468 | 469 | end 470 | 471 | module Range = struct 472 | let range loc ?by start stop = 473 | let int = expect_int "range" in 474 | let start = int start and stop = int stop in 475 | let step = match by with Some step -> int step | None -> 1 in 476 | if stop < start then 477 | error loc "[%%range]: invalid argument start indice %d > stop indice %d" 478 | start stop 479 | else if step <= 0 then 480 | error loc "[%%range]: invalid argument step %d ≤ 0" 481 | step 482 | ; let len = 1 + (stop - start) / step in 483 | let start = Ints.Lt.nat loc start and stop = Ints.Lt.nat loc stop 484 | and len = Ints.Eq.nat loc len and step = Ints.Expr.int ~loc step in 485 | [%expr Mask.Range( 486 | Range.create 487 | ~start:[%e start] ~stop:[%e stop] ~step:[%e step] ~len:[%e len] 488 | ) 489 | ] 490 | 491 | let extension_match loc = function 492 | | [%expr [%e? start] [%e? stop] ~by:[%e? step] ] as e -> 493 | range e.pexp_loc ~by:step start stop 494 | | [%expr [%e? start] [%e? stop] ] as e -> 495 | range e.pexp_loc start stop 496 | | _ -> error loc "Unsupported range expression" 497 | let extension = let open Extension in 498 | declare "range" Context.expression single_str_expr 499 | (fun ~loc ~path:_ e _ -> extension_match loc e) 500 | 501 | let rule_extension = Context_free.Rule.extension extension 502 | 503 | let special_fn = 504 | Context_free.Rule.special_function 505 | "#->#" (function 506 | | [%expr ([%e? start] #-># [%e? stop]) ~by:[%e? step] ] 507 | | [%expr [%e? start] #-># [%e? stop] % [%e? step] ] as e -> 508 | Some(range e.pexp_loc ~by:step start stop) 509 | | [%expr [%e? start] #-># [%e? stop] ] as e -> 510 | Some(range e.pexp_loc start stop) 511 | | _ -> None 512 | ) 513 | 514 | let rules = [ rule_extension; special_fn] 515 | 516 | end 517 | 518 | module Index = struct 519 | 520 | let slice = 521 | Context_free.Rule.special_function 522 | "(.!())" 523 | (function 524 | | [%expr [%e? a].!( [%e? i] ) ] as e -> 525 | let loc = e.pexp_loc in 526 | let i = Index_rewriter.all i in 527 | Some [%expr [%e a].%[ Tensority.Mask.( [%e i] )] ] 528 | | _ -> None 529 | ) 530 | 531 | let slice_bis = 532 | Context_free.Rule.special_function 533 | "(.!(;..))" 534 | (function 535 | | [%expr (.!(;..)) [%e? a] [%e? i] ] as e -> 536 | let loc = e.pexp_loc in 537 | let i = Index_rewriter.all_array i in 538 | Some [%expr [%e a].%[ Tensority.Mask.( [%e i] )] ] 539 | | _ -> None 540 | ) 541 | 542 | 543 | let access = 544 | Context_free.Rule.special_function 545 | "Array.get" 546 | ( function 547 | | [%expr [%e? a].([%e? i]) ] as e -> 548 | let loc = e.pexp_loc in 549 | let i = Index_rewriter.all i in 550 | Some [%expr [%e a].%( Tensority.Shape.( [%e i] ) ) ] 551 | | _ -> None 552 | ) 553 | 554 | let blit = 555 | Context_free.Rule.special_function 556 | "(.!()<-)" 557 | (function 558 | | [%expr [%e? a].!([%e? i] ) <- [%e? v] ] as e -> 559 | let loc = e.pexp_loc in 560 | let i = Index_rewriter.all i in 561 | Some [%expr [%e a].%[Tensority.Mask.([%e i])]<- [%e v] ] 562 | | _ -> None 563 | ) 564 | 565 | let blit_bis = 566 | Context_free.Rule.special_function 567 | "(.!(;..)<-)" 568 | (function 569 | | [%expr (.!(;..)<-) [%e? a] [%e? i] [%e? v] ] as e -> 570 | let loc = e.pexp_loc in 571 | let i = Index_rewriter.all_array i in 572 | Some [%expr [%e a].%[Tensority.Mask.([%e i])]<- [%e v] ] 573 | | _ -> None 574 | ) 575 | 576 | 577 | let assign = 578 | Context_free.Rule.special_function 579 | "Array.set" 580 | (function 581 | | [%expr [%e? a].([%e? i]) <- [%e? v] ] as e -> 582 | let loc = e.pexp_loc in 583 | Some [%expr [%e a].%(Tensority.Shape.([%e i]))<- [%e v] ] 584 | | _ -> None 585 | ) 586 | 587 | 588 | let rules = [slice; slice_bis; blit; blit_bis; access; assign] 589 | 590 | end 591 | 592 | let () = Driver.register_transformation 593 | ~rules:([Range.rule_extension; Range.special_fn] 594 | @ Array_lit.rules 595 | @ Index.rules 596 | @ constants) 597 | "ppx_tensority" 598 | -------------------------------------------------------------------------------- /tensority.opam: -------------------------------------------------------------------------------- 1 | opam-version: "2.0" 2 | name: "tensority" 3 | version: "0.1" 4 | maintainer: "octachron " 5 | authors:"octachron " 6 | homepage: "https://github.com/Octachron/tensority" 7 | dev-repo: "git+https://github.com/Octachron/tensority.git" 8 | bug-reports: "https://github.com/Octachron/tensority/issues" 9 | license: "GPL 3+" 10 | build: ["dune" "build" "-p" name ] 11 | depends: [ "dune" {build} ] 12 | -------------------------------------------------------------------------------- /tensority_ppx.opam: -------------------------------------------------------------------------------- 1 | opam-version: "2.0" 2 | name: "tensority_ppx" 3 | version: "0.1" 4 | maintainer: "octachron " 5 | authors:"octachron " 6 | homepage: "https://github.com/Octachron/tensority" 7 | dev-repo: "git+https://github.com/Octachron/tensority.git" 8 | bug-reports: "https://github.com/Octachron/tensority/issues" 9 | license: "GPL 3+" 10 | build: ["dune" "build" "-p" name ] 11 | depends: [ "dune" {build} "ppxlib" ] 12 | --------------------------------------------------------------------------------