├── .gitignore ├── CHANGELOG.md ├── Cargo.lock ├── Cargo.toml ├── DESIGN.md ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── clippy.toml ├── rustfmt.toml ├── src ├── array.rs ├── dim.rs ├── expr │ ├── adapters.rs │ ├── buffer.rs │ ├── expression.rs │ ├── into_expr.rs │ ├── iter.rs │ ├── mod.rs │ └── sources.rs ├── index │ ├── axis.rs │ ├── mod.rs │ ├── permutation.rs │ ├── slice.rs │ └── view.rs ├── layout.rs ├── lib.rs ├── macros.rs ├── mapping.rs ├── ops.rs ├── raw_slice.rs ├── raw_tensor.rs ├── serde.rs ├── shape.rs ├── slice.rs ├── tensor.rs ├── traits.rs └── view.rs └── tests ├── aligned_alloc └── mod.rs └── test.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.7.0] - 2025-02-23 9 | 10 | - Update dependencies. 11 | - Add Apply and BorrowMut for Owned, and remove T: Default bound for Apply. 12 | - Add Ord for Shape and Copy for ConstShape. 13 | - Add Owned trait for deriving type in conversions from slices and expressions. 14 | - Make axis indexing and iteration more consistent and complete. 15 | - Update to edition 2024. 16 | - Implement IntoShape for array references, see #6. 17 | - Remove Concat and bound in Reverse, and hide types for indexing. 18 | - Improve documentation of identity permutation, see #6. 19 | - Update Axis trait and helper types. 20 | - Reorganize modules and cleanup. 21 | - Add array and to_array methods, and fix issues with view indexing and ZST layout. 22 | - Support dynamic axis arguments and permutations, simplify Dyn and add axis indexing. 23 | - Improve conversion functions, see issue #6. 24 | - Add placeholder dimension in reshape methods, see #6. 25 | - Implement row, column and diagonal methods for generic shape. 26 | - Simplify shape and layout types for subarrays. 27 | - Minor fixes and added more checking. 28 | - Simplify module structure. 29 | - Remove extern type and PhantomPinned due to lack of noalias. 30 | - Add support for dynamic rank, see #4. 31 | - Remove slice methods and dereferencing for IntoExpr, and update documentation. 32 | - Rename array types, see proposal in #2. 33 | - Switch to row-major order, and simplify layout types to Dense and Strided. 34 | - Remove derivation of output type in eval(), to align with collect(). 35 | - Improved conversions to/from Array. 36 | - Added design notes. 37 | 38 | ## [0.6.1] - 2024-08-31 39 | 40 | - Added missing public items. 41 | 42 | ## [0.6.0] - 2024-08-03 43 | 44 | - Added array methods, and some further updates. 45 | - Added array type with inline storage. 46 | - Fix long compile times for release build. 47 | - Update to Rust 1.79, and avoid dependency to ATPIT. 48 | - Simplify apply and zip_with for arrays, and readd map method. 49 | - Introduced array shape for a list of dimensions, each having static or dynamic size. 50 | - Remove array trait, and merge array expression and view types. 51 | - Changed expression to a trait and some further cleanup. 52 | - Use associated type bounds and minor cleanup. 53 | - Return expressions instead of new arrays for operators, depends on rust-lang/rust#63063. 54 | 55 | ## [0.5.0] - 2023-12-02 56 | 57 | - Added iteration over all but one dimension, and methods for row, column and diagonal. 58 | - Added expressions for multidimensional iteration. 59 | - Added permutation of dimensions. 60 | - Create views from multiple arguments instead of a tuple, and changed ranges with negative step size. 61 | - Refactoring and added spare_capacity_mut. 62 | - Simplify array layout types. 63 | - Add contains method and remove implicit borrowing. 64 | - Remove must_use annotations and enable unused_results warning. 65 | - Improve zero-dimensional array support and minor cleanup. 66 | - Add missing file. 67 | - Added macros for creating arrays and array views. 68 | - Use fixed element order. 69 | - Merge GridBase and SpanBase to common Array type. 70 | - Remove generic parameters for layout. 71 | - Make DenseSpan public. 72 | 73 | ## [0.4.0] - 2022-11-03 74 | 75 | - Change SpanBase from ZST slice to extern type. 76 | - Fix feature attributes for tests. 77 | - Add must_use attributes. 78 | - Improve interface and add debug, hashing and from_elem/from_elem_in. 79 | - Remove dependencies to nightly features. 80 | - Use const generics for the dimension in array view iterator and split methods. 81 | - Remove attribute to enable GAT. 82 | - Add span index trait, including unchecked get functions. 83 | - Move element order into dimension trait. 84 | - Avoid redundant format types for rank < 2, and simplify indexing. 85 | - Fix clippy warnings. 86 | - Move indexing into submodule. 87 | - Simplify layout methods. 88 | - Reorganize mapping file. 89 | - Replace &self with self for copy types. 90 | - Rename linear to flat format. 91 | - Replace methods not needing self with associated constants or functions. 92 | - Replace generic parameter with associated type. 93 | - Update documentation. 94 | - Add support for permissive provenance, and remove ptr_metadata feature. 95 | - Rename split functions and add split for any axis. 96 | - Added into_split_at and into_view for array views. 97 | - Improve functions for flattening, reformatting and reshaping. 98 | - Add indexing for shape and strides types, and ensure inner stride is maintained in reshape. 99 | - Further refactoring and added more methods/operators and serde support. 100 | 101 | ## [0.3.0] - 2022-01-06 102 | 103 | - Major refactoring including type-level constants for rank, and dense/general/strided layout. 104 | - Add into_array/into_vec and AsMut for arrays. 105 | - Refactor static layout and add AsRef for arrays. 106 | - Added comparisons, conversions, debug and iterators. 107 | - Add type for aligned memory allocation. 108 | - Avoid separate types for subarrays, remove deref from ViewBase to slice and refactoring. 109 | - Fix generic resize. 110 | - Store layout inline for dense 1-dimensional views, and added missing asserts. 111 | - Minimum alignment based on target features. 112 | 113 | ## [0.2.0] - 2021-09-11 114 | 115 | - Renamed array types, added license files and updated version. 116 | - Refactor and added subarrays. 117 | 118 | ## [0.1.0] - 2021-08-24 119 | 120 | - Initial version. 121 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 4 4 | 5 | [[package]] 6 | name = "mdarray" 7 | version = "0.7.0" 8 | dependencies = [ 9 | "serde", 10 | "serde_test", 11 | ] 12 | 13 | [[package]] 14 | name = "proc-macro2" 15 | version = "1.0.93" 16 | source = "registry+https://github.com/rust-lang/crates.io-index" 17 | checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" 18 | dependencies = [ 19 | "unicode-ident", 20 | ] 21 | 22 | [[package]] 23 | name = "quote" 24 | version = "1.0.38" 25 | source = "registry+https://github.com/rust-lang/crates.io-index" 26 | checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" 27 | dependencies = [ 28 | "proc-macro2", 29 | ] 30 | 31 | [[package]] 32 | name = "serde" 33 | version = "1.0.218" 34 | source = "registry+https://github.com/rust-lang/crates.io-index" 35 | checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" 36 | dependencies = [ 37 | "serde_derive", 38 | ] 39 | 40 | [[package]] 41 | name = "serde_derive" 42 | version = "1.0.218" 43 | source = "registry+https://github.com/rust-lang/crates.io-index" 44 | checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" 45 | dependencies = [ 46 | "proc-macro2", 47 | "quote", 48 | "syn", 49 | ] 50 | 51 | [[package]] 52 | name = "serde_test" 53 | version = "1.0.177" 54 | source = "registry+https://github.com/rust-lang/crates.io-index" 55 | checksum = "7f901ee573cab6b3060453d2d5f0bae4e6d628c23c0a962ff9b5f1d7c8d4f1ed" 56 | dependencies = [ 57 | "serde", 58 | ] 59 | 60 | [[package]] 61 | name = "syn" 62 | version = "2.0.98" 63 | source = "registry+https://github.com/rust-lang/crates.io-index" 64 | checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" 65 | dependencies = [ 66 | "proc-macro2", 67 | "quote", 68 | "unicode-ident", 69 | ] 70 | 71 | [[package]] 72 | name = "unicode-ident" 73 | version = "1.0.17" 74 | source = "registry+https://github.com/rust-lang/crates.io-index" 75 | checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" 76 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mdarray" 3 | version = "0.7.0" 4 | edition = "2024" 5 | rust-version = "1.85" 6 | description = "Multidimensional array for Rust" 7 | repository = "https://github.com/fre-hu/mdarray" 8 | license = "MIT OR Apache-2.0" 9 | keywords = ["array", "matrix", "multidimensional"] 10 | categories = ["data-structures", "mathematics", "science"] 11 | 12 | [dependencies] 13 | serde = { version = "1.0", optional = true } 14 | 15 | [dev-dependencies] 16 | serde_test = "1.0" 17 | 18 | [features] 19 | nightly = [] # Required for testing in Miri by using extern types, see: 20 | # https://github.com/rust-lang/unsafe-code-guidelines/issues/256 21 | -------------------------------------------------------------------------------- /DESIGN.md: -------------------------------------------------------------------------------- 1 | # Design notes 2 | 3 | Below are some design notes, to better understand how the library is working 4 | internally and choices that are made. It is a complement to the documentation 5 | in `lib.rs`. 6 | 7 | ## The `Slice` type for array references 8 | 9 | The starting point of the design is to mimic the `Vec`, `slice` and `array` 10 | types in Rust when possible. This includes having a similar interface and 11 | implementing the same traits etc. 12 | 13 | One difference is that for normal arrays in Rust, a single `slice` type is 14 | sufficient as an array reference when dereferencing from `Vec` and `array`. 15 | However for multi-dimensional arrays, a reference is larger than `slice` and 16 | does not fit in the two words in a fat pointer. There have been suggestions 17 | to have custom dynamically sized types (DSTs) that could be larger, but 18 | unfortunately this seems to be far away in the future. 19 | 20 | It is solved by having separate view types that can contain metadata, and that 21 | a reference is simply a pointer to the internal metadata structure `RawSlice`. 22 | The owned array type `Tensor` has the same metadata structure, which makes it 23 | possible to dereference from both the owned array and view types to a single 24 | reference type. 25 | 26 | The reference type is implemented as a zero sized type (ZST), to disallow any 27 | mutation of the metadata. Otherwise one could modify the internal state of 28 | arrays, creating undefined behavior. Internally there are type casts to the 29 | metadata structure to access its contents. 30 | 31 | ## Fixed sized arrays 32 | 33 | An array shape can be defined using a combination of static and/or dynamic 34 | dimensions. Static dimensions are included in the type and do not take up space 35 | in the metadata. This makes it possible to have fixed sized arrays without any 36 | metadata, except for a pointer to the array elements. One can then dereference 37 | to the `Slice` type also for fixed sized arrays that are allocated on the stack. 38 | 39 | When there is no metadata, a reference to `Slice` points to the array elements 40 | and not to the metadata structure. This is handled automatically depending on 41 | the size of the metadata. 42 | 43 | ## Array view and expression types 44 | 45 | There are two types for array views: `View` and `ViewMut`. These are created 46 | with the methods `expr` and `expr_mut` in `Slice`, and with other methods that 47 | give subarray views. 48 | 49 | In addition to being arrays views, these type are also used as iterator types. 50 | The normal iterator types cannot be used, since they do not contain information 51 | about multiple dimensions. This is an issue for example with the `map` and `zip` 52 | adaptors, since the result type is internal to Rust and cannot be extended. 53 | Furthermore, iteration over multiple dimensions with the `next` method is not 54 | efficient. 55 | 56 | A solution is to create a separate `Expression` trait in parallel to `Iterator`. 57 | The trait has similar methods as the iterator trait, and it is the combinators 58 | that are important. An expression can be encapsulated in the `Iter` type to get 59 | a regular iterator if needed. 60 | 61 | One observation is that expressions are similar to array views, and instead of 62 | having separate types they are merged. This both reduces complexity and avoids 63 | unnecessary type conversions. 64 | 65 | When iterating over an expression, the value is consumed so that one cannot 66 | have a partially evaluated expression. It is needed to be able to merge the 67 | expression and view types as above, and simplifies expression building. 68 | 69 | The `Expression` trait is not implemented for the `Array` and `Tensor` types. 70 | The reason is that it would give the wrong behavior, so that e.g. the result 71 | from the `map` method is an expression and not an array. One would then also 72 | expect the input array to be consumed, but it is not useful as default. 73 | 74 | The `Expression` trait is also not implemented for `&Slice` and `&mut Slice`. 75 | While it could make sense and be convenient, it unfortunately deviates from 76 | how `Iterator` and `IntoIterator` are implemented for normal array types. 77 | 78 | ## Conversion to an expression 79 | 80 | The `IntoExpression` trait is implemented for owned arrays and array references, 81 | similar to `IntoIterator`. It makes it possible to automatically convert to an 82 | expression for example in function arguments. 83 | 84 | Additionally, there is a trait `Apply` that is implemented for the same types 85 | as `IntoExpression`. It acts as a combination of a conversion to an expression, 86 | applying a function and evaluating the result to an array. This is useful to 87 | implement unary and binary operators, where the result is an array if one of the 88 | arguments is an array as described in `lib.rs`. It makes it possible to reuse 89 | the same memory for heap allocated arrays. 90 | 91 | ## Comparison to C++ mdarray/mdspan 92 | 93 | The design borrows a lot on the new C++ mdarray and mdspan types. These are 94 | very well defined and gives a standard to be followed. Some deviations are made 95 | to align with Rust naming and conventions. 96 | 97 | Below are the larger differences to C++ mdarray/mdspan: 98 | 99 | - There is no accessor policy for array views and references. The reason is to 100 | simplify and focus on the case when array elements are directly addressable. 101 | 102 | One use case of the accessor policy is to have custom element alignment e.g. 103 | to optimize for SIMD. However an alternative is to use `Simd` as element type. 104 | Another use case is to have scaling and/or conjugation of elements, but this 105 | is left for higher level libraries. 106 | 107 | - The owned array type is parameterized by an allocator instead of a container. 108 | The main reason is to be able to define the `RawSlice` structure internally 109 | and support dereferencing to `Slice`. 110 | 111 | - The fixed size array type is different from the generic array type with heap 112 | allocation. This is to align with Rust array types, and that the interface 113 | for the fixed size array type is quite different. With separate types the 114 | documentation becomes more clear. 115 | 116 | - Indexing is done with `usize` and is not parameterized. This follows how 117 | indexing is done in Rust, and could be extended if there is a need. 118 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multidimensional array for Rust 2 | 3 | ## Overview 4 | 5 | The mdarray crate provides a multidimensional array for Rust. Its main target 6 | is for numeric types, however generic types are supported as well. The purpose 7 | is to provide a generic container type that is simple and flexible to use, 8 | with interworking to other crates for e.g. BLAS/LAPACK functionality. 9 | 10 | Here are the main features of mdarray: 11 | 12 | - Dense array type, where the rank is known at compile time. 13 | - Static or dynamic array dimensions, with optional stack allocation. 14 | - Standard Rust mechanisms are used for e.g. indexing and iteration. 15 | - Generic expressions for multidimensional iteration. 16 | 17 | The design is inspired from other Rust crates (ndarray, nalgebra, bitvec, dfdx 18 | and candle), the proposed C++ mdarray and mdspan types, and multidimensional 19 | arrays in other languages. 20 | 21 | ## License 22 | 23 | Licensed under either of 24 | 25 | * Apache License, Version 2.0 26 | ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 27 | * MIT license 28 | ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 29 | 30 | at your option. 31 | 32 | ## Contribution 33 | 34 | Unless you explicitly state otherwise, any contribution intentionally submitted 35 | for inclusion in the work by you, as defined in the Apache-2.0 license, shall be 36 | dual licensed as above, without any additional terms or conditions. 37 | -------------------------------------------------------------------------------- /clippy.toml: -------------------------------------------------------------------------------- 1 | type-complexity-threshold = 1000 2 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | format_code_in_doc_comments = true 2 | use_small_heuristics = "Max" 3 | -------------------------------------------------------------------------------- /src/array.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::{Borrow, BorrowMut}; 2 | use std::fmt::{Debug, Formatter, Result}; 3 | use std::hash::{Hash, Hasher}; 4 | use std::mem::{self, ManuallyDrop, MaybeUninit}; 5 | use std::ops::{Deref, DerefMut, Index, IndexMut}; 6 | use std::ptr; 7 | 8 | use crate::dim::Const; 9 | use crate::expr::{self, IntoExpr, Iter, Map, Zip}; 10 | use crate::expr::{Apply, Expression, FromExpression, IntoExpression}; 11 | use crate::index::SliceIndex; 12 | use crate::layout::{Dense, Layout}; 13 | use crate::shape::{ConstShape, Shape}; 14 | use crate::slice::Slice; 15 | use crate::tensor::Tensor; 16 | use crate::traits::Owned; 17 | use crate::view::{View, ViewMut}; 18 | 19 | /// Multidimensional array with constant-sized dimensions and inline allocation. 20 | #[derive(Clone, Copy, Default)] 21 | #[repr(transparent)] 22 | pub struct Array(pub S::Inner); 23 | 24 | impl Array { 25 | /// Creates an array from the given element. 26 | pub fn from_elem(elem: T) -> Self 27 | where 28 | T: Clone, 29 | { 30 | Self::from_expr(expr::from_elem(S::default(), elem)) 31 | } 32 | 33 | /// Creates an array with the results from the given function. 34 | pub fn from_fn T>(f: F) -> Self { 35 | Self::from_expr(expr::from_fn(S::default(), f)) 36 | } 37 | 38 | /// Converts an array with a single element into the contained value. 39 | /// 40 | /// # Panics 41 | /// 42 | /// Panics if the array length is not equal to one. 43 | pub fn into_scalar(self) -> T { 44 | assert!(self.len() == 1, "invalid length"); 45 | 46 | self.into_shape::<()>().0 47 | } 48 | 49 | /// Converts the array into a reshaped array, which must have the same length. 50 | /// 51 | /// # Panics 52 | /// 53 | /// Panics if the array length is changed. 54 | pub fn into_shape(self) -> Array { 55 | assert!(I::default().len() == self.len(), "length must not change"); 56 | 57 | let me = ManuallyDrop::new(self); 58 | 59 | unsafe { mem::transmute_copy(&me) } 60 | } 61 | 62 | /// Returns an array with the same shape, and the given closure applied to each element. 63 | pub fn map U>(self, f: F) -> Array { 64 | self.apply(f) 65 | } 66 | 67 | /// Creates an array with uninitialized elements. 68 | pub fn uninit() -> Array, S> { 69 | let array = >::uninit(); 70 | 71 | unsafe { mem::transmute_copy(&array) } 72 | } 73 | 74 | /// Creates an array with elements set to zero. 75 | /// 76 | /// Zero elements are created using `Default::default()`. 77 | pub fn zeros() -> Self 78 | where 79 | T: Default, 80 | { 81 | let mut array = Self::uninit(); 82 | 83 | array.expr_mut().for_each(|x| { 84 | _ = x.write(T::default()); 85 | }); 86 | 87 | unsafe { array.assume_init() } 88 | } 89 | 90 | fn from_expr>(expr: E) -> Self { 91 | struct DropGuard<'a, T, S: ConstShape> { 92 | array: &'a mut MaybeUninit>, 93 | index: usize, 94 | } 95 | 96 | impl Drop for DropGuard<'_, T, S> { 97 | fn drop(&mut self) { 98 | let ptr = self.array.as_mut_ptr() as *mut T; 99 | 100 | unsafe { 101 | ptr::slice_from_raw_parts_mut(ptr, self.index).drop_in_place(); 102 | } 103 | } 104 | } 105 | 106 | // Ensure that the shape is valid. 107 | _ = expr.shape().with_dims(|dims| S::from_dims(dims)); 108 | 109 | let mut array = MaybeUninit::uninit(); 110 | let mut guard = DropGuard { array: &mut array, index: 0 }; 111 | 112 | let ptr = guard.array.as_mut_ptr() as *mut E::Item; 113 | 114 | expr.for_each(|x| unsafe { 115 | ptr.add(guard.index).write(x); 116 | guard.index += 1; 117 | }); 118 | 119 | mem::forget(guard); 120 | 121 | unsafe { array.assume_init() } 122 | } 123 | } 124 | 125 | impl Array, S> { 126 | /// Converts the array element type from `MaybeUninit` to `T`. 127 | /// 128 | /// # Safety 129 | /// 130 | /// All elements in the array must be initialized, or the behavior is undefined. 131 | pub unsafe fn assume_init(self) -> Array { 132 | unsafe { mem::transmute_copy(&self) } 133 | } 134 | } 135 | 136 | impl<'a, T, U, S: ConstShape> Apply for &'a Array { 137 | type Output U> = Map; 138 | type ZippedWith U> = 139 | Map, F>; 140 | 141 | fn apply U>(self, f: F) -> Self::Output { 142 | self.expr().map(f) 143 | } 144 | 145 | fn zip_with(self, expr: I, f: F) -> Self::ZippedWith 146 | where 147 | F: FnMut((&'a T, I::Item)) -> U, 148 | { 149 | self.expr().zip(expr).map(f) 150 | } 151 | } 152 | 153 | impl<'a, T, U, S: ConstShape> Apply for &'a mut Array { 154 | type Output U> = Map; 155 | type ZippedWith U> = 156 | Map, F>; 157 | 158 | fn apply U>(self, f: F) -> Self::Output { 159 | self.expr_mut().map(f) 160 | } 161 | 162 | fn zip_with(self, expr: I, f: F) -> Self::ZippedWith 163 | where 164 | F: FnMut((&'a mut T, I::Item)) -> U, 165 | { 166 | self.expr_mut().zip(expr).map(f) 167 | } 168 | } 169 | 170 | impl Apply for Array { 171 | type Output U> = Array; 172 | type ZippedWith U> = Array; 173 | 174 | fn apply U>(self, f: F) -> Array { 175 | Array::from_expr(self.into_expr().map(f)) 176 | } 177 | 178 | fn zip_with(self, expr: I, f: F) -> Array 179 | where 180 | F: FnMut((T, I::Item)) -> U, 181 | { 182 | Array::from_expr(self.into_expr().zip(expr).map(f)) 183 | } 184 | } 185 | 186 | impl AsMut for Array 187 | where 188 | Slice: AsMut, 189 | { 190 | fn as_mut(&mut self) -> &mut U { 191 | (**self).as_mut() 192 | } 193 | } 194 | 195 | impl AsRef for Array 196 | where 197 | Slice: AsRef, 198 | { 199 | fn as_ref(&self) -> &U { 200 | (**self).as_ref() 201 | } 202 | } 203 | 204 | macro_rules! impl_as_mut_ref { 205 | (($($xyz:tt),+), $array:tt) => { 206 | impl AsMut,)+)>> for $array { 207 | fn as_mut(&mut self) -> &mut Array,)+)> { 208 | unsafe { &mut *(self as *mut Self as *mut Array,)+)>) } 209 | } 210 | } 211 | 212 | impl AsRef,)+)>> for $array { 213 | fn as_ref(&self) -> &Array,)+)> { 214 | unsafe { &*(self as *const Self as *const Array,)+)>) } 215 | } 216 | } 217 | }; 218 | } 219 | 220 | impl_as_mut_ref!((X), [T; X]); 221 | impl_as_mut_ref!((X, Y), [[T; Y]; X]); 222 | impl_as_mut_ref!((X, Y, Z), [[[T; Z]; Y]; X]); 223 | impl_as_mut_ref!((X, Y, Z, W), [[[[T; W]; Z]; Y]; X]); 224 | impl_as_mut_ref!((X, Y, Z, W, U), [[[[[T; U]; W]; Z]; Y]; X]); 225 | impl_as_mut_ref!((X, Y, Z, W, U, V), [[[[[[T; V]; U]; W]; Z]; Y]; X]); 226 | 227 | impl Borrow> for Array { 228 | fn borrow(&self) -> &Slice { 229 | self 230 | } 231 | } 232 | 233 | impl BorrowMut> for Array { 234 | fn borrow_mut(&mut self) -> &mut Slice { 235 | self 236 | } 237 | } 238 | 239 | impl Debug for Array { 240 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 241 | (**self).fmt(f) 242 | } 243 | } 244 | 245 | impl Deref for Array { 246 | type Target = Slice; 247 | 248 | fn deref(&self) -> &Self::Target { 249 | _ = S::default().checked_len().expect("invalid length"); 250 | 251 | unsafe { &*(self as *const Self as *const Slice) } 252 | } 253 | } 254 | 255 | impl DerefMut for Array { 256 | fn deref_mut(&mut self) -> &mut Self::Target { 257 | _ = S::default().checked_len().expect("invalid length"); 258 | 259 | unsafe { &mut *(self as *mut Self as *mut Slice) } 260 | } 261 | } 262 | 263 | impl From> for Array { 264 | fn from(value: Tensor) -> Self { 265 | Self::from_expr(value.into_expr()) 266 | } 267 | } 268 | 269 | impl<'a, T: 'a + Clone, S: ConstShape, L: Layout, I> From for Array 270 | where 271 | I: IntoExpression>, 272 | { 273 | fn from(value: I) -> Self { 274 | Self::from_expr(value.into_expr().cloned()) 275 | } 276 | } 277 | 278 | macro_rules! impl_from_array { 279 | (($($xyz:tt),+), $array:tt) => { 280 | impl From<&$array> for Array,)+)> { 281 | fn from(array: &$array) -> Self { 282 | Self(array.clone()) 283 | } 284 | } 285 | 286 | impl From,)+)>> for $array { 287 | fn from(array: Array,)+)>) -> Self { 288 | array.0 289 | } 290 | } 291 | 292 | impl From<$array> for Array,)+)> { 293 | fn from(array: $array) -> Self { 294 | Self(array) 295 | } 296 | } 297 | }; 298 | } 299 | 300 | impl_from_array!((X), [T; X]); 301 | impl_from_array!((X, Y), [[T; Y]; X]); 302 | impl_from_array!((X, Y, Z), [[[T; Z]; Y]; X]); 303 | impl_from_array!((X, Y, Z, W), [[[[T; W]; Z]; Y]; X]); 304 | impl_from_array!((X, Y, Z, W, U), [[[[[T; U]; W]; Z]; Y]; X]); 305 | impl_from_array!((X, Y, Z, W, U, V), [[[[[[T; V]; U]; W]; Z]; Y]; X]); 306 | 307 | impl FromExpression for Array { 308 | fn from_expr>(expr: I) -> Self { 309 | Self::from_expr(expr.into_expr()) 310 | } 311 | } 312 | 313 | impl Hash for Array { 314 | fn hash(&self, state: &mut H) { 315 | (**self).hash(state) 316 | } 317 | } 318 | 319 | impl> Index for Array { 320 | type Output = I::Output; 321 | 322 | fn index(&self, index: I) -> &I::Output { 323 | index.index(self) 324 | } 325 | } 326 | 327 | impl> IndexMut for Array { 328 | fn index_mut(&mut self, index: I) -> &mut I::Output { 329 | index.index_mut(self) 330 | } 331 | } 332 | 333 | impl<'a, T, S: ConstShape> IntoExpression for &'a Array { 334 | type Shape = S; 335 | type IntoExpr = View<'a, T, S>; 336 | 337 | fn into_expr(self) -> Self::IntoExpr { 338 | self.expr() 339 | } 340 | } 341 | 342 | impl<'a, T, S: ConstShape> IntoExpression for &'a mut Array { 343 | type Shape = S; 344 | type IntoExpr = ViewMut<'a, T, S>; 345 | 346 | fn into_expr(self) -> Self::IntoExpr { 347 | self.expr_mut() 348 | } 349 | } 350 | 351 | impl IntoExpression for Array { 352 | type Shape = S; 353 | type IntoExpr = IntoExpr, S>>; 354 | 355 | fn into_expr(self) -> Self::IntoExpr { 356 | _ = S::default().checked_len().expect("invalid length"); 357 | 358 | let me = ManuallyDrop::new(self); 359 | 360 | unsafe { IntoExpr::new(mem::transmute_copy(&me)) } 361 | } 362 | } 363 | 364 | impl<'a, T, S: ConstShape> IntoIterator for &'a Array { 365 | type Item = &'a T; 366 | type IntoIter = Iter>; 367 | 368 | fn into_iter(self) -> Self::IntoIter { 369 | self.iter() 370 | } 371 | } 372 | 373 | impl<'a, T, S: ConstShape> IntoIterator for &'a mut Array { 374 | type Item = &'a mut T; 375 | type IntoIter = Iter>; 376 | 377 | fn into_iter(self) -> Self::IntoIter { 378 | self.iter_mut() 379 | } 380 | } 381 | 382 | impl IntoIterator for Array { 383 | type Item = T; 384 | type IntoIter = Iter, S>>>; 385 | 386 | fn into_iter(self) -> Self::IntoIter { 387 | self.into_expr().into_iter() 388 | } 389 | } 390 | 391 | impl Owned for Array { 392 | type WithConst = S::WithConst; 393 | 394 | fn clone_from_slice(&mut self, slice: &Slice) 395 | where 396 | T: Clone, 397 | { 398 | self.assign(slice); 399 | } 400 | } 401 | -------------------------------------------------------------------------------- /src/dim.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{self, Debug, Formatter}; 2 | use std::hash::Hash; 3 | 4 | use crate::shape::Shape; 5 | use crate::tensor::Tensor; 6 | use crate::traits::Owned; 7 | 8 | /// Array dimension trait. 9 | pub trait Dim: Copy + Debug + Default + Hash + Ord + Send + Sync { 10 | /// Merge dimensions, where constant size is preferred over dynamic. 11 | type Merge: Dim; 12 | 13 | #[doc(hidden)] 14 | type Owned: Owned>; 15 | 16 | /// Dimension size if known statically, or `None` if dynamic. 17 | const SIZE: Option; 18 | 19 | /// Creates an array dimension with the given size. 20 | /// 21 | /// # Panics 22 | /// 23 | /// Panics if the size is not matching a constant-sized dimension. 24 | fn from_size(size: usize) -> Self; 25 | 26 | /// Returns the number of elements in the dimension. 27 | fn size(self) -> usize; 28 | } 29 | 30 | #[allow(unreachable_pub)] 31 | pub trait Dims: 32 | AsMut<[T]> 33 | + AsRef<[T]> 34 | + Clone 35 | + Debug 36 | + Default 37 | + Eq 38 | + Hash 39 | + Send 40 | + Sync 41 | + for<'a> TryFrom<&'a [T], Error: Debug> 42 | { 43 | fn new(len: usize) -> Self; 44 | } 45 | 46 | /// Type-level constant. 47 | #[derive(Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] 48 | pub struct Const; 49 | 50 | /// Dynamically-sized dimension type. 51 | pub type Dyn = usize; 52 | 53 | impl Debug for Const { 54 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 55 | f.debug_tuple("Const").field(&N).finish() 56 | } 57 | } 58 | 59 | impl Dim for Const { 60 | type Merge = Self; 61 | type Owned = as Owned>::WithConst; 62 | 63 | const SIZE: Option = Some(N); 64 | 65 | fn from_size(size: usize) -> Self { 66 | assert!(size == N, "invalid size"); 67 | 68 | Self 69 | } 70 | 71 | fn size(self) -> usize { 72 | N 73 | } 74 | } 75 | 76 | impl Dim for Dyn { 77 | type Merge = D; 78 | type Owned = Tensor>; 79 | 80 | const SIZE: Option = None; 81 | 82 | fn from_size(size: usize) -> Self { 83 | size 84 | } 85 | 86 | fn size(self) -> usize { 87 | self 88 | } 89 | } 90 | 91 | macro_rules! impl_dims { 92 | ($($n:tt),+) => { 93 | $( 94 | impl Dims for [T; $n] { 95 | fn new(len: usize) -> Self { 96 | assert!(len == $n, "invalid length"); 97 | 98 | Self::default() 99 | } 100 | } 101 | )+ 102 | }; 103 | } 104 | 105 | impl_dims!(0, 1, 2, 3, 4, 5, 6); 106 | 107 | impl Dims for Box<[T]> { 108 | fn new(len: usize) -> Self { 109 | vec![T::default(); len].into() 110 | } 111 | } 112 | 113 | impl From> for Dyn { 114 | fn from(_: Const) -> Self { 115 | N 116 | } 117 | } 118 | 119 | impl TryFrom for Const { 120 | type Error = Dyn; 121 | 122 | fn try_from(value: Dyn) -> Result { 123 | if value.size() == N { Ok(Self) } else { Err(value) } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/expr/adapters.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter, Result}; 2 | 3 | use crate::expr::expression::{Expression, IntoExpression}; 4 | use crate::expr::iter::Iter; 5 | use crate::shape::Shape; 6 | 7 | /// Expression that clones the elements of an underlying expression. 8 | #[derive(Clone, Debug)] 9 | pub struct Cloned { 10 | expr: E, 11 | } 12 | 13 | /// Expression that copies the elements of an underlying expression. 14 | #[derive(Clone, Debug)] 15 | pub struct Copied { 16 | expr: E, 17 | } 18 | 19 | /// Expression that gives the current count and the element during iteration. 20 | #[derive(Clone)] 21 | pub struct Enumerate { 22 | expr: E, 23 | count: usize, 24 | } 25 | 26 | /// Expression that calls a closure on each element. 27 | #[derive(Clone)] 28 | pub struct Map { 29 | expr: E, 30 | f: F, 31 | } 32 | 33 | /// Expression that gives tuples `(x, y)` of the elements from each expression. 34 | #[derive(Clone)] 35 | pub struct Zip { 36 | a: A, 37 | b: B, 38 | shape: ::Shape, 39 | } 40 | 41 | /// Creates an expression that clones the elements of the argument. 42 | /// 43 | /// # Examples 44 | /// 45 | /// ``` 46 | /// use mdarray::{expr, expr::Expression, view}; 47 | /// 48 | /// let v = view![0, 1, 2]; 49 | /// 50 | /// assert_eq!(expr::cloned(v).eval(), v); 51 | /// ``` 52 | pub fn cloned<'a, T: 'a + Clone, I: IntoExpression>(expr: I) -> Cloned { 53 | expr.into_expr().cloned() 54 | } 55 | 56 | /// Creates an expression that copies the elements of the argument. 57 | /// 58 | /// # Examples 59 | /// 60 | /// ``` 61 | /// use mdarray::{expr, expr::Expression, view}; 62 | /// 63 | /// let v = view![0, 1, 2]; 64 | /// 65 | /// assert_eq!(expr::copied(v).eval(), v); 66 | /// ``` 67 | pub fn copied<'a, T: 'a + Copy, I: IntoExpression>(expr: I) -> Copied { 68 | expr.into_expr().copied() 69 | } 70 | 71 | /// Creates an expression that enumerates the elements of the argument. 72 | /// 73 | /// # Examples 74 | /// 75 | /// ``` 76 | /// use mdarray::{expr, expr::Expression, tensor, view}; 77 | /// 78 | /// let t = tensor![3, 4, 5]; 79 | /// 80 | /// assert_eq!(expr::enumerate(t).eval(), view![(0, 3), (1, 4), (2, 5)]); 81 | /// ``` 82 | pub fn enumerate(expr: I) -> Enumerate { 83 | expr.into_expr().enumerate() 84 | } 85 | 86 | /// Creates an expression that calls a closure on each element of the argument. 87 | /// 88 | /// # Examples 89 | /// 90 | /// ``` 91 | /// use mdarray::{expr, expr::Expression, view}; 92 | /// 93 | /// let v = view![0, 1, 2]; 94 | /// 95 | /// assert_eq!(expr::map(v, |x| 2 * x).eval(), view![0, 2, 4]); 96 | /// ``` 97 | pub fn map T>(expr: I, f: F) -> Map { 98 | expr.into_expr().map(f) 99 | } 100 | 101 | /// Converts the arguments to expressions and zips them. 102 | /// 103 | /// # Panics 104 | /// 105 | /// Panics if the expressions cannot be broadcast to a common shape. 106 | /// 107 | /// # Examples 108 | /// 109 | /// ``` 110 | /// use mdarray::{expr, expr::Expression, tensor, view}; 111 | /// 112 | /// let a = tensor![0, 1, 2]; 113 | /// let b = tensor![3, 4, 5]; 114 | /// 115 | /// assert_eq!(expr::zip(a, b).eval(), view![(0, 3), (1, 4), (2, 5)]); 116 | /// ``` 117 | pub fn zip(a: A, b: B) -> Zip { 118 | a.into_expr().zip(b) 119 | } 120 | 121 | impl Cloned { 122 | pub(crate) fn new(expr: E) -> Self { 123 | Self { expr } 124 | } 125 | } 126 | 127 | impl<'a, T: 'a + Clone, E: Expression> Expression for Cloned { 128 | type Shape = E::Shape; 129 | 130 | const IS_REPEATABLE: bool = E::IS_REPEATABLE; 131 | 132 | fn shape(&self) -> &E::Shape { 133 | self.expr.shape() 134 | } 135 | 136 | unsafe fn get_unchecked(&mut self, index: usize) -> T { 137 | unsafe { self.expr.get_unchecked(index).clone() } 138 | } 139 | 140 | fn inner_rank(&self) -> usize { 141 | self.expr.inner_rank() 142 | } 143 | 144 | unsafe fn reset_dim(&mut self, index: usize, count: usize) { 145 | unsafe { 146 | self.expr.reset_dim(index, count); 147 | } 148 | } 149 | 150 | unsafe fn step_dim(&mut self, index: usize) { 151 | unsafe { 152 | self.expr.step_dim(index); 153 | } 154 | } 155 | } 156 | 157 | impl<'a, T: 'a + Clone, E: Expression> IntoIterator for Cloned { 158 | type Item = T; 159 | type IntoIter = Iter; 160 | 161 | fn into_iter(self) -> Self::IntoIter { 162 | Iter::new(self) 163 | } 164 | } 165 | 166 | impl Copied { 167 | pub(crate) fn new(expr: E) -> Self { 168 | Self { expr } 169 | } 170 | } 171 | 172 | impl<'a, T: 'a + Copy, E: Expression> Expression for Copied { 173 | type Shape = E::Shape; 174 | 175 | const IS_REPEATABLE: bool = E::IS_REPEATABLE; 176 | 177 | fn shape(&self) -> &E::Shape { 178 | self.expr.shape() 179 | } 180 | 181 | unsafe fn get_unchecked(&mut self, index: usize) -> T { 182 | unsafe { *self.expr.get_unchecked(index) } 183 | } 184 | 185 | fn inner_rank(&self) -> usize { 186 | self.expr.inner_rank() 187 | } 188 | 189 | unsafe fn reset_dim(&mut self, index: usize, count: usize) { 190 | unsafe { 191 | self.expr.reset_dim(index, count); 192 | } 193 | } 194 | 195 | unsafe fn step_dim(&mut self, index: usize) { 196 | unsafe { 197 | self.expr.step_dim(index); 198 | } 199 | } 200 | } 201 | 202 | impl<'a, T: 'a + Copy, E: Expression> IntoIterator for Copied { 203 | type Item = T; 204 | type IntoIter = Iter; 205 | 206 | fn into_iter(self) -> Self::IntoIter { 207 | Iter::new(self) 208 | } 209 | } 210 | 211 | impl Enumerate { 212 | pub(crate) fn new(expr: E) -> Self { 213 | Self { expr, count: 0 } 214 | } 215 | } 216 | 217 | impl Debug for Enumerate { 218 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 219 | f.debug_struct("Enumerate").field("expr", &self.expr).finish() 220 | } 221 | } 222 | 223 | impl Expression for Enumerate { 224 | type Shape = E::Shape; 225 | 226 | const IS_REPEATABLE: bool = E::IS_REPEATABLE; 227 | 228 | fn shape(&self) -> &E::Shape { 229 | self.expr.shape() 230 | } 231 | 232 | unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item { 233 | self.count += 1; 234 | 235 | unsafe { (self.count - 1, self.expr.get_unchecked(index)) } 236 | } 237 | 238 | fn inner_rank(&self) -> usize { 239 | self.expr.inner_rank() 240 | } 241 | 242 | unsafe fn reset_dim(&mut self, index: usize, count: usize) { 243 | unsafe { 244 | self.expr.reset_dim(index, count); 245 | } 246 | } 247 | 248 | unsafe fn step_dim(&mut self, index: usize) { 249 | unsafe { 250 | self.expr.step_dim(index); 251 | } 252 | } 253 | } 254 | 255 | impl IntoIterator for Enumerate { 256 | type Item = (usize, E::Item); 257 | type IntoIter = Iter; 258 | 259 | fn into_iter(self) -> Self::IntoIter { 260 | Iter::new(self) 261 | } 262 | } 263 | 264 | impl Map { 265 | pub(crate) fn new(expr: E, f: F) -> Self { 266 | Self { expr, f } 267 | } 268 | } 269 | 270 | impl Debug for Map { 271 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 272 | f.debug_struct("Map").field("expr", &self.expr).finish() 273 | } 274 | } 275 | 276 | impl T> Expression for Map { 277 | type Shape = E::Shape; 278 | 279 | const IS_REPEATABLE: bool = E::IS_REPEATABLE; 280 | 281 | fn shape(&self) -> &E::Shape { 282 | self.expr.shape() 283 | } 284 | 285 | unsafe fn get_unchecked(&mut self, index: usize) -> T { 286 | unsafe { (self.f)(self.expr.get_unchecked(index)) } 287 | } 288 | 289 | fn inner_rank(&self) -> usize { 290 | self.expr.inner_rank() 291 | } 292 | 293 | unsafe fn reset_dim(&mut self, index: usize, count: usize) { 294 | unsafe { 295 | self.expr.reset_dim(index, count); 296 | } 297 | } 298 | 299 | unsafe fn step_dim(&mut self, index: usize) { 300 | unsafe { 301 | self.expr.step_dim(index); 302 | } 303 | } 304 | } 305 | 306 | impl T> IntoIterator for Map { 307 | type Item = T; 308 | type IntoIter = Iter; 309 | 310 | fn into_iter(self) -> Self::IntoIter { 311 | Iter::new(self) 312 | } 313 | } 314 | 315 | impl Zip { 316 | pub(crate) fn new(a: A, b: B) -> Self { 317 | assert!(A::IS_REPEATABLE || a.rank() >= b.rank(), "expression not repeatable"); 318 | assert!(B::IS_REPEATABLE || b.rank() >= a.rank(), "expression not repeatable"); 319 | 320 | let shape = a.shape().with_dims(|a_dims| { 321 | b.shape().with_dims(|b_dims| { 322 | let dims = if a_dims.len() < b_dims.len() { b_dims } else { a_dims }; 323 | let inner_match = 324 | a_dims[dims.len() - b_dims.len()..] == b_dims[dims.len() - a_dims.len()..]; 325 | 326 | assert!(inner_match, "inner dimensions mismatch"); 327 | 328 | Shape::from_dims(dims) 329 | }) 330 | }); 331 | 332 | Self { a, b, shape } 333 | } 334 | } 335 | 336 | impl Debug for Zip { 337 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 338 | f.debug_struct("Zip").field("a", &self.a).field("b", &self.b).finish() 339 | } 340 | } 341 | 342 | impl Expression for Zip 343 | where 344 | A: Expression, 345 | B: Expression, 346 | { 347 | type Shape = <::Merge as Shape>::Reverse; 348 | 349 | const IS_REPEATABLE: bool = A::IS_REPEATABLE && B::IS_REPEATABLE; 350 | 351 | fn shape(&self) -> &Self::Shape { 352 | &self.shape 353 | } 354 | 355 | unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item { 356 | unsafe { (self.a.get_unchecked(index), self.b.get_unchecked(index)) } 357 | } 358 | 359 | fn inner_rank(&self) -> usize { 360 | self.a.inner_rank().min(self.b.inner_rank()) 361 | } 362 | 363 | unsafe fn reset_dim(&mut self, index: usize, count: usize) { 364 | let delta = self.shape.rank() - index; 365 | 366 | unsafe { 367 | if delta <= self.a.rank() { 368 | self.a.reset_dim(self.a.rank() - delta, count); 369 | } 370 | 371 | if delta <= self.b.rank() { 372 | self.b.reset_dim(self.b.rank() - delta, count); 373 | } 374 | } 375 | } 376 | 377 | unsafe fn step_dim(&mut self, index: usize) { 378 | let delta = self.shape.rank() - index; 379 | 380 | unsafe { 381 | if delta <= self.a.rank() { 382 | self.a.step_dim(self.a.rank() - delta); 383 | } 384 | 385 | if delta <= self.b.rank() { 386 | self.b.step_dim(self.b.rank() - delta); 387 | } 388 | } 389 | } 390 | } 391 | 392 | impl IntoIterator for Zip { 393 | type Item = (A::Item, B::Item); 394 | type IntoIter = Iter; 395 | 396 | fn into_iter(self) -> Self::IntoIter { 397 | Iter::new(self) 398 | } 399 | } 400 | -------------------------------------------------------------------------------- /src/expr/buffer.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nightly")] 2 | use std::alloc::{Allocator, Global}; 3 | use std::mem::ManuallyDrop; 4 | use std::ptr; 5 | 6 | #[cfg(not(feature = "nightly"))] 7 | use crate::alloc::{Allocator, Global}; 8 | use crate::array::Array; 9 | use crate::dim::Const; 10 | use crate::index::Axis; 11 | use crate::mapping::Mapping; 12 | use crate::shape::{ConstShape, Shape}; 13 | use crate::slice::Slice; 14 | use crate::tensor::Tensor; 15 | use crate::view::ViewMut; 16 | 17 | /// Array buffer trait, for moving elements out of an array. 18 | pub trait Buffer { 19 | /// Array element type. 20 | type Item; 21 | 22 | /// Array shape type. 23 | type Shape: Shape; 24 | 25 | #[doc(hidden)] 26 | fn as_mut_slice(&mut self) -> &mut Slice, Self::Shape>; 27 | 28 | #[doc(hidden)] 29 | fn as_slice(&self) -> &Slice, Self::Shape>; 30 | } 31 | 32 | /// Buffer for moving elements out of an array range. 33 | pub struct Drain<'a, T, S: Shape, A: Allocator = Global> { 34 | tensor: &'a mut Tensor, 35 | view: ViewMut<'a, ManuallyDrop, S>, 36 | new_size: usize, 37 | tail: usize, 38 | } 39 | 40 | impl<'a, T, S: Shape, A: Allocator> Drain<'a, T, S, A> { 41 | pub(crate) fn new(tensor: &'a mut Tensor, start: usize, end: usize) -> Self { 42 | assert!(start <= end && end <= tensor.dim(0), "invalid range"); 43 | 44 | let new_size = tensor.dim(0) - (end - start); 45 | let tail = Axis::resize(Const::<0>, tensor.mapping(), new_size - start).len(); 46 | 47 | // Shrink the array, to be safe in case Drain is leaked. 48 | unsafe { 49 | tensor.set_mapping(Mapping::resize_dim(tensor.mapping(), 0, start)); 50 | } 51 | 52 | let ptr = unsafe { tensor.as_mut_ptr().add(tensor.len()) as *mut ManuallyDrop }; 53 | let mapping = Mapping::resize_dim(tensor.mapping(), 0, end - start); 54 | 55 | let view = unsafe { ViewMut::new_unchecked(ptr, mapping) }; 56 | 57 | Self { tensor, view, new_size, tail } 58 | } 59 | } 60 | 61 | impl Buffer for Drain<'_, T, S, A> { 62 | type Item = T; 63 | type Shape = S; 64 | 65 | fn as_mut_slice(&mut self) -> &mut Slice, S> { 66 | &mut self.view 67 | } 68 | 69 | fn as_slice(&self) -> &Slice, S> { 70 | &self.view 71 | } 72 | } 73 | 74 | impl Drop for Drain<'_, T, S, A> { 75 | fn drop(&mut self) { 76 | let mapping = Mapping::resize_dim(self.tensor.mapping(), 0, self.new_size); 77 | 78 | unsafe { 79 | ptr::copy(self.view.as_ptr().add(self.view.len()), self.view.as_mut_ptr(), self.tail); 80 | self.tensor.set_mapping(mapping); 81 | } 82 | } 83 | } 84 | 85 | impl Buffer for Array, S> { 86 | type Item = T; 87 | type Shape = S; 88 | 89 | fn as_mut_slice(&mut self) -> &mut Slice, S> { 90 | self 91 | } 92 | 93 | fn as_slice(&self) -> &Slice, S> { 94 | self 95 | } 96 | } 97 | 98 | impl Buffer for Tensor, S, A> { 99 | type Item = T; 100 | type Shape = S; 101 | 102 | fn as_mut_slice(&mut self) -> &mut Slice, S> { 103 | self 104 | } 105 | 106 | fn as_slice(&self) -> &Slice, S> { 107 | self 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/expr/expression.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nightly")] 2 | use std::alloc::Allocator; 3 | 4 | #[cfg(not(feature = "nightly"))] 5 | use crate::alloc::Allocator; 6 | use crate::expr::adapters::{Cloned, Copied, Enumerate, Map, Zip}; 7 | use crate::expr::iter::Iter; 8 | use crate::shape::Shape; 9 | use crate::tensor::Tensor; 10 | use crate::traits::IntoCloned; 11 | 12 | /// Trait for applying a closure and returning an existing array or an expression. 13 | pub trait Apply: IntoExpression { 14 | /// The resulting type after applying a closure. 15 | type Output T>: IntoExpression; 16 | 17 | /// The resulting type after zipping elements and applying a closure. 18 | type ZippedWith: IntoExpression 19 | where 20 | F: FnMut((Self::Item, I::Item)) -> T; 21 | 22 | /// Returns the array or an expression with the given closure applied to each element. 23 | fn apply T>(self, f: F) -> Self::Output; 24 | 25 | /// Returns the array or an expression with the given closure applied to zipped element pairs. 26 | fn zip_with(self, expr: I, f: F) -> Self::ZippedWith 27 | where 28 | F: FnMut((Self::Item, I::Item)) -> T; 29 | } 30 | 31 | /// Expression trait, for multidimensional iteration. 32 | pub trait Expression: IntoIterator { 33 | /// Array shape type. 34 | type Shape: Shape; 35 | 36 | /// True if the expression can be restarted from the beginning after the last element. 37 | const IS_REPEATABLE: bool; 38 | 39 | /// Returns the array shape. 40 | fn shape(&self) -> &Self::Shape; 41 | 42 | /// Creates an expression which clones all of its elements. 43 | fn cloned<'a, T: 'a + Clone>(self) -> Cloned 44 | where 45 | Self: Expression + Sized, 46 | { 47 | Cloned::new(self) 48 | } 49 | 50 | /// Creates an expression which copies all of its elements. 51 | fn copied<'a, T: 'a + Copy>(self) -> Copied 52 | where 53 | Self: Expression + Sized, 54 | { 55 | Copied::new(self) 56 | } 57 | 58 | /// Returns the number of elements in the specified dimension. 59 | /// 60 | /// # Panics 61 | /// 62 | /// Panics if the dimension is out of bounds. 63 | fn dim(&self, index: usize) -> usize { 64 | self.shape().dim(index) 65 | } 66 | 67 | /// Creates an expression which gives tuples of the current count and the element. 68 | fn enumerate(self) -> Enumerate 69 | where 70 | Self: Sized, 71 | { 72 | Enumerate::new(self) 73 | } 74 | 75 | /// Determines if the elements of the expression are equal to those of another. 76 | fn eq(self, other: I) -> bool 77 | where 78 | Self: Expression> + Sized, 79 | { 80 | self.eq_by(other, |x, y| x == y) 81 | } 82 | 83 | /// Determines if the elements of the expression are equal to those of another 84 | /// with respect to the specified equality function. 85 | fn eq_by(self, other: I, mut eq: F) -> bool 86 | where 87 | Self: Sized, 88 | F: FnMut(Self::Item, I::Item) -> bool, 89 | { 90 | let other = other.into_expr(); 91 | 92 | self.shape().with_dims(|dims| other.shape().with_dims(|other| dims == other)) 93 | && self.zip(other).into_iter().all(|(x, y)| eq(x, y)) 94 | } 95 | 96 | /// Evaluates the expression into a new array. 97 | /// 98 | /// The resulting type is `Array` if the shape has constant-sized dimensions, or 99 | /// otherwise `Tensor`. If the shape type is generic, `FromExpression::from_expr` 100 | /// can be used to evaluate the expression into a specific array type. 101 | fn eval(self) -> ::Owned 102 | where 103 | Self: Sized, 104 | { 105 | FromExpression::from_expr(self) 106 | } 107 | 108 | /// Evaluates the expression with broadcasting and appends to the given array 109 | /// along the first dimension. 110 | /// 111 | /// If the array is empty, it is reshaped to match the shape of the expression. 112 | /// 113 | /// # Panics 114 | /// 115 | /// Panics if the inner dimensions do not match, if the rank is not the same and 116 | /// at least 1, or if the first dimension is not dynamically-sized. 117 | fn eval_into( 118 | self, 119 | tensor: &mut Tensor, 120 | ) -> &mut Tensor 121 | where 122 | Self: Sized, 123 | { 124 | tensor.expand(self); 125 | tensor 126 | } 127 | 128 | /// Folds all elements into an accumulator by applying an operation, and returns the result. 129 | fn fold T>(self, init: T, f: F) -> T 130 | where 131 | Self: Sized, 132 | { 133 | Iter::new(self).fold(init, f) 134 | } 135 | 136 | /// Calls a closure on each element of the expression. 137 | fn for_each(self, mut f: F) 138 | where 139 | Self: Sized, 140 | { 141 | self.fold((), |(), x| f(x)); 142 | } 143 | 144 | /// Returns `true` if the array contains no elements. 145 | fn is_empty(&self) -> bool { 146 | self.shape().is_empty() 147 | } 148 | 149 | /// Returns the number of elements in the array. 150 | fn len(&self) -> usize { 151 | self.shape().len() 152 | } 153 | 154 | /// Creates an expression that calls a closure on each element. 155 | fn map T>(self, f: F) -> Map 156 | where 157 | Self: Sized, 158 | { 159 | Map::new(self, f) 160 | } 161 | 162 | /// Determines if the elements of the expression are not equal to those of another. 163 | fn ne(self, other: I) -> bool 164 | where 165 | Self: Expression> + Sized, 166 | { 167 | !self.eq(other) 168 | } 169 | 170 | /// Returns the array rank, i.e. the number of dimensions. 171 | fn rank(&self) -> usize { 172 | self.shape().rank() 173 | } 174 | 175 | /// Creates an expression that gives tuples `(x, y)` of the elements from each expression. 176 | /// 177 | /// # Panics 178 | /// 179 | /// Panics if the expressions cannot be broadcast to a common shape. 180 | fn zip(self, other: I) -> Zip 181 | where 182 | Self: Sized, 183 | { 184 | Zip::new(self, other.into_expr()) 185 | } 186 | 187 | #[doc(hidden)] 188 | unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item; 189 | 190 | #[doc(hidden)] 191 | fn inner_rank(&self) -> usize; 192 | 193 | #[doc(hidden)] 194 | unsafe fn reset_dim(&mut self, index: usize, count: usize); 195 | 196 | #[doc(hidden)] 197 | unsafe fn step_dim(&mut self, index: usize); 198 | 199 | #[cfg(not(feature = "nightly"))] 200 | #[doc(hidden)] 201 | fn clone_into_vec(self, vec: &mut Vec) 202 | where 203 | Self: Expression> + Sized, 204 | { 205 | assert!(self.len() <= vec.capacity() - vec.len(), "length exceeds capacity"); 206 | 207 | self.for_each(|x| unsafe { 208 | vec.as_mut_ptr().add(vec.len()).write(x.into_cloned()); 209 | vec.set_len(vec.len() + 1); 210 | }); 211 | } 212 | 213 | #[cfg(feature = "nightly")] 214 | #[doc(hidden)] 215 | fn clone_into_vec(self, vec: &mut Vec) 216 | where 217 | Self: Expression> + Sized, 218 | { 219 | assert!(self.len() <= vec.capacity() - vec.len(), "length exceeds capacity"); 220 | 221 | self.for_each(|x| unsafe { 222 | vec.as_mut_ptr().add(vec.len()).write(x.into_cloned()); 223 | vec.set_len(vec.len() + 1); 224 | }); 225 | } 226 | } 227 | 228 | /// Conversion trait from an expression. 229 | pub trait FromExpression: Sized { 230 | /// Creates an array from an expression. 231 | fn from_expr>(expr: I) -> Self; 232 | } 233 | 234 | /// Conversion trait into an expression. 235 | pub trait IntoExpression: IntoIterator { 236 | /// Array shape type. 237 | type Shape: Shape; 238 | 239 | /// Which kind of expression are we turning this into? 240 | type IntoExpr: Expression; 241 | 242 | /// Creates an expression from a value. 243 | fn into_expr(self) -> Self::IntoExpr; 244 | } 245 | 246 | impl Apply for E { 247 | type Output T> = Map; 248 | type ZippedWith T> = 249 | Map, F>; 250 | 251 | fn apply T>(self, f: F) -> Self::Output { 252 | self.map(f) 253 | } 254 | 255 | fn zip_with(self, expr: I, f: F) -> Self::ZippedWith 256 | where 257 | F: FnMut((Self::Item, I::Item)) -> T, 258 | { 259 | self.zip(expr).map(f) 260 | } 261 | } 262 | 263 | impl IntoExpression for E { 264 | type Shape = E::Shape; 265 | type IntoExpr = E; 266 | 267 | fn into_expr(self) -> Self { 268 | self 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /src/expr/into_expr.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter, Result}; 2 | use std::mem::ManuallyDrop; 3 | use std::ptr; 4 | 5 | use crate::expr::buffer::Buffer; 6 | use crate::expr::expression::Expression; 7 | use crate::expr::iter::Iter; 8 | use crate::slice::Slice; 9 | 10 | /// Expression that moves elements out of an array. 11 | pub struct IntoExpr { 12 | buffer: B, 13 | index: usize, 14 | } 15 | 16 | impl IntoExpr { 17 | pub(crate) fn new(buffer: B) -> Self { 18 | Self { buffer, index: 0 } 19 | } 20 | } 21 | 22 | impl AsMut> for IntoExpr { 23 | fn as_mut(&mut self) -> &mut Slice { 24 | debug_assert!(self.index == 0, "expression in use"); 25 | 26 | unsafe { 27 | &mut *(self.buffer.as_mut_slice() as *mut Slice, B::Shape> 28 | as *mut Slice) 29 | } 30 | } 31 | } 32 | 33 | impl AsRef> for IntoExpr { 34 | fn as_ref(&self) -> &Slice { 35 | debug_assert!(self.index == 0, "expression in use"); 36 | 37 | unsafe { 38 | &*(self.buffer.as_slice() as *const Slice, B::Shape> 39 | as *const Slice) 40 | } 41 | } 42 | } 43 | 44 | impl Clone for IntoExpr { 45 | fn clone(&self) -> Self { 46 | assert!(self.index == 0, "expression in use"); 47 | 48 | Self { buffer: self.buffer.clone(), index: 0 } 49 | } 50 | 51 | fn clone_from(&mut self, source: &Self) { 52 | assert!(self.index == 0 && source.index == 0, "expression in use"); 53 | 54 | self.buffer.clone_from(&source.buffer); 55 | } 56 | } 57 | 58 | impl> Debug for IntoExpr { 59 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 60 | f.debug_tuple("IntoExpr").field(&self.as_ref()).finish() 61 | } 62 | } 63 | 64 | impl Default for IntoExpr { 65 | fn default() -> Self { 66 | Self { buffer: Default::default(), index: 0 } 67 | } 68 | } 69 | 70 | impl Drop for IntoExpr { 71 | fn drop(&mut self) { 72 | unsafe { 73 | let ptr = self.buffer.as_mut_slice().as_mut_ptr().add(self.index) as *mut B::Item; 74 | let len = self.buffer.as_slice().len() - self.index; 75 | 76 | ptr::slice_from_raw_parts_mut(ptr, len).drop_in_place(); 77 | } 78 | } 79 | } 80 | 81 | impl Expression for IntoExpr { 82 | type Shape = B::Shape; 83 | 84 | const IS_REPEATABLE: bool = false; 85 | 86 | fn shape(&self) -> &Self::Shape { 87 | self.buffer.as_slice().shape() 88 | } 89 | 90 | unsafe fn get_unchecked(&mut self, _: usize) -> B::Item { 91 | debug_assert!(self.index < self.buffer.as_slice().len(), "index out of bounds"); 92 | 93 | self.index += 1; // Keep track of that the element is moved out. 94 | 95 | unsafe { 96 | ManuallyDrop::take(&mut *self.buffer.as_mut_slice().as_mut_ptr().add(self.index - 1)) 97 | } 98 | } 99 | 100 | fn inner_rank(&self) -> usize { 101 | usize::MAX 102 | } 103 | 104 | unsafe fn reset_dim(&mut self, _: usize, _: usize) {} 105 | unsafe fn step_dim(&mut self, _: usize) {} 106 | } 107 | 108 | impl IntoIterator for IntoExpr { 109 | type Item = B::Item; 110 | type IntoIter = Iter; 111 | 112 | fn into_iter(self) -> Iter { 113 | Iter::new(self) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/expr/iter.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter, Result}; 2 | use std::iter::FusedIterator; 3 | 4 | use crate::dim::Dims; 5 | use crate::expr::expression::Expression; 6 | use crate::shape::Shape; 7 | 8 | /// Iterator type for array expressions. 9 | #[derive(Clone)] 10 | pub struct Iter { 11 | expr: E, 12 | inner_index: usize, 13 | inner_limit: usize, 14 | outer_index: ::Dims, 15 | outer_limit: ::Dims, 16 | } 17 | 18 | impl Iter { 19 | pub(crate) fn new(expr: E) -> Self { 20 | let outer_rank = expr.rank().saturating_sub(expr.inner_rank()); 21 | 22 | let inner_index = 0; 23 | let inner_limit = expr.shape().with_dims(|dims| dims[outer_rank..].iter().product()); 24 | 25 | let mut outer_index = Default::default(); 26 | let mut outer_limit = Default::default(); 27 | 28 | if outer_rank > 0 { 29 | outer_index = Dims::new(expr.rank()); 30 | outer_limit = 31 | expr.shape().with_dims(|dims| TryFrom::try_from(dims).expect("invalid rank")); 32 | } 33 | 34 | Self { expr, inner_index, inner_limit, outer_index, outer_limit } 35 | } 36 | 37 | unsafe fn step_outer(&mut self) -> bool { 38 | let outer_rank = self.expr.rank().saturating_sub(self.expr.inner_rank()); 39 | 40 | unsafe { 41 | // If the inner rank is >0, reset the last dimension when stepping outer dimensions. 42 | // This is needed in the FromFn implementation. 43 | if outer_rank < self.expr.rank() { 44 | self.expr.reset_dim(self.expr.rank() - 1, 0); 45 | } 46 | 47 | for i in (0..outer_rank).rev() { 48 | if self.outer_index.as_ref()[i] + 1 < self.outer_limit.as_ref()[i] { 49 | self.expr.step_dim(i); 50 | self.outer_index.as_mut()[i] += 1; 51 | 52 | return true; 53 | } 54 | 55 | self.expr.reset_dim(i, self.outer_index.as_ref()[i]); 56 | self.outer_index.as_mut()[i] = 0; 57 | } 58 | } 59 | 60 | self.outer_index.as_mut().fill(0); // Ensure that following calls return false. 61 | 62 | false 63 | } 64 | } 65 | 66 | impl Debug for Iter { 67 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 68 | assert!(self.inner_index == 0, "iterator in use"); 69 | 70 | f.debug_tuple("Iter").field(&self.expr).finish() 71 | } 72 | } 73 | 74 | impl ExactSizeIterator for Iter {} 75 | impl FusedIterator for Iter {} 76 | 77 | impl Iterator for Iter { 78 | type Item = E::Item; 79 | 80 | fn fold T>(mut self, init: T, mut f: F) -> T { 81 | let mut accum = init; 82 | 83 | loop { 84 | for i in self.inner_index..self.inner_limit { 85 | accum = f(accum, unsafe { self.expr.get_unchecked(i) }); 86 | } 87 | 88 | if unsafe { !self.step_outer() } { 89 | return accum; 90 | } 91 | 92 | self.inner_index = 0; 93 | } 94 | } 95 | 96 | fn next(&mut self) -> Option { 97 | if self.inner_index == self.inner_limit { 98 | if unsafe { !self.step_outer() } { 99 | return None; 100 | } 101 | 102 | self.inner_index = 0; 103 | } 104 | 105 | self.inner_index += 1; 106 | 107 | unsafe { Some(self.expr.get_unchecked(self.inner_index - 1)) } 108 | } 109 | 110 | fn size_hint(&self) -> (usize, Option) { 111 | let outer_rank = self.expr.rank().saturating_sub(self.expr.inner_rank()); 112 | let mut len = 1; 113 | 114 | for i in 0..outer_rank { 115 | len = len * self.outer_limit.as_ref()[i] - self.outer_index.as_ref()[i]; 116 | } 117 | 118 | len = len * self.inner_limit - self.inner_index; 119 | 120 | (len, Some(len)) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/expr/mod.rs: -------------------------------------------------------------------------------- 1 | //! Expression module, for multidimensional iteration. 2 | 3 | mod adapters; 4 | mod buffer; 5 | mod expression; 6 | mod into_expr; 7 | mod iter; 8 | mod sources; 9 | 10 | pub use adapters::{Cloned, Copied, Enumerate, Map, Zip, cloned, copied, enumerate, map, zip}; 11 | pub use buffer::{Buffer, Drain}; 12 | pub use expression::{Apply, Expression, FromExpression, IntoExpression}; 13 | pub use into_expr::IntoExpr; 14 | pub use iter::Iter; 15 | pub use sources::{AxisExpr, AxisExprMut, Lanes, LanesMut}; 16 | pub use sources::{Fill, FillWith, FromElem, FromFn, fill, fill_with, from_elem, from_fn}; 17 | 18 | /// Folds all elements of the argument into an accumulator by applying an operation, 19 | /// and returns the result. 20 | /// 21 | /// # Examples 22 | /// 23 | /// ``` 24 | /// use mdarray::{expr, view}; 25 | /// 26 | /// let v = view![0, 1, 2]; 27 | /// 28 | /// assert_eq!(expr::fold(v, 0, |acc, x| acc + x), 3); 29 | /// ``` 30 | pub fn fold T>(expr: I, init: T, f: F) -> T { 31 | expr.into_expr().fold(init, f) 32 | } 33 | 34 | /// Calls a closure on each element of the argument. 35 | /// 36 | /// # Examples 37 | /// 38 | /// ``` 39 | /// use mdarray::{expr, tensor, view}; 40 | /// 41 | /// let mut t = tensor![0, 1, 2]; 42 | /// 43 | /// expr::for_each(&mut t, |x| *x *= 2); 44 | /// 45 | /// assert_eq!(t, view![0, 2, 4]); 46 | /// ``` 47 | pub fn for_each(expr: I, f: F) { 48 | expr.into_expr().for_each(f); 49 | } 50 | -------------------------------------------------------------------------------- /src/expr/sources.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter, Result}; 2 | 3 | use crate::expr::expression::Expression; 4 | use crate::expr::iter::Iter; 5 | use crate::index::{Axis, Keep, Split}; 6 | use crate::layout::Layout; 7 | use crate::mapping::Mapping; 8 | use crate::shape::{IntoShape, Shape}; 9 | use crate::slice::Slice; 10 | use crate::view::{View, ViewMut}; 11 | 12 | /// Array axis expression. 13 | pub struct AxisExpr<'a, T, S: Shape, L: Layout, A: Axis> { 14 | slice: &'a Slice, 15 | axis: A, 16 | mapping: as Layout>::Mapping<(A::Dim,)>, 17 | offset: isize, 18 | } 19 | 20 | /// Mutable array axis expression. 21 | pub struct AxisExprMut<'a, T, S: Shape, L: Layout, A: Axis> { 22 | slice: &'a mut Slice, 23 | axis: A, 24 | mapping: as Layout>::Mapping<(A::Dim,)>, 25 | offset: isize, 26 | } 27 | 28 | /// Expression that repeats an element by cloning. 29 | #[derive(Clone)] 30 | pub struct Fill { 31 | value: T, 32 | } 33 | 34 | /// Expression that gives elements by calling a closure repeatedly. 35 | #[derive(Clone)] 36 | pub struct FillWith { 37 | f: F, 38 | } 39 | 40 | /// Expression with a defined shape that repeats an element by cloning. 41 | #[derive(Clone)] 42 | pub struct FromElem { 43 | shape: S, 44 | elem: T, 45 | } 46 | 47 | /// Expression with a defined shape and elements from the given function. 48 | #[derive(Clone)] 49 | pub struct FromFn { 50 | shape: S, 51 | f: F, 52 | index: S::Dims, 53 | } 54 | 55 | /// Array lanes expression. 56 | pub struct Lanes<'a, T, S: Shape, L: Layout, A: Axis> { 57 | slice: &'a Slice, 58 | axis: A, 59 | mapping: as Layout>::Mapping>, 60 | offset: isize, 61 | } 62 | 63 | /// Mutable array lanes expression. 64 | pub struct LanesMut<'a, T, S: Shape, L: Layout, A: Axis> { 65 | slice: &'a mut Slice, 66 | axis: A, 67 | mapping: as Layout>::Mapping>, 68 | offset: isize, 69 | } 70 | 71 | /// Creates an expression with elements by cloning `value`. 72 | /// 73 | /// # Examples 74 | /// 75 | /// ``` 76 | /// use mdarray::{expr, tensor, view}; 77 | /// 78 | /// let mut t = tensor![0; 3]; 79 | /// 80 | /// t.assign(expr::fill(1)); 81 | /// 82 | /// assert_eq!(t, view![1; 3]); 83 | /// ``` 84 | pub fn fill(value: T) -> Fill { 85 | Fill::new(value) 86 | } 87 | 88 | /// Creates an expression with elements returned by calling a closure repeatedly. 89 | /// 90 | /// # Examples 91 | /// 92 | /// ``` 93 | /// use mdarray::{expr, tensor, view}; 94 | /// 95 | /// let mut t = tensor![0; 3]; 96 | /// 97 | /// t.assign(expr::fill_with(|| 1)); 98 | /// 99 | /// assert_eq!(t, view![1; 3]); 100 | /// ``` 101 | pub fn fill_with T>(f: F) -> FillWith { 102 | FillWith::new(f) 103 | } 104 | 105 | /// Creates an expression with the given shape and elements by cloning `value`. 106 | /// 107 | /// # Examples 108 | /// 109 | /// ``` 110 | /// use mdarray::{view, expr, expr::Expression}; 111 | /// 112 | /// assert_eq!(expr::from_elem([2, 3], 1).eval(), view![[1; 3]; 2]); 113 | /// ``` 114 | pub fn from_elem(shape: I, elem: T) -> FromElem { 115 | FromElem::new(shape.into_shape(), elem) 116 | } 117 | 118 | /// Creates an expression with the given shape and elements from the given function. 119 | /// 120 | /// # Examples 121 | /// 122 | /// ``` 123 | /// use mdarray::{expr, expr::Expression, view}; 124 | /// 125 | /// assert_eq!(expr::from_fn([2, 3], |i| 3 * i[0] + i[1] + 1).eval(), view![[1, 2, 3], [4, 5, 6]]); 126 | /// ``` 127 | pub fn from_fn(shape: I, f: F) -> FromFn 128 | where 129 | F: FnMut(&[usize]) -> T, 130 | { 131 | FromFn::new(shape.into_shape(), f) 132 | } 133 | 134 | macro_rules! impl_axis_expr { 135 | ($name:tt, $expr:tt, $as_ptr:tt, {$($mut:tt)?}, $repeatable:tt) => { 136 | impl<'a, T, S: Shape, L: Layout, A: Axis> $name<'a, T, S, L, A> { 137 | pub(crate) fn new( 138 | slice: &'a $($mut)? Slice, 139 | axis: A, 140 | ) -> Self { 141 | let mapping = axis.get(slice.mapping()); 142 | 143 | Self { slice, axis, mapping, offset: 0 } 144 | } 145 | } 146 | 147 | impl<'a, T: Debug, S: Shape, L: Layout, A: Axis> Debug for $name<'a, T, S, L, A> { 148 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 149 | let index = self.axis.index(self.slice.rank()); 150 | 151 | f.debug_tuple(stringify!($name)).field(&index).field(&self.slice).finish() 152 | } 153 | } 154 | 155 | impl<'a, T, S: Shape, L: Layout, A: Axis> Expression for $name<'a, T, S, L, A> { 156 | type Shape = (A::Dim,); 157 | 158 | const IS_REPEATABLE: bool = $repeatable; 159 | 160 | fn shape(&self) -> &Self::Shape { 161 | self.mapping.shape() 162 | } 163 | 164 | unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item { 165 | let offset = self.offset + self.mapping.inner_stride() * index as isize; 166 | 167 | let mapping = self.axis.remove(self.slice.mapping()); 168 | let len = mapping.shape().checked_len().expect("invalid length"); 169 | 170 | // If the view is empty, we must not offset the pointer. 171 | let count = if len == 0 { 0 } else { offset }; 172 | 173 | unsafe { $expr::new_unchecked(self.slice.$as_ptr().offset(count), mapping) } 174 | } 175 | 176 | fn inner_rank(&self) -> usize { 177 | 1 178 | } 179 | 180 | unsafe fn reset_dim(&mut self, _: usize, _: usize) { 181 | self.offset = 0; 182 | } 183 | 184 | unsafe fn step_dim(&mut self, _: usize) { 185 | self.offset += self.mapping.inner_stride(); 186 | } 187 | } 188 | 189 | impl<'a, T, S: Shape, L: Layout, A: Axis> IntoIterator for $name<'a, T, S, L, A> { 190 | type Item = $expr<'a, T, A::Remove, Split>; 191 | type IntoIter = Iter; 192 | 193 | fn into_iter(self) -> Iter { 194 | Iter::new(self) 195 | } 196 | } 197 | }; 198 | } 199 | 200 | impl_axis_expr!(AxisExpr, View, as_ptr, {}, true); 201 | impl_axis_expr!(AxisExprMut, ViewMut, as_mut_ptr, {mut}, false); 202 | 203 | impl Clone for AxisExpr<'_, T, S, L, A> { 204 | fn clone(&self) -> Self { 205 | Self { 206 | slice: self.slice, 207 | axis: self.axis, 208 | mapping: self.mapping.clone(), 209 | offset: self.offset, 210 | } 211 | } 212 | 213 | fn clone_from(&mut self, source: &Self) { 214 | self.slice = source.slice; 215 | self.axis = source.axis; 216 | self.mapping.clone_from(&source.mapping); 217 | self.offset = source.offset; 218 | } 219 | } 220 | 221 | impl Fill { 222 | pub(crate) fn new(value: T) -> Self { 223 | Self { value } 224 | } 225 | } 226 | 227 | impl Debug for Fill { 228 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 229 | f.debug_tuple("Fill").field(&self.value).finish() 230 | } 231 | } 232 | 233 | impl Expression for Fill { 234 | type Shape = (); 235 | 236 | const IS_REPEATABLE: bool = true; 237 | 238 | fn shape(&self) -> &() { 239 | &() 240 | } 241 | 242 | unsafe fn get_unchecked(&mut self, _: usize) -> T { 243 | self.value.clone() 244 | } 245 | 246 | fn inner_rank(&self) -> usize { 247 | usize::MAX 248 | } 249 | 250 | unsafe fn reset_dim(&mut self, _: usize, _: usize) {} 251 | unsafe fn step_dim(&mut self, _: usize) {} 252 | } 253 | 254 | impl IntoIterator for Fill { 255 | type Item = T; 256 | type IntoIter = Iter; 257 | 258 | fn into_iter(self) -> Iter { 259 | Iter::new(self) 260 | } 261 | } 262 | 263 | impl FillWith { 264 | pub(crate) fn new(f: F) -> Self { 265 | Self { f } 266 | } 267 | } 268 | 269 | impl T> Debug for FillWith { 270 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 271 | f.debug_tuple("FillWith").finish() 272 | } 273 | } 274 | 275 | impl T> Expression for FillWith { 276 | type Shape = (); 277 | 278 | const IS_REPEATABLE: bool = true; 279 | 280 | fn shape(&self) -> &() { 281 | &() 282 | } 283 | 284 | unsafe fn get_unchecked(&mut self, _: usize) -> T { 285 | (self.f)() 286 | } 287 | 288 | fn inner_rank(&self) -> usize { 289 | usize::MAX 290 | } 291 | 292 | unsafe fn reset_dim(&mut self, _: usize, _: usize) {} 293 | unsafe fn step_dim(&mut self, _: usize) {} 294 | } 295 | 296 | impl T> IntoIterator for FillWith { 297 | type Item = T; 298 | type IntoIter = Iter; 299 | 300 | fn into_iter(self) -> Iter { 301 | Iter::new(self) 302 | } 303 | } 304 | 305 | impl FromElem { 306 | pub(crate) fn new(shape: S, elem: T) -> Self { 307 | _ = shape.checked_len().expect("invalid length"); 308 | 309 | Self { shape, elem } 310 | } 311 | } 312 | 313 | impl Debug for FromElem { 314 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 315 | f.debug_tuple("FromElem").field(&self.shape).field(&self.elem).finish() 316 | } 317 | } 318 | 319 | impl Expression for FromElem { 320 | type Shape = S; 321 | 322 | const IS_REPEATABLE: bool = true; 323 | 324 | fn shape(&self) -> &S { 325 | &self.shape 326 | } 327 | 328 | unsafe fn get_unchecked(&mut self, _: usize) -> T { 329 | self.elem.clone() 330 | } 331 | 332 | fn inner_rank(&self) -> usize { 333 | usize::MAX 334 | } 335 | 336 | unsafe fn reset_dim(&mut self, _: usize, _: usize) {} 337 | unsafe fn step_dim(&mut self, _: usize) {} 338 | } 339 | 340 | impl IntoIterator for FromElem { 341 | type Item = T; 342 | type IntoIter = Iter; 343 | 344 | fn into_iter(self) -> Iter { 345 | Iter::new(self) 346 | } 347 | } 348 | 349 | impl FromFn { 350 | pub(crate) fn new(shape: S, f: F) -> Self { 351 | _ = shape.checked_len().expect("invalid length"); 352 | 353 | Self { shape, f, index: S::Dims::default() } 354 | } 355 | } 356 | 357 | impl Debug for FromFn { 358 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 359 | f.debug_tuple("FromFn").field(&self.shape).finish() 360 | } 361 | } 362 | 363 | impl T> Expression for FromFn { 364 | type Shape = S; 365 | 366 | const IS_REPEATABLE: bool = true; 367 | 368 | fn shape(&self) -> &S { 369 | &self.shape 370 | } 371 | 372 | unsafe fn get_unchecked(&mut self, _: usize) -> T { 373 | let value = (self.f)(self.index.as_ref()); 374 | 375 | // Increment the last dimension, which will be reset by reset_dim(). 376 | if self.rank() > 0 { 377 | self.index.as_mut()[self.shape.rank() - 1] += 1; 378 | } 379 | 380 | value 381 | } 382 | 383 | fn inner_rank(&self) -> usize { 384 | if self.shape.rank() > 0 { 1 } else { usize::MAX } 385 | } 386 | 387 | unsafe fn reset_dim(&mut self, index: usize, _: usize) { 388 | self.index.as_mut()[index] = 0; 389 | } 390 | 391 | unsafe fn step_dim(&mut self, index: usize) { 392 | // Don't increment the last dimension, since it is done in get_unchecked(). 393 | if index + 1 < self.rank() { 394 | self.index.as_mut()[index] += 1; 395 | } 396 | } 397 | } 398 | 399 | impl T> IntoIterator for FromFn { 400 | type Item = T; 401 | type IntoIter = Iter; 402 | 403 | fn into_iter(self) -> Iter { 404 | Iter::new(self) 405 | } 406 | } 407 | 408 | macro_rules! impl_lanes { 409 | ($name:tt, $expr:tt, $as_ptr:tt, {$($mut:tt)?}, $repeatable:tt) => { 410 | impl<'a, T, S: Shape, L: Layout, A: Axis> $name<'a, T, S, L, A> { 411 | pub(crate) fn new( 412 | slice: &'a $($mut)? Slice, 413 | axis: A, 414 | ) -> Self { 415 | let mapping = axis.remove(slice.mapping()); 416 | 417 | // Ensure that the subarray is valid. 418 | _ = mapping.shape().checked_len().expect("invalid length"); 419 | 420 | Self { slice, axis, mapping, offset: 0 } 421 | } 422 | } 423 | 424 | impl<'a, T: Debug, S: Shape, L: Layout, A: Axis> Debug for $name<'a, T, S, L, A> { 425 | fn fmt(&self, f: &mut Formatter<'_>) -> Result { 426 | let index = self.axis.index(self.slice.rank()); 427 | 428 | f.debug_tuple(stringify!($name)).field(&index).field(&self.slice).finish() 429 | } 430 | } 431 | 432 | impl<'a, T, S: Shape, L: Layout, A: Axis> Expression for $name<'a, T, S, L, A> { 433 | type Shape = A::Remove; 434 | 435 | const IS_REPEATABLE: bool = $repeatable; 436 | 437 | fn shape(&self) -> &Self::Shape { 438 | self.mapping.shape() 439 | } 440 | 441 | unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item { 442 | let offset = self.mapping.inner_stride() * index as isize; 443 | let mapping = self.axis.get(self.slice.mapping()); 444 | 445 | // If the view is empty, we must not offset the pointer. 446 | let count = if mapping.is_empty() { 0 } else { offset }; 447 | 448 | unsafe { $expr::new_unchecked(self.slice.$as_ptr().offset(count), mapping) } 449 | } 450 | 451 | fn inner_rank(&self) -> usize { 452 | if Split::::IS_DENSE { 453 | // For static rank 0, the inner stride is 0 so we allow inner rank >0. 454 | if A::Remove::::RANK == Some(0) { usize::MAX } else { self.mapping.rank() } 455 | } else { 456 | // For rank 0, the inner stride is always 0 so we can allow inner rank >0. 457 | if self.mapping.rank() > 0 { 1 } else { usize::MAX } 458 | } 459 | } 460 | 461 | unsafe fn reset_dim(&mut self, index: usize, count: usize) { 462 | self.offset -= self.mapping.stride(index) * count as isize; 463 | } 464 | 465 | unsafe fn step_dim(&mut self, index: usize) { 466 | self.offset += self.mapping.stride(index); 467 | } 468 | } 469 | 470 | impl<'a, T, S: Shape, L: Layout, A: Axis> IntoIterator for $name<'a, T, S, L, A> { 471 | type Item = $expr<'a, T, (A::Dim,), Keep>; 472 | type IntoIter = Iter; 473 | 474 | fn into_iter(self) -> Iter { 475 | Iter::new(self) 476 | } 477 | } 478 | }; 479 | } 480 | 481 | impl_lanes!(Lanes, View, as_ptr, {}, true); 482 | impl_lanes!(LanesMut, ViewMut, as_mut_ptr, {mut}, false); 483 | 484 | impl Clone for Lanes<'_, T, S, L, A> { 485 | fn clone(&self) -> Self { 486 | Self { 487 | slice: self.slice, 488 | axis: self.axis, 489 | mapping: self.mapping.clone(), 490 | offset: self.offset, 491 | } 492 | } 493 | 494 | fn clone_from(&mut self, source: &Self) { 495 | self.slice = source.slice; 496 | self.axis = source.axis; 497 | self.mapping.clone_from(&source.mapping); 498 | self.offset = source.offset; 499 | } 500 | } 501 | -------------------------------------------------------------------------------- /src/index/axis.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::hash::Hash; 3 | 4 | use crate::dim::{Const, Dim, Dyn}; 5 | use crate::layout::Layout; 6 | use crate::mapping::{DenseMapping, Mapping}; 7 | use crate::shape::{DynRank, Shape}; 8 | 9 | /// Array axis trait, for subarray shapes. 10 | pub trait Axis: Copy + Debug + Default + Hash + Ord + Send + Sync { 11 | /// Corresponding dimension. 12 | type Dim: Dim; 13 | 14 | /// Shape for the previous dimensions excluding the current dimension. 15 | type Init: Shape; 16 | 17 | /// Shape for the next dimensions excluding the current dimension. 18 | type Rest: Shape; 19 | 20 | /// Remove the dimension from the shape. 21 | type Remove: Shape; 22 | 23 | /// Insert the dimension into the shape. 24 | type Insert: Shape; 25 | 26 | /// Returns the dimension index. 27 | fn index(self, rank: usize) -> usize; 28 | 29 | #[doc(hidden)] 30 | fn get( 31 | self, 32 | mapping: &M, 33 | ) -> as Layout>::Mapping<(Self::Dim,)> { 34 | let index = self.index(mapping.rank()); 35 | 36 | Mapping::prepend_dim(&DenseMapping::new(()), mapping.dim(index), mapping.stride(index)) 37 | } 38 | 39 | #[doc(hidden)] 40 | fn remove( 41 | self, 42 | mapping: &M, 43 | ) -> as Layout>::Mapping> { 44 | Mapping::remove_dim::(mapping, self.index(mapping.rank())) 45 | } 46 | 47 | #[doc(hidden)] 48 | fn resize( 49 | self, 50 | mapping: &M, 51 | new_size: usize, 52 | ) -> as Layout>::Mapping> { 53 | Mapping::resize_dim::(mapping, self.index(mapping.rank()), new_size) 54 | } 55 | } 56 | 57 | /// Column axis type, for the second last dimension. 58 | #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] 59 | pub struct Cols; 60 | 61 | /// Row axis type, for the last dimension. 62 | #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] 63 | pub struct Rows; 64 | 65 | // 66 | // These types are public to improve documentation, but hidden since 67 | // they are not considered part of the API. 68 | // 69 | 70 | #[doc(hidden)] 71 | pub type Resize = ::Insert::Remove>; 72 | 73 | #[doc(hidden)] 74 | pub type Keep = <::Rest as Shape>::Layout; 75 | 76 | #[doc(hidden)] 77 | pub type Split = <::Init as Shape>::Layout; 78 | 79 | // 80 | // The tables below give the resulting layout depending on the rank and axis. 81 | // 82 | // Keep: 83 | // 84 | // Rank \ Axis 0 1 2 ... Dyn 85 | // ------------------------------------------------------------------------- 86 | // 1 L - - - Strided 87 | // 2 Strided L - - Strided 88 | // 3 Strided Strided L - Strided 89 | // ... 90 | // DynRank Strided Strided Strided ... Strided 91 | // 92 | // Split: 93 | // 94 | // Rank \ Axis 0 1 2 ... Dyn 95 | // ------------------------------------------------------------------------- 96 | // 1 L - - - Strided 97 | // 2 L Strided - - Strided 98 | // 3 L Strided Strided - Strided 99 | // ... 100 | // DynRank L Strided Strided ... Strided 101 | // 102 | 103 | impl Axis for Const<0> { 104 | type Dim = S::Head; 105 | 106 | type Init = (); 107 | type Rest = S::Tail; 108 | 109 | type Remove = S::Tail; 110 | type Insert = S::Prepend; 111 | 112 | fn index(self, rank: usize) -> usize { 113 | assert!(rank > 0, "invalid dimension"); 114 | 115 | 0 116 | } 117 | } 118 | 119 | macro_rules! impl_axis { 120 | (($($n:tt),*), ($($k:tt),*)) => { 121 | $( 122 | impl Axis for Const<$n> { 123 | type Dim = as Axis>::Dim; 124 | 125 | type Init = 126 | < as Axis>::Init as Shape>::Prepend; 127 | type Rest = as Axis>::Rest; 128 | 129 | type Remove = 130 | < as Axis>::Remove as Shape>::Prepend; 131 | type Insert = 132 | < as Axis>::Insert as Shape>::Prepend; 133 | 134 | fn index(self, rank: usize) -> usize { 135 | assert!(rank > $n, "invalid dimension"); 136 | 137 | $n 138 | } 139 | } 140 | )* 141 | }; 142 | } 143 | 144 | impl_axis!((1, 2, 3, 4, 5), (0, 1, 2, 3, 4)); 145 | 146 | macro_rules! impl_cols_rows { 147 | ($name:tt, $n:tt) => { 148 | impl Axis for $name { 149 | type Dim = as Axis>::Dim; 150 | 151 | type Init = < as Axis>::Rest as Shape>::Reverse; 152 | type Rest = < as Axis>::Init as Shape>::Reverse; 153 | 154 | type Remove = < as Axis>::Remove as Shape>::Reverse; 155 | type Insert = 156 | < as Axis>::Insert as Shape>::Reverse; 157 | 158 | fn index(self, rank: usize) -> usize { 159 | assert!(rank > $n, "invalid dimension"); 160 | 161 | rank - $n - 1 162 | } 163 | } 164 | }; 165 | } 166 | 167 | impl_cols_rows!(Cols, 1); 168 | impl_cols_rows!(Rows, 0); 169 | 170 | impl Axis for Dyn { 171 | type Dim = Dyn; 172 | 173 | type Init = DynRank; 174 | type Rest = DynRank; 175 | 176 | type Remove = ::Dyn; 177 | type Insert = ::Prepend; 178 | 179 | fn index(self, rank: usize) -> usize { 180 | assert!(self < rank, "invalid dimension"); 181 | 182 | self 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /src/index/mod.rs: -------------------------------------------------------------------------------- 1 | //! Module for array slice and view indexing, and for array axis subarray types. 2 | 3 | mod axis; 4 | mod permutation; 5 | mod slice; 6 | mod view; 7 | 8 | pub use axis::{Axis, Cols, Rows}; 9 | pub use permutation::Permutation; 10 | pub use slice::SliceIndex; 11 | pub use view::{DimIndex, ViewIndex}; 12 | 13 | #[doc(hidden)] 14 | pub use axis::{Keep, Resize, Split}; 15 | 16 | #[cfg(not(feature = "nightly"))] 17 | pub(crate) fn range(range: R, bounds: std::ops::RangeTo) -> std::ops::Range 18 | where 19 | R: std::ops::RangeBounds, 20 | { 21 | let len = bounds.end; 22 | 23 | let start: std::ops::Bound<&usize> = range.start_bound(); 24 | let start = match start { 25 | std::ops::Bound::Included(&start) => start, 26 | std::ops::Bound::Excluded(start) => start 27 | .checked_add(1) 28 | .unwrap_or_else(|| panic!("attempted to index slice from after maximum usize")), 29 | std::ops::Bound::Unbounded => 0, 30 | }; 31 | 32 | let end: std::ops::Bound<&usize> = range.end_bound(); 33 | let end = match end { 34 | std::ops::Bound::Included(end) => end 35 | .checked_add(1) 36 | .unwrap_or_else(|| panic!("attempted to index slice up to maximum usize")), 37 | std::ops::Bound::Excluded(&end) => end, 38 | std::ops::Bound::Unbounded => len, 39 | }; 40 | 41 | assert!(start <= end, "slice index starts at {start} but ends at {end}"); 42 | assert!(end <= len, "range end index {end} out of range for slice of length {len}"); 43 | 44 | std::ops::Range { start, end } 45 | } 46 | 47 | #[cold] 48 | #[inline(never)] 49 | #[track_caller] 50 | pub(crate) fn panic_bounds_check(index: usize, len: usize) -> ! { 51 | panic!("index out of bounds: the len is {len} but the index is {index}") 52 | } 53 | -------------------------------------------------------------------------------- /src/index/permutation.rs: -------------------------------------------------------------------------------- 1 | use crate::index::axis::Axis; 2 | use crate::layout::{Layout, Strided}; 3 | use crate::shape::{DynRank, Shape}; 4 | 5 | /// Array permutation trait, for array types after permutation of dimensions. 6 | pub trait Permutation { 7 | /// Shape after permuting dimensions. 8 | type Shape: Shape; 9 | 10 | /// Layout after permuting dimensions. 11 | type Layout: Layout; 12 | 13 | #[doc(hidden)] 14 | type Init: Shape; 15 | } 16 | 17 | impl Permutation for (X,) { 18 | type Shape = (X::Dim,); 19 | type Layout = L; 20 | 21 | type Init = X::Init<()>; 22 | } 23 | 24 | macro_rules! impl_permutation { 25 | (($($jk:tt),+), ($($yz:tt),+)) => { 26 | impl Permutation for (X $(,$yz)+) 27 | where 28 | ($($yz,)+): Permutation 29 | { 30 | type Shape = 31 | <<($($yz,)+) as Permutation>::Shape as Shape>::Prepend>; 32 | type Layout = ::Layout; 33 | 34 | type Init = 35 | <<<($($yz,)+) as Permutation>::Init as Shape>::Tail as Shape>::Merge>; 36 | } 37 | }; 38 | } 39 | 40 | impl_permutation!((1), (Y)); 41 | impl_permutation!((1, 2), (Y, Z)); 42 | impl_permutation!((1, 2, 3), (Y, Z, W)); 43 | impl_permutation!((1, 2, 3, 4), (Y, Z, W, U)); 44 | impl_permutation!((1, 2, 3, 4, 5), (Y, Z, W, U, V)); 45 | 46 | impl Permutation for DynRank { 47 | type Shape = S::Dyn; 48 | type Layout = Strided; 49 | 50 | type Init = Self; 51 | } 52 | -------------------------------------------------------------------------------- /src/index/slice.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{ 2 | Bound, Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, 3 | }; 4 | 5 | use crate::index; 6 | use crate::layout::{Dense, Layout}; 7 | use crate::mapping::Mapping; 8 | use crate::shape::Shape; 9 | use crate::slice::Slice; 10 | 11 | /// Array slice index trait, for an element or a subslice. 12 | pub trait SliceIndex { 13 | /// Array element or subslice type. 14 | type Output: ?Sized; 15 | 16 | #[doc(hidden)] 17 | unsafe fn get_unchecked(self, slice: &Slice) -> &Self::Output; 18 | 19 | #[doc(hidden)] 20 | unsafe fn get_unchecked_mut(self, slice: &mut Slice) -> &mut Self::Output; 21 | 22 | #[doc(hidden)] 23 | fn index(self, slice: &Slice) -> &Self::Output; 24 | 25 | #[doc(hidden)] 26 | fn index_mut(self, slice: &mut Slice) -> &mut Self::Output; 27 | } 28 | 29 | impl SliceIndex for &[usize] { 30 | type Output = T; 31 | 32 | unsafe fn get_unchecked(self, slice: &Slice) -> &T { 33 | unsafe { &*slice.as_ptr().offset(slice.mapping().offset(self)) } 34 | } 35 | 36 | unsafe fn get_unchecked_mut(self, slice: &mut Slice) -> &mut T { 37 | unsafe { &mut *slice.as_mut_ptr().offset(slice.mapping().offset(self)) } 38 | } 39 | 40 | fn index(self, slice: &Slice) -> &T { 41 | assert!(self.len() == slice.rank(), "invalid rank"); 42 | 43 | for i in 0..self.len() { 44 | if self[i] >= slice.dim(i) { 45 | index::panic_bounds_check(self[i], slice.dim(i)); 46 | } 47 | } 48 | 49 | unsafe { SliceIndex::get_unchecked(self, slice) } 50 | } 51 | 52 | fn index_mut(self, slice: &mut Slice) -> &mut T { 53 | assert!(self.len() == slice.rank(), "invalid rank"); 54 | 55 | for i in 0..self.len() { 56 | if self[i] >= slice.dim(i) { 57 | index::panic_bounds_check(self[i], slice.dim(i)); 58 | } 59 | } 60 | 61 | unsafe { SliceIndex::get_unchecked_mut(self, slice) } 62 | } 63 | } 64 | 65 | impl SliceIndex for [usize; N] { 66 | type Output = T; 67 | 68 | unsafe fn get_unchecked(self, slice: &Slice) -> &T { 69 | unsafe { SliceIndex::get_unchecked(&self[..], slice) } 70 | } 71 | 72 | unsafe fn get_unchecked_mut(self, slice: &mut Slice) -> &mut T { 73 | unsafe { SliceIndex::get_unchecked_mut(&self[..], slice) } 74 | } 75 | 76 | fn index(self, slice: &Slice) -> &T { 77 | SliceIndex::index(&self[..], slice) 78 | } 79 | 80 | fn index_mut(self, slice: &mut Slice) -> &mut T { 81 | SliceIndex::index_mut(&self[..], slice) 82 | } 83 | } 84 | 85 | impl SliceIndex for usize { 86 | type Output = T; 87 | 88 | unsafe fn get_unchecked(self, slice: &Slice) -> &T { 89 | unsafe { &*slice.as_ptr().offset(slice.mapping().linear_offset(self)) } 90 | } 91 | 92 | unsafe fn get_unchecked_mut(self, slice: &mut Slice) -> &mut T { 93 | unsafe { &mut *slice.as_mut_ptr().offset(slice.mapping().linear_offset(self)) } 94 | } 95 | 96 | fn index(self, slice: &Slice) -> &T { 97 | if self >= slice.len() { 98 | index::panic_bounds_check(self, slice.len()); 99 | } 100 | 101 | unsafe { SliceIndex::get_unchecked(self, slice) } 102 | } 103 | 104 | fn index_mut(self, slice: &mut Slice) -> &mut T { 105 | if self >= slice.len() { 106 | index::panic_bounds_check(self, slice.len()); 107 | } 108 | 109 | unsafe { SliceIndex::get_unchecked_mut(self, slice) } 110 | } 111 | } 112 | 113 | macro_rules! impl_slice_index { 114 | ($type:ty) => { 115 | impl SliceIndex for $type { 116 | type Output = [T]; 117 | 118 | unsafe fn get_unchecked(self, slice: &Slice) -> &[T] { 119 | unsafe { <&[T]>::from(slice.flatten()).get_unchecked(self) } 120 | } 121 | 122 | unsafe fn get_unchecked_mut(self, slice: &mut Slice) -> &mut [T] { 123 | unsafe { <&mut [T]>::from(slice.flatten_mut()).get_unchecked_mut(self) } 124 | } 125 | 126 | fn index(self, slice: &Slice) -> &[T] { 127 | <&[T]>::from(slice.flatten()).index(self) 128 | } 129 | 130 | fn index_mut(self, slice: &mut Slice) -> &mut [T] { 131 | <&mut [T]>::from(slice.flatten_mut()).index_mut(self) 132 | } 133 | } 134 | }; 135 | } 136 | 137 | impl_slice_index!((Bound, Bound)); 138 | impl_slice_index!(Range); 139 | impl_slice_index!(RangeFrom); 140 | impl_slice_index!(RangeFull); 141 | impl_slice_index!(RangeInclusive); 142 | impl_slice_index!(RangeTo); 143 | impl_slice_index!(RangeToInclusive); 144 | -------------------------------------------------------------------------------- /src/index/view.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nightly")] 2 | use std::slice; 3 | 4 | use std::ops::{ 5 | Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, 6 | }; 7 | 8 | use crate::dim::Dyn; 9 | use crate::index; 10 | use crate::layout::{Layout, Strided}; 11 | use crate::mapping::Mapping; 12 | use crate::ops::StepRange; 13 | use crate::shape::Shape; 14 | 15 | /// Helper trait for array indexing, for a single index. 16 | pub trait DimIndex { 17 | #[doc(hidden)] 18 | type Shape: Shape; 19 | 20 | #[doc(hidden)] 21 | type Layout: Layout; 22 | 23 | #[doc(hidden)] 24 | type Outer: Layout; 25 | 26 | #[doc(hidden)] 27 | fn dim_index = Self::Shape>>( 28 | self, 29 | tail: I, 30 | mapping: &M, 31 | ) -> (isize, as Layout>::Mapping>); 32 | } 33 | 34 | /// Array view index trait, for a multidimensional index. 35 | pub trait ViewIndex { 36 | /// Array view shape. 37 | type Shape: Shape; 38 | 39 | /// Array view layout. 40 | type Layout: Layout; 41 | 42 | #[doc(hidden)] 43 | type Outer: Layout; 44 | 45 | #[doc(hidden)] 46 | const RANK: usize; 47 | 48 | #[doc(hidden)] 49 | fn view_index( 50 | self, 51 | mapping: &M, 52 | ) -> (isize, as Layout>::Mapping>); 53 | } 54 | 55 | impl DimIndex for usize { 56 | type Shape = I::Shape; 57 | type Layout = I::Layout; 58 | type Outer = Strided; 59 | 60 | fn dim_index = Self::Shape>>( 61 | self, 62 | tail: I, 63 | mapping: &M, 64 | ) -> (isize, as Layout>::Mapping>) { 65 | let (offset, inner) = tail.view_index::(mapping); 66 | 67 | let size = mapping.dim(mapping.rank() - 1 - I::RANK); 68 | let stride = mapping.stride(mapping.rank() - 1 - I::RANK); 69 | 70 | if self >= size { 71 | index::panic_bounds_check(self, size) 72 | } 73 | 74 | (offset + stride * self as isize, Mapping::remap(&inner)) 75 | } 76 | } 77 | 78 | impl DimIndex for RangeFull { 79 | type Shape = as Shape>::Prepend; 80 | type Layout = I::Outer; 81 | type Outer = I::Outer; 82 | 83 | fn dim_index = Self::Shape>>( 84 | self, 85 | tail: I, 86 | mapping: &M, 87 | ) -> (isize, as Layout>::Mapping>) { 88 | let (offset, inner) = tail.view_index::(mapping); 89 | 90 | let size = mapping.dim(mapping.rank() - 1 - I::RANK); 91 | let stride = mapping.stride(mapping.rank() - 1 - I::RANK); 92 | 93 | (offset, Mapping::prepend_dim(&inner, size, stride)) 94 | } 95 | } 96 | 97 | macro_rules! impl_dim_index { 98 | ($type:ty) => { 99 | impl DimIndex for $type { 100 | type Shape = as Shape>::Prepend; 101 | type Layout = I::Outer; 102 | type Outer = Strided; 103 | 104 | fn dim_index( 105 | self, 106 | tail: I, 107 | mapping: &M, 108 | ) -> (isize, as Layout>::Mapping>) 109 | where 110 | J: ViewIndex = Self::Shape>, 111 | { 112 | let (offset, inner) = tail.view_index::(mapping); 113 | 114 | let size = mapping.dim(mapping.rank() - 1 - I::RANK); 115 | let stride = mapping.stride(mapping.rank() - 1 - I::RANK); 116 | 117 | #[cfg(not(feature = "nightly"))] 118 | let range = crate::index::range(self, ..size); 119 | #[cfg(feature = "nightly")] 120 | let range = slice::range(self, ..size); 121 | let count = stride * range.start as isize; 122 | 123 | (offset + count, Mapping::prepend_dim(&inner, range.len(), stride)) 124 | } 125 | } 126 | }; 127 | } 128 | 129 | impl_dim_index!((Bound, Bound)); 130 | impl_dim_index!(Range); 131 | impl_dim_index!(RangeFrom); 132 | impl_dim_index!(RangeInclusive); 133 | impl_dim_index!(RangeTo); 134 | impl_dim_index!(RangeToInclusive); 135 | 136 | impl> DimIndex for StepRange { 137 | type Shape = as Shape>::Prepend; 138 | type Layout = Strided; 139 | type Outer = Strided; 140 | 141 | fn dim_index = Self::Shape>>( 142 | self, 143 | tail: I, 144 | mapping: &M, 145 | ) -> (isize, as Layout>::Mapping>) { 146 | let (offset, inner) = tail.view_index::(mapping); 147 | 148 | let size = mapping.dim(mapping.rank() - 1 - I::RANK); 149 | let stride = mapping.stride(mapping.rank() - 1 - I::RANK); 150 | 151 | #[cfg(not(feature = "nightly"))] 152 | let range = crate::index::range(self.range, ..size); 153 | #[cfg(feature = "nightly")] 154 | let range = slice::range(self.range, ..size); 155 | let len = range.len().div_ceil(self.step.abs_diff(0)); 156 | 157 | let delta = if self.step < 0 && !range.is_empty() { range.end - 1 } else { range.start }; 158 | 159 | (offset + stride * delta as isize, Mapping::prepend_dim(&inner, len, stride * self.step)) 160 | } 161 | } 162 | 163 | impl ViewIndex for () { 164 | type Shape = (); 165 | type Layout = L; 166 | type Outer = L; 167 | 168 | const RANK: usize = 0; 169 | 170 | fn view_index( 171 | self, 172 | _: &M, 173 | ) -> (isize, as Layout>::Mapping>) { 174 | (0, Default::default()) 175 | } 176 | } 177 | 178 | macro_rules! impl_view_index { 179 | ($n:tt, ($($jk:tt),*), ($($yz:tt),*)) => { 180 | impl ViewIndex for (X, $($yz),*) { 181 | type Shape = X::Shape; 182 | type Layout = X::Layout; 183 | type Outer = X::Outer; 184 | 185 | const RANK: usize = $n; 186 | 187 | fn view_index( 188 | self, 189 | mapping: &M, 190 | ) -> (isize, as Layout>::Mapping>) { 191 | self.0.dim_index::(($(self.$jk,)*), mapping) 192 | } 193 | } 194 | }; 195 | } 196 | 197 | impl_view_index!(1, (), ()); 198 | impl_view_index!(2, (1), (Y)); 199 | impl_view_index!(3, (1, 2), (Y, Z)); 200 | impl_view_index!(4, (1, 2, 3), (Y, Z, W)); 201 | impl_view_index!(5, (1, 2, 3, 4), (Y, Z, W, U)); 202 | impl_view_index!(6, (1, 2, 3, 4, 5), (Y, Z, W, U, V)); 203 | -------------------------------------------------------------------------------- /src/layout.rs: -------------------------------------------------------------------------------- 1 | use crate::mapping::{DenseMapping, Mapping, StridedMapping}; 2 | use crate::shape::Shape; 3 | 4 | /// Array memory layout trait. 5 | pub trait Layout { 6 | /// Array layout mapping type. 7 | type Mapping: Mapping; 8 | 9 | /// True if the layout type is dense. 10 | const IS_DENSE: bool; 11 | } 12 | 13 | /// Dense array layout type. 14 | pub struct Dense; 15 | 16 | /// Strided array layout type. 17 | pub struct Strided; 18 | 19 | impl Layout for Dense { 20 | type Mapping = DenseMapping; 21 | 22 | const IS_DENSE: bool = true; 23 | } 24 | 25 | impl Layout for Strided { 26 | type Mapping = StridedMapping; 27 | 28 | const IS_DENSE: bool = false; 29 | } 30 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Multidimensional array for Rust 2 | //! 3 | //! ## Overview 4 | //! 5 | //! The mdarray crate provides a multidimensional array for Rust. Its main target 6 | //! is for numeric types, however generic types are supported as well. The purpose 7 | //! is to provide a generic container type that is simple and flexible to use, 8 | //! with interworking to other crates for e.g. BLAS/LAPACK functionality. 9 | //! 10 | //! Here are the main features of mdarray: 11 | //! 12 | //! - Dense array type, where the rank is known at compile time. 13 | //! - Static or dynamic array dimensions, with optional inline storage. 14 | //! - Standard Rust mechanisms are used for e.g. indexing and iteration. 15 | //! - Generic expressions for multidimensional iteration. 16 | //! 17 | //! The design is inspired from other Rust crates (ndarray, nalgebra, bitvec, dfdx 18 | //! and candle), the proposed C++ mdarray and mdspan types, and multidimensional 19 | //! arrays in other languages. 20 | //! 21 | //! ## Array types 22 | //! 23 | //! The basic array type is `Tensor` for a dense array that owns the storage, 24 | //! similar to the Rust `Vec` type. It is parameterized by the element type, 25 | //! the shape (i.e. the size of each dimension) and optionally an allocator. 26 | //! 27 | //! `Array` is a dense array which stores elements inline, similar to the Rust 28 | //! `array` type. The shape must consist of dimensions with constant size. 29 | //! 30 | //! `View` and `ViewMut` are array types that refer to a parent array. They are 31 | //! used for example when creating array views without duplicating elements. 32 | //! 33 | //! `Slice` is a generic array reference, similar to the Rust `slice` type. 34 | //! It consists of a pointer to an internal structure that holds the storage 35 | //! and the layout mapping. All arrays can be dereferenced to an array slice. 36 | //! 37 | //! The following type aliases are provided: 38 | //! 39 | //! - `DTensor` for a dense array with a given rank. 40 | //! - `DSlice` for an array slice with a given rank. 41 | //! 42 | //! The rank can be dynamic using the `DynRank` shape type. This is the default 43 | //! for array types if no shape is specified. 44 | //! 45 | //! The layout mapping describes how elements are stored in memory. The mapping 46 | //! is parameterized by the shape and the layout. It contains the dynamic size 47 | //! and stride per dimension when needed. 48 | //! 49 | //! The layout is `Dense` if elements are stored contiguously without gaps, and 50 | //! it is `Strided` if all dimensions can have arbitrary strides. 51 | //! 52 | //! The array elements are stored in row-major or C order, where the first 53 | //! dimension is the outermost one. 54 | //! 55 | //! ## Indexing and views 56 | //! 57 | //! Scalar indexing is done using the normal square-bracket index operator and 58 | //! an array of `usize` per dimension as index. A scalar `usize` can be used for 59 | //! linear indexing. If the layout is `Dense`, a range can also be used to select 60 | //! a slice. 61 | //! 62 | //! An array view can be created with the `view` and `view_mut` methods, which 63 | //! take indices per dimension as arguments. Each index can be either a range 64 | //! or `usize`. The resulting array layout depends on both the layout inferred 65 | //! from the indices and the input layout. 66 | //! 67 | //! For two-dimensional arrays, a view of one column or row can be created with 68 | //! the `col`, `col_mut`, `row` and `row_mut` methods, and a view of the diagonal 69 | //! with `diag` and `diag_mut`. 70 | //! 71 | //! If the array layout is not known, `remap`, `remap_mut` and `into_mapping` can 72 | //! be used to change layout. 73 | //! 74 | //! ## Iteration 75 | //! 76 | //! An iterator can be created from an array with the `iter`, `iter_mut` and 77 | //! `into_iter` methods like for `Vec` and `slice`. 78 | //! 79 | //! Expressions are similar to iterators, but support multidimensional iteration 80 | //! and have consistency checking of shapes. An expression is created with the 81 | //! `expr`, `expr_mut` and `into_expr` methods. Note that the array types `View` 82 | //! and `ViewMut` are also expressions. 83 | //! 84 | //! There are methods for for evaluating expressions or converting into other 85 | //! expressions, such as `eval`, `for_each` and `map`. Two expressions can be 86 | //! merged to an expression of tuples with the `zip` method or free function. 87 | //! 88 | //! When merging expressions, if the rank differs the expression with the lower 89 | //! rank is broadcast into the larger shape by adding outer dimensions. It is not 90 | //! possible to broadcast mutable arrays or when moving elements out of an array. 91 | //! 92 | //! For multidimensional arrays, iteration over a single dimension can be done 93 | //! with `outer_expr`, `outer_expr_mut`, `axis_expr` and `axis_expr_mut`. 94 | //! The resulting expressions give array views of the remaining dimensions. 95 | //! 96 | //! It is also possible to iterate over all except one dimension with `cols`, 97 | //! `cols_mut`, `lanes`, `lanes_mut`, `rows` and `rows_mut`. 98 | //! 99 | //! ## Operators 100 | //! 101 | //! Arithmetic, logical, negation, comparison and compound assignment operators 102 | //! are supported for arrays and expressions. 103 | //! 104 | //! If at least one of the inputs is an array that is passed by value, the 105 | //! operation is evaluated directly and the input array is reused for the result. 106 | //! Otherwise, if all input parameters are array references or expressions, an 107 | //! expression is returned. In the latter case, the result may have a different 108 | //! element type. 109 | //! 110 | //! For comparison operators, the parameters must always be arrays that are passed 111 | //! by reference. For compound assignment operators, the first parameter is always 112 | //! a mutable reference to an array where the result is stored. 113 | //! 114 | //! Scalar parameters must passed using the `fill` function that wraps a value in 115 | //! an `Fill` expression. If a type does not implement the `Copy` trait, the 116 | //! parameter must be passed by reference. 117 | //! 118 | //! ## Example 119 | //! 120 | //! This example implements matrix multiplication and addition `C = A * B + C`. 121 | //! The matrices use row-major ordering, and the inner loop runs over one row in 122 | //! `B` and `C`. By using iterator-like expressions the array bounds checking is 123 | //! avoided, and the compiler is able to vectorize the inner loop. 124 | //! 125 | //! ``` 126 | //! use mdarray::{expr::Expression, tensor, view, DSlice}; 127 | //! 128 | //! fn matmul(a: &DSlice, b: &DSlice, c: &mut DSlice) { 129 | //! for (mut ci, ai) in c.rows_mut().zip(a.rows()) { 130 | //! for (aik, bk) in ai.expr().zip(b.rows()) { 131 | //! for (cij, bkj) in ci.expr_mut().zip(bk) { 132 | //! *cij = aik.mul_add(*bkj, *cij); 133 | //! } 134 | //! } 135 | //! } 136 | //! } 137 | //! 138 | //! let a = view![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; 139 | //! let b = view![[0.0, 1.0], [1.0, 1.0]]; 140 | //! 141 | //! let mut c = tensor![[0.0; 2]; 3]; 142 | //! 143 | //! matmul(&a, &b, &mut c); 144 | //! 145 | //! assert_eq!(c, view![[4.0, 5.0], [5.0, 7.0], [6.0, 9.0]]); 146 | //! ``` 147 | 148 | #![allow(clippy::comparison_chain)] 149 | #![allow(clippy::needless_range_loop)] 150 | #![cfg_attr(feature = "nightly", feature(allocator_api))] 151 | #![cfg_attr(feature = "nightly", feature(extern_types))] 152 | #![cfg_attr(feature = "nightly", feature(hasher_prefixfree_extras))] 153 | #![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))] 154 | #![cfg_attr(feature = "nightly", feature(macro_metavar_expr))] 155 | #![cfg_attr(feature = "nightly", feature(slice_range))] 156 | #![warn(missing_docs)] 157 | #![warn(unreachable_pub)] 158 | #![warn(unused_results)] 159 | 160 | pub mod expr; 161 | pub mod index; 162 | 163 | mod array; 164 | mod dim; 165 | mod layout; 166 | mod macros; 167 | mod mapping; 168 | mod ops; 169 | mod raw_slice; 170 | mod raw_tensor; 171 | mod shape; 172 | mod slice; 173 | mod tensor; 174 | mod traits; 175 | mod view; 176 | 177 | #[cfg(feature = "serde")] 178 | mod serde; 179 | 180 | #[cfg(not(feature = "nightly"))] 181 | mod alloc { 182 | pub trait Allocator {} 183 | 184 | #[derive(Copy, Clone, Default, Debug)] 185 | pub struct Global; 186 | 187 | impl Allocator for Global {} 188 | } 189 | 190 | pub use array::Array; 191 | pub use dim::{Const, Dim, Dyn}; 192 | pub use layout::{Dense, Layout, Strided}; 193 | pub use mapping::{DenseMapping, Mapping, StridedMapping}; 194 | pub use ops::{StepRange, step}; 195 | pub use shape::{ConstShape, DynRank, IntoShape, Rank, Shape}; 196 | pub use slice::{DSlice, Slice}; 197 | pub use tensor::{DTensor, Tensor}; 198 | pub use traits::{IntoCloned, Owned}; 199 | pub use view::{DView, DViewMut, View, ViewMut}; 200 | -------------------------------------------------------------------------------- /src/macros.rs: -------------------------------------------------------------------------------- 1 | /// Creates an inline multidimensional array containing the arguments. 2 | /// 3 | /// This macro is used to create an array, similar to the `vec!` macro for vectors. 4 | /// There are two forms of this macro: 5 | /// 6 | /// - Create an array containing a given list of elements: 7 | /// 8 | /// ``` 9 | /// use mdarray::{Array, array}; 10 | /// 11 | /// let a = array![[1, 2, 3], [4, 5, 6]]; 12 | /// 13 | /// assert_eq!(a, Array::from([[1, 2, 3], [4, 5, 6]])); 14 | /// ``` 15 | /// 16 | /// - Create an array from a given element and shape: 17 | /// 18 | /// ``` 19 | /// use mdarray::{array, Const, Array}; 20 | /// 21 | /// let a = array![[1; 3]; 2]; 22 | /// 23 | /// assert_eq!(a, Array::from([[1; 3]; 2])); 24 | /// ``` 25 | /// 26 | /// In the second form, the argument must be an array repeat expression with constant shape. 27 | #[macro_export] 28 | macro_rules! array { 29 | ($([$([$([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 30 | $crate::Array::<_, (_, _, _, _, _, _)>::from([$([$([$([$([$([$($x),*]),+]),+]),+]),+]),+]) 31 | ); 32 | ($([$([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 33 | $crate::Array::<_, (_, _, _, _, _)>::from([$([$([$([$([$($x),*]),+]),+]),+]),+]) 34 | ); 35 | ($([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 36 | $crate::Array::<_, (_, _, _, _)>::from([$([$([$([$($x),*]),+]),+]),+]) 37 | ); 38 | ($([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?) => ( 39 | $crate::Array::<_, (_, _, _)>::from([$([$([$($x),*]),+]),+]) 40 | ); 41 | ($([$($x:expr),* $(,)?]),+ $(,)?) => ( 42 | $crate::Array::<_, (_, _)>::from([$([$($x),*]),+]) 43 | ); 44 | ($($x:expr),* $(,)?) => ( 45 | $crate::Array::<_, (_,)>::from([$($x),*]) 46 | ); 47 | ([[[[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr]; $m:expr]; $n:expr) => ( 48 | $crate::Array::<_, (_, _, _, _, _, _)>::from([[[[[[$elem; $i]; $j]; $k]; $l]; $m]; $n]) 49 | ); 50 | ([[[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr]; $m:expr) => ( 51 | $crate::Array::<_, (_, _, _, _, _)>::from([[[[[$elem; $i]; $j]; $k]; $l]; $m]) 52 | ); 53 | ([[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr) => ( 54 | $crate::Array::<_, (_, _, _, _)>::from([[[[$elem; $i]; $j]; $k]; $l]) 55 | ); 56 | ([[$elem:expr; $i:expr]; $j:expr]; $k:expr) => ( 57 | $crate::Array::<_, (_, _, _)>::from([[[$elem; $i]; $j]; $k]) 58 | ); 59 | ([$elem:expr; $i:expr]; $j:expr) => ( 60 | $crate::Array::<_, (_, _)>::from([[$elem; $i]; $j]) 61 | ); 62 | ($elem:expr; $i:expr) => ( 63 | $crate::Array::<_, (_,)>::from([$elem; $i]) 64 | ); 65 | } 66 | 67 | /// Creates a dense multidimensional array containing the arguments. 68 | /// 69 | /// This macro is used to create an array, similar to the `vec!` macro for vectors. 70 | /// There are two forms of this macro: 71 | /// 72 | /// - Create an array containing a given list of elements: 73 | /// 74 | /// ``` 75 | /// use mdarray::{DTensor, tensor}; 76 | /// 77 | /// let a = tensor![[1, 2, 3], [4, 5, 6]]; 78 | /// 79 | /// assert_eq!(a, DTensor::<_, 2>::from([[1, 2, 3], [4, 5, 6]])); 80 | /// ``` 81 | /// 82 | /// - Create an array from a given element and shape by cloning the element: 83 | /// 84 | /// ``` 85 | /// use mdarray::{tensor, Tensor}; 86 | /// 87 | /// let a = tensor![[1; 3]; 2]; 88 | /// 89 | /// assert_eq!(a, Tensor::from_elem([2, 3], 1)); 90 | /// ``` 91 | /// 92 | /// In the second form, like for vectors the shape does not have to be constant. 93 | #[macro_export] 94 | macro_rules! tensor { 95 | ($([$([$([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 96 | $crate::DTensor::<_, 6>::from([$([$([$([$([$([$($x),*]),+]),+]),+]),+]),+]) 97 | ); 98 | ($([$([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 99 | $crate::DTensor::<_, 5>::from([$([$([$([$([$($x),*]),+]),+]),+]),+]) 100 | ); 101 | ($([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 102 | $crate::DTensor::<_, 4>::from([$([$([$([$($x),*]),+]),+]),+]) 103 | ); 104 | ($([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?) => ( 105 | $crate::DTensor::<_, 3>::from([$([$([$($x),*]),+]),+]) 106 | ); 107 | ($([$($x:expr),* $(,)?]),+ $(,)?) => ( 108 | $crate::DTensor::<_, 2>::from([$([$($x),*]),+]) 109 | ); 110 | ($($x:expr),* $(,)?) => ( 111 | $crate::DTensor::<_, 1>::from([$($x),*]) 112 | ); 113 | ([[[[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr]; $m:expr]; $n:expr) => ( 114 | $crate::DTensor::<_, 6>::from_elem([$n, $m, $l, $k, $j, $i], $elem) 115 | ); 116 | ([[[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr]; $m:expr) => ( 117 | $crate::DTensor::<_, 5>::from_elem([$m, $l, $k, $j, $i], $elem) 118 | ); 119 | ([[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr) => ( 120 | $crate::DTensor::<_, 4>::from_elem([$l, $k, $j, $i], $elem) 121 | ); 122 | ([[$elem:expr; $i:expr]; $j:expr]; $k:expr) => ( 123 | $crate::DTensor::<_, 3>::from_elem([$k, $j, $i], $elem) 124 | ); 125 | ([$elem:expr; $i:expr]; $j:expr) => ( 126 | $crate::DTensor::<_, 2>::from_elem([$j, $i], $elem) 127 | ); 128 | ($elem:expr; $i:expr) => ( 129 | $crate::DTensor::<_, 1>::from_elem([$i], $elem) 130 | ); 131 | } 132 | 133 | /// Creates a multidimensional array view containing the arguments. 134 | /// 135 | /// This macro is used to create an array view, similar to the `vec!` macro for vectors. 136 | /// There are two forms of this macro: 137 | /// 138 | /// - Create an array view containing a given list of elements: 139 | /// 140 | /// ``` 141 | /// use mdarray::{DView, view}; 142 | /// 143 | /// let a = view![[1, 2, 3], [4, 5, 6]]; 144 | /// 145 | /// assert_eq!(a, DView::<_, 2>::from(&[[1, 2, 3], [4, 5, 6]])); 146 | /// ``` 147 | /// 148 | /// - Create an array view from a given element and shape: 149 | /// 150 | /// ``` 151 | /// use mdarray::{view, DView}; 152 | /// 153 | /// let a = view![[1; 3]; 2]; 154 | /// 155 | /// assert_eq!(a, DView::<_, 2>::from(&[[1; 3]; 2])); 156 | /// ``` 157 | /// 158 | /// In the second form, the argument must be an array repeat expression with constant shape. 159 | #[macro_export] 160 | macro_rules! view { 161 | ($([$([$([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 162 | $crate::View::<_, $crate::Rank<6>>::from(&[$([$([$([$([$([$($x),*]),+]),+]),+]),+]),+]) 163 | ); 164 | ($([$([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 165 | $crate::View::<_, $crate::Rank<5>>::from(&[$([$([$([$([$($x),*]),+]),+]),+]),+]) 166 | ); 167 | ($([$([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?]),+ $(,)?) => ( 168 | $crate::View::<_, $crate::Rank<4>>::from(&[$([$([$([$($x),*]),+]),+]),+]) 169 | ); 170 | ($([$([$($x:expr),* $(,)?]),+ $(,)?]),+ $(,)?) => ( 171 | $crate::View::<_, $crate::Rank<3>>::from(&[$([$([$($x),*]),+]),+]) 172 | ); 173 | ($([$($x:expr),* $(,)?]),+ $(,)?) => ( 174 | $crate::View::<_, $crate::Rank<2>>::from(&[$([$($x),*]),+]) 175 | ); 176 | ($($x:expr),* $(,)?) => ( 177 | $crate::View::<_, $crate::Rank<1>>::from(&[$($x),*]) 178 | ); 179 | ([[[[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr]; $m:expr]; $n:expr) => ( 180 | $crate::View::<_, $crate::Rank<6>>::from(&[[[[[[$elem; $i]; $j]; $k]; $l]; $m]; $n]) 181 | ); 182 | ([[[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr]; $m:expr) => ( 183 | $crate::View::<_, $crate::Rank<5>>::from(&[[[[[$elem; $i]; $j]; $k]; $l]; $m]) 184 | ); 185 | ([[[$elem:expr; $i:expr]; $j:expr]; $k:expr]; $l:expr) => ( 186 | $crate::View::<_, $crate::Rank<4>>::from(&[[[[$elem; $i]; $j]; $k]; $l]) 187 | ); 188 | ([[$elem:expr; $i:expr]; $j:expr]; $k:expr) => ( 189 | $crate::View::<_, $crate::Rank<3>>::from(&[[[$elem; $i]; $j]; $k]) 190 | ); 191 | ([$elem:expr; $i:expr]; $j:expr) => ( 192 | $crate::View::<_, $crate::Rank<2>>::from(&[[$elem; $i]; $j]) 193 | ); 194 | ($elem:expr; $i:expr) => ( 195 | $crate::View::<_, $crate::Rank<1>>::from(&[$elem; $i]) 196 | ); 197 | } 198 | -------------------------------------------------------------------------------- /src/mapping.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::hash::Hash; 3 | 4 | use crate::dim::Dims; 5 | use crate::layout::{Dense, Layout, Strided}; 6 | use crate::shape::{DynRank, Shape}; 7 | 8 | /// Array layout mapping trait, including shape and strides. 9 | pub trait Mapping: Clone + Debug + Default + Eq + Hash + Send + Sync { 10 | /// Array shape type. 11 | type Shape: Shape; 12 | 13 | /// Array layout type. 14 | type Layout: Layout = Self>; 15 | 16 | /// Returns `true` if the array strides are consistent with contiguous memory layout. 17 | fn is_contiguous(&self) -> bool; 18 | 19 | /// Returns the array shape. 20 | fn shape(&self) -> &Self::Shape; 21 | 22 | /// Returns the distance between elements in the specified dimension. 23 | /// 24 | /// # Panics 25 | /// 26 | /// Panics if the dimension is out of bounds. 27 | fn stride(&self, index: usize) -> isize; 28 | 29 | /// Returns the number of elements in the specified dimension. 30 | /// 31 | /// # Panics 32 | /// 33 | /// Panics if the dimension is out of bounds. 34 | fn dim(&self, index: usize) -> usize { 35 | self.shape().dim(index) 36 | } 37 | 38 | /// Returns the number of elements in each dimension. 39 | fn dims(&self) -> &[usize] 40 | where 41 | Self: Mapping, 42 | { 43 | self.shape().dims() 44 | } 45 | 46 | /// Returns `true` if the array contains no elements. 47 | fn is_empty(&self) -> bool { 48 | self.shape().is_empty() 49 | } 50 | 51 | /// Returns the number of elements in the array. 52 | fn len(&self) -> usize { 53 | self.shape().len() 54 | } 55 | 56 | /// Returns the array rank, i.e. the number of dimensions. 57 | fn rank(&self) -> usize { 58 | self.shape().rank() 59 | } 60 | 61 | #[doc(hidden)] 62 | fn for_each_stride(&self, f: F); 63 | 64 | #[doc(hidden)] 65 | fn inner_stride(&self) -> isize; 66 | 67 | #[doc(hidden)] 68 | fn linear_offset(&self, index: usize) -> isize; 69 | 70 | #[doc(hidden)] 71 | fn permute(mapping: &M, perm: &[usize]) -> Self; 72 | 73 | #[doc(hidden)] 74 | fn prepend_dim(mapping: &M, size: usize, stride: isize) -> Self; 75 | 76 | #[doc(hidden)] 77 | fn remap(mapping: &M) -> Self; 78 | 79 | #[doc(hidden)] 80 | fn remove_dim(mapping: &M, index: usize) -> Self; 81 | 82 | #[doc(hidden)] 83 | fn reshape(&self, new_shape: S) -> ::Mapping; 84 | 85 | #[doc(hidden)] 86 | fn resize_dim(mapping: &M, index: usize, new_size: usize) -> Self; 87 | 88 | #[doc(hidden)] 89 | fn shape_mut(&mut self) -> &mut Self::Shape; 90 | 91 | #[doc(hidden)] 92 | fn transpose>>(mapping: &M) -> Self; 93 | 94 | #[doc(hidden)] 95 | fn offset(&self, index: &[usize]) -> isize { 96 | debug_assert!(index.len() == self.rank(), "invalid rank"); 97 | 98 | let mut offset = 0; 99 | 100 | self.for_each_stride(|i, stride| { 101 | debug_assert!(index[i] < self.dim(i), "index out of bounds"); 102 | 103 | offset += stride * index[i] as isize; 104 | }); 105 | 106 | offset 107 | } 108 | } 109 | 110 | /// Dense layout mapping type. 111 | #[derive(Debug, Default, Eq, Hash, PartialEq)] 112 | pub struct DenseMapping { 113 | shape: S, 114 | } 115 | 116 | /// Strided layout mapping type. 117 | #[derive(Debug, Eq, Hash, PartialEq)] 118 | pub struct StridedMapping { 119 | shape: S, 120 | strides: S::Dims, 121 | } 122 | 123 | impl DenseMapping { 124 | /// Creates a new, dense layout mapping with the specified shape. 125 | pub fn new(shape: S) -> Self { 126 | Self { shape } 127 | } 128 | } 129 | 130 | impl Clone for DenseMapping { 131 | fn clone(&self) -> Self { 132 | Self::new(self.shape.clone()) 133 | } 134 | 135 | fn clone_from(&mut self, source: &Self) { 136 | self.shape.clone_from(&source.shape); 137 | } 138 | } 139 | 140 | impl Copy for DenseMapping {} 141 | 142 | impl Mapping for DenseMapping { 143 | type Shape = S; 144 | type Layout = Dense; 145 | 146 | fn is_contiguous(&self) -> bool { 147 | true 148 | } 149 | 150 | fn shape(&self) -> &S { 151 | &self.shape 152 | } 153 | 154 | fn stride(&self, index: usize) -> isize { 155 | assert!(index < self.rank(), "invalid dimension"); 156 | 157 | let mut stride = 1; 158 | 159 | for i in index + 1..self.rank() { 160 | stride *= self.dim(i); 161 | } 162 | 163 | stride as isize 164 | } 165 | 166 | fn for_each_stride(&self, mut f: F) { 167 | let mut stride = 1; 168 | 169 | for i in (0..self.rank()).rev() { 170 | f(i, stride as isize); 171 | stride *= self.dim(i); 172 | } 173 | } 174 | 175 | fn inner_stride(&self) -> isize { 176 | // The inner stride should be a compile time constant with dense layout. 177 | // For static rank 0, we set it to 0 to allow inner rank >0 in iteration. 178 | if S::RANK == Some(0) { 0 } else { 1 } 179 | } 180 | 181 | fn linear_offset(&self, index: usize) -> isize { 182 | debug_assert!(index < self.len(), "index out of bounds"); 183 | 184 | index as isize 185 | } 186 | 187 | fn permute(mapping: &M, perm: &[usize]) -> Self { 188 | assert!(perm.len() == mapping.rank(), "invalid permutation"); 189 | 190 | for i in 0..mapping.rank() { 191 | assert!(perm[i] == i, "invalid permutation"); 192 | } 193 | 194 | Self::remap(mapping) 195 | } 196 | 197 | fn prepend_dim(mapping: &M, size: usize, stride: isize) -> Self { 198 | assert!(M::Layout::IS_DENSE, "invalid layout"); 199 | assert!(stride == mapping.len() as isize, "invalid stride"); 200 | 201 | Self::new(mapping.shape().prepend_dim(size)) 202 | } 203 | 204 | fn remap(mapping: &M) -> Self { 205 | assert!(mapping.is_contiguous(), "mapping not contiguous"); 206 | 207 | Self::new(mapping.shape().with_dims(Shape::from_dims)) 208 | } 209 | 210 | fn remove_dim(mapping: &M, index: usize) -> Self { 211 | assert!(M::Layout::IS_DENSE, "invalid layout"); 212 | assert!(index == 0, "invalid dimension"); 213 | 214 | Self::new(mapping.shape().remove_dim(index)) 215 | } 216 | 217 | fn reshape(&self, new_shape: R) -> DenseMapping { 218 | DenseMapping::new(self.shape.reshape(new_shape)) 219 | } 220 | 221 | fn resize_dim(mapping: &M, index: usize, new_size: usize) -> Self { 222 | assert!(M::Layout::IS_DENSE, "invalid layout"); 223 | assert!(index == 0, "invalid dimension"); 224 | 225 | Self::new(mapping.shape().resize_dim(index, new_size)) 226 | } 227 | 228 | fn shape_mut(&mut self) -> &mut S { 229 | &mut self.shape 230 | } 231 | 232 | fn transpose>>(mapping: &M) -> Self { 233 | assert!(mapping.rank() < 2 && M::Layout::IS_DENSE, "invalid layout"); 234 | 235 | Self::new(mapping.shape().reverse()) 236 | } 237 | } 238 | 239 | impl StridedMapping { 240 | /// Creates a new, strided layout mapping with the specified shape and strides. 241 | pub fn new(shape: S, strides: &[isize]) -> Self { 242 | assert!(shape.rank() == strides.len(), "length mismatch"); 243 | 244 | Self { shape, strides: TryFrom::try_from(strides).expect("invalid rank") } 245 | } 246 | 247 | /// Returns the distance between elements in each dimension. 248 | pub fn strides(&self) -> &[isize] { 249 | self.strides.as_ref() 250 | } 251 | } 252 | 253 | impl Default for StridedMapping { 254 | fn default() -> Self { 255 | Self { shape: S::default(), strides: S::Dims::new(S::default().rank()) } 256 | } 257 | } 258 | 259 | impl Clone for StridedMapping { 260 | fn clone(&self) -> Self { 261 | Self { shape: self.shape.clone(), strides: self.strides.clone() } 262 | } 263 | 264 | fn clone_from(&mut self, source: &Self) { 265 | self.shape.clone_from(&source.shape); 266 | self.strides.clone_from(&source.strides); 267 | } 268 | } 269 | 270 | impl: Copy> + Copy> Copy for StridedMapping {} 271 | 272 | impl Mapping for StridedMapping { 273 | type Shape = S; 274 | type Layout = Strided; 275 | 276 | fn is_contiguous(&self) -> bool { 277 | let mut stride = 1; 278 | 279 | for i in (0..self.rank()).rev() { 280 | if self.strides.as_ref()[i] != stride { 281 | return false; 282 | } 283 | 284 | stride *= self.dim(i) as isize; 285 | } 286 | 287 | true 288 | } 289 | 290 | fn shape(&self) -> &S { 291 | &self.shape 292 | } 293 | 294 | fn stride(&self, index: usize) -> isize { 295 | assert!(index < self.rank(), "invalid dimension"); 296 | 297 | self.strides.as_ref()[index] 298 | } 299 | 300 | fn for_each_stride(&self, mut f: F) { 301 | for i in 0..self.rank() { 302 | f(i, self.strides.as_ref()[i]) 303 | } 304 | } 305 | 306 | fn inner_stride(&self) -> isize { 307 | if self.rank() > 0 { self.strides.as_ref()[self.rank() - 1] } else { 0 } 308 | } 309 | 310 | fn linear_offset(&self, index: usize) -> isize { 311 | debug_assert!(index < self.len(), "index out of bounds"); 312 | 313 | let mut dividend = index; 314 | let mut offset = 0; 315 | 316 | for i in (0..self.rank()).rev() { 317 | offset += self.strides.as_ref()[i] * (dividend % self.dim(i)) as isize; 318 | dividend /= self.dim(i); 319 | } 320 | 321 | offset 322 | } 323 | 324 | fn permute(mapping: &M, perm: &[usize]) -> Self { 325 | assert!(perm.len() == mapping.rank(), "invalid permutation"); 326 | 327 | let mut index_mask = 0; 328 | 329 | for i in 0..mapping.rank() { 330 | assert!(perm[i] < mapping.rank(), "invalid permutation"); 331 | 332 | index_mask |= 1 << perm[i]; 333 | } 334 | 335 | assert!(index_mask == !(usize::MAX << mapping.rank()), "invalid permutation"); 336 | 337 | let mut shape = S::new(mapping.rank()); 338 | let mut strides = S::Dims::new(mapping.rank()); 339 | 340 | shape.with_mut_dims(|dims| { 341 | // Calculate inverse permutation 342 | for i in 0..mapping.rank() { 343 | dims[perm[i]] = i; 344 | } 345 | 346 | // Permute strides 347 | mapping.for_each_stride(|i, stride| strides.as_mut()[dims[i]] = stride); 348 | 349 | // Permute shape 350 | for i in 0..mapping.rank() { 351 | dims[i] = mapping.dim(perm[i]); 352 | } 353 | }); 354 | 355 | Self { shape, strides } 356 | } 357 | 358 | fn prepend_dim(mapping: &M, size: usize, stride: isize) -> Self { 359 | let mut strides = S::Dims::new(mapping.rank() + 1); 360 | 361 | strides.as_mut()[0] = stride; 362 | mapping.for_each_stride(|i, stride| strides.as_mut()[i + 1] = stride); 363 | 364 | Self { shape: mapping.shape().prepend_dim(size), strides } 365 | } 366 | 367 | fn remap(mapping: &M) -> Self { 368 | let mut strides = S::Dims::new(mapping.rank()); 369 | 370 | mapping.for_each_stride(|i, stride| strides.as_mut()[i] = stride); 371 | 372 | Self { shape: mapping.shape().with_dims(Shape::from_dims), strides } 373 | } 374 | 375 | fn remove_dim(mapping: &M, index: usize) -> Self { 376 | assert!(index < mapping.rank(), "invalid dimension"); 377 | 378 | let mut strides = S::Dims::new(mapping.rank() - 1); 379 | 380 | mapping.for_each_stride(|i, stride| { 381 | if i < index { 382 | strides.as_mut()[i] = stride; 383 | } else if i > index { 384 | strides.as_mut()[i - 1] = stride; 385 | } 386 | }); 387 | 388 | Self { shape: mapping.shape().remove_dim(index), strides } 389 | } 390 | 391 | fn reshape(&self, new_shape: R) -> StridedMapping { 392 | let new_shape = self.shape.reshape(new_shape); 393 | let mut new_strides = R::Dims::new(new_shape.rank()); 394 | 395 | let mut old_len = 1usize; 396 | let mut new_len = 1usize; 397 | 398 | let mut old_stride = 1; 399 | let mut new_stride = 1; 400 | 401 | let mut valid_layout = true; 402 | 403 | let mut j = new_shape.rank(); 404 | 405 | for i in (0..self.rank()).rev() { 406 | // Set strides for the next region or extend the current region. 407 | if old_len == new_len { 408 | old_stride = self.strides.as_ref()[i]; 409 | new_stride = old_stride; 410 | } else { 411 | valid_layout &= old_stride == self.strides.as_ref()[i]; 412 | } 413 | 414 | old_len *= self.dim(i); 415 | old_stride *= self.dim(i) as isize; 416 | 417 | // Add dimensions within the current region. 418 | while j > 0 { 419 | if new_len * new_shape.dim(j - 1) > old_len { 420 | break; 421 | } 422 | 423 | j -= 1; 424 | 425 | new_strides.as_mut()[j] = new_stride; 426 | 427 | new_len *= new_shape.dim(j); 428 | new_stride *= new_shape.dim(j) as isize; 429 | } 430 | } 431 | 432 | // Add remaining dimensions. 433 | while j > 0 { 434 | j -= 1; 435 | 436 | new_strides.as_mut()[j] = new_stride; 437 | 438 | new_len *= new_shape.dim(j); 439 | new_stride *= new_shape.dim(j) as isize; 440 | } 441 | 442 | assert!(new_len == 0 || valid_layout, "memory layout not compatible"); 443 | 444 | StridedMapping { shape: new_shape, strides: new_strides } 445 | } 446 | 447 | fn resize_dim(mapping: &M, index: usize, new_size: usize) -> Self { 448 | let mut strides = S::Dims::new(mapping.rank()); 449 | 450 | mapping.for_each_stride(|i, stride| strides.as_mut()[i] = stride); 451 | 452 | Self { shape: mapping.shape().resize_dim(index, new_size), strides } 453 | } 454 | 455 | fn shape_mut(&mut self) -> &mut S { 456 | &mut self.shape 457 | } 458 | 459 | fn transpose>>(mapping: &M) -> Self { 460 | let mut strides = S::Dims::new(mapping.rank()); 461 | 462 | mapping.for_each_stride(|i, stride| strides.as_mut()[mapping.rank() - 1 - i] = stride); 463 | 464 | Self { shape: mapping.shape().reverse(), strides } 465 | } 466 | } 467 | -------------------------------------------------------------------------------- /src/ops.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nightly")] 2 | use std::alloc::Allocator; 3 | 4 | use std::ops::{ 5 | Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, 6 | Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, 7 | }; 8 | 9 | #[cfg(not(feature = "nightly"))] 10 | use crate::alloc::Allocator; 11 | use crate::array::Array; 12 | use crate::expr::{Apply, Buffer, Expression, IntoExpression}; 13 | use crate::expr::{Fill, FillWith, FromElem, FromFn, IntoExpr, Map}; 14 | use crate::layout::Layout; 15 | use crate::shape::{ConstShape, Shape}; 16 | use crate::slice::Slice; 17 | use crate::tensor::Tensor; 18 | use crate::view::{View, ViewMut}; 19 | 20 | /// Range constructed from a unit spaced range with the given step size. 21 | #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] 22 | pub struct StepRange { 23 | /// Unit spaced range. 24 | pub range: R, 25 | 26 | /// Step size. 27 | pub step: S, 28 | } 29 | 30 | /// Creates a range with the given step size from a unit spaced range. 31 | /// 32 | /// If the step size is negative, the result is obtained by reversing the input range 33 | /// and stepping by the absolute value of the step size. 34 | /// 35 | /// # Examples 36 | /// 37 | /// ``` 38 | /// use mdarray::{step, view}; 39 | /// 40 | /// let v = view![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; 41 | /// 42 | /// assert_eq!(v.view(step(0..10, 2)).to_vec(), [0, 2, 4, 6, 8]); 43 | /// assert_eq!(v.view(step(0..10, -2)).to_vec(), [9, 7, 5, 3, 1]); 44 | /// ``` 45 | pub fn step(range: R, step: S) -> StepRange { 46 | StepRange { range, step } 47 | } 48 | 49 | impl Eq for Array {} 50 | impl Eq for Slice {} 51 | impl Eq for Tensor {} 52 | impl Eq for View<'_, T, S, L> {} 53 | impl Eq for ViewMut<'_, T, S, L> {} 54 | 55 | impl PartialEq for Array 56 | where 57 | for<'a> &'a I: IntoExpression>, 58 | T: PartialEq, 59 | { 60 | fn eq(&self, other: &I) -> bool { 61 | (**self).eq(other) 62 | } 63 | } 64 | 65 | impl PartialEq for Slice 66 | where 67 | for<'a> &'a I: IntoExpression>, 68 | T: PartialEq, 69 | { 70 | fn eq(&self, other: &I) -> bool { 71 | let other = other.into_expr(); 72 | 73 | if self.shape().with_dims(|dims| other.shape().with_dims(|other| dims == other)) { 74 | // Avoid very long compile times for release build with MIR inlining, 75 | // by avoiding recursion until types are known. 76 | // 77 | // This is a workaround until const if is available, see #3582 and #122301. 78 | 79 | fn compare_dense( 80 | this: &Slice, 81 | other: &Slice, 82 | ) -> bool 83 | where 84 | T: PartialEq, 85 | { 86 | this.remap::()[..].eq(&other.remap::()[..]) 87 | } 88 | 89 | fn compare_strided( 90 | this: &Slice, 91 | other: &Slice, 92 | ) -> bool 93 | where 94 | T: PartialEq, 95 | { 96 | if this.rank() < 2 { 97 | this.iter().eq(other) 98 | } else { 99 | this.outer_expr().into_iter().eq(other.outer_expr()) 100 | } 101 | } 102 | 103 | let f = 104 | const { if L::IS_DENSE && K::IS_DENSE { compare_dense } else { compare_strided } }; 105 | 106 | f(self, &other) 107 | } else { 108 | false 109 | } 110 | } 111 | } 112 | 113 | impl PartialEq for Tensor 114 | where 115 | for<'a> &'a I: IntoExpression>, 116 | T: PartialEq, 117 | { 118 | fn eq(&self, other: &I) -> bool { 119 | (**self).eq(other) 120 | } 121 | } 122 | 123 | impl PartialEq for View<'_, T, S, L> 124 | where 125 | for<'a> &'a I: IntoExpression>, 126 | T: PartialEq, 127 | { 128 | fn eq(&self, other: &I) -> bool { 129 | (**self).eq(other) 130 | } 131 | } 132 | 133 | impl PartialEq 134 | for ViewMut<'_, T, S, L> 135 | where 136 | for<'a> &'a I: IntoExpression>, 137 | T: PartialEq, 138 | { 139 | fn eq(&self, other: &I) -> bool { 140 | (**self).eq(other) 141 | } 142 | } 143 | 144 | macro_rules! impl_binary_op { 145 | ($trt:tt, $fn:tt) => { 146 | impl<'a, T, U, S: ConstShape, I: Apply> $trt for &'a Array 147 | where 148 | &'a T: $trt, 149 | { 150 | #[cfg(not(feature = "nightly"))] 151 | type Output = I::ZippedWith U>; 152 | 153 | #[cfg(feature = "nightly")] 154 | type Output = I::ZippedWith U>; 155 | 156 | fn $fn(self, rhs: I) -> Self::Output { 157 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 158 | } 159 | } 160 | 161 | impl<'a, T, U, S: Shape, L: Layout, I: Apply> $trt for &'a Slice 162 | where 163 | &'a T: $trt, 164 | { 165 | #[cfg(not(feature = "nightly"))] 166 | type Output = I::ZippedWith U>; 167 | 168 | #[cfg(feature = "nightly")] 169 | type Output = I::ZippedWith U>; 170 | 171 | fn $fn(self, rhs: I) -> Self::Output { 172 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 173 | } 174 | } 175 | 176 | impl<'a, T, U, S: Shape, A: Allocator, I: Apply> $trt for &'a Tensor 177 | where 178 | &'a T: $trt, 179 | { 180 | #[cfg(not(feature = "nightly"))] 181 | type Output = I::ZippedWith U>; 182 | 183 | #[cfg(feature = "nightly")] 184 | type Output = I::ZippedWith U>; 185 | 186 | fn $fn(self, rhs: I) -> Self::Output { 187 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 188 | } 189 | } 190 | 191 | impl<'a, T, U, S: Shape, L: Layout, I: Apply> $trt for &'a View<'_, T, S, L> 192 | where 193 | &'a T: $trt, 194 | { 195 | #[cfg(not(feature = "nightly"))] 196 | type Output = I::ZippedWith U>; 197 | 198 | #[cfg(feature = "nightly")] 199 | type Output = I::ZippedWith U>; 200 | 201 | fn $fn(self, rhs: I) -> Self::Output { 202 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 203 | } 204 | } 205 | 206 | impl<'a, T, U, S: Shape, L: Layout, I: Apply> $trt for &'a ViewMut<'_, T, S, L> 207 | where 208 | &'a T: $trt, 209 | { 210 | #[cfg(not(feature = "nightly"))] 211 | type Output = I::ZippedWith U>; 212 | 213 | #[cfg(feature = "nightly")] 214 | type Output = I::ZippedWith U>; 215 | 216 | fn $fn(self, rhs: I) -> Self::Output { 217 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 218 | } 219 | } 220 | 221 | impl $trt for Array 222 | where 223 | T: $trt, 224 | { 225 | type Output = Array; 226 | 227 | fn $fn(self, rhs: I) -> Self::Output { 228 | self.zip_with(rhs, |(x, y)| x.$fn(y)) 229 | } 230 | } 231 | 232 | impl> $trt for Fill 233 | where 234 | T: $trt, 235 | { 236 | #[cfg(not(feature = "nightly"))] 237 | type Output = I::ZippedWith U>; 238 | 239 | #[cfg(feature = "nightly")] 240 | type Output = I::ZippedWith U>; 241 | 242 | fn $fn(self, rhs: I) -> Self::Output { 243 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 244 | } 245 | } 246 | 247 | impl T, I: Apply> $trt for FillWith 248 | where 249 | T: $trt, 250 | { 251 | #[cfg(not(feature = "nightly"))] 252 | type Output = I::ZippedWith U>; 253 | 254 | #[cfg(feature = "nightly")] 255 | type Output = I::ZippedWith U>; 256 | 257 | fn $fn(self, rhs: I) -> Self::Output { 258 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 259 | } 260 | } 261 | 262 | impl> $trt for FromElem 263 | where 264 | T: $trt, 265 | { 266 | #[cfg(not(feature = "nightly"))] 267 | type Output = I::ZippedWith U>; 268 | 269 | #[cfg(feature = "nightly")] 270 | type Output = I::ZippedWith U>; 271 | 272 | fn $fn(self, rhs: I) -> Self::Output { 273 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 274 | } 275 | } 276 | 277 | impl T, I: Apply> $trt for FromFn 278 | where 279 | T: $trt, 280 | { 281 | #[cfg(not(feature = "nightly"))] 282 | type Output = I::ZippedWith U>; 283 | 284 | #[cfg(feature = "nightly")] 285 | type Output = I::ZippedWith U>; 286 | 287 | fn $fn(self, rhs: I) -> Self::Output { 288 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 289 | } 290 | } 291 | 292 | impl> $trt for IntoExpr 293 | where 294 | B::Item: $trt, 295 | { 296 | #[cfg(not(feature = "nightly"))] 297 | type Output = I::ZippedWith T>; 298 | 299 | #[cfg(feature = "nightly")] 300 | type Output = I::ZippedWith T>; 301 | 302 | fn $fn(self, rhs: I) -> Self::Output { 303 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 304 | } 305 | } 306 | 307 | impl T, I: Apply> $trt for Map 308 | where 309 | T: $trt, 310 | { 311 | #[cfg(not(feature = "nightly"))] 312 | type Output = I::ZippedWith U>; 313 | 314 | #[cfg(feature = "nightly")] 315 | type Output = I::ZippedWith U>; 316 | 317 | fn $fn(self, rhs: I) -> Self::Output { 318 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 319 | } 320 | } 321 | 322 | impl $trt for Tensor 323 | where 324 | T: $trt, 325 | { 326 | type Output = Self; 327 | 328 | fn $fn(self, rhs: I) -> Self { 329 | self.zip_with(rhs, |(x, y)| x.$fn(y)) 330 | } 331 | } 332 | 333 | impl<'a, T, U, S: Shape, L: Layout, I: Apply> $trt for View<'a, T, S, L> 334 | where 335 | &'a T: $trt, 336 | { 337 | #[cfg(not(feature = "nightly"))] 338 | type Output = I::ZippedWith U>; 339 | 340 | #[cfg(feature = "nightly")] 341 | type Output = I::ZippedWith U>; 342 | 343 | fn $fn(self, rhs: I) -> Self::Output { 344 | rhs.zip_with(self, |(x, y)| y.$fn(x)) 345 | } 346 | } 347 | }; 348 | } 349 | 350 | impl_binary_op!(Add, add); 351 | impl_binary_op!(Sub, sub); 352 | impl_binary_op!(Mul, mul); 353 | impl_binary_op!(Div, div); 354 | impl_binary_op!(Rem, rem); 355 | impl_binary_op!(BitAnd, bitand); 356 | impl_binary_op!(BitOr, bitor); 357 | impl_binary_op!(BitXor, bitxor); 358 | impl_binary_op!(Shl, shl); 359 | impl_binary_op!(Shr, shr); 360 | 361 | macro_rules! impl_op_assign { 362 | ($trt:tt, $fn:tt) => { 363 | impl $trt for Array 364 | where 365 | T: $trt, 366 | { 367 | fn $fn(&mut self, rhs: I) { 368 | self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y)); 369 | } 370 | } 371 | 372 | impl $trt for Slice 373 | where 374 | T: $trt, 375 | { 376 | fn $fn(&mut self, rhs: I) { 377 | self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y)); 378 | } 379 | } 380 | 381 | impl $trt for Tensor 382 | where 383 | T: $trt, 384 | { 385 | fn $fn(&mut self, rhs: I) { 386 | self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y)); 387 | } 388 | } 389 | 390 | impl $trt for ViewMut<'_, T, S, L> 391 | where 392 | T: $trt, 393 | { 394 | fn $fn(&mut self, rhs: I) { 395 | self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y)); 396 | } 397 | } 398 | }; 399 | } 400 | 401 | impl_op_assign!(AddAssign, add_assign); 402 | impl_op_assign!(SubAssign, sub_assign); 403 | impl_op_assign!(MulAssign, mul_assign); 404 | impl_op_assign!(DivAssign, div_assign); 405 | impl_op_assign!(RemAssign, rem_assign); 406 | impl_op_assign!(BitAndAssign, bitand_assign); 407 | impl_op_assign!(BitOrAssign, bitor_assign); 408 | impl_op_assign!(BitXorAssign, bitxor_assign); 409 | impl_op_assign!(ShlAssign, shl_assign); 410 | impl_op_assign!(ShrAssign, shr_assign); 411 | 412 | macro_rules! impl_unary_op { 413 | ($trt:tt, $fn:tt) => { 414 | impl<'a, T, U, S: ConstShape> $trt for &'a Array 415 | where 416 | &'a T: $trt, 417 | { 418 | #[cfg(not(feature = "nightly"))] 419 | type Output = >::Output U>; 420 | 421 | #[cfg(feature = "nightly")] 422 | type Output = >::Output U>; 423 | 424 | fn $fn(self) -> Self::Output { 425 | self.apply(|x| x.$fn()) 426 | } 427 | } 428 | 429 | impl<'a, T, U, S: Shape, L: Layout> $trt for &'a Slice 430 | where 431 | &'a T: $trt, 432 | { 433 | #[cfg(not(feature = "nightly"))] 434 | type Output = >::Output U>; 435 | 436 | #[cfg(feature = "nightly")] 437 | type Output = >::Output U>; 438 | 439 | fn $fn(self) -> Self::Output { 440 | self.apply(|x| x.$fn()) 441 | } 442 | } 443 | 444 | impl<'a, T, U, S: Shape, A: Allocator> $trt for &'a Tensor 445 | where 446 | &'a T: $trt, 447 | { 448 | #[cfg(not(feature = "nightly"))] 449 | type Output = >::Output U>; 450 | 451 | #[cfg(feature = "nightly")] 452 | type Output = >::Output U>; 453 | 454 | fn $fn(self) -> Self::Output { 455 | self.apply(|x| x.$fn()) 456 | } 457 | } 458 | 459 | impl<'a, T, U, S: Shape, L: Layout> $trt for &'a View<'_, T, S, L> 460 | where 461 | &'a T: $trt, 462 | { 463 | #[cfg(not(feature = "nightly"))] 464 | type Output = >::Output U>; 465 | 466 | #[cfg(feature = "nightly")] 467 | type Output = >::Output U>; 468 | 469 | fn $fn(self) -> Self::Output { 470 | self.apply(|x| x.$fn()) 471 | } 472 | } 473 | 474 | impl<'a, T, U, S: Shape, L: Layout> $trt for &'a ViewMut<'_, T, S, L> 475 | where 476 | &'a T: $trt, 477 | { 478 | #[cfg(not(feature = "nightly"))] 479 | type Output = >::Output U>; 480 | 481 | #[cfg(feature = "nightly")] 482 | type Output = >::Output U>; 483 | 484 | fn $fn(self) -> Self::Output { 485 | self.apply(|x| x.$fn()) 486 | } 487 | } 488 | 489 | impl $trt for Array 490 | where 491 | T: $trt, 492 | { 493 | type Output = Array; 494 | 495 | fn $fn(self) -> Self::Output { 496 | self.apply(|x| x.$fn()) 497 | } 498 | } 499 | 500 | impl $trt for Fill 501 | where 502 | T: $trt, 503 | { 504 | #[cfg(not(feature = "nightly"))] 505 | type Output = >::Output U>; 506 | 507 | #[cfg(feature = "nightly")] 508 | type Output = >::Output U>; 509 | 510 | fn $fn(self) -> Self::Output { 511 | self.apply(|x| x.$fn()) 512 | } 513 | } 514 | 515 | impl T> $trt for FillWith 516 | where 517 | T: $trt, 518 | { 519 | #[cfg(not(feature = "nightly"))] 520 | type Output = >::Output U>; 521 | 522 | #[cfg(feature = "nightly")] 523 | type Output = >::Output U>; 524 | 525 | fn $fn(self) -> Self::Output { 526 | self.apply(|x| x.$fn()) 527 | } 528 | } 529 | 530 | impl $trt for FromElem 531 | where 532 | T: $trt, 533 | { 534 | #[cfg(not(feature = "nightly"))] 535 | type Output = >::Output U>; 536 | 537 | #[cfg(feature = "nightly")] 538 | type Output = >::Output U>; 539 | 540 | fn $fn(self) -> Self::Output { 541 | self.apply(|x| x.$fn()) 542 | } 543 | } 544 | 545 | impl T> $trt for FromFn 546 | where 547 | T: $trt, 548 | { 549 | #[cfg(not(feature = "nightly"))] 550 | type Output = >::Output U>; 551 | 552 | #[cfg(feature = "nightly")] 553 | type Output = >::Output U>; 554 | 555 | fn $fn(self) -> Self::Output { 556 | self.apply(|x| x.$fn()) 557 | } 558 | } 559 | 560 | impl $trt for IntoExpr 561 | where 562 | B::Item: $trt, 563 | { 564 | #[cfg(not(feature = "nightly"))] 565 | type Output = >::Output T>; 566 | 567 | #[cfg(feature = "nightly")] 568 | type Output = >::Output T>; 569 | 570 | fn $fn(self) -> Self::Output { 571 | self.apply(|x| x.$fn()) 572 | } 573 | } 574 | 575 | impl T> $trt for Map 576 | where 577 | T: $trt, 578 | { 579 | #[cfg(not(feature = "nightly"))] 580 | type Output = >::Output U>; 581 | 582 | #[cfg(feature = "nightly")] 583 | type Output = >::Output U>; 584 | 585 | fn $fn(self) -> Self::Output { 586 | self.apply(|x| x.$fn()) 587 | } 588 | } 589 | 590 | impl $trt for Tensor 591 | where 592 | T: $trt, 593 | { 594 | type Output = Self; 595 | 596 | fn $fn(self) -> Self { 597 | self.apply(|x| x.$fn()) 598 | } 599 | } 600 | 601 | impl<'a, T, U, S: Shape, L: Layout> $trt for View<'a, T, S, L> 602 | where 603 | &'a T: $trt, 604 | { 605 | #[cfg(not(feature = "nightly"))] 606 | type Output = >::Output U>; 607 | 608 | #[cfg(feature = "nightly")] 609 | type Output = >::Output U>; 610 | 611 | fn $fn(self) -> Self::Output { 612 | self.apply(|x| x.$fn()) 613 | } 614 | } 615 | }; 616 | } 617 | 618 | impl_unary_op!(Neg, neg); 619 | impl_unary_op!(Not, not); 620 | -------------------------------------------------------------------------------- /src/raw_slice.rs: -------------------------------------------------------------------------------- 1 | use std::mem; 2 | use std::ptr::NonNull; 3 | 4 | use crate::layout::Layout; 5 | use crate::shape::Shape; 6 | use crate::slice::Slice; 7 | 8 | pub(crate) struct RawSlice { 9 | ptr: NonNull, 10 | mapping: L::Mapping, 11 | } 12 | 13 | impl RawSlice { 14 | pub(crate) fn as_mut_ptr(&mut self) -> *mut T { 15 | self.ptr.as_ptr() 16 | } 17 | 18 | pub(crate) fn as_mut_slice(&mut self) -> &mut Slice { 19 | if mem::size_of::>() > 0 { 20 | unsafe { &mut *(self as *mut Self as *mut Slice) } 21 | } else { 22 | unsafe { &mut *(self.ptr.as_ptr() as *mut Slice) } 23 | } 24 | } 25 | 26 | pub(crate) fn as_ptr(&self) -> *const T { 27 | self.ptr.as_ptr() 28 | } 29 | 30 | pub(crate) fn as_slice(&self) -> &Slice { 31 | if mem::size_of::>() > 0 { 32 | unsafe { &*(self as *const Self as *const Slice) } 33 | } else { 34 | unsafe { &*(self.ptr.as_ptr() as *const Slice) } 35 | } 36 | } 37 | 38 | pub(crate) fn from_mut_slice(slice: &mut Slice) -> &mut Self { 39 | assert!(mem::size_of::>() > 0, "ZST not allowed"); 40 | 41 | unsafe { &mut *(slice as *mut Slice as *mut Self) } 42 | } 43 | 44 | pub(crate) fn from_slice(slice: &Slice) -> &Self { 45 | assert!(mem::size_of::>() > 0, "ZST not allowed"); 46 | 47 | unsafe { &*(slice as *const Slice as *const Self) } 48 | } 49 | 50 | pub(crate) fn mapping(&self) -> &L::Mapping { 51 | &self.mapping 52 | } 53 | 54 | pub(crate) unsafe fn mapping_mut(&mut self) -> &mut L::Mapping { 55 | &mut self.mapping 56 | } 57 | 58 | pub(crate) unsafe fn new_unchecked(ptr: *mut T, mapping: L::Mapping) -> Self { 59 | unsafe { Self { ptr: NonNull::new_unchecked(ptr), mapping } } 60 | } 61 | 62 | pub(crate) unsafe fn set_ptr(&mut self, new_ptr: *mut T) { 63 | self.ptr = unsafe { NonNull::new_unchecked(new_ptr) }; 64 | } 65 | } 66 | 67 | impl Clone for RawSlice { 68 | fn clone(&self) -> Self { 69 | Self { ptr: self.ptr, mapping: self.mapping.clone() } 70 | } 71 | 72 | fn clone_from(&mut self, source: &Self) { 73 | self.ptr = source.ptr; 74 | self.mapping.clone_from(&source.mapping); 75 | } 76 | } 77 | 78 | impl: Copy>> Copy for RawSlice {} 79 | -------------------------------------------------------------------------------- /src/raw_tensor.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nightly")] 2 | use std::alloc::Allocator; 3 | use std::marker::PhantomData; 4 | use std::mem::{self, ManuallyDrop}; 5 | use std::ptr; 6 | 7 | #[cfg(not(feature = "nightly"))] 8 | use crate::alloc::Allocator; 9 | use crate::layout::Dense; 10 | use crate::mapping::{DenseMapping, Mapping}; 11 | use crate::raw_slice::RawSlice; 12 | use crate::shape::Shape; 13 | use crate::slice::Slice; 14 | 15 | #[cfg(not(feature = "nightly"))] 16 | macro_rules! vec_t { 17 | ($type:ty, $alloc:ty) => { 18 | Vec<$type> 19 | }; 20 | } 21 | 22 | #[cfg(feature = "nightly")] 23 | macro_rules! vec_t { 24 | ($type:ty, $alloc:ty) => { 25 | Vec<$type, $alloc> 26 | }; 27 | } 28 | 29 | pub(crate) struct RawTensor { 30 | slice: RawSlice, 31 | capacity: usize, 32 | #[cfg(not(feature = "nightly"))] 33 | phantom: PhantomData, 34 | #[cfg(feature = "nightly")] 35 | alloc: ManuallyDrop, 36 | } 37 | 38 | struct DropGuard<'a, T, A: Allocator> { 39 | ptr: *mut T, 40 | len: usize, 41 | #[cfg(not(feature = "nightly"))] 42 | phantom: PhantomData<(&'a mut Vec, &'a A)>, 43 | #[cfg(feature = "nightly")] 44 | phantom: PhantomData<&'a mut Vec>, 45 | } 46 | 47 | impl RawTensor { 48 | #[cfg(feature = "nightly")] 49 | pub(crate) fn allocator(&self) -> &A { 50 | &self.alloc 51 | } 52 | 53 | pub(crate) fn as_mut_slice(&mut self) -> &mut Slice { 54 | self.slice.as_mut_slice() 55 | } 56 | 57 | pub(crate) fn as_slice(&self) -> &Slice { 58 | self.slice.as_slice() 59 | } 60 | 61 | pub(crate) fn capacity(&self) -> usize { 62 | if mem::size_of::() > 0 { self.capacity } else { usize::MAX } 63 | } 64 | 65 | #[cfg(not(feature = "nightly"))] 66 | pub(crate) unsafe fn from_parts(vec: Vec, mapping: DenseMapping) -> Self { 67 | debug_assert!(Some(vec.len()) == mapping.shape().checked_len(), "length mismatch"); 68 | 69 | let mut vec = ManuallyDrop::new(vec); 70 | 71 | Self { 72 | slice: unsafe { RawSlice::new_unchecked(vec.as_mut_ptr(), mapping) }, 73 | capacity: vec.capacity(), 74 | phantom: PhantomData, 75 | } 76 | } 77 | 78 | #[cfg(feature = "nightly")] 79 | pub(crate) unsafe fn from_parts(vec: Vec, mapping: DenseMapping) -> Self { 80 | debug_assert!(Some(vec.len()) == mapping.shape().checked_len(), "length mismatch"); 81 | 82 | let (ptr, _, capacity, alloc) = vec.into_raw_parts_with_alloc(); 83 | 84 | Self { 85 | slice: unsafe { RawSlice::new_unchecked(ptr, mapping) }, 86 | capacity, 87 | alloc: ManuallyDrop::new(alloc), 88 | } 89 | } 90 | 91 | pub(crate) fn into_parts(self) -> (vec_t!(T, A), DenseMapping) { 92 | let mut me = ManuallyDrop::new(self); 93 | 94 | #[cfg(not(feature = "nightly"))] 95 | let vec = unsafe { 96 | Vec::from_raw_parts(me.slice.as_mut_ptr(), me.slice.mapping().len(), me.capacity) 97 | }; 98 | #[cfg(feature = "nightly")] 99 | let vec = unsafe { 100 | Vec::from_raw_parts_in( 101 | me.slice.as_mut_ptr(), 102 | me.slice.mapping().len(), 103 | me.capacity, 104 | ManuallyDrop::take(&mut me.alloc), 105 | ) 106 | }; 107 | 108 | unsafe { (vec, ptr::read(me.slice.mapping())) } 109 | } 110 | 111 | pub(crate) fn resize_with T>(&mut self, new_dims: &[usize], mut f: F) 112 | where 113 | A: Clone, 114 | { 115 | assert!(new_dims.len() == self.slice.mapping().rank(), "invalid rank"); 116 | 117 | if !new_dims.is_empty() { 118 | let new_len = new_dims.iter().try_fold(1usize, |acc, &x| acc.checked_mul(x)); 119 | let new_len = new_len.expect("invalid length"); 120 | 121 | unsafe { 122 | self.with_mut_parts(|vec, old_mapping| { 123 | old_mapping.shape().with_dims(|old_dims| { 124 | if new_len == 0 { 125 | vec.clear(); 126 | } else if new_dims[1..] == old_dims[1..] { 127 | vec.resize_with(new_len, &mut f); 128 | } else { 129 | #[cfg(not(feature = "nightly"))] 130 | let mut new_vec = Vec::with_capacity(new_len); 131 | #[cfg(feature = "nightly")] 132 | let mut new_vec = 133 | Vec::with_capacity_in(new_len, vec.allocator().clone()); 134 | 135 | copy_dim::( 136 | &mut DropGuard::new(vec), 137 | &mut new_vec, 138 | old_dims, 139 | new_dims, 140 | &mut f, 141 | ); 142 | 143 | *vec = new_vec; 144 | } 145 | }); 146 | 147 | old_mapping.shape_mut().with_mut_dims(|dims| dims.copy_from_slice(new_dims)); 148 | }); 149 | } 150 | } 151 | } 152 | 153 | pub(crate) unsafe fn set_mapping(&mut self, new_mapping: DenseMapping) { 154 | debug_assert!(new_mapping.shape().checked_len().is_some(), "invalid length"); 155 | debug_assert!(new_mapping.len() <= self.capacity, "length exceeds capacity"); 156 | 157 | unsafe { 158 | *self.slice.mapping_mut() = new_mapping; 159 | } 160 | } 161 | 162 | #[cfg(not(feature = "nightly"))] 163 | pub(crate) unsafe fn with_mut_parts(&mut self, f: F) -> U 164 | where 165 | F: FnOnce(&mut Vec, &mut DenseMapping) -> U, 166 | { 167 | struct DropGuard<'a, T, S: Shape, A: Allocator> { 168 | tensor: &'a mut RawTensor, 169 | vec: ManuallyDrop>, 170 | } 171 | 172 | impl Drop for DropGuard<'_, T, S, A> { 173 | fn drop(&mut self) { 174 | unsafe { 175 | self.tensor.slice.set_ptr(self.vec.as_mut_ptr()); 176 | self.tensor.capacity = self.vec.capacity(); 177 | 178 | // Cleanup in case of length mismatch (e.g. due to allocation failure) 179 | if self.vec.len() != self.tensor.slice.mapping().len() { 180 | assert!(S::default().len() == 0, "default length not zero"); 181 | 182 | *self.tensor.slice.mapping_mut() = DenseMapping::default(); 183 | ptr::drop_in_place(self.vec.as_mut_slice()); 184 | } 185 | } 186 | } 187 | } 188 | 189 | let vec = unsafe { 190 | Vec::from_raw_parts(self.slice.as_mut_ptr(), self.slice.mapping().len(), self.capacity) 191 | }; 192 | 193 | let mut guard = DropGuard { tensor: self, vec: ManuallyDrop::new(vec) }; 194 | 195 | let mapping = unsafe { guard.tensor.slice.mapping_mut() }; 196 | let result = f(&mut guard.vec, mapping); 197 | 198 | debug_assert!(Some(guard.vec.len()) == mapping.shape().checked_len(), "length mismatch"); 199 | 200 | unsafe { 201 | guard.tensor.slice.set_ptr(guard.vec.as_mut_ptr()); 202 | guard.tensor.capacity = guard.vec.capacity(); 203 | } 204 | 205 | mem::forget(guard); 206 | 207 | result 208 | } 209 | 210 | #[cfg(feature = "nightly")] 211 | pub(crate) unsafe fn with_mut_parts(&mut self, f: F) -> U 212 | where 213 | F: FnOnce(&mut Vec, &mut DenseMapping) -> U, 214 | { 215 | struct DropGuard<'a, T, S: Shape, A: Allocator> { 216 | tensor: &'a mut RawTensor, 217 | vec: ManuallyDrop>, 218 | } 219 | 220 | impl Drop for DropGuard<'_, T, S, A> { 221 | fn drop(&mut self) { 222 | unsafe { 223 | self.tensor.slice.set_ptr(self.vec.as_mut_ptr()); 224 | self.tensor.capacity = self.vec.capacity(); 225 | self.tensor.alloc = ManuallyDrop::new(ptr::read(self.vec.allocator())); 226 | 227 | // Cleanup in case of length mismatch (e.g. due to allocation failure) 228 | if self.vec.len() != self.tensor.slice.mapping().len() { 229 | *self.tensor.slice.mapping_mut() = DenseMapping::default(); 230 | ptr::drop_in_place(self.vec.as_mut_slice()); 231 | } 232 | } 233 | } 234 | } 235 | 236 | let vec = unsafe { 237 | Vec::from_raw_parts_in( 238 | self.slice.as_mut_ptr(), 239 | self.slice.mapping().len(), 240 | self.capacity, 241 | ManuallyDrop::take(&mut self.alloc), 242 | ) 243 | }; 244 | 245 | let mut guard = DropGuard { tensor: self, vec: ManuallyDrop::new(vec) }; 246 | 247 | let mapping = unsafe { guard.tensor.slice.mapping_mut() }; 248 | let result = f(&mut guard.vec, mapping); 249 | 250 | debug_assert!(Some(guard.vec.len()) == mapping.shape().checked_len(), "length mismatch"); 251 | 252 | unsafe { 253 | guard.tensor.slice.set_ptr(guard.vec.as_mut_ptr()); 254 | guard.tensor.capacity = guard.vec.capacity(); 255 | guard.tensor.alloc = ManuallyDrop::new(ptr::read(guard.vec.allocator())); 256 | } 257 | 258 | mem::forget(guard); 259 | 260 | result 261 | } 262 | 263 | pub(crate) fn with_vec U>(&self, f: F) -> U { 264 | #[cfg(not(feature = "nightly"))] 265 | let vec = unsafe { 266 | Vec::from_raw_parts( 267 | self.slice.as_ptr() as *mut T, 268 | self.slice.mapping().len(), 269 | self.capacity, 270 | ) 271 | }; 272 | #[cfg(feature = "nightly")] 273 | let vec = unsafe { 274 | Vec::from_raw_parts_in( 275 | self.slice.as_ptr() as *mut T, 276 | self.slice.mapping().len(), 277 | self.capacity, 278 | ptr::read(&*self.alloc), 279 | ) 280 | }; 281 | 282 | f(&ManuallyDrop::new(vec)) 283 | } 284 | } 285 | 286 | impl Clone for RawTensor { 287 | fn clone(&self) -> Self { 288 | unsafe { Self::from_parts(self.with_vec(|vec| vec.clone()), self.slice.mapping().clone()) } 289 | } 290 | 291 | fn clone_from(&mut self, source: &Self) { 292 | unsafe { 293 | self.with_mut_parts(|dst, mapping| { 294 | source.with_vec(|src| dst.clone_from(src)); 295 | mapping.clone_from(source.slice.mapping()); 296 | }); 297 | } 298 | } 299 | } 300 | 301 | impl Drop for RawTensor { 302 | #[cfg(not(feature = "nightly"))] 303 | fn drop(&mut self) { 304 | _ = unsafe { 305 | Vec::from_raw_parts(self.slice.as_mut_ptr(), self.slice.mapping().len(), self.capacity) 306 | }; 307 | } 308 | 309 | #[cfg(feature = "nightly")] 310 | fn drop(&mut self) { 311 | _ = unsafe { 312 | Vec::from_raw_parts_in( 313 | self.slice.as_mut_ptr(), 314 | self.slice.mapping().len(), 315 | self.capacity, 316 | ManuallyDrop::take(&mut self.alloc), 317 | ) 318 | }; 319 | } 320 | } 321 | 322 | unsafe impl Send for RawTensor {} 323 | unsafe impl Sync for RawTensor {} 324 | 325 | impl<'a, T, A: Allocator> DropGuard<'a, T, A> { 326 | fn new(vec: &'a mut vec_t!(T, A)) -> Self { 327 | let len = vec.len(); 328 | 329 | unsafe { 330 | vec.set_len(0); 331 | } 332 | 333 | Self { ptr: vec.as_mut_ptr(), len, phantom: PhantomData } 334 | } 335 | } 336 | 337 | impl Drop for DropGuard<'_, T, A> { 338 | fn drop(&mut self) { 339 | unsafe { 340 | ptr::slice_from_raw_parts_mut(self.ptr, self.len).drop_in_place(); 341 | } 342 | } 343 | } 344 | 345 | unsafe fn copy_dim( 346 | old_vec: &mut DropGuard, 347 | new_vec: &mut vec_t!(T, A), 348 | old_dims: &[usize], 349 | new_dims: &[usize], 350 | f: &mut impl FnMut() -> T, 351 | ) { 352 | let old_stride: usize = old_dims[1..].iter().product(); 353 | let new_stride: usize = new_dims[1..].iter().product(); 354 | 355 | let old_size = old_dims[0]; 356 | let new_size = new_dims[0]; 357 | 358 | let min_size = old_size.min(new_size); 359 | 360 | unsafe { 361 | if old_dims.len() > 1 { 362 | // Avoid very long compile times for release build with MIR inlining, 363 | // by avoiding recursion until types are known. 364 | // 365 | // This is a workaround until const if is available, see #3582 and #122301. 366 | 367 | unsafe fn dummy( 368 | _: &mut DropGuard, 369 | _: &mut vec_t!(T, A), 370 | _: &[usize], 371 | _: &[usize], 372 | _: &mut impl FnMut() -> T, 373 | ) { 374 | unreachable!(); 375 | } 376 | 377 | let g = const { 378 | match S::RANK { 379 | Some(..2) => dummy::, 380 | _ => copy_dim::, 381 | } 382 | }; 383 | 384 | for _ in 0..min_size { 385 | g(old_vec, new_vec, &old_dims[1..], &new_dims[1..], f); 386 | } 387 | } else { 388 | debug_assert!(old_vec.len >= min_size, "slice exceeds remainder"); 389 | debug_assert!(new_vec.len() + min_size <= new_vec.capacity(), "slice exceeds capacity"); 390 | 391 | ptr::copy_nonoverlapping( 392 | old_vec.ptr, 393 | new_vec.as_mut_ptr().add(new_vec.len()), 394 | min_size, 395 | ); 396 | 397 | old_vec.ptr = old_vec.ptr.add(min_size); 398 | old_vec.len -= min_size; 399 | 400 | new_vec.set_len(new_vec.len() + min_size); 401 | } 402 | 403 | if old_size > min_size { 404 | let count = (old_size - min_size) * old_stride; 405 | let slice = ptr::slice_from_raw_parts_mut(old_vec.ptr, count); 406 | 407 | debug_assert!(old_vec.len >= count, "slice exceeds remainder"); 408 | 409 | old_vec.ptr = old_vec.ptr.add(count); 410 | old_vec.len -= count; 411 | 412 | ptr::drop_in_place(slice); 413 | } 414 | 415 | let additional = (new_size - min_size) * new_stride; 416 | 417 | debug_assert!(new_vec.len() + additional <= new_vec.capacity(), "slice exceeds capacity"); 418 | 419 | for _ in 0..additional { 420 | new_vec.as_mut_ptr().add(new_vec.len()).write(f()); 421 | new_vec.set_len(new_vec.len() + 1); 422 | } 423 | } 424 | } 425 | -------------------------------------------------------------------------------- /src/serde.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "nightly")] 2 | use std::alloc::Allocator; 3 | use std::fmt::{self, Formatter}; 4 | use std::marker::PhantomData; 5 | 6 | use serde::de::{Error, SeqAccess, Visitor}; 7 | use serde::ser::SerializeSeq; 8 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 9 | 10 | #[cfg(not(feature = "nightly"))] 11 | use crate::alloc::Allocator; 12 | use crate::array::Array; 13 | use crate::dim::Dim; 14 | use crate::layout::Layout; 15 | use crate::shape::{ConstShape, Shape}; 16 | use crate::slice::Slice; 17 | use crate::tensor::Tensor; 18 | use crate::view::{View, ViewMut}; 19 | use crate::{array, tensor}; 20 | 21 | struct TensorVisitor { 22 | phantom: PhantomData<(T, S)>, 23 | } 24 | 25 | impl<'a, T: Deserialize<'a>, S: Shape> Visitor<'a> for TensorVisitor { 26 | type Value = Tensor; 27 | 28 | fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { 29 | write!(formatter, "an array of rank {}", S::RANK.expect("invalid rank")) 30 | } 31 | 32 | fn visit_seq>(self, mut seq: A) -> Result { 33 | assert!(S::RANK.is_some_and(|rank| rank > 0), "invalid rank"); 34 | 35 | let mut vec = Vec::new(); 36 | let mut shape = S::default(); 37 | let mut size = 0; 38 | 39 | let size_hint = seq.size_hint().unwrap_or(0); 40 | 41 | if S::RANK == Some(1) { 42 | vec.reserve(size_hint); 43 | 44 | while let Some(value) = seq.next_element()? { 45 | vec.push(value); 46 | size += 1; 47 | } 48 | } else { 49 | while let Some(value) = seq.next_element::>()? { 50 | if size == 0 { 51 | vec.reserve(value.len() * size_hint); 52 | shape.with_mut_dims(|dims| { 53 | value.shape().with_dims(|src| dims[1..].copy_from_slice(src)); 54 | }); 55 | } else { 56 | shape.with_dims(|dims| { 57 | value.shape().with_dims(|src| { 58 | let dst = &dims[1..]; 59 | 60 | if src != dst { 61 | let msg = format!("invalid dimensions {src:?}, expected {dst:?}"); 62 | 63 | Err(A::Error::custom(msg)) 64 | } else { 65 | Ok(()) 66 | } 67 | }) 68 | })?; 69 | } 70 | 71 | vec.append(&mut value.into_vec()); 72 | size += 1; 73 | } 74 | } 75 | 76 | if S::Head::SIZE.is_none() { 77 | shape.with_mut_dims(|dims| dims[0] = size); 78 | } else if size != shape.dim(0) { 79 | let msg = format!("invalid dimension {size:?}, expected {:?}", shape.dim(0)); 80 | 81 | return Err(A::Error::custom(msg)); 82 | } 83 | 84 | Ok(Tensor::from(vec).into_shape(shape)) 85 | } 86 | } 87 | 88 | impl<'a, T: Deserialize<'a>, S: ConstShape> Deserialize<'a> for Array { 89 | fn deserialize>(deserializer: R) -> Result { 90 | if S::RANK != Some(0) { 91 | Ok( as Deserialize>::deserialize(deserializer)?.into()) 92 | } else { 93 | let value = ::deserialize(deserializer)?; 94 | 95 | Ok(array![value].into_shape()) 96 | } 97 | } 98 | } 99 | 100 | impl<'a, T: Deserialize<'a>, S: Shape> Deserialize<'a> for Tensor { 101 | fn deserialize>(deserializer: R) -> Result { 102 | let rank = S::RANK.expect("dynamic rank not supported"); 103 | 104 | if rank > 0 { 105 | let visitor = TensorVisitor { phantom: PhantomData }; 106 | 107 | deserializer.deserialize_seq(visitor) 108 | } else { 109 | let value = ::deserialize(deserializer)?; 110 | 111 | Ok(tensor![value].into_shape(S::default())) 112 | } 113 | } 114 | } 115 | 116 | impl Serialize for Array { 117 | fn serialize(&self, serializer: R) -> Result { 118 | (**self).serialize(serializer) 119 | } 120 | } 121 | 122 | impl Serialize for Slice { 123 | fn serialize(&self, serializer: R) -> Result { 124 | let rank = S::RANK.expect("dynamic rank not supported"); 125 | 126 | if rank == 0 { 127 | self[[]].serialize(serializer) 128 | } else { 129 | let mut seq = serializer.serialize_seq(Some(self.dim(0)))?; 130 | 131 | for x in self.outer_expr() { 132 | seq.serialize_element(&x)?; 133 | } 134 | 135 | seq.end() 136 | } 137 | } 138 | } 139 | 140 | impl Serialize for Tensor { 141 | fn serialize(&self, serializer: R) -> Result { 142 | (**self).serialize(serializer) 143 | } 144 | } 145 | 146 | impl Serialize for View<'_, T, S, L> { 147 | fn serialize(&self, serializer: R) -> Result { 148 | (**self).serialize(serializer) 149 | } 150 | } 151 | 152 | impl Serialize for ViewMut<'_, T, S, L> { 153 | fn serialize(&self, serializer: R) -> Result { 154 | (**self).serialize(serializer) 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /src/shape.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::fmt::Debug; 3 | use std::hash::{Hash, Hasher}; 4 | use std::slice; 5 | 6 | use crate::array::Array; 7 | use crate::dim::{Const, Dim, Dims, Dyn}; 8 | use crate::layout::{Layout, Strided}; 9 | use crate::tensor::Tensor; 10 | use crate::traits::Owned; 11 | 12 | /// Array shape trait. 13 | pub trait Shape: Clone + Debug + Default + Hash + Ord + Send + Sync { 14 | /// First dimension. 15 | type Head: Dim; 16 | 17 | /// Shape excluding the first dimension. 18 | type Tail: Shape; 19 | 20 | /// Shape with the reverse ordering of dimensions. 21 | type Reverse: Shape; 22 | 23 | /// Prepend the dimension to the shape. 24 | type Prepend: Shape; 25 | 26 | /// Corresponding shape with dynamically-sized dimensions. 27 | type Dyn: Shape; 28 | 29 | /// Merge each dimension pair, where constant size is preferred over dynamic. 30 | /// The result has dynamic rank if at least one of the inputs has dynamic rank. 31 | type Merge: Shape; 32 | 33 | /// Select layout `L` for rank 0, or `Strided` for rank >0 or dynamic. 34 | type Layout: Layout; 35 | 36 | /// Corresponding array type owning its contents. 37 | type Owned: Owned; 38 | 39 | #[doc(hidden)] 40 | type Dims: Dims; 41 | 42 | /// Array rank if known statically, or `None` if dynamic. 43 | const RANK: Option; 44 | 45 | /// Returns the number of elements in the specified dimension. 46 | /// 47 | /// # Panics 48 | /// 49 | /// Panics if the dimension is out of bounds. 50 | fn dim(&self, index: usize) -> usize { 51 | assert!(index < self.rank(), "invalid dimension"); 52 | 53 | self.with_dims(|dims| dims[index]) 54 | } 55 | 56 | /// Creates an array shape with the given dimensions. 57 | /// 58 | /// # Panics 59 | /// 60 | /// Panics if the dimensions are not matching static rank or constant-sized dimensions. 61 | fn from_dims(dims: &[usize]) -> Self { 62 | let mut shape = Self::new(dims.len()); 63 | 64 | shape.with_mut_dims(|dst| dst.copy_from_slice(dims)); 65 | shape 66 | } 67 | 68 | /// Returns `true` if the array contains no elements. 69 | fn is_empty(&self) -> bool { 70 | self.len() == 0 71 | } 72 | 73 | /// Returns the number of elements in the array. 74 | fn len(&self) -> usize { 75 | self.with_dims(|dims| dims.iter().product()) 76 | } 77 | 78 | /// Returns the array rank, i.e. the number of dimensions. 79 | fn rank(&self) -> usize { 80 | self.with_dims(|dims| dims.len()) 81 | } 82 | 83 | #[doc(hidden)] 84 | fn new(rank: usize) -> Self; 85 | 86 | #[doc(hidden)] 87 | fn with_dims T>(&self, f: F) -> T; 88 | 89 | #[doc(hidden)] 90 | fn with_mut_dims T>(&mut self, f: F) -> T; 91 | 92 | #[doc(hidden)] 93 | fn checked_len(&self) -> Option { 94 | self.with_dims(|dims| dims.iter().try_fold(1usize, |acc, &x| acc.checked_mul(x))) 95 | } 96 | 97 | #[doc(hidden)] 98 | fn prepend_dim(&self, size: usize) -> S { 99 | let mut shape = S::new(self.rank() + 1); 100 | 101 | shape.with_mut_dims(|dims| { 102 | dims[0] = size; 103 | self.with_dims(|src| dims[1..].copy_from_slice(src)); 104 | }); 105 | 106 | shape 107 | } 108 | 109 | #[doc(hidden)] 110 | fn remove_dim(&self, index: usize) -> S { 111 | assert!(index < self.rank(), "invalid dimension"); 112 | 113 | let mut shape = S::new(self.rank() - 1); 114 | 115 | shape.with_mut_dims(|dims| { 116 | self.with_dims(|src| { 117 | dims[..index].copy_from_slice(&src[..index]); 118 | dims[index..].copy_from_slice(&src[index + 1..]); 119 | }); 120 | }); 121 | 122 | shape 123 | } 124 | 125 | #[doc(hidden)] 126 | fn reshape(&self, mut new_shape: S) -> S { 127 | let mut inferred = None; 128 | 129 | new_shape.with_mut_dims(|dims| { 130 | for i in 0..dims.len() { 131 | if dims[i] == usize::MAX { 132 | assert!(inferred.is_none(), "at most one dimension can be inferred"); 133 | 134 | dims[i] = 1; 135 | inferred = Some(i); 136 | } 137 | } 138 | }); 139 | 140 | let old_len = self.len(); 141 | let new_len = new_shape.checked_len().expect("invalid length"); 142 | 143 | if let Some(i) = inferred { 144 | assert!(old_len % new_len == 0, "length not divisible by the new dimensions"); 145 | 146 | new_shape.with_mut_dims(|dims| dims[i] = old_len / new_len); 147 | } else { 148 | assert!(new_len == old_len, "length must not change"); 149 | } 150 | 151 | new_shape 152 | } 153 | 154 | #[doc(hidden)] 155 | fn resize_dim(&self, index: usize, new_size: usize) -> S { 156 | assert!(index < self.rank(), "invalid dimension"); 157 | 158 | let mut shape = S::new(self.rank()); 159 | 160 | shape.with_mut_dims(|dims| { 161 | self.with_dims(|src| dims[..].copy_from_slice(src)); 162 | dims[index] = new_size; 163 | }); 164 | 165 | shape 166 | } 167 | 168 | #[doc(hidden)] 169 | fn reverse(&self) -> Self::Reverse { 170 | let mut shape = Self::Reverse::new(self.rank()); 171 | 172 | shape.with_mut_dims(|dims| { 173 | self.with_dims(|src| dims.copy_from_slice(src)); 174 | dims.reverse(); 175 | }); 176 | 177 | shape 178 | } 179 | } 180 | 181 | /// Trait for array shape where all dimensions are constant-sized. 182 | pub trait ConstShape: Copy + Shape { 183 | #[doc(hidden)] 184 | type Inner; 185 | 186 | #[doc(hidden)] 187 | type WithConst>: Owned>>; 188 | } 189 | 190 | /// Conversion trait into an array shape. 191 | pub trait IntoShape { 192 | /// Which kind of array shape are we turning this into? 193 | type IntoShape: Shape; 194 | 195 | /// Creates an array shape from a value. 196 | fn into_shape(self) -> Self::IntoShape; 197 | 198 | #[doc(hidden)] 199 | fn into_dims T>(self, f: F) -> T; 200 | } 201 | 202 | /// Array shape type with dynamic rank. 203 | /// 204 | /// If the rank is 0 or 1, no heap allocation is necessary. The default value 205 | /// will have rank 1 and contain no elements. 206 | pub enum DynRank { 207 | /// Shape variant with dynamic rank. 208 | Dyn(Box<[usize]>), 209 | /// Shape variant with rank 1. 210 | One(usize), 211 | } 212 | 213 | /// Array shape type with dynamically-sized dimensions. 214 | pub type Rank = <[usize; N] as IntoShape>::IntoShape; 215 | 216 | impl DynRank { 217 | /// Returns the number of elements in each dimension. 218 | pub fn dims(&self) -> &[usize] { 219 | match self { 220 | Self::Dyn(dims) => dims, 221 | Self::One(size) => slice::from_ref(size), 222 | } 223 | } 224 | } 225 | 226 | impl Clone for DynRank { 227 | fn clone(&self) -> Self { 228 | match self { 229 | Self::One(dim) => Self::One(*dim), 230 | Self::Dyn(dims) => { 231 | if dims.len() == 1 { 232 | Self::One(dims[0]) 233 | } else { 234 | Self::Dyn(dims.clone()) 235 | } 236 | } 237 | } 238 | } 239 | 240 | fn clone_from(&mut self, source: &Self) { 241 | if let Self::Dyn(dims) = self { 242 | if let Self::Dyn(src) = source { 243 | if dims.len() == src.len() { 244 | dims.clone_from_slice(src); 245 | 246 | return; 247 | } 248 | } 249 | } 250 | 251 | *self = source.clone(); 252 | } 253 | } 254 | 255 | impl Debug for DynRank { 256 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 257 | self.with_dims(|dims| f.debug_tuple("DynRank").field(&dims).finish()) 258 | } 259 | } 260 | 261 | impl Default for DynRank { 262 | fn default() -> Self { 263 | Self::One(0) 264 | } 265 | } 266 | 267 | impl Eq for DynRank {} 268 | 269 | impl Hash for DynRank { 270 | fn hash(&self, state: &mut H) { 271 | self.with_dims(|dims| dims.hash(state)) 272 | } 273 | } 274 | 275 | impl Ord for DynRank { 276 | fn cmp(&self, other: &Self) -> Ordering { 277 | self.with_dims(|dims| other.with_dims(|other| dims.cmp(other))) 278 | } 279 | } 280 | 281 | impl PartialEq for DynRank { 282 | fn eq(&self, other: &Self) -> bool { 283 | self.with_dims(|dims| other.with_dims(|other| dims.eq(other))) 284 | } 285 | } 286 | 287 | impl PartialOrd for DynRank { 288 | fn partial_cmp(&self, other: &Self) -> Option { 289 | Some(self.cmp(other)) 290 | } 291 | } 292 | 293 | impl Shape for DynRank { 294 | type Head = Dyn; 295 | type Tail = Self; 296 | 297 | type Reverse = Self; 298 | type Prepend = Self; 299 | 300 | type Dyn = Self; 301 | type Merge = Self; 302 | 303 | type Layout = Strided; 304 | type Owned = Tensor; 305 | 306 | type Dims = Box<[T]>; 307 | 308 | const RANK: Option = None; 309 | 310 | fn new(rank: usize) -> Self { 311 | if rank == 1 { Self::One(0) } else { Self::Dyn(Dims::new(rank)) } 312 | } 313 | 314 | fn with_dims T>(&self, f: F) -> T { 315 | let dims = match self { 316 | Self::Dyn(dims) => dims, 317 | Self::One(size) => slice::from_ref(size), 318 | }; 319 | 320 | f(dims) 321 | } 322 | 323 | fn with_mut_dims T>(&mut self, f: F) -> T { 324 | let dims = match self { 325 | Self::Dyn(dims) => dims, 326 | Self::One(size) => slice::from_mut(size), 327 | }; 328 | 329 | f(dims) 330 | } 331 | } 332 | 333 | impl Shape for () { 334 | type Head = Dyn; 335 | type Tail = Self; 336 | 337 | type Reverse = Self; 338 | type Prepend = (D,); 339 | 340 | type Dyn = Self; 341 | type Merge = S; 342 | 343 | type Layout = L; 344 | type Owned = Array; 345 | 346 | type Dims = [T; 0]; 347 | 348 | const RANK: Option = Some(0); 349 | 350 | fn new(rank: usize) { 351 | assert!(rank == 0, "invalid rank"); 352 | } 353 | 354 | fn with_dims T>(&self, f: F) -> T { 355 | f(&[]) 356 | } 357 | 358 | fn with_mut_dims T>(&mut self, f: F) -> T { 359 | f(&mut []) 360 | } 361 | } 362 | 363 | impl Shape for (X,) { 364 | type Head = X; 365 | type Tail = (); 366 | 367 | type Reverse = Self; 368 | type Prepend = (D, X); 369 | 370 | type Dyn = (Dyn,); 371 | type Merge = ::Prepend>; 372 | 373 | type Layout = Strided; 374 | type Owned = X::Owned; 375 | 376 | type Dims = [T; 1]; 377 | 378 | const RANK: Option = Some(1); 379 | 380 | fn new(rank: usize) -> Self { 381 | assert!(rank == 1, "invalid rank"); 382 | 383 | Self::default() 384 | } 385 | 386 | fn with_dims T>(&self, f: F) -> T { 387 | f(&[self.0.size()]) 388 | } 389 | 390 | fn with_mut_dims T>(&mut self, f: F) -> T { 391 | let mut dims = [self.0.size()]; 392 | let value = f(&mut dims); 393 | 394 | *self = (X::from_size(dims[0]),); 395 | 396 | value 397 | } 398 | } 399 | 400 | #[cfg(not(feature = "nightly"))] 401 | macro_rules! dyn_shape { 402 | ($($yz:tt),+) => { 403 | <::Dyn as Shape>::Prepend 404 | }; 405 | } 406 | 407 | #[cfg(feature = "nightly")] 408 | macro_rules! dyn_shape { 409 | ($($yz:tt),+) => { 410 | (Dyn $(,${ignore($yz)} Dyn)+) 411 | }; 412 | } 413 | 414 | macro_rules! impl_shape { 415 | ($n:tt, ($($jk:tt),+), ($($yz:tt),+), $reverse:tt, $prepend:tt) => { 416 | impl Shape for (X $(,$yz)+) { 417 | type Head = X; 418 | type Tail = ($($yz,)+); 419 | 420 | type Reverse = $reverse; 421 | type Prepend = $prepend; 422 | 423 | type Dyn = dyn_shape!($($yz),+); 424 | type Merge = 425 | <::Merge as Shape>::Prepend>; 426 | 427 | type Layout = Strided; 428 | type Owned = X::Owned; 429 | 430 | type Dims = [T; $n]; 431 | 432 | const RANK: Option = Some($n); 433 | 434 | fn new(rank: usize) -> Self { 435 | assert!(rank == $n, "invalid rank"); 436 | 437 | Self::default() 438 | } 439 | 440 | fn with_dims T>(&self, f: F) -> T { 441 | f(&[self.0.size() $(,self.$jk.size())+]) 442 | } 443 | 444 | fn with_mut_dims T>(&mut self, f: F) -> T { 445 | let mut dims = [self.0.size() $(,self.$jk.size())+]; 446 | let value = f(&mut dims); 447 | 448 | *self = (X::from_size(dims[0]) $(,$yz::from_size(dims[$jk]))+); 449 | 450 | value 451 | } 452 | } 453 | }; 454 | } 455 | 456 | impl_shape!(2, (1), (Y), (Y, X), (D, X, Y)); 457 | impl_shape!(3, (1, 2), (Y, Z), (Z, Y, X), (D, X, Y, Z)); 458 | impl_shape!(4, (1, 2, 3), (Y, Z, W), (W, Z, Y, X), (D, X, Y, Z, W)); 459 | impl_shape!(5, (1, 2, 3, 4), (Y, Z, W, U), (U, W, Z, Y, X), (D, X, Y, Z, W, U)); 460 | impl_shape!(6, (1, 2, 3, 4, 5), (Y, Z, W, U, V), (V, U, W, Z, Y, X), DynRank); 461 | 462 | macro_rules! impl_const_shape { 463 | (($($xyz:tt),*), $inner:ty, $with_const:tt) => { 464 | impl<$(const $xyz: usize),*> ConstShape for ($(Const<$xyz>,)*) { 465 | type Inner = $inner; 466 | type WithConst> = 467 | $with_const>>; 468 | } 469 | }; 470 | } 471 | 472 | impl_const_shape!((), T, Array); 473 | impl_const_shape!((X), [T; X], Array); 474 | impl_const_shape!((X, Y), [[T; Y]; X], Array); 475 | impl_const_shape!((X, Y, Z), [[[T; Z]; Y]; X], Array); 476 | impl_const_shape!((X, Y, Z, W), [[[[T; W]; Z]; Y]; X], Array); 477 | impl_const_shape!((X, Y, Z, W, U), [[[[[T; U]; W]; Z]; Y]; X], Array); 478 | impl_const_shape!((X, Y, Z, W, U, V), [[[[[[T; V]; U]; W]; Z]; Y]; X], Tensor); 479 | 480 | impl IntoShape for S { 481 | type IntoShape = S; 482 | 483 | fn into_shape(self) -> S { 484 | self 485 | } 486 | 487 | fn into_dims T>(self, f: F) -> T { 488 | self.with_dims(f) 489 | } 490 | } 491 | 492 | impl IntoShape for &[usize; N] { 493 | type IntoShape = DynRank; 494 | 495 | fn into_shape(self) -> DynRank { 496 | Shape::from_dims(self) 497 | } 498 | 499 | fn into_dims T>(self, f: F) -> T { 500 | f(self) 501 | } 502 | } 503 | 504 | impl IntoShape for &[usize] { 505 | type IntoShape = DynRank; 506 | 507 | fn into_shape(self) -> DynRank { 508 | Shape::from_dims(self) 509 | } 510 | 511 | fn into_dims T>(self, f: F) -> T { 512 | f(self) 513 | } 514 | } 515 | 516 | impl IntoShape for Box<[usize]> { 517 | type IntoShape = DynRank; 518 | 519 | fn into_shape(self) -> DynRank { 520 | DynRank::Dyn(self) 521 | } 522 | 523 | fn into_dims T>(self, f: F) -> T { 524 | f(&self) 525 | } 526 | } 527 | 528 | impl IntoShape for Const { 529 | type IntoShape = (Self,); 530 | 531 | fn into_shape(self) -> Self::IntoShape { 532 | (self,) 533 | } 534 | 535 | fn into_dims T>(self, f: F) -> T { 536 | f(&[N]) 537 | } 538 | } 539 | 540 | impl IntoShape for Dyn { 541 | type IntoShape = (Self,); 542 | 543 | fn into_shape(self) -> Self::IntoShape { 544 | (self,) 545 | } 546 | 547 | fn into_dims T>(self, f: F) -> T { 548 | f(&[self]) 549 | } 550 | } 551 | 552 | impl IntoShape for Vec { 553 | type IntoShape = DynRank; 554 | 555 | fn into_shape(self) -> DynRank { 556 | DynRank::Dyn(self.into()) 557 | } 558 | 559 | fn into_dims T>(self, f: F) -> T { 560 | f(&self) 561 | } 562 | } 563 | 564 | macro_rules! impl_into_shape { 565 | ($n:tt, $shape:ty) => { 566 | impl IntoShape for [usize; $n] { 567 | type IntoShape = $shape; 568 | 569 | fn into_shape(self) -> Self::IntoShape { 570 | Shape::from_dims(&self) 571 | } 572 | 573 | fn into_dims T>(self, f: F) -> T { 574 | f(&self) 575 | } 576 | } 577 | }; 578 | } 579 | 580 | impl_into_shape!(0, ()); 581 | impl_into_shape!(1, (Dyn,)); 582 | impl_into_shape!(2, (Dyn, Dyn)); 583 | impl_into_shape!(3, (Dyn, Dyn, Dyn)); 584 | impl_into_shape!(4, (Dyn, Dyn, Dyn, Dyn)); 585 | impl_into_shape!(5, (Dyn, Dyn, Dyn, Dyn, Dyn)); 586 | impl_into_shape!(6, (Dyn, Dyn, Dyn, Dyn, Dyn, Dyn)); 587 | -------------------------------------------------------------------------------- /src/traits.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::BorrowMut; 2 | 3 | use crate::dim::Const; 4 | use crate::expr::{Apply, FromExpression}; 5 | use crate::shape::Shape; 6 | use crate::slice::Slice; 7 | 8 | /// Trait for generalization of `Clone` that can reuse an existing object. 9 | pub trait IntoCloned { 10 | /// Moves an existing object or clones from a reference to the target object. 11 | fn clone_to(self, target: &mut T); 12 | 13 | /// Returns an existing object or a new clone from a reference. 14 | fn into_cloned(self) -> T; 15 | } 16 | 17 | impl IntoCloned for &T { 18 | fn clone_to(self, target: &mut T) { 19 | target.clone_from(self); 20 | } 21 | 22 | fn into_cloned(self) -> T { 23 | self.clone() 24 | } 25 | } 26 | 27 | impl IntoCloned for T { 28 | fn clone_to(self, target: &mut T) { 29 | *target = self; 30 | } 31 | 32 | fn into_cloned(self) -> T { 33 | self 34 | } 35 | } 36 | 37 | /// Trait for a multidimensional array owning its contents. 38 | pub trait Owned: Apply + BorrowMut> + FromExpression { 39 | #[doc(hidden)] 40 | type WithConst: Owned>>; 41 | 42 | #[doc(hidden)] 43 | fn clone_from_slice(&mut self, slice: &Slice) 44 | where 45 | T: Clone; 46 | } 47 | -------------------------------------------------------------------------------- /tests/aligned_alloc/mod.rs: -------------------------------------------------------------------------------- 1 | use std::alloc::{AllocError, Allocator, Global, Layout}; 2 | use std::cmp; 3 | use std::ptr::NonNull; 4 | 5 | /// Aligned memory allocator, using the global allocator as default. 6 | #[derive(Clone, Copy, Debug, Default)] 7 | pub struct AlignedAlloc { 8 | alloc: A, 9 | } 10 | 11 | impl AlignedAlloc { 12 | /// Creates a new aligned allocator based on the specified allocator. 13 | pub fn new(alloc: A) -> Self { 14 | assert!(N.is_power_of_two(), "alignment must be power of two"); 15 | 16 | Self { alloc } 17 | } 18 | } 19 | 20 | unsafe impl Allocator for AlignedAlloc { 21 | fn allocate(&self, layout: Layout) -> Result, AllocError> { 22 | self.alloc.allocate(aligned_layout::(layout)) 23 | } 24 | 25 | unsafe fn deallocate(&self, ptr: NonNull, layout: Layout) { 26 | unsafe { 27 | self.alloc.deallocate(ptr, aligned_layout::(layout)); 28 | } 29 | } 30 | 31 | fn allocate_zeroed(&self, layout: Layout) -> Result, AllocError> { 32 | self.alloc.allocate_zeroed(aligned_layout::(layout)) 33 | } 34 | 35 | unsafe fn grow( 36 | &self, 37 | ptr: NonNull, 38 | old_layout: Layout, 39 | new_layout: Layout, 40 | ) -> Result, AllocError> { 41 | unsafe { 42 | self.alloc.grow(ptr, aligned_layout::(old_layout), aligned_layout::(new_layout)) 43 | } 44 | } 45 | 46 | unsafe fn grow_zeroed( 47 | &self, 48 | ptr: NonNull, 49 | old_layout: Layout, 50 | new_layout: Layout, 51 | ) -> Result, AllocError> { 52 | unsafe { 53 | self.alloc.grow_zeroed( 54 | ptr, 55 | aligned_layout::(old_layout), 56 | aligned_layout::(new_layout), 57 | ) 58 | } 59 | } 60 | 61 | unsafe fn shrink( 62 | &self, 63 | ptr: NonNull, 64 | old_layout: Layout, 65 | new_layout: Layout, 66 | ) -> Result, AllocError> { 67 | unsafe { 68 | self.alloc.shrink(ptr, aligned_layout::(old_layout), aligned_layout::(new_layout)) 69 | } 70 | } 71 | } 72 | 73 | fn aligned_layout(layout: Layout) -> Layout { 74 | // Align to the specified value, but not larger than the layout size rounded 75 | // to the next power of two and not smaller than the layout alignment. 76 | let align = cmp::min(N, layout.size().next_power_of_two()); 77 | 78 | unsafe { Layout::from_size_align_unchecked(layout.size(), cmp::max(layout.align(), align)) } 79 | } 80 | --------------------------------------------------------------------------------